函数变换组合 #

概述 #

JAX 的核心优势之一是函数变换可以自由组合。gradvmapjit 可以组合使用,实现强大的功能。

组合优势 #

text
┌─────────────────────────────────────────────────────────────┐
│                    变换组合优势                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  grad + vmap  ─── 批量梯度计算                              │
│  grad + jit   ─── 高性能梯度计算                            │
│  vmap + jit   ─── 高性能批处理                              │
│  grad+vmap+jit ── 高性能批量梯度计算                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本组合 #

grad + jit #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

grad_loss = jax.grad(loss)

jit_grad_loss = jax.jit(jax.grad(loss))

@jax.jit
def train_step(params, x, y):
    grads = jax.grad(loss)(params, x, y)
    return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)

params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([5.0, 11.0])

new_params = train_step(params, x, y)
print(f"更新后参数: {new_params}")

vmap + jit #

python
import jax
import jax.numpy as jnp

def process(x):
    return jnp.sum(x ** 2)

process_batch = jax.vmap(process)

fast_process_batch = jax.jit(jax.vmap(process))

@jax.jit
def fast_batch_process(batch_x):
    return jax.vmap(process)(batch_x)

batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])
result = fast_process_batch(batch_x)
print(f"结果: {result}")

grad + vmap #

python
import jax
import jax.numpy as jnp

def single_loss(params, x, y):
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

batch_grad = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))

params = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])

grads = batch_grad(params, x_batch, y_batch)
print(f"批量梯度: {grads}")

三重组合 #

grad + vmap + jit #

python
import jax
import jax.numpy as jnp

def single_loss(params, x, y):
    pred = jnp.dot(x, params)
    return (pred - y) ** 2

single_grad = jax.grad(single_loss)
batch_grad = jax.vmap(single_grad, in_axes=(None, 0, 0))
fast_batch_grad = jax.jit(batch_grad)

params = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])

grads = fast_batch_grad(params, x_batch, y_batch)
print(f"快速批量梯度: {grads}")

简洁写法 #

python
import jax
import jax.numpy as jnp

@jax.jit
def fast_batch_grad(params, x_batch, y_batch):
    def single_grad(x, y):
        return jax.grad(lambda p: single_loss(p, x, y))(params)
    return jax.vmap(single_grad)(x_batch, y_batch)

实际应用 #

神经网络训练 #

python
import jax
import jax.numpy as jnp

def init_params(key, input_dim, hidden_dim, output_dim):
    keys = jax.random.split(key, 4)
    return [
        (jax.random.normal(keys[0], (input_dim, hidden_dim)) * 0.01,
         jnp.zeros(hidden_dim)),
        (jax.random.normal(keys[1], (hidden_dim, output_dim)) * 0.01,
         jnp.zeros(output_dim))
    ]

def forward(params, x):
    for w, b in params[:-1]:
        x = jnp.maximum(jnp.dot(x, w) + b, 0)
    w, b = params[-1]
    return jnp.dot(x, w) + b

def loss_fn(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(params, x, y, lr=0.01):
    grads = jax.grad(loss_fn)(params, x, y)
    return jax.tree_map(
        lambda p, g: p - lr * g,
        params,
        grads
    )

key = jax.random.PRNGKey(0)
params = init_params(key, 10, 20, 5)

x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))

for epoch in range(100):
    params = train_step(params, x, y)
    if epoch % 20 == 0:
        loss = loss_fn(params, x, y)
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

批量训练 #

python
import jax
import jax.numpy as jnp

def forward(params, x):
    w, b = params
    return jnp.dot(x, w) + b

def single_loss(params, x, y):
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def batch_train_step(params, x_batch, y_batch, lr=0.01):
    def single_grad(x, y):
        return jax.grad(lambda p: single_loss(p, x, y))(params)
    
    all_grads = jax.vmap(single_grad)(x_batch, y_batch)
    
    avg_grads = jax.tree_map(lambda g: jnp.mean(g, axis=0), all_grads)
    
    return jax.tree_map(
        lambda p, g: p - lr * g,
        params,
        avg_grads
    )

key = jax.random.PRNGKey(0)
params = (jax.random.normal(key, (10, 5)), jnp.zeros(5))

x_batch = jax.random.normal(key, (32, 10))
y_batch = jax.random.normal(key, (32, 5))

new_params = batch_train_step(params, x_batch, y_batch)
print(f"更新后参数形状: {new_params[0].shape}")

高阶优化 #

python
import jax
import jax.numpy as jnp

def loss_fn(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def newton_step(params, x, y):
    g = jax.grad(loss_fn)(params, x, y)
    H = jax.hessian(loss_fn)(params, x, y)
    
    H_inv = jnp.linalg.inv(H)
    update = jnp.dot(H_inv, g)
    
    return params - update

params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([[1, 2, 3], [4, 5, 6]])
y = jnp.array([14, 32])

new_params = newton_step(params, x, y)
print(f"牛顿法更新后: {new_params}")

高级组合 #

嵌套 vmap + grad #

python
import jax
import jax.numpy as jnp

def process(x, y):
    return jnp.sum(x * y)

process_2d = jax.vmap(jax.vmap(process, in_axes=(0, None)), in_axes=(None, 0))

grad_process_2d = jax.vmap(jax.vmap(jax.grad(process), in_axes=(0, None)), in_axes=(None, 0))

x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])

grads = grad_process_2d(x, y)
print(f"嵌套梯度形状: {grads.shape}")

多参数梯度 #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    w1, w2 = params
    h = jnp.dot(x, w1)
    pred = jnp.dot(h, w2)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def multi_grad_step(params, x, y, lr=0.01):
    grads = jax.grad(loss, argnums=(0,))(params, x, y)[0]
    return jax.tree_map(lambda p, g: p - lr * g, params, grads)

params = (jnp.array([[1.0, 2.0], [3.0, 4.0]]), jnp.array([[5.0], [6.0]]))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[17], [39]])

new_params = multi_grad_step(params, x, y)
print(f"更新后参数: {new_params}")

值和梯度组合 #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    pred = jnp.dot(x, params)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def value_and_grad_step(params, x, y, lr=0.01):
    val, grads = jax.value_and_grad(loss)(params, x, y)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, val

params = jnp.array([1.0, 2.0])
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([5, 11])

new_params, loss_val = value_and_grad_step(params, x, y)
print(f"损失值: {loss_val}")
print(f"更新后参数: {new_params}")

性能对比 #

组合顺序影响 #

python
import jax
import jax.numpy as jnp
import time

def f(x):
    return jnp.sum(x ** 2)

x = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))

grad_then_jit = jax.jit(jax.grad(f))
jit_then_grad = jax.grad(jax.jit(f))

grad_then_jit(x).block_until_ready()
jit_then_grad(x).block_until_ready()

start = time.time()
for _ in range(100):
    grad_then_jit(x).block_until_ready()
print(f"grad -> jit: {time.time() - start:.4f}s")

start = time.time()
for _ in range(100):
    jit_then_grad(x).block_until_ready()
print(f"jit -> grad: {time.time() - start:.4f}s")

最佳实践 #

1. 合理组合顺序 #

python
import jax
import jax.numpy as jnp

@jax.jit
def train_step(params, x, y):
    grads = jax.grad(loss)(params, x, y)
    return update(params, grads)

2. 使用静态参数 #

python
import jax
import jax.numpy as jnp

@jax.jit(static_argnames=['mode'])
def process(x, mode='train'):
    if mode == 'train':
        return x * 2
    else:
        return x

3. 避免重复编译 #

python
import jax
import jax.numpy as jnp

train_step = jax.jit(train_step)

for epoch in range(100):
    params = train_step(params, x, y)

常见问题 #

问题 1: 组合顺序错误 #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

correct = jax.jit(jax.vmap(f))
wrong = jax.vmap(jax.jit(f))  

问题 2: 参数不匹配 #

python
import jax
import jax.numpy as jnp

def f(x, y):
    return x + y

f_vmap = jax.vmap(f, in_axes=(0, None))

x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([5, 6])

result = f_vmap(x, y)

下一步 #

现在你已经掌握了函数变换组合,接下来学习 数组操作,深入了解 JAX 的数值计算能力!

最后更新:2026-04-04