react-native-executorch 0.5.15 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <limits>
|
|
4
|
+
#include <torch/headeronly/macros/Macros.h>
|
|
5
|
+
#include <type_traits>
|
|
6
|
+
|
|
7
|
+
C10_CLANG_DIAGNOSTIC_PUSH()
|
|
8
|
+
#if C10_CLANG_HAS_WARNING("-Wstring-conversion")
|
|
9
|
+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wstring-conversion")
|
|
10
|
+
#endif
|
|
11
|
+
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
|
12
|
+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
|
13
|
+
#endif
|
|
14
|
+
|
|
15
|
+
namespace c10 {
|
|
16
|
+
|
|
17
|
+
/// Returns false since we cannot have x < 0 if x is unsigned.
|
|
18
|
+
template <typename T>
|
|
19
|
+
inline constexpr bool is_negative(const T & /*x*/,
|
|
20
|
+
std::true_type /*is_unsigned*/) {
|
|
21
|
+
return false;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
/// Returns true if a signed variable x < 0
|
|
25
|
+
template <typename T>
|
|
26
|
+
inline constexpr bool is_negative(const T &x, std::false_type /*is_unsigned*/) {
|
|
27
|
+
return x < T(0);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/// Returns true if x < 0
|
|
31
|
+
/// NOTE: Will fail on an unsigned custom type
|
|
32
|
+
/// For the most part it's possible to fix this if
|
|
33
|
+
/// the custom type has a constexpr constructor.
|
|
34
|
+
/// However, notably, c10::Half does not :-(
|
|
35
|
+
template <typename T> inline constexpr bool is_negative(const T &x) {
|
|
36
|
+
return is_negative(x, std::is_unsigned<T>());
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/// Returns the sign of an unsigned variable x as 0, 1
|
|
40
|
+
template <typename T>
|
|
41
|
+
inline constexpr int signum(const T &x, std::true_type /*is_unsigned*/) {
|
|
42
|
+
return T(0) < x;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
/// Returns the sign of a signed variable x as -1, 0, 1
|
|
46
|
+
template <typename T>
|
|
47
|
+
inline constexpr int signum(const T &x, std::false_type /*is_unsigned*/) {
|
|
48
|
+
return (T(0) < x) - (x < T(0));
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
/// Returns the sign of x as -1, 0, 1
|
|
52
|
+
/// NOTE: Will fail on an unsigned custom type
|
|
53
|
+
/// For the most part it's possible to fix this if
|
|
54
|
+
/// the custom type has a constexpr constructor.
|
|
55
|
+
/// However, notably, c10::Half does not :-(
|
|
56
|
+
template <typename T> inline constexpr int signum(const T &x) {
|
|
57
|
+
return signum(x, std::is_unsigned<T>());
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/// Returns true if a and b are not both negative
|
|
61
|
+
template <typename T, typename U>
|
|
62
|
+
inline constexpr bool signs_differ(const T &a, const U &b) {
|
|
63
|
+
return is_negative(a) != is_negative(b);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// Suppress sign compare warning when compiling with GCC
|
|
67
|
+
// as later does not account for short-circuit rule before
|
|
68
|
+
// raising the warning, see https://godbolt.org/z/Tr3Msnz99
|
|
69
|
+
#ifdef __GNUC__
|
|
70
|
+
#pragma GCC diagnostic push
|
|
71
|
+
#pragma GCC diagnostic ignored "-Wsign-compare"
|
|
72
|
+
#endif
|
|
73
|
+
|
|
74
|
+
/// Returns true if x is greater than the greatest value of the type Limit
|
|
75
|
+
template <typename Limit, typename T>
|
|
76
|
+
inline constexpr bool greater_than_max(const T &x) {
|
|
77
|
+
constexpr bool can_overflow =
|
|
78
|
+
std::numeric_limits<T>::digits > std::numeric_limits<Limit>::digits;
|
|
79
|
+
return can_overflow && x > (std::numeric_limits<Limit>::max)();
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
#ifdef __GNUC__
|
|
83
|
+
#pragma GCC diagnostic pop
|
|
84
|
+
#endif
|
|
85
|
+
|
|
86
|
+
/// Returns true if x < lowest(Limit). Standard comparison
|
|
87
|
+
template <typename Limit, typename T>
|
|
88
|
+
inline constexpr bool less_than_lowest(const T &x,
|
|
89
|
+
std::false_type /*limit_is_unsigned*/,
|
|
90
|
+
std::false_type /*x_is_unsigned*/) {
|
|
91
|
+
return x < std::numeric_limits<Limit>::lowest();
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/// Returns false since all the limit is signed and therefore includes
|
|
95
|
+
/// negative values but x cannot be negative because it is unsigned
|
|
96
|
+
template <typename Limit, typename T>
|
|
97
|
+
inline constexpr bool less_than_lowest(const T & /*x*/,
|
|
98
|
+
std::false_type /*limit_is_unsigned*/,
|
|
99
|
+
std::true_type /*x_is_unsigned*/) {
|
|
100
|
+
return false;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
/// Returns true if x < 0, where 0 is constructed from T.
|
|
104
|
+
/// Limit is not signed, so its lower value is zero
|
|
105
|
+
template <typename Limit, typename T>
|
|
106
|
+
inline constexpr bool less_than_lowest(const T &x,
|
|
107
|
+
std::true_type /*limit_is_unsigned*/,
|
|
108
|
+
std::false_type /*x_is_unsigned*/) {
|
|
109
|
+
return x < T(0);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
/// Returns false sign both types are unsigned
|
|
113
|
+
template <typename Limit, typename T>
|
|
114
|
+
inline constexpr bool less_than_lowest(const T & /*x*/,
|
|
115
|
+
std::true_type /*limit_is_unsigned*/,
|
|
116
|
+
std::true_type /*x_is_unsigned*/) {
|
|
117
|
+
return false;
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/// Returns true if x is less than the lowest value of type T
|
|
121
|
+
/// NOTE: Will fail on an unsigned custom type
|
|
122
|
+
/// For the most part it's possible to fix this if
|
|
123
|
+
/// the custom type has a constexpr constructor.
|
|
124
|
+
/// However, notably, c10::Half does not :
|
|
125
|
+
template <typename Limit, typename T>
|
|
126
|
+
inline constexpr bool less_than_lowest(const T &x) {
|
|
127
|
+
return less_than_lowest<Limit>(x, std::is_unsigned<Limit>(),
|
|
128
|
+
std::is_unsigned<T>());
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
} // namespace c10
|
|
132
|
+
|
|
133
|
+
C10_CLANG_DIAGNOSTIC_POP()
|
|
134
|
+
|
|
135
|
+
namespace torch::headeronly {
|
|
136
|
+
using c10::greater_than_max;
|
|
137
|
+
using c10::is_negative;
|
|
138
|
+
using c10::less_than_lowest;
|
|
139
|
+
using c10::signs_differ;
|
|
140
|
+
using c10::signum;
|
|
141
|
+
} // namespace torch::headeronly
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstring>
|
|
4
|
+
#include <type_traits>
|
|
5
|
+
|
|
6
|
+
#include <torch/headeronly/macros/Macros.h>
|
|
7
|
+
|
|
8
|
+
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
|
|
9
|
+
#include <bit>
|
|
10
|
+
#define C10_HAVE_STD_BIT_CAST 1
|
|
11
|
+
#else
|
|
12
|
+
#define C10_HAVE_STD_BIT_CAST 0
|
|
13
|
+
#endif // __has_include(<bit>) && (__cplusplus >= 202002L ||
|
|
14
|
+
// (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L))
|
|
15
|
+
|
|
16
|
+
namespace torch::headeronly {
|
|
17
|
+
|
|
18
|
+
#if C10_HAVE_STD_BIT_CAST
|
|
19
|
+
using std::bit_cast;
|
|
20
|
+
#else
|
|
21
|
+
// Implementations of std::bit_cast() from C++ 20.
|
|
22
|
+
//
|
|
23
|
+
// This is a less sketchy version of reinterpret_cast.
|
|
24
|
+
//
|
|
25
|
+
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
|
|
26
|
+
// information as well as the source of our implementations.
|
|
27
|
+
template <class To, class From>
|
|
28
|
+
C10_HOST_DEVICE std::enable_if_t<sizeof(To) == sizeof(From) &&
|
|
29
|
+
std::is_trivially_copyable_v<From> &&
|
|
30
|
+
std::is_trivially_copyable_v<To>,
|
|
31
|
+
To>
|
|
32
|
+
// constexpr support needs compiler magic
|
|
33
|
+
bit_cast(const From &src) noexcept {
|
|
34
|
+
static_assert(std::is_trivially_constructible_v<To>,
|
|
35
|
+
"This implementation additionally requires "
|
|
36
|
+
"destination type to be trivially constructible");
|
|
37
|
+
|
|
38
|
+
To dst;
|
|
39
|
+
std::memcpy(&dst, &src, sizeof(To));
|
|
40
|
+
return dst;
|
|
41
|
+
}
|
|
42
|
+
#endif // C10_HAVE_STD_BIT_CAST
|
|
43
|
+
#undef C10_HAVE_STD_BIT_CAST
|
|
44
|
+
|
|
45
|
+
} // namespace torch::headeronly
|
|
46
|
+
|
|
47
|
+
namespace c10 {
|
|
48
|
+
using torch::headeronly::bit_cast;
|
|
49
|
+
} // namespace c10
|