模型并行 #

概述 #

当模型太大无法放入单个设备的内存时,需要使用模型并行。模型并行将模型分割到多个设备上,包括张量并行和流水线并行。

模型并行类型 #

text
┌─────────────────────────────────────────────────────────────┐
│                    模型并行类型                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  张量并行(Tensor Parallelism)                              │
│  ─── 将单个张量分割到多个设备                                │
│  ─── 适合大型线性层                                          │
│  ─── 需要频繁的设备间通信                                    │
│                                                             │
│  流水线并行(Pipeline Parallelism)                          │
│  ─── 将模型层分割到多个设备                                  │
│  ─── 适合深层网络                                            │
│  ─── 通信较少但有流水线气泡                                  │
│                                                             │
│  混合并行                                                    │
│  ─── 结合数据并行和模型并行                                  │
│  ─── 大模型训练的标准方案                                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

张量并行 #

矩阵乘法并行 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

mesh = jax.sharding.Mesh(
    jax.devices()[:2],  
    ('tp',)
)

def parallel_matmul(x, w, mesh):
    x_sharding = NamedSharding(mesh, P())
    w_sharding = NamedSharding(mesh, P('tp', None))
    
    x = jax.device_put(x, x_sharding)
    w = jax.device_put(w, w_sharding)
    
    return jnp.dot(x, w)

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 1024))
w = jax.random.normal(key, (1024, 4096))

result = parallel_matmul(x, w, mesh)
print(f"结果形状: {result.shape}")

并行线性层 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

class ParallelLinear:
    def __init__(self, in_features, out_features, mesh, name='linear'):
        self.mesh = mesh
        self.in_features = in_features
        self.out_features = out_features
        
        key = jax.random.PRNGKey(0)
        self.w = jax.random.normal(key, (in_features, out_features)) * 0.01
        self.b = jnp.zeros(out_features)
        
        w_sharding = NamedSharding(mesh, P('tp', None))
        self.w = jax.device_put(self.w, w_sharding)
    
    def __call__(self, x):
        return jnp.dot(x, self.w) + self.b

mesh = jax.sharding.Mesh(jax.devices()[:2], ('tp',))
layer = ParallelLinear(1024, 4096, mesh)
x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = layer(x)
print(f"输出形状: {output.shape}")

流水线并行 #

基本流水线 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def stage1(params1, x):
    return nn.relu(jnp.dot(x, params1['w']) + params1['b'])

def stage2(params2, x):
    return jnp.dot(x, params2['w']) + params2['b']

def pipeline_forward(params, x, devices):
    params1, params2 = params
    device1, device2 = devices
    
    with jax.default_device(device1):
        x = stage1(params1, x)
        x = jax.device_put(x, device2)
    
    with jax.default_device(device2):
        output = stage2(params2, x)
    
    return output

devices = jax.devices()[:2]
params = [
    {'w': jax.random.normal(jax.random.PRNGKey(0), (784, 256)) * 0.01,
     'b': jnp.zeros(256)},
    {'w': jax.random.normal(jax.random.PRNGKey(1), (256, 10)) * 0.01,
     'b': jnp.zeros(10)}
]

x = jax.random.normal(jax.random.PRNGKey(2), (32, 784))
output = pipeline_forward(params, x, devices)
print(f"输出形状: {output.shape}")

微批次流水线 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def pipeline_with_microbatches(params, x, num_microbatches, devices):
    microbatches = jnp.array_split(x, num_microbatches)
    
    outputs = []
    for mb in microbatches:
        output = pipeline_forward(params, mb, devices)
        outputs.append(output)
    
    return jnp.concatenate(outputs, axis=0)

num_microbatches = 4
x = jax.random.normal(jax.random.PRNGKey(0), (128, 784))
output = pipeline_with_microbatches(params, x, num_microbatches, devices)
print(f"输出形状: {output.shape}")

混合并行 #

2D 并行 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

devices = jax.devices()[:4]
mesh = jax.sharding.Mesh(
    devices.reshape(2, 2),
    ('dp', 'tp')
)

def create_2d_parallel_params(key, in_features, out_features, mesh):
    w = jax.random.normal(key, (in_features, out_features)) * 0.01
    
    w_sharding = NamedSharding(mesh, P(None, 'tp'))
    w = jax.device_put(w, w_sharding)
    
    return w

key = jax.random.PRNGKey(0)
w = create_2d_parallel_params(key, 1024, 4096, mesh)
print(f"权重形状: {w.shape}")

3D 并行 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

devices = jax.devices()[:8]
mesh = jax.sharding.Mesh(
    devices.reshape(2, 2, 2),
    ('dp', 'tp', 'pp')
)

def forward_3d_parallel(params, x, mesh):
    x_sharding = NamedSharding(mesh, P('dp', None))
    x = jax.device_put(x, x_sharding)
    
    for w in params:
        w_sharding = NamedSharding(mesh, P(None, 'tp'))
        x = jnp.dot(x, jax.device_put(w, w_sharding))
        x = nn.relu(x)
    
    return x

params = [
    jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01
    for i in range(4)
]

x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = forward_3d_parallel(params, x, mesh)
print(f"输出形状: {output.shape}")

大模型训练 #

内存优化 #

python
import jax
import jax.numpy as jnp

def gradient_checkpointing(forward_fn, params, x):
    @jax.remat
    def remat_forward(p):
        return forward_fn(p, x)
    
    return remat_forward(params)

def forward_large_model(params, x):
    for w, b in params:
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    return x

params = [
    (jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01,
     jnp.zeros(1024))
    for i in range(100)
]

x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))

output = gradient_checkpointing(forward_large_model, params, x)
print(f"输出形状: {output.shape}")

ZeRO 优化 #

python
import jax
import jax.numpy as jnp

def zero_optimization_step(params, grads, optimizer_state, mesh):
    new_params = jax.tree_map(
        lambda p, g: p - 0.001 * g,
        params, grads
    )
    
    return new_params, optimizer_state

下一步 #

现在你已经掌握了模型并行,接下来学习 TPU 加速,了解如何在 TPU 上训练模型!

最后更新:2026-04-04