模型检查点 #
为什么需要检查点? #
text
┌─────────────────────────────────────────────────────────────┐
│ 检查点的作用 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 防止意外: │
│ ├── 训练中断 │
│ ├── 硬件故障 │
│ └── 断电等意外情况 │
│ │
│ 保存最佳: │
│ ├── 最佳验证性能 │
│ ├── 防止过拟合后丢失最佳模型 │
│ └── 方便后续使用 │
│ │
│ 恢复训练: │
│ ├── 从断点继续 │
│ └── 节省时间 │
│ │
└─────────────────────────────────────────────────────────────┘
ModelCheckpoint #
保存最佳模型 #
python
import keras
checkpoint = keras.callbacks.ModelCheckpoint(
filepath='best_model.keras',
monitor='val_loss',
save_best_only=True,
mode='min',
verbose=1
)
model.fit(
x_train, y_train,
validation_data=(x_val, y_val),
epochs=100,
callbacks=[checkpoint]
)
定期保存 #
python
import keras
checkpoint = keras.callbacks.ModelCheckpoint(
filepath='model_{epoch:02d}.keras',
save_freq='epoch'
)
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint])
仅保存权重 #
python
import keras
checkpoint = keras.callbacks.ModelCheckpoint(
filepath='weights_{epoch:02d}.weights.h5',
save_weights_only=True,
save_freq='epoch'
)
model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint])
完整示例 #
python
import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
model = keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath='best_model.keras',
monitor='val_accuracy',
save_best_only=True,
mode='max',
verbose=1
),
keras.callbacks.EarlyStopping(
monitor='val_accuracy',
patience=5,
restore_best_weights=True
)
]
history = model.fit(
x_train, y_train,
validation_split=0.1,
epochs=50,
batch_size=128,
callbacks=callbacks
)
best_model = keras.models.load_model('best_model.keras')
test_loss, test_acc = best_model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc:.4f}')
下一步 #
现在你已经掌握了模型检查点,接下来学习 迁移学习,利用预训练模型!
最后更新:2026-04-04