状态管理 #

概述 #

JAX 的函数式设计要求显式管理状态。神经网络中的状态包括模型参数、BatchNorm 的统计量、Dropout 的随机状态等。

参数管理 #

参数结构 #

python
import jax
import jax.numpy as jnp

def init_params(key):
    keys = jax.random.split(key, 4)
    params = {
        'layer1': {
            'w': jax.random.normal(keys[0], (784, 256)) * 0.01,
            'b': jnp.zeros(256)
        },
        'layer2': {
            'w': jax.random.normal(keys[1], (256, 128)) * 0.01,
            'b': jnp.zeros(128)
        },
        'output': {
            'w': jax.random.normal(keys[2], (128, 10)) * 0.01,
            'b': jnp.zeros(10)
        }
    }
    return params

key = jax.random.PRNGKey(0)
params = init_params(key)

print(jax.tree_util.tree_structure(params))

参数更新 #

python
import jax

def update_params(params, grads, lr=0.01):
    return jax.tree_map(lambda p, g: p - lr * g, params, grads)

params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
grads = {'w': jnp.array([0.1, 0.2]), 'b': jnp.array([0.01])}

new_params = update_params(params, grads)
print(f"更新后参数: {new_params}")

参数复制 #

python
import jax

def copy_params(params):
    return jax.tree_map(lambda x: x.copy(), params)

params = {'w': jnp.array([1.0, 2.0])}
copied = copy_params(params)
print(f"原参数: {params}")
print(f"复制参数: {copied}")

BatchNorm #

BatchNorm 实现 #

python
import jax
import jax.numpy as jnp

def batchnorm_forward(params, state, x, train=True):
    scale, bias = params['scale'], params['bias']
    mean, var = state['mean'], state['var']
    
    if train:
        batch_mean = jnp.mean(x, axis=0, keepdims=True)
        batch_var = jnp.var(x, axis=0, keepdims=True)
        
        momentum = 0.1
        new_mean = (1 - momentum) * mean + momentum * batch_mean
        new_var = (1 - momentum) * var + momentum * batch_var
        
        x_norm = (x - batch_mean) / jnp.sqrt(batch_var + 1e-5)
        new_state = {'mean': new_mean, 'var': new_var}
    else:
        x_norm = (x - mean) / jnp.sqrt(var + 1e-5)
        new_state = state
    
    output = scale * x_norm + bias
    return output, new_state

def init_batchnorm(key, features):
    return {
        'params': {
            'scale': jnp.ones(features),
            'bias': jnp.zeros(features)
        },
        'state': {
            'mean': jnp.zeros(features),
            'var': jnp.ones(features)
        }
    }

key = jax.random.PRNGKey(0)
bn = init_batchnorm(key, 256)

x = jax.random.normal(key, (32, 256))
output, new_state = batchnorm_forward(bn['params'], bn['state'], x)
print(f"输出形状: {output.shape}")

集成到网络 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_network_with_bn(key, features):
    keys = jax.random.split(key, 3)
    return {
        'dense1': {
            'w': jax.random.normal(keys[0], (features[0], features[1])) * 0.01,
            'b': jnp.zeros(features[1])
        },
        'bn1': init_batchnorm(keys[1], features[1]),
        'dense2': {
            'w': jax.random.normal(keys[2], (features[1], features[2])) * 0.01,
            'b': jnp.zeros(features[2])
        }
    }

def forward_with_bn(params, states, x, train=True):
    x = jnp.dot(x, params['dense1']['w']) + params['dense1']['b']
    x, new_bn_state = batchnorm_forward(
        params['bn1']['params'], 
        states['bn1'], 
        x, 
        train
    )
    x = nn.relu(x)
    x = jnp.dot(x, params['dense2']['w']) + params['dense2']['b']
    return x, {'bn1': new_bn_state}

Dropout #

Dropout 实现 #

python
import jax
import jax.numpy as jnp

def dropout(key, x, rate=0.5, train=True):
    if not train:
        return x, key
    
    key, subkey = jax.random.split(key)
    mask = jax.random.bernoulli(subkey, 1 - rate, x.shape)
    output = jnp.where(mask, x / (1 - rate), 0)
    return output, key

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (32, 256))

output, key = dropout(key, x, rate=0.5, train=True)
print(f"Dropout 后形状: {output.shape}")
print(f"非零元素比例: {jnp.mean(output != 0):.4f}")

集成到网络 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def forward_with_dropout(params, x, key, train=True):
    key, subkey1 = jax.random.split(key)
    
    x = jnp.dot(x, params['w1']) + params['b1']
    x = nn.relu(x)
    x, key = dropout(key, x, rate=0.5, train=train)
    
    x = jnp.dot(x, params['w2']) + params['b2']
    return x, key

key = jax.random.PRNGKey(0)
params = {
    'w1': jax.random.normal(key, (784, 256)) * 0.01,
    'b1': jnp.zeros(256),
    'w2': jax.random.normal(key, (256, 10)) * 0.01,
    'b2': jnp.zeros(10)
}

x = jax.random.normal(key, (32, 784))
output, key = forward_with_dropout(params, x, key, train=True)
print(f"输出形状: {output.shape}")

完整训练状态 #

状态结构 #

python
import jax
import jax.numpy as jnp

def init_training_state(key, layer_sizes):
    keys = jax.random.split(key, len(layer_sizes) + 1)
    
    params = []
    bn_params = []
    bn_states = []
    
    for i, (in_size, out_size) in enumerate(layer_sizes):
        params.append({
            'w': jax.random.normal(keys[i], (in_size, out_size)) * 0.01,
            'b': jnp.zeros(out_size)
        })
        bn_params.append({
            'scale': jnp.ones(out_size),
            'bias': jnp.zeros(out_size)
        })
        bn_states.append({
            'mean': jnp.zeros(out_size),
            'var': jnp.ones(out_size)
        })
    
    return {
        'params': params,
        'bn_params': bn_params,
        'bn_states': bn_states,
        'optimizer_state': {'momentum': [jnp.zeros_like(p['w']) for p in params],
                           'velocity': [jnp.zeros_like(p['w']) for p in params]}
    }

训练步骤 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

@jax.jit
def train_step(state, x, y, key, lr=0.01):
    def loss_fn(params, bn_params, bn_states):
        h = x
        new_bn_states = []
        
        for i, (p, bn_p, bn_s) in enumerate(zip(params, bn_params, bn_states)):
            h = jnp.dot(h, p['w']) + p['b']
            h, new_bn_s = batchnorm_forward(bn_p, bn_s, h)
            new_bn_states.append(new_bn_s)
            
            if i < len(params) - 1:
                h = nn.relu(h)
                key, subkey = jax.random.split(key)
                h, _ = dropout(subkey, h, rate=0.5)
        
        logits = h
        log_probs = nn.log_softmax(logits)
        one_hot = nn.one_hot(y, logits.shape[-1])
        loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
        
        return loss, new_bn_states
    
    (loss, new_bn_states), grads = jax.value_and_grad(loss_fn, has_aux=True)(
        state['params'], state['bn_params'], state['bn_states']
    )
    
    new_params = jax.tree_map(lambda p, g: p - lr * g, state['params'], grads)
    
    return {
        'params': new_params,
        'bn_params': state['bn_params'],
        'bn_states': new_bn_states,
        'optimizer_state': state['optimizer_state']
    }, loss

优化器状态 #

SGD with Momentum #

python
import jax
import jax.numpy as jnp

def init_sgd_momentum(params):
    return jax.tree_map(jnp.zeros_like, params)

@jax.jit
def sgd_momentum_step(params, grads, state, lr=0.01, momentum=0.9):
    new_state = jax.tree_map(
        lambda s, g: momentum * s + g,
        state, grads
    )
    
    new_params = jax.tree_map(
        lambda p, s: p - lr * s,
        params, new_state
    )
    
    return new_params, new_state

key = jax.random.PRNGKey(0)
params = {'w': jax.random.normal(key, (10, 5))}
grads = {'w': jnp.ones((10, 5))}
state = init_sgd_momentum(params)

new_params, new_state = sgd_momentum_step(params, grads, state)
print(f"更新后参数形状: {new_params['w'].shape}")

Adam #

python
import jax
import jax.numpy as jnp

def init_adam(params):
    return {
        'm': jax.tree_map(jnp.zeros_like, params),
        'v': jax.tree_map(jnp.zeros_like, params),
        't': 0
    }

@jax.jit
def adam_step(params, grads, state, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
    t = state['t'] + 1
    
    new_m = jax.tree_map(
        lambda m, g: beta1 * m + (1 - beta1) * g,
        state['m'], grads
    )
    
    new_v = jax.tree_map(
        lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
        state['v'], grads
    )
    
    m_hat = jax.tree_map(lambda m: m / (1 - beta1 ** t), new_m)
    v_hat = jax.tree_map(lambda v: v / (1 - beta2 ** t), new_v)
    
    new_params = jax.tree_map(
        lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
        params, m_hat, v_hat
    )
    
    new_state = {'m': new_m, 'v': new_v, 't': t}
    
    return new_params, new_state

key = jax.random.PRNGKey(0)
params = {'w': jax.random.normal(key, (10, 5))}
grads = {'w': jnp.ones((10, 5))}
state = init_adam(params)

new_params, new_state = adam_step(params, grads, state)
print(f"更新后参数形状: {new_params['w'].shape}")

下一步 #

现在你已经掌握了状态管理,接下来学习 训练循环,了解如何构建完整的训练流程!

最后更新:2026-04-04