com.github.asus4.onnxruntime 0.4.1 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. package/Plugins/Android/onnxruntime.aar +0 -0
  2. package/Plugins/Linux/x64/libonnxruntime.so +0 -0
  3. package/Plugins/Windows/arm64/onnxruntime.dll +0 -0
  4. package/Plugins/Windows/x64/onnxruntime.dll +0 -0
  5. package/Plugins/Windows/x86/onnxruntime.dll +0 -0
  6. package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +11 -5
  7. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +1307 -64
  8. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +425 -24
  9. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +614 -6
  10. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +3 -0
  11. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +29 -5
  12. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
  13. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
  14. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_c_api.h +1307 -64
  15. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_cxx_api.h +425 -24
  16. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_cxx_inline.h +614 -6
  17. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_run_options_config_keys.h +3 -0
  18. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_session_options_config_keys.h +29 -5
  19. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Resources/Info.plist +2 -2
  20. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/onnxruntime +0 -0
  21. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +1307 -64
  22. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +425 -24
  23. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +614 -6
  24. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +3 -0
  25. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +29 -5
  26. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
  27. package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
  28. package/Plugins/macOS/arm64/libonnxruntime.dylib +0 -0
  29. package/Plugins/macOS/x64/libonnxruntime.dylib +0 -0
  30. package/README.md +6 -5
  31. package/Runtime/ManagedProjections.shared.cs +1 -2
  32. package/Runtime/NativeMethods.shared.cs +20 -2
  33. package/Runtime/OrtValue.shared.cs +1 -1
  34. package/Runtime/SessionOptions.shared.cs +10 -0
  35. package/Runtime/Training/NativeTrainingMethods.shared.cs +1 -1
  36. package/package.json +1 -1
@@ -10,7 +10,9 @@
10
10
  #include <algorithm>
11
11
  #include <functional>
12
12
  #include <iterator>
13
+ #include <string>
13
14
  #include <type_traits>
15
+ #include <vector>
14
16
 
15
17
  // Convert OrtStatus to Ort::Status and return
16
18
  // instead of throwing
@@ -477,6 +479,125 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom
477
479
  return *this;
478
480
  }
479
481
 
482
+ namespace detail {
483
+ template <typename T>
484
+ inline const char* KeyValuePairsImpl<T>::GetValue(const char* key) const {
485
+ return GetApi().GetKeyValue(this->p_, key);
486
+ }
487
+
488
+ template <typename T>
489
+ inline std::unordered_map<std::string, std::string> KeyValuePairsImpl<T>::GetKeyValuePairs() const {
490
+ std::unordered_map<std::string, std::string> out;
491
+
492
+ size_t num_pairs = 0;
493
+ const char* const* keys = nullptr;
494
+ const char* const* values = nullptr;
495
+ GetApi().GetKeyValuePairs(this->p_, &keys, &values, &num_pairs);
496
+ if (num_pairs > 0) {
497
+ out.reserve(num_pairs);
498
+ for (size_t i = 0; i < num_pairs; ++i) {
499
+ out.emplace(keys[i], values[i]);
500
+ }
501
+ }
502
+
503
+ return out;
504
+ }
505
+
506
+ template <typename T>
507
+ inline void KeyValuePairsImpl<T>::GetKeyValuePairs(std::vector<const char*>& keys,
508
+ std::vector<const char*>& values) const {
509
+ keys.clear();
510
+ values.clear();
511
+
512
+ size_t num_pairs = 0;
513
+ const char* const* keys_ptr = nullptr;
514
+ const char* const* values_ptr = nullptr;
515
+ GetApi().GetKeyValuePairs(this->p_, &keys_ptr, &values_ptr, &num_pairs);
516
+ if (num_pairs > 0) {
517
+ keys.resize(num_pairs);
518
+ values.resize(num_pairs);
519
+ std::copy(keys_ptr, keys_ptr + num_pairs, keys.begin());
520
+ std::copy(values_ptr, values_ptr + num_pairs, values.begin());
521
+ }
522
+ }
523
+ } // namespace detail
524
+
525
+ inline KeyValuePairs::KeyValuePairs() {
526
+ GetApi().CreateKeyValuePairs(&p_);
527
+ }
528
+
529
+ inline KeyValuePairs::KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs) {
530
+ GetApi().CreateKeyValuePairs(&p_);
531
+ for (const auto& kv : kv_pairs) {
532
+ GetApi().AddKeyValuePair(this->p_, kv.first.c_str(), kv.second.c_str());
533
+ }
534
+ }
535
+
536
+ inline void KeyValuePairs::Add(const char* key, const char* value) {
537
+ GetApi().AddKeyValuePair(this->p_, key, value);
538
+ }
539
+
540
+ inline void KeyValuePairs::Remove(const char* key) {
541
+ GetApi().RemoveKeyValuePair(this->p_, key);
542
+ }
543
+
544
+ namespace detail {
545
+ template <typename T>
546
+ inline OrtHardwareDeviceType HardwareDeviceImpl<T>::Type() const {
547
+ return GetApi().HardwareDevice_Type(this->p_);
548
+ }
549
+
550
+ template <typename T>
551
+ inline uint32_t HardwareDeviceImpl<T>::VendorId() const {
552
+ return GetApi().HardwareDevice_VendorId(this->p_);
553
+ }
554
+
555
+ template <typename T>
556
+ inline uint32_t HardwareDeviceImpl<T>::DeviceId() const {
557
+ return GetApi().HardwareDevice_DeviceId(this->p_);
558
+ }
559
+
560
+ template <typename T>
561
+ inline const char* HardwareDeviceImpl<T>::Vendor() const {
562
+ return GetApi().HardwareDevice_Vendor(this->p_);
563
+ }
564
+
565
+ template <typename T>
566
+ inline ConstKeyValuePairs HardwareDeviceImpl<T>::Metadata() const {
567
+ return ConstKeyValuePairs{GetApi().HardwareDevice_Metadata(this->p_)};
568
+ }
569
+
570
+ template <typename T>
571
+ inline const char* EpDeviceImpl<T>::EpName() const {
572
+ return GetApi().EpDevice_EpName(this->p_);
573
+ }
574
+
575
+ template <typename T>
576
+ inline const char* EpDeviceImpl<T>::EpVendor() const {
577
+ return GetApi().EpDevice_EpVendor(this->p_);
578
+ }
579
+
580
+ template <typename T>
581
+ inline ConstKeyValuePairs EpDeviceImpl<T>::EpMetadata() const {
582
+ return ConstKeyValuePairs(GetApi().EpDevice_EpMetadata(this->p_));
583
+ }
584
+
585
+ template <typename T>
586
+ inline ConstKeyValuePairs EpDeviceImpl<T>::EpOptions() const {
587
+ return ConstKeyValuePairs(GetApi().EpDevice_EpOptions(this->p_));
588
+ }
589
+
590
+ template <typename T>
591
+ inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
592
+ return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_));
593
+ }
594
+ } // namespace detail
595
+
596
+ inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
597
+ ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) {
598
+ ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_));
599
+ }
600
+
480
601
  inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
481
602
  ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
482
603
  if (strcmp(logid, "onnxruntime-node") == 0) {
@@ -549,6 +670,33 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type,
549
670
  return *this;
550
671
  }
551
672
 
673
+ inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name,
674
+ const std::basic_string<ORTCHAR_T>& path) {
675
+ ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str()));
676
+ return *this;
677
+ }
678
+
679
+ inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) {
680
+ ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name));
681
+ return *this;
682
+ }
683
+
684
+ inline std::vector<ConstEpDevice> Env::GetEpDevices() const {
685
+ size_t num_devices = 0;
686
+ const OrtEpDevice* const* device_ptrs = nullptr;
687
+ ThrowOnError(GetApi().GetEpDevices(p_, &device_ptrs, &num_devices));
688
+
689
+ std::vector<ConstEpDevice> devices;
690
+ if (num_devices > 0) {
691
+ devices.reserve(num_devices);
692
+ for (size_t i = 0; i < num_devices; ++i) {
693
+ devices.emplace_back(device_ptrs[i]);
694
+ }
695
+ }
696
+
697
+ return devices;
698
+ }
699
+
552
700
  inline CustomOpDomain::CustomOpDomain(const char* domain) {
553
701
  ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
554
702
  }
@@ -628,6 +776,62 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter)
628
776
  return *this;
629
777
  }
630
778
 
779
+ inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) {
780
+ ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
781
+ }
782
+
783
+ inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, ConstSessionOptions session_options) {
784
+ ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
785
+ }
786
+
787
+ inline Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options) {
788
+ return Ort::Status(GetCompileApi().CompileModel(env, model_compilation_options));
789
+ }
790
+
791
+ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelPath(
792
+ const ORTCHAR_T* input_model_path) {
793
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelPath(this->p_, input_model_path));
794
+ return *this;
795
+ }
796
+
797
+ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer(
798
+ const void* input_model_data, size_t input_model_data_size) {
799
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelFromBuffer(this->p_, input_model_data,
800
+ input_model_data_size));
801
+ return *this;
802
+ }
803
+
804
+ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath(
805
+ const ORTCHAR_T* output_model_path) {
806
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path));
807
+ return *this;
808
+ }
809
+
810
+ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile(
811
+ const ORTCHAR_T* file_path, size_t initializer_size_threshold) {
812
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile(
813
+ this->p_,
814
+ file_path,
815
+ initializer_size_threshold));
816
+ return *this;
817
+ }
818
+
819
+ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
820
+ OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) {
821
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator,
822
+ output_model_buffer_ptr,
823
+ output_model_buffer_size_ptr));
824
+ return *this;
825
+ }
826
+
827
+ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
828
+ bool embed_ep_context_in_model) {
829
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
830
+ this->p_,
831
+ embed_ep_context_in_model));
832
+ return *this;
833
+ }
834
+
631
835
  namespace detail {
632
836
 
633
837
  template <typename T>
@@ -659,7 +863,8 @@ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) c
659
863
  }
660
864
 
661
865
  template <typename T>
662
- inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
866
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key,
867
+ const std::string& def) const {
663
868
  if (!this->HasConfigEntry(config_key)) {
664
869
  return def;
665
870
  }
@@ -745,6 +950,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionM
745
950
  return *this;
746
951
  }
747
952
 
953
+ template <typename T>
954
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLoadCancellationFlag(bool value) {
955
+ ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value));
956
+ return *this;
957
+ }
958
+
748
959
  template <typename T>
749
960
  inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
750
961
  ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
@@ -891,6 +1102,65 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
891
1102
  return *this;
892
1103
  }
893
1104
 
1105
+ namespace {
1106
+ template <typename T>
1107
+ void SessionOptionsAppendEP(detail::SessionOptionsImpl<T>& session_options,
1108
+ Env& env, const std::vector<ConstEpDevice>& ep_devices,
1109
+ const std::vector<const char*>& ep_options_keys,
1110
+ const std::vector<const char*>& ep_options_values) {
1111
+ std::vector<const OrtEpDevice*> ep_devices_ptrs;
1112
+ ep_devices_ptrs.reserve(ep_devices.size());
1113
+ for (const auto& ep_device : ep_devices) {
1114
+ ep_devices_ptrs.push_back(ep_device);
1115
+ }
1116
+
1117
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_V2(
1118
+ session_options, env, ep_devices_ptrs.data(), ep_devices_ptrs.size(),
1119
+ ep_options_keys.data(), ep_options_values.data(), ep_options_keys.size()));
1120
+ }
1121
+ } // namespace
1122
+
1123
+ template <typename T>
1124
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
1125
+ Env& env, const std::vector<ConstEpDevice>& ep_devices, const KeyValuePairs& ep_options) {
1126
+ std::vector<const char*> ep_options_keys, ep_options_values;
1127
+ ep_options.GetKeyValuePairs(ep_options_keys, ep_options_values);
1128
+
1129
+ SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);
1130
+
1131
+ return *this;
1132
+ }
1133
+
1134
+ template <typename T>
1135
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
1136
+ Env& env, const std::vector<ConstEpDevice>& ep_devices,
1137
+ const std::unordered_map<std::string, std::string>& ep_options) {
1138
+ std::vector<const char*> ep_options_keys, ep_options_values;
1139
+ ep_options_keys.reserve(ep_options.size());
1140
+ ep_options_values.reserve(ep_options.size());
1141
+
1142
+ for (const auto& [key, value] : ep_options) {
1143
+ ep_options_keys.push_back(key.c_str());
1144
+ ep_options_values.push_back(value.c_str());
1145
+ }
1146
+
1147
+ SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);
1148
+
1149
+ return *this;
1150
+ }
1151
+
1152
+ template <typename T>
1153
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) {
1154
+ ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy));
1155
+ return *this;
1156
+ }
1157
+
1158
+ template <typename T>
1159
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) {
1160
+ ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state));
1161
+ return *this;
1162
+ }
1163
+
894
1164
  template <typename T>
895
1165
  inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
896
1166
  ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
@@ -995,6 +1265,59 @@ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
995
1265
  return out;
996
1266
  }
997
1267
 
1268
+ template <typename T>
1269
+ inline std::vector<std::string> ConstSessionImpl<T>::GetInputNames() const {
1270
+ AllocatorWithDefaultOptions allocator;
1271
+
1272
+ auto num_inputs = GetInputCount();
1273
+ std::vector<std::string> input_names;
1274
+ input_names.reserve(num_inputs);
1275
+
1276
+ for (size_t i = 0; i < num_inputs; ++i) {
1277
+ char* name = nullptr;
1278
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name));
1279
+ input_names.push_back(name);
1280
+ allocator.Free(name);
1281
+ }
1282
+
1283
+ return input_names;
1284
+ }
1285
+
1286
+ template <typename T>
1287
+ inline std::vector<std::string> ConstSessionImpl<T>::GetOutputNames() const {
1288
+ AllocatorWithDefaultOptions allocator;
1289
+
1290
+ auto num_inputs = GetOutputCount();
1291
+ std::vector<std::string> output_names;
1292
+ output_names.reserve(num_inputs);
1293
+
1294
+ for (size_t i = 0; i < num_inputs; ++i) {
1295
+ char* name = nullptr;
1296
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name));
1297
+ output_names.push_back(name);
1298
+ allocator.Free(name);
1299
+ }
1300
+
1301
+ return output_names;
1302
+ }
1303
+
1304
+ template <typename T>
1305
+ inline std::vector<std::string> ConstSessionImpl<T>::GetOverridableInitializerNames() const {
1306
+ AllocatorWithDefaultOptions allocator;
1307
+
1308
+ auto num_initializers = GetOverridableInitializerCount();
1309
+ std::vector<std::string> initializer_names;
1310
+ initializer_names.reserve(num_initializers);
1311
+
1312
+ for (size_t i = 0; i < num_initializers; ++i) {
1313
+ char* name = nullptr;
1314
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name));
1315
+ initializer_names.push_back(name);
1316
+ }
1317
+
1318
+ return initializer_names;
1319
+ }
1320
+
998
1321
  template <typename T>
999
1322
  inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
1000
1323
  char* out;
@@ -1051,6 +1374,45 @@ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t in
1051
1374
  return TypeInfo{out};
1052
1375
  }
1053
1376
 
1377
+ #if !defined(ORT_MINIMAL_BUILD)
1378
+ template <typename T>
1379
+ inline int ConstSessionImpl<T>::GetOpset(const std::string& domain) const {
1380
+ int opset;
1381
+ ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset));
1382
+ return opset;
1383
+ }
1384
+ #endif // !defined(ORT_MINIMAL_BUILD)
1385
+
1386
+ template <typename T>
1387
+ std::vector<ValueInfo> ConstSessionImpl<T>::GetInputs() const {
1388
+ const std::vector<std::string> input_names = GetInputNames();
1389
+
1390
+ std::vector<ValueInfo> inputs;
1391
+ inputs.reserve(input_names.size());
1392
+
1393
+ for (size_t i = 0; i < input_names.size(); ++i) {
1394
+ auto type_info = GetInputTypeInfo(i);
1395
+ inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()});
1396
+ }
1397
+
1398
+ return inputs;
1399
+ }
1400
+
1401
+ template <typename T>
1402
+ std::vector<ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
1403
+ const std::vector<std::string> output_names = GetOutputNames();
1404
+
1405
+ std::vector<ValueInfo> outputs;
1406
+ outputs.reserve(output_names.size());
1407
+
1408
+ for (size_t i = 0; i < output_names.size(); ++i) {
1409
+ auto type_info = GetOutputTypeInfo(i);
1410
+ outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()});
1411
+ }
1412
+
1413
+ return outputs;
1414
+ }
1415
+
1054
1416
  template <typename T>
1055
1417
  inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1056
1418
  const char* const* output_names, size_t output_count) {
@@ -1098,6 +1460,15 @@ inline void SessionImpl<T>::SetEpDynamicOptions(const char* const* keys, const c
1098
1460
  ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len));
1099
1461
  }
1100
1462
 
1463
+ #if !defined(ORT_MINIMAL_BUILD)
1464
+ template <typename T>
1465
+ inline void SessionImpl<T>::FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
1466
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
1467
+ ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model));
1468
+ ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container));
1469
+ }
1470
+ #endif // #if !defined(ORT_MINIMAL_BUILD)
1471
+
1101
1472
  } // namespace detail
1102
1473
 
1103
1474
  inline SessionOptions::SessionOptions() {
@@ -1144,6 +1515,32 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat
1144
1515
  prepacked_weights_container, &this->p_));
1145
1516
  }
1146
1517
 
1518
+ #if !defined(ORT_MINIMAL_BUILD)
1519
+ inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) {
1520
+ ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_));
1521
+ }
1522
+
1523
+ // static
1524
+ inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path,
1525
+ const SessionOptions& options) {
1526
+ OrtSession* session = nullptr;
1527
+ ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session));
1528
+ return Session(session);
1529
+ }
1530
+
1531
+ // static
1532
+ inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length,
1533
+ const SessionOptions& options) {
1534
+ OrtSession* session = nullptr;
1535
+ ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options,
1536
+ &session));
1537
+ return Session(session);
1538
+ }
1539
+
1540
+ void FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
1541
+ OrtPrepackedWeightsContainer* prepacked_weights_container);
1542
+ #endif // #if !defined(ORT_MINIMAL_BUILD)
1543
+
1147
1544
  inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
1148
1545
  char* out;
1149
1546
  ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
@@ -1211,6 +1608,59 @@ inline int64_t ModelMetadata::GetVersion() const {
1211
1608
  return out;
1212
1609
  }
1213
1610
 
1611
+ inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type,
1612
+ const std::vector<int64_t>& dims,
1613
+ const std::vector<std::string>* symbolic_dims) {
1614
+ ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_));
1615
+ ThrowOnError(GetApi().SetTensorElementType(p_, element_type));
1616
+ ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size()));
1617
+
1618
+ if (symbolic_dims) {
1619
+ std::vector<const char*> symbolic_dims_cstr;
1620
+ symbolic_dims_cstr.reserve(symbolic_dims->size());
1621
+ std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr),
1622
+ [](const std::string& s) { return s.c_str(); });
1623
+ ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size()));
1624
+ }
1625
+ }
1626
+
1627
+ #if !defined(ORT_MINIMAL_BUILD)
1628
+ // static
1629
+ inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) {
1630
+ OrtTypeInfo* output = nullptr;
1631
+ ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output));
1632
+ return TypeInfo{output};
1633
+ }
1634
+
1635
+ // static
1636
+ inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) {
1637
+ OrtTypeInfo* output = nullptr;
1638
+ ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output));
1639
+ return TypeInfo{output};
1640
+ }
1641
+
1642
+ // static
1643
+ inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) {
1644
+ OrtTypeInfo* output;
1645
+ ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output));
1646
+ return TypeInfo{output};
1647
+ }
1648
+
1649
+ // static
1650
+ inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) {
1651
+ OrtTypeInfo* output;
1652
+ ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output));
1653
+ return TypeInfo{output};
1654
+ }
1655
+
1656
+ // static
1657
+ inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) {
1658
+ OrtTypeInfo* output;
1659
+ ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output));
1660
+ return TypeInfo{output};
1661
+ }
1662
+ #endif // #if !defined(ORT_MINIMAL_BUILD)
1663
+
1214
1664
  namespace detail {
1215
1665
 
1216
1666
  template <typename T>
@@ -1244,9 +1694,16 @@ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** va
1244
1694
  ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1245
1695
  }
1246
1696
 
1697
+ template <typename T>
1698
+ inline std::vector<const char*> TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions() const {
1699
+ std::vector<const char*> out(GetDimensionsCount(), nullptr);
1700
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size()));
1701
+ return out;
1702
+ }
1703
+
1247
1704
  template <typename T>
1248
1705
  inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1249
- std::vector<int64_t> out(GetDimensionsCount(), 0);
1706
+ std::vector<int64_t> out(GetDimensionsCount(), -1);
1250
1707
  ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1251
1708
  return out;
1252
1709
  }
@@ -1560,23 +2017,35 @@ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf
1560
2017
  } // namespace detail
1561
2018
 
1562
2019
  template <typename T>
1563
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
2020
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count,
2021
+ const int64_t* shape, size_t shape_len) {
1564
2022
  return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1565
2023
  }
1566
2024
 
1567
- inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
2025
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count,
2026
+ const int64_t* shape, size_t shape_len,
1568
2027
  ONNXTensorElementDataType type) {
1569
2028
  OrtValue* out;
1570
2029
  ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1571
2030
  return Value{out};
1572
2031
  }
1573
2032
 
2033
+ inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count,
2034
+ const int64_t* shape, size_t shape_len,
2035
+ ONNXTensorElementDataType type) {
2036
+ OrtValue* out;
2037
+ ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count,
2038
+ shape, shape_len, type, &out));
2039
+ return Value{out};
2040
+ }
2041
+
1574
2042
  template <typename T>
1575
2043
  inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1576
2044
  return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1577
2045
  }
1578
2046
 
1579
- inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
2047
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len,
2048
+ ONNXTensorElementDataType type) {
1580
2049
  OrtValue* out;
1581
2050
  ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1582
2051
  return Value{out};
@@ -1594,7 +2063,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data,
1594
2063
  const Shape& values_shape, ONNXTensorElementDataType type) {
1595
2064
  OrtValue* out;
1596
2065
  ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1597
- values_shape.shape, values_shape.shape_len, type, &out));
2066
+ values_shape.shape, values_shape.shape_len, type,
2067
+ &out));
1598
2068
  return Value{out};
1599
2069
  }
1600
2070
 
@@ -2167,4 +2637,142 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con
2167
2637
  return attr_hdl;
2168
2638
  }
2169
2639
 
2640
+ namespace detail {
2641
+ inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>& strings) {
2642
+ std::vector<const char*> ptrs;
2643
+ ptrs.reserve(strings.size());
2644
+ std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs),
2645
+ [](const std::string& s) { return s.c_str(); });
2646
+
2647
+ return ptrs;
2648
+ }
2649
+ } // namespace detail
2650
+
2651
+ #if !defined(ORT_MINIMAL_BUILD)
2652
+ // static
2653
+ inline void Node::Init(const std::string& operator_name, const std::string& operator_domain,
2654
+ const std::string& node_name,
2655
+ const std::vector<std::string>& input_names,
2656
+ const std::vector<std::string>& output_names,
2657
+ std::vector<OpAttr>& attributes,
2658
+ OrtNode*& node) {
2659
+ auto inputs = detail::StringsToCharPtrs(input_names);
2660
+ auto outputs = detail::StringsToCharPtrs(output_names);
2661
+
2662
+ std::vector<OrtOpAttr*> attributes_ptrs;
2663
+ attributes_ptrs.reserve(attributes.size());
2664
+ std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs),
2665
+ [](OpAttr& attr) -> OrtOpAttr* { return attr; });
2666
+
2667
+ ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(),
2668
+ inputs.data(), inputs.size(),
2669
+ outputs.data(), outputs.size(),
2670
+ attributes_ptrs.data(), attributes_ptrs.size(),
2671
+ &node));
2672
+
2673
+ // Node now owns the attributes
2674
+ std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); });
2675
+ }
2676
+
2677
+ inline Node::Node(const std::string& operator_name, const std::string& operator_domain,
2678
+ const std::string& node_name,
2679
+ const std::vector<std::string>& input_names,
2680
+ const std::vector<std::string>& output_names,
2681
+ std::vector<OpAttr>& attributes) {
2682
+ Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_);
2683
+ }
2684
+
2685
+ inline Node::Node(const std::string& operator_name, const std::string& operator_domain,
2686
+ const std::string& node_name,
2687
+ const std::vector<std::string>& input_names,
2688
+ const std::vector<std::string>& output_names) {
2689
+ std::vector<OpAttr> empty_attributes;
2690
+ Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_);
2691
+ }
2692
+
2693
+ inline Graph::Graph() {
2694
+ ThrowOnError(GetModelEditorApi().CreateGraph(&p_));
2695
+ }
2696
+
2697
+ inline Model::Model(const std::vector<DomainOpsetPair>& opsets) {
2698
+ std::vector<const char*> domains;
2699
+ std::vector<int> versions;
2700
+ domains.reserve(opsets.size());
2701
+ versions.reserve(opsets.size());
2702
+
2703
+ for (const auto& pair : opsets) {
2704
+ domains.push_back(pair.first.c_str());
2705
+ versions.push_back(pair.second);
2706
+ }
2707
+
2708
+ ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_));
2709
+ }
2710
+
2711
+ inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) {
2712
+ ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_));
2713
+ }
2714
+ #endif // !defined(ORT_MINIMAL_BUILD)
2715
+
2716
+ namespace detail {
2717
+ template <>
2718
+ inline std::string ValueInfoImpl<OrtValueInfo>::Name() const {
2719
+ const char* name = nullptr;
2720
+ ThrowOnError(GetApi().GetValueInfoName(this->p_, &name));
2721
+ return name;
2722
+ }
2723
+
2724
+ template <>
2725
+ inline ConstTypeInfo ValueInfoImpl<OrtValueInfo>::TypeInfo() const {
2726
+ const OrtTypeInfo* type_info = nullptr;
2727
+ ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info));
2728
+ return ConstTypeInfo{type_info};
2729
+ }
2730
+
2731
+ #if !defined(ORT_MINIMAL_BUILD)
2732
+ template <>
2733
+ inline void GraphImpl<OrtGraph>::SetInputs(std::vector<ValueInfo>& inputs) {
2734
+ std::vector<OrtValueInfo*> inputs_ptrs;
2735
+ inputs_ptrs.reserve(inputs.size());
2736
+ std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs),
2737
+ [](ValueInfo& vi) -> OrtValueInfo* { return vi; });
2738
+
2739
+ ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size()));
2740
+
2741
+ // Graph now owns the inputs
2742
+ std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); });
2743
+ }
2744
+
2745
+ template <>
2746
+ inline void GraphImpl<OrtGraph>::SetOutputs(std::vector<ValueInfo>& outputs) {
2747
+ std::vector<OrtValueInfo*> outputs_ptrs;
2748
+ outputs_ptrs.reserve(outputs.size());
2749
+ std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs),
2750
+ [](ValueInfo& vi) -> OrtValueInfo* { return vi; });
2751
+
2752
+ ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size()));
2753
+
2754
+ // Graph now owns the outputs
2755
+ std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); });
2756
+ }
2757
+
2758
+ template <>
2759
+ inline void GraphImpl<OrtGraph>::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) {
2760
+ // Graph takes ownership of `initializer`
2761
+ ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external));
2762
+ }
2763
+
2764
+ template <>
2765
+ inline void GraphImpl<OrtGraph>::AddNode(Node& node) {
2766
+ // Graph takes ownership of `node`
2767
+ ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release()));
2768
+ }
2769
+
2770
+ template <>
2771
+ inline void ModelImpl<OrtModel>::AddGraph(Graph& graph) {
2772
+ // Model takes ownership of `graph`
2773
+ ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release()));
2774
+ }
2775
+ #endif // !defined(ORT_MINIMAL_BUILD)
2776
+
2777
+ } // namespace detail
2170
2778
  } // namespace Ort