react-native-sherpa-onnx 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +21 -7
- package/SherpaOnnx.podspec +1 -1
- package/android/build.gradle +35 -26
- package/android/prebuilt-download.gradle +27 -14
- package/android/src/main/cpp/CMakeLists.txt +51 -17
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +14 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +16 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +19 -2
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +2 -1
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +1 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +114 -8
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxOnlineSttHelper.kt +535 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +10 -10
- package/ios/SherpaOnnx+OnlineSTT.mm +365 -0
- package/ios/SherpaOnnx+TTS.mm +35 -9
- package/ios/SherpaOnnx.mm +6 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +3 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +16 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +19 -2
- package/ios/model_detect/sherpa-onnx-model-detect.h +2 -1
- package/ios/online_stt/sherpa-onnx-online-stt-wrapper.h +85 -0
- package/ios/online_stt/sherpa-onnx-online-stt-wrapper.mm +270 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/index.js +2 -2
- package/lib/module/stt/index.js +4 -0
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/streaming.js +257 -0
- package/lib/module/stt/streaming.js.map +1 -0
- package/lib/module/stt/streamingTypes.js +38 -0
- package/lib/module/stt/streamingTypes.js.map +1 -0
- package/lib/module/tts/index.js +4 -43
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/streaming.js +220 -0
- package/lib/module/tts/streaming.js.map +1 -0
- package/lib/module/tts/streamingTypes.js +4 -0
- package/lib/module/tts/streamingTypes.js.map +1 -0
- package/lib/module/tts/types.js +8 -1
- package/lib/module/tts/types.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +66 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +3 -0
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/streaming.d.ts +42 -0
- package/lib/typescript/src/stt/streaming.d.ts.map +1 -0
- package/lib/typescript/src/stt/streamingTypes.d.ts +122 -0
- package/lib/typescript/src/stt/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/tts/index.d.ts +3 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/streaming.d.ts +24 -0
- package/lib/typescript/src/tts/streaming.d.ts.map +1 -0
- package/lib/typescript/src/tts/streamingTypes.d.ts +27 -0
- package/lib/typescript/src/tts/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/tts/types.d.ts +19 -6
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/package.json +1 -2
- package/src/NativeSherpaOnnx.ts +95 -0
- package/src/index.tsx +2 -2
- package/src/stt/index.ts +17 -0
- package/src/stt/streaming.ts +361 -0
- package/src/stt/streamingTypes.ts +151 -0
- package/src/tts/index.ts +6 -66
- package/src/tts/streaming.ts +336 -0
- package/src/tts/streamingTypes.ts +54 -0
- package/src/tts/types.ts +20 -10
- package/android/codegen.gradle +0 -57
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
package com.sherpaonnx
|
|
2
|
+
|
|
3
|
+
import android.content.Context
|
|
4
|
+
import android.net.Uri
|
|
5
|
+
import android.util.Log
|
|
6
|
+
import com.facebook.react.bridge.Arguments
|
|
7
|
+
import com.facebook.react.bridge.Promise
|
|
8
|
+
import com.facebook.react.bridge.ReadableArray
|
|
9
|
+
import com.facebook.react.bridge.WritableMap
|
|
10
|
+
import com.k2fsa.sherpa.onnx.EndpointConfig
|
|
11
|
+
import com.k2fsa.sherpa.onnx.EndpointRule
|
|
12
|
+
import com.k2fsa.sherpa.onnx.FeatureConfig
|
|
13
|
+
import com.k2fsa.sherpa.onnx.OnlineModelConfig
|
|
14
|
+
import com.k2fsa.sherpa.onnx.OnlineNeMoCtcModelConfig
|
|
15
|
+
import com.k2fsa.sherpa.onnx.OnlineParaformerModelConfig
|
|
16
|
+
import com.k2fsa.sherpa.onnx.OnlineRecognizer
|
|
17
|
+
import com.k2fsa.sherpa.onnx.OnlineRecognizerConfig
|
|
18
|
+
import com.k2fsa.sherpa.onnx.OnlineRecognizerResult
|
|
19
|
+
import com.k2fsa.sherpa.onnx.OnlineStream
|
|
20
|
+
import com.k2fsa.sherpa.onnx.OnlineToneCtcModelConfig
|
|
21
|
+
import com.k2fsa.sherpa.onnx.OnlineTransducerModelConfig
|
|
22
|
+
import com.k2fsa.sherpa.onnx.OnlineZipformer2CtcModelConfig
|
|
23
|
+
import java.io.File
|
|
24
|
+
import java.util.concurrent.ConcurrentHashMap
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* Helper for streaming (online) STT using sherpa-onnx OnlineRecognizer + OnlineStream.
|
|
28
|
+
* Manages recognizer instances and streams; resolves model paths by scanning the model directory.
|
|
29
|
+
*/
|
|
30
|
+
internal class SherpaOnnxOnlineSttHelper(
|
|
31
|
+
private val context: Context,
|
|
32
|
+
private val logTag: String
|
|
33
|
+
) {
|
|
34
|
+
|
|
35
|
+
private data class OnlineSttInstance(
|
|
36
|
+
val recognizer: OnlineRecognizer,
|
|
37
|
+
val config: OnlineRecognizerConfig,
|
|
38
|
+
val streams: MutableMap<String, OnlineStream> = mutableMapOf()
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
private val instances = ConcurrentHashMap<String, OnlineSttInstance>()
|
|
42
|
+
private val streamToInstance = ConcurrentHashMap<String, String>()
|
|
43
|
+
|
|
44
|
+
private fun getInstance(instanceId: String): OnlineSttInstance? = instances[instanceId]
|
|
45
|
+
|
|
46
|
+
private fun getStream(streamId: String): Pair<OnlineSttInstance, OnlineStream>? {
|
|
47
|
+
val instanceId = streamToInstance[streamId] ?: return null
|
|
48
|
+
val inst = instances[instanceId] ?: return null
|
|
49
|
+
val stream = inst.streams[streamId] ?: return null
|
|
50
|
+
return inst to stream
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
private fun resolveContentUriToFile(path: String, cacheFilePrefix: String): String {
|
|
54
|
+
if (!path.startsWith("content://")) return path
|
|
55
|
+
val uri = Uri.parse(path)
|
|
56
|
+
val cacheFile = File(context.cacheDir, "${cacheFilePrefix}_${System.nanoTime()}")
|
|
57
|
+
context.contentResolver.openInputStream(uri)?.use { input ->
|
|
58
|
+
cacheFile.outputStream().use { output -> input.copyTo(output) }
|
|
59
|
+
} ?: throw IllegalStateException("File is not readable (content URI could not be opened): $path")
|
|
60
|
+
return cacheFile.absolutePath
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
private fun resolveFilePaths(pathsString: String, cacheFilePrefix: String): String {
|
|
64
|
+
if (pathsString.isBlank()) return pathsString
|
|
65
|
+
return pathsString.split(',').map { it.trim() }.filter { it.isNotEmpty() }
|
|
66
|
+
.mapIndexed { index, p -> resolveContentUriToFile(p, "${cacheFilePrefix}_$index") }
|
|
67
|
+
.joinToString(",")
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
/**
|
|
71
|
+
* Scan model directory for files matching the given online model type.
|
|
72
|
+
* Returns a map with keys: encoder, decoder, joiner, tokens (transducer/paraformer) or model, tokens (ctc types).
|
|
73
|
+
*/
|
|
74
|
+
private fun scanOnlineModelPaths(modelDir: String, modelType: String): Map<String, String> {
|
|
75
|
+
val dir = File(modelDir)
|
|
76
|
+
if (!dir.exists() || !dir.isDirectory) {
|
|
77
|
+
throw IllegalArgumentException("Model directory does not exist or is not a directory: $modelDir")
|
|
78
|
+
}
|
|
79
|
+
val files = dir.listFiles()?.filter { it.isFile }.orEmpty()
|
|
80
|
+
|
|
81
|
+
fun firstFile(vararg prefixes: String, suffix: String = ".onnx"): String =
|
|
82
|
+
prefixes.firstNotNullOfOrNull { prefix ->
|
|
83
|
+
files.firstOrNull { it.name.startsWith(prefix) && it.name.endsWith(suffix) }?.absolutePath
|
|
84
|
+
}.orEmpty()
|
|
85
|
+
|
|
86
|
+
val tokensPath = files.firstOrNull { it.name == "tokens.txt" }?.absolutePath ?: ""
|
|
87
|
+
|
|
88
|
+
return when (modelType) {
|
|
89
|
+
"transducer" -> mapOf(
|
|
90
|
+
"encoder" to firstFile("encoder"),
|
|
91
|
+
"decoder" to firstFile("decoder"),
|
|
92
|
+
"joiner" to firstFile("joiner"),
|
|
93
|
+
"tokens" to tokensPath
|
|
94
|
+
)
|
|
95
|
+
"paraformer" -> mapOf(
|
|
96
|
+
"encoder" to firstFile("encoder"),
|
|
97
|
+
"decoder" to firstFile("decoder"),
|
|
98
|
+
"tokens" to tokensPath
|
|
99
|
+
)
|
|
100
|
+
"zipformer2_ctc", "nemo_ctc", "tone_ctc" -> mapOf(
|
|
101
|
+
"model" to firstFile("model"),
|
|
102
|
+
"tokens" to tokensPath
|
|
103
|
+
)
|
|
104
|
+
else -> throw IllegalArgumentException("Unsupported online STT model type: $modelType. Use: transducer, paraformer, zipformer2_ctc, nemo_ctc, tone_ctc")
|
|
105
|
+
}.also { paths ->
|
|
106
|
+
when (modelType) {
|
|
107
|
+
"transducer" -> {
|
|
108
|
+
if ((paths["encoder"]?.isEmpty() != false) || (paths["decoder"]?.isEmpty() != false) || (paths["joiner"]?.isEmpty() != false))
|
|
109
|
+
throw IllegalArgumentException("Transducer model requires encoder, decoder, and joiner .onnx files in $modelDir")
|
|
110
|
+
}
|
|
111
|
+
"paraformer" -> {
|
|
112
|
+
if ((paths["encoder"]?.isEmpty() != false) || (paths["decoder"]?.isEmpty() != false))
|
|
113
|
+
throw IllegalArgumentException("Paraformer model requires encoder and decoder .onnx files in $modelDir")
|
|
114
|
+
}
|
|
115
|
+
"zipformer2_ctc", "nemo_ctc", "tone_ctc" -> {
|
|
116
|
+
if (paths["model"]?.isEmpty() != false)
|
|
117
|
+
throw IllegalArgumentException("$modelType model requires model.onnx (or model*.onnx) in $modelDir")
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
private fun buildOnlineRecognizerConfig(
|
|
124
|
+
modelDir: String,
|
|
125
|
+
modelType: String,
|
|
126
|
+
enableEndpoint: Boolean,
|
|
127
|
+
decodingMethod: String,
|
|
128
|
+
maxActivePaths: Int,
|
|
129
|
+
hotwordsFile: String?,
|
|
130
|
+
hotwordsScore: Float?,
|
|
131
|
+
numThreads: Int?,
|
|
132
|
+
provider: String?,
|
|
133
|
+
ruleFsts: String?,
|
|
134
|
+
ruleFars: String?,
|
|
135
|
+
blankPenalty: Float?,
|
|
136
|
+
debug: Boolean?,
|
|
137
|
+
rule1MustContainNonSilence: Boolean?,
|
|
138
|
+
rule1MinTrailingSilence: Float?,
|
|
139
|
+
rule1MinUtteranceLength: Float?,
|
|
140
|
+
rule2MustContainNonSilence: Boolean?,
|
|
141
|
+
rule2MinTrailingSilence: Float?,
|
|
142
|
+
rule2MinUtteranceLength: Float?,
|
|
143
|
+
rule3MustContainNonSilence: Boolean?,
|
|
144
|
+
rule3MinTrailingSilence: Float?,
|
|
145
|
+
rule3MinUtteranceLength: Float?
|
|
146
|
+
): OnlineRecognizerConfig {
|
|
147
|
+
val paths = scanOnlineModelPaths(modelDir, modelType)
|
|
148
|
+
|
|
149
|
+
val endpointConfig = EndpointConfig(
|
|
150
|
+
rule1 = EndpointRule(
|
|
151
|
+
mustContainNonSilence = rule1MustContainNonSilence ?: false,
|
|
152
|
+
minTrailingSilence = rule1MinTrailingSilence ?: 2.4f,
|
|
153
|
+
minUtteranceLength = rule1MinUtteranceLength ?: 0f
|
|
154
|
+
),
|
|
155
|
+
rule2 = EndpointRule(
|
|
156
|
+
mustContainNonSilence = rule2MustContainNonSilence ?: true,
|
|
157
|
+
minTrailingSilence = rule2MinTrailingSilence ?: 1.4f,
|
|
158
|
+
minUtteranceLength = rule2MinUtteranceLength ?: 0f
|
|
159
|
+
),
|
|
160
|
+
rule3 = EndpointRule(
|
|
161
|
+
mustContainNonSilence = rule3MustContainNonSilence ?: false,
|
|
162
|
+
minTrailingSilence = rule3MinTrailingSilence ?: 0f,
|
|
163
|
+
minUtteranceLength = rule3MinUtteranceLength ?: 20f
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
val modelConfig = when (modelType) {
|
|
168
|
+
"transducer" -> OnlineModelConfig(
|
|
169
|
+
transducer = OnlineTransducerModelConfig(
|
|
170
|
+
encoder = paths["encoder"] ?: "",
|
|
171
|
+
decoder = paths["decoder"] ?: "",
|
|
172
|
+
joiner = paths["joiner"] ?: ""
|
|
173
|
+
),
|
|
174
|
+
tokens = paths["tokens"] ?: "",
|
|
175
|
+
numThreads = numThreads ?: 1,
|
|
176
|
+
debug = debug ?: false,
|
|
177
|
+
provider = provider ?: "cpu",
|
|
178
|
+
modelType = "zipformer"
|
|
179
|
+
)
|
|
180
|
+
"paraformer" -> OnlineModelConfig(
|
|
181
|
+
paraformer = OnlineParaformerModelConfig(
|
|
182
|
+
encoder = paths["encoder"] ?: "",
|
|
183
|
+
decoder = paths["decoder"] ?: ""
|
|
184
|
+
),
|
|
185
|
+
tokens = paths["tokens"] ?: "",
|
|
186
|
+
numThreads = numThreads ?: 1,
|
|
187
|
+
debug = debug ?: false,
|
|
188
|
+
provider = provider ?: "cpu",
|
|
189
|
+
modelType = "paraformer"
|
|
190
|
+
)
|
|
191
|
+
"zipformer2_ctc" -> OnlineModelConfig(
|
|
192
|
+
zipformer2Ctc = OnlineZipformer2CtcModelConfig(model = paths["model"] ?: ""),
|
|
193
|
+
tokens = paths["tokens"] ?: "",
|
|
194
|
+
numThreads = numThreads ?: 1,
|
|
195
|
+
debug = debug ?: false,
|
|
196
|
+
provider = provider ?: "cpu",
|
|
197
|
+
modelType = "zipformer2"
|
|
198
|
+
)
|
|
199
|
+
"nemo_ctc" -> OnlineModelConfig(
|
|
200
|
+
neMoCtc = OnlineNeMoCtcModelConfig(model = paths["model"] ?: ""),
|
|
201
|
+
tokens = paths["tokens"] ?: "",
|
|
202
|
+
numThreads = numThreads ?: 1,
|
|
203
|
+
debug = debug ?: false,
|
|
204
|
+
provider = provider ?: "cpu"
|
|
205
|
+
)
|
|
206
|
+
"tone_ctc" -> OnlineModelConfig(
|
|
207
|
+
toneCtc = OnlineToneCtcModelConfig(model = paths["model"] ?: ""),
|
|
208
|
+
tokens = paths["tokens"] ?: "",
|
|
209
|
+
numThreads = numThreads ?: 1,
|
|
210
|
+
debug = debug ?: false,
|
|
211
|
+
provider = provider ?: "cpu"
|
|
212
|
+
)
|
|
213
|
+
else -> throw IllegalArgumentException("Unsupported online model type: $modelType")
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
val resolvedRuleFsts = try {
|
|
217
|
+
resolveFilePaths(ruleFsts.orEmpty().trim(), "online_stt_rule_fst")
|
|
218
|
+
} catch (e: Exception) {
|
|
219
|
+
""
|
|
220
|
+
}
|
|
221
|
+
val resolvedRuleFars = try {
|
|
222
|
+
resolveFilePaths(ruleFars.orEmpty().trim(), "online_stt_rule_far")
|
|
223
|
+
} catch (e: Exception) {
|
|
224
|
+
""
|
|
225
|
+
}
|
|
226
|
+
var resolvedHotwordsFile = hotwordsFile?.trim().orEmpty()
|
|
227
|
+
if (resolvedHotwordsFile.isNotEmpty()) {
|
|
228
|
+
try {
|
|
229
|
+
resolvedHotwordsFile = resolveContentUriToFile(resolvedHotwordsFile, "online_stt_hotwords")
|
|
230
|
+
} catch (_: Exception) {
|
|
231
|
+
resolvedHotwordsFile = ""
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
return OnlineRecognizerConfig(
|
|
236
|
+
featConfig = FeatureConfig(sampleRate = 16000, featureDim = 80, dither = 0f),
|
|
237
|
+
modelConfig = modelConfig,
|
|
238
|
+
endpointConfig = endpointConfig,
|
|
239
|
+
enableEndpoint = enableEndpoint,
|
|
240
|
+
decodingMethod = decodingMethod,
|
|
241
|
+
maxActivePaths = maxActivePaths,
|
|
242
|
+
hotwordsFile = resolvedHotwordsFile,
|
|
243
|
+
hotwordsScore = hotwordsScore ?: 1.5f,
|
|
244
|
+
ruleFsts = resolvedRuleFsts,
|
|
245
|
+
ruleFars = resolvedRuleFars,
|
|
246
|
+
blankPenalty = blankPenalty ?: 0f
|
|
247
|
+
)
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
fun initializeOnlineStt(
|
|
251
|
+
instanceId: String,
|
|
252
|
+
modelDir: String,
|
|
253
|
+
modelType: String,
|
|
254
|
+
enableEndpoint: Boolean,
|
|
255
|
+
decodingMethod: String,
|
|
256
|
+
maxActivePaths: Int,
|
|
257
|
+
hotwordsFile: String?,
|
|
258
|
+
hotwordsScore: Double?,
|
|
259
|
+
numThreads: Double?,
|
|
260
|
+
provider: String?,
|
|
261
|
+
ruleFsts: String?,
|
|
262
|
+
ruleFars: String?,
|
|
263
|
+
blankPenalty: Double?,
|
|
264
|
+
debug: Boolean?,
|
|
265
|
+
rule1MustContainNonSilence: Boolean?,
|
|
266
|
+
rule1MinTrailingSilence: Double?,
|
|
267
|
+
rule1MinUtteranceLength: Double?,
|
|
268
|
+
rule2MustContainNonSilence: Boolean?,
|
|
269
|
+
rule2MinTrailingSilence: Double?,
|
|
270
|
+
rule2MinUtteranceLength: Double?,
|
|
271
|
+
rule3MustContainNonSilence: Boolean?,
|
|
272
|
+
rule3MinTrailingSilence: Double?,
|
|
273
|
+
rule3MinUtteranceLength: Double?,
|
|
274
|
+
promise: Promise
|
|
275
|
+
) {
|
|
276
|
+
try {
|
|
277
|
+
val config = buildOnlineRecognizerConfig(
|
|
278
|
+
modelDir = modelDir,
|
|
279
|
+
modelType = modelType,
|
|
280
|
+
enableEndpoint = enableEndpoint,
|
|
281
|
+
decodingMethod = decodingMethod,
|
|
282
|
+
maxActivePaths = maxActivePaths,
|
|
283
|
+
hotwordsFile = hotwordsFile,
|
|
284
|
+
hotwordsScore = hotwordsScore?.toFloat(),
|
|
285
|
+
numThreads = numThreads?.toInt(),
|
|
286
|
+
provider = provider,
|
|
287
|
+
ruleFsts = ruleFsts,
|
|
288
|
+
ruleFars = ruleFars,
|
|
289
|
+
blankPenalty = blankPenalty?.toFloat(),
|
|
290
|
+
debug = debug,
|
|
291
|
+
rule1MustContainNonSilence = rule1MustContainNonSilence,
|
|
292
|
+
rule1MinTrailingSilence = rule1MinTrailingSilence?.toFloat(),
|
|
293
|
+
rule1MinUtteranceLength = rule1MinUtteranceLength?.toFloat(),
|
|
294
|
+
rule2MustContainNonSilence = rule2MustContainNonSilence,
|
|
295
|
+
rule2MinTrailingSilence = rule2MinTrailingSilence?.toFloat(),
|
|
296
|
+
rule2MinUtteranceLength = rule2MinUtteranceLength?.toFloat(),
|
|
297
|
+
rule3MustContainNonSilence = rule3MustContainNonSilence,
|
|
298
|
+
rule3MinTrailingSilence = rule3MinTrailingSilence?.toFloat(),
|
|
299
|
+
rule3MinUtteranceLength = rule3MinUtteranceLength?.toFloat()
|
|
300
|
+
)
|
|
301
|
+
val recognizer = OnlineRecognizer(assetManager = null, config = config)
|
|
302
|
+
instances[instanceId] = OnlineSttInstance(recognizer = recognizer, config = config)
|
|
303
|
+
promise.resolve(Arguments.createMap().apply { putBoolean("success", true) })
|
|
304
|
+
} catch (e: Exception) {
|
|
305
|
+
Log.e(logTag, "initializeOnlineStt failed: ${e.message}", e)
|
|
306
|
+
promise.reject("INIT_ERROR", "Online STT init failed: ${e.message}", e)
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
fun createSttStream(instanceId: String, streamId: String, hotwords: String?, promise: Promise) {
|
|
311
|
+
try {
|
|
312
|
+
val inst = getInstance(instanceId)
|
|
313
|
+
?: run {
|
|
314
|
+
promise.reject("STREAM_ERROR", "Online STT instance not found: $instanceId")
|
|
315
|
+
return
|
|
316
|
+
}
|
|
317
|
+
if (inst.streams.containsKey(streamId)) {
|
|
318
|
+
promise.reject("STREAM_ERROR", "Stream already exists: $streamId")
|
|
319
|
+
return
|
|
320
|
+
}
|
|
321
|
+
val stream = inst.recognizer.createStream(hotwords = hotwords?.trim().orEmpty())
|
|
322
|
+
inst.streams[streamId] = stream
|
|
323
|
+
streamToInstance[streamId] = instanceId
|
|
324
|
+
promise.resolve(null)
|
|
325
|
+
} catch (e: Exception) {
|
|
326
|
+
Log.e(logTag, "createSttStream failed: ${e.message}", e)
|
|
327
|
+
promise.reject("STREAM_ERROR", "Create stream failed: ${e.message}", e)
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
private fun readableArrayToFloatArray(arr: ReadableArray): FloatArray =
|
|
332
|
+
FloatArray(arr.size()) { i -> arr.getDouble(i).toFloat() }
|
|
333
|
+
|
|
334
|
+
fun acceptSttWaveform(streamId: String, samples: ReadableArray, sampleRate: Int, promise: Promise) {
|
|
335
|
+
try {
|
|
336
|
+
val (_, stream) = getStream(streamId)
|
|
337
|
+
?: run {
|
|
338
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
339
|
+
return
|
|
340
|
+
}
|
|
341
|
+
val floatSamples = readableArrayToFloatArray(samples)
|
|
342
|
+
stream.acceptWaveform(floatSamples, sampleRate)
|
|
343
|
+
promise.resolve(null)
|
|
344
|
+
} catch (e: Exception) {
|
|
345
|
+
Log.e(logTag, "acceptSttWaveform failed: ${e.message}", e)
|
|
346
|
+
promise.reject("STREAM_ERROR", "acceptSttWaveform failed: ${e.message}", e)
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
fun sttStreamInputFinished(streamId: String, promise: Promise) {
|
|
351
|
+
try {
|
|
352
|
+
val (_, stream) = getStream(streamId)
|
|
353
|
+
?: run {
|
|
354
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
355
|
+
return
|
|
356
|
+
}
|
|
357
|
+
stream.inputFinished()
|
|
358
|
+
promise.resolve(null)
|
|
359
|
+
} catch (e: Exception) {
|
|
360
|
+
Log.e(logTag, "sttStreamInputFinished failed: ${e.message}", e)
|
|
361
|
+
promise.reject("STREAM_ERROR", "sttStreamInputFinished failed: ${e.message}", e)
|
|
362
|
+
}
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
fun decodeSttStream(streamId: String, promise: Promise) {
|
|
366
|
+
try {
|
|
367
|
+
val (inst, stream) = getStream(streamId)
|
|
368
|
+
?: run {
|
|
369
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
370
|
+
return
|
|
371
|
+
}
|
|
372
|
+
inst.recognizer.decode(stream)
|
|
373
|
+
promise.resolve(null)
|
|
374
|
+
} catch (e: Exception) {
|
|
375
|
+
Log.e(logTag, "decodeSttStream failed: ${e.message}", e)
|
|
376
|
+
promise.reject("STREAM_ERROR", "decodeSttStream failed: ${e.message}", e)
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
fun isSttStreamReady(streamId: String, promise: Promise) {
|
|
381
|
+
try {
|
|
382
|
+
val (inst, stream) = getStream(streamId)
|
|
383
|
+
?: run {
|
|
384
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
385
|
+
return
|
|
386
|
+
}
|
|
387
|
+
val ready = inst.recognizer.isReady(stream)
|
|
388
|
+
promise.resolve(ready)
|
|
389
|
+
} catch (e: Exception) {
|
|
390
|
+
Log.e(logTag, "isSttStreamReady failed: ${e.message}", e)
|
|
391
|
+
promise.reject("STREAM_ERROR", "isSttStreamReady failed: ${e.message}", e)
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
private fun resultToWritableMap(result: OnlineRecognizerResult): WritableMap {
|
|
396
|
+
val map = Arguments.createMap()
|
|
397
|
+
map.putString("text", result.text)
|
|
398
|
+
val tokensArray = Arguments.createArray()
|
|
399
|
+
for (t in result.tokens) tokensArray.pushString(t)
|
|
400
|
+
map.putArray("tokens", tokensArray)
|
|
401
|
+
val timestampsArray = Arguments.createArray()
|
|
402
|
+
for (t in result.timestamps) timestampsArray.pushDouble(t.toDouble())
|
|
403
|
+
map.putArray("timestamps", timestampsArray)
|
|
404
|
+
return map
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
fun getSttStreamResult(streamId: String, promise: Promise) {
|
|
408
|
+
try {
|
|
409
|
+
val (inst, stream) = getStream(streamId)
|
|
410
|
+
?: run {
|
|
411
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
412
|
+
return
|
|
413
|
+
}
|
|
414
|
+
val result = inst.recognizer.getResult(stream)
|
|
415
|
+
promise.resolve(resultToWritableMap(result))
|
|
416
|
+
} catch (e: Exception) {
|
|
417
|
+
Log.e(logTag, "getSttStreamResult failed: ${e.message}", e)
|
|
418
|
+
promise.reject("STREAM_ERROR", "getSttStreamResult failed: ${e.message}", e)
|
|
419
|
+
}
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
fun isSttStreamEndpoint(streamId: String, promise: Promise) {
|
|
423
|
+
try {
|
|
424
|
+
val (inst, stream) = getStream(streamId)
|
|
425
|
+
?: run {
|
|
426
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
427
|
+
return
|
|
428
|
+
}
|
|
429
|
+
val endpoint = inst.recognizer.isEndpoint(stream)
|
|
430
|
+
promise.resolve(endpoint)
|
|
431
|
+
} catch (e: Exception) {
|
|
432
|
+
Log.e(logTag, "isSttStreamEndpoint failed: ${e.message}", e)
|
|
433
|
+
promise.reject("STREAM_ERROR", "isSttStreamEndpoint failed: ${e.message}", e)
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
fun resetSttStream(streamId: String, promise: Promise) {
|
|
438
|
+
try {
|
|
439
|
+
val (inst, stream) = getStream(streamId)
|
|
440
|
+
?: run {
|
|
441
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
442
|
+
return
|
|
443
|
+
}
|
|
444
|
+
inst.recognizer.reset(stream)
|
|
445
|
+
promise.resolve(null)
|
|
446
|
+
} catch (e: Exception) {
|
|
447
|
+
Log.e(logTag, "resetSttStream failed: ${e.message}", e)
|
|
448
|
+
promise.reject("STREAM_ERROR", "resetSttStream failed: ${e.message}", e)
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
fun releaseSttStream(streamId: String, promise: Promise) {
|
|
453
|
+
try {
|
|
454
|
+
val instanceId = streamToInstance.remove(streamId) ?: run {
|
|
455
|
+
promise.resolve(null)
|
|
456
|
+
return
|
|
457
|
+
}
|
|
458
|
+
val inst = instances[instanceId] ?: run {
|
|
459
|
+
promise.resolve(null)
|
|
460
|
+
return
|
|
461
|
+
}
|
|
462
|
+
inst.streams.remove(streamId)?.release()
|
|
463
|
+
promise.resolve(null)
|
|
464
|
+
} catch (e: Exception) {
|
|
465
|
+
Log.e(logTag, "releaseSttStream failed: ${e.message}", e)
|
|
466
|
+
promise.reject("STREAM_ERROR", "releaseSttStream failed: ${e.message}", e)
|
|
467
|
+
}
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
fun unloadOnlineStt(instanceId: String, promise: Promise) {
|
|
471
|
+
try {
|
|
472
|
+
val inst = instances.remove(instanceId) ?: run {
|
|
473
|
+
promise.resolve(null)
|
|
474
|
+
return
|
|
475
|
+
}
|
|
476
|
+
val streamIds = inst.streams.keys.toList()
|
|
477
|
+
inst.streams.values.forEach { it.release() }
|
|
478
|
+
inst.streams.clear()
|
|
479
|
+
streamIds.forEach { streamToInstance.remove(it) }
|
|
480
|
+
inst.recognizer.release()
|
|
481
|
+
promise.resolve(null)
|
|
482
|
+
} catch (e: Exception) {
|
|
483
|
+
Log.e(logTag, "unloadOnlineStt failed: ${e.message}", e)
|
|
484
|
+
promise.reject("RELEASE_ERROR", "unloadOnlineStt failed: ${e.message}", e)
|
|
485
|
+
}
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
/**
|
|
489
|
+
* Convenience: accept waveform, then while (isReady) decode, then getResult and isEndpoint.
|
|
490
|
+
*/
|
|
491
|
+
fun processSttAudioChunk(
|
|
492
|
+
streamId: String,
|
|
493
|
+
samples: ReadableArray,
|
|
494
|
+
sampleRate: Int,
|
|
495
|
+
promise: Promise
|
|
496
|
+
) {
|
|
497
|
+
try {
|
|
498
|
+
val (inst, stream) = getStream(streamId)
|
|
499
|
+
?: run {
|
|
500
|
+
promise.reject("STREAM_ERROR", "Stream not found: $streamId")
|
|
501
|
+
return
|
|
502
|
+
}
|
|
503
|
+
val floatSamples = readableArrayToFloatArray(samples)
|
|
504
|
+
stream.acceptWaveform(floatSamples, sampleRate)
|
|
505
|
+
while (inst.recognizer.isReady(stream)) {
|
|
506
|
+
inst.recognizer.decode(stream)
|
|
507
|
+
}
|
|
508
|
+
val result = inst.recognizer.getResult(stream)
|
|
509
|
+
val isEndpoint = inst.recognizer.isEndpoint(stream)
|
|
510
|
+
val map = resultToWritableMap(result)
|
|
511
|
+
map.putBoolean("isEndpoint", isEndpoint)
|
|
512
|
+
promise.resolve(map)
|
|
513
|
+
} catch (e: Exception) {
|
|
514
|
+
Log.e(logTag, "processSttAudioChunk failed: ${e.message}", e)
|
|
515
|
+
promise.reject("STREAM_ERROR", "processSttAudioChunk failed: ${e.message}", e)
|
|
516
|
+
}
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
/** Call from Module.onCatalystInstanceDestroy to release all resources. */
|
|
520
|
+
fun shutdown() {
|
|
521
|
+
instances.keys.toList().forEach { instanceId ->
|
|
522
|
+
try {
|
|
523
|
+
val inst = instances.remove(instanceId) ?: return@forEach
|
|
524
|
+
val streamIds = inst.streams.keys.toList()
|
|
525
|
+
inst.streams.values.forEach { it.release() }
|
|
526
|
+
inst.streams.clear()
|
|
527
|
+
streamIds.forEach { streamToInstance.remove(it) }
|
|
528
|
+
inst.recognizer.release()
|
|
529
|
+
} catch (e: Exception) {
|
|
530
|
+
Log.w(logTag, "shutdown: failed to release instance $instanceId: ${e.message}")
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
streamToInstance.clear()
|
|
534
|
+
}
|
|
535
|
+
}
|
|
@@ -41,9 +41,9 @@ import java.util.concurrent.atomic.AtomicBoolean
|
|
|
41
41
|
internal class SherpaOnnxTtsHelper(
|
|
42
42
|
private val context: ReactApplicationContext,
|
|
43
43
|
private val detectTtsModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?,
|
|
44
|
-
private val emitChunk: (String, FloatArray, Int, Float, Boolean) -> Unit,
|
|
45
|
-
private val emitError: (String, String) -> Unit,
|
|
46
|
-
private val emitEnd: (String, Boolean) -> Unit
|
|
44
|
+
private val emitChunk: (String, String, FloatArray, Int, Float, Boolean) -> Unit,
|
|
45
|
+
private val emitError: (String, String, String) -> Unit,
|
|
46
|
+
private val emitEnd: (String, String, Boolean) -> Unit
|
|
47
47
|
) {
|
|
48
48
|
|
|
49
49
|
private data class TtsInitState(
|
|
@@ -501,7 +501,7 @@ internal class SherpaOnnxTtsHelper(
|
|
|
501
501
|
}
|
|
502
502
|
}
|
|
503
503
|
|
|
504
|
-
fun generateTtsStream(instanceId: String, text: String, options: ReadableMap?, promise: Promise) {
|
|
504
|
+
fun generateTtsStream(instanceId: String, requestId: String, text: String, options: ReadableMap?, promise: Promise) {
|
|
505
505
|
val inst = getInstance(instanceId) ?: run {
|
|
506
506
|
Log.e("SherpaOnnxTts", "TTS_STREAM_ERROR: TTS instance not found: $instanceId")
|
|
507
507
|
promise.reject("TTS_STREAM_ERROR", "TTS instance not found: $instanceId")
|
|
@@ -534,34 +534,34 @@ internal class SherpaOnnxTtsHelper(
|
|
|
534
534
|
val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
|
|
535
535
|
inst.tts!!.generateWithConfigAndCallback(text, config) { chunk ->
|
|
536
536
|
if (inst.ttsStreamCancelled.get()) return@generateWithConfigAndCallback 0
|
|
537
|
-
emitChunk(instanceId, chunk, sampleRate, 0f, false)
|
|
537
|
+
emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
|
|
538
538
|
chunk.size
|
|
539
539
|
}
|
|
540
540
|
}
|
|
541
541
|
inst.zipvoiceTts != null -> {
|
|
542
542
|
inst.zipvoiceTts!!.generateWithCallback(text, sid, speed) { chunk ->
|
|
543
543
|
if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
|
|
544
|
-
emitChunk(instanceId, chunk, sampleRate, 0f, false)
|
|
544
|
+
emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
|
|
545
545
|
chunk.size
|
|
546
546
|
}
|
|
547
547
|
}
|
|
548
548
|
else -> {
|
|
549
549
|
inst.tts!!.generateWithCallback(text, sid, speed) { chunk ->
|
|
550
550
|
if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
|
|
551
|
-
emitChunk(instanceId, chunk, sampleRate, 0f, false)
|
|
551
|
+
emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
|
|
552
552
|
chunk.size
|
|
553
553
|
}
|
|
554
554
|
}
|
|
555
555
|
}
|
|
556
556
|
if (!inst.ttsStreamCancelled.get()) {
|
|
557
|
-
emitChunk(instanceId, FloatArray(0), sampleRate, 1f, true)
|
|
557
|
+
emitChunk(instanceId, requestId, FloatArray(0), sampleRate, 1f, true)
|
|
558
558
|
}
|
|
559
559
|
} catch (e: Exception) {
|
|
560
560
|
if (!inst.ttsStreamCancelled.get()) {
|
|
561
|
-
emitError(instanceId, "TTS streaming failed: ${e.message}")
|
|
561
|
+
emitError(instanceId, requestId, "TTS streaming failed: ${e.message}")
|
|
562
562
|
}
|
|
563
563
|
} finally {
|
|
564
|
-
emitEnd(instanceId, inst.ttsStreamCancelled.get())
|
|
564
|
+
emitEnd(instanceId, requestId, inst.ttsStreamCancelled.get())
|
|
565
565
|
inst.ttsStreamRunning.set(false)
|
|
566
566
|
}
|
|
567
567
|
}
|