react-native-sherpa-onnx 0.3.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +20 -5
  2. package/SherpaOnnx.podspec +5 -1
  3. package/android/prebuilt-download.gradle +89 -49
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/asr-models-license-status.csv +1 -0
  6. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  7. package/android/src/main/cpp/CMakeLists.txt +3 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +23 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +9 -0
  13. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +51 -8
  14. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +41 -0
  15. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +5 -0
  16. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  17. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  18. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-stt.cpp +11 -0
  19. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  20. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +110 -35
  21. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  22. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  23. package/android/src/main/java/com/sherpaonnx/SherpaOnnxExtractionNotificationHelper.kt +102 -0
  24. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +198 -18
  25. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +22 -0
  26. package/ios/Resources/model_licenses/asr-models-license-status.csv +1 -0
  27. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  28. package/ios/SherpaOnnx+Assets.mm +5 -0
  29. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  30. package/ios/SherpaOnnx+STT.mm +13 -1
  31. package/ios/SherpaOnnx.mm +87 -17
  32. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  33. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  34. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  35. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +5 -0
  36. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +23 -0
  37. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +51 -7
  38. package/ios/model_detect/sherpa-onnx-model-detect.h +33 -0
  39. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  40. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  41. package/ios/model_detect/sherpa-onnx-validate-stt.mm +11 -0
  42. package/ios/stt/sherpa-onnx-stt-wrapper.h +11 -1
  43. package/ios/stt/sherpa-onnx-stt-wrapper.mm +30 -2
  44. package/ios/tts/sherpa-onnx-tts-wrapper.mm +16 -0
  45. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  46. package/lib/module/download/localModels.js +2 -3
  47. package/lib/module/download/localModels.js.map +1 -1
  48. package/lib/module/download/paths.js +2 -1
  49. package/lib/module/download/paths.js.map +1 -1
  50. package/lib/module/download/postDownloadProcessing.js +17 -4
  51. package/lib/module/download/postDownloadProcessing.js.map +1 -1
  52. package/lib/module/enhancement/index.js +63 -48
  53. package/lib/module/enhancement/index.js.map +1 -1
  54. package/lib/module/enhancement/streaming.js +60 -0
  55. package/lib/module/enhancement/streaming.js.map +1 -0
  56. package/lib/module/enhancement/streamingTypes.js +4 -0
  57. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  58. package/lib/module/enhancement/types.js +4 -0
  59. package/lib/module/enhancement/types.js.map +1 -0
  60. package/lib/module/extraction/extractTarBz2.js +2 -2
  61. package/lib/module/extraction/extractTarBz2.js.map +1 -1
  62. package/lib/module/extraction/extractTarZst.js +2 -2
  63. package/lib/module/extraction/extractTarZst.js.map +1 -1
  64. package/lib/module/extraction/index.js +10 -5
  65. package/lib/module/extraction/index.js.map +1 -1
  66. package/lib/module/licenses.js +9 -3
  67. package/lib/module/licenses.js.map +1 -1
  68. package/lib/module/stt/index.js +4 -2
  69. package/lib/module/stt/index.js.map +1 -1
  70. package/lib/module/stt/streaming.js +2 -1
  71. package/lib/module/stt/streaming.js.map +1 -1
  72. package/lib/module/stt/types.js +3 -1
  73. package/lib/module/stt/types.js.map +1 -1
  74. package/lib/module/tts/index.js +4 -2
  75. package/lib/module/tts/index.js.map +1 -1
  76. package/lib/module/tts/streaming.js +3 -1
  77. package/lib/module/tts/streaming.js.map +1 -1
  78. package/lib/typescript/src/NativeSherpaOnnx.d.ts +70 -9
  79. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  80. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  81. package/lib/typescript/src/download/paths.d.ts +2 -1
  82. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  83. package/lib/typescript/src/download/postDownloadProcessing.d.ts +9 -0
  84. package/lib/typescript/src/download/postDownloadProcessing.d.ts.map +1 -1
  85. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  86. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  87. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  88. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  89. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  90. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  91. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  92. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  93. package/lib/typescript/src/extraction/extractTarBz2.d.ts +2 -1
  94. package/lib/typescript/src/extraction/extractTarBz2.d.ts.map +1 -1
  95. package/lib/typescript/src/extraction/extractTarZst.d.ts +2 -1
  96. package/lib/typescript/src/extraction/extractTarZst.d.ts.map +1 -1
  97. package/lib/typescript/src/extraction/index.d.ts +1 -1
  98. package/lib/typescript/src/extraction/index.d.ts.map +1 -1
  99. package/lib/typescript/src/extraction/types.d.ts +12 -0
  100. package/lib/typescript/src/extraction/types.d.ts.map +1 -1
  101. package/lib/typescript/src/licenses.d.ts.map +1 -1
  102. package/lib/typescript/src/stt/index.d.ts +1 -1
  103. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  104. package/lib/typescript/src/stt/streaming.d.ts.map +1 -1
  105. package/lib/typescript/src/stt/types.d.ts +16 -1
  106. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  107. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  108. package/lib/typescript/src/tts/streaming.d.ts.map +1 -1
  109. package/package.json +1 -1
  110. package/scripts/ci/check-model-csvs.sh +27 -2
  111. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  112. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  113. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  114. package/scripts/ci/update_model_license_csv.sh +17 -17
  115. package/src/NativeSherpaOnnx.ts +108 -10
  116. package/src/download/localModels.ts +1 -3
  117. package/src/download/paths.ts +2 -1
  118. package/src/download/postDownloadProcessing.ts +24 -1
  119. package/src/enhancement/index.ts +120 -58
  120. package/src/enhancement/streaming.ts +105 -0
  121. package/src/enhancement/streamingTypes.ts +14 -0
  122. package/src/enhancement/types.ts +36 -0
  123. package/src/extraction/extractTarBz2.ts +7 -2
  124. package/src/extraction/extractTarZst.ts +7 -2
  125. package/src/extraction/index.ts +29 -6
  126. package/src/extraction/types.ts +16 -0
  127. package/src/licenses.ts +13 -2
  128. package/src/stt/index.ts +8 -7
  129. package/src/stt/streaming.ts +7 -1
  130. package/src/stt/types.ts +18 -0
  131. package/src/tts/index.ts +7 -7
  132. package/src/tts/streaming.ts +6 -3
  133. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
  134. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
@@ -0,0 +1,85 @@
1
+ #ifndef SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
2
+ #define SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
3
+
4
+ #include "sherpa-onnx-common.h"
5
+ #include "sherpa-onnx-model-detect.h"
6
+ #include <cstdint>
7
+ #include <memory>
8
+ #include <optional>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ namespace sherpaonnx {
13
+
14
+ struct EnhancementInitializeResult {
15
+ bool success = false;
16
+ std::vector<DetectedModel> detectedModels;
17
+ std::string error;
18
+ std::string modelType;
19
+ int32_t sampleRate = 0;
20
+ int32_t frameShiftInSamples = 0;
21
+ };
22
+
23
+ struct EnhancedAudioResult {
24
+ std::vector<float> samples;
25
+ int32_t sampleRate = 0;
26
+ };
27
+
28
+ class EnhancementWrapper {
29
+ public:
30
+ EnhancementWrapper();
31
+ ~EnhancementWrapper();
32
+
33
+ EnhancementInitializeResult initialize(
34
+ const std::string& modelDir,
35
+ const std::string& modelType = "auto",
36
+ int32_t numThreads = 1,
37
+ const std::optional<std::string>& provider = std::nullopt,
38
+ bool debug = false
39
+ );
40
+
41
+ EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
42
+
43
+ int32_t getSampleRate() const;
44
+
45
+ bool isInitialized() const;
46
+
47
+ void release();
48
+
49
+ private:
50
+ class Impl;
51
+ std::unique_ptr<Impl> pImpl;
52
+ };
53
+
54
+ class OnlineEnhancementWrapper {
55
+ public:
56
+ OnlineEnhancementWrapper();
57
+ ~OnlineEnhancementWrapper();
58
+
59
+ EnhancementInitializeResult initialize(
60
+ const std::string& modelDir,
61
+ const std::string& modelType = "auto",
62
+ int32_t numThreads = 1,
63
+ const std::optional<std::string>& provider = std::nullopt,
64
+ bool debug = false
65
+ );
66
+
67
+ EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
68
+ EnhancedAudioResult flush();
69
+ void reset();
70
+
71
+ int32_t getSampleRate() const;
72
+ int32_t getFrameShiftInSamples() const;
73
+
74
+ bool isInitialized() const;
75
+
76
+ void release();
77
+
78
+ private:
79
+ class Impl;
80
+ std::unique_ptr<Impl> pImpl;
81
+ };
82
+
83
+ } // namespace sherpaonnx
84
+
85
+ #endif // SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
@@ -0,0 +1,218 @@
1
+ #include "sherpa-onnx-enhancement-wrapper.h"
2
+
3
+ #include "sherpa-onnx-model-detect.h"
4
+
5
+ #include <optional>
6
+
7
+ #include "sherpa-onnx/c-api/cxx-api.h"
8
+
9
+ namespace sherpaonnx {
10
+ namespace {
11
+
12
+ std::string EnhancementKindToString(EnhancementModelKind kind) {
13
+ switch (kind) {
14
+ case EnhancementModelKind::kGtcrn:
15
+ return "gtcrn";
16
+ case EnhancementModelKind::kDpdfNet:
17
+ return "dpdfnet";
18
+ default:
19
+ return "unknown";
20
+ }
21
+ }
22
+
23
+ sherpa_onnx::cxx::OfflineSpeechDenoiserModelConfig BuildModelConfig(
24
+ const EnhancementDetectResult& detect,
25
+ int32_t numThreads,
26
+ const std::optional<std::string>& provider,
27
+ bool debug
28
+ ) {
29
+ sherpa_onnx::cxx::OfflineSpeechDenoiserModelConfig cfg;
30
+ cfg.num_threads = numThreads;
31
+ cfg.debug = debug;
32
+ if (provider.has_value() && !provider->empty()) {
33
+ cfg.provider = *provider;
34
+ }
35
+ switch (detect.selectedKind) {
36
+ case EnhancementModelKind::kGtcrn:
37
+ cfg.gtcrn.model = detect.paths.model;
38
+ break;
39
+ case EnhancementModelKind::kDpdfNet:
40
+ cfg.dpdfnet.model = detect.paths.model;
41
+ break;
42
+ default:
43
+ break;
44
+ }
45
+ return cfg;
46
+ }
47
+
48
+ EnhancedAudioResult ToEnhancedAudioResult(
49
+ const sherpa_onnx::cxx::DenoisedAudio& audio
50
+ ) {
51
+ EnhancedAudioResult out;
52
+ out.samples = audio.samples;
53
+ out.sampleRate = audio.sample_rate;
54
+ return out;
55
+ }
56
+
57
+ } // namespace
58
+
59
+ class EnhancementWrapper::Impl {
60
+ public:
61
+ bool initialized = false;
62
+ std::optional<sherpa_onnx::cxx::OfflineSpeechDenoiser> denoiser;
63
+ };
64
+
65
+ EnhancementWrapper::EnhancementWrapper() : pImpl(std::make_unique<Impl>()) {}
66
+
67
+ EnhancementWrapper::~EnhancementWrapper() { release(); }
68
+
69
+ EnhancementInitializeResult EnhancementWrapper::initialize(
70
+ const std::string& modelDir,
71
+ const std::string& modelType,
72
+ int32_t numThreads,
73
+ const std::optional<std::string>& provider,
74
+ bool debug
75
+ ) {
76
+ EnhancementInitializeResult result;
77
+ if (pImpl->initialized) {
78
+ release();
79
+ }
80
+ if (modelDir.empty()) {
81
+ result.error = "Enhancement model directory is empty";
82
+ return result;
83
+ }
84
+
85
+ auto detect = DetectEnhancementModel(modelDir, modelType);
86
+ result.detectedModels = detect.detectedModels;
87
+ result.modelType = EnhancementKindToString(detect.selectedKind);
88
+ if (!detect.ok) {
89
+ result.error = detect.error;
90
+ return result;
91
+ }
92
+
93
+ sherpa_onnx::cxx::OfflineSpeechDenoiserConfig config;
94
+ config.model = BuildModelConfig(detect, numThreads, provider, debug);
95
+ pImpl->denoiser = sherpa_onnx::cxx::OfflineSpeechDenoiser::Create(config);
96
+ pImpl->initialized = true;
97
+
98
+ result.success = true;
99
+ result.sampleRate = pImpl->denoiser->GetSampleRate();
100
+ return result;
101
+ }
102
+
103
+ EnhancedAudioResult EnhancementWrapper::runSamples(
104
+ const std::vector<float>& samples,
105
+ int32_t sampleRate
106
+ ) {
107
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) {
108
+ return {};
109
+ }
110
+ return ToEnhancedAudioResult(
111
+ pImpl->denoiser->Run(samples.data(), static_cast<int32_t>(samples.size()), sampleRate)
112
+ );
113
+ }
114
+
115
+ int32_t EnhancementWrapper::getSampleRate() const {
116
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) return 0;
117
+ return pImpl->denoiser->GetSampleRate();
118
+ }
119
+
120
+ bool EnhancementWrapper::isInitialized() const { return pImpl->initialized; }
121
+
122
+ void EnhancementWrapper::release() {
123
+ if (pImpl->denoiser.has_value()) {
124
+ pImpl->denoiser.reset();
125
+ }
126
+ pImpl->initialized = false;
127
+ }
128
+
129
+ class OnlineEnhancementWrapper::Impl {
130
+ public:
131
+ bool initialized = false;
132
+ std::optional<sherpa_onnx::cxx::OnlineSpeechDenoiser> denoiser;
133
+ };
134
+
135
+ OnlineEnhancementWrapper::OnlineEnhancementWrapper()
136
+ : pImpl(std::make_unique<Impl>()) {}
137
+
138
+ OnlineEnhancementWrapper::~OnlineEnhancementWrapper() { release(); }
139
+
140
+ EnhancementInitializeResult OnlineEnhancementWrapper::initialize(
141
+ const std::string& modelDir,
142
+ const std::string& modelType,
143
+ int32_t numThreads,
144
+ const std::optional<std::string>& provider,
145
+ bool debug
146
+ ) {
147
+ EnhancementInitializeResult result;
148
+ if (pImpl->initialized) {
149
+ release();
150
+ }
151
+ if (modelDir.empty()) {
152
+ result.error = "Enhancement model directory is empty";
153
+ return result;
154
+ }
155
+
156
+ auto detect = DetectEnhancementModel(modelDir, modelType);
157
+ result.detectedModels = detect.detectedModels;
158
+ result.modelType = EnhancementKindToString(detect.selectedKind);
159
+ if (!detect.ok) {
160
+ result.error = detect.error;
161
+ return result;
162
+ }
163
+
164
+ sherpa_onnx::cxx::OnlineSpeechDenoiserConfig config;
165
+ config.model = BuildModelConfig(detect, numThreads, provider, debug);
166
+ pImpl->denoiser = sherpa_onnx::cxx::OnlineSpeechDenoiser::Create(config);
167
+ pImpl->initialized = true;
168
+
169
+ result.success = true;
170
+ result.sampleRate = pImpl->denoiser->GetSampleRate();
171
+ result.frameShiftInSamples = pImpl->denoiser->GetFrameShiftInSamples();
172
+ return result;
173
+ }
174
+
175
+ EnhancedAudioResult OnlineEnhancementWrapper::runSamples(
176
+ const std::vector<float>& samples,
177
+ int32_t sampleRate
178
+ ) {
179
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) {
180
+ return {};
181
+ }
182
+ return ToEnhancedAudioResult(
183
+ pImpl->denoiser->Run(samples.data(), static_cast<int32_t>(samples.size()), sampleRate)
184
+ );
185
+ }
186
+
187
+ EnhancedAudioResult OnlineEnhancementWrapper::flush() {
188
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) {
189
+ return {};
190
+ }
191
+ return ToEnhancedAudioResult(pImpl->denoiser->Flush());
192
+ }
193
+
194
+ void OnlineEnhancementWrapper::reset() {
195
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) return;
196
+ pImpl->denoiser->Reset();
197
+ }
198
+
199
+ int32_t OnlineEnhancementWrapper::getSampleRate() const {
200
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) return 0;
201
+ return pImpl->denoiser->GetSampleRate();
202
+ }
203
+
204
+ int32_t OnlineEnhancementWrapper::getFrameShiftInSamples() const {
205
+ if (!pImpl->initialized || !pImpl->denoiser.has_value()) return 0;
206
+ return pImpl->denoiser->GetFrameShiftInSamples();
207
+ }
208
+
209
+ bool OnlineEnhancementWrapper::isInitialized() const { return pImpl->initialized; }
210
+
211
+ void OnlineEnhancementWrapper::release() {
212
+ if (pImpl->denoiser.has_value()) {
213
+ pImpl->denoiser.reset();
214
+ }
215
+ pImpl->initialized = false;
216
+ }
217
+
218
+ } // namespace sherpaonnx
@@ -0,0 +1,92 @@
1
+ #include "sherpa-onnx-model-detect.h"
2
+ #include "sherpa-onnx-model-detect-helper.h"
3
+ #include "sherpa-onnx-validate-enhancement.h"
4
+
5
+ #include <optional>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ namespace sherpaonnx {
10
+ namespace {
11
+
12
+ using namespace model_detect;
13
+
14
+ EnhancementModelKind ParseEnhancementModelType(const std::string& modelType) {
15
+ if (modelType == "gtcrn") return EnhancementModelKind::kGtcrn;
16
+ if (modelType == "dpdfnet") return EnhancementModelKind::kDpdfNet;
17
+ return EnhancementModelKind::kUnknown;
18
+ }
19
+
20
+ } // namespace
21
+
22
+ EnhancementDetectResult DetectEnhancementModel(
23
+ const std::string& modelDir,
24
+ const std::string& modelType
25
+ ) {
26
+ EnhancementDetectResult result;
27
+
28
+ if (modelDir.empty()) {
29
+ result.error = "Enhancement: model directory is empty";
30
+ return result;
31
+ }
32
+ if (!FileExists(modelDir) || !IsDirectory(modelDir)) {
33
+ result.error =
34
+ "Enhancement: model directory does not exist or is not a directory: " +
35
+ modelDir;
36
+ return result;
37
+ }
38
+
39
+ const std::vector<FileEntry> files = ListFilesRecursive(modelDir, 4);
40
+ const std::string gtcrnModel =
41
+ FindOnnxByAnyToken(files, {"gtcrn"}, std::nullopt);
42
+ const std::string dpdfnetModel =
43
+ FindOnnxByAnyToken(files, {"dpdfnet"}, std::nullopt);
44
+
45
+ if (!gtcrnModel.empty()) {
46
+ result.detectedModels.push_back({"gtcrn", modelDir});
47
+ }
48
+ if (!dpdfnetModel.empty()) {
49
+ result.detectedModels.push_back({"dpdfnet", modelDir});
50
+ }
51
+
52
+ EnhancementModelKind selected = EnhancementModelKind::kUnknown;
53
+ if (modelType == "auto" || modelType.empty()) {
54
+ if (!gtcrnModel.empty()) {
55
+ selected = EnhancementModelKind::kGtcrn;
56
+ } else if (!dpdfnetModel.empty()) {
57
+ selected = EnhancementModelKind::kDpdfNet;
58
+ }
59
+ } else {
60
+ selected = ParseEnhancementModelType(modelType);
61
+ if (selected == EnhancementModelKind::kUnknown) {
62
+ result.error = "Enhancement: unknown model type: " + modelType;
63
+ return result;
64
+ }
65
+ }
66
+
67
+ switch (selected) {
68
+ case EnhancementModelKind::kGtcrn:
69
+ result.paths.model = gtcrnModel;
70
+ break;
71
+ case EnhancementModelKind::kDpdfNet:
72
+ result.paths.model = dpdfnetModel;
73
+ break;
74
+ default:
75
+ result.error = "Enhancement: no compatible model type detected in " +
76
+ modelDir;
77
+ return result;
78
+ }
79
+
80
+ auto validation =
81
+ ValidateEnhancementPaths(selected, result.paths, modelDir);
82
+ if (!validation.ok) {
83
+ result.error = validation.error;
84
+ return result;
85
+ }
86
+
87
+ result.selectedKind = selected;
88
+ result.ok = true;
89
+ return result;
90
+ }
91
+
92
+ } // namespace sherpaonnx
@@ -80,6 +80,11 @@ std::vector<LexiconCandidate> FindLexiconCandidates(
80
80
  const std::string& rootDir
81
81
  );
82
82
 
83
+ bool Qwen3TokenizerDirHasVocabAndMerges(
84
+ const std::vector<FileEntry>& files,
85
+ const std::string& dir
86
+ );
87
+
83
88
  } // namespace model_detect
84
89
  } // namespace sherpaonnx
85
90
 
@@ -257,5 +257,28 @@ std::vector<LexiconCandidate> FindLexiconCandidates(
257
257
  return candidates;
258
258
  }
259
259
 
260
+ bool Qwen3TokenizerDirHasVocabAndMerges(
261
+ const std::vector<FileEntry>& files,
262
+ const std::string& dirRaw
263
+ ) {
264
+ std::string dir = dirRaw;
265
+ while (!dir.empty() && (dir.back() == '/' || dir.back() == '\\'))
266
+ dir.pop_back();
267
+ if (dir.empty()) return false;
268
+ bool hasVocab = false;
269
+ bool hasMerges = false;
270
+ const std::string prefix = dir + "/";
271
+ for (const auto& e : files) {
272
+ if (e.path.size() <= prefix.size()) continue;
273
+ if (e.path.compare(0, prefix.size(), prefix) != 0) continue;
274
+ std::string rest = e.path.substr(prefix.size());
275
+ if (rest.find('/') != std::string::npos || rest.find('\\') != std::string::npos) continue;
276
+ if (e.nameLower == "vocab.json") hasVocab = true;
277
+ if (e.nameLower == "merges.txt") hasMerges = true;
278
+ }
279
+ if (hasVocab && hasMerges) return true;
280
+ return FileExists(dir + "/vocab.json") && FileExists(dir + "/merges.txt");
281
+ }
282
+
260
283
  } // namespace model_detect
261
284
  } // namespace sherpaonnx
@@ -58,6 +58,7 @@ static const char* KindToName(SttModelKind k) {
58
58
  case SttModelKind::kZipformerCtc: return "zipformer_ctc";
59
59
  case SttModelKind::kWhisper: return "whisper";
60
60
  case SttModelKind::kFunAsrNano: return "funasr_nano";
61
+ case SttModelKind::kQwen3Asr: return "qwen3_asr";
61
62
  case SttModelKind::kFireRedAsr: return "fire_red_asr";
62
63
  case SttModelKind::kMoonshine: return "moonshine";
63
64
  case SttModelKind::kMoonshineV2: return "moonshine_v2";
@@ -85,6 +86,7 @@ SttModelKind ParseSttModelType(const std::string& modelType) {
85
86
  if (modelType == "zipformer_ctc" || modelType == "ctc") return SttModelKind::kZipformerCtc;
86
87
  if (modelType == "whisper") return SttModelKind::kWhisper;
87
88
  if (modelType == "funasr_nano") return SttModelKind::kFunAsrNano;
89
+ if (modelType == "qwen3_asr") return SttModelKind::kQwen3Asr;
88
90
  if (modelType == "fire_red_asr") return SttModelKind::kFireRedAsr;
89
91
  if (modelType == "moonshine") return SttModelKind::kMoonshine;
90
92
  if (modelType == "moonshine_v2") return SttModelKind::kMoonshineV2;
@@ -123,6 +125,8 @@ static bool CapabilitySupportsKind(
123
125
  return cap.hasWhisper;
124
126
  case SttModelKind::kFunAsrNano:
125
127
  return cap.hasFunAsrNano;
128
+ case SttModelKind::kQwen3Asr:
129
+ return cap.hasQwen3Asr;
126
130
  case SttModelKind::kFireRedAsr:
127
131
  return cap.hasFireRedAsr;
128
132
  case SttModelKind::kMoonshine:
@@ -185,6 +189,8 @@ static std::vector<SttModelKind> GetKindsFromDirName(const std::string& modelDir
185
189
  add(SttModelKind::kTransducer);
186
190
  add(SttModelKind::kZipformerCtc);
187
191
  }
192
+ if (lower.find("qwen3-asr") != std::string::npos || lower.find("qwen3_asr") != std::string::npos)
193
+ add(SttModelKind::kQwen3Asr);
188
194
  if (lower.find("funasr") != std::string::npos)
189
195
  add(SttModelKind::kFunAsrNano);
190
196
  if (lower.find("canary") != std::string::npos)
@@ -245,6 +251,19 @@ static SttCandidatePaths GatherSttCandidatePaths(
245
251
  p.funasrTokenizerDir = vocabInSubdir.substr(0, lastSlash);
246
252
  }
247
253
  }
254
+ p.qwen3ConvFrontend = FindOnnxByAnyToken(files, {"conv_frontend"}, preferInt8);
255
+ {
256
+ for (const auto& entry : files) {
257
+ if (entry.nameLower != "tokenizer_config.json") continue;
258
+ size_t slash = entry.path.find_last_of("/\\");
259
+ if (slash == std::string::npos) continue;
260
+ std::string dir = entry.path.substr(0, slash);
261
+ if (Qwen3TokenizerDirHasVocabAndMerges(files, dir)) {
262
+ p.qwen3TokenizerDir = dir;
263
+ break;
264
+ }
265
+ }
266
+ }
248
267
  p.moonshinePreprocessor = FindOnnxByAnyToken(files, {"preprocess", "preprocessor"}, preferInt8);
249
268
  p.moonshineEncoder = FindOnnxByAnyToken(files, {"encode", "encoder_model"}, preferInt8);
250
269
  p.moonshineUncachedDecoder = FindOnnxByAnyToken(files, {"uncached_decode", "uncached"}, preferInt8);
@@ -254,7 +273,8 @@ static SttCandidatePaths GatherSttCandidatePaths(
254
273
  static const std::vector<std::string> modelExcludes = {
255
274
  "encoder", "decoder", "joiner", "vocoder", "acoustic", "embedding", "llm",
256
275
  "encoder_adaptor", "encoder-adaptor", "encoder_model", "decoder_model",
257
- "merged_decoder", "decoder_model_merged", "preprocess", "encode", "uncached", "cached"
276
+ "merged_decoder", "decoder_model_merged", "preprocess", "encode", "uncached", "cached",
277
+ "conv_frontend"
258
278
  };
259
279
  p.paraformerModel = FindOnnxByAnyToken(files, {"model"}, preferInt8);
260
280
  if (!p.paraformerModel.empty()) {
@@ -297,6 +317,7 @@ static SttPathHints GetSttPathHints(const std::string& modelDir) {
297
317
  h.isLikelyWenetCtc = lower.find("wenet") != std::string::npos;
298
318
  h.isLikelySenseVoice = lower.find("sense") != std::string::npos || lower.find("sensevoice") != std::string::npos;
299
319
  h.isLikelyFunAsrNano = lower.find("funasr") != std::string::npos || lower.find("funasr-nano") != std::string::npos;
320
+ h.isLikelyQwen3Asr = lower.find("qwen3-asr") != std::string::npos || lower.find("qwen3_asr") != std::string::npos;
300
321
  h.isLikelyZipformer = lower.find("zipformer") != std::string::npos;
301
322
  h.isLikelyMoonshine = lower.find("moonshine") != std::string::npos;
302
323
  h.isLikelyDolphin = lower.find("dolphin") != std::string::npos;
@@ -338,7 +359,9 @@ static SttCapabilities ComputeSttCapabilities(const SttCandidatePaths& paths, co
338
359
  c.hasTransducer = !paths.encoder.empty() && !paths.decoder.empty() && !paths.joiner.empty();
339
360
  bool hasWhisperEnc = !paths.encoder.empty();
340
361
  bool hasWhisperDec = !paths.decoder.empty();
341
- c.hasWhisper = hasWhisperEnc && hasWhisperDec && paths.joiner.empty();
362
+ bool hasQwen3Tok = !paths.qwen3TokenizerDir.empty();
363
+ c.hasQwen3Asr = !paths.qwen3ConvFrontend.empty() && hasWhisperEnc && hasWhisperDec && hasQwen3Tok;
364
+ c.hasWhisper = hasWhisperEnc && hasWhisperDec && paths.joiner.empty() && !c.hasQwen3Asr;
342
365
  bool hasFunAsrTok = !paths.funasrTokenizerDir.empty();
343
366
  c.hasFunAsrNano = !paths.funasrEncoderAdaptor.empty() && !paths.funasrLLM.empty() &&
344
367
  !paths.funasrEmbedding.empty() && hasFunAsrTok;
@@ -378,6 +401,7 @@ static void CollectDetectedModels(
378
401
  out.push_back({"paraformer", modelDir});
379
402
  }
380
403
  if (cap.hasWhisper) out.push_back({"whisper", modelDir});
404
+ if (cap.hasQwen3Asr) out.push_back({"qwen3_asr", modelDir});
381
405
  if (cap.hasFunAsrNano) out.push_back({"funasr_nano", modelDir});
382
406
  if (cap.hasMoonshine) out.push_back({"moonshine", modelDir});
383
407
  if (cap.hasMoonshineV2) out.push_back({"moonshine_v2", modelDir});
@@ -439,6 +463,10 @@ static SttModelKind ResolveSttKind(
439
463
  outError = "FunASR Nano model requested but required files not found in " + modelDir;
440
464
  return SttModelKind::kUnknown;
441
465
  }
466
+ if (selected == SttModelKind::kQwen3Asr && !cap.hasQwen3Asr) {
467
+ outError = "Qwen3-ASR model requested but conv_frontend/encoder/decoder/tokenizer not found in " + modelDir;
468
+ return SttModelKind::kUnknown;
469
+ }
442
470
  if (selected == SttModelKind::kMoonshine && !cap.hasMoonshine) {
443
471
  outError = "Moonshine v1 model requested but preprocess/encode/uncached_decode/cached_decode not found in " + modelDir;
444
472
  return SttModelKind::kUnknown;
@@ -505,7 +533,9 @@ static SttModelKind ResolveSttKind(
505
533
  if (!paths.paraformerModel.empty()) return SttModelKind::kParaformer;
506
534
  if (cap.hasCanary) return SttModelKind::kCanary;
507
535
  if (cap.hasFireRedAsr) return SttModelKind::kFireRedAsr;
536
+ if (cap.hasQwen3Asr && hints.isLikelyQwen3Asr) return SttModelKind::kQwen3Asr;
508
537
  if (cap.hasWhisper) return SttModelKind::kWhisper;
538
+ if (cap.hasQwen3Asr) return SttModelKind::kQwen3Asr;
509
539
  if (cap.hasFunAsrNano) return SttModelKind::kFunAsrNano;
510
540
  if (cap.hasMoonshineV2) return SttModelKind::kMoonshineV2;
511
541
  if (cap.hasDolphin) return SttModelKind::kDolphin;
@@ -551,6 +581,12 @@ static void ApplyPathsForSttKind(SttModelKind kind, const SttCandidatePaths& can
551
581
  resultPaths.funasrEmbedding = candidate.funasrEmbedding;
552
582
  resultPaths.funasrTokenizer = candidate.funasrTokenizerDir;
553
583
  break;
584
+ case SttModelKind::kQwen3Asr:
585
+ resultPaths.qwen3ConvFrontend = candidate.qwen3ConvFrontend;
586
+ resultPaths.qwen3Encoder = candidate.encoder;
587
+ resultPaths.qwen3Decoder = candidate.decoder;
588
+ resultPaths.qwen3Tokenizer = candidate.qwen3TokenizerDir;
589
+ break;
554
590
  case SttModelKind::kMoonshine:
555
591
  resultPaths.moonshinePreprocessor = candidate.moonshinePreprocessor;
556
592
  resultPaths.moonshineEncoder = candidate.moonshineEncoder;
@@ -624,13 +660,15 @@ SttDetectResult DetectSttModel(
624
660
  EmptyOrPath(candidate.encoder), EmptyOrPath(candidate.decoder));
625
661
  LOGI("DetectSttModel: funasr encoderAdaptor=%s llm=%s embedding=%s tokenizerDir=%s",
626
662
  EmptyOrPath(candidate.funasrEncoderAdaptor), EmptyOrPath(candidate.funasrLLM), EmptyOrPath(candidate.funasrEmbedding), EmptyOrPath(candidate.funasrTokenizerDir));
627
- LOGI("DetectSttModel: hasTransducer=%d hasWhisper=%d hasMoonshine=%d hasMoonshineV2=%d hasParaformer=%d hasFunAsrNano=%d hasDolphin=%d hasFireRedAsr=%d hasFireRedCtc=%d hasCanary=%d hasOmnilingual=%d hasMedAsr=%d hasTeleSpeechCtc=%d hasToneCtc=%d",
663
+ LOGI("DetectSttModel: qwen3_asr conv=%s tokenizerDir=%s",
664
+ EmptyOrPath(candidate.qwen3ConvFrontend), EmptyOrPath(candidate.qwen3TokenizerDir));
665
+ LOGI("DetectSttModel: hasTransducer=%d hasWhisper=%d hasMoonshine=%d hasMoonshineV2=%d hasParaformer=%d hasFunAsrNano=%d hasQwen3Asr=%d hasDolphin=%d hasFireRedAsr=%d hasFireRedCtc=%d hasCanary=%d hasOmnilingual=%d hasMedAsr=%d hasTeleSpeechCtc=%d hasToneCtc=%d",
628
666
  (int)cap.hasTransducer, (int)cap.hasWhisper, (int)cap.hasMoonshine, (int)cap.hasMoonshineV2,
629
- (int)cap.hasParaformer, (int)cap.hasFunAsrNano, (int)cap.hasDolphin, (int)cap.hasFireRedAsr, (int)cap.hasFireRedCtc,
667
+ (int)cap.hasParaformer, (int)cap.hasFunAsrNano, (int)cap.hasQwen3Asr, (int)cap.hasDolphin, (int)cap.hasFireRedAsr, (int)cap.hasFireRedCtc,
630
668
  (int)cap.hasCanary, (int)cap.hasOmnilingual, (int)cap.hasMedAsr, (int)cap.hasTeleSpeechCtc, (int)cap.hasToneCtc);
631
- LOGI("DetectSttModel: hints isLikelyNemo=%d isLikelyTdt=%d isLikelyWenetCtc=%d isLikelySenseVoice=%d isLikelyFunAsrNano=%d isLikelyZipformer=%d isLikelyMoonshine=%d isLikelyDolphin=%d isLikelyFireRedAsr=%d isLikelyCanary=%d isLikelyOmnilingual=%d isLikelyMedAsr=%d isLikelyTeleSpeech=%d isLikelyToneCtc=%d isLikelyParaformer=%d isLikelyVad=%d isLikelyTdnn=%d",
669
+ LOGI("DetectSttModel: hints isLikelyNemo=%d isLikelyTdt=%d isLikelyWenetCtc=%d isLikelySenseVoice=%d isLikelyFunAsrNano=%d isLikelyQwen3Asr=%d isLikelyZipformer=%d isLikelyMoonshine=%d isLikelyDolphin=%d isLikelyFireRedAsr=%d isLikelyCanary=%d isLikelyOmnilingual=%d isLikelyMedAsr=%d isLikelyTeleSpeech=%d isLikelyToneCtc=%d isLikelyParaformer=%d isLikelyVad=%d isLikelyTdnn=%d",
632
670
  (int)hints.isLikelyNemo, (int)hints.isLikelyTdt, (int)hints.isLikelyWenetCtc, (int)hints.isLikelySenseVoice,
633
- (int)hints.isLikelyFunAsrNano, (int)hints.isLikelyZipformer, (int)hints.isLikelyMoonshine, (int)hints.isLikelyDolphin,
671
+ (int)hints.isLikelyFunAsrNano, (int)hints.isLikelyQwen3Asr, (int)hints.isLikelyZipformer, (int)hints.isLikelyMoonshine, (int)hints.isLikelyDolphin,
634
672
  (int)hints.isLikelyFireRedAsr, (int)hints.isLikelyCanary, (int)hints.isLikelyOmnilingual, (int)hints.isLikelyMedAsr,
635
673
  (int)hints.isLikelyTeleSpeech, (int)hints.isLikelyToneCtc, (int)hints.isLikelyParaformer, (int)hints.isLikelyVad, (int)hints.isLikelyTdnn);
636
674
  }
@@ -653,7 +691,8 @@ SttDetectResult DetectSttModel(
653
691
  }
654
692
 
655
693
  LOGI("DetectSttModel: selected kind=%d (%s)", static_cast<int>(result.selectedKind), KindToName(result.selectedKind));
656
- result.tokensRequired = (result.selectedKind != SttModelKind::kFunAsrNano);
694
+ result.tokensRequired = (result.selectedKind != SttModelKind::kFunAsrNano &&
695
+ result.selectedKind != SttModelKind::kQwen3Asr);
657
696
  ApplyPathsForSttKind(result.selectedKind, candidate, result.paths);
658
697
 
659
698
  if (!candidate.tokens.empty() && FileExists(candidate.tokens)) {
@@ -711,6 +750,11 @@ SttDetectResult DetectSttModel(
711
750
  EmptyOrPath(result.paths.funasrEncoderAdaptor), EmptyOrPath(result.paths.funasrLLM),
712
751
  EmptyOrPath(result.paths.funasrEmbedding), EmptyOrPath(result.paths.funasrTokenizer));
713
752
  break;
753
+ case SttModelKind::kQwen3Asr:
754
+ LOGI("DetectSttModel: paths set qwen3_asr conv=%s encoder=%s decoder=%s tokenizer=%s",
755
+ EmptyOrPath(result.paths.qwen3ConvFrontend), EmptyOrPath(result.paths.qwen3Encoder),
756
+ EmptyOrPath(result.paths.qwen3Decoder), EmptyOrPath(result.paths.qwen3Tokenizer));
757
+ break;
714
758
  default:
715
759
  break;
716
760
  }