react-native-sherpa-onnx 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +232 -236
- package/SherpaOnnx.podspec +68 -64
- package/android/build.gradle +182 -192
- package/android/codegen.gradle +57 -0
- package/android/prebuilt-download.gradle +428 -0
- package/android/prebuilt-versions.gradle +43 -0
- package/android/proguard-rules.pro +10 -0
- package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
- package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
- package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
- package/android/src/main/cpp/CMakeLists.txt +166 -129
- package/android/src/main/cpp/CMakePresets.json +54 -0
- package/android/src/main/cpp/crypto/sha256.cpp +174 -0
- package/android/src/main/cpp/crypto/sha256.h +16 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
- package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
- package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
- package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
- package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
- package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
- package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
- package/ios/SherpaOnnx+Assets.h +11 -0
- package/ios/SherpaOnnx+Assets.mm +325 -0
- package/ios/SherpaOnnx+STT.mm +455 -118
- package/ios/SherpaOnnx+TTS.mm +1101 -712
- package/ios/SherpaOnnx.h +17 -6
- package/ios/SherpaOnnx.mm +206 -311
- package/ios/SherpaOnnx.xcconfig +19 -19
- package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
- package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
- package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
- package/ios/libarchive_darwin_config.h +153 -0
- package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
- package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
- package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
- package/ios/scripts/patch-libarchive-includes.sh +61 -0
- package/ios/scripts/setup-ios-libarchive.sh +98 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
- package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
- package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
- package/lib/module/NativeSherpaOnnx.js +3 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/audio/index.js +22 -0
- package/lib/module/audio/index.js.map +1 -0
- package/lib/module/diarization/index.js +1 -1
- package/lib/module/diarization/index.js.map +1 -1
- package/lib/module/download/ModelDownloadManager.js +918 -0
- package/lib/module/download/ModelDownloadManager.js.map +1 -0
- package/lib/module/download/extractTarBz2.js +53 -0
- package/lib/module/download/extractTarBz2.js.map +1 -0
- package/lib/module/download/index.js +6 -0
- package/lib/module/download/index.js.map +1 -0
- package/lib/module/download/validation.js +178 -0
- package/lib/module/download/validation.js.map +1 -0
- package/lib/module/enhancement/index.js +1 -1
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/index.js +41 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/separation/index.js +1 -1
- package/lib/module/separation/index.js.map +1 -1
- package/lib/module/stt/index.js +127 -60
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/sttModelLanguages.js +512 -0
- package/lib/module/stt/sttModelLanguages.js.map +1 -0
- package/lib/module/stt/types.js +53 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +216 -289
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/types.js +86 -1
- package/lib/module/tts/types.js.map +1 -1
- package/lib/module/types.js.map +1 -1
- package/lib/module/utils.js +86 -73
- package/lib/module/utils.js.map +1 -1
- package/lib/module/vad/index.js +1 -1
- package/lib/module/vad/index.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/audio/index.d.ts +13 -0
- package/lib/typescript/src/audio/index.d.ts.map +1 -0
- package/lib/typescript/src/diarization/index.d.ts +3 -2
- package/lib/typescript/src/diarization/index.d.ts.map +1 -1
- package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
- package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
- package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
- package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
- package/lib/typescript/src/download/index.d.ts +7 -0
- package/lib/typescript/src/download/index.d.ts.map +1 -0
- package/lib/typescript/src/download/validation.d.ts +57 -0
- package/lib/typescript/src/download/validation.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/index.d.ts +3 -2
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +26 -2
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/separation/index.d.ts +3 -2
- package/lib/typescript/src/separation/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +31 -43
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
- package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
- package/lib/typescript/src/stt/types.d.ts +196 -9
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts +25 -211
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/types.d.ts +148 -25
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/lib/typescript/src/types.d.ts +0 -32
- package/lib/typescript/src/types.d.ts.map +1 -1
- package/lib/typescript/src/utils.d.ts +28 -13
- package/lib/typescript/src/utils.d.ts.map +1 -1
- package/lib/typescript/src/vad/index.d.ts +3 -2
- package/lib/typescript/src/vad/index.d.ts.map +1 -1
- package/package.json +250 -222
- package/scripts/check-qnn-support.sh +78 -0
- package/scripts/setup-ios-framework.sh +379 -282
- package/src/NativeSherpaOnnx.ts +474 -251
- package/src/audio/index.ts +32 -0
- package/src/diarization/index.ts +4 -2
- package/src/download/ModelDownloadManager.ts +1325 -0
- package/src/download/extractTarBz2.ts +78 -0
- package/src/download/index.ts +43 -0
- package/src/download/validation.ts +279 -0
- package/src/enhancement/index.ts +4 -2
- package/src/index.tsx +78 -27
- package/src/separation/index.ts +4 -2
- package/src/stt/index.ts +249 -89
- package/src/stt/sttModelLanguages.ts +237 -0
- package/src/stt/types.ts +263 -9
- package/src/tts/index.ts +470 -458
- package/src/tts/types.ts +373 -218
- package/src/types.ts +0 -44
- package/src/utils.ts +145 -131
- package/src/vad/index.ts +4 -2
- package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
- package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
- package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
- package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
- package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
- package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
- package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
- package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
- package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
- package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
- package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
- package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
- package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
- package/ios/sherpa-onnx-model-detect.mm +0 -441
- package/ios/sherpa-onnx-stt-wrapper.h +0 -48
- package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
- package/scripts/copy-headers.js +0 -184
- package/scripts/setup-assets.js +0 -323
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* sherpa-onnx-model-detect-helper.mm
|
|
3
|
+
*
|
|
4
|
+
* Purpose: Shared filesystem and string helpers for model detection (file/dir listing, token-based
|
|
5
|
+
* ONNX search, path resolution). Used by sherpa-onnx-model-detect-stt.mm and -tts.mm on iOS.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include "sherpa-onnx-model-detect-helper.h"
|
|
9
|
+
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <cctype>
|
|
12
|
+
#include <cstdint>
|
|
13
|
+
#include <filesystem>
|
|
14
|
+
#include <string>
|
|
15
|
+
#include <vector>
|
|
16
|
+
|
|
17
|
+
namespace fs = std::filesystem;
|
|
18
|
+
|
|
19
|
+
namespace sherpaonnx {
|
|
20
|
+
namespace model_detect {
|
|
21
|
+
|
|
22
|
+
namespace {
|
|
23
|
+
|
|
24
|
+
bool EndsWith(const std::string& value, const std::string& suffix) {
|
|
25
|
+
if (suffix.size() > value.size()) return false;
|
|
26
|
+
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
bool ContainsToken(const std::string& value, const std::string& token) {
|
|
30
|
+
return value.find(token) != std::string::npos;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
std::string ChooseLargest(const std::vector<FileEntry>& files,
|
|
34
|
+
const std::vector<std::string>& excludeTokens, bool onlyInt8, bool onlyNonInt8) {
|
|
35
|
+
std::string chosen;
|
|
36
|
+
std::uint64_t bestSize = 0;
|
|
37
|
+
for (const auto& entry : files) {
|
|
38
|
+
if (!EndsWith(entry.nameLower, ".onnx")) continue;
|
|
39
|
+
bool hasExcluded = false;
|
|
40
|
+
for (const auto& token : excludeTokens) {
|
|
41
|
+
if (ContainsToken(entry.nameLower, token)) { hasExcluded = true; break; }
|
|
42
|
+
}
|
|
43
|
+
if (hasExcluded) continue;
|
|
44
|
+
bool isInt8 = ContainsToken(entry.nameLower, "int8");
|
|
45
|
+
if (onlyInt8 && !isInt8) continue;
|
|
46
|
+
if (onlyNonInt8 && isInt8) continue;
|
|
47
|
+
if (entry.size >= bestSize) {
|
|
48
|
+
bestSize = entry.size;
|
|
49
|
+
chosen = entry.path;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
return chosen;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
} // namespace
|
|
56
|
+
|
|
57
|
+
bool FileExists(const std::string& path) {
|
|
58
|
+
return fs::exists(path);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
bool IsDirectory(const std::string& path) {
|
|
62
|
+
return fs::is_directory(path);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
std::string ToLower(std::string value) {
|
|
66
|
+
std::transform(value.begin(), value.end(), value.begin(), [](unsigned char c) {
|
|
67
|
+
return static_cast<char>(std::tolower(c));
|
|
68
|
+
});
|
|
69
|
+
return value;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
std::vector<std::string> ListDirectories(const std::string& path) {
|
|
73
|
+
std::vector<std::string> results;
|
|
74
|
+
try {
|
|
75
|
+
for (const auto& entry : fs::directory_iterator(path)) {
|
|
76
|
+
if (entry.is_directory()) results.push_back(entry.path().string());
|
|
77
|
+
}
|
|
78
|
+
} catch (const std::exception&) {}
|
|
79
|
+
return results;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
std::vector<FileEntry> ListFiles(const std::string& dir) {
|
|
83
|
+
std::vector<FileEntry> results;
|
|
84
|
+
try {
|
|
85
|
+
for (const auto& entry : fs::directory_iterator(dir)) {
|
|
86
|
+
if (!entry.is_regular_file()) continue;
|
|
87
|
+
FileEntry file;
|
|
88
|
+
file.path = entry.path().string();
|
|
89
|
+
std::string name = entry.path().filename().string();
|
|
90
|
+
file.nameLower = ToLower(name);
|
|
91
|
+
file.size = static_cast<std::uint64_t>(entry.file_size());
|
|
92
|
+
results.push_back(file);
|
|
93
|
+
}
|
|
94
|
+
} catch (const std::exception&) {}
|
|
95
|
+
return results;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
std::vector<FileEntry> ListFilesRecursive(const std::string& path, int maxDepth) {
|
|
99
|
+
std::vector<FileEntry> results = ListFiles(path);
|
|
100
|
+
if (maxDepth <= 0) return results;
|
|
101
|
+
for (const auto& dir : ListDirectories(path)) {
|
|
102
|
+
auto nested = ListFilesRecursive(dir, maxDepth - 1);
|
|
103
|
+
results.insert(results.end(), nested.begin(), nested.end());
|
|
104
|
+
}
|
|
105
|
+
return results;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
std::string FindLargestOnnxExcludingTokens(const std::vector<FileEntry>& files,
|
|
109
|
+
const std::vector<std::string>& excludeTokens) {
|
|
110
|
+
return ChooseLargest(files, excludeTokens, false, false);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
std::string FindOnnxByToken(const std::vector<FileEntry>& files,
|
|
114
|
+
const std::string& token, const std::optional<bool>& preferInt8) {
|
|
115
|
+
std::string tokenLower = ToLower(token);
|
|
116
|
+
std::vector<FileEntry> matches;
|
|
117
|
+
for (const auto& entry : files) {
|
|
118
|
+
if (!EndsWith(entry.nameLower, ".onnx")) continue;
|
|
119
|
+
if (ContainsToken(entry.nameLower, tokenLower)) matches.push_back(entry);
|
|
120
|
+
}
|
|
121
|
+
if (matches.empty()) return "";
|
|
122
|
+
std::vector<std::string> emptyTokens;
|
|
123
|
+
bool wantInt8 = preferInt8.has_value() && preferInt8.value();
|
|
124
|
+
bool wantNonInt8 = preferInt8.has_value() && !preferInt8.value();
|
|
125
|
+
std::string preferred = ChooseLargest(matches, emptyTokens, wantInt8, wantNonInt8);
|
|
126
|
+
if (!preferred.empty()) return preferred;
|
|
127
|
+
return ChooseLargest(matches, emptyTokens, false, false);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
std::string FindOnnxByAnyToken(const std::vector<FileEntry>& files,
|
|
131
|
+
const std::vector<std::string>& tokens, const std::optional<bool>& preferInt8) {
|
|
132
|
+
for (const auto& token : tokens) {
|
|
133
|
+
std::string match = FindOnnxByToken(files, token, preferInt8);
|
|
134
|
+
if (!match.empty()) return match;
|
|
135
|
+
}
|
|
136
|
+
return "";
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
std::string FindFileEndingWith(const std::vector<FileEntry>& files, const std::string& suffix) {
|
|
140
|
+
std::string targetSuffix = ToLower(suffix);
|
|
141
|
+
for (const auto& entry : files) {
|
|
142
|
+
if (entry.nameLower == targetSuffix) return entry.path;
|
|
143
|
+
}
|
|
144
|
+
for (const auto& entry : files) {
|
|
145
|
+
if (EndsWith(entry.nameLower, targetSuffix)) return entry.path;
|
|
146
|
+
}
|
|
147
|
+
return "";
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
std::string FindFileByName(const std::string& baseDir, const std::string& fileName, int maxDepth) {
|
|
151
|
+
std::string target = ToLower(fileName);
|
|
152
|
+
auto files = ListFilesRecursive(baseDir, maxDepth);
|
|
153
|
+
for (const auto& entry : files) {
|
|
154
|
+
if (entry.nameLower == target) return entry.path;
|
|
155
|
+
}
|
|
156
|
+
return "";
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
std::string FindDirectoryByName(const std::string& baseDir, const std::string& dirName, int maxDepth) {
|
|
160
|
+
std::string target = ToLower(dirName);
|
|
161
|
+
std::vector<std::string> toVisit = ListDirectories(baseDir);
|
|
162
|
+
int depth = 0;
|
|
163
|
+
while (!toVisit.empty() && depth <= maxDepth) {
|
|
164
|
+
std::vector<std::string> next;
|
|
165
|
+
for (const auto& dir : toVisit) {
|
|
166
|
+
std::string name = fs::path(dir).filename().string();
|
|
167
|
+
if (ToLower(name) == target) return dir;
|
|
168
|
+
if (depth < maxDepth) {
|
|
169
|
+
auto nested = ListDirectories(dir);
|
|
170
|
+
next.insert(next.end(), nested.begin(), nested.end());
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
toVisit.swap(next);
|
|
174
|
+
depth += 1;
|
|
175
|
+
}
|
|
176
|
+
return "";
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
std::string ResolveTokenizerDir(const std::string& modelDir) {
|
|
180
|
+
std::string vocabInMain = modelDir + "/vocab.json";
|
|
181
|
+
if (FileExists(vocabInMain)) {
|
|
182
|
+
return modelDir;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
try {
|
|
186
|
+
for (const auto& entry : fs::directory_iterator(modelDir)) {
|
|
187
|
+
if (entry.is_directory()) {
|
|
188
|
+
std::string dirName = entry.path().filename().string();
|
|
189
|
+
std::string dirNameLower = ToLower(dirName);
|
|
190
|
+
if (dirNameLower.find("qwen3") != std::string::npos) {
|
|
191
|
+
std::string vocabPath = entry.path().string() + "/vocab.json";
|
|
192
|
+
if (FileExists(vocabPath)) {
|
|
193
|
+
return entry.path().string();
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
} catch (const std::exception&) {
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
std::string commonPath = modelDir + "/Qwen3-0.6B";
|
|
202
|
+
if (FileExists(commonPath + "/vocab.json")) {
|
|
203
|
+
return commonPath;
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return "";
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
} // namespace model_detect
|
|
210
|
+
} // namespace sherpaonnx
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* sherpa-onnx-model-detect-stt.mm
|
|
3
|
+
*
|
|
4
|
+
* Purpose: Detects STT (speech-to-text) model type and fills SttModelPaths from a model directory.
|
|
5
|
+
* Supports transducer, paraformer, whisper, and other STT variants. Used by the STT wrapper on iOS.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include "sherpa-onnx-model-detect.h"
|
|
9
|
+
#include "sherpa-onnx-model-detect-helper.h"
|
|
10
|
+
|
|
11
|
+
#include <string>
|
|
12
|
+
|
|
13
|
+
namespace sherpaonnx {
|
|
14
|
+
namespace {
|
|
15
|
+
|
|
16
|
+
using namespace model_detect;
|
|
17
|
+
|
|
18
|
+
SttModelKind ParseSttModelType(const std::string& modelType) {
|
|
19
|
+
if (modelType == "transducer" || modelType == "zipformer") return SttModelKind::kTransducer;
|
|
20
|
+
if (modelType == "nemo_transducer") return SttModelKind::kNemoTransducer;
|
|
21
|
+
if (modelType == "paraformer") return SttModelKind::kParaformer;
|
|
22
|
+
if (modelType == "nemo_ctc") return SttModelKind::kNemoCtc;
|
|
23
|
+
if (modelType == "wenet_ctc") return SttModelKind::kWenetCtc;
|
|
24
|
+
if (modelType == "sense_voice") return SttModelKind::kSenseVoice;
|
|
25
|
+
if (modelType == "zipformer_ctc" || modelType == "ctc") return SttModelKind::kZipformerCtc;
|
|
26
|
+
if (modelType == "whisper") return SttModelKind::kWhisper;
|
|
27
|
+
if (modelType == "funasr_nano") return SttModelKind::kFunAsrNano;
|
|
28
|
+
if (modelType == "fire_red_asr") return SttModelKind::kFireRedAsr;
|
|
29
|
+
if (modelType == "moonshine") return SttModelKind::kMoonshine;
|
|
30
|
+
if (modelType == "dolphin") return SttModelKind::kDolphin;
|
|
31
|
+
if (modelType == "canary") return SttModelKind::kCanary;
|
|
32
|
+
if (modelType == "omnilingual") return SttModelKind::kOmnilingual;
|
|
33
|
+
if (modelType == "medasr") return SttModelKind::kMedAsr;
|
|
34
|
+
if (modelType == "telespeech_ctc") return SttModelKind::kTeleSpeechCtc;
|
|
35
|
+
return SttModelKind::kUnknown;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
} // namespace
|
|
39
|
+
|
|
40
|
+
SttDetectResult DetectSttModel(
|
|
41
|
+
const std::string& modelDir,
|
|
42
|
+
const std::optional<bool>& preferInt8,
|
|
43
|
+
const std::optional<std::string>& modelType,
|
|
44
|
+
bool debug /* = false */
|
|
45
|
+
) {
|
|
46
|
+
using namespace model_detect;
|
|
47
|
+
|
|
48
|
+
SttDetectResult result;
|
|
49
|
+
|
|
50
|
+
if (modelDir.empty()) {
|
|
51
|
+
result.error = "Model directory is empty";
|
|
52
|
+
return result;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
if (!FileExists(modelDir) || !IsDirectory(modelDir)) {
|
|
56
|
+
result.error = "Model directory does not exist or is not a directory: " + modelDir;
|
|
57
|
+
return result;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
const int kMaxSearchDepth = 4;
|
|
61
|
+
const std::vector<FileEntry> files = ListFilesRecursive(modelDir, kMaxSearchDepth);
|
|
62
|
+
|
|
63
|
+
std::string encoderPath = FindOnnxByAnyToken(files, {"encoder"}, preferInt8);
|
|
64
|
+
std::string decoderPath = FindOnnxByAnyToken(files, {"decoder"}, preferInt8);
|
|
65
|
+
std::string joinerPath = FindOnnxByAnyToken(files, {"joiner"}, preferInt8);
|
|
66
|
+
std::string tokensPath = FindFileEndingWith(files, "tokens.txt");
|
|
67
|
+
|
|
68
|
+
std::vector<std::string> modelExcludes = {
|
|
69
|
+
"encoder", "decoder", "joiner", "vocoder", "acoustic", "embedding", "llm",
|
|
70
|
+
"encoder_adaptor", "encoder-adaptor"
|
|
71
|
+
};
|
|
72
|
+
std::string paraformerModelPath = FindOnnxByAnyToken(files, {"model"}, preferInt8);
|
|
73
|
+
if (paraformerModelPath.empty()) {
|
|
74
|
+
paraformerModelPath = FindLargestOnnxExcludingTokens(files, modelExcludes);
|
|
75
|
+
}
|
|
76
|
+
std::string ctcModelPath = FindOnnxByAnyToken(files, {"model"}, preferInt8);
|
|
77
|
+
if (ctcModelPath.empty()) {
|
|
78
|
+
ctcModelPath = FindLargestOnnxExcludingTokens(files, modelExcludes);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
std::string funasrEncoderAdaptor = FindOnnxByAnyToken(files, {"encoder_adaptor", "encoder-adaptor"}, preferInt8);
|
|
82
|
+
std::string funasrLLM = FindOnnxByAnyToken(files, {"llm"}, preferInt8);
|
|
83
|
+
std::string funasrEmbedding = FindOnnxByAnyToken(files, {"embedding"}, preferInt8);
|
|
84
|
+
std::string funasrTokenizerDir = ResolveTokenizerDir(modelDir);
|
|
85
|
+
|
|
86
|
+
std::string moonshinePreprocess = FindOnnxByAnyToken(files, {"preprocess", "preprocessor"}, preferInt8);
|
|
87
|
+
std::string moonshineEncode = FindOnnxByAnyToken(files, {"encode"}, preferInt8);
|
|
88
|
+
std::string moonshineUncachedDecode = FindOnnxByAnyToken(files, {"uncached_decode", "uncached"}, preferInt8);
|
|
89
|
+
std::string moonshineCachedDecode = FindOnnxByAnyToken(files, {"cached_decode", "cached"}, preferInt8);
|
|
90
|
+
|
|
91
|
+
bool hasTransducer = !encoderPath.empty() && !decoderPath.empty() && !joinerPath.empty();
|
|
92
|
+
|
|
93
|
+
bool hasWhisperEncoder = !encoderPath.empty();
|
|
94
|
+
bool hasWhisperDecoder = !decoderPath.empty();
|
|
95
|
+
bool hasWhisper = hasWhisperEncoder && hasWhisperDecoder && joinerPath.empty();
|
|
96
|
+
|
|
97
|
+
bool hasFunAsrEncoderAdaptor = !funasrEncoderAdaptor.empty();
|
|
98
|
+
bool hasFunAsrLLM = !funasrLLM.empty();
|
|
99
|
+
bool hasFunAsrEmbedding = !funasrEmbedding.empty();
|
|
100
|
+
bool hasFunAsrTokenizer = !funasrTokenizerDir.empty() && FileExists(funasrTokenizerDir + "/vocab.json");
|
|
101
|
+
bool hasFunAsrNano = hasFunAsrEncoderAdaptor && hasFunAsrLLM && hasFunAsrEmbedding && hasFunAsrTokenizer;
|
|
102
|
+
|
|
103
|
+
std::string modelDirLower = ToLower(modelDir);
|
|
104
|
+
bool isLikelyNemo = modelDirLower.find("nemo") != std::string::npos ||
|
|
105
|
+
modelDirLower.find("parakeet") != std::string::npos;
|
|
106
|
+
bool isLikelyTdt = modelDirLower.find("tdt") != std::string::npos;
|
|
107
|
+
bool isLikelyWenetCtc = modelDirLower.find("wenet") != std::string::npos;
|
|
108
|
+
bool isLikelySenseVoice = modelDirLower.find("sense") != std::string::npos ||
|
|
109
|
+
modelDirLower.find("sensevoice") != std::string::npos;
|
|
110
|
+
bool isLikelyFunAsrNano = modelDirLower.find("funasr") != std::string::npos ||
|
|
111
|
+
modelDirLower.find("funasr-nano") != std::string::npos;
|
|
112
|
+
bool isLikelyZipformer = modelDirLower.find("zipformer") != std::string::npos;
|
|
113
|
+
bool isLikelyMoonshine = modelDirLower.find("moonshine") != std::string::npos;
|
|
114
|
+
bool isLikelyDolphin = modelDirLower.find("dolphin") != std::string::npos;
|
|
115
|
+
bool isLikelyFireRedAsr = modelDirLower.find("fire_red") != std::string::npos ||
|
|
116
|
+
modelDirLower.find("fire-red") != std::string::npos;
|
|
117
|
+
bool isLikelyCanary = modelDirLower.find("canary") != std::string::npos;
|
|
118
|
+
bool isLikelyOmnilingual = modelDirLower.find("omnilingual") != std::string::npos;
|
|
119
|
+
bool isLikelyMedAsr = modelDirLower.find("medasr") != std::string::npos;
|
|
120
|
+
bool isLikelyTeleSpeech = modelDirLower.find("telespeech") != std::string::npos;
|
|
121
|
+
|
|
122
|
+
bool hasMoonshine = !moonshinePreprocess.empty() && !moonshineUncachedDecode.empty() &&
|
|
123
|
+
!moonshineCachedDecode.empty() && !moonshineEncode.empty();
|
|
124
|
+
bool hasDolphin = isLikelyDolphin && !ctcModelPath.empty();
|
|
125
|
+
bool hasFireRedAsr = hasTransducer && isLikelyFireRedAsr;
|
|
126
|
+
bool hasCanary = hasWhisperEncoder && hasWhisperDecoder && joinerPath.empty() && isLikelyCanary;
|
|
127
|
+
bool hasOmnilingual = !ctcModelPath.empty() && isLikelyOmnilingual;
|
|
128
|
+
bool hasMedAsr = !ctcModelPath.empty() && isLikelyMedAsr;
|
|
129
|
+
bool hasTeleSpeechCtc = (!ctcModelPath.empty() || !paraformerModelPath.empty()) && isLikelyTeleSpeech;
|
|
130
|
+
|
|
131
|
+
if (hasTransducer) {
|
|
132
|
+
if (isLikelyNemo || isLikelyTdt) {
|
|
133
|
+
result.detectedModels.push_back({"nemo_transducer", modelDir});
|
|
134
|
+
} else {
|
|
135
|
+
result.detectedModels.push_back({isLikelyZipformer ? "zipformer" : "transducer", modelDir});
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
if (!ctcModelPath.empty() && (isLikelyNemo || isLikelyWenetCtc || isLikelySenseVoice)) {
|
|
140
|
+
if (isLikelyNemo) {
|
|
141
|
+
result.detectedModels.push_back({"nemo_ctc", modelDir});
|
|
142
|
+
} else if (isLikelyWenetCtc) {
|
|
143
|
+
result.detectedModels.push_back({"wenet_ctc", modelDir});
|
|
144
|
+
} else if (isLikelySenseVoice) {
|
|
145
|
+
result.detectedModels.push_back({"sense_voice", modelDir});
|
|
146
|
+
} else {
|
|
147
|
+
result.detectedModels.push_back({"ctc", modelDir});
|
|
148
|
+
}
|
|
149
|
+
} else if (!paraformerModelPath.empty()) {
|
|
150
|
+
result.detectedModels.push_back({"paraformer", modelDir});
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
if (hasWhisper) {
|
|
154
|
+
result.detectedModels.push_back({"whisper", modelDir});
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if (hasFunAsrNano) {
|
|
158
|
+
result.detectedModels.push_back({"funasr_nano", modelDir});
|
|
159
|
+
}
|
|
160
|
+
if (hasMoonshine) {
|
|
161
|
+
result.detectedModels.push_back({"moonshine", modelDir});
|
|
162
|
+
}
|
|
163
|
+
if (hasDolphin) {
|
|
164
|
+
result.detectedModels.push_back({"dolphin", modelDir});
|
|
165
|
+
}
|
|
166
|
+
if (hasFireRedAsr) {
|
|
167
|
+
result.detectedModels.push_back({"fire_red_asr", modelDir});
|
|
168
|
+
}
|
|
169
|
+
if (hasCanary) {
|
|
170
|
+
result.detectedModels.push_back({"canary", modelDir});
|
|
171
|
+
}
|
|
172
|
+
if (hasOmnilingual) {
|
|
173
|
+
result.detectedModels.push_back({"omnilingual", modelDir});
|
|
174
|
+
}
|
|
175
|
+
if (hasMedAsr) {
|
|
176
|
+
result.detectedModels.push_back({"medasr", modelDir});
|
|
177
|
+
}
|
|
178
|
+
if (hasTeleSpeechCtc) {
|
|
179
|
+
result.detectedModels.push_back({"telespeech_ctc", modelDir});
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
SttModelKind selected = SttModelKind::kUnknown;
|
|
183
|
+
|
|
184
|
+
if (modelType.has_value() && modelType.value() != "auto") {
|
|
185
|
+
selected = ParseSttModelType(modelType.value());
|
|
186
|
+
if (selected == SttModelKind::kUnknown) {
|
|
187
|
+
result.error = "Unknown model type: " + modelType.value();
|
|
188
|
+
return result;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if (selected == SttModelKind::kTransducer && !hasTransducer) {
|
|
192
|
+
result.error = "Transducer model requested but files not found in " + modelDir;
|
|
193
|
+
return result;
|
|
194
|
+
}
|
|
195
|
+
if (selected == SttModelKind::kNemoTransducer && !hasTransducer) {
|
|
196
|
+
result.error = "NeMo Transducer model requested but encoder/decoder/joiner not found in " + modelDir;
|
|
197
|
+
return result;
|
|
198
|
+
}
|
|
199
|
+
if (selected == SttModelKind::kParaformer && paraformerModelPath.empty()) {
|
|
200
|
+
result.error = "Paraformer model requested but model.onnx not found in " + modelDir;
|
|
201
|
+
return result;
|
|
202
|
+
}
|
|
203
|
+
if ((selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
|
|
204
|
+
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) &&
|
|
205
|
+
ctcModelPath.empty()) {
|
|
206
|
+
result.error = "CTC model requested but model.onnx not found in " + modelDir;
|
|
207
|
+
return result;
|
|
208
|
+
}
|
|
209
|
+
if (selected == SttModelKind::kWhisper && !hasWhisper) {
|
|
210
|
+
result.error = "Whisper model requested but encoder/decoder not found in " + modelDir;
|
|
211
|
+
return result;
|
|
212
|
+
}
|
|
213
|
+
if (selected == SttModelKind::kFunAsrNano && !hasFunAsrNano) {
|
|
214
|
+
result.error = "FunASR Nano model requested but required files not found in " + modelDir;
|
|
215
|
+
return result;
|
|
216
|
+
}
|
|
217
|
+
if (selected == SttModelKind::kMoonshine && !hasMoonshine) {
|
|
218
|
+
result.error = "Moonshine model requested but preprocess/encode/uncached_decode/cached_decode not found in " + modelDir;
|
|
219
|
+
return result;
|
|
220
|
+
}
|
|
221
|
+
if (selected == SttModelKind::kDolphin && !hasDolphin) {
|
|
222
|
+
result.error = "Dolphin model requested but model not found in " + modelDir;
|
|
223
|
+
return result;
|
|
224
|
+
}
|
|
225
|
+
if (selected == SttModelKind::kFireRedAsr && !hasFireRedAsr) {
|
|
226
|
+
result.error = "FireRed ASR model requested but encoder/decoder not found in " + modelDir;
|
|
227
|
+
return result;
|
|
228
|
+
}
|
|
229
|
+
if (selected == SttModelKind::kCanary && !hasCanary) {
|
|
230
|
+
result.error = "Canary model requested but encoder/decoder not found in " + modelDir;
|
|
231
|
+
return result;
|
|
232
|
+
}
|
|
233
|
+
if (selected == SttModelKind::kOmnilingual && !hasOmnilingual) {
|
|
234
|
+
result.error = "Omnilingual model requested but model not found in " + modelDir;
|
|
235
|
+
return result;
|
|
236
|
+
}
|
|
237
|
+
if (selected == SttModelKind::kMedAsr && !hasMedAsr) {
|
|
238
|
+
result.error = "MedASR model requested but model not found in " + modelDir;
|
|
239
|
+
return result;
|
|
240
|
+
}
|
|
241
|
+
if (selected == SttModelKind::kTeleSpeechCtc && !hasTeleSpeechCtc) {
|
|
242
|
+
result.error = "TeleSpeech CTC model requested but model not found in " + modelDir;
|
|
243
|
+
return result;
|
|
244
|
+
}
|
|
245
|
+
} else {
|
|
246
|
+
if (hasTransducer) {
|
|
247
|
+
selected = (isLikelyNemo || isLikelyTdt) ? SttModelKind::kNemoTransducer : SttModelKind::kTransducer;
|
|
248
|
+
} else if (!ctcModelPath.empty() && (isLikelyNemo || isLikelyWenetCtc || isLikelySenseVoice)) {
|
|
249
|
+
if (isLikelyNemo) {
|
|
250
|
+
selected = SttModelKind::kNemoCtc;
|
|
251
|
+
} else if (isLikelyWenetCtc) {
|
|
252
|
+
selected = SttModelKind::kWenetCtc;
|
|
253
|
+
} else {
|
|
254
|
+
selected = SttModelKind::kSenseVoice;
|
|
255
|
+
}
|
|
256
|
+
} else if (hasFunAsrNano && isLikelyFunAsrNano) {
|
|
257
|
+
selected = SttModelKind::kFunAsrNano;
|
|
258
|
+
} else if (!paraformerModelPath.empty()) {
|
|
259
|
+
selected = SttModelKind::kParaformer;
|
|
260
|
+
} else if (hasCanary) {
|
|
261
|
+
selected = SttModelKind::kCanary;
|
|
262
|
+
} else if (hasFireRedAsr) {
|
|
263
|
+
selected = SttModelKind::kFireRedAsr;
|
|
264
|
+
} else if (hasWhisper) {
|
|
265
|
+
selected = SttModelKind::kWhisper;
|
|
266
|
+
} else if (hasFunAsrNano) {
|
|
267
|
+
selected = SttModelKind::kFunAsrNano;
|
|
268
|
+
} else if (hasMoonshine && isLikelyMoonshine) {
|
|
269
|
+
selected = SttModelKind::kMoonshine;
|
|
270
|
+
} else if (hasDolphin) {
|
|
271
|
+
selected = SttModelKind::kDolphin;
|
|
272
|
+
} else if (hasFireRedAsr) {
|
|
273
|
+
selected = SttModelKind::kFireRedAsr;
|
|
274
|
+
} else if (hasCanary) {
|
|
275
|
+
selected = SttModelKind::kCanary;
|
|
276
|
+
} else if (hasOmnilingual) {
|
|
277
|
+
selected = SttModelKind::kOmnilingual;
|
|
278
|
+
} else if (hasMedAsr) {
|
|
279
|
+
selected = SttModelKind::kMedAsr;
|
|
280
|
+
} else if (hasTeleSpeechCtc) {
|
|
281
|
+
selected = SttModelKind::kTeleSpeechCtc;
|
|
282
|
+
} else if (!ctcModelPath.empty()) {
|
|
283
|
+
selected = SttModelKind::kZipformerCtc;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
if (selected == SttModelKind::kUnknown) {
|
|
288
|
+
result.error = "No compatible model type detected in " + modelDir;
|
|
289
|
+
return result;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
result.selectedKind = selected;
|
|
293
|
+
result.tokensRequired = (selected != SttModelKind::kFunAsrNano);
|
|
294
|
+
|
|
295
|
+
if (selected == SttModelKind::kTransducer || selected == SttModelKind::kNemoTransducer) {
|
|
296
|
+
result.paths.encoder = encoderPath;
|
|
297
|
+
result.paths.decoder = decoderPath;
|
|
298
|
+
result.paths.joiner = joinerPath;
|
|
299
|
+
} else if (selected == SttModelKind::kParaformer) {
|
|
300
|
+
result.paths.paraformerModel = paraformerModelPath;
|
|
301
|
+
} else if (selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
|
|
302
|
+
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) {
|
|
303
|
+
result.paths.ctcModel = ctcModelPath;
|
|
304
|
+
} else if (selected == SttModelKind::kWhisper) {
|
|
305
|
+
result.paths.whisperEncoder = encoderPath;
|
|
306
|
+
result.paths.whisperDecoder = decoderPath;
|
|
307
|
+
} else if (selected == SttModelKind::kFunAsrNano) {
|
|
308
|
+
result.paths.funasrEncoderAdaptor = funasrEncoderAdaptor;
|
|
309
|
+
result.paths.funasrLLM = funasrLLM;
|
|
310
|
+
result.paths.funasrEmbedding = funasrEmbedding;
|
|
311
|
+
result.paths.funasrTokenizer = funasrTokenizerDir;
|
|
312
|
+
} else if (selected == SttModelKind::kMoonshine) {
|
|
313
|
+
result.paths.moonshinePreprocessor = moonshinePreprocess;
|
|
314
|
+
result.paths.moonshineEncoder = moonshineEncode;
|
|
315
|
+
result.paths.moonshineUncachedDecoder = moonshineUncachedDecode;
|
|
316
|
+
result.paths.moonshineCachedDecoder = moonshineCachedDecode;
|
|
317
|
+
} else if (selected == SttModelKind::kDolphin) {
|
|
318
|
+
result.paths.dolphinModel = ctcModelPath.empty() ? paraformerModelPath : ctcModelPath;
|
|
319
|
+
} else if (selected == SttModelKind::kFireRedAsr) {
|
|
320
|
+
result.paths.fireRedEncoder = encoderPath;
|
|
321
|
+
result.paths.fireRedDecoder = decoderPath;
|
|
322
|
+
} else if (selected == SttModelKind::kCanary) {
|
|
323
|
+
result.paths.canaryEncoder = encoderPath;
|
|
324
|
+
result.paths.canaryDecoder = decoderPath;
|
|
325
|
+
} else if (selected == SttModelKind::kOmnilingual) {
|
|
326
|
+
result.paths.omnilingualModel = ctcModelPath;
|
|
327
|
+
} else if (selected == SttModelKind::kMedAsr) {
|
|
328
|
+
result.paths.medasrModel = ctcModelPath;
|
|
329
|
+
} else if (selected == SttModelKind::kTeleSpeechCtc) {
|
|
330
|
+
result.paths.telespeechCtcModel = ctcModelPath.empty() ? paraformerModelPath : ctcModelPath;
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
if (!tokensPath.empty() && FileExists(tokensPath)) {
|
|
334
|
+
result.paths.tokens = tokensPath;
|
|
335
|
+
} else if (result.tokensRequired) {
|
|
336
|
+
result.error = "Tokens file not found in " + modelDir;
|
|
337
|
+
return result;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
result.ok = true;
|
|
341
|
+
return result;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
} // namespace sherpaonnx
|