PyTorch 模型转换实战 #

ResNet 模型转换 #

基本转换 #

python
import torch
import torchvision
import torch.onnx

model = torchvision.models.resnet50(pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "resnet50.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=17,
    do_constant_folding=True
)

print("ResNet50 模型导出成功")

动态 Batch Size #

python
import torch
import torchvision

model = torchvision.models.resnet50(pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

dynamic_axes = {
    "input": {0: "batch_size"},
    "output": {0: "batch_size"}
}

torch.onnx.export(
    model,
    dummy_input,
    "resnet50_dynamic.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

完整转换流程 #

python
import torch
import torchvision
import torch.onnx
import onnx
import onnxruntime as ort
import numpy as np

def export_resnet_to_onnx(
    model_name="resnet50",
    output_path="resnet50.onnx",
    opset_version=17,
    dynamic_batch=True
):
    print(f"加载 {model_name} 模型...")
    model_fn = getattr(torchvision.models, model_name)
    model = model_fn(pretrained=True)
    model.eval()
    
    print("创建输入张量...")
    dummy_input = torch.randn(1, 3, 224, 224)
    
    print("配置导出参数...")
    export_args = {
        "model": model,
        "args": dummy_input,
        "f": output_path,
        "input_names": ["input"],
        "output_names": ["output"],
        "opset_version": opset_version,
        "do_constant_folding": True,
        "verbose": False
    }
    
    if dynamic_batch:
        export_args["dynamic_axes"] = {
            "input": {0: "batch_size"},
            "output": {0: "batch_size"}
        }
    
    print("导出 ONNX 模型...")
    torch.onnx.export(**export_args)
    
    print("验证 ONNX 模型...")
    onnx_model = onnx.load(output_path)
    onnx.checker.check_model(onnx_model)
    
    print("验证输出一致性...")
    session = ort.InferenceSession(output_path)
    input_name = session.get_inputs()[0].name
    
    test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
    
    with torch.no_grad():
        torch_output = model(torch.from_numpy(test_input)).numpy()
    
    onnx_output = session.run(None, {input_name: test_input})[0]
    
    np.testing.assert_allclose(torch_output, onnx_output, rtol=1e-3, atol=1e-5)
    print("✅ 输出一致性验证通过")
    
    print(f"\n模型信息:")
    print(f"  文件: {output_path}")
    print(f"  Opset: {opset_version}")
    print(f"  动态 Batch: {dynamic_batch}")
    
    return output_path

export_resnet_to_onnx()

BERT 模型转换 #

HuggingFace Transformers 转换 #

python
import torch
from transformers import BertModel, BertTokenizer
import onnx

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()

text = "Hello, how are you?"
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
token_type_ids = inputs.get("token_type_ids", torch.zeros_like(input_ids))

class BertWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        return outputs.last_hidden_state, outputs.pooler_output

wrapped_model = BertWrapper(model)

torch.onnx.export(
    wrapped_model,
    (input_ids, attention_mask, token_type_ids),
    "bert_base.onnx",
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["last_hidden_state", "pooler_output"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "token_type_ids": {0: "batch_size", 1: "sequence_length"},
        "last_hidden_state": {0: "batch_size", 1: "sequence_length"},
        "pooler_output": {0: "batch_size"}
    },
    opset_version=17
)

print("BERT 模型导出成功")

优化 BERT 导出 #

python
import torch
from transformers import BertModel, BertTokenizer
import onnx
from onnxruntime.transformers import optimizer

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
model.eval()

max_seq_length = 128

dummy_input_ids = torch.zeros(1, max_seq_length, dtype=torch.long)
dummy_attention_mask = torch.ones(1, max_seq_length, dtype=torch.long)
dummy_token_type_ids = torch.zeros(1, max_seq_length, dtype=torch.long)

class OptimizedBertWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        return outputs.last_hidden_state

wrapped_model = OptimizedBertWrapper(model)

torch.onnx.export(
    wrapped_model,
    (dummy_input_ids, dummy_attention_mask, dummy_token_type_ids),
    "bert_optimized.onnx",
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["last_hidden_state"],
    dynamic_axes={
        "input_ids": {0: "batch_size"},
        "attention_mask": {0: "batch_size"},
        "token_type_ids": {0: "batch_size"},
        "last_hidden_state": {0: "batch_size"}
    },
    opset_version=17
)

optimized_model = optimizer.optimize_model(
    "bert_optimized.onnx",
    model_type="bert",
    num_heads=12,
    hidden_size=768
)

optimized_model.save_model_to_file("bert_optimized_final.onnx")
print("优化后的 BERT 模型已保存")

YOLO 模型转换 #

YOLOv5 转换 #

python
import torch

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
model.eval()

img = torch.zeros(1, 3, 640, 640)

torch.onnx.export(
    model,
    img,
    "yolov5s.onnx",
    input_names=["images"],
    output_names=["output"],
    dynamic_axes={
        "images": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    opset_version=17
)

print("YOLOv5 模型导出成功")

YOLOv8 转换 #

python
from ultralytics import YOLO

model = YOLO("yolov8n.pt")

model.export(format="onnx", dynamic=True, simplify=True)

print("YOLOv8 模型导出成功")

自定义模型转换 #

带控制流的模型 #

python
import torch
import torch.onnx

class ConditionalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 20)
        self.fc2 = torch.nn.Linear(20, 10)
        self.fc3 = torch.nn.Linear(10, 5)
    
    def forward(self, x, condition):
        x = torch.relu(self.fc1(x))
        
        if condition.sum() > 0:
            x = torch.relu(self.fc2(x))
        
        x = self.fc3(x)
        return x

model = ConditionalModel()
model.eval()

x = torch.randn(1, 10)
condition = torch.tensor([1.0])

torch.onnx.export(
    model,
    (x, condition),
    "conditional_model.onnx",
    input_names=["input", "condition"],
    output_names=["output"],
    opset_version=17
)

带循环的模型 #

python
import torch
import torch.onnx

class RNNModel(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.rnn = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, 10)
    
    def forward(self, x):
        h0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size)
        c0 = torch.zeros(self.rnn.num_layers, x.size(0), self.rnn.hidden_size)
        
        out, _ = self.rnn(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out

model = RNNModel(input_size=32, hidden_size=64, num_layers=2)
model.eval()

dummy_input = torch.randn(1, 10, 32)

torch.onnx.export(
    model,
    dummy_input,
    "rnn_model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "sequence_length"},
        "output": {0: "batch_size"}
    },
    opset_version=17
)

验证与测试 #

完整验证脚本 #

python
import torch
import onnx
import onnxruntime as ort
import numpy as np

def validate_onnx_export(
    pytorch_model,
    onnx_path,
    input_shape,
    rtol=1e-3,
    atol=1e-5
):
    print("=" * 60)
    print("ONNX 导出验证")
    print("=" * 60)
    
    pytorch_model.eval()
    
    print("\n1. 验证 ONNX 模型格式...")
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("   ✅ 模型格式正确")
    
    print("\n2. 获取模型信息...")
    session = ort.InferenceSession(onnx_path)
    
    print("   输入:")
    for inp in session.get_inputs():
        print(f"     - {inp.name}: {inp.shape}")
    
    print("   输出:")
    for out in session.get_outputs():
        print(f"     - {out.name}: {out.shape}")
    
    print("\n3. 比较输出一致性...")
    input_name = session.get_inputs()[0].name
    
    test_inputs = [
        np.random.randn(*input_shape).astype(np.float32)
        for _ in range(5)
    ]
    
    max_diff = 0
    for i, test_input in enumerate(test_inputs):
        with torch.no_grad():
            torch_output = pytorch_model(torch.from_numpy(test_input)).numpy()
        
        onnx_output = session.run(None, {input_name: test_input})[0]
        
        diff = np.abs(torch_output - onnx_output).max()
        max_diff = max(max_diff, diff)
        
        try:
            np.testing.assert_allclose(
                torch_output, onnx_output,
                rtol=rtol, atol=atol
            )
            print(f"   样本 {i+1}: ✅ 一致 (最大差异: {diff:.6f})")
        except AssertionError as e:
            print(f"   样本 {i+1}: ❌ 不一致")
            print(f"     错误: {e}")
    
    print(f"\n4. 最大差异: {max_diff:.6f}")
    
    if max_diff < atol:
        print("\n✅ 验证通过!模型可以安全部署。")
        return True
    else:
        print("\n⚠️ 存在精度差异,请检查模型。")
        return False

import torchvision
model = torchvision.models.resnet18(pretrained=True)
validate_onnx_export(model, "resnet18.onnx", (1, 3, 224, 224))

性能测试 #

python
import torch
import onnxruntime as ort
import numpy as np
import time

def benchmark_pytorch_vs_onnx(
    pytorch_model,
    onnx_path,
    input_shape,
    num_runs=100,
    warmup=10
):
    pytorch_model.eval()
    session = ort.InferenceSession(onnx_path)
    input_name = session.get_inputs()[0].name
    
    test_input = np.random.randn(*input_shape).astype(np.float32)
    torch_input = torch.from_numpy(test_input)
    
    for _ in range(warmup):
        with torch.no_grad():
            _ = pytorch_model(torch_input)
        _ = session.run(None, {input_name: test_input})
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    
    start = time.time()
    for _ in range(num_runs):
        with torch.no_grad():
            _ = pytorch_model(torch_input)
    torch_time = (time.time() - start) / num_runs * 1000
    
    start = time.time()
    for _ in range(num_runs):
        _ = session.run(None, {input_name: test_input})
    onnx_time = (time.time() - start) / num_runs * 1000
    
    print("=" * 60)
    print("性能对比")
    print("=" * 60)
    print(f"PyTorch 推理时间: {torch_time:.2f} ms")
    print(f"ONNX Runtime 推理时间: {onnx_time:.2f} ms")
    print(f"加速比: {torch_time / onnx_time:.2f}x")
    print("=" * 60)
    
    return {
        "pytorch_time": torch_time,
        "onnx_time": onnx_time,
        "speedup": torch_time / onnx_time
    }

常见问题解决 #

问题 1:导出失败 #

python
import torch
import torch.onnx

class ProblematicModel(torch.nn.Module):
    def forward(self, x):
        return x.item()

model = ProblematicModel()

class FixedModel(torch.nn.Module):
    def forward(self, x):
        return x.sum()

fixed_model = FixedModel()
torch.onnx.export(fixed_model, torch.randn(1), "fixed.onnx")

问题 2:动态形状问题 #

python
import torch

class DynamicModel(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

model = DynamicModel()

torch.onnx.export(
    model,
    torch.randn(1, 3, 224, 224),
    "dynamic.onnx",
    dynamic_axes={
        "input": {0: "batch_size", 2: "height", 3: "width"},
        "output": {0: "batch_size"}
    }
)

下一步 #

现在你已经掌握了 PyTorch 模型转换,接下来学习 TensorFlow 模型转换,了解如何转换 TensorFlow 模型!

最后更新:2026-04-04