常见问题 #

安装问题 #

Q: GPU 未被识别? #

python
import jax
print(jax.devices())

解决方案:

bash
pip uninstall jax jaxlib
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Q: CUDA 版本不匹配? #

检查 CUDA 版本:

bash
nvidia-smi
nvcc --version

安装对应版本:

bash
pip install jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

数组操作问题 #

Q: 如何修改数组元素? #

JAX 数组是不可变的:

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3])

x = x.at[0].set(10)
print(x)  

Q: 如何处理整数数组? #

JAX 默认使用 float32:

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3], dtype=jnp.int32)
print(f"类型: {x.dtype}")

JIT 问题 #

Q: JIT 编译报错? #

条件语句问题:

python
import jax
import jax.numpy as jnp

@jax.jit
def bad(x):
    if x > 0:  
        return x
    return -x

@jax.jit
def good(x):
    return jnp.where(x > 0, x, -x)

Q: 如何调试 JIT 函数? #

python
import jax

jax.config.update('jax_disable_jit', True)

@jax.jit
def debuggable(x):
    print(f"x = {x}")  
    return x + 1

梯度问题 #

Q: 梯度为 NaN? #

python
import jax
import jax.numpy as jnp

jax.config.update('jax_debug_nans', True)

def loss_fn(x):
    return jnp.sqrt(x)  

x = jnp.array(-1.0)
try:
    grad = jax.grad(loss_fn)(x)
except Exception as e:
    print(f"错误: {e}")

Q: 如何停止梯度? #

python
import jax
import jax.numpy as jnp

def loss_fn(x, y):
    y_stop = jax.lax.stop_gradient(y)
    return jnp.sum((x - y_stop) ** 2)

性能问题 #

Q: 代码运行慢? #

使用 JIT:

python
import jax

@jax.jit
def fast_compute(x):
    return x @ x.T

Q: 内存不足? #

python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

import jax

随机数问题 #

Q: 随机数不可复现? #

python
import jax

key = jax.random.PRNGKey(42)

key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)

x1 = jax.random.normal(subkey1, (3,))
x2 = jax.random.normal(subkey2, (3,))

print(f"x1: {x1}")
print(f"x2: {x2}")

下一步 #

恭喜你完成了 JAX 文档的学习!现在可以开始 线性回归实战,将所学知识应用到实际项目中!

最后更新:2026-04-04