@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -0,0 +1,1245 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11
+ #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
12
+
13
+ #if EIGEN_COMP_MSVC
14
+ #include <intrin.h>
15
+ #else
16
+ #include <x86intrin.h>
17
+ #endif
18
+ #include <immintrin.h>
19
+ #include <type_traits>
20
+
21
+ // IWYU pragma: private
22
+ #include "../../InternalHeaderCheck.h"
23
+
24
+ #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
25
+ #define EIGEN_USE_AVX512_GEMM_KERNELS 1
26
+ #endif
27
+
28
+ #define SECOND_FETCH (32)
29
+ #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
30
+ // Use less registers to load A elements to workaround compiler spills. Loose a
31
+ // bit of performance (less than ~2%).
32
+ #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
33
+ #endif
34
+
35
+ namespace Eigen {
36
+ namespace internal {
37
+
38
+ template <typename Scalar, bool is_unit_inc>
39
+ class gemm_class {
40
+ using vec = typename packet_traits<Scalar>::type;
41
+ using vec_ymm = typename unpacket_traits<vec>::half;
42
+ using vec_xmm = typename unpacket_traits<vec_ymm>::half;
43
+ using umask_t = typename unpacket_traits<vec>::mask_t;
44
+
45
+ static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
46
+ static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
47
+
48
+ #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
49
+ static constexpr bool use_less_a_regs = !is_unit_inc;
50
+ #else
51
+ static constexpr bool use_less_a_regs = true;
52
+ #endif
53
+ #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
54
+ static constexpr bool use_less_b_regs = !is_unit_inc;
55
+ #else
56
+ static constexpr bool use_less_b_regs = true;
57
+ #endif
58
+
59
+ static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
60
+ static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
61
+ static constexpr int c_regs[] = {
62
+ 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
63
+ };
64
+
65
+ static constexpr int alpha_load_reg = 0;
66
+ static constexpr int c_load_regs[] = {1, 2, 6};
67
+
68
+ static constexpr int a_shift = 128;
69
+ static constexpr int b_shift = 128;
70
+
71
+ static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
72
+ static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
73
+ static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
74
+
75
+ vec zmm[32];
76
+ umask_t mask;
77
+
78
+ // gemm arguments.
79
+ Index m;
80
+ const Index n, k, ldc;
81
+ const Index inc;
82
+ const Scalar *alpha;
83
+
84
+ const Scalar *a, *b;
85
+ Scalar *c;
86
+
87
+ const bool is_alpha1;
88
+ const bool is_beta0;
89
+
90
+ const Index a_stride, b_stride;
91
+ const Index a_off, b_off;
92
+
93
+ EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) {
94
+ _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
95
+ }
96
+
97
+ EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) {
98
+ _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
99
+ }
100
+
101
+ EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); }
102
+
103
+ EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) {
104
+ #if defined(__PRFCHW__) && __PRFCHW__ == 1
105
+ _m_prefetchw((void *)c_addr);
106
+ #else
107
+ _mm_prefetch((char *)c_addr, _MM_HINT_T0);
108
+ #endif
109
+ }
110
+
111
+ template <int nelems>
112
+ EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) {
113
+ switch (nelems * sizeof(*a_addr) * 8) {
114
+ default:
115
+ case 512 * 3:
116
+ a_reg = ploadu<vec>(a_addr);
117
+ break;
118
+ case 512 * 2:
119
+ a_reg = ploadu<vec>(a_addr);
120
+ break;
121
+ case 512 * 1:
122
+ a_reg = ploadu<vec>(a_addr);
123
+ break;
124
+ case 256 * 1:
125
+ a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
126
+ break;
127
+ case 128 * 1:
128
+ a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
129
+ break;
130
+ case 64 * 1:
131
+ a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
132
+ break;
133
+ case 32 * 1:
134
+ a_reg = pload1<vec>(a_addr);
135
+ break;
136
+ }
137
+ }
138
+
139
+ EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
140
+
141
+ template <int nelems>
142
+ EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) {
143
+ if (is_unit_inc) {
144
+ switch (nelems * sizeof(*mem) * 8) {
145
+ default:
146
+ case 512 * 3:
147
+ pstoreu(mem, src);
148
+ break;
149
+ case 512 * 2:
150
+ pstoreu(mem, src);
151
+ break;
152
+ case 512 * 1:
153
+ pstoreu(mem, src);
154
+ break;
155
+ case 256 * 1:
156
+ pstoreu(mem, preinterpret<vec_ymm>(src));
157
+ break;
158
+ case 128 * 1:
159
+ pstoreu(mem, preinterpret<vec_xmm>(src));
160
+ break;
161
+ case 64 * 1:
162
+ pstorel(mem, preinterpret<vec_xmm>(src));
163
+ break;
164
+ case 32 * 1:
165
+ pstores(mem, preinterpret<vec_xmm>(src));
166
+ break;
167
+ }
168
+ } else {
169
+ switch (nelems * sizeof(*mem) * 8) {
170
+ default:
171
+ case 512 * 3:
172
+ pscatter(mem, src, inc);
173
+ break;
174
+ case 512 * 2:
175
+ pscatter(mem, src, inc);
176
+ break;
177
+ case 512 * 1:
178
+ pscatter(mem, src, inc);
179
+ break;
180
+ case 256 * 1:
181
+ pscatter(mem, src, inc, mask);
182
+ break;
183
+ case 128 * 1:
184
+ pscatter(mem, src, inc, mask);
185
+ break;
186
+ case 64 * 1:
187
+ pscatter(mem, src, inc, mask);
188
+ break;
189
+ case 32 * 1:
190
+ pscatter(mem, src, inc, mask);
191
+ break;
192
+ }
193
+ }
194
+ }
195
+
196
+ template <int nelems>
197
+ EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec &reg) {
198
+ if (is_unit_inc) {
199
+ switch (nelems * sizeof(*mem) * 8) {
200
+ default:
201
+ case 512 * 3:
202
+ dst = padd(src, ploadu<vec>(mem));
203
+ break;
204
+ case 512 * 2:
205
+ dst = padd(src, ploadu<vec>(mem));
206
+ break;
207
+ case 512 * 1:
208
+ dst = padd(src, ploadu<vec>(mem));
209
+ break;
210
+ case 256 * 1:
211
+ dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
212
+ break;
213
+ case 128 * 1:
214
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
215
+ break;
216
+ case 64 * 1:
217
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
218
+ break;
219
+ case 32 * 1:
220
+ dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
221
+ break;
222
+ }
223
+ } else {
224
+ // Zero out scratch register
225
+ reg = pzero(reg);
226
+
227
+ switch (nelems * sizeof(*mem) * 8) {
228
+ default:
229
+ case 512 * 3:
230
+ reg = pgather<Scalar, vec>(mem, inc);
231
+ dst = padd(src, reg);
232
+ break;
233
+ case 512 * 2:
234
+ reg = pgather<Scalar, vec>(mem, inc);
235
+ dst = padd(src, reg);
236
+ break;
237
+ case 512 * 1:
238
+ reg = pgather<Scalar, vec>(mem, inc);
239
+ dst = padd(src, reg);
240
+ break;
241
+ case 256 * 1:
242
+ reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
243
+ dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
244
+ break;
245
+ case 128 * 1:
246
+ reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
247
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
248
+ break;
249
+ case 64 * 1:
250
+ if (is_f32) {
251
+ reg = pgather(reg, mem, inc, mask);
252
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
253
+ } else {
254
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
255
+ }
256
+ break;
257
+ case 32 * 1:
258
+ dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
259
+ break;
260
+ }
261
+ }
262
+ }
263
+
264
+ EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
265
+ dst = pmadd(src1, src2, dst);
266
+
267
+ #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
268
+ // Workaround register spills for gcc and clang
269
+ __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
270
+ #endif
271
+ }
272
+
273
+ template <int nelems>
274
+ EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec &reg) {
275
+ if (is_unit_inc) {
276
+ switch (nelems * sizeof(*mem) * 8) {
277
+ default:
278
+ case 512 * 3:
279
+ dst = pmadd(scale, src, ploadu<vec>(mem));
280
+ break;
281
+ case 512 * 2:
282
+ dst = pmadd(scale, src, ploadu<vec>(mem));
283
+ break;
284
+ case 512 * 1:
285
+ dst = pmadd(scale, src, ploadu<vec>(mem));
286
+ break;
287
+ case 256 * 1:
288
+ dst =
289
+ preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
290
+ break;
291
+ case 128 * 1:
292
+ dst =
293
+ preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
294
+ break;
295
+ case 64 * 1:
296
+ dst =
297
+ preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
298
+ break;
299
+ case 32 * 1:
300
+ dst =
301
+ preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
302
+ break;
303
+ }
304
+ } else {
305
+ // Zero out scratch register
306
+ reg = pzero(reg);
307
+
308
+ switch (nelems * sizeof(*mem) * 8) {
309
+ default:
310
+ case 512 * 3:
311
+ reg = pgather<Scalar, vec>(mem, inc);
312
+ dst = pmadd(scale, src, reg);
313
+ break;
314
+ case 512 * 2:
315
+ reg = pgather<Scalar, vec>(mem, inc);
316
+ dst = pmadd(scale, src, reg);
317
+ break;
318
+ case 512 * 1:
319
+ reg = pgather<Scalar, vec>(mem, inc);
320
+ dst = pmadd(scale, src, reg);
321
+ break;
322
+ case 256 * 1:
323
+ reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
324
+ dst = preinterpret<vec>(
325
+ pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
326
+ break;
327
+ case 128 * 1:
328
+ reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
329
+ dst = preinterpret<vec>(
330
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
331
+ break;
332
+ case 64 * 1:
333
+ if (is_f32) {
334
+ reg = pgather(reg, mem, inc, mask);
335
+ dst = preinterpret<vec>(
336
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
337
+ } else {
338
+ dst = preinterpret<vec>(
339
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
340
+ }
341
+ break;
342
+ case 32 * 1:
343
+ dst =
344
+ preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
345
+ break;
346
+ }
347
+ }
348
+ }
349
+
350
+ template <int j, int endX, int i, int endY, int nelems>
351
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) {
352
+ EIGEN_UNUSED_VARIABLE(ao);
353
+ }
354
+
355
+ template <int j, int endX, int i, int endY, int nelems>
356
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) {
357
+ if (j < endX) {
358
+ if (i < endY) {
359
+ auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
360
+ const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
361
+ a_load<nelems>(a_reg, a_addr);
362
+
363
+ a_loads<j, endX, i + 1, endY, nelems>(ao);
364
+ } else {
365
+ a_loads<j + 1, endX, 0, endY, nelems>(ao);
366
+ }
367
+ }
368
+ }
369
+
370
+ template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
371
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1,
372
+ const Scalar *co2) {
373
+ EIGEN_UNUSED_VARIABLE(co1);
374
+ EIGEN_UNUSED_VARIABLE(co2);
375
+ }
376
+
377
+ /* C prefetch loop structure.
378
+ * for (int un = 0; un < 8; un++) {
379
+ * if (b_unroll >= un + 1) {
380
+ * if (un == 4) co2 = co1 + 4 * ldc;
381
+ *
382
+ * for (int i = 0; i < um_vecs; i++) {
383
+ * Scalar *co = (un + 1 <= 4) ? co1 : co2;
384
+ * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
385
+ * prefetch_c(co + co_off);
386
+ * }
387
+ * }
388
+ * }
389
+ */
390
+
391
+ template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
392
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
393
+ if (un < max_b_unroll) {
394
+ if (b_unroll >= un + 1) {
395
+ if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
396
+
397
+ if (i < um_vecs) {
398
+ Scalar *co = (un + 1 <= 4) ? co1 : co2;
399
+ auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
400
+ prefetch_c(co + co_off);
401
+
402
+ prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
403
+ } else {
404
+ prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
405
+ }
406
+
407
+ } else {
408
+ prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
409
+ }
410
+ }
411
+ }
412
+
413
+ // load_c
414
+ template <int i, int um_vecs, int idx, int nelems>
415
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
416
+ EIGEN_UNUSED_VARIABLE(cox);
417
+ EIGEN_UNUSED_VARIABLE(alpha_reg);
418
+ }
419
+
420
+ template <int i, int um_vecs, int idx, int nelems>
421
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
422
+ if (i < um_vecs) {
423
+ auto &c_reg = zmm[c_regs[i + idx * 3]];
424
+ auto &c_load_reg = zmm[c_load_regs[i % 3]];
425
+ auto c_mem = cox;
426
+ if (is_unit_inc)
427
+ c_mem += i * nelems_in_cache_line;
428
+ else
429
+ c_mem += i * nelems_in_cache_line * inc;
430
+
431
+ if (!is_beta0 && is_alpha1)
432
+ vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
433
+ else if (!is_beta0 && !is_alpha1)
434
+ vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
435
+ else if (is_beta0 && !is_alpha1)
436
+ c_reg = pmul(alpha_reg, c_reg);
437
+
438
+ scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
439
+ }
440
+ }
441
+
442
+ // store_c
443
+ template <int i, int um_vecs, int idx, int nelems>
444
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
445
+ EIGEN_UNUSED_VARIABLE(cox);
446
+ }
447
+
448
+ template <int i, int um_vecs, int idx, int nelems>
449
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
450
+ if (i < um_vecs) {
451
+ auto &c_reg = zmm[c_regs[i + idx * 3]];
452
+ auto c_mem = cox;
453
+ if (is_unit_inc)
454
+ c_mem += i * nelems_in_cache_line;
455
+ else
456
+ c_mem += i * nelems_in_cache_line * inc;
457
+
458
+ c_store<nelems>(c_mem, c_reg);
459
+ c_reg = pzero(c_reg);
460
+
461
+ write_c<i + 1, um_vecs, idx, nelems>(cox);
462
+ }
463
+ }
464
+
465
+ /* C update loop structure.
466
+ * co2 = co1 + ldc;
467
+ *
468
+ * auto &alpha_reg = zmm[alpha_load_reg];
469
+ * if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
470
+ *
471
+ * int idx = 0;
472
+ * for (pow = 1; pow <= 8; pow <<= 1) {
473
+ *
474
+ * if (b_unroll >= pow) {
475
+ * for (count = 1; count < (pow + 1) / 2 + 1; count++) {
476
+ * if (pow >= 4) co2 += ldc;
477
+ *
478
+ * const Scalar *cox = (idx == 0) ? co1 : co2;
479
+ *
480
+ * const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
481
+ * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
482
+ * write_c<0, um_vecs, idx, a_unroll>(cox);
483
+ *
484
+ * idx++;
485
+ * }
486
+ * }
487
+ * }
488
+ *
489
+ * if (b_unroll == 1)
490
+ * co1 += ldc;
491
+ * else
492
+ * co1 = co2 + ldc;
493
+ */
494
+
495
+ template <int pow, int a_unroll, int idx>
496
+ EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) {
497
+ if (pow >= 4) cox += ldc;
498
+
499
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
500
+ auto &alpha_reg = zmm[alpha_load_reg];
501
+
502
+ scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503
+ write_c<0, um_vecs, idx, a_unroll>(cox);
504
+ }
505
+
506
+ template <int pow, int a_unroll>
507
+ EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) {
508
+ constexpr int idx = pow / 2;
509
+ Scalar *&cox = idx == 0 ? co1 : co2;
510
+
511
+ constexpr int max_count = (pow + 1) / 2;
512
+ static_assert(max_count <= 4, "Unsupported max_count.");
513
+
514
+ if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515
+ if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516
+ if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517
+ if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
518
+ }
519
+
520
+ template <int max_b_unroll, int a_unroll, int b_unroll>
521
+ EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) {
522
+ auto &alpha_reg = zmm[alpha_load_reg];
523
+
524
+ co2 = co1 + ldc;
525
+ if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
526
+ if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
527
+
528
+ static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
529
+
530
+ if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531
+ if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532
+ if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533
+ if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
534
+
535
+ if (b_unroll == 1)
536
+ co1 += ldc;
537
+ else
538
+ co1 = co2 + ldc;
539
+ }
540
+
541
+ // compute
542
+ template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
543
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
544
+ int &fetchB_idx, vec &b_reg) {
545
+ EIGEN_UNUSED_VARIABLE(ao);
546
+ EIGEN_UNUSED_VARIABLE(bo);
547
+ EIGEN_UNUSED_VARIABLE(fetchA_idx);
548
+ EIGEN_UNUSED_VARIABLE(fetchB_idx);
549
+ EIGEN_UNUSED_VARIABLE(b_reg);
550
+ }
551
+
552
+ template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
553
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
554
+ int &fetchB_idx, vec &b_reg) {
555
+ if (um < um_vecs) {
556
+ auto &c_reg = zmm[c_regs[um + idx * 3]];
557
+ auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
558
+
559
+ vfmadd(c_reg, a_reg, b_reg);
560
+
561
+ if (!fetch_x && um == 0 &&
562
+ (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
563
+ (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
564
+ prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
565
+ fetchA_idx++;
566
+ }
567
+
568
+ if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
569
+ prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
570
+ fetchB_idx++;
571
+ }
572
+
573
+ compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
574
+ }
575
+ }
576
+
577
+ // load_a
578
+ template <int um, int um_vecs, int uk, int nelems, bool ktail>
579
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) {
580
+ EIGEN_UNUSED_VARIABLE(ao);
581
+ }
582
+
583
+ template <int um, int um_vecs, int uk, int nelems, bool ktail>
584
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) {
585
+ if (um < um_vecs) {
586
+ auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
587
+ const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
588
+ a_load<nelems>(a_reg, a_addr);
589
+
590
+ load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
591
+ }
592
+ }
593
+ template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
594
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
595
+ const Scalar *const &ao,
596
+ const Scalar *const &bo, Scalar *&co2,
597
+ int &fetchA_idx, int &fetchB_idx) {
598
+ EIGEN_UNUSED_VARIABLE(aa);
599
+ EIGEN_UNUSED_VARIABLE(ao);
600
+ EIGEN_UNUSED_VARIABLE(bo);
601
+ EIGEN_UNUSED_VARIABLE(co2);
602
+ EIGEN_UNUSED_VARIABLE(fetchA_idx);
603
+ EIGEN_UNUSED_VARIABLE(fetchB_idx);
604
+ }
605
+
606
+ template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
607
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
608
+ const Scalar *const &ao,
609
+ const Scalar *const &bo, Scalar *&co2,
610
+ int &fetchA_idx, int &fetchB_idx) {
611
+ const int idx = (pow / 2) + count;
612
+
613
+ if (count < (pow + 1) / 2) {
614
+ auto &b_reg = zmm[b_regs[idx % 2]];
615
+
616
+ if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
617
+ if (fetch_x && uk == 3 && idx == 4) aa += 8;
618
+
619
+ if (b_unroll >= pow) {
620
+ compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
621
+
622
+ const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
623
+ b_load(b_reg, b_addr);
624
+ }
625
+
626
+ // Go to the next count.
627
+ innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
628
+ fetchB_idx);
629
+
630
+ } else {
631
+ // Maybe prefetch C data after count-loop.
632
+ if (pow == 2 && c_fetch) {
633
+ if (uk % 3 == 0 && uk > 0) {
634
+ co2 += ldc;
635
+ } else {
636
+ prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
637
+ }
638
+ }
639
+ }
640
+ }
641
+
642
+ template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch,
643
+ bool no_a_preload = false>
644
+ EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
645
+ Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
646
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
647
+
648
+ if (max_b_unroll >= 1)
649
+ innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650
+ if (max_b_unroll >= 2)
651
+ innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652
+ if (max_b_unroll >= 4)
653
+ innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654
+ if (max_b_unroll >= 8)
655
+ innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
656
+
657
+ // Load A after pow-loop. Skip this at the end to prevent running over the buffer
658
+ if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
659
+ }
660
+
661
+ /* Inner kernel loop structure.
662
+ * for (int uk = 0; uk < kfactor; uk++) {
663
+ * int idx = 0;
664
+ *
665
+ * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
666
+ * for (int count = 0; count < (pow + 1) / 2; count++) {
667
+ * auto &b_reg = zmm[b_regs[idx % 2]];
668
+ *
669
+ * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
670
+ * if (fetch_x && uk == 3 && idx == 4) aa += 8;
671
+ *
672
+ * if (b_unroll >= pow) {
673
+ * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
674
+ *
675
+ * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
676
+ * b_load(b_reg, b_addr);
677
+ * }
678
+ * idx++;
679
+ * }
680
+ *
681
+ * Maybe prefetch C data.
682
+ * if (pow == 2 && c_fetch) {
683
+ * if (uk % 3 == 0 && uk > 0) {
684
+ * co2 += ldc;
685
+ * } else {
686
+ * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
687
+ * }
688
+ * }
689
+ * }
690
+ *
691
+ * Load A.
692
+ * load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
693
+ * }
694
+ *
695
+ * Advance A/B pointers after uk-loop.
696
+ * ao += a_unroll * kfactor;
697
+ * bo += b_unroll * kfactor;
698
+ */
699
+
700
+ template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch,
701
+ bool no_a_preload = false>
702
+ EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
703
+ int fetchA_idx = 0;
704
+ int fetchB_idx = 0;
705
+
706
+ const bool fetch_x = k_factor == max_k_factor;
707
+ const bool ktail = k_factor == 1;
708
+
709
+ static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
710
+ static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1),
711
+ "skipping a preload only allowed when k unroll is 1");
712
+
713
+ if (k_factor > 0)
714
+ innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
716
+ if (k_factor > 1)
717
+ innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
719
+ if (k_factor > 2)
720
+ innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
722
+ if (k_factor > 3)
723
+ innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
725
+
726
+ // Advance A/B pointers after uk-loop.
727
+ ao += a_unroll * k_factor;
728
+ bo += b_unroll * k_factor;
729
+ }
730
+
731
+ template <int a_unroll, int b_unroll, int max_b_unroll>
732
+ EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
733
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
734
+ if (!use_less_a_regs && k > 1)
735
+ a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
736
+ else
737
+ a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
738
+
739
+ b_load(zmm[b_regs[0]], bo - b_shift + 0);
740
+ if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
741
+
742
+ #ifndef SECOND_FETCH
743
+ prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
744
+ #endif // SECOND_FETCH
745
+
746
+ // Unrolling k-loop by a factor of 4.
747
+ const int max_k_factor = 4;
748
+ Index kRem = k % max_k_factor;
749
+ Index k_ = k - kRem;
750
+ if (k_ >= max_k_factor) {
751
+ k_ -= max_k_factor;
752
+ kRem += max_k_factor;
753
+ }
754
+ Index loop_count = k_ / max_k_factor;
755
+
756
+ if (loop_count > 0) {
757
+ #ifdef SECOND_FETCH
758
+ loop_count -= SECOND_FETCH;
759
+ #endif
760
+ while (loop_count > 0) {
761
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
762
+ loop_count--;
763
+ }
764
+ #ifdef SECOND_FETCH
765
+ co2 = co1 + nelems_in_cache_line - 1;
766
+
767
+ loop_count += b_unroll;
768
+ while (loop_count > 0) {
769
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
770
+ loop_count--;
771
+ }
772
+
773
+ loop_count += SECOND_FETCH - b_unroll;
774
+ while (loop_count > 0) {
775
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
776
+ loop_count--;
777
+ }
778
+ #endif
779
+ }
780
+
781
+ // k-loop remainder handling.
782
+ loop_count = kRem;
783
+ while (loop_count > 1) {
784
+ innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
785
+ loop_count--;
786
+ }
787
+ if (loop_count > 0) {
788
+ innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
789
+ }
790
+
791
+ // Update C matrix.
792
+ c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
793
+ }
794
+
795
+ template <int a_unroll, int b_unroll, int max_b_unroll>
796
+ EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
797
+ // Set A matrix pointer.
798
+ ao = a + a_off * a_unroll;
799
+
800
+ // Set B matrix pointer if needed.
801
+ bo += b_unroll * b_off;
802
+
803
+ kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
804
+
805
+ // Advance B matrix pointer if needed.
806
+ bo += b_unroll * (b_stride - k - b_off);
807
+
808
+ // Advance prefetch A pointer.
809
+ aa += 16;
810
+ }
811
+
812
+ template <int a_unroll, int max_a_unroll, int max_b_unroll>
813
+ EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
814
+ // Set prefetch A pointers.
815
+ const Scalar *aa = a + a_unroll * a_stride;
816
+
817
+ // Set C matrix pointers.
818
+ co1 = c;
819
+ if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
820
+ if (is_unit_inc)
821
+ c += a_unroll;
822
+ else
823
+ c += a_unroll * inc;
824
+
825
+ // Set B matrix pointer.
826
+ bo = b;
827
+
828
+ // Main n-loop.
829
+ for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
830
+
831
+ // n-remainders.
832
+ if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
833
+ #if 0
834
+ if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
835
+ if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
836
+ #else
837
+ // Copy kernels don't support tails of n = 2 for single/double precision.
838
+ // Loop over ones.
839
+ int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
840
+ while (n_rem > 0) {
841
+ nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
842
+ n_rem--;
843
+ }
844
+ #endif
845
+
846
+ // Advance A matrix pointer.
847
+ a = ao + a_unroll * (a_stride - k - a_off);
848
+ }
849
+
850
+ public:
851
+ // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
852
+ template <int max_a_unroll, int max_b_unroll>
853
+ EIGEN_ALWAYS_INLINE void compute_kern() {
854
+ a -= -a_shift;
855
+ b -= -b_shift;
856
+
857
+ const Scalar *ao = nullptr;
858
+ const Scalar *bo = nullptr;
859
+ Scalar *co1 = nullptr;
860
+ Scalar *co2 = nullptr;
861
+
862
+ // Main m-loop.
863
+ for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
864
+
865
+ // m-remainders.
866
+ if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867
+ if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868
+ if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869
+ if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870
+ if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
871
+ if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
872
+
873
+ // Copy kernels don't support tails of m = 2 for single precision.
874
+ // Loop over ones.
875
+ if (is_f32) {
876
+ int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
877
+ while (m_rem > 0) {
878
+ mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
879
+ m_rem--;
880
+ }
881
+ }
882
+ }
883
+
884
+ gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_,
885
+ const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_,
886
+ Index a_off_, Index b_off_)
887
+ : m(m_),
888
+ n(n_),
889
+ k(k_),
890
+ ldc(ldc_),
891
+ inc(inc_),
892
+ alpha(alpha_),
893
+ a(a_),
894
+ b(b_),
895
+ c(c_),
896
+ is_alpha1(is_alpha1_),
897
+ is_beta0(is_beta0_),
898
+ a_stride(a_stride_),
899
+ b_stride(b_stride_),
900
+ a_off(a_off_),
901
+ b_off(b_off_) {
902
+ // Zero out all accumulation registers.
903
+ zmm[8] = pzero(zmm[8]);
904
+ zmm[9] = pzero(zmm[9]);
905
+ zmm[10] = pzero(zmm[10]);
906
+ zmm[11] = pzero(zmm[11]);
907
+ zmm[12] = pzero(zmm[12]);
908
+ zmm[13] = pzero(zmm[13]);
909
+ zmm[14] = pzero(zmm[14]);
910
+ zmm[15] = pzero(zmm[15]);
911
+ zmm[16] = pzero(zmm[16]);
912
+ zmm[17] = pzero(zmm[17]);
913
+ zmm[18] = pzero(zmm[18]);
914
+ zmm[19] = pzero(zmm[19]);
915
+ zmm[20] = pzero(zmm[20]);
916
+ zmm[21] = pzero(zmm[21]);
917
+ zmm[22] = pzero(zmm[22]);
918
+ zmm[23] = pzero(zmm[23]);
919
+ zmm[24] = pzero(zmm[24]);
920
+ zmm[25] = pzero(zmm[25]);
921
+ zmm[26] = pzero(zmm[26]);
922
+ zmm[27] = pzero(zmm[27]);
923
+ zmm[28] = pzero(zmm[28]);
924
+ zmm[29] = pzero(zmm[29]);
925
+ zmm[30] = pzero(zmm[30]);
926
+ zmm[31] = pzero(zmm[31]);
927
+ }
928
+ };
929
+
930
+ // Compute kernel with max unroll support of:
931
+ // Single precision:
932
+ // max_a_unroll: 48, 32, 16, 8, 4, 2, 1
933
+ // max_b_unroll: 8, 4, 2, 1
934
+ // Double precision:
935
+ // max_a_unroll: 24, 16, 8, 4, 2, 1
936
+ // max_b_unroll: 8, 4, 2, 1
937
+ template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0, bool is_unit_inc>
938
+ EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b,
939
+ Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
940
+ Index a_off = 0, Index b_off = 0) {
941
+ if (a_stride == -1) a_stride = k;
942
+ if (b_stride == -1) b_stride = k;
943
+
944
+ gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
945
+ b_off);
946
+ g.template compute_kern<max_a_unroll, max_b_unroll>();
947
+ }
948
+
949
+ // Template specializations of GEBP kernels with nr = 8.
950
+ #if EIGEN_USE_AVX512_GEMM_KERNELS
951
+ template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
952
+ class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
953
+ : public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
954
+ using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
955
+
956
+ public:
957
+ enum { nr = Base::Vectorizable ? 8 : 4 };
958
+ };
959
+
960
+ template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
961
+ class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
962
+ : public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
963
+ using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
964
+
965
+ public:
966
+ enum { nr = Base::Vectorizable ? 8 : 4 };
967
+ };
968
+
969
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
970
+ struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode> {
971
+ typedef typename packet_traits<Scalar>::type Packet;
972
+ typedef typename DataMapper::LinearMapper LinearMapper;
973
+ enum { PacketSize = packet_traits<Scalar>::size };
974
+ EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
975
+ Index offset = 0);
976
+ };
977
+
978
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
979
+ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>::operator()(
980
+ Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
981
+ constexpr int nr = 8;
982
+ EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
983
+ EIGEN_UNUSED_VARIABLE(stride);
984
+ EIGEN_UNUSED_VARIABLE(offset);
985
+ eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
986
+ conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
987
+ Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
988
+ Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
989
+ Index count = 0;
990
+ const Index peeled_k = (depth / PacketSize) * PacketSize;
991
+ if (nr >= 8) {
992
+ for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
993
+ // skip what we have before
994
+ if (PanelMode) count += 8 * offset;
995
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
996
+ const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
997
+ const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
998
+ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
999
+ const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
1000
+ const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
1001
+ const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1002
+ const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1003
+ Index k = 0;
1004
+ if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4
1005
+ {
1006
+ for (; k < peeled_k; k += PacketSize) {
1007
+ PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1008
+
1009
+ kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1010
+ kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1011
+ kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1012
+ kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1013
+ kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1014
+ kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1015
+ kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1016
+ kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1017
+
1018
+ ptranspose(kernel);
1019
+
1020
+ pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1021
+ pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1022
+ pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1023
+ pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1024
+ pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1025
+ pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1026
+ pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1027
+ pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1028
+ count += 8 * PacketSize;
1029
+ }
1030
+ }
1031
+ for (; k < depth; k++) {
1032
+ blockB[count + 0] = cj(dm0(k));
1033
+ blockB[count + 1] = cj(dm1(k));
1034
+ blockB[count + 2] = cj(dm2(k));
1035
+ blockB[count + 3] = cj(dm3(k));
1036
+ blockB[count + 4] = cj(dm4(k));
1037
+ blockB[count + 5] = cj(dm5(k));
1038
+ blockB[count + 6] = cj(dm6(k));
1039
+ blockB[count + 7] = cj(dm7(k));
1040
+ count += 8;
1041
+ }
1042
+ // skip what we have after
1043
+ if (PanelMode) count += 8 * (stride - offset - depth);
1044
+ }
1045
+ }
1046
+
1047
+ if (nr >= 4) {
1048
+ for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1049
+ // skip what we have before
1050
+ if (PanelMode) count += 4 * offset;
1051
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1052
+ const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1053
+ const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1054
+ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1055
+
1056
+ Index k = 0;
1057
+ if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
1058
+ {
1059
+ for (; k < peeled_k; k += PacketSize) {
1060
+ PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1061
+ kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1062
+ kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1063
+ kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1064
+ kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1065
+ ptranspose(kernel);
1066
+ pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1067
+ pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1068
+ pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1069
+ pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1070
+ count += 4 * PacketSize;
1071
+ }
1072
+ }
1073
+ for (; k < depth; k++) {
1074
+ blockB[count + 0] = cj(dm0(k));
1075
+ blockB[count + 1] = cj(dm1(k));
1076
+ blockB[count + 2] = cj(dm2(k));
1077
+ blockB[count + 3] = cj(dm3(k));
1078
+ count += 4;
1079
+ }
1080
+ // skip what we have after
1081
+ if (PanelMode) count += 4 * (stride - offset - depth);
1082
+ }
1083
+ }
1084
+
1085
+ // copy the remaining columns one at a time (nr==1)
1086
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1087
+ if (PanelMode) count += offset;
1088
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1089
+ for (Index k = 0; k < depth; k++) {
1090
+ blockB[count] = cj(dm0(k));
1091
+ count += 1;
1092
+ }
1093
+ if (PanelMode) count += (stride - offset - depth);
1094
+ }
1095
+ }
1096
+
1097
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
1098
+ struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, RowMajor, Conjugate, PanelMode> {
1099
+ typedef typename packet_traits<Scalar>::type Packet;
1100
+ typedef typename unpacket_traits<Packet>::half HalfPacket;
1101
+ typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1102
+ typedef typename DataMapper::LinearMapper LinearMapper;
1103
+ enum {
1104
+ PacketSize = packet_traits<Scalar>::size,
1105
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
1106
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size
1107
+ };
1108
+ EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
1109
+ Index offset = 0) {
1110
+ constexpr int nr = 8;
1111
+ EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
1112
+ EIGEN_UNUSED_VARIABLE(stride);
1113
+ EIGEN_UNUSED_VARIABLE(offset);
1114
+ eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1115
+ const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
1116
+ const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
1117
+ conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1118
+ Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1119
+ Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1120
+ Index count = 0;
1121
+
1122
+ if (nr >= 8) {
1123
+ for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1124
+ // skip what we have before
1125
+ if (PanelMode) count += 8 * offset;
1126
+ for (Index k = 0; k < depth; k++) {
1127
+ if (PacketSize == 8) {
1128
+ // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1129
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1130
+ pstoreu(blockB + count, cj.pconj(A));
1131
+ } else if (HasHalf && HalfPacketSize == 8) {
1132
+ HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1133
+ pstoreu(blockB + count, cj.pconj(A));
1134
+ } else if (HasQuarter && QuarterPacketSize == 8) {
1135
+ QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1136
+ pstoreu(blockB + count, cj.pconj(A));
1137
+ } else if (PacketSize == 4) {
1138
+ // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1139
+ // Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
1140
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1141
+ Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1142
+ pstoreu(blockB + count, cj.pconj(A));
1143
+ pstoreu(blockB + count + PacketSize, cj.pconj(B));
1144
+ } else {
1145
+ // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
1146
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1147
+ blockB[count + 0] = cj(dm0(0));
1148
+ blockB[count + 1] = cj(dm0(1));
1149
+ blockB[count + 2] = cj(dm0(2));
1150
+ blockB[count + 3] = cj(dm0(3));
1151
+ blockB[count + 4] = cj(dm0(4));
1152
+ blockB[count + 5] = cj(dm0(5));
1153
+ blockB[count + 6] = cj(dm0(6));
1154
+ blockB[count + 7] = cj(dm0(7));
1155
+ }
1156
+ count += 8;
1157
+ }
1158
+ // skip what we have after
1159
+ if (PanelMode) count += 8 * (stride - offset - depth);
1160
+ }
1161
+ }
1162
+
1163
+ if (nr >= 4) {
1164
+ for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1165
+ // skip what we have before
1166
+ if (PanelMode) count += 4 * offset;
1167
+ for (Index k = 0; k < depth; k++) {
1168
+ if (PacketSize == 4) {
1169
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1170
+ pstoreu(blockB + count, cj.pconj(A));
1171
+ count += PacketSize;
1172
+ } else if (HasHalf && HalfPacketSize == 4) {
1173
+ HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1174
+ pstoreu(blockB + count, cj.pconj(A));
1175
+ count += HalfPacketSize;
1176
+ } else if (HasQuarter && QuarterPacketSize == 4) {
1177
+ QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1178
+ pstoreu(blockB + count, cj.pconj(A));
1179
+ count += QuarterPacketSize;
1180
+ } else {
1181
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1182
+ blockB[count + 0] = cj(dm0(0));
1183
+ blockB[count + 1] = cj(dm0(1));
1184
+ blockB[count + 2] = cj(dm0(2));
1185
+ blockB[count + 3] = cj(dm0(3));
1186
+ count += 4;
1187
+ }
1188
+ }
1189
+ // skip what we have after
1190
+ if (PanelMode) count += 4 * (stride - offset - depth);
1191
+ }
1192
+ }
1193
+ // copy the remaining columns one at a time (nr==1)
1194
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1195
+ if (PanelMode) count += offset;
1196
+ for (Index k = 0; k < depth; k++) {
1197
+ blockB[count] = cj(rhs(k, j2));
1198
+ count += 1;
1199
+ }
1200
+ if (PanelMode) count += stride - offset - depth;
1201
+ }
1202
+ }
1203
+ };
1204
+
1205
+ template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1206
+ struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1207
+ EIGEN_ALWAYS_INLINE void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows,
1208
+ Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1,
1209
+ Index offsetA = 0, Index offsetB = 0);
1210
+ };
1211
+
1212
+ template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1213
+ EIGEN_ALWAYS_INLINE void gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs>::operator()(
1214
+ const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols,
1215
+ Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
1216
+ if (res.incr() == 1) {
1217
+ if (alpha == 1) {
1218
+ gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1219
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1220
+ strideB, offsetA, offsetB);
1221
+ } else {
1222
+ gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1223
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1224
+ strideB, offsetA, offsetB);
1225
+ }
1226
+ } else {
1227
+ if (alpha == 1) {
1228
+ gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1229
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1230
+ strideB, offsetA, offsetB);
1231
+ } else {
1232
+ gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1233
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1234
+ strideB, offsetA, offsetB);
1235
+ }
1236
+ }
1237
+ }
1238
+ #endif // EIGEN_USE_AVX512_GEMM_KERNELS
1239
+
1240
+ } // namespace internal
1241
+ } // namespace Eigen
1242
+
1243
+ #undef SECOND_FETCH
1244
+
1245
+ #endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H