自动微分 (grad) #

什么是自动微分? #

自动微分(Automatic Differentiation,简称 AD)是一种计算函数导数的技术。JAX 提供了强大的自动微分功能,可以自动计算各种复杂函数的梯度。

手动微分 vs 自动微分 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2 + 2 * x + 1

def manual_grad(x):
    return 2 * x + 2

auto_grad = jax.grad(f)

x = 3.0
print(f"手动梯度: {manual_grad(x)}")
print(f"自动梯度: {auto_grad(x)}")

基本用法 #

一阶导数 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 3

grad_f = jax.grad(f)

print(f"f(2) = {f(2.0)}")
print(f"f'(2) = {grad_f(2.0)}")

多变量函数 #

python
import jax
import jax.numpy as jnp

def f(x, y):
    return x ** 2 + y ** 2

grad_f = jax.grad(f, argnums=(0, 1))

x, y = 3.0, 4.0
gx, gy = grad_f(x, y)
print(f"∂f/∂x = {gx}")  
print(f"∂f/∂y = {gy}")  

指定参数 #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    predict = jnp.dot(x, params)
    return jnp.mean((predict - y) ** 2)

grad_loss = jax.grad(loss, argnums=0)

params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([5.0, 11.0])

grads = grad_loss(params, x, y)
print(f"参数梯度: {grads}")

高阶导数 #

二阶导数 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 4

grad_f = jax.grad(f)
grad_grad_f = jax.grad(grad_f)

x = 2.0
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {grad_f(x)}")
print(f"f''({x}) = {grad_grad_f(x)}")

高阶导数 #

python
import jax
import jax.numpy as jnp

def nth_derivative(f, n):
    if n == 0:
        return f
    return nth_derivative(jax.grad(f), n - 1)

def f(x):
    return x ** 5

for n in range(6):
    deriv = nth_derivative(f, n)
    print(f"f^({n})(2) = {deriv(2.0)}")

Hessian 矩阵 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x[0] ** 2 + x[1] ** 2 + x[0] * x[1]

hessian_f = jax.hessian(f)

x = jnp.array([1.0, 2.0])
H = hessian_f(x)
print(f"Hessian 矩阵:\n{H}")

Jacobian 矩阵 #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.array([x[0] ** 2, x[0] * x[1], x[1] ** 2])

jacobian_f = jax.jacobian(f)

x = jnp.array([1.0, 2.0])
J = jacobian_f(x)
print(f"Jacobian 矩阵:\n{J}")

梯度类型 #

标量输出梯度 #

python
import jax
import jax.numpy as jnp

def scalar_loss(params):
    return jnp.sum(params ** 2)

grad_fn = jax.grad(scalar_loss)

params = jnp.array([1.0, 2.0, 3.0])
grads = grad_fn(params)
print(f"梯度: {grads}")

向量输出梯度 #

python
import jax
import jax.numpy as jnp

def vector_fn(x):
    return x ** 2

jacobian_fn = jax.jacobian(vector_fn)

x = jnp.array([1.0, 2.0, 3.0])
J = jacobian_fn(x)
print(f"Jacobian:\n{J}")

值和梯度 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

value_and_grad = jax.value_and_grad(f)

x = 3.0
val, grad = value_and_grad(x)
print(f"值: {val}")
print(f"梯度: {grad}")

辅助函数 #

jax.grad #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

grad_f = jax.grad(f)

jax.value_and_grad #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

val_grad_f = jax.value_and_grad(f)
value, gradient = val_grad_f(jnp.array([1.0, 2.0, 3.0]))

jax.jacfwd 和 jax.jacrev #

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.array([x[0] ** 2, x[0] * x[1], x[1] ** 2])

jac_forward = jax.jacfwd(f)  
jac_reverse = jax.jacrev(f)  

x = jnp.array([1.0, 2.0])
print(f"前向模式 Jacobian:\n{jac_forward(x)}")
print(f"反向模式 Jacobian:\n{jac_reverse(x)}")

停止梯度 #

stop_gradient #

python
import jax
import jax.numpy as jnp

def loss(x, y):
    return jnp.sum((x - jax.lax.stop_gradient(y)) ** 2)

grad_loss_x = jax.grad(loss, argnums=0)
grad_loss_y = jax.grad(loss, argnums=1)

x, y = 2.0, 3.0
print(f"∂loss/∂x = {grad_loss_x(x, y)}")  
print(f"∂loss/∂y = {grad_loss_y(x, y)}")  

应用场景 #

python
import jax
import jax.numpy as jnp

def contrastive_loss(z1, z2, temperature=0.1):
    similarity = jnp.dot(z1, z2) / temperature
    
    stop_sim = jax.lax.stop_gradient(similarity)
    
    loss = -jnp.log(jax.nn.softmax(stop_sim))
    return loss

z1 = jnp.array([1.0, 2.0, 3.0])
z2 = jnp.array([1.0, 2.0, 3.0])

grad_z1 = jax.grad(contrastive_loss, argnums=0)(z1, z2)
print(f"梯度: {grad_z1}")

自定义梯度 #

custom_jvp #

python
import jax
import jax.numpy as jnp

@jax.custom_jvp
def f(x):
    return jnp.sin(x)

@f.defjvp
def f_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    primal_out = f(x)
    tangent_out = jnp.cos(x) * x_dot
    return primal_out, tangent_out

grad_f = jax.grad(f)
print(f"f'(π/2) = {grad_f(jnp.pi / 2)}")

custom_vjp #

python
import jax
import jax.numpy as jnp

@jax.custom_vjp
def f(x):
    return jnp.sin(x)

def f_fwd(x):
    return f(x), x

def f_bwd(x, g):
    return (g * jnp.cos(x),)

f.defvjp(f_fwd, f_bwd)

grad_f = jax.grad(f)
print(f"f'(π/2) = {grad_f(jnp.pi / 2)}")

实际应用 #

神经网络训练 #

python
import jax
import jax.numpy as jnp

def predict(params, x):
    w1, b1, w2, b2 = params
    h = jnp.maximum(jnp.dot(x, w1) + b1, 0)  
    return jnp.dot(h, w2) + b2

def mse_loss(params, x, y):
    pred = predict(params, x)
    return jnp.mean((pred - y) ** 2)

grad_loss = jax.grad(mse_loss)

key = jax.random.PRNGKey(0)
key1, key2, key3, key4 = jax.random.split(key, 4)
params = [
    jax.random.normal(key1, (10, 20)),
    jax.random.normal(key2, (20,)),
    jax.random.normal(key3, (20, 5)),
    jax.random.normal(key4, (5,))
]

x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))

grads = grad_loss(params, x, y)
print(f"梯度数量: {len(grads)}")
for i, g in enumerate(grads):
    print(f"参数 {i} 梯度形状: {g.shape}")

优化问题 #

python
import jax
import jax.numpy as jnp

def objective(x):
    return jnp.sum(x ** 2) + jnp.sum(jnp.sin(x))

grad_obj = jax.grad(objective)

x = jnp.array([1.0, 2.0, 3.0])
lr = 0.1

for i in range(100):
    g = grad_obj(x)
    x = x - lr * g
    if i % 20 == 0:
        print(f"迭代 {i}: 目标值 = {objective(x):.6f}")

性能优化 #

JIT 编译梯度 #

python
import jax
import jax.numpy as jnp

def loss(params, x, y):
    return jnp.mean((jnp.dot(x, params) - y) ** 2)

grad_loss = jax.jit(jax.grad(loss))

params = jnp.array([1.0, 2.0])
x = jax.random.normal(jax.random.PRNGKey(0), (100, 2))
y = jax.random.normal(jax.random.PRNGKey(1), (100,))

grads = grad_loss(params, x, y)

批量梯度计算 #

python
import jax
import jax.numpy as jnp

def single_loss(params, x, y):
    return (jnp.dot(x, params) - y) ** 2

batch_grad = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))

params = jnp.array([1.0, 2.0])
x = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
y = jax.random.normal(jax.random.PRNGKey(1), (10,))

grads = batch_grad(params, x, y)
print(f"批量梯度形状: {grads.shape}")

常见问题 #

问题 1: 非标量输出 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2  

try:
    grad_f = jax.grad(f)
    grad_f(jnp.array([1.0, 2.0]))
except Exception as e:
    print(f"错误: {e}")

grad_f = jax.grad(lambda x: jnp.sum(x ** 2))
print(f"梯度: {grad_f(jnp.array([1.0, 2.0]))}")

问题 2: 整数输入 #

python
import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

try:
    grad_f = jax.grad(f)
    grad_f(2)  
except Exception as e:
    print(f"错误: {e}")

print(f"正确: {jax.grad(f)(2.0)}")

下一步 #

现在你已经掌握了自动微分,接下来学习 自动向量化 (vmap),了解如何自动批处理!

最后更新:2026-04-04