react-native-executorch 0.5.15 → 0.6.0-nightly-897eae9-20251213

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. package/README.md +42 -36
  2. package/android/CMakeLists.txt +13 -25
  3. package/android/build.gradle +2 -3
  4. package/android/libs/classes.jar +0 -0
  5. package/android/src/main/cpp/CMakeLists.txt +2 -1
  6. package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
  7. package/common/rnexecutorch/TokenizerModule.cpp +3 -3
  8. package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
  9. package/common/rnexecutorch/data_processing/Numerical.h +6 -1
  10. package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
  11. package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
  12. package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
  13. package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
  14. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
  15. package/common/rnexecutorch/models/BaseModel.cpp +12 -11
  16. package/common/rnexecutorch/models/BaseModel.h +18 -10
  17. package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
  18. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
  19. package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
  20. package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
  21. package/common/rnexecutorch/models/llm/LLM.h +4 -4
  22. package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
  23. package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
  24. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
  25. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
  26. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
  27. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
  28. package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
  29. package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
  30. package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
  31. package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
  32. package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
  33. package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
  34. package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
  35. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
  36. package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
  37. package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
  38. package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
  39. package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
  40. package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
  41. package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
  42. package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
  43. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
  44. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
  45. package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
  46. package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
  47. package/common/rnexecutorch/tests/README.md +30 -13
  48. package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
  49. package/common/runner/arange_util.cpp +44 -0
  50. package/common/runner/arange_util.h +37 -0
  51. package/common/runner/constants.h +28 -0
  52. package/common/runner/io_manager.h +240 -0
  53. package/common/runner/irunner.h +87 -16
  54. package/common/runner/kernel_includes.h +23 -0
  55. package/common/runner/runner.cpp +151 -66
  56. package/common/runner/runner.h +39 -22
  57. package/common/runner/sampler.cpp +8 -1
  58. package/common/runner/sampler.h +4 -2
  59. package/common/runner/stats.h +1 -4
  60. package/common/runner/text_decoder_runner.cpp +26 -12
  61. package/common/runner/text_decoder_runner.h +52 -31
  62. package/common/runner/text_prefiller.cpp +46 -12
  63. package/common/runner/text_prefiller.h +38 -4
  64. package/common/runner/text_token_generator.h +51 -26
  65. package/common/runner/util.h +53 -8
  66. package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
  67. package/lib/module/Error.js +1 -0
  68. package/lib/module/Error.js.map +1 -1
  69. package/lib/module/constants/directories.js +1 -1
  70. package/lib/module/constants/directories.js.map +1 -1
  71. package/lib/module/constants/modelUrls.js +32 -1
  72. package/lib/module/constants/modelUrls.js.map +1 -1
  73. package/lib/module/constants/ocr/models.js +7 -7
  74. package/lib/module/constants/ocr/models.js.map +1 -1
  75. package/lib/module/constants/ocr/symbols.js +3 -2
  76. package/lib/module/constants/ocr/symbols.js.map +1 -1
  77. package/lib/module/controllers/LLMController.js +10 -1
  78. package/lib/module/controllers/LLMController.js.map +1 -1
  79. package/lib/module/controllers/OCRController.js +3 -3
  80. package/lib/module/controllers/OCRController.js.map +1 -1
  81. package/lib/module/controllers/VerticalOCRController.js +2 -2
  82. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  83. package/lib/module/hooks/computer_vision/useOCR.js +3 -3
  84. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  85. package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
  86. package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
  87. package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
  88. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  89. package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
  90. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  91. package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
  92. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
  93. package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
  94. package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
  95. package/lib/module/index.js +7 -2
  96. package/lib/module/index.js.map +1 -1
  97. package/lib/module/modules/computer_vision/OCRModule.js +2 -2
  98. package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
  99. package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
  100. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
  101. package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
  102. package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
  103. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
  104. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  105. package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
  106. package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
  107. package/lib/module/types/llm.js.map +1 -1
  108. package/lib/module/types/vad.js +2 -0
  109. package/lib/module/types/vad.js.map +1 -0
  110. package/lib/module/utils/ResourceFetcher.js +2 -1
  111. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  112. package/lib/module/utils/ResourceFetcherUtils.js +6 -6
  113. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  114. package/lib/typescript/Error.d.ts +1 -0
  115. package/lib/typescript/Error.d.ts.map +1 -1
  116. package/lib/typescript/constants/modelUrls.d.ts +23 -0
  117. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  118. package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
  119. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  120. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  121. package/lib/typescript/controllers/OCRController.d.ts +1 -1
  122. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  123. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  124. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  125. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
  126. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  127. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
  128. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
  129. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
  130. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  131. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  132. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
  133. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
  134. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
  135. package/lib/typescript/index.d.ts +8 -1
  136. package/lib/typescript/index.d.ts.map +1 -1
  137. package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
  138. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  139. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
  140. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
  141. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
  142. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  143. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
  144. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  145. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
  146. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
  147. package/lib/typescript/types/llm.d.ts +2 -0
  148. package/lib/typescript/types/llm.d.ts.map +1 -1
  149. package/lib/typescript/types/vad.d.ts +5 -0
  150. package/lib/typescript/types/vad.d.ts.map +1 -0
  151. package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
  152. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  153. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
  154. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  155. package/package.json +11 -8
  156. package/react-native-executorch.podspec +9 -9
  157. package/src/Error.ts +1 -0
  158. package/src/constants/directories.ts +1 -1
  159. package/src/constants/modelUrls.ts +36 -1
  160. package/src/constants/ocr/models.ts +7 -7
  161. package/src/constants/ocr/symbols.ts +3 -2
  162. package/src/controllers/LLMController.ts +12 -1
  163. package/src/controllers/OCRController.ts +3 -3
  164. package/src/controllers/VerticalOCRController.ts +2 -2
  165. package/src/hooks/computer_vision/useOCR.ts +4 -5
  166. package/src/hooks/computer_vision/useTextToImage.ts +92 -0
  167. package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
  168. package/src/hooks/natural_language_processing/useLLM.ts +3 -4
  169. package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
  170. package/src/hooks/natural_language_processing/useVAD.ts +15 -0
  171. package/src/index.ts +20 -1
  172. package/src/modules/computer_vision/OCRModule.ts +2 -2
  173. package/src/modules/computer_vision/TextToImageModule.ts +93 -0
  174. package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
  175. package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
  176. package/src/modules/natural_language_processing/VADModule.ts +27 -0
  177. package/src/types/llm.ts +2 -0
  178. package/src/types/vad.ts +4 -0
  179. package/src/utils/ResourceFetcher.ts +2 -1
  180. package/src/utils/ResourceFetcherUtils.ts +8 -8
  181. package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
  182. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  183. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  184. package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
  185. package/third-party/include/c10/macros/Export.h +0 -78
  186. package/third-party/include/c10/macros/Macros.h +1 -520
  187. package/third-party/include/c10/util/BFloat16-inl.h +1 -339
  188. package/third-party/include/c10/util/BFloat16.h +1 -122
  189. package/third-party/include/c10/util/Half-inl.h +1 -347
  190. package/third-party/include/c10/util/Half.h +6 -419
  191. package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
  192. package/third-party/include/c10/util/bit_cast.h +1 -43
  193. package/third-party/include/c10/util/complex.h +1 -568
  194. package/third-party/include/c10/util/floating_point_utils.h +1 -33
  195. package/third-party/include/c10/util/irange.h +1 -1
  196. package/third-party/include/c10/util/llvmMathExtras.h +866 -0
  197. package/third-party/include/c10/util/safe_numerics.h +97 -0
  198. package/third-party/include/executorch/ExecuTorchError.h +6 -7
  199. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
  200. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
  201. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
  202. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
  203. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
  204. package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
  205. package/third-party/include/executorch/ExecuTorchLog.h +1 -0
  206. package/third-party/include/executorch/ExecuTorchModule.h +177 -4
  207. package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
  208. package/third-party/include/executorch/ExecuTorchValue.h +1 -7
  209. package/third-party/include/executorch/extension/module/module.h +139 -8
  210. package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
  211. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
  212. package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
  213. package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
  214. package/third-party/include/executorch/runtime/backend/interface.h +1 -1
  215. package/third-party/include/executorch/runtime/core/error.h +76 -49
  216. package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
  217. package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
  218. package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
  219. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
  220. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
  221. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
  222. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
  223. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
  224. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
  225. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
  226. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
  227. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
  228. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
  229. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
  230. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
  231. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
  232. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
  233. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
  234. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
  235. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
  236. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  237. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
  238. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
  239. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
  240. package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
  241. package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
  242. package/third-party/include/executorch/runtime/executor/method.h +21 -8
  243. package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
  244. package/third-party/include/executorch/runtime/executor/program.h +0 -10
  245. package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
  246. package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
  247. package/third-party/include/executorch/schema/extended_header.h +10 -1
  248. package/third-party/include/torch/headeronly/macros/Export.h +66 -0
  249. package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
  250. package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
  251. package/third-party/include/torch/headeronly/util/Half.h +781 -0
  252. package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  253. package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
  254. package/third-party/include/torch/headeronly/util/complex.h +593 -0
  255. package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
  256. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  257. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  258. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  259. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  260. package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
  261. package/common/rnexecutorch/tests/run_test.sh +0 -18
  262. package/ios/RnExecutorch/utils/Conversions.h +0 -14
  263. package/ios/RnExecutorch/utils/ETError.h +0 -26
  264. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
  265. package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
  266. package/ios/RnExecutorch/utils/Numerical.h +0 -3
  267. package/ios/RnExecutorch/utils/Numerical.mm +0 -18
  268. package/ios/RnExecutorch/utils/ScalarType.h +0 -14
  269. package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
  270. package/lib/module/hooks/useNonStaticModule.js.map +0 -1
  271. package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
  272. package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
  273. package/src/hooks/useNonStaticModule.ts +0 -74
  274. package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
  275. package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
  276. package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
  277. package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
@@ -0,0 +1,240 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <executorch/extension/module/module.h>
12
+ #include <executorch/extension/tensor/tensor.h>
13
+
14
+ namespace executorch {
15
+ namespace extension {
16
+ namespace llm {
17
+
18
+ /**
19
+ * @brief Base class for managing input/output operations for LLM inference.
20
+ *
21
+ * IOManager provides an interface for handling the input preparation and
22
+ * output processing for both prefill and decode phases of LLM inference.
23
+ * Derived classes must implement the virtual methods to provide specific IO
24
+ * management functionality.
25
+ */
26
+ class IOManager {
27
+ public:
28
+ /**
29
+ * @brief Construct an IOManager bound to a Module.
30
+ *
31
+ * @param module The Module used for querying method metadata and execution.
32
+ */
33
+ explicit IOManager(ET_MODULE_NAMESPACE::Module &module) : module_(module) {}
34
+
35
+ /**
36
+ * @brief Virtual destructor to allow proper cleanup in derived classes.
37
+ */
38
+ virtual ~IOManager() = default;
39
+
40
+ /**
41
+ * @brief Load the IO manager with method metadata for prefill and
42
+ * decode operations.
43
+ *
44
+ * @param prefill_method The prefill method to initialize with.
45
+ * @param decode_method The decode method to initialize with.
46
+ */
47
+ ET_NODISCARD virtual runtime::Error load(const std::string &prefill_method,
48
+ const std::string &decode_method) {
49
+ (void)prefill_method;
50
+ (void)decode_method;
51
+ return runtime::Error::Ok;
52
+ }
53
+
54
+ /**
55
+ * @brief Load the IO manager using the default method names.
56
+ *
57
+ * Uses "forward" for both prefill and decode.
58
+ *
59
+ * @return Error code.
60
+ */
61
+ ET_NODISCARD runtime::Error load() { return load("forward", "forward"); }
62
+
63
+ /**
64
+ * @brief Reset the IO manager state.
65
+ *
66
+ * @param prefill_method The prefill method to reset with.
67
+ * @param decode_method The decode method to reset with.
68
+ */
69
+ ET_NODISCARD virtual runtime::Error reset(const std::string &prefill_method,
70
+ const std::string &decode_method) {
71
+ (void)prefill_method;
72
+ (void)decode_method;
73
+ return runtime::Error::Ok;
74
+ }
75
+
76
+ /**
77
+ * @brief Reset the IO manager state using the default method names.
78
+ *
79
+ * Uses "forward" for both prefill and decode.
80
+ *
81
+ * @return Error code.
82
+ */
83
+ ET_NODISCARD runtime::Error reset() { return reset("forward", "forward"); }
84
+
85
+ /**
86
+ * @brief Prepare inputs for the prefill phase of LLM inference.
87
+ *
88
+ * @param input The input tensor containing token IDs.
89
+ * @param start_pos The tensor containing the starting position of the current
90
+ * input within the context.
91
+ * @param prefill_method The prefill method to prepare inputs for.
92
+ * @return std::vector<runtime::EValue> Vector of prepared inputs
93
+ * for the prefill method.
94
+ */
95
+ virtual runtime::Result<std::vector<runtime::EValue>>
96
+ prepare_prefill(const TensorPtr &input, const TensorPtr &start_pos,
97
+ const std::string &prefill_method) {
98
+ auto method_meta = module_.method_meta(prefill_method);
99
+ if (!method_meta.ok()) {
100
+ return method_meta.error();
101
+ }
102
+ if (method_meta->num_inputs() != 2) {
103
+ ET_LOG(Error,
104
+ "Expected 2 inputs for prefill method, got %zu. Likely the model "
105
+ "takes the caches or mask as an argument which this IOManager "
106
+ "does not support.",
107
+ method_meta->num_inputs());
108
+ return runtime::Error::InvalidState;
109
+ }
110
+ // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
111
+ // here.
112
+ return std::vector<runtime::EValue>{input, start_pos};
113
+ }
114
+
115
+ /**
116
+ * @brief Prepare inputs for the prefill phase using the default method name.
117
+ *
118
+ * Uses "forward" as the prefill method.
119
+ *
120
+ * @param input The input tensor containing token IDs.
121
+ * @param start_pos The tensor containing the starting position.
122
+ * @return Vector of prepared inputs for the prefill method.
123
+ */
124
+ runtime::Result<std::vector<runtime::EValue>>
125
+ prepare_prefill(const TensorPtr &input, const TensorPtr &start_pos) {
126
+ return prepare_prefill(input, start_pos, "forward");
127
+ }
128
+
129
+ /**
130
+ * @brief Prepare inputs for the decode phase of LLM inference.
131
+ *
132
+ * @param input The input tensor containing token IDs.
133
+ * @param start_pos The tensor containing the starting position of the current
134
+ * input within the context.
135
+ * @param decode_method The decode method to prepare inputs for.
136
+ * @return std::vector<runtime::EValue> Vector of prepared inputs
137
+ * for the decode method.
138
+ */
139
+ virtual runtime::Result<std::vector<runtime::EValue>>
140
+ prepare_decode(const TensorPtr &input, const TensorPtr &start_pos,
141
+ const std::string &decode_method) {
142
+ auto method_meta = module_.method_meta(decode_method);
143
+ if (!method_meta.ok()) {
144
+ return method_meta.error();
145
+ }
146
+ if (method_meta->num_inputs() != 2) {
147
+ ET_LOG(Error,
148
+ "Expected 2 inputs for decode method, got %zu. Likely the model "
149
+ "takes the caches or mask as an argument which this IOManager "
150
+ "does not support.",
151
+ method_meta->num_inputs());
152
+ return runtime::Error::InvalidState;
153
+ }
154
+ // Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
155
+ // here.
156
+ return std::vector<runtime::EValue>{input, start_pos};
157
+ }
158
+
159
+ /**
160
+ * @brief Prepare inputs for the decode phase using the default method name.
161
+ *
162
+ * Uses "forward" as the decode method.
163
+ *
164
+ * @param input The input tensor containing token IDs.
165
+ * @param start_pos The tensor containing the starting position.
166
+ * @return Vector of prepared inputs for the decode method.
167
+ */
168
+ runtime::Result<std::vector<runtime::EValue>>
169
+ prepare_decode(const TensorPtr &input, const TensorPtr &start_pos) {
170
+ return prepare_decode(input, start_pos, "forward");
171
+ }
172
+
173
+ /**
174
+ * @brief Process and update internal state with outputs from the prefill
175
+ * phase.
176
+ *
177
+ * @param prefill_method The prefill method to update with outputs.
178
+ * @param model_outputs Vector of outputs from the prefill method execution.
179
+ */
180
+ ET_NODISCARD virtual runtime::Error
181
+ update_prefill(const std::vector<runtime::EValue> &model_outputs,
182
+ const std::string &prefill_method) {
183
+ (void)model_outputs;
184
+ (void)prefill_method;
185
+ // No post inference work to do.
186
+ return runtime::Error::Ok;
187
+ }
188
+
189
+ /**
190
+ * @brief Process outputs from the prefill phase using the default method.
191
+ *
192
+ * Uses "forward" as the prefill method.
193
+ *
194
+ * @param model_outputs Vector of outputs from the prefill execution.
195
+ * @return Error code.
196
+ */
197
+ ET_NODISCARD runtime::Error
198
+ update_prefill(const std::vector<runtime::EValue> &model_outputs) {
199
+ return update_prefill(model_outputs, "forward");
200
+ }
201
+
202
+ /**
203
+ * @brief Process and update internal state with outputs from the decode
204
+ * phase.
205
+ *
206
+ * @param decode_method The decode method to update with outputs.
207
+ * @param model_outputs Vector of outputs from the decode method execution.
208
+ */
209
+ ET_NODISCARD virtual runtime::Error
210
+ update_decode(const std::vector<runtime::EValue> &model_outputs,
211
+ const std::string &decode_method) {
212
+ (void)model_outputs;
213
+ (void)decode_method;
214
+ // No post inference work to do.
215
+ return runtime::Error::Ok;
216
+ }
217
+
218
+ /**
219
+ * @brief Process outputs from the decode phase using the default method.
220
+ *
221
+ * Uses "forward" as the decode method.
222
+ *
223
+ * @param model_outputs Vector of outputs from the decode execution.
224
+ * @return Error code.
225
+ */
226
+ ET_NODISCARD runtime::Error
227
+ update_decode(const std::vector<runtime::EValue> &model_outputs) {
228
+ return update_decode(model_outputs, "forward");
229
+ }
230
+
231
+ private:
232
+ /**
233
+ * @brief Reference to the Module used for method metadata and execution.
234
+ */
235
+ ET_MODULE_NAMESPACE::Module &module_;
236
+ };
237
+
238
+ } // namespace llm
239
+ } // namespace extension
240
+ } // namespace executorch
@@ -6,41 +6,112 @@
6
6
  * LICENSE file in the root directory of this source tree.
7
7
  */
8
8
 
9
- // An interface for LLM runners. Developers can create their own runner that
10
- // implements their own load and generation logic to run the model.
9
+ // Interface for text generation runners.
11
10
 
12
11
  #pragma once
13
12
 
13
+ #include "stats.h"
14
+
15
+ #include <cstdint>
14
16
  #include <functional>
17
+ #include <memory>
15
18
  #include <string>
16
19
 
17
- #include "stats.h"
18
- #include <executorch/extension/module/module.h>
20
+ #include <executorch/runtime/core/error.h>
19
21
 
20
22
  namespace executorch {
21
23
  namespace extension {
22
24
  namespace llm {
23
25
 
24
- class ET_EXPERIMENTAL IRunner {
26
+ // Configuration struct for generation parameters
27
+ struct GenerationConfig {
28
+ // Whether to echo the input prompt in the output
29
+ bool echo = false;
30
+
31
+ // Whether this is a warmup run (affects perf benchmarking)
32
+ bool warming = false;
33
+
34
+ // Maximum number of new tokens to generate
35
+ // If the max_context_len metadata that's serialized in the .pte file exists,
36
+ // then the number of prompt tokens + max_new_tokens won't exceed
37
+ // max_context_len. If this field is -1, it means we will rely on
38
+ // max_context_len metadata and seq_len value.
39
+ int32_t max_new_tokens = -1;
40
+
41
+ // Maximum number of total tokens
42
+ // If the .pte file contains the max_context_len metadata, it will override
43
+ // this value if it's smaller. If this field is -1, we will use the
44
+ // max_context_len metadata directly.
45
+ int32_t max_seq_len = -1;
46
+
47
+ // Maximum context length
48
+ // If the .pte file contains the max_context_len metadata, it will override
49
+ // this value if it's smaller. If this field is -1, we will use the
50
+ // max_context_len metadata directly.
51
+ int32_t max_context_length = -1;
52
+
53
+ // Temperature for sampling (higher = more random)
54
+ float temperature = -1.F;
55
+
56
+ // Top-p (nucleus sampling) – limits next token selection to the smallest set
57
+ // whose cumulative probability exceeds topp. Range: 0.0 to 1.0. Lower values
58
+ // = more deterministic, higher = more diverse generations.
59
+ float topp = -1.F;
60
+
61
+ // Enable dynamic input shapes (if implemented) or not
62
+ // Impacts the prefill phase and causes TextPrefiller to pass all the tokens
63
+ // at once if set to true.
64
+ bool enable_dynamic_shape = true;
65
+
66
+ // Use KV_CACHE implementation (if implemented) or not
67
+ bool enable_kv_cache = true;
68
+ };
69
+
70
+ // Base interface for LLM runners
71
+ class IRunner {
25
72
  public:
26
73
  virtual ~IRunner() = default;
27
74
 
28
- // Checks if the model is loaded.
75
+ /**
76
+ * Check if the runner is loaded and ready for inference.
77
+ *
78
+ * @return true if the runner is loaded, false otherwise
79
+ */
29
80
  virtual bool is_loaded() const = 0;
30
81
 
31
- // Load the model and tokenizer.
32
- virtual ::executorch::runtime::Error load() = 0;
82
+ /**
83
+ * Load the model and prepare for inference.
84
+ *
85
+ * @return Error::Ok if successful, an error otherwise
86
+ */
87
+ virtual runtime::Error load() = 0;
33
88
 
34
- // Generate the output tokens.
35
- virtual ::executorch::runtime::Error
36
- generate(const std::string &prompt,
37
- std::function<void(const std::string &)> token_callback = {},
38
- std::function<void(const ::executorch::extension::llm::Stats &)>
39
- stats_callback = {},
40
- bool echo = true, bool warming = false) = 0;
89
+ /**
90
+ * Generate text based on the provided prompt and generation config.
91
+ *
92
+ * @param prompt The input prompt to generate from
93
+ * @param config Generation configuration parameters
94
+ * @param token_callback Callback function called for each generated token
95
+ * @param stats_callback Callback function for generation statistics
96
+ * @return Error::Ok if successful, an error otherwise
97
+ */
98
+ virtual runtime::Error
99
+ generate(const std::string &prompt, const GenerationConfig &config,
100
+ std::function<void(const std::string &)> token_callback,
101
+ std::function<void(const Stats &)> stats_callback) = 0;
41
102
 
42
- // Stop the generation.
103
+ /**
104
+ * Stop the generation process.
105
+ */
43
106
  virtual void stop() = 0;
107
+
108
+ /**
109
+ * Force remove prefilled tokens and reset KV cache start position
110
+ *
111
+ * This method removes the prefilled tokens from the KV cache and resets the
112
+ * start position to 0.
113
+ */
114
+ virtual void reset() = 0;
44
115
  };
45
116
 
46
117
  } // namespace llm
@@ -0,0 +1,23 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ /**
10
+ * @file
11
+ *
12
+ * Common includes used by all kernel implementations.
13
+ */
14
+
15
+ #pragma once
16
+
17
+ // This list should be very conservative since most kernel .cpp files will
18
+ // include these and depend on their transitive deps. Only add a header if 99%
19
+ // of kernels would have included it anyway.
20
+ #include <executorch/runtime/core/exec_aten/exec_aten.h> // IWYU pragma: export
21
+ #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> // IWYU pragma: export
22
+ #include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
23
+ #include <executorch/runtime/kernel/kernel_runtime_context.h> // IWYU pragma: export