react-native-sherpa-onnx 0.3.8 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +20 -5
- package/SherpaOnnx.podspec +5 -1
- package/android/prebuilt-download.gradle +89 -49
- package/android/prebuilt-versions.gradle +1 -1
- package/android/src/main/assets/model_licenses/asr-models-license-status.csv +1 -0
- package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +23 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +9 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +51 -8
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +41 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +5 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-stt.cpp +11 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +110 -35
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxExtractionNotificationHelper.kt +102 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +198 -18
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +22 -0
- package/ios/Resources/model_licenses/asr-models-license-status.csv +1 -0
- package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/ios/SherpaOnnx+Assets.mm +5 -0
- package/ios/SherpaOnnx+Enhancement.mm +435 -0
- package/ios/SherpaOnnx+STT.mm +13 -1
- package/ios/SherpaOnnx.mm +87 -17
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
- package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +5 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +23 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +51 -7
- package/ios/model_detect/sherpa-onnx-model-detect.h +33 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
- package/ios/model_detect/sherpa-onnx-validate-stt.mm +11 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +11 -1
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +30 -2
- package/ios/tts/sherpa-onnx-tts-wrapper.mm +16 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/download/localModels.js +2 -3
- package/lib/module/download/localModels.js.map +1 -1
- package/lib/module/download/paths.js +2 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/download/postDownloadProcessing.js +17 -4
- package/lib/module/download/postDownloadProcessing.js.map +1 -1
- package/lib/module/enhancement/index.js +63 -48
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/enhancement/streaming.js +60 -0
- package/lib/module/enhancement/streaming.js.map +1 -0
- package/lib/module/enhancement/streamingTypes.js +4 -0
- package/lib/module/enhancement/streamingTypes.js.map +1 -0
- package/lib/module/enhancement/types.js +4 -0
- package/lib/module/enhancement/types.js.map +1 -0
- package/lib/module/extraction/extractTarBz2.js +2 -2
- package/lib/module/extraction/extractTarBz2.js.map +1 -1
- package/lib/module/extraction/extractTarZst.js +2 -2
- package/lib/module/extraction/extractTarZst.js.map +1 -1
- package/lib/module/extraction/index.js +10 -5
- package/lib/module/extraction/index.js.map +1 -1
- package/lib/module/licenses.js +9 -3
- package/lib/module/licenses.js.map +1 -1
- package/lib/module/stt/index.js +4 -2
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/streaming.js +2 -1
- package/lib/module/stt/streaming.js.map +1 -1
- package/lib/module/stt/types.js +3 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +4 -2
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/streaming.js +3 -1
- package/lib/module/tts/streaming.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +70 -9
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/download/localModels.d.ts.map +1 -1
- package/lib/typescript/src/download/paths.d.ts +2 -1
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/download/postDownloadProcessing.d.ts +9 -0
- package/lib/typescript/src/download/postDownloadProcessing.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/index.d.ts +9 -46
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
- package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/types.d.ts +31 -0
- package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
- package/lib/typescript/src/extraction/extractTarBz2.d.ts +2 -1
- package/lib/typescript/src/extraction/extractTarBz2.d.ts.map +1 -1
- package/lib/typescript/src/extraction/extractTarZst.d.ts +2 -1
- package/lib/typescript/src/extraction/extractTarZst.d.ts.map +1 -1
- package/lib/typescript/src/extraction/index.d.ts +1 -1
- package/lib/typescript/src/extraction/index.d.ts.map +1 -1
- package/lib/typescript/src/extraction/types.d.ts +12 -0
- package/lib/typescript/src/extraction/types.d.ts.map +1 -1
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +1 -1
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/streaming.d.ts.map +1 -1
- package/lib/typescript/src/stt/types.d.ts +16 -1
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/streaming.d.ts.map +1 -1
- package/package.json +1 -1
- package/scripts/ci/check-model-csvs.sh +27 -2
- package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
- package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
- package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
- package/scripts/ci/update_model_license_csv.sh +17 -17
- package/src/NativeSherpaOnnx.ts +108 -10
- package/src/download/localModels.ts +1 -3
- package/src/download/paths.ts +2 -1
- package/src/download/postDownloadProcessing.ts +24 -1
- package/src/enhancement/index.ts +120 -58
- package/src/enhancement/streaming.ts +105 -0
- package/src/enhancement/streamingTypes.ts +14 -0
- package/src/enhancement/types.ts +36 -0
- package/src/extraction/extractTarBz2.ts +7 -2
- package/src/extraction/extractTarZst.ts +7 -2
- package/src/extraction/index.ts +29 -6
- package/src/extraction/types.ts +16 -0
- package/src/licenses.ts +13 -2
- package/src/stt/index.ts +8 -7
- package/src/stt/streaming.ts +7 -1
- package/src/stt/types.ts +18 -0
- package/src/tts/index.ts +7 -7
- package/src/tts/streaming.ts +6 -3
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -1
|
@@ -54,7 +54,8 @@ class SherpaOnnxArchiveHelper {
|
|
|
54
54
|
targetPath: String,
|
|
55
55
|
force: Boolean,
|
|
56
56
|
promise: Promise,
|
|
57
|
-
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
|
|
57
|
+
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
|
|
58
|
+
extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
|
|
58
59
|
) {
|
|
59
60
|
val promiseSettled = AtomicBoolean(false)
|
|
60
61
|
fun resolveOnce(success: Boolean, reason: String? = null) {
|
|
@@ -70,26 +71,28 @@ class SherpaOnnxArchiveHelper {
|
|
|
70
71
|
val cancelFlag = AtomicBoolean(false)
|
|
71
72
|
cancelFlags[sourcePath] = cancelFlag
|
|
72
73
|
|
|
73
|
-
// Create a progress callback object that JNI can call
|
|
74
|
-
val progressCallback = object : Any() {
|
|
75
|
-
fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
|
|
76
|
-
onProgress(bytesExtracted, totalBytes, percent)
|
|
77
|
-
}
|
|
78
|
-
}
|
|
79
|
-
|
|
80
74
|
// Run extraction on a background thread so the React Native bridge thread is not blocked.
|
|
81
75
|
// The thread pool allows multiple extractions in parallel.
|
|
82
76
|
extractExecutor.execute {
|
|
77
|
+
val notif = extractionNotification
|
|
83
78
|
try {
|
|
84
79
|
// Check per-path cancel flag before starting the native extraction.
|
|
85
80
|
if (cancelFlag.get()) {
|
|
86
81
|
resolveOnce(false, "Cancelled")
|
|
87
82
|
return@execute
|
|
88
83
|
}
|
|
89
|
-
|
|
84
|
+
notif?.start()
|
|
85
|
+
val wrappedCallback = object : Any() {
|
|
86
|
+
fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
|
|
87
|
+
onProgress(bytesExtracted, totalBytes, percent)
|
|
88
|
+
notif?.updateProgress(percent)
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
nativeExtractTarBz2(sourcePath, targetPath, force, wrappedCallback, promise)
|
|
90
92
|
} catch (e: Exception) {
|
|
91
93
|
resolveOnce(false, "Archive extraction error: ${e.message}")
|
|
92
94
|
} finally {
|
|
95
|
+
notif?.finish()
|
|
93
96
|
cancelFlags.remove(sourcePath)
|
|
94
97
|
}
|
|
95
98
|
}
|
|
@@ -104,7 +107,8 @@ class SherpaOnnxArchiveHelper {
|
|
|
104
107
|
targetPath: String,
|
|
105
108
|
force: Boolean,
|
|
106
109
|
promise: Promise,
|
|
107
|
-
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
|
|
110
|
+
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
|
|
111
|
+
extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
|
|
108
112
|
) {
|
|
109
113
|
val promiseSettled = AtomicBoolean(false)
|
|
110
114
|
fun resolveOnce(success: Boolean, reason: String? = null) {
|
|
@@ -119,22 +123,26 @@ class SherpaOnnxArchiveHelper {
|
|
|
119
123
|
val cancelFlag = AtomicBoolean(false)
|
|
120
124
|
cancelFlags[sourcePath] = cancelFlag
|
|
121
125
|
|
|
122
|
-
val progressCallback = object : Any() {
|
|
123
|
-
fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
|
|
124
|
-
onProgress(bytesExtracted, totalBytes, percent)
|
|
125
|
-
}
|
|
126
|
-
}
|
|
127
126
|
extractExecutor.execute {
|
|
127
|
+
val notif = extractionNotification
|
|
128
128
|
try {
|
|
129
129
|
// Check per-path cancel flag before starting the native extraction.
|
|
130
130
|
if (cancelFlag.get()) {
|
|
131
131
|
resolveOnce(false, "Cancelled")
|
|
132
132
|
return@execute
|
|
133
133
|
}
|
|
134
|
-
|
|
134
|
+
notif?.start()
|
|
135
|
+
val wrappedCallback = object : Any() {
|
|
136
|
+
fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
|
|
137
|
+
onProgress(bytesExtracted, totalBytes, percent)
|
|
138
|
+
notif?.updateProgress(percent)
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
nativeExtractTarZst(sourcePath, targetPath, force, wrappedCallback, promise)
|
|
135
142
|
} catch (e: Exception) {
|
|
136
143
|
resolveOnce(false, "Archive extraction error: ${e.message}")
|
|
137
144
|
} finally {
|
|
145
|
+
notif?.finish()
|
|
138
146
|
cancelFlags.remove(sourcePath)
|
|
139
147
|
}
|
|
140
148
|
}
|
|
@@ -144,47 +152,106 @@ class SherpaOnnxArchiveHelper {
|
|
|
144
152
|
}
|
|
145
153
|
}
|
|
146
154
|
|
|
155
|
+
/**
|
|
156
|
+
* Which JNI stream entry to use for APK asset extraction.
|
|
157
|
+
*
|
|
158
|
+
* Both paths invoke libarchive’s `ExtractFromStream`, which **auto-detects** compression
|
|
159
|
+
* (`.tar.zst` vs `.tar.bz2`, etc.); `nativeExtractTarBz2FromStream` forwards to the same
|
|
160
|
+
* native implementation as zst. Keeping distinct JNI symbols preserves a clear API and avoids
|
|
161
|
+
* the impression that bz2 assets are mistakenly wired only to a “zst” method.
|
|
162
|
+
*/
|
|
163
|
+
private enum class AssetTarStreamKind {
|
|
164
|
+
ZST,
|
|
165
|
+
BZ2,
|
|
166
|
+
}
|
|
167
|
+
|
|
147
168
|
fun extractTarZstFromAsset(
|
|
148
169
|
context: Context,
|
|
149
170
|
assetPath: String,
|
|
150
171
|
targetPath: String,
|
|
151
172
|
force: Boolean,
|
|
152
173
|
promise: Promise,
|
|
153
|
-
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
|
|
174
|
+
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
|
|
175
|
+
extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
|
|
176
|
+
) {
|
|
177
|
+
extractTarArchiveFromAsset(
|
|
178
|
+
context,
|
|
179
|
+
assetPath,
|
|
180
|
+
targetPath,
|
|
181
|
+
force,
|
|
182
|
+
promise,
|
|
183
|
+
onProgress,
|
|
184
|
+
extractionNotification,
|
|
185
|
+
AssetTarStreamKind.ZST,
|
|
186
|
+
)
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
fun extractTarBz2FromAsset(
|
|
190
|
+
context: Context,
|
|
191
|
+
assetPath: String,
|
|
192
|
+
targetPath: String,
|
|
193
|
+
force: Boolean,
|
|
194
|
+
promise: Promise,
|
|
195
|
+
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
|
|
196
|
+
extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
|
|
197
|
+
) {
|
|
198
|
+
extractTarArchiveFromAsset(
|
|
199
|
+
context,
|
|
200
|
+
assetPath,
|
|
201
|
+
targetPath,
|
|
202
|
+
force,
|
|
203
|
+
promise,
|
|
204
|
+
onProgress,
|
|
205
|
+
extractionNotification,
|
|
206
|
+
AssetTarStreamKind.BZ2,
|
|
207
|
+
)
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
private fun extractTarArchiveFromAsset(
|
|
211
|
+
context: Context,
|
|
212
|
+
assetPath: String,
|
|
213
|
+
targetPath: String,
|
|
214
|
+
force: Boolean,
|
|
215
|
+
promise: Promise,
|
|
216
|
+
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit,
|
|
217
|
+
extractionNotification: SherpaOnnxExtractionNotificationHelper? = null,
|
|
218
|
+
kind: AssetTarStreamKind,
|
|
154
219
|
) {
|
|
155
220
|
if (BuildConfig.DEBUG) {
|
|
156
|
-
Log.i(
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
onProgress(bytesExtracted, totalBytes, percent)
|
|
161
|
-
}
|
|
221
|
+
Log.i(
|
|
222
|
+
"SherpaOnnx",
|
|
223
|
+
"extractTar${if (kind == AssetTarStreamKind.ZST) "Zst" else "Bz2"}FromAsset assetPath=$assetPath targetPath=$targetPath",
|
|
224
|
+
)
|
|
162
225
|
}
|
|
163
226
|
extractExecutor.execute {
|
|
227
|
+
val notif = extractionNotification
|
|
164
228
|
try {
|
|
229
|
+
notif?.start()
|
|
230
|
+
val progressCallback = object : Any() {
|
|
231
|
+
fun invoke(bytesExtracted: Long, totalBytes: Long, percent: Double) {
|
|
232
|
+
onProgress(bytesExtracted, totalBytes, percent)
|
|
233
|
+
notif?.updateProgress(percent)
|
|
234
|
+
}
|
|
235
|
+
}
|
|
165
236
|
context.assets.open(assetPath).use { stream ->
|
|
166
|
-
|
|
237
|
+
when (kind) {
|
|
238
|
+
AssetTarStreamKind.ZST ->
|
|
239
|
+
nativeExtractTarZstFromStream(stream, targetPath, force, progressCallback, promise)
|
|
240
|
+
AssetTarStreamKind.BZ2 ->
|
|
241
|
+
nativeExtractTarBz2FromStream(stream, targetPath, force, progressCallback, promise)
|
|
242
|
+
}
|
|
167
243
|
}
|
|
168
244
|
} catch (e: Exception) {
|
|
169
245
|
val result = Arguments.createMap()
|
|
170
246
|
result.putBoolean("success", false)
|
|
171
247
|
result.putString("reason", e.message ?: "Failed to open asset")
|
|
172
248
|
promise.resolve(result)
|
|
249
|
+
} finally {
|
|
250
|
+
notif?.finish()
|
|
173
251
|
}
|
|
174
252
|
}
|
|
175
253
|
}
|
|
176
254
|
|
|
177
|
-
fun extractTarBz2FromAsset(
|
|
178
|
-
context: Context,
|
|
179
|
-
assetPath: String,
|
|
180
|
-
targetPath: String,
|
|
181
|
-
force: Boolean,
|
|
182
|
-
promise: Promise,
|
|
183
|
-
onProgress: (bytes: Long, totalBytes: Long, percent: Double) -> Unit
|
|
184
|
-
) {
|
|
185
|
-
extractTarZstFromAsset(context, assetPath, targetPath, force, promise, onProgress)
|
|
186
|
-
}
|
|
187
|
-
|
|
188
255
|
fun computeFileSha256(filePath: String, promise: Promise) {
|
|
189
256
|
nativeComputeFileSha256(filePath, promise)
|
|
190
257
|
}
|
|
@@ -214,6 +281,14 @@ class SherpaOnnxArchiveHelper {
|
|
|
214
281
|
promise: Promise
|
|
215
282
|
)
|
|
216
283
|
|
|
284
|
+
private external fun nativeExtractTarBz2FromStream(
|
|
285
|
+
inputStream: java.io.InputStream,
|
|
286
|
+
targetPath: String,
|
|
287
|
+
force: Boolean,
|
|
288
|
+
progressCallback: Any?,
|
|
289
|
+
promise: Promise
|
|
290
|
+
)
|
|
291
|
+
|
|
217
292
|
private external fun nativeCancelExtract()
|
|
218
293
|
|
|
219
294
|
private external fun nativeComputeFileSha256(
|
|
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
|
|
|
382
382
|
"mms",
|
|
383
383
|
"tts"
|
|
384
384
|
)
|
|
385
|
+
val enhancementHints = listOf(
|
|
386
|
+
"gtcrn",
|
|
387
|
+
"dpdfnet"
|
|
388
|
+
)
|
|
385
389
|
|
|
386
390
|
val isStt = sttHints.any { name.contains(it) }
|
|
387
391
|
val isTts = ttsHints.any { name.contains(it) }
|
|
392
|
+
val isEnhancement = enhancementHints.any { name.contains(it) }
|
|
388
393
|
|
|
389
394
|
return when {
|
|
390
395
|
isStt && !isTts -> "stt"
|
|
391
396
|
isTts && !isStt -> "tts"
|
|
397
|
+
isEnhancement && !isStt && !isTts -> "enhancement"
|
|
392
398
|
else -> "unknown"
|
|
393
399
|
}
|
|
394
400
|
}
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
package com.sherpaonnx
|
|
2
|
+
|
|
3
|
+
import android.net.Uri
|
|
4
|
+
import android.util.Log
|
|
5
|
+
import com.facebook.react.bridge.Arguments
|
|
6
|
+
import com.facebook.react.bridge.Promise
|
|
7
|
+
import com.facebook.react.bridge.ReadableArray
|
|
8
|
+
import com.facebook.react.bridge.ReactApplicationContext
|
|
9
|
+
import com.facebook.react.bridge.WritableMap
|
|
10
|
+
import com.k2fsa.sherpa.onnx.DenoisedAudio
|
|
11
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
|
|
12
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
|
|
13
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
|
|
14
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
|
|
15
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
|
|
16
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
|
|
17
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
|
|
18
|
+
import com.k2fsa.sherpa.onnx.WaveReader
|
|
19
|
+
import java.io.File
|
|
20
|
+
import java.util.concurrent.ConcurrentHashMap
|
|
21
|
+
|
|
22
|
+
internal class SherpaOnnxEnhancementHelper(
|
|
23
|
+
private val context: ReactApplicationContext,
|
|
24
|
+
private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
|
|
25
|
+
) {
|
|
26
|
+
private data class EnhancementInstance(
|
|
27
|
+
@Volatile var denoiser: OfflineSpeechDenoiser? = null
|
|
28
|
+
) {
|
|
29
|
+
fun release() {
|
|
30
|
+
denoiser?.release()
|
|
31
|
+
denoiser = null
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
private data class OnlineEnhancementInstance(
|
|
36
|
+
@Volatile var denoiser: OnlineSpeechDenoiser? = null
|
|
37
|
+
) {
|
|
38
|
+
fun release() {
|
|
39
|
+
denoiser?.release()
|
|
40
|
+
denoiser = null
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
private val instances = ConcurrentHashMap<String, EnhancementInstance>()
|
|
45
|
+
private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
|
|
46
|
+
|
|
47
|
+
fun shutdown() {
|
|
48
|
+
instances.values.forEach { it.release() }
|
|
49
|
+
instances.clear()
|
|
50
|
+
onlineInstances.values.forEach { it.release() }
|
|
51
|
+
onlineInstances.clear()
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
|
|
55
|
+
|
|
56
|
+
private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
|
|
57
|
+
val samples = Arguments.createArray()
|
|
58
|
+
for (sample in audio.samples) {
|
|
59
|
+
samples.pushDouble(sample.toDouble())
|
|
60
|
+
}
|
|
61
|
+
val out = Arguments.createMap()
|
|
62
|
+
out.putArray("samples", samples)
|
|
63
|
+
out.putInt("sampleRate", audio.sampleRate)
|
|
64
|
+
return out
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
|
|
68
|
+
val out = FloatArray(samples.size())
|
|
69
|
+
for (i in 0 until samples.size()) {
|
|
70
|
+
out[i] = samples.getDouble(i).toFloat()
|
|
71
|
+
}
|
|
72
|
+
return out
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
|
|
76
|
+
if (!path.startsWith("content://")) return Pair(path, null)
|
|
77
|
+
val uri = Uri.parse(path)
|
|
78
|
+
val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
|
|
79
|
+
context.contentResolver.openInputStream(uri)?.use { input ->
|
|
80
|
+
tmp.outputStream().use { output -> input.copyTo(output) }
|
|
81
|
+
} ?: throw IllegalStateException("File is not readable: $path")
|
|
82
|
+
return Pair(tmp.absolutePath, tmp)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fun detectEnhancementModel(
|
|
86
|
+
modelDir: String,
|
|
87
|
+
modelType: String?,
|
|
88
|
+
promise: Promise
|
|
89
|
+
) {
|
|
90
|
+
try {
|
|
91
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
92
|
+
if (result == null) {
|
|
93
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
|
|
94
|
+
return
|
|
95
|
+
}
|
|
96
|
+
val success = result["success"] as? Boolean ?: false
|
|
97
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
98
|
+
?: arrayListOf<HashMap<String, String>>()
|
|
99
|
+
val modelTypeStr = result["modelType"] as? String
|
|
100
|
+
|
|
101
|
+
val resultMap = Arguments.createMap()
|
|
102
|
+
resultMap.putBoolean("success", success)
|
|
103
|
+
val modelsArray = Arguments.createArray()
|
|
104
|
+
for (model in detectedModels) {
|
|
105
|
+
val modelMap = model as? HashMap<*, *>
|
|
106
|
+
if (modelMap != null) {
|
|
107
|
+
val entry = Arguments.createMap()
|
|
108
|
+
entry.putString("type", modelMap["type"] as? String ?: "")
|
|
109
|
+
entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
|
|
110
|
+
modelsArray.pushMap(entry)
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
resultMap.putArray("detectedModels", modelsArray)
|
|
114
|
+
if (modelTypeStr != null) {
|
|
115
|
+
resultMap.putString("modelType", modelTypeStr)
|
|
116
|
+
}
|
|
117
|
+
if (!success) {
|
|
118
|
+
val error = result["error"] as? String
|
|
119
|
+
if (!error.isNullOrBlank()) {
|
|
120
|
+
resultMap.putString("error", error)
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
promise.resolve(resultMap)
|
|
124
|
+
} catch (e: Exception) {
|
|
125
|
+
Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
|
|
126
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
fun initializeEnhancement(
|
|
131
|
+
instanceId: String,
|
|
132
|
+
modelDir: String,
|
|
133
|
+
modelType: String?,
|
|
134
|
+
numThreads: Double?,
|
|
135
|
+
provider: String?,
|
|
136
|
+
debug: Boolean?,
|
|
137
|
+
promise: Promise
|
|
138
|
+
) {
|
|
139
|
+
try {
|
|
140
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
141
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
142
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
143
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", reason)
|
|
144
|
+
return
|
|
145
|
+
}
|
|
146
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
147
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
148
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
149
|
+
?.mapKeys { it.key.toString() }
|
|
150
|
+
?: emptyMap()
|
|
151
|
+
|
|
152
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
153
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
154
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
155
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
156
|
+
provider = provider ?: "cpu",
|
|
157
|
+
debug = debug ?: false
|
|
158
|
+
)
|
|
159
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
160
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
161
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
162
|
+
provider = provider ?: "cpu",
|
|
163
|
+
debug = debug ?: false
|
|
164
|
+
)
|
|
165
|
+
else -> {
|
|
166
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
167
|
+
return
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
|
|
172
|
+
inst.release()
|
|
173
|
+
val denoiser = OfflineSpeechDenoiser(
|
|
174
|
+
config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
175
|
+
)
|
|
176
|
+
inst.denoiser = denoiser
|
|
177
|
+
|
|
178
|
+
val modelsArray = Arguments.createArray()
|
|
179
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
180
|
+
detectedModels?.forEach { modelObj ->
|
|
181
|
+
if (modelObj is HashMap<*, *>) {
|
|
182
|
+
val modelMap = Arguments.createMap()
|
|
183
|
+
modelMap.putString("type", modelObj["type"] as? String ?: "")
|
|
184
|
+
modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
|
|
185
|
+
modelsArray.pushMap(modelMap)
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
val out = Arguments.createMap()
|
|
190
|
+
out.putBoolean("success", true)
|
|
191
|
+
out.putArray("detectedModels", modelsArray)
|
|
192
|
+
out.putString("modelType", modelTypeStr)
|
|
193
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
194
|
+
promise.resolve(out)
|
|
195
|
+
} catch (e: Exception) {
|
|
196
|
+
Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
|
|
197
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
fun enhanceSamples(
|
|
202
|
+
instanceId: String,
|
|
203
|
+
samples: ReadableArray,
|
|
204
|
+
sampleRate: Double,
|
|
205
|
+
promise: Promise
|
|
206
|
+
) {
|
|
207
|
+
val inst = instances[instanceId]
|
|
208
|
+
val denoiser = inst?.denoiser
|
|
209
|
+
if (denoiser == null) {
|
|
210
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
211
|
+
return
|
|
212
|
+
}
|
|
213
|
+
try {
|
|
214
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
215
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
216
|
+
} catch (e: Exception) {
|
|
217
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
fun enhanceFile(
|
|
222
|
+
instanceId: String,
|
|
223
|
+
inputPath: String,
|
|
224
|
+
outputPath: String?,
|
|
225
|
+
promise: Promise
|
|
226
|
+
) {
|
|
227
|
+
val inst = instances[instanceId]
|
|
228
|
+
val denoiser = inst?.denoiser
|
|
229
|
+
if (denoiser == null) {
|
|
230
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
231
|
+
return
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
var tmpInput: File? = null
|
|
235
|
+
try {
|
|
236
|
+
val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
|
|
237
|
+
tmpInput = tmp
|
|
238
|
+
val wave = WaveReader.readWave(resolvedInputPath)
|
|
239
|
+
val audio = denoiser.run(wave.samples, wave.sampleRate)
|
|
240
|
+
if (!outputPath.isNullOrBlank()) {
|
|
241
|
+
audio.save(outputPath)
|
|
242
|
+
}
|
|
243
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
244
|
+
} catch (e: Exception) {
|
|
245
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
|
|
246
|
+
} finally {
|
|
247
|
+
tmpInput?.delete()
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
fun getSampleRate(instanceId: String, promise: Promise) {
|
|
252
|
+
val inst = instances[instanceId]
|
|
253
|
+
val denoiser = inst?.denoiser
|
|
254
|
+
if (denoiser == null) {
|
|
255
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
256
|
+
return
|
|
257
|
+
}
|
|
258
|
+
promise.resolve(denoiser.sampleRate)
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
fun unloadEnhancement(instanceId: String, promise: Promise) {
|
|
262
|
+
instances.remove(instanceId)?.release()
|
|
263
|
+
promise.resolve(null)
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
fun initializeOnlineEnhancement(
|
|
267
|
+
instanceId: String,
|
|
268
|
+
modelDir: String,
|
|
269
|
+
modelType: String?,
|
|
270
|
+
numThreads: Double?,
|
|
271
|
+
provider: String?,
|
|
272
|
+
debug: Boolean?,
|
|
273
|
+
promise: Promise
|
|
274
|
+
) {
|
|
275
|
+
try {
|
|
276
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
277
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
278
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
279
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
|
|
280
|
+
return
|
|
281
|
+
}
|
|
282
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
283
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
284
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
285
|
+
?.mapKeys { it.key.toString() }
|
|
286
|
+
?: emptyMap()
|
|
287
|
+
|
|
288
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
289
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
290
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
291
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
292
|
+
provider = provider ?: "cpu",
|
|
293
|
+
debug = debug ?: false
|
|
294
|
+
)
|
|
295
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
296
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
297
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
298
|
+
provider = provider ?: "cpu",
|
|
299
|
+
debug = debug ?: false
|
|
300
|
+
)
|
|
301
|
+
else -> {
|
|
302
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
303
|
+
return
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
|
|
308
|
+
inst.release()
|
|
309
|
+
val denoiser = OnlineSpeechDenoiser(
|
|
310
|
+
config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
311
|
+
)
|
|
312
|
+
inst.denoiser = denoiser
|
|
313
|
+
|
|
314
|
+
val out = Arguments.createMap()
|
|
315
|
+
out.putBoolean("success", true)
|
|
316
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
317
|
+
out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
|
|
318
|
+
promise.resolve(out)
|
|
319
|
+
} catch (e: Exception) {
|
|
320
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
fun feedSamples(
|
|
325
|
+
instanceId: String,
|
|
326
|
+
samples: ReadableArray,
|
|
327
|
+
sampleRate: Double,
|
|
328
|
+
promise: Promise
|
|
329
|
+
) {
|
|
330
|
+
val inst = onlineInstances[instanceId]
|
|
331
|
+
val denoiser = inst?.denoiser
|
|
332
|
+
if (denoiser == null) {
|
|
333
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
334
|
+
return
|
|
335
|
+
}
|
|
336
|
+
try {
|
|
337
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
338
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
339
|
+
} catch (e: Exception) {
|
|
340
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
fun flushOnline(instanceId: String, promise: Promise) {
|
|
345
|
+
val inst = onlineInstances[instanceId]
|
|
346
|
+
val denoiser = inst?.denoiser
|
|
347
|
+
if (denoiser == null) {
|
|
348
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
349
|
+
return
|
|
350
|
+
}
|
|
351
|
+
try {
|
|
352
|
+
promise.resolve(toEnhancedAudioMap(denoiser.flush()))
|
|
353
|
+
} catch (e: Exception) {
|
|
354
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
fun resetOnline(instanceId: String, promise: Promise) {
|
|
359
|
+
val inst = onlineInstances[instanceId]
|
|
360
|
+
val denoiser = inst?.denoiser
|
|
361
|
+
if (denoiser == null) {
|
|
362
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
363
|
+
return
|
|
364
|
+
}
|
|
365
|
+
try {
|
|
366
|
+
denoiser.reset()
|
|
367
|
+
promise.resolve(null)
|
|
368
|
+
} catch (e: Exception) {
|
|
369
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
fun unloadOnline(instanceId: String, promise: Promise) {
|
|
374
|
+
onlineInstances.remove(instanceId)?.release()
|
|
375
|
+
promise.resolve(null)
|
|
376
|
+
}
|
|
377
|
+
}
|