自定义算子 #

概述 #

当标准 ONNX 算子无法满足需求时,可以通过自定义算子扩展 ONNX 功能。

text
┌─────────────────────────────────────────────────────────────┐
│                    自定义算子应用场景                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  需要自定义算子的场景:                                     │
│  ├── 使用了框架特有的算子                                   │
│  ├── 需要特殊的数学运算                                     │
│  ├── 业务特定的处理逻辑                                     │
│  └── 性能优化的专用实现                                     │
│                                                             │
│  自定义算子组成:                                           │
│  ├── 算子定义 - 定义算子签名和属性                         │
│  ├── 算子实现 - 在推理引擎中实现算子                       │
│  └── 导出支持 - 在框架中支持导出                           │
│                                                             │
└─────────────────────────────────────────────────────────────┘

PyTorch 自定义算子导出 #

注册符号函数 #

python
import torch
import torch.onnx
from torch.onnx import register_custom_op_symbolic

class CustomAdd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        return x + y + 1
    
    @staticmethod
    def symbolic(g, x, y):
        return g.op("CustomDomain::CustomAdd", x, y, alpha_f=1.0)

class MyModel(torch.nn.Module):
    def forward(self, x, y):
        return CustomAdd.apply(x, y)

def custom_add_symbolic(g, x, y):
    return g.op("CustomDomain::CustomAdd", x, y, alpha_f=1.0)

register_custom_op_symbolic("CustomAdd", custom_add_symbolic, 17)

model = MyModel()
x = torch.randn(1, 3)
y = torch.randn(1, 3)

torch.onnx.export(
    model,
    (x, y),
    "custom_add.onnx",
    input_names=["x", "y"],
    output_names=["output"],
    custom_opsets={"CustomDomain": 1}
)

使用 make_custom_op #

python
import torch
from torch.onnx._internal import registration

@registration.onnx_symbolic("CustomDomain::MyRelu")
def my_relu_symbolic(g, x):
    return g.op("Relu", x)

class MyReluModel(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x)

model = MyReluModel()
torch.onnx.export(model, torch.randn(1, 3), "my_relu.onnx")

带属性的自定义算子 #

python
import torch
from torch.onnx import register_custom_op_symbolic

class CustomScale(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale):
        ctx.save_for_backward(torch.tensor(scale))
        return x * scale
    
    @staticmethod
    def symbolic(g, x, scale):
        return g.op("CustomDomain::Scale", x, scale_f=scale)

def custom_scale_symbolic(g, x, scale):
    return g.op("CustomDomain::Scale", x, scale_f=scale)

register_custom_op_symbolic("CustomScale", custom_scale_symbolic, 17)

class ScaleModel(torch.nn.Module):
    def __init__(self, scale=2.0):
        super().__init__()
        self.scale = scale
    
    def forward(self, x):
        return CustomScale.apply(x, self.scale)

model = ScaleModel(scale=3.0)
torch.onnx.export(
    model,
    torch.randn(1, 3),
    "custom_scale.onnx",
    custom_opsets={"CustomDomain": 1}
)

ONNX Runtime 自定义算子 #

Python 实现 #

python
import onnxruntime as ort
import numpy as np

class CustomOpModel:
    def __init__(self):
        pass

def custom_add_impl(x, y, alpha):
    return x + y + alpha

def get_custom_ops():
    return {
        "CustomDomain.CustomAdd": (
            custom_add_impl,
            [
                ort.custom_op_types.float32,
                ort.custom_op_types.float32,
            ],
            [ort.custom_op_types.float32],
        )
    }

so = ort.SessionOptions()
so.register_custom_ops_library("custom_ops_library.so")

session = ort.InferenceSession("custom_add.onnx", so)

C++ 实现 #

cpp
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#include <cmath>

struct CustomAddKernel {
    CustomAddKernel(OrtApi api) : api_(api) {}
    
    void Compute(OrtKernelContext* context) {
        Ort::KernelContext ctx(context);
        
        auto input_x = ctx.GetInput(0);
        auto input_y = ctx.GetInput(1);
        
        auto x_data = input_x.GetTensorData<float>();
        auto y_data = input_y.GetTensorData<float>();
        
        auto dimensions = input_x.GetTensorTypeAndShapeInfo().GetShape();
        
        auto output = ctx.GetOutput(0, dimensions);
        auto output_data = output.GetTensorMutableData<float>();
        
        size_t size = 1;
        for (auto dim : dimensions) {
            size *= dim;
        }
        
        float alpha = 1.0f;
        
        for (size_t i = 0; i < size; ++i) {
            output_data[i] = x_data[i] + y_data[i] + alpha;
        }
    }
    
private:
    OrtApi api_;
};

struct CustomAddOp {
    void* CreateKernel(OrtApi api, const OrtKernelInfo* info) {
        return new CustomAddKernel(api);
    }
    
    const char* GetName() const { return "CustomAdd"; }
    
    const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
    
    size_t GetInputTypeCount() const { return 2; }
    
    ONNXTensorElementDataType GetInputType(size_t index) const {
        return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
    }
    
    size_t GetOutputTypeCount() const { return 1; }
    
    ONNXTensorElementDataType GetOutputType(size_t index) const {
        return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
    }
};

注册自定义算子库 #

python
import onnxruntime as ort

sess_options = ort.SessionOptions()

sess_options.register_custom_ops_library("/path/to/custom_ops.so")

session = ort.InferenceSession("model.onnx", sess_options)

完整示例 #

定义自定义算子 #

python
import torch
import torch.onnx
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as np

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, beta=1.0):
        ctx.beta = beta
        return x * torch.sigmoid(beta * x)
    
    @staticmethod
    def symbolic(g, x, beta=1.0):
        return g.op("CustomDomain::Swish", x, beta_f=beta)

def swish_symbolic(g, x, beta=1.0):
    return g.op("CustomDomain::Swish", x, beta_f=beta)

torch.onnx.register_custom_op_symbolic("Swish", swish_symbolic, 17)

class SwishModel(torch.nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta
    
    def forward(self, x):
        return Swish.apply(x, self.beta)

model = SwishModel(beta=1.0)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "swish_model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=17,
    custom_opsets={"CustomDomain": 1}
)

onnx_model = onnx.load("swish_model.onnx")
onnx.checker.check_model(onnx_model)
print("模型导出成功")

ONNX Runtime 实现 #

python
import onnxruntime as ort
import numpy as np
from onnxruntime.capi import _pybind_state as C

def swish_impl(x, beta=1.0):
    return x * (1.0 / (1.0 + np.exp(-beta * x)))

class SwishKernel:
    def __init__(self):
        pass
    
    def compute(self, x, beta=1.0):
        return swish_impl(x, beta)

def create_custom_op_library():
    pass

sess_options = ort.SessionOptions()

session = ort.InferenceSession("swish_model.onnx", sess_options)

input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

算子定义规范 #

算子签名 #

python
from onnx import defs, helper

op_schema = defs.OpSchema(
    name="Swish",
    domain="custom_domain",
    since_version=1,
    doc="Swish activation function: x * sigmoid(beta * x)",
    inputs=[
        defs.OpSchema.FormalParameter("X", "T", "Input tensor")
    ],
    outputs=[
        defs.OpSchema.FormalParameter("Y", "T", "Output tensor")
    ],
    attributes=[
        defs.OpSchema.Attribute("beta", defs.OpSchema.AttrType.FLOAT, "Beta parameter", default=1.0)
    ],
    type_constraints=[
        defs.OpSchema.TypeConstraint("T", ["float", "double"], "Float types")
    ]
)

创建带自定义算子的模型 #

python
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as np

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 224, 224])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 224, 224])

swish_node = helper.make_node(
    "Swish",
    inputs=["X"],
    outputs=["Y"],
    domain="custom_domain",
    beta=1.0
)

graph = helper.make_graph(
    [swish_node],
    "swish_graph",
    [X],
    [Y]
)

model = helper.make_model(
    graph,
    opset_imports=[
        helper.make_opsetid("", 17),
        helper.make_opsetid("custom_domain", 1)
    ],
    producer_name="custom_op_example"
)

onnx.save(model, "swish_custom.onnx")

调试技巧 #

检查自定义算子 #

python
import onnx

model = onnx.load("custom_model.onnx")

for node in model.graph.node:
    if node.domain:
        print(f"自定义算子: {node.domain}::{node.op_type}")
        print(f"  输入: {list(node.input)}")
        print(f"  输出: {list(node.output)}")
        print(f"  属性: {[attr.name for attr in node.attribute]}")

验证导出 #

python
import torch
import onnx
import onnxruntime as ort
import numpy as np

def validate_custom_op_export(model, input_data, onnx_path):
    model.eval()
    
    with torch.no_grad():
        torch_output = model(input_data)
    
    torch.onnx.export(
        model,
        input_data,
        onnx_path,
        custom_opsets={"CustomDomain": 1}
    )
    
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    
    print("自定义算子节点:")
    for node in onnx_model.graph.node:
        if node.domain:
            print(f"  {node.domain}::{node.op_type}")
    
    return True

model = SwishModel()
input_data = torch.randn(1, 3, 224, 224)
validate_custom_op_export(model, input_data, "swish.onnx")

最佳实践 #

命名规范 #

text
┌─────────────────────────────────────────────────────────────┐
│                    自定义算子命名规范                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  域名:                                                     │
│  ├── 使用有意义的名称,如公司名或项目名                     │
│  ├── 避免使用空字符串(保留给标准算子)                     │
│  └── 示例: mycompany.custom_ops                            │
│                                                             │
│  算子名:                                                   │
│  ├── 使用 PascalCase                                       │
│  ├── 描述算子功能                                          │
│  └── 示例: Swish, FusedConvBN                              │
│                                                             │
│  属性名:                                                   │
│  ├── 使用 snake_case                                       │
│  ├── 添加类型后缀(如 _f 表示 float)                      │
│  └── 示例: kernel_size_i, alpha_f, axes_is                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

版本管理 #

python
import onnx
from onnx import helper

model = helper.make_model(
    graph,
    opset_imports=[
        helper.make_opsetid("", 17),
        helper.make_opsetid("custom_domain", 1)
    ]
)

for opset in model.opset_import:
    print(f"域: {opset.domain or 'default'}, 版本: {opset.version}")

下一步 #

现在你已经了解了自定义算子,接下来学习 模型量化,学习如何对模型进行量化以提升推理性能!

最后更新:2026-04-04