react-native-executorch 0.9.0-nightly-0e95b89-20260525 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/android/libs/classes.jar +0 -0
  2. package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
  3. package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
  4. package/common/rnexecutorch/models/llm/LLM.h +4 -3
  5. package/common/rnexecutorch/models/llm/Types.h +23 -0
  6. package/common/runner/base_llm_runner.cpp +10 -3
  7. package/common/runner/base_llm_runner.h +1 -0
  8. package/common/runner/constants.h +15 -1
  9. package/common/runner/encoders/audio_encoder.cpp +111 -0
  10. package/common/runner/encoders/audio_encoder.h +40 -0
  11. package/common/runner/encoders/vision_encoder.cpp +0 -1
  12. package/common/runner/irunner.h +5 -0
  13. package/common/runner/multimodal_decoder_runner.h +50 -1
  14. package/common/runner/multimodal_input.h +16 -1
  15. package/common/runner/multimodal_prefiller.cpp +374 -64
  16. package/common/runner/multimodal_prefiller.h +57 -6
  17. package/common/runner/multimodal_runner.cpp +19 -12
  18. package/common/runner/multimodal_runner.h +1 -1
  19. package/common/runner/sampler.cpp +111 -35
  20. package/common/runner/sampler.h +13 -5
  21. package/common/runner/text_decoder_runner.cpp +1 -4
  22. package/common/runner/text_decoder_runner.h +3 -2
  23. package/common/runner/text_prefiller.cpp +8 -8
  24. package/common/runner/text_prefiller.h +8 -1
  25. package/common/runner/text_runner.cpp +35 -9
  26. package/common/runner/text_token_generator.h +2 -3
  27. package/common/runner/util.h +0 -1
  28. package/lib/module/constants/llmDefaults.js +1 -1
  29. package/lib/module/constants/llmDefaults.js.map +1 -1
  30. package/lib/module/constants/modelRegistry.js +33 -2
  31. package/lib/module/constants/modelRegistry.js.map +1 -1
  32. package/lib/module/constants/modelUrls.js +43 -6
  33. package/lib/module/constants/modelUrls.js.map +1 -1
  34. package/lib/module/controllers/LLMController.js +69 -20
  35. package/lib/module/controllers/LLMController.js.map +1 -1
  36. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
  37. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  38. package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
  39. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  40. package/lib/module/types/llm.js +11 -0
  41. package/lib/module/types/llm.js.map +1 -1
  42. package/lib/typescript/constants/llmDefaults.d.ts +1 -1
  43. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
  44. package/lib/typescript/constants/modelRegistry.d.ts +28 -1
  45. package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
  46. package/lib/typescript/constants/modelUrls.d.ts +40 -12
  47. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  48. package/lib/typescript/controllers/LLMController.d.ts +7 -9
  49. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  50. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
  51. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  52. package/lib/typescript/types/llm.d.ts +63 -36
  53. package/lib/typescript/types/llm.d.ts.map +1 -1
  54. package/package.json +1 -1
  55. package/react-native-executorch.podspec +6 -0
  56. package/src/constants/llmDefaults.ts +1 -1
  57. package/src/constants/modelRegistry.ts +34 -2
  58. package/src/constants/modelUrls.ts +47 -6
  59. package/src/controllers/LLMController.ts +89 -40
  60. package/src/hooks/natural_language_processing/useLLM.ts +5 -6
  61. package/src/modules/natural_language_processing/LLMModule.ts +19 -8
  62. package/src/types/llm.ts +64 -34
  63. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  64. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  65. package/third-party/include/executorch/ExecuTorch.h +2 -0
  66. package/third-party/include/executorch/ExecuTorchModule.h +46 -0
  67. package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
  68. package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
  69. package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
  70. package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
  71. package/third-party/include/executorch/extension/module/module.h +47 -8
  72. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
  73. package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
  74. package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
  75. package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
  76. package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
  77. package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
  78. package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
  79. package/third-party/include/executorch/runtime/core/error.h +1 -0
  80. package/third-party/include/executorch/runtime/core/evalue.h +256 -9
  81. package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
  82. package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
  83. package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
  84. package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
  85. package/third-party/include/executorch/runtime/executor/method.h +9 -3
  86. package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
  87. package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
  88. package/third-party/include/executorch/runtime/executor/program.h +3 -1
  89. package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
  90. package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
  91. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  92. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  93. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
  94. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  95. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  96. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
Binary file
@@ -4,6 +4,7 @@
4
4
  #include <cstdint>
5
5
  #include <set>
6
6
  #include <span>
7
+ #include <string>
7
8
  #include <type_traits>
8
9
  #include <unordered_map>
9
10
  #include <variant>
@@ -17,6 +18,7 @@
17
18
 
18
19
  #include <rnexecutorch/metaprogramming/TypeConcepts.h>
19
20
  #include <rnexecutorch/models/instance_segmentation/Types.h>
21
+ #include <rnexecutorch/models/llm/Types.h>
20
22
  #include <rnexecutorch/models/object_detection/Constants.h>
21
23
  #include <rnexecutorch/models/object_detection/Types.h>
22
24
  #include <rnexecutorch/models/ocr/Types.h>
@@ -223,6 +225,22 @@ inline std::vector<float> getValue<std::vector<float>>(const jsi::Value &val,
223
225
  return getArrayAsVector<float>(val, runtime);
224
226
  }
225
227
 
228
+ template <>
229
+ inline std::vector<std::vector<float>>
230
+ getValue<std::vector<std::vector<float>>>(const jsi::Value &val,
231
+ jsi::Runtime &runtime) {
232
+ jsi::Array array = val.asObject(runtime).asArray(runtime);
233
+ const size_t length = array.size(runtime);
234
+ std::vector<std::vector<float>> result;
235
+ result.reserve(length);
236
+ for (size_t i = 0; i < length; ++i) {
237
+ jsi::Value element = array.getValueAtIndex(runtime, i);
238
+ auto span = getTypedArrayAsSpan<float>(element, runtime);
239
+ result.emplace_back(span.begin(), span.end());
240
+ }
241
+ return result;
242
+ }
243
+
226
244
  template <>
227
245
  inline std::vector<int64_t>
228
246
  getValue<std::vector<int64_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
@@ -302,6 +320,31 @@ getValue<std::span<uint64_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
302
320
  return getTypedArrayAsSpan<uint64_t>(val, runtime);
303
321
  }
304
322
 
323
+ template <>
324
+ inline models::llm::MultimodalInputs
325
+ getValue<models::llm::MultimodalInputs>(const jsi::Value &val,
326
+ jsi::Runtime &runtime) {
327
+ models::llm::MultimodalInputs multimodalInputs;
328
+ jsi::Object obj = val.asObject(runtime);
329
+
330
+ jsi::Value v = obj.getProperty(runtime, "imageToken");
331
+ if (!v.isUndefined() && !v.isNull()) {
332
+ auto &images = multimodalInputs.images.emplace();
333
+ images.token = getValue<std::string>(v, runtime);
334
+ v = obj.getProperty(runtime, "imagePaths");
335
+ images.paths = getValue<std::vector<std::string>>(v, runtime);
336
+ }
337
+ v = obj.getProperty(runtime, "audioToken");
338
+ if (!v.isUndefined() && !v.isNull()) {
339
+ auto &audios = multimodalInputs.audios.emplace();
340
+ audios.token = getValue<std::string>(v, runtime);
341
+ v = obj.getProperty(runtime, "audioWaveforms");
342
+ audios.waveforms = getValue<std::vector<std::vector<float>>>(v, runtime);
343
+ }
344
+
345
+ return multimodalInputs;
346
+ }
347
+
305
348
  // Conversion from C++ types to jsi --------------------------------------------
306
349
 
307
350
  // Implementation functions might return any type, but in a promise we can only
@@ -1,11 +1,12 @@
1
1
  #include "LLM.h"
2
+ #include "rnexecutorch/models/llm/Types.h"
2
3
 
3
4
  #include <executorch/extension/tensor/tensor.h>
4
5
  #include <filesystem>
5
6
  #include <map>
6
7
  #include <rnexecutorch/Error.h>
7
- #include <rnexecutorch/Log.h>
8
8
  #include <rnexecutorch/threads/GlobalThreadPool.h>
9
+ #include <runner/encoders/audio_encoder.h>
9
10
  #include <runner/encoders/vision_encoder.h>
10
11
  #include <runner/multimodal_runner.h>
11
12
  #include <runner/text_runner.h>
@@ -21,7 +22,6 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
21
22
  std::vector<std::string> capabilities,
22
23
  std::shared_ptr<react::CallInvoker> callInvoker)
23
24
  : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) {
24
-
25
25
  if (capabilities.empty()) {
26
26
  runner_ =
27
27
  std::make_unique<llm::TextRunner>(std::move(module_), tokenizerSource);
@@ -31,6 +31,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
31
31
  if (cap == "vision") {
32
32
  encoders[llm::MultimodalType::Image] =
33
33
  std::make_unique<llm::VisionEncoder>(*module_);
34
+ } else if (cap == "audio") {
35
+ encoders[llm::MultimodalType::Audio] =
36
+ std::make_unique<llm::AudioEncoder>(*module_);
34
37
  }
35
38
  }
36
39
  runner_ = std::make_unique<llm::MultimodalRunner>(
@@ -75,62 +78,73 @@ std::string LLM::generate(std::string input,
75
78
  }
76
79
 
77
80
  std::string LLM::generateMultimodal(std::string prompt,
78
- std::vector<std::string> imagePaths,
79
- std::string imageToken,
80
- std::shared_ptr<jsi::Function> callback) {
81
+ std::shared_ptr<jsi::Function> callback,
82
+ MultimodalInputs mutlimodalInputs) {
81
83
  if (!runner_ || !runner_->is_loaded()) {
82
84
  throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded,
83
85
  "Runner is not loaded");
84
86
  }
85
87
  if (!runner_->is_multimodal()) {
86
- throw RnExecutorchError(
87
- RnExecutorchErrorCode::InvalidUserInput,
88
- "This model does not support multimodal input. Use generate(prompt, "
89
- "callback) for text-only generation.");
88
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
89
+ "This model does not support multimodal input.");
90
90
  }
91
- if (imageToken.empty()) {
91
+ if (!mutlimodalInputs.images.has_value() &&
92
+ !mutlimodalInputs.audios.has_value()) {
92
93
  throw RnExecutorchError(
93
94
  RnExecutorchErrorCode::InvalidUserInput,
94
- "imageToken must not be empty. Pass the model's image token (e.g. "
95
- "from tokenizer_config.json).");
95
+ "At least one of imageToken/audioToken must be non-empty");
96
96
  }
97
97
 
98
- const size_t kImageTokenLen = imageToken.size();
99
-
98
+ // Scan the prompt once, splitting at the earliest placeholder at each step
99
+ // so that image/audio placeholders can be freely interleaved in the prompt.
100
100
  std::vector<llm::MultimodalInput> inputs;
101
- size_t imageIdx = 0;
102
- size_t searchPos = 0;
103
-
104
- while (true) {
105
- size_t found = prompt.find(imageToken, searchPos);
106
- if (found == std::string::npos) {
107
- if (searchPos < prompt.size()) {
108
- inputs.push_back(llm::make_text_input(prompt.substr(searchPos)));
109
- }
101
+ size_t imageIdx = 0, audioIdx = 0, pos = 0;
102
+ while (pos < prompt.size()) {
103
+ size_t imgAt = mutlimodalInputs.images.has_value()
104
+ ? prompt.find(mutlimodalInputs.images.value().token, pos)
105
+ : std::string::npos;
106
+ size_t audAt = mutlimodalInputs.audios.has_value()
107
+ ? prompt.find(mutlimodalInputs.audios.value().token, pos)
108
+ : std::string::npos;
109
+ if (imgAt == std::string::npos && audAt == std::string::npos) {
110
+ inputs.push_back(llm::make_text_input(prompt.substr(pos)));
110
111
  break;
111
112
  }
112
- // Text segment before this placeholder
113
- if (found > searchPos) {
114
- inputs.push_back(
115
- llm::make_text_input(prompt.substr(searchPos, found - searchPos)));
113
+ const bool imageFirst = imgAt != std::string::npos &&
114
+ (audAt == std::string::npos || imgAt < audAt);
115
+ size_t at = imageFirst ? imgAt : audAt;
116
+ if (at > pos) {
117
+ inputs.push_back(llm::make_text_input(prompt.substr(pos, at - pos)));
116
118
  }
117
- // Image at this position
118
- if (imageIdx >= imagePaths.size()) {
119
- throw RnExecutorchError(
120
- RnExecutorchErrorCode::InvalidUserInput,
121
- "More '" + imageToken +
122
- "' placeholders in prompt than image paths provided");
119
+ if (imageFirst) {
120
+ auto &images = mutlimodalInputs.images.value();
121
+ if (imageIdx >= images.paths.size()) {
122
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
123
+ "More '" + images.token +
124
+ "' placeholders than image paths");
125
+ }
126
+ inputs.push_back(llm::make_image_input(images.paths[imageIdx++]));
127
+ pos = at + images.token.size();
128
+ } else {
129
+ auto &audios = mutlimodalInputs.audios.value();
130
+ if (audioIdx >= audios.waveforms.size()) {
131
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
132
+ "More '" + audios.token +
133
+ "' placeholders than audio waveforms");
134
+ }
135
+ inputs.push_back(
136
+ llm::make_audio_input(std::move(audios.waveforms[audioIdx++])));
137
+ pos = at + audios.token.size();
123
138
  }
124
- inputs.push_back(llm::make_image_input(imagePaths[imageIdx++]));
125
- searchPos = found + kImageTokenLen;
126
139
  }
127
-
128
- if (imageIdx < imagePaths.size()) {
129
- throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
130
- "More image paths provided than '" + imageToken +
131
- "' placeholders in prompt");
140
+ if ((mutlimodalInputs.images.has_value() &&
141
+ imageIdx < mutlimodalInputs.images.value().paths.size()) ||
142
+ (mutlimodalInputs.audios.has_value() &&
143
+ audioIdx < mutlimodalInputs.audios.value().waveforms.size())) {
144
+ throw RnExecutorchError(
145
+ RnExecutorchErrorCode::InvalidUserInput,
146
+ "More image/audio paths provided than placeholders in prompt");
132
147
  }
133
-
134
148
  if (inputs.empty()) {
135
149
  throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
136
150
  "No inputs to generate from");
@@ -150,7 +164,6 @@ std::string LLM::generateMultimodal(std::string prompt,
150
164
  if (error != Error::Ok) {
151
165
  throw RnExecutorchError(error, "Failed to generate multimodal response");
152
166
  }
153
-
154
167
  return output;
155
168
  }
156
169
 
@@ -7,6 +7,7 @@
7
7
  #include <ReactCommon/CallInvoker.h>
8
8
  #include <jsi/jsi.h>
9
9
  #include <rnexecutorch/models/BaseModel.h>
10
+ #include <rnexecutorch/models/llm/Types.h>
10
11
  #include <runner/base_llm_runner.h>
11
12
 
12
13
  namespace rnexecutorch {
@@ -22,10 +23,10 @@ public:
22
23
 
23
24
  std::string generate(std::string prompt,
24
25
  std::shared_ptr<jsi::Function> callback);
26
+
25
27
  std::string generateMultimodal(std::string prompt,
26
- std::vector<std::string> imagePaths,
27
- std::string imageToken,
28
- std::shared_ptr<jsi::Function> callback);
28
+ std::shared_ptr<jsi::Function> callback,
29
+ MultimodalInputs mutlimodalInputs = {});
29
30
 
30
31
  void interrupt();
31
32
  void reset();
@@ -0,0 +1,23 @@
1
+ #pragma once
2
+
3
+ #include <optional>
4
+ #include <string>
5
+ #include <vector>
6
+
7
+ namespace rnexecutorch::models::llm {
8
+ struct ImageInputs {
9
+ std::vector<std::string> paths;
10
+ std::string token;
11
+ };
12
+
13
+ struct AudioInputs {
14
+ std::vector<std::vector<float>> waveforms;
15
+ std::string token;
16
+ };
17
+
18
+ struct MultimodalInputs {
19
+ std::optional<ImageInputs> images;
20
+ std::optional<AudioInputs> audios;
21
+ };
22
+
23
+ } // namespace rnexecutorch::models::llm
@@ -56,11 +56,16 @@ Error BaseLLMRunner::load() {
56
56
  ? static_cast<int32_t>(metadata_.at(kMaxContextLen))
57
57
  : static_cast<int32_t>(metadata_.at(kMaxSeqLen));
58
58
  }
59
- if (config_.max_new_tokens < 0)
60
- config_.max_new_tokens =
61
- std::min(config_.max_seq_len, config_.max_context_length);
62
59
  config_.enable_dynamic_shape =
63
60
  static_cast<bool>(metadata_.at(kEnableDynamicShape));
61
+ if (config_.max_new_tokens < 0) {
62
+ // For dynamic-shape PTEs, max_seq_len is the per-call decoder chunk
63
+ // size, not the generation budget — use max_context_length instead.
64
+ const int32_t seq_cap = config_.enable_dynamic_shape
65
+ ? config_.max_context_length
66
+ : config_.max_seq_len;
67
+ config_.max_new_tokens = std::min(seq_cap, config_.max_context_length);
68
+ }
64
69
  config_.enable_kv_cache = static_cast<bool>(metadata_.at(kUseKVCache));
65
70
 
66
71
  eos_ids_ = std::make_unique<std::unordered_set<uint64_t>>();
@@ -149,6 +154,8 @@ void BaseLLMRunner::set_repetition_penalty(float repetition_penalty) noexcept {
149
154
  config_.repetition_penalty = repetition_penalty;
150
155
  }
151
156
 
157
+ void BaseLLMRunner::set_topk(int32_t topk) noexcept { config_.topk = topk; }
158
+
152
159
  void BaseLLMRunner::set_count_interval(size_t count_interval) {
153
160
  config_.output_token_batch_size = count_interval;
154
161
  }
@@ -55,6 +55,7 @@ public:
55
55
  void set_topp(float topp) noexcept;
56
56
  void set_min_p(float min_p) noexcept;
57
57
  void set_repetition_penalty(float repetition_penalty) noexcept;
58
+ void set_topk(int32_t topk) noexcept;
58
59
  void set_count_interval(size_t count_interval);
59
60
  void set_time_interval(size_t time_interval);
60
61
 
@@ -23,8 +23,22 @@ inline constexpr auto kVisionEncoderMethod = "vision_encoder";
23
23
  inline constexpr auto kAudioEncoderMethod = "audio_encoder";
24
24
  inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
25
25
  inline constexpr auto kTextModelMethod = "text_decoder";
26
-
27
26
  inline constexpr auto numOfAddedBoSTokens = 0;
28
27
  inline constexpr auto numOfAddedEoSTokens = 0;
29
28
 
29
+ // Gemma4
30
+ // PLE models only: token id that marks image placeholder slots in input_ids.
31
+ // token_embedding run on this id produces the per-layer PLE signal for image
32
+ // positions; the inputs_embeds output for those positions is discarded (the
33
+ // vision encoder output replaces it).
34
+ inline constexpr auto kImagePlaceholderId = "image_placeholder_id";
35
+ // True iff the model exposes a per-layer-embedding (PLE) signal alongside
36
+ // inputs_embeds (Gemma4-style). When true, `token_embedding.execute()`
37
+ // returns the tuple (inputs_embeds, ple_tok) and the runner must thread
38
+ // ple_tok into text_decoder; when false (or absent), token_embedding returns
39
+ // inputs_embeds alone. Text-only PTEs that ship a single `forward` method
40
+ // omit this key entirely — it is meaningful only for multimodal PTEs that
41
+ // expose a separate `token_embedding` method.
42
+ inline constexpr auto kHasPLE = "has_ple";
43
+
30
44
  } // namespace executorch::extension::llm
@@ -0,0 +1,111 @@
1
+ // common/runner/encoders/audio_encoder.cpp
2
+ #include "audio_encoder.h"
3
+
4
+ #include <rnexecutorch/Error.h>
5
+ #include <runner/constants.h>
6
+
7
+ #include <executorch/extension/tensor/tensor.h>
8
+
9
+ #include <cmath>
10
+ #include <cstdint>
11
+ #include <cstring>
12
+ #include <string>
13
+ #include <vector>
14
+
15
+ namespace executorch::extension::llm {
16
+
17
+ using ::executorch::aten::SizesType;
18
+ using ::executorch::runtime::Error;
19
+ using ::executorch::runtime::EValue;
20
+ using ::executorch::runtime::Result;
21
+
22
+ namespace {
23
+ constexpr int32_t kSamplingRate = 16e3;
24
+ constexpr int32_t kMaxLengthSeconds = 30;
25
+ constexpr int32_t kSamplesPerBlock = 7680;
26
+ constexpr int64_t kAudioBlockKMin = 1;
27
+ constexpr int64_t kAudioBlockKMax =
28
+ kSamplingRate * kMaxLengthSeconds / kSamplesPerBlock;
29
+ } // namespace
30
+
31
+ AudioEncoder::AudioEncoder(::executorch::extension::Module &module)
32
+ : module_(&module) {}
33
+
34
+ Error AudioEncoder::load() {
35
+ if (is_loaded()) {
36
+ return Error::Ok;
37
+ }
38
+ auto method_names_result = module_->method_names();
39
+ if (!method_names_result.ok()) {
40
+ return method_names_result.error();
41
+ }
42
+ if (method_names_result->count(kAudioEncoderMethod) == 0) {
43
+ throw rnexecutorch::RnExecutorchError(
44
+ rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
45
+ "Model does not support audio: 'audio_encoder' method not found. "
46
+ "Check that the .pte file matches the declared capabilities.");
47
+ }
48
+ return module_->load_method(kAudioEncoderMethod);
49
+ }
50
+
51
+ bool AudioEncoder::is_loaded() const noexcept {
52
+ return module_->is_method_loaded(kAudioEncoderMethod);
53
+ }
54
+
55
+ int32_t AudioEncoder::encoderTokenCount() const noexcept {
56
+ return last_token_count_;
57
+ }
58
+
59
+ Result<EValue> AudioEncoder::encode(const MultimodalInput &input) {
60
+ if (!is_loaded()) {
61
+ return Error::InvalidState;
62
+ }
63
+ if (!input.is_audio()) {
64
+ return Error::InvalidArgument;
65
+ }
66
+
67
+ const auto &wav = input.get_audio();
68
+ ET_CHECK_OR_RETURN_ERROR(!wav.samples.empty(), InvalidArgument,
69
+ "AudioEncoder: empty waveform");
70
+
71
+ const int64_t n_valid = static_cast<int64_t>(wav.samples.size());
72
+ const int64_t k_blocks = (n_valid + kSamplesPerBlock - 1) / kSamplesPerBlock;
73
+ ET_CHECK_OR_RETURN_ERROR(
74
+ k_blocks >= kAudioBlockKMin && k_blocks <= kAudioBlockKMax,
75
+ InvalidArgument,
76
+ "AudioEncoder: waveform of %lld samples needs k_blocks=%lld.",
77
+ static_cast<long long>(n_valid), static_cast<long long>(k_blocks));
78
+ const int64_t n_padded = k_blocks * kSamplesPerBlock;
79
+
80
+ // Own the padded waveform for the lifetime of this call; from_blob below
81
+ // borrows without copying. The current export takes
82
+ // forward(waveform[1, 7680*k] fp32, num_blocks: int64 scalar)
83
+ // — input 1 is a rank-0 Long telling the encoder how many of the K_MAX
84
+ // blocks contain real PCM. Passing a 2-d mask here trips "Attempted to
85
+ // change tensor rank: old=0, new=2".
86
+ padded_wav_.assign(static_cast<size_t>(n_padded), 0.0f);
87
+ std::memcpy(padded_wav_.data(), wav.samples.data(),
88
+ static_cast<size_t>(n_valid) * sizeof(float));
89
+
90
+ valid_samples_scalar_ = n_valid;
91
+
92
+ auto wav_tensor = ::executorch::extension::from_blob(
93
+ padded_wav_.data(), {1, static_cast<SizesType>(n_padded)},
94
+ ::executorch::aten::ScalarType::Float);
95
+
96
+ auto num_blocks_tensor = ::executorch::extension::from_blob(
97
+ &valid_samples_scalar_, {}, ::executorch::aten::ScalarType::Long);
98
+
99
+ std::vector<EValue> args = {EValue(*wav_tensor), EValue(*num_blocks_tensor)};
100
+ auto exec_result = ET_UNWRAP(module_->execute(kAudioEncoderMethod, args));
101
+ ET_CHECK_OR_RETURN_ERROR(!exec_result.empty(), InvalidState,
102
+ "audio_encoder returned no outputs");
103
+ auto audio_tensor = exec_result[0].toTensor();
104
+ ET_CHECK_OR_RETURN_ERROR(audio_tensor.dim() == 3, InvalidState,
105
+ "audio_encoder output rank=%zd, expected 3",
106
+ audio_tensor.dim());
107
+ last_token_count_ = static_cast<int32_t>(audio_tensor.size(1));
108
+ return exec_result[0];
109
+ }
110
+
111
+ } // namespace executorch::extension::llm
@@ -0,0 +1,40 @@
1
+ // common/runner/encoders/audio_encoder.h
2
+ #pragma once
3
+
4
+ #include "iencoder.h"
5
+ #include <executorch/extension/module/module.h>
6
+ #include <executorch/runtime/core/evalue.h>
7
+ #include <runner/multimodal_input.h>
8
+
9
+ #include <cstdint>
10
+ #include <vector>
11
+
12
+ namespace executorch::extension::llm {
13
+
14
+ // Runs the Gemma4 `audio_encoder` PTE method.
15
+ //
16
+ // Contract mirrors SpeechToText (Whisper): JS hands in fp32 mono 16 kHz PCM
17
+ // via `MultimodalInput::get_audio()`; the PTE owns the log-mel frontend so
18
+ // this class just wraps the samples in a `[1, N_samples]` Float tensor and
19
+ // executes. Resampling and WAV/MP3 decoding are the caller's responsibility
20
+ // (e.g. react-native-audio-api).
21
+ class AudioEncoder : public IEncoder {
22
+ public:
23
+ explicit AudioEncoder(::executorch::extension::Module &module);
24
+
25
+ ::executorch::runtime::Error load() override;
26
+ bool is_loaded() const noexcept override;
27
+ ::executorch::runtime::Result<::executorch::runtime::EValue>
28
+ encode(const MultimodalInput &input) override;
29
+ // Number of audio embedding tokens produced per encode() call. 0 until first
30
+ // encode, since Gemma4's audio_encoder has a dynamic T dim.
31
+ int32_t encoderTokenCount() const noexcept override;
32
+
33
+ private:
34
+ ::executorch::extension::Module *module_;
35
+ int32_t last_token_count_ = 0;
36
+ std::vector<float> padded_wav_;
37
+ int64_t valid_samples_scalar_ = 0;
38
+ };
39
+
40
+ } // namespace executorch::extension::llm
@@ -2,7 +2,6 @@
2
2
  #include "vision_encoder.h"
3
3
 
4
4
  #include <rnexecutorch/Error.h>
5
- #include <rnexecutorch/Log.h>
6
5
  #include <rnexecutorch/data_processing/ImageProcessing.h>
7
6
  #include <runner/constants.h>
8
7
 
@@ -73,6 +73,11 @@ struct GenerationConfig {
73
73
  size_t output_token_batch_size = 10;
74
74
  size_t batch_time_interval_ms = 120;
75
75
 
76
+ // Top-k sampling – keep only the k highest-logit tokens before softmax.
77
+ // 0 (default) disables top-k filtering. Stacks with topp: temperature ->
78
+ // top-k -> top-p -> softmax -> multinomial.
79
+ int32_t topk = 0;
80
+
76
81
  // Enable dynamic input shapes (if implemented) or not
77
82
  // Impacts the prefill phase and causes TextPrefiller to pass all the tokens
78
83
  // at once if set to true.
@@ -14,19 +14,50 @@
14
14
  #include "text_decoder_runner.h"
15
15
 
16
16
  namespace executorch::extension::llm {
17
+ // Supports two PTE contracts, selected per-call from the kHasPLE metadata
18
+ // key (mirrors how kEnableDynamicShape etc. are read — queried on demand,
19
+ // not cached in a member). Callers that need it multiple times in a hot
20
+ // path should snapshot into a local.
21
+ //
22
+ // * Legacy (has_ple == false):
23
+ // token_embedding(ids) -> inputs_embeds
24
+ // text_decoder(inputs_embeds, input_pos)
25
+ //
26
+ // * Gemma-style PLE (has_ple == true):
27
+ // token_embedding(ids) -> (inputs_embeds, ple_tok)
28
+ // text_decoder(inputs_embeds, ple_tok, input_pos)
29
+ // ple_tok carries Gemma4's per-layer PLE signal keyed on input_ids. It's
30
+ // computed once in token_embedding and threaded through every decoder call
31
+ // so PLE fires at every position (including multimodal placeholder slots).
17
32
  class MultimodalDecoderRunner : public TextDecoderRunner {
18
33
  public:
19
34
  explicit MultimodalDecoderRunner(Module &module, IOManager *io_manager,
20
35
  const GenerationConfig &config)
21
36
  : TextDecoderRunner(module, io_manager, config) {}
22
37
 
38
+ bool has_ple() const {
39
+ auto r = module_->get(kHasPLE);
40
+ if (r.error() != ::executorch::runtime::Error::Ok) {
41
+ return false;
42
+ }
43
+ return r->toScalar().to<bool>();
44
+ }
45
+
23
46
  inline ::executorch::runtime::Result<::executorch::aten::Tensor>
24
47
  step(TensorPtr &tokens, int64_t start_pos) override {
25
48
  auto embed_result = module_->execute(kTokenEmbeddingMethod, tokens);
26
49
  if (!embed_result.ok()) {
27
50
  return embed_result.error();
28
51
  }
29
- return decode((*embed_result)[0], start_pos);
52
+ auto &embed_outputs = *embed_result;
53
+ if (has_ple()) {
54
+ ET_CHECK_MSG(embed_outputs.size() == 2,
55
+ "Expected 2 outputs (inputs_embeds, ple_tok) from "
56
+ "token_embedding, got %zu",
57
+ embed_outputs.size());
58
+ return decode(embed_outputs[0], embed_outputs[1], start_pos);
59
+ }
60
+ return decode(embed_outputs[0], start_pos);
30
61
  }
31
62
 
32
63
  inline ::executorch::runtime::Result<::executorch::aten::Tensor>
@@ -46,6 +77,24 @@ public:
46
77
  return outputs[0].toTensor();
47
78
  }
48
79
 
80
+ inline ::executorch::runtime::Result<::executorch::aten::Tensor>
81
+ decode(const ::executorch::runtime::EValue &embeddings,
82
+ const ::executorch::runtime::EValue &ple_tok, int64_t start_pos) {
83
+ auto start_pos_tensor = ::executorch::extension::from_blob(
84
+ &start_pos, {1}, ::executorch::aten::ScalarType::Long);
85
+ auto outputs_result = module_->execute(
86
+ kTextModelMethod, {embeddings, ple_tok, start_pos_tensor});
87
+ if (!outputs_result.ok()) {
88
+ return outputs_result.error();
89
+ }
90
+ auto &outputs = *outputs_result;
91
+ ET_CHECK_MSG(outputs.size() == 1,
92
+ "Expected 1 output from text_decoder, got %zu",
93
+ outputs.size());
94
+ ET_CHECK_MSG(outputs[0].isTensor(), "text_decoder output is not a tensor");
95
+ return outputs[0].toTensor();
96
+ }
97
+
49
98
  inline ::executorch::runtime::Error load() override {
50
99
  if (is_method_loaded()) {
51
100
  return ::executorch::runtime::Error::Ok;
@@ -20,6 +20,10 @@ struct ImagePath {
20
20
  std::string path;
21
21
  };
22
22
 
23
+ struct AudioWaveform {
24
+ std::vector<float> samples;
25
+ };
26
+
23
27
  class MultimodalInput {
24
28
  public:
25
29
  explicit MultimodalInput(std::string text) : data_(std::move(text)) {}
@@ -27,6 +31,7 @@ public:
27
31
  : data_(std::move(tokens)) {}
28
32
  explicit MultimodalInput(ImagePath image_path)
29
33
  : data_(std::move(image_path)) {}
34
+ explicit MultimodalInput(AudioWaveform audio) : data_(std::move(audio)) {}
30
35
 
31
36
  MultimodalInput(const MultimodalInput &) = default;
32
37
  MultimodalInput &operator=(const MultimodalInput &) = default;
@@ -42,6 +47,9 @@ public:
42
47
  bool is_image() const noexcept {
43
48
  return std::holds_alternative<ImagePath>(data_);
44
49
  }
50
+ bool is_audio() const noexcept {
51
+ return std::holds_alternative<AudioWaveform>(data_);
52
+ }
45
53
 
46
54
  const std::string &get_text() const & { return std::get<std::string>(data_); }
47
55
  const std::vector<uint64_t> &get_tokens() const & {
@@ -50,9 +58,13 @@ public:
50
58
  const std::string &get_image_path() const & {
51
59
  return std::get<ImagePath>(data_).path;
52
60
  }
61
+ const AudioWaveform &get_audio() const & {
62
+ return std::get<AudioWaveform>(data_);
63
+ }
53
64
 
54
65
  private:
55
- std::variant<std::string, std::vector<uint64_t>, ImagePath> data_;
66
+ std::variant<std::string, std::vector<uint64_t>, ImagePath, AudioWaveform>
67
+ data_;
56
68
  };
57
69
 
58
70
  inline MultimodalInput make_text_input(const std::string &text) noexcept {
@@ -64,5 +76,8 @@ inline MultimodalInput make_text_input(std::string &&text) noexcept {
64
76
  inline MultimodalInput make_image_input(std::string path) noexcept {
65
77
  return MultimodalInput(ImagePath{std::move(path)});
66
78
  }
79
+ inline MultimodalInput make_audio_input(std::vector<float> samples) noexcept {
80
+ return MultimodalInput(AudioWaveform{std::move(samples)});
81
+ }
67
82
 
68
83
  } // namespace executorch::extension::llm