@smake/eigen 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -20
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +187 -120
  6. package/eigen/Eigen/Eigenvalues +16 -13
  7. package/eigen/Eigen/Geometry +18 -18
  8. package/eigen/Eigen/Householder +9 -7
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -13
  11. package/eigen/Eigen/KLUSupport +23 -21
  12. package/eigen/Eigen/LU +15 -16
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -51
  15. package/eigen/Eigen/PaStiXSupport +23 -21
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -20
  18. package/eigen/Eigen/QtAlignedMalloc +5 -12
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -17
  21. package/eigen/Eigen/Sparse +1 -2
  22. package/eigen/Eigen/SparseCholesky +18 -15
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +9 -9
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +366 -405
  37. package/eigen/Eigen/src/Cholesky/LLT.h +323 -367
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +585 -529
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +143 -317
  42. package/eigen/Eigen/src/Core/Array.h +329 -370
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +126 -170
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +651 -604
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +125 -120
  48. package/eigen/Eigen/src/Core/BandMatrix.h +267 -282
  49. package/eigen/Eigen/src/Core/Block.h +371 -390
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +85 -100
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1214 -937
  53. package/eigen/Eigen/src/Core/CoreIterators.h +72 -63
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +112 -129
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +676 -702
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +55 -67
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +127 -92
  59. package/eigen/Eigen/src/Core/DenseBase.h +630 -658
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -628
  61. package/eigen/Eigen/src/Core/DenseStorage.h +511 -590
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +168 -207
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +346 -317
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +167 -217
  67. package/eigen/Eigen/src/Core/EigenBase.h +74 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -113
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +315 -261
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1182 -520
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +193 -157
  75. package/eigen/Eigen/src/Core/IO.h +131 -156
  76. package/eigen/Eigen/src/Core/IndexedView.h +209 -125
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +50 -59
  80. package/eigen/Eigen/src/Core/Map.h +123 -141
  81. package/eigen/Eigen/src/Core/MapBase.h +255 -282
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1247 -1201
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +162 -99
  84. package/eigen/Eigen/src/Core/Matrix.h +463 -494
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -470
  86. package/eigen/Eigen/src/Core/NestByValue.h +58 -52
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -86
  88. package/eigen/Eigen/src/Core/NumTraits.h +206 -206
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +163 -142
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +858 -972
  92. package/eigen/Eigen/src/Core/Product.h +246 -130
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +779 -671
  94. package/eigen/Eigen/src/Core/Random.h +153 -164
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +334 -314
  98. package/eigen/Eigen/src/Core/Ref.h +259 -257
  99. package/eigen/Eigen/src/Core/Replicate.h +92 -104
  100. package/eigen/Eigen/src/Core/Reshaped.h +215 -271
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +47 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +133 -148
  103. package/eigen/Eigen/src/Core/Select.h +68 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +254 -290
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +88 -102
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +126 -124
  109. package/eigen/Eigen/src/Core/SolverBase.h +132 -133
  110. package/eigen/Eigen/src/Core/StableNorm.h +113 -147
  111. package/eigen/Eigen/src/Core/StlIterators.h +404 -248
  112. package/eigen/Eigen/src/Core/Stride.h +90 -92
  113. package/eigen/Eigen/src/Core/Swap.h +70 -39
  114. package/eigen/Eigen/src/Core/Transpose.h +258 -295
  115. package/eigen/Eigen/src/Core/Transpositions.h +270 -333
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +642 -743
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +653 -704
  119. package/eigen/Eigen/src/Core/Visitor.h +464 -308
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +380 -187
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +65 -163
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2145 -638
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +253 -60
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +278 -228
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +48 -269
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1597 -754
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +229 -41
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +420 -184
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +40 -49
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2962 -2213
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +196 -212
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +713 -441
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2380 -1362
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +390 -224
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +78 -67
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1784 -799
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +167 -50
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +528 -379
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +10 -12
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +41 -40
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +550 -523
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +27 -30
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +8 -8
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +54 -82
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +84 -92
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +51 -47
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +454 -306
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +175 -115
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +23 -30
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4366 -2857
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +616 -393
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +350 -198
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +38 -149
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +1791 -912
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +128 -40
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +10 -6
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +156 -234
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +6 -3
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +27 -32
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +119 -117
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +325 -419
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +15 -17
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +325 -181
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +94 -83
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +811 -458
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +121 -124
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +576 -370
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +194 -109
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +95 -112
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1038 -749
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1883 -1375
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +312 -370
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +189 -176
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +84 -81
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +292 -337
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +207 -105
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +327 -388
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +138 -147
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -47
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -277
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +68 -94
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +342 -303
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +348 -317
  217. package/eigen/Eigen/src/Core/util/Constants.h +297 -262
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -90
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +449 -247
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +417 -116
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +211 -204
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -37
  226. package/eigen/Eigen/src/Core/util/Macros.h +655 -773
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +970 -748
  229. package/eigen/Eigen/src/Core/util/Meta.h +581 -633
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +17 -17
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +50 -166
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +377 -225
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +784 -547
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +89 -105
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +537 -607
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +342 -381
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +541 -595
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +430 -462
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +226 -227
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +131 -133
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +285 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +151 -160
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -146
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +127 -127
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +566 -506
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +107 -105
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +113 -106
  264. package/eigen/Eigen/src/Geometry/Transform.h +858 -936
  265. package/eigen/Eigen/src/Geometry/Translation.h +94 -92
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +90 -104
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +51 -46
  269. package/eigen/Eigen/src/Householder/Householder.h +102 -124
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +412 -453
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +149 -162
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +124 -119
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +92 -104
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +251 -243
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +224 -228
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +178 -227
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +79 -84
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +54 -60
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +252 -308
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +208 -227
  285. package/eigen/Eigen/src/LU/Determinant.h +50 -69
  286. package/eigen/Eigen/src/LU/FullPivLU.h +545 -596
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +206 -285
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +390 -428
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +72 -70
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +243 -265
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +831 -1004
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +112 -119
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -430
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +479 -479
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +166 -153
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +495 -475
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +394 -285
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +244 -264
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +817 -713
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +577 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +242 -182
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +200 -235
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +765 -594
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +308 -94
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -252
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +134 -178
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +149 -140
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +403 -440
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +525 -303
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +555 -339
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +169 -197
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1603 -1245
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -350
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +94 -97
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +370 -416
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +138 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +756 -710
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +245 -301
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +89 -100
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +124 -132
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +450 -502
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -93
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -730
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +428 -464
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9972 -16179
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/{BlockMethods.h → BlockMethods.inc} +434 -506
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/{CommonCwiseUnaryOps.h → CommonCwiseUnaryOps.inc} +58 -68
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/package.json +1 -1
  409. package/eigen/COPYING.APACHE +0 -203
  410. package/eigen/COPYING.BSD +0 -26
  411. package/eigen/COPYING.GPL +0 -674
  412. package/eigen/COPYING.LGPL +0 -502
  413. package/eigen/COPYING.MINPACK +0 -51
  414. package/eigen/COPYING.MPL2 +0 -373
  415. package/eigen/COPYING.README +0 -18
  416. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -162
  417. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -258
  418. package/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +0 -120
  419. package/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +0 -694
  420. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  421. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  422. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  423. package/eigen/Eigen/src/misc/lapack.h +0 -152
  424. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -358
  425. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -696
  426. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  427. package/eigen/Eigen/src/plugins/IndexedViewMethods.h +0 -262
  428. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  429. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -95
  430. package/eigen/Eigen/src/plugins/ReshapedMethods.h +0 -149
  431. package/eigen/README.md +0 -5
@@ -24,7 +24,6 @@
24
24
  // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25
25
  // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
 
27
-
28
27
  // Standard 16-bit float type, mostly useful for GPUs. Defines a new
29
28
  // type Eigen::half (inheriting either from CUDA's or HIP's __half struct) with
30
29
  // operator overloads such that it behaves basically as an arithmetic
@@ -32,29 +31,30 @@
32
31
  // in fp32 for CPUs, except for simple parameter conversions, I/O
33
32
  // to disk and the likes), but fast on GPUs.
34
33
 
35
-
36
34
  #ifndef EIGEN_HALF_H
37
35
  #define EIGEN_HALF_H
38
36
 
39
- #include <sstream>
37
+ // IWYU pragma: private
38
+ #include "../../InternalHeaderCheck.h"
40
39
 
41
- #if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
42
40
  // When compiling with GPU support, the "__half_raw" base class as well as
43
41
  // some other routines are defined in the GPU compiler header files
44
42
  // (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
45
43
  // As a consequence, we get compile failures when compiling Eigen with
46
44
  // GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
47
- // Eigen with GPU support
48
- #pragma push_macro("EIGEN_CONSTEXPR")
49
- #undef EIGEN_CONSTEXPR
50
- #define EIGEN_CONSTEXPR
45
+ // Eigen with GPU support.
46
+ // Any functions that require `numext::bit_cast` may also not be constexpr,
47
+ // including any native types when setting via raw bit values.
48
+ #if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
49
+ #define _EIGEN_MAYBE_CONSTEXPR
50
+ #else
51
+ #define _EIGEN_MAYBE_CONSTEXPR constexpr
51
52
  #endif
52
53
 
53
- #define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
54
- template <> \
55
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \
56
- PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
57
- return float2half(METHOD<PACKET_F>(half2float(_x))); \
54
+ #define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
55
+ template <> \
56
+ EIGEN_UNUSED EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
57
+ return float2half(METHOD<PACKET_F>(half2float(_x))); \
58
58
  }
59
59
 
60
60
  namespace Eigen {
@@ -83,8 +83,10 @@ namespace half_impl {
83
83
  // Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
84
84
  // this error, and hence the following convoluted #if condition
85
85
  #if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
86
+
86
87
  // Make our own __half_raw definition that is similar to CUDA's.
87
88
  struct __half_raw {
89
+ struct construct_from_rep_tag {};
88
90
  #if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
89
91
  // Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
90
92
  // The element type for shared memory cannot have non-trivial constructors
@@ -93,54 +95,62 @@ struct __half_raw {
93
95
  // hence the need for this
94
96
  EIGEN_DEVICE_FUNC __half_raw() {}
95
97
  #else
96
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
98
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw() : x(0) {}
97
99
  #endif
100
+
98
101
  #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
99
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {
100
- }
102
+ explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
103
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, __fp16 rep) : x{rep} {}
101
104
  __fp16 x;
105
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
106
+ explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<_Float16>(raw)) {}
107
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, _Float16 rep) : x{rep} {}
108
+ _Float16 x;
102
109
  #else
103
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
110
+ explicit EIGEN_DEVICE_FUNC constexpr __half_raw(numext::uint16_t raw) : x(raw) {}
111
+ EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, numext::uint16_t rep) : x{rep} {}
104
112
  numext::uint16_t x;
105
113
  #endif
106
114
  };
107
115
 
108
116
  #elif defined(EIGEN_HAS_HIP_FP16)
109
- // Nothing to do here
110
- // HIP fp16 header file has a definition for __half_raw
117
+ // HIP GPU compile phase: nothing to do here.
118
+ // HIP fp16 header file has a definition for __half_raw
111
119
  #elif defined(EIGEN_HAS_CUDA_FP16)
112
- #if EIGEN_CUDA_SDK_VER < 90000
113
- // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
114
- typedef __half __half_raw;
115
- #endif // defined(EIGEN_HAS_CUDA_FP16)
120
+
121
+ // CUDA GPU compile phase.
122
+ #if EIGEN_CUDA_SDK_VER < 90000
123
+ // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
124
+ typedef __half __half_raw;
125
+ #endif // defined(EIGEN_HAS_CUDA_FP16)
126
+
116
127
  #elif defined(SYCL_DEVICE_ONLY)
117
- typedef cl::sycl::half __half_raw;
128
+ typedef cl::sycl::half __half_raw;
118
129
  #endif
119
130
 
120
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
131
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
121
132
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
122
133
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
123
134
 
124
135
  struct half_base : public __half_raw {
125
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
126
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
136
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base() {}
137
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
127
138
 
128
139
  #if defined(EIGEN_HAS_GPU_FP16)
129
- #if defined(EIGEN_HAS_HIP_FP16)
130
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
131
- #elif defined(EIGEN_HAS_CUDA_FP16)
132
- #if EIGEN_CUDA_SDK_VER >= 90000
133
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
134
- #endif
135
- #endif
140
+ #if defined(EIGEN_HAS_HIP_FP16)
141
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
142
+ #elif defined(EIGEN_HAS_CUDA_FP16)
143
+ #if EIGEN_CUDA_SDK_VER >= 90000
144
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
145
+ #endif
146
+ #endif
136
147
  #endif
137
148
  };
138
149
 
139
- } // namespace half_impl
150
+ } // namespace half_impl
140
151
 
141
152
  // Class definition.
142
153
  struct half : public half_impl::half_base {
143
-
144
154
  // Writing this out as separate #if-else blocks to make the code easier to follow
145
155
  // The same applies to most #if-else blocks in this file
146
156
  #if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
@@ -152,44 +162,50 @@ struct half : public half_impl::half_base {
152
162
  // Nothing to do here
153
163
  // HIP fp16 header file has a definition for __half_raw
154
164
  #elif defined(EIGEN_HAS_CUDA_FP16)
155
- // Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so
156
- // (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within
157
- // #if defined(EIGEN_HAS_CUDA_FP16) is needed
158
- #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000
159
- typedef half_impl::__half_raw __half_raw;
160
- #endif
165
+ // Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so
166
+ // (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within
167
+ // #if defined(EIGEN_HAS_CUDA_FP16) is needed
168
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000
169
+ typedef half_impl::__half_raw __half_raw;
170
+ #endif
161
171
  #endif
162
172
 
163
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
173
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half() {}
164
174
 
165
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
175
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
166
176
 
167
177
  #if defined(EIGEN_HAS_GPU_FP16)
168
- #if defined(EIGEN_HAS_HIP_FP16)
169
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
170
- #elif defined(EIGEN_HAS_CUDA_FP16)
171
- #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
172
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
173
- #endif
174
- #endif
178
+ #if defined(EIGEN_HAS_HIP_FP16)
179
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
180
+ #elif defined(EIGEN_HAS_CUDA_FP16)
181
+ #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
182
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
183
+ #endif
184
+ #endif
175
185
  #endif
176
186
 
187
+ #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
188
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(__fp16 b)
189
+ : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
190
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
191
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(_Float16 b)
192
+ : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
193
+ #endif
177
194
 
178
- explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
195
+ explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(bool b)
179
196
  : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
180
- template<class T>
197
+ template <class T>
181
198
  explicit EIGEN_DEVICE_FUNC half(T val)
182
199
  : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
183
- explicit EIGEN_DEVICE_FUNC half(float f)
184
- : half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
200
+ explicit EIGEN_DEVICE_FUNC half(float f) : half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
185
201
 
186
202
  // Following the convention of numpy, converting between complex and
187
203
  // float will lead to loss of imag value.
188
- template<typename RealScalar>
204
+ template <typename RealScalar>
189
205
  explicit EIGEN_DEVICE_FUNC half(std::complex<RealScalar> c)
190
206
  : half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(c.real()))) {}
191
207
 
192
- EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
208
+ EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless.
193
209
  return half_impl::half_to_float(*this);
194
210
  }
195
211
 
@@ -202,69 +218,131 @@ struct half : public half_impl::half_base {
202
218
  #endif
203
219
  };
204
220
 
205
- } // end namespace Eigen
206
-
207
- namespace std {
208
- template<>
209
- struct numeric_limits<Eigen::half> {
210
- static const bool is_specialized = true;
211
- static const bool is_signed = true;
212
- static const bool is_integer = false;
213
- static const bool is_exact = false;
214
- static const bool has_infinity = true;
215
- static const bool has_quiet_NaN = true;
216
- static const bool has_signaling_NaN = true;
217
- static const float_denorm_style has_denorm = denorm_present;
218
- static const bool has_denorm_loss = false;
219
- static const std::float_round_style round_style = std::round_to_nearest;
220
- static const bool is_iec559 = false;
221
- static const bool is_bounded = false;
222
- static const bool is_modulo = false;
223
- static const int digits = 11;
224
- static const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
225
- static const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
226
- static const int radix = 2;
227
- static const int min_exponent = -13;
228
- static const int min_exponent10 = -4;
229
- static const int max_exponent = 16;
230
- static const int max_exponent10 = 4;
231
- static const bool traps = true;
232
- static const bool tinyness_before = false;
233
-
234
- static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); }
235
- static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
236
- static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
237
- static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); }
238
- static Eigen::half round_error() { return Eigen::half(0.5); }
239
- static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
240
- static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
241
- static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
242
- static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); }
221
+ // TODO(majnemer): Get rid of this once we can rely on C++17 inline variables do
222
+ // solve the ODR issue.
223
+ namespace half_impl {
224
+ template <typename = void>
225
+ struct numeric_limits_half_impl {
226
+ static constexpr const bool is_specialized = true;
227
+ static constexpr const bool is_signed = true;
228
+ static constexpr const bool is_integer = false;
229
+ static constexpr const bool is_exact = false;
230
+ static constexpr const bool has_infinity = true;
231
+ static constexpr const bool has_quiet_NaN = true;
232
+ static constexpr const bool has_signaling_NaN = true;
233
+ EIGEN_DIAGNOSTICS(push)
234
+ EIGEN_DISABLE_DEPRECATED_WARNING
235
+ static constexpr const std::float_denorm_style has_denorm = std::denorm_present;
236
+ static constexpr const bool has_denorm_loss = false;
237
+ EIGEN_DIAGNOSTICS(pop)
238
+ static constexpr const std::float_round_style round_style = std::round_to_nearest;
239
+ static constexpr const bool is_iec559 = true;
240
+ // The C++ standard defines this as "true if the set of values representable
241
+ // by the type is finite." Half has finite precision.
242
+ static constexpr const bool is_bounded = true;
243
+ static constexpr const bool is_modulo = false;
244
+ static constexpr const int digits = 11;
245
+ static constexpr const int digits10 =
246
+ 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
247
+ static constexpr const int max_digits10 =
248
+ 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
249
+ static constexpr const int radix = std::numeric_limits<float>::radix;
250
+ static constexpr const int min_exponent = -13;
251
+ static constexpr const int min_exponent10 = -4;
252
+ static constexpr const int max_exponent = 16;
253
+ static constexpr const int max_exponent10 = 4;
254
+ static constexpr const bool traps = std::numeric_limits<float>::traps;
255
+ // IEEE754: "The implementer shall choose how tininess is detected, but shall
256
+ // detect tininess in the same way for all operations in radix two"
257
+ static constexpr const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
258
+
259
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
260
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
261
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
262
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
263
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
264
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
265
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
266
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
267
+ static _EIGEN_MAYBE_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
243
268
  };
244
269
 
270
+ template <typename T>
271
+ constexpr const bool numeric_limits_half_impl<T>::is_specialized;
272
+ template <typename T>
273
+ constexpr const bool numeric_limits_half_impl<T>::is_signed;
274
+ template <typename T>
275
+ constexpr const bool numeric_limits_half_impl<T>::is_integer;
276
+ template <typename T>
277
+ constexpr const bool numeric_limits_half_impl<T>::is_exact;
278
+ template <typename T>
279
+ constexpr const bool numeric_limits_half_impl<T>::has_infinity;
280
+ template <typename T>
281
+ constexpr const bool numeric_limits_half_impl<T>::has_quiet_NaN;
282
+ template <typename T>
283
+ constexpr const bool numeric_limits_half_impl<T>::has_signaling_NaN;
284
+ EIGEN_DIAGNOSTICS(push)
285
+ EIGEN_DISABLE_DEPRECATED_WARNING
286
+ template <typename T>
287
+ constexpr const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
288
+ template <typename T>
289
+ constexpr const bool numeric_limits_half_impl<T>::has_denorm_loss;
290
+ EIGEN_DIAGNOSTICS(pop)
291
+ template <typename T>
292
+ constexpr const std::float_round_style numeric_limits_half_impl<T>::round_style;
293
+ template <typename T>
294
+ constexpr const bool numeric_limits_half_impl<T>::is_iec559;
295
+ template <typename T>
296
+ constexpr const bool numeric_limits_half_impl<T>::is_bounded;
297
+ template <typename T>
298
+ constexpr const bool numeric_limits_half_impl<T>::is_modulo;
299
+ template <typename T>
300
+ constexpr const int numeric_limits_half_impl<T>::digits;
301
+ template <typename T>
302
+ constexpr const int numeric_limits_half_impl<T>::digits10;
303
+ template <typename T>
304
+ constexpr const int numeric_limits_half_impl<T>::max_digits10;
305
+ template <typename T>
306
+ constexpr const int numeric_limits_half_impl<T>::radix;
307
+ template <typename T>
308
+ constexpr const int numeric_limits_half_impl<T>::min_exponent;
309
+ template <typename T>
310
+ constexpr const int numeric_limits_half_impl<T>::min_exponent10;
311
+ template <typename T>
312
+ constexpr const int numeric_limits_half_impl<T>::max_exponent;
313
+ template <typename T>
314
+ constexpr const int numeric_limits_half_impl<T>::max_exponent10;
315
+ template <typename T>
316
+ constexpr const bool numeric_limits_half_impl<T>::traps;
317
+ template <typename T>
318
+ constexpr const bool numeric_limits_half_impl<T>::tinyness_before;
319
+ } // end namespace half_impl
320
+ } // end namespace Eigen
321
+
322
+ namespace std {
245
323
  // If std::numeric_limits<T> is specialized, should also specialize
246
324
  // std::numeric_limits<const T>, std::numeric_limits<volatile T>, and
247
325
  // std::numeric_limits<const volatile T>
248
326
  // https://stackoverflow.com/a/16519653/
249
- template<>
250
- struct numeric_limits<const Eigen::half> : numeric_limits<Eigen::half> {};
251
- template<>
252
- struct numeric_limits<volatile Eigen::half> : numeric_limits<Eigen::half> {};
253
- template<>
254
- struct numeric_limits<const volatile Eigen::half> : numeric_limits<Eigen::half> {};
255
- } // end namespace std
327
+ template <>
328
+ class numeric_limits<Eigen::half> : public Eigen::half_impl::numeric_limits_half_impl<> {};
329
+ template <>
330
+ class numeric_limits<const Eigen::half> : public numeric_limits<Eigen::half> {};
331
+ template <>
332
+ class numeric_limits<volatile Eigen::half> : public numeric_limits<Eigen::half> {};
333
+ template <>
334
+ class numeric_limits<const volatile Eigen::half> : public numeric_limits<Eigen::half> {};
335
+ } // end namespace std
256
336
 
257
337
  namespace Eigen {
258
338
 
259
339
  namespace half_impl {
260
340
 
261
- #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \
262
- EIGEN_CUDA_ARCH >= 530) || \
341
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
263
342
  (defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
264
- // Note: We deliberatly do *not* define this to 1 even if we have Arm's native
265
- // fp16 type since GPU halfs are rather different from native CPU halfs.
266
- // TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
267
- #define EIGEN_HAS_NATIVE_FP16
343
+ // Note: We deliberately do *not* define this to 1 even if we have Arm's native
344
+ // fp16 type since GPU half types are rather different from native CPU half types.
345
+ #define EIGEN_HAS_NATIVE_GPU_FP16
268
346
  #endif
269
347
 
270
348
  // Intrinsics for native fp16 support. Note that on current hardware,
@@ -272,21 +350,17 @@ namespace half_impl {
272
350
  // versions to get the ALU speed increased), but you do save the
273
351
  // conversion steps back and forth.
274
352
 
275
- #if defined(EIGEN_HAS_NATIVE_FP16)
276
- EIGEN_STRONG_INLINE __device__ half operator + (const half& a, const half& b) {
353
+ #if defined(EIGEN_HAS_NATIVE_GPU_FP16)
354
+ EIGEN_STRONG_INLINE __device__ half operator+(const half& a, const half& b) {
277
355
  #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
278
356
  return __hadd(::__half(a), ::__half(b));
279
357
  #else
280
358
  return __hadd(a, b);
281
359
  #endif
282
360
  }
283
- EIGEN_STRONG_INLINE __device__ half operator * (const half& a, const half& b) {
284
- return __hmul(a, b);
285
- }
286
- EIGEN_STRONG_INLINE __device__ half operator - (const half& a, const half& b) {
287
- return __hsub(a, b);
288
- }
289
- EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) {
361
+ EIGEN_STRONG_INLINE __device__ half operator*(const half& a, const half& b) { return __hmul(a, b); }
362
+ EIGEN_STRONG_INLINE __device__ half operator-(const half& a, const half& b) { return __hsub(a, b); }
363
+ EIGEN_STRONG_INLINE __device__ half operator/(const half& a, const half& b) {
290
364
  #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
291
365
  return __hdiv(a, b);
292
366
  #else
@@ -295,173 +369,194 @@ EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) {
295
369
  return __float2half(num / denom);
296
370
  #endif
297
371
  }
298
- EIGEN_STRONG_INLINE __device__ half operator - (const half& a) {
299
- return __hneg(a);
300
- }
301
- EIGEN_STRONG_INLINE __device__ half& operator += (half& a, const half& b) {
372
+ EIGEN_STRONG_INLINE __device__ half operator-(const half& a) { return __hneg(a); }
373
+ EIGEN_STRONG_INLINE __device__ half& operator+=(half& a, const half& b) {
302
374
  a = a + b;
303
375
  return a;
304
376
  }
305
- EIGEN_STRONG_INLINE __device__ half& operator *= (half& a, const half& b) {
377
+ EIGEN_STRONG_INLINE __device__ half& operator*=(half& a, const half& b) {
306
378
  a = a * b;
307
379
  return a;
308
380
  }
309
- EIGEN_STRONG_INLINE __device__ half& operator -= (half& a, const half& b) {
381
+ EIGEN_STRONG_INLINE __device__ half& operator-=(half& a, const half& b) {
310
382
  a = a - b;
311
383
  return a;
312
384
  }
313
- EIGEN_STRONG_INLINE __device__ half& operator /= (half& a, const half& b) {
385
+ EIGEN_STRONG_INLINE __device__ half& operator/=(half& a, const half& b) {
314
386
  a = a / b;
315
387
  return a;
316
388
  }
317
- EIGEN_STRONG_INLINE __device__ bool operator == (const half& a, const half& b) {
318
- return __heq(a, b);
319
- }
320
- EIGEN_STRONG_INLINE __device__ bool operator != (const half& a, const half& b) {
321
- return __hne(a, b);
322
- }
323
- EIGEN_STRONG_INLINE __device__ bool operator < (const half& a, const half& b) {
324
- return __hlt(a, b);
325
- }
326
- EIGEN_STRONG_INLINE __device__ bool operator <= (const half& a, const half& b) {
327
- return __hle(a, b);
328
- }
329
- EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) {
330
- return __hgt(a, b);
331
- }
332
- EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
333
- return __hge(a, b);
334
- }
335
- #endif
336
-
337
- #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
338
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
339
- return half(vaddh_f16(a.x, b.x));
340
- }
341
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
342
- return half(vmulh_f16(a.x, b.x));
343
- }
344
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
345
- return half(vsubh_f16(a.x, b.x));
346
- }
347
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
348
- return half(vdivh_f16(a.x, b.x));
349
- }
350
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
351
- return half(vnegh_f16(a.x));
352
- }
353
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
389
+ EIGEN_STRONG_INLINE __device__ bool operator==(const half& a, const half& b) { return __heq(a, b); }
390
+ EIGEN_STRONG_INLINE __device__ bool operator!=(const half& a, const half& b) { return __hne(a, b); }
391
+ EIGEN_STRONG_INLINE __device__ bool operator<(const half& a, const half& b) { return __hlt(a, b); }
392
+ EIGEN_STRONG_INLINE __device__ bool operator<=(const half& a, const half& b) { return __hle(a, b); }
393
+ EIGEN_STRONG_INLINE __device__ bool operator>(const half& a, const half& b) { return __hgt(a, b); }
394
+ EIGEN_STRONG_INLINE __device__ bool operator>=(const half& a, const half& b) { return __hge(a, b); }
395
+
396
+ #endif // EIGEN_HAS_NATIVE_GPU_FP16
397
+
398
+ #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) && !defined(EIGEN_GPU_COMPILE_PHASE)
399
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(vaddh_f16(a.x, b.x)); }
400
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(vmulh_f16(a.x, b.x)); }
401
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(vsubh_f16(a.x, b.x)); }
402
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(vdivh_f16(a.x, b.x)); }
403
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(vnegh_f16(a.x)); }
404
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
354
405
  a = half(vaddh_f16(a.x, b.x));
355
406
  return a;
356
407
  }
357
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
408
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
358
409
  a = half(vmulh_f16(a.x, b.x));
359
410
  return a;
360
411
  }
361
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
412
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
362
413
  a = half(vsubh_f16(a.x, b.x));
363
414
  return a;
364
415
  }
365
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
416
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
366
417
  a = half(vdivh_f16(a.x, b.x));
367
418
  return a;
368
419
  }
369
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
370
- return vceqh_f16(a.x, b.x);
371
- }
372
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
373
- return !vceqh_f16(a.x, b.x);
374
- }
375
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
376
- return vclth_f16(a.x, b.x);
420
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return vceqh_f16(a.x, b.x); }
421
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !vceqh_f16(a.x, b.x); }
422
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return vclth_f16(a.x, b.x); }
423
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return vcleh_f16(a.x, b.x); }
424
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return vcgth_f16(a.x, b.x); }
425
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return vcgeh_f16(a.x, b.x); }
426
+
427
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16) && !defined(EIGEN_GPU_COMPILE_PHASE)
428
+
429
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(a.x + b.x); }
430
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(a.x * b.x); }
431
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(a.x - b.x); }
432
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(a.x / b.x); }
433
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(-a.x); }
434
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
435
+ a = a + b;
436
+ return a;
377
437
  }
378
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
379
- return vcleh_f16(a.x, b.x);
438
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
439
+ a = a * b;
440
+ return a;
380
441
  }
381
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
382
- return vcgth_f16(a.x, b.x);
442
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
443
+ a = a - b;
444
+ return a;
383
445
  }
384
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
385
- return vcgeh_f16(a.x, b.x);
446
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
447
+ a = a / b;
448
+ return a;
386
449
  }
450
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return a.x == b.x; }
451
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return a.x != b.x; }
452
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return a.x < b.x; }
453
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return a.x <= b.x; }
454
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return a.x > b.x; }
455
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return a.x >= b.x; }
456
+
387
457
  // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
388
458
  // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
389
459
  // of the functions, while the latter can only deal with one of them.
390
- #elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
460
+ #elif !defined(EIGEN_HAS_NATIVE_GPU_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
391
461
 
392
- #if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
462
+ #if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
393
463
  // We need to provide emulated *host-side* FP16 operators for clang.
394
464
  #pragma push_macro("EIGEN_DEVICE_FUNC")
395
465
  #undef EIGEN_DEVICE_FUNC
396
- #if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
466
+ #if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_GPU_FP16)
397
467
  #define EIGEN_DEVICE_FUNC __host__
398
- #else // both host and device need emulated ops.
468
+ #else // both host and device need emulated ops.
399
469
  #define EIGEN_DEVICE_FUNC __host__ __device__
400
470
  #endif
401
471
  #endif
402
472
 
403
473
  // Definitions for CPUs and older HIP+CUDA, mostly working through conversion
404
474
  // to/from fp32.
405
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
406
- return half(float(a) + float(b));
407
- }
408
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
409
- return half(float(a) * float(b));
410
- }
411
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
412
- return half(float(a) - float(b));
413
- }
414
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
415
- return half(float(a) / float(b));
416
- }
417
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
475
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(float(a) + float(b)); }
476
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(float(a) * float(b)); }
477
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(float(a) - float(b)); }
478
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(float(a) / float(b)); }
479
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) {
418
480
  half result;
419
481
  result.x = a.x ^ 0x8000;
420
482
  return result;
421
483
  }
422
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
484
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
423
485
  a = half(float(a) + float(b));
424
486
  return a;
425
487
  }
426
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
488
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
427
489
  a = half(float(a) * float(b));
428
490
  return a;
429
491
  }
430
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
492
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
431
493
  a = half(float(a) - float(b));
432
494
  return a;
433
495
  }
434
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
496
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
435
497
  a = half(float(a) / float(b));
436
498
  return a;
437
499
  }
438
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
439
- return numext::equal_strict(float(a),float(b));
440
- }
441
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
442
- return numext::not_equal_strict(float(a), float(b));
500
+
501
+ // Non-negative floating point numbers have a monotonic mapping to non-negative integers.
502
+ // This property allows floating point numbers to be reinterpreted as integers for comparisons, which is useful if there
503
+ // is no native floating point comparison operator. Floating point signedness is handled by the sign-magnitude
504
+ // representation, whereas integers typically use two's complement. Converting the bit pattern from sign-magnitude to
505
+ // two's complement allows the transformed bit patterns be compared as signed integers. All edge cases (+/-0 and +/-
506
+ // infinity) are handled automatically, except NaN.
507
+ //
508
+ // fp16 uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The bit pattern conveys NaN when all the exponent
509
+ // bits (5) are set, and at least one mantissa bit is set. The sign bit is irrelevant for determining NaN. To check for
510
+ // NaN, clear the sign bit and check if the integral representation is greater than 01111100000000. To test
511
+ // for non-NaN, clear the sign bit and check if the integeral representation is less than or equal to 01111100000000.
512
+
513
+ // convert sign-magnitude representation to two's complement
514
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int16_t mapToSigned(uint16_t a) {
515
+ constexpr uint16_t kAbsMask = (1 << 15) - 1;
516
+ // If the sign bit is set, clear the sign bit and return the (integer) negation. Otherwise, return the input.
517
+ return (a >> 15) ? -(a & kAbsMask) : a;
518
+ }
519
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool isOrdered(const half& a, const half& b) {
520
+ constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
521
+ constexpr uint16_t kAbsMask = (1 << 15) - 1;
522
+ return numext::maxi(a.x & kAbsMask, b.x & kAbsMask) <= kInf;
523
+ }
524
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) {
525
+ bool result = mapToSigned(a.x) == mapToSigned(b.x);
526
+ result &= isOrdered(a, b);
527
+ return result;
443
528
  }
444
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
445
- return float(a) < float(b);
529
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !(a == b); }
530
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) {
531
+ bool result = mapToSigned(a.x) < mapToSigned(b.x);
532
+ result &= isOrdered(a, b);
533
+ return result;
446
534
  }
447
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
448
- return float(a) <= float(b);
535
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) {
536
+ bool result = mapToSigned(a.x) <= mapToSigned(b.x);
537
+ result &= isOrdered(a, b);
538
+ return result;
449
539
  }
450
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
451
- return float(a) > float(b);
540
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) {
541
+ bool result = mapToSigned(a.x) > mapToSigned(b.x);
542
+ result &= isOrdered(a, b);
543
+ return result;
452
544
  }
453
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
454
- return float(a) >= float(b);
545
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) {
546
+ bool result = mapToSigned(a.x) >= mapToSigned(b.x);
547
+ result &= isOrdered(a, b);
548
+ return result;
455
549
  }
456
550
 
457
- #if defined(__clang__) && defined(__CUDA__)
551
+ #if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
458
552
  #pragma pop_macro("EIGEN_DEVICE_FUNC")
459
553
  #endif
554
+
460
555
  #endif // Emulate support for half floats
461
556
 
462
557
  // Division by an index. Do it in full float precision to avoid accuracy
463
558
  // issues in converting the denominator to half.
464
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
559
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, Index b) {
465
560
  return half(static_cast<float>(a) / static_cast<float>(b));
466
561
  }
467
562
 
@@ -492,7 +587,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
492
587
  // these in hardware. If we need more performance on older/other CPUs, they are
493
588
  // also possible to vectorize directly.
494
589
 
495
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
590
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
496
591
  // We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
497
592
  // in the hip_fp16 header file, and that will trigger a compile error
498
593
  // On the other hand, having anything but a return statement also triggers a compile error
@@ -500,8 +595,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_h
500
595
  // Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out
501
596
  // of this catch22 by having separate bodies for GPU / non GPU
502
597
  #if defined(EIGEN_HAS_GPU_FP16)
503
- __half_raw h;
504
- h.x = x;
598
+ __half_raw h;
599
+ h.x = x;
505
600
  return h;
506
601
  #else
507
602
  return __half_raw(x);
@@ -514,6 +609,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
514
609
  // For SYCL, cl::sycl::half is _Float16, so cast directly.
515
610
  #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
516
611
  return numext::bit_cast<numext::uint16_t>(h.x);
612
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
613
+ return numext::bit_cast<numext::uint16_t>(h.x);
517
614
  #elif defined(SYCL_DEVICE_ONLY)
518
615
  return numext::bit_cast<numext::uint16_t>(h);
519
616
  #else
@@ -521,67 +618,72 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
521
618
  #endif
522
619
  }
523
620
 
524
- union float32_bits {
525
- unsigned int u;
526
- float f;
527
- };
528
-
529
621
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
530
622
  #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
531
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
623
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
532
624
  __half tmp_ff = __float2half(ff);
533
625
  return *(__half_raw*)&tmp_ff;
534
626
 
535
- #elif defined(EIGEN_HAS_FP16_C)
627
+ #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
536
628
  __half_raw h;
537
- h.x = _cvtss_sh(ff, 0);
629
+ h.x = static_cast<__fp16>(ff);
538
630
  return h;
539
631
 
540
- #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
632
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
541
633
  __half_raw h;
542
- h.x = static_cast<__fp16>(ff);
634
+ h.x = static_cast<_Float16>(ff);
543
635
  return h;
544
636
 
637
+ #elif defined(EIGEN_HAS_FP16_C)
638
+ __half_raw h;
639
+ #if EIGEN_COMP_MSVC
640
+ // MSVC does not have scalar instructions.
641
+ h.x = _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(ff), 0), 0);
545
642
  #else
546
- float32_bits f; f.f = ff;
643
+ h.x = _cvtss_sh(ff, 0);
644
+ #endif
645
+ return h;
547
646
 
548
- const float32_bits f32infty = { 255 << 23 };
549
- const float32_bits f16max = { (127 + 16) << 23 };
550
- const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
551
- unsigned int sign_mask = 0x80000000u;
647
+ #else
648
+ uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
649
+ const uint32_t f32infty_bits = {255 << 23};
650
+ const uint32_t f16max_bits = {(127 + 16) << 23};
651
+ const uint32_t denorm_magic_bits = {((127 - 15) + (23 - 10) + 1) << 23};
652
+ const uint32_t sign_mask = 0x80000000u;
552
653
  __half_raw o;
553
- o.x = static_cast<numext::uint16_t>(0x0u);
654
+ o.x = static_cast<uint16_t>(0x0u);
554
655
 
555
- unsigned int sign = f.u & sign_mask;
556
- f.u ^= sign;
656
+ const uint32_t sign = f_bits & sign_mask;
657
+ f_bits ^= sign;
557
658
 
558
659
  // NOTE all the integer compares in this function can be safely
559
660
  // compiled into signed compares since all operands are below
560
661
  // 0x80000000. Important if you want fast straight SSE2 code
561
662
  // (since there's no unsigned PCMPGTD).
562
663
 
563
- if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set)
564
- o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
565
- } else { // (De)normalized number or zero
566
- if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero
664
+ if (f_bits >= f16max_bits) { // result is Inf or NaN (all exponent bits set)
665
+ o.x = (f_bits > f32infty_bits) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
666
+ } else { // (De)normalized number or zero
667
+ if (f_bits < (113 << 23)) { // resulting FP16 is subnormal or zero
567
668
  // use a magic value to align our 10 mantissa bits at the bottom of
568
669
  // the float. as long as FP addition is round-to-nearest-even this
569
670
  // just works.
570
- f.f += denorm_magic.f;
671
+ f_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(f_bits) +
672
+ Eigen::numext::bit_cast<float>(denorm_magic_bits));
571
673
 
572
674
  // and one integer subtract of the bias later, we have our final float!
573
- o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
675
+ o.x = static_cast<numext::uint16_t>(f_bits - denorm_magic_bits);
574
676
  } else {
575
- unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
677
+ const uint32_t mant_odd = (f_bits >> 13) & 1; // resulting mantissa is odd
576
678
 
577
679
  // update exponent, rounding bias part 1
578
680
  // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but
579
681
  // without arithmetic overflow.
580
- f.u += 0xc8000fffU;
682
+ f_bits += 0xc8000fffU;
581
683
  // rounding bias part 2
582
- f.u += mant_odd;
684
+ f_bits += mant_odd;
583
685
  // take the bits!
584
- o.x = static_cast<numext::uint16_t>(f.u >> 13);
686
+ o.x = static_cast<numext::uint16_t>(f_bits >> 13);
585
687
  }
586
688
  }
587
689
 
@@ -592,60 +694,73 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
592
694
 
593
695
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
594
696
  #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
595
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
697
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
596
698
  return __half2float(h);
699
+ #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
700
+ return static_cast<float>(h.x);
597
701
  #elif defined(EIGEN_HAS_FP16_C)
702
+ #if EIGEN_COMP_MSVC
703
+ // MSVC does not have scalar instructions.
704
+ return _mm_cvtss_f32(_mm_cvtph_ps(_mm_set1_epi16(h.x)));
705
+ #else
598
706
  return _cvtsh_ss(h.x);
599
- #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
600
- return static_cast<float>(h.x);
707
+ #endif
601
708
  #else
602
- const float32_bits magic = { 113 << 23 };
603
- const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
604
- float32_bits o;
605
-
606
- o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
607
- unsigned int exp = shifted_exp & o.u; // just the exponent
608
- o.u += (127 - 15) << 23; // exponent adjust
709
+ const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
710
+ const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
711
+ uint32_t o_bits = (h.x & 0x7fff) << 13; // exponent/mantissa bits
712
+ const uint32_t exp = shifted_exp & o_bits; // just the exponent
713
+ o_bits += (127 - 15) << 23; // exponent adjust
609
714
 
610
715
  // handle exponent special cases
611
- if (exp == shifted_exp) { // Inf/NaN?
612
- o.u += (128 - 16) << 23; // extra exp adjust
613
- } else if (exp == 0) { // Zero/Denormal?
614
- o.u += 1 << 23; // extra exp adjust
615
- o.f -= magic.f; // renormalize
716
+ if (exp == shifted_exp) { // Inf/NaN?
717
+ o_bits += (128 - 16) << 23; // extra exp adjust
718
+ } else if (exp == 0) { // Zero/Denormal?
719
+ o_bits += 1 << 23; // extra exp adjust
720
+ // renormalize
721
+ o_bits = Eigen::numext::bit_cast<uint32_t>(Eigen::numext::bit_cast<float>(o_bits) - magic);
616
722
  }
617
723
 
618
- o.u |= (h.x & 0x8000) << 16; // sign bit
619
- return o.f;
724
+ o_bits |= (h.x & 0x8000) << 16; // sign bit
725
+ return Eigen::numext::bit_cast<float>(o_bits);
620
726
  #endif
621
727
  }
622
728
 
623
729
  // --- standard functions ---
624
730
 
625
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
626
- #ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
731
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const half& a) {
732
+ #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
627
733
  return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
628
734
  #else
629
735
  return (a.x & 0x7fff) == 0x7c00;
630
736
  #endif
631
737
  }
632
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
738
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
633
739
  #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
634
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
740
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
635
741
  return __hisnan(a);
636
- #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
742
+ #elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
637
743
  return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
638
744
  #else
639
745
  return (a.x & 0x7fff) > 0x7c00;
640
746
  #endif
641
747
  }
642
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
643
- return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a));
748
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
749
+ #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
750
+ return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) < 0x7c00;
751
+ #else
752
+ return (a.x & 0x7fff) < 0x7c00;
753
+ #endif
644
754
  }
645
755
 
646
756
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
647
757
  #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
648
758
  return half(vabsh_f16(a.x));
759
+ #elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
760
+ half result;
761
+ result.x =
762
+ numext::bit_cast<_Float16>(static_cast<numext::uint16_t>(numext::bit_cast<numext::uint16_t>(a.x) & 0x7FFF));
763
+ return result;
649
764
  #else
650
765
  half result;
651
766
  result.x = a.x & 0x7FFF;
@@ -654,65 +769,61 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
654
769
  }
655
770
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
656
771
  #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
657
- defined(EIGEN_HIP_DEVICE_COMPILE)
772
+ defined(EIGEN_HIP_DEVICE_COMPILE)
658
773
  return half(hexp(a));
659
774
  #else
660
- return half(::expf(float(a)));
775
+ return half(::expf(float(a)));
661
776
  #endif
662
777
  }
663
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) {
664
- return half(numext::expm1(float(a)));
778
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp2(const half& a) {
779
+ #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
780
+ defined(EIGEN_HIP_DEVICE_COMPILE)
781
+ return half(hexp2(a));
782
+ #else
783
+ return half(::exp2f(float(a)));
784
+ #endif
665
785
  }
786
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) { return half(numext::expm1(float(a))); }
666
787
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) {
667
- #if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
668
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
669
- return half(::hlog(a));
788
+ #if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && \
789
+ EIGEN_CUDA_ARCH >= 530) || \
790
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
791
+ return half(hlog(a));
670
792
  #else
671
793
  return half(::logf(float(a)));
672
794
  #endif
673
795
  }
674
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) {
675
- return half(numext::log1p(float(a)));
676
- }
677
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) {
678
- return half(::log10f(float(a)));
679
- }
796
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) { return half(numext::log1p(float(a))); }
797
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) { return half(::log10f(float(a))); }
680
798
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) {
681
799
  return half(static_cast<float>(EIGEN_LOG2E) * ::logf(float(a)));
682
800
  }
683
801
 
684
802
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
685
803
  #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
686
- defined(EIGEN_HIP_DEVICE_COMPILE)
804
+ defined(EIGEN_HIP_DEVICE_COMPILE)
687
805
  return half(hsqrt(a));
688
806
  #else
689
- return half(::sqrtf(float(a)));
807
+ return half(::sqrtf(float(a)));
690
808
  #endif
691
809
  }
692
810
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
693
811
  return half(::powf(float(a), float(b)));
694
812
  }
695
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
696
- return half(::sinf(float(a)));
697
- }
698
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) {
699
- return half(::cosf(float(a)));
700
- }
701
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) {
702
- return half(::tanf(float(a)));
703
- }
704
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) {
705
- return half(::tanhf(float(a)));
706
- }
707
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) {
708
- return half(::asinf(float(a)));
709
- }
710
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) {
711
- return half(::acosf(float(a)));
712
- }
813
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan2(const half& a, const half& b) {
814
+ return half(::atan2f(float(a), float(b)));
815
+ }
816
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) { return half(::sinf(float(a))); }
817
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) { return half(::cosf(float(a))); }
818
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) { return half(::tanf(float(a))); }
819
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) { return half(::tanhf(float(a))); }
820
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) { return half(::asinf(float(a))); }
821
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) { return half(::acosf(float(a))); }
822
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan(const half& a) { return half(::atanf(float(a))); }
823
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atanh(const half& a) { return half(::atanhf(float(a))); }
713
824
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
714
825
  #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
715
- defined(EIGEN_HIP_DEVICE_COMPILE)
826
+ defined(EIGEN_HIP_DEVICE_COMPILE)
716
827
  return half(hfloor(a));
717
828
  #else
718
829
  return half(::floorf(float(a)));
@@ -720,109 +831,97 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) {
720
831
  }
721
832
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) {
722
833
  #if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \
723
- defined(EIGEN_HIP_DEVICE_COMPILE)
834
+ defined(EIGEN_HIP_DEVICE_COMPILE)
724
835
  return half(hceil(a));
725
836
  #else
726
837
  return half(::ceilf(float(a)));
727
838
  #endif
728
839
  }
729
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) {
730
- return half(::rintf(float(a)));
731
- }
732
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) {
733
- return half(::roundf(float(a)));
734
- }
840
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) { return half(::rintf(float(a))); }
841
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) { return half(::roundf(float(a))); }
842
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half trunc(const half& a) { return half(::truncf(float(a))); }
735
843
  EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
736
844
  return half(::fmodf(float(a), float(b)));
737
845
  }
738
846
 
739
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) {
740
- #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
741
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
742
- return __hlt(b, a) ? b : a;
743
- #else
744
- const float f1 = static_cast<float>(a);
745
- const float f2 = static_cast<float>(b);
746
- return f2 < f1 ? b : a;
747
- #endif
748
- }
749
- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) {
750
- #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
751
- (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
752
- return __hlt(a, b) ? b : a;
847
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { return b < a ? b : a; }
848
+
849
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
850
+
851
+ EIGEN_DEVICE_FUNC inline half fma(const half& a, const half& b, const half& c) {
852
+ #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
853
+ return half(vfmah_f16(c.x, a.x, b.x));
854
+ #elif defined(EIGEN_VECTORIZE_AVX512FP16)
855
+ // Reduces to vfmadd213sh.
856
+ return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
753
857
  #else
754
- const float f1 = static_cast<float>(a);
755
- const float f2 = static_cast<float>(b);
756
- return f1 < f2 ? b : a;
858
+ // Emulate FMA via float.
859
+ return half(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
757
860
  #endif
758
861
  }
759
862
 
760
863
  #ifndef EIGEN_NO_IO
761
- EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) {
864
+ EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
762
865
  os << static_cast<float>(v);
763
866
  return os;
764
867
  }
765
868
  #endif
766
869
 
767
- } // end namespace half_impl
870
+ } // end namespace half_impl
768
871
 
769
872
  // import Eigen::half_impl::half into Eigen namespace
770
873
  // using half_impl::half;
771
874
 
772
875
  namespace internal {
773
876
 
774
- template<>
775
- struct random_default_impl<half, false, false>
776
- {
777
- static inline half run(const half& x, const half& y)
778
- {
779
- return x + (y-x) * half(float(std::rand()) / float(RAND_MAX));
877
+ template <>
878
+ struct is_arithmetic<half> {
879
+ enum { value = true };
880
+ };
881
+
882
+ template <>
883
+ struct random_impl<half> {
884
+ enum : int { MantissaBits = 10 };
885
+ using Impl = random_impl<float>;
886
+ static EIGEN_DEVICE_FUNC inline half run(const half& x, const half& y) {
887
+ float result = Impl::run(x, y, MantissaBits);
888
+ return half(result);
780
889
  }
781
- static inline half run()
782
- {
783
- return run(half(-1.f), half(1.f));
890
+ static EIGEN_DEVICE_FUNC inline half run() {
891
+ float result = Impl::run(MantissaBits);
892
+ return half(result);
784
893
  }
785
894
  };
786
895
 
787
- template<> struct is_arithmetic<half> { enum { value = true }; };
896
+ } // end namespace internal
788
897
 
789
- } // end namespace internal
790
-
791
- template<> struct NumTraits<Eigen::half>
792
- : GenericNumTraits<Eigen::half>
793
- {
794
- enum {
795
- IsSigned = true,
796
- IsInteger = false,
797
- IsComplex = false,
798
- RequireInitialization = false
799
- };
898
+ template <>
899
+ struct NumTraits<Eigen::half> : GenericNumTraits<Eigen::half> {
900
+ enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
800
901
 
801
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
902
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
802
903
  return half_impl::raw_uint16_to_half(0x0800);
803
904
  }
804
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
805
- return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
905
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
906
+ return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
806
907
  }
807
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
908
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
808
909
  return half_impl::raw_uint16_to_half(0x7bff);
809
910
  }
810
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
911
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
811
912
  return half_impl::raw_uint16_to_half(0xfbff);
812
913
  }
813
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
914
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
814
915
  return half_impl::raw_uint16_to_half(0x7c00);
815
916
  }
816
- EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
917
+ EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
817
918
  return half_impl::raw_uint16_to_half(0x7e00);
818
919
  }
819
920
  };
820
921
 
821
- } // end namespace Eigen
922
+ } // end namespace Eigen
822
923
 
823
- #if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
824
- #pragma pop_macro("EIGEN_CONSTEXPR")
825
- #endif
924
+ #undef _EIGEN_MAYBE_CONSTEXPR
826
925
 
827
926
  namespace Eigen {
828
927
  namespace numext {
@@ -856,6 +955,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(c
856
955
  return Eigen::half_impl::raw_half_as_uint16(src);
857
956
  }
858
957
 
958
+ // Specialize multiply-add to match packet operations and reduce conversions to/from float.
959
+ template<>
960
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half madd<Eigen::half>(const Eigen::half& x, const Eigen::half& y, const Eigen::half& z) {
961
+ return Eigen::half(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
962
+ }
963
+
859
964
  } // namespace numext
860
965
  } // namespace Eigen
861
966
 
@@ -870,63 +975,65 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(c
870
975
  // with native support for __half and __nv_bfloat16
871
976
  //
872
977
  // Note that the following are __device__ - only functions.
873
- #if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) \
874
- || defined(EIGEN_HIPCC)
978
+ #if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) || defined(EIGEN_HIPCC)
875
979
 
876
980
  #if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000
877
981
 
878
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) {
982
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane,
983
+ int width = warpSize) {
879
984
  const __half h = var;
880
985
  return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
881
986
  }
882
987
 
883
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
988
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta,
989
+ int width = warpSize) {
884
990
  const __half h = var;
885
991
  return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
886
992
  }
887
993
 
888
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) {
994
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta,
995
+ int width = warpSize) {
889
996
  const __half h = var;
890
997
  return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
891
998
  }
892
999
 
893
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) {
1000
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask,
1001
+ int width = warpSize) {
894
1002
  const __half h = var;
895
1003
  return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
896
1004
  }
897
1005
 
898
- #else // HIP or CUDA SDK < 9.0
1006
+ #else // HIP or CUDA SDK < 9.0
899
1007
 
900
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width=warpSize) {
1008
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width = warpSize) {
901
1009
  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
902
1010
  return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl(ivar, srcLane, width)));
903
1011
  }
904
1012
 
905
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width=warpSize) {
1013
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width = warpSize) {
906
1014
  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
907
1015
  return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_up(ivar, delta, width)));
908
1016
  }
909
1017
 
910
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width=warpSize) {
1018
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width = warpSize) {
911
1019
  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
912
1020
  return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_down(ivar, delta, width)));
913
1021
  }
914
1022
 
915
- __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
1023
+ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width = warpSize) {
916
1024
  const int ivar = static_cast<int>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(var));
917
1025
  return Eigen::numext::bit_cast<Eigen::half>(static_cast<Eigen::numext::uint16_t>(__shfl_xor(ivar, laneMask, width)));
918
1026
  }
919
1027
 
920
- #endif // HIP vs CUDA
921
- #endif // __shfl*
1028
+ #endif // HIP vs CUDA
1029
+ #endif // __shfl*
922
1030
 
923
1031
  // ldg() has an overload for __half_raw, but we also need one for Eigen::half.
924
- #if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) \
925
- || defined(EIGEN_HIPCC)
1032
+ #if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) || defined(EIGEN_HIPCC)
926
1033
  EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
927
1034
  return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
928
1035
  }
929
- #endif // __ldg
1036
+ #endif // __ldg
930
1037
 
931
1038
  #if EIGEN_HAS_STD_HASH
932
1039
  namespace std {
@@ -936,7 +1043,49 @@ struct hash<Eigen::half> {
936
1043
  return static_cast<std::size_t>(Eigen::numext::bit_cast<Eigen::numext::uint16_t>(a));
937
1044
  }
938
1045
  };
939
- } // end namespace std
1046
+ } // end namespace std
1047
+ #endif
1048
+
1049
+ namespace Eigen {
1050
+ namespace internal {
1051
+
1052
+ template <>
1053
+ struct cast_impl<float, half> {
1054
+ EIGEN_DEVICE_FUNC static inline half run(const float& a) {
1055
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
1056
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
1057
+ return __float2half(a);
1058
+ #else
1059
+ return half(a);
1060
+ #endif
1061
+ }
1062
+ };
1063
+
1064
+ template <>
1065
+ struct cast_impl<int, half> {
1066
+ EIGEN_DEVICE_FUNC static inline half run(const int& a) {
1067
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
1068
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
1069
+ return __float2half(static_cast<float>(a));
1070
+ #else
1071
+ return half(static_cast<float>(a));
1072
+ #endif
1073
+ }
1074
+ };
1075
+
1076
+ template <>
1077
+ struct cast_impl<half, float> {
1078
+ EIGEN_DEVICE_FUNC static inline float run(const half& a) {
1079
+ #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
1080
+ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
1081
+ return __half2float(a);
1082
+ #else
1083
+ return static_cast<float>(a);
940
1084
  #endif
1085
+ }
1086
+ };
1087
+
1088
+ } // namespace internal
1089
+ } // namespace Eigen
941
1090
 
942
- #endif // EIGEN_HALF_H
1091
+ #endif // EIGEN_HALF_H