react-native-executorch 0.5.15 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. package/README.md +42 -36
  2. package/android/CMakeLists.txt +13 -25
  3. package/android/build.gradle +2 -3
  4. package/android/libs/classes.jar +0 -0
  5. package/android/src/main/cpp/CMakeLists.txt +2 -1
  6. package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
  7. package/common/rnexecutorch/TokenizerModule.cpp +3 -3
  8. package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
  9. package/common/rnexecutorch/data_processing/Numerical.h +6 -1
  10. package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
  11. package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
  12. package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
  13. package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
  14. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
  15. package/common/rnexecutorch/models/BaseModel.cpp +12 -11
  16. package/common/rnexecutorch/models/BaseModel.h +18 -10
  17. package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
  18. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
  19. package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
  20. package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
  21. package/common/rnexecutorch/models/llm/LLM.h +4 -4
  22. package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
  23. package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
  24. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
  25. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
  26. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
  27. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
  28. package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
  29. package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
  30. package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
  31. package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
  32. package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
  33. package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
  34. package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
  35. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
  36. package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
  37. package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
  38. package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
  39. package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
  40. package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
  41. package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
  42. package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
  43. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
  44. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
  45. package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
  46. package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
  47. package/common/rnexecutorch/tests/README.md +30 -13
  48. package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
  49. package/common/runner/arange_util.cpp +44 -0
  50. package/common/runner/arange_util.h +37 -0
  51. package/common/runner/constants.h +28 -0
  52. package/common/runner/io_manager.h +240 -0
  53. package/common/runner/irunner.h +87 -16
  54. package/common/runner/kernel_includes.h +23 -0
  55. package/common/runner/runner.cpp +151 -66
  56. package/common/runner/runner.h +39 -22
  57. package/common/runner/sampler.cpp +8 -1
  58. package/common/runner/sampler.h +4 -2
  59. package/common/runner/stats.h +1 -4
  60. package/common/runner/text_decoder_runner.cpp +26 -12
  61. package/common/runner/text_decoder_runner.h +52 -31
  62. package/common/runner/text_prefiller.cpp +46 -12
  63. package/common/runner/text_prefiller.h +38 -4
  64. package/common/runner/text_token_generator.h +51 -26
  65. package/common/runner/util.h +53 -8
  66. package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
  67. package/lib/module/Error.js +1 -0
  68. package/lib/module/Error.js.map +1 -1
  69. package/lib/module/constants/directories.js +1 -1
  70. package/lib/module/constants/directories.js.map +1 -1
  71. package/lib/module/constants/modelUrls.js +32 -1
  72. package/lib/module/constants/modelUrls.js.map +1 -1
  73. package/lib/module/constants/ocr/models.js +7 -7
  74. package/lib/module/constants/ocr/models.js.map +1 -1
  75. package/lib/module/constants/ocr/symbols.js +3 -2
  76. package/lib/module/constants/ocr/symbols.js.map +1 -1
  77. package/lib/module/controllers/LLMController.js +10 -1
  78. package/lib/module/controllers/LLMController.js.map +1 -1
  79. package/lib/module/controllers/OCRController.js +3 -3
  80. package/lib/module/controllers/OCRController.js.map +1 -1
  81. package/lib/module/controllers/VerticalOCRController.js +2 -2
  82. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  83. package/lib/module/hooks/computer_vision/useOCR.js +3 -3
  84. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  85. package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
  86. package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
  87. package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
  88. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  89. package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
  90. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  91. package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
  92. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
  93. package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
  94. package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
  95. package/lib/module/index.js +7 -2
  96. package/lib/module/index.js.map +1 -1
  97. package/lib/module/modules/computer_vision/OCRModule.js +2 -2
  98. package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
  99. package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
  100. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
  101. package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
  102. package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
  103. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
  104. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  105. package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
  106. package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
  107. package/lib/module/types/llm.js.map +1 -1
  108. package/lib/module/types/vad.js +2 -0
  109. package/lib/module/types/vad.js.map +1 -0
  110. package/lib/module/utils/ResourceFetcher.js +2 -1
  111. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  112. package/lib/module/utils/ResourceFetcherUtils.js +6 -6
  113. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  114. package/lib/typescript/Error.d.ts +1 -0
  115. package/lib/typescript/Error.d.ts.map +1 -1
  116. package/lib/typescript/constants/modelUrls.d.ts +23 -0
  117. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  118. package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
  119. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  120. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  121. package/lib/typescript/controllers/OCRController.d.ts +1 -1
  122. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  123. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  124. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  125. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
  126. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  127. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
  128. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
  129. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
  130. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  131. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  132. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
  133. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
  134. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
  135. package/lib/typescript/index.d.ts +8 -1
  136. package/lib/typescript/index.d.ts.map +1 -1
  137. package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
  138. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  139. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
  140. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
  141. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
  142. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  143. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
  144. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  145. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
  146. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
  147. package/lib/typescript/types/llm.d.ts +2 -0
  148. package/lib/typescript/types/llm.d.ts.map +1 -1
  149. package/lib/typescript/types/vad.d.ts +5 -0
  150. package/lib/typescript/types/vad.d.ts.map +1 -0
  151. package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
  152. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  153. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
  154. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  155. package/package.json +11 -8
  156. package/react-native-executorch.podspec +9 -9
  157. package/src/Error.ts +1 -0
  158. package/src/constants/directories.ts +1 -1
  159. package/src/constants/modelUrls.ts +36 -1
  160. package/src/constants/ocr/models.ts +7 -7
  161. package/src/constants/ocr/symbols.ts +3 -2
  162. package/src/controllers/LLMController.ts +12 -1
  163. package/src/controllers/OCRController.ts +3 -3
  164. package/src/controllers/VerticalOCRController.ts +2 -2
  165. package/src/hooks/computer_vision/useOCR.ts +4 -5
  166. package/src/hooks/computer_vision/useTextToImage.ts +92 -0
  167. package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
  168. package/src/hooks/natural_language_processing/useLLM.ts +3 -4
  169. package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
  170. package/src/hooks/natural_language_processing/useVAD.ts +15 -0
  171. package/src/index.ts +20 -1
  172. package/src/modules/computer_vision/OCRModule.ts +2 -2
  173. package/src/modules/computer_vision/TextToImageModule.ts +93 -0
  174. package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
  175. package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
  176. package/src/modules/natural_language_processing/VADModule.ts +27 -0
  177. package/src/types/llm.ts +2 -0
  178. package/src/types/vad.ts +4 -0
  179. package/src/utils/ResourceFetcher.ts +2 -1
  180. package/src/utils/ResourceFetcherUtils.ts +8 -8
  181. package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
  182. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  183. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  184. package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
  185. package/third-party/include/c10/macros/Export.h +0 -78
  186. package/third-party/include/c10/macros/Macros.h +1 -520
  187. package/third-party/include/c10/util/BFloat16-inl.h +1 -339
  188. package/third-party/include/c10/util/BFloat16.h +1 -122
  189. package/third-party/include/c10/util/Half-inl.h +1 -347
  190. package/third-party/include/c10/util/Half.h +6 -419
  191. package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
  192. package/third-party/include/c10/util/bit_cast.h +1 -43
  193. package/third-party/include/c10/util/complex.h +1 -568
  194. package/third-party/include/c10/util/floating_point_utils.h +1 -33
  195. package/third-party/include/c10/util/irange.h +1 -1
  196. package/third-party/include/c10/util/llvmMathExtras.h +866 -0
  197. package/third-party/include/c10/util/safe_numerics.h +97 -0
  198. package/third-party/include/executorch/ExecuTorchError.h +6 -7
  199. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
  200. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
  201. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
  202. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
  203. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
  204. package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
  205. package/third-party/include/executorch/ExecuTorchLog.h +1 -0
  206. package/third-party/include/executorch/ExecuTorchModule.h +177 -4
  207. package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
  208. package/third-party/include/executorch/ExecuTorchValue.h +1 -7
  209. package/third-party/include/executorch/extension/module/module.h +139 -8
  210. package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
  211. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
  212. package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
  213. package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
  214. package/third-party/include/executorch/runtime/backend/interface.h +1 -1
  215. package/third-party/include/executorch/runtime/core/error.h +76 -49
  216. package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
  217. package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
  218. package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
  219. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
  220. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
  221. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
  222. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
  223. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
  224. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
  225. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
  226. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
  227. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
  228. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
  229. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
  230. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
  231. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
  232. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
  233. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
  234. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
  235. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
  236. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  237. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
  238. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
  239. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
  240. package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
  241. package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
  242. package/third-party/include/executorch/runtime/executor/method.h +21 -8
  243. package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
  244. package/third-party/include/executorch/runtime/executor/program.h +0 -10
  245. package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
  246. package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
  247. package/third-party/include/executorch/schema/extended_header.h +10 -1
  248. package/third-party/include/torch/headeronly/macros/Export.h +66 -0
  249. package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
  250. package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
  251. package/third-party/include/torch/headeronly/util/Half.h +781 -0
  252. package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  253. package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
  254. package/third-party/include/torch/headeronly/util/complex.h +593 -0
  255. package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
  256. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  257. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  258. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  259. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  260. package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
  261. package/common/rnexecutorch/tests/run_test.sh +0 -18
  262. package/ios/RnExecutorch/utils/Conversions.h +0 -14
  263. package/ios/RnExecutorch/utils/ETError.h +0 -26
  264. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
  265. package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
  266. package/ios/RnExecutorch/utils/Numerical.h +0 -3
  267. package/ios/RnExecutorch/utils/Numerical.mm +0 -18
  268. package/ios/RnExecutorch/utils/ScalarType.h +0 -14
  269. package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
  270. package/lib/module/hooks/useNonStaticModule.js.map +0 -1
  271. package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
  272. package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
  273. package/src/hooks/useNonStaticModule.ts +0 -74
  274. package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
  275. package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
  276. package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
  277. package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
@@ -1,339 +1 @@
1
- #pragma once
2
-
3
- #include <c10/macros/Macros.h>
4
- #include <c10/util/bit_cast.h>
5
-
6
- #include <limits>
7
-
8
- C10_CLANG_DIAGNOSTIC_PUSH()
9
- #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10
- C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11
- #endif
12
-
13
- #if defined(CL_SYCL_LANGUAGE_VERSION)
14
- #include <CL/sycl.hpp> // for SYCL 1.2.1
15
- #elif defined(SYCL_LANGUAGE_VERSION)
16
- #include <sycl/sycl.hpp> // for SYCL 2020
17
- #endif
18
-
19
- namespace c10 {
20
-
21
- /// Constructors
22
- inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
23
- :
24
- #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
25
- __CUDA_ARCH__ >= 800
26
- x(__bfloat16_as_ushort(__float2bfloat16(value)))
27
- #elif defined(__SYCL_DEVICE_ONLY__) && \
28
- defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
29
- x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
30
- #else
31
- // RNE by default
32
- x(detail::round_to_nearest_even(value))
33
- #endif
34
- {
35
- }
36
-
37
- /// Implicit conversions
38
- inline C10_HOST_DEVICE BFloat16::operator float() const {
39
- #if defined(__CUDACC__) && !defined(USE_ROCM)
40
- return __bfloat162float(*reinterpret_cast<const __nv_bfloat16 *>(&x));
41
- #elif defined(__SYCL_DEVICE_ONLY__) && \
42
- defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
43
- return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16 *>(&x));
44
- #else
45
- return detail::f32_from_bits(x);
46
- #endif
47
- }
48
-
49
- #if defined(__CUDACC__) && !defined(USE_ROCM)
50
- inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16 &value) {
51
- x = *reinterpret_cast<const unsigned short *>(&value);
52
- }
53
- inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
54
- return *reinterpret_cast<const __nv_bfloat16 *>(&x);
55
- }
56
- #endif
57
-
58
- #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
59
- inline C10_HOST_DEVICE
60
- BFloat16::BFloat16(const sycl::ext::oneapi::bfloat16 &value) {
61
- x = *reinterpret_cast<const unsigned short *>(&value);
62
- }
63
- inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
64
- return *reinterpret_cast<const sycl::ext::oneapi::bfloat16 *>(&x);
65
- }
66
- #endif
67
-
68
- // CUDA intrinsics
69
-
70
- #if defined(__CUDACC__) || defined(__HIPCC__)
71
- inline C10_DEVICE BFloat16 __ldg(const BFloat16 *ptr) {
72
- #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
73
- return __ldg(reinterpret_cast<const __nv_bfloat16 *>(ptr));
74
- #else
75
- return *ptr;
76
- #endif
77
- }
78
- #endif
79
-
80
- /// Arithmetic
81
-
82
- inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16 &a,
83
- const BFloat16 &b) {
84
- return static_cast<float>(a) + static_cast<float>(b);
85
- }
86
-
87
- inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16 &a,
88
- const BFloat16 &b) {
89
- return static_cast<float>(a) - static_cast<float>(b);
90
- }
91
-
92
- inline C10_HOST_DEVICE BFloat16 operator*(const BFloat16 &a,
93
- const BFloat16 &b) {
94
- return static_cast<float>(a) * static_cast<float>(b);
95
- }
96
-
97
- inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16 &a, const BFloat16 &b)
98
- __ubsan_ignore_float_divide_by_zero__ {
99
- return static_cast<float>(a) / static_cast<float>(b);
100
- }
101
-
102
- inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16 &a) {
103
- return -static_cast<float>(a);
104
- }
105
-
106
- inline C10_HOST_DEVICE BFloat16 &operator+=(BFloat16 &a, const BFloat16 &b) {
107
- a = a + b;
108
- return a;
109
- }
110
-
111
- inline C10_HOST_DEVICE BFloat16 &operator-=(BFloat16 &a, const BFloat16 &b) {
112
- a = a - b;
113
- return a;
114
- }
115
-
116
- inline C10_HOST_DEVICE BFloat16 &operator*=(BFloat16 &a, const BFloat16 &b) {
117
- a = a * b;
118
- return a;
119
- }
120
-
121
- inline C10_HOST_DEVICE BFloat16 &operator/=(BFloat16 &a, const BFloat16 &b) {
122
- a = a / b;
123
- return a;
124
- }
125
-
126
- inline C10_HOST_DEVICE BFloat16 &operator|(BFloat16 &a, const BFloat16 &b) {
127
- a.x = a.x | b.x;
128
- return a;
129
- }
130
-
131
- inline C10_HOST_DEVICE BFloat16 &operator^(BFloat16 &a, const BFloat16 &b) {
132
- a.x = a.x ^ b.x;
133
- return a;
134
- }
135
-
136
- inline C10_HOST_DEVICE BFloat16 &operator&(BFloat16 &a, const BFloat16 &b) {
137
- a.x = a.x & b.x;
138
- return a;
139
- }
140
-
141
- /// Arithmetic with floats
142
-
143
- inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
144
- return static_cast<float>(a) + b;
145
- }
146
- inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
147
- return static_cast<float>(a) - b;
148
- }
149
- inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
150
- return static_cast<float>(a) * b;
151
- }
152
- inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
153
- return static_cast<float>(a) / b;
154
- }
155
-
156
- inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
157
- return a + static_cast<float>(b);
158
- }
159
- inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
160
- return a - static_cast<float>(b);
161
- }
162
- inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
163
- return a * static_cast<float>(b);
164
- }
165
- inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
166
- return a / static_cast<float>(b);
167
- }
168
-
169
- inline C10_HOST_DEVICE float &operator+=(float &a, const BFloat16 &b) {
170
- return a += static_cast<float>(b);
171
- }
172
- inline C10_HOST_DEVICE float &operator-=(float &a, const BFloat16 &b) {
173
- return a -= static_cast<float>(b);
174
- }
175
- inline C10_HOST_DEVICE float &operator*=(float &a, const BFloat16 &b) {
176
- return a *= static_cast<float>(b);
177
- }
178
- inline C10_HOST_DEVICE float &operator/=(float &a, const BFloat16 &b) {
179
- return a /= static_cast<float>(b);
180
- }
181
-
182
- /// Arithmetic with doubles
183
-
184
- inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
185
- return static_cast<double>(a) + b;
186
- }
187
- inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
188
- return static_cast<double>(a) - b;
189
- }
190
- inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
191
- return static_cast<double>(a) * b;
192
- }
193
- inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
194
- return static_cast<double>(a) / b;
195
- }
196
-
197
- inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
198
- return a + static_cast<double>(b);
199
- }
200
- inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
201
- return a - static_cast<double>(b);
202
- }
203
- inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
204
- return a * static_cast<double>(b);
205
- }
206
- inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
207
- return a / static_cast<double>(b);
208
- }
209
-
210
- /// Arithmetic with ints
211
-
212
- inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
213
- return a + static_cast<BFloat16>(b);
214
- }
215
- inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
216
- return a - static_cast<BFloat16>(b);
217
- }
218
- inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
219
- return a * static_cast<BFloat16>(b);
220
- }
221
- inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
222
- return a / static_cast<BFloat16>(b);
223
- }
224
-
225
- inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
226
- return static_cast<BFloat16>(a) + b;
227
- }
228
- inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
229
- return static_cast<BFloat16>(a) - b;
230
- }
231
- inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
232
- return static_cast<BFloat16>(a) * b;
233
- }
234
- inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
235
- return static_cast<BFloat16>(a) / b;
236
- }
237
-
238
- //// Arithmetic with int64_t
239
-
240
- inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
241
- return a + static_cast<BFloat16>(b);
242
- }
243
- inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
244
- return a - static_cast<BFloat16>(b);
245
- }
246
- inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
247
- return a * static_cast<BFloat16>(b);
248
- }
249
- inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
250
- return a / static_cast<BFloat16>(b);
251
- }
252
-
253
- inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
254
- return static_cast<BFloat16>(a) + b;
255
- }
256
- inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
257
- return static_cast<BFloat16>(a) - b;
258
- }
259
- inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
260
- return static_cast<BFloat16>(a) * b;
261
- }
262
- inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
263
- return static_cast<BFloat16>(a) / b;
264
- }
265
-
266
- // Overloading < and > operators, because std::max and std::min use them.
267
-
268
- inline C10_HOST_DEVICE bool operator>(BFloat16 &lhs, BFloat16 &rhs) {
269
- return float(lhs) > float(rhs);
270
- }
271
-
272
- inline C10_HOST_DEVICE bool operator<(BFloat16 &lhs, BFloat16 &rhs) {
273
- return float(lhs) < float(rhs);
274
- }
275
-
276
- } // namespace c10
277
-
278
- namespace std {
279
-
280
- template <> class numeric_limits<c10::BFloat16> {
281
- public:
282
- static constexpr bool is_signed = true;
283
- static constexpr bool is_specialized = true;
284
- static constexpr bool is_integer = false;
285
- static constexpr bool is_exact = false;
286
- static constexpr bool has_infinity = true;
287
- static constexpr bool has_quiet_NaN = true;
288
- static constexpr bool has_signaling_NaN = true;
289
- static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
290
- static constexpr auto has_denorm_loss =
291
- numeric_limits<float>::has_denorm_loss;
292
- static constexpr auto round_style = numeric_limits<float>::round_style;
293
- static constexpr bool is_iec559 = false;
294
- static constexpr bool is_bounded = true;
295
- static constexpr bool is_modulo = false;
296
- static constexpr int digits = 8;
297
- static constexpr int digits10 = 2;
298
- static constexpr int max_digits10 = 4;
299
- static constexpr int radix = 2;
300
- static constexpr int min_exponent = -125;
301
- static constexpr int min_exponent10 = -37;
302
- static constexpr int max_exponent = 128;
303
- static constexpr int max_exponent10 = 38;
304
- static constexpr auto traps = numeric_limits<float>::traps;
305
- static constexpr auto tinyness_before =
306
- numeric_limits<float>::tinyness_before;
307
-
308
- static constexpr c10::BFloat16 min() {
309
- return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
310
- }
311
- static constexpr c10::BFloat16 lowest() {
312
- return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
313
- }
314
- static constexpr c10::BFloat16 max() {
315
- return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
316
- }
317
- static constexpr c10::BFloat16 epsilon() {
318
- return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
319
- }
320
- static constexpr c10::BFloat16 round_error() {
321
- return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
322
- }
323
- static constexpr c10::BFloat16 infinity() {
324
- return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
325
- }
326
- static constexpr c10::BFloat16 quiet_NaN() {
327
- return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
328
- }
329
- static constexpr c10::BFloat16 signaling_NaN() {
330
- return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
331
- }
332
- static constexpr c10::BFloat16 denorm_min() {
333
- return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
334
- }
335
- };
336
-
337
- } // namespace std
338
-
339
- C10_CLANG_DIAGNOSTIC_POP()
1
+ #include <torch/headeronly/util/BFloat16.h>
@@ -1,122 +1 @@
1
- #pragma once
2
-
3
- // Defines the bloat16 type (brain floating-point). This representation uses
4
- // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
5
-
6
- #include <c10/macros/Macros.h>
7
- #include <cmath>
8
- #include <cstdint>
9
- #include <cstring>
10
- #include <iosfwd>
11
- #include <ostream>
12
-
13
- #if defined(__CUDACC__) && !defined(USE_ROCM)
14
- #include <cuda_bf16.h>
15
- #endif
16
-
17
- #if defined(CL_SYCL_LANGUAGE_VERSION)
18
- #include <CL/sycl.hpp> // for SYCL 1.2.1
19
- #elif defined(SYCL_LANGUAGE_VERSION)
20
- #include <sycl/sycl.hpp> // for SYCL 2020
21
- #endif
22
-
23
- namespace c10 {
24
-
25
- namespace detail {
26
- inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
27
- float res = 0;
28
- uint32_t tmp = src;
29
- tmp <<= 16;
30
-
31
- #if defined(USE_ROCM) && defined(__HIPCC__)
32
- float *tempRes;
33
-
34
- // We should be using memcpy in order to respect the strict aliasing rule
35
- // but it fails in the HIP environment.
36
- tempRes = reinterpret_cast<float *>(&tmp);
37
- res = *tempRes;
38
- #else
39
- std::memcpy(&res, &tmp, sizeof(tmp));
40
- #endif
41
-
42
- return res;
43
- }
44
-
45
- inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
46
- uint32_t res = 0;
47
-
48
- #if defined(USE_ROCM) && defined(__HIPCC__)
49
- // We should be using memcpy in order to respect the strict aliasing rule
50
- // but it fails in the HIP environment.
51
- uint32_t *tempRes = reinterpret_cast<uint32_t *>(&src);
52
- res = *tempRes;
53
- #else
54
- std::memcpy(&res, &src, sizeof(res));
55
- #endif
56
-
57
- return res >> 16;
58
- }
59
-
60
- inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
61
- #if defined(USE_ROCM) && defined(__HIPCC__)
62
- if (src != src) {
63
- #elif defined(_MSC_VER)
64
- if (isnan(src)) {
65
- #else
66
- if (std::isnan(src)) {
67
- #endif
68
- return UINT16_C(0x7FC0);
69
- } else {
70
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
71
- union {
72
- uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
73
- float F32; // NOLINT(facebook-hte-BadMemberName)
74
- };
75
-
76
- F32 = src;
77
- uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
78
- return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
79
- }
80
- }
81
- } // namespace detail
82
-
83
- struct alignas(2) BFloat16 {
84
- uint16_t x;
85
-
86
- // HIP wants __host__ __device__ tag, CUDA does not
87
- #if defined(USE_ROCM) && defined(__HIPCC__)
88
- C10_HOST_DEVICE BFloat16() = default;
89
- #else
90
- BFloat16() = default;
91
- #endif
92
-
93
- struct from_bits_t {};
94
- static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
95
- return from_bits_t();
96
- }
97
-
98
- constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
99
- : x(bits) {}
100
- /* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
101
- inline C10_HOST_DEVICE operator float() const;
102
-
103
- #if defined(__CUDACC__) && !defined(USE_ROCM)
104
- inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16 &value);
105
- explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
106
- #endif
107
-
108
- #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
109
- inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16 &value);
110
- explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
111
- #endif
112
- };
113
-
114
- C10_API inline std::ostream &operator<<(std::ostream &out,
115
- const BFloat16 &value) {
116
- out << (float)value;
117
- return out;
118
- }
119
-
120
- } // namespace c10
121
-
122
- #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
1
+ #include <torch/headeronly/util/BFloat16.h>