@smake/eigen 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -20
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +187 -120
  6. package/eigen/Eigen/Eigenvalues +16 -13
  7. package/eigen/Eigen/Geometry +18 -18
  8. package/eigen/Eigen/Householder +9 -7
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -13
  11. package/eigen/Eigen/KLUSupport +23 -21
  12. package/eigen/Eigen/LU +15 -16
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -51
  15. package/eigen/Eigen/PaStiXSupport +23 -21
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -20
  18. package/eigen/Eigen/QtAlignedMalloc +5 -12
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -17
  21. package/eigen/Eigen/Sparse +1 -2
  22. package/eigen/Eigen/SparseCholesky +18 -15
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +9 -9
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +366 -405
  37. package/eigen/Eigen/src/Cholesky/LLT.h +323 -367
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +585 -529
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +143 -317
  42. package/eigen/Eigen/src/Core/Array.h +329 -370
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +126 -170
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +651 -604
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +125 -120
  48. package/eigen/Eigen/src/Core/BandMatrix.h +267 -282
  49. package/eigen/Eigen/src/Core/Block.h +371 -390
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +85 -100
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1214 -937
  53. package/eigen/Eigen/src/Core/CoreIterators.h +72 -63
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +112 -129
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +676 -702
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +55 -67
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +127 -92
  59. package/eigen/Eigen/src/Core/DenseBase.h +630 -658
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -628
  61. package/eigen/Eigen/src/Core/DenseStorage.h +511 -590
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +168 -207
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +346 -317
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +167 -217
  67. package/eigen/Eigen/src/Core/EigenBase.h +74 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -113
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +315 -261
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1182 -520
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +193 -157
  75. package/eigen/Eigen/src/Core/IO.h +131 -156
  76. package/eigen/Eigen/src/Core/IndexedView.h +209 -125
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +50 -59
  80. package/eigen/Eigen/src/Core/Map.h +123 -141
  81. package/eigen/Eigen/src/Core/MapBase.h +255 -282
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1247 -1201
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +162 -99
  84. package/eigen/Eigen/src/Core/Matrix.h +463 -494
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -470
  86. package/eigen/Eigen/src/Core/NestByValue.h +58 -52
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -86
  88. package/eigen/Eigen/src/Core/NumTraits.h +206 -206
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +163 -142
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +858 -972
  92. package/eigen/Eigen/src/Core/Product.h +246 -130
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +779 -671
  94. package/eigen/Eigen/src/Core/Random.h +153 -164
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +334 -314
  98. package/eigen/Eigen/src/Core/Ref.h +259 -257
  99. package/eigen/Eigen/src/Core/Replicate.h +92 -104
  100. package/eigen/Eigen/src/Core/Reshaped.h +215 -271
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +47 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +133 -148
  103. package/eigen/Eigen/src/Core/Select.h +68 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +254 -290
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +88 -102
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +126 -124
  109. package/eigen/Eigen/src/Core/SolverBase.h +132 -133
  110. package/eigen/Eigen/src/Core/StableNorm.h +113 -147
  111. package/eigen/Eigen/src/Core/StlIterators.h +404 -248
  112. package/eigen/Eigen/src/Core/Stride.h +90 -92
  113. package/eigen/Eigen/src/Core/Swap.h +70 -39
  114. package/eigen/Eigen/src/Core/Transpose.h +258 -295
  115. package/eigen/Eigen/src/Core/Transpositions.h +270 -333
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +642 -743
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +653 -704
  119. package/eigen/Eigen/src/Core/Visitor.h +464 -308
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +380 -187
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +65 -163
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2145 -638
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +253 -60
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +278 -228
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +48 -269
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1597 -754
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +229 -41
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +420 -184
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +40 -49
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2962 -2213
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +196 -212
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +713 -441
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2380 -1362
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +390 -224
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +78 -67
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1784 -799
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +167 -50
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +528 -379
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +10 -12
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +41 -40
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +550 -523
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +27 -30
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +8 -8
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +54 -82
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +84 -92
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +51 -47
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +454 -306
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +175 -115
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +23 -30
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4366 -2857
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +616 -393
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +350 -198
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +38 -149
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +1791 -912
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +128 -40
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +10 -6
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +156 -234
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +6 -3
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +27 -32
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +119 -117
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +325 -419
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +15 -17
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +325 -181
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +94 -83
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +811 -458
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +121 -124
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +576 -370
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +194 -109
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +95 -112
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1038 -749
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1883 -1375
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +312 -370
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +189 -176
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +84 -81
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +292 -337
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +207 -105
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +327 -388
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +138 -147
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -47
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -277
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +68 -94
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +342 -303
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +348 -317
  217. package/eigen/Eigen/src/Core/util/Constants.h +297 -262
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -90
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +449 -247
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +417 -116
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +211 -204
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -37
  226. package/eigen/Eigen/src/Core/util/Macros.h +655 -773
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +970 -748
  229. package/eigen/Eigen/src/Core/util/Meta.h +581 -633
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +17 -17
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +50 -166
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +377 -225
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +784 -547
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +89 -105
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +537 -607
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +342 -381
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +541 -595
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +430 -462
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +226 -227
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +131 -133
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +285 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +151 -160
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -146
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +127 -127
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +566 -506
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +107 -105
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +113 -106
  264. package/eigen/Eigen/src/Geometry/Transform.h +858 -936
  265. package/eigen/Eigen/src/Geometry/Translation.h +94 -92
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +90 -104
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +51 -46
  269. package/eigen/Eigen/src/Householder/Householder.h +102 -124
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +412 -453
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +149 -162
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +124 -119
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +92 -104
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +251 -243
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +224 -228
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +178 -227
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +79 -84
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +54 -60
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +252 -308
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +208 -227
  285. package/eigen/Eigen/src/LU/Determinant.h +50 -69
  286. package/eigen/Eigen/src/LU/FullPivLU.h +545 -596
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +206 -285
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +390 -428
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +72 -70
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +243 -265
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +831 -1004
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +112 -119
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -430
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +479 -479
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +166 -153
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +495 -475
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +394 -285
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +244 -264
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +817 -713
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +577 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +242 -182
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +200 -235
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +765 -594
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +308 -94
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -252
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +134 -178
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +149 -140
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +403 -440
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +525 -303
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +555 -339
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +169 -197
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1603 -1245
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -350
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +94 -97
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +370 -416
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +138 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +756 -710
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +245 -301
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +89 -100
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +124 -132
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +450 -502
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -93
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -730
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +428 -464
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9972 -16179
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/{BlockMethods.h → BlockMethods.inc} +434 -506
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/{CommonCwiseUnaryOps.h → CommonCwiseUnaryOps.inc} +58 -68
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/package.json +1 -1
  409. package/eigen/COPYING.APACHE +0 -203
  410. package/eigen/COPYING.BSD +0 -26
  411. package/eigen/COPYING.GPL +0 -674
  412. package/eigen/COPYING.LGPL +0 -502
  413. package/eigen/COPYING.MINPACK +0 -51
  414. package/eigen/COPYING.MPL2 +0 -373
  415. package/eigen/COPYING.README +0 -18
  416. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -162
  417. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -258
  418. package/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +0 -120
  419. package/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +0 -694
  420. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  421. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  422. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  423. package/eigen/Eigen/src/misc/lapack.h +0 -152
  424. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -358
  425. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -696
  426. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  427. package/eigen/Eigen/src/plugins/IndexedViewMethods.h +0 -262
  428. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  429. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -95
  430. package/eigen/Eigen/src/plugins/ReshapedMethods.h +0 -149
  431. package/eigen/README.md +0 -5
@@ -0,0 +1,1219 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
11
+ #define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
12
+
13
+ template <bool isARowMajor = true>
14
+ EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
15
+ EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
16
+ else return i + j * LDA;
17
+ }
18
+
19
+ /**
20
+ * This namespace contains various classes used to generate compile-time unrolls which are
21
+ * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
22
+ * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
23
+ *
24
+ * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
25
+ *
26
+ * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
27
+ * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
28
+ * func(startI,startJ) startJ = (startC)%(endJ)
29
+ * func(...)
30
+ *
31
+ * The 1-D loop can be unrolled recursively by using enable_if and defining an auxiliary function
32
+ * with a template parameter used as a counter.
33
+ *
34
+ * template <endI, endJ, counter>
35
+ * std::enable_if_t<(counter <= 0)> <---- tail case.
36
+ * aux_func {}
37
+ *
38
+ * template <endI, endJ, counter>
39
+ * std::enable_if_t<(counter > 0)> <---- actual for-loop
40
+ * aux_func {
41
+ * startC = endI*endJ - counter
42
+ * startI = (startC)/(endJ)
43
+ * startJ = (startC)%(endJ)
44
+ * func(startI, startJ)
45
+ * aux_func<endI, endJ, counter-1>()
46
+ * }
47
+ *
48
+ * Note: Additional wrapper functions are provided for aux_func which hides the counter template
49
+ * parameter since counter usually depends on endI, endJ, etc...
50
+ *
51
+ * Conventions:
52
+ * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
53
+ *
54
+ * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
55
+ * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
56
+ */
57
+ namespace unrolls {
58
+
59
+ template <int64_t N>
60
+ EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
61
+ EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
62
+ else EIGEN_IF_CONSTEXPR(N == 8) {
63
+ return 0xFF >> (8 - m);
64
+ }
65
+ else EIGEN_IF_CONSTEXPR(N == 4) {
66
+ return 0x0F >> (4 - m);
67
+ }
68
+ return 0;
69
+ }
70
+
71
+ template <typename Packet>
72
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
73
+
74
+ template <>
75
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
76
+ __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
77
+ __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
78
+ __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
79
+ __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
80
+ __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
81
+ __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
82
+ __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
83
+ __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
84
+
85
+ kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
86
+ kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
87
+ kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
88
+ kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
89
+ kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
90
+ kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
91
+ kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
92
+ kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
93
+
94
+ T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
95
+ T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
96
+ T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
97
+ T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
98
+ T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
99
+ T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
100
+ T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
101
+ T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
102
+ T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
103
+ T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
104
+ T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
105
+ T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
106
+ T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
107
+ T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
108
+ T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
109
+ T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
110
+
111
+ kernel.packet[0] = T0;
112
+ kernel.packet[1] = T1;
113
+ kernel.packet[2] = T2;
114
+ kernel.packet[3] = T3;
115
+ kernel.packet[4] = T4;
116
+ kernel.packet[5] = T5;
117
+ kernel.packet[6] = T6;
118
+ kernel.packet[7] = T7;
119
+ }
120
+
121
+ template <>
122
+ EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
123
+ ptranspose(kernel);
124
+ }
125
+
126
+ /***
127
+ * Unrolls for transposed C stores
128
+ */
129
+ template <typename Scalar>
130
+ class trans {
131
+ public:
132
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
133
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
134
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
135
+
136
+ /***********************************
137
+ * Auxiliary Functions for:
138
+ * - storeC
139
+ ***********************************
140
+ */
141
+
142
+ /**
143
+ * aux_storeC
144
+ *
145
+ * 1-D unroll
146
+ * for(startN = 0; startN < endN; startN++)
147
+ *
148
+ * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
149
+ *
150
+ **/
151
+ template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
152
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
153
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
154
+ constexpr int64_t counterReverse = endN - counter;
155
+ constexpr int64_t startN = counterReverse;
156
+
157
+ EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
158
+ EIGEN_IF_CONSTEXPR(remM) {
159
+ pstoreu<Scalar>(
160
+ C_arr + LDC * startN,
161
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
162
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
163
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
164
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
165
+ }
166
+ else {
167
+ pstoreu<Scalar>(C_arr + LDC * startN,
168
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
169
+ preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
170
+ }
171
+ }
172
+ else { // This block is only needed for fp32 case
173
+ // Reinterpret as __m512 for _mm512_shuffle_f32x4
174
+ vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
175
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
176
+ // Swap lower and upper half of avx register.
177
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
178
+ preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
179
+
180
+ EIGEN_IF_CONSTEXPR(remM) {
181
+ pstoreu<Scalar>(
182
+ C_arr + LDC * startN,
183
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
184
+ preinterpret<vecHalf>(
185
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
186
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
187
+ }
188
+ else {
189
+ pstoreu<Scalar>(
190
+ C_arr + LDC * startN,
191
+ padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
192
+ preinterpret<vecHalf>(
193
+ zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
194
+ }
195
+ }
196
+ aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
197
+ }
198
+
199
+ template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
200
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
201
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
202
+ EIGEN_UNUSED_VARIABLE(C_arr);
203
+ EIGEN_UNUSED_VARIABLE(LDC);
204
+ EIGEN_UNUSED_VARIABLE(zmm);
205
+ EIGEN_UNUSED_VARIABLE(remM_);
206
+ }
207
+
208
+ template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
209
+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
210
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
211
+ int64_t remM_ = 0) {
212
+ aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
213
+ }
214
+
215
+ /**
216
+ * Transposes LxunrollN row major block of matrices stored `EIGEN_AVX_MAX_NUM_ACC` zmm registers to
217
+ * "unrollN"xL ymm registers to be stored col-major into C.
218
+ *
219
+ * For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
220
+ *
221
+ * ```
222
+ * row0: zmm0 zmm1 zmm2
223
+ * row1: zmm3 zmm4 zmm5
224
+ * .
225
+ * .
226
+ * row7: zmm21 zmm22 zmm23
227
+ *
228
+ * For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
229
+ *
230
+ * row0: zmm0 zmm1
231
+ * row1: zmm2 zmm3
232
+ * .
233
+ * .
234
+ * row7: zmm14 zmm15
235
+ * ```
236
+ *
237
+ * In general we will have {1,2,3} groups of avx registers each of size
238
+ * `EIGEN_AVX_MAX_NUM_ROW`. packetIndexOffset is used to select which "block" of
239
+ * avx registers are being transposed.
240
+ */
241
+ template <int64_t unrollN, int64_t packetIndexOffset>
242
+ static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
243
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
244
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
245
+ constexpr int64_t zmmStride = unrollN / PacketSize;
246
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
247
+ r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
248
+ r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
249
+ r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
250
+ r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
251
+ r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
252
+ r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
253
+ r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
254
+ r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
255
+ trans8x8blocks(r);
256
+ zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
257
+ zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
258
+ zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
259
+ zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
260
+ zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
261
+ zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
262
+ zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
263
+ zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
264
+ }
265
+ };
266
+
267
+ /**
268
+ * Unrolls for copyBToRowMajor
269
+ *
270
+ * Idea:
271
+ * 1) Load a block of right-hand sides to registers (using loadB).
272
+ * 2) Convert the block from column-major to row-major (transposeLxL)
273
+ * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
274
+ *
275
+ * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
276
+ * used as temps for transposing.
277
+ *
278
+ * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
279
+ * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
280
+ */
281
+ template <typename Scalar>
282
+ class transB {
283
+ public:
284
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
285
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
286
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
287
+
288
+ /***********************************
289
+ * Auxiliary Functions for:
290
+ * - loadB
291
+ * - storeB
292
+ * - loadBBlock
293
+ * - storeBBlock
294
+ ***********************************
295
+ */
296
+
297
+ /**
298
+ * aux_loadB
299
+ *
300
+ * 1-D unroll
301
+ * for(startN = 0; startN < endN; startN++)
302
+ **/
303
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
304
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
305
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
306
+ int64_t remM_ = 0) {
307
+ constexpr int64_t counterReverse = endN - counter;
308
+ constexpr int64_t startN = counterReverse;
309
+
310
+ EIGEN_IF_CONSTEXPR(remM) {
311
+ ymm.packet[packetIndexOffset + startN] =
312
+ ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
313
+ }
314
+ else {
315
+ EIGEN_IF_CONSTEXPR(remN_ == 0) {
316
+ ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
317
+ }
318
+ else ymm.packet[packetIndexOffset + startN] =
319
+ ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
320
+ }
321
+
322
+ aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
323
+ }
324
+
325
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
326
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
327
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
328
+ int64_t remM_ = 0) {
329
+ EIGEN_UNUSED_VARIABLE(B_arr);
330
+ EIGEN_UNUSED_VARIABLE(LDB);
331
+ EIGEN_UNUSED_VARIABLE(ymm);
332
+ EIGEN_UNUSED_VARIABLE(remM_);
333
+ }
334
+
335
+ /**
336
+ * aux_storeB
337
+ *
338
+ * 1-D unroll
339
+ * for(startN = 0; startN < endN; startN++)
340
+ **/
341
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
342
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
343
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
344
+ constexpr int64_t counterReverse = endN - counter;
345
+ constexpr int64_t startN = counterReverse;
346
+
347
+ EIGEN_IF_CONSTEXPR(remK || remM) {
348
+ pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
349
+ remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
350
+ }
351
+ else {
352
+ pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
353
+ }
354
+
355
+ aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
356
+ }
357
+
358
+ template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
359
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
360
+ Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
361
+ EIGEN_UNUSED_VARIABLE(B_arr);
362
+ EIGEN_UNUSED_VARIABLE(LDB);
363
+ EIGEN_UNUSED_VARIABLE(ymm);
364
+ EIGEN_UNUSED_VARIABLE(rem_);
365
+ }
366
+
367
+ /**
368
+ * aux_loadBBlock
369
+ *
370
+ * 1-D unroll
371
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
372
+ **/
373
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
374
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
375
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
376
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
377
+ constexpr int64_t counterReverse = endN - counter;
378
+ constexpr int64_t startN = counterReverse;
379
+ transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
380
+ aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
381
+ }
382
+
383
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
384
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
385
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
386
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
387
+ EIGEN_UNUSED_VARIABLE(B_arr);
388
+ EIGEN_UNUSED_VARIABLE(LDB);
389
+ EIGEN_UNUSED_VARIABLE(B_temp);
390
+ EIGEN_UNUSED_VARIABLE(LDB_);
391
+ EIGEN_UNUSED_VARIABLE(ymm);
392
+ EIGEN_UNUSED_VARIABLE(remM_);
393
+ }
394
+
395
+ /**
396
+ * aux_storeBBlock
397
+ *
398
+ * 1-D unroll
399
+ * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
400
+ **/
401
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
402
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
403
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
404
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
405
+ constexpr int64_t counterReverse = endN - counter;
406
+ constexpr int64_t startN = counterReverse;
407
+
408
+ EIGEN_IF_CONSTEXPR(toTemp) {
409
+ transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
410
+ }
411
+ else {
412
+ transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
413
+ ymm, remM_);
414
+ }
415
+ aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
416
+ }
417
+
418
+ template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
419
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
420
+ Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
421
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
422
+ EIGEN_UNUSED_VARIABLE(B_arr);
423
+ EIGEN_UNUSED_VARIABLE(LDB);
424
+ EIGEN_UNUSED_VARIABLE(B_temp);
425
+ EIGEN_UNUSED_VARIABLE(LDB_);
426
+ EIGEN_UNUSED_VARIABLE(ymm);
427
+ EIGEN_UNUSED_VARIABLE(remM_);
428
+ }
429
+
430
+ /********************************************************
431
+ * Wrappers for aux_XXXX to hide counter parameter
432
+ ********************************************************/
433
+
434
+ template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
435
+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
436
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
437
+ int64_t remM_ = 0) {
438
+ aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
439
+ }
440
+
441
+ template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
442
+ static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
443
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
444
+ int64_t rem_ = 0) {
445
+ aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
446
+ }
447
+
448
+ template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
449
+ static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
450
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
451
+ int64_t remM_ = 0) {
452
+ EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
453
+ else {
454
+ aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
455
+ }
456
+ }
457
+
458
+ template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
459
+ static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
460
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
461
+ int64_t remM_ = 0) {
462
+ aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
463
+ }
464
+
465
+ template <int64_t packetIndexOffset>
466
+ static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
467
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
468
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
469
+ PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
470
+ r.packet[0] = ymm.packet[packetIndexOffset + 0];
471
+ r.packet[1] = ymm.packet[packetIndexOffset + 1];
472
+ r.packet[2] = ymm.packet[packetIndexOffset + 2];
473
+ r.packet[3] = ymm.packet[packetIndexOffset + 3];
474
+ r.packet[4] = ymm.packet[packetIndexOffset + 4];
475
+ r.packet[5] = ymm.packet[packetIndexOffset + 5];
476
+ r.packet[6] = ymm.packet[packetIndexOffset + 6];
477
+ r.packet[7] = ymm.packet[packetIndexOffset + 7];
478
+ ptranspose(r);
479
+ ymm.packet[packetIndexOffset + 0] = r.packet[0];
480
+ ymm.packet[packetIndexOffset + 1] = r.packet[1];
481
+ ymm.packet[packetIndexOffset + 2] = r.packet[2];
482
+ ymm.packet[packetIndexOffset + 3] = r.packet[3];
483
+ ymm.packet[packetIndexOffset + 4] = r.packet[4];
484
+ ymm.packet[packetIndexOffset + 5] = r.packet[5];
485
+ ymm.packet[packetIndexOffset + 6] = r.packet[6];
486
+ ymm.packet[packetIndexOffset + 7] = r.packet[7];
487
+ }
488
+
489
+ template <int64_t unrollN, bool toTemp, bool remM>
490
+ static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
491
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
492
+ int64_t remM_ = 0) {
493
+ constexpr int64_t U3 = PacketSize * 3;
494
+ constexpr int64_t U2 = PacketSize * 2;
495
+ constexpr int64_t U1 = PacketSize * 1;
496
+ /**
497
+ * Unrolls needed for each case:
498
+ * - AVX512 fp32 48 32 16 8 4 2 1
499
+ * - AVX512 fp64 24 16 8 4 2 1
500
+ *
501
+ * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
502
+ */
503
+ EIGEN_IF_CONSTEXPR(unrollN == U3) {
504
+ // load LxU3 B col major, transpose LxU3 row major
505
+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
506
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
507
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
508
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
509
+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
510
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
511
+
512
+ EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
513
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
514
+ ymm, remM_);
515
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
516
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
517
+ transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
518
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
519
+ ymm, remM_);
520
+ }
521
+ }
522
+ else EIGEN_IF_CONSTEXPR(unrollN == U2) {
523
+ // load LxU2 B col major, transpose LxU2 row major
524
+ constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
525
+ transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
526
+ transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
527
+ transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
528
+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
529
+ transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
530
+
531
+ EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
532
+ transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
533
+ &B_temp[maxUBlock], LDB_, ymm, remM_);
534
+ transB::template transposeLxL<0>(ymm);
535
+ transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
536
+ &B_temp[maxUBlock], LDB_, ymm, remM_);
537
+ }
538
+ }
539
+ else EIGEN_IF_CONSTEXPR(unrollN == U1) {
540
+ // load LxU1 B col major, transpose LxU1 row major
541
+ transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
542
+ transB::template transposeLxL<0>(ymm);
543
+ EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
544
+ transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
545
+ }
546
+ else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
547
+ // load Lx4 B col major, transpose Lx4 row major
548
+ transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
549
+ transB::template transposeLxL<0>(ymm);
550
+ transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
551
+ }
552
+ else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
553
+ // load Lx4 B col major, transpose Lx4 row major
554
+ transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
555
+ transB::template transposeLxL<0>(ymm);
556
+ transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
557
+ }
558
+ else EIGEN_IF_CONSTEXPR(unrollN == 2) {
559
+ // load Lx2 B col major, transpose Lx2 row major
560
+ transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
561
+ transB::template transposeLxL<0>(ymm);
562
+ transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
563
+ }
564
+ else EIGEN_IF_CONSTEXPR(unrollN == 1) {
565
+ // load Lx1 B col major, transpose Lx1 row major
566
+ transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
567
+ transB::template transposeLxL<0>(ymm);
568
+ transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
569
+ }
570
+ }
571
+ };
572
+
573
+ /**
574
+ * Unrolls for triSolveKernel
575
+ *
576
+ * Idea:
577
+ * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
578
+ * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
579
+ * stored in AInPacket (using triSolveMicroKernel).
580
+ * 3) Store final results (in avx registers) back into memory (using storeRHS).
581
+ *
582
+ * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
583
+ * EIGEN_AVX_MAX_NUM_ROW registers.
584
+ */
585
+ template <typename Scalar>
586
+ class trsm {
587
+ public:
588
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
589
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
590
+
591
+ /***********************************
592
+ * Auxiliary Functions for:
593
+ * - loadRHS
594
+ * - storeRHS
595
+ * - divRHSByDiag
596
+ * - updateRHS
597
+ * - triSolveMicroKernel
598
+ ************************************/
599
+ /**
600
+ * aux_loadRHS
601
+ *
602
+ * 2-D unroll
603
+ * for(startM = 0; startM < endM; startM++)
604
+ * for(startK = 0; startK < endK; startK++)
605
+ **/
606
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
607
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
608
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
609
+ constexpr int64_t counterReverse = endM * endK - counter;
610
+ constexpr int64_t startM = counterReverse / (endK);
611
+ constexpr int64_t startK = counterReverse % endK;
612
+
613
+ constexpr int64_t packetIndex = startM * endK + startK;
614
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
615
+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
616
+ EIGEN_IF_CONSTEXPR(krem) {
617
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
618
+ }
619
+ else {
620
+ RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
621
+ }
622
+ aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
623
+ }
624
+
625
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
626
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
627
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
628
+ EIGEN_UNUSED_VARIABLE(B_arr);
629
+ EIGEN_UNUSED_VARIABLE(LDB);
630
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
631
+ EIGEN_UNUSED_VARIABLE(rem);
632
+ }
633
+
634
+ /**
635
+ * aux_storeRHS
636
+ *
637
+ * 2-D unroll
638
+ * for(startM = 0; startM < endM; startM++)
639
+ * for(startK = 0; startK < endK; startK++)
640
+ **/
641
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
642
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
643
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
644
+ constexpr int64_t counterReverse = endM * endK - counter;
645
+ constexpr int64_t startM = counterReverse / (endK);
646
+ constexpr int64_t startK = counterReverse % endK;
647
+
648
+ constexpr int64_t packetIndex = startM * endK + startK;
649
+ constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
650
+ const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
651
+ EIGEN_IF_CONSTEXPR(krem) {
652
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
653
+ }
654
+ else {
655
+ pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
656
+ }
657
+ aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
658
+ }
659
+
660
+ template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
661
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
662
+ Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
663
+ EIGEN_UNUSED_VARIABLE(B_arr);
664
+ EIGEN_UNUSED_VARIABLE(LDB);
665
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
666
+ EIGEN_UNUSED_VARIABLE(rem);
667
+ }
668
+
669
+ /**
670
+ * aux_divRHSByDiag
671
+ *
672
+ * currM may be -1, (currM >=0) in enable_if checks for this
673
+ *
674
+ * 1-D unroll
675
+ * for(startK = 0; startK < endK; startK++)
676
+ **/
677
+ template <int64_t currM, int64_t endK, int64_t counter>
678
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
679
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
680
+ constexpr int64_t counterReverse = endK - counter;
681
+ constexpr int64_t startK = counterReverse;
682
+
683
+ constexpr int64_t packetIndex = currM * endK + startK;
684
+ RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
685
+ aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
686
+ }
687
+
688
+ template <int64_t currM, int64_t endK, int64_t counter>
689
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
690
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
691
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
692
+ EIGEN_UNUSED_VARIABLE(AInPacket);
693
+ }
694
+
695
+ /**
696
+ * aux_updateRHS
697
+ *
698
+ * 2-D unroll
699
+ * for(startM = initM; startM < endM; startM++)
700
+ * for(startK = 0; startK < endK; startK++)
701
+ **/
702
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
703
+ int64_t counter, int64_t currentM>
704
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
705
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
706
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
707
+ constexpr int64_t counterReverse = (endM - initM) * endK - counter;
708
+ constexpr int64_t startM = initM + counterReverse / (endK);
709
+ constexpr int64_t startK = counterReverse % endK;
710
+
711
+ // For each row of A, first update all corresponding RHS
712
+ constexpr int64_t packetIndex = startM * endK + startK;
713
+ EIGEN_IF_CONSTEXPR(currentM > 0) {
714
+ RHSInPacket.packet[packetIndex] =
715
+ pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
716
+ RHSInPacket.packet[packetIndex]);
717
+ }
718
+
719
+ EIGEN_IF_CONSTEXPR(startK == endK - 1) {
720
+ // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
721
+ EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
722
+ // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
723
+ // This will be used in divRHSByDiag
724
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
725
+ AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
726
+ else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
727
+ }
728
+ else {
729
+ // Broadcast next off diagonal element of A
730
+ EIGEN_IF_CONSTEXPR(isFWDSolve)
731
+ AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
732
+ else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
733
+ }
734
+ }
735
+
736
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
737
+ A_arr, LDA, RHSInPacket, AInPacket);
738
+ }
739
+
740
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
741
+ int64_t counter, int64_t currentM>
742
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
743
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
744
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
745
+ EIGEN_UNUSED_VARIABLE(A_arr);
746
+ EIGEN_UNUSED_VARIABLE(LDA);
747
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
748
+ EIGEN_UNUSED_VARIABLE(AInPacket);
749
+ }
750
+
751
+ /**
752
+ * aux_triSolverMicroKernel
753
+ *
754
+ * 1-D unroll
755
+ * for(startM = 0; startM < endM; startM++)
756
+ **/
757
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
758
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
759
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
760
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
761
+ constexpr int64_t counterReverse = endM - counter;
762
+ constexpr int64_t startM = counterReverse;
763
+
764
+ constexpr int64_t currentM = startM;
765
+ // Divides the right-hand side in row startM, by digonal value of A
766
+ // broadcasted to AInPacket.packet[startM-1] in the previous iteration.
767
+ //
768
+ // Without "if constexpr" the compiler instantiates the case <-1, numK>
769
+ // this is handled with enable_if to prevent out-of-bound warnings
770
+ // from the compiler
771
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
772
+ trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
773
+
774
+ // After division, the rhs corresponding to subsequent rows of A can be partially updated
775
+ // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
776
+ // to be used in the next iteration.
777
+ trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
778
+ AInPacket);
779
+
780
+ // Handle division for the RHS corresponding to the final row of A.
781
+ EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
782
+ trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
783
+
784
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
785
+ AInPacket);
786
+ }
787
+
788
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
789
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
790
+ Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
791
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
792
+ EIGEN_UNUSED_VARIABLE(A_arr);
793
+ EIGEN_UNUSED_VARIABLE(LDA);
794
+ EIGEN_UNUSED_VARIABLE(RHSInPacket);
795
+ EIGEN_UNUSED_VARIABLE(AInPacket);
796
+ }
797
+
798
+ /********************************************************
799
+ * Wrappers for aux_XXXX to hide counter parameter
800
+ ********************************************************/
801
+
802
+ /**
803
+ * Load endMxendK block of B to RHSInPacket
804
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
805
+ */
806
+ template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
807
+ static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
808
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
809
+ aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
810
+ }
811
+
812
+ /**
813
+ * Load endMxendK block of B to RHSInPacket
814
+ * Masked loads are used for cases where endK is not a multiple of PacketSize
815
+ */
816
+ template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
817
+ static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
818
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
819
+ aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
820
+ }
821
+
822
+ /**
823
+ * Only used if Triangular matrix has non-unit diagonal values
824
+ */
825
+ template <int64_t currM, int64_t endK>
826
+ static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
827
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
828
+ aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
829
+ }
830
+
831
+ /**
832
+ * Update right-hand sides (stored in avx registers)
833
+ * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
834
+ **/
835
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
836
+ int64_t currentM>
837
+ static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
838
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
839
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
840
+ aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
841
+ A_arr, LDA, RHSInPacket, AInPacket);
842
+ }
843
+
844
+ /**
845
+ * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
846
+ * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
847
+ * isFWDSolve: true => forward substitution, false => backwards substitution
848
+ * isUnitDiag: true => triangular matrix has unit diagonal.
849
+ */
850
+ template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
851
+ static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
852
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
853
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
854
+ static_assert(numK >= 1 && numK <= 3, "numK out of range");
855
+ aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
856
+ }
857
+ };
858
+
859
+ /**
860
+ * Unrolls for gemm kernel
861
+ *
862
+ * isAdd: true => C += A*B, false => C -= A*B
863
+ */
864
+ template <typename Scalar, bool isAdd>
865
+ class gemm {
866
+ public:
867
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
868
+ static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
869
+
870
+ /***********************************
871
+ * Auxiliary Functions for:
872
+ * - setzero
873
+ * - updateC
874
+ * - storeC
875
+ * - startLoadB
876
+ * - triSolveMicroKernel
877
+ ************************************/
878
+
879
+ /**
880
+ * aux_setzero
881
+ *
882
+ * 2-D unroll
883
+ * for(startM = 0; startM < endM; startM++)
884
+ * for(startN = 0; startN < endN; startN++)
885
+ **/
886
+ template <int64_t endM, int64_t endN, int64_t counter>
887
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
888
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
889
+ constexpr int64_t counterReverse = endM * endN - counter;
890
+ constexpr int64_t startM = counterReverse / (endN);
891
+ constexpr int64_t startN = counterReverse % endN;
892
+
893
+ zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
894
+ aux_setzero<endM, endN, counter - 1>(zmm);
895
+ }
896
+
897
+ template <int64_t endM, int64_t endN, int64_t counter>
898
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
899
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
900
+ EIGEN_UNUSED_VARIABLE(zmm);
901
+ }
902
+
903
+ /**
904
+ * aux_updateC
905
+ *
906
+ * 2-D unroll
907
+ * for(startM = 0; startM < endM; startM++)
908
+ * for(startN = 0; startN < endN; startN++)
909
+ **/
910
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
911
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
912
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
913
+ EIGEN_UNUSED_VARIABLE(rem_);
914
+ constexpr int64_t counterReverse = endM * endN - counter;
915
+ constexpr int64_t startM = counterReverse / (endN);
916
+ constexpr int64_t startN = counterReverse % endN;
917
+
918
+ EIGEN_IF_CONSTEXPR(rem)
919
+ zmm.packet[startN * endM + startM] =
920
+ padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
921
+ zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
922
+ else zmm.packet[startN * endM + startM] =
923
+ padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
924
+ aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
925
+ }
926
+
927
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
928
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
929
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
930
+ EIGEN_UNUSED_VARIABLE(C_arr);
931
+ EIGEN_UNUSED_VARIABLE(LDC);
932
+ EIGEN_UNUSED_VARIABLE(zmm);
933
+ EIGEN_UNUSED_VARIABLE(rem_);
934
+ }
935
+
936
+ /**
937
+ * aux_storeC
938
+ *
939
+ * 2-D unroll
940
+ * for(startM = 0; startM < endM; startM++)
941
+ * for(startN = 0; startN < endN; startN++)
942
+ **/
943
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
944
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
945
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
946
+ EIGEN_UNUSED_VARIABLE(rem_);
947
+ constexpr int64_t counterReverse = endM * endN - counter;
948
+ constexpr int64_t startM = counterReverse / (endN);
949
+ constexpr int64_t startN = counterReverse % endN;
950
+
951
+ EIGEN_IF_CONSTEXPR(rem)
952
+ pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
953
+ remMask<PacketSize>(rem_));
954
+ else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
955
+ aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
956
+ }
957
+
958
+ template <int64_t endM, int64_t endN, int64_t counter, bool rem>
959
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
960
+ Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
961
+ EIGEN_UNUSED_VARIABLE(C_arr);
962
+ EIGEN_UNUSED_VARIABLE(LDC);
963
+ EIGEN_UNUSED_VARIABLE(zmm);
964
+ EIGEN_UNUSED_VARIABLE(rem_);
965
+ }
966
+
967
+ /**
968
+ * aux_startLoadB
969
+ *
970
+ * 1-D unroll
971
+ * for(startL = 0; startL < endL; startL++)
972
+ **/
973
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
974
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
975
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
976
+ EIGEN_UNUSED_VARIABLE(rem_);
977
+ constexpr int64_t counterReverse = endL - counter;
978
+ constexpr int64_t startL = counterReverse;
979
+
980
+ EIGEN_IF_CONSTEXPR(rem)
981
+ zmm.packet[unrollM * unrollN + startL] =
982
+ ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
983
+ else zmm.packet[unrollM * unrollN + startL] =
984
+ ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
985
+
986
+ aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
987
+ }
988
+
989
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
990
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
991
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
992
+ EIGEN_UNUSED_VARIABLE(B_t);
993
+ EIGEN_UNUSED_VARIABLE(LDB);
994
+ EIGEN_UNUSED_VARIABLE(zmm);
995
+ EIGEN_UNUSED_VARIABLE(rem_);
996
+ }
997
+
998
+ /**
999
+ * aux_startBCastA
1000
+ *
1001
+ * 1-D unroll
1002
+ * for(startB = 0; startB < endB; startB++)
1003
+ **/
1004
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1005
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
1006
+ Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1007
+ constexpr int64_t counterReverse = endB - counter;
1008
+ constexpr int64_t startB = counterReverse;
1009
+
1010
+ zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
1011
+
1012
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
1013
+ }
1014
+
1015
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
1016
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
1017
+ Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1018
+ EIGEN_UNUSED_VARIABLE(A_t);
1019
+ EIGEN_UNUSED_VARIABLE(LDA);
1020
+ EIGEN_UNUSED_VARIABLE(zmm);
1021
+ }
1022
+
1023
+ /**
1024
+ * aux_loadB
1025
+ * currK: current K
1026
+ *
1027
+ * 1-D unroll
1028
+ * for(startM = 0; startM < endM; startM++)
1029
+ **/
1030
+ template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1031
+ int64_t numBCast, bool rem>
1032
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
1033
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1034
+ EIGEN_UNUSED_VARIABLE(rem_);
1035
+ if ((numLoad / endM + currK < unrollK)) {
1036
+ constexpr int64_t counterReverse = endM - counter;
1037
+ constexpr int64_t startM = counterReverse;
1038
+
1039
+ EIGEN_IF_CONSTEXPR(rem) {
1040
+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1041
+ ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
1042
+ }
1043
+ else {
1044
+ zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
1045
+ ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
1046
+ }
1047
+
1048
+ aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1049
+ }
1050
+ }
1051
+
1052
+ template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
1053
+ int64_t numBCast, bool rem>
1054
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
1055
+ Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
1056
+ EIGEN_UNUSED_VARIABLE(B_t);
1057
+ EIGEN_UNUSED_VARIABLE(LDB);
1058
+ EIGEN_UNUSED_VARIABLE(zmm);
1059
+ EIGEN_UNUSED_VARIABLE(rem_);
1060
+ }
1061
+
1062
+ /**
1063
+ * aux_microKernel
1064
+ *
1065
+ * 3-D unroll
1066
+ * for(startM = 0; startM < endM; startM++)
1067
+ * for(startN = 0; startN < endN; startN++)
1068
+ * for(startK = 0; startK < endK; startK++)
1069
+ **/
1070
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1071
+ int64_t numBCast, bool rem>
1072
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
1073
+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1074
+ int64_t rem_ = 0) {
1075
+ EIGEN_UNUSED_VARIABLE(rem_);
1076
+ constexpr int64_t counterReverse = endM * endN * endK - counter;
1077
+ constexpr int startK = counterReverse / (endM * endN);
1078
+ constexpr int startN = (counterReverse / (endM)) % endN;
1079
+ constexpr int startM = counterReverse % endM;
1080
+
1081
+ EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
1082
+ gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
1083
+ gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
1084
+ }
1085
+
1086
+ {
1087
+ // Interleave FMA and Bcast
1088
+ EIGEN_IF_CONSTEXPR(isAdd) {
1089
+ zmm.packet[startN * endM + startM] =
1090
+ pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1091
+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1092
+ }
1093
+ else {
1094
+ zmm.packet[startN * endM + startM] =
1095
+ pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
1096
+ zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
1097
+ }
1098
+ // Bcast
1099
+ EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
1100
+ zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
1101
+ (numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
1102
+ }
1103
+ }
1104
+
1105
+ // We have updated all accumulators, time to load next set of B's
1106
+ EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
1107
+ gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1108
+ }
1109
+ aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
1110
+ }
1111
+
1112
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
1113
+ int64_t numBCast, bool rem>
1114
+ static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
1115
+ Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1116
+ int64_t rem_ = 0) {
1117
+ EIGEN_UNUSED_VARIABLE(B_t);
1118
+ EIGEN_UNUSED_VARIABLE(A_t);
1119
+ EIGEN_UNUSED_VARIABLE(LDB);
1120
+ EIGEN_UNUSED_VARIABLE(LDA);
1121
+ EIGEN_UNUSED_VARIABLE(zmm);
1122
+ EIGEN_UNUSED_VARIABLE(rem_);
1123
+ }
1124
+
1125
+ /********************************************************
1126
+ * Wrappers for aux_XXXX to hide counter parameter
1127
+ ********************************************************/
1128
+
1129
+ template <int64_t endM, int64_t endN>
1130
+ static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1131
+ aux_setzero<endM, endN, endM * endN>(zmm);
1132
+ }
1133
+
1134
+ /**
1135
+ * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
1136
+ */
1137
+ template <int64_t endM, int64_t endN, bool rem = false>
1138
+ static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
1139
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1140
+ int64_t rem_ = 0) {
1141
+ EIGEN_UNUSED_VARIABLE(rem_);
1142
+ aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1143
+ }
1144
+
1145
+ template <int64_t endM, int64_t endN, bool rem = false>
1146
+ static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
1147
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1148
+ int64_t rem_ = 0) {
1149
+ EIGEN_UNUSED_VARIABLE(rem_);
1150
+ aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
1151
+ }
1152
+
1153
+ /**
1154
+ * Use numLoad registers for loading B at start of microKernel
1155
+ */
1156
+ template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
1157
+ static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
1158
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1159
+ int64_t rem_ = 0) {
1160
+ EIGEN_UNUSED_VARIABLE(rem_);
1161
+ aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
1162
+ }
1163
+
1164
+ /**
1165
+ * Use numBCast registers for broadcasting A at start of microKernel
1166
+ */
1167
+ template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
1168
+ static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
1169
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
1170
+ aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
1171
+ }
1172
+
1173
+ /**
1174
+ * Loads next set of B into vector registers between each K unroll.
1175
+ */
1176
+ template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
1177
+ static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
1178
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1179
+ int64_t rem_ = 0) {
1180
+ EIGEN_UNUSED_VARIABLE(rem_);
1181
+ aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
1182
+ }
1183
+
1184
+ /**
1185
+ * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
1186
+ * A matrix can be row/col-major. B matrix is assumed row-major.
1187
+ *
1188
+ * isARowMajor: is A row major
1189
+ * endM: Number registers per row
1190
+ * endN: Number of rows
1191
+ * endK: Loop unroll for K.
1192
+ * numLoad: Number of registers for loading B.
1193
+ * numBCast: Number of registers for broadcasting A.
1194
+ *
1195
+ * Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
1196
+ * 6 register for loading B, 2 for broadcasting A.
1197
+ *
1198
+ * Note: Ideally the microkernel should not have any register spilling.
1199
+ * The avx instruction counts should be:
1200
+ * - endK*endN vbroadcasts{s,d}
1201
+ * - endK*endM vmovup{s,d}
1202
+ * - endK*endN*endM FMAs
1203
+ *
1204
+ * From testing, there are no register spills with clang. There are register spills with GNU, which
1205
+ * causes a performance hit.
1206
+ */
1207
+ template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
1208
+ bool rem = false>
1209
+ static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
1210
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
1211
+ int64_t rem_ = 0) {
1212
+ EIGEN_UNUSED_VARIABLE(rem_);
1213
+ aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
1214
+ rem_);
1215
+ }
1216
+ };
1217
+ } // namespace unrolls
1218
+
1219
+ #endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H