com.github.asus4.onnxruntime 0.1.10 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Plugins/Android/onnxruntime-android.aar +0 -0
- package/Plugins/Linux/x64/libonnxruntime.so +0 -0
- package/Plugins/Windows/x64/onnxruntime.dll +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +13 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/coreml_provider_factory.h +45 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/cpu_provider_factory.h +19 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_c_api.h +4717 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +2372 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +2075 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_float16.h +540 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Info.plist +20 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/macOS/libonnxruntime.dylib +0 -0
- package/README.md +8 -8
- package/Runtime/NativeMethods.shared.cs +270 -276
- package/Runtime/OrtValue.shared.cs +7 -3
- package/Runtime/Training/NativeTrainingMethods.shared.cs +2 -2
- package/package.json +1 -1
|
@@ -0,0 +1,2075 @@
|
|
|
1
|
+
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
// Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
// Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
|
|
5
|
+
// If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
|
|
6
|
+
//
|
|
7
|
+
// These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
|
|
8
|
+
// the main C++ file with implementation details.
|
|
9
|
+
|
|
10
|
+
#include <cstring>
|
|
11
|
+
#include <functional>
|
|
12
|
+
|
|
13
|
+
#define RETURN_ON_API_FAIL(expression) \
|
|
14
|
+
{ \
|
|
15
|
+
auto err = (expression); \
|
|
16
|
+
if (err) { \
|
|
17
|
+
return Status(err); \
|
|
18
|
+
} \
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
namespace Ort {
|
|
22
|
+
|
|
23
|
+
namespace detail {
|
|
24
|
+
inline void ThrowStatus(const Status& st) {
|
|
25
|
+
std::string error_message = st.GetErrorMessage();
|
|
26
|
+
OrtErrorCode error_code = st.GetErrorCode();
|
|
27
|
+
ORT_CXX_API_THROW(std::move(error_message), error_code);
|
|
28
|
+
}
|
|
29
|
+
} // namespace detail
|
|
30
|
+
|
|
31
|
+
inline void ThrowOnError(OrtStatus* ort_status) {
|
|
32
|
+
if (ort_status) {
|
|
33
|
+
Ort::Status st(ort_status);
|
|
34
|
+
detail::ThrowStatus(st);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
inline void ThrowOnError(const Status& st) {
|
|
39
|
+
if (st) {
|
|
40
|
+
detail::ThrowStatus(st);
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
inline Status::Status(const std::exception& e) noexcept {
|
|
48
|
+
p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
inline Status::Status(const Exception& e) noexcept {
|
|
52
|
+
p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
inline Status::Status(const char* message, OrtErrorCode code) noexcept {
|
|
56
|
+
p_ = GetApi().CreateStatus(code, message);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
inline std::string Status::GetErrorMessage() const {
|
|
60
|
+
std::string message(GetApi().GetErrorMessage(p_));
|
|
61
|
+
return message;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
inline OrtErrorCode Status::GetErrorCode() const {
|
|
65
|
+
return GetApi().GetErrorCode(p_);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
inline bool Status::IsOK() const noexcept {
|
|
69
|
+
return (p_ == nullptr);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// This template converts a C++ type into it's ONNXTensorElementDataType
|
|
73
|
+
template <typename T>
|
|
74
|
+
struct TypeToTensorType;
|
|
75
|
+
template <>
|
|
76
|
+
struct TypeToTensorType<float> {
|
|
77
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
78
|
+
};
|
|
79
|
+
template <>
|
|
80
|
+
struct TypeToTensorType<Float16_t> {
|
|
81
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
|
82
|
+
};
|
|
83
|
+
template <>
|
|
84
|
+
struct TypeToTensorType<BFloat16_t> {
|
|
85
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
|
|
86
|
+
};
|
|
87
|
+
template <>
|
|
88
|
+
struct TypeToTensorType<double> {
|
|
89
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
|
90
|
+
};
|
|
91
|
+
template <>
|
|
92
|
+
struct TypeToTensorType<int8_t> {
|
|
93
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
|
|
94
|
+
};
|
|
95
|
+
template <>
|
|
96
|
+
struct TypeToTensorType<int16_t> {
|
|
97
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
|
|
98
|
+
};
|
|
99
|
+
template <>
|
|
100
|
+
struct TypeToTensorType<int32_t> {
|
|
101
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
|
102
|
+
};
|
|
103
|
+
template <>
|
|
104
|
+
struct TypeToTensorType<int64_t> {
|
|
105
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
106
|
+
};
|
|
107
|
+
template <>
|
|
108
|
+
struct TypeToTensorType<uint8_t> {
|
|
109
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
|
110
|
+
};
|
|
111
|
+
template <>
|
|
112
|
+
struct TypeToTensorType<uint16_t> {
|
|
113
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
|
|
114
|
+
};
|
|
115
|
+
template <>
|
|
116
|
+
struct TypeToTensorType<uint32_t> {
|
|
117
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
|
|
118
|
+
};
|
|
119
|
+
template <>
|
|
120
|
+
struct TypeToTensorType<uint64_t> {
|
|
121
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
|
|
122
|
+
};
|
|
123
|
+
template <>
|
|
124
|
+
struct TypeToTensorType<bool> {
|
|
125
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
template <>
|
|
129
|
+
struct TypeToTensorType<Float8E4M3FN_t> {
|
|
130
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN;
|
|
131
|
+
};
|
|
132
|
+
template <>
|
|
133
|
+
struct TypeToTensorType<Float8E4M3FNUZ_t> {
|
|
134
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ;
|
|
135
|
+
};
|
|
136
|
+
template <>
|
|
137
|
+
struct TypeToTensorType<Float8E5M2_t> {
|
|
138
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2;
|
|
139
|
+
};
|
|
140
|
+
template <>
|
|
141
|
+
struct TypeToTensorType<Float8E5M2FNUZ_t> {
|
|
142
|
+
static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ;
|
|
143
|
+
};
|
|
144
|
+
|
|
145
|
+
inline bool BFloat16_t::operator==(const BFloat16_t& rhs) const noexcept {
|
|
146
|
+
if (IsNaN() || rhs.IsNaN()) {
|
|
147
|
+
// IEEE defines that NaN is not equal to anything, including itself.
|
|
148
|
+
return false;
|
|
149
|
+
}
|
|
150
|
+
return val == rhs.val;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
inline bool BFloat16_t::operator<(const BFloat16_t& rhs) const noexcept {
|
|
154
|
+
if (IsNaN() || rhs.IsNaN()) {
|
|
155
|
+
// IEEE defines that NaN is unordered with respect to everything, including itself.
|
|
156
|
+
return false;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
const bool left_is_negative = IsNegative();
|
|
160
|
+
if (left_is_negative != rhs.IsNegative()) {
|
|
161
|
+
// When the signs of left and right differ, we know that left is less than right if it is
|
|
162
|
+
// the negative value. The exception to this is if both values are zero, in which case IEEE
|
|
163
|
+
// says they should be equal, even if the signs differ.
|
|
164
|
+
return left_is_negative && !AreZero(*this, rhs);
|
|
165
|
+
}
|
|
166
|
+
return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
|
|
170
|
+
: allocator_(allocator), p_(p), size_(size) {
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
inline MemoryAllocation::~MemoryAllocation() {
|
|
174
|
+
if (p_ != nullptr) {
|
|
175
|
+
// We do not throw out of destructor
|
|
176
|
+
auto ret = GetApi().AllocatorFree(allocator_, p_);
|
|
177
|
+
static_cast<void>(ret);
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
|
|
182
|
+
*this = std::move(o);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
|
|
186
|
+
OrtAllocator* alloc = nullptr;
|
|
187
|
+
void* p = nullptr;
|
|
188
|
+
size_t sz = 0;
|
|
189
|
+
|
|
190
|
+
// Swap out this
|
|
191
|
+
std::swap(alloc, allocator_);
|
|
192
|
+
std::swap(p, p_);
|
|
193
|
+
std::swap(sz, size_);
|
|
194
|
+
|
|
195
|
+
// Swap with incoming
|
|
196
|
+
std::swap(allocator_, o.allocator_);
|
|
197
|
+
std::swap(p_, o.p_);
|
|
198
|
+
std::swap(size_, o.size_);
|
|
199
|
+
|
|
200
|
+
// Destroy this instance if needed
|
|
201
|
+
MemoryAllocation this_alloc(alloc, p, sz);
|
|
202
|
+
return *this;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
namespace detail {
|
|
206
|
+
|
|
207
|
+
template <typename T>
|
|
208
|
+
inline void* AllocatorImpl<T>::Alloc(size_t size) {
|
|
209
|
+
void* out;
|
|
210
|
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
|
211
|
+
return out;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
template <typename T>
|
|
215
|
+
inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
|
|
216
|
+
void* out;
|
|
217
|
+
ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
|
|
218
|
+
MemoryAllocation result(this->p_, out, size);
|
|
219
|
+
return result;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
template <typename T>
|
|
223
|
+
inline void AllocatorImpl<T>::Free(void* p) {
|
|
224
|
+
ThrowOnError(GetApi().AllocatorFree(this->p_, p));
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
template <typename T>
|
|
228
|
+
inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
|
|
229
|
+
const OrtMemoryInfo* out;
|
|
230
|
+
ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
|
|
231
|
+
return ConstMemoryInfo{out};
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
} // namespace detail
|
|
235
|
+
|
|
236
|
+
inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
|
|
237
|
+
ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
|
|
241
|
+
ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
namespace detail {
|
|
245
|
+
|
|
246
|
+
template <typename T>
|
|
247
|
+
inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
|
|
248
|
+
const char* name = nullptr;
|
|
249
|
+
ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
|
|
250
|
+
return std::string(name);
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
template <typename T>
|
|
254
|
+
inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
|
|
255
|
+
OrtAllocatorType type;
|
|
256
|
+
ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
|
|
257
|
+
return type;
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
template <typename T>
|
|
261
|
+
inline int MemoryInfoImpl<T>::GetDeviceId() const {
|
|
262
|
+
int id = 0;
|
|
263
|
+
ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
|
|
264
|
+
return id;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
template <typename T>
|
|
268
|
+
inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
|
|
269
|
+
OrtMemoryInfoDeviceType type;
|
|
270
|
+
GetApi().MemoryInfoGetDeviceType(this->p_, &type);
|
|
271
|
+
return type;
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
template <typename T>
|
|
275
|
+
inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
|
|
276
|
+
OrtMemType type;
|
|
277
|
+
ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
|
|
278
|
+
return type;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
template <typename T>
|
|
282
|
+
template <typename U>
|
|
283
|
+
inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
|
|
284
|
+
int comp_result = 0;
|
|
285
|
+
ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
|
|
286
|
+
return comp_result == 0;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
} // namespace detail
|
|
290
|
+
|
|
291
|
+
inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
|
|
292
|
+
OrtMemoryInfo* p;
|
|
293
|
+
ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
|
|
294
|
+
return MemoryInfo(p);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
|
|
298
|
+
ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
namespace detail {
|
|
302
|
+
template <typename T>
|
|
303
|
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
|
|
304
|
+
AllocatorWithDefaultOptions allocator;
|
|
305
|
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
template <typename T>
|
|
309
|
+
inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
|
|
310
|
+
return binding_utils::GetOutputNamesHelper(this->p_, allocator);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
template <typename T>
|
|
314
|
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
|
|
315
|
+
AllocatorWithDefaultOptions allocator;
|
|
316
|
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
template <typename T>
|
|
320
|
+
inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
|
|
321
|
+
return binding_utils::GetOutputValuesHelper(this->p_, allocator);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
template <typename T>
|
|
325
|
+
inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
|
|
326
|
+
ThrowOnError(GetApi().BindInput(this->p_, name, value));
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
template <typename T>
|
|
330
|
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
|
|
331
|
+
ThrowOnError(GetApi().BindOutput(this->p_, name, value));
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
template <typename T>
|
|
335
|
+
inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
|
|
336
|
+
ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
template <typename T>
|
|
340
|
+
inline void IoBindingImpl<T>::ClearBoundInputs() {
|
|
341
|
+
GetApi().ClearBoundInputs(this->p_);
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
template <typename T>
|
|
345
|
+
inline void IoBindingImpl<T>::ClearBoundOutputs() {
|
|
346
|
+
GetApi().ClearBoundOutputs(this->p_);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
template <typename T>
|
|
350
|
+
inline void IoBindingImpl<T>::SynchronizeInputs() {
|
|
351
|
+
ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
template <typename T>
|
|
355
|
+
inline void IoBindingImpl<T>::SynchronizeOutputs() {
|
|
356
|
+
ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
namespace binding_utils {
|
|
360
|
+
inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
|
361
|
+
std::vector<std::string> result;
|
|
362
|
+
auto free_fn = detail::AllocatedFree(allocator);
|
|
363
|
+
using Ptr = std::unique_ptr<void, decltype(free_fn)>;
|
|
364
|
+
|
|
365
|
+
char* buffer = nullptr;
|
|
366
|
+
size_t* lengths = nullptr;
|
|
367
|
+
size_t count = 0;
|
|
368
|
+
ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
|
|
369
|
+
|
|
370
|
+
if (count == 0) {
|
|
371
|
+
return result;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
Ptr buffer_g(buffer, free_fn);
|
|
375
|
+
Ptr lengths_g(lengths, free_fn);
|
|
376
|
+
|
|
377
|
+
result.reserve(count);
|
|
378
|
+
for (size_t i = 0; i < count; ++i) {
|
|
379
|
+
auto sz = *lengths;
|
|
380
|
+
result.emplace_back(buffer, sz);
|
|
381
|
+
buffer += sz;
|
|
382
|
+
++lengths;
|
|
383
|
+
}
|
|
384
|
+
return result;
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
|
|
388
|
+
std::vector<Value> result;
|
|
389
|
+
size_t owned = 0;
|
|
390
|
+
size_t output_count = 0;
|
|
391
|
+
// Lambda to release the buffer when no longer needed and
|
|
392
|
+
// make sure that we destroy all instances on exception
|
|
393
|
+
auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
|
|
394
|
+
if (buffer) {
|
|
395
|
+
while (owned < output_count) {
|
|
396
|
+
auto* p = buffer + owned++;
|
|
397
|
+
GetApi().ReleaseValue(*p);
|
|
398
|
+
}
|
|
399
|
+
allocator->Free(allocator, buffer);
|
|
400
|
+
}
|
|
401
|
+
};
|
|
402
|
+
using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
|
|
403
|
+
|
|
404
|
+
OrtValue** output_buffer = nullptr;
|
|
405
|
+
ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
|
|
406
|
+
if (output_count == 0) {
|
|
407
|
+
return result;
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
Ptr buffer_g(output_buffer, free_fn);
|
|
411
|
+
|
|
412
|
+
result.reserve(output_count);
|
|
413
|
+
for (size_t i = 0; i < output_count; ++i) {
|
|
414
|
+
result.emplace_back(output_buffer[i]);
|
|
415
|
+
++owned;
|
|
416
|
+
}
|
|
417
|
+
return result;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
} // namespace binding_utils
|
|
421
|
+
} // namespace detail
|
|
422
|
+
|
|
423
|
+
inline IoBinding::IoBinding(Session& session) {
|
|
424
|
+
ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
|
|
428
|
+
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
inline ThreadingOptions::ThreadingOptions() {
|
|
432
|
+
ThrowOnError(GetApi().CreateThreadingOptions(&p_));
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
|
|
436
|
+
ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
|
|
437
|
+
return *this;
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
|
|
441
|
+
ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
|
|
442
|
+
return *this;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
|
|
446
|
+
ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
|
|
447
|
+
return *this;
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
|
|
451
|
+
ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
|
|
452
|
+
return *this;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
|
456
|
+
ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
|
|
457
|
+
return *this;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
|
461
|
+
ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
|
|
462
|
+
return *this;
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
|
466
|
+
ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
|
|
467
|
+
return *this;
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
471
|
+
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
|
|
472
|
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
473
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
474
|
+
} else {
|
|
475
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
476
|
+
}
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
|
480
|
+
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
|
|
481
|
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
482
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
483
|
+
} else {
|
|
484
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
489
|
+
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
|
|
490
|
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
491
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
492
|
+
} else {
|
|
493
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
|
|
498
|
+
OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
499
|
+
ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
|
|
500
|
+
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
501
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
|
|
502
|
+
} else {
|
|
503
|
+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
inline Env& Env::EnableTelemetryEvents() {
|
|
508
|
+
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
|
|
509
|
+
return *this;
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
inline Env& Env::DisableTelemetryEvents() {
|
|
513
|
+
ThrowOnError(GetApi().DisableTelemetryEvents(p_));
|
|
514
|
+
return *this;
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
|
|
518
|
+
ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
|
|
519
|
+
return *this;
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
|
|
523
|
+
ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
|
|
524
|
+
return *this;
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg) {
|
|
528
|
+
std::vector<const char*> keys, values;
|
|
529
|
+
auto num_entries = options.size();
|
|
530
|
+
if (num_entries > 0) {
|
|
531
|
+
keys.reserve(num_entries);
|
|
532
|
+
values.reserve(num_entries);
|
|
533
|
+
for (const auto& entry : options) {
|
|
534
|
+
keys.push_back(entry.first.c_str());
|
|
535
|
+
values.push_back(entry.second.c_str());
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
ThrowOnError(GetApi().CreateAndRegisterAllocatorV2(p_, provider_type.c_str(), mem_info, arena_cfg, keys.data(), values.data(), num_entries));
|
|
539
|
+
return *this;
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
|
543
|
+
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
inline void CustomOpDomain::Add(const OrtCustomOp* op) {
|
|
547
|
+
ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
inline RunOptions::RunOptions() {
|
|
551
|
+
ThrowOnError(GetApi().CreateRunOptions(&p_));
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
|
|
555
|
+
ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
|
|
556
|
+
return *this;
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
|
|
560
|
+
ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
|
|
561
|
+
return *this;
|
|
562
|
+
}
|
|
563
|
+
|
|
564
|
+
inline int RunOptions::GetRunLogVerbosityLevel() const {
|
|
565
|
+
int out;
|
|
566
|
+
ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
|
|
567
|
+
return out;
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
inline int RunOptions::GetRunLogSeverityLevel() const {
|
|
571
|
+
int out;
|
|
572
|
+
ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
|
|
573
|
+
return out;
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
|
|
577
|
+
ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
|
|
578
|
+
return *this;
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
inline const char* RunOptions::GetRunTag() const {
|
|
582
|
+
const char* out;
|
|
583
|
+
ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
|
|
584
|
+
return out;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
|
|
588
|
+
ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
|
|
589
|
+
return *this;
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
inline RunOptions& RunOptions::SetTerminate() {
|
|
593
|
+
ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
|
|
594
|
+
return *this;
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
inline RunOptions& RunOptions::UnsetTerminate() {
|
|
598
|
+
ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
|
|
599
|
+
return *this;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
namespace detail {
|
|
603
|
+
|
|
604
|
+
template <typename T>
|
|
605
|
+
inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
|
|
606
|
+
OrtSessionOptions* out;
|
|
607
|
+
ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
|
|
608
|
+
return SessionOptions{out};
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
template <typename T>
|
|
612
|
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
|
|
613
|
+
size_t size = 0;
|
|
614
|
+
// Feed nullptr for the data buffer to query the true size of the string value
|
|
615
|
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
|
|
616
|
+
|
|
617
|
+
std::string out;
|
|
618
|
+
out.resize(size);
|
|
619
|
+
Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
|
|
620
|
+
out.resize(size - 1); // remove the terminating character '\0'
|
|
621
|
+
|
|
622
|
+
return out;
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
template <typename T>
|
|
626
|
+
inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
|
|
627
|
+
int out = 0;
|
|
628
|
+
Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
|
|
629
|
+
return static_cast<bool>(out);
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
template <typename T>
|
|
633
|
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
|
|
634
|
+
if (!this->HasConfigEntry(config_key)) {
|
|
635
|
+
return def;
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
return this->GetConfigEntry(config_key);
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
template <typename T>
|
|
642
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
|
|
643
|
+
ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
|
|
644
|
+
return *this;
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
template <typename T>
|
|
648
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
|
|
649
|
+
ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
|
|
650
|
+
return *this;
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
template <typename T>
|
|
654
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
|
|
655
|
+
ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
|
|
656
|
+
return *this;
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
template <typename T>
|
|
660
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetDeterministicCompute(bool value) {
|
|
661
|
+
ThrowOnError(GetApi().SetDeterministicCompute(this->p_, value));
|
|
662
|
+
return *this;
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
template <typename T>
|
|
666
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
|
|
667
|
+
ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
|
|
668
|
+
return *this;
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
template <typename T>
|
|
672
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
|
673
|
+
ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
|
|
674
|
+
return *this;
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
template <typename T>
|
|
678
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
|
|
679
|
+
ThrowOnError(GetApi().DisableProfiling(this->p_));
|
|
680
|
+
return *this;
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
template <typename T>
|
|
684
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
|
|
685
|
+
ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
|
|
686
|
+
return *this;
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
template <typename T>
|
|
690
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
|
|
691
|
+
ThrowOnError(GetApi().EnableMemPattern(this->p_));
|
|
692
|
+
return *this;
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
template <typename T>
|
|
696
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
|
|
697
|
+
ThrowOnError(GetApi().DisableMemPattern(this->p_));
|
|
698
|
+
return *this;
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
template <typename T>
|
|
702
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
|
|
703
|
+
ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
|
|
704
|
+
return *this;
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
template <typename T>
|
|
708
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
|
|
709
|
+
ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
|
|
710
|
+
return *this;
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
template <typename T>
|
|
714
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
|
|
715
|
+
ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
|
|
716
|
+
return *this;
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
template <typename T>
|
|
720
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
|
|
721
|
+
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
|
|
722
|
+
return *this;
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
template <typename T>
|
|
726
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
|
|
727
|
+
ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
|
|
728
|
+
return *this;
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
template <typename T>
|
|
732
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
|
|
733
|
+
ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
|
|
734
|
+
return *this;
|
|
735
|
+
}
|
|
736
|
+
|
|
737
|
+
template <typename T>
|
|
738
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
|
|
739
|
+
ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
|
|
740
|
+
return *this;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
template <typename T>
|
|
744
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
|
|
745
|
+
ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
|
|
746
|
+
return *this;
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
template <typename T>
|
|
750
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
|
|
751
|
+
ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
|
|
752
|
+
return *this;
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
template <typename T>
|
|
756
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
|
|
757
|
+
const std::vector<Value>& ort_values) {
|
|
758
|
+
const size_t inputs_num = names.size();
|
|
759
|
+
if (inputs_num != ort_values.size()) {
|
|
760
|
+
ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
|
|
761
|
+
}
|
|
762
|
+
std::vector<const char*> names_ptr;
|
|
763
|
+
std::vector<const OrtValue*> ort_values_ptrs;
|
|
764
|
+
names_ptr.reserve(inputs_num);
|
|
765
|
+
ort_values_ptrs.reserve(inputs_num);
|
|
766
|
+
for (size_t i = 0; i < inputs_num; ++i) {
|
|
767
|
+
names_ptr.push_back(names[i].c_str());
|
|
768
|
+
ort_values_ptrs.push_back(ort_values[i]);
|
|
769
|
+
}
|
|
770
|
+
ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
|
|
771
|
+
return *this;
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
template <typename T>
|
|
775
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
|
|
776
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
|
|
777
|
+
return *this;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
template <typename T>
|
|
781
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
|
|
782
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
|
|
783
|
+
return *this;
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
template <typename T>
|
|
787
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
|
|
788
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
|
|
789
|
+
return *this;
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
template <typename T>
|
|
793
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
|
|
794
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
|
|
795
|
+
return *this;
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
template <typename T>
|
|
799
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
|
|
800
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
|
|
801
|
+
return *this;
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
template <typename T>
|
|
805
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
|
|
806
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
|
|
807
|
+
return *this;
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
template <typename T>
|
|
811
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
|
|
812
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
|
|
813
|
+
return *this;
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
template <typename T>
|
|
817
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
|
|
818
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
|
|
819
|
+
return *this;
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
template <typename T>
|
|
823
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
|
|
824
|
+
const std::string& provider_name,
|
|
825
|
+
const std::unordered_map<std::string, std::string>& provider_options) {
|
|
826
|
+
auto num_entries = provider_options.size();
|
|
827
|
+
std::vector<const char*> keys, values;
|
|
828
|
+
if (num_entries > 0) {
|
|
829
|
+
keys.reserve(num_entries);
|
|
830
|
+
values.reserve(num_entries);
|
|
831
|
+
|
|
832
|
+
for (const auto& entry : provider_options) {
|
|
833
|
+
keys.push_back(entry.first.c_str());
|
|
834
|
+
values.push_back(entry.second.c_str());
|
|
835
|
+
}
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
|
|
839
|
+
keys.data(), values.data(), num_entries));
|
|
840
|
+
|
|
841
|
+
return *this;
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
template <typename T>
|
|
845
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
|
846
|
+
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
|
|
847
|
+
return *this;
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
template <typename T>
|
|
851
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
|
|
852
|
+
ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
|
|
853
|
+
return *this;
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
template <typename T>
|
|
857
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
|
|
858
|
+
ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
|
|
859
|
+
return *this;
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
template <typename T>
|
|
863
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
|
|
864
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
|
|
865
|
+
return *this;
|
|
866
|
+
}
|
|
867
|
+
|
|
868
|
+
template <typename T>
|
|
869
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options) {
|
|
870
|
+
auto num_entries = provider_options.size();
|
|
871
|
+
std::vector<const char*> keys, values;
|
|
872
|
+
if (num_entries > 0) {
|
|
873
|
+
keys.reserve(num_entries);
|
|
874
|
+
values.reserve(num_entries);
|
|
875
|
+
|
|
876
|
+
for (const auto& entry : provider_options) {
|
|
877
|
+
keys.push_back(entry.first.c_str());
|
|
878
|
+
values.push_back(entry.second.c_str());
|
|
879
|
+
}
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO_V2(this->p_,
|
|
883
|
+
keys.data(), values.data(), num_entries));
|
|
884
|
+
|
|
885
|
+
return *this;
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
template <typename T>
|
|
889
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
|
|
890
|
+
const CustomOpConfigs& custom_op_configs) {
|
|
891
|
+
// Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
|
|
892
|
+
// the custom op library.
|
|
893
|
+
for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
|
|
894
|
+
AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
|
|
898
|
+
return *this;
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
template <typename T>
|
|
902
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
|
|
903
|
+
ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
|
|
904
|
+
return *this;
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
/// Session
|
|
908
|
+
template <typename T>
|
|
909
|
+
inline size_t ConstSessionImpl<T>::GetInputCount() const {
|
|
910
|
+
size_t out;
|
|
911
|
+
ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
|
|
912
|
+
return out;
|
|
913
|
+
}
|
|
914
|
+
|
|
915
|
+
template <typename T>
|
|
916
|
+
inline size_t ConstSessionImpl<T>::GetOutputCount() const {
|
|
917
|
+
size_t out;
|
|
918
|
+
ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
|
|
919
|
+
return out;
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
template <typename T>
|
|
923
|
+
inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
|
|
924
|
+
size_t out;
|
|
925
|
+
ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
|
|
926
|
+
return out;
|
|
927
|
+
}
|
|
928
|
+
|
|
929
|
+
template <typename T>
|
|
930
|
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
|
931
|
+
char* out;
|
|
932
|
+
ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
|
|
933
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
template <typename T>
|
|
937
|
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
|
938
|
+
char* out;
|
|
939
|
+
ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
|
|
940
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
941
|
+
}
|
|
942
|
+
|
|
943
|
+
template <typename T>
|
|
944
|
+
inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
|
|
945
|
+
char* out;
|
|
946
|
+
ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
|
|
947
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
template <typename T>
|
|
951
|
+
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
|
|
952
|
+
uint64_t out;
|
|
953
|
+
ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
|
|
954
|
+
return out;
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
template <typename T>
|
|
958
|
+
inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
|
|
959
|
+
OrtModelMetadata* out;
|
|
960
|
+
ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
|
|
961
|
+
return ModelMetadata{out};
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
template <typename T>
|
|
965
|
+
inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
|
|
966
|
+
OrtTypeInfo* out;
|
|
967
|
+
ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
|
|
968
|
+
return TypeInfo{out};
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
template <typename T>
|
|
972
|
+
inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
|
|
973
|
+
OrtTypeInfo* out;
|
|
974
|
+
ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
|
|
975
|
+
return TypeInfo{out};
|
|
976
|
+
}
|
|
977
|
+
|
|
978
|
+
template <typename T>
|
|
979
|
+
inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
|
|
980
|
+
OrtTypeInfo* out;
|
|
981
|
+
ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
|
|
982
|
+
return TypeInfo{out};
|
|
983
|
+
}
|
|
984
|
+
|
|
985
|
+
template <typename T>
|
|
986
|
+
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
987
|
+
const char* const* output_names, size_t output_count) {
|
|
988
|
+
std::vector<Value> output_values;
|
|
989
|
+
output_values.reserve(output_count);
|
|
990
|
+
for (size_t i = 0; i < output_count; i++)
|
|
991
|
+
output_values.emplace_back(nullptr);
|
|
992
|
+
Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
|
|
993
|
+
return output_values;
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
template <typename T>
|
|
997
|
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
998
|
+
const char* const* output_names, Value* output_values, size_t output_count) {
|
|
999
|
+
static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
|
1000
|
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
|
1001
|
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
|
1002
|
+
ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
|
|
1003
|
+
}
|
|
1004
|
+
|
|
1005
|
+
template <typename T>
|
|
1006
|
+
inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
|
|
1007
|
+
ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
|
|
1008
|
+
}
|
|
1009
|
+
|
|
1010
|
+
template <typename T>
|
|
1011
|
+
inline void SessionImpl<T>::RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
1012
|
+
const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data) {
|
|
1013
|
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
|
1014
|
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
|
1015
|
+
ThrowOnError(GetApi().RunAsync(this->p_, run_options, input_names,
|
|
1016
|
+
ort_input_values, input_count, output_names, output_count,
|
|
1017
|
+
ort_output_values, callback, user_data));
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
template <typename T>
|
|
1021
|
+
inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
|
|
1022
|
+
char* out = nullptr;
|
|
1023
|
+
ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
|
|
1024
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
} // namespace detail
|
|
1028
|
+
|
|
1029
|
+
inline SessionOptions::SessionOptions() {
|
|
1030
|
+
ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
/// CustomOpConfigs
|
|
1034
|
+
inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
|
|
1035
|
+
std::string config_key = "custom_op.";
|
|
1036
|
+
|
|
1037
|
+
config_key += custom_op_name;
|
|
1038
|
+
config_key += ".";
|
|
1039
|
+
config_key += config;
|
|
1040
|
+
|
|
1041
|
+
return config_key;
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
|
|
1045
|
+
const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
|
|
1046
|
+
flat_configs_[full_flat_key] = config_value;
|
|
1047
|
+
return *this;
|
|
1048
|
+
}
|
|
1049
|
+
|
|
1050
|
+
inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
|
|
1051
|
+
return flat_configs_;
|
|
1052
|
+
}
|
|
1053
|
+
|
|
1054
|
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
|
|
1055
|
+
ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
|
|
1059
|
+
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
|
1060
|
+
ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
|
|
1061
|
+
}
|
|
1062
|
+
|
|
1063
|
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
|
|
1064
|
+
ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
|
|
1068
|
+
const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
|
1069
|
+
ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
|
|
1070
|
+
prepacked_weights_container, &this->p_));
|
|
1071
|
+
}
|
|
1072
|
+
|
|
1073
|
+
inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
|
|
1074
|
+
char* out;
|
|
1075
|
+
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
|
1076
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
|
|
1080
|
+
char* out;
|
|
1081
|
+
ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
|
|
1082
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
|
|
1086
|
+
char* out;
|
|
1087
|
+
ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
|
|
1088
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1089
|
+
}
|
|
1090
|
+
|
|
1091
|
+
inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
|
|
1092
|
+
char* out;
|
|
1093
|
+
ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
|
|
1094
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
|
|
1098
|
+
char* out;
|
|
1099
|
+
ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
|
|
1100
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1101
|
+
}
|
|
1102
|
+
|
|
1103
|
+
inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
|
|
1104
|
+
char* out;
|
|
1105
|
+
ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
|
|
1106
|
+
return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
|
|
1107
|
+
}
|
|
1108
|
+
|
|
1109
|
+
inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
|
|
1110
|
+
auto deletor = detail::AllocatedFree(allocator);
|
|
1111
|
+
std::vector<AllocatedStringPtr> result;
|
|
1112
|
+
|
|
1113
|
+
char** out = nullptr;
|
|
1114
|
+
int64_t num_keys = 0;
|
|
1115
|
+
ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
|
|
1116
|
+
if (num_keys <= 0) {
|
|
1117
|
+
return result;
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
// array of pointers will be freed
|
|
1121
|
+
std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
|
|
1122
|
+
// reserve may throw
|
|
1123
|
+
auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
|
|
1124
|
+
std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
|
|
1125
|
+
result.reserve(static_cast<size_t>(num_keys));
|
|
1126
|
+
strings_guard.release();
|
|
1127
|
+
for (int64_t i = 0; i < num_keys; ++i) {
|
|
1128
|
+
result.push_back(AllocatedStringPtr(out[i], deletor));
|
|
1129
|
+
}
|
|
1130
|
+
|
|
1131
|
+
return result;
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
inline int64_t ModelMetadata::GetVersion() const {
|
|
1135
|
+
int64_t out;
|
|
1136
|
+
ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
|
|
1137
|
+
return out;
|
|
1138
|
+
}
|
|
1139
|
+
|
|
1140
|
+
namespace detail {
|
|
1141
|
+
|
|
1142
|
+
template <typename T>
|
|
1143
|
+
inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
|
|
1144
|
+
ONNXTensorElementDataType out;
|
|
1145
|
+
ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
|
|
1146
|
+
return out;
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1149
|
+
template <typename T>
|
|
1150
|
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
|
|
1151
|
+
size_t out;
|
|
1152
|
+
ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
|
|
1153
|
+
return static_cast<size_t>(out);
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
template <typename T>
|
|
1157
|
+
inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
|
|
1158
|
+
size_t out;
|
|
1159
|
+
ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
|
|
1160
|
+
return out;
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
template <typename T>
|
|
1164
|
+
inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
|
|
1165
|
+
ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
template <typename T>
|
|
1169
|
+
inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
|
|
1170
|
+
ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
template <typename T>
|
|
1174
|
+
inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
|
|
1175
|
+
std::vector<int64_t> out(GetDimensionsCount(), 0);
|
|
1176
|
+
ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
|
|
1177
|
+
return out;
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
template <typename T>
|
|
1181
|
+
inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
|
|
1182
|
+
const OrtTensorTypeAndShapeInfo* out;
|
|
1183
|
+
ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
|
|
1184
|
+
return ConstTensorTypeAndShapeInfo{out};
|
|
1185
|
+
}
|
|
1186
|
+
|
|
1187
|
+
template <typename T>
|
|
1188
|
+
inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
|
|
1189
|
+
const OrtSequenceTypeInfo* out;
|
|
1190
|
+
ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
|
|
1191
|
+
return ConstSequenceTypeInfo{out};
|
|
1192
|
+
}
|
|
1193
|
+
|
|
1194
|
+
template <typename T>
|
|
1195
|
+
inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
|
|
1196
|
+
const OrtMapTypeInfo* out;
|
|
1197
|
+
ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
|
|
1198
|
+
return ConstMapTypeInfo{out};
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
template <typename T>
|
|
1202
|
+
inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
|
|
1203
|
+
ONNXType out;
|
|
1204
|
+
ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
|
|
1205
|
+
return out;
|
|
1206
|
+
}
|
|
1207
|
+
|
|
1208
|
+
template <typename T>
|
|
1209
|
+
inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
|
|
1210
|
+
OrtTypeInfo* output;
|
|
1211
|
+
ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
|
|
1212
|
+
return TypeInfo{output};
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
template <typename T>
|
|
1216
|
+
inline TypeInfo OptionalTypeInfoImpl<T>::GetOptionalElementType() const {
|
|
1217
|
+
OrtTypeInfo* info;
|
|
1218
|
+
ThrowOnError(GetApi().GetOptionalContainedTypeInfo(this->p_, &info));
|
|
1219
|
+
return TypeInfo{info};
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
template <typename T>
|
|
1223
|
+
inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
|
|
1224
|
+
ONNXTensorElementDataType out;
|
|
1225
|
+
ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
|
|
1226
|
+
return out;
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
template <typename T>
|
|
1230
|
+
inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
|
|
1231
|
+
OrtTypeInfo* output;
|
|
1232
|
+
ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
|
|
1233
|
+
return TypeInfo{output};
|
|
1234
|
+
}
|
|
1235
|
+
|
|
1236
|
+
template <typename T>
|
|
1237
|
+
inline ConstOptionalTypeInfo TypeInfoImpl<T>::GetOptionalTypeInfo() const {
|
|
1238
|
+
const OrtOptionalTypeInfo* info;
|
|
1239
|
+
ThrowOnError(GetApi().CastTypeInfoToOptionalTypeInfo(this->p_, &info));
|
|
1240
|
+
return ConstOptionalTypeInfo{info};
|
|
1241
|
+
}
|
|
1242
|
+
|
|
1243
|
+
} // namespace detail
|
|
1244
|
+
|
|
1245
|
+
namespace detail {
|
|
1246
|
+
|
|
1247
|
+
template <typename T>
|
|
1248
|
+
template <typename R>
|
|
1249
|
+
inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
|
|
1250
|
+
ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
|
|
1251
|
+
}
|
|
1252
|
+
|
|
1253
|
+
template <typename T>
|
|
1254
|
+
inline bool ConstValueImpl<T>::IsTensor() const {
|
|
1255
|
+
int out;
|
|
1256
|
+
ThrowOnError(GetApi().IsTensor(this->p_, &out));
|
|
1257
|
+
return out != 0;
|
|
1258
|
+
}
|
|
1259
|
+
|
|
1260
|
+
template <typename T>
|
|
1261
|
+
inline bool ConstValueImpl<T>::HasValue() const {
|
|
1262
|
+
int out;
|
|
1263
|
+
ThrowOnError(GetApi().HasValue(this->p_, &out));
|
|
1264
|
+
return out != 0;
|
|
1265
|
+
}
|
|
1266
|
+
|
|
1267
|
+
template <typename T>
|
|
1268
|
+
inline size_t ConstValueImpl<T>::GetCount() const {
|
|
1269
|
+
size_t out;
|
|
1270
|
+
ThrowOnError(GetApi().GetValueCount(this->p_, &out));
|
|
1271
|
+
return out;
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1274
|
+
template <typename T>
|
|
1275
|
+
inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
|
|
1276
|
+
OrtValue* out;
|
|
1277
|
+
ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
|
|
1278
|
+
return Value{out};
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
template <typename T>
|
|
1282
|
+
inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
|
|
1283
|
+
size_t out;
|
|
1284
|
+
ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
|
|
1285
|
+
return out;
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
template <typename T>
|
|
1289
|
+
inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
|
|
1290
|
+
size_t out;
|
|
1291
|
+
ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
|
|
1292
|
+
return out;
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
template <typename T>
|
|
1296
|
+
template <typename R>
|
|
1297
|
+
inline const R* ConstValueImpl<T>::GetTensorData() const {
|
|
1298
|
+
R* out;
|
|
1299
|
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
|
|
1300
|
+
return out;
|
|
1301
|
+
}
|
|
1302
|
+
|
|
1303
|
+
template <typename T>
|
|
1304
|
+
inline const void* ConstValueImpl<T>::GetTensorRawData() const {
|
|
1305
|
+
void* out;
|
|
1306
|
+
ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
|
|
1307
|
+
return out;
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
template <typename T>
|
|
1311
|
+
inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
|
|
1312
|
+
OrtTypeInfo* output;
|
|
1313
|
+
ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
|
|
1314
|
+
return TypeInfo{output};
|
|
1315
|
+
}
|
|
1316
|
+
|
|
1317
|
+
template <typename T>
|
|
1318
|
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
|
|
1319
|
+
OrtTensorTypeAndShapeInfo* output;
|
|
1320
|
+
ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
|
|
1321
|
+
return TensorTypeAndShapeInfo{output};
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
template <typename T>
|
|
1325
|
+
inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
|
|
1326
|
+
const OrtMemoryInfo* mem_info;
|
|
1327
|
+
ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
|
|
1328
|
+
return ConstMemoryInfo(mem_info);
|
|
1329
|
+
}
|
|
1330
|
+
|
|
1331
|
+
template <typename T>
|
|
1332
|
+
inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
|
|
1333
|
+
ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1336
|
+
template <typename T>
|
|
1337
|
+
inline std::string ConstValueImpl<T>::GetStringTensorElement(size_t element_index) const {
|
|
1338
|
+
size_t buffer_length;
|
|
1339
|
+
ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &buffer_length));
|
|
1340
|
+
|
|
1341
|
+
std::string s;
|
|
1342
|
+
s.resize(buffer_length);
|
|
1343
|
+
ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, &s[0]));
|
|
1344
|
+
return s;
|
|
1345
|
+
}
|
|
1346
|
+
|
|
1347
|
+
template <typename T>
|
|
1348
|
+
inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
|
|
1349
|
+
ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
1353
|
+
template <typename T>
|
|
1354
|
+
inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
|
|
1355
|
+
OrtSparseFormat format;
|
|
1356
|
+
ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
|
|
1357
|
+
return format;
|
|
1358
|
+
}
|
|
1359
|
+
|
|
1360
|
+
template <typename T>
|
|
1361
|
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
|
|
1362
|
+
OrtTensorTypeAndShapeInfo* output;
|
|
1363
|
+
ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
|
|
1364
|
+
return TensorTypeAndShapeInfo{output};
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
template <typename T>
|
|
1368
|
+
inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
|
|
1369
|
+
OrtTensorTypeAndShapeInfo* output;
|
|
1370
|
+
ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
|
|
1371
|
+
return TensorTypeAndShapeInfo{output};
|
|
1372
|
+
}
|
|
1373
|
+
|
|
1374
|
+
template <typename T>
|
|
1375
|
+
template <typename R>
|
|
1376
|
+
inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
|
|
1377
|
+
const void* out;
|
|
1378
|
+
ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
|
|
1379
|
+
return reinterpret_cast<const R*>(out);
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
template <typename T>
|
|
1383
|
+
inline bool ConstValueImpl<T>::IsSparseTensor() const {
|
|
1384
|
+
int out;
|
|
1385
|
+
ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
|
|
1386
|
+
return out != 0;
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
template <typename T>
|
|
1390
|
+
template <typename R>
|
|
1391
|
+
inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
|
|
1392
|
+
const void* out;
|
|
1393
|
+
ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
|
|
1394
|
+
return reinterpret_cast<const R*>(out);
|
|
1395
|
+
}
|
|
1396
|
+
|
|
1397
|
+
#endif
|
|
1398
|
+
|
|
1399
|
+
template <typename T>
|
|
1400
|
+
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
|
|
1401
|
+
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
template <typename T>
|
|
1405
|
+
void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
|
|
1406
|
+
ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
template <typename T>
|
|
1410
|
+
inline char* ValueImpl<T>::GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length) {
|
|
1411
|
+
char* result;
|
|
1412
|
+
ThrowOnError(GetApi().GetResizedStringTensorElementBuffer(this->p_, index, buffer_length, &result));
|
|
1413
|
+
return result;
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1416
|
+
template <typename T>
|
|
1417
|
+
void* ValueImpl<T>::GetTensorMutableRawData() {
|
|
1418
|
+
void* out;
|
|
1419
|
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
|
|
1420
|
+
return out;
|
|
1421
|
+
}
|
|
1422
|
+
|
|
1423
|
+
template <typename T>
|
|
1424
|
+
template <typename R>
|
|
1425
|
+
R* ValueImpl<T>::GetTensorMutableData() {
|
|
1426
|
+
R* out;
|
|
1427
|
+
ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
|
|
1428
|
+
return out;
|
|
1429
|
+
}
|
|
1430
|
+
|
|
1431
|
+
template <typename T>
|
|
1432
|
+
template <typename R>
|
|
1433
|
+
R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
|
|
1434
|
+
static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
|
|
1435
|
+
R* out;
|
|
1436
|
+
ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
|
|
1437
|
+
return *out;
|
|
1438
|
+
}
|
|
1439
|
+
|
|
1440
|
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
1441
|
+
template <typename T>
|
|
1442
|
+
void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
|
|
1443
|
+
ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
|
|
1444
|
+
}
|
|
1445
|
+
|
|
1446
|
+
template <typename T>
|
|
1447
|
+
void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
|
|
1448
|
+
ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
|
|
1449
|
+
}
|
|
1450
|
+
|
|
1451
|
+
template <typename T>
|
|
1452
|
+
void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
|
|
1453
|
+
ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
|
|
1454
|
+
}
|
|
1455
|
+
|
|
1456
|
+
template <typename T>
|
|
1457
|
+
void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
|
|
1458
|
+
const int64_t* indices_data, size_t indices_num) {
|
|
1459
|
+
ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
|
|
1460
|
+
values_param.values_shape_len, values_param.data.p_data,
|
|
1461
|
+
indices_data, indices_num));
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
template <typename T>
|
|
1465
|
+
void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
|
|
1466
|
+
const OrtSparseValuesParam& values,
|
|
1467
|
+
const int64_t* inner_indices_data, size_t inner_indices_num,
|
|
1468
|
+
const int64_t* outer_indices_data, size_t outer_indices_num) {
|
|
1469
|
+
ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
|
1470
|
+
inner_indices_data, inner_indices_num,
|
|
1471
|
+
outer_indices_data, outer_indices_num));
|
|
1472
|
+
}
|
|
1473
|
+
|
|
1474
|
+
template <typename T>
|
|
1475
|
+
void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
|
|
1476
|
+
const OrtSparseValuesParam& values,
|
|
1477
|
+
const Shape& indices_shape,
|
|
1478
|
+
const int32_t* indices_data) {
|
|
1479
|
+
ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
|
|
1480
|
+
indices_shape.shape, indices_shape.shape_len,
|
|
1481
|
+
indices_data));
|
|
1482
|
+
}
|
|
1483
|
+
|
|
1484
|
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
1485
|
+
|
|
1486
|
+
} // namespace detail
|
|
1487
|
+
|
|
1488
|
+
template <typename T>
|
|
1489
|
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
|
|
1490
|
+
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
|
|
1491
|
+
}
|
|
1492
|
+
|
|
1493
|
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
|
|
1494
|
+
ONNXTensorElementDataType type) {
|
|
1495
|
+
OrtValue* out;
|
|
1496
|
+
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
|
1497
|
+
return Value{out};
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
template <typename T>
|
|
1501
|
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
|
|
1502
|
+
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
|
|
1506
|
+
OrtValue* out;
|
|
1507
|
+
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
|
1508
|
+
return Value{out};
|
|
1509
|
+
}
|
|
1510
|
+
|
|
1511
|
+
#if !defined(DISABLE_SPARSE_TENSORS)
|
|
1512
|
+
|
|
1513
|
+
template <typename T>
|
|
1514
|
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
|
|
1515
|
+
const Shape& values_shape) {
|
|
1516
|
+
return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
|
|
1517
|
+
}
|
|
1518
|
+
|
|
1519
|
+
inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
|
|
1520
|
+
const Shape& values_shape, ONNXTensorElementDataType type) {
|
|
1521
|
+
OrtValue* out;
|
|
1522
|
+
ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
|
|
1523
|
+
values_shape.shape, values_shape.shape_len, type, &out));
|
|
1524
|
+
return Value{out};
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
template <typename T>
|
|
1528
|
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
|
|
1529
|
+
return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
|
|
1530
|
+
}
|
|
1531
|
+
|
|
1532
|
+
inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
|
|
1533
|
+
ONNXTensorElementDataType type) {
|
|
1534
|
+
OrtValue* out;
|
|
1535
|
+
ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
|
|
1536
|
+
return Value{out};
|
|
1537
|
+
}
|
|
1538
|
+
#endif // !defined(DISABLE_SPARSE_TENSORS)
|
|
1539
|
+
|
|
1540
|
+
inline Value Value::CreateMap(const Value& keys, const Value& values) {
|
|
1541
|
+
OrtValue* out;
|
|
1542
|
+
const OrtValue* inputs[2] = {keys, values};
|
|
1543
|
+
ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
|
|
1544
|
+
return Value{out};
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
inline Value Value::CreateSequence(const std::vector<Value>& values) {
|
|
1548
|
+
OrtValue* out;
|
|
1549
|
+
std::vector<const OrtValue*> values_ort{values.data(), values.data() + values.size()};
|
|
1550
|
+
ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
|
|
1551
|
+
return Value{out};
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
template <typename T>
|
|
1555
|
+
inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
|
|
1556
|
+
OrtValue* out;
|
|
1557
|
+
ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
|
|
1558
|
+
return Value{out};
|
|
1559
|
+
}
|
|
1560
|
+
|
|
1561
|
+
//
|
|
1562
|
+
// Custom OP Inlines
|
|
1563
|
+
//
|
|
1564
|
+
inline Logger::Logger(const OrtLogger* logger) : logger_(logger) {
|
|
1565
|
+
Ort::ThrowOnError(GetApi().Logger_GetLoggingSeverityLevel(this->logger_, &this->cached_severity_level_));
|
|
1566
|
+
}
|
|
1567
|
+
|
|
1568
|
+
inline OrtLoggingLevel Logger::GetLoggingSeverityLevel() const noexcept {
|
|
1569
|
+
return cached_severity_level_;
|
|
1570
|
+
}
|
|
1571
|
+
|
|
1572
|
+
inline Status Logger::LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
|
|
1573
|
+
const char* func_name, const char* message) const noexcept {
|
|
1574
|
+
OrtStatus* status = GetApi().Logger_LogMessage(logger_, log_severity_level, message, file_path, line_number,
|
|
1575
|
+
func_name);
|
|
1576
|
+
return Status{status};
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
// Disable warnings about the format string not being a literal (-Wformat-nonliteral and -Wformat-security)
|
|
1580
|
+
// for gcc and clang. The alternative is to use actual C-style variadic parameters and apply
|
|
1581
|
+
// __attribute__(format(printf...)), which does not work with variadic templates.
|
|
1582
|
+
#if defined(__GNUC__)
|
|
1583
|
+
#pragma GCC diagnostic push
|
|
1584
|
+
#pragma GCC diagnostic ignored "-Wformat-nonliteral"
|
|
1585
|
+
#pragma GCC diagnostic ignored "-Wformat-security"
|
|
1586
|
+
#elif defined(__clang__)
|
|
1587
|
+
#pragma clang diagnostic push
|
|
1588
|
+
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
|
1589
|
+
#pragma clang diagnostic ignored "-Wformat-security"
|
|
1590
|
+
#endif
|
|
1591
|
+
template <typename... Args>
|
|
1592
|
+
inline Status Logger::LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path,
|
|
1593
|
+
int line_number, const char* func_name, const char* format,
|
|
1594
|
+
Args&&... args) const noexcept {
|
|
1595
|
+
int msg_len = std::snprintf(nullptr, 0U, format, std::forward<Args>(args)...);
|
|
1596
|
+
|
|
1597
|
+
if (msg_len < 0) { // Formatting error
|
|
1598
|
+
return Status("Failed to log message due to formatting error", OrtErrorCode::ORT_FAIL);
|
|
1599
|
+
}
|
|
1600
|
+
|
|
1601
|
+
OrtStatus* status = nullptr;
|
|
1602
|
+
const size_t buffer_size = static_cast<size_t>(msg_len) + 1U;
|
|
1603
|
+
|
|
1604
|
+
constexpr size_t kStackBufferSize = 1024;
|
|
1605
|
+
|
|
1606
|
+
if (buffer_size < kStackBufferSize) {
|
|
1607
|
+
char buffer[kStackBufferSize];
|
|
1608
|
+
snprintf(buffer, kStackBufferSize, format, std::forward<Args>(args)...);
|
|
1609
|
+
status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer, file_path, line_number, func_name);
|
|
1610
|
+
} else {
|
|
1611
|
+
// std::make_unique is only supported starting at C++14.
|
|
1612
|
+
#if (__cplusplus >= 201402L) || (_MSC_VER >= 1900)
|
|
1613
|
+
auto buffer = std::make_unique<char[]>(buffer_size);
|
|
1614
|
+
#else
|
|
1615
|
+
std::unique_ptr<char[]> buffer(new char[buffer_size]);
|
|
1616
|
+
#endif
|
|
1617
|
+
std::snprintf(buffer.get(), buffer_size, format, std::forward<Args>(args)...);
|
|
1618
|
+
status = GetApi().Logger_LogMessage(logger_, log_severity_level, buffer.get(), file_path, line_number, func_name);
|
|
1619
|
+
}
|
|
1620
|
+
|
|
1621
|
+
return Status{status};
|
|
1622
|
+
}
|
|
1623
|
+
// Re-enable -Wformat-nonliteral and -Wformat-security
|
|
1624
|
+
#if defined(__GNUC__)
|
|
1625
|
+
#pragma GCC diagnostic pop
|
|
1626
|
+
#elif defined(__clang__)
|
|
1627
|
+
#pragma clang diagnostic pop
|
|
1628
|
+
#endif
|
|
1629
|
+
|
|
1630
|
+
inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
|
|
1631
|
+
}
|
|
1632
|
+
|
|
1633
|
+
inline size_t KernelContext::GetInputCount() const {
|
|
1634
|
+
size_t out = 0;
|
|
1635
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
|
|
1636
|
+
return out;
|
|
1637
|
+
}
|
|
1638
|
+
|
|
1639
|
+
inline size_t KernelContext::GetOutputCount() const {
|
|
1640
|
+
size_t out = 0;
|
|
1641
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
|
|
1642
|
+
return out;
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1645
|
+
inline ConstValue KernelContext::GetInput(size_t index) const {
|
|
1646
|
+
const OrtValue* out = nullptr;
|
|
1647
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
|
|
1648
|
+
return ConstValue{out};
|
|
1649
|
+
}
|
|
1650
|
+
|
|
1651
|
+
inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
|
|
1652
|
+
OrtValue* out = nullptr;
|
|
1653
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
|
|
1654
|
+
return UnownedValue(out);
|
|
1655
|
+
}
|
|
1656
|
+
|
|
1657
|
+
inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
|
|
1658
|
+
OrtValue* out = nullptr;
|
|
1659
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
|
|
1660
|
+
return UnownedValue(out);
|
|
1661
|
+
}
|
|
1662
|
+
|
|
1663
|
+
inline void* KernelContext::GetGPUComputeStream() const {
|
|
1664
|
+
void* out = nullptr;
|
|
1665
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
|
|
1666
|
+
return out;
|
|
1667
|
+
}
|
|
1668
|
+
|
|
1669
|
+
inline OrtAllocator* KernelContext::GetAllocator(const OrtMemoryInfo& memory_info) const {
|
|
1670
|
+
OrtAllocator* out = nullptr;
|
|
1671
|
+
Ort::ThrowOnError(GetApi().KernelContext_GetAllocator(ctx_, &memory_info, &out));
|
|
1672
|
+
return out;
|
|
1673
|
+
}
|
|
1674
|
+
|
|
1675
|
+
inline Logger KernelContext::GetLogger() const {
|
|
1676
|
+
const OrtLogger* out = nullptr;
|
|
1677
|
+
ThrowOnError(GetApi().KernelContext_GetLogger(this->ctx_, &out));
|
|
1678
|
+
return Logger{out};
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {
|
|
1682
|
+
ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
|
|
1683
|
+
}
|
|
1684
|
+
|
|
1685
|
+
inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
|
|
1686
|
+
Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
|
|
1687
|
+
}
|
|
1688
|
+
|
|
1689
|
+
namespace detail {
|
|
1690
|
+
template <typename T>
|
|
1691
|
+
inline KernelInfo KernelInfoImpl<T>::Copy() const {
|
|
1692
|
+
OrtKernelInfo* info_copy = nullptr;
|
|
1693
|
+
Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
|
|
1694
|
+
return KernelInfo{info_copy};
|
|
1695
|
+
}
|
|
1696
|
+
|
|
1697
|
+
template <typename T>
|
|
1698
|
+
inline size_t KernelInfoImpl<T>::GetInputCount() const {
|
|
1699
|
+
size_t out = 0;
|
|
1700
|
+
ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
|
|
1701
|
+
return out;
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
template <typename T>
|
|
1705
|
+
inline size_t KernelInfoImpl<T>::GetOutputCount() const {
|
|
1706
|
+
size_t out = 0;
|
|
1707
|
+
ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
|
|
1708
|
+
return out;
|
|
1709
|
+
}
|
|
1710
|
+
|
|
1711
|
+
template <typename T>
|
|
1712
|
+
inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
|
|
1713
|
+
size_t size = 0;
|
|
1714
|
+
|
|
1715
|
+
// Feed nullptr for the data buffer to query the true size of the string value
|
|
1716
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
|
|
1717
|
+
|
|
1718
|
+
std::string out;
|
|
1719
|
+
out.resize(size);
|
|
1720
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
|
|
1721
|
+
out.resize(size - 1); // remove the terminating character '\0'
|
|
1722
|
+
|
|
1723
|
+
return out;
|
|
1724
|
+
}
|
|
1725
|
+
|
|
1726
|
+
template <typename T>
|
|
1727
|
+
inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
|
|
1728
|
+
size_t size = 0;
|
|
1729
|
+
|
|
1730
|
+
// Feed nullptr for the data buffer to query the true size of the string value
|
|
1731
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
|
|
1732
|
+
|
|
1733
|
+
std::string out;
|
|
1734
|
+
out.resize(size);
|
|
1735
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
|
|
1736
|
+
out.resize(size - 1); // remove the terminating character '\0'
|
|
1737
|
+
|
|
1738
|
+
return out;
|
|
1739
|
+
}
|
|
1740
|
+
|
|
1741
|
+
template <typename T>
|
|
1742
|
+
inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
|
|
1743
|
+
OrtTypeInfo* out = nullptr;
|
|
1744
|
+
ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
|
|
1745
|
+
return TypeInfo{out};
|
|
1746
|
+
}
|
|
1747
|
+
|
|
1748
|
+
template <typename T>
|
|
1749
|
+
inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
|
|
1750
|
+
OrtTypeInfo* out = nullptr;
|
|
1751
|
+
ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
|
|
1752
|
+
return TypeInfo{out};
|
|
1753
|
+
}
|
|
1754
|
+
|
|
1755
|
+
template <typename T>
|
|
1756
|
+
inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
|
|
1757
|
+
OrtValue* out = nullptr;
|
|
1758
|
+
ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
|
|
1759
|
+
return Value{out};
|
|
1760
|
+
}
|
|
1761
|
+
|
|
1762
|
+
template <typename T>
|
|
1763
|
+
inline ConstValue KernelInfoImpl<T>::GetTensorConstantInput(size_t index, int* is_constant) const {
|
|
1764
|
+
const OrtValue* out = nullptr;
|
|
1765
|
+
ThrowOnError(GetApi().KernelInfoGetConstantInput_tensor(this->p_, index, is_constant, &out));
|
|
1766
|
+
return ConstValue{out};
|
|
1767
|
+
}
|
|
1768
|
+
|
|
1769
|
+
template <typename T>
|
|
1770
|
+
inline std::string KernelInfoImpl<T>::GetNodeName() const {
|
|
1771
|
+
size_t size = 0;
|
|
1772
|
+
|
|
1773
|
+
// Feed nullptr for the data buffer to query the true size of the string value
|
|
1774
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, nullptr, &size));
|
|
1775
|
+
|
|
1776
|
+
std::string out;
|
|
1777
|
+
out.resize(size);
|
|
1778
|
+
Ort::ThrowOnError(GetApi().KernelInfo_GetNodeName(this->p_, &out[0], &size));
|
|
1779
|
+
out.resize(size - 1); // remove the terminating character '\0'
|
|
1780
|
+
|
|
1781
|
+
return out;
|
|
1782
|
+
}
|
|
1783
|
+
|
|
1784
|
+
template <typename T>
|
|
1785
|
+
inline Logger KernelInfoImpl<T>::GetLogger() const {
|
|
1786
|
+
const OrtLogger* out = nullptr;
|
|
1787
|
+
ThrowOnError(GetApi().KernelInfo_GetLogger(this->p_, &out));
|
|
1788
|
+
return Logger{out};
|
|
1789
|
+
}
|
|
1790
|
+
|
|
1791
|
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
|
|
1792
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
|
|
1793
|
+
}
|
|
1794
|
+
|
|
1795
|
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
|
|
1796
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
|
|
1797
|
+
}
|
|
1798
|
+
|
|
1799
|
+
inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
|
|
1800
|
+
size_t size = 0;
|
|
1801
|
+
// Feed nullptr for the data buffer to query the true size of the string attribute
|
|
1802
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
|
|
1803
|
+
|
|
1804
|
+
std::string out;
|
|
1805
|
+
out.resize(size);
|
|
1806
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
|
|
1807
|
+
out.resize(size - 1); // remove the terminating character '\0'
|
|
1808
|
+
out.swap(result);
|
|
1809
|
+
}
|
|
1810
|
+
|
|
1811
|
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
|
|
1812
|
+
size_t size = 0;
|
|
1813
|
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
|
1814
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
|
|
1815
|
+
|
|
1816
|
+
std::vector<float> out;
|
|
1817
|
+
out.resize(size);
|
|
1818
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
|
|
1819
|
+
out.swap(result);
|
|
1820
|
+
}
|
|
1821
|
+
|
|
1822
|
+
inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
|
|
1823
|
+
size_t size = 0;
|
|
1824
|
+
|
|
1825
|
+
// Feed nullptr for the data buffer to query the true size of the attribute
|
|
1826
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
|
|
1827
|
+
|
|
1828
|
+
std::vector<int64_t> out;
|
|
1829
|
+
out.resize(size);
|
|
1830
|
+
Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
|
|
1831
|
+
out.swap(result);
|
|
1832
|
+
}
|
|
1833
|
+
} // namespace detail
|
|
1834
|
+
|
|
1835
|
+
inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
|
|
1836
|
+
|
|
1837
|
+
inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
|
|
1838
|
+
|
|
1839
|
+
inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
|
|
1840
|
+
const char** type_constraint_names,
|
|
1841
|
+
const ONNXTensorElementDataType* type_constraint_values,
|
|
1842
|
+
size_t type_constraint_count,
|
|
1843
|
+
const OpAttr* attr_values, size_t attr_count,
|
|
1844
|
+
size_t input_count, size_t output_count) {
|
|
1845
|
+
static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
|
|
1846
|
+
"OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
|
|
1847
|
+
auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
|
|
1848
|
+
OrtOp* op;
|
|
1849
|
+
Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
|
|
1850
|
+
static_cast<int>(type_constraint_count),
|
|
1851
|
+
attr_input_values,
|
|
1852
|
+
static_cast<int>(attr_count),
|
|
1853
|
+
static_cast<int>(input_count),
|
|
1854
|
+
static_cast<int>(output_count), &op));
|
|
1855
|
+
return Op{op};
|
|
1856
|
+
}
|
|
1857
|
+
|
|
1858
|
+
inline void Op::Invoke(const OrtKernelContext* context,
|
|
1859
|
+
const Value* input_values,
|
|
1860
|
+
size_t input_count,
|
|
1861
|
+
Value* output_values,
|
|
1862
|
+
size_t output_count) {
|
|
1863
|
+
static_assert(sizeof(Value) == sizeof(OrtValue*),
|
|
1864
|
+
"Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
|
|
1865
|
+
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
|
|
1866
|
+
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
|
|
1867
|
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
|
|
1868
|
+
ort_output_values, static_cast<int>(output_count)));
|
|
1869
|
+
}
|
|
1870
|
+
|
|
1871
|
+
inline void Op::Invoke(const OrtKernelContext* context,
|
|
1872
|
+
const OrtValue* const* input_values,
|
|
1873
|
+
size_t input_count,
|
|
1874
|
+
OrtValue* const* output_values,
|
|
1875
|
+
size_t output_count) {
|
|
1876
|
+
Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
|
|
1877
|
+
output_values, static_cast<int>(output_count)));
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
inline std::string GetVersionString() {
|
|
1881
|
+
return OrtGetApiBase()->GetVersionString();
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
inline std::string GetBuildInfoString() {
|
|
1885
|
+
return GetApi().GetBuildInfoString();
|
|
1886
|
+
}
|
|
1887
|
+
|
|
1888
|
+
inline std::vector<std::string> GetAvailableProviders() {
|
|
1889
|
+
char** providers;
|
|
1890
|
+
int len;
|
|
1891
|
+
|
|
1892
|
+
auto release_fn = [&len](char** providers) {
|
|
1893
|
+
// This should always return nullptr.
|
|
1894
|
+
ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
|
|
1895
|
+
};
|
|
1896
|
+
|
|
1897
|
+
ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
|
|
1898
|
+
std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
|
|
1899
|
+
std::vector<std::string> available_providers;
|
|
1900
|
+
available_providers.reserve(static_cast<size_t>(len));
|
|
1901
|
+
for (int i = 0; i < len; ++i) {
|
|
1902
|
+
available_providers.emplace_back(providers[i]);
|
|
1903
|
+
}
|
|
1904
|
+
return available_providers;
|
|
1905
|
+
}
|
|
1906
|
+
|
|
1907
|
+
template <typename TOp, typename TKernel, bool WithStatus>
|
|
1908
|
+
void CustomOpBase<TOp, TKernel, WithStatus>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
|
|
1909
|
+
ConstSessionOptions options) const {
|
|
1910
|
+
const TOp* derived = static_cast<const TOp*>(this);
|
|
1911
|
+
std::vector<std::string> keys = derived->GetSessionConfigKeys();
|
|
1912
|
+
|
|
1913
|
+
out.reserve(keys.size());
|
|
1914
|
+
|
|
1915
|
+
std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
|
|
1916
|
+
const size_t prefix_size = config_entry_key.length();
|
|
1917
|
+
|
|
1918
|
+
for (const auto& key : keys) {
|
|
1919
|
+
config_entry_key.resize(prefix_size);
|
|
1920
|
+
config_entry_key.append(key);
|
|
1921
|
+
out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
|
|
1922
|
+
}
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1925
|
+
inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api,
|
|
1926
|
+
OrtShapeInferContext* ctx) : ort_api_(ort_api), ctx_(ctx) {
|
|
1927
|
+
size_t input_count = 0;
|
|
1928
|
+
Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputCount(ctx_, &input_count));
|
|
1929
|
+
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
|
|
1930
|
+
OrtTensorTypeAndShapeInfo* info{};
|
|
1931
|
+
Ort::ThrowOnError(ort_api_->ShapeInferContext_GetInputTypeShape(ctx, ith_input, &info));
|
|
1932
|
+
TensorTypeAndShapeInfo type_shape_info(info);
|
|
1933
|
+
auto integer_shape = type_shape_info.GetShape();
|
|
1934
|
+
std::vector<const char*> symbolic_shape(integer_shape.size(), {});
|
|
1935
|
+
type_shape_info.GetSymbolicDimensions(&symbolic_shape[0], integer_shape.size());
|
|
1936
|
+
Shape shape;
|
|
1937
|
+
for (size_t ith = 0; ith < integer_shape.size(); ++ith) {
|
|
1938
|
+
if (symbolic_shape[ith] && std::string{symbolic_shape[ith]}.size() > 0) {
|
|
1939
|
+
shape.emplace_back(symbolic_shape[ith]);
|
|
1940
|
+
} else {
|
|
1941
|
+
shape.emplace_back(integer_shape[ith]);
|
|
1942
|
+
}
|
|
1943
|
+
}
|
|
1944
|
+
input_shapes_.push_back(std::move(shape));
|
|
1945
|
+
type_shape_info.release();
|
|
1946
|
+
}
|
|
1947
|
+
}
|
|
1948
|
+
|
|
1949
|
+
inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) {
|
|
1950
|
+
OrtTensorTypeAndShapeInfo* info = {};
|
|
1951
|
+
RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info));
|
|
1952
|
+
|
|
1953
|
+
using InfoPtr = std::unique_ptr<OrtTensorTypeAndShapeInfo, std::function<void(OrtTensorTypeAndShapeInfo*)>>;
|
|
1954
|
+
|
|
1955
|
+
InfoPtr info_ptr(info, [this](OrtTensorTypeAndShapeInfo* obj) {
|
|
1956
|
+
ort_api_->ReleaseTensorTypeAndShapeInfo(obj);
|
|
1957
|
+
});
|
|
1958
|
+
|
|
1959
|
+
std::vector<int64_t> integer_dims;
|
|
1960
|
+
std::vector<const char*> symbolic_dims;
|
|
1961
|
+
|
|
1962
|
+
for (const auto dim : shape) {
|
|
1963
|
+
if (dim.IsInt()) {
|
|
1964
|
+
integer_dims.push_back(dim.IsInt());
|
|
1965
|
+
symbolic_dims.push_back("");
|
|
1966
|
+
} else {
|
|
1967
|
+
if (!dim.AsSym() || std::string{dim.AsSym()}.empty()) {
|
|
1968
|
+
ORT_CXX_API_THROW("Symbolic dim must not be an empty string", ORT_INVALID_ARGUMENT);
|
|
1969
|
+
}
|
|
1970
|
+
integer_dims.push_back(SymbolicInteger::INVALID_INT_DIM);
|
|
1971
|
+
symbolic_dims.push_back(dim.AsSym());
|
|
1972
|
+
}
|
|
1973
|
+
}
|
|
1974
|
+
|
|
1975
|
+
RETURN_ON_API_FAIL(ort_api_->SetDimensions(info, integer_dims.data(), integer_dims.size()));
|
|
1976
|
+
RETURN_ON_API_FAIL(ort_api_->SetSymbolicDimensions(info, symbolic_dims.data(), symbolic_dims.size()));
|
|
1977
|
+
RETURN_ON_API_FAIL(ort_api_->ShapeInferContext_SetOutputTypeShape(ctx_, indice, info));
|
|
1978
|
+
return Status{nullptr};
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
inline int64_t ShapeInferContext::GetAttrInt(const char* attr_name) {
|
|
1982
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
1983
|
+
int64_t i = {};
|
|
1984
|
+
size_t out = {};
|
|
1985
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INT, &i, sizeof(i), &out));
|
|
1986
|
+
return i;
|
|
1987
|
+
}
|
|
1988
|
+
|
|
1989
|
+
inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_name) {
|
|
1990
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
1991
|
+
int64_t i = {};
|
|
1992
|
+
size_t out = {};
|
|
1993
|
+
// first call to get the bytes needed
|
|
1994
|
+
auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out);
|
|
1995
|
+
if (status) {
|
|
1996
|
+
size_t num_i = out / sizeof(int64_t);
|
|
1997
|
+
ShapeInferContext::Ints ints(num_i, 0);
|
|
1998
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out));
|
|
1999
|
+
return ints;
|
|
2000
|
+
} else {
|
|
2001
|
+
return {i};
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
inline float ShapeInferContext::GetAttrFloat(const char* attr_name) {
|
|
2006
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
2007
|
+
float f = {};
|
|
2008
|
+
size_t out = {};
|
|
2009
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOAT, &f, sizeof(f), &out));
|
|
2010
|
+
return f;
|
|
2011
|
+
}
|
|
2012
|
+
|
|
2013
|
+
inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* attr_name) {
|
|
2014
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
2015
|
+
float f = {};
|
|
2016
|
+
size_t out = {};
|
|
2017
|
+
// first call to get the bytes needed
|
|
2018
|
+
auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out);
|
|
2019
|
+
if (status) {
|
|
2020
|
+
size_t num_f = out / sizeof(float);
|
|
2021
|
+
ShapeInferContext::Floats floats(num_f, 0);
|
|
2022
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out));
|
|
2023
|
+
return floats;
|
|
2024
|
+
} else {
|
|
2025
|
+
return {f};
|
|
2026
|
+
}
|
|
2027
|
+
}
|
|
2028
|
+
|
|
2029
|
+
inline std::string ShapeInferContext::GetAttrString(const char* attr_name) {
|
|
2030
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
2031
|
+
char c = {};
|
|
2032
|
+
size_t out = {};
|
|
2033
|
+
// first call to get the bytes needed
|
|
2034
|
+
auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, &c, sizeof(char), &out);
|
|
2035
|
+
if (status) {
|
|
2036
|
+
std::vector<char> chars(out, '\0');
|
|
2037
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out));
|
|
2038
|
+
return {chars.data()};
|
|
2039
|
+
} else {
|
|
2040
|
+
return {c};
|
|
2041
|
+
}
|
|
2042
|
+
}
|
|
2043
|
+
|
|
2044
|
+
inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* attr_name) {
|
|
2045
|
+
const auto* attr = GetAttrHdl(attr_name);
|
|
2046
|
+
char c = {};
|
|
2047
|
+
size_t out = {};
|
|
2048
|
+
// first call to get the bytes needed
|
|
2049
|
+
auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out);
|
|
2050
|
+
if (status) {
|
|
2051
|
+
std::vector<char> chars(out, '\0');
|
|
2052
|
+
Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, chars.data(), out, &out));
|
|
2053
|
+
ShapeInferContext::Strings strings;
|
|
2054
|
+
char* char_st = chars.data();
|
|
2055
|
+
char* char_ed = char_st + out;
|
|
2056
|
+
while (char_st < char_ed) {
|
|
2057
|
+
strings.emplace_back(char_st);
|
|
2058
|
+
while (*char_st != '\0') {
|
|
2059
|
+
char_st++;
|
|
2060
|
+
}
|
|
2061
|
+
char_st++;
|
|
2062
|
+
}
|
|
2063
|
+
return strings;
|
|
2064
|
+
} else {
|
|
2065
|
+
return {std::string{c}};
|
|
2066
|
+
}
|
|
2067
|
+
}
|
|
2068
|
+
|
|
2069
|
+
inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) const {
|
|
2070
|
+
const OrtOpAttr* attr_hdl = {};
|
|
2071
|
+
Ort::ThrowOnError(ort_api_->ShapeInferContext_GetAttribute(ctx_, attr_name, &attr_hdl));
|
|
2072
|
+
return attr_hdl;
|
|
2073
|
+
}
|
|
2074
|
+
|
|
2075
|
+
} // namespace Ort
|