react-native-executorch 0.5.3 → 0.5.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/data_processing/Numerical.cpp +27 -19
- package/common/rnexecutorch/data_processing/Numerical.h +53 -4
- package/common/rnexecutorch/data_processing/dsp.cpp +1 -1
- package/common/rnexecutorch/data_processing/dsp.h +1 -1
- package/common/rnexecutorch/data_processing/gzip.cpp +47 -0
- package/common/rnexecutorch/data_processing/gzip.h +7 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +24 -0
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +21 -1
- package/common/rnexecutorch/models/BaseModel.cpp +3 -2
- package/common/rnexecutorch/models/BaseModel.h +3 -2
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +103 -39
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +39 -21
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +310 -0
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +62 -0
- package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +82 -0
- package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +25 -0
- package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +99 -0
- package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +33 -0
- package/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +15 -0
- package/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +12 -0
- package/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +12 -0
- package/common/rnexecutorch/models/speech_to_text/types/Segment.h +14 -0
- package/common/rnexecutorch/models/speech_to_text/types/Word.h +13 -0
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +75 -53
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +5 -5
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -12
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/types/stt.d.ts +0 -9
- package/lib/typescript/types/stt.d.ts.map +1 -1
- package/package.json +1 -1
- package/react-native-executorch.podspec +2 -0
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +118 -54
- package/src/types/stt.ts +0 -12
- package/common/rnexecutorch/models/EncoderDecoderBase.cpp +0 -21
- package/common/rnexecutorch/models/EncoderDecoderBase.h +0 -31
- package/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +0 -27
- package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +0 -50
- package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +0 -25
- package/lib/Error.js +0 -53
- package/lib/ThreadPool.d.ts +0 -10
- package/lib/ThreadPool.js +0 -28
- package/lib/common/Logger.d.ts +0 -8
- package/lib/common/Logger.js +0 -19
- package/lib/constants/directories.js +0 -2
- package/lib/constants/llmDefaults.d.ts +0 -6
- package/lib/constants/llmDefaults.js +0 -16
- package/lib/constants/modelUrls.d.ts +0 -223
- package/lib/constants/modelUrls.js +0 -322
- package/lib/constants/ocr/models.d.ts +0 -882
- package/lib/constants/ocr/models.js +0 -182
- package/lib/constants/ocr/symbols.js +0 -139
- package/lib/constants/sttDefaults.d.ts +0 -28
- package/lib/constants/sttDefaults.js +0 -68
- package/lib/controllers/LLMController.d.ts +0 -47
- package/lib/controllers/LLMController.js +0 -213
- package/lib/controllers/OCRController.js +0 -67
- package/lib/controllers/SpeechToTextController.d.ts +0 -56
- package/lib/controllers/SpeechToTextController.js +0 -349
- package/lib/controllers/VerticalOCRController.js +0 -70
- package/lib/hooks/computer_vision/useClassification.d.ts +0 -15
- package/lib/hooks/computer_vision/useClassification.js +0 -7
- package/lib/hooks/computer_vision/useImageEmbeddings.d.ts +0 -15
- package/lib/hooks/computer_vision/useImageEmbeddings.js +0 -7
- package/lib/hooks/computer_vision/useImageSegmentation.d.ts +0 -38
- package/lib/hooks/computer_vision/useImageSegmentation.js +0 -7
- package/lib/hooks/computer_vision/useOCR.d.ts +0 -20
- package/lib/hooks/computer_vision/useOCR.js +0 -41
- package/lib/hooks/computer_vision/useObjectDetection.d.ts +0 -15
- package/lib/hooks/computer_vision/useObjectDetection.js +0 -7
- package/lib/hooks/computer_vision/useStyleTransfer.d.ts +0 -15
- package/lib/hooks/computer_vision/useStyleTransfer.js +0 -7
- package/lib/hooks/computer_vision/useVerticalOCR.d.ts +0 -21
- package/lib/hooks/computer_vision/useVerticalOCR.js +0 -43
- package/lib/hooks/general/useExecutorchModule.d.ts +0 -13
- package/lib/hooks/general/useExecutorchModule.js +0 -7
- package/lib/hooks/natural_language_processing/useLLM.d.ts +0 -10
- package/lib/hooks/natural_language_processing/useLLM.js +0 -78
- package/lib/hooks/natural_language_processing/useSpeechToText.d.ts +0 -27
- package/lib/hooks/natural_language_processing/useSpeechToText.js +0 -49
- package/lib/hooks/natural_language_processing/useTextEmbeddings.d.ts +0 -16
- package/lib/hooks/natural_language_processing/useTextEmbeddings.js +0 -7
- package/lib/hooks/natural_language_processing/useTokenizer.d.ts +0 -17
- package/lib/hooks/natural_language_processing/useTokenizer.js +0 -52
- package/lib/hooks/useModule.js +0 -45
- package/lib/hooks/useNonStaticModule.d.ts +0 -20
- package/lib/hooks/useNonStaticModule.js +0 -49
- package/lib/index.d.ts +0 -48
- package/lib/index.js +0 -58
- package/lib/module/utils/SpeechToTextModule/ASR.js +0 -191
- package/lib/module/utils/SpeechToTextModule/ASR.js.map +0 -1
- package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js +0 -73
- package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js.map +0 -1
- package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js +0 -56
- package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js.map +0 -1
- package/lib/module/utils/stt.js +0 -22
- package/lib/module/utils/stt.js.map +0 -1
- package/lib/modules/BaseModule.js +0 -25
- package/lib/modules/BaseNonStaticModule.js +0 -14
- package/lib/modules/computer_vision/ClassificationModule.d.ts +0 -8
- package/lib/modules/computer_vision/ClassificationModule.js +0 -17
- package/lib/modules/computer_vision/ImageEmbeddingsModule.d.ts +0 -8
- package/lib/modules/computer_vision/ImageEmbeddingsModule.js +0 -17
- package/lib/modules/computer_vision/ImageSegmentationModule.d.ts +0 -11
- package/lib/modules/computer_vision/ImageSegmentationModule.js +0 -27
- package/lib/modules/computer_vision/OCRModule.d.ts +0 -14
- package/lib/modules/computer_vision/OCRModule.js +0 -17
- package/lib/modules/computer_vision/ObjectDetectionModule.d.ts +0 -9
- package/lib/modules/computer_vision/ObjectDetectionModule.js +0 -17
- package/lib/modules/computer_vision/StyleTransferModule.d.ts +0 -8
- package/lib/modules/computer_vision/StyleTransferModule.js +0 -17
- package/lib/modules/computer_vision/VerticalOCRModule.d.ts +0 -14
- package/lib/modules/computer_vision/VerticalOCRModule.js +0 -19
- package/lib/modules/general/ExecutorchModule.d.ts +0 -7
- package/lib/modules/general/ExecutorchModule.js +0 -14
- package/lib/modules/natural_language_processing/LLMModule.d.ts +0 -28
- package/lib/modules/natural_language_processing/LLMModule.js +0 -45
- package/lib/modules/natural_language_processing/SpeechToTextModule.d.ts +0 -24
- package/lib/modules/natural_language_processing/SpeechToTextModule.js +0 -36
- package/lib/modules/natural_language_processing/TextEmbeddingsModule.d.ts +0 -9
- package/lib/modules/natural_language_processing/TextEmbeddingsModule.js +0 -21
- package/lib/modules/natural_language_processing/TokenizerModule.d.ts +0 -12
- package/lib/modules/natural_language_processing/TokenizerModule.js +0 -30
- package/lib/native/NativeETInstaller.js +0 -2
- package/lib/native/NativeOCR.js +0 -2
- package/lib/native/NativeVerticalOCR.js +0 -2
- package/lib/native/RnExecutorchModules.d.ts +0 -7
- package/lib/native/RnExecutorchModules.js +0 -18
- package/lib/tsconfig.tsbuildinfo +0 -1
- package/lib/types/common.d.ts +0 -32
- package/lib/types/common.js +0 -25
- package/lib/types/imageSegmentation.js +0 -26
- package/lib/types/llm.d.ts +0 -46
- package/lib/types/llm.js +0 -9
- package/lib/types/objectDetection.js +0 -94
- package/lib/types/ocr.js +0 -1
- package/lib/types/stt.d.ts +0 -94
- package/lib/types/stt.js +0 -85
- package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts +0 -27
- package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts.map +0 -1
- package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts +0 -23
- package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts.map +0 -1
- package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts +0 -13
- package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts.map +0 -1
- package/lib/typescript/utils/stt.d.ts +0 -2
- package/lib/typescript/utils/stt.d.ts.map +0 -1
- package/lib/utils/ResourceFetcher.d.ts +0 -24
- package/lib/utils/ResourceFetcher.js +0 -305
- package/lib/utils/ResourceFetcherUtils.d.ts +0 -54
- package/lib/utils/ResourceFetcherUtils.js +0 -127
- package/lib/utils/llm.d.ts +0 -6
- package/lib/utils/llm.js +0 -72
- package/lib/utils/stt.js +0 -21
- package/src/utils/SpeechToTextModule/ASR.ts +0 -303
- package/src/utils/SpeechToTextModule/OnlineProcessor.ts +0 -87
- package/src/utils/SpeechToTextModule/hypothesisBuffer.ts +0 -79
- package/src/utils/stt.ts +0 -28
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|
#include <string>
|
|
10
10
|
|
|
11
11
|
namespace rnexecutorch::numerical {
|
|
12
|
-
void softmax(std::
|
|
12
|
+
void softmax(std::span<float> v) {
|
|
13
13
|
float max = *std::max_element(v.begin(), v.end());
|
|
14
14
|
|
|
15
15
|
float sum = 0.0f;
|
|
@@ -22,32 +22,40 @@ void softmax(std::vector<float> &v) {
|
|
|
22
22
|
}
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
-
void
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
sum += val * val;
|
|
25
|
+
void softmaxWithTemperature(std::span<float> input, float temperature) {
|
|
26
|
+
if (input.empty()) {
|
|
27
|
+
return;
|
|
29
28
|
}
|
|
30
29
|
|
|
31
|
-
if (
|
|
32
|
-
|
|
30
|
+
if (temperature <= 0.0F) {
|
|
31
|
+
throw std::invalid_argument(
|
|
32
|
+
"Temperature must be greater than 0 for softmax with temperature.");
|
|
33
33
|
}
|
|
34
34
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
35
|
+
const auto maxElement = *std::ranges::max_element(input);
|
|
36
|
+
|
|
37
|
+
for (auto &value : input) {
|
|
38
|
+
value = std::exp((value - maxElement) / temperature);
|
|
38
39
|
}
|
|
39
|
-
}
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
41
|
+
const auto sum = std::reduce(input.begin(), input.end());
|
|
42
|
+
|
|
43
|
+
// sum is at least 1 since exp(max - max) == exp(0) == 1
|
|
44
|
+
for (auto &value : input) {
|
|
45
|
+
value /= sum;
|
|
45
46
|
}
|
|
47
|
+
}
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
49
|
+
void normalize(std::span<float> input) {
|
|
50
|
+
const auto sumOfSquares =
|
|
51
|
+
std::inner_product(input.begin(), input.end(), input.begin(), 0.0F);
|
|
52
|
+
|
|
53
|
+
constexpr auto kEpsilon = 1.0e-15F;
|
|
54
|
+
|
|
55
|
+
const auto norm = std::sqrt(sumOfSquares) + kEpsilon;
|
|
56
|
+
|
|
57
|
+
for (auto &value : input) {
|
|
58
|
+
value /= norm;
|
|
51
59
|
}
|
|
52
60
|
}
|
|
53
61
|
|
|
@@ -4,10 +4,59 @@
|
|
|
4
4
|
#include <vector>
|
|
5
5
|
|
|
6
6
|
namespace rnexecutorch::numerical {
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
7
|
+
|
|
8
|
+
/**
|
|
9
|
+
* @brief Applies the softmax function in-place to a sequence of numbers.
|
|
10
|
+
*
|
|
11
|
+
* @param input A mutable span of floating-point numbers. After the function
|
|
12
|
+
* returns, `input` contains the softmax probabilities.
|
|
13
|
+
*/
|
|
14
|
+
void softmax(std::span<float> input);
|
|
15
|
+
|
|
16
|
+
/**
|
|
17
|
+
* @brief Applies the softmax function with temperature scaling in-place to a
|
|
18
|
+
* sequence of numbers.
|
|
19
|
+
*
|
|
20
|
+
* The temperature parameter controls the "sharpness" of the resulting
|
|
21
|
+
* probability distribution. A temperature of 1.0 means no scaling, while lower
|
|
22
|
+
* values make the distribution sharper (more peaked), and higher values make it
|
|
23
|
+
* softer (more uniform).
|
|
24
|
+
*
|
|
25
|
+
* @param input A mutable span of floating-point numbers. After the function
|
|
26
|
+
* returns, `input` contains the softmax probabilities.
|
|
27
|
+
* @param temperature A positive float value used to scale the logits before
|
|
28
|
+
* applying softmax. Must be greater than 0.
|
|
29
|
+
*/
|
|
30
|
+
void softmaxWithTemperature(std::span<float> input, float temperature);
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* @brief Normalizes the elements of the given float span in-place using the
|
|
34
|
+
* L2 norm method.
|
|
35
|
+
*
|
|
36
|
+
* This function scales the input vector such that its L2 norm (Euclidean norm)
|
|
37
|
+
* becomes 1. If the norm is zero, the result is a zero vector with the same
|
|
38
|
+
* size as the input.
|
|
39
|
+
*
|
|
40
|
+
* @param input A mutable span of floating-point values representing the data to
|
|
41
|
+
* be normalized.
|
|
42
|
+
*/
|
|
43
|
+
void normalize(std::span<float> input);
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* @brief Computes mean pooling across the modelOutput adjusted by an attention
|
|
47
|
+
* mask.
|
|
48
|
+
*
|
|
49
|
+
* This function aggregates the `modelOutput` span by sections defined by
|
|
50
|
+
* `attnMask`, computing the mean of sections influenced by the mask. The result
|
|
51
|
+
* is a vector where each element is the mean of a segment from the original
|
|
52
|
+
* data.
|
|
53
|
+
*
|
|
54
|
+
* @param modelOutput A span of floating-point numbers representing the model
|
|
55
|
+
* output.
|
|
56
|
+
* @param attnMask A span of integers where each integer is a weight
|
|
57
|
+
* corresponding to the elements in `modelOutput`.
|
|
58
|
+
* @return A std::vector<float> containing the computed mean values of segments.
|
|
59
|
+
*/
|
|
11
60
|
std::vector<float> meanPooling(std::span<const float> modelOutput,
|
|
12
61
|
std::span<const int64_t> attnMask);
|
|
13
62
|
/**
|
|
@@ -18,7 +18,7 @@ std::vector<float> hannWindow(size_t size) {
|
|
|
18
18
|
return window;
|
|
19
19
|
}
|
|
20
20
|
|
|
21
|
-
std::vector<float> stftFromWaveform(std::span<float> waveform,
|
|
21
|
+
std::vector<float> stftFromWaveform(std::span<const float> waveform,
|
|
22
22
|
size_t fftWindowSize, size_t hopSize) {
|
|
23
23
|
// Initialize FFT
|
|
24
24
|
FFT fft(fftWindowSize);
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
namespace rnexecutorch::dsp {
|
|
7
7
|
|
|
8
8
|
std::vector<float> hannWindow(size_t size);
|
|
9
|
-
std::vector<float> stftFromWaveform(std::span<float> waveform,
|
|
9
|
+
std::vector<float> stftFromWaveform(std::span<const float> waveform,
|
|
10
10
|
size_t fftWindowSize, size_t hopSize);
|
|
11
11
|
|
|
12
12
|
} // namespace rnexecutorch::dsp
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
#include <vector>
|
|
2
|
+
#include <zlib.h>
|
|
3
|
+
|
|
4
|
+
#include "gzip.h"
|
|
5
|
+
|
|
6
|
+
namespace rnexecutorch::gzip {
|
|
7
|
+
|
|
8
|
+
namespace {
|
|
9
|
+
constexpr int32_t kGzipWrapper = 16; // gzip header/trailer
|
|
10
|
+
constexpr int32_t kMemLevel = 8; // memory level
|
|
11
|
+
constexpr size_t kChunkSize = 16 * 1024; // 16 KiB stream buffer
|
|
12
|
+
} // namespace
|
|
13
|
+
|
|
14
|
+
size_t deflateSize(const std::string &input) {
|
|
15
|
+
z_stream strm{};
|
|
16
|
+
if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
|
|
17
|
+
MAX_WBITS + kGzipWrapper, kMemLevel,
|
|
18
|
+
Z_DEFAULT_STRATEGY) != Z_OK) {
|
|
19
|
+
throw std::runtime_error("deflateInit2 failed");
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
size_t outSize = 0;
|
|
23
|
+
|
|
24
|
+
strm.next_in = reinterpret_cast<z_const Bytef *>(
|
|
25
|
+
const_cast<z_const char *>(input.data()));
|
|
26
|
+
strm.avail_in = static_cast<uInt>(input.size());
|
|
27
|
+
|
|
28
|
+
std::vector<unsigned char> buf(kChunkSize);
|
|
29
|
+
int ret;
|
|
30
|
+
do {
|
|
31
|
+
strm.next_out = buf.data();
|
|
32
|
+
strm.avail_out = static_cast<uInt>(buf.size());
|
|
33
|
+
|
|
34
|
+
ret = ::deflate(&strm, strm.avail_in ? Z_NO_FLUSH : Z_FINISH);
|
|
35
|
+
if (ret == Z_STREAM_ERROR) {
|
|
36
|
+
::deflateEnd(&strm);
|
|
37
|
+
throw std::runtime_error("deflate stream error");
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
outSize += buf.size() - strm.avail_out;
|
|
41
|
+
} while (ret != Z_STREAM_END);
|
|
42
|
+
|
|
43
|
+
::deflateEnd(&strm);
|
|
44
|
+
return outSize;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
} // namespace rnexecutorch::gzip
|
|
@@ -62,6 +62,30 @@ public:
|
|
|
62
62
|
"decode"));
|
|
63
63
|
}
|
|
64
64
|
|
|
65
|
+
if constexpr (meta::HasTranscribe<Model>) {
|
|
66
|
+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
|
|
67
|
+
promiseHostFunction<&Model::transcribe>,
|
|
68
|
+
"transcribe"));
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if constexpr (meta::HasStream<Model>) {
|
|
72
|
+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
|
|
73
|
+
promiseHostFunction<&Model::stream>,
|
|
74
|
+
"stream"));
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if constexpr (meta::HasStreamInsert<Model>) {
|
|
78
|
+
addFunctions(JSI_EXPORT_FUNCTION(
|
|
79
|
+
ModelHostObject<Model>, promiseHostFunction<&Model::streamInsert>,
|
|
80
|
+
"streamInsert"));
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
if constexpr (meta::HasStreamStop<Model>) {
|
|
84
|
+
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
|
|
85
|
+
promiseHostFunction<&Model::streamStop>,
|
|
86
|
+
"streamStop"));
|
|
87
|
+
}
|
|
88
|
+
|
|
65
89
|
if constexpr (meta::SameAs<Model, TokenizerModule>) {
|
|
66
90
|
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
|
|
67
91
|
promiseHostFunction<&Model::encode>,
|
|
@@ -26,6 +26,26 @@ concept HasDecode = requires(T t) {
|
|
|
26
26
|
{ &T::decode };
|
|
27
27
|
};
|
|
28
28
|
|
|
29
|
+
template <typename T>
|
|
30
|
+
concept HasTranscribe = requires(T t) {
|
|
31
|
+
{ &T::transcribe };
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
template <typename T>
|
|
35
|
+
concept HasStream = requires(T t) {
|
|
36
|
+
{ &T::stream };
|
|
37
|
+
};
|
|
38
|
+
|
|
39
|
+
template <typename T>
|
|
40
|
+
concept HasStreamInsert = requires(T t) {
|
|
41
|
+
{ &T::streamInsert };
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
template <typename T>
|
|
45
|
+
concept HasStreamStop = requires(T t) {
|
|
46
|
+
{ &T::streamStop };
|
|
47
|
+
};
|
|
48
|
+
|
|
29
49
|
template <typename T>
|
|
30
50
|
concept IsNumeric = std::is_arithmetic_v<T>;
|
|
31
51
|
|
|
@@ -34,4 +54,4 @@ concept ProvidesMemoryLowerBound = requires(T t) {
|
|
|
34
54
|
{ &T::getMemoryLowerBound };
|
|
35
55
|
};
|
|
36
56
|
|
|
37
|
-
} // namespace rnexecutorch::meta
|
|
57
|
+
} // namespace rnexecutorch::meta
|
|
@@ -142,7 +142,8 @@ BaseModel::getMethodMeta(const std::string &methodName) {
|
|
|
142
142
|
return module_->method_meta(methodName);
|
|
143
143
|
}
|
|
144
144
|
|
|
145
|
-
Result<std::vector<EValue>>
|
|
145
|
+
Result<std::vector<EValue>>
|
|
146
|
+
BaseModel::forward(const EValue &input_evalue) const {
|
|
146
147
|
if (!module_) {
|
|
147
148
|
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
|
|
148
149
|
}
|
|
@@ -150,7 +151,7 @@ Result<std::vector<EValue>> BaseModel::forward(const EValue &input_evalue) {
|
|
|
150
151
|
}
|
|
151
152
|
|
|
152
153
|
Result<std::vector<EValue>>
|
|
153
|
-
BaseModel::forward(const std::vector<EValue> &input_evalues) {
|
|
154
|
+
BaseModel::forward(const std::vector<EValue> &input_evalues) const {
|
|
154
155
|
if (!module_) {
|
|
155
156
|
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
|
|
156
157
|
}
|
|
@@ -26,8 +26,9 @@ public:
|
|
|
26
26
|
getAllInputShapes(std::string methodName = "forward");
|
|
27
27
|
std::vector<JSTensorViewOut>
|
|
28
28
|
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
|
|
29
|
-
Result<std::vector<EValue>> forward(const EValue &input_value);
|
|
30
|
-
Result<std::vector<EValue>>
|
|
29
|
+
Result<std::vector<EValue>> forward(const EValue &input_value) const;
|
|
30
|
+
Result<std::vector<EValue>>
|
|
31
|
+
forward(const std::vector<EValue> &input_value) const;
|
|
31
32
|
Result<std::vector<EValue>> execute(const std::string &methodName,
|
|
32
33
|
const std::vector<EValue> &input_value);
|
|
33
34
|
Result<executorch::runtime::MethodMeta>
|
|
@@ -1,64 +1,128 @@
|
|
|
1
|
-
#include <
|
|
2
|
-
|
|
3
|
-
#include
|
|
1
|
+
#include <thread>
|
|
2
|
+
|
|
3
|
+
#include "SpeechToText.h"
|
|
4
4
|
|
|
5
5
|
namespace rnexecutorch::models::speech_to_text {
|
|
6
6
|
|
|
7
7
|
using namespace ::executorch::extension;
|
|
8
|
+
using namespace asr;
|
|
9
|
+
using namespace types;
|
|
10
|
+
using namespace stream;
|
|
8
11
|
|
|
9
|
-
SpeechToText::SpeechToText(const std::string &
|
|
10
|
-
const std::string &
|
|
11
|
-
const std::string &
|
|
12
|
+
SpeechToText::SpeechToText(const std::string &encoderSource,
|
|
13
|
+
const std::string &decoderSource,
|
|
14
|
+
const std::string &tokenizerSource,
|
|
12
15
|
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
13
|
-
:
|
|
14
|
-
|
|
15
|
-
|
|
16
|
+
: callInvoker(std::move(callInvoker)),
|
|
17
|
+
encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
|
|
18
|
+
decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
|
|
19
|
+
tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
|
|
20
|
+
this->callInvoker)),
|
|
21
|
+
asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
|
|
22
|
+
this->tokenizer.get())),
|
|
23
|
+
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
|
|
24
|
+
isStreaming(false), readyToProcess(false) {}
|
|
25
|
+
|
|
26
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
27
|
+
SpeechToText::encode(std::span<float> waveform) const {
|
|
28
|
+
std::vector<float> encoderOutput = this->asr->encode(waveform);
|
|
29
|
+
return this->makeOwningBuffer(encoderOutput);
|
|
16
30
|
}
|
|
17
31
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
". Only 'whisper' is supported.");
|
|
24
|
-
}
|
|
32
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
33
|
+
SpeechToText::decode(std::span<int32_t> tokens,
|
|
34
|
+
std::span<float> encoderOutput) const {
|
|
35
|
+
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
|
|
36
|
+
return this->makeOwningBuffer(decoderOutput);
|
|
25
37
|
}
|
|
26
38
|
|
|
27
|
-
|
|
28
|
-
|
|
39
|
+
std::string SpeechToText::transcribe(std::span<float> waveform,
|
|
40
|
+
std::string languageOption) const {
|
|
41
|
+
std::vector<Segment> segments =
|
|
42
|
+
this->asr->transcribe(waveform, DecodingOptions(languageOption));
|
|
43
|
+
std::string transcription;
|
|
44
|
+
|
|
45
|
+
size_t transcriptionLength = 0;
|
|
46
|
+
for (auto &segment : segments) {
|
|
47
|
+
for (auto &word : segment.words) {
|
|
48
|
+
transcriptionLength += word.content.size();
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
transcription.reserve(transcriptionLength);
|
|
29
52
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
std::to_string(static_cast<int>(result.error())));
|
|
53
|
+
for (auto &segment : segments) {
|
|
54
|
+
for (auto &word : segment.words) {
|
|
55
|
+
transcription += word.content;
|
|
56
|
+
}
|
|
35
57
|
}
|
|
58
|
+
return transcription;
|
|
59
|
+
}
|
|
36
60
|
|
|
37
|
-
|
|
61
|
+
size_t SpeechToText::getMemoryLowerBound() const noexcept {
|
|
62
|
+
return this->encoder->getMemoryLowerBound() +
|
|
63
|
+
this->decoder->getMemoryLowerBound() +
|
|
64
|
+
this->tokenizer->getMemoryLowerBound();
|
|
38
65
|
}
|
|
39
66
|
|
|
40
67
|
std::shared_ptr<OwningArrayBuffer>
|
|
41
|
-
SpeechToText::
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
68
|
+
SpeechToText::makeOwningBuffer(std::span<const float> vectorView) const {
|
|
69
|
+
auto owningArrayBuffer =
|
|
70
|
+
std::make_shared<OwningArrayBuffer>(vectorView.size_bytes());
|
|
71
|
+
std::memcpy(owningArrayBuffer->data(), vectorView.data(),
|
|
72
|
+
vectorView.size_bytes());
|
|
73
|
+
return owningArrayBuffer;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
|
|
77
|
+
std::string languageOption) {
|
|
78
|
+
if (this->isStreaming) {
|
|
79
|
+
throw std::runtime_error("Streaming is already in progress");
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
auto nativeCallback = [this, callback](const std::string &committed,
|
|
83
|
+
const std::string &nonCommitted,
|
|
84
|
+
bool isDone) {
|
|
85
|
+
this->callInvoker->invokeAsync(
|
|
86
|
+
[callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
|
|
87
|
+
callback->call(rt, jsi::String::createFromUtf8(rt, committed),
|
|
88
|
+
jsi::String::createFromUtf8(rt, nonCommitted),
|
|
89
|
+
jsi::Value(isDone));
|
|
90
|
+
});
|
|
91
|
+
};
|
|
92
|
+
|
|
93
|
+
this->resetStreamState();
|
|
94
|
+
|
|
95
|
+
this->isStreaming = true;
|
|
96
|
+
while (this->isStreaming) {
|
|
97
|
+
if (!this->readyToProcess ||
|
|
98
|
+
this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
|
|
99
|
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
100
|
+
continue;
|
|
101
|
+
}
|
|
102
|
+
ProcessResult res =
|
|
103
|
+
this->processor->processIter(DecodingOptions(languageOption));
|
|
104
|
+
nativeCallback(res.committed, res.nonCommitted, false);
|
|
105
|
+
this->readyToProcess = false;
|
|
45
106
|
}
|
|
46
107
|
|
|
47
|
-
|
|
108
|
+
std::string committed = this->processor->finish();
|
|
109
|
+
nativeCallback(committed, "", true);
|
|
110
|
+
}
|
|
48
111
|
|
|
49
|
-
|
|
50
|
-
const auto decoderResult =
|
|
51
|
-
decoder_->execute(decoderMethod, {prevTokensTensor, encoderOutput});
|
|
112
|
+
void SpeechToText::streamStop() { this->isStreaming = false; }
|
|
52
113
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
std::to_string(static_cast<int>(decoderResult.error())));
|
|
114
|
+
void SpeechToText::streamInsert(std::span<float> waveform) {
|
|
115
|
+
if (!this->isStreaming) {
|
|
116
|
+
throw std::runtime_error("Streaming is not started");
|
|
57
117
|
}
|
|
118
|
+
this->processor->insertAudioChunk(waveform);
|
|
119
|
+
this->readyToProcess = true;
|
|
120
|
+
}
|
|
58
121
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
122
|
+
void SpeechToText::resetStreamState() {
|
|
123
|
+
this->isStreaming = false;
|
|
124
|
+
this->readyToProcess = false;
|
|
125
|
+
this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
|
|
62
126
|
}
|
|
63
127
|
|
|
64
128
|
} // namespace rnexecutorch::models::speech_to_text
|
|
@@ -1,38 +1,56 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
|
-
#include "
|
|
4
|
-
#include "executorch/runtime/core/evalue.h"
|
|
5
|
-
#include <cstdint>
|
|
6
|
-
#include <memory>
|
|
7
|
-
#include <span>
|
|
8
|
-
#include <string>
|
|
9
|
-
#include <vector>
|
|
10
|
-
|
|
11
|
-
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
|
|
12
|
-
#include <rnexecutorch/models/EncoderDecoderBase.h>
|
|
13
|
-
#include <rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h>
|
|
3
|
+
#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
|
|
14
4
|
|
|
15
5
|
namespace rnexecutorch {
|
|
6
|
+
|
|
16
7
|
namespace models::speech_to_text {
|
|
17
|
-
|
|
8
|
+
|
|
9
|
+
class SpeechToText {
|
|
18
10
|
public:
|
|
19
|
-
explicit SpeechToText(const std::string &
|
|
20
|
-
const std::string &
|
|
21
|
-
const std::string &
|
|
11
|
+
explicit SpeechToText(const std::string &encoderSource,
|
|
12
|
+
const std::string &decoderSource,
|
|
13
|
+
const std::string &tokenizerSource,
|
|
22
14
|
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
23
|
-
|
|
24
|
-
std::shared_ptr<OwningArrayBuffer>
|
|
15
|
+
|
|
16
|
+
std::shared_ptr<OwningArrayBuffer> encode(std::span<float> waveform) const;
|
|
17
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
18
|
+
decode(std::span<int32_t> tokens, std::span<float> encoderOutput) const;
|
|
19
|
+
std::string transcribe(std::span<float> waveform,
|
|
20
|
+
std::string languageOption) const;
|
|
21
|
+
|
|
22
|
+
size_t getMemoryLowerBound() const noexcept;
|
|
23
|
+
|
|
24
|
+
// Stream
|
|
25
|
+
void stream(std::shared_ptr<jsi::Function> callback,
|
|
26
|
+
std::string languageOption);
|
|
27
|
+
void streamStop();
|
|
28
|
+
void streamInsert(std::span<float> waveform);
|
|
25
29
|
|
|
26
30
|
private:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
std::unique_ptr<
|
|
31
|
+
std::unique_ptr<BaseModel> encoder;
|
|
32
|
+
std::unique_ptr<BaseModel> decoder;
|
|
33
|
+
std::unique_ptr<TokenizerModule> tokenizer;
|
|
34
|
+
std::unique_ptr<asr::ASR> asr;
|
|
30
35
|
|
|
31
|
-
|
|
36
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
37
|
+
makeOwningBuffer(std::span<const float> vectorView) const;
|
|
38
|
+
|
|
39
|
+
// Stream
|
|
40
|
+
std::shared_ptr<react::CallInvoker> callInvoker;
|
|
41
|
+
std::unique_ptr<stream::OnlineASRProcessor> processor;
|
|
42
|
+
bool isStreaming;
|
|
43
|
+
bool readyToProcess;
|
|
44
|
+
|
|
45
|
+
constexpr static int32_t kMinAudioSamples = 16000; // 1 second
|
|
46
|
+
|
|
47
|
+
void resetStreamState();
|
|
32
48
|
};
|
|
49
|
+
|
|
33
50
|
} // namespace models::speech_to_text
|
|
34
51
|
|
|
35
52
|
REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string,
|
|
36
53
|
std::string, std::string,
|
|
37
54
|
std::shared_ptr<react::CallInvoker>);
|
|
55
|
+
|
|
38
56
|
} // namespace rnexecutorch
|