安装与配置 #

安装概述 #

JAX 支持多种安装方式和平台。根据你的硬件环境,可以选择 CPU 版本或 GPU 版本。

安装要求 #

要求 说明
Python 3.9 - 3.12
操作系统 Linux, macOS, Windows (WSL)
pip 最新版本

CPU 版本安装 #

使用 pip 安装 #

bash
pip install -U jax

使用 conda 安装 #

bash
conda install -c conda-forge jax

验证安装 #

python
import jax
import jax.numpy as jnp

print(f"JAX 版本: {jax.__version__}")
print(f"默认后端: {jax.default_backend()}")

x = jnp.array([1.0, 2.0, 3.0])
print(f"数组: {x}")
print(f"设备: {x.devices()}")

GPU 版本安装 #

CUDA 支持 #

JAX 支持 NVIDIA GPU,需要 CUDA 和 cuDNN。

检查 CUDA 版本 #

bash
nvidia-smi
nvcc --version

安装 GPU 版本 #

bash
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

验证 GPU 安装 #

python
import jax

print(f"设备列表: {jax.devices()}")
print(f"设备数量: {jax.device_count()}")
print(f"默认设备: {jax.default_backend()}")

import jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
print(f"数组设备: {x.devices()}")

CUDA 版本对应 #

JAX 版本 CUDA 版本 cuDNN 版本
最新 12.x 8.6+
最新 11.8 8.6+

安装特定 CUDA 版本 #

bash
pip install -U "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

TPU 版本配置 #

Google Cloud TPU #

在 Google Cloud TPU 上使用 JAX:

python
import jax

print(f"TPU 设备: {jax.devices()}")
print(f"TPU 数量: {jax.device_count()}")

TPU VM 设置 #

bash
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Colab TPU #

在 Google Colab 中使用 TPU:

python
import jax
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()

print(f"TPU 设备: {jax.devices()}")

开发版本安装 #

从源码安装 #

bash
git clone https://github.com/google/jax.git
cd jax
pip install -e .

安装最新开发版 #

bash
pip install -U jaxlib==0.4.24+cuda12.cudnn89 \
  -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

环境配置 #

虚拟环境 #

推荐使用虚拟环境:

bash
python -m venv jax-env
source jax-env/bin/activate  # Linux/macOS
jax-env\Scripts\activate     # Windows

pip install -U jax

Conda 环境 #

bash
conda create -n jax-env python=3.11
conda activate jax-env
conda install -c conda-forge jax

requirements.txt #

text
jax>=0.4.24
jaxlib>=0.4.24
numpy

设备选择 #

指定默认设备 #

python
import jax

jax.config.update('jax_default_device', jax.devices('gpu')[0])

强制使用 CPU #

python
import jax

with jax.default_device(jax.devices('cpu')[0]):
    x = jax.numpy.array([1, 2, 3])

设备上下文 #

python
import jax

devices = jax.devices('gpu')
with jax.default_device(devices[0]):
    result = some_computation()

配置选项 #

常用配置 #

python
import jax

jax.config.update('jax_enable_x64', True)  
jax.config.update('jax_disable_jit', True)  
jax.config.update('jax_debug_nans', True)  

环境变量配置 #

bash
export JAX_ENABLE_X64=1
export JAX_DISABLE_JIT=0
export JAX_DEBUG_NANS=1

配置文件 #

python
from jax import config

config.update('jax_enable_x64', True)
config.update('jax_platform_name', 'gpu')

性能配置 #

内存预分配 #

python
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.8'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'

内存碎片整理 #

python
import jax
import gc

jax.clear_backends()
gc.collect()

验证安装 #

完整验证脚本 #

python
import jax
import jax.numpy as jnp
import time

def verify_installation():
    print("=" * 50)
    print("JAX 安装验证")
    print("=" * 50)
    
    print(f"\n1. 版本信息")
    print(f"   JAX 版本: {jax.__version__}")
    
    print(f"\n2. 设备信息")
    print(f"   默认后端: {jax.default_backend()}")
    print(f"   设备列表: {jax.devices()}")
    print(f"   设备数量: {jax.device_count()}")
    
    print(f"\n3. 基本操作测试")
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    print(f"   数组创建: {x}")
    print(f"   数组运算: {x + y}")
    
    print(f"\n4. JIT 编译测试")
    @jax.jit
    def compute(x):
        return jnp.sum(x ** 2)
    
    result = compute(x)
    print(f"   JIT 结果: {result}")
    
    print(f"\n5. 自动微分测试")
    grad_fn = jax.grad(lambda x: jnp.sum(x ** 2))
    grad_result = grad_fn(x)
    print(f"   梯度结果: {grad_result}")
    
    print(f"\n6. 性能测试")
    large_x = jnp.ones((10000, 10000))
    
    start = time.time()
    result = jnp.dot(large_x, large_x.T)
    result.block_until_ready()
    end = time.time()
    
    print(f"   矩阵乘法时间: {end - start:.4f} 秒")
    
    print("\n" + "=" * 50)
    print("验证完成!")
    print("=" * 50)

if __name__ == "__main__":
    verify_installation()

常见问题 #

问题 1: GPU 未被识别 #

python
import jax
print(jax.devices())

如果只显示 CPU 设备:

bash
pip uninstall jax jaxlib
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

问题 2: CUDA 版本不匹配 #

bash
pip install jax[cuda11_local] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

问题 3: 内存不足 #

python
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5'

问题 4: JIT 编译错误 #

python
import jax
jax.config.update('jax_disable_jit', True)  

下一步 #

安装完成后,继续学习 JAX 基础概念,了解 JAX 的核心概念!

最后更新:2026-04-04