JAX 简介 #

什么是 JAX? #

JAX 是由 Google 开发的高性能机器学习研究框架。它的名字来源于关键组件的首字母缩写:JIT(即时编译)、Autograd(自动微分)和 XLA(加速线性代数)。JAX 提供了一个统一的 API,能够在 CPU、GPU 和 TPU 上高效运行。

核心定位 #

text
┌─────────────────────────────────────────────────────────────┐
│                         JAX                                  │
├─────────────────────────────────────────────────────────────┤
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  自动微分    │  │  自动向量化  │  │  JIT 编译   │         │
│  │   (grad)    │  │   (vmap)    │  │   (jit)     │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  NumPy API  │  │  函数式设计  │  │  多设备支持  │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
└─────────────────────────────────────────────────────────────┘

JAX 的历史 #

发展历程 #

text
2018年 ─── JAX 项目启动
    │
    │      Google Research 内部项目
    │      结合 Autograd 和 XLA
    │
2019年 ─── 开源发布
    │
    │      GitHub 开源
    │      社区开始关注
    │
2020年 ─── 生态系统发展
    │
    │      Flax、Haiku 等库发布
    │      研究论文广泛使用
    │
2021年 ─── 功能增强
    │
    │      pjit 分布式支持
    │      TPU 支持完善
    │
至今   ─── 广泛应用
    │
    │      DeepMind 大量使用
    │      学术研究首选框架
    │      大模型训练

设计目标 #

JAX 的设计目标是:

  1. 可组合的函数变换:不同的变换可以自由组合
  2. 高性能:通过 JIT 编译获得接近 C++ 的性能
  3. 可移植性:同一代码在 CPU、GPU、TPU 上运行
  4. 研究友好:灵活、表达力强、易于实验

JAX 的核心特点 #

1. 函数变换(Function Transformations) #

JAX 的核心是函数变换,它们可以组合使用:

python
import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x ** 2)

x = jnp.array([1.0, 2.0, 3.0])

grad_f = jax.grad(f)           f_grad = grad_f(x)             
vmap_f = jax.vmap(f)           batch_result = vmap_f(batch_x) 
jit_f = jax.jit(f)             jit_result = jit_f(x)          

jit_grad_f = jax.jit(jax.grad(f))

2. NumPy 兼容 API #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3])
y = jnp.zeros((3, 4))
z = jnp.dot(x, y)

3. 自动微分 #

python
def loss(params, x, y):
    predict = jnp.dot(x, params)
    return jnp.mean((predict - y) ** 2)

grad_loss = jax.grad(loss)
gradients = grad_loss(params, x, y)

4. 自动向量化 #

python
def process_single(x):
    return jnp.sum(x ** 2)

process_batch = jax.vmap(process_single)
batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch_x)  

5. JIT 编译 #

python
@jax.jit
def fast_function(x):
    return jnp.dot(x, x.T)

result = fast_function(x)  

JAX 解决的问题 #

传统框架的痛点 #

text
┌─────────────────────────────────────────────────────────────┐
│                    传统框架的问题                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 性能优化复杂                                             │
│     - 需要手动优化                                           │
│     - 编译和执行分离                                         │
│                                                             │
│  2. 批处理繁琐                                               │
│     - 手动编写批处理代码                                     │
│     - 代码重复                                               │
│                                                             │
│  3. 高阶微分困难                                             │
│     - 二阶导数实现复杂                                       │
│     - 性能不佳                                               │
│                                                             │
│  4. 设备迁移成本高                                           │
│     - CPU/GPU/TPU 代码不同                                   │
│     - 需要重写                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

JAX 的解决方案 #

text
┌─────────────────────────────────────────────────────────────┐
│                    JAX 的解决方案                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✅ JIT 自动优化                                             │
│     - 一行代码启用编译                                       │
│     - XLA 自动优化                                          │
│                                                             │
│  ✅ vmap 自动向量化                                          │
│     - 无需修改原函数                                         │
│     - 自动批处理                                            │
│                                                             │
│  ✅ 高阶微分简单                                             │
│     - grad 可以嵌套                                         │
│     - grad(grad(f)) 即可                                    │
│                                                             │
│  ✅ 设备无关代码                                             │
│     - 同一代码多设备运行                                     │
│     - 无需修改                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

JAX 的应用场景 #

1. 学术研究 #

python
@jax.jit
def neural_ode(params, x, t):
    """神经 ODE 研究"""
    def dynamics(x, t):
        return jnp.dot(params, x)
    return jax.experimental.ode.odeint(dynamics, x, t)

2. 大规模机器学习 #

python
@jax.pmap
def train_step(params, batch):
    """多设备并行训练"""
    grads = jax.grad(loss)(params, batch)
    return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)

3. 科学计算 #

python
def simulate(params, state):
    """物理模拟"""
    return jax.lax.fori_loop(0, 1000, step_fn, state)

4. 生成模型 #

python
def diffusion_step(params, x, t):
    """扩散模型"""
    return jax.grad(score_function)(params, x, t)

JAX vs 其他框架 #

与 PyTorch 对比 #

特性 JAX PyTorch
编程范式 函数式 面向对象
自动微分 grad 函数 autograd 模块
向量化 vmap 自动 手动批处理
编译优化 JIT + XLA TorchScript
调试 Python 原生 Python 原生
TPU 支持 原生优秀 有限支持
社区规模 较小但活跃 非常大

与 TensorFlow 对比 #

特性 JAX TensorFlow
编程范式 函数式 图/函数式
即时执行 默认 tf.function
自动微分 grad 函数 GradientTape
分布式 pmap/pjit Strategy
生产部署 有限 完善
研究友好 非常好 一般

JAX 的优势与局限 #

优势 #

text
✅ 可组合的函数变换
   - grad、vmap、jit 可以任意组合
   - 代码简洁优雅

✅ 高性能
   - JIT 编译优化
   - XLA 后端加速

✅ 多设备支持
   - CPU、GPU、TPU 统一 API
   - 分布式训练简单

✅ 函数式设计
   - 纯函数易于推理
   - 易于测试和调试

✅ 高阶微分
   - 轻松计算高阶导数
   - 适合研究场景

局限性 #

text
⚠️ 学习曲线
   - 函数式编程思维
   - 需要理解变换概念

⚠️ 生态系统
   - 比 PyTorch 小
   - 部分领域库不完善

⚠️ 生产部署
   - 不如 TensorFlow 成熟
   - 服务化支持有限

⚠️ 副作用处理
   - 需要理解纯函数
   - 状态管理需要技巧

JAX 生态系统 #

核心库 #

描述
jax 核心库
jaxlib 编译后端
jax.numpy NumPy 兼容 API

神经网络库 #

描述
Flax 灵活的神经网络库
Haiku Sonnet 风格的神经网络库
Equinox 函数式神经网络库
Trax 端到端深度学习

工具库 #

描述
Optax 优化器库
Chex 测试工具
JAX-MD 分子动力学
RLax 强化学习

学习路径 #

text
入门阶段
├── JAX 简介(本文)
├── 安装与配置
├── JAX 基础概念
└── NumPy 兼容性

基础阶段
├── 自动微分 (grad)
├── 自动向量化 (vmap)
├── JIT 编译
└── 函数变换组合

进阶阶段
├── 数组操作
├── 线性代数
├── 随机数生成
└── 控制流

高级阶段
├── 构建神经网络
├── 状态管理
├── 训练循环
└── 模型保存与加载

下一步 #

现在你已经了解了 JAX 的基本概念,接下来学习 安装与配置,开始你的 JAX 实践之旅!

最后更新:2026-04-04