性能优化 #
概述 #
本节介绍 JAX 性能优化的各种技巧,帮助你写出高效的 JAX 代码。
JIT 编译优化 #
编译缓存 #
python
import jax
import jax.numpy as jnp
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x_float32 = jnp.ones((100, 100), dtype=jnp.float32)
x_float64 = jnp.ones((100, 100), dtype=jnp.float64)
compute(x_float32)
compute(x_float64)
compute(x_float32)
静态参数 #
python
import jax
import jax.numpy as jnp
@jax.jit(static_argnums=(1,))
def process(x, mode):
if mode == 'relu':
return jnp.maximum(x, 0)
else:
return x
x = jnp.array([-1, 2, -3, 4])
print(process(x, 'relu'))
print(process(x, 'linear'))
内存优化 #
内存预分配 #
python
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'
import jax
梯度检查点 #
python
import jax
import jax.numpy as jnp
def large_forward(params, x):
for w in params:
x = jnp.dot(x, w)
x = jax.nn.relu(x)
return x
@jax.remat
def memory_efficient_forward(params, x):
return large_forward(params, x)
params = [jax.random.normal(jax.random.PRNGKey(i), (1024, 1024)) * 0.01
for i in range(50)]
x = jax.random.normal(jax.random.PRNGKey(100), (32, 1024))
output = memory_efficient_forward(params, x)
print(f"输出形状: {output.shape}")
并行化 #
vmap 批处理 #
python
import jax
import jax.numpy as jnp
def single_process(x):
return jnp.sum(x ** 2)
batch_process = jax.jit(jax.vmap(single_process))
batch_x = jax.random.normal(jax.random.PRNGKey(0), (10000, 100))
result = batch_process(batch_x)
print(f"结果形状: {result.shape}")
pmap 多设备 #
python
import jax
import jax.numpy as jnp
@jax.pmap
def parallel_compute(x):
return jnp.sum(x ** 2)
num_devices = jax.device_count()
batch_x = jax.random.normal(jax.random.PRNGKey(0), (num_devices, 1000, 100))
result = parallel_compute(batch_x)
print(f"结果形状: {result.shape}")
性能分析 #
使用 profiler #
python
import jax
import jax.numpy as jnp
with jax.profiler.trace('/tmp/jax_trace'):
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
result = compute(x)
result.block_until_ready()
时间测量 #
python
import jax
import jax.numpy as jnp
import time
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
compute(x).block_until_ready()
start = time.time()
for _ in range(100):
result = compute(x)
result.block_until_ready()
end = time.time()
print(f"平均时间: {(end - start) / 100:.4f}s")
下一步 #
现在你已经掌握了性能优化,接下来学习 调试技巧,了解如何调试 JAX 代码!
最后更新:2026-04-04