JAX 简介 #
什么是 JAX? #
JAX 是由 Google 开发的高性能机器学习研究框架。它的名字来源于关键组件的首字母缩写:JIT(即时编译)、Autograd(自动微分)和 XLA(加速线性代数)。JAX 提供了一个统一的 API,能够在 CPU、GPU 和 TPU 上高效运行。
核心定位 #
text
┌─────────────────────────────────────────────────────────────┐
│ JAX │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 自动微分 │ │ 自动向量化 │ │ JIT 编译 │ │
│ │ (grad) │ │ (vmap) │ │ (jit) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ NumPy API │ │ 函数式设计 │ │ 多设备支持 │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
JAX 的历史 #
发展历程 #
text
2018年 ─── JAX 项目启动
│
│ Google Research 内部项目
│ 结合 Autograd 和 XLA
│
2019年 ─── 开源发布
│
│ GitHub 开源
│ 社区开始关注
│
2020年 ─── 生态系统发展
│
│ Flax、Haiku 等库发布
│ 研究论文广泛使用
│
2021年 ─── 功能增强
│
│ pjit 分布式支持
│ TPU 支持完善
│
至今 ─── 广泛应用
│
│ DeepMind 大量使用
│ 学术研究首选框架
│ 大模型训练
设计目标 #
JAX 的设计目标是:
- 可组合的函数变换:不同的变换可以自由组合
- 高性能:通过 JIT 编译获得接近 C++ 的性能
- 可移植性:同一代码在 CPU、GPU、TPU 上运行
- 研究友好:灵活、表达力强、易于实验
JAX 的核心特点 #
1. 函数变换(Function Transformations) #
JAX 的核心是函数变换,它们可以组合使用:
python
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x ** 2)
x = jnp.array([1.0, 2.0, 3.0])
grad_f = jax.grad(f) f_grad = grad_f(x)
vmap_f = jax.vmap(f) batch_result = vmap_f(batch_x)
jit_f = jax.jit(f) jit_result = jit_f(x)
jit_grad_f = jax.jit(jax.grad(f))
2. NumPy 兼容 API #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
y = jnp.zeros((3, 4))
z = jnp.dot(x, y)
3. 自动微分 #
python
def loss(params, x, y):
predict = jnp.dot(x, params)
return jnp.mean((predict - y) ** 2)
grad_loss = jax.grad(loss)
gradients = grad_loss(params, x, y)
4. 自动向量化 #
python
def process_single(x):
return jnp.sum(x ** 2)
process_batch = jax.vmap(process_single)
batch_x = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch_x)
5. JIT 编译 #
python
@jax.jit
def fast_function(x):
return jnp.dot(x, x.T)
result = fast_function(x)
JAX 解决的问题 #
传统框架的痛点 #
text
┌─────────────────────────────────────────────────────────────┐
│ 传统框架的问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 性能优化复杂 │
│ - 需要手动优化 │
│ - 编译和执行分离 │
│ │
│ 2. 批处理繁琐 │
│ - 手动编写批处理代码 │
│ - 代码重复 │
│ │
│ 3. 高阶微分困难 │
│ - 二阶导数实现复杂 │
│ - 性能不佳 │
│ │
│ 4. 设备迁移成本高 │
│ - CPU/GPU/TPU 代码不同 │
│ - 需要重写 │
│ │
└─────────────────────────────────────────────────────────────┘
JAX 的解决方案 #
text
┌─────────────────────────────────────────────────────────────┐
│ JAX 的解决方案 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ✅ JIT 自动优化 │
│ - 一行代码启用编译 │
│ - XLA 自动优化 │
│ │
│ ✅ vmap 自动向量化 │
│ - 无需修改原函数 │
│ - 自动批处理 │
│ │
│ ✅ 高阶微分简单 │
│ - grad 可以嵌套 │
│ - grad(grad(f)) 即可 │
│ │
│ ✅ 设备无关代码 │
│ - 同一代码多设备运行 │
│ - 无需修改 │
│ │
└─────────────────────────────────────────────────────────────┘
JAX 的应用场景 #
1. 学术研究 #
python
@jax.jit
def neural_ode(params, x, t):
"""神经 ODE 研究"""
def dynamics(x, t):
return jnp.dot(params, x)
return jax.experimental.ode.odeint(dynamics, x, t)
2. 大规模机器学习 #
python
@jax.pmap
def train_step(params, batch):
"""多设备并行训练"""
grads = jax.grad(loss)(params, batch)
return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
3. 科学计算 #
python
def simulate(params, state):
"""物理模拟"""
return jax.lax.fori_loop(0, 1000, step_fn, state)
4. 生成模型 #
python
def diffusion_step(params, x, t):
"""扩散模型"""
return jax.grad(score_function)(params, x, t)
JAX vs 其他框架 #
与 PyTorch 对比 #
| 特性 | JAX | PyTorch |
|---|---|---|
| 编程范式 | 函数式 | 面向对象 |
| 自动微分 | grad 函数 | autograd 模块 |
| 向量化 | vmap 自动 | 手动批处理 |
| 编译优化 | JIT + XLA | TorchScript |
| 调试 | Python 原生 | Python 原生 |
| TPU 支持 | 原生优秀 | 有限支持 |
| 社区规模 | 较小但活跃 | 非常大 |
与 TensorFlow 对比 #
| 特性 | JAX | TensorFlow |
|---|---|---|
| 编程范式 | 函数式 | 图/函数式 |
| 即时执行 | 默认 | tf.function |
| 自动微分 | grad 函数 | GradientTape |
| 分布式 | pmap/pjit | Strategy |
| 生产部署 | 有限 | 完善 |
| 研究友好 | 非常好 | 一般 |
JAX 的优势与局限 #
优势 #
text
✅ 可组合的函数变换
- grad、vmap、jit 可以任意组合
- 代码简洁优雅
✅ 高性能
- JIT 编译优化
- XLA 后端加速
✅ 多设备支持
- CPU、GPU、TPU 统一 API
- 分布式训练简单
✅ 函数式设计
- 纯函数易于推理
- 易于测试和调试
✅ 高阶微分
- 轻松计算高阶导数
- 适合研究场景
局限性 #
text
⚠️ 学习曲线
- 函数式编程思维
- 需要理解变换概念
⚠️ 生态系统
- 比 PyTorch 小
- 部分领域库不完善
⚠️ 生产部署
- 不如 TensorFlow 成熟
- 服务化支持有限
⚠️ 副作用处理
- 需要理解纯函数
- 状态管理需要技巧
JAX 生态系统 #
核心库 #
| 库 | 描述 |
|---|---|
| jax | 核心库 |
| jaxlib | 编译后端 |
| jax.numpy | NumPy 兼容 API |
神经网络库 #
| 库 | 描述 |
|---|---|
| Flax | 灵活的神经网络库 |
| Haiku | Sonnet 风格的神经网络库 |
| Equinox | 函数式神经网络库 |
| Trax | 端到端深度学习 |
工具库 #
| 库 | 描述 |
|---|---|
| Optax | 优化器库 |
| Chex | 测试工具 |
| JAX-MD | 分子动力学 |
| RLax | 强化学习 |
学习路径 #
text
入门阶段
├── JAX 简介(本文)
├── 安装与配置
├── JAX 基础概念
└── NumPy 兼容性
基础阶段
├── 自动微分 (grad)
├── 自动向量化 (vmap)
├── JIT 编译
└── 函数变换组合
进阶阶段
├── 数组操作
├── 线性代数
├── 随机数生成
└── 控制流
高级阶段
├── 构建神经网络
├── 状态管理
├── 训练循环
└── 模型保存与加载
下一步 #
现在你已经了解了 JAX 的基本概念,接下来学习 安装与配置,开始你的 JAX 实践之旅!
最后更新:2026-04-04