react-native-sherpa-onnx 0.3.9 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. package/README.md +17 -4
  2. package/SherpaOnnx.podspec +1 -0
  3. package/android/prebuilt-download.gradle +67 -27
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  6. package/android/src/main/cpp/CMakeLists.txt +3 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  13. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  15. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  16. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
  17. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  18. package/ios/SherpaOnnx+Assets.mm +5 -0
  19. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  20. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  21. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  22. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  23. package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
  24. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  25. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  26. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  27. package/lib/module/download/localModels.js +2 -3
  28. package/lib/module/download/localModels.js.map +1 -1
  29. package/lib/module/download/paths.js +2 -1
  30. package/lib/module/download/paths.js.map +1 -1
  31. package/lib/module/enhancement/index.js +63 -48
  32. package/lib/module/enhancement/index.js.map +1 -1
  33. package/lib/module/enhancement/streaming.js +60 -0
  34. package/lib/module/enhancement/streaming.js.map +1 -0
  35. package/lib/module/enhancement/streamingTypes.js +4 -0
  36. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  37. package/lib/module/enhancement/types.js +4 -0
  38. package/lib/module/enhancement/types.js.map +1 -0
  39. package/lib/module/licenses.js +9 -3
  40. package/lib/module/licenses.js.map +1 -1
  41. package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
  42. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  43. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  44. package/lib/typescript/src/download/paths.d.ts +2 -1
  45. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  46. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  47. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  48. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  49. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  50. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  51. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  52. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  53. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  54. package/lib/typescript/src/licenses.d.ts.map +1 -1
  55. package/package.json +1 -1
  56. package/scripts/ci/check-model-csvs.sh +27 -2
  57. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  58. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  59. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  60. package/scripts/ci/update_model_license_csv.sh +1 -1
  61. package/src/NativeSherpaOnnx.ts +71 -0
  62. package/src/download/localModels.ts +1 -3
  63. package/src/download/paths.ts +2 -1
  64. package/src/enhancement/index.ts +120 -58
  65. package/src/enhancement/streaming.ts +105 -0
  66. package/src/enhancement/streamingTypes.ts +14 -0
  67. package/src/enhancement/types.ts +36 -0
  68. package/src/licenses.ts +13 -2
  69. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
382
382
  "mms",
383
383
  "tts"
384
384
  )
385
+ val enhancementHints = listOf(
386
+ "gtcrn",
387
+ "dpdfnet"
388
+ )
385
389
 
386
390
  val isStt = sttHints.any { name.contains(it) }
387
391
  val isTts = ttsHints.any { name.contains(it) }
392
+ val isEnhancement = enhancementHints.any { name.contains(it) }
388
393
 
389
394
  return when {
390
395
  isStt && !isTts -> "stt"
391
396
  isTts && !isStt -> "tts"
397
+ isEnhancement && !isStt && !isTts -> "enhancement"
392
398
  else -> "unknown"
393
399
  }
394
400
  }
@@ -0,0 +1,377 @@
1
+ package com.sherpaonnx
2
+
3
+ import android.net.Uri
4
+ import android.util.Log
5
+ import com.facebook.react.bridge.Arguments
6
+ import com.facebook.react.bridge.Promise
7
+ import com.facebook.react.bridge.ReadableArray
8
+ import com.facebook.react.bridge.ReactApplicationContext
9
+ import com.facebook.react.bridge.WritableMap
10
+ import com.k2fsa.sherpa.onnx.DenoisedAudio
11
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
12
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
13
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
14
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
15
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
16
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
17
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
18
+ import com.k2fsa.sherpa.onnx.WaveReader
19
+ import java.io.File
20
+ import java.util.concurrent.ConcurrentHashMap
21
+
22
+ internal class SherpaOnnxEnhancementHelper(
23
+ private val context: ReactApplicationContext,
24
+ private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
25
+ ) {
26
+ private data class EnhancementInstance(
27
+ @Volatile var denoiser: OfflineSpeechDenoiser? = null
28
+ ) {
29
+ fun release() {
30
+ denoiser?.release()
31
+ denoiser = null
32
+ }
33
+ }
34
+
35
+ private data class OnlineEnhancementInstance(
36
+ @Volatile var denoiser: OnlineSpeechDenoiser? = null
37
+ ) {
38
+ fun release() {
39
+ denoiser?.release()
40
+ denoiser = null
41
+ }
42
+ }
43
+
44
+ private val instances = ConcurrentHashMap<String, EnhancementInstance>()
45
+ private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
46
+
47
+ fun shutdown() {
48
+ instances.values.forEach { it.release() }
49
+ instances.clear()
50
+ onlineInstances.values.forEach { it.release() }
51
+ onlineInstances.clear()
52
+ }
53
+
54
+ private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
55
+
56
+ private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
57
+ val samples = Arguments.createArray()
58
+ for (sample in audio.samples) {
59
+ samples.pushDouble(sample.toDouble())
60
+ }
61
+ val out = Arguments.createMap()
62
+ out.putArray("samples", samples)
63
+ out.putInt("sampleRate", audio.sampleRate)
64
+ return out
65
+ }
66
+
67
+ private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
68
+ val out = FloatArray(samples.size())
69
+ for (i in 0 until samples.size()) {
70
+ out[i] = samples.getDouble(i).toFloat()
71
+ }
72
+ return out
73
+ }
74
+
75
+ private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
76
+ if (!path.startsWith("content://")) return Pair(path, null)
77
+ val uri = Uri.parse(path)
78
+ val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
79
+ context.contentResolver.openInputStream(uri)?.use { input ->
80
+ tmp.outputStream().use { output -> input.copyTo(output) }
81
+ } ?: throw IllegalStateException("File is not readable: $path")
82
+ return Pair(tmp.absolutePath, tmp)
83
+ }
84
+
85
+ fun detectEnhancementModel(
86
+ modelDir: String,
87
+ modelType: String?,
88
+ promise: Promise
89
+ ) {
90
+ try {
91
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
92
+ if (result == null) {
93
+ promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
94
+ return
95
+ }
96
+ val success = result["success"] as? Boolean ?: false
97
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
98
+ ?: arrayListOf<HashMap<String, String>>()
99
+ val modelTypeStr = result["modelType"] as? String
100
+
101
+ val resultMap = Arguments.createMap()
102
+ resultMap.putBoolean("success", success)
103
+ val modelsArray = Arguments.createArray()
104
+ for (model in detectedModels) {
105
+ val modelMap = model as? HashMap<*, *>
106
+ if (modelMap != null) {
107
+ val entry = Arguments.createMap()
108
+ entry.putString("type", modelMap["type"] as? String ?: "")
109
+ entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
110
+ modelsArray.pushMap(entry)
111
+ }
112
+ }
113
+ resultMap.putArray("detectedModels", modelsArray)
114
+ if (modelTypeStr != null) {
115
+ resultMap.putString("modelType", modelTypeStr)
116
+ }
117
+ if (!success) {
118
+ val error = result["error"] as? String
119
+ if (!error.isNullOrBlank()) {
120
+ resultMap.putString("error", error)
121
+ }
122
+ }
123
+ promise.resolve(resultMap)
124
+ } catch (e: Exception) {
125
+ Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
126
+ promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
127
+ }
128
+ }
129
+
130
+ fun initializeEnhancement(
131
+ instanceId: String,
132
+ modelDir: String,
133
+ modelType: String?,
134
+ numThreads: Double?,
135
+ provider: String?,
136
+ debug: Boolean?,
137
+ promise: Promise
138
+ ) {
139
+ try {
140
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
141
+ if (result == null || result["success"] as? Boolean != true) {
142
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
143
+ promise.reject("ENHANCEMENT_INIT_ERROR", reason)
144
+ return
145
+ }
146
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
147
+ val paths = (result["paths"] as? Map<*, *>)
148
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
149
+ ?.mapKeys { it.key.toString() }
150
+ ?: emptyMap()
151
+
152
+ val offlineModelConfig = when (modelTypeStr) {
153
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
154
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
155
+ numThreads = numThreads?.toInt() ?: 1,
156
+ provider = provider ?: "cpu",
157
+ debug = debug ?: false
158
+ )
159
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
160
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
161
+ numThreads = numThreads?.toInt() ?: 1,
162
+ provider = provider ?: "cpu",
163
+ debug = debug ?: false
164
+ )
165
+ else -> {
166
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
167
+ return
168
+ }
169
+ }
170
+
171
+ val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
172
+ inst.release()
173
+ val denoiser = OfflineSpeechDenoiser(
174
+ config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
175
+ )
176
+ inst.denoiser = denoiser
177
+
178
+ val modelsArray = Arguments.createArray()
179
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
180
+ detectedModels?.forEach { modelObj ->
181
+ if (modelObj is HashMap<*, *>) {
182
+ val modelMap = Arguments.createMap()
183
+ modelMap.putString("type", modelObj["type"] as? String ?: "")
184
+ modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
185
+ modelsArray.pushMap(modelMap)
186
+ }
187
+ }
188
+
189
+ val out = Arguments.createMap()
190
+ out.putBoolean("success", true)
191
+ out.putArray("detectedModels", modelsArray)
192
+ out.putString("modelType", modelTypeStr)
193
+ out.putInt("sampleRate", denoiser.sampleRate)
194
+ promise.resolve(out)
195
+ } catch (e: Exception) {
196
+ Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
197
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
198
+ }
199
+ }
200
+
201
+ fun enhanceSamples(
202
+ instanceId: String,
203
+ samples: ReadableArray,
204
+ sampleRate: Double,
205
+ promise: Promise
206
+ ) {
207
+ val inst = instances[instanceId]
208
+ val denoiser = inst?.denoiser
209
+ if (denoiser == null) {
210
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
211
+ return
212
+ }
213
+ try {
214
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
215
+ promise.resolve(toEnhancedAudioMap(audio))
216
+ } catch (e: Exception) {
217
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
218
+ }
219
+ }
220
+
221
+ fun enhanceFile(
222
+ instanceId: String,
223
+ inputPath: String,
224
+ outputPath: String?,
225
+ promise: Promise
226
+ ) {
227
+ val inst = instances[instanceId]
228
+ val denoiser = inst?.denoiser
229
+ if (denoiser == null) {
230
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
231
+ return
232
+ }
233
+
234
+ var tmpInput: File? = null
235
+ try {
236
+ val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
237
+ tmpInput = tmp
238
+ val wave = WaveReader.readWave(resolvedInputPath)
239
+ val audio = denoiser.run(wave.samples, wave.sampleRate)
240
+ if (!outputPath.isNullOrBlank()) {
241
+ audio.save(outputPath)
242
+ }
243
+ promise.resolve(toEnhancedAudioMap(audio))
244
+ } catch (e: Exception) {
245
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
246
+ } finally {
247
+ tmpInput?.delete()
248
+ }
249
+ }
250
+
251
+ fun getSampleRate(instanceId: String, promise: Promise) {
252
+ val inst = instances[instanceId]
253
+ val denoiser = inst?.denoiser
254
+ if (denoiser == null) {
255
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
256
+ return
257
+ }
258
+ promise.resolve(denoiser.sampleRate)
259
+ }
260
+
261
+ fun unloadEnhancement(instanceId: String, promise: Promise) {
262
+ instances.remove(instanceId)?.release()
263
+ promise.resolve(null)
264
+ }
265
+
266
+ fun initializeOnlineEnhancement(
267
+ instanceId: String,
268
+ modelDir: String,
269
+ modelType: String?,
270
+ numThreads: Double?,
271
+ provider: String?,
272
+ debug: Boolean?,
273
+ promise: Promise
274
+ ) {
275
+ try {
276
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
277
+ if (result == null || result["success"] as? Boolean != true) {
278
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
279
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
280
+ return
281
+ }
282
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
283
+ val paths = (result["paths"] as? Map<*, *>)
284
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
285
+ ?.mapKeys { it.key.toString() }
286
+ ?: emptyMap()
287
+
288
+ val offlineModelConfig = when (modelTypeStr) {
289
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
290
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
291
+ numThreads = numThreads?.toInt() ?: 1,
292
+ provider = provider ?: "cpu",
293
+ debug = debug ?: false
294
+ )
295
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
296
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
297
+ numThreads = numThreads?.toInt() ?: 1,
298
+ provider = provider ?: "cpu",
299
+ debug = debug ?: false
300
+ )
301
+ else -> {
302
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
303
+ return
304
+ }
305
+ }
306
+
307
+ val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
308
+ inst.release()
309
+ val denoiser = OnlineSpeechDenoiser(
310
+ config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
311
+ )
312
+ inst.denoiser = denoiser
313
+
314
+ val out = Arguments.createMap()
315
+ out.putBoolean("success", true)
316
+ out.putInt("sampleRate", denoiser.sampleRate)
317
+ out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
318
+ promise.resolve(out)
319
+ } catch (e: Exception) {
320
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
321
+ }
322
+ }
323
+
324
+ fun feedSamples(
325
+ instanceId: String,
326
+ samples: ReadableArray,
327
+ sampleRate: Double,
328
+ promise: Promise
329
+ ) {
330
+ val inst = onlineInstances[instanceId]
331
+ val denoiser = inst?.denoiser
332
+ if (denoiser == null) {
333
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
334
+ return
335
+ }
336
+ try {
337
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
338
+ promise.resolve(toEnhancedAudioMap(audio))
339
+ } catch (e: Exception) {
340
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
341
+ }
342
+ }
343
+
344
+ fun flushOnline(instanceId: String, promise: Promise) {
345
+ val inst = onlineInstances[instanceId]
346
+ val denoiser = inst?.denoiser
347
+ if (denoiser == null) {
348
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
349
+ return
350
+ }
351
+ try {
352
+ promise.resolve(toEnhancedAudioMap(denoiser.flush()))
353
+ } catch (e: Exception) {
354
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
355
+ }
356
+ }
357
+
358
+ fun resetOnline(instanceId: String, promise: Promise) {
359
+ val inst = onlineInstances[instanceId]
360
+ val denoiser = inst?.denoiser
361
+ if (denoiser == null) {
362
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
363
+ return
364
+ }
365
+ try {
366
+ denoiser.reset()
367
+ promise.resolve(null)
368
+ } catch (e: Exception) {
369
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
370
+ }
371
+ }
372
+
373
+ fun unloadOnline(instanceId: String, promise: Promise) {
374
+ onlineInstances.remove(instanceId)?.release()
375
+ promise.resolve(null)
376
+ }
377
+ }
@@ -56,6 +56,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
56
56
  { instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
57
57
  { instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
58
58
  )
59
+ private val enhancementHelper = SherpaOnnxEnhancementHelper(
60
+ reactApplicationContext,
61
+ { modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
62
+ )
59
63
  private val archiveHelper = SherpaOnnxArchiveHelper()
60
64
  private var pcmCapture: SherpaOnnxPcmCapture? = null
61
65
 
@@ -69,6 +73,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
69
73
  pcmCapture = null
70
74
  onlineSttHelper.shutdown()
71
75
  ttsHelper.shutdown()
76
+ enhancementHelper.shutdown()
72
77
  }
73
78
 
74
79
  /**
@@ -1059,6 +1064,103 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1059
1064
  ttsHelper.unloadTts(instanceId, promise)
1060
1065
  }
1061
1066
 
1067
+ // ==================== Speech Enhancement Methods ====================
1068
+
1069
+ override fun detectEnhancementModel(
1070
+ modelDir: String,
1071
+ modelType: String?,
1072
+ promise: Promise
1073
+ ) {
1074
+ enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
1075
+ }
1076
+
1077
+ override fun initializeEnhancement(
1078
+ instanceId: String,
1079
+ modelDir: String,
1080
+ modelType: String?,
1081
+ numThreads: Double?,
1082
+ provider: String?,
1083
+ debug: Boolean?,
1084
+ promise: Promise
1085
+ ) {
1086
+ enhancementHelper.initializeEnhancement(
1087
+ instanceId,
1088
+ modelDir,
1089
+ modelType,
1090
+ numThreads,
1091
+ provider,
1092
+ debug,
1093
+ promise
1094
+ )
1095
+ }
1096
+
1097
+ override fun enhanceFile(
1098
+ instanceId: String,
1099
+ inputPath: String,
1100
+ outputPath: String?,
1101
+ promise: Promise
1102
+ ) {
1103
+ enhancementHelper.enhanceFile(instanceId, inputPath, outputPath, promise)
1104
+ }
1105
+
1106
+ override fun enhanceSamples(
1107
+ instanceId: String,
1108
+ samples: ReadableArray,
1109
+ sampleRate: Double,
1110
+ promise: Promise
1111
+ ) {
1112
+ enhancementHelper.enhanceSamples(instanceId, samples, sampleRate, promise)
1113
+ }
1114
+
1115
+ override fun getEnhancementSampleRate(instanceId: String, promise: Promise) {
1116
+ enhancementHelper.getSampleRate(instanceId, promise)
1117
+ }
1118
+
1119
+ override fun unloadEnhancement(instanceId: String, promise: Promise) {
1120
+ enhancementHelper.unloadEnhancement(instanceId, promise)
1121
+ }
1122
+
1123
+ override fun initializeOnlineEnhancement(
1124
+ instanceId: String,
1125
+ modelDir: String,
1126
+ modelType: String?,
1127
+ numThreads: Double?,
1128
+ provider: String?,
1129
+ debug: Boolean?,
1130
+ promise: Promise
1131
+ ) {
1132
+ enhancementHelper.initializeOnlineEnhancement(
1133
+ instanceId,
1134
+ modelDir,
1135
+ modelType,
1136
+ numThreads,
1137
+ provider,
1138
+ debug,
1139
+ promise
1140
+ )
1141
+ }
1142
+
1143
+ override fun feedEnhancementSamples(
1144
+ instanceId: String,
1145
+ samples: ReadableArray,
1146
+ sampleRate: Double,
1147
+ promise: Promise
1148
+ ) {
1149
+ enhancementHelper.feedSamples(instanceId, samples, sampleRate, promise)
1150
+ }
1151
+
1152
+ override fun flushOnlineEnhancement(instanceId: String, promise: Promise) {
1153
+ enhancementHelper.flushOnline(instanceId, promise)
1154
+ }
1155
+
1156
+ override fun resetOnlineEnhancement(instanceId: String, promise: Promise) {
1157
+ enhancementHelper.resetOnline(instanceId, promise)
1158
+ }
1159
+
1160
+ override fun unloadOnlineEnhancement(instanceId: String, promise: Promise) {
1161
+ enhancementHelper.unloadOnline(instanceId, promise)
1162
+ }
1163
+
1062
1164
  /**
1063
1165
  * Save TTS audio samples to a WAV file.
1064
1166
  */
@@ -1256,6 +1358,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
1256
1358
  @JvmStatic
1257
1359
  private external fun nativeDetectTtsModel(modelDir: String, modelType: String): HashMap<String, Any>?
1258
1360
 
1361
+ /** Model detection for speech enhancement: returns HashMap with success, error, detectedModels, modelType, paths. */
1362
+ @JvmStatic
1363
+ private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
1364
+
1259
1365
  /** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
1260
1366
  * outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
1261
1367
  * Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.
@@ -0,0 +1,7 @@
1
+ asset_name,license_type,commercial_use,confidence,detection_source,license_file
2
+ dpdfnet2.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
3
+ dpdfnet2_48khz_hr.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
4
+ dpdfnet4.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
5
+ dpdfnet8.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
6
+ dpdfnet_baseline.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
7
+ gtcrn_simple.onnx,mit,yes,high,manual,https://github.com/Xiaobin-Rong/gtcrn/tree/main
@@ -319,6 +319,11 @@ static void collectModelFolderNames(NSFileManager *fileManager, NSString *path,
319
319
  return @"tts";
320
320
  }
321
321
 
322
+ BOOL isEnhancement = [name containsString:@"gtcrn"] || [name containsString:@"dpdfnet"];
323
+ if (isEnhancement) {
324
+ return @"enhancement";
325
+ }
326
+
322
327
  return @"unknown";
323
328
  }
324
329