TPU 加速 #
概述 #
TPU(Tensor Processing Unit)是 Google 专门为机器学习设计的加速器。JAX 对 TPU 有原生支持,可以无缝地在 TPU 上运行。
TPU 基础 #
检查 TPU #
python
import jax
print(f"TPU 设备: {jax.devices('tpu')}")
print(f"TPU 数量: {len(jax.devices('tpu'))}")
print(f"默认后端: {jax.default_backend()}")
TPU 初始化 #
python
import jax
if 'tpu' in str(jax.devices()).lower():
print("TPU 可用")
else:
print("TPU 不可用,使用 CPU/GPU")
Colab TPU 设置 #
python
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(f"TPU 设备: {jax.devices()}")
TPU 架构 #
text
┌─────────────────────────────────────────────────────────────┐
│ TPU 架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ TPU v2 │
│ ─── 8 GB HBM │
│ ─── 180 TFLOPS (bf16) │
│ ─── 适合研究和开发 │
│ │
│ TPU v3 │
│ ─── 16 GB HBM │
│ ─── 420 TFLOPS (bf16) │
│ ─── 适合大规模训练 │
│ │
│ TPU v4 │
│ ─── 32 GB HBM │
│ ─── 275 TFLOPS (bf16) │
│ ─── 最高性能 │
│ │
│ TPU Pod │
│ ─── 多个 TPU 互联 │
│ ─── 高速互联 │
│ ─── 超大规模训练 │
│ │
└─────────────────────────────────────────────────────────────┘
TPU 训练 #
基本 TPU 训练 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_params(key, layer_sizes):
params = []
for in_size, out_size in layer_sizes:
key, w_key, b_key = jax.random.split(key, 3)
w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
b = jnp.zeros(out_size)
params.append({'w': w, 'b': b})
return params
def forward(params, x):
for layer in params[:-1]:
x = nn.relu(jnp.dot(x, layer['w']) + layer['b'])
return jnp.dot(x, params[-1]['w']) + params[-1]['b']
def loss_fn(params, x, y):
pred = forward(params, x)
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(params, x, y, lr=0.001):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
key = jax.random.PRNGKey(0)
params = init_params(key, [(784, 256), (256, 10)])
x = jax.random.normal(key, (1024, 784))
y = jax.random.normal(key, (1024, 10))
for step in range(100):
params, loss = train_step(params, x, y)
if step % 20 == 0:
print(f"Step {step}: loss={loss:.4f}")
TPU 数据并行 #
python
import jax
import jax.numpy as jnp
@jax.pmap
def tpu_train_step(params, batch):
x, y = batch
def loss_fn(p):
pred = forward(p, x)
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
grads = jax.lax.pmean(grads, 'device')
new_params = jax.tree_map(
lambda p, g: p - 0.001 * g,
params, grads
)
return new_params
num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())
batch_x = jax.random.normal(key, (num_devices, 128, 784))
batch_y = jax.random.normal(key, (num_devices, 128, 10))
for step in range(100):
params = tpu_train_step(params, (batch_x, batch_y))
TPU Pod #
Pod 训练 #
python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding
num_devices = jax.device_count()
mesh = jax.sharding.Mesh(
jax.devices().reshape(-1, 8),
('data', 'model')
)
def create_sharded_batch(key, batch_size, input_size, mesh):
x = jax.random.normal(key, (batch_size, input_size))
sharding = NamedSharding(mesh, P('data', None))
return jax.device_put(x, sharding)
大规模训练 #
python
import jax
import jax.numpy as jnp
def large_scale_train_step(params, batch, mesh):
@jax.jit
def step(p, x, y):
def loss_fn(params):
pred = forward(params, x)
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(p)
return jax.tree_map(lambda p, g: p - 0.001 * g, p, grads)
x, y = batch
return step(params, x, y)
性能优化 #
内存管理 #
python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'
import jax
批处理优化 #
python
import jax
import jax.numpy as jnp
def optimal_batch_size(model_size, memory_gb):
bytes_per_param = 4
available_memory = memory_gb * 1e9
model_memory = model_size * bytes_per_param
activation_memory = available_memory * 0.3
return int(activation_memory / model_memory)
batch_size = optimal_batch_size(model_size=1e9, memory_gb=16)
print(f"推荐批大小: {batch_size}")
混合精度 #
python
import jax
import jax.numpy as jnp
def mixed_precision_forward(params, x):
x = x.astype(jnp.bfloat16)
for w, b in params:
w = w.astype(jnp.bfloat16)
x = jnp.dot(x, w) + b
x = jax.nn.relu(x)
return x.astype(jnp.float32)
params = [
(jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)),
jnp.zeros(1024))
for i in range(10)
]
x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = mixed_precision_forward(params, x)
print(f"输出形状: {output.shape}")
print(f"输出类型: {output.dtype}")
TPU 最佳实践 #
1. 使用 pmap 或 pjit #
python
import jax
@jax.pmap
def tpu_parallel_step(params, batch):
return train_step(params, batch)
2. 批处理最大化 #
python
import jax.numpy as jnp
def create_large_batch(data_loader, target_batch_size):
batches = []
current_size = 0
for batch in data_loader:
batches.append(batch)
current_size += batch[0].shape[0]
if current_size >= target_batch_size:
break
x = jnp.concatenate([b[0] for b in batches], axis=0)
y = jnp.concatenate([b[1] for b in batches], axis=0)
return x[:target_batch_size], y[:target_batch_size]
3. 预取数据 #
python
import jax
import jax.numpy as jnp
def prefetch_to_tpu(batch, device):
return jax.device_put(batch, device)
下一步 #
现在你已经掌握了 TPU 加速,接下来学习 自定义操作,了解如何扩展 JAX 功能!
最后更新:2026-04-04