Python API #
API 概述 #
ONNX Python 库提供了完整的模型操作接口,包括加载、保存、验证、修改等功能。
text
┌─────────────────────────────────────────────────────────────┐
│ ONNX Python API 结构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ onnx │
│ ├── 模型操作 │
│ │ ├── load() - 加载模型 │
│ │ ├── save() - 保存模型 │
│ │ ├── load_model() - 加载模型(带选项) │
│ │ └── save_model() - 保存模型(带选项) │
│ │ │
│ ├── 模型验证 │
│ │ ├── checker.check_model() - 验证模型 │
│ │ └── checker.validate_model() - 完整验证 │
│ │ │
│ ├── 模型工具 │
│ │ ├── helper - 创建模型组件 │
│ │ ├── numpy_helper - NumPy 转换 │
│ │ └── shape_inference - 形状推断 │
│ │ │
│ └── 模型优化 │
│ └── optimizer - 图优化 │
│ │
└─────────────────────────────────────────────────────────────┘
模型加载与保存 #
基本加载 #
python
import onnx
model = onnx.load("model.onnx")
print(f"IR 版本: {model.ir_version}")
print(f"节点数量: {len(model.graph.node)}")
加载选项 #
python
import onnx
from onnx import ModelProto
model = ModelProto()
with open("model.onnx", "rb") as f:
model.ParseFromString(f.read())
model = onnx.load_model("model.onnx", format="protobuf")
try:
model = onnx.load("model.onnx", load_external_data=False)
except Exception as e:
print(f"加载失败: {e}")
基本保存 #
python
import onnx
model = onnx.load("model.onnx")
onnx.save(model, "output.onnx")
保存大型模型 #
python
import onnx
from onnx.external_data_helper import convert_model_to_external_data
model = onnx.load("large_model.onnx")
convert_model_to_external_data(
model,
all_tensors_to_one_file=True,
location="weights.bin",
size_threshold=1024,
convert_attribute=False
)
onnx.save(model, "large_model_external.onnx")
加载外部数据模型 #
python
import onnx
model = onnx.load("large_model_external.onnx", load_external_data=True)
模型验证 #
基本验证 #
python
import onnx
model = onnx.load("model.onnx")
try:
onnx.checker.check_model(model)
print("模型验证通过")
except onnx.checker.ValidationError as e:
print(f"模型验证失败: {e}")
完整验证 #
python
import onnx
model = onnx.load("model.onnx")
try:
onnx.checker.check_model(model, full_check=True)
print("完整验证通过")
except Exception as e:
print(f"验证失败: {e}")
验证文件 #
python
import onnx
onnx.checker.check_model("model.onnx")
Helper 模块 #
创建张量值信息 #
python
from onnx import helper, TensorProto
input_tensor = helper.make_tensor_value_info(
"input",
TensorProto.FLOAT,
[1, 3, 224, 224]
)
dynamic_input = helper.make_tensor_value_info(
"input",
TensorProto.FLOAT,
None
)
dynamic_input.type.tensor_type.shape.dim.add().dim_param = "N"
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 3
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 224
dynamic_input.type.tensor_type.shape.dim.add().dim_value = 224
创建节点 #
python
from onnx import helper
conv_node = helper.make_node(
"Conv",
inputs=["X", "W", "B"],
outputs=["Y"],
name="conv1",
kernel_shape=[3, 3],
strides=[1, 1],
pads=[1, 1, 1, 1]
)
relu_node = helper.make_node(
"Relu",
inputs=["Y"],
outputs=["Z"],
name="relu1"
)
add_node = helper.make_node(
"Add",
inputs=["A", "B"],
outputs=["C"],
name="add1"
)
创建图 #
python
from onnx import helper
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 224, 224])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 64, 112, 112])
conv_node = helper.make_node(
"Conv",
inputs=["X", "W"],
outputs=["conv_out"],
kernel_shape=[7, 7],
strides=[2, 2],
pads=[3, 3, 3, 3]
)
relu_node = helper.make_node(
"Relu",
inputs=["conv_out"],
outputs=["Y"]
)
graph = helper.make_graph(
[conv_node, relu_node],
"simple_graph",
[X],
[Y],
initializer=[]
)
创建模型 #
python
from onnx import helper
model = helper.make_model(
graph,
producer_name="my_tool",
producer_version="1.0.0",
opset_imports=[helper.make_opsetid("", 17)]
)
model.ir_version = 8
创建 OpsetID #
python
from onnx import helper
default_opset = helper.make_opsetid("", 17)
custom_opset = helper.make_opsetid("custom_domain", 1)
model = helper.make_model(
graph,
opset_imports=[default_opset, custom_opset]
)
NumPy Helper 模块 #
张量转 NumPy #
python
import onnx
from onnx import numpy_helper
model = onnx.load("model.onnx")
for initializer in model.graph.initializer:
weights = numpy_helper.to_array(initializer)
print(f"{initializer.name}: shape={weights.shape}, dtype={weights.dtype}")
NumPy 转张量 #
python
import numpy as np
from onnx import numpy_helper
weights = np.random.randn(64, 3, 7, 7).astype(np.float32)
tensor = numpy_helper.from_array(weights, name="conv_weight")
print(f"名称: {tensor.name}")
print(f"形状: {list(tensor.dims)}")
print(f"类型: {tensor.data_type}")
处理不同数据类型 #
python
import numpy as np
from onnx import numpy_helper, TensorProto
float32_tensor = numpy_helper.from_array(
np.random.randn(10, 10).astype(np.float32),
name="float32_data"
)
int64_tensor = numpy_helper.from_array(
np.array([1, 2, 3, 4, 5], dtype=np.int64),
name="int64_data"
)
float16_tensor = numpy_helper.from_array(
np.random.randn(10, 10).astype(np.float16),
name="float16_data"
)
int8_tensor = numpy_helper.from_array(
np.array([1, 2, 3], dtype=np.int8),
name="int8_data"
)
形状推断 #
基本形状推断 #
python
import onnx
from onnx import shape_inference
model = onnx.load("model.onnx")
inferred_model = shape_inference.infer_shapes(model)
onnx.save(inferred_model, "model_with_shapes.onnx")
检查形状信息 #
python
import onnx
from onnx import shape_inference
model = onnx.load("model.onnx")
inferred_model = shape_inference.infer_shapes(model)
for output in inferred_model.graph.output:
shape = [d.dim_value or d.dim_param or "?"
for d in output.type.tensor_type.shape.dim]
print(f"{output.name}: {shape}")
for vi in inferred_model.graph.value_info:
shape = [d.dim_value or d.dim_param or "?"
for d in vi.type.tensor_type.shape.dim]
print(f"{vi.name}: {shape}")
形状推断选项 #
python
from onnx import shape_inference
inferred_model = shape_inference.infer_shapes(
model,
check_type=True,
strict_mode=True,
data_prop=True
)
模型修改 #
添加节点 #
python
import onnx
from onnx import helper
model = onnx.load("model.onnx")
graph = model.graph
new_node = helper.make_node(
"Relu",
inputs=["old_output"],
outputs=["new_output"],
name="new_relu"
)
graph.node.append(new_node)
graph.output[0].name = "new_output"
删除节点 #
python
import onnx
model = onnx.load("model.onnx")
graph = model.graph
nodes_to_remove = ["node_to_remove_1", "node_to_remove_2"]
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
del graph.node[:]
graph.node.extend(new_nodes)
修改节点属性 #
python
import onnx
from onnx import helper, AttributeProto
model = onnx.load("model.onnx")
for node in model.graph.node:
if node.op_type == "Conv":
for attr in node.attribute:
if attr.name == "strides":
attr.ints[:] = [2, 2]
break
onnx.save(model, "model_modified.onnx")
替换初始化器 #
python
import onnx
from onnx import numpy_helper
import numpy as np
model = onnx.load("model.onnx")
for initializer in model.graph.initializer:
if initializer.name == "conv1.weight":
old_weights = numpy_helper.to_array(initializer)
new_weights = old_weights * 0.5
new_initializer = numpy_helper.from_array(
new_weights,
name=initializer.name
)
for i, init in enumerate(model.graph.initializer):
if init.name == "conv1.weight":
model.graph.initializer[i].CopyFrom(new_initializer)
break
onnx.save(model, "model_new_weights.onnx")
模型优化 #
使用优化器 #
python
import onnx
from onnx import optimizer
model = onnx.load("model.onnx")
all_passes = optimizer.get_available_passes()
print(f"可用优化: {all_passes}")
passes = [
"eliminate_identity",
"eliminate_nop_transpose",
"fuse_bn_into_conv",
"fuse_consecutive_transposes",
"eliminate_unused_initializer"
]
optimized_model = optimizer.optimize(model, passes)
onnx.save(optimized_model, "model_optimized.onnx")
常用优化 Pass #
text
┌─────────────────────────────────────────────────────────────┐
│ 常用优化 Pass │
├─────────────────────────────────────────────────────────────┤
│ │
│ 消除优化: │
│ ├── eliminate_identity - 移除 Identity 节点 │
│ ├── eliminate_nop_transpose - 移除无效 Transpose │
│ ├── eliminate_nop_pad - 移除无效 Pad │
│ ├── eliminate_unused_initializer - 移除未使用参数 │
│ └── eliminate_deadend - 移除死代码 │
│ │
│ 融合优化: │
│ ├── fuse_bn_into_conv - BatchNorm 融入 Conv │
│ ├── fuse_consecutive_concats - 合并连续 Concat │
│ ├── fuse_consecutive_reduce_unsqueeze - 合并归约操作 │
│ ├── fuse_consecutive_squeezes - 合并连续 Squeeze │
│ └── fuse_consecutive_transposes - 合并连续 Transpose │
│ │
│ 其他优化: │
│ ├── extract_constant_to_initializer - 提取常量 │
│ ├── lift_lexical_references - 提升词法引用 │
│ └── split_init - 分离初始化 │
│ │
└─────────────────────────────────────────────────────────────┘
模型信息提取 #
获取模型摘要 #
python
import onnx
def get_model_summary(model_path):
model = onnx.load(model_path)
graph = model.graph
summary = {
"ir_version": model.ir_version,
"producer": f"{model.producer_name} {model.producer_version}",
"opset": [f"{op.domain or 'ai.onnx'}:{op.version}"
for op in model.opset_import],
"graph_name": graph.name,
"inputs": [],
"outputs": [],
"nodes": [],
"initializers": []
}
for inp in graph.input:
shape = [d.dim_value or d.dim_param or "?"
for d in inp.type.tensor_type.shape.dim]
summary["inputs"].append({
"name": inp.name,
"shape": shape,
"type": onnx.TensorProto.DataType.Name(inp.type.tensor_type.elem_type)
})
for out in graph.output:
shape = [d.dim_value or d.dim_param or "?"
for d in out.type.tensor_type.shape.dim]
summary["outputs"].append({
"name": out.name,
"shape": shape
})
for node in graph.node:
summary["nodes"].append({
"name": node.name,
"type": node.op_type,
"inputs": list(node.input),
"outputs": list(node.output)
})
for init in graph.initializer:
summary["initializers"].append({
"name": init.name,
"shape": list(init.dims)
})
return summary
summary = get_model_summary("model.onnx")
print(f"节点数: {len(summary['nodes'])}")
print(f"参数数: {len(summary['initializers'])}")
统计模型信息 #
python
import onnx
from onnx import numpy_helper
def get_model_stats(model_path):
model = onnx.load(model_path)
total_params = 0
param_sizes = {}
for init in model.graph.initializer:
weights = numpy_helper.to_array(init)
size = weights.size
total_params += size
param_sizes[init.name] = {
"shape": list(weights.shape),
"size": size,
"dtype": str(weights.dtype)
}
node_types = {}
for node in model.graph.node:
node_types[node.op_type] = node_types.get(node.op_type, 0) + 1
return {
"total_params": total_params,
"param_details": param_sizes,
"node_count": len(model.graph.node),
"node_types": node_types
}
stats = get_model_stats("model.onnx")
print(f"总参数量: {stats['total_params']:,}")
print(f"节点类型: {stats['node_types']}")
下一步 #
现在你已经了解了 ONNX Python API,接下来学习 ONNX Runtime,学习如何使用 ONNX Runtime 进行高性能推理!
最后更新:2026-04-04