react-native-sherpa-onnx 0.3.9 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -4
- package/SherpaOnnx.podspec +1 -0
- package/android/prebuilt-download.gradle +67 -27
- package/android/prebuilt-versions.gradle +1 -1
- package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +66 -13
- package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/ios/SherpaOnnx+Assets.mm +5 -0
- package/ios/SherpaOnnx+Enhancement.mm +435 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
- package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
- package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/download/localModels.js +2 -3
- package/lib/module/download/localModels.js.map +1 -1
- package/lib/module/download/paths.js +2 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/enhancement/index.js +63 -48
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/enhancement/streaming.js +60 -0
- package/lib/module/enhancement/streaming.js.map +1 -0
- package/lib/module/enhancement/streamingTypes.js +4 -0
- package/lib/module/enhancement/streamingTypes.js.map +1 -0
- package/lib/module/enhancement/types.js +4 -0
- package/lib/module/enhancement/types.js.map +1 -0
- package/lib/module/licenses.js +9 -3
- package/lib/module/licenses.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/download/localModels.d.ts.map +1 -1
- package/lib/typescript/src/download/paths.d.ts +2 -1
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/index.d.ts +9 -46
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
- package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/types.d.ts +31 -0
- package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/package.json +1 -1
- package/scripts/ci/check-model-csvs.sh +27 -2
- package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
- package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
- package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
- package/scripts/ci/update_model_license_csv.sh +1 -1
- package/src/NativeSherpaOnnx.ts +71 -0
- package/src/download/localModels.ts +1 -3
- package/src/download/paths.ts +2 -1
- package/src/enhancement/index.ts +120 -58
- package/src/enhancement/streaming.ts +105 -0
- package/src/enhancement/streamingTypes.ts +14 -0
- package/src/enhancement/types.ts +36 -0
- package/src/licenses.ts +13 -2
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
#import "SherpaOnnx.h"
|
|
2
|
+
#import <React/RCTLog.h>
|
|
3
|
+
|
|
4
|
+
#include "sherpa-onnx-enhancement-wrapper.h"
|
|
5
|
+
#include "sherpa-onnx-model-detect.h"
|
|
6
|
+
#include "sherpa-onnx/c-api/cxx-api.h"
|
|
7
|
+
|
|
8
|
+
#include <memory>
|
|
9
|
+
#include <mutex>
|
|
10
|
+
#include <optional>
|
|
11
|
+
#include <string>
|
|
12
|
+
#include <unordered_map>
|
|
13
|
+
#include <vector>
|
|
14
|
+
|
|
15
|
+
struct EnhancementInstanceState {
|
|
16
|
+
std::unique_ptr<sherpaonnx::EnhancementWrapper> wrapper;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
struct OnlineEnhancementInstanceState {
|
|
20
|
+
std::unique_ptr<sherpaonnx::OnlineEnhancementWrapper> wrapper;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
static std::unordered_map<std::string, std::unique_ptr<EnhancementInstanceState>> g_enhancement_instances;
|
|
24
|
+
static std::unordered_map<std::string, std::unique_ptr<OnlineEnhancementInstanceState>> g_online_enhancement_instances;
|
|
25
|
+
static std::mutex g_enhancement_mutex;
|
|
26
|
+
|
|
27
|
+
namespace {
|
|
28
|
+
|
|
29
|
+
static NSString *enhancementKindToNSString(sherpaonnx::EnhancementModelKind kind) {
|
|
30
|
+
using K = sherpaonnx::EnhancementModelKind;
|
|
31
|
+
switch (kind) {
|
|
32
|
+
case K::kGtcrn: return @"gtcrn";
|
|
33
|
+
case K::kDpdfNet: return @"dpdfnet";
|
|
34
|
+
default: return @"unknown";
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
static NSDictionary *enhancedAudioToDict(const sherpaonnx::EnhancedAudioResult& r) {
|
|
39
|
+
NSMutableArray *samples = [NSMutableArray arrayWithCapacity:r.samples.size()];
|
|
40
|
+
for (float s : r.samples) {
|
|
41
|
+
[samples addObject:@(s)];
|
|
42
|
+
}
|
|
43
|
+
return @{
|
|
44
|
+
@"samples": samples,
|
|
45
|
+
@"sampleRate": @(r.sampleRate)
|
|
46
|
+
};
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
static NSDictionary *enhancementDetectResultToDict(const sherpaonnx::EnhancementDetectResult& result) {
|
|
50
|
+
NSMutableArray *detectedModelsArray = [NSMutableArray array];
|
|
51
|
+
for (const auto& model : result.detectedModels) {
|
|
52
|
+
[detectedModelsArray addObject:@{
|
|
53
|
+
@"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
|
|
54
|
+
@"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
|
|
55
|
+
}];
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
NSMutableDictionary *dict = [@{
|
|
59
|
+
@"success": @(result.ok),
|
|
60
|
+
@"detectedModels": detectedModelsArray,
|
|
61
|
+
@"modelType": enhancementKindToNSString(result.selectedKind),
|
|
62
|
+
} mutableCopy];
|
|
63
|
+
if (!result.ok && !result.error.empty()) {
|
|
64
|
+
dict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()] ?: @"Enhancement model detection failed";
|
|
65
|
+
}
|
|
66
|
+
return dict;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
} // namespace
|
|
70
|
+
|
|
71
|
+
@implementation SherpaOnnx (Enhancement)
|
|
72
|
+
|
|
73
|
+
- (void)detectEnhancementModel:(NSString *)modelDir
|
|
74
|
+
modelType:(NSString *)modelType
|
|
75
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
76
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
77
|
+
{
|
|
78
|
+
@try {
|
|
79
|
+
std::string modelDirStr = (modelDir != nil) ? [modelDir UTF8String] : "";
|
|
80
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
81
|
+
auto result = sherpaonnx::DetectEnhancementModel(modelDirStr, modelTypeStr);
|
|
82
|
+
resolve(enhancementDetectResultToDict(result));
|
|
83
|
+
} @catch (NSException *exception) {
|
|
84
|
+
reject(@"DETECT_ERROR",
|
|
85
|
+
[NSString stringWithFormat:@"Enhancement detect failed: %@", exception.reason],
|
|
86
|
+
nil);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
- (void)initializeEnhancement:(NSString *)instanceId
|
|
91
|
+
modelDir:(NSString *)modelDir
|
|
92
|
+
modelType:(NSString *)modelType
|
|
93
|
+
numThreads:(NSNumber *)numThreads
|
|
94
|
+
provider:(NSString *)provider
|
|
95
|
+
debug:(NSNumber *)debug
|
|
96
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
97
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
98
|
+
{
|
|
99
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
100
|
+
reject(@"ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
|
|
101
|
+
return;
|
|
102
|
+
}
|
|
103
|
+
if (modelDir == nil || [modelDir length] == 0) {
|
|
104
|
+
reject(@"ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
|
|
105
|
+
return;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
109
|
+
std::string modelDirStr = [modelDir UTF8String];
|
|
110
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
111
|
+
int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
|
|
112
|
+
bool debugVal = debug != nil && [debug boolValue];
|
|
113
|
+
std::optional<std::string> providerOpt = std::nullopt;
|
|
114
|
+
if (provider != nil && [provider length] > 0) {
|
|
115
|
+
providerOpt = std::string([provider UTF8String]);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
@try {
|
|
119
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
120
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
121
|
+
if (it == g_enhancement_instances.end()) {
|
|
122
|
+
g_enhancement_instances[instanceIdStr] = std::make_unique<EnhancementInstanceState>();
|
|
123
|
+
}
|
|
124
|
+
auto *inst = g_enhancement_instances[instanceIdStr].get();
|
|
125
|
+
if (inst->wrapper == nullptr) {
|
|
126
|
+
inst->wrapper = std::make_unique<sherpaonnx::EnhancementWrapper>();
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
auto result = inst->wrapper->initialize(
|
|
130
|
+
modelDirStr,
|
|
131
|
+
modelTypeStr,
|
|
132
|
+
numThreadsVal,
|
|
133
|
+
providerOpt,
|
|
134
|
+
debugVal
|
|
135
|
+
);
|
|
136
|
+
|
|
137
|
+
if (!result.success) {
|
|
138
|
+
NSString *errorMsg = result.error.empty()
|
|
139
|
+
? @"Failed to initialize enhancement"
|
|
140
|
+
: [NSString stringWithUTF8String:result.error.c_str()];
|
|
141
|
+
reject(@"ENHANCEMENT_INIT_ERROR", errorMsg, nil);
|
|
142
|
+
return;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
NSMutableArray *detectedModelsArray = [NSMutableArray array];
|
|
146
|
+
for (const auto& model : result.detectedModels) {
|
|
147
|
+
[detectedModelsArray addObject:@{
|
|
148
|
+
@"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
|
|
149
|
+
@"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
|
|
150
|
+
}];
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
resolve(@{
|
|
154
|
+
@"success": @YES,
|
|
155
|
+
@"detectedModels": detectedModelsArray,
|
|
156
|
+
@"modelType": [NSString stringWithUTF8String:result.modelType.c_str()] ?: @"unknown",
|
|
157
|
+
@"sampleRate": @(result.sampleRate)
|
|
158
|
+
});
|
|
159
|
+
} @catch (NSException *exception) {
|
|
160
|
+
reject(@"ENHANCEMENT_INIT_ERROR",
|
|
161
|
+
[NSString stringWithFormat:@"Enhancement init failed: %@", exception.reason],
|
|
162
|
+
nil);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
- (void)enhanceSamples:(NSString *)instanceId
|
|
167
|
+
samples:(NSArray *)samples
|
|
168
|
+
sampleRate:(double)sampleRate
|
|
169
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
170
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
171
|
+
{
|
|
172
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
173
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
174
|
+
return;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
178
|
+
std::vector<float> floatSamples;
|
|
179
|
+
floatSamples.reserve([samples count]);
|
|
180
|
+
for (NSNumber *n in samples) {
|
|
181
|
+
floatSamples.push_back([n floatValue]);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
@try {
|
|
185
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
186
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
187
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
188
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
189
|
+
return;
|
|
190
|
+
}
|
|
191
|
+
auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
|
|
192
|
+
resolve(enhancedAudioToDict(out));
|
|
193
|
+
} @catch (NSException *exception) {
|
|
194
|
+
reject(@"ENHANCEMENT_ERROR",
|
|
195
|
+
[NSString stringWithFormat:@"Enhance samples failed: %@", exception.reason],
|
|
196
|
+
nil);
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
- (void)enhanceFile:(NSString *)instanceId
|
|
201
|
+
inputPath:(NSString *)inputPath
|
|
202
|
+
outputPath:(NSString *)outputPath
|
|
203
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
204
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
205
|
+
{
|
|
206
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
207
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
208
|
+
return;
|
|
209
|
+
}
|
|
210
|
+
if (inputPath == nil || [inputPath length] == 0) {
|
|
211
|
+
reject(@"ENHANCEMENT_ERROR", @"inputPath is required", nil);
|
|
212
|
+
return;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
216
|
+
std::string inputPathStr = [inputPath UTF8String];
|
|
217
|
+
|
|
218
|
+
@try {
|
|
219
|
+
sherpa_onnx::cxx::Wave wave = sherpa_onnx::cxx::ReadWave(inputPathStr);
|
|
220
|
+
if (wave.samples.empty() || wave.sample_rate <= 0) {
|
|
221
|
+
reject(@"ENHANCEMENT_ERROR", @"Failed to read input wave file", nil);
|
|
222
|
+
return;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
226
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
227
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
228
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
229
|
+
return;
|
|
230
|
+
}
|
|
231
|
+
auto out = it->second->wrapper->runSamples(wave.samples, wave.sample_rate);
|
|
232
|
+
|
|
233
|
+
if (outputPath != nil && [outputPath length] > 0) {
|
|
234
|
+
sherpa_onnx::cxx::Wave outputWave;
|
|
235
|
+
outputWave.samples = out.samples;
|
|
236
|
+
outputWave.sample_rate = out.sampleRate;
|
|
237
|
+
std::string outputPathStr = [outputPath UTF8String];
|
|
238
|
+
sherpa_onnx::cxx::WriteWave(outputPathStr, outputWave);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
resolve(enhancedAudioToDict(out));
|
|
242
|
+
} @catch (NSException *exception) {
|
|
243
|
+
reject(@"ENHANCEMENT_ERROR",
|
|
244
|
+
[NSString stringWithFormat:@"Enhance file failed: %@", exception.reason],
|
|
245
|
+
nil);
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
- (void)getEnhancementSampleRate:(NSString *)instanceId
|
|
250
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
251
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
252
|
+
{
|
|
253
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
254
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
258
|
+
|
|
259
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
260
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
261
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
262
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
263
|
+
return;
|
|
264
|
+
}
|
|
265
|
+
resolve(@(it->second->wrapper->getSampleRate()));
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
- (void)unloadEnhancement:(NSString *)instanceId
|
|
269
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
270
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
271
|
+
{
|
|
272
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
273
|
+
resolve(nil);
|
|
274
|
+
return;
|
|
275
|
+
}
|
|
276
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
277
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
278
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
279
|
+
if (it != g_enhancement_instances.end() && it->second->wrapper != nullptr) {
|
|
280
|
+
it->second->wrapper->release();
|
|
281
|
+
g_enhancement_instances.erase(it);
|
|
282
|
+
}
|
|
283
|
+
resolve(nil);
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
- (void)initializeOnlineEnhancement:(NSString *)instanceId
|
|
287
|
+
modelDir:(NSString *)modelDir
|
|
288
|
+
modelType:(NSString *)modelType
|
|
289
|
+
numThreads:(NSNumber *)numThreads
|
|
290
|
+
provider:(NSString *)provider
|
|
291
|
+
debug:(NSNumber *)debug
|
|
292
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
293
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
294
|
+
{
|
|
295
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
296
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
|
|
297
|
+
return;
|
|
298
|
+
}
|
|
299
|
+
if (modelDir == nil || [modelDir length] == 0) {
|
|
300
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
|
|
301
|
+
return;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
305
|
+
std::string modelDirStr = [modelDir UTF8String];
|
|
306
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
307
|
+
int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
|
|
308
|
+
bool debugVal = debug != nil && [debug boolValue];
|
|
309
|
+
std::optional<std::string> providerOpt = std::nullopt;
|
|
310
|
+
if (provider != nil && [provider length] > 0) {
|
|
311
|
+
providerOpt = std::string([provider UTF8String]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
@try {
|
|
315
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
316
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
317
|
+
if (it == g_online_enhancement_instances.end()) {
|
|
318
|
+
g_online_enhancement_instances[instanceIdStr] = std::make_unique<OnlineEnhancementInstanceState>();
|
|
319
|
+
}
|
|
320
|
+
auto *inst = g_online_enhancement_instances[instanceIdStr].get();
|
|
321
|
+
if (inst->wrapper == nullptr) {
|
|
322
|
+
inst->wrapper = std::make_unique<sherpaonnx::OnlineEnhancementWrapper>();
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
auto result = inst->wrapper->initialize(
|
|
326
|
+
modelDirStr,
|
|
327
|
+
modelTypeStr,
|
|
328
|
+
numThreadsVal,
|
|
329
|
+
providerOpt,
|
|
330
|
+
debugVal
|
|
331
|
+
);
|
|
332
|
+
if (!result.success) {
|
|
333
|
+
NSString *errorMsg = result.error.empty()
|
|
334
|
+
? @"Failed to initialize online enhancement"
|
|
335
|
+
: [NSString stringWithUTF8String:result.error.c_str()];
|
|
336
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", errorMsg, nil);
|
|
337
|
+
return;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
resolve(@{
|
|
341
|
+
@"success": @YES,
|
|
342
|
+
@"sampleRate": @(result.sampleRate),
|
|
343
|
+
@"frameShiftInSamples": @(result.frameShiftInSamples)
|
|
344
|
+
});
|
|
345
|
+
} @catch (NSException *exception) {
|
|
346
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR",
|
|
347
|
+
[NSString stringWithFormat:@"Online enhancement init failed: %@", exception.reason],
|
|
348
|
+
nil);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
- (void)feedEnhancementSamples:(NSString *)instanceId
|
|
353
|
+
samples:(NSArray *)samples
|
|
354
|
+
sampleRate:(double)sampleRate
|
|
355
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
356
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
357
|
+
{
|
|
358
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
359
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
360
|
+
return;
|
|
361
|
+
}
|
|
362
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
363
|
+
std::vector<float> floatSamples;
|
|
364
|
+
floatSamples.reserve([samples count]);
|
|
365
|
+
for (NSNumber *n in samples) {
|
|
366
|
+
floatSamples.push_back([n floatValue]);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
370
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
371
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
372
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
373
|
+
return;
|
|
374
|
+
}
|
|
375
|
+
auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
|
|
376
|
+
resolve(enhancedAudioToDict(out));
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
- (void)flushOnlineEnhancement:(NSString *)instanceId
|
|
380
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
381
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
382
|
+
{
|
|
383
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
384
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
385
|
+
return;
|
|
386
|
+
}
|
|
387
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
388
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
389
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
390
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
391
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
392
|
+
return;
|
|
393
|
+
}
|
|
394
|
+
auto out = it->second->wrapper->flush();
|
|
395
|
+
resolve(enhancedAudioToDict(out));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
- (void)resetOnlineEnhancement:(NSString *)instanceId
|
|
399
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
400
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
401
|
+
{
|
|
402
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
403
|
+
resolve(nil);
|
|
404
|
+
return;
|
|
405
|
+
}
|
|
406
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
407
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
408
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
409
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
410
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
411
|
+
return;
|
|
412
|
+
}
|
|
413
|
+
it->second->wrapper->reset();
|
|
414
|
+
resolve(nil);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
- (void)unloadOnlineEnhancement:(NSString *)instanceId
|
|
418
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
419
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
420
|
+
{
|
|
421
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
422
|
+
resolve(nil);
|
|
423
|
+
return;
|
|
424
|
+
}
|
|
425
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
426
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
427
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
428
|
+
if (it != g_online_enhancement_instances.end() && it->second->wrapper != nullptr) {
|
|
429
|
+
it->second->wrapper->release();
|
|
430
|
+
g_online_enhancement_instances.erase(it);
|
|
431
|
+
}
|
|
432
|
+
resolve(nil);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
@end
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
#ifndef SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
|
|
2
|
+
#define SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
|
|
3
|
+
|
|
4
|
+
#include "sherpa-onnx-common.h"
|
|
5
|
+
#include "sherpa-onnx-model-detect.h"
|
|
6
|
+
#include <cstdint>
|
|
7
|
+
#include <memory>
|
|
8
|
+
#include <optional>
|
|
9
|
+
#include <string>
|
|
10
|
+
#include <vector>
|
|
11
|
+
|
|
12
|
+
namespace sherpaonnx {
|
|
13
|
+
|
|
14
|
+
struct EnhancementInitializeResult {
|
|
15
|
+
bool success = false;
|
|
16
|
+
std::vector<DetectedModel> detectedModels;
|
|
17
|
+
std::string error;
|
|
18
|
+
std::string modelType;
|
|
19
|
+
int32_t sampleRate = 0;
|
|
20
|
+
int32_t frameShiftInSamples = 0;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
struct EnhancedAudioResult {
|
|
24
|
+
std::vector<float> samples;
|
|
25
|
+
int32_t sampleRate = 0;
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
class EnhancementWrapper {
|
|
29
|
+
public:
|
|
30
|
+
EnhancementWrapper();
|
|
31
|
+
~EnhancementWrapper();
|
|
32
|
+
|
|
33
|
+
EnhancementInitializeResult initialize(
|
|
34
|
+
const std::string& modelDir,
|
|
35
|
+
const std::string& modelType = "auto",
|
|
36
|
+
int32_t numThreads = 1,
|
|
37
|
+
const std::optional<std::string>& provider = std::nullopt,
|
|
38
|
+
bool debug = false
|
|
39
|
+
);
|
|
40
|
+
|
|
41
|
+
EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
|
|
42
|
+
|
|
43
|
+
int32_t getSampleRate() const;
|
|
44
|
+
|
|
45
|
+
bool isInitialized() const;
|
|
46
|
+
|
|
47
|
+
void release();
|
|
48
|
+
|
|
49
|
+
private:
|
|
50
|
+
class Impl;
|
|
51
|
+
std::unique_ptr<Impl> pImpl;
|
|
52
|
+
};
|
|
53
|
+
|
|
54
|
+
class OnlineEnhancementWrapper {
|
|
55
|
+
public:
|
|
56
|
+
OnlineEnhancementWrapper();
|
|
57
|
+
~OnlineEnhancementWrapper();
|
|
58
|
+
|
|
59
|
+
EnhancementInitializeResult initialize(
|
|
60
|
+
const std::string& modelDir,
|
|
61
|
+
const std::string& modelType = "auto",
|
|
62
|
+
int32_t numThreads = 1,
|
|
63
|
+
const std::optional<std::string>& provider = std::nullopt,
|
|
64
|
+
bool debug = false
|
|
65
|
+
);
|
|
66
|
+
|
|
67
|
+
EnhancedAudioResult runSamples(const std::vector<float>& samples, int32_t sampleRate);
|
|
68
|
+
EnhancedAudioResult flush();
|
|
69
|
+
void reset();
|
|
70
|
+
|
|
71
|
+
int32_t getSampleRate() const;
|
|
72
|
+
int32_t getFrameShiftInSamples() const;
|
|
73
|
+
|
|
74
|
+
bool isInitialized() const;
|
|
75
|
+
|
|
76
|
+
void release();
|
|
77
|
+
|
|
78
|
+
private:
|
|
79
|
+
class Impl;
|
|
80
|
+
std::unique_ptr<Impl> pImpl;
|
|
81
|
+
};
|
|
82
|
+
|
|
83
|
+
} // namespace sherpaonnx
|
|
84
|
+
|
|
85
|
+
#endif // SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
|