自定义操作 #

概述 #

JAX 允许用户定义自定义操作,包括自定义微分规则和与原生代码集成。

自定义微分 #

custom_jvp #

python
import jax
import jax.numpy as jnp

@jax.custom_jvp
def my_function(x):
    return x ** 2

@my_function.defjvp
def my_function_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    primal_out = my_function(x)
    tangent_out = 2 * x * x_dot
    return primal_out, tangent_out

grad_fn = jax.grad(my_function)
print(f"f'(3) = {grad_fn(3.0)}")

custom_vjp #

python
import jax
import jax.numpy as jnp

@jax.custom_vjp
def my_function(x):
    return x ** 2

def f_fwd(x):
    return my_function(x), x

def f_bwd(x, g):
    return (2 * x * g,)

my_function.defvjp(f_fwd, f_bwd)

grad_fn = jax.grad(my_function)
print(f"f'(3) = {grad_fn(3.0)}")

自定义梯度 #

python
import jax
import jax.numpy as jnp

@jax.custom_vjp
def clip_gradient(x):
    return x

def clip_fwd(x):
    return x, None

def clip_bwd(_, g):
    return (jnp.clip(g, -1, 1),)

clip_gradient.defvjp(clip_fwd, clip_bwd)

def loss_fn(x):
    y = clip_gradient(x)
    return y ** 2

grad_fn = jax.grad(loss_fn)
print(f"梯度: {grad_fn(10.0)}")

纯 Python 自定义操作 #

使用 lax #

python
import jax
import jax.numpy as jnp
from jax import lax

def my_cumsum(x):
    def scan_fn(carry, val):
        new_carry = carry + val
        return new_carry, new_carry
    _, result = lax.scan(scan_fn, 0, x)
    return result

x = jnp.array([1, 2, 3, 4, 5])
print(f"累积和: {my_cumsum(x)}")

使用 jax.core #

python
import jax
from jax import core

my_p = core.Primitive('my_op')

def my_op(x):
    return my_p.bind(x)

def my_op_impl(x):
    return x * 2

my_p.def_impl(my_op_impl)

def my_op_jvp(primals, tangents):
    x, = primals
    x_dot, = tangents
    return my_op(x), 2 * x_dot

from jax.interpreters import ad
ad.primitive_jvps[my_p] = my_op_jvp

下一步 #

现在你已经掌握了自定义操作,接下来学习 性能优化,了解如何优化 JAX 代码!

最后更新:2026-04-04