react-native-sherpa-onnx 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. package/README.md +21 -7
  2. package/SherpaOnnx.podspec +1 -1
  3. package/android/build.gradle +35 -26
  4. package/android/prebuilt-download.gradle +27 -14
  5. package/android/src/main/cpp/CMakeLists.txt +51 -17
  6. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +14 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +16 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +19 -2
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +2 -1
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +1 -0
  12. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +114 -8
  13. package/android/src/main/java/com/sherpaonnx/SherpaOnnxOnlineSttHelper.kt +535 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +10 -10
  15. package/ios/SherpaOnnx+OnlineSTT.mm +365 -0
  16. package/ios/SherpaOnnx+TTS.mm +35 -9
  17. package/ios/SherpaOnnx.mm +6 -0
  18. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
  19. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +16 -0
  20. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +19 -2
  21. package/ios/model_detect/sherpa-onnx-model-detect.h +2 -1
  22. package/ios/online_stt/sherpa-onnx-online-stt-wrapper.h +85 -0
  23. package/ios/online_stt/sherpa-onnx-online-stt-wrapper.mm +270 -0
  24. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  25. package/lib/module/index.js +2 -2
  26. package/lib/module/stt/index.js +4 -0
  27. package/lib/module/stt/index.js.map +1 -1
  28. package/lib/module/stt/streaming.js +257 -0
  29. package/lib/module/stt/streaming.js.map +1 -0
  30. package/lib/module/stt/streamingTypes.js +38 -0
  31. package/lib/module/stt/streamingTypes.js.map +1 -0
  32. package/lib/module/tts/index.js +4 -43
  33. package/lib/module/tts/index.js.map +1 -1
  34. package/lib/module/tts/streaming.js +220 -0
  35. package/lib/module/tts/streaming.js.map +1 -0
  36. package/lib/module/tts/streamingTypes.js +4 -0
  37. package/lib/module/tts/streamingTypes.js.map +1 -0
  38. package/lib/module/tts/types.js +8 -1
  39. package/lib/module/tts/types.js.map +1 -1
  40. package/lib/typescript/src/NativeSherpaOnnx.d.ts +66 -1
  41. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  42. package/lib/typescript/src/stt/index.d.ts +3 -0
  43. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  44. package/lib/typescript/src/stt/streaming.d.ts +42 -0
  45. package/lib/typescript/src/stt/streaming.d.ts.map +1 -0
  46. package/lib/typescript/src/stt/streamingTypes.d.ts +122 -0
  47. package/lib/typescript/src/stt/streamingTypes.d.ts.map +1 -0
  48. package/lib/typescript/src/tts/index.d.ts +3 -1
  49. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  50. package/lib/typescript/src/tts/streaming.d.ts +24 -0
  51. package/lib/typescript/src/tts/streaming.d.ts.map +1 -0
  52. package/lib/typescript/src/tts/streamingTypes.d.ts +27 -0
  53. package/lib/typescript/src/tts/streamingTypes.d.ts.map +1 -0
  54. package/lib/typescript/src/tts/types.d.ts +19 -6
  55. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  56. package/package.json +1 -2
  57. package/src/NativeSherpaOnnx.ts +95 -0
  58. package/src/index.tsx +2 -2
  59. package/src/stt/index.ts +17 -0
  60. package/src/stt/streaming.ts +361 -0
  61. package/src/stt/streamingTypes.ts +151 -0
  62. package/src/tts/index.ts +6 -66
  63. package/src/tts/streaming.ts +336 -0
  64. package/src/tts/streamingTypes.ts +54 -0
  65. package/src/tts/types.ts +20 -10
  66. package/android/codegen.gradle +0 -57
@@ -0,0 +1,535 @@
1
+ package com.sherpaonnx
2
+
3
+ import android.content.Context
4
+ import android.net.Uri
5
+ import android.util.Log
6
+ import com.facebook.react.bridge.Arguments
7
+ import com.facebook.react.bridge.Promise
8
+ import com.facebook.react.bridge.ReadableArray
9
+ import com.facebook.react.bridge.WritableMap
10
+ import com.k2fsa.sherpa.onnx.EndpointConfig
11
+ import com.k2fsa.sherpa.onnx.EndpointRule
12
+ import com.k2fsa.sherpa.onnx.FeatureConfig
13
+ import com.k2fsa.sherpa.onnx.OnlineModelConfig
14
+ import com.k2fsa.sherpa.onnx.OnlineNeMoCtcModelConfig
15
+ import com.k2fsa.sherpa.onnx.OnlineParaformerModelConfig
16
+ import com.k2fsa.sherpa.onnx.OnlineRecognizer
17
+ import com.k2fsa.sherpa.onnx.OnlineRecognizerConfig
18
+ import com.k2fsa.sherpa.onnx.OnlineRecognizerResult
19
+ import com.k2fsa.sherpa.onnx.OnlineStream
20
+ import com.k2fsa.sherpa.onnx.OnlineToneCtcModelConfig
21
+ import com.k2fsa.sherpa.onnx.OnlineTransducerModelConfig
22
+ import com.k2fsa.sherpa.onnx.OnlineZipformer2CtcModelConfig
23
+ import java.io.File
24
+ import java.util.concurrent.ConcurrentHashMap
25
+
26
+ /**
27
+ * Helper for streaming (online) STT using sherpa-onnx OnlineRecognizer + OnlineStream.
28
+ * Manages recognizer instances and streams; resolves model paths by scanning the model directory.
29
+ */
30
+ internal class SherpaOnnxOnlineSttHelper(
31
+ private val context: Context,
32
+ private val logTag: String
33
+ ) {
34
+
35
+ private data class OnlineSttInstance(
36
+ val recognizer: OnlineRecognizer,
37
+ val config: OnlineRecognizerConfig,
38
+ val streams: MutableMap<String, OnlineStream> = mutableMapOf()
39
+ )
40
+
41
+ private val instances = ConcurrentHashMap<String, OnlineSttInstance>()
42
+ private val streamToInstance = ConcurrentHashMap<String, String>()
43
+
44
+ private fun getInstance(instanceId: String): OnlineSttInstance? = instances[instanceId]
45
+
46
+ private fun getStream(streamId: String): Pair<OnlineSttInstance, OnlineStream>? {
47
+ val instanceId = streamToInstance[streamId] ?: return null
48
+ val inst = instances[instanceId] ?: return null
49
+ val stream = inst.streams[streamId] ?: return null
50
+ return inst to stream
51
+ }
52
+
53
+ private fun resolveContentUriToFile(path: String, cacheFilePrefix: String): String {
54
+ if (!path.startsWith("content://")) return path
55
+ val uri = Uri.parse(path)
56
+ val cacheFile = File(context.cacheDir, "${cacheFilePrefix}_${System.nanoTime()}")
57
+ context.contentResolver.openInputStream(uri)?.use { input ->
58
+ cacheFile.outputStream().use { output -> input.copyTo(output) }
59
+ } ?: throw IllegalStateException("File is not readable (content URI could not be opened): $path")
60
+ return cacheFile.absolutePath
61
+ }
62
+
63
+ private fun resolveFilePaths(pathsString: String, cacheFilePrefix: String): String {
64
+ if (pathsString.isBlank()) return pathsString
65
+ return pathsString.split(',').map { it.trim() }.filter { it.isNotEmpty() }
66
+ .mapIndexed { index, p -> resolveContentUriToFile(p, "${cacheFilePrefix}_$index") }
67
+ .joinToString(",")
68
+ }
69
+
70
+ /**
71
+ * Scan model directory for files matching the given online model type.
72
+ * Returns a map with keys: encoder, decoder, joiner, tokens (transducer/paraformer) or model, tokens (ctc types).
73
+ */
74
+ private fun scanOnlineModelPaths(modelDir: String, modelType: String): Map<String, String> {
75
+ val dir = File(modelDir)
76
+ if (!dir.exists() || !dir.isDirectory) {
77
+ throw IllegalArgumentException("Model directory does not exist or is not a directory: $modelDir")
78
+ }
79
+ val files = dir.listFiles()?.filter { it.isFile }.orEmpty()
80
+
81
+ fun firstFile(vararg prefixes: String, suffix: String = ".onnx"): String =
82
+ prefixes.firstNotNullOfOrNull { prefix ->
83
+ files.firstOrNull { it.name.startsWith(prefix) && it.name.endsWith(suffix) }?.absolutePath
84
+ }.orEmpty()
85
+
86
+ val tokensPath = files.firstOrNull { it.name == "tokens.txt" }?.absolutePath ?: ""
87
+
88
+ return when (modelType) {
89
+ "transducer" -> mapOf(
90
+ "encoder" to firstFile("encoder"),
91
+ "decoder" to firstFile("decoder"),
92
+ "joiner" to firstFile("joiner"),
93
+ "tokens" to tokensPath
94
+ )
95
+ "paraformer" -> mapOf(
96
+ "encoder" to firstFile("encoder"),
97
+ "decoder" to firstFile("decoder"),
98
+ "tokens" to tokensPath
99
+ )
100
+ "zipformer2_ctc", "nemo_ctc", "tone_ctc" -> mapOf(
101
+ "model" to firstFile("model"),
102
+ "tokens" to tokensPath
103
+ )
104
+ else -> throw IllegalArgumentException("Unsupported online STT model type: $modelType. Use: transducer, paraformer, zipformer2_ctc, nemo_ctc, tone_ctc")
105
+ }.also { paths ->
106
+ when (modelType) {
107
+ "transducer" -> {
108
+ if ((paths["encoder"]?.isEmpty() != false) || (paths["decoder"]?.isEmpty() != false) || (paths["joiner"]?.isEmpty() != false))
109
+ throw IllegalArgumentException("Transducer model requires encoder, decoder, and joiner .onnx files in $modelDir")
110
+ }
111
+ "paraformer" -> {
112
+ if ((paths["encoder"]?.isEmpty() != false) || (paths["decoder"]?.isEmpty() != false))
113
+ throw IllegalArgumentException("Paraformer model requires encoder and decoder .onnx files in $modelDir")
114
+ }
115
+ "zipformer2_ctc", "nemo_ctc", "tone_ctc" -> {
116
+ if (paths["model"]?.isEmpty() != false)
117
+ throw IllegalArgumentException("$modelType model requires model.onnx (or model*.onnx) in $modelDir")
118
+ }
119
+ }
120
+ }
121
+ }
122
+
123
+ private fun buildOnlineRecognizerConfig(
124
+ modelDir: String,
125
+ modelType: String,
126
+ enableEndpoint: Boolean,
127
+ decodingMethod: String,
128
+ maxActivePaths: Int,
129
+ hotwordsFile: String?,
130
+ hotwordsScore: Float?,
131
+ numThreads: Int?,
132
+ provider: String?,
133
+ ruleFsts: String?,
134
+ ruleFars: String?,
135
+ blankPenalty: Float?,
136
+ debug: Boolean?,
137
+ rule1MustContainNonSilence: Boolean?,
138
+ rule1MinTrailingSilence: Float?,
139
+ rule1MinUtteranceLength: Float?,
140
+ rule2MustContainNonSilence: Boolean?,
141
+ rule2MinTrailingSilence: Float?,
142
+ rule2MinUtteranceLength: Float?,
143
+ rule3MustContainNonSilence: Boolean?,
144
+ rule3MinTrailingSilence: Float?,
145
+ rule3MinUtteranceLength: Float?
146
+ ): OnlineRecognizerConfig {
147
+ val paths = scanOnlineModelPaths(modelDir, modelType)
148
+
149
+ val endpointConfig = EndpointConfig(
150
+ rule1 = EndpointRule(
151
+ mustContainNonSilence = rule1MustContainNonSilence ?: false,
152
+ minTrailingSilence = rule1MinTrailingSilence ?: 2.4f,
153
+ minUtteranceLength = rule1MinUtteranceLength ?: 0f
154
+ ),
155
+ rule2 = EndpointRule(
156
+ mustContainNonSilence = rule2MustContainNonSilence ?: true,
157
+ minTrailingSilence = rule2MinTrailingSilence ?: 1.4f,
158
+ minUtteranceLength = rule2MinUtteranceLength ?: 0f
159
+ ),
160
+ rule3 = EndpointRule(
161
+ mustContainNonSilence = rule3MustContainNonSilence ?: false,
162
+ minTrailingSilence = rule3MinTrailingSilence ?: 0f,
163
+ minUtteranceLength = rule3MinUtteranceLength ?: 20f
164
+ )
165
+ )
166
+
167
+ val modelConfig = when (modelType) {
168
+ "transducer" -> OnlineModelConfig(
169
+ transducer = OnlineTransducerModelConfig(
170
+ encoder = paths["encoder"] ?: "",
171
+ decoder = paths["decoder"] ?: "",
172
+ joiner = paths["joiner"] ?: ""
173
+ ),
174
+ tokens = paths["tokens"] ?: "",
175
+ numThreads = numThreads ?: 1,
176
+ debug = debug ?: false,
177
+ provider = provider ?: "cpu",
178
+ modelType = "zipformer"
179
+ )
180
+ "paraformer" -> OnlineModelConfig(
181
+ paraformer = OnlineParaformerModelConfig(
182
+ encoder = paths["encoder"] ?: "",
183
+ decoder = paths["decoder"] ?: ""
184
+ ),
185
+ tokens = paths["tokens"] ?: "",
186
+ numThreads = numThreads ?: 1,
187
+ debug = debug ?: false,
188
+ provider = provider ?: "cpu",
189
+ modelType = "paraformer"
190
+ )
191
+ "zipformer2_ctc" -> OnlineModelConfig(
192
+ zipformer2Ctc = OnlineZipformer2CtcModelConfig(model = paths["model"] ?: ""),
193
+ tokens = paths["tokens"] ?: "",
194
+ numThreads = numThreads ?: 1,
195
+ debug = debug ?: false,
196
+ provider = provider ?: "cpu",
197
+ modelType = "zipformer2"
198
+ )
199
+ "nemo_ctc" -> OnlineModelConfig(
200
+ neMoCtc = OnlineNeMoCtcModelConfig(model = paths["model"] ?: ""),
201
+ tokens = paths["tokens"] ?: "",
202
+ numThreads = numThreads ?: 1,
203
+ debug = debug ?: false,
204
+ provider = provider ?: "cpu"
205
+ )
206
+ "tone_ctc" -> OnlineModelConfig(
207
+ toneCtc = OnlineToneCtcModelConfig(model = paths["model"] ?: ""),
208
+ tokens = paths["tokens"] ?: "",
209
+ numThreads = numThreads ?: 1,
210
+ debug = debug ?: false,
211
+ provider = provider ?: "cpu"
212
+ )
213
+ else -> throw IllegalArgumentException("Unsupported online model type: $modelType")
214
+ }
215
+
216
+ val resolvedRuleFsts = try {
217
+ resolveFilePaths(ruleFsts.orEmpty().trim(), "online_stt_rule_fst")
218
+ } catch (e: Exception) {
219
+ ""
220
+ }
221
+ val resolvedRuleFars = try {
222
+ resolveFilePaths(ruleFars.orEmpty().trim(), "online_stt_rule_far")
223
+ } catch (e: Exception) {
224
+ ""
225
+ }
226
+ var resolvedHotwordsFile = hotwordsFile?.trim().orEmpty()
227
+ if (resolvedHotwordsFile.isNotEmpty()) {
228
+ try {
229
+ resolvedHotwordsFile = resolveContentUriToFile(resolvedHotwordsFile, "online_stt_hotwords")
230
+ } catch (_: Exception) {
231
+ resolvedHotwordsFile = ""
232
+ }
233
+ }
234
+
235
+ return OnlineRecognizerConfig(
236
+ featConfig = FeatureConfig(sampleRate = 16000, featureDim = 80, dither = 0f),
237
+ modelConfig = modelConfig,
238
+ endpointConfig = endpointConfig,
239
+ enableEndpoint = enableEndpoint,
240
+ decodingMethod = decodingMethod,
241
+ maxActivePaths = maxActivePaths,
242
+ hotwordsFile = resolvedHotwordsFile,
243
+ hotwordsScore = hotwordsScore ?: 1.5f,
244
+ ruleFsts = resolvedRuleFsts,
245
+ ruleFars = resolvedRuleFars,
246
+ blankPenalty = blankPenalty ?: 0f
247
+ )
248
+ }
249
+
250
+ fun initializeOnlineStt(
251
+ instanceId: String,
252
+ modelDir: String,
253
+ modelType: String,
254
+ enableEndpoint: Boolean,
255
+ decodingMethod: String,
256
+ maxActivePaths: Int,
257
+ hotwordsFile: String?,
258
+ hotwordsScore: Double?,
259
+ numThreads: Double?,
260
+ provider: String?,
261
+ ruleFsts: String?,
262
+ ruleFars: String?,
263
+ blankPenalty: Double?,
264
+ debug: Boolean?,
265
+ rule1MustContainNonSilence: Boolean?,
266
+ rule1MinTrailingSilence: Double?,
267
+ rule1MinUtteranceLength: Double?,
268
+ rule2MustContainNonSilence: Boolean?,
269
+ rule2MinTrailingSilence: Double?,
270
+ rule2MinUtteranceLength: Double?,
271
+ rule3MustContainNonSilence: Boolean?,
272
+ rule3MinTrailingSilence: Double?,
273
+ rule3MinUtteranceLength: Double?,
274
+ promise: Promise
275
+ ) {
276
+ try {
277
+ val config = buildOnlineRecognizerConfig(
278
+ modelDir = modelDir,
279
+ modelType = modelType,
280
+ enableEndpoint = enableEndpoint,
281
+ decodingMethod = decodingMethod,
282
+ maxActivePaths = maxActivePaths,
283
+ hotwordsFile = hotwordsFile,
284
+ hotwordsScore = hotwordsScore?.toFloat(),
285
+ numThreads = numThreads?.toInt(),
286
+ provider = provider,
287
+ ruleFsts = ruleFsts,
288
+ ruleFars = ruleFars,
289
+ blankPenalty = blankPenalty?.toFloat(),
290
+ debug = debug,
291
+ rule1MustContainNonSilence = rule1MustContainNonSilence,
292
+ rule1MinTrailingSilence = rule1MinTrailingSilence?.toFloat(),
293
+ rule1MinUtteranceLength = rule1MinUtteranceLength?.toFloat(),
294
+ rule2MustContainNonSilence = rule2MustContainNonSilence,
295
+ rule2MinTrailingSilence = rule2MinTrailingSilence?.toFloat(),
296
+ rule2MinUtteranceLength = rule2MinUtteranceLength?.toFloat(),
297
+ rule3MustContainNonSilence = rule3MustContainNonSilence,
298
+ rule3MinTrailingSilence = rule3MinTrailingSilence?.toFloat(),
299
+ rule3MinUtteranceLength = rule3MinUtteranceLength?.toFloat()
300
+ )
301
+ val recognizer = OnlineRecognizer(assetManager = null, config = config)
302
+ instances[instanceId] = OnlineSttInstance(recognizer = recognizer, config = config)
303
+ promise.resolve(Arguments.createMap().apply { putBoolean("success", true) })
304
+ } catch (e: Exception) {
305
+ Log.e(logTag, "initializeOnlineStt failed: ${e.message}", e)
306
+ promise.reject("INIT_ERROR", "Online STT init failed: ${e.message}", e)
307
+ }
308
+ }
309
+
310
+ fun createSttStream(instanceId: String, streamId: String, hotwords: String?, promise: Promise) {
311
+ try {
312
+ val inst = getInstance(instanceId)
313
+ ?: run {
314
+ promise.reject("STREAM_ERROR", "Online STT instance not found: $instanceId")
315
+ return
316
+ }
317
+ if (inst.streams.containsKey(streamId)) {
318
+ promise.reject("STREAM_ERROR", "Stream already exists: $streamId")
319
+ return
320
+ }
321
+ val stream = inst.recognizer.createStream(hotwords = hotwords?.trim().orEmpty())
322
+ inst.streams[streamId] = stream
323
+ streamToInstance[streamId] = instanceId
324
+ promise.resolve(null)
325
+ } catch (e: Exception) {
326
+ Log.e(logTag, "createSttStream failed: ${e.message}", e)
327
+ promise.reject("STREAM_ERROR", "Create stream failed: ${e.message}", e)
328
+ }
329
+ }
330
+
331
+ private fun readableArrayToFloatArray(arr: ReadableArray): FloatArray =
332
+ FloatArray(arr.size()) { i -> arr.getDouble(i).toFloat() }
333
+
334
+ fun acceptSttWaveform(streamId: String, samples: ReadableArray, sampleRate: Int, promise: Promise) {
335
+ try {
336
+ val (_, stream) = getStream(streamId)
337
+ ?: run {
338
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
339
+ return
340
+ }
341
+ val floatSamples = readableArrayToFloatArray(samples)
342
+ stream.acceptWaveform(floatSamples, sampleRate)
343
+ promise.resolve(null)
344
+ } catch (e: Exception) {
345
+ Log.e(logTag, "acceptSttWaveform failed: ${e.message}", e)
346
+ promise.reject("STREAM_ERROR", "acceptSttWaveform failed: ${e.message}", e)
347
+ }
348
+ }
349
+
350
+ fun sttStreamInputFinished(streamId: String, promise: Promise) {
351
+ try {
352
+ val (_, stream) = getStream(streamId)
353
+ ?: run {
354
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
355
+ return
356
+ }
357
+ stream.inputFinished()
358
+ promise.resolve(null)
359
+ } catch (e: Exception) {
360
+ Log.e(logTag, "sttStreamInputFinished failed: ${e.message}", e)
361
+ promise.reject("STREAM_ERROR", "sttStreamInputFinished failed: ${e.message}", e)
362
+ }
363
+ }
364
+
365
+ fun decodeSttStream(streamId: String, promise: Promise) {
366
+ try {
367
+ val (inst, stream) = getStream(streamId)
368
+ ?: run {
369
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
370
+ return
371
+ }
372
+ inst.recognizer.decode(stream)
373
+ promise.resolve(null)
374
+ } catch (e: Exception) {
375
+ Log.e(logTag, "decodeSttStream failed: ${e.message}", e)
376
+ promise.reject("STREAM_ERROR", "decodeSttStream failed: ${e.message}", e)
377
+ }
378
+ }
379
+
380
+ fun isSttStreamReady(streamId: String, promise: Promise) {
381
+ try {
382
+ val (inst, stream) = getStream(streamId)
383
+ ?: run {
384
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
385
+ return
386
+ }
387
+ val ready = inst.recognizer.isReady(stream)
388
+ promise.resolve(ready)
389
+ } catch (e: Exception) {
390
+ Log.e(logTag, "isSttStreamReady failed: ${e.message}", e)
391
+ promise.reject("STREAM_ERROR", "isSttStreamReady failed: ${e.message}", e)
392
+ }
393
+ }
394
+
395
+ private fun resultToWritableMap(result: OnlineRecognizerResult): WritableMap {
396
+ val map = Arguments.createMap()
397
+ map.putString("text", result.text)
398
+ val tokensArray = Arguments.createArray()
399
+ for (t in result.tokens) tokensArray.pushString(t)
400
+ map.putArray("tokens", tokensArray)
401
+ val timestampsArray = Arguments.createArray()
402
+ for (t in result.timestamps) timestampsArray.pushDouble(t.toDouble())
403
+ map.putArray("timestamps", timestampsArray)
404
+ return map
405
+ }
406
+
407
+ fun getSttStreamResult(streamId: String, promise: Promise) {
408
+ try {
409
+ val (inst, stream) = getStream(streamId)
410
+ ?: run {
411
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
412
+ return
413
+ }
414
+ val result = inst.recognizer.getResult(stream)
415
+ promise.resolve(resultToWritableMap(result))
416
+ } catch (e: Exception) {
417
+ Log.e(logTag, "getSttStreamResult failed: ${e.message}", e)
418
+ promise.reject("STREAM_ERROR", "getSttStreamResult failed: ${e.message}", e)
419
+ }
420
+ }
421
+
422
+ fun isSttStreamEndpoint(streamId: String, promise: Promise) {
423
+ try {
424
+ val (inst, stream) = getStream(streamId)
425
+ ?: run {
426
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
427
+ return
428
+ }
429
+ val endpoint = inst.recognizer.isEndpoint(stream)
430
+ promise.resolve(endpoint)
431
+ } catch (e: Exception) {
432
+ Log.e(logTag, "isSttStreamEndpoint failed: ${e.message}", e)
433
+ promise.reject("STREAM_ERROR", "isSttStreamEndpoint failed: ${e.message}", e)
434
+ }
435
+ }
436
+
437
+ fun resetSttStream(streamId: String, promise: Promise) {
438
+ try {
439
+ val (inst, stream) = getStream(streamId)
440
+ ?: run {
441
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
442
+ return
443
+ }
444
+ inst.recognizer.reset(stream)
445
+ promise.resolve(null)
446
+ } catch (e: Exception) {
447
+ Log.e(logTag, "resetSttStream failed: ${e.message}", e)
448
+ promise.reject("STREAM_ERROR", "resetSttStream failed: ${e.message}", e)
449
+ }
450
+ }
451
+
452
+ fun releaseSttStream(streamId: String, promise: Promise) {
453
+ try {
454
+ val instanceId = streamToInstance.remove(streamId) ?: run {
455
+ promise.resolve(null)
456
+ return
457
+ }
458
+ val inst = instances[instanceId] ?: run {
459
+ promise.resolve(null)
460
+ return
461
+ }
462
+ inst.streams.remove(streamId)?.release()
463
+ promise.resolve(null)
464
+ } catch (e: Exception) {
465
+ Log.e(logTag, "releaseSttStream failed: ${e.message}", e)
466
+ promise.reject("STREAM_ERROR", "releaseSttStream failed: ${e.message}", e)
467
+ }
468
+ }
469
+
470
+ fun unloadOnlineStt(instanceId: String, promise: Promise) {
471
+ try {
472
+ val inst = instances.remove(instanceId) ?: run {
473
+ promise.resolve(null)
474
+ return
475
+ }
476
+ val streamIds = inst.streams.keys.toList()
477
+ inst.streams.values.forEach { it.release() }
478
+ inst.streams.clear()
479
+ streamIds.forEach { streamToInstance.remove(it) }
480
+ inst.recognizer.release()
481
+ promise.resolve(null)
482
+ } catch (e: Exception) {
483
+ Log.e(logTag, "unloadOnlineStt failed: ${e.message}", e)
484
+ promise.reject("RELEASE_ERROR", "unloadOnlineStt failed: ${e.message}", e)
485
+ }
486
+ }
487
+
488
+ /**
489
+ * Convenience: accept waveform, then while (isReady) decode, then getResult and isEndpoint.
490
+ */
491
+ fun processSttAudioChunk(
492
+ streamId: String,
493
+ samples: ReadableArray,
494
+ sampleRate: Int,
495
+ promise: Promise
496
+ ) {
497
+ try {
498
+ val (inst, stream) = getStream(streamId)
499
+ ?: run {
500
+ promise.reject("STREAM_ERROR", "Stream not found: $streamId")
501
+ return
502
+ }
503
+ val floatSamples = readableArrayToFloatArray(samples)
504
+ stream.acceptWaveform(floatSamples, sampleRate)
505
+ while (inst.recognizer.isReady(stream)) {
506
+ inst.recognizer.decode(stream)
507
+ }
508
+ val result = inst.recognizer.getResult(stream)
509
+ val isEndpoint = inst.recognizer.isEndpoint(stream)
510
+ val map = resultToWritableMap(result)
511
+ map.putBoolean("isEndpoint", isEndpoint)
512
+ promise.resolve(map)
513
+ } catch (e: Exception) {
514
+ Log.e(logTag, "processSttAudioChunk failed: ${e.message}", e)
515
+ promise.reject("STREAM_ERROR", "processSttAudioChunk failed: ${e.message}", e)
516
+ }
517
+ }
518
+
519
+ /** Call from Module.onCatalystInstanceDestroy to release all resources. */
520
+ fun shutdown() {
521
+ instances.keys.toList().forEach { instanceId ->
522
+ try {
523
+ val inst = instances.remove(instanceId) ?: return@forEach
524
+ val streamIds = inst.streams.keys.toList()
525
+ inst.streams.values.forEach { it.release() }
526
+ inst.streams.clear()
527
+ streamIds.forEach { streamToInstance.remove(it) }
528
+ inst.recognizer.release()
529
+ } catch (e: Exception) {
530
+ Log.w(logTag, "shutdown: failed to release instance $instanceId: ${e.message}")
531
+ }
532
+ }
533
+ streamToInstance.clear()
534
+ }
535
+ }
@@ -41,9 +41,9 @@ import java.util.concurrent.atomic.AtomicBoolean
41
41
  internal class SherpaOnnxTtsHelper(
42
42
  private val context: ReactApplicationContext,
43
43
  private val detectTtsModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?,
44
- private val emitChunk: (String, FloatArray, Int, Float, Boolean) -> Unit,
45
- private val emitError: (String, String) -> Unit,
46
- private val emitEnd: (String, Boolean) -> Unit
44
+ private val emitChunk: (String, String, FloatArray, Int, Float, Boolean) -> Unit,
45
+ private val emitError: (String, String, String) -> Unit,
46
+ private val emitEnd: (String, String, Boolean) -> Unit
47
47
  ) {
48
48
 
49
49
  private data class TtsInitState(
@@ -501,7 +501,7 @@ internal class SherpaOnnxTtsHelper(
501
501
  }
502
502
  }
503
503
 
504
- fun generateTtsStream(instanceId: String, text: String, options: ReadableMap?, promise: Promise) {
504
+ fun generateTtsStream(instanceId: String, requestId: String, text: String, options: ReadableMap?, promise: Promise) {
505
505
  val inst = getInstance(instanceId) ?: run {
506
506
  Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: TTS instance not found: $instanceId")
507
507
  promise.reject("TTS_STREAM_ERROR", "TTS instance not found: $instanceId")
@@ -534,34 +534,34 @@ internal class SherpaOnnxTtsHelper(
534
534
  val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
535
535
  inst.tts!!.generateWithConfigAndCallback(text, config) { chunk ->
536
536
  if (inst.ttsStreamCancelled.get()) return@generateWithConfigAndCallback 0
537
- emitChunk(instanceId, chunk, sampleRate, 0f, false)
537
+ emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
538
538
  chunk.size
539
539
  }
540
540
  }
541
541
  inst.zipvoiceTts != null -> {
542
542
  inst.zipvoiceTts!!.generateWithCallback(text, sid, speed) { chunk ->
543
543
  if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
544
- emitChunk(instanceId, chunk, sampleRate, 0f, false)
544
+ emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
545
545
  chunk.size
546
546
  }
547
547
  }
548
548
  else -> {
549
549
  inst.tts!!.generateWithCallback(text, sid, speed) { chunk ->
550
550
  if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
551
- emitChunk(instanceId, chunk, sampleRate, 0f, false)
551
+ emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
552
552
  chunk.size
553
553
  }
554
554
  }
555
555
  }
556
556
  if (!inst.ttsStreamCancelled.get()) {
557
- emitChunk(instanceId, FloatArray(0), sampleRate, 1f, true)
557
+ emitChunk(instanceId, requestId, FloatArray(0), sampleRate, 1f, true)
558
558
  }
559
559
  } catch (e: Exception) {
560
560
  if (!inst.ttsStreamCancelled.get()) {
561
- emitError(instanceId, "TTS streaming failed: ${e.message}")
561
+ emitError(instanceId, requestId, "TTS streaming failed: ${e.message}")
562
562
  }
563
563
  } finally {
564
- emitEnd(instanceId, inst.ttsStreamCancelled.get())
564
+ emitEnd(instanceId, requestId, inst.ttsStreamCancelled.get())
565
565
  inst.ttsStreamRunning.set(false)
566
566
  }
567
567
  }