react-native-executorch 0.9.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. package/android/libs/classes.jar +0 -0
  2. package/common/rnexecutorch/host_objects/JsiConversions.h +43 -0
  3. package/common/rnexecutorch/models/llm/LLM.cpp +55 -42
  4. package/common/rnexecutorch/models/llm/LLM.h +4 -3
  5. package/common/rnexecutorch/models/llm/Types.h +23 -0
  6. package/common/runner/base_llm_runner.cpp +10 -3
  7. package/common/runner/base_llm_runner.h +1 -0
  8. package/common/runner/constants.h +15 -1
  9. package/common/runner/encoders/audio_encoder.cpp +111 -0
  10. package/common/runner/encoders/audio_encoder.h +40 -0
  11. package/common/runner/encoders/vision_encoder.cpp +0 -1
  12. package/common/runner/irunner.h +5 -0
  13. package/common/runner/multimodal_decoder_runner.h +50 -1
  14. package/common/runner/multimodal_input.h +16 -1
  15. package/common/runner/multimodal_prefiller.cpp +374 -64
  16. package/common/runner/multimodal_prefiller.h +57 -6
  17. package/common/runner/multimodal_runner.cpp +19 -12
  18. package/common/runner/multimodal_runner.h +1 -1
  19. package/common/runner/sampler.cpp +111 -35
  20. package/common/runner/sampler.h +13 -5
  21. package/common/runner/text_decoder_runner.cpp +1 -4
  22. package/common/runner/text_decoder_runner.h +3 -2
  23. package/common/runner/text_prefiller.cpp +8 -8
  24. package/common/runner/text_prefiller.h +8 -1
  25. package/common/runner/text_runner.cpp +35 -9
  26. package/common/runner/text_token_generator.h +2 -3
  27. package/common/runner/util.h +0 -1
  28. package/lib/module/constants/llmDefaults.js +1 -1
  29. package/lib/module/constants/llmDefaults.js.map +1 -1
  30. package/lib/module/constants/modelRegistry.js +33 -2
  31. package/lib/module/constants/modelRegistry.js.map +1 -1
  32. package/lib/module/constants/modelUrls.js +43 -6
  33. package/lib/module/constants/modelUrls.js.map +1 -1
  34. package/lib/module/controllers/LLMController.js +69 -20
  35. package/lib/module/controllers/LLMController.js.map +1 -1
  36. package/lib/module/hooks/natural_language_processing/useLLM.js +1 -5
  37. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  38. package/lib/module/modules/natural_language_processing/LLMModule.js +12 -7
  39. package/lib/module/modules/natural_language_processing/LLMModule.js.map +1 -1
  40. package/lib/module/types/llm.js +11 -0
  41. package/lib/module/types/llm.js.map +1 -1
  42. package/lib/typescript/constants/llmDefaults.d.ts +1 -1
  43. package/lib/typescript/constants/llmDefaults.d.ts.map +1 -1
  44. package/lib/typescript/constants/modelRegistry.d.ts +28 -1
  45. package/lib/typescript/constants/modelRegistry.d.ts.map +1 -1
  46. package/lib/typescript/constants/modelUrls.d.ts +40 -12
  47. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  48. package/lib/typescript/controllers/LLMController.d.ts +7 -9
  49. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  50. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts +6 -3
  51. package/lib/typescript/modules/natural_language_processing/LLMModule.d.ts.map +1 -1
  52. package/lib/typescript/types/llm.d.ts +63 -36
  53. package/lib/typescript/types/llm.d.ts.map +1 -1
  54. package/package.json +1 -1
  55. package/react-native-executorch.podspec +6 -0
  56. package/src/constants/llmDefaults.ts +1 -1
  57. package/src/constants/modelRegistry.ts +34 -2
  58. package/src/constants/modelUrls.ts +47 -6
  59. package/src/controllers/LLMController.ts +89 -40
  60. package/src/hooks/natural_language_processing/useLLM.ts +5 -6
  61. package/src/modules/natural_language_processing/LLMModule.ts +19 -8
  62. package/src/types/llm.ts +64 -34
  63. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  64. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  65. package/third-party/include/executorch/ExecuTorch.h +2 -0
  66. package/third-party/include/executorch/ExecuTorchModule.h +46 -0
  67. package/third-party/include/executorch/extension/data_loader/buffer_data_loader.h +4 -3
  68. package/third-party/include/executorch/extension/data_loader/mman.h +46 -0
  69. package/third-party/include/executorch/extension/data_loader/mmap_data_loader.h +4 -0
  70. package/third-party/include/executorch/extension/data_loader/shared_ptr_data_loader.h +7 -3
  71. package/third-party/include/executorch/extension/module/module.h +47 -8
  72. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +17 -5
  73. package/third-party/include/executorch/kernels/optimized/Functions.h +12 -0
  74. package/third-party/include/executorch/kernels/optimized/NativeFunctions.h +4 -0
  75. package/third-party/include/executorch/kernels/portable/Functions.h +18 -0
  76. package/third-party/include/executorch/kernels/portable/NativeFunctions.h +6 -0
  77. package/third-party/include/executorch/runtime/backend/backend_options_map.h +37 -0
  78. package/third-party/include/executorch/runtime/core/array_ref.h +3 -1
  79. package/third-party/include/executorch/runtime/core/error.h +1 -0
  80. package/third-party/include/executorch/runtime/core/evalue.h +256 -9
  81. package/third-party/include/executorch/runtime/core/exec_aten/exec_aten.h +24 -0
  82. package/third-party/include/executorch/runtime/core/hierarchical_allocator.h +9 -6
  83. package/third-party/include/executorch/runtime/core/portable_type/device.h +3 -4
  84. package/third-party/include/executorch/runtime/core/portable_type/tensor_impl.h +31 -1
  85. package/third-party/include/executorch/runtime/executor/method.h +9 -3
  86. package/third-party/include/executorch/runtime/executor/method_meta.h +14 -0
  87. package/third-party/include/executorch/runtime/executor/platform_memory_allocator.h +12 -2
  88. package/third-party/include/executorch/runtime/executor/program.h +3 -1
  89. package/third-party/include/executorch/runtime/executor/tensor_parser.h +5 -1
  90. package/third-party/include/executorch/runtime/kernel/operator_registry.h +9 -0
  91. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  92. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  93. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/mlx.metallib +0 -0
  94. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  95. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  96. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/mlx.metallib +0 -0
@@ -20,8 +20,12 @@
20
20
  namespace torch {
21
21
  namespace executor {
22
22
  namespace native {
23
+ torch::executor::Tensor & _adaptive_avg_pool2d_out(const torch::executor::Tensor & self, torch::executor::ArrayRef<int64_t> output_size, torch::executor::Tensor & out);
24
+ torch::executor::Tensor & _adaptive_avg_pool2d_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, torch::executor::ArrayRef<int64_t> output_size, torch::executor::Tensor & out);
23
25
  torch::executor::Tensor & _cdist_forward_out(const torch::executor::Tensor & x1, const torch::executor::Tensor & x2, double p, torch::executor::optional<int64_t> compute_mode, torch::executor::Tensor & out);
24
26
  torch::executor::Tensor & _cdist_forward_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & x1, const torch::executor::Tensor & x2, double p, torch::executor::optional<int64_t> compute_mode, torch::executor::Tensor & out);
27
+ torch::executor::Tensor & _conj_physical_out(const torch::executor::Tensor & self, torch::executor::Tensor & out);
28
+ torch::executor::Tensor & _conj_physical_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, torch::executor::Tensor & out);
25
29
  torch::executor::Tensor & log_softmax_out(const torch::executor::Tensor & self, int64_t dim, bool half_to_float, torch::executor::Tensor & out);
26
30
  torch::executor::Tensor & log_softmax_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, int64_t dim, bool half_to_float, torch::executor::Tensor & out);
27
31
  ::std::tuple<torch::executor::Tensor &,torch::executor::Tensor &,torch::executor::Tensor &> _native_batch_norm_legit_out(const torch::executor::Tensor & input, const torch::executor::optional<torch::executor::Tensor> & weight, const torch::executor::optional<torch::executor::Tensor> & bias, torch::executor::Tensor & running_mean, torch::executor::Tensor & running_var, bool training, double momentum, double eps, torch::executor::Tensor & out, torch::executor::Tensor & save_mean, torch::executor::Tensor & save_invstd);
@@ -412,6 +416,8 @@ torch::executor::Tensor & upsample_nearest2d_vec_out(const torch::executor::Tens
412
416
  torch::executor::Tensor & upsample_nearest2d_vec_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & input, torch::executor::optional<torch::executor::ArrayRef<int64_t>> output_size, torch::executor::optional<torch::executor::ArrayRef<double>> scale_factors, torch::executor::Tensor & out);
413
417
  torch::executor::Tensor & var_correction_out(const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, const torch::executor::optional<torch::executor::Scalar> & correction, bool keepdim, torch::executor::Tensor & out);
414
418
  torch::executor::Tensor & var_correction_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, const torch::executor::optional<torch::executor::Scalar> & correction, bool keepdim, torch::executor::Tensor & out);
419
+ ::std::tuple<torch::executor::Tensor &,torch::executor::Tensor &> var_mean_correction_out(const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, const torch::executor::optional<torch::executor::Scalar> & correction, bool keepdim, torch::executor::Tensor & out0, torch::executor::Tensor & out1);
420
+ ::std::tuple<torch::executor::Tensor &,torch::executor::Tensor &> var_mean_correction_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, const torch::executor::optional<torch::executor::Scalar> & correction, bool keepdim, torch::executor::Tensor & out0, torch::executor::Tensor & out1);
415
421
  torch::executor::Tensor & var_out(const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, bool unbiased, bool keepdim, torch::executor::Tensor & out);
416
422
  torch::executor::Tensor & var_out(torch::executor::KernelRuntimeContext & context, const torch::executor::Tensor & self, torch::executor::optional<torch::executor::ArrayRef<int64_t>> dim, bool unbiased, bool keepdim, torch::executor::Tensor & out);
417
423
  torch::executor::Tensor & view_as_real_copy_out(const torch::executor::Tensor & self, torch::executor::Tensor & out);
@@ -11,6 +11,7 @@
11
11
  #include <executorch/runtime/backend/options.h>
12
12
  #include <executorch/runtime/core/error.h>
13
13
  #include <executorch/runtime/core/span.h>
14
+ #include <executorch/runtime/platform/assert.h>
14
15
 
15
16
  #include <cstring>
16
17
 
@@ -168,6 +169,42 @@ public:
168
169
  */
169
170
  size_t size() const { return size_; }
170
171
 
172
+ /**
173
+ * Non-owning view of a single (backend_id, options) entry, returned by
174
+ * entry_at(). The pointer / span are valid until the map is mutated or
175
+ * destroyed.
176
+ */
177
+ struct EntryView {
178
+ const char *backend_id = nullptr;
179
+ Span<const BackendOption> options;
180
+ };
181
+
182
+ /**
183
+ * Returns the (backend_id, options) entry at the given index for
184
+ * enumeration over the map's contents.
185
+ *
186
+ * @param index The entry index. Must be < size(); behavior is undefined
187
+ * otherwise. Use this together with size() to walk every entry.
188
+ * @return EntryView referencing the entry's backend_id and options. The
189
+ * view is valid until the next mutation of, or destruction of, this
190
+ * map.
191
+ *
192
+ * Example:
193
+ * @code
194
+ * for (size_t i = 0; i < map.size(); ++i) {
195
+ * const auto entry = map.entry_at(i);
196
+ * // use entry.backend_id and entry.options ...
197
+ * }
198
+ * @endcode
199
+ */
200
+ EntryView entry_at(size_t index) const {
201
+ ET_DCHECK_MSG(index < size_, "entry_at index %zu out of bounds (size=%zu)",
202
+ index, size_);
203
+ return EntryView{entries_[index].backend_id,
204
+ Span<const BackendOption>(entries_[index].options.data(),
205
+ entries_[index].options.size())};
206
+ }
207
+
171
208
  private:
172
209
  static constexpr size_t kMaxBackends = 8;
173
210
  static constexpr size_t kMaxBackendIdLength = 64;
@@ -30,6 +30,7 @@
30
30
  #include <cstdint>
31
31
 
32
32
  #include <c10/util/irange.h>
33
+ #include <c10/util/safe_numerics.h>
33
34
  #include <executorch/runtime/platform/assert.h>
34
35
 
35
36
  namespace executorch {
@@ -146,7 +147,8 @@ public:
146
147
  /// slice(n, m) - Take M elements of the array starting at element N
147
148
  ArrayRef<T> slice(size_t N, size_t M) const {
148
149
  // cant slice longer then the array
149
- ET_CHECK(N + M <= size());
150
+ size_t end = 0;
151
+ ET_CHECK(!c10::add_overflows(N, M, &end) && end <= size());
150
152
  return ArrayRef<T>(data() + N, M);
151
153
  }
152
154
 
@@ -152,6 +152,7 @@ constexpr const char *to_string(const Error error) {
152
152
  case Error::RegistrationAlreadyRegistered:
153
153
  return "Error::RegistrationAlreadyRegistered";
154
154
  }
155
+ return "Error::Unknown";
155
156
  }
156
157
 
157
158
  } // namespace runtime
@@ -8,6 +8,7 @@
8
8
 
9
9
  #pragma once
10
10
  #include <executorch/runtime/core/exec_aten/exec_aten.h>
11
+ #include <executorch/runtime/core/result.h>
11
12
  #include <executorch/runtime/core/tag.h>
12
13
  #include <executorch/runtime/platform/assert.h>
13
14
 
@@ -67,6 +68,29 @@ public:
67
68
  */
68
69
  executorch::aten::ArrayRef<T> get() const;
69
70
 
71
+ /**
72
+ * Result-returning counterpart of get(). Validates each wrapped EValue's
73
+ * tag before materializing; returns Error::InvalidType if any element's
74
+ * tag does not match T and Error::InvalidState if any element pointer is
75
+ * null. Use this when materializing lists from untrusted .pte data so that
76
+ * a malformed program cannot force a process abort inside to<T>() /
77
+ * ET_CHECK.
78
+ */
79
+ Result<executorch::aten::ArrayRef<T>> tryGet() const;
80
+
81
+ /**
82
+ * Destroys the unwrapped elements without re-dereferencing wrapped_vals_.
83
+ * This is safe to call during EValue destruction because it does not
84
+ * dereference wrapped_vals_, which may point to EValues mutated by
85
+ * MoveCall instructions.
86
+ */
87
+ void destroy_elements() {
88
+ for (typename executorch::aten::ArrayRef<T>::size_type i = 0;
89
+ i < wrapped_vals_.size(); i++) {
90
+ unwrapped_vals_[i].~T();
91
+ }
92
+ }
93
+
70
94
  private:
71
95
  static EValue **checkWrappedVals(EValue **wrapped_vals, int size) {
72
96
  ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
@@ -89,6 +113,10 @@ template <>
89
113
  executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
90
114
  BoxedEvalueList<std::optional<executorch::aten::Tensor>>::get() const;
91
115
 
116
+ template <>
117
+ Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
118
+ BoxedEvalueList<std::optional<executorch::aten::Tensor>>::tryGet() const;
119
+
92
120
  // Aggregate typing system similar to IValue only slimmed down with less
93
121
  // functionality, no dependencies on atomic, and fewer supported types to better
94
122
  // suit embedded systems (ie no intrusive ptr)
@@ -165,6 +193,13 @@ struct EValue {
165
193
  return payload.copyable_union.as_int;
166
194
  }
167
195
 
196
+ Result<int64_t> tryToInt() const {
197
+ if (!isInt()) {
198
+ return Error::InvalidType;
199
+ }
200
+ return payload.copyable_union.as_int;
201
+ }
202
+
168
203
  /****** Double Type ******/
169
204
  /*implicit*/ EValue(double d) : tag(Tag::Double) {
170
205
  payload.copyable_union.as_double = d;
@@ -177,6 +212,13 @@ struct EValue {
177
212
  return payload.copyable_union.as_double;
178
213
  }
179
214
 
215
+ Result<double> tryToDouble() const {
216
+ if (!isDouble()) {
217
+ return Error::InvalidType;
218
+ }
219
+ return payload.copyable_union.as_double;
220
+ }
221
+
180
222
  /****** Bool Type ******/
181
223
  /*implicit*/ EValue(bool b) : tag(Tag::Bool) {
182
224
  payload.copyable_union.as_bool = b;
@@ -189,6 +231,13 @@ struct EValue {
189
231
  return payload.copyable_union.as_bool;
190
232
  }
191
233
 
234
+ Result<bool> tryToBool() const {
235
+ if (!isBool()) {
236
+ return Error::InvalidType;
237
+ }
238
+ return payload.copyable_union.as_bool;
239
+ }
240
+
192
241
  /****** Scalar Type ******/
193
242
  /// Construct an EValue using the implicit value of a Scalar.
194
243
  /*implicit*/ EValue(executorch::aten::Scalar s) {
@@ -224,6 +273,19 @@ struct EValue {
224
273
  }
225
274
  }
226
275
 
276
+ Result<executorch::aten::Scalar> tryToScalar() const {
277
+ if (isDouble()) {
278
+ return executorch::aten::Scalar(payload.copyable_union.as_double);
279
+ }
280
+ if (isInt()) {
281
+ return executorch::aten::Scalar(payload.copyable_union.as_int);
282
+ }
283
+ if (isBool()) {
284
+ return executorch::aten::Scalar(payload.copyable_union.as_bool);
285
+ }
286
+ return Error::InvalidType;
287
+ }
288
+
227
289
  /****** Tensor Type ******/
228
290
  /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) {
229
291
  // When built in aten mode, at::Tensor has a non trivial constructor
@@ -270,6 +332,16 @@ struct EValue {
270
332
  return payload.as_tensor;
271
333
  }
272
334
 
335
+ // Returns a copy of the Tensor handle (one intrusive_ptr refcount bump in
336
+ // ATen mode; free in lean mode). Unlike toTensor()'s const& / & overloads,
337
+ // tryToTensor() cannot return a reference — Result<T> wraps by value.
338
+ Result<executorch::aten::Tensor> tryToTensor() const {
339
+ if (!isTensor()) {
340
+ return Error::InvalidType;
341
+ }
342
+ return payload.as_tensor;
343
+ }
344
+
273
345
  /****** String Type ******/
274
346
  /*implicit*/ EValue(executorch::aten::ArrayRef<char> *s) : tag(Tag::String) {
275
347
  ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
@@ -286,6 +358,17 @@ struct EValue {
286
358
  payload.copyable_union.as_string_ptr->size());
287
359
  }
288
360
 
361
+ Result<std::string_view> tryToString() const {
362
+ if (!isString()) {
363
+ return Error::InvalidType;
364
+ }
365
+ if (payload.copyable_union.as_string_ptr == nullptr) {
366
+ return Error::InvalidState;
367
+ }
368
+ return std::string_view(payload.copyable_union.as_string_ptr->data(),
369
+ payload.copyable_union.as_string_ptr->size());
370
+ }
371
+
289
372
  /****** Int List Type ******/
290
373
  /*implicit*/ EValue(BoxedEvalueList<int64_t> *i) : tag(Tag::ListInt) {
291
374
  ET_CHECK_MSG(i != nullptr,
@@ -302,6 +385,16 @@ struct EValue {
302
385
  return (payload.copyable_union.as_int_list_ptr)->get();
303
386
  }
304
387
 
388
+ Result<executorch::aten::ArrayRef<int64_t>> tryToIntList() const {
389
+ if (!isIntList()) {
390
+ return Error::InvalidType;
391
+ }
392
+ if (payload.copyable_union.as_int_list_ptr == nullptr) {
393
+ return Error::InvalidState;
394
+ }
395
+ return (payload.copyable_union.as_int_list_ptr)->tryGet();
396
+ }
397
+
305
398
  /****** Bool List Type ******/
306
399
  /*implicit*/ EValue(executorch::aten::ArrayRef<bool> *b)
307
400
  : tag(Tag::ListBool) {
@@ -318,6 +411,16 @@ struct EValue {
318
411
  return *(payload.copyable_union.as_bool_list_ptr);
319
412
  }
320
413
 
414
+ Result<executorch::aten::ArrayRef<bool>> tryToBoolList() const {
415
+ if (!isBoolList()) {
416
+ return Error::InvalidType;
417
+ }
418
+ if (payload.copyable_union.as_bool_list_ptr == nullptr) {
419
+ return Error::InvalidState;
420
+ }
421
+ return *(payload.copyable_union.as_bool_list_ptr);
422
+ }
423
+
321
424
  /****** Double List Type ******/
322
425
  /*implicit*/ EValue(executorch::aten::ArrayRef<double> *d)
323
426
  : tag(Tag::ListDouble) {
@@ -334,6 +437,16 @@ struct EValue {
334
437
  return *(payload.copyable_union.as_double_list_ptr);
335
438
  }
336
439
 
440
+ Result<executorch::aten::ArrayRef<double>> tryToDoubleList() const {
441
+ if (!isDoubleList()) {
442
+ return Error::InvalidType;
443
+ }
444
+ if (payload.copyable_union.as_double_list_ptr == nullptr) {
445
+ return Error::InvalidState;
446
+ }
447
+ return *(payload.copyable_union.as_double_list_ptr);
448
+ }
449
+
337
450
  /****** Tensor List Type ******/
338
451
  /*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor> *t)
339
452
  : tag(Tag::ListTensor) {
@@ -351,6 +464,17 @@ struct EValue {
351
464
  return payload.copyable_union.as_tensor_list_ptr->get();
352
465
  }
353
466
 
467
+ Result<executorch::aten::ArrayRef<executorch::aten::Tensor>>
468
+ tryToTensorList() const {
469
+ if (!isTensorList()) {
470
+ return Error::InvalidType;
471
+ }
472
+ if (payload.copyable_union.as_tensor_list_ptr == nullptr) {
473
+ return Error::InvalidState;
474
+ }
475
+ return payload.copyable_union.as_tensor_list_ptr->tryGet();
476
+ }
477
+
354
478
  /****** List Optional Tensor Type ******/
355
479
  /*implicit*/ EValue(
356
480
  BoxedEvalueList<std::optional<executorch::aten::Tensor>> *t)
@@ -371,6 +495,17 @@ struct EValue {
371
495
  return payload.copyable_union.as_list_optional_tensor_ptr->get();
372
496
  }
373
497
 
498
+ Result<executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>>
499
+ tryToListOptionalTensor() const {
500
+ if (!isListOptionalTensor()) {
501
+ return Error::InvalidType;
502
+ }
503
+ if (payload.copyable_union.as_list_optional_tensor_ptr == nullptr) {
504
+ return Error::InvalidState;
505
+ }
506
+ return payload.copyable_union.as_list_optional_tensor_ptr->tryGet();
507
+ }
508
+
374
509
  /****** ScalarType Type ******/
375
510
  executorch::aten::ScalarType toScalarType() const {
376
511
  ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
@@ -378,6 +513,14 @@ struct EValue {
378
513
  payload.copyable_union.as_int);
379
514
  }
380
515
 
516
+ Result<executorch::aten::ScalarType> tryToScalarType() const {
517
+ if (!isInt()) {
518
+ return Error::InvalidType;
519
+ }
520
+ return static_cast<executorch::aten::ScalarType>(
521
+ payload.copyable_union.as_int);
522
+ }
523
+
381
524
  /****** MemoryFormat Type ******/
382
525
  executorch::aten::MemoryFormat toMemoryFormat() const {
383
526
  ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
@@ -385,12 +528,27 @@ struct EValue {
385
528
  payload.copyable_union.as_int);
386
529
  }
387
530
 
531
+ Result<executorch::aten::MemoryFormat> tryToMemoryFormat() const {
532
+ if (!isInt()) {
533
+ return Error::InvalidType;
534
+ }
535
+ return static_cast<executorch::aten::MemoryFormat>(
536
+ payload.copyable_union.as_int);
537
+ }
538
+
388
539
  /****** Layout Type ******/
389
540
  executorch::aten::Layout toLayout() const {
390
541
  ET_CHECK_MSG(isInt(), "EValue is not a Layout.");
391
542
  return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int);
392
543
  }
393
544
 
545
+ Result<executorch::aten::Layout> tryToLayout() const {
546
+ if (!isInt()) {
547
+ return Error::InvalidType;
548
+ }
549
+ return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int);
550
+ }
551
+
394
552
  /****** Device Type ******/
395
553
  executorch::aten::Device toDevice() const {
396
554
  ET_CHECK_MSG(isInt(), "EValue is not a Device.");
@@ -399,12 +557,29 @@ struct EValue {
399
557
  -1);
400
558
  }
401
559
 
560
+ Result<executorch::aten::Device> tryToDevice() const {
561
+ if (!isInt()) {
562
+ return Error::InvalidType;
563
+ }
564
+ return executorch::aten::Device(static_cast<executorch::aten::DeviceType>(
565
+ payload.copyable_union.as_int),
566
+ -1);
567
+ }
568
+
402
569
  template <typename T> T to() &&;
403
570
  template <typename T>
404
571
  typename internal::evalue_to_const_ref_overload_return<T>::type to() const &;
405
572
  template <typename T>
406
573
  typename internal::evalue_to_ref_overload_return<T>::type to() &;
407
574
 
575
+ /**
576
+ * Result-returning equivalent of `to<T>()`. Tag mismatch returns
577
+ * `Error::InvalidType`; a null list/string payload returns
578
+ * `Error::InvalidState`. Specializations are defined below via
579
+ * `EVALUE_DEFINE_TRY_TO`.
580
+ */
581
+ template <typename T> Result<T> tryTo() const;
582
+
408
583
  /**
409
584
  * Converts the EValue to an optional object that can represent both T and
410
585
  * an uninitialized state.
@@ -416,6 +591,22 @@ struct EValue {
416
591
  return this->to<T>();
417
592
  }
418
593
 
594
+ /**
595
+ * Result-returning equivalent of `toOptional<T>()`. None maps to an empty
596
+ * optional; any other tag that doesn't match T propagates `tryTo<T>()`'s
597
+ * error (`Error::InvalidType`).
598
+ */
599
+ template <typename T> inline Result<std::optional<T>> tryToOptional() const {
600
+ if (this->isNone()) {
601
+ return std::optional<T>(std::nullopt);
602
+ }
603
+ auto r = this->tryTo<T>();
604
+ if (!r.ok()) {
605
+ return r.error();
606
+ }
607
+ return std::optional<T>(std::move(r.get()));
608
+ }
609
+
419
610
  private:
420
611
  // Pre cond: the payload value has had its destructor called
421
612
  void clearToNone() noexcept {
@@ -446,17 +637,10 @@ private:
446
637
  payload.as_tensor.~Tensor();
447
638
  } else if (isTensorList() &&
448
639
  payload.copyable_union.as_tensor_list_ptr != nullptr) {
449
- // for (auto& tensor : toTensorList()) {
450
- for (auto &tensor : payload.copyable_union.as_tensor_list_ptr->get()) {
451
- tensor.~Tensor();
452
- }
640
+ payload.copyable_union.as_tensor_list_ptr->destroy_elements();
453
641
  } else if (isListOptionalTensor() &&
454
642
  payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
455
- // for (auto& optional_tensor : toListOptionalTensor()) {
456
- for (auto &optional_tensor :
457
- payload.copyable_union.as_list_optional_tensor_ptr->get()) {
458
- optional_tensor.~optional();
459
- }
643
+ payload.copyable_union.as_list_optional_tensor_ptr->destroy_elements();
460
644
  }
461
645
  }
462
646
 
@@ -532,6 +716,53 @@ EVALUE_DEFINE_TO(
532
716
  toListOptionalTensor)
533
717
  #undef EVALUE_DEFINE_TO
534
718
 
719
+ #define EVALUE_DEFINE_TRY_TO(T, method_name) \
720
+ template <> inline Result<T> EValue::tryTo<T>() const { \
721
+ return this->method_name(); \
722
+ }
723
+
724
+ EVALUE_DEFINE_TRY_TO(executorch::aten::Scalar, tryToScalar)
725
+ EVALUE_DEFINE_TRY_TO(int64_t, tryToInt)
726
+ EVALUE_DEFINE_TRY_TO(bool, tryToBool)
727
+ EVALUE_DEFINE_TRY_TO(double, tryToDouble)
728
+ EVALUE_DEFINE_TRY_TO(std::string_view, tryToString)
729
+ EVALUE_DEFINE_TRY_TO(executorch::aten::ScalarType, tryToScalarType)
730
+ EVALUE_DEFINE_TRY_TO(executorch::aten::MemoryFormat, tryToMemoryFormat)
731
+ EVALUE_DEFINE_TRY_TO(executorch::aten::Layout, tryToLayout)
732
+ EVALUE_DEFINE_TRY_TO(executorch::aten::Device, tryToDevice)
733
+ // Tensor and Optional Tensor
734
+ EVALUE_DEFINE_TRY_TO(executorch::aten::Tensor, tryToTensor)
735
+ EVALUE_DEFINE_TRY_TO(std::optional<executorch::aten::Tensor>,
736
+ tryToOptional<executorch::aten::Tensor>)
737
+
738
+ // IntList and Optional IntList
739
+ EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<int64_t>, tryToIntList)
740
+ EVALUE_DEFINE_TRY_TO(std::optional<executorch::aten::ArrayRef<int64_t>>,
741
+ tryToOptional<executorch::aten::ArrayRef<int64_t>>)
742
+
743
+ // DoubleList and Optional DoubleList
744
+ EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<double>, tryToDoubleList)
745
+ EVALUE_DEFINE_TRY_TO(std::optional<executorch::aten::ArrayRef<double>>,
746
+ tryToOptional<executorch::aten::ArrayRef<double>>)
747
+
748
+ // BoolList and Optional BoolList
749
+ EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<bool>, tryToBoolList)
750
+ EVALUE_DEFINE_TRY_TO(std::optional<executorch::aten::ArrayRef<bool>>,
751
+ tryToOptional<executorch::aten::ArrayRef<bool>>)
752
+
753
+ // TensorList and Optional TensorList
754
+ EVALUE_DEFINE_TRY_TO(executorch::aten::ArrayRef<executorch::aten::Tensor>,
755
+ tryToTensorList)
756
+ EVALUE_DEFINE_TRY_TO(
757
+ std::optional<executorch::aten::ArrayRef<executorch::aten::Tensor>>,
758
+ tryToOptional<executorch::aten::ArrayRef<executorch::aten::Tensor>>)
759
+
760
+ // List of Optional Tensor
761
+ EVALUE_DEFINE_TRY_TO(
762
+ executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>,
763
+ tryToListOptionalTensor)
764
+ #undef EVALUE_DEFINE_TRY_TO
765
+
535
766
  template <typename T>
536
767
  executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
537
768
  for (typename executorch::aten::ArrayRef<T>::size_type i = 0;
@@ -542,6 +773,22 @@ executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
542
773
  return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
543
774
  }
544
775
 
776
+ template <typename T>
777
+ Result<executorch::aten::ArrayRef<T>> BoxedEvalueList<T>::tryGet() const {
778
+ for (typename executorch::aten::ArrayRef<T>::size_type i = 0;
779
+ i < wrapped_vals_.size(); i++) {
780
+ if (wrapped_vals_[i] == nullptr) {
781
+ return Error::InvalidState;
782
+ }
783
+ auto r = wrapped_vals_[i]->template tryTo<T>();
784
+ if (!r.ok()) {
785
+ return r.error();
786
+ }
787
+ unwrapped_vals_[i] = std::move(r.get());
788
+ }
789
+ return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
790
+ }
791
+
545
792
  } // namespace runtime
546
793
  } // namespace executorch
547
794
 
@@ -8,7 +8,10 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include <executorch/runtime/core/error.h> // @manual
12
+ #include <executorch/runtime/core/result.h> // @manual
11
13
  #include <executorch/runtime/core/tensor_shape_dynamism.h> // @manual
14
+ #include <executorch/runtime/platform/assert.h> // @manual
12
15
  #include <executorch/runtime/platform/compiler.h>
13
16
  #ifdef USE_ATEN_LIB
14
17
  #include <ATen/Tensor.h> // @manual
@@ -28,6 +31,7 @@
28
31
  #include <c10/util/quint2x4.h> // @manual
29
32
  #include <c10/util/quint4x2.h> // @manual
30
33
  #include <c10/util/quint8.h> // @manual
34
+ #include <c10/util/safe_numerics.h> // @manual
31
35
  #include <c10/util/string_view.h> // @manual
32
36
  #include <torch/torch.h>
33
37
  #else // use executor
@@ -107,6 +111,25 @@ inline ssize_t compute_numel(const SizesType *sizes, ssize_t dim) {
107
111
  c10::multiply_integers(c10::ArrayRef<SizesType>(sizes, dim)));
108
112
  }
109
113
 
114
+ inline ::executorch::runtime::Result<ssize_t> safe_numel(const SizesType *sizes,
115
+ ssize_t dim) {
116
+ ET_CHECK_OR_RETURN_ERROR(dim == 0 || sizes != nullptr, InvalidArgument,
117
+ "Sizes must be provided for non-scalar tensors");
118
+ ssize_t numel = 1;
119
+ for (ssize_t i = 0; i < dim; i++) {
120
+ ET_CHECK_OR_RETURN_ERROR(
121
+ sizes[i] >= 0, InvalidArgument,
122
+ "Size must be non-negative, got %zd at dimension %zd",
123
+ static_cast<ssize_t>(sizes[i]), i);
124
+ ssize_t next_numel;
125
+ ET_CHECK_OR_RETURN_ERROR(
126
+ !c10::mul_overflows(numel, static_cast<ssize_t>(sizes[i]), &next_numel),
127
+ InvalidArgument, "Overflow computing numel at dimension %zd", i);
128
+ numel = next_numel;
129
+ }
130
+ return numel;
131
+ }
132
+
110
133
  #undef ET_PRI_TENSOR_SIZE
111
134
  #define ET_PRI_TENSOR_SIZE PRId64
112
135
 
@@ -153,6 +176,7 @@ using OptionalArrayRef =
153
176
  using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
154
177
 
155
178
  using torch::executor::compute_numel;
179
+ using torch::executor::safe_numel;
156
180
 
157
181
  #endif // Use ExecuTorch types
158
182
 
@@ -9,6 +9,7 @@
9
9
  #pragma once
10
10
 
11
11
  #include <c10/util/irange.h>
12
+ #include <c10/util/safe_numerics.h>
12
13
 
13
14
  #include <executorch/runtime/core/memory_allocator.h>
14
15
  #include <executorch/runtime/core/result.h>
@@ -56,17 +57,19 @@ public:
56
57
  size_t offset_bytes,
57
58
  size_t size_bytes) {
58
59
  // Check for integer overflow in offset_bytes + size_bytes.
59
- ET_CHECK_OR_RETURN_ERROR(size_bytes <= SIZE_MAX - offset_bytes,
60
- InvalidArgument,
61
- "Integer overflow in offset_bytes (%" ET_PRIsize_t
62
- ") + size_bytes (%" ET_PRIsize_t ")",
63
- offset_bytes, size_bytes);
60
+ size_t end_bytes = 0;
61
+ ET_CHECK_OR_RETURN_ERROR(
62
+ !c10::add_overflows(offset_bytes, size_bytes, &end_bytes),
63
+ InvalidArgument,
64
+ "Integer overflow in offset_bytes (%" ET_PRIsize_t
65
+ ") + size_bytes (%" ET_PRIsize_t ")",
66
+ offset_bytes, size_bytes);
64
67
  ET_CHECK_OR_RETURN_ERROR(memory_id < buffers_.size(), InvalidArgument,
65
68
  "id %" PRIu32 " >= %" ET_PRIsize_t, memory_id,
66
69
  buffers_.size());
67
70
  Span<uint8_t> buffer = buffers_[memory_id];
68
71
  ET_CHECK_OR_RETURN_ERROR(
69
- offset_bytes + size_bytes <= buffer.size(), MemoryAllocationFailed,
72
+ end_bytes <= buffer.size(), MemoryAllocationFailed,
70
73
  "offset_bytes (%" ET_PRIsize_t ") + size_bytes (%" ET_PRIsize_t
71
74
  ") >= allocator size (%" ET_PRIsize_t ") "
72
75
  "for memory_id %" PRIu32,
@@ -26,7 +26,6 @@ enum class DeviceType : int8_t {
26
26
  constexpr size_t kNumDeviceTypes = 2;
27
27
 
28
28
  /// An index representing a specific device; e.g. GPU 0 vs GPU 1.
29
- /// -1 means the default/unspecified device for that type.
30
29
  using DeviceIndex = int8_t;
31
30
 
32
31
  /**
@@ -41,7 +40,7 @@ struct Device final {
41
40
 
42
41
  /// Constructs a new `Device` from a `DeviceType` and an optional device
43
42
  /// index.
44
- /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
43
+ /* implicit */ Device(DeviceType type, DeviceIndex index = 0)
45
44
  : type_(type), index_(index) {}
46
45
 
47
46
  /// Returns the type of device the tensor data resides on.
@@ -50,7 +49,7 @@ struct Device final {
50
49
  /// Returns true if the device is of CPU type.
51
50
  bool is_cpu() const noexcept { return type_ == DeviceType::CPU; }
52
51
 
53
- /// Returns the device index, or -1 if default/unspecified.
52
+ /// Returns the device index.
54
53
  DeviceIndex index() const noexcept { return index_; }
55
54
 
56
55
  bool operator==(const Device &other) const noexcept {
@@ -63,7 +62,7 @@ struct Device final {
63
62
 
64
63
  private:
65
64
  DeviceType type_;
66
- DeviceIndex index_ = -1;
65
+ DeviceIndex index_ = 0;
67
66
  };
68
67
 
69
68
  } // namespace etensor