react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -1,345 +1,376 @@
1
- #include "sherpa-onnx-tts-wrapper.h"
2
- #include "sherpa-onnx-model-detect.h"
3
- #include <algorithm>
4
- #include <cctype>
5
- #include <cstring>
6
- #include <fstream>
7
- #include <optional>
8
- #include <sstream>
9
-
10
- // iOS logging
11
- #ifdef __APPLE__
12
- #include <Foundation/Foundation.h>
13
- #include <cstdio>
14
- #define LOGI(fmt, ...) NSLog(@"TtsWrapper: " fmt, ##__VA_ARGS__)
15
- #define LOGE(fmt, ...) NSLog(@"TtsWrapper ERROR: " fmt, ##__VA_ARGS__)
16
- #else
17
- #define LOGI(...)
18
- #define LOGE(...)
19
- #endif
20
-
21
- // Use C++17 filesystem (podspec enforces C++17)
22
- #include <filesystem>
23
- namespace fs = std::filesystem;
24
-
25
- // sherpa-onnx headers - use C++ API (RAII wrapper around C API)
26
- #include "sherpa-onnx/c-api/cxx-api.h"
27
-
28
- namespace sherpaonnx {
29
-
30
- class TtsWrapper::Impl {
31
- public:
32
- bool initialized = false;
33
- std::string modelDir;
34
- std::optional<sherpa_onnx::cxx::OfflineTts> tts;
35
- };
36
-
37
- TtsWrapper::TtsWrapper() : pImpl(std::make_unique<Impl>()) {
38
- LOGI("TtsWrapper created");
39
- }
40
-
41
- TtsWrapper::~TtsWrapper() {
42
- release();
43
- LOGI("TtsWrapper destroyed");
44
- }
45
-
46
- TtsInitializeResult TtsWrapper::initialize(
47
- const std::string& modelDir,
48
- const std::string& modelType,
49
- int32_t numThreads,
50
- bool debug,
51
- const std::optional<float>& noiseScale,
52
- const std::optional<float>& noiseScaleW,
53
- const std::optional<float>& lengthScale
54
- ) {
55
- TtsInitializeResult result;
56
- result.success = false;
57
-
58
- if (pImpl->initialized) {
59
- release();
60
- }
61
-
62
- if (modelDir.empty()) {
63
- LOGE("TTS: Model directory is empty");
64
- return result;
65
- }
66
-
67
- try {
68
- sherpa_onnx::cxx::OfflineTtsConfig config;
69
- config.model.num_threads = numThreads;
70
- config.model.debug = debug;
71
-
72
- auto detect = DetectTtsModel(modelDir, modelType);
73
- if (!detect.ok) {
74
- LOGE("%s", detect.error.c_str());
75
- return result;
76
- }
77
-
78
- switch (detect.selectedKind) {
79
- case TtsModelKind::kVits:
80
- config.model.vits.model = detect.paths.ttsModel;
81
- config.model.vits.tokens = detect.paths.tokens;
82
- config.model.vits.data_dir = detect.paths.dataDir;
83
- if (noiseScale.has_value()) {
84
- config.model.vits.noise_scale = *noiseScale;
85
- }
86
- if (noiseScaleW.has_value()) {
87
- config.model.vits.noise_scale_w = *noiseScaleW;
88
- }
89
- if (lengthScale.has_value()) {
90
- config.model.vits.length_scale = *lengthScale;
91
- }
92
- break;
93
- case TtsModelKind::kMatcha:
94
- config.model.matcha.acoustic_model = detect.paths.acousticModel;
95
- config.model.matcha.vocoder = detect.paths.vocoder;
96
- config.model.matcha.tokens = detect.paths.tokens;
97
- config.model.matcha.data_dir = detect.paths.dataDir;
98
- if (noiseScale.has_value()) {
99
- config.model.matcha.noise_scale = *noiseScale;
100
- }
101
- if (lengthScale.has_value()) {
102
- config.model.matcha.length_scale = *lengthScale;
103
- }
104
- break;
105
- case TtsModelKind::kKokoro:
106
- config.model.kokoro.model = detect.paths.ttsModel;
107
- config.model.kokoro.tokens = detect.paths.tokens;
108
- config.model.kokoro.data_dir = detect.paths.dataDir;
109
- config.model.kokoro.voices = detect.paths.voices;
110
- if (!detect.paths.lexicon.empty()) {
111
- config.model.kokoro.lexicon = detect.paths.lexicon;
112
- }
113
- if (lengthScale.has_value()) {
114
- config.model.kokoro.length_scale = *lengthScale;
115
- }
116
- break;
117
- case TtsModelKind::kKitten:
118
- config.model.kitten.model = detect.paths.ttsModel;
119
- config.model.kitten.tokens = detect.paths.tokens;
120
- config.model.kitten.data_dir = detect.paths.dataDir;
121
- config.model.kitten.voices = detect.paths.voices;
122
- if (lengthScale.has_value()) {
123
- config.model.kitten.length_scale = *lengthScale;
124
- }
125
- break;
126
- case TtsModelKind::kZipvoice:
127
- config.model.zipvoice.encoder = detect.paths.encoder;
128
- config.model.zipvoice.decoder = detect.paths.decoder;
129
- config.model.zipvoice.vocoder = detect.paths.vocoder;
130
- config.model.zipvoice.tokens = detect.paths.tokens;
131
- config.model.zipvoice.data_dir = detect.paths.dataDir;
132
- break;
133
- case TtsModelKind::kUnknown:
134
- default:
135
- LOGE("TTS: Unknown model type: %s", modelType.c_str());
136
- return result;
137
- }
138
-
139
- LOGI("TTS: Creating OfflineTts instance...");
140
- pImpl->tts = sherpa_onnx::cxx::OfflineTts::Create(config);
141
-
142
- if (!pImpl->tts.has_value()) {
143
- LOGE("TTS: Failed to create OfflineTts instance");
144
- return result;
145
- }
146
-
147
- pImpl->initialized = true;
148
- pImpl->modelDir = modelDir;
149
-
150
- LOGI("TTS: Initialization successful");
151
- LOGI("TTS: Sample rate: %d Hz", pImpl->tts.value().SampleRate());
152
- LOGI("TTS: Number of speakers: %d", pImpl->tts.value().NumSpeakers());
153
-
154
- result.success = true;
155
- result.detectedModels = detect.detectedModels;
156
- return result;
157
- } catch (const std::exception& e) {
158
- LOGE("TTS: Exception during initialization: %s", e.what());
159
- return result;
160
- } catch (...) {
161
- LOGE("TTS: Unknown exception during initialization");
162
- return result;
163
- }
164
- }
165
-
166
- TtsWrapper::AudioResult TtsWrapper::generate(
167
- const std::string& text,
168
- int32_t sid,
169
- float speed
170
- ) {
171
- AudioResult result;
172
- result.sampleRate = 0;
173
-
174
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
175
- LOGE("TTS: Not initialized. Call initialize() first.");
176
- return result;
177
- }
178
-
179
- if (text.empty()) {
180
- LOGE("TTS: Input text is empty");
181
- return result;
182
- }
183
-
184
- try {
185
- LOGI("TTS: Generating speech for text: %s (sid=%d, speed=%.2f)",
186
- text.c_str(), sid, speed);
187
-
188
- auto audio = pImpl->tts.value().Generate(text, sid, speed);
189
-
190
- result.samples = std::move(audio.samples);
191
- result.sampleRate = audio.sample_rate;
192
-
193
- LOGI("TTS: Generated %zu samples at %d Hz",
194
- result.samples.size(), result.sampleRate);
195
-
196
- return result;
197
- } catch (const std::exception& e) {
198
- LOGE("TTS: Exception during generation: %s", e.what());
199
- return result;
200
- } catch (...) {
201
- LOGE("TTS: Unknown exception during generation");
202
- return result;
203
- }
204
- }
205
-
206
- bool TtsWrapper::generateStream(
207
- const std::string& text,
208
- int32_t sid,
209
- float speed,
210
- const TtsStreamCallback& callback
211
- ) {
212
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
213
- LOGE("TTS: Not initialized. Call initialize() first.");
214
- return false;
215
- }
216
-
217
- if (text.empty()) {
218
- LOGE("TTS: Input text is empty");
219
- return false;
220
- }
221
-
222
- try {
223
- LOGI("TTS: Streaming generation for text: %s (sid=%d, speed=%.2f)",
224
- text.c_str(), sid, speed);
225
-
226
- auto callbackCopy = callback;
227
- auto shim = [](const float *samples, int32_t numSamples, float progress, void *arg) -> int32_t {
228
- auto *cb = reinterpret_cast<TtsStreamCallback*>(arg);
229
- if (!cb || !(*cb)) return 0;
230
- return (*cb)(samples, numSamples, progress);
231
- };
232
-
233
- pImpl->tts.value().Generate(
234
- text,
235
- sid,
236
- speed,
237
- callbackCopy ? shim : nullptr,
238
- callbackCopy ? &callbackCopy : nullptr
239
- );
240
-
241
- return true;
242
- } catch (const std::exception& e) {
243
- LOGE("TTS: Exception during streaming generation: %s", e.what());
244
- return false;
245
- } catch (...) {
246
- LOGE("TTS: Unknown exception during streaming generation");
247
- return false;
248
- }
249
- }
250
-
251
- int32_t TtsWrapper::getSampleRate() const {
252
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
253
- LOGE("TTS: Not initialized. Call initialize() first.");
254
- return 0;
255
- }
256
- return pImpl->tts.value().SampleRate();
257
- }
258
-
259
- int32_t TtsWrapper::getNumSpeakers() const {
260
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
261
- LOGE("TTS: Not initialized. Call initialize() first.");
262
- return 0;
263
- }
264
- return pImpl->tts.value().NumSpeakers();
265
- }
266
-
267
- bool TtsWrapper::isInitialized() const {
268
- return pImpl->initialized;
269
- }
270
-
271
- void TtsWrapper::release() {
272
- if (pImpl->initialized) {
273
- pImpl->tts.reset();
274
- pImpl->initialized = false;
275
- pImpl->modelDir.clear();
276
- LOGI("TTS: Resources released");
277
- }
278
- }
279
-
280
- bool TtsWrapper::saveToWavFile(
281
- const std::vector<float>& samples,
282
- int32_t sampleRate,
283
- const std::string& filePath
284
- ) {
285
- if (samples.empty()) {
286
- LOGE("TTS: Cannot save empty audio samples");
287
- return false;
288
- }
289
-
290
- if (sampleRate <= 0) {
291
- LOGE("TTS: Invalid sample rate: %d", sampleRate);
292
- return false;
293
- }
294
-
295
- try {
296
- std::ofstream outfile(filePath, std::ios::binary);
297
- if (!outfile) {
298
- LOGE("TTS: Failed to open output file: %s", filePath.c_str());
299
- return false;
300
- }
301
-
302
- const int32_t numChannels = 1;
303
- const int32_t bitsPerSample = 16;
304
- const int32_t byteRate = sampleRate * numChannels * bitsPerSample / 8;
305
- const int32_t blockAlign = numChannels * bitsPerSample / 8;
306
- const int32_t dataSize = static_cast<int32_t>(samples.size()) * bitsPerSample / 8;
307
- const int32_t chunkSize = 36 + dataSize;
308
-
309
- outfile.write("RIFF", 4);
310
- outfile.write(reinterpret_cast<const char*>(&chunkSize), 4);
311
- outfile.write("WAVE", 4);
312
-
313
- outfile.write("fmt ", 4);
314
- const int32_t subchunk1Size = 16;
315
- outfile.write(reinterpret_cast<const char*>(&subchunk1Size), 4);
316
- const int16_t audioFormat = 1;
317
- outfile.write(reinterpret_cast<const char*>(&audioFormat), 2);
318
- const int16_t numChannelsInt16 = static_cast<int16_t>(numChannels);
319
- outfile.write(reinterpret_cast<const char*>(&numChannelsInt16), 2);
320
- outfile.write(reinterpret_cast<const char*>(&sampleRate), 4);
321
- outfile.write(reinterpret_cast<const char*>(&byteRate), 4);
322
- const int16_t blockAlignInt16 = static_cast<int16_t>(blockAlign);
323
- outfile.write(reinterpret_cast<const char*>(&blockAlignInt16), 2);
324
- const int16_t bitsPerSampleInt16 = static_cast<int16_t>(bitsPerSample);
325
- outfile.write(reinterpret_cast<const char*>(&bitsPerSampleInt16), 2);
326
-
327
- outfile.write("data", 4);
328
- outfile.write(reinterpret_cast<const char*>(&dataSize), 4);
329
-
330
- for (float sample : samples) {
331
- float clamped = std::max(-1.0f, std::min(1.0f, sample));
332
- int16_t intSample = static_cast<int16_t>(clamped * 32767.0f);
333
- outfile.write(reinterpret_cast<const char*>(&intSample), sizeof(int16_t));
334
- }
335
-
336
- outfile.close();
337
- LOGI("TTS: Successfully saved %zu samples to %s", samples.size(), filePath.c_str());
338
- return true;
339
- } catch (const std::exception& e) {
340
- LOGE("TTS: Exception while saving WAV file: %s", e.what());
341
- return false;
342
- }
343
- }
344
-
345
- } // namespace sherpaonnx
1
+ /**
2
+ * sherpa-onnx-tts-wrapper.mm
3
+ *
4
+ * Purpose: Wraps the sherpa-onnx C++ OfflineTts for iOS. Builds config from TtsModelPaths, creates
5
+ * TTS instance, generates audio from text. Used by SherpaOnnx+TTS.mm.
6
+ */
7
+
8
+ #include "sherpa-onnx-tts-wrapper.h"
9
+ #include "sherpa-onnx-model-detect.h"
10
+ #include <algorithm>
11
+ #include <cctype>
12
+ #include <cstring>
13
+ #include <fstream>
14
+ #include <optional>
15
+ #include <sstream>
16
+
17
+ // iOS logging
18
+ #ifdef __APPLE__
19
+ #include <Foundation/Foundation.h>
20
+ #include <cstdio>
21
+ #define LOGI(fmt, ...) NSLog(@"TtsWrapper: " fmt, ##__VA_ARGS__)
22
+ #define LOGE(fmt, ...) NSLog(@"TtsWrapper ERROR: " fmt, ##__VA_ARGS__)
23
+ #else
24
+ #define LOGI(...)
25
+ #define LOGE(...)
26
+ #endif
27
+
28
+ // Use C++17 filesystem (podspec enforces C++17)
29
+ #include <filesystem>
30
+ namespace fs = std::filesystem;
31
+
32
+ // sherpa-onnx headers - use C++ API (RAII wrapper around C API)
33
+ #include "sherpa-onnx/c-api/cxx-api.h"
34
+
35
+ namespace sherpaonnx {
36
+
37
+ class TtsWrapper::Impl {
38
+ public:
39
+ bool initialized = false;
40
+ std::string modelDir;
41
+ std::optional<sherpa_onnx::cxx::OfflineTts> tts;
42
+ };
43
+
44
+ TtsWrapper::TtsWrapper() : pImpl(std::make_unique<Impl>()) {
45
+ LOGI("TtsWrapper created");
46
+ }
47
+
48
+ TtsWrapper::~TtsWrapper() {
49
+ release();
50
+ LOGI("TtsWrapper destroyed");
51
+ }
52
+
53
+ TtsInitializeResult TtsWrapper::initialize(
54
+ const std::string& modelDir,
55
+ const std::string& modelType,
56
+ int32_t numThreads,
57
+ bool debug,
58
+ const std::optional<float>& noiseScale,
59
+ const std::optional<float>& noiseScaleW,
60
+ const std::optional<float>& lengthScale,
61
+ const std::optional<std::string>& ruleFsts,
62
+ const std::optional<std::string>& ruleFars,
63
+ const std::optional<int32_t>& maxNumSentences,
64
+ const std::optional<float>& silenceScale,
65
+ const std::optional<std::string>& provider
66
+ ) {
67
+ TtsInitializeResult result;
68
+ result.success = false;
69
+
70
+ if (pImpl->initialized) {
71
+ release();
72
+ }
73
+
74
+ if (modelDir.empty()) {
75
+ LOGE("TTS: Model directory is empty");
76
+ return result;
77
+ }
78
+
79
+ try {
80
+ sherpa_onnx::cxx::OfflineTtsConfig config;
81
+ config.model.num_threads = numThreads;
82
+ config.model.debug = debug;
83
+ if (provider.has_value() && !provider->empty()) {
84
+ config.model.provider = *provider;
85
+ }
86
+
87
+ auto detect = DetectTtsModel(modelDir, modelType);
88
+ if (!detect.ok) {
89
+ LOGE("%s", detect.error.c_str());
90
+ return result;
91
+ }
92
+
93
+ switch (detect.selectedKind) {
94
+ case TtsModelKind::kVits:
95
+ config.model.vits.model = detect.paths.ttsModel;
96
+ config.model.vits.tokens = detect.paths.tokens;
97
+ config.model.vits.data_dir = detect.paths.dataDir;
98
+ if (noiseScale.has_value()) {
99
+ config.model.vits.noise_scale = *noiseScale;
100
+ }
101
+ if (noiseScaleW.has_value()) {
102
+ config.model.vits.noise_scale_w = *noiseScaleW;
103
+ }
104
+ if (lengthScale.has_value()) {
105
+ config.model.vits.length_scale = *lengthScale;
106
+ }
107
+ break;
108
+ case TtsModelKind::kMatcha:
109
+ config.model.matcha.acoustic_model = detect.paths.acousticModel;
110
+ config.model.matcha.vocoder = detect.paths.vocoder;
111
+ config.model.matcha.tokens = detect.paths.tokens;
112
+ config.model.matcha.data_dir = detect.paths.dataDir;
113
+ if (noiseScale.has_value()) {
114
+ config.model.matcha.noise_scale = *noiseScale;
115
+ }
116
+ if (lengthScale.has_value()) {
117
+ config.model.matcha.length_scale = *lengthScale;
118
+ }
119
+ break;
120
+ case TtsModelKind::kKokoro:
121
+ config.model.kokoro.model = detect.paths.ttsModel;
122
+ config.model.kokoro.tokens = detect.paths.tokens;
123
+ config.model.kokoro.data_dir = detect.paths.dataDir;
124
+ config.model.kokoro.voices = detect.paths.voices;
125
+ if (!detect.paths.lexicon.empty()) {
126
+ config.model.kokoro.lexicon = detect.paths.lexicon;
127
+ }
128
+ if (lengthScale.has_value()) {
129
+ config.model.kokoro.length_scale = *lengthScale;
130
+ }
131
+ break;
132
+ case TtsModelKind::kKitten:
133
+ config.model.kitten.model = detect.paths.ttsModel;
134
+ config.model.kitten.tokens = detect.paths.tokens;
135
+ config.model.kitten.data_dir = detect.paths.dataDir;
136
+ config.model.kitten.voices = detect.paths.voices;
137
+ if (lengthScale.has_value()) {
138
+ config.model.kitten.length_scale = *lengthScale;
139
+ }
140
+ break;
141
+ case TtsModelKind::kZipvoice:
142
+ config.model.zipvoice.encoder = detect.paths.encoder;
143
+ config.model.zipvoice.decoder = detect.paths.decoder;
144
+ config.model.zipvoice.vocoder = detect.paths.vocoder;
145
+ config.model.zipvoice.tokens = detect.paths.tokens;
146
+ config.model.zipvoice.data_dir = detect.paths.dataDir;
147
+ break;
148
+ case TtsModelKind::kPocket:
149
+ LOGE("TTS: Pocket model type is detected but not yet supported on iOS");
150
+ return result;
151
+ case TtsModelKind::kUnknown:
152
+ default:
153
+ LOGE("TTS: Unknown model type: %s", modelType.c_str());
154
+ return result;
155
+ }
156
+
157
+ if (ruleFsts.has_value() && !ruleFsts->empty()) {
158
+ config.rule_fsts = *ruleFsts;
159
+ }
160
+ if (ruleFars.has_value() && !ruleFars->empty()) {
161
+ config.rule_fars = *ruleFars;
162
+ }
163
+ if (maxNumSentences.has_value() && *maxNumSentences >= 1) {
164
+ config.max_num_sentences = *maxNumSentences;
165
+ }
166
+ if (silenceScale.has_value()) {
167
+ config.silence_scale = *silenceScale;
168
+ }
169
+
170
+ LOGI("TTS: Creating OfflineTts instance...");
171
+ pImpl->tts = sherpa_onnx::cxx::OfflineTts::Create(config);
172
+
173
+ if (!pImpl->tts.has_value()) {
174
+ LOGE("TTS: Failed to create OfflineTts instance");
175
+ return result;
176
+ }
177
+
178
+ pImpl->initialized = true;
179
+ pImpl->modelDir = modelDir;
180
+
181
+ LOGI("TTS: Initialization successful");
182
+ LOGI("TTS: Sample rate: %d Hz", pImpl->tts.value().SampleRate());
183
+ LOGI("TTS: Number of speakers: %d", pImpl->tts.value().NumSpeakers());
184
+
185
+ result.success = true;
186
+ result.detectedModels = detect.detectedModels;
187
+ return result;
188
+ } catch (const std::exception& e) {
189
+ LOGE("TTS: Exception during initialization: %s", e.what());
190
+ return result;
191
+ } catch (...) {
192
+ LOGE("TTS: Unknown exception during initialization");
193
+ return result;
194
+ }
195
+ }
196
+
197
+ TtsWrapper::AudioResult TtsWrapper::generate(
198
+ const std::string& text,
199
+ int32_t sid,
200
+ float speed
201
+ ) {
202
+ AudioResult result;
203
+ result.sampleRate = 0;
204
+
205
+ if (!pImpl->initialized || !pImpl->tts.has_value()) {
206
+ LOGE("TTS: Not initialized. Call initialize() first.");
207
+ return result;
208
+ }
209
+
210
+ if (text.empty()) {
211
+ LOGE("TTS: Input text is empty");
212
+ return result;
213
+ }
214
+
215
+ try {
216
+ LOGI("TTS: Generating speech for text: %s (sid=%d, speed=%.2f)",
217
+ text.c_str(), sid, speed);
218
+
219
+ auto audio = pImpl->tts.value().Generate(text, sid, speed);
220
+
221
+ result.samples = std::move(audio.samples);
222
+ result.sampleRate = audio.sample_rate;
223
+
224
+ LOGI("TTS: Generated %zu samples at %d Hz",
225
+ result.samples.size(), result.sampleRate);
226
+
227
+ return result;
228
+ } catch (const std::exception& e) {
229
+ LOGE("TTS: Exception during generation: %s", e.what());
230
+ return result;
231
+ } catch (...) {
232
+ LOGE("TTS: Unknown exception during generation");
233
+ return result;
234
+ }
235
+ }
236
+
237
+ bool TtsWrapper::generateStream(
238
+ const std::string& text,
239
+ int32_t sid,
240
+ float speed,
241
+ const TtsStreamCallback& callback
242
+ ) {
243
+ if (!pImpl->initialized || !pImpl->tts.has_value()) {
244
+ LOGE("TTS: Not initialized. Call initialize() first.");
245
+ return false;
246
+ }
247
+
248
+ if (text.empty()) {
249
+ LOGE("TTS: Input text is empty");
250
+ return false;
251
+ }
252
+
253
+ try {
254
+ LOGI("TTS: Streaming generation for text: %s (sid=%d, speed=%.2f)",
255
+ text.c_str(), sid, speed);
256
+
257
+ auto callbackCopy = callback;
258
+ auto shim = [](const float *samples, int32_t numSamples, float progress, void *arg) -> int32_t {
259
+ auto *cb = reinterpret_cast<TtsStreamCallback*>(arg);
260
+ if (!cb || !(*cb)) return 0;
261
+ return (*cb)(samples, numSamples, progress);
262
+ };
263
+
264
+ pImpl->tts.value().Generate(
265
+ text,
266
+ sid,
267
+ speed,
268
+ callbackCopy ? shim : nullptr,
269
+ callbackCopy ? &callbackCopy : nullptr
270
+ );
271
+
272
+ return true;
273
+ } catch (const std::exception& e) {
274
+ LOGE("TTS: Exception during streaming generation: %s", e.what());
275
+ return false;
276
+ } catch (...) {
277
+ LOGE("TTS: Unknown exception during streaming generation");
278
+ return false;
279
+ }
280
+ }
281
+
282
+ int32_t TtsWrapper::getSampleRate() const {
283
+ if (!pImpl->initialized || !pImpl->tts.has_value()) {
284
+ LOGE("TTS: Not initialized. Call initialize() first.");
285
+ return 0;
286
+ }
287
+ return pImpl->tts.value().SampleRate();
288
+ }
289
+
290
+ int32_t TtsWrapper::getNumSpeakers() const {
291
+ if (!pImpl->initialized || !pImpl->tts.has_value()) {
292
+ LOGE("TTS: Not initialized. Call initialize() first.");
293
+ return 0;
294
+ }
295
+ return pImpl->tts.value().NumSpeakers();
296
+ }
297
+
298
+ bool TtsWrapper::isInitialized() const {
299
+ return pImpl->initialized;
300
+ }
301
+
302
+ void TtsWrapper::release() {
303
+ if (pImpl->initialized) {
304
+ pImpl->tts.reset();
305
+ pImpl->initialized = false;
306
+ pImpl->modelDir.clear();
307
+ LOGI("TTS: Resources released");
308
+ }
309
+ }
310
+
311
+ bool TtsWrapper::saveToWavFile(
312
+ const std::vector<float>& samples,
313
+ int32_t sampleRate,
314
+ const std::string& filePath
315
+ ) {
316
+ if (samples.empty()) {
317
+ LOGE("TTS: Cannot save empty audio samples");
318
+ return false;
319
+ }
320
+
321
+ if (sampleRate <= 0) {
322
+ LOGE("TTS: Invalid sample rate: %d", sampleRate);
323
+ return false;
324
+ }
325
+
326
+ try {
327
+ std::ofstream outfile(filePath, std::ios::binary);
328
+ if (!outfile) {
329
+ LOGE("TTS: Failed to open output file: %s", filePath.c_str());
330
+ return false;
331
+ }
332
+
333
+ const int32_t numChannels = 1;
334
+ const int32_t bitsPerSample = 16;
335
+ const int32_t byteRate = sampleRate * numChannels * bitsPerSample / 8;
336
+ const int32_t blockAlign = numChannels * bitsPerSample / 8;
337
+ const int32_t dataSize = static_cast<int32_t>(samples.size()) * bitsPerSample / 8;
338
+ const int32_t chunkSize = 36 + dataSize;
339
+
340
+ outfile.write("RIFF", 4);
341
+ outfile.write(reinterpret_cast<const char*>(&chunkSize), 4);
342
+ outfile.write("WAVE", 4);
343
+
344
+ outfile.write("fmt ", 4);
345
+ const int32_t subchunk1Size = 16;
346
+ outfile.write(reinterpret_cast<const char*>(&subchunk1Size), 4);
347
+ const int16_t audioFormat = 1;
348
+ outfile.write(reinterpret_cast<const char*>(&audioFormat), 2);
349
+ const int16_t numChannelsInt16 = static_cast<int16_t>(numChannels);
350
+ outfile.write(reinterpret_cast<const char*>(&numChannelsInt16), 2);
351
+ outfile.write(reinterpret_cast<const char*>(&sampleRate), 4);
352
+ outfile.write(reinterpret_cast<const char*>(&byteRate), 4);
353
+ const int16_t blockAlignInt16 = static_cast<int16_t>(blockAlign);
354
+ outfile.write(reinterpret_cast<const char*>(&blockAlignInt16), 2);
355
+ const int16_t bitsPerSampleInt16 = static_cast<int16_t>(bitsPerSample);
356
+ outfile.write(reinterpret_cast<const char*>(&bitsPerSampleInt16), 2);
357
+
358
+ outfile.write("data", 4);
359
+ outfile.write(reinterpret_cast<const char*>(&dataSize), 4);
360
+
361
+ for (float sample : samples) {
362
+ float clamped = std::max(-1.0f, std::min(1.0f, sample));
363
+ int16_t intSample = static_cast<int16_t>(clamped * 32767.0f);
364
+ outfile.write(reinterpret_cast<const char*>(&intSample), sizeof(int16_t));
365
+ }
366
+
367
+ outfile.close();
368
+ LOGI("TTS: Successfully saved %zu samples to %s", samples.size(), filePath.c_str());
369
+ return true;
370
+ } catch (const std::exception& e) {
371
+ LOGE("TTS: Exception while saving WAV file: %s", e.what());
372
+ return false;
373
+ }
374
+ }
375
+
376
+ } // namespace sherpaonnx