react-native-sherpa-onnx 0.3.9 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -4
- package/SherpaOnnx.podspec +1 -0
- package/android/prebuilt-download.gradle +67 -27
- package/android/prebuilt-versions.gradle +1 -1
- package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
- package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/ios/SherpaOnnx+Assets.mm +5 -0
- package/ios/SherpaOnnx+Enhancement.mm +435 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
- package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
- package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/download/localModels.js +2 -3
- package/lib/module/download/localModels.js.map +1 -1
- package/lib/module/download/paths.js +2 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/enhancement/index.js +63 -48
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/enhancement/streaming.js +60 -0
- package/lib/module/enhancement/streaming.js.map +1 -0
- package/lib/module/enhancement/streamingTypes.js +4 -0
- package/lib/module/enhancement/streamingTypes.js.map +1 -0
- package/lib/module/enhancement/types.js +4 -0
- package/lib/module/enhancement/types.js.map +1 -0
- package/lib/module/licenses.js +9 -3
- package/lib/module/licenses.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/download/localModels.d.ts.map +1 -1
- package/lib/typescript/src/download/paths.d.ts +2 -1
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/index.d.ts +9 -46
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
- package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/types.d.ts +31 -0
- package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/package.json +1 -1
- package/scripts/ci/check-model-csvs.sh +27 -2
- package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
- package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
- package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
- package/scripts/ci/update_model_license_csv.sh +1 -1
- package/src/NativeSherpaOnnx.ts +71 -0
- package/src/download/localModels.ts +1 -3
- package/src/download/paths.ts +2 -1
- package/src/enhancement/index.ts +120 -58
- package/src/enhancement/streaming.ts +105 -0
- package/src/enhancement/streamingTypes.ts +14 -0
- package/src/enhancement/types.ts +36 -0
- package/src/licenses.ts +13 -2
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
|
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
|
|
|
382
382
|
"mms",
|
|
383
383
|
"tts"
|
|
384
384
|
)
|
|
385
|
+
val enhancementHints = listOf(
|
|
386
|
+
"gtcrn",
|
|
387
|
+
"dpdfnet"
|
|
388
|
+
)
|
|
385
389
|
|
|
386
390
|
val isStt = sttHints.any { name.contains(it) }
|
|
387
391
|
val isTts = ttsHints.any { name.contains(it) }
|
|
392
|
+
val isEnhancement = enhancementHints.any { name.contains(it) }
|
|
388
393
|
|
|
389
394
|
return when {
|
|
390
395
|
isStt && !isTts -> "stt"
|
|
391
396
|
isTts && !isStt -> "tts"
|
|
397
|
+
isEnhancement && !isStt && !isTts -> "enhancement"
|
|
392
398
|
else -> "unknown"
|
|
393
399
|
}
|
|
394
400
|
}
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
package com.sherpaonnx
|
|
2
|
+
|
|
3
|
+
import android.net.Uri
|
|
4
|
+
import android.util.Log
|
|
5
|
+
import com.facebook.react.bridge.Arguments
|
|
6
|
+
import com.facebook.react.bridge.Promise
|
|
7
|
+
import com.facebook.react.bridge.ReadableArray
|
|
8
|
+
import com.facebook.react.bridge.ReactApplicationContext
|
|
9
|
+
import com.facebook.react.bridge.WritableMap
|
|
10
|
+
import com.k2fsa.sherpa.onnx.DenoisedAudio
|
|
11
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
|
|
12
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
|
|
13
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
|
|
14
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
|
|
15
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
|
|
16
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
|
|
17
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
|
|
18
|
+
import com.k2fsa.sherpa.onnx.WaveReader
|
|
19
|
+
import java.io.File
|
|
20
|
+
import java.util.concurrent.ConcurrentHashMap
|
|
21
|
+
|
|
22
|
+
internal class SherpaOnnxEnhancementHelper(
|
|
23
|
+
private val context: ReactApplicationContext,
|
|
24
|
+
private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
|
|
25
|
+
) {
|
|
26
|
+
private data class EnhancementInstance(
|
|
27
|
+
@Volatile var denoiser: OfflineSpeechDenoiser? = null
|
|
28
|
+
) {
|
|
29
|
+
fun release() {
|
|
30
|
+
denoiser?.release()
|
|
31
|
+
denoiser = null
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
private data class OnlineEnhancementInstance(
|
|
36
|
+
@Volatile var denoiser: OnlineSpeechDenoiser? = null
|
|
37
|
+
) {
|
|
38
|
+
fun release() {
|
|
39
|
+
denoiser?.release()
|
|
40
|
+
denoiser = null
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
private val instances = ConcurrentHashMap<String, EnhancementInstance>()
|
|
45
|
+
private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
|
|
46
|
+
|
|
47
|
+
fun shutdown() {
|
|
48
|
+
instances.values.forEach { it.release() }
|
|
49
|
+
instances.clear()
|
|
50
|
+
onlineInstances.values.forEach { it.release() }
|
|
51
|
+
onlineInstances.clear()
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
|
|
55
|
+
|
|
56
|
+
private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
|
|
57
|
+
val samples = Arguments.createArray()
|
|
58
|
+
for (sample in audio.samples) {
|
|
59
|
+
samples.pushDouble(sample.toDouble())
|
|
60
|
+
}
|
|
61
|
+
val out = Arguments.createMap()
|
|
62
|
+
out.putArray("samples", samples)
|
|
63
|
+
out.putInt("sampleRate", audio.sampleRate)
|
|
64
|
+
return out
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
|
|
68
|
+
val out = FloatArray(samples.size())
|
|
69
|
+
for (i in 0 until samples.size()) {
|
|
70
|
+
out[i] = samples.getDouble(i).toFloat()
|
|
71
|
+
}
|
|
72
|
+
return out
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
|
|
76
|
+
if (!path.startsWith("content://")) return Pair(path, null)
|
|
77
|
+
val uri = Uri.parse(path)
|
|
78
|
+
val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
|
|
79
|
+
context.contentResolver.openInputStream(uri)?.use { input ->
|
|
80
|
+
tmp.outputStream().use { output -> input.copyTo(output) }
|
|
81
|
+
} ?: throw IllegalStateException("File is not readable: $path")
|
|
82
|
+
return Pair(tmp.absolutePath, tmp)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fun detectEnhancementModel(
|
|
86
|
+
modelDir: String,
|
|
87
|
+
modelType: String?,
|
|
88
|
+
promise: Promise
|
|
89
|
+
) {
|
|
90
|
+
try {
|
|
91
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
92
|
+
if (result == null) {
|
|
93
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
|
|
94
|
+
return
|
|
95
|
+
}
|
|
96
|
+
val success = result["success"] as? Boolean ?: false
|
|
97
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
98
|
+
?: arrayListOf<HashMap<String, String>>()
|
|
99
|
+
val modelTypeStr = result["modelType"] as? String
|
|
100
|
+
|
|
101
|
+
val resultMap = Arguments.createMap()
|
|
102
|
+
resultMap.putBoolean("success", success)
|
|
103
|
+
val modelsArray = Arguments.createArray()
|
|
104
|
+
for (model in detectedModels) {
|
|
105
|
+
val modelMap = model as? HashMap<*, *>
|
|
106
|
+
if (modelMap != null) {
|
|
107
|
+
val entry = Arguments.createMap()
|
|
108
|
+
entry.putString("type", modelMap["type"] as? String ?: "")
|
|
109
|
+
entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
|
|
110
|
+
modelsArray.pushMap(entry)
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
resultMap.putArray("detectedModels", modelsArray)
|
|
114
|
+
if (modelTypeStr != null) {
|
|
115
|
+
resultMap.putString("modelType", modelTypeStr)
|
|
116
|
+
}
|
|
117
|
+
if (!success) {
|
|
118
|
+
val error = result["error"] as? String
|
|
119
|
+
if (!error.isNullOrBlank()) {
|
|
120
|
+
resultMap.putString("error", error)
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
promise.resolve(resultMap)
|
|
124
|
+
} catch (e: Exception) {
|
|
125
|
+
Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
|
|
126
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
fun initializeEnhancement(
|
|
131
|
+
instanceId: String,
|
|
132
|
+
modelDir: String,
|
|
133
|
+
modelType: String?,
|
|
134
|
+
numThreads: Double?,
|
|
135
|
+
provider: String?,
|
|
136
|
+
debug: Boolean?,
|
|
137
|
+
promise: Promise
|
|
138
|
+
) {
|
|
139
|
+
try {
|
|
140
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
141
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
142
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
143
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", reason)
|
|
144
|
+
return
|
|
145
|
+
}
|
|
146
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
147
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
148
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
149
|
+
?.mapKeys { it.key.toString() }
|
|
150
|
+
?: emptyMap()
|
|
151
|
+
|
|
152
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
153
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
154
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
155
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
156
|
+
provider = provider ?: "cpu",
|
|
157
|
+
debug = debug ?: false
|
|
158
|
+
)
|
|
159
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
160
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
161
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
162
|
+
provider = provider ?: "cpu",
|
|
163
|
+
debug = debug ?: false
|
|
164
|
+
)
|
|
165
|
+
else -> {
|
|
166
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
167
|
+
return
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
|
|
172
|
+
inst.release()
|
|
173
|
+
val denoiser = OfflineSpeechDenoiser(
|
|
174
|
+
config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
175
|
+
)
|
|
176
|
+
inst.denoiser = denoiser
|
|
177
|
+
|
|
178
|
+
val modelsArray = Arguments.createArray()
|
|
179
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
180
|
+
detectedModels?.forEach { modelObj ->
|
|
181
|
+
if (modelObj is HashMap<*, *>) {
|
|
182
|
+
val modelMap = Arguments.createMap()
|
|
183
|
+
modelMap.putString("type", modelObj["type"] as? String ?: "")
|
|
184
|
+
modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
|
|
185
|
+
modelsArray.pushMap(modelMap)
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
val out = Arguments.createMap()
|
|
190
|
+
out.putBoolean("success", true)
|
|
191
|
+
out.putArray("detectedModels", modelsArray)
|
|
192
|
+
out.putString("modelType", modelTypeStr)
|
|
193
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
194
|
+
promise.resolve(out)
|
|
195
|
+
} catch (e: Exception) {
|
|
196
|
+
Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
|
|
197
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
fun enhanceSamples(
|
|
202
|
+
instanceId: String,
|
|
203
|
+
samples: ReadableArray,
|
|
204
|
+
sampleRate: Double,
|
|
205
|
+
promise: Promise
|
|
206
|
+
) {
|
|
207
|
+
val inst = instances[instanceId]
|
|
208
|
+
val denoiser = inst?.denoiser
|
|
209
|
+
if (denoiser == null) {
|
|
210
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
211
|
+
return
|
|
212
|
+
}
|
|
213
|
+
try {
|
|
214
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
215
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
216
|
+
} catch (e: Exception) {
|
|
217
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
fun enhanceFile(
|
|
222
|
+
instanceId: String,
|
|
223
|
+
inputPath: String,
|
|
224
|
+
outputPath: String?,
|
|
225
|
+
promise: Promise
|
|
226
|
+
) {
|
|
227
|
+
val inst = instances[instanceId]
|
|
228
|
+
val denoiser = inst?.denoiser
|
|
229
|
+
if (denoiser == null) {
|
|
230
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
231
|
+
return
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
var tmpInput: File? = null
|
|
235
|
+
try {
|
|
236
|
+
val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
|
|
237
|
+
tmpInput = tmp
|
|
238
|
+
val wave = WaveReader.readWave(resolvedInputPath)
|
|
239
|
+
val audio = denoiser.run(wave.samples, wave.sampleRate)
|
|
240
|
+
if (!outputPath.isNullOrBlank()) {
|
|
241
|
+
audio.save(outputPath)
|
|
242
|
+
}
|
|
243
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
244
|
+
} catch (e: Exception) {
|
|
245
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
|
|
246
|
+
} finally {
|
|
247
|
+
tmpInput?.delete()
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
fun getSampleRate(instanceId: String, promise: Promise) {
|
|
252
|
+
val inst = instances[instanceId]
|
|
253
|
+
val denoiser = inst?.denoiser
|
|
254
|
+
if (denoiser == null) {
|
|
255
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
256
|
+
return
|
|
257
|
+
}
|
|
258
|
+
promise.resolve(denoiser.sampleRate)
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
fun unloadEnhancement(instanceId: String, promise: Promise) {
|
|
262
|
+
instances.remove(instanceId)?.release()
|
|
263
|
+
promise.resolve(null)
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
fun initializeOnlineEnhancement(
|
|
267
|
+
instanceId: String,
|
|
268
|
+
modelDir: String,
|
|
269
|
+
modelType: String?,
|
|
270
|
+
numThreads: Double?,
|
|
271
|
+
provider: String?,
|
|
272
|
+
debug: Boolean?,
|
|
273
|
+
promise: Promise
|
|
274
|
+
) {
|
|
275
|
+
try {
|
|
276
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
277
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
278
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
279
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
|
|
280
|
+
return
|
|
281
|
+
}
|
|
282
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
283
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
284
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
285
|
+
?.mapKeys { it.key.toString() }
|
|
286
|
+
?: emptyMap()
|
|
287
|
+
|
|
288
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
289
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
290
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
291
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
292
|
+
provider = provider ?: "cpu",
|
|
293
|
+
debug = debug ?: false
|
|
294
|
+
)
|
|
295
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
296
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
297
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
298
|
+
provider = provider ?: "cpu",
|
|
299
|
+
debug = debug ?: false
|
|
300
|
+
)
|
|
301
|
+
else -> {
|
|
302
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
303
|
+
return
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
|
|
308
|
+
inst.release()
|
|
309
|
+
val denoiser = OnlineSpeechDenoiser(
|
|
310
|
+
config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
311
|
+
)
|
|
312
|
+
inst.denoiser = denoiser
|
|
313
|
+
|
|
314
|
+
val out = Arguments.createMap()
|
|
315
|
+
out.putBoolean("success", true)
|
|
316
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
317
|
+
out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
|
|
318
|
+
promise.resolve(out)
|
|
319
|
+
} catch (e: Exception) {
|
|
320
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
fun feedSamples(
|
|
325
|
+
instanceId: String,
|
|
326
|
+
samples: ReadableArray,
|
|
327
|
+
sampleRate: Double,
|
|
328
|
+
promise: Promise
|
|
329
|
+
) {
|
|
330
|
+
val inst = onlineInstances[instanceId]
|
|
331
|
+
val denoiser = inst?.denoiser
|
|
332
|
+
if (denoiser == null) {
|
|
333
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
334
|
+
return
|
|
335
|
+
}
|
|
336
|
+
try {
|
|
337
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
338
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
339
|
+
} catch (e: Exception) {
|
|
340
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
fun flushOnline(instanceId: String, promise: Promise) {
|
|
345
|
+
val inst = onlineInstances[instanceId]
|
|
346
|
+
val denoiser = inst?.denoiser
|
|
347
|
+
if (denoiser == null) {
|
|
348
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
349
|
+
return
|
|
350
|
+
}
|
|
351
|
+
try {
|
|
352
|
+
promise.resolve(toEnhancedAudioMap(denoiser.flush()))
|
|
353
|
+
} catch (e: Exception) {
|
|
354
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
fun resetOnline(instanceId: String, promise: Promise) {
|
|
359
|
+
val inst = onlineInstances[instanceId]
|
|
360
|
+
val denoiser = inst?.denoiser
|
|
361
|
+
if (denoiser == null) {
|
|
362
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
363
|
+
return
|
|
364
|
+
}
|
|
365
|
+
try {
|
|
366
|
+
denoiser.reset()
|
|
367
|
+
promise.resolve(null)
|
|
368
|
+
} catch (e: Exception) {
|
|
369
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
fun unloadOnline(instanceId: String, promise: Promise) {
|
|
374
|
+
onlineInstances.remove(instanceId)?.release()
|
|
375
|
+
promise.resolve(null)
|
|
376
|
+
}
|
|
377
|
+
}
|
|
@@ -56,6 +56,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
56
56
|
{ instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
|
|
57
57
|
{ instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
|
|
58
58
|
)
|
|
59
|
+
private val enhancementHelper = SherpaOnnxEnhancementHelper(
|
|
60
|
+
reactApplicationContext,
|
|
61
|
+
{ modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
|
|
62
|
+
)
|
|
59
63
|
private val archiveHelper = SherpaOnnxArchiveHelper()
|
|
60
64
|
private var pcmCapture: SherpaOnnxPcmCapture? = null
|
|
61
65
|
|
|
@@ -69,6 +73,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
69
73
|
pcmCapture = null
|
|
70
74
|
onlineSttHelper.shutdown()
|
|
71
75
|
ttsHelper.shutdown()
|
|
76
|
+
enhancementHelper.shutdown()
|
|
72
77
|
}
|
|
73
78
|
|
|
74
79
|
/**
|
|
@@ -1059,6 +1064,103 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1059
1064
|
ttsHelper.unloadTts(instanceId, promise)
|
|
1060
1065
|
}
|
|
1061
1066
|
|
|
1067
|
+
// ==================== Speech Enhancement Methods ====================
|
|
1068
|
+
|
|
1069
|
+
override fun detectEnhancementModel(
|
|
1070
|
+
modelDir: String,
|
|
1071
|
+
modelType: String?,
|
|
1072
|
+
promise: Promise
|
|
1073
|
+
) {
|
|
1074
|
+
enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
override fun initializeEnhancement(
|
|
1078
|
+
instanceId: String,
|
|
1079
|
+
modelDir: String,
|
|
1080
|
+
modelType: String?,
|
|
1081
|
+
numThreads: Double?,
|
|
1082
|
+
provider: String?,
|
|
1083
|
+
debug: Boolean?,
|
|
1084
|
+
promise: Promise
|
|
1085
|
+
) {
|
|
1086
|
+
enhancementHelper.initializeEnhancement(
|
|
1087
|
+
instanceId,
|
|
1088
|
+
modelDir,
|
|
1089
|
+
modelType,
|
|
1090
|
+
numThreads,
|
|
1091
|
+
provider,
|
|
1092
|
+
debug,
|
|
1093
|
+
promise
|
|
1094
|
+
)
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
override fun enhanceFile(
|
|
1098
|
+
instanceId: String,
|
|
1099
|
+
inputPath: String,
|
|
1100
|
+
outputPath: String?,
|
|
1101
|
+
promise: Promise
|
|
1102
|
+
) {
|
|
1103
|
+
enhancementHelper.enhanceFile(instanceId, inputPath, outputPath, promise)
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
override fun enhanceSamples(
|
|
1107
|
+
instanceId: String,
|
|
1108
|
+
samples: ReadableArray,
|
|
1109
|
+
sampleRate: Double,
|
|
1110
|
+
promise: Promise
|
|
1111
|
+
) {
|
|
1112
|
+
enhancementHelper.enhanceSamples(instanceId, samples, sampleRate, promise)
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
override fun getEnhancementSampleRate(instanceId: String, promise: Promise) {
|
|
1116
|
+
enhancementHelper.getSampleRate(instanceId, promise)
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
override fun unloadEnhancement(instanceId: String, promise: Promise) {
|
|
1120
|
+
enhancementHelper.unloadEnhancement(instanceId, promise)
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
override fun initializeOnlineEnhancement(
|
|
1124
|
+
instanceId: String,
|
|
1125
|
+
modelDir: String,
|
|
1126
|
+
modelType: String?,
|
|
1127
|
+
numThreads: Double?,
|
|
1128
|
+
provider: String?,
|
|
1129
|
+
debug: Boolean?,
|
|
1130
|
+
promise: Promise
|
|
1131
|
+
) {
|
|
1132
|
+
enhancementHelper.initializeOnlineEnhancement(
|
|
1133
|
+
instanceId,
|
|
1134
|
+
modelDir,
|
|
1135
|
+
modelType,
|
|
1136
|
+
numThreads,
|
|
1137
|
+
provider,
|
|
1138
|
+
debug,
|
|
1139
|
+
promise
|
|
1140
|
+
)
|
|
1141
|
+
}
|
|
1142
|
+
|
|
1143
|
+
override fun feedEnhancementSamples(
|
|
1144
|
+
instanceId: String,
|
|
1145
|
+
samples: ReadableArray,
|
|
1146
|
+
sampleRate: Double,
|
|
1147
|
+
promise: Promise
|
|
1148
|
+
) {
|
|
1149
|
+
enhancementHelper.feedSamples(instanceId, samples, sampleRate, promise)
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
override fun flushOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1153
|
+
enhancementHelper.flushOnline(instanceId, promise)
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
override fun resetOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1157
|
+
enhancementHelper.resetOnline(instanceId, promise)
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
override fun unloadOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1161
|
+
enhancementHelper.unloadOnline(instanceId, promise)
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1062
1164
|
/**
|
|
1063
1165
|
* Save TTS audio samples to a WAV file.
|
|
1064
1166
|
*/
|
|
@@ -1256,6 +1358,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1256
1358
|
@JvmStatic
|
|
1257
1359
|
private external fun nativeDetectTtsModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1258
1360
|
|
|
1361
|
+
/** Model detection for speech enhancement: returns HashMap with success, error, detectedModels, modelType, paths. */
|
|
1362
|
+
@JvmStatic
|
|
1363
|
+
private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1364
|
+
|
|
1259
1365
|
/** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
|
|
1260
1366
|
* outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
|
|
1261
1367
|
* Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
asset_name,license_type,commercial_use,confidence,detection_source,license_file
|
|
2
|
+
dpdfnet2.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
3
|
+
dpdfnet2_48khz_hr.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
4
|
+
dpdfnet4.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
5
|
+
dpdfnet8.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
6
|
+
dpdfnet_baseline.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
7
|
+
gtcrn_simple.onnx,mit,yes,high,manual,https://github.com/Xiaobin-Rong/gtcrn/tree/main
|
package/ios/SherpaOnnx+Assets.mm
CHANGED
|
@@ -319,6 +319,11 @@ static void collectModelFolderNames(NSFileManager *fileManager, NSString *path,
|
|
|
319
319
|
return @"tts";
|
|
320
320
|
}
|
|
321
321
|
|
|
322
|
+
BOOL isEnhancement = [name containsString:@"gtcrn"] || [name containsString:@"dpdfnet"];
|
|
323
|
+
if (isEnhancement) {
|
|
324
|
+
return @"enhancement";
|
|
325
|
+
}
|
|
326
|
+
|
|
322
327
|
return @"unknown";
|
|
323
328
|
}
|
|
324
329
|
|