@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -0,0 +1,1219 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11
+ #define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
12
+
13
+ template <bool isARowMajor = true>
14
+ EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15
+ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
16
+ else return i + j * LDA;
17
+ }
18
+
19
+ /**
20
+ * This namespace contains various classes used to generate compile-time unrolls which are
21
+ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
22
+ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
23
+ *
24
+ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
25
+ *
26
+ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
27
+ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
28
+ * func(startI,startJ) startJ = (startC)%(endJ)
29
+ * func(...)
30
+ *
31
+ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxiliary function
32
+ * with a template parameter used as a counter.
33
+ *
34
+ * template <endI, endJ, counter>
35
+ * std::enable_if_t<(counter <= 0)> <---- tail case.
36
+ * aux_func {}
37
+ *
38
+ * template <endI, endJ, counter>
39
+ * std::enable_if_t<(counter > 0)> <---- actual for-loop
40
+ * aux_func {
41
+ * startC = endI*endJ - counter
42
+ * startI = (startC)/(endJ)
43
+ * startJ = (startC)%(endJ)
44
+ * func(startI, startJ)
45
+ * aux_func<endI, endJ, counter-1>()
46
+ * }
47
+ *
48
+ * Note: Additional wrapper functions are provided for aux_func which hides the counter template
49
+ * parameter since counter usually depends on endI, endJ, etc...
50
+ *
51
+ * Conventions:
52
+ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
53
+ *
54
+ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
55
+ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
56
+ */
57
+ namespace unrolls {
58
+
59
+ template <int64_t N>
60
+ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
61
+ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
62
+ else EIGEN_IF_CONSTEXPR(N == 8) {
63
+ return 0xFF >> (8 - m);
64
+ }
65
+ else EIGEN_IF_CONSTEXPR(N == 4) {
66
+ return 0x0F >> (4 - m);
67
+ }
68
+ return 0;
69
+ }
70
+
71
+ template <typename Packet>
72
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
73
+
74
+ template <>
75
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76
+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77
+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78
+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79
+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80
+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81
+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82
+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83
+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
84
+
85
+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86
+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87
+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88
+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89
+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90
+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91
+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92
+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
93
+
94
+ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95
+ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96
+ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97
+ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98
+ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99
+ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100
+ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101
+ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102
+ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103
+ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104
+ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105
+ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106
+ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107
+ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108
+ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109
+ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
110
+
111
+ kernel.packet[0] = T0;
112
+ kernel.packet[1] = T1;
113
+ kernel.packet[2] = T2;
114
+ kernel.packet[3] = T3;
115
+ kernel.packet[4] = T4;
116
+ kernel.packet[5] = T5;
117
+ kernel.packet[6] = T6;
118
+ kernel.packet[7] = T7;
119
+ }
120
+
121
+ template <>
122
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
123
+ ptranspose(kernel);
124
+ }
125
+
126
+ /***
127
+ * Unrolls for transposed C stores
128
+ */
129
+ template <typename Scalar>
130
+ class trans {
131
+ public:
132
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
135
+
136
+ /***********************************
137
+ * Auxiliary Functions for:
138
+ * - storeC
139
+ ***********************************
140
+ */
141
+
142
+ /**
143
+ * aux_storeC
144
+ *
145
+ * 1-D unroll
146
+ * for(startN = 0; startN < endN; startN++)
147
+ *
148
+ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
149
+ *
150
+ **/
151
+ template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
152
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154
+ constexpr int64_t counterReverse = endN - counter;
155
+ constexpr int64_t startN = counterReverse;
156
+
157
+ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158
+ EIGEN_IF_CONSTEXPR(remM) {
159
+ pstoreu<Scalar>(
160
+ C_arr + LDC * startN,
161
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
165
+ }
166
+ else {
167
+ pstoreu<Scalar>(C_arr + LDC * startN,
168
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
169
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
170
+ }
171
+ }
172
+ else { // This block is only needed for fp32 case
173
+ // Reinterpret as __m512 for _mm512_shuffle_f32x4
174
+ vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
176
+ // Swap lower and upper half of avx register.
177
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178
+ preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
179
+
180
+ EIGEN_IF_CONSTEXPR(remM) {
181
+ pstoreu<Scalar>(
182
+ C_arr + LDC * startN,
183
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184
+ preinterpret<vecHalf>(
185
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
187
+ }
188
+ else {
189
+ pstoreu<Scalar>(
190
+ C_arr + LDC * startN,
191
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
192
+ preinterpret<vecHalf>(
193
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
194
+ }
195
+ }
196
+ aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
197
+ }
198
+
199
+ template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
200
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202
+ EIGEN_UNUSED_VARIABLE(C_arr);
203
+ EIGEN_UNUSED_VARIABLE(LDC);
204
+ EIGEN_UNUSED_VARIABLE(zmm);
205
+ EIGEN_UNUSED_VARIABLE(remM_);
206
+ }
207
+
208
+ template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
209
+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
210
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
211
+ int64_t remM_ = 0) {
212
+ aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
213
+ }
214
+
215
+ /**
216
+ * Transposes LxunrollN row major block of matrices stored `EIGEN_AVX_MAX_NUM_ACC` zmm registers to
217
+ * "unrollN"xL ymm registers to be stored col-major into C.
218
+ *
219
+ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
220
+ *
221
+ * ```
222
+ * row0: zmm0 zmm1 zmm2
223
+ * row1: zmm3 zmm4 zmm5
224
+ * .
225
+ * .
226
+ * row7: zmm21 zmm22 zmm23
227
+ *
228
+ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
229
+ *
230
+ * row0: zmm0 zmm1
231
+ * row1: zmm2 zmm3
232
+ * .
233
+ * .
234
+ * row7: zmm14 zmm15
235
+ * ```
236
+ *
237
+ * In general we will have {1,2,3} groups of avx registers each of size
238
+ * `EIGEN_AVX_MAX_NUM_ROW`. packetIndexOffset is used to select which "block" of
239
+ * avx registers are being transposed.
240
+ */
241
+ template <int64_t unrollN, int64_t packetIndexOffset>
242
+ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
243
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
244
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
245
+ constexpr int64_t zmmStride = unrollN / PacketSize;
246
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
247
+ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
248
+ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
249
+ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
250
+ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
251
+ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
252
+ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
253
+ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
254
+ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
255
+ trans8x8blocks(r);
256
+ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
257
+ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
258
+ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
259
+ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
260
+ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
261
+ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
262
+ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
263
+ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
264
+ }
265
+ };
266
+
267
+ /**
268
+ * Unrolls for copyBToRowMajor
269
+ *
270
+ * Idea:
271
+ * 1) Load a block of right-hand sides to registers (using loadB).
272
+ * 2) Convert the block from column-major to row-major (transposeLxL)
273
+ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
274
+ *
275
+ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
276
+ * used as temps for transposing.
277
+ *
278
+ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
279
+ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
280
+ */
281
+ template <typename Scalar>
282
+ class transB {
283
+ public:
284
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
285
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
286
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
287
+
288
+ /***********************************
289
+ * Auxiliary Functions for:
290
+ * - loadB
291
+ * - storeB
292
+ * - loadBBlock
293
+ * - storeBBlock
294
+ ***********************************
295
+ */
296
+
297
+ /**
298
+ * aux_loadB
299
+ *
300
+ * 1-D unroll
301
+ * for(startN = 0; startN < endN; startN++)
302
+ **/
303
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
304
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
305
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
306
+ int64_t remM_ = 0) {
307
+ constexpr int64_t counterReverse = endN - counter;
308
+ constexpr int64_t startN = counterReverse;
309
+
310
+ EIGEN_IF_CONSTEXPR(remM) {
311
+ ymm.packet[packetIndexOffset + startN] =
312
+ ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
313
+ }
314
+ else {
315
+ EIGEN_IF_CONSTEXPR(remN_ == 0) {
316
+ ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
317
+ }
318
+ else ymm.packet[packetIndexOffset + startN] =
319
+ ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
320
+ }
321
+
322
+ aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
323
+ }
324
+
325
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
326
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
327
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
328
+ int64_t remM_ = 0) {
329
+ EIGEN_UNUSED_VARIABLE(B_arr);
330
+ EIGEN_UNUSED_VARIABLE(LDB);
331
+ EIGEN_UNUSED_VARIABLE(ymm);
332
+ EIGEN_UNUSED_VARIABLE(remM_);
333
+ }
334
+
335
+ /**
336
+ * aux_storeB
337
+ *
338
+ * 1-D unroll
339
+ * for(startN = 0; startN < endN; startN++)
340
+ **/
341
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
342
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
343
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
344
+ constexpr int64_t counterReverse = endN - counter;
345
+ constexpr int64_t startN = counterReverse;
346
+
347
+ EIGEN_IF_CONSTEXPR(remK || remM) {
348
+ pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
349
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
350
+ }
351
+ else {
352
+ pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
353
+ }
354
+
355
+ aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
356
+ }
357
+
358
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
359
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
360
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
361
+ EIGEN_UNUSED_VARIABLE(B_arr);
362
+ EIGEN_UNUSED_VARIABLE(LDB);
363
+ EIGEN_UNUSED_VARIABLE(ymm);
364
+ EIGEN_UNUSED_VARIABLE(rem_);
365
+ }
366
+
367
+ /**
368
+ * aux_loadBBlock
369
+ *
370
+ * 1-D unroll
371
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
372
+ **/
373
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
374
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
375
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
376
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
377
+ constexpr int64_t counterReverse = endN - counter;
378
+ constexpr int64_t startN = counterReverse;
379
+ transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
380
+ aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
381
+ }
382
+
383
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
384
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
385
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
386
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
387
+ EIGEN_UNUSED_VARIABLE(B_arr);
388
+ EIGEN_UNUSED_VARIABLE(LDB);
389
+ EIGEN_UNUSED_VARIABLE(B_temp);
390
+ EIGEN_UNUSED_VARIABLE(LDB_);
391
+ EIGEN_UNUSED_VARIABLE(ymm);
392
+ EIGEN_UNUSED_VARIABLE(remM_);
393
+ }
394
+
395
+ /**
396
+ * aux_storeBBlock
397
+ *
398
+ * 1-D unroll
399
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
400
+ **/
401
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
402
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
403
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
404
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
405
+ constexpr int64_t counterReverse = endN - counter;
406
+ constexpr int64_t startN = counterReverse;
407
+
408
+ EIGEN_IF_CONSTEXPR(toTemp) {
409
+ transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
410
+ }
411
+ else {
412
+ transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
413
+ ymm, remM_);
414
+ }
415
+ aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
416
+ }
417
+
418
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
419
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
420
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
421
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
422
+ EIGEN_UNUSED_VARIABLE(B_arr);
423
+ EIGEN_UNUSED_VARIABLE(LDB);
424
+ EIGEN_UNUSED_VARIABLE(B_temp);
425
+ EIGEN_UNUSED_VARIABLE(LDB_);
426
+ EIGEN_UNUSED_VARIABLE(ymm);
427
+ EIGEN_UNUSED_VARIABLE(remM_);
428
+ }
429
+
430
+ /********************************************************
431
+ * Wrappers for aux_XXXX to hide counter parameter
432
+ ********************************************************/
433
+
434
+ template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
435
+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
436
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
437
+ int64_t remM_ = 0) {
438
+ aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
439
+ }
440
+
441
+ template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
442
+ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
443
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
444
+ int64_t rem_ = 0) {
445
+ aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
446
+ }
447
+
448
+ template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
449
+ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
450
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
451
+ int64_t remM_ = 0) {
452
+ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
453
+ else {
454
+ aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
455
+ }
456
+ }
457
+
458
+ template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
459
+ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
460
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
461
+ int64_t remM_ = 0) {
462
+ aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
463
+ }
464
+
465
+ template <int64_t packetIndexOffset>
466
+ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
467
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
468
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
469
+ PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
470
+ r.packet[0] = ymm.packet[packetIndexOffset + 0];
471
+ r.packet[1] = ymm.packet[packetIndexOffset + 1];
472
+ r.packet[2] = ymm.packet[packetIndexOffset + 2];
473
+ r.packet[3] = ymm.packet[packetIndexOffset + 3];
474
+ r.packet[4] = ymm.packet[packetIndexOffset + 4];
475
+ r.packet[5] = ymm.packet[packetIndexOffset + 5];
476
+ r.packet[6] = ymm.packet[packetIndexOffset + 6];
477
+ r.packet[7] = ymm.packet[packetIndexOffset + 7];
478
+ ptranspose(r);
479
+ ymm.packet[packetIndexOffset + 0] = r.packet[0];
480
+ ymm.packet[packetIndexOffset + 1] = r.packet[1];
481
+ ymm.packet[packetIndexOffset + 2] = r.packet[2];
482
+ ymm.packet[packetIndexOffset + 3] = r.packet[3];
483
+ ymm.packet[packetIndexOffset + 4] = r.packet[4];
484
+ ymm.packet[packetIndexOffset + 5] = r.packet[5];
485
+ ymm.packet[packetIndexOffset + 6] = r.packet[6];
486
+ ymm.packet[packetIndexOffset + 7] = r.packet[7];
487
+ }
488
+
489
+ template <int64_t unrollN, bool toTemp, bool remM>
490
+ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
491
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
492
+ int64_t remM_ = 0) {
493
+ constexpr int64_t U3 = PacketSize * 3;
494
+ constexpr int64_t U2 = PacketSize * 2;
495
+ constexpr int64_t U1 = PacketSize * 1;
496
+ /**
497
+ * Unrolls needed for each case:
498
+ * - AVX512 fp32 48 32 16 8 4 2 1
499
+ * - AVX512 fp64 24 16 8 4 2 1
500
+ *
501
+ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
502
+ */
503
+ EIGEN_IF_CONSTEXPR(unrollN == U3) {
504
+ // load LxU3 B col major, transpose LxU3 row major
505
+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
506
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
507
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509
+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
510
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
511
+
512
+ EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
513
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
514
+ ymm, remM_);
515
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517
+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
518
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
519
+ ymm, remM_);
520
+ }
521
+ }
522
+ else EIGEN_IF_CONSTEXPR(unrollN == U2) {
523
+ // load LxU2 B col major, transpose LxU2 row major
524
+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
525
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
526
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528
+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
529
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
530
+
531
+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
532
+ transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
533
+ &B_temp[maxUBlock], LDB_, ymm, remM_);
534
+ transB::template transposeLxL<0>(ymm);
535
+ transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
536
+ &B_temp[maxUBlock], LDB_, ymm, remM_);
537
+ }
538
+ }
539
+ else EIGEN_IF_CONSTEXPR(unrollN == U1) {
540
+ // load LxU1 B col major, transpose LxU1 row major
541
+ transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
542
+ transB::template transposeLxL<0>(ymm);
543
+ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
544
+ transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
545
+ }
546
+ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
547
+ // load Lx4 B col major, transpose Lx4 row major
548
+ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
549
+ transB::template transposeLxL<0>(ymm);
550
+ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
551
+ }
552
+ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
553
+ // load Lx4 B col major, transpose Lx4 row major
554
+ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
555
+ transB::template transposeLxL<0>(ymm);
556
+ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
557
+ }
558
+ else EIGEN_IF_CONSTEXPR(unrollN == 2) {
559
+ // load Lx2 B col major, transpose Lx2 row major
560
+ transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
561
+ transB::template transposeLxL<0>(ymm);
562
+ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
563
+ }
564
+ else EIGEN_IF_CONSTEXPR(unrollN == 1) {
565
+ // load Lx1 B col major, transpose Lx1 row major
566
+ transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
567
+ transB::template transposeLxL<0>(ymm);
568
+ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
569
+ }
570
+ }
571
+ };
572
+
573
+ /**
574
+ * Unrolls for triSolveKernel
575
+ *
576
+ * Idea:
577
+ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
578
+ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
579
+ * stored in AInPacket (using triSolveMicroKernel).
580
+ * 3) Store final results (in avx registers) back into memory (using storeRHS).
581
+ *
582
+ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
583
+ * EIGEN_AVX_MAX_NUM_ROW registers.
584
+ */
585
+ template <typename Scalar>
586
+ class trsm {
587
+ public:
588
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
589
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
590
+
591
+ /***********************************
592
+ * Auxiliary Functions for:
593
+ * - loadRHS
594
+ * - storeRHS
595
+ * - divRHSByDiag
596
+ * - updateRHS
597
+ * - triSolveMicroKernel
598
+ ************************************/
599
+ /**
600
+ * aux_loadRHS
601
+ *
602
+ * 2-D unroll
603
+ * for(startM = 0; startM < endM; startM++)
604
+ * for(startK = 0; startK < endK; startK++)
605
+ **/
606
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
607
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
608
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
609
+ constexpr int64_t counterReverse = endM * endK - counter;
610
+ constexpr int64_t startM = counterReverse / (endK);
611
+ constexpr int64_t startK = counterReverse % endK;
612
+
613
+ constexpr int64_t packetIndex = startM * endK + startK;
614
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
615
+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
616
+ EIGEN_IF_CONSTEXPR(krem) {
617
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
618
+ }
619
+ else {
620
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
621
+ }
622
+ aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
623
+ }
624
+
625
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
626
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
627
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
628
+ EIGEN_UNUSED_VARIABLE(B_arr);
629
+ EIGEN_UNUSED_VARIABLE(LDB);
630
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
631
+ EIGEN_UNUSED_VARIABLE(rem);
632
+ }
633
+
634
+ /**
635
+ * aux_storeRHS
636
+ *
637
+ * 2-D unroll
638
+ * for(startM = 0; startM < endM; startM++)
639
+ * for(startK = 0; startK < endK; startK++)
640
+ **/
641
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
642
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
643
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
644
+ constexpr int64_t counterReverse = endM * endK - counter;
645
+ constexpr int64_t startM = counterReverse / (endK);
646
+ constexpr int64_t startK = counterReverse % endK;
647
+
648
+ constexpr int64_t packetIndex = startM * endK + startK;
649
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
650
+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
651
+ EIGEN_IF_CONSTEXPR(krem) {
652
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
653
+ }
654
+ else {
655
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
656
+ }
657
+ aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
658
+ }
659
+
660
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
661
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
662
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
663
+ EIGEN_UNUSED_VARIABLE(B_arr);
664
+ EIGEN_UNUSED_VARIABLE(LDB);
665
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
666
+ EIGEN_UNUSED_VARIABLE(rem);
667
+ }
668
+
669
+ /**
670
+ * aux_divRHSByDiag
671
+ *
672
+ * currM may be -1, (currM >=0) in enable_if checks for this
673
+ *
674
+ * 1-D unroll
675
+ * for(startK = 0; startK < endK; startK++)
676
+ **/
677
+ template <int64_t currM, int64_t endK, int64_t counter>
678
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
679
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
680
+ constexpr int64_t counterReverse = endK - counter;
681
+ constexpr int64_t startK = counterReverse;
682
+
683
+ constexpr int64_t packetIndex = currM * endK + startK;
684
+ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
685
+ aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
686
+ }
687
+
688
+ template <int64_t currM, int64_t endK, int64_t counter>
689
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
690
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
691
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
692
+ EIGEN_UNUSED_VARIABLE(AInPacket);
693
+ }
694
+
695
+ /**
696
+ * aux_updateRHS
697
+ *
698
+ * 2-D unroll
699
+ * for(startM = initM; startM < endM; startM++)
700
+ * for(startK = 0; startK < endK; startK++)
701
+ **/
702
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
703
+ int64_t counter, int64_t currentM>
704
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
705
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
706
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
707
+ constexpr int64_t counterReverse = (endM - initM) * endK - counter;
708
+ constexpr int64_t startM = initM + counterReverse / (endK);
709
+ constexpr int64_t startK = counterReverse % endK;
710
+
711
+ // For each row of A, first update all corresponding RHS
712
+ constexpr int64_t packetIndex = startM * endK + startK;
713
+ EIGEN_IF_CONSTEXPR(currentM > 0) {
714
+ RHSInPacket.packet[packetIndex] =
715
+ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
716
+ RHSInPacket.packet[packetIndex]);
717
+ }
718
+
719
+ EIGEN_IF_CONSTEXPR(startK == endK - 1) {
720
+ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
721
+ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
722
+ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
723
+ // This will be used in divRHSByDiag
724
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
725
+ AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
726
+ else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
727
+ }
728
+ else {
729
+ // Broadcast next off diagonal element of A
730
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
731
+ AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
732
+ else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
733
+ }
734
+ }
735
+
736
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
737
+ A_arr, LDA, RHSInPacket, AInPacket);
738
+ }
739
+
740
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
741
+ int64_t counter, int64_t currentM>
742
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
743
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
744
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
745
+ EIGEN_UNUSED_VARIABLE(A_arr);
746
+ EIGEN_UNUSED_VARIABLE(LDA);
747
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
748
+ EIGEN_UNUSED_VARIABLE(AInPacket);
749
+ }
750
+
751
+ /**
752
+ * aux_triSolverMicroKernel
753
+ *
754
+ * 1-D unroll
755
+ * for(startM = 0; startM < endM; startM++)
756
+ **/
757
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
758
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
759
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
760
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
761
+ constexpr int64_t counterReverse = endM - counter;
762
+ constexpr int64_t startM = counterReverse;
763
+
764
+ constexpr int64_t currentM = startM;
765
+ // Divides the right-hand side in row startM, by digonal value of A
766
+ // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
767
+ //
768
+ // Without "if constexpr" the compiler instantiates the case <-1, numK>
769
+ // this is handled with enable_if to prevent out-of-bound warnings
770
+ // from the compiler
771
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
772
+ trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
773
+
774
+ // After division, the rhs corresponding to subsequent rows of A can be partially updated
775
+ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
776
+ // to be used in the next iteration.
777
+ trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
778
+ AInPacket);
779
+
780
+ // Handle division for the RHS corresponding to the final row of A.
781
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
782
+ trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
783
+
784
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
785
+ AInPacket);
786
+ }
787
+
788
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
789
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
790
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
791
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
792
+ EIGEN_UNUSED_VARIABLE(A_arr);
793
+ EIGEN_UNUSED_VARIABLE(LDA);
794
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
795
+ EIGEN_UNUSED_VARIABLE(AInPacket);
796
+ }
797
+
798
+ /********************************************************
799
+ * Wrappers for aux_XXXX to hide counter parameter
800
+ ********************************************************/
801
+
802
+ /**
803
+ * Load endMxendK block of B to RHSInPacket
804
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
805
+ */
806
+ template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
807
+ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
808
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
809
+ aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
810
+ }
811
+
812
+ /**
813
+ * Load endMxendK block of B to RHSInPacket
814
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
815
+ */
816
+ template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
817
+ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
818
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
819
+ aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
820
+ }
821
+
822
+ /**
823
+ * Only used if Triangular matrix has non-unit diagonal values
824
+ */
825
+ template <int64_t currM, int64_t endK>
826
+ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
827
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
828
+ aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
829
+ }
830
+
831
+ /**
832
+ * Update right-hand sides (stored in avx registers)
833
+ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
834
+ **/
835
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
836
+ int64_t currentM>
837
+ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
838
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
839
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
840
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
841
+ A_arr, LDA, RHSInPacket, AInPacket);
842
+ }
843
+
844
+ /**
845
+ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
846
+ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
847
+ * isFWDSolve: true => forward substitution, false => backwards substitution
848
+ * isUnitDiag: true => triangular matrix has unit diagonal.
849
+ */
850
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
851
+ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
852
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
853
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
854
+ static_assert(numK >= 1 && numK <= 3, "numK out of range");
855
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
856
+ }
857
+ };
858
+
859
+ /**
860
+ * Unrolls for gemm kernel
861
+ *
862
+ * isAdd: true => C += A*B, false => C -= A*B
863
+ */
864
+ template <typename Scalar, bool isAdd>
865
+ class gemm {
866
+ public:
867
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
868
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
869
+
870
+ /***********************************
871
+ * Auxiliary Functions for:
872
+ * - setzero
873
+ * - updateC
874
+ * - storeC
875
+ * - startLoadB
876
+ * - triSolveMicroKernel
877
+ ************************************/
878
+
879
+ /**
880
+ * aux_setzero
881
+ *
882
+ * 2-D unroll
883
+ * for(startM = 0; startM < endM; startM++)
884
+ * for(startN = 0; startN < endN; startN++)
885
+ **/
886
+ template <int64_t endM, int64_t endN, int64_t counter>
887
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
888
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
889
+ constexpr int64_t counterReverse = endM * endN - counter;
890
+ constexpr int64_t startM = counterReverse / (endN);
891
+ constexpr int64_t startN = counterReverse % endN;
892
+
893
+ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
894
+ aux_setzero<endM, endN, counter - 1>(zmm);
895
+ }
896
+
897
+ template <int64_t endM, int64_t endN, int64_t counter>
898
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
899
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
900
+ EIGEN_UNUSED_VARIABLE(zmm);
901
+ }
902
+
903
+ /**
904
+ * aux_updateC
905
+ *
906
+ * 2-D unroll
907
+ * for(startM = 0; startM < endM; startM++)
908
+ * for(startN = 0; startN < endN; startN++)
909
+ **/
910
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
911
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
912
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
913
+ EIGEN_UNUSED_VARIABLE(rem_);
914
+ constexpr int64_t counterReverse = endM * endN - counter;
915
+ constexpr int64_t startM = counterReverse / (endN);
916
+ constexpr int64_t startN = counterReverse % endN;
917
+
918
+ EIGEN_IF_CONSTEXPR(rem)
919
+ zmm.packet[startN * endM + startM] =
920
+ padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
921
+ zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
922
+ else zmm.packet[startN * endM + startM] =
923
+ padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
924
+ aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
925
+ }
926
+
927
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
928
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
929
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
930
+ EIGEN_UNUSED_VARIABLE(C_arr);
931
+ EIGEN_UNUSED_VARIABLE(LDC);
932
+ EIGEN_UNUSED_VARIABLE(zmm);
933
+ EIGEN_UNUSED_VARIABLE(rem_);
934
+ }
935
+
936
+ /**
937
+ * aux_storeC
938
+ *
939
+ * 2-D unroll
940
+ * for(startM = 0; startM < endM; startM++)
941
+ * for(startN = 0; startN < endN; startN++)
942
+ **/
943
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
944
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
945
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
946
+ EIGEN_UNUSED_VARIABLE(rem_);
947
+ constexpr int64_t counterReverse = endM * endN - counter;
948
+ constexpr int64_t startM = counterReverse / (endN);
949
+ constexpr int64_t startN = counterReverse % endN;
950
+
951
+ EIGEN_IF_CONSTEXPR(rem)
952
+ pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
953
+ remMask<PacketSize>(rem_));
954
+ else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
955
+ aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
956
+ }
957
+
958
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
959
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
960
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
961
+ EIGEN_UNUSED_VARIABLE(C_arr);
962
+ EIGEN_UNUSED_VARIABLE(LDC);
963
+ EIGEN_UNUSED_VARIABLE(zmm);
964
+ EIGEN_UNUSED_VARIABLE(rem_);
965
+ }
966
+
967
+ /**
968
+ * aux_startLoadB
969
+ *
970
+ * 1-D unroll
971
+ * for(startL = 0; startL < endL; startL++)
972
+ **/
973
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
974
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
975
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
976
+ EIGEN_UNUSED_VARIABLE(rem_);
977
+ constexpr int64_t counterReverse = endL - counter;
978
+ constexpr int64_t startL = counterReverse;
979
+
980
+ EIGEN_IF_CONSTEXPR(rem)
981
+ zmm.packet[unrollM * unrollN + startL] =
982
+ ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
983
+ else zmm.packet[unrollM * unrollN + startL] =
984
+ ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
985
+
986
+ aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
987
+ }
988
+
989
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
990
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
991
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
992
+ EIGEN_UNUSED_VARIABLE(B_t);
993
+ EIGEN_UNUSED_VARIABLE(LDB);
994
+ EIGEN_UNUSED_VARIABLE(zmm);
995
+ EIGEN_UNUSED_VARIABLE(rem_);
996
+ }
997
+
998
+ /**
999
+ * aux_startBCastA
1000
+ *
1001
+ * 1-D unroll
1002
+ * for(startB = 0; startB < endB; startB++)
1003
+ **/
1004
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1005
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1006
+ Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1007
+ constexpr int64_t counterReverse = endB - counter;
1008
+ constexpr int64_t startB = counterReverse;
1009
+
1010
+ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1011
+
1012
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1013
+ }
1014
+
1015
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1016
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1017
+ Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1018
+ EIGEN_UNUSED_VARIABLE(A_t);
1019
+ EIGEN_UNUSED_VARIABLE(LDA);
1020
+ EIGEN_UNUSED_VARIABLE(zmm);
1021
+ }
1022
+
1023
+ /**
1024
+ * aux_loadB
1025
+ * currK: current K
1026
+ *
1027
+ * 1-D unroll
1028
+ * for(startM = 0; startM < endM; startM++)
1029
+ **/
1030
+ template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1031
+ int64_t numBCast, bool rem>
1032
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1033
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1034
+ EIGEN_UNUSED_VARIABLE(rem_);
1035
+ if ((numLoad / endM + currK < unrollK)) {
1036
+ constexpr int64_t counterReverse = endM - counter;
1037
+ constexpr int64_t startM = counterReverse;
1038
+
1039
+ EIGEN_IF_CONSTEXPR(rem) {
1040
+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1041
+ ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1042
+ }
1043
+ else {
1044
+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1045
+ ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1046
+ }
1047
+
1048
+ aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1049
+ }
1050
+ }
1051
+
1052
+ template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1053
+ int64_t numBCast, bool rem>
1054
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1055
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1056
+ EIGEN_UNUSED_VARIABLE(B_t);
1057
+ EIGEN_UNUSED_VARIABLE(LDB);
1058
+ EIGEN_UNUSED_VARIABLE(zmm);
1059
+ EIGEN_UNUSED_VARIABLE(rem_);
1060
+ }
1061
+
1062
+ /**
1063
+ * aux_microKernel
1064
+ *
1065
+ * 3-D unroll
1066
+ * for(startM = 0; startM < endM; startM++)
1067
+ * for(startN = 0; startN < endN; startN++)
1068
+ * for(startK = 0; startK < endK; startK++)
1069
+ **/
1070
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1071
+ int64_t numBCast, bool rem>
1072
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1073
+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1074
+ int64_t rem_ = 0) {
1075
+ EIGEN_UNUSED_VARIABLE(rem_);
1076
+ constexpr int64_t counterReverse = endM * endN * endK - counter;
1077
+ constexpr int startK = counterReverse / (endM * endN);
1078
+ constexpr int startN = (counterReverse / (endM)) % endN;
1079
+ constexpr int startM = counterReverse % endM;
1080
+
1081
+ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1082
+ gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1083
+ gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1084
+ }
1085
+
1086
+ {
1087
+ // Interleave FMA and Bcast
1088
+ EIGEN_IF_CONSTEXPR(isAdd) {
1089
+ zmm.packet[startN * endM + startM] =
1090
+ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1091
+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1092
+ }
1093
+ else {
1094
+ zmm.packet[startN * endM + startM] =
1095
+ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1096
+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1097
+ }
1098
+ // Bcast
1099
+ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1100
+ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1101
+ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1102
+ }
1103
+ }
1104
+
1105
+ // We have updated all accumulators, time to load next set of B's
1106
+ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1107
+ gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1108
+ }
1109
+ aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1110
+ }
1111
+
1112
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1113
+ int64_t numBCast, bool rem>
1114
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1115
+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1116
+ int64_t rem_ = 0) {
1117
+ EIGEN_UNUSED_VARIABLE(B_t);
1118
+ EIGEN_UNUSED_VARIABLE(A_t);
1119
+ EIGEN_UNUSED_VARIABLE(LDB);
1120
+ EIGEN_UNUSED_VARIABLE(LDA);
1121
+ EIGEN_UNUSED_VARIABLE(zmm);
1122
+ EIGEN_UNUSED_VARIABLE(rem_);
1123
+ }
1124
+
1125
+ /********************************************************
1126
+ * Wrappers for aux_XXXX to hide counter parameter
1127
+ ********************************************************/
1128
+
1129
+ template <int64_t endM, int64_t endN>
1130
+ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1131
+ aux_setzero<endM, endN, endM * endN>(zmm);
1132
+ }
1133
+
1134
+ /**
1135
+ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
1136
+ */
1137
+ template <int64_t endM, int64_t endN, bool rem = false>
1138
+ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
1139
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1140
+ int64_t rem_ = 0) {
1141
+ EIGEN_UNUSED_VARIABLE(rem_);
1142
+ aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1143
+ }
1144
+
1145
+ template <int64_t endM, int64_t endN, bool rem = false>
1146
+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
1147
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1148
+ int64_t rem_ = 0) {
1149
+ EIGEN_UNUSED_VARIABLE(rem_);
1150
+ aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1151
+ }
1152
+
1153
+ /**
1154
+ * Use numLoad registers for loading B at start of microKernel
1155
+ */
1156
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
1157
+ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
1158
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1159
+ int64_t rem_ = 0) {
1160
+ EIGEN_UNUSED_VARIABLE(rem_);
1161
+ aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1162
+ }
1163
+
1164
+ /**
1165
+ * Use numBCast registers for broadcasting A at start of microKernel
1166
+ */
1167
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
1168
+ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
1169
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1170
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1171
+ }
1172
+
1173
+ /**
1174
+ * Loads next set of B into vector registers between each K unroll.
1175
+ */
1176
+ template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
1177
+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
1178
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1179
+ int64_t rem_ = 0) {
1180
+ EIGEN_UNUSED_VARIABLE(rem_);
1181
+ aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1182
+ }
1183
+
1184
+ /**
1185
+ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
1186
+ * A matrix can be row/col-major. B matrix is assumed row-major.
1187
+ *
1188
+ * isARowMajor: is A row major
1189
+ * endM: Number registers per row
1190
+ * endN: Number of rows
1191
+ * endK: Loop unroll for K.
1192
+ * numLoad: Number of registers for loading B.
1193
+ * numBCast: Number of registers for broadcasting A.
1194
+ *
1195
+ * Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
1196
+ * 6 register for loading B, 2 for broadcasting A.
1197
+ *
1198
+ * Note: Ideally the microkernel should not have any register spilling.
1199
+ * The avx instruction counts should be:
1200
+ * - endK*endN vbroadcasts{s,d}
1201
+ * - endK*endM vmovup{s,d}
1202
+ * - endK*endN*endM FMAs
1203
+ *
1204
+ * From testing, there are no register spills with clang. There are register spills with GNU, which
1205
+ * causes a performance hit.
1206
+ */
1207
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1208
+ bool rem = false>
1209
+ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1210
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1211
+ int64_t rem_ = 0) {
1212
+ EIGEN_UNUSED_VARIABLE(rem_);
1213
+ aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
1214
+ rem_);
1215
+ }
1216
+ };
1217
+ } // namespace unrolls
1218
+
1219
+ #endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H