@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -10,6 +10,9 @@
10
10
  #ifndef EIGEN_PACKET_MATH_AVX512_H
11
11
  #define EIGEN_PACKET_MATH_AVX512_H
12
12
 
13
+ // IWYU pragma: private
14
+ #include "../../InternalHeaderCheck.h"
15
+
13
16
  namespace Eigen {
14
17
 
15
18
  namespace internal {
@@ -31,6 +34,15 @@ namespace internal {
31
34
  typedef __m512 Packet16f;
32
35
  typedef __m512i Packet16i;
33
36
  typedef __m512d Packet8d;
37
+ typedef eigen_packet_wrapper<__m512i, 1> Packet8l;
38
+ #ifndef EIGEN_VECTORIZE_AVX512FP16
39
+ typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
40
+ #endif
41
+ typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
42
+
43
+ typedef eigen_packet_wrapper<__m512i, 6> Packet32s;
44
+ typedef eigen_packet_wrapper<__m256i, 6> Packet16s;
45
+ typedef eigen_packet_wrapper<__m128i, 6> Packet8s;
34
46
 
35
47
  template <>
36
48
  struct is_arithmetic<__m512> {
@@ -44,75 +56,237 @@ template <>
44
56
  struct is_arithmetic<__m512d> {
45
57
  enum { value = true };
46
58
  };
59
+ template <>
60
+ struct is_arithmetic<Packet8l> {
61
+ enum { value = true };
62
+ };
63
+
64
+ #ifndef EIGEN_VECTORIZE_AVX512FP16
65
+ template <>
66
+ struct is_arithmetic<Packet16h> {
67
+ enum { value = true };
68
+ };
69
+
70
+ template <>
71
+ struct packet_traits<half> : default_packet_traits {
72
+ typedef Packet16h type;
73
+ // There is no half-size packet for Packet16h.
74
+ typedef Packet16h half;
75
+ enum {
76
+ Vectorizable = 1,
77
+ AlignedOnScalar = 1,
78
+ size = 16,
79
+
80
+ HasCmp = 1,
81
+ HasAdd = 1,
82
+ HasSub = 1,
83
+ HasMul = 1,
84
+ HasDiv = 1,
85
+ HasNegate = 1,
86
+ HasAbs = 1,
87
+ HasAbs2 = 0,
88
+ HasMin = 1,
89
+ HasMax = 1,
90
+ HasConj = 1,
91
+ HasSetLinear = 0,
92
+ HasSqrt = 1,
93
+ HasRsqrt = 1,
94
+ HasLog = 1,
95
+ HasLog1p = 1,
96
+ HasExp = 1,
97
+ HasExpm1 = 1,
98
+ HasBessel = 1,
99
+ HasNdtri = 1,
100
+ HasSin = EIGEN_FAST_MATH,
101
+ HasCos = EIGEN_FAST_MATH,
102
+ HasTanh = EIGEN_FAST_MATH,
103
+ HasErf = EIGEN_FAST_MATH,
104
+ HasBlend = 0
105
+ };
106
+ };
107
+ #endif
47
108
 
48
- template<> struct packet_traits<float> : default_packet_traits
49
- {
109
+ template <>
110
+ struct packet_traits<float> : default_packet_traits {
50
111
  typedef Packet16f type;
51
112
  typedef Packet8f half;
52
113
  enum {
53
114
  Vectorizable = 1,
54
115
  AlignedOnScalar = 1,
55
116
  size = 16,
56
- HasHalfPacket = 1,
57
- HasBlend = 0,
58
- #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
59
- #ifdef EIGEN_VECTORIZE_AVX512DQ
117
+
118
+ HasAbs = 1,
119
+ HasMin = 1,
120
+ HasMax = 1,
121
+ HasConj = 1,
122
+ HasBlend = 1,
123
+ HasSin = EIGEN_FAST_MATH,
124
+ HasCos = EIGEN_FAST_MATH,
125
+ HasACos = 1,
126
+ HasASin = 1,
127
+ HasATan = 1,
128
+ HasATanh = 1,
129
+ HasSqrt = 1,
130
+ HasRsqrt = 1,
131
+ HasCbrt = 1,
60
132
  HasLog = 1,
61
- #endif
133
+ HasLog1p = 1,
134
+ HasExpm1 = 1,
135
+ HasNdtri = 1,
136
+ HasBessel = 1,
62
137
  HasExp = 1,
63
- HasSqrt = EIGEN_FAST_MATH,
64
- HasRsqrt = EIGEN_FAST_MATH,
65
- #endif
138
+ HasPow = 1,
139
+ HasReciprocal = EIGEN_FAST_MATH,
140
+ HasTanh = EIGEN_FAST_MATH,
141
+ HasErf = EIGEN_FAST_MATH,
142
+ HasErfc = EIGEN_FAST_MATH,
143
+ HasCmp = 1,
66
144
  HasDiv = 1
67
145
  };
68
- };
69
- template<> struct packet_traits<double> : default_packet_traits
70
- {
146
+ };
147
+ template <>
148
+ struct packet_traits<double> : default_packet_traits {
71
149
  typedef Packet8d type;
72
150
  typedef Packet4d half;
73
151
  enum {
74
152
  Vectorizable = 1,
75
153
  AlignedOnScalar = 1,
76
154
  size = 8,
77
- HasHalfPacket = 1,
78
- #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
79
- HasSqrt = EIGEN_FAST_MATH,
80
- HasRsqrt = EIGEN_FAST_MATH,
81
- #endif
155
+ HasBlend = 1,
156
+ HasSqrt = 1,
157
+ HasRsqrt = 1,
158
+ HasCbrt = 1,
159
+ HasSin = EIGEN_FAST_MATH,
160
+ HasCos = EIGEN_FAST_MATH,
161
+ HasLog = 1,
162
+ HasExp = 1,
163
+ HasPow = 1,
164
+ HasATan = 1,
165
+ HasTanh = EIGEN_FAST_MATH,
166
+ HasErf = EIGEN_FAST_MATH,
167
+ HasErfc = EIGEN_FAST_MATH,
168
+ HasATanh = 1,
169
+ HasCmp = 1,
82
170
  HasDiv = 1
83
171
  };
84
172
  };
85
173
 
86
- /* TODO Implement AVX512 for integers
87
- template<> struct packet_traits<int> : default_packet_traits
88
- {
174
+ template <>
175
+ struct packet_traits<int> : default_packet_traits {
89
176
  typedef Packet16i type;
90
- enum {
91
- Vectorizable = 1,
92
- AlignedOnScalar = 1,
93
- size=8
94
- };
177
+ typedef Packet8i half;
178
+ enum { Vectorizable = 1, AlignedOnScalar = 1, HasBlend = 0, HasCmp = 1, HasDiv = 1, size = 16 };
179
+ };
180
+
181
+ template <>
182
+ struct packet_traits<int64_t> : default_packet_traits {
183
+ typedef Packet8l type;
184
+ typedef Packet4l half;
185
+ enum { Vectorizable = 1, AlignedOnScalar = 1, HasCmp = 1, size = 8 };
95
186
  };
96
- */
97
187
 
98
188
  template <>
99
189
  struct unpacket_traits<Packet16f> {
100
190
  typedef float type;
101
191
  typedef Packet8f half;
102
192
  typedef Packet16i integer_packet;
103
- enum { size = 16, alignment=Aligned64 };
193
+ typedef uint16_t mask_t;
194
+ enum {
195
+ size = 16,
196
+ alignment = Aligned64,
197
+ vectorizable = true,
198
+ masked_load_available = true,
199
+ masked_store_available = true,
200
+ masked_fpops_available = true
201
+ };
104
202
  };
105
203
  template <>
106
204
  struct unpacket_traits<Packet8d> {
107
205
  typedef double type;
108
206
  typedef Packet4d half;
109
- enum { size = 8, alignment=Aligned64 };
207
+ typedef Packet8l integer_packet;
208
+ typedef uint8_t mask_t;
209
+ enum {
210
+ size = 8,
211
+ alignment = Aligned64,
212
+ vectorizable = true,
213
+ masked_load_available = true,
214
+ masked_store_available = true,
215
+ masked_fpops_available = true
216
+ };
110
217
  };
111
218
  template <>
112
219
  struct unpacket_traits<Packet16i> {
113
220
  typedef int type;
114
221
  typedef Packet8i half;
115
- enum { size = 16, alignment=Aligned64 };
222
+ enum {
223
+ size = 16,
224
+ alignment = Aligned64,
225
+ vectorizable = true,
226
+ masked_load_available = false,
227
+ masked_store_available = false
228
+ };
229
+ };
230
+
231
+ template <>
232
+ struct unpacket_traits<Packet8l> {
233
+ typedef int64_t type;
234
+ typedef Packet4l half;
235
+ enum {
236
+ size = 8,
237
+ alignment = Aligned64,
238
+ vectorizable = true,
239
+ masked_load_available = false,
240
+ masked_store_available = false
241
+ };
242
+ };
243
+
244
+ #ifndef EIGEN_VECTORIZE_AVX512FP16
245
+ template <>
246
+ struct unpacket_traits<Packet16h> {
247
+ typedef Eigen::half type;
248
+ typedef Packet8h half;
249
+ enum {
250
+ size = 16,
251
+ alignment = Aligned32,
252
+ vectorizable = true,
253
+ masked_load_available = false,
254
+ masked_store_available = false
255
+ };
256
+ };
257
+ #endif
258
+
259
+ template <>
260
+ struct unpacket_traits<Packet32s> {
261
+ typedef numext::int16_t type;
262
+ typedef Packet16s half;
263
+ enum {
264
+ size = 32,
265
+ alignment = Aligned64,
266
+ vectorizable = false,
267
+ };
268
+ };
269
+
270
+ template <>
271
+ struct unpacket_traits<Packet16s> {
272
+ typedef numext::int16_t type;
273
+ typedef Packet8s half;
274
+ enum {
275
+ size = 16,
276
+ alignment = Aligned32,
277
+ vectorizable = false,
278
+ };
279
+ };
280
+
281
+ template <>
282
+ struct unpacket_traits<Packet8s> {
283
+ typedef numext::int16_t type;
284
+ typedef Packet8s half;
285
+ enum {
286
+ size = 8,
287
+ alignment = Aligned16,
288
+ vectorizable = false,
289
+ };
116
290
  };
117
291
 
118
292
  template <>
@@ -127,68 +301,166 @@ template <>
127
301
  EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) {
128
302
  return _mm512_set1_epi32(from);
129
303
  }
304
+ template <>
305
+ EIGEN_STRONG_INLINE Packet8l pset1<Packet8l>(const int64_t& from) {
306
+ return _mm512_set1_epi64(from);
307
+ }
308
+
309
+ template <>
310
+ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
311
+ return _mm512_castsi512_ps(_mm512_set1_epi32(from));
312
+ }
313
+
314
+ template <>
315
+ EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
316
+ return _mm512_castsi512_pd(_mm512_set1_epi64(from));
317
+ }
318
+
319
+ template <>
320
+ EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) {
321
+ return _mm512_setzero_ps();
322
+ }
323
+ template <>
324
+ EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) {
325
+ return _mm512_setzero_pd();
326
+ }
327
+ template <>
328
+ EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) {
329
+ return _mm512_setzero_si512();
330
+ }
331
+
332
+ template <>
333
+ EIGEN_STRONG_INLINE Packet8l pzero(const Packet8l& /*a*/) {
334
+ return _mm512_setzero_si512();
335
+ }
336
+
337
+ template <>
338
+ EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) {
339
+ return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1));
340
+ }
341
+ template <>
342
+ EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) {
343
+ return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1, 0, -1);
344
+ }
345
+ template <>
346
+ EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
347
+ return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1));
348
+ }
349
+ template <>
350
+ EIGEN_STRONG_INLINE Packet8l peven_mask(const Packet8l& /*a*/) {
351
+ return _mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1, 0, 0, -1, -1);
352
+ }
130
353
 
131
354
  template <>
132
355
  EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
356
+ #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
357
+ // Inline asm here helps reduce some register spilling in TRSM kernels.
358
+ // See note in unrolls::gemm::microKernel in TrsmKernel.h
359
+ Packet16f ret;
360
+ __asm__("vbroadcastss %[mem], %[dst]" : [dst] "=v"(ret) : [mem] "m"(*from));
361
+ return ret;
362
+ #else
133
363
  return _mm512_broadcastss_ps(_mm_load_ps1(from));
364
+ #endif
134
365
  }
135
366
  template <>
136
367
  EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
368
+ #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
369
+ Packet8d ret;
370
+ __asm__("vbroadcastsd %[mem], %[dst]" : [dst] "=v"(ret) : [mem] "m"(*from));
371
+ return ret;
372
+ #else
137
373
  return _mm512_set1_pd(*from);
374
+ #endif
138
375
  }
139
376
 
140
377
  template <>
141
378
  EIGEN_STRONG_INLINE Packet16f plset<Packet16f>(const float& a) {
142
- return _mm512_add_ps(
143
- _mm512_set1_ps(a),
144
- _mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f,
145
- 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
379
+ return _mm512_add_ps(_mm512_set1_ps(a), _mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
380
+ 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
146
381
  }
147
382
  template <>
148
383
  EIGEN_STRONG_INLINE Packet8d plset<Packet8d>(const double& a) {
149
- return _mm512_add_pd(_mm512_set1_pd(a),
150
- _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0));
384
+ return _mm512_add_pd(_mm512_set1_pd(a), _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0));
385
+ }
386
+ template <>
387
+ EIGEN_STRONG_INLINE Packet16i plset<Packet16i>(const int& a) {
388
+ return _mm512_add_epi32(_mm512_set1_epi32(a), _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
389
+ }
390
+ template <>
391
+ EIGEN_STRONG_INLINE Packet8l plset<Packet8l>(const int64_t& a) {
392
+ return _mm512_add_epi64(_mm512_set1_epi64(a), _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0));
151
393
  }
152
394
 
153
395
  template <>
154
- EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a,
155
- const Packet16f& b) {
396
+ EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a, const Packet16f& b) {
156
397
  return _mm512_add_ps(a, b);
157
398
  }
158
399
  template <>
159
- EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
160
- const Packet8d& b) {
400
+ EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, const Packet8d& b) {
161
401
  return _mm512_add_pd(a, b);
162
402
  }
163
403
  template <>
164
- EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a,
165
- const Packet16i& b) {
404
+ EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a, const Packet16i& b) {
166
405
  return _mm512_add_epi32(a, b);
167
406
  }
407
+ template <>
408
+ EIGEN_STRONG_INLINE Packet8l padd<Packet8l>(const Packet8l& a, const Packet8l& b) {
409
+ return _mm512_add_epi64(a, b);
410
+ }
411
+
412
+ template <>
413
+ EIGEN_STRONG_INLINE Packet16f padd<Packet16f>(const Packet16f& a, const Packet16f& b, uint16_t umask) {
414
+ __mmask16 mask = static_cast<__mmask16>(umask);
415
+ return _mm512_maskz_add_ps(mask, a, b);
416
+ }
417
+ template <>
418
+ EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a, const Packet8d& b, uint8_t umask) {
419
+ __mmask8 mask = static_cast<__mmask8>(umask);
420
+ return _mm512_maskz_add_pd(mask, a, b);
421
+ }
168
422
 
169
423
  template <>
170
- EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
171
- const Packet16f& b) {
424
+ EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a, const Packet16f& b) {
172
425
  return _mm512_sub_ps(a, b);
173
426
  }
174
427
  template <>
175
- EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a,
176
- const Packet8d& b) {
428
+ EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a, const Packet8d& b) {
177
429
  return _mm512_sub_pd(a, b);
178
430
  }
179
431
  template <>
180
- EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a,
181
- const Packet16i& b) {
432
+ EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a, const Packet16i& b) {
182
433
  return _mm512_sub_epi32(a, b);
183
434
  }
435
+ template <>
436
+ EIGEN_STRONG_INLINE Packet8l psub<Packet8l>(const Packet8l& a, const Packet8l& b) {
437
+ return _mm512_sub_epi64(a, b);
438
+ }
184
439
 
185
440
  template <>
186
441
  EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) {
187
- return _mm512_sub_ps(_mm512_set1_ps(0.0), a);
442
+ // NOTE: MSVC seems to struggle with _mm512_set1_epi32, leading to random results.
443
+ // The intel docs give it a relatively high latency as well, so we're probably
444
+ // better off with using _mm512_set_epi32 directly anyways.
445
+ const __m512i mask =
446
+ _mm512_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
447
+ 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
448
+ return _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(a), mask));
188
449
  }
189
450
  template <>
190
451
  EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) {
191
- return _mm512_sub_pd(_mm512_set1_pd(0.0), a);
452
+ const __m512i mask =
453
+ _mm512_set_epi64(0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL,
454
+ 0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL, 0x8000000000000000ULL);
455
+ return _mm512_castsi512_pd(_mm512_xor_epi64(_mm512_castpd_si512(a), mask));
456
+ }
457
+ template <>
458
+ EIGEN_STRONG_INLINE Packet16i pnegate(const Packet16i& a) {
459
+ return _mm512_sub_epi32(_mm512_setzero_si512(), a);
460
+ }
461
+ template <>
462
+ EIGEN_STRONG_INLINE Packet8l pnegate(const Packet8l& a) {
463
+ return _mm512_sub_epi64(_mm512_setzero_si512(), a);
192
464
  }
193
465
 
194
466
  template <>
@@ -203,91 +475,217 @@ template <>
203
475
  EIGEN_STRONG_INLINE Packet16i pconj(const Packet16i& a) {
204
476
  return a;
205
477
  }
478
+ template <>
479
+ EIGEN_STRONG_INLINE Packet8l pconj(const Packet8l& a) {
480
+ return a;
481
+ }
206
482
 
207
483
  template <>
208
- EIGEN_STRONG_INLINE Packet16f pmul<Packet16f>(const Packet16f& a,
209
- const Packet16f& b) {
484
+ EIGEN_STRONG_INLINE Packet16f pmul<Packet16f>(const Packet16f& a, const Packet16f& b) {
210
485
  return _mm512_mul_ps(a, b);
211
486
  }
212
487
  template <>
213
- EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a,
214
- const Packet8d& b) {
488
+ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a, const Packet8d& b) {
215
489
  return _mm512_mul_pd(a, b);
216
490
  }
217
491
  template <>
218
- EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a,
219
- const Packet16i& b) {
220
- return _mm512_mul_epi32(a, b);
492
+ EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a, const Packet16i& b) {
493
+ return _mm512_mullo_epi32(a, b);
494
+ }
495
+ template <>
496
+ EIGEN_STRONG_INLINE Packet8l pmul<Packet8l>(const Packet8l& a, const Packet8l& b) {
497
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
498
+ return _mm512_mullo_epi64(a, b);
499
+ #else
500
+ return _mm512_mullox_epi64(a, b);
501
+ #endif
221
502
  }
222
503
 
223
504
  template <>
224
- EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a,
225
- const Packet16f& b) {
505
+ EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a, const Packet16f& b) {
226
506
  return _mm512_div_ps(a, b);
227
507
  }
508
+
228
509
  template <>
229
- EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a,
230
- const Packet8d& b) {
510
+ EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a, const Packet8d& b) {
231
511
  return _mm512_div_pd(a, b);
232
512
  }
233
513
 
514
+ template <>
515
+ EIGEN_STRONG_INLINE Packet16i pdiv<Packet16i>(const Packet16i& a, const Packet16i& b) {
516
+ Packet8i q_lo = pdiv<Packet8i>(_mm512_extracti64x4_epi64(a, 0), _mm512_extracti64x4_epi64(b, 0));
517
+ Packet8i q_hi = pdiv<Packet8i>(_mm512_extracti64x4_epi64(a, 1), _mm512_extracti64x4_epi64(b, 1));
518
+ return _mm512_inserti64x4(_mm512_castsi256_si512(q_lo), q_hi, 1);
519
+ }
520
+
234
521
  #ifdef EIGEN_VECTORIZE_FMA
235
522
  template <>
236
- EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b,
237
- const Packet16f& c) {
523
+ EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
238
524
  return _mm512_fmadd_ps(a, b, c);
239
525
  }
240
526
  template <>
241
- EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b,
242
- const Packet8d& c) {
527
+ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
243
528
  return _mm512_fmadd_pd(a, b, c);
244
529
  }
530
+
531
+ template <>
532
+ EIGEN_STRONG_INLINE Packet16f pmsub(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
533
+ return _mm512_fmsub_ps(a, b, c);
534
+ }
535
+ template <>
536
+ EIGEN_STRONG_INLINE Packet8d pmsub(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
537
+ return _mm512_fmsub_pd(a, b, c);
538
+ }
539
+
540
+ template <>
541
+ EIGEN_STRONG_INLINE Packet16f pnmadd(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
542
+ return _mm512_fnmadd_ps(a, b, c);
543
+ }
544
+ template <>
545
+ EIGEN_STRONG_INLINE Packet8d pnmadd(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
546
+ return _mm512_fnmadd_pd(a, b, c);
547
+ }
548
+
549
+ template <>
550
+ EIGEN_STRONG_INLINE Packet16f pnmsub(const Packet16f& a, const Packet16f& b, const Packet16f& c) {
551
+ return _mm512_fnmsub_ps(a, b, c);
552
+ }
553
+ template <>
554
+ EIGEN_STRONG_INLINE Packet8d pnmsub(const Packet8d& a, const Packet8d& b, const Packet8d& c) {
555
+ return _mm512_fnmsub_pd(a, b, c);
556
+ }
245
557
  #endif
246
558
 
247
559
  template <>
248
- EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a,
249
- const Packet16f& b) {
560
+ EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask, const Packet16f& a, const Packet16f& b) {
561
+ __mmask16 mask16 = _mm512_cmpeq_epi32_mask(_mm512_castps_si512(mask), _mm512_setzero_epi32());
562
+ return _mm512_mask_blend_ps(mask16, a, b);
563
+ }
564
+
565
+ template <>
566
+ EIGEN_DEVICE_FUNC inline Packet16i pselect(const Packet16i& mask, const Packet16i& a, const Packet16i& b) {
567
+ __mmask16 mask16 = _mm512_cmpeq_epi32_mask(mask, _mm512_setzero_epi32());
568
+ return _mm512_mask_blend_epi32(mask16, a, b);
569
+ }
570
+
571
+ template <>
572
+ EIGEN_DEVICE_FUNC inline Packet8l pselect(const Packet8l& mask, const Packet8l& a, const Packet8l& b) {
573
+ __mmask8 mask8 = _mm512_cmpeq_epi64_mask(mask, _mm512_setzero_si512());
574
+ return _mm512_mask_blend_epi64(mask8, a, b);
575
+ }
576
+
577
+ template <>
578
+ EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, const Packet8d& a, const Packet8d& b) {
579
+ __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
580
+ return _mm512_mask_blend_pd(mask8, a, b);
581
+ }
582
+
583
+ template <>
584
+ EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a, const Packet16f& b) {
250
585
  // Arguments are reversed to match NaN propagation behavior of std::min.
251
586
  return _mm512_min_ps(b, a);
252
587
  }
253
588
  template <>
254
- EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a,
255
- const Packet8d& b) {
589
+ EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a, const Packet8d& b) {
256
590
  // Arguments are reversed to match NaN propagation behavior of std::min.
257
591
  return _mm512_min_pd(b, a);
258
592
  }
593
+ template <>
594
+ EIGEN_STRONG_INLINE Packet16i pmin<Packet16i>(const Packet16i& a, const Packet16i& b) {
595
+ return _mm512_min_epi32(b, a);
596
+ }
597
+ template <>
598
+ EIGEN_STRONG_INLINE Packet8l pmin<Packet8l>(const Packet8l& a, const Packet8l& b) {
599
+ return _mm512_min_epi64(b, a);
600
+ }
259
601
 
260
602
  template <>
261
- EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a,
262
- const Packet16f& b) {
603
+ EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a, const Packet16f& b) {
263
604
  // Arguments are reversed to match NaN propagation behavior of std::max.
264
605
  return _mm512_max_ps(b, a);
265
606
  }
266
607
  template <>
267
- EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a,
268
- const Packet8d& b) {
608
+ EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a, const Packet8d& b) {
269
609
  // Arguments are reversed to match NaN propagation behavior of std::max.
270
610
  return _mm512_max_pd(b, a);
271
611
  }
612
+ template <>
613
+ EIGEN_STRONG_INLINE Packet16i pmax<Packet16i>(const Packet16i& a, const Packet16i& b) {
614
+ return _mm512_max_epi32(b, a);
615
+ }
616
+ template <>
617
+ EIGEN_STRONG_INLINE Packet8l pmax<Packet8l>(const Packet8l& a, const Packet8l& b) {
618
+ return _mm512_max_epi64(b, a);
619
+ }
620
+
621
+ // Add specializations for min/max with prescribed NaN propagation.
622
+ template <>
623
+ EIGEN_STRONG_INLINE Packet16f pmin<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
624
+ return pminmax_propagate_numbers(a, b, pmin<Packet16f>);
625
+ }
626
+ template <>
627
+ EIGEN_STRONG_INLINE Packet8d pmin<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
628
+ return pminmax_propagate_numbers(a, b, pmin<Packet8d>);
629
+ }
630
+ template <>
631
+ EIGEN_STRONG_INLINE Packet16f pmax<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
632
+ return pminmax_propagate_numbers(a, b, pmax<Packet16f>);
633
+ }
634
+ template <>
635
+ EIGEN_STRONG_INLINE Packet8d pmax<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
636
+ return pminmax_propagate_numbers(a, b, pmax<Packet8d>);
637
+ }
638
+ template <>
639
+ EIGEN_STRONG_INLINE Packet16f pmin<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
640
+ return pminmax_propagate_nan(a, b, pmin<Packet16f>);
641
+ }
642
+ template <>
643
+ EIGEN_STRONG_INLINE Packet8d pmin<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
644
+ return pminmax_propagate_nan(a, b, pmin<Packet8d>);
645
+ }
646
+ template <>
647
+ EIGEN_STRONG_INLINE Packet16f pmax<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
648
+ return pminmax_propagate_nan(a, b, pmax<Packet16f>);
649
+ }
650
+ template <>
651
+ EIGEN_STRONG_INLINE Packet8d pmax<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
652
+ return pminmax_propagate_nan(a, b, pmax<Packet8d>);
653
+ }
272
654
 
273
655
  #ifdef EIGEN_VECTORIZE_AVX512DQ
274
- template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); }
275
- template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); }
276
- EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); }
656
+ template <int I_>
657
+ EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
658
+ return _mm512_extractf32x8_ps(x, I_);
659
+ }
660
+ template <int I_>
661
+ EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
662
+ return _mm512_extractf64x2_pd(x, I_);
663
+ }
664
+ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
665
+ return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1);
666
+ }
667
+ EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) {
668
+ return _mm512_inserti32x8(_mm512_castsi256_si512(a), b, 1);
669
+ }
277
670
  #else
278
671
  // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
279
- template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
280
- return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_));
672
+ template <int I_>
673
+ EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
674
+ return _mm256_castsi256_ps(_mm512_extracti64x4_epi64(_mm512_castps_si512(x), I_));
281
675
  }
282
676
 
283
677
  // AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512
284
- template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
285
- return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_));
678
+ template <int I_>
679
+ EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
680
+ return _mm_castsi128_pd(_mm512_extracti32x4_epi32(_mm512_castpd_si512(x), I_));
286
681
  }
287
682
 
288
683
  EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
289
- return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)),
290
- _mm256_castps_si256(b),1));
684
+ return _mm512_castsi512_ps(
685
+ _mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), _mm256_castps_si256(b), 1));
686
+ }
687
+ EIGEN_STRONG_INLINE Packet16i cat256i(Packet8i a, Packet8i b) {
688
+ return _mm512_inserti64x4(_mm512_castsi256_si512(a), b, 1);
291
689
  }
292
690
  #endif
293
691
 
@@ -303,126 +701,311 @@ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
303
701
  // dst[255:240] := Saturate16(rf[255:224])
304
702
  __m256i lo = _mm256_castps_si256(extract256<0>(rf));
305
703
  __m256i hi = _mm256_castps_si256(extract256<1>(rf));
306
- __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
307
- _mm256_extractf128_si256(lo, 1));
308
- __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
309
- _mm256_extractf128_si256(hi, 1));
704
+ __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), _mm256_extractf128_si256(lo, 1));
705
+ __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), _mm256_extractf128_si256(hi, 1));
310
706
  return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
311
707
  }
312
708
 
313
709
  template <>
314
- EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a,
315
- const Packet16i& b) {
316
- return _mm512_and_si512(a,b);
710
+ EIGEN_STRONG_INLINE Packet16f pisnan(const Packet16f& a) {
711
+ __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_UNORD_Q);
712
+ return _mm512_castsi512_ps(_mm512_maskz_set1_epi32(mask, int32_t(-1)));
317
713
  }
318
714
 
319
715
  template <>
320
- EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a,
321
- const Packet16f& b) {
322
- #ifdef EIGEN_VECTORIZE_AVX512DQ
323
- return _mm512_and_ps(a, b);
324
- #else
325
- return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b)));
326
- #endif
716
+ EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
717
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
718
+ return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
327
719
  }
328
720
  template <>
329
- EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a,
330
- const Packet8d& b) {
331
- #ifdef EIGEN_VECTORIZE_AVX512DQ
332
- return _mm512_and_pd(a, b);
333
- #else
334
- Packet8d res = _mm512_undefined_pd();
335
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
336
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
337
- res = _mm512_insertf64x4(res, _mm256_and_pd(lane0_a, lane0_b), 0);
338
-
339
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
340
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
341
- return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
342
- #endif
721
+ EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
722
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
723
+ return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
343
724
  }
344
725
 
345
726
  template <>
346
- EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) {
347
- return _mm512_or_si512(a, b);
727
+ EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
728
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
729
+ return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
348
730
  }
349
731
 
350
732
  template <>
351
- EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) {
352
- #ifdef EIGEN_VECTORIZE_AVX512DQ
353
- return _mm512_or_ps(a, b);
354
- #else
355
- return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b)));
356
- #endif
733
+ EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
734
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
735
+ return _mm512_castsi512_ps(_mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1)));
357
736
  }
358
737
 
359
738
  template <>
360
- EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a,
361
- const Packet8d& b) {
362
- #ifdef EIGEN_VECTORIZE_AVX512DQ
363
- return _mm512_or_pd(a, b);
364
- #else
365
- return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
366
- #endif
739
+ EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
740
+ __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_EQ);
741
+ return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
367
742
  }
368
-
369
743
  template <>
370
- EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) {
371
- return _mm512_xor_si512(a, b);
744
+ EIGEN_STRONG_INLINE Packet16i pcmp_le(const Packet16i& a, const Packet16i& b) {
745
+ __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LE);
746
+ return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
372
747
  }
373
-
374
748
  template <>
375
- EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) {
376
- #ifdef EIGEN_VECTORIZE_AVX512DQ
377
- return _mm512_xor_ps(a, b);
378
- #else
379
- return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b)));
380
- #endif
749
+ EIGEN_STRONG_INLINE Packet16i pcmp_lt(const Packet16i& a, const Packet16i& b) {
750
+ __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _MM_CMPINT_LT);
751
+ return _mm512_mask_set1_epi32(_mm512_setzero_epi32(), mask, int32_t(-1));
381
752
  }
382
753
 
383
754
  template <>
384
- EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) {
385
- #ifdef EIGEN_VECTORIZE_AVX512DQ
386
- return _mm512_xor_pd(a, b);
387
- #else
388
- return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
389
- #endif
755
+ EIGEN_STRONG_INLINE Packet8l pcmp_eq(const Packet8l& a, const Packet8l& b) {
756
+ __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_EQ);
757
+ return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
390
758
  }
391
-
392
759
  template <>
393
- EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) {
394
- return _mm512_andnot_si512(b, a);
760
+ EIGEN_STRONG_INLINE Packet8l pcmp_le(const Packet8l& a, const Packet8l& b) {
761
+ __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LE);
762
+ return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
763
+ }
764
+ template <>
765
+ EIGEN_STRONG_INLINE Packet8l pcmp_lt(const Packet8l& a, const Packet8l& b) {
766
+ __mmask8 mask = _mm512_cmp_epi64_mask(a, b, _MM_CMPINT_LT);
767
+ return _mm512_mask_set1_epi64(_mm512_setzero_si512(), mask, int64_t(-1));
395
768
  }
396
769
 
397
770
  template <>
398
- EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) {
399
- #ifdef EIGEN_VECTORIZE_AVX512DQ
400
- return _mm512_andnot_ps(b, a);
401
- #else
402
- return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b)));
403
- #endif
771
+ EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
772
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
773
+ return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
404
774
  }
405
775
  template <>
406
- EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& b) {
407
- #ifdef EIGEN_VECTORIZE_AVX512DQ
408
- return _mm512_andnot_pd(b, a);
409
- #else
410
- return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
411
- #endif
776
+ EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
777
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
778
+ return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
412
779
  }
413
-
414
- template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
415
- return _mm512_srai_epi32(a, N);
780
+ template <>
781
+ EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
782
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
783
+ return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
784
+ }
785
+ template <>
786
+ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
787
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
788
+ return _mm512_castsi512_pd(_mm512_mask_set1_epi64(_mm512_setzero_epi32(), mask, 0xffffffffffffffffu));
789
+ }
790
+
791
+ template <>
792
+ EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) {
793
+ return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION);
794
+ }
795
+ template <>
796
+ EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) {
797
+ return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION);
798
+ }
799
+
800
+ template <>
801
+ EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) {
802
+ return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF);
803
+ }
804
+ template <>
805
+ EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) {
806
+ return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF);
807
+ }
808
+
809
+ template <>
810
+ EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) {
811
+ return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF);
812
+ }
813
+ template <>
814
+ EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) {
815
+ return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF);
816
+ }
817
+
818
+ template <>
819
+ EIGEN_STRONG_INLINE Packet16f ptrunc<Packet16f>(const Packet16f& a) {
820
+ return _mm512_roundscale_ps(a, _MM_FROUND_TO_ZERO);
821
+ }
822
+ template <>
823
+ EIGEN_STRONG_INLINE Packet8d ptrunc<Packet8d>(const Packet8d& a) {
824
+ return _mm512_roundscale_pd(a, _MM_FROUND_TO_ZERO);
825
+ }
826
+
827
+ template <>
828
+ EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
829
+ return _mm512_set1_epi32(int32_t(-1));
830
+ }
831
+
832
+ template <>
833
+ EIGEN_STRONG_INLINE Packet8l ptrue<Packet8l>(const Packet8l& /*a*/) {
834
+ return _mm512_set1_epi64(int64_t(-1));
835
+ }
836
+
837
+ template <>
838
+ EIGEN_STRONG_INLINE Packet16f ptrue<Packet16f>(const Packet16f& a) {
839
+ return _mm512_castsi512_ps(ptrue<Packet16i>(_mm512_castps_si512(a)));
840
+ }
841
+
842
+ template <>
843
+ EIGEN_STRONG_INLINE Packet8d ptrue<Packet8d>(const Packet8d& a) {
844
+ return _mm512_castsi512_pd(ptrue<Packet16i>(_mm512_castpd_si512(a)));
845
+ }
846
+
847
+ template <>
848
+ EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a, const Packet16i& b) {
849
+ return _mm512_and_si512(a, b);
850
+ }
851
+
852
+ template <>
853
+ EIGEN_STRONG_INLINE Packet8l pand<Packet8l>(const Packet8l& a, const Packet8l& b) {
854
+ return _mm512_and_si512(a, b);
855
+ }
856
+
857
+ template <>
858
+ EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a, const Packet16f& b) {
859
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
860
+ return _mm512_and_ps(a, b);
861
+ #else
862
+ return _mm512_castsi512_ps(pand(_mm512_castps_si512(a), _mm512_castps_si512(b)));
863
+ #endif
864
+ }
865
+ template <>
866
+ EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a, const Packet8d& b) {
867
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
868
+ return _mm512_and_pd(a, b);
869
+ #else
870
+ Packet8d res = _mm512_undefined_pd();
871
+ Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
872
+ Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
873
+ res = _mm512_insertf64x4(res, _mm256_and_pd(lane0_a, lane0_b), 0);
874
+
875
+ Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
876
+ Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
877
+ return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
878
+ #endif
879
+ }
880
+
881
+ template <>
882
+ EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) {
883
+ return _mm512_or_si512(a, b);
884
+ }
885
+
886
+ template <>
887
+ EIGEN_STRONG_INLINE Packet8l por<Packet8l>(const Packet8l& a, const Packet8l& b) {
888
+ return _mm512_or_si512(a, b);
889
+ }
890
+
891
+ template <>
892
+ EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) {
893
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
894
+ return _mm512_or_ps(a, b);
895
+ #else
896
+ return _mm512_castsi512_ps(por(_mm512_castps_si512(a), _mm512_castps_si512(b)));
897
+ #endif
898
+ }
899
+
900
+ template <>
901
+ EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a, const Packet8d& b) {
902
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
903
+ return _mm512_or_pd(a, b);
904
+ #else
905
+ return _mm512_castsi512_pd(por(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
906
+ #endif
907
+ }
908
+
909
+ template <>
910
+ EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) {
911
+ return _mm512_xor_si512(a, b);
912
+ }
913
+
914
+ template <>
915
+ EIGEN_STRONG_INLINE Packet8l pxor<Packet8l>(const Packet8l& a, const Packet8l& b) {
916
+ return _mm512_xor_si512(a, b);
917
+ }
918
+
919
+ template <>
920
+ EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) {
921
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
922
+ return _mm512_xor_ps(a, b);
923
+ #else
924
+ return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a), _mm512_castps_si512(b)));
925
+ #endif
926
+ }
927
+
928
+ template <>
929
+ EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) {
930
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
931
+ return _mm512_xor_pd(a, b);
932
+ #else
933
+ return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
934
+ #endif
935
+ }
936
+
937
+ template <>
938
+ EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) {
939
+ return _mm512_andnot_si512(b, a);
940
+ }
941
+
942
+ template <>
943
+ EIGEN_STRONG_INLINE Packet8l pandnot<Packet8l>(const Packet8l& a, const Packet8l& b) {
944
+ return _mm512_andnot_si512(b, a);
945
+ }
946
+
947
+ template <>
948
+ EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) {
949
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
950
+ return _mm512_andnot_ps(b, a);
951
+ #else
952
+ return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a), _mm512_castps_si512(b)));
953
+ #endif
954
+ }
955
+ template <>
956
+ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a, const Packet8d& b) {
957
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
958
+ return _mm512_andnot_pd(b, a);
959
+ #else
960
+ return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a), _mm512_castpd_si512(b)));
961
+ #endif
962
+ }
963
+
964
+ template <>
965
+ EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a) {
966
+ // Work-around for default std::round rounding mode.
967
+ const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
968
+ const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
969
+ return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
970
+ }
971
+ template <>
972
+ EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a) {
973
+ // Work-around for default std::round rounding mode.
974
+ const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
975
+ const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
976
+ return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
977
+ }
978
+
979
+ template <int N>
980
+ EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
981
+ return _mm512_srai_epi32(a, N);
416
982
  }
417
983
 
418
- template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) {
984
+ template <int N>
985
+ EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) {
419
986
  return _mm512_srli_epi32(a, N);
420
987
  }
421
988
 
422
- template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) {
989
+ template <int N>
990
+ EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) {
423
991
  return _mm512_slli_epi32(a, N);
424
992
  }
425
993
 
994
+ template <int N>
995
+ EIGEN_STRONG_INLINE Packet8l parithmetic_shift_right(Packet8l a) {
996
+ return _mm512_srai_epi64(a, N);
997
+ }
998
+
999
+ template <int N>
1000
+ EIGEN_STRONG_INLINE Packet8l plogical_shift_right(Packet8l a) {
1001
+ return _mm512_srli_epi64(a, N);
1002
+ }
1003
+
1004
+ template <int N>
1005
+ EIGEN_STRONG_INLINE Packet8l plogical_shift_left(Packet8l a) {
1006
+ return _mm512_slli_epi64(a, N);
1007
+ }
1008
+
426
1009
  template <>
427
1010
  EIGEN_STRONG_INLINE Packet16f pload<Packet16f>(const float* from) {
428
1011
  EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ps(from);
@@ -433,8 +1016,11 @@ EIGEN_STRONG_INLINE Packet8d pload<Packet8d>(const double* from) {
433
1016
  }
434
1017
  template <>
435
1018
  EIGEN_STRONG_INLINE Packet16i pload<Packet16i>(const int* from) {
436
- EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512(
437
- reinterpret_cast<const __m512i*>(from));
1019
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from);
1020
+ }
1021
+ template <>
1022
+ EIGEN_STRONG_INLINE Packet8l pload<Packet8l>(const int64_t* from) {
1023
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_epi64(from);
438
1024
  }
439
1025
 
440
1026
  template <>
@@ -447,8 +1033,22 @@ EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from) {
447
1033
  }
448
1034
  template <>
449
1035
  EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
450
- EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512(
451
- reinterpret_cast<const __m512i*>(from));
1036
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi32(from);
1037
+ }
1038
+ template <>
1039
+ EIGEN_STRONG_INLINE Packet8l ploadu<Packet8l>(const int64_t* from) {
1040
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_epi64(from);
1041
+ }
1042
+
1043
+ template <>
1044
+ EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
1045
+ __mmask16 mask = static_cast<__mmask16>(umask);
1046
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
1047
+ }
1048
+ template <>
1049
+ EIGEN_STRONG_INLINE Packet8d ploadu<Packet8d>(const double* from, uint8_t umask) {
1050
+ __mmask8 mask = static_cast<__mmask8>(umask);
1051
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from);
452
1052
  }
453
1053
 
454
1054
  // Loads 8 floats from memory a returns the packet
@@ -457,43 +1057,46 @@ template <>
457
1057
  EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) {
458
1058
  // an unaligned load is required here as there is no requirement
459
1059
  // on the alignment of input pointer 'from'
460
- __m256i low_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
1060
+ __m256i low_half = _mm256_castps_si256(_mm256_loadu_ps(from));
461
1061
  __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
462
1062
  __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
463
1063
  return pairs;
464
1064
  }
465
1065
 
466
- #ifdef EIGEN_VECTORIZE_AVX512DQ
467
- // FIXME: this does not look optimal, better load a Packet4d and shuffle...
468
- // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3,
1066
+ // Loads 4 doubles from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
469
1067
  // a3}
470
1068
  template <>
471
1069
  EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
472
- __m512d x = _mm512_setzero_pd();
473
- x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0);
474
- x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1);
475
- x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2);
476
- x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3);
477
- return x;
1070
+ Packet8d tmp = _mm512_castpd256_pd512(ploadu<Packet4d>(from));
1071
+ const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0);
1072
+ return _mm512_permutexvar_pd(scatter_mask, tmp);
478
1073
  }
479
- #else
1074
+
1075
+ // Loads 4 int64_t from memory a returns the packet {a0, a0, a1, a1, a2, a2, a3,
1076
+ // a3}
480
1077
  template <>
481
- EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
482
- __m512d x = _mm512_setzero_pd();
483
- x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0));
484
- x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1));
485
- x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2));
486
- x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3));
487
- return x;
1078
+ EIGEN_STRONG_INLINE Packet8l ploaddup<Packet8l>(const int64_t* from) {
1079
+ Packet8l tmp = _mm512_castsi256_si512(ploadu<Packet4l>(from));
1080
+ const Packet8l scatter_mask = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0);
1081
+ return _mm512_permutexvar_epi64(scatter_mask, tmp);
1082
+ }
1083
+
1084
+ // Loads 8 integers from memory and returns the packet
1085
+ // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
1086
+ template <>
1087
+ EIGEN_STRONG_INLINE Packet16i ploaddup<Packet16i>(const int* from) {
1088
+ __m256i low_half = _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
1089
+ __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
1090
+ __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
1091
+ return _mm512_castps_si512(pairs);
488
1092
  }
489
- #endif
490
1093
 
491
1094
  // Loads 4 floats from memory a returns the packet
492
1095
  // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
493
1096
  template <>
494
1097
  EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) {
495
1098
  Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from));
496
- const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0);
1099
+ const Packet16i scatter_mask = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
497
1100
  return _mm512_permutexvar_ps(scatter_mask, tmp);
498
1101
  }
499
1102
 
@@ -502,12 +1105,32 @@ EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) {
502
1105
  template <>
503
1106
  EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) {
504
1107
  __m256d lane0 = _mm256_set1_pd(*from);
505
- __m256d lane1 = _mm256_set1_pd(*(from+1));
1108
+ __m256d lane1 = _mm256_set1_pd(*(from + 1));
506
1109
  __m512d tmp = _mm512_undefined_pd();
507
1110
  tmp = _mm512_insertf64x4(tmp, lane0, 0);
508
1111
  return _mm512_insertf64x4(tmp, lane1, 1);
509
1112
  }
510
1113
 
1114
+ // Loads 2 int64_t from memory a returns the packet
1115
+ // {a0, a0 a0, a0, a1, a1, a1, a1}
1116
+ template <>
1117
+ EIGEN_STRONG_INLINE Packet8l ploadquad<Packet8l>(const int64_t* from) {
1118
+ __m256i lane0 = _mm256_set1_epi64x(*from);
1119
+ __m256i lane1 = _mm256_set1_epi64x(*(from + 1));
1120
+ __m512i tmp = _mm512_undefined_epi32();
1121
+ tmp = _mm512_inserti64x4(tmp, lane0, 0);
1122
+ return _mm512_inserti64x4(tmp, lane1, 1);
1123
+ }
1124
+
1125
+ // Loads 4 integers from memory and returns the packet
1126
+ // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
1127
+ template <>
1128
+ EIGEN_STRONG_INLINE Packet16i ploadquad<Packet16i>(const int* from) {
1129
+ Packet16i tmp = _mm512_castsi128_si512(ploadu<Packet4i>(from));
1130
+ const Packet16i scatter_mask = _mm512_set_epi32(3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0);
1131
+ return _mm512_permutexvar_epi32(scatter_mask, tmp);
1132
+ }
1133
+
511
1134
  template <>
512
1135
  EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet16f& from) {
513
1136
  EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ps(to, from);
@@ -518,8 +1141,11 @@ EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet8d& from) {
518
1141
  }
519
1142
  template <>
520
1143
  EIGEN_STRONG_INLINE void pstore<int>(int* to, const Packet16i& from) {
521
- EIGEN_DEBUG_ALIGNED_STORE _mm512_storeu_si512(reinterpret_cast<__m512i*>(to),
522
- from);
1144
+ EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi32(to, from);
1145
+ }
1146
+ template <>
1147
+ EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet8l& from) {
1148
+ EIGEN_DEBUG_ALIGNED_STORE _mm512_store_epi64(to, from);
523
1149
  }
524
1150
 
525
1151
  template <>
@@ -532,49 +1158,128 @@ EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from) {
532
1158
  }
533
1159
  template <>
534
1160
  EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) {
535
- EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512(
536
- reinterpret_cast<__m512i*>(to), from);
1161
+ EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi32(to, from);
1162
+ }
1163
+ template <>
1164
+ EIGEN_STRONG_INLINE void pstoreu<int64_t>(int64_t* to, const Packet8l& from) {
1165
+ EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_epi64(to, from);
1166
+ }
1167
+ template <>
1168
+ EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16_t umask) {
1169
+ __mmask16 mask = static_cast<__mmask16>(umask);
1170
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
1171
+ }
1172
+ template <>
1173
+ EIGEN_STRONG_INLINE void pstoreu<double>(double* to, const Packet8d& from, uint8_t umask) {
1174
+ __mmask8 mask = static_cast<__mmask8>(umask);
1175
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from);
1176
+ }
1177
+
1178
+ template <typename Scalar, typename Packet>
1179
+ EIGEN_DEVICE_FUNC inline Packet pgather(const Packet& src, const Scalar* from, Index stride,
1180
+ typename unpacket_traits<Packet>::mask_t umask);
1181
+ template <>
1182
+ EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const Packet16f& src, const float* from, Index stride,
1183
+ uint16_t umask) {
1184
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
1185
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1186
+ Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
1187
+ __mmask16 mask = static_cast<__mmask16>(umask);
1188
+
1189
+ return _mm512_mask_i32gather_ps(src, mask, indices, from, 4);
1190
+ }
1191
+ template <>
1192
+ EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const Packet8d& src, const double* from, Index stride,
1193
+ uint8_t umask) {
1194
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
1195
+ Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
1196
+ Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
1197
+ __mmask8 mask = static_cast<__mmask8>(umask);
1198
+
1199
+ return _mm512_mask_i32gather_pd(src, mask, indices, from, 8);
537
1200
  }
538
1201
 
539
1202
  template <>
540
- EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
541
- Index stride) {
1203
+ EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from, Index stride) {
542
1204
  Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
543
- Packet16i stride_multiplier =
544
- _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1205
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
545
1206
  Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
546
1207
 
547
1208
  return _mm512_i32gather_ps(indices, from, 4);
548
1209
  }
549
1210
  template <>
550
- EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from,
551
- Index stride) {
1211
+ EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from, Index stride) {
552
1212
  Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
553
1213
  Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
554
1214
  Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
555
1215
 
556
1216
  return _mm512_i32gather_pd(indices, from, 8);
557
1217
  }
1218
+ template <>
1219
+ EIGEN_DEVICE_FUNC inline Packet8l pgather<int64_t, Packet8l>(const int64_t* from, Index stride) {
1220
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
1221
+ Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
1222
+ Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
1223
+
1224
+ return _mm512_i32gather_epi64(indices, from, 8);
1225
+ }
1226
+ template <>
1227
+ EIGEN_DEVICE_FUNC inline Packet16i pgather<int, Packet16i>(const int* from, Index stride) {
1228
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
1229
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1230
+ Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
1231
+ return _mm512_i32gather_epi32(indices, from, 4);
1232
+ }
558
1233
 
1234
+ template <typename Scalar, typename Packet>
1235
+ EIGEN_DEVICE_FUNC inline void pscatter(Scalar* to, const Packet& from, Index stride,
1236
+ typename unpacket_traits<Packet>::mask_t umask);
1237
+ template <>
1238
+ EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, const Packet16f& from, Index stride,
1239
+ uint16_t umask) {
1240
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
1241
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1242
+ Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
1243
+ __mmask16 mask = static_cast<__mmask16>(umask);
1244
+ _mm512_mask_i32scatter_ps(to, mask, indices, from, 4);
1245
+ }
1246
+ template <>
1247
+ EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, const Packet8d& from, Index stride,
1248
+ uint8_t umask) {
1249
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
1250
+ Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
1251
+ Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
1252
+ __mmask8 mask = static_cast<__mmask8>(umask);
1253
+ _mm512_mask_i32scatter_pd(to, mask, indices, from, 8);
1254
+ }
559
1255
  template <>
560
- EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
561
- const Packet16f& from,
562
- Index stride) {
1256
+ EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to, const Packet16f& from, Index stride) {
563
1257
  Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
564
- Packet16i stride_multiplier =
565
- _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1258
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
566
1259
  Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
567
1260
  _mm512_i32scatter_ps(to, indices, from, 4);
568
1261
  }
569
1262
  template <>
570
- EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to,
571
- const Packet8d& from,
572
- Index stride) {
1263
+ EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to, const Packet8d& from, Index stride) {
573
1264
  Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
574
1265
  Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
575
1266
  Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
576
1267
  _mm512_i32scatter_pd(to, indices, from, 8);
577
1268
  }
1269
+ template <>
1270
+ EIGEN_DEVICE_FUNC inline void pscatter<int64_t, Packet8l>(int64_t* to, const Packet8l& from, Index stride) {
1271
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
1272
+ Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
1273
+ Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
1274
+ _mm512_i32scatter_epi64(to, indices, from, 8);
1275
+ }
1276
+ template <>
1277
+ EIGEN_DEVICE_FUNC inline void pscatter<int, Packet16i>(int* to, const Packet16i& from, Index stride) {
1278
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
1279
+ Packet16i stride_multiplier = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1280
+ Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
1281
+ _mm512_i32scatter_epi32(to, indices, from, 4);
1282
+ }
578
1283
 
579
1284
  template <>
580
1285
  EIGEN_STRONG_INLINE void pstore1<Packet16f>(float* to, const float& a) {
@@ -591,64 +1296,190 @@ EIGEN_STRONG_INLINE void pstore1<Packet16i>(int* to, const int& a) {
591
1296
  Packet16i pa = pset1<Packet16i>(a);
592
1297
  pstore(to, pa);
593
1298
  }
1299
+ template <>
1300
+ EIGEN_STRONG_INLINE void pstore1<Packet8l>(int64_t* to, const int64_t& a) {
1301
+ Packet8l pa = pset1<Packet8l>(a);
1302
+ pstore(to, pa);
1303
+ }
594
1304
 
595
- template<> EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
596
- template<> EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
597
- template<> EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); }
1305
+ template <>
1306
+ EIGEN_STRONG_INLINE void prefetch<float>(const float* addr) {
1307
+ _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
1308
+ }
1309
+ template <>
1310
+ EIGEN_STRONG_INLINE void prefetch<double>(const double* addr) {
1311
+ _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
1312
+ }
1313
+ template <>
1314
+ EIGEN_STRONG_INLINE void prefetch<int>(const int* addr) {
1315
+ _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0);
1316
+ }
598
1317
 
599
1318
  template <>
600
1319
  EIGEN_STRONG_INLINE float pfirst<Packet16f>(const Packet16f& a) {
601
- return _mm_cvtss_f32(_mm512_extractf32x4_ps(a, 0));
1320
+ return _mm512_cvtss_f32(a);
602
1321
  }
603
1322
  template <>
604
1323
  EIGEN_STRONG_INLINE double pfirst<Packet8d>(const Packet8d& a) {
605
- return _mm_cvtsd_f64(_mm256_extractf128_pd(_mm512_extractf64x4_pd(a, 0), 0));
1324
+ return _mm512_cvtsd_f64(a);
1325
+ }
1326
+ template <>
1327
+ EIGEN_STRONG_INLINE int64_t pfirst<Packet8l>(const Packet8l& a) {
1328
+ int64_t x = _mm_extract_epi64_0(_mm512_extracti32x4_epi32(a, 0));
1329
+ return x;
606
1330
  }
607
1331
  template <>
608
1332
  EIGEN_STRONG_INLINE int pfirst<Packet16i>(const Packet16i& a) {
609
- return _mm_extract_epi32(_mm512_extracti32x4_epi32(a, 0), 0);
1333
+ #if EIGEN_GNUC_STRICT_LESS_THAN(11, 0, 0)
1334
+ return _mm_cvtsi128_si32(_mm512_castsi512_si128(a));
1335
+ #else
1336
+ return _mm512_cvtsi512_si32(a);
1337
+ #endif
610
1338
  }
611
1339
 
612
- template<> EIGEN_STRONG_INLINE Packet16f preverse(const Packet16f& a)
613
- {
1340
+ template <>
1341
+ EIGEN_STRONG_INLINE Packet16f preverse(const Packet16f& a) {
614
1342
  return _mm512_permutexvar_ps(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a);
615
1343
  }
616
1344
 
617
- template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a)
618
- {
1345
+ template <>
1346
+ EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) {
619
1347
  return _mm512_permutexvar_pd(_mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), a);
620
1348
  }
621
1349
 
622
- template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a)
623
- {
1350
+ template <>
1351
+ EIGEN_STRONG_INLINE Packet16i preverse(const Packet16i& a) {
1352
+ return _mm512_permutexvar_epi32(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a);
1353
+ }
1354
+
1355
+ template <>
1356
+ EIGEN_STRONG_INLINE Packet8l preverse(const Packet8l& a) {
1357
+ return _mm512_permutexvar_epi64(_mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7), a);
1358
+ }
1359
+
1360
+ template <>
1361
+ EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) {
624
1362
  // _mm512_abs_ps intrinsic not found, so hack around it
625
1363
  return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff)));
626
1364
  }
627
1365
  template <>
628
1366
  EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
629
1367
  // _mm512_abs_ps intrinsic not found, so hack around it
630
- return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a),
631
- _mm512_set1_epi64(0x7fffffffffffffff)));
1368
+ return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), _mm512_set1_epi64(0x7fffffffffffffff)));
1369
+ }
1370
+ template <>
1371
+ EIGEN_STRONG_INLINE Packet16i pabs(const Packet16i& a) {
1372
+ return _mm512_abs_epi32(a);
1373
+ }
1374
+ template <>
1375
+ EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) {
1376
+ return _mm512_abs_epi64(a);
632
1377
  }
633
1378
 
634
- #ifdef EIGEN_VECTORIZE_AVX512DQ
635
- // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
636
- #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
637
- __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
1379
+ #ifndef EIGEN_VECTORIZE_AVX512FP16
1380
+ template <>
1381
+ EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) {
1382
+ return _mm256_srai_epi16(a, 15);
1383
+ }
1384
+ #endif // EIGEN_VECTORIZE_AVX512FP16
1385
+
1386
+ template <>
1387
+ EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) {
1388
+ return _mm256_srai_epi16(a, 15);
1389
+ }
1390
+ template <>
1391
+ EIGEN_STRONG_INLINE Packet16f psignbit(const Packet16f& a) {
1392
+ return _mm512_castsi512_ps(_mm512_srai_epi32(_mm512_castps_si512(a), 31));
1393
+ }
1394
+ template <>
1395
+ EIGEN_STRONG_INLINE Packet8d psignbit(const Packet8d& a) {
1396
+ return _mm512_castsi512_pd(_mm512_srai_epi64(_mm512_castpd_si512(a), 63));
1397
+ }
1398
+
1399
+ template <>
1400
+ EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent) {
1401
+ return pfrexp_generic(a, exponent);
1402
+ }
1403
+
1404
+ // Extract exponent without existence of Packet8l.
1405
+ template <>
1406
+ EIGEN_STRONG_INLINE Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
1407
+ const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
1408
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
1409
+ return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
1410
+ #else
1411
+ return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
1412
+ #endif
1413
+ }
1414
+
1415
+ template <>
1416
+ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
1417
+ return pfrexp_generic(a, exponent);
1418
+ }
1419
+
1420
+ template <>
1421
+ EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
1422
+ return pldexp_generic(a, exponent);
1423
+ }
1424
+
1425
+ template <>
1426
+ EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
1427
+ // Clamp exponent to [-2099, 2099]
1428
+ const Packet8d max_exponent = pset1<Packet8d>(2099.0);
1429
+ const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
1430
+
1431
+ // Split 2^e into four factors and multiply.
1432
+ const Packet8i bias = pset1<Packet8i>(1023);
1433
+ Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4)
1434
+
1435
+ // 2^b
1436
+ const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
1437
+ Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
1438
+ Packet8i lo = _mm256_slli_epi64(hi, 52);
1439
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
1440
+ Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
1441
+ Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
1442
+
1443
+ // 2^(e - 3b)
1444
+ b = psub(psub(psub(e, b), b), b); // e - 3b
1445
+ hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
1446
+ lo = _mm256_slli_epi64(hi, 52);
1447
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
1448
+ c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
1449
+ out = pmul(out, c); // a * 2^e
1450
+ return out;
1451
+ }
1452
+
1453
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
1454
+ // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
1455
+ #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
1456
+ __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
638
1457
  __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1)
1458
+
1459
+ // AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i
1460
+ #define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
1461
+ __m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0); \
1462
+ __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1)
639
1463
  #else
640
- #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
641
- __m256 OUTPUT##_0 = _mm256_insertf128_ps( \
642
- _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
643
- _mm512_extractf32x4_ps(INPUT, 1), 1); \
644
- __m256 OUTPUT##_1 = _mm256_insertf128_ps( \
645
- _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
646
- _mm512_extractf32x4_ps(INPUT, 3), 1);
1464
+ #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
1465
+ __m256 OUTPUT##_0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \
1466
+ _mm512_extractf32x4_ps(INPUT, 1), 1); \
1467
+ __m256 OUTPUT##_1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \
1468
+ _mm512_extractf32x4_ps(INPUT, 3), 1)
1469
+
1470
+ #define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \
1471
+ __m256i OUTPUT##_0 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 0)), \
1472
+ _mm512_extracti32x4_epi32(INPUT, 1), 1); \
1473
+ __m256i OUTPUT##_1 = _mm256_insertf128_si256(_mm256_castsi128_si256(_mm512_extracti32x4_epi32(INPUT, 2)), \
1474
+ _mm512_extracti32x4_epi32(INPUT, 3), 1)
647
1475
  #endif
648
1476
 
649
1477
  #ifdef EIGEN_VECTORIZE_AVX512DQ
650
1478
  #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
651
1479
  OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1);
1480
+
1481
+ #define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
1482
+ OUTPUT = _mm512_inserti32x8(_mm512_castsi256_si512(INPUTA), INPUTB, 1);
652
1483
  #else
653
1484
  #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
654
1485
  OUTPUT = _mm512_undefined_ps(); \
@@ -656,318 +1487,60 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
656
1487
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
657
1488
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
658
1489
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
1490
+
1491
+ #define EIGEN_INSERT_8i_INTO_16i(OUTPUT, INPUTA, INPUTB) \
1492
+ OUTPUT = _mm512_undefined_epi32(); \
1493
+ OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 0), 0); \
1494
+ OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTA, 1), 1); \
1495
+ OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 0), 2); \
1496
+ OUTPUT = _mm512_inserti32x4(OUTPUT, _mm256_extractf128_si256(INPUTB, 1), 3);
659
1497
  #endif
660
1498
 
661
1499
  template <>
662
- EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) {
1500
+ EIGEN_STRONG_INLINE Packet8f predux_half_dowto4<Packet16f>(const Packet16f& a) {
663
1501
  #ifdef EIGEN_VECTORIZE_AVX512DQ
664
1502
  __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
665
1503
  __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
666
- Packet8f x = _mm256_add_ps(lane0, lane1);
667
- return predux<Packet8f>(x);
1504
+ return _mm256_add_ps(lane0, lane1);
668
1505
  #else
669
1506
  __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
670
1507
  __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
671
1508
  __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
672
1509
  __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
673
- __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3));
674
- sum = _mm_hadd_ps(sum, sum);
675
- sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1));
676
- return _mm_cvtss_f32(sum);
1510
+ __m128 sum0 = _mm_add_ps(lane0, lane2);
1511
+ __m128 sum1 = _mm_add_ps(lane1, lane3);
1512
+ return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
677
1513
  #endif
678
1514
  }
679
1515
  template <>
680
- EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) {
1516
+ EIGEN_STRONG_INLINE Packet4d predux_half_dowto4<Packet8d>(const Packet8d& a) {
681
1517
  __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
682
1518
  __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
683
- __m256d sum = _mm256_add_pd(lane0, lane1);
684
- __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
685
- return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0)));
1519
+ return _mm256_add_pd(lane0, lane1);
686
1520
  }
687
-
688
1521
  template <>
689
- EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) {
1522
+ EIGEN_STRONG_INLINE Packet8i predux_half_dowto4<Packet16i>(const Packet16i& a) {
690
1523
  #ifdef EIGEN_VECTORIZE_AVX512DQ
691
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
692
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
693
- return padd(lane0, lane1);
694
- #else
695
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
696
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
697
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
698
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
699
- Packet4f sum0 = padd(lane0, lane2);
700
- Packet4f sum1 = padd(lane1, lane3);
701
- return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
702
- #endif
703
- }
704
- template <>
705
- EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) {
706
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
707
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
708
- Packet4d res = padd(lane0, lane1);
709
- return res;
710
- }
711
-
712
- template <>
713
- EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) {
714
- //#ifdef EIGEN_VECTORIZE_AVX512DQ
715
- #if 0
716
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
717
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
718
- Packet8f res = pmul(lane0, lane1);
719
- res = pmul(res, _mm256_permute2f128_ps(res, res, 1));
720
- res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
721
- return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
1524
+ __m256i lane0 = _mm512_extracti32x8_epi32(a, 0);
1525
+ __m256i lane1 = _mm512_extracti32x8_epi32(a, 1);
1526
+ return _mm256_add_epi32(lane0, lane1);
722
1527
  #else
723
- __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
724
- __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
725
- __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
726
- __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
727
- __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
728
- res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
729
- return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
1528
+ __m128i lane0 = _mm512_extracti32x4_epi32(a, 0);
1529
+ __m128i lane1 = _mm512_extracti32x4_epi32(a, 1);
1530
+ __m128i lane2 = _mm512_extracti32x4_epi32(a, 2);
1531
+ __m128i lane3 = _mm512_extracti32x4_epi32(a, 3);
1532
+ __m128i sum0 = _mm_add_epi32(lane0, lane2);
1533
+ __m128i sum1 = _mm_add_epi32(lane1, lane3);
1534
+ return _mm256_inserti128_si256(_mm256_castsi128_si256(sum0), sum1, 1);
730
1535
  #endif
731
1536
  }
732
- template <>
733
- EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) {
734
- __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
735
- __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
736
- __m256d res = pmul(lane0, lane1);
737
- res = pmul(res, _mm256_permute2f128_pd(res, res, 1));
738
- return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1)));
739
- }
740
-
741
- template <>
742
- EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) {
743
- __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
744
- __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
745
- __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
746
- __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
747
- __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
748
- res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
749
- return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
750
- }
751
- template <>
752
- EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) {
753
- __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
754
- __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
755
- __m256d res = _mm256_min_pd(lane0, lane1);
756
- res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1));
757
- return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1)));
758
- }
759
-
760
- template <>
761
- EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) {
762
- __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
763
- __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
764
- __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
765
- __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
766
- __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
767
- res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
768
- return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
769
- }
770
1537
 
771
1538
  template <>
772
- EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) {
773
- __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
774
- __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
775
- __m256d res = _mm256_max_pd(lane0, lane1);
776
- res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1));
777
- return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1)));
778
- }
779
-
780
- template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f* vecs)
781
- {
782
- EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0);
783
- EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1);
784
- EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2);
785
- EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3);
786
- EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4);
787
- EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5);
788
- EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6);
789
- EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7);
790
- EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8);
791
- EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9);
792
- EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10);
793
- EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11);
794
- EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12);
795
- EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13);
796
- EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14);
797
- EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15);
798
-
799
- __m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0);
800
- __m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0);
801
- __m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0);
802
- __m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0);
803
-
804
- __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
805
- __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
806
- __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
807
- __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
808
-
809
- __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
810
- __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
811
- __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
812
- __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
813
-
814
- __m256 sum1 = _mm256_add_ps(perm1, hsum5);
815
- __m256 sum2 = _mm256_add_ps(perm2, hsum6);
816
- __m256 sum3 = _mm256_add_ps(perm3, hsum7);
817
- __m256 sum4 = _mm256_add_ps(perm4, hsum8);
818
-
819
- __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
820
- __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
821
-
822
- __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
823
-
824
- hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1);
825
- hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1);
826
- hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1);
827
- hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1);
828
-
829
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
830
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
831
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
832
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
833
-
834
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
835
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
836
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
837
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
838
-
839
- sum1 = _mm256_add_ps(perm1, hsum5);
840
- sum2 = _mm256_add_ps(perm2, hsum6);
841
- sum3 = _mm256_add_ps(perm3, hsum7);
842
- sum4 = _mm256_add_ps(perm4, hsum8);
843
-
844
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
845
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
846
-
847
- final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0));
848
-
849
- hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0);
850
- hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0);
851
- hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0);
852
- hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0);
853
-
854
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
855
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
856
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
857
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
858
-
859
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
860
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
861
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
862
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
863
-
864
- sum1 = _mm256_add_ps(perm1, hsum5);
865
- sum2 = _mm256_add_ps(perm2, hsum6);
866
- sum3 = _mm256_add_ps(perm3, hsum7);
867
- sum4 = _mm256_add_ps(perm4, hsum8);
868
-
869
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
870
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
871
-
872
- __m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0);
873
-
874
- hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1);
875
- hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1);
876
- hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1);
877
- hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1);
878
-
879
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
880
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
881
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
882
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
883
-
884
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
885
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
886
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
887
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
888
-
889
- sum1 = _mm256_add_ps(perm1, hsum5);
890
- sum2 = _mm256_add_ps(perm2, hsum6);
891
- sum3 = _mm256_add_ps(perm3, hsum7);
892
- sum4 = _mm256_add_ps(perm4, hsum8);
893
-
894
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
895
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
896
-
897
- final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0));
898
-
899
- __m512 final_output;
900
-
901
- EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1);
902
- return final_output;
903
- }
904
-
905
- template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
906
- {
907
- Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0);
908
- Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1);
909
-
910
- Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0);
911
- Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1);
912
-
913
- Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0);
914
- Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1);
915
-
916
- Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0);
917
- Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1);
918
-
919
- Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0);
920
- Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1);
921
-
922
- Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0);
923
- Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1);
924
-
925
- Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0);
926
- Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1);
927
-
928
- Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0);
929
- Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1);
930
-
931
- Packet4d tmp0, tmp1;
932
-
933
- tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0);
934
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
935
-
936
- tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0);
937
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
938
-
939
- __m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC);
940
-
941
- tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1);
942
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
943
-
944
- tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1);
945
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
946
-
947
- final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC));
948
-
949
- tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0);
950
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
951
-
952
- tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0);
953
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
954
-
955
- __m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC);
956
-
957
- tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1);
958
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
959
-
960
- tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1);
961
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
962
-
963
- final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC));
964
-
965
- __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0);
966
-
967
- return _mm512_insertf64x4(final_output, final_1, 1);
1539
+ EIGEN_STRONG_INLINE Packet4l predux_half_dowto4<Packet8l>(const Packet8l& a) {
1540
+ __m256i lane0 = _mm512_extracti64x4_epi64(a, 0);
1541
+ __m256i lane1 = _mm512_extracti64x4_epi64(a, 1);
1542
+ return _mm256_add_epi64(lane0, lane1);
968
1543
  }
969
-
970
-
971
1544
 
972
1545
  #define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
973
1546
  EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
@@ -1083,9 +1656,46 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 16>& kernel) {
1083
1656
  PACK_OUTPUT(kernel.packet, tmp.packet, 14, 16);
1084
1657
  PACK_OUTPUT(kernel.packet, tmp.packet, 15, 16);
1085
1658
  }
1086
- #define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \
1087
- EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \
1088
- INPUT[2 * INDEX + STRIDE]);
1659
+ #define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \
1660
+ EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
1661
+
1662
+ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 8>& kernel) {
1663
+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
1664
+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
1665
+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
1666
+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
1667
+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
1668
+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
1669
+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
1670
+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
1671
+
1672
+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
1673
+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
1674
+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
1675
+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
1676
+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
1677
+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
1678
+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
1679
+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
1680
+
1681
+ T0 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0x44);
1682
+ T1 = _mm512_shuffle_f32x4(kernel.packet[0], kernel.packet[4], 0xee);
1683
+ T2 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0x44);
1684
+ T3 = _mm512_shuffle_f32x4(kernel.packet[1], kernel.packet[5], 0xee);
1685
+ T4 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0x44);
1686
+ T5 = _mm512_shuffle_f32x4(kernel.packet[2], kernel.packet[6], 0xee);
1687
+ T6 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0x44);
1688
+ T7 = _mm512_shuffle_f32x4(kernel.packet[3], kernel.packet[7], 0xee);
1689
+
1690
+ kernel.packet[0] = _mm512_shuffle_f32x4(T0, T2, 0x88);
1691
+ kernel.packet[2] = _mm512_shuffle_f32x4(T0, T2, 0xdd);
1692
+ kernel.packet[1] = _mm512_shuffle_f32x4(T4, T6, 0x88);
1693
+ kernel.packet[3] = _mm512_shuffle_f32x4(T4, T6, 0xdd);
1694
+ kernel.packet[4] = _mm512_shuffle_f32x4(T1, T3, 0x88);
1695
+ kernel.packet[6] = _mm512_shuffle_f32x4(T1, T3, 0xdd);
1696
+ kernel.packet[5] = _mm512_shuffle_f32x4(T5, T7, 0x88);
1697
+ kernel.packet[7] = _mm512_shuffle_f32x4(T5, T7, 0xdd);
1698
+ }
1089
1699
 
1090
1700
  EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
1091
1701
  __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
@@ -1127,8 +1737,11 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16f, 4>& kernel) {
1127
1737
 
1128
1738
  #define PACK_OUTPUT_D(OUTPUT, INPUT, INDEX, STRIDE) \
1129
1739
  OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
1130
- OUTPUT[INDEX] = \
1131
- _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
1740
+ OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
1741
+
1742
+ #define PACK_OUTPUT_L(OUTPUT, INPUT, INDEX, STRIDE) \
1743
+ OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \
1744
+ OUTPUT[INDEX] = _mm512_inserti64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1);
1132
1745
 
1133
1746
  EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
1134
1747
  __m512d T0 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0);
@@ -1138,23 +1751,15 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 4>& kernel) {
1138
1751
 
1139
1752
  PacketBlock<Packet4d, 8> tmp;
1140
1753
 
1141
- tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
1142
- _mm512_extractf64x4_pd(T2, 0), 0x20);
1143
- tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
1144
- _mm512_extractf64x4_pd(T3, 0), 0x20);
1145
- tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
1146
- _mm512_extractf64x4_pd(T2, 0), 0x31);
1147
- tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
1148
- _mm512_extractf64x4_pd(T3, 0), 0x31);
1149
-
1150
- tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
1151
- _mm512_extractf64x4_pd(T2, 1), 0x20);
1152
- tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
1153
- _mm512_extractf64x4_pd(T3, 1), 0x20);
1154
- tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
1155
- _mm512_extractf64x4_pd(T2, 1), 0x31);
1156
- tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
1157
- _mm512_extractf64x4_pd(T3, 1), 0x31);
1754
+ tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), _mm512_extractf64x4_pd(T2, 0), 0x20);
1755
+ tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), _mm512_extractf64x4_pd(T3, 0), 0x20);
1756
+ tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), _mm512_extractf64x4_pd(T2, 0), 0x31);
1757
+ tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), _mm512_extractf64x4_pd(T3, 0), 0x31);
1758
+
1759
+ tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), _mm512_extractf64x4_pd(T2, 1), 0x20);
1760
+ tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), _mm512_extractf64x4_pd(T3, 1), 0x20);
1761
+ tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), _mm512_extractf64x4_pd(T2, 1), 0x31);
1762
+ tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), _mm512_extractf64x4_pd(T3, 1), 0x31);
1158
1763
 
1159
1764
  PACK_OUTPUT_D(kernel.packet, tmp.packet, 0, 1);
1160
1765
  PACK_OUTPUT_D(kernel.packet, tmp.packet, 1, 1);
@@ -1172,134 +1777,1370 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8d, 8>& kernel) {
1172
1777
  __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]);
1173
1778
  __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]);
1174
1779
 
1175
- PacketBlock<Packet4d, 16> tmp;
1176
-
1177
- tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
1178
- _mm512_extractf64x4_pd(T2, 0), 0x20);
1179
- tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
1180
- _mm512_extractf64x4_pd(T3, 0), 0x20);
1181
- tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0),
1182
- _mm512_extractf64x4_pd(T2, 0), 0x31);
1183
- tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0),
1184
- _mm512_extractf64x4_pd(T3, 0), 0x31);
1185
-
1186
- tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
1187
- _mm512_extractf64x4_pd(T2, 1), 0x20);
1188
- tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
1189
- _mm512_extractf64x4_pd(T3, 1), 0x20);
1190
- tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1),
1191
- _mm512_extractf64x4_pd(T2, 1), 0x31);
1192
- tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1),
1193
- _mm512_extractf64x4_pd(T3, 1), 0x31);
1194
-
1195
- tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
1196
- _mm512_extractf64x4_pd(T6, 0), 0x20);
1197
- tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
1198
- _mm512_extractf64x4_pd(T7, 0), 0x20);
1199
- tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0),
1200
- _mm512_extractf64x4_pd(T6, 0), 0x31);
1201
- tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0),
1202
- _mm512_extractf64x4_pd(T7, 0), 0x31);
1203
-
1204
- tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
1205
- _mm512_extractf64x4_pd(T6, 1), 0x20);
1206
- tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
1207
- _mm512_extractf64x4_pd(T7, 1), 0x20);
1208
- tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1),
1209
- _mm512_extractf64x4_pd(T6, 1), 0x31);
1210
- tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1),
1211
- _mm512_extractf64x4_pd(T7, 1), 0x31);
1212
-
1213
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8);
1214
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8);
1215
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8);
1216
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8);
1217
-
1218
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8);
1219
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8);
1220
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8);
1221
- PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8);
1222
- }
1223
- template <>
1224
- EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/,
1225
- const Packet16f& /*thenPacket*/,
1226
- const Packet16f& /*elsePacket*/) {
1227
- assert(false && "To be implemented");
1228
- return Packet16f();
1229
- }
1230
- template <>
1231
- EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket,
1232
- const Packet8d& thenPacket,
1780
+ kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E);
1781
+ kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]);
1782
+ kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E);
1783
+ kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2);
1784
+ kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E);
1785
+ kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]);
1786
+ kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E);
1787
+ kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3);
1788
+ kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E);
1789
+ kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]);
1790
+ kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E);
1791
+ kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6);
1792
+ kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E);
1793
+ kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]);
1794
+ kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E);
1795
+ kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7);
1796
+
1797
+ T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E);
1798
+ T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0);
1799
+ T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E);
1800
+ T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]);
1801
+ T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E);
1802
+ T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1);
1803
+ T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E);
1804
+ T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]);
1805
+ T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E);
1806
+ T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2);
1807
+ T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E);
1808
+ T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]);
1809
+ T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E);
1810
+ T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3);
1811
+ T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E);
1812
+ T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]);
1813
+
1814
+ kernel.packet[0] = T0;
1815
+ kernel.packet[1] = T1;
1816
+ kernel.packet[2] = T2;
1817
+ kernel.packet[3] = T3;
1818
+ kernel.packet[4] = T4;
1819
+ kernel.packet[5] = T5;
1820
+ kernel.packet[6] = T6;
1821
+ kernel.packet[7] = T7;
1822
+ }
1823
+
1824
+ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8l, 4>& kernel) {
1825
+ __m512i T0 = _mm512_castpd_si512(
1826
+ _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0));
1827
+ __m512i T1 = _mm512_castpd_si512(
1828
+ _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[0]), _mm512_castsi512_pd(kernel.packet[1]), 0xff));
1829
+ __m512i T2 = _mm512_castpd_si512(
1830
+ _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0));
1831
+ __m512i T3 = _mm512_castpd_si512(
1832
+ _mm512_shuffle_pd(_mm512_castsi512_pd(kernel.packet[2]), _mm512_castsi512_pd(kernel.packet[3]), 0xff));
1833
+
1834
+ PacketBlock<Packet4l, 8> tmp;
1835
+
1836
+ tmp.packet[0] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x20);
1837
+ tmp.packet[1] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x20);
1838
+ tmp.packet[2] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 0), _mm512_extracti64x4_epi64(T2, 0), 0x31);
1839
+ tmp.packet[3] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 0), _mm512_extracti64x4_epi64(T3, 0), 0x31);
1840
+
1841
+ tmp.packet[4] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x20);
1842
+ tmp.packet[5] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x20);
1843
+ tmp.packet[6] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T0, 1), _mm512_extracti64x4_epi64(T2, 1), 0x31);
1844
+ tmp.packet[7] = _mm256_permute2x128_si256(_mm512_extracti64x4_epi64(T1, 1), _mm512_extracti64x4_epi64(T3, 1), 0x31);
1845
+
1846
+ PACK_OUTPUT_L(kernel.packet, tmp.packet, 0, 1);
1847
+ PACK_OUTPUT_L(kernel.packet, tmp.packet, 1, 1);
1848
+ PACK_OUTPUT_L(kernel.packet, tmp.packet, 2, 1);
1849
+ PACK_OUTPUT_L(kernel.packet, tmp.packet, 3, 1);
1850
+ }
1851
+
1852
+ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8l, 8>& kernel) {
1853
+ __m512i T0 = _mm512_unpacklo_epi64(kernel.packet[0], kernel.packet[1]);
1854
+ __m512i T1 = _mm512_unpackhi_epi64(kernel.packet[0], kernel.packet[1]);
1855
+ __m512i T2 = _mm512_unpacklo_epi64(kernel.packet[2], kernel.packet[3]);
1856
+ __m512i T3 = _mm512_unpackhi_epi64(kernel.packet[2], kernel.packet[3]);
1857
+ __m512i T4 = _mm512_unpacklo_epi64(kernel.packet[4], kernel.packet[5]);
1858
+ __m512i T5 = _mm512_unpackhi_epi64(kernel.packet[4], kernel.packet[5]);
1859
+ __m512i T6 = _mm512_unpacklo_epi64(kernel.packet[6], kernel.packet[7]);
1860
+ __m512i T7 = _mm512_unpackhi_epi64(kernel.packet[6], kernel.packet[7]);
1861
+
1862
+ kernel.packet[0] = _mm512_permutex_epi64(T2, 0x4E);
1863
+ kernel.packet[0] = _mm512_mask_blend_epi64(0xCC, T0, kernel.packet[0]);
1864
+ kernel.packet[2] = _mm512_permutex_epi64(T0, 0x4E);
1865
+ kernel.packet[2] = _mm512_mask_blend_epi64(0xCC, kernel.packet[2], T2);
1866
+ kernel.packet[1] = _mm512_permutex_epi64(T3, 0x4E);
1867
+ kernel.packet[1] = _mm512_mask_blend_epi64(0xCC, T1, kernel.packet[1]);
1868
+ kernel.packet[3] = _mm512_permutex_epi64(T1, 0x4E);
1869
+ kernel.packet[3] = _mm512_mask_blend_epi64(0xCC, kernel.packet[3], T3);
1870
+ kernel.packet[4] = _mm512_permutex_epi64(T6, 0x4E);
1871
+ kernel.packet[4] = _mm512_mask_blend_epi64(0xCC, T4, kernel.packet[4]);
1872
+ kernel.packet[6] = _mm512_permutex_epi64(T4, 0x4E);
1873
+ kernel.packet[6] = _mm512_mask_blend_epi64(0xCC, kernel.packet[6], T6);
1874
+ kernel.packet[5] = _mm512_permutex_epi64(T7, 0x4E);
1875
+ kernel.packet[5] = _mm512_mask_blend_epi64(0xCC, T5, kernel.packet[5]);
1876
+ kernel.packet[7] = _mm512_permutex_epi64(T5, 0x4E);
1877
+ kernel.packet[7] = _mm512_mask_blend_epi64(0xCC, kernel.packet[7], T7);
1878
+
1879
+ T0 = _mm512_shuffle_i64x2(kernel.packet[4], kernel.packet[4], 0x4E);
1880
+ T0 = _mm512_mask_blend_epi64(0xF0, kernel.packet[0], T0);
1881
+ T4 = _mm512_shuffle_i64x2(kernel.packet[0], kernel.packet[0], 0x4E);
1882
+ T4 = _mm512_mask_blend_epi64(0xF0, T4, kernel.packet[4]);
1883
+ T1 = _mm512_shuffle_i64x2(kernel.packet[5], kernel.packet[5], 0x4E);
1884
+ T1 = _mm512_mask_blend_epi64(0xF0, kernel.packet[1], T1);
1885
+ T5 = _mm512_shuffle_i64x2(kernel.packet[1], kernel.packet[1], 0x4E);
1886
+ T5 = _mm512_mask_blend_epi64(0xF0, T5, kernel.packet[5]);
1887
+ T2 = _mm512_shuffle_i64x2(kernel.packet[6], kernel.packet[6], 0x4E);
1888
+ T2 = _mm512_mask_blend_epi64(0xF0, kernel.packet[2], T2);
1889
+ T6 = _mm512_shuffle_i64x2(kernel.packet[2], kernel.packet[2], 0x4E);
1890
+ T6 = _mm512_mask_blend_epi64(0xF0, T6, kernel.packet[6]);
1891
+ T3 = _mm512_shuffle_i64x2(kernel.packet[7], kernel.packet[7], 0x4E);
1892
+ T3 = _mm512_mask_blend_epi64(0xF0, kernel.packet[3], T3);
1893
+ T7 = _mm512_shuffle_i64x2(kernel.packet[3], kernel.packet[3], 0x4E);
1894
+ T7 = _mm512_mask_blend_epi64(0xF0, T7, kernel.packet[7]);
1895
+
1896
+ kernel.packet[0] = T0;
1897
+ kernel.packet[1] = T1;
1898
+ kernel.packet[2] = T2;
1899
+ kernel.packet[3] = T3;
1900
+ kernel.packet[4] = T4;
1901
+ kernel.packet[5] = T5;
1902
+ kernel.packet[6] = T6;
1903
+ kernel.packet[7] = T7;
1904
+ }
1905
+
1906
+ #define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \
1907
+ EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]);
1908
+
1909
+ #define PACK_OUTPUT_I32_2(OUTPUT, INPUT, INDEX, STRIDE) \
1910
+ EIGEN_INSERT_8i_INTO_16i(OUTPUT[INDEX], INPUT[2 * INDEX], INPUT[2 * INDEX + STRIDE]);
1911
+
1912
+ #define SHUFFLE_EPI32(A, B, M) _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(A), _mm512_castsi512_ps(B), M))
1913
+
1914
+ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16i, 16>& kernel) {
1915
+ __m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]);
1916
+ __m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]);
1917
+ __m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]);
1918
+ __m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]);
1919
+ __m512i T4 = _mm512_unpacklo_epi32(kernel.packet[4], kernel.packet[5]);
1920
+ __m512i T5 = _mm512_unpackhi_epi32(kernel.packet[4], kernel.packet[5]);
1921
+ __m512i T6 = _mm512_unpacklo_epi32(kernel.packet[6], kernel.packet[7]);
1922
+ __m512i T7 = _mm512_unpackhi_epi32(kernel.packet[6], kernel.packet[7]);
1923
+ __m512i T8 = _mm512_unpacklo_epi32(kernel.packet[8], kernel.packet[9]);
1924
+ __m512i T9 = _mm512_unpackhi_epi32(kernel.packet[8], kernel.packet[9]);
1925
+ __m512i T10 = _mm512_unpacklo_epi32(kernel.packet[10], kernel.packet[11]);
1926
+ __m512i T11 = _mm512_unpackhi_epi32(kernel.packet[10], kernel.packet[11]);
1927
+ __m512i T12 = _mm512_unpacklo_epi32(kernel.packet[12], kernel.packet[13]);
1928
+ __m512i T13 = _mm512_unpackhi_epi32(kernel.packet[12], kernel.packet[13]);
1929
+ __m512i T14 = _mm512_unpacklo_epi32(kernel.packet[14], kernel.packet[15]);
1930
+ __m512i T15 = _mm512_unpackhi_epi32(kernel.packet[14], kernel.packet[15]);
1931
+ __m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
1932
+ __m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
1933
+ __m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
1934
+ __m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
1935
+ __m512i S4 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(1, 0, 1, 0));
1936
+ __m512i S5 = SHUFFLE_EPI32(T4, T6, _MM_SHUFFLE(3, 2, 3, 2));
1937
+ __m512i S6 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(1, 0, 1, 0));
1938
+ __m512i S7 = SHUFFLE_EPI32(T5, T7, _MM_SHUFFLE(3, 2, 3, 2));
1939
+ __m512i S8 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(1, 0, 1, 0));
1940
+ __m512i S9 = SHUFFLE_EPI32(T8, T10, _MM_SHUFFLE(3, 2, 3, 2));
1941
+ __m512i S10 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(1, 0, 1, 0));
1942
+ __m512i S11 = SHUFFLE_EPI32(T9, T11, _MM_SHUFFLE(3, 2, 3, 2));
1943
+ __m512i S12 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(1, 0, 1, 0));
1944
+ __m512i S13 = SHUFFLE_EPI32(T12, T14, _MM_SHUFFLE(3, 2, 3, 2));
1945
+ __m512i S14 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(1, 0, 1, 0));
1946
+ __m512i S15 = SHUFFLE_EPI32(T13, T15, _MM_SHUFFLE(3, 2, 3, 2));
1947
+
1948
+ EIGEN_EXTRACT_8i_FROM_16i(S0, S0);
1949
+ EIGEN_EXTRACT_8i_FROM_16i(S1, S1);
1950
+ EIGEN_EXTRACT_8i_FROM_16i(S2, S2);
1951
+ EIGEN_EXTRACT_8i_FROM_16i(S3, S3);
1952
+ EIGEN_EXTRACT_8i_FROM_16i(S4, S4);
1953
+ EIGEN_EXTRACT_8i_FROM_16i(S5, S5);
1954
+ EIGEN_EXTRACT_8i_FROM_16i(S6, S6);
1955
+ EIGEN_EXTRACT_8i_FROM_16i(S7, S7);
1956
+ EIGEN_EXTRACT_8i_FROM_16i(S8, S8);
1957
+ EIGEN_EXTRACT_8i_FROM_16i(S9, S9);
1958
+ EIGEN_EXTRACT_8i_FROM_16i(S10, S10);
1959
+ EIGEN_EXTRACT_8i_FROM_16i(S11, S11);
1960
+ EIGEN_EXTRACT_8i_FROM_16i(S12, S12);
1961
+ EIGEN_EXTRACT_8i_FROM_16i(S13, S13);
1962
+ EIGEN_EXTRACT_8i_FROM_16i(S14, S14);
1963
+ EIGEN_EXTRACT_8i_FROM_16i(S15, S15);
1964
+
1965
+ PacketBlock<Packet8i, 32> tmp;
1966
+
1967
+ tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S4_0, 0x20);
1968
+ tmp.packet[1] = _mm256_permute2f128_si256(S1_0, S5_0, 0x20);
1969
+ tmp.packet[2] = _mm256_permute2f128_si256(S2_0, S6_0, 0x20);
1970
+ tmp.packet[3] = _mm256_permute2f128_si256(S3_0, S7_0, 0x20);
1971
+ tmp.packet[4] = _mm256_permute2f128_si256(S0_0, S4_0, 0x31);
1972
+ tmp.packet[5] = _mm256_permute2f128_si256(S1_0, S5_0, 0x31);
1973
+ tmp.packet[6] = _mm256_permute2f128_si256(S2_0, S6_0, 0x31);
1974
+ tmp.packet[7] = _mm256_permute2f128_si256(S3_0, S7_0, 0x31);
1975
+
1976
+ tmp.packet[8] = _mm256_permute2f128_si256(S0_1, S4_1, 0x20);
1977
+ tmp.packet[9] = _mm256_permute2f128_si256(S1_1, S5_1, 0x20);
1978
+ tmp.packet[10] = _mm256_permute2f128_si256(S2_1, S6_1, 0x20);
1979
+ tmp.packet[11] = _mm256_permute2f128_si256(S3_1, S7_1, 0x20);
1980
+ tmp.packet[12] = _mm256_permute2f128_si256(S0_1, S4_1, 0x31);
1981
+ tmp.packet[13] = _mm256_permute2f128_si256(S1_1, S5_1, 0x31);
1982
+ tmp.packet[14] = _mm256_permute2f128_si256(S2_1, S6_1, 0x31);
1983
+ tmp.packet[15] = _mm256_permute2f128_si256(S3_1, S7_1, 0x31);
1984
+
1985
+ // Second set of _m256 outputs
1986
+ tmp.packet[16] = _mm256_permute2f128_si256(S8_0, S12_0, 0x20);
1987
+ tmp.packet[17] = _mm256_permute2f128_si256(S9_0, S13_0, 0x20);
1988
+ tmp.packet[18] = _mm256_permute2f128_si256(S10_0, S14_0, 0x20);
1989
+ tmp.packet[19] = _mm256_permute2f128_si256(S11_0, S15_0, 0x20);
1990
+ tmp.packet[20] = _mm256_permute2f128_si256(S8_0, S12_0, 0x31);
1991
+ tmp.packet[21] = _mm256_permute2f128_si256(S9_0, S13_0, 0x31);
1992
+ tmp.packet[22] = _mm256_permute2f128_si256(S10_0, S14_0, 0x31);
1993
+ tmp.packet[23] = _mm256_permute2f128_si256(S11_0, S15_0, 0x31);
1994
+
1995
+ tmp.packet[24] = _mm256_permute2f128_si256(S8_1, S12_1, 0x20);
1996
+ tmp.packet[25] = _mm256_permute2f128_si256(S9_1, S13_1, 0x20);
1997
+ tmp.packet[26] = _mm256_permute2f128_si256(S10_1, S14_1, 0x20);
1998
+ tmp.packet[27] = _mm256_permute2f128_si256(S11_1, S15_1, 0x20);
1999
+ tmp.packet[28] = _mm256_permute2f128_si256(S8_1, S12_1, 0x31);
2000
+ tmp.packet[29] = _mm256_permute2f128_si256(S9_1, S13_1, 0x31);
2001
+ tmp.packet[30] = _mm256_permute2f128_si256(S10_1, S14_1, 0x31);
2002
+ tmp.packet[31] = _mm256_permute2f128_si256(S11_1, S15_1, 0x31);
2003
+
2004
+ // Pack them into the output
2005
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 0, 16);
2006
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 1, 16);
2007
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 2, 16);
2008
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 3, 16);
2009
+
2010
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 4, 16);
2011
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 5, 16);
2012
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 6, 16);
2013
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 7, 16);
2014
+
2015
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 8, 16);
2016
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 9, 16);
2017
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 10, 16);
2018
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 11, 16);
2019
+
2020
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 12, 16);
2021
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 13, 16);
2022
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 14, 16);
2023
+ PACK_OUTPUT_I32(kernel.packet, tmp.packet, 15, 16);
2024
+ }
2025
+
2026
+ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet16i, 4>& kernel) {
2027
+ __m512i T0 = _mm512_unpacklo_epi32(kernel.packet[0], kernel.packet[1]);
2028
+ __m512i T1 = _mm512_unpackhi_epi32(kernel.packet[0], kernel.packet[1]);
2029
+ __m512i T2 = _mm512_unpacklo_epi32(kernel.packet[2], kernel.packet[3]);
2030
+ __m512i T3 = _mm512_unpackhi_epi32(kernel.packet[2], kernel.packet[3]);
2031
+
2032
+ __m512i S0 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(1, 0, 1, 0));
2033
+ __m512i S1 = SHUFFLE_EPI32(T0, T2, _MM_SHUFFLE(3, 2, 3, 2));
2034
+ __m512i S2 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(1, 0, 1, 0));
2035
+ __m512i S3 = SHUFFLE_EPI32(T1, T3, _MM_SHUFFLE(3, 2, 3, 2));
2036
+
2037
+ EIGEN_EXTRACT_8i_FROM_16i(S0, S0);
2038
+ EIGEN_EXTRACT_8i_FROM_16i(S1, S1);
2039
+ EIGEN_EXTRACT_8i_FROM_16i(S2, S2);
2040
+ EIGEN_EXTRACT_8i_FROM_16i(S3, S3);
2041
+
2042
+ PacketBlock<Packet8i, 8> tmp;
2043
+
2044
+ tmp.packet[0] = _mm256_permute2f128_si256(S0_0, S1_0, 0x20);
2045
+ tmp.packet[1] = _mm256_permute2f128_si256(S2_0, S3_0, 0x20);
2046
+ tmp.packet[2] = _mm256_permute2f128_si256(S0_0, S1_0, 0x31);
2047
+ tmp.packet[3] = _mm256_permute2f128_si256(S2_0, S3_0, 0x31);
2048
+
2049
+ tmp.packet[4] = _mm256_permute2f128_si256(S0_1, S1_1, 0x20);
2050
+ tmp.packet[5] = _mm256_permute2f128_si256(S2_1, S3_1, 0x20);
2051
+ tmp.packet[6] = _mm256_permute2f128_si256(S0_1, S1_1, 0x31);
2052
+ tmp.packet[7] = _mm256_permute2f128_si256(S2_1, S3_1, 0x31);
2053
+
2054
+ PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 0, 1);
2055
+ PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 1, 1);
2056
+ PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 2, 1);
2057
+ PACK_OUTPUT_I32_2(kernel.packet, tmp.packet, 3, 1);
2058
+ }
2059
+
2060
+ template <size_t N>
2061
+ EIGEN_STRONG_INLINE int avx512_blend_mask(const Selector<N>& ifPacket) {
2062
+ alignas(__m128i) uint8_t aux[sizeof(__m128i)];
2063
+ for (size_t i = 0; i < N; i++) aux[i] = static_cast<uint8_t>(ifPacket.select[i]);
2064
+ __m128i paux = _mm_sub_epi8(_mm_setzero_si128(), _mm_load_si128(reinterpret_cast<const __m128i*>(aux)));
2065
+ return _mm_movemask_epi8(paux);
2066
+ }
2067
+
2068
+ template <>
2069
+ EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& ifPacket, const Packet16f& thenPacket,
2070
+ const Packet16f& elsePacket) {
2071
+ __mmask16 m = avx512_blend_mask(ifPacket);
2072
+ return _mm512_mask_blend_ps(m, elsePacket, thenPacket);
2073
+ }
2074
+ template <>
2075
+ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d& thenPacket,
1233
2076
  const Packet8d& elsePacket) {
1234
- __mmask8 m = (ifPacket.select[0] )
1235
- | (ifPacket.select[1]<<1)
1236
- | (ifPacket.select[2]<<2)
1237
- | (ifPacket.select[3]<<3)
1238
- | (ifPacket.select[4]<<4)
1239
- | (ifPacket.select[5]<<5)
1240
- | (ifPacket.select[6]<<6)
1241
- | (ifPacket.select[7]<<7);
2077
+ __mmask8 m = avx512_blend_mask(ifPacket);
1242
2078
  return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
1243
2079
  }
1244
2080
 
1245
- template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
1246
- return _mm512_cvttps_epi32(a);
2081
+ // Packet math for Eigen::half
2082
+ #ifndef EIGEN_VECTORIZE_AVX512FP16
2083
+ template <>
2084
+ EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
2085
+ return _mm256_set1_epi16(from.x);
2086
+ }
2087
+
2088
+ template <>
2089
+ EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
2090
+ return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from, 0)));
2091
+ }
2092
+
2093
+ template <>
2094
+ EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
2095
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
2096
+ }
2097
+
2098
+ template <>
2099
+ EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
2100
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
2101
+ }
2102
+
2103
+ template <>
2104
+ EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
2105
+ // (void*) -> workaround clang warning:
2106
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
2107
+ EIGEN_DEBUG_ALIGNED_STORE
2108
+ _mm256_store_si256((__m256i*)(void*)to, from);
2109
+ }
2110
+
2111
+ template <>
2112
+ EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
2113
+ // (void*) -> workaround clang warning:
2114
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
2115
+ EIGEN_DEBUG_UNALIGNED_STORE
2116
+ _mm256_storeu_si256((__m256i*)(void*)to, from);
2117
+ }
2118
+
2119
+ template <>
2120
+ EIGEN_STRONG_INLINE Packet16h ploaddup<Packet16h>(const Eigen::half* from) {
2121
+ unsigned short a = from[0].x;
2122
+ unsigned short b = from[1].x;
2123
+ unsigned short c = from[2].x;
2124
+ unsigned short d = from[3].x;
2125
+ unsigned short e = from[4].x;
2126
+ unsigned short f = from[5].x;
2127
+ unsigned short g = from[6].x;
2128
+ unsigned short h = from[7].x;
2129
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
2130
+ }
2131
+
2132
+ template <>
2133
+ EIGEN_STRONG_INLINE Packet16h ploadquad(const Eigen::half* from) {
2134
+ unsigned short a = from[0].x;
2135
+ unsigned short b = from[1].x;
2136
+ unsigned short c = from[2].x;
2137
+ unsigned short d = from[3].x;
2138
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
2139
+ }
2140
+
2141
+ EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtph_ps(a); }
2142
+
2143
+ EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) {
2144
+ return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
2145
+ }
2146
+
2147
+ template <>
2148
+ EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
2149
+ return Packet16h(ptrue(Packet8i(a)));
1247
2150
  }
1248
2151
 
1249
- template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packet16i& a) {
1250
- return _mm512_cvtepi32_ps(a);
2152
+ template <>
2153
+ EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
2154
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
2155
+ return _mm256_andnot_si256(sign_mask, a);
1251
2156
  }
1252
2157
 
1253
- template <int Offset>
1254
- struct palign_impl<Offset, Packet16f> {
1255
- static EIGEN_STRONG_INLINE void run(Packet16f& first,
1256
- const Packet16f& second) {
1257
- if (Offset != 0) {
1258
- __m512i first_idx = _mm512_set_epi32(
1259
- Offset + 15, Offset + 14, Offset + 13, Offset + 12, Offset + 11,
1260
- Offset + 10, Offset + 9, Offset + 8, Offset + 7, Offset + 6,
1261
- Offset + 5, Offset + 4, Offset + 3, Offset + 2, Offset + 1, Offset);
2158
+ template <>
2159
+ EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a, const Packet16h& b) {
2160
+ return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
2161
+ }
1262
2162
 
1263
- __m512i second_idx =
1264
- _mm512_set_epi32(Offset - 1, Offset - 2, Offset - 3, Offset - 4,
1265
- Offset - 5, Offset - 6, Offset - 7, Offset - 8,
1266
- Offset - 9, Offset - 10, Offset - 11, Offset - 12,
1267
- Offset - 13, Offset - 14, Offset - 15, Offset - 16);
2163
+ template <>
2164
+ EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a, const Packet16h& b) {
2165
+ return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
2166
+ }
1268
2167
 
1269
- unsigned short mask = 0xFFFF;
1270
- mask <<= (16 - Offset);
2168
+ template <>
2169
+ EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
2170
+ return float2half(plset<Packet16f>(static_cast<float>(a)));
2171
+ }
1271
2172
 
1272
- first = _mm512_permutexvar_ps(first_idx, first);
1273
- Packet16f tmp = _mm512_permutexvar_ps(second_idx, second);
1274
- first = _mm512_mask_blend_ps(mask, first, tmp);
2173
+ template <>
2174
+ EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) {
2175
+ // in some cases Packet8i is a wrapper around __m256i, so we need to
2176
+ // cast to Packet8i to call the correct overload.
2177
+ return Packet16h(por(Packet8i(a), Packet8i(b)));
2178
+ }
2179
+ template <>
2180
+ EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) {
2181
+ return Packet16h(pxor(Packet8i(a), Packet8i(b)));
2182
+ }
2183
+ template <>
2184
+ EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) {
2185
+ return Packet16h(pand(Packet8i(a), Packet8i(b)));
2186
+ }
2187
+ template <>
2188
+ EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) {
2189
+ return Packet16h(pandnot(Packet8i(a), Packet8i(b)));
2190
+ }
2191
+
2192
+ template <>
2193
+ EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
2194
+ return _mm256_blendv_epi8(b, a, mask);
2195
+ }
2196
+
2197
+ template <>
2198
+ EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
2199
+ return float2half(pround<Packet16f>(half2float(a)));
2200
+ }
2201
+
2202
+ template <>
2203
+ EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
2204
+ return float2half(print<Packet16f>(half2float(a)));
2205
+ }
2206
+
2207
+ template <>
2208
+ EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
2209
+ return float2half(pceil<Packet16f>(half2float(a)));
2210
+ }
2211
+
2212
+ template <>
2213
+ EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
2214
+ return float2half(pfloor<Packet16f>(half2float(a)));
2215
+ }
2216
+
2217
+ template <>
2218
+ EIGEN_STRONG_INLINE Packet16h ptrunc<Packet16h>(const Packet16h& a) {
2219
+ return float2half(ptrunc<Packet16f>(half2float(a)));
2220
+ }
2221
+
2222
+ template <>
2223
+ EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) {
2224
+ Packet16f af = half2float(a);
2225
+ Packet16f bf = half2float(b);
2226
+ return Pack32To16(pcmp_eq(af, bf));
2227
+ }
2228
+
2229
+ template <>
2230
+ EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) {
2231
+ return Pack32To16(pcmp_le(half2float(a), half2float(b)));
2232
+ }
2233
+
2234
+ template <>
2235
+ EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) {
2236
+ return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
2237
+ }
2238
+
2239
+ template <>
2240
+ EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) {
2241
+ return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
2242
+ }
2243
+
2244
+ template <>
2245
+ EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) {
2246
+ return a;
2247
+ }
2248
+
2249
+ template <>
2250
+ EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
2251
+ Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
2252
+ return _mm256_xor_si256(a, sign_mask);
2253
+ }
2254
+
2255
+ template <>
2256
+ EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
2257
+ Packet16f af = half2float(a);
2258
+ Packet16f bf = half2float(b);
2259
+ Packet16f rf = padd(af, bf);
2260
+ return float2half(rf);
2261
+ }
2262
+
2263
+ template <>
2264
+ EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
2265
+ Packet16f af = half2float(a);
2266
+ Packet16f bf = half2float(b);
2267
+ Packet16f rf = psub(af, bf);
2268
+ return float2half(rf);
2269
+ }
2270
+
2271
+ template <>
2272
+ EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
2273
+ Packet16f af = half2float(a);
2274
+ Packet16f bf = half2float(b);
2275
+ Packet16f rf = pmul(af, bf);
2276
+ return float2half(rf);
2277
+ }
2278
+
2279
+ template <>
2280
+ EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
2281
+ Packet16f af = half2float(a);
2282
+ Packet16f bf = half2float(b);
2283
+ Packet16f rf = pdiv(af, bf);
2284
+ return float2half(rf);
2285
+ }
2286
+
2287
+ template <>
2288
+ EIGEN_STRONG_INLINE Packet16h pmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
2289
+ return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
2290
+ }
2291
+
2292
+ template <>
2293
+ EIGEN_STRONG_INLINE Packet16h pmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
2294
+ return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
2295
+ }
2296
+
2297
+ template <>
2298
+ EIGEN_STRONG_INLINE Packet16h pnmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
2299
+ return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
2300
+ }
2301
+
2302
+ template <>
2303
+ EIGEN_STRONG_INLINE Packet16h pnmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
2304
+ return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
2305
+ }
2306
+
2307
+ template <>
2308
+ EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
2309
+ Packet8h lane0 = _mm256_extractf128_si256(a, 0);
2310
+ Packet8h lane1 = _mm256_extractf128_si256(a, 1);
2311
+ return padd<Packet8h>(lane0, lane1);
2312
+ }
2313
+
2314
+ template <>
2315
+ EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) {
2316
+ __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
2317
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a, 1), m)),
2318
+ _mm_shuffle_epi8(_mm256_extractf128_si256(a, 0), m), 1);
2319
+ }
2320
+
2321
+ template <>
2322
+ EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride) {
2323
+ return _mm256_set_epi16(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
2324
+ from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x,
2325
+ from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x,
2326
+ from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x);
2327
+ }
2328
+
2329
+ template <>
2330
+ EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride) {
2331
+ EIGEN_ALIGN64 half aux[16];
2332
+ pstore(aux, from);
2333
+ to[stride * 0] = aux[0];
2334
+ to[stride * 1] = aux[1];
2335
+ to[stride * 2] = aux[2];
2336
+ to[stride * 3] = aux[3];
2337
+ to[stride * 4] = aux[4];
2338
+ to[stride * 5] = aux[5];
2339
+ to[stride * 6] = aux[6];
2340
+ to[stride * 7] = aux[7];
2341
+ to[stride * 8] = aux[8];
2342
+ to[stride * 9] = aux[9];
2343
+ to[stride * 10] = aux[10];
2344
+ to[stride * 11] = aux[11];
2345
+ to[stride * 12] = aux[12];
2346
+ to[stride * 13] = aux[13];
2347
+ to[stride * 14] = aux[14];
2348
+ to[stride * 15] = aux[15];
2349
+ }
2350
+
2351
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 16>& kernel) {
2352
+ __m256i a = kernel.packet[0];
2353
+ __m256i b = kernel.packet[1];
2354
+ __m256i c = kernel.packet[2];
2355
+ __m256i d = kernel.packet[3];
2356
+ __m256i e = kernel.packet[4];
2357
+ __m256i f = kernel.packet[5];
2358
+ __m256i g = kernel.packet[6];
2359
+ __m256i h = kernel.packet[7];
2360
+ __m256i i = kernel.packet[8];
2361
+ __m256i j = kernel.packet[9];
2362
+ __m256i k = kernel.packet[10];
2363
+ __m256i l = kernel.packet[11];
2364
+ __m256i m = kernel.packet[12];
2365
+ __m256i n = kernel.packet[13];
2366
+ __m256i o = kernel.packet[14];
2367
+ __m256i p = kernel.packet[15];
2368
+
2369
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
2370
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
2371
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
2372
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
2373
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
2374
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
2375
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
2376
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
2377
+
2378
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
2379
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
2380
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
2381
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
2382
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
2383
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
2384
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
2385
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
2386
+
2387
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
2388
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
2389
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
2390
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
2391
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
2392
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
2393
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
2394
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
2395
+
2396
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
2397
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
2398
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
2399
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
2400
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
2401
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
2402
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
2403
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
2404
+
2405
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
2406
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
2407
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
2408
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
2409
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
2410
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
2411
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
2412
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
2413
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
2414
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
2415
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
2416
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
2417
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
2418
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
2419
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
2420
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
2421
+
2422
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
2423
+ __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
2424
+ __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
2425
+ __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
2426
+ __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
2427
+ __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
2428
+ __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
2429
+ __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
2430
+ __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
2431
+ __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
2432
+ __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
2433
+ __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
2434
+ __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
2435
+ __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
2436
+ __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
2437
+ __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
2438
+ __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
2439
+
2440
+ kernel.packet[0] = a_p_0;
2441
+ kernel.packet[1] = a_p_1;
2442
+ kernel.packet[2] = a_p_2;
2443
+ kernel.packet[3] = a_p_3;
2444
+ kernel.packet[4] = a_p_4;
2445
+ kernel.packet[5] = a_p_5;
2446
+ kernel.packet[6] = a_p_6;
2447
+ kernel.packet[7] = a_p_7;
2448
+ kernel.packet[8] = a_p_8;
2449
+ kernel.packet[9] = a_p_9;
2450
+ kernel.packet[10] = a_p_a;
2451
+ kernel.packet[11] = a_p_b;
2452
+ kernel.packet[12] = a_p_c;
2453
+ kernel.packet[13] = a_p_d;
2454
+ kernel.packet[14] = a_p_e;
2455
+ kernel.packet[15] = a_p_f;
2456
+ }
2457
+
2458
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 8>& kernel) {
2459
+ EIGEN_ALIGN64 half in[8][16];
2460
+ pstore<half>(in[0], kernel.packet[0]);
2461
+ pstore<half>(in[1], kernel.packet[1]);
2462
+ pstore<half>(in[2], kernel.packet[2]);
2463
+ pstore<half>(in[3], kernel.packet[3]);
2464
+ pstore<half>(in[4], kernel.packet[4]);
2465
+ pstore<half>(in[5], kernel.packet[5]);
2466
+ pstore<half>(in[6], kernel.packet[6]);
2467
+ pstore<half>(in[7], kernel.packet[7]);
2468
+
2469
+ EIGEN_ALIGN64 half out[8][16];
2470
+
2471
+ for (int i = 0; i < 8; ++i) {
2472
+ for (int j = 0; j < 8; ++j) {
2473
+ out[i][j] = in[j][2 * i];
2474
+ }
2475
+ for (int j = 0; j < 8; ++j) {
2476
+ out[i][j + 8] = in[j][2 * i + 1];
1275
2477
  }
1276
2478
  }
1277
- };
1278
- template <int Offset>
1279
- struct palign_impl<Offset, Packet8d> {
1280
- static EIGEN_STRONG_INLINE void run(Packet8d& first, const Packet8d& second) {
1281
- if (Offset != 0) {
1282
- __m512i first_idx = _mm512_set_epi32(
1283
- 0, Offset + 7, 0, Offset + 6, 0, Offset + 5, 0, Offset + 4, 0,
1284
- Offset + 3, 0, Offset + 2, 0, Offset + 1, 0, Offset);
1285
-
1286
- __m512i second_idx = _mm512_set_epi32(
1287
- 0, Offset - 1, 0, Offset - 2, 0, Offset - 3, 0, Offset - 4, 0,
1288
- Offset - 5, 0, Offset - 6, 0, Offset - 7, 0, Offset - 8);
1289
-
1290
- unsigned char mask = 0xFF;
1291
- mask <<= (8 - Offset);
1292
-
1293
- first = _mm512_permutexvar_pd(first_idx, first);
1294
- Packet8d tmp = _mm512_permutexvar_pd(second_idx, second);
1295
- first = _mm512_mask_blend_pd(mask, first, tmp);
2479
+
2480
+ kernel.packet[0] = pload<Packet16h>(out[0]);
2481
+ kernel.packet[1] = pload<Packet16h>(out[1]);
2482
+ kernel.packet[2] = pload<Packet16h>(out[2]);
2483
+ kernel.packet[3] = pload<Packet16h>(out[3]);
2484
+ kernel.packet[4] = pload<Packet16h>(out[4]);
2485
+ kernel.packet[5] = pload<Packet16h>(out[5]);
2486
+ kernel.packet[6] = pload<Packet16h>(out[6]);
2487
+ kernel.packet[7] = pload<Packet16h>(out[7]);
2488
+ }
2489
+
2490
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
2491
+ EIGEN_ALIGN64 half in[4][16];
2492
+ pstore<half>(in[0], kernel.packet[0]);
2493
+ pstore<half>(in[1], kernel.packet[1]);
2494
+ pstore<half>(in[2], kernel.packet[2]);
2495
+ pstore<half>(in[3], kernel.packet[3]);
2496
+
2497
+ EIGEN_ALIGN64 half out[4][16];
2498
+
2499
+ for (int i = 0; i < 4; ++i) {
2500
+ for (int j = 0; j < 4; ++j) {
2501
+ out[i][j] = in[j][4 * i];
2502
+ }
2503
+ for (int j = 0; j < 4; ++j) {
2504
+ out[i][j + 4] = in[j][4 * i + 1];
2505
+ }
2506
+ for (int j = 0; j < 4; ++j) {
2507
+ out[i][j + 8] = in[j][4 * i + 2];
2508
+ }
2509
+ for (int j = 0; j < 4; ++j) {
2510
+ out[i][j + 12] = in[j][4 * i + 3];
1296
2511
  }
1297
2512
  }
2513
+
2514
+ kernel.packet[0] = pload<Packet16h>(out[0]);
2515
+ kernel.packet[1] = pload<Packet16h>(out[1]);
2516
+ kernel.packet[2] = pload<Packet16h>(out[2]);
2517
+ kernel.packet[3] = pload<Packet16h>(out[3]);
2518
+ }
2519
+
2520
+ #endif // EIGEN_VECTORIZE_AVX512FP16
2521
+
2522
+ template <>
2523
+ struct is_arithmetic<Packet16bf> {
2524
+ enum { value = true };
2525
+ };
2526
+
2527
+ template <>
2528
+ struct packet_traits<bfloat16> : default_packet_traits {
2529
+ typedef Packet16bf type;
2530
+ typedef Packet8bf half;
2531
+ enum {
2532
+ Vectorizable = 1,
2533
+ AlignedOnScalar = 1,
2534
+ size = 16,
2535
+ HasBlend = 0,
2536
+ HasInsert = 1,
2537
+ HasSin = EIGEN_FAST_MATH,
2538
+ HasCos = EIGEN_FAST_MATH,
2539
+ HasSqrt = 1,
2540
+ HasRsqrt = 1,
2541
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
2542
+ HasLog = 1, // Currently fails test with bad accuracy.
2543
+ HasLog1p = 1,
2544
+ HasExpm1 = 1,
2545
+ HasNdtri = 1,
2546
+ HasBessel = 1,
2547
+ #endif
2548
+ HasExp = 1,
2549
+ HasTanh = EIGEN_FAST_MATH,
2550
+ HasErf = EIGEN_FAST_MATH,
2551
+ HasCmp = 1,
2552
+ HasDiv = 1
2553
+ };
2554
+ };
2555
+
2556
+ template <>
2557
+ struct unpacket_traits<Packet16bf> {
2558
+ typedef bfloat16 type;
2559
+ enum {
2560
+ size = 16,
2561
+ alignment = Aligned32,
2562
+ vectorizable = true,
2563
+ masked_load_available = false,
2564
+ masked_store_available = false
2565
+ };
2566
+ typedef Packet8bf half;
1298
2567
  };
1299
2568
 
2569
+ template <>
2570
+ EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
2571
+ return _mm256_set1_epi16(from.value);
2572
+ }
2573
+
2574
+ template <>
2575
+ EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
2576
+ bfloat16 t;
2577
+ t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
2578
+ return t;
2579
+ }
2580
+
2581
+ template <>
2582
+ EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
2583
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
2584
+ }
2585
+
2586
+ template <>
2587
+ EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
2588
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
2589
+ }
2590
+
2591
+ template <>
2592
+ EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet16bf& from) {
2593
+ EIGEN_DEBUG_ALIGNED_STORE
2594
+ _mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
2595
+ }
2596
+
2597
+ template <>
2598
+ EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet16bf& from) {
2599
+ EIGEN_DEBUG_UNALIGNED_STORE
2600
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
2601
+ }
2602
+
2603
+ template <>
2604
+ EIGEN_STRONG_INLINE Packet16bf ploaddup<Packet16bf>(const bfloat16* from) {
2605
+ unsigned short a = from[0].value;
2606
+ unsigned short b = from[1].value;
2607
+ unsigned short c = from[2].value;
2608
+ unsigned short d = from[3].value;
2609
+ unsigned short e = from[4].value;
2610
+ unsigned short f = from[5].value;
2611
+ unsigned short g = from[6].value;
2612
+ unsigned short h = from[7].value;
2613
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
2614
+ }
2615
+
2616
+ template <>
2617
+ EIGEN_STRONG_INLINE Packet16bf ploadquad(const bfloat16* from) {
2618
+ unsigned short a = from[0].value;
2619
+ unsigned short b = from[1].value;
2620
+ unsigned short c = from[2].value;
2621
+ unsigned short d = from[3].value;
2622
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
2623
+ }
2624
+
2625
+ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
2626
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
2627
+ }
2628
+
2629
+ // Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
2630
+ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
2631
+ Packet16bf r;
2632
+
2633
+ #if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_STRICT_AT_LEAST(10, 1, 0)
2634
+ // Since GCC 10.1 supports avx512bf16 and C style explicit cast
2635
+ // (C++ static_cast is not supported yet), do conversion via intrinsic
2636
+ // and register path for performance.
2637
+ r = (__m256i)(_mm512_cvtneps_pbh(a));
2638
+
2639
+ #else
2640
+ __m512i t;
2641
+ __m512i input = _mm512_castps_si512(a);
2642
+ __m512i nan = _mm512_set1_epi32(0x7fc0);
2643
+
2644
+ // uint32_t lsb = (input >> 16) & 1;
2645
+ t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
2646
+ // uint32_t rounding_bias = 0x7fff + lsb;
2647
+ t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
2648
+ // input += rounding_bias;
2649
+ t = _mm512_add_epi32(t, input);
2650
+ // input = input >> 16;
2651
+ t = _mm512_srli_epi32(t, 16);
2652
+
2653
+ // Check NaN before converting back to bf16
2654
+ __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
2655
+
2656
+ t = _mm512_mask_blend_epi32(mask, nan, t);
2657
+ // output.value = static_cast<uint16_t>(input);
2658
+ r = _mm512_cvtepi32_epi16(t);
2659
+ #endif // EIGEN_VECTORIZE_AVX512BF16
2660
+
2661
+ return r;
2662
+ }
2663
+
2664
+ template <>
2665
+ EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
2666
+ return Packet16bf(ptrue<Packet8i>(Packet8i(a)));
2667
+ }
2668
+
2669
+ template <>
2670
+ EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
2671
+ return Packet16bf(por<Packet8i>(Packet8i(a), Packet8i(b)));
2672
+ }
2673
+
2674
+ template <>
2675
+ EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
2676
+ return Packet16bf(pxor<Packet8i>(Packet8i(a), Packet8i(b)));
2677
+ }
2678
+
2679
+ template <>
2680
+ EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
2681
+ return Packet16bf(pand<Packet8i>(Packet8i(a), Packet8i(b)));
2682
+ }
2683
+
2684
+ template <>
2685
+ EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, const Packet16bf& b) {
2686
+ return Packet16bf(pandnot<Packet8i>(Packet8i(a), Packet8i(b)));
2687
+ }
2688
+
2689
+ template <>
2690
+ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, const Packet16bf& a, const Packet16bf& b) {
2691
+ // Input mask is expected to be all 0/1, handle it with 8-bit
2692
+ // intrinsic for performance.
2693
+ return _mm256_blendv_epi8(b, a, mask);
2694
+ }
2695
+
2696
+ template <>
2697
+ EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a) {
2698
+ return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
2699
+ }
2700
+
2701
+ template <>
2702
+ EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
2703
+ return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
2704
+ }
2705
+
2706
+ template <>
2707
+ EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
2708
+ return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
2709
+ }
2710
+
2711
+ template <>
2712
+ EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
2713
+ return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
2714
+ }
2715
+
2716
+ template <>
2717
+ EIGEN_STRONG_INLINE Packet16bf ptrunc<Packet16bf>(const Packet16bf& a) {
2718
+ return F32ToBf16(ptrunc<Packet16f>(Bf16ToF32(a)));
2719
+ }
2720
+
2721
+ template <>
2722
+ EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, const Packet16bf& b) {
2723
+ return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
2724
+ }
2725
+
2726
+ template <>
2727
+ EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, const Packet16bf& b) {
2728
+ return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
2729
+ }
2730
+
2731
+ template <>
2732
+ EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, const Packet16bf& b) {
2733
+ return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
2734
+ }
2735
+
2736
+ template <>
2737
+ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, const Packet16bf& b) {
2738
+ return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
2739
+ }
2740
+
2741
+ template <>
2742
+ EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
2743
+ Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
2744
+ return _mm256_xor_si256(a, sign_mask);
2745
+ }
2746
+
2747
+ template <>
2748
+ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
2749
+ return a;
2750
+ }
2751
+
2752
+ template <>
2753
+ EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
2754
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
2755
+ return _mm256_andnot_si256(sign_mask, a);
2756
+ }
2757
+
2758
+ template <>
2759
+ EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2760
+ return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2761
+ }
2762
+
2763
+ template <>
2764
+ EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2765
+ return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2766
+ }
2767
+
2768
+ template <>
2769
+ EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2770
+ return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b)));
2771
+ }
2772
+
2773
+ template <>
2774
+ EIGEN_STRONG_INLINE Packet16bf pmadd<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
2775
+ return F32ToBf16(pmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
2776
+ }
2777
+
2778
+ template <>
2779
+ EIGEN_STRONG_INLINE Packet16bf pmsub<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
2780
+ return F32ToBf16(pmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
2781
+ }
2782
+
2783
+ template <>
2784
+ EIGEN_STRONG_INLINE Packet16bf pnmadd<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
2785
+ return F32ToBf16(pnmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
2786
+ }
2787
+
2788
+ template <>
2789
+ EIGEN_STRONG_INLINE Packet16bf pnmsub<Packet16bf>(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) {
2790
+ return F32ToBf16(pnmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c)));
2791
+ }
2792
+
2793
+ template <>
2794
+ EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2795
+ return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2796
+ }
2797
+
2798
+ template <>
2799
+ EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2800
+ return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2801
+ }
2802
+
2803
+ template <>
2804
+ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a, const Packet16bf& b) {
2805
+ return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2806
+ }
2807
+
2808
+ template <>
2809
+ EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) {
2810
+ return F32ToBf16(plset<Packet16f>(static_cast<float>(a)));
2811
+ }
2812
+
2813
+ template <>
2814
+ EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
2815
+ Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
2816
+ Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
2817
+ return padd<Packet8bf>(lane0, lane1);
2818
+ }
2819
+
2820
+ template <>
2821
+ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
2822
+ __m256i m = _mm256_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1, 14, 15, 12, 13, 10, 11, 8, 9, 6, 7,
2823
+ 4, 5, 2, 3, 0, 1);
2824
+
2825
+ Packet16bf res;
2826
+ // Swap hi and lo first because shuffle is in 128-bit lanes.
2827
+ res = _mm256_permute2x128_si256(a, a, 1);
2828
+ // Shuffle 8-bit values in src within 2*128-bit lanes.
2829
+ return _mm256_shuffle_epi8(res, m);
2830
+ }
2831
+
2832
+ template <>
2833
+ EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from, Index stride) {
2834
+ return _mm256_set_epi16(
2835
+ from[15 * stride].value, from[14 * stride].value, from[13 * stride].value, from[12 * stride].value,
2836
+ from[11 * stride].value, from[10 * stride].value, from[9 * stride].value, from[8 * stride].value,
2837
+ from[7 * stride].value, from[6 * stride].value, from[5 * stride].value, from[4 * stride].value,
2838
+ from[3 * stride].value, from[2 * stride].value, from[1 * stride].value, from[0 * stride].value);
2839
+ }
2840
+
2841
+ template <>
2842
+ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to, const Packet16bf& from, Index stride) {
2843
+ EIGEN_ALIGN64 bfloat16 aux[16];
2844
+ pstore(aux, from);
2845
+ to[stride * 0] = aux[0];
2846
+ to[stride * 1] = aux[1];
2847
+ to[stride * 2] = aux[2];
2848
+ to[stride * 3] = aux[3];
2849
+ to[stride * 4] = aux[4];
2850
+ to[stride * 5] = aux[5];
2851
+ to[stride * 6] = aux[6];
2852
+ to[stride * 7] = aux[7];
2853
+ to[stride * 8] = aux[8];
2854
+ to[stride * 9] = aux[9];
2855
+ to[stride * 10] = aux[10];
2856
+ to[stride * 11] = aux[11];
2857
+ to[stride * 12] = aux[12];
2858
+ to[stride * 13] = aux[13];
2859
+ to[stride * 14] = aux[14];
2860
+ to[stride * 15] = aux[15];
2861
+ }
2862
+
2863
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 16>& kernel) {
2864
+ __m256i a = kernel.packet[0];
2865
+ __m256i b = kernel.packet[1];
2866
+ __m256i c = kernel.packet[2];
2867
+ __m256i d = kernel.packet[3];
2868
+ __m256i e = kernel.packet[4];
2869
+ __m256i f = kernel.packet[5];
2870
+ __m256i g = kernel.packet[6];
2871
+ __m256i h = kernel.packet[7];
2872
+ __m256i i = kernel.packet[8];
2873
+ __m256i j = kernel.packet[9];
2874
+ __m256i k = kernel.packet[10];
2875
+ __m256i l = kernel.packet[11];
2876
+ __m256i m = kernel.packet[12];
2877
+ __m256i n = kernel.packet[13];
2878
+ __m256i o = kernel.packet[14];
2879
+ __m256i p = kernel.packet[15];
2880
+
2881
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
2882
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
2883
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
2884
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
2885
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
2886
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
2887
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
2888
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
2889
+
2890
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
2891
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
2892
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
2893
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
2894
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
2895
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
2896
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
2897
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
2898
+
2899
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
2900
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
2901
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
2902
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
2903
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
2904
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
2905
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
2906
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
2907
+
2908
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
2909
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
2910
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
2911
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
2912
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
2913
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
2914
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
2915
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
2916
+
2917
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
2918
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
2919
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
2920
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
2921
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
2922
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
2923
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
2924
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
2925
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
2926
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
2927
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
2928
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
2929
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
2930
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
2931
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
2932
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
2933
+
2934
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
2935
+ kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
2936
+ kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
2937
+ kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
2938
+ kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
2939
+ kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
2940
+ kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
2941
+ kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
2942
+ kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
2943
+ kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
2944
+ kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
2945
+ kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
2946
+ kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
2947
+ kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
2948
+ kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
2949
+ kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
2950
+ kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
2951
+ }
2952
+
2953
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 4>& kernel) {
2954
+ __m256i a = kernel.packet[0];
2955
+ __m256i b = kernel.packet[1];
2956
+ __m256i c = kernel.packet[2];
2957
+ __m256i d = kernel.packet[3];
2958
+
2959
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
2960
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
2961
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
2962
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
2963
+
2964
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
2965
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
2966
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
2967
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
2968
+
2969
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
2970
+ kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
2971
+ kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
2972
+ kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
2973
+ kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
2974
+ }
2975
+
2976
+ // Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
2977
+
2978
+ template <>
2979
+ EIGEN_STRONG_INLINE Packet32s pset1<Packet32s>(const numext::int16_t& x) {
2980
+ return _mm512_set1_epi16(x);
2981
+ }
2982
+
2983
+ template <>
2984
+ EIGEN_STRONG_INLINE Packet16s pset1<Packet16s>(const numext::int16_t& x) {
2985
+ return _mm256_set1_epi16(x);
2986
+ }
2987
+
2988
+ template <>
2989
+ EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
2990
+ return _mm_set1_epi16(x);
2991
+ }
2992
+
2993
+ template <>
2994
+ EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
2995
+ EIGEN_DEBUG_ALIGNED_STORE
2996
+ _mm512_store_epi32(out, x);
2997
+ }
2998
+
2999
+ template <>
3000
+ EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
3001
+ EIGEN_DEBUG_ALIGNED_STORE
3002
+ #if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
3003
+ _mm256_store_epi32(out, x);
3004
+ #else
3005
+ _mm256_store_si256(reinterpret_cast<__m256i*>(out), x);
3006
+ #endif
3007
+ }
3008
+
3009
+ template <>
3010
+ EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
3011
+ EIGEN_DEBUG_ALIGNED_STORE
3012
+ #if defined(EIGEN_VECTORIZE_AVX512F) && defined(EIGEN_VECTORIZE_AVX512VL)
3013
+ _mm256_store_epi32(out, x);
3014
+ #else
3015
+ _mm_store_si128(reinterpret_cast<__m128i*>(out), x);
3016
+ #endif
3017
+ }
3018
+
3019
+ template <>
3020
+ EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
3021
+ EIGEN_DEBUG_UNALIGNED_STORE
3022
+ _mm512_storeu_epi32(out, x);
3023
+ }
3024
+
3025
+ template <>
3026
+ EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
3027
+ EIGEN_DEBUG_UNALIGNED_STORE
3028
+ _mm256_storeu_epi32(out, x);
3029
+ }
3030
+
3031
+ template <>
3032
+ EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
3033
+ EIGEN_DEBUG_UNALIGNED_STORE
3034
+ _mm_storeu_epi32(out, x);
3035
+ }
3036
+
3037
+ template <>
3038
+ EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) {
3039
+ return _mm512_add_epi16(a, b);
3040
+ }
3041
+
3042
+ template <>
3043
+ EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) {
3044
+ return _mm256_add_epi16(a, b);
3045
+ }
3046
+
3047
+ template <>
3048
+ EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) {
3049
+ return _mm_add_epi16(a, b);
3050
+ }
3051
+
3052
+ template <>
3053
+ EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) {
3054
+ return _mm512_sub_epi16(a, b);
3055
+ }
3056
+
3057
+ template <>
3058
+ EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) {
3059
+ return _mm256_sub_epi16(a, b);
3060
+ }
3061
+
3062
+ template <>
3063
+ EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) {
3064
+ return _mm_sub_epi16(a, b);
3065
+ }
3066
+
3067
+ template <>
3068
+ EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) {
3069
+ return _mm512_mullo_epi16(a, b);
3070
+ }
3071
+
3072
+ template <>
3073
+ EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) {
3074
+ return _mm256_mullo_epi16(a, b);
3075
+ }
3076
+
3077
+ template <>
3078
+ EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) {
3079
+ return _mm_mullo_epi16(a, b);
3080
+ }
3081
+
3082
+ template <>
3083
+ EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) {
3084
+ return _mm512_sub_epi16(_mm512_setzero_si512(), a);
3085
+ }
3086
+
3087
+ template <>
3088
+ EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) {
3089
+ return _mm256_sub_epi16(_mm256_setzero_si256(), a);
3090
+ }
3091
+
3092
+ template <>
3093
+ EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) {
3094
+ return _mm_sub_epi16(_mm_setzero_si128(), a);
3095
+ }
3096
+
3097
+ template <int N>
3098
+ EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) {
3099
+ return _mm512_srai_epi16(a, N);
3100
+ }
3101
+
3102
+ template <int N>
3103
+ EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) {
3104
+ return _mm256_srai_epi16(a, N);
3105
+ }
3106
+
3107
+ template <int N>
3108
+ EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) {
3109
+ return _mm_srai_epi16(a, N);
3110
+ }
3111
+
3112
+ template <int N>
3113
+ EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) {
3114
+ return _mm512_slli_epi16(a, N);
3115
+ }
3116
+
3117
+ template <int N>
3118
+ EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) {
3119
+ return _mm256_slli_epi16(a, N);
3120
+ }
3121
+
3122
+ template <int N>
3123
+ EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) {
3124
+ return _mm_slli_epi16(a, N);
3125
+ }
3126
+
3127
+ template <int N>
3128
+ EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) {
3129
+ return _mm512_srli_epi16(a, N);
3130
+ }
3131
+
3132
+ template <int N>
3133
+ EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) {
3134
+ return _mm256_srli_epi16(a, N);
3135
+ }
3136
+
3137
+ template <int N>
3138
+ EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) {
3139
+ return _mm_srli_epi16(a, N);
3140
+ }
1300
3141
 
1301
- } // end namespace internal
3142
+ } // end namespace internal
1302
3143
 
1303
- } // end namespace Eigen
3144
+ } // end namespace Eigen
1304
3145
 
1305
- #endif // EIGEN_PACKET_MATH_AVX512_H
3146
+ #endif // EIGEN_PACKET_MATH_AVX512_H