多设备计算 #
概述 #
JAX 提供了统一的多设备计算 API,可以在 CPU、GPU 和 TPU 上运行相同的代码。本节介绍如何管理设备和执行跨设备操作。
设备管理 #
查看设备 #
python
import jax
print(f"所有设备: {jax.devices()}")
print(f"CPU 设备: {jax.devices('cpu')}")
print(f"GPU 设备: {jax.devices('gpu')}")
print(f"设备数量: {jax.device_count()}")
print(f"默认后端: {jax.default_backend()}")
设备属性 #
python
import jax
for device in jax.devices():
print(f"设备: {device}")
print(f" 平台: {device.platform}")
print(f" 设备类型: {device.device_kind}")
print(f" ID: {device.id}")
选择设备 #
python
import jax
import jax.numpy as jnp
with jax.default_device(jax.devices('gpu')[0]):
x = jnp.array([1, 2, 3])
print(f"数组设备: {x.devices()}")
数据放置 #
设备放置 #
python
import jax
import jax.numpy as jnp
x = jax.device_put(jnp.array([1, 2, 3]), jax.devices('gpu')[0])
print(f"数组设备: {x.devices()}")
y = jax.device_put(x, jax.devices('cpu')[0])
print(f"转移后设备: {y.devices()}")
批量放置 #
python
import jax
import jax.numpy as jnp
def put_on_devices(arrays, devices):
return [jax.device_put(arr, devices[i % len(devices)])
for i, arr in enumerate(arrays)]
arrays = [jnp.ones((100, 100)) for _ in range(4)]
devices = jax.devices('gpu')
placed = put_on_devices(arrays, devices)
设备计算 #
单设备计算 #
python
import jax
import jax.numpy as jnp
@jax.jit
def compute_on_device(x):
return jnp.dot(x, x.T)
x = jnp.ones((1000, 1000))
result = compute_on_device(x)
print(f"结果设备: {result.devices()}")
指定设备计算 #
python
import jax
import jax.numpy as jnp
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jnp.ones((1000, 1000))
with jax.default_device(jax.devices('gpu')[0]):
result = compute(x)
print(f"结果设备: {result.devices()}")
多设备并行 #
pmap 基础 #
python
import jax
import jax.numpy as jnp
def process(x):
return x ** 2
parallel_process = jax.pmap(process)
num_devices = jax.device_count()
batch_x = jnp.ones((num_devices, 100))
result = parallel_process(batch_x)
print(f"结果形状: {result.shape}")
print(f"结果设备: {[r.devices() for r in result]}")
pmap 带参数 #
python
import jax
import jax.numpy as jnp
def train_step(params, x, y):
def loss_fn(p):
return jnp.mean((jnp.dot(x, p) - y) ** 2)
grads = jax.grad(loss_fn)(params)
return params - 0.01 * grads
parallel_train = jax.pmap(train_step, in_axes=(None, 0, 0))
params = jnp.ones(10)
batch_x = jnp.ones((jax.device_count(), 32, 10))
batch_y = jnp.ones((jax.device_count(), 32))
new_params = parallel_train(params, batch_x, batch_y)
设备间通信 #
psum #
python
import jax
import jax.numpy as jnp
@jax.pmap
def sum_across_devices(x):
return jax.lax.psum(x, 'i')
x = jnp.arange(jax.device_count())
result = sum_across_devices(x)
print(f"跨设备求和: {result}")
pmean #
python
import jax
import jax.numpy as jnp
@jax.pmap
def mean_across_devices(x):
return jax.lax.pmean(x, 'i')
x = jnp.arange(jax.device_count())
result = mean_across_devices(x)
print(f"跨设备均值: {result}")
all_gather #
python
import jax
import jax.numpy as jnp
@jax.pmap
def gather_all(x):
return jax.lax.all_gather(x, 'i')
x = jnp.arange(jax.device_count())
result = gather_all(x)
print(f"收集结果: {result}")
实际应用 #
多 GPU 训练 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_params(key, input_size, output_size):
w = jax.random.normal(key, (input_size, output_size)) * 0.01
b = jnp.zeros(output_size)
return {'w': w, 'b': b}
def forward(params, x):
return jnp.dot(x, params['w']) + params['b']
def loss_fn(params, x, y):
pred = forward(params, x)
return jnp.mean((pred - y) ** 2)
@jax.pmap
def train_step(params, x, y):
grads = jax.grad(loss_fn)(params, x, y)
grads = jax.lax.pmean(grads, 'device')
new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
loss = loss_fn(params, x, y)
return new_params, loss
key = jax.random.PRNGKey(0)
params = init_params(key, 10, 5)
num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())
batch_x = jax.random.normal(key, (num_devices, 32, 10))
batch_y = jax.random.normal(key, (num_devices, 32, 5))
for step in range(100):
params, loss = train_step(params, batch_x, batch_y)
if step % 20 == 0:
print(f"Step {step}: loss={loss[0]:.4f}")
下一步 #
现在你已经掌握了多设备计算,接下来学习 数据并行,深入了解分布式训练!
最后更新:2026-04-04