Python API #

API 概述 #

ONNX Python 库提供了完整的模型操作接口,包括加载、保存、验证、修改等功能。

text
┌─────────────────────────────────────────────────────────────┐
│                    ONNX Python API 结构                      │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  onnx                                                       │
│  ├── 模型操作                                               │
│  │   ├── load() - 加载模型                                  │
│  │   ├── save() - 保存模型                                  │
│  │   ├── load_model() - 加载模型(带选项)                  │
│  │   └── save_model() - 保存模型(带选项)                  │
│  │                                                          │
│  ├── 模型验证                                               │
│  │   ├── checker.check_model() - 验证模型                   │
│  │   └── checker.validate_model() - 完整验证                │
│  │                                                          │
│  ├── 模型工具                                               │
│  │   ├── helper - 创建模型组件                              │
│  │   ├── numpy_helper - NumPy 转换                          │
│  │   └── shape_inference - 形状推断                         │
│  │                                                          │
│  └── 模型优化                                               │
│      └── optimizer - 图优化                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

模型加载与保存 #

基本加载 #

python
import onnx

model = onnx.load("model.onnx")

print(f"IR 版本: {model.ir_version}")
print(f"节点数量: {len(model.graph.node)}")

加载选项 #

python
import onnx
from onnx import ModelProto

model = ModelProto()
with open("model.onnx", "rb") as f:
    model.ParseFromString(f.read())

model = onnx.load_model("model.onnx", format="protobuf")

try:
    model = onnx.load("model.onnx", load_external_data=False)
except Exception as e:
    print(f"加载失败: {e}")

基本保存 #

python
import onnx

model = onnx.load("model.onnx")

onnx.save(model, "output.onnx")

保存大型模型 #

python
import onnx
from onnx.external_data_helper import convert_model_to_external_data

model = onnx.load("large_model.onnx")

convert_model_to_external_data(
    model,
    all_tensors_to_one_file=True,
    location="weights.bin",
    size_threshold=1024,
    convert_attribute=False
)

onnx.save(model, "large_model_external.onnx")

加载外部数据模型 #

python
import onnx

model = onnx.load("large_model_external.onnx", load_external_data=True)

模型验证 #

基本验证 #

python
import onnx

model = onnx.load("model.onnx")

try:
    onnx.checker.check_model(model)
    print("模型验证通过")
except onnx.checker.ValidationError as e:
    print(f"模型验证失败: {e}")

完整验证 #

python
import onnx

model = onnx.load("model.onnx")

try:
    onnx.checker.check_model(model, full_check=True)
    print("完整验证通过")
except Exception as e:
    print(f"验证失败: {e}")

验证文件 #

python
import onnx

onnx.checker.check_model("model.onnx")

Helper 模块 #

创建张量值信息 #

python
from onnx import helper, TensorProto

input_tensor = helper.make_tensor_value_info(
    "input",
    TensorProto.FLOAT,
    [1, 3, 224, 224]
)

dynamic_input = helper.make_tensor_value_info(
    "input",
    TensorProto.FLOAT,
    None
)
dynamic_input.type.tensor_type.shape.dim.add().dim_param = "N"
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 3
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 224
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 224

创建节点 #

python
from onnx import helper

conv_node = helper.make_node(
    "Conv",
    inputs=["X", "W", "B"],
    outputs=["Y"],
    name="conv1",
    kernel_shape=[3, 3],
    strides=[1, 1],
    pads=[1, 1, 1, 1]
)

relu_node = helper.make_node(
    "Relu",
    inputs=["Y"],
    outputs=["Z"],
    name="relu1"
)

add_node = helper.make_node(
    "Add",
    inputs=["A", "B"],
    outputs=["C"],
    name="add1"
)

创建图 #

python
from onnx import helper

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 224, 224])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 64, 112, 112])

conv_node = helper.make_node(
    "Conv",
    inputs=["X", "W"],
    outputs=["conv_out"],
    kernel_shape=[7, 7],
    strides=[2, 2],
    pads=[3, 3, 3, 3]
)

relu_node = helper.make_node(
    "Relu",
    inputs=["conv_out"],
    outputs=["Y"]
)

graph = helper.make_graph(
    [conv_node, relu_node],
    "simple_graph",
    [X],
    [Y],
    initializer=[]
)

创建模型 #

python
from onnx import helper

model = helper.make_model(
    graph,
    producer_name="my_tool",
    producer_version="1.0.0",
    opset_imports=[helper.make_opsetid("", 17)]
)

model.ir_version = 8

创建 OpsetID #

python
from onnx import helper

default_opset = helper.make_opsetid("", 17)

custom_opset = helper.make_opsetid("custom_domain", 1)

model = helper.make_model(
    graph,
    opset_imports=[default_opset, custom_opset]
)

NumPy Helper 模块 #

张量转 NumPy #

python
import onnx
from onnx import numpy_helper

model = onnx.load("model.onnx")

for initializer in model.graph.initializer:
    weights = numpy_helper.to_array(initializer)
    print(f"{initializer.name}: shape={weights.shape}, dtype={weights.dtype}")

NumPy 转张量 #

python
import numpy as np
from onnx import numpy_helper

weights = np.random.randn(64, 3, 7, 7).astype(np.float32)

tensor = numpy_helper.from_array(weights, name="conv_weight")

print(f"名称: {tensor.name}")
print(f"形状: {list(tensor.dims)}")
print(f"类型: {tensor.data_type}")

处理不同数据类型 #

python
import numpy as np
from onnx import numpy_helper, TensorProto

float32_tensor = numpy_helper.from_array(
    np.random.randn(10, 10).astype(np.float32),
    name="float32_data"
)

int64_tensor = numpy_helper.from_array(
    np.array([1, 2, 3, 4, 5], dtype=np.int64),
    name="int64_data"
)

float16_tensor = numpy_helper.from_array(
    np.random.randn(10, 10).astype(np.float16),
    name="float16_data"
)

int8_tensor = numpy_helper.from_array(
    np.array([1, 2, 3], dtype=np.int8),
    name="int8_data"
)

形状推断 #

基本形状推断 #

python
import onnx
from onnx import shape_inference

model = onnx.load("model.onnx")

inferred_model = shape_inference.infer_shapes(model)

onnx.save(inferred_model, "model_with_shapes.onnx")

检查形状信息 #

python
import onnx
from onnx import shape_inference

model = onnx.load("model.onnx")

inferred_model = shape_inference.infer_shapes(model)

for output in inferred_model.graph.output:
    shape = [d.dim_value or d.dim_param or "?" 
             for d in output.type.tensor_type.shape.dim]
    print(f"{output.name}: {shape}")

for vi in inferred_model.graph.value_info:
    shape = [d.dim_value or d.dim_param or "?" 
             for d in vi.type.tensor_type.shape.dim]
    print(f"{vi.name}: {shape}")

形状推断选项 #

python
from onnx import shape_inference

inferred_model = shape_inference.infer_shapes(
    model,
    check_type=True,
    strict_mode=True,
    data_prop=True
)

模型修改 #

添加节点 #

python
import onnx
from onnx import helper

model = onnx.load("model.onnx")
graph = model.graph

new_node = helper.make_node(
    "Relu",
    inputs=["old_output"],
    outputs=["new_output"],
    name="new_relu"
)

graph.node.append(new_node)

graph.output[0].name = "new_output"

删除节点 #

python
import onnx

model = onnx.load("model.onnx")
graph = model.graph

nodes_to_remove = ["node_to_remove_1", "node_to_remove_2"]

new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]

del graph.node[:]
graph.node.extend(new_nodes)

修改节点属性 #

python
import onnx
from onnx import helper, AttributeProto

model = onnx.load("model.onnx")

for node in model.graph.node:
    if node.op_type == "Conv":
        for attr in node.attribute:
            if attr.name == "strides":
                attr.ints[:] = [2, 2]
                break

onnx.save(model, "model_modified.onnx")

替换初始化器 #

python
import onnx
from onnx import numpy_helper
import numpy as np

model = onnx.load("model.onnx")

for initializer in model.graph.initializer:
    if initializer.name == "conv1.weight":
        old_weights = numpy_helper.to_array(initializer)
        new_weights = old_weights * 0.5
        
        new_initializer = numpy_helper.from_array(
            new_weights,
            name=initializer.name
        )
        
        for i, init in enumerate(model.graph.initializer):
            if init.name == "conv1.weight":
                model.graph.initializer[i].CopyFrom(new_initializer)
                break

onnx.save(model, "model_new_weights.onnx")

模型优化 #

使用优化器 #

python
import onnx
from onnx import optimizer

model = onnx.load("model.onnx")

all_passes = optimizer.get_available_passes()
print(f"可用优化: {all_passes}")

passes = [
    "eliminate_identity",
    "eliminate_nop_transpose",
    "fuse_bn_into_conv",
    "fuse_consecutive_transposes",
    "eliminate_unused_initializer"
]

optimized_model = optimizer.optimize(model, passes)

onnx.save(optimized_model, "model_optimized.onnx")

常用优化 Pass #

text
┌─────────────────────────────────────────────────────────────┐
│                    常用优化 Pass                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  消除优化:                                                 │
│  ├── eliminate_identity - 移除 Identity 节点               │
│  ├── eliminate_nop_transpose - 移除无效 Transpose          │
│  ├── eliminate_nop_pad - 移除无效 Pad                      │
│  ├── eliminate_unused_initializer - 移除未使用参数         │
│  └── eliminate_deadend - 移除死代码                        │
│                                                             │
│  融合优化:                                                 │
│  ├── fuse_bn_into_conv - BatchNorm 融入 Conv               │
│  ├── fuse_consecutive_concats - 合并连续 Concat            │
│  ├── fuse_consecutive_reduce_unsqueeze - 合并归约操作      │
│  ├── fuse_consecutive_squeezes - 合并连续 Squeeze          │
│  └── fuse_consecutive_transposes - 合并连续 Transpose      │
│                                                             │
│  其他优化:                                                 │
│  ├── extract_constant_to_initializer - 提取常量            │
│  ├── lift_lexical_references - 提升词法引用                │
│  └── split_init - 分离初始化                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

模型信息提取 #

获取模型摘要 #

python
import onnx

def get_model_summary(model_path):
    model = onnx.load(model_path)
    graph = model.graph
    
    summary = {
        "ir_version": model.ir_version,
        "producer": f"{model.producer_name} {model.producer_version}",
        "opset": [f"{op.domain or 'ai.onnx'}:{op.version}" 
                  for op in model.opset_import],
        "graph_name": graph.name,
        "inputs": [],
        "outputs": [],
        "nodes": [],
        "initializers": []
    }
    
    for inp in graph.input:
        shape = [d.dim_value or d.dim_param or "?" 
                 for d in inp.type.tensor_type.shape.dim]
        summary["inputs"].append({
            "name": inp.name,
            "shape": shape,
            "type": onnx.TensorProto.DataType.Name(inp.type.tensor_type.elem_type)
        })
    
    for out in graph.output:
        shape = [d.dim_value or d.dim_param or "?" 
                 for d in out.type.tensor_type.shape.dim]
        summary["outputs"].append({
            "name": out.name,
            "shape": shape
        })
    
    for node in graph.node:
        summary["nodes"].append({
            "name": node.name,
            "type": node.op_type,
            "inputs": list(node.input),
            "outputs": list(node.output)
        })
    
    for init in graph.initializer:
        summary["initializers"].append({
            "name": init.name,
            "shape": list(init.dims)
        })
    
    return summary

summary = get_model_summary("model.onnx")
print(f"节点数: {len(summary['nodes'])}")
print(f"参数数: {len(summary['initializers'])}")

统计模型信息 #

python
import onnx
from onnx import numpy_helper

def get_model_stats(model_path):
    model = onnx.load(model_path)
    
    total_params = 0
    param_sizes = {}
    
    for init in model.graph.initializer:
        weights = numpy_helper.to_array(init)
        size = weights.size
        total_params += size
        param_sizes[init.name] = {
            "shape": list(weights.shape),
            "size": size,
            "dtype": str(weights.dtype)
        }
    
    node_types = {}
    for node in model.graph.node:
        node_types[node.op_type] = node_types.get(node.op_type, 0) + 1
    
    return {
        "total_params": total_params,
        "param_details": param_sizes,
        "node_count": len(model.graph.node),
        "node_types": node_types
    }

stats = get_model_stats("model.onnx")
print(f"总参数量: {stats['total_params']:,}")
print(f"节点类型: {stats['node_types']}")

下一步 #

现在你已经了解了 ONNX Python API,接下来学习 ONNX Runtime,学习如何使用 ONNX Runtime 进行高性能推理!

最后更新:2026-04-04