调试技巧 #

概述 #

本节介绍 JAX 调试的技巧和工具,帮助你快速定位和解决问题。

常见错误 #

形状不匹配 #

python
import jax
import jax.numpy as jnp

def compute(x, w):
    return jnp.dot(x, w)

try:
    x = jnp.ones((10, 5))
    w = jnp.ones((3, 4))
    result = compute(x, w)
except Exception as e:
    print(f"错误: {e}")

x = jnp.ones((10, 5))
w = jnp.ones((5, 4))
result = compute(x, w)
print(f"正确结果形状: {result.shape}")

类型不匹配 #

python
import jax
import jax.numpy as jnp

def compute(x):
    return x + 1

x_int = jnp.array([1, 2, 3])
x_float = jnp.array([1.0, 2.0, 3.0])

print(f"整数结果: {compute(x_int)}")
print(f"浮点结果: {compute(x_float)}")

JIT 追踪错误 #

python
import jax
import jax.numpy as jnp

@jax.jit
def bad_function(x):
    if x > 0:  
        return x
    else:
        return -x

@jax.jit
def good_function(x):
    return jnp.where(x > 0, x, -x)

x = jnp.array([-1, 2, -3, 4])
result = good_function(x)
print(f"结果: {result}")

调试工具 #

jax.debug.print #

python
import jax
import jax.numpy as jnp

@jax.jit
def debug_function(x):
    jax.debug.print("x = {}", x)
    y = x + 1
    jax.debug.print("y = {}", y)
    return y

x = jnp.array([1, 2, 3])
result = debug_function(x)

jax.debug.breakpoint #

python
import jax
import jax.numpy as jnp

@jax.jit
def debug_with_breakpoint(x):
    y = x + 1
    jax.debug.breakpoint()  
    return y

x = jnp.array([1, 2, 3])
result = debug_with_breakpoint(x)

禁用 JIT #

python
import jax

jax.config.update('jax_disable_jit', True)

@jax.jit
def now_debuggable(x):
    print(f"调试: x = {x}")
    return x + 1

x = jnp.array([1, 2, 3])
result = now_debuggable(x)

检查中间值 #

make_jaxpr #

python
import jax
import jax.numpy as jnp

def compute(x):
    y = x + 1
    z = y ** 2
    return z

jaxpr = jax.make_jaxpr(compute)(jnp.array([1.0, 2.0, 3.0]))
print(f"JAXPR:\n{jaxpr}")

检查梯度 #

python
import jax
import jax.numpy as jnp

def loss_fn(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])
grad = jax.grad(loss_fn)(x)
print(f"梯度: {grad}")

expected_grad = 2 * x
print(f"预期梯度: {expected_grad}")
print(f"匹配: {jnp.allclose(grad, expected_grad)}")

下一步 #

现在你已经掌握了调试技巧,接下来学习 常见问题,了解更多解决方案!

最后更新:2026-04-04