随机数生成 #

概述 #

JAX 使用显式的随机数状态(PRNGKey),这与 NumPy 的全局随机状态不同。这种设计使得随机数生成更加可控和可复现。

为什么使用显式状态? #

text
┌─────────────────────────────────────────────────────────────┐
│                    显式状态优势                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ✅ 可复现性                                                │
│     - 相同的 key 产生相同的随机数                            │
│     - 实验结果可验证                                        │
│                                                             │
│  ✅ 可组合性                                                │
│     - 函数变换友好                                          │
│     - JIT 编译安全                                          │
│                                                             │
│  ✅ 并行安全                                                │
│     - 无全局状态竞争                                        │
│     - 分布式计算友好                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

基本用法 #

创建 PRNGKey #

python
import jax

key = jax.random.PRNGKey(42)

print(f"Key: {key}")
print(f"Key 类型: {type(key)}")

分裂 Key #

python
import jax

key = jax.random.PRNGKey(42)

key, subkey1, subkey2 = jax.random.split(key, 3)

print(f"主 key: {key}")
print(f"子 key 1: {subkey1}")
print(f"子 key 2: {subkey2}")

生成随机数 #

python
import jax

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)

x = jax.random.normal(subkey, shape=(3, 3))
print(f"正态分布:\n{x}")

key, subkey = jax.random.split(key)
y = jax.random.uniform(subkey, shape=(3, 3))
print(f"均匀分布:\n{y}")

随机分布 #

正态分布 #

python
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(100,))

key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(100,), dtype=jnp.float32)

key, subkey = jax.random.split(key)
z = jax.random.normal(subkey, shape=(100,)) * 2 + 3  

均匀分布 #

python
import jax

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, shape=(10,))

key, subkey = jax.random.split(key)
y = jax.random.uniform(subkey, shape=(10,), minval=0, maxval=100)

整数随机 #

python
import jax

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key)
x = jax.random.randint(subkey, shape=(10,), minval=0, maxval=100)

key, subkey = jax.random.split(key)
y = jax.random.randint(subkey, shape=(5, 5), minval=0, maxval=10)

其他分布 #

python
import jax

key = jax.random.PRNGKey(0)

key, subkey = jax.random.split(key)
bernoulli = jax.random.bernoulli(subkey, shape=(10,))

key, subkey = jax.random.split(key)
poisson = jax.random.poisson(subkey, lam=5.0, shape=(10,))

key, subkey = jax.random.split(key)
exponential = jax.random.exponential(subkey, shape=(10,))

key, subkey = jax.random.split(key)
gamma = jax.random.gamma(subkey, a=2.0, shape=(10,))

Key 管理模式 #

模式 1: 顺序分裂 #

python
import jax

def init_params(key):
    key, w_key = jax.random.split(key)
    key, b_key = jax.random.split(key)
    
    w = jax.random.normal(w_key, (10, 5))
    b = jax.random.normal(b_key, (5,))
    
    return {'w': w, 'b': b}, key

key = jax.random.PRNGKey(0)
params, key = init_params(key)

模式 2: 批量分裂 #

python
import jax

def init_layers(key, num_layers):
    keys = jax.random.split(key, num_layers + 1)
    return keys[:-1], keys[-1]

key = jax.random.PRNGKey(0)
layer_keys, key = init_layers(key, 3)

for i, k in enumerate(layer_keys):
    print(f"Layer {i} key: {k}")

模式 3: 嵌套分裂 #

python
import jax

def init_network(key, layer_sizes):
    keys = jax.random.split(key, len(layer_sizes))
    params = []
    
    for i, (k, (in_size, out_size)) in enumerate(zip(keys, layer_sizes)):
        w_key, b_key = jax.random.split(k)
        w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
        b = jnp.zeros(out_size)
        params.append({'w': w, 'b': b})
    
    return params

import jax.numpy as jnp
key = jax.random.PRNGKey(0)
layer_sizes = [(10, 20), (20, 30), (30, 5)]
params = init_network(key, layer_sizes)

可复现性 #

确定性结果 #

python
import jax

def experiment(seed):
    key = jax.random.PRNGKey(seed)
    key, subkey = jax.random.split(key)
    return jax.random.normal(subkey, (3,))

result1 = experiment(42)
result2 = experiment(42)

print(f"第一次: {result1}")
print(f"第二次: {result2}")
print(f"相同: {jnp.allclose(result1, result2)}")

JIT 中的可复现性 #

python
import jax

@jax.jit
def random_function(key):
    key, subkey = jax.random.split(key)
    return jax.random.normal(subkey, (3,)), key

key = jax.random.PRNGKey(42)

result1, key = random_function(key)
result2, key = random_function(key)

print(f"结果 1: {result1}")
print(f"结果 2: {result2}")

随机采样 #

随机选择 #

python
import jax

key = jax.random.PRNGKey(0)
array = jnp.array([10, 20, 30, 40, 50])

key, subkey = jax.random.split(key)
choice = jax.random.choice(subkey, array)
print(f"随机选择: {choice}")

key, subkey = jax.random.split(key)
choices = jax.random.choice(subkey, array, shape=(3,))
print(f"多次选择: {choices}")

key, subkey = jax.random.split(key)
choices_no_replace = jax.random.choice(subkey, array, shape=(3,), replace=False)
print(f"无放回选择: {choices_no_replace}")

随机排列 #

python
import jax

key = jax.random.PRNGKey(0)
array = jnp.array([1, 2, 3, 4, 5])

key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, array)
print(f"随机排列: {perm}")

随机打乱 #

python
import jax

key = jax.random.PRNGKey(0)
array = jnp.array([1, 2, 3, 4, 5])

key, subkey = jax.random.split(key)
shuffled = jax.random.shuffle(subkey, array)
print(f"随机打乱: {shuffled}")

批量随机数 #

使用 vmap #

python
import jax

def single_random(key):
    return jax.random.normal(key, (3,))

batch_random = jax.vmap(single_random)

keys = jax.random.split(jax.random.PRNGKey(0), 10)
batch_results = batch_random(keys)
print(f"批量随机数形状: {batch_results.shape}")

分割多个 Key #

python
import jax

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 5)

for i, k in enumerate(keys):
    print(f"Key {i}: {k}")

实际应用 #

参数初始化 #

python
import jax
import jax.numpy as jnp

def init_mlp_params(key, layer_sizes):
    params = []
    for i, (in_size, out_size) in enumerate(layer_sizes):
        key, w_key, b_key = jax.random.split(key, 3)
        w = jax.random.normal(w_key, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
        b = jnp.zeros(out_size)
        params.append({'w': w, 'b': b})
    return params

key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])
print(f"层数: {len(params)}")

Dropout #

python
import jax
import jax.numpy as jnp

def dropout(key, x, rate=0.5):
    key, subkey = jax.random.split(key)
    mask = jax.random.bernoulli(subkey, 1 - rate, x.shape)
    return x * mask / (1 - rate), key

key = jax.random.PRNGKey(0)
x = jnp.ones((3, 4))
dropped, key = dropout(key, x, rate=0.5)
print(f"Dropout 后:\n{dropped}")

数据增强 #

python
import jax
import jax.numpy as jnp

def random_flip(key, image):
    key, subkey = jax.random.split(key)
    flip = jax.random.bernoulli(subkey)
    return jnp.where(flip, jnp.flip(image, axis=1), image), key

key = jax.random.PRNGKey(0)
image = jnp.arange(12).reshape(3, 4)
flipped, key = random_flip(key, image)
print(f"随机翻转后:\n{flipped}")

常见问题 #

问题 1: Key 重用 #

python
import jax

key = jax.random.PRNGKey(0)

x = jax.random.normal(key, (3,))
y = jax.random.normal(key, (3,))  
print(f"相同: {jnp.allclose(x, y)}")

key = jax.random.PRNGKey(0)
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)
x = jax.random.normal(subkey1, (3,))
y = jax.random.normal(subkey2, (3,))
print(f"不同: {not jnp.allclose(x, y)}")

问题 2: JIT 中的 Key #

python
import jax

@jax.jit
def good_practice(key, x):
    key, subkey = jax.random.split(key)
    noise = jax.random.normal(subkey, x.shape)
    return x + noise, key

key = jax.random.PRNGKey(0)
x = jnp.zeros((3,))

result, key = good_practice(key, x)
print(f"结果: {result}")

下一步 #

现在你已经掌握了随机数生成,接下来学习 控制流,了解 JAX 中的条件与循环!

最后更新:2026-04-04