核心概念 #

ONNX 模型结构 #

ONNX 模型是一个有向无环图(DAG),由节点(Node)、边(Edge)和属性(Attribute)组成。

text
┌─────────────────────────────────────────────────────────────┐
│                    ONNX 模型层次结构                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ModelProto (模型)                                          │
│  ├── ir_version: IR 版本号                                 │
│  ├── opset_import: 算子集版本                              │
│  ├── producer_name: 生成器名称                             │
│  ├── producer_version: 生成器版本                          │
│  │                                                          │
│  └── GraphProto (计算图)                                    │
│      ├── name: 图名称                                      │
│      ├── input: 输入张量列表                               │
│      ├── output: 输出张量列表                              │
│      ├── initializer: 初始化参数                           │
│      │                                                      │
│      └── NodeProto (节点)                                   │
│          ├── name: 节点名称                                │
│          ├── op_type: 算子类型                             │
│          ├── input: 输入张量名                             │
│          ├── output: 输出张量名                            │
│          └── attribute: 算子属性                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

模型元数据 #

python
import onnx

model = onnx.load("model.onnx")

print(f"IR 版本: {model.ir_version}")
print(f"生成器: {model.producer_name}")
print(f"生成器版本: {model.producer_version}")

for opset in model.opset_import:
    print(f"算子集: domain={opset.domain}, version={opset.version}")

计算图(Graph) #

计算图是 ONNX 模型的核心,定义了数据流和计算过程。

图的组成 #

text
┌─────────────────────────────────────────────────────────────┐
│                    计算图示例                                │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   输入                                                      │
│   ┌─────┐                                                   │
│   │  X  │                                                   │
│   └──┬──┘                                                   │
│      │                                                      │
│      ▼                                                      │
│   ┌─────────┐    W1    ┌─────────┐                         │
│   │  Conv   │◄─────────│ Weight1 │                          │
│   └────┬────┘          └─────────┘                         │
│        │                                                    │
│        ▼                                                    │
│   ┌─────────┐                                              │
│   │  ReLU   │                                              │
│   └────┬────┘                                              │
│        │                                                    │
│        ▼                                                    │
│   ┌─────────┐    W2    ┌─────────┐                         │
│   │  Conv   │◄─────────│ Weight2 │                          │
│   └────┬────┘          └─────────┘                         │
│        │                                                    │
│        ▼                                                    │
│   ┌─────────┐                                              │
│   │  ReLU   │                                              │
│   └────┬────┘                                              │
│        │                                                    │
│        ▼                                                    │
│   ┌─────┐                                                   │
│   │  Y  │  输出                                            │
│   └─────┘                                                   │
│                                                             │
└─────────────────────────────────────────────────────────────┘

图的基本操作 #

python
import onnx
from onnx import numpy_helper

model = onnx.load("model.onnx")
graph = model.graph

print(f"图名称: {graph.name}")
print(f"输入数量: {len(graph.input)}")
print(f"输出数量: {len(graph.output)}")
print(f"节点数量: {len(graph.node)}")
print(f"初始化参数数量: {len(graph.initializer)}")

print("\n输入:")
for input_tensor in graph.input:
    print(f"  - {input_tensor.name}: {[d.dim_value for d in input_tensor.type.tensor_type.shape.dim]}")

print("\n输出:")
for output_tensor in graph.output:
    print(f"  - {output_tensor.name}")

print("\n节点:")
for node in graph.node:
    print(f"  - {node.op_type}: {node.input} -> {node.output}")

张量(Tensor) #

张量是 ONNX 中的基本数据单元,用于表示输入、输出和中间结果。

张量类型 #

text
┌─────────────────────────────────────────────────────────────┐
│                    ONNX 数据类型                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  浮点类型:                                                 │
│  ├── TensorProto.FLOAT (1)      - float32                  │
│  ├── TensorProto.FLOAT16 (10)   - float16                  │
│  ├── TensorProto.BFLOAT16 (16)  - bfloat16                 │
│  └── TensorProto.DOUBLE (11)    - float64                  │
│                                                             │
│  整数类型:                                                 │
│  ├── TensorProto.INT8 (3)       - int8                     │
│  ├── TensorProto.INT16 (5)      - int16                    │
│  ├── TensorProto.INT32 (6)      - int32                    │
│  ├── TensorProto.INT64 (7)      - int64                    │
│  ├── TensorProto.UINT8 (2)      - uint8                    │
│  ├── TensorProto.UINT16 (9)     - uint16                   │
│  ├── TensorProto.UINT32 (12)    - uint32                   │
│  └── TensorProto.UINT64 (13)    - uint64                   │
│                                                             │
│  其他类型:                                                 │
│  ├── TensorProto.BOOL (9)       - bool                     │
│  ├── TensorProto.STRING (8)     - string                   │
│  └── TensorProto.COMPLEX64/128  - 复数                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

张量形状 #

python
import onnx
from onnx import TensorProto

model = onnx.load("model.onnx")
graph = model.graph

for input_tensor in graph.input:
    tensor_type = input_tensor.type.tensor_type
    
    shape = []
    for dim in tensor_type.shape.dim:
        if dim.dim_value:
            shape.append(dim.dim_value)
        elif dim.dim_param:
            shape.append(dim.dim_param)
        else:
            shape.append(-1)
    
    print(f"{input_tensor.name}: shape={shape}, dtype={TensorProto.DataType.Name(tensor_type.elem_type)}")

动态形状 #

text
┌─────────────────────────────────────────────────────────────┐
│                    动态形状定义                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  静态形状:                                                 │
│  input: [1, 3, 224, 224]  # batch 固定为 1                 │
│                                                             │
│  动态形状:                                                 │
│  input: [batch_size, 3, 224, 224]                          │
│                                                             │
│  定义方式:                                                 │
│  dynamic_axes = {                                           │
│      "input": {0: "batch_size"},                           │
│      "output": {0: "batch_size"}                           │
│  }                                                          │
│                                                             │
│  多维度动态:                                               │
│  dynamic_axes = {                                           │
│      "input": {                                             │
│          0: "batch_size",                                  │
│          2: "height",                                      │
│          3: "width"                                        │
│      }                                                      │
│  }                                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

算子(Operator) #

算子是计算图中的节点,定义了具体的计算操作。

标准算子分类 #

text
┌─────────────────────────────────────────────────────────────┐
│                    ONNX 算子分类                             │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  神经网络层:                                               │
│  ├── Conv - 卷积                                           │
│  ├── ConvTranspose - 转置卷积                              │
│  ├── MaxPool / AveragePool - 池化                          │
│  ├── BatchNormalization - 批归一化                         │
│  ├── Dropout - 随机失活                                    │
│  └── Gemm - 通用矩阵乘法                                   │
│                                                             │
│  激活函数:                                                 │
│  ├── Relu - ReLU                                          │
│  ├── Sigmoid - Sigmoid                                    │
│  ├── Tanh - Tanh                                          │
│  ├── LeakyRelu - Leaky ReLU                               │
│  ├── Softmax - Softmax                                    │
│  └── PRelu - PReLU                                        │
│                                                             │
│  数学运算:                                                 │
│  ├── Add, Sub, Mul, Div - 四则运算                        │
│  ├── MatMul - 矩阵乘法                                    │
│  ├── Pow, Sqrt, Exp - 指数运算                            │
│  ├── Log - 对数                                           │
│  └── ReduceMean, ReduceSum - 归约运算                     │
│                                                             │
│  张量操作:                                                 │
│  ├── Reshape - 形状变换                                    │
│  ├── Transpose - 转置                                      │
│  ├── Concat - 拼接                                        │
│  ├── Split - 分割                                         │
│  ├── Slice - 切片                                         │
│  ├── Gather - 收集                                        │
│  └── Flatten - 展平                                       │
│                                                             │
│  控制流:                                                   │
│  ├── If - 条件分支                                        │
│  └── Loop - 循环                                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

算子属性 #

python
import onnx

model = onnx.load("model.onnx")

for node in model.graph.node:
    if node.op_type == "Conv":
        for attr in node.attribute:
            print(f"属性: {attr.name}")
            print(f"  类型: {onnx.AttributeProto.AttributeType.Name(attr.type)}")
            
            if attr.type == onnx.AttributeProto.INTS:
                print(f"  值: {list(attr.ints)}")
            elif attr.type == onnx.AttributeProto.INT:
                print(f"  值: {attr.i}")
            elif attr.type == onnx.AttributeProto.FLOATS:
                print(f"  值: {list(attr.floats)}")
            elif attr.type == onnx.AttributeProto.FLOAT:
                print(f"  值: {attr.f}")

常见算子示例 #

text
Conv 算子:
├── 输入: X (特征), W (权重), B (偏置,可选)
├── 输出: Y (输出特征)
└── 属性:
    ├── kernel_shape: 卷积核大小
    ├── strides: 步长
    ├── pads: 填充
    ├── dilations: 膨胀
    └── group: 分组数

Gemm 算子(全连接层):
├── 输入: A, B, C (可选偏置)
├── 输出: Y
└── 属性:
    ├── alpha: A 的缩放因子
    ├── beta: C 的缩放因子
    ├── transA: 是否转置 A
    └── transB: 是否转置 B

算子集(Opset) #

算子集定义了算子的版本和语义,确保模型的可移植性。

Opset 版本 #

text
┌─────────────────────────────────────────────────────────────┐
│                    Opset 版本演进                            │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  opset 17 (ONNX 1.13+)                                     │
│  ├── 新增: LayerNormalization, Reshape 优化                │
│  └── 修改: 若干算子行为                                    │
│                                                             │
│  opset 16 (ONNX 1.11+)                                     │
│  ├── 新增: Identity, ScatterND 改进                        │
│  └── 修改: Round 算子                                      │
│                                                             │
│  opset 15 (ONNX 1.10+)                                     │
│  ├── 新增: DFT, HannWindow                                │
│  └── 修改: 若干算子签名                                    │
│                                                             │
│  opset 14 (ONNX 1.9+)                                      │
│  ├── 新增: Trilu, CumSum 改进                             │
│  └── 修改: Reshape 支持动态形状                            │
│                                                             │
│  opset 13 (ONNX 1.8+)                                      │
│  ├── 新增: Einsum, Col2Im                                 │
│  └── 修改: 若干算子类型支持                                │
│                                                             │
│  opset 11 (ONNX 1.6+)                                      │
│  ├── 动态形状支持                                          │
│  └── 循环和条件支持                                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Opset 兼容性 #

python
import onnx
from onnx import version_converter

model = onnx.load("model.onnx")

print(f"当前 opset 版本: {model.opset_import[0].version}")

converted_model = version_converter.convert_version(model, 14)

onnx.save(converted_model, "model_opset14.onnx")

初始化器(Initializer) #

初始化器存储模型的权重和偏置等参数。

访问初始化器 #

python
import onnx
from onnx import numpy_helper

model = onnx.load("model.onnx")

for initializer in model.graph.initializer:
    name = initializer.name
    dims = list(initializer.dims)
    dtype = initializer.data_type
    
    weights = numpy_helper.to_array(initializer)
    
    print(f"参数: {name}")
    print(f"  形状: {dims}")
    print(f"  类型: {dtype}")
    print(f"  NumPy 形状: {weights.shape}")
    print(f"  值范围: [{weights.min():.4f}, {weights.max():.4f}]")

创建初始化器 #

python
import onnx
from onnx import numpy_helper, TensorProto
import numpy as np

weights = np.random.randn(64, 3, 3, 3).astype(np.float32)

initializer = numpy_helper.from_array(weights, name="conv1.weight")

model.graph.initializer.append(initializer)

值信息(ValueInfo) #

值信息描述了张量的名称、类型和形状。

python
import onnx

model = onnx.load("model.onnx")
graph = model.graph

def print_value_info(value_info, prefix=""):
    name = value_info.name
    tensor_type = value_info.type.tensor_type
    
    shape = []
    for dim in tensor_type.shape.dim:
        if dim.dim_value:
            shape.append(dim.dim_value)
        elif dim.dim_param:
            shape.append(dim.dim_param)
        else:
            shape.append("?")
    
    dtype = onnx.TensorProto.DataType.Name(tensor_type.elem_type)
    
    print(f"{prefix}{name}: shape={shape}, dtype={dtype}")

print("输入:")
for vi in graph.input:
    print_value_info(vi, "  ")

print("\n输出:")
for vi in graph.output:
    print_value_info(vi, "  ")

模型验证 #

ONNX 提供了模型验证工具,确保模型符合规范。

基本验证 #

python
import onnx

model = onnx.load("model.onnx")

try:
    onnx.checker.check_model(model)
    print("模型验证通过!")
except onnx.checker.ValidationError as e:
    print(f"模型验证失败: {e}")

详细验证 #

python
import onnx
from onnx import checker, helper

model = onnx.load("model.onnx")

checker.check_model(model, full_check=True)

try:
    onnx.helper.printable_graph(model.graph)
except Exception as e:
    print(f"图打印失败: {e}")

形状推断 #

python
import onnx
from onnx import shape_inference

model = onnx.load("model.onnx")

model_with_shapes = shape_inference.infer_shapes(model)

onnx.save(model_with_shapes, "model_with_shapes.onnx")

模型可视化 #

使用 Netron #

python
import netron

netron.start("model.onnx")

netron.start("model.onnx", port=8080)

使用 ONNX 自带工具 #

python
import onnx

model = onnx.load("model.onnx")

print(onnx.helper.printable_graph(model.graph))

下一步 #

现在你已经掌握了 ONNX 的核心概念,接下来学习 模型结构,深入了解 ONNX 模型的详细结构!

最后更新:2026-04-04