分布式训练 #

分布式训练概述 #

分布式训练允许在多个设备或机器上并行训练模型,加速训练过程并处理更大规模的数据。

分布式策略 #

text
┌─────────────────────────────────────────────────────────────┐
│                    分布式策略                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  MirroredStrategy                                           │
│  ├── 单机多 GPU                                             │
│  ├── 同步训练                                               │
│  └── 最常用                                                 │
│                                                             │
│  MultiWorkerMirroredStrategy                                │
│  ├── 多机多 GPU                                             │
│  ├── 同步训练                                               │
│  └── 集群训练                                               │
│                                                             │
│  TPUStrategy                                                │
│  ├── TPU 设备                                               │
│  ├── 高性能                                                 │
│  └── Google Cloud                                           │
│                                                             │
│  ParameterServerStrategy                                    │
│  ├── 参数服务器架构                                         │
│  ├── 异步训练                                               │
│  └── 大规模训练                                             │
│                                                             │
└─────────────────────────────────────────────────────────────┘

MirroredStrategy #

基本用法 #

python
import tensorflow as tf
import numpy as np

strategy = tf.distribute.MirroredStrategy()
print(f"设备数量: {strategy.num_replicas_in_sync}")

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

x_train = np.random.random((1000, 784)).astype(np.float32)
y_train = np.random.randint(10, size=(1000,))

batch_size = 64 * strategy.num_replicas_in_sync
model.fit(x_train, y_train, epochs=10, batch_size=batch_size)

指定 GPU #

python
import tensorflow as tf

# 使用特定 GPU
strategy = tf.distribute.MirroredStrategy(devices=['/gpu:0', '/gpu:1'])

# 使用所有 GPU
strategy = tf.distribute.MirroredStrategy()

自定义训练循环 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10)
    ])
    
    optimizer = tf.keras.optimizers.Adam()
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(1000).batch(64)
train_dataset = strategy.experimental_distribute_dataset(train_dataset)

@tf.function
def train_step(iterator):
    def step_fn(inputs):
        x, y = inputs
        with tf.GradientTape() as tape:
            logits = model(x, training=True)
            loss = loss_fn(y, logits)
        
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss
    
    return strategy.run(step_fn, args=(next(iterator),))

iterator = iter(train_dataset)
for epoch in range(10):
    loss = train_step(iterator)
    print(f"Epoch {epoch + 1}: Loss = {loss}")

MultiWorkerMirroredStrategy #

配置集群 #

python
import tensorflow as tf
import os
import json

# 设置 TF_CONFIG 环境变量
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["192.168.1.1:12345", "192.168.1.2:12345"]
    },
    'task': {'type': 'worker', 'index': 0}
})

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

per_worker_batch_size = 64
global_batch_size = per_worker_batch_size * strategy.num_replicas_in_sync

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.batch(global_batch_size)

model.fit(train_dataset, epochs=10)

启动多机训练 #

bash
# 在 worker 0 上
export TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:12346"]},"task":{"type":"worker","index":0}}'
python train.py

# 在 worker 1 上
export TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:12346"]},"task":{"type":"worker","index":1}}'
python train.py

TPUStrategy #

连接 TPU #

python
import tensorflow as tf

# 连接 TPU
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='your-tpu-address')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)
print(f"TPU 核心数: {strategy.num_replicas_in_sync}")

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

Google Colab TPU #

python
import tensorflow as tf

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:
    strategy = tf.distribute.get_strategy()

print(f"设备数: {strategy.num_replicas_in_sync}")

ParameterServerStrategy #

基本配置 #

python
import tensorflow as tf
import os
import json

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:12346"],
        'ps': ["localhost:12347"]
    },
    'task': {'type': 'worker', 'index': 0}
})

strategy = tf.distribute.ParameterServerStrategy()

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10)
    ])
    
    optimizer = tf.keras.optimizers.Adam()
    model.compile(optimizer=optimizer, loss='mse')

分布式数据集 #

分布数据集 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

global_batch_size = 64 * strategy.num_replicas_in_sync

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).batch(global_batch_size)

# 分布数据集
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

分布式迭代器 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.batch(64)
dist_dataset = strategy.experimental_distribute_dataset(train_dataset)

iterator = iter(dist_dataset)

@tf.function
def train_step():
    def step_fn(inputs):
        x, y = inputs
        with tf.GradientTape() as tape:
            predictions = model(x, training=True)
            loss = loss_fn(y, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        return loss
    
    return strategy.run(step_fn, args=(next(iterator),))

for epoch in range(10):
    loss = train_step()
    print(f"Epoch {epoch + 1}: Loss = {loss}")

分布式指标 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

# 指标会自动聚合
history = model.fit(train_dataset, epochs=10)

保存和恢复 #

分布式模型保存 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10)
    ])
    
    model.compile(optimizer='adam', loss='mse')

# 训练
model.fit(x_train, y_train, epochs=10)

# 保存(只在 chief worker 上保存)
if strategy.cluster_resolver.task_type == 'chief' or strategy.cluster_resolver is None:
    model.save('model.keras')

# 加载
with strategy.scope():
    loaded_model = tf.keras.models.load_model('model.keras')

最佳实践 #

批处理大小 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

per_gpu_batch_size = 64
global_batch_size = per_gpu_batch_size * strategy.num_replicas_in_sync

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.batch(global_batch_size)

学习率缩放 #

python
import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()

base_learning_rate = 0.001
scaled_learning_rate = base_learning_rate * strategy.num_replicas_in_sync

with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10)
    ])
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=scaled_learning_rate)
    model.compile(optimizer=optimizer, loss='mse')

下一步 #

现在你已经掌握了分布式训练,接下来学习 CNN 图像分类,开始实际项目实战!

最后更新:2026-04-04