控制流 #

概述 #

JAX 的控制流需要特殊处理,因为 JIT 编译需要静态的计算图。JAX 提供了 jax.lax 模块中的函数式控制流操作。

为什么需要函数式控制流? #

text
┌─────────────────────────────────────────────────────────────┐
│                    函数式控制流                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Python 控制流问题:                                         │
│  ❌ if/else 依赖运行时值                                    │
│  ❌ while 循环次数不确定                                    │
│  ❌ 无法被 JIT 编译追踪                                     │
│                                                             │
│  JAX 解决方案:                                              │
│  ✅ lax.cond - 条件分支                                     │
│  ✅ lax.while_loop - 循环                                   │
│  ✅ lax.scan - 带状态的循环                                 │
│  ✅ lax.fori_loop - 索引循环                                │
│                                                             │
└─────────────────────────────────────────────────────────────┘

条件分支 (lax.cond) #

基本用法 #

python
import jax
import jax.numpy as jnp

def conditional_abs(x):
    return jax.lax.cond(
        x >= 0,
        lambda x: x,
        lambda x: -x,
        x
    )

print(f"|3| = {conditional_abs(3.0)}")
print(f"|-3| = {conditional_abs(-3.0)}")

多条件 #

python
import jax
import jax.numpy as jnp

def piecewise(x):
    return jax.lax.cond(
        x < 0,
        lambda x: x ** 2,
        lambda x: jax.lax.cond(
            x < 1,
            lambda x: x,
            lambda x: x ** 0.5,
            x
        ),
        x
    )

print(f"x=-2: {piecewise(-2.0)}")
print(f"x=0.5: {piecewise(0.5)}")
print(f"x=4: {piecewise(4.0)}")

使用 jnp.where #

python
import jax.numpy as jnp

def simple_conditional(x):
    return jnp.where(x >= 0, x, -x)

print(f"|3| = {simple_conditional(3.0)}")
print(f"|-3| = {simple_conditional(-3.0)}")

使用 jnp.select #

python
import jax.numpy as jnp

def multi_conditional(x):
    conditions = [
        x < 0,
        (x >= 0) & (x < 1),
        x >= 1
    ]
    choices = [
        x ** 2,
        x,
        x ** 0.5
    ]
    return jnp.select(conditions, choices, default=x)

print(f"x=-2: {multi_conditional(-2.0)}")
print(f"x=0.5: {multi_conditional(0.5)}")
print(f"x=4: {multi_conditional(4.0)}")

循环 (lax.while_loop) #

基本用法 #

python
import jax

def count_up_to(limit):
    def cond_fn(state):
        i, _ = state
        return i < limit
    
    def body_fn(state):
        i, acc = state
        return (i + 1, acc + i)
    
    _, result = jax.lax.while_loop(cond_fn, body_fn, (0, 0))
    return result

print(f"求和 0-9: {count_up_to(10)}")

收敛循环 #

python
import jax
import jax.numpy as jnp

def newton_sqrt(x, initial_guess=1.0, tolerance=1e-6):
    def cond_fn(state):
        guess, prev_guess = state
        return jnp.abs(guess - prev_guess) > tolerance
    
    def body_fn(state):
        guess, _ = state
        new_guess = 0.5 * (guess + x / guess)
        return (new_guess, guess)
    
    result, _ = jax.lax.while_loop(cond_fn, body_fn, (initial_guess, 0.0))
    return result

print(f"sqrt(2) ≈ {newton_sqrt(2.0)}")
print(f"sqrt(9) ≈ {newton_sqrt(9.0)}")

索引循环 (lax.fori_loop) #

基本用法 #

python
import jax

def sum_range(n):
    def body_fn(i, acc):
        return acc + i
    return jax.lax.fori_loop(0, n, body_fn, 0)

print(f"求和 0-9: {sum_range(10)}")

带状态循环 #

python
import jax
import jax.numpy as jnp

def fibonacci(n):
    def body_fn(i, state):
        a, b = state
        return (b, a + b)
    
    final_a, final_b = jax.lax.fori_loop(0, n, body_fn, (0, 1))
    return final_a

for i in range(10):
    print(f"fib({i}) = {fibonacci(i)}")

矩阵幂 #

python
import jax
import jax.numpy as jnp

def matrix_power(A, n):
    def body_fn(i, M):
        return M @ A
    return jax.lax.fori_loop(0, n - 1, body_fn, A)

A = jnp.array([[1, 1], [1, 0]], dtype=jnp.float32)
print(f"A^5:\n{matrix_power(A, 5)}")

扫描循环 (lax.scan) #

基本用法 #

python
import jax
import jax.numpy as jnp

def cumulative_sum(arr):
    def scan_fn(carry, x):
        new_carry = carry + x
        return new_carry, new_carry
    
    _, result = jax.lax.scan(scan_fn, 0, arr)
    return result

arr = jnp.array([1, 2, 3, 4, 5])
print(f"累积和: {cumulative_sum(arr)}")

保存中间状态 #

python
import jax
import jax.numpy as jnp

def rnn_step(carry, x):
    new_carry = jnp.tanh(jnp.dot(x, carry))
    return new_carry, new_carry

def run_rnn(inputs, initial_state):
    final_state, all_states = jax.lax.scan(rnn_step, initial_state, inputs)
    return final_state, all_states

inputs = jax.random.normal(jax.random.PRNGKey(0), (10, 5))
initial_state = jnp.zeros(5)
final_state, all_states = run_rnn(inputs, initial_state)
print(f"所有状态形状: {all_states.shape}")

梯度下降轨迹 #

python
import jax
import jax.numpy as jnp

def gradient_descent_trajectory(f, initial_x, lr, num_steps):
    grad_f = jax.grad(f)
    
    def step_fn(x, _):
        g = grad_f(x)
        new_x = x - lr * g
        return new_x, new_x
    
    final_x, trajectory = jax.lax.scan(step_fn, initial_x, None, length=num_steps)
    return final_x, trajectory

def quadratic(x):
    return jnp.sum(x ** 2)

initial_x = jnp.array([5.0, 5.0])
final_x, trajectory = gradient_descent_trajectory(quadratic, initial_x, 0.1, 20)
print(f"最终位置: {final_x}")
print(f"轨迹形状: {trajectory.shape}")

实际应用 #

RNN 前向传播 #

python
import jax
import jax.numpy as jnp

def rnn_forward(params, inputs, initial_state):
    W_hh, W_xh, b_h, W_hy, b_y = params
    
    def step_fn(h, x):
        h_new = jnp.tanh(jnp.dot(h, W_hh) + jnp.dot(x, W_xh) + b_h)
        y = jnp.dot(h_new, W_hy) + b_y
        return h_new, y
    
    final_state, outputs = jax.lax.scan(step_fn, initial_state, inputs)
    return final_state, outputs

key = jax.random.PRNGKey(0)
hidden_size = 10
input_size = 5
output_size = 3

params = (
    jax.random.normal(key, (hidden_size, hidden_size)) * 0.01,
    jax.random.normal(key, (input_size, hidden_size)) * 0.01,
    jnp.zeros(hidden_size),
    jax.random.normal(key, (hidden_size, output_size)) * 0.01,
    jnp.zeros(output_size)
)

inputs = jax.random.normal(key, (20, input_size))
initial_state = jnp.zeros(hidden_size)

final_state, outputs = rnn_forward(params, inputs, initial_state)
print(f"输出形状: {outputs.shape}")

梯度下降优化 #

python
import jax
import jax.numpy as jnp

def optimize(loss_fn, initial_params, lr, num_steps):
    grad_fn = jax.grad(loss_fn)
    
    def step_fn(params, _):
        grads = grad_fn(params)
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
        return new_params, loss_fn(new_params)
    
    final_params, losses = jax.lax.scan(step_fn, initial_params, None, length=num_steps)
    return final_params, losses

def mse_loss(params):
    x, y = params
    return (x - 1) ** 2 + (y - 2) ** 2

initial_params = (5.0, 5.0)
final_params, losses = optimize(mse_loss, initial_params, 0.1, 50)
print(f"最终参数: {final_params}")
print(f"最终损失: {losses[-1]:.6f}")

蒙特卡洛模拟 #

python
import jax
import jax.numpy as jnp

def monte_carlo_pi(key, num_samples):
    keys = jax.random.split(key, num_samples)
    
    def sample_fn(_, key):
        x = jax.random.uniform(key)
        key, subkey = jax.random.split(key)
        y = jax.random.uniform(subkey)
        return key, x ** 2 + y ** 2 <= 1
    
    _, inside = jax.lax.scan(sample_fn, key, keys)
    
    return 4 * jnp.mean(inside)

key = jax.random.PRNGKey(0)
pi_estimate = monte_carlo_pi(key, 100000)
print(f"π 估计值: {pi_estimate}")

性能对比 #

fori_loop vs Python for #

python
import jax
import jax.numpy as jnp
import time

def python_sum(n):
    result = 0
    for i in range(n):
        result += i
    return result

@jax.jit
def jax_sum(n):
    def body_fn(i, acc):
        return acc + i
    return jax.lax.fori_loop(0, n, body_fn, 0)

n = 10000

start = time.time()
result = python_sum(n)
print(f"Python for: {time.time() - start:.4f}s")

start = time.time()
result = jax_sum(n)
result.block_until_ready()
print(f"JAX fori_loop (首次): {time.time() - start:.4f}s")

start = time.time()
result = jax_sum(n)
result.block_until_ready()
print(f"JAX fori_loop (编译后): {time.time() - start:.4f}s")

最佳实践 #

1. 使用 scan 替代 fori_loop #

python
import jax
import jax.numpy as jnp

@jax.jit
def good_practice(arr):
    def scan_fn(carry, x):
        return carry + x, carry + x
    _, result = jax.lax.scan(scan_fn, 0, arr)
    return result

@jax.jit
def bad_practice(arr):
    def body_fn(i, carry):
        return carry + arr[i]
    return jax.lax.fori_loop(0, len(arr), body_fn, 0)

2. 使用 where 替代简单 cond #

python
import jax.numpy as jnp

def simple_cond(x):
    return jnp.where(x > 0, x, -x)

def complex_cond(x):
    return jax.lax.cond(x > 0, lambda x: x ** 2, lambda x: -x, x)

常见问题 #

问题 1: 循环次数必须是静态的 #

python
import jax
import jax.numpy as jnp

@jax.jit
def bad_loop(x, n):
    result = 0
    for i in range(n):  
        result += x
    return result

@jax.jit
def good_loop(x, n):
    def body_fn(i, acc):
        return acc + x
    return jax.lax.fori_loop(0, n, body_fn, 0)

问题 2: 条件值必须是标量 #

python
import jax
import jax.numpy as jnp

def bad_cond(x):
    return jax.lax.cond(x > 0, lambda x: x, lambda x: -x, x)

def good_cond(x):
    return jnp.where(x > 0, x, -x)

下一步 #

现在你已经掌握了控制流,接下来学习 构建神经网络,了解如何在 JAX 中构建神经网络!

最后更新:2026-04-04