react-native-executorch 0.9.0 → 0.9.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/libs/classes.jar +0 -0
- package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
- package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
- package/common/rnexecutorch/models/llm/LLM.h +4 -3
- package/common/rnexecutorch/models/llm/Types.h +23 -0
- package/common/runner/base_llm_runner.cpp +10 -3
- package/common/runner/base_llm_runner.h +1 -0
- package/common/runner/constants.h +15 -1
- package/common/runner/encoders/audio_encoder.cpp +111 -0
- package/common/runner/encoders/audio_encoder.h +40 -0
- package/common/runner/encoders/vision_encoder.cpp +13 -5
- package/common/runner/encoders/vision_encoder.h +15 -2
- package/common/runner/irunner.h +5 -0
- package/common/runner/multimodal_decoder_runner.h +50 -1
- package/common/runner/multimodal_input.h +16 -1
- package/common/runner/multimodal_prefiller.cpp +374 -64
- package/common/runner/multimodal_prefiller.h +57 -6
- package/common/runner/multimodal_runner.cpp +19 -12
- package/common/runner/multimodal_runner.h +1 -1
- package/common/runner/sampler.cpp +126 -39
- package/common/runner/sampler.h +13 -5
- package/common/runner/text_decoder_runner.cpp +1 -4
- package/common/runner/text_decoder_runner.h +3 -2
- package/common/runner/text_prefiller.cpp +8 -8
- package/common/runner/text_prefiller.h +8 -1
- package/common/runner/text_runner.cpp +35 -9
- package/common/runner/text_token_generator.h +2 -3
- package/common/runner/util.h +0 -1
- package/lib/module/constants/llmDefaults.js +1 -1
- package/lib/module/constants/llmDefaults.js.map +1 -1
- package/lib/module/constants/modelRegistry.js +62 -3
- package/lib/module/constants/modelRegistry.js.map +1 -1
- package/lib/module/constants/modelUrls.js +62 -6
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/controllers/LLMController.js +69 -20
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/modules/computer_vision/PoseEstimationModule.js +13 -1
- package/lib/module/modules/computer_vision/PoseEstimationModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
- package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
- package/lib/module/types/llm.js +11 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/poseEstimation.js.map +1 -1
- package/lib/typescript/constants/llmDefaults.d.ts +1 -1
- package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
- package/lib/typescript/constants/modelRegistry.d.ts +38 -1
- package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +52 -12
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts +7 -9
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts +6 -0
- package/lib/typescript/modules/computer_vision/PoseEstimationModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
- package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
- package/lib/typescript/types/llm.d.ts +63 -36
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/poseEstimation.d.ts +3 -0
- package/lib/typescript/types/poseEstimation.d.ts.map +1 -1
- package/package.json +1 -1
- package/react-native-executorch.podspec +6 -0
- package/src/constants/llmDefaults.ts +1 -1
- package/src/constants/modelRegistry.ts +62 -2
- package/src/constants/modelUrls.ts +69 -6
- package/src/controllers/LLMController.ts +89 -40
- package/src/hooks/natural_language_processing/useLLM.ts +5 -6
- package/src/modules/computer_vision/PoseEstimationModule.ts +12 -0
- package/src/modules/natural_language_processing/LLMModule.ts +19 -8
- package/src/types/llm.ts +64 -34
- package/src/types/poseEstimation.ts +10 -4
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/include/executorch/ExecuTorch.h +2 -0
- package/third-party/include/executorch/ExecuTorchModule.h +46 -0
- package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
- package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
- package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
- package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
- package/third-party/include/executorch/extension/module/module.h +47 -8
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
- package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
- package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
- package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
- package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
- package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
- package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
- package/third-party/include/executorch/runtime/core/error.h +1 -0
- package/third-party/include/executorch/runtime/core/evalue.h +256 -9
- package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
- package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
- package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
- package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
- package/third-party/include/executorch/runtime/executor/method.h +9 -3
- package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
- package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/executor/program.h +3 -1
- package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
#include "constants.h"
|
|
4
4
|
#include "util.h"
|
|
5
5
|
#include <rnexecutorch/Error.h>
|
|
6
|
-
#include <rnexecutorch/Log.h>
|
|
7
6
|
|
|
8
7
|
namespace executorch::extension::llm {
|
|
9
8
|
|
|
@@ -54,8 +53,14 @@ Error MultimodalRunner::load_subcomponents() {
|
|
|
54
53
|
if (enc_it != encoders_.end()) {
|
|
55
54
|
image_encoder = enc_it->second.get();
|
|
56
55
|
}
|
|
56
|
+
IEncoder *audio_encoder = nullptr;
|
|
57
|
+
auto aud_it = encoders_.find(MultimodalType::Audio);
|
|
58
|
+
if (aud_it != encoders_.end()) {
|
|
59
|
+
audio_encoder = aud_it->second.get();
|
|
60
|
+
}
|
|
57
61
|
mm_prefiller_ = std::make_unique<MultimodalPrefiller>(
|
|
58
|
-
*module_, *mm_decoder_runner_, *tokenizer_, image_encoder
|
|
62
|
+
*module_, *mm_decoder_runner_, *tokenizer_, metadata_, image_encoder,
|
|
63
|
+
audio_encoder);
|
|
59
64
|
mm_token_generator_ = std::make_unique<TextTokenGenerator>(
|
|
60
65
|
tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true,
|
|
61
66
|
std::move(eos_ids_), stats_ptr, config_);
|
|
@@ -78,22 +83,24 @@ Error MultimodalRunner::generate_internal(
|
|
|
78
83
|
}
|
|
79
84
|
|
|
80
85
|
stats_.inference_start_ms = time_in_ms();
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
if (!prefill_result.ok())
|
|
86
|
-
return prefill_result.error();
|
|
87
|
-
prefill_next_token = prefill_result.get();
|
|
88
|
-
}
|
|
86
|
+
auto prefill_result = mm_prefiller_->prefill(inputs, pos_);
|
|
87
|
+
if (!prefill_result.ok())
|
|
88
|
+
return prefill_result.error();
|
|
89
|
+
uint64_t prefill_next_token = prefill_result.get();
|
|
89
90
|
|
|
90
91
|
stats_.first_token_ms = time_in_ms();
|
|
91
92
|
stats_.prompt_eval_end_ms = time_in_ms();
|
|
92
93
|
stats_.num_prompt_tokens = pos_;
|
|
93
94
|
|
|
95
|
+
// For dynamic-shape PTEs (Gemma4 iter*), get_max_seq_len is the per-call
|
|
96
|
+
// decoder chunk size (e.g. 128) and the true generation budget lives in
|
|
97
|
+
// get_max_context_len. Mirrors text_runner.cpp:95-97.
|
|
98
|
+
const int32_t seq_cap = config_.enable_dynamic_shape
|
|
99
|
+
? config_.max_context_length
|
|
100
|
+
: config_.max_seq_len;
|
|
94
101
|
int32_t resolved_max_new = resolve_max_new_tokens(
|
|
95
|
-
static_cast<int32_t>(pos_), config_.
|
|
96
|
-
config_.
|
|
102
|
+
static_cast<int32_t>(pos_), seq_cap, config_.max_context_length,
|
|
103
|
+
config_.max_new_tokens);
|
|
97
104
|
|
|
98
105
|
std::vector<uint64_t> seed_tokens = {prefill_next_token};
|
|
99
106
|
auto wrapped_callback = [&](const std::string &piece) {
|
|
@@ -35,6 +35,10 @@
|
|
|
35
35
|
#include "sampler.h"
|
|
36
36
|
#include <algorithm>
|
|
37
37
|
#include <ctime>
|
|
38
|
+
#include <limits>
|
|
39
|
+
#include <ranges>
|
|
40
|
+
#include <span>
|
|
41
|
+
#include <type_traits>
|
|
38
42
|
#include <vector>
|
|
39
43
|
|
|
40
44
|
namespace executorch {
|
|
@@ -46,7 +50,7 @@ template <typename T> int32_t Sampler::sample_argmax(T *probabilities) {
|
|
|
46
50
|
// return the index that has the highest probability
|
|
47
51
|
int max_i = 0;
|
|
48
52
|
T max_p = probabilities[0];
|
|
49
|
-
for (
|
|
53
|
+
for (size_t i = 1; i < vocab_size_; i++) {
|
|
50
54
|
if (probabilities[i] > max_p) {
|
|
51
55
|
max_i = i;
|
|
52
56
|
max_p = probabilities[i];
|
|
@@ -60,7 +64,7 @@ int32_t Sampler::sample_mult(T *probabilities, float coin) {
|
|
|
60
64
|
// sample index from probabilities (they must sum to 1!)
|
|
61
65
|
// coin is a random number in [0, 1), usually from random_f32()
|
|
62
66
|
T cdf = 0.0;
|
|
63
|
-
for (
|
|
67
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
64
68
|
cdf += probabilities[i];
|
|
65
69
|
if (coin < cdf) {
|
|
66
70
|
return i;
|
|
@@ -84,7 +88,7 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
|
|
|
84
88
|
std::make_unique<ProbIndex<T>[]>(vocab_size_);
|
|
85
89
|
|
|
86
90
|
const float cutoff = (1.0f - topp_) / (n - 1);
|
|
87
|
-
for (
|
|
91
|
+
for (size_t i = 0; i < n; i++) {
|
|
88
92
|
if (probabilities[i] >= cutoff) {
|
|
89
93
|
probindex[n0].index = i;
|
|
90
94
|
probindex[n0].prob = probabilities[i];
|
|
@@ -92,62 +96,147 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
|
|
|
92
96
|
}
|
|
93
97
|
}
|
|
94
98
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
+
std::sort(probindex.get(), probindex.get() + n0,
|
|
100
|
+
[](const ProbIndex<T> &a, const ProbIndex<T> &b) {
|
|
101
|
+
return a.prob > b.prob;
|
|
102
|
+
});
|
|
99
103
|
|
|
100
104
|
// truncate the list where cumulative probability exceeds topp
|
|
101
105
|
T cumulative_prob = 0;
|
|
102
|
-
int last_idx = n0 - 1;
|
|
103
|
-
for (
|
|
106
|
+
int last_idx = n0 - 1;
|
|
107
|
+
for (size_t i = 0; i < n0; i++) {
|
|
104
108
|
cumulative_prob += probindex[i].prob;
|
|
105
|
-
if (cumulative_prob > topp_) {
|
|
109
|
+
if (static_cast<float>(cumulative_prob) > topp_) {
|
|
106
110
|
last_idx = i;
|
|
107
|
-
break;
|
|
111
|
+
break;
|
|
108
112
|
}
|
|
109
113
|
}
|
|
110
114
|
|
|
111
115
|
// sample from the truncated list
|
|
112
|
-
|
|
116
|
+
float r = coin * static_cast<float>(cumulative_prob);
|
|
113
117
|
T cdf = 0;
|
|
114
|
-
for (
|
|
118
|
+
for (size_t i = 0; i <= last_idx; i++) {
|
|
115
119
|
cdf += probindex[i].prob;
|
|
116
|
-
if (r < cdf) {
|
|
120
|
+
if (r < static_cast<float>(cdf)) {
|
|
117
121
|
return probindex[i].index;
|
|
118
122
|
}
|
|
119
123
|
}
|
|
120
|
-
return probindex[last_idx].index;
|
|
124
|
+
return probindex[last_idx].index;
|
|
121
125
|
}
|
|
122
126
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
127
|
+
// Mask logits outside the top-k by rank to -inf. Ties at the k-th boundary
|
|
128
|
+
// are kept (matches HuggingFace TopKLogitsWarper).
|
|
129
|
+
template <typename T> void Sampler::mask_topk(T *logits) {
|
|
130
|
+
if (topk_ <= 0 || topk_ >= vocab_size_) {
|
|
131
|
+
return;
|
|
132
|
+
}
|
|
133
|
+
// Partial-select the (topk_-th largest) threshold using nth_element on a
|
|
134
|
+
// copy of logits; O(n) average.
|
|
135
|
+
std::vector<T> scratch(logits, logits + vocab_size_);
|
|
136
|
+
std::nth_element(scratch.begin(), scratch.begin() + (topk_ - 1),
|
|
137
|
+
scratch.end(), std::greater<T>());
|
|
138
|
+
const T threshold = scratch[topk_ - 1];
|
|
139
|
+
constexpr T neg_inf = std::numeric_limits<T>::lowest();
|
|
140
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
141
|
+
if (logits[i] < threshold) {
|
|
142
|
+
logits[i] = neg_inf;
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Mask logits outside the top-p nucleus to -inf. Approximates the exact
|
|
148
|
+
// sort-based nucleus with a histogram over (logit - max): two O(n) passes, no
|
|
149
|
+
// sort. Binning in logit (not probability) space keeps uniform resolution for
|
|
150
|
+
// peaked and flat distributions alike. kRange=40 spans exp() down to ~4e-18.
|
|
151
|
+
template <typename T> void Sampler::mask_topp(T *logits) {
|
|
152
|
+
if (topp_ <= 0.0f || topp_ >= 1.0f) {
|
|
153
|
+
return;
|
|
154
|
+
}
|
|
155
|
+
constexpr int32_t kBins = 2048;
|
|
156
|
+
// Compute in a type at least as wide as T so converting logits never loses
|
|
157
|
+
// precision: double stays double, everything else (float and the narrow
|
|
158
|
+
// half/bf16/uint16 logit types) widens to float. Accumulating in T directly
|
|
159
|
+
// would be unsafe for bf16, whose mantissa saturates when summing exp()
|
|
160
|
+
// over the full vocab.
|
|
161
|
+
using acc_t = std::conditional_t<std::is_same_v<T, double>, double, float>;
|
|
162
|
+
constexpr acc_t kRange = 40;
|
|
163
|
+
|
|
164
|
+
std::span<const T> logit_span{logits, static_cast<size_t>(vocab_size_)};
|
|
165
|
+
const acc_t max_val =
|
|
166
|
+
static_cast<acc_t>(*std::ranges::max_element(logit_span));
|
|
167
|
+
|
|
168
|
+
std::vector<acc_t> bin_mass(kBins, acc_t(0));
|
|
169
|
+
acc_t total = 0;
|
|
170
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
171
|
+
acc_t d = static_cast<acc_t>(logits[i]) - max_val;
|
|
172
|
+
acc_t e = std::exp(d);
|
|
173
|
+
total += e;
|
|
174
|
+
int32_t bin = static_cast<int32_t>((d + kRange) / kRange * kBins);
|
|
175
|
+
bin = std::clamp(bin, 0, kBins - 1);
|
|
176
|
+
bin_mass[bin] += e;
|
|
177
|
+
}
|
|
178
|
+
if (total <= acc_t(0)) {
|
|
179
|
+
return;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// Highest bin downward until the kept mass reaches topp. The crossing bin is
|
|
183
|
+
// kept (HuggingFace "keep the token that crosses" convention).
|
|
184
|
+
const acc_t target = static_cast<acc_t>(topp_) * total;
|
|
185
|
+
acc_t acc = 0;
|
|
186
|
+
int32_t keep_bin = 0;
|
|
187
|
+
for (int32_t bin = kBins - 1; bin >= 0; --bin) {
|
|
188
|
+
acc += bin_mass[bin];
|
|
189
|
+
if (acc >= target) {
|
|
190
|
+
keep_bin = bin;
|
|
191
|
+
break;
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
const acc_t d_threshold =
|
|
195
|
+
static_cast<acc_t>(keep_bin) / kBins * kRange - kRange;
|
|
196
|
+
|
|
197
|
+
constexpr T neg_inf = std::numeric_limits<T>::lowest();
|
|
198
|
+
for (size_t i = 0; i < vocab_size_; i++) {
|
|
199
|
+
if (static_cast<acc_t>(logits[i]) - max_val < d_threshold) {
|
|
200
|
+
logits[i] = neg_inf;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
Sampler::Sampler(int32_t vocab_size, GenerationConfig config,
|
|
206
|
+
unsigned long long rng_seed)
|
|
126
207
|
: vocab_size_(vocab_size),
|
|
127
|
-
inv_temperature_(
|
|
128
|
-
|
|
208
|
+
inv_temperature_(
|
|
209
|
+
(config.temperature != 0.0f) ? (1.0f / config.temperature) : 0.0f),
|
|
210
|
+
topp_(config.topp), min_p_(config.min_p),
|
|
211
|
+
repetition_penalty_(config.repetition_penalty), topk_(config.topk),
|
|
129
212
|
rng_state_(rng_seed) {}
|
|
130
213
|
|
|
131
|
-
Sampler::Sampler(
|
|
132
|
-
: Sampler(vocab_size,
|
|
214
|
+
Sampler::Sampler(int32_t vocab_size, GenerationConfig config)
|
|
215
|
+
: Sampler(vocab_size, config, std::time(nullptr)) {}
|
|
133
216
|
|
|
134
217
|
template <typename T> static void softmax(T *x, int size) {
|
|
135
|
-
//
|
|
218
|
+
// Runs after top-k/top-p masking, which sets rejected logits to lowest().
|
|
219
|
+
// Skip exp() on those: it underflows to 0 anyway and is slow on device.
|
|
220
|
+
constexpr T kMasked = std::numeric_limits<T>::lowest();
|
|
136
221
|
T max_val = x[0];
|
|
137
|
-
for (
|
|
222
|
+
for (size_t i = 1; i < size; i++) {
|
|
138
223
|
if (x[i] > max_val) {
|
|
139
224
|
max_val = x[i];
|
|
140
225
|
}
|
|
141
226
|
}
|
|
142
|
-
// exp and sum
|
|
143
227
|
T sum = 0;
|
|
144
|
-
for (
|
|
228
|
+
for (size_t i = 0; i < size; i++) {
|
|
229
|
+
if (x[i] == kMasked) {
|
|
230
|
+
x[i] = T(0);
|
|
231
|
+
continue;
|
|
232
|
+
}
|
|
145
233
|
x[i] = expf(x[i] - max_val);
|
|
146
234
|
sum += x[i];
|
|
147
235
|
}
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
236
|
+
for (size_t i = 0; i < size; i++) {
|
|
237
|
+
if (x[i] != T(0)) {
|
|
238
|
+
x[i] /= sum;
|
|
239
|
+
}
|
|
151
240
|
}
|
|
152
241
|
}
|
|
153
242
|
|
|
@@ -175,20 +264,18 @@ int32_t Sampler::sample(T *logits, const std::vector<uint64_t> &recent_tokens) {
|
|
|
175
264
|
apply_repetition_penalty(logits, vocab_size_, recent_tokens);
|
|
176
265
|
// 2. apply the temperature to the logits
|
|
177
266
|
apply_temperature(logits, vocab_size_);
|
|
178
|
-
// 3.
|
|
267
|
+
// 3. mask out logits outside top-k by rank (pre-softmax, becomes 0 mass)
|
|
268
|
+
mask_topk(logits);
|
|
269
|
+
// 4. mask out logits outside top-p by rank (pre-softmax)
|
|
270
|
+
mask_topp(logits);
|
|
271
|
+
// 5. apply softmax to the logits to get the probabilities for next token
|
|
179
272
|
softmax(logits, vocab_size_);
|
|
180
|
-
//
|
|
273
|
+
// 6. apply min_p truncation
|
|
181
274
|
apply_min_p(logits, vocab_size_);
|
|
182
275
|
// flip a (float) coin (this is our source of entropy for sampling)
|
|
183
276
|
float coin = random_f32(&rng_state_);
|
|
184
|
-
//
|
|
185
|
-
|
|
186
|
-
// simply sample from the predicted probability distribution
|
|
187
|
-
next = sample_mult(logits, coin);
|
|
188
|
-
} else {
|
|
189
|
-
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
|
190
|
-
next = sample_topp(logits, coin);
|
|
191
|
-
}
|
|
277
|
+
// 7. we sample from this distribution to get the next token
|
|
278
|
+
next = sample_mult(logits, coin);
|
|
192
279
|
}
|
|
193
280
|
return next;
|
|
194
281
|
}
|
package/common/runner/sampler.h
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
#pragma once
|
|
10
10
|
|
|
11
|
+
#include "runner/irunner.h"
|
|
11
12
|
#include <algorithm>
|
|
12
13
|
#include <cctype>
|
|
13
14
|
#include <cmath>
|
|
@@ -28,6 +29,7 @@ namespace executorch {
|
|
|
28
29
|
namespace extension {
|
|
29
30
|
namespace llm {
|
|
30
31
|
// A simple llama2 sampler.
|
|
32
|
+
struct GenerationConfig;
|
|
31
33
|
|
|
32
34
|
inline constexpr auto kTopp = 0.9f;
|
|
33
35
|
|
|
@@ -38,11 +40,13 @@ template <typename T> struct ProbIndex {
|
|
|
38
40
|
|
|
39
41
|
class Sampler {
|
|
40
42
|
public:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
Sampler(int32_t vocab_size,
|
|
43
|
+
// topk <= 0 disables top-k filtering. topp <= 0 || topp >= 1 disables top-p.
|
|
44
|
+
// Pipeline when temperature != 0: temperature -> top-k mask -> top-p mask
|
|
45
|
+
// -> softmax -> multinomial. Note: topk == 1 with temperature != 0 collapses
|
|
46
|
+
// to greedy; pass topk = 0 to keep full-vocab temperature sampling.
|
|
47
|
+
Sampler(int32_t vocab_size, GenerationConfig config,
|
|
48
|
+
unsigned long long rng_seed);
|
|
49
|
+
Sampler(int32_t vocab_size, GenerationConfig config);
|
|
46
50
|
|
|
47
51
|
template <typename T> int32_t sample(T *logits);
|
|
48
52
|
|
|
@@ -53,6 +57,9 @@ private:
|
|
|
53
57
|
template <typename T> int32_t sample_topp(T *probabilities, float coin);
|
|
54
58
|
template <typename T> int32_t sample_mult(T *probabilities, float coin);
|
|
55
59
|
template <typename T> int32_t sample_argmax(T *probabilities);
|
|
60
|
+
// In-place logit warpers: set excluded indices to -inf.
|
|
61
|
+
template <typename T> void mask_topk(T *logits);
|
|
62
|
+
template <typename T> void mask_topp(T *logits);
|
|
56
63
|
|
|
57
64
|
template <typename T>
|
|
58
65
|
inline void apply_temperature(T *logits, int32_t vocab_size) {
|
|
@@ -110,6 +117,7 @@ private:
|
|
|
110
117
|
float topp_;
|
|
111
118
|
float min_p_;
|
|
112
119
|
float repetition_penalty_;
|
|
120
|
+
int32_t topk_;
|
|
113
121
|
unsigned long long rng_state_;
|
|
114
122
|
};
|
|
115
123
|
|
|
@@ -31,7 +31,6 @@ TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager,
|
|
|
31
31
|
// outer loop (call site) is responsible for managing state.
|
|
32
32
|
::executorch::runtime::Result<executorch::aten::Tensor>
|
|
33
33
|
TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
|
|
34
|
-
// ET_LOG(Info, "Input token %" PRIu64, input_token);
|
|
35
34
|
auto method_meta_result = module_->method_meta("forward");
|
|
36
35
|
if (!method_meta_result.ok()) {
|
|
37
36
|
return method_meta_result.error();
|
|
@@ -102,9 +101,7 @@ int32_t TextDecoderRunner::logits_to_token(
|
|
|
102
101
|
auto num_tokens = logits_tensor.size(1);
|
|
103
102
|
logits += (num_tokens - 1) * vocab_size;
|
|
104
103
|
}
|
|
105
|
-
Sampler sampler(vocab_size, config_
|
|
106
|
-
static_cast<unsigned long long>(std::time(nullptr)),
|
|
107
|
-
config_.min_p, config_.repetition_penalty);
|
|
104
|
+
Sampler sampler(vocab_size, config_);
|
|
108
105
|
result = sampler.sample(logits, recent_tokens);
|
|
109
106
|
});
|
|
110
107
|
return result;
|
|
@@ -10,6 +10,7 @@
|
|
|
10
10
|
|
|
11
11
|
#pragma once
|
|
12
12
|
|
|
13
|
+
#include "constants.h"
|
|
13
14
|
#include "io_manager.h"
|
|
14
15
|
#include "sampler.h"
|
|
15
16
|
|
|
@@ -40,8 +41,8 @@ public:
|
|
|
40
41
|
step(TensorPtr &input, int64_t start_pos);
|
|
41
42
|
|
|
42
43
|
/**
|
|
43
|
-
* Load the Module for text decode purpose.
|
|
44
|
-
*
|
|
44
|
+
* Load the Module for text decode purpose. Loads the dynamic-shape `forward`
|
|
45
|
+
* method used for both prefill and decode.
|
|
45
46
|
*/
|
|
46
47
|
virtual ::executorch::runtime::Error load() {
|
|
47
48
|
return module_->load_method("forward");
|
|
@@ -18,10 +18,11 @@ namespace llm {
|
|
|
18
18
|
|
|
19
19
|
TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner,
|
|
20
20
|
bool use_kv_cache, bool enable_parallel_prefill,
|
|
21
|
-
int64_t max_seq_len)
|
|
21
|
+
int64_t max_seq_len, int32_t prefill_chunk_size)
|
|
22
22
|
: text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache),
|
|
23
23
|
enable_parallel_prefill_(enable_parallel_prefill),
|
|
24
|
-
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128)
|
|
24
|
+
max_seq_len_(max_seq_len > 0 ? max_seq_len : 128),
|
|
25
|
+
prefill_chunk_size_(prefill_chunk_size) {}
|
|
25
26
|
|
|
26
27
|
::executorch::runtime::Result<uint64_t>
|
|
27
28
|
TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
|
|
@@ -31,17 +32,17 @@ TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
|
|
|
31
32
|
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
|
|
32
33
|
}
|
|
33
34
|
|
|
34
|
-
// Check if we need to chunk the prompt tokens
|
|
35
35
|
int32_t num_prompt_tokens = prompt_tokens.size();
|
|
36
|
+
int32_t chunk_size =
|
|
37
|
+
prefill_chunk_size_ > 0 ? prefill_chunk_size_ : max_seq_len_;
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
if (num_prompt_tokens > max_seq_len_) {
|
|
39
|
+
if (num_prompt_tokens > chunk_size) {
|
|
39
40
|
uint64_t cur_token = 0;
|
|
40
41
|
int num_tokens_to_process = 0;
|
|
41
42
|
|
|
42
43
|
while (num_tokens_to_process < num_prompt_tokens) {
|
|
43
|
-
auto num_tokens_to_prefill_with =
|
|
44
|
-
num_prompt_tokens - num_tokens_to_process,
|
|
44
|
+
auto num_tokens_to_prefill_with =
|
|
45
|
+
std::min<int>(num_prompt_tokens - num_tokens_to_process, chunk_size);
|
|
45
46
|
|
|
46
47
|
std::vector<uint64_t> prompt_tokens_to_process(
|
|
47
48
|
num_tokens_to_prefill_with);
|
|
@@ -75,7 +76,6 @@ TextPrefiller::prefill_chunk(std::vector<uint64_t> &prompt_tokens,
|
|
|
75
76
|
// store the token
|
|
76
77
|
uint64_t cur_token;
|
|
77
78
|
if (enable_parallel_prefill_ || !use_kv_cache_) {
|
|
78
|
-
// initialize tensor wrappers
|
|
79
79
|
auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens},
|
|
80
80
|
executorch::aten::ScalarType::Long);
|
|
81
81
|
|
|
@@ -19,8 +19,14 @@ namespace llm {
|
|
|
19
19
|
|
|
20
20
|
class TextPrefiller {
|
|
21
21
|
public:
|
|
22
|
+
// prefill_chunk_size: when > 0, the prompt is always processed in steps of
|
|
23
|
+
// this size (see prefill()). Set to the model's forward sequence-length cap
|
|
24
|
+
// for the MLX backend (its forward is exported with a sliding-window bound
|
|
25
|
+
// and one-shot prefill spikes Metal memory). Other backends (XNNPACK/CoreML)
|
|
26
|
+
// pass 0 → original one-shot behavior.
|
|
22
27
|
TextPrefiller(TextDecoderRunner *text_decoder_runner, bool use_kv_cache,
|
|
23
|
-
bool enable_parallel_prefill, int64_t max_seq_len = 128
|
|
28
|
+
bool enable_parallel_prefill, int64_t max_seq_len = 128,
|
|
29
|
+
int32_t prefill_chunk_size = 0);
|
|
24
30
|
|
|
25
31
|
virtual ~TextPrefiller() = default;
|
|
26
32
|
/**
|
|
@@ -70,6 +76,7 @@ private:
|
|
|
70
76
|
bool use_kv_cache_;
|
|
71
77
|
bool enable_parallel_prefill_;
|
|
72
78
|
int64_t max_seq_len_;
|
|
79
|
+
int32_t prefill_chunk_size_;
|
|
73
80
|
};
|
|
74
81
|
|
|
75
82
|
} // namespace llm
|
|
@@ -26,11 +26,24 @@ Error TextRunner::load_subcomponents() {
|
|
|
26
26
|
|
|
27
27
|
Stats *stats_ptr = &stats_;
|
|
28
28
|
|
|
29
|
-
text_decoder_runner_ =
|
|
30
|
-
*module_, io_manager_.get(), config_);
|
|
29
|
+
text_decoder_runner_ =
|
|
30
|
+
std::make_unique<TextDecoderRunner>(*module_, io_manager_.get(), config_);
|
|
31
|
+
|
|
32
|
+
int32_t prefill_chunk_size = 0;
|
|
33
|
+
auto fwd_meta = module_->method_meta("forward");
|
|
34
|
+
if (fwd_meta.ok() && fwd_meta->uses_backend("MLXBackend")) {
|
|
35
|
+
auto input_meta = fwd_meta->input_tensor_meta(0);
|
|
36
|
+
if (input_meta.ok()) {
|
|
37
|
+
auto sizes = input_meta->sizes();
|
|
38
|
+
if (sizes.size() >= 2 && sizes[sizes.size() - 1] > 0) {
|
|
39
|
+
prefill_chunk_size = sizes[sizes.size() - 1];
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
31
44
|
text_prefiller_ = std::make_unique<TextPrefiller>(
|
|
32
45
|
text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
33
|
-
config_.enable_dynamic_shape, config_.max_seq_len);
|
|
46
|
+
config_.enable_dynamic_shape, config_.max_seq_len, prefill_chunk_size);
|
|
34
47
|
text_token_generator_ = std::make_unique<TextTokenGenerator>(
|
|
35
48
|
tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
36
49
|
std::move(eos_ids_), stats_ptr, config_);
|
|
@@ -65,6 +78,10 @@ Error TextRunner::generate_internal(
|
|
|
65
78
|
|
|
66
79
|
stats_.inference_start_ms = time_in_ms();
|
|
67
80
|
|
|
81
|
+
// Multi-turn: JS re-renders the full chat history each call, so reset KV
|
|
82
|
+
// position to 0 and re-prefill from scratch.
|
|
83
|
+
pos_ = 0;
|
|
84
|
+
|
|
68
85
|
int64_t context_len_left =
|
|
69
86
|
static_cast<int64_t>(config_.max_context_length) - pos_;
|
|
70
87
|
|
|
@@ -79,16 +96,25 @@ Error TextRunner::generate_internal(
|
|
|
79
96
|
std::vector<uint64_t> prompt_tokens = encodeResult.get();
|
|
80
97
|
int num_prompt_tokens = prompt_tokens.size();
|
|
81
98
|
|
|
99
|
+
// For dynamic-shape PTEs (e.g. Gemma4 MLX/Vulkan), get_max_seq_len is the
|
|
100
|
+
// per-call decoder chunk size (e.g. the sliding window) and the real
|
|
101
|
+
// generation budget lives in get_max_context_len. Static-shape PTEs set both
|
|
102
|
+
// equal, so this collapses to the old behavior. Without this the budget is
|
|
103
|
+
// computed from the small chunk size, so max_new_tokens can resolve to ~0 and
|
|
104
|
+
// generation ends immediately after prefill.
|
|
105
|
+
const int32_t seq_cap = config_.enable_dynamic_shape
|
|
106
|
+
? config_.max_context_length
|
|
107
|
+
: config_.max_seq_len;
|
|
108
|
+
|
|
82
109
|
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
|
|
83
110
|
"Expected at least 1 prompt token");
|
|
84
|
-
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens <
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
num_prompt_tokens, config_.max_seq_len);
|
|
111
|
+
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < seq_cap, InvalidArgument,
|
|
112
|
+
"num_prompt_tokens %d >= seq cap %" PRId32,
|
|
113
|
+
num_prompt_tokens, seq_cap);
|
|
88
114
|
|
|
89
115
|
int32_t max_new_tokens = resolve_max_new_tokens(
|
|
90
|
-
num_prompt_tokens,
|
|
91
|
-
|
|
116
|
+
num_prompt_tokens, seq_cap, static_cast<int32_t>(context_len_left),
|
|
117
|
+
config_.max_new_tokens);
|
|
92
118
|
|
|
93
119
|
ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
|
|
94
120
|
"Max new tokens %d is <= 0", max_new_tokens);
|
|
@@ -100,8 +100,8 @@ public:
|
|
|
100
100
|
prev_token = cur_token;
|
|
101
101
|
|
|
102
102
|
stats_->on_sampling_begin();
|
|
103
|
-
cur_token =
|
|
104
|
-
|
|
103
|
+
cur_token = text_decoder_runner_->logits_to_token(logits_tensor,
|
|
104
|
+
generated_tokens);
|
|
105
105
|
stats_->on_sampling_end();
|
|
106
106
|
|
|
107
107
|
pos++;
|
|
@@ -152,7 +152,6 @@ public:
|
|
|
152
152
|
if (should_stop_) {
|
|
153
153
|
break;
|
|
154
154
|
}
|
|
155
|
-
|
|
156
155
|
// data-dependent terminating condition: we have n_eos_ number of EOS
|
|
157
156
|
if (eos_ids_->find(cur_token) != eos_ids_->end()) {
|
|
158
157
|
printf("\n");
|
package/common/runner/util.h
CHANGED
|
@@ -6,7 +6,7 @@ import { SlidingWindowContextStrategy } from '../utils/llms/context_strategy';
|
|
|
6
6
|
* Default system prompt used to guide the behavior of Large Language Models (LLMs).
|
|
7
7
|
* @category Utilities - LLM
|
|
8
8
|
*/
|
|
9
|
-
export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text.";
|
|
9
|
+
export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text. If provided with audio samples treat it with at most importance";
|
|
10
10
|
|
|
11
11
|
/**
|
|
12
12
|
* Generates a default structured output prompt based on the provided JSON schema.
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+
|
|
1
|
+
{"version":3,"names":["SlidingWindowContextStrategy","DEFAULT_SYSTEM_PROMPT","DEFAULT_STRUCTURED_OUTPUT_PROMPT","structuredOutputSchema","DEFAULT_MESSAGE_HISTORY","DEFAULT_CONTEXT_BUFFER_TOKENS","DEFAULT_CHAT_CONFIG","systemPrompt","initialMessageHistory","contextStrategy"],"sourceRoot":"../../../src","sources":["constants/llmDefaults.ts"],"mappings":";;AACA,SAASA,4BAA4B,QAAQ,gCAAgC;;AAE7E;AACA;AACA;AACA;AACA,OAAO,MAAMC,qBAAqB,GAChC,+UAA+U;;AAEjV;AACA;AACA;AACA;AACA;AACA;AACA,OAAO,MAAMC,gCAAgC,GAC3CC,sBAA8B,IAC3B;AACL;AACA;AACA;AACA;AACA;AACA,EAAEA,sBAAsB;AACxB,CAAC;;AAED;AACA;AACA;AACA;AACA,OAAO,MAAMC,uBAAkC,GAAG,EAAE;;AAEpD;AACA;AACA;AACA;AACA,OAAO,MAAMC,6BAA6B,GAAG,GAAG;;AAEhD;AACA;AACA;AACA;AACA,OAAO,MAAMC,mBAA+B,GAAG;EAC7CC,YAAY,EAAEN,qBAAqB;EACnCO,qBAAqB,EAAEJ,uBAAuB;EAC9CK,eAAe,EAAE,IAAIT,4BAA4B,CAC/CK,6BACF;AACF,CAAC","ignoreList":[]}
|