react-native-executorch 0.5.15 → 0.6.0-nightly-897eae9-20251213

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. package/README.md +42 -36
  2. package/android/CMakeLists.txt +13 -25
  3. package/android/build.gradle +2 -3
  4. package/android/libs/classes.jar +0 -0
  5. package/android/src/main/cpp/CMakeLists.txt +2 -1
  6. package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
  7. package/common/rnexecutorch/TokenizerModule.cpp +3 -3
  8. package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
  9. package/common/rnexecutorch/data_processing/Numerical.h +6 -1
  10. package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
  11. package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
  12. package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
  13. package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
  14. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
  15. package/common/rnexecutorch/models/BaseModel.cpp +12 -11
  16. package/common/rnexecutorch/models/BaseModel.h +18 -10
  17. package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
  18. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
  19. package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
  20. package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
  21. package/common/rnexecutorch/models/llm/LLM.h +4 -4
  22. package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
  23. package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
  24. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
  25. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
  26. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
  27. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
  28. package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
  29. package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
  30. package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
  31. package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
  32. package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
  33. package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
  34. package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
  35. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
  36. package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
  37. package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
  38. package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
  39. package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
  40. package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
  41. package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
  42. package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
  43. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
  44. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
  45. package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
  46. package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
  47. package/common/rnexecutorch/tests/README.md +30 -13
  48. package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
  49. package/common/runner/arange_util.cpp +44 -0
  50. package/common/runner/arange_util.h +37 -0
  51. package/common/runner/constants.h +28 -0
  52. package/common/runner/io_manager.h +240 -0
  53. package/common/runner/irunner.h +87 -16
  54. package/common/runner/kernel_includes.h +23 -0
  55. package/common/runner/runner.cpp +151 -66
  56. package/common/runner/runner.h +39 -22
  57. package/common/runner/sampler.cpp +8 -1
  58. package/common/runner/sampler.h +4 -2
  59. package/common/runner/stats.h +1 -4
  60. package/common/runner/text_decoder_runner.cpp +26 -12
  61. package/common/runner/text_decoder_runner.h +52 -31
  62. package/common/runner/text_prefiller.cpp +46 -12
  63. package/common/runner/text_prefiller.h +38 -4
  64. package/common/runner/text_token_generator.h +51 -26
  65. package/common/runner/util.h +53 -8
  66. package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
  67. package/lib/module/Error.js +1 -0
  68. package/lib/module/Error.js.map +1 -1
  69. package/lib/module/constants/directories.js +1 -1
  70. package/lib/module/constants/directories.js.map +1 -1
  71. package/lib/module/constants/modelUrls.js +32 -1
  72. package/lib/module/constants/modelUrls.js.map +1 -1
  73. package/lib/module/constants/ocr/models.js +7 -7
  74. package/lib/module/constants/ocr/models.js.map +1 -1
  75. package/lib/module/constants/ocr/symbols.js +3 -2
  76. package/lib/module/constants/ocr/symbols.js.map +1 -1
  77. package/lib/module/controllers/LLMController.js +10 -1
  78. package/lib/module/controllers/LLMController.js.map +1 -1
  79. package/lib/module/controllers/OCRController.js +3 -3
  80. package/lib/module/controllers/OCRController.js.map +1 -1
  81. package/lib/module/controllers/VerticalOCRController.js +2 -2
  82. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  83. package/lib/module/hooks/computer_vision/useOCR.js +3 -3
  84. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  85. package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
  86. package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
  87. package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
  88. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  89. package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
  90. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  91. package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
  92. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
  93. package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
  94. package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
  95. package/lib/module/index.js +7 -2
  96. package/lib/module/index.js.map +1 -1
  97. package/lib/module/modules/computer_vision/OCRModule.js +2 -2
  98. package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
  99. package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
  100. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
  101. package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
  102. package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
  103. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
  104. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  105. package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
  106. package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
  107. package/lib/module/types/llm.js.map +1 -1
  108. package/lib/module/types/vad.js +2 -0
  109. package/lib/module/types/vad.js.map +1 -0
  110. package/lib/module/utils/ResourceFetcher.js +2 -1
  111. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  112. package/lib/module/utils/ResourceFetcherUtils.js +6 -6
  113. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  114. package/lib/typescript/Error.d.ts +1 -0
  115. package/lib/typescript/Error.d.ts.map +1 -1
  116. package/lib/typescript/constants/modelUrls.d.ts +23 -0
  117. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  118. package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
  119. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  120. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  121. package/lib/typescript/controllers/OCRController.d.ts +1 -1
  122. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  123. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  124. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  125. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
  126. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  127. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
  128. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
  129. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
  130. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  131. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  132. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
  133. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
  134. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
  135. package/lib/typescript/index.d.ts +8 -1
  136. package/lib/typescript/index.d.ts.map +1 -1
  137. package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
  138. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  139. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
  140. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
  141. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
  142. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  143. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
  144. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  145. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
  146. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
  147. package/lib/typescript/types/llm.d.ts +2 -0
  148. package/lib/typescript/types/llm.d.ts.map +1 -1
  149. package/lib/typescript/types/vad.d.ts +5 -0
  150. package/lib/typescript/types/vad.d.ts.map +1 -0
  151. package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
  152. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  153. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
  154. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  155. package/package.json +11 -8
  156. package/react-native-executorch.podspec +9 -9
  157. package/src/Error.ts +1 -0
  158. package/src/constants/directories.ts +1 -1
  159. package/src/constants/modelUrls.ts +36 -1
  160. package/src/constants/ocr/models.ts +7 -7
  161. package/src/constants/ocr/symbols.ts +3 -2
  162. package/src/controllers/LLMController.ts +12 -1
  163. package/src/controllers/OCRController.ts +3 -3
  164. package/src/controllers/VerticalOCRController.ts +2 -2
  165. package/src/hooks/computer_vision/useOCR.ts +4 -5
  166. package/src/hooks/computer_vision/useTextToImage.ts +92 -0
  167. package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
  168. package/src/hooks/natural_language_processing/useLLM.ts +3 -4
  169. package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
  170. package/src/hooks/natural_language_processing/useVAD.ts +15 -0
  171. package/src/index.ts +20 -1
  172. package/src/modules/computer_vision/OCRModule.ts +2 -2
  173. package/src/modules/computer_vision/TextToImageModule.ts +93 -0
  174. package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
  175. package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
  176. package/src/modules/natural_language_processing/VADModule.ts +27 -0
  177. package/src/types/llm.ts +2 -0
  178. package/src/types/vad.ts +4 -0
  179. package/src/utils/ResourceFetcher.ts +2 -1
  180. package/src/utils/ResourceFetcherUtils.ts +8 -8
  181. package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
  182. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  183. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  184. package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
  185. package/third-party/include/c10/macros/Export.h +0 -78
  186. package/third-party/include/c10/macros/Macros.h +1 -520
  187. package/third-party/include/c10/util/BFloat16-inl.h +1 -339
  188. package/third-party/include/c10/util/BFloat16.h +1 -122
  189. package/third-party/include/c10/util/Half-inl.h +1 -347
  190. package/third-party/include/c10/util/Half.h +6 -419
  191. package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
  192. package/third-party/include/c10/util/bit_cast.h +1 -43
  193. package/third-party/include/c10/util/complex.h +1 -568
  194. package/third-party/include/c10/util/floating_point_utils.h +1 -33
  195. package/third-party/include/c10/util/irange.h +1 -1
  196. package/third-party/include/c10/util/llvmMathExtras.h +866 -0
  197. package/third-party/include/c10/util/safe_numerics.h +97 -0
  198. package/third-party/include/executorch/ExecuTorchError.h +6 -7
  199. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
  200. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
  201. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
  202. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
  203. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
  204. package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
  205. package/third-party/include/executorch/ExecuTorchLog.h +1 -0
  206. package/third-party/include/executorch/ExecuTorchModule.h +177 -4
  207. package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
  208. package/third-party/include/executorch/ExecuTorchValue.h +1 -7
  209. package/third-party/include/executorch/extension/module/module.h +139 -8
  210. package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
  211. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
  212. package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
  213. package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
  214. package/third-party/include/executorch/runtime/backend/interface.h +1 -1
  215. package/third-party/include/executorch/runtime/core/error.h +76 -49
  216. package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
  217. package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
  218. package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
  219. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
  220. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
  221. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
  222. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
  223. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
  224. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
  225. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
  226. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
  227. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
  228. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
  229. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
  230. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
  231. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
  232. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
  233. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
  234. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
  235. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
  236. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  237. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
  238. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
  239. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
  240. package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
  241. package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
  242. package/third-party/include/executorch/runtime/executor/method.h +21 -8
  243. package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
  244. package/third-party/include/executorch/runtime/executor/program.h +0 -10
  245. package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
  246. package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
  247. package/third-party/include/executorch/schema/extended_header.h +10 -1
  248. package/third-party/include/torch/headeronly/macros/Export.h +66 -0
  249. package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
  250. package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
  251. package/third-party/include/torch/headeronly/util/Half.h +781 -0
  252. package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  253. package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
  254. package/third-party/include/torch/headeronly/util/complex.h +593 -0
  255. package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
  256. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  257. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  258. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  259. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  260. package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
  261. package/common/rnexecutorch/tests/run_test.sh +0 -18
  262. package/ios/RnExecutorch/utils/Conversions.h +0 -14
  263. package/ios/RnExecutorch/utils/ETError.h +0 -26
  264. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
  265. package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
  266. package/ios/RnExecutorch/utils/Numerical.h +0 -3
  267. package/ios/RnExecutorch/utils/Numerical.mm +0 -18
  268. package/ios/RnExecutorch/utils/ScalarType.h +0 -14
  269. package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
  270. package/lib/module/hooks/useNonStaticModule.js.map +0 -1
  271. package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
  272. package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
  273. package/src/hooks/useNonStaticModule.ts +0 -74
  274. package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
  275. package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
  276. package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
  277. package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
@@ -4,6 +4,7 @@
4
4
  *
5
5
  * This source code is licensed under the BSD-style license found in the
6
6
  * LICENSE file in the root directory of this source tree.
7
+ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
7
8
  */
8
9
 
9
10
  // A simple llama2 runner that includes preprocessing and post processing logic.
@@ -21,8 +22,6 @@ using ::executorch::extension::Module;
21
22
  using ::executorch::runtime::Error;
22
23
  using ::executorch::runtime::Result;
23
24
 
24
- namespace llm = ::executorch::extension::llm;
25
-
26
25
  std::string loadBytesFromFile(const std::string &path) {
27
26
  std::ifstream fs(path, std::ios::in | std::ios::binary);
28
27
  if (fs.fail()) {
@@ -39,7 +38,6 @@ std::string loadBytesFromFile(const std::string &path) {
39
38
 
40
39
  namespace {
41
40
  static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
42
- static constexpr auto kBosId = "get_bos_id";
43
41
  static constexpr auto kEosIds = "get_eos_ids";
44
42
  static constexpr auto kMaxSeqLen = "get_max_seq_len";
45
43
  static constexpr auto kMaxContextLen = "get_max_context_len";
@@ -48,29 +46,16 @@ static constexpr auto kUseKVCache = "use_kv_cache";
48
46
  static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
49
47
  } // namespace
50
48
 
51
- Runner::Runner(const std::string &model_path, const std::string &tokenizer_path,
52
- const float temperature,
53
- std::optional<const std::string> data_path)
54
- // NOTE: we observed ~2x loading performance increase on iPhone 15
55
- // and a ~5% improvement on Galaxy S22 by switching to
56
- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
57
- : temperature_(temperature), tokenizer_path_(tokenizer_path),
49
+ Runner::Runner(Module *module, const std::string &tokenizer_path,
50
+ const llm::GenerationConfig &config)
51
+ : config_(config), module_(module), tokenizer_path_(tokenizer_path),
58
52
  metadata_({
59
53
  {kEnableDynamicShape, false},
60
54
  {kMaxSeqLen, 128},
61
55
  {kMaxContextLen, 128},
62
56
  {kUseKVCache, true},
63
57
  {kUseSDPAWithKVCache, false},
64
- }) {
65
- if (data_path.has_value()) {
66
- module_ = std::make_unique<Module>(model_path, data_path.value(),
67
- Module::LoadMode::File);
68
- } else {
69
- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
70
- }
71
- ET_LOG(Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
72
- model_path.c_str(), tokenizer_path.c_str());
73
- }
58
+ }) {}
74
59
 
75
60
  bool Runner::is_loaded() const {
76
61
  return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
@@ -81,9 +66,10 @@ Error Runner::load() {
81
66
  if (is_loaded()) {
82
67
  return Error::Ok;
83
68
  }
69
+
84
70
  ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
85
- // load tokenizer.
86
71
 
72
+ // Load tokenizer.
87
73
  auto blob = loadBytesFromFile(tokenizer_path_);
88
74
  tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(blob);
89
75
 
@@ -92,9 +78,9 @@ Error Runner::load() {
92
78
  auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
93
79
  metadata_[kVocabSize] = tokenizer_->GetVocabSize();
94
80
 
81
+ // Load model metadata
95
82
  const auto method_names =
96
83
  ET_UNWRAP(module_->method_names(), "Failed reading method names");
97
-
98
84
  for (auto &pair : metadata_) {
99
85
  const auto &method_name = pair.first;
100
86
  auto &value = pair.second;
@@ -103,11 +89,13 @@ Error Runner::load() {
103
89
  .toScalar()
104
90
  .to<decltype(metadata_)::mapped_type>();
105
91
  } else {
106
- ET_LOG(Info, "Methond %s not found, using the default value %" PRId64,
92
+ ET_LOG(Info, "Method %s not found, using the default value %" PRId64,
107
93
  method_name.c_str(), value);
108
94
  }
109
95
  ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
110
96
  }
97
+
98
+ // Load EOS token ids
111
99
  if (method_names.count(kEosIds)) {
112
100
  eos_ids->clear();
113
101
  for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
@@ -116,15 +104,34 @@ Error Runner::load() {
116
104
  ET_LOG(Info, "eos_id = %" PRId64, value);
117
105
  }
118
106
  }
107
+
108
+ // Determine missing config values
109
+ // If user does not directly specify configuration parameters such as
110
+ // max_seq_len (i.e. leaves them as default values), they are determined by
111
+ // reading the exported model's methods.
112
+ if (config_.max_seq_len < 0)
113
+ config_.max_seq_len = static_cast<int32_t>(metadata_.at(kMaxSeqLen));
114
+ if (config_.max_context_length < 0)
115
+ config_.max_context_length =
116
+ static_cast<int32_t>(metadata_.at(kMaxContextLen));
117
+ if (config_.max_new_tokens < 0)
118
+ config_.max_new_tokens =
119
+ std::min(config_.max_seq_len, config_.max_context_length);
120
+ if (config_.enable_dynamic_shape)
121
+ config_.enable_dynamic_shape =
122
+ static_cast<bool>(metadata_.at(kEnableDynamicShape));
123
+ if (config_.enable_kv_cache)
124
+ config_.enable_kv_cache = static_cast<bool>(metadata_.at(kUseKVCache));
125
+
126
+ io_manager_ = std::make_unique<llm::IOManager>(*module_);
119
127
  text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
120
- module_.get(), metadata_.at(kUseKVCache), metadata_.at(kVocabSize),
121
- temperature_);
128
+ module_, io_manager_.get(), config_.temperature, config_.topp);
122
129
  text_prefiller_ = std::make_unique<llm::TextPrefiller>(
123
- text_decoder_runner_.get(), metadata_.at(kUseKVCache),
124
- metadata_.at(kEnableDynamicShape));
130
+ text_decoder_runner_.get(), config_.enable_kv_cache,
131
+ config_.enable_dynamic_shape, config_.max_seq_len);
125
132
 
126
133
  text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
127
- tokenizer_.get(), text_decoder_runner_.get(), metadata_.at(kUseKVCache),
134
+ tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
128
135
  std::move(eos_ids), &stats_);
129
136
 
130
137
  return Error::Ok;
@@ -139,9 +146,9 @@ Error Runner::load() {
139
146
  }
140
147
 
141
148
  Error Runner::generate(const std::string &prompt,
149
+ const llm::GenerationConfig &generation_config,
142
150
  std::function<void(const std::string &)> token_callback,
143
- std::function<void(const llm::Stats &)> stats_callback,
144
- bool echo, bool warmup) {
151
+ std::function<void(const llm::Stats &)> stats_callback) {
145
152
  // Prepare the inputs.
146
153
  // Use ones-initialized inputs.
147
154
  ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -151,17 +158,18 @@ Error Runner::generate(const std::string &prompt,
151
158
  stats_.model_load_end_ms = llm::time_in_ms();
152
159
  }
153
160
 
154
- if (warmup) {
161
+ if (generation_config.warming) {
155
162
  ET_LOG(Info, "Doing a warmup run...");
156
163
  }
157
164
 
158
- RUNNER_ET_LOG(warmup, "RSS after loading model: %f MiB (0 if unsupported)",
165
+ RUNNER_ET_LOG(generation_config.warming,
166
+ "RSS after loading model: %f MiB (0 if unsupported)",
159
167
  llm::get_rss_bytes() / 1024.0 / 1024.0);
160
168
 
161
169
  // Wrap the token_callback with print function
162
170
  std::function<void(const std::string &)> wrapped_callback =
163
- [token_callback, warmup](const std::string &piece) {
164
- if (!warmup) {
171
+ [token_callback, &generation_config](const std::string &piece) {
172
+ if (!generation_config.warming) {
165
173
  llm::safe_printf(piece.c_str());
166
174
  fflush(stdout);
167
175
  }
@@ -175,10 +183,23 @@ Error Runner::generate(const std::string &prompt,
175
183
  stats_.inference_start_ms = llm::time_in_ms();
176
184
  shouldStop_ = false;
177
185
 
178
- // Set the sequence length to the max seq length if not provided
179
- int32_t seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
180
- ? seq_len
181
- : metadata_.at(kMaxSeqLen);
186
+ // Override main config fields with given generation config if specified
187
+ int32_t max_seq_len = generation_config.max_seq_len >= 0
188
+ ? generation_config.max_seq_len
189
+ : config_.max_seq_len;
190
+ int32_t max_context_length = generation_config.max_context_length >= 0
191
+ ? generation_config.max_context_length
192
+ : config_.max_context_length;
193
+ int32_t new_tokens_limit = generation_config.max_new_tokens >= 0
194
+ ? generation_config.max_new_tokens
195
+ : config_.max_new_tokens;
196
+ float temperature = generation_config.temperature >= 0.F
197
+ ? generation_config.temperature
198
+ : config_.temperature;
199
+ float topp =
200
+ generation_config.topp >= 0.F ? generation_config.topp : config_.topp;
201
+
202
+ int64_t context_len_left = static_cast<int64_t>(max_context_length) - pos_;
182
203
 
183
204
  std::vector<int32_t> prompt_tokens = tokenizer_->Encode(prompt);
184
205
  std::vector<uint64_t> prompt_tokens_uint64(prompt_tokens.begin(),
@@ -187,30 +208,38 @@ Error Runner::generate(const std::string &prompt,
187
208
  // encode the (string) prompt into tokens sequence
188
209
  int num_prompt_tokens = prompt_tokens.size();
189
210
 
190
- if (num_prompt_tokens < 1) {
191
- ET_LOG(Error,
192
- "num_prompt_tokens %d < 1, expected at least 1 token to be passed "
193
- "to generate()!",
194
- num_prompt_tokens);
195
- return Error::InvalidArgument;
196
- } else if (num_prompt_tokens >= seq_len) {
197
- ET_LOG(Error,
198
- "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - "
199
- "please increase the seq_len value passed to generate()!",
200
- num_prompt_tokens, seq_len);
201
- return Error::InvalidArgument;
202
- }
211
+ ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
212
+ "Expected at least 1 prompt token");
213
+ ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < max_seq_len, InvalidArgument,
214
+ "num_prompt_tokens %d >= max_context_len %" PRId32
215
+ ", Max seq length exceeded - please increase max "
216
+ "seq len value in your export script",
217
+ num_prompt_tokens, max_seq_len);
218
+
219
+ // Determine max_new_tokens using the GenerationConfig's resolve method,
220
+ // then subtract pos_ for max_new_tokens.
221
+ int32_t max_new_tokens = resolve_max_new_tokens(
222
+ num_prompt_tokens, max_seq_len, static_cast<int32_t>(context_len_left),
223
+ new_tokens_limit);
224
+
225
+ ET_LOG(Info,
226
+ "Max new tokens resolved: %d, given pos_ %" PRId64
227
+ ", num_prompt_tokens %zu, max_context_len %" PRId64,
228
+ max_new_tokens, pos_, prompt_tokens.size(),
229
+ static_cast<int64_t>(max_context_length));
230
+ ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
231
+ "Max new tokens %d is less than or equal to 0",
232
+ max_new_tokens);
203
233
 
204
234
  // Prefill first
205
235
  // Here feed all tokens to the model and get the next predicted token
206
236
  // after the prompt. After that we will enter generate loop.
207
237
 
208
238
  // print prompts
209
- if (echo) {
239
+ if (generation_config.echo) {
210
240
  wrapped_callback(prompt);
211
241
  }
212
- int64_t pos = 0;
213
- auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos);
242
+ auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos_);
214
243
  stats_.first_token_ms = llm::time_in_ms();
215
244
  stats_.prompt_eval_end_ms = llm::time_in_ms();
216
245
  ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
@@ -219,30 +248,36 @@ Error Runner::generate(const std::string &prompt,
219
248
  // print the first token from prefill. No prev_token so use cur_token for it.
220
249
  const std::string cur_decoded =
221
250
  tokenizer_->Decode(std::vector<int32_t>{static_cast<int32_t>(cur_token)});
222
- RUNNER_ET_LOG(warmup, "RSS after prompt prefill: %f MiB (0 if unsupported)",
251
+ RUNNER_ET_LOG(generation_config.warming,
252
+ "RSS after prompt prefill: %f MiB (0 if unsupported)",
223
253
  llm::get_rss_bytes() / 1024.0 / 1024.0);
224
254
 
225
255
  // start the main loop
226
256
  prompt_tokens_uint64.push_back(cur_token);
227
257
  int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
228
- prompt_tokens_uint64, num_prompt_tokens, seq_len, wrapped_callback));
258
+ prompt_tokens_uint64, pos_, max_new_tokens - 1, temperature, topp,
259
+ wrapped_callback));
260
+
261
+ pos_ += num_generated_tokens;
229
262
 
230
263
  stats_.inference_end_ms = llm::time_in_ms();
231
- if (!warmup) {
264
+ if (!generation_config.warming) {
232
265
  printf("\n");
233
266
  }
234
267
  RUNNER_ET_LOG(
235
- warmup, "RSS after finishing text generation: %f MiB (0 if unsupported)",
268
+ generation_config.warming,
269
+ "RSS after finishing text generation: %f MiB (0 if unsupported)",
236
270
  llm::get_rss_bytes() / 1024.0 / 1024.0);
237
271
 
238
- if (num_prompt_tokens + num_generated_tokens == seq_len) {
239
- RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
272
+ if (num_generated_tokens == max_new_tokens) {
273
+ RUNNER_ET_LOG(generation_config.warming, "Max new tokens %i reached!",
274
+ max_new_tokens);
240
275
  }
241
276
 
242
277
  stats_.num_prompt_tokens = num_prompt_tokens;
243
278
  stats_.num_generated_tokens = num_generated_tokens;
244
279
 
245
- if (warmup) {
280
+ if (generation_config.warming) {
246
281
  ET_LOG(Info, "Warmup run finished!");
247
282
  } else {
248
283
  // Do not print report during warmup
@@ -256,12 +291,17 @@ Error Runner::generate(const std::string &prompt,
256
291
  }
257
292
 
258
293
  Error Runner::warmup(const std::string &prompt) {
259
- Error err = generate(prompt,
294
+ // Create a GenerationConfig for warmup
295
+ llm::GenerationConfig config{.echo = false, .warming = true};
296
+
297
+ // Call generate with the warmup config
298
+ Error err = generate(prompt, config,
260
299
  /*token_callback=*/nullptr,
261
- /*stats_callbak=*/nullptr,
262
- /*echo=*/false,
263
- /*warmup=*/true);
264
- stats_.reset();
300
+ /*stats_callbak=*/nullptr);
301
+
302
+ // Reset stats after warmup
303
+ reset();
304
+
265
305
  return err;
266
306
  }
267
307
 
@@ -273,6 +313,11 @@ void Runner::stop() {
273
313
  }
274
314
  }
275
315
 
316
+ void Runner::reset() {
317
+ stats_.reset();
318
+ pos_ = 0;
319
+ }
320
+
276
321
  void Runner::set_count_interval(size_t count_interval) {
277
322
  text_token_generator_->set_count_interval(count_interval);
278
323
  }
@@ -281,4 +326,44 @@ void Runner::set_time_interval(size_t time_interval) {
281
326
  text_token_generator_->set_time_interval(time_interval);
282
327
  }
283
328
 
329
+ void Runner::set_temperature(float temperature) noexcept {
330
+ config_.temperature = temperature;
331
+ if (text_decoder_runner_) {
332
+ text_decoder_runner_->set_temperature(temperature);
333
+ }
334
+ }
335
+
336
+ void Runner::set_topp(float topp) noexcept {
337
+ config_.topp = topp;
338
+ if (text_decoder_runner_) {
339
+ text_decoder_runner_->set_topp(topp);
340
+ }
341
+ }
342
+
343
+ int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
344
+ int32_t max_seq_len,
345
+ int32_t max_context_len,
346
+ int32_t max_new_tokens) const {
347
+ int32_t result;
348
+
349
+ if (max_seq_len == -1 && max_new_tokens == -1) {
350
+ // Both are -1, use max context len minus prompt tokens
351
+ result = max_context_len - num_prompt_tokens;
352
+ } else if (max_seq_len == -1 && max_new_tokens != -1) {
353
+ // Only max_new_tokens is specified
354
+ result = std::min(max_new_tokens, max_context_len - num_prompt_tokens);
355
+ } else if (max_seq_len != -1 && max_new_tokens == -1) {
356
+ // Only seq_len is specified
357
+ result = std::min(max_seq_len, max_context_len) - num_prompt_tokens;
358
+ } else {
359
+ // Both are specified
360
+ result =
361
+ std::min(std::min(max_seq_len, max_context_len) - num_prompt_tokens,
362
+ max_new_tokens);
363
+ }
364
+
365
+ // Ensure result is not negative
366
+ return std::max(0, result);
367
+ }
368
+
284
369
  } // namespace example
@@ -27,42 +27,59 @@
27
27
 
28
28
  namespace example {
29
29
 
30
- class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
30
+ namespace llm = ::executorch::extension::llm;
31
+
32
+ class Runner : public llm::IRunner {
31
33
  public:
32
- explicit Runner(const std::string &model_path,
34
+ explicit Runner(::executorch::extension::Module *module,
33
35
  const std::string &tokenizer_path,
34
- const float temperature = 0.8f,
35
- std::optional<const std::string> data_path = std::nullopt);
36
+ const llm::GenerationConfig &config = {
37
+ .temperature = 0.8F, .topp = 0.9F}); // The main config
36
38
 
37
- bool is_loaded() const;
38
- ::executorch::runtime::Error load();
39
- ::executorch::runtime::Error
40
- generate(const std::string &prompt,
41
- std::function<void(const std::string &)> token_callback = {},
42
- std::function<void(const ::executorch::extension::llm::Stats &)>
43
- stats_callback = {},
44
- bool echo = true, bool warming = false);
39
+ bool is_loaded() const override;
40
+ ::executorch::runtime::Error load() override;
41
+ ::executorch::runtime::Error generate(
42
+ const std::string &prompt,
43
+ const llm::GenerationConfig &generation_config =
44
+ {}, // An extra config which temporarily overrides previous model
45
+ // settings
46
+ std::function<void(const std::string &)> token_callback = {},
47
+ std::function<void(const llm::Stats &)> stats_callback = {}) override;
45
48
  ::executorch::runtime::Error warmup(const std::string &prompt);
46
49
  void set_count_interval(size_t count_interval);
47
50
  void set_time_interval(size_t time_interval);
48
- void stop();
51
+ void set_temperature(float temperature) noexcept;
52
+ void set_topp(float topp) noexcept;
53
+
54
+ void stop() override;
55
+ void reset() override;
49
56
 
50
- ::executorch::extension::llm::Stats stats_;
57
+ llm::Stats stats_;
51
58
 
52
59
  private:
53
- float temperature_;
60
+ // Helper functions
61
+ int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len,
62
+ int32_t max_context_len,
63
+ int32_t max_new_tokens = -1) const;
64
+
65
+ // Main config
66
+ llm::GenerationConfig config_;
67
+
68
+ // Flow control
54
69
  bool shouldStop_{false};
70
+ int64_t pos_ = 0; // The position in KV cache of the input, starting from 0.
71
+
72
+ // Main model
73
+ ::executorch::extension::Module *module_;
55
74
 
56
- // model
57
- std::unique_ptr<::executorch::extension::Module> module_;
75
+ // Subcomponents
58
76
  std::string tokenizer_path_;
59
77
  std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
60
78
  std::unordered_map<std::string, int64_t> metadata_;
61
- std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
62
- text_decoder_runner_;
63
- std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
64
- std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
65
- text_token_generator_;
79
+ std::unique_ptr<llm::IOManager> io_manager_;
80
+ std::unique_ptr<llm::TextDecoderRunner> text_decoder_runner_;
81
+ std::unique_ptr<llm::TextPrefiller> text_prefiller_;
82
+ std::unique_ptr<llm::TextTokenGenerator> text_token_generator_;
66
83
  };
67
84
 
68
85
  } // namespace example
@@ -34,6 +34,7 @@
34
34
 
35
35
  #include "sampler.h"
36
36
  #include <algorithm>
37
+ #include <ctime>
37
38
 
38
39
  namespace executorch {
39
40
  namespace extension {
@@ -121,9 +122,14 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
121
122
  Sampler::Sampler(int vocab_size, float temperature, float topp,
122
123
  unsigned long long rng_seed)
123
124
  : vocab_size_(vocab_size),
124
- inv_temperature_(static_cast<bool>(temperature) ? 1.0f / temperature : 0),
125
+ inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
125
126
  topp_(topp), rng_state_(rng_seed) {}
126
127
 
128
+ Sampler::Sampler(int vocab_size, float temperature, float topp)
129
+ : vocab_size_(vocab_size),
130
+ inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
131
+ topp_(topp), rng_state_(std::time(nullptr)) {}
132
+
127
133
  template <typename T> static void softmax(T *x, int size) {
128
134
  // find max value (for numerical stability)
129
135
  T max_val = x[0];
@@ -184,6 +190,7 @@ template <typename T> int32_t Sampler::sample(T *logits) {
184
190
  }
185
191
 
186
192
  template int32_t Sampler::sample<float>(float *logits);
193
+ template int32_t Sampler::sample<uint16_t>(uint16_t *logits);
187
194
  template int32_t
188
195
  Sampler::sample<executorch::aten::Half>(executorch::aten::Half *logits);
189
196
  template int32_t
@@ -26,16 +26,18 @@ namespace extension {
26
26
  namespace llm {
27
27
  // A simple llama2 sampler.
28
28
 
29
- template <typename T> struct ET_EXPERIMENTAL ProbIndex {
29
+ template <typename T> struct ProbIndex {
30
30
  T prob;
31
31
  int32_t index;
32
32
  }; // struct used when sorting probabilities during top-p sampling
33
33
 
34
- class ET_EXPERIMENTAL Sampler {
34
+ class Sampler {
35
35
  public:
36
36
  Sampler(int32_t vocab_size, float temperature, float topp,
37
37
  unsigned long long rng_seed);
38
38
 
39
+ Sampler(int32_t vocab_size, float temperature, float topp);
40
+
39
41
  template <typename T> int32_t sample(T *logits);
40
42
 
41
43
  private:
@@ -18,7 +18,7 @@ namespace executorch {
18
18
  namespace extension {
19
19
  namespace llm {
20
20
 
21
- struct ET_EXPERIMENTAL Stats {
21
+ struct Stats {
22
22
  // Scaling factor for timestamps - in this case, we use ms.
23
23
  const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
24
24
  // Time stamps for the different stages of the execution
@@ -82,8 +82,6 @@ private:
82
82
  long aggregate_sampling_timer_start_timestamp = 0;
83
83
  };
84
84
 
85
- static constexpr auto kTopp = 0.9f;
86
-
87
85
  inline std::string stats_to_json_string(const Stats &stats) {
88
86
  std::stringstream ss;
89
87
  ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
@@ -157,7 +155,6 @@ namespace executorch {
157
155
  namespace llm {
158
156
  // TODO(T197294990): Remove these deprecated aliases once all users have moved
159
157
  // to the new `::executorch` namespaces.
160
- using ::executorch::extension::llm::kTopp;
161
158
  using ::executorch::extension::llm::print_report;
162
159
  using ::executorch::extension::llm::Stats;
163
160
  } // namespace llm
@@ -9,11 +9,11 @@
9
9
  // Given inputs, run a text decoder and return logits.
10
10
 
11
11
  #include "text_decoder_runner.h"
12
+ #include "arange_util.h"
13
+ #include "stats.h"
12
14
 
13
15
  #include <ctime>
14
16
 
15
- #include "stats.h"
16
-
17
17
  namespace executorch {
18
18
  namespace extension {
19
19
  namespace llm {
@@ -21,23 +21,37 @@ namespace llm {
21
21
  // NOTE: we observed ~2x loading performance increase on iPhone 15
22
22
  // and a ~5% improvement on Galaxy S22 by switching to
23
23
  // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
24
- TextDecoderRunner::TextDecoderRunner(Module *module, bool use_kv_cache,
25
- int32_t vocab_size, float temperature)
26
- : module_(module),
27
- sampler_(std::make_unique<Sampler>(
28
- vocab_size, temperature, kTopp,
29
- static_cast<unsigned long long>(std::time(nullptr)))),
30
- use_kv_cache_(use_kv_cache) {}
24
+ TextDecoderRunner::TextDecoderRunner(Module *module, IOManager *io_manager,
25
+ float temperature, float topp)
26
+ : module_(module), io_manager_(io_manager), temperature_(temperature),
27
+ topp_(topp) {}
31
28
 
32
29
  // This function is functional, meaning it shouldn't modify any state of the
33
30
  // input. It should be safe to call multiple times with the same inputs. The
34
31
  // outer loop (call site) is responsible for managing state.
35
32
  ::executorch::runtime::Result<executorch::aten::Tensor>
36
- TextDecoderRunner::step(TensorPtr &tokens, TensorPtr &start_pos) {
33
+ TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
37
34
  // ET_LOG(Info, "Input token %" PRIu64, input_token);
38
- if (use_kv_cache_) {
39
- auto outputs_res = module_->forward({tokens, start_pos});
35
+ auto method_meta = ET_UNWRAP(module_->method_meta("forward"));
36
+ // If only 1 input, we are not using kv cache
37
+ bool use_kv_cache = method_meta.num_inputs() > 1;
38
+
39
+ std::vector<int64_t> cache_positions;
40
+
41
+ if (use_kv_cache) {
42
+ auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
43
+ module_, start_pos, cache_positions, tokens->numel(), "forward"));
44
+
45
+ std::vector<runtime::EValue> inputs;
46
+ auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);
47
+ ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error());
48
+ inputs = inputs_res.get();
49
+ auto outputs_res = module_->forward(inputs);
40
50
  ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
51
+
52
+ auto update_err = io_manager_->update_decode(outputs_res.get());
53
+ ET_CHECK_OK_OR_RETURN_ERROR(update_err);
54
+
41
55
  ET_CHECK_MSG(outputs_res.get().size() == 1,
42
56
  "More then one output returned from executing LLM.");
43
57
  ET_CHECK_MSG(outputs_res.get()[0].isTensor(),