react-native-sherpa-onnx 0.3.8 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +20 -5
- package/SherpaOnnx.podspec +5 -1
- package/android/prebuilt-download.gradle +89 -49
- package/android/prebuilt-versions.gradle +1 -1
- package/android/src/main/assets/model_licenses/asr-models-license-status.csv +1 -0
- package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +23 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +9 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +51 -8
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +41 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +5 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-stt.cpp +11 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +110 -35
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxExtractionNotificationHelper.kt +102 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +198 -18
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +22 -0
- package/ios/Resources/model_licenses/asr-models-license-status.csv +1 -0
- package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/ios/SherpaOnnx+Assets.mm +5 -0
- package/ios/SherpaOnnx+Enhancement.mm +435 -0
- package/ios/SherpaOnnx+STT.mm +13 -1
- package/ios/SherpaOnnx.mm +87 -17
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
- package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +5 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +23 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +51 -7
- package/ios/model_detect/sherpa-onnx-model-detect.h +33 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
- package/ios/model_detect/sherpa-onnx-validate-stt.mm +11 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +11 -1
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +30 -2
- package/ios/tts/sherpa-onnx-tts-wrapper.mm +16 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/download/localModels.js +2 -3
- package/lib/module/download/localModels.js.map +1 -1
- package/lib/module/download/paths.js +2 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/download/postDownloadProcessing.js +17 -4
- package/lib/module/download/postDownloadProcessing.js.map +1 -1
- package/lib/module/enhancement/index.js +63 -48
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/enhancement/streaming.js +60 -0
- package/lib/module/enhancement/streaming.js.map +1 -0
- package/lib/module/enhancement/streamingTypes.js +4 -0
- package/lib/module/enhancement/streamingTypes.js.map +1 -0
- package/lib/module/enhancement/types.js +4 -0
- package/lib/module/enhancement/types.js.map +1 -0
- package/lib/module/extraction/extractTarBz2.js +2 -2
- package/lib/module/extraction/extractTarBz2.js.map +1 -1
- package/lib/module/extraction/extractTarZst.js +2 -2
- package/lib/module/extraction/extractTarZst.js.map +1 -1
- package/lib/module/extraction/index.js +10 -5
- package/lib/module/extraction/index.js.map +1 -1
- package/lib/module/licenses.js +9 -3
- package/lib/module/licenses.js.map +1 -1
- package/lib/module/stt/index.js +4 -2
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/streaming.js +2 -1
- package/lib/module/stt/streaming.js.map +1 -1
- package/lib/module/stt/types.js +3 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +4 -2
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/streaming.js +3 -1
- package/lib/module/tts/streaming.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +70 -9
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/download/localModels.d.ts.map +1 -1
- package/lib/typescript/src/download/paths.d.ts +2 -1
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/download/postDownloadProcessing.d.ts +9 -0
- package/lib/typescript/src/download/postDownloadProcessing.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/index.d.ts +9 -46
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
- package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/types.d.ts +31 -0
- package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
- package/lib/typescript/src/extraction/extractTarBz2.d.ts +2 -1
- package/lib/typescript/src/extraction/extractTarBz2.d.ts.map +1 -1
- package/lib/typescript/src/extraction/extractTarZst.d.ts +2 -1
- package/lib/typescript/src/extraction/extractTarZst.d.ts.map +1 -1
- package/lib/typescript/src/extraction/index.d.ts +1 -1
- package/lib/typescript/src/extraction/index.d.ts.map +1 -1
- package/lib/typescript/src/extraction/types.d.ts +12 -0
- package/lib/typescript/src/extraction/types.d.ts.map +1 -1
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +1 -1
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/streaming.d.ts.map +1 -1
- package/lib/typescript/src/stt/types.d.ts +16 -1
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/streaming.d.ts.map +1 -1
- package/package.json +1 -1
- package/scripts/ci/check-model-csvs.sh +27 -2
- package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
- package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
- package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
- package/scripts/ci/update_model_license_csv.sh +17 -17
- package/src/NativeSherpaOnnx.ts +108 -10
- package/src/download/localModels.ts +1 -3
- package/src/download/paths.ts +2 -1
- package/src/download/postDownloadProcessing.ts +24 -1
- package/src/enhancement/index.ts +120 -58
- package/src/enhancement/streaming.ts +105 -0
- package/src/enhancement/streamingTypes.ts +14 -0
- package/src/enhancement/types.ts +36 -0
- package/src/extraction/extractTarBz2.ts +7 -2
- package/src/extraction/extractTarZst.ts +7 -2
- package/src/extraction/index.ts +29 -6
- package/src/extraction/types.ts +16 -0
- package/src/licenses.ts +13 -2
- package/src/stt/index.ts +8 -7
- package/src/stt/streaming.ts +7 -1
- package/src/stt/types.ts +18 -0
- package/src/tts/index.ts +7 -7
- package/src/tts/streaming.ts +6 -3
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
#import "SherpaOnnx.h"
|
|
2
|
+
#import <React/RCTLog.h>
|
|
3
|
+
|
|
4
|
+
#include "sherpa-onnx-enhancement-wrapper.h"
|
|
5
|
+
#include "sherpa-onnx-model-detect.h"
|
|
6
|
+
#include "sherpa-onnx/c-api/cxx-api.h"
|
|
7
|
+
|
|
8
|
+
#include <memory>
|
|
9
|
+
#include <mutex>
|
|
10
|
+
#include <optional>
|
|
11
|
+
#include <string>
|
|
12
|
+
#include <unordered_map>
|
|
13
|
+
#include <vector>
|
|
14
|
+
|
|
15
|
+
struct EnhancementInstanceState {
|
|
16
|
+
std::unique_ptr<sherpaonnx::EnhancementWrapper> wrapper;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
struct OnlineEnhancementInstanceState {
|
|
20
|
+
std::unique_ptr<sherpaonnx::OnlineEnhancementWrapper> wrapper;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
static std::unordered_map<std::string, std::unique_ptr<EnhancementInstanceState>> g_enhancement_instances;
|
|
24
|
+
static std::unordered_map<std::string, std::unique_ptr<OnlineEnhancementInstanceState>> g_online_enhancement_instances;
|
|
25
|
+
static std::mutex g_enhancement_mutex;
|
|
26
|
+
|
|
27
|
+
namespace {
|
|
28
|
+
|
|
29
|
+
static NSString *enhancementKindToNSString(sherpaonnx::EnhancementModelKind kind) {
|
|
30
|
+
using K = sherpaonnx::EnhancementModelKind;
|
|
31
|
+
switch (kind) {
|
|
32
|
+
case K::kGtcrn: return @"gtcrn";
|
|
33
|
+
case K::kDpdfNet: return @"dpdfnet";
|
|
34
|
+
default: return @"unknown";
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
static NSDictionary *enhancedAudioToDict(const sherpaonnx::EnhancedAudioResult& r) {
|
|
39
|
+
NSMutableArray *samples = [NSMutableArray arrayWithCapacity:r.samples.size()];
|
|
40
|
+
for (float s : r.samples) {
|
|
41
|
+
[samples addObject:@(s)];
|
|
42
|
+
}
|
|
43
|
+
return @{
|
|
44
|
+
@"samples": samples,
|
|
45
|
+
@"sampleRate": @(r.sampleRate)
|
|
46
|
+
};
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
static NSDictionary *enhancementDetectResultToDict(const sherpaonnx::EnhancementDetectResult& result) {
|
|
50
|
+
NSMutableArray *detectedModelsArray = [NSMutableArray array];
|
|
51
|
+
for (const auto& model : result.detectedModels) {
|
|
52
|
+
[detectedModelsArray addObject:@{
|
|
53
|
+
@"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
|
|
54
|
+
@"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
|
|
55
|
+
}];
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
NSMutableDictionary *dict = [@{
|
|
59
|
+
@"success": @(result.ok),
|
|
60
|
+
@"detectedModels": detectedModelsArray,
|
|
61
|
+
@"modelType": enhancementKindToNSString(result.selectedKind),
|
|
62
|
+
} mutableCopy];
|
|
63
|
+
if (!result.ok && !result.error.empty()) {
|
|
64
|
+
dict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()] ?: @"Enhancement model detection failed";
|
|
65
|
+
}
|
|
66
|
+
return dict;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
} // namespace
|
|
70
|
+
|
|
71
|
+
@implementation SherpaOnnx (Enhancement)
|
|
72
|
+
|
|
73
|
+
- (void)detectEnhancementModel:(NSString *)modelDir
|
|
74
|
+
modelType:(NSString *)modelType
|
|
75
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
76
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
77
|
+
{
|
|
78
|
+
@try {
|
|
79
|
+
std::string modelDirStr = (modelDir != nil) ? [modelDir UTF8String] : "";
|
|
80
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
81
|
+
auto result = sherpaonnx::DetectEnhancementModel(modelDirStr, modelTypeStr);
|
|
82
|
+
resolve(enhancementDetectResultToDict(result));
|
|
83
|
+
} @catch (NSException *exception) {
|
|
84
|
+
reject(@"DETECT_ERROR",
|
|
85
|
+
[NSString stringWithFormat:@"Enhancement detect failed: %@", exception.reason],
|
|
86
|
+
nil);
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
- (void)initializeEnhancement:(NSString *)instanceId
|
|
91
|
+
modelDir:(NSString *)modelDir
|
|
92
|
+
modelType:(NSString *)modelType
|
|
93
|
+
numThreads:(NSNumber *)numThreads
|
|
94
|
+
provider:(NSString *)provider
|
|
95
|
+
debug:(NSNumber *)debug
|
|
96
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
97
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
98
|
+
{
|
|
99
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
100
|
+
reject(@"ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
|
|
101
|
+
return;
|
|
102
|
+
}
|
|
103
|
+
if (modelDir == nil || [modelDir length] == 0) {
|
|
104
|
+
reject(@"ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
|
|
105
|
+
return;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
109
|
+
std::string modelDirStr = [modelDir UTF8String];
|
|
110
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
111
|
+
int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
|
|
112
|
+
bool debugVal = debug != nil && [debug boolValue];
|
|
113
|
+
std::optional<std::string> providerOpt = std::nullopt;
|
|
114
|
+
if (provider != nil && [provider length] > 0) {
|
|
115
|
+
providerOpt = std::string([provider UTF8String]);
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
@try {
|
|
119
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
120
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
121
|
+
if (it == g_enhancement_instances.end()) {
|
|
122
|
+
g_enhancement_instances[instanceIdStr] = std::make_unique<EnhancementInstanceState>();
|
|
123
|
+
}
|
|
124
|
+
auto *inst = g_enhancement_instances[instanceIdStr].get();
|
|
125
|
+
if (inst->wrapper == nullptr) {
|
|
126
|
+
inst->wrapper = std::make_unique<sherpaonnx::EnhancementWrapper>();
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
auto result = inst->wrapper->initialize(
|
|
130
|
+
modelDirStr,
|
|
131
|
+
modelTypeStr,
|
|
132
|
+
numThreadsVal,
|
|
133
|
+
providerOpt,
|
|
134
|
+
debugVal
|
|
135
|
+
);
|
|
136
|
+
|
|
137
|
+
if (!result.success) {
|
|
138
|
+
NSString *errorMsg = result.error.empty()
|
|
139
|
+
? @"Failed to initialize enhancement"
|
|
140
|
+
: [NSString stringWithUTF8String:result.error.c_str()];
|
|
141
|
+
reject(@"ENHANCEMENT_INIT_ERROR", errorMsg, nil);
|
|
142
|
+
return;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
NSMutableArray *detectedModelsArray = [NSMutableArray array];
|
|
146
|
+
for (const auto& model : result.detectedModels) {
|
|
147
|
+
[detectedModelsArray addObject:@{
|
|
148
|
+
@"type": [NSString stringWithUTF8String:model.type.c_str()] ?: @"",
|
|
149
|
+
@"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()] ?: @""
|
|
150
|
+
}];
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
resolve(@{
|
|
154
|
+
@"success": @YES,
|
|
155
|
+
@"detectedModels": detectedModelsArray,
|
|
156
|
+
@"modelType": [NSString stringWithUTF8String:result.modelType.c_str()] ?: @"unknown",
|
|
157
|
+
@"sampleRate": @(result.sampleRate)
|
|
158
|
+
});
|
|
159
|
+
} @catch (NSException *exception) {
|
|
160
|
+
reject(@"ENHANCEMENT_INIT_ERROR",
|
|
161
|
+
[NSString stringWithFormat:@"Enhancement init failed: %@", exception.reason],
|
|
162
|
+
nil);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
- (void)enhanceSamples:(NSString *)instanceId
|
|
167
|
+
samples:(NSArray *)samples
|
|
168
|
+
sampleRate:(double)sampleRate
|
|
169
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
170
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
171
|
+
{
|
|
172
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
173
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
174
|
+
return;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
178
|
+
std::vector<float> floatSamples;
|
|
179
|
+
floatSamples.reserve([samples count]);
|
|
180
|
+
for (NSNumber *n in samples) {
|
|
181
|
+
floatSamples.push_back([n floatValue]);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
@try {
|
|
185
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
186
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
187
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
188
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
189
|
+
return;
|
|
190
|
+
}
|
|
191
|
+
auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
|
|
192
|
+
resolve(enhancedAudioToDict(out));
|
|
193
|
+
} @catch (NSException *exception) {
|
|
194
|
+
reject(@"ENHANCEMENT_ERROR",
|
|
195
|
+
[NSString stringWithFormat:@"Enhance samples failed: %@", exception.reason],
|
|
196
|
+
nil);
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
- (void)enhanceFile:(NSString *)instanceId
|
|
201
|
+
inputPath:(NSString *)inputPath
|
|
202
|
+
outputPath:(NSString *)outputPath
|
|
203
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
204
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
205
|
+
{
|
|
206
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
207
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
208
|
+
return;
|
|
209
|
+
}
|
|
210
|
+
if (inputPath == nil || [inputPath length] == 0) {
|
|
211
|
+
reject(@"ENHANCEMENT_ERROR", @"inputPath is required", nil);
|
|
212
|
+
return;
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
216
|
+
std::string inputPathStr = [inputPath UTF8String];
|
|
217
|
+
|
|
218
|
+
@try {
|
|
219
|
+
sherpa_onnx::cxx::Wave wave = sherpa_onnx::cxx::ReadWave(inputPathStr);
|
|
220
|
+
if (wave.samples.empty() || wave.sample_rate <= 0) {
|
|
221
|
+
reject(@"ENHANCEMENT_ERROR", @"Failed to read input wave file", nil);
|
|
222
|
+
return;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
226
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
227
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
228
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
229
|
+
return;
|
|
230
|
+
}
|
|
231
|
+
auto out = it->second->wrapper->runSamples(wave.samples, wave.sample_rate);
|
|
232
|
+
|
|
233
|
+
if (outputPath != nil && [outputPath length] > 0) {
|
|
234
|
+
sherpa_onnx::cxx::Wave outputWave;
|
|
235
|
+
outputWave.samples = out.samples;
|
|
236
|
+
outputWave.sample_rate = out.sampleRate;
|
|
237
|
+
std::string outputPathStr = [outputPath UTF8String];
|
|
238
|
+
sherpa_onnx::cxx::WriteWave(outputPathStr, outputWave);
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
resolve(enhancedAudioToDict(out));
|
|
242
|
+
} @catch (NSException *exception) {
|
|
243
|
+
reject(@"ENHANCEMENT_ERROR",
|
|
244
|
+
[NSString stringWithFormat:@"Enhance file failed: %@", exception.reason],
|
|
245
|
+
nil);
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
- (void)getEnhancementSampleRate:(NSString *)instanceId
|
|
250
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
251
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
252
|
+
{
|
|
253
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
254
|
+
reject(@"ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
255
|
+
return;
|
|
256
|
+
}
|
|
257
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
258
|
+
|
|
259
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
260
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
261
|
+
if (it == g_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
262
|
+
reject(@"ENHANCEMENT_ERROR", @"Enhancement instance not found", nil);
|
|
263
|
+
return;
|
|
264
|
+
}
|
|
265
|
+
resolve(@(it->second->wrapper->getSampleRate()));
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
- (void)unloadEnhancement:(NSString *)instanceId
|
|
269
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
270
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
271
|
+
{
|
|
272
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
273
|
+
resolve(nil);
|
|
274
|
+
return;
|
|
275
|
+
}
|
|
276
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
277
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
278
|
+
auto it = g_enhancement_instances.find(instanceIdStr);
|
|
279
|
+
if (it != g_enhancement_instances.end() && it->second->wrapper != nullptr) {
|
|
280
|
+
it->second->wrapper->release();
|
|
281
|
+
g_enhancement_instances.erase(it);
|
|
282
|
+
}
|
|
283
|
+
resolve(nil);
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
- (void)initializeOnlineEnhancement:(NSString *)instanceId
|
|
287
|
+
modelDir:(NSString *)modelDir
|
|
288
|
+
modelType:(NSString *)modelType
|
|
289
|
+
numThreads:(NSNumber *)numThreads
|
|
290
|
+
provider:(NSString *)provider
|
|
291
|
+
debug:(NSNumber *)debug
|
|
292
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
293
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
294
|
+
{
|
|
295
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
296
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"instanceId is required", nil);
|
|
297
|
+
return;
|
|
298
|
+
}
|
|
299
|
+
if (modelDir == nil || [modelDir length] == 0) {
|
|
300
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", @"modelDir is required", nil);
|
|
301
|
+
return;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
305
|
+
std::string modelDirStr = [modelDir UTF8String];
|
|
306
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "auto";
|
|
307
|
+
int32_t numThreadsVal = numThreads != nil ? [numThreads intValue] : 1;
|
|
308
|
+
bool debugVal = debug != nil && [debug boolValue];
|
|
309
|
+
std::optional<std::string> providerOpt = std::nullopt;
|
|
310
|
+
if (provider != nil && [provider length] > 0) {
|
|
311
|
+
providerOpt = std::string([provider UTF8String]);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
@try {
|
|
315
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
316
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
317
|
+
if (it == g_online_enhancement_instances.end()) {
|
|
318
|
+
g_online_enhancement_instances[instanceIdStr] = std::make_unique<OnlineEnhancementInstanceState>();
|
|
319
|
+
}
|
|
320
|
+
auto *inst = g_online_enhancement_instances[instanceIdStr].get();
|
|
321
|
+
if (inst->wrapper == nullptr) {
|
|
322
|
+
inst->wrapper = std::make_unique<sherpaonnx::OnlineEnhancementWrapper>();
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
auto result = inst->wrapper->initialize(
|
|
326
|
+
modelDirStr,
|
|
327
|
+
modelTypeStr,
|
|
328
|
+
numThreadsVal,
|
|
329
|
+
providerOpt,
|
|
330
|
+
debugVal
|
|
331
|
+
);
|
|
332
|
+
if (!result.success) {
|
|
333
|
+
NSString *errorMsg = result.error.empty()
|
|
334
|
+
? @"Failed to initialize online enhancement"
|
|
335
|
+
: [NSString stringWithUTF8String:result.error.c_str()];
|
|
336
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR", errorMsg, nil);
|
|
337
|
+
return;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
resolve(@{
|
|
341
|
+
@"success": @YES,
|
|
342
|
+
@"sampleRate": @(result.sampleRate),
|
|
343
|
+
@"frameShiftInSamples": @(result.frameShiftInSamples)
|
|
344
|
+
});
|
|
345
|
+
} @catch (NSException *exception) {
|
|
346
|
+
reject(@"ONLINE_ENHANCEMENT_INIT_ERROR",
|
|
347
|
+
[NSString stringWithFormat:@"Online enhancement init failed: %@", exception.reason],
|
|
348
|
+
nil);
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
- (void)feedEnhancementSamples:(NSString *)instanceId
|
|
353
|
+
samples:(NSArray *)samples
|
|
354
|
+
sampleRate:(double)sampleRate
|
|
355
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
356
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
357
|
+
{
|
|
358
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
359
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
360
|
+
return;
|
|
361
|
+
}
|
|
362
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
363
|
+
std::vector<float> floatSamples;
|
|
364
|
+
floatSamples.reserve([samples count]);
|
|
365
|
+
for (NSNumber *n in samples) {
|
|
366
|
+
floatSamples.push_back([n floatValue]);
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
370
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
371
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
372
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
373
|
+
return;
|
|
374
|
+
}
|
|
375
|
+
auto out = it->second->wrapper->runSamples(floatSamples, static_cast<int32_t>(sampleRate));
|
|
376
|
+
resolve(enhancedAudioToDict(out));
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
- (void)flushOnlineEnhancement:(NSString *)instanceId
|
|
380
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
381
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
382
|
+
{
|
|
383
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
384
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"instanceId is required", nil);
|
|
385
|
+
return;
|
|
386
|
+
}
|
|
387
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
388
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
389
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
390
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
391
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
392
|
+
return;
|
|
393
|
+
}
|
|
394
|
+
auto out = it->second->wrapper->flush();
|
|
395
|
+
resolve(enhancedAudioToDict(out));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
- (void)resetOnlineEnhancement:(NSString *)instanceId
|
|
399
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
400
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
401
|
+
{
|
|
402
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
403
|
+
resolve(nil);
|
|
404
|
+
return;
|
|
405
|
+
}
|
|
406
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
407
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
408
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
409
|
+
if (it == g_online_enhancement_instances.end() || it->second->wrapper == nullptr) {
|
|
410
|
+
reject(@"ONLINE_ENHANCEMENT_ERROR", @"Online enhancement instance not found", nil);
|
|
411
|
+
return;
|
|
412
|
+
}
|
|
413
|
+
it->second->wrapper->reset();
|
|
414
|
+
resolve(nil);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
- (void)unloadOnlineEnhancement:(NSString *)instanceId
|
|
418
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
419
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
420
|
+
{
|
|
421
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
422
|
+
resolve(nil);
|
|
423
|
+
return;
|
|
424
|
+
}
|
|
425
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
426
|
+
std::lock_guard<std::mutex> lock(g_enhancement_mutex);
|
|
427
|
+
auto it = g_online_enhancement_instances.find(instanceIdStr);
|
|
428
|
+
if (it != g_online_enhancement_instances.end() && it->second->wrapper != nullptr) {
|
|
429
|
+
it->second->wrapper->release();
|
|
430
|
+
g_online_enhancement_instances.erase(it);
|
|
431
|
+
}
|
|
432
|
+
resolve(nil);
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
@end
|
package/ios/SherpaOnnx+STT.mm
CHANGED
|
@@ -36,6 +36,7 @@ static NSString *sttModelKindToNSString(sherpaonnx::SttModelKind kind) {
|
|
|
36
36
|
case K::kZipformerCtc: return @"zipformer_ctc";
|
|
37
37
|
case K::kWhisper: return @"whisper";
|
|
38
38
|
case K::kFunAsrNano: return @"funasr_nano";
|
|
39
|
+
case K::kQwen3Asr: return @"qwen3_asr";
|
|
39
40
|
case K::kFireRedAsr: return @"fire_red_asr";
|
|
40
41
|
case K::kMoonshine: return @"moonshine";
|
|
41
42
|
case K::kMoonshineV2: return @"moonshine_v2";
|
|
@@ -164,10 +165,12 @@ static NSDictionary *sttResultToDict(const sherpaonnx::SttRecognitionResult& r)
|
|
|
164
165
|
sherpaonnx::SttSenseVoiceOptions senseVoiceOpts;
|
|
165
166
|
sherpaonnx::SttCanaryOptions canaryOpts;
|
|
166
167
|
sherpaonnx::SttFunAsrNanoOptions funasrNanoOpts;
|
|
168
|
+
sherpaonnx::SttQwen3AsrOptions qwen3AsrOpts;
|
|
167
169
|
const sherpaonnx::SttWhisperOptions *whisperOptsPtr = nullptr;
|
|
168
170
|
const sherpaonnx::SttSenseVoiceOptions *senseVoiceOptsPtr = nullptr;
|
|
169
171
|
const sherpaonnx::SttCanaryOptions *canaryOptsPtr = nullptr;
|
|
170
172
|
const sherpaonnx::SttFunAsrNanoOptions *funasrNanoOptsPtr = nullptr;
|
|
173
|
+
const sherpaonnx::SttQwen3AsrOptions *qwen3AsrOptsPtr = nullptr;
|
|
171
174
|
if (modelOptions != nil && [modelOptions isKindOfClass:[NSDictionary class]]) {
|
|
172
175
|
NSDictionary *w = modelOptions[@"whisper"];
|
|
173
176
|
if ([w isKindOfClass:[NSDictionary class]]) {
|
|
@@ -202,12 +205,21 @@ static NSDictionary *sttResultToDict(const sherpaonnx::SttRecognitionResult& r)
|
|
|
202
205
|
if (fn[@"hotwords"] != nil) funasrNanoOpts.hotwords = std::string([(NSString *)fn[@"hotwords"] UTF8String]);
|
|
203
206
|
funasrNanoOptsPtr = &funasrNanoOpts;
|
|
204
207
|
}
|
|
208
|
+
NSDictionary *q3 = modelOptions[@"qwen3Asr"];
|
|
209
|
+
if ([q3 isKindOfClass:[NSDictionary class]]) {
|
|
210
|
+
if (q3[@"maxTotalLen"] != nil) qwen3AsrOpts.max_total_len = [(NSNumber *)q3[@"maxTotalLen"] intValue];
|
|
211
|
+
if (q3[@"maxNewTokens"] != nil) qwen3AsrOpts.max_new_tokens = [(NSNumber *)q3[@"maxNewTokens"] intValue];
|
|
212
|
+
if (q3[@"temperature"] != nil) qwen3AsrOpts.temperature = [(NSNumber *)q3[@"temperature"] floatValue];
|
|
213
|
+
if (q3[@"topP"] != nil) qwen3AsrOpts.top_p = [(NSNumber *)q3[@"topP"] floatValue];
|
|
214
|
+
if (q3[@"seed"] != nil) qwen3AsrOpts.seed = [(NSNumber *)q3[@"seed"] intValue];
|
|
215
|
+
qwen3AsrOptsPtr = &qwen3AsrOpts;
|
|
216
|
+
}
|
|
205
217
|
}
|
|
206
218
|
|
|
207
219
|
sherpaonnx::SttInitializeResult result = inst->wrapper->initialize(
|
|
208
220
|
modelDirStr, preferInt8Opt, modelTypeOpt, debugVal, hotwordsFileOpt, hotwordsScoreOpt,
|
|
209
221
|
numThreadsOpt, providerOpt, ruleFstsOpt, ruleFarsOpt, ditherOpt,
|
|
210
|
-
whisperOptsPtr, senseVoiceOptsPtr, canaryOptsPtr, funasrNanoOptsPtr);
|
|
222
|
+
whisperOptsPtr, senseVoiceOptsPtr, canaryOptsPtr, funasrNanoOptsPtr, qwen3AsrOptsPtr);
|
|
211
223
|
|
|
212
224
|
if (result.success) {
|
|
213
225
|
RCTLogInfo(@"Sherpa-onnx initialized successfully");
|
package/ios/SherpaOnnx.mm
CHANGED
|
@@ -138,9 +138,15 @@
|
|
|
138
138
|
- (void)extractTarBz2:(NSString *)sourcePath
|
|
139
139
|
targetPath:(NSString *)targetPath
|
|
140
140
|
force:(BOOL)force
|
|
141
|
-
|
|
142
|
-
|
|
141
|
+
showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
|
|
142
|
+
notificationTitle:(NSString *)notificationTitle
|
|
143
|
+
notificationText:(NSString *)notificationText
|
|
144
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
145
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
143
146
|
{
|
|
147
|
+
(void)showNotificationsEnabled;
|
|
148
|
+
(void)notificationTitle;
|
|
149
|
+
(void)notificationText;
|
|
144
150
|
SherpaOnnxArchiveHelper *helper = [SherpaOnnxArchiveHelper new];
|
|
145
151
|
NSDictionary *result = [helper extractTarBz2:sourcePath
|
|
146
152
|
targetPath:targetPath
|
|
@@ -165,9 +171,15 @@
|
|
|
165
171
|
- (void)extractTarZst:(NSString *)sourcePath
|
|
166
172
|
targetPath:(NSString *)targetPath
|
|
167
173
|
force:(BOOL)force
|
|
168
|
-
|
|
169
|
-
|
|
174
|
+
showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
|
|
175
|
+
notificationTitle:(NSString *)notificationTitle
|
|
176
|
+
notificationText:(NSString *)notificationText
|
|
177
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
178
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
170
179
|
{
|
|
180
|
+
(void)showNotificationsEnabled;
|
|
181
|
+
(void)notificationTitle;
|
|
182
|
+
(void)notificationText;
|
|
171
183
|
SherpaOnnxArchiveHelper *helper = [SherpaOnnxArchiveHelper new];
|
|
172
184
|
NSDictionary *result = [helper extractTarZst:sourcePath
|
|
173
185
|
targetPath:targetPath
|
|
@@ -229,19 +241,33 @@
|
|
|
229
241
|
|
|
230
242
|
- (void)extractTarZstFromAsset:(NSString *)assetPath
|
|
231
243
|
targetPath:(NSString *)targetPath
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
244
|
+
force:(BOOL)force
|
|
245
|
+
showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
|
|
246
|
+
notificationTitle:(NSString *)notificationTitle
|
|
247
|
+
notificationText:(NSString *)notificationText
|
|
248
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
249
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
235
250
|
{
|
|
251
|
+
(void)force;
|
|
252
|
+
(void)showNotificationsEnabled;
|
|
253
|
+
(void)notificationTitle;
|
|
254
|
+
(void)notificationText;
|
|
236
255
|
resolve(@{ @"success": @NO, @"reason": @"Not supported on iOS; use path-based extraction." });
|
|
237
256
|
}
|
|
238
257
|
|
|
239
258
|
- (void)extractTarBz2FromAsset:(NSString *)assetPath
|
|
240
259
|
targetPath:(NSString *)targetPath
|
|
241
|
-
force:(
|
|
242
|
-
|
|
243
|
-
|
|
260
|
+
force:(BOOL)force
|
|
261
|
+
showNotificationsEnabled:(NSNumber *)showNotificationsEnabled
|
|
262
|
+
notificationTitle:(NSString *)notificationTitle
|
|
263
|
+
notificationText:(NSString *)notificationText
|
|
264
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
265
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
244
266
|
{
|
|
267
|
+
(void)force;
|
|
268
|
+
(void)showNotificationsEnabled;
|
|
269
|
+
(void)notificationTitle;
|
|
270
|
+
(void)notificationText;
|
|
245
271
|
resolve(@{ @"success": @NO, @"reason": @"Not supported on iOS; use path-based extraction." });
|
|
246
272
|
}
|
|
247
273
|
|
|
@@ -329,15 +355,59 @@
|
|
|
329
355
|
nil);
|
|
330
356
|
return;
|
|
331
357
|
}
|
|
332
|
-
NSString *
|
|
333
|
-
|
|
358
|
+
NSString *fullPath = nil;
|
|
359
|
+
NSBundle *mainBundle = [NSBundle mainBundle];
|
|
360
|
+
NSString *assetDir = [assetPath stringByDeletingLastPathComponent];
|
|
361
|
+
NSString *assetNameWithExt = [assetPath lastPathComponent];
|
|
362
|
+
NSString *assetName = [assetNameWithExt stringByDeletingPathExtension];
|
|
363
|
+
NSString *assetExt = [assetNameWithExt pathExtension];
|
|
364
|
+
|
|
365
|
+
// 1) App bundle: regular nested path (keeps generic asset support)
|
|
366
|
+
NSString *mainPath = [mainBundle pathForResource:assetName
|
|
367
|
+
ofType:assetExt.length > 0 ? assetExt : nil
|
|
368
|
+
inDirectory:assetDir.length > 0 ? assetDir : nil];
|
|
369
|
+
if (mainPath.length > 0) {
|
|
370
|
+
fullPath = mainPath;
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
// 2) CocoaPods resource bundle: files are flattened into bundle root
|
|
374
|
+
if (!fullPath) {
|
|
375
|
+
NSString *resBundlePath = [mainBundle pathForResource:@"SherpaOnnxResources"
|
|
376
|
+
ofType:@"bundle"];
|
|
377
|
+
if (resBundlePath.length > 0) {
|
|
378
|
+
NSBundle *resBundle = [NSBundle bundleWithPath:resBundlePath];
|
|
379
|
+
if (resBundle) {
|
|
380
|
+
NSString *bundleRootPath = [resBundle pathForResource:assetName
|
|
381
|
+
ofType:assetExt.length > 0 ? assetExt : nil];
|
|
382
|
+
if (bundleRootPath.length > 0) {
|
|
383
|
+
fullPath = bundleRootPath;
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
if (!fullPath) {
|
|
390
|
+
reject(@"ASSET_READ_ERROR",
|
|
391
|
+
[NSString stringWithFormat:@"Failed to locate asset %@", assetPath],
|
|
392
|
+
nil);
|
|
393
|
+
return;
|
|
394
|
+
}
|
|
395
|
+
|
|
334
396
|
NSError *error = nil;
|
|
335
|
-
NSString *content = [NSString stringWithContentsOfFile:fullPath
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
397
|
+
NSString *content = [NSString stringWithContentsOfFile:fullPath
|
|
398
|
+
encoding:NSUTF8StringEncoding
|
|
399
|
+
error:&error];
|
|
400
|
+
if (error || content == nil) {
|
|
401
|
+
reject(@"ASSET_READ_ERROR",
|
|
402
|
+
[NSString stringWithFormat:@"Failed to read asset %@ at %@: %@",
|
|
403
|
+
assetPath,
|
|
404
|
+
fullPath,
|
|
405
|
+
error.localizedDescription ?: @"Unknown error"],
|
|
406
|
+
error);
|
|
407
|
+
return;
|
|
340
408
|
}
|
|
409
|
+
|
|
410
|
+
resolve(content);
|
|
341
411
|
}
|
|
342
412
|
|
|
343
413
|
@end
|