多 GPU 训练 #
分布式训练概述 #
text
┌─────────────────────────────────────────────────────────────┐
│ 分布式训练策略 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 数据并行: │
│ ├── 每个设备有完整模型副本 │
│ ├── 数据分割到各设备 │
│ └── 梯度聚合后更新 │
│ │
│ 模型并行: │
│ ├── 模型分割到各设备 │
│ ├── 适合超大模型 │
│ └── 实现复杂 │
│ │
│ Keras 主要使用数据并行 │
│ │
└─────────────────────────────────────────────────────────────┘
MirroredStrategy #
单机多 GPU #
python
import keras
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
print(f'设备数量: {strategy.num_replicas_in_sync}')
with strategy.scope():
model = keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(x_train, y_train, epochs=10, batch_size=64)
指定 GPU #
python
import keras
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[:2], 'GPU')
strategy = tf.distribute.MirroredStrategy(
devices=['/gpu:0', '/gpu:1']
)
MultiWorkerMirroredStrategy #
多机训练 #
python
import keras
import tensorflow as tf
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(x_train, y_train, epochs=10, batch_size=64)
配置集群 #
python
import os
os.environ['TF_CONFIG'] = '''
{
"cluster": {
"worker": ["host1:port", "host2:port"]
},
"task": {"type": "worker", "index": 0}
}
'''
TPUStrategy #
python
import keras
import tensorflow as tf
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
with strategy.scope():
model = keras.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
批次大小调整 #
text
┌─────────────────────────────────────────────────────────────┐
│ 批次大小建议 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 全局批次大小 = 单 GPU 批次 × GPU 数量 │
│ │
│ 例: │
│ 单 GPU 批次: 32 │
│ GPU 数量: 4 │
│ 全局批次: 128 │
│ │
│ 注意事项: │
│ ├── 学习率可能需要相应调整 │
│ ├── 批次太大可能影响泛化 │
│ └── 确保数据量足够 │
│ │
└─────────────────────────────────────────────────────────────┘
完整示例 #
python
import keras
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
strategy = tf.distribute.MirroredStrategy()
print(f'使用 {strategy.num_replicas_in_sync} 个 GPU')
batch_size = 64 * strategy.num_replicas_in_sync
with strategy.scope():
model = keras.Sequential([
keras.layers.Conv2D(32, 3, padding='same', activation='relu', input_shape=(32, 32, 3)),
keras.layers.BatchNormalization(),
keras.layers.Conv2D(32, 3, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPooling2D(),
keras.layers.Dropout(0.25),
keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Conv2D(64, 3, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.MaxPooling2D(),
keras.layers.Dropout(0.25),
keras.layers.Flatten(),
keras.layers.Dense(512, activation='relu'),
keras.layers.BatchNormalization(),
keras.layers.Dropout(0.5),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.001),
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
history = model.fit(
x_train, y_train,
validation_data=(x_test, y_test),
epochs=50,
batch_size=batch_size,
callbacks=[
keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
]
)
下一步 #
现在你已经掌握了多 GPU 训练,接下来学习 图像分类实战,应用所学知识!
最后更新:2026-04-04