自定义操作 #
概述 #
JAX 允许用户定义自定义操作,包括自定义微分规则和与原生代码集成。
自定义微分 #
custom_jvp #
python
import jax
import jax.numpy as jnp
@jax.custom_jvp
def my_function(x):
return x ** 2
@my_function.defjvp
def my_function_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = my_function(x)
tangent_out = 2 * x * x_dot
return primal_out, tangent_out
grad_fn = jax.grad(my_function)
print(f"f'(3) = {grad_fn(3.0)}")
custom_vjp #
python
import jax
import jax.numpy as jnp
@jax.custom_vjp
def my_function(x):
return x ** 2
def f_fwd(x):
return my_function(x), x
def f_bwd(x, g):
return (2 * x * g,)
my_function.defvjp(f_fwd, f_bwd)
grad_fn = jax.grad(my_function)
print(f"f'(3) = {grad_fn(3.0)}")
自定义梯度 #
python
import jax
import jax.numpy as jnp
@jax.custom_vjp
def clip_gradient(x):
return x
def clip_fwd(x):
return x, None
def clip_bwd(_, g):
return (jnp.clip(g, -1, 1),)
clip_gradient.defvjp(clip_fwd, clip_bwd)
def loss_fn(x):
y = clip_gradient(x)
return y ** 2
grad_fn = jax.grad(loss_fn)
print(f"梯度: {grad_fn(10.0)}")
纯 Python 自定义操作 #
使用 lax #
python
import jax
import jax.numpy as jnp
from jax import lax
def my_cumsum(x):
def scan_fn(carry, val):
new_carry = carry + val
return new_carry, new_carry
_, result = lax.scan(scan_fn, 0, x)
return result
x = jnp.array([1, 2, 3, 4, 5])
print(f"累积和: {my_cumsum(x)}")
使用 jax.core #
python
import jax
from jax import core
my_p = core.Primitive('my_op')
def my_op(x):
return my_p.bind(x)
def my_op_impl(x):
return x * 2
my_p.def_impl(my_op_impl)
def my_op_jvp(primals, tangents):
x, = primals
x_dot, = tangents
return my_op(x), 2 * x_dot
from jax.interpreters import ad
ad.primitive_jvps[my_p] = my_op_jvp
下一步 #
现在你已经掌握了自定义操作,接下来学习 性能优化,了解如何优化 JAX 代码!
最后更新:2026-04-04