@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -0,0 +1,1167 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11
+ #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
12
+
13
+ // IWYU pragma: private
14
+ #include "../../InternalHeaderCheck.h"
15
+
16
+ #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
17
+ #define EIGEN_USE_AVX512_TRSM_KERNELS 1
18
+ #endif
19
+
20
+ // TRSM kernels currently unconditionally rely on malloc with AVX512.
21
+ // Disable them if malloc is explicitly disabled at compile-time.
22
+ #ifdef EIGEN_NO_MALLOC
23
+ #undef EIGEN_USE_AVX512_TRSM_KERNELS
24
+ #define EIGEN_USE_AVX512_TRSM_KERNELS 0
25
+ #endif
26
+
27
+ #if EIGEN_USE_AVX512_TRSM_KERNELS
28
+ #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
29
+ #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
30
+ #endif
31
+ #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
32
+ #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
33
+ #endif
34
+ #else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
35
+ #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
36
+ #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
37
+ #endif
38
+
39
+ // Need this for some std::min calls.
40
+ #ifdef min
41
+ #undef min
42
+ #endif
43
+
44
+ namespace Eigen {
45
+ namespace internal {
46
+
47
+ #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
48
+ #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
49
+ #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
50
+ #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
51
+ #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
52
+ typedef Packet16f vecFullFloat;
53
+ typedef Packet8d vecFullDouble;
54
+ typedef Packet8f vecHalfFloat;
55
+ typedef Packet4d vecHalfDouble;
56
+
57
+ // Compile-time unrolls are implemented here.
58
+ // Note: this depends on macros and typedefs above.
59
+ #include "TrsmUnrolls.inc"
60
+
61
+ #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
62
+ /**
63
+ * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
64
+ * is faster than the packed versions in TriangularSolverMatrix.h.
65
+ *
66
+ * The current heuristic is based on having having all arrays used in the largest gemm-update
67
+ * in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
68
+ * larger for some trsm cases.
69
+ * The formula:
70
+ *
71
+ * (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
72
+ *
73
+ * L = number of rows to solve at a time
74
+ * N = number of rhs
75
+ * M = Dimension of triangular matrix
76
+ *
77
+ */
78
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
79
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
80
+ #endif
81
+
82
+ #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
83
+
84
+ #if EIGEN_USE_AVX512_TRSM_R_KERNELS
85
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
86
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
87
+ #endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
88
+ #endif
89
+
90
+ #if EIGEN_USE_AVX512_TRSM_L_KERNELS
91
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
92
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
93
+ #endif
94
+ #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
95
+
96
+ #else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
97
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
98
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
99
+ #endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
100
+
101
+ template <typename Scalar>
102
+ int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
103
+ const int64_t U3 = 3 * packet_traits<Scalar>::size;
104
+ const int64_t MaxNb = 5 * U3;
105
+ int64_t Nb = std::min(MaxNb, N);
106
+ double cutoff_d =
107
+ (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
108
+ int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
109
+ return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
110
+ }
111
+ #else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
112
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
113
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
114
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
115
+ #endif
116
+
117
+ /**
118
+ * Used by gemmKernel for the case A/B row-major and C col-major.
119
+ */
120
+ template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
121
+ EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, Scalar *C_arr,
122
+ int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
123
+ EIGEN_UNUSED_VARIABLE(remN_);
124
+ EIGEN_UNUSED_VARIABLE(remM_);
125
+ using urolls = unrolls::trans<Scalar>;
126
+
127
+ constexpr int64_t U3 = urolls::PacketSize * 3;
128
+ constexpr int64_t U2 = urolls::PacketSize * 2;
129
+ constexpr int64_t U1 = urolls::PacketSize * 1;
130
+
131
+ static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
132
+ static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
133
+
134
+ urolls::template transpose<unrollN, 0>(zmm);
135
+ EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
136
+ EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
137
+
138
+ static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
139
+ EIGEN_IF_CONSTEXPR(!remN) {
140
+ urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
141
+ EIGEN_IF_CONSTEXPR(unrollN > U1) {
142
+ constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
143
+ urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
144
+ }
145
+ EIGEN_IF_CONSTEXPR(unrollN > U2) {
146
+ constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
147
+ urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
148
+ }
149
+ }
150
+ else {
151
+ EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
152
+ // Note: without "if constexpr" this section of code will also be
153
+ // parsed by the compiler so each of the storeC will still be instantiated.
154
+ // We use enable_if in aux_storeC to set it to an empty function for
155
+ // these cases.
156
+ if (remN_ == 15)
157
+ urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
158
+ else if (remN_ == 14)
159
+ urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
160
+ else if (remN_ == 13)
161
+ urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
162
+ else if (remN_ == 12)
163
+ urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
164
+ else if (remN_ == 11)
165
+ urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
166
+ else if (remN_ == 10)
167
+ urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
168
+ else if (remN_ == 9)
169
+ urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
170
+ else if (remN_ == 8)
171
+ urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
172
+ else if (remN_ == 7)
173
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
174
+ else if (remN_ == 6)
175
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
176
+ else if (remN_ == 5)
177
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
178
+ else if (remN_ == 4)
179
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
180
+ else if (remN_ == 3)
181
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
182
+ else if (remN_ == 2)
183
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
184
+ else if (remN_ == 1)
185
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
186
+ }
187
+ else {
188
+ if (remN_ == 7)
189
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
190
+ else if (remN_ == 6)
191
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
192
+ else if (remN_ == 5)
193
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
194
+ else if (remN_ == 4)
195
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
196
+ else if (remN_ == 3)
197
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
198
+ else if (remN_ == 2)
199
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
200
+ else if (remN_ == 1)
201
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
202
+ }
203
+ }
204
+ }
205
+
206
+ /**
207
+ * GEMM like operation for trsm panel updates.
208
+ * Computes: C -= A*B
209
+ * K must be multiple of 4.
210
+ *
211
+ * Unrolls used are {1,2,4,8}x{U1,U2,U3};
212
+ * For good performance we want K to be large with M/N relatively small, but also large enough
213
+ * to use the {8,U3} unroll block.
214
+ *
215
+ * isARowMajor: is A_arr row-major?
216
+ * isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
217
+ * isAdd: C += A*B or C -= A*B (used by trsm)
218
+ * handleKRem: Handle arbitrary K? This is not needed for trsm.
219
+ */
220
+ template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
221
+ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
222
+ int64_t LDC) {
223
+ using urolls = unrolls::gemm<Scalar, isAdd>;
224
+ constexpr int64_t U3 = urolls::PacketSize * 3;
225
+ constexpr int64_t U2 = urolls::PacketSize * 2;
226
+ constexpr int64_t U1 = urolls::PacketSize * 1;
227
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
228
+ int64_t N_ = (N / U3) * U3;
229
+ int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
230
+ int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
231
+ int64_t j = 0;
232
+ for (; j < N_; j += U3) {
233
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
234
+ int64_t i = 0;
235
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
236
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
237
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
238
+ urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
239
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
240
+ urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
241
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
242
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
243
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
244
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
245
+ }
246
+ EIGEN_IF_CONSTEXPR(handleKRem) {
247
+ for (int64_t k = K_; k < K; k++) {
248
+ urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
249
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
250
+ B_t += LDB;
251
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
252
+ else A_t += LDA;
253
+ }
254
+ }
255
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
256
+ urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
257
+ urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
258
+ }
259
+ else {
260
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
261
+ }
262
+ }
263
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
264
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
265
+ Scalar *B_t = &B_arr[0 * LDB + j];
266
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
267
+ urolls::template setzero<3, 4>(zmm);
268
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
269
+ urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
270
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
271
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
272
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
273
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
274
+ }
275
+ EIGEN_IF_CONSTEXPR(handleKRem) {
276
+ for (int64_t k = K_; k < K; k++) {
277
+ urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
278
+ B_t, A_t, LDB, LDA, zmm);
279
+ B_t += LDB;
280
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
281
+ else A_t += LDA;
282
+ }
283
+ }
284
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
285
+ urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
286
+ urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
287
+ }
288
+ else {
289
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
290
+ }
291
+ i += 4;
292
+ }
293
+ if (M - i >= 2) {
294
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
295
+ Scalar *B_t = &B_arr[0 * LDB + j];
296
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
297
+ urolls::template setzero<3, 2>(zmm);
298
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
299
+ urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
300
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
301
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
302
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
303
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
304
+ }
305
+ EIGEN_IF_CONSTEXPR(handleKRem) {
306
+ for (int64_t k = K_; k < K; k++) {
307
+ urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
308
+ B_t, A_t, LDB, LDA, zmm);
309
+ B_t += LDB;
310
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
311
+ else A_t += LDA;
312
+ }
313
+ }
314
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
315
+ urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
316
+ urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
317
+ }
318
+ else {
319
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
320
+ }
321
+ i += 2;
322
+ }
323
+ if (M - i > 0) {
324
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
325
+ Scalar *B_t = &B_arr[0 * LDB + j];
326
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
327
+ urolls::template setzero<3, 1>(zmm);
328
+ {
329
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
330
+ urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
331
+ B_t, A_t, LDB, LDA, zmm);
332
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
333
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
334
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
335
+ }
336
+ EIGEN_IF_CONSTEXPR(handleKRem) {
337
+ for (int64_t k = K_; k < K; k++) {
338
+ urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
339
+ B_t += LDB;
340
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
341
+ else A_t += LDA;
342
+ }
343
+ }
344
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
345
+ urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
346
+ urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
347
+ }
348
+ else {
349
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
350
+ }
351
+ }
352
+ }
353
+ }
354
+ if (N - j >= U2) {
355
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
356
+ int64_t i = 0;
357
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
358
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
359
+ EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
360
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
361
+ urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
362
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
363
+ urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
364
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
365
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
366
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
367
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
368
+ }
369
+ EIGEN_IF_CONSTEXPR(handleKRem) {
370
+ for (int64_t k = K_; k < K; k++) {
371
+ urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
372
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
373
+ B_t += LDB;
374
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
375
+ else A_t += LDA;
376
+ }
377
+ }
378
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
379
+ urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
380
+ urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
381
+ }
382
+ else {
383
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
384
+ }
385
+ }
386
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
387
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
388
+ Scalar *B_t = &B_arr[0 * LDB + j];
389
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
390
+ urolls::template setzero<2, 4>(zmm);
391
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
392
+ urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
393
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
394
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
395
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
396
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
397
+ }
398
+ EIGEN_IF_CONSTEXPR(handleKRem) {
399
+ for (int64_t k = K_; k < K; k++) {
400
+ urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
401
+ LDA, zmm);
402
+ B_t += LDB;
403
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
404
+ else A_t += LDA;
405
+ }
406
+ }
407
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
408
+ urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
409
+ urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
410
+ }
411
+ else {
412
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
413
+ }
414
+ i += 4;
415
+ }
416
+ if (M - i >= 2) {
417
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
418
+ Scalar *B_t = &B_arr[0 * LDB + j];
419
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
420
+ urolls::template setzero<2, 2>(zmm);
421
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
422
+ urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
423
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
424
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
425
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
426
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
427
+ }
428
+ EIGEN_IF_CONSTEXPR(handleKRem) {
429
+ for (int64_t k = K_; k < K; k++) {
430
+ urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
431
+ LDA, zmm);
432
+ B_t += LDB;
433
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
434
+ else A_t += LDA;
435
+ }
436
+ }
437
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
438
+ urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
439
+ urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
440
+ }
441
+ else {
442
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
443
+ }
444
+ i += 2;
445
+ }
446
+ if (M - i > 0) {
447
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
448
+ Scalar *B_t = &B_arr[0 * LDB + j];
449
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
450
+ urolls::template setzero<2, 1>(zmm);
451
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
452
+ urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
453
+ LDA, zmm);
454
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
455
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
456
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
457
+ }
458
+ EIGEN_IF_CONSTEXPR(handleKRem) {
459
+ for (int64_t k = K_; k < K; k++) {
460
+ urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
461
+ B_t += LDB;
462
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
463
+ else A_t += LDA;
464
+ }
465
+ }
466
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
467
+ urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
468
+ urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
469
+ }
470
+ else {
471
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
472
+ }
473
+ }
474
+ j += U2;
475
+ }
476
+ if (N - j >= U1) {
477
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
478
+ int64_t i = 0;
479
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
480
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
481
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
482
+ urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
483
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
484
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
485
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
486
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
487
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
488
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
489
+ }
490
+ EIGEN_IF_CONSTEXPR(handleKRem) {
491
+ for (int64_t k = K_; k < K; k++) {
492
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
493
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
494
+ B_t += LDB;
495
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
496
+ else A_t += LDA;
497
+ }
498
+ }
499
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
500
+ urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
501
+ urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
502
+ }
503
+ else {
504
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
505
+ }
506
+ }
507
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
508
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
509
+ Scalar *B_t = &B_arr[0 * LDB + j];
510
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
511
+ urolls::template setzero<1, 4>(zmm);
512
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
513
+ urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
514
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
515
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
516
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
517
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
518
+ }
519
+ EIGEN_IF_CONSTEXPR(handleKRem) {
520
+ for (int64_t k = K_; k < K; k++) {
521
+ urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
522
+ LDA, zmm);
523
+ B_t += LDB;
524
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
525
+ else A_t += LDA;
526
+ }
527
+ }
528
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
529
+ urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
530
+ urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
531
+ }
532
+ else {
533
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
534
+ }
535
+ i += 4;
536
+ }
537
+ if (M - i >= 2) {
538
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
539
+ Scalar *B_t = &B_arr[0 * LDB + j];
540
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
541
+ urolls::template setzero<1, 2>(zmm);
542
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
543
+ urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
544
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
545
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
546
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
547
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
548
+ }
549
+ EIGEN_IF_CONSTEXPR(handleKRem) {
550
+ for (int64_t k = K_; k < K; k++) {
551
+ urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
552
+ LDA, zmm);
553
+ B_t += LDB;
554
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
555
+ else A_t += LDA;
556
+ }
557
+ }
558
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
559
+ urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
560
+ urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
561
+ }
562
+ else {
563
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
564
+ }
565
+ i += 2;
566
+ }
567
+ if (M - i > 0) {
568
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
569
+ Scalar *B_t = &B_arr[0 * LDB + j];
570
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
571
+ urolls::template setzero<1, 1>(zmm);
572
+ {
573
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
574
+ urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
575
+ LDA, zmm);
576
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
577
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
578
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
579
+ }
580
+ EIGEN_IF_CONSTEXPR(handleKRem) {
581
+ for (int64_t k = K_; k < K; k++) {
582
+ urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
583
+ B_t += LDB;
584
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
585
+ else A_t += LDA;
586
+ }
587
+ }
588
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
589
+ urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
590
+ urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
591
+ }
592
+ else {
593
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
594
+ }
595
+ }
596
+ }
597
+ j += U1;
598
+ }
599
+ if (N - j > 0) {
600
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
601
+ int64_t i = 0;
602
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
603
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
604
+ Scalar *B_t = &B_arr[0 * LDB + j];
605
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
606
+ urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
607
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
608
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
609
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
610
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
611
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
612
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
613
+ }
614
+ EIGEN_IF_CONSTEXPR(handleKRem) {
615
+ for (int64_t k = K_; k < K; k++) {
616
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
617
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
618
+ B_t += LDB;
619
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
620
+ else A_t += LDA;
621
+ }
622
+ }
623
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
624
+ urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
625
+ urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
626
+ }
627
+ else {
628
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
629
+ }
630
+ }
631
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
632
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
633
+ Scalar *B_t = &B_arr[0 * LDB + j];
634
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
635
+ urolls::template setzero<1, 4>(zmm);
636
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
637
+ urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
638
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
639
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
640
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
641
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
642
+ }
643
+ EIGEN_IF_CONSTEXPR(handleKRem) {
644
+ for (int64_t k = K_; k < K; k++) {
645
+ urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
646
+ B_t, A_t, LDB, LDA, zmm, N - j);
647
+ B_t += LDB;
648
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
649
+ else A_t += LDA;
650
+ }
651
+ }
652
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
653
+ urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
654
+ urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
655
+ }
656
+ else {
657
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
658
+ }
659
+ i += 4;
660
+ }
661
+ if (M - i >= 2) {
662
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
663
+ Scalar *B_t = &B_arr[0 * LDB + j];
664
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
665
+ urolls::template setzero<1, 2>(zmm);
666
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
667
+ urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
668
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
669
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
670
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
671
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
672
+ }
673
+ EIGEN_IF_CONSTEXPR(handleKRem) {
674
+ for (int64_t k = K_; k < K; k++) {
675
+ urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
676
+ B_t, A_t, LDB, LDA, zmm, N - j);
677
+ B_t += LDB;
678
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
679
+ else A_t += LDA;
680
+ }
681
+ }
682
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
683
+ urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
684
+ urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
685
+ }
686
+ else {
687
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
688
+ }
689
+ i += 2;
690
+ }
691
+ if (M - i > 0) {
692
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
693
+ Scalar *B_t = &B_arr[0 * LDB + j];
694
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
695
+ urolls::template setzero<1, 1>(zmm);
696
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
697
+ urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
698
+ B_t, A_t, LDB, LDA, zmm, N - j);
699
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
700
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
701
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
702
+ }
703
+ EIGEN_IF_CONSTEXPR(handleKRem) {
704
+ for (int64_t k = K_; k < K; k++) {
705
+ urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
706
+ N - j);
707
+ B_t += LDB;
708
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
709
+ else A_t += LDA;
710
+ }
711
+ }
712
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
713
+ urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
714
+ urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
715
+ }
716
+ else {
717
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
718
+ }
719
+ }
720
+ }
721
+ }
722
+
723
+ /**
724
+ * Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
725
+ *
726
+ * unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
727
+ * isFWDSolve: is forward solve?
728
+ * isUnitDiag: is the diagonal of A all ones?
729
+ * The B matrix (RHS) is assumed to be row-major
730
+ */
731
+ template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
732
+ EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
733
+ static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
734
+ using urolls = unrolls::trsm<Scalar>;
735
+ constexpr int64_t U3 = urolls::PacketSize * 3;
736
+ constexpr int64_t U2 = urolls::PacketSize * 2;
737
+ constexpr int64_t U1 = urolls::PacketSize * 1;
738
+
739
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
740
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
741
+
742
+ int64_t k = 0;
743
+ while (K - k >= U3) {
744
+ urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
745
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
746
+ AInPacket);
747
+ urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
748
+ k += U3;
749
+ }
750
+ if (K - k >= U2) {
751
+ urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
752
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
753
+ AInPacket);
754
+ urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
755
+ k += U2;
756
+ }
757
+ if (K - k >= U1) {
758
+ urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
759
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
760
+ AInPacket);
761
+ urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
762
+ k += U1;
763
+ }
764
+ if (K - k > 0) {
765
+ // Handle remaining number of RHS
766
+ urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
767
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
768
+ AInPacket);
769
+ urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
770
+ }
771
+ }
772
+
773
+ /**
774
+ * Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
775
+ * a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
776
+ *
777
+ * isFWDSolve: is forward solve?
778
+ * isUnitDiag: is the diagonal of A all ones?
779
+ * The B matrix (RHS) is assumed to be row-major
780
+ */
781
+ template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
782
+ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
783
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
784
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
785
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
786
+ if (M == 8)
787
+ triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
788
+ else if (M == 7)
789
+ triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
790
+ else if (M == 6)
791
+ triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
792
+ else if (M == 5)
793
+ triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
794
+ else if (M == 4)
795
+ triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
796
+ else if (M == 3)
797
+ triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
798
+ else if (M == 2)
799
+ triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
800
+ else if (M == 1)
801
+ triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
802
+ return;
803
+ }
804
+
805
+ /**
806
+ * This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
807
+ *
808
+ * toTemp: true => copy to temporary array, false => copy from temporary array
809
+ * remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
810
+ *
811
+ */
812
+ template <typename Scalar, bool toTemp = true, bool remM = false>
813
+ EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
814
+ int64_t remM_ = 0) {
815
+ EIGEN_UNUSED_VARIABLE(remM_);
816
+ using urolls = unrolls::transB<Scalar>;
817
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
818
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
819
+ constexpr int64_t U3 = urolls::PacketSize * 3;
820
+ constexpr int64_t U2 = urolls::PacketSize * 2;
821
+ constexpr int64_t U1 = urolls::PacketSize * 1;
822
+ int64_t K_ = K / U3 * U3;
823
+ int64_t k = 0;
824
+
825
+ for (; k < K_; k += U3) {
826
+ urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
827
+ B_temp += U3;
828
+ }
829
+ if (K - k >= U2) {
830
+ urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
831
+ B_temp += U2;
832
+ k += U2;
833
+ }
834
+ if (K - k >= U1) {
835
+ urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
836
+ B_temp += U1;
837
+ k += U1;
838
+ }
839
+ EIGEN_IF_CONSTEXPR(U1 > 8) {
840
+ // Note: without "if constexpr" this section of code will also be
841
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
842
+ // to make sure the counter is not non-negative.
843
+ if (K - k >= 8) {
844
+ urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
845
+ B_temp += 8;
846
+ k += 8;
847
+ }
848
+ }
849
+ EIGEN_IF_CONSTEXPR(U1 > 4) {
850
+ // Note: without "if constexpr" this section of code will also be
851
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
852
+ // to make sure the counter is not non-negative.
853
+ if (K - k >= 4) {
854
+ urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
855
+ B_temp += 4;
856
+ k += 4;
857
+ }
858
+ }
859
+ if (K - k >= 2) {
860
+ urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
861
+ B_temp += 2;
862
+ k += 2;
863
+ }
864
+ if (K - k >= 1) {
865
+ urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
866
+ B_temp += 1;
867
+ k += 1;
868
+ }
869
+ }
870
+
871
+ /**
872
+ * Main triangular solve driver
873
+ *
874
+ * Triangular solve with A on the left.
875
+ * Scalar: Scalar precision, only float/double is supported.
876
+ * isARowMajor: is A row-major?
877
+ * isBRowMajor: is B row-major?
878
+ * isFWDSolve: is this forward solve or backward (true => forward)?
879
+ * isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
880
+ *
881
+ * M: dimension of A
882
+ * numRHS: number of right hand sides (coincides with K dimension for gemm updates)
883
+ *
884
+ * Here are the mapping between the different TRSM cases (col-major) and triSolve:
885
+ *
886
+ * LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
887
+ * LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
888
+ * LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
889
+ * LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
890
+ * RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
891
+ * RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
892
+ * RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
893
+ * RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
894
+ *
895
+ * Note: For RXX cases M,numRHS should be swapped.
896
+ *
897
+ */
898
+ template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
899
+ bool isUnitDiag = false>
900
+ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
901
+ constexpr int64_t psize = packet_traits<Scalar>::size;
902
+ /**
903
+ * The values for kB, numM were determined experimentally.
904
+ * kB: Number of RHS we process at a time.
905
+ * numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
906
+ *
907
+ * kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
908
+ * performance with M=numRHS.
909
+ * It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
910
+ *
911
+ * numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
912
+ * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
913
+ * determine optimal values for numM.
914
+ */
915
+ constexpr int64_t kB = (3 * psize) * 5; // 5*U3
916
+ constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
917
+
918
+ int64_t sizeBTemp = 0;
919
+ Scalar *B_temp = NULL;
920
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
921
+ /**
922
+ * If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
923
+ * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
924
+ * The updated row-major copy of B is reused in the GEMM updates.
925
+ */
926
+ sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
927
+ }
928
+
929
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
930
+
931
+ for (int64_t k = 0; k < numRHS; k += kB) {
932
+ int64_t bK = numRHS - k > kB ? kB : numRHS - k;
933
+ int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
934
+
935
+ // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
936
+ // instead of bK RHS in triSolveKernelLxK.
937
+ int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
938
+ const int64_t numScalarPerCache = 64 / sizeof(Scalar);
939
+ // Leading dimension of B_temp, will be a multiple of the cache line size.
940
+ int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
941
+ int64_t offsetBTemp = 0;
942
+ for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
943
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
944
+ int64_t indA_i = isFWDSolve ? i : M - 1 - i;
945
+ int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
946
+ int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
947
+ int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
948
+ // Copy values from B to B_temp.
949
+ copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
950
+ // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
951
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
952
+ &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
953
+ // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
954
+ copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
955
+
956
+ offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
957
+ }
958
+ else {
959
+ int64_t ind = isFWDSolve ? i : M - 1 - i;
960
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
961
+ &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
962
+ }
963
+ if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
964
+ /**
965
+ * For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
966
+ * to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
967
+ * B as follows:
968
+ *
969
+ * A B
970
+ * __
971
+ * |__|__ |__|
972
+ * |__|__|__ |__|
973
+ * |__|__|__|__ |__|
974
+ * |********|__| |**|
975
+ */
976
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
977
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
978
+ int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
979
+ int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
980
+ int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
981
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
982
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
983
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
984
+ }
985
+ else {
986
+ if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
987
+ /**
988
+ * Similar idea as mentioned above, but here we are limited by the number of updated values of B
989
+ * that can be stored (row-major) in B_temp.
990
+ *
991
+ * If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
992
+ * update and partially update the remaining old values of B which depends on the new values
993
+ * of B stored in B_temp. These values are then no longer needed and can be overwritten.
994
+ */
995
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
996
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
997
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
998
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
999
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1000
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1001
+ M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1002
+ offsetBTemp = 0;
1003
+ gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
1004
+ } else {
1005
+ /**
1006
+ * If there is enough space in B_temp, we only update the next 8xbK values of B.
1007
+ */
1008
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1009
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
1010
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1011
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1012
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1013
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1014
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1015
+ }
1016
+ }
1017
+ }
1018
+ }
1019
+ // Handle M remainder..
1020
+ int64_t bM = M - M_;
1021
+ if (bM > 0) {
1022
+ if (M_ > 0) {
1023
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
1024
+ int64_t indA_i = isFWDSolve ? M_ : 0;
1025
+ int64_t indA_j = isFWDSolve ? 0 : bM;
1026
+ int64_t indB_i = isFWDSolve ? 0 : bM;
1027
+ int64_t indB_i2 = isFWDSolve ? M_ : 0;
1028
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1029
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1030
+ bK, M_, LDA, LDB, LDB);
1031
+ }
1032
+ else {
1033
+ int64_t indA_i = isFWDSolve ? M_ : 0;
1034
+ int64_t indA_j = isFWDSolve ? gemmOff : bM;
1035
+ int64_t indB_i = isFWDSolve ? M_ : 0;
1036
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1037
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1038
+ B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1039
+ M_ - gemmOff, LDA, LDT, LDB);
1040
+ }
1041
+ }
1042
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
1043
+ int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
1044
+ int64_t indB_i = isFWDSolve ? M_ : 0;
1045
+ int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1046
+ copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1047
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1048
+ B_temp + offB_1, bM, bkL, LDA, bkL);
1049
+ copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1050
+ }
1051
+ else {
1052
+ int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
1053
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
1054
+ B_arr + k + ind * LDB, bM, bK, LDA, LDB);
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
1060
+ }
1061
+
1062
+ // Template specializations of trsmKernelL/R for float/double and inner strides of 1.
1063
+ #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1064
+ #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1065
+ template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1066
+ bool Specialized>
1067
+ struct trsmKernelR;
1068
+
1069
+ template <typename Index, int Mode, int TriStorageOrder>
1070
+ struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
1071
+ static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1072
+ Index otherStride);
1073
+ };
1074
+
1075
+ template <typename Index, int Mode, int TriStorageOrder>
1076
+ struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
1077
+ static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1078
+ Index otherStride);
1079
+ };
1080
+
1081
+ template <typename Index, int Mode, int TriStorageOrder>
1082
+ EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1083
+ Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1084
+ Index otherStride) {
1085
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1086
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1087
+ if (!is_malloc_allowed()) {
1088
+ trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1089
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1090
+ return;
1091
+ }
1092
+ #endif
1093
+ triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1094
+ const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1095
+ }
1096
+
1097
+ template <typename Index, int Mode, int TriStorageOrder>
1098
+ EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1099
+ Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1100
+ Index otherStride) {
1101
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1102
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1103
+ if (!is_malloc_allowed()) {
1104
+ trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1105
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1106
+ return;
1107
+ }
1108
+ #endif
1109
+ triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1110
+ const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1111
+ }
1112
+ #endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1113
+
1114
+ // These trsm kernels require temporary memory allocation
1115
+ #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1116
+ template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1117
+ bool Specialized = true>
1118
+ struct trsmKernelL;
1119
+
1120
+ template <typename Index, int Mode, int TriStorageOrder>
1121
+ struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
1122
+ static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1123
+ Index otherStride);
1124
+ };
1125
+
1126
+ template <typename Index, int Mode, int TriStorageOrder>
1127
+ struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
1128
+ static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1129
+ Index otherStride);
1130
+ };
1131
+
1132
+ template <typename Index, int Mode, int TriStorageOrder>
1133
+ EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1134
+ Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1135
+ Index otherStride) {
1136
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1137
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1138
+ if (!is_malloc_allowed()) {
1139
+ trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1140
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1141
+ return;
1142
+ }
1143
+ #endif
1144
+ triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1145
+ const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1146
+ }
1147
+
1148
+ template <typename Index, int Mode, int TriStorageOrder>
1149
+ EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1150
+ Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1151
+ Index otherStride) {
1152
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1153
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1154
+ if (!is_malloc_allowed()) {
1155
+ trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1156
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1157
+ return;
1158
+ }
1159
+ #endif
1160
+ triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1161
+ const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1162
+ }
1163
+ #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
1164
+ #endif // EIGEN_USE_AVX512_TRSM_KERNELS
1165
+ } // namespace internal
1166
+ } // namespace Eigen
1167
+ #endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H