随机数生成 #
概述 #
JAX 使用显式的随机数状态(PRNGKey),这与 NumPy 的全局随机状态不同。这种设计使得随机数生成更加可控和可复现。
为什么使用显式状态? #
text
┌─────────────────────────────────────────────────────────────┐
│ 显式状态优势 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ✅ 可复现性 │
│ - 相同的 key 产生相同的随机数 │
│ - 实验结果可验证 │
│ │
│ ✅ 可组合性 │
│ - 函数变换友好 │
│ - JIT 编译安全 │
│ │
│ ✅ 并行安全 │
│ - 无全局状态竞争 │
│ - 分布式计算友好 │
│ │
└─────────────────────────────────────────────────────────────┘
基本用法 #
创建 PRNGKey #
python
import jax
key = jax.random.PRNGKey(42)
print(f"Key: {key}")
print(f"Key 类型: {type(key)}")
分裂 Key #
python
import jax
key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)
print(f"主 key: {key}")
print(f"子 key 1: {subkey1}")
print(f"子 key 2: {subkey2}")
生成随机数 #
python
import jax
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(3, 3))
print(f"正态分布:\n{x}")
key, subkey = jax.random.split(key)
y = jax.random.uniform(subkey, shape=(3, 3))
print(f"均匀分布:\n{y}")
随机分布 #
正态分布 #
python
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(100,))
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(100,), dtype=jnp.float32)
key, subkey = jax.random.split(key)
z = jax.random.normal(subkey, shape=(100,)) * 2 + 3
均匀分布 #
python
import jax
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
x = jax.random.uniform(subkey, shape=(10,))
key, subkey = jax.random.split(key)
y = jax.random.uniform(subkey, shape=(10,), minval=0, maxval=100)
整数随机 #
python
import jax
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
x = jax.random.randint(subkey, shape=(10,), minval=0, maxval=100)
key, subkey = jax.random.split(key)
y = jax.random.randint(subkey, shape=(5, 5), minval=0, maxval=10)
其他分布 #
python
import jax
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
bernoulli = jax.random.bernoulli(subkey, shape=(10,))
key, subkey = jax.random.split(key)
poisson = jax.random.poisson(subkey, lam=5.0, shape=(10,))
key, subkey = jax.random.split(key)
exponential = jax.random.exponential(subkey, shape=(10,))
key, subkey = jax.random.split(key)
gamma = jax.random.gamma(subkey, a=2.0, shape=(10,))
Key 管理模式 #
模式 1: 顺序分裂 #
python
import jax
def init_params(key):
key, w_key = jax.random.split(key)
key, b_key = jax.random.split(key)
w = jax.random.normal(w_key, (10, 5))
b = jax.random.normal(b_key, (5,))
return {'w': w, 'b': b}, key
key = jax.random.PRNGKey(0)
params, key = init_params(key)
模式 2: 批量分裂 #
python
import jax
def init_layers(key, num_layers):
keys = jax.random.split(key, num_layers + 1)
return keys[:-1], keys[-1]
key = jax.random.PRNGKey(0)
layer_keys, key = init_layers(key, 3)
for i, k in enumerate(layer_keys):
print(f"Layer {i} key: {k}")
模式 3: 嵌套分裂 #
python
import jax
def init_network(key, layer_sizes):
keys = jax.random.split(key, len(layer_sizes))
params = []
for i, (k, (in_size, out_size)) in enumerate(zip(keys, layer_sizes)):
w_key, b_key = jax.random.split(k)
w = jax.random.normal(w_key, (in_size, out_size)) * 0.01
b = jnp.zeros(out_size)
params.append({'w': w, 'b': b})
return params
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
layer_sizes = [(10, 20), (20, 30), (30, 5)]
params = init_network(key, layer_sizes)
可复现性 #
确定性结果 #
python
import jax
def experiment(seed):
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)
return jax.random.normal(subkey, (3,))
result1 = experiment(42)
result2 = experiment(42)
print(f"第一次: {result1}")
print(f"第二次: {result2}")
print(f"相同: {jnp.allclose(result1, result2)}")
JIT 中的可复现性 #
python
import jax
@jax.jit
def random_function(key):
key, subkey = jax.random.split(key)
return jax.random.normal(subkey, (3,)), key
key = jax.random.PRNGKey(42)
result1, key = random_function(key)
result2, key = random_function(key)
print(f"结果 1: {result1}")
print(f"结果 2: {result2}")
随机采样 #
随机选择 #
python
import jax
key = jax.random.PRNGKey(0)
array = jnp.array([10, 20, 30, 40, 50])
key, subkey = jax.random.split(key)
choice = jax.random.choice(subkey, array)
print(f"随机选择: {choice}")
key, subkey = jax.random.split(key)
choices = jax.random.choice(subkey, array, shape=(3,))
print(f"多次选择: {choices}")
key, subkey = jax.random.split(key)
choices_no_replace = jax.random.choice(subkey, array, shape=(3,), replace=False)
print(f"无放回选择: {choices_no_replace}")
随机排列 #
python
import jax
key = jax.random.PRNGKey(0)
array = jnp.array([1, 2, 3, 4, 5])
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, array)
print(f"随机排列: {perm}")
随机打乱 #
python
import jax
key = jax.random.PRNGKey(0)
array = jnp.array([1, 2, 3, 4, 5])
key, subkey = jax.random.split(key)
shuffled = jax.random.shuffle(subkey, array)
print(f"随机打乱: {shuffled}")
批量随机数 #
使用 vmap #
python
import jax
def single_random(key):
return jax.random.normal(key, (3,))
batch_random = jax.vmap(single_random)
keys = jax.random.split(jax.random.PRNGKey(0), 10)
batch_results = batch_random(keys)
print(f"批量随机数形状: {batch_results.shape}")
分割多个 Key #
python
import jax
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 5)
for i, k in enumerate(keys):
print(f"Key {i}: {k}")
实际应用 #
参数初始化 #
python
import jax
import jax.numpy as jnp
def init_mlp_params(key, layer_sizes):
params = []
for i, (in_size, out_size) in enumerate(layer_sizes):
key, w_key, b_key = jax.random.split(key, 3)
w = jax.random.normal(w_key, (in_size, out_size)) * jnp.sqrt(2.0 / in_size)
b = jnp.zeros(out_size)
params.append({'w': w, 'b': b})
return params
key = jax.random.PRNGKey(0)
params = init_mlp_params(key, [(784, 256), (256, 128), (128, 10)])
print(f"层数: {len(params)}")
Dropout #
python
import jax
import jax.numpy as jnp
def dropout(key, x, rate=0.5):
key, subkey = jax.random.split(key)
mask = jax.random.bernoulli(subkey, 1 - rate, x.shape)
return x * mask / (1 - rate), key
key = jax.random.PRNGKey(0)
x = jnp.ones((3, 4))
dropped, key = dropout(key, x, rate=0.5)
print(f"Dropout 后:\n{dropped}")
数据增强 #
python
import jax
import jax.numpy as jnp
def random_flip(key, image):
key, subkey = jax.random.split(key)
flip = jax.random.bernoulli(subkey)
return jnp.where(flip, jnp.flip(image, axis=1), image), key
key = jax.random.PRNGKey(0)
image = jnp.arange(12).reshape(3, 4)
flipped, key = random_flip(key, image)
print(f"随机翻转后:\n{flipped}")
常见问题 #
问题 1: Key 重用 #
python
import jax
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (3,))
y = jax.random.normal(key, (3,))
print(f"相同: {jnp.allclose(x, y)}")
key = jax.random.PRNGKey(0)
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)
x = jax.random.normal(subkey1, (3,))
y = jax.random.normal(subkey2, (3,))
print(f"不同: {not jnp.allclose(x, y)}")
问题 2: JIT 中的 Key #
python
import jax
@jax.jit
def good_practice(key, x):
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, x.shape)
return x + noise, key
key = jax.random.PRNGKey(0)
x = jnp.zeros((3,))
result, key = good_practice(key, x)
print(f"结果: {result}")
下一步 #
现在你已经掌握了随机数生成,接下来学习 控制流,了解 JAX 中的条件与循环!
最后更新:2026-04-04