NumPy 兼容性 #
概述 #
JAX 提供了 jax.numpy 模块,它与 NumPy API 高度兼容,使得从 NumPy 迁移到 JAX 非常简单。
基本使用 #
python
import jax.numpy as jnp
import numpy as np
x_jax = jnp.array([1, 2, 3])
x_np = np.array([1, 2, 3])
print(jnp.sum(x_jax))
print(np.sum(x_np))
兼容性概览 #
text
┌─────────────────────────────────────────────────────────────┐
│ JAX vs NumPy 兼容性 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ✅ 高度兼容 │
│ - 大多数 NumPy API 可直接使用 │
│ - 广播机制完全兼容 │
│ - 数学运算一致 │
│ │
│ ⚠️ 关键差异 │
│ - 数组不可变 │
│ - 默认 float32 │
│ - 随机数 API 不同 │
│ - 部分函数未实现 │
│ │
└─────────────────────────────────────────────────────────────┘
主要差异 #
1. 不可变性 #
NumPy 允许就地修改数组,JAX 不允许:
python
import numpy as np
import jax.numpy as jnp
x_np = np.array([1, 2, 3])
x_np[0] = 10
print(x_np)
x_jax = jnp.array([1, 2, 3])
x_jax = x_jax.at[0].set(10)
print(x_jax)
2. 索引更新 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3, 4, 5])
x = x.at[0].set(10)
x = x.at[1:3].set(0)
x = x.at[::2].add(1)
x = x.at[x > 3].set(0)
3. 数据类型 #
python
import numpy as np
import jax.numpy as jnp
x_np = np.array([1.0])
print(x_np.dtype)
x_jax = jnp.array([1.0])
print(x_jax.dtype)
x_jax_64 = jnp.array([1.0], dtype=jnp.float64)
print(x_jax_64.dtype)
4. 随机数 #
python
import numpy as np
import jax
import jax.numpy as jnp
x_np = np.random.randn(3, 3)
key = jax.random.PRNGKey(42)
x_jax = jax.random.normal(key, (3, 3))
5. 外部副作用 #
python
import jax
import jax.numpy as jnp
@jax.jit
def bad_function(x):
print(f"x = {x}")
return x + 1
@jax.jit
def good_function(x):
return x + 1
API 对照表 #
数组创建 #
| NumPy | JAX | 说明 |
|---|---|---|
np.array() |
jnp.array() |
创建数组 |
np.zeros() |
jnp.zeros() |
全零数组 |
np.ones() |
jnp.ones() |
全一数组 |
np.arange() |
jnp.arange() |
范围数组 |
np.linspace() |
jnp.linspace() |
线性空间 |
np.eye() |
jnp.eye() |
单位矩阵 |
np.random.rand() |
jax.random.uniform() |
随机数组 |
数组操作 #
| NumPy | JAX | 说明 |
|---|---|---|
np.reshape() |
jnp.reshape() |
变形 |
np.transpose() |
jnp.transpose() |
转置 |
np.concatenate() |
jnp.concatenate() |
拼接 |
np.split() |
jnp.split() |
分割 |
np.stack() |
jnp.stack() |
堆叠 |
数学运算 #
| NumPy | JAX | 说明 |
|---|---|---|
np.add() |
jnp.add() |
加法 |
np.subtract() |
jnp.subtract() |
减法 |
np.multiply() |
jnp.multiply() |
乘法 |
np.divide() |
jnp.divide() |
除法 |
np.dot() |
jnp.dot() |
点积 |
np.matmul() |
jnp.matmul() |
矩阵乘法 |
统计函数 #
| NumPy | JAX | 说明 |
|---|---|---|
np.sum() |
jnp.sum() |
求和 |
np.mean() |
jnp.mean() |
均值 |
np.std() |
jnp.std() |
标准差 |
np.var() |
jnp.var() |
方差 |
np.max() |
jnp.max() |
最大值 |
np.min() |
jnp.min() |
最小值 |
线性代数 #
| NumPy | JAX | 说明 |
|---|---|---|
np.linalg.inv() |
jnp.linalg.inv() |
矩阵求逆 |
np.linalg.det() |
jnp.linalg.det() |
行列式 |
np.linalg.eig() |
jnp.linalg.eig() |
特征值 |
np.linalg.svd() |
jnp.linalg.svd() |
奇异值分解 |
np.linalg.solve() |
jnp.linalg.solve() |
线性方程组 |
迁移指南 #
步骤 1: 替换导入 #
python
import numpy as np
import jax.numpy as jnp
步骤 2: 修改数组操作 #
python
import numpy as np
import jax.numpy as jnp
x = np.array([1, 2, 3])
x[0] = 10
x = jnp.array([1, 2, 3])
x = x.at[0].set(10)
步骤 3: 修改随机数 #
python
import numpy as np
import jax
x = np.random.randn(3, 3)
key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (3, 3))
步骤 4: 添加 JIT #
python
import jax
import jax.numpy as jnp
def compute(x):
return jnp.sum(x ** 2)
fast_compute = jax.jit(compute)
未实现的函数 #
某些 NumPy 函数在 JAX 中未实现:
python
import numpy as np
import jax.numpy as jnp
x = np.array([1, 2, 3, 4, 5])
np.sort(x, kind='quicksort')
jnp.sort(x)
np.histogram(x, bins=10)
性能对比 #
CPU 性能 #
python
import time
import numpy as np
import jax.numpy as jnp
import jax
size = 10000
x_np = np.random.randn(size, size)
x_jax = jnp.array(x_np)
start = time.time()
result_np = np.dot(x_np, x_np.T)
print(f"NumPy: {time.time() - start:.4f}s")
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
compute(x_jax).block_until_ready()
start = time.time()
result_jax = compute(x_jax)
result_jax.block_until_ready()
print(f"JAX JIT: {time.time() - start:.4f}s")
GPU 性能 #
python
import time
import jax
import jax.numpy as jnp
size = 10000
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size))
@jax.jit
def compute(x):
return jnp.dot(x, x.T)
compute(x).block_until_ready()
start = time.time()
result = compute(x)
result.block_until_ready()
print(f"JAX GPU: {time.time() - start:.4f}s")
最佳实践 #
1. 使用 JAX 风格的索引 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3, 4, 5])
x = x.at[0].set(10)
x = x.at[1:3].mul(2)
x = x.at[x > 5].set(5)
2. 使用显式随机数 #
python
import jax
def init_params(key, shape):
return jax.random.normal(key, shape)
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
params = init_params(subkey, (10, 5))
3. 使用 JIT 加速 #
python
import jax
import jax.numpy as jnp
@jax.jit
def train_step(params, x, y):
def loss_fn(p):
predict = jnp.dot(x, p)
return jnp.mean((predict - y) ** 2)
grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
4. 避免设备间传输 #
python
import numpy as np
import jax.numpy as jnp
x_np = np.array([1, 2, 3])
x_jax = jnp.array(x_np)
result = x_jax + 1
常见陷阱 #
陷阱 1: 就地修改 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x = x.at[0].set(10)
陷阱 2: 隐式类型转换 #
python
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
y = jnp.array([1.0, 2.0, 3.0])
z = x + y
print(z.dtype)
陷阱 3: 副作用 #
python
import jax
counter = 0
@jax.jit
def bad_function(x):
global counter
counter += 1
return x + 1
@jax.jit
def good_function(x, counter):
return x + 1, counter + 1
下一步 #
现在你已经了解了 JAX 与 NumPy 的兼容性,接下来学习 自动微分 (grad),深入了解 JAX 的核心功能!
最后更新:2026-04-04