How to Implement Custom Operators in ONNX Runtime Using the Custom Operator API
To implement custom operators in ONNX Runtime, define an OrtCustomOp struct describing the operator's metadata and callbacks, implement a kernel class containing the Compute logic, and register the operator domain using CreateCustomRegistry before attaching it to SessionOptions.
The Custom Operator API in microsoft/onnxruntime enables extending the inference engine with user-defined operators compiled as shared libraries and loaded at runtime. This workflow centers on three core components defined in onnxruntime/core/session/custom_ops.h and custom_ops.cc: the operator descriptor, the kernel implementation, and the domain registration mechanism.
Define the Operator Descriptor (OrtCustomOp)
The OrtCustomOp struct acts as the bridge between ONNX Runtime and your implementation. It is declared in onnxruntime/include/onnxruntime/core/session/ort_apis.h and requires function pointers for kernel lifecycle management, type constraints, and naming.
Create a static instance that populates the required callbacks. The CreateKernel callback receives the OrtApi and OrtKernelInfo, allowing you to capture attributes or verify domain metadata as demonstrated in onnxruntime/test/shared_lib/custom_op_utils.cc.
static const OrtCustomOp c_CustomAdd = {
/* version */ 1,
/* CreateKernel */ [](const OrtCustomOp* /*self*/, const OrtApi* api,
const OrtKernelInfo* info) -> void* {
return new MyAddKernel(*api, info);
},
/* GetName */ [](const OrtCustomOp* /*self*/) -> const char* {
return "MyAdd";
},
/* GetInputTypeCount */ [](const OrtCustomOp* /*self*/) -> size_t { return 2; },
/* GetOutputTypeCount */ [](const OrtCustomOp* /*self*/) -> size_t { return 1; },
/* GetInputType */ [](const OrtCustomOp* /*self*/, size_t /*idx*/) -> ONNXTensorElementDataType {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
},
/* GetOutputType */ [](const OrtCustomOp* /*self*/, size_t /*idx*/) -> ONNXTensorElementDataType {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
},
/* KernelCompute */ [](void* op_kernel, OrtKernelContext* ctx) {
static_cast<MyAddKernel*>(op_kernel)->Compute(ctx);
},
/* KernelDestroy */ [](void* op_kernel) {
delete static_cast<MyAddKernel*>(op_kernel);
},
/* Optional shape inference (API v17+) */ nullptr
};
Implement the Kernel Class
The kernel encapsulates the actual compute logic. It typically stores a reference to the OrtApi and implements a Compute method accepting an OrtKernelContext*. The implementation in onnxruntime/test/shared_lib/custom_op_utils.cc provides concrete examples of attribute handling and type checking.
class MyAddKernel {
public:
MyAddKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api) {
// Optional: Verify operator metadata using Ort::ConstKernelInfo
Ort::ConstKernelInfo kinfo(info);
// kinfo.GetOperatorDomain(), GetOperatorType(), GetOperatorSinceVersion()
}
void Compute(OrtKernelContext* ctx) const {
Ort::KernelContext context(ctx);
const float* X = context.GetInput(0).GetTensorData<float>();
const float* Y = context.GetInput(1).GetTensorData<float>();
auto shape = context.GetInput(0).GetTensorTypeAndShapeInfo().GetShape();
float* Z = context.GetOutput(0, shape).GetTensorMutableData<float>();
size_t count = context.GetOutput(0).GetTensorTypeAndShapeInfo().GetElementCount();
for (size_t i = 0; i < count; ++i) Z[i] = X[i] + Y[i];
}
private:
const OrtApi& api_;
};
Accessing Advanced Context Features
The Ort::KernelContext wrapper (implemented via ORT_API_STATUS_IMPL macros in onnxruntime/core/session/custom_ops.cc) exposes several utilities:
GetInput(n)/GetOutput(n, shape): Access input tensors and allocate output storage with specified dimensions.KernelContext_GetGPUComputeStream: Retrieve the CUDA compute stream for GPU execution providers (lines 58-71 incustom_ops.cc).KernelContext_GetAllocator: Obtain an allocator for temporary buffers inside the kernel (lines 48-55 incustom_ops.cc).
Register the Custom Operator Domain
Registration follows a three-step sequence implemented in onnxruntime/core/session/custom_ops.cc:
- Create an
OrtCustomOpDomainusingOrtApis::CreateCustomOpDomain. - Add operators via
OrtApis::CustomOpDomain_Add. - Instantiate a
CustomRegistryusingCreateCustomRegistry(declared incustom_ops.h, lines 14-16) and attach it toSessionOptions.
// 1. Create domain
OrtCustomOpDomain* domain = nullptr;
Ort::ThrowOnError(OrtApis::CreateCustomOpDomain("my_custom_ops", &domain));
Ort::ThrowOnError(OrtApis::CustomOpDomain_Add(domain, &c_CustomAdd));
// 2. Create registry
std::shared_ptr<onnxruntime::CustomRegistry> custom_registry;
std::vector<OrtCustomOpDomain* const> domains = {domain};
Ort::ThrowOnError(onnxruntime::CreateCustomRegistry(domains, custom_registry));
// 3. Attach to session options
Ort::SessionOptions session_options;
session_options.Add(custom_registry);
Optional: Implement Shape Inference
For operators with dynamic output shapes, implement OrtCustomOp::InferOutputShapeFn, available from API version 17. The callback receives an OrtShapeInferContext (wrapper defined in custom_ops.cc, lines 63-127), allowing you to read input shapes and call SetOutputTypeShape. The CreateSchema helper function in custom_ops.cc (lines 1125-1155) demonstrates how ONNX Runtime constructs the schema from your callback.
Build and Load the Shared Library
Export an OrtRegisterCustomOps entry point so ONNX Runtime can locate and initialize your library at runtime.
// my_custom_op.cc
#include "onnxruntime_c_api.h"
#include "core/session/custom_ops.h"
extern "C" OrtStatus* OrtRegisterCustomOps(OrtSessionOptions* options,
const OrtApiBase* api_base) {
const OrtApi* api = api_base->GetApi(ORT_API_VERSION);
OrtCustomOpDomain* domain = nullptr;
OrtStatus* status = api->CreateCustomOpDomain("my_custom_ops", &domain);
if (status) return status;
status = api->CustomOpDomain_Add(domain, &c_CustomAdd);
if (status) return status;
std::shared_ptr<onnxruntime::CustomRegistry> registry;
std::vector<OrtCustomOpDomain* const> domains = {domain};
status = onnxruntime::CreateCustomRegistry(domains, registry);
if (status) return status;
// Attach registry (internal detail shown in test implementations)
options->custom_op_register = registry.get();
return nullptr;
}
Compile the source into a shared library:
g++ -fPIC -shared -std=c++17 my_custom_op.cc -I${ORT_HOME}/include \
-L${ORT_HOME}/lib -lonnxruntime -o libmycustomop.so
Load the library from Python or C++ before creating the session:
import onnxruntime as ort
opts = ort.SessionOptions()
opts.register_custom_ops_library("./libmycustomop.so")
sess = ort.InferenceSession("model_with_custom_op.onnx", sess_options=opts)
Summary
- Define metadata using the
OrtCustomOpstruct with callbacks forCreateKernel,KernelCompute,GetName, and type constraints. - Implement compute logic in a kernel class that uses
Ort::KernelContext(wrappingOrtKernelContext) to access tensors viaGetInputandGetOutput. - Register the domain by creating an
OrtCustomOpDomain, adding it to aCustomRegistryviaCreateCustomRegistry, and attaching the registry toSessionOptions. - Deploy dynamically by exporting
OrtRegisterCustomOpsfrom a shared library and loading it viaregister_custom_ops_library.
Frequently Asked Questions
How do I access CUDA streams within a custom operator kernel?
According to the implementation in onnxruntime/core/session/custom_ops.cc (lines 58-71), you can query the current GPU compute stream via KernelContext_GetGPUComputeStream. This allows custom CUDA kernels to synchronize with the CUDA Execution Provider's stream, ensuring proper memory ordering and execution overlap.
Can I register multiple custom operators in a single shared library?
Yes. Create multiple OrtCustomOp descriptors and add them to the same OrtCustomOpDomain using repeated calls to OrtApis::CustomOpDomain_Add before constructing the CustomRegistry. You can also use multiple domains by passing a vector of domains to CreateCustomRegistry, which is useful for organizing operators by functionality or versioning.
What ONNX Runtime API version is required for shape inference callbacks?
The InferOutputShapeFn callback in OrtCustomOp requires API version 17 or later. This allows ONNX Runtime to perform shape inference during model loading rather than at runtime. For earlier versions, you must omit this callback, and shapes will be inferred dynamically during the first inference run.
How does ONNX Runtime dispatch model nodes to my custom kernel?
When loading a model, ONNX Runtime matches node domains and op types against registered custom domains. If a match is found in the CustomRegistry, the runtime invokes your CreateKernel callback to instantiate the kernel object, then calls KernelCompute during the inference phase. This dispatch is transparent to client code; models reference custom ops by domain and name exactly like built-in operators.
Have a question about this repo?
These articles cover the highlights, but your codebase questions are specific. Give your agent direct access to the source. Share this with your agent to get started:
curl -s "https://instagit.com/install.md" Maintain an open-source project? Get it listed too →