react-native-executorch 0.9.0 → 0.9.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. package/android/libs/classes.jar +0 -0
  2. package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
  3. package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
  4. package/common/rnexecutorch/models/llm/LLM.h +4 -3
  5. package/common/rnexecutorch/models/llm/Types.h +23 -0
  6. package/common/runner/base_llm_runner.cpp +10 -3
  7. package/common/runner/base_llm_runner.h +1 -0
  8. package/common/runner/constants.h +15 -1
  9. package/common/runner/encoders/audio_encoder.cpp +111 -0
  10. package/common/runner/encoders/audio_encoder.h +40 -0
  11. package/common/runner/encoders/vision_encoder.cpp +13 -5
  12. package/common/runner/encoders/vision_encoder.h +15 -2
  13. package/common/runner/irunner.h +5 -0
  14. package/common/runner/multimodal_decoder_runner.h +50 -1
  15. package/common/runner/multimodal_input.h +16 -1
  16. package/common/runner/multimodal_prefiller.cpp +374 -64
  17. package/common/runner/multimodal_prefiller.h +57 -6
  18. package/common/runner/multimodal_runner.cpp +19 -12
  19. package/common/runner/multimodal_runner.h +1 -1
  20. package/common/runner/sampler.cpp +126 -39
  21. package/common/runner/sampler.h +13 -5
  22. package/common/runner/text_decoder_runner.cpp +1 -4
  23. package/common/runner/text_decoder_runner.h +3 -2
  24. package/common/runner/text_prefiller.cpp +8 -8
  25. package/common/runner/text_prefiller.h +8 -1
  26. package/common/runner/text_runner.cpp +35 -9
  27. package/common/runner/text_token_generator.h +2 -3
  28. package/common/runner/util.h +0 -1
  29. package/lib/module/constants/llmDefaults.js +1 -1
  30. package/lib/module/constants/llmDefaults.js.map +1 -1
  31. package/lib/module/constants/modelRegistry.js +62 -3
  32. package/lib/module/constants/modelRegistry.js.map +1 -1
  33. package/lib/module/constants/modelUrls.js +62 -6
  34. package/lib/module/constants/modelUrls.js.map +1 -1
  35. package/lib/module/controllers/LLMController.js +69 -20
  36. package/lib/module/controllers/LLMController.js.map +1 -1
  37. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
  38. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  39. package/lib/module/modules/computer_vision/PoseEstimationModule.js +13 -1
  40. package/lib/module/modules/computer_vision/PoseEstimationModule.js.map +1 -1
  41. package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
  42. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  43. package/lib/module/types/llm.js +11 -0
  44. package/lib/module/types/llm.js.map +1 -1
  45. package/lib/module/types/poseEstimation.js.map +1 -1
  46. package/lib/typescript/constants/llmDefaults.d.ts +1 -1
  47. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
  48. package/lib/typescript/constants/modelRegistry.d.ts +38 -1
  49. package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
  50. package/lib/typescript/constants/modelUrls.d.ts +52 -12
  51. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  52. package/lib/typescript/controllers/LLMController.d.ts +7 -9
  53. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  54. package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts +6 -0
  55. package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts.map +1 -1
  56. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
  57. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  58. package/lib/typescript/types/llm.d.ts +63 -36
  59. package/lib/typescript/types/llm.d.ts.map +1 -1
  60. package/lib/typescript/types/poseEstimation.d.ts +3 -0
  61. package/lib/typescript/types/poseEstimation.d.ts.map +1 -1
  62. package/package.json +1 -1
  63. package/react-native-executorch.podspec +6 -0
  64. package/src/constants/llmDefaults.ts +1 -1
  65. package/src/constants/modelRegistry.ts +62 -2
  66. package/src/constants/modelUrls.ts +69 -6
  67. package/src/controllers/LLMController.ts +89 -40
  68. package/src/hooks/natural_language_processing/useLLM.ts +5 -6
  69. package/src/modules/computer_vision/PoseEstimationModule.ts +12 -0
  70. package/src/modules/natural_language_processing/LLMModule.ts +19 -8
  71. package/src/types/llm.ts +64 -34
  72. package/src/types/poseEstimation.ts +10 -4
  73. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  74. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  75. package/third-party/include/executorch/ExecuTorch.h +2 -0
  76. package/third-party/include/executorch/ExecuTorchModule.h +46 -0
  77. package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
  78. package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
  79. package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
  80. package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
  81. package/third-party/include/executorch/extension/module/module.h +47 -8
  82. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
  83. package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
  84. package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
  85. package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
  86. package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
  87. package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
  88. package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
  89. package/third-party/include/executorch/runtime/core/error.h +1 -0
  90. package/third-party/include/executorch/runtime/core/evalue.h +256 -9
  91. package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
  92. package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
  93. package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
  94. package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
  95. package/third-party/include/executorch/runtime/executor/method.h +9 -3
  96. package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
  97. package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
  98. package/third-party/include/executorch/runtime/executor/program.h +3 -1
  99. package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
  100. package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
  101. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  102. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  103. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
  104. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  105. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  106. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
@@ -3,7 +3,6 @@
3
3
  #include "constants.h"
4
4
  #include "util.h"
5
5
  #include <rnexecutorch/Error.h>
6
- #include <rnexecutorch/Log.h>
7
6
 
8
7
  namespace executorch::extension::llm {
9
8
 
@@ -54,8 +53,14 @@ Error MultimodalRunner::load_subcomponents() {
54
53
  if (enc_it != encoders_.end()) {
55
54
  image_encoder = enc_it->second.get();
56
55
  }
56
+ IEncoder *audio_encoder = nullptr;
57
+ auto aud_it = encoders_.find(MultimodalType::Audio);
58
+ if (aud_it != encoders_.end()) {
59
+ audio_encoder = aud_it->second.get();
60
+ }
57
61
  mm_prefiller_ = std::make_unique<MultimodalPrefiller>(
58
- *module_, *mm_decoder_runner_, *tokenizer_, image_encoder);
62
+ *module_, *mm_decoder_runner_, *tokenizer_, metadata_, image_encoder,
63
+ audio_encoder);
59
64
  mm_token_generator_ = std::make_unique<TextTokenGenerator>(
60
65
  tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true,
61
66
  std::move(eos_ids_), stats_ptr, config_);
@@ -78,22 +83,24 @@ Error MultimodalRunner::generate_internal(
78
83
  }
79
84
 
80
85
  stats_.inference_start_ms = time_in_ms();
81
-
82
- uint64_t prefill_next_token = 0;
83
- for (const auto &input : inputs) {
84
- auto prefill_result = mm_prefiller_->prefill(input, pos_);
85
- if (!prefill_result.ok())
86
- return prefill_result.error();
87
- prefill_next_token = prefill_result.get();
88
- }
86
+ auto prefill_result = mm_prefiller_->prefill(inputs, pos_);
87
+ if (!prefill_result.ok())
88
+ return prefill_result.error();
89
+ uint64_t prefill_next_token = prefill_result.get();
89
90
 
90
91
  stats_.first_token_ms = time_in_ms();
91
92
  stats_.prompt_eval_end_ms = time_in_ms();
92
93
  stats_.num_prompt_tokens = pos_;
93
94
 
95
+ // For dynamic-shape PTEs (Gemma4 iter*), get_max_seq_len is the per-call
96
+ // decoder chunk size (e.g. 128) and the true generation budget lives in
97
+ // get_max_context_len. Mirrors text_runner.cpp:95-97.
98
+ const int32_t seq_cap = config_.enable_dynamic_shape
99
+ ? config_.max_context_length
100
+ : config_.max_seq_len;
94
101
  int32_t resolved_max_new = resolve_max_new_tokens(
95
- static_cast<int32_t>(pos_), config_.max_seq_len,
96
- config_.max_context_length, config_.max_new_tokens);
102
+ static_cast<int32_t>(pos_), seq_cap, config_.max_context_length,
103
+ config_.max_new_tokens);
97
104
 
98
105
  std::vector<uint64_t> seed_tokens = {prefill_next_token};
99
106
  auto wrapped_callback = [&](const std::string &piece) {
@@ -10,7 +10,7 @@
10
10
 
11
11
  namespace executorch::extension::llm {
12
12
 
13
- enum class MultimodalType { Image };
13
+ enum class MultimodalType { Image, Audio };
14
14
 
15
15
  class MultimodalRunner : public BaseLLMRunner {
16
16
  public:
@@ -35,6 +35,10 @@
35
35
  #include "sampler.h"
36
36
  #include <algorithm>
37
37
  #include <ctime>
38
+ #include <limits>
39
+ #include <ranges>
40
+ #include <span>
41
+ #include <type_traits>
38
42
  #include <vector>
39
43
 
40
44
  namespace executorch {
@@ -46,7 +50,7 @@ template <typename T> int32_t Sampler::sample_argmax(T *probabilities) {
46
50
  // return the index that has the highest probability
47
51
  int max_i = 0;
48
52
  T max_p = probabilities[0];
49
- for (int i = 1; i < vocab_size_; i++) {
53
+ for (size_t i = 1; i < vocab_size_; i++) {
50
54
  if (probabilities[i] > max_p) {
51
55
  max_i = i;
52
56
  max_p = probabilities[i];
@@ -60,7 +64,7 @@ int32_t Sampler::sample_mult(T *probabilities, float coin) {
60
64
  // sample index from probabilities (they must sum to 1!)
61
65
  // coin is a random number in [0, 1), usually from random_f32()
62
66
  T cdf = 0.0;
63
- for (int i = 0; i < vocab_size_; i++) {
67
+ for (size_t i = 0; i < vocab_size_; i++) {
64
68
  cdf += probabilities[i];
65
69
  if (coin < cdf) {
66
70
  return i;
@@ -84,7 +88,7 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
84
88
  std::make_unique<ProbIndex<T>[]>(vocab_size_);
85
89
 
86
90
  const float cutoff = (1.0f - topp_) / (n - 1);
87
- for (int i = 0; i < n; i++) {
91
+ for (size_t i = 0; i < n; i++) {
88
92
  if (probabilities[i] >= cutoff) {
89
93
  probindex[n0].index = i;
90
94
  probindex[n0].prob = probabilities[i];
@@ -92,62 +96,147 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
92
96
  }
93
97
  }
94
98
 
95
- auto compare = [](const ProbIndex<T> &a, const ProbIndex<T> &b) {
96
- return a.prob > b.prob;
97
- };
98
- std::sort(probindex.get(), probindex.get() + n0, compare);
99
+ std::sort(probindex.get(), probindex.get() + n0,
100
+ [](const ProbIndex<T> &a, const ProbIndex<T> &b) {
101
+ return a.prob > b.prob;
102
+ });
99
103
 
100
104
  // truncate the list where cumulative probability exceeds topp
101
105
  T cumulative_prob = 0;
102
- int last_idx = n0 - 1; // in case of rounding errors consider all elements
103
- for (int i = 0; i < n0; i++) {
106
+ int last_idx = n0 - 1;
107
+ for (size_t i = 0; i < n0; i++) {
104
108
  cumulative_prob += probindex[i].prob;
105
- if (cumulative_prob > topp_) {
109
+ if (static_cast<float>(cumulative_prob) > topp_) {
106
110
  last_idx = i;
107
- break; // we've exceeded topp by including last_idx
111
+ break;
108
112
  }
109
113
  }
110
114
 
111
115
  // sample from the truncated list
112
- const T &r = coin * cumulative_prob;
116
+ float r = coin * static_cast<float>(cumulative_prob);
113
117
  T cdf = 0;
114
- for (int i = 0; i <= last_idx; i++) {
118
+ for (size_t i = 0; i <= last_idx; i++) {
115
119
  cdf += probindex[i].prob;
116
- if (r < cdf) {
120
+ if (r < static_cast<float>(cdf)) {
117
121
  return probindex[i].index;
118
122
  }
119
123
  }
120
- return probindex[last_idx].index; // in case of rounding errors
124
+ return probindex[last_idx].index;
121
125
  }
122
126
 
123
- Sampler::Sampler(int32_t vocab_size, float temperature, float topp,
124
- unsigned long long rng_seed, float min_p,
125
- float repetition_penalty)
127
+ // Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary
128
+ // are kept (matches HuggingFace TopKLogitsWarper).
129
+ template <typename T> void Sampler::mask_topk(T *logits) {
130
+ if (topk_ <= 0 || topk_ >= vocab_size_) {
131
+ return;
132
+ }
133
+ // Partial-select the (topk_-th largest) threshold using nth_element on a
134
+ // copy of logits; O(n) average.
135
+ std::vector<T> scratch(logits, logits + vocab_size_);
136
+ std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1),
137
+ scratch.end(), std::greater<T>());
138
+ const T threshold = scratch[topk_ - 1];
139
+ constexpr T neg_inf = std::numeric_limits<T>::lowest();
140
+ for (size_t i = 0; i < vocab_size_; i++) {
141
+ if (logits[i] < threshold) {
142
+ logits[i] = neg_inf;
143
+ }
144
+ }
145
+ }
146
+
147
+ // Mask logits outside the top-p nucleus to -inf. Approximates the exact
148
+ // sort-based nucleus with a histogram over (logit - max): two O(n) passes, no
149
+ // sort. Binning in logit (not probability) space keeps uniform resolution for
150
+ // peaked and flat distributions alike. kRange=40 spans exp() down to ~4e-18.
151
+ template <typename T> void Sampler::mask_topp(T *logits) {
152
+ if (topp_ <= 0.0f || topp_ >= 1.0f) {
153
+ return;
154
+ }
155
+ constexpr int32_t kBins = 2048;
156
+ // Compute in a type at least as wide as T so converting logits never loses
157
+ // precision: double stays double, everything else (float and the narrow
158
+ // half/bf16/uint16 logit types) widens to float. Accumulating in T directly
159
+ // would be unsafe for bf16, whose mantissa saturates when summing exp()
160
+ // over the full vocab.
161
+ using acc_t = std::conditional_t<std::is_same_v<T, double>, double, float>;
162
+ constexpr acc_t kRange = 40;
163
+
164
+ std::span<const T> logit_span{logits, static_cast<size_t>(vocab_size_)};
165
+ const acc_t max_val =
166
+ static_cast<acc_t>(*std::ranges::max_element(logit_span));
167
+
168
+ std::vector<acc_t> bin_mass(kBins, acc_t(0));
169
+ acc_t total = 0;
170
+ for (size_t i = 0; i < vocab_size_; i++) {
171
+ acc_t d = static_cast<acc_t>(logits[i]) - max_val;
172
+ acc_t e = std::exp(d);
173
+ total += e;
174
+ int32_t bin = static_cast<int32_t>((d + kRange) / kRange * kBins);
175
+ bin = std::clamp(bin, 0, kBins - 1);
176
+ bin_mass[bin] += e;
177
+ }
178
+ if (total <= acc_t(0)) {
179
+ return;
180
+ }
181
+
182
+ // Highest bin downward until the kept mass reaches topp. The crossing bin is
183
+ // kept (HuggingFace "keep the token that crosses" convention).
184
+ const acc_t target = static_cast<acc_t>(topp_) * total;
185
+ acc_t acc = 0;
186
+ int32_t keep_bin = 0;
187
+ for (int32_t bin = kBins - 1; bin >= 0; --bin) {
188
+ acc += bin_mass[bin];
189
+ if (acc >= target) {
190
+ keep_bin = bin;
191
+ break;
192
+ }
193
+ }
194
+ const acc_t d_threshold =
195
+ static_cast<acc_t>(keep_bin) / kBins * kRange - kRange;
196
+
197
+ constexpr T neg_inf = std::numeric_limits<T>::lowest();
198
+ for (size_t i = 0; i < vocab_size_; i++) {
199
+ if (static_cast<acc_t>(logits[i]) - max_val < d_threshold) {
200
+ logits[i] = neg_inf;
201
+ }
202
+ }
203
+ }
204
+
205
+ Sampler::Sampler(int32_t vocab_size, GenerationConfig config,
206
+ unsigned long long rng_seed)
126
207
  : vocab_size_(vocab_size),
127
- inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
128
- topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty),
208
+ inv_temperature_(
209
+ (config.temperature != 0.0f) ? (1.0f / config.temperature) : 0.0f),
210
+ topp_(config.topp), min_p_(config.min_p),
211
+ repetition_penalty_(config.repetition_penalty), topk_(config.topk),
129
212
  rng_state_(rng_seed) {}
130
213
 
131
- Sampler::Sampler(int vocab_size, float temperature, float topp)
132
- : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {}
214
+ Sampler::Sampler(int32_t vocab_size, GenerationConfig config)
215
+ : Sampler(vocab_size, config, std::time(nullptr)) {}
133
216
 
134
217
  template <typename T> static void softmax(T *x, int size) {
135
- // find max value (for numerical stability)
218
+ // Runs after top-k/top-p masking, which sets rejected logits to lowest().
219
+ // Skip exp() on those: it underflows to 0 anyway and is slow on device.
220
+ constexpr T kMasked = std::numeric_limits<T>::lowest();
136
221
  T max_val = x[0];
137
- for (int i = 1; i < size; i++) {
222
+ for (size_t i = 1; i < size; i++) {
138
223
  if (x[i] > max_val) {
139
224
  max_val = x[i];
140
225
  }
141
226
  }
142
- // exp and sum
143
227
  T sum = 0;
144
- for (int i = 0; i < size; i++) {
228
+ for (size_t i = 0; i < size; i++) {
229
+ if (x[i] == kMasked) {
230
+ x[i] = T(0);
231
+ continue;
232
+ }
145
233
  x[i] = expf(x[i] - max_val);
146
234
  sum += x[i];
147
235
  }
148
- // normalize
149
- for (int i = 0; i < size; i++) {
150
- x[i] /= sum;
236
+ for (size_t i = 0; i < size; i++) {
237
+ if (x[i] != T(0)) {
238
+ x[i] /= sum;
239
+ }
151
240
  }
152
241
  }
153
242
 
@@ -175,20 +264,18 @@ int32_t Sampler::sample(T *logits, const std::vector<uint64_t> &recent_tokens) {
175
264
  apply_repetition_penalty(logits, vocab_size_, recent_tokens);
176
265
  // 2. apply the temperature to the logits
177
266
  apply_temperature(logits, vocab_size_);
178
- // 3. apply softmax to the logits to get the probabilities for next token
267
+ // 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass)
268
+ mask_topk(logits);
269
+ // 4. mask out logits outside top-p by rank (pre-softmax)
270
+ mask_topp(logits);
271
+ // 5. apply softmax to the logits to get the probabilities for next token
179
272
  softmax(logits, vocab_size_);
180
- // 4. apply min_p truncation
273
+ // 6. apply min_p truncation
181
274
  apply_min_p(logits, vocab_size_);
182
275
  // flip a (float) coin (this is our source of entropy for sampling)
183
276
  float coin = random_f32(&rng_state_);
184
- // 5. we sample from this distribution to get the next token
185
- if (topp_ <= 0 || topp_ >= 1) {
186
- // simply sample from the predicted probability distribution
187
- next = sample_mult(logits, coin);
188
- } else {
189
- // top-p (nucleus) sampling, clamping the least likely tokens to zero
190
- next = sample_topp(logits, coin);
191
- }
277
+ // 7. we sample from this distribution to get the next token
278
+ next = sample_mult(logits, coin);
192
279
  }
193
280
  return next;
194
281
  }
@@ -8,6 +8,7 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include "runner/irunner.h"
11
12
  #include <algorithm>
12
13
  #include <cctype>
13
14
  #include <cmath>
@@ -28,6 +29,7 @@ namespace executorch {
28
29
  namespace extension {
29
30
  namespace llm {
30
31
  // A simple llama2 sampler.
32
+ struct GenerationConfig;
31
33
 
32
34
  inline constexpr auto kTopp = 0.9f;
33
35
 
@@ -38,11 +40,13 @@ template <typename T> struct ProbIndex {
38
40
 
39
41
  class Sampler {
40
42
  public:
41
- Sampler(int32_t vocab_size, float temperature, float topp,
42
- unsigned long long rng_seed, float min_p = 0.0f,
43
- float repetition_penalty = 1.0f);
44
-
45
- Sampler(int32_t vocab_size, float temperature, float topp);
43
+ // topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p.
44
+ // Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask
45
+ // -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses
46
+ // to greedy; pass topk = 0 to keep full-vocab temperature sampling.
47
+ Sampler(int32_t vocab_size, GenerationConfig config,
48
+ unsigned long long rng_seed);
49
+ Sampler(int32_t vocab_size, GenerationConfig config);
46
50
 
47
51
  template <typename T> int32_t sample(T *logits);
48
52
 
@@ -53,6 +57,9 @@ private:
53
57
  template <typename T> int32_t sample_topp(T *probabilities, float coin);
54
58
  template <typename T> int32_t sample_mult(T *probabilities, float coin);
55
59
  template <typename T> int32_t sample_argmax(T *probabilities);
60
+ // In-place logit warpers: set excluded indices to -inf.
61
+ template <typename T> void mask_topk(T *logits);
62
+ template <typename T> void mask_topp(T *logits);
56
63
 
57
64
  template <typename T>
58
65
  inline void apply_temperature(T *logits, int32_t vocab_size) {
@@ -110,6 +117,7 @@ private:
110
117
  float topp_;
111
118
  float min_p_;
112
119
  float repetition_penalty_;
120
+ int32_t topk_;
113
121
  unsigned long long rng_state_;
114
122
  };
115
123
 
@@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager,
31
31
  // outer loop (call site) is responsible for managing state.
32
32
  ::executorch::runtime::Result<executorch::aten::Tensor>
33
33
  TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
34
- // ET_LOG(Info, "Input token %" PRIu64, input_token);
35
34
  auto method_meta_result = module_->method_meta("forward");
36
35
  if (!method_meta_result.ok()) {
37
36
  return method_meta_result.error();
@@ -102,9 +101,7 @@ int32_t TextDecoderRunner::logits_to_token(
102
101
  auto num_tokens = logits_tensor.size(1);
103
102
  logits += (num_tokens - 1) * vocab_size;
104
103
  }
105
- Sampler sampler(vocab_size, config_.temperature, config_.topp,
106
- static_cast<unsigned long long>(std::time(nullptr)),
107
- config_.min_p, config_.repetition_penalty);
104
+ Sampler sampler(vocab_size, config_);
108
105
  result = sampler.sample(logits, recent_tokens);
109
106
  });
110
107
  return result;
@@ -10,6 +10,7 @@
10
10
 
11
11
  #pragma once
12
12
 
13
+ #include "constants.h"
13
14
  #include "io_manager.h"
14
15
  #include "sampler.h"
15
16
 
@@ -40,8 +41,8 @@ public:
40
41
  step(TensorPtr &input, int64_t start_pos);
41
42
 
42
43
  /**
43
- * Load the Module for text decode purpose.
44
- * @return The error code.
44
+ * Load the Module for text decode purpose. Loads the dynamic-shape `forward`
45
+ * method used for both prefill and decode.
45
46
  */
46
47
  virtual ::executorch::runtime::Error load() {
47
48
  return module_->load_method("forward");
@@ -18,10 +18,11 @@ namespace llm {
18
18
 
19
19
  TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner,
20
20
  bool use_kv_cache, bool enable_parallel_prefill,
21
- int64_t max_seq_len)
21
+ int64_t max_seq_len, int32_t prefill_chunk_size)
22
22
  : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache),
23
23
  enable_parallel_prefill_(enable_parallel_prefill),
24
- max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {}
24
+ max_seq_len_(max_seq_len > 0 ? max_seq_len : 128),
25
+ prefill_chunk_size_(prefill_chunk_size) {}
25
26
 
26
27
  ::executorch::runtime::Result<uint64_t>
27
28
  TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
@@ -31,17 +32,17 @@ TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
31
32
  ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
32
33
  }
33
34
 
34
- // Check if we need to chunk the prompt tokens
35
35
  int32_t num_prompt_tokens = prompt_tokens.size();
36
+ int32_t chunk_size =
37
+ prefill_chunk_size_ > 0 ? prefill_chunk_size_ : max_seq_len_;
36
38
 
37
- // If prompt tokens exceed max_seq_len_, we need to chunk them
38
- if (num_prompt_tokens > max_seq_len_) {
39
+ if (num_prompt_tokens > chunk_size) {
39
40
  uint64_t cur_token = 0;
40
41
  int num_tokens_to_process = 0;
41
42
 
42
43
  while (num_tokens_to_process < num_prompt_tokens) {
43
- auto num_tokens_to_prefill_with = std::min<int>(
44
- num_prompt_tokens - num_tokens_to_process, max_seq_len_);
44
+ auto num_tokens_to_prefill_with =
45
+ std::min<int>(num_prompt_tokens - num_tokens_to_process, chunk_size);
45
46
 
46
47
  std::vector<uint64_t> prompt_tokens_to_process(
47
48
  num_tokens_to_prefill_with);
@@ -75,7 +76,6 @@ TextPrefiller::prefill_chunk(std::vector<uint64_t> &prompt_tokens,
75
76
  // store the token
76
77
  uint64_t cur_token;
77
78
  if (enable_parallel_prefill_ || !use_kv_cache_) {
78
- // initialize tensor wrappers
79
79
  auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens},
80
80
  executorch::aten::ScalarType::Long);
81
81
 
@@ -19,8 +19,14 @@ namespace llm {
19
19
 
20
20
  class TextPrefiller {
21
21
  public:
22
+ // prefill_chunk_size: when > 0, the prompt is always processed in steps of
23
+ // this size (see prefill()). Set to the model's forward sequence-length cap
24
+ // for the MLX backend (its forward is exported with a sliding-window bound
25
+ // and one-shot prefill spikes Metal memory). Other backends (XNNPACK/CoreML)
26
+ // pass 0 → original one-shot behavior.
22
27
  TextPrefiller(TextDecoderRunner *text_decoder_runner, bool use_kv_cache,
23
- bool enable_parallel_prefill, int64_t max_seq_len = 128);
28
+ bool enable_parallel_prefill, int64_t max_seq_len = 128,
29
+ int32_t prefill_chunk_size = 0);
24
30
 
25
31
  virtual ~TextPrefiller() = default;
26
32
  /**
@@ -70,6 +76,7 @@ private:
70
76
  bool use_kv_cache_;
71
77
  bool enable_parallel_prefill_;
72
78
  int64_t max_seq_len_;
79
+ int32_t prefill_chunk_size_;
73
80
  };
74
81
 
75
82
  } // namespace llm
@@ -26,11 +26,24 @@ Error TextRunner::load_subcomponents() {
26
26
 
27
27
  Stats *stats_ptr = &stats_;
28
28
 
29
- text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
30
- *module_, io_manager_.get(), config_);
29
+ text_decoder_runner_ =
30
+ std::make_unique<TextDecoderRunner>(*module_, io_manager_.get(), config_);
31
+
32
+ int32_t prefill_chunk_size = 0;
33
+ auto fwd_meta = module_->method_meta("forward");
34
+ if (fwd_meta.ok() && fwd_meta->uses_backend("MLXBackend")) {
35
+ auto input_meta = fwd_meta->input_tensor_meta(0);
36
+ if (input_meta.ok()) {
37
+ auto sizes = input_meta->sizes();
38
+ if (sizes.size() >= 2 && sizes[sizes.size() - 1] > 0) {
39
+ prefill_chunk_size = sizes[sizes.size() - 1];
40
+ }
41
+ }
42
+ }
43
+
31
44
  text_prefiller_ = std::make_unique<TextPrefiller>(
32
45
  text_decoder_runner_.get(), config_.enable_kv_cache,
33
- config_.enable_dynamic_shape, config_.max_seq_len);
46
+ config_.enable_dynamic_shape, config_.max_seq_len, prefill_chunk_size);
34
47
  text_token_generator_ = std::make_unique<TextTokenGenerator>(
35
48
  tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
36
49
  std::move(eos_ids_), stats_ptr, config_);
@@ -65,6 +78,10 @@ Error TextRunner::generate_internal(
65
78
 
66
79
  stats_.inference_start_ms = time_in_ms();
67
80
 
81
+ // Multi-turn: JS re-renders the full chat history each call, so reset KV
82
+ // position to 0 and re-prefill from scratch.
83
+ pos_ = 0;
84
+
68
85
  int64_t context_len_left =
69
86
  static_cast<int64_t>(config_.max_context_length) - pos_;
70
87
 
@@ -79,16 +96,25 @@ Error TextRunner::generate_internal(
79
96
  std::vector<uint64_t> prompt_tokens = encodeResult.get();
80
97
  int num_prompt_tokens = prompt_tokens.size();
81
98
 
99
+ // For dynamic-shape PTEs (e.g. Gemma4 MLX/Vulkan), get_max_seq_len is the
100
+ // per-call decoder chunk size (e.g. the sliding window) and the real
101
+ // generation budget lives in get_max_context_len. Static-shape PTEs set both
102
+ // equal, so this collapses to the old behavior. Without this the budget is
103
+ // computed from the small chunk size, so max_new_tokens can resolve to ~0 and
104
+ // generation ends immediately after prefill.
105
+ const int32_t seq_cap = config_.enable_dynamic_shape
106
+ ? config_.max_context_length
107
+ : config_.max_seq_len;
108
+
82
109
  ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
83
110
  "Expected at least 1 prompt token");
84
- ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < config_.max_seq_len,
85
- InvalidArgument,
86
- "num_prompt_tokens %d >= max_seq_len %" PRId32,
87
- num_prompt_tokens, config_.max_seq_len);
111
+ ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument,
112
+ "num_prompt_tokens %d >= seq cap %" PRId32,
113
+ num_prompt_tokens, seq_cap);
88
114
 
89
115
  int32_t max_new_tokens = resolve_max_new_tokens(
90
- num_prompt_tokens, config_.max_seq_len,
91
- static_cast<int32_t>(context_len_left), config_.max_new_tokens);
116
+ num_prompt_tokens, seq_cap, static_cast<int32_t>(context_len_left),
117
+ config_.max_new_tokens);
92
118
 
93
119
  ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
94
120
  "Max new tokens %d is <= 0", max_new_tokens);
@@ -100,8 +100,8 @@ public:
100
100
  prev_token = cur_token;
101
101
 
102
102
  stats_->on_sampling_begin();
103
- cur_token =
104
- text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens);
103
+ cur_token = text_decoder_runner_->logits_to_token(logits_tensor,
104
+ generated_tokens);
105
105
  stats_->on_sampling_end();
106
106
 
107
107
  pos++;
@@ -152,7 +152,6 @@ public:
152
152
  if (should_stop_) {
153
153
  break;
154
154
  }
155
-
156
155
  // data-dependent terminating condition: we have n_eos_ number of EOS
157
156
  if (eos_ids_->find(cur_token) != eos_ids_->end()) {
158
157
  printf("\n");
@@ -8,7 +8,6 @@
8
8
 
9
9
  #pragma once
10
10
  #include "constants.h"
11
- #include "text_prefiller.h"
12
11
  #include <cctype>
13
12
  #include <executorch/extension/module/module.h>
14
13
  #include <executorch/extension/tensor/tensor.h>
@@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy';
6
6
  * Default system prompt used to guide the behavior of Large Language Models (LLMs).
7
7
  * @category Utilities - LLM
8
8
  */
9
- export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text.";
9
+ export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance";
10
10
 
11
11
  /**
12
12
  * Generates a default structured output prompt based on the provided JSON schema.
@@ -1 +1 @@
1
- {"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+QAA+Q;;AAEjR;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}
1
+ {"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+UAA+U;;AAEjV;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}