com.github.asus4.onnxruntime 0.1.11 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Plugins/Android/onnxruntime-android.aar +0 -0
- package/Plugins/Linux/x64/libonnxruntime.so +0 -0
- package/Plugins/Windows/x64/onnxruntime.dll +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +13 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/coreml_provider_factory.h +45 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/cpu_provider_factory.h +19 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_c_api.h +4717 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +2372 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +2075 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_float16.h +540 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Info.plist +20 -0
- package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/macOS/libonnxruntime.dylib +0 -0
- package/README.md +7 -7
- package/Runtime/NativeMethods.shared.cs +270 -276
- package/Runtime/OrtValue.shared.cs +7 -3
- package/Runtime/Training/NativeTrainingMethods.shared.cs +2 -2
- package/package.json +1 -1
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -4,6 +4,19 @@
|
|
|
4
4
|
<dict>
|
|
5
5
|
<key>AvailableLibraries</key>
|
|
6
6
|
<array>
|
|
7
|
+
<dict>
|
|
8
|
+
<key>LibraryIdentifier</key>
|
|
9
|
+
<string>macos-arm64_x86_64</string>
|
|
10
|
+
<key>LibraryPath</key>
|
|
11
|
+
<string>onnxruntime.framework</string>
|
|
12
|
+
<key>SupportedArchitectures</key>
|
|
13
|
+
<array>
|
|
14
|
+
<string>arm64</string>
|
|
15
|
+
<string>x86_64</string>
|
|
16
|
+
</array>
|
|
17
|
+
<key>SupportedPlatform</key>
|
|
18
|
+
<string>macos</string>
|
|
19
|
+
</dict>
|
|
7
20
|
<dict>
|
|
8
21
|
<key>LibraryIdentifier</key>
|
|
9
22
|
<string>ios-arm64</string>
|
|
@@ -29,15 +29,16 @@
|
|
|
29
29
|
*/
|
|
30
30
|
|
|
31
31
|
#pragma once
|
|
32
|
-
#include <
|
|
32
|
+
#include <stdbool.h>
|
|
33
33
|
#include <stdint.h>
|
|
34
|
+
#include <stdlib.h>
|
|
34
35
|
#include <string.h>
|
|
35
36
|
|
|
36
37
|
/** \brief The API version defined in this header
|
|
37
38
|
*
|
|
38
39
|
* This value is used by some API functions to behave as this version of the header expects.
|
|
39
40
|
*/
|
|
40
|
-
#define ORT_API_VERSION
|
|
41
|
+
#define ORT_API_VERSION 17
|
|
41
42
|
|
|
42
43
|
#ifdef __cplusplus
|
|
43
44
|
extern "C" {
|
|
@@ -299,6 +300,7 @@ ORT_RUNTIME_CLASS(DnnlProviderOptions);
|
|
|
299
300
|
ORT_RUNTIME_CLASS(Op);
|
|
300
301
|
ORT_RUNTIME_CLASS(OpAttr);
|
|
301
302
|
ORT_RUNTIME_CLASS(Logger);
|
|
303
|
+
ORT_RUNTIME_CLASS(ShapeInferContext);
|
|
302
304
|
|
|
303
305
|
#ifdef _WIN32
|
|
304
306
|
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
|
|
@@ -598,9 +600,11 @@ typedef struct OrtTensorRTProviderOptions {
|
|
|
598
600
|
* \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
|
|
599
601
|
*/
|
|
600
602
|
typedef struct OrtMIGraphXProviderOptions {
|
|
601
|
-
int device_id;
|
|
602
|
-
int migraphx_fp16_enable;
|
|
603
|
-
int migraphx_int8_enable;
|
|
603
|
+
int device_id; // hip device id.
|
|
604
|
+
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
|
|
605
|
+
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
|
|
606
|
+
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
|
|
607
|
+
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
|
|
604
608
|
} OrtMIGraphXProviderOptions;
|
|
605
609
|
|
|
606
610
|
/** \brief OpenVINO Provider Options
|
|
@@ -610,7 +614,7 @@ typedef struct OrtMIGraphXProviderOptions {
|
|
|
610
614
|
typedef struct OrtOpenVINOProviderOptions {
|
|
611
615
|
#ifdef __cplusplus
|
|
612
616
|
OrtOpenVINOProviderOptions() : device_type{},
|
|
613
|
-
|
|
617
|
+
enable_npu_fast_compile{},
|
|
614
618
|
device_id{},
|
|
615
619
|
num_of_threads{},
|
|
616
620
|
cache_dir{},
|
|
@@ -623,7 +627,7 @@ typedef struct OrtOpenVINOProviderOptions {
|
|
|
623
627
|
* Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"
|
|
624
628
|
*/
|
|
625
629
|
const char* device_type;
|
|
626
|
-
unsigned char
|
|
630
|
+
unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled
|
|
627
631
|
const char* device_id;
|
|
628
632
|
size_t num_of_threads; ///< 0 = Use default number of threads
|
|
629
633
|
const char* cache_dir; // path is set to empty by default
|
|
@@ -745,6 +749,8 @@ struct OrtApi {
|
|
|
745
749
|
|
|
746
750
|
/** \brief Create an OrtEnv
|
|
747
751
|
*
|
|
752
|
+
* \note Invoking this function will return the same instance of the environment as that returned by a previous call
|
|
753
|
+
* to another env creation function; all arguments to this function will be ignored.
|
|
748
754
|
* \param[in] log_severity_level The log severity level.
|
|
749
755
|
* \param[in] logid The log identifier.
|
|
750
756
|
* \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv
|
|
@@ -755,17 +761,20 @@ struct OrtApi {
|
|
|
755
761
|
|
|
756
762
|
/** \brief Create an OrtEnv
|
|
757
763
|
*
|
|
764
|
+
* \note Invoking this function will return the same instance of the environment as that returned by a previous call
|
|
765
|
+
* to another env creation function; all arguments to this function will be ignored. If you want to provide your
|
|
766
|
+
* own logging function, consider setting it using the SetUserLoggingFunction API instead.
|
|
758
767
|
* \param[in] logging_function A pointer to a logging function.
|
|
759
768
|
* \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to
|
|
760
|
-
* `logging_function`.
|
|
769
|
+
* `logging_function`. This parameter is optional.
|
|
761
770
|
* \param[in] log_severity_level The log severity level.
|
|
762
771
|
* \param[in] logid The log identifier.
|
|
763
772
|
* \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv
|
|
764
773
|
*
|
|
765
774
|
* \snippet{doc} snippets.dox OrtStatus Return Value
|
|
766
775
|
*/
|
|
767
|
-
ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param,
|
|
768
|
-
OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out);
|
|
776
|
+
ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param,
|
|
777
|
+
_In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out);
|
|
769
778
|
|
|
770
779
|
/** \brief Enable Telemetry
|
|
771
780
|
*
|
|
@@ -3585,13 +3594,28 @@ struct OrtApi {
|
|
|
3585
3594
|
*
|
|
3586
3595
|
* QNN supported keys:
|
|
3587
3596
|
* "backend_path": file path to QNN backend library.
|
|
3588
|
-
* "qnn_context_cache_enable": 1 to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will
|
|
3589
|
-
* load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist
|
|
3590
|
-
* "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
|
|
3591
3597
|
* "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off.
|
|
3592
3598
|
* "rpc_control_latency": QNN RPC control latency.
|
|
3599
|
+
* "vtcm_mb": QNN VTCM size in MB. default to 0(not set).
|
|
3593
3600
|
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
|
|
3594
|
-
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
|
|
3601
|
+
* "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
|
|
3602
|
+
* "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
|
|
3603
|
+
* dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
|
|
3604
|
+
* may alter model/EP partitioning. Use only for debugging.
|
|
3605
|
+
* "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal".
|
|
3606
|
+
* "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options:
|
|
3607
|
+
* - "0": Default.
|
|
3608
|
+
* - "1": Faster preparation time, less optimal graph.
|
|
3609
|
+
* - "2": Longer preparation time, more optimal graph.
|
|
3610
|
+
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
|
|
3611
|
+
* "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
|
|
3612
|
+
* "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
|
|
3613
|
+
* - "0": Default (none).
|
|
3614
|
+
* - "68"
|
|
3615
|
+
* - "69"
|
|
3616
|
+
* - "73"
|
|
3617
|
+
* - "75"
|
|
3618
|
+
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
|
|
3595
3619
|
*
|
|
3596
3620
|
* SNPE supported keys:
|
|
3597
3621
|
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
|
|
@@ -4402,7 +4426,7 @@ struct OrtApi {
|
|
|
4402
4426
|
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
|
|
4403
4427
|
|
|
4404
4428
|
/**
|
|
4405
|
-
* Get a EP
|
|
4429
|
+
* Get a EP resource.
|
|
4406
4430
|
* E.g. a cuda stream or a cublas handle
|
|
4407
4431
|
*
|
|
4408
4432
|
* \param context - Kernel context
|
|
@@ -4413,6 +4437,135 @@ struct OrtApi {
|
|
|
4413
4437
|
* \since Version 1.16.
|
|
4414
4438
|
*/
|
|
4415
4439
|
ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource);
|
|
4440
|
+
|
|
4441
|
+
/** \brief Set user logging function
|
|
4442
|
+
*
|
|
4443
|
+
* By default the logger created by the CreateEnv* functions is used to create the session logger as well.
|
|
4444
|
+
* This function allows a user to override this default session logger with a logger of their own choosing. This way
|
|
4445
|
+
* the user doesn't have to create a separate environment with a custom logger. This addresses the problem when
|
|
4446
|
+
* the user already created an env but now wants to use a different logger for a specific session (for debugging or
|
|
4447
|
+
* other reasons).
|
|
4448
|
+
*
|
|
4449
|
+
* \param[in] options
|
|
4450
|
+
* \param[in] user_logging_function A pointer to a logging function.
|
|
4451
|
+
* \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to
|
|
4452
|
+
* `user_logging_function`. This parameter is optional.
|
|
4453
|
+
*
|
|
4454
|
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
|
4455
|
+
*
|
|
4456
|
+
* \since Version 1.17.
|
|
4457
|
+
*/
|
|
4458
|
+
ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options,
|
|
4459
|
+
_In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param);
|
|
4460
|
+
|
|
4461
|
+
/**
|
|
4462
|
+
* Get number of input from OrtShapeInferContext
|
|
4463
|
+
*
|
|
4464
|
+
* \param[in] context
|
|
4465
|
+
* \param[out] out The number of inputs
|
|
4466
|
+
*
|
|
4467
|
+
* \since Version 1.17.
|
|
4468
|
+
*/
|
|
4469
|
+
ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out);
|
|
4470
|
+
|
|
4471
|
+
/**
|
|
4472
|
+
* Get type and shape info of an input
|
|
4473
|
+
*
|
|
4474
|
+
* \param[in] context
|
|
4475
|
+
* \param[in] index The index of the input
|
|
4476
|
+
* \param[out] info Type shape info of the input
|
|
4477
|
+
*
|
|
4478
|
+
* \since Version 1.17.
|
|
4479
|
+
*/
|
|
4480
|
+
ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info);
|
|
4481
|
+
|
|
4482
|
+
/**
|
|
4483
|
+
* Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node.
|
|
4484
|
+
*
|
|
4485
|
+
* \param[in] context
|
|
4486
|
+
* \param[in] attr_name Name of the attribute
|
|
4487
|
+
* \param[out] attr Handle of the attribute fetched
|
|
4488
|
+
*
|
|
4489
|
+
* \since Version 1.17.
|
|
4490
|
+
*/
|
|
4491
|
+
ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr);
|
|
4492
|
+
|
|
4493
|
+
/**
|
|
4494
|
+
* Set type and shape info of an ouput
|
|
4495
|
+
*
|
|
4496
|
+
* \param[in] context
|
|
4497
|
+
* \param[in] index The index of the ouput
|
|
4498
|
+
* \param[out] info Type shape info of the output
|
|
4499
|
+
*
|
|
4500
|
+
* \since Version 1.17.
|
|
4501
|
+
*/
|
|
4502
|
+
ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info);
|
|
4503
|
+
|
|
4504
|
+
/**
|
|
4505
|
+
* Set symbolic shape to type shape info
|
|
4506
|
+
*
|
|
4507
|
+
* \param[in] info Type shape info
|
|
4508
|
+
* \param[in] dim_params Symbolic strings
|
|
4509
|
+
* \param[in] dim_params_length Number of strings
|
|
4510
|
+
*
|
|
4511
|
+
* \since Version 1.17.
|
|
4512
|
+
*/
|
|
4513
|
+
ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length);
|
|
4514
|
+
|
|
4515
|
+
/**
|
|
4516
|
+
* Read contents of an attribute to data
|
|
4517
|
+
*
|
|
4518
|
+
* \param[in] op_attr
|
|
4519
|
+
* \param[in] type Attribute type
|
|
4520
|
+
* \param[out] data Memory address to save raw content of the attribute
|
|
4521
|
+
* \param[in] len Number of bytes allowed to store in data
|
|
4522
|
+
* \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success
|
|
4523
|
+
*
|
|
4524
|
+
* \since Version 1.17.
|
|
4525
|
+
*/
|
|
4526
|
+
ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
|
|
4527
|
+
|
|
4528
|
+
/** \brief Set whether to use deterministic compute.
|
|
4529
|
+
*
|
|
4530
|
+
* Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible.
|
|
4531
|
+
* Note that this most likely will have a performance cost.
|
|
4532
|
+
*
|
|
4533
|
+
* \param[in] options
|
|
4534
|
+
* \param[in] value
|
|
4535
|
+
*
|
|
4536
|
+
* \since Version 1.17.
|
|
4537
|
+
*/
|
|
4538
|
+
ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);
|
|
4539
|
+
|
|
4540
|
+
/**
|
|
4541
|
+
* Run fn in parallel
|
|
4542
|
+
*
|
|
4543
|
+
* \param[in] context
|
|
4544
|
+
* \param[in] fn Function accepting usr_data and an integer as iterator
|
|
4545
|
+
* \param[in] total The number of times fn is to be invoked
|
|
4546
|
+
* \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit
|
|
4547
|
+
* \param[in] usr_data User data to be passed back to fn
|
|
4548
|
+
*
|
|
4549
|
+
* \since Version 1.17.
|
|
4550
|
+
*/
|
|
4551
|
+
ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data);
|
|
4552
|
+
|
|
4553
|
+
/** \brief Append OpenVINO execution provider to the session options
|
|
4554
|
+
*
|
|
4555
|
+
* If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail.
|
|
4556
|
+
*
|
|
4557
|
+
* \param[in] options
|
|
4558
|
+
* \param[in] provider_options_keys
|
|
4559
|
+
* \param[in] provider_options_values
|
|
4560
|
+
* \param[in] num_keys
|
|
4561
|
+
*
|
|
4562
|
+
* \snippet{doc} snippets.dox OrtStatus Return Value
|
|
4563
|
+
*/
|
|
4564
|
+
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
|
|
4565
|
+
_In_ OrtSessionOptions* options,
|
|
4566
|
+
_In_reads_(num_keys) const char* const* provider_options_keys,
|
|
4567
|
+
_In_reads_(num_keys) const char* const* provider_options_values,
|
|
4568
|
+
_In_ size_t num_keys);
|
|
4416
4569
|
};
|
|
4417
4570
|
|
|
4418
4571
|
/*
|
|
@@ -4504,6 +4657,12 @@ struct OrtCustomOp {
|
|
|
4504
4657
|
|
|
4505
4658
|
// Perform the computation step.
|
|
4506
4659
|
OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
|
|
4660
|
+
|
|
4661
|
+
OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*);
|
|
4662
|
+
|
|
4663
|
+
// Get start range
|
|
4664
|
+
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
|
|
4665
|
+
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
|
|
4507
4666
|
};
|
|
4508
4667
|
|
|
4509
4668
|
/*
|
|
@@ -4544,6 +4703,14 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessio
|
|
|
4544
4703
|
*/
|
|
4545
4704
|
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena);
|
|
4546
4705
|
|
|
4706
|
+
/*
|
|
4707
|
+
* This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality
|
|
4708
|
+
* This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists
|
|
4709
|
+
*
|
|
4710
|
+
* \param device_id CUDA device id, starts from zero.
|
|
4711
|
+
*/
|
|
4712
|
+
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
|
|
4713
|
+
|
|
4547
4714
|
#ifdef __cplusplus
|
|
4548
4715
|
}
|
|
4549
4716
|
#endif
|
|
@@ -845,6 +845,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
|
|
|
845
845
|
SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
|
|
846
846
|
SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
|
|
847
847
|
SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
|
|
848
|
+
SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
|
|
848
849
|
|
|
849
850
|
SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
|
|
850
851
|
SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
|
|
@@ -873,10 +874,12 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
|
|
|
873
874
|
SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
|
|
874
875
|
SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
|
|
875
876
|
|
|
876
|
-
SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
|
|
877
|
-
SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options);
|
|
878
|
-
SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
|
|
879
|
-
SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
|
|
877
|
+
SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
|
|
878
|
+
SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
|
|
879
|
+
SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
|
|
880
|
+
SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
|
|
881
|
+
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
|
|
882
|
+
SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
|
|
880
883
|
SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
|
881
884
|
SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
|
|
882
885
|
SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
|
|
@@ -2055,6 +2058,8 @@ struct KernelContext {
|
|
|
2055
2058
|
void* GetGPUComputeStream() const;
|
|
2056
2059
|
Logger GetLogger() const;
|
|
2057
2060
|
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
|
|
2061
|
+
OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
|
|
2062
|
+
void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
|
|
2058
2063
|
|
|
2059
2064
|
private:
|
|
2060
2065
|
OrtKernelContext* ctx_;
|
|
@@ -2155,6 +2160,80 @@ struct Op : detail::Base<OrtOp> {
|
|
|
2155
2160
|
size_t output_count);
|
|
2156
2161
|
};
|
|
2157
2162
|
|
|
2163
|
+
/// <summary>
|
|
2164
|
+
/// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
|
|
2165
|
+
/// </summary>
|
|
2166
|
+
struct ShapeInferContext {
|
|
2167
|
+
struct SymbolicInteger {
|
|
2168
|
+
SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
|
|
2169
|
+
SymbolicInteger(const char* s) : s_(s), is_int_(false){};
|
|
2170
|
+
SymbolicInteger(const SymbolicInteger&) = default;
|
|
2171
|
+
SymbolicInteger(SymbolicInteger&&) = default;
|
|
2172
|
+
|
|
2173
|
+
SymbolicInteger& operator=(const SymbolicInteger&) = default;
|
|
2174
|
+
SymbolicInteger& operator=(SymbolicInteger&&) = default;
|
|
2175
|
+
|
|
2176
|
+
bool operator==(const SymbolicInteger& dim) const {
|
|
2177
|
+
if (is_int_ == dim.is_int_) {
|
|
2178
|
+
if (is_int_) {
|
|
2179
|
+
return i_ == dim.i_;
|
|
2180
|
+
} else {
|
|
2181
|
+
return std::string{s_} == std::string{dim.s_};
|
|
2182
|
+
}
|
|
2183
|
+
}
|
|
2184
|
+
return false;
|
|
2185
|
+
}
|
|
2186
|
+
|
|
2187
|
+
bool IsInt() const { return is_int_; }
|
|
2188
|
+
int64_t AsInt() const { return i_; }
|
|
2189
|
+
const char* AsSym() const { return s_; }
|
|
2190
|
+
|
|
2191
|
+
static constexpr int INVALID_INT_DIM = -2;
|
|
2192
|
+
|
|
2193
|
+
private:
|
|
2194
|
+
union {
|
|
2195
|
+
int64_t i_;
|
|
2196
|
+
const char* s_;
|
|
2197
|
+
};
|
|
2198
|
+
bool is_int_;
|
|
2199
|
+
};
|
|
2200
|
+
|
|
2201
|
+
using Shape = std::vector<SymbolicInteger>;
|
|
2202
|
+
|
|
2203
|
+
ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
|
|
2204
|
+
|
|
2205
|
+
const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
|
|
2206
|
+
|
|
2207
|
+
size_t GetInputCount() const { return input_shapes_.size(); }
|
|
2208
|
+
|
|
2209
|
+
Status SetOutputShape(size_t indice, const Shape& shape);
|
|
2210
|
+
|
|
2211
|
+
int64_t GetAttrInt(const char* attr_name);
|
|
2212
|
+
|
|
2213
|
+
using Ints = std::vector<int64_t>;
|
|
2214
|
+
Ints GetAttrInts(const char* attr_name);
|
|
2215
|
+
|
|
2216
|
+
float GetAttrFloat(const char* attr_name);
|
|
2217
|
+
|
|
2218
|
+
using Floats = std::vector<float>;
|
|
2219
|
+
Floats GetAttrFloats(const char* attr_name);
|
|
2220
|
+
|
|
2221
|
+
std::string GetAttrString(const char* attr_name);
|
|
2222
|
+
|
|
2223
|
+
using Strings = std::vector<std::string>;
|
|
2224
|
+
Strings GetAttrStrings(const char* attr_name);
|
|
2225
|
+
|
|
2226
|
+
private:
|
|
2227
|
+
const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
|
|
2228
|
+
const OrtApi* ort_api_;
|
|
2229
|
+
OrtShapeInferContext* ctx_;
|
|
2230
|
+
std::vector<Shape> input_shapes_;
|
|
2231
|
+
};
|
|
2232
|
+
|
|
2233
|
+
using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
|
|
2234
|
+
|
|
2235
|
+
#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
|
|
2236
|
+
|
|
2158
2237
|
template <typename TOp, typename TKernel, bool WithStatus = false>
|
|
2159
2238
|
struct CustomOpBase : OrtCustomOp {
|
|
2160
2239
|
CustomOpBase() {
|
|
@@ -2205,6 +2284,16 @@ struct CustomOpBase : OrtCustomOp {
|
|
|
2205
2284
|
static_cast<TKernel*>(op_kernel)->Compute(context);
|
|
2206
2285
|
};
|
|
2207
2286
|
}
|
|
2287
|
+
|
|
2288
|
+
SetShapeInferFn<TOp>(0);
|
|
2289
|
+
|
|
2290
|
+
OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
|
|
2291
|
+
return static_cast<const TOp*>(this_)->start_ver_;
|
|
2292
|
+
};
|
|
2293
|
+
|
|
2294
|
+
OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
|
|
2295
|
+
return static_cast<const TOp*>(this_)->end_ver_;
|
|
2296
|
+
};
|
|
2208
2297
|
}
|
|
2209
2298
|
|
|
2210
2299
|
// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
|
|
@@ -2256,9 +2345,26 @@ struct CustomOpBase : OrtCustomOp {
|
|
|
2256
2345
|
return std::vector<std::string>{};
|
|
2257
2346
|
}
|
|
2258
2347
|
|
|
2348
|
+
template <typename C>
|
|
2349
|
+
decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
|
|
2350
|
+
OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
|
|
2351
|
+
ShapeInferContext ctx(&GetApi(), ort_ctx);
|
|
2352
|
+
return C::InferOutputShape(ctx);
|
|
2353
|
+
};
|
|
2354
|
+
return {};
|
|
2355
|
+
}
|
|
2356
|
+
|
|
2357
|
+
template <typename C>
|
|
2358
|
+
void SetShapeInferFn(...) {
|
|
2359
|
+
OrtCustomOp::InferOutputShapeFn = {};
|
|
2360
|
+
}
|
|
2361
|
+
|
|
2259
2362
|
protected:
|
|
2260
2363
|
// Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
|
|
2261
2364
|
void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
|
|
2365
|
+
|
|
2366
|
+
int start_ver_ = 1;
|
|
2367
|
+
int end_ver_ = MAX_CUSTOM_OP_END_VER;
|
|
2262
2368
|
};
|
|
2263
2369
|
|
|
2264
2370
|
} // namespace Ort
|