模型保存与加载 #
概述 #
JAX 提供了多种方式来保存和加载模型参数。本节介绍如何使用检查点和序列化来持久化模型。
基本保存与加载 #
使用 pickle #
python
import pickle
import jax
import jax.numpy as jnp
def save_params_pickle(params, filepath):
params_np = jax.tree_map(lambda x: np.array(x), params)
with open(filepath, 'wb') as f:
pickle.dump(params_np, f)
def load_params_pickle(filepath):
with open(filepath, 'rb') as f:
params_np = pickle.load(f)
return jax.tree_map(lambda x: jnp.array(x), params_np)
import numpy as np
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
save_params_pickle(params, 'params.pkl')
loaded = load_params_pickle('params.pkl')
print(f"加载的参数: {loaded}")
使用 numpy #
python
import numpy as np
import jax
import jax.numpy as jnp
def save_params_npz(params, filepath):
flat_params = jax.tree_util.tree_flatten(params)[0]
np.savez(filepath, *flat_params)
def load_params_npz(filepath, params_template):
data = np.load(filepath)
flat_params = [jnp.array(data[f'arr_{i}']) for i in range(len(data.files))]
structure = jax.tree_util.tree_structure(params_template)
return jax.tree_util.tree_unflatten(structure, flat_params)
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
save_params_npz(params, 'params.npz')
loaded = load_params_npz('params.npz', params)
print(f"加载的参数: {loaded}")
使用 JAX 检查点 #
orbax 检查点 #
python
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
def save_checkpoint_orbax(params, state, step, checkpoint_dir):
checkpointer = ocp.PyTreeCheckpointer()
save_args = ocp.save_args.PyTreeSaveArgs(
item={'params': params, 'state': state, 'step': step}
)
checkpointer.save(checkpoint_dir, save_args)
def load_checkpoint_orbax(checkpoint_dir):
checkpointer = ocp.PyTreeCheckpointer()
return checkpointer.load(checkpoint_dir)
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
state = {'step': 0}
save_checkpoint_orbax(params, state, 0, 'checkpoint')
loaded = load_checkpoint_orbax('checkpoint')
print(f"加载的检查点: {loaded}")
flax 检查点 #
python
from flax.training import checkpoints
import jax
import jax.numpy as jnp
def save_checkpoint_flax(params, state, step, checkpoint_dir):
checkpoints.save_checkpoint(
checkpoint_dir,
{'params': params, 'state': state},
step,
overwrite=True
)
def load_checkpoint_flax(checkpoint_dir, step=None):
return checkpoints.restore_checkpoint(checkpoint_dir, None, step=step)
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
state = {'step': 0}
save_checkpoint_flax(params, state, 0, 'flax_checkpoint')
loaded = load_checkpoint_flax('flax_checkpoint')
print(f"加载的检查点: {loaded}")
完整训练检查点 #
保存训练状态 #
python
import jax
import jax.numpy as jnp
import pickle
import os
class CheckpointManager:
def __init__(self, checkpoint_dir):
self.checkpoint_dir = checkpoint_dir
os.makedirs(checkpoint_dir, exist_ok=True)
def save(self, state, step):
filepath = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pkl')
save_state = {
'params': jax.tree_map(lambda x: np.array(x), state['params']),
'optimizer_state': jax.tree_map(
lambda x: np.array(x) if hasattr(x, 'shape') else x,
state['optimizer_state']
),
'step': step
}
with open(filepath, 'wb') as f:
pickle.dump(save_state, f)
print(f"检查点已保存: {filepath}")
def load(self, step=None):
if step is None:
files = [f for f in os.listdir(self.checkpoint_dir) if f.startswith('checkpoint_')]
if not files:
return None
latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
filepath = os.path.join(self.checkpoint_dir, latest)
else:
filepath = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pkl')
if not os.path.exists(filepath):
return None
with open(filepath, 'rb') as f:
save_state = pickle.load(f)
state = {
'params': jax.tree_map(lambda x: jnp.array(x), save_state['params']),
'optimizer_state': save_state['optimizer_state'],
'step': save_state['step']
}
print(f"检查点已加载: {filepath}")
return state
import numpy as np
manager = CheckpointManager('checkpoints')
state = {
'params': {'w': jnp.array([1.0, 2.0])},
'optimizer_state': {'m': jnp.array([0.0, 0.0])},
'step': 100
}
manager.save(state, 100)
loaded_state = manager.load()
print(f"加载的步骤: {loaded_state['step']}")
训练中断恢复 #
python
import jax
import jax.numpy as jnp
def train_with_checkpoint(params, train_loader, epochs, checkpoint_manager, resume=True):
if resume:
state = checkpoint_manager.load()
if state is not None:
params = state['params']
start_epoch = state['step'] + 1
print(f"从 epoch {start_epoch} 恢复训练")
else:
start_epoch = 0
else:
start_epoch = 0
for epoch in range(start_epoch, epochs):
params, loss = train_epoch(params, train_loader)
if epoch % 10 == 0:
checkpoint_manager.save(
{'params': params, 'optimizer_state': {}, 'step': epoch},
epoch
)
print(f"Epoch {epoch}: loss={loss:.4f}")
return params
模型导出 #
导出为 ONNX #
python
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf
def export_to_onnx(params, forward_fn, input_shape, filepath):
tf_fn = jax2tf.convert(lambda x: forward_fn(params, x))
input_signature = [tf.TensorSpec(input_shape, tf.float32)]
tf_model = tf.function(tf_fn, input_signature=input_signature)
tf.saved_model.save(tf_model, filepath)
print(f"模型已导出: {filepath}")
params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}
def forward(params, x):
return jnp.dot(x, params['w']) + params['b']
export_to_onnx(params, forward, (None, 10), 'saved_model')
导出为 TensorFlow SavedModel #
python
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf
def export_to_savedmodel(params, forward_fn, input_shape, filepath):
tf_fn = jax2tf.convert(
lambda x: forward_fn(params, x),
with_gradient=False
)
input_signature = [tf.TensorSpec(input_shape, tf.float32)]
tf_model = tf.function(tf_fn, input_signature=input_signature)
tf.saved_model.save(tf_model, filepath)
print(f"SavedModel 已导出: {filepath}")
最佳实践 #
1. 定期保存检查点 #
python
import os
def train_with_regular_checkpoints(params, train_loader, epochs, checkpoint_dir, save_every=5):
manager = CheckpointManager(checkpoint_dir)
for epoch in range(epochs):
params, loss = train_epoch(params, train_loader)
if epoch % save_every == 0:
manager.save(
{'params': params, 'optimizer_state': {}, 'step': epoch},
epoch
)
print(f"Epoch {epoch}: loss={loss:.4f}")
manager.save(
{'params': params, 'optimizer_state': {}, 'step': epochs - 1},
epochs - 1
)
return params
2. 保存最佳模型 #
python
import os
def train_save_best(params, train_loader, val_loader, epochs, checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_loss = float('inf')
for epoch in range(epochs):
params, train_loss = train_epoch(params, train_loader)
val_loss = evaluate(params, val_loader)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_params_pickle(params, os.path.join(checkpoint_dir, 'best_params.pkl'))
print(f"Epoch {epoch}: 保存最佳模型, val_loss={val_loss:.4f}")
print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
return params
3. 版本控制 #
python
import os
import json
from datetime import datetime
def save_with_metadata(params, checkpoint_dir, metadata=None):
os.makedirs(checkpoint_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
checkpoint_path = os.path.join(checkpoint_dir, f'params_{timestamp}.pkl')
save_params_pickle(params, checkpoint_path)
if metadata:
meta_path = os.path.join(checkpoint_dir, f'metadata_{timestamp}.json')
with open(meta_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"检查点已保存: {checkpoint_path}")
params = {'w': jnp.array([1.0, 2.0])}
metadata = {
'model': 'MLP',
'epochs': 100,
'lr': 0.001,
'train_loss': 0.05,
'val_loss': 0.08
}
save_with_metadata(params, 'checkpoints', metadata)
下一步 #
现在你已经掌握了模型保存与加载,接下来学习 多设备计算,了解 JAX 的分布式计算能力!
最后更新:2026-04-04