模型优化 #
优化概述 #
ONNX 模型优化通过图变换技术减少计算量和内存占用,提升推理性能。
text
┌─────────────────────────────────────────────────────────────┐
│ 优化类型 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 算子级优化: │
│ ├── 算子融合 - 合并多个算子 │
│ ├── 算子替换 - 替换为更高效的算子 │
│ └── 算子简化 - 简化计算逻辑 │
│ │
│ 图级优化: │
│ ├── 常量折叠 - 编译时计算常量 │
│ ├── 死代码消除 - 移除无用节点 │
│ ├── 公共子表达式消除 - 复用计算结果 │
│ └── 内存优化 - 优化内存分配 │
│ │
│ 数据级优化: │
│ ├── 量化 - 降低精度 │
│ ├── 剪枝 - 移除冗余连接 │
│ └── 知识蒸馏 - 模型压缩 │
│ │
└─────────────────────────────────────────────────────────────┘
ONNX Optimizer #
安装 #
bash
pip install onnxoptimizer
基本使用 #
python
import onnx
from onnx import optimizer
model = onnx.load("model.onnx")
all_passes = optimizer.get_available_passes()
print(f"可用优化: {all_passes}")
passes = [
"eliminate_identity",
"eliminate_nop_transpose",
"eliminate_nop_pad",
"eliminate_unused_initializer",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"fuse_consecutive_squeezes",
"fuse_consecutive_reduce_unsqueeze",
"extract_constant_to_initializer"
]
optimized_model = optimizer.optimize(model, passes)
onnx.save(optimized_model, "model_optimized.onnx")
优化 Pass 详解 #
text
┌─────────────────────────────────────────────────────────────┐
│ 消除类优化 │
├─────────────────────────────────────────────────────────────┤
│ │
│ eliminate_identity: │
│ 移除 Identity 节点 │
│ X ──> Identity ──> Y 变为 X ──> Y │
│ │
│ eliminate_nop_transpose: │
│ 移除无效 Transpose(perm = [0,1,2,...]) │
│ │
│ eliminate_nop_pad: │
│ 移除无效 Pad(pads 全为 0) │
│ │
│ eliminate_unused_initializer: │
│ 移除未被使用的初始化参数 │
│ │
│ eliminate_deadend: │
│ 移除输出未被使用的节点 │
│ │
│ eliminate_duplicate_initializer: │
│ 合并重复的初始化参数 │
│ │
└─────────────────────────────────────────────────────────────┘
text
┌─────────────────────────────────────────────────────────────┐
│ 融合类优化 │
├─────────────────────────────────────────────────────────────┤
│ │
│ fuse_bn_into_conv: │
│ 将 BatchNorm 融入 Conv │
│ │
│ Conv ──> BN ──> 变为 Conv_fused ──> │
│ │
│ 融合公式: │
│ W_fused = W * (gamma / sqrt(var + eps)) │
│ B_fused = (B - mean) * gamma / sqrt(var + eps) + beta │
│ │
│ fuse_consecutive_transposes: │
│ 合并连续 Transpose │
│ │
│ Transpose(perm1) ──> Transpose(perm2) │
│ 变为 Transpose(combine(perm1, perm2)) │
│ │
│ fuse_consecutive_squeezes: │
│ 合并连续 Squeeze │
│ │
│ fuse_consecutive_reduce_unsqueeze: │
│ 合并 Reduce + Unsqueeze │
│ │
│ fuse_add_bias_into_conv: │
│ 将 Add 偏置融入 Conv │
│ │
└─────────────────────────────────────────────────────────────┘
自定义优化顺序 #
python
import onnx
from onnx import optimizer
model = onnx.load("model.onnx")
fixed_model = optimizer.optimize(model, ["eliminate_identity", "eliminate_nop_transpose"])
fused_model = optimizer.optimize(fixed_model, ["fuse_bn_into_conv"])
final_model = optimizer.optimize(fused_model, ["eliminate_unused_initializer"])
onnx.save(final_model, "model_optimized.onnx")
ONNX Simplifier #
安装 #
bash
pip install onnx-simplifier
命令行使用 #
bash
onnxsim input.onnx output.onnx
onnxsim input.onnx output.onnx --input-shape 1,3,224,224
onnxsim input.onnx output.onnx --skip-fuse-bn
onnxsim input.onnx output.onnx --input-shape input:1,3,224,224 input2:1,10
Python API #
python
import onnx
from onnxsim import simplify
model = onnx.load("model.onnx")
model_simplified, check = simplify(model)
if check:
onnx.save(model_simplified, "model_simplified.onnx")
print("简化成功")
else:
print("简化验证失败")
model_simplified, check = simplify(
model,
input_shapes={"input": [1, 3, 224, 224]}
)
model_simplified, check = simplify(
model,
perform_optimization=True,
skip_fuse_bn=False
)
简化效果 #
text
┌─────────────────────────────────────────────────────────────┐
│ 简化前后对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 简化前: │
│ 节点数: 156 │
│ 参数量: 25.6M │
│ 文件大小: 98MB │
│ │
│ 简化后: │
│ 节点数: 89 │
│ 参数量: 25.6M │
│ 文件大小: 98MB │
│ │
│ 效果: │
│ - 移除 67 个冗余节点 │
│ - 推理速度提升约 15% │
│ │
└─────────────────────────────────────────────────────────────┘
ONNX Runtime 优化 #
内置图优化 #
python
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.optimized_model_filepath = "model_optimized_by_ort.onnx"
session = ort.InferenceSession("model.onnx", sess_options)
优化级别 #
python
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
保存优化模型 #
python
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.optimized_model_filepath = "optimized_model.onnx"
session = ort.InferenceSession("model.onnx", sess_options)
常见优化模式 #
BatchNorm 融合 #
python
import onnx
from onnx import optimizer, numpy_helper
import numpy as np
def fuse_bn_into_conv_manual(model):
graph = model.graph
bn_nodes = {}
for node in graph.node:
if node.op_type == "BatchNormalization":
bn_nodes[node.input[0]] = node
for node in list(graph.node):
if node.op_type == "Conv" and node.output[0] in bn_nodes:
bn_node = bn_nodes[node.output[0]]
conv_weight = None
conv_bias = None
for init in graph.initializer:
if init.name == node.input[1]:
conv_weight = numpy_helper.to_array(init)
if len(node.input) > 2 and init.name == node.input[2]:
conv_bias = numpy_helper.to_array(init)
bn_scale = None
bn_bias = None
bn_mean = None
bn_var = None
for init in graph.initializer:
if init.name == bn_node.input[1]:
bn_scale = numpy_helper.to_array(init)
if init.name == bn_node.input[2]:
bn_bias = numpy_helper.to_array(init)
if init.name == bn_node.input[3]:
bn_mean = numpy_helper.to_array(init)
if init.name == bn_node.input[4]:
bn_var = numpy_helper.to_array(init)
if conv_weight is None or bn_scale is None:
continue
eps = 1e-5
for attr in bn_node.attribute:
if attr.name == "epsilon":
eps = attr.f
std = np.sqrt(bn_var + eps)
new_weight = conv_weight * (bn_scale / std).reshape(-1, 1, 1, 1)
if conv_bias is not None:
new_bias = (conv_bias - bn_mean) / std * bn_scale + bn_bias
else:
new_bias = -bn_mean / std * bn_scale + bn_bias
print(f"融合 Conv {node.name} 和 BN {bn_node.name}")
return optimizer.optimize(model, ["fuse_bn_into_conv"])
model = onnx.load("model.onnx")
optimized = fuse_bn_into_conv_manual(model)
onnx.save(optimized, "model_bn_fused.onnx")
常量折叠 #
python
import onnx
from onnx import optimizer, numpy_helper, helper
import numpy as np
def constant_folding(model):
graph = model.graph
const_nodes = {}
for node in graph.node:
if node.op_type == "Constant":
for attr in node.attribute:
if attr.name == "value":
const_nodes[node.output[0]] = numpy_helper.to_array(attr.t)
for node in list(graph.node):
new_inputs = []
changed = False
for inp in node.input:
if inp in const_nodes:
new_inputs.append(const_nodes[inp])
changed = True
else:
new_inputs.append(inp)
if changed:
print(f"节点 {node.name} 可进行常量折叠")
return optimizer.optimize(model, ["extract_constant_to_initializer"])
model = onnx.load("model.onnx")
optimized = constant_folding(model)
onnx.save(optimized, "model_const_folded.onnx")
死代码消除 #
python
import onnx
from onnx import optimizer
model = onnx.load("model.onnx")
passes = [
"eliminate_deadend",
"eliminate_unused_initializer",
"eliminate_duplicate_initializer"
]
optimized = optimizer.optimize(model, passes)
onnx.save(optimized, "model_dce.onnx")
性能对比 #
优化前后对比 #
python
import onnxruntime as ort
import numpy as np
import time
def benchmark_model(model_path, num_runs=100):
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
if any(isinstance(d, str) for d in input_shape):
input_shape = [1 if isinstance(d, str) else d for d in input_shape]
input_data = np.random.randn(*input_shape).astype(np.float32)
for _ in range(10):
session.run(None, {input_name: input_data})
start = time.time()
for _ in range(num_runs):
session.run(None, {input_name: input_data})
end = time.time()
avg_time = (end - start) / num_runs * 1000
return avg_time
original_time = benchmark_model("model.onnx")
optimized_time = benchmark_model("model_optimized.onnx")
print(f"原始模型: {original_time:.2f} ms")
print(f"优化模型: {optimized_time:.2f} ms")
print(f"提升: {(original_time - optimized_time) / original_time * 100:.1f}%")
最佳实践 #
优化流程 #
text
┌─────────────────────────────────────────────────────────────┐
│ 推荐优化流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. 导出 ONNX 模型 │
│ └── 使用正确的 opset 版本 │
│ │
│ 2. 使用 onnx-simplifier │
│ └── onnxsim model.onnx model_sim.onnx │
│ │
│ 3. 使用 onnxoptimizer │
│ └── 应用适当的优化 pass │
│ │
│ 4. 验证模型 │
│ └── 确保输出一致性 │
│ │
│ 5. 性能测试 │
│ └── 对比优化前后性能 │
│ │
│ 6. 部署 │
│ └── 使用优化后的模型 │
│ │
└─────────────────────────────────────────────────────────────┘
完整优化脚本 #
python
import onnx
from onnx import optimizer
from onnxsim import simplify
import onnxruntime as ort
import numpy as np
def optimize_model(input_path, output_path, input_shapes=None):
print(f"加载模型: {input_path}")
model = onnx.load(input_path)
print("验证原始模型...")
onnx.checker.check_model(model)
print("简化模型...")
model_simplified, check = simplify(model, input_shapes=input_shapes)
if not check:
print("警告: 简化验证失败")
print("优化模型...")
passes = [
"eliminate_identity",
"eliminate_nop_transpose",
"eliminate_nop_pad",
"eliminate_unused_initializer",
"eliminate_deadend",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"fuse_consecutive_squeezes",
"extract_constant_to_initializer"
]
model_optimized = optimizer.optimize(model_simplified, passes)
print("验证优化模型...")
onnx.checker.check_model(model_optimized)
print(f"保存优化模型: {output_path}")
onnx.save(model_optimized, output_path)
print("优化完成!")
print(f"原始节点数: {len(model.graph.node)}")
print(f"优化节点数: {len(model_optimized.graph.node)}")
return model_optimized
optimize_model("model.onnx", "model_optimized.onnx")
下一步 #
现在你已经了解了模型优化,接下来学习 自定义算子,学习如何扩展 ONNX 算子!
最后更新:2026-04-04