自动向量化 (vmap) #
什么是 vmap? #
vmap(vectorizing map)是 JAX 提供的自动向量化功能,它可以将处理单个样本的函数自动转换为处理批量样本的函数,无需手动编写循环。
手动批处理 vs vmap #
python
import jax
import jax.numpy as jnp
def process_single(x):
return x ** 2 + 1
def process_batch_manual(batch_x):
return jnp.array([process_single(x) for x in batch_x])
process_batch_vmap = jax.vmap(process_single)
batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])
result_manual = process_batch_manual(batch_x)
result_vmap = process_batch_vmap(batch_x)
print(f"手动结果: {result_manual}")
print(f"vmap 结果: {result_vmap}")
基本用法 #
简单向量化 #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
f_batch = jax.vmap(f)
x_single = jnp.array([1, 2, 3])
x_batch = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f"单个输入: {f(x_single)}")
print(f"批量输入: {f_batch(x_batch)}")
多参数向量化 #
python
import jax
import jax.numpy as jnp
def dot_product(x, y):
return jnp.dot(x, y)
dot_batch = jax.vmap(dot_product)
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([[1, 1], [2, 2], [3, 3]])
result = dot_batch(x_batch, y_batch)
print(f"批量点积: {result}")
轴映射 (in_axes) #
指定批处理轴 #
python
import jax
import jax.numpy as jnp
def f(x, y):
return x + y
f_vmap = jax.vmap(f, in_axes=(0, 0))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])
result = f_vmap(x, y)
print(f"结果: {result}")
部分参数批处理 #
python
import jax
import jax.numpy as jnp
def apply_weight(x, weight):
return x * weight
apply_batch = jax.vmap(apply_weight, in_axes=(0, None))
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
weight = jnp.array([10, 20])
result = apply_batch(x_batch, weight)
print(f"结果:\n{result}")
不同轴映射 #
python
import jax
import jax.numpy as jnp
def f(x, y):
return jnp.outer(x, y)
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])
f_vmap_0_0 = jax.vmap(f, in_axes=(0, 0))
f_vmap_0_1 = jax.vmap(f, in_axes=(0, 1))
f_vmap_1_0 = jax.vmap(f, in_axes=(1, 0))
print(f"in_axes=(0, 0):\n{f_vmap_0_0(x, y)}")
print(f"in_axes=(0, 1):\n{f_vmap_0_1(x, y)}")
print(f"in_axes=(1, 0):\n{f_vmap_1_0(x, y)}")
输出轴 (out_axes) #
默认输出轴 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2
f_vmap = jax.vmap(f)
x = jnp.array([[1, 2], [3, 4]])
result = f_vmap(x)
print(f"默认 out_axes=0:\n{result}")
指定输出轴 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2
f_vmap_0 = jax.vmap(f, out_axes=0)
f_vmap_1 = jax.vmap(f, out_axes=1)
x = jnp.array([[1, 2], [3, 4]])
print(f"out_axes=0:\n{f_vmap_0(x)}")
print(f"out_axes=1:\n{f_vmap_1(x)}")
嵌套 vmap #
多维批处理 #
python
import jax
import jax.numpy as jnp
def process(x):
return x ** 2
process_2d = jax.vmap(jax.vmap(process))
x_2d = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = process_2d(x_2d)
print(f"2D 批处理结果:\n{result}")
不同轴嵌套 #
python
import jax
import jax.numpy as jnp
def outer_product(x, y):
return jnp.outer(x, y)
outer_2d = jax.vmap(jax.vmap(outer_product, in_axes=(0, None)), in_axes=(None, 0))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])
result = outer_2d(x, y)
print(f"嵌套 vmap 结果形状: {result.shape}")
实际应用 #
神经网络推理 #
python
import jax
import jax.numpy as jnp
def predict(params, x):
w, b = params
return jnp.dot(x, w) + b
predict_batch = jax.vmap(predict, in_axes=(None, 0))
key = jax.random.PRNGKey(0)
params = (
jax.random.normal(key, (10, 5)),
jax.random.normal(key, (5,))
)
x_batch = jax.random.normal(key, (32, 10))
predictions = predict_batch(params, x_batch)
print(f"预测形状: {predictions.shape}")
批量损失计算 #
python
import jax
import jax.numpy as jnp
def single_loss(params, x, y):
w, b = params
pred = jnp.dot(x, w) + b
return (pred - y) ** 2
batch_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))
params = (jnp.array([1.0, 2.0]), jnp.array([0.1]))
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])
losses = batch_loss(params, x_batch, y_batch)
print(f"批量损失: {losses}")
批量梯度计算 #
python
import jax
import jax.numpy as jnp
def single_grad(params, x, y):
def loss_fn(p):
w, b = p
pred = jnp.dot(x, w) + b
return (pred - y) ** 2
return jax.grad(loss_fn)(params)
batch_grad = jax.vmap(single_grad, in_axes=(None, 0, 0))
params = (jnp.array([1.0, 2.0]), jnp.array([0.1]))
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])
grads = batch_grad(params, x_batch, y_batch)
print(f"批量梯度: {grads}")
性能对比 #
vmap vs 循环 #
python
import time
import jax
import jax.numpy as jnp
def process(x):
return jnp.sum(x ** 2)
process_vmap = jax.jit(jax.vmap(process))
def process_loop(batch_x):
return jnp.array([process(x) for x in batch_x])
process_loop_jit = jax.jit(process_loop)
batch_x = jax.random.normal(jax.random.PRNGKey(0), (10000, 100))
process_vmap(batch_x).block_until_ready()
process_loop_jit(batch_x).block_until_ready()
start = time.time()
for _ in range(100):
result_vmap = process_vmap(batch_x)
result_vmap.block_until_ready()
print(f"vmap 时间: {time.time() - start:.4f}s")
start = time.time()
for _ in range(100):
result_loop = process_loop_jit(batch_x)
result_loop.block_until_ready()
print(f"循环 时间: {time.time() - start:.4f}s")
高级用法 #
条件批处理 #
python
import jax
import jax.numpy as jnp
def process_with_condition(x):
return jnp.where(x > 0, x ** 2, -x ** 2)
process_batch = jax.vmap(process_with_condition)
x = jnp.array([[-1, 2], [3, -4], [-5, 6]])
result = process_batch(x)
print(f"条件批处理结果:\n{result}")
状态处理 #
python
import jax
import jax.numpy as jnp
def update_state(state, x):
return state + x, state * x
update_batch = jax.vmap(update_state, in_axes=(None, 0))
state = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
new_state, outputs = update_batch(state, x_batch)
print(f"新状态: {new_state}")
print(f"输出: {outputs}")
常见问题 #
问题 1: 形状不匹配 #
python
import jax
import jax.numpy as jnp
def f(x, y):
return x + y
try:
f_vmap = jax.vmap(f, in_axes=(0, 0))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([1, 2, 3])
result = f_vmap(x, y)
except Exception as e:
print(f"错误: {e}")
f_vmap = jax.vmap(f, in_axes=(0, None))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([1, 2])
result = f_vmap(x, y)
print(f"正确结果:\n{result}")
问题 2: 非数组输入 #
python
import jax
import jax.numpy as jnp
def f(x, scale):
return x * scale
f_vmap = jax.vmap(f, in_axes=(0, None))
x = jnp.array([[1, 2], [3, 4]])
scale = 2.0
result = f_vmap(x, scale)
print(f"结果:\n{result}")
最佳实践 #
1. 使用 JIT 加速 #
python
import jax
import jax.numpy as jnp
def process(x):
return jnp.sum(x ** 2)
process_batch = jax.jit(jax.vmap(process))
2. 合理设置轴映射 #
python
import jax
import jax.numpy as jnp
def apply_weights(x, weights):
return jnp.dot(x, weights)
apply_batch = jax.vmap(apply_weights, in_axes=(0, None))
3. 组合其他变换 #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
return jnp.mean((jnp.dot(x, params) - y) ** 2)
grad_loss = jax.grad(loss)
batch_grad = jax.vmap(grad_loss, in_axes=(None, 0, 0))
fast_batch_grad = jax.jit(batch_grad)
下一步 #
现在你已经掌握了自动向量化,接下来学习 JIT 编译,了解如何进一步优化性能!
最后更新:2026-04-04