react-native-executorch 0.9.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/android/libs/classes.jar +0 -0
  2. package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
  3. package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
  4. package/common/rnexecutorch/models/llm/LLM.h +4 -3
  5. package/common/rnexecutorch/models/llm/Types.h +23 -0
  6. package/common/runner/base_llm_runner.cpp +10 -3
  7. package/common/runner/base_llm_runner.h +1 -0
  8. package/common/runner/constants.h +15 -1
  9. package/common/runner/encoders/audio_encoder.cpp +111 -0
  10. package/common/runner/encoders/audio_encoder.h +40 -0
  11. package/common/runner/encoders/vision_encoder.cpp +0 -1
  12. package/common/runner/irunner.h +5 -0
  13. package/common/runner/multimodal_decoder_runner.h +50 -1
  14. package/common/runner/multimodal_input.h +16 -1
  15. package/common/runner/multimodal_prefiller.cpp +374 -64
  16. package/common/runner/multimodal_prefiller.h +57 -6
  17. package/common/runner/multimodal_runner.cpp +19 -12
  18. package/common/runner/multimodal_runner.h +1 -1
  19. package/common/runner/sampler.cpp +111 -35
  20. package/common/runner/sampler.h +13 -5
  21. package/common/runner/text_decoder_runner.cpp +1 -4
  22. package/common/runner/text_decoder_runner.h +3 -2
  23. package/common/runner/text_prefiller.cpp +8 -8
  24. package/common/runner/text_prefiller.h +8 -1
  25. package/common/runner/text_runner.cpp +35 -9
  26. package/common/runner/text_token_generator.h +2 -3
  27. package/common/runner/util.h +0 -1
  28. package/lib/module/constants/llmDefaults.js +1 -1
  29. package/lib/module/constants/llmDefaults.js.map +1 -1
  30. package/lib/module/constants/modelRegistry.js +33 -2
  31. package/lib/module/constants/modelRegistry.js.map +1 -1
  32. package/lib/module/constants/modelUrls.js +43 -6
  33. package/lib/module/constants/modelUrls.js.map +1 -1
  34. package/lib/module/controllers/LLMController.js +69 -20
  35. package/lib/module/controllers/LLMController.js.map +1 -1
  36. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
  37. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  38. package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
  39. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  40. package/lib/module/types/llm.js +11 -0
  41. package/lib/module/types/llm.js.map +1 -1
  42. package/lib/typescript/constants/llmDefaults.d.ts +1 -1
  43. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
  44. package/lib/typescript/constants/modelRegistry.d.ts +28 -1
  45. package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
  46. package/lib/typescript/constants/modelUrls.d.ts +40 -12
  47. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  48. package/lib/typescript/controllers/LLMController.d.ts +7 -9
  49. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  50. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
  51. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  52. package/lib/typescript/types/llm.d.ts +63 -36
  53. package/lib/typescript/types/llm.d.ts.map +1 -1
  54. package/package.json +1 -1
  55. package/react-native-executorch.podspec +6 -0
  56. package/src/constants/llmDefaults.ts +1 -1
  57. package/src/constants/modelRegistry.ts +34 -2
  58. package/src/constants/modelUrls.ts +47 -6
  59. package/src/controllers/LLMController.ts +89 -40
  60. package/src/hooks/natural_language_processing/useLLM.ts +5 -6
  61. package/src/modules/natural_language_processing/LLMModule.ts +19 -8
  62. package/src/types/llm.ts +64 -34
  63. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  64. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  65. package/third-party/include/executorch/ExecuTorch.h +2 -0
  66. package/third-party/include/executorch/ExecuTorchModule.h +46 -0
  67. package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
  68. package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
  69. package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
  70. package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
  71. package/third-party/include/executorch/extension/module/module.h +47 -8
  72. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
  73. package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
  74. package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
  75. package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
  76. package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
  77. package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
  78. package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
  79. package/third-party/include/executorch/runtime/core/error.h +1 -0
  80. package/third-party/include/executorch/runtime/core/evalue.h +256 -9
  81. package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
  82. package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
  83. package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
  84. package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
  85. package/third-party/include/executorch/runtime/executor/method.h +9 -3
  86. package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
  87. package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
  88. package/third-party/include/executorch/runtime/executor/program.h +3 -1
  89. package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
  90. package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
  91. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  92. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  93. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
  94. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  95. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  96. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
@@ -35,6 +35,7 @@
35
35
  #include "sampler.h"
36
36
  #include <algorithm>
37
37
  #include <ctime>
38
+ #include <limits>
38
39
  #include <vector>
39
40
 
40
41
  namespace executorch {
@@ -46,7 +47,7 @@ template <typename T> int32_t Sampler::sample_argmax(T *probabilities) {
46
47
  // return the index that has the highest probability
47
48
  int max_i = 0;
48
49
  T max_p = probabilities[0];
49
- for (int i = 1; i < vocab_size_; i++) {
50
+ for (size_t i = 1; i < vocab_size_; i++) {
50
51
  if (probabilities[i] > max_p) {
51
52
  max_i = i;
52
53
  max_p = probabilities[i];
@@ -60,7 +61,7 @@ int32_t Sampler::sample_mult(T *probabilities, float coin) {
60
61
  // sample index from probabilities (they must sum to 1!)
61
62
  // coin is a random number in [0, 1), usually from random_f32()
62
63
  T cdf = 0.0;
63
- for (int i = 0; i < vocab_size_; i++) {
64
+ for (size_t i = 0; i < vocab_size_; i++) {
64
65
  cdf += probabilities[i];
65
66
  if (coin < cdf) {
66
67
  return i;
@@ -84,7 +85,7 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
84
85
  std::make_unique<ProbIndex<T>[]>(vocab_size_);
85
86
 
86
87
  const float cutoff = (1.0f - topp_) / (n - 1);
87
- for (int i = 0; i < n; i++) {
88
+ for (size_t i = 0; i < n; i++) {
88
89
  if (probabilities[i] >= cutoff) {
89
90
  probindex[n0].index = i;
90
91
  probindex[n0].prob = probabilities[i];
@@ -92,61 +93,138 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
92
93
  }
93
94
  }
94
95
 
95
- auto compare = [](const ProbIndex<T> &a, const ProbIndex<T> &b) {
96
- return a.prob > b.prob;
97
- };
98
- std::sort(probindex.get(), probindex.get() + n0, compare);
96
+ std::sort(probindex.get(), probindex.get() + n0,
97
+ [](const ProbIndex<T> &a, const ProbIndex<T> &b) {
98
+ return a.prob > b.prob;
99
+ });
99
100
 
100
101
  // truncate the list where cumulative probability exceeds topp
101
102
  T cumulative_prob = 0;
102
- int last_idx = n0 - 1; // in case of rounding errors consider all elements
103
- for (int i = 0; i < n0; i++) {
103
+ int last_idx = n0 - 1;
104
+ for (size_t i = 0; i < n0; i++) {
104
105
  cumulative_prob += probindex[i].prob;
105
- if (cumulative_prob > topp_) {
106
+ if (static_cast<float>(cumulative_prob) > topp_) {
106
107
  last_idx = i;
107
- break; // we've exceeded topp by including last_idx
108
+ break;
108
109
  }
109
110
  }
110
111
 
111
112
  // sample from the truncated list
112
- const T &r = coin * cumulative_prob;
113
+ float r = coin * static_cast<float>(cumulative_prob);
113
114
  T cdf = 0;
114
- for (int i = 0; i <= last_idx; i++) {
115
+ for (size_t i = 0; i <= last_idx; i++) {
115
116
  cdf += probindex[i].prob;
116
- if (r < cdf) {
117
+ if (r < static_cast<float>(cdf)) {
117
118
  return probindex[i].index;
118
119
  }
119
120
  }
120
- return probindex[last_idx].index; // in case of rounding errors
121
+ return probindex[last_idx].index;
121
122
  }
122
123
 
123
- Sampler::Sampler(int32_t vocab_size, float temperature, float topp,
124
- unsigned long long rng_seed, float min_p,
125
- float repetition_penalty)
124
+ // Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary
125
+ // are kept (matches HuggingFace TopKLogitsWarper).
126
+ template <typename T> void Sampler::mask_topk(T *logits) {
127
+ if (topk_ <= 0 || topk_ >= vocab_size_) {
128
+ return;
129
+ }
130
+ // Partial-select the (topk_-th largest) threshold using nth_element on a
131
+ // copy of logits; O(n) average.
132
+ std::vector<T> scratch(logits, logits + vocab_size_);
133
+ std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1),
134
+ scratch.end(), std::greater<T>());
135
+ const T threshold = scratch[topk_ - 1];
136
+ constexpr T neg_inf = std::numeric_limits<T>::lowest();
137
+ for (size_t i = 0; i < vocab_size_; i++) {
138
+ if (logits[i] < threshold) {
139
+ logits[i] = neg_inf;
140
+ }
141
+ }
142
+ }
143
+
144
+ // Mask logits whose softmax-prob falls outside the top-p nucleus to -inf.
145
+ // Keeps the token that crosses the threshold (HuggingFace convention).
146
+ template <typename T> void Sampler::mask_topp(T *logits) {
147
+ if (topp_ <= 0.0f || topp_ >= 1.0f) {
148
+ return;
149
+ }
150
+ // Softmax into a scratch probs[] (do not mutate logits yet).
151
+ T max_val = logits[0];
152
+ for (size_t i = 1; i < vocab_size_; i++) {
153
+ if (logits[i] > max_val) {
154
+ max_val = logits[i];
155
+ }
156
+ }
157
+ std::unique_ptr<ProbIndex<T>[]> probindex =
158
+ std::make_unique<ProbIndex<T>[]>(vocab_size_);
159
+ T sum = 0;
160
+ for (size_t i = 0; i < vocab_size_; i++) {
161
+ T e = static_cast<T>(std::expf(static_cast<float>(logits[i] - max_val)));
162
+ probindex[i].prob = e;
163
+ probindex[i].index = i;
164
+ sum += e;
165
+ }
166
+ if (sum <= T(0)) {
167
+ return;
168
+ }
169
+ for (size_t i = 0; i < vocab_size_; i++) {
170
+ probindex[i].prob /= sum;
171
+ }
172
+ std::sort(probindex.get(), probindex.get() + vocab_size_,
173
+ [](const ProbIndex<T> &a, const ProbIndex<T> &b) {
174
+ return a.prob > b.prob;
175
+ });
176
+
177
+ // Find the smallest prefix whose cumulative probability >= topp_.
178
+ T cumulative = 0;
179
+ int last_idx = vocab_size_ - 1;
180
+ for (size_t i = 0; i < vocab_size_; i++) {
181
+ cumulative += probindex[i].prob;
182
+ if (static_cast<float>(cumulative) >= topp_) {
183
+ last_idx = i;
184
+ break;
185
+ }
186
+ }
187
+ // Mark kept indices, then -inf the rest.
188
+ std::vector<bool> keep(vocab_size_, false);
189
+ for (size_t i = 0; i <= last_idx; i++) {
190
+ keep[probindex[i].index] = true;
191
+ }
192
+ constexpr T neg_inf = std::numeric_limits<T>::lowest();
193
+ for (size_t i = 0; i < vocab_size_; i++) {
194
+ if (!keep[i]) {
195
+ logits[i] = neg_inf;
196
+ }
197
+ }
198
+ }
199
+
200
+ Sampler::Sampler(int32_t vocab_size, GenerationConfig config,
201
+ unsigned long long rng_seed)
126
202
  : vocab_size_(vocab_size),
127
- inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
128
- topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty),
203
+ inv_temperature_(
204
+ (config.temperature != 0.0f) ? (1.0f / config.temperature) : 0.0f),
205
+ topp_(config.topp), min_p_(config.min_p),
206
+ repetition_penalty_(config.repetition_penalty), topk_(config.topk),
129
207
  rng_state_(rng_seed) {}
130
208
 
131
- Sampler::Sampler(int vocab_size, float temperature, float topp)
132
- : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {}
209
+ Sampler::Sampler(int32_t vocab_size, GenerationConfig config)
210
+ : Sampler(vocab_size, config, std::time(nullptr)) {}
133
211
 
134
212
  template <typename T> static void softmax(T *x, int size) {
135
213
  // find max value (for numerical stability)
136
214
  T max_val = x[0];
137
- for (int i = 1; i < size; i++) {
215
+ for (size_t i = 1; i < size; i++) {
138
216
  if (x[i] > max_val) {
139
217
  max_val = x[i];
140
218
  }
141
219
  }
142
220
  // exp and sum
143
221
  T sum = 0;
144
- for (int i = 0; i < size; i++) {
222
+ for (size_t i = 0; i < size; i++) {
145
223
  x[i] = expf(x[i] - max_val);
146
224
  sum += x[i];
147
225
  }
148
226
  // normalize
149
- for (int i = 0; i < size; i++) {
227
+ for (size_t i = 0; i < size; i++) {
150
228
  x[i] /= sum;
151
229
  }
152
230
  }
@@ -175,20 +253,18 @@ int32_t Sampler::sample(T *logits, const std::vector<uint64_t> &recent_tokens) {
175
253
  apply_repetition_penalty(logits, vocab_size_, recent_tokens);
176
254
  // 2. apply the temperature to the logits
177
255
  apply_temperature(logits, vocab_size_);
178
- // 3. apply softmax to the logits to get the probabilities for next token
256
+ // 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass)
257
+ mask_topk(logits);
258
+ // 4. mask out logits outside top-p by rank (pre-softmax)
259
+ mask_topp(logits);
260
+ // 5. apply softmax to the logits to get the probabilities for next token
179
261
  softmax(logits, vocab_size_);
180
- // 4. apply min_p truncation
262
+ // 6. apply min_p truncation
181
263
  apply_min_p(logits, vocab_size_);
182
264
  // flip a (float) coin (this is our source of entropy for sampling)
183
265
  float coin = random_f32(&rng_state_);
184
- // 5. we sample from this distribution to get the next token
185
- if (topp_ <= 0 || topp_ >= 1) {
186
- // simply sample from the predicted probability distribution
187
- next = sample_mult(logits, coin);
188
- } else {
189
- // top-p (nucleus) sampling, clamping the least likely tokens to zero
190
- next = sample_topp(logits, coin);
191
- }
266
+ // 7. we sample from this distribution to get the next token
267
+ next = sample_mult(logits, coin);
192
268
  }
193
269
  return next;
194
270
  }
@@ -8,6 +8,7 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include "runner/irunner.h"
11
12
  #include <algorithm>
12
13
  #include <cctype>
13
14
  #include <cmath>
@@ -28,6 +29,7 @@ namespace executorch {
28
29
  namespace extension {
29
30
  namespace llm {
30
31
  // A simple llama2 sampler.
32
+ struct GenerationConfig;
31
33
 
32
34
  inline constexpr auto kTopp = 0.9f;
33
35
 
@@ -38,11 +40,13 @@ template <typename T> struct ProbIndex {
38
40
 
39
41
  class Sampler {
40
42
  public:
41
- Sampler(int32_t vocab_size, float temperature, float topp,
42
- unsigned long long rng_seed, float min_p = 0.0f,
43
- float repetition_penalty = 1.0f);
44
-
45
- Sampler(int32_t vocab_size, float temperature, float topp);
43
+ // topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p.
44
+ // Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask
45
+ // -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses
46
+ // to greedy; pass topk = 0 to keep full-vocab temperature sampling.
47
+ Sampler(int32_t vocab_size, GenerationConfig config,
48
+ unsigned long long rng_seed);
49
+ Sampler(int32_t vocab_size, GenerationConfig config);
46
50
 
47
51
  template <typename T> int32_t sample(T *logits);
48
52
 
@@ -53,6 +57,9 @@ private:
53
57
  template <typename T> int32_t sample_topp(T *probabilities, float coin);
54
58
  template <typename T> int32_t sample_mult(T *probabilities, float coin);
55
59
  template <typename T> int32_t sample_argmax(T *probabilities);
60
+ // In-place logit warpers: set excluded indices to -inf.
61
+ template <typename T> void mask_topk(T *logits);
62
+ template <typename T> void mask_topp(T *logits);
56
63
 
57
64
  template <typename T>
58
65
  inline void apply_temperature(T *logits, int32_t vocab_size) {
@@ -110,6 +117,7 @@ private:
110
117
  float topp_;
111
118
  float min_p_;
112
119
  float repetition_penalty_;
120
+ int32_t topk_;
113
121
  unsigned long long rng_state_;
114
122
  };
115
123
 
@@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager,
31
31
  // outer loop (call site) is responsible for managing state.
32
32
  ::executorch::runtime::Result<executorch::aten::Tensor>
33
33
  TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
34
- // ET_LOG(Info, "Input token %" PRIu64, input_token);
35
34
  auto method_meta_result = module_->method_meta("forward");
36
35
  if (!method_meta_result.ok()) {
37
36
  return method_meta_result.error();
@@ -102,9 +101,7 @@ int32_t TextDecoderRunner::logits_to_token(
102
101
  auto num_tokens = logits_tensor.size(1);
103
102
  logits += (num_tokens - 1) * vocab_size;
104
103
  }
105
- Sampler sampler(vocab_size, config_.temperature, config_.topp,
106
- static_cast<unsigned long long>(std::time(nullptr)),
107
- config_.min_p, config_.repetition_penalty);
104
+ Sampler sampler(vocab_size, config_);
108
105
  result = sampler.sample(logits, recent_tokens);
109
106
  });
110
107
  return result;
@@ -10,6 +10,7 @@
10
10
 
11
11
  #pragma once
12
12
 
13
+ #include "constants.h"
13
14
  #include "io_manager.h"
14
15
  #include "sampler.h"
15
16
 
@@ -40,8 +41,8 @@ public:
40
41
  step(TensorPtr &input, int64_t start_pos);
41
42
 
42
43
  /**
43
- * Load the Module for text decode purpose.
44
- * @return The error code.
44
+ * Load the Module for text decode purpose. Loads the dynamic-shape `forward`
45
+ * method used for both prefill and decode.
45
46
  */
46
47
  virtual ::executorch::runtime::Error load() {
47
48
  return module_->load_method("forward");
@@ -18,10 +18,11 @@ namespace llm {
18
18
 
19
19
  TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner,
20
20
  bool use_kv_cache, bool enable_parallel_prefill,
21
- int64_t max_seq_len)
21
+ int64_t max_seq_len, int32_t prefill_chunk_size)
22
22
  : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache),
23
23
  enable_parallel_prefill_(enable_parallel_prefill),
24
- max_seq_len_(max_seq_len > 0 ? max_seq_len : 128) {}
24
+ max_seq_len_(max_seq_len > 0 ? max_seq_len : 128),
25
+ prefill_chunk_size_(prefill_chunk_size) {}
25
26
 
26
27
  ::executorch::runtime::Result<uint64_t>
27
28
  TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
@@ -31,17 +32,17 @@ TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
31
32
  ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
32
33
  }
33
34
 
34
- // Check if we need to chunk the prompt tokens
35
35
  int32_t num_prompt_tokens = prompt_tokens.size();
36
+ int32_t chunk_size =
37
+ prefill_chunk_size_ > 0 ? prefill_chunk_size_ : max_seq_len_;
36
38
 
37
- // If prompt tokens exceed max_seq_len_, we need to chunk them
38
- if (num_prompt_tokens > max_seq_len_) {
39
+ if (num_prompt_tokens > chunk_size) {
39
40
  uint64_t cur_token = 0;
40
41
  int num_tokens_to_process = 0;
41
42
 
42
43
  while (num_tokens_to_process < num_prompt_tokens) {
43
- auto num_tokens_to_prefill_with = std::min<int>(
44
- num_prompt_tokens - num_tokens_to_process, max_seq_len_);
44
+ auto num_tokens_to_prefill_with =
45
+ std::min<int>(num_prompt_tokens - num_tokens_to_process, chunk_size);
45
46
 
46
47
  std::vector<uint64_t> prompt_tokens_to_process(
47
48
  num_tokens_to_prefill_with);
@@ -75,7 +76,6 @@ TextPrefiller::prefill_chunk(std::vector<uint64_t> &prompt_tokens,
75
76
  // store the token
76
77
  uint64_t cur_token;
77
78
  if (enable_parallel_prefill_ || !use_kv_cache_) {
78
- // initialize tensor wrappers
79
79
  auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens},
80
80
  executorch::aten::ScalarType::Long);
81
81
 
@@ -19,8 +19,14 @@ namespace llm {
19
19
 
20
20
  class TextPrefiller {
21
21
  public:
22
+ // prefill_chunk_size: when > 0, the prompt is always processed in steps of
23
+ // this size (see prefill()). Set to the model's forward sequence-length cap
24
+ // for the MLX backend (its forward is exported with a sliding-window bound
25
+ // and one-shot prefill spikes Metal memory). Other backends (XNNPACK/CoreML)
26
+ // pass 0 → original one-shot behavior.
22
27
  TextPrefiller(TextDecoderRunner *text_decoder_runner, bool use_kv_cache,
23
- bool enable_parallel_prefill, int64_t max_seq_len = 128);
28
+ bool enable_parallel_prefill, int64_t max_seq_len = 128,
29
+ int32_t prefill_chunk_size = 0);
24
30
 
25
31
  virtual ~TextPrefiller() = default;
26
32
  /**
@@ -70,6 +76,7 @@ private:
70
76
  bool use_kv_cache_;
71
77
  bool enable_parallel_prefill_;
72
78
  int64_t max_seq_len_;
79
+ int32_t prefill_chunk_size_;
73
80
  };
74
81
 
75
82
  } // namespace llm
@@ -26,11 +26,24 @@ Error TextRunner::load_subcomponents() {
26
26
 
27
27
  Stats *stats_ptr = &stats_;
28
28
 
29
- text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
30
- *module_, io_manager_.get(), config_);
29
+ text_decoder_runner_ =
30
+ std::make_unique<TextDecoderRunner>(*module_, io_manager_.get(), config_);
31
+
32
+ int32_t prefill_chunk_size = 0;
33
+ auto fwd_meta = module_->method_meta("forward");
34
+ if (fwd_meta.ok() && fwd_meta->uses_backend("MLXBackend")) {
35
+ auto input_meta = fwd_meta->input_tensor_meta(0);
36
+ if (input_meta.ok()) {
37
+ auto sizes = input_meta->sizes();
38
+ if (sizes.size() >= 2 && sizes[sizes.size() - 1] > 0) {
39
+ prefill_chunk_size = sizes[sizes.size() - 1];
40
+ }
41
+ }
42
+ }
43
+
31
44
  text_prefiller_ = std::make_unique<TextPrefiller>(
32
45
  text_decoder_runner_.get(), config_.enable_kv_cache,
33
- config_.enable_dynamic_shape, config_.max_seq_len);
46
+ config_.enable_dynamic_shape, config_.max_seq_len, prefill_chunk_size);
34
47
  text_token_generator_ = std::make_unique<TextTokenGenerator>(
35
48
  tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
36
49
  std::move(eos_ids_), stats_ptr, config_);
@@ -65,6 +78,10 @@ Error TextRunner::generate_internal(
65
78
 
66
79
  stats_.inference_start_ms = time_in_ms();
67
80
 
81
+ // Multi-turn: JS re-renders the full chat history each call, so reset KV
82
+ // position to 0 and re-prefill from scratch.
83
+ pos_ = 0;
84
+
68
85
  int64_t context_len_left =
69
86
  static_cast<int64_t>(config_.max_context_length) - pos_;
70
87
 
@@ -79,16 +96,25 @@ Error TextRunner::generate_internal(
79
96
  std::vector<uint64_t> prompt_tokens = encodeResult.get();
80
97
  int num_prompt_tokens = prompt_tokens.size();
81
98
 
99
+ // For dynamic-shape PTEs (e.g. Gemma4 MLX/Vulkan), get_max_seq_len is the
100
+ // per-call decoder chunk size (e.g. the sliding window) and the real
101
+ // generation budget lives in get_max_context_len. Static-shape PTEs set both
102
+ // equal, so this collapses to the old behavior. Without this the budget is
103
+ // computed from the small chunk size, so max_new_tokens can resolve to ~0 and
104
+ // generation ends immediately after prefill.
105
+ const int32_t seq_cap = config_.enable_dynamic_shape
106
+ ? config_.max_context_length
107
+ : config_.max_seq_len;
108
+
82
109
  ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
83
110
  "Expected at least 1 prompt token");
84
- ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < config_.max_seq_len,
85
- InvalidArgument,
86
- "num_prompt_tokens %d >= max_seq_len %" PRId32,
87
- num_prompt_tokens, config_.max_seq_len);
111
+ ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument,
112
+ "num_prompt_tokens %d >= seq cap %" PRId32,
113
+ num_prompt_tokens, seq_cap);
88
114
 
89
115
  int32_t max_new_tokens = resolve_max_new_tokens(
90
- num_prompt_tokens, config_.max_seq_len,
91
- static_cast<int32_t>(context_len_left), config_.max_new_tokens);
116
+ num_prompt_tokens, seq_cap, static_cast<int32_t>(context_len_left),
117
+ config_.max_new_tokens);
92
118
 
93
119
  ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
94
120
  "Max new tokens %d is <= 0", max_new_tokens);
@@ -100,8 +100,8 @@ public:
100
100
  prev_token = cur_token;
101
101
 
102
102
  stats_->on_sampling_begin();
103
- cur_token =
104
- text_decoder_runner_->logits_to_token(logits_tensor, generated_tokens);
103
+ cur_token = text_decoder_runner_->logits_to_token(logits_tensor,
104
+ generated_tokens);
105
105
  stats_->on_sampling_end();
106
106
 
107
107
  pos++;
@@ -152,7 +152,6 @@ public:
152
152
  if (should_stop_) {
153
153
  break;
154
154
  }
155
-
156
155
  // data-dependent terminating condition: we have n_eos_ number of EOS
157
156
  if (eos_ids_->find(cur_token) != eos_ids_->end()) {
158
157
  printf("\n");
@@ -8,7 +8,6 @@
8
8
 
9
9
  #pragma once
10
10
  #include "constants.h"
11
- #include "text_prefiller.h"
12
11
  #include <cctype>
13
12
  #include <executorch/extension/module/module.h>
14
13
  #include <executorch/extension/tensor/tensor.h>
@@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy';
6
6
  * Default system prompt used to guide the behavior of Large Language Models (LLMs).
7
7
  * @category Utilities - LLM
8
8
  */
9
- export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text.";
9
+ export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance";
10
10
 
11
11
  /**
12
12
  * Generates a default structured output prompt based on the provided JSON schema.
@@ -1 +1 @@
1
- {"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+QAA+Q;;AAEjR;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}
1
+ {"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+UAA+U;;AAEjV;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}
@@ -26,7 +26,7 @@ import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
26
26
 
27
27
  // Accessors are functions; calling with no opts returns the platform default.
28
28
 
29
- const BACKEND_ORDER = ['xnnpack', 'coreml', 'vulkan', 'qnn'];
29
+ const BACKEND_ORDER = ['xnnpack', 'coreml', 'mlx', 'vulkan', 'qnn'];
30
30
  function firstBackend(variants) {
31
31
  for (const b of BACKEND_ORDER) {
32
32
  if (variants[b]) return b;
@@ -107,6 +107,32 @@ function tts(c) {
107
107
  // Per-backend variant maps for models that ship more than one backend.
108
108
  // ─────────────────────────────────────────────────────────────────────────────
109
109
 
110
+ const GEMMA4_E2B_VARIANTS = {
111
+ mlx: {
112
+ base: {
113
+ modelName: 'gemma4-e2b',
114
+ modelSource: M.GEMMA4_E2B_MLX_MODEL,
115
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
116
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
117
+ }
118
+ },
119
+ xnnpack: {
120
+ base: {
121
+ modelName: 'gemma4-e2b',
122
+ modelSource: M.GEMMA4_E2B_XNNPACK_MODEL,
123
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
124
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
125
+ }
126
+ },
127
+ vulkan: {
128
+ base: {
129
+ modelName: 'gemma4-e2b',
130
+ modelSource: M.GEMMA4_E2B_VULKAN_MODEL,
131
+ tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
132
+ tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
133
+ }
134
+ }
135
+ };
110
136
  const EFFICIENTNET_V2_S_VARIANTS = {
111
137
  xnnpack: {
112
138
  base: {
@@ -331,10 +357,15 @@ export const models = {
331
357
  lfm2_5_350m: pair(M.LFM2_5_350M, M.LFM2_5_350M_QUANTIZED),
332
358
  lfm2_5_1_2b_instruct: pair(M.LFM2_5_1_2B_INSTRUCT, M.LFM2_5_1_2B_INSTRUCT_QUANTIZED),
333
359
  bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED),
360
+ gemma4_e2b: variant(GEMMA4_E2B_VARIANTS, {
361
+ ios: 'mlx',
362
+ android: 'vulkan'
363
+ }),
334
364
  // Multimodal LLMs — same hook/module as plain LLMs, listed here so users
335
365
  // pick a model by capability ("LLM") rather than by modality.
336
366
  lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED),
337
- lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED)
367
+ lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED),
368
+ gemma4_e2b_multimodal: base(M.GEMMA4_E2B_MM)
338
369
  },
339
370
  classification: {
340
371
  efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS)