react-native-executorch 0.9.0 → 0.9.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/libs/classes.jar +0 -0
- package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
- package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
- package/common/rnexecutorch/models/llm/LLM.h +4 -3
- package/common/rnexecutorch/models/llm/Types.h +23 -0
- package/common/runner/base_llm_runner.cpp +10 -3
- package/common/runner/base_llm_runner.h +1 -0
- package/common/runner/constants.h +15 -1
- package/common/runner/encoders/audio_encoder.cpp +111 -0
- package/common/runner/encoders/audio_encoder.h +40 -0
- package/common/runner/encoders/vision_encoder.cpp +0 -1
- package/common/runner/irunner.h +5 -0
- package/common/runner/multimodal_decoder_runner.h +50 -1
- package/common/runner/multimodal_input.h +16 -1
- package/common/runner/multimodal_prefiller.cpp +374 -64
- package/common/runner/multimodal_prefiller.h +57 -6
- package/common/runner/multimodal_runner.cpp +19 -12
- package/common/runner/multimodal_runner.h +1 -1
- package/common/runner/sampler.cpp +111 -35
- package/common/runner/sampler.h +13 -5
- package/common/runner/text_decoder_runner.cpp +1 -4
- package/common/runner/text_decoder_runner.h +3 -2
- package/common/runner/text_prefiller.cpp +8 -8
- package/common/runner/text_prefiller.h +8 -1
- package/common/runner/text_runner.cpp +35 -9
- package/common/runner/text_token_generator.h +2 -3
- package/common/runner/util.h +0 -1
- package/lib/module/constants/llmDefaults.js +1 -1
- package/lib/module/constants/llmDefaults.js.map +1 -1
- package/lib/module/constants/modelRegistry.js +33 -2
- package/lib/module/constants/modelRegistry.js.map +1 -1
- package/lib/module/constants/modelUrls.js +43 -6
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/controllers/LLMController.js +69 -20
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
- package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
- package/lib/module/types/llm.js +11 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/typescript/constants/llmDefaults.d.ts +1 -1
- package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
- package/lib/typescript/constants/modelRegistry.d.ts +28 -1
- package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +40 -12
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts +7 -9
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
- package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
- package/lib/typescript/types/llm.d.ts +63 -36
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/package.json +1 -1
- package/react-native-executorch.podspec +6 -0
- package/src/constants/llmDefaults.ts +1 -1
- package/src/constants/modelRegistry.ts +34 -2
- package/src/constants/modelUrls.ts +47 -6
- package/src/controllers/LLMController.ts +89 -40
- package/src/hooks/natural_language_processing/useLLM.ts +5 -6
- package/src/modules/natural_language_processing/LLMModule.ts +19 -8
- package/src/types/llm.ts +64 -34
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/include/executorch/ExecuTorch.h +2 -0
- package/third-party/include/executorch/ExecuTorchModule.h +46 -0
- package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
- package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
- package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
- package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
- package/third-party/include/executorch/extension/module/module.h +47 -8
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
- package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
- package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
- package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
- package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
- package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
- package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
- package/third-party/include/executorch/runtime/core/error.h +1 -0
- package/third-party/include/executorch/runtime/core/evalue.h +256 -9
- package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
- package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
- package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
- package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
- package/third-party/include/executorch/runtime/executor/method.h +9 -3
- package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
- package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/executor/program.h +3 -1
- package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
|
@@ -35,6 +35,7 @@
|
|
|
35
35
|
#include "sampler.h"
|
|
36
36
|
#include <algorithm>
|
|
37
37
|
#include <ctime>
|
|
38
|
+
#include <limits>
|
|
38
39
|
#include <vector>
|
|
39
40
|
|
|
40
41
|
namespace executorch {
|
|
@@ -46,7 +47,7 @@ template <typename T> int32_t Sampler::sample_argmax(T *probabilities) {
|
|
|
46
47
|
// return the index that has the highest probability
|
|
47
48
|
int max_i = 0;
|
|
48
49
|
T max_p = probabilities[0];
|
|
49
|
-
for (
|
|
50
|
+
for (size_t i = 1; i < vocab_size_; i++) {
|
|
50
51
|
if (probabilities[i] > max_p) {
|
|
51
52
|
max_i = i;
|
|
52
53
|
max_p = probabilities[i];
|
|
@@ -60,7 +61,7 @@ int32_t Sampler::sample_mult(T *probabilities, float coin) {
|
|
|
60
61
|
// sample index from probabilities (they must sum to 1!)
|
|
61
62
|
// coin is a random number in [0, 1), usually from random_f32()
|
|
62
63
|
T cdf = 0.0;
|
|
63
|
-
for (
|
|
64
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
64
65
|
cdf += probabilities[i];
|
|
65
66
|
if (coin < cdf) {
|
|
66
67
|
return i;
|
|
@@ -84,7 +85,7 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
|
|
|
84
85
|
std::make_unique<ProbIndex<T>[]>(vocab_size_);
|
|
85
86
|
|
|
86
87
|
const float cutoff = (1.0f - topp_) / (n - 1);
|
|
87
|
-
for (
|
|
88
|
+
for (size_t i = 0; i < n; i++) {
|
|
88
89
|
if (probabilities[i] >= cutoff) {
|
|
89
90
|
probindex[n0].index = i;
|
|
90
91
|
probindex[n0].prob = probabilities[i];
|
|
@@ -92,61 +93,138 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
|
|
|
92
93
|
}
|
|
93
94
|
}
|
|
94
95
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
96
|
+
std::sort(probindex.get(), probindex.get() + n0,
|
|
97
|
+
[](const ProbIndex<T> &a, const ProbIndex<T> &b) {
|
|
98
|
+
return a.prob > b.prob;
|
|
99
|
+
});
|
|
99
100
|
|
|
100
101
|
// truncate the list where cumulative probability exceeds topp
|
|
101
102
|
T cumulative_prob = 0;
|
|
102
|
-
int last_idx = n0 - 1;
|
|
103
|
-
for (
|
|
103
|
+
int last_idx = n0 - 1;
|
|
104
|
+
for (size_t i = 0; i < n0; i++) {
|
|
104
105
|
cumulative_prob += probindex[i].prob;
|
|
105
|
-
if (cumulative_prob > topp_) {
|
|
106
|
+
if (static_cast<float>(cumulative_prob) > topp_) {
|
|
106
107
|
last_idx = i;
|
|
107
|
-
break;
|
|
108
|
+
break;
|
|
108
109
|
}
|
|
109
110
|
}
|
|
110
111
|
|
|
111
112
|
// sample from the truncated list
|
|
112
|
-
|
|
113
|
+
float r = coin * static_cast<float>(cumulative_prob);
|
|
113
114
|
T cdf = 0;
|
|
114
|
-
for (
|
|
115
|
+
for (size_t i = 0; i <= last_idx; i++) {
|
|
115
116
|
cdf += probindex[i].prob;
|
|
116
|
-
if (r < cdf) {
|
|
117
|
+
if (r < static_cast<float>(cdf)) {
|
|
117
118
|
return probindex[i].index;
|
|
118
119
|
}
|
|
119
120
|
}
|
|
120
|
-
return probindex[last_idx].index;
|
|
121
|
+
return probindex[last_idx].index;
|
|
121
122
|
}
|
|
122
123
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
124
|
+
// Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary
|
|
125
|
+
// are kept (matches HuggingFace TopKLogitsWarper).
|
|
126
|
+
template <typename T> void Sampler::mask_topk(T *logits) {
|
|
127
|
+
if (topk_ <= 0 || topk_ >= vocab_size_) {
|
|
128
|
+
return;
|
|
129
|
+
}
|
|
130
|
+
// Partial-select the (topk_-th largest) threshold using nth_element on a
|
|
131
|
+
// copy of logits; O(n) average.
|
|
132
|
+
std::vector<T> scratch(logits, logits + vocab_size_);
|
|
133
|
+
std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1),
|
|
134
|
+
scratch.end(), std::greater<T>());
|
|
135
|
+
const T threshold = scratch[topk_ - 1];
|
|
136
|
+
constexpr T neg_inf = std::numeric_limits<T>::lowest();
|
|
137
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
138
|
+
if (logits[i] < threshold) {
|
|
139
|
+
logits[i] = neg_inf;
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// Mask logits whose softmax-prob falls outside the top-p nucleus to -inf.
|
|
145
|
+
// Keeps the token that crosses the threshold (HuggingFace convention).
|
|
146
|
+
template <typename T> void Sampler::mask_topp(T *logits) {
|
|
147
|
+
if (topp_ <= 0.0f || topp_ >= 1.0f) {
|
|
148
|
+
return;
|
|
149
|
+
}
|
|
150
|
+
// Softmax into a scratch probs[] (do not mutate logits yet).
|
|
151
|
+
T max_val = logits[0];
|
|
152
|
+
for (size_t i = 1; i < vocab_size_; i++) {
|
|
153
|
+
if (logits[i] > max_val) {
|
|
154
|
+
max_val = logits[i];
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
std::unique_ptr<ProbIndex<T>[]> probindex =
|
|
158
|
+
std::make_unique<ProbIndex<T>[]>(vocab_size_);
|
|
159
|
+
T sum = 0;
|
|
160
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
161
|
+
T e = static_cast<T>(std::expf(static_cast<float>(logits[i] - max_val)));
|
|
162
|
+
probindex[i].prob = e;
|
|
163
|
+
probindex[i].index = i;
|
|
164
|
+
sum += e;
|
|
165
|
+
}
|
|
166
|
+
if (sum <= T(0)) {
|
|
167
|
+
return;
|
|
168
|
+
}
|
|
169
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
170
|
+
probindex[i].prob /= sum;
|
|
171
|
+
}
|
|
172
|
+
std::sort(probindex.get(), probindex.get() + vocab_size_,
|
|
173
|
+
[](const ProbIndex<T> &a, const ProbIndex<T> &b) {
|
|
174
|
+
return a.prob > b.prob;
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
// Find the smallest prefix whose cumulative probability >= topp_.
|
|
178
|
+
T cumulative = 0;
|
|
179
|
+
int last_idx = vocab_size_ - 1;
|
|
180
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
181
|
+
cumulative += probindex[i].prob;
|
|
182
|
+
if (static_cast<float>(cumulative) >= topp_) {
|
|
183
|
+
last_idx = i;
|
|
184
|
+
break;
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
// Mark kept indices, then -inf the rest.
|
|
188
|
+
std::vector<bool> keep(vocab_size_, false);
|
|
189
|
+
for (size_t i = 0; i <= last_idx; i++) {
|
|
190
|
+
keep[probindex[i].index] = true;
|
|
191
|
+
}
|
|
192
|
+
constexpr T neg_inf = std::numeric_limits<T>::lowest();
|
|
193
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
194
|
+
if (!keep[i]) {
|
|
195
|
+
logits[i] = neg_inf;
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
Sampler::Sampler(int32_t vocab_size, GenerationConfig config,
|
|
201
|
+
unsigned long long rng_seed)
|
|
126
202
|
: vocab_size_(vocab_size),
|
|
127
|
-
inv_temperature_(
|
|
128
|
-
|
|
203
|
+
inv_temperature_(
|
|
204
|
+
(config.temperature != 0.0f) ? (1.0f / config.temperature) : 0.0f),
|
|
205
|
+
topp_(config.topp), min_p_(config.min_p),
|
|
206
|
+
repetition_penalty_(config.repetition_penalty), topk_(config.topk),
|
|
129
207
|
rng_state_(rng_seed) {}
|
|
130
208
|
|
|
131
|
-
Sampler::Sampler(
|
|
132
|
-
: Sampler(vocab_size,
|
|
209
|
+
Sampler::Sampler(int32_t vocab_size, GenerationConfig config)
|
|
210
|
+
: Sampler(vocab_size, config, std::time(nullptr)) {}
|
|
133
211
|
|
|
134
212
|
template <typename T> static void softmax(T *x, int size) {
|
|
135
213
|
// find max value (for numerical stability)
|
|
136
214
|
T max_val = x[0];
|
|
137
|
-
for (
|
|
215
|
+
for (size_t i = 1; i < size; i++) {
|
|
138
216
|
if (x[i] > max_val) {
|
|
139
217
|
max_val = x[i];
|
|
140
218
|
}
|
|
141
219
|
}
|
|
142
220
|
// exp and sum
|
|
143
221
|
T sum = 0;
|
|
144
|
-
for (
|
|
222
|
+
for (size_t i = 0; i < size; i++) {
|
|
145
223
|
x[i] = expf(x[i] - max_val);
|
|
146
224
|
sum += x[i];
|
|
147
225
|
}
|
|
148
226
|
// normalize
|
|
149
|
-
for (
|
|
227
|
+
for (size_t i = 0; i < size; i++) {
|
|
150
228
|
x[i] /= sum;
|
|
151
229
|
}
|
|
152
230
|
}
|
|
@@ -175,20 +253,18 @@ int32_t Sampler::sample(T *logits, const std::vector<uint64_t> &recent_tokens) {
|
|
|
175
253
|
apply_repetition_penalty(logits, vocab_size_, recent_tokens);
|
|
176
254
|
// 2. apply the temperature to the logits
|
|
177
255
|
apply_temperature(logits, vocab_size_);
|
|
178
|
-
// 3.
|
|
256
|
+
// 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass)
|
|
257
|
+
mask_topk(logits);
|
|
258
|
+
// 4. mask out logits outside top-p by rank (pre-softmax)
|
|
259
|
+
mask_topp(logits);
|
|
260
|
+
// 5. apply softmax to the logits to get the probabilities for next token
|
|
179
261
|
softmax(logits, vocab_size_);
|
|
180
|
-
//
|
|
262
|
+
// 6. apply min_p truncation
|
|
181
263
|
apply_min_p(logits, vocab_size_);
|
|
182
264
|
// flip a (float) coin (this is our source of entropy for sampling)
|
|
183
265
|
float coin = random_f32(&rng_state_);
|
|
184
|
-
//
|
|
185
|
-
|
|
186
|
-
// simply sample from the predicted probability distribution
|
|
187
|
-
next = sample_mult(logits, coin);
|
|
188
|
-
} else {
|
|
189
|
-
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
|
190
|
-
next = sample_topp(logits, coin);
|
|
191
|
-
}
|
|
266
|
+
// 7. we sample from this distribution to get the next token
|
|
267
|
+
next = sample_mult(logits, coin);
|
|
192
268
|
}
|
|
193
269
|
return next;
|
|
194
270
|
}
|
package/common/runner/sampler.h
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
#pragma once
|
|
10
10
|
|
|
11
|
+
#include "runner/irunner.h"
|
|
11
12
|
#include <algorithm>
|
|
12
13
|
#include <cctype>
|
|
13
14
|
#include <cmath>
|
|
@@ -28,6 +29,7 @@ namespace executorch {
|
|
|
28
29
|
namespace extension {
|
|
29
30
|
namespace llm {
|
|
30
31
|
// A simple llama2 sampler.
|
|
32
|
+
struct GenerationConfig;
|
|
31
33
|
|
|
32
34
|
inline constexpr auto kTopp = 0.9f;
|
|
33
35
|
|
|
@@ -38,11 +40,13 @@ template <typename T> struct ProbIndex {
|
|
|
38
40
|
|
|
39
41
|
class Sampler {
|
|
40
42
|
public:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
Sampler(int32_t vocab_size,
|
|
43
|
+
// topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p.
|
|
44
|
+
// Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask
|
|
45
|
+
// -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses
|
|
46
|
+
// to greedy; pass topk = 0 to keep full-vocab temperature sampling.
|
|
47
|
+
Sampler(int32_t vocab_size, GenerationConfig config,
|
|
48
|
+
unsigned long long rng_seed);
|
|
49
|
+
Sampler(int32_t vocab_size, GenerationConfig config);
|
|
46
50
|
|
|
47
51
|
template <typename T> int32_t sample(T *logits);
|
|
48
52
|
|
|
@@ -53,6 +57,9 @@ private:
|
|
|
53
57
|
template <typename T> int32_t sample_topp(T *probabilities, float coin);
|
|
54
58
|
template <typename T> int32_t sample_mult(T *probabilities, float coin);
|
|
55
59
|
template <typename T> int32_t sample_argmax(T *probabilities);
|
|
60
|
+
// In-place logit warpers: set excluded indices to -inf.
|
|
61
|
+
template <typename T> void mask_topk(T *logits);
|
|
62
|
+
template <typename T> void mask_topp(T *logits);
|
|
56
63
|
|
|
57
64
|
template <typename T>
|
|
58
65
|
inline void apply_temperature(T *logits, int32_t vocab_size) {
|
|
@@ -110,6 +117,7 @@ private:
|
|
|
110
117
|
float topp_;
|
|
111
118
|
float min_p_;
|
|
112
119
|
float repetition_penalty_;
|
|
120
|
+
int32_t topk_;
|
|
113
121
|
unsigned long long rng_state_;
|
|
114
122
|
};
|
|
115
123
|
|
|
@@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager,
|
|
|
31
31
|
// outer loop (call site) is responsible for managing state.
|
|
32
32
|
::executorch::runtime::Result<executorch::aten::Tensor>
|
|
33
33
|
TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
|
|
34
|
-
// ET_LOG(Info, "Input token %" PRIu64, input_token);
|
|
35
34
|
auto method_meta_result = module_->method_meta("forward");
|
|
36
35
|
if (!method_meta_result.ok()) {
|
|
37
36
|
return method_meta_result.error();
|
|
@@ -102,9 +101,7 @@ int32_t TextDecoderRunner::logits_to_token(
|
|
|
102
101
|
auto num_tokens = logits_tensor.size(1);
|
|
103
102
|
logits += (num_tokens - 1) * vocab_size;
|
|
104
103
|
}
|
|
105
|
-
Sampler sampler(vocab_size, config_
|
|
106
|
-
static_cast<unsigned long long>(std::time(nullptr)),
|
|
107
|
-
config_.min_p, config_.repetition_penalty);
|
|
104
|
+
Sampler sampler(vocab_size, config_);
|
|
108
105
|
result = sampler.sample(logits, recent_tokens);
|
|
109
106
|
});
|
|
110
107
|
return result;
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
#pragma once
|
|
12
12
|
|
|
13
|
+
#include "constants.h"
|
|
13
14
|
#include "io_manager.h"
|
|
14
15
|
#include "sampler.h"
|
|
15
16
|
|
|
@@ -40,8 +41,8 @@ public:
|
|
|
40
41
|
step(TensorPtr &input, int64_t start_pos);
|
|
41
42
|
|
|
42
43
|
/**
|
|
43
|
-
* Load the Module for text decode purpose.
|
|
44
|
-
*
|
|
44
|
+
* Load the Module for text decode purpose. Loads the dynamic-shape `forward`
|
|
45
|
+
* method used for both prefill and decode.
|
|
45
46
|
*/
|
|
46
47
|
virtual ::executorch::runtime::Error load() {
|
|
47
48
|
return module_->load_method("forward");
|
|
@@ -18,10 +18,11 @@ namespace llm {
|
|
|
18
18
|
|
|
19
19
|
TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner,
|
|
20
20
|
bool use_kv_cache, bool enable_parallel_prefill,
|
|
21
|
-
int64_t max_seq_len)
|
|
21
|
+
int64_t max_seq_len, int32_t prefill_chunk_size)
|
|
22
22
|
: text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache),
|
|
23
23
|
enable_parallel_prefill_(enable_parallel_prefill),
|
|
24
|
-
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128)
|
|
24
|
+
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128),
|
|
25
|
+
prefill_chunk_size_(prefill_chunk_size) {}
|
|
25
26
|
|
|
26
27
|
::executorch::runtime::Result<uint64_t>
|
|
27
28
|
TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
|
|
@@ -31,17 +32,17 @@ TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
|
|
|
31
32
|
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
|
|
32
33
|
}
|
|
33
34
|
|
|
34
|
-
// Check if we need to chunk the prompt tokens
|
|
35
35
|
int32_t num_prompt_tokens = prompt_tokens.size();
|
|
36
|
+
int32_t chunk_size =
|
|
37
|
+
prefill_chunk_size_ > 0 ? prefill_chunk_size_ : max_seq_len_;
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
if (num_prompt_tokens > max_seq_len_) {
|
|
39
|
+
if (num_prompt_tokens > chunk_size) {
|
|
39
40
|
uint64_t cur_token = 0;
|
|
40
41
|
int num_tokens_to_process = 0;
|
|
41
42
|
|
|
42
43
|
while (num_tokens_to_process < num_prompt_tokens) {
|
|
43
|
-
auto num_tokens_to_prefill_with =
|
|
44
|
-
num_prompt_tokens - num_tokens_to_process,
|
|
44
|
+
auto num_tokens_to_prefill_with =
|
|
45
|
+
std::min<int>(num_prompt_tokens - num_tokens_to_process, chunk_size);
|
|
45
46
|
|
|
46
47
|
std::vector<uint64_t> prompt_tokens_to_process(
|
|
47
48
|
num_tokens_to_prefill_with);
|
|
@@ -75,7 +76,6 @@ TextPrefiller::prefill_chunk(std::vector<uint64_t> &prompt_tokens,
|
|
|
75
76
|
// store the token
|
|
76
77
|
uint64_t cur_token;
|
|
77
78
|
if (enable_parallel_prefill_ || !use_kv_cache_) {
|
|
78
|
-
// initialize tensor wrappers
|
|
79
79
|
auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens},
|
|
80
80
|
executorch::aten::ScalarType::Long);
|
|
81
81
|
|
|
@@ -19,8 +19,14 @@ namespace llm {
|
|
|
19
19
|
|
|
20
20
|
class TextPrefiller {
|
|
21
21
|
public:
|
|
22
|
+
// prefill_chunk_size: when > 0, the prompt is always processed in steps of
|
|
23
|
+
// this size (see prefill()). Set to the model's forward sequence-length cap
|
|
24
|
+
// for the MLX backend (its forward is exported with a sliding-window bound
|
|
25
|
+
// and one-shot prefill spikes Metal memory). Other backends (XNNPACK/CoreML)
|
|
26
|
+
// pass 0 → original one-shot behavior.
|
|
22
27
|
TextPrefiller(TextDecoderRunner *text_decoder_runner, bool use_kv_cache,
|
|
23
|
-
bool enable_parallel_prefill, int64_t max_seq_len = 128
|
|
28
|
+
bool enable_parallel_prefill, int64_t max_seq_len = 128,
|
|
29
|
+
int32_t prefill_chunk_size = 0);
|
|
24
30
|
|
|
25
31
|
virtual ~TextPrefiller() = default;
|
|
26
32
|
/**
|
|
@@ -70,6 +76,7 @@ private:
|
|
|
70
76
|
bool use_kv_cache_;
|
|
71
77
|
bool enable_parallel_prefill_;
|
|
72
78
|
int64_t max_seq_len_;
|
|
79
|
+
int32_t prefill_chunk_size_;
|
|
73
80
|
};
|
|
74
81
|
|
|
75
82
|
} // namespace llm
|
|
@@ -26,11 +26,24 @@ Error TextRunner::load_subcomponents() {
|
|
|
26
26
|
|
|
27
27
|
Stats *stats_ptr = &stats_;
|
|
28
28
|
|
|
29
|
-
text_decoder_runner_ =
|
|
30
|
-
*module_, io_manager_.get(), config_);
|
|
29
|
+
text_decoder_runner_ =
|
|
30
|
+
std::make_unique<TextDecoderRunner>(*module_, io_manager_.get(), config_);
|
|
31
|
+
|
|
32
|
+
int32_t prefill_chunk_size = 0;
|
|
33
|
+
auto fwd_meta = module_->method_meta("forward");
|
|
34
|
+
if (fwd_meta.ok() && fwd_meta->uses_backend("MLXBackend")) {
|
|
35
|
+
auto input_meta = fwd_meta->input_tensor_meta(0);
|
|
36
|
+
if (input_meta.ok()) {
|
|
37
|
+
auto sizes = input_meta->sizes();
|
|
38
|
+
if (sizes.size() >= 2 && sizes[sizes.size() - 1] > 0) {
|
|
39
|
+
prefill_chunk_size = sizes[sizes.size() - 1];
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
31
44
|
text_prefiller_ = std::make_unique<TextPrefiller>(
|
|
32
45
|
text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
33
|
-
config_.enable_dynamic_shape, config_.max_seq_len);
|
|
46
|
+
config_.enable_dynamic_shape, config_.max_seq_len, prefill_chunk_size);
|
|
34
47
|
text_token_generator_ = std::make_unique<TextTokenGenerator>(
|
|
35
48
|
tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
36
49
|
std::move(eos_ids_), stats_ptr, config_);
|
|
@@ -65,6 +78,10 @@ Error TextRunner::generate_internal(
|
|
|
65
78
|
|
|
66
79
|
stats_.inference_start_ms = time_in_ms();
|
|
67
80
|
|
|
81
|
+
// Multi-turn: JS re-renders the full chat history each call, so reset KV
|
|
82
|
+
// position to 0 and re-prefill from scratch.
|
|
83
|
+
pos_ = 0;
|
|
84
|
+
|
|
68
85
|
int64_t context_len_left =
|
|
69
86
|
static_cast<int64_t>(config_.max_context_length) - pos_;
|
|
70
87
|
|
|
@@ -79,16 +96,25 @@ Error TextRunner::generate_internal(
|
|
|
79
96
|
std::vector<uint64_t> prompt_tokens = encodeResult.get();
|
|
80
97
|
int num_prompt_tokens = prompt_tokens.size();
|
|
81
98
|
|
|
99
|
+
// For dynamic-shape PTEs (e.g. Gemma4 MLX/Vulkan), get_max_seq_len is the
|
|
100
|
+
// per-call decoder chunk size (e.g. the sliding window) and the real
|
|
101
|
+
// generation budget lives in get_max_context_len. Static-shape PTEs set both
|
|
102
|
+
// equal, so this collapses to the old behavior. Without this the budget is
|
|
103
|
+
// computed from the small chunk size, so max_new_tokens can resolve to ~0 and
|
|
104
|
+
// generation ends immediately after prefill.
|
|
105
|
+
const int32_t seq_cap = config_.enable_dynamic_shape
|
|
106
|
+
? config_.max_context_length
|
|
107
|
+
: config_.max_seq_len;
|
|
108
|
+
|
|
82
109
|
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
|
|
83
110
|
"Expected at least 1 prompt token");
|
|
84
|
-
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens <
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
num_prompt_tokens, config_.max_seq_len);
|
|
111
|
+
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument,
|
|
112
|
+
"num_prompt_tokens %d >= seq cap %" PRId32,
|
|
113
|
+
num_prompt_tokens, seq_cap);
|
|
88
114
|
|
|
89
115
|
int32_t max_new_tokens = resolve_max_new_tokens(
|
|
90
|
-
num_prompt_tokens,
|
|
91
|
-
|
|
116
|
+
num_prompt_tokens, seq_cap, static_cast<int32_t>(context_len_left),
|
|
117
|
+
config_.max_new_tokens);
|
|
92
118
|
|
|
93
119
|
ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
|
|
94
120
|
"Max new tokens %d is <= 0", max_new_tokens);
|
|
@@ -100,8 +100,8 @@ public:
|
|
|
100
100
|
prev_token = cur_token;
|
|
101
101
|
|
|
102
102
|
stats_->on_sampling_begin();
|
|
103
|
-
cur_token =
|
|
104
|
-
|
|
103
|
+
cur_token = text_decoder_runner_->logits_to_token(logits_tensor,
|
|
104
|
+
generated_tokens);
|
|
105
105
|
stats_->on_sampling_end();
|
|
106
106
|
|
|
107
107
|
pos++;
|
|
@@ -152,7 +152,6 @@ public:
|
|
|
152
152
|
if (should_stop_) {
|
|
153
153
|
break;
|
|
154
154
|
}
|
|
155
|
-
|
|
156
155
|
// data-dependent terminating condition: we have n_eos_ number of EOS
|
|
157
156
|
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
|
|
158
157
|
printf("\n");
|
package/common/runner/util.h
CHANGED
|
@@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy';
|
|
|
6
6
|
* Default system prompt used to guide the behavior of Large Language Models (LLMs).
|
|
7
7
|
* @category Utilities - LLM
|
|
8
8
|
*/
|
|
9
|
-
export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text.";
|
|
9
|
+
export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance";
|
|
10
10
|
|
|
11
11
|
/**
|
|
12
12
|
* Generates a default structured output prompt based on the provided JSON schema.
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+
|
|
1
|
+
{"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+UAA+U;;AAEjV;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}
|
|
@@ -26,7 +26,7 @@ import { RnExecutorchErrorCode } from '../errors/ErrorCodes';
|
|
|
26
26
|
|
|
27
27
|
// Accessors are functions; calling with no opts returns the platform default.
|
|
28
28
|
|
|
29
|
-
const BACKEND_ORDER = ['xnnpack', 'coreml', 'vulkan', 'qnn'];
|
|
29
|
+
const BACKEND_ORDER = ['xnnpack', 'coreml', 'mlx', 'vulkan', 'qnn'];
|
|
30
30
|
function firstBackend(variants) {
|
|
31
31
|
for (const b of BACKEND_ORDER) {
|
|
32
32
|
if (variants[b]) return b;
|
|
@@ -107,6 +107,32 @@ function tts(c) {
|
|
|
107
107
|
// Per-backend variant maps for models that ship more than one backend.
|
|
108
108
|
// ─────────────────────────────────────────────────────────────────────────────
|
|
109
109
|
|
|
110
|
+
const GEMMA4_E2B_VARIANTS = {
|
|
111
|
+
mlx: {
|
|
112
|
+
base: {
|
|
113
|
+
modelName: 'gemma4-e2b',
|
|
114
|
+
modelSource: M.GEMMA4_E2B_MLX_MODEL,
|
|
115
|
+
tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
|
|
116
|
+
tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
|
|
117
|
+
}
|
|
118
|
+
},
|
|
119
|
+
xnnpack: {
|
|
120
|
+
base: {
|
|
121
|
+
modelName: 'gemma4-e2b',
|
|
122
|
+
modelSource: M.GEMMA4_E2B_XNNPACK_MODEL,
|
|
123
|
+
tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
|
|
124
|
+
tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
|
|
125
|
+
}
|
|
126
|
+
},
|
|
127
|
+
vulkan: {
|
|
128
|
+
base: {
|
|
129
|
+
modelName: 'gemma4-e2b',
|
|
130
|
+
modelSource: M.GEMMA4_E2B_VULKAN_MODEL,
|
|
131
|
+
tokenizerSource: M.GEMMA4_E2B_TOKENIZER,
|
|
132
|
+
tokenizerConfigSource: M.GEMMA4_E2B_TOKENIZER_CONFIG
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
};
|
|
110
136
|
const EFFICIENTNET_V2_S_VARIANTS = {
|
|
111
137
|
xnnpack: {
|
|
112
138
|
base: {
|
|
@@ -331,10 +357,15 @@ export const models = {
|
|
|
331
357
|
lfm2_5_350m: pair(M.LFM2_5_350M, M.LFM2_5_350M_QUANTIZED),
|
|
332
358
|
lfm2_5_1_2b_instruct: pair(M.LFM2_5_1_2B_INSTRUCT, M.LFM2_5_1_2B_INSTRUCT_QUANTIZED),
|
|
333
359
|
bielik_v3_0_1_5b: pair(M.BIELIK_V3_0_1_5B, M.BIELIK_V3_0_1_5B_QUANTIZED),
|
|
360
|
+
gemma4_e2b: variant(GEMMA4_E2B_VARIANTS, {
|
|
361
|
+
ios: 'mlx',
|
|
362
|
+
android: 'vulkan'
|
|
363
|
+
}),
|
|
334
364
|
// Multimodal LLMs — same hook/module as plain LLMs, listed here so users
|
|
335
365
|
// pick a model by capability ("LLM") rather than by modality.
|
|
336
366
|
lfm2_5_vl_1_6b: base(M.LFM2_5_VL_1_6B_QUANTIZED),
|
|
337
|
-
lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED)
|
|
367
|
+
lfm2_5_vl_450m: base(M.LFM2_5_VL_450M_QUANTIZED),
|
|
368
|
+
gemma4_e2b_multimodal: base(M.GEMMA4_E2B_MM)
|
|
338
369
|
},
|
|
339
370
|
classification: {
|
|
340
371
|
efficientnet_v2_s: variant(EFFICIENTNET_V2_S_VARIANTS)
|