NumPy 兼容性 #

概述 #

JAX 提供了 jax.numpy 模块,它与 NumPy API 高度兼容,使得从 NumPy 迁移到 JAX 非常简单。

基本使用 #

python
import jax.numpy as jnp
import numpy as np

x_jax = jnp.array([1, 2, 3])
x_np = np.array([1, 2, 3])

print(jnp.sum(x_jax))  
print(np.sum(x_np))    

兼容性概览 #

text
┌─────────────────────────────────────────────────────────────┐
│                  JAX vs NumPy 兼容性                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✅ 高度兼容                                                │
│     - 大多数 NumPy API 可直接使用                           │
│     - 广播机制完全兼容                                      │
│     - 数学运算一致                                          │
│                                                             │
│  ⚠️ 关键差异                                                │
│     - 数组不可变                                            │
│     - 默认 float32                                          │
│     - 随机数 API 不同                                       │
│     - 部分函数未实现                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

主要差异 #

1. 不可变性 #

NumPy 允许就地修改数组,JAX 不允许:

python
import numpy as np
import jax.numpy as jnp

x_np = np.array([1, 2, 3])
x_np[0] = 10  
print(x_np)  

x_jax = jnp.array([1, 2, 3])

x_jax = x_jax.at[0].set(10)
print(x_jax)  

2. 索引更新 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])

x = x.at[0].set(10)

x = x.at[1:3].set(0)

x = x.at[::2].add(1)

x = x.at[x > 3].set(0)

3. 数据类型 #

python
import numpy as np
import jax.numpy as jnp

x_np = np.array([1.0])
print(x_np.dtype)  

x_jax = jnp.array([1.0])
print(x_jax.dtype)  

x_jax_64 = jnp.array([1.0], dtype=jnp.float64)
print(x_jax_64.dtype)  

4. 随机数 #

python
import numpy as np
import jax
import jax.numpy as jnp

x_np = np.random.randn(3, 3)

key = jax.random.PRNGKey(42)
x_jax = jax.random.normal(key, (3, 3))

5. 外部副作用 #

python
import jax
import jax.numpy as jnp

@jax.jit
def bad_function(x):
    print(f"x = {x}")  
    return x + 1

@jax.jit
def good_function(x):
    return x + 1

API 对照表 #

数组创建 #

NumPy JAX 说明
np.array() jnp.array() 创建数组
np.zeros() jnp.zeros() 全零数组
np.ones() jnp.ones() 全一数组
np.arange() jnp.arange() 范围数组
np.linspace() jnp.linspace() 线性空间
np.eye() jnp.eye() 单位矩阵
np.random.rand() jax.random.uniform() 随机数组

数组操作 #

NumPy JAX 说明
np.reshape() jnp.reshape() 变形
np.transpose() jnp.transpose() 转置
np.concatenate() jnp.concatenate() 拼接
np.split() jnp.split() 分割
np.stack() jnp.stack() 堆叠

数学运算 #

NumPy JAX 说明
np.add() jnp.add() 加法
np.subtract() jnp.subtract() 减法
np.multiply() jnp.multiply() 乘法
np.divide() jnp.divide() 除法
np.dot() jnp.dot() 点积
np.matmul() jnp.matmul() 矩阵乘法

统计函数 #

NumPy JAX 说明
np.sum() jnp.sum() 求和
np.mean() jnp.mean() 均值
np.std() jnp.std() 标准差
np.var() jnp.var() 方差
np.max() jnp.max() 最大值
np.min() jnp.min() 最小值

线性代数 #

NumPy JAX 说明
np.linalg.inv() jnp.linalg.inv() 矩阵求逆
np.linalg.det() jnp.linalg.det() 行列式
np.linalg.eig() jnp.linalg.eig() 特征值
np.linalg.svd() jnp.linalg.svd() 奇异值分解
np.linalg.solve() jnp.linalg.solve() 线性方程组

迁移指南 #

步骤 1: 替换导入 #

python
import numpy as np

import jax.numpy as jnp

步骤 2: 修改数组操作 #

python
import numpy as np
import jax.numpy as jnp

x = np.array([1, 2, 3])
x[0] = 10

x = jnp.array([1, 2, 3])
x = x.at[0].set(10)

步骤 3: 修改随机数 #

python
import numpy as np
import jax

x = np.random.randn(3, 3)

key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (3, 3))

步骤 4: 添加 JIT #

python
import jax
import jax.numpy as jnp

def compute(x):
    return jnp.sum(x ** 2)

fast_compute = jax.jit(compute)

未实现的函数 #

某些 NumPy 函数在 JAX 中未实现:

python
import numpy as np
import jax.numpy as jnp

x = np.array([1, 2, 3, 4, 5])

np.sort(x, kind='quicksort')  
jnp.sort(x)  

np.histogram(x, bins=10)  

性能对比 #

CPU 性能 #

python
import time
import numpy as np
import jax.numpy as jnp
import jax

size = 10000

x_np = np.random.randn(size, size)
x_jax = jnp.array(x_np)

start = time.time()
result_np = np.dot(x_np, x_np.T)
print(f"NumPy: {time.time() - start:.4f}s")

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

compute(x_jax).block_until_ready()  
start = time.time()
result_jax = compute(x_jax)
result_jax.block_until_ready()
print(f"JAX JIT: {time.time() - start:.4f}s")

GPU 性能 #

python
import time
import jax
import jax.numpy as jnp

size = 10000
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size))

@jax.jit
def compute(x):
    return jnp.dot(x, x.T)

compute(x).block_until_ready()

start = time.time()
result = compute(x)
result.block_until_ready()
print(f"JAX GPU: {time.time() - start:.4f}s")

最佳实践 #

1. 使用 JAX 风格的索引 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3, 4, 5])

x = x.at[0].set(10)
x = x.at[1:3].mul(2)
x = x.at[x > 5].set(5)

2. 使用显式随机数 #

python
import jax

def init_params(key, shape):
    return jax.random.normal(key, shape)

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
params = init_params(subkey, (10, 5))

3. 使用 JIT 加速 #

python
import jax
import jax.numpy as jnp

@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        predict = jnp.dot(x, p)
        return jnp.mean((predict - y) ** 2)
    
    grads = jax.grad(loss_fn)(params)
    return jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)

4. 避免设备间传输 #

python
import numpy as np
import jax.numpy as jnp

x_np = np.array([1, 2, 3])

x_jax = jnp.array(x_np)

result = x_jax + 1

常见陷阱 #

陷阱 1: 就地修改 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3])

x = x.at[0].set(10)

陷阱 2: 隐式类型转换 #

python
import jax.numpy as jnp

x = jnp.array([1, 2, 3])  
y = jnp.array([1.0, 2.0, 3.0])  

z = x + y  
print(z.dtype)  

陷阱 3: 副作用 #

python
import jax

counter = 0

@jax.jit
def bad_function(x):
    global counter
    counter += 1  
    return x + 1

@jax.jit
def good_function(x, counter):
    return x + 1, counter + 1

下一步 #

现在你已经了解了 JAX 与 NumPy 的兼容性,接下来学习 自动微分 (grad),深入了解 JAX 的核心功能!

最后更新:2026-04-04