react-native-executorch 0.5.3 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. package/android/src/main/cpp/CMakeLists.txt +2 -1
  2. package/common/rnexecutorch/data_processing/Numerical.cpp +27 -19
  3. package/common/rnexecutorch/data_processing/Numerical.h +53 -4
  4. package/common/rnexecutorch/data_processing/dsp.cpp +1 -1
  5. package/common/rnexecutorch/data_processing/dsp.h +1 -1
  6. package/common/rnexecutorch/data_processing/gzip.cpp +47 -0
  7. package/common/rnexecutorch/data_processing/gzip.h +7 -0
  8. package/common/rnexecutorch/host_objects/ModelHostObject.h +24 -0
  9. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +21 -1
  10. package/common/rnexecutorch/models/BaseModel.cpp +3 -2
  11. package/common/rnexecutorch/models/BaseModel.h +3 -2
  12. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +100 -39
  13. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +43 -21
  14. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +307 -0
  15. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +61 -0
  16. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +80 -0
  17. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +27 -0
  18. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +96 -0
  19. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +36 -0
  20. package/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +15 -0
  21. package/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +12 -0
  22. package/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +12 -0
  23. package/common/rnexecutorch/models/speech_to_text/types/Segment.h +14 -0
  24. package/common/rnexecutorch/models/speech_to_text/types/Word.h +13 -0
  25. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +75 -53
  26. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  27. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +5 -5
  28. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -12
  29. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  30. package/lib/typescript/types/stt.d.ts +0 -9
  31. package/lib/typescript/types/stt.d.ts.map +1 -1
  32. package/package.json +1 -1
  33. package/react-native-executorch.podspec +2 -0
  34. package/src/modules/natural_language_processing/SpeechToTextModule.ts +118 -54
  35. package/src/types/stt.ts +0 -12
  36. package/common/rnexecutorch/models/EncoderDecoderBase.cpp +0 -21
  37. package/common/rnexecutorch/models/EncoderDecoderBase.h +0 -31
  38. package/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +0 -27
  39. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +0 -50
  40. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +0 -25
  41. package/lib/module/utils/SpeechToTextModule/ASR.js +0 -191
  42. package/lib/module/utils/SpeechToTextModule/ASR.js.map +0 -1
  43. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js +0 -73
  44. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js.map +0 -1
  45. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js +0 -56
  46. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js.map +0 -1
  47. package/lib/module/utils/stt.js +0 -22
  48. package/lib/module/utils/stt.js.map +0 -1
  49. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts +0 -27
  50. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts.map +0 -1
  51. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts +0 -23
  52. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts.map +0 -1
  53. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts +0 -13
  54. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts.map +0 -1
  55. package/lib/typescript/utils/stt.d.ts +0 -2
  56. package/lib/typescript/utils/stt.d.ts.map +0 -1
  57. package/src/utils/SpeechToTextModule/ASR.ts +0 -303
  58. package/src/utils/SpeechToTextModule/OnlineProcessor.ts +0 -87
  59. package/src/utils/SpeechToTextModule/hypothesisBuffer.ts +0 -79
  60. package/src/utils/stt.ts +0 -28
@@ -0,0 +1,307 @@
1
+ #include <random>
2
+
3
+ #include "ASR.h"
4
+ #include "executorch/extension/tensor/tensor_ptr.h"
5
+ #include "rnexecutorch/data_processing/Numerical.h"
6
+ #include "rnexecutorch/data_processing/dsp.h"
7
+ #include "rnexecutorch/data_processing/gzip.h"
8
+
9
+ namespace rnexecutorch::models::speech_to_text::asr {
10
+
11
+ ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder,
12
+ const TokenizerModule *tokenizer)
13
+ : encoder(encoder), decoder(decoder), tokenizer(tokenizer),
14
+ startOfTranscriptionToken(
15
+ this->tokenizer->tokenToId("<|startoftranscript|>")),
16
+ endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")),
17
+ timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {}
18
+
19
+ std::vector<int32_t>
20
+ ASR::getInitialSequence(const DecodingOptions &options) const {
21
+ std::vector<int32_t> seq;
22
+ seq.push_back(this->startOfTranscriptionToken);
23
+
24
+ if (options.language.has_value()) {
25
+ int32_t langToken =
26
+ this->tokenizer->tokenToId("<|" + options.language.value() + "|>");
27
+ int32_t taskToken = this->tokenizer->tokenToId("<|transcribe|>");
28
+ seq.push_back(langToken);
29
+ seq.push_back(taskToken);
30
+ }
31
+
32
+ seq.push_back(this->timestampBeginToken);
33
+
34
+ return seq;
35
+ }
36
+
37
+ GenerationResult ASR::generate(std::span<const float> waveform,
38
+ float temperature,
39
+ const DecodingOptions &options) const {
40
+ std::vector<float> encoderOutput = this->encode(waveform);
41
+
42
+ std::vector<int32_t> sequenceIds = this->getInitialSequence(options);
43
+ const size_t initialSequenceLenght = sequenceIds.size();
44
+ std::vector<float> scores;
45
+
46
+ while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) {
47
+ std::vector<float> logits = this->decode(sequenceIds, encoderOutput);
48
+
49
+ // intentionally comparing float to float
50
+ // temperatures are predefined, so this is safe
51
+ if (temperature == 0.0f) {
52
+ numerical::softmax(logits);
53
+ } else {
54
+ numerical::softmaxWithTemperature(logits, temperature);
55
+ }
56
+
57
+ const std::vector<float> &probs = logits;
58
+
59
+ int32_t nextId;
60
+ float nextProb;
61
+
62
+ // intentionally comparing float to float
63
+ // temperatures are predefined, so this is safe
64
+ if (temperature == 0.0f) {
65
+ auto maxIt = std::ranges::max_element(probs);
66
+ nextId = static_cast<int32_t>(std::distance(probs.begin(), maxIt));
67
+ nextProb = *maxIt;
68
+ } else {
69
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
70
+ std::mt19937 gen((std::random_device{}()));
71
+ nextId = dist(gen);
72
+ nextProb = probs[nextId];
73
+ }
74
+
75
+ sequenceIds.push_back(nextId);
76
+ scores.push_back(nextProb);
77
+
78
+ if (nextId == this->endOfTranscriptionToken) {
79
+ break;
80
+ }
81
+ }
82
+
83
+ return {.tokens = std::vector<int32_t>(
84
+ sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()),
85
+ .scores = scores};
86
+ }
87
+
88
+ float ASR::getCompressionRatio(const std::string &text) const {
89
+ size_t compressedSize = gzip::deflateSize(text);
90
+ return static_cast<float>(text.size()) / static_cast<float>(compressedSize);
91
+ }
92
+
93
+ std::vector<Segment>
94
+ ASR::generateWithFallback(std::span<const float> waveform,
95
+ const DecodingOptions &options) const {
96
+ std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
97
+ std::vector<int32_t> bestTokens;
98
+
99
+ for (auto t : temperatures) {
100
+ auto [tokens, scores] = this->generate(waveform, t, options);
101
+
102
+ const float cumLogProb = std::transform_reduce(
103
+ scores.begin(), scores.end(), 0.0f, std::plus<>(),
104
+ [](float s) { return std::log(std::max(s, 1e-9f)); });
105
+
106
+ const float avgLogProb = cumLogProb / static_cast<float>(tokens.size() + 1);
107
+ const std::string text = this->tokenizer->decode(tokens, true);
108
+ const float compressionRatio = this->getCompressionRatio(text);
109
+
110
+ if (avgLogProb >= -1.0f && compressionRatio < 2.4f) {
111
+ bestTokens = std::move(tokens);
112
+ break;
113
+ }
114
+ }
115
+
116
+ return this->calculateWordLevelTimestamps(bestTokens, waveform);
117
+ }
118
+
119
+ std::vector<Segment>
120
+ ASR::calculateWordLevelTimestamps(std::span<const int32_t> generatedTokens,
121
+ const std::span<const float> waveform) const {
122
+ const size_t generatedTokensSize = generatedTokens.size();
123
+ if (generatedTokensSize < 2 ||
124
+ generatedTokens[generatedTokensSize - 1] !=
125
+ this->endOfTranscriptionToken ||
126
+ generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) {
127
+ return {};
128
+ }
129
+ std::vector<Segment> segments;
130
+ std::vector<int32_t> tokens;
131
+ int32_t prevTimestamp = this->timestampBeginToken;
132
+
133
+ for (size_t i = 0; i < generatedTokensSize; i++) {
134
+ if (generatedTokens[i] < this->timestampBeginToken) {
135
+ tokens.push_back(generatedTokens[i]);
136
+ }
137
+ if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken &&
138
+ generatedTokens[i] >= this->timestampBeginToken) {
139
+ const int32_t start = prevTimestamp;
140
+ const int32_t end = generatedTokens[i - 1];
141
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
142
+ if (words.size()) {
143
+ segments.emplace_back(std::move(words), 0.0);
144
+ }
145
+ tokens.clear();
146
+ prevTimestamp = generatedTokens[i];
147
+ }
148
+ }
149
+
150
+ const int32_t start = prevTimestamp;
151
+ const int32_t end = generatedTokens[generatedTokensSize - 2];
152
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
153
+
154
+ if (words.size()) {
155
+ segments.emplace_back(std::move(words), 0.0);
156
+ }
157
+
158
+ float scalingFactor =
159
+ static_cast<float>(waveform.size()) /
160
+ (ASR::kSamplingRate * (end - this->timestampBeginToken) *
161
+ ASR::kTimePrecision);
162
+ if (scalingFactor < 1.0f) {
163
+ for (auto &seg : segments) {
164
+ for (auto &w : seg.words) {
165
+ w.start *= scalingFactor;
166
+ w.end *= scalingFactor;
167
+ }
168
+ }
169
+ }
170
+
171
+ return segments;
172
+ }
173
+
174
+ std::vector<Word>
175
+ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
176
+ int32_t start, int32_t end) const {
177
+ const std::vector<int32_t> tokensVec(tokens.begin(), tokens.end());
178
+ const std::string segmentText = this->tokenizer->decode(tokensVec, true);
179
+ std::istringstream iss(segmentText);
180
+ std::vector<std::string> wordsStr;
181
+ std::string word;
182
+ while (iss >> word) {
183
+ wordsStr.emplace_back(" ");
184
+ wordsStr.back().append(word);
185
+ }
186
+
187
+ size_t numChars = 0;
188
+ for (const auto &w : wordsStr) {
189
+ numChars += w.size();
190
+ }
191
+ const float duration = (end - start) * ASR::kTimePrecision;
192
+ const float timePerChar = duration / std::max<float>(1, numChars);
193
+ const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision;
194
+
195
+ std::vector<Word> wordObjs;
196
+ wordObjs.reserve(wordsStr.size());
197
+ int32_t prevCharCount = 0;
198
+ for (auto &w : wordsStr) {
199
+ const auto wSize = static_cast<int32_t>(w.size());
200
+ const float wStart = startOffset + prevCharCount * timePerChar;
201
+ const float wEnd = wStart + timePerChar * wSize;
202
+ prevCharCount += wSize;
203
+ wordObjs.emplace_back(std::move(w), wStart, wEnd);
204
+ }
205
+
206
+ return wordObjs;
207
+ }
208
+
209
+ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
210
+ const DecodingOptions &options) const {
211
+ int32_t seek = 0;
212
+ std::vector<Segment> results;
213
+
214
+ while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) {
215
+ int32_t start = seek * ASR::kSamplingRate;
216
+ const auto end = std::min<int32_t>(
217
+ (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
218
+ std::span<const float> chunk = waveform.subspan(start, end - start);
219
+
220
+ if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
221
+ break;
222
+ }
223
+
224
+ std::vector<Segment> segments = this->generateWithFallback(chunk, options);
225
+
226
+ if (segments.empty()) {
227
+ seek += ASR::kChunkSize;
228
+ continue;
229
+ }
230
+
231
+ for (auto &seg : segments) {
232
+ for (auto &w : seg.words) {
233
+ w.start += seek;
234
+ w.end += seek;
235
+ }
236
+ }
237
+
238
+ seek = static_cast<int32_t>(segments.back().words.back().end);
239
+ results.insert(results.end(), std::make_move_iterator(segments.begin()),
240
+ std::make_move_iterator(segments.end()));
241
+ }
242
+
243
+ return results;
244
+ }
245
+
246
+ std::vector<float> ASR::encode(std::span<const float> waveform) const {
247
+ constexpr int32_t fftWindowSize = 512;
248
+ constexpr int32_t stftHopLength = 160;
249
+ constexpr int32_t innerDim = 256;
250
+
251
+ std::vector<float> preprocessedData =
252
+ dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
253
+ const auto numFrames =
254
+ static_cast<int32_t>(preprocessedData.size()) / innerDim;
255
+ std::vector<int32_t> inputShape = {numFrames, innerDim};
256
+
257
+ const auto modelInputTensor = executorch::extension::make_tensor_ptr(
258
+ std::move(inputShape), std::move(preprocessedData));
259
+ const auto encoderResult = this->encoder->forward(modelInputTensor);
260
+
261
+ if (!encoderResult.ok()) {
262
+ throw std::runtime_error(
263
+ "Forward pass failed during encoding, error code: " +
264
+ std::to_string(static_cast<int32_t>(encoderResult.error())));
265
+ }
266
+
267
+ const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
268
+ const int32_t outputNumel = decoderOutputTensor.numel();
269
+
270
+ const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
271
+ return {dataPtr, dataPtr + outputNumel};
272
+ }
273
+
274
+ std::vector<float> ASR::decode(std::span<int32_t> tokens,
275
+ std::span<float> encoderOutput) const {
276
+ std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
277
+ auto tokenTensor = executorch::extension::make_tensor_ptr(
278
+ std::move(tokenShape), tokens.data(), ScalarType::Int);
279
+
280
+ const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
281
+ std::vector<int32_t> encShape = {1, ASR::kNumFrames,
282
+ encoderOutputSize / ASR::kNumFrames};
283
+ auto encoderTensor = executorch::extension::make_tensor_ptr(
284
+ std::move(encShape), encoderOutput.data(), ScalarType::Float);
285
+
286
+ const auto decoderResult =
287
+ this->decoder->forward({tokenTensor, encoderTensor});
288
+
289
+ if (!decoderResult.ok()) {
290
+ throw std::runtime_error(
291
+ "Forward pass failed during decoding, error code: " +
292
+ std::to_string(static_cast<int32_t>(decoderResult.error())));
293
+ }
294
+
295
+ const auto logitsTensor = decoderResult.get().at(0).toTensor();
296
+ const int32_t outputNumel = logitsTensor.numel();
297
+
298
+ const size_t innerDim = logitsTensor.size(1);
299
+ const size_t dictSize = logitsTensor.size(2);
300
+
301
+ const float *const dataPtr =
302
+ logitsTensor.const_data_ptr<float>() + (innerDim - 1) * dictSize;
303
+
304
+ return {dataPtr, dataPtr + outputNumel / innerDim};
305
+ }
306
+
307
+ } // namespace rnexecutorch::models::speech_to_text::asr
@@ -0,0 +1,61 @@
1
+ #pragma once
2
+
3
+ #include "rnexecutorch/TokenizerModule.h"
4
+ #include "rnexecutorch/models/BaseModel.h"
5
+ #include "rnexecutorch/models/speech_to_text/types/DecodingOptions.h"
6
+ #include "rnexecutorch/models/speech_to_text/types/GenerationResult.h"
7
+ #include "rnexecutorch/models/speech_to_text/types/Segment.h"
8
+
9
+ namespace rnexecutorch::models::speech_to_text::asr {
10
+
11
+ using namespace types;
12
+
13
+ class ASR {
14
+ public:
15
+ explicit ASR(const models::BaseModel *encoder,
16
+ const models::BaseModel *decoder,
17
+ const TokenizerModule *tokenizer);
18
+ std::vector<Segment> transcribe(std::span<const float> waveform,
19
+ const DecodingOptions &options) const;
20
+ std::vector<float> encode(std::span<const float> waveform) const;
21
+ std::vector<float> decode(std::span<int32_t> tokens,
22
+ std::span<float> encoderOutput) const;
23
+
24
+ private:
25
+ const models::BaseModel *encoder;
26
+ const models::BaseModel *decoder;
27
+ const TokenizerModule *tokenizer;
28
+
29
+ int32_t startOfTranscriptionToken;
30
+ int32_t endOfTranscriptionToken;
31
+ int32_t timestampBeginToken;
32
+
33
+ // Time precision used by Whisper timestamps: each token spans 0.02 seconds
34
+ constexpr static float kTimePrecision = 0.02f;
35
+ // The maximum number of tokens the decoder can generate per chunk
36
+ constexpr static int32_t kMaxDecodeLength = 128;
37
+ // Maximum duration of each audio chunk to process (in seconds)
38
+ constexpr static int32_t kChunkSize = 30;
39
+ // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz)
40
+ constexpr static int32_t kSamplingRate = 16000;
41
+ // Minimum allowed chunk length before processing (in audio samples)
42
+ constexpr static int32_t kMinChunkSamples = 1 * 16000;
43
+ // Number of mel frames output by the encoder (derived from input spectrogram)
44
+ constexpr static int32_t kNumFrames = 1500;
45
+
46
+ std::vector<int32_t> getInitialSequence(const DecodingOptions &options) const;
47
+ GenerationResult generate(std::span<const float> waveform, float temperature,
48
+ const DecodingOptions &options) const;
49
+ std::vector<Segment>
50
+ generateWithFallback(std::span<const float> waveform,
51
+ const DecodingOptions &options) const;
52
+ std::vector<Segment>
53
+ calculateWordLevelTimestamps(std::span<const int32_t> tokens,
54
+ std::span<const float> waveform) const;
55
+ std::vector<Word>
56
+ estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
57
+ int32_t start, int32_t end) const;
58
+ float getCompressionRatio(const std::string &text) const;
59
+ };
60
+
61
+ } // namespace rnexecutorch::models::speech_to_text::asr
@@ -0,0 +1,80 @@
1
+ #include "HypothesisBuffer.h"
2
+
3
+ namespace rnexecutorch::models::speech_to_text::stream {
4
+
5
+ void HypothesisBuffer::insert(std::span<const Word> newWords, float offset) {
6
+ this->fresh.clear();
7
+ for (const auto &word : newWords) {
8
+ const float newStart = word.start + offset;
9
+ if (newStart > lastCommittedTime - 0.5f) {
10
+ this->fresh.emplace_back(word.content, newStart, word.end + offset);
11
+ }
12
+ }
13
+
14
+ if (!this->fresh.empty() && !this->committedInBuffer.empty()) {
15
+ const float a = this->fresh.front().start;
16
+ if (std::fabs(a - lastCommittedTime) < 1.0f) {
17
+ const size_t cn = this->committedInBuffer.size();
18
+ const size_t nn = this->fresh.size();
19
+ const std::size_t maxCheck = std::min<std::size_t>({cn, nn, 5});
20
+ for (size_t i = 1; i <= maxCheck; i++) {
21
+ std::string c;
22
+ for (auto it = this->committedInBuffer.cend() - i;
23
+ it != this->committedInBuffer.cend(); ++it) {
24
+ if (!c.empty()) {
25
+ c += ' ';
26
+ }
27
+ c += it->content;
28
+ }
29
+
30
+ std::string tail;
31
+ auto it = this->fresh.cbegin();
32
+ for (size_t k = 0; k < i; k++, it++) {
33
+ if (!tail.empty()) {
34
+ tail += ' ';
35
+ }
36
+ tail += it->content;
37
+ }
38
+
39
+ if (c == tail) {
40
+ this->fresh.erase(this->fresh.begin(), this->fresh.begin() + i);
41
+ break;
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+
48
+ std::deque<Word> HypothesisBuffer::flush() {
49
+ std::deque<Word> commit;
50
+
51
+ while (!this->fresh.empty() && !this->buffer.empty()) {
52
+ if (this->fresh.front().content != this->buffer.front().content) {
53
+ break;
54
+ }
55
+ commit.push_back(this->fresh.front());
56
+ this->buffer.pop_front();
57
+ this->fresh.pop_front();
58
+ }
59
+
60
+ if (!commit.empty()) {
61
+ lastCommittedTime = commit.back().end;
62
+ }
63
+
64
+ this->buffer = std::move(this->fresh);
65
+ this->fresh.clear();
66
+ this->committedInBuffer.insert(this->committedInBuffer.end(), commit.begin(),
67
+ commit.end());
68
+ return commit;
69
+ }
70
+
71
+ void HypothesisBuffer::popCommitted(float time) {
72
+ while (!this->committedInBuffer.empty() &&
73
+ this->committedInBuffer.front().end <= time) {
74
+ this->committedInBuffer.pop_front();
75
+ }
76
+ }
77
+
78
+ std::deque<Word> HypothesisBuffer::complete() const { return this->buffer; }
79
+
80
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,27 @@
1
+ #pragma once
2
+
3
+ #include <deque>
4
+ #include <span>
5
+
6
+ #include "rnexecutorch/models/speech_to_text/types/Word.h"
7
+
8
+ namespace rnexecutorch::models::speech_to_text::stream {
9
+
10
+ using namespace types;
11
+
12
+ class HypothesisBuffer {
13
+ public:
14
+ void insert(std::span<const Word> newWords, float offset);
15
+ std::deque<Word> flush();
16
+ void popCommitted(float time);
17
+ std::deque<Word> complete() const;
18
+
19
+ private:
20
+ float lastCommittedTime = 0.0f;
21
+
22
+ std::deque<Word> committedInBuffer;
23
+ std::deque<Word> buffer;
24
+ std::deque<Word> fresh;
25
+ };
26
+
27
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,96 @@
1
+ #include <numeric>
2
+
3
+ #include "OnlineASRProcessor.h"
4
+
5
+ namespace rnexecutorch::models::speech_to_text::stream {
6
+
7
+ OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {}
8
+
9
+ void OnlineASRProcessor::insertAudioChunk(std::span<const float> audio) {
10
+ audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end());
11
+ }
12
+
13
+ ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) {
14
+ std::vector<Segment> res = asr->transcribe(audioBuffer, options);
15
+
16
+ std::vector<Word> tsw;
17
+ for (const auto &segment : res) {
18
+ for (const auto &word : segment.words) {
19
+ tsw.push_back(word);
20
+ }
21
+ }
22
+
23
+ this->hypothesisBuffer.insert(tsw, this->bufferTimeOffset);
24
+ std::deque<Word> flushed = this->hypothesisBuffer.flush();
25
+ this->committed.insert(this->committed.end(), flushed.begin(), flushed.end());
26
+
27
+ constexpr int32_t chunkThresholdSec = 15;
28
+ if (static_cast<float>(audioBuffer.size()) /
29
+ OnlineASRProcessor::kSamplingRate >
30
+ chunkThresholdSec) {
31
+ chunkCompletedSegment(res);
32
+ }
33
+
34
+ std::deque<Word> nonCommittedWords = this->hypothesisBuffer.complete();
35
+ return {this->toFlush(flushed), this->toFlush(nonCommittedWords)};
36
+ }
37
+
38
+ void OnlineASRProcessor::chunkCompletedSegment(std::span<const Segment> res) {
39
+ if (this->committed.empty())
40
+ return;
41
+
42
+ std::vector<float> ends(res.size());
43
+ std::ranges::transform(res, ends.begin(), [](const Segment &seg) {
44
+ return seg.words.back().end;
45
+ });
46
+
47
+ const float t = this->committed.back().end;
48
+
49
+ if (ends.size() > 1) {
50
+ float e = ends[ends.size() - 2] + this->bufferTimeOffset;
51
+ while (ends.size() > 2 && e > t) {
52
+ ends.pop_back();
53
+ e = ends[ends.size() - 2] + this->bufferTimeOffset;
54
+ }
55
+ if (e <= t) {
56
+ chunkAt(e);
57
+ }
58
+ }
59
+ }
60
+
61
+ void OnlineASRProcessor::chunkAt(float time) {
62
+ this->hypothesisBuffer.popCommitted(time);
63
+
64
+ const float cutSeconds = time - this->bufferTimeOffset;
65
+ auto startIndex =
66
+ static_cast<size_t>(cutSeconds * OnlineASRProcessor::kSamplingRate);
67
+
68
+ if (startIndex < audioBuffer.size()) {
69
+ audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex);
70
+ } else {
71
+ audioBuffer.clear();
72
+ }
73
+
74
+ this->bufferTimeOffset = time;
75
+ }
76
+
77
+ std::string OnlineASRProcessor::finish() {
78
+ const std::deque<Word> buffer = this->hypothesisBuffer.complete();
79
+ std::string committedText = this->toFlush(buffer);
80
+ this->bufferTimeOffset += static_cast<float>(audioBuffer.size()) /
81
+ OnlineASRProcessor::kSamplingRate;
82
+ return committedText;
83
+ }
84
+
85
+ std::string OnlineASRProcessor::toFlush(const std::deque<Word> &words) const {
86
+ std::string text;
87
+ text.reserve(std::accumulate(
88
+ words.cbegin(), words.cend(), 0,
89
+ [](size_t sum, const Word &w) { return sum + w.content.size(); }));
90
+ for (const auto &word : words) {
91
+ text.append(word.content);
92
+ }
93
+ return text;
94
+ }
95
+
96
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,36 @@
1
+ #pragma once
2
+
3
+ #include "rnexecutorch/models/speech_to_text/asr/ASR.h"
4
+ #include "rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h"
5
+ #include "rnexecutorch/models/speech_to_text/types/ProcessResult.h"
6
+
7
+ namespace rnexecutorch::models::speech_to_text::stream {
8
+
9
+ using namespace asr;
10
+ using namespace types;
11
+
12
+ class OnlineASRProcessor {
13
+ public:
14
+ explicit OnlineASRProcessor(const ASR *asr);
15
+
16
+ void insertAudioChunk(std::span<const float> audio);
17
+ ProcessResult processIter(const DecodingOptions &options);
18
+ std::string finish();
19
+
20
+ std::vector<float> audioBuffer;
21
+
22
+ private:
23
+ const ASR *asr;
24
+ constexpr static int32_t kSamplingRate = 16000;
25
+
26
+ HypothesisBuffer hypothesisBuffer;
27
+ float bufferTimeOffset = 0.0f;
28
+ std::vector<Word> committed;
29
+
30
+ void chunkCompletedSegment(std::span<const Segment> res);
31
+ void chunkAt(float time);
32
+
33
+ std::string toFlush(const std::deque<Word> &words) const;
34
+ };
35
+
36
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,15 @@
1
+ #pragma once
2
+
3
+ #include <optional>
4
+ #include <string>
5
+
6
+ namespace rnexecutorch::models::speech_to_text::types {
7
+
8
+ struct DecodingOptions {
9
+ explicit DecodingOptions(const std::string &language)
10
+ : language(language.empty() ? std::nullopt : std::optional(language)) {}
11
+
12
+ std::optional<std::string> language;
13
+ };
14
+
15
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct GenerationResult {
8
+ std::vector<int32_t> tokens;
9
+ std::vector<float> scores;
10
+ };
11
+
12
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include <string>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct ProcessResult {
8
+ std::string committed;
9
+ std::string nonCommitted;
10
+ };
11
+
12
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,14 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "Word.h"
6
+
7
+ namespace rnexecutorch::models::speech_to_text::types {
8
+
9
+ struct Segment {
10
+ std::vector<Word> words;
11
+ float noSpeechProbability;
12
+ };
13
+
14
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,13 @@
1
+ #pragma once
2
+
3
+ #include <string>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct Word {
8
+ std::string content;
9
+ float start;
10
+ float end;
11
+ };
12
+
13
+ } // namespace rnexecutorch::models::speech_to_text::types