模型转换 #
转换概述 #
ONNX 作为中间格式,支持从多种深度学习框架导出模型。
text
┌─────────────────────────────────────────────────────────────┐
│ 模型转换流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 训练框架 ONNX 推理引擎 │
│ ┌─────────┐ ┌─────┐ ┌─────┐│
│ │PyTorch │──┐ │ │ ┌──│ORT ││
│ └─────────┘ │ │ │ │ └─────┘│
│ ┌─────────┐ │ │ │ │ ┌─────┐│
│ │TensorFlow│──┼──────────>│ONNX │──────────┼──│TRT ││
│ └─────────┘ │ │ │ │ └─────┘│
│ ┌─────────┐ │ │ │ │ ┌─────┐│
│ │ JAX │──┘ │ │ └──│OV ││
│ └─────────┘ └─────┘ └─────┘│
│ │
└─────────────────────────────────────────────────────────────┘
PyTorch 模型转换 #
基本转换 #
python
import torch
import torch.onnx
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.bn1 = torch.nn.BatchNorm2d(64)
self.relu = torch.nn.ReLU()
self.pool = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = SimpleModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
opset_version=17,
verbose=True
)
export 参数详解 #
text
┌─────────────────────────────────────────────────────────────┐
│ torch.onnx.export 参数 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 必需参数: │
│ ├── model: 要导出的 PyTorch 模型 │
│ ├── args: 模型输入(张量或张量元组) │
│ └── f: 输出文件路径或文件对象 │
│ │
│ 常用可选参数: │
│ ├── input_names: 输入张量名称列表 │
│ ├── output_names: 输出张量名称列表 │
│ ├── opset_version: ONNX 算子集版本 │
│ ├── dynamic_axes: 动态维度定义 │
│ ├── verbose: 是否打印详细信息 │
│ ├── training: 导出模式 (TrainingMode) │
│ ├── do_constant_folding: 是否常量折叠 │
│ ├── export_params: 是否导出参数 │
│ └── custom_opsets: 自定义算子集 │
│ │
└─────────────────────────────────────────────────────────────┘
动态形状 #
python
import torch
model = SimpleModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
dynamic_axes = {
"input": {
0: "batch_size",
2: "height",
3: "width"
},
"output": {
0: "batch_size"
}
}
torch.onnx.export(
model,
dummy_input,
"model_dynamic.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
opset_version=17
)
多输入多输出 #
python
class MultiInputModel(torch.nn.Module):
def forward(self, x1, x2):
return x1 + x2, x1 * x2
model = MultiInputModel()
model.eval()
dummy_input1 = torch.randn(1, 3, 224, 224)
dummy_input2 = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
(dummy_input1, dummy_input2),
"multi_io.onnx",
input_names=["input1", "input2"],
output_names=["output1", "output2"],
dynamic_axes={
"input1": {0: "batch_size"},
"input2": {0: "batch_size"},
"output1": {0: "batch_size"},
"output2": {0: "batch_size"}
},
opset_version=17
)
导出验证 #
python
import torch
import onnx
import onnxruntime as ort
import numpy as np
model = SimpleModel()
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
opset_version=17
)
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过")
session = ort.InferenceSession("model.onnx")
input_numpy = dummy_input.numpy()
outputs = session.run(None, {"input": input_numpy})
with torch.no_grad():
torch_output = model(dummy_input)
np.testing.assert_allclose(
torch_output.numpy(),
outputs[0],
rtol=1e-3,
atol=1e-5
)
print("输出一致性验证通过")
TensorFlow 模型转换 #
使用 tf2onnx #
bash
pip install tf2onnx onnx onnxruntime
从 SavedModel 转换 #
python
import tensorflow as tf
import tf2onnx
model = tf.keras.applications.ResNet50(weights=None, input_shape=(224, 224, 3))
model.save("saved_model")
import subprocess
subprocess.run([
"python", "-m", "tf2onnx.convert",
"--saved-model", "saved_model",
"--output", "model.onnx",
"--opset", "17"
])
从 Keras 模型转换 #
python
import tensorflow as tf
import tf2onnx
import onnx
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(64, 3, padding='same', input_shape=(224, 224, 3)),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])
input_signature = [tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(
model,
input_signature=input_signature,
opset=17
)
onnx.save(onnx_model, "keras_model.onnx")
从 Concrete Function 转换 #
python
import tensorflow as tf
import tf2onnx
class MyModel(tf.Module):
@tf.function(input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)])
def __call__(self, x):
return tf.keras.layers.Conv2D(64, 3)(x)
model = MyModel()
concrete_func = model.__call__.get_concrete_function()
onnx_model, _ = tf2onnx.convert.from_function(
concrete_func,
input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32, name='input')],
opset=17
)
onnx.save(onnx_model, "tf_model.onnx")
命令行转换 #
bash
python -m tf2onnx.convert --saved-model path/to/saved_model --output model.onnx
python -m tf2onnx.convert --frozen-graph frozen.pb --output model.onnx --inputs input:0 --outputs output:0
python -m tf2onnx.convert --checkpoint checkpoint.ckpt --output model.onnx --inputs input:0 --outputs output:0
scikit-learn 模型转换 #
使用 skl2onnx #
bash
pip install skl2onnx onnxmltools
转换示例 #
python
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnxruntime as ort
import numpy as np
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
initial_type = [('float_input', FloatTensorType([None, 4]))]
onnx_model = convert_sklearn(
model,
initial_types=initial_type,
target_opset=17
)
with open("rf_iris.onnx", "wb") as f:
f.write(onnx_model.SerializeToString())
session = ort.InferenceSession("rf_iris.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
predictions = session.run([output_name], {input_name: X_test.astype(np.float32)})[0]
print(f"sklearn 预测: {model.predict(X_test[:5])}")
print(f"ONNX 预测: {predictions[:5]}")
支持的模型 #
text
┌─────────────────────────────────────────────────────────────┐
│ skl2onnx 支持的模型类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 分类模型: │
│ ├── LogisticRegression │
│ ├── SVC, NuSVC │
│ ├── RandomForestClassifier │
│ ├── GradientBoostingClassifier │
│ ├── DecisionTreeClassifier │
│ ├── KNeighborsClassifier │
│ └── MLPClassifier │
│ │
│ 回归模型: │
│ ├── LinearRegression │
│ ├── SVR, NuSVR │
│ ├── RandomForestRegressor │
│ ├── GradientBoostingRegressor │
│ └── DecisionTreeRegressor │
│ │
│ 聚类模型: │
│ ├── KMeans │
│ └── MiniBatchKMeans │
│ │
│ 预处理: │
│ ├── StandardScaler │
│ ├── MinMaxScaler │
│ ├── Normalizer │
│ └── PCA │
│ │
└─────────────────────────────────────────────────────────────┘
其他框架转换 #
JAX 转换 #
python
import jax
import jax.numpy as jnp
from jax.experimental import jax2onnx
def model(x):
return jnp.maximum(x, 0)
onnx_model = jax2onnx.convert(
jax.jit(model),
jnp.ones((1, 3, 224, 224))
)
onnx_model.save("jax_model.onnx")
MXNet 转换 #
python
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
sym = 'model-symbol.json'
params = 'model-0000.params'
onnx_file = 'mxnet_model.onnx'
input_shape = (1, 3, 224, 224)
onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)
转换最佳实践 #
1. 选择合适的 Opset 版本 #
python
import torch
import onnx
from onnx import defs
current_opset = defs.onnx_opset_version()
print(f"当前 ONNX 默认 opset: {current_opset}")
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=17
)
2. 处理动态形状 #
python
dynamic_axes_options = {
"batch_only": {
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
"full_dynamic": {
"input": {
0: "batch_size",
2: "height",
3: "width"
},
"output": {0: "batch_size"}
}
}
torch.onnx.export(
model,
dummy_input,
"model.onnx",
dynamic_axes=dynamic_axes_options["batch_only"],
input_names=["input"],
output_names=["output"]
)
3. 自定义算子处理 #
python
import torch
class CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x * 2
class CustomModel(torch.nn.Module):
def forward(self, x):
return CustomOp.apply(x)
from torch.onnx import register_custom_op_symbolic
def custom_op_symbolic(g, x):
return g.op("CustomDomain::CustomOp", x)
register_custom_op_symbolic("CustomOp", custom_op_symbolic, 17)
model = CustomModel()
torch.onnx.export(
model,
torch.randn(1, 10),
"custom_op.onnx",
custom_opsets={"CustomDomain": 1}
)
4. 模型简化 #
python
import onnx
from onnxsim import simplify
model = onnx.load("model.onnx")
model_simplified, check = simplify(model)
if check:
onnx.save(model_simplified, "model_simplified.onnx")
print("模型简化成功")
else:
print("模型简化验证失败")
常见问题 #
问题 1:算子不支持 #
text
错误信息:
RuntimeError: ONNX export failed: Couldn't export operator xxx
解决方案:
1. 检查算子是否在当前 opset 版本中支持
2. 使用更高版本的 opset
3. 自定义算子导出
4. 替换为支持的等价操作
问题 2:动态形状问题 #
text
错误信息:
The shape of input does not match the model
解决方案:
1. 正确定义 dynamic_axes
2. 使用 None 作为动态维度
3. 检查输入形状是否匹配
问题 3:精度问题 #
python
import torch
import onnxruntime as ort
import numpy as np
def compare_outputs(model, onnx_path, input_tensor, rtol=1e-3, atol=1e-5):
model.eval()
with torch.no_grad():
torch_output = model(input_tensor)
session = ort.InferenceSession(onnx_path)
onnx_output = session.run(None, {"input": input_tensor.numpy()})[0]
try:
np.testing.assert_allclose(
torch_output.numpy(),
onnx_output,
rtol=rtol,
atol=atol
)
print("输出一致")
return True
except AssertionError as e:
print(f"输出不一致: {e}")
return False
问题 4:BatchNorm 问题 #
python
model.eval()
for module in model.modules():
if isinstance(module, torch.nn.BatchNorm2d):
module.track_running_stats = False
torch.onnx.export(model, dummy_input, "model.onnx")
下一步 #
现在你已经掌握了模型转换,接下来学习 Python API,深入了解 ONNX 的 Python 接口!
最后更新:2026-04-04