自动向量化 (vmap) #

什么是 vmap? #

vmap(vectorizing map)是 JAX 提供的自动向量化功能,它可以将处理单个样本的函数自动转换为处理批量样本的函数,无需手动编写循环。

手动批处理 vs vmap #

python
import jax
import jax.numpy as jnp

def process_single(x):
    return x ** 2 + 1

def process_batch_manual(batch_x):
    return jnp.array([process_single(x) for x in batch_x])

process_batch_vmap = jax.vmap(process_single)

batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])

result_manual = process_batch_manual(batch_x)
result_vmap = process_batch_vmap(batch_x)

print(f"手动结果: {result_manual}")
print(f"vmap 结果: {result_vmap}")

基本用法 #

简单向量化 #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

f_batch = jax.vmap(f)

x_single = jnp.array([1, 2, 3])
x_batch = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

print(f"单个输入: {f(x_single)}")
print(f"批量输入: {f_batch(x_batch)}")

多参数向量化 #

python
import jax
import jax.numpy as jnp

def dot_product(x, y):
    return jnp.dot(x, y)

dot_batch = jax.vmap(dot_product)

x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([[1, 1], [2, 2], [3, 3]])

result = dot_batch(x_batch, y_batch)
print(f"批量点积: {result}")

轴映射 (in_axes) #

指定批处理轴 #

python
import jax
import jax.numpy as jnp

def f(x, y):
    return x + y

f_vmap = jax.vmap(f, in_axes=(0, 0))

x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])

result = f_vmap(x, y)
print(f"结果: {result}")

部分参数批处理 #

python
import jax
import jax.numpy as jnp

def apply_weight(x, weight):
    return x * weight

apply_batch = jax.vmap(apply_weight, in_axes=(0, None))

x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
weight = jnp.array([10, 20])

result = apply_batch(x_batch, weight)
print(f"结果:\n{result}")

不同轴映射 #

python
import jax
import jax.numpy as jnp

def f(x, y):
    return jnp.outer(x, y)

x = jnp.array([[1, 2], [3, 4]])  
y = jnp.array([[5, 6], [7, 8]])  

f_vmap_0_0 = jax.vmap(f, in_axes=(0, 0))
f_vmap_0_1 = jax.vmap(f, in_axes=(0, 1))
f_vmap_1_0 = jax.vmap(f, in_axes=(1, 0))

print(f"in_axes=(0, 0):\n{f_vmap_0_0(x, y)}")
print(f"in_axes=(0, 1):\n{f_vmap_0_1(x, y)}")
print(f"in_axes=(1, 0):\n{f_vmap_1_0(x, y)}")

输出轴 (out_axes) #

默认输出轴 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

f_vmap = jax.vmap(f)

x = jnp.array([[1, 2], [3, 4]])
result = f_vmap(x)
print(f"默认 out_axes=0:\n{result}")

指定输出轴 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

f_vmap_0 = jax.vmap(f, out_axes=0)
f_vmap_1 = jax.vmap(f, out_axes=1)

x = jnp.array([[1, 2], [3, 4]])

print(f"out_axes=0:\n{f_vmap_0(x)}")
print(f"out_axes=1:\n{f_vmap_1(x)}")

嵌套 vmap #

多维批处理 #

python
import jax
import jax.numpy as jnp

def process(x):
    return x ** 2

process_2d = jax.vmap(jax.vmap(process))

x_2d = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = process_2d(x_2d)
print(f"2D 批处理结果:\n{result}")

不同轴嵌套 #

python
import jax
import jax.numpy as jnp

def outer_product(x, y):
    return jnp.outer(x, y)

outer_2d = jax.vmap(jax.vmap(outer_product, in_axes=(0, None)), in_axes=(None, 0))

x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([[5, 6], [7, 8]])

result = outer_2d(x, y)
print(f"嵌套 vmap 结果形状: {result.shape}")

实际应用 #

神经网络推理 #

python
import jax
import jax.numpy as jnp

def predict(params, x):
    w, b = params
    return jnp.dot(x, w) + b

predict_batch = jax.vmap(predict, in_axes=(None, 0))

key = jax.random.PRNGKey(0)
params = (
    jax.random.normal(key, (10, 5)),
    jax.random.normal(key, (5,))
)

x_batch = jax.random.normal(key, (32, 10))
predictions = predict_batch(params, x_batch)
print(f"预测形状: {predictions.shape}")

批量损失计算 #

python
import jax
import jax.numpy as jnp

def single_loss(params, x, y):
    w, b = params
    pred = jnp.dot(x, w) + b
    return (pred - y) ** 2

batch_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))

params = (jnp.array([1.0, 2.0]), jnp.array([0.1]))
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])

losses = batch_loss(params, x_batch, y_batch)
print(f"批量损失: {losses}")

批量梯度计算 #

python
import jax
import jax.numpy as jnp

def single_grad(params, x, y):
    def loss_fn(p):
        w, b = p
        pred = jnp.dot(x, w) + b
        return (pred - y) ** 2
    return jax.grad(loss_fn)(params)

batch_grad = jax.vmap(single_grad, in_axes=(None, 0, 0))

params = (jnp.array([1.0, 2.0]), jnp.array([0.1]))
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])
y_batch = jnp.array([5, 11, 17])

grads = batch_grad(params, x_batch, y_batch)
print(f"批量梯度: {grads}")

性能对比 #

vmap vs 循环 #

python
import time
import jax
import jax.numpy as jnp

def process(x):
    return jnp.sum(x ** 2)

process_vmap = jax.jit(jax.vmap(process))

def process_loop(batch_x):
    return jnp.array([process(x) for x in batch_x])

process_loop_jit = jax.jit(process_loop)

batch_x = jax.random.normal(jax.random.PRNGKey(0), (10000, 100))

process_vmap(batch_x).block_until_ready()
process_loop_jit(batch_x).block_until_ready()

start = time.time()
for _ in range(100):
    result_vmap = process_vmap(batch_x)
    result_vmap.block_until_ready()
print(f"vmap 时间: {time.time() - start:.4f}s")

start = time.time()
for _ in range(100):
    result_loop = process_loop_jit(batch_x)
    result_loop.block_until_ready()
print(f"循环 时间: {time.time() - start:.4f}s")

高级用法 #

条件批处理 #

python
import jax
import jax.numpy as jnp

def process_with_condition(x):
    return jnp.where(x > 0, x ** 2, -x ** 2)

process_batch = jax.vmap(process_with_condition)

x = jnp.array([[-1, 2], [3, -4], [-5, 6]])
result = process_batch(x)
print(f"条件批处理结果:\n{result}")

状态处理 #

python
import jax
import jax.numpy as jnp

def update_state(state, x):
    return state + x, state * x

update_batch = jax.vmap(update_state, in_axes=(None, 0))

state = jnp.array([1.0, 2.0])
x_batch = jnp.array([[1, 2], [3, 4], [5, 6]])

new_state, outputs = update_batch(state, x_batch)
print(f"新状态: {new_state}")
print(f"输出: {outputs}")

常见问题 #

问题 1: 形状不匹配 #

python
import jax
import jax.numpy as jnp

def f(x, y):
    return x + y

try:
    f_vmap = jax.vmap(f, in_axes=(0, 0))
    x = jnp.array([[1, 2], [3, 4]])
    y = jnp.array([1, 2, 3])  
    result = f_vmap(x, y)
except Exception as e:
    print(f"错误: {e}")

f_vmap = jax.vmap(f, in_axes=(0, None))
x = jnp.array([[1, 2], [3, 4]])
y = jnp.array([1, 2])
result = f_vmap(x, y)
print(f"正确结果:\n{result}")

问题 2: 非数组输入 #

python
import jax
import jax.numpy as jnp

def f(x, scale):
    return x * scale

f_vmap = jax.vmap(f, in_axes=(0, None))

x = jnp.array([[1, 2], [3, 4]])
scale = 2.0  

result = f_vmap(x, scale)
print(f"结果:\n{result}")

最佳实践 #

1. 使用 JIT 加速 #

python
import jax
import jax.numpy as jnp

def process(x):
    return jnp.sum(x ** 2)

process_batch = jax.jit(jax.vmap(process))

2. 合理设置轴映射 #

python
import jax
import jax.numpy as jnp

def apply_weights(x, weights):
    return jnp.dot(x, weights)

apply_batch = jax.vmap(apply_weights, in_axes=(0, None))

3. 组合其他变换 #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    return jnp.mean((jnp.dot(x, params) - y) ** 2)

grad_loss = jax.grad(loss)

batch_grad = jax.vmap(grad_loss, in_axes=(None, 0, 0))

fast_batch_grad = jax.jit(batch_grad)

下一步 #

现在你已经掌握了自动向量化,接下来学习 JIT 编译,了解如何进一步优化性能!

最后更新:2026-04-04