react-native-executorch 0.5.15 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (277) hide show
  1. package/README.md +42 -36
  2. package/android/CMakeLists.txt +13 -25
  3. package/android/build.gradle +2 -3
  4. package/android/libs/classes.jar +0 -0
  5. package/android/src/main/cpp/CMakeLists.txt +2 -1
  6. package/common/rnexecutorch/RnExecutorchInstaller.cpp +18 -0
  7. package/common/rnexecutorch/TokenizerModule.cpp +3 -3
  8. package/common/rnexecutorch/data_processing/Numerical.cpp +31 -23
  9. package/common/rnexecutorch/data_processing/Numerical.h +6 -1
  10. package/common/rnexecutorch/data_processing/dsp.cpp +0 -46
  11. package/common/rnexecutorch/host_objects/JsiConversions.h +16 -0
  12. package/common/rnexecutorch/host_objects/ModelHostObject.h +26 -11
  13. package/common/rnexecutorch/jsi/OwningArrayBuffer.h +19 -2
  14. package/common/rnexecutorch/metaprogramming/TypeConcepts.h +0 -20
  15. package/common/rnexecutorch/models/BaseModel.cpp +12 -11
  16. package/common/rnexecutorch/models/BaseModel.h +18 -10
  17. package/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +3 -11
  18. package/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +0 -1
  19. package/common/rnexecutorch/models/image_segmentation/ImageSegmentation.cpp +6 -12
  20. package/common/rnexecutorch/models/llm/LLM.cpp +25 -8
  21. package/common/rnexecutorch/models/llm/LLM.h +4 -4
  22. package/common/rnexecutorch/models/ocr/CTCLabelConverter.h +1 -1
  23. package/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +7 -4
  24. package/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +8 -13
  25. package/common/rnexecutorch/models/speech_to_text/SpeechToText.h +1 -3
  26. package/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +12 -19
  27. package/common/rnexecutorch/models/speech_to_text/asr/ASR.h +4 -5
  28. package/common/rnexecutorch/models/text_to_image/Constants.h +9 -0
  29. package/common/rnexecutorch/models/text_to_image/Decoder.cpp +32 -0
  30. package/common/rnexecutorch/models/text_to_image/Decoder.h +24 -0
  31. package/common/rnexecutorch/models/text_to_image/Encoder.cpp +44 -0
  32. package/common/rnexecutorch/models/text_to_image/Encoder.h +32 -0
  33. package/common/rnexecutorch/models/text_to_image/Scheduler.cpp +152 -0
  34. package/common/rnexecutorch/models/text_to_image/Scheduler.h +41 -0
  35. package/common/rnexecutorch/models/text_to_image/TextToImage.cpp +141 -0
  36. package/common/rnexecutorch/models/text_to_image/TextToImage.h +64 -0
  37. package/common/rnexecutorch/models/text_to_image/UNet.cpp +38 -0
  38. package/common/rnexecutorch/models/text_to_image/UNet.h +28 -0
  39. package/common/rnexecutorch/models/voice_activity_detection/Constants.h +27 -0
  40. package/common/rnexecutorch/models/voice_activity_detection/Types.h +12 -0
  41. package/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +15 -0
  42. package/common/rnexecutorch/models/voice_activity_detection/Utils.h +13 -0
  43. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +160 -0
  44. package/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +36 -0
  45. package/common/rnexecutorch/tests/CMakeLists.txt +30 -0
  46. package/common/rnexecutorch/tests/NumericalTest.cpp +110 -0
  47. package/common/rnexecutorch/tests/README.md +30 -13
  48. package/common/rnexecutorch/threads/GlobalThreadPool.h +4 -0
  49. package/common/runner/arange_util.cpp +44 -0
  50. package/common/runner/arange_util.h +37 -0
  51. package/common/runner/constants.h +28 -0
  52. package/common/runner/io_manager.h +240 -0
  53. package/common/runner/irunner.h +87 -16
  54. package/common/runner/kernel_includes.h +23 -0
  55. package/common/runner/runner.cpp +151 -66
  56. package/common/runner/runner.h +39 -22
  57. package/common/runner/sampler.cpp +8 -1
  58. package/common/runner/sampler.h +4 -2
  59. package/common/runner/stats.h +1 -4
  60. package/common/runner/text_decoder_runner.cpp +26 -12
  61. package/common/runner/text_decoder_runner.h +52 -31
  62. package/common/runner/text_prefiller.cpp +46 -12
  63. package/common/runner/text_prefiller.h +38 -4
  64. package/common/runner/text_token_generator.h +51 -26
  65. package/common/runner/util.h +53 -8
  66. package/ios/RnExecutorch.xcodeproj/project.pbxproj +0 -23
  67. package/lib/module/Error.js +1 -0
  68. package/lib/module/Error.js.map +1 -1
  69. package/lib/module/constants/directories.js +1 -1
  70. package/lib/module/constants/directories.js.map +1 -1
  71. package/lib/module/constants/modelUrls.js +32 -1
  72. package/lib/module/constants/modelUrls.js.map +1 -1
  73. package/lib/module/constants/ocr/models.js +7 -7
  74. package/lib/module/constants/ocr/models.js.map +1 -1
  75. package/lib/module/constants/ocr/symbols.js +3 -2
  76. package/lib/module/constants/ocr/symbols.js.map +1 -1
  77. package/lib/module/controllers/LLMController.js +10 -1
  78. package/lib/module/controllers/LLMController.js.map +1 -1
  79. package/lib/module/controllers/OCRController.js +3 -3
  80. package/lib/module/controllers/OCRController.js.map +1 -1
  81. package/lib/module/controllers/VerticalOCRController.js +2 -2
  82. package/lib/module/controllers/VerticalOCRController.js.map +1 -1
  83. package/lib/module/hooks/computer_vision/useOCR.js +3 -3
  84. package/lib/module/hooks/computer_vision/useOCR.js.map +1 -1
  85. package/lib/module/hooks/{useNonStaticModule.js → computer_vision/useTextToImage.js} +21 -16
  86. package/lib/module/hooks/computer_vision/useTextToImage.js.map +1 -0
  87. package/lib/module/hooks/computer_vision/useVerticalOCR.js +3 -3
  88. package/lib/module/hooks/computer_vision/useVerticalOCR.js.map +1 -1
  89. package/lib/module/hooks/natural_language_processing/useLLM.js +3 -3
  90. package/lib/module/hooks/natural_language_processing/useLLM.js.map +1 -1
  91. package/lib/module/hooks/natural_language_processing/useTokenizer.js +5 -5
  92. package/lib/module/hooks/natural_language_processing/useTokenizer.js.map +1 -1
  93. package/lib/module/hooks/natural_language_processing/useVAD.js +13 -0
  94. package/lib/module/hooks/natural_language_processing/useVAD.js.map +1 -0
  95. package/lib/module/index.js +7 -2
  96. package/lib/module/index.js.map +1 -1
  97. package/lib/module/modules/computer_vision/OCRModule.js +2 -2
  98. package/lib/module/modules/computer_vision/OCRModule.js.map +1 -1
  99. package/lib/module/modules/computer_vision/TextToImageModule.js +48 -0
  100. package/lib/module/modules/computer_vision/TextToImageModule.js.map +1 -0
  101. package/lib/module/modules/computer_vision/VerticalOCRModule.js +2 -2
  102. package/lib/module/modules/computer_vision/VerticalOCRModule.js.map +1 -1
  103. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js +7 -4
  104. package/lib/module/modules/natural_language_processing/SpeechToTextModule.js.map +1 -1
  105. package/lib/module/modules/natural_language_processing/VADModule.js +19 -0
  106. package/lib/module/modules/natural_language_processing/VADModule.js.map +1 -0
  107. package/lib/module/types/llm.js.map +1 -1
  108. package/lib/module/types/vad.js +2 -0
  109. package/lib/module/types/vad.js.map +1 -0
  110. package/lib/module/utils/ResourceFetcher.js +2 -1
  111. package/lib/module/utils/ResourceFetcher.js.map +1 -1
  112. package/lib/module/utils/ResourceFetcherUtils.js +6 -6
  113. package/lib/module/utils/ResourceFetcherUtils.js.map +1 -1
  114. package/lib/typescript/Error.d.ts +1 -0
  115. package/lib/typescript/Error.d.ts.map +1 -1
  116. package/lib/typescript/constants/modelUrls.d.ts +23 -0
  117. package/lib/typescript/constants/modelUrls.d.ts.map +1 -1
  118. package/lib/typescript/constants/ocr/symbols.d.ts +1 -1
  119. package/lib/typescript/constants/ocr/symbols.d.ts.map +1 -1
  120. package/lib/typescript/controllers/LLMController.d.ts.map +1 -1
  121. package/lib/typescript/controllers/OCRController.d.ts +1 -1
  122. package/lib/typescript/controllers/OCRController.d.ts.map +1 -1
  123. package/lib/typescript/controllers/VerticalOCRController.d.ts +1 -1
  124. package/lib/typescript/controllers/VerticalOCRController.d.ts.map +1 -1
  125. package/lib/typescript/hooks/computer_vision/useOCR.d.ts +1 -1
  126. package/lib/typescript/hooks/computer_vision/useOCR.d.ts.map +1 -1
  127. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts +22 -0
  128. package/lib/typescript/hooks/computer_vision/useTextToImage.d.ts.map +1 -0
  129. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts +1 -1
  130. package/lib/typescript/hooks/computer_vision/useVerticalOCR.d.ts.map +1 -1
  131. package/lib/typescript/hooks/natural_language_processing/useLLM.d.ts.map +1 -1
  132. package/lib/typescript/hooks/natural_language_processing/useSpeechToText.d.ts +2 -2
  133. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts +16 -0
  134. package/lib/typescript/hooks/natural_language_processing/useVAD.d.ts.map +1 -0
  135. package/lib/typescript/index.d.ts +8 -1
  136. package/lib/typescript/index.d.ts.map +1 -1
  137. package/lib/typescript/modules/computer_vision/OCRModule.d.ts +1 -1
  138. package/lib/typescript/modules/computer_vision/OCRModule.d.ts.map +1 -1
  139. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts +16 -0
  140. package/lib/typescript/modules/computer_vision/TextToImageModule.d.ts.map +1 -0
  141. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts +1 -1
  142. package/lib/typescript/modules/computer_vision/VerticalOCRModule.d.ts.map +1 -1
  143. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts +3 -2
  144. package/lib/typescript/modules/natural_language_processing/SpeechToTextModule.d.ts.map +1 -1
  145. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts +10 -0
  146. package/lib/typescript/modules/natural_language_processing/VADModule.d.ts.map +1 -0
  147. package/lib/typescript/types/llm.d.ts +2 -0
  148. package/lib/typescript/types/llm.d.ts.map +1 -1
  149. package/lib/typescript/types/vad.d.ts +5 -0
  150. package/lib/typescript/types/vad.d.ts.map +1 -0
  151. package/lib/typescript/utils/ResourceFetcher.d.ts +29 -0
  152. package/lib/typescript/utils/ResourceFetcher.d.ts.map +1 -1
  153. package/lib/typescript/utils/ResourceFetcherUtils.d.ts +2 -2
  154. package/lib/typescript/utils/ResourceFetcherUtils.d.ts.map +1 -1
  155. package/package.json +11 -8
  156. package/react-native-executorch.podspec +9 -9
  157. package/src/Error.ts +1 -0
  158. package/src/constants/directories.ts +1 -1
  159. package/src/constants/modelUrls.ts +36 -1
  160. package/src/constants/ocr/models.ts +7 -7
  161. package/src/constants/ocr/symbols.ts +3 -2
  162. package/src/controllers/LLMController.ts +12 -1
  163. package/src/controllers/OCRController.ts +3 -3
  164. package/src/controllers/VerticalOCRController.ts +2 -2
  165. package/src/hooks/computer_vision/useOCR.ts +4 -5
  166. package/src/hooks/computer_vision/useTextToImage.ts +92 -0
  167. package/src/hooks/computer_vision/useVerticalOCR.ts +4 -5
  168. package/src/hooks/natural_language_processing/useLLM.ts +3 -4
  169. package/src/hooks/natural_language_processing/useTokenizer.ts +5 -5
  170. package/src/hooks/natural_language_processing/useVAD.ts +15 -0
  171. package/src/index.ts +20 -1
  172. package/src/modules/computer_vision/OCRModule.ts +2 -2
  173. package/src/modules/computer_vision/TextToImageModule.ts +93 -0
  174. package/src/modules/computer_vision/VerticalOCRModule.ts +2 -2
  175. package/src/modules/natural_language_processing/SpeechToTextModule.ts +8 -4
  176. package/src/modules/natural_language_processing/VADModule.ts +27 -0
  177. package/src/types/llm.ts +2 -0
  178. package/src/types/vad.ts +4 -0
  179. package/src/utils/ResourceFetcher.ts +2 -1
  180. package/src/utils/ResourceFetcherUtils.ts +8 -8
  181. package/third-party/android/libs/cpuinfo/arm64-v8a/libcpuinfo.so +0 -0
  182. package/third-party/android/libs/executorch/arm64-v8a/libexecutorch.so +0 -0
  183. package/third-party/android/libs/executorch/x86_64/libexecutorch.so +0 -0
  184. package/third-party/android/libs/pthreadpool/arm64-v8a/libpthreadpool.so +0 -0
  185. package/third-party/include/c10/macros/Export.h +0 -78
  186. package/third-party/include/c10/macros/Macros.h +1 -520
  187. package/third-party/include/c10/util/BFloat16-inl.h +1 -339
  188. package/third-party/include/c10/util/BFloat16.h +1 -122
  189. package/third-party/include/c10/util/Half-inl.h +1 -347
  190. package/third-party/include/c10/util/Half.h +6 -419
  191. package/third-party/include/c10/util/TypeSafeSignMath.h +1 -133
  192. package/third-party/include/c10/util/bit_cast.h +1 -43
  193. package/third-party/include/c10/util/complex.h +1 -568
  194. package/third-party/include/c10/util/floating_point_utils.h +1 -33
  195. package/third-party/include/c10/util/irange.h +1 -1
  196. package/third-party/include/c10/util/llvmMathExtras.h +866 -0
  197. package/third-party/include/c10/util/safe_numerics.h +97 -0
  198. package/third-party/include/executorch/ExecuTorchError.h +6 -7
  199. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLM.h +12 -0
  200. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMConfig.h +56 -0
  201. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMError.h +16 -0
  202. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMMultimodalRunner.h +227 -0
  203. package/third-party/include/executorch/ExecuTorchLLM/ExecuTorchLLMTextRunner.h +97 -0
  204. package/third-party/include/executorch/ExecuTorchLLM/module.modulemap +4 -0
  205. package/third-party/include/executorch/ExecuTorchLog.h +1 -0
  206. package/third-party/include/executorch/ExecuTorchModule.h +177 -4
  207. package/third-party/include/executorch/ExecuTorchTensor.h +3 -4
  208. package/third-party/include/executorch/ExecuTorchValue.h +1 -7
  209. package/third-party/include/executorch/extension/module/module.h +139 -8
  210. package/third-party/include/executorch/extension/tensor/tensor.h +1 -0
  211. package/third-party/include/executorch/extension/tensor/tensor_ptr.h +88 -26
  212. package/third-party/include/executorch/extension/threadpool/threadpool.h +4 -1
  213. package/third-party/include/executorch/runtime/backend/backend_init_context.h +6 -0
  214. package/third-party/include/executorch/runtime/backend/interface.h +1 -1
  215. package/third-party/include/executorch/runtime/core/error.h +76 -49
  216. package/third-party/include/executorch/runtime/core/exec_aten/util/scalar_type_util.h +18 -4
  217. package/third-party/include/executorch/runtime/core/memory_allocator.h +12 -2
  218. package/third-party/include/executorch/runtime/core/named_data_map.h +1 -11
  219. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Export.h +0 -78
  220. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/macros/Macros.h +1 -520
  221. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16-inl.h +1 -339
  222. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/BFloat16.h +1 -122
  223. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half-inl.h +1 -347
  224. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/Half.h +6 -419
  225. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/TypeSafeSignMath.h +1 -133
  226. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/bit_cast.h +1 -43
  227. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/complex.h +1 -568
  228. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/floating_point_utils.h +1 -33
  229. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/irange.h +1 -1
  230. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/llvmMathExtras.h +866 -0
  231. package/third-party/include/executorch/runtime/core/portable_type/c10/c10/util/safe_numerics.h +97 -0
  232. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Export.h +66 -0
  233. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/macros/Macros.h +553 -0
  234. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/BFloat16.h +477 -0
  235. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/Half.h +781 -0
  236. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  237. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/bit_cast.h +49 -0
  238. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/complex.h +593 -0
  239. package/third-party/include/executorch/runtime/core/portable_type/c10/torch/headeronly/util/floating_point_utils.h +38 -0
  240. package/third-party/include/executorch/runtime/core/tensor_layout.h +1 -1
  241. package/third-party/include/executorch/runtime/executor/merged_data_map.h +142 -0
  242. package/third-party/include/executorch/runtime/executor/method.h +21 -8
  243. package/third-party/include/executorch/runtime/executor/method_meta.h +20 -2
  244. package/third-party/include/executorch/runtime/executor/program.h +0 -10
  245. package/third-party/include/executorch/runtime/kernel/operator_registry.h +1 -1
  246. package/third-party/include/executorch/runtime/platform/compiler.h +2 -0
  247. package/third-party/include/executorch/schema/extended_header.h +10 -1
  248. package/third-party/include/torch/headeronly/macros/Export.h +66 -0
  249. package/third-party/include/torch/headeronly/macros/Macros.h +553 -0
  250. package/third-party/include/torch/headeronly/util/BFloat16.h +477 -0
  251. package/third-party/include/torch/headeronly/util/Half.h +781 -0
  252. package/third-party/include/torch/headeronly/util/TypeSafeSignMath.h +141 -0
  253. package/third-party/include/torch/headeronly/util/bit_cast.h +49 -0
  254. package/third-party/include/torch/headeronly/util/complex.h +593 -0
  255. package/third-party/include/torch/headeronly/util/floating_point_utils.h +38 -0
  256. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/ExecutorchLib +0 -0
  257. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64/ExecutorchLib.framework/Info.plist +0 -0
  258. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib +0 -0
  259. package/third-party/ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist +0 -0
  260. package/common/rnexecutorch/tests/run_all_tests.sh +0 -14
  261. package/common/rnexecutorch/tests/run_test.sh +0 -18
  262. package/ios/RnExecutorch/utils/Conversions.h +0 -14
  263. package/ios/RnExecutorch/utils/ETError.h +0 -26
  264. package/ios/RnExecutorch/utils/ImageProcessor.h +0 -15
  265. package/ios/RnExecutorch/utils/ImageProcessor.mm +0 -147
  266. package/ios/RnExecutorch/utils/Numerical.h +0 -3
  267. package/ios/RnExecutorch/utils/Numerical.mm +0 -18
  268. package/ios/RnExecutorch/utils/ScalarType.h +0 -14
  269. package/ios/RnExecutorch/utils/ScalarType.mm +0 -21
  270. package/lib/module/hooks/useNonStaticModule.js.map +0 -1
  271. package/lib/typescript/hooks/useNonStaticModule.d.ts +0 -21
  272. package/lib/typescript/hooks/useNonStaticModule.d.ts.map +0 -1
  273. package/src/hooks/useNonStaticModule.ts +0 -74
  274. package/third-party/include/executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h +0 -181
  275. package/third-party/include/executorch/extension/kernel_util/meta_programming.h +0 -108
  276. package/third-party/include/executorch/extension/kernel_util/type_list.h +0 -137
  277. package/third-party/include/executorch/extension/threadpool/threadpool_guard.h +0 -35
@@ -1,347 +1 @@
1
- #pragma once
2
-
3
- #include <c10/macros/Macros.h>
4
- #include <c10/util/bit_cast.h>
5
-
6
- #include <cstring>
7
- #include <limits>
8
-
9
- #ifdef __CUDACC__
10
- #include <cuda_fp16.h>
11
- #endif
12
-
13
- #ifdef __HIPCC__
14
- #include <hip/hip_fp16.h>
15
- #endif
16
-
17
- #if defined(CL_SYCL_LANGUAGE_VERSION)
18
- #include <CL/sycl.hpp> // for SYCL 1.2.1
19
- #elif defined(SYCL_LANGUAGE_VERSION)
20
- #include <sycl/sycl.hpp> // for SYCL 2020
21
- #endif
22
-
23
- #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
24
- !defined(__APPLE__)
25
- #include <ATen/cpu/vec/vec_half.h>
26
- #endif
27
-
28
- C10_CLANG_DIAGNOSTIC_PUSH()
29
- #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
30
- C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
31
- #endif
32
-
33
- namespace c10 {
34
-
35
- #if defined(__aarch64__) && !defined(__CUDACC__)
36
- /// Constructors
37
- inline Half::Half(float16_t value) : x(detail::fp16_to_bits(value)) {}
38
- inline Half::operator float16_t() const { return detail::fp16_from_bits(x); }
39
- #else
40
-
41
- inline C10_HOST_DEVICE Half::Half(float value)
42
- :
43
- #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
44
- x(__half_as_short(__float2half(value)))
45
- #elif defined(__SYCL_DEVICE_ONLY__)
46
- x(c10::bit_cast<uint16_t>(sycl::half(value)))
47
- #elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
48
- !defined(__APPLE__)
49
- x(at::vec::float2half_scalar(value))
50
- #else
51
- x(detail::fp16_ieee_from_fp32_value(value))
52
- #endif
53
- {
54
- }
55
-
56
- /// Implicit conversions
57
-
58
- inline C10_HOST_DEVICE Half::operator float() const {
59
- #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
60
- return __half2float(*reinterpret_cast<const __half *>(&x));
61
- #elif defined(__SYCL_DEVICE_ONLY__)
62
- return float(c10::bit_cast<sycl::half>(x));
63
- #elif (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && \
64
- !defined(__APPLE__)
65
- return at::vec::half2float_scalar(x);
66
- #elif defined(__aarch64__) && !defined(__CUDACC__)
67
- return detail::native_fp16_to_fp32_value(x);
68
- #else
69
- return detail::fp16_ieee_to_fp32_value(x);
70
- #endif
71
- }
72
-
73
- #endif /* !defined(__aarch64__) || defined(__CUDACC__) \
74
- */
75
-
76
- #if defined(__CUDACC__) || defined(__HIPCC__)
77
- inline C10_HOST_DEVICE Half::Half(const __half &value) {
78
- x = *reinterpret_cast<const unsigned short *>(&value);
79
- }
80
- inline C10_HOST_DEVICE Half::operator __half() const {
81
- return *reinterpret_cast<const __half *>(&x);
82
- }
83
- #endif
84
-
85
- #ifdef SYCL_LANGUAGE_VERSION
86
- inline C10_HOST_DEVICE Half::Half(const sycl::half &value) {
87
- x = *reinterpret_cast<const unsigned short *>(&value);
88
- }
89
- inline C10_HOST_DEVICE Half::operator sycl::half() const {
90
- return *reinterpret_cast<const sycl::half *>(&x);
91
- }
92
- #endif
93
-
94
- // CUDA intrinsics
95
-
96
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)) || \
97
- (defined(__clang__) && defined(__CUDA__))
98
- inline __device__ Half __ldg(const Half *ptr) {
99
- return __ldg(reinterpret_cast<const __half *>(ptr));
100
- }
101
- #endif
102
-
103
- /// Arithmetic
104
-
105
- inline C10_HOST_DEVICE Half operator+(const Half &a, const Half &b) {
106
- return static_cast<float>(a) + static_cast<float>(b);
107
- }
108
-
109
- inline C10_HOST_DEVICE Half operator-(const Half &a, const Half &b) {
110
- return static_cast<float>(a) - static_cast<float>(b);
111
- }
112
-
113
- inline C10_HOST_DEVICE Half operator*(const Half &a, const Half &b) {
114
- return static_cast<float>(a) * static_cast<float>(b);
115
- }
116
-
117
- inline C10_HOST_DEVICE Half operator/(const Half &a, const Half &b)
118
- __ubsan_ignore_float_divide_by_zero__ {
119
- return static_cast<float>(a) / static_cast<float>(b);
120
- }
121
-
122
- inline C10_HOST_DEVICE Half operator-(const Half &a) {
123
- #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || \
124
- defined(__HIP_DEVICE_COMPILE__)
125
- return __hneg(a);
126
- #elif defined(__SYCL_DEVICE_ONLY__)
127
- return -c10::bit_cast<sycl::half>(a);
128
- #else
129
- return -static_cast<float>(a);
130
- #endif
131
- }
132
-
133
- inline C10_HOST_DEVICE Half &operator+=(Half &a, const Half &b) {
134
- a = a + b;
135
- return a;
136
- }
137
-
138
- inline C10_HOST_DEVICE Half &operator-=(Half &a, const Half &b) {
139
- a = a - b;
140
- return a;
141
- }
142
-
143
- inline C10_HOST_DEVICE Half &operator*=(Half &a, const Half &b) {
144
- a = a * b;
145
- return a;
146
- }
147
-
148
- inline C10_HOST_DEVICE Half &operator/=(Half &a, const Half &b) {
149
- a = a / b;
150
- return a;
151
- }
152
-
153
- /// Arithmetic with floats
154
-
155
- inline C10_HOST_DEVICE float operator+(Half a, float b) {
156
- return static_cast<float>(a) + b;
157
- }
158
- inline C10_HOST_DEVICE float operator-(Half a, float b) {
159
- return static_cast<float>(a) - b;
160
- }
161
- inline C10_HOST_DEVICE float operator*(Half a, float b) {
162
- return static_cast<float>(a) * b;
163
- }
164
- inline C10_HOST_DEVICE float
165
- operator/(Half a, float b) __ubsan_ignore_float_divide_by_zero__ {
166
- return static_cast<float>(a) / b;
167
- }
168
-
169
- inline C10_HOST_DEVICE float operator+(float a, Half b) {
170
- return a + static_cast<float>(b);
171
- }
172
- inline C10_HOST_DEVICE float operator-(float a, Half b) {
173
- return a - static_cast<float>(b);
174
- }
175
- inline C10_HOST_DEVICE float operator*(float a, Half b) {
176
- return a * static_cast<float>(b);
177
- }
178
- inline C10_HOST_DEVICE float
179
- operator/(float a, Half b) __ubsan_ignore_float_divide_by_zero__ {
180
- return a / static_cast<float>(b);
181
- }
182
-
183
- inline C10_HOST_DEVICE float &operator+=(float &a, const Half &b) {
184
- return a += static_cast<float>(b);
185
- }
186
- inline C10_HOST_DEVICE float &operator-=(float &a, const Half &b) {
187
- return a -= static_cast<float>(b);
188
- }
189
- inline C10_HOST_DEVICE float &operator*=(float &a, const Half &b) {
190
- return a *= static_cast<float>(b);
191
- }
192
- inline C10_HOST_DEVICE float &operator/=(float &a, const Half &b) {
193
- return a /= static_cast<float>(b);
194
- }
195
-
196
- /// Arithmetic with doubles
197
-
198
- inline C10_HOST_DEVICE double operator+(Half a, double b) {
199
- return static_cast<double>(a) + b;
200
- }
201
- inline C10_HOST_DEVICE double operator-(Half a, double b) {
202
- return static_cast<double>(a) - b;
203
- }
204
- inline C10_HOST_DEVICE double operator*(Half a, double b) {
205
- return static_cast<double>(a) * b;
206
- }
207
- inline C10_HOST_DEVICE double
208
- operator/(Half a, double b) __ubsan_ignore_float_divide_by_zero__ {
209
- return static_cast<double>(a) / b;
210
- }
211
-
212
- inline C10_HOST_DEVICE double operator+(double a, Half b) {
213
- return a + static_cast<double>(b);
214
- }
215
- inline C10_HOST_DEVICE double operator-(double a, Half b) {
216
- return a - static_cast<double>(b);
217
- }
218
- inline C10_HOST_DEVICE double operator*(double a, Half b) {
219
- return a * static_cast<double>(b);
220
- }
221
- inline C10_HOST_DEVICE double
222
- operator/(double a, Half b) __ubsan_ignore_float_divide_by_zero__ {
223
- return a / static_cast<double>(b);
224
- }
225
-
226
- /// Arithmetic with ints
227
-
228
- inline C10_HOST_DEVICE Half operator+(Half a, int b) {
229
- return a + static_cast<Half>(b);
230
- }
231
- inline C10_HOST_DEVICE Half operator-(Half a, int b) {
232
- return a - static_cast<Half>(b);
233
- }
234
- inline C10_HOST_DEVICE Half operator*(Half a, int b) {
235
- return a * static_cast<Half>(b);
236
- }
237
- inline C10_HOST_DEVICE Half operator/(Half a, int b) {
238
- return a / static_cast<Half>(b);
239
- }
240
-
241
- inline C10_HOST_DEVICE Half operator+(int a, Half b) {
242
- return static_cast<Half>(a) + b;
243
- }
244
- inline C10_HOST_DEVICE Half operator-(int a, Half b) {
245
- return static_cast<Half>(a) - b;
246
- }
247
- inline C10_HOST_DEVICE Half operator*(int a, Half b) {
248
- return static_cast<Half>(a) * b;
249
- }
250
- inline C10_HOST_DEVICE Half operator/(int a, Half b) {
251
- return static_cast<Half>(a) / b;
252
- }
253
-
254
- //// Arithmetic with int64_t
255
-
256
- inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
257
- return a + static_cast<Half>(b);
258
- }
259
- inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
260
- return a - static_cast<Half>(b);
261
- }
262
- inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
263
- return a * static_cast<Half>(b);
264
- }
265
- inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
266
- return a / static_cast<Half>(b);
267
- }
268
-
269
- inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
270
- return static_cast<Half>(a) + b;
271
- }
272
- inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
273
- return static_cast<Half>(a) - b;
274
- }
275
- inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
276
- return static_cast<Half>(a) * b;
277
- }
278
- inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
279
- return static_cast<Half>(a) / b;
280
- }
281
-
282
- /// NOTE: we do not define comparisons directly and instead rely on the implicit
283
- /// conversion from c10::Half to float.
284
-
285
- } // namespace c10
286
-
287
- namespace std {
288
-
289
- template <> class numeric_limits<c10::Half> {
290
- public:
291
- static constexpr bool is_specialized = true;
292
- static constexpr bool is_signed = true;
293
- static constexpr bool is_integer = false;
294
- static constexpr bool is_exact = false;
295
- static constexpr bool has_infinity = true;
296
- static constexpr bool has_quiet_NaN = true;
297
- static constexpr bool has_signaling_NaN = true;
298
- static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
299
- static constexpr auto has_denorm_loss =
300
- numeric_limits<float>::has_denorm_loss;
301
- static constexpr auto round_style = numeric_limits<float>::round_style;
302
- static constexpr bool is_iec559 = true;
303
- static constexpr bool is_bounded = true;
304
- static constexpr bool is_modulo = false;
305
- static constexpr int digits = 11;
306
- static constexpr int digits10 = 3;
307
- static constexpr int max_digits10 = 5;
308
- static constexpr int radix = 2;
309
- static constexpr int min_exponent = -13;
310
- static constexpr int min_exponent10 = -4;
311
- static constexpr int max_exponent = 16;
312
- static constexpr int max_exponent10 = 4;
313
- static constexpr auto traps = numeric_limits<float>::traps;
314
- static constexpr auto tinyness_before =
315
- numeric_limits<float>::tinyness_before;
316
- static constexpr c10::Half min() {
317
- return c10::Half(0x0400, c10::Half::from_bits());
318
- }
319
- static constexpr c10::Half lowest() {
320
- return c10::Half(0xFBFF, c10::Half::from_bits());
321
- }
322
- static constexpr c10::Half max() {
323
- return c10::Half(0x7BFF, c10::Half::from_bits());
324
- }
325
- static constexpr c10::Half epsilon() {
326
- return c10::Half(0x1400, c10::Half::from_bits());
327
- }
328
- static constexpr c10::Half round_error() {
329
- return c10::Half(0x3800, c10::Half::from_bits());
330
- }
331
- static constexpr c10::Half infinity() {
332
- return c10::Half(0x7C00, c10::Half::from_bits());
333
- }
334
- static constexpr c10::Half quiet_NaN() {
335
- return c10::Half(0x7E00, c10::Half::from_bits());
336
- }
337
- static constexpr c10::Half signaling_NaN() {
338
- return c10::Half(0x7D00, c10::Half::from_bits());
339
- }
340
- static constexpr c10::Half denorm_min() {
341
- return c10::Half(0x0001, c10::Half::from_bits());
342
- }
343
- };
344
-
345
- } // namespace std
346
-
347
- C10_CLANG_DIAGNOSTIC_POP()
1
+ #include <torch/headeronly/util/Half.h>