调试技巧 #
概述 #
本节介绍 JAX 调试的技巧和工具,帮助你快速定位和解决问题。
常见错误 #
形状不匹配 #
python
import jax
import jax.numpy as jnp
def compute(x, w):
return jnp.dot(x, w)
try:
x = jnp.ones((10, 5))
w = jnp.ones((3, 4))
result = compute(x, w)
except Exception as e:
print(f"错误: {e}")
x = jnp.ones((10, 5))
w = jnp.ones((5, 4))
result = compute(x, w)
print(f"正确结果形状: {result.shape}")
类型不匹配 #
python
import jax
import jax.numpy as jnp
def compute(x):
return x + 1
x_int = jnp.array([1, 2, 3])
x_float = jnp.array([1.0, 2.0, 3.0])
print(f"整数结果: {compute(x_int)}")
print(f"浮点结果: {compute(x_float)}")
JIT 追踪错误 #
python
import jax
import jax.numpy as jnp
@jax.jit
def bad_function(x):
if x > 0:
return x
else:
return -x
@jax.jit
def good_function(x):
return jnp.where(x > 0, x, -x)
x = jnp.array([-1, 2, -3, 4])
result = good_function(x)
print(f"结果: {result}")
调试工具 #
jax.debug.print #
python
import jax
import jax.numpy as jnp
@jax.jit
def debug_function(x):
jax.debug.print("x = {}", x)
y = x + 1
jax.debug.print("y = {}", y)
return y
x = jnp.array([1, 2, 3])
result = debug_function(x)
jax.debug.breakpoint #
python
import jax
import jax.numpy as jnp
@jax.jit
def debug_with_breakpoint(x):
y = x + 1
jax.debug.breakpoint()
return y
x = jnp.array([1, 2, 3])
result = debug_with_breakpoint(x)
禁用 JIT #
python
import jax
jax.config.update('jax_disable_jit', True)
@jax.jit
def now_debuggable(x):
print(f"调试: x = {x}")
return x + 1
x = jnp.array([1, 2, 3])
result = now_debuggable(x)
检查中间值 #
make_jaxpr #
python
import jax
import jax.numpy as jnp
def compute(x):
y = x + 1
z = y ** 2
return z
jaxpr = jax.make_jaxpr(compute)(jnp.array([1.0, 2.0, 3.0]))
print(f"JAXPR:\n{jaxpr}")
检查梯度 #
python
import jax
import jax.numpy as jnp
def loss_fn(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
grad = jax.grad(loss_fn)(x)
print(f"梯度: {grad}")
expected_grad = 2 * x
print(f"预期梯度: {expected_grad}")
print(f"匹配: {jnp.allclose(grad, expected_grad)}")
下一步 #
现在你已经掌握了调试技巧,接下来学习 常见问题,了解更多解决方案!
最后更新:2026-04-04