react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -1,668 +1,1123 @@
1
- package com.sherpaonnx
2
-
3
- import android.content.Intent
4
- import android.media.AudioAttributes
5
- import android.media.AudioFormat
6
- import android.media.AudioManager
7
- import android.media.AudioTrack
8
- import android.net.Uri
9
- import android.os.Build
10
- import android.provider.DocumentsContract
11
- import androidx.core.content.FileProvider
12
- import com.facebook.react.bridge.Arguments
13
- import com.facebook.react.bridge.Promise
14
- import com.facebook.react.bridge.ReadableArray
15
- import com.facebook.react.bridge.ReactApplicationContext
16
- import java.io.File
17
- import java.io.FileOutputStream
18
- import java.io.InputStream
19
- import java.io.OutputStream
20
- import java.util.concurrent.atomic.AtomicBoolean
21
-
22
- internal class SherpaOnnxTtsHelper(
23
- private val context: ReactApplicationContext,
24
- private val native: NativeTtsBridge,
25
- private val emitChunk: (FloatArray, Int, Float, Boolean) -> Unit,
26
- private val emitError: (String) -> Unit,
27
- private val emitEnd: (Boolean) -> Unit
28
- ) {
29
- interface NativeTtsBridge {
30
- fun nativeTtsInitialize(
31
- modelDir: String,
32
- modelType: String,
33
- numThreads: Int,
34
- debug: Boolean,
35
- noiseScale: Double,
36
- noiseScaleW: Double,
37
- lengthScale: Double
38
- ): HashMap<String, Any>?
39
-
40
- fun nativeTtsGenerate(text: String, sid: Int, speed: Float): HashMap<String, Any>?
41
-
42
- fun nativeTtsGenerateWithTimestamps(text: String, sid: Int, speed: Float): HashMap<String, Any>?
43
-
44
- fun nativeTtsGenerateStream(text: String, sid: Int, speed: Float): Boolean
45
-
46
- fun nativeTtsCancelStream()
47
-
48
- fun nativeTtsGetSampleRate(): Int
49
-
50
- fun nativeTtsGetNumSpeakers(): Int
51
-
52
- fun nativeTtsRelease()
53
-
54
- fun nativeTtsSaveToWavFile(samples: FloatArray, sampleRate: Int, filePath: String): Boolean
55
- }
56
-
57
- private data class TtsInitState(
58
- val modelDir: String,
59
- val modelType: String,
60
- val numThreads: Int,
61
- val debug: Boolean,
62
- val noiseScale: Double?,
63
- val noiseScaleW: Double?,
64
- val lengthScale: Double?
65
- )
66
-
67
- private val ttsStreamRunning = AtomicBoolean(false)
68
- private val ttsStreamCancelled = AtomicBoolean(false)
69
- private var ttsStreamThread: Thread? = null
70
- private var ttsPcmTrack: AudioTrack? = null
71
- private var ttsInitState: TtsInitState? = null
72
-
73
- fun initializeTts(
74
- modelDir: String,
75
- modelType: String,
76
- numThreads: Double,
77
- debug: Boolean,
78
- noiseScale: Double?,
79
- noiseScaleW: Double?,
80
- lengthScale: Double?,
81
- promise: Promise
82
- ) {
83
- try {
84
- val result = native.nativeTtsInitialize(
85
- modelDir,
86
- modelType,
87
- numThreads.toInt(),
88
- debug,
89
- noiseScale ?: Double.NaN,
90
- noiseScaleW ?: Double.NaN,
91
- lengthScale ?: Double.NaN
92
- )
93
-
94
- if (result == null) {
95
- promise.reject("TTS_INIT_ERROR", "Failed to initialize TTS: native call returned null")
96
- return
97
- }
98
-
99
- val success = result["success"] as? Boolean ?: false
100
-
101
- if (success) {
102
- val detectedModels = result["detectedModels"] as? ArrayList<*>
103
- val modelsArray = Arguments.createArray()
104
-
105
- detectedModels?.forEach { modelObj ->
106
- if (modelObj is HashMap<*, *>) {
107
- val modelMap = Arguments.createMap()
108
- modelMap.putString("type", modelObj["type"] as? String ?: "")
109
- modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
110
- modelsArray.pushMap(modelMap)
111
- }
112
- }
113
-
114
- val resultMap = Arguments.createMap()
115
- resultMap.putBoolean("success", true)
116
- resultMap.putArray("detectedModels", modelsArray)
117
- ttsInitState = TtsInitState(
118
- modelDir,
119
- modelType,
120
- numThreads.toInt(),
121
- debug,
122
- noiseScale?.takeUnless { it.isNaN() },
123
- noiseScaleW?.takeUnless { it.isNaN() },
124
- lengthScale?.takeUnless { it.isNaN() }
125
- )
126
- promise.resolve(resultMap)
127
- } else {
128
- promise.reject("TTS_INIT_ERROR", "Failed to initialize TTS")
129
- }
130
- } catch (e: Exception) {
131
- promise.reject("TTS_INIT_ERROR", "Failed to initialize TTS", e)
132
- }
133
- }
134
-
135
- fun updateTtsParams(
136
- noiseScale: Double?,
137
- noiseScaleW: Double?,
138
- lengthScale: Double?,
139
- promise: Promise
140
- ) {
141
- if (ttsStreamRunning.get()) {
142
- promise.reject("TTS_UPDATE_ERROR", "Cannot update params while streaming")
143
- return
144
- }
145
-
146
- val state = ttsInitState
147
- if (state == null) {
148
- promise.reject("TTS_UPDATE_ERROR", "TTS not initialized")
149
- return
150
- }
151
-
152
- val nextNoiseScale = when {
153
- noiseScale == null -> null
154
- noiseScale.isNaN() -> state.noiseScale
155
- else -> noiseScale
156
- }
157
- val nextNoiseScaleW = when {
158
- noiseScaleW == null -> null
159
- noiseScaleW.isNaN() -> state.noiseScaleW
160
- else -> noiseScaleW
161
- }
162
- val nextLengthScale = when {
163
- lengthScale == null -> null
164
- lengthScale.isNaN() -> state.lengthScale
165
- else -> lengthScale
166
- }
167
-
168
- try {
169
- val result = native.nativeTtsInitialize(
170
- state.modelDir,
171
- state.modelType,
172
- state.numThreads,
173
- state.debug,
174
- nextNoiseScale ?: Double.NaN,
175
- nextNoiseScaleW ?: Double.NaN,
176
- nextLengthScale ?: Double.NaN
177
- )
178
-
179
- if (result == null) {
180
- promise.reject("TTS_UPDATE_ERROR", "Failed to update TTS params: native call returned null")
181
- return
182
- }
183
-
184
- val success = result["success"] as? Boolean ?: false
185
- if (!success) {
186
- promise.reject("TTS_UPDATE_ERROR", "Failed to update TTS params")
187
- return
188
- }
189
-
190
- val detectedModels = result["detectedModels"] as? ArrayList<*>
191
- val modelsArray = Arguments.createArray()
192
- detectedModels?.forEach { modelObj ->
193
- if (modelObj is HashMap<*, *>) {
194
- val modelMap = Arguments.createMap()
195
- modelMap.putString("type", modelObj["type"] as? String ?: "")
196
- modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
197
- modelsArray.pushMap(modelMap)
198
- }
199
- }
200
-
201
- val resultMap = Arguments.createMap()
202
- resultMap.putBoolean("success", true)
203
- resultMap.putArray("detectedModels", modelsArray)
204
- ttsInitState = TtsInitState(
205
- state.modelDir,
206
- state.modelType,
207
- state.numThreads,
208
- state.debug,
209
- nextNoiseScale,
210
- nextNoiseScaleW,
211
- nextLengthScale
212
- )
213
- promise.resolve(resultMap)
214
- } catch (e: Exception) {
215
- promise.reject("TTS_UPDATE_ERROR", "Failed to update TTS params", e)
216
- }
217
- }
218
-
219
- fun generateTts(text: String, sid: Double, speed: Double, promise: Promise) {
220
- try {
221
- val result = native.nativeTtsGenerate(text, sid.toInt(), speed.toFloat())
222
- if (result != null) {
223
- val map = Arguments.createMap()
224
-
225
- @Suppress("UNCHECKED_CAST")
226
- val samples = result["samples"] as? FloatArray
227
- val sampleRate = result["sampleRate"] as? Int
228
-
229
- if (samples != null && sampleRate != null) {
230
- val samplesArray = Arguments.createArray()
231
- for (sample in samples) {
232
- samplesArray.pushDouble(sample.toDouble())
233
- }
234
-
235
- map.putArray("samples", samplesArray)
236
- map.putInt("sampleRate", sampleRate)
237
- promise.resolve(map)
238
- } else {
239
- promise.reject("TTS_GENERATE_ERROR", "Invalid result format from native code")
240
- }
241
- } else {
242
- promise.reject("TTS_GENERATE_ERROR", "Failed to generate speech")
243
- }
244
- } catch (e: Exception) {
245
- promise.reject("TTS_GENERATE_ERROR", "Failed to generate speech", e)
246
- }
247
- }
248
-
249
- fun generateTtsWithTimestamps(text: String, sid: Double, speed: Double, promise: Promise) {
250
- try {
251
- val result = native.nativeTtsGenerateWithTimestamps(text, sid.toInt(), speed.toFloat())
252
- if (result != null) {
253
- val map = Arguments.createMap()
254
-
255
- @Suppress("UNCHECKED_CAST")
256
- val samples = result["samples"] as? FloatArray
257
- val sampleRate = result["sampleRate"] as? Int
258
- val subtitles = result["subtitles"] as? ArrayList<*>
259
- val estimated = result["estimated"] as? Boolean ?: true
260
-
261
- if (samples != null && sampleRate != null) {
262
- val samplesArray = Arguments.createArray()
263
- for (sample in samples) {
264
- samplesArray.pushDouble(sample.toDouble())
265
- }
266
-
267
- val subtitlesArray = Arguments.createArray()
268
- subtitles?.forEach { item ->
269
- if (item is HashMap<*, *>) {
270
- val subtitleMap = Arguments.createMap()
271
- subtitleMap.putString("text", item["text"] as? String ?: "")
272
- subtitleMap.putDouble("start", (item["start"] as? Number)?.toDouble() ?: 0.0)
273
- subtitleMap.putDouble("end", (item["end"] as? Number)?.toDouble() ?: 0.0)
274
- subtitlesArray.pushMap(subtitleMap)
275
- }
276
- }
277
-
278
- map.putArray("samples", samplesArray)
279
- map.putInt("sampleRate", sampleRate)
280
- map.putArray("subtitles", subtitlesArray)
281
- map.putBoolean("estimated", estimated)
282
- promise.resolve(map)
283
- } else {
284
- promise.reject("TTS_GENERATE_ERROR", "Invalid result format from native code")
285
- }
286
- } else {
287
- promise.reject("TTS_GENERATE_ERROR", "Failed to generate speech")
288
- }
289
- } catch (e: Exception) {
290
- promise.reject("TTS_GENERATE_ERROR", "Failed to generate speech", e)
291
- }
292
- }
293
-
294
- fun generateTtsStream(text: String, sid: Double, speed: Double, promise: Promise) {
295
- if (ttsStreamRunning.get()) {
296
- promise.reject("TTS_STREAM_ERROR", "TTS streaming already in progress")
297
- return
298
- }
299
-
300
- ttsStreamCancelled.set(false)
301
- ttsStreamRunning.set(true)
302
-
303
- ttsStreamThread = Thread {
304
- try {
305
- val success = native.nativeTtsGenerateStream(text, sid.toInt(), speed.toFloat())
306
- if (!success && !ttsStreamCancelled.get()) {
307
- emitError("TTS streaming generation failed")
308
- }
309
- } catch (e: Exception) {
310
- emitError("TTS streaming failed: ${e.message}")
311
- } finally {
312
- emitEnd(ttsStreamCancelled.get())
313
- ttsStreamRunning.set(false)
314
- }
315
- }
316
-
317
- ttsStreamThread?.start()
318
- promise.resolve(null)
319
- }
320
-
321
- fun cancelTtsStream(promise: Promise) {
322
- ttsStreamCancelled.set(true)
323
- try {
324
- native.nativeTtsCancelStream()
325
- ttsStreamThread?.interrupt()
326
- } catch (e: Exception) {
327
- promise.reject("TTS_STREAM_ERROR", "Failed to cancel TTS stream", e)
328
- return
329
- }
330
- promise.resolve(null)
331
- }
332
-
333
- fun startTtsPcmPlayer(sampleRate: Double, channels: Double, promise: Promise) {
334
- try {
335
- if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) {
336
- promise.reject("TTS_PCM_ERROR", "PCM playback requires API 21+")
337
- return
338
- }
339
-
340
- if (channels.toInt() != 1) {
341
- promise.reject("TTS_PCM_ERROR", "PCM playback supports mono only")
342
- return
343
- }
344
-
345
- stopPcmPlayerInternal()
346
-
347
- val channelConfig = AudioFormat.CHANNEL_OUT_MONO
348
-
349
- val audioFormat = AudioFormat.Builder()
350
- .setSampleRate(sampleRate.toInt())
351
- .setChannelMask(channelConfig)
352
- .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
353
- .build()
354
-
355
- val minBufferSize = AudioTrack.getMinBufferSize(
356
- sampleRate.toInt(),
357
- channelConfig,
358
- AudioFormat.ENCODING_PCM_FLOAT
359
- )
360
-
361
- if (minBufferSize == AudioTrack.ERROR || minBufferSize == AudioTrack.ERROR_BAD_VALUE) {
362
- promise.reject("TTS_PCM_ERROR", "Invalid buffer size for PCM player")
363
- return
364
- }
365
-
366
- val attributes = AudioAttributes.Builder()
367
- .setUsage(AudioAttributes.USAGE_MEDIA)
368
- .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
369
- .build()
370
-
371
- ttsPcmTrack = AudioTrack(
372
- attributes,
373
- audioFormat,
374
- minBufferSize,
375
- AudioTrack.MODE_STREAM,
376
- AudioManager.AUDIO_SESSION_ID_GENERATE
377
- )
378
-
379
- ttsPcmTrack?.play()
380
- promise.resolve(null)
381
- } catch (e: Exception) {
382
- promise.reject("TTS_PCM_ERROR", "Failed to start PCM player", e)
383
- }
384
- }
385
-
386
- fun writeTtsPcmChunk(samples: ReadableArray, promise: Promise) {
387
- val track = ttsPcmTrack
388
- if (track == null) {
389
- promise.reject("TTS_PCM_ERROR", "PCM player not initialized")
390
- return
391
- }
392
-
393
- try {
394
- val buffer = FloatArray(samples.size())
395
- for (i in 0 until samples.size()) {
396
- buffer[i] = samples.getDouble(i).toFloat()
397
- }
398
-
399
- val written = track.write(buffer, 0, buffer.size, AudioTrack.WRITE_BLOCKING)
400
- if (written < 0) {
401
- promise.reject("TTS_PCM_ERROR", "PCM write failed: $written")
402
- return
403
- }
404
-
405
- promise.resolve(null)
406
- } catch (e: Exception) {
407
- promise.reject("TTS_PCM_ERROR", "Failed to write PCM chunk", e)
408
- }
409
- }
410
-
411
- fun stopTtsPcmPlayer(promise: Promise) {
412
- try {
413
- stopPcmPlayerInternal()
414
- promise.resolve(null)
415
- } catch (e: Exception) {
416
- promise.reject("TTS_PCM_ERROR", "Failed to stop PCM player", e)
417
- }
418
- }
419
-
420
- fun getTtsSampleRate(promise: Promise) {
421
- try {
422
- val sampleRate = native.nativeTtsGetSampleRate()
423
- promise.resolve(sampleRate.toDouble())
424
- } catch (e: Exception) {
425
- promise.reject("TTS_ERROR", "Failed to get sample rate", e)
426
- }
427
- }
428
-
429
- fun getTtsNumSpeakers(promise: Promise) {
430
- try {
431
- val numSpeakers = native.nativeTtsGetNumSpeakers()
432
- promise.resolve(numSpeakers.toDouble())
433
- } catch (e: Exception) {
434
- promise.reject("TTS_ERROR", "Failed to get number of speakers", e)
435
- }
436
- }
437
-
438
- fun unloadTts(promise: Promise) {
439
- try {
440
- stopPcmPlayerInternal()
441
- native.nativeTtsRelease()
442
- ttsInitState = null
443
- promise.resolve(null)
444
- } catch (e: Exception) {
445
- promise.reject("TTS_RELEASE_ERROR", "Failed to release TTS resources", e)
446
- }
447
- }
448
-
449
- fun saveTtsAudioToFile(
450
- samples: ReadableArray,
451
- sampleRate: Double,
452
- filePath: String,
453
- promise: Promise
454
- ) {
455
- try {
456
- val samplesArray = FloatArray(samples.size())
457
- for (i in 0 until samples.size()) {
458
- samplesArray[i] = samples.getDouble(i).toFloat()
459
- }
460
-
461
- val success = native.nativeTtsSaveToWavFile(samplesArray, sampleRate.toInt(), filePath)
462
- if (success) {
463
- promise.resolve(filePath)
464
- } else {
465
- promise.reject("TTS_SAVE_ERROR", "Failed to save audio to file")
466
- }
467
- } catch (e: Exception) {
468
- promise.reject("TTS_SAVE_ERROR", "Failed to save audio to file", e)
469
- }
470
- }
471
-
472
- fun saveTtsAudioToContentUri(
473
- samples: ReadableArray,
474
- sampleRate: Double,
475
- directoryUri: String,
476
- filename: String,
477
- promise: Promise
478
- ) {
479
- try {
480
- val samplesArray = FloatArray(samples.size())
481
- for (i in 0 until samples.size()) {
482
- samplesArray[i] = samples.getDouble(i).toFloat()
483
- }
484
-
485
- val resolver = context.contentResolver
486
- val dirUri = Uri.parse(directoryUri)
487
- val fileUri = createDocumentInDirectory(resolver, dirUri, filename, "audio/wav")
488
-
489
- resolver.openOutputStream(fileUri, "w")?.use { outputStream ->
490
- writeWavToStream(samplesArray, sampleRate.toInt(), outputStream)
491
- } ?: throw IllegalStateException("Failed to open output stream for URI: $fileUri")
492
-
493
- promise.resolve(fileUri.toString())
494
- } catch (e: Exception) {
495
- promise.reject("TTS_SAVE_ERROR", "Failed to save audio to content URI", e)
496
- }
497
- }
498
-
499
- fun saveTtsTextToContentUri(
500
- text: String,
501
- directoryUri: String,
502
- filename: String,
503
- mimeType: String,
504
- promise: Promise
505
- ) {
506
- try {
507
- val resolver = context.contentResolver
508
- val dirUri = Uri.parse(directoryUri)
509
- val fileUri = createDocumentInDirectory(resolver, dirUri, filename, mimeType)
510
-
511
- resolver.openOutputStream(fileUri, "w")?.use { outputStream ->
512
- outputStream.write(text.toByteArray(Charsets.UTF_8))
513
- } ?: throw IllegalStateException("Failed to open output stream for URI: $fileUri")
514
-
515
- promise.resolve(fileUri.toString())
516
- } catch (e: Exception) {
517
- promise.reject("TTS_SAVE_ERROR", "Failed to save text to content URI", e)
518
- }
519
- }
520
-
521
- fun copyTtsContentUriToCache(fileUri: String, filename: String, promise: Promise) {
522
- try {
523
- val resolver = context.contentResolver
524
- val uri = Uri.parse(fileUri)
525
- val cacheFile = File(context.cacheDir, filename)
526
-
527
- resolver.openInputStream(uri)?.use { inputStream ->
528
- FileOutputStream(cacheFile).use { outputStream ->
529
- copyStream(inputStream, outputStream)
530
- }
531
- } ?: throw IllegalStateException("Failed to open input stream for URI: $fileUri")
532
-
533
- promise.resolve(cacheFile.absolutePath)
534
- } catch (e: Exception) {
535
- promise.reject("TTS_SAVE_ERROR", "Failed to copy audio to cache", e)
536
- }
537
- }
538
-
539
- fun shareTtsAudio(fileUri: String, mimeType: String, promise: Promise) {
540
- try {
541
- val uri = if (fileUri.startsWith("content://")) {
542
- Uri.parse(fileUri)
543
- } else {
544
- val path = if (fileUri.startsWith("file://")) {
545
- try {
546
- Uri.parse(fileUri).path ?: fileUri.replaceFirst("file://", "")
547
- } catch (e: Exception) {
548
- fileUri.replaceFirst("file://", "")
549
- }
550
- } else {
551
- fileUri
552
- }
553
-
554
- val file = File(path)
555
- val authority = context.packageName + ".fileprovider"
556
- FileProvider.getUriForFile(context, authority, file)
557
- }
558
-
559
- val shareIntent = Intent(Intent.ACTION_SEND).apply {
560
- type = mimeType
561
- putExtra(Intent.EXTRA_STREAM, uri)
562
- addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
563
- }
564
-
565
- val chooser = Intent.createChooser(shareIntent, "Share audio")
566
- chooser.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
567
- context.startActivity(chooser)
568
- promise.resolve(null)
569
- } catch (e: Exception) {
570
- promise.reject("TTS_SHARE_ERROR", "Failed to share audio", e)
571
- }
572
- }
573
-
574
- fun emitTtsStreamChunk(samples: FloatArray, sampleRate: Int, progress: Float, isFinal: Boolean) {
575
- emitChunk(samples, sampleRate, progress, isFinal)
576
- }
577
-
578
- fun emitTtsStreamError(message: String) {
579
- emitError(message)
580
- }
581
-
582
- fun emitTtsStreamEnd(cancelled: Boolean) {
583
- emitEnd(cancelled)
584
- }
585
-
586
- private fun stopPcmPlayerInternal() {
587
- ttsPcmTrack?.apply {
588
- try {
589
- stop()
590
- } catch (_: IllegalStateException) {
591
- }
592
- flush()
593
- release()
594
- }
595
- ttsPcmTrack = null
596
- }
597
-
598
- private fun createDocumentInDirectory(
599
- resolver: android.content.ContentResolver,
600
- directoryUri: Uri,
601
- filename: String,
602
- mimeType: String
603
- ): Uri {
604
- return if (DocumentsContract.isTreeUri(directoryUri)) {
605
- val documentId = DocumentsContract.getTreeDocumentId(directoryUri)
606
- val dirDocUri = DocumentsContract.buildDocumentUriUsingTree(directoryUri, documentId)
607
- DocumentsContract.createDocument(resolver, dirDocUri, mimeType, filename)
608
- ?: throw IllegalStateException("Failed to create document in tree URI")
609
- } else {
610
- DocumentsContract.createDocument(resolver, directoryUri, mimeType, filename)
611
- ?: throw IllegalStateException("Failed to create document in directory URI")
612
- }
613
- }
614
-
615
- private fun writeWavToStream(samples: FloatArray, sampleRate: Int, outputStream: OutputStream) {
616
- val numChannels = 1
617
- val bitsPerSample = 16
618
- val byteRate = sampleRate * numChannels * bitsPerSample / 8
619
- val blockAlign = numChannels * bitsPerSample / 8
620
- val dataSize = samples.size * 2
621
- val chunkSize = 36 + dataSize
622
-
623
- outputStream.write("RIFF".toByteArray(Charsets.US_ASCII))
624
- writeIntLE(outputStream, chunkSize)
625
- outputStream.write("WAVE".toByteArray(Charsets.US_ASCII))
626
- outputStream.write("fmt ".toByteArray(Charsets.US_ASCII))
627
- writeIntLE(outputStream, 16)
628
- writeShortLE(outputStream, 1)
629
- writeShortLE(outputStream, numChannels.toShort())
630
- writeIntLE(outputStream, sampleRate)
631
- writeIntLE(outputStream, byteRate)
632
- writeShortLE(outputStream, blockAlign.toShort())
633
- writeShortLE(outputStream, bitsPerSample.toShort())
634
- outputStream.write("data".toByteArray(Charsets.US_ASCII))
635
- writeIntLE(outputStream, dataSize)
636
-
637
- for (sample in samples) {
638
- val clamped = sample.coerceIn(-1.0f, 1.0f)
639
- val intSample = (clamped * 32767.0f).toInt()
640
- writeShortLE(outputStream, intSample.toShort())
641
- }
642
-
643
- outputStream.flush()
644
- }
645
-
646
- private fun writeIntLE(outputStream: OutputStream, value: Int) {
647
- outputStream.write(value and 0xFF)
648
- outputStream.write((value shr 8) and 0xFF)
649
- outputStream.write((value shr 16) and 0xFF)
650
- outputStream.write((value shr 24) and 0xFF)
651
- }
652
-
653
- private fun writeShortLE(outputStream: OutputStream, value: Short) {
654
- val intValue = value.toInt()
655
- outputStream.write(intValue and 0xFF)
656
- outputStream.write((intValue shr 8) and 0xFF)
657
- }
658
-
659
- private fun copyStream(inputStream: InputStream, outputStream: OutputStream) {
660
- val buffer = ByteArray(8192)
661
- var bytes = inputStream.read(buffer)
662
- while (bytes >= 0) {
663
- outputStream.write(buffer, 0, bytes)
664
- bytes = inputStream.read(buffer)
665
- }
666
- outputStream.flush()
667
- }
668
- }
1
+ package com.sherpaonnx
2
+
3
+ import android.app.ActivityManager
4
+ import android.content.Context
5
+ import android.content.Intent
6
+ import android.media.AudioAttributes
7
+ import android.media.AudioFormat
8
+ import android.media.AudioManager
9
+ import android.media.AudioTrack
10
+ import android.net.Uri
11
+ import android.os.Build
12
+ import android.os.Handler
13
+ import android.os.Looper
14
+ import android.provider.DocumentsContract
15
+ import android.util.Log
16
+ import androidx.core.content.FileProvider
17
+ import com.facebook.react.bridge.Arguments
18
+ import com.facebook.react.bridge.Promise
19
+ import com.facebook.react.bridge.ReadableArray
20
+ import com.facebook.react.bridge.ReadableMap
21
+ import com.facebook.react.bridge.ReactApplicationContext
22
+ import com.facebook.react.bridge.WritableMap
23
+ import com.k2fsa.sherpa.onnx.GeneratedAudio
24
+ import com.k2fsa.sherpa.onnx.GenerationConfig
25
+ import com.k2fsa.sherpa.onnx.OfflineTts
26
+ import com.k2fsa.sherpa.onnx.OfflineTtsConfig
27
+ import com.k2fsa.sherpa.onnx.OfflineTtsModelConfig
28
+ import com.k2fsa.sherpa.onnx.OfflineTtsPocketModelConfig
29
+ import com.k2fsa.sherpa.onnx.OfflineTtsVitsModelConfig
30
+ import com.k2fsa.sherpa.onnx.OfflineTtsMatchaModelConfig
31
+ import com.k2fsa.sherpa.onnx.OfflineTtsKokoroModelConfig
32
+ import com.k2fsa.sherpa.onnx.OfflineTtsKittenModelConfig
33
+ import java.io.File
34
+ import java.io.FileOutputStream
35
+ import java.io.InputStream
36
+ import java.io.OutputStream
37
+ import java.util.concurrent.ConcurrentHashMap
38
+ import java.util.concurrent.Executors
39
+ import java.util.concurrent.atomic.AtomicBoolean
40
+
41
+ internal class SherpaOnnxTtsHelper(
42
+ private val context: ReactApplicationContext,
43
+ private val detectTtsModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?,
44
+ private val emitChunk: (String, FloatArray, Int, Float, Boolean) -> Unit,
45
+ private val emitError: (String, String) -> Unit,
46
+ private val emitEnd: (String, Boolean) -> Unit
47
+ ) {
48
+
49
+ private data class TtsInitState(
50
+ val modelDir: String,
51
+ val modelType: String,
52
+ val numThreads: Int,
53
+ val debug: Boolean,
54
+ val noiseScale: Double?,
55
+ val noiseScaleW: Double?,
56
+ val lengthScale: Double?,
57
+ val ruleFsts: String?,
58
+ val ruleFars: String?,
59
+ val maxNumSentences: Int?,
60
+ val silenceScale: Double?,
61
+ val provider: String?
62
+ )
63
+
64
+ private data class TtsEngineInstance(
65
+ @Volatile var tts: OfflineTts? = null,
66
+ @Volatile var zipvoiceTts: ZipvoiceTtsWrapper? = null,
67
+ var ttsInitState: TtsInitState? = null,
68
+ val ttsStreamRunning: AtomicBoolean = AtomicBoolean(false),
69
+ val ttsStreamCancelled: AtomicBoolean = AtomicBoolean(false),
70
+ var ttsStreamThread: Thread? = null,
71
+ var ttsPcmTrack: AudioTrack? = null
72
+ ) {
73
+ private val lock = Any()
74
+
75
+ fun hasEngine(): Boolean = synchronized(lock) { tts != null || zipvoiceTts != null }
76
+ val isZipvoice: Boolean get() = synchronized(lock) { zipvoiceTts != null }
77
+ fun releaseEngines() {
78
+ synchronized(lock) {
79
+ tts?.release()
80
+ tts = null
81
+ zipvoiceTts?.release()
82
+ zipvoiceTts = null
83
+ ttsInitState = null
84
+ }
85
+ }
86
+ fun stopPcmPlayer() {
87
+ synchronized(lock) {
88
+ ttsPcmTrack?.apply {
89
+ try { stop() } catch (_: IllegalStateException) {}
90
+ flush()
91
+ release()
92
+ }
93
+ ttsPcmTrack = null
94
+ }
95
+ }
96
+ }
97
+
98
+ private val instances = ConcurrentHashMap<String, TtsEngineInstance>()
99
+
100
+ private fun getInstance(instanceId: String): TtsEngineInstance? = instances[instanceId]
101
+
102
+ /** Run promise resolve/reject on the UI thread so React state updates run on the main thread. */
103
+ private val mainHandler = Handler(Looper.getMainLooper())
104
+ private fun resolveOnUiThread(promise: Promise, result: WritableMap) {
105
+ mainHandler.post { promise.resolve(result) }
106
+ }
107
+ private fun rejectOnUiThread(promise: Promise, code: String, message: String, throwable: Throwable? = null) {
108
+ mainHandler.post {
109
+ if (throwable != null) promise.reject(code, message, throwable) else promise.reject(code, message)
110
+ }
111
+ }
112
+
113
+ /** Single-thread executor for TTS init so the RN bridge thread is not blocked (avoids Inspector/dev WebSocket races in debug builds). */
114
+ private val ttsInitExecutor = Executors.newSingleThreadExecutor()
115
+
116
+ /**
117
+ * Shuts down the TTS init executor and releases all engine instances.
118
+ * Call from the native module's onCatalystInstanceDestroy() to avoid leaking the executor thread.
119
+ */
120
+ fun shutdown() {
121
+ try {
122
+ ttsInitExecutor.shutdown()
123
+ if (!ttsInitExecutor.awaitTermination(3, java.util.concurrent.TimeUnit.SECONDS)) {
124
+ ttsInitExecutor.shutdownNow()
125
+ }
126
+ } catch (e: InterruptedException) {
127
+ Thread.currentThread().interrupt()
128
+ ttsInitExecutor.shutdownNow()
129
+ }
130
+ instances.values.forEach { inst ->
131
+ inst.releaseEngines()
132
+ inst.stopPcmPlayer()
133
+ }
134
+ instances.clear()
135
+ }
136
+
137
+ fun initializeTts(
138
+ instanceId: String,
139
+ modelDir: String,
140
+ modelType: String,
141
+ numThreads: Double,
142
+ debug: Boolean,
143
+ noiseScale: Double?,
144
+ noiseScaleW: Double?,
145
+ lengthScale: Double?,
146
+ ruleFsts: String?,
147
+ ruleFars: String?,
148
+ maxNumSentences: Double?,
149
+ silenceScale: Double?,
150
+ provider: String?,
151
+ promise: Promise
152
+ ) {
153
+ ttsInitExecutor.execute init@{
154
+ try {
155
+ val result = detectTtsModel(modelDir, modelType)
156
+ if (result == null) {
157
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: Failed to detect TTS model: native call returned null")
158
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", "Failed to detect TTS model: native call returned null")
159
+ return@init
160
+ }
161
+ val success = result["success"] as? Boolean ?: false
162
+ if (!success) {
163
+ val reason = result["error"] as? String
164
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: ${reason ?: "Failed to detect TTS model"}")
165
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", reason ?: "Failed to detect TTS model")
166
+ return@init
167
+ }
168
+ val paths = (result["paths"] as? Map<*, *>)?.mapValues { (_, v) -> (v as? String).orEmpty() }?.mapKeys { it.key.toString() } ?: emptyMap()
169
+ val modelTypeStr = result["modelType"] as? String ?: "vits"
170
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
171
+
172
+ val inst = instances.getOrPut(instanceId) { TtsEngineInstance() }
173
+ inst.stopPcmPlayer()
174
+ inst.releaseEngines()
175
+
176
+ val sampleRate: Int
177
+ val numSpeakers: Int
178
+
179
+ if (modelTypeStr == "zipvoice") {
180
+ val vocoderPath = path(paths, "vocoder")
181
+ if (vocoderPath.isBlank()) {
182
+ val msg = "Zipvoice distill models (encoder+decoder only, no vocoder) are not supported. Use the full Zipvoice model that includes vocos_24khz.onnx (or similar vocoder file)."
183
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: $msg")
184
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", msg)
185
+ return@init
186
+ }
187
+ val am = context.applicationContext.getSystemService(Context.ACTIVITY_SERVICE) as? ActivityManager
188
+ if (am != null) {
189
+ val memInfo = ActivityManager.MemoryInfo()
190
+ am.getMemoryInfo(memInfo)
191
+ val availMb = memInfo.availMem / (1024 * 1024)
192
+ if (memInfo.availMem < 800L * 1024 * 1024) {
193
+ val msg = "Not enough free memory to load the Zipvoice model (available: ${availMb} MB). Close other apps to free memory or use a smaller Zipvoice model that includes all required components (encoder, decoder, and vocoder)."
194
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: $msg")
195
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", msg)
196
+ return@init
197
+ }
198
+ }
199
+ // Hint GC before heavy allocation to reduce memory pressure; zipvoice always uses 1 thread to limit peak RAM.
200
+ System.gc()
201
+ if (am != null) {
202
+ val memInfoBefore = ActivityManager.MemoryInfo()
203
+ am.getMemoryInfo(memInfoBefore)
204
+ Log.i("SherpaOnnxTts", "Zipvoice init: availMem=${memInfoBefore.availMem / (1024 * 1024)} MB (before load)")
205
+ }
206
+ val zipvoiceNumThreads = 1
207
+ val wrapper = ZipvoiceTtsWrapper.create(
208
+ tokens = path(paths, "tokens"),
209
+ encoder = path(paths, "encoder"),
210
+ decoder = path(paths, "decoder"),
211
+ vocoder = vocoderPath,
212
+ dataDir = path(paths, "dataDir"),
213
+ lexicon = path(paths, "lexicon"),
214
+ numThreads = zipvoiceNumThreads,
215
+ debug = debug,
216
+ ruleFsts = ruleFsts?.takeIf { it.isNotBlank() } ?: "",
217
+ ruleFars = ruleFars?.takeIf { it.isNotBlank() } ?: "",
218
+ maxNumSentences = maxNumSentences?.toInt()?.coerceAtLeast(1) ?: 1,
219
+ silenceScale = silenceScale?.toFloat()?.coerceIn(0f, 10f) ?: 0.2f,
220
+ provider = provider?.takeIf { it.isNotBlank() } ?: "cpu"
221
+ )
222
+ if (am != null) {
223
+ val memInfo = ActivityManager.MemoryInfo()
224
+ am.getMemoryInfo(memInfo)
225
+ Log.i("SherpaOnnxTts", "Zipvoice init: availMem=${memInfo.availMem / (1024 * 1024)} MB (after load)")
226
+ }
227
+ if (wrapper == null) {
228
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: Failed to create Zipvoice TTS engine via C-API. Check logcat for details.")
229
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", "Failed to create Zipvoice TTS engine via C-API. Check logcat for details.")
230
+ return@init
231
+ }
232
+ inst.zipvoiceTts = wrapper
233
+ sampleRate = wrapper.sampleRate()
234
+ numSpeakers = wrapper.numSpeakers()
235
+ } else {
236
+ val config = buildTtsConfig(
237
+ paths, modelTypeStr, numThreads.toInt(), debug,
238
+ noiseScale, noiseScaleW, lengthScale,
239
+ ruleFsts, ruleFars, maxNumSentences?.toInt(), silenceScale,
240
+ provider
241
+ )
242
+ inst.tts = OfflineTts(config = config)
243
+ sampleRate = inst.tts!!.sampleRate()
244
+ numSpeakers = inst.tts!!.numSpeakers()
245
+ }
246
+
247
+ Log.i("SherpaOnnxTts", "initializeTts: instanceId=$instanceId, engine=${if (inst.isZipvoice) "zipvoice-c-api" else "kotlin-api"}, sampleRate=$sampleRate, numSpeakers=$numSpeakers")
248
+
249
+ val modelsArray = Arguments.createArray()
250
+ detectedModels?.forEach { modelObj ->
251
+ if (modelObj is HashMap<*, *>) {
252
+ val modelMap = Arguments.createMap()
253
+ modelMap.putString("type", modelObj["type"] as? String ?: "")
254
+ modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
255
+ modelsArray.pushMap(modelMap)
256
+ }
257
+ }
258
+
259
+ inst.ttsInitState = TtsInitState(
260
+ modelDir,
261
+ modelType,
262
+ numThreads.toInt(),
263
+ debug,
264
+ noiseScale?.takeUnless { it.isNaN() },
265
+ noiseScaleW?.takeUnless { it.isNaN() },
266
+ lengthScale?.takeUnless { it.isNaN() },
267
+ ruleFsts?.takeIf { it.isNotBlank() },
268
+ ruleFars?.takeIf { it.isNotBlank() },
269
+ maxNumSentences?.toInt()?.takeIf { it > 0 },
270
+ silenceScale?.takeUnless { it.isNaN() },
271
+ provider?.takeIf { it.isNotBlank() }
272
+ )
273
+
274
+ val resultMap = Arguments.createMap()
275
+ resultMap.putBoolean("success", true)
276
+ resultMap.putArray("detectedModels", modelsArray)
277
+ resultMap.putInt("sampleRate", sampleRate)
278
+ resultMap.putInt("numSpeakers", numSpeakers)
279
+ resolveOnUiThread(promise, resultMap)
280
+ } catch (e: Exception) {
281
+ Log.e("SherpaOnnxTts", "TTS_INIT_ERROR: Failed to initialize TTS: ${e.message}", e)
282
+ rejectOnUiThread(promise, "TTS_INIT_ERROR", "Failed to initialize TTS: ${e.message}", e)
283
+ }
284
+ }
285
+ }
286
+
287
+ fun updateTtsParams(
288
+ instanceId: String,
289
+ noiseScale: Double?,
290
+ noiseScaleW: Double?,
291
+ lengthScale: Double?,
292
+ promise: Promise
293
+ ) {
294
+ val inst = getInstance(instanceId) ?: run {
295
+ Log.e("SherpaOnnxTts", "TTS_UPDATE_ERROR: TTS instance not found: $instanceId")
296
+ promise.reject("TTS_UPDATE_ERROR", "TTS instance not found: $instanceId")
297
+ return
298
+ }
299
+ if (inst.ttsStreamRunning.get()) {
300
+ Log.e("SherpaOnnxTts", "TTS_UPDATE_ERROR: Cannot update params while streaming")
301
+ promise.reject("TTS_UPDATE_ERROR", "Cannot update params while streaming")
302
+ return
303
+ }
304
+ val state = inst.ttsInitState ?: run {
305
+ Log.e("SherpaOnnxTts", "TTS_UPDATE_ERROR: TTS not initialized")
306
+ promise.reject("TTS_UPDATE_ERROR", "TTS not initialized")
307
+ return
308
+ }
309
+
310
+ if (inst.isZipvoice) {
311
+ initializeTts(
312
+ instanceId,
313
+ state.modelDir, state.modelType, state.numThreads.toDouble(), state.debug,
314
+ noiseScale, noiseScaleW, lengthScale,
315
+ state.ruleFsts, state.ruleFars, state.maxNumSentences?.toDouble(), state.silenceScale,
316
+ state.provider,
317
+ promise
318
+ )
319
+ return
320
+ }
321
+
322
+ val nextNoiseScale = when {
323
+ noiseScale == null -> null
324
+ noiseScale.isNaN() -> state.noiseScale
325
+ else -> noiseScale
326
+ }
327
+ val nextNoiseScaleW = when {
328
+ noiseScaleW == null -> null
329
+ noiseScaleW.isNaN() -> state.noiseScaleW
330
+ else -> noiseScaleW
331
+ }
332
+ val nextLengthScale = when {
333
+ lengthScale == null -> null
334
+ lengthScale.isNaN() -> state.lengthScale
335
+ else -> lengthScale
336
+ }
337
+ try {
338
+ val result = detectTtsModel(state.modelDir, state.modelType)
339
+ if (result == null || result["success"] as? Boolean != true) {
340
+ Log.e("SherpaOnnxTts", "TTS_UPDATE_ERROR: Failed to re-detect TTS model")
341
+ promise.reject("TTS_UPDATE_ERROR", "Failed to re-detect TTS model")
342
+ return
343
+ }
344
+ val paths = (result["paths"] as? Map<*, *>)?.mapValues { (_, v) -> (v as? String).orEmpty() }?.mapKeys { it.key.toString() } ?: emptyMap()
345
+ val modelTypeStr = result["modelType"] as? String ?: state.modelType
346
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
347
+
348
+ inst.tts?.release()
349
+ inst.tts = null
350
+ val config = buildTtsConfig(
351
+ paths, modelTypeStr, state.numThreads, state.debug,
352
+ nextNoiseScale, nextNoiseScaleW, nextLengthScale,
353
+ state.ruleFsts, state.ruleFars, state.maxNumSentences, state.silenceScale,
354
+ state.provider
355
+ )
356
+ inst.tts = OfflineTts(config = config)
357
+ val ttsInstance = inst.tts!!
358
+
359
+ val modelsArray = Arguments.createArray()
360
+ detectedModels?.forEach { modelObj ->
361
+ if (modelObj is HashMap<*, *>) {
362
+ val modelMap = Arguments.createMap()
363
+ modelMap.putString("type", modelObj["type"] as? String ?: "")
364
+ modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
365
+ modelsArray.pushMap(modelMap)
366
+ }
367
+ }
368
+
369
+ inst.ttsInitState = state.copy(
370
+ noiseScale = nextNoiseScale,
371
+ noiseScaleW = nextNoiseScaleW,
372
+ lengthScale = nextLengthScale
373
+ )
374
+
375
+ val resultMap = Arguments.createMap()
376
+ resultMap.putBoolean("success", true)
377
+ resultMap.putArray("detectedModels", modelsArray)
378
+ resultMap.putInt("sampleRate", ttsInstance.sampleRate())
379
+ resultMap.putInt("numSpeakers", ttsInstance.numSpeakers())
380
+ promise.resolve(resultMap)
381
+ } catch (e: Exception) {
382
+ Log.e("SherpaOnnxTts", "TTS_UPDATE_ERROR: Failed to update TTS params", e)
383
+ promise.reject("TTS_UPDATE_ERROR", "Failed to update TTS params", e)
384
+ }
385
+ }
386
+
387
+ fun generateTts(instanceId: String, text: String, options: ReadableMap?, promise: Promise) {
388
+ try {
389
+ val inst = getInstance(instanceId) ?: run {
390
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS instance not found: $instanceId")
391
+ promise.reject("TTS_GENERATE_ERROR", "TTS instance not found: $instanceId")
392
+ return
393
+ }
394
+ if (!inst.hasEngine()) {
395
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
396
+ promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
397
+ return
398
+ }
399
+ val sid = getSid(options)
400
+ val speed = getSpeed(options)
401
+ val audio = when {
402
+ hasReferenceOptions(options) && inst.isZipvoice -> {
403
+ val refAudio = options?.getArray("referenceAudio")
404
+ ?: run {
405
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: referenceAudio required for Zipvoice voice cloning")
406
+ promise.reject("TTS_GENERATE_ERROR", "referenceAudio required for Zipvoice voice cloning")
407
+ return
408
+ }
409
+ val promptSr = if (options.hasKey("referenceSampleRate")) options.getDouble("referenceSampleRate").toInt() else 0
410
+ val promptText = options.getString("referenceText").orEmpty()
411
+ val numSteps = if (options.hasKey("numSteps")) options.getDouble("numSteps").toInt() else 20
412
+ val samples = FloatArray(refAudio.size()) { i -> refAudio.getDouble(i).toFloat() }
413
+ inst.zipvoiceTts!!.generateWithZipvoice(text, promptText, samples, promptSr, speed, numSteps)
414
+ }
415
+ hasReferenceOptions(options) && inst.tts != null -> {
416
+ val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
417
+ inst.tts!!.generateWithConfig(text, config)
418
+ }
419
+ else -> dispatchGenerate(inst, text, sid, speed)
420
+ ?: run {
421
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
422
+ promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
423
+ return
424
+ }
425
+ }
426
+ val map = Arguments.createMap()
427
+ val samplesArray = Arguments.createArray()
428
+ for (sample in audio.samples) {
429
+ samplesArray.pushDouble(sample.toDouble())
430
+ }
431
+ map.putArray("samples", samplesArray)
432
+ map.putInt("sampleRate", audio.sampleRate)
433
+ promise.resolve(map)
434
+ } catch (e: Exception) {
435
+ Log.e("SherpaOnnxTts", "generateTts error: ${e.message}", e)
436
+ promise.reject("TTS_GENERATE_ERROR", e.message ?: "Failed to generate speech", e)
437
+ }
438
+ }
439
+
440
+ fun generateTtsWithTimestamps(instanceId: String, text: String, options: ReadableMap?, promise: Promise) {
441
+ try {
442
+ val inst = getInstance(instanceId) ?: run {
443
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS instance not found: $instanceId")
444
+ promise.reject("TTS_GENERATE_ERROR", "TTS instance not found: $instanceId")
445
+ return
446
+ }
447
+ if (!inst.hasEngine()) {
448
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
449
+ promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
450
+ return
451
+ }
452
+ val sid = getSid(options)
453
+ val speed = getSpeed(options)
454
+ val audio = when {
455
+ hasReferenceOptions(options) && inst.isZipvoice -> {
456
+ val refAudio = options?.getArray("referenceAudio")
457
+ ?: run {
458
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: referenceAudio required for Zipvoice voice cloning")
459
+ promise.reject("TTS_GENERATE_ERROR", "referenceAudio required for Zipvoice voice cloning")
460
+ return
461
+ }
462
+ val promptSr = if (options.hasKey("referenceSampleRate")) options.getDouble("referenceSampleRate").toInt() else 0
463
+ val promptText = options.getString("referenceText").orEmpty()
464
+ val numSteps = if (options.hasKey("numSteps")) options.getDouble("numSteps").toInt() else 20
465
+ val samples = FloatArray(refAudio.size()) { i -> refAudio.getDouble(i).toFloat() }
466
+ inst.zipvoiceTts!!.generateWithZipvoice(text, promptText, samples, promptSr, speed, numSteps)
467
+ }
468
+ hasReferenceOptions(options) && inst.tts != null -> {
469
+ val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
470
+ inst.tts!!.generateWithConfig(text, config)
471
+ }
472
+ else -> dispatchGenerate(inst, text, sid, speed)
473
+ ?: run {
474
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
475
+ promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
476
+ return
477
+ }
478
+ }
479
+ val map = Arguments.createMap()
480
+ val samplesArray = Arguments.createArray()
481
+ for (sample in audio.samples) {
482
+ samplesArray.pushDouble(sample.toDouble())
483
+ }
484
+ map.putArray("samples", samplesArray)
485
+ map.putInt("sampleRate", audio.sampleRate)
486
+ val subtitlesArray = Arguments.createArray()
487
+ if (audio.samples.isNotEmpty() && audio.sampleRate > 0) {
488
+ val durationSec = audio.samples.size.toDouble() / audio.sampleRate
489
+ val subtitleMap = Arguments.createMap()
490
+ subtitleMap.putString("text", text)
491
+ subtitleMap.putDouble("start", 0.0)
492
+ subtitleMap.putDouble("end", durationSec)
493
+ subtitlesArray.pushMap(subtitleMap)
494
+ }
495
+ map.putArray("subtitles", subtitlesArray)
496
+ map.putBoolean("estimated", true)
497
+ promise.resolve(map)
498
+ } catch (e: Exception) {
499
+ Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: ${e.message ?: "Failed to generate speech"}", e)
500
+ promise.reject("TTS_GENERATE_ERROR", e.message ?: "Failed to generate speech", e)
501
+ }
502
+ }
503
+
504
+ fun generateTtsStream(instanceId: String, text: String, options: ReadableMap?, promise: Promise) {
505
+ val inst = getInstance(instanceId) ?: run {
506
+ Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: TTS instance not found: $instanceId")
507
+ promise.reject("TTS_STREAM_ERROR", "TTS instance not found: $instanceId")
508
+ return
509
+ }
510
+ if (inst.ttsStreamRunning.get()) {
511
+ Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: TTS streaming already in progress")
512
+ promise.reject("TTS_STREAM_ERROR", "TTS streaming already in progress")
513
+ return
514
+ }
515
+ if (!inst.hasEngine()) {
516
+ Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: TTS not initialized")
517
+ promise.reject("TTS_STREAM_ERROR", "TTS not initialized")
518
+ return
519
+ }
520
+ if (hasReferenceOptions(options) && inst.isZipvoice) {
521
+ Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: Streaming with reference audio not supported for Zipvoice")
522
+ promise.reject("TTS_STREAM_ERROR", "Streaming with reference audio not supported for Zipvoice")
523
+ return
524
+ }
525
+ val sid = getSid(options)
526
+ val speed = getSpeed(options)
527
+ inst.ttsStreamCancelled.set(false)
528
+ inst.ttsStreamRunning.set(true)
529
+ inst.ttsStreamThread = Thread {
530
+ try {
531
+ val sampleRate = dispatchSampleRate(inst)
532
+ when {
533
+ hasReferenceOptions(options) && inst.tts != null -> {
534
+ val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
535
+ inst.tts!!.generateWithConfigAndCallback(text, config) { chunk ->
536
+ if (inst.ttsStreamCancelled.get()) return@generateWithConfigAndCallback 0
537
+ emitChunk(instanceId, chunk, sampleRate, 0f, false)
538
+ chunk.size
539
+ }
540
+ }
541
+ inst.zipvoiceTts != null -> {
542
+ inst.zipvoiceTts!!.generateWithCallback(text, sid, speed) { chunk ->
543
+ if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
544
+ emitChunk(instanceId, chunk, sampleRate, 0f, false)
545
+ chunk.size
546
+ }
547
+ }
548
+ else -> {
549
+ inst.tts!!.generateWithCallback(text, sid, speed) { chunk ->
550
+ if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
551
+ emitChunk(instanceId, chunk, sampleRate, 0f, false)
552
+ chunk.size
553
+ }
554
+ }
555
+ }
556
+ if (!inst.ttsStreamCancelled.get()) {
557
+ emitChunk(instanceId, FloatArray(0), sampleRate, 1f, true)
558
+ }
559
+ } catch (e: Exception) {
560
+ if (!inst.ttsStreamCancelled.get()) {
561
+ emitError(instanceId, "TTS streaming failed: ${e.message}")
562
+ }
563
+ } finally {
564
+ emitEnd(instanceId, inst.ttsStreamCancelled.get())
565
+ inst.ttsStreamRunning.set(false)
566
+ }
567
+ }
568
+ inst.ttsStreamThread?.start()
569
+ promise.resolve(null)
570
+ }
571
+
572
+ fun cancelTtsStream(instanceId: String, promise: Promise) {
573
+ val inst = getInstance(instanceId)
574
+ if (inst != null) {
575
+ inst.ttsStreamCancelled.set(true)
576
+ inst.ttsStreamThread?.interrupt()
577
+ }
578
+ promise.resolve(null)
579
+ }
580
+
581
+ fun startTtsPcmPlayer(instanceId: String, sampleRate: Double, channels: Double, promise: Promise) {
582
+ val inst = getInstance(instanceId) ?: run {
583
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: TTS instance not found: $instanceId")
584
+ promise.reject("TTS_PCM_ERROR", "TTS instance not found: $instanceId")
585
+ return
586
+ }
587
+ try {
588
+ if (Build.VERSION.SDK_INT < Build.VERSION_CODES.LOLLIPOP) {
589
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: PCM playback requires API 21+")
590
+ promise.reject("TTS_PCM_ERROR", "PCM playback requires API 21+")
591
+ return
592
+ }
593
+ if (channels.toInt() != 1) {
594
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: PCM playback supports mono only")
595
+ promise.reject("TTS_PCM_ERROR", "PCM playback supports mono only")
596
+ return
597
+ }
598
+ inst.stopPcmPlayer()
599
+ val channelConfig = AudioFormat.CHANNEL_OUT_MONO
600
+ val audioFormat = AudioFormat.Builder()
601
+ .setSampleRate(sampleRate.toInt())
602
+ .setChannelMask(channelConfig)
603
+ .setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
604
+ .build()
605
+ val minBufferSize = AudioTrack.getMinBufferSize(sampleRate.toInt(), channelConfig, AudioFormat.ENCODING_PCM_FLOAT)
606
+ if (minBufferSize == AudioTrack.ERROR || minBufferSize == AudioTrack.ERROR_BAD_VALUE) {
607
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: Invalid buffer size for PCM player")
608
+ promise.reject("TTS_PCM_ERROR", "Invalid buffer size for PCM player")
609
+ return
610
+ }
611
+ val attributes = AudioAttributes.Builder()
612
+ .setUsage(AudioAttributes.USAGE_MEDIA)
613
+ .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
614
+ .build()
615
+ inst.ttsPcmTrack = AudioTrack(attributes, audioFormat, minBufferSize, AudioTrack.MODE_STREAM, AudioManager.AUDIO_SESSION_ID_GENERATE)
616
+ inst.ttsPcmTrack?.play()
617
+ promise.resolve(null)
618
+ } catch (e: Exception) {
619
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: Failed to start PCM player", e)
620
+ promise.reject("TTS_PCM_ERROR", "Failed to start PCM player", e)
621
+ }
622
+ }
623
+
624
+ fun writeTtsPcmChunk(instanceId: String, samples: ReadableArray, promise: Promise) {
625
+ val inst = getInstance(instanceId) ?: run {
626
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: TTS instance not found: $instanceId")
627
+ promise.reject("TTS_PCM_ERROR", "TTS instance not found: $instanceId")
628
+ return
629
+ }
630
+ val track = inst.ttsPcmTrack ?: run {
631
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: PCM player not initialized")
632
+ promise.reject("TTS_PCM_ERROR", "PCM player not initialized")
633
+ return
634
+ }
635
+ try {
636
+ val buffer = FloatArray(samples.size())
637
+ for (i in 0 until samples.size()) {
638
+ buffer[i] = samples.getDouble(i).toFloat()
639
+ }
640
+ val written = track.write(buffer, 0, buffer.size, AudioTrack.WRITE_BLOCKING)
641
+ if (written < 0) {
642
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: PCM write failed: $written")
643
+ promise.reject("TTS_PCM_ERROR", "PCM write failed: $written")
644
+ return
645
+ }
646
+ promise.resolve(null)
647
+ } catch (e: Exception) {
648
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: Failed to write PCM chunk", e)
649
+ promise.reject("TTS_PCM_ERROR", "Failed to write PCM chunk", e)
650
+ }
651
+ }
652
+
653
+ fun stopTtsPcmPlayer(instanceId: String, promise: Promise) {
654
+ try {
655
+ getInstance(instanceId)?.stopPcmPlayer()
656
+ promise.resolve(null)
657
+ } catch (e: Exception) {
658
+ Log.e("SherpaOnnxTts", "TTS_PCM_ERROR: Failed to stop PCM player", e)
659
+ promise.reject("TTS_PCM_ERROR", "Failed to stop PCM player", e)
660
+ }
661
+ }
662
+
663
+ fun getTtsSampleRate(instanceId: String, promise: Promise) {
664
+ try {
665
+ val inst = getInstance(instanceId) ?: run {
666
+ Log.e("SherpaOnnxTts", "TTS_ERROR: TTS instance not found: $instanceId")
667
+ promise.reject("TTS_ERROR", "TTS instance not found: $instanceId")
668
+ return
669
+ }
670
+ if (!inst.hasEngine()) {
671
+ Log.e("SherpaOnnxTts", "TTS_ERROR: TTS not initialized")
672
+ promise.reject("TTS_ERROR", "TTS not initialized")
673
+ return
674
+ }
675
+ promise.resolve(dispatchSampleRate(inst).toDouble())
676
+ } catch (e: Exception) {
677
+ Log.e("SherpaOnnxTts", "TTS_ERROR: Failed to get sample rate", e)
678
+ promise.reject("TTS_ERROR", "Failed to get sample rate", e)
679
+ }
680
+ }
681
+
682
+ fun getTtsNumSpeakers(instanceId: String, promise: Promise) {
683
+ try {
684
+ val inst = getInstance(instanceId) ?: run {
685
+ Log.e("SherpaOnnxTts", "TTS_ERROR: TTS instance not found: $instanceId")
686
+ promise.reject("TTS_ERROR", "TTS instance not found: $instanceId")
687
+ return
688
+ }
689
+ if (!inst.hasEngine()) {
690
+ Log.e("SherpaOnnxTts", "TTS_ERROR: TTS not initialized")
691
+ promise.reject("TTS_ERROR", "TTS not initialized")
692
+ return
693
+ }
694
+ promise.resolve(dispatchNumSpeakers(inst).toDouble())
695
+ } catch (e: Exception) {
696
+ Log.e("SherpaOnnxTts", "TTS_ERROR: Failed to get number of speakers", e)
697
+ promise.reject("TTS_ERROR", "Failed to get number of speakers", e)
698
+ }
699
+ }
700
+
701
+ fun unloadTts(instanceId: String, promise: Promise) {
702
+ try {
703
+ val inst = instances.remove(instanceId)
704
+ if (inst != null) {
705
+ inst.stopPcmPlayer()
706
+ inst.releaseEngines()
707
+ }
708
+ promise.resolve(null)
709
+ } catch (e: Exception) {
710
+ Log.e("SherpaOnnxTts", "TTS_RELEASE_ERROR: Failed to release TTS resources", e)
711
+ promise.reject("TTS_RELEASE_ERROR", "Failed to release TTS resources", e)
712
+ }
713
+ }
714
+
715
+ fun saveTtsAudioToFile(
716
+ samples: ReadableArray,
717
+ sampleRate: Double,
718
+ filePath: String,
719
+ promise: Promise
720
+ ) {
721
+ try {
722
+ val samplesArray = FloatArray(samples.size())
723
+ for (i in 0 until samples.size()) {
724
+ samplesArray[i] = samples.getDouble(i).toFloat()
725
+ }
726
+ val success = GeneratedAudio(samplesArray, sampleRate.toInt()).save(filePath)
727
+ if (success) {
728
+ promise.resolve(filePath)
729
+ } else {
730
+ Log.e("SherpaOnnxTts", "TTS_SAVE_ERROR: Failed to save audio to file")
731
+ promise.reject("TTS_SAVE_ERROR", "Failed to save audio to file")
732
+ }
733
+ } catch (e: Exception) {
734
+ Log.e("SherpaOnnxTts", "TTS_SAVE_ERROR: Failed to save audio to file", e)
735
+ promise.reject("TTS_SAVE_ERROR", "Failed to save audio to file", e)
736
+ }
737
+ }
738
+
739
+ fun saveTtsAudioToContentUri(
740
+ samples: ReadableArray,
741
+ sampleRate: Double,
742
+ directoryUri: String,
743
+ filename: String,
744
+ promise: Promise
745
+ ) {
746
+ try {
747
+ val samplesArray = FloatArray(samples.size())
748
+ for (i in 0 until samples.size()) {
749
+ samplesArray[i] = samples.getDouble(i).toFloat()
750
+ }
751
+ val resolver = context.contentResolver
752
+ val dirUri = Uri.parse(directoryUri)
753
+ val fileUri = createDocumentInDirectory(resolver, dirUri, filename, "audio/wav")
754
+ resolver.openOutputStream(fileUri, "w")?.use { outputStream ->
755
+ writeWavToStream(samplesArray, sampleRate.toInt(), outputStream)
756
+ } ?: throw IllegalStateException("Failed to open output stream for URI: $fileUri")
757
+ promise.resolve(fileUri.toString())
758
+ } catch (e: Exception) {
759
+ Log.e("SherpaOnnxTts", "TTS_SAVE_ERROR: Failed to save audio to content URI", e)
760
+ promise.reject("TTS_SAVE_ERROR", "Failed to save audio to content URI", e)
761
+ }
762
+ }
763
+
764
+ fun saveTtsTextToContentUri(
765
+ text: String,
766
+ directoryUri: String,
767
+ filename: String,
768
+ mimeType: String,
769
+ promise: Promise
770
+ ) {
771
+ try {
772
+ val resolver = context.contentResolver
773
+ val dirUri = Uri.parse(directoryUri)
774
+ val fileUri = createDocumentInDirectory(resolver, dirUri, filename, mimeType)
775
+ resolver.openOutputStream(fileUri, "w")?.use { outputStream ->
776
+ outputStream.write(text.toByteArray(Charsets.UTF_8))
777
+ } ?: throw IllegalStateException("Failed to open output stream for URI: $fileUri")
778
+ promise.resolve(fileUri.toString())
779
+ } catch (e: Exception) {
780
+ Log.e("SherpaOnnxTts", "TTS_SAVE_ERROR: Failed to save text to content URI", e)
781
+ promise.reject("TTS_SAVE_ERROR", "Failed to save text to content URI", e)
782
+ }
783
+ }
784
+
785
+ fun copyTtsContentUriToCache(fileUri: String, filename: String, promise: Promise) {
786
+ try {
787
+ val resolver = context.contentResolver
788
+ val uri = Uri.parse(fileUri)
789
+ val cacheFile = File(context.cacheDir, filename)
790
+ resolver.openInputStream(uri)?.use { inputStream ->
791
+ FileOutputStream(cacheFile).use { outputStream ->
792
+ copyStream(inputStream, outputStream)
793
+ }
794
+ } ?: throw IllegalStateException("Failed to open input stream for URI: $fileUri")
795
+ promise.resolve(cacheFile.absolutePath)
796
+ } catch (e: Exception) {
797
+ Log.e("SherpaOnnxTts", "TTS_SAVE_ERROR: Failed to copy audio to cache", e)
798
+ promise.reject("TTS_SAVE_ERROR", "Failed to copy audio to cache", e)
799
+ }
800
+ }
801
+
802
+ fun shareTtsAudio(fileUri: String, mimeType: String, promise: Promise) {
803
+ try {
804
+ val uri = if (fileUri.startsWith("content://")) {
805
+ Uri.parse(fileUri)
806
+ } else {
807
+ val path = if (fileUri.startsWith("file://")) {
808
+ try {
809
+ Uri.parse(fileUri).path ?: fileUri.replaceFirst("file://", "")
810
+ } catch (_: Exception) {
811
+ fileUri.replaceFirst("file://", "")
812
+ }
813
+ } else {
814
+ fileUri
815
+ }
816
+ val file = File(path)
817
+ val authority = context.packageName + ".fileprovider"
818
+ FileProvider.getUriForFile(context, authority, file)
819
+ }
820
+ val shareIntent = Intent(Intent.ACTION_SEND).apply {
821
+ type = mimeType
822
+ putExtra(Intent.EXTRA_STREAM, uri)
823
+ addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION)
824
+ }
825
+ val chooser = Intent.createChooser(shareIntent, "Share audio")
826
+ chooser.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
827
+ context.startActivity(chooser)
828
+ promise.resolve(null)
829
+ } catch (e: Exception) {
830
+ Log.e("SherpaOnnxTts", "TTS_SHARE_ERROR: Failed to share audio", e)
831
+ promise.reject("TTS_SHARE_ERROR", "Failed to share audio", e)
832
+ }
833
+ }
834
+
835
+ // -- Dual-engine dispatch helpers --
836
+
837
+ /** True if options contain reference-audio fields for voice cloning. */
838
+ private fun hasReferenceOptions(options: ReadableMap?): Boolean {
839
+ if (options == null) return false
840
+ val refAudio = options.getArray("referenceAudio")
841
+ val refText = options.getString("referenceText")
842
+ return (refAudio != null && refAudio.size() > 0) || !refText.isNullOrEmpty()
843
+ }
844
+
845
+ /** Parse sid and speed from options with defaults. */
846
+ private fun getSid(options: ReadableMap?): Int =
847
+ if (options != null && options.hasKey("sid")) options.getDouble("sid").toInt() else 0
848
+
849
+ private fun getSpeed(options: ReadableMap?): Float =
850
+ if (options != null && options.hasKey("speed")) options.getDouble("speed").toFloat() else 1.0f
851
+
852
+ /** Build Kotlin GenerationConfig from ReadableMap. Returns null only when options is null; otherwise returns a config with sid, speed, silenceScale, numSteps, and any reference/extra fields from options. */
853
+ private fun parseGenerationConfig(options: ReadableMap?): GenerationConfig? {
854
+ if (options == null) return null
855
+ val refAudio = options.getArray("referenceAudio")
856
+ val refSampleRate = if (options.hasKey("referenceSampleRate")) options.getDouble("referenceSampleRate").toInt() else 0
857
+ val refText = options.getString("referenceText")
858
+ val silenceScale = if (options.hasKey("silenceScale")) options.getDouble("silenceScale").toFloat() else 0.2f
859
+ val speed = getSpeed(options)
860
+ val sid = getSid(options)
861
+ val numSteps = if (options.hasKey("numSteps")) options.getDouble("numSteps").toInt() else 5
862
+ val extraMap = options.getMap("extra")?.let { map ->
863
+ val it = map.keySetIterator()
864
+ buildMap<String, String> {
865
+ while (it.hasNextKey()) {
866
+ val k = it.nextKey()
867
+ put(k, map.getString(k).orEmpty())
868
+ }
869
+ }
870
+ }
871
+ val refAudioFloat = refAudio?.let { arr ->
872
+ FloatArray(arr.size()) { i -> arr.getDouble(i).toFloat() }
873
+ }
874
+ return GenerationConfig(
875
+ silenceScale = silenceScale,
876
+ speed = speed,
877
+ sid = sid,
878
+ referenceAudio = refAudioFloat,
879
+ referenceSampleRate = refSampleRate,
880
+ referenceText = refText,
881
+ numSteps = numSteps,
882
+ extra = extraMap
883
+ )
884
+ }
885
+
886
+ /** Dispatch generate to whichever engine is active on the instance. Returns null if none loaded. */
887
+ private fun dispatchGenerate(inst: TtsEngineInstance, text: String, sid: Int, speed: Float): GeneratedAudio? {
888
+ inst.zipvoiceTts?.let { return it.generate(text, sid, speed) }
889
+ inst.tts?.let { return it.generate(text, sid, speed) }
890
+ return null
891
+ }
892
+
893
+ private fun dispatchSampleRate(inst: TtsEngineInstance): Int {
894
+ inst.zipvoiceTts?.let { return it.sampleRate() }
895
+ return inst.tts?.sampleRate() ?: 0
896
+ }
897
+
898
+ private fun dispatchNumSpeakers(inst: TtsEngineInstance): Int {
899
+ inst.zipvoiceTts?.let { return it.numSpeakers() }
900
+ return inst.tts?.numSpeakers() ?: 0
901
+ }
902
+
903
+ private fun path(paths: Map<String, String>, key: String): String = paths[key].orEmpty()
904
+
905
+ private fun buildTtsConfig(
906
+ paths: Map<String, String>,
907
+ modelType: String,
908
+ numThreads: Int,
909
+ debug: Boolean,
910
+ noiseScale: Double?,
911
+ noiseScaleW: Double?,
912
+ lengthScale: Double?,
913
+ ruleFsts: String?,
914
+ ruleFars: String?,
915
+ maxNumSentences: Int?,
916
+ silenceScale: Double?,
917
+ provider: String?
918
+ ): OfflineTtsConfig {
919
+ val ns = noiseScale?.toFloat() ?: 0.667f
920
+ val nsw = noiseScaleW?.toFloat() ?: 0.8f
921
+ val ls = lengthScale?.toFloat() ?: 1.0f
922
+ val prov = provider?.takeIf { it.isNotBlank() } ?: "cpu"
923
+ val modelConfig = when (modelType) {
924
+ "vits" -> OfflineTtsModelConfig(
925
+ vits = OfflineTtsVitsModelConfig(
926
+ model = path(paths, "ttsModel"),
927
+ lexicon = path(paths, "lexicon"),
928
+ tokens = path(paths, "tokens"),
929
+ dataDir = path(paths, "dataDir"),
930
+ noiseScale = ns,
931
+ noiseScaleW = nsw,
932
+ lengthScale = ls
933
+ ),
934
+ numThreads = numThreads,
935
+ debug = debug,
936
+ provider = prov
937
+ )
938
+ "matcha" -> OfflineTtsModelConfig(
939
+ matcha = OfflineTtsMatchaModelConfig(
940
+ acousticModel = path(paths, "acousticModel"),
941
+ vocoder = path(paths, "vocoder"),
942
+ lexicon = path(paths, "lexicon"),
943
+ tokens = path(paths, "tokens"),
944
+ dataDir = path(paths, "dataDir"),
945
+ noiseScale = ns,
946
+ lengthScale = ls
947
+ ),
948
+ numThreads = numThreads,
949
+ debug = debug,
950
+ provider = prov
951
+ )
952
+ "kokoro" -> OfflineTtsModelConfig(
953
+ kokoro = OfflineTtsKokoroModelConfig(
954
+ model = path(paths, "ttsModel"),
955
+ voices = path(paths, "voices"),
956
+ tokens = path(paths, "tokens"),
957
+ dataDir = path(paths, "dataDir"),
958
+ lexicon = path(paths, "lexicon"),
959
+ lengthScale = ls
960
+ ),
961
+ numThreads = numThreads,
962
+ debug = debug,
963
+ provider = prov
964
+ )
965
+ "kitten" -> OfflineTtsModelConfig(
966
+ kitten = OfflineTtsKittenModelConfig(
967
+ model = path(paths, "ttsModel"),
968
+ voices = path(paths, "voices"),
969
+ tokens = path(paths, "tokens"),
970
+ dataDir = path(paths, "dataDir"),
971
+ lengthScale = ls
972
+ ),
973
+ numThreads = numThreads,
974
+ debug = debug,
975
+ provider = prov
976
+ )
977
+ "pocket" -> OfflineTtsModelConfig(
978
+ pocket = OfflineTtsPocketModelConfig(
979
+ lmFlow = path(paths, "lmFlow"),
980
+ lmMain = path(paths, "lmMain"),
981
+ encoder = path(paths, "encoder"),
982
+ decoder = path(paths, "decoder"),
983
+ textConditioner = path(paths, "textConditioner"),
984
+ vocabJson = path(paths, "vocabJson"),
985
+ tokenScoresJson = path(paths, "tokenScoresJson")
986
+ ),
987
+ numThreads = numThreads,
988
+ debug = debug,
989
+ provider = prov
990
+ )
991
+ "zipvoice" -> {
992
+ // Zipvoice is handled by ZipvoiceTtsWrapper (C-API), not OfflineTts (Kotlin API).
993
+ // This branch should not be reached because initializeTts/updateTtsParams handle
994
+ // the "zipvoice" case before calling buildTtsConfig.
995
+ throw IllegalStateException(
996
+ "buildTtsConfig should not be called for zipvoice models. Use ZipvoiceTtsWrapper instead."
997
+ )
998
+ }
999
+ else -> {
1000
+ if (path(paths, "acousticModel").isNotEmpty()) {
1001
+ OfflineTtsModelConfig(
1002
+ matcha = OfflineTtsMatchaModelConfig(
1003
+ acousticModel = path(paths, "acousticModel"),
1004
+ vocoder = path(paths, "vocoder"),
1005
+ lexicon = path(paths, "lexicon"),
1006
+ tokens = path(paths, "tokens"),
1007
+ dataDir = path(paths, "dataDir"),
1008
+ noiseScale = ns,
1009
+ lengthScale = ls
1010
+ ),
1011
+ numThreads = numThreads,
1012
+ debug = debug,
1013
+ provider = prov
1014
+ )
1015
+ } else if (path(paths, "voices").isNotEmpty()) {
1016
+ OfflineTtsModelConfig(
1017
+ kokoro = OfflineTtsKokoroModelConfig(
1018
+ model = path(paths, "ttsModel"),
1019
+ voices = path(paths, "voices"),
1020
+ tokens = path(paths, "tokens"),
1021
+ dataDir = path(paths, "dataDir"),
1022
+ lexicon = path(paths, "lexicon"),
1023
+ lengthScale = ls
1024
+ ),
1025
+ numThreads = numThreads,
1026
+ debug = debug,
1027
+ provider = prov
1028
+ )
1029
+ } else {
1030
+ OfflineTtsModelConfig(
1031
+ vits = OfflineTtsVitsModelConfig(
1032
+ model = path(paths, "ttsModel"),
1033
+ lexicon = path(paths, "lexicon"),
1034
+ tokens = path(paths, "tokens"),
1035
+ dataDir = path(paths, "dataDir"),
1036
+ noiseScale = ns,
1037
+ noiseScaleW = nsw,
1038
+ lengthScale = ls
1039
+ ),
1040
+ numThreads = numThreads,
1041
+ debug = debug,
1042
+ provider = prov
1043
+ )
1044
+ }
1045
+ }
1046
+ }
1047
+ return OfflineTtsConfig(
1048
+ model = modelConfig,
1049
+ ruleFsts = ruleFsts?.takeIf { it.isNotBlank() } ?: "",
1050
+ ruleFars = ruleFars?.takeIf { it.isNotBlank() } ?: "",
1051
+ maxNumSentences = maxNumSentences?.coerceAtLeast(1) ?: 1,
1052
+ silenceScale = silenceScale?.toFloat()?.coerceIn(0f, 10f) ?: 0.2f
1053
+ )
1054
+ }
1055
+
1056
+ private fun createDocumentInDirectory(
1057
+ resolver: android.content.ContentResolver,
1058
+ directoryUri: Uri,
1059
+ filename: String,
1060
+ mimeType: String
1061
+ ): Uri {
1062
+ return if (DocumentsContract.isTreeUri(directoryUri)) {
1063
+ val documentId = DocumentsContract.getTreeDocumentId(directoryUri)
1064
+ val dirDocUri = DocumentsContract.buildDocumentUriUsingTree(directoryUri, documentId)
1065
+ DocumentsContract.createDocument(resolver, dirDocUri, mimeType, filename)
1066
+ ?: throw IllegalStateException("Failed to create document in tree URI")
1067
+ } else {
1068
+ DocumentsContract.createDocument(resolver, directoryUri, mimeType, filename)
1069
+ ?: throw IllegalStateException("Failed to create document in directory URI")
1070
+ }
1071
+ }
1072
+
1073
+ private fun writeWavToStream(samples: FloatArray, sampleRate: Int, outputStream: OutputStream) {
1074
+ val numChannels = 1
1075
+ val bitsPerSample = 16
1076
+ val byteRate = sampleRate * numChannels * bitsPerSample / 8
1077
+ val blockAlign = numChannels * bitsPerSample / 8
1078
+ val dataSize = samples.size * 2
1079
+ val chunkSize = 36 + dataSize
1080
+ outputStream.write("RIFF".toByteArray(Charsets.US_ASCII))
1081
+ writeIntLE(outputStream, chunkSize)
1082
+ outputStream.write("WAVE".toByteArray(Charsets.US_ASCII))
1083
+ outputStream.write("fmt ".toByteArray(Charsets.US_ASCII))
1084
+ writeIntLE(outputStream, 16)
1085
+ writeShortLE(outputStream, 1)
1086
+ writeShortLE(outputStream, numChannels.toShort())
1087
+ writeIntLE(outputStream, sampleRate)
1088
+ writeIntLE(outputStream, byteRate)
1089
+ writeShortLE(outputStream, blockAlign.toShort())
1090
+ writeShortLE(outputStream, bitsPerSample.toShort())
1091
+ outputStream.write("data".toByteArray(Charsets.US_ASCII))
1092
+ writeIntLE(outputStream, dataSize)
1093
+ for (sample in samples) {
1094
+ val clamped = sample.coerceIn(-1.0f, 1.0f)
1095
+ val intSample = (clamped * 32767.0f).toInt()
1096
+ writeShortLE(outputStream, intSample.toShort())
1097
+ }
1098
+ outputStream.flush()
1099
+ }
1100
+
1101
+ private fun writeIntLE(outputStream: OutputStream, value: Int) {
1102
+ outputStream.write(value and 0xFF)
1103
+ outputStream.write((value shr 8) and 0xFF)
1104
+ outputStream.write((value shr 16) and 0xFF)
1105
+ outputStream.write((value shr 24) and 0xFF)
1106
+ }
1107
+
1108
+ private fun writeShortLE(outputStream: OutputStream, value: Short) {
1109
+ val intValue = value.toInt()
1110
+ outputStream.write(intValue and 0xFF)
1111
+ outputStream.write((intValue shr 8) and 0xFF)
1112
+ }
1113
+
1114
+ private fun copyStream(inputStream: InputStream, outputStream: OutputStream) {
1115
+ val buffer = ByteArray(8192)
1116
+ var bytes = inputStream.read(buffer)
1117
+ while (bytes >= 0) {
1118
+ outputStream.write(buffer, 0, bytes)
1119
+ bytes = inputStream.read(buffer)
1120
+ }
1121
+ outputStream.flush()
1122
+ }
1123
+ }