react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -0,0 +1,210 @@
1
+ /**
2
+ * sherpa-onnx-model-detect-helper.mm
3
+ *
4
+ * Purpose: Shared filesystem and string helpers for model detection (file/dir listing, token-based
5
+ * ONNX search, path resolution). Used by sherpa-onnx-model-detect-stt.mm and -tts.mm on iOS.
6
+ */
7
+
8
+ #include "sherpa-onnx-model-detect-helper.h"
9
+
10
+ #include <algorithm>
11
+ #include <cctype>
12
+ #include <cstdint>
13
+ #include <filesystem>
14
+ #include <string>
15
+ #include <vector>
16
+
17
+ namespace fs = std::filesystem;
18
+
19
+ namespace sherpaonnx {
20
+ namespace model_detect {
21
+
22
+ namespace {
23
+
24
+ bool EndsWith(const std::string& value, const std::string& suffix) {
25
+ if (suffix.size() > value.size()) return false;
26
+ return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
27
+ }
28
+
29
+ bool ContainsToken(const std::string& value, const std::string& token) {
30
+ return value.find(token) != std::string::npos;
31
+ }
32
+
33
+ std::string ChooseLargest(const std::vector<FileEntry>& files,
34
+ const std::vector<std::string>& excludeTokens, bool onlyInt8, bool onlyNonInt8) {
35
+ std::string chosen;
36
+ std::uint64_t bestSize = 0;
37
+ for (const auto& entry : files) {
38
+ if (!EndsWith(entry.nameLower, ".onnx")) continue;
39
+ bool hasExcluded = false;
40
+ for (const auto& token : excludeTokens) {
41
+ if (ContainsToken(entry.nameLower, token)) { hasExcluded = true; break; }
42
+ }
43
+ if (hasExcluded) continue;
44
+ bool isInt8 = ContainsToken(entry.nameLower, "int8");
45
+ if (onlyInt8 && !isInt8) continue;
46
+ if (onlyNonInt8 && isInt8) continue;
47
+ if (entry.size >= bestSize) {
48
+ bestSize = entry.size;
49
+ chosen = entry.path;
50
+ }
51
+ }
52
+ return chosen;
53
+ }
54
+
55
+ } // namespace
56
+
57
+ bool FileExists(const std::string& path) {
58
+ return fs::exists(path);
59
+ }
60
+
61
+ bool IsDirectory(const std::string& path) {
62
+ return fs::is_directory(path);
63
+ }
64
+
65
+ std::string ToLower(std::string value) {
66
+ std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
67
+ return static_cast<char>(std::tolower(c));
68
+ });
69
+ return value;
70
+ }
71
+
72
+ std::vector<std::string> ListDirectories(const std::string& path) {
73
+ std::vector<std::string> results;
74
+ try {
75
+ for (const auto& entry : fs::directory_iterator(path)) {
76
+ if (entry.is_directory()) results.push_back(entry.path().string());
77
+ }
78
+ } catch (const std::exception&) {}
79
+ return results;
80
+ }
81
+
82
+ std::vector<FileEntry> ListFiles(const std::string& dir) {
83
+ std::vector<FileEntry> results;
84
+ try {
85
+ for (const auto& entry : fs::directory_iterator(dir)) {
86
+ if (!entry.is_regular_file()) continue;
87
+ FileEntry file;
88
+ file.path = entry.path().string();
89
+ std::string name = entry.path().filename().string();
90
+ file.nameLower = ToLower(name);
91
+ file.size = static_cast<std::uint64_t>(entry.file_size());
92
+ results.push_back(file);
93
+ }
94
+ } catch (const std::exception&) {}
95
+ return results;
96
+ }
97
+
98
+ std::vector<FileEntry> ListFilesRecursive(const std::string& path, int maxDepth) {
99
+ std::vector<FileEntry> results = ListFiles(path);
100
+ if (maxDepth <= 0) return results;
101
+ for (const auto& dir : ListDirectories(path)) {
102
+ auto nested = ListFilesRecursive(dir, maxDepth - 1);
103
+ results.insert(results.end(), nested.begin(), nested.end());
104
+ }
105
+ return results;
106
+ }
107
+
108
+ std::string FindLargestOnnxExcludingTokens(const std::vector<FileEntry>& files,
109
+ const std::vector<std::string>& excludeTokens) {
110
+ return ChooseLargest(files, excludeTokens, false, false);
111
+ }
112
+
113
+ std::string FindOnnxByToken(const std::vector<FileEntry>& files,
114
+ const std::string& token, const std::optional<bool>& preferInt8) {
115
+ std::string tokenLower = ToLower(token);
116
+ std::vector<FileEntry> matches;
117
+ for (const auto& entry : files) {
118
+ if (!EndsWith(entry.nameLower, ".onnx")) continue;
119
+ if (ContainsToken(entry.nameLower, tokenLower)) matches.push_back(entry);
120
+ }
121
+ if (matches.empty()) return "";
122
+ std::vector<std::string> emptyTokens;
123
+ bool wantInt8 = preferInt8.has_value() && preferInt8.value();
124
+ bool wantNonInt8 = preferInt8.has_value() && !preferInt8.value();
125
+ std::string preferred = ChooseLargest(matches, emptyTokens, wantInt8, wantNonInt8);
126
+ if (!preferred.empty()) return preferred;
127
+ return ChooseLargest(matches, emptyTokens, false, false);
128
+ }
129
+
130
+ std::string FindOnnxByAnyToken(const std::vector<FileEntry>& files,
131
+ const std::vector<std::string>& tokens, const std::optional<bool>& preferInt8) {
132
+ for (const auto& token : tokens) {
133
+ std::string match = FindOnnxByToken(files, token, preferInt8);
134
+ if (!match.empty()) return match;
135
+ }
136
+ return "";
137
+ }
138
+
139
+ std::string FindFileEndingWith(const std::vector<FileEntry>& files, const std::string& suffix) {
140
+ std::string targetSuffix = ToLower(suffix);
141
+ for (const auto& entry : files) {
142
+ if (entry.nameLower == targetSuffix) return entry.path;
143
+ }
144
+ for (const auto& entry : files) {
145
+ if (EndsWith(entry.nameLower, targetSuffix)) return entry.path;
146
+ }
147
+ return "";
148
+ }
149
+
150
+ std::string FindFileByName(const std::string& baseDir, const std::string& fileName, int maxDepth) {
151
+ std::string target = ToLower(fileName);
152
+ auto files = ListFilesRecursive(baseDir, maxDepth);
153
+ for (const auto& entry : files) {
154
+ if (entry.nameLower == target) return entry.path;
155
+ }
156
+ return "";
157
+ }
158
+
159
+ std::string FindDirectoryByName(const std::string& baseDir, const std::string& dirName, int maxDepth) {
160
+ std::string target = ToLower(dirName);
161
+ std::vector<std::string> toVisit = ListDirectories(baseDir);
162
+ int depth = 0;
163
+ while (!toVisit.empty() && depth <= maxDepth) {
164
+ std::vector<std::string> next;
165
+ for (const auto& dir : toVisit) {
166
+ std::string name = fs::path(dir).filename().string();
167
+ if (ToLower(name) == target) return dir;
168
+ if (depth < maxDepth) {
169
+ auto nested = ListDirectories(dir);
170
+ next.insert(next.end(), nested.begin(), nested.end());
171
+ }
172
+ }
173
+ toVisit.swap(next);
174
+ depth += 1;
175
+ }
176
+ return "";
177
+ }
178
+
179
+ std::string ResolveTokenizerDir(const std::string& modelDir) {
180
+ std::string vocabInMain = modelDir + "/vocab.json";
181
+ if (FileExists(vocabInMain)) {
182
+ return modelDir;
183
+ }
184
+
185
+ try {
186
+ for (const auto& entry : fs::directory_iterator(modelDir)) {
187
+ if (entry.is_directory()) {
188
+ std::string dirName = entry.path().filename().string();
189
+ std::string dirNameLower = ToLower(dirName);
190
+ if (dirNameLower.find("qwen3") != std::string::npos) {
191
+ std::string vocabPath = entry.path().string() + "/vocab.json";
192
+ if (FileExists(vocabPath)) {
193
+ return entry.path().string();
194
+ }
195
+ }
196
+ }
197
+ }
198
+ } catch (const std::exception&) {
199
+ }
200
+
201
+ std::string commonPath = modelDir + "/Qwen3-0.6B";
202
+ if (FileExists(commonPath + "/vocab.json")) {
203
+ return commonPath;
204
+ }
205
+
206
+ return "";
207
+ }
208
+
209
+ } // namespace model_detect
210
+ } // namespace sherpaonnx
@@ -0,0 +1,344 @@
1
+ /**
2
+ * sherpa-onnx-model-detect-stt.mm
3
+ *
4
+ * Purpose: Detects STT (speech-to-text) model type and fills SttModelPaths from a model directory.
5
+ * Supports transducer, paraformer, whisper, and other STT variants. Used by the STT wrapper on iOS.
6
+ */
7
+
8
+ #include "sherpa-onnx-model-detect.h"
9
+ #include "sherpa-onnx-model-detect-helper.h"
10
+
11
+ #include <string>
12
+
13
+ namespace sherpaonnx {
14
+ namespace {
15
+
16
+ using namespace model_detect;
17
+
18
+ SttModelKind ParseSttModelType(const std::string& modelType) {
19
+ if (modelType == "transducer" || modelType == "zipformer") return SttModelKind::kTransducer;
20
+ if (modelType == "nemo_transducer") return SttModelKind::kNemoTransducer;
21
+ if (modelType == "paraformer") return SttModelKind::kParaformer;
22
+ if (modelType == "nemo_ctc") return SttModelKind::kNemoCtc;
23
+ if (modelType == "wenet_ctc") return SttModelKind::kWenetCtc;
24
+ if (modelType == "sense_voice") return SttModelKind::kSenseVoice;
25
+ if (modelType == "zipformer_ctc" || modelType == "ctc") return SttModelKind::kZipformerCtc;
26
+ if (modelType == "whisper") return SttModelKind::kWhisper;
27
+ if (modelType == "funasr_nano") return SttModelKind::kFunAsrNano;
28
+ if (modelType == "fire_red_asr") return SttModelKind::kFireRedAsr;
29
+ if (modelType == "moonshine") return SttModelKind::kMoonshine;
30
+ if (modelType == "dolphin") return SttModelKind::kDolphin;
31
+ if (modelType == "canary") return SttModelKind::kCanary;
32
+ if (modelType == "omnilingual") return SttModelKind::kOmnilingual;
33
+ if (modelType == "medasr") return SttModelKind::kMedAsr;
34
+ if (modelType == "telespeech_ctc") return SttModelKind::kTeleSpeechCtc;
35
+ return SttModelKind::kUnknown;
36
+ }
37
+
38
+ } // namespace
39
+
40
+ SttDetectResult DetectSttModel(
41
+ const std::string& modelDir,
42
+ const std::optional<bool>& preferInt8,
43
+ const std::optional<std::string>& modelType,
44
+ bool debug /* = false */
45
+ ) {
46
+ using namespace model_detect;
47
+
48
+ SttDetectResult result;
49
+
50
+ if (modelDir.empty()) {
51
+ result.error = "Model directory is empty";
52
+ return result;
53
+ }
54
+
55
+ if (!FileExists(modelDir) || !IsDirectory(modelDir)) {
56
+ result.error = "Model directory does not exist or is not a directory: " + modelDir;
57
+ return result;
58
+ }
59
+
60
+ const int kMaxSearchDepth = 4;
61
+ const std::vector<FileEntry> files = ListFilesRecursive(modelDir, kMaxSearchDepth);
62
+
63
+ std::string encoderPath = FindOnnxByAnyToken(files, {"encoder"}, preferInt8);
64
+ std::string decoderPath = FindOnnxByAnyToken(files, {"decoder"}, preferInt8);
65
+ std::string joinerPath = FindOnnxByAnyToken(files, {"joiner"}, preferInt8);
66
+ std::string tokensPath = FindFileEndingWith(files, "tokens.txt");
67
+
68
+ std::vector<std::string> modelExcludes = {
69
+ "encoder", "decoder", "joiner", "vocoder", "acoustic", "embedding", "llm",
70
+ "encoder_adaptor", "encoder-adaptor"
71
+ };
72
+ std::string paraformerModelPath = FindOnnxByAnyToken(files, {"model"}, preferInt8);
73
+ if (paraformerModelPath.empty()) {
74
+ paraformerModelPath = FindLargestOnnxExcludingTokens(files, modelExcludes);
75
+ }
76
+ std::string ctcModelPath = FindOnnxByAnyToken(files, {"model"}, preferInt8);
77
+ if (ctcModelPath.empty()) {
78
+ ctcModelPath = FindLargestOnnxExcludingTokens(files, modelExcludes);
79
+ }
80
+
81
+ std::string funasrEncoderAdaptor = FindOnnxByAnyToken(files, {"encoder_adaptor", "encoder-adaptor"}, preferInt8);
82
+ std::string funasrLLM = FindOnnxByAnyToken(files, {"llm"}, preferInt8);
83
+ std::string funasrEmbedding = FindOnnxByAnyToken(files, {"embedding"}, preferInt8);
84
+ std::string funasrTokenizerDir = ResolveTokenizerDir(modelDir);
85
+
86
+ std::string moonshinePreprocess = FindOnnxByAnyToken(files, {"preprocess", "preprocessor"}, preferInt8);
87
+ std::string moonshineEncode = FindOnnxByAnyToken(files, {"encode"}, preferInt8);
88
+ std::string moonshineUncachedDecode = FindOnnxByAnyToken(files, {"uncached_decode", "uncached"}, preferInt8);
89
+ std::string moonshineCachedDecode = FindOnnxByAnyToken(files, {"cached_decode", "cached"}, preferInt8);
90
+
91
+ bool hasTransducer = !encoderPath.empty() && !decoderPath.empty() && !joinerPath.empty();
92
+
93
+ bool hasWhisperEncoder = !encoderPath.empty();
94
+ bool hasWhisperDecoder = !decoderPath.empty();
95
+ bool hasWhisper = hasWhisperEncoder && hasWhisperDecoder && joinerPath.empty();
96
+
97
+ bool hasFunAsrEncoderAdaptor = !funasrEncoderAdaptor.empty();
98
+ bool hasFunAsrLLM = !funasrLLM.empty();
99
+ bool hasFunAsrEmbedding = !funasrEmbedding.empty();
100
+ bool hasFunAsrTokenizer = !funasrTokenizerDir.empty() && FileExists(funasrTokenizerDir + "/vocab.json");
101
+ bool hasFunAsrNano = hasFunAsrEncoderAdaptor && hasFunAsrLLM && hasFunAsrEmbedding && hasFunAsrTokenizer;
102
+
103
+ std::string modelDirLower = ToLower(modelDir);
104
+ bool isLikelyNemo = modelDirLower.find("nemo") != std::string::npos ||
105
+ modelDirLower.find("parakeet") != std::string::npos;
106
+ bool isLikelyTdt = modelDirLower.find("tdt") != std::string::npos;
107
+ bool isLikelyWenetCtc = modelDirLower.find("wenet") != std::string::npos;
108
+ bool isLikelySenseVoice = modelDirLower.find("sense") != std::string::npos ||
109
+ modelDirLower.find("sensevoice") != std::string::npos;
110
+ bool isLikelyFunAsrNano = modelDirLower.find("funasr") != std::string::npos ||
111
+ modelDirLower.find("funasr-nano") != std::string::npos;
112
+ bool isLikelyZipformer = modelDirLower.find("zipformer") != std::string::npos;
113
+ bool isLikelyMoonshine = modelDirLower.find("moonshine") != std::string::npos;
114
+ bool isLikelyDolphin = modelDirLower.find("dolphin") != std::string::npos;
115
+ bool isLikelyFireRedAsr = modelDirLower.find("fire_red") != std::string::npos ||
116
+ modelDirLower.find("fire-red") != std::string::npos;
117
+ bool isLikelyCanary = modelDirLower.find("canary") != std::string::npos;
118
+ bool isLikelyOmnilingual = modelDirLower.find("omnilingual") != std::string::npos;
119
+ bool isLikelyMedAsr = modelDirLower.find("medasr") != std::string::npos;
120
+ bool isLikelyTeleSpeech = modelDirLower.find("telespeech") != std::string::npos;
121
+
122
+ bool hasMoonshine = !moonshinePreprocess.empty() && !moonshineUncachedDecode.empty() &&
123
+ !moonshineCachedDecode.empty() && !moonshineEncode.empty();
124
+ bool hasDolphin = isLikelyDolphin && !ctcModelPath.empty();
125
+ bool hasFireRedAsr = hasTransducer && isLikelyFireRedAsr;
126
+ bool hasCanary = hasWhisperEncoder && hasWhisperDecoder && joinerPath.empty() && isLikelyCanary;
127
+ bool hasOmnilingual = !ctcModelPath.empty() && isLikelyOmnilingual;
128
+ bool hasMedAsr = !ctcModelPath.empty() && isLikelyMedAsr;
129
+ bool hasTeleSpeechCtc = (!ctcModelPath.empty() || !paraformerModelPath.empty()) && isLikelyTeleSpeech;
130
+
131
+ if (hasTransducer) {
132
+ if (isLikelyNemo || isLikelyTdt) {
133
+ result.detectedModels.push_back({"nemo_transducer", modelDir});
134
+ } else {
135
+ result.detectedModels.push_back({isLikelyZipformer ? "zipformer" : "transducer", modelDir});
136
+ }
137
+ }
138
+
139
+ if (!ctcModelPath.empty() && (isLikelyNemo || isLikelyWenetCtc || isLikelySenseVoice)) {
140
+ if (isLikelyNemo) {
141
+ result.detectedModels.push_back({"nemo_ctc", modelDir});
142
+ } else if (isLikelyWenetCtc) {
143
+ result.detectedModels.push_back({"wenet_ctc", modelDir});
144
+ } else if (isLikelySenseVoice) {
145
+ result.detectedModels.push_back({"sense_voice", modelDir});
146
+ } else {
147
+ result.detectedModels.push_back({"ctc", modelDir});
148
+ }
149
+ } else if (!paraformerModelPath.empty()) {
150
+ result.detectedModels.push_back({"paraformer", modelDir});
151
+ }
152
+
153
+ if (hasWhisper) {
154
+ result.detectedModels.push_back({"whisper", modelDir});
155
+ }
156
+
157
+ if (hasFunAsrNano) {
158
+ result.detectedModels.push_back({"funasr_nano", modelDir});
159
+ }
160
+ if (hasMoonshine) {
161
+ result.detectedModels.push_back({"moonshine", modelDir});
162
+ }
163
+ if (hasDolphin) {
164
+ result.detectedModels.push_back({"dolphin", modelDir});
165
+ }
166
+ if (hasFireRedAsr) {
167
+ result.detectedModels.push_back({"fire_red_asr", modelDir});
168
+ }
169
+ if (hasCanary) {
170
+ result.detectedModels.push_back({"canary", modelDir});
171
+ }
172
+ if (hasOmnilingual) {
173
+ result.detectedModels.push_back({"omnilingual", modelDir});
174
+ }
175
+ if (hasMedAsr) {
176
+ result.detectedModels.push_back({"medasr", modelDir});
177
+ }
178
+ if (hasTeleSpeechCtc) {
179
+ result.detectedModels.push_back({"telespeech_ctc", modelDir});
180
+ }
181
+
182
+ SttModelKind selected = SttModelKind::kUnknown;
183
+
184
+ if (modelType.has_value() && modelType.value() != "auto") {
185
+ selected = ParseSttModelType(modelType.value());
186
+ if (selected == SttModelKind::kUnknown) {
187
+ result.error = "Unknown model type: " + modelType.value();
188
+ return result;
189
+ }
190
+
191
+ if (selected == SttModelKind::kTransducer && !hasTransducer) {
192
+ result.error = "Transducer model requested but files not found in " + modelDir;
193
+ return result;
194
+ }
195
+ if (selected == SttModelKind::kNemoTransducer && !hasTransducer) {
196
+ result.error = "NeMo Transducer model requested but encoder/decoder/joiner not found in " + modelDir;
197
+ return result;
198
+ }
199
+ if (selected == SttModelKind::kParaformer && paraformerModelPath.empty()) {
200
+ result.error = "Paraformer model requested but model.onnx not found in " + modelDir;
201
+ return result;
202
+ }
203
+ if ((selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
204
+ selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) &&
205
+ ctcModelPath.empty()) {
206
+ result.error = "CTC model requested but model.onnx not found in " + modelDir;
207
+ return result;
208
+ }
209
+ if (selected == SttModelKind::kWhisper && !hasWhisper) {
210
+ result.error = "Whisper model requested but encoder/decoder not found in " + modelDir;
211
+ return result;
212
+ }
213
+ if (selected == SttModelKind::kFunAsrNano && !hasFunAsrNano) {
214
+ result.error = "FunASR Nano model requested but required files not found in " + modelDir;
215
+ return result;
216
+ }
217
+ if (selected == SttModelKind::kMoonshine && !hasMoonshine) {
218
+ result.error = "Moonshine model requested but preprocess/encode/uncached_decode/cached_decode not found in " + modelDir;
219
+ return result;
220
+ }
221
+ if (selected == SttModelKind::kDolphin && !hasDolphin) {
222
+ result.error = "Dolphin model requested but model not found in " + modelDir;
223
+ return result;
224
+ }
225
+ if (selected == SttModelKind::kFireRedAsr && !hasFireRedAsr) {
226
+ result.error = "FireRed ASR model requested but encoder/decoder not found in " + modelDir;
227
+ return result;
228
+ }
229
+ if (selected == SttModelKind::kCanary && !hasCanary) {
230
+ result.error = "Canary model requested but encoder/decoder not found in " + modelDir;
231
+ return result;
232
+ }
233
+ if (selected == SttModelKind::kOmnilingual && !hasOmnilingual) {
234
+ result.error = "Omnilingual model requested but model not found in " + modelDir;
235
+ return result;
236
+ }
237
+ if (selected == SttModelKind::kMedAsr && !hasMedAsr) {
238
+ result.error = "MedASR model requested but model not found in " + modelDir;
239
+ return result;
240
+ }
241
+ if (selected == SttModelKind::kTeleSpeechCtc && !hasTeleSpeechCtc) {
242
+ result.error = "TeleSpeech CTC model requested but model not found in " + modelDir;
243
+ return result;
244
+ }
245
+ } else {
246
+ if (hasTransducer) {
247
+ selected = (isLikelyNemo || isLikelyTdt) ? SttModelKind::kNemoTransducer : SttModelKind::kTransducer;
248
+ } else if (!ctcModelPath.empty() && (isLikelyNemo || isLikelyWenetCtc || isLikelySenseVoice)) {
249
+ if (isLikelyNemo) {
250
+ selected = SttModelKind::kNemoCtc;
251
+ } else if (isLikelyWenetCtc) {
252
+ selected = SttModelKind::kWenetCtc;
253
+ } else {
254
+ selected = SttModelKind::kSenseVoice;
255
+ }
256
+ } else if (hasFunAsrNano && isLikelyFunAsrNano) {
257
+ selected = SttModelKind::kFunAsrNano;
258
+ } else if (!paraformerModelPath.empty()) {
259
+ selected = SttModelKind::kParaformer;
260
+ } else if (hasCanary) {
261
+ selected = SttModelKind::kCanary;
262
+ } else if (hasFireRedAsr) {
263
+ selected = SttModelKind::kFireRedAsr;
264
+ } else if (hasWhisper) {
265
+ selected = SttModelKind::kWhisper;
266
+ } else if (hasFunAsrNano) {
267
+ selected = SttModelKind::kFunAsrNano;
268
+ } else if (hasMoonshine && isLikelyMoonshine) {
269
+ selected = SttModelKind::kMoonshine;
270
+ } else if (hasDolphin) {
271
+ selected = SttModelKind::kDolphin;
272
+ } else if (hasFireRedAsr) {
273
+ selected = SttModelKind::kFireRedAsr;
274
+ } else if (hasCanary) {
275
+ selected = SttModelKind::kCanary;
276
+ } else if (hasOmnilingual) {
277
+ selected = SttModelKind::kOmnilingual;
278
+ } else if (hasMedAsr) {
279
+ selected = SttModelKind::kMedAsr;
280
+ } else if (hasTeleSpeechCtc) {
281
+ selected = SttModelKind::kTeleSpeechCtc;
282
+ } else if (!ctcModelPath.empty()) {
283
+ selected = SttModelKind::kZipformerCtc;
284
+ }
285
+ }
286
+
287
+ if (selected == SttModelKind::kUnknown) {
288
+ result.error = "No compatible model type detected in " + modelDir;
289
+ return result;
290
+ }
291
+
292
+ result.selectedKind = selected;
293
+ result.tokensRequired = (selected != SttModelKind::kFunAsrNano);
294
+
295
+ if (selected == SttModelKind::kTransducer || selected == SttModelKind::kNemoTransducer) {
296
+ result.paths.encoder = encoderPath;
297
+ result.paths.decoder = decoderPath;
298
+ result.paths.joiner = joinerPath;
299
+ } else if (selected == SttModelKind::kParaformer) {
300
+ result.paths.paraformerModel = paraformerModelPath;
301
+ } else if (selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
302
+ selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) {
303
+ result.paths.ctcModel = ctcModelPath;
304
+ } else if (selected == SttModelKind::kWhisper) {
305
+ result.paths.whisperEncoder = encoderPath;
306
+ result.paths.whisperDecoder = decoderPath;
307
+ } else if (selected == SttModelKind::kFunAsrNano) {
308
+ result.paths.funasrEncoderAdaptor = funasrEncoderAdaptor;
309
+ result.paths.funasrLLM = funasrLLM;
310
+ result.paths.funasrEmbedding = funasrEmbedding;
311
+ result.paths.funasrTokenizer = funasrTokenizerDir;
312
+ } else if (selected == SttModelKind::kMoonshine) {
313
+ result.paths.moonshinePreprocessor = moonshinePreprocess;
314
+ result.paths.moonshineEncoder = moonshineEncode;
315
+ result.paths.moonshineUncachedDecoder = moonshineUncachedDecode;
316
+ result.paths.moonshineCachedDecoder = moonshineCachedDecode;
317
+ } else if (selected == SttModelKind::kDolphin) {
318
+ result.paths.dolphinModel = ctcModelPath.empty() ? paraformerModelPath : ctcModelPath;
319
+ } else if (selected == SttModelKind::kFireRedAsr) {
320
+ result.paths.fireRedEncoder = encoderPath;
321
+ result.paths.fireRedDecoder = decoderPath;
322
+ } else if (selected == SttModelKind::kCanary) {
323
+ result.paths.canaryEncoder = encoderPath;
324
+ result.paths.canaryDecoder = decoderPath;
325
+ } else if (selected == SttModelKind::kOmnilingual) {
326
+ result.paths.omnilingualModel = ctcModelPath;
327
+ } else if (selected == SttModelKind::kMedAsr) {
328
+ result.paths.medasrModel = ctcModelPath;
329
+ } else if (selected == SttModelKind::kTeleSpeechCtc) {
330
+ result.paths.telespeechCtcModel = ctcModelPath.empty() ? paraformerModelPath : ctcModelPath;
331
+ }
332
+
333
+ if (!tokensPath.empty() && FileExists(tokensPath)) {
334
+ result.paths.tokens = tokensPath;
335
+ } else if (result.tokensRequired) {
336
+ result.error = "Tokens file not found in " + modelDir;
337
+ return result;
338
+ }
339
+
340
+ result.ok = true;
341
+ return result;
342
+ }
343
+
344
+ } // namespace sherpaonnx