PyTorch 模型转换实战 #
ResNet 模型转换 #
基本转换 #
python
import torch
import torchvision
import torch.onnx
model = torchvision.models.resnet50(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"resnet50.onnx",
input_names=["input"],
output_names=["output"],
opset_version=17,
do_constant_folding=True
)
print("ResNet50 模型导出成功")
动态 Batch Size #
python
import torch
import torchvision
model = torchvision.models.resnet50(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
dynamic_axes = {
"input": {0: "batch_size"},
"output": {0: "batch_size"}
}
torch.onnx.export(
model,
dummy_input,
"resnet50_dynamic.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
opset_version=17
)
完整转换流程 #
python
import torch
import torchvision
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np
def export_resnet_to_onnx(
model_name="resnet50",
output_path="resnet50.onnx",
opset_version=17,
dynamic_batch=True
):
print(f"加载 {model_name} 模型...")
model_fn = getattr(torchvision.models, model_name)
model = model_fn(pretrained=True)
model.eval()
print("创建输入张量...")
dummy_input = torch.randn(1, 3, 224, 224)
print("配置导出参数...")
export_args = {
"model": model,
"args": dummy_input,
"f": output_path,
"input_names": ["input"],
"output_names": ["output"],
"opset_version": opset_version,
"do_constant_folding": True,
"verbose": False
}
if dynamic_batch:
export_args["dynamic_axes"] = {
"input": {0: "batch_size"},
"output": {0: "batch_size"}
}
print("导出 ONNX 模型...")
torch.onnx.export(**export_args)
print("验证 ONNX 模型...")
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("验证输出一致性...")
session = ort.InferenceSession(output_path)
input_name = session.get_inputs()[0].name
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
with torch.no_grad():
torch_output = model(torch.from_numpy(test_input)).numpy()
onnx_output = session.run(None, {input_name: test_input})[0]
np.testing.assert_allclose(torch_output, onnx_output, rtol=1e-3, atol=1e-5)
print("✅ 输出一致性验证通过")
print(f"\n模型信息:")
print(f" 文件: {output_path}")
print(f" Opset: {opset_version}")
print(f" 动态 Batch: {dynamic_batch}")
return output_path
export_resnet_to_onnx()
BERT 模型转换 #
HuggingFace Transformers 转换 #
python
import torch
from transformers import BertModel, BertTokenizer
import onnx
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()
text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
token_type_ids = inputs.get("token_type_ids", torch.zeros_like(input_ids))
class BertWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
return outputs.last_hidden_state, outputs.pooler_output
wrapped_model = BertWrapper(model)
torch.onnx.export(
wrapped_model,
(input_ids, attention_mask, token_type_ids),
"bert_base.onnx",
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"}
},
opset_version=17
)
print("BERT 模型导出成功")
优化 BERT 导出 #
python
import torch
from transformers import BertModel, BertTokenizer
import onnx
from onnxruntime.transformers import optimizer
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()
max_seq_length = 128
dummy_input_ids = torch.zeros(1, max_seq_length, dtype=torch.long)
dummy_attention_mask = torch.ones(1, max_seq_length, dtype=torch.long)
dummy_token_type_ids = torch.zeros(1, max_seq_length, dtype=torch.long)
class OptimizedBertWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
return outputs.last_hidden_state
wrapped_model = OptimizedBertWrapper(model)
torch.onnx.export(
wrapped_model,
(dummy_input_ids, dummy_attention_mask, dummy_token_type_ids),
"bert_optimized.onnx",
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch_size"},
"attention_mask": {0: "batch_size"},
"token_type_ids": {0: "batch_size"},
"last_hidden_state": {0: "batch_size"}
},
opset_version=17
)
optimized_model = optimizer.optimize_model(
"bert_optimized.onnx",
model_type="bert",
num_heads=12,
hidden_size=768
)
optimized_model.save_model_to_file("bert_optimized_final.onnx")
print("优化后的 BERT 模型已保存")
YOLO 模型转换 #
YOLOv5 转换 #
python
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model.eval()
img = torch.zeros(1, 3, 640, 640)
torch.onnx.export(
model,
img,
"yolov5s.onnx",
input_names=["images"],
output_names=["output"],
dynamic_axes={
"images": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=17
)
print("YOLOv5 模型导出成功")
YOLOv8 转换 #
python
from ultralytics import YOLO
model = YOLO("yolov8n.pt")
model.export(format="onnx", dynamic=True, simplify=True)
print("YOLOv8 模型导出成功")
自定义模型转换 #
带控制流的模型 #
python
import torch
import torch.onnx
class ConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 20)
self.fc2 = torch.nn.Linear(20, 10)
self.fc3 = torch.nn.Linear(10, 5)
def forward(self, x, condition):
x = torch.relu(self.fc1(x))
if condition.sum() > 0:
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = ConditionalModel()
model.eval()
x = torch.randn(1, 10)
condition = torch.tensor([1.0])
torch.onnx.export(
model,
(x, condition),
"conditional_model.onnx",
input_names=["input", "condition"],
output_names=["output"],
opset_version=17
)
带循环的模型 #
python
import torch
import torch.onnx
class RNNModel(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.rnn = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = torch.nn.Linear(hidden_size, 10)
def forward(self, x):
h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size)
c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size)
out, _ = self.rnn(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
model = RNNModel(input_size=32, hidden_size=64, num_layers=2)
model.eval()
dummy_input = torch.randn(1, 10, 32)
torch.onnx.export(
model,
dummy_input,
"rnn_model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "sequence_length"},
"output": {0: "batch_size"}
},
opset_version=17
)
验证与测试 #
完整验证脚本 #
python
import torch
import onnx
import onnxruntime as ort
import numpy as np
def validate_onnx_export(
pytorch_model,
onnx_path,
input_shape,
rtol=1e-3,
atol=1e-5
):
print("=" * 60)
print("ONNX 导出验证")
print("=" * 60)
pytorch_model.eval()
print("\n1. 验证 ONNX 模型格式...")
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(" ✅ 模型格式正确")
print("\n2. 获取模型信息...")
session = ort.InferenceSession(onnx_path)
print(" 输入:")
for inp in session.get_inputs():
print(f" - {inp.name}: {inp.shape}")
print(" 输出:")
for out in session.get_outputs():
print(f" - {out.name}: {out.shape}")
print("\n3. 比较输出一致性...")
input_name = session.get_inputs()[0].name
test_inputs = [
np.random.randn(*input_shape).astype(np.float32)
for _ in range(5)
]
max_diff = 0
for i, test_input in enumerate(test_inputs):
with torch.no_grad():
torch_output = pytorch_model(torch.from_numpy(test_input)).numpy()
onnx_output = session.run(None, {input_name: test_input})[0]
diff = np.abs(torch_output - onnx_output).max()
max_diff = max(max_diff, diff)
try:
np.testing.assert_allclose(
torch_output, onnx_output,
rtol=rtol, atol=atol
)
print(f" 样本 {i+1}: ✅ 一致 (最大差异: {diff:.6f})")
except AssertionError as e:
print(f" 样本 {i+1}: ❌ 不一致")
print(f" 错误: {e}")
print(f"\n4. 最大差异: {max_diff:.6f}")
if max_diff < atol:
print("\n✅ 验证通过!模型可以安全部署。")
return True
else:
print("\n⚠️ 存在精度差异,请检查模型。")
return False
import torchvision
model = torchvision.models.resnet18(pretrained=True)
validate_onnx_export(model, "resnet18.onnx", (1, 3, 224, 224))
性能测试 #
python
import torch
import onnxruntime as ort
import numpy as np
import time
def benchmark_pytorch_vs_onnx(
pytorch_model,
onnx_path,
input_shape,
num_runs=100,
warmup=10
):
pytorch_model.eval()
session = ort.InferenceSession(onnx_path)
input_name = session.get_inputs()[0].name
test_input = np.random.randn(*input_shape).astype(np.float32)
torch_input = torch.from_numpy(test_input)
for _ in range(warmup):
with torch.no_grad():
_ = pytorch_model(torch_input)
_ = session.run(None, {input_name: test_input})
torch.cuda.synchronize() if torch.cuda.is_available() else None
start = time.time()
for _ in range(num_runs):
with torch.no_grad():
_ = pytorch_model(torch_input)
torch_time = (time.time() - start) / num_runs * 1000
start = time.time()
for _ in range(num_runs):
_ = session.run(None, {input_name: test_input})
onnx_time = (time.time() - start) / num_runs * 1000
print("=" * 60)
print("性能对比")
print("=" * 60)
print(f"PyTorch 推理时间: {torch_time:.2f} ms")
print(f"ONNX Runtime 推理时间: {onnx_time:.2f} ms")
print(f"加速比: {torch_time / onnx_time:.2f}x")
print("=" * 60)
return {
"pytorch_time": torch_time,
"onnx_time": onnx_time,
"speedup": torch_time / onnx_time
}
常见问题解决 #
问题 1:导出失败 #
python
import torch
import torch.onnx
class ProblematicModel(torch.nn.Module):
def forward(self, x):
return x.item()
model = ProblematicModel()
class FixedModel(torch.nn.Module):
def forward(self, x):
return x.sum()
fixed_model = FixedModel()
torch.onnx.export(fixed_model, torch.randn(1), "fixed.onnx")
问题 2:动态形状问题 #
python
import torch
class DynamicModel(torch.nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
model = DynamicModel()
torch.onnx.export(
model,
torch.randn(1, 3, 224, 224),
"dynamic.onnx",
dynamic_axes={
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size"}
}
)
下一步 #
现在你已经掌握了 PyTorch 模型转换,接下来学习 TensorFlow 模型转换,了解如何转换 TensorFlow 模型!
最后更新:2026-04-04