react-native-executorch 0.5.15 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include <executorch/extension/module/module.h>
|
|
12
|
+
#include <executorch/extension/tensor/tensor.h>
|
|
13
|
+
|
|
14
|
+
namespace executorch {
|
|
15
|
+
namespace extension {
|
|
16
|
+
namespace llm {
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* @brief Base class for managing input/output operations for LLM inference.
|
|
20
|
+
*
|
|
21
|
+
* IOManager provides an interface for handling the input preparation and
|
|
22
|
+
* output processing for both prefill and decode phases of LLM inference.
|
|
23
|
+
* Derived classes must implement the virtual methods to provide specific IO
|
|
24
|
+
* management functionality.
|
|
25
|
+
*/
|
|
26
|
+
class IOManager {
|
|
27
|
+
public:
|
|
28
|
+
/**
|
|
29
|
+
* @brief Construct an IOManager bound to a Module.
|
|
30
|
+
*
|
|
31
|
+
* @param module The Module used for querying method metadata and execution.
|
|
32
|
+
*/
|
|
33
|
+
explicit IOManager(ET_MODULE_NAMESPACE::Module &module) : module_(module) {}
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* @brief Virtual destructor to allow proper cleanup in derived classes.
|
|
37
|
+
*/
|
|
38
|
+
virtual ~IOManager() = default;
|
|
39
|
+
|
|
40
|
+
/**
|
|
41
|
+
* @brief Load the IO manager with method metadata for prefill and
|
|
42
|
+
* decode operations.
|
|
43
|
+
*
|
|
44
|
+
* @param prefill_method The prefill method to initialize with.
|
|
45
|
+
* @param decode_method The decode method to initialize with.
|
|
46
|
+
*/
|
|
47
|
+
ET_NODISCARD virtual runtime::Error load(const std::string &prefill_method,
|
|
48
|
+
const std::string &decode_method) {
|
|
49
|
+
(void)prefill_method;
|
|
50
|
+
(void)decode_method;
|
|
51
|
+
return runtime::Error::Ok;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
/**
|
|
55
|
+
* @brief Load the IO manager using the default method names.
|
|
56
|
+
*
|
|
57
|
+
* Uses "forward" for both prefill and decode.
|
|
58
|
+
*
|
|
59
|
+
* @return Error code.
|
|
60
|
+
*/
|
|
61
|
+
ET_NODISCARD runtime::Error load() { return load("forward", "forward"); }
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* @brief Reset the IO manager state.
|
|
65
|
+
*
|
|
66
|
+
* @param prefill_method The prefill method to reset with.
|
|
67
|
+
* @param decode_method The decode method to reset with.
|
|
68
|
+
*/
|
|
69
|
+
ET_NODISCARD virtual runtime::Error reset(const std::string &prefill_method,
|
|
70
|
+
const std::string &decode_method) {
|
|
71
|
+
(void)prefill_method;
|
|
72
|
+
(void)decode_method;
|
|
73
|
+
return runtime::Error::Ok;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
/**
|
|
77
|
+
* @brief Reset the IO manager state using the default method names.
|
|
78
|
+
*
|
|
79
|
+
* Uses "forward" for both prefill and decode.
|
|
80
|
+
*
|
|
81
|
+
* @return Error code.
|
|
82
|
+
*/
|
|
83
|
+
ET_NODISCARD runtime::Error reset() { return reset("forward", "forward"); }
|
|
84
|
+
|
|
85
|
+
/**
|
|
86
|
+
* @brief Prepare inputs for the prefill phase of LLM inference.
|
|
87
|
+
*
|
|
88
|
+
* @param input The input tensor containing token IDs.
|
|
89
|
+
* @param start_pos The tensor containing the starting position of the current
|
|
90
|
+
* input within the context.
|
|
91
|
+
* @param prefill_method The prefill method to prepare inputs for.
|
|
92
|
+
* @return std::vector<runtime::EValue> Vector of prepared inputs
|
|
93
|
+
* for the prefill method.
|
|
94
|
+
*/
|
|
95
|
+
virtual runtime::Result<std::vector<runtime::EValue>>
|
|
96
|
+
prepare_prefill(const TensorPtr &input, const TensorPtr &start_pos,
|
|
97
|
+
const std::string &prefill_method) {
|
|
98
|
+
auto method_meta = module_.method_meta(prefill_method);
|
|
99
|
+
if (!method_meta.ok()) {
|
|
100
|
+
return method_meta.error();
|
|
101
|
+
}
|
|
102
|
+
if (method_meta->num_inputs() != 2) {
|
|
103
|
+
ET_LOG(Error,
|
|
104
|
+
"Expected 2 inputs for prefill method, got %zu. Likely the model "
|
|
105
|
+
"takes the caches or mask as an argument which this IOManager "
|
|
106
|
+
"does not support.",
|
|
107
|
+
method_meta->num_inputs());
|
|
108
|
+
return runtime::Error::InvalidState;
|
|
109
|
+
}
|
|
110
|
+
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
|
|
111
|
+
// here.
|
|
112
|
+
return std::vector<runtime::EValue>{input, start_pos};
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
/**
|
|
116
|
+
* @brief Prepare inputs for the prefill phase using the default method name.
|
|
117
|
+
*
|
|
118
|
+
* Uses "forward" as the prefill method.
|
|
119
|
+
*
|
|
120
|
+
* @param input The input tensor containing token IDs.
|
|
121
|
+
* @param start_pos The tensor containing the starting position.
|
|
122
|
+
* @return Vector of prepared inputs for the prefill method.
|
|
123
|
+
*/
|
|
124
|
+
runtime::Result<std::vector<runtime::EValue>>
|
|
125
|
+
prepare_prefill(const TensorPtr &input, const TensorPtr &start_pos) {
|
|
126
|
+
return prepare_prefill(input, start_pos, "forward");
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/**
|
|
130
|
+
* @brief Prepare inputs for the decode phase of LLM inference.
|
|
131
|
+
*
|
|
132
|
+
* @param input The input tensor containing token IDs.
|
|
133
|
+
* @param start_pos The tensor containing the starting position of the current
|
|
134
|
+
* input within the context.
|
|
135
|
+
* @param decode_method The decode method to prepare inputs for.
|
|
136
|
+
* @return std::vector<runtime::EValue> Vector of prepared inputs
|
|
137
|
+
* for the decode method.
|
|
138
|
+
*/
|
|
139
|
+
virtual runtime::Result<std::vector<runtime::EValue>>
|
|
140
|
+
prepare_decode(const TensorPtr &input, const TensorPtr &start_pos,
|
|
141
|
+
const std::string &decode_method) {
|
|
142
|
+
auto method_meta = module_.method_meta(decode_method);
|
|
143
|
+
if (!method_meta.ok()) {
|
|
144
|
+
return method_meta.error();
|
|
145
|
+
}
|
|
146
|
+
if (method_meta->num_inputs() != 2) {
|
|
147
|
+
ET_LOG(Error,
|
|
148
|
+
"Expected 2 inputs for decode method, got %zu. Likely the model "
|
|
149
|
+
"takes the caches or mask as an argument which this IOManager "
|
|
150
|
+
"does not support.",
|
|
151
|
+
method_meta->num_inputs());
|
|
152
|
+
return runtime::Error::InvalidState;
|
|
153
|
+
}
|
|
154
|
+
// Cpu IO Manager supports dynamic shapes for prefill, so no work to be done
|
|
155
|
+
// here.
|
|
156
|
+
return std::vector<runtime::EValue>{input, start_pos};
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
/**
|
|
160
|
+
* @brief Prepare inputs for the decode phase using the default method name.
|
|
161
|
+
*
|
|
162
|
+
* Uses "forward" as the decode method.
|
|
163
|
+
*
|
|
164
|
+
* @param input The input tensor containing token IDs.
|
|
165
|
+
* @param start_pos The tensor containing the starting position.
|
|
166
|
+
* @return Vector of prepared inputs for the decode method.
|
|
167
|
+
*/
|
|
168
|
+
runtime::Result<std::vector<runtime::EValue>>
|
|
169
|
+
prepare_decode(const TensorPtr &input, const TensorPtr &start_pos) {
|
|
170
|
+
return prepare_decode(input, start_pos, "forward");
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
/**
|
|
174
|
+
* @brief Process and update internal state with outputs from the prefill
|
|
175
|
+
* phase.
|
|
176
|
+
*
|
|
177
|
+
* @param prefill_method The prefill method to update with outputs.
|
|
178
|
+
* @param model_outputs Vector of outputs from the prefill method execution.
|
|
179
|
+
*/
|
|
180
|
+
ET_NODISCARD virtual runtime::Error
|
|
181
|
+
update_prefill(const std::vector<runtime::EValue> &model_outputs,
|
|
182
|
+
const std::string &prefill_method) {
|
|
183
|
+
(void)model_outputs;
|
|
184
|
+
(void)prefill_method;
|
|
185
|
+
// No post inference work to do.
|
|
186
|
+
return runtime::Error::Ok;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/**
|
|
190
|
+
* @brief Process outputs from the prefill phase using the default method.
|
|
191
|
+
*
|
|
192
|
+
* Uses "forward" as the prefill method.
|
|
193
|
+
*
|
|
194
|
+
* @param model_outputs Vector of outputs from the prefill execution.
|
|
195
|
+
* @return Error code.
|
|
196
|
+
*/
|
|
197
|
+
ET_NODISCARD runtime::Error
|
|
198
|
+
update_prefill(const std::vector<runtime::EValue> &model_outputs) {
|
|
199
|
+
return update_prefill(model_outputs, "forward");
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
/**
|
|
203
|
+
* @brief Process and update internal state with outputs from the decode
|
|
204
|
+
* phase.
|
|
205
|
+
*
|
|
206
|
+
* @param decode_method The decode method to update with outputs.
|
|
207
|
+
* @param model_outputs Vector of outputs from the decode method execution.
|
|
208
|
+
*/
|
|
209
|
+
ET_NODISCARD virtual runtime::Error
|
|
210
|
+
update_decode(const std::vector<runtime::EValue> &model_outputs,
|
|
211
|
+
const std::string &decode_method) {
|
|
212
|
+
(void)model_outputs;
|
|
213
|
+
(void)decode_method;
|
|
214
|
+
// No post inference work to do.
|
|
215
|
+
return runtime::Error::Ok;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
/**
|
|
219
|
+
* @brief Process outputs from the decode phase using the default method.
|
|
220
|
+
*
|
|
221
|
+
* Uses "forward" as the decode method.
|
|
222
|
+
*
|
|
223
|
+
* @param model_outputs Vector of outputs from the decode execution.
|
|
224
|
+
* @return Error code.
|
|
225
|
+
*/
|
|
226
|
+
ET_NODISCARD runtime::Error
|
|
227
|
+
update_decode(const std::vector<runtime::EValue> &model_outputs) {
|
|
228
|
+
return update_decode(model_outputs, "forward");
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
private:
|
|
232
|
+
/**
|
|
233
|
+
* @brief Reference to the Module used for method metadata and execution.
|
|
234
|
+
*/
|
|
235
|
+
ET_MODULE_NAMESPACE::Module &module_;
|
|
236
|
+
};
|
|
237
|
+
|
|
238
|
+
} // namespace llm
|
|
239
|
+
} // namespace extension
|
|
240
|
+
} // namespace executorch
|
package/common/runner/irunner.h
CHANGED
|
@@ -6,41 +6,112 @@
|
|
|
6
6
|
* LICENSE file in the root directory of this source tree.
|
|
7
7
|
*/
|
|
8
8
|
|
|
9
|
-
//
|
|
10
|
-
// implements their own load and generation logic to run the model.
|
|
9
|
+
// Interface for text generation runners.
|
|
11
10
|
|
|
12
11
|
#pragma once
|
|
13
12
|
|
|
13
|
+
#include "stats.h"
|
|
14
|
+
|
|
15
|
+
#include <cstdint>
|
|
14
16
|
#include <functional>
|
|
17
|
+
#include <memory>
|
|
15
18
|
#include <string>
|
|
16
19
|
|
|
17
|
-
#include
|
|
18
|
-
#include <executorch/extension/module/module.h>
|
|
20
|
+
#include <executorch/runtime/core/error.h>
|
|
19
21
|
|
|
20
22
|
namespace executorch {
|
|
21
23
|
namespace extension {
|
|
22
24
|
namespace llm {
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
// Configuration struct for generation parameters
|
|
27
|
+
struct GenerationConfig {
|
|
28
|
+
// Whether to echo the input prompt in the output
|
|
29
|
+
bool echo = false;
|
|
30
|
+
|
|
31
|
+
// Whether this is a warmup run (affects perf benchmarking)
|
|
32
|
+
bool warming = false;
|
|
33
|
+
|
|
34
|
+
// Maximum number of new tokens to generate
|
|
35
|
+
// If the max_context_len metadata that's serialized in the .pte file exists,
|
|
36
|
+
// then the number of prompt tokens + max_new_tokens won't exceed
|
|
37
|
+
// max_context_len. If this field is -1, it means we will rely on
|
|
38
|
+
// max_context_len metadata and seq_len value.
|
|
39
|
+
int32_t max_new_tokens = -1;
|
|
40
|
+
|
|
41
|
+
// Maximum number of total tokens
|
|
42
|
+
// If the .pte file contains the max_context_len metadata, it will override
|
|
43
|
+
// this value if it's smaller. If this field is -1, we will use the
|
|
44
|
+
// max_context_len metadata directly.
|
|
45
|
+
int32_t max_seq_len = -1;
|
|
46
|
+
|
|
47
|
+
// Maximum context length
|
|
48
|
+
// If the .pte file contains the max_context_len metadata, it will override
|
|
49
|
+
// this value if it's smaller. If this field is -1, we will use the
|
|
50
|
+
// max_context_len metadata directly.
|
|
51
|
+
int32_t max_context_length = -1;
|
|
52
|
+
|
|
53
|
+
// Temperature for sampling (higher = more random)
|
|
54
|
+
float temperature = -1.F;
|
|
55
|
+
|
|
56
|
+
// Top-p (nucleus sampling) – limits next token selection to the smallest set
|
|
57
|
+
// whose cumulative probability exceeds topp. Range: 0.0 to 1.0. Lower values
|
|
58
|
+
// = more deterministic, higher = more diverse generations.
|
|
59
|
+
float topp = -1.F;
|
|
60
|
+
|
|
61
|
+
// Enable dynamic input shapes (if implemented) or not
|
|
62
|
+
// Impacts the prefill phase and causes TextPrefiller to pass all the tokens
|
|
63
|
+
// at once if set to true.
|
|
64
|
+
bool enable_dynamic_shape = true;
|
|
65
|
+
|
|
66
|
+
// Use KV_CACHE implementation (if implemented) or not
|
|
67
|
+
bool enable_kv_cache = true;
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
// Base interface for LLM runners
|
|
71
|
+
class IRunner {
|
|
25
72
|
public:
|
|
26
73
|
virtual ~IRunner() = default;
|
|
27
74
|
|
|
28
|
-
|
|
75
|
+
/**
|
|
76
|
+
* Check if the runner is loaded and ready for inference.
|
|
77
|
+
*
|
|
78
|
+
* @return true if the runner is loaded, false otherwise
|
|
79
|
+
*/
|
|
29
80
|
virtual bool is_loaded() const = 0;
|
|
30
81
|
|
|
31
|
-
|
|
32
|
-
|
|
82
|
+
/**
|
|
83
|
+
* Load the model and prepare for inference.
|
|
84
|
+
*
|
|
85
|
+
* @return Error::Ok if successful, an error otherwise
|
|
86
|
+
*/
|
|
87
|
+
virtual runtime::Error load() = 0;
|
|
33
88
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
89
|
+
/**
|
|
90
|
+
* Generate text based on the provided prompt and generation config.
|
|
91
|
+
*
|
|
92
|
+
* @param prompt The input prompt to generate from
|
|
93
|
+
* @param config Generation configuration parameters
|
|
94
|
+
* @param token_callback Callback function called for each generated token
|
|
95
|
+
* @param stats_callback Callback function for generation statistics
|
|
96
|
+
* @return Error::Ok if successful, an error otherwise
|
|
97
|
+
*/
|
|
98
|
+
virtual runtime::Error
|
|
99
|
+
generate(const std::string &prompt, const GenerationConfig &config,
|
|
100
|
+
std::function<void(const std::string &)> token_callback,
|
|
101
|
+
std::function<void(const Stats &)> stats_callback) = 0;
|
|
41
102
|
|
|
42
|
-
|
|
103
|
+
/**
|
|
104
|
+
* Stop the generation process.
|
|
105
|
+
*/
|
|
43
106
|
virtual void stop() = 0;
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* Force remove prefilled tokens and reset KV cache start position
|
|
110
|
+
*
|
|
111
|
+
* This method removes the prefilled tokens from the KV cache and resets the
|
|
112
|
+
* start position to 0.
|
|
113
|
+
*/
|
|
114
|
+
virtual void reset() = 0;
|
|
44
115
|
};
|
|
45
116
|
|
|
46
117
|
} // namespace llm
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* @file
|
|
11
|
+
*
|
|
12
|
+
* Common includes used by all kernel implementations.
|
|
13
|
+
*/
|
|
14
|
+
|
|
15
|
+
#pragma once
|
|
16
|
+
|
|
17
|
+
// This list should be very conservative since most kernel .cpp files will
|
|
18
|
+
// include these and depend on their transitive deps. Only add a header if 99%
|
|
19
|
+
// of kernels would have included it anyway.
|
|
20
|
+
#include <executorch/runtime/core/exec_aten/exec_aten.h> // IWYU pragma: export
|
|
21
|
+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> // IWYU pragma: export
|
|
22
|
+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
|
|
23
|
+
#include <executorch/runtime/kernel/kernel_runtime_context.h> // IWYU pragma: export
|