控制流 #
概述 #
JAX 的控制流需要特殊处理,因为 JIT 编译需要静态的计算图。JAX 提供了 jax.lax 模块中的函数式控制流操作。
为什么需要函数式控制流? #
text
┌─────────────────────────────────────────────────────────────┐
│ 函数式控制流 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Python 控制流问题: │
│ ❌ if/else 依赖运行时值 │
│ ❌ while 循环次数不确定 │
│ ❌ 无法被 JIT 编译追踪 │
│ │
│ JAX 解决方案: │
│ ✅ lax.cond - 条件分支 │
│ ✅ lax.while_loop - 循环 │
│ ✅ lax.scan - 带状态的循环 │
│ ✅ lax.fori_loop - 索引循环 │
│ │
└─────────────────────────────────────────────────────────────┘
条件分支 (lax.cond) #
基本用法 #
python
import jax
import jax.numpy as jnp
def conditional_abs(x):
return jax.lax.cond(
x >= 0,
lambda x: x,
lambda x: -x,
x
)
print(f"|3| = {conditional_abs(3.0)}")
print(f"|-3| = {conditional_abs(-3.0)}")
多条件 #
python
import jax
import jax.numpy as jnp
def piecewise(x):
return jax.lax.cond(
x < 0,
lambda x: x ** 2,
lambda x: jax.lax.cond(
x < 1,
lambda x: x,
lambda x: x ** 0.5,
x
),
x
)
print(f"x=-2: {piecewise(-2.0)}")
print(f"x=0.5: {piecewise(0.5)}")
print(f"x=4: {piecewise(4.0)}")
使用 jnp.where #
python
import jax.numpy as jnp
def simple_conditional(x):
return jnp.where(x >= 0, x, -x)
print(f"|3| = {simple_conditional(3.0)}")
print(f"|-3| = {simple_conditional(-3.0)}")
使用 jnp.select #
python
import jax.numpy as jnp
def multi_conditional(x):
conditions = [
x < 0,
(x >= 0) & (x < 1),
x >= 1
]
choices = [
x ** 2,
x,
x ** 0.5
]
return jnp.select(conditions, choices, default=x)
print(f"x=-2: {multi_conditional(-2.0)}")
print(f"x=0.5: {multi_conditional(0.5)}")
print(f"x=4: {multi_conditional(4.0)}")
循环 (lax.while_loop) #
基本用法 #
python
import jax
def count_up_to(limit):
def cond_fn(state):
i, _ = state
return i < limit
def body_fn(state):
i, acc = state
return (i + 1, acc + i)
_, result = jax.lax.while_loop(cond_fn, body_fn, (0, 0))
return result
print(f"求和 0-9: {count_up_to(10)}")
收敛循环 #
python
import jax
import jax.numpy as jnp
def newton_sqrt(x, initial_guess=1.0, tolerance=1e-6):
def cond_fn(state):
guess, prev_guess = state
return jnp.abs(guess - prev_guess) > tolerance
def body_fn(state):
guess, _ = state
new_guess = 0.5 * (guess + x / guess)
return (new_guess, guess)
result, _ = jax.lax.while_loop(cond_fn, body_fn, (initial_guess, 0.0))
return result
print(f"sqrt(2) ≈ {newton_sqrt(2.0)}")
print(f"sqrt(9) ≈ {newton_sqrt(9.0)}")
索引循环 (lax.fori_loop) #
基本用法 #
python
import jax
def sum_range(n):
def body_fn(i, acc):
return acc + i
return jax.lax.fori_loop(0, n, body_fn, 0)
print(f"求和 0-9: {sum_range(10)}")
带状态循环 #
python
import jax
import jax.numpy as jnp
def fibonacci(n):
def body_fn(i, state):
a, b = state
return (b, a + b)
final_a, final_b = jax.lax.fori_loop(0, n, body_fn, (0, 1))
return final_a
for i in range(10):
print(f"fib({i}) = {fibonacci(i)}")
矩阵幂 #
python
import jax
import jax.numpy as jnp
def matrix_power(A, n):
def body_fn(i, M):
return M @ A
return jax.lax.fori_loop(0, n - 1, body_fn, A)
A = jnp.array([[1, 1], [1, 0]], dtype=jnp.float32)
print(f"A^5:\n{matrix_power(A, 5)}")
扫描循环 (lax.scan) #
基本用法 #
python
import jax
import jax.numpy as jnp
def cumulative_sum(arr):
def scan_fn(carry, x):
new_carry = carry + x
return new_carry, new_carry
_, result = jax.lax.scan(scan_fn, 0, arr)
return result
arr = jnp.array([1, 2, 3, 4, 5])
print(f"累积和: {cumulative_sum(arr)}")
保存中间状态 #
python
import jax
import jax.numpy as jnp
def rnn_step(carry, x):
new_carry = jnp.tanh(jnp.dot(x, carry))
return new_carry, new_carry
def run_rnn(inputs, initial_state):
final_state, all_states = jax.lax.scan(rnn_step, initial_state, inputs)
return final_state, all_states
inputs = jax.random.normal(jax.random.PRNGKey(0), (10, 5))
initial_state = jnp.zeros(5)
final_state, all_states = run_rnn(inputs, initial_state)
print(f"所有状态形状: {all_states.shape}")
梯度下降轨迹 #
python
import jax
import jax.numpy as jnp
def gradient_descent_trajectory(f, initial_x, lr, num_steps):
grad_f = jax.grad(f)
def step_fn(x, _):
g = grad_f(x)
new_x = x - lr * g
return new_x, new_x
final_x, trajectory = jax.lax.scan(step_fn, initial_x, None, length=num_steps)
return final_x, trajectory
def quadratic(x):
return jnp.sum(x ** 2)
initial_x = jnp.array([5.0, 5.0])
final_x, trajectory = gradient_descent_trajectory(quadratic, initial_x, 0.1, 20)
print(f"最终位置: {final_x}")
print(f"轨迹形状: {trajectory.shape}")
实际应用 #
RNN 前向传播 #
python
import jax
import jax.numpy as jnp
def rnn_forward(params, inputs, initial_state):
W_hh, W_xh, b_h, W_hy, b_y = params
def step_fn(h, x):
h_new = jnp.tanh(jnp.dot(h, W_hh) + jnp.dot(x, W_xh) + b_h)
y = jnp.dot(h_new, W_hy) + b_y
return h_new, y
final_state, outputs = jax.lax.scan(step_fn, initial_state, inputs)
return final_state, outputs
key = jax.random.PRNGKey(0)
hidden_size = 10
input_size = 5
output_size = 3
params = (
jax.random.normal(key, (hidden_size, hidden_size)) * 0.01,
jax.random.normal(key, (input_size, hidden_size)) * 0.01,
jnp.zeros(hidden_size),
jax.random.normal(key, (hidden_size, output_size)) * 0.01,
jnp.zeros(output_size)
)
inputs = jax.random.normal(key, (20, input_size))
initial_state = jnp.zeros(hidden_size)
final_state, outputs = rnn_forward(params, inputs, initial_state)
print(f"输出形状: {outputs.shape}")
梯度下降优化 #
python
import jax
import jax.numpy as jnp
def optimize(loss_fn, initial_params, lr, num_steps):
grad_fn = jax.grad(loss_fn)
def step_fn(params, _):
grads = grad_fn(params)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss_fn(new_params)
final_params, losses = jax.lax.scan(step_fn, initial_params, None, length=num_steps)
return final_params, losses
def mse_loss(params):
x, y = params
return (x - 1) ** 2 + (y - 2) ** 2
initial_params = (5.0, 5.0)
final_params, losses = optimize(mse_loss, initial_params, 0.1, 50)
print(f"最终参数: {final_params}")
print(f"最终损失: {losses[-1]:.6f}")
蒙特卡洛模拟 #
python
import jax
import jax.numpy as jnp
def monte_carlo_pi(key, num_samples):
keys = jax.random.split(key, num_samples)
def sample_fn(_, key):
x = jax.random.uniform(key)
key, subkey = jax.random.split(key)
y = jax.random.uniform(subkey)
return key, x ** 2 + y ** 2 <= 1
_, inside = jax.lax.scan(sample_fn, key, keys)
return 4 * jnp.mean(inside)
key = jax.random.PRNGKey(0)
pi_estimate = monte_carlo_pi(key, 100000)
print(f"π 估计值: {pi_estimate}")
性能对比 #
fori_loop vs Python for #
python
import jax
import jax.numpy as jnp
import time
def python_sum(n):
result = 0
for i in range(n):
result += i
return result
@jax.jit
def jax_sum(n):
def body_fn(i, acc):
return acc + i
return jax.lax.fori_loop(0, n, body_fn, 0)
n = 10000
start = time.time()
result = python_sum(n)
print(f"Python for: {time.time() - start:.4f}s")
start = time.time()
result = jax_sum(n)
result.block_until_ready()
print(f"JAX fori_loop (首次): {time.time() - start:.4f}s")
start = time.time()
result = jax_sum(n)
result.block_until_ready()
print(f"JAX fori_loop (编译后): {time.time() - start:.4f}s")
最佳实践 #
1. 使用 scan 替代 fori_loop #
python
import jax
import jax.numpy as jnp
@jax.jit
def good_practice(arr):
def scan_fn(carry, x):
return carry + x, carry + x
_, result = jax.lax.scan(scan_fn, 0, arr)
return result
@jax.jit
def bad_practice(arr):
def body_fn(i, carry):
return carry + arr[i]
return jax.lax.fori_loop(0, len(arr), body_fn, 0)
2. 使用 where 替代简单 cond #
python
import jax.numpy as jnp
def simple_cond(x):
return jnp.where(x > 0, x, -x)
def complex_cond(x):
return jax.lax.cond(x > 0, lambda x: x ** 2, lambda x: -x, x)
常见问题 #
问题 1: 循环次数必须是静态的 #
python
import jax
import jax.numpy as jnp
@jax.jit
def bad_loop(x, n):
result = 0
for i in range(n):
result += x
return result
@jax.jit
def good_loop(x, n):
def body_fn(i, acc):
return acc + x
return jax.lax.fori_loop(0, n, body_fn, 0)
问题 2: 条件值必须是标量 #
python
import jax
import jax.numpy as jnp
def bad_cond(x):
return jax.lax.cond(x > 0, lambda x: x, lambda x: -x, x)
def good_cond(x):
return jnp.where(x > 0, x, -x)
下一步 #
现在你已经掌握了控制流,接下来学习 构建神经网络,了解如何在 JAX 中构建神经网络!
最后更新:2026-04-04