react-native-sherpa-onnx 0.4.0 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -0
- package/android/src/main/assets/model_licenses/alignment-models-license-status.csv +5 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.cpp +66 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-alignment-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-alignment.cpp +108 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +30 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.cpp +66 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-alignment.h +30 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAlignmentHelper.kt +555 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +76 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTextSegmenter.kt +330 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +180 -23
- package/ios/Resources/model_licenses/alignment-models-license-status.csv +5 -0
- package/ios/SherpaOnnx+Alignment.mm +704 -0
- package/ios/SherpaOnnx+STT.mm +6 -0
- package/ios/SherpaOnnx+TTS.mm +624 -50
- package/ios/model_detect/sherpa-onnx-model-detect-alignment.mm +108 -0
- package/ios/model_detect/sherpa-onnx-model-detect.h +31 -0
- package/ios/model_detect/sherpa-onnx-validate-alignment.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-alignment.mm +66 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +3 -1
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +6 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/alignment/index.js +27 -0
- package/lib/module/alignment/index.js.map +1 -0
- package/lib/module/alignment/types.js +2 -0
- package/lib/module/alignment/types.js.map +1 -0
- package/lib/module/alignment/vocab.js +40 -0
- package/lib/module/alignment/vocab.js.map +1 -0
- package/lib/module/download/paths.js +9 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/download/registry.js +17 -1
- package/lib/module/download/registry.js.map +1 -1
- package/lib/module/download/types.js +1 -0
- package/lib/module/download/types.js.map +1 -1
- package/lib/module/index.js +6 -4
- package/lib/module/index.js.map +1 -1
- package/lib/module/licenses.js +8 -2
- package/lib/module/licenses.js.map +1 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +68 -2
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/subtitles.js +400 -0
- package/lib/module/tts/subtitles.js.map +1 -0
- package/lib/module/tts/tempAudio.js +17 -0
- package/lib/module/tts/tempAudio.js.map +1 -0
- package/lib/module/tts/types.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +34 -3
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/alignment/index.d.ts +8 -0
- package/lib/typescript/src/alignment/index.d.ts.map +1 -0
- package/lib/typescript/src/alignment/types.d.ts +23 -0
- package/lib/typescript/src/alignment/types.d.ts.map +1 -0
- package/lib/typescript/src/alignment/vocab.d.ts +5 -0
- package/lib/typescript/src/alignment/vocab.d.ts.map +1 -0
- package/lib/typescript/src/download/paths.d.ts +5 -2
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/download/registry.d.ts.map +1 -1
- package/lib/typescript/src/download/types.d.ts +2 -1
- package/lib/typescript/src/download/types.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +1 -0
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/lib/typescript/src/stt/types.d.ts +5 -2
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts +2 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/subtitles.d.ts +24 -0
- package/lib/typescript/src/tts/subtitles.d.ts.map +1 -0
- package/lib/typescript/src/tts/tempAudio.d.ts +3 -0
- package/lib/typescript/src/tts/tempAudio.d.ts.map +1 -0
- package/lib/typescript/src/tts/types.d.ts +68 -2
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/package.json +6 -1
- package/scripts/alignment-models/README.md +90 -0
- package/scripts/alignment-models/build_and_upload.js +724 -0
- package/scripts/alignment-models/sources.csv +5 -0
- package/scripts/alignment-models/sync_alignment_license_status.js +123 -0
- package/src/NativeSherpaOnnx.ts +35 -3
- package/src/alignment/index.ts +41 -0
- package/src/alignment/types.ts +22 -0
- package/src/alignment/vocab.ts +38 -0
- package/src/download/paths.ts +18 -5
- package/src/download/registry.ts +23 -3
- package/src/download/types.ts +1 -0
- package/src/index.tsx +6 -4
- package/src/licenses.ts +12 -1
- package/src/stt/types.ts +5 -2
- package/src/tts/index.ts +110 -3
- package/src/tts/subtitles.ts +611 -0
- package/src/tts/tempAudio.ts +31 -0
- package/src/tts/types.ts +79 -2
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
package com.sherpaonnx
|
|
2
|
+
|
|
3
|
+
import ai.onnxruntime.OnnxTensor
|
|
4
|
+
import ai.onnxruntime.OrtEnvironment
|
|
5
|
+
import ai.onnxruntime.OrtSession
|
|
6
|
+
import ai.onnxruntime.TensorInfo
|
|
7
|
+
import android.net.Uri
|
|
8
|
+
import android.util.Log
|
|
9
|
+
import com.facebook.react.bridge.Arguments
|
|
10
|
+
import com.facebook.react.bridge.Promise
|
|
11
|
+
import com.facebook.react.bridge.ReactApplicationContext
|
|
12
|
+
import com.facebook.react.bridge.WritableArray
|
|
13
|
+
import com.facebook.react.bridge.WritableMap
|
|
14
|
+
import com.k2fsa.sherpa.onnx.WaveReader
|
|
15
|
+
import org.json.JSONObject
|
|
16
|
+
import java.io.File
|
|
17
|
+
import java.io.FileOutputStream
|
|
18
|
+
import java.nio.FloatBuffer
|
|
19
|
+
import java.util.Locale
|
|
20
|
+
import java.util.concurrent.Executors
|
|
21
|
+
import kotlin.math.exp
|
|
22
|
+
import kotlin.math.floor
|
|
23
|
+
import kotlin.math.ln
|
|
24
|
+
import kotlin.math.max
|
|
25
|
+
import kotlin.math.min
|
|
26
|
+
import kotlin.math.sqrt
|
|
27
|
+
|
|
28
|
+
internal class SherpaOnnxAlignmentHelper(
|
|
29
|
+
private val context: ReactApplicationContext
|
|
30
|
+
) {
|
|
31
|
+
private val executor = Executors.newSingleThreadExecutor()
|
|
32
|
+
|
|
33
|
+
private data class AlignmentItem(
|
|
34
|
+
val text: String,
|
|
35
|
+
val start: Double,
|
|
36
|
+
val end: Double,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
private data class ExpandedTarget(
|
|
40
|
+
val ids: IntArray,
|
|
41
|
+
val tokenIndices: IntArray,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
fun shutdown() {
|
|
45
|
+
executor.shutdownNow()
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
fun runCTCForcedAlignment(
|
|
49
|
+
modelPath: String,
|
|
50
|
+
audioPath: String,
|
|
51
|
+
text: String,
|
|
52
|
+
vocabJson: String,
|
|
53
|
+
promise: Promise,
|
|
54
|
+
) {
|
|
55
|
+
executor.execute {
|
|
56
|
+
var cleanupPath: String? = null
|
|
57
|
+
try {
|
|
58
|
+
if (modelPath.isBlank()) {
|
|
59
|
+
promise.reject("ALIGNMENT_ERROR", "modelPath is required")
|
|
60
|
+
return@execute
|
|
61
|
+
}
|
|
62
|
+
if (audioPath.isBlank()) {
|
|
63
|
+
promise.reject("ALIGNMENT_ERROR", "audioPath is required")
|
|
64
|
+
return@execute
|
|
65
|
+
}
|
|
66
|
+
if (text.isBlank()) {
|
|
67
|
+
promise.reject("ALIGNMENT_ERROR", "text is required")
|
|
68
|
+
return@execute
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
val resolvedAudio = resolveAudioPath(audioPath)
|
|
72
|
+
cleanupPath = resolvedAudio.second
|
|
73
|
+
|
|
74
|
+
val file = File(resolvedAudio.first)
|
|
75
|
+
if (!file.exists() || file.length() <= 0L) {
|
|
76
|
+
promise.reject("ALIGNMENT_ERROR", "Audio file does not exist or is empty: ${resolvedAudio.first}")
|
|
77
|
+
return@execute
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
val vocab = parseVocab(vocabJson)
|
|
81
|
+
val blankId = vocab["<pad>"] ?: 0
|
|
82
|
+
val wordBoundaryId = vocab["|"] ?: 4
|
|
83
|
+
|
|
84
|
+
val tokenTexts = buildTokenTexts(text, vocab, wordBoundaryId)
|
|
85
|
+
if (tokenTexts.isEmpty()) {
|
|
86
|
+
promise.reject("ALIGNMENT_ERROR", "Transcript has no alignable tokens for the provided vocabulary")
|
|
87
|
+
return@execute
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
val tokenIds = IntArray(tokenTexts.size) { index ->
|
|
91
|
+
vocab[tokenTexts[index]] ?: blankId
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
val wave = WaveReader.readWave(resolvedAudio.first)
|
|
95
|
+
val rawSamples = wave.samples ?: FloatArray(0)
|
|
96
|
+
if (rawSamples.isEmpty()) {
|
|
97
|
+
promise.reject("ALIGNMENT_ERROR", "Could not decode WAV samples from: ${resolvedAudio.first}")
|
|
98
|
+
return@execute
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
val mono16k = if (wave.sampleRate == 16000) {
|
|
102
|
+
rawSamples
|
|
103
|
+
} else {
|
|
104
|
+
resampleLinear(rawSamples, wave.sampleRate, 16000)
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
val normalized = normalizeAudio(mono16k)
|
|
108
|
+
if (normalized.isEmpty()) {
|
|
109
|
+
promise.reject("ALIGNMENT_ERROR", "Audio is empty after preprocessing")
|
|
110
|
+
return@execute
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
val logits = runInference(modelPath, normalized)
|
|
114
|
+
if (logits.isEmpty() || logits[0].isEmpty()) {
|
|
115
|
+
promise.reject("ALIGNMENT_ERROR", "Model inference returned empty logits")
|
|
116
|
+
return@execute
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
val expanded = buildExpandedTarget(tokenIds, blankId)
|
|
120
|
+
val path = ctcBacktrack(logits, expanded.ids, blankId)
|
|
121
|
+
|
|
122
|
+
val frameIndicesByToken = Array(tokenIds.size) { mutableListOf<Int>() }
|
|
123
|
+
for (t in path.indices) {
|
|
124
|
+
val state = path[t]
|
|
125
|
+
if (state < 0 || state >= expanded.tokenIndices.size) {
|
|
126
|
+
continue
|
|
127
|
+
}
|
|
128
|
+
val tokenIndex = expanded.tokenIndices[state]
|
|
129
|
+
val tokenId = expanded.ids[state]
|
|
130
|
+
if (tokenIndex >= 0 && tokenIndex < frameIndicesByToken.size && tokenId != blankId) {
|
|
131
|
+
frameIndicesByToken[tokenIndex].add(t)
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
val charItems = mutableListOf<AlignmentItem>()
|
|
136
|
+
var fallbackEndFrame = 0
|
|
137
|
+
|
|
138
|
+
for (i in tokenTexts.indices) {
|
|
139
|
+
val token = tokenTexts[i]
|
|
140
|
+
if (token == "|") {
|
|
141
|
+
continue
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
val frames = frameIndicesByToken[i]
|
|
145
|
+
val startFrame: Int
|
|
146
|
+
val endFrameExclusive: Int
|
|
147
|
+
if (frames.isNotEmpty()) {
|
|
148
|
+
startFrame = frames.first()
|
|
149
|
+
endFrameExclusive = frames.last() + 1
|
|
150
|
+
fallbackEndFrame = max(fallbackEndFrame, endFrameExclusive)
|
|
151
|
+
} else {
|
|
152
|
+
startFrame = fallbackEndFrame
|
|
153
|
+
endFrameExclusive = fallbackEndFrame
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
val start = startFrame * 0.02
|
|
157
|
+
val end = max(start, endFrameExclusive * 0.02)
|
|
158
|
+
charItems.add(AlignmentItem(token, start, end))
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
val wordItems = mutableListOf<AlignmentItem>()
|
|
162
|
+
val currentWord = StringBuilder()
|
|
163
|
+
var wordStart = 0.0
|
|
164
|
+
var wordEnd = 0.0
|
|
165
|
+
var charCursor = 0
|
|
166
|
+
|
|
167
|
+
for (token in tokenTexts) {
|
|
168
|
+
if (token == "|") {
|
|
169
|
+
if (currentWord.isNotEmpty()) {
|
|
170
|
+
wordItems.add(AlignmentItem(currentWord.toString(), wordStart, wordEnd))
|
|
171
|
+
currentWord.clear()
|
|
172
|
+
}
|
|
173
|
+
continue
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
val charItem = charItems.getOrNull(charCursor)
|
|
177
|
+
charCursor += 1
|
|
178
|
+
if (charItem == null) {
|
|
179
|
+
continue
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
if (currentWord.isEmpty()) {
|
|
183
|
+
wordStart = charItem.start
|
|
184
|
+
wordEnd = charItem.end
|
|
185
|
+
} else {
|
|
186
|
+
wordEnd = max(wordEnd, charItem.end)
|
|
187
|
+
}
|
|
188
|
+
currentWord.append(charItem.text)
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
if (currentWord.isNotEmpty()) {
|
|
192
|
+
wordItems.add(AlignmentItem(currentWord.toString(), wordStart, wordEnd))
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
val result = Arguments.createMap()
|
|
196
|
+
result.putArray("words", toWritableArray(wordItems))
|
|
197
|
+
result.putArray("chars", toWritableArray(charItems))
|
|
198
|
+
promise.resolve(result)
|
|
199
|
+
} catch (e: Exception) {
|
|
200
|
+
Log.e("SherpaOnnxAlignment", "ALIGNMENT_ERROR: ${e.message}", e)
|
|
201
|
+
promise.reject("ALIGNMENT_ERROR", e.message ?: "CTC alignment failed", e)
|
|
202
|
+
} finally {
|
|
203
|
+
if (cleanupPath != null) {
|
|
204
|
+
try {
|
|
205
|
+
File(cleanupPath).delete()
|
|
206
|
+
} catch (_: Exception) {
|
|
207
|
+
// ignore cleanup errors
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
private fun resolveAudioPath(audioPath: String): Pair<String, String?> {
|
|
215
|
+
if (!audioPath.startsWith("content://")) {
|
|
216
|
+
return Pair(audioPath, null)
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
val uri = Uri.parse(audioPath)
|
|
220
|
+
val tempFile = File.createTempFile("alignment_input_", ".wav", context.cacheDir)
|
|
221
|
+
context.contentResolver.openInputStream(uri)?.use { input ->
|
|
222
|
+
FileOutputStream(tempFile).use { output ->
|
|
223
|
+
input.copyTo(output)
|
|
224
|
+
}
|
|
225
|
+
} ?: throw IllegalStateException("Could not open content URI: $audioPath")
|
|
226
|
+
|
|
227
|
+
return Pair(tempFile.absolutePath, tempFile.absolutePath)
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
private fun parseVocab(vocabJson: String): Map<String, Int> {
|
|
231
|
+
val obj = JSONObject(vocabJson)
|
|
232
|
+
val out = linkedMapOf<String, Int>()
|
|
233
|
+
val keys = obj.keys()
|
|
234
|
+
while (keys.hasNext()) {
|
|
235
|
+
val key = keys.next()
|
|
236
|
+
if (key.isBlank()) {
|
|
237
|
+
continue
|
|
238
|
+
}
|
|
239
|
+
val value = obj.optInt(key, Int.MIN_VALUE)
|
|
240
|
+
if (value != Int.MIN_VALUE) {
|
|
241
|
+
out[key] = value
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
if (out.isEmpty()) {
|
|
245
|
+
throw IllegalArgumentException("Vocabulary JSON is empty")
|
|
246
|
+
}
|
|
247
|
+
return out
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
private fun buildTokenTexts(
|
|
251
|
+
text: String,
|
|
252
|
+
vocab: Map<String, Int>,
|
|
253
|
+
wordBoundaryId: Int,
|
|
254
|
+
): List<String> {
|
|
255
|
+
val out = mutableListOf<String>()
|
|
256
|
+
val uppercase = text.uppercase(Locale.US)
|
|
257
|
+
|
|
258
|
+
for (char in uppercase) {
|
|
259
|
+
if (char.isWhitespace()) {
|
|
260
|
+
if (out.isNotEmpty() && out.last() != "|") {
|
|
261
|
+
out.add("|")
|
|
262
|
+
}
|
|
263
|
+
continue
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
val normalized = when (char) {
|
|
267
|
+
'’', '`', '´' -> '\''
|
|
268
|
+
else -> char
|
|
269
|
+
}
|
|
270
|
+
val token = normalized.toString()
|
|
271
|
+
if (vocab.containsKey(token)) {
|
|
272
|
+
out.add(token)
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
while (out.firstOrNull() == "|") {
|
|
277
|
+
out.removeAt(0)
|
|
278
|
+
}
|
|
279
|
+
while (out.lastOrNull() == "|") {
|
|
280
|
+
out.removeAt(out.lastIndex)
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if (!vocab.containsKey("|") || vocab["|"] != wordBoundaryId) {
|
|
284
|
+
return out.filter { it != "|" }
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
return out
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
private fun resampleLinear(
|
|
291
|
+
input: FloatArray,
|
|
292
|
+
sourceSampleRate: Int,
|
|
293
|
+
targetSampleRate: Int,
|
|
294
|
+
): FloatArray {
|
|
295
|
+
if (input.isEmpty() || sourceSampleRate <= 0 || targetSampleRate <= 0) {
|
|
296
|
+
return FloatArray(0)
|
|
297
|
+
}
|
|
298
|
+
if (sourceSampleRate == targetSampleRate) {
|
|
299
|
+
return input
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
val outputLength = max(1, floor(input.size.toDouble() * targetSampleRate / sourceSampleRate).toInt())
|
|
303
|
+
val output = FloatArray(outputLength)
|
|
304
|
+
val ratio = sourceSampleRate.toDouble() / targetSampleRate.toDouble()
|
|
305
|
+
|
|
306
|
+
for (i in 0 until outputLength) {
|
|
307
|
+
val srcPos = i * ratio
|
|
308
|
+
val leftIndex = floor(srcPos).toInt()
|
|
309
|
+
val rightIndex = min(leftIndex + 1, input.lastIndex)
|
|
310
|
+
val frac = srcPos - leftIndex
|
|
311
|
+
val left = input[min(leftIndex, input.lastIndex)]
|
|
312
|
+
val right = input[rightIndex]
|
|
313
|
+
output[i] = (left + (right - left) * frac).toFloat()
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
return output
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
private fun normalizeAudio(input: FloatArray): FloatArray {
|
|
320
|
+
if (input.isEmpty()) {
|
|
321
|
+
return input
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
var sum = 0.0
|
|
325
|
+
for (sample in input) {
|
|
326
|
+
sum += sample
|
|
327
|
+
}
|
|
328
|
+
val mean = sum / input.size
|
|
329
|
+
|
|
330
|
+
var varianceSum = 0.0
|
|
331
|
+
for (sample in input) {
|
|
332
|
+
val centered = sample - mean
|
|
333
|
+
varianceSum += centered * centered
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
val std = sqrt(max(varianceSum / input.size, 1e-12))
|
|
337
|
+
val out = FloatArray(input.size)
|
|
338
|
+
for (i in input.indices) {
|
|
339
|
+
out[i] = ((input[i] - mean) / std).toFloat()
|
|
340
|
+
}
|
|
341
|
+
return out
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
private fun runInference(modelPath: String, samples: FloatArray): Array<FloatArray> {
|
|
345
|
+
val env = OrtEnvironment.getEnvironment()
|
|
346
|
+
|
|
347
|
+
OrtSession.SessionOptions().use { sessionOptions ->
|
|
348
|
+
env.createSession(modelPath, sessionOptions).use { session ->
|
|
349
|
+
val inputName = session.inputNames.firstOrNull()
|
|
350
|
+
?: throw IllegalStateException("Alignment model has no input")
|
|
351
|
+
|
|
352
|
+
val inputShape = longArrayOf(1L, samples.size.toLong())
|
|
353
|
+
OnnxTensor.createTensor(env, FloatBuffer.wrap(samples), inputShape).use { inputTensor ->
|
|
354
|
+
val outputs = session.run(mapOf(inputName to inputTensor))
|
|
355
|
+
outputs.use { result ->
|
|
356
|
+
val outputTensor = result.get(0) as? OnnxTensor
|
|
357
|
+
?: throw IllegalStateException("Alignment model output is not a tensor")
|
|
358
|
+
|
|
359
|
+
val info = outputTensor.info as? TensorInfo
|
|
360
|
+
?: throw IllegalStateException("Alignment tensor info missing")
|
|
361
|
+
|
|
362
|
+
val shape = info.shape
|
|
363
|
+
val floatBuffer = outputTensor.floatBuffer
|
|
364
|
+
floatBuffer.rewind()
|
|
365
|
+
|
|
366
|
+
val totalValues = floatBuffer.remaining()
|
|
367
|
+
if (totalValues <= 0) {
|
|
368
|
+
return emptyArray()
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
val logitsFlat = FloatArray(totalValues)
|
|
372
|
+
floatBuffer.get(logitsFlat)
|
|
373
|
+
|
|
374
|
+
val (frames, vocabSize) = when {
|
|
375
|
+
shape.size >= 3 -> {
|
|
376
|
+
val t = shape[1].toInt()
|
|
377
|
+
val v = shape[2].toInt()
|
|
378
|
+
Pair(max(1, t), max(1, v))
|
|
379
|
+
}
|
|
380
|
+
shape.size == 2 -> {
|
|
381
|
+
val t = shape[0].toInt()
|
|
382
|
+
val v = shape[1].toInt()
|
|
383
|
+
Pair(max(1, t), max(1, v))
|
|
384
|
+
}
|
|
385
|
+
else -> {
|
|
386
|
+
Pair(1, max(1, totalValues))
|
|
387
|
+
}
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
val safeFrames = max(1, min(frames, totalValues))
|
|
391
|
+
val safeVocab = max(1, min(vocabSize, totalValues / safeFrames))
|
|
392
|
+
|
|
393
|
+
return logSoftmax(logitsFlat, safeFrames, safeVocab)
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
private fun logSoftmax(
|
|
401
|
+
logitsFlat: FloatArray,
|
|
402
|
+
frames: Int,
|
|
403
|
+
vocabSize: Int,
|
|
404
|
+
): Array<FloatArray> {
|
|
405
|
+
val out = Array(frames) { FloatArray(vocabSize) }
|
|
406
|
+
|
|
407
|
+
for (t in 0 until frames) {
|
|
408
|
+
val rowOffset = t * vocabSize
|
|
409
|
+
var rowMax = Float.NEGATIVE_INFINITY
|
|
410
|
+
for (v in 0 until vocabSize) {
|
|
411
|
+
val value = logitsFlat[rowOffset + v]
|
|
412
|
+
if (value > rowMax) {
|
|
413
|
+
rowMax = value
|
|
414
|
+
}
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
var sumExp = 0.0
|
|
418
|
+
for (v in 0 until vocabSize) {
|
|
419
|
+
sumExp += exp((logitsFlat[rowOffset + v] - rowMax).toDouble())
|
|
420
|
+
}
|
|
421
|
+
val logDenom = rowMax + ln(max(sumExp, 1e-12))
|
|
422
|
+
|
|
423
|
+
for (v in 0 until vocabSize) {
|
|
424
|
+
out[t][v] = (logitsFlat[rowOffset + v] - logDenom).toFloat()
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
return out
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
private fun buildExpandedTarget(tokenIds: IntArray, blankId: Int): ExpandedTarget {
|
|
432
|
+
val stateSize = tokenIds.size * 2 + 1
|
|
433
|
+
val ids = IntArray(stateSize)
|
|
434
|
+
val tokenIndices = IntArray(stateSize) { -1 }
|
|
435
|
+
|
|
436
|
+
var s = 0
|
|
437
|
+
ids[s] = blankId
|
|
438
|
+
for (i in tokenIds.indices) {
|
|
439
|
+
s += 1
|
|
440
|
+
ids[s] = tokenIds[i]
|
|
441
|
+
tokenIndices[s] = i
|
|
442
|
+
|
|
443
|
+
s += 1
|
|
444
|
+
ids[s] = blankId
|
|
445
|
+
}
|
|
446
|
+
|
|
447
|
+
return ExpandedTarget(ids, tokenIndices)
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
private fun ctcBacktrack(
|
|
451
|
+
logProbs: Array<FloatArray>,
|
|
452
|
+
expandedTarget: IntArray,
|
|
453
|
+
blankId: Int,
|
|
454
|
+
): IntArray {
|
|
455
|
+
val timeSteps = logProbs.size
|
|
456
|
+
val states = expandedTarget.size
|
|
457
|
+
if (timeSteps == 0 || states == 0) {
|
|
458
|
+
return IntArray(0)
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
val negInf = -1.0e30f
|
|
462
|
+
val trellis = Array(timeSteps) { FloatArray(states) { negInf } }
|
|
463
|
+
|
|
464
|
+
trellis[0][0] = safeLogProb(logProbs[0], expandedTarget[0])
|
|
465
|
+
if (states > 1) {
|
|
466
|
+
trellis[0][1] = safeLogProb(logProbs[0], expandedTarget[1])
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
for (t in 1 until timeSteps) {
|
|
470
|
+
val row = trellis[t]
|
|
471
|
+
val prev = trellis[t - 1]
|
|
472
|
+
for (s in 0 until states) {
|
|
473
|
+
var best = prev[s]
|
|
474
|
+
if (s > 0) {
|
|
475
|
+
best = max(best, prev[s - 1])
|
|
476
|
+
}
|
|
477
|
+
if (
|
|
478
|
+
s > 1 &&
|
|
479
|
+
expandedTarget[s] != blankId &&
|
|
480
|
+
expandedTarget[s] != expandedTarget[s - 2]
|
|
481
|
+
) {
|
|
482
|
+
best = max(best, prev[s - 2])
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
if (best <= negInf / 2) {
|
|
486
|
+
row[s] = negInf
|
|
487
|
+
continue
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
row[s] = best + safeLogProb(logProbs[t], expandedTarget[s])
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
var state = if (
|
|
495
|
+
states > 1 &&
|
|
496
|
+
trellis[timeSteps - 1][states - 2] > trellis[timeSteps - 1][states - 1]
|
|
497
|
+
) {
|
|
498
|
+
states - 2
|
|
499
|
+
} else {
|
|
500
|
+
states - 1
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
val path = IntArray(timeSteps)
|
|
504
|
+
path[timeSteps - 1] = state
|
|
505
|
+
|
|
506
|
+
for (t in (timeSteps - 1) downTo 1) {
|
|
507
|
+
val prev = trellis[t - 1]
|
|
508
|
+
var bestState = state
|
|
509
|
+
var bestScore = prev[state]
|
|
510
|
+
|
|
511
|
+
if (state > 0) {
|
|
512
|
+
val stepScore = prev[state - 1]
|
|
513
|
+
if (stepScore > bestScore) {
|
|
514
|
+
bestScore = stepScore
|
|
515
|
+
bestState = state - 1
|
|
516
|
+
}
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
if (
|
|
520
|
+
state > 1 &&
|
|
521
|
+
expandedTarget[state] != blankId &&
|
|
522
|
+
expandedTarget[state] != expandedTarget[state - 2]
|
|
523
|
+
) {
|
|
524
|
+
val skipScore = prev[state - 2]
|
|
525
|
+
if (skipScore > bestScore) {
|
|
526
|
+
bestState = state - 2
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
state = bestState
|
|
531
|
+
path[t - 1] = state
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
return path
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
private fun safeLogProb(row: FloatArray, tokenId: Int): Float {
|
|
538
|
+
if (tokenId < 0 || tokenId >= row.size) {
|
|
539
|
+
return -1.0e30f
|
|
540
|
+
}
|
|
541
|
+
return row[tokenId]
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
private fun toWritableArray(items: List<AlignmentItem>): WritableArray {
|
|
545
|
+
val array = Arguments.createArray()
|
|
546
|
+
for (item in items) {
|
|
547
|
+
val map: WritableMap = Arguments.createMap()
|
|
548
|
+
map.putString("text", item.text)
|
|
549
|
+
map.putDouble("start", item.start)
|
|
550
|
+
map.putDouble("end", item.end)
|
|
551
|
+
array.pushMap(map)
|
|
552
|
+
}
|
|
553
|
+
return array
|
|
554
|
+
}
|
|
555
|
+
}
|
|
@@ -56,6 +56,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
56
56
|
{ instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
|
|
57
57
|
{ instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
|
|
58
58
|
)
|
|
59
|
+
private val alignmentHelper = SherpaOnnxAlignmentHelper(reactApplicationContext)
|
|
59
60
|
private val enhancementHelper = SherpaOnnxEnhancementHelper(
|
|
60
61
|
reactApplicationContext,
|
|
61
62
|
{ modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
|
|
@@ -73,6 +74,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
73
74
|
pcmCapture = null
|
|
74
75
|
onlineSttHelper.shutdown()
|
|
75
76
|
ttsHelper.shutdown()
|
|
77
|
+
alignmentHelper.shutdown()
|
|
76
78
|
enhancementHelper.shutdown()
|
|
77
79
|
}
|
|
78
80
|
|
|
@@ -899,6 +901,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
899
901
|
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
900
902
|
?: arrayListOf<HashMap<String, String>>()
|
|
901
903
|
val modelTypeStr = result["modelType"] as? String
|
|
904
|
+
val paths = result["paths"] as? HashMap<*, *>
|
|
902
905
|
|
|
903
906
|
val resultMap = Arguments.createMap()
|
|
904
907
|
resultMap.putBoolean("success", success)
|
|
@@ -916,6 +919,12 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
916
919
|
if (modelTypeStr != null) {
|
|
917
920
|
resultMap.putString("modelType", modelTypeStr)
|
|
918
921
|
}
|
|
922
|
+
val modelPath = paths?.get("model") as? String
|
|
923
|
+
if (!modelPath.isNullOrBlank()) {
|
|
924
|
+
val pathsMap = Arguments.createMap()
|
|
925
|
+
pathsMap.putString("model", modelPath)
|
|
926
|
+
resultMap.putMap("paths", pathsMap)
|
|
927
|
+
}
|
|
919
928
|
if (!success) {
|
|
920
929
|
val error = result["error"] as? String
|
|
921
930
|
if (!error.isNullOrBlank()) {
|
|
@@ -964,6 +973,16 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
964
973
|
ttsHelper.generateTtsWithTimestamps(instanceId, text, options, promise)
|
|
965
974
|
}
|
|
966
975
|
|
|
976
|
+
override fun runCTCForcedAlignment(
|
|
977
|
+
modelPath: String,
|
|
978
|
+
audioPath: String,
|
|
979
|
+
text: String,
|
|
980
|
+
vocabJson: String,
|
|
981
|
+
promise: Promise,
|
|
982
|
+
) {
|
|
983
|
+
alignmentHelper.runCTCForcedAlignment(modelPath, audioPath, text, vocabJson, promise)
|
|
984
|
+
}
|
|
985
|
+
|
|
967
986
|
/**
|
|
968
987
|
* Generate speech in streaming mode (emits chunk events).
|
|
969
988
|
*/
|
|
@@ -1074,6 +1093,59 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1074
1093
|
enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
|
|
1075
1094
|
}
|
|
1076
1095
|
|
|
1096
|
+
override fun detectAlignmentModel(
|
|
1097
|
+
modelDir: String,
|
|
1098
|
+
modelType: String?,
|
|
1099
|
+
promise: Promise
|
|
1100
|
+
) {
|
|
1101
|
+
try {
|
|
1102
|
+
val result = Companion.nativeDetectAlignmentModel(modelDir, modelType ?: "auto")
|
|
1103
|
+
if (result == null) {
|
|
1104
|
+
android.util.Log.e(NAME, "DETECT_ERROR: Alignment model detection returned null")
|
|
1105
|
+
promise.reject("DETECT_ERROR", "Alignment model detection returned null")
|
|
1106
|
+
return
|
|
1107
|
+
}
|
|
1108
|
+
val success = result["success"] as? Boolean ?: false
|
|
1109
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
1110
|
+
?: arrayListOf<HashMap<String, String>>()
|
|
1111
|
+
val modelTypeStr = result["modelType"] as? String
|
|
1112
|
+
val paths = result["paths"] as? HashMap<*, *>
|
|
1113
|
+
|
|
1114
|
+
val resultMap = Arguments.createMap()
|
|
1115
|
+
resultMap.putBoolean("success", success)
|
|
1116
|
+
val modelsArray = Arguments.createArray()
|
|
1117
|
+
for (model in detectedModels) {
|
|
1118
|
+
val modelMap = model as? HashMap<*, *>
|
|
1119
|
+
if (modelMap != null) {
|
|
1120
|
+
val entry = Arguments.createMap()
|
|
1121
|
+
entry.putString("type", modelMap["type"] as? String ?: "")
|
|
1122
|
+
entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
|
|
1123
|
+
modelsArray.pushMap(entry)
|
|
1124
|
+
}
|
|
1125
|
+
}
|
|
1126
|
+
resultMap.putArray("detectedModels", modelsArray)
|
|
1127
|
+
if (modelTypeStr != null) {
|
|
1128
|
+
resultMap.putString("modelType", modelTypeStr)
|
|
1129
|
+
}
|
|
1130
|
+
val alignmentModelPath = paths?.get("model") as? String
|
|
1131
|
+
if (!alignmentModelPath.isNullOrBlank()) {
|
|
1132
|
+
val pathsMap = Arguments.createMap()
|
|
1133
|
+
pathsMap.putString("model", alignmentModelPath)
|
|
1134
|
+
resultMap.putMap("paths", pathsMap)
|
|
1135
|
+
}
|
|
1136
|
+
if (!success) {
|
|
1137
|
+
val error = result["error"] as? String
|
|
1138
|
+
if (!error.isNullOrBlank()) {
|
|
1139
|
+
resultMap.putString("error", error)
|
|
1140
|
+
}
|
|
1141
|
+
}
|
|
1142
|
+
promise.resolve(resultMap)
|
|
1143
|
+
} catch (e: Exception) {
|
|
1144
|
+
android.util.Log.e(NAME, "DETECT_ERROR: Alignment model detection failed: ${e.message}", e)
|
|
1145
|
+
promise.reject("DETECT_ERROR", "Alignment model detection failed: ${e.message}", e)
|
|
1146
|
+
}
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1077
1149
|
override fun initializeEnhancement(
|
|
1078
1150
|
instanceId: String,
|
|
1079
1151
|
modelDir: String,
|
|
@@ -1362,6 +1434,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1362
1434
|
@JvmStatic
|
|
1363
1435
|
private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1364
1436
|
|
|
1437
|
+
/** Model detection for subtitles/alignment: returns HashMap with success, error, detectedModels, modelType, paths. */
|
|
1438
|
+
@JvmStatic
|
|
1439
|
+
private external fun nativeDetectAlignmentModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1440
|
+
|
|
1365
1441
|
/** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
|
|
1366
1442
|
* outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
|
|
1367
1443
|
* Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.
|