训练循环 #
概述 #
本节介绍如何在 JAX 中构建完整的训练循环,包括数据加载、损失计算、参数更新和模型评估。
基本训练步骤 #
定义模型和损失 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
def init_params(key, input_size, hidden_size, output_size):
keys = jax.random.split(key, 3)
return {
'w1': jax.random.normal(keys[0], (input_size, hidden_size)) * 0.01,
'b1': jnp.zeros(hidden_size),
'w2': jax.random.normal(keys[1], (hidden_size, output_size)) * 0.01,
'b2': jnp.zeros(output_size)
}
def forward(params, x):
x = jnp.dot(x, params['w1']) + params['b1']
x = nn.relu(x)
x = jnp.dot(x, params['w2']) + params['b2']
return x
def loss_fn(params, x, y):
logits = forward(params, x)
log_probs = nn.log_softmax(logits)
one_hot = nn.one_hot(y, logits.shape[-1])
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
训练步骤 #
python
import jax
@jax.jit
def train_step(params, x, y, lr=0.01):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return new_params, loss
完整训练循环 #
python
import jax
import jax.numpy as jnp
def train_epoch(params, data_loader, lr=0.01):
total_loss = 0
num_batches = 0
for x, y in data_loader:
params, loss = train_step(params, x, y, lr)
total_loss += loss
num_batches += 1
return params, total_loss / num_batches
def train(params, train_loader, val_loader, epochs=10, lr=0.01):
for epoch in range(epochs):
params, train_loss = train_epoch(params, train_loader, lr)
val_loss = evaluate(params, val_loader)
print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
return params
def evaluate(params, data_loader):
total_loss = 0
num_batches = 0
for x, y in data_loader:
loss = loss_fn(params, x, y)
total_loss += loss
num_batches += 1
return total_loss / num_batches
优化器 #
SGD #
python
import jax
import jax.numpy as jnp
def sgd_step(params, grads, lr=0.01):
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
SGD with Momentum #
python
import jax
import jax.numpy as jnp
def init_momentum_state(params):
return jax.tree_map(jnp.zeros_like, params)
@jax.jit
def momentum_step(params, grads, state, lr=0.01, momentum=0.9):
new_state = jax.tree_map(
lambda s, g: momentum * s + g,
state, grads
)
new_params = jax.tree_map(
lambda p, s: p - lr * s,
params, new_state
)
return new_params, new_state
Adam #
python
import jax
import jax.numpy as jnp
def init_adam_state(params):
return {
'm': jax.tree_map(jnp.zeros_like, params),
'v': jax.tree_map(jnp.zeros_like, params),
't': 0
}
@jax.jit
def adam_step(params, grads, state, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
t = state['t'] + 1
new_m = jax.tree_map(
lambda m, g: beta1 * m + (1 - beta1) * g,
state['m'], grads
)
new_v = jax.tree_map(
lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
state['v'], grads
)
m_hat = jax.tree_map(lambda m: m / (1 - beta1 ** t), new_m)
v_hat = jax.tree_map(lambda v: v / (1 - beta2 ** t), new_v)
new_params = jax.tree_map(
lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
params, m_hat, v_hat
)
return new_params, {'m': new_m, 'v': new_v, 't': t}
使用优化器训练 #
python
import jax
@jax.jit
def train_step_with_adam(params, state, x, y, lr=0.001):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params, new_state = adam_step(params, grads, state, lr)
return new_params, new_state, loss
def train_with_adam(params, train_loader, epochs=10, lr=0.001):
state = init_adam_state(params)
for epoch in range(epochs):
total_loss = 0
num_batches = 0
for x, y in train_loader:
params, state, loss = train_step_with_adam(params, state, x, y, lr)
total_loss += loss
num_batches += 1
avg_loss = total_loss / num_batches
print(f"Epoch {epoch}: loss={avg_loss:.4f}")
return params
学习率调度 #
阶梯式衰减 #
python
def step_decay(epoch, initial_lr, decay_rate=0.1, decay_epochs=10):
return initial_lr * (decay_rate ** (epoch // decay_epochs))
def train_with_lr_schedule(params, train_loader, epochs=30, initial_lr=0.01):
for epoch in range(epochs):
lr = step_decay(epoch, initial_lr)
params, loss = train_epoch(params, train_loader, lr)
print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
return params
指数衰减 #
python
import jax.numpy as jnp
def exponential_decay(epoch, initial_lr, decay_rate=0.95):
return initial_lr * (decay_rate ** epoch)
def train_with_exp_decay(params, train_loader, epochs=30, initial_lr=0.01):
for epoch in range(epochs):
lr = exponential_decay(epoch, initial_lr)
params, loss = train_epoch(params, train_loader, lr)
print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
return params
余弦退火 #
python
import jax.numpy as jnp
def cosine_annealing(epoch, total_epochs, initial_lr, min_lr=0):
return min_lr + 0.5 * (initial_lr - min_lr) * (1 + jnp.cos(jnp.pi * epoch / total_epochs))
def train_with_cosine_annealing(params, train_loader, epochs=30, initial_lr=0.01):
for epoch in range(epochs):
lr = cosine_annealing(epoch, epochs, initial_lr)
params, loss = train_epoch(params, train_loader, lr)
print(f"Epoch {epoch}: lr={lr:.6f}, loss={loss:.4f}")
return params
梯度裁剪 #
全局范数裁剪 #
python
import jax
import jax.numpy as jnp
def clip_grads_global_norm(grads, max_norm=1.0):
leaves = jax.tree_util.tree_leaves(grads)
norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
scale = jnp.minimum(1.0, max_norm / (norm + 1e-6))
return jax.tree_map(lambda g: g * scale, grads)
@jax.jit
def train_step_with_clip(params, x, y, lr=0.01, max_norm=1.0):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
clipped_grads = clip_grads_global_norm(grads, max_norm)
new_params = jax.tree_map(lambda p, g: p - lr * g, params, clipped_grads)
return new_params, loss
按值裁剪 #
python
import jax
def clip_grads_by_value(grads, min_val=-1.0, max_val=1.0):
return jax.tree_map(lambda g: jnp.clip(g, min_val, max_val), grads)
完整训练示例 #
python
import jax
import jax.numpy as jnp
import jax.nn as nn
from functools import partial
class Trainer:
def __init__(self, model_fn, loss_fn, optimizer='adam', lr=0.001):
self.model_fn = model_fn
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr = lr
if optimizer == 'adam':
self.init_optimizer_state = init_adam_state
self.optimizer_step = adam_step
elif optimizer == 'sgd':
self.init_optimizer_state = lambda p: {}
self.optimizer_step = lambda p, g, s, **kw: (sgd_step(p, g, kw.get('lr', lr)), s)
@partial(jax.jit, static_argnums=(0,))
def train_step(self, params, state, x, y):
def loss_wrapper(p):
return self.loss_fn(p, x, y)
loss, grads = jax.value_and_grad(loss_wrapper)(params)
grads = clip_grads_global_norm(grads, max_norm=1.0)
new_params, new_state = self.optimizer_step(params, grads, state, lr=self.lr)
return new_params, new_state, loss
def train(self, params, train_loader, val_loader, epochs=10):
state = self.init_optimizer_state(params)
for epoch in range(epochs):
total_loss = 0
num_batches = 0
for x, y in train_loader:
params, state, loss = self.train_step(params, state, x, y)
total_loss += loss
num_batches += 1
train_loss = total_loss / num_batches
val_loss = self.evaluate(params, val_loader)
print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
return params
def evaluate(self, params, data_loader):
total_loss = 0
num_batches = 0
for x, y in data_loader:
loss = self.loss_fn(params, x, y)
total_loss += loss
num_batches += 1
return total_loss / num_batches
指标跟踪 #
python
import jax.numpy as jnp
def accuracy(params, x, y):
logits = forward(params, x)
predictions = jnp.argmax(logits, axis=-1)
return jnp.mean(predictions == y)
def train_with_metrics(params, train_loader, val_loader, epochs=10, lr=0.01):
for epoch in range(epochs):
total_loss = 0
total_acc = 0
num_batches = 0
for x, y in train_loader:
params, loss = train_step(params, x, y, lr)
acc = accuracy(params, x, y)
total_loss += loss
total_acc += acc
num_batches += 1
train_loss = total_loss / num_batches
train_acc = total_acc / num_batches
val_loss = evaluate(params, val_loader)
val_acc = accuracy(params, val_loader.x, val_loader.y)
print(f"Epoch {epoch}:")
print(f" Train - loss: {train_loss:.4f}, acc: {train_acc:.4f}")
print(f" Val - loss: {val_loss:.4f}, acc: {val_acc:.4f}")
return params
下一步 #
现在你已经掌握了训练循环,接下来学习 模型保存与加载,了解如何持久化模型!
最后更新:2026-04-04