JAX #
什么是 JAX? #
JAX 是由 Google 开发的高性能机器学习研究框架,它提供了自动微分、自动向量化、JIT 编译等功能,能够在 CPU、GPU 和 TPU 上高效运行。JAX 的设计理念是"可组合的函数变换",让研究者能够灵活地组合各种变换来构建复杂的机器学习系统。
JAX 的核心优势 #
| 优势 | 说明 |
|---|---|
| 自动微分 | 支持 grad、hessian、jacobian 等多种微分方式 |
| 自动向量化 | vmap 自动批处理,无需手动编写循环 |
| JIT 编译 | XLA 编译优化,获得接近 C++ 的性能 |
| 函数式设计 | 纯函数、不可变状态,易于推理和测试 |
| 多设备支持 | 统一的 API 支持 CPU、GPU、TPU |
| NumPy 兼容 | 几乎完全兼容 NumPy API,学习成本低 |
文档结构 #
本指南按以下结构组织,适合初学者按顺序学习:
1. 基础入门 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| JAX 简介 | JAX 的发展历史、核心特点、应用场景 | intro.md |
| 安装与配置 | 各平台安装、GPU/TPU 配置、环境验证 | installation.md |
| JAX 基础概念 | 数组、设备、计算图、函数变换 | basics.md |
| NumPy 兼容性 | jax.numpy 与 numpy 的差异与迁移 | numpy.md |
2. 核心功能 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 自动微分 (grad) | 梯度计算、高阶导数、停止梯度 | grad.md |
| 自动向量化 (vmap) | 批处理、嵌套 vmap、性能优化 | vmap.md |
| JIT 编译 | 编译原理、缓存机制、调试技巧 | jit.md |
| 函数变换组合 | grad + vmap + jit 组合使用 | transformations.md |
3. 数值计算 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 数组操作 | 创建、索引、切片、变形 | arrays.md |
| 线性代数 | 矩阵运算、分解、求解 | linalg.md |
| 随机数生成 | PRNGKey 设计、随机数使用 | random.md |
| 控制流 | lax.cond、lax.while_loop、lax.scan | control-flow.md |
4. 神经网络 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 构建神经网络 | 线性层、激活函数、模型定义 | nn-basics.md |
| 状态管理 | 参数管理、BatchNorm、Dropout | state-management.md |
| 训练循环 | 损失函数、优化器、训练步骤 | training.md |
| 模型保存与加载 | 序列化、检查点、模型恢复 | serialization.md |
5. 分布式计算 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 多设备计算 | 设备管理、数据放置、设备选择 | multi-device.md |
| 数据并行 | pmap、pjit、分布式训练 | data-parallel.md |
| 模型并行 | 大模型切分、流水线并行 | model-parallel.md |
| TPU 加速 | TPU 使用、性能调优 | tpu.md |
6. 高级特性 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 自定义操作 | 自定义 C++/CUDA 算子 | custom-ops.md |
| 性能优化 | 性能分析、内存优化、编译优化 | performance.md |
| 调试技巧 | 常见错误、调试工具、最佳实践 | debugging.md |
| 常见问题 | FAQ、陷阱、解决方案 | faq.md |
7. 实战案例 #
| 主题 | 描述 | 文档链接 |
|---|---|---|
| 线性回归 | 从零实现线性回归 | linear-regression.md |
| 图像分类 | CNN 图像分类实战 | image-classification.md |
| 文本生成 | Transformer 文本生成 | text-generation.md |
| 强化学习 | DQN/PPO 实现 | rl.md |
学习路线 #
text
入门阶段
├── JAX 简介
├── 安装与配置
├── JAX 基础概念
└── NumPy 兼容性
基础阶段
├── 自动微分 (grad)
├── 自动向量化 (vmap)
├── JIT 编译
└── 函数变换组合
进阶阶段
├── 数组操作
├── 线性代数
├── 随机数生成
└── 控制流
高级阶段
├── 构建神经网络
├── 状态管理
├── 训练循环
└── 模型保存与加载
专家阶段
├── 分布式计算
├── 性能优化
├── 自定义操作
└── 实战项目
JAX 核心概念 #
函数变换 #
JAX 的核心是函数变换(Function Transformations),它们可以组合使用:
text
┌─────────────────────────────────────────────────────────────┐
│ JAX 函数变换 │
├─────────────────────────────────────────────────────────────┤
│ │
│ grad ─── 自动微分,计算梯度 │
│ vmap ─── 自动向量化,批处理 │
│ jit ─── 即时编译,性能优化 │
│ pmap ─── 并行映射,多设备 │
│ │
│ 组合示例: │
│ │
│ jit(grad(loss)) ─── 编译优化梯度计算 │
│ vmap(grad(loss)) ─── 批量计算梯度 │
│ jit(vmap(grad(loss))) ─── 编译优化的批量梯度 │
│ │
└─────────────────────────────────────────────────────────────┘
与其他框架对比 #
| 特性 | JAX | PyTorch | TensorFlow |
|---|---|---|---|
| 自动微分 | 函数式 | 对象式 | 图/函数式 |
| 向量化 | vmap 自动 | 手动批处理 | 手动/自动 |
| 编译优化 | JIT + XLA | TorchScript | XLA |
| 分布式 | pmap/pjit | DDP/FSDP | Strategy |
| TPU 支持 | 原生 | 有限 | 原生 |
| 调试体验 | Python 原生 | Python 原生 | 较复杂 |
适用人群 #
| 人群 | 建议 |
|---|---|
| 初学者 | 从基础入门开始,了解函数变换概念 |
| 研究者 | 重点学习自动微分、向量化、JIT |
| 工程师 | 重点学习性能优化、分布式训练 |
| TPU 用户 | 重点学习 TPU 加速、大规模训练 |
学习建议 #
- 理解函数式编程:JAX 采用函数式设计,理解纯函数和不可变性很重要
- 掌握 NumPy:JAX 的 API 与 NumPy 高度兼容,熟悉 NumPy 有助于快速上手
- 动手实践:边学边做,通过实际代码理解概念
- 组合变换:学会组合 grad、vmap、jit 是掌握 JAX 的关键
- 关注性能:理解 JIT 编译原理,写出高性能代码
生态系统 #
JAX 拥有丰富的生态系统:
| 库 | 描述 |
|---|---|
| Flax | 神经网络库,提供灵活的模型定义 |
| Optax | 优化器库,包含各种优化算法 |
| Haiku | Sonnet 风格的神经网络库 |
| Equinox | 函数式神经网络库 |
| JAX-MD | 分子动力学模拟 |
| RLax | 强化学习算法 |
学习资源 #
- 官方文档:https://jax.readthedocs.io
- GitHub 仓库:https://github.com/google/jax
- JAX 论文:https://arxiv.org/abs/1912.10054
- JAX 社区:https://github.com/google/jax/discussions
开始学习 #
准备好了吗?让我们从 JAX 简介 开始你的 JAX 学习之旅!
最后更新:2026-04-04