训练循环 #

概述 #

本节介绍如何在 JAX 中构建完整的训练循环,包括数据加载、损失计算、参数更新和模型评估。

基本训练步骤 #

定义模型和损失 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn

def init_params(key, input_size, hidden_size, output_size):
    keys = jax.random.split(key, 3)
    return {
        'w1': jax.random.normal(keys[0], (input_size, hidden_size)) * 0.01,
        'b1': jnp.zeros(hidden_size),
        'w2': jax.random.normal(keys[1], (hidden_size, output_size)) * 0.01,
        'b2': jnp.zeros(output_size)
    }

def forward(params, x):
    x = jnp.dot(x, params['w1']) + params['b1']
    x = nn.relu(x)
    x = jnp.dot(x, params['w2']) + params['b2']
    return x

def loss_fn(params, x, y):
    logits = forward(params, x)
    log_probs = nn.log_softmax(logits)
    one_hot = nn.one_hot(y, logits.shape[-1])
    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))

训练步骤 #

python
import jax

@jax.jit
def train_step(params, x, y, lr=0.01):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

完整训练循环 #

python
import jax
import jax.numpy as jnp

def train_epoch(params, data_loader, lr=0.01):
    total_loss = 0
    num_batches = 0
    
    for x, y in data_loader:
        params, loss = train_step(params, x, y, lr)
        total_loss += loss
        num_batches += 1
    
    return params, total_loss / num_batches

def train(params, train_loader, val_loader, epochs=10, lr=0.01):
    for epoch in range(epochs):
        params, train_loss = train_epoch(params, train_loader, lr)
        val_loss = evaluate(params, val_loader)
        
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    
    return params

def evaluate(params, data_loader):
    total_loss = 0
    num_batches = 0
    
    for x, y in data_loader:
        loss = loss_fn(params, x, y)
        total_loss += loss
        num_batches += 1
    
    return total_loss / num_batches

优化器 #

SGD #

python
import jax
import jax.numpy as jnp

def sgd_step(params, grads, lr=0.01):
    return jax.tree_map(lambda p, g: p - lr * g, params, grads)

SGD with Momentum #

python
import jax
import jax.numpy as jnp

def init_momentum_state(params):
    return jax.tree_map(jnp.zeros_like, params)

@jax.jit
def momentum_step(params, grads, state, lr=0.01, momentum=0.9):
    new_state = jax.tree_map(
        lambda s, g: momentum * s + g,
        state, grads
    )
    new_params = jax.tree_map(
        lambda p, s: p - lr * s,
        params, new_state
    )
    return new_params, new_state

Adam #

python
import jax
import jax.numpy as jnp

def init_adam_state(params):
    return {
        'm': jax.tree_map(jnp.zeros_like, params),
        'v': jax.tree_map(jnp.zeros_like, params),
        't': 0
    }

@jax.jit
def adam_step(params, grads, state, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
    t = state['t'] + 1
    
    new_m = jax.tree_map(
        lambda m, g: beta1 * m + (1 - beta1) * g,
        state['m'], grads
    )
    new_v = jax.tree_map(
        lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
        state['v'], grads
    )
    
    m_hat = jax.tree_map(lambda m: m / (1 - beta1 ** t), new_m)
    v_hat = jax.tree_map(lambda v: v / (1 - beta2 ** t), new_v)
    
    new_params = jax.tree_map(
        lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
        params, m_hat, v_hat
    )
    
    return new_params, {'m': new_m, 'v': new_v, 't': t}

使用优化器训练 #

python
import jax

@jax.jit
def train_step_with_adam(params, state, x, y, lr=0.001):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params, new_state = adam_step(params, grads, state, lr)
    return new_params, new_state, loss

def train_with_adam(params, train_loader, epochs=10, lr=0.001):
    state = init_adam_state(params)
    
    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0
        
        for x, y in train_loader:
            params, state, loss = train_step_with_adam(params, state, x, y, lr)
            total_loss += loss
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch}: loss={avg_loss:.4f}")
    
    return params

学习率调度 #

阶梯式衰减 #

python
def step_decay(epoch, initial_lr, decay_rate=0.1, decay_epochs=10):
    return initial_lr * (decay_rate ** (epoch // decay_epochs))

def train_with_lr_schedule(params, train_loader, epochs=30, initial_lr=0.01):
    for epoch in range(epochs):
        lr = step_decay(epoch, initial_lr)
        params, loss = train_epoch(params, train_loader, lr)
        print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
    
    return params

指数衰减 #

python
import jax.numpy as jnp

def exponential_decay(epoch, initial_lr, decay_rate=0.95):
    return initial_lr * (decay_rate ** epoch)

def train_with_exp_decay(params, train_loader, epochs=30, initial_lr=0.01):
    for epoch in range(epochs):
        lr = exponential_decay(epoch, initial_lr)
        params, loss = train_epoch(params, train_loader, lr)
        print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
    
    return params

余弦退火 #

python
import jax.numpy as jnp

def cosine_annealing(epoch, total_epochs, initial_lr, min_lr=0):
    return min_lr + 0.5 * (initial_lr - min_lr) * (1 + jnp.cos(jnp.pi * epoch / total_epochs))

def train_with_cosine_annealing(params, train_loader, epochs=30, initial_lr=0.01):
    for epoch in range(epochs):
        lr = cosine_annealing(epoch, epochs, initial_lr)
        params, loss = train_epoch(params, train_loader, lr)
        print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
    
    return params

梯度裁剪 #

全局范数裁剪 #

python
import jax
import jax.numpy as jnp

def clip_grads_global_norm(grads, max_norm=1.0):
    leaves = jax.tree_util.tree_leaves(grads)
    norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
    
    scale = jnp.minimum(1.0, max_norm / (norm + 1e-6))
    
    return jax.tree_map(lambda g: g * scale, grads)

@jax.jit
def train_step_with_clip(params, x, y, lr=0.01, max_norm=1.0):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    clipped_grads = clip_grads_global_norm(grads, max_norm)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, clipped_grads)
    return new_params, loss

按值裁剪 #

python
import jax

def clip_grads_by_value(grads, min_val=-1.0, max_val=1.0):
    return jax.tree_map(lambda g: jnp.clip(g, min_val, max_val), grads)

完整训练示例 #

python
import jax
import jax.numpy as jnp
import jax.nn as nn
from functools import partial

class Trainer:
    def __init__(self, model_fn, loss_fn, optimizer='adam', lr=0.001):
        self.model_fn = model_fn
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.lr = lr
        
        if optimizer == 'adam':
            self.init_optimizer_state = init_adam_state
            self.optimizer_step = adam_step
        elif optimizer == 'sgd':
            self.init_optimizer_state = lambda p: {}
            self.optimizer_step = lambda p, g, s, **kw: (sgd_step(p, g, kw.get('lr', lr)), s)
    
    @partial(jax.jit, static_argnums=(0,))
    def train_step(self, params, state, x, y):
        def loss_wrapper(p):
            return self.loss_fn(p, x, y)
        
        loss, grads = jax.value_and_grad(loss_wrapper)(params)
        grads = clip_grads_global_norm(grads, max_norm=1.0)
        
        new_params, new_state = self.optimizer_step(params, grads, state, lr=self.lr)
        return new_params, new_state, loss
    
    def train(self, params, train_loader, val_loader, epochs=10):
        state = self.init_optimizer_state(params)
        
        for epoch in range(epochs):
            total_loss = 0
            num_batches = 0
            
            for x, y in train_loader:
                params, state, loss = self.train_step(params, state, x, y)
                total_loss += loss
                num_batches += 1
            
            train_loss = total_loss / num_batches
            val_loss = self.evaluate(params, val_loader)
            
            print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
        
        return params
    
    def evaluate(self, params, data_loader):
        total_loss = 0
        num_batches = 0
        
        for x, y in data_loader:
            loss = self.loss_fn(params, x, y)
            total_loss += loss
            num_batches += 1
        
        return total_loss / num_batches

指标跟踪 #

python
import jax.numpy as jnp

def accuracy(params, x, y):
    logits = forward(params, x)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == y)

def train_with_metrics(params, train_loader, val_loader, epochs=10, lr=0.01):
    for epoch in range(epochs):
        total_loss = 0
        total_acc = 0
        num_batches = 0
        
        for x, y in train_loader:
            params, loss = train_step(params, x, y, lr)
            acc = accuracy(params, x, y)
            
            total_loss += loss
            total_acc += acc
            num_batches += 1
        
        train_loss = total_loss / num_batches
        train_acc = total_acc / num_batches
        
        val_loss = evaluate(params, val_loader)
        val_acc = accuracy(params, val_loader.x, val_loader.y)
        
        print(f"Epoch {epoch}:")
        print(f"  Train - loss: {train_loss:.4f}, acc: {train_acc:.4f}")
        print(f"  Val   - loss: {val_loss:.4f}, acc: {val_acc:.4f}")
    
    return params

下一步 #

现在你已经掌握了训练循环,接下来学习 模型保存与加载,了解如何持久化模型!

最后更新:2026-04-04