com.github.asus4.onnxruntime 0.1.10 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. package/Plugins/Android/onnxruntime-android.aar +0 -0
  2. package/Plugins/Linux/x64/libonnxruntime.so +0 -0
  3. package/Plugins/Windows/x64/onnxruntime.dll +0 -0
  4. package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +13 -0
  5. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
  6. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
  7. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
  8. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  9. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  10. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
  11. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
  12. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +182 -15
  13. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +110 -4
  14. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +189 -0
  15. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  16. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  17. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
  18. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
  19. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/coreml_provider_factory.h +45 -0
  20. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/cpu_provider_factory.h +19 -0
  21. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_c_api.h +4717 -0
  22. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +2372 -0
  23. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +2075 -0
  24. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_float16.h +540 -0
  25. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +32 -0
  26. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +258 -0
  27. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/Info.plist +20 -0
  28. package/Plugins/iOS~/onnxruntime.xcframework/macos-arm64_x86_64/onnxruntime.framework/onnxruntime +0 -0
  29. package/Plugins/macOS/libonnxruntime.dylib +0 -0
  30. package/README.md +8 -8
  31. package/Runtime/NativeMethods.shared.cs +270 -276
  32. package/Runtime/OrtValue.shared.cs +7 -3
  33. package/Runtime/Training/NativeTrainingMethods.shared.cs +2 -2
  34. package/package.json +1 -1
Binary file
Binary file
@@ -4,6 +4,19 @@
4
4
  <dict>
5
5
  <key>AvailableLibraries</key>
6
6
  <array>
7
+ <dict>
8
+ <key>LibraryIdentifier</key>
9
+ <string>macos-arm64_x86_64</string>
10
+ <key>LibraryPath</key>
11
+ <string>onnxruntime.framework</string>
12
+ <key>SupportedArchitectures</key>
13
+ <array>
14
+ <string>arm64</string>
15
+ <string>x86_64</string>
16
+ </array>
17
+ <key>SupportedPlatform</key>
18
+ <string>macos</string>
19
+ </dict>
7
20
  <dict>
8
21
  <key>LibraryIdentifier</key>
9
22
  <string>ios-arm64</string>
@@ -29,15 +29,16 @@
29
29
  */
30
30
 
31
31
  #pragma once
32
- #include <stdlib.h>
32
+ #include <stdbool.h>
33
33
  #include <stdint.h>
34
+ #include <stdlib.h>
34
35
  #include <string.h>
35
36
 
36
37
  /** \brief The API version defined in this header
37
38
  *
38
39
  * This value is used by some API functions to behave as this version of the header expects.
39
40
  */
40
- #define ORT_API_VERSION 16
41
+ #define ORT_API_VERSION 17
41
42
 
42
43
  #ifdef __cplusplus
43
44
  extern "C" {
@@ -299,6 +300,7 @@ ORT_RUNTIME_CLASS(DnnlProviderOptions);
299
300
  ORT_RUNTIME_CLASS(Op);
300
301
  ORT_RUNTIME_CLASS(OpAttr);
301
302
  ORT_RUNTIME_CLASS(Logger);
303
+ ORT_RUNTIME_CLASS(ShapeInferContext);
302
304
 
303
305
  #ifdef _WIN32
304
306
  typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -598,9 +600,11 @@ typedef struct OrtTensorRTProviderOptions {
598
600
  * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
599
601
  */
600
602
  typedef struct OrtMIGraphXProviderOptions {
601
- int device_id; // hip device id.
602
- int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true
603
- int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true
603
+ int device_id; // hip device id.
604
+ int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
605
+ int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
606
+ int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
607
+ const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
604
608
  } OrtMIGraphXProviderOptions;
605
609
 
606
610
  /** \brief OpenVINO Provider Options
@@ -610,7 +614,7 @@ typedef struct OrtMIGraphXProviderOptions {
610
614
  typedef struct OrtOpenVINOProviderOptions {
611
615
  #ifdef __cplusplus
612
616
  OrtOpenVINOProviderOptions() : device_type{},
613
- enable_vpu_fast_compile{},
617
+ enable_npu_fast_compile{},
614
618
  device_id{},
615
619
  num_of_threads{},
616
620
  cache_dir{},
@@ -623,7 +627,7 @@ typedef struct OrtOpenVINOProviderOptions {
623
627
  * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"
624
628
  */
625
629
  const char* device_type;
626
- unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled
630
+ unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled
627
631
  const char* device_id;
628
632
  size_t num_of_threads; ///< 0 = Use default number of threads
629
633
  const char* cache_dir; // path is set to empty by default
@@ -745,6 +749,8 @@ struct OrtApi {
745
749
 
746
750
  /** \brief Create an OrtEnv
747
751
  *
752
+ * \note Invoking this function will return the same instance of the environment as that returned by a previous call
753
+ * to another env creation function; all arguments to this function will be ignored.
748
754
  * \param[in] log_severity_level The log severity level.
749
755
  * \param[in] logid The log identifier.
750
756
  * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv
@@ -755,17 +761,20 @@ struct OrtApi {
755
761
 
756
762
  /** \brief Create an OrtEnv
757
763
  *
764
+ * \note Invoking this function will return the same instance of the environment as that returned by a previous call
765
+ * to another env creation function; all arguments to this function will be ignored. If you want to provide your
766
+ * own logging function, consider setting it using the SetUserLoggingFunction API instead.
758
767
  * \param[in] logging_function A pointer to a logging function.
759
768
  * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to
760
- * `logging_function`.
769
+ * `logging_function`. This parameter is optional.
761
770
  * \param[in] log_severity_level The log severity level.
762
771
  * \param[in] logid The log identifier.
763
772
  * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv
764
773
  *
765
774
  * \snippet{doc} snippets.dox OrtStatus Return Value
766
775
  */
767
- ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param,
768
- OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out);
776
+ ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param,
777
+ _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out);
769
778
 
770
779
  /** \brief Enable Telemetry
771
780
  *
@@ -3585,13 +3594,28 @@ struct OrtApi {
3585
3594
  *
3586
3595
  * QNN supported keys:
3587
3596
  * "backend_path": file path to QNN backend library.
3588
- * "qnn_context_cache_enable": 1 to enable QNN graph creation from cached QNN context file. If it's enabled: QNN EP will
3589
- * load from cached QNN context binary if it exist. It will generate a context binary file if it's not exist
3590
- * "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
3591
3597
  * "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off.
3592
3598
  * "rpc_control_latency": QNN RPC control latency.
3599
+ * "vtcm_mb": QNN VTCM size in MB. default to 0(not set).
3593
3600
  * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
3594
- * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
3601
+ * "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
3602
+ * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will
3603
+ * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and
3604
+ * may alter model/EP partitioning. Use only for debugging.
3605
+ * "qnn_context_priority": QNN context priority, options: "low", "normal", "normal_high", "high". Default to "normal".
3606
+ * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. Available options:
3607
+ * - "0": Default.
3608
+ * - "1": Faster preparation time, less optimal graph.
3609
+ * - "2": Longer preparation time, more optimal graph.
3610
+ * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
3611
+ * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
3612
+ * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
3613
+ * - "0": Default (none).
3614
+ * - "68"
3615
+ * - "69"
3616
+ * - "73"
3617
+ * - "75"
3618
+ * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
3595
3619
  *
3596
3620
  * SNPE supported keys:
3597
3621
  * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
@@ -4402,7 +4426,7 @@ struct OrtApi {
4402
4426
  ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);
4403
4427
 
4404
4428
  /**
4405
- * Get a EP resoure.
4429
+ * Get a EP resource.
4406
4430
  * E.g. a cuda stream or a cublas handle
4407
4431
  *
4408
4432
  * \param context - Kernel context
@@ -4413,6 +4437,135 @@ struct OrtApi {
4413
4437
  * \since Version 1.16.
4414
4438
  */
4415
4439
  ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource);
4440
+
4441
+ /** \brief Set user logging function
4442
+ *
4443
+ * By default the logger created by the CreateEnv* functions is used to create the session logger as well.
4444
+ * This function allows a user to override this default session logger with a logger of their own choosing. This way
4445
+ * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when
4446
+ * the user already created an env but now wants to use a different logger for a specific session (for debugging or
4447
+ * other reasons).
4448
+ *
4449
+ * \param[in] options
4450
+ * \param[in] user_logging_function A pointer to a logging function.
4451
+ * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to
4452
+ * `user_logging_function`. This parameter is optional.
4453
+ *
4454
+ * \snippet{doc} snippets.dox OrtStatus Return Value
4455
+ *
4456
+ * \since Version 1.17.
4457
+ */
4458
+ ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options,
4459
+ _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param);
4460
+
4461
+ /**
4462
+ * Get number of input from OrtShapeInferContext
4463
+ *
4464
+ * \param[in] context
4465
+ * \param[out] out The number of inputs
4466
+ *
4467
+ * \since Version 1.17.
4468
+ */
4469
+ ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out);
4470
+
4471
+ /**
4472
+ * Get type and shape info of an input
4473
+ *
4474
+ * \param[in] context
4475
+ * \param[in] index The index of the input
4476
+ * \param[out] info Type shape info of the input
4477
+ *
4478
+ * \since Version 1.17.
4479
+ */
4480
+ ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info);
4481
+
4482
+ /**
4483
+ * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node.
4484
+ *
4485
+ * \param[in] context
4486
+ * \param[in] attr_name Name of the attribute
4487
+ * \param[out] attr Handle of the attribute fetched
4488
+ *
4489
+ * \since Version 1.17.
4490
+ */
4491
+ ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr);
4492
+
4493
+ /**
4494
+ * Set type and shape info of an ouput
4495
+ *
4496
+ * \param[in] context
4497
+ * \param[in] index The index of the ouput
4498
+ * \param[out] info Type shape info of the output
4499
+ *
4500
+ * \since Version 1.17.
4501
+ */
4502
+ ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info);
4503
+
4504
+ /**
4505
+ * Set symbolic shape to type shape info
4506
+ *
4507
+ * \param[in] info Type shape info
4508
+ * \param[in] dim_params Symbolic strings
4509
+ * \param[in] dim_params_length Number of strings
4510
+ *
4511
+ * \since Version 1.17.
4512
+ */
4513
+ ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length);
4514
+
4515
+ /**
4516
+ * Read contents of an attribute to data
4517
+ *
4518
+ * \param[in] op_attr
4519
+ * \param[in] type Attribute type
4520
+ * \param[out] data Memory address to save raw content of the attribute
4521
+ * \param[in] len Number of bytes allowed to store in data
4522
+ * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success
4523
+ *
4524
+ * \since Version 1.17.
4525
+ */
4526
+ ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
4527
+
4528
+ /** \brief Set whether to use deterministic compute.
4529
+ *
4530
+ * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible.
4531
+ * Note that this most likely will have a performance cost.
4532
+ *
4533
+ * \param[in] options
4534
+ * \param[in] value
4535
+ *
4536
+ * \since Version 1.17.
4537
+ */
4538
+ ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);
4539
+
4540
+ /**
4541
+ * Run fn in parallel
4542
+ *
4543
+ * \param[in] context
4544
+ * \param[in] fn Function accepting usr_data and an integer as iterator
4545
+ * \param[in] total The number of times fn is to be invoked
4546
+ * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit
4547
+ * \param[in] usr_data User data to be passed back to fn
4548
+ *
4549
+ * \since Version 1.17.
4550
+ */
4551
+ ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data);
4552
+
4553
+ /** \brief Append OpenVINO execution provider to the session options
4554
+ *
4555
+ * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail.
4556
+ *
4557
+ * \param[in] options
4558
+ * \param[in] provider_options_keys
4559
+ * \param[in] provider_options_values
4560
+ * \param[in] num_keys
4561
+ *
4562
+ * \snippet{doc} snippets.dox OrtStatus Return Value
4563
+ */
4564
+ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
4565
+ _In_ OrtSessionOptions* options,
4566
+ _In_reads_(num_keys) const char* const* provider_options_keys,
4567
+ _In_reads_(num_keys) const char* const* provider_options_values,
4568
+ _In_ size_t num_keys);
4416
4569
  };
4417
4570
 
4418
4571
  /*
@@ -4504,6 +4657,12 @@ struct OrtCustomOp {
4504
4657
 
4505
4658
  // Perform the computation step.
4506
4659
  OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context);
4660
+
4661
+ OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*);
4662
+
4663
+ // Get start range
4664
+ int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
4665
+ int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
4507
4666
  };
4508
4667
 
4509
4668
  /*
@@ -4544,6 +4703,14 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessio
4544
4703
  */
4545
4704
  ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena);
4546
4705
 
4706
+ /*
4707
+ * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality
4708
+ * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists
4709
+ *
4710
+ * \param device_id CUDA device id, starts from zero.
4711
+ */
4712
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
4713
+
4547
4714
  #ifdef __cplusplus
4548
4715
  }
4549
4716
  #endif
@@ -845,6 +845,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
845
845
  SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
846
846
  SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
847
847
  SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
848
+ SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute
848
849
 
849
850
  SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
850
851
  SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
@@ -873,10 +874,12 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
873
874
  SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
874
875
  SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
875
876
 
876
- SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
877
- SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
878
- SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
879
- SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
877
+ SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
878
+ SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
879
+ SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
880
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
881
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
882
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
880
883
  SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
881
884
  SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
882
885
  SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
@@ -2055,6 +2058,8 @@ struct KernelContext {
2055
2058
  void* GetGPUComputeStream() const;
2056
2059
  Logger GetLogger() const;
2057
2060
  OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2061
+ OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2062
+ void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2058
2063
 
2059
2064
  private:
2060
2065
  OrtKernelContext* ctx_;
@@ -2155,6 +2160,80 @@ struct Op : detail::Base<OrtOp> {
2155
2160
  size_t output_count);
2156
2161
  };
2157
2162
 
2163
+ /// <summary>
2164
+ /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
2165
+ /// </summary>
2166
+ struct ShapeInferContext {
2167
+ struct SymbolicInteger {
2168
+ SymbolicInteger(int64_t i) : i_(i), is_int_(true){};
2169
+ SymbolicInteger(const char* s) : s_(s), is_int_(false){};
2170
+ SymbolicInteger(const SymbolicInteger&) = default;
2171
+ SymbolicInteger(SymbolicInteger&&) = default;
2172
+
2173
+ SymbolicInteger& operator=(const SymbolicInteger&) = default;
2174
+ SymbolicInteger& operator=(SymbolicInteger&&) = default;
2175
+
2176
+ bool operator==(const SymbolicInteger& dim) const {
2177
+ if (is_int_ == dim.is_int_) {
2178
+ if (is_int_) {
2179
+ return i_ == dim.i_;
2180
+ } else {
2181
+ return std::string{s_} == std::string{dim.s_};
2182
+ }
2183
+ }
2184
+ return false;
2185
+ }
2186
+
2187
+ bool IsInt() const { return is_int_; }
2188
+ int64_t AsInt() const { return i_; }
2189
+ const char* AsSym() const { return s_; }
2190
+
2191
+ static constexpr int INVALID_INT_DIM = -2;
2192
+
2193
+ private:
2194
+ union {
2195
+ int64_t i_;
2196
+ const char* s_;
2197
+ };
2198
+ bool is_int_;
2199
+ };
2200
+
2201
+ using Shape = std::vector<SymbolicInteger>;
2202
+
2203
+ ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
2204
+
2205
+ const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2206
+
2207
+ size_t GetInputCount() const { return input_shapes_.size(); }
2208
+
2209
+ Status SetOutputShape(size_t indice, const Shape& shape);
2210
+
2211
+ int64_t GetAttrInt(const char* attr_name);
2212
+
2213
+ using Ints = std::vector<int64_t>;
2214
+ Ints GetAttrInts(const char* attr_name);
2215
+
2216
+ float GetAttrFloat(const char* attr_name);
2217
+
2218
+ using Floats = std::vector<float>;
2219
+ Floats GetAttrFloats(const char* attr_name);
2220
+
2221
+ std::string GetAttrString(const char* attr_name);
2222
+
2223
+ using Strings = std::vector<std::string>;
2224
+ Strings GetAttrStrings(const char* attr_name);
2225
+
2226
+ private:
2227
+ const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2228
+ const OrtApi* ort_api_;
2229
+ OrtShapeInferContext* ctx_;
2230
+ std::vector<Shape> input_shapes_;
2231
+ };
2232
+
2233
+ using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
2234
+
2235
+ #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2236
+
2158
2237
  template <typename TOp, typename TKernel, bool WithStatus = false>
2159
2238
  struct CustomOpBase : OrtCustomOp {
2160
2239
  CustomOpBase() {
@@ -2205,6 +2284,16 @@ struct CustomOpBase : OrtCustomOp {
2205
2284
  static_cast<TKernel*>(op_kernel)->Compute(context);
2206
2285
  };
2207
2286
  }
2287
+
2288
+ SetShapeInferFn<TOp>(0);
2289
+
2290
+ OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2291
+ return static_cast<const TOp*>(this_)->start_ver_;
2292
+ };
2293
+
2294
+ OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2295
+ return static_cast<const TOp*>(this_)->end_ver_;
2296
+ };
2208
2297
  }
2209
2298
 
2210
2299
  // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
@@ -2256,9 +2345,26 @@ struct CustomOpBase : OrtCustomOp {
2256
2345
  return std::vector<std::string>{};
2257
2346
  }
2258
2347
 
2348
+ template <typename C>
2349
+ decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2350
+ OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2351
+ ShapeInferContext ctx(&GetApi(), ort_ctx);
2352
+ return C::InferOutputShape(ctx);
2353
+ };
2354
+ return {};
2355
+ }
2356
+
2357
+ template <typename C>
2358
+ void SetShapeInferFn(...) {
2359
+ OrtCustomOp::InferOutputShapeFn = {};
2360
+ }
2361
+
2259
2362
  protected:
2260
2363
  // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2261
2364
  void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2365
+
2366
+ int start_ver_ = 1;
2367
+ int end_ver_ = MAX_CUSTOM_OP_END_VER;
2262
2368
  };
2263
2369
 
2264
2370
  } // namespace Ort