安装与配置 #
安装概述 #
JAX 支持多种安装方式和平台。根据你的硬件环境,可以选择 CPU 版本或 GPU 版本。
安装要求 #
| 要求 | 说明 |
|---|---|
| Python | 3.9 - 3.12 |
| 操作系统 | Linux, macOS, Windows (WSL) |
| pip | 最新版本 |
CPU 版本安装 #
使用 pip 安装 #
bash
pip install -U jax
使用 conda 安装 #
bash
conda install -c conda-forge jax
验证安装 #
python
import jax
import jax.numpy as jnp
print(f"JAX 版本: {jax.__version__}")
print(f"默认后端: {jax.default_backend()}")
x = jnp.array([1.0, 2.0, 3.0])
print(f"数组: {x}")
print(f"设备: {x.devices()}")
GPU 版本安装 #
CUDA 支持 #
JAX 支持 NVIDIA GPU,需要 CUDA 和 cuDNN。
检查 CUDA 版本 #
bash
nvidia-smi
nvcc --version
安装 GPU 版本 #
bash
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
验证 GPU 安装 #
python
import jax
print(f"设备列表: {jax.devices()}")
print(f"设备数量: {jax.device_count()}")
print(f"默认设备: {jax.default_backend()}")
import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
print(f"数组设备: {x.devices()}")
CUDA 版本对应 #
| JAX 版本 | CUDA 版本 | cuDNN 版本 |
|---|---|---|
| 最新 | 12.x | 8.6+ |
| 最新 | 11.8 | 8.6+ |
安装特定 CUDA 版本 #
bash
pip install -U "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
TPU 版本配置 #
Google Cloud TPU #
在 Google Cloud TPU 上使用 JAX:
python
import jax
print(f"TPU 设备: {jax.devices()}")
print(f"TPU 数量: {jax.device_count()}")
TPU VM 设置 #
bash
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Colab TPU #
在 Google Colab 中使用 TPU:
python
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(f"TPU 设备: {jax.devices()}")
开发版本安装 #
从源码安装 #
bash
git clone https://github.com/google/jax.git
cd jax
pip install -e .
安装最新开发版 #
bash
pip install -U jaxlib==0.4.24+cuda12.cudnn89 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
环境配置 #
虚拟环境 #
推荐使用虚拟环境:
bash
python -m venv jax-env
source jax-env/bin/activate # Linux/macOS
jax-env\Scripts\activate # Windows
pip install -U jax
Conda 环境 #
bash
conda create -n jax-env python=3.11
conda activate jax-env
conda install -c conda-forge jax
requirements.txt #
text
jax>=0.4.24
jaxlib>=0.4.24
numpy
设备选择 #
指定默认设备 #
python
import jax
jax.config.update('jax_default_device', jax.devices('gpu')[0])
强制使用 CPU #
python
import jax
with jax.default_device(jax.devices('cpu')[0]):
x = jax.numpy.array([1, 2, 3])
设备上下文 #
python
import jax
devices = jax.devices('gpu')
with jax.default_device(devices[0]):
result = some_computation()
配置选项 #
常用配置 #
python
import jax
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_disable_jit', True)
jax.config.update('jax_debug_nans', True)
环境变量配置 #
bash
export JAX_ENABLE_X64=1
export JAX_DISABLE_JIT=0
export JAX_DEBUG_NANS=1
配置文件 #
python
from jax import config
config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'gpu')
性能配置 #
内存预分配 #
python
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
内存碎片整理 #
python
import jax
import gc
jax.clear_backends()
gc.collect()
验证安装 #
完整验证脚本 #
python
import jax
import jax.numpy as jnp
import time
def verify_installation():
print("=" * 50)
print("JAX 安装验证")
print("=" * 50)
print(f"\n1. 版本信息")
print(f" JAX 版本: {jax.__version__}")
print(f"\n2. 设备信息")
print(f" 默认后端: {jax.default_backend()}")
print(f" 设备列表: {jax.devices()}")
print(f" 设备数量: {jax.device_count()}")
print(f"\n3. 基本操作测试")
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
print(f" 数组创建: {x}")
print(f" 数组运算: {x + y}")
print(f"\n4. JIT 编译测试")
@jax.jit
def compute(x):
return jnp.sum(x ** 2)
result = compute(x)
print(f" JIT 结果: {result}")
print(f"\n5. 自动微分测试")
grad_fn = jax.grad(lambda x: jnp.sum(x ** 2))
grad_result = grad_fn(x)
print(f" 梯度结果: {grad_result}")
print(f"\n6. 性能测试")
large_x = jnp.ones((10000, 10000))
start = time.time()
result = jnp.dot(large_x, large_x.T)
result.block_until_ready()
end = time.time()
print(f" 矩阵乘法时间: {end - start:.4f} 秒")
print("\n" + "=" * 50)
print("验证完成!")
print("=" * 50)
if __name__ == "__main__":
verify_installation()
常见问题 #
问题 1: GPU 未被识别 #
python
import jax
print(jax.devices())
如果只显示 CPU 设备:
bash
pip uninstall jax jaxlib
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
问题 2: CUDA 版本不匹配 #
bash
pip install jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
问题 3: 内存不足 #
python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'
问题 4: JIT 编译错误 #
python
import jax
jax.config.update('jax_disable_jit', True)
下一步 #
安装完成后,继续学习 JAX 基础概念,了解 JAX 的核心概念!
最后更新:2026-04-04