react-native-sherpa-onnx 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. package/README.md +21 -7
  2. package/SherpaOnnx.podspec +1 -1
  3. package/android/build.gradle +35 -26
  4. package/android/prebuilt-download.gradle +27 -14
  5. package/android/src/main/cpp/CMakeLists.txt +51 -17
  6. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +14 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +16 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +19 -2
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +2 -1
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +1 -0
  12. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +114 -8
  13. package/android/src/main/java/com/sherpaonnx/SherpaOnnxOnlineSttHelper.kt +535 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +10 -10
  15. package/ios/SherpaOnnx+OnlineSTT.mm +365 -0
  16. package/ios/SherpaOnnx+TTS.mm +35 -9
  17. package/ios/SherpaOnnx.mm +6 -0
  18. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
  19. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +16 -0
  20. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +19 -2
  21. package/ios/model_detect/sherpa-onnx-model-detect.h +2 -1
  22. package/ios/online_stt/sherpa-onnx-online-stt-wrapper.h +85 -0
  23. package/ios/online_stt/sherpa-onnx-online-stt-wrapper.mm +270 -0
  24. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  25. package/lib/module/index.js +2 -2
  26. package/lib/module/stt/index.js +4 -0
  27. package/lib/module/stt/index.js.map +1 -1
  28. package/lib/module/stt/streaming.js +257 -0
  29. package/lib/module/stt/streaming.js.map +1 -0
  30. package/lib/module/stt/streamingTypes.js +38 -0
  31. package/lib/module/stt/streamingTypes.js.map +1 -0
  32. package/lib/module/tts/index.js +4 -43
  33. package/lib/module/tts/index.js.map +1 -1
  34. package/lib/module/tts/streaming.js +220 -0
  35. package/lib/module/tts/streaming.js.map +1 -0
  36. package/lib/module/tts/streamingTypes.js +4 -0
  37. package/lib/module/tts/streamingTypes.js.map +1 -0
  38. package/lib/module/tts/types.js +8 -1
  39. package/lib/module/tts/types.js.map +1 -1
  40. package/lib/typescript/src/NativeSherpaOnnx.d.ts +66 -1
  41. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  42. package/lib/typescript/src/stt/index.d.ts +3 -0
  43. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  44. package/lib/typescript/src/stt/streaming.d.ts +42 -0
  45. package/lib/typescript/src/stt/streaming.d.ts.map +1 -0
  46. package/lib/typescript/src/stt/streamingTypes.d.ts +122 -0
  47. package/lib/typescript/src/stt/streamingTypes.d.ts.map +1 -0
  48. package/lib/typescript/src/tts/index.d.ts +3 -1
  49. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  50. package/lib/typescript/src/tts/streaming.d.ts +24 -0
  51. package/lib/typescript/src/tts/streaming.d.ts.map +1 -0
  52. package/lib/typescript/src/tts/streamingTypes.d.ts +27 -0
  53. package/lib/typescript/src/tts/streamingTypes.d.ts.map +1 -0
  54. package/lib/typescript/src/tts/types.d.ts +19 -6
  55. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  56. package/package.json +1 -2
  57. package/src/NativeSherpaOnnx.ts +95 -0
  58. package/src/index.tsx +2 -2
  59. package/src/stt/index.ts +17 -0
  60. package/src/stt/streaming.ts +361 -0
  61. package/src/stt/streamingTypes.ts +151 -0
  62. package/src/tts/index.ts +6 -66
  63. package/src/tts/streaming.ts +336 -0
  64. package/src/tts/streamingTypes.ts +54 -0
  65. package/src/tts/types.ts +20 -10
  66. package/android/codegen.gradle +0 -57
@@ -0,0 +1,365 @@
1
+ /**
2
+ * SherpaOnnx+OnlineSTT.mm
3
+ *
4
+ * Purpose: iOS TurboModule methods for streaming (online) STT: initializeOnlineSttWithOptions,
5
+ * createSttStream, acceptSttWaveform, decodeSttStream, getSttStreamResult, etc.
6
+ * Uses sherpa-onnx-online-stt-wrapper for native OnlineRecognizer.
7
+ */
8
+
9
+ #import "SherpaOnnx.h"
10
+ #import <React/RCTLog.h>
11
+
12
+ #include "sherpa-onnx-online-stt-wrapper.h"
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <string>
16
+ #include <unordered_map>
17
+
18
+ static std::unordered_map<std::string, std::unique_ptr<sherpaonnx::OnlineSttWrapper>> g_online_stt_instances;
19
+ static std::unordered_map<std::string, std::string> g_online_stt_stream_to_instance;
20
+ static std::mutex g_online_stt_mutex;
21
+
22
+ static sherpaonnx::OnlineSttWrapper* getOnlineSttInstance(NSString* instanceId) {
23
+ if (instanceId == nil || [instanceId length] == 0) return nullptr;
24
+ std::string key = [instanceId UTF8String];
25
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
26
+ auto it = g_online_stt_instances.find(key);
27
+ return (it != g_online_stt_instances.end() && it->second != nullptr) ? it->second.get() : nullptr;
28
+ }
29
+
30
+ static sherpaonnx::OnlineSttWrapper* getOnlineSttInstanceForStream(NSString* streamId) {
31
+ if (streamId == nil || [streamId length] == 0) return nullptr;
32
+ std::string streamIdStr = [streamId UTF8String];
33
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
34
+ auto sit = g_online_stt_stream_to_instance.find(streamIdStr);
35
+ if (sit == g_online_stt_stream_to_instance.end()) return nullptr;
36
+ auto it = g_online_stt_instances.find(sit->second);
37
+ return (it != g_online_stt_instances.end() && it->second != nullptr) ? it->second.get() : nullptr;
38
+ }
39
+
40
+
41
+ @implementation SherpaOnnx (OnlineSTT)
42
+
43
+ - (void)initializeOnlineSttWithOptions:(NSString *)instanceId
44
+ options:(JS::NativeSherpaOnnx::SpecInitializeOnlineSttWithOptionsOptions &)options
45
+ resolve:(RCTPromiseResolveBlock)resolve
46
+ reject:(RCTPromiseRejectBlock)reject
47
+ {
48
+ if (instanceId == nil || [instanceId length] == 0) {
49
+ reject(@"INIT_ERROR", @"instanceId is required", nil);
50
+ return;
51
+ }
52
+ NSString *modelDir = options.modelDir();
53
+ NSString *modelType = options.modelType();
54
+ RCTLogInfo(@"[SherpaOnnx OnlineSTT] initializeOnlineSttWithOptions instanceId=%@ modelDir=%@ modelType=%@",
55
+ instanceId, modelDir, modelType);
56
+ if (modelDir == nil || [modelDir length] == 0) {
57
+ reject(@"INIT_ERROR", @"modelDir is required", nil);
58
+ return;
59
+ }
60
+ std::string instanceIdStr = [instanceId UTF8String];
61
+ std::string modelDirStr = [modelDir UTF8String];
62
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0) ? [modelType UTF8String] : "transducer";
63
+
64
+ auto enableEndpoint = options.enableEndpoint();
65
+ NSString *decodingMethod = options.decodingMethod();
66
+ auto maxActivePaths = options.maxActivePaths();
67
+ NSString *hotwordsFile = options.hotwordsFile();
68
+ auto hotwordsScore = options.hotwordsScore();
69
+ auto numThreads = options.numThreads();
70
+ NSString *provider = options.provider();
71
+ NSString *ruleFsts = options.ruleFsts();
72
+ NSString *ruleFars = options.ruleFars();
73
+ auto blankPenalty = options.blankPenalty();
74
+ auto debug = options.debug();
75
+ auto rule1MustContainNonSilence = options.rule1MustContainNonSilence();
76
+ auto rule1MinTrailingSilence = options.rule1MinTrailingSilence();
77
+ auto rule1MinUtteranceLength = options.rule1MinUtteranceLength();
78
+ auto rule2MustContainNonSilence = options.rule2MustContainNonSilence();
79
+ auto rule2MinTrailingSilence = options.rule2MinTrailingSilence();
80
+ auto rule2MinUtteranceLength = options.rule2MinUtteranceLength();
81
+ auto rule3MustContainNonSilence = options.rule3MustContainNonSilence();
82
+ auto rule3MinTrailingSilence = options.rule3MinTrailingSilence();
83
+ auto rule3MinUtteranceLength = options.rule3MinUtteranceLength();
84
+
85
+ @try {
86
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
87
+ if (g_online_stt_instances.find(instanceIdStr) != g_online_stt_instances.end()) {
88
+ reject(@"INIT_ERROR", @"Online STT instance already exists", nil);
89
+ return;
90
+ }
91
+ RCTLogInfo(@"[SherpaOnnx OnlineSTT] creating wrapper and calling initialize");
92
+ auto wrapper = std::make_unique<sherpaonnx::OnlineSttWrapper>();
93
+ sherpaonnx::OnlineSttInitResult result = wrapper->initialize(
94
+ modelDirStr,
95
+ modelTypeStr,
96
+ enableEndpoint.has_value() && enableEndpoint.value(),
97
+ decodingMethod != nil ? [decodingMethod UTF8String] : "greedy_search",
98
+ maxActivePaths.has_value() ? (int32_t)maxActivePaths.value() : 4,
99
+ hotwordsFile != nil ? [hotwordsFile UTF8String] : "",
100
+ hotwordsScore.has_value() ? (float)hotwordsScore.value() : 1.5f,
101
+ numThreads.has_value() ? (int32_t)numThreads.value() : 1,
102
+ provider != nil ? [provider UTF8String] : "cpu",
103
+ ruleFsts != nil ? [ruleFsts UTF8String] : "",
104
+ ruleFars != nil ? [ruleFars UTF8String] : "",
105
+ blankPenalty.has_value() ? (float)blankPenalty.value() : 0.f,
106
+ debug.has_value() && debug.value(),
107
+ rule1MustContainNonSilence.has_value() && rule1MustContainNonSilence.value(),
108
+ rule1MinTrailingSilence.has_value() ? (float)rule1MinTrailingSilence.value() : 2.4f,
109
+ rule1MinUtteranceLength.has_value() ? (float)rule1MinUtteranceLength.value() : 0.f,
110
+ rule2MustContainNonSilence.has_value() && rule2MustContainNonSilence.value(),
111
+ rule2MinTrailingSilence.has_value() ? (float)rule2MinTrailingSilence.value() : 1.2f,
112
+ rule2MinUtteranceLength.has_value() ? (float)rule2MinUtteranceLength.value() : 0.f,
113
+ rule3MustContainNonSilence.has_value() && rule3MustContainNonSilence.value(),
114
+ rule3MinTrailingSilence.has_value() ? (float)rule3MinTrailingSilence.value() : 0.f,
115
+ rule3MinUtteranceLength.has_value() ? (float)rule3MinUtteranceLength.value() : 20.f
116
+ );
117
+ if (!result.success) {
118
+ RCTLogError(@"[SherpaOnnx OnlineSTT] initialize failed: %s", result.error.c_str());
119
+ reject(@"INIT_ERROR", [NSString stringWithUTF8String:result.error.c_str()], nil);
120
+ return;
121
+ }
122
+ g_online_stt_instances[instanceIdStr] = std::move(wrapper);
123
+ RCTLogInfo(@"[SherpaOnnx OnlineSTT] init success for instanceId=%@", instanceId);
124
+ resolve(@{ @"success": @YES });
125
+ } @catch (NSException *exception) {
126
+ NSString *errorMsg = [NSString stringWithFormat:@"Online STT init failed: %@", exception.reason];
127
+ RCTLogError(@"%@", errorMsg);
128
+ reject(@"INIT_ERROR", errorMsg, nil);
129
+ }
130
+ }
131
+
132
+ - (void)createSttStream:(NSString *)instanceId
133
+ streamId:(NSString *)streamId
134
+ hotwords:(NSString *)hotwords
135
+ resolve:(RCTPromiseResolveBlock)resolve
136
+ reject:(RCTPromiseRejectBlock)reject
137
+ {
138
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstance(instanceId);
139
+ if (wrapper == nullptr) {
140
+ reject(@"STREAM_ERROR", @"Online STT instance not found", nil);
141
+ return;
142
+ }
143
+ std::string instanceIdStr = [instanceId UTF8String];
144
+ std::string streamIdStr = [streamId UTF8String];
145
+ std::string hotwordsStr = hotwords != nil ? [hotwords UTF8String] : "";
146
+ if (!wrapper->createStream(streamIdStr, hotwordsStr)) {
147
+ reject(@"STREAM_ERROR", @"Stream already exists or create failed", nil);
148
+ return;
149
+ }
150
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
151
+ g_online_stt_stream_to_instance[streamIdStr] = instanceIdStr;
152
+ resolve(nil);
153
+ }
154
+
155
+ - (void)acceptSttWaveform:(NSString *)streamId
156
+ samples:(NSArray *)samples
157
+ sampleRate:(double)sampleRate
158
+ resolve:(RCTPromiseResolveBlock)resolve
159
+ reject:(RCTPromiseRejectBlock)reject
160
+ {
161
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
162
+ if (wrapper == nullptr) {
163
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
164
+ return;
165
+ }
166
+ std::vector<float> floatSamples;
167
+ floatSamples.reserve([samples count]);
168
+ for (NSNumber* n in samples) {
169
+ floatSamples.push_back([n floatValue]);
170
+ }
171
+ std::string streamIdStr = [streamId UTF8String];
172
+ wrapper->acceptWaveform(streamIdStr, (int32_t)sampleRate, floatSamples.data(), floatSamples.size());
173
+ resolve(nil);
174
+ }
175
+
176
+ - (void)sttStreamInputFinished:(NSString *)streamId
177
+ resolve:(RCTPromiseResolveBlock)resolve
178
+ reject:(RCTPromiseRejectBlock)reject
179
+ {
180
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
181
+ if (wrapper == nullptr) {
182
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
183
+ return;
184
+ }
185
+ std::string streamIdStr = [streamId UTF8String];
186
+ wrapper->inputFinished(streamIdStr);
187
+ resolve(nil);
188
+ }
189
+
190
+ - (void)decodeSttStream:(NSString *)streamId
191
+ resolve:(RCTPromiseResolveBlock)resolve
192
+ reject:(RCTPromiseRejectBlock)reject
193
+ {
194
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
195
+ if (wrapper == nullptr) {
196
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
197
+ return;
198
+ }
199
+ std::string streamIdStr = [streamId UTF8String];
200
+ wrapper->decode(streamIdStr);
201
+ resolve(nil);
202
+ }
203
+
204
+ - (void)isSttStreamReady:(NSString *)streamId
205
+ resolve:(RCTPromiseResolveBlock)resolve
206
+ reject:(RCTPromiseRejectBlock)reject
207
+ {
208
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
209
+ if (wrapper == nullptr) {
210
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
211
+ return;
212
+ }
213
+ std::string streamIdStr = [streamId UTF8String];
214
+ BOOL ready = wrapper->isReady(streamIdStr);
215
+ resolve(@(ready));
216
+ }
217
+
218
+ - (void)getSttStreamResult:(NSString *)streamId
219
+ resolve:(RCTPromiseResolveBlock)resolve
220
+ reject:(RCTPromiseRejectBlock)reject
221
+ {
222
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
223
+ if (wrapper == nullptr) {
224
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
225
+ return;
226
+ }
227
+ std::string streamIdStr = [streamId UTF8String];
228
+ sherpaonnx::OnlineSttStreamResult r = wrapper->getResult(streamIdStr);
229
+ NSMutableArray* tokens = [NSMutableArray arrayWithCapacity:r.tokens.size()];
230
+ for (const auto& t : r.tokens) {
231
+ [tokens addObject:[NSString stringWithUTF8String:t.c_str()]];
232
+ }
233
+ NSMutableArray* timestamps = [NSMutableArray arrayWithCapacity:r.timestamps.size()];
234
+ for (float ts : r.timestamps) {
235
+ [timestamps addObject:@(ts)];
236
+ }
237
+ resolve(@{
238
+ @"text": [NSString stringWithUTF8String:r.text.c_str()] ?: @"",
239
+ @"tokens": tokens,
240
+ @"timestamps": timestamps
241
+ });
242
+ }
243
+
244
+ - (void)isSttStreamEndpoint:(NSString *)streamId
245
+ resolve:(RCTPromiseResolveBlock)resolve
246
+ reject:(RCTPromiseRejectBlock)reject
247
+ {
248
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
249
+ if (wrapper == nullptr) {
250
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
251
+ return;
252
+ }
253
+ std::string streamIdStr = [streamId UTF8String];
254
+ BOOL endpoint = wrapper->isEndpoint(streamIdStr);
255
+ resolve(@(endpoint));
256
+ }
257
+
258
+ - (void)resetSttStream:(NSString *)streamId
259
+ resolve:(RCTPromiseResolveBlock)resolve
260
+ reject:(RCTPromiseRejectBlock)reject
261
+ {
262
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
263
+ if (wrapper == nullptr) {
264
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
265
+ return;
266
+ }
267
+ std::string streamIdStr = [streamId UTF8String];
268
+ wrapper->resetStream(streamIdStr);
269
+ resolve(nil);
270
+ }
271
+
272
+ - (void)releaseSttStream:(NSString *)streamId
273
+ resolve:(RCTPromiseResolveBlock)resolve
274
+ reject:(RCTPromiseRejectBlock)reject
275
+ {
276
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
277
+ std::string streamIdStr = [streamId UTF8String];
278
+ if (wrapper != nullptr) {
279
+ wrapper->releaseStream(streamIdStr);
280
+ }
281
+ {
282
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
283
+ g_online_stt_stream_to_instance.erase(streamIdStr);
284
+ }
285
+ resolve(nil);
286
+ }
287
+
288
+ - (void)unloadOnlineStt:(NSString *)instanceId
289
+ resolve:(RCTPromiseResolveBlock)resolve
290
+ reject:(RCTPromiseRejectBlock)reject
291
+ {
292
+ if (instanceId == nil || [instanceId length] == 0) {
293
+ resolve(nil);
294
+ return;
295
+ }
296
+ std::string key = [instanceId UTF8String];
297
+ @try {
298
+ std::lock_guard<std::mutex> lock(g_online_stt_mutex);
299
+ auto it = g_online_stt_instances.find(key);
300
+ if (it != g_online_stt_instances.end()) {
301
+ it->second->unload();
302
+ for (auto sit = g_online_stt_stream_to_instance.begin(); sit != g_online_stt_stream_to_instance.end(); ) {
303
+ if (sit->second == key) sit = g_online_stt_stream_to_instance.erase(sit);
304
+ else ++sit;
305
+ }
306
+ g_online_stt_instances.erase(it);
307
+ }
308
+ resolve(nil);
309
+ } @catch (NSException *exception) {
310
+ reject(@"RELEASE_ERROR", [NSString stringWithFormat:@"unloadOnlineStt failed: %@", exception.reason], nil);
311
+ }
312
+ }
313
+
314
+ - (void)processSttAudioChunk:(NSString *)streamId
315
+ samples:(NSArray *)samples
316
+ sampleRate:(double)sampleRate
317
+ resolve:(RCTPromiseResolveBlock)resolve
318
+ reject:(RCTPromiseRejectBlock)reject
319
+ {
320
+ sherpaonnx::OnlineSttWrapper* wrapper = getOnlineSttInstanceForStream(streamId);
321
+ if (wrapper == nullptr) {
322
+ reject(@"STREAM_ERROR", @"Stream not found", nil);
323
+ return;
324
+ }
325
+ std::string streamIdStr = [streamId UTF8String];
326
+ std::vector<float> floatSamples;
327
+ NSUInteger count = [samples count];
328
+ floatSamples.reserve(count);
329
+ for (NSUInteger i = 0; i < count; i++) {
330
+ id obj = [samples objectAtIndex:i];
331
+ float val = 0.0f;
332
+ if ([obj isKindOfClass:[NSNumber class]]) {
333
+ val = [(NSNumber *)obj floatValue];
334
+ } else if ([obj respondsToSelector:@selector(doubleValue)]) {
335
+ val = (float)[(id)obj doubleValue];
336
+ }
337
+ floatSamples.push_back(val);
338
+ }
339
+ if (floatSamples.empty()) {
340
+ RCTLogWarn(@"[SherpaOnnx OnlineSTT] processSttAudioChunk: no samples (count=%lu)", (unsigned long)count);
341
+ }
342
+
343
+ wrapper->acceptWaveform(streamIdStr, (int32_t)sampleRate, floatSamples.data(), floatSamples.size());
344
+ while (wrapper->isReady(streamIdStr)) {
345
+ wrapper->decode(streamIdStr);
346
+ }
347
+ sherpaonnx::OnlineSttStreamResult r = wrapper->getResult(streamIdStr);
348
+ BOOL isEndpoint = wrapper->isEndpoint(streamIdStr);
349
+ NSMutableArray* tokens = [NSMutableArray arrayWithCapacity:r.tokens.size()];
350
+ for (const auto& t : r.tokens) {
351
+ [tokens addObject:[NSString stringWithUTF8String:t.c_str()]];
352
+ }
353
+ NSMutableArray* timestamps = [NSMutableArray arrayWithCapacity:r.timestamps.size()];
354
+ for (float ts : r.timestamps) {
355
+ [timestamps addObject:@(ts)];
356
+ }
357
+ resolve(@{
358
+ @"text": [NSString stringWithUTF8String:r.text.c_str()] ?: @"",
359
+ @"tokens": tokens,
360
+ @"timestamps": timestamps,
361
+ @"isEndpoint": @(isEndpoint)
362
+ });
363
+ }
364
+
365
+ @end
@@ -515,10 +515,11 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
515
515
  }
516
516
 
517
517
  - (void)generateTtsStream:(NSString *)instanceId
518
- text:(NSString *)text
518
+ requestId:(NSString *)requestId
519
+ text:(NSString *)text
519
520
  options:(NSDictionary *)options
520
- resolve:(RCTPromiseResolveBlock)resolve
521
- reject:(RCTPromiseRejectBlock)reject
521
+ resolve:(RCTPromiseResolveBlock)resolve
522
+ reject:(RCTPromiseRejectBlock)reject
522
523
  {
523
524
  if (instanceId == nil || [instanceId length] == 0) {
524
525
  reject(@"TTS_STREAM_ERROR", @"instanceId is required", nil);
@@ -551,6 +552,7 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
551
552
  std::string textStr = [text UTF8String];
552
553
  int32_t sampleRate = instRef->wrapper->getSampleRate();
553
554
  NSString *instanceIdCopy = [instanceId copy];
555
+ NSString *requestIdCopy = (requestId != nil && [requestId length] > 0) ? [requestId copy] : nil;
554
556
 
555
557
  __weak SherpaOnnx *weakSelf = self;
556
558
  dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0), ^{
@@ -560,7 +562,7 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
560
562
  textStr,
561
563
  static_cast<int32_t>(sid),
562
564
  static_cast<float>(speed),
563
- [weakSelf, sampleRate, instanceIdCopy, instRef](const float *samples, int32_t numSamples, float progress) -> int32_t {
565
+ [weakSelf, sampleRate, instanceIdCopy, requestIdCopy, instRef](const float *samples, int32_t numSamples, float progress) -> int32_t {
564
566
  if (instRef->streamCancelled.load()) {
565
567
  return 0;
566
568
  }
@@ -570,13 +572,14 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
570
572
  [samplesArray addObject:@(samples[i])];
571
573
  }
572
574
 
573
- NSDictionary *payload = @{
575
+ NSMutableDictionary *payload = [NSMutableDictionary dictionaryWithDictionary:@{
574
576
  @"instanceId": instanceIdCopy,
575
577
  @"samples": samplesArray,
576
578
  @"sampleRate": @(sampleRate),
577
579
  @"progress": @(progress),
578
580
  @"isFinal": @NO
579
- };
581
+ }];
582
+ if (requestIdCopy != nil) payload[@"requestId"] = requestIdCopy;
580
583
 
581
584
  dispatch_async(dispatch_get_main_queue(), ^{
582
585
  if (weakSelf) {
@@ -589,25 +592,48 @@ std::vector<std::string> SplitTtsTokens(const std::string &text) {
589
592
  );
590
593
  } @catch (NSException *exception) {
591
594
  NSString *errorMsg = [NSString stringWithFormat:@"TTS streaming failed: %@", exception.reason];
595
+ NSMutableDictionary *errPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"message": errorMsg }];
596
+ if (requestIdCopy != nil) errPayload[@"requestId"] = requestIdCopy;
592
597
  dispatch_async(dispatch_get_main_queue(), ^{
593
598
  if (weakSelf) {
594
- [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"instanceId": instanceIdCopy, @"message": errorMsg }];
599
+ [weakSelf sendEventWithName:@"ttsStreamError" body:errPayload];
595
600
  }
596
601
  });
597
602
  }
598
603
 
599
604
  bool cancelled = instRef->streamCancelled.load();
600
605
  if (!success && !cancelled) {
606
+ NSMutableDictionary *errPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"message": @"TTS streaming generation failed" }];
607
+ if (requestIdCopy != nil) errPayload[@"requestId"] = requestIdCopy;
608
+ dispatch_async(dispatch_get_main_queue(), ^{
609
+ if (weakSelf) {
610
+ [weakSelf sendEventWithName:@"ttsStreamError" body:errPayload];
611
+ }
612
+ });
613
+ }
614
+
615
+ // Emit final chunk (empty, progress 1, isFinal YES) when not cancelled, matching Android behaviour
616
+ if (!cancelled) {
617
+ NSMutableDictionary *finalPayload = [NSMutableDictionary dictionaryWithDictionary:@{
618
+ @"instanceId": instanceIdCopy,
619
+ @"samples": @[],
620
+ @"sampleRate": @(sampleRate),
621
+ @"progress": @1.0f,
622
+ @"isFinal": @YES
623
+ }];
624
+ if (requestIdCopy != nil) finalPayload[@"requestId"] = requestIdCopy;
601
625
  dispatch_async(dispatch_get_main_queue(), ^{
602
626
  if (weakSelf) {
603
- [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"instanceId": instanceIdCopy, @"message": @"TTS streaming generation failed" }];
627
+ [weakSelf sendEventWithName:@"ttsStreamChunk" body:finalPayload];
604
628
  }
605
629
  });
606
630
  }
607
631
 
632
+ NSMutableDictionary *endPayload = [NSMutableDictionary dictionaryWithDictionary:@{ @"instanceId": instanceIdCopy, @"cancelled": @(cancelled) }];
633
+ if (requestIdCopy != nil) endPayload[@"requestId"] = requestIdCopy;
608
634
  dispatch_async(dispatch_get_main_queue(), ^{
609
635
  if (weakSelf) {
610
- [weakSelf sendEventWithName:@"ttsStreamEnd" body:@{ @"instanceId": instanceIdCopy, @"cancelled": @(cancelled) }];
636
+ [weakSelf sendEventWithName:@"ttsStreamEnd" body:endPayload];
611
637
  }
612
638
  });
613
639
 
package/ios/SherpaOnnx.mm CHANGED
@@ -22,6 +22,12 @@
22
22
  return @"SherpaOnnx";
23
23
  }
24
24
 
25
+ - (instancetype)init
26
+ {
27
+ self = [super initWithDisabledObservation];
28
+ return self;
29
+ }
30
+
25
31
  - (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
26
32
  (const facebook::react::ObjCTurboModule::InitParams &)params
27
33
  {
@@ -43,6 +43,9 @@ std::string FindLargestOnnxExcludingTokens(
43
43
  const std::vector<std::string>& excludeTokens
44
44
  );
45
45
 
46
+ /** Returns true if \p word appears in \p haystack as a standalone token (surrounded by separators: / - _ . space). */
47
+ bool ContainsWord(const std::string& haystack, const std::string& word);
48
+
46
49
  } // namespace model_detect
47
50
  } // namespace sherpaonnx
48
51
 
@@ -156,6 +156,22 @@ std::string FindFileByName(const std::string& baseDir, const std::string& fileNa
156
156
  return "";
157
157
  }
158
158
 
159
+ bool ContainsWord(const std::string& haystack, const std::string& word) {
160
+ if (word.empty()) return false;
161
+ size_t pos = 0;
162
+ auto isSep = [](char c) {
163
+ return c == '\0' || c == '/' || c == '-' || c == '_' || c == '.' || c == ' ';
164
+ };
165
+ while ((pos = haystack.find(word, pos)) != std::string::npos) {
166
+ char before = (pos == 0) ? '\0' : haystack[pos - 1];
167
+ size_t afterPos = pos + word.size();
168
+ char after = (afterPos >= haystack.size()) ? '\0' : haystack[afterPos];
169
+ if (isSep(before) && isSep(after)) return true;
170
+ pos++;
171
+ }
172
+ return false;
173
+ }
174
+
159
175
  std::string FindDirectoryByName(const std::string& baseDir, const std::string& dirName, int maxDepth) {
160
176
  std::string target = ToLower(dirName);
161
177
  std::vector<std::string> toVisit = ListDirectories(baseDir);
@@ -32,6 +32,7 @@ SttModelKind ParseSttModelType(const std::string& modelType) {
32
32
  if (modelType == "omnilingual") return SttModelKind::kOmnilingual;
33
33
  if (modelType == "medasr") return SttModelKind::kMedAsr;
34
34
  if (modelType == "telespeech_ctc") return SttModelKind::kTeleSpeechCtc;
35
+ if (modelType == "tone_ctc") return SttModelKind::kToneCtc;
35
36
  return SttModelKind::kUnknown;
36
37
  }
37
38
 
@@ -118,6 +119,10 @@ SttDetectResult DetectSttModel(
118
119
  bool isLikelyOmnilingual = modelDirLower.find("omnilingual") != std::string::npos;
119
120
  bool isLikelyMedAsr = modelDirLower.find("medasr") != std::string::npos;
120
121
  bool isLikelyTeleSpeech = modelDirLower.find("telespeech") != std::string::npos;
122
+ // Tone CTC: match "tone" only as standalone word (not e.g. "cantonese"); also accept "t-one" / "t_one"
123
+ bool isLikelyToneCtc = modelDirLower.find("t-one") != std::string::npos ||
124
+ modelDirLower.find("t_one") != std::string::npos ||
125
+ ContainsWord(modelDirLower, "tone");
121
126
 
122
127
  bool hasMoonshine = !moonshinePreprocess.empty() && !moonshineUncachedDecode.empty() &&
123
128
  !moonshineCachedDecode.empty() && !moonshineEncode.empty();
@@ -127,6 +132,7 @@ SttDetectResult DetectSttModel(
127
132
  bool hasOmnilingual = !ctcModelPath.empty() && isLikelyOmnilingual;
128
133
  bool hasMedAsr = !ctcModelPath.empty() && isLikelyMedAsr;
129
134
  bool hasTeleSpeechCtc = (!ctcModelPath.empty() || !paraformerModelPath.empty()) && isLikelyTeleSpeech;
135
+ bool hasToneCtc = !ctcModelPath.empty() && isLikelyToneCtc;
130
136
 
131
137
  if (hasTransducer) {
132
138
  if (isLikelyNemo || isLikelyTdt) {
@@ -178,6 +184,9 @@ SttDetectResult DetectSttModel(
178
184
  if (hasTeleSpeechCtc) {
179
185
  result.detectedModels.push_back({"telespeech_ctc", modelDir});
180
186
  }
187
+ if (hasToneCtc) {
188
+ result.detectedModels.push_back({"tone_ctc", modelDir});
189
+ }
181
190
 
182
191
  SttModelKind selected = SttModelKind::kUnknown;
183
192
 
@@ -201,7 +210,8 @@ SttDetectResult DetectSttModel(
201
210
  return result;
202
211
  }
203
212
  if ((selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
204
- selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) &&
213
+ selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc ||
214
+ selected == SttModelKind::kToneCtc) &&
205
215
  ctcModelPath.empty()) {
206
216
  result.error = "CTC model requested but model.onnx not found in " + modelDir;
207
217
  return result;
@@ -242,6 +252,10 @@ SttDetectResult DetectSttModel(
242
252
  result.error = "TeleSpeech CTC model requested but model not found in " + modelDir;
243
253
  return result;
244
254
  }
255
+ if (selected == SttModelKind::kToneCtc && !hasToneCtc) {
256
+ result.error = "Tone CTC model requested but path does not contain 'tone' (as a word), 't-one', or 't_one' (e.g. sherpa-onnx-streaming-t-one-*) in " + modelDir;
257
+ return result;
258
+ }
245
259
  } else {
246
260
  if (hasTransducer) {
247
261
  selected = (isLikelyNemo || isLikelyTdt) ? SttModelKind::kNemoTransducer : SttModelKind::kTransducer;
@@ -279,6 +293,8 @@ SttDetectResult DetectSttModel(
279
293
  selected = SttModelKind::kMedAsr;
280
294
  } else if (hasTeleSpeechCtc) {
281
295
  selected = SttModelKind::kTeleSpeechCtc;
296
+ } else if (hasToneCtc) {
297
+ selected = SttModelKind::kToneCtc;
282
298
  } else if (!ctcModelPath.empty()) {
283
299
  selected = SttModelKind::kZipformerCtc;
284
300
  }
@@ -299,7 +315,8 @@ SttDetectResult DetectSttModel(
299
315
  } else if (selected == SttModelKind::kParaformer) {
300
316
  result.paths.paraformerModel = paraformerModelPath;
301
317
  } else if (selected == SttModelKind::kNemoCtc || selected == SttModelKind::kWenetCtc ||
302
- selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc) {
318
+ selected == SttModelKind::kSenseVoice || selected == SttModelKind::kZipformerCtc ||
319
+ selected == SttModelKind::kToneCtc) {
303
320
  result.paths.ctcModel = ctcModelPath;
304
321
  } else if (selected == SttModelKind::kWhisper) {
305
322
  result.paths.whisperEncoder = encoderPath;
@@ -25,7 +25,8 @@ enum class SttModelKind {
25
25
  kCanary,
26
26
  kOmnilingual,
27
27
  kMedAsr,
28
- kTeleSpeechCtc
28
+ kTeleSpeechCtc,
29
+ kToneCtc
29
30
  };
30
31
 
31
32
  enum class TtsModelKind {