模型保存与加载 #

保存选项 #

text
┌─────────────────────────────────────────────────────────────┐
│                    模型保存选项                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  完整模型:                                                  │
│  ├── 架构 + 权重 + 优化器状态                              │
│  └── 可以继续训练                                          │
│                                                             │
│  仅权重:                                                    │
│  ├── 只保存参数值                                          │
│  └── 需要相同架构才能加载                                  │
│                                                             │
│  仅架构:                                                    │
│  ├── 只保存模型结构                                        │
│  └── 需要重新训练                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

保存完整模型 #

Keras 格式 #

python
import keras

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit(x_train, y_train, epochs=5)

model.save('model.keras')

loaded_model = keras.models.load_model('model.keras')

SavedModel 格式 #

python
import keras

model.save('saved_model')

loaded_model = keras.models.load_model('saved_model')

HDF5 格式 #

python
import keras

model.save('model.h5')

loaded_model = keras.models.load_model('model.h5')

仅保存权重 #

python
import keras

model.save_weights('weights.weights.h5')

model.load_weights('weights.weights.h5')

model.save_weights('weights')

model.load_weights('weights')

仅保存架构 #

python
import keras

config = model.get_config()

new_model = keras.Model.from_config(config)

json_config = model.to_json()

new_model = keras.models.model_from_json(json_config)

训练时保存 #

ModelCheckpoint #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='best_model.keras',
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    save_weights_only=False
)

model.fit(
    x_train, y_train,
    validation_data=(x_val, y_val),
    epochs=100,
    callbacks=[checkpoint]
)

定期保存 #

python
import keras

checkpoint = keras.callbacks.ModelCheckpoint(
    filepath='model_{epoch:02d}.keras',
    save_freq='epoch'
)

model.fit(x_train, y_train, epochs=10, callbacks=[checkpoint])

自定义对象 #

python
import keras

class CustomLayer(keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
    
    def call(self, inputs):
        return keras.ops.dot(inputs, self.kernel)

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu'),
    CustomLayer(32),
    keras.layers.Dense(10)
])

model.save('custom_model.keras')

loaded_model = keras.models.load_model(
    'custom_model.keras',
    custom_objects={'CustomLayer': CustomLayer}
)

完整示例 #

python
import keras
import numpy as np

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model = keras.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

checkpoint = keras.callbacks.ModelCheckpoint(
    'best_model.keras',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max'
)

history = model.fit(
    x_train, y_train,
    validation_split=0.1,
    epochs=20,
    batch_size=128,
    callbacks=[checkpoint]
)

model.save('final_model.keras')

loaded_model = keras.models.load_model('best_model.keras')
test_loss, test_acc = loaded_model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc:.4f}')

下一步 #

现在你已经掌握了模型保存与加载,接下来学习 TensorBoard 可视化,可视化训练过程!

最后更新:2026-04-04