JIT 编译 #

什么是 JIT? #

JIT(Just-In-Time)编译是 JAX 提供的性能优化功能,它通过 XLA(Accelerated Linear Algebra)编译器将 Python 函数编译成高效的机器码,显著提升执行速度。

JIT 的优势 #

text
┌─────────────────────────────────────────────────────────────┐
│                    JIT 编译优势                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✅ 性能提升                                                │
│     - 编译优化                                              │
│     - 算子融合                                              │
│     - 内存优化                                              │
│                                                             │
│  ✅ 硬件加速                                                │
│     - GPU 并行                                              │
│     - TPU 加速                                              │
│     - 向量化                                                │
│                                                             │
│  ✅ 代码简化                                                │
│     - 一行代码启用                                          │
│     - 无需修改逻辑                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

装饰器方式 #

python
import jax
import jax.numpy as jnp

@jax.jit
def fast_function(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])
result = fast_function(x)
print(f"结果: {result}")

函数调用方式 #

python
import jax
import jax.numpy as jnp

def slow_function(x):
    return jnp.sum(x ** 2)

fast_function = jax.jit(slow_function)

x = jnp.array([1.0, 2.0, 3.0])
result = fast_function(x)
print(f"结果: {result}")

带参数的 JIT #

python
import jax
import jax.numpy as jnp

@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        predict = jnp.dot(x, p)
        return jnp.mean((predict - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    new_params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    return new_params

params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([[1, 2, 3], [4, 5, 6]])
y = jnp.array([6, 15])

new_params = train_step(params, x, y)
print(f"更新后参数: {new_params}")

编译过程 #

第一次调用 #

python
import jax
import jax.numpy as jnp
import time

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x = jnp.ones((1000, 1000))

start = time.time()
result = compute(x)
result.block_until_ready()
print(f"第一次调用(编译): {time.time() - start:.4f}s")

start = time.time()
result = compute(x)
result.block_until_ready()
print(f"第二次调用(缓存): {time.time() - start:.4f}s")

编译缓存 #

python
import jax
import jax.numpy as jnp

@jax.jit
def process(x):
    return x + 1

x_int = jnp.array([1, 2, 3])
x_float = jnp.array([1.0, 2.0, 3.0])

result1 = process(x_int)  
result2 = process(x_float)  
result3 = process(x_int)  

静态参数 #

static_argnums #

python
import jax
import jax.numpy as jnp

@jax.jit(static_argnums=(1,))
def process(x, use_relu):
    if use_relu:
        return jnp.maximum(x, 0)
    else:
        return x

x = jnp.array([-1, 2, -3, 4])

result_relu = process(x, True)
result_linear = process(x, False)

print(f"ReLU: {result_relu}")
print(f"Linear: {result_linear}")

static_argnames #

python
import jax
import jax.numpy as jnp

@jax.jit(static_argnames=['mode'])
def process(x, mode='relu'):
    if mode == 'relu':
        return jnp.maximum(x, 0)
    elif mode == 'sigmoid':
        return jax.nn.sigmoid(x)
    else:
        return x

x = jnp.array([-1, 2, -3, 4])

result_relu = process(x, mode='relu')
result_sigmoid = process(x, mode='sigmoid')

print(f"ReLU: {result_relu}")
print(f"Sigmoid: {result_sigmoid}")

追踪与编译 #

追踪过程 #

python
import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    print(f"追踪中: x = {x}")
    return x + 1

print("第一次调用:")
result = f(jnp.array([1, 2, 3]))

print("\n第二次调用:")
result = f(jnp.array([4, 5, 6]))

抽象值 #

python
import jax
import jax.numpy as jnp

def inspect_tracer(x):
    print(f"类型: {type(x)}")
    print(f"形状: {x.shape}")
    print(f"数据类型: {x.dtype}")
    return x + 1

jax.make_jaxpr(inspect_tracer)(jnp.array([1, 2, 3]))

JIT 限制 #

条件语句限制 #

python
import jax
import jax.numpy as jnp

@jax.jit
def bad_condition(x):
    if x > 0:  
        return x
    else:
        return -x

@jax.jit
def good_condition(x):
    return jnp.where(x > 0, x, -x)

x = jnp.array([1, -2, 3, -4])
result = good_condition(x)
print(f"结果: {result}")

循环限制 #

python
import jax
import jax.numpy as jnp

@jax.jit
def bad_loop(x, n):
    for i in range(n):  
        x = x + 1
    return x

@jax.jit
def good_loop(x, n):
    def body(i, val):
        return val + 1
    return jax.lax.fori_loop(0, n, body, x)

x = jnp.array([1, 2, 3])
result = good_loop(x, 5)
print(f"结果: {result}")

副作用限制 #

python
import jax
import jax.numpy as jnp

counter = 0

@jax.jit
def bad_side_effect(x):
    global counter
    counter += 1  
    return x + 1

@jax.jit
def good_pure(x):
    return x + 1

调试技巧 #

禁用 JIT #

python
import jax
import jax.numpy as jnp

jax.config.update('jax_disable_jit', True)

@jax.jit
def debug_function(x):
    print(f"调试: x = {x}")  
    return x + 1

x = jnp.array([1, 2, 3])
result = debug_function(x)

jax.debug.print #

python
import jax
import jax.numpy as jnp

@jax.jit
def debug_function(x):
    jax.debug.print("x = {}", x)
    y = x + 1
    jax.debug.print("y = {}", y)
    return y

x = jnp.array([1, 2, 3])
result = debug_function(x)

jax.debug.breakpoint #

python
import jax
import jax.numpy as jnp

@jax.jit
def debug_function(x):
    y = x + 1
    jax.debug.breakpoint()  
    return y

x = jnp.array([1, 2, 3])
result = debug_function(x)

性能优化 #

算子融合 #

python
import jax
import jax.numpy as jnp
import time

def slow_fusion(x):
    a = x * 2
    b = a + 1
    c = b ** 2
    return c

@jax.jit
def fast_fusion(x):
    a = x * 2
    b = a + 1
    c = b ** 2
    return c

x = jnp.ones((10000, 10000))

start = time.time()
for _ in range(10):
    result = slow_fusion(x)
    result.block_until_ready()
print(f"无 JIT: {time.time() - start:.4f}s")

start = time.time()
for _ in range(10):
    result = fast_fusion(x)
    result.block_until_ready()
print(f"有 JIT: {time.time() - start:.4f}s")

内存优化 #

python
import jax
import jax.numpy as jnp

@jax.jit
def memory_efficient(x):
    return jnp.einsum('ij,jk->ik', x, x.T)

x = jnp.ones((1000, 1000))
result = memory_efficient(x)

编译选项 #

python
import jax
import jax.numpy as jnp

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

x = jnp.ones((100, 100))

jax.make_jaxpr(compute)(x)

实际应用 #

训练循环 #

python
import jax
import jax.numpy as jnp

@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        w, b = p
        pred = jnp.dot(x, w) + b
        return jnp.mean((pred - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    new_params = jax.tree_map(
        lambda p, g: p - 0.01 * g,
        params,
        grads
    )
    return new_params

key = jax.random.PRNGKey(0)
params = (
    jax.random.normal(key, (10, 5)),
    jax.random.normal(key, (5,))
)

x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))

for epoch in range(100):
    params = train_step(params, x, y)

推理加速 #

python
import jax
import jax.numpy as jnp

@jax.jit
def predict(params, x):
    for w, b in params:
        x = jnp.maximum(jnp.dot(x, w) + b, 0)
    return x

params = [
    (jax.random.normal(jax.random.PRNGKey(0), (10, 20)),
     jax.random.normal(jax.random.PRNGKey(1), (20,))),
    (jax.random.normal(jax.random.PRNGKey(2), (20, 5)),
     jax.random.normal(jax.random.PRNGKey(3), (5,)))
]

x = jax.random.normal(jax.random.PRNGKey(4), (1000, 10))
predictions = predict(params, x)

常见问题 #

问题 1: 编译时间过长 #

python
import jax
import jax.numpy as jnp

@jax.jit
def large_function(x):
    result = x
    for i in range(100):
        result = result + 1
    return result

@jax.jit
def optimized_function(x):
    return x + 100

问题 2: 内存不足 #

python
import jax
import jax.numpy as jnp

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

@jax.jit
def memory_intensive(x):
    return jnp.dot(x, x.T)

下一步 #

现在你已经掌握了 JIT 编译,接下来学习 函数变换组合,了解如何组合使用各种变换!

最后更新:2026-04-04