常见问题 #
安装问题 #
Q: GPU 未被识别? #
python
import jax
print(jax.devices())
解决方案:
bash
pip uninstall jax jaxlib
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Q: CUDA 版本不匹配? #
检查 CUDA 版本:
bash
nvidia-smi
nvcc --version
安装对应版本:
bash
pip install jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
数组操作问题 #
Q: 如何修改数组元素? #
JAX 数组是不可变的:
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x = x.at[0].set(10)
print(x)
Q: 如何处理整数数组? #
JAX 默认使用 float32:
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3], dtype=jnp.int32)
print(f"类型: {x.dtype}")
JIT 问题 #
Q: JIT 编译报错? #
条件语句问题:
python
import jax
import jax.numpy as jnp
@jax.jit
def bad(x):
if x > 0:
return x
return -x
@jax.jit
def good(x):
return jnp.where(x > 0, x, -x)
Q: 如何调试 JIT 函数? #
python
import jax
jax.config.update('jax_disable_jit', True)
@jax.jit
def debuggable(x):
print(f"x = {x}")
return x + 1
梯度问题 #
Q: 梯度为 NaN? #
python
import jax
import jax.numpy as jnp
jax.config.update('jax_debug_nans', True)
def loss_fn(x):
return jnp.sqrt(x)
x = jnp.array(-1.0)
try:
grad = jax.grad(loss_fn)(x)
except Exception as e:
print(f"错误: {e}")
Q: 如何停止梯度? #
python
import jax
import jax.numpy as jnp
def loss_fn(x, y):
y_stop = jax.lax.stop_gradient(y)
return jnp.sum((x - y_stop) ** 2)
性能问题 #
Q: 代码运行慢? #
使用 JIT:
python
import jax
@jax.jit
def fast_compute(x):
return x @ x.T
Q: 内存不足? #
python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'
import jax
随机数问题 #
Q: 随机数不可复现? #
python
import jax
key = jax.random.PRNGKey(42)
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)
x1 = jax.random.normal(subkey1, (3,))
x2 = jax.random.normal(subkey2, (3,))
print(f"x1: {x1}")
print(f"x2: {x2}")
下一步 #
恭喜你完成了 JAX 文档的学习!现在可以开始 线性回归实战,将所学知识应用到实际项目中!
最后更新:2026-04-04