多设备计算 #

概述 #

JAX 提供了统一的多设备计算 API,可以在 CPU、GPU 和 TPU 上运行相同的代码。本节介绍如何管理设备和执行跨设备操作。

设备管理 #

查看设备 #

python
import jax

print(f"所有设备: {jax.devices()}")
print(f"CPU 设备: {jax.devices('cpu')}")
print(f"GPU 设备: {jax.devices('gpu')}")
print(f"设备数量: {jax.device_count()}")
print(f"默认后端: {jax.default_backend()}")

设备属性 #

python
import jax

for device in jax.devices():
    print(f"设备: {device}")
    print(f"  平台: {device.platform}")
    print(f"  设备类型: {device.device_kind}")
    print(f"  ID: {device.id}")

选择设备 #

python
import jax
import jax.numpy as jnp

with jax.default_device(jax.devices('gpu')[0]):
    x = jnp.array([1, 2, 3])
    print(f"数组设备: {x.devices()}")

数据放置 #

设备放置 #

python
import jax
import jax.numpy as jnp

x = jax.device_put(jnp.array([1, 2, 3]), jax.devices('gpu')[0])
print(f"数组设备: {x.devices()}")

y = jax.device_put(x, jax.devices('cpu')[0])
print(f"转移后设备: {y.devices()}")

批量放置 #

python
import jax
import jax.numpy as jnp

def put_on_devices(arrays, devices):
    return [jax.device_put(arr, devices[i % len(devices)]) 
            for i, arr in enumerate(arrays)]

arrays = [jnp.ones((100, 100)) for _ in range(4)]
devices = jax.devices('gpu')
placed = put_on_devices(arrays, devices)

设备计算 #

单设备计算 #

python
import jax
import jax.numpy as jnp

@jax.jit
def compute_on_device(x):
    return jnp.dot(x, x.T)

x = jnp.ones((1000, 1000))
result = compute_on_device(x)
print(f"结果设备: {result.devices()}")

指定设备计算 #

python
import jax
import jax.numpy as jnp

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x = jnp.ones((1000, 1000))

with jax.default_device(jax.devices('gpu')[0]):
    result = compute(x)
    print(f"结果设备: {result.devices()}")

多设备并行 #

pmap 基础 #

python
import jax
import jax.numpy as jnp

def process(x):
    return x ** 2

parallel_process = jax.pmap(process)

num_devices = jax.device_count()
batch_x = jnp.ones((num_devices, 100))

result = parallel_process(batch_x)
print(f"结果形状: {result.shape}")
print(f"结果设备: {[r.devices() for r in result]}")

pmap 带参数 #

python
import jax
import jax.numpy as jnp

def train_step(params, x, y):
    def loss_fn(p):
        return jnp.mean((jnp.dot(x, p) - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    return params - 0.01 * grads

parallel_train = jax.pmap(train_step, in_axes=(None, 0, 0))

params = jnp.ones(10)
batch_x = jnp.ones((jax.device_count(), 32, 10))
batch_y = jnp.ones((jax.device_count(), 32))

new_params = parallel_train(params, batch_x, batch_y)

设备间通信 #

psum #

python
import jax
import jax.numpy as jnp

@jax.pmap
def sum_across_devices(x):
    return jax.lax.psum(x, 'i')

x = jnp.arange(jax.device_count())
result = sum_across_devices(x)
print(f"跨设备求和: {result}")

pmean #

python
import jax
import jax.numpy as jnp

@jax.pmap
def mean_across_devices(x):
    return jax.lax.pmean(x, 'i')

x = jnp.arange(jax.device_count())
result = mean_across_devices(x)
print(f"跨设备均值: {result}")

all_gather #

python
import jax
import jax.numpy as jnp

@jax.pmap
def gather_all(x):
    return jax.lax.all_gather(x, 'i')

x = jnp.arange(jax.device_count())
result = gather_all(x)
print(f"收集结果: {result}")

实际应用 #

多 GPU 训练 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_params(key, input_size, output_size):
    w = jax.random.normal(key, (input_size, output_size)) * 0.01
    b = jnp.zeros(output_size)
    return {'w': w, 'b': b}

def forward(params, x):
    return jnp.dot(x, params['w']) + params['b']

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.pmap
def train_step(params, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    
    grads = jax.lax.pmean(grads, 'device')
    
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    loss = loss_fn(params, x, y)
    
    return new_params, loss

key = jax.random.PRNGKey(0)
params = init_params(key, 10, 5)

num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())

batch_x = jax.random.normal(key, (num_devices, 32, 10))
batch_y = jax.random.normal(key, (num_devices, 32, 5))

for step in range(100):
    params, loss = train_step(params, batch_x, batch_y)
    if step % 20 == 0:
        print(f"Step {step}: loss={loss[0]:.4f}")

下一步 #

现在你已经掌握了多设备计算,接下来学习 数据并行,深入了解分布式训练!

最后更新:2026-04-04