react-native-sherpa-onnx 0.4.0 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. package/README.md +3 -0
  2. package/android/src/main/assets/model_licenses/alignment-models-license-status.csv +5 -0
  3. package/android/src/main/cpp/CMakeLists.txt +3 -0
  4. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.cpp +66 -0
  5. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.h +17 -0
  6. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-alignment.cpp +108 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +30 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.cpp +66 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.h +30 -0
  10. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  11. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAlignmentHelper.kt +555 -0
  12. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +76 -0
  13. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTextSegmenter.kt +330 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +180 -23
  15. package/ios/Resources/model_licenses/alignment-models-license-status.csv +5 -0
  16. package/ios/SherpaOnnx+Alignment.mm +704 -0
  17. package/ios/SherpaOnnx+STT.mm +6 -0
  18. package/ios/SherpaOnnx+TTS.mm +624 -50
  19. package/ios/model_detect/sherpa-onnx-model-detect-alignment.mm +108 -0
  20. package/ios/model_detect/sherpa-onnx-model-detect.h +31 -0
  21. package/ios/model_detect/sherpa-onnx-validate-alignment.h +30 -0
  22. package/ios/model_detect/sherpa-onnx-validate-alignment.mm +66 -0
  23. package/ios/stt/sherpa-onnx-stt-wrapper.h +3 -1
  24. package/ios/stt/sherpa-onnx-stt-wrapper.mm +6 -0
  25. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  26. package/lib/module/alignment/index.js +27 -0
  27. package/lib/module/alignment/index.js.map +1 -0
  28. package/lib/module/alignment/types.js +2 -0
  29. package/lib/module/alignment/types.js.map +1 -0
  30. package/lib/module/alignment/vocab.js +40 -0
  31. package/lib/module/alignment/vocab.js.map +1 -0
  32. package/lib/module/download/paths.js +9 -1
  33. package/lib/module/download/paths.js.map +1 -1
  34. package/lib/module/download/registry.js +17 -1
  35. package/lib/module/download/registry.js.map +1 -1
  36. package/lib/module/download/types.js +1 -0
  37. package/lib/module/download/types.js.map +1 -1
  38. package/lib/module/index.js +6 -4
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/licenses.js +8 -2
  41. package/lib/module/licenses.js.map +1 -1
  42. package/lib/module/stt/types.js.map +1 -1
  43. package/lib/module/tts/index.js +68 -2
  44. package/lib/module/tts/index.js.map +1 -1
  45. package/lib/module/tts/subtitles.js +400 -0
  46. package/lib/module/tts/subtitles.js.map +1 -0
  47. package/lib/module/tts/tempAudio.js +17 -0
  48. package/lib/module/tts/tempAudio.js.map +1 -0
  49. package/lib/module/tts/types.js.map +1 -1
  50. package/lib/typescript/src/NativeSherpaOnnx.d.ts +34 -3
  51. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  52. package/lib/typescript/src/alignment/index.d.ts +8 -0
  53. package/lib/typescript/src/alignment/index.d.ts.map +1 -0
  54. package/lib/typescript/src/alignment/types.d.ts +23 -0
  55. package/lib/typescript/src/alignment/types.d.ts.map +1 -0
  56. package/lib/typescript/src/alignment/vocab.d.ts +5 -0
  57. package/lib/typescript/src/alignment/vocab.d.ts.map +1 -0
  58. package/lib/typescript/src/download/paths.d.ts +5 -2
  59. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  60. package/lib/typescript/src/download/registry.d.ts.map +1 -1
  61. package/lib/typescript/src/download/types.d.ts +2 -1
  62. package/lib/typescript/src/download/types.d.ts.map +1 -1
  63. package/lib/typescript/src/index.d.ts +1 -0
  64. package/lib/typescript/src/index.d.ts.map +1 -1
  65. package/lib/typescript/src/licenses.d.ts.map +1 -1
  66. package/lib/typescript/src/stt/types.d.ts +5 -2
  67. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  68. package/lib/typescript/src/tts/index.d.ts +2 -1
  69. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  70. package/lib/typescript/src/tts/subtitles.d.ts +24 -0
  71. package/lib/typescript/src/tts/subtitles.d.ts.map +1 -0
  72. package/lib/typescript/src/tts/tempAudio.d.ts +3 -0
  73. package/lib/typescript/src/tts/tempAudio.d.ts.map +1 -0
  74. package/lib/typescript/src/tts/types.d.ts +68 -2
  75. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  76. package/package.json +6 -1
  77. package/scripts/alignment-models/README.md +90 -0
  78. package/scripts/alignment-models/build_and_upload.js +724 -0
  79. package/scripts/alignment-models/sources.csv +5 -0
  80. package/scripts/alignment-models/sync_alignment_license_status.js +123 -0
  81. package/src/NativeSherpaOnnx.ts +35 -3
  82. package/src/alignment/index.ts +41 -0
  83. package/src/alignment/types.ts +22 -0
  84. package/src/alignment/vocab.ts +38 -0
  85. package/src/download/paths.ts +18 -5
  86. package/src/download/registry.ts +23 -3
  87. package/src/download/types.ts +1 -0
  88. package/src/index.tsx +6 -4
  89. package/src/licenses.ts +12 -1
  90. package/src/stt/types.ts +5 -2
  91. package/src/tts/index.ts +110 -3
  92. package/src/tts/subtitles.ts +611 -0
  93. package/src/tts/tempAudio.ts +31 -0
  94. package/src/tts/types.ts +79 -2
  95. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
@@ -0,0 +1,555 @@
1
+ package com.sherpaonnx
2
+
3
+ import ai.onnxruntime.OnnxTensor
4
+ import ai.onnxruntime.OrtEnvironment
5
+ import ai.onnxruntime.OrtSession
6
+ import ai.onnxruntime.TensorInfo
7
+ import android.net.Uri
8
+ import android.util.Log
9
+ import com.facebook.react.bridge.Arguments
10
+ import com.facebook.react.bridge.Promise
11
+ import com.facebook.react.bridge.ReactApplicationContext
12
+ import com.facebook.react.bridge.WritableArray
13
+ import com.facebook.react.bridge.WritableMap
14
+ import com.k2fsa.sherpa.onnx.WaveReader
15
+ import org.json.JSONObject
16
+ import java.io.File
17
+ import java.io.FileOutputStream
18
+ import java.nio.FloatBuffer
19
+ import java.util.Locale
20
+ import java.util.concurrent.Executors
21
+ import kotlin.math.exp
22
+ import kotlin.math.floor
23
+ import kotlin.math.ln
24
+ import kotlin.math.max
25
+ import kotlin.math.min
26
+ import kotlin.math.sqrt
27
+
28
+ internal class SherpaOnnxAlignmentHelper(
29
+ private val context: ReactApplicationContext
30
+ ) {
31
+ private val executor = Executors.newSingleThreadExecutor()
32
+
33
+ private data class AlignmentItem(
34
+ val text: String,
35
+ val start: Double,
36
+ val end: Double,
37
+ )
38
+
39
+ private data class ExpandedTarget(
40
+ val ids: IntArray,
41
+ val tokenIndices: IntArray,
42
+ )
43
+
44
+ fun shutdown() {
45
+ executor.shutdownNow()
46
+ }
47
+
48
+ fun runCTCForcedAlignment(
49
+ modelPath: String,
50
+ audioPath: String,
51
+ text: String,
52
+ vocabJson: String,
53
+ promise: Promise,
54
+ ) {
55
+ executor.execute {
56
+ var cleanupPath: String? = null
57
+ try {
58
+ if (modelPath.isBlank()) {
59
+ promise.reject("ALIGNMENT_ERROR", "modelPath is required")
60
+ return@execute
61
+ }
62
+ if (audioPath.isBlank()) {
63
+ promise.reject("ALIGNMENT_ERROR", "audioPath is required")
64
+ return@execute
65
+ }
66
+ if (text.isBlank()) {
67
+ promise.reject("ALIGNMENT_ERROR", "text is required")
68
+ return@execute
69
+ }
70
+
71
+ val resolvedAudio = resolveAudioPath(audioPath)
72
+ cleanupPath = resolvedAudio.second
73
+
74
+ val file = File(resolvedAudio.first)
75
+ if (!file.exists() || file.length() <= 0L) {
76
+ promise.reject("ALIGNMENT_ERROR", "Audio file does not exist or is empty: ${resolvedAudio.first}")
77
+ return@execute
78
+ }
79
+
80
+ val vocab = parseVocab(vocabJson)
81
+ val blankId = vocab["<pad>"] ?: 0
82
+ val wordBoundaryId = vocab["|"] ?: 4
83
+
84
+ val tokenTexts = buildTokenTexts(text, vocab, wordBoundaryId)
85
+ if (tokenTexts.isEmpty()) {
86
+ promise.reject("ALIGNMENT_ERROR", "Transcript has no alignable tokens for the provided vocabulary")
87
+ return@execute
88
+ }
89
+
90
+ val tokenIds = IntArray(tokenTexts.size) { index ->
91
+ vocab[tokenTexts[index]] ?: blankId
92
+ }
93
+
94
+ val wave = WaveReader.readWave(resolvedAudio.first)
95
+ val rawSamples = wave.samples ?: FloatArray(0)
96
+ if (rawSamples.isEmpty()) {
97
+ promise.reject("ALIGNMENT_ERROR", "Could not decode WAV samples from: ${resolvedAudio.first}")
98
+ return@execute
99
+ }
100
+
101
+ val mono16k = if (wave.sampleRate == 16000) {
102
+ rawSamples
103
+ } else {
104
+ resampleLinear(rawSamples, wave.sampleRate, 16000)
105
+ }
106
+
107
+ val normalized = normalizeAudio(mono16k)
108
+ if (normalized.isEmpty()) {
109
+ promise.reject("ALIGNMENT_ERROR", "Audio is empty after preprocessing")
110
+ return@execute
111
+ }
112
+
113
+ val logits = runInference(modelPath, normalized)
114
+ if (logits.isEmpty() || logits[0].isEmpty()) {
115
+ promise.reject("ALIGNMENT_ERROR", "Model inference returned empty logits")
116
+ return@execute
117
+ }
118
+
119
+ val expanded = buildExpandedTarget(tokenIds, blankId)
120
+ val path = ctcBacktrack(logits, expanded.ids, blankId)
121
+
122
+ val frameIndicesByToken = Array(tokenIds.size) { mutableListOf<Int>() }
123
+ for (t in path.indices) {
124
+ val state = path[t]
125
+ if (state < 0 || state >= expanded.tokenIndices.size) {
126
+ continue
127
+ }
128
+ val tokenIndex = expanded.tokenIndices[state]
129
+ val tokenId = expanded.ids[state]
130
+ if (tokenIndex >= 0 && tokenIndex < frameIndicesByToken.size && tokenId != blankId) {
131
+ frameIndicesByToken[tokenIndex].add(t)
132
+ }
133
+ }
134
+
135
+ val charItems = mutableListOf<AlignmentItem>()
136
+ var fallbackEndFrame = 0
137
+
138
+ for (i in tokenTexts.indices) {
139
+ val token = tokenTexts[i]
140
+ if (token == "|") {
141
+ continue
142
+ }
143
+
144
+ val frames = frameIndicesByToken[i]
145
+ val startFrame: Int
146
+ val endFrameExclusive: Int
147
+ if (frames.isNotEmpty()) {
148
+ startFrame = frames.first()
149
+ endFrameExclusive = frames.last() + 1
150
+ fallbackEndFrame = max(fallbackEndFrame, endFrameExclusive)
151
+ } else {
152
+ startFrame = fallbackEndFrame
153
+ endFrameExclusive = fallbackEndFrame
154
+ }
155
+
156
+ val start = startFrame * 0.02
157
+ val end = max(start, endFrameExclusive * 0.02)
158
+ charItems.add(AlignmentItem(token, start, end))
159
+ }
160
+
161
+ val wordItems = mutableListOf<AlignmentItem>()
162
+ val currentWord = StringBuilder()
163
+ var wordStart = 0.0
164
+ var wordEnd = 0.0
165
+ var charCursor = 0
166
+
167
+ for (token in tokenTexts) {
168
+ if (token == "|") {
169
+ if (currentWord.isNotEmpty()) {
170
+ wordItems.add(AlignmentItem(currentWord.toString(), wordStart, wordEnd))
171
+ currentWord.clear()
172
+ }
173
+ continue
174
+ }
175
+
176
+ val charItem = charItems.getOrNull(charCursor)
177
+ charCursor += 1
178
+ if (charItem == null) {
179
+ continue
180
+ }
181
+
182
+ if (currentWord.isEmpty()) {
183
+ wordStart = charItem.start
184
+ wordEnd = charItem.end
185
+ } else {
186
+ wordEnd = max(wordEnd, charItem.end)
187
+ }
188
+ currentWord.append(charItem.text)
189
+ }
190
+
191
+ if (currentWord.isNotEmpty()) {
192
+ wordItems.add(AlignmentItem(currentWord.toString(), wordStart, wordEnd))
193
+ }
194
+
195
+ val result = Arguments.createMap()
196
+ result.putArray("words", toWritableArray(wordItems))
197
+ result.putArray("chars", toWritableArray(charItems))
198
+ promise.resolve(result)
199
+ } catch (e: Exception) {
200
+ Log.e("SherpaOnnxAlignment", "ALIGNMENT_ERROR: ${e.message}", e)
201
+ promise.reject("ALIGNMENT_ERROR", e.message ?: "CTC alignment failed", e)
202
+ } finally {
203
+ if (cleanupPath != null) {
204
+ try {
205
+ File(cleanupPath).delete()
206
+ } catch (_: Exception) {
207
+ // ignore cleanup errors
208
+ }
209
+ }
210
+ }
211
+ }
212
+ }
213
+
214
+ private fun resolveAudioPath(audioPath: String): Pair<String, String?> {
215
+ if (!audioPath.startsWith("content://")) {
216
+ return Pair(audioPath, null)
217
+ }
218
+
219
+ val uri = Uri.parse(audioPath)
220
+ val tempFile = File.createTempFile("alignment_input_", ".wav", context.cacheDir)
221
+ context.contentResolver.openInputStream(uri)?.use { input ->
222
+ FileOutputStream(tempFile).use { output ->
223
+ input.copyTo(output)
224
+ }
225
+ } ?: throw IllegalStateException("Could not open content URI: $audioPath")
226
+
227
+ return Pair(tempFile.absolutePath, tempFile.absolutePath)
228
+ }
229
+
230
+ private fun parseVocab(vocabJson: String): Map<String, Int> {
231
+ val obj = JSONObject(vocabJson)
232
+ val out = linkedMapOf<String, Int>()
233
+ val keys = obj.keys()
234
+ while (keys.hasNext()) {
235
+ val key = keys.next()
236
+ if (key.isBlank()) {
237
+ continue
238
+ }
239
+ val value = obj.optInt(key, Int.MIN_VALUE)
240
+ if (value != Int.MIN_VALUE) {
241
+ out[key] = value
242
+ }
243
+ }
244
+ if (out.isEmpty()) {
245
+ throw IllegalArgumentException("Vocabulary JSON is empty")
246
+ }
247
+ return out
248
+ }
249
+
250
+ private fun buildTokenTexts(
251
+ text: String,
252
+ vocab: Map<String, Int>,
253
+ wordBoundaryId: Int,
254
+ ): List<String> {
255
+ val out = mutableListOf<String>()
256
+ val uppercase = text.uppercase(Locale.US)
257
+
258
+ for (char in uppercase) {
259
+ if (char.isWhitespace()) {
260
+ if (out.isNotEmpty() && out.last() != "|") {
261
+ out.add("|")
262
+ }
263
+ continue
264
+ }
265
+
266
+ val normalized = when (char) {
267
+ '’', '`', '´' -> '\''
268
+ else -> char
269
+ }
270
+ val token = normalized.toString()
271
+ if (vocab.containsKey(token)) {
272
+ out.add(token)
273
+ }
274
+ }
275
+
276
+ while (out.firstOrNull() == "|") {
277
+ out.removeAt(0)
278
+ }
279
+ while (out.lastOrNull() == "|") {
280
+ out.removeAt(out.lastIndex)
281
+ }
282
+
283
+ if (!vocab.containsKey("|") || vocab["|"] != wordBoundaryId) {
284
+ return out.filter { it != "|" }
285
+ }
286
+
287
+ return out
288
+ }
289
+
290
+ private fun resampleLinear(
291
+ input: FloatArray,
292
+ sourceSampleRate: Int,
293
+ targetSampleRate: Int,
294
+ ): FloatArray {
295
+ if (input.isEmpty() || sourceSampleRate <= 0 || targetSampleRate <= 0) {
296
+ return FloatArray(0)
297
+ }
298
+ if (sourceSampleRate == targetSampleRate) {
299
+ return input
300
+ }
301
+
302
+ val outputLength = max(1, floor(input.size.toDouble() * targetSampleRate / sourceSampleRate).toInt())
303
+ val output = FloatArray(outputLength)
304
+ val ratio = sourceSampleRate.toDouble() / targetSampleRate.toDouble()
305
+
306
+ for (i in 0 until outputLength) {
307
+ val srcPos = i * ratio
308
+ val leftIndex = floor(srcPos).toInt()
309
+ val rightIndex = min(leftIndex + 1, input.lastIndex)
310
+ val frac = srcPos - leftIndex
311
+ val left = input[min(leftIndex, input.lastIndex)]
312
+ val right = input[rightIndex]
313
+ output[i] = (left + (right - left) * frac).toFloat()
314
+ }
315
+
316
+ return output
317
+ }
318
+
319
+ private fun normalizeAudio(input: FloatArray): FloatArray {
320
+ if (input.isEmpty()) {
321
+ return input
322
+ }
323
+
324
+ var sum = 0.0
325
+ for (sample in input) {
326
+ sum += sample
327
+ }
328
+ val mean = sum / input.size
329
+
330
+ var varianceSum = 0.0
331
+ for (sample in input) {
332
+ val centered = sample - mean
333
+ varianceSum += centered * centered
334
+ }
335
+
336
+ val std = sqrt(max(varianceSum / input.size, 1e-12))
337
+ val out = FloatArray(input.size)
338
+ for (i in input.indices) {
339
+ out[i] = ((input[i] - mean) / std).toFloat()
340
+ }
341
+ return out
342
+ }
343
+
344
+ private fun runInference(modelPath: String, samples: FloatArray): Array<FloatArray> {
345
+ val env = OrtEnvironment.getEnvironment()
346
+
347
+ OrtSession.SessionOptions().use { sessionOptions ->
348
+ env.createSession(modelPath, sessionOptions).use { session ->
349
+ val inputName = session.inputNames.firstOrNull()
350
+ ?: throw IllegalStateException("Alignment model has no input")
351
+
352
+ val inputShape = longArrayOf(1L, samples.size.toLong())
353
+ OnnxTensor.createTensor(env, FloatBuffer.wrap(samples), inputShape).use { inputTensor ->
354
+ val outputs = session.run(mapOf(inputName to inputTensor))
355
+ outputs.use { result ->
356
+ val outputTensor = result.get(0) as? OnnxTensor
357
+ ?: throw IllegalStateException("Alignment model output is not a tensor")
358
+
359
+ val info = outputTensor.info as? TensorInfo
360
+ ?: throw IllegalStateException("Alignment tensor info missing")
361
+
362
+ val shape = info.shape
363
+ val floatBuffer = outputTensor.floatBuffer
364
+ floatBuffer.rewind()
365
+
366
+ val totalValues = floatBuffer.remaining()
367
+ if (totalValues <= 0) {
368
+ return emptyArray()
369
+ }
370
+
371
+ val logitsFlat = FloatArray(totalValues)
372
+ floatBuffer.get(logitsFlat)
373
+
374
+ val (frames, vocabSize) = when {
375
+ shape.size >= 3 -> {
376
+ val t = shape[1].toInt()
377
+ val v = shape[2].toInt()
378
+ Pair(max(1, t), max(1, v))
379
+ }
380
+ shape.size == 2 -> {
381
+ val t = shape[0].toInt()
382
+ val v = shape[1].toInt()
383
+ Pair(max(1, t), max(1, v))
384
+ }
385
+ else -> {
386
+ Pair(1, max(1, totalValues))
387
+ }
388
+ }
389
+
390
+ val safeFrames = max(1, min(frames, totalValues))
391
+ val safeVocab = max(1, min(vocabSize, totalValues / safeFrames))
392
+
393
+ return logSoftmax(logitsFlat, safeFrames, safeVocab)
394
+ }
395
+ }
396
+ }
397
+ }
398
+ }
399
+
400
+ private fun logSoftmax(
401
+ logitsFlat: FloatArray,
402
+ frames: Int,
403
+ vocabSize: Int,
404
+ ): Array<FloatArray> {
405
+ val out = Array(frames) { FloatArray(vocabSize) }
406
+
407
+ for (t in 0 until frames) {
408
+ val rowOffset = t * vocabSize
409
+ var rowMax = Float.NEGATIVE_INFINITY
410
+ for (v in 0 until vocabSize) {
411
+ val value = logitsFlat[rowOffset + v]
412
+ if (value > rowMax) {
413
+ rowMax = value
414
+ }
415
+ }
416
+
417
+ var sumExp = 0.0
418
+ for (v in 0 until vocabSize) {
419
+ sumExp += exp((logitsFlat[rowOffset + v] - rowMax).toDouble())
420
+ }
421
+ val logDenom = rowMax + ln(max(sumExp, 1e-12))
422
+
423
+ for (v in 0 until vocabSize) {
424
+ out[t][v] = (logitsFlat[rowOffset + v] - logDenom).toFloat()
425
+ }
426
+ }
427
+
428
+ return out
429
+ }
430
+
431
+ private fun buildExpandedTarget(tokenIds: IntArray, blankId: Int): ExpandedTarget {
432
+ val stateSize = tokenIds.size * 2 + 1
433
+ val ids = IntArray(stateSize)
434
+ val tokenIndices = IntArray(stateSize) { -1 }
435
+
436
+ var s = 0
437
+ ids[s] = blankId
438
+ for (i in tokenIds.indices) {
439
+ s += 1
440
+ ids[s] = tokenIds[i]
441
+ tokenIndices[s] = i
442
+
443
+ s += 1
444
+ ids[s] = blankId
445
+ }
446
+
447
+ return ExpandedTarget(ids, tokenIndices)
448
+ }
449
+
450
+ private fun ctcBacktrack(
451
+ logProbs: Array<FloatArray>,
452
+ expandedTarget: IntArray,
453
+ blankId: Int,
454
+ ): IntArray {
455
+ val timeSteps = logProbs.size
456
+ val states = expandedTarget.size
457
+ if (timeSteps == 0 || states == 0) {
458
+ return IntArray(0)
459
+ }
460
+
461
+ val negInf = -1.0e30f
462
+ val trellis = Array(timeSteps) { FloatArray(states) { negInf } }
463
+
464
+ trellis[0][0] = safeLogProb(logProbs[0], expandedTarget[0])
465
+ if (states > 1) {
466
+ trellis[0][1] = safeLogProb(logProbs[0], expandedTarget[1])
467
+ }
468
+
469
+ for (t in 1 until timeSteps) {
470
+ val row = trellis[t]
471
+ val prev = trellis[t - 1]
472
+ for (s in 0 until states) {
473
+ var best = prev[s]
474
+ if (s > 0) {
475
+ best = max(best, prev[s - 1])
476
+ }
477
+ if (
478
+ s > 1 &&
479
+ expandedTarget[s] != blankId &&
480
+ expandedTarget[s] != expandedTarget[s - 2]
481
+ ) {
482
+ best = max(best, prev[s - 2])
483
+ }
484
+
485
+ if (best <= negInf / 2) {
486
+ row[s] = negInf
487
+ continue
488
+ }
489
+
490
+ row[s] = best + safeLogProb(logProbs[t], expandedTarget[s])
491
+ }
492
+ }
493
+
494
+ var state = if (
495
+ states > 1 &&
496
+ trellis[timeSteps - 1][states - 2] > trellis[timeSteps - 1][states - 1]
497
+ ) {
498
+ states - 2
499
+ } else {
500
+ states - 1
501
+ }
502
+
503
+ val path = IntArray(timeSteps)
504
+ path[timeSteps - 1] = state
505
+
506
+ for (t in (timeSteps - 1) downTo 1) {
507
+ val prev = trellis[t - 1]
508
+ var bestState = state
509
+ var bestScore = prev[state]
510
+
511
+ if (state > 0) {
512
+ val stepScore = prev[state - 1]
513
+ if (stepScore > bestScore) {
514
+ bestScore = stepScore
515
+ bestState = state - 1
516
+ }
517
+ }
518
+
519
+ if (
520
+ state > 1 &&
521
+ expandedTarget[state] != blankId &&
522
+ expandedTarget[state] != expandedTarget[state - 2]
523
+ ) {
524
+ val skipScore = prev[state - 2]
525
+ if (skipScore > bestScore) {
526
+ bestState = state - 2
527
+ }
528
+ }
529
+
530
+ state = bestState
531
+ path[t - 1] = state
532
+ }
533
+
534
+ return path
535
+ }
536
+
537
+ private fun safeLogProb(row: FloatArray, tokenId: Int): Float {
538
+ if (tokenId < 0 || tokenId >= row.size) {
539
+ return -1.0e30f
540
+ }
541
+ return row[tokenId]
542
+ }
543
+
544
+ private fun toWritableArray(items: List<AlignmentItem>): WritableArray {
545
+ val array = Arguments.createArray()
546
+ for (item in items) {
547
+ val map: WritableMap = Arguments.createMap()
548
+ map.putString("text", item.text)
549
+ map.putDouble("start", item.start)
550
+ map.putDouble("end", item.end)
551
+ array.pushMap(map)
552
+ }
553
+ return array
554
+ }
555
+ }
@@ -56,6 +56,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
56
56
  { instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
57
57
  { instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
58
58
  )
59
+ private val alignmentHelper = SherpaOnnxAlignmentHelper(reactApplicationContext)
59
60
  private val enhancementHelper = SherpaOnnxEnhancementHelper(
60
61
  reactApplicationContext,
61
62
  { modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
@@ -73,6 +74,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
73
74
  pcmCapture = null
74
75
  onlineSttHelper.shutdown()
75
76
  ttsHelper.shutdown()
77
+ alignmentHelper.shutdown()
76
78
  enhancementHelper.shutdown()
77
79
  }
78
80
 
@@ -899,6 +901,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
899
901
  val detectedModels = result["detectedModels"] as? ArrayList<*>
900
902
  ?: arrayListOf<HashMap<String, String>>()
901
903
  val modelTypeStr = result["modelType"] as? String
904
+ val paths = result["paths"] as? HashMap<*, *>
902
905
 
903
906
  val resultMap = Arguments.createMap()
904
907
  resultMap.putBoolean("success", success)
@@ -916,6 +919,12 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
916
919
  if (modelTypeStr != null) {
917
920
  resultMap.putString("modelType", modelTypeStr)
918
921
  }
922
+ val modelPath = paths?.get("model") as? String
923
+ if (!modelPath.isNullOrBlank()) {
924
+ val pathsMap = Arguments.createMap()
925
+ pathsMap.putString("model", modelPath)
926
+ resultMap.putMap("paths", pathsMap)
927
+ }
919
928
  if (!success) {
920
929
  val error = result["error"] as? String
921
930
  if (!error.isNullOrBlank()) {
@@ -964,6 +973,16 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
964
973
  ttsHelper.generateTtsWithTimestamps(instanceId, text, options, promise)
965
974
  }
966
975
 
976
+ override fun runCTCForcedAlignment(
977
+ modelPath: String,
978
+ audioPath: String,
979
+ text: String,
980
+ vocabJson: String,
981
+ promise: Promise,
982
+ ) {
983
+ alignmentHelper.runCTCForcedAlignment(modelPath, audioPath, text, vocabJson, promise)
984
+ }
985
+
967
986
  /**
968
987
  * Generate speech in streaming mode (emits chunk events).
969
988
  */
@@ -1074,6 +1093,59 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1074
1093
  enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
1075
1094
  }
1076
1095
 
1096
+ override fun detectAlignmentModel(
1097
+ modelDir: String,
1098
+ modelType: String?,
1099
+ promise: Promise
1100
+ ) {
1101
+ try {
1102
+ val result = Companion.nativeDetectAlignmentModel(modelDir, modelType ?: "auto")
1103
+ if (result == null) {
1104
+ android.util.Log.e(NAME, "DETECT_ERROR: Alignment model detection returned null")
1105
+ promise.reject("DETECT_ERROR", "Alignment model detection returned null")
1106
+ return
1107
+ }
1108
+ val success = result["success"] as? Boolean ?: false
1109
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
1110
+ ?: arrayListOf<HashMap<String, String>>()
1111
+ val modelTypeStr = result["modelType"] as? String
1112
+ val paths = result["paths"] as? HashMap<*, *>
1113
+
1114
+ val resultMap = Arguments.createMap()
1115
+ resultMap.putBoolean("success", success)
1116
+ val modelsArray = Arguments.createArray()
1117
+ for (model in detectedModels) {
1118
+ val modelMap = model as? HashMap<*, *>
1119
+ if (modelMap != null) {
1120
+ val entry = Arguments.createMap()
1121
+ entry.putString("type", modelMap["type"] as? String ?: "")
1122
+ entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
1123
+ modelsArray.pushMap(entry)
1124
+ }
1125
+ }
1126
+ resultMap.putArray("detectedModels", modelsArray)
1127
+ if (modelTypeStr != null) {
1128
+ resultMap.putString("modelType", modelTypeStr)
1129
+ }
1130
+ val alignmentModelPath = paths?.get("model") as? String
1131
+ if (!alignmentModelPath.isNullOrBlank()) {
1132
+ val pathsMap = Arguments.createMap()
1133
+ pathsMap.putString("model", alignmentModelPath)
1134
+ resultMap.putMap("paths", pathsMap)
1135
+ }
1136
+ if (!success) {
1137
+ val error = result["error"] as? String
1138
+ if (!error.isNullOrBlank()) {
1139
+ resultMap.putString("error", error)
1140
+ }
1141
+ }
1142
+ promise.resolve(resultMap)
1143
+ } catch (e: Exception) {
1144
+ android.util.Log.e(NAME, "DETECT_ERROR: Alignment model detection failed: ${e.message}", e)
1145
+ promise.reject("DETECT_ERROR", "Alignment model detection failed: ${e.message}", e)
1146
+ }
1147
+ }
1148
+
1077
1149
  override fun initializeEnhancement(
1078
1150
  instanceId: String,
1079
1151
  modelDir: String,
@@ -1362,6 +1434,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1362
1434
  @JvmStatic
1363
1435
  private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
1364
1436
 
1437
+ /** Model detection for subtitles/alignment: returns HashMap with success, error, detectedModels, modelType, paths. */
1438
+ @JvmStatic
1439
+ private external fun nativeDetectAlignmentModel(modelDir: String, modelType: String): HashMap<String, Any>?
1440
+
1365
1441
  /** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
1366
1442
  * outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
1367
1443
  * Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.