函数变换组合 #
概述 #
JAX 的核心优势之一是函数变换可以自由组合。grad、vmap、jit 可以组合使用,实现强大的功能。
组合优势 #
text
┌─────────────────────────────────────────────────────────────┐
│ 变换组合优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ grad + vmap ─── 批量梯度计算 │
│ grad + jit ─── 高性能梯度计算 │
│ vmap + jit ─── 高性能批处理 │
│ grad+vmap+jit ── 高性能批量梯度计算 │
│ │
└─────────────────────────────────────────────────────────────┘
基本组合 #
grad + jit #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
grad_loss = jax.grad(loss)
jit_grad_loss = jax.jit(jax.grad(loss))
@jax.jit
def train_step(params, x, y):
grads = jax.grad(loss)(params, x, y)
return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([5.0, 11.0])
new_params = train_step(params, x, y)
print(f"更新后参数: {new_params}")
vmap + jit #
python
import jax
import jax.numpy as jnp
def process(x):
return jnp.sum(x ** 2)
process_batch = jax.vmap(process)
fast_process_batch = jax.jit(jax.vmap(process))
@jax.jit
def fast_batch_process(batch_x):
return jax.vmap(process)(batch_x)
batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])
result = fast_process_batch(batch_x)
print(f"结果: {result}")
grad + vmap #
python
import jax
import jax.numpy as jnp
def single_loss(params, x, y):
pred = jnp.dot(x, params)
return (pred - y) ** 2
batch_grad = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))
params = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])
grads = batch_grad(params, x_batch, y_batch)
print(f"批量梯度: {grads}")
三重组合 #
grad + vmap + jit #
python
import jax
import jax.numpy as jnp
def single_loss(params, x, y):
pred = jnp.dot(x, params)
return (pred - y) ** 2
single_grad = jax.grad(single_loss)
batch_grad = jax.vmap(single_grad, in_axes=(None, 0, 0))
fast_batch_grad = jax.jit(batch_grad)
params = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])
grads = fast_batch_grad(params, x_batch, y_batch)
print(f"快速批量梯度: {grads}")
简洁写法 #
python
import jax
import jax.numpy as jnp
@jax.jit
def fast_batch_grad(params, x_batch, y_batch):
def single_grad(x, y):
return jax.grad(lambda p: single_loss(p, x, y))(params)
return jax.vmap(single_grad)(x_batch, y_batch)
实际应用 #
神经网络训练 #
python
import jax
import jax.numpy as jnp
def init_params(key, input_dim, hidden_dim, output_dim):
keys = jax.random.split(key, 4)
return [
(jax.random.normal(keys[0], (input_dim, hidden_dim)) * 0.01,
jnp.zeros(hidden_dim)),
(jax.random.normal(keys[1], (hidden_dim, output_dim)) * 0.01,
jnp.zeros(output_dim))
]
def forward(params, x):
for w, b in params[:-1]:
x = jnp.maximum(jnp.dot(x, w) + b, 0)
w, b = params[-1]
return jnp.dot(x, w) + b
def loss_fn(params, x, y):
pred = forward(params, x)
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(params, x, y, lr=0.01):
grads = jax.grad(loss_fn)(params, x, y)
return jax.tree_map(
lambda p, g: p - lr * g,
params,
grads
)
key = jax.random.PRNGKey(0)
params = init_params(key, 10, 20, 5)
x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))
for epoch in range(100):
params = train_step(params, x, y)
if epoch % 20 == 0:
loss = loss_fn(params, x, y)
print(f"Epoch {epoch}, Loss: {loss:.4f}")
批量训练 #
python
import jax
import jax.numpy as jnp
def forward(params, x):
w, b = params
return jnp.dot(x, w) + b
def single_loss(params, x, y):
pred = forward(params, x)
return jnp.mean((pred - y) ** 2)
@jax.jit
def batch_train_step(params, x_batch, y_batch, lr=0.01):
def single_grad(x, y):
return jax.grad(lambda p: single_loss(p, x, y))(params)
all_grads = jax.vmap(single_grad)(x_batch, y_batch)
avg_grads = jax.tree_map(lambda g: jnp.mean(g, axis=0), all_grads)
return jax.tree_map(
lambda p, g: p - lr * g,
params,
avg_grads
)
key = jax.random.PRNGKey(0)
params = (jax.random.normal(key, (10, 5)), jnp.zeros(5))
x_batch = jax.random.normal(key, (32, 10))
y_batch = jax.random.normal(key, (32, 5))
new_params = batch_train_step(params, x_batch, y_batch)
print(f"更新后参数形状: {new_params[0].shape}")
高阶优化 #
python
import jax
import jax.numpy as jnp
def loss_fn(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
@jax.jit
def newton_step(params, x, y):
g = jax.grad(loss_fn)(params, x, y)
H = jax.hessian(loss_fn)(params, x, y)
H_inv = jnp.linalg.inv(H)
update = jnp.dot(H_inv, g)
return params - update
params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([[1, 2, 3], [4, 5, 6]])
y = jnp.array([14, 32])
new_params = newton_step(params, x, y)
print(f"牛顿法更新后: {new_params}")
高级组合 #
嵌套 vmap + grad #
python
import jax
import jax.numpy as jnp
def process(x, y):
return jnp.sum(x * y)
process_2d = jax.vmap(jax.vmap(process, in_axes=(0, None)), in_axes=(None, 0))
grad_process_2d = jax.vmap(jax.vmap(jax.grad(process), in_axes=(0, None)), in_axes=(None, 0))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])
grads = grad_process_2d(x, y)
print(f"嵌套梯度形状: {grads.shape}")
多参数梯度 #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
w1, w2 = params
h = jnp.dot(x, w1)
pred = jnp.dot(h, w2)
return jnp.mean((pred - y) ** 2)
@jax.jit
def multi_grad_step(params, x, y, lr=0.01):
grads = jax.grad(loss, argnums=(0,))(params, x, y)[0]
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
params = (jnp.array([[1.0, 2.0], [3.0, 4.0]]), jnp.array([[5.0], [6.0]]))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[17], [39]])
new_params = multi_grad_step(params, x, y)
print(f"更新后参数: {new_params}")
值和梯度组合 #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
pred = jnp.dot(x, params)
return jnp.mean((pred - y) ** 2)
@jax.jit
def value_and_grad_step(params, x, y, lr=0.01):
val, grads = jax.value_and_grad(loss)(params, x, y)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, val
params = jnp.array([1.0, 2.0])
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([5, 11])
new_params, loss_val = value_and_grad_step(params, x, y)
print(f"损失值: {loss_val}")
print(f"更新后参数: {new_params}")
性能对比 #
组合顺序影响 #
python
import jax
import jax.numpy as jnp
import time
def f(x):
return jnp.sum(x ** 2)
x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))
grad_then_jit = jax.jit(jax.grad(f))
jit_then_grad = jax.grad(jax.jit(f))
grad_then_jit(x).block_until_ready()
jit_then_grad(x).block_until_ready()
start = time.time()
for _ in range(100):
grad_then_jit(x).block_until_ready()
print(f"grad -> jit: {time.time() - start:.4f}s")
start = time.time()
for _ in range(100):
jit_then_grad(x).block_until_ready()
print(f"jit -> grad: {time.time() - start:.4f}s")
最佳实践 #
1. 合理组合顺序 #
python
import jax
import jax.numpy as jnp
@jax.jit
def train_step(params, x, y):
grads = jax.grad(loss)(params, x, y)
return update(params, grads)
2. 使用静态参数 #
python
import jax
import jax.numpy as jnp
@jax.jit(static_argnames=['mode'])
def process(x, mode='train'):
if mode == 'train':
return x * 2
else:
return x
3. 避免重复编译 #
python
import jax
import jax.numpy as jnp
train_step = jax.jit(train_step)
for epoch in range(100):
params = train_step(params, x, y)
常见问题 #
问题 1: 组合顺序错误 #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
correct = jax.jit(jax.vmap(f))
wrong = jax.vmap(jax.jit(f))
问题 2: 参数不匹配 #
python
import jax
import jax.numpy as jnp
def f(x, y):
return x + y
f_vmap = jax.vmap(f, in_axes=(0, None))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([5, 6])
result = f_vmap(x, y)
下一步 #
现在你已经掌握了函数变换组合,接下来学习 数组操作,深入了解 JAX 的数值计算能力!
最后更新:2026-04-04