react-native-executorch 0.5.15 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <memory>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
#include <ReactCommon/CallInvoker.h>
|
|
8
|
+
#include <jsi/jsi.h>
|
|
9
|
+
|
|
10
|
+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
|
|
11
|
+
|
|
12
|
+
#include <rnexecutorch/models/embeddings/text/TextEmbeddings.h>
|
|
13
|
+
|
|
14
|
+
namespace rnexecutorch {
|
|
15
|
+
namespace models::text_to_image {
|
|
16
|
+
using namespace facebook;
|
|
17
|
+
|
|
18
|
+
class Encoder final {
|
|
19
|
+
public:
|
|
20
|
+
explicit Encoder(const std::string &tokenizerSource,
|
|
21
|
+
const std::string &encoderSource,
|
|
22
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
23
|
+
std::vector<float> generate(std::string input);
|
|
24
|
+
size_t getMemoryLowerBound() const noexcept;
|
|
25
|
+
void unload() noexcept;
|
|
26
|
+
|
|
27
|
+
private:
|
|
28
|
+
std::shared_ptr<react::CallInvoker> callInvoker;
|
|
29
|
+
std::unique_ptr<embeddings::TextEmbeddings> encoder;
|
|
30
|
+
};
|
|
31
|
+
} // namespace models::text_to_image
|
|
32
|
+
} // namespace rnexecutorch
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
// The implementation of the PNDMScheduler class from the diffusers library
|
|
2
|
+
// https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py
|
|
3
|
+
|
|
4
|
+
#include "Scheduler.h"
|
|
5
|
+
|
|
6
|
+
#include <algorithm>
|
|
7
|
+
#include <cmath>
|
|
8
|
+
|
|
9
|
+
namespace rnexecutorch::models::text_to_image {
|
|
10
|
+
using namespace facebook;
|
|
11
|
+
|
|
12
|
+
Scheduler::Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps,
|
|
13
|
+
int32_t stepsOffset,
|
|
14
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
15
|
+
: numTrainTimesteps(numTrainTimesteps), stepsOffset(stepsOffset) {
|
|
16
|
+
const float start = std::sqrt(betaStart);
|
|
17
|
+
const float end = std::sqrt(betaEnd);
|
|
18
|
+
const float step = (end - start) / (numTrainTimesteps - 1);
|
|
19
|
+
|
|
20
|
+
float runningProduct = 1.0f;
|
|
21
|
+
alphas.reserve(numTrainTimesteps);
|
|
22
|
+
// alphasCumprod[t] — fraction of the signal remaining after t steps
|
|
23
|
+
alphasCumprod.reserve(numTrainTimesteps);
|
|
24
|
+
// betas[t] — amount of noise injected at timestep t
|
|
25
|
+
betas.reserve(numTrainTimesteps);
|
|
26
|
+
for (int32_t i = 0; i < numTrainTimesteps; ++i) {
|
|
27
|
+
const float value = start + step * i;
|
|
28
|
+
const float beta = value * value;
|
|
29
|
+
betas.push_back(beta);
|
|
30
|
+
|
|
31
|
+
const float alpha = 1.0f - beta;
|
|
32
|
+
alphas.push_back(alpha);
|
|
33
|
+
|
|
34
|
+
runningProduct *= alpha;
|
|
35
|
+
alphasCumprod.push_back(runningProduct);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
// finalAlphaCumprod — signal at the first training step (highest
|
|
39
|
+
// signal-to-noise ratio) used as reference at the end of diffusion process
|
|
40
|
+
if (!alphasCumprod.empty()) {
|
|
41
|
+
finalAlphaCumprod = alphasCumprod[0];
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
void Scheduler::setTimesteps(size_t numInferenceSteps) {
|
|
46
|
+
this->numInferenceSteps = numInferenceSteps;
|
|
47
|
+
ets.clear();
|
|
48
|
+
|
|
49
|
+
if (numInferenceSteps < 2) {
|
|
50
|
+
timesteps = {1};
|
|
51
|
+
return;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
timesteps.clear();
|
|
55
|
+
timesteps.reserve(numInferenceSteps + 1);
|
|
56
|
+
|
|
57
|
+
float numStepsRatio =
|
|
58
|
+
static_cast<float>(numTrainTimesteps) / numInferenceSteps;
|
|
59
|
+
for (size_t i = 0; i < numInferenceSteps; i++) {
|
|
60
|
+
const auto timestep =
|
|
61
|
+
static_cast<int32_t>(std::round(i * numStepsRatio)) + stepsOffset;
|
|
62
|
+
timesteps.push_back(timestep);
|
|
63
|
+
}
|
|
64
|
+
// Duplicate the timestep to provide enough points for the solver
|
|
65
|
+
timesteps.insert(timesteps.end() - 1, timesteps[numInferenceSteps - 2]);
|
|
66
|
+
std::ranges::reverse(timesteps);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
std::vector<float> Scheduler::step(const std::vector<float> &sample,
|
|
70
|
+
const std::vector<float> &noise,
|
|
71
|
+
int32_t timestep) {
|
|
72
|
+
if (numInferenceSteps == 0) {
|
|
73
|
+
throw std::runtime_error(
|
|
74
|
+
"Number of inference steps is not set. Call `set_timesteps` first.");
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
size_t noiseSize = noise.size();
|
|
78
|
+
std::vector<float> etsOutput(noiseSize);
|
|
79
|
+
float numStepsRatio =
|
|
80
|
+
static_cast<float>(numTrainTimesteps) / numInferenceSteps;
|
|
81
|
+
float timestepPrev = timestep - numStepsRatio;
|
|
82
|
+
|
|
83
|
+
if (ets.empty()) {
|
|
84
|
+
ets.push_back(noise);
|
|
85
|
+
etsOutput = noise;
|
|
86
|
+
tempFirstSample = sample;
|
|
87
|
+
return getPrevSample(sample, etsOutput, timestep, timestepPrev);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// Use the previous sample as the estimate requires at least 2 points
|
|
91
|
+
if (ets.size() == 1 && !tempFirstSample.empty()) {
|
|
92
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
93
|
+
etsOutput[i] = (noise[i] + ets[0][i]) / 2;
|
|
94
|
+
}
|
|
95
|
+
auto prevSample = getPrevSample(std::move(tempFirstSample), etsOutput,
|
|
96
|
+
timestep + numStepsRatio, timestep);
|
|
97
|
+
tempFirstSample.clear();
|
|
98
|
+
return prevSample;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// Coefficients come from the linear multistep method
|
|
102
|
+
// https://en.wikipedia.org/wiki/Linear_multistep_method
|
|
103
|
+
ets.push_back(noise);
|
|
104
|
+
|
|
105
|
+
if (ets.size() == 2) {
|
|
106
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
107
|
+
etsOutput[i] = (ets[1][i] * 3 - ets[0][i]) / 2;
|
|
108
|
+
}
|
|
109
|
+
} else if (ets.size() == 3) {
|
|
110
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
111
|
+
etsOutput[i] = ((ets[2][i] * 23 - ets[1][i] * 16) + ets[0][i] * 5) / 12;
|
|
112
|
+
}
|
|
113
|
+
} else {
|
|
114
|
+
ets.assign(ets.end() - 4, ets.end());
|
|
115
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
116
|
+
etsOutput[i] =
|
|
117
|
+
(ets[3][i] * 55 - ets[2][i] * 59 + ets[1][i] * 37 - ets[0][i] * 9) /
|
|
118
|
+
24;
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
return getPrevSample(sample, etsOutput, timestep, timestepPrev);
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
std::vector<float> Scheduler::getPrevSample(const std::vector<float> &sample,
|
|
125
|
+
const std::vector<float> &noise,
|
|
126
|
+
int32_t timestep,
|
|
127
|
+
int32_t timestepPrev) const {
|
|
128
|
+
const float alpha = alphasCumprod[timestep];
|
|
129
|
+
const float alphaPrev =
|
|
130
|
+
timestepPrev >= 0 ? alphasCumprod[timestepPrev] : finalAlphaCumprod;
|
|
131
|
+
const float beta = 1 - alpha;
|
|
132
|
+
const float betaPrev = 1 - alphaPrev;
|
|
133
|
+
|
|
134
|
+
size_t noiseSize = noise.size();
|
|
135
|
+
const float noiseCoeff =
|
|
136
|
+
(alphaPrev - alpha) /
|
|
137
|
+
(alpha * std::sqrt(betaPrev) + std::sqrt(alpha * beta * alphaPrev));
|
|
138
|
+
const float sampleCoeff = std::sqrt(alphaPrev / alpha);
|
|
139
|
+
|
|
140
|
+
std::vector<float> samplePrev;
|
|
141
|
+
samplePrev.reserve(noiseSize);
|
|
142
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
143
|
+
const float noiseTerm =
|
|
144
|
+
(noise[i] * std::sqrt(alpha) + sample[i] * std::sqrt(beta)) *
|
|
145
|
+
noiseCoeff;
|
|
146
|
+
samplePrev.push_back(sample[i] * sampleCoeff - noiseTerm);
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
return samplePrev;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <memory>
|
|
4
|
+
#include <vector>
|
|
5
|
+
|
|
6
|
+
#include <ReactCommon/CallInvoker.h>
|
|
7
|
+
|
|
8
|
+
namespace rnexecutorch::models::text_to_image {
|
|
9
|
+
|
|
10
|
+
using namespace facebook;
|
|
11
|
+
|
|
12
|
+
class Scheduler final {
|
|
13
|
+
public:
|
|
14
|
+
explicit Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps,
|
|
15
|
+
int32_t stepsOfset,
|
|
16
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
17
|
+
void setTimesteps(size_t numInferenceSteps);
|
|
18
|
+
std::vector<float> step(const std::vector<float> &sample,
|
|
19
|
+
const std::vector<float> &noise, int32_t timestep);
|
|
20
|
+
|
|
21
|
+
std::vector<int32_t> timesteps;
|
|
22
|
+
|
|
23
|
+
private:
|
|
24
|
+
int32_t numTrainTimesteps;
|
|
25
|
+
int32_t stepsOffset;
|
|
26
|
+
|
|
27
|
+
std::vector<float> betas;
|
|
28
|
+
std::vector<float> alphas;
|
|
29
|
+
std::vector<float> alphasCumprod;
|
|
30
|
+
std::vector<float> tempFirstSample;
|
|
31
|
+
std::vector<std::vector<float>> ets;
|
|
32
|
+
float finalAlphaCumprod{1.0f};
|
|
33
|
+
|
|
34
|
+
size_t numInferenceSteps{0};
|
|
35
|
+
|
|
36
|
+
std::vector<float> getPrevSample(const std::vector<float> &sample,
|
|
37
|
+
const std::vector<float> &noise,
|
|
38
|
+
int32_t timestep,
|
|
39
|
+
int32_t prevTimestep) const;
|
|
40
|
+
};
|
|
41
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#include "TextToImage.h"
|
|
2
|
+
|
|
3
|
+
#include <cmath>
|
|
4
|
+
#include <random>
|
|
5
|
+
#include <span>
|
|
6
|
+
|
|
7
|
+
#include <executorch/extension/tensor/tensor.h>
|
|
8
|
+
|
|
9
|
+
#include <rnexecutorch/Log.h>
|
|
10
|
+
#include <rnexecutorch/models/text_to_image/Constants.h>
|
|
11
|
+
|
|
12
|
+
namespace rnexecutorch::models::text_to_image {
|
|
13
|
+
|
|
14
|
+
using namespace executorch::extension;
|
|
15
|
+
|
|
16
|
+
TextToImage::TextToImage(const std::string &tokenizerSource,
|
|
17
|
+
const std::string &encoderSource,
|
|
18
|
+
const std::string &unetSource,
|
|
19
|
+
const std::string &decoderSource,
|
|
20
|
+
float schedulerBetaStart, float schedulerBetaEnd,
|
|
21
|
+
int32_t schedulerNumTrainTimesteps,
|
|
22
|
+
int32_t schedulerStepsOffset,
|
|
23
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
24
|
+
: callInvoker(callInvoker),
|
|
25
|
+
scheduler(std::make_unique<Scheduler>(
|
|
26
|
+
schedulerBetaStart, schedulerBetaEnd, schedulerNumTrainTimesteps,
|
|
27
|
+
schedulerStepsOffset, callInvoker)),
|
|
28
|
+
encoder(std::make_unique<Encoder>(tokenizerSource, encoderSource,
|
|
29
|
+
callInvoker)),
|
|
30
|
+
unet(std::make_unique<UNet>(unetSource, callInvoker)),
|
|
31
|
+
decoder(std::make_unique<Decoder>(decoderSource, callInvoker)) {}
|
|
32
|
+
|
|
33
|
+
void TextToImage::setImageSize(int32_t imageSize) {
|
|
34
|
+
if (imageSize % 32 != 0) {
|
|
35
|
+
throw std::runtime_error("Image size must be a multiple of 32.");
|
|
36
|
+
}
|
|
37
|
+
this->imageSize = imageSize;
|
|
38
|
+
constexpr int32_t latentDownsample = 8;
|
|
39
|
+
latentImageSize = std::floor(imageSize / latentDownsample);
|
|
40
|
+
unet->latentImageSize = latentImageSize;
|
|
41
|
+
decoder->latentImageSize = latentImageSize;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
void TextToImage::setSeed(int32_t &seed) {
|
|
45
|
+
// Seed argument is provided
|
|
46
|
+
if (seed >= 0) {
|
|
47
|
+
return;
|
|
48
|
+
}
|
|
49
|
+
std::random_device rd;
|
|
50
|
+
seed = rd();
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
54
|
+
TextToImage::generate(std::string input, int32_t imageSize,
|
|
55
|
+
size_t numInferenceSteps, int32_t seed,
|
|
56
|
+
std::shared_ptr<jsi::Function> callback) {
|
|
57
|
+
setImageSize(imageSize);
|
|
58
|
+
setSeed(seed);
|
|
59
|
+
|
|
60
|
+
std::vector<float> embeddings = encoder->generate(input);
|
|
61
|
+
std::vector<int32_t> embeddingsShape = {2, 77, 768};
|
|
62
|
+
auto embeddingsTensor =
|
|
63
|
+
make_tensor_ptr(embeddingsShape, embeddings.data(), ScalarType::Float);
|
|
64
|
+
|
|
65
|
+
int32_t latentsSize = numChannels * latentImageSize * latentImageSize;
|
|
66
|
+
std::vector<float> latents(latentsSize);
|
|
67
|
+
std::mt19937 gen(seed);
|
|
68
|
+
std::normal_distribution<float> dist(0.0, 1.0);
|
|
69
|
+
for (auto &val : latents) {
|
|
70
|
+
val = dist(gen);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
scheduler->setTimesteps(numInferenceSteps);
|
|
74
|
+
std::vector<int32_t> timesteps = scheduler->timesteps;
|
|
75
|
+
|
|
76
|
+
auto nativeCallback = [this, callback](size_t stepIdx) {
|
|
77
|
+
this->callInvoker->invokeAsync([callback, stepIdx](jsi::Runtime &runtime) {
|
|
78
|
+
callback->call(runtime, jsi::Value(static_cast<int32_t>(stepIdx)));
|
|
79
|
+
});
|
|
80
|
+
};
|
|
81
|
+
for (size_t t = 0; t < numInferenceSteps + 1 && !interrupted; t++) {
|
|
82
|
+
log(LOG_LEVEL::Debug, "Step:", t, "/", numInferenceSteps);
|
|
83
|
+
|
|
84
|
+
std::vector<float> noisePred =
|
|
85
|
+
unet->generate(latents, timesteps[t], embeddingsTensor);
|
|
86
|
+
|
|
87
|
+
size_t noiseSize = noisePred.size() / 2;
|
|
88
|
+
std::span<const float> noisePredSpan{noisePred};
|
|
89
|
+
std::span<const float> noiseUncond = noisePredSpan.subspan(0, noiseSize);
|
|
90
|
+
std::span<const float> noiseText =
|
|
91
|
+
noisePredSpan.subspan(noiseSize, noiseSize);
|
|
92
|
+
std::vector<float> noise(noiseSize);
|
|
93
|
+
for (size_t i = 0; i < noiseSize; i++) {
|
|
94
|
+
noise[i] =
|
|
95
|
+
noiseUncond[i] * (1 - guidanceScale) + noiseText[i] * guidanceScale;
|
|
96
|
+
}
|
|
97
|
+
latents = scheduler->step(latents, noise, timesteps[t]);
|
|
98
|
+
|
|
99
|
+
nativeCallback(t);
|
|
100
|
+
}
|
|
101
|
+
if (interrupted) {
|
|
102
|
+
interrupted = false;
|
|
103
|
+
return std::make_shared<OwningArrayBuffer>(0);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
for (auto &val : latents) {
|
|
107
|
+
val /= latentsScale;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
std::vector<float> output = decoder->generate(latents);
|
|
111
|
+
return postprocess(output);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
115
|
+
TextToImage::postprocess(const std::vector<float> &output) const {
|
|
116
|
+
// Convert RGB to RGBA
|
|
117
|
+
int32_t imagePixelCount = imageSize * imageSize;
|
|
118
|
+
std::vector<uint8_t> outputRgba(imagePixelCount * 4);
|
|
119
|
+
for (int32_t i = 0; i < imagePixelCount; i++) {
|
|
120
|
+
outputRgba[i * 4 + 0] = output[i * 3 + 0];
|
|
121
|
+
outputRgba[i * 4 + 1] = output[i * 3 + 1];
|
|
122
|
+
outputRgba[i * 4 + 2] = output[i * 3 + 2];
|
|
123
|
+
outputRgba[i * 4 + 3] = 255;
|
|
124
|
+
}
|
|
125
|
+
return std::make_shared<OwningArrayBuffer>(outputRgba);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
void TextToImage::interrupt() noexcept { interrupted = true; }
|
|
129
|
+
|
|
130
|
+
size_t TextToImage::getMemoryLowerBound() const noexcept {
|
|
131
|
+
return encoder->getMemoryLowerBound() + unet->getMemoryLowerBound() +
|
|
132
|
+
decoder->getMemoryLowerBound();
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
void TextToImage::unload() noexcept {
|
|
136
|
+
encoder->unload();
|
|
137
|
+
unet->unload();
|
|
138
|
+
decoder->unload();
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <memory>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
#include <ReactCommon/CallInvoker.h>
|
|
8
|
+
#include <jsi/jsi.h>
|
|
9
|
+
|
|
10
|
+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
|
|
11
|
+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
|
|
12
|
+
|
|
13
|
+
#include <rnexecutorch/models/text_to_image/Decoder.h>
|
|
14
|
+
#include <rnexecutorch/models/text_to_image/Encoder.h>
|
|
15
|
+
#include <rnexecutorch/models/text_to_image/Scheduler.h>
|
|
16
|
+
#include <rnexecutorch/models/text_to_image/UNet.h>
|
|
17
|
+
|
|
18
|
+
namespace rnexecutorch {
|
|
19
|
+
namespace models::text_to_image {
|
|
20
|
+
using namespace facebook;
|
|
21
|
+
|
|
22
|
+
class TextToImage final {
|
|
23
|
+
public:
|
|
24
|
+
explicit TextToImage(const std::string &tokenizerSource,
|
|
25
|
+
const std::string &encoderSource,
|
|
26
|
+
const std::string &unetSource,
|
|
27
|
+
const std::string &decoderSource,
|
|
28
|
+
float schedulerBetaStart, float schedulerBetaEnd,
|
|
29
|
+
int32_t schedulerNumTrainTimesteps,
|
|
30
|
+
int32_t schedulerStepsOffset,
|
|
31
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
32
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
33
|
+
generate(std::string input, int32_t imageSize, size_t numInferenceSteps,
|
|
34
|
+
int32_t seed, std::shared_ptr<jsi::Function> callback);
|
|
35
|
+
void interrupt() noexcept;
|
|
36
|
+
size_t getMemoryLowerBound() const noexcept;
|
|
37
|
+
void unload() noexcept;
|
|
38
|
+
|
|
39
|
+
private:
|
|
40
|
+
void setImageSize(int32_t imageSize);
|
|
41
|
+
void setSeed(int32_t &seed);
|
|
42
|
+
std::shared_ptr<OwningArrayBuffer>
|
|
43
|
+
postprocess(const std::vector<float> &output) const;
|
|
44
|
+
|
|
45
|
+
size_t memorySizeLowerBound;
|
|
46
|
+
int32_t imageSize;
|
|
47
|
+
int32_t latentImageSize;
|
|
48
|
+
static constexpr int32_t numChannels = 4;
|
|
49
|
+
static constexpr float guidanceScale = 7.5f;
|
|
50
|
+
static constexpr float latentsScale = 0.18215f;
|
|
51
|
+
bool interrupted = false;
|
|
52
|
+
|
|
53
|
+
std::shared_ptr<react::CallInvoker> callInvoker;
|
|
54
|
+
std::unique_ptr<Scheduler> scheduler;
|
|
55
|
+
std::unique_ptr<Encoder> encoder;
|
|
56
|
+
std::unique_ptr<UNet> unet;
|
|
57
|
+
std::unique_ptr<Decoder> decoder;
|
|
58
|
+
};
|
|
59
|
+
} // namespace models::text_to_image
|
|
60
|
+
|
|
61
|
+
REGISTER_CONSTRUCTOR(models::text_to_image::TextToImage, std::string,
|
|
62
|
+
std::string, std::string, std::string, float, float,
|
|
63
|
+
int32_t, int32_t, std::shared_ptr<react::CallInvoker>);
|
|
64
|
+
} // namespace rnexecutorch
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
#include "UNet.h"
|
|
2
|
+
|
|
3
|
+
namespace rnexecutorch::models::text_to_image {
|
|
4
|
+
|
|
5
|
+
using namespace executorch::extension;
|
|
6
|
+
|
|
7
|
+
UNet::UNet(const std::string &modelSource,
|
|
8
|
+
std::shared_ptr<react::CallInvoker> callInvoker)
|
|
9
|
+
: BaseModel(modelSource, callInvoker) {}
|
|
10
|
+
|
|
11
|
+
std::vector<float> UNet::generate(std::vector<float> &latents, int32_t timestep,
|
|
12
|
+
TensorPtr &embeddingsTensor) const {
|
|
13
|
+
std::vector<float> latentsConcat;
|
|
14
|
+
latentsConcat.reserve(2 * latentImageSize);
|
|
15
|
+
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
|
|
16
|
+
latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end());
|
|
17
|
+
|
|
18
|
+
std::vector<int32_t> latentsShape = {2, numChannels, latentImageSize,
|
|
19
|
+
latentImageSize};
|
|
20
|
+
|
|
21
|
+
auto timestepTensor =
|
|
22
|
+
make_tensor_ptr<int64_t>({static_cast<int64_t>(timestep)});
|
|
23
|
+
auto latentsTensor =
|
|
24
|
+
make_tensor_ptr(latentsShape, latentsConcat.data(), ScalarType::Float);
|
|
25
|
+
|
|
26
|
+
auto forwardResult =
|
|
27
|
+
BaseModel::forward({latentsTensor, timestepTensor, embeddingsTensor});
|
|
28
|
+
if (!forwardResult.ok()) {
|
|
29
|
+
throw std::runtime_error(
|
|
30
|
+
"Function forward in UNet failed with error code: " +
|
|
31
|
+
std::to_string(static_cast<uint32_t>(forwardResult.error())));
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
auto forwardResultTensor = forwardResult->at(0).toTensor();
|
|
35
|
+
const auto *dataPtr = forwardResultTensor.const_data_ptr<float>();
|
|
36
|
+
return {dataPtr, dataPtr + forwardResultTensor.numel()};
|
|
37
|
+
}
|
|
38
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <memory>
|
|
4
|
+
#include <string>
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
#include <executorch/extension/tensor/tensor.h>
|
|
8
|
+
|
|
9
|
+
#include <ReactCommon/CallInvoker.h>
|
|
10
|
+
#include <rnexecutorch/models/BaseModel.h>
|
|
11
|
+
|
|
12
|
+
namespace rnexecutorch::models::text_to_image {
|
|
13
|
+
|
|
14
|
+
using namespace executorch::extension;
|
|
15
|
+
|
|
16
|
+
class UNet final : public BaseModel {
|
|
17
|
+
public:
|
|
18
|
+
explicit UNet(const std::string &modelSource,
|
|
19
|
+
std::shared_ptr<react::CallInvoker> callInvoker);
|
|
20
|
+
std::vector<float> generate(std::vector<float> &latents, int32_t timestep,
|
|
21
|
+
TensorPtr &embeddingsTensor) const;
|
|
22
|
+
|
|
23
|
+
int32_t latentImageSize;
|
|
24
|
+
|
|
25
|
+
private:
|
|
26
|
+
static constexpr int32_t numChannels = 4;
|
|
27
|
+
};
|
|
28
|
+
} // namespace rnexecutorch::models::text_to_image
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <bit>
|
|
4
|
+
#include <cstddef>
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
namespace rnexecutorch::models::voice_activity_detection::constants {
|
|
7
|
+
|
|
8
|
+
inline constexpr uint32_t kSampleRate = 16000;
|
|
9
|
+
inline constexpr auto kMstoSecond = 0.001f;
|
|
10
|
+
inline constexpr uint32_t kWindowSizeMs = 25;
|
|
11
|
+
inline constexpr uint32_t kHopLengthMs = 10;
|
|
12
|
+
inline constexpr auto kWindowSize =
|
|
13
|
+
static_cast<uint32_t>(kMstoSecond * kWindowSizeMs * kSampleRate); // 400
|
|
14
|
+
inline constexpr auto kHopLength =
|
|
15
|
+
static_cast<uint32_t>(kMstoSecond * kHopLengthMs * kSampleRate); // 160
|
|
16
|
+
inline constexpr auto kPreemphasisCoeff = 0.97f;
|
|
17
|
+
inline constexpr auto kLeftPadding = (kWindowSize - 1) / 2;
|
|
18
|
+
inline constexpr auto kRightPadding = kWindowSize / 2;
|
|
19
|
+
inline constexpr auto kPaddedWindowSize = std::bit_ceil(kWindowSize); // 512
|
|
20
|
+
inline constexpr size_t kModelInputMin = 100;
|
|
21
|
+
inline constexpr size_t kModelInputMax = 1000;
|
|
22
|
+
inline constexpr auto kSpeechThreshold = 0.6f;
|
|
23
|
+
inline constexpr size_t kMinSpeechDuration = 25; // 250 ms
|
|
24
|
+
inline constexpr size_t kMinSilenceDuration = 10; // 100 ms
|
|
25
|
+
inline constexpr size_t kSpeechPad = 3; // 30 ms
|
|
26
|
+
|
|
27
|
+
} // namespace rnexecutorch::models::voice_activity_detection::constants
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#include "Utils.h"
|
|
2
|
+
|
|
3
|
+
namespace rnexecutorch::models::voice_activity_detection::utils {
|
|
4
|
+
size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
|
|
5
|
+
size_t numClass, size_t size,
|
|
6
|
+
std::vector<float> &resultVector,
|
|
7
|
+
size_t startIdx) {
|
|
8
|
+
const auto *rawData = tensor.const_data_ptr<float>();
|
|
9
|
+
for (size_t i = 0; i < size; i++) {
|
|
10
|
+
resultVector[startIdx + i] = rawData[numClass * i];
|
|
11
|
+
}
|
|
12
|
+
return startIdx + size;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
} // namespace rnexecutorch::models::voice_activity_detection::utils
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstddef>
|
|
4
|
+
#include <executorch/extension/tensor/tensor.h>
|
|
5
|
+
#include <vector>
|
|
6
|
+
|
|
7
|
+
namespace rnexecutorch::models::voice_activity_detection::utils {
|
|
8
|
+
size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
|
|
9
|
+
size_t numClass, size_t size,
|
|
10
|
+
std::vector<float> &resultVector,
|
|
11
|
+
size_t startIdx);
|
|
12
|
+
|
|
13
|
+
} // namespace rnexecutorch::models::voice_activity_detection::utils
|