react-native-executorch 0.5.3 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/android/src/main/cpp/CMakeLists.txt +2 -1
  2. package/common/rnexecutorch/data_processing/Numerical.cpp +27 -19
  3. package/common/rnexecutorch/data_processing/Numerical.h +53 -4
  4. package/common/rnexecutorch/data_processing/dsp.cpp +1 -1
  5. package/common/rnexecutorch/data_processing/dsp.h +1 -1
  6. package/common/rnexecutorch/data_processing/gzip.cpp +47 -0
  7. package/common/rnexecutorch/data_processing/gzip.h +7 -0
  8. package/common/rnexecutorch/host_objects/ModelHostObject.h +24 -0
  9. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +21 -1
  10. package/common/rnexecutorch/models/BaseModel.cpp +3 -2
  11. package/common/rnexecutorch/models/BaseModel.h +3 -2
  12. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +100 -39
  13. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +43 -21
  14. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +307 -0
  15. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +61 -0
  16. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +80 -0
  17. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +27 -0
  18. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +96 -0
  19. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +36 -0
  20. package/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +15 -0
  21. package/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +12 -0
  22. package/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +12 -0
  23. package/common/rnexecutorch/models/speech_to_text/types/Segment.h +14 -0
  24. package/common/rnexecutorch/models/speech_to_text/types/Word.h +13 -0
  25. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +75 -53
  26. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  27. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +5 -5
  28. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -12
  29. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  30. package/lib/typescript/types/stt.d.ts +0 -9
  31. package/lib/typescript/types/stt.d.ts.map +1 -1
  32. package/package.json +1 -1
  33. package/react-native-executorch.podspec +2 -0
  34. package/src/modules/natural_language_processing/SpeechToTextModule.ts +118 -54
  35. package/src/types/stt.ts +0 -12
  36. package/common/rnexecutorch/models/EncoderDecoderBase.cpp +0 -21
  37. package/common/rnexecutorch/models/EncoderDecoderBase.h +0 -31
  38. package/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +0 -27
  39. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +0 -50
  40. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +0 -25
  41. package/lib/module/utils/SpeechToTextModule/ASR.js +0 -191
  42. package/lib/module/utils/SpeechToTextModule/ASR.js.map +0 -1
  43. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js +0 -73
  44. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js.map +0 -1
  45. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js +0 -56
  46. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js.map +0 -1
  47. package/lib/module/utils/stt.js +0 -22
  48. package/lib/module/utils/stt.js.map +0 -1
  49. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts +0 -27
  50. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts.map +0 -1
  51. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts +0 -23
  52. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts.map +0 -1
  53. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts +0 -13
  54. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts.map +0 -1
  55. package/lib/typescript/utils/stt.d.ts +0 -2
  56. package/lib/typescript/utils/stt.d.ts.map +0 -1
  57. package/src/utils/SpeechToTextModule/ASR.ts +0 -303
  58. package/src/utils/SpeechToTextModule/OnlineProcessor.ts +0 -87
  59. package/src/utils/SpeechToTextModule/hypothesisBuffer.ts +0 -79
  60. package/src/utils/stt.ts +0 -28
@@ -95,4 +95,5 @@ target_link_libraries(
95
95
  ${OPENCV_THIRD_PARTY_LIBS}
96
96
  executorch
97
97
  ${EXECUTORCH_LIBS}
98
- )
98
+ z
99
+ )
@@ -9,7 +9,7 @@
9
9
  #include <string>
10
10
 
11
11
  namespace rnexecutorch::numerical {
12
- void softmax(std::vector<float> &v) {
12
+ void softmax(std::span<float> v) {
13
13
  float max = *std::max_element(v.begin(), v.end());
14
14
 
15
15
  float sum = 0.0f;
@@ -22,32 +22,40 @@ void softmax(std::vector<float> &v) {
22
22
  }
23
23
  }
24
24
 
25
- void normalize(std::span<float> span) {
26
- auto sum = 0.0f;
27
- for (const auto &val : span) {
28
- sum += val * val;
25
+ void softmaxWithTemperature(std::span<float> input, float temperature) {
26
+ if (input.empty()) {
27
+ return;
29
28
  }
30
29
 
31
- if (isClose(sum, 0.0f)) {
32
- return;
30
+ if (temperature <= 0.0F) {
31
+ throw std::invalid_argument(
32
+ "Temperature must be greater than 0 for softmax with temperature.");
33
33
  }
34
34
 
35
- float norm = std::sqrt(sum);
36
- for (auto &val : span) {
37
- val /= norm;
35
+ const auto maxElement = *std::ranges::max_element(input);
36
+
37
+ for (auto &value : input) {
38
+ value = std::exp((value - maxElement) / temperature);
38
39
  }
39
- }
40
40
 
41
- void normalize(std::vector<float> &v) {
42
- float sum = 0.0f;
43
- for (float &x : v) {
44
- sum += x * x;
41
+ const auto sum = std::reduce(input.begin(), input.end());
42
+
43
+ // sum is at least 1 since exp(max - max) == exp(0) == 1
44
+ for (auto &value : input) {
45
+ value /= sum;
45
46
  }
47
+ }
46
48
 
47
- float norm =
48
- std::max(std::sqrt(sum), 1e-9f); // Solely for preventing division by 0
49
- for (float &x : v) {
50
- x /= norm;
49
+ void normalize(std::span<float> input) {
50
+ const auto sumOfSquares =
51
+ std::inner_product(input.begin(), input.end(), input.begin(), 0.0F);
52
+
53
+ constexpr auto kEpsilon = 1.0e-15F;
54
+
55
+ const auto norm = std::sqrt(sumOfSquares) + kEpsilon;
56
+
57
+ for (auto &value : input) {
58
+ value /= norm;
51
59
  }
52
60
  }
53
61
 
@@ -4,10 +4,59 @@
4
4
  #include <vector>
5
5
 
6
6
  namespace rnexecutorch::numerical {
7
- void softmax(std::vector<float> &v);
8
- void normalize(std::span<float> span);
9
- void normalize(std::vector<float> &v);
10
- void normalize(std::span<float> span);
7
+
8
+ /**
9
+ * @brief Applies the softmax function in-place to a sequence of numbers.
10
+ *
11
+ * @param input A mutable span of floating-point numbers. After the function
12
+ * returns, `input` contains the softmax probabilities.
13
+ */
14
+ void softmax(std::span<float> input);
15
+
16
+ /**
17
+ * @brief Applies the softmax function with temperature scaling in-place to a
18
+ * sequence of numbers.
19
+ *
20
+ * The temperature parameter controls the "sharpness" of the resulting
21
+ * probability distribution. A temperature of 1.0 means no scaling, while lower
22
+ * values make the distribution sharper (more peaked), and higher values make it
23
+ * softer (more uniform).
24
+ *
25
+ * @param input A mutable span of floating-point numbers. After the function
26
+ * returns, `input` contains the softmax probabilities.
27
+ * @param temperature A positive float value used to scale the logits before
28
+ * applying softmax. Must be greater than 0.
29
+ */
30
+ void softmaxWithTemperature(std::span<float> input, float temperature);
31
+
32
+ /**
33
+ * @brief Normalizes the elements of the given float span in-place using the
34
+ * L2 norm method.
35
+ *
36
+ * This function scales the input vector such that its L2 norm (Euclidean norm)
37
+ * becomes 1. If the norm is zero, the result is a zero vector with the same
38
+ * size as the input.
39
+ *
40
+ * @param input A mutable span of floating-point values representing the data to
41
+ * be normalized.
42
+ */
43
+ void normalize(std::span<float> input);
44
+
45
+ /**
46
+ * @brief Computes mean pooling across the modelOutput adjusted by an attention
47
+ * mask.
48
+ *
49
+ * This function aggregates the `modelOutput` span by sections defined by
50
+ * `attnMask`, computing the mean of sections influenced by the mask. The result
51
+ * is a vector where each element is the mean of a segment from the original
52
+ * data.
53
+ *
54
+ * @param modelOutput A span of floating-point numbers representing the model
55
+ * output.
56
+ * @param attnMask A span of integers where each integer is a weight
57
+ * corresponding to the elements in `modelOutput`.
58
+ * @return A std::vector<float> containing the computed mean values of segments.
59
+ */
11
60
  std::vector<float> meanPooling(std::span<const float> modelOutput,
12
61
  std::span<const int64_t> attnMask);
13
62
  /**
@@ -18,7 +18,7 @@ std::vector<float> hannWindow(size_t size) {
18
18
  return window;
19
19
  }
20
20
 
21
- std::vector<float> stftFromWaveform(std::span<float> waveform,
21
+ std::vector<float> stftFromWaveform(std::span<const float> waveform,
22
22
  size_t fftWindowSize, size_t hopSize) {
23
23
  // Initialize FFT
24
24
  FFT fft(fftWindowSize);
@@ -6,7 +6,7 @@
6
6
  namespace rnexecutorch::dsp {
7
7
 
8
8
  std::vector<float> hannWindow(size_t size);
9
- std::vector<float> stftFromWaveform(std::span<float> waveform,
9
+ std::vector<float> stftFromWaveform(std::span<const float> waveform,
10
10
  size_t fftWindowSize, size_t hopSize);
11
11
 
12
12
  } // namespace rnexecutorch::dsp
@@ -0,0 +1,47 @@
1
+ #include <vector>
2
+ #include <zlib.h>
3
+
4
+ #include "gzip.h"
5
+
6
+ namespace rnexecutorch::gzip {
7
+
8
+ namespace {
9
+ constexpr int32_t kGzipWrapper = 16; // gzip header/trailer
10
+ constexpr int32_t kMemLevel = 8; // memory level
11
+ constexpr size_t kChunkSize = 16 * 1024; // 16 KiB stream buffer
12
+ } // namespace
13
+
14
+ size_t deflateSize(const std::string &input) {
15
+ z_stream strm{};
16
+ if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
17
+ MAX_WBITS + kGzipWrapper, kMemLevel,
18
+ Z_DEFAULT_STRATEGY) != Z_OK) {
19
+ throw std::runtime_error("deflateInit2 failed");
20
+ }
21
+
22
+ size_t outSize = 0;
23
+
24
+ strm.next_in = reinterpret_cast<z_const Bytef *>(
25
+ const_cast<z_const char *>(input.data()));
26
+ strm.avail_in = static_cast<uInt>(input.size());
27
+
28
+ std::vector<unsigned char> buf(kChunkSize);
29
+ int ret;
30
+ do {
31
+ strm.next_out = buf.data();
32
+ strm.avail_out = static_cast<uInt>(buf.size());
33
+
34
+ ret = ::deflate(&strm, strm.avail_in ? Z_NO_FLUSH : Z_FINISH);
35
+ if (ret == Z_STREAM_ERROR) {
36
+ ::deflateEnd(&strm);
37
+ throw std::runtime_error("deflate stream error");
38
+ }
39
+
40
+ outSize += buf.size() - strm.avail_out;
41
+ } while (ret != Z_STREAM_END);
42
+
43
+ ::deflateEnd(&strm);
44
+ return outSize;
45
+ }
46
+
47
+ } // namespace rnexecutorch::gzip
@@ -0,0 +1,7 @@
1
+ #pragma once
2
+
3
+ namespace rnexecutorch::gzip {
4
+
5
+ size_t deflateSize(const std::string &input);
6
+
7
+ } // namespace rnexecutorch::gzip
@@ -62,6 +62,30 @@ public:
62
62
  "decode"));
63
63
  }
64
64
 
65
+ if constexpr (meta::HasTranscribe<Model>) {
66
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
67
+ promiseHostFunction<&Model::transcribe>,
68
+ "transcribe"));
69
+ }
70
+
71
+ if constexpr (meta::HasStream<Model>) {
72
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
73
+ promiseHostFunction<&Model::stream>,
74
+ "stream"));
75
+ }
76
+
77
+ if constexpr (meta::HasStreamInsert<Model>) {
78
+ addFunctions(JSI_EXPORT_FUNCTION(
79
+ ModelHostObject<Model>, promiseHostFunction<&Model::streamInsert>,
80
+ "streamInsert"));
81
+ }
82
+
83
+ if constexpr (meta::HasStreamStop<Model>) {
84
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
85
+ promiseHostFunction<&Model::streamStop>,
86
+ "streamStop"));
87
+ }
88
+
65
89
  if constexpr (meta::SameAs<Model, TokenizerModule>) {
66
90
  addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
67
91
  promiseHostFunction<&Model::encode>,
@@ -26,6 +26,26 @@ concept HasDecode = requires(T t) {
26
26
  { &T::decode };
27
27
  };
28
28
 
29
+ template <typename T>
30
+ concept HasTranscribe = requires(T t) {
31
+ { &T::transcribe };
32
+ };
33
+
34
+ template <typename T>
35
+ concept HasStream = requires(T t) {
36
+ { &T::stream };
37
+ };
38
+
39
+ template <typename T>
40
+ concept HasStreamInsert = requires(T t) {
41
+ { &T::streamInsert };
42
+ };
43
+
44
+ template <typename T>
45
+ concept HasStreamStop = requires(T t) {
46
+ { &T::streamStop };
47
+ };
48
+
29
49
  template <typename T>
30
50
  concept IsNumeric = std::is_arithmetic_v<T>;
31
51
 
@@ -34,4 +54,4 @@ concept ProvidesMemoryLowerBound = requires(T t) {
34
54
  { &T::getMemoryLowerBound };
35
55
  };
36
56
 
37
- } // namespace rnexecutorch::meta
57
+ } // namespace rnexecutorch::meta
@@ -142,7 +142,8 @@ BaseModel::getMethodMeta(const std::string &methodName) {
142
142
  return module_->method_meta(methodName);
143
143
  }
144
144
 
145
- Result<std::vector<EValue>> BaseModel::forward(const EValue &input_evalue) {
145
+ Result<std::vector<EValue>>
146
+ BaseModel::forward(const EValue &input_evalue) const {
146
147
  if (!module_) {
147
148
  throw std::runtime_error("Model not loaded: Cannot perform forward pass");
148
149
  }
@@ -150,7 +151,7 @@ Result<std::vector<EValue>> BaseModel::forward(const EValue &input_evalue) {
150
151
  }
151
152
 
152
153
  Result<std::vector<EValue>>
153
- BaseModel::forward(const std::vector<EValue> &input_evalues) {
154
+ BaseModel::forward(const std::vector<EValue> &input_evalues) const {
154
155
  if (!module_) {
155
156
  throw std::runtime_error("Model not loaded: Cannot perform forward pass");
156
157
  }
@@ -26,8 +26,9 @@ public:
26
26
  getAllInputShapes(std::string methodName = "forward");
27
27
  std::vector<JSTensorViewOut>
28
28
  forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
29
- Result<std::vector<EValue>> forward(const EValue &input_value);
30
- Result<std::vector<EValue>> forward(const std::vector<EValue> &input_value);
29
+ Result<std::vector<EValue>> forward(const EValue &input_value) const;
30
+ Result<std::vector<EValue>>
31
+ forward(const std::vector<EValue> &input_value) const;
31
32
  Result<std::vector<EValue>> execute(const std::string &methodName,
32
33
  const std::vector<EValue> &input_value);
33
34
  Result<executorch::runtime::MethodMeta>
@@ -1,64 +1,125 @@
1
- #include <rnexecutorch/models/speech_to_text/SpeechToText.h>
2
- #include <rnexecutorch/models/speech_to_text/WhisperStrategy.h>
3
- #include <stdexcept>
1
+ #include <thread>
2
+
3
+ #include "SpeechToText.h"
4
4
 
5
5
  namespace rnexecutorch::models::speech_to_text {
6
6
 
7
7
  using namespace ::executorch::extension;
8
8
 
9
- SpeechToText::SpeechToText(const std::string &encoderPath,
10
- const std::string &decoderPath,
11
- const std::string &modelName,
9
+ SpeechToText::SpeechToText(const std::string &encoderSource,
10
+ const std::string &decoderSource,
11
+ const std::string &tokenizerSource,
12
12
  std::shared_ptr<react::CallInvoker> callInvoker)
13
- : EncoderDecoderBase(encoderPath, decoderPath, callInvoker),
14
- modelName(modelName) {
15
- initializeStrategy();
13
+ : callInvoker(std::move(callInvoker)),
14
+ encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
15
+ decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
16
+ tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
17
+ this->callInvoker)),
18
+ asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
19
+ this->tokenizer.get())),
20
+ processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
21
+ isStreaming(false), readyToProcess(false) {}
22
+
23
+ std::shared_ptr<OwningArrayBuffer>
24
+ SpeechToText::encode(std::span<float> waveform) const {
25
+ std::vector<float> encoderOutput = this->asr->encode(waveform);
26
+ return this->makeOwningBuffer(encoderOutput);
16
27
  }
17
28
 
18
- void SpeechToText::initializeStrategy() {
19
- if (modelName == "whisper") {
20
- strategy = std::make_unique<WhisperStrategy>();
21
- } else {
22
- throw std::runtime_error("Unsupported STT model: " + modelName +
23
- ". Only 'whisper' is supported.");
24
- }
29
+ std::shared_ptr<OwningArrayBuffer>
30
+ SpeechToText::decode(std::span<int32_t> tokens,
31
+ std::span<float> encoderOutput) const {
32
+ std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
33
+ return this->makeOwningBuffer(decoderOutput);
25
34
  }
26
35
 
27
- void SpeechToText::encode(std::span<float> waveform) {
28
- const auto modelInputTensor = strategy->prepareAudioInput(waveform);
36
+ std::string SpeechToText::transcribe(std::span<float> waveform,
37
+ std::string languageOption) const {
38
+ std::vector<Segment> segments =
39
+ this->asr->transcribe(waveform, DecodingOptions(languageOption));
40
+ std::string transcription;
41
+
42
+ size_t transcriptionLength = 0;
43
+ for (auto &segment : segments) {
44
+ for (auto &word : segment.words) {
45
+ transcriptionLength += word.content.size();
46
+ }
47
+ }
48
+ transcription.reserve(transcriptionLength);
29
49
 
30
- const auto result = encoder_->forward(modelInputTensor);
31
- if (!result.ok()) {
32
- throw std::runtime_error(
33
- "Forward pass failed during encoding, error code: " +
34
- std::to_string(static_cast<int>(result.error())));
50
+ for (auto &segment : segments) {
51
+ for (auto &word : segment.words) {
52
+ transcription += word.content;
53
+ }
35
54
  }
55
+ return transcription;
56
+ }
36
57
 
37
- encoderOutput = result.get().at(0);
58
+ size_t SpeechToText::getMemoryLowerBound() const noexcept {
59
+ return this->encoder->getMemoryLowerBound() +
60
+ this->decoder->getMemoryLowerBound() +
61
+ this->tokenizer->getMemoryLowerBound();
38
62
  }
39
63
 
40
64
  std::shared_ptr<OwningArrayBuffer>
41
- SpeechToText::decode(std::vector<int64_t> prevTokens) {
42
- if (encoderOutput.isNone()) {
43
- throw std::runtime_error("Empty encodings on decode call, make sure to "
44
- "call encode() prior to decode()!");
65
+ SpeechToText::makeOwningBuffer(std::span<const float> vectorView) const {
66
+ auto owningArrayBuffer =
67
+ std::make_shared<OwningArrayBuffer>(vectorView.size_bytes());
68
+ std::memcpy(owningArrayBuffer->data(), vectorView.data(),
69
+ vectorView.size_bytes());
70
+ return owningArrayBuffer;
71
+ }
72
+
73
+ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
74
+ std::string languageOption) {
75
+ if (this->isStreaming) {
76
+ throw std::runtime_error("Streaming is already in progress");
77
+ }
78
+
79
+ auto nativeCallback = [this, callback](const std::string &committed,
80
+ const std::string &nonCommitted,
81
+ bool isDone) {
82
+ this->callInvoker->invokeAsync(
83
+ [callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
84
+ callback->call(rt, jsi::String::createFromUtf8(rt, committed),
85
+ jsi::String::createFromUtf8(rt, nonCommitted),
86
+ jsi::Value(isDone));
87
+ });
88
+ };
89
+
90
+ this->resetStreamState();
91
+
92
+ this->isStreaming = true;
93
+ while (this->isStreaming) {
94
+ if (!this->readyToProcess ||
95
+ this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
96
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
97
+ continue;
98
+ }
99
+ ProcessResult res =
100
+ this->processor->processIter(DecodingOptions(languageOption));
101
+ nativeCallback(res.committed, res.nonCommitted, false);
102
+ this->readyToProcess = false;
45
103
  }
46
104
 
47
- const auto prevTokensTensor = strategy->prepareTokenInput(prevTokens);
105
+ std::string committed = this->processor->finish();
106
+ nativeCallback(committed, "", true);
107
+ }
48
108
 
49
- const auto decoderMethod = strategy->getDecoderMethod();
50
- const auto decoderResult =
51
- decoder_->execute(decoderMethod, {prevTokensTensor, encoderOutput});
109
+ void SpeechToText::streamStop() { this->isStreaming = false; }
52
110
 
53
- if (!decoderResult.ok()) {
54
- throw std::runtime_error(
55
- "Forward pass failed during decoding, error code: " +
56
- std::to_string(static_cast<int>(decoderResult.error())));
111
+ void SpeechToText::streamInsert(std::span<float> waveform) {
112
+ if (!this->isStreaming) {
113
+ throw std::runtime_error("Streaming is not started");
57
114
  }
115
+ this->processor->insertAudioChunk(waveform);
116
+ this->readyToProcess = true;
117
+ }
58
118
 
59
- const auto decoderOutputTensor = decoderResult.get().at(0).toTensor();
60
- const auto innerDim = decoderOutputTensor.size(1);
61
- return strategy->extractOutputToken(decoderOutputTensor);
119
+ void SpeechToText::resetStreamState() {
120
+ this->isStreaming = false;
121
+ this->readyToProcess = false;
122
+ this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
62
123
  }
63
124
 
64
125
  } // namespace rnexecutorch::models::speech_to_text
@@ -1,38 +1,60 @@
1
1
  #pragma once
2
2
 
3
- #include "ReactCommon/CallInvoker.h"
4
- #include "executorch/runtime/core/evalue.h"
5
- #include <cstdint>
6
- #include <memory>
7
- #include <span>
8
- #include <string>
9
- #include <vector>
10
-
11
- #include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
12
- #include <rnexecutorch/models/EncoderDecoderBase.h>
13
- #include <rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h>
3
+ #include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
14
4
 
15
5
  namespace rnexecutorch {
6
+
16
7
  namespace models::speech_to_text {
17
- class SpeechToText : public EncoderDecoderBase {
8
+
9
+ using namespace asr;
10
+ using namespace types;
11
+ using namespace stream;
12
+
13
+ class SpeechToText {
18
14
  public:
19
- explicit SpeechToText(const std::string &encoderPath,
20
- const std::string &decoderPath,
21
- const std::string &modelName,
15
+ explicit SpeechToText(const std::string &encoderSource,
16
+ const std::string &decoderSource,
17
+ const std::string &tokenizerSource,
22
18
  std::shared_ptr<react::CallInvoker> callInvoker);
23
- void encode(std::span<float> waveform);
24
- std::shared_ptr<OwningArrayBuffer> decode(std::vector<int64_t> prevTokens);
19
+
20
+ std::shared_ptr<OwningArrayBuffer> encode(std::span<float> waveform) const;
21
+ std::shared_ptr<OwningArrayBuffer>
22
+ decode(std::span<int32_t> tokens, std::span<float> encoderOutput) const;
23
+ std::string transcribe(std::span<float> waveform,
24
+ std::string languageOption) const;
25
+
26
+ size_t getMemoryLowerBound() const noexcept;
27
+
28
+ // Stream
29
+ void stream(std::shared_ptr<jsi::Function> callback,
30
+ std::string languageOption);
31
+ void streamStop();
32
+ void streamInsert(std::span<float> waveform);
25
33
 
26
34
  private:
27
- const std::string modelName;
28
- executorch::runtime::EValue encoderOutput;
29
- std::unique_ptr<SpeechToTextStrategy> strategy;
35
+ std::unique_ptr<BaseModel> encoder;
36
+ std::unique_ptr<BaseModel> decoder;
37
+ std::unique_ptr<TokenizerModule> tokenizer;
38
+ std::unique_ptr<ASR> asr;
30
39
 
31
- void initializeStrategy();
40
+ std::shared_ptr<OwningArrayBuffer>
41
+ makeOwningBuffer(std::span<const float> vectorView) const;
42
+
43
+ // Stream
44
+ std::shared_ptr<react::CallInvoker> callInvoker;
45
+ std::unique_ptr<OnlineASRProcessor> processor;
46
+ bool isStreaming;
47
+ bool readyToProcess;
48
+
49
+ constexpr static int32_t kMinAudioSamples = 16000; // 1 second
50
+
51
+ void resetStreamState();
32
52
  };
53
+
33
54
  } // namespace models::speech_to_text
34
55
 
35
56
  REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string,
36
57
  std::string, std::string,
37
58
  std::shared_ptr<react::CallInvoker>);
59
+
38
60
  } // namespace rnexecutorch