react-native-sherpa-onnx 0.3.9 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. package/README.md +17 -4
  2. package/SherpaOnnx.podspec +1 -0
  3. package/android/prebuilt-download.gradle +67 -27
  4. package/android/prebuilt-versions.gradle +1 -1
  5. package/android/src/main/assets/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  6. package/android/src/main/cpp/CMakeLists.txt +3 -0
  7. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp +68 -0
  8. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-enhancement-wrapper.h +17 -0
  9. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp +119 -0
  10. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect.h +31 -0
  11. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.cpp +68 -0
  12. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  13. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +21 -0
  14. package/android/src/main/java/com/sherpaonnx/SherpaOnnxAssetHelper.kt +6 -0
  15. package/android/src/main/java/com/sherpaonnx/SherpaOnnxEnhancementHelper.kt +377 -0
  16. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +106 -0
  17. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +66 -13
  18. package/ios/Resources/model_licenses/speech-enhancement-models-license-status.csv +7 -0
  19. package/ios/SherpaOnnx+Assets.mm +5 -0
  20. package/ios/SherpaOnnx+Enhancement.mm +435 -0
  21. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.h +85 -0
  22. package/ios/enhancement/sherpa-onnx-enhancement-wrapper.mm +218 -0
  23. package/ios/model_detect/sherpa-onnx-model-detect-enhancement.mm +92 -0
  24. package/ios/model_detect/sherpa-onnx-model-detect.h +23 -0
  25. package/ios/model_detect/sherpa-onnx-validate-enhancement.h +30 -0
  26. package/ios/model_detect/sherpa-onnx-validate-enhancement.mm +69 -0
  27. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  28. package/lib/module/download/localModels.js +2 -3
  29. package/lib/module/download/localModels.js.map +1 -1
  30. package/lib/module/download/paths.js +2 -1
  31. package/lib/module/download/paths.js.map +1 -1
  32. package/lib/module/enhancement/index.js +63 -48
  33. package/lib/module/enhancement/index.js.map +1 -1
  34. package/lib/module/enhancement/streaming.js +60 -0
  35. package/lib/module/enhancement/streaming.js.map +1 -0
  36. package/lib/module/enhancement/streamingTypes.js +4 -0
  37. package/lib/module/enhancement/streamingTypes.js.map +1 -0
  38. package/lib/module/enhancement/types.js +4 -0
  39. package/lib/module/enhancement/types.js.map +1 -0
  40. package/lib/module/licenses.js +9 -3
  41. package/lib/module/licenses.js.map +1 -1
  42. package/lib/typescript/src/NativeSherpaOnnx.d.ts +45 -0
  43. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  44. package/lib/typescript/src/download/localModels.d.ts.map +1 -1
  45. package/lib/typescript/src/download/paths.d.ts +2 -1
  46. package/lib/typescript/src/download/paths.d.ts.map +1 -1
  47. package/lib/typescript/src/enhancement/index.d.ts +9 -46
  48. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  49. package/lib/typescript/src/enhancement/streaming.d.ts +6 -0
  50. package/lib/typescript/src/enhancement/streaming.d.ts.map +1 -0
  51. package/lib/typescript/src/enhancement/streamingTypes.d.ts +12 -0
  52. package/lib/typescript/src/enhancement/streamingTypes.d.ts.map +1 -0
  53. package/lib/typescript/src/enhancement/types.d.ts +31 -0
  54. package/lib/typescript/src/enhancement/types.d.ts.map +1 -0
  55. package/lib/typescript/src/licenses.d.ts.map +1 -1
  56. package/package.json +1 -1
  57. package/scripts/ci/check-model-csvs.sh +27 -2
  58. package/scripts/ci/collect_all_sherpa_model_streams.sh +3 -1
  59. package/scripts/ci/collect_one_sherpa_release_stream.sh +3 -1
  60. package/scripts/ci/sherpa_speech_enhancement_model_release_streams.json +13 -0
  61. package/scripts/ci/update_model_license_csv.sh +1 -1
  62. package/src/NativeSherpaOnnx.ts +71 -0
  63. package/src/download/localModels.ts +1 -3
  64. package/src/download/paths.ts +2 -1
  65. package/src/enhancement/index.ts +120 -58
  66. package/src/enhancement/streaming.ts +105 -0
  67. package/src/enhancement/streamingTypes.ts +14 -0
  68. package/src/enhancement/types.ts +36 -0
  69. package/src/licenses.ts +13 -2
  70. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -1
package/README.md CHANGED
@@ -92,6 +92,7 @@ Full step-by-step: [Download manager – Setup (iOS & Android)](docs/download-ma
92
92
  - [Supported Model Types](#supported-model-types)
93
93
  - [Speech-to-Text (STT) Models](#speech-to-text-stt-models)
94
94
  - [Text-to-Speech (TTS) Models](#text-to-speech-tts-models)
95
+ - [Speech Enhancement Models](#speech-enhancement-models)
95
96
  - [Documentation](#documentation)
96
97
  - [Requirements](#requirements)
97
98
  - [Breaking changes (upgrading to 0.3.0)](#breaking-changes-upgrading-to-030)
@@ -108,8 +109,8 @@ Full step-by-step: [Download manager – Setup (iOS & Android)](docs/download-ma
108
109
 
109
110
  | Platform | Version |
110
111
  |----------|---------|
111
- | Android | 1.12.31 |
112
- | iOS | 1.12.31 |
112
+ | Android | 1.12.34 |
113
+ | iOS | 1.12.34 |
113
114
 
114
115
  ## Feature Support
115
116
 
@@ -126,7 +127,7 @@ Full step-by-step: [Download manager – Setup (iOS & Android)](docs/download-ma
126
127
  | Model quantization | ✅ **Supported** | [Model setup](./docs/model-setup.md) | Automatic detection and preference for quantized (int8) models. |
127
128
  | Flexible model loading | ✅ **Supported** | [Model setup](./docs/model-setup.md) | Asset models, file system models, or auto-detection. |
128
129
  | TypeScript | ✅ **Supported** | — | Full type definitions included. |
129
- | Speech Enhancement | Not yet supported | [Enhancement](./docs/enhancement.md) | Scheduled for release 0.4.0 |
130
+ | Speech Enhancement | **Supported** | [Speech Enhancement](./docs/speech-enhancement.md) | API and initialization covered in docs. |
130
131
  | Speaker Diarization | ❌ Not yet supported | [Diarization](./docs/diarization.md) | Scheduled for release 0.5.0 |
131
132
  | Source Separation | ❌ Not yet supported | [Separation](./docs/separation.md) | Scheduled for release 0.6.0 |
132
133
  | VAD (Voice Activity Detection) | ❌ Not yet supported | [VAD](./docs/vad.md) | Scheduled for release 0.7.0 |
@@ -184,6 +185,18 @@ For **real-time (streaming) recognition** from a microphone or audio stream, use
184
185
 
185
186
  For **streaming TTS** (incremental generation, low latency), use `createStreamingTTS()` with supported model types. See [Streaming Text-to-Speech](./docs/tts-streaming.md).
186
187
 
188
+ ### Speech Enhancement Models
189
+
190
+ Speech enhancement improves noisy or degraded speech using ONNX models from the sherpa-onnx **speech-enhancement-models** release. Detection looks for **`.onnx`** filenames containing **`gtcrn`** or **`dpdfnet`** (case-insensitive). With **`'auto'`**, **GTCRN** is preferred when both are present in the same folder.
191
+
192
+ | Model Type | `modelType` Value | Description | Download Links |
193
+ | ------------ | ----------------- | --------------------------------------------------------------------------- | -------------------------------------------------------------------------------- |
194
+ | **Auto Detect** | `'auto'` | Picks **GTCRN** if a matching `.onnx` exists, otherwise **DPDFNet** if found. | n/a |
195
+ | **GTCRN** | `'gtcrn'` | Lightweight speech enhancement (e.g. `gtcrn_simple.onnx`). | [Download](https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models) |
196
+ | **DPDFNet** | `'dpdfnet'` | Deep speech enhancement variants (e.g. `dpdfnet2.onnx`, `dpdfnet4.onnx`, `dpdfnet8.onnx`, `dpdfnet_baseline.onnx`, `dpdfnet2_48khz_hr.onnx`). | [Download](https://github.com/k2-fsa/sherpa-onnx/releases/tag/speech-enhancement-models) |
197
+
198
+ APIs, batch vs online processing, and initialization are covered in [Speech Enhancement](./docs/speech-enhancement.md).
199
+
187
200
  ## Documentation
188
201
 
189
202
  - [Known issues](./docs/KNOWN_ISSUES.md) – SDK-facing notes (e.g. Pocket TTS cloning / cross-platform behavior)
@@ -195,7 +208,7 @@ For **streaming TTS** (incremental generation, low latency), use `createStreamin
195
208
  - [Execution provider support (QNN, NNAPI, XNNPACK, Core ML)](./docs/execution-providers.md) – Checking and using acceleration backends
196
209
  - [Voice Activity Detection (VAD)](./docs/vad.md)
197
210
  - [Speaker Diarization](./docs/diarization.md)
198
- - [Speech Enhancement](./docs/enhancement.md)
211
+ - [Speech Enhancement](./docs/speech-enhancement.md)
199
212
  - [Source Separation](./docs/separation.md)
200
213
  - [Model Setup](./docs/model-setup.md) – Bundled assets, Play Asset Delivery (PAD), model discovery APIs, and troubleshooting
201
214
  - [Model Download Manager](./docs/download-manager.md)
@@ -94,6 +94,7 @@ Pod::Spec.new do |s|
94
94
  "\"#{pod_root}/ios/model_detect\"",
95
95
  "\"#{pod_root}/ios/stt\"",
96
96
  "\"#{pod_root}/ios/tts\"",
97
+ "\"#{pod_root}/ios/enhancement\"",
97
98
  "\"#{pod_root}/ios/online_stt\"",
98
99
  "\"#{device_headers}\"",
99
100
  "\"#{simulator_headers}\""
@@ -19,10 +19,12 @@ def requiredFfmpegSoFiles = [
19
19
  'libavcodec.so', 'libavformat.so', 'libavutil.so', 'libswresample.so', 'libavfilter.so', 'libshine.so'
20
20
  ]
21
21
  def requiredSherpaOnnxSoFiles = [
22
- 'libsherpa-onnx-jni.so', 'libsherpa-onnx-c-api.so', 'libsherpa-onnx-cxx-api.so', 'libonnxruntime.so'
22
+ 'libsherpa-onnx-jni.so', 'libsherpa-onnx-c-api.so', 'libsherpa-onnx-cxx-api.so'
23
23
  ]
24
24
  def requiredLibarchiveSoFiles = ['libarchive.so', 'libzstd.so']
25
25
  def requiredOnnxruntimeJniSoFiles = ['libonnxruntime4j_jni.so']
26
+ /** Both from the same onnxruntime release (AAR or third_party); required for Java ORT + symbol consistency. */
27
+ def requiredOnnxruntimeBundleSoFiles = ['libonnxruntime4j_jni.so', 'libonnxruntime.so']
26
28
 
27
29
  def jniLibsDir = file("${project.projectDir}/src/main/jniLibs")
28
30
  def sherpaOnnxClassesDir = file("${project.buildDir}/sherpa-onnx-classes")
@@ -79,11 +81,12 @@ def hasLibarchiveHeaders = {
79
81
  return new File(project.projectDir, "src/main/cpp/include/libarchive/archive.h").exists()
80
82
  }
81
83
 
82
- def hasAllOnnxruntimeJniLibs = {
84
+ /** Both ORT .so present per ABI (sherpa-onnx AAR no longer ships libonnxruntime.so). */
85
+ def hasAllOnnxruntimeBundleLibs = {
83
86
  for (abi in requiredAbis) {
84
87
  def dir = new File(jniLibsDir, abi)
85
88
  if (!dir.exists()) return false
86
- for (soName in requiredOnnxruntimeJniSoFiles) {
89
+ for (soName in requiredOnnxruntimeBundleSoFiles) {
87
90
  if (!new File(dir, soName).exists()) return false
88
91
  }
89
92
  }
@@ -163,14 +166,11 @@ project.tasks.register("downloadNativeLibsIfNeeded") {
163
166
  def libarchiveHdrsOk = hasLibarchiveHeaders()
164
167
  def libarchiveNeedsUpdate = !libarchiveLibsOk || !libarchiveHdrsOk || storedLibarchiveVersion == null || storedLibarchiveVersion != currentLibarchiveVersion
165
168
 
166
- def ortJniOk = hasAllOnnxruntimeJniLibs()
167
-
168
169
  println "[prebuilt] Resolution order: (1) THIRD_PARTY (2) LOCAL_SDK (3) MAVEN_AAR (4) GITHUB_RELEASE — see docs/PREBUILT_RESOLUTION.md"
169
170
 
170
171
  def sherpaResolved = false
171
172
  def ffmpegResolved = false
172
173
  def libarchiveResolved = false
173
- def ortJniResolved = ortJniOk
174
174
 
175
175
  // =====================================================================
176
176
  // sherpa-onnx: JNI (.so) + C headers
@@ -374,13 +374,55 @@ project.tasks.register("downloadNativeLibsIfNeeded") {
374
374
  }
375
375
 
376
376
  // =====================================================================
377
- // onnxruntime: JNI bridge only (libonnxruntime4j_jni.so)
378
- // libonnxruntime.so comes from sherpa-onnx prebuilts; not extracted here.
377
+ // onnxruntime: libonnxruntime.so + libonnxruntime4j_jni.so (same ORT build)
378
+ //
379
+ // Java OrtEnvironment loads libonnxruntime4j_jni.so, which dlopen's libonnxruntime.so.
380
+ // Sherpa-onnx AAR does not ship libonnxruntime.so; both come from the onnxruntime AAR
381
+ // or a full third_party/onnxruntime_prebuilt/android bundle.
379
382
  // =====================================================================
380
- if (!ortJniOk) {
381
- // Stage 1: THIRD_PARTY
382
- def tpJni = hasAllLibsUnder(thirdPartyOrtDir, requiredOnnxruntimeJniSoFiles)
383
- if (tpJni) {
383
+ def ortMatchedPairCopied = false
384
+ if (hasAllLibsUnder(thirdPartyOrtDir, requiredOnnxruntimeBundleSoFiles)) {
385
+ requiredAbis.each { abi ->
386
+ copy {
387
+ from new File(thirdPartyOrtDir, "jni/${abi}")
388
+ include 'libonnxruntime4j_jni.so', 'libonnxruntime.so'
389
+ into new File(jniLibsDir, abi)
390
+ }
391
+ }
392
+ ortMatchedPairCopied = true
393
+ println "[onnxruntime] libonnxruntime.so + libonnxruntime4j_jni.so (matched pair) .... THIRD_PARTY ${thirdPartyOrtDir}"
394
+ }
395
+ if (!ortMatchedPairCopied) {
396
+ try {
397
+ def aarFiles = project.configurations.onnxruntimeAar.files
398
+ if (!aarFiles.isEmpty()) {
399
+ downloadDir.mkdirs()
400
+ def aar = aarFiles.iterator().next()
401
+ def aarExtractDir = new File(downloadDir, "onnxruntime-aar-extract")
402
+ if (aarExtractDir.exists()) aarExtractDir.deleteDir()
403
+ aarExtractDir.mkdirs()
404
+ copy { from zipTree(aar); into aarExtractDir }
405
+ requiredAbis.each { abi ->
406
+ def aarJniDir = new File(aarExtractDir, "jni/${abi}")
407
+ if (aarJniDir.exists()) {
408
+ copy {
409
+ from aarJniDir
410
+ include 'libonnxruntime4j_jni.so', 'libonnxruntime.so'
411
+ into new File(jniLibsDir, abi)
412
+ }
413
+ }
414
+ }
415
+ ortMatchedPairCopied = true
416
+ println "[onnxruntime] libonnxruntime.so + libonnxruntime4j_jni.so (matched pair) .... MAVEN_AAR ${aar.name}"
417
+ }
418
+ } catch (Exception e) {
419
+ println "[onnxruntime] MAVEN_AAR matched-pair copy failed: ${e.message}"
420
+ }
421
+ }
422
+ if (!ortMatchedPairCopied && !hasAllOnnxruntimeBundleLibs()) {
423
+ // Legacy: third_party with JNI only (incomplete bundle)
424
+ def tpJniOnly = hasAllLibsUnder(thirdPartyOrtDir, requiredOnnxruntimeJniSoFiles)
425
+ if (tpJniOnly) {
384
426
  requiredAbis.each { abi ->
385
427
  copy {
386
428
  from new File(thirdPartyOrtDir, "jni/${abi}")
@@ -388,18 +430,16 @@ project.tasks.register("downloadNativeLibsIfNeeded") {
388
430
  into new File(jniLibsDir, abi)
389
431
  }
390
432
  }
391
- ortJniResolved = true
392
- println "[onnxruntime] libonnxruntime4j_jni.so .......... THIRD_PARTY ${thirdPartyOrtDir}"
433
+ println "[onnxruntime] libonnxruntime4j_jni.so only .... THIRD_PARTY ${thirdPartyOrtDir}"
434
+ println "[onnxruntime] WARN: no libonnxruntime.so in ort third_party — add full bundle or resolve onnxruntime AAR."
393
435
  }
394
-
395
- // Stage 3: MAVEN_AAR
396
- if (!ortJniResolved) {
436
+ if (!hasAllOnnxruntimeBundleLibs()) {
397
437
  try {
398
438
  def aarFiles = project.configurations.onnxruntimeAar.files
399
439
  if (!aarFiles.isEmpty()) {
400
440
  downloadDir.mkdirs()
401
441
  def aar = aarFiles.iterator().next()
402
- def aarExtractDir = new File(downloadDir, "onnxruntime-aar-extract")
442
+ def aarExtractDir = new File(downloadDir, "onnxruntime-aar-extract-jni-only")
403
443
  if (aarExtractDir.exists()) aarExtractDir.deleteDir()
404
444
  aarExtractDir.mkdirs()
405
445
  copy { from zipTree(aar); into aarExtractDir }
@@ -413,18 +453,18 @@ project.tasks.register("downloadNativeLibsIfNeeded") {
413
453
  }
414
454
  }
415
455
  }
416
- ortJniResolved = true
417
- println "[onnxruntime] libonnxruntime4j_jni.so .......... MAVEN_AAR ${aar.name}"
418
- println "[onnxruntime] install: per ABI --> ${jniLibsHuman}/<abi>/ (only JNI bridge; libonnxruntime.so from sherpa prebuilts)"
456
+ println "[onnxruntime] libonnxruntime4j_jni.so only .... MAVEN_AAR ${aar.name}"
457
+ println "[onnxruntime] WARN: AAR jni/<abi> missing libonnxruntime.so use a complete onnxruntime AAR."
419
458
  } else {
420
- println "[onnxruntime] MAVEN_AAR: onnxruntimeAar empty — libonnxruntime4j_jni.so still missing"
459
+ println "[onnxruntime] MAVEN_AAR: onnxruntimeAar empty — ORT native libs still missing"
421
460
  }
422
461
  } catch (Exception e) {
423
- println "[onnxruntime] MAVEN_AAR failed: ${e.message}"
462
+ println "[onnxruntime] MAVEN_AAR jni-only failed: ${e.message}"
424
463
  }
425
464
  }
426
- } else {
427
- println "[onnxruntime] libonnxruntime4j_jni.so (per ABI) .... LOCAL_SDK"
465
+ }
466
+ if (!hasAllOnnxruntimeBundleLibs()) {
467
+ println "[onnxruntime] WARN: libonnxruntime.so and/or libonnxruntime4j_jni.so missing after resolution. checkJniLibs will fail; use com.xdcobra.sherpa:onnxruntime or full third_party bundle."
428
468
  }
429
469
 
430
470
  // =====================================================================
@@ -598,10 +638,10 @@ project.tasks.register("checkJniLibs") {
598
638
  }
599
639
  }
600
640
  }
601
- requiredOnnxruntimeJniSoFiles.each { soName ->
641
+ requiredOnnxruntimeBundleSoFiles.each { soName ->
602
642
  def soFile = new File(dir, soName)
603
643
  if (!soFile.exists()) {
604
- throw new RuntimeException("Missing onnxruntime JNI bridge '${soName}' for ABI ${abi}. Ensure Maven com.xdcobra.sherpa:onnxruntime is available.")
644
+ throw new RuntimeException("Missing onnxruntime native library '${soName}' for ABI ${abi}. Ensure Maven com.xdcobra.sherpa:onnxruntime (libonnxruntime.so + libonnxruntime4j_jni.so) resolves.")
605
645
  }
606
646
  }
607
647
  }
@@ -45,7 +45,7 @@ println "[react-native-sherpa-onnx] libarchive version (extracted/used): ${libar
45
45
  def ortVersion = System.getenv('ORT_VERSION')
46
46
  if (!ortVersion) {
47
47
  def v = readVersionFromTagFile(new File(moduleRoot, 'third_party/onnxruntime_prebuilt/ANDROID_RELEASE_TAG'), 'ort-android-qnn-v')
48
- ortVersion = v ?: (project.hasProperty('ortVersion') ? project.ortVersion : '1.24.2-qnn2.43.1.260218')
48
+ ortVersion = v ?: (project.hasProperty('ortVersion') ? project.ortVersion : '1.24.4-qnn2.43.1.260218-1')
49
49
  }
50
50
  project.ext.ortVersion = ortVersion
51
51
  println "[react-native-sherpa-onnx] onnxruntime version (extracted/used): ${ortVersion}"
@@ -0,0 +1,7 @@
1
+ asset_name,license_type,commercial_use,confidence,detection_source,license_file
2
+ dpdfnet2.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
3
+ dpdfnet2_48khz_hr.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
4
+ dpdfnet4.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
5
+ dpdfnet8.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
6
+ dpdfnet_baseline.onnx,apache-2.0,yes,high,manual,https://huggingface.co/Ceva-IP/DPDFNet/tree/main/onnx
7
+ gtcrn_simple.onnx,mit,yes,high,manual,https://github.com/Xiaobin-Rong/gtcrn/tree/main
@@ -83,11 +83,14 @@ set(SOURCES
83
83
  jni/model_detect/sherpa-onnx-model-detect-helper.cpp
84
84
  jni/model_detect/sherpa-onnx-model-detect-stt.cpp
85
85
  jni/model_detect/sherpa-onnx-model-detect-tts.cpp
86
+ jni/model_detect/sherpa-onnx-model-detect-enhancement.cpp
86
87
  jni/model_detect/sherpa-onnx-validate-stt.cpp
87
88
  jni/model_detect/sherpa-onnx-validate-tts.cpp
89
+ jni/model_detect/sherpa-onnx-validate-enhancement.cpp
88
90
  jni/model_detect/sherpa-onnx-detect-jni-common.cpp
89
91
  jni/model_detect/sherpa-onnx-stt-wrapper.cpp
90
92
  jni/model_detect/sherpa-onnx-tts-wrapper.cpp
93
+ jni/model_detect/sherpa-onnx-enhancement-wrapper.cpp
91
94
  jni/audio/sherpa-onnx-audio-convert-jni.cpp
92
95
  crypto/sha256.cpp
93
96
  )
@@ -0,0 +1,68 @@
1
+ #include "sherpa-onnx-enhancement-wrapper.h"
2
+
3
+ #include "sherpa-onnx-detect-jni-common.h"
4
+
5
+ namespace sherpaonnx {
6
+ namespace {
7
+
8
+ const char* EnhancementModelKindToString(EnhancementModelKind k) {
9
+ switch (k) {
10
+ case EnhancementModelKind::kGtcrn:
11
+ return "gtcrn";
12
+ case EnhancementModelKind::kDpdfNet:
13
+ return "dpdfnet";
14
+ default:
15
+ return "unknown";
16
+ }
17
+ }
18
+
19
+ } // namespace
20
+
21
+ jobject EnhancementDetectResultToJava(
22
+ JNIEnv* env,
23
+ const EnhancementDetectResult& result
24
+ ) {
25
+ jclass mapClass = env->FindClass("java/util/HashMap");
26
+ if (!mapClass) return nullptr;
27
+ jmethodID mapInit = env->GetMethodID(mapClass, "<init>", "()V");
28
+ jmethodID mapPut =
29
+ env->GetMethodID(mapClass, "put",
30
+ "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
31
+ if (!mapInit || !mapPut) {
32
+ env->DeleteLocalRef(mapClass);
33
+ return nullptr;
34
+ }
35
+ jobject map = env->NewObject(mapClass, mapInit);
36
+ env->DeleteLocalRef(mapClass);
37
+ if (!map) return nullptr;
38
+
39
+ PutBoolean(env, map, mapPut, "success", result.ok);
40
+ PutString(env, map, mapPut, "error", result.error);
41
+ PutString(env, map, mapPut, "modelType",
42
+ EnhancementModelKindToString(result.selectedKind));
43
+
44
+ jobject detectedList = BuildDetectedModelsList(env, result.detectedModels);
45
+ if (detectedList) {
46
+ jstring keyDetected = env->NewStringUTF("detectedModels");
47
+ env->CallObjectMethod(map, mapPut, keyDetected, detectedList);
48
+ env->DeleteLocalRef(keyDetected);
49
+ env->DeleteLocalRef(detectedList);
50
+ }
51
+
52
+ jclass hashMapClass = env->FindClass("java/util/HashMap");
53
+ if (hashMapClass) {
54
+ jobject pathsMap = env->NewObject(hashMapClass, mapInit);
55
+ env->DeleteLocalRef(hashMapClass);
56
+ if (pathsMap) {
57
+ PutString(env, pathsMap, mapPut, "model", result.paths.model);
58
+ jstring keyPaths = env->NewStringUTF("paths");
59
+ env->CallObjectMethod(map, mapPut, keyPaths, pathsMap);
60
+ env->DeleteLocalRef(keyPaths);
61
+ env->DeleteLocalRef(pathsMap);
62
+ }
63
+ }
64
+
65
+ return map;
66
+ }
67
+
68
+ } // namespace sherpaonnx
@@ -0,0 +1,17 @@
1
+ #ifndef SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
2
+ #define SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
3
+
4
+ #include <jni.h>
5
+
6
+ #include "sherpa-onnx-model-detect.h"
7
+
8
+ namespace sherpaonnx {
9
+
10
+ jobject EnhancementDetectResultToJava(
11
+ JNIEnv* env,
12
+ const EnhancementDetectResult& result
13
+ );
14
+
15
+ } // namespace sherpaonnx
16
+
17
+ #endif // SHERPA_ONNX_ENHANCEMENT_WRAPPER_H
@@ -0,0 +1,119 @@
1
+ #include "sherpa-onnx-model-detect.h"
2
+ #include "sherpa-onnx-model-detect-helper.h"
3
+ #include "sherpa-onnx-validate-enhancement.h"
4
+
5
+ #include <optional>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ namespace {
10
+
11
+ using namespace sherpaonnx::model_detect;
12
+
13
+ sherpaonnx::EnhancementModelKind ParseEnhancementModelType(const std::string& modelType) {
14
+ if (modelType == "gtcrn") return sherpaonnx::EnhancementModelKind::kGtcrn;
15
+ if (modelType == "dpdfnet") return sherpaonnx::EnhancementModelKind::kDpdfNet;
16
+ return sherpaonnx::EnhancementModelKind::kUnknown;
17
+ }
18
+
19
+ sherpaonnx::EnhancementDetectResult DetectEnhancementModelFromFiles(
20
+ const std::vector<FileEntry>& files,
21
+ const std::string& modelDir,
22
+ const std::string& modelType
23
+ ) {
24
+ sherpaonnx::EnhancementDetectResult result;
25
+
26
+ const std::string gtcrnModel =
27
+ FindOnnxByAnyToken(files, {"gtcrn"}, std::nullopt);
28
+ const std::string dpdfnetModel =
29
+ FindOnnxByAnyToken(files, {"dpdfnet"}, std::nullopt);
30
+
31
+ if (!gtcrnModel.empty()) {
32
+ result.detectedModels.push_back({"gtcrn", modelDir});
33
+ }
34
+ if (!dpdfnetModel.empty()) {
35
+ result.detectedModels.push_back({"dpdfnet", modelDir});
36
+ }
37
+
38
+ sherpaonnx::EnhancementModelKind selected = sherpaonnx::EnhancementModelKind::kUnknown;
39
+ if (modelType == "auto" || modelType.empty()) {
40
+ if (!gtcrnModel.empty()) {
41
+ selected = sherpaonnx::EnhancementModelKind::kGtcrn;
42
+ } else if (!dpdfnetModel.empty()) {
43
+ selected = sherpaonnx::EnhancementModelKind::kDpdfNet;
44
+ }
45
+ } else {
46
+ selected = ParseEnhancementModelType(modelType);
47
+ if (selected == sherpaonnx::EnhancementModelKind::kUnknown) {
48
+ result.error = "Enhancement: unknown model type: " + modelType;
49
+ return result;
50
+ }
51
+ }
52
+
53
+ switch (selected) {
54
+ case sherpaonnx::EnhancementModelKind::kGtcrn:
55
+ result.paths.model = gtcrnModel;
56
+ break;
57
+ case sherpaonnx::EnhancementModelKind::kDpdfNet:
58
+ result.paths.model = dpdfnetModel;
59
+ break;
60
+ default:
61
+ result.error = "Enhancement: no compatible model type detected in " +
62
+ modelDir;
63
+ return result;
64
+ }
65
+
66
+ auto validation =
67
+ sherpaonnx::ValidateEnhancementPaths(selected, result.paths, modelDir);
68
+ if (!validation.ok) {
69
+ result.error = validation.error;
70
+ return result;
71
+ }
72
+
73
+ result.selectedKind = selected;
74
+ result.ok = true;
75
+ return result;
76
+ }
77
+
78
+ } // namespace
79
+
80
+ namespace sherpaonnx {
81
+
82
+ using namespace model_detect;
83
+
84
+ EnhancementDetectResult DetectEnhancementModel(
85
+ const std::string& modelDir,
86
+ const std::string& modelType
87
+ ) {
88
+ EnhancementDetectResult result;
89
+
90
+ if (modelDir.empty()) {
91
+ result.error = "Enhancement: model directory is empty";
92
+ return result;
93
+ }
94
+ if (!FileExists(modelDir) || !IsDirectory(modelDir)) {
95
+ result.error =
96
+ "Enhancement: model directory does not exist or is not a directory: " +
97
+ modelDir;
98
+ return result;
99
+ }
100
+
101
+ const std::vector<model_detect::FileEntry> files = ListFilesRecursive(modelDir, 4);
102
+ return DetectEnhancementModelFromFiles(files, modelDir, modelType);
103
+ }
104
+
105
+ // Test-only: used by host-side model_detect_test; not used in production.
106
+ EnhancementDetectResult DetectEnhancementModelFromFileList(
107
+ const std::vector<model_detect::FileEntry>& files,
108
+ const std::string& modelDir,
109
+ const std::string& modelType
110
+ ) {
111
+ EnhancementDetectResult result;
112
+ if (modelDir.empty()) {
113
+ result.error = "Enhancement: model directory is empty";
114
+ return result;
115
+ }
116
+ return DetectEnhancementModelFromFiles(files, modelDir, modelType);
117
+ }
118
+
119
+ } // namespace sherpaonnx
@@ -43,6 +43,12 @@ enum class TtsModelKind {
43
43
  kSupertonic
44
44
  };
45
45
 
46
+ enum class EnhancementModelKind {
47
+ kUnknown,
48
+ kGtcrn,
49
+ kDpdfNet
50
+ };
51
+
46
52
  struct SttModelPaths {
47
53
  std::string encoder;
48
54
  std::string decoder;
@@ -174,6 +180,10 @@ struct TtsModelPaths {
174
180
  std::string voiceStyle;
175
181
  };
176
182
 
183
+ struct EnhancementModelPaths {
184
+ std::string model;
185
+ };
186
+
177
187
  struct SttDetectResult {
178
188
  bool ok = false;
179
189
  std::string error;
@@ -195,6 +205,14 @@ struct TtsDetectResult {
195
205
  std::vector<std::string> lexiconLanguageCandidates;
196
206
  };
197
207
 
208
+ struct EnhancementDetectResult {
209
+ bool ok = false;
210
+ std::string error;
211
+ std::vector<DetectedModel> detectedModels;
212
+ EnhancementModelKind selectedKind = EnhancementModelKind::kUnknown;
213
+ EnhancementModelPaths paths;
214
+ };
215
+
198
216
  SttDetectResult DetectSttModel(
199
217
  const std::string& modelDir,
200
218
  const std::optional<bool>& preferInt8,
@@ -228,6 +246,19 @@ TtsDetectResult DetectTtsModelFromFileList(
228
246
  const std::string& modelType = "auto"
229
247
  );
230
248
 
249
+ EnhancementDetectResult DetectEnhancementModel(
250
+ const std::string& modelDir,
251
+ const std::string& modelType
252
+ );
253
+
254
+ /** Test-only: Like DetectEnhancementModel but takes a pre-built file list; no filesystem access.
255
+ * Only used by the host-side C++ test suite (test/cpp/model_detect_test.cpp). */
256
+ EnhancementDetectResult DetectEnhancementModelFromFileList(
257
+ const std::vector<model_detect::FileEntry>& files,
258
+ const std::string& modelDir,
259
+ const std::string& modelType = "auto"
260
+ );
261
+
231
262
  } // namespace sherpaonnx
232
263
 
233
264
  #endif // SHERPA_ONNX_MODEL_DETECT_H
@@ -0,0 +1,68 @@
1
+ #include "sherpa-onnx-validate-enhancement.h"
2
+
3
+ #include <cstddef>
4
+
5
+ namespace sherpaonnx {
6
+ namespace {
7
+
8
+ static const EnhancementFieldRequirement kGenericReqs[] = {
9
+ {"model", &EnhancementModelPaths::model, true},
10
+ };
11
+
12
+ static const EnhancementFieldRequirement* GetRequirements(
13
+ EnhancementModelKind kind,
14
+ size_t& count
15
+ ) {
16
+ switch (kind) {
17
+ case EnhancementModelKind::kGtcrn:
18
+ case EnhancementModelKind::kDpdfNet:
19
+ count = std::size(kGenericReqs);
20
+ return kGenericReqs;
21
+ default:
22
+ count = 0;
23
+ return nullptr;
24
+ }
25
+ }
26
+
27
+ static const char* EnhancementKindToName(EnhancementModelKind kind) {
28
+ switch (kind) {
29
+ case EnhancementModelKind::kGtcrn:
30
+ return "GTCRN";
31
+ case EnhancementModelKind::kDpdfNet:
32
+ return "DPDFNet";
33
+ default:
34
+ return "Unknown";
35
+ }
36
+ }
37
+
38
+ } // namespace
39
+
40
+ EnhancementValidationResult ValidateEnhancementPaths(
41
+ EnhancementModelKind kind,
42
+ const EnhancementModelPaths& paths,
43
+ const std::string& modelDir
44
+ ) {
45
+ EnhancementValidationResult result;
46
+ size_t count = 0;
47
+ const auto* reqs = GetRequirements(kind, count);
48
+ if (!reqs) return result;
49
+
50
+ for (size_t i = 0; i < count; ++i) {
51
+ if (reqs[i].required && (paths.*(reqs[i].field)).empty()) {
52
+ result.missingRequired.push_back(reqs[i].fieldName);
53
+ }
54
+ }
55
+
56
+ if (!result.missingRequired.empty()) {
57
+ result.ok = false;
58
+ result.error = std::string("Enhancement ") + EnhancementKindToName(kind) +
59
+ ": missing required files in " + modelDir + ": ";
60
+ for (size_t i = 0; i < result.missingRequired.size(); ++i) {
61
+ if (i > 0) result.error += ", ";
62
+ result.error += result.missingRequired[i];
63
+ }
64
+ }
65
+ return result;
66
+ }
67
+
68
+ } // namespace sherpaonnx
@@ -0,0 +1,30 @@
1
+ #ifndef SHERPA_ONNX_VALIDATE_ENHANCEMENT_H
2
+ #define SHERPA_ONNX_VALIDATE_ENHANCEMENT_H
3
+
4
+ #include "sherpa-onnx-model-detect.h"
5
+ #include <string>
6
+ #include <vector>
7
+
8
+ namespace sherpaonnx {
9
+
10
+ struct EnhancementFieldRequirement {
11
+ const char* fieldName;
12
+ std::string EnhancementModelPaths::* field;
13
+ bool required;
14
+ };
15
+
16
+ struct EnhancementValidationResult {
17
+ bool ok = true;
18
+ std::vector<std::string> missingRequired;
19
+ std::string error;
20
+ };
21
+
22
+ EnhancementValidationResult ValidateEnhancementPaths(
23
+ EnhancementModelKind kind,
24
+ const EnhancementModelPaths& paths,
25
+ const std::string& modelDir
26
+ );
27
+
28
+ } // namespace sherpaonnx
29
+
30
+ #endif // SHERPA_ONNX_VALIDATE_ENHANCEMENT_H
@@ -20,6 +20,7 @@
20
20
  #include "sherpa-onnx-model-detect.h"
21
21
  #include "sherpa-onnx-stt-wrapper.h"
22
22
  #include "sherpa-onnx-tts-wrapper.h"
23
+ #include "sherpa-onnx-enhancement-wrapper.h"
23
24
 
24
25
  extern "C" {
25
26
 
@@ -187,4 +188,24 @@ Java_com_sherpaonnx_SherpaOnnxModule_nativeDetectTtsModel(
187
188
  return sherpaonnx::TtsDetectResultToJava(env, result);
188
189
  }
189
190
 
191
+ // Detect enhancement model in directory. Returns HashMap with success, error, detectedModels, modelType, paths.
192
+ JNIEXPORT jobject JNICALL
193
+ Java_com_sherpaonnx_SherpaOnnxModule_nativeDetectEnhancementModel(
194
+ JNIEnv* env,
195
+ jobject /* this */,
196
+ jstring j_model_dir,
197
+ jstring j_model_type) {
198
+ const char* model_dir_c = env->GetStringUTFChars(j_model_dir, nullptr);
199
+ const char* model_type_c =
200
+ j_model_type ? env->GetStringUTFChars(j_model_type, nullptr) : nullptr;
201
+ std::string model_dir(model_dir_c ? model_dir_c : "");
202
+ std::string model_type(model_type_c ? model_type_c : "auto");
203
+ env->ReleaseStringUTFChars(j_model_dir, model_dir_c);
204
+ if (model_type_c) env->ReleaseStringUTFChars(j_model_type, model_type_c);
205
+
206
+ sherpaonnx::EnhancementDetectResult result =
207
+ sherpaonnx::DetectEnhancementModel(model_dir, model_type);
208
+ return sherpaonnx::EnhancementDetectResultToJava(env, result);
209
+ }
210
+
190
211
  } // extern "C"