react-native-sherpa-onnx 0.3.9 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +17 -4
  2. package/SherpaOnnx.podspec +1 -0
  3. package/android/prebuilt-download.gradle +67 -27
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  6. package/android/src/main/cpp/CMakeLists.txt +3 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  13. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  15. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  16. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
  17. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +66 -13
  18. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  19. package/ios/SherpaOnnx+Assets.mm +5 -0
  20. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  21. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  22. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  23. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  24. package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
  25. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  26. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  27. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  28. package/lib/module/download/localModels.js +2 -3
  29. package/lib/module/download/localModels.js.map +1 -1
  30. package/lib/module/download/paths.js +2 -1
  31. package/lib/module/download/paths.js.map +1 -1
  32. package/lib/module/enhancement/index.js +63 -48
  33. package/lib/module/enhancement/index.js.map +1 -1
  34. package/lib/module/enhancement/streaming.js +60 -0
  35. package/lib/module/enhancement/streaming.js.map +1 -0
  36. package/lib/module/enhancement/streamingTypes.js +4 -0
  37. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  38. package/lib/module/enhancement/types.js +4 -0
  39. package/lib/module/enhancement/types.js.map +1 -0
  40. package/lib/module/licenses.js +9 -3
  41. package/lib/module/licenses.js.map +1 -1
  42. package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
  43. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  44. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  45. package/lib/typescript/src/download/paths.d.ts +2 -1
  46. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  47. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  48. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  49. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  50. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  51. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  52. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  53. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  54. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  55. package/lib/typescript/src/licenses.d.ts.map +1 -1
  56. package/package.json +1 -1
  57. package/scripts/ci/check-model-csvs.sh +27 -2
  58. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  59. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  60. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  61. package/scripts/ci/update_model_license_csv.sh +1 -1
  62. package/src/NativeSherpaOnnx.ts +71 -0
  63. package/src/download/localModels.ts +1 -3
  64. package/src/download/paths.ts +2 -1
  65. package/src/enhancement/index.ts +120 -58
  66. package/src/enhancement/streaming.ts +105 -0
  67. package/src/enhancement/streamingTypes.ts +14 -0
  68. package/src/enhancement/types.ts +36 -0
  69. package/src/licenses.ts +13 -2
  70. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
382
382
  "mms",
383
383
  "tts"
384
384
  )
385
+ val enhancementHints = listOf(
386
+ "gtcrn",
387
+ "dpdfnet"
388
+ )
385
389
 
386
390
  val isStt = sttHints.any { name.contains(it) }
387
391
  val isTts = ttsHints.any { name.contains(it) }
392
+ val isEnhancement = enhancementHints.any { name.contains(it) }
388
393
 
389
394
  return when {
390
395
  isStt && !isTts -> "stt"
391
396
  isTts && !isStt -> "tts"
397
+ isEnhancement && !isStt && !isTts -> "enhancement"
392
398
  else -> "unknown"
393
399
  }
394
400
  }
@@ -0,0 +1,377 @@
1
+ package com.sherpaonnx
2
+
3
+ import android.net.Uri
4
+ import android.util.Log
5
+ import com.facebook.react.bridge.Arguments
6
+ import com.facebook.react.bridge.Promise
7
+ import com.facebook.react.bridge.ReadableArray
8
+ import com.facebook.react.bridge.ReactApplicationContext
9
+ import com.facebook.react.bridge.WritableMap
10
+ import com.k2fsa.sherpa.onnx.DenoisedAudio
11
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
12
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
13
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
14
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
15
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
16
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
17
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
18
+ import com.k2fsa.sherpa.onnx.WaveReader
19
+ import java.io.File
20
+ import java.util.concurrent.ConcurrentHashMap
21
+
22
+ internal class SherpaOnnxEnhancementHelper(
23
+ private val context: ReactApplicationContext,
24
+ private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
25
+ ) {
26
+ private data class EnhancementInstance(
27
+ @Volatile var denoiser: OfflineSpeechDenoiser? = null
28
+ ) {
29
+ fun release() {
30
+ denoiser?.release()
31
+ denoiser = null
32
+ }
33
+ }
34
+
35
+ private data class OnlineEnhancementInstance(
36
+ @Volatile var denoiser: OnlineSpeechDenoiser? = null
37
+ ) {
38
+ fun release() {
39
+ denoiser?.release()
40
+ denoiser = null
41
+ }
42
+ }
43
+
44
+ private val instances = ConcurrentHashMap<String, EnhancementInstance>()
45
+ private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
46
+
47
+ fun shutdown() {
48
+ instances.values.forEach { it.release() }
49
+ instances.clear()
50
+ onlineInstances.values.forEach { it.release() }
51
+ onlineInstances.clear()
52
+ }
53
+
54
+ private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
55
+
56
+ private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
57
+ val samples = Arguments.createArray()
58
+ for (sample in audio.samples) {
59
+ samples.pushDouble(sample.toDouble())
60
+ }
61
+ val out = Arguments.createMap()
62
+ out.putArray("samples", samples)
63
+ out.putInt("sampleRate", audio.sampleRate)
64
+ return out
65
+ }
66
+
67
+ private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
68
+ val out = FloatArray(samples.size())
69
+ for (i in 0 until samples.size()) {
70
+ out[i] = samples.getDouble(i).toFloat()
71
+ }
72
+ return out
73
+ }
74
+
75
+ private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
76
+ if (!path.startsWith("content://")) return Pair(path, null)
77
+ val uri = Uri.parse(path)
78
+ val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
79
+ context.contentResolver.openInputStream(uri)?.use { input ->
80
+ tmp.outputStream().use { output -> input.copyTo(output) }
81
+ } ?: throw IllegalStateException("File is not readable: $path")
82
+ return Pair(tmp.absolutePath, tmp)
83
+ }
84
+
85
+ fun detectEnhancementModel(
86
+ modelDir: String,
87
+ modelType: String?,
88
+ promise: Promise
89
+ ) {
90
+ try {
91
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
92
+ if (result == null) {
93
+ promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
94
+ return
95
+ }
96
+ val success = result["success"] as? Boolean ?: false
97
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
98
+ ?: arrayListOf<HashMap<String, String>>()
99
+ val modelTypeStr = result["modelType"] as? String
100
+
101
+ val resultMap = Arguments.createMap()
102
+ resultMap.putBoolean("success", success)
103
+ val modelsArray = Arguments.createArray()
104
+ for (model in detectedModels) {
105
+ val modelMap = model as? HashMap<*, *>
106
+ if (modelMap != null) {
107
+ val entry = Arguments.createMap()
108
+ entry.putString("type", modelMap["type"] as? String ?: "")
109
+ entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
110
+ modelsArray.pushMap(entry)
111
+ }
112
+ }
113
+ resultMap.putArray("detectedModels", modelsArray)
114
+ if (modelTypeStr != null) {
115
+ resultMap.putString("modelType", modelTypeStr)
116
+ }
117
+ if (!success) {
118
+ val error = result["error"] as? String
119
+ if (!error.isNullOrBlank()) {
120
+ resultMap.putString("error", error)
121
+ }
122
+ }
123
+ promise.resolve(resultMap)
124
+ } catch (e: Exception) {
125
+ Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
126
+ promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
127
+ }
128
+ }
129
+
130
+ fun initializeEnhancement(
131
+ instanceId: String,
132
+ modelDir: String,
133
+ modelType: String?,
134
+ numThreads: Double?,
135
+ provider: String?,
136
+ debug: Boolean?,
137
+ promise: Promise
138
+ ) {
139
+ try {
140
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
141
+ if (result == null || result["success"] as? Boolean != true) {
142
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
143
+ promise.reject("ENHANCEMENT_INIT_ERROR", reason)
144
+ return
145
+ }
146
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
147
+ val paths = (result["paths"] as? Map<*, *>)
148
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
149
+ ?.mapKeys { it.key.toString() }
150
+ ?: emptyMap()
151
+
152
+ val offlineModelConfig = when (modelTypeStr) {
153
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
154
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
155
+ numThreads = numThreads?.toInt() ?: 1,
156
+ provider = provider ?: "cpu",
157
+ debug = debug ?: false
158
+ )
159
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
160
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
161
+ numThreads = numThreads?.toInt() ?: 1,
162
+ provider = provider ?: "cpu",
163
+ debug = debug ?: false
164
+ )
165
+ else -> {
166
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
167
+ return
168
+ }
169
+ }
170
+
171
+ val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
172
+ inst.release()
173
+ val denoiser = OfflineSpeechDenoiser(
174
+ config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
175
+ )
176
+ inst.denoiser = denoiser
177
+
178
+ val modelsArray = Arguments.createArray()
179
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
180
+ detectedModels?.forEach { modelObj ->
181
+ if (modelObj is HashMap<*, *>) {
182
+ val modelMap = Arguments.createMap()
183
+ modelMap.putString("type", modelObj["type"] as? String ?: "")
184
+ modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
185
+ modelsArray.pushMap(modelMap)
186
+ }
187
+ }
188
+
189
+ val out = Arguments.createMap()
190
+ out.putBoolean("success", true)
191
+ out.putArray("detectedModels", modelsArray)
192
+ out.putString("modelType", modelTypeStr)
193
+ out.putInt("sampleRate", denoiser.sampleRate)
194
+ promise.resolve(out)
195
+ } catch (e: Exception) {
196
+ Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
197
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
198
+ }
199
+ }
200
+
201
+ fun enhanceSamples(
202
+ instanceId: String,
203
+ samples: ReadableArray,
204
+ sampleRate: Double,
205
+ promise: Promise
206
+ ) {
207
+ val inst = instances[instanceId]
208
+ val denoiser = inst?.denoiser
209
+ if (denoiser == null) {
210
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
211
+ return
212
+ }
213
+ try {
214
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
215
+ promise.resolve(toEnhancedAudioMap(audio))
216
+ } catch (e: Exception) {
217
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
218
+ }
219
+ }
220
+
221
+ fun enhanceFile(
222
+ instanceId: String,
223
+ inputPath: String,
224
+ outputPath: String?,
225
+ promise: Promise
226
+ ) {
227
+ val inst = instances[instanceId]
228
+ val denoiser = inst?.denoiser
229
+ if (denoiser == null) {
230
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
231
+ return
232
+ }
233
+
234
+ var tmpInput: File? = null
235
+ try {
236
+ val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
237
+ tmpInput = tmp
238
+ val wave = WaveReader.readWave(resolvedInputPath)
239
+ val audio = denoiser.run(wave.samples, wave.sampleRate)
240
+ if (!outputPath.isNullOrBlank()) {
241
+ audio.save(outputPath)
242
+ }
243
+ promise.resolve(toEnhancedAudioMap(audio))
244
+ } catch (e: Exception) {
245
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
246
+ } finally {
247
+ tmpInput?.delete()
248
+ }
249
+ }
250
+
251
+ fun getSampleRate(instanceId: String, promise: Promise) {
252
+ val inst = instances[instanceId]
253
+ val denoiser = inst?.denoiser
254
+ if (denoiser == null) {
255
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
256
+ return
257
+ }
258
+ promise.resolve(denoiser.sampleRate)
259
+ }
260
+
261
+ fun unloadEnhancement(instanceId: String, promise: Promise) {
262
+ instances.remove(instanceId)?.release()
263
+ promise.resolve(null)
264
+ }
265
+
266
+ fun initializeOnlineEnhancement(
267
+ instanceId: String,
268
+ modelDir: String,
269
+ modelType: String?,
270
+ numThreads: Double?,
271
+ provider: String?,
272
+ debug: Boolean?,
273
+ promise: Promise
274
+ ) {
275
+ try {
276
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
277
+ if (result == null || result["success"] as? Boolean != true) {
278
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
279
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
280
+ return
281
+ }
282
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
283
+ val paths = (result["paths"] as? Map<*, *>)
284
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
285
+ ?.mapKeys { it.key.toString() }
286
+ ?: emptyMap()
287
+
288
+ val offlineModelConfig = when (modelTypeStr) {
289
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
290
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
291
+ numThreads = numThreads?.toInt() ?: 1,
292
+ provider = provider ?: "cpu",
293
+ debug = debug ?: false
294
+ )
295
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
296
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
297
+ numThreads = numThreads?.toInt() ?: 1,
298
+ provider = provider ?: "cpu",
299
+ debug = debug ?: false
300
+ )
301
+ else -> {
302
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
303
+ return
304
+ }
305
+ }
306
+
307
+ val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
308
+ inst.release()
309
+ val denoiser = OnlineSpeechDenoiser(
310
+ config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
311
+ )
312
+ inst.denoiser = denoiser
313
+
314
+ val out = Arguments.createMap()
315
+ out.putBoolean("success", true)
316
+ out.putInt("sampleRate", denoiser.sampleRate)
317
+ out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
318
+ promise.resolve(out)
319
+ } catch (e: Exception) {
320
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
321
+ }
322
+ }
323
+
324
+ fun feedSamples(
325
+ instanceId: String,
326
+ samples: ReadableArray,
327
+ sampleRate: Double,
328
+ promise: Promise
329
+ ) {
330
+ val inst = onlineInstances[instanceId]
331
+ val denoiser = inst?.denoiser
332
+ if (denoiser == null) {
333
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
334
+ return
335
+ }
336
+ try {
337
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
338
+ promise.resolve(toEnhancedAudioMap(audio))
339
+ } catch (e: Exception) {
340
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
341
+ }
342
+ }
343
+
344
+ fun flushOnline(instanceId: String, promise: Promise) {
345
+ val inst = onlineInstances[instanceId]
346
+ val denoiser = inst?.denoiser
347
+ if (denoiser == null) {
348
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
349
+ return
350
+ }
351
+ try {
352
+ promise.resolve(toEnhancedAudioMap(denoiser.flush()))
353
+ } catch (e: Exception) {
354
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
355
+ }
356
+ }
357
+
358
+ fun resetOnline(instanceId: String, promise: Promise) {
359
+ val inst = onlineInstances[instanceId]
360
+ val denoiser = inst?.denoiser
361
+ if (denoiser == null) {
362
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
363
+ return
364
+ }
365
+ try {
366
+ denoiser.reset()
367
+ promise.resolve(null)
368
+ } catch (e: Exception) {
369
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
370
+ }
371
+ }
372
+
373
+ fun unloadOnline(instanceId: String, promise: Promise) {
374
+ onlineInstances.remove(instanceId)?.release()
375
+ promise.resolve(null)
376
+ }
377
+ }
@@ -56,6 +56,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
56
56
  { instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
57
57
  { instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
58
58
  )
59
+ private val enhancementHelper = SherpaOnnxEnhancementHelper(
60
+ reactApplicationContext,
61
+ { modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
62
+ )
59
63
  private val archiveHelper = SherpaOnnxArchiveHelper()
60
64
  private var pcmCapture: SherpaOnnxPcmCapture? = null
61
65
 
@@ -69,6 +73,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
69
73
  pcmCapture = null
70
74
  onlineSttHelper.shutdown()
71
75
  ttsHelper.shutdown()
76
+ enhancementHelper.shutdown()
72
77
  }
73
78
 
74
79
  /**
@@ -1059,6 +1064,103 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1059
1064
  ttsHelper.unloadTts(instanceId, promise)
1060
1065
  }
1061
1066
 
1067
+ // ==================== Speech Enhancement Methods ====================
1068
+
1069
+ override fun detectEnhancementModel(
1070
+ modelDir: String,
1071
+ modelType: String?,
1072
+ promise: Promise
1073
+ ) {
1074
+ enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
1075
+ }
1076
+
1077
+ override fun initializeEnhancement(
1078
+ instanceId: String,
1079
+ modelDir: String,
1080
+ modelType: String?,
1081
+ numThreads: Double?,
1082
+ provider: String?,
1083
+ debug: Boolean?,
1084
+ promise: Promise
1085
+ ) {
1086
+ enhancementHelper.initializeEnhancement(
1087
+ instanceId,
1088
+ modelDir,
1089
+ modelType,
1090
+ numThreads,
1091
+ provider,
1092
+ debug,
1093
+ promise
1094
+ )
1095
+ }
1096
+
1097
+ override fun enhanceFile(
1098
+ instanceId: String,
1099
+ inputPath: String,
1100
+ outputPath: String?,
1101
+ promise: Promise
1102
+ ) {
1103
+ enhancementHelper.enhanceFile(instanceId, inputPath, outputPath, promise)
1104
+ }
1105
+
1106
+ override fun enhanceSamples(
1107
+ instanceId: String,
1108
+ samples: ReadableArray,
1109
+ sampleRate: Double,
1110
+ promise: Promise
1111
+ ) {
1112
+ enhancementHelper.enhanceSamples(instanceId, samples, sampleRate, promise)
1113
+ }
1114
+
1115
+ override fun getEnhancementSampleRate(instanceId: String, promise: Promise) {
1116
+ enhancementHelper.getSampleRate(instanceId, promise)
1117
+ }
1118
+
1119
+ override fun unloadEnhancement(instanceId: String, promise: Promise) {
1120
+ enhancementHelper.unloadEnhancement(instanceId, promise)
1121
+ }
1122
+
1123
+ override fun initializeOnlineEnhancement(
1124
+ instanceId: String,
1125
+ modelDir: String,
1126
+ modelType: String?,
1127
+ numThreads: Double?,
1128
+ provider: String?,
1129
+ debug: Boolean?,
1130
+ promise: Promise
1131
+ ) {
1132
+ enhancementHelper.initializeOnlineEnhancement(
1133
+ instanceId,
1134
+ modelDir,
1135
+ modelType,
1136
+ numThreads,
1137
+ provider,
1138
+ debug,
1139
+ promise
1140
+ )
1141
+ }
1142
+
1143
+ override fun feedEnhancementSamples(
1144
+ instanceId: String,
1145
+ samples: ReadableArray,
1146
+ sampleRate: Double,
1147
+ promise: Promise
1148
+ ) {
1149
+ enhancementHelper.feedSamples(instanceId, samples, sampleRate, promise)
1150
+ }
1151
+
1152
+ override fun flushOnlineEnhancement(instanceId: String, promise: Promise) {
1153
+ enhancementHelper.flushOnline(instanceId, promise)
1154
+ }
1155
+
1156
+ override fun resetOnlineEnhancement(instanceId: String, promise: Promise) {
1157
+ enhancementHelper.resetOnline(instanceId, promise)
1158
+ }
1159
+
1160
+ override fun unloadOnlineEnhancement(instanceId: String, promise: Promise) {
1161
+ enhancementHelper.unloadOnline(instanceId, promise)
1162
+ }
1163
+
1062
1164
  /**
1063
1165
  * Save TTS audio samples to a WAV file.
1064
1166
  */
@@ -1256,6 +1358,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1256
1358
  @JvmStatic
1257
1359
  private external fun nativeDetectTtsModel(modelDir: String, modelType: String): HashMap<String, Any>?
1258
1360
 
1361
+ /** Model detection for speech enhancement: returns HashMap with success, error, detectedModels, modelType, paths. */
1362
+ @JvmStatic
1363
+ private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
1364
+
1259
1365
  /** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
1260
1366
  * outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
1261
1367
  * Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.
@@ -111,6 +111,46 @@ internal class SherpaOnnxTtsHelper(
111
111
  }
112
112
  }
113
113
 
114
+ /**
115
+ * libsherpa-onnx-jni looks up `invoke([F)Ljava/lang/Integer` (see sherpa-onnx `offline-tts.cc` CallCallback).
116
+ * Kotlin `Function1<*, Int>` compiles to `invoke([F)I`, so GetMethodID fails and JNI aborts.
117
+ * Using [java.lang.Integer] as the type parameter yields the boxed JVM signature the JNI expects.
118
+ * The cast is only for the Kotlin API (`generateWithCallback` still declares `Function1<FloatArray, Int>`).
119
+ */
120
+ /** Box for JNI: must be real [java.lang.Integer], not Kotlin [Int] (primitive `invoke([F)I` breaks sherpa JNI). */
121
+ @Suppress("DEPRECATION")
122
+ private fun boxForTtsJni(n: Int): java.lang.Integer = java.lang.Integer(n)
123
+
124
+ @Suppress("UNCHECKED_CAST")
125
+ private fun ttsChunkCallbackForJni(
126
+ sentenceChunkSizes: MutableList<Int>
127
+ ): kotlin.Function1<FloatArray, Int> {
128
+ val boxed =
129
+ object : kotlin.jvm.functions.Function1<FloatArray, java.lang.Integer> {
130
+ override fun invoke(chunk: FloatArray): java.lang.Integer {
131
+ sentenceChunkSizes.add(chunk.size)
132
+ return boxForTtsJni(chunk.size)
133
+ }
134
+ }
135
+ return boxed as kotlin.Function1<FloatArray, Int>
136
+ }
137
+
138
+ @Suppress("UNCHECKED_CAST")
139
+ private fun ttsStreamChunkCallbackForJni(
140
+ cancelled: AtomicBoolean,
141
+ onChunk: (FloatArray) -> Unit
142
+ ): kotlin.Function1<FloatArray, Int> {
143
+ val boxed =
144
+ object : kotlin.jvm.functions.Function1<FloatArray, java.lang.Integer> {
145
+ override fun invoke(chunk: FloatArray): java.lang.Integer {
146
+ if (cancelled.get()) return boxForTtsJni(0)
147
+ onChunk(chunk)
148
+ return boxForTtsJni(chunk.size)
149
+ }
150
+ }
151
+ return boxed as kotlin.Function1<FloatArray, Int>
152
+ }
153
+
114
154
  /** Single-thread executor for TTS init so the RN bridge thread is not blocked (avoids Inspector/dev WebSocket races in debug builds). */
115
155
  private val ttsInitExecutor = Executors.newSingleThreadExecutor()
116
156
 
@@ -453,6 +493,7 @@ internal class SherpaOnnxTtsHelper(
453
493
  }
454
494
  val sid = getSid(options)
455
495
  val speed = getSpeed(options)
496
+ val sentenceChunkSizes = mutableListOf<Int>()
456
497
  val audio = when {
457
498
  hasReferenceAudio(options) && (inst.isZipvoice || inst.isPocket) -> {
458
499
  if (inst.isZipvoice) {
@@ -467,7 +508,11 @@ internal class SherpaOnnxTtsHelper(
467
508
  }
468
509
  }
469
510
  val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
470
- inst.tts!!.generateWithConfig(text, config)
511
+ inst.tts!!.generateWithConfigAndCallback(
512
+ text,
513
+ config,
514
+ ttsChunkCallbackForJni(sentenceChunkSizes)
515
+ )
471
516
  }
472
517
  hasReferenceAudio(options) -> {
473
518
  Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: Reference audio is not supported for this TTS model type")
@@ -485,12 +530,15 @@ internal class SherpaOnnxTtsHelper(
485
530
  )
486
531
  return
487
532
  }
488
- else -> dispatchGenerate(inst, text, sid, speed)
489
- ?: run {
533
+ else -> {
534
+ val tts = inst.tts
535
+ if (tts == null) {
490
536
  Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
491
537
  promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
492
538
  return
493
539
  }
540
+ tts.generateWithCallback(text, sid, speed, ttsChunkCallbackForJni(sentenceChunkSizes))
541
+ }
494
542
  }
495
543
  val map = Arguments.createMap()
496
544
  val samplesArray = Arguments.createArray()
@@ -564,18 +612,23 @@ internal class SherpaOnnxTtsHelper(
564
612
  when {
565
613
  hasReferenceAudio(options) && inst.isPocket -> {
566
614
  val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
567
- inst.tts!!.generateWithConfigAndCallback(text, config) { chunk ->
568
- if (inst.ttsStreamCancelled.get()) return@generateWithConfigAndCallback 0
569
- emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
570
- chunk.size
571
- }
615
+ inst.tts!!.generateWithConfigAndCallback(
616
+ text,
617
+ config,
618
+ ttsStreamChunkCallbackForJni(inst.ttsStreamCancelled) { chunk ->
619
+ emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
620
+ }
621
+ )
572
622
  }
573
623
  else -> {
574
- inst.tts!!.generateWithCallback(text, sid, speed) { chunk ->
575
- if (inst.ttsStreamCancelled.get()) return@generateWithCallback 0
576
- emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
577
- chunk.size
578
- }
624
+ inst.tts!!.generateWithCallback(
625
+ text,
626
+ sid,
627
+ speed,
628
+ ttsStreamChunkCallbackForJni(inst.ttsStreamCancelled) { chunk ->
629
+ emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
630
+ }
631
+ )
579
632
  }
580
633
  }
581
634
  if (!inst.ttsStreamCancelled.get()) {
@@ -0,0 +1,7 @@
1
+ asset_name,license_type,commercial_use,confidence,detection_source,license_file
2
+ dpdfnet2.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
3
+ dpdfnet2_48khz_hr.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
4
+ dpdfnet4.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
5
+ dpdfnet8.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
6
+ dpdfnet_baseline.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
7
+ gtcrn_simple.onnx,mit,yes,high,manual,https://github.com/Xiaobin-Rong/gtcrn/tree/main
@@ -319,6 +319,11 @@ static void collectModelFolderNames(NSFileManager *fileManager, NSString *path,
319
319
  return @"tts";
320
320
  }
321
321
 
322
+ BOOL isEnhancement = [name containsString:@"gtcrn"] || [name containsString:@"dpdfnet"];
323
+ if (isEnhancement) {
324
+ return @"enhancement";
325
+ }
326
+
322
327
  return @"unknown";
323
328
  }
324
329