TPU 加速 #

概述 #

TPU(Tensor Processing Unit)是 Google 专门为机器学习设计的加速器。JAX 对 TPU 有原生支持,可以无缝地在 TPU 上运行。

TPU 基础 #

检查 TPU #

python
import jax

print(f"TPU 设备: {jax.devices('tpu')}")
print(f"TPU 数量: {len(jax.devices('tpu'))}")
print(f"默认后端: {jax.default_backend()}")

TPU 初始化 #

python
import jax

if 'tpu' in str(jax.devices()).lower():
    print("TPU 可用")
else:
    print("TPU 不可用,使用 CPU/GPU")

Colab TPU 设置 #

python
import jax
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()

print(f"TPU 设备: {jax.devices()}")

TPU 架构 #

text
┌─────────────────────────────────────────────────────────────┐
│                    TPU 架构                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  TPU v2                                                     │
│  ─── 8 GB HBM                                               │
│  ─── 180 TFLOPS (bf16)                                      │
│  ─── 适合研究和开发                                          │
│                                                             │
│  TPU v3                                                     │
│  ─── 16 GB HBM                                              │
│  ─── 420 TFLOPS (bf16)                                      │
│  ─── 适合大规模训练                                          │
│                                                             │
│  TPU v4                                                     │
│  ─── 32 GB HBM                                              │
│  ─── 275 TFLOPS (bf16)                                      │
│  ─── 最高性能                                                │
│                                                             │
│  TPU Pod                                                    │
│  ─── 多个 TPU 互联                                          │
│  ─── 高速互联                                                │
│  ─── 超大规模训练                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

TPU 训练 #

基本 TPU 训练 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_params(key, layer_sizes):
    params = []
    for in_size, out_size in layer_sizes:
        key, w_key, b_key = jax.random.split(key, 3)
        w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
        b = jnp.zeros(out_size)
        params.append({'w': w, 'b': b})
    return params

def forward(params, x):
    for layer in params[:-1]:
        x = nn.relu(jnp.dot(x, layer['w']) + layer['b'])
    return jnp.dot(x, params[-1]['w']) + params[-1]['b']

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(params, x, y, lr=0.001):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

key = jax.random.PRNGKey(0)
params = init_params(key, [(784, 256), (256, 10)])

x = jax.random.normal(key, (1024, 784))
y = jax.random.normal(key, (1024, 10))

for step in range(100):
    params, loss = train_step(params, x, y)
    if step % 20 == 0:
        print(f"Step {step}: loss={loss:.4f}")

TPU 数据并行 #

python
import jax
import jax.numpy as jnp

@jax.pmap
def tpu_train_step(params, batch):
    x, y = batch
    
    def loss_fn(p):
        pred = forward(p, x)
        return jnp.mean((pred - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    
    grads = jax.lax.pmean(grads, 'device')
    
    new_params = jax.tree_map(
        lambda p, g: p - 0.001 * g,
        params, grads
    )
    
    return new_params

num_devices = jax.device_count()
params = jax.device_put_replicated(params, jax.devices())

batch_x = jax.random.normal(key, (num_devices, 128, 784))
batch_y = jax.random.normal(key, (num_devices, 128, 10))

for step in range(100):
    params = tpu_train_step(params, (batch_x, batch_y))

TPU Pod #

Pod 训练 #

python
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jax.sharding import NamedSharding

num_devices = jax.device_count()
mesh = jax.sharding.Mesh(
    jax.devices().reshape(-1, 8),
    ('data', 'model')
)

def create_sharded_batch(key, batch_size, input_size, mesh):
    x = jax.random.normal(key, (batch_size, input_size))
    sharding = NamedSharding(mesh, P('data', None))
    return jax.device_put(x, sharding)

大规模训练 #

python
import jax
import jax.numpy as jnp

def large_scale_train_step(params, batch, mesh):
    @jax.jit
    def step(p, x, y):
        def loss_fn(params):
            pred = forward(params, x)
            return jnp.mean((pred - y) ** 2)
        
        grads = jax.grad(loss_fn)(p)
        return jax.tree_map(lambda p, g: p - 0.001 * g, p, grads)
    
    x, y = batch
    return step(params, x, y)

性能优化 #

内存管理 #

python
import os

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9'

import jax

批处理优化 #

python
import jax
import jax.numpy as jnp

def optimal_batch_size(model_size, memory_gb):
    bytes_per_param = 4
    available_memory = memory_gb * 1e9
    
    model_memory = model_size * bytes_per_param
    
    activation_memory = available_memory * 0.3
    
    return int(activation_memory / model_memory)

batch_size = optimal_batch_size(model_size=1e9, memory_gb=16)
print(f"推荐批大小: {batch_size}")

混合精度 #

python
import jax
import jax.numpy as jnp

def mixed_precision_forward(params, x):
    x = x.astype(jnp.bfloat16)
    
    for w, b in params:
        w = w.astype(jnp.bfloat16)
        x = jnp.dot(x, w) + b
        x = jax.nn.relu(x)
    
    return x.astype(jnp.float32)

params = [
    (jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)),
     jnp.zeros(1024))
    for i in range(10)
]

x = jax.random.normal(jax.random.PRNGKey(0), (32, 1024))
output = mixed_precision_forward(params, x)
print(f"输出形状: {output.shape}")
print(f"输出类型: {output.dtype}")

TPU 最佳实践 #

1. 使用 pmap 或 pjit #

python
import jax

@jax.pmap
def tpu_parallel_step(params, batch):
    return train_step(params, batch)

2. 批处理最大化 #

python
import jax.numpy as jnp

def create_large_batch(data_loader, target_batch_size):
    batches = []
    current_size = 0
    
    for batch in data_loader:
        batches.append(batch)
        current_size += batch[0].shape[0]
        
        if current_size >= target_batch_size:
            break
    
    x = jnp.concatenate([b[0] for b in batches], axis=0)
    y = jnp.concatenate([b[1] for b in batches], axis=0)
    
    return x[:target_batch_size], y[:target_batch_size]

3. 预取数据 #

python
import jax
import jax.numpy as jnp

def prefetch_to_tpu(batch, device):
    return jax.device_put(batch, device)

下一步 #

现在你已经掌握了 TPU 加速,接下来学习 自定义操作,了解如何扩展 JAX 功能!

最后更新:2026-04-04