模型保存与加载 #

概述 #

JAX 提供了多种方式来保存和加载模型参数。本节介绍如何使用检查点和序列化来持久化模型。

基本保存与加载 #

使用 pickle #

python
import pickle
import jax
import jax.numpy as jnp

def save_params_pickle(params, filepath):
    params_np = jax.tree_map(lambda x: np.array(x), params)
    with open(filepath, 'wb') as f:
        pickle.dump(params_np, f)

def load_params_pickle(filepath):
    with open(filepath, 'rb') as f:
        params_np = pickle.load(f)
    return jax.tree_map(lambda x: jnp.array(x), params_np)

import numpy as np
params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
save_params_pickle(params, 'params.pkl')
loaded = load_params_pickle('params.pkl')
print(f"加载的参数: {loaded}")

使用 numpy #

python
import numpy as np
import jax
import jax.numpy as jnp

def save_params_npz(params, filepath):
    flat_params = jax.tree_util.tree_flatten(params)[0]
    np.savez(filepath, *flat_params)

def load_params_npz(filepath, params_template):
    data = np.load(filepath)
    flat_params = [jnp.array(data[f'arr_{i}']) for i in range(len(data.files))]
    structure = jax.tree_util.tree_structure(params_template)
    return jax.tree_util.tree_unflatten(structure, flat_params)

params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
save_params_npz(params, 'params.npz')
loaded = load_params_npz('params.npz', params)
print(f"加载的参数: {loaded}")

使用 JAX 检查点 #

orbax 检查点 #

python
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp

def save_checkpoint_orbax(params, state, step, checkpoint_dir):
    checkpointer = ocp.PyTreeCheckpointer()
    save_args = ocp.save_args.PyTreeSaveArgs(
        item={'params': params, 'state': state, 'step': step}
    )
    checkpointer.save(checkpoint_dir, save_args)

def load_checkpoint_orbax(checkpoint_dir):
    checkpointer = ocp.PyTreeCheckpointer()
    return checkpointer.load(checkpoint_dir)

params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
state = {'step': 0}
save_checkpoint_orbax(params, state, 0, 'checkpoint')
loaded = load_checkpoint_orbax('checkpoint')
print(f"加载的检查点: {loaded}")

flax 检查点 #

python
from flax.training import checkpoints
import jax
import jax.numpy as jnp

def save_checkpoint_flax(params, state, step, checkpoint_dir):
    checkpoints.save_checkpoint(
        checkpoint_dir,
        {'params': params, 'state': state},
        step,
        overwrite=True
    )

def load_checkpoint_flax(checkpoint_dir, step=None):
    return checkpoints.restore_checkpoint(checkpoint_dir, None, step=step)

params = {'w': jnp.array([1.0, 2.0]), 'b': jnp.array([0.1])}
state = {'step': 0}
save_checkpoint_flax(params, state, 0, 'flax_checkpoint')
loaded = load_checkpoint_flax('flax_checkpoint')
print(f"加载的检查点: {loaded}")

完整训练检查点 #

保存训练状态 #

python
import jax
import jax.numpy as jnp
import pickle
import os

class CheckpointManager:
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)
    
    def save(self, state, step):
        filepath = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pkl')
        
        save_state = {
            'params': jax.tree_map(lambda x: np.array(x), state['params']),
            'optimizer_state': jax.tree_map(
                lambda x: np.array(x) if hasattr(x, 'shape') else x,
                state['optimizer_state']
            ),
            'step': step
        }
        
        with open(filepath, 'wb') as f:
            pickle.dump(save_state, f)
        
        print(f"检查点已保存: {filepath}")
    
    def load(self, step=None):
        if step is None:
            files = [f for f in os.listdir(self.checkpoint_dir) if f.startswith('checkpoint_')]
            if not files:
                return None
            latest = max(files, key=lambda f: int(f.split('_')[1].split('.')[0]))
            filepath = os.path.join(self.checkpoint_dir, latest)
        else:
            filepath = os.path.join(self.checkpoint_dir, f'checkpoint_{step}.pkl')
        
        if not os.path.exists(filepath):
            return None
        
        with open(filepath, 'rb') as f:
            save_state = pickle.load(f)
        
        state = {
            'params': jax.tree_map(lambda x: jnp.array(x), save_state['params']),
            'optimizer_state': save_state['optimizer_state'],
            'step': save_state['step']
        }
        
        print(f"检查点已加载: {filepath}")
        return state

import numpy as np
manager = CheckpointManager('checkpoints')

state = {
    'params': {'w': jnp.array([1.0, 2.0])},
    'optimizer_state': {'m': jnp.array([0.0, 0.0])},
    'step': 100
}
manager.save(state, 100)

loaded_state = manager.load()
print(f"加载的步骤: {loaded_state['step']}")

训练中断恢复 #

python
import jax
import jax.numpy as jnp

def train_with_checkpoint(params, train_loader, epochs, checkpoint_manager, resume=True):
    if resume:
        state = checkpoint_manager.load()
        if state is not None:
            params = state['params']
            start_epoch = state['step'] + 1
            print(f"从 epoch {start_epoch} 恢复训练")
        else:
            start_epoch = 0
    else:
        start_epoch = 0
    
    for epoch in range(start_epoch, epochs):
        params, loss = train_epoch(params, train_loader)
        
        if epoch % 10 == 0:
            checkpoint_manager.save(
                {'params': params, 'optimizer_state': {}, 'step': epoch},
                epoch
            )
            print(f"Epoch {epoch}: loss={loss:.4f}")
    
    return params

模型导出 #

导出为 ONNX #

python
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

def export_to_onnx(params, forward_fn, input_shape, filepath):
    tf_fn = jax2tf.convert(lambda x: forward_fn(params, x))
    
    input_signature = [tf.TensorSpec(input_shape, tf.float32)]
    tf_model = tf.function(tf_fn, input_signature=input_signature)
    
    tf.saved_model.save(tf_model, filepath)
    print(f"模型已导出: {filepath}")

params = {'w': jnp.ones((10, 5)), 'b': jnp.zeros(5)}

def forward(params, x):
    return jnp.dot(x, params['w']) + params['b']

export_to_onnx(params, forward, (None, 10), 'saved_model')

导出为 TensorFlow SavedModel #

python
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

def export_to_savedmodel(params, forward_fn, input_shape, filepath):
    tf_fn = jax2tf.convert(
        lambda x: forward_fn(params, x),
        with_gradient=False
    )
    
    input_signature = [tf.TensorSpec(input_shape, tf.float32)]
    tf_model = tf.function(tf_fn, input_signature=input_signature)
    
    tf.saved_model.save(tf_model, filepath)
    print(f"SavedModel 已导出: {filepath}")

最佳实践 #

1. 定期保存检查点 #

python
import os

def train_with_regular_checkpoints(params, train_loader, epochs, checkpoint_dir, save_every=5):
    manager = CheckpointManager(checkpoint_dir)
    
    for epoch in range(epochs):
        params, loss = train_epoch(params, train_loader)
        
        if epoch % save_every == 0:
            manager.save(
                {'params': params, 'optimizer_state': {}, 'step': epoch},
                epoch
            )
        
        print(f"Epoch {epoch}: loss={loss:.4f}")
    
    manager.save(
        {'params': params, 'optimizer_state': {}, 'step': epochs - 1},
        epochs - 1
    )
    
    return params

2. 保存最佳模型 #

python
import os

def train_save_best(params, train_loader, val_loader, epochs, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        params, train_loss = train_epoch(params, train_loader)
        val_loss = evaluate(params, val_loader)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_params_pickle(params, os.path.join(checkpoint_dir, 'best_params.pkl'))
            print(f"Epoch {epoch}: 保存最佳模型, val_loss={val_loss:.4f}")
        
        print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}")
    
    return params

3. 版本控制 #

python
import os
import json
from datetime import datetime

def save_with_metadata(params, checkpoint_dir, metadata=None):
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    checkpoint_path = os.path.join(checkpoint_dir, f'params_{timestamp}.pkl')
    
    save_params_pickle(params, checkpoint_path)
    
    if metadata:
        meta_path = os.path.join(checkpoint_dir, f'metadata_{timestamp}.json')
        with open(meta_path, 'w') as f:
            json.dump(metadata, f, indent=2)
    
    print(f"检查点已保存: {checkpoint_path}")

params = {'w': jnp.array([1.0, 2.0])}
metadata = {
    'model': 'MLP',
    'epochs': 100,
    'lr': 0.001,
    'train_loss': 0.05,
    'val_loss': 0.08
}
save_with_metadata(params, 'checkpoints', metadata)

下一步 #

现在你已经掌握了模型保存与加载,接下来学习 多设备计算,了解 JAX 的分布式计算能力!

最后更新:2026-04-04