react-native-executorch 0.5.15 → 0.6.0-nightly-897eae9-20251213
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
#include "VoiceActivityDetection.h"
|
|
2
|
+
#include "rnexecutorch/data_processing/dsp.h"
|
|
3
|
+
#include "rnexecutorch/models/voice_activity_detection/Utils.h"
|
|
4
|
+
|
|
5
|
+
#include <algorithm>
|
|
6
|
+
#include <array>
|
|
7
|
+
#include <functional>
|
|
8
|
+
#include <numeric>
|
|
9
|
+
#include <vector>
|
|
10
|
+
|
|
11
|
+
namespace rnexecutorch::models::voice_activity_detection {
|
|
12
|
+
using namespace constants;
|
|
13
|
+
namespace ranges = std::ranges;
|
|
14
|
+
using executorch::aten::Tensor;
|
|
15
|
+
using executorch::extension::TensorPtr;
|
|
16
|
+
|
|
17
|
+
VoiceActivityDetection::VoiceActivityDetection(
|
|
18
|
+
const std::string &modelSource,
|
|
19
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
20
|
+
: BaseModel(modelSource, callInvoker) {}
|
|
21
|
+
|
|
22
|
+
std::vector<std::array<float, kPaddedWindowSize>>
|
|
23
|
+
VoiceActivityDetection::preprocess(std::span<float> waveform) const {
|
|
24
|
+
auto kHammingWindowArray = dsp::hannWindow(kWindowSize);
|
|
25
|
+
|
|
26
|
+
const size_t numFrames = (waveform.size() - kWindowSize) / kHopLength;
|
|
27
|
+
|
|
28
|
+
std::vector<std::array<float, kPaddedWindowSize>> frameBuffer(
|
|
29
|
+
numFrames, std::array<float, kPaddedWindowSize>{});
|
|
30
|
+
|
|
31
|
+
constexpr size_t totalPadding = kPaddedWindowSize - kWindowSize;
|
|
32
|
+
constexpr size_t leftPadding = totalPadding / 2;
|
|
33
|
+
for (size_t i = 0; i < numFrames; i++) {
|
|
34
|
+
|
|
35
|
+
auto windowView = waveform.subspan(i * kHopLength, kWindowSize);
|
|
36
|
+
ranges::copy(windowView, frameBuffer[i].begin() + leftPadding);
|
|
37
|
+
auto frameView =
|
|
38
|
+
std::span{frameBuffer[i].data() + leftPadding, kWindowSize};
|
|
39
|
+
const float sum = std::reduce(frameView.begin(), frameView.end(), 0.0f);
|
|
40
|
+
const float mean = sum / kWindowSize;
|
|
41
|
+
ranges::transform(frameView, frameView.begin(),
|
|
42
|
+
[mean](float value) { return value - mean; });
|
|
43
|
+
|
|
44
|
+
// apply pre-emphasis filter
|
|
45
|
+
for (auto j = frameView.size() - 1; j > 0; --j) {
|
|
46
|
+
frameView[j] -= kPreemphasisCoeff * frameView[j - 1];
|
|
47
|
+
}
|
|
48
|
+
// apply hamming window to reduce spectral leakage
|
|
49
|
+
ranges::transform(frameView, kHammingWindowArray, frameView.begin(),
|
|
50
|
+
std::multiplies{});
|
|
51
|
+
}
|
|
52
|
+
return frameBuffer;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
std::vector<types::Segment>
|
|
56
|
+
VoiceActivityDetection::generate(std::span<float> waveform) const {
|
|
57
|
+
|
|
58
|
+
auto windowedInput = preprocess(waveform);
|
|
59
|
+
auto [chunksNumber, remainder] = std::div(
|
|
60
|
+
static_cast<int>(windowedInput.size()), static_cast<int>(kModelInputMax));
|
|
61
|
+
std::vector<float> scores(windowedInput.size());
|
|
62
|
+
auto lastChunkSize = remainder;
|
|
63
|
+
if (remainder < kModelInputMin) {
|
|
64
|
+
auto paddingSize = kModelInputMin - remainder;
|
|
65
|
+
lastChunkSize = kModelInputMin;
|
|
66
|
+
windowedInput.insert(windowedInput.end(), paddingSize,
|
|
67
|
+
std::array<float, kPaddedWindowSize>{});
|
|
68
|
+
}
|
|
69
|
+
TensorPtr inputTensor;
|
|
70
|
+
size_t startIdx = 0;
|
|
71
|
+
|
|
72
|
+
for (size_t i = 0; i < chunksNumber; i++) {
|
|
73
|
+
std::span<std::array<float, kPaddedWindowSize>> chunk(
|
|
74
|
+
windowedInput.data() + kModelInputMax * i, kModelInputMax);
|
|
75
|
+
inputTensor = executorch::extension::from_blob(
|
|
76
|
+
chunk.data(), {kModelInputMax, kPaddedWindowSize},
|
|
77
|
+
executorch::aten::ScalarType::Float);
|
|
78
|
+
auto forwardResult = BaseModel::forward(inputTensor);
|
|
79
|
+
if (!forwardResult.ok()) {
|
|
80
|
+
throw std::runtime_error(
|
|
81
|
+
"Failed to forward, error: " +
|
|
82
|
+
std::to_string(static_cast<uint32_t>(forwardResult.error())));
|
|
83
|
+
}
|
|
84
|
+
auto tensor = forwardResult->at(0).toTensor();
|
|
85
|
+
startIdx = utils::getNonSpeechClassProbabilites(
|
|
86
|
+
tensor, tensor.size(2), tensor.size(1), scores, startIdx);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
std::span<std::array<float, kPaddedWindowSize>> lastChunk(
|
|
90
|
+
windowedInput.data() + kModelInputMax * chunksNumber, lastChunkSize);
|
|
91
|
+
inputTensor = executorch::extension::from_blob(
|
|
92
|
+
lastChunk.data(), {lastChunkSize, kPaddedWindowSize},
|
|
93
|
+
executorch::aten::ScalarType::Float);
|
|
94
|
+
auto forwardResult = BaseModel::forward(inputTensor);
|
|
95
|
+
if (!forwardResult.ok()) {
|
|
96
|
+
throw std::runtime_error(
|
|
97
|
+
"Failed to forward, error: " +
|
|
98
|
+
std::to_string(static_cast<uint32_t>(forwardResult.error())));
|
|
99
|
+
}
|
|
100
|
+
auto tensor = forwardResult->at(0).toTensor();
|
|
101
|
+
startIdx = utils::getNonSpeechClassProbabilites(tensor, tensor.size(2),
|
|
102
|
+
remainder, scores, startIdx);
|
|
103
|
+
return postprocess(scores, kSpeechThreshold);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
std::vector<types::Segment>
|
|
107
|
+
VoiceActivityDetection::postprocess(const std::vector<float> &scores,
|
|
108
|
+
float threshold) const {
|
|
109
|
+
bool triggered = false;
|
|
110
|
+
std::vector<types::Segment> speechSegments{};
|
|
111
|
+
ssize_t startSegment = -1;
|
|
112
|
+
ssize_t endSegment = -1;
|
|
113
|
+
ssize_t potentialStart = -1;
|
|
114
|
+
ssize_t potentialEnd = -1;
|
|
115
|
+
float score;
|
|
116
|
+
for (size_t i = 0; i < scores.size(); i++) {
|
|
117
|
+
score = 1 - scores[i];
|
|
118
|
+
if (!triggered) {
|
|
119
|
+
if (score >= threshold) {
|
|
120
|
+
if (potentialStart == -1) {
|
|
121
|
+
potentialStart = i;
|
|
122
|
+
} else if (i - potentialStart >= kMinSpeechDuration) {
|
|
123
|
+
triggered = true;
|
|
124
|
+
startSegment = potentialStart;
|
|
125
|
+
potentialStart = -1;
|
|
126
|
+
}
|
|
127
|
+
} else { // score < threshold
|
|
128
|
+
potentialStart = -1;
|
|
129
|
+
}
|
|
130
|
+
} else { // triggered
|
|
131
|
+
if (score < threshold) {
|
|
132
|
+
if (potentialEnd == -1) {
|
|
133
|
+
potentialEnd = i;
|
|
134
|
+
} else if (i - potentialEnd >= kMinSilenceDuration) {
|
|
135
|
+
triggered = false;
|
|
136
|
+
endSegment = potentialEnd;
|
|
137
|
+
speechSegments.emplace_back(startSegment, endSegment);
|
|
138
|
+
potentialEnd = -1;
|
|
139
|
+
}
|
|
140
|
+
} else {
|
|
141
|
+
potentialEnd = -1;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
if (triggered) {
|
|
146
|
+
endSegment = scores.size();
|
|
147
|
+
speechSegments.emplace_back(startSegment, endSegment);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
for (auto &[start, end] : speechSegments) {
|
|
151
|
+
// std::max(start-kSpeedchPad, 0) might be underflow that is why we use ?
|
|
152
|
+
// operator.
|
|
153
|
+
start = (start > kSpeechPad ? start - kSpeechPad : 0) * kHopLength;
|
|
154
|
+
end = std::min(end + kSpeechPad, scores.size()) * kHopLength;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
return speechSegments;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
} // namespace rnexecutorch::models::voice_activity_detection
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstddef>
|
|
4
|
+
#include <executorch/extension/module/module.h>
|
|
5
|
+
#include <executorch/extension/tensor/tensor.h>
|
|
6
|
+
#include <executorch/extension/tensor/tensor_ptr.h>
|
|
7
|
+
#include <executorch/runtime/core/evalue.h>
|
|
8
|
+
#include <span>
|
|
9
|
+
|
|
10
|
+
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
|
|
11
|
+
#include "rnexecutorch/models/BaseModel.h"
|
|
12
|
+
#include "rnexecutorch/models/voice_activity_detection/Constants.h"
|
|
13
|
+
#include "rnexecutorch/models/voice_activity_detection/Types.h"
|
|
14
|
+
|
|
15
|
+
namespace rnexecutorch {
|
|
16
|
+
namespace models::voice_activity_detection {
|
|
17
|
+
using executorch::extension::TensorPtr;
|
|
18
|
+
using executorch::runtime::EValue;
|
|
19
|
+
class VoiceActivityDetection : public BaseModel {
|
|
20
|
+
public:
|
|
21
|
+
VoiceActivityDetection(const std::string &modelSource,
|
|
22
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
23
|
+
[[nodiscard("Registered non-void function")]] std::vector<types::Segment>
|
|
24
|
+
generate(std::span<float> waveform) const;
|
|
25
|
+
|
|
26
|
+
private:
|
|
27
|
+
std::vector<std::array<float, constants::kPaddedWindowSize>>
|
|
28
|
+
preprocess(std::span<float> waveform) const;
|
|
29
|
+
std::vector<types::Segment> postprocess(const std::vector<float> &scores,
|
|
30
|
+
float threshold) const;
|
|
31
|
+
};
|
|
32
|
+
} // namespace models::voice_activity_detection
|
|
33
|
+
|
|
34
|
+
REGISTER_CONSTRUCTOR(models::voice_activity_detection::VoiceActivityDetection,
|
|
35
|
+
std::string, std::shared_ptr<react::CallInvoker>);
|
|
36
|
+
} // namespace rnexecutorch
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.10)
|
|
2
|
+
project(RNExecutorchTests)
|
|
3
|
+
|
|
4
|
+
# C++ standard
|
|
5
|
+
set(CMAKE_CXX_STANDARD 20)
|
|
6
|
+
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
|
|
7
|
+
|
|
8
|
+
# googletest subdirectory
|
|
9
|
+
# Using an absolute path from the top-level source directory
|
|
10
|
+
add_subdirectory(${CMAKE_SOURCE_DIR}/../../../../../third-party/googletest ${PROJECT_BINARY_DIR}/googletest)
|
|
11
|
+
|
|
12
|
+
# Directories to include
|
|
13
|
+
include_directories(${CMAKE_SOURCE_DIR}/../data_processing)
|
|
14
|
+
include_directories(${CMAKE_SOURCE_DIR}/..)
|
|
15
|
+
|
|
16
|
+
# Source files
|
|
17
|
+
set(SOURCE_FILES ${CMAKE_SOURCE_DIR}/../data_processing/Numerical.cpp)
|
|
18
|
+
|
|
19
|
+
# Executables for the tests
|
|
20
|
+
add_executable(NumericalTests NumericalTest.cpp ${SOURCE_FILES})
|
|
21
|
+
add_executable(LogTests LogTest.cpp)
|
|
22
|
+
|
|
23
|
+
# Libraries linking
|
|
24
|
+
target_link_libraries(NumericalTests gtest gtest_main)
|
|
25
|
+
target_link_libraries(LogTests gtest gtest_main)
|
|
26
|
+
|
|
27
|
+
# Testing functionalities
|
|
28
|
+
enable_testing()
|
|
29
|
+
add_test(NAME NumericalTests COMMAND NumericalTests)
|
|
30
|
+
add_test(NAME LogTests COMMAND LogTests)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
#include "../data_processing/Numerical.h"
|
|
2
|
+
#include <gtest/gtest.h>
|
|
3
|
+
#include <limits>
|
|
4
|
+
#include <span>
|
|
5
|
+
#include <stdexcept>
|
|
6
|
+
#include <vector>
|
|
7
|
+
|
|
8
|
+
namespace rnexecutorch::numerical {
|
|
9
|
+
|
|
10
|
+
// Helper function to check if two float vectors are approximately equal
|
|
11
|
+
void expect_vectors_eq(const std::vector<float> &vector1,
|
|
12
|
+
const std::vector<float> &vector2, float atol = 1.0e-6F) {
|
|
13
|
+
ASSERT_EQ(vector1.size(), vector2.size());
|
|
14
|
+
for (size_t i = 0; i < vector1.size(); i++) {
|
|
15
|
+
EXPECT_NEAR(vector1[i], vector2[i], atol);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
TEST(SoftmaxTests, SoftmaxBasic) {
|
|
20
|
+
std::vector<float> input = {1.0F, 2.0F, 3.0F};
|
|
21
|
+
softmax(input);
|
|
22
|
+
const std::vector<float> expected = {0.09003057F, 0.24472847F, 0.66524095F};
|
|
23
|
+
expect_vectors_eq(input, expected);
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
TEST(SoftmaxTests, SoftmaxWithBigValues) {
|
|
27
|
+
std::vector<float> input = {100000.0F, 100000.0F, 100000.0F};
|
|
28
|
+
softmax(input);
|
|
29
|
+
const std::vector<float> expected = {0.3333333F, 0.3333333F, 0.3333333F};
|
|
30
|
+
expect_vectors_eq(input, expected);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
TEST(SoftmaxTests, SoftmaxOfEmptyVector) {
|
|
34
|
+
std::vector<float> emptyVector{};
|
|
35
|
+
EXPECT_NO_THROW(softmax(emptyVector));
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
TEST(NormalizeTests, NormalizeBasic) {
|
|
39
|
+
std::vector<float> input = {1.0F, 2.0F, 3.0F};
|
|
40
|
+
normalize(input);
|
|
41
|
+
const auto normOfInput = std::sqrtf(14.0F);
|
|
42
|
+
const std::vector<float> expected = {1.0F / normOfInput, 2.0F / normOfInput,
|
|
43
|
+
3.0F / normOfInput};
|
|
44
|
+
expect_vectors_eq(input, expected);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
TEST(NormalizeTests, NormalizationOfExtremelySmallValues) {
|
|
48
|
+
constexpr auto epsilon = std::numeric_limits<float>::epsilon();
|
|
49
|
+
std::vector<float> input(3, epsilon);
|
|
50
|
+
const auto normOfInput = std::sqrtf(3.0F);
|
|
51
|
+
const std::vector<float> expected(3, 1.0F / normOfInput);
|
|
52
|
+
normalize(input);
|
|
53
|
+
expect_vectors_eq(input, expected);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
TEST(NormalizeTests, NormalizationOfZeroVector) {
|
|
57
|
+
std::vector<float> zeroVector(3, 0.0F);
|
|
58
|
+
EXPECT_NO_THROW(normalize(zeroVector));
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
TEST(NormalizeTests, NormalizationOfEmptyVector) {
|
|
62
|
+
std::vector<float> emptyVector{};
|
|
63
|
+
EXPECT_NO_THROW(normalize(emptyVector));
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
TEST(MeanPoolingTests, MeanPoolingBasic) {
|
|
67
|
+
const std::vector<float> modelOutputVec = {1.0F, 2.0F, 3.0F,
|
|
68
|
+
4.0F, 5.0F, 6.0F};
|
|
69
|
+
const std::vector<int64_t> attnMaskVec = {1, 1, 0};
|
|
70
|
+
|
|
71
|
+
std::span<const float> modelOutput(modelOutputVec);
|
|
72
|
+
std::span<const int64_t> attnMask(attnMaskVec);
|
|
73
|
+
|
|
74
|
+
const auto result = meanPooling(modelOutput, attnMask);
|
|
75
|
+
const std::vector<float> expected = {2.0F, 3.0F};
|
|
76
|
+
expect_vectors_eq(result, expected);
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
TEST(MeanPoolingTests, MeanPoolingWithZeroAttentionMask) {
|
|
80
|
+
const std::vector<float> modelOutputVec = {1.0F, 2.0F, 3.0F,
|
|
81
|
+
4.0F, 5.0F, 6.0F};
|
|
82
|
+
const std::vector<int64_t> attnMaskVec = {0, 0, 0};
|
|
83
|
+
|
|
84
|
+
std::span<const float> modelOutput(modelOutputVec);
|
|
85
|
+
std::span<const int64_t> attnMask(attnMaskVec);
|
|
86
|
+
|
|
87
|
+
const auto result = meanPooling(modelOutput, attnMask);
|
|
88
|
+
const std::vector<float> expected = {0.0F, 0.0F};
|
|
89
|
+
expect_vectors_eq(result, expected);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
TEST(MeanPoolingTests, InvalidDimensionSize) {
|
|
93
|
+
const std::vector<float> modelOutput = {1.0F, 2.0F, 3.0F, 4.0F};
|
|
94
|
+
const std::vector<int64_t> attnMask = {1, 1, 1};
|
|
95
|
+
|
|
96
|
+
EXPECT_THROW(
|
|
97
|
+
{ meanPooling(modelOutput, attnMask); },
|
|
98
|
+
std::invalid_argument);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
TEST(MeanPoolingTests, EmptyAttentionMask) {
|
|
102
|
+
const std::vector<float> modelOutput = {1.0F, 2.0F, 3.0F, 4.0F};
|
|
103
|
+
const std::vector<int64_t> attnMask = {};
|
|
104
|
+
|
|
105
|
+
EXPECT_THROW(
|
|
106
|
+
{ meanPooling(modelOutput, attnMask); },
|
|
107
|
+
std::invalid_argument);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
} // namespace rnexecutorch::numerical
|
|
@@ -2,19 +2,36 @@
|
|
|
2
2
|
This guide provide information on how functions are tested, how to install all needed dependencies and how to run tests.
|
|
3
3
|
|
|
4
4
|
### Used Tools
|
|
5
|
-
To test the native code we use [`googletest`](https://github.com/google/googletest). It
|
|
5
|
+
To test the native code we use [`googletest`](https://github.com/google/googletest). It's a flexible tool for creating unit tests.
|
|
6
6
|
|
|
7
7
|
### Installation
|
|
8
|
-
The
|
|
9
|
-
*
|
|
10
|
-
`git clone git@github.com:google/googletest.git && cd googletest && git switch --detach v1.17.0`
|
|
11
|
-
* Build library files:
|
|
12
|
-
* `mkdir build && cd build`
|
|
13
|
-
* `cmake ..`
|
|
14
|
-
* `make`
|
|
15
|
-
* Add `/usr/local/include` and `/usr/local/lib` to your path if not already there.
|
|
8
|
+
The googletest is already in repo in `react-native-executorch/third-party/googletest`. Firstly, you need to fetch googletest locally, run from root directory of project:
|
|
9
|
+
* `git submodule update --init --recursive third-party/googletest`
|
|
16
10
|
|
|
17
|
-
###
|
|
18
|
-
To run tests
|
|
19
|
-
* `
|
|
20
|
-
|
|
11
|
+
### Build Test Files
|
|
12
|
+
To run tests navigate tests directory namely:
|
|
13
|
+
* `cd packages/react-native-executorch/common/rnexecutorch/tests`
|
|
14
|
+
and then type:
|
|
15
|
+
* `mkdir build && cd build`
|
|
16
|
+
* `cmake ..`
|
|
17
|
+
* `make`
|
|
18
|
+
|
|
19
|
+
### Run Tests
|
|
20
|
+
To run tests use the following command in `packages/react-native-executorch/common/rnexecutorch/tests/build`:
|
|
21
|
+
* `ctest --verbose`
|
|
22
|
+
|
|
23
|
+
Every time you updated the source code, you need to recompile the test files using: `cmake .. && make`.
|
|
24
|
+
|
|
25
|
+
### How to add a new test
|
|
26
|
+
To add new test you need to:
|
|
27
|
+
* Place `*.cpp` file with tests using googletest in this directory.
|
|
28
|
+
* In `CMakeLists.txt`, add all executables and link them with googletest, e.g.:
|
|
29
|
+
```
|
|
30
|
+
set(SOURCE_FILES ${CMAKE_SOURCE_DIR}/../data_processing/Numerical.cpp)
|
|
31
|
+
add_executable(NumericalTests tests/NumericalTest.cpp ${SOURCE_FILES})
|
|
32
|
+
target_link_libraries(NumericalTests gtest gtest_main)
|
|
33
|
+
```
|
|
34
|
+
* Add test execution, e.g.:
|
|
35
|
+
```
|
|
36
|
+
add_test(NAME NumericalTests COMMAND NumericalTests)
|
|
37
|
+
```
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
#include <executorch/extension/threadpool/cpuinfo_utils.h>
|
|
5
5
|
#include <memory>
|
|
6
6
|
#include <mutex>
|
|
7
|
+
#include <opencv2/opencv.hpp>
|
|
7
8
|
#include <optional>
|
|
8
9
|
#include <rnexecutorch/Log.h>
|
|
9
10
|
#include <rnexecutorch/threads/HighPerformanceThreadPool.h>
|
|
@@ -38,6 +39,9 @@ public:
|
|
|
38
39
|
numThreads, "threads");
|
|
39
40
|
instance = std::make_unique<HighPerformanceThreadPool>(numThreads.value(),
|
|
40
41
|
config);
|
|
42
|
+
// Disable OpenCV's internal threading to prevent it from overriding our
|
|
43
|
+
// thread pool configuration, which would cause degraded performance
|
|
44
|
+
cv::setNumThreads(0);
|
|
41
45
|
});
|
|
42
46
|
}
|
|
43
47
|
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#include "arange_util.h"
|
|
10
|
+
|
|
11
|
+
namespace torch::executor::native {
|
|
12
|
+
#define ET_ARANGE_IMPL(ctx, start, numel, step, out, op_name) \
|
|
13
|
+
ET_SWITCH_REALHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE, [&]() { \
|
|
14
|
+
auto out_data = out.mutable_data_ptr<CTYPE>(); \
|
|
15
|
+
for (executorch::aten::SizesType i = 0; i < numel; ++i) { \
|
|
16
|
+
out_data[i] = static_cast<CTYPE>(start + i * step); \
|
|
17
|
+
} \
|
|
18
|
+
})
|
|
19
|
+
|
|
20
|
+
executorch::aten::SizesType compute_arange_out_size(double start, double end,
|
|
21
|
+
double step) {
|
|
22
|
+
executorch::aten::SizesType numel =
|
|
23
|
+
static_cast<executorch::aten::SizesType>(std::ceil((end - start) / step));
|
|
24
|
+
|
|
25
|
+
ET_CHECK_MSG(numel >= 0,
|
|
26
|
+
"numel should be non-negative, but got (%" PRId64
|
|
27
|
+
"). start (%f), end (%f), step (%f)",
|
|
28
|
+
static_cast<int64_t>(numel), start, end, step);
|
|
29
|
+
return numel;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
void arange_out_impl(KernelRuntimeContext &ctx, double start, double end,
|
|
33
|
+
double step, Tensor &out) {
|
|
34
|
+
(void)ctx;
|
|
35
|
+
executorch::aten::SizesType numel = compute_arange_out_size(start, end, step);
|
|
36
|
+
ET_ARANGE_IMPL(ctx, start, numel, step, out, "arange.start_out");
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
void arange_out_impl(KernelRuntimeContext &ctx, double end, Tensor &out) {
|
|
40
|
+
(void)ctx;
|
|
41
|
+
ET_ARANGE_IMPL(ctx, 0.0, end, 1.0, out, "arange.out");
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
} // namespace torch::executor::native
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
#include "kernel_includes.h"
|
|
12
|
+
|
|
13
|
+
namespace torch::executor::native {
|
|
14
|
+
|
|
15
|
+
executorch::aten::SizesType compute_arange_out_size(double start, double end,
|
|
16
|
+
double step);
|
|
17
|
+
|
|
18
|
+
inline executorch::aten::SizesType compute_arange_out_size(double end) {
|
|
19
|
+
return compute_arange_out_size(0.0, end, 1.0);
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
void arange_out_impl(KernelRuntimeContext &ctx, double start, double end,
|
|
23
|
+
double step, Tensor &out);
|
|
24
|
+
|
|
25
|
+
void arange_out_impl(KernelRuntimeContext &ctx, double end, Tensor &out);
|
|
26
|
+
|
|
27
|
+
inline void arange_out_impl(double start, double end, double step,
|
|
28
|
+
Tensor &out) {
|
|
29
|
+
KernelRuntimeContext ctx;
|
|
30
|
+
arange_out_impl(ctx, start, end, step, out);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
inline void arange_out_impl(double end, Tensor &out) {
|
|
34
|
+
KernelRuntimeContext ctx;
|
|
35
|
+
arange_out_impl(ctx, 0.0, end, 1.0, out);
|
|
36
|
+
}
|
|
37
|
+
} // namespace torch::executor::native
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
+
* All rights reserved.
|
|
4
|
+
*
|
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
|
7
|
+
*/
|
|
8
|
+
#pragma once
|
|
9
|
+
// constants for LLM runtime
|
|
10
|
+
namespace executorch::extension::llm {
|
|
11
|
+
|
|
12
|
+
// Runtime metadata key constants
|
|
13
|
+
inline constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
|
|
14
|
+
inline constexpr auto kBosId = "get_bos_id";
|
|
15
|
+
inline constexpr auto kEosIds = "get_eos_ids";
|
|
16
|
+
inline constexpr auto kMaxSeqLen = "get_max_seq_len";
|
|
17
|
+
inline constexpr auto kMaxContextLen = "get_max_context_len";
|
|
18
|
+
inline constexpr auto kVocabSize = "get_vocab_size";
|
|
19
|
+
inline constexpr auto kUseKVCache = "use_kv_cache";
|
|
20
|
+
inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
|
|
21
|
+
|
|
22
|
+
// Multimodal method name conventions
|
|
23
|
+
inline constexpr auto kVisionEncoderMethod = "vision_encoder";
|
|
24
|
+
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
|
|
25
|
+
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
|
|
26
|
+
inline constexpr auto kTextModelMethod = "text_decoder";
|
|
27
|
+
|
|
28
|
+
} // namespace executorch::extension::llm
|