模型检查点 #

为什么需要检查点? #

text
┌─────────────────────────────────────────────────────────────┐
│                    检查点的作用                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  防止意外:                                                  │
│  ├── 训练中断                                              │
│  ├── 硬件故障                                              │
│  └── 断电等意外情况                                        │
│                                                             │
│  保存最佳:                                                  │
│  ├── 最佳验证性能                                          │
│  ├── 防止过拟合后丢失最佳模型                              │
│  └── 方便后续使用                                          │
│                                                             │
│  恢复训练:                                                  │
│  ├── 从断点继续                                            │
│  └── 节省时间                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ModelCheckpoint #

保存最佳模型 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='best_model.keras',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    verbose=1
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[checkpoint]
)

定期保存 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='model_{epoch:02d}.keras',
    save_freq='epoch'
)

model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint])

仅保存权重 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='weights_{epoch:02d}.weights.h5',
    save_weights_only=True,
    save_freq='epoch'
)

model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint])

完整示例 #

python
import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

model = keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath='best_model.keras',
        monitor='val_accuracy',
        save_best_only=True,
        mode='max',
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_accuracy',
        patience=5,
        restore_best_weights=True
    )
]

history = model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=50,
    batch_size=128,
    callbacks=callbacks
)

best_model = keras.models.load_model('best_model.keras')
test_loss, test_acc = best_model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc:.4f}')

下一步 #

现在你已经掌握了模型检查点,接下来学习 迁移学习,利用预训练模型!

最后更新:2026-04-04