react-native-sherpa-onnx 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +21 -7
- package/SherpaOnnx.podspec +1 -1
- package/android/build.gradle +35 -26
- package/android/prebuilt-download.gradle +27 -14
- package/android/src/main/cpp/CMakeLists.txt +51 -17
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +14 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +16 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +19 -2
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +2 -1
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +1 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +114 -8
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxOnlineSttHelper.kt +535 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +10 -10
- package/ios/SherpaOnnx+OnlineSTT.mm +365 -0
- package/ios/SherpaOnnx+TTS.mm +35 -9
- package/ios/SherpaOnnx.mm +6 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +16 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +19 -2
- package/ios/model_detect/sherpa-onnx-model-detect.h +2 -1
- package/ios/online_stt/sherpa-onnx-online-stt-wrapper.h +85 -0
- package/ios/online_stt/sherpa-onnx-online-stt-wrapper.mm +270 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/index.js +2 -2
- package/lib/module/stt/index.js +4 -0
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/streaming.js +257 -0
- package/lib/module/stt/streaming.js.map +1 -0
- package/lib/module/stt/streamingTypes.js +38 -0
- package/lib/module/stt/streamingTypes.js.map +1 -0
- package/lib/module/tts/index.js +4 -43
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/streaming.js +220 -0
- package/lib/module/tts/streaming.js.map +1 -0
- package/lib/module/tts/streamingTypes.js +4 -0
- package/lib/module/tts/streamingTypes.js.map +1 -0
- package/lib/module/tts/types.js +8 -1
- package/lib/module/tts/types.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +66 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +3 -0
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/streaming.d.ts +42 -0
- package/lib/typescript/src/stt/streaming.d.ts.map +1 -0
- package/lib/typescript/src/stt/streamingTypes.d.ts +122 -0
- package/lib/typescript/src/stt/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/tts/index.d.ts +3 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/streaming.d.ts +24 -0
- package/lib/typescript/src/tts/streaming.d.ts.map +1 -0
- package/lib/typescript/src/tts/streamingTypes.d.ts +27 -0
- package/lib/typescript/src/tts/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/tts/types.d.ts +19 -6
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeSherpaOnnx.ts +95 -0
- package/src/index.tsx +2 -2
- package/src/stt/index.ts +17 -0
- package/src/stt/streaming.ts +361 -0
- package/src/stt/streamingTypes.ts +151 -0
- package/src/tts/index.ts +6 -66
- package/src/tts/streaming.ts +336 -0
- package/src/tts/streamingTypes.ts +54 -0
- package/src/tts/types.ts +20 -10
- package/android/codegen.gradle +0 -57
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* SherpaOnnx+OnlineSTT.mm
|
|
3
|
+
*
|
|
4
|
+
* Purpose: iOS TurboModule methods for streaming (online) STT: initializeOnlineSttWithOptions,
|
|
5
|
+
* createSttStream, acceptSttWaveform, decodeSttStream, getSttStreamResult, etc.
|
|
6
|
+
* Uses sherpa-onnx-online-stt-wrapper for native OnlineRecognizer.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#import "SherpaOnnx.h"
|
|
10
|
+
#import <React/RCTLog.h>
|
|
11
|
+
|
|
12
|
+
#include "sherpa-onnx-online-stt-wrapper.h"
|
|
13
|
+
#include <memory>
|
|
14
|
+
#include <mutex>
|
|
15
|
+
#include <string>
|
|
16
|
+
#include <unordered_map>
|
|
17
|
+
|
|
18
|
+
static std::unordered_map<std::string, std::unique_ptr<sherpaonnx::OnlineSttWrapper>> g_online_stt_instances;
|
|
19
|
+
static std::unordered_map<std::string, std::string> g_online_stt_stream_to_instance;
|
|
20
|
+
static std::mutex g_online_stt_mutex;
|
|
21
|
+
|
|
22
|
+
static sherpaonnx::OnlineSttWrapper* getOnlineSttInstance(NSString* instanceId) {
|
|
23
|
+
if (instanceId == nil || [instanceId length] == 0) return nullptr;
|
|
24
|
+
std::string key = [instanceId UTF8String];
|
|
25
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
26
|
+
auto it = g_online_stt_instances.find(key);
|
|
27
|
+
return (it != g_online_stt_instances.end() && it->second != nullptr) ? it->second.get() : nullptr;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
static sherpaonnx::OnlineSttWrapper* getOnlineSttInstanceForStream(NSString* streamId) {
|
|
31
|
+
if (streamId == nil || [streamId length] == 0) return nullptr;
|
|
32
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
33
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
34
|
+
auto sit = g_online_stt_stream_to_instance.find(streamIdStr);
|
|
35
|
+
if (sit == g_online_stt_stream_to_instance.end()) return nullptr;
|
|
36
|
+
auto it = g_online_stt_instances.find(sit->second);
|
|
37
|
+
return (it != g_online_stt_instances.end() && it->second != nullptr) ? it->second.get() : nullptr;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@implementation SherpaOnnx (OnlineSTT)
|
|
42
|
+
|
|
43
|
+
- (void)initializeOnlineSttWithOptions:(NSString *)instanceId
|
|
44
|
+
options:(JS::NativeSherpaOnnx::SpecInitializeOnlineSttWithOptionsOptions &)options
|
|
45
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
46
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
47
|
+
{
|
|
48
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
49
|
+
reject(@"INIT_ERROR", @"instanceId is required", nil);
|
|
50
|
+
return;
|
|
51
|
+
}
|
|
52
|
+
NSString *modelDir = options.modelDir();
|
|
53
|
+
NSString *modelType = options.modelType();
|
|
54
|
+
RCTLogInfo(@"[SherpaOnnx OnlineSTT] initializeOnlineSttWithOptions instanceId=%@ modelDir=%@ modelType=%@",
|
|
55
|
+
instanceId, modelDir, modelType);
|
|
56
|
+
if (modelDir == nil || [modelDir length] == 0) {
|
|
57
|
+
reject(@"INIT_ERROR", @"modelDir is required", nil);
|
|
58
|
+
return;
|
|
59
|
+
}
|
|
60
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
61
|
+
std::string modelDirStr = [modelDir UTF8String];
|
|
62
|
+
std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "transducer";
|
|
63
|
+
|
|
64
|
+
auto enableEndpoint = options.enableEndpoint();
|
|
65
|
+
NSString *decodingMethod = options.decodingMethod();
|
|
66
|
+
auto maxActivePaths = options.maxActivePaths();
|
|
67
|
+
NSString *hotwordsFile = options.hotwordsFile();
|
|
68
|
+
auto hotwordsScore = options.hotwordsScore();
|
|
69
|
+
auto numThreads = options.numThreads();
|
|
70
|
+
NSString *provider = options.provider();
|
|
71
|
+
NSString *ruleFsts = options.ruleFsts();
|
|
72
|
+
NSString *ruleFars = options.ruleFars();
|
|
73
|
+
auto blankPenalty = options.blankPenalty();
|
|
74
|
+
auto debug = options.debug();
|
|
75
|
+
auto rule1MustContainNonSilence = options.rule1MustContainNonSilence();
|
|
76
|
+
auto rule1MinTrailingSilence = options.rule1MinTrailingSilence();
|
|
77
|
+
auto rule1MinUtteranceLength = options.rule1MinUtteranceLength();
|
|
78
|
+
auto rule2MustContainNonSilence = options.rule2MustContainNonSilence();
|
|
79
|
+
auto rule2MinTrailingSilence = options.rule2MinTrailingSilence();
|
|
80
|
+
auto rule2MinUtteranceLength = options.rule2MinUtteranceLength();
|
|
81
|
+
auto rule3MustContainNonSilence = options.rule3MustContainNonSilence();
|
|
82
|
+
auto rule3MinTrailingSilence = options.rule3MinTrailingSilence();
|
|
83
|
+
auto rule3MinUtteranceLength = options.rule3MinUtteranceLength();
|
|
84
|
+
|
|
85
|
+
@try {
|
|
86
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
87
|
+
if (g_online_stt_instances.find(instanceIdStr) != g_online_stt_instances.end()) {
|
|
88
|
+
reject(@"INIT_ERROR", @"Online STT instance already exists", nil);
|
|
89
|
+
return;
|
|
90
|
+
}
|
|
91
|
+
RCTLogInfo(@"[SherpaOnnx OnlineSTT] creating wrapper and calling initialize");
|
|
92
|
+
auto wrapper = std::make_unique<sherpaonnx::OnlineSttWrapper>();
|
|
93
|
+
sherpaonnx::OnlineSttInitResult result = wrapper->initialize(
|
|
94
|
+
modelDirStr,
|
|
95
|
+
modelTypeStr,
|
|
96
|
+
enableEndpoint.has_value() && enableEndpoint.value(),
|
|
97
|
+
decodingMethod != nil ? [decodingMethod UTF8String] : "greedy_search",
|
|
98
|
+
maxActivePaths.has_value() ? (int32_t)maxActivePaths.value() : 4,
|
|
99
|
+
hotwordsFile != nil ? [hotwordsFile UTF8String] : "",
|
|
100
|
+
hotwordsScore.has_value() ? (float)hotwordsScore.value() : 1.5f,
|
|
101
|
+
numThreads.has_value() ? (int32_t)numThreads.value() : 1,
|
|
102
|
+
provider != nil ? [provider UTF8String] : "cpu",
|
|
103
|
+
ruleFsts != nil ? [ruleFsts UTF8String] : "",
|
|
104
|
+
ruleFars != nil ? [ruleFars UTF8String] : "",
|
|
105
|
+
blankPenalty.has_value() ? (float)blankPenalty.value() : 0.f,
|
|
106
|
+
debug.has_value() && debug.value(),
|
|
107
|
+
rule1MustContainNonSilence.has_value() && rule1MustContainNonSilence.value(),
|
|
108
|
+
rule1MinTrailingSilence.has_value() ? (float)rule1MinTrailingSilence.value() : 2.4f,
|
|
109
|
+
rule1MinUtteranceLength.has_value() ? (float)rule1MinUtteranceLength.value() : 0.f,
|
|
110
|
+
rule2MustContainNonSilence.has_value() && rule2MustContainNonSilence.value(),
|
|
111
|
+
rule2MinTrailingSilence.has_value() ? (float)rule2MinTrailingSilence.value() : 1.2f,
|
|
112
|
+
rule2MinUtteranceLength.has_value() ? (float)rule2MinUtteranceLength.value() : 0.f,
|
|
113
|
+
rule3MustContainNonSilence.has_value() && rule3MustContainNonSilence.value(),
|
|
114
|
+
rule3MinTrailingSilence.has_value() ? (float)rule3MinTrailingSilence.value() : 0.f,
|
|
115
|
+
rule3MinUtteranceLength.has_value() ? (float)rule3MinUtteranceLength.value() : 20.f
|
|
116
|
+
);
|
|
117
|
+
if (!result.success) {
|
|
118
|
+
RCTLogError(@"[SherpaOnnx OnlineSTT] initialize failed: %s", result.error.c_str());
|
|
119
|
+
reject(@"INIT_ERROR", [NSString stringWithUTF8String:result.error.c_str()], nil);
|
|
120
|
+
return;
|
|
121
|
+
}
|
|
122
|
+
g_online_stt_instances[instanceIdStr] = std::move(wrapper);
|
|
123
|
+
RCTLogInfo(@"[SherpaOnnx OnlineSTT] init success for instanceId=%@", instanceId);
|
|
124
|
+
resolve(@{ @"success": @YES });
|
|
125
|
+
} @catch (NSException *exception) {
|
|
126
|
+
NSString *errorMsg = [NSString stringWithFormat:@"Online STT init failed: %@", exception.reason];
|
|
127
|
+
RCTLogError(@"%@", errorMsg);
|
|
128
|
+
reject(@"INIT_ERROR", errorMsg, nil);
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
- (void)createSttStream:(NSString *)instanceId
|
|
133
|
+
streamId:(NSString *)streamId
|
|
134
|
+
hotwords:(NSString *)hotwords
|
|
135
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
136
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
137
|
+
{
|
|
138
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstance(instanceId);
|
|
139
|
+
if (wrapper == nullptr) {
|
|
140
|
+
reject(@"STREAM_ERROR", @"Online STT instance not found", nil);
|
|
141
|
+
return;
|
|
142
|
+
}
|
|
143
|
+
std::string instanceIdStr = [instanceId UTF8String];
|
|
144
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
145
|
+
std::string hotwordsStr = hotwords != nil ? [hotwords UTF8String] : "";
|
|
146
|
+
if (!wrapper->createStream(streamIdStr, hotwordsStr)) {
|
|
147
|
+
reject(@"STREAM_ERROR", @"Stream already exists or create failed", nil);
|
|
148
|
+
return;
|
|
149
|
+
}
|
|
150
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
151
|
+
g_online_stt_stream_to_instance[streamIdStr] = instanceIdStr;
|
|
152
|
+
resolve(nil);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
- (void)acceptSttWaveform:(NSString *)streamId
|
|
156
|
+
samples:(NSArray *)samples
|
|
157
|
+
sampleRate:(double)sampleRate
|
|
158
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
159
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
160
|
+
{
|
|
161
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
162
|
+
if (wrapper == nullptr) {
|
|
163
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
164
|
+
return;
|
|
165
|
+
}
|
|
166
|
+
std::vector<float> floatSamples;
|
|
167
|
+
floatSamples.reserve([samples count]);
|
|
168
|
+
for (NSNumber* n in samples) {
|
|
169
|
+
floatSamples.push_back([n floatValue]);
|
|
170
|
+
}
|
|
171
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
172
|
+
wrapper->acceptWaveform(streamIdStr, (int32_t)sampleRate, floatSamples.data(), floatSamples.size());
|
|
173
|
+
resolve(nil);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
- (void)sttStreamInputFinished:(NSString *)streamId
|
|
177
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
178
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
179
|
+
{
|
|
180
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
181
|
+
if (wrapper == nullptr) {
|
|
182
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
183
|
+
return;
|
|
184
|
+
}
|
|
185
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
186
|
+
wrapper->inputFinished(streamIdStr);
|
|
187
|
+
resolve(nil);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
- (void)decodeSttStream:(NSString *)streamId
|
|
191
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
192
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
193
|
+
{
|
|
194
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
195
|
+
if (wrapper == nullptr) {
|
|
196
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
197
|
+
return;
|
|
198
|
+
}
|
|
199
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
200
|
+
wrapper->decode(streamIdStr);
|
|
201
|
+
resolve(nil);
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
- (void)isSttStreamReady:(NSString *)streamId
|
|
205
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
206
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
207
|
+
{
|
|
208
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
209
|
+
if (wrapper == nullptr) {
|
|
210
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
211
|
+
return;
|
|
212
|
+
}
|
|
213
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
214
|
+
BOOL ready = wrapper->isReady(streamIdStr);
|
|
215
|
+
resolve(@(ready));
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
- (void)getSttStreamResult:(NSString *)streamId
|
|
219
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
220
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
221
|
+
{
|
|
222
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
223
|
+
if (wrapper == nullptr) {
|
|
224
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
225
|
+
return;
|
|
226
|
+
}
|
|
227
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
228
|
+
sherpaonnx::OnlineSttStreamResult r = wrapper->getResult(streamIdStr);
|
|
229
|
+
NSMutableArray* tokens = [NSMutableArray arrayWithCapacity:r.tokens.size()];
|
|
230
|
+
for (const auto& t : r.tokens) {
|
|
231
|
+
[tokens addObject:[NSString stringWithUTF8String:t.c_str()]];
|
|
232
|
+
}
|
|
233
|
+
NSMutableArray* timestamps = [NSMutableArray arrayWithCapacity:r.timestamps.size()];
|
|
234
|
+
for (float ts : r.timestamps) {
|
|
235
|
+
[timestamps addObject:@(ts)];
|
|
236
|
+
}
|
|
237
|
+
resolve(@{
|
|
238
|
+
@"text": [NSString stringWithUTF8String:r.text.c_str()] ?: @"",
|
|
239
|
+
@"tokens": tokens,
|
|
240
|
+
@"timestamps": timestamps
|
|
241
|
+
});
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
- (void)isSttStreamEndpoint:(NSString *)streamId
|
|
245
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
246
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
247
|
+
{
|
|
248
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
249
|
+
if (wrapper == nullptr) {
|
|
250
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
251
|
+
return;
|
|
252
|
+
}
|
|
253
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
254
|
+
BOOL endpoint = wrapper->isEndpoint(streamIdStr);
|
|
255
|
+
resolve(@(endpoint));
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
- (void)resetSttStream:(NSString *)streamId
|
|
259
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
260
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
261
|
+
{
|
|
262
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
263
|
+
if (wrapper == nullptr) {
|
|
264
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
265
|
+
return;
|
|
266
|
+
}
|
|
267
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
268
|
+
wrapper->resetStream(streamIdStr);
|
|
269
|
+
resolve(nil);
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
- (void)releaseSttStream:(NSString *)streamId
|
|
273
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
274
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
275
|
+
{
|
|
276
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
277
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
278
|
+
if (wrapper != nullptr) {
|
|
279
|
+
wrapper->releaseStream(streamIdStr);
|
|
280
|
+
}
|
|
281
|
+
{
|
|
282
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
283
|
+
g_online_stt_stream_to_instance.erase(streamIdStr);
|
|
284
|
+
}
|
|
285
|
+
resolve(nil);
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
- (void)unloadOnlineStt:(NSString *)instanceId
|
|
289
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
290
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
291
|
+
{
|
|
292
|
+
if (instanceId == nil || [instanceId length] == 0) {
|
|
293
|
+
resolve(nil);
|
|
294
|
+
return;
|
|
295
|
+
}
|
|
296
|
+
std::string key = [instanceId UTF8String];
|
|
297
|
+
@try {
|
|
298
|
+
std::lock_guard<std::mutex> lock(g_online_stt_mutex);
|
|
299
|
+
auto it = g_online_stt_instances.find(key);
|
|
300
|
+
if (it != g_online_stt_instances.end()) {
|
|
301
|
+
it->second->unload();
|
|
302
|
+
for (auto sit = g_online_stt_stream_to_instance.begin(); sit != g_online_stt_stream_to_instance.end(); ) {
|
|
303
|
+
if (sit->second == key) sit = g_online_stt_stream_to_instance.erase(sit);
|
|
304
|
+
else ++sit;
|
|
305
|
+
}
|
|
306
|
+
g_online_stt_instances.erase(it);
|
|
307
|
+
}
|
|
308
|
+
resolve(nil);
|
|
309
|
+
} @catch (NSException *exception) {
|
|
310
|
+
reject(@"RELEASE_ERROR", [NSString stringWithFormat:@"unloadOnlineStt failed: %@", exception.reason], nil);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
- (void)processSttAudioChunk:(NSString *)streamId
|
|
315
|
+
samples:(NSArray *)samples
|
|
316
|
+
sampleRate:(double)sampleRate
|
|
317
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
318
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
319
|
+
{
|
|
320
|
+
sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
|
|
321
|
+
if (wrapper == nullptr) {
|
|
322
|
+
reject(@"STREAM_ERROR", @"Stream not found", nil);
|
|
323
|
+
return;
|
|
324
|
+
}
|
|
325
|
+
std::string streamIdStr = [streamId UTF8String];
|
|
326
|
+
std::vector<float> floatSamples;
|
|
327
|
+
NSUInteger count = [samples count];
|
|
328
|
+
floatSamples.reserve(count);
|
|
329
|
+
for (NSUInteger i = 0; i < count; i++) {
|
|
330
|
+
id obj = [samples objectAtIndex:i];
|
|
331
|
+
float val = 0.0f;
|
|
332
|
+
if ([obj isKindOfClass:[NSNumber class]]) {
|
|
333
|
+
val = [(NSNumber *)obj floatValue];
|
|
334
|
+
} else if ([obj respondsToSelector:@selector(doubleValue)]) {
|
|
335
|
+
val = (float)[(id)obj doubleValue];
|
|
336
|
+
}
|
|
337
|
+
floatSamples.push_back(val);
|
|
338
|
+
}
|
|
339
|
+
if (floatSamples.empty()) {
|
|
340
|
+
RCTLogWarn(@"[SherpaOnnx OnlineSTT] processSttAudioChunk: no samples (count=%lu)", (unsigned long)count);
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
wrapper->acceptWaveform(streamIdStr, (int32_t)sampleRate, floatSamples.data(), floatSamples.size());
|
|
344
|
+
while (wrapper->isReady(streamIdStr)) {
|
|
345
|
+
wrapper->decode(streamIdStr);
|
|
346
|
+
}
|
|
347
|
+
sherpaonnx::OnlineSttStreamResult r = wrapper->getResult(streamIdStr);
|
|
348
|
+
BOOL isEndpoint = wrapper->isEndpoint(streamIdStr);
|
|
349
|
+
NSMutableArray* tokens = [NSMutableArray arrayWithCapacity:r.tokens.size()];
|
|
350
|
+
for (const auto& t : r.tokens) {
|
|
351
|
+
[tokens addObject:[NSString stringWithUTF8String:t.c_str()]];
|
|
352
|
+
}
|
|
353
|
+
NSMutableArray* timestamps = [NSMutableArray arrayWithCapacity:r.timestamps.size()];
|
|
354
|
+
for (float ts : r.timestamps) {
|
|
355
|
+
[timestamps addObject:@(ts)];
|
|
356
|
+
}
|
|
357
|
+
resolve(@{
|
|
358
|
+
@"text": [NSString stringWithUTF8String:r.text.c_str()] ?: @"",
|
|
359
|
+
@"tokens": tokens,
|
|
360
|
+
@"timestamps": timestamps,
|
|
361
|
+
@"isEndpoint": @(isEndpoint)
|
|
362
|
+
});
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
@end
|
package/ios/SherpaOnnx+TTS.mm
CHANGED
|
@@ -515,10 +515,11 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
|
|
|
515
515
|
}
|
|
516
516
|
|
|
517
517
|
- (void)generateTtsStream:(NSString *)instanceId
|
|
518
|
-
|
|
518
|
+
requestId:(NSString *)requestId
|
|
519
|
+
text:(NSString *)text
|
|
519
520
|
options:(NSDictionary *)options
|
|
520
|
-
|
|
521
|
-
|
|
521
|
+
resolve:(RCTPromiseResolveBlock)resolve
|
|
522
|
+
reject:(RCTPromiseRejectBlock)reject
|
|
522
523
|
{
|
|
523
524
|
if (instanceId == nil || [instanceId length] == 0) {
|
|
524
525
|
reject(@"TTS_STREAM_ERROR", @"instanceId is required", nil);
|
|
@@ -551,6 +552,7 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
|
|
|
551
552
|
std::string textStr = [text UTF8String];
|
|
552
553
|
int32_t sampleRate = instRef->wrapper->getSampleRate();
|
|
553
554
|
NSString *instanceIdCopy = [instanceId copy];
|
|
555
|
+
NSString *requestIdCopy = (requestId != nil && [requestId length] > 0) ? [requestId copy] : nil;
|
|
554
556
|
|
|
555
557
|
__weak SherpaOnnx *weakSelf = self;
|
|
556
558
|
dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0), ^{
|
|
@@ -560,7 +562,7 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
|
|
|
560
562
|
textStr,
|
|
561
563
|
static_cast<int32_t>(sid),
|
|
562
564
|
static_cast<float>(speed),
|
|
563
|
-
[weakSelf, sampleRate, instanceIdCopy, instRef](const float *samples, int32_t numSamples, float progress) -> int32_t {
|
|
565
|
+
[weakSelf, sampleRate, instanceIdCopy, requestIdCopy, instRef](const float *samples, int32_t numSamples, float progress) -> int32_t {
|
|
564
566
|
if (instRef->streamCancelled.load()) {
|
|
565
567
|
return 0;
|
|
566
568
|
}
|
|
@@ -570,13 +572,14 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
|
|
|
570
572
|
[samplesArray addObject:@(samples[i])];
|
|
571
573
|
}
|
|
572
574
|
|
|
573
|
-
|
|
575
|
+
NSMutableDictionary *payload = [NSMutableDictionary dictionaryWithDictionary:@{
|
|
574
576
|
@"instanceId": instanceIdCopy,
|
|
575
577
|
@"samples": samplesArray,
|
|
576
578
|
@"sampleRate": @(sampleRate),
|
|
577
579
|
@"progress": @(progress),
|
|
578
580
|
@"isFinal": @NO
|
|
579
|
-
};
|
|
581
|
+
}];
|
|
582
|
+
if (requestIdCopy != nil) payload[@"requestId"] = requestIdCopy;
|
|
580
583
|
|
|
581
584
|
dispatch_async(dispatch_get_main_queue(), ^{
|
|
582
585
|
if (weakSelf) {
|
|
@@ -589,25 +592,48 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
|
|
|
589
592
|
);
|
|
590
593
|
} @catch (NSException *exception) {
|
|
591
594
|
NSString *errorMsg = [NSString stringWithFormat:@"TTS streaming failed: %@", exception.reason];
|
|
595
|
+
NSMutableDictionary *errPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"message": errorMsg }];
|
|
596
|
+
if (requestIdCopy != nil) errPayload[@"requestId"] = requestIdCopy;
|
|
592
597
|
dispatch_async(dispatch_get_main_queue(), ^{
|
|
593
598
|
if (weakSelf) {
|
|
594
|
-
[weakSelf sendEventWithName:@"ttsStreamError" body
|
|
599
|
+
[weakSelf sendEventWithName:@"ttsStreamError" body:errPayload];
|
|
595
600
|
}
|
|
596
601
|
});
|
|
597
602
|
}
|
|
598
603
|
|
|
599
604
|
bool cancelled = instRef->streamCancelled.load();
|
|
600
605
|
if (!success && !cancelled) {
|
|
606
|
+
NSMutableDictionary *errPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"message": @"TTS streaming generation failed" }];
|
|
607
|
+
if (requestIdCopy != nil) errPayload[@"requestId"] = requestIdCopy;
|
|
608
|
+
dispatch_async(dispatch_get_main_queue(), ^{
|
|
609
|
+
if (weakSelf) {
|
|
610
|
+
[weakSelf sendEventWithName:@"ttsStreamError" body:errPayload];
|
|
611
|
+
}
|
|
612
|
+
});
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
// Emit final chunk (empty, progress 1, isFinal YES) when not cancelled, matching Android behaviour
|
|
616
|
+
if (!cancelled) {
|
|
617
|
+
NSMutableDictionary *finalPayload = [NSMutableDictionary dictionaryWithDictionary:@{
|
|
618
|
+
@"instanceId": instanceIdCopy,
|
|
619
|
+
@"samples": @[],
|
|
620
|
+
@"sampleRate": @(sampleRate),
|
|
621
|
+
@"progress": @1.0f,
|
|
622
|
+
@"isFinal": @YES
|
|
623
|
+
}];
|
|
624
|
+
if (requestIdCopy != nil) finalPayload[@"requestId"] = requestIdCopy;
|
|
601
625
|
dispatch_async(dispatch_get_main_queue(), ^{
|
|
602
626
|
if (weakSelf) {
|
|
603
|
-
[weakSelf sendEventWithName:@"
|
|
627
|
+
[weakSelf sendEventWithName:@"ttsStreamChunk" body:finalPayload];
|
|
604
628
|
}
|
|
605
629
|
});
|
|
606
630
|
}
|
|
607
631
|
|
|
632
|
+
NSMutableDictionary *endPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"cancelled": @(cancelled) }];
|
|
633
|
+
if (requestIdCopy != nil) endPayload[@"requestId"] = requestIdCopy;
|
|
608
634
|
dispatch_async(dispatch_get_main_queue(), ^{
|
|
609
635
|
if (weakSelf) {
|
|
610
|
-
[weakSelf sendEventWithName:@"ttsStreamEnd" body
|
|
636
|
+
[weakSelf sendEventWithName:@"ttsStreamEnd" body:endPayload];
|
|
611
637
|
}
|
|
612
638
|
});
|
|
613
639
|
|
package/ios/SherpaOnnx.mm
CHANGED
|
@@ -22,6 +22,12 @@
|
|
|
22
22
|
return @"SherpaOnnx";
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
+
- (instancetype)init
|
|
26
|
+
{
|
|
27
|
+
self = [super initWithDisabledObservation];
|
|
28
|
+
return self;
|
|
29
|
+
}
|
|
30
|
+
|
|
25
31
|
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
|
|
26
32
|
(const facebook::react::ObjCTurboModule::InitParams &)params
|
|
27
33
|
{
|
|
@@ -43,6 +43,9 @@ std::string FindLargestOnnxExcludingTokens(
|
|
|
43
43
|
const std::vector<std::string>& excludeTokens
|
|
44
44
|
);
|
|
45
45
|
|
|
46
|
+
/** Returns true if \p word appears in \p haystack as a standalone token (surrounded by separators: / - _ . space). */
|
|
47
|
+
bool ContainsWord(const std::string& haystack, const std::string& word);
|
|
48
|
+
|
|
46
49
|
} // namespace model_detect
|
|
47
50
|
} // namespace sherpaonnx
|
|
48
51
|
|
|
@@ -156,6 +156,22 @@ std::string FindFileByName(const std::string& baseDir, const std::string& fileNa
|
|
|
156
156
|
return "";
|
|
157
157
|
}
|
|
158
158
|
|
|
159
|
+
bool ContainsWord(const std::string& haystack, const std::string& word) {
|
|
160
|
+
if (word.empty()) return false;
|
|
161
|
+
size_t pos = 0;
|
|
162
|
+
auto isSep = [](char c) {
|
|
163
|
+
return c == '\0' || c == '/' || c == '-' || c == '_' || c == '.' || c == ' ';
|
|
164
|
+
};
|
|
165
|
+
while ((pos = haystack.find(word, pos)) != std::string::npos) {
|
|
166
|
+
char before = (pos == 0) ? '\0' : haystack[pos - 1];
|
|
167
|
+
size_t afterPos = pos + word.size();
|
|
168
|
+
char after = (afterPos >= haystack.size()) ? '\0' : haystack[afterPos];
|
|
169
|
+
if (isSep(before) && isSep(after)) return true;
|
|
170
|
+
pos++;
|
|
171
|
+
}
|
|
172
|
+
return false;
|
|
173
|
+
}
|
|
174
|
+
|
|
159
175
|
std::string FindDirectoryByName(const std::string& baseDir, const std::string& dirName, int maxDepth) {
|
|
160
176
|
std::string target = ToLower(dirName);
|
|
161
177
|
std::vector<std::string> toVisit = ListDirectories(baseDir);
|
|
@@ -32,6 +32,7 @@ SttModelKind ParseSttModelType(const std::string& modelType) {
|
|
|
32
32
|
if (modelType == "omnilingual") return SttModelKind::kOmnilingual;
|
|
33
33
|
if (modelType == "medasr") return SttModelKind::kMedAsr;
|
|
34
34
|
if (modelType == "telespeech_ctc") return SttModelKind::kTeleSpeechCtc;
|
|
35
|
+
if (modelType == "tone_ctc") return SttModelKind::kToneCtc;
|
|
35
36
|
return SttModelKind::kUnknown;
|
|
36
37
|
}
|
|
37
38
|
|
|
@@ -118,6 +119,10 @@ SttDetectResult DetectSttModel(
|
|
|
118
119
|
bool isLikelyOmnilingual = modelDirLower.find("omnilingual") != std::string::npos;
|
|
119
120
|
bool isLikelyMedAsr = modelDirLower.find("medasr") != std::string::npos;
|
|
120
121
|
bool isLikelyTeleSpeech = modelDirLower.find("telespeech") != std::string::npos;
|
|
122
|
+
// Tone CTC: match "tone" only as standalone word (not e.g. "cantonese"); also accept "t-one" / "t_one"
|
|
123
|
+
bool isLikelyToneCtc = modelDirLower.find("t-one") != std::string::npos ||
|
|
124
|
+
modelDirLower.find("t_one") != std::string::npos ||
|
|
125
|
+
ContainsWord(modelDirLower, "tone");
|
|
121
126
|
|
|
122
127
|
bool hasMoonshine = !moonshinePreprocess.empty() && !moonshineUncachedDecode.empty() &&
|
|
123
128
|
!moonshineCachedDecode.empty() && !moonshineEncode.empty();
|
|
@@ -127,6 +132,7 @@ SttDetectResult DetectSttModel(
|
|
|
127
132
|
bool hasOmnilingual = !ctcModelPath.empty() && isLikelyOmnilingual;
|
|
128
133
|
bool hasMedAsr = !ctcModelPath.empty() && isLikelyMedAsr;
|
|
129
134
|
bool hasTeleSpeechCtc = (!ctcModelPath.empty() || !paraformerModelPath.empty()) && isLikelyTeleSpeech;
|
|
135
|
+
bool hasToneCtc = !ctcModelPath.empty() && isLikelyToneCtc;
|
|
130
136
|
|
|
131
137
|
if (hasTransducer) {
|
|
132
138
|
if (isLikelyNemo || isLikelyTdt) {
|
|
@@ -178,6 +184,9 @@ SttDetectResult DetectSttModel(
|
|
|
178
184
|
if (hasTeleSpeechCtc) {
|
|
179
185
|
result.detectedModels.push_back({"telespeech_ctc", modelDir});
|
|
180
186
|
}
|
|
187
|
+
if (hasToneCtc) {
|
|
188
|
+
result.detectedModels.push_back({"tone_ctc", modelDir});
|
|
189
|
+
}
|
|
181
190
|
|
|
182
191
|
SttModelKind selected = SttModelKind::kUnknown;
|
|
183
192
|
|
|
@@ -201,7 +210,8 @@ SttDetectResult DetectSttModel(
|
|
|
201
210
|
return result;
|
|
202
211
|
}
|
|
203
212
|
if ((selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
|
|
204
|
-
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc
|
|
213
|
+
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc ||
|
|
214
|
+
selected == SttModelKind::kToneCtc) &&
|
|
205
215
|
ctcModelPath.empty()) {
|
|
206
216
|
result.error = "CTC model requested but model.onnx not found in " + modelDir;
|
|
207
217
|
return result;
|
|
@@ -242,6 +252,10 @@ SttDetectResult DetectSttModel(
|
|
|
242
252
|
result.error = "TeleSpeech CTC model requested but model not found in " + modelDir;
|
|
243
253
|
return result;
|
|
244
254
|
}
|
|
255
|
+
if (selected == SttModelKind::kToneCtc && !hasToneCtc) {
|
|
256
|
+
result.error = "Tone CTC model requested but path does not contain 'tone' (as a word), 't-one', or 't_one' (e.g. sherpa-onnx-streaming-t-one-*) in " + modelDir;
|
|
257
|
+
return result;
|
|
258
|
+
}
|
|
245
259
|
} else {
|
|
246
260
|
if (hasTransducer) {
|
|
247
261
|
selected = (isLikelyNemo || isLikelyTdt) ? SttModelKind::kNemoTransducer : SttModelKind::kTransducer;
|
|
@@ -279,6 +293,8 @@ SttDetectResult DetectSttModel(
|
|
|
279
293
|
selected = SttModelKind::kMedAsr;
|
|
280
294
|
} else if (hasTeleSpeechCtc) {
|
|
281
295
|
selected = SttModelKind::kTeleSpeechCtc;
|
|
296
|
+
} else if (hasToneCtc) {
|
|
297
|
+
selected = SttModelKind::kToneCtc;
|
|
282
298
|
} else if (!ctcModelPath.empty()) {
|
|
283
299
|
selected = SttModelKind::kZipformerCtc;
|
|
284
300
|
}
|
|
@@ -299,7 +315,8 @@ SttDetectResult DetectSttModel(
|
|
|
299
315
|
} else if (selected == SttModelKind::kParaformer) {
|
|
300
316
|
result.paths.paraformerModel = paraformerModelPath;
|
|
301
317
|
} else if (selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
|
|
302
|
-
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc
|
|
318
|
+
selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc ||
|
|
319
|
+
selected == SttModelKind::kToneCtc) {
|
|
303
320
|
result.paths.ctcModel = ctcModelPath;
|
|
304
321
|
} else if (selected == SttModelKind::kWhisper) {
|
|
305
322
|
result.paths.whisperEncoder = encoderPath;
|