react-native-sherpa-onnx 0.3.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +20 -5
  2. package/SherpaOnnx.podspec +5 -1
  3. package/android/prebuilt-download.gradle +89 -49
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/asr-models-license-status.csv +1 -0
  6. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  7. package/android/src/main/cpp/CMakeLists.txt +3 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +23 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +9 -0
  13. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +51 -8
  14. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +41 -0
  15. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +5 -0
  16. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  17. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  18. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-stt.cpp +11 -0
  19. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  20. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +110 -35
  21. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  22. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  23. package/android/src/main/java/com/sherpaonnx/SherpaOnnxExtractionNotificationHelper.kt +102 -0
  24. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +198 -18
  25. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +22 -0
  26. package/ios/Resources/model_licenses/asr-models-license-status.csv +1 -0
  27. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  28. package/ios/SherpaOnnx+Assets.mm +5 -0
  29. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  30. package/ios/SherpaOnnx+STT.mm +13 -1
  31. package/ios/SherpaOnnx.mm +87 -17
  32. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  33. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  34. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  35. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +5 -0
  36. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +23 -0
  37. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +51 -7
  38. package/ios/model_detect/sherpa-onnx-model-detect.h +33 -0
  39. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  40. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  41. package/ios/model_detect/sherpa-onnx-validate-stt.mm +11 -0
  42. package/ios/stt/sherpa-onnx-stt-wrapper.h +11 -1
  43. package/ios/stt/sherpa-onnx-stt-wrapper.mm +30 -2
  44. package/ios/tts/sherpa-onnx-tts-wrapper.mm +16 -0
  45. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  46. package/lib/module/download/localModels.js +2 -3
  47. package/lib/module/download/localModels.js.map +1 -1
  48. package/lib/module/download/paths.js +2 -1
  49. package/lib/module/download/paths.js.map +1 -1
  50. package/lib/module/download/postDownloadProcessing.js +17 -4
  51. package/lib/module/download/postDownloadProcessing.js.map +1 -1
  52. package/lib/module/enhancement/index.js +63 -48
  53. package/lib/module/enhancement/index.js.map +1 -1
  54. package/lib/module/enhancement/streaming.js +60 -0
  55. package/lib/module/enhancement/streaming.js.map +1 -0
  56. package/lib/module/enhancement/streamingTypes.js +4 -0
  57. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  58. package/lib/module/enhancement/types.js +4 -0
  59. package/lib/module/enhancement/types.js.map +1 -0
  60. package/lib/module/extraction/extractTarBz2.js +2 -2
  61. package/lib/module/extraction/extractTarBz2.js.map +1 -1
  62. package/lib/module/extraction/extractTarZst.js +2 -2
  63. package/lib/module/extraction/extractTarZst.js.map +1 -1
  64. package/lib/module/extraction/index.js +10 -5
  65. package/lib/module/extraction/index.js.map +1 -1
  66. package/lib/module/licenses.js +9 -3
  67. package/lib/module/licenses.js.map +1 -1
  68. package/lib/module/stt/index.js +4 -2
  69. package/lib/module/stt/index.js.map +1 -1
  70. package/lib/module/stt/streaming.js +2 -1
  71. package/lib/module/stt/streaming.js.map +1 -1
  72. package/lib/module/stt/types.js +3 -1
  73. package/lib/module/stt/types.js.map +1 -1
  74. package/lib/module/tts/index.js +4 -2
  75. package/lib/module/tts/index.js.map +1 -1
  76. package/lib/module/tts/streaming.js +3 -1
  77. package/lib/module/tts/streaming.js.map +1 -1
  78. package/lib/typescript/src/NativeSherpaOnnx.d.ts +70 -9
  79. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  80. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  81. package/lib/typescript/src/download/paths.d.ts +2 -1
  82. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  83. package/lib/typescript/src/download/postDownloadProcessing.d.ts +9 -0
  84. package/lib/typescript/src/download/postDownloadProcessing.d.ts.map +1 -1
  85. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  86. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  87. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  88. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  89. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  90. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  91. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  92. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  93. package/lib/typescript/src/extraction/extractTarBz2.d.ts +2 -1
  94. package/lib/typescript/src/extraction/extractTarBz2.d.ts.map +1 -1
  95. package/lib/typescript/src/extraction/extractTarZst.d.ts +2 -1
  96. package/lib/typescript/src/extraction/extractTarZst.d.ts.map +1 -1
  97. package/lib/typescript/src/extraction/index.d.ts +1 -1
  98. package/lib/typescript/src/extraction/index.d.ts.map +1 -1
  99. package/lib/typescript/src/extraction/types.d.ts +12 -0
  100. package/lib/typescript/src/extraction/types.d.ts.map +1 -1
  101. package/lib/typescript/src/licenses.d.ts.map +1 -1
  102. package/lib/typescript/src/stt/index.d.ts +1 -1
  103. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  104. package/lib/typescript/src/stt/streaming.d.ts.map +1 -1
  105. package/lib/typescript/src/stt/types.d.ts +16 -1
  106. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  107. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  108. package/lib/typescript/src/tts/streaming.d.ts.map +1 -1
  109. package/package.json +1 -1
  110. package/scripts/ci/check-model-csvs.sh +27 -2
  111. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  112. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  113. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  114. package/scripts/ci/update_model_license_csv.sh +17 -17
  115. package/src/NativeSherpaOnnx.ts +108 -10
  116. package/src/download/localModels.ts +1 -3
  117. package/src/download/paths.ts +2 -1
  118. package/src/download/postDownloadProcessing.ts +24 -1
  119. package/src/enhancement/index.ts +120 -58
  120. package/src/enhancement/streaming.ts +105 -0
  121. package/src/enhancement/streamingTypes.ts +14 -0
  122. package/src/enhancement/types.ts +36 -0
  123. package/src/extraction/extractTarBz2.ts +7 -2
  124. package/src/extraction/extractTarZst.ts +7 -2
  125. package/src/extraction/index.ts +29 -6
  126. package/src/extraction/types.ts +16 -0
  127. package/src/licenses.ts +13 -2
  128. package/src/stt/index.ts +8 -7
  129. package/src/stt/streaming.ts +7 -1
  130. package/src/stt/types.ts +18 -0
  131. package/src/tts/index.ts +7 -7
  132. package/src/tts/streaming.ts +6 -3
  133. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
  134. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
@@ -0,0 +1,435 @@
1
+ #import "SherpaOnnx.h"
2
+ #import <React/RCTLog.h>
3
+
4
+ #include "sherpa-onnx-enhancement-wrapper.h"
5
+ #include "sherpa-onnx-model-detect.h"
6
+ #include "sherpa-onnx/c-api/cxx-api.h"
7
+
8
+ #include <memory>
9
+ #include <mutex>
10
+ #include <optional>
11
+ #include <string>
12
+ #include <unordered_map>
13
+ #include <vector>
14
+
15
+ struct EnhancementInstanceState {
16
+ std::unique_ptr<sherpaonnx::EnhancementWrapper> wrapper;
17
+ };
18
+
19
+ struct OnlineEnhancementInstanceState {
20
+ std::unique_ptr<sherpaonnx::OnlineEnhancementWrapper> wrapper;
21
+ };
22
+
23
+ static std::unordered_map<std::string, std::unique_ptr<EnhancementInstanceState>> g_enhancement_instances;
24
+ static std::unordered_map<std::string, std::unique_ptr<OnlineEnhancementInstanceState>> g_online_enhancement_instances;
25
+ static std::mutex g_enhancement_mutex;
26
+
27
+ namespace {
28
+
29
+ static NSString *enhancementKindToNSString(sherpaonnx::EnhancementModelKind kind) {
30
+ using K = sherpaonnx::EnhancementModelKind;
31
+ switch (kind) {
32
+ case K::kGtcrn: return @"gtcrn";
33
+ case K::kDpdfNet: return @"dpdfnet";
34
+ default: return @"unknown";
35
+ }
36
+ }
37
+
38
+ static NSDictionary *enhancedAudioToDict(const sherpaonnx::EnhancedAudioResult& r) {
39
+ NSMutableArray *samples = [NSMutableArray arrayWithCapacity:r.samples.size()];
40
+ for (float s : r.samples) {
41
+ [samples addObject:@(s)];
42
+ }
43
+ return @{
44
+ @"samples": samples,
45
+ @"sampleRate": @(r.sampleRate)
46
+ };
47
+ }
48
+
49
+ static NSDictionary *enhancementDetectResultToDict(const sherpaonnx::EnhancementDetectResult& result) {
50
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
51
+ for (const auto& model : result.detectedModels) {
52
+ [detectedModelsArray addObject:@{
53
+ @"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
54
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
55
+ }];
56
+ }
57
+
58
+ NSMutableDictionary *dict = [@{
59
+ @"success": @(result.ok),
60
+ @"detectedModels": detectedModelsArray,
61
+ @"modelType": enhancementKindToNSString(result.selectedKind),
62
+ } mutableCopy];
63
+ if (!result.ok && !result.error.empty()) {
64
+ dict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()] ?: @"Enhancement model detection failed";
65
+ }
66
+ return dict;
67
+ }
68
+
69
+ } // namespace
70
+
71
+ @implementation SherpaOnnx (Enhancement)
72
+
73
+ - (void)detectEnhancementModel:(NSString *)modelDir
74
+ modelType:(NSString *)modelType
75
+ resolve:(RCTPromiseResolveBlock)resolve
76
+ reject:(RCTPromiseRejectBlock)reject
77
+ {
78
+ @try {
79
+ std::string modelDirStr = (modelDir != nil) ? [modelDir UTF8String] : "";
80
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
81
+ auto result = sherpaonnx::DetectEnhancementModel(modelDirStr, modelTypeStr);
82
+ resolve(enhancementDetectResultToDict(result));
83
+ } @catch (NSException *exception) {
84
+ reject(@"DETECT_ERROR",
85
+ [NSString stringWithFormat:@"Enhancement detect failed: %@", exception.reason],
86
+ nil);
87
+ }
88
+ }
89
+
90
+ - (void)initializeEnhancement:(NSString *)instanceId
91
+ modelDir:(NSString *)modelDir
92
+ modelType:(NSString *)modelType
93
+ numThreads:(NSNumber *)numThreads
94
+ provider:(NSString *)provider
95
+ debug:(NSNumber *)debug
96
+ resolve:(RCTPromiseResolveBlock)resolve
97
+ reject:(RCTPromiseRejectBlock)reject
98
+ {
99
+ if (instanceId == nil || [instanceId length] == 0) {
100
+ reject(@"ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
101
+ return;
102
+ }
103
+ if (modelDir == nil || [modelDir length] == 0) {
104
+ reject(@"ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
105
+ return;
106
+ }
107
+
108
+ std::string instanceIdStr = [instanceId UTF8String];
109
+ std::string modelDirStr = [modelDir UTF8String];
110
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
111
+ int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
112
+ bool debugVal = debug != nil && [debug boolValue];
113
+ std::optional<std::string> providerOpt = std::nullopt;
114
+ if (provider != nil && [provider length] > 0) {
115
+ providerOpt = std::string([provider UTF8String]);
116
+ }
117
+
118
+ @try {
119
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
120
+ auto it = g_enhancement_instances.find(instanceIdStr);
121
+ if (it == g_enhancement_instances.end()) {
122
+ g_enhancement_instances[instanceIdStr] = std::make_unique<EnhancementInstanceState>();
123
+ }
124
+ auto *inst = g_enhancement_instances[instanceIdStr].get();
125
+ if (inst->wrapper == nullptr) {
126
+ inst->wrapper = std::make_unique<sherpaonnx::EnhancementWrapper>();
127
+ }
128
+
129
+ auto result = inst->wrapper->initialize(
130
+ modelDirStr,
131
+ modelTypeStr,
132
+ numThreadsVal,
133
+ providerOpt,
134
+ debugVal
135
+ );
136
+
137
+ if (!result.success) {
138
+ NSString *errorMsg = result.error.empty()
139
+ ? @"Failed to initialize enhancement"
140
+ : [NSString stringWithUTF8String:result.error.c_str()];
141
+ reject(@"ENHANCEMENT_INIT_ERROR", errorMsg, nil);
142
+ return;
143
+ }
144
+
145
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
146
+ for (const auto& model : result.detectedModels) {
147
+ [detectedModelsArray addObject:@{
148
+ @"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
149
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
150
+ }];
151
+ }
152
+
153
+ resolve(@{
154
+ @"success": @YES,
155
+ @"detectedModels": detectedModelsArray,
156
+ @"modelType": [NSString stringWithUTF8String:result.modelType.c_str()] ?: @"unknown",
157
+ @"sampleRate": @(result.sampleRate)
158
+ });
159
+ } @catch (NSException *exception) {
160
+ reject(@"ENHANCEMENT_INIT_ERROR",
161
+ [NSString stringWithFormat:@"Enhancement init failed: %@", exception.reason],
162
+ nil);
163
+ }
164
+ }
165
+
166
+ - (void)enhanceSamples:(NSString *)instanceId
167
+ samples:(NSArray *)samples
168
+ sampleRate:(double)sampleRate
169
+ resolve:(RCTPromiseResolveBlock)resolve
170
+ reject:(RCTPromiseRejectBlock)reject
171
+ {
172
+ if (instanceId == nil || [instanceId length] == 0) {
173
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
174
+ return;
175
+ }
176
+
177
+ std::string instanceIdStr = [instanceId UTF8String];
178
+ std::vector<float> floatSamples;
179
+ floatSamples.reserve([samples count]);
180
+ for (NSNumber *n in samples) {
181
+ floatSamples.push_back([n floatValue]);
182
+ }
183
+
184
+ @try {
185
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
186
+ auto it = g_enhancement_instances.find(instanceIdStr);
187
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
188
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
189
+ return;
190
+ }
191
+ auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
192
+ resolve(enhancedAudioToDict(out));
193
+ } @catch (NSException *exception) {
194
+ reject(@"ENHANCEMENT_ERROR",
195
+ [NSString stringWithFormat:@"Enhance samples failed: %@", exception.reason],
196
+ nil);
197
+ }
198
+ }
199
+
200
+ - (void)enhanceFile:(NSString *)instanceId
201
+ inputPath:(NSString *)inputPath
202
+ outputPath:(NSString *)outputPath
203
+ resolve:(RCTPromiseResolveBlock)resolve
204
+ reject:(RCTPromiseRejectBlock)reject
205
+ {
206
+ if (instanceId == nil || [instanceId length] == 0) {
207
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
208
+ return;
209
+ }
210
+ if (inputPath == nil || [inputPath length] == 0) {
211
+ reject(@"ENHANCEMENT_ERROR", @"inputPath is required", nil);
212
+ return;
213
+ }
214
+
215
+ std::string instanceIdStr = [instanceId UTF8String];
216
+ std::string inputPathStr = [inputPath UTF8String];
217
+
218
+ @try {
219
+ sherpa_onnx::cxx::Wave wave = sherpa_onnx::cxx::ReadWave(inputPathStr);
220
+ if (wave.samples.empty() || wave.sample_rate <= 0) {
221
+ reject(@"ENHANCEMENT_ERROR", @"Failed to read input wave file", nil);
222
+ return;
223
+ }
224
+
225
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
226
+ auto it = g_enhancement_instances.find(instanceIdStr);
227
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
228
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
229
+ return;
230
+ }
231
+ auto out = it->second->wrapper->runSamples(wave.samples, wave.sample_rate);
232
+
233
+ if (outputPath != nil && [outputPath length] > 0) {
234
+ sherpa_onnx::cxx::Wave outputWave;
235
+ outputWave.samples = out.samples;
236
+ outputWave.sample_rate = out.sampleRate;
237
+ std::string outputPathStr = [outputPath UTF8String];
238
+ sherpa_onnx::cxx::WriteWave(outputPathStr, outputWave);
239
+ }
240
+
241
+ resolve(enhancedAudioToDict(out));
242
+ } @catch (NSException *exception) {
243
+ reject(@"ENHANCEMENT_ERROR",
244
+ [NSString stringWithFormat:@"Enhance file failed: %@", exception.reason],
245
+ nil);
246
+ }
247
+ }
248
+
249
+ - (void)getEnhancementSampleRate:(NSString *)instanceId
250
+ resolve:(RCTPromiseResolveBlock)resolve
251
+ reject:(RCTPromiseRejectBlock)reject
252
+ {
253
+ if (instanceId == nil || [instanceId length] == 0) {
254
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
255
+ return;
256
+ }
257
+ std::string instanceIdStr = [instanceId UTF8String];
258
+
259
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
260
+ auto it = g_enhancement_instances.find(instanceIdStr);
261
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
262
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
263
+ return;
264
+ }
265
+ resolve(@(it->second->wrapper->getSampleRate()));
266
+ }
267
+
268
+ - (void)unloadEnhancement:(NSString *)instanceId
269
+ resolve:(RCTPromiseResolveBlock)resolve
270
+ reject:(RCTPromiseRejectBlock)reject
271
+ {
272
+ if (instanceId == nil || [instanceId length] == 0) {
273
+ resolve(nil);
274
+ return;
275
+ }
276
+ std::string instanceIdStr = [instanceId UTF8String];
277
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
278
+ auto it = g_enhancement_instances.find(instanceIdStr);
279
+ if (it != g_enhancement_instances.end() && it->second->wrapper != nullptr) {
280
+ it->second->wrapper->release();
281
+ g_enhancement_instances.erase(it);
282
+ }
283
+ resolve(nil);
284
+ }
285
+
286
+ - (void)initializeOnlineEnhancement:(NSString *)instanceId
287
+ modelDir:(NSString *)modelDir
288
+ modelType:(NSString *)modelType
289
+ numThreads:(NSNumber *)numThreads
290
+ provider:(NSString *)provider
291
+ debug:(NSNumber *)debug
292
+ resolve:(RCTPromiseResolveBlock)resolve
293
+ reject:(RCTPromiseRejectBlock)reject
294
+ {
295
+ if (instanceId == nil || [instanceId length] == 0) {
296
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
297
+ return;
298
+ }
299
+ if (modelDir == nil || [modelDir length] == 0) {
300
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
301
+ return;
302
+ }
303
+
304
+ std::string instanceIdStr = [instanceId UTF8String];
305
+ std::string modelDirStr = [modelDir UTF8String];
306
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
307
+ int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
308
+ bool debugVal = debug != nil && [debug boolValue];
309
+ std::optional<std::string> providerOpt = std::nullopt;
310
+ if (provider != nil && [provider length] > 0) {
311
+ providerOpt = std::string([provider UTF8String]);
312
+ }
313
+
314
+ @try {
315
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
316
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
317
+ if (it == g_online_enhancement_instances.end()) {
318
+ g_online_enhancement_instances[instanceIdStr] = std::make_unique<OnlineEnhancementInstanceState>();
319
+ }
320
+ auto *inst = g_online_enhancement_instances[instanceIdStr].get();
321
+ if (inst->wrapper == nullptr) {
322
+ inst->wrapper = std::make_unique<sherpaonnx::OnlineEnhancementWrapper>();
323
+ }
324
+
325
+ auto result = inst->wrapper->initialize(
326
+ modelDirStr,
327
+ modelTypeStr,
328
+ numThreadsVal,
329
+ providerOpt,
330
+ debugVal
331
+ );
332
+ if (!result.success) {
333
+ NSString *errorMsg = result.error.empty()
334
+ ? @"Failed to initialize online enhancement"
335
+ : [NSString stringWithUTF8String:result.error.c_str()];
336
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", errorMsg, nil);
337
+ return;
338
+ }
339
+
340
+ resolve(@{
341
+ @"success": @YES,
342
+ @"sampleRate": @(result.sampleRate),
343
+ @"frameShiftInSamples": @(result.frameShiftInSamples)
344
+ });
345
+ } @catch (NSException *exception) {
346
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR",
347
+ [NSString stringWithFormat:@"Online enhancement init failed: %@", exception.reason],
348
+ nil);
349
+ }
350
+ }
351
+
352
+ - (void)feedEnhancementSamples:(NSString *)instanceId
353
+ samples:(NSArray *)samples
354
+ sampleRate:(double)sampleRate
355
+ resolve:(RCTPromiseResolveBlock)resolve
356
+ reject:(RCTPromiseRejectBlock)reject
357
+ {
358
+ if (instanceId == nil || [instanceId length] == 0) {
359
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
360
+ return;
361
+ }
362
+ std::string instanceIdStr = [instanceId UTF8String];
363
+ std::vector<float> floatSamples;
364
+ floatSamples.reserve([samples count]);
365
+ for (NSNumber *n in samples) {
366
+ floatSamples.push_back([n floatValue]);
367
+ }
368
+
369
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
370
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
371
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
372
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
373
+ return;
374
+ }
375
+ auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
376
+ resolve(enhancedAudioToDict(out));
377
+ }
378
+
379
+ - (void)flushOnlineEnhancement:(NSString *)instanceId
380
+ resolve:(RCTPromiseResolveBlock)resolve
381
+ reject:(RCTPromiseRejectBlock)reject
382
+ {
383
+ if (instanceId == nil || [instanceId length] == 0) {
384
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
385
+ return;
386
+ }
387
+ std::string instanceIdStr = [instanceId UTF8String];
388
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
389
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
390
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
391
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
392
+ return;
393
+ }
394
+ auto out = it->second->wrapper->flush();
395
+ resolve(enhancedAudioToDict(out));
396
+ }
397
+
398
+ - (void)resetOnlineEnhancement:(NSString *)instanceId
399
+ resolve:(RCTPromiseResolveBlock)resolve
400
+ reject:(RCTPromiseRejectBlock)reject
401
+ {
402
+ if (instanceId == nil || [instanceId length] == 0) {
403
+ resolve(nil);
404
+ return;
405
+ }
406
+ std::string instanceIdStr = [instanceId UTF8String];
407
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
408
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
409
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
410
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
411
+ return;
412
+ }
413
+ it->second->wrapper->reset();
414
+ resolve(nil);
415
+ }
416
+
417
+ - (void)unloadOnlineEnhancement:(NSString *)instanceId
418
+ resolve:(RCTPromiseResolveBlock)resolve
419
+ reject:(RCTPromiseRejectBlock)reject
420
+ {
421
+ if (instanceId == nil || [instanceId length] == 0) {
422
+ resolve(nil);
423
+ return;
424
+ }
425
+ std::string instanceIdStr = [instanceId UTF8String];
426
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
427
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
428
+ if (it != g_online_enhancement_instances.end() && it->second->wrapper != nullptr) {
429
+ it->second->wrapper->release();
430
+ g_online_enhancement_instances.erase(it);
431
+ }
432
+ resolve(nil);
433
+ }
434
+
435
+ @end
@@ -36,6 +36,7 @@ static NSString *sttModelKindToNSString(sherpaonnx::SttModelKind kind) {
36
36
  case K::kZipformerCtc: return @"zipformer_ctc";
37
37
  case K::kWhisper: return @"whisper";
38
38
  case K::kFunAsrNano: return @"funasr_nano";
39
+ case K::kQwen3Asr: return @"qwen3_asr";
39
40
  case K::kFireRedAsr: return @"fire_red_asr";
40
41
  case K::kMoonshine: return @"moonshine";
41
42
  case K::kMoonshineV2: return @"moonshine_v2";
@@ -164,10 +165,12 @@ static NSDictionary *sttResultToDict(const sherpaonnx::SttRecognitionResult& r)
164
165
  sherpaonnx::SttSenseVoiceOptions senseVoiceOpts;
165
166
  sherpaonnx::SttCanaryOptions canaryOpts;
166
167
  sherpaonnx::SttFunAsrNanoOptions funasrNanoOpts;
168
+ sherpaonnx::SttQwen3AsrOptions qwen3AsrOpts;
167
169
  const sherpaonnx::SttWhisperOptions *whisperOptsPtr = nullptr;
168
170
  const sherpaonnx::SttSenseVoiceOptions *senseVoiceOptsPtr = nullptr;
169
171
  const sherpaonnx::SttCanaryOptions *canaryOptsPtr = nullptr;
170
172
  const sherpaonnx::SttFunAsrNanoOptions *funasrNanoOptsPtr = nullptr;
173
+ const sherpaonnx::SttQwen3AsrOptions *qwen3AsrOptsPtr = nullptr;
171
174
  if (modelOptions != nil && [modelOptions isKindOfClass:[NSDictionary class]]) {
172
175
  NSDictionary *w = modelOptions[@"whisper"];
173
176
  if ([w isKindOfClass:[NSDictionary class]]) {
@@ -202,12 +205,21 @@ static NSDictionary *sttResultToDict(const sherpaonnx::SttRecognitionResult& r)
202
205
  if (fn[@"hotwords"] != nil) funasrNanoOpts.hotwords = std::string([(NSString *)fn[@"hotwords"] UTF8String]);
203
206
  funasrNanoOptsPtr = &funasrNanoOpts;
204
207
  }
208
+ NSDictionary *q3 = modelOptions[@"qwen3Asr"];
209
+ if ([q3 isKindOfClass:[NSDictionary class]]) {
210
+ if (q3[@"maxTotalLen"] != nil) qwen3AsrOpts.max_total_len = [(NSNumber *)q3[@"maxTotalLen"] intValue];
211
+ if (q3[@"maxNewTokens"] != nil) qwen3AsrOpts.max_new_tokens = [(NSNumber *)q3[@"maxNewTokens"] intValue];
212
+ if (q3[@"temperature"] != nil) qwen3AsrOpts.temperature = [(NSNumber *)q3[@"temperature"] floatValue];
213
+ if (q3[@"topP"] != nil) qwen3AsrOpts.top_p = [(NSNumber *)q3[@"topP"] floatValue];
214
+ if (q3[@"seed"] != nil) qwen3AsrOpts.seed = [(NSNumber *)q3[@"seed"] intValue];
215
+ qwen3AsrOptsPtr = &qwen3AsrOpts;
216
+ }
205
217
  }
206
218
 
207
219
  sherpaonnx::SttInitializeResult result = inst->wrapper->initialize(
208
220
  modelDirStr, preferInt8Opt, modelTypeOpt, debugVal, hotwordsFileOpt, hotwordsScoreOpt,
209
221
  numThreadsOpt, providerOpt, ruleFstsOpt, ruleFarsOpt, ditherOpt,
210
- whisperOptsPtr, senseVoiceOptsPtr, canaryOptsPtr, funasrNanoOptsPtr);
222
+ whisperOptsPtr, senseVoiceOptsPtr, canaryOptsPtr, funasrNanoOptsPtr, qwen3AsrOptsPtr);
211
223
 
212
224
  if (result.success) {
213
225
  RCTLogInfo(@"Sherpa-onnx initialized successfully");
package/ios/SherpaOnnx.mm CHANGED
@@ -138,9 +138,15 @@
138
138
  - (void)extractTarBz2:(NSString *)sourcePath
139
139
  targetPath:(NSString *)targetPath
140
140
  force:(BOOL)force
141
- resolve:(RCTPromiseResolveBlock)resolve
142
- reject:(RCTPromiseRejectBlock)reject
141
+ showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
142
+ notificationTitle:(NSString *)notificationTitle
143
+ notificationText:(NSString *)notificationText
144
+ resolve:(RCTPromiseResolveBlock)resolve
145
+ reject:(RCTPromiseRejectBlock)reject
143
146
  {
147
+ (void)showNotificationsEnabled;
148
+ (void)notificationTitle;
149
+ (void)notificationText;
144
150
  SherpaOnnxArchiveHelper *helper = [SherpaOnnxArchiveHelper new];
145
151
  NSDictionary *result = [helper extractTarBz2:sourcePath
146
152
  targetPath:targetPath
@@ -165,9 +171,15 @@
165
171
  - (void)extractTarZst:(NSString *)sourcePath
166
172
  targetPath:(NSString *)targetPath
167
173
  force:(BOOL)force
168
- resolve:(RCTPromiseResolveBlock)resolve
169
- reject:(RCTPromiseRejectBlock)reject
174
+ showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
175
+ notificationTitle:(NSString *)notificationTitle
176
+ notificationText:(NSString *)notificationText
177
+ resolve:(RCTPromiseResolveBlock)resolve
178
+ reject:(RCTPromiseRejectBlock)reject
170
179
  {
180
+ (void)showNotificationsEnabled;
181
+ (void)notificationTitle;
182
+ (void)notificationText;
171
183
  SherpaOnnxArchiveHelper *helper = [SherpaOnnxArchiveHelper new];
172
184
  NSDictionary *result = [helper extractTarZst:sourcePath
173
185
  targetPath:targetPath
@@ -229,19 +241,33 @@
229
241
 
230
242
  - (void)extractTarZstFromAsset:(NSString *)assetPath
231
243
  targetPath:(NSString *)targetPath
232
- force:(NSNumber *)force
233
- resolve:(RCTPromiseResolveBlock)resolve
234
- reject:(RCTPromiseRejectBlock)reject
244
+ force:(BOOL)force
245
+ showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
246
+ notificationTitle:(NSString *)notificationTitle
247
+ notificationText:(NSString *)notificationText
248
+ resolve:(RCTPromiseResolveBlock)resolve
249
+ reject:(RCTPromiseRejectBlock)reject
235
250
  {
251
+ (void)force;
252
+ (void)showNotificationsEnabled;
253
+ (void)notificationTitle;
254
+ (void)notificationText;
236
255
  resolve(@{ @"success": @NO, @"reason": @"Not supported on iOS; use path-based extraction." });
237
256
  }
238
257
 
239
258
  - (void)extractTarBz2FromAsset:(NSString *)assetPath
240
259
  targetPath:(NSString *)targetPath
241
- force:(NSNumber *)force
242
- resolve:(RCTPromiseResolveBlock)resolve
243
- reject:(RCTPromiseRejectBlock)reject
260
+ force:(BOOL)force
261
+ showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
262
+ notificationTitle:(NSString *)notificationTitle
263
+ notificationText:(NSString *)notificationText
264
+ resolve:(RCTPromiseResolveBlock)resolve
265
+ reject:(RCTPromiseRejectBlock)reject
244
266
  {
267
+ (void)force;
268
+ (void)showNotificationsEnabled;
269
+ (void)notificationTitle;
270
+ (void)notificationText;
245
271
  resolve(@{ @"success": @NO, @"reason": @"Not supported on iOS; use path-based extraction." });
246
272
  }
247
273
 
@@ -329,15 +355,59 @@
329
355
  nil);
330
356
  return;
331
357
  }
332
- NSString *resourcePath = [[NSBundle mainBundle] resourcePath];
333
- NSString *fullPath = [resourcePath stringByAppendingPathComponent:assetPath];
358
+ NSString *fullPath = nil;
359
+ NSBundle *mainBundle = [NSBundle mainBundle];
360
+ NSString *assetDir = [assetPath stringByDeletingLastPathComponent];
361
+ NSString *assetNameWithExt = [assetPath lastPathComponent];
362
+ NSString *assetName = [assetNameWithExt stringByDeletingPathExtension];
363
+ NSString *assetExt = [assetNameWithExt pathExtension];
364
+
365
+ // 1) App bundle: regular nested path (keeps generic asset support)
366
+ NSString *mainPath = [mainBundle pathForResource:assetName
367
+ ofType:assetExt.length > 0 ? assetExt : nil
368
+ inDirectory:assetDir.length > 0 ? assetDir : nil];
369
+ if (mainPath.length > 0) {
370
+ fullPath = mainPath;
371
+ }
372
+
373
+ // 2) CocoaPods resource bundle: files are flattened into bundle root
374
+ if (!fullPath) {
375
+ NSString *resBundlePath = [mainBundle pathForResource:@"SherpaOnnxResources"
376
+ ofType:@"bundle"];
377
+ if (resBundlePath.length > 0) {
378
+ NSBundle *resBundle = [NSBundle bundleWithPath:resBundlePath];
379
+ if (resBundle) {
380
+ NSString *bundleRootPath = [resBundle pathForResource:assetName
381
+ ofType:assetExt.length > 0 ? assetExt : nil];
382
+ if (bundleRootPath.length > 0) {
383
+ fullPath = bundleRootPath;
384
+ }
385
+ }
386
+ }
387
+ }
388
+
389
+ if (!fullPath) {
390
+ reject(@"ASSET_READ_ERROR",
391
+ [NSString stringWithFormat:@"Failed to locate asset %@", assetPath],
392
+ nil);
393
+ return;
394
+ }
395
+
334
396
  NSError *error = nil;
335
- NSString *content = [NSString stringWithContentsOfFile:fullPath encoding:NSUTF8StringEncoding error:&error];
336
- if (error) {
337
- reject(@"ASSET_READ_ERROR", [NSString stringWithFormat:@"Failed to read asset %@: %@", assetPath, error.localizedDescription], error);
338
- } else {
339
- resolve(content);
397
+ NSString *content = [NSString stringWithContentsOfFile:fullPath
398
+ encoding:NSUTF8StringEncoding
399
+ error:&error];
400
+ if (error || content == nil) {
401
+ reject(@"ASSET_READ_ERROR",
402
+ [NSString stringWithFormat:@"Failed to read asset %@ at %@: %@",
403
+ assetPath,
404
+ fullPath,
405
+ error.localizedDescription ?: @"Unknown error"],
406
+ error);
407
+ return;
340
408
  }
409
+
410
+ resolve(content);
341
411
  }
342
412
 
343
413
  @end