自定义算子 #
概述 #
当标准 ONNX 算子无法满足需求时,可以通过自定义算子扩展 ONNX 功能。
text
┌─────────────────────────────────────────────────────────────┐
│ 自定义算子应用场景 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 需要自定义算子的场景: │
│ ├── 使用了框架特有的算子 │
│ ├── 需要特殊的数学运算 │
│ ├── 业务特定的处理逻辑 │
│ └── 性能优化的专用实现 │
│ │
│ 自定义算子组成: │
│ ├── 算子定义 - 定义算子签名和属性 │
│ ├── 算子实现 - 在推理引擎中实现算子 │
│ └── 导出支持 - 在框架中支持导出 │
│ │
└─────────────────────────────────────────────────────────────┘
PyTorch 自定义算子导出 #
注册符号函数 #
python
import torch
import torch.onnx
from torch.onnx import register_custom_op_symbolic
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return x + y + 1
@staticmethod
def symbolic(g, x, y):
return g.op("CustomDomain::CustomAdd", x, y, alpha_f=1.0)
class MyModel(torch.nn.Module):
def forward(self, x, y):
return CustomAdd.apply(x, y)
def custom_add_symbolic(g, x, y):
return g.op("CustomDomain::CustomAdd", x, y, alpha_f=1.0)
register_custom_op_symbolic("CustomAdd", custom_add_symbolic, 17)
model = MyModel()
x = torch.randn(1, 3)
y = torch.randn(1, 3)
torch.onnx.export(
model,
(x, y),
"custom_add.onnx",
input_names=["x", "y"],
output_names=["output"],
custom_opsets={"CustomDomain": 1}
)
使用 make_custom_op #
python
import torch
from torch.onnx._internal import registration
@registration.onnx_symbolic("CustomDomain::MyRelu")
def my_relu_symbolic(g, x):
return g.op("Relu", x)
class MyReluModel(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
model = MyReluModel()
torch.onnx.export(model, torch.randn(1, 3), "my_relu.onnx")
带属性的自定义算子 #
python
import torch
from torch.onnx import register_custom_op_symbolic
class CustomScale(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.save_for_backward(torch.tensor(scale))
return x * scale
@staticmethod
def symbolic(g, x, scale):
return g.op("CustomDomain::Scale", x, scale_f=scale)
def custom_scale_symbolic(g, x, scale):
return g.op("CustomDomain::Scale", x, scale_f=scale)
register_custom_op_symbolic("CustomScale", custom_scale_symbolic, 17)
class ScaleModel(torch.nn.Module):
def __init__(self, scale=2.0):
super().__init__()
self.scale = scale
def forward(self, x):
return CustomScale.apply(x, self.scale)
model = ScaleModel(scale=3.0)
torch.onnx.export(
model,
torch.randn(1, 3),
"custom_scale.onnx",
custom_opsets={"CustomDomain": 1}
)
ONNX Runtime 自定义算子 #
Python 实现 #
python
import onnxruntime as ort
import numpy as np
class CustomOpModel:
def __init__(self):
pass
def custom_add_impl(x, y, alpha):
return x + y + alpha
def get_custom_ops():
return {
"CustomDomain.CustomAdd": (
custom_add_impl,
[
ort.custom_op_types.float32,
ort.custom_op_types.float32,
],
[ort.custom_op_types.float32],
)
}
so = ort.SessionOptions()
so.register_custom_ops_library("custom_ops_library.so")
session = ort.InferenceSession("custom_add.onnx", so)
C++ 实现 #
cpp
#include "onnxruntime_c_api.h"
#include "onnxruntime_cxx_api.h"
#include <cmath>
struct CustomAddKernel {
CustomAddKernel(OrtApi api) : api_(api) {}
void Compute(OrtKernelContext* context) {
Ort::KernelContext ctx(context);
auto input_x = ctx.GetInput(0);
auto input_y = ctx.GetInput(1);
auto x_data = input_x.GetTensorData<float>();
auto y_data = input_y.GetTensorData<float>();
auto dimensions = input_x.GetTensorTypeAndShapeInfo().GetShape();
auto output = ctx.GetOutput(0, dimensions);
auto output_data = output.GetTensorMutableData<float>();
size_t size = 1;
for (auto dim : dimensions) {
size *= dim;
}
float alpha = 1.0f;
for (size_t i = 0; i < size; ++i) {
output_data[i] = x_data[i] + y_data[i] + alpha;
}
}
private:
OrtApi api_;
};
struct CustomAddOp {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) {
return new CustomAddKernel(api);
}
const char* GetName() const { return "CustomAdd"; }
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
size_t GetInputTypeCount() const { return 2; }
ONNXTensorElementDataType GetInputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t index) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
};
注册自定义算子库 #
python
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.register_custom_ops_library("/path/to/custom_ops.so")
session = ort.InferenceSession("model.onnx", sess_options)
完整示例 #
定义自定义算子 #
python
import torch
import torch.onnx
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as np
class Swish(torch.autograd.Function):
@staticmethod
def forward(ctx, x, beta=1.0):
ctx.beta = beta
return x * torch.sigmoid(beta * x)
@staticmethod
def symbolic(g, x, beta=1.0):
return g.op("CustomDomain::Swish", x, beta_f=beta)
def swish_symbolic(g, x, beta=1.0):
return g.op("CustomDomain::Swish", x, beta_f=beta)
torch.onnx.register_custom_op_symbolic("Swish", swish_symbolic, 17)
class SwishModel(torch.nn.Module):
def __init__(self, beta=1.0):
super().__init__()
self.beta = beta
def forward(self, x):
return Swish.apply(x, self.beta)
model = SwishModel(beta=1.0)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"swish_model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
opset_version=17,
custom_opsets={"CustomDomain": 1}
)
onnx_model = onnx.load("swish_model.onnx")
onnx.checker.check_model(onnx_model)
print("模型导出成功")
ONNX Runtime 实现 #
python
import onnxruntime as ort
import numpy as np
from onnxruntime.capi import _pybind_state as C
def swish_impl(x, beta=1.0):
return x * (1.0 / (1.0 + np.exp(-beta * x)))
class SwishKernel:
def __init__(self):
pass
def compute(self, x, beta=1.0):
return swish_impl(x, beta)
def create_custom_op_library():
pass
sess_options = ort.SessionOptions()
session = ort.InferenceSession("swish_model.onnx", sess_options)
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
算子定义规范 #
算子签名 #
python
from onnx import defs, helper
op_schema = defs.OpSchema(
name="Swish",
domain="custom_domain",
since_version=1,
doc="Swish activation function: x * sigmoid(beta * x)",
inputs=[
defs.OpSchema.FormalParameter("X", "T", "Input tensor")
],
outputs=[
defs.OpSchema.FormalParameter("Y", "T", "Output tensor")
],
attributes=[
defs.OpSchema.Attribute("beta", defs.OpSchema.AttrType.FLOAT, "Beta parameter", default=1.0)
],
type_constraints=[
defs.OpSchema.TypeConstraint("T", ["float", "double"], "Float types")
]
)
创建带自定义算子的模型 #
python
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as np
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 224, 224])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 224, 224])
swish_node = helper.make_node(
"Swish",
inputs=["X"],
outputs=["Y"],
domain="custom_domain",
beta=1.0
)
graph = helper.make_graph(
[swish_node],
"swish_graph",
[X],
[Y]
)
model = helper.make_model(
graph,
opset_imports=[
helper.make_opsetid("", 17),
helper.make_opsetid("custom_domain", 1)
],
producer_name="custom_op_example"
)
onnx.save(model, "swish_custom.onnx")
调试技巧 #
检查自定义算子 #
python
import onnx
model = onnx.load("custom_model.onnx")
for node in model.graph.node:
if node.domain:
print(f"自定义算子: {node.domain}::{node.op_type}")
print(f" 输入: {list(node.input)}")
print(f" 输出: {list(node.output)}")
print(f" 属性: {[attr.name for attr in node.attribute]}")
验证导出 #
python
import torch
import onnx
import onnxruntime as ort
import numpy as np
def validate_custom_op_export(model, input_data, onnx_path):
model.eval()
with torch.no_grad():
torch_output = model(input_data)
torch.onnx.export(
model,
input_data,
onnx_path,
custom_opsets={"CustomDomain": 1}
)
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("自定义算子节点:")
for node in onnx_model.graph.node:
if node.domain:
print(f" {node.domain}::{node.op_type}")
return True
model = SwishModel()
input_data = torch.randn(1, 3, 224, 224)
validate_custom_op_export(model, input_data, "swish.onnx")
最佳实践 #
命名规范 #
text
┌─────────────────────────────────────────────────────────────┐
│ 自定义算子命名规范 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 域名: │
│ ├── 使用有意义的名称,如公司名或项目名 │
│ ├── 避免使用空字符串(保留给标准算子) │
│ └── 示例: mycompany.custom_ops │
│ │
│ 算子名: │
│ ├── 使用 PascalCase │
│ ├── 描述算子功能 │
│ └── 示例: Swish, FusedConvBN │
│ │
│ 属性名: │
│ ├── 使用 snake_case │
│ ├── 添加类型后缀(如 _f 表示 float) │
│ └── 示例: kernel_size_i, alpha_f, axes_is │
│ │
└─────────────────────────────────────────────────────────────┘
版本管理 #
python
import onnx
from onnx import helper
model = helper.make_model(
graph,
opset_imports=[
helper.make_opsetid("", 17),
helper.make_opsetid("custom_domain", 1)
]
)
for opset in model.opset_import:
print(f"域: {opset.domain or 'default'}, 版本: {opset.version}")
下一步 #
现在你已经了解了自定义算子,接下来学习 模型量化,学习如何对模型进行量化以提升推理性能!
最后更新:2026-04-04