com.github.asus4.onnxruntime 0.1.14 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Plugins/Android/onnxruntime-android.aar +0 -0
- package/Plugins/Linux/x64/libonnxruntime.so +0 -0
- package/Plugins/Windows/x64/onnxruntime.dll +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +6 -6
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/coreml_provider_factory.h +4 -1
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +134 -19
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +18 -3
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +68 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_lite_custom_op.h +1119 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +19 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +32 -9
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/coreml_provider_factory.h +4 -1
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +134 -19
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +18 -3
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +68 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_lite_custom_op.h +1119 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +19 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +32 -9
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/coreml_provider_factory.h +4 -1
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_c_api.h +134 -19
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_cxx_api.h +18 -3
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_cxx_inline.h +68 -15
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Versions/A/Headers/onnxruntime_lite_custom_op.h +1119 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_run_options_config_keys.h +19 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_session_options_config_keys.h +32 -9
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Info.plist → Versions/A/Resources/Info.plist} +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{onnxruntime → Versions/A/onnxruntime} +0 -0
- package/Plugins/macOS/libonnxruntime.dylib +0 -0
- package/README.md +8 -8
- package/Runtime/AssemblyInfo.shared.cs +1 -11
- package/Runtime/NativeMethods.shared.cs +37 -2
- package/Runtime/OrtValue.shared.cs +38 -38
- package/Runtime/SessionOptions.shared.cs +14 -0
- package/Runtime/Training/NativeTrainingMethods.shared.cs +20 -2
- package/Runtime/Training/TrainingSession.shared.cs +107 -0
- package/package.json +1 -1
- /package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/cpu_provider_factory.h +0 -0
- /package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/{Headers → Versions/A/Headers}/onnxruntime_float16.h +0 -0
|
@@ -0,0 +1,1119 @@
|
|
|
1
|
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
// Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
// Summary
|
|
5
|
+
// The header has APIs to save custom op authors the trouble of defining schemas,
|
|
6
|
+
// which will be inferred by functions' signature, as long as their argument list has types supported here.
|
|
7
|
+
// Input could be:
|
|
8
|
+
// 1. Tensor of onnx data types.
|
|
9
|
+
// 2. Span of onnx data types.
|
|
10
|
+
// 3. Scalar of onnx data types.
|
|
11
|
+
// A input could be optional if indicated as std::optional<...>.
|
|
12
|
+
// For an output, it must be a tensor of onnx data types.
|
|
13
|
+
// Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
|
|
14
|
+
// For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
|
15
|
+
// Note - all APIs in this header are ABI.
|
|
16
|
+
|
|
17
|
+
#pragma once
|
|
18
|
+
#include "onnxruntime_cxx_api.h"
|
|
19
|
+
#include <optional>
|
|
20
|
+
#include <numeric>
|
|
21
|
+
#include <functional>
|
|
22
|
+
#include <unordered_set>
|
|
23
|
+
|
|
24
|
+
namespace Ort {
|
|
25
|
+
namespace Custom {
|
|
26
|
+
|
|
27
|
+
class ArgBase {
|
|
28
|
+
public:
|
|
29
|
+
ArgBase(OrtKernelContext* ctx,
|
|
30
|
+
size_t indice,
|
|
31
|
+
bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
|
|
32
|
+
virtual ~ArgBase() {};
|
|
33
|
+
|
|
34
|
+
protected:
|
|
35
|
+
struct KernelContext ctx_;
|
|
36
|
+
size_t indice_;
|
|
37
|
+
bool is_input_;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
using ArgPtr = std::unique_ptr<Custom::ArgBase>;
|
|
41
|
+
using ArgPtrs = std::vector<ArgPtr>;
|
|
42
|
+
|
|
43
|
+
class TensorBase : public ArgBase {
|
|
44
|
+
public:
|
|
45
|
+
TensorBase(OrtKernelContext* ctx,
|
|
46
|
+
size_t indice,
|
|
47
|
+
bool is_input) : ArgBase(ctx, indice, is_input) {}
|
|
48
|
+
|
|
49
|
+
operator bool() const {
|
|
50
|
+
return shape_.has_value();
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const std::vector<int64_t>& Shape() const {
|
|
54
|
+
if (!shape_.has_value()) {
|
|
55
|
+
ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
56
|
+
}
|
|
57
|
+
return shape_.value();
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
ONNXTensorElementDataType Type() const {
|
|
61
|
+
return type_;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
int64_t NumberOfElement() const {
|
|
65
|
+
if (shape_.has_value()) {
|
|
66
|
+
return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
|
|
67
|
+
} else {
|
|
68
|
+
return 0;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
std::string Shape2Str() const {
|
|
73
|
+
if (shape_.has_value()) {
|
|
74
|
+
std::string shape_str;
|
|
75
|
+
for (const auto& dim : *shape_) {
|
|
76
|
+
shape_str.append(std::to_string(dim));
|
|
77
|
+
shape_str.append(", ");
|
|
78
|
+
}
|
|
79
|
+
return shape_str;
|
|
80
|
+
} else {
|
|
81
|
+
return "empty";
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
bool IsCpuTensor() const {
|
|
86
|
+
return strcmp("Cpu", mem_type_) == 0;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
virtual const void* DataRaw() const = 0;
|
|
90
|
+
virtual size_t SizeInBytes() const = 0;
|
|
91
|
+
|
|
92
|
+
protected:
|
|
93
|
+
std::optional<std::vector<int64_t>> shape_;
|
|
94
|
+
ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
|
|
95
|
+
const char* mem_type_ = "Cpu";
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
template <typename T>
|
|
99
|
+
struct Span {
|
|
100
|
+
const T* data_ = {};
|
|
101
|
+
size_t size_ = {};
|
|
102
|
+
void Assign(const T* data, size_t size) {
|
|
103
|
+
data_ = data;
|
|
104
|
+
size_ = size;
|
|
105
|
+
}
|
|
106
|
+
size_t size() const { return size_; }
|
|
107
|
+
T operator[](size_t indice) const {
|
|
108
|
+
return data_[indice];
|
|
109
|
+
}
|
|
110
|
+
const T* data() const { return data_; }
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
template <typename T>
|
|
114
|
+
class Tensor : public TensorBase {
|
|
115
|
+
public:
|
|
116
|
+
using TT = typename std::remove_reference<T>::type;
|
|
117
|
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
|
118
|
+
if (is_input_) {
|
|
119
|
+
if (indice >= ctx_.GetInputCount()) {
|
|
120
|
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
|
121
|
+
}
|
|
122
|
+
const_value_ = ctx_.GetInput(indice);
|
|
123
|
+
auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
|
|
124
|
+
shape_ = type_shape_info.GetShape();
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
const TT* Data() const {
|
|
128
|
+
return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
|
|
129
|
+
}
|
|
130
|
+
TT* Allocate(const std::vector<int64_t>& shape) {
|
|
131
|
+
shape_ = shape;
|
|
132
|
+
if (!data_) {
|
|
133
|
+
shape_ = shape;
|
|
134
|
+
data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
|
|
135
|
+
}
|
|
136
|
+
return data_;
|
|
137
|
+
}
|
|
138
|
+
static TT GetT() { return (TT)0; }
|
|
139
|
+
const Span<T>& AsSpan() {
|
|
140
|
+
if (!shape_.has_value() || shape_->size() != 1) {
|
|
141
|
+
ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
|
|
142
|
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
143
|
+
}
|
|
144
|
+
span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
|
|
145
|
+
return span_;
|
|
146
|
+
}
|
|
147
|
+
const T& AsScalar() {
|
|
148
|
+
if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
|
|
149
|
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
|
|
150
|
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
151
|
+
}
|
|
152
|
+
return *Data();
|
|
153
|
+
}
|
|
154
|
+
const void* DataRaw() const override {
|
|
155
|
+
return reinterpret_cast<const void*>(Data());
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
size_t SizeInBytes() const override {
|
|
159
|
+
return sizeof(TT) * static_cast<size_t>(NumberOfElement());
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
private:
|
|
163
|
+
ConstValue const_value_; // for input
|
|
164
|
+
TT* data_{}; // for output
|
|
165
|
+
Span<T> span_;
|
|
166
|
+
};
|
|
167
|
+
|
|
168
|
+
template <>
|
|
169
|
+
class Tensor<std::string> : public TensorBase {
|
|
170
|
+
public:
|
|
171
|
+
using strings = std::vector<std::string>;
|
|
172
|
+
|
|
173
|
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
|
174
|
+
if (is_input_) {
|
|
175
|
+
if (indice >= ctx_.GetInputCount()) {
|
|
176
|
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
|
177
|
+
}
|
|
178
|
+
auto const_value = ctx_.GetInput(indice);
|
|
179
|
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
|
180
|
+
shape_ = type_shape_info.GetShape();
|
|
181
|
+
auto num_chars = const_value.GetStringTensorDataLength();
|
|
182
|
+
// note - there will be copy ...
|
|
183
|
+
auto num_strings = static_cast<size_t>(NumberOfElement());
|
|
184
|
+
if (num_strings) {
|
|
185
|
+
std::vector<char> chars(num_chars + 1, '\0');
|
|
186
|
+
std::vector<size_t> offsets(num_strings);
|
|
187
|
+
const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
|
|
188
|
+
auto upper_bound = num_strings - 1;
|
|
189
|
+
input_strings_.resize(num_strings);
|
|
190
|
+
for (size_t i = upper_bound;; --i) {
|
|
191
|
+
if (i < upper_bound) {
|
|
192
|
+
chars[offsets[i + 1]] = '\0';
|
|
193
|
+
}
|
|
194
|
+
input_strings_[i] = chars.data() + offsets[i];
|
|
195
|
+
if (0 == i) {
|
|
196
|
+
break;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
const strings& Data() const {
|
|
203
|
+
return input_strings_;
|
|
204
|
+
}
|
|
205
|
+
const void* DataRaw() const override {
|
|
206
|
+
if (input_strings_.size() != 1) {
|
|
207
|
+
ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
|
208
|
+
}
|
|
209
|
+
return reinterpret_cast<const void*>(input_strings_[0].c_str());
|
|
210
|
+
}
|
|
211
|
+
size_t SizeInBytes() const override {
|
|
212
|
+
if (input_strings_.size() != 1) {
|
|
213
|
+
ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
|
214
|
+
}
|
|
215
|
+
return input_strings_[0].size();
|
|
216
|
+
}
|
|
217
|
+
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
|
218
|
+
shape_ = dims;
|
|
219
|
+
std::vector<const char*> raw;
|
|
220
|
+
for (const auto& s : ss) {
|
|
221
|
+
raw.push_back(s.data());
|
|
222
|
+
}
|
|
223
|
+
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
|
224
|
+
// note - there will be copy ...
|
|
225
|
+
output.FillStringTensor(raw.data(), raw.size());
|
|
226
|
+
}
|
|
227
|
+
const Span<std::string>& AsSpan() {
|
|
228
|
+
ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
229
|
+
}
|
|
230
|
+
const std::string& AsScalar() {
|
|
231
|
+
if (input_strings_.size() != 1) {
|
|
232
|
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
|
|
233
|
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
234
|
+
}
|
|
235
|
+
return input_strings_[0];
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
private:
|
|
239
|
+
std::vector<std::string> input_strings_; // for input
|
|
240
|
+
};
|
|
241
|
+
|
|
242
|
+
template <>
|
|
243
|
+
class Tensor<std::string_view> : public TensorBase {
|
|
244
|
+
public:
|
|
245
|
+
using strings = std::vector<std::string>;
|
|
246
|
+
using string_views = std::vector<std::string_view>;
|
|
247
|
+
|
|
248
|
+
Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
|
|
249
|
+
if (is_input_) {
|
|
250
|
+
if (indice >= ctx_.GetInputCount()) {
|
|
251
|
+
ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
|
|
252
|
+
}
|
|
253
|
+
auto const_value = ctx_.GetInput(indice);
|
|
254
|
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
|
255
|
+
shape_ = type_shape_info.GetShape();
|
|
256
|
+
auto num_chars = const_value.GetStringTensorDataLength();
|
|
257
|
+
chars_.resize(num_chars + 1, '\0');
|
|
258
|
+
auto num_strings = static_cast<size_t>(NumberOfElement());
|
|
259
|
+
if (num_strings) {
|
|
260
|
+
std::vector<size_t> offsets(num_strings);
|
|
261
|
+
const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
|
|
262
|
+
offsets.push_back(num_chars);
|
|
263
|
+
for (size_t i = 0; i < num_strings; ++i) {
|
|
264
|
+
input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
const string_views& Data() const {
|
|
270
|
+
return input_string_views_;
|
|
271
|
+
}
|
|
272
|
+
const void* DataRaw() const override {
|
|
273
|
+
if (input_string_views_.size() != 1) {
|
|
274
|
+
ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
|
275
|
+
}
|
|
276
|
+
return reinterpret_cast<const void*>(input_string_views_[0].data());
|
|
277
|
+
}
|
|
278
|
+
size_t SizeInBytes() const override {
|
|
279
|
+
if (input_string_views_.size() != 1) {
|
|
280
|
+
ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
|
|
281
|
+
}
|
|
282
|
+
return input_string_views_[0].size();
|
|
283
|
+
}
|
|
284
|
+
void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
|
|
285
|
+
shape_ = dims;
|
|
286
|
+
std::vector<const char*> raw;
|
|
287
|
+
for (const auto& s : ss) {
|
|
288
|
+
raw.push_back(s.data());
|
|
289
|
+
}
|
|
290
|
+
auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
|
|
291
|
+
// note - there will be copy ...
|
|
292
|
+
output.FillStringTensor(raw.data(), raw.size());
|
|
293
|
+
}
|
|
294
|
+
const Span<std::string_view>& AsSpan() {
|
|
295
|
+
ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
296
|
+
}
|
|
297
|
+
std::string_view AsScalar() {
|
|
298
|
+
if (input_string_views_.size() != 1) {
|
|
299
|
+
ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
|
|
300
|
+
OrtErrorCode::ORT_RUNTIME_EXCEPTION);
|
|
301
|
+
}
|
|
302
|
+
return input_string_views_[0];
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
private:
|
|
306
|
+
std::vector<char> chars_; // for input
|
|
307
|
+
std::vector<std::string_view> input_string_views_; // for input
|
|
308
|
+
};
|
|
309
|
+
|
|
310
|
+
using TensorPtr = std::unique_ptr<Custom::TensorBase>;
|
|
311
|
+
using TensorPtrs = std::vector<TensorPtr>;
|
|
312
|
+
|
|
313
|
+
struct TensorArray : public ArgBase {
|
|
314
|
+
TensorArray(OrtKernelContext* ctx,
|
|
315
|
+
size_t start_indice,
|
|
316
|
+
bool is_input) : ArgBase(ctx,
|
|
317
|
+
start_indice,
|
|
318
|
+
is_input) {
|
|
319
|
+
if (is_input) {
|
|
320
|
+
auto input_count = ctx_.GetInputCount();
|
|
321
|
+
for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
|
|
322
|
+
auto const_value = ctx_.GetInput(start_indice);
|
|
323
|
+
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
|
|
324
|
+
auto type = type_shape_info.GetElementType();
|
|
325
|
+
TensorPtr tensor;
|
|
326
|
+
switch (type) {
|
|
327
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
|
328
|
+
tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
|
|
329
|
+
break;
|
|
330
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
|
331
|
+
tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
|
|
332
|
+
break;
|
|
333
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
|
334
|
+
tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
|
|
335
|
+
break;
|
|
336
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
|
337
|
+
tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
|
|
338
|
+
break;
|
|
339
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
|
|
340
|
+
tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
|
|
341
|
+
break;
|
|
342
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
|
343
|
+
tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
|
|
344
|
+
break;
|
|
345
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
|
|
346
|
+
tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
|
|
347
|
+
break;
|
|
348
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
|
|
349
|
+
tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
|
|
350
|
+
break;
|
|
351
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
|
352
|
+
tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
|
|
353
|
+
break;
|
|
354
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
|
|
355
|
+
tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
|
|
356
|
+
break;
|
|
357
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
|
358
|
+
tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
|
|
359
|
+
break;
|
|
360
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
|
|
361
|
+
tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
|
|
362
|
+
break;
|
|
363
|
+
default:
|
|
364
|
+
ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
|
|
365
|
+
break;
|
|
366
|
+
}
|
|
367
|
+
tensors_.emplace_back(tensor.release());
|
|
368
|
+
} // for
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
template <typename T>
|
|
372
|
+
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
|
|
373
|
+
// ith_output is the indice of output relative to the tensor array
|
|
374
|
+
// indice_ + ith_output is the indice relative to context
|
|
375
|
+
auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
|
|
376
|
+
auto raw_output = tensor.get()->Allocate(shape);
|
|
377
|
+
tensors_.emplace_back(tensor.release());
|
|
378
|
+
return raw_output;
|
|
379
|
+
}
|
|
380
|
+
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
|
|
381
|
+
// ith_output is the indice of output relative to the tensor array
|
|
382
|
+
// indice_ + ith_output is the indice relative to context
|
|
383
|
+
auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
|
|
384
|
+
Tensor<std::string>& output = *tensor;
|
|
385
|
+
tensors_.emplace_back(tensor.release());
|
|
386
|
+
return output;
|
|
387
|
+
}
|
|
388
|
+
size_t Size() const {
|
|
389
|
+
return tensors_.size();
|
|
390
|
+
}
|
|
391
|
+
const TensorPtr& operator[](size_t ith_input) const {
|
|
392
|
+
// ith_input is the indice of output relative to the tensor array
|
|
393
|
+
return tensors_.at(ith_input);
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
private:
|
|
397
|
+
TensorPtrs tensors_;
|
|
398
|
+
};
|
|
399
|
+
|
|
400
|
+
using Variadic = TensorArray;
|
|
401
|
+
|
|
402
|
+
/*
|
|
403
|
+
Note:
|
|
404
|
+
OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
|
|
405
|
+
The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
|
|
406
|
+
1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy.
|
|
407
|
+
2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
|
|
408
|
+
hence memory could still be recycled properly.
|
|
409
|
+
Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
|
|
410
|
+
*/
|
|
411
|
+
struct OrtLiteCustomOp : public OrtCustomOp {
|
|
412
|
+
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
|
|
413
|
+
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
|
|
414
|
+
|
|
415
|
+
// CreateTuple
|
|
416
|
+
template <size_t ith_input, size_t ith_output, typename... Ts>
|
|
417
|
+
static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
|
|
418
|
+
CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
|
|
419
|
+
return std::make_tuple();
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
423
|
+
static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
|
|
424
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
425
|
+
std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
|
|
426
|
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
427
|
+
return std::tuple_cat(current, next);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
431
|
+
static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
|
|
432
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
433
|
+
std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
|
|
434
|
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
435
|
+
return std::tuple_cat(current, next);
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
#ifdef ORT_CUDA_CTX
|
|
439
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
440
|
+
static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
|
|
441
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
442
|
+
thread_local CudaContext cuda_context;
|
|
443
|
+
cuda_context.Init(*context);
|
|
444
|
+
std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
|
|
445
|
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
446
|
+
return std::tuple_cat(current, next);
|
|
447
|
+
}
|
|
448
|
+
#endif
|
|
449
|
+
|
|
450
|
+
#ifdef ORT_ROCM_CTX
|
|
451
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
452
|
+
static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
|
|
453
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
454
|
+
thread_local RocmContext rocm_context;
|
|
455
|
+
rocm_context.Init(*context);
|
|
456
|
+
std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
|
|
457
|
+
auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
458
|
+
return std::tuple_cat(current, next);
|
|
459
|
+
}
|
|
460
|
+
#endif
|
|
461
|
+
|
|
462
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
463
|
+
static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
|
|
464
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
465
|
+
args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
|
|
466
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
|
|
467
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
468
|
+
return std::tuple_cat(current, next);
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
472
|
+
static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
|
|
473
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
474
|
+
args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
|
|
475
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
|
|
476
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
|
|
477
|
+
return std::tuple_cat(current, next);
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
481
|
+
static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
|
|
482
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
483
|
+
args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
|
|
484
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
|
|
485
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
|
|
486
|
+
return std::tuple_cat(current, next);
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
|
|
490
|
+
static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
|
|
491
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
|
|
492
|
+
args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
|
|
493
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
|
|
494
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
|
|
495
|
+
return std::tuple_cat(current, next);
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
#define CREATE_TUPLE_INPUT(data_type) \
|
|
499
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
500
|
+
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
|
501
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
502
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
503
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
|
|
504
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
505
|
+
return std::tuple_cat(current, next); \
|
|
506
|
+
} \
|
|
507
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
508
|
+
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
|
509
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
510
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
511
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
|
|
512
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
513
|
+
return std::tuple_cat(current, next); \
|
|
514
|
+
} \
|
|
515
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
516
|
+
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
|
517
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
518
|
+
if (ith_input < num_input) { \
|
|
519
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
520
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
|
|
521
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
522
|
+
return std::tuple_cat(current, next); \
|
|
523
|
+
} else { \
|
|
524
|
+
std::tuple<T> current = std::tuple<T>{}; \
|
|
525
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
526
|
+
return std::tuple_cat(current, next); \
|
|
527
|
+
} \
|
|
528
|
+
} \
|
|
529
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
530
|
+
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
|
531
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
532
|
+
if ("CPUExecutionProvider" != ep) { \
|
|
533
|
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
|
534
|
+
} \
|
|
535
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
536
|
+
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
|
537
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
538
|
+
return std::tuple_cat(current, next); \
|
|
539
|
+
} \
|
|
540
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
541
|
+
static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
|
542
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
543
|
+
if ("CPUExecutionProvider" != ep) { \
|
|
544
|
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
|
545
|
+
} \
|
|
546
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
547
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
|
548
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
549
|
+
return std::tuple_cat(current, next); \
|
|
550
|
+
} \
|
|
551
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
552
|
+
static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
|
553
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
554
|
+
if (ith_input < num_input) { \
|
|
555
|
+
if ("CPUExecutionProvider" != ep) { \
|
|
556
|
+
ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
|
557
|
+
} \
|
|
558
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
559
|
+
std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
|
|
560
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
561
|
+
return std::tuple_cat(current, next); \
|
|
562
|
+
} else { \
|
|
563
|
+
std::tuple<T> current = std::tuple<T>{}; \
|
|
564
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
565
|
+
return std::tuple_cat(current, next); \
|
|
566
|
+
} \
|
|
567
|
+
} \
|
|
568
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
569
|
+
static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
|
|
570
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
571
|
+
if ("CPUExecutionProvider" != ep) { \
|
|
572
|
+
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
|
573
|
+
} \
|
|
574
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
575
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
|
|
576
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
577
|
+
return std::tuple_cat(current, next); \
|
|
578
|
+
} \
|
|
579
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
580
|
+
static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
|
|
581
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
582
|
+
if (ith_input < num_input) { \
|
|
583
|
+
if ("CPUExecutionProvider" != ep) { \
|
|
584
|
+
ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
|
|
585
|
+
} \
|
|
586
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
|
|
587
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
|
|
588
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
589
|
+
return std::tuple_cat(current, next); \
|
|
590
|
+
} else { \
|
|
591
|
+
std::tuple<T> current = std::tuple<T>{}; \
|
|
592
|
+
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
|
|
593
|
+
return std::tuple_cat(current, next); \
|
|
594
|
+
} \
|
|
595
|
+
}
|
|
596
|
+
#define CREATE_TUPLE_OUTPUT(data_type) \
|
|
597
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
598
|
+
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
|
|
599
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
600
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
|
601
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
|
|
602
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
|
603
|
+
return std::tuple_cat(current, next); \
|
|
604
|
+
} \
|
|
605
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
606
|
+
static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
|
|
607
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
608
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
|
609
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
|
|
610
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
|
611
|
+
return std::tuple_cat(current, next); \
|
|
612
|
+
} \
|
|
613
|
+
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
|
|
614
|
+
static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
|
|
615
|
+
CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
|
|
616
|
+
if (ith_output < num_output) { \
|
|
617
|
+
args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
|
|
618
|
+
std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
|
|
619
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
|
620
|
+
return std::tuple_cat(current, next); \
|
|
621
|
+
} else { \
|
|
622
|
+
std::tuple<T> current = std::tuple<T>{}; \
|
|
623
|
+
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
|
|
624
|
+
return std::tuple_cat(current, next); \
|
|
625
|
+
} \
|
|
626
|
+
}
|
|
627
|
+
#define CREATE_TUPLE(data_type) \
|
|
628
|
+
CREATE_TUPLE_INPUT(data_type) \
|
|
629
|
+
CREATE_TUPLE_OUTPUT(data_type)
|
|
630
|
+
|
|
631
|
+
CREATE_TUPLE(bool)
|
|
632
|
+
CREATE_TUPLE(float)
|
|
633
|
+
CREATE_TUPLE(Ort::Float16_t)
|
|
634
|
+
CREATE_TUPLE(Ort::BFloat16_t)
|
|
635
|
+
CREATE_TUPLE(double)
|
|
636
|
+
CREATE_TUPLE(int8_t)
|
|
637
|
+
CREATE_TUPLE(int16_t)
|
|
638
|
+
CREATE_TUPLE(int32_t)
|
|
639
|
+
CREATE_TUPLE(int64_t)
|
|
640
|
+
CREATE_TUPLE(uint8_t)
|
|
641
|
+
CREATE_TUPLE(uint16_t)
|
|
642
|
+
CREATE_TUPLE(uint32_t)
|
|
643
|
+
CREATE_TUPLE(uint64_t)
|
|
644
|
+
CREATE_TUPLE(std::string)
|
|
645
|
+
CREATE_TUPLE_INPUT(std::string_view)
|
|
646
|
+
CREATE_TUPLE(Ort::Float8E4M3FN_t)
|
|
647
|
+
CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
|
|
648
|
+
CREATE_TUPLE(Ort::Float8E5M2_t)
|
|
649
|
+
CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
|
|
650
|
+
|
|
651
|
+
// ParseArgs ...
|
|
652
|
+
template <typename... Ts>
|
|
653
|
+
static typename std::enable_if<0 == sizeof...(Ts)>::type
|
|
654
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
template <typename T, typename... Ts>
|
|
658
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
|
|
659
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
660
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
template <typename T, typename... Ts>
|
|
664
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
|
|
665
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
666
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
#ifdef ORT_CUDA_CTX
|
|
670
|
+
template <typename T, typename... Ts>
|
|
671
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
|
|
672
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
673
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
674
|
+
}
|
|
675
|
+
#endif
|
|
676
|
+
|
|
677
|
+
#ifdef ORT_ROCM_CTX
|
|
678
|
+
template <typename T, typename... Ts>
|
|
679
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
|
|
680
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
681
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
682
|
+
}
|
|
683
|
+
#endif
|
|
684
|
+
|
|
685
|
+
template <typename T, typename... Ts>
|
|
686
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
|
|
687
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
688
|
+
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
|
689
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
template <typename T, typename... Ts>
|
|
693
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
|
|
694
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
695
|
+
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
|
696
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
template <typename T, typename... Ts>
|
|
700
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
|
|
701
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
702
|
+
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
|
703
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
template <typename T, typename... Ts>
|
|
707
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
|
|
708
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
|
|
709
|
+
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
|
|
710
|
+
ParseArgs<Ts...>(input_types, output_types);
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
#define PARSE_INPUT_BASE(pack_type, onnx_type) \
|
|
714
|
+
template <typename T, typename... Ts> \
|
|
715
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
|
|
716
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
717
|
+
input_types.push_back(onnx_type); \
|
|
718
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
719
|
+
} \
|
|
720
|
+
template <typename T, typename... Ts> \
|
|
721
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
|
|
722
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
723
|
+
input_types.push_back(onnx_type); \
|
|
724
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
725
|
+
} \
|
|
726
|
+
template <typename T, typename... Ts> \
|
|
727
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
|
|
728
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
729
|
+
input_types.push_back(onnx_type); \
|
|
730
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
#define PARSE_INPUT(data_type, onnx_type) \
|
|
734
|
+
PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
|
|
735
|
+
PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
|
|
736
|
+
PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
|
|
737
|
+
PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
|
|
738
|
+
PARSE_INPUT_BASE(data_type, onnx_type)
|
|
739
|
+
|
|
740
|
+
#define PARSE_OUTPUT(data_type, onnx_type) \
|
|
741
|
+
template <typename T, typename... Ts> \
|
|
742
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
|
|
743
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
744
|
+
output_types.push_back(onnx_type); \
|
|
745
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
746
|
+
} \
|
|
747
|
+
template <typename T, typename... Ts> \
|
|
748
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
|
|
749
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
750
|
+
output_types.push_back(onnx_type); \
|
|
751
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
752
|
+
} \
|
|
753
|
+
template <typename T, typename... Ts> \
|
|
754
|
+
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
|
|
755
|
+
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
|
|
756
|
+
output_types.push_back(onnx_type); \
|
|
757
|
+
ParseArgs<Ts...>(input_types, output_types); \
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
#define PARSE_ARGS(data_type, onnx_type) \
|
|
761
|
+
PARSE_INPUT(data_type, onnx_type) \
|
|
762
|
+
PARSE_OUTPUT(data_type, onnx_type)
|
|
763
|
+
|
|
764
|
+
PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
|
|
765
|
+
PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
|
|
766
|
+
PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
|
|
767
|
+
PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
|
|
768
|
+
PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
|
|
769
|
+
PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
|
|
770
|
+
PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
|
|
771
|
+
PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
|
|
772
|
+
PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
|
|
773
|
+
PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
|
|
774
|
+
PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
|
|
775
|
+
PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
|
|
776
|
+
PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
|
|
777
|
+
PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
|
|
778
|
+
PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
|
|
779
|
+
PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
|
|
780
|
+
PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
|
|
781
|
+
PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
|
|
782
|
+
PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
|
|
783
|
+
|
|
784
|
+
OrtLiteCustomOp(const char* op_name,
|
|
785
|
+
const char* execution_provider,
|
|
786
|
+
ShapeInferFn shape_infer_fn,
|
|
787
|
+
int start_ver = 1,
|
|
788
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
|
|
789
|
+
execution_provider_(execution_provider),
|
|
790
|
+
shape_infer_fn_(shape_infer_fn),
|
|
791
|
+
start_ver_(start_ver),
|
|
792
|
+
end_ver_(end_ver) {
|
|
793
|
+
OrtCustomOp::version = ORT_API_VERSION;
|
|
794
|
+
|
|
795
|
+
OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
|
|
796
|
+
OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
|
|
797
|
+
OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
|
|
798
|
+
|
|
799
|
+
OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
|
|
800
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
801
|
+
return self->input_types_.size();
|
|
802
|
+
};
|
|
803
|
+
|
|
804
|
+
OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
|
|
805
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
806
|
+
return self->input_types_[indice];
|
|
807
|
+
};
|
|
808
|
+
|
|
809
|
+
OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
|
|
810
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
811
|
+
return self->output_types_.size();
|
|
812
|
+
};
|
|
813
|
+
|
|
814
|
+
OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
|
|
815
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
816
|
+
return self->output_types_[indice];
|
|
817
|
+
};
|
|
818
|
+
|
|
819
|
+
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
|
|
820
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
821
|
+
return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
|
|
822
|
+
};
|
|
823
|
+
|
|
824
|
+
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
|
|
825
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
826
|
+
return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
|
|
827
|
+
};
|
|
828
|
+
|
|
829
|
+
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
|
|
830
|
+
return 1;
|
|
831
|
+
};
|
|
832
|
+
|
|
833
|
+
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
|
|
834
|
+
return 0;
|
|
835
|
+
};
|
|
836
|
+
|
|
837
|
+
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
|
|
838
|
+
return 1;
|
|
839
|
+
};
|
|
840
|
+
|
|
841
|
+
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
|
|
842
|
+
return 0;
|
|
843
|
+
};
|
|
844
|
+
|
|
845
|
+
OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
|
|
846
|
+
OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
|
847
|
+
OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
|
|
848
|
+
OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
|
|
849
|
+
|
|
850
|
+
OrtCustomOp::CreateKernelV2 = {};
|
|
851
|
+
OrtCustomOp::KernelComputeV2 = {};
|
|
852
|
+
OrtCustomOp::KernelCompute = {};
|
|
853
|
+
|
|
854
|
+
OrtCustomOp::InferOutputShapeFn = {};
|
|
855
|
+
|
|
856
|
+
OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
|
|
857
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
858
|
+
return self->start_ver_;
|
|
859
|
+
};
|
|
860
|
+
|
|
861
|
+
OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
|
|
862
|
+
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
|
|
863
|
+
return self->end_ver_;
|
|
864
|
+
};
|
|
865
|
+
|
|
866
|
+
OrtCustomOp::GetMayInplace = {};
|
|
867
|
+
OrtCustomOp::ReleaseMayInplace = {};
|
|
868
|
+
OrtCustomOp::GetAliasMap = {};
|
|
869
|
+
OrtCustomOp::ReleaseAliasMap = {};
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
const std::string op_name_;
|
|
873
|
+
const std::string execution_provider_;
|
|
874
|
+
|
|
875
|
+
std::vector<ONNXTensorElementDataType> input_types_;
|
|
876
|
+
std::vector<ONNXTensorElementDataType> output_types_;
|
|
877
|
+
|
|
878
|
+
ShapeInferFn shape_infer_fn_ = {};
|
|
879
|
+
|
|
880
|
+
int start_ver_ = 1;
|
|
881
|
+
int end_ver_ = MAX_CUSTOM_OP_END_VER;
|
|
882
|
+
|
|
883
|
+
void* compute_fn_ = {};
|
|
884
|
+
void* compute_fn_return_status_ = {};
|
|
885
|
+
};
|
|
886
|
+
|
|
887
|
+
//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
|
|
888
|
+
// The struct is to implement function-as-op.
|
|
889
|
+
// E.g. a function might be defined as:
|
|
890
|
+
// void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
|
|
891
|
+
// It could be registered this way:
|
|
892
|
+
// Ort::CustomOpDomain v2_domain{"v2"};
|
|
893
|
+
// std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
|
|
894
|
+
// v2_domain.Add(fil_op_ptr.get());
|
|
895
|
+
// session_options.Add(v2_domain);
|
|
896
|
+
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
|
897
|
+
template <typename... Args>
|
|
898
|
+
struct OrtLiteCustomFunc : public OrtLiteCustomOp {
|
|
899
|
+
using ComputeFn = void (*)(Args...);
|
|
900
|
+
using ComputeFnReturnStatus = Status (*)(Args...);
|
|
901
|
+
using MyType = OrtLiteCustomFunc<Args...>;
|
|
902
|
+
|
|
903
|
+
struct Kernel {
|
|
904
|
+
size_t num_input_{};
|
|
905
|
+
size_t num_output_{};
|
|
906
|
+
ComputeFn compute_fn_{};
|
|
907
|
+
ComputeFnReturnStatus compute_fn_return_status_{};
|
|
908
|
+
std::string ep_{};
|
|
909
|
+
};
|
|
910
|
+
|
|
911
|
+
OrtLiteCustomFunc(const char* op_name,
|
|
912
|
+
const char* execution_provider,
|
|
913
|
+
ComputeFn compute_fn,
|
|
914
|
+
ShapeInferFn shape_infer_fn = {},
|
|
915
|
+
int start_ver = 1,
|
|
916
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
|
|
917
|
+
compute_fn_ = reinterpret_cast<void*>(compute_fn);
|
|
918
|
+
ParseArgs<Args...>(input_types_, output_types_);
|
|
919
|
+
|
|
920
|
+
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
|
921
|
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
|
922
|
+
std::vector<ArgPtr> args;
|
|
923
|
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
|
924
|
+
std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
|
|
925
|
+
};
|
|
926
|
+
|
|
927
|
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
|
928
|
+
auto kernel = std::make_unique<Kernel>();
|
|
929
|
+
auto me = static_cast<const MyType*>(this_);
|
|
930
|
+
kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
|
|
931
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
|
932
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
|
933
|
+
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
|
|
934
|
+
kernel->ep_ = self->execution_provider_;
|
|
935
|
+
return reinterpret_cast<void*>(kernel.release());
|
|
936
|
+
};
|
|
937
|
+
|
|
938
|
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
|
939
|
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
|
940
|
+
};
|
|
941
|
+
|
|
942
|
+
if (shape_infer_fn_) {
|
|
943
|
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
|
944
|
+
auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
|
|
945
|
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
|
946
|
+
return shape_info_fn(ctx);
|
|
947
|
+
};
|
|
948
|
+
}
|
|
949
|
+
}
|
|
950
|
+
|
|
951
|
+
OrtLiteCustomFunc(const char* op_name,
|
|
952
|
+
const char* execution_provider,
|
|
953
|
+
ComputeFnReturnStatus compute_fn_return_status,
|
|
954
|
+
ShapeInferFn shape_infer_fn = {},
|
|
955
|
+
int start_ver = 1,
|
|
956
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
|
|
957
|
+
compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
|
|
958
|
+
ParseArgs<Args...>(input_types_, output_types_);
|
|
959
|
+
|
|
960
|
+
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
|
961
|
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
|
962
|
+
std::vector<ArgPtr> args;
|
|
963
|
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
|
964
|
+
return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
|
|
965
|
+
};
|
|
966
|
+
|
|
967
|
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
|
968
|
+
auto kernel = std::make_unique<Kernel>();
|
|
969
|
+
auto me = static_cast<const MyType*>(this_);
|
|
970
|
+
kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
|
|
971
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
|
972
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
|
973
|
+
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
|
|
974
|
+
kernel->ep_ = self->execution_provider_;
|
|
975
|
+
return reinterpret_cast<void*>(kernel.release());
|
|
976
|
+
};
|
|
977
|
+
|
|
978
|
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
|
979
|
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
|
980
|
+
};
|
|
981
|
+
|
|
982
|
+
if (shape_infer_fn_) {
|
|
983
|
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
|
984
|
+
auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
|
|
985
|
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
|
986
|
+
return shape_info_fn(ctx);
|
|
987
|
+
};
|
|
988
|
+
}
|
|
989
|
+
}
|
|
990
|
+
}; // struct OrtLiteCustomFunc
|
|
991
|
+
|
|
992
|
+
/////////////////////////// OrtLiteCustomStruct ///////////////////////////
|
|
993
|
+
// The struct is to implement struct-as-op.
|
|
994
|
+
// E.g. a struct might be defined as:
|
|
995
|
+
// struct Merge {
|
|
996
|
+
// Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
|
|
997
|
+
// void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
|
|
998
|
+
// std::string_view string_in,
|
|
999
|
+
// Ort::Custom::Tensor<std::string>* strings_out) {...}
|
|
1000
|
+
// bool reverse_ = false;
|
|
1001
|
+
// };
|
|
1002
|
+
// It could be registered this way:
|
|
1003
|
+
// Ort::CustomOpDomain v2_domain{"v2"};
|
|
1004
|
+
// std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
|
|
1005
|
+
// v2_domain.Add(mrg_op_ptr.get());
|
|
1006
|
+
// session_options.Add(v2_domain);
|
|
1007
|
+
// For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
|
|
1008
|
+
template <typename CustomOp>
|
|
1009
|
+
struct OrtLiteCustomStruct : public OrtLiteCustomOp {
|
|
1010
|
+
template <typename... Args>
|
|
1011
|
+
using CustomComputeFn = void (CustomOp::*)(Args...);
|
|
1012
|
+
|
|
1013
|
+
template <typename... Args>
|
|
1014
|
+
using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
|
|
1015
|
+
|
|
1016
|
+
using MyType = OrtLiteCustomStruct<CustomOp>;
|
|
1017
|
+
|
|
1018
|
+
struct Kernel {
|
|
1019
|
+
size_t num_input_{};
|
|
1020
|
+
size_t num_output_{};
|
|
1021
|
+
std::unique_ptr<CustomOp> custom_op_;
|
|
1022
|
+
std::string ep_{};
|
|
1023
|
+
};
|
|
1024
|
+
|
|
1025
|
+
OrtLiteCustomStruct(const char* op_name,
|
|
1026
|
+
const char* execution_provider,
|
|
1027
|
+
int start_ver = 1,
|
|
1028
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
|
|
1029
|
+
SetCompute(&CustomOp::Compute);
|
|
1030
|
+
|
|
1031
|
+
OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
|
|
1032
|
+
auto kernel = std::make_unique<Kernel>();
|
|
1033
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
|
|
1034
|
+
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
|
|
1035
|
+
kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
|
|
1036
|
+
auto self = static_cast<const OrtLiteCustomStruct*>(this_);
|
|
1037
|
+
kernel->ep_ = self->execution_provider_;
|
|
1038
|
+
return reinterpret_cast<void*>(kernel.release());
|
|
1039
|
+
};
|
|
1040
|
+
|
|
1041
|
+
OrtCustomOp::KernelDestroy = [](void* op_kernel) {
|
|
1042
|
+
delete reinterpret_cast<Kernel*>(op_kernel);
|
|
1043
|
+
};
|
|
1044
|
+
|
|
1045
|
+
SetShapeInfer<CustomOp>(0);
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
template <typename... Args>
|
|
1049
|
+
void SetCompute(CustomComputeFn<Args...>) {
|
|
1050
|
+
ParseArgs<Args...>(input_types_, output_types_);
|
|
1051
|
+
OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
|
|
1052
|
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
|
1053
|
+
ArgPtrs args;
|
|
1054
|
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
|
1055
|
+
std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
|
|
1056
|
+
};
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
template <typename... Args>
|
|
1060
|
+
void SetCompute(CustomComputeFnReturnStatus<Args...>) {
|
|
1061
|
+
ParseArgs<Args...>(input_types_, output_types_);
|
|
1062
|
+
OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
|
|
1063
|
+
auto kernel = reinterpret_cast<Kernel*>(op_kernel);
|
|
1064
|
+
ArgPtrs args;
|
|
1065
|
+
auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
|
|
1066
|
+
return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
|
|
1067
|
+
};
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
template <typename C>
|
|
1071
|
+
decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
|
|
1072
|
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
|
1073
|
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
|
1074
|
+
return C::InferOutputShape(ctx);
|
|
1075
|
+
};
|
|
1076
|
+
return {};
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
template <typename C>
|
|
1080
|
+
void SetShapeInfer(...) {
|
|
1081
|
+
OrtCustomOp::InferOutputShapeFn = {};
|
|
1082
|
+
}
|
|
1083
|
+
}; // struct OrtLiteCustomStruct
|
|
1084
|
+
|
|
1085
|
+
/////////////////////////// CreateLiteCustomOp ////////////////////////////
|
|
1086
|
+
|
|
1087
|
+
template <typename... Args>
|
|
1088
|
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
|
1089
|
+
const char* execution_provider,
|
|
1090
|
+
void (*custom_compute_fn)(Args...),
|
|
1091
|
+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
|
|
1092
|
+
int start_ver = 1,
|
|
1093
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
|
1094
|
+
using LiteOp = OrtLiteCustomFunc<Args...>;
|
|
1095
|
+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
template <typename... Args>
|
|
1099
|
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
|
1100
|
+
const char* execution_provider,
|
|
1101
|
+
Status (*custom_compute_fn_v2)(Args...),
|
|
1102
|
+
Status (*shape_infer_fn)(ShapeInferContext&) = {},
|
|
1103
|
+
int start_ver = 1,
|
|
1104
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
|
1105
|
+
using LiteOp = OrtLiteCustomFunc<Args...>;
|
|
1106
|
+
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
template <typename CustomOp>
|
|
1110
|
+
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
|
|
1111
|
+
const char* execution_provider,
|
|
1112
|
+
int start_ver = 1,
|
|
1113
|
+
int end_ver = MAX_CUSTOM_OP_END_VER) {
|
|
1114
|
+
using LiteOp = OrtLiteCustomStruct<CustomOp>;
|
|
1115
|
+
return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
} // namespace Custom
|
|
1119
|
+
} // namespace Ort
|