react-native-executorch 0.5.15-rc1 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -8,13 +8,14 @@ namespace rnexecutorch::models {
|
|
|
8
8
|
|
|
9
9
|
using namespace facebook;
|
|
10
10
|
using namespace executorch::extension;
|
|
11
|
+
using ::executorch::extension::module::Module;
|
|
11
12
|
using ::executorch::runtime::Error;
|
|
12
13
|
|
|
13
14
|
BaseModel::BaseModel(const std::string &modelSource,
|
|
14
|
-
std::shared_ptr<react::CallInvoker> callInvoker
|
|
15
|
+
std::shared_ptr<react::CallInvoker> callInvoker,
|
|
16
|
+
Module::LoadMode loadMode)
|
|
15
17
|
: callInvoker(callInvoker),
|
|
16
|
-
module_(std::make_unique<Module>(
|
|
17
|
-
modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)) {
|
|
18
|
+
module_(std::make_unique<Module>(modelSource, loadMode)) {
|
|
18
19
|
Error loadError = module_->load();
|
|
19
20
|
if (loadError != Error::Ok) {
|
|
20
21
|
throw std::runtime_error("Failed to load model: Error " +
|
|
@@ -29,7 +30,7 @@ BaseModel::BaseModel(const std::string &modelSource,
|
|
|
29
30
|
}
|
|
30
31
|
|
|
31
32
|
std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
|
|
32
|
-
int32_t index) {
|
|
33
|
+
int32_t index) const {
|
|
33
34
|
if (!module_) {
|
|
34
35
|
throw std::runtime_error("Model not loaded: Cannot get input shape");
|
|
35
36
|
}
|
|
@@ -55,7 +56,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
|
|
|
55
56
|
}
|
|
56
57
|
|
|
57
58
|
std::vector<std::vector<int32_t>>
|
|
58
|
-
BaseModel::getAllInputShapes(std::string methodName) {
|
|
59
|
+
BaseModel::getAllInputShapes(std::string methodName) const {
|
|
59
60
|
if (!module_) {
|
|
60
61
|
throw std::runtime_error("Model not loaded: Cannot get all input shapes");
|
|
61
62
|
}
|
|
@@ -87,7 +88,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
|
|
|
87
88
|
/// to JS. It is not meant to be used within C++. If you want to call forward
|
|
88
89
|
/// from C++ on a BaseModel, please use BaseModel::forward.
|
|
89
90
|
std::vector<JSTensorViewOut>
|
|
90
|
-
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
|
|
91
|
+
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
|
|
91
92
|
if (!module_) {
|
|
92
93
|
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
|
|
93
94
|
}
|
|
@@ -126,8 +127,8 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
|
|
|
126
127
|
auto &outputTensor = outputs[i].toTensor();
|
|
127
128
|
std::vector<int32_t> sizes = getTensorShape(outputTensor);
|
|
128
129
|
size_t bufferSize = outputTensor.numel() * outputTensor.element_size();
|
|
129
|
-
auto buffer = std::make_shared<OwningArrayBuffer>(
|
|
130
|
-
|
|
130
|
+
auto buffer = std::make_shared<OwningArrayBuffer>(
|
|
131
|
+
outputTensor.const_data_ptr(), bufferSize);
|
|
131
132
|
auto jsTensor = JSTensorViewOut(sizes, outputTensor.scalar_type(), buffer);
|
|
132
133
|
output.emplace_back(jsTensor);
|
|
133
134
|
}
|
|
@@ -135,7 +136,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
|
|
|
135
136
|
}
|
|
136
137
|
|
|
137
138
|
Result<executorch::runtime::MethodMeta>
|
|
138
|
-
BaseModel::getMethodMeta(const std::string &methodName) {
|
|
139
|
+
BaseModel::getMethodMeta(const std::string &methodName) const {
|
|
139
140
|
if (!module_) {
|
|
140
141
|
throw std::runtime_error("Model not loaded: Cannot get method meta!");
|
|
141
142
|
}
|
|
@@ -160,7 +161,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {
|
|
|
160
161
|
|
|
161
162
|
Result<std::vector<EValue>>
|
|
162
163
|
BaseModel::execute(const std::string &methodName,
|
|
163
|
-
const std::vector<EValue> &input_value) {
|
|
164
|
+
const std::vector<EValue> &input_value) const {
|
|
164
165
|
if (!module_) {
|
|
165
166
|
throw std::runtime_error("Model not loaded, cannot run execute.");
|
|
166
167
|
}
|
|
@@ -174,7 +175,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
|
|
|
174
175
|
void BaseModel::unload() noexcept { module_.reset(nullptr); }
|
|
175
176
|
|
|
176
177
|
std::vector<int32_t>
|
|
177
|
-
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
|
|
178
|
+
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
|
|
178
179
|
auto sizes = tensor.sizes();
|
|
179
180
|
return std::vector<int32_t>(sizes.begin(), sizes.end());
|
|
180
181
|
}
|
|
@@ -13,26 +13,32 @@
|
|
|
13
13
|
namespace rnexecutorch {
|
|
14
14
|
namespace models {
|
|
15
15
|
using namespace facebook;
|
|
16
|
+
using executorch::extension::module::Module;
|
|
16
17
|
using executorch::runtime::EValue;
|
|
17
18
|
using executorch::runtime::Result;
|
|
19
|
+
|
|
18
20
|
class BaseModel {
|
|
19
21
|
public:
|
|
20
|
-
BaseModel(
|
|
21
|
-
|
|
22
|
+
BaseModel(
|
|
23
|
+
const std::string &modelSource,
|
|
24
|
+
std::shared_ptr<react::CallInvoker> callInvoker,
|
|
25
|
+
Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors);
|
|
22
26
|
std::size_t getMemoryLowerBound() const noexcept;
|
|
23
27
|
void unload() noexcept;
|
|
24
|
-
std::vector<int32_t> getInputShape(std::string method_name,
|
|
28
|
+
std::vector<int32_t> getInputShape(std::string method_name,
|
|
29
|
+
int32_t index) const;
|
|
25
30
|
std::vector<std::vector<int32_t>>
|
|
26
|
-
getAllInputShapes(std::string methodName = "forward");
|
|
31
|
+
getAllInputShapes(std::string methodName = "forward") const;
|
|
27
32
|
std::vector<JSTensorViewOut>
|
|
28
|
-
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
|
|
33
|
+
forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
|
|
29
34
|
Result<std::vector<EValue>> forward(const EValue &input_value) const;
|
|
30
35
|
Result<std::vector<EValue>>
|
|
31
36
|
forward(const std::vector<EValue> &input_value) const;
|
|
32
|
-
Result<std::vector<EValue>>
|
|
33
|
-
|
|
37
|
+
Result<std::vector<EValue>>
|
|
38
|
+
execute(const std::string &methodName,
|
|
39
|
+
const std::vector<EValue> &input_value) const;
|
|
34
40
|
Result<executorch::runtime::MethodMeta>
|
|
35
|
-
getMethodMeta(const std::string &methodName);
|
|
41
|
+
getMethodMeta(const std::string &methodName) const;
|
|
36
42
|
|
|
37
43
|
protected:
|
|
38
44
|
// If possible, models should not use the JS runtime to keep JSI internals
|
|
@@ -42,9 +48,11 @@ protected:
|
|
|
42
48
|
std::shared_ptr<react::CallInvoker> callInvoker;
|
|
43
49
|
std::unique_ptr<executorch::extension::Module> module_;
|
|
44
50
|
|
|
45
|
-
private:
|
|
46
51
|
std::size_t memorySizeLowerBound{0};
|
|
47
|
-
|
|
52
|
+
|
|
53
|
+
private:
|
|
54
|
+
std::vector<int32_t>
|
|
55
|
+
getTensorShape(const executorch::aten::Tensor &tensor) const;
|
|
48
56
|
};
|
|
49
57
|
} // namespace models
|
|
50
58
|
|
|
@@ -11,17 +11,9 @@ BaseEmbeddings::BaseEmbeddings(const std::string &modelSource,
|
|
|
11
11
|
std::shared_ptr<OwningArrayBuffer>
|
|
12
12
|
BaseEmbeddings::postprocess(const Result<std::vector<EValue>> &forwardResult) {
|
|
13
13
|
auto forwardResultTensor = forwardResult->at(0).toTensor();
|
|
14
|
-
auto
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
std::span<float> modelOutput(static_cast<float *>(dataPtr), outputNumel);
|
|
18
|
-
|
|
19
|
-
auto createBuffer = [](const auto &data, size_t size) {
|
|
20
|
-
auto buffer = std::make_shared<OwningArrayBuffer>(size);
|
|
21
|
-
std::memcpy(buffer->data(), data, size);
|
|
22
|
-
return buffer;
|
|
23
|
-
};
|
|
24
|
-
return createBuffer(modelOutput.data(), modelOutput.size_bytes());
|
|
14
|
+
auto buffer = std::make_shared<OwningArrayBuffer>(
|
|
15
|
+
forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes());
|
|
16
|
+
return buffer;
|
|
25
17
|
}
|
|
26
18
|
|
|
27
19
|
} // namespace rnexecutorch::models::embeddings
|
|
@@ -48,7 +48,6 @@ TextEmbeddings::generate(const std::string input) {
|
|
|
48
48
|
attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long);
|
|
49
49
|
|
|
50
50
|
auto forwardResult = BaseModel::forward({tokenIds, attnMask});
|
|
51
|
-
|
|
52
51
|
if (!forwardResult.ok()) {
|
|
53
52
|
throw std::runtime_error(
|
|
54
53
|
"Function forward in TextEmbeddings failed with error code: " +
|
|
@@ -62,11 +62,9 @@ std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(
|
|
|
62
62
|
std::vector<std::shared_ptr<OwningArrayBuffer>> resultClasses;
|
|
63
63
|
resultClasses.reserve(numClasses);
|
|
64
64
|
for (std::size_t cl = 0; cl < numClasses; ++cl) {
|
|
65
|
-
auto classBuffer =
|
|
66
|
-
|
|
65
|
+
auto classBuffer = std::make_shared<OwningArrayBuffer>(
|
|
66
|
+
&resultData[cl * numModelPixels], numModelPixels * sizeof(float));
|
|
67
67
|
resultClasses.push_back(classBuffer);
|
|
68
|
-
std::memcpy(classBuffer->data(), &resultData[cl * numModelPixels],
|
|
69
|
-
numModelPixels * sizeof(float));
|
|
70
68
|
}
|
|
71
69
|
|
|
72
70
|
// Apply softmax per each pixel across all classes
|
|
@@ -112,18 +110,14 @@ std::shared_ptr<jsi::Object> ImageSegmentation::postprocess(
|
|
|
112
110
|
cv::Mat argmaxMat(modelImageSize, CV_32SC1, argmax->data());
|
|
113
111
|
cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0,
|
|
114
112
|
cv::InterpolationFlags::INTER_NEAREST);
|
|
115
|
-
argmax = std::make_shared<OwningArrayBuffer>(
|
|
116
|
-
|
|
117
|
-
std::memcpy(argmax->data(), argmaxMat.data,
|
|
118
|
-
originalSize.area() * sizeof(int32_t));
|
|
113
|
+
argmax = std::make_shared<OwningArrayBuffer>(
|
|
114
|
+
argmaxMat.data, originalSize.area() * sizeof(int32_t));
|
|
119
115
|
|
|
120
116
|
for (auto &[label, arrayBuffer] : *buffersToReturn) {
|
|
121
117
|
cv::Mat classMat(modelImageSize, CV_32FC1, arrayBuffer->data());
|
|
122
118
|
cv::resize(classMat, classMat, originalSize);
|
|
123
|
-
arrayBuffer = std::make_shared<OwningArrayBuffer>(
|
|
124
|
-
|
|
125
|
-
std::memcpy(arrayBuffer->data(), classMat.data,
|
|
126
|
-
originalSize.area() * sizeof(float));
|
|
119
|
+
arrayBuffer = std::make_shared<OwningArrayBuffer>(
|
|
120
|
+
classMat.data, originalSize.area() * sizeof(float));
|
|
127
121
|
}
|
|
128
122
|
}
|
|
129
123
|
return populateDictionary(argmax, buffersToReturn);
|
|
@@ -1,30 +1,33 @@
|
|
|
1
1
|
#include "LLM.h"
|
|
2
2
|
|
|
3
|
-
#include <atomic>
|
|
4
3
|
#include <executorch/extension/tensor/tensor.h>
|
|
5
4
|
#include <filesystem>
|
|
6
5
|
#include <rnexecutorch/threads/GlobalThreadPool.h>
|
|
7
6
|
|
|
8
7
|
namespace rnexecutorch::models::llm {
|
|
8
|
+
namespace llm = ::executorch::extension::llm;
|
|
9
|
+
namespace fs = std::filesystem;
|
|
9
10
|
using namespace facebook;
|
|
10
11
|
using executorch::extension::TensorPtr;
|
|
12
|
+
using executorch::extension::module::Module;
|
|
11
13
|
using executorch::runtime::Error;
|
|
12
14
|
|
|
13
15
|
LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
|
|
14
16
|
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
15
|
-
:
|
|
16
|
-
|
|
17
|
-
|
|
17
|
+
: BaseModel(modelSource, callInvoker, Module::LoadMode::File),
|
|
18
|
+
runner(
|
|
19
|
+
std::make_unique<example::Runner>(module_.get(), tokenizerSource)) {
|
|
18
20
|
auto loadResult = runner->load();
|
|
19
21
|
if (loadResult != Error::Ok) {
|
|
20
22
|
throw std::runtime_error("Failed to load LLM runner, error code: " +
|
|
21
23
|
std::to_string(static_cast<int>(loadResult)));
|
|
22
24
|
}
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
25
|
+
|
|
26
|
+
memorySizeLowerBound = fs::file_size(fs::path(modelSource)) +
|
|
27
|
+
fs::file_size(fs::path(tokenizerSource));
|
|
26
28
|
}
|
|
27
29
|
|
|
30
|
+
// TODO: add a way to manipulate the generation config with params
|
|
28
31
|
void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
|
|
29
32
|
if (!runner || !runner->is_loaded()) {
|
|
30
33
|
throw std::runtime_error("Runner is not loaded");
|
|
@@ -37,7 +40,8 @@ void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {
|
|
|
37
40
|
});
|
|
38
41
|
};
|
|
39
42
|
|
|
40
|
-
auto
|
|
43
|
+
auto config = llm::GenerationConfig{.echo = false, .warming = false};
|
|
44
|
+
auto error = runner->generate(input, config, nativeCallback, {});
|
|
41
45
|
if (error != executorch::runtime::Error::Ok) {
|
|
42
46
|
throw std::runtime_error("Failed to generate text, error code: " +
|
|
43
47
|
std::to_string(static_cast<int>(error)));
|
|
@@ -76,6 +80,19 @@ void LLM::setTimeInterval(size_t timeInterval) {
|
|
|
76
80
|
runner->set_time_interval(timeInterval);
|
|
77
81
|
}
|
|
78
82
|
|
|
83
|
+
void LLM::setTemperature(float temperature) {
|
|
84
|
+
if (!runner || !runner->is_loaded()) {
|
|
85
|
+
throw std::runtime_error("Can't configure a model that's not loaded!");
|
|
86
|
+
}
|
|
87
|
+
runner->set_temperature(temperature);
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
void LLM::setTopp(float topp) {
|
|
91
|
+
if (!runner || !runner->is_loaded()) {
|
|
92
|
+
throw std::runtime_error("Can't configure a model that's not loaded!");
|
|
93
|
+
}
|
|
94
|
+
runner->set_topp(topp);
|
|
95
|
+
}
|
|
79
96
|
void LLM::unload() noexcept { runner.reset(nullptr); }
|
|
80
97
|
|
|
81
98
|
} // namespace rnexecutorch::models::llm
|
|
@@ -3,16 +3,16 @@
|
|
|
3
3
|
#include <memory>
|
|
4
4
|
#include <string>
|
|
5
5
|
|
|
6
|
-
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
|
|
7
6
|
#include <ReactCommon/CallInvoker.h>
|
|
8
7
|
#include <jsi/jsi.h>
|
|
8
|
+
#include <rnexecutorch/models/BaseModel.h>
|
|
9
9
|
#include <runner/runner.h>
|
|
10
10
|
|
|
11
11
|
namespace rnexecutorch {
|
|
12
12
|
namespace models::llm {
|
|
13
13
|
using namespace facebook;
|
|
14
14
|
|
|
15
|
-
class LLM {
|
|
15
|
+
class LLM : public BaseModel {
|
|
16
16
|
public:
|
|
17
17
|
explicit LLM(const std::string &modelSource,
|
|
18
18
|
const std::string &tokenizerSource,
|
|
@@ -24,12 +24,12 @@ public:
|
|
|
24
24
|
size_t getGeneratedTokenCount() const noexcept;
|
|
25
25
|
size_t getMemoryLowerBound() const noexcept;
|
|
26
26
|
void setCountInterval(size_t countInterval);
|
|
27
|
+
void setTemperature(float temperature);
|
|
28
|
+
void setTopp(float topp);
|
|
27
29
|
void setTimeInterval(size_t timeInterval);
|
|
28
30
|
|
|
29
31
|
private:
|
|
30
|
-
size_t memorySizeLowerBound;
|
|
31
32
|
std::unique_ptr<example::Runner> runner;
|
|
32
|
-
std::shared_ptr<react::CallInvoker> callInvoker;
|
|
33
33
|
};
|
|
34
34
|
} // namespace models::llm
|
|
35
35
|
|
|
@@ -60,16 +60,19 @@ cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image,
|
|
|
60
60
|
cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(),
|
|
61
61
|
cv::INTER_LINEAR);
|
|
62
62
|
|
|
63
|
-
|
|
63
|
+
constexpr int32_t rows = 4;
|
|
64
|
+
constexpr int32_t cols = 2;
|
|
65
|
+
cv::Mat rectMat(rows, cols, CV_32FC2);
|
|
64
66
|
#pragma unroll
|
|
65
|
-
for (int32_t i = 0; i <
|
|
67
|
+
for (int32_t i = 0; i < rows; ++i) {
|
|
66
68
|
rectMat.at<cv::Vec2f>(i, 0) = cv::Vec2f(rectPoints[i].x, rectPoints[i].y);
|
|
67
69
|
}
|
|
68
70
|
cv::transform(rectMat, rectMat, rotationMatrix);
|
|
69
71
|
|
|
70
|
-
|
|
72
|
+
constexpr size_t transformedPointsSize = 4;
|
|
73
|
+
std::vector<cv::Point2f> transformedPoints(transformedPointsSize);
|
|
71
74
|
#pragma unroll
|
|
72
|
-
for (std::size_t i = 0; i <
|
|
75
|
+
for (std::size_t i = 0; i < transformedPointsSize; ++i) {
|
|
73
76
|
cv::Vec2f point = rectMat.at<cv::Vec2f>(i, 0);
|
|
74
77
|
transformedPoints[i] = cv::Point2f(point[0], point[1]);
|
|
75
78
|
}
|
|
@@ -23,17 +23,22 @@ SpeechToText::SpeechToText(const std::string &encoderSource,
|
|
|
23
23
|
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
|
|
24
24
|
isStreaming(false), readyToProcess(false) {}
|
|
25
25
|
|
|
26
|
+
void SpeechToText::unload() noexcept {
|
|
27
|
+
this->encoder->unload();
|
|
28
|
+
this->decoder->unload();
|
|
29
|
+
}
|
|
30
|
+
|
|
26
31
|
std::shared_ptr<OwningArrayBuffer>
|
|
27
32
|
SpeechToText::encode(std::span<float> waveform) const {
|
|
28
33
|
std::vector<float> encoderOutput = this->asr->encode(waveform);
|
|
29
|
-
return
|
|
34
|
+
return std::make_shared<OwningArrayBuffer>(encoderOutput);
|
|
30
35
|
}
|
|
31
36
|
|
|
32
37
|
std::shared_ptr<OwningArrayBuffer>
|
|
33
38
|
SpeechToText::decode(std::span<int32_t> tokens,
|
|
34
39
|
std::span<float> encoderOutput) const {
|
|
35
40
|
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
|
|
36
|
-
return
|
|
41
|
+
return std::make_shared<OwningArrayBuffer>(decoderOutput);
|
|
37
42
|
}
|
|
38
43
|
|
|
39
44
|
std::vector<char> SpeechToText::transcribe(std::span<float> waveform,
|
|
@@ -61,17 +66,7 @@ std::vector<char> SpeechToText::transcribe(std::span<float> waveform,
|
|
|
61
66
|
|
|
62
67
|
size_t SpeechToText::getMemoryLowerBound() const noexcept {
|
|
63
68
|
return this->encoder->getMemoryLowerBound() +
|
|
64
|
-
this->decoder->getMemoryLowerBound()
|
|
65
|
-
this->tokenizer->getMemoryLowerBound();
|
|
66
|
-
}
|
|
67
|
-
|
|
68
|
-
std::shared_ptr<OwningArrayBuffer>
|
|
69
|
-
SpeechToText::makeOwningBuffer(std::span<const float> vectorView) const {
|
|
70
|
-
auto owningArrayBuffer =
|
|
71
|
-
std::make_shared<OwningArrayBuffer>(vectorView.size_bytes());
|
|
72
|
-
std::memcpy(owningArrayBuffer->data(), vectorView.data(),
|
|
73
|
-
vectorView.size_bytes());
|
|
74
|
-
return owningArrayBuffer;
|
|
69
|
+
this->decoder->getMemoryLowerBound();
|
|
75
70
|
}
|
|
76
71
|
|
|
77
72
|
void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
|
|
@@ -16,6 +16,7 @@ public:
|
|
|
16
16
|
const std::string &tokenizerSource,
|
|
17
17
|
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
18
18
|
|
|
19
|
+
void unload() noexcept;
|
|
19
20
|
std::shared_ptr<OwningArrayBuffer> encode(std::span<float> waveform) const;
|
|
20
21
|
std::shared_ptr<OwningArrayBuffer>
|
|
21
22
|
decode(std::span<int32_t> tokens, std::span<float> encoderOutput) const;
|
|
@@ -37,9 +38,6 @@ private:
|
|
|
37
38
|
std::unique_ptr<TokenizerModule> tokenizer;
|
|
38
39
|
std::unique_ptr<asr::ASR> asr;
|
|
39
40
|
|
|
40
|
-
std::shared_ptr<OwningArrayBuffer>
|
|
41
|
-
makeOwningBuffer(std::span<const float> vectorView) const;
|
|
42
|
-
|
|
43
41
|
// Stream
|
|
44
42
|
std::unique_ptr<stream::OnlineASRProcessor> processor;
|
|
45
43
|
bool isStreaming;
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
#include "ASR.h"
|
|
5
5
|
#include "executorch/extension/tensor/tensor_ptr.h"
|
|
6
6
|
#include "rnexecutorch/data_processing/Numerical.h"
|
|
7
|
-
#include "rnexecutorch/data_processing/dsp.h"
|
|
8
7
|
#include "rnexecutorch/data_processing/gzip.h"
|
|
9
8
|
|
|
10
9
|
namespace rnexecutorch::models::speech_to_text::asr {
|
|
@@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
|
|
|
37
36
|
return seq;
|
|
38
37
|
}
|
|
39
38
|
|
|
40
|
-
GenerationResult ASR::generate(std::span<
|
|
41
|
-
float temperature,
|
|
39
|
+
GenerationResult ASR::generate(std::span<float> waveform, float temperature,
|
|
42
40
|
const DecodingOptions &options) const {
|
|
43
41
|
std::vector<float> encoderOutput = this->encode(waveform);
|
|
44
42
|
|
|
@@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
|
|
|
94
92
|
}
|
|
95
93
|
|
|
96
94
|
std::vector<Segment>
|
|
97
|
-
ASR::generateWithFallback(std::span<
|
|
95
|
+
ASR::generateWithFallback(std::span<float> waveform,
|
|
98
96
|
const DecodingOptions &options) const {
|
|
99
97
|
std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
|
|
100
98
|
std::vector<int32_t> bestTokens;
|
|
@@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
|
|
|
209
207
|
return wordObjs;
|
|
210
208
|
}
|
|
211
209
|
|
|
212
|
-
std::vector<Segment> ASR::transcribe(std::span<
|
|
210
|
+
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
|
|
213
211
|
const DecodingOptions &options) const {
|
|
214
212
|
int32_t seek = 0;
|
|
215
213
|
std::vector<Segment> results;
|
|
@@ -218,7 +216,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
|
|
|
218
216
|
int32_t start = seek * ASR::kSamplingRate;
|
|
219
217
|
const auto end = std::min<int32_t>(
|
|
220
218
|
(seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
|
|
221
|
-
|
|
219
|
+
auto chunk = waveform.subspan(start, end - start);
|
|
222
220
|
|
|
223
221
|
if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
|
|
224
222
|
break;
|
|
@@ -246,19 +244,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
|
|
|
246
244
|
return results;
|
|
247
245
|
}
|
|
248
246
|
|
|
249
|
-
std::vector<float> ASR::encode(std::span<
|
|
250
|
-
|
|
251
|
-
constexpr int32_t stftHopLength = 160;
|
|
252
|
-
constexpr int32_t innerDim = 256;
|
|
253
|
-
|
|
254
|
-
std::vector<float> preprocessedData =
|
|
255
|
-
dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
|
|
256
|
-
const auto numFrames =
|
|
257
|
-
static_cast<int32_t>(preprocessedData.size()) / innerDim;
|
|
258
|
-
std::vector<int32_t> inputShape = {numFrames, innerDim};
|
|
247
|
+
std::vector<float> ASR::encode(std::span<float> waveform) const {
|
|
248
|
+
auto inputShape = {static_cast<int32_t>(waveform.size())};
|
|
259
249
|
|
|
260
250
|
const auto modelInputTensor = executorch::extension::make_tensor_ptr(
|
|
261
|
-
std::move(inputShape),
|
|
251
|
+
std::move(inputShape), waveform.data(),
|
|
252
|
+
executorch::runtime::etensor::ScalarType::Float);
|
|
262
253
|
const auto encoderResult = this->encoder->forward(modelInputTensor);
|
|
263
254
|
|
|
264
255
|
if (!encoderResult.ok()) {
|
|
@@ -268,7 +259,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
|
|
|
268
259
|
}
|
|
269
260
|
|
|
270
261
|
const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
|
|
271
|
-
const
|
|
262
|
+
const auto outputNumel = decoderOutputTensor.numel();
|
|
272
263
|
|
|
273
264
|
const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
|
|
274
265
|
return {dataPtr, dataPtr + outputNumel};
|
|
@@ -277,8 +268,10 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
|
|
|
277
268
|
std::vector<float> ASR::decode(std::span<int32_t> tokens,
|
|
278
269
|
std::span<float> encoderOutput) const {
|
|
279
270
|
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
|
|
271
|
+
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());
|
|
272
|
+
|
|
280
273
|
auto tokenTensor = executorch::extension::make_tensor_ptr(
|
|
281
|
-
|
|
274
|
+
tokenShape, tokensLong.data(), ScalarType::Long);
|
|
282
275
|
|
|
283
276
|
const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
|
|
284
277
|
std::vector<int32_t> encShape = {1, ASR::kNumFrames,
|
|
@@ -14,9 +14,9 @@ public:
|
|
|
14
14
|
const models::BaseModel *decoder,
|
|
15
15
|
const TokenizerModule *tokenizer);
|
|
16
16
|
std::vector<types::Segment>
|
|
17
|
-
transcribe(std::span<
|
|
17
|
+
transcribe(std::span<float> waveform,
|
|
18
18
|
const types::DecodingOptions &options) const;
|
|
19
|
-
std::vector<float> encode(std::span<
|
|
19
|
+
std::vector<float> encode(std::span<float> waveform) const;
|
|
20
20
|
std::vector<float> decode(std::span<int32_t> tokens,
|
|
21
21
|
std::span<float> encoderOutput) const;
|
|
22
22
|
|
|
@@ -44,11 +44,10 @@ private:
|
|
|
44
44
|
|
|
45
45
|
std::vector<int32_t>
|
|
46
46
|
getInitialSequence(const types::DecodingOptions &options) const;
|
|
47
|
-
types::GenerationResult generate(std::span<
|
|
48
|
-
float temperature,
|
|
47
|
+
types::GenerationResult generate(std::span<float> waveform, float temperature,
|
|
49
48
|
const types::DecodingOptions &options) const;
|
|
50
49
|
std::vector<types::Segment>
|
|
51
|
-
generateWithFallback(std::span<
|
|
50
|
+
generateWithFallback(std::span<float> waveform,
|
|
52
51
|
const types::DecodingOptions &options) const;
|
|
53
52
|
std::vector<types::Segment>
|
|
54
53
|
calculateWordLevelTimestamps(std::span<const int32_t> tokens,
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
#include "Decoder.h"
|
|
2
|
+
|
|
3
|
+
#include <cmath>
|
|
4
|
+
|
|
5
|
+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
|
|
6
|
+
|
|
7
|
+
namespace rnexecutorch::models::text_to_image {
|
|
8
|
+
|
|
9
|
+
using namespace executorch::extension;
|
|
10
|
+
|
|
11
|
+
Decoder::Decoder(const std::string &modelSource,
|
|
12
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
13
|
+
: BaseModel(modelSource, callInvoker) {}
|
|
14
|
+
|
|
15
|
+
std::vector<float> Decoder::generate(std::vector<float> &input) const {
|
|
16
|
+
std::vector<int32_t> inputShape = {1, numChannels, latentImageSize,
|
|
17
|
+
latentImageSize};
|
|
18
|
+
auto inputTensor =
|
|
19
|
+
make_tensor_ptr(inputShape, input.data(), ScalarType::Float);
|
|
20
|
+
|
|
21
|
+
auto forwardResult = BaseModel::forward(inputTensor);
|
|
22
|
+
if (!forwardResult.ok()) {
|
|
23
|
+
throw std::runtime_error(
|
|
24
|
+
"Function forward in decoder failed with error code: " +
|
|
25
|
+
std::to_string(static_cast<uint32_t>(forwardResult.error())));
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
auto forwardResultTensor = forwardResult->at(0).toTensor();
|
|
29
|
+
const auto *dataPtr = forwardResultTensor.const_data_ptr<float>();
|
|
30
|
+
return {dataPtr, dataPtr + forwardResultTensor.numel()};
|
|
31
|
+
}
|
|
32
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <memory>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
#include <ReactCommon/CallInvoker.h>
|
|
8
|
+
|
|
9
|
+
#include <rnexecutorch/models/BaseModel.h>
|
|
10
|
+
|
|
11
|
+
namespace rnexecutorch::models::text_to_image {
|
|
12
|
+
|
|
13
|
+
class Decoder final : public BaseModel {
|
|
14
|
+
public:
|
|
15
|
+
explicit Decoder(const std::string &modelSource,
|
|
16
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
17
|
+
std::vector<float> generate(std::vector<float> &input) const;
|
|
18
|
+
|
|
19
|
+
int32_t latentImageSize;
|
|
20
|
+
|
|
21
|
+
private:
|
|
22
|
+
static constexpr int32_t numChannels = 4;
|
|
23
|
+
};
|
|
24
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
#include "Encoder.h"
|
|
2
|
+
|
|
3
|
+
#include <cmath>
|
|
4
|
+
#include <random>
|
|
5
|
+
#include <span>
|
|
6
|
+
|
|
7
|
+
#include <rnexecutorch/models/text_to_image/Constants.h>
|
|
8
|
+
|
|
9
|
+
namespace rnexecutorch::models::text_to_image {
|
|
10
|
+
|
|
11
|
+
Encoder::Encoder(const std::string &tokenizerSource,
|
|
12
|
+
const std::string &encoderSource,
|
|
13
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
14
|
+
: callInvoker(callInvoker),
|
|
15
|
+
encoder(std::make_unique<embeddings::TextEmbeddings>(
|
|
16
|
+
encoderSource, tokenizerSource, callInvoker)) {}
|
|
17
|
+
|
|
18
|
+
std::vector<float> Encoder::generate(std::string input) {
|
|
19
|
+
std::shared_ptr<OwningArrayBuffer> embeddingsText = encoder->generate(input);
|
|
20
|
+
std::shared_ptr<OwningArrayBuffer> embeddingsUncond =
|
|
21
|
+
encoder->generate(std::string(constants::kBosToken));
|
|
22
|
+
|
|
23
|
+
assert(embeddingsText->size() == embeddingsUncond->size());
|
|
24
|
+
size_t embeddingsSize = embeddingsText->size() / sizeof(float);
|
|
25
|
+
auto *embeddingsTextPtr = reinterpret_cast<float *>(embeddingsText->data());
|
|
26
|
+
auto *embeddingsUncondPtr =
|
|
27
|
+
reinterpret_cast<float *>(embeddingsUncond->data());
|
|
28
|
+
|
|
29
|
+
std::vector<float> embeddingsConcat;
|
|
30
|
+
embeddingsConcat.reserve(embeddingsSize * 2);
|
|
31
|
+
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsUncondPtr,
|
|
32
|
+
embeddingsUncondPtr + embeddingsSize);
|
|
33
|
+
embeddingsConcat.insert(embeddingsConcat.end(), embeddingsTextPtr,
|
|
34
|
+
embeddingsTextPtr + embeddingsSize);
|
|
35
|
+
return embeddingsConcat;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
size_t Encoder::getMemoryLowerBound() const noexcept {
|
|
39
|
+
return encoder->getMemoryLowerBound();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
void Encoder::unload() noexcept { encoder->unload(); }
|
|
43
|
+
|
|
44
|
+
} // namespace rnexecutorch::models::text_to_image
|