数据并行 #
概述 #
数据并行是分布式训练中最常用的策略,它将数据分割到多个设备上并行计算,然后聚合梯度更新模型。
数据并行原理 #
text
┌─────────────────────────────────────────────────────────────┐
│ 数据并行流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 数据分割 │
│ Batch → [Batch/N, Batch/N, ..., Batch/N] │
│ │
│ 2. 并行前向传播 │
│ 每个设备独立计算前向 │
│ │
│ 3. 并行反向传播 │
│ 每个设备独立计算梯度 │
│ │
│ 4. 梯度聚合 │
│ All-Reduce 平均梯度 │
│ │
│ 5. 参数更新 │
│ 所有设备使用相同梯度更新 │
│ │
└─────────────────────────────────────────────────────────────┘
使用 pmap #
基本数据并行 #
python
import jax
import jax.numpy as jnp
@jax.pmap
def train_step(params, batch):
def loss_fn(p):
x, y = batch
pred = jnp.dot(x, p)
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
grads = jax.lax.pmean(grads, 'device')
new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
return new_params
key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (10, 5))
num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())
batch_x = jax.random.normal(key, (num_devices, 32, 10))
batch_y = jax.random.normal(key, (num_devices, 32, 5))
batch = (batch_x, batch_y)
for step in range(10):
params = train_step(params, batch)
完整训练示例 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_mlp_params(key, layer_sizes):
params = []
for i, (in_size, out_size) in enumerate(layer_sizes):
key, w_key, b_key = jax.random.split(key, 3)
w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
b = jnp.zeros(out_size)
params.append({'w': w, 'b': b})
return params
def forward(params, x):
for layer in params[:-1]:
x = jnp.dot(x, layer['w']) + layer['b']
x = nn.relu(x)
x = jnp.dot(x, params[-1]['w']) + params[-1]['b']
return x
def loss_fn(params, batch):
x, y = batch
logits = forward(params, x)
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, logits.shape[-1])
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
@jax.pmap
def train_step(params, batch):
loss, grads = jax.value_and_grad(loss_fn)(params, batch)
grads = jax.lax.pmean(grads, 'device')
loss = jax.lax.pmean(loss, 'device')
new_params = jax.tree_map(
lambda p, g: p - 0.001 * g,
params, grads
)
return new_params, loss
key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])
num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())
def create_batch(key, batch_size, num_devices):
batch_x = jax.random.normal(key, (num_devices, batch_size, 784))
batch_y = jax.random.randint(key, (num_devices, batch_size), 0, 10)
return batch_x, batch_y
for epoch in range(10):
key, subkey = jax.random.split(key)
batch = create_batch(subkey, 32, num_devices)
params, loss = train_step(params, batch)
print(f"Epoch {epoch}: loss={loss[0]:.4f}")
使用 pjit #
pjit 基础 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
mesh = jax.sharding.Mesh(
jax.devices(),
('devices',)
)
@jax.jit
def train_step_pjit(params, batch):
x, y = batch
def loss_fn(p):
pred = jnp.dot(x, p)
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (10, 5))
batch_x = jax.random.normal(key, (32, 10))
batch_y = jax.random.normal(key, (32, 5))
new_params = train_step_pjit(params, (batch_x, batch_y))
带分片的 pjit #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
def create_sharded_params(key, shape, mesh):
params = jax.random.normal(key, shape)
spec = P()
sharding = jax.sharding.NamedSharding(mesh, spec)
return jax.device_put(params, sharding)
def create_sharded_batch(key, batch_size, input_size, mesh):
x = jax.random.normal(key, (batch_size, input_size))
spec = P('devices')
sharding = jax.sharding.NamedSharding(mesh, spec)
return jax.device_put(x, sharding)
梯度聚合策略 #
All-Reduce #
python
import jax
import jax.numpy as jnp
@jax.pmap
def all_reduce_example(x):
summed = jax.lax.psum(x, 'i')
return summed
x = jnp.arange(jax.device_count())
result = all_reduce_example(x)
print(f"All-Reduce 结果: {result}")
All-Gather #
python
import jax
import jax.numpy as jnp
@jax.pmap
def all_gather_example(x):
gathered = jax.lax.all_gather(x, 'i')
return gathered
x = jnp.arange(jax.device_count())
result = all_gather_example(x)
print(f"All-Gather 结果: {result}")
性能优化 #
重叠计算和通信 #
python
import jax
import jax.numpy as jnp
@jax.pmap
def optimized_train_step(params, batch):
def loss_fn(p):
x, y = batch
pred = jnp.dot(x, p)
return jnp.mean((pred - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(params)
grads = jax.lax.pmean(grads, 'device')
new_params = jax.tree_map(
lambda p, g: p - 0.01 * g,
params, grads
)
return new_params, jax.lax.pmean(loss, 'device')
梯度累积 #
python
import jax
import jax.numpy as jnp
@jax.pmap
def accumulate_gradients(params, batches):
def scan_fn(grads, batch):
def loss_fn(p):
x, y = batch
return jnp.mean((jnp.dot(x, p) - y) ** 2)
batch_grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda g1, g2: g1 + g2, grads, batch_grads), None
init_grads = jax.tree_map(jnp.zeros_like, params)
final_grads, _ = jax.lax.scan(scan_fn, init_grads, batches)
final_grads = jax.lax.pmean(final_grads, 'device')
final_grads = jax.tree_map(lambda g: g / batches.shape[0], final_grads)
new_params = jax.tree_map(
lambda p, g: p - 0.01 * g,
params, final_grads
)
return new_params
下一步 #
现在你已经掌握了数据并行,接下来学习 模型并行,了解如何处理大模型!
最后更新:2026-04-04