react-native-sherpa-onnx 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. package/README.md +232 -236
  2. package/SherpaOnnx.podspec +68 -64
  3. package/android/build.gradle +182 -192
  4. package/android/codegen.gradle +57 -0
  5. package/android/prebuilt-download.gradle +428 -0
  6. package/android/prebuilt-versions.gradle +43 -0
  7. package/android/proguard-rules.pro +10 -0
  8. package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
  9. package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
  10. package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
  11. package/android/src/main/cpp/CMakeLists.txt +166 -129
  12. package/android/src/main/cpp/CMakePresets.json +54 -0
  13. package/android/src/main/cpp/crypto/sha256.cpp +174 -0
  14. package/android/src/main/cpp/crypto/sha256.h +16 -0
  15. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
  16. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
  17. package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
  18. package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
  19. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
  20. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
  21. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
  22. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
  23. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
  24. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
  25. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
  26. package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
  27. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
  28. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
  29. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
  30. package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
  31. package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
  32. package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
  33. package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
  34. package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
  35. package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
  36. package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
  37. package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
  38. package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
  39. package/ios/SherpaOnnx+Assets.h +11 -0
  40. package/ios/SherpaOnnx+Assets.mm +325 -0
  41. package/ios/SherpaOnnx+STT.mm +455 -118
  42. package/ios/SherpaOnnx+TTS.mm +1101 -712
  43. package/ios/SherpaOnnx.h +17 -6
  44. package/ios/SherpaOnnx.mm +206 -311
  45. package/ios/SherpaOnnx.xcconfig +19 -19
  46. package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
  47. package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
  48. package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
  49. package/ios/libarchive_darwin_config.h +153 -0
  50. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
  51. package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
  52. package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
  53. package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
  54. package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
  55. package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
  56. package/ios/scripts/patch-libarchive-includes.sh +61 -0
  57. package/ios/scripts/setup-ios-libarchive.sh +98 -0
  58. package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
  59. package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
  60. package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
  61. package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
  62. package/lib/module/NativeSherpaOnnx.js +3 -0
  63. package/lib/module/NativeSherpaOnnx.js.map +1 -1
  64. package/lib/module/audio/index.js +22 -0
  65. package/lib/module/audio/index.js.map +1 -0
  66. package/lib/module/diarization/index.js +1 -1
  67. package/lib/module/diarization/index.js.map +1 -1
  68. package/lib/module/download/ModelDownloadManager.js +918 -0
  69. package/lib/module/download/ModelDownloadManager.js.map +1 -0
  70. package/lib/module/download/extractTarBz2.js +53 -0
  71. package/lib/module/download/extractTarBz2.js.map +1 -0
  72. package/lib/module/download/index.js +6 -0
  73. package/lib/module/download/index.js.map +1 -0
  74. package/lib/module/download/validation.js +178 -0
  75. package/lib/module/download/validation.js.map +1 -0
  76. package/lib/module/enhancement/index.js +1 -1
  77. package/lib/module/enhancement/index.js.map +1 -1
  78. package/lib/module/index.js +41 -3
  79. package/lib/module/index.js.map +1 -1
  80. package/lib/module/separation/index.js +1 -1
  81. package/lib/module/separation/index.js.map +1 -1
  82. package/lib/module/stt/index.js +127 -60
  83. package/lib/module/stt/index.js.map +1 -1
  84. package/lib/module/stt/sttModelLanguages.js +512 -0
  85. package/lib/module/stt/sttModelLanguages.js.map +1 -0
  86. package/lib/module/stt/types.js +53 -1
  87. package/lib/module/stt/types.js.map +1 -1
  88. package/lib/module/tts/index.js +216 -289
  89. package/lib/module/tts/index.js.map +1 -1
  90. package/lib/module/tts/types.js +86 -1
  91. package/lib/module/tts/types.js.map +1 -1
  92. package/lib/module/types.js.map +1 -1
  93. package/lib/module/utils.js +86 -73
  94. package/lib/module/utils.js.map +1 -1
  95. package/lib/module/vad/index.js +1 -1
  96. package/lib/module/vad/index.js.map +1 -1
  97. package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
  98. package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
  99. package/lib/typescript/src/audio/index.d.ts +13 -0
  100. package/lib/typescript/src/audio/index.d.ts.map +1 -0
  101. package/lib/typescript/src/diarization/index.d.ts +3 -2
  102. package/lib/typescript/src/diarization/index.d.ts.map +1 -1
  103. package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
  104. package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
  105. package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
  106. package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
  107. package/lib/typescript/src/download/index.d.ts +7 -0
  108. package/lib/typescript/src/download/index.d.ts.map +1 -0
  109. package/lib/typescript/src/download/validation.d.ts +57 -0
  110. package/lib/typescript/src/download/validation.d.ts.map +1 -0
  111. package/lib/typescript/src/enhancement/index.d.ts +3 -2
  112. package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
  113. package/lib/typescript/src/index.d.ts +26 -2
  114. package/lib/typescript/src/index.d.ts.map +1 -1
  115. package/lib/typescript/src/separation/index.d.ts +3 -2
  116. package/lib/typescript/src/separation/index.d.ts.map +1 -1
  117. package/lib/typescript/src/stt/index.d.ts +31 -43
  118. package/lib/typescript/src/stt/index.d.ts.map +1 -1
  119. package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
  120. package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
  121. package/lib/typescript/src/stt/types.d.ts +196 -9
  122. package/lib/typescript/src/stt/types.d.ts.map +1 -1
  123. package/lib/typescript/src/tts/index.d.ts +25 -211
  124. package/lib/typescript/src/tts/index.d.ts.map +1 -1
  125. package/lib/typescript/src/tts/types.d.ts +148 -25
  126. package/lib/typescript/src/tts/types.d.ts.map +1 -1
  127. package/lib/typescript/src/types.d.ts +0 -32
  128. package/lib/typescript/src/types.d.ts.map +1 -1
  129. package/lib/typescript/src/utils.d.ts +28 -13
  130. package/lib/typescript/src/utils.d.ts.map +1 -1
  131. package/lib/typescript/src/vad/index.d.ts +3 -2
  132. package/lib/typescript/src/vad/index.d.ts.map +1 -1
  133. package/package.json +250 -222
  134. package/scripts/check-qnn-support.sh +78 -0
  135. package/scripts/setup-ios-framework.sh +379 -282
  136. package/src/NativeSherpaOnnx.ts +474 -251
  137. package/src/audio/index.ts +32 -0
  138. package/src/diarization/index.ts +4 -2
  139. package/src/download/ModelDownloadManager.ts +1325 -0
  140. package/src/download/extractTarBz2.ts +78 -0
  141. package/src/download/index.ts +43 -0
  142. package/src/download/validation.ts +279 -0
  143. package/src/enhancement/index.ts +4 -2
  144. package/src/index.tsx +78 -27
  145. package/src/separation/index.ts +4 -2
  146. package/src/stt/index.ts +249 -89
  147. package/src/stt/sttModelLanguages.ts +237 -0
  148. package/src/stt/types.ts +263 -9
  149. package/src/tts/index.ts +470 -458
  150. package/src/tts/types.ts +373 -218
  151. package/src/types.ts +0 -44
  152. package/src/utils.ts +145 -131
  153. package/src/vad/index.ts +4 -2
  154. package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
  155. package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
  156. package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
  157. package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
  158. package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
  159. package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
  160. package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  161. package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
  162. package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
  163. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
  164. package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
  165. package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
  166. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
  167. package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
  168. package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
  169. package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
  170. package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
  171. package/ios/sherpa-onnx-model-detect.mm +0 -441
  172. package/ios/sherpa-onnx-stt-wrapper.h +0 -48
  173. package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
  174. package/scripts/copy-headers.js +0 -184
  175. package/scripts/setup-assets.js +0 -323
@@ -0,0 +1,129 @@
1
+ #ifndef SHERPA_ONNX_STT_WRAPPER_H
2
+ #define SHERPA_ONNX_STT_WRAPPER_H
3
+
4
+ #include "sherpa-onnx-common.h"
5
+ #include <cstdint>
6
+ #include <memory>
7
+ #include <optional>
8
+ #include <string>
9
+ #include <vector>
10
+
11
+ namespace sherpaonnx {
12
+
13
+ /**
14
+ * Result of STT initialization.
15
+ */
16
+ struct SttInitializeResult {
17
+ bool success;
18
+ std::vector<DetectedModel> detectedModels; // List of detected models with type and path
19
+ /** When success is false, optional error message (e.g. HOTWORDS_NOT_SUPPORTED). */
20
+ std::string error;
21
+ /** Loaded model type (e.g. "whisper", "transducer") for JS modelType in init result. */
22
+ std::string modelType;
23
+ /** Decoding method actually applied (e.g. "greedy_search", "modified_beam_search"). Set when success is true. */
24
+ std::string decodingMethod;
25
+ };
26
+
27
+ /**
28
+ * Full recognition result (aligned with JS SttRecognitionResult).
29
+ */
30
+ struct SttRecognitionResult {
31
+ std::string text;
32
+ std::vector<std::string> tokens;
33
+ std::vector<float> timestamps;
34
+ std::string lang;
35
+ std::string emotion;
36
+ std::string event;
37
+ std::vector<float> durations;
38
+ };
39
+
40
+ /**
41
+ * Runtime config options for setConfig (only mutable fields).
42
+ */
43
+ struct SttRuntimeConfigOptions {
44
+ std::optional<std::string> decoding_method;
45
+ std::optional<int32_t> max_active_paths;
46
+ std::optional<std::string> hotwords_file;
47
+ std::optional<float> hotwords_score;
48
+ std::optional<float> blank_penalty;
49
+ std::optional<std::string> rule_fsts;
50
+ std::optional<std::string> rule_fars;
51
+ };
52
+
53
+ /** Model-specific options: Whisper (iOS: language, task, tail_paddings only). */
54
+ struct SttWhisperOptions {
55
+ std::optional<std::string> language;
56
+ std::optional<std::string> task;
57
+ std::optional<int32_t> tail_paddings;
58
+ };
59
+
60
+ /** Model-specific options: SenseVoice. */
61
+ struct SttSenseVoiceOptions {
62
+ std::optional<std::string> language;
63
+ std::optional<bool> use_itn;
64
+ };
65
+
66
+ /** Model-specific options: Canary. */
67
+ struct SttCanaryOptions {
68
+ std::optional<std::string> src_lang;
69
+ std::optional<std::string> tgt_lang;
70
+ std::optional<bool> use_pnc;
71
+ };
72
+
73
+ /** Model-specific options: FunASR Nano. */
74
+ struct SttFunAsrNanoOptions {
75
+ std::optional<std::string> system_prompt;
76
+ std::optional<std::string> user_prompt;
77
+ std::optional<int32_t> max_new_tokens;
78
+ std::optional<float> temperature;
79
+ std::optional<float> top_p;
80
+ std::optional<int32_t> seed;
81
+ std::optional<std::string> language;
82
+ std::optional<bool> itn;
83
+ std::optional<std::string> hotwords;
84
+ };
85
+
86
+ /**
87
+ * Wrapper class for sherpa-onnx OfflineRecognizer (STT).
88
+ */
89
+ class SttWrapper {
90
+ public:
91
+ SttWrapper();
92
+ ~SttWrapper();
93
+
94
+ SttInitializeResult initialize(
95
+ const std::string& modelDir,
96
+ const std::optional<bool>& preferInt8 = std::nullopt,
97
+ const std::optional<std::string>& modelType = std::nullopt,
98
+ bool debug = false,
99
+ const std::optional<std::string>& hotwordsFile = std::nullopt,
100
+ const std::optional<float>& hotwordsScore = std::nullopt,
101
+ const std::optional<int32_t>& numThreads = std::nullopt,
102
+ const std::optional<std::string>& provider = std::nullopt,
103
+ const std::optional<std::string>& ruleFsts = std::nullopt,
104
+ const std::optional<std::string>& ruleFars = std::nullopt,
105
+ const std::optional<float>& dither = std::nullopt,
106
+ const SttWhisperOptions* whisperOpts = nullptr,
107
+ const SttSenseVoiceOptions* senseVoiceOpts = nullptr,
108
+ const SttCanaryOptions* canaryOpts = nullptr,
109
+ const SttFunAsrNanoOptions* funasrNanoOpts = nullptr
110
+ );
111
+
112
+ SttRecognitionResult transcribeFile(const std::string& filePath);
113
+
114
+ SttRecognitionResult transcribeSamples(const std::vector<float>& samples, int32_t sampleRate);
115
+
116
+ void setConfig(const SttRuntimeConfigOptions& options);
117
+
118
+ bool isInitialized() const;
119
+
120
+ void release();
121
+
122
+ private:
123
+ class Impl;
124
+ std::unique_ptr<Impl> pImpl;
125
+ };
126
+
127
+ } // namespace sherpaonnx
128
+
129
+ #endif // SHERPA_ONNX_STT_WRAPPER_H
@@ -0,0 +1,523 @@
1
+ /**
2
+ * sherpa-onnx-stt-wrapper.mm
3
+ *
4
+ * Purpose: Wraps the sherpa-onnx C++ OfflineRecognizer for iOS. Builds config from SttModelPaths,
5
+ * creates/destroys recognizer and streams, runs recognition and returns results. Used by SherpaOnnx+STT.mm.
6
+ */
7
+
8
+ #include "sherpa-onnx-stt-wrapper.h"
9
+ #include "sherpa-onnx-model-detect.h"
10
+ #include <algorithm>
11
+ #include <cctype>
12
+ #include <cstring>
13
+ #include <fstream>
14
+ #include <optional>
15
+ #include <sstream>
16
+ #include <cstdint>
17
+ #include <limits>
18
+
19
+ // iOS logging
20
+ #ifdef __APPLE__
21
+ #include <Foundation/Foundation.h>
22
+ #include <cstdio>
23
+ #define LOGI(fmt, ...) NSLog(@"SttWrapper: " fmt, ##__VA_ARGS__)
24
+ #define LOGE(fmt, ...) NSLog(@"SttWrapper ERROR: " fmt, ##__VA_ARGS__)
25
+ #else
26
+ #define LOGI(...)
27
+ #define LOGE(...)
28
+ #endif
29
+
30
+ // Use C++17 filesystem (podspec enforces C++17)
31
+ #include <filesystem>
32
+ namespace fs = std::filesystem;
33
+
34
+ // sherpa-onnx headers - use C++ API (RAII wrapper around C API)
35
+ #include "sherpa-onnx/c-api/cxx-api.h"
36
+
37
+ namespace sherpaonnx {
38
+
39
+ // Hotwords are supported for transducer and NeMo transducer (sherpa-onnx; NeMo: #3077).
40
+ static bool SupportsHotwords(sherpaonnx::SttModelKind kind) {
41
+ return kind == sherpaonnx::SttModelKind::kTransducer || kind == sherpaonnx::SttModelKind::kNemoTransducer;
42
+ }
43
+
44
+ // Returns error message if hotwords file is invalid, else empty optional.
45
+ static std::optional<std::string> ValidateHotwordsFile(const std::string& filePath) {
46
+ if (filePath.empty()) return std::nullopt;
47
+ try {
48
+ if (!fs::exists(filePath)) return "Hotwords file does not exist: " + filePath;
49
+ if (!fs::is_regular_file(filePath)) return "Hotwords path is not a file: " + filePath;
50
+ std::ifstream f(filePath, std::ios::binary);
51
+ if (!f) return "Hotwords file is not readable: " + filePath;
52
+ std::string content((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
53
+ f.close();
54
+ if (content.find('\0') != std::string::npos) return "Hotwords file contains null bytes (not a valid text file).";
55
+ NSCharacterSet *letterSet = [NSCharacterSet letterCharacterSet];
56
+ int validLines = 0;
57
+ std::istringstream stream(content);
58
+ std::string line;
59
+ while (std::getline(stream, line, '\n')) {
60
+ if (!line.empty() && line.back() == '\r') line.pop_back();
61
+ size_t start = 0;
62
+ while (start < line.size() && (line[start] == ' ' || line[start] == '\t')) start++;
63
+ size_t end = line.size();
64
+ while (end > start && (line[end - 1] == ' ' || line[end - 1] == '\t')) end--;
65
+ if (start >= end) continue;
66
+ line = line.substr(start, end - start);
67
+ std::string hotwordPart;
68
+ size_t spaceColon = line.rfind(" :");
69
+ if (spaceColon != std::string::npos) {
70
+ std::string scoreStr = line.substr(spaceColon + 2);
71
+ try {
72
+ (void)std::stof(scoreStr);
73
+ } catch (...) {
74
+ return "Invalid hotword line (score must be a number after ' :'): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
75
+ }
76
+ size_t hStart = 0, hEnd = spaceColon;
77
+ while (hStart < hEnd && (line[hStart] == ' ' || line[hStart] == '\t')) hStart++;
78
+ while (hEnd > hStart && (line[hEnd - 1] == ' ' || line[hEnd - 1] == '\t')) hEnd--;
79
+ hotwordPart = line.substr(hStart, hEnd - hStart);
80
+ } else {
81
+ size_t tabPos = line.find('\t');
82
+ if (tabPos != std::string::npos) {
83
+ std::string afterTab = line.substr(tabPos + 1);
84
+ try {
85
+ (void)std::stof(afterTab);
86
+ return "This file looks like a sentencepiece .vocab file (token<TAB>score). Use a hotwords file instead: one word or phrase per line, optional ' :score' at end.";
87
+ } catch (...) {}
88
+ }
89
+ hotwordPart = line;
90
+ }
91
+ if (hotwordPart.empty()) return "Invalid hotword line (empty hotword): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
92
+ @autoreleasepool {
93
+ NSString *hotwordNS = [NSString stringWithUTF8String:hotwordPart.c_str()];
94
+ if (!hotwordNS) return "Invalid hotword line (invalid UTF-8): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
95
+ if ([hotwordNS rangeOfCharacterFromSet:letterSet].location == NSNotFound)
96
+ return "Invalid hotword line (must contain at least one letter): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
97
+ }
98
+ validLines++;
99
+ }
100
+ if (validLines == 0) return "Hotwords file has no valid lines (one hotword or phrase per line, UTF-8 text).";
101
+ return std::nullopt;
102
+ } catch (const std::exception& e) {
103
+ return std::string("Failed to read hotwords file: ") + e.what();
104
+ }
105
+ }
106
+
107
+ // PIMPL pattern implementation
108
+ class SttWrapper::Impl {
109
+ public:
110
+ bool initialized = false;
111
+ std::string modelDir;
112
+ sherpaonnx::SttModelKind currentModelKind = sherpaonnx::SttModelKind::kUnknown;
113
+ std::optional<sherpa_onnx::cxx::OfflineRecognizer> recognizer;
114
+ std::optional<sherpa_onnx::cxx::OfflineRecognizerConfig> lastConfig;
115
+ };
116
+
117
+ SttWrapper::SttWrapper() : pImpl(std::make_unique<Impl>()) {
118
+ LOGI("SttWrapper created");
119
+ }
120
+
121
+ SttWrapper::~SttWrapper() {
122
+ release();
123
+ LOGI("SttWrapper destroyed");
124
+ }
125
+
126
+ SttInitializeResult SttWrapper::initialize(
127
+ const std::string& modelDir,
128
+ const std::optional<bool>& preferInt8,
129
+ const std::optional<std::string>& modelType,
130
+ bool debug,
131
+ const std::optional<std::string>& hotwordsFile,
132
+ const std::optional<float>& hotwordsScore,
133
+ const std::optional<int32_t>& numThreads,
134
+ const std::optional<std::string>& provider,
135
+ const std::optional<std::string>& ruleFsts,
136
+ const std::optional<std::string>& ruleFars,
137
+ const std::optional<float>& dither,
138
+ const SttWhisperOptions* whisperOpts,
139
+ const SttSenseVoiceOptions* senseVoiceOpts,
140
+ const SttCanaryOptions* canaryOpts,
141
+ const SttFunAsrNanoOptions* funasrNanoOpts
142
+ ) {
143
+ SttInitializeResult result;
144
+ result.success = false;
145
+
146
+ if (pImpl->initialized) {
147
+ release();
148
+ }
149
+
150
+ if (modelDir.empty()) {
151
+ result.error = "Model directory is empty";
152
+ LOGE("%s", result.error.c_str());
153
+ return result;
154
+ }
155
+
156
+ try {
157
+ sherpa_onnx::cxx::OfflineRecognizerConfig config;
158
+ config.feat_config.sample_rate = 16000;
159
+ config.feat_config.feature_dim = 80;
160
+
161
+ auto detect = DetectSttModel(modelDir, preferInt8, modelType, debug);
162
+ if (!detect.ok) {
163
+ result.error = detect.error;
164
+ LOGE("%s", result.error.c_str());
165
+ return result;
166
+ }
167
+
168
+ switch (detect.selectedKind) {
169
+ case SttModelKind::kTransducer:
170
+ case SttModelKind::kNemoTransducer:
171
+ config.model_config.transducer.encoder = detect.paths.encoder;
172
+ config.model_config.transducer.decoder = detect.paths.decoder;
173
+ config.model_config.transducer.joiner = detect.paths.joiner;
174
+ break;
175
+ case SttModelKind::kParaformer:
176
+ config.model_config.paraformer.model = detect.paths.paraformerModel;
177
+ break;
178
+ case SttModelKind::kNemoCtc:
179
+ config.model_config.nemo_ctc.model = detect.paths.ctcModel;
180
+ break;
181
+ case SttModelKind::kWenetCtc:
182
+ config.model_config.wenet_ctc.model = detect.paths.ctcModel;
183
+ break;
184
+ case SttModelKind::kSenseVoice:
185
+ config.model_config.sense_voice.model = detect.paths.ctcModel;
186
+ break;
187
+ case SttModelKind::kZipformerCtc:
188
+ config.model_config.zipformer_ctc.model = detect.paths.ctcModel;
189
+ break;
190
+ case SttModelKind::kWhisper:
191
+ config.model_config.whisper.encoder = detect.paths.whisperEncoder;
192
+ config.model_config.whisper.decoder = detect.paths.whisperDecoder;
193
+ break;
194
+ case SttModelKind::kFunAsrNano:
195
+ config.model_config.funasr_nano.encoder_adaptor = detect.paths.funasrEncoderAdaptor;
196
+ config.model_config.funasr_nano.llm = detect.paths.funasrLLM;
197
+ config.model_config.funasr_nano.embedding = detect.paths.funasrEmbedding;
198
+ config.model_config.funasr_nano.tokenizer = detect.paths.funasrTokenizer;
199
+ break;
200
+ case SttModelKind::kFireRedAsr:
201
+ config.model_config.fire_red_asr.encoder = detect.paths.fireRedEncoder;
202
+ config.model_config.fire_red_asr.decoder = detect.paths.fireRedDecoder;
203
+ break;
204
+ case SttModelKind::kMoonshine:
205
+ config.model_config.moonshine.preprocessor = detect.paths.moonshinePreprocessor;
206
+ config.model_config.moonshine.encoder = detect.paths.moonshineEncoder;
207
+ config.model_config.moonshine.uncached_decoder = detect.paths.moonshineUncachedDecoder;
208
+ config.model_config.moonshine.cached_decoder = detect.paths.moonshineCachedDecoder;
209
+ break;
210
+ case SttModelKind::kDolphin:
211
+ config.model_config.dolphin.model = detect.paths.dolphinModel;
212
+ break;
213
+ case SttModelKind::kCanary:
214
+ config.model_config.canary.encoder = detect.paths.canaryEncoder;
215
+ config.model_config.canary.decoder = detect.paths.canaryDecoder;
216
+ break;
217
+ case SttModelKind::kOmnilingual:
218
+ config.model_config.omnilingual.model = detect.paths.omnilingualModel;
219
+ break;
220
+ case SttModelKind::kMedAsr:
221
+ config.model_config.medasr.model = detect.paths.medasrModel;
222
+ break;
223
+ case SttModelKind::kTeleSpeechCtc:
224
+ config.model_config.telespeech_ctc = detect.paths.telespeechCtcModel;
225
+ break;
226
+ case SttModelKind::kUnknown:
227
+ default:
228
+ result.error = "No compatible model type detected in " + modelDir;
229
+ LOGE("%s", result.error.c_str());
230
+ return result;
231
+ }
232
+
233
+ if (!detect.paths.tokens.empty()) {
234
+ config.model_config.tokens = detect.paths.tokens;
235
+ }
236
+
237
+ // Apply model-specific options (only for the loaded model type).
238
+ switch (detect.selectedKind) {
239
+ case SttModelKind::kWhisper:
240
+ if (whisperOpts) {
241
+ if (whisperOpts->language.has_value())
242
+ config.model_config.whisper.language = *whisperOpts->language;
243
+ if (whisperOpts->task.has_value())
244
+ config.model_config.whisper.task = *whisperOpts->task;
245
+ if (whisperOpts->tail_paddings.has_value())
246
+ config.model_config.whisper.tail_paddings = *whisperOpts->tail_paddings;
247
+ }
248
+ break;
249
+ case SttModelKind::kSenseVoice:
250
+ if (senseVoiceOpts) {
251
+ if (senseVoiceOpts->language.has_value())
252
+ config.model_config.sense_voice.language = *senseVoiceOpts->language;
253
+ if (senseVoiceOpts->use_itn.has_value())
254
+ config.model_config.sense_voice.use_itn = *senseVoiceOpts->use_itn;
255
+ }
256
+ break;
257
+ case SttModelKind::kCanary:
258
+ if (canaryOpts) {
259
+ if (canaryOpts->src_lang.has_value())
260
+ config.model_config.canary.src_lang = *canaryOpts->src_lang;
261
+ if (canaryOpts->tgt_lang.has_value())
262
+ config.model_config.canary.tgt_lang = *canaryOpts->tgt_lang;
263
+ if (canaryOpts->use_pnc.has_value())
264
+ config.model_config.canary.use_pnc = *canaryOpts->use_pnc;
265
+ }
266
+ break;
267
+ case SttModelKind::kFunAsrNano:
268
+ if (funasrNanoOpts) {
269
+ if (funasrNanoOpts->system_prompt.has_value())
270
+ config.model_config.funasr_nano.system_prompt = *funasrNanoOpts->system_prompt;
271
+ if (funasrNanoOpts->user_prompt.has_value())
272
+ config.model_config.funasr_nano.user_prompt = *funasrNanoOpts->user_prompt;
273
+ if (funasrNanoOpts->max_new_tokens.has_value())
274
+ config.model_config.funasr_nano.max_new_tokens = *funasrNanoOpts->max_new_tokens;
275
+ if (funasrNanoOpts->temperature.has_value())
276
+ config.model_config.funasr_nano.temperature = *funasrNanoOpts->temperature;
277
+ if (funasrNanoOpts->top_p.has_value())
278
+ config.model_config.funasr_nano.top_p = *funasrNanoOpts->top_p;
279
+ if (funasrNanoOpts->seed.has_value())
280
+ config.model_config.funasr_nano.seed = *funasrNanoOpts->seed;
281
+ if (funasrNanoOpts->language.has_value())
282
+ config.model_config.funasr_nano.language = *funasrNanoOpts->language;
283
+ if (funasrNanoOpts->itn.has_value())
284
+ config.model_config.funasr_nano.itn = *funasrNanoOpts->itn;
285
+ if (funasrNanoOpts->hotwords.has_value())
286
+ config.model_config.funasr_nano.hotwords = *funasrNanoOpts->hotwords;
287
+ }
288
+ break;
289
+ default:
290
+ break;
291
+ }
292
+
293
+ if (hotwordsFile.has_value() && !hotwordsFile->empty()) {
294
+ if (!SupportsHotwords(detect.selectedKind)) {
295
+ result.success = false;
296
+ result.error = "HOTWORDS_NOT_SUPPORTED: Hotwords are only supported for transducer models (transducer, nemo_transducer). Current model type is not transducer.";
297
+ LOGE("%s", result.error.c_str());
298
+ return result;
299
+ }
300
+ auto validateErr = ValidateHotwordsFile(*hotwordsFile);
301
+ if (validateErr.has_value()) {
302
+ result.success = false;
303
+ result.error = "INVALID_HOTWORDS_FILE: " + *validateErr;
304
+ LOGE("%s", result.error.c_str());
305
+ return result;
306
+ }
307
+ }
308
+
309
+ config.decoding_method = "greedy_search";
310
+ config.model_config.num_threads = numThreads.value_or(1);
311
+ config.model_config.provider = provider.value_or("cpu");
312
+ if (hotwordsFile.has_value() && !hotwordsFile->empty()) {
313
+ config.hotwords_file = *hotwordsFile;
314
+ config.decoding_method = "modified_beam_search";
315
+ config.max_active_paths = std::max(4, config.max_active_paths);
316
+ }
317
+ if (hotwordsScore.has_value()) {
318
+ config.hotwords_score = *hotwordsScore;
319
+ }
320
+ if (ruleFsts.has_value() && !ruleFsts->empty()) {
321
+ config.rule_fsts = *ruleFsts;
322
+ }
323
+ if (ruleFars.has_value() && !ruleFars->empty()) {
324
+ config.rule_fars = *ruleFars;
325
+ }
326
+ (void)dither; // FeatureConfig in bundled cxx-api.h has no dither; reserve for future use
327
+
328
+ bool isWhisperModel = !config.model_config.whisper.encoder.empty() && !config.model_config.whisper.decoder.empty();
329
+ if (isWhisperModel) {
330
+ LOGI("Initializing Whisper model with encoder: %s, decoder: %s", config.model_config.whisper.encoder.c_str(), config.model_config.whisper.decoder.c_str());
331
+ } else {
332
+ LOGI("Initializing non-Whisper model");
333
+ }
334
+ try {
335
+ pImpl->recognizer = sherpa_onnx::cxx::OfflineRecognizer::Create(config);
336
+ } catch (const std::exception& e) {
337
+ LOGE("Failed to create recognizer: %s", e.what());
338
+ result.success = false;
339
+ result.error = std::string("INIT_ERROR: ") + e.what();
340
+ return result;
341
+ } catch (...) {
342
+ LOGE("Unknown exception during recognizer creation");
343
+ result.success = false;
344
+ result.error = "INIT_ERROR: Unknown exception during recognizer creation";
345
+ return result;
346
+ }
347
+
348
+ pImpl->lastConfig = config;
349
+ pImpl->modelDir = modelDir;
350
+ pImpl->currentModelKind = detect.selectedKind;
351
+ pImpl->initialized = true;
352
+
353
+ result.success = true;
354
+ result.detectedModels = detect.detectedModels;
355
+ result.modelType = detect.detectedModels.empty() ? "" : detect.detectedModels[0].type;
356
+ result.decodingMethod = config.decoding_method;
357
+ return result;
358
+ } catch (const std::exception& e) {
359
+ result.error = std::string("Exception during initialization: ") + e.what();
360
+ LOGE("%s", result.error.c_str());
361
+ return result;
362
+ } catch (...) {
363
+ result.error = "Unknown exception during initialization";
364
+ LOGE("%s", result.error.c_str());
365
+ return result;
366
+ }
367
+ }
368
+
369
+ namespace {
370
+ SttRecognitionResult offlineResultToSttResult(const sherpa_onnx::cxx::OfflineRecognizerResult& r) {
371
+ SttRecognitionResult out;
372
+ out.text = r.text;
373
+ out.tokens = r.tokens;
374
+ out.timestamps = r.timestamps;
375
+ out.lang = r.lang;
376
+ out.emotion = r.emotion;
377
+ out.event = r.event;
378
+ out.durations = r.durations;
379
+ return out;
380
+ }
381
+ } // namespace
382
+
383
+ SttRecognitionResult SttWrapper::transcribeFile(const std::string& filePath) {
384
+ if (!pImpl->initialized || !pImpl->recognizer.has_value()) {
385
+ LOGE("Not initialized. Call initialize() first.");
386
+ throw std::runtime_error("STT not initialized. Call initialize() first.");
387
+ }
388
+
389
+ auto fileExists = [](const std::string& path) -> bool {
390
+ return fs::exists(path);
391
+ };
392
+
393
+ LOGI("Transcribe: file=%s", filePath.c_str());
394
+ if (!fileExists(filePath)) {
395
+ LOGE("Audio file not found: %s", filePath.c_str());
396
+ throw std::runtime_error(std::string("Audio file not found: ") + filePath);
397
+ }
398
+
399
+ sherpa_onnx::cxx::Wave wave;
400
+ try {
401
+ wave = sherpa_onnx::cxx::ReadWave(filePath);
402
+ } catch (const std::exception& e) {
403
+ LOGE("Transcribe: ReadWave failed: %s", e.what());
404
+ throw;
405
+ } catch (...) {
406
+ LOGE("Transcribe: ReadWave failed (unknown exception)");
407
+ throw std::runtime_error(std::string("Failed to read audio file: ") + filePath);
408
+ }
409
+
410
+ if (wave.samples.empty()) {
411
+ LOGE("Audio file is empty or failed to read: %s", filePath.c_str());
412
+ throw std::runtime_error(std::string("Audio file is empty or could not be read: ") + filePath);
413
+ }
414
+
415
+ try {
416
+ auto stream = pImpl->recognizer.value().CreateStream();
417
+
418
+ // Ensure safe conversions: AcceptWaveform expects 32-bit ints
419
+ if (wave.samples.size() > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
420
+ LOGE("Audio too large: sample count %zu exceeds int32_t max", wave.samples.size());
421
+ throw std::runtime_error("Audio too large to process");
422
+ }
423
+
424
+ int32_t sample_rate = 0;
425
+ if (wave.sample_rate > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
426
+ LOGE("Sample rate too large: %u", wave.sample_rate);
427
+ throw std::runtime_error("Unsupported sample rate");
428
+ } else {
429
+ sample_rate = static_cast<int32_t>(wave.sample_rate);
430
+ }
431
+
432
+ int32_t n_samples = static_cast<int32_t>(wave.samples.size());
433
+
434
+ stream.AcceptWaveform(sample_rate, wave.samples.data(), n_samples);
435
+ pImpl->recognizer.value().Decode(&stream);
436
+ auto result = pImpl->recognizer.value().GetResult(&stream);
437
+ return offlineResultToSttResult(result);
438
+ } catch (const std::exception& e) {
439
+ LOGE("Transcribe: recognition failed: %s", e.what());
440
+ throw;
441
+ } catch (...) {
442
+ LOGE("Transcribe: recognition failed (unknown exception)");
443
+ throw std::runtime_error(
444
+ "Recognition failed. Ensure the model supports offline decoding and audio is 16 kHz mono WAV."
445
+ );
446
+ }
447
+ }
448
+
449
+ SttRecognitionResult SttWrapper::transcribeSamples(const std::vector<float>& samples, int32_t sampleRate) {
450
+ if (!pImpl->initialized || !pImpl->recognizer.has_value()) {
451
+ LOGE("Not initialized. Call initialize() first.");
452
+ throw std::runtime_error("STT not initialized. Call initialize() first.");
453
+ }
454
+ if (samples.empty()) {
455
+ SttRecognitionResult empty;
456
+ return empty;
457
+ }
458
+ if (samples.size() > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
459
+ LOGE("Samples too large: %zu", samples.size());
460
+ throw std::runtime_error("Samples array too large to process");
461
+ }
462
+ try {
463
+ auto stream = pImpl->recognizer.value().CreateStream();
464
+ int32_t n = static_cast<int32_t>(samples.size());
465
+ stream.AcceptWaveform(sampleRate, samples.data(), n);
466
+ pImpl->recognizer.value().Decode(&stream);
467
+ auto result = pImpl->recognizer.value().GetResult(&stream);
468
+ return offlineResultToSttResult(result);
469
+ } catch (const std::exception& e) {
470
+ LOGE("TranscribeSamples: recognition failed: %s", e.what());
471
+ throw;
472
+ } catch (...) {
473
+ LOGE("TranscribeSamples: recognition failed (unknown exception)");
474
+ throw std::runtime_error("Recognition failed.");
475
+ }
476
+ }
477
+
478
+ void SttWrapper::setConfig(const SttRuntimeConfigOptions& options) {
479
+ if (!pImpl->initialized || !pImpl->recognizer.has_value() || !pImpl->lastConfig.has_value()) {
480
+ LOGE("Not initialized or no stored config.");
481
+ throw std::runtime_error("STT not initialized. Call initialize() first.");
482
+ }
483
+ auto& config = pImpl->lastConfig.value();
484
+ if (options.hotwords_file.has_value() && !options.hotwords_file->empty()) {
485
+ if (!SupportsHotwords(pImpl->currentModelKind)) {
486
+ LOGE("Hotwords are only supported for transducer models.");
487
+ throw std::runtime_error("HOTWORDS_NOT_SUPPORTED: Hotwords are only supported for transducer models (transducer, nemo_transducer). Current model type is not transducer.");
488
+ }
489
+ auto validateErr = ValidateHotwordsFile(*options.hotwords_file);
490
+ if (validateErr.has_value()) {
491
+ LOGE("%s", validateErr->c_str());
492
+ throw std::runtime_error("INVALID_HOTWORDS_FILE: " + *validateErr);
493
+ }
494
+ }
495
+ if (options.decoding_method.has_value()) config.decoding_method = *options.decoding_method;
496
+ if (options.max_active_paths.has_value()) config.max_active_paths = *options.max_active_paths;
497
+ if (options.hotwords_file.has_value()) config.hotwords_file = *options.hotwords_file;
498
+ if (options.hotwords_score.has_value()) config.hotwords_score = *options.hotwords_score;
499
+ if (options.blank_penalty.has_value()) config.blank_penalty = *options.blank_penalty;
500
+ if (options.rule_fsts.has_value()) config.rule_fsts = *options.rule_fsts;
501
+ if (options.rule_fars.has_value()) config.rule_fars = *options.rule_fars;
502
+ if (!config.hotwords_file.empty()) {
503
+ config.decoding_method = "modified_beam_search";
504
+ config.max_active_paths = std::max(4, config.max_active_paths);
505
+ }
506
+ pImpl->recognizer.value().SetConfig(config);
507
+ }
508
+
509
+ bool SttWrapper::isInitialized() const {
510
+ return pImpl->initialized;
511
+ }
512
+
513
+ void SttWrapper::release() {
514
+ if (pImpl->initialized) {
515
+ pImpl->recognizer.reset();
516
+ pImpl->lastConfig.reset();
517
+ pImpl->initialized = false;
518
+ pImpl->modelDir.clear();
519
+ pImpl->currentModelKind = sherpaonnx::SttModelKind::kUnknown;
520
+ }
521
+ }
522
+
523
+ } // namespace sherpaonnx