com.github.asus4.onnxruntime 0.4.1 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Plugins/Android/onnxruntime.aar +0 -0
- package/Plugins/Linux/x64/libonnxruntime.so +0 -0
- package/Plugins/Windows/arm64/onnxruntime.dll +0 -0
- package/Plugins/Windows/x64/onnxruntime.dll +0 -0
- package/Plugins/Windows/x86/onnxruntime.dll +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/Info.plist +11 -5
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_c_api.h +1307 -64
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +425 -24
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +614 -6
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +3 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +29 -5
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_c_api.h +1307 -64
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_cxx_api.h +425 -24
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_cxx_inline.h +614 -6
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_run_options_config_keys.h +3 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Headers/onnxruntime_session_options_config_keys.h +29 -5
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/Resources/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-maccatalyst/onnxruntime.framework/Versions/A/onnxruntime +0 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_c_api.h +1307 -64
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_api.h +425 -24
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_cxx_inline.h +614 -6
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_run_options_config_keys.h +3 -0
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Headers/onnxruntime_session_options_config_keys.h +29 -5
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/Info.plist +2 -2
- package/Plugins/iOS~/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.framework/onnxruntime +0 -0
- package/Plugins/macOS/arm64/libonnxruntime.dylib +0 -0
- package/Plugins/macOS/x64/libonnxruntime.dylib +0 -0
- package/README.md +6 -5
- package/Runtime/ManagedProjections.shared.cs +1 -2
- package/Runtime/NativeMethods.shared.cs +20 -2
- package/Runtime/OrtValue.shared.cs +1 -1
- package/Runtime/SessionOptions.shared.cs +10 -0
- package/Runtime/Training/NativeTrainingMethods.shared.cs +1 -1
- package/package.json +1 -1
|
@@ -10,7 +10,9 @@
|
|
|
10
10
|
#include <algorithm>
|
|
11
11
|
#include <functional>
|
|
12
12
|
#include <iterator>
|
|
13
|
+
#include <string>
|
|
13
14
|
#include <type_traits>
|
|
15
|
+
#include <vector>
|
|
14
16
|
|
|
15
17
|
// Convert OrtStatus to Ort::Status and return
|
|
16
18
|
// instead of throwing
|
|
@@ -477,6 +479,125 @@ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustom
|
|
|
477
479
|
return *this;
|
|
478
480
|
}
|
|
479
481
|
|
|
482
|
+
namespace detail {
|
|
483
|
+
template <typename T>
|
|
484
|
+
inline const char* KeyValuePairsImpl<T>::GetValue(const char* key) const {
|
|
485
|
+
return GetApi().GetKeyValue(this->p_, key);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
template <typename T>
|
|
489
|
+
inline std::unordered_map<std::string, std::string> KeyValuePairsImpl<T>::GetKeyValuePairs() const {
|
|
490
|
+
std::unordered_map<std::string, std::string> out;
|
|
491
|
+
|
|
492
|
+
size_t num_pairs = 0;
|
|
493
|
+
const char* const* keys = nullptr;
|
|
494
|
+
const char* const* values = nullptr;
|
|
495
|
+
GetApi().GetKeyValuePairs(this->p_, &keys, &values, &num_pairs);
|
|
496
|
+
if (num_pairs > 0) {
|
|
497
|
+
out.reserve(num_pairs);
|
|
498
|
+
for (size_t i = 0; i < num_pairs; ++i) {
|
|
499
|
+
out.emplace(keys[i], values[i]);
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
return out;
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
template <typename T>
|
|
507
|
+
inline void KeyValuePairsImpl<T>::GetKeyValuePairs(std::vector<const char*>& keys,
|
|
508
|
+
std::vector<const char*>& values) const {
|
|
509
|
+
keys.clear();
|
|
510
|
+
values.clear();
|
|
511
|
+
|
|
512
|
+
size_t num_pairs = 0;
|
|
513
|
+
const char* const* keys_ptr = nullptr;
|
|
514
|
+
const char* const* values_ptr = nullptr;
|
|
515
|
+
GetApi().GetKeyValuePairs(this->p_, &keys_ptr, &values_ptr, &num_pairs);
|
|
516
|
+
if (num_pairs > 0) {
|
|
517
|
+
keys.resize(num_pairs);
|
|
518
|
+
values.resize(num_pairs);
|
|
519
|
+
std::copy(keys_ptr, keys_ptr + num_pairs, keys.begin());
|
|
520
|
+
std::copy(values_ptr, values_ptr + num_pairs, values.begin());
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
} // namespace detail
|
|
524
|
+
|
|
525
|
+
inline KeyValuePairs::KeyValuePairs() {
|
|
526
|
+
GetApi().CreateKeyValuePairs(&p_);
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
inline KeyValuePairs::KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs) {
|
|
530
|
+
GetApi().CreateKeyValuePairs(&p_);
|
|
531
|
+
for (const auto& kv : kv_pairs) {
|
|
532
|
+
GetApi().AddKeyValuePair(this->p_, kv.first.c_str(), kv.second.c_str());
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
inline void KeyValuePairs::Add(const char* key, const char* value) {
|
|
537
|
+
GetApi().AddKeyValuePair(this->p_, key, value);
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
inline void KeyValuePairs::Remove(const char* key) {
|
|
541
|
+
GetApi().RemoveKeyValuePair(this->p_, key);
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
namespace detail {
|
|
545
|
+
template <typename T>
|
|
546
|
+
inline OrtHardwareDeviceType HardwareDeviceImpl<T>::Type() const {
|
|
547
|
+
return GetApi().HardwareDevice_Type(this->p_);
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
template <typename T>
|
|
551
|
+
inline uint32_t HardwareDeviceImpl<T>::VendorId() const {
|
|
552
|
+
return GetApi().HardwareDevice_VendorId(this->p_);
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
template <typename T>
|
|
556
|
+
inline uint32_t HardwareDeviceImpl<T>::DeviceId() const {
|
|
557
|
+
return GetApi().HardwareDevice_DeviceId(this->p_);
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
template <typename T>
|
|
561
|
+
inline const char* HardwareDeviceImpl<T>::Vendor() const {
|
|
562
|
+
return GetApi().HardwareDevice_Vendor(this->p_);
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
template <typename T>
|
|
566
|
+
inline ConstKeyValuePairs HardwareDeviceImpl<T>::Metadata() const {
|
|
567
|
+
return ConstKeyValuePairs{GetApi().HardwareDevice_Metadata(this->p_)};
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
template <typename T>
|
|
571
|
+
inline const char* EpDeviceImpl<T>::EpName() const {
|
|
572
|
+
return GetApi().EpDevice_EpName(this->p_);
|
|
573
|
+
}
|
|
574
|
+
|
|
575
|
+
template <typename T>
|
|
576
|
+
inline const char* EpDeviceImpl<T>::EpVendor() const {
|
|
577
|
+
return GetApi().EpDevice_EpVendor(this->p_);
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
template <typename T>
|
|
581
|
+
inline ConstKeyValuePairs EpDeviceImpl<T>::EpMetadata() const {
|
|
582
|
+
return ConstKeyValuePairs(GetApi().EpDevice_EpMetadata(this->p_));
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
template <typename T>
|
|
586
|
+
inline ConstKeyValuePairs EpDeviceImpl<T>::EpOptions() const {
|
|
587
|
+
return ConstKeyValuePairs(GetApi().EpDevice_EpOptions(this->p_));
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
template <typename T>
|
|
591
|
+
inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
|
|
592
|
+
return ConstHardwareDevice(GetApi().EpDevice_Device(this->p_));
|
|
593
|
+
}
|
|
594
|
+
} // namespace detail
|
|
595
|
+
|
|
596
|
+
inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
|
|
597
|
+
ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) {
|
|
598
|
+
ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_));
|
|
599
|
+
}
|
|
600
|
+
|
|
480
601
|
inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
|
|
481
602
|
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
|
|
482
603
|
if (strcmp(logid, "onnxruntime-node") == 0) {
|
|
@@ -549,6 +670,33 @@ inline Env& Env::CreateAndRegisterAllocatorV2(const std::string& provider_type,
|
|
|
549
670
|
return *this;
|
|
550
671
|
}
|
|
551
672
|
|
|
673
|
+
inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name,
|
|
674
|
+
const std::basic_string<ORTCHAR_T>& path) {
|
|
675
|
+
ThrowOnError(GetApi().RegisterExecutionProviderLibrary(p_, registration_name, path.c_str()));
|
|
676
|
+
return *this;
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) {
|
|
680
|
+
ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name));
|
|
681
|
+
return *this;
|
|
682
|
+
}
|
|
683
|
+
|
|
684
|
+
inline std::vector<ConstEpDevice> Env::GetEpDevices() const {
|
|
685
|
+
size_t num_devices = 0;
|
|
686
|
+
const OrtEpDevice* const* device_ptrs = nullptr;
|
|
687
|
+
ThrowOnError(GetApi().GetEpDevices(p_, &device_ptrs, &num_devices));
|
|
688
|
+
|
|
689
|
+
std::vector<ConstEpDevice> devices;
|
|
690
|
+
if (num_devices > 0) {
|
|
691
|
+
devices.reserve(num_devices);
|
|
692
|
+
for (size_t i = 0; i < num_devices; ++i) {
|
|
693
|
+
devices.emplace_back(device_ptrs[i]);
|
|
694
|
+
}
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
return devices;
|
|
698
|
+
}
|
|
699
|
+
|
|
552
700
|
inline CustomOpDomain::CustomOpDomain(const char* domain) {
|
|
553
701
|
ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
|
|
554
702
|
}
|
|
@@ -628,6 +776,62 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter)
|
|
|
628
776
|
return *this;
|
|
629
777
|
}
|
|
630
778
|
|
|
779
|
+
inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) {
|
|
780
|
+
ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, ConstSessionOptions session_options) {
|
|
784
|
+
ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_));
|
|
785
|
+
}
|
|
786
|
+
|
|
787
|
+
inline Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options) {
|
|
788
|
+
return Ort::Status(GetCompileApi().CompileModel(env, model_compilation_options));
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelPath(
|
|
792
|
+
const ORTCHAR_T* input_model_path) {
|
|
793
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelPath(this->p_, input_model_path));
|
|
794
|
+
return *this;
|
|
795
|
+
}
|
|
796
|
+
|
|
797
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer(
|
|
798
|
+
const void* input_model_data, size_t input_model_data_size) {
|
|
799
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModelFromBuffer(this->p_, input_model_data,
|
|
800
|
+
input_model_data_size));
|
|
801
|
+
return *this;
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath(
|
|
805
|
+
const ORTCHAR_T* output_model_path) {
|
|
806
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path));
|
|
807
|
+
return *this;
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile(
|
|
811
|
+
const ORTCHAR_T* file_path, size_t initializer_size_threshold) {
|
|
812
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile(
|
|
813
|
+
this->p_,
|
|
814
|
+
file_path,
|
|
815
|
+
initializer_size_threshold));
|
|
816
|
+
return *this;
|
|
817
|
+
}
|
|
818
|
+
|
|
819
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer(
|
|
820
|
+
OrtAllocator* allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr) {
|
|
821
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer(this->p_, allocator,
|
|
822
|
+
output_model_buffer_ptr,
|
|
823
|
+
output_model_buffer_size_ptr));
|
|
824
|
+
return *this;
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
|
|
828
|
+
bool embed_ep_context_in_model) {
|
|
829
|
+
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(
|
|
830
|
+
this->p_,
|
|
831
|
+
embed_ep_context_in_model));
|
|
832
|
+
return *this;
|
|
833
|
+
}
|
|
834
|
+
|
|
631
835
|
namespace detail {
|
|
632
836
|
|
|
633
837
|
template <typename T>
|
|
@@ -659,7 +863,8 @@ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) c
|
|
|
659
863
|
}
|
|
660
864
|
|
|
661
865
|
template <typename T>
|
|
662
|
-
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key,
|
|
866
|
+
inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key,
|
|
867
|
+
const std::string& def) const {
|
|
663
868
|
if (!this->HasConfigEntry(config_key)) {
|
|
664
869
|
return def;
|
|
665
870
|
}
|
|
@@ -745,6 +950,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionM
|
|
|
745
950
|
return *this;
|
|
746
951
|
}
|
|
747
952
|
|
|
953
|
+
template <typename T>
|
|
954
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLoadCancellationFlag(bool value) {
|
|
955
|
+
ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value));
|
|
956
|
+
return *this;
|
|
957
|
+
}
|
|
958
|
+
|
|
748
959
|
template <typename T>
|
|
749
960
|
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
|
|
750
961
|
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
|
|
@@ -891,6 +1102,65 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
|
|
|
891
1102
|
return *this;
|
|
892
1103
|
}
|
|
893
1104
|
|
|
1105
|
+
namespace {
|
|
1106
|
+
template <typename T>
|
|
1107
|
+
void SessionOptionsAppendEP(detail::SessionOptionsImpl<T>& session_options,
|
|
1108
|
+
Env& env, const std::vector<ConstEpDevice>& ep_devices,
|
|
1109
|
+
const std::vector<const char*>& ep_options_keys,
|
|
1110
|
+
const std::vector<const char*>& ep_options_values) {
|
|
1111
|
+
std::vector<const OrtEpDevice*> ep_devices_ptrs;
|
|
1112
|
+
ep_devices_ptrs.reserve(ep_devices.size());
|
|
1113
|
+
for (const auto& ep_device : ep_devices) {
|
|
1114
|
+
ep_devices_ptrs.push_back(ep_device);
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_V2(
|
|
1118
|
+
session_options, env, ep_devices_ptrs.data(), ep_devices_ptrs.size(),
|
|
1119
|
+
ep_options_keys.data(), ep_options_values.data(), ep_options_keys.size()));
|
|
1120
|
+
}
|
|
1121
|
+
} // namespace
|
|
1122
|
+
|
|
1123
|
+
template <typename T>
|
|
1124
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
|
|
1125
|
+
Env& env, const std::vector<ConstEpDevice>& ep_devices, const KeyValuePairs& ep_options) {
|
|
1126
|
+
std::vector<const char*> ep_options_keys, ep_options_values;
|
|
1127
|
+
ep_options.GetKeyValuePairs(ep_options_keys, ep_options_values);
|
|
1128
|
+
|
|
1129
|
+
SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);
|
|
1130
|
+
|
|
1131
|
+
return *this;
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
template <typename T>
|
|
1135
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
|
|
1136
|
+
Env& env, const std::vector<ConstEpDevice>& ep_devices,
|
|
1137
|
+
const std::unordered_map<std::string, std::string>& ep_options) {
|
|
1138
|
+
std::vector<const char*> ep_options_keys, ep_options_values;
|
|
1139
|
+
ep_options_keys.reserve(ep_options.size());
|
|
1140
|
+
ep_options_values.reserve(ep_options.size());
|
|
1141
|
+
|
|
1142
|
+
for (const auto& [key, value] : ep_options) {
|
|
1143
|
+
ep_options_keys.push_back(key.c_str());
|
|
1144
|
+
ep_options_values.push_back(value.c_str());
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
SessionOptionsAppendEP(*this, env, ep_devices, ep_options_keys, ep_options_values);
|
|
1148
|
+
|
|
1149
|
+
return *this;
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
template <typename T>
|
|
1153
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) {
|
|
1154
|
+
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy));
|
|
1155
|
+
return *this;
|
|
1156
|
+
}
|
|
1157
|
+
|
|
1158
|
+
template <typename T>
|
|
1159
|
+
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) {
|
|
1160
|
+
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state));
|
|
1161
|
+
return *this;
|
|
1162
|
+
}
|
|
1163
|
+
|
|
894
1164
|
template <typename T>
|
|
895
1165
|
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
|
|
896
1166
|
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
|
|
@@ -995,6 +1265,59 @@ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
|
|
|
995
1265
|
return out;
|
|
996
1266
|
}
|
|
997
1267
|
|
|
1268
|
+
template <typename T>
|
|
1269
|
+
inline std::vector<std::string> ConstSessionImpl<T>::GetInputNames() const {
|
|
1270
|
+
AllocatorWithDefaultOptions allocator;
|
|
1271
|
+
|
|
1272
|
+
auto num_inputs = GetInputCount();
|
|
1273
|
+
std::vector<std::string> input_names;
|
|
1274
|
+
input_names.reserve(num_inputs);
|
|
1275
|
+
|
|
1276
|
+
for (size_t i = 0; i < num_inputs; ++i) {
|
|
1277
|
+
char* name = nullptr;
|
|
1278
|
+
ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name));
|
|
1279
|
+
input_names.push_back(name);
|
|
1280
|
+
allocator.Free(name);
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
return input_names;
|
|
1284
|
+
}
|
|
1285
|
+
|
|
1286
|
+
template <typename T>
|
|
1287
|
+
inline std::vector<std::string> ConstSessionImpl<T>::GetOutputNames() const {
|
|
1288
|
+
AllocatorWithDefaultOptions allocator;
|
|
1289
|
+
|
|
1290
|
+
auto num_inputs = GetOutputCount();
|
|
1291
|
+
std::vector<std::string> output_names;
|
|
1292
|
+
output_names.reserve(num_inputs);
|
|
1293
|
+
|
|
1294
|
+
for (size_t i = 0; i < num_inputs; ++i) {
|
|
1295
|
+
char* name = nullptr;
|
|
1296
|
+
ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name));
|
|
1297
|
+
output_names.push_back(name);
|
|
1298
|
+
allocator.Free(name);
|
|
1299
|
+
}
|
|
1300
|
+
|
|
1301
|
+
return output_names;
|
|
1302
|
+
}
|
|
1303
|
+
|
|
1304
|
+
template <typename T>
|
|
1305
|
+
inline std::vector<std::string> ConstSessionImpl<T>::GetOverridableInitializerNames() const {
|
|
1306
|
+
AllocatorWithDefaultOptions allocator;
|
|
1307
|
+
|
|
1308
|
+
auto num_initializers = GetOverridableInitializerCount();
|
|
1309
|
+
std::vector<std::string> initializer_names;
|
|
1310
|
+
initializer_names.reserve(num_initializers);
|
|
1311
|
+
|
|
1312
|
+
for (size_t i = 0; i < num_initializers; ++i) {
|
|
1313
|
+
char* name = nullptr;
|
|
1314
|
+
ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name));
|
|
1315
|
+
initializer_names.push_back(name);
|
|
1316
|
+
}
|
|
1317
|
+
|
|
1318
|
+
return initializer_names;
|
|
1319
|
+
}
|
|
1320
|
+
|
|
998
1321
|
template <typename T>
|
|
999
1322
|
inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
|
|
1000
1323
|
char* out;
|
|
@@ -1051,6 +1374,45 @@ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t in
|
|
|
1051
1374
|
return TypeInfo{out};
|
|
1052
1375
|
}
|
|
1053
1376
|
|
|
1377
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
1378
|
+
template <typename T>
|
|
1379
|
+
inline int ConstSessionImpl<T>::GetOpset(const std::string& domain) const {
|
|
1380
|
+
int opset;
|
|
1381
|
+
ThrowOnError(GetModelEditorApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset));
|
|
1382
|
+
return opset;
|
|
1383
|
+
}
|
|
1384
|
+
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
1385
|
+
|
|
1386
|
+
template <typename T>
|
|
1387
|
+
std::vector<ValueInfo> ConstSessionImpl<T>::GetInputs() const {
|
|
1388
|
+
const std::vector<std::string> input_names = GetInputNames();
|
|
1389
|
+
|
|
1390
|
+
std::vector<ValueInfo> inputs;
|
|
1391
|
+
inputs.reserve(input_names.size());
|
|
1392
|
+
|
|
1393
|
+
for (size_t i = 0; i < input_names.size(); ++i) {
|
|
1394
|
+
auto type_info = GetInputTypeInfo(i);
|
|
1395
|
+
inputs.emplace_back(ValueInfo{input_names[i], type_info.GetConst()});
|
|
1396
|
+
}
|
|
1397
|
+
|
|
1398
|
+
return inputs;
|
|
1399
|
+
}
|
|
1400
|
+
|
|
1401
|
+
template <typename T>
|
|
1402
|
+
std::vector<ValueInfo> ConstSessionImpl<T>::GetOutputs() const {
|
|
1403
|
+
const std::vector<std::string> output_names = GetOutputNames();
|
|
1404
|
+
|
|
1405
|
+
std::vector<ValueInfo> outputs;
|
|
1406
|
+
outputs.reserve(output_names.size());
|
|
1407
|
+
|
|
1408
|
+
for (size_t i = 0; i < output_names.size(); ++i) {
|
|
1409
|
+
auto type_info = GetOutputTypeInfo(i);
|
|
1410
|
+
outputs.emplace_back(ValueInfo{output_names[i], type_info.GetConst()});
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
return outputs;
|
|
1414
|
+
}
|
|
1415
|
+
|
|
1054
1416
|
template <typename T>
|
|
1055
1417
|
inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
|
|
1056
1418
|
const char* const* output_names, size_t output_count) {
|
|
@@ -1098,6 +1460,15 @@ inline void SessionImpl<T>::SetEpDynamicOptions(const char* const* keys, const c
|
|
|
1098
1460
|
ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len));
|
|
1099
1461
|
}
|
|
1100
1462
|
|
|
1463
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
1464
|
+
template <typename T>
|
|
1465
|
+
inline void SessionImpl<T>::FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
|
|
1466
|
+
OrtPrepackedWeightsContainer* prepacked_weights_container) {
|
|
1467
|
+
ThrowOnError(GetModelEditorApi().ApplyModelToModelEditorSession(this->p_, model));
|
|
1468
|
+
ThrowOnError(GetModelEditorApi().FinalizeModelEditorSession(this->p_, options, prepacked_weights_container));
|
|
1469
|
+
}
|
|
1470
|
+
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
|
1471
|
+
|
|
1101
1472
|
} // namespace detail
|
|
1102
1473
|
|
|
1103
1474
|
inline SessionOptions::SessionOptions() {
|
|
@@ -1144,6 +1515,32 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat
|
|
|
1144
1515
|
prepacked_weights_container, &this->p_));
|
|
1145
1516
|
}
|
|
1146
1517
|
|
|
1518
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
1519
|
+
inline Session::Session(const Env& env, const Model& model, const SessionOptions& options) {
|
|
1520
|
+
ThrowOnError(GetModelEditorApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_));
|
|
1521
|
+
}
|
|
1522
|
+
|
|
1523
|
+
// static
|
|
1524
|
+
inline Session Session::CreateModelEditorSession(const Env& env, const ORTCHAR_T* model_path,
|
|
1525
|
+
const SessionOptions& options) {
|
|
1526
|
+
OrtSession* session = nullptr;
|
|
1527
|
+
ThrowOnError(GetModelEditorApi().CreateModelEditorSession(env, model_path, options, &session));
|
|
1528
|
+
return Session(session);
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1531
|
+
// static
|
|
1532
|
+
inline Session Session::CreateModelEditorSession(const Env& env, const void* model_data, size_t model_data_length,
|
|
1533
|
+
const SessionOptions& options) {
|
|
1534
|
+
OrtSession* session = nullptr;
|
|
1535
|
+
ThrowOnError(GetModelEditorApi().CreateModelEditorSessionFromArray(env, model_data, model_data_length, options,
|
|
1536
|
+
&session));
|
|
1537
|
+
return Session(session);
|
|
1538
|
+
}
|
|
1539
|
+
|
|
1540
|
+
void FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
|
|
1541
|
+
OrtPrepackedWeightsContainer* prepacked_weights_container);
|
|
1542
|
+
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
|
1543
|
+
|
|
1147
1544
|
inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
|
|
1148
1545
|
char* out;
|
|
1149
1546
|
ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
|
|
@@ -1211,6 +1608,59 @@ inline int64_t ModelMetadata::GetVersion() const {
|
|
|
1211
1608
|
return out;
|
|
1212
1609
|
}
|
|
1213
1610
|
|
|
1611
|
+
inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type,
|
|
1612
|
+
const std::vector<int64_t>& dims,
|
|
1613
|
+
const std::vector<std::string>* symbolic_dims) {
|
|
1614
|
+
ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_));
|
|
1615
|
+
ThrowOnError(GetApi().SetTensorElementType(p_, element_type));
|
|
1616
|
+
ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size()));
|
|
1617
|
+
|
|
1618
|
+
if (symbolic_dims) {
|
|
1619
|
+
std::vector<const char*> symbolic_dims_cstr;
|
|
1620
|
+
symbolic_dims_cstr.reserve(symbolic_dims->size());
|
|
1621
|
+
std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr),
|
|
1622
|
+
[](const std::string& s) { return s.c_str(); });
|
|
1623
|
+
ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size()));
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
|
|
1627
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
1628
|
+
// static
|
|
1629
|
+
inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) {
|
|
1630
|
+
OrtTypeInfo* output = nullptr;
|
|
1631
|
+
ThrowOnError(GetModelEditorApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output));
|
|
1632
|
+
return TypeInfo{output};
|
|
1633
|
+
}
|
|
1634
|
+
|
|
1635
|
+
// static
|
|
1636
|
+
inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) {
|
|
1637
|
+
OrtTypeInfo* output = nullptr;
|
|
1638
|
+
ThrowOnError(GetModelEditorApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output));
|
|
1639
|
+
return TypeInfo{output};
|
|
1640
|
+
}
|
|
1641
|
+
|
|
1642
|
+
// static
|
|
1643
|
+
inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) {
|
|
1644
|
+
OrtTypeInfo* output;
|
|
1645
|
+
ThrowOnError(GetModelEditorApi().CreateSequenceTypeInfo(sequence_type, &output));
|
|
1646
|
+
return TypeInfo{output};
|
|
1647
|
+
}
|
|
1648
|
+
|
|
1649
|
+
// static
|
|
1650
|
+
inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) {
|
|
1651
|
+
OrtTypeInfo* output;
|
|
1652
|
+
ThrowOnError(GetModelEditorApi().CreateMapTypeInfo(key_type, value_type, &output));
|
|
1653
|
+
return TypeInfo{output};
|
|
1654
|
+
}
|
|
1655
|
+
|
|
1656
|
+
// static
|
|
1657
|
+
inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) {
|
|
1658
|
+
OrtTypeInfo* output;
|
|
1659
|
+
ThrowOnError(GetModelEditorApi().CreateOptionalTypeInfo(contained_type, &output));
|
|
1660
|
+
return TypeInfo{output};
|
|
1661
|
+
}
|
|
1662
|
+
#endif // #if !defined(ORT_MINIMAL_BUILD)
|
|
1663
|
+
|
|
1214
1664
|
namespace detail {
|
|
1215
1665
|
|
|
1216
1666
|
template <typename T>
|
|
@@ -1244,9 +1694,16 @@ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** va
|
|
|
1244
1694
|
ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
|
|
1245
1695
|
}
|
|
1246
1696
|
|
|
1697
|
+
template <typename T>
|
|
1698
|
+
inline std::vector<const char*> TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions() const {
|
|
1699
|
+
std::vector<const char*> out(GetDimensionsCount(), nullptr);
|
|
1700
|
+
ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size()));
|
|
1701
|
+
return out;
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1247
1704
|
template <typename T>
|
|
1248
1705
|
inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
|
|
1249
|
-
std::vector<int64_t> out(GetDimensionsCount(),
|
|
1706
|
+
std::vector<int64_t> out(GetDimensionsCount(), -1);
|
|
1250
1707
|
ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
|
|
1251
1708
|
return out;
|
|
1252
1709
|
}
|
|
@@ -1560,23 +2017,35 @@ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf
|
|
|
1560
2017
|
} // namespace detail
|
|
1561
2018
|
|
|
1562
2019
|
template <typename T>
|
|
1563
|
-
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count,
|
|
2020
|
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count,
|
|
2021
|
+
const int64_t* shape, size_t shape_len) {
|
|
1564
2022
|
return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
|
|
1565
2023
|
}
|
|
1566
2024
|
|
|
1567
|
-
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count,
|
|
2025
|
+
inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count,
|
|
2026
|
+
const int64_t* shape, size_t shape_len,
|
|
1568
2027
|
ONNXTensorElementDataType type) {
|
|
1569
2028
|
OrtValue* out;
|
|
1570
2029
|
ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
|
|
1571
2030
|
return Value{out};
|
|
1572
2031
|
}
|
|
1573
2032
|
|
|
2033
|
+
inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count,
|
|
2034
|
+
const int64_t* shape, size_t shape_len,
|
|
2035
|
+
ONNXTensorElementDataType type) {
|
|
2036
|
+
OrtValue* out;
|
|
2037
|
+
ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count,
|
|
2038
|
+
shape, shape_len, type, &out));
|
|
2039
|
+
return Value{out};
|
|
2040
|
+
}
|
|
2041
|
+
|
|
1574
2042
|
template <typename T>
|
|
1575
2043
|
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
|
|
1576
2044
|
return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
|
|
1577
2045
|
}
|
|
1578
2046
|
|
|
1579
|
-
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len,
|
|
2047
|
+
inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len,
|
|
2048
|
+
ONNXTensorElementDataType type) {
|
|
1580
2049
|
OrtValue* out;
|
|
1581
2050
|
ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
|
|
1582
2051
|
return Value{out};
|
|
@@ -1594,7 +2063,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data,
|
|
|
1594
2063
|
const Shape& values_shape, ONNXTensorElementDataType type) {
|
|
1595
2064
|
OrtValue* out;
|
|
1596
2065
|
ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
|
|
1597
|
-
values_shape.shape, values_shape.shape_len, type,
|
|
2066
|
+
values_shape.shape, values_shape.shape_len, type,
|
|
2067
|
+
&out));
|
|
1598
2068
|
return Value{out};
|
|
1599
2069
|
}
|
|
1600
2070
|
|
|
@@ -2167,4 +2637,142 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con
|
|
|
2167
2637
|
return attr_hdl;
|
|
2168
2638
|
}
|
|
2169
2639
|
|
|
2640
|
+
namespace detail {
|
|
2641
|
+
inline std::vector<const char*> StringsToCharPtrs(const std::vector<std::string>& strings) {
|
|
2642
|
+
std::vector<const char*> ptrs;
|
|
2643
|
+
ptrs.reserve(strings.size());
|
|
2644
|
+
std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs),
|
|
2645
|
+
[](const std::string& s) { return s.c_str(); });
|
|
2646
|
+
|
|
2647
|
+
return ptrs;
|
|
2648
|
+
}
|
|
2649
|
+
} // namespace detail
|
|
2650
|
+
|
|
2651
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
2652
|
+
// static
|
|
2653
|
+
inline void Node::Init(const std::string& operator_name, const std::string& operator_domain,
|
|
2654
|
+
const std::string& node_name,
|
|
2655
|
+
const std::vector<std::string>& input_names,
|
|
2656
|
+
const std::vector<std::string>& output_names,
|
|
2657
|
+
std::vector<OpAttr>& attributes,
|
|
2658
|
+
OrtNode*& node) {
|
|
2659
|
+
auto inputs = detail::StringsToCharPtrs(input_names);
|
|
2660
|
+
auto outputs = detail::StringsToCharPtrs(output_names);
|
|
2661
|
+
|
|
2662
|
+
std::vector<OrtOpAttr*> attributes_ptrs;
|
|
2663
|
+
attributes_ptrs.reserve(attributes.size());
|
|
2664
|
+
std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs),
|
|
2665
|
+
[](OpAttr& attr) -> OrtOpAttr* { return attr; });
|
|
2666
|
+
|
|
2667
|
+
ThrowOnError(GetModelEditorApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(),
|
|
2668
|
+
inputs.data(), inputs.size(),
|
|
2669
|
+
outputs.data(), outputs.size(),
|
|
2670
|
+
attributes_ptrs.data(), attributes_ptrs.size(),
|
|
2671
|
+
&node));
|
|
2672
|
+
|
|
2673
|
+
// Node now owns the attributes
|
|
2674
|
+
std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); });
|
|
2675
|
+
}
|
|
2676
|
+
|
|
2677
|
+
inline Node::Node(const std::string& operator_name, const std::string& operator_domain,
|
|
2678
|
+
const std::string& node_name,
|
|
2679
|
+
const std::vector<std::string>& input_names,
|
|
2680
|
+
const std::vector<std::string>& output_names,
|
|
2681
|
+
std::vector<OpAttr>& attributes) {
|
|
2682
|
+
Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_);
|
|
2683
|
+
}
|
|
2684
|
+
|
|
2685
|
+
inline Node::Node(const std::string& operator_name, const std::string& operator_domain,
|
|
2686
|
+
const std::string& node_name,
|
|
2687
|
+
const std::vector<std::string>& input_names,
|
|
2688
|
+
const std::vector<std::string>& output_names) {
|
|
2689
|
+
std::vector<OpAttr> empty_attributes;
|
|
2690
|
+
Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_);
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
inline Graph::Graph() {
|
|
2694
|
+
ThrowOnError(GetModelEditorApi().CreateGraph(&p_));
|
|
2695
|
+
}
|
|
2696
|
+
|
|
2697
|
+
inline Model::Model(const std::vector<DomainOpsetPair>& opsets) {
|
|
2698
|
+
std::vector<const char*> domains;
|
|
2699
|
+
std::vector<int> versions;
|
|
2700
|
+
domains.reserve(opsets.size());
|
|
2701
|
+
versions.reserve(opsets.size());
|
|
2702
|
+
|
|
2703
|
+
for (const auto& pair : opsets) {
|
|
2704
|
+
domains.push_back(pair.first.c_str());
|
|
2705
|
+
versions.push_back(pair.second);
|
|
2706
|
+
}
|
|
2707
|
+
|
|
2708
|
+
ThrowOnError(GetModelEditorApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_));
|
|
2709
|
+
}
|
|
2710
|
+
|
|
2711
|
+
inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) {
|
|
2712
|
+
ThrowOnError(GetModelEditorApi().CreateValueInfo(name.c_str(), type_info, &p_));
|
|
2713
|
+
}
|
|
2714
|
+
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
2715
|
+
|
|
2716
|
+
namespace detail {
|
|
2717
|
+
template <>
|
|
2718
|
+
inline std::string ValueInfoImpl<OrtValueInfo>::Name() const {
|
|
2719
|
+
const char* name = nullptr;
|
|
2720
|
+
ThrowOnError(GetApi().GetValueInfoName(this->p_, &name));
|
|
2721
|
+
return name;
|
|
2722
|
+
}
|
|
2723
|
+
|
|
2724
|
+
template <>
|
|
2725
|
+
inline ConstTypeInfo ValueInfoImpl<OrtValueInfo>::TypeInfo() const {
|
|
2726
|
+
const OrtTypeInfo* type_info = nullptr;
|
|
2727
|
+
ThrowOnError(GetApi().GetValueInfoTypeInfo(this->p_, &type_info));
|
|
2728
|
+
return ConstTypeInfo{type_info};
|
|
2729
|
+
}
|
|
2730
|
+
|
|
2731
|
+
#if !defined(ORT_MINIMAL_BUILD)
|
|
2732
|
+
template <>
|
|
2733
|
+
inline void GraphImpl<OrtGraph>::SetInputs(std::vector<ValueInfo>& inputs) {
|
|
2734
|
+
std::vector<OrtValueInfo*> inputs_ptrs;
|
|
2735
|
+
inputs_ptrs.reserve(inputs.size());
|
|
2736
|
+
std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs),
|
|
2737
|
+
[](ValueInfo& vi) -> OrtValueInfo* { return vi; });
|
|
2738
|
+
|
|
2739
|
+
ThrowOnError(GetModelEditorApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size()));
|
|
2740
|
+
|
|
2741
|
+
// Graph now owns the inputs
|
|
2742
|
+
std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); });
|
|
2743
|
+
}
|
|
2744
|
+
|
|
2745
|
+
template <>
|
|
2746
|
+
inline void GraphImpl<OrtGraph>::SetOutputs(std::vector<ValueInfo>& outputs) {
|
|
2747
|
+
std::vector<OrtValueInfo*> outputs_ptrs;
|
|
2748
|
+
outputs_ptrs.reserve(outputs.size());
|
|
2749
|
+
std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs),
|
|
2750
|
+
[](ValueInfo& vi) -> OrtValueInfo* { return vi; });
|
|
2751
|
+
|
|
2752
|
+
ThrowOnError(GetModelEditorApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size()));
|
|
2753
|
+
|
|
2754
|
+
// Graph now owns the outputs
|
|
2755
|
+
std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); });
|
|
2756
|
+
}
|
|
2757
|
+
|
|
2758
|
+
template <>
|
|
2759
|
+
inline void GraphImpl<OrtGraph>::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) {
|
|
2760
|
+
// Graph takes ownership of `initializer`
|
|
2761
|
+
ThrowOnError(GetModelEditorApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external));
|
|
2762
|
+
}
|
|
2763
|
+
|
|
2764
|
+
template <>
|
|
2765
|
+
inline void GraphImpl<OrtGraph>::AddNode(Node& node) {
|
|
2766
|
+
// Graph takes ownership of `node`
|
|
2767
|
+
ThrowOnError(GetModelEditorApi().AddNodeToGraph(p_, node.release()));
|
|
2768
|
+
}
|
|
2769
|
+
|
|
2770
|
+
template <>
|
|
2771
|
+
inline void ModelImpl<OrtModel>::AddGraph(Graph& graph) {
|
|
2772
|
+
// Model takes ownership of `graph`
|
|
2773
|
+
ThrowOnError(GetModelEditorApi().AddGraphToModel(p_, graph.release()));
|
|
2774
|
+
}
|
|
2775
|
+
#endif // !defined(ORT_MINIMAL_BUILD)
|
|
2776
|
+
|
|
2777
|
+
} // namespace detail
|
|
2170
2778
|
} // namespace Ort
|