回调函数 #
什么是回调函数? #
回调函数在训练过程中的特定时间点被调用,用于监控和干预训练过程。
text
┌─────────────────────────────────────────────────────────────┐
│ 回调函数调用时机 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 训练开始 ──► on_train_begin │
│ │ │
│ ▼ │
│ Epoch 开始 ──► on_epoch_begin │
│ │ │
│ ▼ │
│ Batch 开始 ──► on_train_batch_begin │
│ │ │
│ ▼ │
│ Batch 结束 ──► on_train_batch_end │
│ │ │
│ ▼ │
│ Epoch 结束 ──► on_epoch_end │
│ │ │
│ ▼ │
│ 训练结束 ──► on_train_end │
│ │
└─────────────────────────────────────────────────────────────┘
EarlyStopping #
基本用法 #
python
import keras
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=5,
restore_best_weights=True
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[early_stopping]
)
参数详解 #
python
keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False
)
text
┌─────────────────────────────────────────────────────────────┐
│ EarlyStopping 参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ monitor: 监控指标 │
│ ├── 'val_loss': 验证损失 │
│ └── 'val_accuracy': 验证准确率 │
│ │
│ patience: 容忍轮数 │
│ └── 指标不改善多少轮后停止 │
│ │
│ min_delta: 最小改善量 │
│ └── 小于此值不算改善 │
│ │
│ mode: 模式 │
│ ├── 'auto': 自动判断 │
│ ├── 'min': 指标越小越好 │
│ └── 'max': 指标越大越好 │
│ │
│ restore_best_weights: 是否恢复最佳权重 │
│ │
└─────────────────────────────────────────────────────────────┘
ModelCheckpoint #
基本用法 #
python
import keras
checkpoint = keras.callbacks.ModelCheckpoint(
filepath='model_{epoch:02d}_{val_loss:.4f}.keras',
monitor='val_loss',
save_best_only=True,
mode='min'
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[checkpoint]
)
参数详解 #
python
keras.callbacks.ModelCheckpoint(
filepath='model.keras',
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
initial_value_threshold=None
)
保存权重 #
python
import keras
checkpoint = keras.callbacks.ModelCheckpoint(
filepath='weights_{epoch:02d}.weights.h5',
save_weights_only=True,
save_freq='epoch'
)
ReduceLROnPlateau #
python
import keras
lr_scheduler = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
min_lr=1e-6,
verbose=1
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[lr_scheduler]
)
text
┌─────────────────────────────────────────────────────────────┐
│ ReduceLROnPlateau 参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ monitor: 监控指标 │
│ │
│ factor: 学习率衰减因子 │
│ └── 新学习率 = 旧学习率 × factor │
│ │
│ patience: 容忍轮数 │
│ └── 指标不改善多少轮后降低学习率 │
│ │
│ min_lr: 最小学习率 │
│ │
│ min_delta: 最小改善量 │
│ │
└─────────────────────────────────────────────────────────────┘
LearningRateScheduler #
基本用法 #
python
import keras
def lr_schedule(epoch, lr):
if epoch < 10:
return lr
elif epoch < 20:
return lr * 0.1
else:
return lr * 0.01
lr_scheduler = keras.callbacks.LearningRateScheduler(lr_schedule, verbose=1)
model.fit(
x_train, y_train,
epochs=30,
callbacks=[lr_scheduler]
)
余弦退火 #
python
import keras
import math
def cosine_annealing(epoch, lr):
epochs_total = 100
lr_min = 1e-6
lr_max = 0.1
return lr_min + (lr_max - lr_min) * (1 + math.cos(math.pi * epoch / epochs_total)) / 2
lr_scheduler = keras.callbacks.LearningRateScheduler(cosine_annealing)
TensorBoard #
python
import keras
tensorboard = keras.callbacks.TensorBoard(
log_dir='./logs',
histogram_freq=1,
write_graph=True,
write_images=True,
update_freq='epoch',
profile_batch=2
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[tensorboard]
)
CSVLogger #
python
import keras
csv_logger = keras.callbacks.CSVLogger(
'training_log.csv',
separator=',',
append=False
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[csv_logger]
)
LambdaCallback #
python
import keras
lambda_callback = keras.callbacks.LambdaCallback(
on_epoch_begin=lambda epoch, logs: print(f'Epoch {epoch} 开始'),
on_epoch_end=lambda epoch, logs: print(f'Epoch {epoch} 结束, loss: {logs["loss"]:.4f}'),
on_batch_begin=lambda batch, logs: None,
on_batch_end=lambda batch, logs: None,
on_train_begin=lambda logs: print('训练开始'),
on_train_end=lambda logs: print('训练结束')
)
model.fit(x_train, y_train, epochs=10, callbacks=[lambda_callback])
自定义回调函数 #
python
import keras
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
print('训练开始')
def on_train_end(self, logs=None):
print('训练结束')
def on_epoch_begin(self, epoch, logs=None):
print(f'\nEpoch {epoch + 1} 开始')
def on_epoch_end(self, epoch, logs=None):
print(f'Epoch {epoch + 1} 结束')
print(f'训练损失: {logs["loss"]:.4f}')
if 'val_loss' in logs:
print(f'验证损失: {logs["val_loss"]:.4f}')
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=10,
callbacks=[CustomCallback()]
)
组合使用多个回调 #
python
import keras
callbacks = [
keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=10,
restore_best_weights=True
),
keras.callbacks.ModelCheckpoint(
filepath='best_model.keras',
monitor='val_loss',
save_best_only=True
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=5,
min_lr=1e-6
),
keras.callbacks.TensorBoard(
log_dir='./logs'
)
]
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
batch_size=32,
callbacks=callbacks
)
下一步 #
现在你已经掌握了回调函数,接下来学习 数据预处理,了解如何处理训练数据!
最后更新:2026-04-04