JIT 编译 #
什么是 JIT? #
JIT(Just-In-Time)编译是 JAX 提供的性能优化功能,它通过 XLA(Accelerated Linear Algebra)编译器将 Python 函数编译成高效的机器码,显著提升执行速度。
JIT 的优势 #
text
┌─────────────────────────────────────────────────────────────┐
│ JIT 编译优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ✅ 性能提升 │
│ - 编译优化 │
│ - 算子融合 │
│ - 内存优化 │
│ │
│ ✅ 硬件加速 │
│ - GPU 并行 │
│ - TPU 加速 │
│ - 向量化 │
│ │
│ ✅ 代码简化 │
│ - 一行代码启用 │
│ - 无需修改逻辑 │
│ │
└─────────────────────────────────────────────────────────────┘
基本用法 #
装饰器方式 #
python
import jax
import jax.numpy as jnp
@jax.jit
def fast_function(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
result = fast_function(x)
print(f"结果: {result}")
函数调用方式 #
python
import jax
import jax.numpy as jnp
def slow_function(x):
return jnp.sum(x ** 2)
fast_function = jax.jit(slow_function)
x = jnp.array([1.0, 2.0, 3.0])
result = fast_function(x)
print(f"结果: {result}")
带参数的 JIT #
python
import jax
import jax.numpy as jnp
@jax.jit
def train_step(params, x, y):
def loss_fn(p):
predict = jnp.dot(x, p)
return jnp.mean((predict - y) ** 2)
grads = jax.grad(loss_fn)(params)
new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
return new_params
params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([[1, 2, 3], [4, 5, 6]])
y = jnp.array([6, 15])
new_params = train_step(params, x, y)
print(f"更新后参数: {new_params}")
编译过程 #
第一次调用 #
python
import jax
import jax.numpy as jnp
import time
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jnp.ones((1000, 1000))
start = time.time()
result = compute(x)
result.block_until_ready()
print(f"第一次调用(编译): {time.time() - start:.4f}s")
start = time.time()
result = compute(x)
result.block_until_ready()
print(f"第二次调用(缓存): {time.time() - start:.4f}s")
编译缓存 #
python
import jax
import jax.numpy as jnp
@jax.jit
def process(x):
return x + 1
x_int = jnp.array([1, 2, 3])
x_float = jnp.array([1.0, 2.0, 3.0])
result1 = process(x_int)
result2 = process(x_float)
result3 = process(x_int)
静态参数 #
static_argnums #
python
import jax
import jax.numpy as jnp
@jax.jit(static_argnums=(1,))
def process(x, use_relu):
if use_relu:
return jnp.maximum(x, 0)
else:
return x
x = jnp.array([-1, 2, -3, 4])
result_relu = process(x, True)
result_linear = process(x, False)
print(f"ReLU: {result_relu}")
print(f"Linear: {result_linear}")
static_argnames #
python
import jax
import jax.numpy as jnp
@jax.jit(static_argnames=['mode'])
def process(x, mode='relu'):
if mode == 'relu':
return jnp.maximum(x, 0)
elif mode == 'sigmoid':
return jax.nn.sigmoid(x)
else:
return x
x = jnp.array([-1, 2, -3, 4])
result_relu = process(x, mode='relu')
result_sigmoid = process(x, mode='sigmoid')
print(f"ReLU: {result_relu}")
print(f"Sigmoid: {result_sigmoid}")
追踪与编译 #
追踪过程 #
python
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
print(f"追踪中: x = {x}")
return x + 1
print("第一次调用:")
result = f(jnp.array([1, 2, 3]))
print("\n第二次调用:")
result = f(jnp.array([4, 5, 6]))
抽象值 #
python
import jax
import jax.numpy as jnp
def inspect_tracer(x):
print(f"类型: {type(x)}")
print(f"形状: {x.shape}")
print(f"数据类型: {x.dtype}")
return x + 1
jax.make_jaxpr(inspect_tracer)(jnp.array([1, 2, 3]))
JIT 限制 #
条件语句限制 #
python
import jax
import jax.numpy as jnp
@jax.jit
def bad_condition(x):
if x > 0:
return x
else:
return -x
@jax.jit
def good_condition(x):
return jnp.where(x > 0, x, -x)
x = jnp.array([1, -2, 3, -4])
result = good_condition(x)
print(f"结果: {result}")
循环限制 #
python
import jax
import jax.numpy as jnp
@jax.jit
def bad_loop(x, n):
for i in range(n):
x = x + 1
return x
@jax.jit
def good_loop(x, n):
def body(i, val):
return val + 1
return jax.lax.fori_loop(0, n, body, x)
x = jnp.array([1, 2, 3])
result = good_loop(x, 5)
print(f"结果: {result}")
副作用限制 #
python
import jax
import jax.numpy as jnp
counter = 0
@jax.jit
def bad_side_effect(x):
global counter
counter += 1
return x + 1
@jax.jit
def good_pure(x):
return x + 1
调试技巧 #
禁用 JIT #
python
import jax
import jax.numpy as jnp
jax.config.update('jax_disable_jit', True)
@jax.jit
def debug_function(x):
print(f"调试: x = {x}")
return x + 1
x = jnp.array([1, 2, 3])
result = debug_function(x)
jax.debug.print #
python
import jax
import jax.numpy as jnp
@jax.jit
def debug_function(x):
jax.debug.print("x = {}", x)
y = x + 1
jax.debug.print("y = {}", y)
return y
x = jnp.array([1, 2, 3])
result = debug_function(x)
jax.debug.breakpoint #
python
import jax
import jax.numpy as jnp
@jax.jit
def debug_function(x):
y = x + 1
jax.debug.breakpoint()
return y
x = jnp.array([1, 2, 3])
result = debug_function(x)
性能优化 #
算子融合 #
python
import jax
import jax.numpy as jnp
import time
def slow_fusion(x):
a = x * 2
b = a + 1
c = b ** 2
return c
@jax.jit
def fast_fusion(x):
a = x * 2
b = a + 1
c = b ** 2
return c
x = jnp.ones((10000, 10000))
start = time.time()
for _ in range(10):
result = slow_fusion(x)
result.block_until_ready()
print(f"无 JIT: {time.time() - start:.4f}s")
start = time.time()
for _ in range(10):
result = fast_fusion(x)
result.block_until_ready()
print(f"有 JIT: {time.time() - start:.4f}s")
内存优化 #
python
import jax
import jax.numpy as jnp
@jax.jit
def memory_efficient(x):
return jnp.einsum('ij,jk->ik', x, x.T)
x = jnp.ones((1000, 1000))
result = memory_efficient(x)
编译选项 #
python
import jax
import jax.numpy as jnp
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
x = jnp.ones((100, 100))
jax.make_jaxpr(compute)(x)
实际应用 #
训练循环 #
python
import jax
import jax.numpy as jnp
@jax.jit
def train_step(params, x, y):
def loss_fn(p):
w, b = p
pred = jnp.dot(x, w) + b
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
new_params = jax.tree_map(
lambda p, g: p - 0.01 * g,
params,
grads
)
return new_params
key = jax.random.PRNGKey(0)
params = (
jax.random.normal(key, (10, 5)),
jax.random.normal(key, (5,))
)
x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))
for epoch in range(100):
params = train_step(params, x, y)
推理加速 #
python
import jax
import jax.numpy as jnp
@jax.jit
def predict(params, x):
for w, b in params:
x = jnp.maximum(jnp.dot(x, w) + b, 0)
return x
params = [
(jax.random.normal(jax.random.PRNGKey(0), (10, 20)),
jax.random.normal(jax.random.PRNGKey(1), (20,))),
(jax.random.normal(jax.random.PRNGKey(2), (20, 5)),
jax.random.normal(jax.random.PRNGKey(3), (5,)))
]
x = jax.random.normal(jax.random.PRNGKey(4), (1000, 10))
predictions = predict(params, x)
常见问题 #
问题 1: 编译时间过长 #
python
import jax
import jax.numpy as jnp
@jax.jit
def large_function(x):
result = x
for i in range(100):
result = result + 1
return result
@jax.jit
def optimized_function(x):
return x + 100
问题 2: 内存不足 #
python
import jax
import jax.numpy as jnp
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'
@jax.jit
def memory_intensive(x):
return jnp.dot(x, x.T)
下一步 #
现在你已经掌握了 JIT 编译,接下来学习 函数变换组合,了解如何组合使用各种变换!
最后更新:2026-04-04