react-native-executorch 0.5.3 → 0.5.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. package/android/src/main/cpp/CMakeLists.txt +2 -1
  2. package/common/rnexecutorch/data_processing/Numerical.cpp +27 -19
  3. package/common/rnexecutorch/data_processing/Numerical.h +53 -4
  4. package/common/rnexecutorch/data_processing/dsp.cpp +1 -1
  5. package/common/rnexecutorch/data_processing/dsp.h +1 -1
  6. package/common/rnexecutorch/data_processing/gzip.cpp +47 -0
  7. package/common/rnexecutorch/data_processing/gzip.h +7 -0
  8. package/common/rnexecutorch/host_objects/ModelHostObject.h +24 -0
  9. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +21 -1
  10. package/common/rnexecutorch/models/BaseModel.cpp +3 -2
  11. package/common/rnexecutorch/models/BaseModel.h +3 -2
  12. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +103 -39
  13. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +39 -21
  14. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +310 -0
  15. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +62 -0
  16. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.cpp +82 -0
  17. package/common/rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h +25 -0
  18. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.cpp +99 -0
  19. package/common/rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h +33 -0
  20. package/common/rnexecutorch/models/speech_to_text/types/DecodingOptions.h +15 -0
  21. package/common/rnexecutorch/models/speech_to_text/types/GenerationResult.h +12 -0
  22. package/common/rnexecutorch/models/speech_to_text/types/ProcessResult.h +12 -0
  23. package/common/rnexecutorch/models/speech_to_text/types/Segment.h +14 -0
  24. package/common/rnexecutorch/models/speech_to_text/types/Word.h +13 -0
  25. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +75 -53
  26. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  27. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +5 -5
  28. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +7 -12
  29. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  30. package/lib/typescript/types/stt.d.ts +0 -9
  31. package/lib/typescript/types/stt.d.ts.map +1 -1
  32. package/package.json +1 -1
  33. package/react-native-executorch.podspec +2 -0
  34. package/src/modules/natural_language_processing/SpeechToTextModule.ts +118 -54
  35. package/src/types/stt.ts +0 -12
  36. package/common/rnexecutorch/models/EncoderDecoderBase.cpp +0 -21
  37. package/common/rnexecutorch/models/EncoderDecoderBase.h +0 -31
  38. package/common/rnexecutorch/models/speech_to_text/SpeechToTextStrategy.h +0 -27
  39. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.cpp +0 -50
  40. package/common/rnexecutorch/models/speech_to_text/WhisperStrategy.h +0 -25
  41. package/lib/Error.js +0 -53
  42. package/lib/ThreadPool.d.ts +0 -10
  43. package/lib/ThreadPool.js +0 -28
  44. package/lib/common/Logger.d.ts +0 -8
  45. package/lib/common/Logger.js +0 -19
  46. package/lib/constants/directories.js +0 -2
  47. package/lib/constants/llmDefaults.d.ts +0 -6
  48. package/lib/constants/llmDefaults.js +0 -16
  49. package/lib/constants/modelUrls.d.ts +0 -223
  50. package/lib/constants/modelUrls.js +0 -322
  51. package/lib/constants/ocr/models.d.ts +0 -882
  52. package/lib/constants/ocr/models.js +0 -182
  53. package/lib/constants/ocr/symbols.js +0 -139
  54. package/lib/constants/sttDefaults.d.ts +0 -28
  55. package/lib/constants/sttDefaults.js +0 -68
  56. package/lib/controllers/LLMController.d.ts +0 -47
  57. package/lib/controllers/LLMController.js +0 -213
  58. package/lib/controllers/OCRController.js +0 -67
  59. package/lib/controllers/SpeechToTextController.d.ts +0 -56
  60. package/lib/controllers/SpeechToTextController.js +0 -349
  61. package/lib/controllers/VerticalOCRController.js +0 -70
  62. package/lib/hooks/computer_vision/useClassification.d.ts +0 -15
  63. package/lib/hooks/computer_vision/useClassification.js +0 -7
  64. package/lib/hooks/computer_vision/useImageEmbeddings.d.ts +0 -15
  65. package/lib/hooks/computer_vision/useImageEmbeddings.js +0 -7
  66. package/lib/hooks/computer_vision/useImageSegmentation.d.ts +0 -38
  67. package/lib/hooks/computer_vision/useImageSegmentation.js +0 -7
  68. package/lib/hooks/computer_vision/useOCR.d.ts +0 -20
  69. package/lib/hooks/computer_vision/useOCR.js +0 -41
  70. package/lib/hooks/computer_vision/useObjectDetection.d.ts +0 -15
  71. package/lib/hooks/computer_vision/useObjectDetection.js +0 -7
  72. package/lib/hooks/computer_vision/useStyleTransfer.d.ts +0 -15
  73. package/lib/hooks/computer_vision/useStyleTransfer.js +0 -7
  74. package/lib/hooks/computer_vision/useVerticalOCR.d.ts +0 -21
  75. package/lib/hooks/computer_vision/useVerticalOCR.js +0 -43
  76. package/lib/hooks/general/useExecutorchModule.d.ts +0 -13
  77. package/lib/hooks/general/useExecutorchModule.js +0 -7
  78. package/lib/hooks/natural_language_processing/useLLM.d.ts +0 -10
  79. package/lib/hooks/natural_language_processing/useLLM.js +0 -78
  80. package/lib/hooks/natural_language_processing/useSpeechToText.d.ts +0 -27
  81. package/lib/hooks/natural_language_processing/useSpeechToText.js +0 -49
  82. package/lib/hooks/natural_language_processing/useTextEmbeddings.d.ts +0 -16
  83. package/lib/hooks/natural_language_processing/useTextEmbeddings.js +0 -7
  84. package/lib/hooks/natural_language_processing/useTokenizer.d.ts +0 -17
  85. package/lib/hooks/natural_language_processing/useTokenizer.js +0 -52
  86. package/lib/hooks/useModule.js +0 -45
  87. package/lib/hooks/useNonStaticModule.d.ts +0 -20
  88. package/lib/hooks/useNonStaticModule.js +0 -49
  89. package/lib/index.d.ts +0 -48
  90. package/lib/index.js +0 -58
  91. package/lib/module/utils/SpeechToTextModule/ASR.js +0 -191
  92. package/lib/module/utils/SpeechToTextModule/ASR.js.map +0 -1
  93. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js +0 -73
  94. package/lib/module/utils/SpeechToTextModule/OnlineProcessor.js.map +0 -1
  95. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js +0 -56
  96. package/lib/module/utils/SpeechToTextModule/hypothesisBuffer.js.map +0 -1
  97. package/lib/module/utils/stt.js +0 -22
  98. package/lib/module/utils/stt.js.map +0 -1
  99. package/lib/modules/BaseModule.js +0 -25
  100. package/lib/modules/BaseNonStaticModule.js +0 -14
  101. package/lib/modules/computer_vision/ClassificationModule.d.ts +0 -8
  102. package/lib/modules/computer_vision/ClassificationModule.js +0 -17
  103. package/lib/modules/computer_vision/ImageEmbeddingsModule.d.ts +0 -8
  104. package/lib/modules/computer_vision/ImageEmbeddingsModule.js +0 -17
  105. package/lib/modules/computer_vision/ImageSegmentationModule.d.ts +0 -11
  106. package/lib/modules/computer_vision/ImageSegmentationModule.js +0 -27
  107. package/lib/modules/computer_vision/OCRModule.d.ts +0 -14
  108. package/lib/modules/computer_vision/OCRModule.js +0 -17
  109. package/lib/modules/computer_vision/ObjectDetectionModule.d.ts +0 -9
  110. package/lib/modules/computer_vision/ObjectDetectionModule.js +0 -17
  111. package/lib/modules/computer_vision/StyleTransferModule.d.ts +0 -8
  112. package/lib/modules/computer_vision/StyleTransferModule.js +0 -17
  113. package/lib/modules/computer_vision/VerticalOCRModule.d.ts +0 -14
  114. package/lib/modules/computer_vision/VerticalOCRModule.js +0 -19
  115. package/lib/modules/general/ExecutorchModule.d.ts +0 -7
  116. package/lib/modules/general/ExecutorchModule.js +0 -14
  117. package/lib/modules/natural_language_processing/LLMModule.d.ts +0 -28
  118. package/lib/modules/natural_language_processing/LLMModule.js +0 -45
  119. package/lib/modules/natural_language_processing/SpeechToTextModule.d.ts +0 -24
  120. package/lib/modules/natural_language_processing/SpeechToTextModule.js +0 -36
  121. package/lib/modules/natural_language_processing/TextEmbeddingsModule.d.ts +0 -9
  122. package/lib/modules/natural_language_processing/TextEmbeddingsModule.js +0 -21
  123. package/lib/modules/natural_language_processing/TokenizerModule.d.ts +0 -12
  124. package/lib/modules/natural_language_processing/TokenizerModule.js +0 -30
  125. package/lib/native/NativeETInstaller.js +0 -2
  126. package/lib/native/NativeOCR.js +0 -2
  127. package/lib/native/NativeVerticalOCR.js +0 -2
  128. package/lib/native/RnExecutorchModules.d.ts +0 -7
  129. package/lib/native/RnExecutorchModules.js +0 -18
  130. package/lib/tsconfig.tsbuildinfo +0 -1
  131. package/lib/types/common.d.ts +0 -32
  132. package/lib/types/common.js +0 -25
  133. package/lib/types/imageSegmentation.js +0 -26
  134. package/lib/types/llm.d.ts +0 -46
  135. package/lib/types/llm.js +0 -9
  136. package/lib/types/objectDetection.js +0 -94
  137. package/lib/types/ocr.js +0 -1
  138. package/lib/types/stt.d.ts +0 -94
  139. package/lib/types/stt.js +0 -85
  140. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts +0 -27
  141. package/lib/typescript/utils/SpeechToTextModule/ASR.d.ts.map +0 -1
  142. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts +0 -23
  143. package/lib/typescript/utils/SpeechToTextModule/OnlineProcessor.d.ts.map +0 -1
  144. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts +0 -13
  145. package/lib/typescript/utils/SpeechToTextModule/hypothesisBuffer.d.ts.map +0 -1
  146. package/lib/typescript/utils/stt.d.ts +0 -2
  147. package/lib/typescript/utils/stt.d.ts.map +0 -1
  148. package/lib/utils/ResourceFetcher.d.ts +0 -24
  149. package/lib/utils/ResourceFetcher.js +0 -305
  150. package/lib/utils/ResourceFetcherUtils.d.ts +0 -54
  151. package/lib/utils/ResourceFetcherUtils.js +0 -127
  152. package/lib/utils/llm.d.ts +0 -6
  153. package/lib/utils/llm.js +0 -72
  154. package/lib/utils/stt.js +0 -21
  155. package/src/utils/SpeechToTextModule/ASR.ts +0 -303
  156. package/src/utils/SpeechToTextModule/OnlineProcessor.ts +0 -87
  157. package/src/utils/SpeechToTextModule/hypothesisBuffer.ts +0 -79
  158. package/src/utils/stt.ts +0 -28
@@ -0,0 +1,310 @@
1
+ #include <random>
2
+ #include <sstream>
3
+
4
+ #include "ASR.h"
5
+ #include "executorch/extension/tensor/tensor_ptr.h"
6
+ #include "rnexecutorch/data_processing/Numerical.h"
7
+ #include "rnexecutorch/data_processing/dsp.h"
8
+ #include "rnexecutorch/data_processing/gzip.h"
9
+
10
+ namespace rnexecutorch::models::speech_to_text::asr {
11
+
12
+ using namespace types;
13
+
14
+ ASR::ASR(const models::BaseModel *encoder, const models::BaseModel *decoder,
15
+ const TokenizerModule *tokenizer)
16
+ : encoder(encoder), decoder(decoder), tokenizer(tokenizer),
17
+ startOfTranscriptionToken(
18
+ this->tokenizer->tokenToId("<|startoftranscript|>")),
19
+ endOfTranscriptionToken(this->tokenizer->tokenToId("<|endoftext|>")),
20
+ timestampBeginToken(this->tokenizer->tokenToId("<|0.00|>")) {}
21
+
22
+ std::vector<int32_t>
23
+ ASR::getInitialSequence(const DecodingOptions &options) const {
24
+ std::vector<int32_t> seq;
25
+ seq.push_back(this->startOfTranscriptionToken);
26
+
27
+ if (options.language.has_value()) {
28
+ int32_t langToken =
29
+ this->tokenizer->tokenToId("<|" + options.language.value() + "|>");
30
+ int32_t taskToken = this->tokenizer->tokenToId("<|transcribe|>");
31
+ seq.push_back(langToken);
32
+ seq.push_back(taskToken);
33
+ }
34
+
35
+ seq.push_back(this->timestampBeginToken);
36
+
37
+ return seq;
38
+ }
39
+
40
+ GenerationResult ASR::generate(std::span<const float> waveform,
41
+ float temperature,
42
+ const DecodingOptions &options) const {
43
+ std::vector<float> encoderOutput = this->encode(waveform);
44
+
45
+ std::vector<int32_t> sequenceIds = this->getInitialSequence(options);
46
+ const size_t initialSequenceLenght = sequenceIds.size();
47
+ std::vector<float> scores;
48
+
49
+ while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) {
50
+ std::vector<float> logits = this->decode(sequenceIds, encoderOutput);
51
+
52
+ // intentionally comparing float to float
53
+ // temperatures are predefined, so this is safe
54
+ if (temperature == 0.0f) {
55
+ numerical::softmax(logits);
56
+ } else {
57
+ numerical::softmaxWithTemperature(logits, temperature);
58
+ }
59
+
60
+ const std::vector<float> &probs = logits;
61
+
62
+ int32_t nextId;
63
+ float nextProb;
64
+
65
+ // intentionally comparing float to float
66
+ // temperatures are predefined, so this is safe
67
+ if (temperature == 0.0f) {
68
+ auto maxIt = std::ranges::max_element(probs);
69
+ nextId = static_cast<int32_t>(std::distance(probs.begin(), maxIt));
70
+ nextProb = *maxIt;
71
+ } else {
72
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
73
+ std::mt19937 gen((std::random_device{}()));
74
+ nextId = dist(gen);
75
+ nextProb = probs[nextId];
76
+ }
77
+
78
+ sequenceIds.push_back(nextId);
79
+ scores.push_back(nextProb);
80
+
81
+ if (nextId == this->endOfTranscriptionToken) {
82
+ break;
83
+ }
84
+ }
85
+
86
+ return {.tokens = std::vector<int32_t>(
87
+ sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()),
88
+ .scores = scores};
89
+ }
90
+
91
+ float ASR::getCompressionRatio(const std::string &text) const {
92
+ size_t compressedSize = gzip::deflateSize(text);
93
+ return static_cast<float>(text.size()) / static_cast<float>(compressedSize);
94
+ }
95
+
96
+ std::vector<Segment>
97
+ ASR::generateWithFallback(std::span<const float> waveform,
98
+ const DecodingOptions &options) const {
99
+ std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
100
+ std::vector<int32_t> bestTokens;
101
+
102
+ for (auto t : temperatures) {
103
+ auto [tokens, scores] = this->generate(waveform, t, options);
104
+
105
+ const float cumLogProb = std::transform_reduce(
106
+ scores.begin(), scores.end(), 0.0f, std::plus<>(),
107
+ [](float s) { return std::log(std::max(s, 1e-9f)); });
108
+
109
+ const float avgLogProb = cumLogProb / static_cast<float>(tokens.size() + 1);
110
+ const std::string text = this->tokenizer->decode(tokens, true);
111
+ const float compressionRatio = this->getCompressionRatio(text);
112
+
113
+ if (avgLogProb >= -1.0f && compressionRatio < 2.4f) {
114
+ bestTokens = std::move(tokens);
115
+ break;
116
+ }
117
+ }
118
+
119
+ return this->calculateWordLevelTimestamps(bestTokens, waveform);
120
+ }
121
+
122
+ std::vector<Segment>
123
+ ASR::calculateWordLevelTimestamps(std::span<const int32_t> generatedTokens,
124
+ const std::span<const float> waveform) const {
125
+ const size_t generatedTokensSize = generatedTokens.size();
126
+ if (generatedTokensSize < 2 ||
127
+ generatedTokens[generatedTokensSize - 1] !=
128
+ this->endOfTranscriptionToken ||
129
+ generatedTokens[generatedTokensSize - 2] < this->timestampBeginToken) {
130
+ return {};
131
+ }
132
+ std::vector<Segment> segments;
133
+ std::vector<int32_t> tokens;
134
+ int32_t prevTimestamp = this->timestampBeginToken;
135
+
136
+ for (size_t i = 0; i < generatedTokensSize; i++) {
137
+ if (generatedTokens[i] < this->timestampBeginToken) {
138
+ tokens.push_back(generatedTokens[i]);
139
+ }
140
+ if (i > 0 && generatedTokens[i - 1] >= this->timestampBeginToken &&
141
+ generatedTokens[i] >= this->timestampBeginToken) {
142
+ const int32_t start = prevTimestamp;
143
+ const int32_t end = generatedTokens[i - 1];
144
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
145
+ if (words.size()) {
146
+ segments.emplace_back(std::move(words), 0.0);
147
+ }
148
+ tokens.clear();
149
+ prevTimestamp = generatedTokens[i];
150
+ }
151
+ }
152
+
153
+ const int32_t start = prevTimestamp;
154
+ const int32_t end = generatedTokens[generatedTokensSize - 2];
155
+ auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
156
+
157
+ if (words.size()) {
158
+ segments.emplace_back(std::move(words), 0.0);
159
+ }
160
+
161
+ float scalingFactor =
162
+ static_cast<float>(waveform.size()) /
163
+ (ASR::kSamplingRate * (end - this->timestampBeginToken) *
164
+ ASR::kTimePrecision);
165
+ if (scalingFactor < 1.0f) {
166
+ for (auto &seg : segments) {
167
+ for (auto &w : seg.words) {
168
+ w.start *= scalingFactor;
169
+ w.end *= scalingFactor;
170
+ }
171
+ }
172
+ }
173
+
174
+ return segments;
175
+ }
176
+
177
+ std::vector<Word>
178
+ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
179
+ int32_t start, int32_t end) const {
180
+ const std::vector<int32_t> tokensVec(tokens.begin(), tokens.end());
181
+ const std::string segmentText = this->tokenizer->decode(tokensVec, true);
182
+ std::istringstream iss(segmentText);
183
+ std::vector<std::string> wordsStr;
184
+ std::string word;
185
+ while (iss >> word) {
186
+ wordsStr.emplace_back(" ");
187
+ wordsStr.back().append(word);
188
+ }
189
+
190
+ size_t numChars = 0;
191
+ for (const auto &w : wordsStr) {
192
+ numChars += w.size();
193
+ }
194
+ const float duration = (end - start) * ASR::kTimePrecision;
195
+ const float timePerChar = duration / std::max<float>(1, numChars);
196
+ const float startOffset = (start - timestampBeginToken) * ASR::kTimePrecision;
197
+
198
+ std::vector<Word> wordObjs;
199
+ wordObjs.reserve(wordsStr.size());
200
+ int32_t prevCharCount = 0;
201
+ for (auto &w : wordsStr) {
202
+ const auto wSize = static_cast<int32_t>(w.size());
203
+ const float wStart = startOffset + prevCharCount * timePerChar;
204
+ const float wEnd = wStart + timePerChar * wSize;
205
+ prevCharCount += wSize;
206
+ wordObjs.emplace_back(std::move(w), wStart, wEnd);
207
+ }
208
+
209
+ return wordObjs;
210
+ }
211
+
212
+ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
213
+ const DecodingOptions &options) const {
214
+ int32_t seek = 0;
215
+ std::vector<Segment> results;
216
+
217
+ while (std::cmp_less(seek * ASR::kSamplingRate, waveform.size())) {
218
+ int32_t start = seek * ASR::kSamplingRate;
219
+ const auto end = std::min<int32_t>(
220
+ (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
221
+ std::span<const float> chunk = waveform.subspan(start, end - start);
222
+
223
+ if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
224
+ break;
225
+ }
226
+
227
+ std::vector<Segment> segments = this->generateWithFallback(chunk, options);
228
+
229
+ if (segments.empty()) {
230
+ seek += ASR::kChunkSize;
231
+ continue;
232
+ }
233
+
234
+ for (auto &seg : segments) {
235
+ for (auto &w : seg.words) {
236
+ w.start += seek;
237
+ w.end += seek;
238
+ }
239
+ }
240
+
241
+ seek = static_cast<int32_t>(segments.back().words.back().end);
242
+ results.insert(results.end(), std::make_move_iterator(segments.begin()),
243
+ std::make_move_iterator(segments.end()));
244
+ }
245
+
246
+ return results;
247
+ }
248
+
249
+ std::vector<float> ASR::encode(std::span<const float> waveform) const {
250
+ constexpr int32_t fftWindowSize = 512;
251
+ constexpr int32_t stftHopLength = 160;
252
+ constexpr int32_t innerDim = 256;
253
+
254
+ std::vector<float> preprocessedData =
255
+ dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
256
+ const auto numFrames =
257
+ static_cast<int32_t>(preprocessedData.size()) / innerDim;
258
+ std::vector<int32_t> inputShape = {numFrames, innerDim};
259
+
260
+ const auto modelInputTensor = executorch::extension::make_tensor_ptr(
261
+ std::move(inputShape), std::move(preprocessedData));
262
+ const auto encoderResult = this->encoder->forward(modelInputTensor);
263
+
264
+ if (!encoderResult.ok()) {
265
+ throw std::runtime_error(
266
+ "Forward pass failed during encoding, error code: " +
267
+ std::to_string(static_cast<int32_t>(encoderResult.error())));
268
+ }
269
+
270
+ const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
271
+ const int32_t outputNumel = decoderOutputTensor.numel();
272
+
273
+ const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
274
+ return {dataPtr, dataPtr + outputNumel};
275
+ }
276
+
277
+ std::vector<float> ASR::decode(std::span<int32_t> tokens,
278
+ std::span<float> encoderOutput) const {
279
+ std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
280
+ auto tokenTensor = executorch::extension::make_tensor_ptr(
281
+ std::move(tokenShape), tokens.data(), ScalarType::Int);
282
+
283
+ const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
284
+ std::vector<int32_t> encShape = {1, ASR::kNumFrames,
285
+ encoderOutputSize / ASR::kNumFrames};
286
+ auto encoderTensor = executorch::extension::make_tensor_ptr(
287
+ std::move(encShape), encoderOutput.data(), ScalarType::Float);
288
+
289
+ const auto decoderResult =
290
+ this->decoder->forward({tokenTensor, encoderTensor});
291
+
292
+ if (!decoderResult.ok()) {
293
+ throw std::runtime_error(
294
+ "Forward pass failed during decoding, error code: " +
295
+ std::to_string(static_cast<int32_t>(decoderResult.error())));
296
+ }
297
+
298
+ const auto logitsTensor = decoderResult.get().at(0).toTensor();
299
+ const int32_t outputNumel = logitsTensor.numel();
300
+
301
+ const size_t innerDim = logitsTensor.size(1);
302
+ const size_t dictSize = logitsTensor.size(2);
303
+
304
+ const float *const dataPtr =
305
+ logitsTensor.const_data_ptr<float>() + (innerDim - 1) * dictSize;
306
+
307
+ return {dataPtr, dataPtr + outputNumel / innerDim};
308
+ }
309
+
310
+ } // namespace rnexecutorch::models::speech_to_text::asr
@@ -0,0 +1,62 @@
1
+ #pragma once
2
+
3
+ #include "rnexecutorch/TokenizerModule.h"
4
+ #include "rnexecutorch/models/BaseModel.h"
5
+ #include "rnexecutorch/models/speech_to_text/types/DecodingOptions.h"
6
+ #include "rnexecutorch/models/speech_to_text/types/GenerationResult.h"
7
+ #include "rnexecutorch/models/speech_to_text/types/Segment.h"
8
+
9
+ namespace rnexecutorch::models::speech_to_text::asr {
10
+
11
+ class ASR {
12
+ public:
13
+ explicit ASR(const models::BaseModel *encoder,
14
+ const models::BaseModel *decoder,
15
+ const TokenizerModule *tokenizer);
16
+ std::vector<types::Segment>
17
+ transcribe(std::span<const float> waveform,
18
+ const types::DecodingOptions &options) const;
19
+ std::vector<float> encode(std::span<const float> waveform) const;
20
+ std::vector<float> decode(std::span<int32_t> tokens,
21
+ std::span<float> encoderOutput) const;
22
+
23
+ private:
24
+ const models::BaseModel *encoder;
25
+ const models::BaseModel *decoder;
26
+ const TokenizerModule *tokenizer;
27
+
28
+ int32_t startOfTranscriptionToken;
29
+ int32_t endOfTranscriptionToken;
30
+ int32_t timestampBeginToken;
31
+
32
+ // Time precision used by Whisper timestamps: each token spans 0.02 seconds
33
+ constexpr static float kTimePrecision = 0.02f;
34
+ // The maximum number of tokens the decoder can generate per chunk
35
+ constexpr static int32_t kMaxDecodeLength = 128;
36
+ // Maximum duration of each audio chunk to process (in seconds)
37
+ constexpr static int32_t kChunkSize = 30;
38
+ // Sampling rate expected by Whisper and the model's audio pipeline (16 kHz)
39
+ constexpr static int32_t kSamplingRate = 16000;
40
+ // Minimum allowed chunk length before processing (in audio samples)
41
+ constexpr static int32_t kMinChunkSamples = 1 * 16000;
42
+ // Number of mel frames output by the encoder (derived from input spectrogram)
43
+ constexpr static int32_t kNumFrames = 1500;
44
+
45
+ std::vector<int32_t>
46
+ getInitialSequence(const types::DecodingOptions &options) const;
47
+ types::GenerationResult generate(std::span<const float> waveform,
48
+ float temperature,
49
+ const types::DecodingOptions &options) const;
50
+ std::vector<types::Segment>
51
+ generateWithFallback(std::span<const float> waveform,
52
+ const types::DecodingOptions &options) const;
53
+ std::vector<types::Segment>
54
+ calculateWordLevelTimestamps(std::span<const int32_t> tokens,
55
+ std::span<const float> waveform) const;
56
+ std::vector<types::Word>
57
+ estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
58
+ int32_t start, int32_t end) const;
59
+ float getCompressionRatio(const std::string &text) const;
60
+ };
61
+
62
+ } // namespace rnexecutorch::models::speech_to_text::asr
@@ -0,0 +1,82 @@
1
+ #include "HypothesisBuffer.h"
2
+
3
+ namespace rnexecutorch::models::speech_to_text::stream {
4
+
5
+ using namespace types;
6
+
7
+ void HypothesisBuffer::insert(std::span<const Word> newWords, float offset) {
8
+ this->fresh.clear();
9
+ for (const auto &word : newWords) {
10
+ const float newStart = word.start + offset;
11
+ if (newStart > lastCommittedTime - 0.5f) {
12
+ this->fresh.emplace_back(word.content, newStart, word.end + offset);
13
+ }
14
+ }
15
+
16
+ if (!this->fresh.empty() && !this->committedInBuffer.empty()) {
17
+ const float a = this->fresh.front().start;
18
+ if (std::fabs(a - lastCommittedTime) < 1.0f) {
19
+ const size_t cn = this->committedInBuffer.size();
20
+ const size_t nn = this->fresh.size();
21
+ const std::size_t maxCheck = std::min<std::size_t>({cn, nn, 5});
22
+ for (size_t i = 1; i <= maxCheck; i++) {
23
+ std::string c;
24
+ for (auto it = this->committedInBuffer.cend() - i;
25
+ it != this->committedInBuffer.cend(); ++it) {
26
+ if (!c.empty()) {
27
+ c += ' ';
28
+ }
29
+ c += it->content;
30
+ }
31
+
32
+ std::string tail;
33
+ auto it = this->fresh.cbegin();
34
+ for (size_t k = 0; k < i; k++, it++) {
35
+ if (!tail.empty()) {
36
+ tail += ' ';
37
+ }
38
+ tail += it->content;
39
+ }
40
+
41
+ if (c == tail) {
42
+ this->fresh.erase(this->fresh.begin(), this->fresh.begin() + i);
43
+ break;
44
+ }
45
+ }
46
+ }
47
+ }
48
+ }
49
+
50
+ std::deque<Word> HypothesisBuffer::flush() {
51
+ std::deque<Word> commit;
52
+
53
+ while (!this->fresh.empty() && !this->buffer.empty()) {
54
+ if (this->fresh.front().content != this->buffer.front().content) {
55
+ break;
56
+ }
57
+ commit.push_back(this->fresh.front());
58
+ this->buffer.pop_front();
59
+ this->fresh.pop_front();
60
+ }
61
+
62
+ if (!commit.empty()) {
63
+ lastCommittedTime = commit.back().end;
64
+ }
65
+
66
+ this->buffer = std::move(this->fresh);
67
+ this->fresh.clear();
68
+ this->committedInBuffer.insert(this->committedInBuffer.end(), commit.begin(),
69
+ commit.end());
70
+ return commit;
71
+ }
72
+
73
+ void HypothesisBuffer::popCommitted(float time) {
74
+ while (!this->committedInBuffer.empty() &&
75
+ this->committedInBuffer.front().end <= time) {
76
+ this->committedInBuffer.pop_front();
77
+ }
78
+ }
79
+
80
+ std::deque<Word> HypothesisBuffer::complete() const { return this->buffer; }
81
+
82
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,25 @@
1
+ #pragma once
2
+
3
+ #include <deque>
4
+ #include <span>
5
+
6
+ #include "rnexecutorch/models/speech_to_text/types/Word.h"
7
+
8
+ namespace rnexecutorch::models::speech_to_text::stream {
9
+
10
+ class HypothesisBuffer {
11
+ public:
12
+ void insert(std::span<const types::Word> newWords, float offset);
13
+ std::deque<types::Word> flush();
14
+ void popCommitted(float time);
15
+ std::deque<types::Word> complete() const;
16
+
17
+ private:
18
+ float lastCommittedTime = 0.0f;
19
+
20
+ std::deque<types::Word> committedInBuffer;
21
+ std::deque<types::Word> buffer;
22
+ std::deque<types::Word> fresh;
23
+ };
24
+
25
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,99 @@
1
+ #include <numeric>
2
+
3
+ #include "OnlineASRProcessor.h"
4
+
5
+ namespace rnexecutorch::models::speech_to_text::stream {
6
+
7
+ using namespace asr;
8
+ using namespace types;
9
+
10
+ OnlineASRProcessor::OnlineASRProcessor(const ASR *asr) : asr(asr) {}
11
+
12
+ void OnlineASRProcessor::insertAudioChunk(std::span<const float> audio) {
13
+ audioBuffer.insert(audioBuffer.end(), audio.begin(), audio.end());
14
+ }
15
+
16
+ ProcessResult OnlineASRProcessor::processIter(const DecodingOptions &options) {
17
+ std::vector<Segment> res = asr->transcribe(audioBuffer, options);
18
+
19
+ std::vector<Word> tsw;
20
+ for (const auto &segment : res) {
21
+ for (const auto &word : segment.words) {
22
+ tsw.push_back(word);
23
+ }
24
+ }
25
+
26
+ this->hypothesisBuffer.insert(tsw, this->bufferTimeOffset);
27
+ std::deque<Word> flushed = this->hypothesisBuffer.flush();
28
+ this->committed.insert(this->committed.end(), flushed.begin(), flushed.end());
29
+
30
+ constexpr int32_t chunkThresholdSec = 15;
31
+ if (static_cast<float>(audioBuffer.size()) /
32
+ OnlineASRProcessor::kSamplingRate >
33
+ chunkThresholdSec) {
34
+ chunkCompletedSegment(res);
35
+ }
36
+
37
+ std::deque<Word> nonCommittedWords = this->hypothesisBuffer.complete();
38
+ return {this->toFlush(flushed), this->toFlush(nonCommittedWords)};
39
+ }
40
+
41
+ void OnlineASRProcessor::chunkCompletedSegment(std::span<const Segment> res) {
42
+ if (this->committed.empty())
43
+ return;
44
+
45
+ std::vector<float> ends(res.size());
46
+ std::ranges::transform(res, ends.begin(), [](const Segment &seg) {
47
+ return seg.words.back().end;
48
+ });
49
+
50
+ const float t = this->committed.back().end;
51
+
52
+ if (ends.size() > 1) {
53
+ float e = ends[ends.size() - 2] + this->bufferTimeOffset;
54
+ while (ends.size() > 2 && e > t) {
55
+ ends.pop_back();
56
+ e = ends[ends.size() - 2] + this->bufferTimeOffset;
57
+ }
58
+ if (e <= t) {
59
+ chunkAt(e);
60
+ }
61
+ }
62
+ }
63
+
64
+ void OnlineASRProcessor::chunkAt(float time) {
65
+ this->hypothesisBuffer.popCommitted(time);
66
+
67
+ const float cutSeconds = time - this->bufferTimeOffset;
68
+ auto startIndex =
69
+ static_cast<size_t>(cutSeconds * OnlineASRProcessor::kSamplingRate);
70
+
71
+ if (startIndex < audioBuffer.size()) {
72
+ audioBuffer.erase(audioBuffer.begin(), audioBuffer.begin() + startIndex);
73
+ } else {
74
+ audioBuffer.clear();
75
+ }
76
+
77
+ this->bufferTimeOffset = time;
78
+ }
79
+
80
+ std::string OnlineASRProcessor::finish() {
81
+ const std::deque<Word> buffer = this->hypothesisBuffer.complete();
82
+ std::string committedText = this->toFlush(buffer);
83
+ this->bufferTimeOffset += static_cast<float>(audioBuffer.size()) /
84
+ OnlineASRProcessor::kSamplingRate;
85
+ return committedText;
86
+ }
87
+
88
+ std::string OnlineASRProcessor::toFlush(const std::deque<Word> &words) const {
89
+ std::string text;
90
+ text.reserve(std::accumulate(
91
+ words.cbegin(), words.cend(), 0,
92
+ [](size_t sum, const Word &w) { return sum + w.content.size(); }));
93
+ for (const auto &word : words) {
94
+ text.append(word.content);
95
+ }
96
+ return text;
97
+ }
98
+
99
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,33 @@
1
+ #pragma once
2
+
3
+ #include "rnexecutorch/models/speech_to_text/asr/ASR.h"
4
+ #include "rnexecutorch/models/speech_to_text/stream/HypothesisBuffer.h"
5
+ #include "rnexecutorch/models/speech_to_text/types/ProcessResult.h"
6
+
7
+ namespace rnexecutorch::models::speech_to_text::stream {
8
+
9
+ class OnlineASRProcessor {
10
+ public:
11
+ explicit OnlineASRProcessor(const asr::ASR *asr);
12
+
13
+ void insertAudioChunk(std::span<const float> audio);
14
+ types::ProcessResult processIter(const types::DecodingOptions &options);
15
+ std::string finish();
16
+
17
+ std::vector<float> audioBuffer;
18
+
19
+ private:
20
+ const asr::ASR *asr;
21
+ constexpr static int32_t kSamplingRate = 16000;
22
+
23
+ HypothesisBuffer hypothesisBuffer;
24
+ float bufferTimeOffset = 0.0f;
25
+ std::vector<types::Word> committed;
26
+
27
+ void chunkCompletedSegment(std::span<const types::Segment> res);
28
+ void chunkAt(float time);
29
+
30
+ std::string toFlush(const std::deque<types::Word> &words) const;
31
+ };
32
+
33
+ } // namespace rnexecutorch::models::speech_to_text::stream
@@ -0,0 +1,15 @@
1
+ #pragma once
2
+
3
+ #include <optional>
4
+ #include <string>
5
+
6
+ namespace rnexecutorch::models::speech_to_text::types {
7
+
8
+ struct DecodingOptions {
9
+ explicit DecodingOptions(const std::string &language)
10
+ : language(language.empty() ? std::nullopt : std::optional(language)) {}
11
+
12
+ std::optional<std::string> language;
13
+ };
14
+
15
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct GenerationResult {
8
+ std::vector<int32_t> tokens;
9
+ std::vector<float> scores;
10
+ };
11
+
12
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include <string>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct ProcessResult {
8
+ std::string committed;
9
+ std::string nonCommitted;
10
+ };
11
+
12
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,14 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "Word.h"
6
+
7
+ namespace rnexecutorch::models::speech_to_text::types {
8
+
9
+ struct Segment {
10
+ std::vector<Word> words;
11
+ float noSpeechProbability;
12
+ };
13
+
14
+ } // namespace rnexecutorch::models::speech_to_text::types
@@ -0,0 +1,13 @@
1
+ #pragma once
2
+
3
+ #include <string>
4
+
5
+ namespace rnexecutorch::models::speech_to_text::types {
6
+
7
+ struct Word {
8
+ std::string content;
9
+ float start;
10
+ float end;
11
+ };
12
+
13
+ } // namespace rnexecutorch::models::speech_to_text::types