多 GPU 训练 #

分布式训练概述 #

text
┌─────────────────────────────────────────────────────────────┐
│                    分布式训练策略                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  数据并行:                                                  │
│  ├── 每个设备有完整模型副本                                │
│  ├── 数据分割到各设备                                      │
│  └── 梯度聚合后更新                                        │
│                                                             │
│  模型并行:                                                  │
│  ├── 模型分割到各设备                                      │
│  ├── 适合超大模型                                          │
│  └── 实现复杂                                              │
│                                                             │
│  Keras 主要使用数据并行                                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

MirroredStrategy #

单机多 GPU #

python
import keras
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()
print(f'设备数量: {strategy.num_replicas_in_sync}')

with strategy.scope():
    model = keras.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

model.fit(x_train, y_train, epochs=10, batch_size=64)

指定 GPU #

python
import keras
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[:2], 'GPU')

strategy = tf.distribute.MirroredStrategy(
    devices=['/gpu:0', '/gpu:1']
)

MultiWorkerMirroredStrategy #

多机训练 #

python
import keras
import tensorflow as tf

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
    model = keras.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

model.fit(x_train, y_train, epochs=10, batch_size=64)

配置集群 #

python
import os

os.environ['TF_CONFIG'] = '''
{
    "cluster": {
        "worker": ["host1:port", "host2:port"]
    },
    "task": {"type": "worker", "index": 0}
}
'''

TPUStrategy #

python
import keras
import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    model = keras.Sequential([
        keras.layers.Dense(512, activation='relu', input_shape=(784,)),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

批次大小调整 #

text
┌─────────────────────────────────────────────────────────────┐
│                    批次大小建议                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  全局批次大小 = 单 GPU 批次 × GPU 数量                      │
│                                                             │
│  例:                                                        │
│  单 GPU 批次: 32                                            │
│  GPU 数量: 4                                                │
│  全局批次: 128                                              │
│                                                             │
│  注意事项:                                                  │
│  ├── 学习率可能需要相应调整                                │
│  ├── 批次太大可能影响泛化                                  │
│  └── 确保数据量足够                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

完整示例 #

python
import keras
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

strategy = tf.distribute.MirroredStrategy()
print(f'使用 {strategy.num_replicas_in_sync} 个 GPU')

batch_size = 64 * strategy.num_replicas_in_sync

with strategy.scope():
    model = keras.Sequential([
        keras.layers.Conv2D(32, 3, padding='same', activation='relu', input_shape=(32, 32, 3)),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(32, 3, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.MaxPooling2D(),
        keras.layers.Dropout(0.25),
        
        keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Conv2D(64, 3, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.MaxPooling2D(),
        keras.layers.Dropout(0.25),
        
        keras.layers.Flatten(),
        keras.layers.Dense(512, activation='relu'),
        keras.layers.BatchNormalization(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

history = model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=50,
    batch_size=batch_size,
    callbacks=[
        keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
        keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
    ]
)

下一步 #

现在你已经掌握了多 GPU 训练,接下来学习 图像分类实战,应用所学知识!

最后更新:2026-04-04