react-native-sherpa-onnx 0.3.8 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +20 -5
  2. package/SherpaOnnx.podspec +5 -1
  3. package/android/prebuilt-download.gradle +89 -49
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/asr-models-license-status.csv +1 -0
  6. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  7. package/android/src/main/cpp/CMakeLists.txt +3 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +23 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +9 -0
  13. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +51 -8
  14. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +41 -0
  15. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +5 -0
  16. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  17. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  18. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-stt.cpp +11 -0
  19. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  20. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +110 -35
  21. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  22. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  23. package/android/src/main/java/com/sherpaonnx/SherpaOnnxExtractionNotificationHelper.kt +102 -0
  24. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +198 -18
  25. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +22 -0
  26. package/ios/Resources/model_licenses/asr-models-license-status.csv +1 -0
  27. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  28. package/ios/SherpaOnnx+Assets.mm +5 -0
  29. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  30. package/ios/SherpaOnnx+STT.mm +13 -1
  31. package/ios/SherpaOnnx.mm +87 -17
  32. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  33. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  34. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  35. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +5 -0
  36. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +23 -0
  37. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +51 -7
  38. package/ios/model_detect/sherpa-onnx-model-detect.h +33 -0
  39. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  40. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  41. package/ios/model_detect/sherpa-onnx-validate-stt.mm +11 -0
  42. package/ios/stt/sherpa-onnx-stt-wrapper.h +11 -1
  43. package/ios/stt/sherpa-onnx-stt-wrapper.mm +30 -2
  44. package/ios/tts/sherpa-onnx-tts-wrapper.mm +16 -0
  45. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  46. package/lib/module/download/localModels.js +2 -3
  47. package/lib/module/download/localModels.js.map +1 -1
  48. package/lib/module/download/paths.js +2 -1
  49. package/lib/module/download/paths.js.map +1 -1
  50. package/lib/module/download/postDownloadProcessing.js +17 -4
  51. package/lib/module/download/postDownloadProcessing.js.map +1 -1
  52. package/lib/module/enhancement/index.js +63 -48
  53. package/lib/module/enhancement/index.js.map +1 -1
  54. package/lib/module/enhancement/streaming.js +60 -0
  55. package/lib/module/enhancement/streaming.js.map +1 -0
  56. package/lib/module/enhancement/streamingTypes.js +4 -0
  57. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  58. package/lib/module/enhancement/types.js +4 -0
  59. package/lib/module/enhancement/types.js.map +1 -0
  60. package/lib/module/extraction/extractTarBz2.js +2 -2
  61. package/lib/module/extraction/extractTarBz2.js.map +1 -1
  62. package/lib/module/extraction/extractTarZst.js +2 -2
  63. package/lib/module/extraction/extractTarZst.js.map +1 -1
  64. package/lib/module/extraction/index.js +10 -5
  65. package/lib/module/extraction/index.js.map +1 -1
  66. package/lib/module/licenses.js +9 -3
  67. package/lib/module/licenses.js.map +1 -1
  68. package/lib/module/stt/index.js +4 -2
  69. package/lib/module/stt/index.js.map +1 -1
  70. package/lib/module/stt/streaming.js +2 -1
  71. package/lib/module/stt/streaming.js.map +1 -1
  72. package/lib/module/stt/types.js +3 -1
  73. package/lib/module/stt/types.js.map +1 -1
  74. package/lib/module/tts/index.js +4 -2
  75. package/lib/module/tts/index.js.map +1 -1
  76. package/lib/module/tts/streaming.js +3 -1
  77. package/lib/module/tts/streaming.js.map +1 -1
  78. package/lib/typescript/src/NativeSherpaOnnx.d.ts +70 -9
  79. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  80. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  81. package/lib/typescript/src/download/paths.d.ts +2 -1
  82. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  83. package/lib/typescript/src/download/postDownloadProcessing.d.ts +9 -0
  84. package/lib/typescript/src/download/postDownloadProcessing.d.ts.map +1 -1
  85. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  86. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  87. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  88. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  89. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  90. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  91. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  92. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  93. package/lib/typescript/src/extraction/extractTarBz2.d.ts +2 -1
  94. package/lib/typescript/src/extraction/extractTarBz2.d.ts.map +1 -1
  95. package/lib/typescript/src/extraction/extractTarZst.d.ts +2 -1
  96. package/lib/typescript/src/extraction/extractTarZst.d.ts.map +1 -1
  97. package/lib/typescript/src/extraction/index.d.ts +1 -1
  98. package/lib/typescript/src/extraction/index.d.ts.map +1 -1
  99. package/lib/typescript/src/extraction/types.d.ts +12 -0
  100. package/lib/typescript/src/extraction/types.d.ts.map +1 -1
  101. package/lib/typescript/src/licenses.d.ts.map +1 -1
  102. package/lib/typescript/src/stt/index.d.ts +1 -1
  103. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  104. package/lib/typescript/src/stt/streaming.d.ts.map +1 -1
  105. package/lib/typescript/src/stt/types.d.ts +16 -1
  106. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  107. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  108. package/lib/typescript/src/tts/streaming.d.ts.map +1 -1
  109. package/package.json +1 -1
  110. package/scripts/ci/check-model-csvs.sh +27 -2
  111. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  112. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  113. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  114. package/scripts/ci/update_model_license_csv.sh +17 -17
  115. package/src/NativeSherpaOnnx.ts +108 -10
  116. package/src/download/localModels.ts +1 -3
  117. package/src/download/paths.ts +2 -1
  118. package/src/download/postDownloadProcessing.ts +24 -1
  119. package/src/enhancement/index.ts +120 -58
  120. package/src/enhancement/streaming.ts +105 -0
  121. package/src/enhancement/streamingTypes.ts +14 -0
  122. package/src/enhancement/types.ts +36 -0
  123. package/src/extraction/extractTarBz2.ts +7 -2
  124. package/src/extraction/extractTarZst.ts +7 -2
  125. package/src/extraction/index.ts +29 -6
  126. package/src/extraction/types.ts +16 -0
  127. package/src/licenses.ts +13 -2
  128. package/src/stt/index.ts +8 -7
  129. package/src/stt/streaming.ts +7 -1
  130. package/src/stt/types.ts +18 -0
  131. package/src/tts/index.ts +7 -7
  132. package/src/tts/streaming.ts +6 -3
  133. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
  134. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
@@ -54,7 +54,8 @@ class SherpaOnnxArchiveHelper {
54
54
  targetPath: String,
55
55
  force: Boolean,
56
56
  promise: Promise,
57
- onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
57
+ onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
58
+ extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
58
59
  ) {
59
60
  val promiseSettled = AtomicBoolean(false)
60
61
  fun resolveOnce(success: Boolean, reason: String? = null) {
@@ -70,26 +71,28 @@ class SherpaOnnxArchiveHelper {
70
71
  val cancelFlag = AtomicBoolean(false)
71
72
  cancelFlags[sourcePath] = cancelFlag
72
73
 
73
- // Create a progress callback object that JNI can call
74
- val progressCallback = object : Any() {
75
- fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
76
- onProgress(bytesExtracted, totalBytes, percent)
77
- }
78
- }
79
-
80
74
  // Run extraction on a background thread so the React Native bridge thread is not blocked.
81
75
  // The thread pool allows multiple extractions in parallel.
82
76
  extractExecutor.execute {
77
+ val notif = extractionNotification
83
78
  try {
84
79
  // Check per-path cancel flag before starting the native extraction.
85
80
  if (cancelFlag.get()) {
86
81
  resolveOnce(false, "Cancelled")
87
82
  return@execute
88
83
  }
89
- nativeExtractTarBz2(sourcePath, targetPath, force, progressCallback, promise)
84
+ notif?.start()
85
+ val wrappedCallback = object : Any() {
86
+ fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
87
+ onProgress(bytesExtracted, totalBytes, percent)
88
+ notif?.updateProgress(percent)
89
+ }
90
+ }
91
+ nativeExtractTarBz2(sourcePath, targetPath, force, wrappedCallback, promise)
90
92
  } catch (e: Exception) {
91
93
  resolveOnce(false, "Archive extraction error: ${e.message}")
92
94
  } finally {
95
+ notif?.finish()
93
96
  cancelFlags.remove(sourcePath)
94
97
  }
95
98
  }
@@ -104,7 +107,8 @@ class SherpaOnnxArchiveHelper {
104
107
  targetPath: String,
105
108
  force: Boolean,
106
109
  promise: Promise,
107
- onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
110
+ onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
111
+ extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
108
112
  ) {
109
113
  val promiseSettled = AtomicBoolean(false)
110
114
  fun resolveOnce(success: Boolean, reason: String? = null) {
@@ -119,22 +123,26 @@ class SherpaOnnxArchiveHelper {
119
123
  val cancelFlag = AtomicBoolean(false)
120
124
  cancelFlags[sourcePath] = cancelFlag
121
125
 
122
- val progressCallback = object : Any() {
123
- fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
124
- onProgress(bytesExtracted, totalBytes, percent)
125
- }
126
- }
127
126
  extractExecutor.execute {
127
+ val notif = extractionNotification
128
128
  try {
129
129
  // Check per-path cancel flag before starting the native extraction.
130
130
  if (cancelFlag.get()) {
131
131
  resolveOnce(false, "Cancelled")
132
132
  return@execute
133
133
  }
134
- nativeExtractTarZst(sourcePath, targetPath, force, progressCallback, promise)
134
+ notif?.start()
135
+ val wrappedCallback = object : Any() {
136
+ fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
137
+ onProgress(bytesExtracted, totalBytes, percent)
138
+ notif?.updateProgress(percent)
139
+ }
140
+ }
141
+ nativeExtractTarZst(sourcePath, targetPath, force, wrappedCallback, promise)
135
142
  } catch (e: Exception) {
136
143
  resolveOnce(false, "Archive extraction error: ${e.message}")
137
144
  } finally {
145
+ notif?.finish()
138
146
  cancelFlags.remove(sourcePath)
139
147
  }
140
148
  }
@@ -144,47 +152,106 @@ class SherpaOnnxArchiveHelper {
144
152
  }
145
153
  }
146
154
 
155
+ /**
156
+ * Which JNI stream entry to use for APK asset extraction.
157
+ *
158
+ * Both paths invoke libarchive’s `ExtractFromStream`, which **auto-detects** compression
159
+ * (`.tar.zst` vs `.tar.bz2`, etc.); `nativeExtractTarBz2FromStream` forwards to the same
160
+ * native implementation as zst. Keeping distinct JNI symbols preserves a clear API and avoids
161
+ * the impression that bz2 assets are mistakenly wired only to a “zst” method.
162
+ */
163
+ private enum class AssetTarStreamKind {
164
+ ZST,
165
+ BZ2,
166
+ }
167
+
147
168
  fun extractTarZstFromAsset(
148
169
  context: Context,
149
170
  assetPath: String,
150
171
  targetPath: String,
151
172
  force: Boolean,
152
173
  promise: Promise,
153
- onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
174
+ onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
175
+ extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
176
+ ) {
177
+ extractTarArchiveFromAsset(
178
+ context,
179
+ assetPath,
180
+ targetPath,
181
+ force,
182
+ promise,
183
+ onProgress,
184
+ extractionNotification,
185
+ AssetTarStreamKind.ZST,
186
+ )
187
+ }
188
+
189
+ fun extractTarBz2FromAsset(
190
+ context: Context,
191
+ assetPath: String,
192
+ targetPath: String,
193
+ force: Boolean,
194
+ promise: Promise,
195
+ onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
196
+ extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
197
+ ) {
198
+ extractTarArchiveFromAsset(
199
+ context,
200
+ assetPath,
201
+ targetPath,
202
+ force,
203
+ promise,
204
+ onProgress,
205
+ extractionNotification,
206
+ AssetTarStreamKind.BZ2,
207
+ )
208
+ }
209
+
210
+ private fun extractTarArchiveFromAsset(
211
+ context: Context,
212
+ assetPath: String,
213
+ targetPath: String,
214
+ force: Boolean,
215
+ promise: Promise,
216
+ onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
217
+ extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
218
+ kind: AssetTarStreamKind,
154
219
  ) {
155
220
  if (BuildConfig.DEBUG) {
156
- Log.i("SherpaOnnx", "extractTarZstFromAsset assetPath=$assetPath targetPath=$targetPath")
157
- }
158
- val progressCallback = object : Any() {
159
- fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
160
- onProgress(bytesExtracted, totalBytes, percent)
161
- }
221
+ Log.i(
222
+ "SherpaOnnx",
223
+ "extractTar${if (kind == AssetTarStreamKind.ZST) "Zst" else "Bz2"}FromAsset assetPath=$assetPath targetPath=$targetPath",
224
+ )
162
225
  }
163
226
  extractExecutor.execute {
227
+ val notif = extractionNotification
164
228
  try {
229
+ notif?.start()
230
+ val progressCallback = object : Any() {
231
+ fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
232
+ onProgress(bytesExtracted, totalBytes, percent)
233
+ notif?.updateProgress(percent)
234
+ }
235
+ }
165
236
  context.assets.open(assetPath).use { stream ->
166
- nativeExtractTarZstFromStream(stream, targetPath, force, progressCallback, promise)
237
+ when (kind) {
238
+ AssetTarStreamKind.ZST ->
239
+ nativeExtractTarZstFromStream(stream, targetPath, force, progressCallback, promise)
240
+ AssetTarStreamKind.BZ2 ->
241
+ nativeExtractTarBz2FromStream(stream, targetPath, force, progressCallback, promise)
242
+ }
167
243
  }
168
244
  } catch (e: Exception) {
169
245
  val result = Arguments.createMap()
170
246
  result.putBoolean("success", false)
171
247
  result.putString("reason", e.message ?: "Failed to open asset")
172
248
  promise.resolve(result)
249
+ } finally {
250
+ notif?.finish()
173
251
  }
174
252
  }
175
253
  }
176
254
 
177
- fun extractTarBz2FromAsset(
178
- context: Context,
179
- assetPath: String,
180
- targetPath: String,
181
- force: Boolean,
182
- promise: Promise,
183
- onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
184
- ) {
185
- extractTarZstFromAsset(context, assetPath, targetPath, force, promise, onProgress)
186
- }
187
-
188
255
  fun computeFileSha256(filePath: String, promise: Promise) {
189
256
  nativeComputeFileSha256(filePath, promise)
190
257
  }
@@ -214,6 +281,14 @@ class SherpaOnnxArchiveHelper {
214
281
  promise: Promise
215
282
  )
216
283
 
284
+ private external fun nativeExtractTarBz2FromStream(
285
+ inputStream: java.io.InputStream,
286
+ targetPath: String,
287
+ force: Boolean,
288
+ progressCallback: Any?,
289
+ promise: Promise
290
+ )
291
+
217
292
  private external fun nativeCancelExtract()
218
293
 
219
294
  private external fun nativeComputeFileSha256(
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
382
382
  "mms",
383
383
  "tts"
384
384
  )
385
+ val enhancementHints = listOf(
386
+ "gtcrn",
387
+ "dpdfnet"
388
+ )
385
389
 
386
390
  val isStt = sttHints.any { name.contains(it) }
387
391
  val isTts = ttsHints.any { name.contains(it) }
392
+ val isEnhancement = enhancementHints.any { name.contains(it) }
388
393
 
389
394
  return when {
390
395
  isStt && !isTts -> "stt"
391
396
  isTts && !isStt -> "tts"
397
+ isEnhancement && !isStt && !isTts -> "enhancement"
392
398
  else -> "unknown"
393
399
  }
394
400
  }
@@ -0,0 +1,377 @@
1
+ package com.sherpaonnx
2
+
3
+ import android.net.Uri
4
+ import android.util.Log
5
+ import com.facebook.react.bridge.Arguments
6
+ import com.facebook.react.bridge.Promise
7
+ import com.facebook.react.bridge.ReadableArray
8
+ import com.facebook.react.bridge.ReactApplicationContext
9
+ import com.facebook.react.bridge.WritableMap
10
+ import com.k2fsa.sherpa.onnx.DenoisedAudio
11
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
12
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
13
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
14
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
15
+ import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
16
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
17
+ import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
18
+ import com.k2fsa.sherpa.onnx.WaveReader
19
+ import java.io.File
20
+ import java.util.concurrent.ConcurrentHashMap
21
+
22
+ internal class SherpaOnnxEnhancementHelper(
23
+ private val context: ReactApplicationContext,
24
+ private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
25
+ ) {
26
+ private data class EnhancementInstance(
27
+ @Volatile var denoiser: OfflineSpeechDenoiser? = null
28
+ ) {
29
+ fun release() {
30
+ denoiser?.release()
31
+ denoiser = null
32
+ }
33
+ }
34
+
35
+ private data class OnlineEnhancementInstance(
36
+ @Volatile var denoiser: OnlineSpeechDenoiser? = null
37
+ ) {
38
+ fun release() {
39
+ denoiser?.release()
40
+ denoiser = null
41
+ }
42
+ }
43
+
44
+ private val instances = ConcurrentHashMap<String, EnhancementInstance>()
45
+ private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
46
+
47
+ fun shutdown() {
48
+ instances.values.forEach { it.release() }
49
+ instances.clear()
50
+ onlineInstances.values.forEach { it.release() }
51
+ onlineInstances.clear()
52
+ }
53
+
54
+ private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
55
+
56
+ private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
57
+ val samples = Arguments.createArray()
58
+ for (sample in audio.samples) {
59
+ samples.pushDouble(sample.toDouble())
60
+ }
61
+ val out = Arguments.createMap()
62
+ out.putArray("samples", samples)
63
+ out.putInt("sampleRate", audio.sampleRate)
64
+ return out
65
+ }
66
+
67
+ private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
68
+ val out = FloatArray(samples.size())
69
+ for (i in 0 until samples.size()) {
70
+ out[i] = samples.getDouble(i).toFloat()
71
+ }
72
+ return out
73
+ }
74
+
75
+ private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
76
+ if (!path.startsWith("content://")) return Pair(path, null)
77
+ val uri = Uri.parse(path)
78
+ val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
79
+ context.contentResolver.openInputStream(uri)?.use { input ->
80
+ tmp.outputStream().use { output -> input.copyTo(output) }
81
+ } ?: throw IllegalStateException("File is not readable: $path")
82
+ return Pair(tmp.absolutePath, tmp)
83
+ }
84
+
85
+ fun detectEnhancementModel(
86
+ modelDir: String,
87
+ modelType: String?,
88
+ promise: Promise
89
+ ) {
90
+ try {
91
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
92
+ if (result == null) {
93
+ promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
94
+ return
95
+ }
96
+ val success = result["success"] as? Boolean ?: false
97
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
98
+ ?: arrayListOf<HashMap<String, String>>()
99
+ val modelTypeStr = result["modelType"] as? String
100
+
101
+ val resultMap = Arguments.createMap()
102
+ resultMap.putBoolean("success", success)
103
+ val modelsArray = Arguments.createArray()
104
+ for (model in detectedModels) {
105
+ val modelMap = model as? HashMap<*, *>
106
+ if (modelMap != null) {
107
+ val entry = Arguments.createMap()
108
+ entry.putString("type", modelMap["type"] as? String ?: "")
109
+ entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
110
+ modelsArray.pushMap(entry)
111
+ }
112
+ }
113
+ resultMap.putArray("detectedModels", modelsArray)
114
+ if (modelTypeStr != null) {
115
+ resultMap.putString("modelType", modelTypeStr)
116
+ }
117
+ if (!success) {
118
+ val error = result["error"] as? String
119
+ if (!error.isNullOrBlank()) {
120
+ resultMap.putString("error", error)
121
+ }
122
+ }
123
+ promise.resolve(resultMap)
124
+ } catch (e: Exception) {
125
+ Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
126
+ promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
127
+ }
128
+ }
129
+
130
+ fun initializeEnhancement(
131
+ instanceId: String,
132
+ modelDir: String,
133
+ modelType: String?,
134
+ numThreads: Double?,
135
+ provider: String?,
136
+ debug: Boolean?,
137
+ promise: Promise
138
+ ) {
139
+ try {
140
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
141
+ if (result == null || result["success"] as? Boolean != true) {
142
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
143
+ promise.reject("ENHANCEMENT_INIT_ERROR", reason)
144
+ return
145
+ }
146
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
147
+ val paths = (result["paths"] as? Map<*, *>)
148
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
149
+ ?.mapKeys { it.key.toString() }
150
+ ?: emptyMap()
151
+
152
+ val offlineModelConfig = when (modelTypeStr) {
153
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
154
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
155
+ numThreads = numThreads?.toInt() ?: 1,
156
+ provider = provider ?: "cpu",
157
+ debug = debug ?: false
158
+ )
159
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
160
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
161
+ numThreads = numThreads?.toInt() ?: 1,
162
+ provider = provider ?: "cpu",
163
+ debug = debug ?: false
164
+ )
165
+ else -> {
166
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
167
+ return
168
+ }
169
+ }
170
+
171
+ val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
172
+ inst.release()
173
+ val denoiser = OfflineSpeechDenoiser(
174
+ config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
175
+ )
176
+ inst.denoiser = denoiser
177
+
178
+ val modelsArray = Arguments.createArray()
179
+ val detectedModels = result["detectedModels"] as? ArrayList<*>
180
+ detectedModels?.forEach { modelObj ->
181
+ if (modelObj is HashMap<*, *>) {
182
+ val modelMap = Arguments.createMap()
183
+ modelMap.putString("type", modelObj["type"] as? String ?: "")
184
+ modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
185
+ modelsArray.pushMap(modelMap)
186
+ }
187
+ }
188
+
189
+ val out = Arguments.createMap()
190
+ out.putBoolean("success", true)
191
+ out.putArray("detectedModels", modelsArray)
192
+ out.putString("modelType", modelTypeStr)
193
+ out.putInt("sampleRate", denoiser.sampleRate)
194
+ promise.resolve(out)
195
+ } catch (e: Exception) {
196
+ Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
197
+ promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
198
+ }
199
+ }
200
+
201
+ fun enhanceSamples(
202
+ instanceId: String,
203
+ samples: ReadableArray,
204
+ sampleRate: Double,
205
+ promise: Promise
206
+ ) {
207
+ val inst = instances[instanceId]
208
+ val denoiser = inst?.denoiser
209
+ if (denoiser == null) {
210
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
211
+ return
212
+ }
213
+ try {
214
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
215
+ promise.resolve(toEnhancedAudioMap(audio))
216
+ } catch (e: Exception) {
217
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
218
+ }
219
+ }
220
+
221
+ fun enhanceFile(
222
+ instanceId: String,
223
+ inputPath: String,
224
+ outputPath: String?,
225
+ promise: Promise
226
+ ) {
227
+ val inst = instances[instanceId]
228
+ val denoiser = inst?.denoiser
229
+ if (denoiser == null) {
230
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
231
+ return
232
+ }
233
+
234
+ var tmpInput: File? = null
235
+ try {
236
+ val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
237
+ tmpInput = tmp
238
+ val wave = WaveReader.readWave(resolvedInputPath)
239
+ val audio = denoiser.run(wave.samples, wave.sampleRate)
240
+ if (!outputPath.isNullOrBlank()) {
241
+ audio.save(outputPath)
242
+ }
243
+ promise.resolve(toEnhancedAudioMap(audio))
244
+ } catch (e: Exception) {
245
+ promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
246
+ } finally {
247
+ tmpInput?.delete()
248
+ }
249
+ }
250
+
251
+ fun getSampleRate(instanceId: String, promise: Promise) {
252
+ val inst = instances[instanceId]
253
+ val denoiser = inst?.denoiser
254
+ if (denoiser == null) {
255
+ promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
256
+ return
257
+ }
258
+ promise.resolve(denoiser.sampleRate)
259
+ }
260
+
261
+ fun unloadEnhancement(instanceId: String, promise: Promise) {
262
+ instances.remove(instanceId)?.release()
263
+ promise.resolve(null)
264
+ }
265
+
266
+ fun initializeOnlineEnhancement(
267
+ instanceId: String,
268
+ modelDir: String,
269
+ modelType: String?,
270
+ numThreads: Double?,
271
+ provider: String?,
272
+ debug: Boolean?,
273
+ promise: Promise
274
+ ) {
275
+ try {
276
+ val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
277
+ if (result == null || result["success"] as? Boolean != true) {
278
+ val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
279
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
280
+ return
281
+ }
282
+ val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
283
+ val paths = (result["paths"] as? Map<*, *>)
284
+ ?.mapValues { (_, v) -> (v as? String).orEmpty() }
285
+ ?.mapKeys { it.key.toString() }
286
+ ?: emptyMap()
287
+
288
+ val offlineModelConfig = when (modelTypeStr) {
289
+ "gtcrn" -> OfflineSpeechDenoiserModelConfig(
290
+ gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
291
+ numThreads = numThreads?.toInt() ?: 1,
292
+ provider = provider ?: "cpu",
293
+ debug = debug ?: false
294
+ )
295
+ "dpdfnet" -> OfflineSpeechDenoiserModelConfig(
296
+ dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
297
+ numThreads = numThreads?.toInt() ?: 1,
298
+ provider = provider ?: "cpu",
299
+ debug = debug ?: false
300
+ )
301
+ else -> {
302
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
303
+ return
304
+ }
305
+ }
306
+
307
+ val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
308
+ inst.release()
309
+ val denoiser = OnlineSpeechDenoiser(
310
+ config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
311
+ )
312
+ inst.denoiser = denoiser
313
+
314
+ val out = Arguments.createMap()
315
+ out.putBoolean("success", true)
316
+ out.putInt("sampleRate", denoiser.sampleRate)
317
+ out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
318
+ promise.resolve(out)
319
+ } catch (e: Exception) {
320
+ promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
321
+ }
322
+ }
323
+
324
+ fun feedSamples(
325
+ instanceId: String,
326
+ samples: ReadableArray,
327
+ sampleRate: Double,
328
+ promise: Promise
329
+ ) {
330
+ val inst = onlineInstances[instanceId]
331
+ val denoiser = inst?.denoiser
332
+ if (denoiser == null) {
333
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
334
+ return
335
+ }
336
+ try {
337
+ val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
338
+ promise.resolve(toEnhancedAudioMap(audio))
339
+ } catch (e: Exception) {
340
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
341
+ }
342
+ }
343
+
344
+ fun flushOnline(instanceId: String, promise: Promise) {
345
+ val inst = onlineInstances[instanceId]
346
+ val denoiser = inst?.denoiser
347
+ if (denoiser == null) {
348
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
349
+ return
350
+ }
351
+ try {
352
+ promise.resolve(toEnhancedAudioMap(denoiser.flush()))
353
+ } catch (e: Exception) {
354
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
355
+ }
356
+ }
357
+
358
+ fun resetOnline(instanceId: String, promise: Promise) {
359
+ val inst = onlineInstances[instanceId]
360
+ val denoiser = inst?.denoiser
361
+ if (denoiser == null) {
362
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
363
+ return
364
+ }
365
+ try {
366
+ denoiser.reset()
367
+ promise.resolve(null)
368
+ } catch (e: Exception) {
369
+ promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
370
+ }
371
+ }
372
+
373
+ fun unloadOnline(instanceId: String, promise: Promise) {
374
+ onlineInstances.remove(instanceId)?.release()
375
+ promise.resolve(null)
376
+ }
377
+ }