迁移学习 #
什么是迁移学习? #
text
┌─────────────────────────────────────────────────────────────┐
│ 迁移学习原理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 传统训练: │
│ 随机初始化 ──► 从零训练 ──► 大量数据 + 长时间 │
│ │
│ 迁移学习: │
│ 预训练模型 ──► 微调 ──► 少量数据 + 短时间 │
│ │
│ 为什么有效? │
│ ├── 浅层学习通用特征(边缘、纹理) │
│ ├── 深层学习特定特征(形状、对象) │
│ └── 知识可以迁移到相关任务 │
│ │
└─────────────────────────────────────────────────────────────┘
使用预训练模型 #
加载预训练模型 #
python
import keras
base_model = keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False
model = keras.Sequential([
base_model,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(256, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
可用预训练模型 #
text
┌─────────────────────────────────────────────────────────────┐
│ 常用预训练模型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 图像分类: │
│ ├── ResNet50/101/152 │
│ ├── VGG16/19 │
│ ├── EfficientNet B0-B7 │
│ ├── MobileNet V1/V2/V3 │
│ ├── DenseNet 121/169/201 │
│ └── Vision Transformer (ViT) │
│ │
│ 选择建议: │
│ ├── 移动端: MobileNet, EfficientNet-B0 │
│ ├── 高精度: EfficientNet-B7, ResNet152 │
│ └── 平衡: ResNet50, EfficientNet-B3 │
│ │
└─────────────────────────────────────────────────────────────┘
特征提取 #
冻结基础模型 #
python
import keras
base_model = keras.applications.EfficientNetB0(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)
微调(Fine-tuning) #
解冻部分层 #
python
import keras
base_model = keras.applications.ResNet50(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = True
for layer in base_model.layers[:-20]:
layer.trainable = False
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)
两阶段训练 #
python
import keras
base_model = keras.applications.EfficientNetB0(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
print("阶段 1: 训练分类头")
model.fit(train_dataset, epochs=5, validation_data=val_dataset)
print("阶段 2: 微调")
base_model.trainable = True
for layer in base_model.layers[:-20]:
layer.trainable = False
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)
完整示例 #
python
import keras
import tensorflow as tf
img_size = (224, 224)
batch_size = 32
num_classes = 10
train_datagen = keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
validation_split=0.2
)
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='training'
)
val_generator = train_datagen.flow_from_directory(
'data/train',
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
subset='validation'
)
base_model = keras.applications.EfficientNetB3(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Dropout(0.5)(x)
outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='categorical_crossentropy',
metrics=['accuracy']
)
print("阶段 1: 训练分类头")
history1 = model.fit(
train_generator,
epochs=5,
validation_data=val_generator
)
print("阶段 2: 微调")
base_model.trainable = True
for layer in base_model.layers[:-30]:
layer.trainable = False
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-5),
loss='categorical_crossentropy',
metrics=['accuracy']
)
history2 = model.fit(
train_generator,
epochs=15,
validation_data=val_generator,
callbacks=[
keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3)
]
)
下一步 #
现在你已经掌握了迁移学习,接下来学习 模型保存与加载,了解如何持久化模型!
最后更新:2026-04-04