react-native-executorch 0.5.15 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +42 -36
- package/android/CMakeLists.txt +13 -25
- package/android/build.gradle +2 -3
- package/android/libs/classes.jar +0 -0
- package/android/src/main/cpp/CMakeLists.txt +2 -1
- package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
- package/common/rnexecutorch/TokenizerModule.cpp +3 -3
- package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
- package/common/rnexecutorch/data_processing/Numerical.h +6 -1
- package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
- package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
- package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
- package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
- package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
- package/common/rnexecutorch/models/BaseModel.cpp +12 -11
- package/common/rnexecutorch/models/BaseModel.h +18 -10
- package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
- package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
- package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
- package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
- package/common/rnexecutorch/models/llm/LLM.h +4 -4
- package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
- package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
- package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
- package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
- package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
- package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
- package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
- package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
- package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
- package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
- package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
- package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
- package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
- package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
- package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
- package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
- package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
- package/common/rnexecutorch/tests/README.md +30 -13
- package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
- package/common/runner/arange_util.cpp +44 -0
- package/common/runner/arange_util.h +37 -0
- package/common/runner/constants.h +28 -0
- package/common/runner/io_manager.h +240 -0
- package/common/runner/irunner.h +87 -16
- package/common/runner/kernel_includes.h +23 -0
- package/common/runner/runner.cpp +151 -66
- package/common/runner/runner.h +39 -22
- package/common/runner/sampler.cpp +8 -1
- package/common/runner/sampler.h +4 -2
- package/common/runner/stats.h +1 -4
- package/common/runner/text_decoder_runner.cpp +26 -12
- package/common/runner/text_decoder_runner.h +52 -31
- package/common/runner/text_prefiller.cpp +46 -12
- package/common/runner/text_prefiller.h +38 -4
- package/common/runner/text_token_generator.h +51 -26
- package/common/runner/util.h +53 -8
- package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
- package/lib/module/Error.js +1 -0
- package/lib/module/Error.js.map +1 -1
- package/lib/module/constants/directories.js +1 -1
- package/lib/module/constants/directories.js.map +1 -1
- package/lib/module/constants/modelUrls.js +32 -1
- package/lib/module/constants/modelUrls.js.map +1 -1
- package/lib/module/constants/ocr/models.js +7 -7
- package/lib/module/constants/ocr/models.js.map +1 -1
- package/lib/module/constants/ocr/symbols.js +3 -2
- package/lib/module/constants/ocr/symbols.js.map +1 -1
- package/lib/module/controllers/LLMController.js +10 -1
- package/lib/module/controllers/LLMController.js.map +1 -1
- package/lib/module/controllers/OCRController.js +3 -3
- package/lib/module/controllers/OCRController.js.map +1 -1
- package/lib/module/controllers/VerticalOCRController.js +2 -2
- package/lib/module/controllers/VerticalOCRController.js.map +1 -1
- package/lib/module/hooks/computer_vision/useOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
- package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
- package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
- package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
- package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
- package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
- package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
- package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
- package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
- package/lib/module/index.js +7 -2
- package/lib/module/index.js.map +1 -1
- package/lib/module/modules/computer_vision/OCRModule.js +2 -2
- package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
- package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
- package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
- package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
- package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
- package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
- package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
- package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
- package/lib/module/types/llm.js.map +1 -1
- package/lib/module/types/vad.js +2 -0
- package/lib/module/types/vad.js.map +1 -0
- package/lib/module/utils/ResourceFetcher.js +2 -1
- package/lib/module/utils/ResourceFetcher.js.map +1 -1
- package/lib/module/utils/ResourceFetcherUtils.js +6 -6
- package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
- package/lib/typescript/Error.d.ts +1 -0
- package/lib/typescript/Error.d.ts.map +1 -1
- package/lib/typescript/constants/modelUrls.d.ts +23 -0
- package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
- package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
- package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
- package/lib/typescript/controllers/OCRController.d.ts +1 -1
- package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
- package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
- package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
- package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
- package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
- package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
- package/lib/typescript/index.d.ts +8 -1
- package/lib/typescript/index.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
- package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
- package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
- package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
- package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
- package/lib/typescript/types/llm.d.ts +2 -0
- package/lib/typescript/types/llm.d.ts.map +1 -1
- package/lib/typescript/types/vad.d.ts +5 -0
- package/lib/typescript/types/vad.d.ts.map +1 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
- package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
- package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
- package/package.json +11 -8
- package/react-native-executorch.podspec +9 -9
- package/src/Error.ts +1 -0
- package/src/constants/directories.ts +1 -1
- package/src/constants/modelUrls.ts +36 -1
- package/src/constants/ocr/models.ts +7 -7
- package/src/constants/ocr/symbols.ts +3 -2
- package/src/controllers/LLMController.ts +12 -1
- package/src/controllers/OCRController.ts +3 -3
- package/src/controllers/VerticalOCRController.ts +2 -2
- package/src/hooks/computer_vision/useOCR.ts +4 -5
- package/src/hooks/computer_vision/useTextToImage.ts +92 -0
- package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
- package/src/hooks/natural_language_processing/useLLM.ts +3 -4
- package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
- package/src/hooks/natural_language_processing/useVAD.ts +15 -0
- package/src/index.ts +20 -1
- package/src/modules/computer_vision/OCRModule.ts +2 -2
- package/src/modules/computer_vision/TextToImageModule.ts +93 -0
- package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
- package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
- package/src/modules/natural_language_processing/VADModule.ts +27 -0
- package/src/types/llm.ts +2 -0
- package/src/types/vad.ts +4 -0
- package/src/utils/ResourceFetcher.ts +2 -1
- package/src/utils/ResourceFetcherUtils.ts +8 -8
- package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
- package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
- package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
- package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
- package/third-party/include/c10/macros/Export.h +0 -78
- package/third-party/include/c10/macros/Macros.h +1 -520
- package/third-party/include/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/c10/util/BFloat16.h +1 -122
- package/third-party/include/c10/util/Half-inl.h +1 -347
- package/third-party/include/c10/util/Half.h +6 -419
- package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/c10/util/bit_cast.h +1 -43
- package/third-party/include/c10/util/complex.h +1 -568
- package/third-party/include/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/c10/util/irange.h +1 -1
- package/third-party/include/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/ExecuTorchError.h +6 -7
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
- package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
- package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
- package/third-party/include/executorch/ExecuTorchLog.h +1 -0
- package/third-party/include/executorch/ExecuTorchModule.h +177 -4
- package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
- package/third-party/include/executorch/ExecuTorchValue.h +1 -7
- package/third-party/include/executorch/extension/module/module.h +139 -8
- package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
- package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
- package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
- package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
- package/third-party/include/executorch/runtime/backend/interface.h +1 -1
- package/third-party/include/executorch/runtime/core/error.h +76 -49
- package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
- package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
- package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
- package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
- package/third-party/include/executorch/runtime/executor/method.h +21 -8
- package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
- package/third-party/include/executorch/runtime/executor/program.h +0 -10
- package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
- package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
- package/third-party/include/executorch/schema/extended_header.h +10 -1
- package/third-party/include/torch/headeronly/macros/Export.h +66 -0
- package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
- package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
- package/third-party/include/torch/headeronly/util/Half.h +781 -0
- package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
- package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
- package/third-party/include/torch/headeronly/util/complex.h +593 -0
- package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
- package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
- package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
- package/common/rnexecutorch/tests/run_test.sh +0 -18
- package/ios/RnExecutorch/utils/Conversions.h +0 -14
- package/ios/RnExecutorch/utils/ETError.h +0 -26
- package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
- package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
- package/ios/RnExecutorch/utils/Numerical.h +0 -3
- package/ios/RnExecutorch/utils/Numerical.mm +0 -18
- package/ios/RnExecutorch/utils/ScalarType.h +0 -14
- package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
- package/lib/module/hooks/useNonStaticModule.js.map +0 -1
- package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
- package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
- package/src/hooks/useNonStaticModule.ts +0 -74
- package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
- package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
- package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
- package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
// Defines the bloat16 type (brain floating-point). This representation uses
|
|
4
|
+
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
|
|
5
|
+
|
|
6
|
+
#include <torch/headeronly/macros/Macros.h>
|
|
7
|
+
#include <torch/headeronly/util/bit_cast.h>
|
|
8
|
+
|
|
9
|
+
#include <cmath>
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <cstring>
|
|
12
|
+
#include <iosfwd>
|
|
13
|
+
#include <ostream>
|
|
14
|
+
|
|
15
|
+
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
|
16
|
+
#include <cuda_bf16.h>
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
#if defined(CL_SYCL_LANGUAGE_VERSION)
|
|
20
|
+
#include <CL/sycl.hpp> // for SYCL 1.2.1
|
|
21
|
+
#elif defined(SYCL_LANGUAGE_VERSION)
|
|
22
|
+
#include <sycl/sycl.hpp> // for SYCL 2020
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
namespace c10 {
|
|
26
|
+
|
|
27
|
+
struct alignas(2) BFloat16 {
|
|
28
|
+
uint16_t x;
|
|
29
|
+
|
|
30
|
+
// HIP wants __host__ __device__ tag, CUDA does not
|
|
31
|
+
#if defined(USE_ROCM) && defined(__HIPCC__)
|
|
32
|
+
C10_HOST_DEVICE BFloat16() = default;
|
|
33
|
+
#else
|
|
34
|
+
BFloat16() = default;
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
struct from_bits_t {};
|
|
38
|
+
static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
|
|
39
|
+
return from_bits_t();
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
|
|
43
|
+
: x(bits) {}
|
|
44
|
+
/* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
|
|
45
|
+
inline C10_HOST_DEVICE operator float() const;
|
|
46
|
+
|
|
47
|
+
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
|
48
|
+
inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16 &value);
|
|
49
|
+
explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
|
|
50
|
+
#endif
|
|
51
|
+
|
|
52
|
+
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
|
53
|
+
inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16 &value);
|
|
54
|
+
explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
|
|
55
|
+
#endif
|
|
56
|
+
};
|
|
57
|
+
|
|
58
|
+
inline std::ostream &operator<<(std::ostream &out, const BFloat16 &value) {
|
|
59
|
+
out << (float)value;
|
|
60
|
+
return out;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
namespace detail {
|
|
64
|
+
inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
|
|
65
|
+
float res = 0;
|
|
66
|
+
uint32_t tmp = src;
|
|
67
|
+
tmp <<= 16;
|
|
68
|
+
|
|
69
|
+
#if defined(USE_ROCM) && defined(__HIPCC__)
|
|
70
|
+
float *tempRes;
|
|
71
|
+
|
|
72
|
+
// We should be using memcpy in order to respect the strict aliasing rule
|
|
73
|
+
// but it fails in the HIP environment.
|
|
74
|
+
tempRes = reinterpret_cast<float *>(&tmp);
|
|
75
|
+
res = *tempRes;
|
|
76
|
+
#else
|
|
77
|
+
std::memcpy(&res, &tmp, sizeof(tmp));
|
|
78
|
+
#endif
|
|
79
|
+
|
|
80
|
+
return res;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
|
|
84
|
+
uint32_t res = 0;
|
|
85
|
+
|
|
86
|
+
#if defined(USE_ROCM) && defined(__HIPCC__)
|
|
87
|
+
// We should be using memcpy in order to respect the strict aliasing rule
|
|
88
|
+
// but it fails in the HIP environment.
|
|
89
|
+
uint32_t *tempRes = reinterpret_cast<uint32_t *>(&src);
|
|
90
|
+
res = *tempRes;
|
|
91
|
+
#else
|
|
92
|
+
std::memcpy(&res, &src, sizeof(res));
|
|
93
|
+
#endif
|
|
94
|
+
|
|
95
|
+
return res >> 16;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
|
|
99
|
+
#if defined(USE_ROCM) && defined(__HIPCC__)
|
|
100
|
+
if (src != src) {
|
|
101
|
+
#elif defined(_MSC_VER)
|
|
102
|
+
if (isnan(src)) {
|
|
103
|
+
#else
|
|
104
|
+
if (std::isnan(src)) {
|
|
105
|
+
#endif
|
|
106
|
+
return UINT16_C(0x7FC0);
|
|
107
|
+
} else {
|
|
108
|
+
const uint32_t U32 = c10::bit_cast<uint32_t>(src);
|
|
109
|
+
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
|
110
|
+
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
} // namespace detail
|
|
115
|
+
|
|
116
|
+
//-------- the following is copied from c10/util/BFloat16-inl.h ---------//
|
|
117
|
+
C10_CLANG_DIAGNOSTIC_PUSH()
|
|
118
|
+
#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
|
|
119
|
+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
|
|
120
|
+
#endif
|
|
121
|
+
|
|
122
|
+
/// Constructors
|
|
123
|
+
inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
|
|
124
|
+
:
|
|
125
|
+
#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
|
|
126
|
+
__CUDA_ARCH__ >= 800
|
|
127
|
+
x(__bfloat16_as_ushort(__float2bfloat16(value)))
|
|
128
|
+
#elif defined(__SYCL_DEVICE_ONLY__) && \
|
|
129
|
+
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
|
130
|
+
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
|
|
131
|
+
#else
|
|
132
|
+
// RNE by default
|
|
133
|
+
x(detail::round_to_nearest_even(value))
|
|
134
|
+
#endif
|
|
135
|
+
{
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
/// Implicit conversions
|
|
139
|
+
inline C10_HOST_DEVICE BFloat16::operator float() const {
|
|
140
|
+
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
|
141
|
+
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16 *>(&x));
|
|
142
|
+
#elif defined(__SYCL_DEVICE_ONLY__) && \
|
|
143
|
+
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
|
144
|
+
return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16 *>(&x));
|
|
145
|
+
#else
|
|
146
|
+
return detail::f32_from_bits(x);
|
|
147
|
+
#endif
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
#if defined(__CUDACC__) && !defined(USE_ROCM)
|
|
151
|
+
inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16 &value) {
|
|
152
|
+
x = *reinterpret_cast<const unsigned short *>(&value);
|
|
153
|
+
}
|
|
154
|
+
inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
|
|
155
|
+
return *reinterpret_cast<const __nv_bfloat16 *>(&x);
|
|
156
|
+
}
|
|
157
|
+
#endif
|
|
158
|
+
|
|
159
|
+
#if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
|
160
|
+
inline C10_HOST_DEVICE
|
|
161
|
+
BFloat16::BFloat16(const sycl::ext::oneapi::bfloat16 &value) {
|
|
162
|
+
x = *reinterpret_cast<const unsigned short *>(&value);
|
|
163
|
+
}
|
|
164
|
+
inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
|
|
165
|
+
return *reinterpret_cast<const sycl::ext::oneapi::bfloat16 *>(&x);
|
|
166
|
+
}
|
|
167
|
+
#endif
|
|
168
|
+
|
|
169
|
+
// CUDA intrinsics
|
|
170
|
+
|
|
171
|
+
#if defined(__CUDACC__) || defined(__HIPCC__)
|
|
172
|
+
inline C10_DEVICE BFloat16 __ldg(const BFloat16 *ptr) {
|
|
173
|
+
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
174
|
+
return __ldg(reinterpret_cast<const __nv_bfloat16 *>(ptr));
|
|
175
|
+
#else
|
|
176
|
+
return *ptr;
|
|
177
|
+
#endif
|
|
178
|
+
}
|
|
179
|
+
#endif
|
|
180
|
+
|
|
181
|
+
/// Arithmetic
|
|
182
|
+
|
|
183
|
+
inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16 &a,
|
|
184
|
+
const BFloat16 &b) {
|
|
185
|
+
return static_cast<float>(a) + static_cast<float>(b);
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16 &a,
|
|
189
|
+
const BFloat16 &b) {
|
|
190
|
+
return static_cast<float>(a) - static_cast<float>(b);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
inline C10_HOST_DEVICE BFloat16 operator*(const BFloat16 &a,
|
|
194
|
+
const BFloat16 &b) {
|
|
195
|
+
return static_cast<float>(a) * static_cast<float>(b);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16 &a, const BFloat16 &b)
|
|
199
|
+
__ubsan_ignore_float_divide_by_zero__ {
|
|
200
|
+
return static_cast<float>(a) / static_cast<float>(b);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16 &a) {
|
|
204
|
+
return -static_cast<float>(a);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
inline C10_HOST_DEVICE BFloat16 &operator+=(BFloat16 &a, const BFloat16 &b) {
|
|
208
|
+
a = a + b;
|
|
209
|
+
return a;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
inline C10_HOST_DEVICE BFloat16 &operator-=(BFloat16 &a, const BFloat16 &b) {
|
|
213
|
+
a = a - b;
|
|
214
|
+
return a;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
inline C10_HOST_DEVICE BFloat16 &operator*=(BFloat16 &a, const BFloat16 &b) {
|
|
218
|
+
a = a * b;
|
|
219
|
+
return a;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
inline C10_HOST_DEVICE BFloat16 &operator/=(BFloat16 &a, const BFloat16 &b) {
|
|
223
|
+
a = a / b;
|
|
224
|
+
return a;
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
inline C10_HOST_DEVICE BFloat16 &operator|(BFloat16 &a, const BFloat16 &b) {
|
|
228
|
+
a.x = a.x | b.x;
|
|
229
|
+
return a;
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
inline C10_HOST_DEVICE BFloat16 &operator^(BFloat16 &a, const BFloat16 &b) {
|
|
233
|
+
a.x = a.x ^ b.x;
|
|
234
|
+
return a;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
inline C10_HOST_DEVICE BFloat16 &operator&(BFloat16 &a, const BFloat16 &b) {
|
|
238
|
+
a.x = a.x & b.x;
|
|
239
|
+
return a;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/// Arithmetic with floats
|
|
243
|
+
|
|
244
|
+
inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
|
|
245
|
+
return static_cast<float>(a) + b;
|
|
246
|
+
}
|
|
247
|
+
inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
|
|
248
|
+
return static_cast<float>(a) - b;
|
|
249
|
+
}
|
|
250
|
+
inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
|
|
251
|
+
return static_cast<float>(a) * b;
|
|
252
|
+
}
|
|
253
|
+
inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
|
|
254
|
+
return static_cast<float>(a) / b;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
|
|
258
|
+
return a + static_cast<float>(b);
|
|
259
|
+
}
|
|
260
|
+
inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
|
|
261
|
+
return a - static_cast<float>(b);
|
|
262
|
+
}
|
|
263
|
+
inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
|
|
264
|
+
return a * static_cast<float>(b);
|
|
265
|
+
}
|
|
266
|
+
inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
|
|
267
|
+
return a / static_cast<float>(b);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
inline C10_HOST_DEVICE float &operator+=(float &a, const BFloat16 &b) {
|
|
271
|
+
return a += static_cast<float>(b);
|
|
272
|
+
}
|
|
273
|
+
inline C10_HOST_DEVICE float &operator-=(float &a, const BFloat16 &b) {
|
|
274
|
+
return a -= static_cast<float>(b);
|
|
275
|
+
}
|
|
276
|
+
inline C10_HOST_DEVICE float &operator*=(float &a, const BFloat16 &b) {
|
|
277
|
+
return a *= static_cast<float>(b);
|
|
278
|
+
}
|
|
279
|
+
inline C10_HOST_DEVICE float &operator/=(float &a, const BFloat16 &b) {
|
|
280
|
+
return a /= static_cast<float>(b);
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
/// Arithmetic with doubles
|
|
284
|
+
|
|
285
|
+
inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
|
|
286
|
+
return static_cast<double>(a) + b;
|
|
287
|
+
}
|
|
288
|
+
inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
|
|
289
|
+
return static_cast<double>(a) - b;
|
|
290
|
+
}
|
|
291
|
+
inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
|
|
292
|
+
return static_cast<double>(a) * b;
|
|
293
|
+
}
|
|
294
|
+
inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
|
|
295
|
+
return static_cast<double>(a) / b;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
|
|
299
|
+
return a + static_cast<double>(b);
|
|
300
|
+
}
|
|
301
|
+
inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
|
|
302
|
+
return a - static_cast<double>(b);
|
|
303
|
+
}
|
|
304
|
+
inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
|
|
305
|
+
return a * static_cast<double>(b);
|
|
306
|
+
}
|
|
307
|
+
inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
|
|
308
|
+
return a / static_cast<double>(b);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
/// Arithmetic with ints
|
|
312
|
+
|
|
313
|
+
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
|
|
314
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
315
|
+
return a + static_cast<BFloat16>(b);
|
|
316
|
+
}
|
|
317
|
+
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
|
|
318
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
319
|
+
return a - static_cast<BFloat16>(b);
|
|
320
|
+
}
|
|
321
|
+
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
|
|
322
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
323
|
+
return a * static_cast<BFloat16>(b);
|
|
324
|
+
}
|
|
325
|
+
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
|
|
326
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
327
|
+
return a / static_cast<BFloat16>(b);
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
|
|
331
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
332
|
+
return static_cast<BFloat16>(a) + b;
|
|
333
|
+
}
|
|
334
|
+
inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
|
|
335
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
336
|
+
return static_cast<BFloat16>(a) - b;
|
|
337
|
+
}
|
|
338
|
+
inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
|
|
339
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
340
|
+
return static_cast<BFloat16>(a) * b;
|
|
341
|
+
}
|
|
342
|
+
inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
|
|
343
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
344
|
+
return static_cast<BFloat16>(a) / b;
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
//// Arithmetic with int64_t
|
|
348
|
+
|
|
349
|
+
inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
|
|
350
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
351
|
+
return a + static_cast<BFloat16>(b);
|
|
352
|
+
}
|
|
353
|
+
inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
|
|
354
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
355
|
+
return a - static_cast<BFloat16>(b);
|
|
356
|
+
}
|
|
357
|
+
inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
|
|
358
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
359
|
+
return a * static_cast<BFloat16>(b);
|
|
360
|
+
}
|
|
361
|
+
inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
|
|
362
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
363
|
+
return a / static_cast<BFloat16>(b);
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
|
|
367
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
368
|
+
return static_cast<BFloat16>(a) + b;
|
|
369
|
+
}
|
|
370
|
+
inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
|
|
371
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
372
|
+
return static_cast<BFloat16>(a) - b;
|
|
373
|
+
}
|
|
374
|
+
inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
|
|
375
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
376
|
+
return static_cast<BFloat16>(a) * b;
|
|
377
|
+
}
|
|
378
|
+
inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
|
|
379
|
+
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
|
380
|
+
return static_cast<BFloat16>(a) / b;
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
// Overloading < and > operators, because std::max and std::min use them.
|
|
384
|
+
|
|
385
|
+
inline C10_HOST_DEVICE bool operator>(BFloat16 &lhs, BFloat16 &rhs) {
|
|
386
|
+
return float(lhs) > float(rhs);
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
inline C10_HOST_DEVICE bool operator<(BFloat16 &lhs, BFloat16 &rhs) {
|
|
390
|
+
return float(lhs) < float(rhs);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
C10_CLANG_DIAGNOSTIC_POP()
|
|
394
|
+
} // namespace c10
|
|
395
|
+
|
|
396
|
+
namespace torch::headeronly {
|
|
397
|
+
|
|
398
|
+
namespace detail {
|
|
399
|
+
using c10::detail::bits_from_f32;
|
|
400
|
+
using c10::detail::f32_from_bits;
|
|
401
|
+
using c10::detail::round_to_nearest_even;
|
|
402
|
+
} // namespace detail
|
|
403
|
+
|
|
404
|
+
using c10::BFloat16;
|
|
405
|
+
using c10::operator+;
|
|
406
|
+
using c10::operator-;
|
|
407
|
+
using c10::operator*;
|
|
408
|
+
using c10::operator/;
|
|
409
|
+
using c10::operator+=;
|
|
410
|
+
using c10::operator-=;
|
|
411
|
+
using c10::operator*=;
|
|
412
|
+
using c10::operator/=;
|
|
413
|
+
using c10::operator<;
|
|
414
|
+
using c10::operator>;
|
|
415
|
+
using c10::operator<<;
|
|
416
|
+
} // namespace torch::headeronly
|
|
417
|
+
|
|
418
|
+
namespace std {
|
|
419
|
+
|
|
420
|
+
template <> class numeric_limits<c10::BFloat16> {
|
|
421
|
+
public:
|
|
422
|
+
static constexpr bool is_signed = true;
|
|
423
|
+
static constexpr bool is_specialized = true;
|
|
424
|
+
static constexpr bool is_integer = false;
|
|
425
|
+
static constexpr bool is_exact = false;
|
|
426
|
+
static constexpr bool has_infinity = true;
|
|
427
|
+
static constexpr bool has_quiet_NaN = true;
|
|
428
|
+
static constexpr bool has_signaling_NaN = true;
|
|
429
|
+
static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
|
|
430
|
+
static constexpr auto has_denorm_loss =
|
|
431
|
+
numeric_limits<float>::has_denorm_loss;
|
|
432
|
+
static constexpr auto round_style = numeric_limits<float>::round_style;
|
|
433
|
+
static constexpr bool is_iec559 = false;
|
|
434
|
+
static constexpr bool is_bounded = true;
|
|
435
|
+
static constexpr bool is_modulo = false;
|
|
436
|
+
static constexpr int digits = 8;
|
|
437
|
+
static constexpr int digits10 = 2;
|
|
438
|
+
static constexpr int max_digits10 = 4;
|
|
439
|
+
static constexpr int radix = 2;
|
|
440
|
+
static constexpr int min_exponent = -125;
|
|
441
|
+
static constexpr int min_exponent10 = -37;
|
|
442
|
+
static constexpr int max_exponent = 128;
|
|
443
|
+
static constexpr int max_exponent10 = 38;
|
|
444
|
+
static constexpr auto traps = numeric_limits<float>::traps;
|
|
445
|
+
static constexpr auto tinyness_before =
|
|
446
|
+
numeric_limits<float>::tinyness_before;
|
|
447
|
+
|
|
448
|
+
static constexpr c10::BFloat16 min() {
|
|
449
|
+
return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
|
|
450
|
+
}
|
|
451
|
+
static constexpr c10::BFloat16 lowest() {
|
|
452
|
+
return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
|
|
453
|
+
}
|
|
454
|
+
static constexpr c10::BFloat16 max() {
|
|
455
|
+
return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
|
|
456
|
+
}
|
|
457
|
+
static constexpr c10::BFloat16 epsilon() {
|
|
458
|
+
return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
|
|
459
|
+
}
|
|
460
|
+
static constexpr c10::BFloat16 round_error() {
|
|
461
|
+
return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
|
|
462
|
+
}
|
|
463
|
+
static constexpr c10::BFloat16 infinity() {
|
|
464
|
+
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
|
|
465
|
+
}
|
|
466
|
+
static constexpr c10::BFloat16 quiet_NaN() {
|
|
467
|
+
return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
|
|
468
|
+
}
|
|
469
|
+
static constexpr c10::BFloat16 signaling_NaN() {
|
|
470
|
+
return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
|
|
471
|
+
}
|
|
472
|
+
static constexpr c10::BFloat16 denorm_min() {
|
|
473
|
+
return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
|
|
474
|
+
}
|
|
475
|
+
};
|
|
476
|
+
|
|
477
|
+
} // namespace std
|