模型训练 #

训练流程 #

TensorFlow 提供了多种训练模型的方式,从简单的 fit 方法到完全自定义的训练循环。

text
┌─────────────────────────────────────────────────────────────┐
│                    训练方式选择                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  model.fit()                                                │
│  ├── 最简单                                                 │
│  ├── 内置训练循环                                           │
│  └── 适合大多数场景                                         │
│                                                             │
│  自定义 train_step                                          │
│  ├── 中等复杂度                                             │
│  ├── 重写训练步骤                                           │
│  └── 适合特殊训练逻辑                                       │
│                                                             │
│  完全自定义训练循环                                         │
│  ├── 最灵活                                                 │
│  ├── 使用 GradientTape                                      │
│  └── 适合研究/复杂训练                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

fit 方法 #

基本训练 #

python
import tensorflow as tf
import numpy as np

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 生成示例数据
x_train = np.random.random((1000, 784)).astype(np.float32)
y_train = np.random.randint(10, size=(1000,))

# 训练
history = model.fit(
    x_train, y_train,
    batch_size=32,
    epochs=10,
    verbose=1
)

使用验证数据 #

python
import tensorflow as tf

x_val = np.random.random((200, 784)).astype(np.float32)
y_val = np.random.randint(10, size=(200,))

history = model.fit(
    x_train, y_train,
    batch_size=32,
    epochs=10,
    validation_data=(x_val, y_val),
    verbose=1
)

# 使用验证分割
history = model.fit(
    x_train, y_train,
    batch_size=32,
    epochs=10,
    validation_split=0.2,
    verbose=1
)

使用 tf.data #

python
import tensorflow as tf

# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(32)

# 训练
history = model.fit(
    train_dataset,
    epochs=10,
    validation_data=val_dataset
)

训练历史 #

python
import tensorflow as tf
import matplotlib.pyplot as plt

history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)

# 查看历史记录
print(f"历史键: {history.history.keys()}")
print(f"训练损失: {history.history['loss']}")
print(f"验证损失: {history.history['val_loss']}")

# 绘制训练曲线
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.savefig('training_history.png')

回调函数 #

常用回调 #

python
import tensorflow as tf

callbacks = [
    # 早停
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    
    # 学习率衰减
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    
    # 模型检查点
    tf.keras.callbacks.ModelCheckpoint(
        filepath='best_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    
    # TensorBoard
    tf.keras.callbacks.TensorBoard(
        log_dir='./logs',
        histogram_freq=1,
        write_graph=True,
        write_images=True,
        update_freq='epoch'
    ),
    
    # CSV 日志
    tf.keras.callbacks.CSVLogger(
        'training_log.csv',
        separator=',',
        append=False
    ),
    
    # 终止训练
    tf.keras.callbacks.TerminateOnNaN()
]

history = model.fit(
    x_train, y_train,
    epochs=100,
    validation_split=0.2,
    callbacks=callbacks
)

自定义回调 #

python
import tensorflow as tf

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.best_weights = None
    
    def on_train_begin(self, logs=None):
        print("训练开始!")
        self.best_val_loss = float('inf')
    
    def on_epoch_end(self, epoch, logs=None):
        val_loss = logs.get('val_loss')
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_weights = self.model.get_weights()
            print(f"\nEpoch {epoch+1}: val_loss 改进到 {val_loss:.4f}")
    
    def on_train_end(self, logs=None):
        print(f"\n训练结束! 最佳 val_loss: {self.best_val_loss:.4f}")
        if self.best_weights:
            self.model.set_weights(self.best_weights)

history = model.fit(
    x_train, y_train,
    epochs=10,
    validation_split=0.2,
    callbacks=[CustomCallback()]
)

学习率调度回调 #

python
import tensorflow as tf

# 学习率预热 + 衰减
initial_learning_rate = 0.001
warmup_epochs = 5
total_epochs = 50

def lr_schedule(epoch):
    if epoch < warmup_epochs:
        return initial_learning_rate * (epoch + 1) / warmup_epochs
    else:
        decay_epochs = total_epochs - warmup_epochs
        remaining_epochs = epoch - warmup_epochs
        return initial_learning_rate * 0.5 ** (remaining_epochs / (decay_epochs / 3))

lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)

history = model.fit(
    x_train, y_train,
    epochs=total_epochs,
    callbacks=[lr_scheduler]
)

自定义 train_step #

重写训练步骤 #

python
import tensorflow as tf

class CustomModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        self.compiled_metrics.update_state(y, y_pred)
        
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        x, y = data
        y_pred = self(x, training=False)
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

model = CustomModel([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.fit(x_train, y_train, epochs=10)

带梯度的训练步骤 #

python
import tensorflow as tf

class GradientLoggingModel(tf.keras.Model):
    def train_step(self, data):
        x, y = data
        
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
        
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # 梯度裁剪
        gradients = [tf.clip_by_norm(g, 1.0) for g in gradients]
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # 记录梯度范数
        grad_norm = tf.sqrt(sum([tf.reduce_sum(tf.square(g)) for g in gradients]))
        
        self.compiled_metrics.update_state(y, y_pred)
        
        results = {m.name: m.result() for m in self.metrics}
        results['gradient_norm'] = grad_norm
        return results

完全自定义训练循环 #

基本训练循环 #

python
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

batch_size = 32
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(1000).batch(batch_size)

epochs = 10
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            loss = loss_fn(y_batch, logits)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        train_acc_metric.update_state(y_batch, logits)
        
        if step % 100 == 0:
            print(f"Step {step}: loss = {loss:.4f}")
    
    train_acc = train_acc_metric.result()
    print(f"Training accuracy: {train_acc:.4f}")
    train_acc_metric.reset_states()

带验证的训练循环 #

python
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(1000).batch(32)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(32)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss = loss_fn(y, logits)
        loss += sum(model.losses)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_acc_metric.update_state(y, logits)
    return loss

@tf.function
def val_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

epochs = 10
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    for x_batch, y_batch in train_dataset:
        loss = train_step(x_batch, y_batch)
    
    for x_batch, y_batch in val_dataset:
        val_step(x_batch, y_batch)
    
    train_acc = train_acc_metric.result()
    val_acc = val_acc_metric.result()
    
    print(f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
    
    train_acc_metric.reset_states()
    val_acc_metric.reset_states()

带进度条的训练循环 #

python
import tensorflow as tf
from tqdm import tqdm

model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(1000).batch(32)

epochs = 10
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    pbar = tqdm(train_dataset, total=len(x_train) // 32)
    epoch_loss = 0
    
    for step, (x_batch, y_batch) in enumerate(pbar):
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            loss = loss_fn(y_batch, logits)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        epoch_loss += loss.numpy()
        pbar.set_postfix({'loss': f'{loss:.4f}'})
    
    print(f"Epoch {epoch + 1} Average Loss: {epoch_loss / (step + 1):.4f}")

下一步 #

现在你已经掌握了模型训练,接下来学习 数据管道,了解如何高效处理训练数据!

最后更新:2026-04-04