自动微分 (grad) #
什么是自动微分? #
自动微分(Automatic Differentiation,简称 AD)是一种计算函数导数的技术。JAX 提供了强大的自动微分功能,可以自动计算各种复杂函数的梯度。
手动微分 vs 自动微分 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2 + 2 * x + 1
def manual_grad(x):
return 2 * x + 2
auto_grad = jax.grad(f)
x = 3.0
print(f"手动梯度: {manual_grad(x)}")
print(f"自动梯度: {auto_grad(x)}")
基本用法 #
一阶导数 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 3
grad_f = jax.grad(f)
print(f"f(2) = {f(2.0)}")
print(f"f'(2) = {grad_f(2.0)}")
多变量函数 #
python
import jax
import jax.numpy as jnp
def f(x, y):
return x ** 2 + y ** 2
grad_f = jax.grad(f, argnums=(0, 1))
x, y = 3.0, 4.0
gx, gy = grad_f(x, y)
print(f"∂f/∂x = {gx}")
print(f"∂f/∂y = {gy}")
指定参数 #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
predict = jnp.dot(x, params)
return jnp.mean((predict - y) ** 2)
grad_loss = jax.grad(loss, argnums=0)
params = jnp.array([1.0, 2.0])
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([5.0, 11.0])
grads = grad_loss(params, x, y)
print(f"参数梯度: {grads}")
高阶导数 #
二阶导数 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 4
grad_f = jax.grad(f)
grad_grad_f = jax.grad(grad_f)
x = 2.0
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {grad_f(x)}")
print(f"f''({x}) = {grad_grad_f(x)}")
高阶导数 #
python
import jax
import jax.numpy as jnp
def nth_derivative(f, n):
if n == 0:
return f
return nth_derivative(jax.grad(f), n - 1)
def f(x):
return x ** 5
for n in range(6):
deriv = nth_derivative(f, n)
print(f"f^({n})(2) = {deriv(2.0)}")
Hessian 矩阵 #
python
import jax
import jax.numpy as jnp
def f(x):
return x[0] ** 2 + x[1] ** 2 + x[0] * x[1]
hessian_f = jax.hessian(f)
x = jnp.array([1.0, 2.0])
H = hessian_f(x)
print(f"Hessian 矩阵:\n{H}")
Jacobian 矩阵 #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.array([x[0] ** 2, x[0] * x[1], x[1] ** 2])
jacobian_f = jax.jacobian(f)
x = jnp.array([1.0, 2.0])
J = jacobian_f(x)
print(f"Jacobian 矩阵:\n{J}")
梯度类型 #
标量输出梯度 #
python
import jax
import jax.numpy as jnp
def scalar_loss(params):
return jnp.sum(params ** 2)
grad_fn = jax.grad(scalar_loss)
params = jnp.array([1.0, 2.0, 3.0])
grads = grad_fn(params)
print(f"梯度: {grads}")
向量输出梯度 #
python
import jax
import jax.numpy as jnp
def vector_fn(x):
return x ** 2
jacobian_fn = jax.jacobian(vector_fn)
x = jnp.array([1.0, 2.0, 3.0])
J = jacobian_fn(x)
print(f"Jacobian:\n{J}")
值和梯度 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2
value_and_grad = jax.value_and_grad(f)
x = 3.0
val, grad = value_and_grad(x)
print(f"值: {val}")
print(f"梯度: {grad}")
辅助函数 #
jax.grad #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
grad_f = jax.grad(f)
jax.value_and_grad #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
val_grad_f = jax.value_and_grad(f)
value, gradient = val_grad_f(jnp.array([1.0, 2.0, 3.0]))
jax.jacfwd 和 jax.jacrev #
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.array([x[0] ** 2, x[0] * x[1], x[1] ** 2])
jac_forward = jax.jacfwd(f)
jac_reverse = jax.jacrev(f)
x = jnp.array([1.0, 2.0])
print(f"前向模式 Jacobian:\n{jac_forward(x)}")
print(f"反向模式 Jacobian:\n{jac_reverse(x)}")
停止梯度 #
stop_gradient #
python
import jax
import jax.numpy as jnp
def loss(x, y):
return jnp.sum((x - jax.lax.stop_gradient(y)) ** 2)
grad_loss_x = jax.grad(loss, argnums=0)
grad_loss_y = jax.grad(loss, argnums=1)
x, y = 2.0, 3.0
print(f"∂loss/∂x = {grad_loss_x(x, y)}")
print(f"∂loss/∂y = {grad_loss_y(x, y)}")
应用场景 #
python
import jax
import jax.numpy as jnp
def contrastive_loss(z1, z2, temperature=0.1):
similarity = jnp.dot(z1, z2) / temperature
stop_sim = jax.lax.stop_gradient(similarity)
loss = -jnp.log(jax.nn.softmax(stop_sim))
return loss
z1 = jnp.array([1.0, 2.0, 3.0])
z2 = jnp.array([1.0, 2.0, 3.0])
grad_z1 = jax.grad(contrastive_loss, argnums=0)(z1, z2)
print(f"梯度: {grad_z1}")
自定义梯度 #
custom_jvp #
python
import jax
import jax.numpy as jnp
@jax.custom_jvp
def f(x):
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = jnp.cos(x) * x_dot
return primal_out, tangent_out
grad_f = jax.grad(f)
print(f"f'(π/2) = {grad_f(jnp.pi / 2)}")
custom_vjp #
python
import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x):
return jnp.sin(x)
def f_fwd(x):
return f(x), x
def f_bwd(x, g):
return (g * jnp.cos(x),)
f.defvjp(f_fwd, f_bwd)
grad_f = jax.grad(f)
print(f"f'(π/2) = {grad_f(jnp.pi / 2)}")
实际应用 #
神经网络训练 #
python
import jax
import jax.numpy as jnp
def predict(params, x):
w1, b1, w2, b2 = params
h = jnp.maximum(jnp.dot(x, w1) + b1, 0)
return jnp.dot(h, w2) + b2
def mse_loss(params, x, y):
pred = predict(params, x)
return jnp.mean((pred - y) ** 2)
grad_loss = jax.grad(mse_loss)
key = jax.random.PRNGKey(0)
key1, key2, key3, key4 = jax.random.split(key, 4)
params = [
jax.random.normal(key1, (10, 20)),
jax.random.normal(key2, (20,)),
jax.random.normal(key3, (20, 5)),
jax.random.normal(key4, (5,))
]
x = jax.random.normal(key, (32, 10))
y = jax.random.normal(key, (32, 5))
grads = grad_loss(params, x, y)
print(f"梯度数量: {len(grads)}")
for i, g in enumerate(grads):
print(f"参数 {i} 梯度形状: {g.shape}")
优化问题 #
python
import jax
import jax.numpy as jnp
def objective(x):
return jnp.sum(x ** 2) + jnp.sum(jnp.sin(x))
grad_obj = jax.grad(objective)
x = jnp.array([1.0, 2.0, 3.0])
lr = 0.1
for i in range(100):
g = grad_obj(x)
x = x - lr * g
if i % 20 == 0:
print(f"迭代 {i}: 目标值 = {objective(x):.6f}")
性能优化 #
JIT 编译梯度 #
python
import jax
import jax.numpy as jnp
def loss(params, x, y):
return jnp.mean((jnp.dot(x, params) - y) ** 2)
grad_loss = jax.jit(jax.grad(loss))
params = jnp.array([1.0, 2.0])
x = jax.random.normal(jax.random.PRNGKey(0), (100, 2))
y = jax.random.normal(jax.random.PRNGKey(1), (100,))
grads = grad_loss(params, x, y)
批量梯度计算 #
python
import jax
import jax.numpy as jnp
def single_loss(params, x, y):
return (jnp.dot(x, params) - y) ** 2
batch_grad = jax.vmap(jax.grad(single_loss), in_axes=(None, 0, 0))
params = jnp.array([1.0, 2.0])
x = jax.random.normal(jax.random.PRNGKey(0), (10, 2))
y = jax.random.normal(jax.random.PRNGKey(1), (10,))
grads = batch_grad(params, x, y)
print(f"批量梯度形状: {grads.shape}")
常见问题 #
问题 1: 非标量输出 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2
try:
grad_f = jax.grad(f)
grad_f(jnp.array([1.0, 2.0]))
except Exception as e:
print(f"错误: {e}")
grad_f = jax.grad(lambda x: jnp.sum(x ** 2))
print(f"梯度: {grad_f(jnp.array([1.0, 2.0]))}")
问题 2: 整数输入 #
python
import jax
import jax.numpy as jnp
def f(x):
return x ** 2
try:
grad_f = jax.grad(f)
grad_f(2)
except Exception as e:
print(f"错误: {e}")
print(f"正确: {jax.grad(f)(2.0)}")
下一步 #
现在你已经掌握了自动微分,接下来学习 自动向量化 (vmap),了解如何自动批处理!
最后更新:2026-04-04