react-native-sherpa-onnx 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +232 -236
- package/SherpaOnnx.podspec +68 -64
- package/android/build.gradle +182 -192
- package/android/codegen.gradle +57 -0
- package/android/prebuilt-download.gradle +428 -0
- package/android/prebuilt-versions.gradle +43 -0
- package/android/proguard-rules.pro +10 -0
- package/android/src/main/assets/testModels/add_mul_add.onnx +28 -0
- package/android/src/main/assets/testModels/nnapi_internal_uint8_support.onnx +0 -0
- package/android/src/main/assets/testModels/qnn_multi_ctx_embed.onnx +0 -0
- package/android/src/main/cpp/CMakeLists.txt +166 -129
- package/android/src/main/cpp/CMakePresets.json +54 -0
- package/android/src/main/cpp/crypto/sha256.cpp +174 -0
- package/android/src/main/cpp/crypto/sha256.h +16 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.cpp +404 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-helper.h +56 -0
- package/android/src/main/cpp/jni/archive/sherpa-onnx-archive-jni.cpp +181 -0
- package/android/src/main/cpp/jni/audio/sherpa-onnx-audio-convert-jni.cpp +888 -0
- package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-common.h +18 -18
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.cpp +86 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-detect-jni-common.h +20 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.cpp +423 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-helper.h +55 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-stt.cpp +399 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-model-detect-tts.cpp +238 -0
- package/{ios → android/src/main/cpp/jni/model_detect}/sherpa-onnx-model-detect.h +122 -89
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.cpp +99 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-stt-wrapper.h +16 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.cpp +78 -0
- package/android/src/main/cpp/jni/model_detect/sherpa-onnx-tts-wrapper.h +16 -0
- package/android/src/main/cpp/jni/module/sherpa-onnx-module-jni.cpp +190 -0
- package/android/src/main/cpp/jni/tts/sherpa-onnx-tts-zipvoice-jni.cpp +301 -0
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxArchiveHelper.kt +94 -0
- package/android/src/main/java/com/sherpaonnx/{SherpaOnnxCoreHelper.kt → SherpaOnnxAssetHelper.kt} +350 -236
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxModule.kt +791 -483
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxSttHelper.kt +699 -109
- package/android/src/main/java/com/sherpaonnx/SherpaOnnxTtsHelper.kt +1123 -668
- package/android/src/main/java/com/sherpaonnx/ZipvoiceTtsWrapper.kt +187 -0
- package/ios/SherpaOnnx+Assets.h +11 -0
- package/ios/SherpaOnnx+Assets.mm +325 -0
- package/ios/SherpaOnnx+STT.mm +455 -118
- package/ios/SherpaOnnx+TTS.mm +1101 -712
- package/ios/SherpaOnnx.h +17 -6
- package/ios/SherpaOnnx.mm +206 -311
- package/ios/SherpaOnnx.xcconfig +19 -19
- package/ios/SherpaOnnxCoreMLHelper.swift +24 -0
- package/ios/archive/sherpa-onnx-archive-helper.h +21 -0
- package/ios/archive/sherpa-onnx-archive-helper.mm +296 -0
- package/ios/libarchive_darwin_config.h +153 -0
- package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-common.h +18 -18
- package/ios/model_detect/sherpa-onnx-model-detect-helper.h +49 -0
- package/ios/model_detect/sherpa-onnx-model-detect-helper.mm +210 -0
- package/ios/model_detect/sherpa-onnx-model-detect-stt.mm +344 -0
- package/ios/model_detect/sherpa-onnx-model-detect-tts.mm +201 -0
- package/{android/src/main/cpp/jni → ios/model_detect}/sherpa-onnx-model-detect.h +117 -89
- package/ios/scripts/patch-libarchive-includes.sh +61 -0
- package/ios/scripts/setup-ios-libarchive.sh +98 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.h +129 -0
- package/ios/stt/sherpa-onnx-stt-wrapper.mm +523 -0
- package/ios/{sherpa-onnx-tts-wrapper.h → tts/sherpa-onnx-tts-wrapper.h} +90 -85
- package/ios/{sherpa-onnx-tts-wrapper.mm → tts/sherpa-onnx-tts-wrapper.mm} +376 -345
- package/lib/module/NativeSherpaOnnx.js +3 -0
- package/lib/module/NativeSherpaOnnx.js.map +1 -1
- package/lib/module/audio/index.js +22 -0
- package/lib/module/audio/index.js.map +1 -0
- package/lib/module/diarization/index.js +1 -1
- package/lib/module/diarization/index.js.map +1 -1
- package/lib/module/download/ModelDownloadManager.js +918 -0
- package/lib/module/download/ModelDownloadManager.js.map +1 -0
- package/lib/module/download/extractTarBz2.js +53 -0
- package/lib/module/download/extractTarBz2.js.map +1 -0
- package/lib/module/download/index.js +6 -0
- package/lib/module/download/index.js.map +1 -0
- package/lib/module/download/validation.js +178 -0
- package/lib/module/download/validation.js.map +1 -0
- package/lib/module/enhancement/index.js +1 -1
- package/lib/module/enhancement/index.js.map +1 -1
- package/lib/module/index.js +41 -3
- package/lib/module/index.js.map +1 -1
- package/lib/module/separation/index.js +1 -1
- package/lib/module/separation/index.js.map +1 -1
- package/lib/module/stt/index.js +127 -60
- package/lib/module/stt/index.js.map +1 -1
- package/lib/module/stt/sttModelLanguages.js +512 -0
- package/lib/module/stt/sttModelLanguages.js.map +1 -0
- package/lib/module/stt/types.js +53 -1
- package/lib/module/stt/types.js.map +1 -1
- package/lib/module/tts/index.js +216 -289
- package/lib/module/tts/index.js.map +1 -1
- package/lib/module/tts/types.js +86 -1
- package/lib/module/tts/types.js.map +1 -1
- package/lib/module/types.js.map +1 -1
- package/lib/module/utils.js +86 -73
- package/lib/module/utils.js.map +1 -1
- package/lib/module/vad/index.js +1 -1
- package/lib/module/vad/index.js.map +1 -1
- package/lib/typescript/src/NativeSherpaOnnx.d.ts +192 -38
- package/lib/typescript/src/NativeSherpaOnnx.d.ts.map +1 -1
- package/lib/typescript/src/audio/index.d.ts +13 -0
- package/lib/typescript/src/audio/index.d.ts.map +1 -0
- package/lib/typescript/src/diarization/index.d.ts +3 -2
- package/lib/typescript/src/diarization/index.d.ts.map +1 -1
- package/lib/typescript/src/download/ModelDownloadManager.d.ts +108 -0
- package/lib/typescript/src/download/ModelDownloadManager.d.ts.map +1 -0
- package/lib/typescript/src/download/extractTarBz2.d.ts +14 -0
- package/lib/typescript/src/download/extractTarBz2.d.ts.map +1 -0
- package/lib/typescript/src/download/index.d.ts +7 -0
- package/lib/typescript/src/download/index.d.ts.map +1 -0
- package/lib/typescript/src/download/validation.d.ts +57 -0
- package/lib/typescript/src/download/validation.d.ts.map +1 -0
- package/lib/typescript/src/enhancement/index.d.ts +3 -2
- package/lib/typescript/src/enhancement/index.d.ts.map +1 -1
- package/lib/typescript/src/index.d.ts +26 -2
- package/lib/typescript/src/index.d.ts.map +1 -1
- package/lib/typescript/src/separation/index.d.ts +3 -2
- package/lib/typescript/src/separation/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/index.d.ts +31 -43
- package/lib/typescript/src/stt/index.d.ts.map +1 -1
- package/lib/typescript/src/stt/sttModelLanguages.d.ts +52 -0
- package/lib/typescript/src/stt/sttModelLanguages.d.ts.map +1 -0
- package/lib/typescript/src/stt/types.d.ts +196 -9
- package/lib/typescript/src/stt/types.d.ts.map +1 -1
- package/lib/typescript/src/tts/index.d.ts +25 -211
- package/lib/typescript/src/tts/index.d.ts.map +1 -1
- package/lib/typescript/src/tts/types.d.ts +148 -25
- package/lib/typescript/src/tts/types.d.ts.map +1 -1
- package/lib/typescript/src/types.d.ts +0 -32
- package/lib/typescript/src/types.d.ts.map +1 -1
- package/lib/typescript/src/utils.d.ts +28 -13
- package/lib/typescript/src/utils.d.ts.map +1 -1
- package/lib/typescript/src/vad/index.d.ts +3 -2
- package/lib/typescript/src/vad/index.d.ts.map +1 -1
- package/package.json +250 -222
- package/scripts/check-qnn-support.sh +78 -0
- package/scripts/setup-ios-framework.sh +379 -282
- package/src/NativeSherpaOnnx.ts +474 -251
- package/src/audio/index.ts +32 -0
- package/src/diarization/index.ts +4 -2
- package/src/download/ModelDownloadManager.ts +1325 -0
- package/src/download/extractTarBz2.ts +78 -0
- package/src/download/index.ts +43 -0
- package/src/download/validation.ts +279 -0
- package/src/enhancement/index.ts +4 -2
- package/src/index.tsx +78 -27
- package/src/separation/index.ts +4 -2
- package/src/stt/index.ts +249 -89
- package/src/stt/sttModelLanguages.ts +237 -0
- package/src/stt/types.ts +263 -9
- package/src/tts/index.ts +470 -458
- package/src/tts/types.ts +373 -218
- package/src/types.ts +0 -44
- package/src/utils.ts +145 -131
- package/src/vad/index.ts +4 -2
- package/third_party/ffmpeg_prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/libarchive_prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/libarchive_prebuilt/IOS_RELEASE_TAG +1 -0
- package/third_party/sherpa-onnx-prebuilt/ANDROID_RELEASE_TAG +1 -0
- package/third_party/sherpa-onnx-prebuilt/IOS_RELEASE_TAG +1 -0
- package/android/src/main/cpp/include/sherpa-onnx/c-api/c-api.h +0 -1918
- package/android/src/main/cpp/include/sherpa-onnx/c-api/cxx-api.h +0 -841
- package/android/src/main/cpp/jni/sherpa-onnx-model-detect.cpp +0 -541
- package/android/src/main/cpp/jni/sherpa-onnx-stt-jni.cpp +0 -336
- package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.cpp +0 -222
- package/android/src/main/cpp/jni/sherpa-onnx-stt-wrapper.h +0 -68
- package/android/src/main/cpp/jni/sherpa-onnx-tts-jni.cpp +0 -823
- package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.cpp +0 -387
- package/android/src/main/cpp/jni/sherpa-onnx-tts-wrapper.h +0 -147
- package/ios/Frameworks/sherpa_onnx.xcframework.zip +0 -0
- package/ios/include/sherpa-onnx/c-api/c-api.h +0 -1918
- package/ios/include/sherpa-onnx/c-api/cxx-api.h +0 -841
- package/ios/sherpa-onnx-model-detect.mm +0 -441
- package/ios/sherpa-onnx-stt-wrapper.h +0 -48
- package/ios/sherpa-onnx-stt-wrapper.mm +0 -201
- package/scripts/copy-headers.js +0 -184
- package/scripts/setup-assets.js +0 -323
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
#ifndef SHERPA_ONNX_STT_WRAPPER_H
|
|
2
|
+
#define SHERPA_ONNX_STT_WRAPPER_H
|
|
3
|
+
|
|
4
|
+
#include "sherpa-onnx-common.h"
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
#include <memory>
|
|
7
|
+
#include <optional>
|
|
8
|
+
#include <string>
|
|
9
|
+
#include <vector>
|
|
10
|
+
|
|
11
|
+
namespace sherpaonnx {
|
|
12
|
+
|
|
13
|
+
/**
|
|
14
|
+
* Result of STT initialization.
|
|
15
|
+
*/
|
|
16
|
+
struct SttInitializeResult {
|
|
17
|
+
bool success;
|
|
18
|
+
std::vector<DetectedModel> detectedModels; // List of detected models with type and path
|
|
19
|
+
/** When success is false, optional error message (e.g. HOTWORDS_NOT_SUPPORTED). */
|
|
20
|
+
std::string error;
|
|
21
|
+
/** Loaded model type (e.g. "whisper", "transducer") for JS modelType in init result. */
|
|
22
|
+
std::string modelType;
|
|
23
|
+
/** Decoding method actually applied (e.g. "greedy_search", "modified_beam_search"). Set when success is true. */
|
|
24
|
+
std::string decodingMethod;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
/**
|
|
28
|
+
* Full recognition result (aligned with JS SttRecognitionResult).
|
|
29
|
+
*/
|
|
30
|
+
struct SttRecognitionResult {
|
|
31
|
+
std::string text;
|
|
32
|
+
std::vector<std::string> tokens;
|
|
33
|
+
std::vector<float> timestamps;
|
|
34
|
+
std::string lang;
|
|
35
|
+
std::string emotion;
|
|
36
|
+
std::string event;
|
|
37
|
+
std::vector<float> durations;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* Runtime config options for setConfig (only mutable fields).
|
|
42
|
+
*/
|
|
43
|
+
struct SttRuntimeConfigOptions {
|
|
44
|
+
std::optional<std::string> decoding_method;
|
|
45
|
+
std::optional<int32_t> max_active_paths;
|
|
46
|
+
std::optional<std::string> hotwords_file;
|
|
47
|
+
std::optional<float> hotwords_score;
|
|
48
|
+
std::optional<float> blank_penalty;
|
|
49
|
+
std::optional<std::string> rule_fsts;
|
|
50
|
+
std::optional<std::string> rule_fars;
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
/** Model-specific options: Whisper (iOS: language, task, tail_paddings only). */
|
|
54
|
+
struct SttWhisperOptions {
|
|
55
|
+
std::optional<std::string> language;
|
|
56
|
+
std::optional<std::string> task;
|
|
57
|
+
std::optional<int32_t> tail_paddings;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
/** Model-specific options: SenseVoice. */
|
|
61
|
+
struct SttSenseVoiceOptions {
|
|
62
|
+
std::optional<std::string> language;
|
|
63
|
+
std::optional<bool> use_itn;
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
/** Model-specific options: Canary. */
|
|
67
|
+
struct SttCanaryOptions {
|
|
68
|
+
std::optional<std::string> src_lang;
|
|
69
|
+
std::optional<std::string> tgt_lang;
|
|
70
|
+
std::optional<bool> use_pnc;
|
|
71
|
+
};
|
|
72
|
+
|
|
73
|
+
/** Model-specific options: FunASR Nano. */
|
|
74
|
+
struct SttFunAsrNanoOptions {
|
|
75
|
+
std::optional<std::string> system_prompt;
|
|
76
|
+
std::optional<std::string> user_prompt;
|
|
77
|
+
std::optional<int32_t> max_new_tokens;
|
|
78
|
+
std::optional<float> temperature;
|
|
79
|
+
std::optional<float> top_p;
|
|
80
|
+
std::optional<int32_t> seed;
|
|
81
|
+
std::optional<std::string> language;
|
|
82
|
+
std::optional<bool> itn;
|
|
83
|
+
std::optional<std::string> hotwords;
|
|
84
|
+
};
|
|
85
|
+
|
|
86
|
+
/**
|
|
87
|
+
* Wrapper class for sherpa-onnx OfflineRecognizer (STT).
|
|
88
|
+
*/
|
|
89
|
+
class SttWrapper {
|
|
90
|
+
public:
|
|
91
|
+
SttWrapper();
|
|
92
|
+
~SttWrapper();
|
|
93
|
+
|
|
94
|
+
SttInitializeResult initialize(
|
|
95
|
+
const std::string& modelDir,
|
|
96
|
+
const std::optional<bool>& preferInt8 = std::nullopt,
|
|
97
|
+
const std::optional<std::string>& modelType = std::nullopt,
|
|
98
|
+
bool debug = false,
|
|
99
|
+
const std::optional<std::string>& hotwordsFile = std::nullopt,
|
|
100
|
+
const std::optional<float>& hotwordsScore = std::nullopt,
|
|
101
|
+
const std::optional<int32_t>& numThreads = std::nullopt,
|
|
102
|
+
const std::optional<std::string>& provider = std::nullopt,
|
|
103
|
+
const std::optional<std::string>& ruleFsts = std::nullopt,
|
|
104
|
+
const std::optional<std::string>& ruleFars = std::nullopt,
|
|
105
|
+
const std::optional<float>& dither = std::nullopt,
|
|
106
|
+
const SttWhisperOptions* whisperOpts = nullptr,
|
|
107
|
+
const SttSenseVoiceOptions* senseVoiceOpts = nullptr,
|
|
108
|
+
const SttCanaryOptions* canaryOpts = nullptr,
|
|
109
|
+
const SttFunAsrNanoOptions* funasrNanoOpts = nullptr
|
|
110
|
+
);
|
|
111
|
+
|
|
112
|
+
SttRecognitionResult transcribeFile(const std::string& filePath);
|
|
113
|
+
|
|
114
|
+
SttRecognitionResult transcribeSamples(const std::vector<float>& samples, int32_t sampleRate);
|
|
115
|
+
|
|
116
|
+
void setConfig(const SttRuntimeConfigOptions& options);
|
|
117
|
+
|
|
118
|
+
bool isInitialized() const;
|
|
119
|
+
|
|
120
|
+
void release();
|
|
121
|
+
|
|
122
|
+
private:
|
|
123
|
+
class Impl;
|
|
124
|
+
std::unique_ptr<Impl> pImpl;
|
|
125
|
+
};
|
|
126
|
+
|
|
127
|
+
} // namespace sherpaonnx
|
|
128
|
+
|
|
129
|
+
#endif // SHERPA_ONNX_STT_WRAPPER_H
|
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* sherpa-onnx-stt-wrapper.mm
|
|
3
|
+
*
|
|
4
|
+
* Purpose: Wraps the sherpa-onnx C++ OfflineRecognizer for iOS. Builds config from SttModelPaths,
|
|
5
|
+
* creates/destroys recognizer and streams, runs recognition and returns results. Used by SherpaOnnx+STT.mm.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include "sherpa-onnx-stt-wrapper.h"
|
|
9
|
+
#include "sherpa-onnx-model-detect.h"
|
|
10
|
+
#include <algorithm>
|
|
11
|
+
#include <cctype>
|
|
12
|
+
#include <cstring>
|
|
13
|
+
#include <fstream>
|
|
14
|
+
#include <optional>
|
|
15
|
+
#include <sstream>
|
|
16
|
+
#include <cstdint>
|
|
17
|
+
#include <limits>
|
|
18
|
+
|
|
19
|
+
// iOS logging
|
|
20
|
+
#ifdef __APPLE__
|
|
21
|
+
#include <Foundation/Foundation.h>
|
|
22
|
+
#include <cstdio>
|
|
23
|
+
#define LOGI(fmt, ...) NSLog(@"SttWrapper: " fmt, ##__VA_ARGS__)
|
|
24
|
+
#define LOGE(fmt, ...) NSLog(@"SttWrapper ERROR: " fmt, ##__VA_ARGS__)
|
|
25
|
+
#else
|
|
26
|
+
#define LOGI(...)
|
|
27
|
+
#define LOGE(...)
|
|
28
|
+
#endif
|
|
29
|
+
|
|
30
|
+
// Use C++17 filesystem (podspec enforces C++17)
|
|
31
|
+
#include <filesystem>
|
|
32
|
+
namespace fs = std::filesystem;
|
|
33
|
+
|
|
34
|
+
// sherpa-onnx headers - use C++ API (RAII wrapper around C API)
|
|
35
|
+
#include "sherpa-onnx/c-api/cxx-api.h"
|
|
36
|
+
|
|
37
|
+
namespace sherpaonnx {
|
|
38
|
+
|
|
39
|
+
// Hotwords are supported for transducer and NeMo transducer (sherpa-onnx; NeMo: #3077).
|
|
40
|
+
static bool SupportsHotwords(sherpaonnx::SttModelKind kind) {
|
|
41
|
+
return kind == sherpaonnx::SttModelKind::kTransducer || kind == sherpaonnx::SttModelKind::kNemoTransducer;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
// Returns error message if hotwords file is invalid, else empty optional.
|
|
45
|
+
static std::optional<std::string> ValidateHotwordsFile(const std::string& filePath) {
|
|
46
|
+
if (filePath.empty()) return std::nullopt;
|
|
47
|
+
try {
|
|
48
|
+
if (!fs::exists(filePath)) return "Hotwords file does not exist: " + filePath;
|
|
49
|
+
if (!fs::is_regular_file(filePath)) return "Hotwords path is not a file: " + filePath;
|
|
50
|
+
std::ifstream f(filePath, std::ios::binary);
|
|
51
|
+
if (!f) return "Hotwords file is not readable: " + filePath;
|
|
52
|
+
std::string content((std::istreambuf_iterator<char>(f)), std::istreambuf_iterator<char>());
|
|
53
|
+
f.close();
|
|
54
|
+
if (content.find('\0') != std::string::npos) return "Hotwords file contains null bytes (not a valid text file).";
|
|
55
|
+
NSCharacterSet *letterSet = [NSCharacterSet letterCharacterSet];
|
|
56
|
+
int validLines = 0;
|
|
57
|
+
std::istringstream stream(content);
|
|
58
|
+
std::string line;
|
|
59
|
+
while (std::getline(stream, line, '\n')) {
|
|
60
|
+
if (!line.empty() && line.back() == '\r') line.pop_back();
|
|
61
|
+
size_t start = 0;
|
|
62
|
+
while (start < line.size() && (line[start] == ' ' || line[start] == '\t')) start++;
|
|
63
|
+
size_t end = line.size();
|
|
64
|
+
while (end > start && (line[end - 1] == ' ' || line[end - 1] == '\t')) end--;
|
|
65
|
+
if (start >= end) continue;
|
|
66
|
+
line = line.substr(start, end - start);
|
|
67
|
+
std::string hotwordPart;
|
|
68
|
+
size_t spaceColon = line.rfind(" :");
|
|
69
|
+
if (spaceColon != std::string::npos) {
|
|
70
|
+
std::string scoreStr = line.substr(spaceColon + 2);
|
|
71
|
+
try {
|
|
72
|
+
(void)std::stof(scoreStr);
|
|
73
|
+
} catch (...) {
|
|
74
|
+
return "Invalid hotword line (score must be a number after ' :'): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
|
|
75
|
+
}
|
|
76
|
+
size_t hStart = 0, hEnd = spaceColon;
|
|
77
|
+
while (hStart < hEnd && (line[hStart] == ' ' || line[hStart] == '\t')) hStart++;
|
|
78
|
+
while (hEnd > hStart && (line[hEnd - 1] == ' ' || line[hEnd - 1] == '\t')) hEnd--;
|
|
79
|
+
hotwordPart = line.substr(hStart, hEnd - hStart);
|
|
80
|
+
} else {
|
|
81
|
+
size_t tabPos = line.find('\t');
|
|
82
|
+
if (tabPos != std::string::npos) {
|
|
83
|
+
std::string afterTab = line.substr(tabPos + 1);
|
|
84
|
+
try {
|
|
85
|
+
(void)std::stof(afterTab);
|
|
86
|
+
return "This file looks like a sentencepiece .vocab file (token<TAB>score). Use a hotwords file instead: one word or phrase per line, optional ' :score' at end.";
|
|
87
|
+
} catch (...) {}
|
|
88
|
+
}
|
|
89
|
+
hotwordPart = line;
|
|
90
|
+
}
|
|
91
|
+
if (hotwordPart.empty()) return "Invalid hotword line (empty hotword): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
|
|
92
|
+
@autoreleasepool {
|
|
93
|
+
NSString *hotwordNS = [NSString stringWithUTF8String:hotwordPart.c_str()];
|
|
94
|
+
if (!hotwordNS) return "Invalid hotword line (invalid UTF-8): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
|
|
95
|
+
if ([hotwordNS rangeOfCharacterFromSet:letterSet].location == NSNotFound)
|
|
96
|
+
return "Invalid hotword line (must contain at least one letter): " + line.substr(0, std::min(line.size(), size_t(60))) + "…";
|
|
97
|
+
}
|
|
98
|
+
validLines++;
|
|
99
|
+
}
|
|
100
|
+
if (validLines == 0) return "Hotwords file has no valid lines (one hotword or phrase per line, UTF-8 text).";
|
|
101
|
+
return std::nullopt;
|
|
102
|
+
} catch (const std::exception& e) {
|
|
103
|
+
return std::string("Failed to read hotwords file: ") + e.what();
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// PIMPL pattern implementation
|
|
108
|
+
class SttWrapper::Impl {
|
|
109
|
+
public:
|
|
110
|
+
bool initialized = false;
|
|
111
|
+
std::string modelDir;
|
|
112
|
+
sherpaonnx::SttModelKind currentModelKind = sherpaonnx::SttModelKind::kUnknown;
|
|
113
|
+
std::optional<sherpa_onnx::cxx::OfflineRecognizer> recognizer;
|
|
114
|
+
std::optional<sherpa_onnx::cxx::OfflineRecognizerConfig> lastConfig;
|
|
115
|
+
};
|
|
116
|
+
|
|
117
|
+
SttWrapper::SttWrapper() : pImpl(std::make_unique<Impl>()) {
|
|
118
|
+
LOGI("SttWrapper created");
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
SttWrapper::~SttWrapper() {
|
|
122
|
+
release();
|
|
123
|
+
LOGI("SttWrapper destroyed");
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
SttInitializeResult SttWrapper::initialize(
|
|
127
|
+
const std::string& modelDir,
|
|
128
|
+
const std::optional<bool>& preferInt8,
|
|
129
|
+
const std::optional<std::string>& modelType,
|
|
130
|
+
bool debug,
|
|
131
|
+
const std::optional<std::string>& hotwordsFile,
|
|
132
|
+
const std::optional<float>& hotwordsScore,
|
|
133
|
+
const std::optional<int32_t>& numThreads,
|
|
134
|
+
const std::optional<std::string>& provider,
|
|
135
|
+
const std::optional<std::string>& ruleFsts,
|
|
136
|
+
const std::optional<std::string>& ruleFars,
|
|
137
|
+
const std::optional<float>& dither,
|
|
138
|
+
const SttWhisperOptions* whisperOpts,
|
|
139
|
+
const SttSenseVoiceOptions* senseVoiceOpts,
|
|
140
|
+
const SttCanaryOptions* canaryOpts,
|
|
141
|
+
const SttFunAsrNanoOptions* funasrNanoOpts
|
|
142
|
+
) {
|
|
143
|
+
SttInitializeResult result;
|
|
144
|
+
result.success = false;
|
|
145
|
+
|
|
146
|
+
if (pImpl->initialized) {
|
|
147
|
+
release();
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if (modelDir.empty()) {
|
|
151
|
+
result.error = "Model directory is empty";
|
|
152
|
+
LOGE("%s", result.error.c_str());
|
|
153
|
+
return result;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
try {
|
|
157
|
+
sherpa_onnx::cxx::OfflineRecognizerConfig config;
|
|
158
|
+
config.feat_config.sample_rate = 16000;
|
|
159
|
+
config.feat_config.feature_dim = 80;
|
|
160
|
+
|
|
161
|
+
auto detect = DetectSttModel(modelDir, preferInt8, modelType, debug);
|
|
162
|
+
if (!detect.ok) {
|
|
163
|
+
result.error = detect.error;
|
|
164
|
+
LOGE("%s", result.error.c_str());
|
|
165
|
+
return result;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
switch (detect.selectedKind) {
|
|
169
|
+
case SttModelKind::kTransducer:
|
|
170
|
+
case SttModelKind::kNemoTransducer:
|
|
171
|
+
config.model_config.transducer.encoder = detect.paths.encoder;
|
|
172
|
+
config.model_config.transducer.decoder = detect.paths.decoder;
|
|
173
|
+
config.model_config.transducer.joiner = detect.paths.joiner;
|
|
174
|
+
break;
|
|
175
|
+
case SttModelKind::kParaformer:
|
|
176
|
+
config.model_config.paraformer.model = detect.paths.paraformerModel;
|
|
177
|
+
break;
|
|
178
|
+
case SttModelKind::kNemoCtc:
|
|
179
|
+
config.model_config.nemo_ctc.model = detect.paths.ctcModel;
|
|
180
|
+
break;
|
|
181
|
+
case SttModelKind::kWenetCtc:
|
|
182
|
+
config.model_config.wenet_ctc.model = detect.paths.ctcModel;
|
|
183
|
+
break;
|
|
184
|
+
case SttModelKind::kSenseVoice:
|
|
185
|
+
config.model_config.sense_voice.model = detect.paths.ctcModel;
|
|
186
|
+
break;
|
|
187
|
+
case SttModelKind::kZipformerCtc:
|
|
188
|
+
config.model_config.zipformer_ctc.model = detect.paths.ctcModel;
|
|
189
|
+
break;
|
|
190
|
+
case SttModelKind::kWhisper:
|
|
191
|
+
config.model_config.whisper.encoder = detect.paths.whisperEncoder;
|
|
192
|
+
config.model_config.whisper.decoder = detect.paths.whisperDecoder;
|
|
193
|
+
break;
|
|
194
|
+
case SttModelKind::kFunAsrNano:
|
|
195
|
+
config.model_config.funasr_nano.encoder_adaptor = detect.paths.funasrEncoderAdaptor;
|
|
196
|
+
config.model_config.funasr_nano.llm = detect.paths.funasrLLM;
|
|
197
|
+
config.model_config.funasr_nano.embedding = detect.paths.funasrEmbedding;
|
|
198
|
+
config.model_config.funasr_nano.tokenizer = detect.paths.funasrTokenizer;
|
|
199
|
+
break;
|
|
200
|
+
case SttModelKind::kFireRedAsr:
|
|
201
|
+
config.model_config.fire_red_asr.encoder = detect.paths.fireRedEncoder;
|
|
202
|
+
config.model_config.fire_red_asr.decoder = detect.paths.fireRedDecoder;
|
|
203
|
+
break;
|
|
204
|
+
case SttModelKind::kMoonshine:
|
|
205
|
+
config.model_config.moonshine.preprocessor = detect.paths.moonshinePreprocessor;
|
|
206
|
+
config.model_config.moonshine.encoder = detect.paths.moonshineEncoder;
|
|
207
|
+
config.model_config.moonshine.uncached_decoder = detect.paths.moonshineUncachedDecoder;
|
|
208
|
+
config.model_config.moonshine.cached_decoder = detect.paths.moonshineCachedDecoder;
|
|
209
|
+
break;
|
|
210
|
+
case SttModelKind::kDolphin:
|
|
211
|
+
config.model_config.dolphin.model = detect.paths.dolphinModel;
|
|
212
|
+
break;
|
|
213
|
+
case SttModelKind::kCanary:
|
|
214
|
+
config.model_config.canary.encoder = detect.paths.canaryEncoder;
|
|
215
|
+
config.model_config.canary.decoder = detect.paths.canaryDecoder;
|
|
216
|
+
break;
|
|
217
|
+
case SttModelKind::kOmnilingual:
|
|
218
|
+
config.model_config.omnilingual.model = detect.paths.omnilingualModel;
|
|
219
|
+
break;
|
|
220
|
+
case SttModelKind::kMedAsr:
|
|
221
|
+
config.model_config.medasr.model = detect.paths.medasrModel;
|
|
222
|
+
break;
|
|
223
|
+
case SttModelKind::kTeleSpeechCtc:
|
|
224
|
+
config.model_config.telespeech_ctc = detect.paths.telespeechCtcModel;
|
|
225
|
+
break;
|
|
226
|
+
case SttModelKind::kUnknown:
|
|
227
|
+
default:
|
|
228
|
+
result.error = "No compatible model type detected in " + modelDir;
|
|
229
|
+
LOGE("%s", result.error.c_str());
|
|
230
|
+
return result;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
if (!detect.paths.tokens.empty()) {
|
|
234
|
+
config.model_config.tokens = detect.paths.tokens;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Apply model-specific options (only for the loaded model type).
|
|
238
|
+
switch (detect.selectedKind) {
|
|
239
|
+
case SttModelKind::kWhisper:
|
|
240
|
+
if (whisperOpts) {
|
|
241
|
+
if (whisperOpts->language.has_value())
|
|
242
|
+
config.model_config.whisper.language = *whisperOpts->language;
|
|
243
|
+
if (whisperOpts->task.has_value())
|
|
244
|
+
config.model_config.whisper.task = *whisperOpts->task;
|
|
245
|
+
if (whisperOpts->tail_paddings.has_value())
|
|
246
|
+
config.model_config.whisper.tail_paddings = *whisperOpts->tail_paddings;
|
|
247
|
+
}
|
|
248
|
+
break;
|
|
249
|
+
case SttModelKind::kSenseVoice:
|
|
250
|
+
if (senseVoiceOpts) {
|
|
251
|
+
if (senseVoiceOpts->language.has_value())
|
|
252
|
+
config.model_config.sense_voice.language = *senseVoiceOpts->language;
|
|
253
|
+
if (senseVoiceOpts->use_itn.has_value())
|
|
254
|
+
config.model_config.sense_voice.use_itn = *senseVoiceOpts->use_itn;
|
|
255
|
+
}
|
|
256
|
+
break;
|
|
257
|
+
case SttModelKind::kCanary:
|
|
258
|
+
if (canaryOpts) {
|
|
259
|
+
if (canaryOpts->src_lang.has_value())
|
|
260
|
+
config.model_config.canary.src_lang = *canaryOpts->src_lang;
|
|
261
|
+
if (canaryOpts->tgt_lang.has_value())
|
|
262
|
+
config.model_config.canary.tgt_lang = *canaryOpts->tgt_lang;
|
|
263
|
+
if (canaryOpts->use_pnc.has_value())
|
|
264
|
+
config.model_config.canary.use_pnc = *canaryOpts->use_pnc;
|
|
265
|
+
}
|
|
266
|
+
break;
|
|
267
|
+
case SttModelKind::kFunAsrNano:
|
|
268
|
+
if (funasrNanoOpts) {
|
|
269
|
+
if (funasrNanoOpts->system_prompt.has_value())
|
|
270
|
+
config.model_config.funasr_nano.system_prompt = *funasrNanoOpts->system_prompt;
|
|
271
|
+
if (funasrNanoOpts->user_prompt.has_value())
|
|
272
|
+
config.model_config.funasr_nano.user_prompt = *funasrNanoOpts->user_prompt;
|
|
273
|
+
if (funasrNanoOpts->max_new_tokens.has_value())
|
|
274
|
+
config.model_config.funasr_nano.max_new_tokens = *funasrNanoOpts->max_new_tokens;
|
|
275
|
+
if (funasrNanoOpts->temperature.has_value())
|
|
276
|
+
config.model_config.funasr_nano.temperature = *funasrNanoOpts->temperature;
|
|
277
|
+
if (funasrNanoOpts->top_p.has_value())
|
|
278
|
+
config.model_config.funasr_nano.top_p = *funasrNanoOpts->top_p;
|
|
279
|
+
if (funasrNanoOpts->seed.has_value())
|
|
280
|
+
config.model_config.funasr_nano.seed = *funasrNanoOpts->seed;
|
|
281
|
+
if (funasrNanoOpts->language.has_value())
|
|
282
|
+
config.model_config.funasr_nano.language = *funasrNanoOpts->language;
|
|
283
|
+
if (funasrNanoOpts->itn.has_value())
|
|
284
|
+
config.model_config.funasr_nano.itn = *funasrNanoOpts->itn;
|
|
285
|
+
if (funasrNanoOpts->hotwords.has_value())
|
|
286
|
+
config.model_config.funasr_nano.hotwords = *funasrNanoOpts->hotwords;
|
|
287
|
+
}
|
|
288
|
+
break;
|
|
289
|
+
default:
|
|
290
|
+
break;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
if (hotwordsFile.has_value() && !hotwordsFile->empty()) {
|
|
294
|
+
if (!SupportsHotwords(detect.selectedKind)) {
|
|
295
|
+
result.success = false;
|
|
296
|
+
result.error = "HOTWORDS_NOT_SUPPORTED: Hotwords are only supported for transducer models (transducer, nemo_transducer). Current model type is not transducer.";
|
|
297
|
+
LOGE("%s", result.error.c_str());
|
|
298
|
+
return result;
|
|
299
|
+
}
|
|
300
|
+
auto validateErr = ValidateHotwordsFile(*hotwordsFile);
|
|
301
|
+
if (validateErr.has_value()) {
|
|
302
|
+
result.success = false;
|
|
303
|
+
result.error = "INVALID_HOTWORDS_FILE: " + *validateErr;
|
|
304
|
+
LOGE("%s", result.error.c_str());
|
|
305
|
+
return result;
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
config.decoding_method = "greedy_search";
|
|
310
|
+
config.model_config.num_threads = numThreads.value_or(1);
|
|
311
|
+
config.model_config.provider = provider.value_or("cpu");
|
|
312
|
+
if (hotwordsFile.has_value() && !hotwordsFile->empty()) {
|
|
313
|
+
config.hotwords_file = *hotwordsFile;
|
|
314
|
+
config.decoding_method = "modified_beam_search";
|
|
315
|
+
config.max_active_paths = std::max(4, config.max_active_paths);
|
|
316
|
+
}
|
|
317
|
+
if (hotwordsScore.has_value()) {
|
|
318
|
+
config.hotwords_score = *hotwordsScore;
|
|
319
|
+
}
|
|
320
|
+
if (ruleFsts.has_value() && !ruleFsts->empty()) {
|
|
321
|
+
config.rule_fsts = *ruleFsts;
|
|
322
|
+
}
|
|
323
|
+
if (ruleFars.has_value() && !ruleFars->empty()) {
|
|
324
|
+
config.rule_fars = *ruleFars;
|
|
325
|
+
}
|
|
326
|
+
(void)dither; // FeatureConfig in bundled cxx-api.h has no dither; reserve for future use
|
|
327
|
+
|
|
328
|
+
bool isWhisperModel = !config.model_config.whisper.encoder.empty() && !config.model_config.whisper.decoder.empty();
|
|
329
|
+
if (isWhisperModel) {
|
|
330
|
+
LOGI("Initializing Whisper model with encoder: %s, decoder: %s", config.model_config.whisper.encoder.c_str(), config.model_config.whisper.decoder.c_str());
|
|
331
|
+
} else {
|
|
332
|
+
LOGI("Initializing non-Whisper model");
|
|
333
|
+
}
|
|
334
|
+
try {
|
|
335
|
+
pImpl->recognizer = sherpa_onnx::cxx::OfflineRecognizer::Create(config);
|
|
336
|
+
} catch (const std::exception& e) {
|
|
337
|
+
LOGE("Failed to create recognizer: %s", e.what());
|
|
338
|
+
result.success = false;
|
|
339
|
+
result.error = std::string("INIT_ERROR: ") + e.what();
|
|
340
|
+
return result;
|
|
341
|
+
} catch (...) {
|
|
342
|
+
LOGE("Unknown exception during recognizer creation");
|
|
343
|
+
result.success = false;
|
|
344
|
+
result.error = "INIT_ERROR: Unknown exception during recognizer creation";
|
|
345
|
+
return result;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
pImpl->lastConfig = config;
|
|
349
|
+
pImpl->modelDir = modelDir;
|
|
350
|
+
pImpl->currentModelKind = detect.selectedKind;
|
|
351
|
+
pImpl->initialized = true;
|
|
352
|
+
|
|
353
|
+
result.success = true;
|
|
354
|
+
result.detectedModels = detect.detectedModels;
|
|
355
|
+
result.modelType = detect.detectedModels.empty() ? "" : detect.detectedModels[0].type;
|
|
356
|
+
result.decodingMethod = config.decoding_method;
|
|
357
|
+
return result;
|
|
358
|
+
} catch (const std::exception& e) {
|
|
359
|
+
result.error = std::string("Exception during initialization: ") + e.what();
|
|
360
|
+
LOGE("%s", result.error.c_str());
|
|
361
|
+
return result;
|
|
362
|
+
} catch (...) {
|
|
363
|
+
result.error = "Unknown exception during initialization";
|
|
364
|
+
LOGE("%s", result.error.c_str());
|
|
365
|
+
return result;
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
namespace {
|
|
370
|
+
SttRecognitionResult offlineResultToSttResult(const sherpa_onnx::cxx::OfflineRecognizerResult& r) {
|
|
371
|
+
SttRecognitionResult out;
|
|
372
|
+
out.text = r.text;
|
|
373
|
+
out.tokens = r.tokens;
|
|
374
|
+
out.timestamps = r.timestamps;
|
|
375
|
+
out.lang = r.lang;
|
|
376
|
+
out.emotion = r.emotion;
|
|
377
|
+
out.event = r.event;
|
|
378
|
+
out.durations = r.durations;
|
|
379
|
+
return out;
|
|
380
|
+
}
|
|
381
|
+
} // namespace
|
|
382
|
+
|
|
383
|
+
SttRecognitionResult SttWrapper::transcribeFile(const std::string& filePath) {
|
|
384
|
+
if (!pImpl->initialized || !pImpl->recognizer.has_value()) {
|
|
385
|
+
LOGE("Not initialized. Call initialize() first.");
|
|
386
|
+
throw std::runtime_error("STT not initialized. Call initialize() first.");
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
auto fileExists = [](const std::string& path) -> bool {
|
|
390
|
+
return fs::exists(path);
|
|
391
|
+
};
|
|
392
|
+
|
|
393
|
+
LOGI("Transcribe: file=%s", filePath.c_str());
|
|
394
|
+
if (!fileExists(filePath)) {
|
|
395
|
+
LOGE("Audio file not found: %s", filePath.c_str());
|
|
396
|
+
throw std::runtime_error(std::string("Audio file not found: ") + filePath);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
sherpa_onnx::cxx::Wave wave;
|
|
400
|
+
try {
|
|
401
|
+
wave = sherpa_onnx::cxx::ReadWave(filePath);
|
|
402
|
+
} catch (const std::exception& e) {
|
|
403
|
+
LOGE("Transcribe: ReadWave failed: %s", e.what());
|
|
404
|
+
throw;
|
|
405
|
+
} catch (...) {
|
|
406
|
+
LOGE("Transcribe: ReadWave failed (unknown exception)");
|
|
407
|
+
throw std::runtime_error(std::string("Failed to read audio file: ") + filePath);
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
if (wave.samples.empty()) {
|
|
411
|
+
LOGE("Audio file is empty or failed to read: %s", filePath.c_str());
|
|
412
|
+
throw std::runtime_error(std::string("Audio file is empty or could not be read: ") + filePath);
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
try {
|
|
416
|
+
auto stream = pImpl->recognizer.value().CreateStream();
|
|
417
|
+
|
|
418
|
+
// Ensure safe conversions: AcceptWaveform expects 32-bit ints
|
|
419
|
+
if (wave.samples.size() > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
|
|
420
|
+
LOGE("Audio too large: sample count %zu exceeds int32_t max", wave.samples.size());
|
|
421
|
+
throw std::runtime_error("Audio too large to process");
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
int32_t sample_rate = 0;
|
|
425
|
+
if (wave.sample_rate > static_cast<uint32_t>(std::numeric_limits<int32_t>::max())) {
|
|
426
|
+
LOGE("Sample rate too large: %u", wave.sample_rate);
|
|
427
|
+
throw std::runtime_error("Unsupported sample rate");
|
|
428
|
+
} else {
|
|
429
|
+
sample_rate = static_cast<int32_t>(wave.sample_rate);
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
int32_t n_samples = static_cast<int32_t>(wave.samples.size());
|
|
433
|
+
|
|
434
|
+
stream.AcceptWaveform(sample_rate, wave.samples.data(), n_samples);
|
|
435
|
+
pImpl->recognizer.value().Decode(&stream);
|
|
436
|
+
auto result = pImpl->recognizer.value().GetResult(&stream);
|
|
437
|
+
return offlineResultToSttResult(result);
|
|
438
|
+
} catch (const std::exception& e) {
|
|
439
|
+
LOGE("Transcribe: recognition failed: %s", e.what());
|
|
440
|
+
throw;
|
|
441
|
+
} catch (...) {
|
|
442
|
+
LOGE("Transcribe: recognition failed (unknown exception)");
|
|
443
|
+
throw std::runtime_error(
|
|
444
|
+
"Recognition failed. Ensure the model supports offline decoding and audio is 16 kHz mono WAV."
|
|
445
|
+
);
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
SttRecognitionResult SttWrapper::transcribeSamples(const std::vector<float>& samples, int32_t sampleRate) {
|
|
450
|
+
if (!pImpl->initialized || !pImpl->recognizer.has_value()) {
|
|
451
|
+
LOGE("Not initialized. Call initialize() first.");
|
|
452
|
+
throw std::runtime_error("STT not initialized. Call initialize() first.");
|
|
453
|
+
}
|
|
454
|
+
if (samples.empty()) {
|
|
455
|
+
SttRecognitionResult empty;
|
|
456
|
+
return empty;
|
|
457
|
+
}
|
|
458
|
+
if (samples.size() > static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
|
|
459
|
+
LOGE("Samples too large: %zu", samples.size());
|
|
460
|
+
throw std::runtime_error("Samples array too large to process");
|
|
461
|
+
}
|
|
462
|
+
try {
|
|
463
|
+
auto stream = pImpl->recognizer.value().CreateStream();
|
|
464
|
+
int32_t n = static_cast<int32_t>(samples.size());
|
|
465
|
+
stream.AcceptWaveform(sampleRate, samples.data(), n);
|
|
466
|
+
pImpl->recognizer.value().Decode(&stream);
|
|
467
|
+
auto result = pImpl->recognizer.value().GetResult(&stream);
|
|
468
|
+
return offlineResultToSttResult(result);
|
|
469
|
+
} catch (const std::exception& e) {
|
|
470
|
+
LOGE("TranscribeSamples: recognition failed: %s", e.what());
|
|
471
|
+
throw;
|
|
472
|
+
} catch (...) {
|
|
473
|
+
LOGE("TranscribeSamples: recognition failed (unknown exception)");
|
|
474
|
+
throw std::runtime_error("Recognition failed.");
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
void SttWrapper::setConfig(const SttRuntimeConfigOptions& options) {
|
|
479
|
+
if (!pImpl->initialized || !pImpl->recognizer.has_value() || !pImpl->lastConfig.has_value()) {
|
|
480
|
+
LOGE("Not initialized or no stored config.");
|
|
481
|
+
throw std::runtime_error("STT not initialized. Call initialize() first.");
|
|
482
|
+
}
|
|
483
|
+
auto& config = pImpl->lastConfig.value();
|
|
484
|
+
if (options.hotwords_file.has_value() && !options.hotwords_file->empty()) {
|
|
485
|
+
if (!SupportsHotwords(pImpl->currentModelKind)) {
|
|
486
|
+
LOGE("Hotwords are only supported for transducer models.");
|
|
487
|
+
throw std::runtime_error("HOTWORDS_NOT_SUPPORTED: Hotwords are only supported for transducer models (transducer, nemo_transducer). Current model type is not transducer.");
|
|
488
|
+
}
|
|
489
|
+
auto validateErr = ValidateHotwordsFile(*options.hotwords_file);
|
|
490
|
+
if (validateErr.has_value()) {
|
|
491
|
+
LOGE("%s", validateErr->c_str());
|
|
492
|
+
throw std::runtime_error("INVALID_HOTWORDS_FILE: " + *validateErr);
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
if (options.decoding_method.has_value()) config.decoding_method = *options.decoding_method;
|
|
496
|
+
if (options.max_active_paths.has_value()) config.max_active_paths = *options.max_active_paths;
|
|
497
|
+
if (options.hotwords_file.has_value()) config.hotwords_file = *options.hotwords_file;
|
|
498
|
+
if (options.hotwords_score.has_value()) config.hotwords_score = *options.hotwords_score;
|
|
499
|
+
if (options.blank_penalty.has_value()) config.blank_penalty = *options.blank_penalty;
|
|
500
|
+
if (options.rule_fsts.has_value()) config.rule_fsts = *options.rule_fsts;
|
|
501
|
+
if (options.rule_fars.has_value()) config.rule_fars = *options.rule_fars;
|
|
502
|
+
if (!config.hotwords_file.empty()) {
|
|
503
|
+
config.decoding_method = "modified_beam_search";
|
|
504
|
+
config.max_active_paths = std::max(4, config.max_active_paths);
|
|
505
|
+
}
|
|
506
|
+
pImpl->recognizer.value().SetConfig(config);
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
bool SttWrapper::isInitialized() const {
|
|
510
|
+
return pImpl->initialized;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
void SttWrapper::release() {
|
|
514
|
+
if (pImpl->initialized) {
|
|
515
|
+
pImpl->recognizer.reset();
|
|
516
|
+
pImpl->lastConfig.reset();
|
|
517
|
+
pImpl->initialized = false;
|
|
518
|
+
pImpl->modelDir.clear();
|
|
519
|
+
pImpl->currentModelKind = sherpaonnx::SttModelKind::kUnknown;
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
} // namespace sherpaonnx
|