回调函数 #

什么是回调函数? #

回调函数在训练过程中的特定时间点被调用,用于监控和干预训练过程。

text
┌─────────────────────────────────────────────────────────────┐
│                    回调函数调用时机                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  训练开始 ──► on_train_begin                                │
│       │                                                     │
│       ▼                                                     │
│  Epoch 开始 ──► on_epoch_begin                              │
│       │                                                     │
│       ▼                                                     │
│  Batch 开始 ──► on_train_batch_begin                        │
│       │                                                     │
│       ▼                                                     │
│  Batch 结束 ──► on_train_batch_end                          │
│       │                                                     │
│       ▼                                                     │
│  Epoch 结束 ──► on_epoch_end                                │
│       │                                                     │
│       ▼                                                     │
│  训练结束 ──► on_train_end                                  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

EarlyStopping #

基本用法 #

python
import keras

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[early_stopping]
)

参数详解 #

python
keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=0,
    verbose=0,
    mode='auto',
    baseline=None,
    restore_best_weights=False
)
text
┌─────────────────────────────────────────────────────────────┐
│                    EarlyStopping 参数                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  monitor: 监控指标                                          │
│  ├── 'val_loss': 验证损失                                  │
│  └── 'val_accuracy': 验证准确率                            │
│                                                             │
│  patience: 容忍轮数                                         │
│  └── 指标不改善多少轮后停止                                │
│                                                             │
│  min_delta: 最小改善量                                      │
│  └── 小于此值不算改善                                      │
│                                                             │
│  mode: 模式                                                 │
│  ├── 'auto': 自动判断                                      │
│  ├── 'min': 指标越小越好                                   │
│  └── 'max': 指标越大越好                                   │
│                                                             │
│  restore_best_weights: 是否恢复最佳权重                     │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ModelCheckpoint #

基本用法 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='model_{epoch:02d}_{val_loss:.4f}.keras',
    monitor='val_loss',
    save_best_only=True,
    mode='min'
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[checkpoint]
)

参数详解 #

python
keras.callbacks.ModelCheckpoint(
    filepath='model.keras',
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    initial_value_threshold=None
)

保存权重 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='weights_{epoch:02d}.weights.h5',
    save_weights_only=True,
    save_freq='epoch'
)

ReduceLROnPlateau #

python
import keras

lr_scheduler = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-6,
    verbose=1
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[lr_scheduler]
)
text
┌─────────────────────────────────────────────────────────────┐
│                    ReduceLROnPlateau 参数                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  monitor: 监控指标                                          │
│                                                             │
│  factor: 学习率衰减因子                                     │
│  └── 新学习率 = 旧学习率 × factor                          │
│                                                             │
│  patience: 容忍轮数                                         │
│  └── 指标不改善多少轮后降低学习率                          │
│                                                             │
│  min_lr: 最小学习率                                         │
│                                                             │
│  min_delta: 最小改善量                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

LearningRateScheduler #

基本用法 #

python
import keras

def lr_schedule(epoch, lr):
    if epoch < 10:
        return lr
    elif epoch < 20:
        return lr * 0.1
    else:
        return lr * 0.01

lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)

model.fit(
    x_train, y_train,
    epochs=30,
    callbacks=[lr_scheduler]
)

余弦退火 #

python
import keras
import math

def cosine_annealing(epoch, lr):
    epochs_total = 100
    lr_min = 1e-6
    lr_max = 0.1
    
    return lr_min + (lr_max - lr_min) * (1 + math.cos(math.pi * epoch / epochs_total)) / 2

lr_scheduler = keras.callbacks.LearningRateScheduler(cosine_annealing)

TensorBoard #

python
import keras

tensorboard = keras.callbacks.TensorBoard(
    log_dir='./logs',
    histogram_freq=1,
    write_graph=True,
    write_images=True,
    update_freq='epoch',
    profile_batch=2
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[tensorboard]
)

CSVLogger #

python
import keras

csv_logger = keras.callbacks.CSVLogger(
    'training_log.csv',
    separator=',',
    append=False
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[csv_logger]
)

LambdaCallback #

python
import keras

lambda_callback = keras.callbacks.LambdaCallback(
    on_epoch_begin=lambda epoch, logs: print(f'Epoch {epoch} 开始'),
    on_epoch_end=lambda epoch, logs: print(f'Epoch {epoch} 结束, loss: {logs["loss"]:.4f}'),
    on_batch_begin=lambda batch, logs: None,
    on_batch_end=lambda batch, logs: None,
    on_train_begin=lambda logs: print('训练开始'),
    on_train_end=lambda logs: print('训练结束')
)

model.fit(x_train, y_train, epochs=10, callbacks=[lambda_callback])

自定义回调函数 #

python
import keras

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        print('训练开始')
    
    def on_train_end(self, logs=None):
        print('训练结束')
    
    def on_epoch_begin(self, epoch, logs=None):
        print(f'\nEpoch {epoch + 1} 开始')
    
    def on_epoch_end(self, epoch, logs=None):
        print(f'Epoch {epoch + 1} 结束')
        print(f'训练损失: {logs["loss"]:.4f}')
        if 'val_loss' in logs:
            print(f'验证损失: {logs["val_loss"]:.4f}')
    
    def on_batch_begin(self, batch, logs=None):
        pass
    
    def on_batch_end(self, batch, logs=None):
        pass

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=10,
    callbacks=[CustomCallback()]
)

组合使用多个回调 #

python
import keras

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    ),
    keras.callbacks.ModelCheckpoint(
        filepath='best_model.keras',
        monitor='val_loss',
        save_best_only=True
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    ),
    keras.callbacks.TensorBoard(
        log_dir='./logs'
    )
]

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    batch_size=32,
    callbacks=callbacks
)

下一步 #

现在你已经掌握了回调函数,接下来学习 数据预处理,了解如何处理训练数据!

最后更新:2026-04-04