PyTorch 模型保存与加载 #

保存与加载概述 #

在深度学习中,模型保存与加载是必不可少的环节,用于模型持久化、断点续训和模型部署。

text
┌─────────────────────────────────────────────────────────────┐
│                    模型保存与加载场景                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 训练中断恢复                                            │
│     - 保存检查点                                            │
│     - 加载继续训练                                          │
│                                                             │
│  2. 模型共享                                                │
│     - 保存训练好的模型                                      │
│     - 他人加载使用                                          │
│                                                             │
│  3. 模型部署                                                │
│     - 保存推理模型                                          │
│     - 生产环境加载                                          │
│                                                             │
│  4. 迁移学习                                                │
│     - 加载预训练权重                                        │
│     - 微调新任务                                            │
│                                                             │
└─────────────────────────────────────────────────────────────┘

state_dict #

什么是 state_dict #

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()

print("模型的 state_dict:")
for key, value in model.state_dict().items():
    print(f"  {key}: {value.shape}")

state_dict = model.state_dict()
print(f"\nstate_dict 类型: {type(state_dict)}")

保存和加载 state_dict #

python
import torch
import torch.nn as nn

model = SimpleNet()

torch.save(model.state_dict(), 'model_weights.pth')

model = SimpleNet()
model.load_state_dict(torch.load('model_weights.pth'))

model.eval()

完整模型保存 #

保存整个模型 #

python
import torch
import torch.nn as nn

model = SimpleNet()

torch.save(model, 'model_complete.pth')

model = torch.load('model_complete.pth')
model.eval()

保存与加载的区别 #

text
┌─────────────────────────────────────────────────────────────┐
│                    保存方式对比                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  state_dict 方式:                                          │
│  ✅ 文件小                                                  │
│  ✅ 灵活性高                                                │
│  ✅ 推荐使用                                                │
│  ⚠️ 需要模型类定义                                         │
│                                                             │
│  完整模型方式:                                              │
│  ✅ 简单直接                                                │
│  ✅ 不需要模型类定义                                        │
│  ❌ 文件大                                                  │
│  ❌ 依赖 Python pickle                                     │
│  ❌ 可能有兼容性问题                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

检查点保存 #

保存训练检查点 #

python
import torch
import torch.optim as optim

model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.5

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}

torch.save(checkpoint, 'checkpoint.pth')

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

print(f"从 epoch {epoch} 恢复训练,loss: {loss}")

完整训练循环示例 #

python
import torch
import torch.nn as nn
import torch.optim as optim

def save_checkpoint(model, optimizer, epoch, loss, path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

def train(model, dataloader, optimizer, criterion, device, start_epoch=0, epochs=10):
    for epoch in range(start_epoch, epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        if (epoch + 1) % 5 == 0:
            save_checkpoint(model, optimizer, epoch + 1, avg_loss, f'checkpoint_epoch_{epoch+1}.pth')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

start_epoch = 0
train(model, dataloader, optimizer, criterion, device, start_epoch, epochs=20)

保存最佳模型 #

python
import torch

best_loss = float('inf')

for epoch in range(epochs):
    train_loss = train_one_epoch(model, dataloader, criterion, optimizer, device)
    val_loss = validate(model, val_loader, criterion, device)
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'best_model.pth')
        print(f"保存最佳模型,验证损失: {val_loss:.4f}")

跨设备加载 #

CPU 加载 GPU 模型 #

python
import torch

checkpoint = torch.load('model_gpu.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])

model = torch.load('model_gpu.pth', map_location='cpu')

checkpoint = torch.load('model.pth', map_location={'cuda:0': 'cuda:1'})

GPU 加载模型 #

python
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load('model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

model = torch.load('model.pth', map_location=lambda storage, loc: storage.cuda(0))

TorchScript #

脚本化模型 #

python
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleNet()
model.eval()

scripted_model = torch.jit.script(model)

scripted_model.save('scripted_model.pt')

loaded_model = torch.jit.load('scripted_model.pt')

x = torch.randn(1, 784)
output = loaded_model(x)
print(f"输出: {output.shape}")

追踪模型 #

python
import torch
import torch.nn as nn

model = SimpleNet()
model.eval()

example_input = torch.randn(1, 784)
traced_model = torch.jit.trace(model, example_input)

traced_model.save('traced_model.pt')

loaded_model = torch.jit.load('traced_model.pt')

output = loaded_model(example_input)
print(f"输出: {output.shape}")

Script vs Trace #

text
┌─────────────────────────────────────────────────────────────┐
│                    TorchScript 方式对比                      │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  torch.jit.script:                                         │
│  ✅ 支持控制流(if/for/while)                              │
│  ✅ 更灵活                                                  │
│  ⚠️ 需要代码兼容 TorchScript                               │
│                                                             │
│  torch.jit.trace:                                          │
│  ✅ 简单易用                                                │
│  ✅ 无需修改代码                                            │
│  ❌ 不支持控制流                                            │
│  ❌ 只记录追踪路径                                          │
│                                                             │
│  选择建议:                                                  │
│  - 无控制流:使用 trace                                     │
│  - 有控制流:使用 script                                    │
│  - 复杂模型:混合使用                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ONNX 导出 #

导出为 ONNX #

python
import torch
import torch.nn as nn

model = SimpleNet()
model.eval()

dummy_input = torch.randn(1, 784)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print("模型已导出为 ONNX 格式")

加载 ONNX 模型 #

python
import onnx
import onnxruntime as ort
import numpy as np

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

session = ort.InferenceSession("model.onnx")

input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

dummy_input = np.random.randn(1, 784).astype(np.float32)

outputs = session.run([output_name], {input_name: dummy_input})
print(f"ONNX 输出: {outputs[0].shape}")

安全加载 #

安全加载权重 #

python
import torch

torch.load('model.pth', weights_only=True)

try:
    checkpoint = torch.load('model.pth', weights_only=True)
except Exception as e:
    print(f"加载失败: {e}")

from safetensors.torch import load_file, save_file

save_file(model.state_dict(), "model.safetensors")

state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)

模型压缩 #

量化保存 #

python
import torch

model = SimpleNet()
model.eval()

quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

torch.save(quantized_model.state_dict(), 'quantized_model.pth')

print(f"原始模型大小: {sum(p.numel() for p in model.parameters())}")
print(f"量化后参数: {sum(p.numel() for p in quantized_model.parameters())}")

剪枝保存 #

python
import torch
import torch.nn.utils.prune as prune

model = SimpleNet()

for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.2)

prune.remove(model.fc1, 'weight')

torch.save(model.state_dict(), 'pruned_model.pth')

完整示例 #

python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)

def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        return checkpoint['epoch'], checkpoint['best_loss']
    return 0, float('inf')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

start_epoch, best_loss = load_checkpoint(model, optimizer)

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    
    return total_loss / len(dataloader), 100. * correct / len(dataloader.dataset)

for epoch in range(start_epoch, 10):
    loss, acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Acc: {acc:.2f}%")
    
    if loss < best_loss:
        best_loss = loss
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_loss': best_loss,
            'optimizer': optimizer.state_dict(),
        }, 'best_model.pth')

torch.jit.script(model).save('model_scripted.pt')
print("模型已保存为 TorchScript 格式")

下一步 #

现在你已经掌握了 PyTorch 模型保存与加载的核心概念,接下来学习 高级主题,了解分布式训练、模型部署等进阶内容!

最后更新:2026-03-29