react-native-executorch 0.8.2 → 0.8.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/android/src/main/java/com/swmansion/rnexecutorch/ETInstallerUnavailable.kt +27 -0
  2. package/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +12 -1
  3. package/common/rnexecutorch/host_objects/ModelHostObject.h +12 -1
  4. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +6 -0
  5. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h +3 -0
  6. package/common/rnexecutorch/models/llm/LLM.cpp +31 -3
  7. package/common/rnexecutorch/models/llm/LLM.h +2 -0
  8. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +2 -0
  9. package/common/rnexecutorch/models/text_to_image/TextToImage.h +2 -0
  10. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +6 -0
  11. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +5 -0
  12. package/common/runner/base_llm_runner.cpp +8 -6
  13. package/common/runner/base_llm_runner.h +8 -4
  14. package/common/runner/encoders/vision_encoder.cpp +12 -4
  15. package/common/runner/irunner.h +15 -0
  16. package/common/runner/multimodal_decoder_runner.h +3 -2
  17. package/common/runner/multimodal_runner.cpp +4 -16
  18. package/common/runner/multimodal_runner.h +0 -4
  19. package/common/runner/sampler.cpp +32 -13
  20. package/common/runner/sampler.h +59 -1
  21. package/common/runner/text_decoder_runner.cpp +31 -3
  22. package/common/runner/text_decoder_runner.h +13 -46
  23. package/common/runner/text_runner.cpp +3 -26
  24. package/common/runner/text_runner.h +0 -4
  25. package/common/runner/text_token_generator.h +20 -18
  26. package/lib/module/constants/modelUrls.js +53 -10
  27. package/lib/module/constants/modelUrls.js.map +1 -1
  28. package/lib/module/controllers/LLMController.js +75 -22
  29. package/lib/module/controllers/LLMController.js.map +1 -1
  30. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -0
  31. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  32. package/lib/module/index.js +11 -0
  33. package/lib/module/index.js.map +1 -1
  34. package/lib/module/modules/natural_language_processing/LLMModule.js +1 -1
  35. package/lib/module/types/llm.js +4 -1
  36. package/lib/module/types/llm.js.map +1 -1
  37. package/lib/typescript/constants/modelUrls.d.ts +126 -0
  38. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  39. package/lib/typescript/controllers/LLMController.d.ts +3 -1
  40. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  41. package/lib/typescript/index.d.ts +7 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +1 -1
  44. package/lib/typescript/types/llm.d.ts +21 -1
  45. package/lib/typescript/types/llm.d.ts.map +1 -1
  46. package/package.json +1 -1
  47. package/src/constants/modelUrls.ts +45 -2
  48. package/src/controllers/LLMController.ts +84 -25
  49. package/src/hooks/natural_language_processing/useLLM.ts +1 -0
  50. package/src/index.ts +11 -0
  51. package/src/modules/natural_language_processing/LLMModule.ts +1 -1
  52. package/src/types/llm.ts +21 -1
@@ -0,0 +1,27 @@
1
+ package com.swmansion.rnexecutorch
2
+
3
+ import com.facebook.react.bridge.ReactApplicationContext
4
+ import com.facebook.react.bridge.ReactMethod
5
+ import com.facebook.react.common.annotations.FrameworkAPI
6
+ import com.facebook.react.module.annotations.ReactModule
7
+
8
+ /**
9
+ * Fallback TurboModule returned when native ExecuTorch libraries cannot be
10
+ * loaded (e.g. 32-bit Android devices where only arm64-v8a binaries are
11
+ * shipped). Extends the same spec as ETInstaller so JS sees a real linked
12
+ * module, but install() returns false to signal unavailability.
13
+ */
14
+ @OptIn(FrameworkAPI::class)
15
+ @ReactModule(name = ETInstallerUnavailable.NAME)
16
+ class ETInstallerUnavailable(
17
+ reactContext: ReactApplicationContext,
18
+ ) : NativeETInstallerSpec(reactContext) {
19
+ companion object {
20
+ const val NAME = NativeETInstallerSpec.NAME
21
+ }
22
+
23
+ @ReactMethod(isBlockingSynchronousMethod = true)
24
+ override fun install(): Boolean {
25
+ return false
26
+ }
27
+ }
@@ -15,7 +15,18 @@ class RnExecutorchPackage : TurboReactPackage() {
15
15
  reactContext: ReactApplicationContext,
16
16
  ): NativeModule? =
17
17
  if (name == ETInstaller.NAME) {
18
- ETInstaller(reactContext)
18
+ try {
19
+ ETInstaller(reactContext)
20
+ } catch (e: RuntimeException) {
21
+ if (e.cause is UnsatisfiedLinkError) {
22
+ // Native library not available (e.g. 32-bit device without arm64-v8a .so).
23
+ // Return a fallback module whose install() returns false so JS can
24
+ // distinguish "unsupported ABI" from "package not linked."
25
+ ETInstallerUnavailable(reactContext)
26
+ } else {
27
+ throw e
28
+ }
29
+ }
19
30
  } else {
20
31
  null
21
32
  }
@@ -140,6 +140,15 @@ public:
140
140
  synchronousHostFunction<&Model::setTopp>,
141
141
  "setTopp"));
142
142
 
143
+ addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
144
+ synchronousHostFunction<&Model::setMinP>,
145
+ "setMinP"));
146
+
147
+ addFunctions(JSI_EXPORT_FUNCTION(
148
+ ModelHostObject<Model>,
149
+ synchronousHostFunction<&Model::setRepetitionPenalty>,
150
+ "setRepetitionPenalty"));
151
+
143
152
  addFunctions(JSI_EXPORT_FUNCTION(
144
153
  ModelHostObject<Model>,
145
154
  synchronousHostFunction<&Model::getMaxContextLength>,
@@ -375,7 +384,9 @@ public:
375
384
  // We need to dispatch a thread if we want the function to be
376
385
  // asynchronous. In this thread all accesses to jsi::Runtime need to
377
386
  // be done via the callInvoker.
378
- threads::GlobalThreadPool::detach([this, promise,
387
+ threads::GlobalThreadPool::detach([model = this->model,
388
+ callInvoker = this->callInvoker,
389
+ promise,
379
390
  argsConverted =
380
391
  std::move(argsConverted)]() {
381
392
  try {
@@ -35,8 +35,14 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
35
35
  return {.inputIds = inputIds64, .attentionMask = attentionMask};
36
36
  }
37
37
 
38
+ void TextEmbeddings::unload() noexcept {
39
+ std::scoped_lock lock(inference_mutex_);
40
+ BaseModel::unload();
41
+ }
42
+
38
43
  std::shared_ptr<OwningArrayBuffer>
39
44
  TextEmbeddings::generate(const std::string input) {
45
+ std::scoped_lock lock(inference_mutex_);
40
46
  auto preprocessed = preprocess(input);
41
47
 
42
48
  std::vector<int32_t> tokenIdsShape = {
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
4
+ #include <mutex>
4
5
  #include <rnexecutorch/TokenizerModule.h>
5
6
  #include <rnexecutorch/models/embeddings/BaseEmbeddings.h>
6
7
 
@@ -20,8 +21,10 @@ public:
20
21
  [[nodiscard(
21
22
  "Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
22
23
  generate(const std::string input);
24
+ void unload() noexcept;
23
25
 
24
26
  private:
27
+ mutable std::mutex inference_mutex_;
25
28
  std::vector<std::vector<int32_t>> inputShapes;
26
29
  TokenIdsWithAttentionMask preprocess(const std::string &input);
27
30
  std::unique_ptr<TokenizerModule> tokenizer;
@@ -20,7 +20,7 @@ using executorch::runtime::Error;
20
20
  LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
21
21
  std::vector<std::string> capabilities,
22
22
  std::shared_ptr<react::CallInvoker> callInvoker)
23
- : BaseModel(modelSource, callInvoker, Module::LoadMode::File) {
23
+ : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) {
24
24
 
25
25
  if (capabilities.empty()) {
26
26
  runner_ =
@@ -42,8 +42,12 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
42
42
  throw RnExecutorchError(loadResult, "Failed to load LLM runner");
43
43
  }
44
44
 
45
- memorySizeLowerBound = fs::file_size(fs::path(modelSource)) +
46
- fs::file_size(fs::path(tokenizerSource));
45
+ // I am purposefully not adding file size of the model here. The reason is
46
+ // that Hermes would crash the app if we try to alloc too much memory here.
47
+ // Also, given we're using mmap, the true memory consumption of a model is not
48
+ // really equal to the size of the model. The size of the tokenizer file is a
49
+ // hint to the GC that this object might be worth getting rid of.
50
+ memorySizeLowerBound = fs::file_size(fs::path(tokenizerSource));
47
51
  }
48
52
 
49
53
  std::string LLM::generate(std::string input,
@@ -246,6 +250,30 @@ void LLM::setTopp(float topp) {
246
250
  runner_->set_topp(topp);
247
251
  }
248
252
 
253
+ void LLM::setMinP(float minP) {
254
+ if (!runner_ || !runner_->is_loaded()) {
255
+ throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
256
+ "Can't configure a model that's not loaded");
257
+ }
258
+ if (minP < 0.0f || minP > 1.0f) {
259
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig,
260
+ "Min-p must be between 0.0 and 1.0");
261
+ }
262
+ runner_->set_min_p(minP);
263
+ }
264
+
265
+ void LLM::setRepetitionPenalty(float repetitionPenalty) {
266
+ if (!runner_ || !runner_->is_loaded()) {
267
+ throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
268
+ "Can't configure a model that's not loaded");
269
+ }
270
+ if (repetitionPenalty < 0.0f) {
271
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig,
272
+ "Repetition penalty must be non-negative");
273
+ }
274
+ runner_->set_repetition_penalty(repetitionPenalty);
275
+ }
276
+
249
277
  int32_t LLM::getMaxContextLength() const {
250
278
  if (!runner_ || !runner_->is_loaded()) {
251
279
  throw RnExecutorchError(
@@ -38,6 +38,8 @@ public:
38
38
  void setCountInterval(size_t countInterval);
39
39
  void setTemperature(float temperature);
40
40
  void setTopp(float topp);
41
+ void setMinP(float minP);
42
+ void setRepetitionPenalty(float repetitionPenalty);
41
43
  void setTimeInterval(size_t timeInterval);
42
44
  int32_t getMaxContextLength() const;
43
45
 
@@ -58,6 +58,7 @@ std::shared_ptr<OwningArrayBuffer>
58
58
  TextToImage::generate(std::string input, int32_t imageSize,
59
59
  size_t numInferenceSteps, int32_t seed,
60
60
  std::shared_ptr<jsi::Function> callback) {
61
+ std::scoped_lock lock(inference_mutex_);
61
62
  setImageSize(imageSize);
62
63
  setSeed(seed);
63
64
 
@@ -137,6 +138,7 @@ size_t TextToImage::getMemoryLowerBound() const noexcept {
137
138
  }
138
139
 
139
140
  void TextToImage::unload() noexcept {
141
+ std::scoped_lock lock(inference_mutex_);
140
142
  encoder->unload();
141
143
  unet->unload();
142
144
  decoder->unload();
@@ -1,6 +1,7 @@
1
1
  #pragma once
2
2
 
3
3
  #include <memory>
4
+ #include <mutex>
4
5
  #include <string>
5
6
  #include <vector>
6
7
 
@@ -49,6 +50,7 @@ private:
49
50
  static constexpr float guidanceScale = 7.5f;
50
51
  static constexpr float latentsScale = 0.18215f;
51
52
  bool interrupted = false;
53
+ mutable std::mutex inference_mutex_;
52
54
 
53
55
  std::shared_ptr<react::CallInvoker> callInvoker;
54
56
  std::unique_ptr<Scheduler> scheduler;
@@ -54,8 +54,14 @@ VoiceActivityDetection::preprocess(std::span<float> waveform) const {
54
54
  return frameBuffer;
55
55
  }
56
56
 
57
+ void VoiceActivityDetection::unload() noexcept {
58
+ std::scoped_lock lock(inference_mutex_);
59
+ BaseModel::unload();
60
+ }
61
+
57
62
  std::vector<types::Segment>
58
63
  VoiceActivityDetection::generate(std::span<float> waveform) const {
64
+ std::scoped_lock lock(inference_mutex_);
59
65
 
60
66
  auto windowedInput = preprocess(waveform);
61
67
  auto [chunksNumber, remainder] = std::div(
@@ -5,6 +5,7 @@
5
5
  #include <executorch/extension/tensor/tensor.h>
6
6
  #include <executorch/extension/tensor/tensor_ptr.h>
7
7
  #include <executorch/runtime/core/evalue.h>
8
+ #include <mutex>
8
9
  #include <span>
9
10
 
10
11
  #include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
@@ -23,7 +24,11 @@ public:
23
24
  [[nodiscard("Registered non-void function")]] std::vector<types::Segment>
24
25
  generate(std::span<float> waveform) const;
25
26
 
27
+ void unload() noexcept;
28
+
26
29
  private:
30
+ mutable std::mutex inference_mutex_;
31
+
27
32
  std::vector<std::array<float, constants::kPaddedWindowSize>>
28
33
  preprocess(std::span<float> waveform) const;
29
34
  std::vector<types::Segment> postprocess(const std::vector<float> &scores,
@@ -139,20 +139,22 @@ int32_t BaseLLMRunner::get_max_context_length() const {
139
139
 
140
140
  void BaseLLMRunner::set_temperature(float temperature) noexcept {
141
141
  config_.temperature = temperature;
142
- set_temperature_impl(temperature);
143
142
  }
144
143
 
145
- void BaseLLMRunner::set_topp(float topp) noexcept {
146
- config_.topp = topp;
147
- set_topp_impl(topp);
144
+ void BaseLLMRunner::set_topp(float topp) noexcept { config_.topp = topp; }
145
+
146
+ void BaseLLMRunner::set_min_p(float min_p) noexcept { config_.min_p = min_p; }
147
+
148
+ void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept {
149
+ config_.repetition_penalty = repetition_penalty;
148
150
  }
149
151
 
150
152
  void BaseLLMRunner::set_count_interval(size_t count_interval) {
151
- set_count_interval_impl(count_interval);
153
+ config_.output_token_batch_size = count_interval;
152
154
  }
153
155
 
154
156
  void BaseLLMRunner::set_time_interval(size_t time_interval) {
155
- set_time_interval_impl(time_interval);
157
+ config_.batch_time_interval_ms = time_interval;
156
158
  }
157
159
 
158
160
  int32_t BaseLLMRunner::resolve_max_new_tokens(int32_t num_prompt_tokens,
@@ -53,6 +53,8 @@ public:
53
53
 
54
54
  void set_temperature(float temperature) noexcept;
55
55
  void set_topp(float topp) noexcept;
56
+ void set_min_p(float min_p) noexcept;
57
+ void set_repetition_penalty(float repetition_penalty) noexcept;
56
58
  void set_count_interval(size_t count_interval);
57
59
  void set_time_interval(size_t time_interval);
58
60
 
@@ -65,10 +67,12 @@ public:
65
67
  protected:
66
68
  virtual ::executorch::runtime::Error load_subcomponents() = 0;
67
69
  virtual void stop_impl() = 0;
68
- virtual void set_temperature_impl(float temperature) = 0;
69
- virtual void set_topp_impl(float topp) = 0;
70
- virtual void set_count_interval_impl(size_t count_interval) = 0;
71
- virtual void set_time_interval_impl(size_t time_interval) = 0;
70
+ // Sampling values and token-batching intervals live entirely in `config_`.
71
+ // The TextDecoderRunner / TextTokenGenerator shared by both TextRunner and
72
+ // MultimodalRunner are constructed with a const reference to `config_`
73
+ // and read those fields on every iteration, so writes via the public
74
+ // set_* methods on BaseLLMRunner take effect immediately with no virtual
75
+ // dispatch needed.
72
76
 
73
77
  int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len,
74
78
  int32_t max_context_len,
@@ -77,15 +77,23 @@ Result<VisionEncoder::ImageShape> VisionEncoder::getInputShape() const {
77
77
  std::vector<float>
78
78
  VisionEncoder::preprocessImage(const std::string &path,
79
79
  const ImageShape &targetShape) const {
80
- cv::Mat mat = rnexecutorch::image_processing::readImage(path);
81
- cv::resize(mat, mat, cv::Size(targetShape.width, targetShape.height));
82
- cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB);
80
+ // The bundled vision-encoder PTEs (e.g. LFM2.5-VL) bake rescale + normalize
81
+ // into the exported graph, so we hand raw 0-255 float pixel values to the
82
+ // module. Adding rescale / normalize here would double-apply the transform
83
+ // and destroy the input distribution. We reuse `resizePadded` for the
84
+ // aspect-ratio-preserving letterbox (it picks the pad colour from the
85
+ // source image corners, which blends better than a flat gray), then
86
+ // convert BGR->RGB and repack the raw pixels into CHW float.
87
+ cv::Mat src = rnexecutorch::image_processing::readImage(path);
88
+ cv::Mat canvas = rnexecutorch::image_processing::resizePadded(
89
+ src, cv::Size(targetShape.width, targetShape.height));
90
+ cv::cvtColor(canvas, canvas, cv::COLOR_BGR2RGB);
83
91
 
84
92
  const int32_t pixelCount = targetShape.height * targetShape.width;
85
93
  std::vector<float> chw(targetShape.channels * pixelCount);
86
94
  for (int32_t i = 0; i < pixelCount; ++i) {
87
95
  cv::Vec3b px =
88
- mat.at<cv::Vec3b>(i / targetShape.width, i % targetShape.width);
96
+ canvas.at<cv::Vec3b>(i / targetShape.width, i % targetShape.width);
89
97
  for (int32_t c = 0; c < targetShape.channels; ++c) {
90
98
  chw[c * pixelCount + i] = static_cast<float>(px[c]);
91
99
  }
@@ -58,6 +58,21 @@ struct GenerationConfig {
58
58
  // = more deterministic, higher = more diverse generations.
59
59
  float topp = -1.F;
60
60
 
61
+ // Minimum probability threshold: tokens with prob < min_p * max_prob are
62
+ // excluded. 0.0 disables min_p filtering.
63
+ float min_p = 0.0f;
64
+
65
+ // Multiplicative penalty applied to logits of recently generated tokens.
66
+ // Values > 1.0 discourage repetition. 1.0 disables the penalty.
67
+ float repetition_penalty = 1.0f;
68
+
69
+ // Token-batching parameters for the streaming token callback. The
70
+ // generator flushes a batch when either `output_token_batch_size` tokens
71
+ // have accumulated or `batch_time_interval_ms` milliseconds have elapsed
72
+ // since the last flush, whichever comes first.
73
+ size_t output_token_batch_size = 10;
74
+ size_t batch_time_interval_ms = 120;
75
+
61
76
  // Enable dynamic input shapes (if implemented) or not
62
77
  // Impacts the prefill phase and causes TextPrefiller to pass all the tokens
63
78
  // at once if set to true.
@@ -16,8 +16,9 @@
16
16
  namespace executorch::extension::llm {
17
17
  class MultimodalDecoderRunner : public TextDecoderRunner {
18
18
  public:
19
- explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager)
20
- : TextDecoderRunner(module, io_manager) {}
19
+ explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager,
20
+ const GenerationConfig &config)
21
+ : TextDecoderRunner(module, io_manager, config) {}
21
22
 
22
23
  inline ::executorch::runtime::Result<::executorch::aten::Tensor>
23
24
  step(TensorPtr &tokens, int64_t start_pos) override {
@@ -47,8 +47,8 @@ Error MultimodalRunner::load_subcomponents() {
47
47
 
48
48
  Stats *stats_ptr = &stats_;
49
49
 
50
- mm_decoder_runner_ =
51
- std::make_unique<MultimodalDecoderRunner>(*module_, io_manager_.get());
50
+ mm_decoder_runner_ = std::make_unique<MultimodalDecoderRunner>(
51
+ *module_, io_manager_.get(), config_);
52
52
  IEncoder *image_encoder = nullptr;
53
53
  auto enc_it = encoders_.find(MultimodalType::Image);
54
54
  if (enc_it != encoders_.end()) {
@@ -58,7 +58,7 @@ Error MultimodalRunner::load_subcomponents() {
58
58
  *module_, *mm_decoder_runner_, *tokenizer_, image_encoder);
59
59
  mm_token_generator_ = std::make_unique<TextTokenGenerator>(
60
60
  tokenizer_.get(), mm_decoder_runner_.get(), /*use_kv_cache=*/true,
61
- std::move(eos_ids_), stats_ptr);
61
+ std::move(eos_ids_), stats_ptr, config_);
62
62
 
63
63
  ET_CHECK_OK_OR_RETURN_ERROR(mm_prefiller_->load());
64
64
  ET_CHECK_OK_OR_RETURN_ERROR(mm_token_generator_->load());
@@ -106,7 +106,7 @@ Error MultimodalRunner::generate_internal(
106
106
  auto generate_result = mm_token_generator_->generate(
107
107
  seed_tokens, pos_,
108
108
  static_cast<uint64_t>(std::max(0, resolved_max_new - 1)),
109
- config_.temperature, config_.topp, wrapped_callback);
109
+ wrapped_callback);
110
110
 
111
111
  if (!generate_result.ok())
112
112
  return generate_result.error();
@@ -125,16 +125,4 @@ void MultimodalRunner::stop_impl() {
125
125
  }
126
126
  }
127
127
 
128
- void MultimodalRunner::set_count_interval_impl(size_t count_interval) {
129
- if (mm_token_generator_) {
130
- mm_token_generator_->set_count_interval(count_interval);
131
- }
132
- }
133
-
134
- void MultimodalRunner::set_time_interval_impl(size_t time_interval) {
135
- if (mm_token_generator_) {
136
- mm_token_generator_->set_time_interval(time_interval);
137
- }
138
- }
139
-
140
128
  } // namespace executorch::extension::llm
@@ -30,10 +30,6 @@ public:
30
30
  protected:
31
31
  ::executorch::runtime::Error load_subcomponents() override;
32
32
  void stop_impl() override;
33
- void set_temperature_impl(float) override {}
34
- void set_topp_impl(float) override {}
35
- void set_count_interval_impl(size_t count_interval) override;
36
- void set_time_interval_impl(size_t time_interval) override;
37
33
 
38
34
  private:
39
35
  std::map<MultimodalType, std::unique_ptr<IEncoder>> encoders_;
@@ -35,6 +35,7 @@
35
35
  #include "sampler.h"
36
36
  #include <algorithm>
37
37
  #include <ctime>
38
+ #include <vector>
38
39
 
39
40
  namespace executorch {
40
41
  namespace extension {
@@ -119,16 +120,16 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
119
120
  return probindex[last_idx].index; // in case of rounding errors
120
121
  }
121
122
 
122
- Sampler::Sampler(int vocab_size, float temperature, float topp,
123
- unsigned long long rng_seed)
123
+ Sampler::Sampler(int32_t vocab_size, float temperature, float topp,
124
+ unsigned long long rng_seed, float min_p,
125
+ float repetition_penalty)
124
126
  : vocab_size_(vocab_size),
125
127
  inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
126
- topp_(topp), rng_state_(rng_seed) {}
128
+ topp_(topp), min_p_(min_p), repetition_penalty_(repetition_penalty),
129
+ rng_state_(rng_seed) {}
127
130
 
128
131
  Sampler::Sampler(int vocab_size, float temperature, float topp)
129
- : vocab_size_(vocab_size),
130
- inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
131
- topp_(topp), rng_state_(std::time(nullptr)) {}
132
+ : Sampler(vocab_size, temperature, topp, std::time(nullptr), 0.0f, 1.0f) {}
132
133
 
133
134
  template <typename T> static void softmax(T *x, int size) {
134
135
  // find max value (for numerical stability)
@@ -162,22 +163,25 @@ static float random_f32(unsigned long long *state) { // random float32 in [0,1)
162
163
  return (random_u32(state) >> 8) / 16777216.0f;
163
164
  }
164
165
 
165
- template <typename T> int32_t Sampler::sample(T *logits) {
166
+ template <typename T>
167
+ int32_t Sampler::sample(T *logits, const std::vector<uint64_t> &recent_tokens) {
166
168
  // sample the token given the logits and some hyperparameters
167
169
  int next;
168
170
  if (inv_temperature_ == 0.0f) {
169
171
  // greedy argmax sampling: take the token with the highest probability
170
172
  next = sample_argmax(logits);
171
173
  } else {
172
- // apply the temperature to the logits
173
- for (int q = 0; q < vocab_size_; q++) {
174
- logits[q] *= inv_temperature_;
175
- }
176
- // apply softmax to the logits to get the probabilities for next token
174
+ // 1. apply repetition penalty to raw logits (pre-softmax)
175
+ apply_repetition_penalty(logits, vocab_size_, recent_tokens);
176
+ // 2. apply the temperature to the logits
177
+ apply_temperature(logits, vocab_size_);
178
+ // 3. apply softmax to the logits to get the probabilities for next token
177
179
  softmax(logits, vocab_size_);
180
+ // 4. apply min_p truncation
181
+ apply_min_p(logits, vocab_size_);
178
182
  // flip a (float) coin (this is our source of entropy for sampling)
179
183
  float coin = random_f32(&rng_state_);
180
- // we sample from this distribution to get the next token
184
+ // 5. we sample from this distribution to get the next token
181
185
  if (topp_ <= 0 || topp_ >= 1) {
182
186
  // simply sample from the predicted probability distribution
183
187
  next = sample_mult(logits, coin);
@@ -189,6 +193,10 @@ template <typename T> int32_t Sampler::sample(T *logits) {
189
193
  return next;
190
194
  }
191
195
 
196
+ template <typename T> int32_t Sampler::sample(T *logits) {
197
+ return sample(logits, {});
198
+ }
199
+
192
200
  template int32_t Sampler::sample<float>(float *logits);
193
201
  template int32_t Sampler::sample<uint16_t>(uint16_t *logits);
194
202
  template int32_t
@@ -196,6 +204,17 @@ Sampler::sample<executorch::aten::Half>(executorch::aten::Half *logits);
196
204
  template int32_t
197
205
  Sampler::sample<executorch::aten::BFloat16>(executorch::aten::BFloat16 *logits);
198
206
 
207
+ template int32_t Sampler::sample<float>(float *logits,
208
+ const std::vector<uint64_t> &);
209
+ template int32_t Sampler::sample<uint16_t>(uint16_t *logits,
210
+ const std::vector<uint64_t> &);
211
+ template int32_t
212
+ Sampler::sample<executorch::aten::Half>(executorch::aten::Half *logits,
213
+ const std::vector<uint64_t> &);
214
+ template int32_t
215
+ Sampler::sample<executorch::aten::BFloat16>(executorch::aten::BFloat16 *logits,
216
+ const std::vector<uint64_t> &);
217
+
199
218
  } // namespace llm
200
219
  } // namespace extension
201
220
  } // namespace executorch
@@ -8,12 +8,15 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include <algorithm>
11
12
  #include <cctype>
12
13
  #include <cmath>
13
14
  #include <cstdio>
14
15
  #include <cstdlib>
15
16
  #include <cstring>
16
17
  #include <memory>
18
+ #include <utility>
19
+ #include <vector>
17
20
  #ifdef USE_ATEN_LIB
18
21
  #include <torch/torch.h>
19
22
  #endif
@@ -36,22 +39,77 @@ template <typename T> struct ProbIndex {
36
39
  class Sampler {
37
40
  public:
38
41
  Sampler(int32_t vocab_size, float temperature, float topp,
39
- unsigned long long rng_seed);
42
+ unsigned long long rng_seed, float min_p = 0.0f,
43
+ float repetition_penalty = 1.0f);
40
44
 
41
45
  Sampler(int32_t vocab_size, float temperature, float topp);
42
46
 
43
47
  template <typename T> int32_t sample(T *logits);
44
48
 
49
+ template <typename T>
50
+ int32_t sample(T *logits, const std::vector<uint64_t> &recent_tokens);
51
+
45
52
  private:
46
53
  template <typename T> int32_t sample_topp(T *probabilities, float coin);
47
54
  template <typename T> int32_t sample_mult(T *probabilities, float coin);
48
55
  template <typename T> int32_t sample_argmax(T *probabilities);
49
56
 
57
+ template <typename T>
58
+ inline void apply_temperature(T *logits, int32_t vocab_size) {
59
+ for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) {
60
+ logits[i] =
61
+ static_cast<T>(static_cast<float>(logits[i]) * inv_temperature_);
62
+ }
63
+ }
64
+
65
+ template <typename T>
66
+ inline void
67
+ apply_repetition_penalty(T *logits, int32_t vocab_size,
68
+ const std::vector<uint64_t> &recent_tokens) {
69
+ if (repetition_penalty_ == 1.0f || recent_tokens.empty())
70
+ return;
71
+ for (uint64_t id : recent_tokens) {
72
+ if (!std::cmp_less(id, vocab_size)) {
73
+ continue;
74
+ }
75
+ T &val = logits[id];
76
+ if (val > T(0)) {
77
+ val = static_cast<T>(static_cast<float>(val) / repetition_penalty_);
78
+ } else {
79
+ val = static_cast<T>(static_cast<float>(val) * repetition_penalty_);
80
+ }
81
+ }
82
+ }
83
+
84
+ template <typename T>
85
+ inline void apply_min_p(T *probabilities, int32_t vocab_size) {
86
+ if (min_p_ <= 0.0f) {
87
+ return;
88
+ }
89
+ T max_prob = *std::max_element(probabilities, probabilities + vocab_size);
90
+ T threshold = static_cast<T>(min_p_ * static_cast<float>(max_prob));
91
+ T sum = T(0);
92
+ for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) {
93
+ if (probabilities[i] < threshold) {
94
+ probabilities[i] = T(0);
95
+ } else {
96
+ sum += probabilities[i];
97
+ }
98
+ }
99
+ if (sum > T(0)) {
100
+ for (std::size_t i = 0; std::cmp_less(i, vocab_size); ++i) {
101
+ probabilities[i] /= sum;
102
+ }
103
+ }
104
+ }
105
+
50
106
  private:
51
107
  int32_t vocab_size_;
52
108
  // reciprocal of temperature, or 0 if temperature == 0.
53
109
  float inv_temperature_;
54
110
  float topp_;
111
+ float min_p_;
112
+ float repetition_penalty_;
55
113
  unsigned long long rng_state_;
56
114
  };
57
115
 
@@ -10,6 +10,7 @@
10
10
 
11
11
  #include "text_decoder_runner.h"
12
12
  #include "arange_util.h"
13
+ #include "irunner.h"
13
14
  #include "stats.h"
14
15
 
15
16
  #include <ctime>
@@ -22,9 +23,8 @@ namespace llm {
22
23
  // and a ~5% improvement on Galaxy S22 by switching to
23
24
  // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
24
25
  TextDecoderRunner::TextDecoderRunner(Module &module, IOManager *io_manager,
25
- float temperature, float topp)
26
- : module_(&module), io_manager_(io_manager), temperature_(temperature),
27
- topp_(topp) {}
26
+ const GenerationConfig &config)
27
+ : module_(&module), io_manager_(io_manager), config_(config) {}
28
28
 
29
29
  // This function is functional, meaning it shouldn't modify any state of the
30
30
  // input. It should be safe to call multiple times with the same inputs. The
@@ -82,6 +82,34 @@ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
82
82
  }
83
83
  }
84
84
 
85
+ int32_t TextDecoderRunner::logits_to_token(
86
+ const executorch::aten::Tensor &logits_tensor,
87
+ const std::vector<uint64_t> &recent_tokens) {
88
+ int32_t result = 0;
89
+
90
+ struct {
91
+ [[noreturn]] void fail(torch::executor::Error) {
92
+ ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token");
93
+ }
94
+ } ctx;
95
+
96
+ ET_SWITCH_FOUR_TYPES(
97
+ Float, Half, BFloat16, UInt16, logits_tensor.scalar_type(), ctx,
98
+ "logits_to_token", CTYPE, [&]() {
99
+ auto *logits = logits_tensor.mutable_data_ptr<CTYPE>();
100
+ ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
101
+ if (logits_tensor.dim() == 3) {
102
+ auto num_tokens = logits_tensor.size(1);
103
+ logits += (num_tokens - 1) * vocab_size;
104
+ }
105
+ Sampler sampler(vocab_size, config_.temperature, config_.topp,
106
+ static_cast<unsigned long long>(std::time(nullptr)),
107
+ config_.min_p, config_.repetition_penalty);
108
+ result = sampler.sample(logits, recent_tokens);
109
+ });
110
+ return result;
111
+ }
112
+
85
113
  } // namespace llm
86
114
  } // namespace extension
87
115
  } // namespace executorch