react-native-executorch 0.5.15 → 0.6.0-nightly-897eae9-20251213
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
package/common/runner/runner.cpp
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
*
|
|
5
5
|
* This source code is licensed under the BSD-style license found in the
|
|
6
6
|
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
|
|
7
8
|
*/
|
|
8
9
|
|
|
9
10
|
// A simple llama2 runner that includes preprocessing and post processing logic.
|
|
@@ -21,8 +22,6 @@ using ::executorch::extension::Module;
|
|
|
21
22
|
using ::executorch::runtime::Error;
|
|
22
23
|
using ::executorch::runtime::Result;
|
|
23
24
|
|
|
24
|
-
namespace llm = ::executorch::extension::llm;
|
|
25
|
-
|
|
26
25
|
std::string loadBytesFromFile(const std::string &path) {
|
|
27
26
|
std::ifstream fs(path, std::ios::in | std::ios::binary);
|
|
28
27
|
if (fs.fail()) {
|
|
@@ -39,7 +38,6 @@ std::string loadBytesFromFile(const std::string &path) {
|
|
|
39
38
|
|
|
40
39
|
namespace {
|
|
41
40
|
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
|
|
42
|
-
static constexpr auto kBosId = "get_bos_id";
|
|
43
41
|
static constexpr auto kEosIds = "get_eos_ids";
|
|
44
42
|
static constexpr auto kMaxSeqLen = "get_max_seq_len";
|
|
45
43
|
static constexpr auto kMaxContextLen = "get_max_context_len";
|
|
@@ -48,29 +46,16 @@ static constexpr auto kUseKVCache = "use_kv_cache";
|
|
|
48
46
|
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
|
|
49
47
|
} // namespace
|
|
50
48
|
|
|
51
|
-
Runner::Runner(
|
|
52
|
-
const
|
|
53
|
-
|
|
54
|
-
// NOTE: we observed ~2x loading performance increase on iPhone 15
|
|
55
|
-
// and a ~5% improvement on Galaxy S22 by switching to
|
|
56
|
-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
|
|
57
|
-
: temperature_(temperature), tokenizer_path_(tokenizer_path),
|
|
49
|
+
Runner::Runner(Module *module, const std::string &tokenizer_path,
|
|
50
|
+
const llm::GenerationConfig &config)
|
|
51
|
+
: config_(config), module_(module), tokenizer_path_(tokenizer_path),
|
|
58
52
|
metadata_({
|
|
59
53
|
{kEnableDynamicShape, false},
|
|
60
54
|
{kMaxSeqLen, 128},
|
|
61
55
|
{kMaxContextLen, 128},
|
|
62
56
|
{kUseKVCache, true},
|
|
63
57
|
{kUseSDPAWithKVCache, false},
|
|
64
|
-
}) {
|
|
65
|
-
if (data_path.has_value()) {
|
|
66
|
-
module_ = std::make_unique<Module>(model_path, data_path.value(),
|
|
67
|
-
Module::LoadMode::File);
|
|
68
|
-
} else {
|
|
69
|
-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
|
|
70
|
-
}
|
|
71
|
-
ET_LOG(Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
|
|
72
|
-
model_path.c_str(), tokenizer_path.c_str());
|
|
73
|
-
}
|
|
58
|
+
}) {}
|
|
74
59
|
|
|
75
60
|
bool Runner::is_loaded() const {
|
|
76
61
|
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
|
|
@@ -81,9 +66,10 @@ Error Runner::load() {
|
|
|
81
66
|
if (is_loaded()) {
|
|
82
67
|
return Error::Ok;
|
|
83
68
|
}
|
|
69
|
+
|
|
84
70
|
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
|
|
85
|
-
// load tokenizer.
|
|
86
71
|
|
|
72
|
+
// Load tokenizer.
|
|
87
73
|
auto blob = loadBytesFromFile(tokenizer_path_);
|
|
88
74
|
tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(blob);
|
|
89
75
|
|
|
@@ -92,9 +78,9 @@ Error Runner::load() {
|
|
|
92
78
|
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
|
|
93
79
|
metadata_[kVocabSize] = tokenizer_->GetVocabSize();
|
|
94
80
|
|
|
81
|
+
// Load model metadata
|
|
95
82
|
const auto method_names =
|
|
96
83
|
ET_UNWRAP(module_->method_names(), "Failed reading method names");
|
|
97
|
-
|
|
98
84
|
for (auto &pair : metadata_) {
|
|
99
85
|
const auto &method_name = pair.first;
|
|
100
86
|
auto &value = pair.second;
|
|
@@ -103,11 +89,13 @@ Error Runner::load() {
|
|
|
103
89
|
.toScalar()
|
|
104
90
|
.to<decltype(metadata_)::mapped_type>();
|
|
105
91
|
} else {
|
|
106
|
-
ET_LOG(Info, "
|
|
92
|
+
ET_LOG(Info, "Method %s not found, using the default value %" PRId64,
|
|
107
93
|
method_name.c_str(), value);
|
|
108
94
|
}
|
|
109
95
|
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
|
|
110
96
|
}
|
|
97
|
+
|
|
98
|
+
// Load EOS token ids
|
|
111
99
|
if (method_names.count(kEosIds)) {
|
|
112
100
|
eos_ids->clear();
|
|
113
101
|
for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
|
|
@@ -116,15 +104,34 @@ Error Runner::load() {
|
|
|
116
104
|
ET_LOG(Info, "eos_id = %" PRId64, value);
|
|
117
105
|
}
|
|
118
106
|
}
|
|
107
|
+
|
|
108
|
+
// Determine missing config values
|
|
109
|
+
// If user does not directly specify configuration parameters such as
|
|
110
|
+
// max_seq_len (i.e. leaves them as default values), they are determined by
|
|
111
|
+
// reading the exported model's methods.
|
|
112
|
+
if (config_.max_seq_len < 0)
|
|
113
|
+
config_.max_seq_len = static_cast<int32_t>(metadata_.at(kMaxSeqLen));
|
|
114
|
+
if (config_.max_context_length < 0)
|
|
115
|
+
config_.max_context_length =
|
|
116
|
+
static_cast<int32_t>(metadata_.at(kMaxContextLen));
|
|
117
|
+
if (config_.max_new_tokens < 0)
|
|
118
|
+
config_.max_new_tokens =
|
|
119
|
+
std::min(config_.max_seq_len, config_.max_context_length);
|
|
120
|
+
if (config_.enable_dynamic_shape)
|
|
121
|
+
config_.enable_dynamic_shape =
|
|
122
|
+
static_cast<bool>(metadata_.at(kEnableDynamicShape));
|
|
123
|
+
if (config_.enable_kv_cache)
|
|
124
|
+
config_.enable_kv_cache = static_cast<bool>(metadata_.at(kUseKVCache));
|
|
125
|
+
|
|
126
|
+
io_manager_ = std::make_unique<llm::IOManager>(*module_);
|
|
119
127
|
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
|
|
120
|
-
module_.get(),
|
|
121
|
-
temperature_);
|
|
128
|
+
module_, io_manager_.get(), config_.temperature, config_.topp);
|
|
122
129
|
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
|
|
123
|
-
text_decoder_runner_.get(),
|
|
124
|
-
|
|
130
|
+
text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
131
|
+
config_.enable_dynamic_shape, config_.max_seq_len);
|
|
125
132
|
|
|
126
133
|
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
|
|
127
|
-
tokenizer_.get(), text_decoder_runner_.get(),
|
|
134
|
+
tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
|
|
128
135
|
std::move(eos_ids), &stats_);
|
|
129
136
|
|
|
130
137
|
return Error::Ok;
|
|
@@ -139,9 +146,9 @@ Error Runner::load() {
|
|
|
139
146
|
}
|
|
140
147
|
|
|
141
148
|
Error Runner::generate(const std::string &prompt,
|
|
149
|
+
const llm::GenerationConfig &generation_config,
|
|
142
150
|
std::function<void(const std::string &)> token_callback,
|
|
143
|
-
std::function<void(const llm::Stats &)> stats_callback
|
|
144
|
-
bool echo, bool warmup) {
|
|
151
|
+
std::function<void(const llm::Stats &)> stats_callback) {
|
|
145
152
|
// Prepare the inputs.
|
|
146
153
|
// Use ones-initialized inputs.
|
|
147
154
|
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
|
|
@@ -151,17 +158,18 @@ Error Runner::generate(const std::string &prompt,
|
|
|
151
158
|
stats_.model_load_end_ms = llm::time_in_ms();
|
|
152
159
|
}
|
|
153
160
|
|
|
154
|
-
if (
|
|
161
|
+
if (generation_config.warming) {
|
|
155
162
|
ET_LOG(Info, "Doing a warmup run...");
|
|
156
163
|
}
|
|
157
164
|
|
|
158
|
-
RUNNER_ET_LOG(
|
|
165
|
+
RUNNER_ET_LOG(generation_config.warming,
|
|
166
|
+
"RSS after loading model: %f MiB (0 if unsupported)",
|
|
159
167
|
llm::get_rss_bytes() / 1024.0 / 1024.0);
|
|
160
168
|
|
|
161
169
|
// Wrap the token_callback with print function
|
|
162
170
|
std::function<void(const std::string &)> wrapped_callback =
|
|
163
|
-
[token_callback,
|
|
164
|
-
if (!
|
|
171
|
+
[token_callback, &generation_config](const std::string &piece) {
|
|
172
|
+
if (!generation_config.warming) {
|
|
165
173
|
llm::safe_printf(piece.c_str());
|
|
166
174
|
fflush(stdout);
|
|
167
175
|
}
|
|
@@ -175,10 +183,23 @@ Error Runner::generate(const std::string &prompt,
|
|
|
175
183
|
stats_.inference_start_ms = llm::time_in_ms();
|
|
176
184
|
shouldStop_ = false;
|
|
177
185
|
|
|
178
|
-
//
|
|
179
|
-
int32_t
|
|
180
|
-
|
|
181
|
-
|
|
186
|
+
// Override main config fields with given generation config if specified
|
|
187
|
+
int32_t max_seq_len = generation_config.max_seq_len >= 0
|
|
188
|
+
? generation_config.max_seq_len
|
|
189
|
+
: config_.max_seq_len;
|
|
190
|
+
int32_t max_context_length = generation_config.max_context_length >= 0
|
|
191
|
+
? generation_config.max_context_length
|
|
192
|
+
: config_.max_context_length;
|
|
193
|
+
int32_t new_tokens_limit = generation_config.max_new_tokens >= 0
|
|
194
|
+
? generation_config.max_new_tokens
|
|
195
|
+
: config_.max_new_tokens;
|
|
196
|
+
float temperature = generation_config.temperature >= 0.F
|
|
197
|
+
? generation_config.temperature
|
|
198
|
+
: config_.temperature;
|
|
199
|
+
float topp =
|
|
200
|
+
generation_config.topp >= 0.F ? generation_config.topp : config_.topp;
|
|
201
|
+
|
|
202
|
+
int64_t context_len_left = static_cast<int64_t>(max_context_length) - pos_;
|
|
182
203
|
|
|
183
204
|
std::vector<int32_t> prompt_tokens = tokenizer_->Encode(prompt);
|
|
184
205
|
std::vector<uint64_t> prompt_tokens_uint64(prompt_tokens.begin(),
|
|
@@ -187,30 +208,38 @@ Error Runner::generate(const std::string &prompt,
|
|
|
187
208
|
// encode the (string) prompt into tokens sequence
|
|
188
209
|
int num_prompt_tokens = prompt_tokens.size();
|
|
189
210
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
211
|
+
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
|
|
212
|
+
"Expected at least 1 prompt token");
|
|
213
|
+
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < max_seq_len, InvalidArgument,
|
|
214
|
+
"num_prompt_tokens %d >= max_context_len %" PRId32
|
|
215
|
+
", Max seq length exceeded - please increase max "
|
|
216
|
+
"seq len value in your export script",
|
|
217
|
+
num_prompt_tokens, max_seq_len);
|
|
218
|
+
|
|
219
|
+
// Determine max_new_tokens using the GenerationConfig's resolve method,
|
|
220
|
+
// then subtract pos_ for max_new_tokens.
|
|
221
|
+
int32_t max_new_tokens = resolve_max_new_tokens(
|
|
222
|
+
num_prompt_tokens, max_seq_len, static_cast<int32_t>(context_len_left),
|
|
223
|
+
new_tokens_limit);
|
|
224
|
+
|
|
225
|
+
ET_LOG(Info,
|
|
226
|
+
"Max new tokens resolved: %d, given pos_ %" PRId64
|
|
227
|
+
", num_prompt_tokens %zu, max_context_len %" PRId64,
|
|
228
|
+
max_new_tokens, pos_, prompt_tokens.size(),
|
|
229
|
+
static_cast<int64_t>(max_context_length));
|
|
230
|
+
ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
|
|
231
|
+
"Max new tokens %d is less than or equal to 0",
|
|
232
|
+
max_new_tokens);
|
|
203
233
|
|
|
204
234
|
// Prefill first
|
|
205
235
|
// Here feed all tokens to the model and get the next predicted token
|
|
206
236
|
// after the prompt. After that we will enter generate loop.
|
|
207
237
|
|
|
208
238
|
// print prompts
|
|
209
|
-
if (echo) {
|
|
239
|
+
if (generation_config.echo) {
|
|
210
240
|
wrapped_callback(prompt);
|
|
211
241
|
}
|
|
212
|
-
|
|
213
|
-
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos);
|
|
242
|
+
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos_);
|
|
214
243
|
stats_.first_token_ms = llm::time_in_ms();
|
|
215
244
|
stats_.prompt_eval_end_ms = llm::time_in_ms();
|
|
216
245
|
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
|
|
@@ -219,30 +248,36 @@ Error Runner::generate(const std::string &prompt,
|
|
|
219
248
|
// print the first token from prefill. No prev_token so use cur_token for it.
|
|
220
249
|
const std::string cur_decoded =
|
|
221
250
|
tokenizer_->Decode(std::vector<int32_t>{static_cast<int32_t>(cur_token)});
|
|
222
|
-
RUNNER_ET_LOG(
|
|
251
|
+
RUNNER_ET_LOG(generation_config.warming,
|
|
252
|
+
"RSS after prompt prefill: %f MiB (0 if unsupported)",
|
|
223
253
|
llm::get_rss_bytes() / 1024.0 / 1024.0);
|
|
224
254
|
|
|
225
255
|
// start the main loop
|
|
226
256
|
prompt_tokens_uint64.push_back(cur_token);
|
|
227
257
|
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
|
|
228
|
-
prompt_tokens_uint64,
|
|
258
|
+
prompt_tokens_uint64, pos_, max_new_tokens - 1, temperature, topp,
|
|
259
|
+
wrapped_callback));
|
|
260
|
+
|
|
261
|
+
pos_ += num_generated_tokens;
|
|
229
262
|
|
|
230
263
|
stats_.inference_end_ms = llm::time_in_ms();
|
|
231
|
-
if (!
|
|
264
|
+
if (!generation_config.warming) {
|
|
232
265
|
printf("\n");
|
|
233
266
|
}
|
|
234
267
|
RUNNER_ET_LOG(
|
|
235
|
-
|
|
268
|
+
generation_config.warming,
|
|
269
|
+
"RSS after finishing text generation: %f MiB (0 if unsupported)",
|
|
236
270
|
llm::get_rss_bytes() / 1024.0 / 1024.0);
|
|
237
271
|
|
|
238
|
-
if (
|
|
239
|
-
RUNNER_ET_LOG(
|
|
272
|
+
if (num_generated_tokens == max_new_tokens) {
|
|
273
|
+
RUNNER_ET_LOG(generation_config.warming, "Max new tokens %i reached!",
|
|
274
|
+
max_new_tokens);
|
|
240
275
|
}
|
|
241
276
|
|
|
242
277
|
stats_.num_prompt_tokens = num_prompt_tokens;
|
|
243
278
|
stats_.num_generated_tokens = num_generated_tokens;
|
|
244
279
|
|
|
245
|
-
if (
|
|
280
|
+
if (generation_config.warming) {
|
|
246
281
|
ET_LOG(Info, "Warmup run finished!");
|
|
247
282
|
} else {
|
|
248
283
|
// Do not print report during warmup
|
|
@@ -256,12 +291,17 @@ Error Runner::generate(const std::string &prompt,
|
|
|
256
291
|
}
|
|
257
292
|
|
|
258
293
|
Error Runner::warmup(const std::string &prompt) {
|
|
259
|
-
|
|
294
|
+
// Create a GenerationConfig for warmup
|
|
295
|
+
llm::GenerationConfig config{.echo = false, .warming = true};
|
|
296
|
+
|
|
297
|
+
// Call generate with the warmup config
|
|
298
|
+
Error err = generate(prompt, config,
|
|
260
299
|
/*token_callback=*/nullptr,
|
|
261
|
-
/*stats_callbak=*/nullptr
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
300
|
+
/*stats_callbak=*/nullptr);
|
|
301
|
+
|
|
302
|
+
// Reset stats after warmup
|
|
303
|
+
reset();
|
|
304
|
+
|
|
265
305
|
return err;
|
|
266
306
|
}
|
|
267
307
|
|
|
@@ -273,6 +313,11 @@ void Runner::stop() {
|
|
|
273
313
|
}
|
|
274
314
|
}
|
|
275
315
|
|
|
316
|
+
void Runner::reset() {
|
|
317
|
+
stats_.reset();
|
|
318
|
+
pos_ = 0;
|
|
319
|
+
}
|
|
320
|
+
|
|
276
321
|
void Runner::set_count_interval(size_t count_interval) {
|
|
277
322
|
text_token_generator_->set_count_interval(count_interval);
|
|
278
323
|
}
|
|
@@ -281,4 +326,44 @@ void Runner::set_time_interval(size_t time_interval) {
|
|
|
281
326
|
text_token_generator_->set_time_interval(time_interval);
|
|
282
327
|
}
|
|
283
328
|
|
|
329
|
+
void Runner::set_temperature(float temperature) noexcept {
|
|
330
|
+
config_.temperature = temperature;
|
|
331
|
+
if (text_decoder_runner_) {
|
|
332
|
+
text_decoder_runner_->set_temperature(temperature);
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
void Runner::set_topp(float topp) noexcept {
|
|
337
|
+
config_.topp = topp;
|
|
338
|
+
if (text_decoder_runner_) {
|
|
339
|
+
text_decoder_runner_->set_topp(topp);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
|
|
344
|
+
int32_t max_seq_len,
|
|
345
|
+
int32_t max_context_len,
|
|
346
|
+
int32_t max_new_tokens) const {
|
|
347
|
+
int32_t result;
|
|
348
|
+
|
|
349
|
+
if (max_seq_len == -1 && max_new_tokens == -1) {
|
|
350
|
+
// Both are -1, use max context len minus prompt tokens
|
|
351
|
+
result = max_context_len - num_prompt_tokens;
|
|
352
|
+
} else if (max_seq_len == -1 && max_new_tokens != -1) {
|
|
353
|
+
// Only max_new_tokens is specified
|
|
354
|
+
result = std::min(max_new_tokens, max_context_len - num_prompt_tokens);
|
|
355
|
+
} else if (max_seq_len != -1 && max_new_tokens == -1) {
|
|
356
|
+
// Only seq_len is specified
|
|
357
|
+
result = std::min(max_seq_len, max_context_len) - num_prompt_tokens;
|
|
358
|
+
} else {
|
|
359
|
+
// Both are specified
|
|
360
|
+
result =
|
|
361
|
+
std::min(std::min(max_seq_len, max_context_len) - num_prompt_tokens,
|
|
362
|
+
max_new_tokens);
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
// Ensure result is not negative
|
|
366
|
+
return std::max(0, result);
|
|
367
|
+
}
|
|
368
|
+
|
|
284
369
|
} // namespace example
|
package/common/runner/runner.h
CHANGED
|
@@ -27,42 +27,59 @@
|
|
|
27
27
|
|
|
28
28
|
namespace example {
|
|
29
29
|
|
|
30
|
-
|
|
30
|
+
namespace llm = ::executorch::extension::llm;
|
|
31
|
+
|
|
32
|
+
class Runner : public llm::IRunner {
|
|
31
33
|
public:
|
|
32
|
-
explicit Runner(
|
|
34
|
+
explicit Runner(::executorch::extension::Module *module,
|
|
33
35
|
const std::string &tokenizer_path,
|
|
34
|
-
const
|
|
35
|
-
|
|
36
|
+
const llm::GenerationConfig &config = {
|
|
37
|
+
.temperature = 0.8F, .topp = 0.9F}); // The main config
|
|
36
38
|
|
|
37
|
-
bool is_loaded() const;
|
|
38
|
-
::executorch::runtime::Error load();
|
|
39
|
-
::executorch::runtime::Error
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
39
|
+
bool is_loaded() const override;
|
|
40
|
+
::executorch::runtime::Error load() override;
|
|
41
|
+
::executorch::runtime::Error generate(
|
|
42
|
+
const std::string &prompt,
|
|
43
|
+
const llm::GenerationConfig &generation_config =
|
|
44
|
+
{}, // An extra config which temporarily overrides previous model
|
|
45
|
+
// settings
|
|
46
|
+
std::function<void(const std::string &)> token_callback = {},
|
|
47
|
+
std::function<void(const llm::Stats &)> stats_callback = {}) override;
|
|
45
48
|
::executorch::runtime::Error warmup(const std::string &prompt);
|
|
46
49
|
void set_count_interval(size_t count_interval);
|
|
47
50
|
void set_time_interval(size_t time_interval);
|
|
48
|
-
void
|
|
51
|
+
void set_temperature(float temperature) noexcept;
|
|
52
|
+
void set_topp(float topp) noexcept;
|
|
53
|
+
|
|
54
|
+
void stop() override;
|
|
55
|
+
void reset() override;
|
|
49
56
|
|
|
50
|
-
|
|
57
|
+
llm::Stats stats_;
|
|
51
58
|
|
|
52
59
|
private:
|
|
53
|
-
|
|
60
|
+
// Helper functions
|
|
61
|
+
int32_t resolve_max_new_tokens(int32_t num_prompt_tokens, int32_t max_seq_len,
|
|
62
|
+
int32_t max_context_len,
|
|
63
|
+
int32_t max_new_tokens = -1) const;
|
|
64
|
+
|
|
65
|
+
// Main config
|
|
66
|
+
llm::GenerationConfig config_;
|
|
67
|
+
|
|
68
|
+
// Flow control
|
|
54
69
|
bool shouldStop_{false};
|
|
70
|
+
int64_t pos_ = 0; // The position in KV cache of the input, starting from 0.
|
|
71
|
+
|
|
72
|
+
// Main model
|
|
73
|
+
::executorch::extension::Module *module_;
|
|
55
74
|
|
|
56
|
-
//
|
|
57
|
-
std::unique_ptr<::executorch::extension::Module> module_;
|
|
75
|
+
// Subcomponents
|
|
58
76
|
std::string tokenizer_path_;
|
|
59
77
|
std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
|
|
60
78
|
std::unordered_map<std::string, int64_t> metadata_;
|
|
61
|
-
std::unique_ptr
|
|
62
|
-
|
|
63
|
-
std::unique_ptr
|
|
64
|
-
std::unique_ptr
|
|
65
|
-
text_token_generator_;
|
|
79
|
+
std::unique_ptr<llm::IOManager> io_manager_;
|
|
80
|
+
std::unique_ptr<llm::TextDecoderRunner> text_decoder_runner_;
|
|
81
|
+
std::unique_ptr<llm::TextPrefiller> text_prefiller_;
|
|
82
|
+
std::unique_ptr<llm::TextTokenGenerator> text_token_generator_;
|
|
66
83
|
};
|
|
67
84
|
|
|
68
85
|
} // namespace example
|
|
@@ -34,6 +34,7 @@
|
|
|
34
34
|
|
|
35
35
|
#include "sampler.h"
|
|
36
36
|
#include <algorithm>
|
|
37
|
+
#include <ctime>
|
|
37
38
|
|
|
38
39
|
namespace executorch {
|
|
39
40
|
namespace extension {
|
|
@@ -121,9 +122,14 @@ int32_t Sampler::sample_topp(T *probabilities, float coin) {
|
|
|
121
122
|
Sampler::Sampler(int vocab_size, float temperature, float topp,
|
|
122
123
|
unsigned long long rng_seed)
|
|
123
124
|
: vocab_size_(vocab_size),
|
|
124
|
-
inv_temperature_(
|
|
125
|
+
inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
|
|
125
126
|
topp_(topp), rng_state_(rng_seed) {}
|
|
126
127
|
|
|
128
|
+
Sampler::Sampler(int vocab_size, float temperature, float topp)
|
|
129
|
+
: vocab_size_(vocab_size),
|
|
130
|
+
inv_temperature_((temperature != 0.0f) ? (1.0f / temperature) : 0.0f),
|
|
131
|
+
topp_(topp), rng_state_(std::time(nullptr)) {}
|
|
132
|
+
|
|
127
133
|
template <typename T> static void softmax(T *x, int size) {
|
|
128
134
|
// find max value (for numerical stability)
|
|
129
135
|
T max_val = x[0];
|
|
@@ -184,6 +190,7 @@ template <typename T> int32_t Sampler::sample(T *logits) {
|
|
|
184
190
|
}
|
|
185
191
|
|
|
186
192
|
template int32_t Sampler::sample<float>(float *logits);
|
|
193
|
+
template int32_t Sampler::sample<uint16_t>(uint16_t *logits);
|
|
187
194
|
template int32_t
|
|
188
195
|
Sampler::sample<executorch::aten::Half>(executorch::aten::Half *logits);
|
|
189
196
|
template int32_t
|
package/common/runner/sampler.h
CHANGED
|
@@ -26,16 +26,18 @@ namespace extension {
|
|
|
26
26
|
namespace llm {
|
|
27
27
|
// A simple llama2 sampler.
|
|
28
28
|
|
|
29
|
-
template <typename T> struct
|
|
29
|
+
template <typename T> struct ProbIndex {
|
|
30
30
|
T prob;
|
|
31
31
|
int32_t index;
|
|
32
32
|
}; // struct used when sorting probabilities during top-p sampling
|
|
33
33
|
|
|
34
|
-
class
|
|
34
|
+
class Sampler {
|
|
35
35
|
public:
|
|
36
36
|
Sampler(int32_t vocab_size, float temperature, float topp,
|
|
37
37
|
unsigned long long rng_seed);
|
|
38
38
|
|
|
39
|
+
Sampler(int32_t vocab_size, float temperature, float topp);
|
|
40
|
+
|
|
39
41
|
template <typename T> int32_t sample(T *logits);
|
|
40
42
|
|
|
41
43
|
private:
|
package/common/runner/stats.h
CHANGED
|
@@ -18,7 +18,7 @@ namespace executorch {
|
|
|
18
18
|
namespace extension {
|
|
19
19
|
namespace llm {
|
|
20
20
|
|
|
21
|
-
struct
|
|
21
|
+
struct Stats {
|
|
22
22
|
// Scaling factor for timestamps - in this case, we use ms.
|
|
23
23
|
const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
|
|
24
24
|
// Time stamps for the different stages of the execution
|
|
@@ -82,8 +82,6 @@ private:
|
|
|
82
82
|
long aggregate_sampling_timer_start_timestamp = 0;
|
|
83
83
|
};
|
|
84
84
|
|
|
85
|
-
static constexpr auto kTopp = 0.9f;
|
|
86
|
-
|
|
87
85
|
inline std::string stats_to_json_string(const Stats &stats) {
|
|
88
86
|
std::stringstream ss;
|
|
89
87
|
ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
|
|
@@ -157,7 +155,6 @@ namespace executorch {
|
|
|
157
155
|
namespace llm {
|
|
158
156
|
// TODO(T197294990): Remove these deprecated aliases once all users have moved
|
|
159
157
|
// to the new `::executorch` namespaces.
|
|
160
|
-
using ::executorch::extension::llm::kTopp;
|
|
161
158
|
using ::executorch::extension::llm::print_report;
|
|
162
159
|
using ::executorch::extension::llm::Stats;
|
|
163
160
|
} // namespace llm
|
|
@@ -9,11 +9,11 @@
|
|
|
9
9
|
// Given inputs, run a text decoder and return logits.
|
|
10
10
|
|
|
11
11
|
#include "text_decoder_runner.h"
|
|
12
|
+
#include "arange_util.h"
|
|
13
|
+
#include "stats.h"
|
|
12
14
|
|
|
13
15
|
#include <ctime>
|
|
14
16
|
|
|
15
|
-
#include "stats.h"
|
|
16
|
-
|
|
17
17
|
namespace executorch {
|
|
18
18
|
namespace extension {
|
|
19
19
|
namespace llm {
|
|
@@ -21,23 +21,37 @@ namespace llm {
|
|
|
21
21
|
// NOTE: we observed ~2x loading performance increase on iPhone 15
|
|
22
22
|
// and a ~5% improvement on Galaxy S22 by switching to
|
|
23
23
|
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
|
|
24
|
-
TextDecoderRunner::TextDecoderRunner(Module *module,
|
|
25
|
-
|
|
26
|
-
: module_(module),
|
|
27
|
-
|
|
28
|
-
vocab_size, temperature, kTopp,
|
|
29
|
-
static_cast<unsigned long long>(std::time(nullptr)))),
|
|
30
|
-
use_kv_cache_(use_kv_cache) {}
|
|
24
|
+
TextDecoderRunner::TextDecoderRunner(Module *module, IOManager *io_manager,
|
|
25
|
+
float temperature, float topp)
|
|
26
|
+
: module_(module), io_manager_(io_manager), temperature_(temperature),
|
|
27
|
+
topp_(topp) {}
|
|
31
28
|
|
|
32
29
|
// This function is functional, meaning it shouldn't modify any state of the
|
|
33
30
|
// input. It should be safe to call multiple times with the same inputs. The
|
|
34
31
|
// outer loop (call site) is responsible for managing state.
|
|
35
32
|
::executorch::runtime::Result<executorch::aten::Tensor>
|
|
36
|
-
TextDecoderRunner::step(TensorPtr &tokens,
|
|
33
|
+
TextDecoderRunner::step(TensorPtr &tokens, int64_t start_pos) {
|
|
37
34
|
// ET_LOG(Info, "Input token %" PRIu64, input_token);
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
auto method_meta = ET_UNWRAP(module_->method_meta("forward"));
|
|
36
|
+
// If only 1 input, we are not using kv cache
|
|
37
|
+
bool use_kv_cache = method_meta.num_inputs() > 1;
|
|
38
|
+
|
|
39
|
+
std::vector<int64_t> cache_positions;
|
|
40
|
+
|
|
41
|
+
if (use_kv_cache) {
|
|
42
|
+
auto start_pos_tensor = ET_UNWRAP(populate_start_pos_or_cache_position(
|
|
43
|
+
module_, start_pos, cache_positions, tokens->numel(), "forward"));
|
|
44
|
+
|
|
45
|
+
std::vector<runtime::EValue> inputs;
|
|
46
|
+
auto inputs_res = io_manager_->prepare_decode(tokens, start_pos_tensor);
|
|
47
|
+
ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error());
|
|
48
|
+
inputs = inputs_res.get();
|
|
49
|
+
auto outputs_res = module_->forward(inputs);
|
|
40
50
|
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
|
|
51
|
+
|
|
52
|
+
auto update_err = io_manager_->update_decode(outputs_res.get());
|
|
53
|
+
ET_CHECK_OK_OR_RETURN_ERROR(update_err);
|
|
54
|
+
|
|
41
55
|
ET_CHECK_MSG(outputs_res.get().size() == 1,
|
|
42
56
|
"More then one output returned from executing LLM.");
|
|
43
57
|
ET_CHECK_MSG(outputs_res.get()[0].isTensor(),
|