性能优化 #

概述 #

本节介绍 JAX 性能优化的各种技巧,帮助你写出高效的 JAX 代码。

JIT 编译优化 #

编译缓存 #

python
import jax
import jax.numpy as jnp

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x_float32 = jnp.ones((100, 100), dtype=jnp.float32)
x_float64 = jnp.ones((100, 100), dtype=jnp.float64)

compute(x_float32)
compute(x_float64)
compute(x_float32)

静态参数 #

python
import jax
import jax.numpy as jnp

@jax.jit(static_argnums=(1,))
def process(x, mode):
    if mode == 'relu':
        return jnp.maximum(x, 0)
    else:
        return x

x = jnp.array([-1, 2, -3, 4])
print(process(x, 'relu'))
print(process(x, 'linear'))

内存优化 #

内存预分配 #

python
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'

import jax

梯度检查点 #

python
import jax
import jax.numpy as jnp

def large_forward(params, x):
    for w in params:
        x = jnp.dot(x, w)
        x = jax.nn.relu(x)
    return x

@jax.remat
def memory_efficient_forward(params, x):
    return large_forward(params, x)

params = [jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01 
          for i in range(50)]
x = jax.random.normal(jax.random.PRNGKey(100), (32, 1024))

output = memory_efficient_forward(params, x)
print(f"输出形状: {output.shape}")

并行化 #

vmap 批处理 #

python
import jax
import jax.numpy as jnp

def single_process(x):
    return jnp.sum(x ** 2)

batch_process = jax.jit(jax.vmap(single_process))

batch_x = jax.random.normal(jax.random.PRNGKey(0), (10000, 100))
result = batch_process(batch_x)
print(f"结果形状: {result.shape}")

pmap 多设备 #

python
import jax
import jax.numpy as jnp

@jax.pmap
def parallel_compute(x):
    return jnp.sum(x ** 2)

num_devices = jax.device_count()
batch_x = jax.random.normal(jax.random.PRNGKey(0), (num_devices, 1000, 100))
result = parallel_compute(batch_x)
print(f"结果形状: {result.shape}")

性能分析 #

使用 profiler #

python
import jax
import jax.numpy as jnp

with jax.profiler.trace('/tmp/jax_trace'):
    @jax.jit
    def compute(x):
        return jnp.dot(x, x.T)
    
    x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
    result = compute(x)
    result.block_until_ready()

时间测量 #

python
import jax
import jax.numpy as jnp
import time

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))

compute(x).block_until_ready()

start = time.time()
for _ in range(100):
    result = compute(x)
    result.block_until_ready()
end = time.time()

print(f"平均时间: {(end - start) / 100:.4f}s")

下一步 #

现在你已经掌握了性能优化,接下来学习 调试技巧,了解如何调试 JAX 代码!

最后更新:2026-04-04