状态管理 #
概述 #
JAX 的函数式设计要求显式管理状态。神经网络中的状态包括模型参数、BatchNorm 的统计量、Dropout 的随机状态等。
参数管理 #
参数结构 #
python
import jax
import jax.numpy as jnp
def init_params(key):
keys = jax.random.split(key, 4)
params = {
'layer1': {
'w': jax.random.normal(keys[0], (784, 256)) * 0.01,
'b': jnp.zeros(256)
},
'layer2': {
'w': jax.random.normal(keys[1], (256, 128)) * 0.01,
'b': jnp.zeros(128)
},
'output': {
'w': jax.random.normal(keys[2], (128, 10)) * 0.01,
'b': jnp.zeros(10)
}
}
return params
key = jax.random.PRNGKey(0)
params = init_params(key)
print(jax.tree_util.tree_structure(params))
参数更新 #
python
import jax
def update_params(params, grads, lr=0.01):
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
grads = {'w': jnp.array([0.1, 0.2]), 'b': jnp.array([0.01])}
new_params = update_params(params, grads)
print(f"更新后参数: {new_params}")
参数复制 #
python
import jax
def copy_params(params):
return jax.tree_map(lambda x: x.copy(), params)
params = {'w': jnp.array([1.0, 2.0])}
copied = copy_params(params)
print(f"原参数: {params}")
print(f"复制参数: {copied}")
BatchNorm #
BatchNorm 实现 #
python
import jax
import jax.numpy as jnp
def batchnorm_forward(params, state, x, train=True):
scale, bias = params['scale'], params['bias']
mean, var = state['mean'], state['var']
if train:
batch_mean = jnp.mean(x, axis=0, keepdims=True)
batch_var = jnp.var(x, axis=0, keepdims=True)
momentum = 0.1
new_mean = (1 - momentum) * mean + momentum * batch_mean
new_var = (1 - momentum) * var + momentum * batch_var
x_norm = (x - batch_mean) / jnp.sqrt(batch_var + 1e-5)
new_state = {'mean': new_mean, 'var': new_var}
else:
x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
new_state = state
output = scale * x_norm + bias
return output, new_state
def init_batchnorm(key, features):
return {
'params': {
'scale': jnp.ones(features),
'bias': jnp.zeros(features)
},
'state': {
'mean': jnp.zeros(features),
'var': jnp.ones(features)
}
}
key = jax.random.PRNGKey(0)
bn = init_batchnorm(key, 256)
x = jax.random.normal(key, (32, 256))
output, new_state = batchnorm_forward(bn['params'], bn['state'], x)
print(f"输出形状: {output.shape}")
集成到网络 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_network_with_bn(key, features):
keys = jax.random.split(key, 3)
return {
'dense1': {
'w': jax.random.normal(keys[0], (features[0], features[1])) * 0.01,
'b': jnp.zeros(features[1])
},
'bn1': init_batchnorm(keys[1], features[1]),
'dense2': {
'w': jax.random.normal(keys[2], (features[1], features[2])) * 0.01,
'b': jnp.zeros(features[2])
}
}
def forward_with_bn(params, states, x, train=True):
x = jnp.dot(x, params['dense1']['w']) + params['dense1']['b']
x, new_bn_state = batchnorm_forward(
params['bn1']['params'],
states['bn1'],
x,
train
)
x = nn.relu(x)
x = jnp.dot(x, params['dense2']['w']) + params['dense2']['b']
return x, {'bn1': new_bn_state}
Dropout #
Dropout 实现 #
python
import jax
import jax.numpy as jnp
def dropout(key, x, rate=0.5, train=True):
if not train:
return x, key
key, subkey = jax.random.split(key)
mask = jax.random.bernoulli(subkey, 1 - rate, x.shape)
output = jnp.where(mask, x / (1 - rate), 0)
return output, key
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 256))
output, key = dropout(key, x, rate=0.5, train=True)
print(f"Dropout 后形状: {output.shape}")
print(f"非零元素比例: {jnp.mean(output != 0):.4f}")
集成到网络 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def forward_with_dropout(params, x, key, train=True):
key, subkey1 = jax.random.split(key)
x = jnp.dot(x, params['w1']) + params['b1']
x = nn.relu(x)
x, key = dropout(key, x, rate=0.5, train=train)
x = jnp.dot(x, params['w2']) + params['b2']
return x, key
key = jax.random.PRNGKey(0)
params = {
'w1': jax.random.normal(key, (784, 256)) * 0.01,
'b1': jnp.zeros(256),
'w2': jax.random.normal(key, (256, 10)) * 0.01,
'b2': jnp.zeros(10)
}
x = jax.random.normal(key, (32, 784))
output, key = forward_with_dropout(params, x, key, train=True)
print(f"输出形状: {output.shape}")
完整训练状态 #
状态结构 #
python
import jax
import jax.numpy as jnp
def init_training_state(key, layer_sizes):
keys = jax.random.split(key, len(layer_sizes) + 1)
params = []
bn_params = []
bn_states = []
for i, (in_size, out_size) in enumerate(layer_sizes):
params.append({
'w': jax.random.normal(keys[i], (in_size, out_size)) * 0.01,
'b': jnp.zeros(out_size)
})
bn_params.append({
'scale': jnp.ones(out_size),
'bias': jnp.zeros(out_size)
})
bn_states.append({
'mean': jnp.zeros(out_size),
'var': jnp.ones(out_size)
})
return {
'params': params,
'bn_params': bn_params,
'bn_states': bn_states,
'optimizer_state': {'momentum': [jnp.zeros_like(p['w']) for p in params],
'velocity': [jnp.zeros_like(p['w']) for p in params]}
}
训练步骤 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
@jax.jit
def train_step(state, x, y, key, lr=0.01):
def loss_fn(params, bn_params, bn_states):
h = x
new_bn_states = []
for i, (p, bn_p, bn_s) in enumerate(zip(params, bn_params, bn_states)):
h = jnp.dot(h, p['w']) + p['b']
h, new_bn_s = batchnorm_forward(bn_p, bn_s, h)
new_bn_states.append(new_bn_s)
if i < len(params) - 1:
h = nn.relu(h)
key, subkey = jax.random.split(key)
h, _ = dropout(subkey, h, rate=0.5)
logits = h
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, logits.shape[-1])
loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
return loss, new_bn_states
(loss, new_bn_states), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state['params'], state['bn_params'], state['bn_states']
)
new_params = jax.tree_map(lambda p, g: p - lr * g, state['params'], grads)
return {
'params': new_params,
'bn_params': state['bn_params'],
'bn_states': new_bn_states,
'optimizer_state': state['optimizer_state']
}, loss
优化器状态 #
SGD with Momentum #
python
import jax
import jax.numpy as jnp
def init_sgd_momentum(params):
return jax.tree_map(jnp.zeros_like, params)
@jax.jit
def sgd_momentum_step(params, grads, state, lr=0.01, momentum=0.9):
new_state = jax.tree_map(
lambda s, g: momentum * s + g,
state, grads
)
new_params = jax.tree_map(
lambda p, s: p - lr * s,
params, new_state
)
return new_params, new_state
key = jax.random.PRNGKey(0)
params = {'w': jax.random.normal(key, (10, 5))}
grads = {'w': jnp.ones((10, 5))}
state = init_sgd_momentum(params)
new_params, new_state = sgd_momentum_step(params, grads, state)
print(f"更新后参数形状: {new_params['w'].shape}")
Adam #
python
import jax
import jax.numpy as jnp
def init_adam(params):
return {
'm': jax.tree_map(jnp.zeros_like, params),
'v': jax.tree_map(jnp.zeros_like, params),
't': 0
}
@jax.jit
def adam_step(params, grads, state, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
t = state['t'] + 1
new_m = jax.tree_map(
lambda m, g: beta1 * m + (1 - beta1) * g,
state['m'], grads
)
new_v = jax.tree_map(
lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
state['v'], grads
)
m_hat = jax.tree_map(lambda m: m / (1 - beta1 ** t), new_m)
v_hat = jax.tree_map(lambda v: v / (1 - beta2 ** t), new_v)
new_params = jax.tree_map(
lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
params, m_hat, v_hat
)
new_state = {'m': new_m, 'v': new_v, 't': t}
return new_params, new_state
key = jax.random.PRNGKey(0)
params = {'w': jax.random.normal(key, (10, 5))}
grads = {'w': jnp.ones((10, 5))}
state = init_adam(params)
new_params, new_state = adam_step(params, grads, state)
print(f"更新后参数形状: {new_params['w'].shape}")
下一步 #
现在你已经掌握了状态管理,接下来学习 训练循环,了解如何构建完整的训练流程!
最后更新:2026-04-04