模型转换 #

转换概述 #

ONNX 作为中间格式,支持从多种深度学习框架导出模型。

text
┌─────────────────────────────────────────────────────────────┐
│                    模型转换流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   训练框架                    ONNX                    推理引擎 │
│   ┌─────────┐                ┌─────┐                ┌─────┐│
│   │PyTorch  │──┐            │     │            ┌──│ORT  ││
│   └─────────┘  │            │     │            │  └─────┘│
│   ┌─────────┐  │            │     │            │  ┌─────┐│
│   │TensorFlow│──┼──────────>│ONNX │──────────┼──│TRT  ││
│   └─────────┘  │            │     │            │  └─────┘│
│   ┌─────────┐  │            │     │            │  ┌─────┐│
│   │  JAX    │──┘            │     │            └──│OV   ││
│   └─────────┘                └─────┘               └─────┘│
│                                                             │
└─────────────────────────────────────────────────────────────┘

PyTorch 模型转换 #

基本转换 #

python
import torch
import torch.onnx

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU()
        self.pool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = SimpleModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=17,
    verbose=True
)

export 参数详解 #

text
┌─────────────────────────────────────────────────────────────┐
│                    torch.onnx.export 参数                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  必需参数:                                                 │
│  ├── model: 要导出的 PyTorch 模型                          │
│  ├── args: 模型输入(张量或张量元组)                       │
│  └── f: 输出文件路径或文件对象                             │
│                                                             │
│  常用可选参数:                                             │
│  ├── input_names: 输入张量名称列表                         │
│  ├── output_names: 输出张量名称列表                        │
│  ├── opset_version: ONNX 算子集版本                        │
│  ├── dynamic_axes: 动态维度定义                            │
│  ├── verbose: 是否打印详细信息                             │
│  ├── training: 导出模式 (TrainingMode)                     │
│  ├── do_constant_folding: 是否常量折叠                     │
│  ├── export_params: 是否导出参数                           │
│  └── custom_opsets: 自定义算子集                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

动态形状 #

python
import torch

model = SimpleModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

dynamic_axes = {
    "input": {
        0: "batch_size",
        2: "height",
        3: "width"
    },
    "output": {
        0: "batch_size"
    }
}

torch.onnx.export(
    model,
    dummy_input,
    "model_dynamic.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes=dynamic_axes,
    opset_version=17
)

多输入多输出 #

python
class MultiInputModel(torch.nn.Module):
    def forward(self, x1, x2):
        return x1 + x2, x1 * x2

model = MultiInputModel()
model.eval()

dummy_input1 = torch.randn(1, 3, 224, 224)
dummy_input2 = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    (dummy_input1, dummy_input2),
    "multi_io.onnx",
    input_names=["input1", "input2"],
    output_names=["output1", "output2"],
    dynamic_axes={
        "input1": {0: "batch_size"},
        "input2": {0: "batch_size"},
        "output1": {0: "batch_size"},
        "output2": {0: "batch_size"}
    },
    opset_version=17
)

导出验证 #

python
import torch
import onnx
import onnxruntime as ort
import numpy as np

model = SimpleModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=17
)

onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过")

session = ort.InferenceSession("model.onnx")

input_numpy = dummy_input.numpy()
outputs = session.run(None, {"input": input_numpy})

with torch.no_grad():
    torch_output = model(dummy_input)

np.testing.assert_allclose(
    torch_output.numpy(),
    outputs[0],
    rtol=1e-3,
    atol=1e-5
)
print("输出一致性验证通过")

TensorFlow 模型转换 #

使用 tf2onnx #

bash
pip install tf2onnx onnx onnxruntime

从 SavedModel 转换 #

python
import tensorflow as tf
import tf2onnx

model = tf.keras.applications.ResNet50(weights=None, input_shape=(224, 224, 3))

model.save("saved_model")

import subprocess
subprocess.run([
    "python", "-m", "tf2onnx.convert",
    "--saved-model", "saved_model",
    "--output", "model.onnx",
    "--opset", "17"
])

从 Keras 模型转换 #

python
import tensorflow as tf
import tf2onnx
import onnx

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64, 3, padding='same', input_shape=(224, 224, 3)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.ReLU(),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10)
])

input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]

onnx_model, _ = tf2onnx.convert.from_keras(
    model,
    input_signature=input_signature,
    opset=17
)

onnx.save(onnx_model, "keras_model.onnx")

从 Concrete Function 转换 #

python
import tensorflow as tf
import tf2onnx

class MyModel(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)])
    def __call__(self, x):
        return tf.keras.layers.Conv2D(64, 3)(x)

model = MyModel()

concrete_func = model.__call__.get_concrete_function()

onnx_model, _ = tf2onnx.convert.from_function(
    concrete_func,
    input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')],
    opset=17
)

onnx.save(onnx_model, "tf_model.onnx")

命令行转换 #

bash
python -m tf2onnx.convert --saved-model path/to/saved_model --output model.onnx

python -m tf2onnx.convert --frozen-graph frozen.pb --output model.onnx --inputs input:0 --outputs output:0

python -m tf2onnx.convert --checkpoint checkpoint.ckpt --output model.onnx --inputs input:0 --outputs output:0

scikit-learn 模型转换 #

使用 skl2onnx #

bash
pip install skl2onnx onnxmltools

转换示例 #

python
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnxruntime as ort
import numpy as np

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

initial_type = [('float_input', FloatTensorType([None, 4]))]

onnx_model = convert_sklearn(
    model,
    initial_types=initial_type,
    target_opset=17
)

with open("rf_iris.onnx", "wb") as f:
    f.write(onnx_model.SerializeToString())

session = ort.InferenceSession("rf_iris.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

predictions = session.run([output_name], {input_name: X_test.astype(np.float32)})[0]

print(f"sklearn 预测: {model.predict(X_test[:5])}")
print(f"ONNX 预测: {predictions[:5]}")

支持的模型 #

text
┌─────────────────────────────────────────────────────────────┐
│              skl2onnx 支持的模型类型                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  分类模型:                                                 │
│  ├── LogisticRegression                                    │
│  ├── SVC, NuSVC                                            │
│  ├── RandomForestClassifier                                │
│  ├── GradientBoostingClassifier                            │
│  ├── DecisionTreeClassifier                                │
│  ├── KNeighborsClassifier                                  │
│  └── MLPClassifier                                         │
│                                                             │
│  回归模型:                                                 │
│  ├── LinearRegression                                      │
│  ├── SVR, NuSVR                                            │
│  ├── RandomForestRegressor                                 │
│  ├── GradientBoostingRegressor                             │
│  └── DecisionTreeRegressor                                 │
│                                                             │
│  聚类模型:                                                 │
│  ├── KMeans                                                │
│  └── MiniBatchKMeans                                       │
│                                                             │
│  预处理:                                                   │
│  ├── StandardScaler                                        │
│  ├── MinMaxScaler                                          │
│  ├── Normalizer                                            │
│  └── PCA                                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

其他框架转换 #

JAX 转换 #

python
import jax
import jax.numpy as jnp
from jax.experimental import jax2onnx

def model(x):
    return jnp.maximum(x, 0)

onnx_model = jax2onnx.convert(
    jax.jit(model),
    jnp.ones((1, 3, 224, 224))
)

onnx_model.save("jax_model.onnx")

MXNet 转换 #

python
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet

sym = 'model-symbol.json'
params = 'model-0000.params'

onnx_file = 'mxnet_model.onnx'

input_shape = (1, 3, 224, 224)

onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)

转换最佳实践 #

1. 选择合适的 Opset 版本 #

python
import torch
import onnx
from onnx import defs

current_opset = defs.onnx_opset_version()
print(f"当前 ONNX 默认 opset: {current_opset}")

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    opset_version=17
)

2. 处理动态形状 #

python
dynamic_axes_options = {
    "batch_only": {
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    "full_dynamic": {
        "input": {
            0: "batch_size",
            2: "height",
            3: "width"
        },
        "output": {0: "batch_size"}
    }
}

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    dynamic_axes=dynamic_axes_options["batch_only"],
    input_names=["input"],
    output_names=["output"]
)

3. 自定义算子处理 #

python
import torch

class CustomOp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x * 2

class CustomModel(torch.nn.Module):
    def forward(self, x):
        return CustomOp.apply(x)

from torch.onnx import register_custom_op_symbolic

def custom_op_symbolic(g, x):
    return g.op("CustomDomain::CustomOp", x)

register_custom_op_symbolic("CustomOp", custom_op_symbolic, 17)

model = CustomModel()
torch.onnx.export(
    model,
    torch.randn(1, 10),
    "custom_op.onnx",
    custom_opsets={"CustomDomain": 1}
)

4. 模型简化 #

python
import onnx
from onnxsim import simplify

model = onnx.load("model.onnx")

model_simplified, check = simplify(model)

if check:
    onnx.save(model_simplified, "model_simplified.onnx")
    print("模型简化成功")
else:
    print("模型简化验证失败")

常见问题 #

问题 1:算子不支持 #

text
错误信息:
RuntimeError: ONNX export failed: Couldn't export operator xxx

解决方案:
1. 检查算子是否在当前 opset 版本中支持
2. 使用更高版本的 opset
3. 自定义算子导出
4. 替换为支持的等价操作

问题 2:动态形状问题 #

text
错误信息:
The shape of input does not match the model

解决方案:
1. 正确定义 dynamic_axes
2. 使用 None 作为动态维度
3. 检查输入形状是否匹配

问题 3:精度问题 #

python
import torch
import onnxruntime as ort
import numpy as np

def compare_outputs(model, onnx_path, input_tensor, rtol=1e-3, atol=1e-5):
    model.eval()
    with torch.no_grad():
        torch_output = model(input_tensor)
    
    session = ort.InferenceSession(onnx_path)
    onnx_output = session.run(None, {"input": input_tensor.numpy()})[0]
    
    try:
        np.testing.assert_allclose(
            torch_output.numpy(),
            onnx_output,
            rtol=rtol,
            atol=atol
        )
        print("输出一致")
        return True
    except AssertionError as e:
        print(f"输出不一致: {e}")
        return False

问题 4:BatchNorm 问题 #

python
model.eval()

for module in model.modules():
    if isinstance(module, torch.nn.BatchNorm2d):
        module.track_running_stats = False

torch.onnx.export(model, dummy_input, "model.onnx")

下一步 #

现在你已经掌握了模型转换,接下来学习 Python API,深入了解 ONNX 的 Python 接口!

最后更新:2026-04-04