数据并行 #

概述 #

数据并行是分布式训练中最常用的策略,它将数据分割到多个设备上并行计算,然后聚合梯度更新模型。

数据并行原理 #

text
┌─────────────────────────────────────────────────────────────┐
│                    数据并行流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 数据分割                                                │
│     Batch → [Batch/N, Batch/N, ..., Batch/N]               │
│                                                             │
│  2. 并行前向传播                                            │
│     每个设备独立计算前向                                     │
│                                                             │
│  3. 并行反向传播                                            │
│     每个设备独立计算梯度                                     │
│                                                             │
│  4. 梯度聚合                                                │
│     All-Reduce 平均梯度                                     │
│                                                             │
│  5. 参数更新                                                │
│     所有设备使用相同梯度更新                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

使用 pmap #

基本数据并行 #

python
import jax
import jax.numpy as jnp

@jax.pmap
def train_step(params, batch):
    def loss_fn(p):
        x, y = batch
        pred = jnp.dot(x, p)
        return jnp.mean((pred - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    
    grads = jax.lax.pmean(grads, 'device')
    
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    
    return new_params

key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (10, 5))

num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())

batch_x = jax.random.normal(key, (num_devices, 32, 10))
batch_y = jax.random.normal(key, (num_devices, 32, 5))
batch = (batch_x, batch_y)

for step in range(10):
    params = train_step(params, batch)

完整训练示例 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_mlp_params(key, layer_sizes):
    params = []
    for i, (in_size, out_size) in enumerate(layer_sizes):
        key, w_key, b_key = jax.random.split(key, 3)
        w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
        b = jnp.zeros(out_size)
        params.append({'w': w, 'b': b})
    return params

def forward(params, x):
    for layer in params[:-1]:
        x = jnp.dot(x, layer['w']) + layer['b']
        x = nn.relu(x)
    x = jnp.dot(x, params[-1]['w']) + params[-1]['b']
    return x

def loss_fn(params, batch):
    x, y = batch
    logits = forward(params, x)
    log_probs = nn.log_softmax(logits)
    one_hot = nn.one_hot(y, logits.shape[-1])
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

@jax.pmap
def train_step(params, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    
    grads = jax.lax.pmean(grads, 'device')
    loss = jax.lax.pmean(loss, 'device')
    
    new_params = jax.tree_map(
        lambda p, g: p - 0.001 * g,
        params, grads
    )
    
    return new_params, loss

key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])

num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())

def create_batch(key, batch_size, num_devices):
    batch_x = jax.random.normal(key, (num_devices, batch_size, 784))
    batch_y = jax.random.randint(key, (num_devices, batch_size), 0, 10)
    return batch_x, batch_y

for epoch in range(10):
    key, subkey = jax.random.split(key)
    batch = create_batch(subkey, 32, num_devices)
    params, loss = train_step(params, batch)
    print(f"Epoch {epoch}: loss={loss[0]:.4f}")

使用 pjit #

pjit 基础 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

mesh = jax.sharding.Mesh(
    jax.devices(), 
    ('devices',)
)

@jax.jit
def train_step_pjit(params, batch):
    x, y = batch
    
    def loss_fn(p):
        pred = jnp.dot(x, p)
        return jnp.mean((pred - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)

key = jax.random.PRNGKey(0)
params = jax.random.normal(key, (10, 5))
batch_x = jax.random.normal(key, (32, 10))
batch_y = jax.random.normal(key, (32, 5))

new_params = train_step_pjit(params, (batch_x, batch_y))

带分片的 pjit #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

def create_sharded_params(key, shape, mesh):
    params = jax.random.normal(key, shape)
    spec = P()
    sharding = jax.sharding.NamedSharding(mesh, spec)
    return jax.device_put(params, sharding)

def create_sharded_batch(key, batch_size, input_size, mesh):
    x = jax.random.normal(key, (batch_size, input_size))
    spec = P('devices')
    sharding = jax.sharding.NamedSharding(mesh, spec)
    return jax.device_put(x, sharding)

梯度聚合策略 #

All-Reduce #

python
import jax
import jax.numpy as jnp

@jax.pmap
def all_reduce_example(x):
    summed = jax.lax.psum(x, 'i')
    return summed

x = jnp.arange(jax.device_count())
result = all_reduce_example(x)
print(f"All-Reduce 结果: {result}")

All-Gather #

python
import jax
import jax.numpy as jnp

@jax.pmap
def all_gather_example(x):
    gathered = jax.lax.all_gather(x, 'i')
    return gathered

x = jnp.arange(jax.device_count())
result = all_gather_example(x)
print(f"All-Gather 结果: {result}")

性能优化 #

重叠计算和通信 #

python
import jax
import jax.numpy as jnp

@jax.pmap
def optimized_train_step(params, batch):
    def loss_fn(p):
        x, y = batch
        pred = jnp.dot(x, p)
        return jnp.mean((pred - y) ** 2)
    
    loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(params)
    
    grads = jax.lax.pmean(grads, 'device')
    
    new_params = jax.tree_map(
        lambda p, g: p - 0.01 * g,
        params, grads
    )
    
    return new_params, jax.lax.pmean(loss, 'device')

梯度累积 #

python
import jax
import jax.numpy as jnp

@jax.pmap
def accumulate_gradients(params, batches):
    def scan_fn(grads, batch):
        def loss_fn(p):
            x, y = batch
            return jnp.mean((jnp.dot(x, p) - y) ** 2)
        batch_grads = jax.grad(loss_fn)(params)
        return jax.tree_map(lambda g1, g2: g1 + g2, grads, batch_grads), None
    
    init_grads = jax.tree_map(jnp.zeros_like, params)
    final_grads, _ = jax.lax.scan(scan_fn, init_grads, batches)
    
    final_grads = jax.lax.pmean(final_grads, 'device')
    final_grads = jax.tree_map(lambda g: g / batches.shape[0], final_grads)
    
    new_params = jax.tree_map(
        lambda p, g: p - 0.01 * g,
        params, final_grads
    )
    
    return new_params

下一步 #

现在你已经掌握了数据并行,接下来学习 模型并行,了解如何处理大模型!

最后更新:2026-04-04