react-native-sherpa-onnx 0.3.9 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -4
- package/SherpaOnnx.podspec +1 -0
- package/android/prebuilt-download.gradle +67 -27
- package/android/prebuilt-versions.gradle +1 -1
- package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/android/src/main/cpp/CMakeLists.txt +3 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +66 -13
- package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
- package/ios/SherpaOnnx+Assets.mm +5 -0
- package/ios/SherpaOnnx+Enhancement.mm +435 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
- package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
- package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
- package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
- package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/download/localModels.js +2 -3
- package/lib/module/download/localModels.js.map +1 -1
- package/lib/module/download/paths.js +2 -1
- package/lib/module/download/paths.js.map +1 -1
- package/lib/module/enhancement/index.js +63 -48
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/enhancement/streaming.js +60 -0
- package/lib/module/enhancement/streaming.js.map +1 -0
- package/lib/module/enhancement/streamingTypes.js +4 -0
- package/lib/module/enhancement/streamingTypes.js.map +1 -0
- package/lib/module/enhancement/types.js +4 -0
- package/lib/module/enhancement/types.js.map +1 -0
- package/lib/module/licenses.js +9 -3
- package/lib/module/licenses.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/download/localModels.d.ts.map +1 -1
- package/lib/typescript/src/download/paths.d.ts +2 -1
- package/lib/typescript/src/download/paths.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/index.d.ts +9 -46
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
- package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
- package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/types.d.ts +31 -0
- package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
- package/lib/typescript/src/licenses.d.ts.map +1 -1
- package/package.json +1 -1
- package/scripts/ci/check-model-csvs.sh +27 -2
- package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
- package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
- package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
- package/scripts/ci/update_model_license_csv.sh +1 -1
- package/src/NativeSherpaOnnx.ts +71 -0
- package/src/download/localModels.ts +1 -3
- package/src/download/paths.ts +2 -1
- package/src/enhancement/index.ts +120 -58
- package/src/enhancement/streaming.ts +105 -0
- package/src/enhancement/streamingTypes.ts +14 -0
- package/src/enhancement/types.ts +36 -0
- package/src/licenses.ts +13 -2
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
|
@@ -382,13 +382,19 @@ internal class SherpaOnnxAssetHelper(
|
|
|
382
382
|
"mms",
|
|
383
383
|
"tts"
|
|
384
384
|
)
|
|
385
|
+
val enhancementHints = listOf(
|
|
386
|
+
"gtcrn",
|
|
387
|
+
"dpdfnet"
|
|
388
|
+
)
|
|
385
389
|
|
|
386
390
|
val isStt = sttHints.any { name.contains(it) }
|
|
387
391
|
val isTts = ttsHints.any { name.contains(it) }
|
|
392
|
+
val isEnhancement = enhancementHints.any { name.contains(it) }
|
|
388
393
|
|
|
389
394
|
return when {
|
|
390
395
|
isStt && !isTts -> "stt"
|
|
391
396
|
isTts && !isStt -> "tts"
|
|
397
|
+
isEnhancement && !isStt && !isTts -> "enhancement"
|
|
392
398
|
else -> "unknown"
|
|
393
399
|
}
|
|
394
400
|
}
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
package com.sherpaonnx
|
|
2
|
+
|
|
3
|
+
import android.net.Uri
|
|
4
|
+
import android.util.Log
|
|
5
|
+
import com.facebook.react.bridge.Arguments
|
|
6
|
+
import com.facebook.react.bridge.Promise
|
|
7
|
+
import com.facebook.react.bridge.ReadableArray
|
|
8
|
+
import com.facebook.react.bridge.ReactApplicationContext
|
|
9
|
+
import com.facebook.react.bridge.WritableMap
|
|
10
|
+
import com.k2fsa.sherpa.onnx.DenoisedAudio
|
|
11
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser
|
|
12
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig
|
|
13
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserDpdfNetModelConfig
|
|
14
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig
|
|
15
|
+
import com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig
|
|
16
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiser
|
|
17
|
+
import com.k2fsa.sherpa.onnx.OnlineSpeechDenoiserConfig
|
|
18
|
+
import com.k2fsa.sherpa.onnx.WaveReader
|
|
19
|
+
import java.io.File
|
|
20
|
+
import java.util.concurrent.ConcurrentHashMap
|
|
21
|
+
|
|
22
|
+
internal class SherpaOnnxEnhancementHelper(
|
|
23
|
+
private val context: ReactApplicationContext,
|
|
24
|
+
private val nativeDetectEnhancementModel: (modelDir: String, modelType: String) -> HashMap<String, Any>?
|
|
25
|
+
) {
|
|
26
|
+
private data class EnhancementInstance(
|
|
27
|
+
@Volatile var denoiser: OfflineSpeechDenoiser? = null
|
|
28
|
+
) {
|
|
29
|
+
fun release() {
|
|
30
|
+
denoiser?.release()
|
|
31
|
+
denoiser = null
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
private data class OnlineEnhancementInstance(
|
|
36
|
+
@Volatile var denoiser: OnlineSpeechDenoiser? = null
|
|
37
|
+
) {
|
|
38
|
+
fun release() {
|
|
39
|
+
denoiser?.release()
|
|
40
|
+
denoiser = null
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
private val instances = ConcurrentHashMap<String, EnhancementInstance>()
|
|
45
|
+
private val onlineInstances = ConcurrentHashMap<String, OnlineEnhancementInstance>()
|
|
46
|
+
|
|
47
|
+
fun shutdown() {
|
|
48
|
+
instances.values.forEach { it.release() }
|
|
49
|
+
instances.clear()
|
|
50
|
+
onlineInstances.values.forEach { it.release() }
|
|
51
|
+
onlineInstances.clear()
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
private fun path(map: Map<String, String>, key: String): String = map[key].orEmpty()
|
|
55
|
+
|
|
56
|
+
private fun toEnhancedAudioMap(audio: DenoisedAudio): WritableMap {
|
|
57
|
+
val samples = Arguments.createArray()
|
|
58
|
+
for (sample in audio.samples) {
|
|
59
|
+
samples.pushDouble(sample.toDouble())
|
|
60
|
+
}
|
|
61
|
+
val out = Arguments.createMap()
|
|
62
|
+
out.putArray("samples", samples)
|
|
63
|
+
out.putInt("sampleRate", audio.sampleRate)
|
|
64
|
+
return out
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
private fun readableArrayToFloatArray(samples: ReadableArray): FloatArray {
|
|
68
|
+
val out = FloatArray(samples.size())
|
|
69
|
+
for (i in 0 until samples.size()) {
|
|
70
|
+
out[i] = samples.getDouble(i).toFloat()
|
|
71
|
+
}
|
|
72
|
+
return out
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
private fun copyContentUriToTemp(path: String, prefix: String): Pair<String, File?> {
|
|
76
|
+
if (!path.startsWith("content://")) return Pair(path, null)
|
|
77
|
+
val uri = Uri.parse(path)
|
|
78
|
+
val tmp = File(context.cacheDir, "${prefix}_${System.nanoTime()}.wav")
|
|
79
|
+
context.contentResolver.openInputStream(uri)?.use { input ->
|
|
80
|
+
tmp.outputStream().use { output -> input.copyTo(output) }
|
|
81
|
+
} ?: throw IllegalStateException("File is not readable: $path")
|
|
82
|
+
return Pair(tmp.absolutePath, tmp)
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fun detectEnhancementModel(
|
|
86
|
+
modelDir: String,
|
|
87
|
+
modelType: String?,
|
|
88
|
+
promise: Promise
|
|
89
|
+
) {
|
|
90
|
+
try {
|
|
91
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
92
|
+
if (result == null) {
|
|
93
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection returned null")
|
|
94
|
+
return
|
|
95
|
+
}
|
|
96
|
+
val success = result["success"] as? Boolean ?: false
|
|
97
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
98
|
+
?: arrayListOf<HashMap<String, String>>()
|
|
99
|
+
val modelTypeStr = result["modelType"] as? String
|
|
100
|
+
|
|
101
|
+
val resultMap = Arguments.createMap()
|
|
102
|
+
resultMap.putBoolean("success", success)
|
|
103
|
+
val modelsArray = Arguments.createArray()
|
|
104
|
+
for (model in detectedModels) {
|
|
105
|
+
val modelMap = model as? HashMap<*, *>
|
|
106
|
+
if (modelMap != null) {
|
|
107
|
+
val entry = Arguments.createMap()
|
|
108
|
+
entry.putString("type", modelMap["type"] as? String ?: "")
|
|
109
|
+
entry.putString("modelDir", modelMap["modelDir"] as? String ?: "")
|
|
110
|
+
modelsArray.pushMap(entry)
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
resultMap.putArray("detectedModels", modelsArray)
|
|
114
|
+
if (modelTypeStr != null) {
|
|
115
|
+
resultMap.putString("modelType", modelTypeStr)
|
|
116
|
+
}
|
|
117
|
+
if (!success) {
|
|
118
|
+
val error = result["error"] as? String
|
|
119
|
+
if (!error.isNullOrBlank()) {
|
|
120
|
+
resultMap.putString("error", error)
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
promise.resolve(resultMap)
|
|
124
|
+
} catch (e: Exception) {
|
|
125
|
+
Log.e("SherpaOnnxEnhancement", "Enhancement detection failed", e)
|
|
126
|
+
promise.reject("DETECT_ERROR", "Enhancement model detection failed: ${e.message}", e)
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
fun initializeEnhancement(
|
|
131
|
+
instanceId: String,
|
|
132
|
+
modelDir: String,
|
|
133
|
+
modelType: String?,
|
|
134
|
+
numThreads: Double?,
|
|
135
|
+
provider: String?,
|
|
136
|
+
debug: Boolean?,
|
|
137
|
+
promise: Promise
|
|
138
|
+
) {
|
|
139
|
+
try {
|
|
140
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
141
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
142
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
143
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", reason)
|
|
144
|
+
return
|
|
145
|
+
}
|
|
146
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
147
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
148
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
149
|
+
?.mapKeys { it.key.toString() }
|
|
150
|
+
?: emptyMap()
|
|
151
|
+
|
|
152
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
153
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
154
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
155
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
156
|
+
provider = provider ?: "cpu",
|
|
157
|
+
debug = debug ?: false
|
|
158
|
+
)
|
|
159
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
160
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
161
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
162
|
+
provider = provider ?: "cpu",
|
|
163
|
+
debug = debug ?: false
|
|
164
|
+
)
|
|
165
|
+
else -> {
|
|
166
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
167
|
+
return
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
val inst = instances.getOrPut(instanceId) { EnhancementInstance() }
|
|
172
|
+
inst.release()
|
|
173
|
+
val denoiser = OfflineSpeechDenoiser(
|
|
174
|
+
config = OfflineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
175
|
+
)
|
|
176
|
+
inst.denoiser = denoiser
|
|
177
|
+
|
|
178
|
+
val modelsArray = Arguments.createArray()
|
|
179
|
+
val detectedModels = result["detectedModels"] as? ArrayList<*>
|
|
180
|
+
detectedModels?.forEach { modelObj ->
|
|
181
|
+
if (modelObj is HashMap<*, *>) {
|
|
182
|
+
val modelMap = Arguments.createMap()
|
|
183
|
+
modelMap.putString("type", modelObj["type"] as? String ?: "")
|
|
184
|
+
modelMap.putString("modelDir", modelObj["modelDir"] as? String ?: "")
|
|
185
|
+
modelsArray.pushMap(modelMap)
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
val out = Arguments.createMap()
|
|
190
|
+
out.putBoolean("success", true)
|
|
191
|
+
out.putArray("detectedModels", modelsArray)
|
|
192
|
+
out.putString("modelType", modelTypeStr)
|
|
193
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
194
|
+
promise.resolve(out)
|
|
195
|
+
} catch (e: Exception) {
|
|
196
|
+
Log.e("SherpaOnnxEnhancement", "Failed to initialize enhancement", e)
|
|
197
|
+
promise.reject("ENHANCEMENT_INIT_ERROR", "Failed to initialize enhancement: ${e.message}", e)
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
fun enhanceSamples(
|
|
202
|
+
instanceId: String,
|
|
203
|
+
samples: ReadableArray,
|
|
204
|
+
sampleRate: Double,
|
|
205
|
+
promise: Promise
|
|
206
|
+
) {
|
|
207
|
+
val inst = instances[instanceId]
|
|
208
|
+
val denoiser = inst?.denoiser
|
|
209
|
+
if (denoiser == null) {
|
|
210
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
211
|
+
return
|
|
212
|
+
}
|
|
213
|
+
try {
|
|
214
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
215
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
216
|
+
} catch (e: Exception) {
|
|
217
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance samples: ${e.message}", e)
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
fun enhanceFile(
|
|
222
|
+
instanceId: String,
|
|
223
|
+
inputPath: String,
|
|
224
|
+
outputPath: String?,
|
|
225
|
+
promise: Promise
|
|
226
|
+
) {
|
|
227
|
+
val inst = instances[instanceId]
|
|
228
|
+
val denoiser = inst?.denoiser
|
|
229
|
+
if (denoiser == null) {
|
|
230
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
231
|
+
return
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
var tmpInput: File? = null
|
|
235
|
+
try {
|
|
236
|
+
val (resolvedInputPath, tmp) = copyContentUriToTemp(inputPath, "enhancement_in")
|
|
237
|
+
tmpInput = tmp
|
|
238
|
+
val wave = WaveReader.readWave(resolvedInputPath)
|
|
239
|
+
val audio = denoiser.run(wave.samples, wave.sampleRate)
|
|
240
|
+
if (!outputPath.isNullOrBlank()) {
|
|
241
|
+
audio.save(outputPath)
|
|
242
|
+
}
|
|
243
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
244
|
+
} catch (e: Exception) {
|
|
245
|
+
promise.reject("ENHANCEMENT_ERROR", "Failed to enhance file: ${e.message}", e)
|
|
246
|
+
} finally {
|
|
247
|
+
tmpInput?.delete()
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
fun getSampleRate(instanceId: String, promise: Promise) {
|
|
252
|
+
val inst = instances[instanceId]
|
|
253
|
+
val denoiser = inst?.denoiser
|
|
254
|
+
if (denoiser == null) {
|
|
255
|
+
promise.reject("ENHANCEMENT_ERROR", "Enhancement instance not found: $instanceId")
|
|
256
|
+
return
|
|
257
|
+
}
|
|
258
|
+
promise.resolve(denoiser.sampleRate)
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
fun unloadEnhancement(instanceId: String, promise: Promise) {
|
|
262
|
+
instances.remove(instanceId)?.release()
|
|
263
|
+
promise.resolve(null)
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
fun initializeOnlineEnhancement(
|
|
267
|
+
instanceId: String,
|
|
268
|
+
modelDir: String,
|
|
269
|
+
modelType: String?,
|
|
270
|
+
numThreads: Double?,
|
|
271
|
+
provider: String?,
|
|
272
|
+
debug: Boolean?,
|
|
273
|
+
promise: Promise
|
|
274
|
+
) {
|
|
275
|
+
try {
|
|
276
|
+
val result = nativeDetectEnhancementModel(modelDir, modelType ?: "auto")
|
|
277
|
+
if (result == null || result["success"] as? Boolean != true) {
|
|
278
|
+
val reason = result?.get("error") as? String ?: "Failed to detect enhancement model"
|
|
279
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", reason)
|
|
280
|
+
return
|
|
281
|
+
}
|
|
282
|
+
val modelTypeStr = result["modelType"] as? String ?: "gtcrn"
|
|
283
|
+
val paths = (result["paths"] as? Map<*, *>)
|
|
284
|
+
?.mapValues { (_, v) -> (v as? String).orEmpty() }
|
|
285
|
+
?.mapKeys { it.key.toString() }
|
|
286
|
+
?: emptyMap()
|
|
287
|
+
|
|
288
|
+
val offlineModelConfig = when (modelTypeStr) {
|
|
289
|
+
"gtcrn" -> OfflineSpeechDenoiserModelConfig(
|
|
290
|
+
gtcrn = OfflineSpeechDenoiserGtcrnModelConfig(model = path(paths, "model")),
|
|
291
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
292
|
+
provider = provider ?: "cpu",
|
|
293
|
+
debug = debug ?: false
|
|
294
|
+
)
|
|
295
|
+
"dpdfnet" -> OfflineSpeechDenoiserModelConfig(
|
|
296
|
+
dpdfnet = OfflineSpeechDenoiserDpdfNetModelConfig(model = path(paths, "model")),
|
|
297
|
+
numThreads = numThreads?.toInt() ?: 1,
|
|
298
|
+
provider = provider ?: "cpu",
|
|
299
|
+
debug = debug ?: false
|
|
300
|
+
)
|
|
301
|
+
else -> {
|
|
302
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Unsupported enhancement model type: $modelTypeStr")
|
|
303
|
+
return
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
val inst = onlineInstances.getOrPut(instanceId) { OnlineEnhancementInstance() }
|
|
308
|
+
inst.release()
|
|
309
|
+
val denoiser = OnlineSpeechDenoiser(
|
|
310
|
+
config = OnlineSpeechDenoiserConfig(model = offlineModelConfig)
|
|
311
|
+
)
|
|
312
|
+
inst.denoiser = denoiser
|
|
313
|
+
|
|
314
|
+
val out = Arguments.createMap()
|
|
315
|
+
out.putBoolean("success", true)
|
|
316
|
+
out.putInt("sampleRate", denoiser.sampleRate)
|
|
317
|
+
out.putInt("frameShiftInSamples", denoiser.frameShiftInSamples)
|
|
318
|
+
promise.resolve(out)
|
|
319
|
+
} catch (e: Exception) {
|
|
320
|
+
promise.reject("ONLINE_ENHANCEMENT_INIT_ERROR", "Failed to initialize online enhancement: ${e.message}", e)
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
fun feedSamples(
|
|
325
|
+
instanceId: String,
|
|
326
|
+
samples: ReadableArray,
|
|
327
|
+
sampleRate: Double,
|
|
328
|
+
promise: Promise
|
|
329
|
+
) {
|
|
330
|
+
val inst = onlineInstances[instanceId]
|
|
331
|
+
val denoiser = inst?.denoiser
|
|
332
|
+
if (denoiser == null) {
|
|
333
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
334
|
+
return
|
|
335
|
+
}
|
|
336
|
+
try {
|
|
337
|
+
val audio = denoiser.run(readableArrayToFloatArray(samples), sampleRate.toInt())
|
|
338
|
+
promise.resolve(toEnhancedAudioMap(audio))
|
|
339
|
+
} catch (e: Exception) {
|
|
340
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to feed enhancement samples: ${e.message}", e)
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
fun flushOnline(instanceId: String, promise: Promise) {
|
|
345
|
+
val inst = onlineInstances[instanceId]
|
|
346
|
+
val denoiser = inst?.denoiser
|
|
347
|
+
if (denoiser == null) {
|
|
348
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
349
|
+
return
|
|
350
|
+
}
|
|
351
|
+
try {
|
|
352
|
+
promise.resolve(toEnhancedAudioMap(denoiser.flush()))
|
|
353
|
+
} catch (e: Exception) {
|
|
354
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to flush online enhancement: ${e.message}", e)
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
fun resetOnline(instanceId: String, promise: Promise) {
|
|
359
|
+
val inst = onlineInstances[instanceId]
|
|
360
|
+
val denoiser = inst?.denoiser
|
|
361
|
+
if (denoiser == null) {
|
|
362
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Online enhancement instance not found: $instanceId")
|
|
363
|
+
return
|
|
364
|
+
}
|
|
365
|
+
try {
|
|
366
|
+
denoiser.reset()
|
|
367
|
+
promise.resolve(null)
|
|
368
|
+
} catch (e: Exception) {
|
|
369
|
+
promise.reject("ONLINE_ENHANCEMENT_ERROR", "Failed to reset online enhancement: ${e.message}", e)
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
fun unloadOnline(instanceId: String, promise: Promise) {
|
|
374
|
+
onlineInstances.remove(instanceId)?.release()
|
|
375
|
+
promise.resolve(null)
|
|
376
|
+
}
|
|
377
|
+
}
|
|
@@ -56,6 +56,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
56
56
|
{ instanceId, requestId, message -> emitTtsStreamError(instanceId, requestId, message) },
|
|
57
57
|
{ instanceId, requestId, cancelled -> emitTtsStreamEnd(instanceId, requestId, cancelled) }
|
|
58
58
|
)
|
|
59
|
+
private val enhancementHelper = SherpaOnnxEnhancementHelper(
|
|
60
|
+
reactApplicationContext,
|
|
61
|
+
{ modelDir, modelType -> Companion.nativeDetectEnhancementModel(modelDir, modelType) }
|
|
62
|
+
)
|
|
59
63
|
private val archiveHelper = SherpaOnnxArchiveHelper()
|
|
60
64
|
private var pcmCapture: SherpaOnnxPcmCapture? = null
|
|
61
65
|
|
|
@@ -69,6 +73,7 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
69
73
|
pcmCapture = null
|
|
70
74
|
onlineSttHelper.shutdown()
|
|
71
75
|
ttsHelper.shutdown()
|
|
76
|
+
enhancementHelper.shutdown()
|
|
72
77
|
}
|
|
73
78
|
|
|
74
79
|
/**
|
|
@@ -1059,6 +1064,103 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1059
1064
|
ttsHelper.unloadTts(instanceId, promise)
|
|
1060
1065
|
}
|
|
1061
1066
|
|
|
1067
|
+
// ==================== Speech Enhancement Methods ====================
|
|
1068
|
+
|
|
1069
|
+
override fun detectEnhancementModel(
|
|
1070
|
+
modelDir: String,
|
|
1071
|
+
modelType: String?,
|
|
1072
|
+
promise: Promise
|
|
1073
|
+
) {
|
|
1074
|
+
enhancementHelper.detectEnhancementModel(modelDir, modelType, promise)
|
|
1075
|
+
}
|
|
1076
|
+
|
|
1077
|
+
override fun initializeEnhancement(
|
|
1078
|
+
instanceId: String,
|
|
1079
|
+
modelDir: String,
|
|
1080
|
+
modelType: String?,
|
|
1081
|
+
numThreads: Double?,
|
|
1082
|
+
provider: String?,
|
|
1083
|
+
debug: Boolean?,
|
|
1084
|
+
promise: Promise
|
|
1085
|
+
) {
|
|
1086
|
+
enhancementHelper.initializeEnhancement(
|
|
1087
|
+
instanceId,
|
|
1088
|
+
modelDir,
|
|
1089
|
+
modelType,
|
|
1090
|
+
numThreads,
|
|
1091
|
+
provider,
|
|
1092
|
+
debug,
|
|
1093
|
+
promise
|
|
1094
|
+
)
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
override fun enhanceFile(
|
|
1098
|
+
instanceId: String,
|
|
1099
|
+
inputPath: String,
|
|
1100
|
+
outputPath: String?,
|
|
1101
|
+
promise: Promise
|
|
1102
|
+
) {
|
|
1103
|
+
enhancementHelper.enhanceFile(instanceId, inputPath, outputPath, promise)
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
override fun enhanceSamples(
|
|
1107
|
+
instanceId: String,
|
|
1108
|
+
samples: ReadableArray,
|
|
1109
|
+
sampleRate: Double,
|
|
1110
|
+
promise: Promise
|
|
1111
|
+
) {
|
|
1112
|
+
enhancementHelper.enhanceSamples(instanceId, samples, sampleRate, promise)
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
override fun getEnhancementSampleRate(instanceId: String, promise: Promise) {
|
|
1116
|
+
enhancementHelper.getSampleRate(instanceId, promise)
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
override fun unloadEnhancement(instanceId: String, promise: Promise) {
|
|
1120
|
+
enhancementHelper.unloadEnhancement(instanceId, promise)
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1123
|
+
override fun initializeOnlineEnhancement(
|
|
1124
|
+
instanceId: String,
|
|
1125
|
+
modelDir: String,
|
|
1126
|
+
modelType: String?,
|
|
1127
|
+
numThreads: Double?,
|
|
1128
|
+
provider: String?,
|
|
1129
|
+
debug: Boolean?,
|
|
1130
|
+
promise: Promise
|
|
1131
|
+
) {
|
|
1132
|
+
enhancementHelper.initializeOnlineEnhancement(
|
|
1133
|
+
instanceId,
|
|
1134
|
+
modelDir,
|
|
1135
|
+
modelType,
|
|
1136
|
+
numThreads,
|
|
1137
|
+
provider,
|
|
1138
|
+
debug,
|
|
1139
|
+
promise
|
|
1140
|
+
)
|
|
1141
|
+
}
|
|
1142
|
+
|
|
1143
|
+
override fun feedEnhancementSamples(
|
|
1144
|
+
instanceId: String,
|
|
1145
|
+
samples: ReadableArray,
|
|
1146
|
+
sampleRate: Double,
|
|
1147
|
+
promise: Promise
|
|
1148
|
+
) {
|
|
1149
|
+
enhancementHelper.feedSamples(instanceId, samples, sampleRate, promise)
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
override fun flushOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1153
|
+
enhancementHelper.flushOnline(instanceId, promise)
|
|
1154
|
+
}
|
|
1155
|
+
|
|
1156
|
+
override fun resetOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1157
|
+
enhancementHelper.resetOnline(instanceId, promise)
|
|
1158
|
+
}
|
|
1159
|
+
|
|
1160
|
+
override fun unloadOnlineEnhancement(instanceId: String, promise: Promise) {
|
|
1161
|
+
enhancementHelper.unloadOnline(instanceId, promise)
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1062
1164
|
/**
|
|
1063
1165
|
* Save TTS audio samples to a WAV file.
|
|
1064
1166
|
*/
|
|
@@ -1256,6 +1358,10 @@ class SherpaOnnxModule(reactContext: ReactApplicationContext) :
|
|
|
1256
1358
|
@JvmStatic
|
|
1257
1359
|
private external fun nativeDetectTtsModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1258
1360
|
|
|
1361
|
+
/** Model detection for speech enhancement: returns HashMap with success, error, detectedModels, modelType, paths. */
|
|
1362
|
+
@JvmStatic
|
|
1363
|
+
private external fun nativeDetectEnhancementModel(modelDir: String, modelType: String): HashMap<String, Any>?
|
|
1364
|
+
|
|
1259
1365
|
/** Convert arbitrary audio file to requested format (e.g. "mp3", "flac", "wav").
|
|
1260
1366
|
* outputSampleRateHz: for MP3 use 32000/44100/48000, 0 = default 44100. Ignored for WAV/FLAC.
|
|
1261
1367
|
* Returns empty string on success, or an error message otherwise. Requires FFmpeg prebuilts when called on Android.
|
|
@@ -111,6 +111,46 @@ internal class SherpaOnnxTtsHelper(
|
|
|
111
111
|
}
|
|
112
112
|
}
|
|
113
113
|
|
|
114
|
+
/**
|
|
115
|
+
* libsherpa-onnx-jni looks up `invoke([F)Ljava/lang/Integer` (see sherpa-onnx `offline-tts.cc` CallCallback).
|
|
116
|
+
* Kotlin `Function1<*, Int>` compiles to `invoke([F)I`, so GetMethodID fails and JNI aborts.
|
|
117
|
+
* Using [java.lang.Integer] as the type parameter yields the boxed JVM signature the JNI expects.
|
|
118
|
+
* The cast is only for the Kotlin API (`generateWithCallback` still declares `Function1<FloatArray, Int>`).
|
|
119
|
+
*/
|
|
120
|
+
/** Box for JNI: must be real [java.lang.Integer], not Kotlin [Int] (primitive `invoke([F)I` breaks sherpa JNI). */
|
|
121
|
+
@Suppress("DEPRECATION")
|
|
122
|
+
private fun boxForTtsJni(n: Int): java.lang.Integer = java.lang.Integer(n)
|
|
123
|
+
|
|
124
|
+
@Suppress("UNCHECKED_CAST")
|
|
125
|
+
private fun ttsChunkCallbackForJni(
|
|
126
|
+
sentenceChunkSizes: MutableList<Int>
|
|
127
|
+
): kotlin.Function1<FloatArray, Int> {
|
|
128
|
+
val boxed =
|
|
129
|
+
object : kotlin.jvm.functions.Function1<FloatArray, java.lang.Integer> {
|
|
130
|
+
override fun invoke(chunk: FloatArray): java.lang.Integer {
|
|
131
|
+
sentenceChunkSizes.add(chunk.size)
|
|
132
|
+
return boxForTtsJni(chunk.size)
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
return boxed as kotlin.Function1<FloatArray, Int>
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
@Suppress("UNCHECKED_CAST")
|
|
139
|
+
private fun ttsStreamChunkCallbackForJni(
|
|
140
|
+
cancelled: AtomicBoolean,
|
|
141
|
+
onChunk: (FloatArray) -> Unit
|
|
142
|
+
): kotlin.Function1<FloatArray, Int> {
|
|
143
|
+
val boxed =
|
|
144
|
+
object : kotlin.jvm.functions.Function1<FloatArray, java.lang.Integer> {
|
|
145
|
+
override fun invoke(chunk: FloatArray): java.lang.Integer {
|
|
146
|
+
if (cancelled.get()) return boxForTtsJni(0)
|
|
147
|
+
onChunk(chunk)
|
|
148
|
+
return boxForTtsJni(chunk.size)
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
return boxed as kotlin.Function1<FloatArray, Int>
|
|
152
|
+
}
|
|
153
|
+
|
|
114
154
|
/** Single-thread executor for TTS init so the RN bridge thread is not blocked (avoids Inspector/dev WebSocket races in debug builds). */
|
|
115
155
|
private val ttsInitExecutor = Executors.newSingleThreadExecutor()
|
|
116
156
|
|
|
@@ -453,6 +493,7 @@ internal class SherpaOnnxTtsHelper(
|
|
|
453
493
|
}
|
|
454
494
|
val sid = getSid(options)
|
|
455
495
|
val speed = getSpeed(options)
|
|
496
|
+
val sentenceChunkSizes = mutableListOf<Int>()
|
|
456
497
|
val audio = when {
|
|
457
498
|
hasReferenceAudio(options) && (inst.isZipvoice || inst.isPocket) -> {
|
|
458
499
|
if (inst.isZipvoice) {
|
|
@@ -467,7 +508,11 @@ internal class SherpaOnnxTtsHelper(
|
|
|
467
508
|
}
|
|
468
509
|
}
|
|
469
510
|
val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
|
|
470
|
-
inst.tts!!.
|
|
511
|
+
inst.tts!!.generateWithConfigAndCallback(
|
|
512
|
+
text,
|
|
513
|
+
config,
|
|
514
|
+
ttsChunkCallbackForJni(sentenceChunkSizes)
|
|
515
|
+
)
|
|
471
516
|
}
|
|
472
517
|
hasReferenceAudio(options) -> {
|
|
473
518
|
Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: Reference audio is not supported for this TTS model type")
|
|
@@ -485,12 +530,15 @@ internal class SherpaOnnxTtsHelper(
|
|
|
485
530
|
)
|
|
486
531
|
return
|
|
487
532
|
}
|
|
488
|
-
else ->
|
|
489
|
-
|
|
533
|
+
else -> {
|
|
534
|
+
val tts = inst.tts
|
|
535
|
+
if (tts == null) {
|
|
490
536
|
Log.e("SherpaOnnxTts", "TTS_GENERATE_ERROR: TTS not initialized")
|
|
491
537
|
promise.reject("TTS_GENERATE_ERROR", "TTS not initialized")
|
|
492
538
|
return
|
|
493
539
|
}
|
|
540
|
+
tts.generateWithCallback(text, sid, speed, ttsChunkCallbackForJni(sentenceChunkSizes))
|
|
541
|
+
}
|
|
494
542
|
}
|
|
495
543
|
val map = Arguments.createMap()
|
|
496
544
|
val samplesArray = Arguments.createArray()
|
|
@@ -564,18 +612,23 @@ internal class SherpaOnnxTtsHelper(
|
|
|
564
612
|
when {
|
|
565
613
|
hasReferenceAudio(options) && inst.isPocket -> {
|
|
566
614
|
val config = parseGenerationConfig(options) ?: GenerationConfig(speed = speed, sid = sid)
|
|
567
|
-
inst.tts!!.generateWithConfigAndCallback(
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
chunk
|
|
571
|
-
|
|
615
|
+
inst.tts!!.generateWithConfigAndCallback(
|
|
616
|
+
text,
|
|
617
|
+
config,
|
|
618
|
+
ttsStreamChunkCallbackForJni(inst.ttsStreamCancelled) { chunk ->
|
|
619
|
+
emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
|
|
620
|
+
}
|
|
621
|
+
)
|
|
572
622
|
}
|
|
573
623
|
else -> {
|
|
574
|
-
inst.tts!!.generateWithCallback(
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
624
|
+
inst.tts!!.generateWithCallback(
|
|
625
|
+
text,
|
|
626
|
+
sid,
|
|
627
|
+
speed,
|
|
628
|
+
ttsStreamChunkCallbackForJni(inst.ttsStreamCancelled) { chunk ->
|
|
629
|
+
emitChunk(instanceId, requestId, chunk, sampleRate, 0f, false)
|
|
630
|
+
}
|
|
631
|
+
)
|
|
579
632
|
}
|
|
580
633
|
}
|
|
581
634
|
if (!inst.ttsStreamCancelled.get()) {
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
asset_name,license_type,commercial_use,confidence,detection_source,license_file
|
|
2
|
+
dpdfnet2.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
3
|
+
dpdfnet2_48khz_hr.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
4
|
+
dpdfnet4.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
5
|
+
dpdfnet8.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
6
|
+
dpdfnet_baseline.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
|
|
7
|
+
gtcrn_simple.onnx,mit,yes,high,manual,https://github.com/Xiaobin-Rong/gtcrn/tree/main
|
package/ios/SherpaOnnx+Assets.mm
CHANGED
|
@@ -319,6 +319,11 @@ static void collectModelFolderNames(NSFileManager *fileManager, NSString *path,
|
|
|
319
319
|
return @"tts";
|
|
320
320
|
}
|
|
321
321
|
|
|
322
|
+
BOOL isEnhancement = [name containsString:@"gtcrn"] || [name containsString:@"dpdfnet"];
|
|
323
|
+
if (isEnhancement) {
|
|
324
|
+
return @"enhancement";
|
|
325
|
+
}
|
|
326
|
+
|
|
322
327
|
return @"unknown";
|
|
323
328
|
}
|
|
324
329
|
|