react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -1,387 +0,0 @@
1
- #include "sherpa-onnx-tts-wrapper.h"
2
- #include "sherpa-onnx-model-detect.h"
3
- #include <android/log.h>
4
- #include <algorithm>
5
- #include <cctype>
6
- #include <fstream>
7
- #include <mutex>
8
- #include <optional>
9
- #include <unordered_map>
10
- #include <sstream>
11
- #include <sys/stat.h>
12
-
13
- // Use filesystem if available (C++17), otherwise fallback
14
- #if __cplusplus >= 201703L && __has_include(<filesystem>)
15
- #include <filesystem>
16
- namespace fs = std::filesystem;
17
- #elif __has_include(<experimental/filesystem>)
18
- #include <experimental/filesystem>
19
- namespace fs = std::experimental::filesystem;
20
- #else
21
- // Fallback: use stat/opendir for older compilers
22
- #include <dirent.h>
23
- #include <sys/stat.h>
24
- #endif
25
-
26
- // sherpa-onnx headers - use cxx-api which is compatible with libsherpa-onnx-cxx-api.so
27
- #include "sherpa-onnx/c-api/cxx-api.h"
28
-
29
- #define LOG_TAG "TtsWrapper"
30
- #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
31
- #define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
32
-
33
- namespace sherpaonnx {
34
-
35
- class TtsWrapper::Impl {
36
- public:
37
- bool initialized = false;
38
- std::string modelDir;
39
- std::optional<sherpa_onnx::cxx::OfflineTts> tts;
40
- // Hold active stream callbacks to ensure they remain alive while native code may call them
41
- std::unordered_map<uint64_t, std::shared_ptr<TtsStreamCallback>> activeStreamCallbacks;
42
- std::mutex streamMutex;
43
- };
44
-
45
- TtsWrapper::TtsWrapper() : pImpl(std::make_unique<Impl>()) {
46
- LOGI("TtsWrapper created");
47
- }
48
-
49
- TtsWrapper::~TtsWrapper() {
50
- release();
51
- LOGI("TtsWrapper destroyed");
52
- }
53
-
54
- TtsInitializeResult TtsWrapper::initialize(
55
- const std::string& modelDir,
56
- const std::string& modelType,
57
- int32_t numThreads,
58
- bool debug,
59
- std::optional<float> noiseScale,
60
- std::optional<float> noiseScaleW,
61
- std::optional<float> lengthScale
62
- ) {
63
- TtsInitializeResult result;
64
- result.success = false;
65
-
66
- if (pImpl->initialized) {
67
- release();
68
- }
69
-
70
- if (modelDir.empty()) {
71
- LOGE("TTS: Model directory is empty");
72
- return result;
73
- }
74
-
75
- try {
76
- sherpa_onnx::cxx::OfflineTtsConfig config;
77
- config.model.num_threads = numThreads;
78
- config.model.debug = debug;
79
-
80
- auto detect = DetectTtsModel(modelDir, modelType);
81
- if (!detect.ok) {
82
- LOGE("%s", detect.error.c_str());
83
- return result;
84
- }
85
-
86
- switch (detect.selectedKind) {
87
- case TtsModelKind::kVits:
88
- config.model.vits.model = detect.paths.ttsModel;
89
- config.model.vits.tokens = detect.paths.tokens;
90
- config.model.vits.data_dir = detect.paths.dataDir;
91
- if (noiseScale.has_value()) {
92
- config.model.vits.noise_scale = noiseScale.value();
93
- }
94
- if (noiseScaleW.has_value()) {
95
- config.model.vits.noise_scale_w = noiseScaleW.value();
96
- }
97
- if (lengthScale.has_value()) {
98
- config.model.vits.length_scale = lengthScale.value();
99
- }
100
- break;
101
- case TtsModelKind::kMatcha:
102
- config.model.matcha.acoustic_model = detect.paths.acousticModel;
103
- config.model.matcha.vocoder = detect.paths.vocoder;
104
- config.model.matcha.tokens = detect.paths.tokens;
105
- config.model.matcha.data_dir = detect.paths.dataDir;
106
- if (noiseScale.has_value()) {
107
- config.model.matcha.noise_scale = noiseScale.value();
108
- }
109
- if (lengthScale.has_value()) {
110
- config.model.matcha.length_scale = lengthScale.value();
111
- }
112
- break;
113
- case TtsModelKind::kKokoro:
114
- config.model.kokoro.model = detect.paths.ttsModel;
115
- config.model.kokoro.tokens = detect.paths.tokens;
116
- config.model.kokoro.data_dir = detect.paths.dataDir;
117
- config.model.kokoro.voices = detect.paths.voices;
118
- if (!detect.paths.lexicon.empty()) {
119
- config.model.kokoro.lexicon = detect.paths.lexicon;
120
- }
121
- if (lengthScale.has_value()) {
122
- config.model.kokoro.length_scale = lengthScale.value();
123
- }
124
- break;
125
- case TtsModelKind::kKitten:
126
- config.model.kitten.model = detect.paths.ttsModel;
127
- config.model.kitten.tokens = detect.paths.tokens;
128
- config.model.kitten.data_dir = detect.paths.dataDir;
129
- config.model.kitten.voices = detect.paths.voices;
130
- if (lengthScale.has_value()) {
131
- config.model.kitten.length_scale = lengthScale.value();
132
- }
133
- break;
134
- case TtsModelKind::kZipvoice:
135
- config.model.zipvoice.encoder = detect.paths.encoder;
136
- config.model.zipvoice.decoder = detect.paths.decoder;
137
- config.model.zipvoice.vocoder = detect.paths.vocoder;
138
- config.model.zipvoice.tokens = detect.paths.tokens;
139
- config.model.zipvoice.data_dir = detect.paths.dataDir;
140
- break;
141
- case TtsModelKind::kUnknown:
142
- default:
143
- LOGE("TTS: Unknown model type: %s", modelType.c_str());
144
- return result;
145
- }
146
-
147
- // Create TTS instance
148
- LOGI("TTS: Creating OfflineTts instance...");
149
- pImpl->tts = sherpa_onnx::cxx::OfflineTts::Create(config);
150
-
151
- if (!pImpl->tts.has_value()) {
152
- LOGE("TTS: Failed to create OfflineTts instance");
153
- return result;
154
- }
155
-
156
- pImpl->initialized = true;
157
- pImpl->modelDir = modelDir;
158
-
159
- LOGI("TTS: Initialization successful");
160
- LOGI("TTS: Sample rate: %d Hz", pImpl->tts.value().SampleRate());
161
- LOGI("TTS: Number of speakers: %d", pImpl->tts.value().NumSpeakers());
162
-
163
- // Success - return detected models
164
- result.success = true;
165
- result.detectedModels = detect.detectedModels;
166
- return result;
167
- } catch (const std::exception& e) {
168
- LOGE("TTS: Exception during initialization: %s", e.what());
169
- return result;
170
- } catch (...) {
171
- LOGE("TTS: Unknown exception during initialization");
172
- return result;
173
- }
174
- }
175
-
176
- TtsWrapper::AudioResult TtsWrapper::generate(
177
- const std::string& text,
178
- int32_t sid,
179
- float speed
180
- ) {
181
- AudioResult result;
182
- result.sampleRate = 0;
183
-
184
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
185
- LOGE("TTS: Not initialized. Call initialize() first.");
186
- return result;
187
- }
188
-
189
- if (text.empty()) {
190
- LOGE("TTS: Input text is empty");
191
- return result;
192
- }
193
-
194
- try {
195
- LOGI("TTS: Generating speech for text: %s (sid=%d, speed=%.2f)",
196
- text.c_str(), sid, speed);
197
-
198
- // Generate audio using cxx-api
199
- auto audio = pImpl->tts.value().Generate(text, sid, speed);
200
-
201
- // Copy samples to result
202
- result.samples = std::move(audio.samples);
203
- result.sampleRate = audio.sample_rate;
204
-
205
- LOGI("TTS: Generated %zu samples at %d Hz",
206
- result.samples.size(), result.sampleRate);
207
-
208
- return result;
209
- } catch (const std::exception& e) {
210
- LOGE("TTS: Exception during generation: %s", e.what());
211
- return result;
212
- } catch (...) {
213
- LOGE("TTS: Unknown exception during generation");
214
- return result;
215
- }
216
- }
217
-
218
- bool TtsWrapper::generateStream(
219
- const std::string& text,
220
- int32_t sid,
221
- float speed,
222
- StreamId streamId,
223
- const TtsStreamCallback& callback
224
- ) {
225
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
226
- LOGE("TTS: Not initialized. Call initialize() first.");
227
- return false;
228
- }
229
-
230
- if (text.empty()) {
231
- LOGE("TTS: Input text is empty");
232
- return false;
233
- }
234
-
235
- try {
236
- LOGI("TTS: Streaming generation for text: %s (sid=%d, speed=%.2f)",
237
- text.c_str(), sid, speed);
238
-
239
- // Keep a shared_ptr to the callback so it remains valid while native code may call it.
240
- std::shared_ptr<TtsStreamCallback> cbPtr = nullptr;
241
- if (callback) {
242
- cbPtr = std::make_shared<TtsStreamCallback>(callback);
243
- std::lock_guard<std::mutex> lock(pImpl->streamMutex);
244
- pImpl->activeStreamCallbacks.emplace(streamId, cbPtr);
245
- }
246
-
247
- auto shim = [](const float *samples, int32_t numSamples, float progress, void *arg) -> int32_t {
248
- auto *cb = reinterpret_cast<TtsStreamCallback*>(arg);
249
- if (!cb || !(*cb)) return 0;
250
- return (*cb)(samples, numSamples, progress);
251
- };
252
-
253
- pImpl->tts.value().Generate(
254
- text,
255
- sid,
256
- speed,
257
- cbPtr ? shim : nullptr,
258
- cbPtr ? cbPtr.get() : nullptr
259
- );
260
-
261
- return true;
262
- } catch (const std::exception& e) {
263
- LOGE("TTS: Exception during streaming generation: %s", e.what());
264
- return false;
265
- } catch (...) {
266
- LOGE("TTS: Unknown exception during streaming generation");
267
- return false;
268
- }
269
- }
270
-
271
- void TtsWrapper::cancelStream(StreamId streamId) {
272
- if (streamId == 0) return;
273
- std::lock_guard<std::mutex> lock(pImpl->streamMutex);
274
- pImpl->activeStreamCallbacks.erase(streamId);
275
- }
276
-
277
- void TtsWrapper::endStream(StreamId streamId) {
278
- if (streamId == 0) return;
279
- std::lock_guard<std::mutex> lock(pImpl->streamMutex);
280
- pImpl->activeStreamCallbacks.erase(streamId);
281
- }
282
-
283
- int32_t TtsWrapper::getSampleRate() const {
284
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
285
- LOGE("TTS: Not initialized. Call initialize() first.");
286
- return 0;
287
- }
288
- return pImpl->tts.value().SampleRate();
289
- }
290
-
291
- int32_t TtsWrapper::getNumSpeakers() const {
292
- if (!pImpl->initialized || !pImpl->tts.has_value()) {
293
- LOGE("TTS: Not initialized. Call initialize() first.");
294
- return 0;
295
- }
296
- return pImpl->tts.value().NumSpeakers();
297
- }
298
-
299
- bool TtsWrapper::isInitialized() const {
300
- return pImpl->initialized;
301
- }
302
-
303
- void TtsWrapper::release() {
304
- if (pImpl->initialized) {
305
- pImpl->tts.reset();
306
- // Clear any stored callbacks to allow them to be freed
307
- {
308
- std::lock_guard<std::mutex> lock(pImpl->streamMutex);
309
- pImpl->activeStreamCallbacks.clear();
310
- }
311
- pImpl->initialized = false;
312
- pImpl->modelDir.clear();
313
- LOGI("TTS: Resources released");
314
- }
315
- }
316
-
317
- bool TtsWrapper::saveToWavFile(
318
- const std::vector<float>& samples,
319
- int32_t sampleRate,
320
- const std::string& filePath
321
- ) {
322
- if (samples.empty()) {
323
- LOGE("TTS: Cannot save empty audio samples");
324
- return false;
325
- }
326
-
327
- if (sampleRate <= 0) {
328
- LOGE("TTS: Invalid sample rate: %d", sampleRate);
329
- return false;
330
- }
331
-
332
- try {
333
- std::ofstream outfile(filePath, std::ios::binary);
334
- if (!outfile) {
335
- LOGE("TTS: Failed to open output file: %s", filePath.c_str());
336
- return false;
337
- }
338
-
339
- // WAV file header
340
- const int32_t numChannels = 1; // Mono
341
- const int32_t bitsPerSample = 16; // 16-bit PCM
342
- const int32_t byteRate = sampleRate * numChannels * bitsPerSample / 8;
343
- const int32_t blockAlign = numChannels * bitsPerSample / 8;
344
- const int32_t dataSize = static_cast<int32_t>(samples.size()) * bitsPerSample / 8;
345
- const int32_t chunkSize = 36 + dataSize;
346
-
347
- // RIFF header
348
- outfile.write("RIFF", 4);
349
- outfile.write(reinterpret_cast<const char*>(&chunkSize), 4);
350
- outfile.write("WAVE", 4);
351
-
352
- // fmt subchunk
353
- outfile.write("fmt ", 4);
354
- const int32_t subchunk1Size = 16; // PCM
355
- outfile.write(reinterpret_cast<const char*>(&subchunk1Size), 4);
356
- const int16_t audioFormat = 1; // PCM
357
- outfile.write(reinterpret_cast<const char*>(&audioFormat), 2);
358
- const int16_t numChannelsInt16 = static_cast<int16_t>(numChannels);
359
- outfile.write(reinterpret_cast<const char*>(&numChannelsInt16), 2);
360
- outfile.write(reinterpret_cast<const char*>(&sampleRate), 4);
361
- outfile.write(reinterpret_cast<const char*>(&byteRate), 4);
362
- const int16_t blockAlignInt16 = static_cast<int16_t>(blockAlign);
363
- outfile.write(reinterpret_cast<const char*>(&blockAlignInt16), 2);
364
- const int16_t bitsPerSampleInt16 = static_cast<int16_t>(bitsPerSample);
365
- outfile.write(reinterpret_cast<const char*>(&bitsPerSampleInt16), 2);
366
-
367
- // data subchunk
368
- outfile.write("data", 4);
369
- outfile.write(reinterpret_cast<const char*>(&dataSize), 4);
370
-
371
- // Convert float samples to int16 PCM and write
372
- for (float sample : samples) {
373
- float clamped = std::max(-1.0f, std::min(1.0f, sample));
374
- int16_t intSample = static_cast<int16_t>(clamped * 32767.0f);
375
- outfile.write(reinterpret_cast<const char*>(&intSample), sizeof(int16_t));
376
- }
377
-
378
- outfile.close();
379
- LOGI("TTS: Successfully saved %zu samples to %s", samples.size(), filePath.c_str());
380
- return true;
381
- } catch (const std::exception& e) {
382
- LOGE("TTS: Exception while saving WAV file: %s", e.what());
383
- return false;
384
- }
385
- }
386
-
387
- } // namespace sherpaonnx
@@ -1,147 +0,0 @@
1
- #ifndef SHERPA_ONNX_TTS_WRAPPER_H
2
- #define SHERPA_ONNX_TTS_WRAPPER_H
3
-
4
- #include "sherpa-onnx-common.h"
5
- #include <cstdint>
6
- #include <functional>
7
- #include <memory>
8
- #include <optional>
9
- #include <string>
10
- #include <vector>
11
-
12
- namespace sherpaonnx {
13
-
14
- /**
15
- * Result of TTS initialization.
16
- */
17
- struct TtsInitializeResult {
18
- bool success;
19
- std::vector<DetectedModel> detectedModels; // List of detected models with type and path
20
- };
21
-
22
- /**
23
- * Wrapper class for sherpa-onnx OfflineTts.
24
- * This provides a C++ interface for Text-to-Speech functionality.
25
- */
26
- class TtsWrapper {
27
- public:
28
- TtsWrapper();
29
- ~TtsWrapper();
30
-
31
- /**
32
- * Initialize TTS with model directory.
33
- * @param modelDir Path to the model directory
34
- * @param modelType Model type ('vits', 'matcha', 'kokoro', 'kitten', 'zipvoice', 'auto')
35
- * @param numThreads Number of threads for inference (default: 2)
36
- * @param debug Enable debug logging (default: false)
37
- * @return TtsInitializeResult with success status and list of detected usable models
38
- */
39
- TtsInitializeResult initialize(
40
- const std::string& modelDir,
41
- const std::string& modelType = "auto",
42
- int32_t numThreads = 2,
43
- bool debug = false,
44
- std::optional<float> noiseScale = std::nullopt,
45
- std::optional<float> noiseScaleW = std::nullopt,
46
- std::optional<float> lengthScale = std::nullopt
47
- );
48
-
49
- /**
50
- * Audio generation result.
51
- */
52
- struct AudioResult {
53
- std::vector<float> samples; // Audio samples in range [-1.0, 1.0]
54
- int32_t sampleRate; // Sample rate in Hz
55
- };
56
-
57
- using StreamId = uint64_t;
58
-
59
- using TtsStreamCallback = std::function<int32_t(
60
- const float *samples,
61
- int32_t numSamples,
62
- float progress
63
- )>;
64
-
65
- /**
66
- * Generate speech from text.
67
- * @param text Text to convert to speech
68
- * @param sid Speaker ID for multi-speaker models (default: 0)
69
- * @param speed Speech speed multiplier (default: 1.0)
70
- * @return AudioResult with samples and sample rate
71
- */
72
- AudioResult generate(
73
- const std::string& text,
74
- int32_t sid = 0,
75
- float speed = 1.0f
76
- );
77
-
78
- /**
79
- * Generate speech with streaming callback.
80
- * @param text Text to convert to speech
81
- * @param sid Speaker ID for multi-speaker models (default: 0)
82
- * @param speed Speech speed multiplier (default: 1.0)
83
- * @param callback Callback invoked with partial audio samples
84
- * @return true if generation started, false on error
85
- */
86
- bool generateStream(
87
- const std::string& text,
88
- int32_t sid,
89
- float speed,
90
- StreamId streamId,
91
- const TtsStreamCallback& callback
92
- );
93
-
94
- /**
95
- * Cancel a streaming TTS callback and release its resources by ID.
96
- */
97
- void cancelStream(StreamId streamId);
98
-
99
- /**
100
- * Mark a stream as completed and release its resources by ID.
101
- */
102
- void endStream(StreamId streamId);
103
-
104
- /**
105
- * Save audio samples to a WAV file.
106
- * @param samples Audio samples vector
107
- * @param sampleRate Sample rate in Hz
108
- * @param filePath Output file path
109
- * @return true if successful, false otherwise
110
- */
111
- static bool saveToWavFile(
112
- const std::vector<float>& samples,
113
- int32_t sampleRate,
114
- const std::string& filePath
115
- );
116
-
117
- /**
118
- * Get the sample rate of the initialized TTS model.
119
- * @return Sample rate in Hz
120
- */
121
- int32_t getSampleRate() const;
122
-
123
- /**
124
- * Get the number of speakers/voices available in the model.
125
- * @return Number of speakers (0 or 1 for single-speaker models)
126
- */
127
- int32_t getNumSpeakers() const;
128
-
129
- /**
130
- * Check if the TTS is initialized.
131
- * @return true if initialized, false otherwise
132
- */
133
- bool isInitialized() const;
134
-
135
- /**
136
- * Release resources.
137
- */
138
- void release();
139
-
140
- private:
141
- class Impl;
142
- std::unique_ptr<Impl> pImpl;
143
- };
144
-
145
- } // namespace sherpaonnx
146
-
147
- #endif // SHERPA_ONNX_TTS_WRAPPER_H