com.github.asus4.onnxruntime 0.1.10 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. package/Plugins/Android/onnxruntime-android.aar +0 -0
  2. package/Plugins/Linux/x64/libonnxruntime.so +0 -0
  3. package/Plugins/Windows/x64/onnxruntime.dll +0 -0
  4. package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +13 -0
  5. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
  6. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
  7. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
  8. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  9. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  10. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
  11. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
  12. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
  13. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
  14. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
  15. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  16. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  17. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
  18. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
  19. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/coreml_provider_factory.h +45 -0
  20. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/cpu_provider_factory.h +19 -0
  21. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_c_api.h +4717 -0
  22. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +2372 -0
  23. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +2075 -0
  24. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_float16.h +540 -0
  25. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  26. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  27. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Info.plist +20 -0
  28. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/onnxruntime +0 -0
  29. package/Plugins/macOS/libonnxruntime.dylib +0 -0
  30. package/README.md +8 -8
  31. package/Runtime/NativeMethods.shared.cs +270 -276
  32. package/Runtime/OrtValue.shared.cs +7 -3
  33. package/Runtime/Training/NativeTrainingMethods.shared.cs +2 -2
  34. package/package.json +1 -1
@@ -0,0 +1,2372 @@
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5
+ //
6
+ // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7
+ // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8
+ // all the resources follow RAII and do not leak memory.
9
+ //
10
+ // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11
+ // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12
+ // until you assign an instance that actually holds an underlying object.
13
+ //
14
+ // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15
+ // Some objects have explicit 'Clone' methods for this purpose.
16
+ //
17
+ // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18
+ // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19
+ //
20
+ // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21
+ //
22
+ // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23
+ // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
+
25
+ #pragma once
26
+ #include "onnxruntime_c_api.h"
27
+ #include "onnxruntime_float16.h"
28
+
29
+ #include <cstddef>
30
+ #include <cstdio>
31
+ #include <array>
32
+ #include <memory>
33
+ #include <stdexcept>
34
+ #include <string>
35
+ #include <vector>
36
+ #include <unordered_map>
37
+ #include <utility>
38
+ #include <type_traits>
39
+
40
+ #ifdef ORT_NO_EXCEPTIONS
41
+ #include <iostream>
42
+ #endif
43
+
44
+ /** \brief All C++ Onnxruntime APIs are defined inside this namespace
45
+ *
46
+ */
47
+ namespace Ort {
48
+
49
+ /** \brief All C++ methods that can fail will throw an exception of this type
50
+ *
51
+ * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
52
+ */
53
+ struct Exception : std::exception {
54
+ Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
55
+
56
+ OrtErrorCode GetOrtErrorCode() const { return code_; }
57
+ const char* what() const noexcept override { return message_.c_str(); }
58
+
59
+ private:
60
+ std::string message_;
61
+ OrtErrorCode code_;
62
+ };
63
+
64
+ #ifdef ORT_NO_EXCEPTIONS
65
+ // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
66
+ // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
67
+ #ifndef ORT_CXX_API_THROW
68
+ #define ORT_CXX_API_THROW(string, code) \
69
+ do { \
70
+ std::cerr << Ort::Exception(string, code) \
71
+ .what() \
72
+ << std::endl; \
73
+ abort(); \
74
+ } while (false)
75
+ #endif
76
+ #else
77
+ #define ORT_CXX_API_THROW(string, code) \
78
+ throw Ort::Exception(string, code)
79
+ #endif
80
+
81
+ // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
82
+ // it's in a template so that we can define a global variable in a header and make
83
+ // it transparent to the users of the API.
84
+ template <typename T>
85
+ struct Global {
86
+ static const OrtApi* api_;
87
+ };
88
+
89
+ // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
90
+ template <typename T>
91
+ #ifdef ORT_API_MANUAL_INIT
92
+ const OrtApi* Global<T>::api_{};
93
+ inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
94
+
95
+ // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
96
+ // required by C++ APIs.
97
+ //
98
+ // Example mycustomop.cc:
99
+ //
100
+ // #define ORT_API_MANUAL_INIT
101
+ // #include <onnxruntime_cxx_api.h>
102
+ // #undef ORT_API_MANUAL_INIT
103
+ //
104
+ // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
105
+ // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
106
+ // // ...
107
+ // }
108
+ //
109
+ inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
110
+ #else
111
+ #if defined(_MSC_VER) && !defined(__clang__)
112
+ #pragma warning(push)
113
+ // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
114
+ // Please define ORT_API_MANUAL_INIT if it conerns you.
115
+ #pragma warning(disable : 26426)
116
+ #endif
117
+ const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
118
+ #if defined(_MSC_VER) && !defined(__clang__)
119
+ #pragma warning(pop)
120
+ #endif
121
+ #endif
122
+
123
+ /// This returns a reference to the OrtApi interface in use
124
+ inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
125
+
126
+ /// <summary>
127
+ /// This function returns the onnxruntime version string
128
+ /// </summary>
129
+ /// <returns>version string major.minor.rev</returns>
130
+ std::string GetVersionString();
131
+
132
+ /// <summary>
133
+ /// This function returns the onnxruntime build information: including git branch,
134
+ /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
135
+ /// </summary>
136
+ /// <returns>string</returns>
137
+ std::string GetBuildInfoString();
138
+
139
+ /// <summary>
140
+ /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
141
+ /// returns a vector of strings representing the available execution providers.
142
+ /// </summary>
143
+ /// <returns>vector of strings</returns>
144
+ std::vector<std::string> GetAvailableProviders();
145
+
146
+ /** \brief IEEE 754 half-precision floating point data type
147
+ *
148
+ * \details This struct is used for converting float to float16 and back
149
+ * so the user could feed inputs and fetch outputs using these type.
150
+ *
151
+ * The size of the structure should align with uint16_t and one can freely cast
152
+ * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
153
+ *
154
+ * \code{.unparsed}
155
+ * // This example demonstrates converion from float to float16
156
+ * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
157
+ * std::vector<Ort::Float16_t> fp16_values;
158
+ * fp16_values.reserve(std::size(values));
159
+ * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
160
+ * [](float value) { return Ort::Float16_t(value); });
161
+ *
162
+ * \endcode
163
+ */
164
+ struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
165
+ private:
166
+ /// <summary>
167
+ /// Constructor from a 16-bit representation of a float16 value
168
+ /// No conversion is done here.
169
+ /// </summary>
170
+ /// <param name="v">16-bit representation</param>
171
+ constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
172
+
173
+ public:
174
+ using Base = onnxruntime_float16::Float16Impl<Float16_t>;
175
+
176
+ /// <summary>
177
+ /// Default constructor
178
+ /// </summary>
179
+ Float16_t() = default;
180
+
181
+ /// <summary>
182
+ /// Explicit conversion to uint16_t representation of float16.
183
+ /// </summary>
184
+ /// <param name="v">uint16_t bit representation of float16</param>
185
+ /// <returns>new instance of Float16_t</returns>
186
+ constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
187
+
188
+ /// <summary>
189
+ /// __ctor from float. Float is converted into float16 16-bit representation.
190
+ /// </summary>
191
+ /// <param name="v">float value</param>
192
+ explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
193
+
194
+ /// <summary>
195
+ /// Converts float16 to float
196
+ /// </summary>
197
+ /// <returns>float representation of float16 value</returns>
198
+ float ToFloat() const noexcept { return Base::ToFloatImpl(); }
199
+
200
+ /// <summary>
201
+ /// Checks if the value is negative
202
+ /// </summary>
203
+ /// <returns>true if negative</returns>
204
+ using Base::IsNegative;
205
+
206
+ /// <summary>
207
+ /// Tests if the value is NaN
208
+ /// </summary>
209
+ /// <returns>true if NaN</returns>
210
+ using Base::IsNaN;
211
+
212
+ /// <summary>
213
+ /// Tests if the value is finite
214
+ /// </summary>
215
+ /// <returns>true if finite</returns>
216
+ using Base::IsFinite;
217
+
218
+ /// <summary>
219
+ /// Tests if the value represents positive infinity.
220
+ /// </summary>
221
+ /// <returns>true if positive infinity</returns>
222
+ using Base::IsPositiveInfinity;
223
+
224
+ /// <summary>
225
+ /// Tests if the value represents negative infinity
226
+ /// </summary>
227
+ /// <returns>true if negative infinity</returns>
228
+ using Base::IsNegativeInfinity;
229
+
230
+ /// <summary>
231
+ /// Tests if the value is either positive or negative infinity.
232
+ /// </summary>
233
+ /// <returns>True if absolute value is infinity</returns>
234
+ using Base::IsInfinity;
235
+
236
+ /// <summary>
237
+ /// Tests if the value is NaN or zero. Useful for comparisons.
238
+ /// </summary>
239
+ /// <returns>True if NaN or zero.</returns>
240
+ using Base::IsNaNOrZero;
241
+
242
+ /// <summary>
243
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
244
+ /// </summary>
245
+ /// <returns>True if so</returns>
246
+ using Base::IsNormal;
247
+
248
+ /// <summary>
249
+ /// Tests if the value is subnormal (denormal).
250
+ /// </summary>
251
+ /// <returns>True if so</returns>
252
+ using Base::IsSubnormal;
253
+
254
+ /// <summary>
255
+ /// Creates an instance that represents absolute value.
256
+ /// </summary>
257
+ /// <returns>Absolute value</returns>
258
+ using Base::Abs;
259
+
260
+ /// <summary>
261
+ /// Creates a new instance with the sign flipped.
262
+ /// </summary>
263
+ /// <returns>Flipped sign instance</returns>
264
+ using Base::Negate;
265
+
266
+ /// <summary>
267
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
268
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
269
+ /// and therefore equivalent, if the resulting value is still zero.
270
+ /// </summary>
271
+ /// <param name="lhs">first value</param>
272
+ /// <param name="rhs">second value</param>
273
+ /// <returns>True if both arguments represent zero</returns>
274
+ using Base::AreZero;
275
+
276
+ /// <summary>
277
+ /// User defined conversion operator. Converts Float16_t to float.
278
+ /// </summary>
279
+ explicit operator float() const noexcept { return ToFloat(); }
280
+
281
+ using Base::operator==;
282
+ using Base::operator!=;
283
+ using Base::operator<;
284
+ };
285
+
286
+ static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
287
+
288
+ /** \brief bfloat16 (Brain Floating Point) data type
289
+ *
290
+ * \details This struct is used for converting float to bfloat16 and back
291
+ * so the user could feed inputs and fetch outputs using these type.
292
+ *
293
+ * The size of the structure should align with uint16_t and one can freely cast
294
+ * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
295
+ *
296
+ * \code{.unparsed}
297
+ * // This example demonstrates converion from float to float16
298
+ * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
299
+ * std::vector<Ort::BFloat16_t> bfp16_values;
300
+ * bfp16_values.reserve(std::size(values));
301
+ * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
302
+ * [](float value) { return Ort::BFloat16_t(value); });
303
+ *
304
+ * \endcode
305
+ */
306
+ struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
307
+ private:
308
+ /// <summary>
309
+ /// Constructor from a uint16_t representation of bfloat16
310
+ /// used in FromBits() to escape overload resolution issue with
311
+ /// constructor from float.
312
+ /// No conversion is done.
313
+ /// </summary>
314
+ /// <param name="v">16-bit bfloat16 value</param>
315
+ constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
316
+
317
+ public:
318
+ using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
319
+
320
+ BFloat16_t() = default;
321
+
322
+ /// <summary>
323
+ /// Explicit conversion to uint16_t representation of bfloat16.
324
+ /// </summary>
325
+ /// <param name="v">uint16_t bit representation of bfloat16</param>
326
+ /// <returns>new instance of BFloat16_t</returns>
327
+ static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
328
+
329
+ /// <summary>
330
+ /// __ctor from float. Float is converted into bfloat16 16-bit representation.
331
+ /// </summary>
332
+ /// <param name="v">float value</param>
333
+ explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
334
+
335
+ /// <summary>
336
+ /// Converts bfloat16 to float
337
+ /// </summary>
338
+ /// <returns>float representation of bfloat16 value</returns>
339
+ float ToFloat() const noexcept { return Base::ToFloatImpl(); }
340
+
341
+ /// <summary>
342
+ /// Checks if the value is negative
343
+ /// </summary>
344
+ /// <returns>true if negative</returns>
345
+ using Base::IsNegative;
346
+
347
+ /// <summary>
348
+ /// Tests if the value is NaN
349
+ /// </summary>
350
+ /// <returns>true if NaN</returns>
351
+ using Base::IsNaN;
352
+
353
+ /// <summary>
354
+ /// Tests if the value is finite
355
+ /// </summary>
356
+ /// <returns>true if finite</returns>
357
+ using Base::IsFinite;
358
+
359
+ /// <summary>
360
+ /// Tests if the value represents positive infinity.
361
+ /// </summary>
362
+ /// <returns>true if positive infinity</returns>
363
+ using Base::IsPositiveInfinity;
364
+
365
+ /// <summary>
366
+ /// Tests if the value represents negative infinity
367
+ /// </summary>
368
+ /// <returns>true if negative infinity</returns>
369
+ using Base::IsNegativeInfinity;
370
+
371
+ /// <summary>
372
+ /// Tests if the value is either positive or negative infinity.
373
+ /// </summary>
374
+ /// <returns>True if absolute value is infinity</returns>
375
+ using Base::IsInfinity;
376
+
377
+ /// <summary>
378
+ /// Tests if the value is NaN or zero. Useful for comparisons.
379
+ /// </summary>
380
+ /// <returns>True if NaN or zero.</returns>
381
+ using Base::IsNaNOrZero;
382
+
383
+ /// <summary>
384
+ /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
385
+ /// </summary>
386
+ /// <returns>True if so</returns>
387
+ using Base::IsNormal;
388
+
389
+ /// <summary>
390
+ /// Tests if the value is subnormal (denormal).
391
+ /// </summary>
392
+ /// <returns>True if so</returns>
393
+ using Base::IsSubnormal;
394
+
395
+ /// <summary>
396
+ /// Creates an instance that represents absolute value.
397
+ /// </summary>
398
+ /// <returns>Absolute value</returns>
399
+ using Base::Abs;
400
+
401
+ /// <summary>
402
+ /// Creates a new instance with the sign flipped.
403
+ /// </summary>
404
+ /// <returns>Flipped sign instance</returns>
405
+ using Base::Negate;
406
+
407
+ /// <summary>
408
+ /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
409
+ /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
410
+ /// and therefore equivalent, if the resulting value is still zero.
411
+ /// </summary>
412
+ /// <param name="lhs">first value</param>
413
+ /// <param name="rhs">second value</param>
414
+ /// <returns>True if both arguments represent zero</returns>
415
+ using Base::AreZero;
416
+
417
+ /// <summary>
418
+ /// User defined conversion operator. Converts BFloat16_t to float.
419
+ /// </summary>
420
+ explicit operator float() const noexcept { return ToFloat(); }
421
+
422
+ // We do not have an inherited impl for the below operators
423
+ // as the internal class implements them a little differently
424
+ bool operator==(const BFloat16_t& rhs) const noexcept;
425
+ bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
426
+ bool operator<(const BFloat16_t& rhs) const noexcept;
427
+ };
428
+
429
+ static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
430
+
431
+ /** \brief float8e4m3fn (Float8 Floating Point) data type
432
+ * \details It is necessary for type dispatching to make use of C++ API
433
+ * The type is implicitly convertible to/from uint8_t.
434
+ * See https://onnx.ai/onnx/technical/float8.html for further details.
435
+ */
436
+ struct Float8E4M3FN_t {
437
+ uint8_t value;
438
+ constexpr Float8E4M3FN_t() noexcept : value(0) {}
439
+ constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
440
+ constexpr operator uint8_t() const noexcept { return value; }
441
+ // nan values are treated like any other value for operator ==, !=
442
+ constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
443
+ constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
444
+ };
445
+
446
+ static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
447
+
448
+ /** \brief float8e4m3fnuz (Float8 Floating Point) data type
449
+ * \details It is necessary for type dispatching to make use of C++ API
450
+ * The type is implicitly convertible to/from uint8_t.
451
+ * See https://onnx.ai/onnx/technical/float8.html for further details.
452
+ */
453
+ struct Float8E4M3FNUZ_t {
454
+ uint8_t value;
455
+ constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
456
+ constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
457
+ constexpr operator uint8_t() const noexcept { return value; }
458
+ // nan values are treated like any other value for operator ==, !=
459
+ constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
460
+ constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
461
+ };
462
+
463
+ static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
464
+
465
+ /** \brief float8e5m2 (Float8 Floating Point) data type
466
+ * \details It is necessary for type dispatching to make use of C++ API
467
+ * The type is implicitly convertible to/from uint8_t.
468
+ * See https://onnx.ai/onnx/technical/float8.html for further details.
469
+ */
470
+ struct Float8E5M2_t {
471
+ uint8_t value;
472
+ constexpr Float8E5M2_t() noexcept : value(0) {}
473
+ constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
474
+ constexpr operator uint8_t() const noexcept { return value; }
475
+ // nan values are treated like any other value for operator ==, !=
476
+ constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
477
+ constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
478
+ };
479
+
480
+ static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
481
+
482
+ /** \brief float8e5m2fnuz (Float8 Floating Point) data type
483
+ * \details It is necessary for type dispatching to make use of C++ API
484
+ * The type is implicitly convertible to/from uint8_t.
485
+ * See https://onnx.ai/onnx/technical/float8.html for further details.
486
+ */
487
+ struct Float8E5M2FNUZ_t {
488
+ uint8_t value;
489
+ constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
490
+ constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
491
+ constexpr operator uint8_t() const noexcept { return value; }
492
+ // nan values are treated like any other value for operator ==, !=
493
+ constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
494
+ constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
495
+ };
496
+
497
+ static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
498
+
499
+ namespace detail {
500
+ // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
501
+ // This can't be done in the C API since C doesn't have function overloading.
502
+ #define ORT_DEFINE_RELEASE(NAME) \
503
+ inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
504
+
505
+ ORT_DEFINE_RELEASE(Allocator);
506
+ ORT_DEFINE_RELEASE(MemoryInfo);
507
+ ORT_DEFINE_RELEASE(CustomOpDomain);
508
+ ORT_DEFINE_RELEASE(ThreadingOptions);
509
+ ORT_DEFINE_RELEASE(Env);
510
+ ORT_DEFINE_RELEASE(RunOptions);
511
+ ORT_DEFINE_RELEASE(Session);
512
+ ORT_DEFINE_RELEASE(SessionOptions);
513
+ ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
514
+ ORT_DEFINE_RELEASE(SequenceTypeInfo);
515
+ ORT_DEFINE_RELEASE(MapTypeInfo);
516
+ ORT_DEFINE_RELEASE(TypeInfo);
517
+ ORT_DEFINE_RELEASE(Value);
518
+ ORT_DEFINE_RELEASE(ModelMetadata);
519
+ ORT_DEFINE_RELEASE(IoBinding);
520
+ ORT_DEFINE_RELEASE(ArenaCfg);
521
+ ORT_DEFINE_RELEASE(Status);
522
+ ORT_DEFINE_RELEASE(OpAttr);
523
+ ORT_DEFINE_RELEASE(Op);
524
+ ORT_DEFINE_RELEASE(KernelInfo);
525
+
526
+ #undef ORT_DEFINE_RELEASE
527
+
528
+ /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
529
+ * has no ownership of the underlying C object.
530
+ */
531
+ template <typename T>
532
+ struct Unowned {
533
+ using Type = T;
534
+ };
535
+
536
+ /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
537
+ * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
538
+ *
539
+ * All of the C++ classes
540
+ * a) serve as containers for pointers to objects that are created by the underlying C API.
541
+ * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
542
+ * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
543
+ * they would release objects owned automatically when going out of scope, they are move-only.
544
+ * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
545
+ * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
546
+ * such as Onnxruntime or instances of XXXX classes.
547
+ * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
548
+ * in C++ code.
549
+ *
550
+ */
551
+
552
+ /// <summary>
553
+ /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
554
+ /// </summary>
555
+ template <typename T>
556
+ struct Base {
557
+ using contained_type = T;
558
+
559
+ constexpr Base() = default;
560
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
561
+ ~Base() { OrtRelease(p_); }
562
+
563
+ Base(const Base&) = delete;
564
+ Base& operator=(const Base&) = delete;
565
+
566
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
567
+ Base& operator=(Base&& v) noexcept {
568
+ OrtRelease(p_);
569
+ p_ = v.release();
570
+ return *this;
571
+ }
572
+
573
+ constexpr operator contained_type*() const noexcept { return p_; }
574
+
575
+ /// \brief Relinquishes ownership of the contained C object pointer
576
+ /// The underlying object is not destroyed
577
+ contained_type* release() {
578
+ T* p = p_;
579
+ p_ = nullptr;
580
+ return p;
581
+ }
582
+
583
+ protected:
584
+ contained_type* p_{};
585
+ };
586
+
587
+ // Undefined. For const types use Base<Unowned<const T>>
588
+ template <typename T>
589
+ struct Base<const T>;
590
+
591
+ /// <summary>
592
+ /// Covers unowned pointers owned by either the ORT
593
+ /// or some other instance of CPP wrappers.
594
+ /// Used for ConstXXX and UnownedXXXX types that are copyable.
595
+ /// Also convenient to wrap raw OrtXX pointers .
596
+ /// </summary>
597
+ /// <typeparam name="T"></typeparam>
598
+ template <typename T>
599
+ struct Base<Unowned<T>> {
600
+ using contained_type = typename Unowned<T>::Type;
601
+
602
+ constexpr Base() = default;
603
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
604
+
605
+ ~Base() = default;
606
+
607
+ Base(const Base&) = default;
608
+ Base& operator=(const Base&) = default;
609
+
610
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
611
+ Base& operator=(Base&& v) noexcept {
612
+ p_ = nullptr;
613
+ std::swap(p_, v.p_);
614
+ return *this;
615
+ }
616
+
617
+ constexpr operator contained_type*() const noexcept { return p_; }
618
+
619
+ protected:
620
+ contained_type* p_{};
621
+ };
622
+
623
+ // Light functor to release memory with OrtAllocator
624
+ struct AllocatedFree {
625
+ OrtAllocator* allocator_;
626
+ explicit AllocatedFree(OrtAllocator* allocator)
627
+ : allocator_(allocator) {}
628
+ void operator()(void* ptr) const {
629
+ if (ptr) allocator_->Free(allocator_, ptr);
630
+ }
631
+ };
632
+
633
+ } // namespace detail
634
+
635
+ struct AllocatorWithDefaultOptions;
636
+ struct Env;
637
+ struct TypeInfo;
638
+ struct Value;
639
+ struct ModelMetadata;
640
+
641
+ /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
642
+ * and release them at the end of the scope. The lifespan of the given allocator
643
+ * must eclipse the lifespan of AllocatedStringPtr instance
644
+ */
645
+ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
646
+
647
+ /** \brief The Status that holds ownership of OrtStatus received from C API
648
+ * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
649
+ * constructors to construct an instance of a Status object from exceptions.
650
+ */
651
+ struct Status : detail::Base<OrtStatus> {
652
+ explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
653
+ explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
654
+ explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
655
+ explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
656
+ Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
657
+ std::string GetErrorMessage() const;
658
+ OrtErrorCode GetErrorCode() const;
659
+ bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status.
660
+ };
661
+
662
+ /** \brief The ThreadingOptions
663
+ *
664
+ * The ThreadingOptions used for set global threadpools' options of The Env.
665
+ */
666
+ struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
667
+ /// \brief Wraps OrtApi::CreateThreadingOptions
668
+ ThreadingOptions();
669
+
670
+ /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
671
+ ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
672
+
673
+ /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
674
+ ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
675
+
676
+ /// \brief Wraps OrtApi::SetGlobalSpinControl
677
+ ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
678
+
679
+ /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
680
+ ThreadingOptions& SetGlobalDenormalAsZero();
681
+
682
+ /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
683
+ ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
684
+
685
+ /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
686
+ ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
687
+
688
+ /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
689
+ ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
690
+ };
691
+
692
+ /** \brief The Env (Environment)
693
+ *
694
+ * The Env holds the logging state used by all other objects.
695
+ * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
696
+ */
697
+ struct Env : detail::Base<OrtEnv> {
698
+ explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
699
+
700
+ /// \brief Wraps OrtApi::CreateEnv
701
+ Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
702
+
703
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
704
+ Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
705
+
706
+ /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
707
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
708
+
709
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
710
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
711
+ OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
712
+
713
+ /// \brief C Interop Helper
714
+ explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
715
+
716
+ Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
717
+ Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
718
+
719
+ Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
720
+
721
+ Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
722
+
723
+ Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
724
+ };
725
+
726
+ /** \brief Custom Op Domain
727
+ *
728
+ */
729
+ struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
730
+ explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
731
+
732
+ /// \brief Wraps OrtApi::CreateCustomOpDomain
733
+ explicit CustomOpDomain(const char* domain);
734
+
735
+ // This does not take ownership of the op, simply registers it.
736
+ void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
737
+ };
738
+
739
+ /** \brief RunOptions
740
+ *
741
+ */
742
+ struct RunOptions : detail::Base<OrtRunOptions> {
743
+ explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
744
+ RunOptions(); ///< Wraps OrtApi::CreateRunOptions
745
+
746
+ RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
747
+ int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
748
+
749
+ RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
750
+ int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
751
+
752
+ RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
753
+ const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
754
+
755
+ RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
756
+
757
+ /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
758
+ *
759
+ * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
760
+ * Wraps OrtApi::RunOptionsSetTerminate
761
+ */
762
+ RunOptions& SetTerminate();
763
+
764
+ /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
765
+ *
766
+ * Wraps OrtApi::RunOptionsUnsetTerminate
767
+ */
768
+ RunOptions& UnsetTerminate();
769
+ };
770
+
771
+ namespace detail {
772
+ // Utility function that returns a SessionOption config entry key for a specific custom operator.
773
+ // Ex: custom_op.[custom_op_name].[config]
774
+ std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
775
+ } // namespace detail
776
+
777
+ /// <summary>
778
+ /// Class that represents session configuration entries for one or more custom operators.
779
+ ///
780
+ /// Example:
781
+ /// Ort::CustomOpConfigs op_configs;
782
+ /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
783
+ ///
784
+ /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
785
+ /// </summary>
786
+ struct CustomOpConfigs {
787
+ CustomOpConfigs() = default;
788
+ ~CustomOpConfigs() = default;
789
+ CustomOpConfigs(const CustomOpConfigs&) = default;
790
+ CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
791
+ CustomOpConfigs(CustomOpConfigs&& o) = default;
792
+ CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
793
+
794
+ /** \brief Adds a session configuration entry/value for a specific custom operator.
795
+ *
796
+ * \param custom_op_name The name of the custom operator for which to add a configuration entry.
797
+ * Must match the name returned by the CustomOp's GetName() method.
798
+ * \param config_key The name of the configuration entry.
799
+ * \param config_value The value of the configuration entry.
800
+ * \return A reference to this object to enable call chaining.
801
+ */
802
+ CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
803
+
804
+ /** \brief Returns a flattened map of custom operator configuration entries and their values.
805
+ *
806
+ * The keys has been flattened to include both the custom operator name and the configuration entry key name.
807
+ * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
808
+ * {"my_op.key", "value"}.
809
+ *
810
+ * \return An unordered map of flattened configurations.
811
+ */
812
+ const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
813
+
814
+ private:
815
+ std::unordered_map<std::string, std::string> flat_configs_;
816
+ };
817
+
818
+ /** \brief Options object used when creating a new Session object
819
+ *
820
+ * Wraps ::OrtSessionOptions object and methods
821
+ */
822
+
823
+ struct SessionOptions;
824
+
825
+ namespace detail {
826
+ // we separate const-only methods because passing const ptr to non-const methods
827
+ // is only discovered when inline methods are compiled which is counter-intuitive
828
+ template <typename T>
829
+ struct ConstSessionOptionsImpl : Base<T> {
830
+ using B = Base<T>;
831
+ using B::B;
832
+
833
+ SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
834
+
835
+ std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
836
+ bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
837
+ std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
838
+ };
839
+
840
+ template <typename T>
841
+ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
842
+ using B = ConstSessionOptionsImpl<T>;
843
+ using B::B;
844
+
845
+ SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
846
+ SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
847
+ SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
848
+ SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
849
+
850
+ SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
851
+ SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
852
+
853
+ SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
854
+
855
+ SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
856
+ SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
857
+
858
+ SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
859
+
860
+ SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
861
+ SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
862
+
863
+ SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
864
+
865
+ SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
866
+ SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
867
+
868
+ SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
869
+
870
+ SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
871
+
872
+ SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
873
+
874
+ SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
875
+ SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
876
+
877
+ SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
878
+ SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
879
+ SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
880
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
881
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
882
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
883
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
884
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
885
+ SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
886
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
887
+ SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
888
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
889
+ SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
890
+ /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
891
+ SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
892
+ const std::unordered_map<std::string, std::string>& provider_options = {});
893
+
894
+ SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
895
+ SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
896
+ SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
897
+
898
+ ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
899
+ ///< The custom operator configurations are optional. If provided, custom operator configs are set via
900
+ ///< OrtApi::AddSessionConfigEntry.
901
+ SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
902
+
903
+ SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
904
+ };
905
+ } // namespace detail
906
+
907
+ using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
908
+ using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
909
+
910
+ /** \brief Wrapper around ::OrtSessionOptions
911
+ *
912
+ */
913
+ struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
914
+ explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
915
+ SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
916
+ explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
917
+ UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
918
+ ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
919
+ };
920
+
921
+ /** \brief Wrapper around ::OrtModelMetadata
922
+ *
923
+ */
924
+ struct ModelMetadata : detail::Base<OrtModelMetadata> {
925
+ explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
926
+ explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
927
+
928
+ /** \brief Returns a copy of the producer name.
929
+ *
930
+ * \param allocator to allocate memory for the copy of the name returned
931
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
932
+ * The OrtAllocator instances must be valid at the point of memory release.
933
+ */
934
+ AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
935
+
936
+ /** \brief Returns a copy of the graph name.
937
+ *
938
+ * \param allocator to allocate memory for the copy of the name returned
939
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
940
+ * The OrtAllocator instances must be valid at the point of memory release.
941
+ */
942
+ AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
943
+
944
+ /** \brief Returns a copy of the domain name.
945
+ *
946
+ * \param allocator to allocate memory for the copy of the name returned
947
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
948
+ * The OrtAllocator instances must be valid at the point of memory release.
949
+ */
950
+ AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
951
+
952
+ /** \brief Returns a copy of the description.
953
+ *
954
+ * \param allocator to allocate memory for the copy of the string returned
955
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
956
+ * The OrtAllocator instances must be valid at the point of memory release.
957
+ */
958
+ AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
959
+
960
+ /** \brief Returns a copy of the graph description.
961
+ *
962
+ * \param allocator to allocate memory for the copy of the string returned
963
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
964
+ * The OrtAllocator instances must be valid at the point of memory release.
965
+ */
966
+ AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
967
+
968
+ /** \brief Returns a vector of copies of the custom metadata keys.
969
+ *
970
+ * \param allocator to allocate memory for the copy of the string returned
971
+ * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
972
+ * The OrtAllocator instance must be valid at the point of memory release.
973
+ */
974
+ std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
975
+
976
+ /** \brief Looks up a value by a key in the Custom Metadata map
977
+ *
978
+ * \param key zero terminated string key to lookup
979
+ * \param allocator to allocate memory for the copy of the string returned
980
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
981
+ * maybe nullptr if key is not found.
982
+ *
983
+ * The OrtAllocator instances must be valid at the point of memory release.
984
+ */
985
+ AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
986
+
987
+ int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
988
+ };
989
+
990
+ struct IoBinding;
991
+
992
+ namespace detail {
993
+
994
+ // we separate const-only methods because passing const ptr to non-const methods
995
+ // is only discovered when inline methods are compiled which is counter-intuitive
996
+ template <typename T>
997
+ struct ConstSessionImpl : Base<T> {
998
+ using B = Base<T>;
999
+ using B::B;
1000
+
1001
+ size_t GetInputCount() const; ///< Returns the number of model inputs
1002
+ size_t GetOutputCount() const; ///< Returns the number of model outputs
1003
+ size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
1004
+
1005
+ /** \brief Returns a copy of input name at the specified index.
1006
+ *
1007
+ * \param index must less than the value returned by GetInputCount()
1008
+ * \param allocator to allocate memory for the copy of the name returned
1009
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1010
+ * The OrtAllocator instances must be valid at the point of memory release.
1011
+ */
1012
+ AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
1013
+
1014
+ /** \brief Returns a copy of output name at then specified index.
1015
+ *
1016
+ * \param index must less than the value returned by GetOutputCount()
1017
+ * \param allocator to allocate memory for the copy of the name returned
1018
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1019
+ * The OrtAllocator instances must be valid at the point of memory release.
1020
+ */
1021
+ AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
1022
+
1023
+ /** \brief Returns a copy of the overridable initializer name at then specified index.
1024
+ *
1025
+ * \param index must less than the value returned by GetOverridableInitializerCount()
1026
+ * \param allocator to allocate memory for the copy of the name returned
1027
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1028
+ * The OrtAllocator instances must be valid at the point of memory release.
1029
+ */
1030
+ AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
1031
+
1032
+ uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
1033
+ ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
1034
+
1035
+ TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
1036
+ TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
1037
+ TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
1038
+ };
1039
+
1040
+ template <typename T>
1041
+ struct SessionImpl : ConstSessionImpl<T> {
1042
+ using B = ConstSessionImpl<T>;
1043
+ using B::B;
1044
+
1045
+ /** \brief Run the model returning results in an Ort allocated vector.
1046
+ *
1047
+ * Wraps OrtApi::Run
1048
+ *
1049
+ * The caller provides a list of inputs and a list of the desired outputs to return.
1050
+ *
1051
+ * See the output logs for more information on warnings/errors that occur while processing the model.
1052
+ * Common errors are.. (TODO)
1053
+ *
1054
+ * \param[in] run_options
1055
+ * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
1056
+ * \param[in] input_values Array of Value objects of length input_count that is the list of input values
1057
+ * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
1058
+ * \param[in] output_names Array of C style strings of length output_count that is the list of output names
1059
+ * \param[in] output_count Number of outputs (the size of the output_names array)
1060
+ * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
1061
+ */
1062
+ std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1063
+ const char* const* output_names, size_t output_count);
1064
+
1065
+ /** \brief Run the model returning results in user provided outputs
1066
+ * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
1067
+ */
1068
+ void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1069
+ const char* const* output_names, Value* output_values, size_t output_count);
1070
+
1071
+ void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
1072
+
1073
+ /** \brief Run the model asynchronously in a thread owned by intra op thread pool
1074
+ *
1075
+ * Wraps OrtApi::RunAsync
1076
+ *
1077
+ * \param[in] run_options
1078
+ * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
1079
+ * \param[in] input_values Array of Value objects of length input_count
1080
+ * \param[in] input_count Number of elements in the input_names and inputs arrays
1081
+ * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
1082
+ * \param[out] output_values Array of provided Values to be filled with outputs.
1083
+ * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
1084
+ * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
1085
+ * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
1086
+ * NOTE: it is customer's duty to finally release output_values and each of its member,
1087
+ * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
1088
+ * \param[in] output_count Number of elements in the output_names and outputs array
1089
+ * \param[in] callback Callback function on model run completion
1090
+ * \param[in] user_data User data that pass back to the callback
1091
+ */
1092
+ void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1093
+ const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1094
+
1095
+ /** \brief End profiling and return a copy of the profiling file name.
1096
+ *
1097
+ * \param allocator to allocate memory for the copy of the string returned
1098
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1099
+ * The OrtAllocator instances must be valid at the point of memory release.
1100
+ */
1101
+ AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
1102
+ };
1103
+
1104
+ } // namespace detail
1105
+
1106
+ using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
1107
+ using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
1108
+
1109
+ /** \brief Wrapper around ::OrtSession
1110
+ *
1111
+ */
1112
+ struct Session : detail::SessionImpl<OrtSession> {
1113
+ explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
1114
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
1115
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1116
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
1117
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
1118
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1119
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
1120
+
1121
+ ConstSession GetConst() const { return ConstSession{this->p_}; }
1122
+ UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1123
+ };
1124
+
1125
+ namespace detail {
1126
+ template <typename T>
1127
+ struct MemoryInfoImpl : Base<T> {
1128
+ using B = Base<T>;
1129
+ using B::B;
1130
+
1131
+ std::string GetAllocatorName() const;
1132
+ OrtAllocatorType GetAllocatorType() const;
1133
+ int GetDeviceId() const;
1134
+ OrtMemoryInfoDeviceType GetDeviceType() const;
1135
+ OrtMemType GetMemoryType() const;
1136
+
1137
+ template <typename U>
1138
+ bool operator==(const MemoryInfoImpl<U>& o) const;
1139
+ };
1140
+ } // namespace detail
1141
+
1142
+ // Const object holder that does not own the underlying object
1143
+ using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
1144
+
1145
+ /** \brief Wrapper around ::OrtMemoryInfo
1146
+ *
1147
+ */
1148
+ struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1149
+ static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
1150
+ explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
1151
+ explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
1152
+ MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1153
+ ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1154
+ };
1155
+
1156
+ namespace detail {
1157
+ template <typename T>
1158
+ struct TensorTypeAndShapeInfoImpl : Base<T> {
1159
+ using B = Base<T>;
1160
+ using B::B;
1161
+
1162
+ ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
1163
+ size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
1164
+
1165
+ size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
1166
+
1167
+ /** \deprecated use GetShape() returning std::vector
1168
+ * [[deprecated]]
1169
+ * This interface is unsafe to use
1170
+ */
1171
+ [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
1172
+
1173
+ void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
1174
+
1175
+ std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
1176
+ };
1177
+
1178
+ } // namespace detail
1179
+
1180
+ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
1181
+
1182
+ /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
1183
+ *
1184
+ */
1185
+ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
1186
+ explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
1187
+ explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
1188
+ ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
1189
+ };
1190
+
1191
+ namespace detail {
1192
+ template <typename T>
1193
+ struct SequenceTypeInfoImpl : Base<T> {
1194
+ using B = Base<T>;
1195
+ using B::B;
1196
+ TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
1197
+ };
1198
+
1199
+ } // namespace detail
1200
+
1201
+ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
1202
+
1203
+ /** \brief Wrapper around ::OrtSequenceTypeInfo
1204
+ *
1205
+ */
1206
+ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1207
+ explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
1208
+ explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
1209
+ ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
1210
+ };
1211
+
1212
+ namespace detail {
1213
+ template <typename T>
1214
+ struct OptionalTypeInfoImpl : Base<T> {
1215
+ using B = Base<T>;
1216
+ using B::B;
1217
+ TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
1218
+ };
1219
+
1220
+ } // namespace detail
1221
+
1222
+ // This is always owned by the TypeInfo and can only be obtained from it.
1223
+ using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl<detail::Unowned<const OrtOptionalTypeInfo>>;
1224
+
1225
+ namespace detail {
1226
+ template <typename T>
1227
+ struct MapTypeInfoImpl : detail::Base<T> {
1228
+ using B = Base<T>;
1229
+ using B::B;
1230
+ ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
1231
+ TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
1232
+ };
1233
+
1234
+ } // namespace detail
1235
+
1236
+ using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
1237
+
1238
+ /** \brief Wrapper around ::OrtMapTypeInfo
1239
+ *
1240
+ */
1241
+ struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1242
+ explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
1243
+ explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
1244
+ ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1245
+ };
1246
+
1247
+ namespace detail {
1248
+ template <typename T>
1249
+ struct TypeInfoImpl : detail::Base<T> {
1250
+ using B = Base<T>;
1251
+ using B::B;
1252
+
1253
+ ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
1254
+ ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
1255
+ ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
1256
+ ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
1257
+
1258
+ ONNXType GetONNXType() const;
1259
+ };
1260
+ } // namespace detail
1261
+
1262
+ /// <summary>
1263
+ /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
1264
+ /// Provides access to const OrtTypeInfo APIs.
1265
+ /// </summary>
1266
+ using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
1267
+
1268
+ /// <summary>
1269
+ /// Type information that may contain either TensorTypeAndShapeInfo or
1270
+ /// the information about contained sequence or map depending on the ONNXType.
1271
+ /// </summary>
1272
+ struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1273
+ explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
1274
+ explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
1275
+
1276
+ ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1277
+ };
1278
+
1279
+ namespace detail {
1280
+ // This structure is used to feed sparse tensor values
1281
+ // information for use with FillSparseTensor<Format>() API
1282
+ // if the data type for the sparse tensor values is numeric
1283
+ // use data.p_data, otherwise, use data.str pointer to feed
1284
+ // values. data.str is an array of const char* that are zero terminated.
1285
+ // number of strings in the array must match shape size.
1286
+ // For fully sparse tensors use shape {0} and set p_data/str
1287
+ // to nullptr.
1288
+ struct OrtSparseValuesParam {
1289
+ const int64_t* values_shape;
1290
+ size_t values_shape_len;
1291
+ union {
1292
+ const void* p_data;
1293
+ const char** str;
1294
+ } data;
1295
+ };
1296
+
1297
+ // Provides a way to pass shape in a single
1298
+ // argument
1299
+ struct Shape {
1300
+ const int64_t* shape;
1301
+ size_t shape_len;
1302
+ };
1303
+
1304
+ template <typename T>
1305
+ struct ConstValueImpl : Base<T> {
1306
+ using B = Base<T>;
1307
+ using B::B;
1308
+
1309
+ /// <summary>
1310
+ /// Obtains a pointer to a user defined data for experimental purposes
1311
+ /// </summary>
1312
+ template <typename R>
1313
+ void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
1314
+
1315
+ bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
1316
+ bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
1317
+
1318
+ size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1319
+ Value GetValue(int index, OrtAllocator* allocator) const;
1320
+
1321
+ /// <summary>
1322
+ /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
1323
+ /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
1324
+ /// for allocating necessary memory and calling GetStringTensorContent().
1325
+ /// </summary>
1326
+ /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
1327
+ size_t GetStringTensorDataLength() const;
1328
+
1329
+ /// <summary>
1330
+ /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
1331
+ /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
1332
+ /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
1333
+ /// strings.
1334
+ ///
1335
+ /// Strings are always assumed to be on CPU, no X-device copy.
1336
+ /// </summary>
1337
+ /// <param name="buffer">user allocated buffer</param>
1338
+ /// <param name="buffer_length">length in bytes of the allocated buffer</param>
1339
+ /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
1340
+ /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
1341
+ /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
1342
+ /// for sparse tensors</param>
1343
+ void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1344
+
1345
+ /// <summary>
1346
+ /// Returns a const typed pointer to the tensor contained data.
1347
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1348
+ /// </summary>
1349
+ /// <typeparam name="T"></typeparam>
1350
+ /// <returns>const pointer to data, no copies made</returns>
1351
+ template <typename R>
1352
+ const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1353
+
1354
+ /// <summary>
1355
+ /// Returns a non-typed pointer to a tensor contained data.
1356
+ /// </summary>
1357
+ /// <returns>const pointer to data, no copies made</returns>
1358
+ const void* GetTensorRawData() const;
1359
+
1360
+ /// <summary>
1361
+ /// The API returns type information for data contained in a tensor. For sparse
1362
+ /// tensors it returns type information for contained non-zero values.
1363
+ /// It returns dense shape for sparse tensors.
1364
+ /// </summary>
1365
+ /// <returns>TypeInfo</returns>
1366
+ TypeInfo GetTypeInfo() const;
1367
+
1368
+ /// <summary>
1369
+ /// The API returns type information for data contained in a tensor. For sparse
1370
+ /// tensors it returns type information for contained non-zero values.
1371
+ /// It returns dense shape for sparse tensors.
1372
+ /// </summary>
1373
+ /// <returns>TensorTypeAndShapeInfo</returns>
1374
+ TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1375
+
1376
+ /// <summary>
1377
+ /// This API returns information about the memory allocation used to hold data.
1378
+ /// </summary>
1379
+ /// <returns>Non owning instance of MemoryInfo</returns>
1380
+ ConstMemoryInfo GetTensorMemoryInfo() const;
1381
+
1382
+ /// <summary>
1383
+ /// The API copies UTF-8 encoded bytes for the requested string element
1384
+ /// contained within a tensor or a sparse tensor into a provided buffer.
1385
+ /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1386
+ /// </summary>
1387
+ /// <param name="buffer_length"></param>
1388
+ /// <param name="element_index"></param>
1389
+ /// <param name="buffer"></param>
1390
+ void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1391
+
1392
+ /// <summary>
1393
+ /// Returns string tensor UTF-8 encoded string element.
1394
+ /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
1395
+ /// </summary>
1396
+ /// <param name="element_index"></param>
1397
+ /// <returns>std::string</returns>
1398
+ std::string GetStringTensorElement(size_t element_index) const;
1399
+
1400
+ /// <summary>
1401
+ /// The API returns a byte length of UTF-8 encoded string element
1402
+ /// contained in either a tensor or a spare tensor values.
1403
+ /// </summary>
1404
+ /// <param name="element_index"></param>
1405
+ /// <returns>byte length for the specified string element</returns>
1406
+ size_t GetStringTensorElementLength(size_t element_index) const;
1407
+
1408
+ #if !defined(DISABLE_SPARSE_TENSORS)
1409
+ /// <summary>
1410
+ /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1411
+ /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1412
+ /// the value returned is ORT_SPARSE_UNDEFINED.
1413
+ /// </summary>
1414
+ /// <returns>Format enum</returns>
1415
+ OrtSparseFormat GetSparseFormat() const;
1416
+
1417
+ /// <summary>
1418
+ /// The API returns type and shape information for stored non-zero values of the
1419
+ /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1420
+ /// </summary>
1421
+ /// <returns>TensorTypeAndShapeInfo values information</returns>
1422
+ TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1423
+
1424
+ /// <summary>
1425
+ /// The API returns type and shape information for the specified indices. Each supported
1426
+ /// indices have their own enum values even if a give format has more than one kind of indices.
1427
+ /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1428
+ /// </summary>
1429
+ /// <param name="format">enum requested</param>
1430
+ /// <returns>type and shape information</returns>
1431
+ TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1432
+
1433
+ /// <summary>
1434
+ /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1435
+ /// a convenience data type casting on the return type pointer. Make sure you are requesting
1436
+ /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1437
+ /// </summary>
1438
+ /// <typeparam name="T">type to cast to</typeparam>
1439
+ /// <param name="indices_format">requested indices kind</param>
1440
+ /// <param name="num_indices">number of indices entries</param>
1441
+ /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1442
+ template <typename R>
1443
+ const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1444
+
1445
+ /// <summary>
1446
+ /// Returns true if the OrtValue contains a sparse tensor
1447
+ /// </summary>
1448
+ /// <returns></returns>
1449
+ bool IsSparseTensor() const;
1450
+
1451
+ /// <summary>
1452
+ /// The API returns a pointer to an internal buffer of the sparse tensor
1453
+ /// containing non-zero values. The API merely does casting. Make sure you
1454
+ /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1455
+ /// first.
1456
+ /// </summary>
1457
+ /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1458
+ /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1459
+ template <typename R>
1460
+ const R* GetSparseTensorValues() const;
1461
+
1462
+ #endif
1463
+ };
1464
+
1465
+ template <typename T>
1466
+ struct ValueImpl : ConstValueImpl<T> {
1467
+ using B = ConstValueImpl<T>;
1468
+ using B::B;
1469
+
1470
+ /// <summary>
1471
+ /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1472
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1473
+ /// </summary>
1474
+ /// <returns>non-const pointer to data, no copies made</returns>
1475
+ template <typename R>
1476
+ R* GetTensorMutableData();
1477
+
1478
+ /// <summary>
1479
+ /// Returns a non-typed non-const pointer to a tensor contained data.
1480
+ /// </summary>
1481
+ /// <returns>pointer to data, no copies made</returns>
1482
+ void* GetTensorMutableRawData();
1483
+
1484
+ /// <summary>
1485
+ // Obtain a reference to an element of data at the location specified
1486
+ /// by the vector of dims.
1487
+ /// </summary>
1488
+ /// <typeparam name="R"></typeparam>
1489
+ /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1490
+ /// <returns></returns>
1491
+ template <typename R>
1492
+ R& At(const std::vector<int64_t>& location);
1493
+
1494
+ /// <summary>
1495
+ /// Set all strings at once in a string tensor
1496
+ /// </summary>
1497
+ /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1498
+ /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1499
+ void FillStringTensor(const char* const* s, size_t s_len);
1500
+
1501
+ /// <summary>
1502
+ /// Set a single string in a string tensor
1503
+ /// </summary>
1504
+ /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1505
+ /// <param name="index">[in] Index of the string in the tensor to set</param>
1506
+ void FillStringTensorElement(const char* s, size_t index);
1507
+
1508
+ /// <summary>
1509
+ /// Allocate if necessary and obtain a pointer to a UTF-8
1510
+ /// encoded string element buffer indexed by the flat element index,
1511
+ /// of the specified length.
1512
+ ///
1513
+ /// This API is for advanced usage. It avoids a need to construct
1514
+ /// an auxiliary array of string pointers, and allows to write data directly
1515
+ /// (do not zero terminate).
1516
+ /// </summary>
1517
+ /// <param name="index"></param>
1518
+ /// <param name="buffer_length"></param>
1519
+ /// <returns>a pointer to a writable buffer</returns>
1520
+ char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1521
+
1522
+ #if !defined(DISABLE_SPARSE_TENSORS)
1523
+ /// <summary>
1524
+ /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1525
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1526
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1527
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1528
+ /// </summary>
1529
+ /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1530
+ /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1531
+ void UseCooIndices(int64_t* indices_data, size_t indices_num);
1532
+
1533
+ /// <summary>
1534
+ /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1535
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1536
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1537
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1538
+ /// </summary>
1539
+ /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1540
+ /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1541
+ /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1542
+ /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1543
+ void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1544
+
1545
+ /// <summary>
1546
+ /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1547
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1548
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1549
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1550
+ /// </summary>
1551
+ /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1552
+ /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1553
+ void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1554
+
1555
+ /// <summary>
1556
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1557
+ /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1558
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1559
+ /// </summary>
1560
+ /// <param name="data_mem_info">specified buffer memory description</param>
1561
+ /// <param name="values_param">values buffer information.</param>
1562
+ /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1563
+ /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1564
+ void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1565
+ const int64_t* indices_data, size_t indices_num);
1566
+
1567
+ /// <summary>
1568
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1569
+ /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1570
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1571
+ /// </summary>
1572
+ /// <param name="data_mem_info">specified buffer memory description</param>
1573
+ /// <param name="values">values buffer information</param>
1574
+ /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1575
+ /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1576
+ /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1577
+ /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1578
+ void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1579
+ const OrtSparseValuesParam& values,
1580
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1581
+ const int64_t* outer_indices_data, size_t outer_indices_num);
1582
+
1583
+ /// <summary>
1584
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1585
+ /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1586
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1587
+ /// </summary>
1588
+ /// <param name="data_mem_info">specified buffer memory description</param>
1589
+ /// <param name="values">values buffer information</param>
1590
+ /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1591
+ /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1592
+ void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1593
+ const OrtSparseValuesParam& values,
1594
+ const Shape& indices_shape,
1595
+ const int32_t* indices_data);
1596
+
1597
+ #endif
1598
+ };
1599
+
1600
+ } // namespace detail
1601
+
1602
+ using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
1603
+ using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
1604
+
1605
+ /** \brief Wrapper around ::OrtValue
1606
+ *
1607
+ */
1608
+ struct Value : detail::ValueImpl<OrtValue> {
1609
+ using Base = detail::ValueImpl<OrtValue>;
1610
+ using OrtSparseValuesParam = detail::OrtSparseValuesParam;
1611
+ using Shape = detail::Shape;
1612
+
1613
+ explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1614
+ explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1615
+ Value(Value&&) = default;
1616
+ Value& operator=(Value&&) = default;
1617
+
1618
+ ConstValue GetConst() const { return ConstValue{this->p_}; }
1619
+ UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1620
+
1621
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1622
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1623
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1624
+ * \param p_data Pointer to the data buffer.
1625
+ * \param p_data_element_count The number of elements in the data buffer.
1626
+ * \param shape Pointer to the tensor shape dimensions.
1627
+ * \param shape_len The number of tensor shape dimensions.
1628
+ */
1629
+ template <typename T>
1630
+ static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1631
+
1632
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1633
+ *
1634
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1635
+ * \param p_data Pointer to the data buffer.
1636
+ * \param p_data_byte_count The number of bytes in the data buffer.
1637
+ * \param shape Pointer to the tensor shape dimensions.
1638
+ * \param shape_len The number of tensor shape dimensions.
1639
+ * \param type The data type.
1640
+ */
1641
+ static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1642
+ ONNXTensorElementDataType type);
1643
+
1644
+ /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1645
+ * This overload will allocate the buffer for the tensor according to the supplied shape and data type.
1646
+ * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1647
+ * The input data would need to be copied into the allocated buffer.
1648
+ * This API is not suitable for strings.
1649
+ *
1650
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1651
+ * \param allocator The allocator to use.
1652
+ * \param shape Pointer to the tensor shape dimensions.
1653
+ * \param shape_len The number of tensor shape dimensions.
1654
+ */
1655
+ template <typename T>
1656
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1657
+
1658
+ /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
1659
+ * Wraps OrtApi::CreateTensorAsOrtValue.
1660
+ * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1661
+ * The input data would need to be copied into the allocated buffer.
1662
+ * This API is not suitable for strings.
1663
+ *
1664
+ * \param allocator The allocator to use.
1665
+ * \param shape Pointer to the tensor shape dimensions.
1666
+ * \param shape_len The number of tensor shape dimensions.
1667
+ * \param type The data type.
1668
+ */
1669
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1670
+
1671
+ /** \brief Creates an OrtValue with a Map Onnx type representation.
1672
+ * The API would ref-count the supplied OrtValues and they will be released
1673
+ * when the returned OrtValue is released. The caller may release keys and values after the call
1674
+ * returns.
1675
+ *
1676
+ * \param keys an OrtValue containing a tensor with primitive data type keys.
1677
+ * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
1678
+ */
1679
+ static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue
1680
+
1681
+ /** \brief Creates an OrtValue with a Sequence Onnx type representation.
1682
+ * The API would ref-count the supplied OrtValues and they will be released
1683
+ * when the returned OrtValue is released. The caller may release the values after the call
1684
+ * returns.
1685
+ *
1686
+ * \param values a vector of OrtValues that must have the same Onnx value type.
1687
+ */
1688
+ static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1689
+
1690
+ /** \brief Creates an OrtValue wrapping an Opaque type.
1691
+ * This is used for experimental support of non-tensor types.
1692
+ *
1693
+ * \tparam T - the type of the value.
1694
+ * \param domain - zero terminated utf-8 string. Domain of the type.
1695
+ * \param type_name - zero terminated utf-8 string. Name of the type.
1696
+ * \param value - the value to be wrapped.
1697
+ */
1698
+ template <typename T>
1699
+ static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue
1700
+
1701
+ #if !defined(DISABLE_SPARSE_TENSORS)
1702
+ /// <summary>
1703
+ /// This is a simple forwarding method to the other overload that helps deducing
1704
+ /// data type enum value from the type of the buffer.
1705
+ /// </summary>
1706
+ /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1707
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1708
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1709
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1710
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1711
+ /// <returns></returns>
1712
+ template <typename T>
1713
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1714
+ const Shape& values_shape);
1715
+
1716
+ /// <summary>
1717
+ /// Creates an OrtValue instance containing SparseTensor. This constructs
1718
+ /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1719
+ /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1720
+ /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1721
+ /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1722
+ /// to supply a sparse format specific indices.
1723
+ /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1724
+ /// can be properly copied into the allocated buffer.
1725
+ /// </summary>
1726
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1727
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1728
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1729
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1730
+ /// <param name="type">data type</param>
1731
+ /// <returns>Ort::Value instance containing SparseTensor</returns>
1732
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1733
+ const Shape& values_shape, ONNXTensorElementDataType type);
1734
+
1735
+ /// <summary>
1736
+ /// This is a simple forwarding method to the below CreateSparseTensor.
1737
+ /// This helps to specify data type enum in terms of C++ data type.
1738
+ /// Use CreateSparseTensor<T>
1739
+ /// </summary>
1740
+ /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1741
+ /// <param name="allocator">allocator to use</param>
1742
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1743
+ /// <returns>Ort::Value</returns>
1744
+ template <typename T>
1745
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1746
+
1747
+ /// <summary>
1748
+ /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1749
+ /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1750
+ /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1751
+ /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1752
+ /// strings.
1753
+ /// </summary>
1754
+ /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1755
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1756
+ /// <param name="type">data type</param>
1757
+ /// <returns>an instance of Ort::Value</returns>
1758
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1759
+
1760
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1761
+ };
1762
+
1763
+ /// <summary>
1764
+ /// Represents native memory allocation coming from one of the
1765
+ /// OrtAllocators registered with OnnxRuntime.
1766
+ /// Use it to wrap an allocation made by an allocator
1767
+ /// so it can be automatically released when no longer needed.
1768
+ /// </summary>
1769
+ struct MemoryAllocation {
1770
+ MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1771
+ ~MemoryAllocation();
1772
+ MemoryAllocation(const MemoryAllocation&) = delete;
1773
+ MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1774
+ MemoryAllocation(MemoryAllocation&&) noexcept;
1775
+ MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1776
+
1777
+ void* get() { return p_; }
1778
+ size_t size() const { return size_; }
1779
+
1780
+ private:
1781
+ OrtAllocator* allocator_;
1782
+ void* p_;
1783
+ size_t size_;
1784
+ };
1785
+
1786
+ namespace detail {
1787
+ template <typename T>
1788
+ struct AllocatorImpl : Base<T> {
1789
+ using B = Base<T>;
1790
+ using B::B;
1791
+
1792
+ void* Alloc(size_t size);
1793
+ MemoryAllocation GetAllocation(size_t size);
1794
+ void Free(void* p);
1795
+ ConstMemoryInfo GetInfo() const;
1796
+ };
1797
+
1798
+ } // namespace detail
1799
+
1800
+ /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1801
+ *
1802
+ */
1803
+ struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1804
+ explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1805
+ AllocatorWithDefaultOptions();
1806
+ };
1807
+
1808
+ /** \brief Wrapper around ::OrtAllocator
1809
+ *
1810
+ */
1811
+ struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1812
+ explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1813
+ Allocator(const Session& session, const OrtMemoryInfo*);
1814
+ };
1815
+
1816
+ using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1817
+
1818
+ namespace detail {
1819
+ namespace binding_utils {
1820
+ // Bring these out of template
1821
+ std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1822
+ std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1823
+ } // namespace binding_utils
1824
+
1825
+ template <typename T>
1826
+ struct ConstIoBindingImpl : Base<T> {
1827
+ using B = Base<T>;
1828
+ using B::B;
1829
+
1830
+ std::vector<std::string> GetOutputNames() const;
1831
+ std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1832
+ std::vector<Value> GetOutputValues() const;
1833
+ std::vector<Value> GetOutputValues(OrtAllocator*) const;
1834
+ };
1835
+
1836
+ template <typename T>
1837
+ struct IoBindingImpl : ConstIoBindingImpl<T> {
1838
+ using B = ConstIoBindingImpl<T>;
1839
+ using B::B;
1840
+
1841
+ void BindInput(const char* name, const Value&);
1842
+ void BindOutput(const char* name, const Value&);
1843
+ void BindOutput(const char* name, const OrtMemoryInfo*);
1844
+ void ClearBoundInputs();
1845
+ void ClearBoundOutputs();
1846
+ void SynchronizeInputs();
1847
+ void SynchronizeOutputs();
1848
+ };
1849
+
1850
+ } // namespace detail
1851
+
1852
+ using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
1853
+ using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
1854
+
1855
+ /** \brief Wrapper around ::OrtIoBinding
1856
+ *
1857
+ */
1858
+ struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1859
+ explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1860
+ explicit IoBinding(Session& session);
1861
+ ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1862
+ UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1863
+ };
1864
+
1865
+ /*! \struct Ort::ArenaCfg
1866
+ * \brief it is a structure that represents the configuration of an arena based allocator
1867
+ * \details Please see docs/C_API.md for details
1868
+ */
1869
+ struct ArenaCfg : detail::Base<OrtArenaCfg> {
1870
+ explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1871
+ /**
1872
+ * Wraps OrtApi::CreateArenaCfg
1873
+ * \param max_mem - use 0 to allow ORT to choose the default
1874
+ * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1875
+ * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1876
+ * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1877
+ * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1878
+ */
1879
+ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1880
+ };
1881
+
1882
+ //
1883
+ // Custom OPs (only needed to implement custom OPs)
1884
+ //
1885
+
1886
+ /// <summary>
1887
+ /// This struct provides life time management for custom op attribute
1888
+ /// </summary>
1889
+ struct OpAttr : detail::Base<OrtOpAttr> {
1890
+ OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1891
+ };
1892
+
1893
+ /**
1894
+ * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
1895
+ * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1896
+ *
1897
+ * \param logger The Ort::Logger instance to use. Must be a value or reference.
1898
+ * \param message_severity The logging severity level of the message.
1899
+ * \param message A null-terminated UTF-8 message to log.
1900
+ */
1901
+ #define ORT_CXX_LOG(logger, message_severity, message) \
1902
+ do { \
1903
+ if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1904
+ Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1905
+ static_cast<const char*>(__FUNCTION__), message)); \
1906
+ } \
1907
+ } while (false)
1908
+
1909
+ /**
1910
+ * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
1911
+ * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1912
+ *
1913
+ * \param logger The Ort::Logger instance to use. Must be a value or reference.
1914
+ * \param message_severity The logging severity level of the message.
1915
+ * \param message A null-terminated UTF-8 message to log.
1916
+ */
1917
+ #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \
1918
+ do { \
1919
+ if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1920
+ static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \
1921
+ static_cast<const char*>(__FUNCTION__), message)); \
1922
+ } \
1923
+ } while (false)
1924
+
1925
+ /**
1926
+ * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
1927
+ * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
1928
+ * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1929
+ *
1930
+ * \param logger The Ort::Logger instance to use. Must be a value or reference.
1931
+ * \param message_severity The logging severity level of the message.
1932
+ * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1933
+ * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1934
+ * \param ... Zero or more variadic arguments referenced by the format string.
1935
+ */
1936
+ #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \
1937
+ do { \
1938
+ if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1939
+ Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1940
+ static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1941
+ } \
1942
+ } while (false)
1943
+
1944
+ /**
1945
+ * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
1946
+ * are silently ignored.
1947
+ * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1948
+ *
1949
+ * \param logger The Ort::Logger instance to use. Must be a value or reference.
1950
+ * \param message_severity The logging severity level of the message.
1951
+ * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1952
+ * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1953
+ * \param ... Zero or more variadic arguments referenced by the format string.
1954
+ */
1955
+ #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \
1956
+ do { \
1957
+ if (message_severity >= logger.GetLoggingSeverityLevel()) { \
1958
+ static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \
1959
+ static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1960
+ } \
1961
+ } while (false)
1962
+
1963
+ /// <summary>
1964
+ /// This class represents an ONNX Runtime logger that can be used to log information with an
1965
+ /// associated severity level and source code location (file path, line number, function name).
1966
+ ///
1967
+ /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
1968
+ /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
1969
+ ///
1970
+ /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
1971
+ /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
1972
+ /// </summary>
1973
+ struct Logger {
1974
+ /**
1975
+ * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
1976
+ */
1977
+ Logger() = default;
1978
+
1979
+ /**
1980
+ * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
1981
+ */
1982
+ explicit Logger(std::nullptr_t) {}
1983
+
1984
+ /**
1985
+ * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
1986
+ * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
1987
+ *
1988
+ * \param logger The ::OrtLogger to wrap.
1989
+ */
1990
+ explicit Logger(const OrtLogger* logger);
1991
+
1992
+ ~Logger() = default;
1993
+
1994
+ Logger(const Logger&) = default;
1995
+ Logger& operator=(const Logger&) = default;
1996
+
1997
+ Logger(Logger&& v) noexcept = default;
1998
+ Logger& operator=(Logger&& v) noexcept = default;
1999
+
2000
+ /**
2001
+ * Returns the logger's current severity level from the cached member.
2002
+ *
2003
+ * \return The current ::OrtLoggingLevel.
2004
+ */
2005
+ OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
2006
+
2007
+ /**
2008
+ * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
2009
+ * macros to properly set the source code location and to use the cached severity level to potentially bypass
2010
+ * calls to the underlying C API.
2011
+ *
2012
+ * \param log_severity_level The message's logging severity level.
2013
+ * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2014
+ * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2015
+ * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2016
+ * \param message The message to log.
2017
+ * \return A Ort::Status value to indicate error or success.
2018
+ */
2019
+ Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2020
+ const char* func_name, const char* message) const noexcept;
2021
+
2022
+ /**
2023
+ * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
2024
+ * macros to properly set the source code location and to use the cached severity level to potentially bypass
2025
+ * calls to the underlying C API. Returns an error status if a formatting error occurs.
2026
+ *
2027
+ * \param log_severity_level The message's logging severity level.
2028
+ * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2029
+ * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2030
+ * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2031
+ * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2032
+ * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2033
+ * \param args Zero or more variadic arguments referenced by the format string.
2034
+ * \return A Ort::Status value to indicate error or success.
2035
+ */
2036
+ template <typename... Args>
2037
+ Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2038
+ const char* func_name, const char* format, Args&&... args) const noexcept;
2039
+
2040
+ private:
2041
+ const OrtLogger* logger_{};
2042
+ OrtLoggingLevel cached_severity_level_{};
2043
+ };
2044
+
2045
+ /// <summary>
2046
+ /// This class wraps a raw pointer OrtKernelContext* that is being passed
2047
+ /// to the custom kernel Compute() method. Use it to safely access context
2048
+ /// attributes, input and output parameters with exception safety guarantees.
2049
+ /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
2050
+ /// </summary>
2051
+ struct KernelContext {
2052
+ explicit KernelContext(OrtKernelContext* context);
2053
+ size_t GetInputCount() const;
2054
+ size_t GetOutputCount() const;
2055
+ ConstValue GetInput(size_t index) const;
2056
+ UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2057
+ UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2058
+ void* GetGPUComputeStream() const;
2059
+ Logger GetLogger() const;
2060
+ OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2061
+ OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2062
+ void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2063
+
2064
+ private:
2065
+ OrtKernelContext* ctx_;
2066
+ };
2067
+
2068
+ struct KernelInfo;
2069
+
2070
+ namespace detail {
2071
+ namespace attr_utils {
2072
+ void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2073
+ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2074
+ void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2075
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2076
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2077
+ } // namespace attr_utils
2078
+
2079
+ template <typename T>
2080
+ struct KernelInfoImpl : Base<T> {
2081
+ using B = Base<T>;
2082
+ using B::B;
2083
+
2084
+ KernelInfo Copy() const;
2085
+
2086
+ template <typename R> // R is only implemented for float, int64_t, and string
2087
+ R GetAttribute(const char* name) const {
2088
+ R val;
2089
+ attr_utils::GetAttr(this->p_, name, val);
2090
+ return val;
2091
+ }
2092
+
2093
+ template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
2094
+ std::vector<R> GetAttributes(const char* name) const {
2095
+ std::vector<R> result;
2096
+ attr_utils::GetAttrs(this->p_, name, result);
2097
+ return result;
2098
+ }
2099
+
2100
+ Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2101
+
2102
+ size_t GetInputCount() const;
2103
+ size_t GetOutputCount() const;
2104
+
2105
+ std::string GetInputName(size_t index) const;
2106
+ std::string GetOutputName(size_t index) const;
2107
+
2108
+ TypeInfo GetInputTypeInfo(size_t index) const;
2109
+ TypeInfo GetOutputTypeInfo(size_t index) const;
2110
+
2111
+ ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2112
+
2113
+ std::string GetNodeName() const;
2114
+ Logger GetLogger() const;
2115
+ };
2116
+
2117
+ } // namespace detail
2118
+
2119
+ using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
2120
+
2121
+ /// <summary>
2122
+ /// This struct owns the OrtKernInfo* pointer when a copy is made.
2123
+ /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
2124
+ /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
2125
+ /// so it does not destroy the pointer the kernel does not own.
2126
+ /// </summary>
2127
+ struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2128
+ explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
2129
+ explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
2130
+ ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2131
+ };
2132
+
2133
+ /// <summary>
2134
+ /// Create and own custom defined operation.
2135
+ /// </summary>
2136
+ struct Op : detail::Base<OrtOp> {
2137
+ explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
2138
+
2139
+ explicit Op(OrtOp*); ///< Take ownership of the OrtOp
2140
+
2141
+ static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2142
+ int version, const char** type_constraint_names,
2143
+ const ONNXTensorElementDataType* type_constraint_values,
2144
+ size_t type_constraint_count,
2145
+ const OpAttr* attr_values,
2146
+ size_t attr_count,
2147
+ size_t input_count, size_t output_count);
2148
+
2149
+ void Invoke(const OrtKernelContext* context,
2150
+ const Value* input_values,
2151
+ size_t input_count,
2152
+ Value* output_values,
2153
+ size_t output_count);
2154
+
2155
+ // For easier refactoring
2156
+ void Invoke(const OrtKernelContext* context,
2157
+ const OrtValue* const* input_values,
2158
+ size_t input_count,
2159
+ OrtValue* const* output_values,
2160
+ size_t output_count);
2161
+ };
2162
+
2163
+ /// <summary>
2164
+ /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
2165
+ /// </summary>
2166
+ struct ShapeInferContext {
2167
+ struct SymbolicInteger {
2168
+ SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
2169
+ SymbolicInteger(const char* s) : s_(s), is_int_(false){};
2170
+ SymbolicInteger(const SymbolicInteger&) = default;
2171
+ SymbolicInteger(SymbolicInteger&&) = default;
2172
+
2173
+ SymbolicInteger& operator=(const SymbolicInteger&) = default;
2174
+ SymbolicInteger& operator=(SymbolicInteger&&) = default;
2175
+
2176
+ bool operator==(const SymbolicInteger& dim) const {
2177
+ if (is_int_ == dim.is_int_) {
2178
+ if (is_int_) {
2179
+ return i_ == dim.i_;
2180
+ } else {
2181
+ return std::string{s_} == std::string{dim.s_};
2182
+ }
2183
+ }
2184
+ return false;
2185
+ }
2186
+
2187
+ bool IsInt() const { return is_int_; }
2188
+ int64_t AsInt() const { return i_; }
2189
+ const char* AsSym() const { return s_; }
2190
+
2191
+ static constexpr int INVALID_INT_DIM = -2;
2192
+
2193
+ private:
2194
+ union {
2195
+ int64_t i_;
2196
+ const char* s_;
2197
+ };
2198
+ bool is_int_;
2199
+ };
2200
+
2201
+ using Shape = std::vector<SymbolicInteger>;
2202
+
2203
+ ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
2204
+
2205
+ const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2206
+
2207
+ size_t GetInputCount() const { return input_shapes_.size(); }
2208
+
2209
+ Status SetOutputShape(size_t indice, const Shape& shape);
2210
+
2211
+ int64_t GetAttrInt(const char* attr_name);
2212
+
2213
+ using Ints = std::vector<int64_t>;
2214
+ Ints GetAttrInts(const char* attr_name);
2215
+
2216
+ float GetAttrFloat(const char* attr_name);
2217
+
2218
+ using Floats = std::vector<float>;
2219
+ Floats GetAttrFloats(const char* attr_name);
2220
+
2221
+ std::string GetAttrString(const char* attr_name);
2222
+
2223
+ using Strings = std::vector<std::string>;
2224
+ Strings GetAttrStrings(const char* attr_name);
2225
+
2226
+ private:
2227
+ const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2228
+ const OrtApi* ort_api_;
2229
+ OrtShapeInferContext* ctx_;
2230
+ std::vector<Shape> input_shapes_;
2231
+ };
2232
+
2233
+ using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
2234
+
2235
+ #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2236
+
2237
+ template <typename TOp, typename TKernel, bool WithStatus = false>
2238
+ struct CustomOpBase : OrtCustomOp {
2239
+ CustomOpBase() {
2240
+ OrtCustomOp::version = ORT_API_VERSION;
2241
+ OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2242
+
2243
+ OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2244
+
2245
+ OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2246
+ OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2247
+ OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2248
+
2249
+ OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2250
+ OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2251
+
2252
+ #if defined(_MSC_VER) && !defined(__clang__)
2253
+ #pragma warning(push)
2254
+ #pragma warning(disable : 26409)
2255
+ #endif
2256
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2257
+ #if defined(_MSC_VER) && !defined(__clang__)
2258
+ #pragma warning(pop)
2259
+ #endif
2260
+ OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2261
+ OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2262
+
2263
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2264
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2265
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2266
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2267
+ #ifdef __cpp_if_constexpr
2268
+ if constexpr (WithStatus) {
2269
+ #else
2270
+ if (WithStatus) {
2271
+ #endif
2272
+ OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2273
+ return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2274
+ };
2275
+ OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2276
+ return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2277
+ };
2278
+ } else {
2279
+ OrtCustomOp::CreateKernelV2 = nullptr;
2280
+ OrtCustomOp::KernelComputeV2 = nullptr;
2281
+
2282
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2283
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2284
+ static_cast<TKernel*>(op_kernel)->Compute(context);
2285
+ };
2286
+ }
2287
+
2288
+ SetShapeInferFn<TOp>(0);
2289
+
2290
+ OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2291
+ return static_cast<const TOp*>(this_)->start_ver_;
2292
+ };
2293
+
2294
+ OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2295
+ return static_cast<const TOp*>(this_)->end_ver_;
2296
+ };
2297
+ }
2298
+
2299
+ // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2300
+ const char* GetExecutionProviderType() const { return nullptr; }
2301
+
2302
+ // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2303
+ // (inputs and outputs are required by default)
2304
+ OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
2305
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2306
+ }
2307
+
2308
+ OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
2309
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2310
+ }
2311
+
2312
+ // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2313
+ OrtMemType GetInputMemoryType(size_t /*index*/) const {
2314
+ return OrtMemTypeDefault;
2315
+ }
2316
+
2317
+ // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2318
+ // should expect at least 1 argument.
2319
+ int GetVariadicInputMinArity() const {
2320
+ return 1;
2321
+ }
2322
+
2323
+ // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2324
+ // to a variadic input should be of the same type.
2325
+ bool GetVariadicInputHomogeneity() const {
2326
+ return true;
2327
+ }
2328
+
2329
+ // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2330
+ // should produce at least 1 output value.
2331
+ int GetVariadicOutputMinArity() const {
2332
+ return 1;
2333
+ }
2334
+
2335
+ // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2336
+ // produced by a variadic output should be of the same type.
2337
+ bool GetVariadicOutputHomogeneity() const {
2338
+ return true;
2339
+ }
2340
+
2341
+ // Declare list of session config entries used by this Custom Op.
2342
+ // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2343
+ // This default implementation returns an empty vector of config entries.
2344
+ std::vector<std::string> GetSessionConfigKeys() const {
2345
+ return std::vector<std::string>{};
2346
+ }
2347
+
2348
+ template <typename C>
2349
+ decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2350
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2351
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
2352
+ return C::InferOutputShape(ctx);
2353
+ };
2354
+ return {};
2355
+ }
2356
+
2357
+ template <typename C>
2358
+ void SetShapeInferFn(...) {
2359
+ OrtCustomOp::InferOutputShapeFn = {};
2360
+ }
2361
+
2362
+ protected:
2363
+ // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2364
+ void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2365
+
2366
+ int start_ver_ = 1;
2367
+ int end_ver_ = MAX_CUSTOM_OP_END_VER;
2368
+ };
2369
+
2370
+ } // namespace Ort
2371
+
2372
+ #include "onnxruntime_cxx_inline.h"