模型优化 #

优化概述 #

ONNX 模型优化通过图变换技术减少计算量和内存占用,提升推理性能。

text
┌─────────────────────────────────────────────────────────────┐
│                    优化类型                                  │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  算子级优化:                                               │
│  ├── 算子融合 - 合并多个算子                               │
│  ├── 算子替换 - 替换为更高效的算子                         │
│  └── 算子简化 - 简化计算逻辑                               │
│                                                             │
│  图级优化:                                                 │
│  ├── 常量折叠 - 编译时计算常量                             │
│  ├── 死代码消除 - 移除无用节点                             │
│  ├── 公共子表达式消除 - 复用计算结果                       │
│  └── 内存优化 - 优化内存分配                               │
│                                                             │
│  数据级优化:                                               │
│  ├── 量化 - 降低精度                                       │
│  ├── 剪枝 - 移除冗余连接                                   │
│  └── 知识蒸馏 - 模型压缩                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ONNX Optimizer #

安装 #

bash
pip install onnxoptimizer

基本使用 #

python
import onnx
from onnx import optimizer

model = onnx.load("model.onnx")

all_passes = optimizer.get_available_passes()
print(f"可用优化: {all_passes}")

passes = [
    "eliminate_identity",
    "eliminate_nop_transpose",
    "eliminate_nop_pad",
    "eliminate_unused_initializer",
    "fuse_bn_into_conv",
    "fuse_consecutive_transposes",
    "fuse_consecutive_squeezes",
    "fuse_consecutive_reduce_unsqueeze",
    "extract_constant_to_initializer"
]

optimized_model = optimizer.optimize(model, passes)

onnx.save(optimized_model, "model_optimized.onnx")

优化 Pass 详解 #

text
┌─────────────────────────────────────────────────────────────┐
│                    消除类优化                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  eliminate_identity:                                        │
│  移除 Identity 节点                                         │
│  X ──> Identity ──> Y    变为    X ──> Y                   │
│                                                             │
│  eliminate_nop_transpose:                                   │
│  移除无效 Transpose(perm = [0,1,2,...])                   │
│                                                             │
│  eliminate_nop_pad:                                         │
│  移除无效 Pad(pads 全为 0)                                │
│                                                             │
│  eliminate_unused_initializer:                              │
│  移除未被使用的初始化参数                                   │
│                                                             │
│  eliminate_deadend:                                         │
│  移除输出未被使用的节点                                     │
│                                                             │
│  eliminate_duplicate_initializer:                           │
│  合并重复的初始化参数                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘
text
┌─────────────────────────────────────────────────────────────┐
│                    融合类优化                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  fuse_bn_into_conv:                                         │
│  将 BatchNorm 融入 Conv                                    │
│                                                             │
│  Conv ──> BN ──>    变为    Conv_fused ──>                 │
│                                                             │
│  融合公式:                                                 │
│  W_fused = W * (gamma / sqrt(var + eps))                   │
│  B_fused = (B - mean) * gamma / sqrt(var + eps) + beta     │
│                                                             │
│  fuse_consecutive_transposes:                               │
│  合并连续 Transpose                                         │
│                                                             │
│  Transpose(perm1) ──> Transpose(perm2)                      │
│  变为 Transpose(combine(perm1, perm2))                      │
│                                                             │
│  fuse_consecutive_squeezes:                                 │
│  合并连续 Squeeze                                           │
│                                                             │
│  fuse_consecutive_reduce_unsqueeze:                         │
│  合并 Reduce + Unsqueeze                                    │
│                                                             │
│  fuse_add_bias_into_conv:                                   │
│  将 Add 偏置融入 Conv                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

自定义优化顺序 #

python
import onnx
from onnx import optimizer

model = onnx.load("model.onnx")

fixed_model = optimizer.optimize(model, ["eliminate_identity", "eliminate_nop_transpose"])

fused_model = optimizer.optimize(fixed_model, ["fuse_bn_into_conv"])

final_model = optimizer.optimize(fused_model, ["eliminate_unused_initializer"])

onnx.save(final_model, "model_optimized.onnx")

ONNX Simplifier #

安装 #

bash
pip install onnx-simplifier

命令行使用 #

bash
onnxsim input.onnx output.onnx

onnxsim input.onnx output.onnx --input-shape 1,3,224,224

onnxsim input.onnx output.onnx --skip-fuse-bn

onnxsim input.onnx output.onnx --input-shape input:1,3,224,224 input2:1,10

Python API #

python
import onnx
from onnxsim import simplify

model = onnx.load("model.onnx")

model_simplified, check = simplify(model)

if check:
    onnx.save(model_simplified, "model_simplified.onnx")
    print("简化成功")
else:
    print("简化验证失败")

model_simplified, check = simplify(
    model,
    input_shapes={"input": [1, 3, 224, 224]}
)

model_simplified, check = simplify(
    model,
    perform_optimization=True,
    skip_fuse_bn=False
)

简化效果 #

text
┌─────────────────────────────────────────────────────────────┐
│                    简化前后对比                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  简化前:                                                   │
│  节点数: 156                                                │
│  参数量: 25.6M                                              │
│  文件大小: 98MB                                             │
│                                                             │
│  简化后:                                                   │
│  节点数: 89                                                 │
│  参数量: 25.6M                                              │
│  文件大小: 98MB                                             │
│                                                             │
│  效果:                                                     │
│  - 移除 67 个冗余节点                                       │
│  - 推理速度提升约 15%                                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

ONNX Runtime 优化 #

内置图优化 #

python
import onnxruntime as ort

sess_options = ort.SessionOptions()

sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

sess_options.optimized_model_filepath = "model_optimized_by_ort.onnx"

session = ort.InferenceSession("model.onnx", sess_options)

优化级别 #

python
import onnxruntime as ort

sess_options = ort.SessionOptions()

sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

保存优化模型 #

python
import onnxruntime as ort

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.optimized_model_filepath = "optimized_model.onnx"

session = ort.InferenceSession("model.onnx", sess_options)

常见优化模式 #

BatchNorm 融合 #

python
import onnx
from onnx import optimizer, numpy_helper
import numpy as np

def fuse_bn_into_conv_manual(model):
    graph = model.graph
    
    bn_nodes = {}
    for node in graph.node:
        if node.op_type == "BatchNormalization":
            bn_nodes[node.input[0]] = node
    
    for node in list(graph.node):
        if node.op_type == "Conv" and node.output[0] in bn_nodes:
            bn_node = bn_nodes[node.output[0]]
            
            conv_weight = None
            conv_bias = None
            for init in graph.initializer:
                if init.name == node.input[1]:
                    conv_weight = numpy_helper.to_array(init)
                if len(node.input) > 2 and init.name == node.input[2]:
                    conv_bias = numpy_helper.to_array(init)
            
            bn_scale = None
            bn_bias = None
            bn_mean = None
            bn_var = None
            for init in graph.initializer:
                if init.name == bn_node.input[1]:
                    bn_scale = numpy_helper.to_array(init)
                if init.name == bn_node.input[2]:
                    bn_bias = numpy_helper.to_array(init)
                if init.name == bn_node.input[3]:
                    bn_mean = numpy_helper.to_array(init)
                if init.name == bn_node.input[4]:
                    bn_var = numpy_helper.to_array(init)
            
            if conv_weight is None or bn_scale is None:
                continue
            
            eps = 1e-5
            for attr in bn_node.attribute:
                if attr.name == "epsilon":
                    eps = attr.f
            
            std = np.sqrt(bn_var + eps)
            new_weight = conv_weight * (bn_scale / std).reshape(-1, 1, 1, 1)
            
            if conv_bias is not None:
                new_bias = (conv_bias - bn_mean) / std * bn_scale + bn_bias
            else:
                new_bias = -bn_mean / std * bn_scale + bn_bias
            
            print(f"融合 Conv {node.name} 和 BN {bn_node.name}")
    
    return optimizer.optimize(model, ["fuse_bn_into_conv"])

model = onnx.load("model.onnx")
optimized = fuse_bn_into_conv_manual(model)
onnx.save(optimized, "model_bn_fused.onnx")

常量折叠 #

python
import onnx
from onnx import optimizer, numpy_helper, helper
import numpy as np

def constant_folding(model):
    graph = model.graph
    
    const_nodes = {}
    for node in graph.node:
        if node.op_type == "Constant":
            for attr in node.attribute:
                if attr.name == "value":
                    const_nodes[node.output[0]] = numpy_helper.to_array(attr.t)
    
    for node in list(graph.node):
        new_inputs = []
        changed = False
        for inp in node.input:
            if inp in const_nodes:
                new_inputs.append(const_nodes[inp])
                changed = True
            else:
                new_inputs.append(inp)
        
        if changed:
            print(f"节点 {node.name} 可进行常量折叠")
    
    return optimizer.optimize(model, ["extract_constant_to_initializer"])

model = onnx.load("model.onnx")
optimized = constant_folding(model)
onnx.save(optimized, "model_const_folded.onnx")

死代码消除 #

python
import onnx
from onnx import optimizer

model = onnx.load("model.onnx")

passes = [
    "eliminate_deadend",
    "eliminate_unused_initializer",
    "eliminate_duplicate_initializer"
]

optimized = optimizer.optimize(model, passes)

onnx.save(optimized, "model_dce.onnx")

性能对比 #

优化前后对比 #

python
import onnxruntime as ort
import numpy as np
import time

def benchmark_model(model_path, num_runs=100):
    session = ort.InferenceSession(model_path)
    input_name = session.get_inputs()[0].name
    input_shape = session.get_inputs()[0].shape
    
    if any(isinstance(d, str) for d in input_shape):
        input_shape = [1 if isinstance(d, str) else d for d in input_shape]
    
    input_data = np.random.randn(*input_shape).astype(np.float32)
    
    for _ in range(10):
        session.run(None, {input_name: input_data})
    
    start = time.time()
    for _ in range(num_runs):
        session.run(None, {input_name: input_data})
    end = time.time()
    
    avg_time = (end - start) / num_runs * 1000
    return avg_time

original_time = benchmark_model("model.onnx")
optimized_time = benchmark_model("model_optimized.onnx")

print(f"原始模型: {original_time:.2f} ms")
print(f"优化模型: {optimized_time:.2f} ms")
print(f"提升: {(original_time - optimized_time) / original_time * 100:.1f}%")

最佳实践 #

优化流程 #

text
┌─────────────────────────────────────────────────────────────┐
│                    推荐优化流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1. 导出 ONNX 模型                                          │
│     └── 使用正确的 opset 版本                              │
│                                                             │
│  2. 使用 onnx-simplifier                                    │
│     └── onnxsim model.onnx model_sim.onnx                  │
│                                                             │
│  3. 使用 onnxoptimizer                                      │
│     └── 应用适当的优化 pass                                │
│                                                             │
│  4. 验证模型                                                │
│     └── 确保输出一致性                                      │
│                                                             │
│  5. 性能测试                                                │
│     └── 对比优化前后性能                                    │
│                                                             │
│  6. 部署                                                    │
│     └── 使用优化后的模型                                    │
│                                                             │
└─────────────────────────────────────────────────────────────┘

完整优化脚本 #

python
import onnx
from onnx import optimizer
from onnxsim import simplify
import onnxruntime as ort
import numpy as np

def optimize_model(input_path, output_path, input_shapes=None):
    print(f"加载模型: {input_path}")
    model = onnx.load(input_path)
    
    print("验证原始模型...")
    onnx.checker.check_model(model)
    
    print("简化模型...")
    model_simplified, check = simplify(model, input_shapes=input_shapes)
    if not check:
        print("警告: 简化验证失败")
    
    print("优化模型...")
    passes = [
        "eliminate_identity",
        "eliminate_nop_transpose",
        "eliminate_nop_pad",
        "eliminate_unused_initializer",
        "eliminate_deadend",
        "fuse_bn_into_conv",
        "fuse_consecutive_transposes",
        "fuse_consecutive_squeezes",
        "extract_constant_to_initializer"
    ]
    model_optimized = optimizer.optimize(model_simplified, passes)
    
    print("验证优化模型...")
    onnx.checker.check_model(model_optimized)
    
    print(f"保存优化模型: {output_path}")
    onnx.save(model_optimized, output_path)
    
    print("优化完成!")
    print(f"原始节点数: {len(model.graph.node)}")
    print(f"优化节点数: {len(model_optimized.graph.node)}")
    
    return model_optimized

optimize_model("model.onnx", "model_optimized.onnx")

下一步 #

现在你已经了解了模型优化,接下来学习 自定义算子,学习如何扩展 ONNX 算子!

最后更新:2026-04-04