react-native-sherpa-onnx 0.3.9 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +17 -4
  2. package/SherpaOnnx.podspec +1 -0
  3. package/android/prebuilt-download.gradle +67 -27
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  6. package/android/src/main/cpp/CMakeLists.txt +3 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  13. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  15. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  16. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
  17. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +66 -13
  18. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  19. package/ios/SherpaOnnx+Assets.mm +5 -0
  20. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  21. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  22. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  23. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  24. package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
  25. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  26. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  27. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  28. package/lib/module/download/localModels.js +2 -3
  29. package/lib/module/download/localModels.js.map +1 -1
  30. package/lib/module/download/paths.js +2 -1
  31. package/lib/module/download/paths.js.map +1 -1
  32. package/lib/module/enhancement/index.js +63 -48
  33. package/lib/module/enhancement/index.js.map +1 -1
  34. package/lib/module/enhancement/streaming.js +60 -0
  35. package/lib/module/enhancement/streaming.js.map +1 -0
  36. package/lib/module/enhancement/streamingTypes.js +4 -0
  37. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  38. package/lib/module/enhancement/types.js +4 -0
  39. package/lib/module/enhancement/types.js.map +1 -0
  40. package/lib/module/licenses.js +9 -3
  41. package/lib/module/licenses.js.map +1 -1
  42. package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
  43. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  44. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  45. package/lib/typescript/src/download/paths.d.ts +2 -1
  46. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  47. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  48. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  49. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  50. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  51. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  52. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  53. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  54. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  55. package/lib/typescript/src/licenses.d.ts.map +1 -1
  56. package/package.json +1 -1
  57. package/scripts/ci/check-model-csvs.sh +27 -2
  58. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  59. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  60. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  61. package/scripts/ci/update_model_license_csv.sh +1 -1
  62. package/src/NativeSherpaOnnx.ts +71 -0
  63. package/src/download/localModels.ts +1 -3
  64. package/src/download/paths.ts +2 -1
  65. package/src/enhancement/index.ts +120 -58
  66. package/src/enhancement/streaming.ts +105 -0
  67. package/src/enhancement/streamingTypes.ts +14 -0
  68. package/src/enhancement/types.ts +36 -0
  69. package/src/licenses.ts +13 -2
  70. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
@@ -0,0 +1,435 @@
1
+ #import "SherpaOnnx.h"
2
+ #import <React/RCTLog.h>
3
+
4
+ #include "sherpa-onnx-enhancement-wrapper.h"
5
+ #include "sherpa-onnx-model-detect.h"
6
+ #include "sherpa-onnx/c-api/cxx-api.h"
7
+
8
+ #include <memory>
9
+ #include <mutex>
10
+ #include <optional>
11
+ #include <string>
12
+ #include <unordered_map>
13
+ #include <vector>
14
+
15
+ struct EnhancementInstanceState {
16
+ std::unique_ptr<sherpaonnx::EnhancementWrapper> wrapper;
17
+ };
18
+
19
+ struct OnlineEnhancementInstanceState {
20
+ std::unique_ptr<sherpaonnx::OnlineEnhancementWrapper> wrapper;
21
+ };
22
+
23
+ static std::unordered_map<std::string, std::unique_ptr<EnhancementInstanceState>> g_enhancement_instances;
24
+ static std::unordered_map<std::string, std::unique_ptr<OnlineEnhancementInstanceState>> g_online_enhancement_instances;
25
+ static std::mutex g_enhancement_mutex;
26
+
27
+ namespace {
28
+
29
+ static NSString *enhancementKindToNSString(sherpaonnx::EnhancementModelKind kind) {
30
+ using K = sherpaonnx::EnhancementModelKind;
31
+ switch (kind) {
32
+ case K::kGtcrn: return @"gtcrn";
33
+ case K::kDpdfNet: return @"dpdfnet";
34
+ default: return @"unknown";
35
+ }
36
+ }
37
+
38
+ static NSDictionary *enhancedAudioToDict(const sherpaonnx::EnhancedAudioResult& r) {
39
+ NSMutableArray *samples = [NSMutableArray arrayWithCapacity:r.samples.size()];
40
+ for (float s : r.samples) {
41
+ [samples addObject:@(s)];
42
+ }
43
+ return @{
44
+ @"samples": samples,
45
+ @"sampleRate": @(r.sampleRate)
46
+ };
47
+ }
48
+
49
+ static NSDictionary *enhancementDetectResultToDict(const sherpaonnx::EnhancementDetectResult& result) {
50
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
51
+ for (const auto& model : result.detectedModels) {
52
+ [detectedModelsArray addObject:@{
53
+ @"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
54
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
55
+ }];
56
+ }
57
+
58
+ NSMutableDictionary *dict = [@{
59
+ @"success": @(result.ok),
60
+ @"detectedModels": detectedModelsArray,
61
+ @"modelType": enhancementKindToNSString(result.selectedKind),
62
+ } mutableCopy];
63
+ if (!result.ok && !result.error.empty()) {
64
+ dict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()] ?: @"Enhancement model detection failed";
65
+ }
66
+ return dict;
67
+ }
68
+
69
+ } // namespace
70
+
71
+ @implementation SherpaOnnx (Enhancement)
72
+
73
+ - (void)detectEnhancementModel:(NSString *)modelDir
74
+ modelType:(NSString *)modelType
75
+ resolve:(RCTPromiseResolveBlock)resolve
76
+ reject:(RCTPromiseRejectBlock)reject
77
+ {
78
+ @try {
79
+ std::string modelDirStr = (modelDir != nil) ? [modelDir UTF8String] : "";
80
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
81
+ auto result = sherpaonnx::DetectEnhancementModel(modelDirStr, modelTypeStr);
82
+ resolve(enhancementDetectResultToDict(result));
83
+ } @catch (NSException *exception) {
84
+ reject(@"DETECT_ERROR",
85
+ [NSString stringWithFormat:@"Enhancement detect failed: %@", exception.reason],
86
+ nil);
87
+ }
88
+ }
89
+
90
+ - (void)initializeEnhancement:(NSString *)instanceId
91
+ modelDir:(NSString *)modelDir
92
+ modelType:(NSString *)modelType
93
+ numThreads:(NSNumber *)numThreads
94
+ provider:(NSString *)provider
95
+ debug:(NSNumber *)debug
96
+ resolve:(RCTPromiseResolveBlock)resolve
97
+ reject:(RCTPromiseRejectBlock)reject
98
+ {
99
+ if (instanceId == nil || [instanceId length] == 0) {
100
+ reject(@"ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
101
+ return;
102
+ }
103
+ if (modelDir == nil || [modelDir length] == 0) {
104
+ reject(@"ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
105
+ return;
106
+ }
107
+
108
+ std::string instanceIdStr = [instanceId UTF8String];
109
+ std::string modelDirStr = [modelDir UTF8String];
110
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
111
+ int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
112
+ bool debugVal = debug != nil && [debug boolValue];
113
+ std::optional<std::string> providerOpt = std::nullopt;
114
+ if (provider != nil && [provider length] > 0) {
115
+ providerOpt = std::string([provider UTF8String]);
116
+ }
117
+
118
+ @try {
119
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
120
+ auto it = g_enhancement_instances.find(instanceIdStr);
121
+ if (it == g_enhancement_instances.end()) {
122
+ g_enhancement_instances[instanceIdStr] = std::make_unique<EnhancementInstanceState>();
123
+ }
124
+ auto *inst = g_enhancement_instances[instanceIdStr].get();
125
+ if (inst->wrapper == nullptr) {
126
+ inst->wrapper = std::make_unique<sherpaonnx::EnhancementWrapper>();
127
+ }
128
+
129
+ auto result = inst->wrapper->initialize(
130
+ modelDirStr,
131
+ modelTypeStr,
132
+ numThreadsVal,
133
+ providerOpt,
134
+ debugVal
135
+ );
136
+
137
+ if (!result.success) {
138
+ NSString *errorMsg = result.error.empty()
139
+ ? @"Failed to initialize enhancement"
140
+ : [NSString stringWithUTF8String:result.error.c_str()];
141
+ reject(@"ENHANCEMENT_INIT_ERROR", errorMsg, nil);
142
+ return;
143
+ }
144
+
145
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
146
+ for (const auto& model : result.detectedModels) {
147
+ [detectedModelsArray addObject:@{
148
+ @"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
149
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
150
+ }];
151
+ }
152
+
153
+ resolve(@{
154
+ @"success": @YES,
155
+ @"detectedModels": detectedModelsArray,
156
+ @"modelType": [NSString stringWithUTF8String:result.modelType.c_str()] ?: @"unknown",
157
+ @"sampleRate": @(result.sampleRate)
158
+ });
159
+ } @catch (NSException *exception) {
160
+ reject(@"ENHANCEMENT_INIT_ERROR",
161
+ [NSString stringWithFormat:@"Enhancement init failed: %@", exception.reason],
162
+ nil);
163
+ }
164
+ }
165
+
166
+ - (void)enhanceSamples:(NSString *)instanceId
167
+ samples:(NSArray *)samples
168
+ sampleRate:(double)sampleRate
169
+ resolve:(RCTPromiseResolveBlock)resolve
170
+ reject:(RCTPromiseRejectBlock)reject
171
+ {
172
+ if (instanceId == nil || [instanceId length] == 0) {
173
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
174
+ return;
175
+ }
176
+
177
+ std::string instanceIdStr = [instanceId UTF8String];
178
+ std::vector<float> floatSamples;
179
+ floatSamples.reserve([samples count]);
180
+ for (NSNumber *n in samples) {
181
+ floatSamples.push_back([n floatValue]);
182
+ }
183
+
184
+ @try {
185
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
186
+ auto it = g_enhancement_instances.find(instanceIdStr);
187
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
188
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
189
+ return;
190
+ }
191
+ auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
192
+ resolve(enhancedAudioToDict(out));
193
+ } @catch (NSException *exception) {
194
+ reject(@"ENHANCEMENT_ERROR",
195
+ [NSString stringWithFormat:@"Enhance samples failed: %@", exception.reason],
196
+ nil);
197
+ }
198
+ }
199
+
200
+ - (void)enhanceFile:(NSString *)instanceId
201
+ inputPath:(NSString *)inputPath
202
+ outputPath:(NSString *)outputPath
203
+ resolve:(RCTPromiseResolveBlock)resolve
204
+ reject:(RCTPromiseRejectBlock)reject
205
+ {
206
+ if (instanceId == nil || [instanceId length] == 0) {
207
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
208
+ return;
209
+ }
210
+ if (inputPath == nil || [inputPath length] == 0) {
211
+ reject(@"ENHANCEMENT_ERROR", @"inputPath is required", nil);
212
+ return;
213
+ }
214
+
215
+ std::string instanceIdStr = [instanceId UTF8String];
216
+ std::string inputPathStr = [inputPath UTF8String];
217
+
218
+ @try {
219
+ sherpa_onnx::cxx::Wave wave = sherpa_onnx::cxx::ReadWave(inputPathStr);
220
+ if (wave.samples.empty() || wave.sample_rate <= 0) {
221
+ reject(@"ENHANCEMENT_ERROR", @"Failed to read input wave file", nil);
222
+ return;
223
+ }
224
+
225
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
226
+ auto it = g_enhancement_instances.find(instanceIdStr);
227
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
228
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
229
+ return;
230
+ }
231
+ auto out = it->second->wrapper->runSamples(wave.samples, wave.sample_rate);
232
+
233
+ if (outputPath != nil && [outputPath length] > 0) {
234
+ sherpa_onnx::cxx::Wave outputWave;
235
+ outputWave.samples = out.samples;
236
+ outputWave.sample_rate = out.sampleRate;
237
+ std::string outputPathStr = [outputPath UTF8String];
238
+ sherpa_onnx::cxx::WriteWave(outputPathStr, outputWave);
239
+ }
240
+
241
+ resolve(enhancedAudioToDict(out));
242
+ } @catch (NSException *exception) {
243
+ reject(@"ENHANCEMENT_ERROR",
244
+ [NSString stringWithFormat:@"Enhance file failed: %@", exception.reason],
245
+ nil);
246
+ }
247
+ }
248
+
249
+ - (void)getEnhancementSampleRate:(NSString *)instanceId
250
+ resolve:(RCTPromiseResolveBlock)resolve
251
+ reject:(RCTPromiseRejectBlock)reject
252
+ {
253
+ if (instanceId == nil || [instanceId length] == 0) {
254
+ reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
255
+ return;
256
+ }
257
+ std::string instanceIdStr = [instanceId UTF8String];
258
+
259
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
260
+ auto it = g_enhancement_instances.find(instanceIdStr);
261
+ if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
262
+ reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
263
+ return;
264
+ }
265
+ resolve(@(it->second->wrapper->getSampleRate()));
266
+ }
267
+
268
+ - (void)unloadEnhancement:(NSString *)instanceId
269
+ resolve:(RCTPromiseResolveBlock)resolve
270
+ reject:(RCTPromiseRejectBlock)reject
271
+ {
272
+ if (instanceId == nil || [instanceId length] == 0) {
273
+ resolve(nil);
274
+ return;
275
+ }
276
+ std::string instanceIdStr = [instanceId UTF8String];
277
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
278
+ auto it = g_enhancement_instances.find(instanceIdStr);
279
+ if (it != g_enhancement_instances.end() && it->second->wrapper != nullptr) {
280
+ it->second->wrapper->release();
281
+ g_enhancement_instances.erase(it);
282
+ }
283
+ resolve(nil);
284
+ }
285
+
286
+ - (void)initializeOnlineEnhancement:(NSString *)instanceId
287
+ modelDir:(NSString *)modelDir
288
+ modelType:(NSString *)modelType
289
+ numThreads:(NSNumber *)numThreads
290
+ provider:(NSString *)provider
291
+ debug:(NSNumber *)debug
292
+ resolve:(RCTPromiseResolveBlock)resolve
293
+ reject:(RCTPromiseRejectBlock)reject
294
+ {
295
+ if (instanceId == nil || [instanceId length] == 0) {
296
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
297
+ return;
298
+ }
299
+ if (modelDir == nil || [modelDir length] == 0) {
300
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
301
+ return;
302
+ }
303
+
304
+ std::string instanceIdStr = [instanceId UTF8String];
305
+ std::string modelDirStr = [modelDir UTF8String];
306
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
307
+ int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
308
+ bool debugVal = debug != nil && [debug boolValue];
309
+ std::optional<std::string> providerOpt = std::nullopt;
310
+ if (provider != nil && [provider length] > 0) {
311
+ providerOpt = std::string([provider UTF8String]);
312
+ }
313
+
314
+ @try {
315
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
316
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
317
+ if (it == g_online_enhancement_instances.end()) {
318
+ g_online_enhancement_instances[instanceIdStr] = std::make_unique<OnlineEnhancementInstanceState>();
319
+ }
320
+ auto *inst = g_online_enhancement_instances[instanceIdStr].get();
321
+ if (inst->wrapper == nullptr) {
322
+ inst->wrapper = std::make_unique<sherpaonnx::OnlineEnhancementWrapper>();
323
+ }
324
+
325
+ auto result = inst->wrapper->initialize(
326
+ modelDirStr,
327
+ modelTypeStr,
328
+ numThreadsVal,
329
+ providerOpt,
330
+ debugVal
331
+ );
332
+ if (!result.success) {
333
+ NSString *errorMsg = result.error.empty()
334
+ ? @"Failed to initialize online enhancement"
335
+ : [NSString stringWithUTF8String:result.error.c_str()];
336
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", errorMsg, nil);
337
+ return;
338
+ }
339
+
340
+ resolve(@{
341
+ @"success": @YES,
342
+ @"sampleRate": @(result.sampleRate),
343
+ @"frameShiftInSamples": @(result.frameShiftInSamples)
344
+ });
345
+ } @catch (NSException *exception) {
346
+ reject(@"ONLINE_ENHANCEMENT_INIT_ERROR",
347
+ [NSString stringWithFormat:@"Online enhancement init failed: %@", exception.reason],
348
+ nil);
349
+ }
350
+ }
351
+
352
+ - (void)feedEnhancementSamples:(NSString *)instanceId
353
+ samples:(NSArray *)samples
354
+ sampleRate:(double)sampleRate
355
+ resolve:(RCTPromiseResolveBlock)resolve
356
+ reject:(RCTPromiseRejectBlock)reject
357
+ {
358
+ if (instanceId == nil || [instanceId length] == 0) {
359
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
360
+ return;
361
+ }
362
+ std::string instanceIdStr = [instanceId UTF8String];
363
+ std::vector<float> floatSamples;
364
+ floatSamples.reserve([samples count]);
365
+ for (NSNumber *n in samples) {
366
+ floatSamples.push_back([n floatValue]);
367
+ }
368
+
369
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
370
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
371
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
372
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
373
+ return;
374
+ }
375
+ auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
376
+ resolve(enhancedAudioToDict(out));
377
+ }
378
+
379
+ - (void)flushOnlineEnhancement:(NSString *)instanceId
380
+ resolve:(RCTPromiseResolveBlock)resolve
381
+ reject:(RCTPromiseRejectBlock)reject
382
+ {
383
+ if (instanceId == nil || [instanceId length] == 0) {
384
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
385
+ return;
386
+ }
387
+ std::string instanceIdStr = [instanceId UTF8String];
388
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
389
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
390
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
391
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
392
+ return;
393
+ }
394
+ auto out = it->second->wrapper->flush();
395
+ resolve(enhancedAudioToDict(out));
396
+ }
397
+
398
+ - (void)resetOnlineEnhancement:(NSString *)instanceId
399
+ resolve:(RCTPromiseResolveBlock)resolve
400
+ reject:(RCTPromiseRejectBlock)reject
401
+ {
402
+ if (instanceId == nil || [instanceId length] == 0) {
403
+ resolve(nil);
404
+ return;
405
+ }
406
+ std::string instanceIdStr = [instanceId UTF8String];
407
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
408
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
409
+ if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
410
+ reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
411
+ return;
412
+ }
413
+ it->second->wrapper->reset();
414
+ resolve(nil);
415
+ }
416
+
417
+ - (void)unloadOnlineEnhancement:(NSString *)instanceId
418
+ resolve:(RCTPromiseResolveBlock)resolve
419
+ reject:(RCTPromiseRejectBlock)reject
420
+ {
421
+ if (instanceId == nil || [instanceId length] == 0) {
422
+ resolve(nil);
423
+ return;
424
+ }
425
+ std::string instanceIdStr = [instanceId UTF8String];
426
+ std::lock_guard<std::mutex> lock(g_enhancement_mutex);
427
+ auto it = g_online_enhancement_instances.find(instanceIdStr);
428
+ if (it != g_online_enhancement_instances.end() && it->second->wrapper != nullptr) {
429
+ it->second->wrapper->release();
430
+ g_online_enhancement_instances.erase(it);
431
+ }
432
+ resolve(nil);
433
+ }
434
+
435
+ @end
@@ -0,0 +1,85 @@
1
+ #ifndef SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
2
+ #define SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
3
+
4
+ #include "sherpa-onnx-common.h"
5
+ #include "sherpa-onnx-model-detect.h"
6
+ #include <cstdint>
7
+ #include <memory>
8
+ #include <optional>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ namespace sherpaonnx {
13
+
14
+ struct EnhancementInitializeResult {
15
+ bool success = false;
16
+ std::vector<DetectedModel> detectedModels;
17
+ std::string error;
18
+ std::string modelType;
19
+ int32_t sampleRate = 0;
20
+ int32_t frameShiftInSamples = 0;
21
+ };
22
+
23
+ struct EnhancedAudioResult {
24
+ std::vector<float> samples;
25
+ int32_t sampleRate = 0;
26
+ };
27
+
28
+ class EnhancementWrapper {
29
+ public:
30
+ EnhancementWrapper();
31
+ ~EnhancementWrapper();
32
+
33
+ EnhancementInitializeResult initialize(
34
+ const std::string& modelDir,
35
+ const std::string& modelType = "auto",
36
+ int32_t numThreads = 1,
37
+ const std::optional<std::string>& provider = std::nullopt,
38
+ bool debug = false
39
+ );
40
+
41
+ EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
42
+
43
+ int32_t getSampleRate() const;
44
+
45
+ bool isInitialized() const;
46
+
47
+ void release();
48
+
49
+ private:
50
+ class Impl;
51
+ std::unique_ptr<Impl> pImpl;
52
+ };
53
+
54
+ class OnlineEnhancementWrapper {
55
+ public:
56
+ OnlineEnhancementWrapper();
57
+ ~OnlineEnhancementWrapper();
58
+
59
+ EnhancementInitializeResult initialize(
60
+ const std::string& modelDir,
61
+ const std::string& modelType = "auto",
62
+ int32_t numThreads = 1,
63
+ const std::optional<std::string>& provider = std::nullopt,
64
+ bool debug = false
65
+ );
66
+
67
+ EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
68
+ EnhancedAudioResult flush();
69
+ void reset();
70
+
71
+ int32_t getSampleRate() const;
72
+ int32_t getFrameShiftInSamples() const;
73
+
74
+ bool isInitialized() const;
75
+
76
+ void release();
77
+
78
+ private:
79
+ class Impl;
80
+ std::unique_ptr<Impl> pImpl;
81
+ };
82
+
83
+ } // namespace sherpaonnx
84
+
85
+ #endif // SHERPA_ONNX_ENHANCEMENT_WRAPPER_H