模型并行 #
概述 #
当模型太大无法放入单个设备的内存时,需要使用模型并行。模型并行将模型分割到多个设备上,包括张量并行和流水线并行。
模型并行类型 #
text
┌─────────────────────────────────────────────────────────────┐
│ 模型并行类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 张量并行(Tensor Parallelism) │
│ ─── 将单个张量分割到多个设备 │
│ ─── 适合大型线性层 │
│ ─── 需要频繁的设备间通信 │
│ │
│ 流水线并行(Pipeline Parallelism) │
│ ─── 将模型层分割到多个设备 │
│ ─── 适合深层网络 │
│ ─── 通信较少但有流水线气泡 │
│ │
│ 混合并行 │
│ ─── 结合数据并行和模型并行 │
│ ─── 大模型训练的标准方案 │
│ │
└─────────────────────────────────────────────────────────────┘
张量并行 #
矩阵乘法并行 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
mesh = jax.sharding.Mesh(
jax.devices()[:2],
('tp',)
)
def parallel_matmul(x, w, mesh):
x_sharding = NamedSharding(mesh, P())
w_sharding = NamedSharding(mesh, P('tp', None))
x = jax.device_put(x, x_sharding)
w = jax.device_put(w, w_sharding)
return jnp.dot(x, w)
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 1024))
w = jax.random.normal(key, (1024, 4096))
result = parallel_matmul(x, w, mesh)
print(f"结果形状: {result.shape}")
并行线性层 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
class ParallelLinear:
def __init__(self, in_features, out_features, mesh, name='linear'):
self.mesh = mesh
self.in_features = in_features
self.out_features = out_features
key = jax.random.PRNGKey(0)
self.w = jax.random.normal(key, (in_features, out_features)) * 0.01
self.b = jnp.zeros(out_features)
w_sharding = NamedSharding(mesh, P('tp', None))
self.w = jax.device_put(self.w, w_sharding)
def __call__(self, x):
return jnp.dot(x, self.w) + self.b
mesh = jax.sharding.Mesh(jax.devices()[:2], ('tp',))
layer = ParallelLinear(1024, 4096, mesh)
x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = layer(x)
print(f"输出形状: {output.shape}")
流水线并行 #
基本流水线 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def stage1(params1, x):
return nn.relu(jnp.dot(x, params1['w']) + params1['b'])
def stage2(params2, x):
return jnp.dot(x, params2['w']) + params2['b']
def pipeline_forward(params, x, devices):
params1, params2 = params
device1, device2 = devices
with jax.default_device(device1):
x = stage1(params1, x)
x = jax.device_put(x, device2)
with jax.default_device(device2):
output = stage2(params2, x)
return output
devices = jax.devices()[:2]
params = [
{'w': jax.random.normal(jax.random.PRNGKey(0), (784, 256)) * 0.01,
'b': jnp.zeros(256)},
{'w': jax.random.normal(jax.random.PRNGKey(1), (256, 10)) * 0.01,
'b': jnp.zeros(10)}
]
x = jax.random.normal(jax.random.PRNGKey(2), (32, 784))
output = pipeline_forward(params, x, devices)
print(f"输出形状: {output.shape}")
微批次流水线 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def pipeline_with_microbatches(params, x, num_microbatches, devices):
microbatches = jnp.array_split(x, num_microbatches)
outputs = []
for mb in microbatches:
output = pipeline_forward(params, mb, devices)
outputs.append(output)
return jnp.concatenate(outputs, axis=0)
num_microbatches = 4
x = jax.random.normal(jax.random.PRNGKey(0), (128, 784))
output = pipeline_with_microbatches(params, x, num_microbatches, devices)
print(f"输出形状: {output.shape}")
混合并行 #
2D 并行 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
devices = jax.devices()[:4]
mesh = jax.sharding.Mesh(
devices.reshape(2, 2),
('dp', 'tp')
)
def create_2d_parallel_params(key, in_features, out_features, mesh):
w = jax.random.normal(key, (in_features, out_features)) * 0.01
w_sharding = NamedSharding(mesh, P(None, 'tp'))
w = jax.device_put(w, w_sharding)
return w
key = jax.random.PRNGKey(0)
w = create_2d_parallel_params(key, 1024, 4096, mesh)
print(f"权重形状: {w.shape}")
3D 并行 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
devices = jax.devices()[:8]
mesh = jax.sharding.Mesh(
devices.reshape(2, 2, 2),
('dp', 'tp', 'pp')
)
def forward_3d_parallel(params, x, mesh):
x_sharding = NamedSharding(mesh, P('dp', None))
x = jax.device_put(x, x_sharding)
for w in params:
w_sharding = NamedSharding(mesh, P(None, 'tp'))
x = jnp.dot(x, jax.device_put(w, w_sharding))
x = nn.relu(x)
return x
params = [
jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01
for i in range(4)
]
x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = forward_3d_parallel(params, x, mesh)
print(f"输出形状: {output.shape}")
大模型训练 #
内存优化 #
python
import jax
import jax.numpy as jnp
def gradient_checkpointing(forward_fn, params, x):
@jax.remat
def remat_forward(p):
return forward_fn(p, x)
return remat_forward(params)
def forward_large_model(params, x):
for w, b in params:
x = jnp.dot(x, w) + b
x = jax.nn.relu(x)
return x
params = [
(jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01,
jnp.zeros(1024))
for i in range(100)
]
x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = gradient_checkpointing(forward_large_model, params, x)
print(f"输出形状: {output.shape}")
ZeRO 优化 #
python
import jax
import jax.numpy as jnp
def zero_optimization_step(params, grads, optimizer_state, mesh):
new_params = jax.tree_map(
lambda p, g: p - 0.001 * g,
params, grads
)
return new_params, optimizer_state
下一步 #
现在你已经掌握了模型并行,接下来学习 TPU 加速,了解如何在 TPU 上训练模型!
最后更新:2026-04-04