react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -1,712 +1,1101 @@
1
- #import "SherpaOnnx.h"
2
- #import <React/RCTLog.h>
3
- #import <React/RCTUtils.h>
4
- #import <UIKit/UIKit.h>
5
- #import <AVFoundation/AVFoundation.h>
6
-
7
- #include "sherpa-onnx-tts-wrapper.h"
8
- #include <atomic>
9
- #include <memory>
10
- #include <sstream>
11
- #include <string>
12
- #include <vector>
13
-
14
- // Global TTS wrapper instance
15
- static std::unique_ptr<sherpaonnx::TtsWrapper> g_tts_wrapper = nullptr;
16
- static std::atomic<bool> g_tts_stream_running{false};
17
- static std::atomic<bool> g_tts_stream_cancelled{false};
18
- static AVAudioEngine *g_tts_engine = nil;
19
- static AVAudioPlayerNode *g_tts_player = nil;
20
- static AVAudioFormat *g_tts_format = nil;
21
- static NSString *g_tts_model_dir = nil;
22
- static NSString *g_tts_model_type = nil;
23
- static int32_t g_tts_num_threads = 2;
24
- static BOOL g_tts_debug = NO;
25
- static NSNumber *g_tts_noise_scale = nil;
26
- static NSNumber *g_tts_noise_scale_w = nil;
27
- static NSNumber *g_tts_length_scale = nil;
28
-
29
- namespace {
30
- std::vector<std::string> SplitTtsTokens(const std::string &text) {
31
- std::vector<std::string> tokens;
32
- std::istringstream iss(text);
33
- std::string token;
34
- while (iss >> token) {
35
- tokens.push_back(token);
36
- }
37
- if (tokens.empty() && !text.empty()) {
38
- tokens.push_back(text);
39
- }
40
- return tokens;
41
- }
42
- }
43
-
44
- @implementation SherpaOnnx (TTS)
45
-
46
- - (void)initializeTts:(NSString *)modelDir
47
- modelType:(NSString *)modelType
48
- numThreads:(double)numThreads
49
- debug:(BOOL)debug
50
- noiseScale:(NSNumber *)noiseScale
51
- noiseScaleW:(NSNumber *)noiseScaleW
52
- lengthScale:(NSNumber *)lengthScale
53
- withResolver:(RCTPromiseResolveBlock)resolve
54
- withRejecter:(RCTPromiseRejectBlock)reject
55
- {
56
- RCTLogInfo(@"Initializing TTS with modelDir: %@, modelType: %@", modelDir, modelType);
57
-
58
- @try {
59
- if (g_tts_wrapper == nullptr) {
60
- g_tts_wrapper = std::make_unique<sherpaonnx::TtsWrapper>();
61
- }
62
-
63
- std::string modelDirStr = [modelDir UTF8String];
64
- std::string modelTypeStr = [modelType UTF8String];
65
-
66
- std::optional<float> noiseScaleOpt = std::nullopt;
67
- std::optional<float> noiseScaleWOpt = std::nullopt;
68
- std::optional<float> lengthScaleOpt = std::nullopt;
69
- if (noiseScale != nil) {
70
- noiseScaleOpt = [noiseScale floatValue];
71
- }
72
- if (noiseScaleW != nil) {
73
- noiseScaleWOpt = [noiseScaleW floatValue];
74
- }
75
- if (lengthScale != nil) {
76
- lengthScaleOpt = [lengthScale floatValue];
77
- }
78
-
79
- sherpaonnx::TtsInitializeResult result = g_tts_wrapper->initialize(
80
- modelDirStr,
81
- modelTypeStr,
82
- static_cast<int32_t>(numThreads),
83
- debug,
84
- noiseScaleOpt,
85
- noiseScaleWOpt,
86
- lengthScaleOpt
87
- );
88
-
89
- if (result.success) {
90
- RCTLogInfo(@"TTS initialization successful");
91
-
92
- g_tts_model_dir = [modelDir copy];
93
- g_tts_model_type = [modelType copy];
94
- g_tts_num_threads = static_cast<int32_t>(numThreads);
95
- g_tts_debug = debug;
96
- g_tts_noise_scale = noiseScale ? [noiseScale copy] : nil;
97
- g_tts_noise_scale_w = noiseScaleW ? [noiseScaleW copy] : nil;
98
- g_tts_length_scale = lengthScale ? [lengthScale copy] : nil;
99
-
100
- NSMutableArray *detectedModelsArray = [NSMutableArray array];
101
- for (const auto& model : result.detectedModels) {
102
- NSDictionary *modelDict = @{
103
- @"type": [NSString stringWithUTF8String:model.type.c_str()],
104
- @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()]
105
- };
106
- [detectedModelsArray addObject:modelDict];
107
- }
108
-
109
- NSDictionary *resultDict = @{
110
- @"success": @YES,
111
- @"detectedModels": detectedModelsArray
112
- };
113
-
114
- resolve(resultDict);
115
- } else {
116
- NSString *errorMsg = @"Failed to initialize TTS";
117
- RCTLogError(@"%@", errorMsg);
118
- reject(@"TTS_INIT_ERROR", errorMsg, nil);
119
- }
120
- } @catch (NSException *exception) {
121
- NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS init: %@", exception.reason];
122
- RCTLogError(@"%@", errorMsg);
123
- reject(@"TTS_INIT_ERROR", errorMsg, nil);
124
- }
125
- }
126
-
127
- - (void)updateTtsParams:(NSNumber *)noiseScale
128
- noiseScaleW:(NSNumber *)noiseScaleW
129
- lengthScale:(NSNumber *)lengthScale
130
- withResolver:(RCTPromiseResolveBlock)resolve
131
- withRejecter:(RCTPromiseRejectBlock)reject
132
- {
133
- if (g_tts_stream_running.load()) {
134
- reject(@"TTS_UPDATE_ERROR", @"Cannot update params while streaming", nil);
135
- return;
136
- }
137
-
138
- if (g_tts_wrapper == nullptr || g_tts_model_dir == nil || g_tts_model_type == nil) {
139
- reject(@"TTS_UPDATE_ERROR", @"TTS not initialized", nil);
140
- return;
141
- }
142
-
143
- NSNumber *nextNoiseScale = nil;
144
- if (noiseScale == nil) {
145
- nextNoiseScale = nil;
146
- } else if (isnan([noiseScale doubleValue])) {
147
- nextNoiseScale = g_tts_noise_scale;
148
- } else {
149
- nextNoiseScale = noiseScale;
150
- }
151
-
152
- NSNumber *nextNoiseScaleW = nil;
153
- if (noiseScaleW == nil) {
154
- nextNoiseScaleW = nil;
155
- } else if (isnan([noiseScaleW doubleValue])) {
156
- nextNoiseScaleW = g_tts_noise_scale_w;
157
- } else {
158
- nextNoiseScaleW = noiseScaleW;
159
- }
160
-
161
- NSNumber *nextLengthScale = nil;
162
- if (lengthScale == nil) {
163
- nextLengthScale = nil;
164
- } else if (isnan([lengthScale doubleValue])) {
165
- nextLengthScale = g_tts_length_scale;
166
- } else {
167
- nextLengthScale = lengthScale;
168
- }
169
-
170
- @try {
171
- std::optional<float> noiseScaleOpt = std::nullopt;
172
- std::optional<float> noiseScaleWOpt = std::nullopt;
173
- std::optional<float> lengthScaleOpt = std::nullopt;
174
- if (nextNoiseScale != nil) {
175
- noiseScaleOpt = [nextNoiseScale floatValue];
176
- }
177
- if (nextNoiseScaleW != nil) {
178
- noiseScaleWOpt = [nextNoiseScaleW floatValue];
179
- }
180
- if (nextLengthScale != nil) {
181
- lengthScaleOpt = [nextLengthScale floatValue];
182
- }
183
-
184
- sherpaonnx::TtsInitializeResult result = g_tts_wrapper->initialize(
185
- std::string([g_tts_model_dir UTF8String]),
186
- std::string([g_tts_model_type UTF8String]),
187
- g_tts_num_threads,
188
- g_tts_debug,
189
- noiseScaleOpt,
190
- noiseScaleWOpt,
191
- lengthScaleOpt
192
- );
193
-
194
- if (!result.success) {
195
- NSString *errorMsg = @"Failed to update TTS params";
196
- RCTLogError(@"%@", errorMsg);
197
- reject(@"TTS_UPDATE_ERROR", errorMsg, nil);
198
- return;
199
- }
200
-
201
- g_tts_noise_scale = nextNoiseScale ? [nextNoiseScale copy] : nil;
202
- g_tts_noise_scale_w = nextNoiseScaleW ? [nextNoiseScaleW copy] : nil;
203
- g_tts_length_scale = nextLengthScale ? [nextLengthScale copy] : nil;
204
-
205
- NSMutableArray *detectedModelsArray = [NSMutableArray array];
206
- for (const auto& model : result.detectedModels) {
207
- NSDictionary *modelDict = @{
208
- @"type": [NSString stringWithUTF8String:model.type.c_str()],
209
- @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()]
210
- };
211
- [detectedModelsArray addObject:modelDict];
212
- }
213
-
214
- NSDictionary *resultDict = @{
215
- @"success": @YES,
216
- @"detectedModels": detectedModelsArray
217
- };
218
-
219
- resolve(resultDict);
220
- } @catch (NSException *exception) {
221
- NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS update: %@", exception.reason];
222
- RCTLogError(@"%@", errorMsg);
223
- reject(@"TTS_UPDATE_ERROR", errorMsg, nil);
224
- }
225
- }
226
-
227
- - (void)generateTts:(NSString *)text
228
- sid:(double)sid
229
- speed:(double)speed
230
- withResolver:(RCTPromiseResolveBlock)resolve
231
- withRejecter:(RCTPromiseRejectBlock)reject
232
- {
233
- @try {
234
- if (g_tts_wrapper == nullptr || !g_tts_wrapper->isInitialized()) {
235
- NSString *errorMsg = @"TTS not initialized. Call initializeTts() first.";
236
- RCTLogError(@"%@", errorMsg);
237
- reject(@"TTS_NOT_INITIALIZED", errorMsg, nil);
238
- return;
239
- }
240
-
241
- std::string textStr = [text UTF8String];
242
-
243
- auto result = g_tts_wrapper->generate(
244
- textStr,
245
- static_cast<int32_t>(sid),
246
- static_cast<float>(speed)
247
- );
248
-
249
- if (result.samples.empty() || result.sampleRate == 0) {
250
- NSString *errorMsg = @"Failed to generate speech or result is empty";
251
- RCTLogError(@"%@", errorMsg);
252
- reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
253
- return;
254
- }
255
-
256
- NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:result.samples.size()];
257
- for (float sample : result.samples) {
258
- [samplesArray addObject:@(sample)];
259
- }
260
-
261
- NSDictionary *resultDict = @{
262
- @"samples": samplesArray,
263
- @"sampleRate": @(result.sampleRate)
264
- };
265
-
266
- RCTLogInfo(@"TTS: Generated %lu samples at %d Hz",
267
- (unsigned long)result.samples.size(), result.sampleRate);
268
-
269
- resolve(resultDict);
270
- } @catch (NSException *exception) {
271
- NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS generation: %@", exception.reason];
272
- RCTLogError(@"%@", errorMsg);
273
- reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
274
- }
275
- }
276
-
277
- - (void)generateTtsWithTimestamps:(NSString *)text
278
- sid:(double)sid
279
- speed:(double)speed
280
- withResolver:(RCTPromiseResolveBlock)resolve
281
- withRejecter:(RCTPromiseRejectBlock)reject
282
- {
283
- @try {
284
- if (g_tts_wrapper == nullptr || !g_tts_wrapper->isInitialized()) {
285
- NSString *errorMsg = @"TTS not initialized. Call initializeTts() first.";
286
- RCTLogError(@"%@", errorMsg);
287
- reject(@"TTS_NOT_INITIALIZED", errorMsg, nil);
288
- return;
289
- }
290
-
291
- std::string textStr = [text UTF8String];
292
-
293
- auto result = g_tts_wrapper->generate(
294
- textStr,
295
- static_cast<int32_t>(sid),
296
- static_cast<float>(speed)
297
- );
298
-
299
- if (result.samples.empty() || result.sampleRate == 0) {
300
- NSString *errorMsg = @"Failed to generate speech or result is empty";
301
- RCTLogError(@"%@", errorMsg);
302
- reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
303
- return;
304
- }
305
-
306
- NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:result.samples.size()];
307
- for (float sample : result.samples) {
308
- [samplesArray addObject:@(sample)];
309
- }
310
-
311
- std::vector<std::string> tokens = SplitTtsTokens(textStr);
312
- NSMutableArray *subtitlesArray = [NSMutableArray array];
313
- if (!tokens.empty()) {
314
- double totalSeconds = static_cast<double>(result.samples.size()) /
315
- static_cast<double>(result.sampleRate);
316
- double perToken = totalSeconds / static_cast<double>(tokens.size());
317
-
318
- for (size_t i = 0; i < tokens.size(); ++i) {
319
- double start = perToken * static_cast<double>(i);
320
- double end = perToken * static_cast<double>(i + 1);
321
- NSDictionary *item = @{
322
- @"text": [NSString stringWithUTF8String:tokens[i].c_str()],
323
- @"start": @(start),
324
- @"end": @(end)
325
- };
326
- [subtitlesArray addObject:item];
327
- }
328
- }
329
-
330
- NSDictionary *resultDict = @{
331
- @"samples": samplesArray,
332
- @"sampleRate": @(result.sampleRate),
333
- @"subtitles": subtitlesArray,
334
- @"estimated": @YES
335
- };
336
-
337
- resolve(resultDict);
338
- } @catch (NSException *exception) {
339
- NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS generation: %@", exception.reason];
340
- RCTLogError(@"%@", errorMsg);
341
- reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
342
- }
343
- }
344
-
345
- - (void)generateTtsStream:(NSString *)text
346
- sid:(double)sid
347
- speed:(double)speed
348
- withResolver:(RCTPromiseResolveBlock)resolve
349
- withRejecter:(RCTPromiseRejectBlock)reject
350
- {
351
- if (g_tts_stream_running.load()) {
352
- reject(@"TTS_STREAM_ERROR", @"TTS streaming already in progress", nil);
353
- return;
354
- }
355
-
356
- if (g_tts_wrapper == nullptr || !g_tts_wrapper->isInitialized()) {
357
- reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
358
- return;
359
- }
360
-
361
- g_tts_stream_cancelled.store(false);
362
- g_tts_stream_running.store(true);
363
-
364
- std::string textStr = [text UTF8String];
365
- int32_t sampleRate = g_tts_wrapper->getSampleRate();
366
-
367
- __weak SherpaOnnx *weakSelf = self;
368
- dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0), ^{
369
- bool success = false;
370
- @try {
371
- success = g_tts_wrapper->generateStream(
372
- textStr,
373
- static_cast<int32_t>(sid),
374
- static_cast<float>(speed),
375
- [weakSelf, sampleRate](const float *samples, int32_t numSamples, float progress) -> int32_t {
376
- if (g_tts_stream_cancelled.load()) {
377
- return 0;
378
- }
379
-
380
- NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:numSamples];
381
- for (int32_t i = 0; i < numSamples; i++) {
382
- [samplesArray addObject:@(samples[i])];
383
- }
384
-
385
- NSDictionary *payload = @{
386
- @"samples": samplesArray,
387
- @"sampleRate": @(sampleRate),
388
- @"progress": @(progress),
389
- @"isFinal": @NO
390
- };
391
-
392
- dispatch_async(dispatch_get_main_queue(), ^{
393
- if (weakSelf) {
394
- [weakSelf sendEventWithName:@"ttsStreamChunk" body:payload];
395
- }
396
- });
397
-
398
- return g_tts_stream_cancelled.load() ? 0 : 1;
399
- }
400
- );
401
- } @catch (NSException *exception) {
402
- NSString *errorMsg = [NSString stringWithFormat:@"TTS streaming failed: %@", exception.reason];
403
- dispatch_async(dispatch_get_main_queue(), ^{
404
- if (weakSelf) {
405
- [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"message": errorMsg }];
406
- }
407
- });
408
- }
409
-
410
- bool cancelled = g_tts_stream_cancelled.load();
411
- if (!success && !cancelled) {
412
- dispatch_async(dispatch_get_main_queue(), ^{
413
- if (weakSelf) {
414
- [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"message": @"TTS streaming generation failed" }];
415
- }
416
- });
417
- }
418
-
419
- dispatch_async(dispatch_get_main_queue(), ^{
420
- if (weakSelf) {
421
- [weakSelf sendEventWithName:@"ttsStreamEnd" body:@{ @"cancelled": @(cancelled) }];
422
- }
423
- });
424
-
425
- g_tts_stream_running.store(false);
426
- });
427
-
428
- resolve(nil);
429
- }
430
-
431
- - (void)cancelTtsStream:(RCTPromiseResolveBlock)resolve
432
- withRejecter:(RCTPromiseRejectBlock)reject
433
- {
434
- @try {
435
- g_tts_stream_cancelled.store(true);
436
- resolve(nil);
437
- } @catch (NSException *exception) {
438
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to cancel TTS stream: %@", exception.reason];
439
- reject(@"TTS_STREAM_ERROR", errorMsg, nil);
440
- }
441
- }
442
-
443
- - (void)startTtsPcmPlayer:(double)sampleRate
444
- channels:(double)channels
445
- withResolver:(RCTPromiseResolveBlock)resolve
446
- withRejecter:(RCTPromiseRejectBlock)reject
447
- {
448
- dispatch_async(dispatch_get_main_queue(), ^{
449
- @try {
450
- if (channels != 1.0) {
451
- reject(@"TTS_PCM_ERROR", @"PCM playback supports mono only", nil);
452
- return;
453
- }
454
- [self stopTtsPcmPlayer:^(__unused id result) {}
455
- withRejecter:^(__unused NSString *code, __unused NSString *message, __unused NSError *error) {}];
456
-
457
- AVAudioSession *session = [AVAudioSession sharedInstance];
458
- [session setCategory:AVAudioSessionCategoryPlayback error:nil];
459
- [session setActive:YES error:nil];
460
-
461
- g_tts_engine = [[AVAudioEngine alloc] init];
462
- g_tts_player = [[AVAudioPlayerNode alloc] init];
463
-
464
- g_tts_format = [[AVAudioFormat alloc] initStandardFormatWithSampleRate:sampleRate channels:1];
465
-
466
- [g_tts_engine attachNode:g_tts_player];
467
- [g_tts_engine connect:g_tts_player to:g_tts_engine.mainMixerNode format:g_tts_format];
468
-
469
- NSError *startError = nil;
470
- if (![g_tts_engine startAndReturnError:&startError]) {
471
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to start audio engine: %@", startError.localizedDescription];
472
- reject(@"TTS_PCM_ERROR", errorMsg, startError);
473
- return;
474
- }
475
-
476
- [g_tts_player play];
477
- resolve(nil);
478
- } @catch (NSException *exception) {
479
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to start PCM player: %@", exception.reason];
480
- reject(@"TTS_PCM_ERROR", errorMsg, nil);
481
- }
482
- });
483
- }
484
-
485
- - (void)writeTtsPcmChunk:(NSArray<NSNumber *> *)samples
486
- withResolver:(RCTPromiseResolveBlock)resolve
487
- withRejecter:(RCTPromiseRejectBlock)reject
488
- {
489
- @try {
490
- if (g_tts_engine == nil || g_tts_player == nil || g_tts_format == nil) {
491
- reject(@"TTS_PCM_ERROR", @"PCM player not initialized", nil);
492
- return;
493
- }
494
-
495
- AVAudioFrameCount frameCount = (AVAudioFrameCount)[samples count];
496
- AVAudioPCMBuffer *buffer = [[AVAudioPCMBuffer alloc] initWithPCMFormat:g_tts_format frameCapacity:frameCount];
497
- buffer.frameLength = frameCount;
498
-
499
- float *channelData = buffer.floatChannelData[0];
500
- for (NSUInteger i = 0; i < [samples count]; i++) {
501
- channelData[i] = [samples[i] floatValue];
502
- }
503
-
504
- [g_tts_player scheduleBuffer:buffer completionHandler:nil];
505
- resolve(nil);
506
- } @catch (NSException *exception) {
507
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to write PCM chunk: %@", exception.reason];
508
- reject(@"TTS_PCM_ERROR", errorMsg, nil);
509
- }
510
- }
511
-
512
- - (void)stopTtsPcmPlayer:(RCTPromiseResolveBlock)resolve
513
- withRejecter:(RCTPromiseRejectBlock)reject
514
- {
515
- dispatch_async(dispatch_get_main_queue(), ^{
516
- @try {
517
- if (g_tts_player != nil) {
518
- [g_tts_player stop];
519
- }
520
- if (g_tts_engine != nil) {
521
- [g_tts_engine stop];
522
- [g_tts_engine reset];
523
- }
524
- g_tts_player = nil;
525
- g_tts_engine = nil;
526
- g_tts_format = nil;
527
- resolve(nil);
528
- } @catch (NSException *exception) {
529
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to stop PCM player: %@", exception.reason];
530
- reject(@"TTS_PCM_ERROR", errorMsg, nil);
531
- }
532
- });
533
- }
534
-
535
- - (void)getTtsSampleRate:(RCTPromiseResolveBlock)resolve
536
- withRejecter:(RCTPromiseRejectBlock)reject
537
- {
538
- @try {
539
- if (g_tts_wrapper == nullptr || !g_tts_wrapper->isInitialized()) {
540
- NSString *errorMsg = @"TTS not initialized. Call initializeTts() first.";
541
- reject(@"TTS_NOT_INITIALIZED", errorMsg, nil);
542
- return;
543
- }
544
-
545
- int32_t sampleRate = g_tts_wrapper->getSampleRate();
546
- resolve(@(sampleRate));
547
- } @catch (NSException *exception) {
548
- NSString *errorMsg = [NSString stringWithFormat:@"Exception getting sample rate: %@", exception.reason];
549
- reject(@"TTS_ERROR", errorMsg, nil);
550
- }
551
- }
552
-
553
- - (void)getTtsNumSpeakers:(RCTPromiseResolveBlock)resolve
554
- withRejecter:(RCTPromiseRejectBlock)reject
555
- {
556
- @try {
557
- if (g_tts_wrapper == nullptr || !g_tts_wrapper->isInitialized()) {
558
- NSString *errorMsg = @"TTS not initialized. Call initializeTts() first.";
559
- reject(@"TTS_NOT_INITIALIZED", errorMsg, nil);
560
- return;
561
- }
562
-
563
- int32_t numSpeakers = g_tts_wrapper->getNumSpeakers();
564
- resolve(@(numSpeakers));
565
- } @catch (NSException *exception) {
566
- NSString *errorMsg = [NSString stringWithFormat:@"Exception getting num speakers: %@", exception.reason];
567
- reject(@"TTS_ERROR", errorMsg, nil);
568
- }
569
- }
570
-
571
- - (void)unloadTts:(RCTPromiseResolveBlock)resolve
572
- withRejecter:(RCTPromiseRejectBlock)reject
573
- {
574
- @try {
575
- [self stopTtsPcmPlayer:^(__unused id result) {}
576
- withRejecter:^(__unused NSString *code, __unused NSString *message, __unused NSError *error) {}];
577
- if (g_tts_wrapper != nullptr) {
578
- g_tts_wrapper->release();
579
- g_tts_wrapper.reset();
580
- g_tts_wrapper = nullptr;
581
- }
582
- g_tts_model_dir = nil;
583
- g_tts_model_type = nil;
584
- g_tts_num_threads = 2;
585
- g_tts_debug = NO;
586
- g_tts_noise_scale = nil;
587
- g_tts_length_scale = nil;
588
- RCTLogInfo(@"TTS resources released");
589
- resolve(nil);
590
- } @catch (NSException *exception) {
591
- NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS cleanup: %@", exception.reason];
592
- RCTLogError(@"%@", errorMsg);
593
- reject(@"TTS_CLEANUP_ERROR", errorMsg, nil);
594
- }
595
- }
596
-
597
- - (void)saveTtsAudioToFile:(NSArray<NSNumber *> *)samples
598
- withSampleRate:(double)sampleRate
599
- withFilePath:(NSString *)filePath
600
- withResolver:(RCTPromiseResolveBlock)resolve
601
- withRejecter:(RCTPromiseRejectBlock)reject
602
- {
603
- @try {
604
- std::vector<float> samplesVec;
605
- samplesVec.reserve([samples count]);
606
- for (NSNumber *num in samples) {
607
- samplesVec.push_back([num floatValue]);
608
- }
609
-
610
- std::string filePathStr = std::string([filePath UTF8String]);
611
-
612
- bool success = sherpaonnx::TtsWrapper::saveToWavFile(
613
- samplesVec,
614
- static_cast<int32_t>(sampleRate),
615
- filePathStr
616
- );
617
-
618
- if (success) {
619
- resolve(filePath);
620
- } else {
621
- reject(@"TTS_SAVE_ERROR", @"Failed to save audio to file", nil);
622
- }
623
- } @catch (NSException *exception) {
624
- NSString *errorMsg = [NSString stringWithFormat:@"Exception saving TTS audio: %@", exception.reason];
625
- reject(@"TTS_SAVE_ERROR", errorMsg, nil);
626
- }
627
- }
628
-
629
- - (void)saveTtsTextToContentUri:(NSString *)text
630
- directoryUri:(NSString *)directoryUri
631
- filename:(NSString *)filename
632
- mimeType:(NSString *)mimeType
633
- withResolver:(RCTPromiseResolveBlock)resolve
634
- withRejecter:(RCTPromiseRejectBlock)reject
635
- {
636
- @try {
637
- if ([directoryUri hasPrefix:@"content://"]) {
638
- reject(@"TTS_SAVE_ERROR", @"Content URIs are not supported on iOS", nil);
639
- return;
640
- }
641
-
642
- NSURL *directoryUrl = nil;
643
- if ([directoryUri hasPrefix:@"file://"]) {
644
- directoryUrl = [NSURL URLWithString:directoryUri];
645
- } else {
646
- directoryUrl = [NSURL fileURLWithPath:directoryUri];
647
- }
648
-
649
- if (!directoryUrl) {
650
- reject(@"TTS_SAVE_ERROR", @"Invalid directory URL", nil);
651
- return;
652
- }
653
-
654
- NSString *directoryPath = [directoryUrl path];
655
- NSString *filePath = [directoryPath stringByAppendingPathComponent:filename];
656
-
657
- NSError *writeError = nil;
658
- BOOL success = [text writeToFile:filePath
659
- atomically:YES
660
- encoding:NSUTF8StringEncoding
661
- error:&writeError];
662
-
663
- if (!success || writeError) {
664
- reject(@"TTS_SAVE_ERROR", @"Failed to save text to file", writeError);
665
- return;
666
- }
667
-
668
- resolve(filePath);
669
- } @catch (NSException *exception) {
670
- NSString *errorMsg = [NSString stringWithFormat:@"Exception saving text file: %@", exception.reason];
671
- reject(@"TTS_SAVE_ERROR", errorMsg, nil);
672
- }
673
- }
674
-
675
- - (void)shareTtsAudio:(NSString *)fileUri
676
- mimeType:(NSString *)mimeType
677
- withResolver:(RCTPromiseResolveBlock)resolve
678
- withRejecter:(RCTPromiseRejectBlock)reject
679
- {
680
- @try {
681
- NSURL *url = nil;
682
- if ([fileUri hasPrefix:@"file://"] || [fileUri hasPrefix:@"content://"]) {
683
- url = [NSURL URLWithString:fileUri];
684
- } else {
685
- url = [NSURL fileURLWithPath:fileUri];
686
- }
687
-
688
- if (!url) {
689
- reject(@"TTS_SHARE_ERROR", @"Invalid file URL", nil);
690
- return;
691
- }
692
-
693
- dispatch_async(dispatch_get_main_queue(), ^{
694
- UIViewController *controller = RCTPresentedViewController();
695
- if (!controller) {
696
- reject(@"TTS_SHARE_ERROR", @"No active view controller", nil);
697
- return;
698
- }
699
-
700
- UIActivityViewController *activity =
701
- [[UIActivityViewController alloc] initWithActivityItems:@[url]
702
- applicationActivities:nil];
703
- [controller presentViewController:activity animated:YES completion:nil];
704
- resolve(nil);
705
- });
706
- } @catch (NSException *exception) {
707
- NSString *errorMsg = [NSString stringWithFormat:@"Failed to share audio: %@", exception.reason];
708
- reject(@"TTS_SHARE_ERROR", errorMsg, nil);
709
- }
710
- }
711
-
712
- @end
1
+ /**
2
+ * SherpaOnnx+TTS.mm
3
+ *
4
+ * Purpose: TTS (text-to-speech) TurboModule methods: createTTS, releaseTTS, generateTTS, and event
5
+ * emission. Uses sherpa-onnx-tts-wrapper for native synthesis and sherpa-onnx-model-detect for model detection.
6
+ */
7
+
8
+ #import "SherpaOnnx.h"
9
+ #import <React/RCTLog.h>
10
+ #import <React/RCTUtils.h>
11
+ #import <UIKit/UIKit.h>
12
+ #import <AVFoundation/AVFoundation.h>
13
+
14
+ #include "sherpa-onnx-tts-wrapper.h"
15
+ #include "sherpa-onnx-model-detect.h"
16
+ #include <atomic>
17
+ #include <condition_variable>
18
+ #include <memory>
19
+ #include <mutex>
20
+ #include <sstream>
21
+ #include <string>
22
+ #include <unordered_map>
23
+ #include <vector>
24
+ #include <chrono>
25
+
26
+ struct TtsInstanceState {
27
+ std::unique_ptr<sherpaonnx::TtsWrapper> wrapper;
28
+ std::atomic<bool> streamRunning{false};
29
+ std::atomic<bool> streamCancelled{false};
30
+ __strong AVAudioEngine *engine = nil;
31
+ __strong AVAudioPlayerNode *player = nil;
32
+ __strong AVAudioFormat *format = nil;
33
+ __strong NSString *modelDir = nil;
34
+ __strong NSString *modelType = nil;
35
+ int32_t numThreads = 2;
36
+ BOOL debug = NO;
37
+ __strong NSNumber *noiseScale = nil;
38
+ __strong NSNumber *noiseScaleW = nil;
39
+ __strong NSNumber *lengthScale = nil;
40
+ __strong NSString *ruleFsts = nil;
41
+ __strong NSString *ruleFars = nil;
42
+ __strong NSNumber *maxNumSentences = nil;
43
+ __strong NSNumber *silenceScale = nil;
44
+ __strong NSString *provider = nil;
45
+ };
46
+
47
+ static std::unordered_map<std::string, std::shared_ptr<TtsInstanceState>> g_tts_instances;
48
+ static std::mutex g_tts_mutex;
49
+ static std::condition_variable g_tts_stream_cv;
50
+
51
+ static NSString *ttsModelKindToNSString(sherpaonnx::TtsModelKind kind) {
52
+ using K = sherpaonnx::TtsModelKind;
53
+ switch (kind) {
54
+ case K::kVits: return @"vits";
55
+ case K::kMatcha: return @"matcha";
56
+ case K::kKokoro: return @"kokoro";
57
+ case K::kKitten: return @"kitten";
58
+ case K::kZipvoice: return @"zipvoice";
59
+ default: return @"unknown";
60
+ }
61
+ }
62
+
63
+ namespace {
64
+ std::vector<std::string> SplitTtsTokens(const std::string &text) {
65
+ std::vector<std::string> tokens;
66
+ std::istringstream iss(text);
67
+ std::string token;
68
+ while (iss >> token) {
69
+ tokens.push_back(token);
70
+ }
71
+ if (tokens.empty() && !text.empty()) {
72
+ tokens.push_back(text);
73
+ }
74
+ return tokens;
75
+ }
76
+ }
77
+
78
+ @implementation SherpaOnnx (TTS)
79
+
80
+ - (void)initializeTts:(NSString *)instanceId
81
+ modelDir:(NSString *)modelDir
82
+ modelType:(NSString *)modelType
83
+ numThreads:(double)numThreads
84
+ debug:(BOOL)debug
85
+ noiseScale:(NSNumber *)noiseScale
86
+ noiseScaleW:(NSNumber *)noiseScaleW
87
+ lengthScale:(NSNumber *)lengthScale
88
+ ruleFsts:(NSString *)ruleFsts
89
+ ruleFars:(NSString *)ruleFars
90
+ maxNumSentences:(NSNumber *)maxNumSentences
91
+ silenceScale:(NSNumber *)silenceScale
92
+ provider:(NSString *)provider
93
+ resolve:(RCTPromiseResolveBlock)resolve
94
+ reject:(RCTPromiseRejectBlock)reject
95
+ {
96
+ if (instanceId == nil || [instanceId length] == 0) {
97
+ reject(@"TTS_INIT_ERROR", @"instanceId is required", nil);
98
+ return;
99
+ }
100
+ std::string instanceIdStr = [instanceId UTF8String];
101
+ RCTLogInfo(@"Initializing TTS instance %@ with modelDir: %@, modelType: %@", instanceId, modelDir, modelType);
102
+
103
+ @try {
104
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
105
+ auto it = g_tts_instances.find(instanceIdStr);
106
+ if (it == g_tts_instances.end()) {
107
+ g_tts_instances[instanceIdStr] = std::make_shared<TtsInstanceState>();
108
+ }
109
+ TtsInstanceState *inst = g_tts_instances[instanceIdStr].get();
110
+ if (inst->wrapper == nullptr) {
111
+ inst->wrapper = std::make_unique<sherpaonnx::TtsWrapper>();
112
+ }
113
+
114
+ std::string modelDirStr = [modelDir UTF8String];
115
+ std::string modelTypeStr = [modelType UTF8String];
116
+
117
+ std::optional<float> noiseScaleOpt = std::nullopt;
118
+ std::optional<float> noiseScaleWOpt = std::nullopt;
119
+ std::optional<float> lengthScaleOpt = std::nullopt;
120
+ if (noiseScale != nil) {
121
+ noiseScaleOpt = [noiseScale floatValue];
122
+ }
123
+ if (noiseScaleW != nil) {
124
+ noiseScaleWOpt = [noiseScaleW floatValue];
125
+ }
126
+ if (lengthScale != nil) {
127
+ lengthScaleOpt = [lengthScale floatValue];
128
+ }
129
+
130
+ std::optional<std::string> ruleFstsOpt = std::nullopt;
131
+ std::optional<std::string> ruleFarsOpt = std::nullopt;
132
+ std::optional<int32_t> maxNumSentencesOpt = std::nullopt;
133
+ std::optional<float> silenceScaleOpt = std::nullopt;
134
+ if (ruleFsts != nil && [ruleFsts length] > 0) {
135
+ ruleFstsOpt = std::string([ruleFsts UTF8String]);
136
+ }
137
+ if (ruleFars != nil && [ruleFars length] > 0) {
138
+ ruleFarsOpt = std::string([ruleFars UTF8String]);
139
+ }
140
+ if (maxNumSentences != nil && [maxNumSentences intValue] >= 1) {
141
+ maxNumSentencesOpt = static_cast<int32_t>([maxNumSentences intValue]);
142
+ }
143
+ if (silenceScale != nil) {
144
+ silenceScaleOpt = [silenceScale floatValue];
145
+ }
146
+ std::optional<std::string> providerOpt = std::nullopt;
147
+ if (provider != nil && [provider length] > 0) {
148
+ providerOpt = std::string([provider UTF8String]);
149
+ }
150
+
151
+ sherpaonnx::TtsInitializeResult result = inst->wrapper->initialize(
152
+ modelDirStr,
153
+ modelTypeStr,
154
+ static_cast<int32_t>(numThreads),
155
+ debug,
156
+ noiseScaleOpt,
157
+ noiseScaleWOpt,
158
+ lengthScaleOpt,
159
+ ruleFstsOpt,
160
+ ruleFarsOpt,
161
+ maxNumSentencesOpt,
162
+ silenceScaleOpt,
163
+ providerOpt
164
+ );
165
+
166
+ if (result.success) {
167
+ RCTLogInfo(@"TTS initialization successful for instance %@", instanceId);
168
+
169
+ inst->modelDir = [modelDir copy];
170
+ inst->modelType = [modelType copy];
171
+ inst->numThreads = static_cast<int32_t>(numThreads);
172
+ inst->debug = debug;
173
+ inst->noiseScale = noiseScale ? [noiseScale copy] : nil;
174
+ inst->noiseScaleW = noiseScaleW ? [noiseScaleW copy] : nil;
175
+ inst->lengthScale = lengthScale ? [lengthScale copy] : nil;
176
+ inst->ruleFsts = (ruleFsts != nil && [ruleFsts length] > 0) ? [ruleFsts copy] : nil;
177
+ inst->ruleFars = (ruleFars != nil && [ruleFars length] > 0) ? [ruleFars copy] : nil;
178
+ inst->maxNumSentences = (maxNumSentences != nil && [maxNumSentences intValue] >= 1) ? [maxNumSentences copy] : nil;
179
+ inst->silenceScale = silenceScale ? [silenceScale copy] : nil;
180
+ inst->provider = (provider != nil && [provider length] > 0) ? [provider copy] : nil;
181
+
182
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
183
+ for (const auto& model : result.detectedModels) {
184
+ NSDictionary *modelDict = @{
185
+ @"type": [NSString stringWithUTF8String:model.type.c_str()],
186
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()]
187
+ };
188
+ [detectedModelsArray addObject:modelDict];
189
+ }
190
+
191
+ NSDictionary *resultDict = @{
192
+ @"success": @YES,
193
+ @"detectedModels": detectedModelsArray
194
+ };
195
+
196
+ resolve(resultDict);
197
+ } else {
198
+ NSString *errorMsg = @"Failed to initialize TTS";
199
+ RCTLogError(@"%@", errorMsg);
200
+ reject(@"TTS_INIT_ERROR", errorMsg, nil);
201
+ }
202
+ } @catch (NSException *exception) {
203
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS init: %@", exception.reason];
204
+ RCTLogError(@"%@", errorMsg);
205
+ reject(@"TTS_INIT_ERROR", errorMsg, nil);
206
+ }
207
+ }
208
+
209
+ - (void)detectTtsModel:(NSString *)modelDir
210
+ modelType:(NSString *)modelType
211
+ resolve:(RCTPromiseResolveBlock)resolve
212
+ reject:(RCTPromiseRejectBlock)reject
213
+ {
214
+ RCTLogInfo(@"Detecting TTS model in: %@", modelDir);
215
+ @try {
216
+ std::string modelDirStr = [modelDir UTF8String];
217
+ std::string modelTypeStr = (modelType != nil && [modelType length] > 0 && ![modelType isEqualToString:@"auto"])
218
+ ? [modelType UTF8String] : "auto";
219
+ sherpaonnx::TtsDetectResult result = sherpaonnx::DetectTtsModel(modelDirStr, modelTypeStr);
220
+
221
+ NSMutableDictionary *resultDict = [NSMutableDictionary dictionary];
222
+ resultDict[@"success"] = @(result.ok);
223
+ if (!result.error.empty()) {
224
+ resultDict[@"error"] = [NSString stringWithUTF8String:result.error.c_str()];
225
+ }
226
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
227
+ for (const auto& model : result.detectedModels) {
228
+ [detectedModelsArray addObject:@{
229
+ @"type": [NSString stringWithUTF8String:model.type.c_str()],
230
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()]
231
+ }];
232
+ }
233
+ resultDict[@"detectedModels"] = detectedModelsArray;
234
+ resultDict[@"modelType"] = ttsModelKindToNSString(result.selectedKind);
235
+ resolve(resultDict);
236
+ } @catch (NSException *exception) {
237
+ NSString *errorMsg = [NSString stringWithFormat:@"TTS model detection failed: %@", exception.reason];
238
+ RCTLogError(@"%@", errorMsg);
239
+ reject(@"DETECT_ERROR", errorMsg, nil);
240
+ }
241
+ }
242
+
243
+ - (void)updateTtsParams:(NSString *)instanceId
244
+ noiseScale:(NSNumber *)noiseScale
245
+ noiseScaleW:(NSNumber *)noiseScaleW
246
+ lengthScale:(NSNumber *)lengthScale
247
+ resolve:(RCTPromiseResolveBlock)resolve
248
+ reject:(RCTPromiseRejectBlock)reject
249
+ {
250
+ if (instanceId == nil || [instanceId length] == 0) {
251
+ reject(@"TTS_UPDATE_ERROR", @"instanceId is required", nil);
252
+ return;
253
+ }
254
+ std::string instanceIdStr = [instanceId UTF8String];
255
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
256
+ auto it = g_tts_instances.find(instanceIdStr);
257
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || it->second->modelDir == nil || it->second->modelType == nil) {
258
+ reject(@"TTS_UPDATE_ERROR", @"TTS instance not found or not initialized", nil);
259
+ return;
260
+ }
261
+ TtsInstanceState *inst = it->second.get();
262
+ if (inst->streamRunning.load()) {
263
+ reject(@"TTS_UPDATE_ERROR", @"Cannot update params while streaming", nil);
264
+ return;
265
+ }
266
+
267
+ NSNumber *nextNoiseScale = nil;
268
+ if (noiseScale == nil) {
269
+ nextNoiseScale = nil;
270
+ } else if (isnan([noiseScale doubleValue])) {
271
+ nextNoiseScale = inst->noiseScale;
272
+ } else {
273
+ nextNoiseScale = noiseScale;
274
+ }
275
+
276
+ NSNumber *nextNoiseScaleW = nil;
277
+ if (noiseScaleW == nil) {
278
+ nextNoiseScaleW = nil;
279
+ } else if (isnan([noiseScaleW doubleValue])) {
280
+ nextNoiseScaleW = inst->noiseScaleW;
281
+ } else {
282
+ nextNoiseScaleW = noiseScaleW;
283
+ }
284
+
285
+ NSNumber *nextLengthScale = nil;
286
+ if (lengthScale == nil) {
287
+ nextLengthScale = nil;
288
+ } else if (isnan([lengthScale doubleValue])) {
289
+ nextLengthScale = inst->lengthScale;
290
+ } else {
291
+ nextLengthScale = lengthScale;
292
+ }
293
+
294
+ @try {
295
+ std::optional<float> noiseScaleOpt = std::nullopt;
296
+ std::optional<float> noiseScaleWOpt = std::nullopt;
297
+ std::optional<float> lengthScaleOpt = std::nullopt;
298
+ if (nextNoiseScale != nil) {
299
+ noiseScaleOpt = [nextNoiseScale floatValue];
300
+ }
301
+ if (nextNoiseScaleW != nil) {
302
+ noiseScaleWOpt = [nextNoiseScaleW floatValue];
303
+ }
304
+ if (nextLengthScale != nil) {
305
+ lengthScaleOpt = [nextLengthScale floatValue];
306
+ }
307
+
308
+ std::optional<std::string> ruleFstsOpt = std::nullopt;
309
+ std::optional<std::string> ruleFarsOpt = std::nullopt;
310
+ std::optional<int32_t> maxNumSentencesOpt = std::nullopt;
311
+ std::optional<float> silenceScaleOpt = std::nullopt;
312
+ if (inst->ruleFsts != nil && [inst->ruleFsts length] > 0) {
313
+ ruleFstsOpt = std::string([inst->ruleFsts UTF8String]);
314
+ }
315
+ if (inst->ruleFars != nil && [inst->ruleFars length] > 0) {
316
+ ruleFarsOpt = std::string([inst->ruleFars UTF8String]);
317
+ }
318
+ if (inst->maxNumSentences != nil && [inst->maxNumSentences intValue] >= 1) {
319
+ maxNumSentencesOpt = static_cast<int32_t>([inst->maxNumSentences intValue]);
320
+ }
321
+ if (inst->silenceScale != nil) {
322
+ silenceScaleOpt = [inst->silenceScale floatValue];
323
+ }
324
+ std::optional<std::string> providerOpt = std::nullopt;
325
+ if (inst->provider != nil && [inst->provider length] > 0) {
326
+ providerOpt = std::string([inst->provider UTF8String]);
327
+ }
328
+
329
+ sherpaonnx::TtsInitializeResult result = inst->wrapper->initialize(
330
+ std::string([inst->modelDir UTF8String]),
331
+ std::string([inst->modelType UTF8String]),
332
+ inst->numThreads,
333
+ inst->debug,
334
+ noiseScaleOpt,
335
+ noiseScaleWOpt,
336
+ lengthScaleOpt,
337
+ ruleFstsOpt,
338
+ ruleFarsOpt,
339
+ maxNumSentencesOpt,
340
+ silenceScaleOpt,
341
+ providerOpt
342
+ );
343
+
344
+ if (!result.success) {
345
+ NSString *errorMsg = @"Failed to update TTS params";
346
+ RCTLogError(@"%@", errorMsg);
347
+ reject(@"TTS_UPDATE_ERROR", errorMsg, nil);
348
+ return;
349
+ }
350
+
351
+ inst->noiseScale = nextNoiseScale ? [nextNoiseScale copy] : nil;
352
+ inst->noiseScaleW = nextNoiseScaleW ? [nextNoiseScaleW copy] : nil;
353
+ inst->lengthScale = nextLengthScale ? [nextLengthScale copy] : nil;
354
+
355
+ NSMutableArray *detectedModelsArray = [NSMutableArray array];
356
+ for (const auto& model : result.detectedModels) {
357
+ NSDictionary *modelDict = @{
358
+ @"type": [NSString stringWithUTF8String:model.type.c_str()],
359
+ @"modelDir": [NSString stringWithUTF8String:model.modelDir.c_str()]
360
+ };
361
+ [detectedModelsArray addObject:modelDict];
362
+ }
363
+
364
+ NSDictionary *resultDict = @{
365
+ @"success": @YES,
366
+ @"detectedModels": detectedModelsArray
367
+ };
368
+
369
+ resolve(resultDict);
370
+ } @catch (NSException *exception) {
371
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS update: %@", exception.reason];
372
+ RCTLogError(@"%@", errorMsg);
373
+ reject(@"TTS_UPDATE_ERROR", errorMsg, nil);
374
+ }
375
+ }
376
+
377
+ - (void)generateTts:(NSString *)instanceId
378
+ text:(NSString *)text
379
+ options:(NSDictionary *)options
380
+ resolve:(RCTPromiseResolveBlock)resolve
381
+ reject:(RCTPromiseRejectBlock)reject
382
+ {
383
+ if (instanceId == nil || [instanceId length] == 0) {
384
+ reject(@"TTS_GENERATE_ERROR", @"instanceId is required", nil);
385
+ return;
386
+ }
387
+ double sid = 0;
388
+ double speed = 1.0;
389
+ if (options != nil) {
390
+ if (options[@"sid"] != nil) sid = [options[@"sid"] doubleValue];
391
+ if (options[@"speed"] != nil) speed = [options[@"speed"] doubleValue];
392
+ }
393
+ std::string instanceIdStr = [instanceId UTF8String];
394
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
395
+ auto it = g_tts_instances.find(instanceIdStr);
396
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || !it->second->wrapper->isInitialized()) {
397
+ reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
398
+ return;
399
+ }
400
+ sherpaonnx::TtsWrapper *wrapper = it->second->wrapper.get();
401
+ @try {
402
+ std::string textStr = [text UTF8String];
403
+
404
+ auto result = wrapper->generate(
405
+ textStr,
406
+ static_cast<int32_t>(sid),
407
+ static_cast<float>(speed)
408
+ );
409
+
410
+ if (result.samples.empty() || result.sampleRate == 0) {
411
+ NSString *errorMsg = @"Failed to generate speech or result is empty";
412
+ RCTLogError(@"%@", errorMsg);
413
+ reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
414
+ return;
415
+ }
416
+
417
+ NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:result.samples.size()];
418
+ for (float sample : result.samples) {
419
+ [samplesArray addObject:@(sample)];
420
+ }
421
+
422
+ NSDictionary *resultDict = @{
423
+ @"samples": samplesArray,
424
+ @"sampleRate": @(result.sampleRate)
425
+ };
426
+
427
+ RCTLogInfo(@"TTS: Generated %lu samples at %d Hz",
428
+ (unsigned long)result.samples.size(), result.sampleRate);
429
+
430
+ resolve(resultDict);
431
+ } @catch (NSException *exception) {
432
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS generation: %@", exception.reason];
433
+ RCTLogError(@"%@", errorMsg);
434
+ reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
435
+ }
436
+ }
437
+
438
+ - (void)generateTtsWithTimestamps:(NSString *)instanceId
439
+ text:(NSString *)text
440
+ options:(NSDictionary *)options
441
+ resolve:(RCTPromiseResolveBlock)resolve
442
+ reject:(RCTPromiseRejectBlock)reject
443
+ {
444
+ if (instanceId == nil || [instanceId length] == 0) {
445
+ reject(@"TTS_GENERATE_ERROR", @"instanceId is required", nil);
446
+ return;
447
+ }
448
+ double sid = 0;
449
+ double speed = 1.0;
450
+ if (options != nil) {
451
+ if (options[@"sid"] != nil) sid = [options[@"sid"] doubleValue];
452
+ if (options[@"speed"] != nil) speed = [options[@"speed"] doubleValue];
453
+ }
454
+ std::string instanceIdStr = [instanceId UTF8String];
455
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
456
+ auto it = g_tts_instances.find(instanceIdStr);
457
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || !it->second->wrapper->isInitialized()) {
458
+ reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
459
+ return;
460
+ }
461
+ sherpaonnx::TtsWrapper *wrapper = it->second->wrapper.get();
462
+ @try {
463
+ std::string textStr = [text UTF8String];
464
+
465
+ auto result = wrapper->generate(
466
+ textStr,
467
+ static_cast<int32_t>(sid),
468
+ static_cast<float>(speed)
469
+ );
470
+
471
+ if (result.samples.empty() || result.sampleRate == 0) {
472
+ NSString *errorMsg = @"Failed to generate speech or result is empty";
473
+ RCTLogError(@"%@", errorMsg);
474
+ reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
475
+ return;
476
+ }
477
+
478
+ NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:result.samples.size()];
479
+ for (float sample : result.samples) {
480
+ [samplesArray addObject:@(sample)];
481
+ }
482
+
483
+ std::vector<std::string> tokens = SplitTtsTokens(textStr);
484
+ NSMutableArray *subtitlesArray = [NSMutableArray array];
485
+ if (!tokens.empty()) {
486
+ double totalSeconds = static_cast<double>(result.samples.size()) /
487
+ static_cast<double>(result.sampleRate);
488
+ double perToken = totalSeconds / static_cast<double>(tokens.size());
489
+
490
+ for (size_t i = 0; i < tokens.size(); ++i) {
491
+ double start = perToken * static_cast<double>(i);
492
+ double end = perToken * static_cast<double>(i + 1);
493
+ NSDictionary *item = @{
494
+ @"text": [NSString stringWithUTF8String:tokens[i].c_str()],
495
+ @"start": @(start),
496
+ @"end": @(end)
497
+ };
498
+ [subtitlesArray addObject:item];
499
+ }
500
+ }
501
+
502
+ NSDictionary *resultDict = @{
503
+ @"samples": samplesArray,
504
+ @"sampleRate": @(result.sampleRate),
505
+ @"subtitles": subtitlesArray,
506
+ @"estimated": @YES
507
+ };
508
+
509
+ resolve(resultDict);
510
+ } @catch (NSException *exception) {
511
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS generation: %@", exception.reason];
512
+ RCTLogError(@"%@", errorMsg);
513
+ reject(@"TTS_GENERATE_ERROR", errorMsg, nil);
514
+ }
515
+ }
516
+
517
+ - (void)generateTtsStream:(NSString *)instanceId
518
+ text:(NSString *)text
519
+ options:(NSDictionary *)options
520
+ resolve:(RCTPromiseResolveBlock)resolve
521
+ reject:(RCTPromiseRejectBlock)reject
522
+ {
523
+ if (instanceId == nil || [instanceId length] == 0) {
524
+ reject(@"TTS_STREAM_ERROR", @"instanceId is required", nil);
525
+ return;
526
+ }
527
+ double sid = 0;
528
+ double speed = 1.0;
529
+ if (options != nil) {
530
+ if (options[@"sid"] != nil) sid = [options[@"sid"] doubleValue];
531
+ if (options[@"speed"] != nil) speed = [options[@"speed"] doubleValue];
532
+ }
533
+ std::string instanceIdStr = [instanceId UTF8String];
534
+ std::shared_ptr<TtsInstanceState> instRef;
535
+ {
536
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
537
+ auto it = g_tts_instances.find(instanceIdStr);
538
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || !it->second->wrapper->isInitialized()) {
539
+ reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
540
+ return;
541
+ }
542
+ instRef = it->second; // shared_ptr copy keeps TtsInstanceState alive during streaming
543
+ if (instRef->streamRunning.load()) {
544
+ reject(@"TTS_STREAM_ERROR", @"TTS streaming already in progress", nil);
545
+ return;
546
+ }
547
+ instRef->streamCancelled.store(false);
548
+ instRef->streamRunning.store(true);
549
+ }
550
+
551
+ std::string textStr = [text UTF8String];
552
+ int32_t sampleRate = instRef->wrapper->getSampleRate();
553
+ NSString *instanceIdCopy = [instanceId copy];
554
+
555
+ __weak SherpaOnnx *weakSelf = self;
556
+ dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0), ^{
557
+ bool success = false;
558
+ @try {
559
+ success = instRef->wrapper->generateStream(
560
+ textStr,
561
+ static_cast<int32_t>(sid),
562
+ static_cast<float>(speed),
563
+ [weakSelf, sampleRate, instanceIdCopy, instRef](const float *samples, int32_t numSamples, float progress) -> int32_t {
564
+ if (instRef->streamCancelled.load()) {
565
+ return 0;
566
+ }
567
+
568
+ NSMutableArray *samplesArray = [NSMutableArray arrayWithCapacity:numSamples];
569
+ for (int32_t i = 0; i < numSamples; i++) {
570
+ [samplesArray addObject:@(samples[i])];
571
+ }
572
+
573
+ NSDictionary *payload = @{
574
+ @"instanceId": instanceIdCopy,
575
+ @"samples": samplesArray,
576
+ @"sampleRate": @(sampleRate),
577
+ @"progress": @(progress),
578
+ @"isFinal": @NO
579
+ };
580
+
581
+ dispatch_async(dispatch_get_main_queue(), ^{
582
+ if (weakSelf) {
583
+ [weakSelf sendEventWithName:@"ttsStreamChunk" body:payload];
584
+ }
585
+ });
586
+
587
+ return instRef->streamCancelled.load() ? 0 : 1;
588
+ }
589
+ );
590
+ } @catch (NSException *exception) {
591
+ NSString *errorMsg = [NSString stringWithFormat:@"TTS streaming failed: %@", exception.reason];
592
+ dispatch_async(dispatch_get_main_queue(), ^{
593
+ if (weakSelf) {
594
+ [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"instanceId": instanceIdCopy, @"message": errorMsg }];
595
+ }
596
+ });
597
+ }
598
+
599
+ bool cancelled = instRef->streamCancelled.load();
600
+ if (!success && !cancelled) {
601
+ dispatch_async(dispatch_get_main_queue(), ^{
602
+ if (weakSelf) {
603
+ [weakSelf sendEventWithName:@"ttsStreamError" body:@{ @"instanceId": instanceIdCopy, @"message": @"TTS streaming generation failed" }];
604
+ }
605
+ });
606
+ }
607
+
608
+ dispatch_async(dispatch_get_main_queue(), ^{
609
+ if (weakSelf) {
610
+ [weakSelf sendEventWithName:@"ttsStreamEnd" body:@{ @"instanceId": instanceIdCopy, @"cancelled": @(cancelled) }];
611
+ }
612
+ });
613
+
614
+ instRef->streamRunning.store(false);
615
+ {
616
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
617
+ g_tts_stream_cv.notify_all();
618
+ }
619
+ });
620
+
621
+ resolve(nil);
622
+ }
623
+
624
+ - (void)cancelTtsStream:(NSString *)instanceId
625
+ resolve:(RCTPromiseResolveBlock)resolve
626
+ reject:(RCTPromiseRejectBlock)reject
627
+ {
628
+ if (instanceId == nil || [instanceId length] == 0) {
629
+ resolve(nil);
630
+ return;
631
+ }
632
+ std::string instanceIdStr = [instanceId UTF8String];
633
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
634
+ auto it = g_tts_instances.find(instanceIdStr);
635
+ if (it != g_tts_instances.end()) {
636
+ it->second->streamCancelled.store(true);
637
+ }
638
+ resolve(nil);
639
+ }
640
+
641
+ - (void)startTtsPcmPlayer:(NSString *)instanceId
642
+ sampleRate:(double)sampleRate
643
+ channels:(double)channels
644
+ resolve:(RCTPromiseResolveBlock)resolve
645
+ reject:(RCTPromiseRejectBlock)reject
646
+ {
647
+ if (instanceId == nil || [instanceId length] == 0) {
648
+ reject(@"TTS_PCM_ERROR", @"instanceId is required", nil);
649
+ return;
650
+ }
651
+ std::string instanceIdStr = [instanceId UTF8String];
652
+ dispatch_async(dispatch_get_main_queue(), ^{
653
+ @try {
654
+ TtsInstanceState *inst = nullptr;
655
+ NSError *startError = nil;
656
+ NSString *errorMsg = nil;
657
+ AVAudioSession *session = nil;
658
+ {
659
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
660
+ auto it = g_tts_instances.find(instanceIdStr);
661
+ if (it == g_tts_instances.end()) {
662
+ errorMsg = @"TTS instance not found";
663
+ goto out_start;
664
+ }
665
+ inst = it->second.get();
666
+ if (channels != 1.0) {
667
+ errorMsg = @"PCM playback supports mono only";
668
+ goto out_start;
669
+ }
670
+ if (inst->player != nil) [inst->player stop];
671
+ if (inst->engine != nil) {
672
+ [inst->engine stop];
673
+ [inst->engine reset];
674
+ }
675
+ inst->player = nil;
676
+ inst->engine = nil;
677
+ inst->format = nil;
678
+ }
679
+
680
+ session = [AVAudioSession sharedInstance];
681
+ [session setCategory:AVAudioSessionCategoryPlayback error:nil];
682
+ [session setActive:YES error:nil];
683
+
684
+ {
685
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
686
+ auto it = g_tts_instances.find(instanceIdStr);
687
+ if (it == g_tts_instances.end()) {
688
+ errorMsg = @"TTS instance not found";
689
+ goto out_start;
690
+ }
691
+ inst = it->second.get();
692
+ inst->engine = [[AVAudioEngine alloc] init];
693
+ inst->player = [[AVAudioPlayerNode alloc] init];
694
+ inst->format = [[AVAudioFormat alloc] initStandardFormatWithSampleRate:sampleRate channels:1];
695
+
696
+ [inst->engine attachNode:inst->player];
697
+ [inst->engine connect:inst->player to:inst->engine.mainMixerNode format:inst->format];
698
+
699
+ if (![inst->engine startAndReturnError:&startError]) {
700
+ errorMsg = [NSString stringWithFormat:@"Failed to start audio engine: %@", startError.localizedDescription];
701
+ goto out_start;
702
+ }
703
+ [inst->player play];
704
+ }
705
+ out_start:
706
+ if (errorMsg != nil) {
707
+ if (startError) {
708
+ reject(@"TTS_PCM_ERROR", errorMsg, startError);
709
+ } else {
710
+ reject(@"TTS_PCM_ERROR", errorMsg, nil);
711
+ }
712
+ } else {
713
+ resolve(nil);
714
+ }
715
+ } @catch (NSException *exception) {
716
+ NSString *errorMsg = [NSString stringWithFormat:@"Failed to start PCM player: %@", exception.reason];
717
+ reject(@"TTS_PCM_ERROR", errorMsg, nil);
718
+ }
719
+ });
720
+ }
721
+
722
+ - (void)writeTtsPcmChunk:(NSString *)instanceId
723
+ samples:(NSArray<NSNumber *> *)samples
724
+ resolve:(RCTPromiseResolveBlock)resolve
725
+ reject:(RCTPromiseRejectBlock)reject
726
+ {
727
+ if (instanceId == nil || [instanceId length] == 0) {
728
+ reject(@"TTS_PCM_ERROR", @"instanceId is required", nil);
729
+ return;
730
+ }
731
+ std::string instanceIdStr = [instanceId UTF8String];
732
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
733
+ auto it = g_tts_instances.find(instanceIdStr);
734
+ if (it == g_tts_instances.end() || it->second->engine == nil || it->second->player == nil || it->second->format == nil) {
735
+ reject(@"TTS_PCM_ERROR", @"PCM player not initialized", nil);
736
+ return;
737
+ }
738
+ TtsInstanceState *inst = it->second.get();
739
+ @try {
740
+ AVAudioFrameCount frameCount = (AVAudioFrameCount)[samples count];
741
+ AVAudioPCMBuffer *buffer = [[AVAudioPCMBuffer alloc] initWithPCMFormat:inst->format frameCapacity:frameCount];
742
+ buffer.frameLength = frameCount;
743
+
744
+ float *channelData = buffer.floatChannelData[0];
745
+ for (NSUInteger i = 0; i < [samples count]; i++) {
746
+ channelData[i] = [samples[i] floatValue];
747
+ }
748
+
749
+ [inst->player scheduleBuffer:buffer completionHandler:nil];
750
+ resolve(nil);
751
+ } @catch (NSException *exception) {
752
+ NSString *errorMsg = [NSString stringWithFormat:@"Failed to write PCM chunk: %@", exception.reason];
753
+ reject(@"TTS_PCM_ERROR", errorMsg, nil);
754
+ }
755
+ }
756
+
757
+ - (void)stopTtsPcmPlayer:(NSString *)instanceId
758
+ resolve:(RCTPromiseResolveBlock)resolve
759
+ reject:(RCTPromiseRejectBlock)reject
760
+ {
761
+ if (instanceId == nil || [instanceId length] == 0) {
762
+ resolve(nil);
763
+ return;
764
+ }
765
+ std::string instanceIdStr = [instanceId UTF8String];
766
+ dispatch_async(dispatch_get_main_queue(), ^{
767
+ @try {
768
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
769
+ auto it = g_tts_instances.find(instanceIdStr);
770
+ if (it != g_tts_instances.end()) {
771
+ TtsInstanceState *inst = it->second.get();
772
+ if (inst->player != nil) {
773
+ [inst->player stop];
774
+ }
775
+ if (inst->engine != nil) {
776
+ [inst->engine stop];
777
+ [inst->engine reset];
778
+ }
779
+ inst->player = nil;
780
+ inst->engine = nil;
781
+ inst->format = nil;
782
+ }
783
+ resolve(nil);
784
+ } @catch (NSException *exception) {
785
+ NSString *errorMsg = [NSString stringWithFormat:@"Failed to stop PCM player: %@", exception.reason];
786
+ reject(@"TTS_PCM_ERROR", errorMsg, nil);
787
+ }
788
+ });
789
+ }
790
+
791
+ - (void)getTtsSampleRate:(NSString *)instanceId
792
+ resolve:(RCTPromiseResolveBlock)resolve
793
+ reject:(RCTPromiseRejectBlock)reject
794
+ {
795
+ if (instanceId == nil || [instanceId length] == 0) {
796
+ reject(@"TTS_ERROR", @"instanceId is required", nil);
797
+ return;
798
+ }
799
+ std::string instanceIdStr = [instanceId UTF8String];
800
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
801
+ auto it = g_tts_instances.find(instanceIdStr);
802
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || !it->second->wrapper->isInitialized()) {
803
+ reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
804
+ return;
805
+ }
806
+ int32_t sampleRate = it->second->wrapper->getSampleRate();
807
+ resolve(@(sampleRate));
808
+ }
809
+
810
+ - (void)getTtsNumSpeakers:(NSString *)instanceId
811
+ resolve:(RCTPromiseResolveBlock)resolve
812
+ reject:(RCTPromiseRejectBlock)reject
813
+ {
814
+ if (instanceId == nil || [instanceId length] == 0) {
815
+ reject(@"TTS_ERROR", @"instanceId is required", nil);
816
+ return;
817
+ }
818
+ std::string instanceIdStr = [instanceId UTF8String];
819
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
820
+ auto it = g_tts_instances.find(instanceIdStr);
821
+ if (it == g_tts_instances.end() || it->second->wrapper == nullptr || !it->second->wrapper->isInitialized()) {
822
+ reject(@"TTS_NOT_INITIALIZED", @"TTS not initialized. Call initializeTts() first.", nil);
823
+ return;
824
+ }
825
+ int32_t numSpeakers = it->second->wrapper->getNumSpeakers();
826
+ resolve(@(numSpeakers));
827
+ }
828
+
829
+ - (void)unloadTts:(NSString *)instanceId
830
+ resolve:(RCTPromiseResolveBlock)resolve
831
+ reject:(RCTPromiseRejectBlock)reject
832
+ {
833
+ if (instanceId == nil || [instanceId length] == 0) {
834
+ resolve(nil);
835
+ return;
836
+ }
837
+ std::string instanceIdStr = [instanceId UTF8String];
838
+ RCTPromiseResolveBlock resolveCopy = resolve;
839
+ RCTPromiseRejectBlock rejectCopy = reject;
840
+ NSString *instanceIdCopy = [instanceId copy];
841
+ @try {
842
+ dispatch_async(dispatch_get_main_queue(), ^{
843
+ TtsInstanceState *inst = nullptr;
844
+ {
845
+ std::lock_guard<std::mutex> lock(g_tts_mutex);
846
+ auto it = g_tts_instances.find(instanceIdStr);
847
+ if (it == g_tts_instances.end()) {
848
+ resolveCopy(nil);
849
+ return;
850
+ }
851
+ inst = it->second.get();
852
+ if (inst->player != nil) [inst->player stop];
853
+ if (inst->engine != nil) {
854
+ [inst->engine stop];
855
+ [inst->engine reset];
856
+ }
857
+ inst->player = nil;
858
+ inst->engine = nil;
859
+ inst->format = nil;
860
+ inst->streamCancelled.store(true);
861
+ }
862
+ dispatch_async(dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0), ^{
863
+ {
864
+ std::unique_lock<std::mutex> lock(g_tts_mutex);
865
+ auto it = g_tts_instances.find(instanceIdStr);
866
+ if (it == g_tts_instances.end()) {
867
+ dispatch_async(dispatch_get_main_queue(), ^{ resolveCopy(nil); });
868
+ return;
869
+ }
870
+ TtsInstanceState *i = it->second.get();
871
+ bool done = g_tts_stream_cv.wait_for(
872
+ lock,
873
+ std::chrono::seconds(5),
874
+ [i] { return !i->streamRunning.load(); }
875
+ );
876
+ if (!done) {
877
+ RCTLogWarn(@"TTS unload: stream did not stop within 5s, releasing anyway");
878
+ }
879
+ if (i->wrapper != nullptr) {
880
+ i->wrapper->release();
881
+ i->wrapper.reset();
882
+ }
883
+ i->modelDir = nil;
884
+ i->modelType = nil;
885
+ i->provider = nil;
886
+ i->noiseScale = nil;
887
+ i->noiseScaleW = nil;
888
+ i->lengthScale = nil;
889
+ i->ruleFsts = nil;
890
+ i->ruleFars = nil;
891
+ i->maxNumSentences = nil;
892
+ i->silenceScale = nil;
893
+ g_tts_instances.erase(it);
894
+ }
895
+ RCTLogInfo(@"TTS instance %@ released", instanceIdCopy);
896
+ dispatch_async(dispatch_get_main_queue(), ^{ resolveCopy(nil); });
897
+ });
898
+ });
899
+ } @catch (NSException *exception) {
900
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception during TTS cleanup: %@", exception.reason];
901
+ RCTLogError(@"%@", errorMsg);
902
+ rejectCopy(@"TTS_CLEANUP_ERROR", errorMsg, nil);
903
+ }
904
+ }
905
+
906
+ - (void)saveTtsAudioToFile:(NSArray<NSNumber *> *)samples
907
+ sampleRate:(double)sampleRate
908
+ filePath:(NSString *)filePath
909
+ resolve:(RCTPromiseResolveBlock)resolve
910
+ reject:(RCTPromiseRejectBlock)reject
911
+ {
912
+ @try {
913
+ std::vector<float> samplesVec;
914
+ samplesVec.reserve([samples count]);
915
+ for (NSNumber *num in samples) {
916
+ samplesVec.push_back([num floatValue]);
917
+ }
918
+
919
+ std::string filePathStr = std::string([filePath UTF8String]);
920
+
921
+ bool success = sherpaonnx::TtsWrapper::saveToWavFile(
922
+ samplesVec,
923
+ static_cast<int32_t>(sampleRate),
924
+ filePathStr
925
+ );
926
+
927
+ if (success) {
928
+ resolve(filePath);
929
+ } else {
930
+ reject(@"TTS_SAVE_ERROR", @"Failed to save audio to file", nil);
931
+ }
932
+ } @catch (NSException *exception) {
933
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception saving TTS audio: %@", exception.reason];
934
+ reject(@"TTS_SAVE_ERROR", errorMsg, nil);
935
+ }
936
+ }
937
+
938
+ - (void)saveTtsAudioToContentUri:(NSArray<NSNumber *> *)samples
939
+ sampleRate:(double)sampleRate
940
+ directoryUri:(NSString *)directoryUri
941
+ filename:(NSString *)filename
942
+ resolve:(RCTPromiseResolveBlock)resolve
943
+ reject:(RCTPromiseRejectBlock)reject
944
+ {
945
+ @try {
946
+ if ([directoryUri hasPrefix:@"content://"]) {
947
+ reject(@"TTS_SAVE_ERROR", @"Content URIs are not supported on iOS", nil);
948
+ return;
949
+ }
950
+ std::vector<float> samplesVec;
951
+ samplesVec.reserve([samples count]);
952
+ for (NSNumber *num in samples) {
953
+ samplesVec.push_back([num floatValue]);
954
+ }
955
+ NSString *dirPath = [directoryUri hasPrefix:@"file://"]
956
+ ? [[NSURL URLWithString:directoryUri] path]
957
+ : directoryUri;
958
+ NSString *filePath = [dirPath stringByAppendingPathComponent:filename];
959
+ std::string filePathStr = std::string([filePath UTF8String]);
960
+ bool success = sherpaonnx::TtsWrapper::saveToWavFile(
961
+ samplesVec,
962
+ static_cast<int32_t>(sampleRate),
963
+ filePathStr
964
+ );
965
+ if (success) {
966
+ resolve(filePath);
967
+ } else {
968
+ reject(@"TTS_SAVE_ERROR", @"Failed to save audio to file", nil);
969
+ }
970
+ } @catch (NSException *exception) {
971
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception saving TTS audio: %@", exception.reason];
972
+ reject(@"TTS_SAVE_ERROR", errorMsg, nil);
973
+ }
974
+ }
975
+
976
+ - (void)copyTtsContentUriToCache:(NSString *)fileUri
977
+ filename:(NSString *)filename
978
+ resolve:(RCTPromiseResolveBlock)resolve
979
+ reject:(RCTPromiseRejectBlock)reject
980
+ {
981
+ @try {
982
+ if ([fileUri hasPrefix:@"content://"]) {
983
+ reject(@"TTS_SAVE_ERROR", @"Content URIs are not supported on iOS", nil);
984
+ return;
985
+ }
986
+ NSString *srcPath = [fileUri hasPrefix:@"file://"]
987
+ ? [[NSURL URLWithString:fileUri] path]
988
+ : fileUri;
989
+ NSFileManager *fm = [NSFileManager defaultManager];
990
+ if (![fm fileExistsAtPath:srcPath]) {
991
+ reject(@"TTS_SAVE_ERROR", @"Source file does not exist", nil);
992
+ return;
993
+ }
994
+ NSArray *caches = NSSearchPathForDirectoriesInDomains(NSCachesDirectory, NSUserDomainMask, YES);
995
+ NSString *cacheDir = caches.firstObject;
996
+ NSString *destPath = [[cacheDir stringByAppendingPathComponent:@"sherpa_tts"] stringByAppendingPathComponent:filename];
997
+ NSError *err = nil;
998
+ [fm createDirectoryAtPath:[destPath stringByDeletingLastPathComponent] withIntermediateDirectories:YES attributes:nil error:&err];
999
+ if (err) {
1000
+ reject(@"TTS_SAVE_ERROR", err.localizedDescription, err);
1001
+ return;
1002
+ }
1003
+ if ([fm fileExistsAtPath:destPath]) {
1004
+ [fm removeItemAtPath:destPath error:nil];
1005
+ }
1006
+ BOOL ok = [fm copyItemAtPath:srcPath toPath:destPath error:&err];
1007
+ if (!ok || err) {
1008
+ reject(@"TTS_SAVE_ERROR", err ? err.localizedDescription : @"Copy failed", err);
1009
+ return;
1010
+ }
1011
+ resolve(destPath);
1012
+ } @catch (NSException *exception) {
1013
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception copying file: %@", exception.reason];
1014
+ reject(@"TTS_SAVE_ERROR", errorMsg, nil);
1015
+ }
1016
+ }
1017
+
1018
+ - (void)saveTtsTextToContentUri:(NSString *)text
1019
+ directoryUri:(NSString *)directoryUri
1020
+ filename:(NSString *)filename
1021
+ mimeType:(NSString *)mimeType
1022
+ resolve:(RCTPromiseResolveBlock)resolve
1023
+ reject:(RCTPromiseRejectBlock)reject
1024
+ {
1025
+ @try {
1026
+ if ([directoryUri hasPrefix:@"content://"]) {
1027
+ reject(@"TTS_SAVE_ERROR", @"Content URIs are not supported on iOS", nil);
1028
+ return;
1029
+ }
1030
+
1031
+ NSURL *directoryUrl = nil;
1032
+ if ([directoryUri hasPrefix:@"file://"]) {
1033
+ directoryUrl = [NSURL URLWithString:directoryUri];
1034
+ } else {
1035
+ directoryUrl = [NSURL fileURLWithPath:directoryUri];
1036
+ }
1037
+
1038
+ if (!directoryUrl) {
1039
+ reject(@"TTS_SAVE_ERROR", @"Invalid directory URL", nil);
1040
+ return;
1041
+ }
1042
+
1043
+ NSString *directoryPath = [directoryUrl path];
1044
+ NSString *filePath = [directoryPath stringByAppendingPathComponent:filename];
1045
+
1046
+ NSError *writeError = nil;
1047
+ BOOL success = [text writeToFile:filePath
1048
+ atomically:YES
1049
+ encoding:NSUTF8StringEncoding
1050
+ error:&writeError];
1051
+
1052
+ if (!success || writeError) {
1053
+ reject(@"TTS_SAVE_ERROR", @"Failed to save text to file", writeError);
1054
+ return;
1055
+ }
1056
+
1057
+ resolve(filePath);
1058
+ } @catch (NSException *exception) {
1059
+ NSString *errorMsg = [NSString stringWithFormat:@"Exception saving text file: %@", exception.reason];
1060
+ reject(@"TTS_SAVE_ERROR", errorMsg, nil);
1061
+ }
1062
+ }
1063
+
1064
+ - (void)shareTtsAudio:(NSString *)fileUri
1065
+ mimeType:(NSString *)mimeType
1066
+ resolve:(RCTPromiseResolveBlock)resolve
1067
+ reject:(RCTPromiseRejectBlock)reject
1068
+ {
1069
+ @try {
1070
+ NSURL *url = nil;
1071
+ if ([fileUri hasPrefix:@"file://"] || [fileUri hasPrefix:@"content://"]) {
1072
+ url = [NSURL URLWithString:fileUri];
1073
+ } else {
1074
+ url = [NSURL fileURLWithPath:fileUri];
1075
+ }
1076
+
1077
+ if (!url) {
1078
+ reject(@"TTS_SHARE_ERROR", @"Invalid file URL", nil);
1079
+ return;
1080
+ }
1081
+
1082
+ dispatch_async(dispatch_get_main_queue(), ^{
1083
+ UIViewController *controller = RCTPresentedViewController();
1084
+ if (!controller) {
1085
+ reject(@"TTS_SHARE_ERROR", @"No active view controller", nil);
1086
+ return;
1087
+ }
1088
+
1089
+ UIActivityViewController *activity =
1090
+ [[UIActivityViewController alloc] initWithActivityItems:@[url]
1091
+ applicationActivities:nil];
1092
+ [controller presentViewController:activity animated:YES completion:nil];
1093
+ resolve(nil);
1094
+ });
1095
+ } @catch (NSException *exception) {
1096
+ NSString *errorMsg = [NSString stringWithFormat:@"Failed to share audio: %@", exception.reason];
1097
+ reject(@"TTS_SHARE_ERROR", errorMsg, nil);
1098
+ }
1099
+ }
1100
+
1101
+ @end