PyTorch 模型保存与加载 #
保存与加载概述 #
在深度学习中,模型保存与加载是必不可少的环节,用于模型持久化、断点续训和模型部署。
text
┌─────────────────────────────────────────────────────────────┐
│ 模型保存与加载场景 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 训练中断恢复 │
│ - 保存检查点 │
│ - 加载继续训练 │
│ │
│ 2. 模型共享 │
│ - 保存训练好的模型 │
│ - 他人加载使用 │
│ │
│ 3. 模型部署 │
│ - 保存推理模型 │
│ - 生产环境加载 │
│ │
│ 4. 迁移学习 │
│ - 加载预训练权重 │
│ - 微调新任务 │
│ │
└─────────────────────────────────────────────────────────────┘
state_dict #
什么是 state_dict #
python
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNet()
print("模型的 state_dict:")
for key, value in model.state_dict().items():
print(f" {key}: {value.shape}")
state_dict = model.state_dict()
print(f"\nstate_dict 类型: {type(state_dict)}")
保存和加载 state_dict #
python
import torch
import torch.nn as nn
model = SimpleNet()
torch.save(model.state_dict(), 'model_weights.pth')
model = SimpleNet()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
完整模型保存 #
保存整个模型 #
python
import torch
import torch.nn as nn
model = SimpleNet()
torch.save(model, 'model_complete.pth')
model = torch.load('model_complete.pth')
model.eval()
保存与加载的区别 #
text
┌─────────────────────────────────────────────────────────────┐
│ 保存方式对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ state_dict 方式: │
│ ✅ 文件小 │
│ ✅ 灵活性高 │
│ ✅ 推荐使用 │
│ ⚠️ 需要模型类定义 │
│ │
│ 完整模型方式: │
│ ✅ 简单直接 │
│ ✅ 不需要模型类定义 │
│ ❌ 文件大 │
│ ❌ 依赖 Python pickle │
│ ❌ 可能有兼容性问题 │
│ │
└─────────────────────────────────────────────────────────────┘
检查点保存 #
保存训练检查点 #
python
import torch
import torch.optim as optim
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch = 10
loss = 0.5
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"从 epoch {epoch} 恢复训练,loss: {loss}")
完整训练循环示例 #
python
import torch
import torch.nn as nn
import torch.optim as optim
def save_checkpoint(model, optimizer, epoch, loss, path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, path)
def load_checkpoint(model, optimizer, path):
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['loss']
def train(model, dataloader, optimizer, criterion, device, start_epoch=0, epochs=10):
for epoch in range(start_epoch, epochs):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
if (epoch + 1) % 5 == 0:
save_checkpoint(model, optimizer, epoch + 1, avg_loss, f'checkpoint_epoch_{epoch+1}.pth')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
start_epoch = 0
train(model, dataloader, optimizer, criterion, device, start_epoch, epochs=20)
保存最佳模型 #
python
import torch
best_loss = float('inf')
for epoch in range(epochs):
train_loss = train_one_epoch(model, dataloader, criterion, optimizer, device)
val_loss = validate(model, val_loader, criterion, device)
if val_loss < best_loss:
best_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
}, 'best_model.pth')
print(f"保存最佳模型,验证损失: {val_loss:.4f}")
跨设备加载 #
CPU 加载 GPU 模型 #
python
import torch
checkpoint = torch.load('model_gpu.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = torch.load('model_gpu.pth', map_location='cpu')
checkpoint = torch.load('model.pth', map_location={'cuda:0': 'cuda:1'})
GPU 加载模型 #
python
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load('model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model = torch.load('model.pth', map_location=lambda storage, loc: storage.cuda(0))
TorchScript #
脚本化模型 #
python
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
model = SimpleNet()
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save('scripted_model.pt')
loaded_model = torch.jit.load('scripted_model.pt')
x = torch.randn(1, 784)
output = loaded_model(x)
print(f"输出: {output.shape}")
追踪模型 #
python
import torch
import torch.nn as nn
model = SimpleNet()
model.eval()
example_input = torch.randn(1, 784)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('traced_model.pt')
loaded_model = torch.jit.load('traced_model.pt')
output = loaded_model(example_input)
print(f"输出: {output.shape}")
Script vs Trace #
text
┌─────────────────────────────────────────────────────────────┐
│ TorchScript 方式对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ torch.jit.script: │
│ ✅ 支持控制流(if/for/while) │
│ ✅ 更灵活 │
│ ⚠️ 需要代码兼容 TorchScript │
│ │
│ torch.jit.trace: │
│ ✅ 简单易用 │
│ ✅ 无需修改代码 │
│ ❌ 不支持控制流 │
│ ❌ 只记录追踪路径 │
│ │
│ 选择建议: │
│ - 无控制流:使用 trace │
│ - 有控制流:使用 script │
│ - 复杂模型:混合使用 │
│ │
└─────────────────────────────────────────────────────────────┘
ONNX 导出 #
导出为 ONNX #
python
import torch
import torch.nn as nn
model = SimpleNet()
model.eval()
dummy_input = torch.randn(1, 784)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print("模型已导出为 ONNX 格式")
加载 ONNX 模型 #
python
import onnx
import onnxruntime as ort
import numpy as np
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
dummy_input = np.random.randn(1, 784).astype(np.float32)
outputs = session.run([output_name], {input_name: dummy_input})
print(f"ONNX 输出: {outputs[0].shape}")
安全加载 #
安全加载权重 #
python
import torch
torch.load('model.pth', weights_only=True)
try:
checkpoint = torch.load('model.pth', weights_only=True)
except Exception as e:
print(f"加载失败: {e}")
from safetensors.torch import load_file, save_file
save_file(model.state_dict(), "model.safetensors")
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
模型压缩 #
量化保存 #
python
import torch
model = SimpleNet()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
torch.save(quantized_model.state_dict(), 'quantized_model.pth')
print(f"原始模型大小: {sum(p.numel() for p in model.parameters())}")
print(f"量化后参数: {sum(p.numel() for p in quantized_model.parameters())}")
剪枝保存 #
python
import torch
import torch.nn.utils.prune as prune
model = SimpleNet()
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.2)
prune.remove(model.fc1, 'weight')
torch.save(model.state_dict(), 'pruned_model.pth')
完整示例 #
python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
return x
def save_checkpoint(state, filename='checkpoint.pth'):
torch.save(state, filename)
def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
if os.path.isfile(filename):
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['epoch'], checkpoint['best_loss']
return 0, float('inf')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
start_epoch, best_loss = load_checkpoint(model, optimizer)
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0
correct = 0
for data, target in dataloader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
return total_loss / len(dataloader), 100. * correct / len(dataloader.dataset)
for epoch in range(start_epoch, 10):
loss, acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Acc: {acc:.2f}%")
if loss < best_loss:
best_loss = loss
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_loss': best_loss,
'optimizer': optimizer.state_dict(),
}, 'best_model.pth')
torch.jit.script(model).save('model_scripted.pt')
print("模型已保存为 TorchScript 格式")
下一步 #
现在你已经掌握了 PyTorch 模型保存与加载的核心概念,接下来学习 高级主题,了解分布式训练、模型部署等进阶内容!
最后更新:2026-03-29