tomoto 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (369) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/README.md +8 -10
  4. data/ext/tomoto/ct.cpp +11 -11
  5. data/ext/tomoto/dmr.cpp +14 -13
  6. data/ext/tomoto/dt.cpp +14 -14
  7. data/ext/tomoto/extconf.rb +7 -5
  8. data/ext/tomoto/gdmr.cpp +7 -7
  9. data/ext/tomoto/hdp.cpp +9 -9
  10. data/ext/tomoto/hlda.cpp +13 -13
  11. data/ext/tomoto/hpa.cpp +5 -5
  12. data/ext/tomoto/lda.cpp +42 -39
  13. data/ext/tomoto/llda.cpp +6 -6
  14. data/ext/tomoto/mglda.cpp +15 -15
  15. data/ext/tomoto/pa.cpp +6 -6
  16. data/ext/tomoto/plda.cpp +6 -6
  17. data/ext/tomoto/slda.cpp +8 -8
  18. data/ext/tomoto/{ext.cpp → tomoto.cpp} +8 -8
  19. data/ext/tomoto/utils.h +16 -70
  20. data/lib/tomoto/version.rb +1 -1
  21. data/lib/tomoto.rb +5 -1
  22. data/vendor/EigenRand/EigenRand/Core.h +10 -10
  23. data/vendor/EigenRand/EigenRand/Dists/Basic.h +208 -9
  24. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +52 -31
  25. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +9 -8
  26. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +28 -21
  27. data/vendor/EigenRand/EigenRand/EigenRand +11 -6
  28. data/vendor/EigenRand/EigenRand/Macro.h +13 -7
  29. data/vendor/EigenRand/EigenRand/MorePacketMath.h +348 -740
  30. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +5 -3
  31. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +9 -3
  32. data/vendor/EigenRand/EigenRand/PacketFilter.h +11 -253
  33. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +21 -47
  34. data/vendor/EigenRand/EigenRand/RandUtils.h +50 -344
  35. data/vendor/EigenRand/EigenRand/arch/AVX/MorePacketMath.h +619 -0
  36. data/vendor/EigenRand/EigenRand/arch/AVX/PacketFilter.h +149 -0
  37. data/vendor/EigenRand/EigenRand/arch/AVX/RandUtils.h +228 -0
  38. data/vendor/EigenRand/EigenRand/arch/NEON/MorePacketMath.h +473 -0
  39. data/vendor/EigenRand/EigenRand/arch/NEON/PacketFilter.h +142 -0
  40. data/vendor/EigenRand/EigenRand/arch/NEON/RandUtils.h +126 -0
  41. data/vendor/EigenRand/EigenRand/arch/SSE/MorePacketMath.h +501 -0
  42. data/vendor/EigenRand/EigenRand/arch/SSE/PacketFilter.h +133 -0
  43. data/vendor/EigenRand/EigenRand/arch/SSE/RandUtils.h +120 -0
  44. data/vendor/EigenRand/EigenRand/doc.h +24 -12
  45. data/vendor/EigenRand/README.md +57 -4
  46. data/vendor/eigen/COPYING.APACHE +203 -0
  47. data/vendor/eigen/COPYING.BSD +1 -1
  48. data/vendor/eigen/COPYING.MINPACK +51 -52
  49. data/vendor/eigen/Eigen/Cholesky +0 -1
  50. data/vendor/eigen/Eigen/Core +112 -265
  51. data/vendor/eigen/Eigen/Eigenvalues +2 -3
  52. data/vendor/eigen/Eigen/Geometry +5 -8
  53. data/vendor/eigen/Eigen/Householder +0 -1
  54. data/vendor/eigen/Eigen/Jacobi +0 -1
  55. data/vendor/eigen/Eigen/KLUSupport +41 -0
  56. data/vendor/eigen/Eigen/LU +2 -5
  57. data/vendor/eigen/Eigen/OrderingMethods +0 -3
  58. data/vendor/eigen/Eigen/PaStiXSupport +1 -0
  59. data/vendor/eigen/Eigen/PardisoSupport +0 -0
  60. data/vendor/eigen/Eigen/QR +2 -3
  61. data/vendor/eigen/Eigen/QtAlignedMalloc +0 -1
  62. data/vendor/eigen/Eigen/SVD +0 -1
  63. data/vendor/eigen/Eigen/Sparse +0 -2
  64. data/vendor/eigen/Eigen/SparseCholesky +0 -8
  65. data/vendor/eigen/Eigen/SparseLU +4 -0
  66. data/vendor/eigen/Eigen/SparseQR +0 -1
  67. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +42 -27
  68. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +39 -23
  69. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +90 -47
  70. data/vendor/eigen/Eigen/src/Core/ArithmeticSequence.h +413 -0
  71. data/vendor/eigen/Eigen/src/Core/Array.h +99 -11
  72. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +3 -3
  73. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +21 -21
  74. data/vendor/eigen/Eigen/src/Core/Assign.h +1 -1
  75. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +125 -50
  76. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +10 -10
  77. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +16 -16
  78. data/vendor/eigen/Eigen/src/Core/Block.h +56 -60
  79. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +29 -31
  80. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +7 -3
  81. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +325 -272
  82. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +5 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +21 -22
  84. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +153 -18
  85. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +6 -6
  86. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +14 -10
  87. data/vendor/eigen/Eigen/src/Core/DenseBase.h +132 -42
  88. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +25 -21
  89. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +153 -71
  90. data/vendor/eigen/Eigen/src/Core/Diagonal.h +21 -23
  91. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +50 -2
  92. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +1 -1
  93. data/vendor/eigen/Eigen/src/Core/Dot.h +10 -10
  94. data/vendor/eigen/Eigen/src/Core/EigenBase.h +10 -9
  95. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +8 -4
  96. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +3 -3
  97. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +20 -10
  98. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +599 -152
  99. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +40 -33
  100. data/vendor/eigen/Eigen/src/Core/IO.h +40 -7
  101. data/vendor/eigen/Eigen/src/Core/IndexedView.h +237 -0
  102. data/vendor/eigen/Eigen/src/Core/Inverse.h +9 -10
  103. data/vendor/eigen/Eigen/src/Core/Map.h +7 -7
  104. data/vendor/eigen/Eigen/src/Core/MapBase.h +10 -3
  105. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +767 -125
  106. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +118 -19
  107. data/vendor/eigen/Eigen/src/Core/Matrix.h +131 -25
  108. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +21 -3
  109. data/vendor/eigen/Eigen/src/Core/NestByValue.h +25 -50
  110. data/vendor/eigen/Eigen/src/Core/NoAlias.h +4 -3
  111. data/vendor/eigen/Eigen/src/Core/NumTraits.h +107 -20
  112. data/vendor/eigen/Eigen/src/Core/PartialReduxEvaluator.h +232 -0
  113. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +3 -31
  114. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +152 -59
  115. data/vendor/eigen/Eigen/src/Core/Product.h +30 -25
  116. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +192 -125
  117. data/vendor/eigen/Eigen/src/Core/Random.h +37 -1
  118. data/vendor/eigen/Eigen/src/Core/Redux.h +180 -170
  119. data/vendor/eigen/Eigen/src/Core/Ref.h +121 -23
  120. data/vendor/eigen/Eigen/src/Core/Replicate.h +8 -8
  121. data/vendor/eigen/Eigen/src/Core/Reshaped.h +454 -0
  122. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +7 -5
  123. data/vendor/eigen/Eigen/src/Core/Reverse.h +18 -12
  124. data/vendor/eigen/Eigen/src/Core/Select.h +8 -6
  125. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +33 -20
  126. data/vendor/eigen/Eigen/src/Core/Solve.h +14 -14
  127. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +16 -16
  128. data/vendor/eigen/Eigen/src/Core/SolverBase.h +41 -3
  129. data/vendor/eigen/Eigen/src/Core/StableNorm.h +100 -70
  130. data/vendor/eigen/Eigen/src/Core/StlIterators.h +463 -0
  131. data/vendor/eigen/Eigen/src/Core/Stride.h +9 -4
  132. data/vendor/eigen/Eigen/src/Core/Swap.h +5 -4
  133. data/vendor/eigen/Eigen/src/Core/Transpose.h +88 -27
  134. data/vendor/eigen/Eigen/src/Core/Transpositions.h +26 -47
  135. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +93 -75
  136. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +5 -5
  137. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +159 -70
  138. data/vendor/eigen/Eigen/src/Core/Visitor.h +137 -29
  139. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +50 -129
  140. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +126 -337
  141. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +1092 -155
  142. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +65 -1
  143. data/vendor/eigen/Eigen/src/Core/arch/AVX512/Complex.h +422 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +207 -236
  145. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1482 -495
  146. data/vendor/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +89 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +152 -165
  148. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +19 -251
  149. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2937 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +221 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +629 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2042 -392
  153. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +235 -80
  154. data/vendor/eigen/Eigen/src/Core/arch/Default/BFloat16.h +700 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +102 -14
  156. data/vendor/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1649 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +110 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/Default/Half.h +942 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +1 -1
  160. data/vendor/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +120 -0
  161. data/vendor/eigen/Eigen/src/Core/arch/{CUDA → GPU}/MathFunctions.h +16 -4
  162. data/vendor/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1685 -0
  163. data/vendor/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +80 -0
  164. data/vendor/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  165. data/vendor/eigen/Eigen/src/Core/arch/MSA/Complex.h +648 -0
  166. data/vendor/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +387 -0
  167. data/vendor/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1233 -0
  168. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +313 -219
  169. data/vendor/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +183 -0
  170. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +54 -70
  171. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4376 -549
  172. data/vendor/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1419 -0
  173. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +59 -179
  174. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +65 -428
  175. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +893 -283
  176. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +65 -0
  177. data/vendor/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +44 -0
  178. data/vendor/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +752 -0
  179. data/vendor/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +49 -0
  180. data/vendor/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +232 -0
  181. data/vendor/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +301 -0
  182. data/vendor/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +670 -0
  183. data/vendor/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +694 -0
  184. data/vendor/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +85 -0
  185. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +212 -183
  186. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +101 -5
  187. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +510 -395
  188. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +11 -2
  189. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +112 -46
  190. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +31 -30
  191. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +32 -2
  192. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +355 -16
  193. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1075 -586
  194. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +49 -24
  195. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +41 -35
  196. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +6 -6
  197. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +4 -2
  198. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +382 -483
  199. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +22 -5
  200. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +53 -30
  201. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +16 -8
  202. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +8 -6
  203. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +4 -4
  204. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +5 -4
  205. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +33 -27
  206. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +14 -12
  207. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +36 -34
  208. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +8 -4
  209. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +13 -10
  210. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +304 -119
  211. data/vendor/eigen/Eigen/src/Core/util/ConfigureVectorization.h +512 -0
  212. data/vendor/eigen/Eigen/src/Core/util/Constants.h +25 -9
  213. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +26 -3
  214. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +29 -9
  215. data/vendor/eigen/Eigen/src/Core/util/IndexedViewHelper.h +186 -0
  216. data/vendor/eigen/Eigen/src/Core/util/IntegralConstant.h +272 -0
  217. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +8 -1
  218. data/vendor/eigen/Eigen/src/Core/util/Macros.h +709 -246
  219. data/vendor/eigen/Eigen/src/Core/util/Memory.h +222 -52
  220. data/vendor/eigen/Eigen/src/Core/util/Meta.h +355 -77
  221. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +5 -1
  222. data/vendor/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  223. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +8 -5
  224. data/vendor/eigen/Eigen/src/Core/util/SymbolicIndex.h +293 -0
  225. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +65 -30
  226. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +1 -1
  227. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +7 -4
  228. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +2 -2
  229. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +1 -1
  230. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +2 -2
  231. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +2 -2
  232. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +9 -6
  233. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +21 -9
  234. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +77 -43
  235. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +20 -15
  236. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +99 -5
  237. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +4 -4
  238. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +3 -3
  239. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +15 -11
  240. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +1 -1
  241. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +3 -2
  242. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +39 -2
  243. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +70 -14
  244. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +3 -3
  245. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +23 -5
  246. data/vendor/eigen/Eigen/src/Geometry/Transform.h +88 -67
  247. data/vendor/eigen/Eigen/src/Geometry/Translation.h +6 -12
  248. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +1 -1
  249. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +168 -0
  250. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +9 -2
  251. data/vendor/eigen/Eigen/src/Householder/Householder.h +8 -4
  252. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +123 -48
  253. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +15 -15
  254. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +7 -23
  255. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +5 -22
  256. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +41 -47
  257. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +51 -60
  258. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +70 -20
  259. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +2 -20
  260. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +11 -9
  261. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +31 -10
  262. data/vendor/eigen/Eigen/src/KLUSupport/KLUSupport.h +358 -0
  263. data/vendor/eigen/Eigen/src/LU/Determinant.h +35 -19
  264. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +29 -43
  265. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +25 -8
  266. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +71 -58
  267. data/vendor/eigen/Eigen/src/LU/arch/InverseSize4.h +351 -0
  268. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +7 -17
  269. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +297 -277
  270. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +6 -10
  271. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +1 -1
  272. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +9 -7
  273. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +41 -20
  274. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +100 -27
  275. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +59 -22
  276. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +48 -23
  277. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +25 -3
  278. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +183 -63
  279. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +22 -14
  280. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +83 -22
  281. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +3 -3
  282. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +17 -9
  283. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +12 -37
  284. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +3 -2
  285. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +16 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +6 -6
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +81 -27
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +25 -57
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +40 -11
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +11 -15
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +4 -2
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +30 -8
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +126 -11
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +5 -12
  295. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +13 -1
  296. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +7 -7
  297. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +5 -2
  298. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +8 -0
  299. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +1 -1
  300. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +1 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +162 -12
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +1 -1
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +76 -2
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +2 -2
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +1 -1
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +1 -1
  307. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +19 -6
  308. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +2 -12
  309. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +2 -2
  310. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +2 -2
  311. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +6 -8
  312. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +175 -39
  313. data/vendor/eigen/Eigen/src/misc/lapacke.h +5 -4
  314. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +28 -2
  315. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +155 -11
  316. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +626 -242
  317. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +14 -0
  318. data/vendor/eigen/Eigen/src/plugins/IndexedViewMethods.h +262 -0
  319. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +4 -4
  320. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +10 -0
  321. data/vendor/eigen/Eigen/src/plugins/ReshapedMethods.h +149 -0
  322. data/vendor/eigen/README.md +2 -0
  323. data/vendor/eigen/bench/btl/README +1 -1
  324. data/vendor/eigen/bench/tensors/README +6 -7
  325. data/vendor/eigen/ci/README.md +56 -0
  326. data/vendor/eigen/demos/mix_eigen_and_c/README +1 -1
  327. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +213 -158
  328. data/vendor/eigen/unsupported/README.txt +1 -1
  329. data/vendor/tomotopy/README.kr.rst +78 -0
  330. data/vendor/tomotopy/README.rst +75 -0
  331. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +2 -2
  332. data/vendor/tomotopy/src/Labeling/Phraser.hpp +4 -4
  333. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +7 -3
  334. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +7 -3
  335. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +6 -3
  336. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +2 -2
  337. data/vendor/tomotopy/src/TopicModel/HDP.h +1 -0
  338. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +57 -6
  339. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +6 -3
  340. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +3 -2
  341. data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
  342. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +5 -5
  343. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +50 -19
  344. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +6 -2
  345. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +3 -2
  346. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
  347. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +6 -2
  348. data/vendor/tomotopy/src/TopicModel/PT.h +3 -1
  349. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +36 -3
  350. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +6 -3
  351. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +55 -26
  352. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
  353. data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
  354. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +36 -1
  355. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
  356. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
  357. data/vendor/tomotopy/src/Utils/exception.h +6 -0
  358. data/vendor/tomotopy/src/Utils/math.h +2 -2
  359. data/vendor/tomotopy/src/Utils/sample.hpp +14 -12
  360. data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
  361. data/vendor/tomotopy/src/Utils/sse_gamma.h +0 -3
  362. metadata +64 -18
  363. data/vendor/eigen/Eigen/CMakeLists.txt +0 -19
  364. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -674
  365. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  366. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  367. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  368. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  369. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
@@ -19,10 +19,10 @@ namespace internal {
19
19
  #endif
20
20
 
21
21
  #ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS
22
- #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*))
22
+ #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
23
23
  #endif
24
24
 
25
- #ifdef __FMA__
25
+ #ifdef EIGEN_VECTORIZE_FMA
26
26
  #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
27
27
  #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
28
28
  #endif
@@ -31,6 +31,8 @@ namespace internal {
31
31
  typedef __m512 Packet16f;
32
32
  typedef __m512i Packet16i;
33
33
  typedef __m512d Packet8d;
34
+ typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
35
+ typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
34
36
 
35
37
  template <>
36
38
  struct is_arithmetic<__m512> {
@@ -45,6 +47,51 @@ struct is_arithmetic<__m512d> {
45
47
  enum { value = true };
46
48
  };
47
49
 
50
+ template<> struct is_arithmetic<Packet16h> { enum { value = true }; };
51
+
52
+ template <>
53
+ struct packet_traits<half> : default_packet_traits {
54
+ typedef Packet16h type;
55
+ // There is no half-size packet for Packet16h.
56
+ typedef Packet16h half;
57
+ enum {
58
+ Vectorizable = 1,
59
+ AlignedOnScalar = 1,
60
+ size = 16,
61
+ HasHalfPacket = 1,
62
+
63
+ HasCmp = 1,
64
+ HasAdd = 1,
65
+ HasSub = 1,
66
+ HasMul = 1,
67
+ HasDiv = 1,
68
+ HasNegate = 1,
69
+ HasAbs = 1,
70
+ HasAbs2 = 0,
71
+ HasMin = 1,
72
+ HasMax = 1,
73
+ HasConj = 1,
74
+ HasSetLinear = 0,
75
+ HasLog = 1,
76
+ HasLog1p = 1,
77
+ HasExpm1 = 1,
78
+ HasExp = 1,
79
+ HasSqrt = 1,
80
+ HasRsqrt = 1,
81
+ HasSin = EIGEN_FAST_MATH,
82
+ HasCos = EIGEN_FAST_MATH,
83
+ HasTanh = EIGEN_FAST_MATH,
84
+ HasErf = EIGEN_FAST_MATH,
85
+ HasBlend = 0,
86
+ HasRound = 1,
87
+ HasFloor = 1,
88
+ HasCeil = 1,
89
+ HasRint = 1,
90
+ HasBessel = 1,
91
+ HasNdtri = 1
92
+ };
93
+ };
94
+
48
95
  template<> struct packet_traits<float> : default_packet_traits
49
96
  {
50
97
  typedef Packet16f type;
@@ -54,15 +101,32 @@ template<> struct packet_traits<float> : default_packet_traits
54
101
  AlignedOnScalar = 1,
55
102
  size = 16,
56
103
  HasHalfPacket = 1,
57
- #if EIGEN_GNUC_AT_LEAST(5, 3)
58
- #ifdef EIGEN_VECTORIZE_AVX512DQ
104
+
105
+ HasAbs = 1,
106
+ HasMin = 1,
107
+ HasMax = 1,
108
+ HasConj = 1,
109
+ HasBlend = 0,
110
+ HasSin = EIGEN_FAST_MATH,
111
+ HasCos = EIGEN_FAST_MATH,
112
+ #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
59
113
  HasLog = 1,
60
- #endif
114
+ HasLog1p = 1,
115
+ HasExpm1 = 1,
116
+ HasNdtri = 1,
117
+ HasBessel = 1,
61
118
  HasExp = 1,
62
- HasSqrt = 1,
63
- HasRsqrt = 1,
119
+ HasSqrt = EIGEN_FAST_MATH,
120
+ HasRsqrt = EIGEN_FAST_MATH,
121
+ HasTanh = EIGEN_FAST_MATH,
122
+ HasErf = EIGEN_FAST_MATH,
64
123
  #endif
65
- HasDiv = 1
124
+ HasCmp = 1,
125
+ HasDiv = 1,
126
+ HasRound = 1,
127
+ HasFloor = 1,
128
+ HasCeil = 1,
129
+ HasRint = 1
66
130
  };
67
131
  };
68
132
  template<> struct packet_traits<double> : default_packet_traits
@@ -74,11 +138,18 @@ template<> struct packet_traits<double> : default_packet_traits
74
138
  AlignedOnScalar = 1,
75
139
  size = 8,
76
140
  HasHalfPacket = 1,
77
- #if EIGEN_GNUC_AT_LEAST(5, 3)
78
- HasSqrt = 1,
141
+ #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
142
+ HasLog = 1,
143
+ HasExp = 1,
144
+ HasSqrt = EIGEN_FAST_MATH,
79
145
  HasRsqrt = EIGEN_FAST_MATH,
80
146
  #endif
81
- HasDiv = 1
147
+ HasCmp = 1,
148
+ HasDiv = 1,
149
+ HasRound = 1,
150
+ HasFloor = 1,
151
+ HasCeil = 1,
152
+ HasRint = 1
82
153
  };
83
154
  };
84
155
 
@@ -98,19 +169,28 @@ template <>
98
169
  struct unpacket_traits<Packet16f> {
99
170
  typedef float type;
100
171
  typedef Packet8f half;
101
- enum { size = 16, alignment=Aligned64 };
172
+ typedef Packet16i integer_packet;
173
+ typedef uint16_t mask_t;
174
+ enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true };
102
175
  };
103
176
  template <>
104
177
  struct unpacket_traits<Packet8d> {
105
178
  typedef double type;
106
179
  typedef Packet4d half;
107
- enum { size = 8, alignment=Aligned64 };
180
+ enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
108
181
  };
109
182
  template <>
110
183
  struct unpacket_traits<Packet16i> {
111
184
  typedef int type;
112
185
  typedef Packet8i half;
113
- enum { size = 16, alignment=Aligned64 };
186
+ enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false, masked_store_available=false };
187
+ };
188
+
189
+ template<>
190
+ struct unpacket_traits<Packet16h> {
191
+ typedef Eigen::half type;
192
+ typedef Packet8h half;
193
+ enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
114
194
  };
115
195
 
116
196
  template <>
@@ -126,13 +206,40 @@ EIGEN_STRONG_INLINE Packet16i pset1<Packet16i>(const int& from) {
126
206
  return _mm512_set1_epi32(from);
127
207
  }
128
208
 
209
+ template <>
210
+ EIGEN_STRONG_INLINE Packet16f pset1frombits<Packet16f>(unsigned int from) {
211
+ return _mm512_castsi512_ps(_mm512_set1_epi32(from));
212
+ }
213
+
214
+ template <>
215
+ EIGEN_STRONG_INLINE Packet8d pset1frombits<Packet8d>(const numext::uint64_t from) {
216
+ return _mm512_castsi512_pd(_mm512_set1_epi64(from));
217
+ }
218
+
219
+ template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); }
220
+ template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); }
221
+ template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); }
222
+
223
+ template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) {
224
+ return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
225
+ 0, -1, 0, -1, 0, -1, 0, -1));
226
+ }
227
+ template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) {
228
+ return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1,
229
+ 0, -1, 0, -1, 0, -1, 0, -1);
230
+ }
231
+ template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) {
232
+ return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1,
233
+ 0, 0, -1, -1, 0, 0, -1, -1));
234
+ }
235
+
129
236
  template <>
130
237
  EIGEN_STRONG_INLINE Packet16f pload1<Packet16f>(const float* from) {
131
238
  return _mm512_broadcastss_ps(_mm_load_ps1(from));
132
239
  }
133
240
  template <>
134
241
  EIGEN_STRONG_INLINE Packet8d pload1<Packet8d>(const double* from) {
135
- return _mm512_broadcastsd_pd(_mm_load_pd1(from));
242
+ return _mm512_set1_pd(*from);
136
243
  }
137
244
 
138
245
  template <>
@@ -158,6 +265,11 @@ EIGEN_STRONG_INLINE Packet8d padd<Packet8d>(const Packet8d& a,
158
265
  const Packet8d& b) {
159
266
  return _mm512_add_pd(a, b);
160
267
  }
268
+ template <>
269
+ EIGEN_STRONG_INLINE Packet16i padd<Packet16i>(const Packet16i& a,
270
+ const Packet16i& b) {
271
+ return _mm512_add_epi32(a, b);
272
+ }
161
273
 
162
274
  template <>
163
275
  EIGEN_STRONG_INLINE Packet16f psub<Packet16f>(const Packet16f& a,
@@ -169,6 +281,11 @@ EIGEN_STRONG_INLINE Packet8d psub<Packet8d>(const Packet8d& a,
169
281
  const Packet8d& b) {
170
282
  return _mm512_sub_pd(a, b);
171
283
  }
284
+ template <>
285
+ EIGEN_STRONG_INLINE Packet16i psub<Packet16i>(const Packet16i& a,
286
+ const Packet16i& b) {
287
+ return _mm512_sub_epi32(a, b);
288
+ }
172
289
 
173
290
  template <>
174
291
  EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) {
@@ -202,6 +319,11 @@ EIGEN_STRONG_INLINE Packet8d pmul<Packet8d>(const Packet8d& a,
202
319
  const Packet8d& b) {
203
320
  return _mm512_mul_pd(a, b);
204
321
  }
322
+ template <>
323
+ EIGEN_STRONG_INLINE Packet16i pmul<Packet16i>(const Packet16i& a,
324
+ const Packet16i& b) {
325
+ return _mm512_mullo_epi32(a, b);
326
+ }
205
327
 
206
328
  template <>
207
329
  EIGEN_STRONG_INLINE Packet16f pdiv<Packet16f>(const Packet16f& a,
@@ -214,7 +336,7 @@ EIGEN_STRONG_INLINE Packet8d pdiv<Packet8d>(const Packet8d& a,
214
336
  return _mm512_div_pd(a, b);
215
337
  }
216
338
 
217
- #ifdef __FMA__
339
+ #ifdef EIGEN_VECTORIZE_FMA
218
340
  template <>
219
341
  EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b,
220
342
  const Packet16f& c) {
@@ -227,52 +349,217 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b,
227
349
  }
228
350
  #endif
229
351
 
352
+ template <>
353
+ EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask,
354
+ const Packet16f& a,
355
+ const Packet16f& b) {
356
+ __mmask16 mask16 = _mm512_cmp_epi32_mask(
357
+ _mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
358
+ return _mm512_mask_blend_ps(mask16, a, b);
359
+ }
360
+
361
+ template <>
362
+ EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask,
363
+ const Packet8d& a,
364
+ const Packet8d& b) {
365
+ __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask),
366
+ _mm512_setzero_epi32(), _MM_CMPINT_EQ);
367
+ return _mm512_mask_blend_pd(mask8, a, b);
368
+ }
369
+
230
370
  template <>
231
371
  EIGEN_STRONG_INLINE Packet16f pmin<Packet16f>(const Packet16f& a,
232
372
  const Packet16f& b) {
233
- return _mm512_min_ps(a, b);
373
+ // Arguments are reversed to match NaN propagation behavior of std::min.
374
+ return _mm512_min_ps(b, a);
234
375
  }
235
376
  template <>
236
377
  EIGEN_STRONG_INLINE Packet8d pmin<Packet8d>(const Packet8d& a,
237
378
  const Packet8d& b) {
238
- return _mm512_min_pd(a, b);
379
+ // Arguments are reversed to match NaN propagation behavior of std::min.
380
+ return _mm512_min_pd(b, a);
239
381
  }
240
382
 
241
383
  template <>
242
384
  EIGEN_STRONG_INLINE Packet16f pmax<Packet16f>(const Packet16f& a,
243
385
  const Packet16f& b) {
244
- return _mm512_max_ps(a, b);
386
+ // Arguments are reversed to match NaN propagation behavior of std::max.
387
+ return _mm512_max_ps(b, a);
245
388
  }
246
389
  template <>
247
390
  EIGEN_STRONG_INLINE Packet8d pmax<Packet8d>(const Packet8d& a,
248
391
  const Packet8d& b) {
249
- return _mm512_max_pd(a, b);
392
+ // Arguments are reversed to match NaN propagation behavior of std::max.
393
+ return _mm512_max_pd(b, a);
250
394
  }
251
395
 
252
- template <>
253
- EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a,
254
- const Packet16f& b) {
396
+ // Add specializations for min/max with prescribed NaN progation.
397
+ template<>
398
+ EIGEN_STRONG_INLINE Packet16f pmin<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
399
+ return pminmax_propagate_numbers(a, b, pmin<Packet16f>);
400
+ }
401
+ template<>
402
+ EIGEN_STRONG_INLINE Packet8d pmin<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
403
+ return pminmax_propagate_numbers(a, b, pmin<Packet8d>);
404
+ }
405
+ template<>
406
+ EIGEN_STRONG_INLINE Packet16f pmax<PropagateNumbers, Packet16f>(const Packet16f& a, const Packet16f& b) {
407
+ return pminmax_propagate_numbers(a, b, pmax<Packet16f>);
408
+ }
409
+ template<>
410
+ EIGEN_STRONG_INLINE Packet8d pmax<PropagateNumbers, Packet8d>(const Packet8d& a, const Packet8d& b) {
411
+ return pminmax_propagate_numbers(a, b, pmax<Packet8d>);
412
+ }
413
+ template<>
414
+ EIGEN_STRONG_INLINE Packet16f pmin<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
415
+ return pminmax_propagate_nan(a, b, pmin<Packet16f>);
416
+ }
417
+ template<>
418
+ EIGEN_STRONG_INLINE Packet8d pmin<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
419
+ return pminmax_propagate_nan(a, b, pmin<Packet8d>);
420
+ }
421
+ template<>
422
+ EIGEN_STRONG_INLINE Packet16f pmax<PropagateNaN, Packet16f>(const Packet16f& a, const Packet16f& b) {
423
+ return pminmax_propagate_nan(a, b, pmax<Packet16f>);
424
+ }
425
+ template<>
426
+ EIGEN_STRONG_INLINE Packet8d pmax<PropagateNaN, Packet8d>(const Packet8d& a, const Packet8d& b) {
427
+ return pminmax_propagate_nan(a, b, pmax<Packet8d>);
428
+ }
429
+
430
+
255
431
  #ifdef EIGEN_VECTORIZE_AVX512DQ
256
- return _mm512_and_ps(a, b);
432
+ template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); }
433
+ template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); }
434
+ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); }
257
435
  #else
258
- Packet16f res = _mm512_undefined_ps();
259
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
260
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
261
- res = _mm512_insertf32x4(res, _mm_and_ps(lane0_a, lane0_b), 0);
436
+ // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
437
+ template<int I_> EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) {
438
+ return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_));
439
+ }
262
440
 
263
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
264
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
265
- res = _mm512_insertf32x4(res, _mm_and_ps(lane1_a, lane1_b), 1);
441
+ // AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512
442
+ template<int I_> EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) {
443
+ return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_));
444
+ }
445
+
446
+ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
447
+ return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)),
448
+ _mm256_castps_si256(b),1));
449
+ }
450
+ #endif
451
+
452
+ // Helper function for bit packing snippet of low precision comparison.
453
+ // It packs the flags from 32x16 to 16x16.
454
+ EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) {
455
+ // Split data into small pieces and handle with AVX instructions
456
+ // to guarantee internal order of vector.
457
+ // Operation:
458
+ // dst[15:0] := Saturate16(rf[31:0])
459
+ // dst[31:16] := Saturate16(rf[63:32])
460
+ // ...
461
+ // dst[255:240] := Saturate16(rf[255:224])
462
+ __m256i lo = _mm256_castps_si256(extract256<0>(rf));
463
+ __m256i hi = _mm256_castps_si256(extract256<1>(rf));
464
+ __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0),
465
+ _mm256_extractf128_si256(lo, 1));
466
+ __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0),
467
+ _mm256_extractf128_si256(hi, 1));
468
+ return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1);
469
+ }
470
+
471
+ template <>
472
+ EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
473
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
474
+ return _mm512_castsi512_ps(
475
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
476
+ }
477
+ template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
478
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
479
+ return _mm512_castsi512_ps(
480
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
481
+ }
482
+
483
+ template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) {
484
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ);
485
+ return _mm512_castsi512_ps(
486
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
487
+ }
488
+
489
+ template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) {
490
+ __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ);
491
+ return _mm512_castsi512_ps(
492
+ _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
493
+ }
494
+
495
+ template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) {
496
+ __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _CMP_EQ_OQ);
497
+ return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
498
+ }
266
499
 
267
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
268
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
269
- res = _mm512_insertf32x4(res, _mm_and_ps(lane2_a, lane2_b), 2);
270
500
 
271
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
272
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
273
- res = _mm512_insertf32x4(res, _mm_and_ps(lane3_a, lane3_b), 3);
501
+ template <>
502
+ EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
503
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ);
504
+ return _mm512_castsi512_pd(
505
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
506
+ }
507
+ template <>
508
+ EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
509
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
510
+ return _mm512_castsi512_pd(
511
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
512
+ }
513
+ template <>
514
+ EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
515
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
516
+ return _mm512_castsi512_pd(
517
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
518
+ }
519
+ template <>
520
+ EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
521
+ __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ);
522
+ return _mm512_castsi512_pd(
523
+ _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
524
+ }
274
525
 
275
- return res;
526
+ template<> EIGEN_STRONG_INLINE Packet16f print<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); }
527
+ template<> EIGEN_STRONG_INLINE Packet8d print<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); }
528
+
529
+ template<> EIGEN_STRONG_INLINE Packet16f pceil<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); }
530
+ template<> EIGEN_STRONG_INLINE Packet8d pceil<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); }
531
+
532
+ template<> EIGEN_STRONG_INLINE Packet16f pfloor<Packet16f>(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); }
533
+ template<> EIGEN_STRONG_INLINE Packet8d pfloor<Packet8d>(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); }
534
+
535
+ template <>
536
+ EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
537
+ return _mm512_set1_epi32(0xffffffffu);
538
+ }
539
+
540
+ template <>
541
+ EIGEN_STRONG_INLINE Packet16f ptrue<Packet16f>(const Packet16f& a) {
542
+ return _mm512_castsi512_ps(ptrue<Packet16i>(_mm512_castps_si512(a)));
543
+ }
544
+
545
+ template <>
546
+ EIGEN_STRONG_INLINE Packet8d ptrue<Packet8d>(const Packet8d& a) {
547
+ return _mm512_castsi512_pd(ptrue<Packet16i>(_mm512_castpd_si512(a)));
548
+ }
549
+
550
+ template <>
551
+ EIGEN_STRONG_INLINE Packet16i pand<Packet16i>(const Packet16i& a,
552
+ const Packet16i& b) {
553
+ return _mm512_and_si512(a,b);
554
+ }
555
+
556
+ template <>
557
+ EIGEN_STRONG_INLINE Packet16f pand<Packet16f>(const Packet16f& a,
558
+ const Packet16f& b) {
559
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
560
+ return _mm512_and_ps(a, b);
561
+ #else
562
+ return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b)));
276
563
  #endif
277
564
  }
278
565
  template <>
@@ -288,35 +575,21 @@ EIGEN_STRONG_INLINE Packet8d pand<Packet8d>(const Packet8d& a,
288
575
 
289
576
  Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
290
577
  Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
291
- res = _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
292
-
293
- return res;
578
+ return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1);
294
579
  #endif
295
580
  }
581
+
582
+ template <>
583
+ EIGEN_STRONG_INLINE Packet16i por<Packet16i>(const Packet16i& a, const Packet16i& b) {
584
+ return _mm512_or_si512(a, b);
585
+ }
586
+
296
587
  template <>
297
- EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a,
298
- const Packet16f& b) {
588
+ EIGEN_STRONG_INLINE Packet16f por<Packet16f>(const Packet16f& a, const Packet16f& b) {
299
589
  #ifdef EIGEN_VECTORIZE_AVX512DQ
300
590
  return _mm512_or_ps(a, b);
301
591
  #else
302
- Packet16f res = _mm512_undefined_ps();
303
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
304
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
305
- res = _mm512_insertf32x4(res, _mm_or_ps(lane0_a, lane0_b), 0);
306
-
307
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
308
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
309
- res = _mm512_insertf32x4(res, _mm_or_ps(lane1_a, lane1_b), 1);
310
-
311
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
312
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
313
- res = _mm512_insertf32x4(res, _mm_or_ps(lane2_a, lane2_b), 2);
314
-
315
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
316
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
317
- res = _mm512_insertf32x4(res, _mm_or_ps(lane3_a, lane3_b), 3);
318
-
319
- return res;
592
+ return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b)));
320
593
  #endif
321
594
  }
322
595
 
@@ -326,107 +599,80 @@ EIGEN_STRONG_INLINE Packet8d por<Packet8d>(const Packet8d& a,
326
599
  #ifdef EIGEN_VECTORIZE_AVX512DQ
327
600
  return _mm512_or_pd(a, b);
328
601
  #else
329
- Packet8d res = _mm512_undefined_pd();
330
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
331
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
332
- res = _mm512_insertf64x4(res, _mm256_or_pd(lane0_a, lane0_b), 0);
333
-
334
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
335
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
336
- res = _mm512_insertf64x4(res, _mm256_or_pd(lane1_a, lane1_b), 1);
337
-
338
- return res;
602
+ return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
339
603
  #endif
340
604
  }
341
605
 
342
606
  template <>
343
- EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a,
344
- const Packet16f& b) {
607
+ EIGEN_STRONG_INLINE Packet16i pxor<Packet16i>(const Packet16i& a, const Packet16i& b) {
608
+ return _mm512_xor_si512(a, b);
609
+ }
610
+
611
+ template <>
612
+ EIGEN_STRONG_INLINE Packet16f pxor<Packet16f>(const Packet16f& a, const Packet16f& b) {
345
613
  #ifdef EIGEN_VECTORIZE_AVX512DQ
346
614
  return _mm512_xor_ps(a, b);
347
615
  #else
348
- Packet16f res = _mm512_undefined_ps();
349
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
350
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
351
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane0_a, lane0_b), 0);
352
-
353
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
354
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
355
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane1_a, lane1_b), 1);
356
-
357
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
358
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
359
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane2_a, lane2_b), 2);
360
-
361
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
362
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
363
- res = _mm512_insertf32x4(res, _mm_xor_ps(lane3_a, lane3_b), 3);
364
-
365
- return res;
616
+ return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b)));
366
617
  #endif
367
618
  }
619
+
368
620
  template <>
369
- EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a,
370
- const Packet8d& b) {
621
+ EIGEN_STRONG_INLINE Packet8d pxor<Packet8d>(const Packet8d& a, const Packet8d& b) {
371
622
  #ifdef EIGEN_VECTORIZE_AVX512DQ
372
623
  return _mm512_xor_pd(a, b);
373
624
  #else
374
- Packet8d res = _mm512_undefined_pd();
375
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
376
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
377
- res = _mm512_insertf64x4(res, _mm256_xor_pd(lane0_a, lane0_b), 0);
378
-
379
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
380
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
381
- res = _mm512_insertf64x4(res, _mm256_xor_pd(lane1_a, lane1_b), 1);
382
-
383
- return res;
625
+ return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
384
626
  #endif
385
627
  }
386
628
 
387
629
  template <>
388
- EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a,
389
- const Packet16f& b) {
630
+ EIGEN_STRONG_INLINE Packet16i pandnot<Packet16i>(const Packet16i& a, const Packet16i& b) {
631
+ return _mm512_andnot_si512(b, a);
632
+ }
633
+
634
+ template <>
635
+ EIGEN_STRONG_INLINE Packet16f pandnot<Packet16f>(const Packet16f& a, const Packet16f& b) {
390
636
  #ifdef EIGEN_VECTORIZE_AVX512DQ
391
- return _mm512_andnot_ps(a, b);
637
+ return _mm512_andnot_ps(b, a);
392
638
  #else
393
- Packet16f res = _mm512_undefined_ps();
394
- Packet4f lane0_a = _mm512_extractf32x4_ps(a, 0);
395
- Packet4f lane0_b = _mm512_extractf32x4_ps(b, 0);
396
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane0_a, lane0_b), 0);
397
-
398
- Packet4f lane1_a = _mm512_extractf32x4_ps(a, 1);
399
- Packet4f lane1_b = _mm512_extractf32x4_ps(b, 1);
400
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane1_a, lane1_b), 1);
401
-
402
- Packet4f lane2_a = _mm512_extractf32x4_ps(a, 2);
403
- Packet4f lane2_b = _mm512_extractf32x4_ps(b, 2);
404
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane2_a, lane2_b), 2);
405
-
406
- Packet4f lane3_a = _mm512_extractf32x4_ps(a, 3);
407
- Packet4f lane3_b = _mm512_extractf32x4_ps(b, 3);
408
- res = _mm512_insertf32x4(res, _mm_andnot_ps(lane3_a, lane3_b), 3);
409
-
410
- return res;
639
+ return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b)));
411
640
  #endif
412
641
  }
413
642
  template <>
414
- EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,
415
- const Packet8d& b) {
643
+ EIGEN_STRONG_INLINE Packet8d pandnot<Packet8d>(const Packet8d& a,const Packet8d& b) {
416
644
  #ifdef EIGEN_VECTORIZE_AVX512DQ
417
- return _mm512_andnot_pd(a, b);
645
+ return _mm512_andnot_pd(b, a);
418
646
  #else
419
- Packet8d res = _mm512_undefined_pd();
420
- Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0);
421
- Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0);
422
- res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane0_a, lane0_b), 0);
647
+ return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b)));
648
+ #endif
649
+ }
423
650
 
424
- Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1);
425
- Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1);
426
- res = _mm512_insertf64x4(res, _mm256_andnot_pd(lane1_a, lane1_b), 1);
651
+ template<> EIGEN_STRONG_INLINE Packet16f pround<Packet16f>(const Packet16f& a)
652
+ {
653
+ // Work-around for default std::round rounding mode.
654
+ const Packet16f mask = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x80000000u));
655
+ const Packet16f prev0dot5 = pset1frombits<Packet16f>(static_cast<numext::uint32_t>(0x3EFFFFFFu));
656
+ return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
657
+ }
658
+ template<> EIGEN_STRONG_INLINE Packet8d pround<Packet8d>(const Packet8d& a)
659
+ {
660
+ // Work-around for default std::round rounding mode.
661
+ const Packet8d mask = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x8000000000000000ull));
662
+ const Packet8d prev0dot5 = pset1frombits<Packet8d>(static_cast<numext::uint64_t>(0x3FDFFFFFFFFFFFFFull));
663
+ return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO);
664
+ }
427
665
 
428
- return res;
429
- #endif
666
+ template<int N> EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) {
667
+ return _mm512_srai_epi32(a, N);
668
+ }
669
+
670
+ template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) {
671
+ return _mm512_srli_epi32(a, N);
672
+ }
673
+
674
+ template<int N> EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) {
675
+ return _mm512_slli_epi32(a, N);
430
676
  }
431
677
 
432
678
  template <>
@@ -457,79 +703,65 @@ EIGEN_STRONG_INLINE Packet16i ploadu<Packet16i>(const int* from) {
457
703
  reinterpret_cast<const __m512i*>(from));
458
704
  }
459
705
 
706
+ template <>
707
+ EIGEN_STRONG_INLINE Packet16f ploadu<Packet16f>(const float* from, uint16_t umask) {
708
+ __mmask16 mask = static_cast<__mmask16>(umask);
709
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from);
710
+ }
711
+
460
712
  // Loads 8 floats from memory a returns the packet
461
713
  // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7}
462
714
  template <>
463
715
  EIGEN_STRONG_INLINE Packet16f ploaddup<Packet16f>(const float* from) {
464
- Packet8f lane0 = _mm256_broadcast_ps((const __m128*)(const void*)from);
465
- // mimic an "inplace" permutation of the lower 128bits using a blend
466
- lane0 = _mm256_blend_ps(
467
- lane0, _mm256_castps128_ps256(_mm_permute_ps(
468
- _mm256_castps256_ps128(lane0), _MM_SHUFFLE(1, 0, 1, 0))),
469
- 15);
470
- // then we can perform a consistent permutation on the global register to get
471
- // everything in shape:
472
- lane0 = _mm256_permute_ps(lane0, _MM_SHUFFLE(3, 3, 2, 2));
473
-
474
- Packet8f lane1 = _mm256_broadcast_ps((const __m128*)(const void*)(from + 4));
475
- // mimic an "inplace" permutation of the lower 128bits using a blend
476
- lane1 = _mm256_blend_ps(
477
- lane1, _mm256_castps128_ps256(_mm_permute_ps(
478
- _mm256_castps256_ps128(lane1), _MM_SHUFFLE(1, 0, 1, 0))),
479
- 15);
480
- // then we can perform a consistent permutation on the global register to get
481
- // everything in shape:
482
- lane1 = _mm256_permute_ps(lane1, _MM_SHUFFLE(3, 3, 2, 2));
716
+ // an unaligned load is required here as there is no requirement
717
+ // on the alignment of input pointer 'from'
718
+ __m256i low_half = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
719
+ __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half));
720
+ __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0));
721
+ return pairs;
722
+ }
483
723
 
484
724
  #ifdef EIGEN_VECTORIZE_AVX512DQ
485
- Packet16f res = _mm512_undefined_ps();
486
- return _mm512_insertf32x8(res, lane0, 0);
487
- return _mm512_insertf32x8(res, lane1, 1);
488
- return res;
489
- #else
490
- Packet16f res = _mm512_undefined_ps();
491
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 0), 0);
492
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane0, 1), 1);
493
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 0), 2);
494
- res = _mm512_insertf32x4(res, _mm256_extractf128_ps(lane1, 1), 3);
495
- return res;
496
- #endif
497
- }
725
+ // FIXME: this does not look optimal, better load a Packet4d and shuffle...
498
726
  // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3,
499
727
  // a3}
500
728
  template <>
501
729
  EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
502
- Packet4d lane0 = _mm256_broadcast_pd((const __m128d*)(const void*)from);
503
- lane0 = _mm256_permute_pd(lane0, 3 << 2);
504
-
505
- Packet4d lane1 = _mm256_broadcast_pd((const __m128d*)(const void*)(from + 2));
506
- lane1 = _mm256_permute_pd(lane1, 3 << 2);
507
-
508
- Packet8d res = _mm512_undefined_pd();
509
- res = _mm512_insertf64x4(res, lane0, 0);
510
- return _mm512_insertf64x4(res, lane1, 1);
730
+ __m512d x = _mm512_setzero_pd();
731
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0);
732
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1);
733
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2);
734
+ x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3);
735
+ return x;
511
736
  }
737
+ #else
738
+ template <>
739
+ EIGEN_STRONG_INLINE Packet8d ploaddup<Packet8d>(const double* from) {
740
+ __m512d x = _mm512_setzero_pd();
741
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0));
742
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1));
743
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2));
744
+ x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3));
745
+ return x;
746
+ }
747
+ #endif
512
748
 
513
749
  // Loads 4 floats from memory a returns the packet
514
750
  // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3}
515
751
  template <>
516
752
  EIGEN_STRONG_INLINE Packet16f ploadquad<Packet16f>(const float* from) {
517
- Packet16f tmp = _mm512_undefined_ps();
518
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from), 0);
519
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 1), 1);
520
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 2), 2);
521
- tmp = _mm512_insertf32x4(tmp, _mm_load_ps1(from + 3), 3);
522
- return tmp;
753
+ Packet16f tmp = _mm512_castps128_ps512(ploadu<Packet4f>(from));
754
+ const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0);
755
+ return _mm512_permutexvar_ps(scatter_mask, tmp);
523
756
  }
757
+
524
758
  // Loads 2 doubles from memory a returns the packet
525
759
  // {a0, a0 a0, a0, a1, a1, a1, a1}
526
760
  template <>
527
761
  EIGEN_STRONG_INLINE Packet8d ploadquad<Packet8d>(const double* from) {
528
- Packet8d tmp = _mm512_undefined_pd();
529
- Packet2d tmp0 = _mm_load_pd1(from);
530
- Packet2d tmp1 = _mm_load_pd1(from + 1);
531
- Packet4d lane0 = _mm256_broadcastsd_pd(tmp0);
532
- Packet4d lane1 = _mm256_broadcastsd_pd(tmp1);
762
+ __m256d lane0 = _mm256_set1_pd(*from);
763
+ __m256d lane1 = _mm256_set1_pd(*(from+1));
764
+ __m512d tmp = _mm512_undefined_pd();
533
765
  tmp = _mm512_insertf64x4(tmp, lane0, 0);
534
766
  return _mm512_insertf64x4(tmp, lane1, 1);
535
767
  }
@@ -561,11 +793,16 @@ EIGEN_STRONG_INLINE void pstoreu<int>(int* to, const Packet16i& from) {
561
793
  EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512(
562
794
  reinterpret_cast<__m512i*>(to), from);
563
795
  }
796
+ template <>
797
+ EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const Packet16f& from, uint16_t umask) {
798
+ __mmask16 mask = static_cast<__mmask16>(umask);
799
+ EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from);
800
+ }
564
801
 
565
802
  template <>
566
803
  EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
567
804
  Index stride) {
568
- Packet16i stride_vector = _mm512_set1_epi32(stride);
805
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
569
806
  Packet16i stride_multiplier =
570
807
  _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
571
808
  Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
@@ -575,7 +812,7 @@ EIGEN_DEVICE_FUNC inline Packet16f pgather<float, Packet16f>(const float* from,
575
812
  template <>
576
813
  EIGEN_DEVICE_FUNC inline Packet8d pgather<double, Packet8d>(const double* from,
577
814
  Index stride) {
578
- Packet8i stride_vector = _mm256_set1_epi32(stride);
815
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
579
816
  Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
580
817
  Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
581
818
 
@@ -586,7 +823,7 @@ template <>
586
823
  EIGEN_DEVICE_FUNC inline void pscatter<float, Packet16f>(float* to,
587
824
  const Packet16f& from,
588
825
  Index stride) {
589
- Packet16i stride_vector = _mm512_set1_epi32(stride);
826
+ Packet16i stride_vector = _mm512_set1_epi32(convert_index<int>(stride));
590
827
  Packet16i stride_multiplier =
591
828
  _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
592
829
  Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier);
@@ -596,7 +833,7 @@ template <>
596
833
  EIGEN_DEVICE_FUNC inline void pscatter<double, Packet8d>(double* to,
597
834
  const Packet8d& from,
598
835
  Index stride) {
599
- Packet8i stride_vector = _mm256_set1_epi32(stride);
836
+ Packet8i stride_vector = _mm256_set1_epi32(convert_index<int>(stride));
600
837
  Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
601
838
  Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier);
602
839
  _mm512_i32scatter_pd(to, indices, from, 8);
@@ -657,11 +894,64 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
657
894
  _mm512_set1_epi64(0x7fffffffffffffff)));
658
895
  }
659
896
 
897
+ template<>
898
+ EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
899
+ return pfrexp_generic(a, exponent);
900
+ }
901
+
902
+ // Extract exponent without existence of Packet8l.
903
+ template<>
904
+ EIGEN_STRONG_INLINE
905
+ Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
906
+ const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
907
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
908
+ return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
909
+ #else
910
+ return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
911
+ #endif
912
+ }
913
+
914
+ template<>
915
+ EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
916
+ return pfrexp_generic(a, exponent);
917
+ }
918
+
919
+ template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
920
+ return pldexp_generic(a, exponent);
921
+ }
922
+
923
+ template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
924
+ // Clamp exponent to [-2099, 2099]
925
+ const Packet8d max_exponent = pset1<Packet8d>(2099.0);
926
+ const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
927
+
928
+ // Split 2^e into four factors and multiply.
929
+ const Packet8i bias = pset1<Packet8i>(1023);
930
+ Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4)
931
+
932
+ // 2^b
933
+ const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
934
+ Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
935
+ Packet8i lo = _mm256_slli_epi64(hi, 52);
936
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
937
+ Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
938
+ Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
939
+
940
+ // 2^(e - 3b)
941
+ b = psub(psub(psub(e, b), b), b); // e - 3b
942
+ hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx);
943
+ lo = _mm256_slli_epi64(hi, 52);
944
+ hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52);
945
+ c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
946
+ out = pmul(out, c); // a * 2^e
947
+ return out;
948
+ }
949
+
660
950
  #ifdef EIGEN_VECTORIZE_AVX512DQ
661
951
  // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512
662
952
  #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
663
- __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0) __m256 OUTPUT##_1 = \
664
- _mm512_extractf32x8_ps(INPUT, 1)
953
+ __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \
954
+ __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1)
665
955
  #else
666
956
  #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \
667
957
  __m256 OUTPUT##_0 = _mm256_insertf128_ps( \
@@ -674,258 +964,64 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
674
964
 
675
965
  #ifdef EIGEN_VECTORIZE_AVX512DQ
676
966
  #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
677
- OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTA, 0); \
678
- OUTPUT = _mm512_insertf32x8(OUTPUT, INPUTB, 1);
967
+ OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1);
679
968
  #else
680
969
  #define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \
970
+ OUTPUT = _mm512_undefined_ps(); \
681
971
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \
682
972
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \
683
973
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \
684
974
  OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3);
685
975
  #endif
686
- template<> EIGEN_STRONG_INLINE Packet16f preduxp<Packet16f>(const Packet16f*
687
- vecs)
688
- {
689
- EIGEN_EXTRACT_8f_FROM_16f(vecs[0], vecs0);
690
- EIGEN_EXTRACT_8f_FROM_16f(vecs[1], vecs1);
691
- EIGEN_EXTRACT_8f_FROM_16f(vecs[2], vecs2);
692
- EIGEN_EXTRACT_8f_FROM_16f(vecs[3], vecs3);
693
- EIGEN_EXTRACT_8f_FROM_16f(vecs[4], vecs4);
694
- EIGEN_EXTRACT_8f_FROM_16f(vecs[5], vecs5);
695
- EIGEN_EXTRACT_8f_FROM_16f(vecs[6], vecs6);
696
- EIGEN_EXTRACT_8f_FROM_16f(vecs[7], vecs7);
697
- EIGEN_EXTRACT_8f_FROM_16f(vecs[8], vecs8);
698
- EIGEN_EXTRACT_8f_FROM_16f(vecs[9], vecs9);
699
- EIGEN_EXTRACT_8f_FROM_16f(vecs[10], vecs10);
700
- EIGEN_EXTRACT_8f_FROM_16f(vecs[11], vecs11);
701
- EIGEN_EXTRACT_8f_FROM_16f(vecs[12], vecs12);
702
- EIGEN_EXTRACT_8f_FROM_16f(vecs[13], vecs13);
703
- EIGEN_EXTRACT_8f_FROM_16f(vecs[14], vecs14);
704
- EIGEN_EXTRACT_8f_FROM_16f(vecs[15], vecs15);
705
-
706
- __m256 hsum1 = _mm256_hadd_ps(vecs0_0, vecs1_0);
707
- __m256 hsum2 = _mm256_hadd_ps(vecs2_0, vecs3_0);
708
- __m256 hsum3 = _mm256_hadd_ps(vecs4_0, vecs5_0);
709
- __m256 hsum4 = _mm256_hadd_ps(vecs6_0, vecs7_0);
710
-
711
- __m256 hsum5 = _mm256_hadd_ps(hsum1, hsum1);
712
- __m256 hsum6 = _mm256_hadd_ps(hsum2, hsum2);
713
- __m256 hsum7 = _mm256_hadd_ps(hsum3, hsum3);
714
- __m256 hsum8 = _mm256_hadd_ps(hsum4, hsum4);
715
-
716
- __m256 perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
717
- __m256 perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
718
- __m256 perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
719
- __m256 perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
720
-
721
- __m256 sum1 = _mm256_add_ps(perm1, hsum5);
722
- __m256 sum2 = _mm256_add_ps(perm2, hsum6);
723
- __m256 sum3 = _mm256_add_ps(perm3, hsum7);
724
- __m256 sum4 = _mm256_add_ps(perm4, hsum8);
725
-
726
- __m256 blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
727
- __m256 blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
728
-
729
- __m256 final = _mm256_blend_ps(blend1, blend2, 0xf0);
730
-
731
- hsum1 = _mm256_hadd_ps(vecs0_1, vecs1_1);
732
- hsum2 = _mm256_hadd_ps(vecs2_1, vecs3_1);
733
- hsum3 = _mm256_hadd_ps(vecs4_1, vecs5_1);
734
- hsum4 = _mm256_hadd_ps(vecs6_1, vecs7_1);
735
-
736
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
737
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
738
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
739
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
740
-
741
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
742
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
743
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
744
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
745
-
746
- sum1 = _mm256_add_ps(perm1, hsum5);
747
- sum2 = _mm256_add_ps(perm2, hsum6);
748
- sum3 = _mm256_add_ps(perm3, hsum7);
749
- sum4 = _mm256_add_ps(perm4, hsum8);
750
-
751
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
752
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
753
-
754
- final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0));
755
-
756
- hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0);
757
- hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0);
758
- hsum3 = _mm256_hadd_ps(vecs12_0, vecs13_0);
759
- hsum4 = _mm256_hadd_ps(vecs14_0, vecs15_0);
760
-
761
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
762
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
763
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
764
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
765
-
766
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
767
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
768
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
769
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
770
-
771
- sum1 = _mm256_add_ps(perm1, hsum5);
772
- sum2 = _mm256_add_ps(perm2, hsum6);
773
- sum3 = _mm256_add_ps(perm3, hsum7);
774
- sum4 = _mm256_add_ps(perm4, hsum8);
775
-
776
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
777
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
778
-
779
- __m256 final_1 = _mm256_blend_ps(blend1, blend2, 0xf0);
780
-
781
- hsum1 = _mm256_hadd_ps(vecs8_1, vecs9_1);
782
- hsum2 = _mm256_hadd_ps(vecs10_1, vecs11_1);
783
- hsum3 = _mm256_hadd_ps(vecs12_1, vecs13_1);
784
- hsum4 = _mm256_hadd_ps(vecs14_1, vecs15_1);
785
-
786
- hsum5 = _mm256_hadd_ps(hsum1, hsum1);
787
- hsum6 = _mm256_hadd_ps(hsum2, hsum2);
788
- hsum7 = _mm256_hadd_ps(hsum3, hsum3);
789
- hsum8 = _mm256_hadd_ps(hsum4, hsum4);
790
-
791
- perm1 = _mm256_permute2f128_ps(hsum5, hsum5, 0x23);
792
- perm2 = _mm256_permute2f128_ps(hsum6, hsum6, 0x23);
793
- perm3 = _mm256_permute2f128_ps(hsum7, hsum7, 0x23);
794
- perm4 = _mm256_permute2f128_ps(hsum8, hsum8, 0x23);
795
-
796
- sum1 = _mm256_add_ps(perm1, hsum5);
797
- sum2 = _mm256_add_ps(perm2, hsum6);
798
- sum3 = _mm256_add_ps(perm3, hsum7);
799
- sum4 = _mm256_add_ps(perm4, hsum8);
800
-
801
- blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
802
- blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
803
-
804
- final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0));
805
-
806
- __m512 final_output;
807
-
808
- EIGEN_INSERT_8f_INTO_16f(final_output, final, final_1);
809
- return final_output;
810
- }
811
-
812
- template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
813
- {
814
- Packet4d vecs0_0 = _mm512_extractf64x4_pd(vecs[0], 0);
815
- Packet4d vecs0_1 = _mm512_extractf64x4_pd(vecs[0], 1);
816
-
817
- Packet4d vecs1_0 = _mm512_extractf64x4_pd(vecs[1], 0);
818
- Packet4d vecs1_1 = _mm512_extractf64x4_pd(vecs[1], 1);
819
-
820
- Packet4d vecs2_0 = _mm512_extractf64x4_pd(vecs[2], 0);
821
- Packet4d vecs2_1 = _mm512_extractf64x4_pd(vecs[2], 1);
822
-
823
- Packet4d vecs3_0 = _mm512_extractf64x4_pd(vecs[3], 0);
824
- Packet4d vecs3_1 = _mm512_extractf64x4_pd(vecs[3], 1);
825
-
826
- Packet4d vecs4_0 = _mm512_extractf64x4_pd(vecs[4], 0);
827
- Packet4d vecs4_1 = _mm512_extractf64x4_pd(vecs[4], 1);
828
-
829
- Packet4d vecs5_0 = _mm512_extractf64x4_pd(vecs[5], 0);
830
- Packet4d vecs5_1 = _mm512_extractf64x4_pd(vecs[5], 1);
831
-
832
- Packet4d vecs6_0 = _mm512_extractf64x4_pd(vecs[6], 0);
833
- Packet4d vecs6_1 = _mm512_extractf64x4_pd(vecs[6], 1);
834
-
835
- Packet4d vecs7_0 = _mm512_extractf64x4_pd(vecs[7], 0);
836
- Packet4d vecs7_1 = _mm512_extractf64x4_pd(vecs[7], 1);
837
-
838
- Packet4d tmp0, tmp1;
839
-
840
- tmp0 = _mm256_hadd_pd(vecs0_0, vecs1_0);
841
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
842
-
843
- tmp1 = _mm256_hadd_pd(vecs2_0, vecs3_0);
844
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
845
-
846
- __m256d final_0 = _mm256_blend_pd(tmp0, tmp1, 0xC);
847
-
848
- tmp0 = _mm256_hadd_pd(vecs0_1, vecs1_1);
849
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
850
-
851
- tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1);
852
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
853
-
854
- final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC));
855
-
856
- tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0);
857
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
858
-
859
- tmp1 = _mm256_hadd_pd(vecs6_0, vecs7_0);
860
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
861
-
862
- __m256d final_1 = _mm256_blend_pd(tmp0, tmp1, 0xC);
863
-
864
- tmp0 = _mm256_hadd_pd(vecs4_1, vecs5_1);
865
- tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
866
-
867
- tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1);
868
- tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
869
-
870
- final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC));
871
-
872
- __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0);
873
-
874
- return _mm512_insertf64x4(final_output, final_1, 1);
875
- }
876
976
 
877
977
  template <>
878
978
  EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) {
879
- //#ifdef EIGEN_VECTORIZE_AVX512DQ
880
- #if 0
881
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
882
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
883
- Packet8f sum = padd(lane0, lane1);
884
- Packet8f tmp0 = _mm256_hadd_ps(sum, _mm256_permute2f128_ps(a, a, 1));
885
- tmp0 = _mm256_hadd_ps(tmp0, tmp0);
886
- return pfirst(_mm256_hadd_ps(tmp0, tmp0));
979
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
980
+ __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
981
+ __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
982
+ Packet8f x = _mm256_add_ps(lane0, lane1);
983
+ return predux<Packet8f>(x);
887
984
  #else
888
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
889
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
890
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
891
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
892
- Packet4f sum = padd(padd(lane0, lane1), padd(lane2, lane3));
985
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
986
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
987
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
988
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
989
+ __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3));
893
990
  sum = _mm_hadd_ps(sum, sum);
894
991
  sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1));
895
- return pfirst(sum);
992
+ return _mm_cvtss_f32(sum);
896
993
  #endif
897
994
  }
898
995
  template <>
899
996
  EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) {
900
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
901
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
902
- Packet4d sum = padd(lane0, lane1);
903
- Packet4d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
904
- return pfirst(_mm256_hadd_pd(tmp0, tmp0));
997
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
998
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
999
+ __m256d sum = _mm256_add_pd(lane0, lane1);
1000
+ __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
1001
+ return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0)));
905
1002
  }
906
1003
 
907
1004
  template <>
908
- EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) {
1005
+ EIGEN_STRONG_INLINE Packet8f predux_half_dowto4<Packet16f>(const Packet16f& a) {
909
1006
  #ifdef EIGEN_VECTORIZE_AVX512DQ
910
- Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
911
- Packet8f lane1 = _mm512_extractf32x8_ps(a, 1);
912
- return padd(lane0, lane1);
1007
+ __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
1008
+ __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
1009
+ return _mm256_add_ps(lane0, lane1);
913
1010
  #else
914
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
915
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
916
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
917
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
918
- Packet4f sum0 = padd(lane0, lane2);
919
- Packet4f sum1 = padd(lane1, lane3);
1011
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
1012
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
1013
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
1014
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
1015
+ __m128 sum0 = _mm_add_ps(lane0, lane2);
1016
+ __m128 sum1 = _mm_add_ps(lane1, lane3);
920
1017
  return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
921
1018
  #endif
922
1019
  }
923
1020
  template <>
924
- EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) {
925
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
926
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
927
- Packet4d res = padd(lane0, lane1);
928
- return res;
1021
+ EIGEN_STRONG_INLINE Packet4d predux_half_dowto4<Packet8d>(const Packet8d& a) {
1022
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
1023
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
1024
+ return _mm256_add_pd(lane0, lane1);
929
1025
  }
930
1026
 
931
1027
  template <>
@@ -939,108 +1035,70 @@ EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) {
939
1035
  res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
940
1036
  return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
941
1037
  #else
942
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
943
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
944
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
945
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
946
- Packet4f res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
1038
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
1039
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
1040
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
1041
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
1042
+ __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
947
1043
  res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
948
1044
  return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
949
1045
  #endif
950
1046
  }
951
1047
  template <>
952
1048
  EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) {
953
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
954
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
955
- Packet4d res = pmul(lane0, lane1);
1049
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
1050
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
1051
+ __m256d res = pmul(lane0, lane1);
956
1052
  res = pmul(res, _mm256_permute2f128_pd(res, res, 1));
957
1053
  return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1)));
958
1054
  }
959
1055
 
960
1056
  template <>
961
1057
  EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) {
962
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
963
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
964
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
965
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
966
- Packet4f res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
1058
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
1059
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
1060
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
1061
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
1062
+ __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
967
1063
  res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
968
1064
  return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
969
1065
  }
970
1066
  template <>
971
1067
  EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) {
972
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
973
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
974
- Packet4d res = _mm256_min_pd(lane0, lane1);
1068
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
1069
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
1070
+ __m256d res = _mm256_min_pd(lane0, lane1);
975
1071
  res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1));
976
1072
  return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1)));
977
1073
  }
978
1074
 
979
1075
  template <>
980
1076
  EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) {
981
- Packet4f lane0 = _mm512_extractf32x4_ps(a, 0);
982
- Packet4f lane1 = _mm512_extractf32x4_ps(a, 1);
983
- Packet4f lane2 = _mm512_extractf32x4_ps(a, 2);
984
- Packet4f lane3 = _mm512_extractf32x4_ps(a, 3);
985
- Packet4f res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
1077
+ __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
1078
+ __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
1079
+ __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
1080
+ __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
1081
+ __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
986
1082
  res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
987
1083
  return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
988
1084
  }
1085
+
989
1086
  template <>
990
1087
  EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) {
991
- Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
992
- Packet4d lane1 = _mm512_extractf64x4_pd(a, 1);
993
- Packet4d res = _mm256_max_pd(lane0, lane1);
1088
+ __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
1089
+ __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
1090
+ __m256d res = _mm256_max_pd(lane0, lane1);
994
1091
  res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1));
995
1092
  return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1)));
996
1093
  }
997
1094
 
998
- template <int Offset>
999
- struct palign_impl<Offset, Packet16f> {
1000
- static EIGEN_STRONG_INLINE void run(Packet16f& first,
1001
- const Packet16f& second) {
1002
- if (Offset != 0) {
1003
- __m512i first_idx = _mm512_set_epi32(
1004
- Offset + 15, Offset + 14, Offset + 13, Offset + 12, Offset + 11,
1005
- Offset + 10, Offset + 9, Offset + 8, Offset + 7, Offset + 6,
1006
- Offset + 5, Offset + 4, Offset + 3, Offset + 2, Offset + 1, Offset);
1007
-
1008
- __m512i second_idx =
1009
- _mm512_set_epi32(Offset - 1, Offset - 2, Offset - 3, Offset - 4,
1010
- Offset - 5, Offset - 6, Offset - 7, Offset - 8,
1011
- Offset - 9, Offset - 10, Offset - 11, Offset - 12,
1012
- Offset - 13, Offset - 14, Offset - 15, Offset - 16);
1013
-
1014
- unsigned short mask = 0xFFFF;
1015
- mask <<= (16 - Offset);
1016
-
1017
- first = _mm512_permutexvar_ps(first_idx, first);
1018
- Packet16f tmp = _mm512_permutexvar_ps(second_idx, second);
1019
- first = _mm512_mask_blend_ps(mask, first, tmp);
1020
- }
1021
- }
1022
- };
1023
- template <int Offset>
1024
- struct palign_impl<Offset, Packet8d> {
1025
- static EIGEN_STRONG_INLINE void run(Packet8d& first, const Packet8d& second) {
1026
- if (Offset != 0) {
1027
- __m512i first_idx = _mm512_set_epi32(
1028
- 0, Offset + 7, 0, Offset + 6, 0, Offset + 5, 0, Offset + 4, 0,
1029
- Offset + 3, 0, Offset + 2, 0, Offset + 1, 0, Offset);
1030
-
1031
- __m512i second_idx = _mm512_set_epi32(
1032
- 0, Offset - 1, 0, Offset - 2, 0, Offset - 3, 0, Offset - 4, 0,
1033
- Offset - 5, 0, Offset - 6, 0, Offset - 7, 0, Offset - 8);
1034
-
1035
- unsigned char mask = 0xFF;
1036
- mask <<= (8 - Offset);
1037
-
1038
- first = _mm512_permutexvar_pd(first_idx, first);
1039
- Packet8d tmp = _mm512_permutexvar_pd(second_idx, second);
1040
- first = _mm512_mask_blend_pd(mask, first, tmp);
1041
- }
1042
- }
1043
- };
1095
+ template<> EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x)
1096
+ {
1097
+ Packet16i xi = _mm512_castps_si512(x);
1098
+ __mmask16 tmp = _mm512_test_epi32_mask(xi,xi);
1099
+ return !_mm512_kortestz(tmp,tmp);
1100
+ }
1101
+
1044
1102
 
1045
1103
 
1046
1104
  #define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \
@@ -1302,11 +1360,940 @@ EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/,
1302
1360
  return Packet16f();
1303
1361
  }
1304
1362
  template <>
1305
- EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& /*ifPacket*/,
1306
- const Packet8d& /*thenPacket*/,
1307
- const Packet8d& /*elsePacket*/) {
1308
- assert(false && "To be implemented");
1309
- return Packet8d();
1363
+ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket,
1364
+ const Packet8d& thenPacket,
1365
+ const Packet8d& elsePacket) {
1366
+ __mmask8 m = (ifPacket.select[0] )
1367
+ | (ifPacket.select[1]<<1)
1368
+ | (ifPacket.select[2]<<2)
1369
+ | (ifPacket.select[3]<<3)
1370
+ | (ifPacket.select[4]<<4)
1371
+ | (ifPacket.select[5]<<5)
1372
+ | (ifPacket.select[6]<<6)
1373
+ | (ifPacket.select[7]<<7);
1374
+ return _mm512_mask_blend_pd(m, elsePacket, thenPacket);
1375
+ }
1376
+
1377
+ // Packet math for Eigen::half
1378
+ template<> EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
1379
+ return _mm256_set1_epi16(from.x);
1380
+ }
1381
+
1382
+ template<> EIGEN_STRONG_INLINE Eigen::half pfirst<Packet16h>(const Packet16h& from) {
1383
+ return half_impl::raw_uint16_to_half(static_cast<unsigned short>(_mm256_extract_epi16(from, 0)));
1384
+ }
1385
+
1386
+ template<> EIGEN_STRONG_INLINE Packet16h pload<Packet16h>(const Eigen::half* from) {
1387
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
1388
+ }
1389
+
1390
+ template<> EIGEN_STRONG_INLINE Packet16h ploadu<Packet16h>(const Eigen::half* from) {
1391
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
1392
+ }
1393
+
1394
+ template<> EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet16h& from) {
1395
+ // (void*) -> workaround clang warning:
1396
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
1397
+ _mm256_store_si256((__m256i*)(void*)to, from);
1398
+ }
1399
+
1400
+ template<> EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet16h& from) {
1401
+ // (void*) -> workaround clang warning:
1402
+ // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32
1403
+ _mm256_storeu_si256((__m256i*)(void*)to, from);
1404
+ }
1405
+
1406
+ template<> EIGEN_STRONG_INLINE Packet16h
1407
+ ploaddup<Packet16h>(const Eigen::half* from) {
1408
+ unsigned short a = from[0].x;
1409
+ unsigned short b = from[1].x;
1410
+ unsigned short c = from[2].x;
1411
+ unsigned short d = from[3].x;
1412
+ unsigned short e = from[4].x;
1413
+ unsigned short f = from[5].x;
1414
+ unsigned short g = from[6].x;
1415
+ unsigned short h = from[7].x;
1416
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
1417
+ }
1418
+
1419
+ template<> EIGEN_STRONG_INLINE Packet16h
1420
+ ploadquad(const Eigen::half* from) {
1421
+ unsigned short a = from[0].x;
1422
+ unsigned short b = from[1].x;
1423
+ unsigned short c = from[2].x;
1424
+ unsigned short d = from[3].x;
1425
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
1426
+ }
1427
+
1428
+ EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) {
1429
+ #ifdef EIGEN_HAS_FP16_C
1430
+ return _mm512_cvtph_ps(a);
1431
+ #else
1432
+ EIGEN_ALIGN64 half aux[16];
1433
+ pstore(aux, a);
1434
+ float f0(aux[0]);
1435
+ float f1(aux[1]);
1436
+ float f2(aux[2]);
1437
+ float f3(aux[3]);
1438
+ float f4(aux[4]);
1439
+ float f5(aux[5]);
1440
+ float f6(aux[6]);
1441
+ float f7(aux[7]);
1442
+ float f8(aux[8]);
1443
+ float f9(aux[9]);
1444
+ float fa(aux[10]);
1445
+ float fb(aux[11]);
1446
+ float fc(aux[12]);
1447
+ float fd(aux[13]);
1448
+ float fe(aux[14]);
1449
+ float ff(aux[15]);
1450
+
1451
+ return _mm512_set_ps(
1452
+ ff, fe, fd, fc, fb, fa, f9, f8, f7, f6, f5, f4, f3, f2, f1, f0);
1453
+ #endif
1454
+ }
1455
+
1456
+ EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) {
1457
+ #ifdef EIGEN_HAS_FP16_C
1458
+ return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
1459
+ #else
1460
+ EIGEN_ALIGN64 float aux[16];
1461
+ pstore(aux, a);
1462
+ half h0(aux[0]);
1463
+ half h1(aux[1]);
1464
+ half h2(aux[2]);
1465
+ half h3(aux[3]);
1466
+ half h4(aux[4]);
1467
+ half h5(aux[5]);
1468
+ half h6(aux[6]);
1469
+ half h7(aux[7]);
1470
+ half h8(aux[8]);
1471
+ half h9(aux[9]);
1472
+ half ha(aux[10]);
1473
+ half hb(aux[11]);
1474
+ half hc(aux[12]);
1475
+ half hd(aux[13]);
1476
+ half he(aux[14]);
1477
+ half hf(aux[15]);
1478
+
1479
+ return _mm256_set_epi16(
1480
+ hf.x, he.x, hd.x, hc.x, hb.x, ha.x, h9.x, h8.x,
1481
+ h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x);
1482
+ #endif
1483
+ }
1484
+
1485
+ template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) {
1486
+ return ptrue(Packet8i(a));
1487
+ }
1488
+
1489
+ template <>
1490
+ EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) {
1491
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
1492
+ return _mm256_andnot_si256(sign_mask, a);
1493
+ }
1494
+
1495
+ template <>
1496
+ EIGEN_STRONG_INLINE Packet16h pmin<Packet16h>(const Packet16h& a,
1497
+ const Packet16h& b) {
1498
+ return float2half(pmin<Packet16f>(half2float(a), half2float(b)));
1499
+ }
1500
+
1501
+ template <>
1502
+ EIGEN_STRONG_INLINE Packet16h pmax<Packet16h>(const Packet16h& a,
1503
+ const Packet16h& b) {
1504
+ return float2half(pmax<Packet16f>(half2float(a), half2float(b)));
1505
+ }
1506
+
1507
+ template <>
1508
+ EIGEN_STRONG_INLINE Packet16h plset<Packet16h>(const half& a) {
1509
+ return float2half(plset<Packet16f>(static_cast<float>(a)));
1510
+ }
1511
+
1512
+ template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) {
1513
+ // in some cases Packet8i is a wrapper around __m256i, so we need to
1514
+ // cast to Packet8i to call the correct overload.
1515
+ return por(Packet8i(a),Packet8i(b));
1516
+ }
1517
+ template<> EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a,const Packet16h& b) {
1518
+ return pxor(Packet8i(a),Packet8i(b));
1519
+ }
1520
+ template<> EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a,const Packet16h& b) {
1521
+ return pand(Packet8i(a),Packet8i(b));
1522
+ }
1523
+ template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet16h& b) {
1524
+ return pandnot(Packet8i(a),Packet8i(b));
1525
+ }
1526
+
1527
+ template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) {
1528
+ return _mm256_blendv_epi8(b, a, mask);
1529
+ }
1530
+
1531
+ template<> EIGEN_STRONG_INLINE Packet16h pround<Packet16h>(const Packet16h& a) {
1532
+ return float2half(pround<Packet16f>(half2float(a)));
1533
+ }
1534
+
1535
+ template<> EIGEN_STRONG_INLINE Packet16h print<Packet16h>(const Packet16h& a) {
1536
+ return float2half(print<Packet16f>(half2float(a)));
1537
+ }
1538
+
1539
+ template<> EIGEN_STRONG_INLINE Packet16h pceil<Packet16h>(const Packet16h& a) {
1540
+ return float2half(pceil<Packet16f>(half2float(a)));
1541
+ }
1542
+
1543
+ template<> EIGEN_STRONG_INLINE Packet16h pfloor<Packet16h>(const Packet16h& a) {
1544
+ return float2half(pfloor<Packet16f>(half2float(a)));
1545
+ }
1546
+
1547
+ template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) {
1548
+ Packet16f af = half2float(a);
1549
+ Packet16f bf = half2float(b);
1550
+ return Pack32To16(pcmp_eq(af, bf));
1551
+ }
1552
+
1553
+ template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) {
1554
+ return Pack32To16(pcmp_le(half2float(a), half2float(b)));
1555
+ }
1556
+
1557
+ template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) {
1558
+ return Pack32To16(pcmp_lt(half2float(a), half2float(b)));
1559
+ }
1560
+
1561
+ template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) {
1562
+ return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b)));
1563
+ }
1564
+
1565
+ template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; }
1566
+
1567
+ template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
1568
+ Packet16h sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
1569
+ return _mm256_xor_si256(a, sign_mask);
1570
+ }
1571
+
1572
+ template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
1573
+ Packet16f af = half2float(a);
1574
+ Packet16f bf = half2float(b);
1575
+ Packet16f rf = padd(af, bf);
1576
+ return float2half(rf);
1577
+ }
1578
+
1579
+ template<> EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
1580
+ Packet16f af = half2float(a);
1581
+ Packet16f bf = half2float(b);
1582
+ Packet16f rf = psub(af, bf);
1583
+ return float2half(rf);
1584
+ }
1585
+
1586
+ template<> EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
1587
+ Packet16f af = half2float(a);
1588
+ Packet16f bf = half2float(b);
1589
+ Packet16f rf = pmul(af, bf);
1590
+ return float2half(rf);
1591
+ }
1592
+
1593
+ template<> EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
1594
+ Packet16f af = half2float(a);
1595
+ Packet16f bf = half2float(b);
1596
+ Packet16f rf = pdiv(af, bf);
1597
+ return float2half(rf);
1598
+ }
1599
+
1600
+ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
1601
+ Packet16f from_float = half2float(from);
1602
+ return half(predux(from_float));
1603
+ }
1604
+
1605
+ template <>
1606
+ EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
1607
+ Packet8h lane0 = _mm256_extractf128_si256(a, 0);
1608
+ Packet8h lane1 = _mm256_extractf128_si256(a, 1);
1609
+ return padd<Packet8h>(lane0, lane1);
1610
+ }
1611
+
1612
+ template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet16h>(const Packet16h& a) {
1613
+ Packet16f af = half2float(a);
1614
+ float reduced = predux_max<Packet16f>(af);
1615
+ return Eigen::half(reduced);
1616
+ }
1617
+
1618
+ template<> EIGEN_STRONG_INLINE Eigen::half predux_min<Packet16h>(const Packet16h& a) {
1619
+ Packet16f af = half2float(a);
1620
+ float reduced = predux_min<Packet16f>(af);
1621
+ return Eigen::half(reduced);
1622
+ }
1623
+
1624
+ template<> EIGEN_STRONG_INLINE half predux_mul<Packet16h>(const Packet16h& from) {
1625
+ Packet16f from_float = half2float(from);
1626
+ return half(predux_mul(from_float));
1627
+ }
1628
+
1629
+ template<> EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a)
1630
+ {
1631
+ __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
1632
+ return _mm256_insertf128_si256(
1633
+ _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a,1),m)),
1634
+ _mm_shuffle_epi8(_mm256_extractf128_si256(a,0),m), 1);
1635
+ }
1636
+
1637
+ template<> EIGEN_STRONG_INLINE Packet16h pgather<Eigen::half, Packet16h>(const Eigen::half* from, Index stride)
1638
+ {
1639
+ return _mm256_set_epi16(
1640
+ from[15*stride].x, from[14*stride].x, from[13*stride].x, from[12*stride].x,
1641
+ from[11*stride].x, from[10*stride].x, from[9*stride].x, from[8*stride].x,
1642
+ from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x,
1643
+ from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x);
1644
+ }
1645
+
1646
+ template<> EIGEN_STRONG_INLINE void pscatter<half, Packet16h>(half* to, const Packet16h& from, Index stride)
1647
+ {
1648
+ EIGEN_ALIGN64 half aux[16];
1649
+ pstore(aux, from);
1650
+ to[stride*0] = aux[0];
1651
+ to[stride*1] = aux[1];
1652
+ to[stride*2] = aux[2];
1653
+ to[stride*3] = aux[3];
1654
+ to[stride*4] = aux[4];
1655
+ to[stride*5] = aux[5];
1656
+ to[stride*6] = aux[6];
1657
+ to[stride*7] = aux[7];
1658
+ to[stride*8] = aux[8];
1659
+ to[stride*9] = aux[9];
1660
+ to[stride*10] = aux[10];
1661
+ to[stride*11] = aux[11];
1662
+ to[stride*12] = aux[12];
1663
+ to[stride*13] = aux[13];
1664
+ to[stride*14] = aux[14];
1665
+ to[stride*15] = aux[15];
1666
+ }
1667
+
1668
+ EIGEN_STRONG_INLINE void
1669
+ ptranspose(PacketBlock<Packet16h,16>& kernel) {
1670
+ __m256i a = kernel.packet[0];
1671
+ __m256i b = kernel.packet[1];
1672
+ __m256i c = kernel.packet[2];
1673
+ __m256i d = kernel.packet[3];
1674
+ __m256i e = kernel.packet[4];
1675
+ __m256i f = kernel.packet[5];
1676
+ __m256i g = kernel.packet[6];
1677
+ __m256i h = kernel.packet[7];
1678
+ __m256i i = kernel.packet[8];
1679
+ __m256i j = kernel.packet[9];
1680
+ __m256i k = kernel.packet[10];
1681
+ __m256i l = kernel.packet[11];
1682
+ __m256i m = kernel.packet[12];
1683
+ __m256i n = kernel.packet[13];
1684
+ __m256i o = kernel.packet[14];
1685
+ __m256i p = kernel.packet[15];
1686
+
1687
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
1688
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
1689
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
1690
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
1691
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
1692
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
1693
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
1694
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
1695
+
1696
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
1697
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
1698
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
1699
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
1700
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
1701
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
1702
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
1703
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
1704
+
1705
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
1706
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
1707
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
1708
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
1709
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
1710
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
1711
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
1712
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
1713
+
1714
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
1715
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
1716
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
1717
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
1718
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
1719
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
1720
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
1721
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
1722
+
1723
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
1724
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
1725
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
1726
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
1727
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
1728
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
1729
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
1730
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
1731
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
1732
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
1733
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
1734
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
1735
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
1736
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
1737
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
1738
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
1739
+
1740
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
1741
+ __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
1742
+ __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
1743
+ __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
1744
+ __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
1745
+ __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
1746
+ __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
1747
+ __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
1748
+ __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
1749
+ __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
1750
+ __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
1751
+ __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
1752
+ __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
1753
+ __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
1754
+ __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
1755
+ __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
1756
+ __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
1757
+
1758
+ kernel.packet[0] = a_p_0;
1759
+ kernel.packet[1] = a_p_1;
1760
+ kernel.packet[2] = a_p_2;
1761
+ kernel.packet[3] = a_p_3;
1762
+ kernel.packet[4] = a_p_4;
1763
+ kernel.packet[5] = a_p_5;
1764
+ kernel.packet[6] = a_p_6;
1765
+ kernel.packet[7] = a_p_7;
1766
+ kernel.packet[8] = a_p_8;
1767
+ kernel.packet[9] = a_p_9;
1768
+ kernel.packet[10] = a_p_a;
1769
+ kernel.packet[11] = a_p_b;
1770
+ kernel.packet[12] = a_p_c;
1771
+ kernel.packet[13] = a_p_d;
1772
+ kernel.packet[14] = a_p_e;
1773
+ kernel.packet[15] = a_p_f;
1774
+ }
1775
+
1776
+ EIGEN_STRONG_INLINE void
1777
+ ptranspose(PacketBlock<Packet16h,8>& kernel) {
1778
+ EIGEN_ALIGN64 half in[8][16];
1779
+ pstore<half>(in[0], kernel.packet[0]);
1780
+ pstore<half>(in[1], kernel.packet[1]);
1781
+ pstore<half>(in[2], kernel.packet[2]);
1782
+ pstore<half>(in[3], kernel.packet[3]);
1783
+ pstore<half>(in[4], kernel.packet[4]);
1784
+ pstore<half>(in[5], kernel.packet[5]);
1785
+ pstore<half>(in[6], kernel.packet[6]);
1786
+ pstore<half>(in[7], kernel.packet[7]);
1787
+
1788
+ EIGEN_ALIGN64 half out[8][16];
1789
+
1790
+ for (int i = 0; i < 8; ++i) {
1791
+ for (int j = 0; j < 8; ++j) {
1792
+ out[i][j] = in[j][2*i];
1793
+ }
1794
+ for (int j = 0; j < 8; ++j) {
1795
+ out[i][j+8] = in[j][2*i+1];
1796
+ }
1797
+ }
1798
+
1799
+ kernel.packet[0] = pload<Packet16h>(out[0]);
1800
+ kernel.packet[1] = pload<Packet16h>(out[1]);
1801
+ kernel.packet[2] = pload<Packet16h>(out[2]);
1802
+ kernel.packet[3] = pload<Packet16h>(out[3]);
1803
+ kernel.packet[4] = pload<Packet16h>(out[4]);
1804
+ kernel.packet[5] = pload<Packet16h>(out[5]);
1805
+ kernel.packet[6] = pload<Packet16h>(out[6]);
1806
+ kernel.packet[7] = pload<Packet16h>(out[7]);
1807
+ }
1808
+
1809
+ EIGEN_STRONG_INLINE void
1810
+ ptranspose(PacketBlock<Packet16h,4>& kernel) {
1811
+ EIGEN_ALIGN64 half in[4][16];
1812
+ pstore<half>(in[0], kernel.packet[0]);
1813
+ pstore<half>(in[1], kernel.packet[1]);
1814
+ pstore<half>(in[2], kernel.packet[2]);
1815
+ pstore<half>(in[3], kernel.packet[3]);
1816
+
1817
+ EIGEN_ALIGN64 half out[4][16];
1818
+
1819
+ for (int i = 0; i < 4; ++i) {
1820
+ for (int j = 0; j < 4; ++j) {
1821
+ out[i][j] = in[j][4*i];
1822
+ }
1823
+ for (int j = 0; j < 4; ++j) {
1824
+ out[i][j+4] = in[j][4*i+1];
1825
+ }
1826
+ for (int j = 0; j < 4; ++j) {
1827
+ out[i][j+8] = in[j][4*i+2];
1828
+ }
1829
+ for (int j = 0; j < 4; ++j) {
1830
+ out[i][j+12] = in[j][4*i+3];
1831
+ }
1832
+ }
1833
+
1834
+ kernel.packet[0] = pload<Packet16h>(out[0]);
1835
+ kernel.packet[1] = pload<Packet16h>(out[1]);
1836
+ kernel.packet[2] = pload<Packet16h>(out[2]);
1837
+ kernel.packet[3] = pload<Packet16h>(out[3]);
1838
+ }
1839
+
1840
+ template <> struct is_arithmetic<Packet16bf> { enum { value = true }; };
1841
+
1842
+ template <>
1843
+ struct packet_traits<bfloat16> : default_packet_traits {
1844
+ typedef Packet16bf type;
1845
+ typedef Packet8bf half;
1846
+ enum {
1847
+ Vectorizable = 1,
1848
+ AlignedOnScalar = 1,
1849
+ size = 16,
1850
+ HasHalfPacket = 1,
1851
+ HasBlend = 0,
1852
+ HasInsert = 1,
1853
+ HasSin = EIGEN_FAST_MATH,
1854
+ HasCos = EIGEN_FAST_MATH,
1855
+ #if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
1856
+ #ifdef EIGEN_VECTORIZE_AVX512DQ
1857
+ HasLog = 1, // Currently fails test with bad accuracy.
1858
+ HasLog1p = 1,
1859
+ HasExpm1 = 1,
1860
+ HasNdtri = 1,
1861
+ HasBessel = 1,
1862
+ #endif
1863
+ HasExp = 1,
1864
+ HasSqrt = EIGEN_FAST_MATH,
1865
+ HasRsqrt = EIGEN_FAST_MATH,
1866
+ HasTanh = EIGEN_FAST_MATH,
1867
+ HasErf = EIGEN_FAST_MATH,
1868
+ #endif
1869
+ HasCmp = 1,
1870
+ HasDiv = 1
1871
+ };
1872
+ };
1873
+
1874
+ template <>
1875
+ struct unpacket_traits<Packet16bf>
1876
+ {
1877
+ typedef bfloat16 type;
1878
+ enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
1879
+ typedef Packet8bf half;
1880
+ };
1881
+
1882
+ template <>
1883
+ EIGEN_STRONG_INLINE Packet16bf pset1<Packet16bf>(const bfloat16& from) {
1884
+ return _mm256_set1_epi16(from.value);
1885
+ }
1886
+
1887
+ template <>
1888
+ EIGEN_STRONG_INLINE bfloat16 pfirst<Packet16bf>(const Packet16bf& from) {
1889
+ bfloat16 t;
1890
+ t.value = static_cast<unsigned short>(_mm256_extract_epi16(from, 0));
1891
+ return t;
1892
+ }
1893
+
1894
+ template <>
1895
+ EIGEN_STRONG_INLINE Packet16bf pload<Packet16bf>(const bfloat16* from) {
1896
+ return _mm256_load_si256(reinterpret_cast<const __m256i*>(from));
1897
+ }
1898
+
1899
+ template <>
1900
+ EIGEN_STRONG_INLINE Packet16bf ploadu<Packet16bf>(const bfloat16* from) {
1901
+ return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(from));
1902
+ }
1903
+
1904
+ template <>
1905
+ EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to,
1906
+ const Packet16bf& from) {
1907
+ _mm256_store_si256(reinterpret_cast<__m256i*>(to), from);
1908
+ }
1909
+
1910
+ template <>
1911
+ EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to,
1912
+ const Packet16bf& from) {
1913
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from);
1914
+ }
1915
+
1916
+ template<> EIGEN_STRONG_INLINE Packet16bf
1917
+ ploaddup<Packet16bf>(const bfloat16* from) {
1918
+ Packet16bf r;
1919
+ unsigned short a = from[0].value;
1920
+ unsigned short b = from[1].value;
1921
+ unsigned short c = from[2].value;
1922
+ unsigned short d = from[3].value;
1923
+ unsigned short e = from[4].value;
1924
+ unsigned short f = from[5].value;
1925
+ unsigned short g = from[6].value;
1926
+ unsigned short h = from[7].value;
1927
+ return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a);
1928
+ }
1929
+
1930
+ template<> EIGEN_STRONG_INLINE Packet16bf
1931
+ ploadquad(const bfloat16* from) {
1932
+ Packet16bf r;
1933
+ unsigned short a = from[0].value;
1934
+ unsigned short b = from[1].value;
1935
+ unsigned short c = from[2].value;
1936
+ unsigned short d = from[3].value;
1937
+ return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a);
1938
+ }
1939
+
1940
+ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
1941
+ return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16));
1942
+ }
1943
+
1944
+ // Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm.
1945
+ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
1946
+ Packet16bf r;
1947
+
1948
+ #if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
1949
+ // Since GCC 10.1 supports avx512bf16 and C style explicit cast
1950
+ // (C++ static_cast is not supported yet), do converion via intrinsic
1951
+ // and register path for performance.
1952
+ r = (__m256i)(_mm512_cvtneps_pbh(a));
1953
+
1954
+ #else
1955
+ __m512i t;
1956
+ __m512i input = _mm512_castps_si512(a);
1957
+ __m512i nan = _mm512_set1_epi32(0x7fc0);
1958
+
1959
+ // uint32_t lsb = (input >> 16) & 1;
1960
+ t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1));
1961
+ // uint32_t rounding_bias = 0x7fff + lsb;
1962
+ t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff));
1963
+ // input += rounding_bias;
1964
+ t = _mm512_add_epi32(t, input);
1965
+ // input = input >> 16;
1966
+ t = _mm512_srli_epi32(t, 16);
1967
+
1968
+ // Check NaN before converting back to bf16
1969
+ __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
1970
+
1971
+ t = _mm512_mask_blend_epi32(mask, nan, t);
1972
+ // output.value = static_cast<uint16_t>(input);
1973
+ r = _mm512_cvtepi32_epi16(t);
1974
+ #endif // EIGEN_VECTORIZE_AVX512BF16
1975
+
1976
+ return r;
1977
+ }
1978
+
1979
+ template <>
1980
+ EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) {
1981
+ return ptrue<Packet8i>(a);
1982
+ }
1983
+
1984
+ template <>
1985
+ EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) {
1986
+ return por<Packet8i>(a, b);
1987
+ }
1988
+
1989
+ template <>
1990
+ EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
1991
+ return pxor<Packet8i>(a, b);
1992
+ }
1993
+
1994
+ template <>
1995
+ EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
1996
+ return pand<Packet8i>(a, b);
1997
+ }
1998
+
1999
+ template <>
2000
+ EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
2001
+ const Packet16bf& b) {
2002
+ return pandnot<Packet8i>(a, b);
2003
+ }
2004
+
2005
+ template <>
2006
+ EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask,
2007
+ const Packet16bf& a,
2008
+ const Packet16bf& b) {
2009
+ // Input mask is expected to be all 0/1, handle it with 8-bit
2010
+ // intrinsic for performance.
2011
+ return _mm256_blendv_epi8(b, a, mask);
2012
+ }
2013
+
2014
+ template<> EIGEN_STRONG_INLINE Packet16bf pround<Packet16bf>(const Packet16bf& a)
2015
+ {
2016
+ return F32ToBf16(pround<Packet16f>(Bf16ToF32(a)));
2017
+ }
2018
+
2019
+ template<> EIGEN_STRONG_INLINE Packet16bf print<Packet16bf>(const Packet16bf& a) {
2020
+ return F32ToBf16(print<Packet16f>(Bf16ToF32(a)));
2021
+ }
2022
+
2023
+ template<> EIGEN_STRONG_INLINE Packet16bf pceil<Packet16bf>(const Packet16bf& a) {
2024
+ return F32ToBf16(pceil<Packet16f>(Bf16ToF32(a)));
2025
+ }
2026
+
2027
+ template<> EIGEN_STRONG_INLINE Packet16bf pfloor<Packet16bf>(const Packet16bf& a) {
2028
+ return F32ToBf16(pfloor<Packet16f>(Bf16ToF32(a)));
2029
+ }
2030
+
2031
+ template <>
2032
+ EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a,
2033
+ const Packet16bf& b) {
2034
+ return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b)));
2035
+ }
2036
+
2037
+ template <>
2038
+ EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a,
2039
+ const Packet16bf& b) {
2040
+ return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b)));
2041
+ }
2042
+
2043
+ template <>
2044
+ EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a,
2045
+ const Packet16bf& b) {
2046
+ return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b)));
2047
+ }
2048
+
2049
+ template <>
2050
+ EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a,
2051
+ const Packet16bf& b) {
2052
+ return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b)));
2053
+ }
2054
+
2055
+ template <>
2056
+ EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) {
2057
+ Packet16bf sign_mask = _mm256_set1_epi16(static_cast<unsigned short>(0x8000));
2058
+ return _mm256_xor_si256(a, sign_mask);
2059
+ }
2060
+
2061
+ template <>
2062
+ EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) {
2063
+ return a;
2064
+ }
2065
+
2066
+ template <>
2067
+ EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) {
2068
+ const __m256i sign_mask = _mm256_set1_epi16(static_cast<numext::uint16_t>(0x8000));
2069
+ return _mm256_andnot_si256(sign_mask, a);
2070
+ }
2071
+
2072
+ template <>
2073
+ EIGEN_STRONG_INLINE Packet16bf padd<Packet16bf>(const Packet16bf& a,
2074
+ const Packet16bf& b) {
2075
+ return F32ToBf16(padd<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2076
+ }
2077
+
2078
+ template <>
2079
+ EIGEN_STRONG_INLINE Packet16bf psub<Packet16bf>(const Packet16bf& a,
2080
+ const Packet16bf& b) {
2081
+ return F32ToBf16(psub<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2082
+ }
2083
+
2084
+ template <>
2085
+ EIGEN_STRONG_INLINE Packet16bf pmul<Packet16bf>(const Packet16bf& a,
2086
+ const Packet16bf& b) {
2087
+ return F32ToBf16(pmul<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2088
+ }
2089
+
2090
+ template <>
2091
+ EIGEN_STRONG_INLINE Packet16bf pdiv<Packet16bf>(const Packet16bf& a,
2092
+ const Packet16bf& b) {
2093
+ return F32ToBf16(pdiv<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2094
+ }
2095
+
2096
+ template <>
2097
+ EIGEN_STRONG_INLINE Packet16bf pmin<Packet16bf>(const Packet16bf& a,
2098
+ const Packet16bf& b) {
2099
+ return F32ToBf16(pmin<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2100
+ }
2101
+
2102
+ template <>
2103
+ EIGEN_STRONG_INLINE Packet16bf pmax<Packet16bf>(const Packet16bf& a,
2104
+ const Packet16bf& b) {
2105
+ return F32ToBf16(pmax<Packet16f>(Bf16ToF32(a), Bf16ToF32(b)));
2106
+ }
2107
+
2108
+ template <>
2109
+ EIGEN_STRONG_INLINE Packet16bf plset<Packet16bf>(const bfloat16& a) {
2110
+ return F32ToBf16(plset<Packet16f>(static_cast<float>(a)));
2111
+ }
2112
+
2113
+ template <>
2114
+ EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4<Packet16bf>(const Packet16bf& a) {
2115
+ Packet8bf lane0 = _mm256_extractf128_si256(a, 0);
2116
+ Packet8bf lane1 = _mm256_extractf128_si256(a, 1);
2117
+ return padd<Packet8bf>(lane0, lane1);
2118
+ }
2119
+
2120
+ template <>
2121
+ EIGEN_STRONG_INLINE bfloat16 predux<Packet16bf>(const Packet16bf& p) {
2122
+ return static_cast<bfloat16>(predux<Packet16f>(Bf16ToF32(p)));
2123
+ }
2124
+
2125
+ template <>
2126
+ EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet16bf>(const Packet16bf& from) {
2127
+ return static_cast<bfloat16>(predux_mul<Packet16f>(Bf16ToF32(from)));
2128
+ }
2129
+
2130
+ template <>
2131
+ EIGEN_STRONG_INLINE bfloat16 predux_min<Packet16bf>(const Packet16bf& from) {
2132
+ return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
2133
+ }
2134
+
2135
+ template <>
2136
+ EIGEN_STRONG_INLINE bfloat16 predux_max<Packet16bf>(const Packet16bf& from) {
2137
+ return static_cast<bfloat16>(predux_max<Packet16f>(Bf16ToF32(from)));
2138
+ }
2139
+
2140
+ template <>
2141
+ EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) {
2142
+ __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,
2143
+ 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1);
2144
+
2145
+ Packet16bf res;
2146
+ // Swap hi and lo first because shuffle is in 128-bit lanes.
2147
+ res = _mm256_permute2x128_si256(a, a, 1);
2148
+ // Shuffle 8-bit values in src within 2*128-bit lanes.
2149
+ return _mm256_shuffle_epi8(res, m);
2150
+ }
2151
+
2152
+ template <>
2153
+ EIGEN_STRONG_INLINE Packet16bf pgather<bfloat16, Packet16bf>(const bfloat16* from,
2154
+ Index stride) {
2155
+ return _mm256_set_epi16(
2156
+ from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value,
2157
+ from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value,
2158
+ from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value,
2159
+ from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value);
2160
+ }
2161
+
2162
+ template <>
2163
+ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet16bf>(bfloat16* to,
2164
+ const Packet16bf& from,
2165
+ Index stride) {
2166
+ EIGEN_ALIGN64 bfloat16 aux[16];
2167
+ pstore(aux, from);
2168
+ to[stride*0] = aux[0];
2169
+ to[stride*1] = aux[1];
2170
+ to[stride*2] = aux[2];
2171
+ to[stride*3] = aux[3];
2172
+ to[stride*4] = aux[4];
2173
+ to[stride*5] = aux[5];
2174
+ to[stride*6] = aux[6];
2175
+ to[stride*7] = aux[7];
2176
+ to[stride*8] = aux[8];
2177
+ to[stride*9] = aux[9];
2178
+ to[stride*10] = aux[10];
2179
+ to[stride*11] = aux[11];
2180
+ to[stride*12] = aux[12];
2181
+ to[stride*13] = aux[13];
2182
+ to[stride*14] = aux[14];
2183
+ to[stride*15] = aux[15];
2184
+ }
2185
+
2186
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,16>& kernel) {
2187
+ __m256i a = kernel.packet[0];
2188
+ __m256i b = kernel.packet[1];
2189
+ __m256i c = kernel.packet[2];
2190
+ __m256i d = kernel.packet[3];
2191
+ __m256i e = kernel.packet[4];
2192
+ __m256i f = kernel.packet[5];
2193
+ __m256i g = kernel.packet[6];
2194
+ __m256i h = kernel.packet[7];
2195
+ __m256i i = kernel.packet[8];
2196
+ __m256i j = kernel.packet[9];
2197
+ __m256i k = kernel.packet[10];
2198
+ __m256i l = kernel.packet[11];
2199
+ __m256i m = kernel.packet[12];
2200
+ __m256i n = kernel.packet[13];
2201
+ __m256i o = kernel.packet[14];
2202
+ __m256i p = kernel.packet[15];
2203
+
2204
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
2205
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
2206
+ __m256i ef_07 = _mm256_unpacklo_epi16(e, f);
2207
+ __m256i gh_07 = _mm256_unpacklo_epi16(g, h);
2208
+ __m256i ij_07 = _mm256_unpacklo_epi16(i, j);
2209
+ __m256i kl_07 = _mm256_unpacklo_epi16(k, l);
2210
+ __m256i mn_07 = _mm256_unpacklo_epi16(m, n);
2211
+ __m256i op_07 = _mm256_unpacklo_epi16(o, p);
2212
+
2213
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
2214
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
2215
+ __m256i ef_8f = _mm256_unpackhi_epi16(e, f);
2216
+ __m256i gh_8f = _mm256_unpackhi_epi16(g, h);
2217
+ __m256i ij_8f = _mm256_unpackhi_epi16(i, j);
2218
+ __m256i kl_8f = _mm256_unpackhi_epi16(k, l);
2219
+ __m256i mn_8f = _mm256_unpackhi_epi16(m, n);
2220
+ __m256i op_8f = _mm256_unpackhi_epi16(o, p);
2221
+
2222
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
2223
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
2224
+ __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07);
2225
+ __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07);
2226
+ __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07);
2227
+ __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07);
2228
+ __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07);
2229
+ __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07);
2230
+
2231
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
2232
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
2233
+ __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f);
2234
+ __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f);
2235
+ __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f);
2236
+ __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f);
2237
+ __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f);
2238
+ __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f);
2239
+
2240
+ __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03);
2241
+ __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03);
2242
+ __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03);
2243
+ __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03);
2244
+ __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47);
2245
+ __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47);
2246
+ __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47);
2247
+ __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47);
2248
+ __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b);
2249
+ __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b);
2250
+ __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b);
2251
+ __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b);
2252
+ __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf);
2253
+ __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf);
2254
+ __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf);
2255
+ __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf);
2256
+
2257
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
2258
+ kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20);
2259
+ kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20);
2260
+ kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20);
2261
+ kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20);
2262
+ kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20);
2263
+ kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20);
2264
+ kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20);
2265
+ kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20);
2266
+ kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31);
2267
+ kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31);
2268
+ kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31);
2269
+ kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31);
2270
+ kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31);
2271
+ kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31);
2272
+ kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31);
2273
+ kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31);
2274
+ }
2275
+
2276
+ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf,4>& kernel) {
2277
+ __m256i a = kernel.packet[0];
2278
+ __m256i b = kernel.packet[1];
2279
+ __m256i c = kernel.packet[2];
2280
+ __m256i d = kernel.packet[3];
2281
+
2282
+ __m256i ab_07 = _mm256_unpacklo_epi16(a, b);
2283
+ __m256i cd_07 = _mm256_unpacklo_epi16(c, d);
2284
+ __m256i ab_8f = _mm256_unpackhi_epi16(a, b);
2285
+ __m256i cd_8f = _mm256_unpackhi_epi16(c, d);
2286
+
2287
+ __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07);
2288
+ __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07);
2289
+ __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f);
2290
+ __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f);
2291
+
2292
+ // NOTE: no unpacklo/hi instr in this case, so using permute instr.
2293
+ kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20);
2294
+ kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20);
2295
+ kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31);
2296
+ kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
1310
2297
  }
1311
2298
 
1312
2299
  } // end namespace internal