@smake/eigen 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -20
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +187 -120
  6. package/eigen/Eigen/Eigenvalues +16 -13
  7. package/eigen/Eigen/Geometry +18 -18
  8. package/eigen/Eigen/Householder +9 -7
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -13
  11. package/eigen/Eigen/KLUSupport +23 -21
  12. package/eigen/Eigen/LU +15 -16
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -51
  15. package/eigen/Eigen/PaStiXSupport +23 -21
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -20
  18. package/eigen/Eigen/QtAlignedMalloc +5 -12
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -17
  21. package/eigen/Eigen/Sparse +1 -2
  22. package/eigen/Eigen/SparseCholesky +18 -15
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +9 -9
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +366 -405
  37. package/eigen/Eigen/src/Cholesky/LLT.h +323 -367
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +585 -529
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +143 -317
  42. package/eigen/Eigen/src/Core/Array.h +329 -370
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +126 -170
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +651 -604
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +125 -120
  48. package/eigen/Eigen/src/Core/BandMatrix.h +267 -282
  49. package/eigen/Eigen/src/Core/Block.h +371 -390
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +85 -100
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1214 -937
  53. package/eigen/Eigen/src/Core/CoreIterators.h +72 -63
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +112 -129
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +676 -702
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +55 -67
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +127 -92
  59. package/eigen/Eigen/src/Core/DenseBase.h +630 -658
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -628
  61. package/eigen/Eigen/src/Core/DenseStorage.h +511 -590
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +168 -207
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +346 -317
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +167 -217
  67. package/eigen/Eigen/src/Core/EigenBase.h +74 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -113
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +315 -261
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1182 -520
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +193 -157
  75. package/eigen/Eigen/src/Core/IO.h +131 -156
  76. package/eigen/Eigen/src/Core/IndexedView.h +209 -125
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +50 -59
  80. package/eigen/Eigen/src/Core/Map.h +123 -141
  81. package/eigen/Eigen/src/Core/MapBase.h +255 -282
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1247 -1201
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +162 -99
  84. package/eigen/Eigen/src/Core/Matrix.h +463 -494
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -470
  86. package/eigen/Eigen/src/Core/NestByValue.h +58 -52
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -86
  88. package/eigen/Eigen/src/Core/NumTraits.h +206 -206
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +163 -142
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +858 -972
  92. package/eigen/Eigen/src/Core/Product.h +246 -130
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +779 -671
  94. package/eigen/Eigen/src/Core/Random.h +153 -164
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +334 -314
  98. package/eigen/Eigen/src/Core/Ref.h +259 -257
  99. package/eigen/Eigen/src/Core/Replicate.h +92 -104
  100. package/eigen/Eigen/src/Core/Reshaped.h +215 -271
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +47 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +133 -148
  103. package/eigen/Eigen/src/Core/Select.h +68 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +254 -290
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +88 -102
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +126 -124
  109. package/eigen/Eigen/src/Core/SolverBase.h +132 -133
  110. package/eigen/Eigen/src/Core/StableNorm.h +113 -147
  111. package/eigen/Eigen/src/Core/StlIterators.h +404 -248
  112. package/eigen/Eigen/src/Core/Stride.h +90 -92
  113. package/eigen/Eigen/src/Core/Swap.h +70 -39
  114. package/eigen/Eigen/src/Core/Transpose.h +258 -295
  115. package/eigen/Eigen/src/Core/Transpositions.h +270 -333
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +642 -743
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +653 -704
  119. package/eigen/Eigen/src/Core/Visitor.h +464 -308
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +380 -187
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +65 -163
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2145 -638
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +253 -60
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +278 -228
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +48 -269
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1597 -754
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +229 -41
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +420 -184
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +40 -49
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2962 -2213
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +196 -212
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +713 -441
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2380 -1362
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +390 -224
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +78 -67
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1784 -799
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +167 -50
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +528 -379
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +10 -12
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +41 -40
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +550 -523
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +27 -30
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +8 -8
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +54 -82
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +84 -92
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +51 -47
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +454 -306
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +175 -115
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +23 -30
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4366 -2857
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +616 -393
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +350 -198
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +38 -149
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +1791 -912
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +128 -40
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +10 -6
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +156 -234
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +6 -3
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +27 -32
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +119 -117
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +325 -419
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +15 -17
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +325 -181
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +94 -83
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +811 -458
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +121 -124
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +576 -370
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +194 -109
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +95 -112
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1038 -749
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1883 -1375
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +312 -370
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +189 -176
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +84 -81
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +292 -337
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +207 -105
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +327 -388
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +138 -147
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -47
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -277
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +68 -94
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +342 -303
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +348 -317
  217. package/eigen/Eigen/src/Core/util/Constants.h +297 -262
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -90
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +449 -247
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +417 -116
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +211 -204
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -37
  226. package/eigen/Eigen/src/Core/util/Macros.h +655 -773
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +970 -748
  229. package/eigen/Eigen/src/Core/util/Meta.h +581 -633
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +17 -17
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +50 -166
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +377 -225
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +784 -547
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +89 -105
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +537 -607
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +342 -381
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +541 -595
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +430 -462
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +226 -227
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +131 -133
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +285 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +151 -160
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -146
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +127 -127
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +566 -506
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +107 -105
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +113 -106
  264. package/eigen/Eigen/src/Geometry/Transform.h +858 -936
  265. package/eigen/Eigen/src/Geometry/Translation.h +94 -92
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +90 -104
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +51 -46
  269. package/eigen/Eigen/src/Householder/Householder.h +102 -124
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +412 -453
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +149 -162
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +124 -119
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +92 -104
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +251 -243
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +224 -228
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +178 -227
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +79 -84
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +54 -60
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +252 -308
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +208 -227
  285. package/eigen/Eigen/src/LU/Determinant.h +50 -69
  286. package/eigen/Eigen/src/LU/FullPivLU.h +545 -596
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +206 -285
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +390 -428
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +72 -70
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +243 -265
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +831 -1004
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +112 -119
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -430
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +479 -479
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +166 -153
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +495 -475
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +394 -285
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +244 -264
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +817 -713
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +577 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +242 -182
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +200 -235
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +765 -594
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +308 -94
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -252
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +134 -178
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +149 -140
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +403 -440
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +525 -303
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +555 -339
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +169 -197
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1603 -1245
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -350
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +94 -97
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +370 -416
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +138 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +756 -710
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +245 -301
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +89 -100
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +124 -132
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +450 -502
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -93
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -730
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +428 -464
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9972 -16179
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/{BlockMethods.h → BlockMethods.inc} +434 -506
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/{CommonCwiseUnaryOps.h → CommonCwiseUnaryOps.inc} +58 -68
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/package.json +1 -1
  409. package/eigen/COPYING.APACHE +0 -203
  410. package/eigen/COPYING.BSD +0 -26
  411. package/eigen/COPYING.GPL +0 -674
  412. package/eigen/COPYING.LGPL +0 -502
  413. package/eigen/COPYING.MINPACK +0 -51
  414. package/eigen/COPYING.MPL2 +0 -373
  415. package/eigen/COPYING.README +0 -18
  416. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -162
  417. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -258
  418. package/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +0 -120
  419. package/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +0 -694
  420. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  421. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  422. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  423. package/eigen/Eigen/src/misc/lapack.h +0 -152
  424. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -358
  425. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -696
  426. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  427. package/eigen/Eigen/src/plugins/IndexedViewMethods.h +0 -262
  428. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  429. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -95
  430. package/eigen/Eigen/src/plugins/ReshapedMethods.h +0 -149
  431. package/eigen/README.md +0 -5
@@ -0,0 +1,1167 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
11
+ #define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
12
+
13
+ // IWYU pragma: private
14
+ #include "../../InternalHeaderCheck.h"
15
+
16
+ #if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
17
+ #define EIGEN_USE_AVX512_TRSM_KERNELS 1
18
+ #endif
19
+
20
+ // TRSM kernels currently unconditionally rely on malloc with AVX512.
21
+ // Disable them if malloc is explicitly disabled at compile-time.
22
+ #ifdef EIGEN_NO_MALLOC
23
+ #undef EIGEN_USE_AVX512_TRSM_KERNELS
24
+ #define EIGEN_USE_AVX512_TRSM_KERNELS 0
25
+ #endif
26
+
27
+ #if EIGEN_USE_AVX512_TRSM_KERNELS
28
+ #if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
29
+ #define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
30
+ #endif
31
+ #if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
32
+ #define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
33
+ #endif
34
+ #else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
35
+ #define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
36
+ #define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
37
+ #endif
38
+
39
+ // Need this for some std::min calls.
40
+ #ifdef min
41
+ #undef min
42
+ #endif
43
+
44
+ namespace Eigen {
45
+ namespace internal {
46
+
47
+ #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
48
+ #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
49
+ #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
50
+ #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
51
+ #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
52
+ typedef Packet16f vecFullFloat;
53
+ typedef Packet8d vecFullDouble;
54
+ typedef Packet8f vecHalfFloat;
55
+ typedef Packet4d vecHalfDouble;
56
+
57
+ // Compile-time unrolls are implemented here.
58
+ // Note: this depends on macros and typedefs above.
59
+ #include "TrsmUnrolls.inc"
60
+
61
+ #if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
62
+ /**
63
+ * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
64
+ * is faster than the packed versions in TriangularSolverMatrix.h.
65
+ *
66
+ * The current heuristic is based on having having all arrays used in the largest gemm-update
67
+ * in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
68
+ * larger for some trsm cases.
69
+ * The formula:
70
+ *
71
+ * (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
72
+ *
73
+ * L = number of rows to solve at a time
74
+ * N = number of rhs
75
+ * M = Dimension of triangular matrix
76
+ *
77
+ */
78
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
79
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
80
+ #endif
81
+
82
+ #if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
83
+
84
+ #if EIGEN_USE_AVX512_TRSM_R_KERNELS
85
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
86
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
87
+ #endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
88
+ #endif
89
+
90
+ #if EIGEN_USE_AVX512_TRSM_L_KERNELS
91
+ #if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
92
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
93
+ #endif
94
+ #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
95
+
96
+ #else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
97
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
98
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
99
+ #endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
100
+
101
+ template <typename Scalar>
102
+ int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
103
+ const int64_t U3 = 3 * packet_traits<Scalar>::size;
104
+ const int64_t MaxNb = 5 * U3;
105
+ int64_t Nb = std::min(MaxNb, N);
106
+ double cutoff_d =
107
+ (((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
108
+ int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
109
+ return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
110
+ }
111
+ #else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
112
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
113
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
114
+ #define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
115
+ #endif
116
+
117
+ /**
118
+ * Used by gemmKernel for the case A/B row-major and C col-major.
119
+ */
120
+ template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
121
+ EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, Scalar *C_arr,
122
+ int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
123
+ EIGEN_UNUSED_VARIABLE(remN_);
124
+ EIGEN_UNUSED_VARIABLE(remM_);
125
+ using urolls = unrolls::trans<Scalar>;
126
+
127
+ constexpr int64_t U3 = urolls::PacketSize * 3;
128
+ constexpr int64_t U2 = urolls::PacketSize * 2;
129
+ constexpr int64_t U1 = urolls::PacketSize * 1;
130
+
131
+ static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
132
+ static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
133
+
134
+ urolls::template transpose<unrollN, 0>(zmm);
135
+ EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
136
+ EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
137
+
138
+ static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
139
+ EIGEN_IF_CONSTEXPR(!remN) {
140
+ urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
141
+ EIGEN_IF_CONSTEXPR(unrollN > U1) {
142
+ constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
143
+ urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
144
+ }
145
+ EIGEN_IF_CONSTEXPR(unrollN > U2) {
146
+ constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
147
+ urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
148
+ }
149
+ }
150
+ else {
151
+ EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
152
+ // Note: without "if constexpr" this section of code will also be
153
+ // parsed by the compiler so each of the storeC will still be instantiated.
154
+ // We use enable_if in aux_storeC to set it to an empty function for
155
+ // these cases.
156
+ if (remN_ == 15)
157
+ urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
158
+ else if (remN_ == 14)
159
+ urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
160
+ else if (remN_ == 13)
161
+ urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
162
+ else if (remN_ == 12)
163
+ urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
164
+ else if (remN_ == 11)
165
+ urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
166
+ else if (remN_ == 10)
167
+ urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
168
+ else if (remN_ == 9)
169
+ urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
170
+ else if (remN_ == 8)
171
+ urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
172
+ else if (remN_ == 7)
173
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
174
+ else if (remN_ == 6)
175
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
176
+ else if (remN_ == 5)
177
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
178
+ else if (remN_ == 4)
179
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
180
+ else if (remN_ == 3)
181
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
182
+ else if (remN_ == 2)
183
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
184
+ else if (remN_ == 1)
185
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
186
+ }
187
+ else {
188
+ if (remN_ == 7)
189
+ urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
190
+ else if (remN_ == 6)
191
+ urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
192
+ else if (remN_ == 5)
193
+ urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
194
+ else if (remN_ == 4)
195
+ urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
196
+ else if (remN_ == 3)
197
+ urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
198
+ else if (remN_ == 2)
199
+ urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
200
+ else if (remN_ == 1)
201
+ urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
202
+ }
203
+ }
204
+ }
205
+
206
+ /**
207
+ * GEMM like operation for trsm panel updates.
208
+ * Computes: C -= A*B
209
+ * K must be multiple of 4.
210
+ *
211
+ * Unrolls used are {1,2,4,8}x{U1,U2,U3};
212
+ * For good performance we want K to be large with M/N relatively small, but also large enough
213
+ * to use the {8,U3} unroll block.
214
+ *
215
+ * isARowMajor: is A_arr row-major?
216
+ * isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
217
+ * isAdd: C += A*B or C -= A*B (used by trsm)
218
+ * handleKRem: Handle arbitrary K? This is not needed for trsm.
219
+ */
220
+ template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
221
+ void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
222
+ int64_t LDC) {
223
+ using urolls = unrolls::gemm<Scalar, isAdd>;
224
+ constexpr int64_t U3 = urolls::PacketSize * 3;
225
+ constexpr int64_t U2 = urolls::PacketSize * 2;
226
+ constexpr int64_t U1 = urolls::PacketSize * 1;
227
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
228
+ int64_t N_ = (N / U3) * U3;
229
+ int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
230
+ int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
231
+ int64_t j = 0;
232
+ for (; j < N_; j += U3) {
233
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
234
+ int64_t i = 0;
235
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
236
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
237
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
238
+ urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
239
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
240
+ urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
241
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
242
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
243
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
244
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
245
+ }
246
+ EIGEN_IF_CONSTEXPR(handleKRem) {
247
+ for (int64_t k = K_; k < K; k++) {
248
+ urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
249
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
250
+ B_t += LDB;
251
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
252
+ else A_t += LDA;
253
+ }
254
+ }
255
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
256
+ urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
257
+ urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
258
+ }
259
+ else {
260
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
261
+ }
262
+ }
263
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
264
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
265
+ Scalar *B_t = &B_arr[0 * LDB + j];
266
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
267
+ urolls::template setzero<3, 4>(zmm);
268
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
269
+ urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
270
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
271
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
272
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
273
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
274
+ }
275
+ EIGEN_IF_CONSTEXPR(handleKRem) {
276
+ for (int64_t k = K_; k < K; k++) {
277
+ urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
278
+ B_t, A_t, LDB, LDA, zmm);
279
+ B_t += LDB;
280
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
281
+ else A_t += LDA;
282
+ }
283
+ }
284
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
285
+ urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
286
+ urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
287
+ }
288
+ else {
289
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
290
+ }
291
+ i += 4;
292
+ }
293
+ if (M - i >= 2) {
294
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
295
+ Scalar *B_t = &B_arr[0 * LDB + j];
296
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
297
+ urolls::template setzero<3, 2>(zmm);
298
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
299
+ urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
300
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
301
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
302
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
303
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
304
+ }
305
+ EIGEN_IF_CONSTEXPR(handleKRem) {
306
+ for (int64_t k = K_; k < K; k++) {
307
+ urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
308
+ B_t, A_t, LDB, LDA, zmm);
309
+ B_t += LDB;
310
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
311
+ else A_t += LDA;
312
+ }
313
+ }
314
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
315
+ urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
316
+ urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
317
+ }
318
+ else {
319
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
320
+ }
321
+ i += 2;
322
+ }
323
+ if (M - i > 0) {
324
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
325
+ Scalar *B_t = &B_arr[0 * LDB + j];
326
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
327
+ urolls::template setzero<3, 1>(zmm);
328
+ {
329
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
330
+ urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
331
+ B_t, A_t, LDB, LDA, zmm);
332
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
333
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
334
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
335
+ }
336
+ EIGEN_IF_CONSTEXPR(handleKRem) {
337
+ for (int64_t k = K_; k < K; k++) {
338
+ urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
339
+ B_t += LDB;
340
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
341
+ else A_t += LDA;
342
+ }
343
+ }
344
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
345
+ urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
346
+ urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
347
+ }
348
+ else {
349
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
350
+ }
351
+ }
352
+ }
353
+ }
354
+ if (N - j >= U2) {
355
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
356
+ int64_t i = 0;
357
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
358
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
359
+ EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
360
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
361
+ urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
362
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
363
+ urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
364
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
365
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
366
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
367
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
368
+ }
369
+ EIGEN_IF_CONSTEXPR(handleKRem) {
370
+ for (int64_t k = K_; k < K; k++) {
371
+ urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
372
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
373
+ B_t += LDB;
374
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
375
+ else A_t += LDA;
376
+ }
377
+ }
378
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
379
+ urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
380
+ urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
381
+ }
382
+ else {
383
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
384
+ }
385
+ }
386
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
387
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
388
+ Scalar *B_t = &B_arr[0 * LDB + j];
389
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
390
+ urolls::template setzero<2, 4>(zmm);
391
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
392
+ urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
393
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
394
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
395
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
396
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
397
+ }
398
+ EIGEN_IF_CONSTEXPR(handleKRem) {
399
+ for (int64_t k = K_; k < K; k++) {
400
+ urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
401
+ LDA, zmm);
402
+ B_t += LDB;
403
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
404
+ else A_t += LDA;
405
+ }
406
+ }
407
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
408
+ urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
409
+ urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
410
+ }
411
+ else {
412
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
413
+ }
414
+ i += 4;
415
+ }
416
+ if (M - i >= 2) {
417
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
418
+ Scalar *B_t = &B_arr[0 * LDB + j];
419
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
420
+ urolls::template setzero<2, 2>(zmm);
421
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
422
+ urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
423
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
424
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
425
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
426
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
427
+ }
428
+ EIGEN_IF_CONSTEXPR(handleKRem) {
429
+ for (int64_t k = K_; k < K; k++) {
430
+ urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
431
+ LDA, zmm);
432
+ B_t += LDB;
433
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
434
+ else A_t += LDA;
435
+ }
436
+ }
437
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
438
+ urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
439
+ urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
440
+ }
441
+ else {
442
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
443
+ }
444
+ i += 2;
445
+ }
446
+ if (M - i > 0) {
447
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
448
+ Scalar *B_t = &B_arr[0 * LDB + j];
449
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
450
+ urolls::template setzero<2, 1>(zmm);
451
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
452
+ urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
453
+ LDA, zmm);
454
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
455
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
456
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
457
+ }
458
+ EIGEN_IF_CONSTEXPR(handleKRem) {
459
+ for (int64_t k = K_; k < K; k++) {
460
+ urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
461
+ B_t += LDB;
462
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
463
+ else A_t += LDA;
464
+ }
465
+ }
466
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
467
+ urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
468
+ urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
469
+ }
470
+ else {
471
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
472
+ }
473
+ }
474
+ j += U2;
475
+ }
476
+ if (N - j >= U1) {
477
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
478
+ int64_t i = 0;
479
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
480
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
481
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
482
+ urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
483
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
484
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
485
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
486
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
487
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
488
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
489
+ }
490
+ EIGEN_IF_CONSTEXPR(handleKRem) {
491
+ for (int64_t k = K_; k < K; k++) {
492
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
493
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
494
+ B_t += LDB;
495
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
496
+ else A_t += LDA;
497
+ }
498
+ }
499
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
500
+ urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
501
+ urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
502
+ }
503
+ else {
504
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
505
+ }
506
+ }
507
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
508
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
509
+ Scalar *B_t = &B_arr[0 * LDB + j];
510
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
511
+ urolls::template setzero<1, 4>(zmm);
512
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
513
+ urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
514
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
515
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
516
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
517
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
518
+ }
519
+ EIGEN_IF_CONSTEXPR(handleKRem) {
520
+ for (int64_t k = K_; k < K; k++) {
521
+ urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
522
+ LDA, zmm);
523
+ B_t += LDB;
524
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
525
+ else A_t += LDA;
526
+ }
527
+ }
528
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
529
+ urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
530
+ urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
531
+ }
532
+ else {
533
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
534
+ }
535
+ i += 4;
536
+ }
537
+ if (M - i >= 2) {
538
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
539
+ Scalar *B_t = &B_arr[0 * LDB + j];
540
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
541
+ urolls::template setzero<1, 2>(zmm);
542
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
543
+ urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
544
+ EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
545
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
546
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
547
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
548
+ }
549
+ EIGEN_IF_CONSTEXPR(handleKRem) {
550
+ for (int64_t k = K_; k < K; k++) {
551
+ urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
552
+ LDA, zmm);
553
+ B_t += LDB;
554
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
555
+ else A_t += LDA;
556
+ }
557
+ }
558
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
559
+ urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
560
+ urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
561
+ }
562
+ else {
563
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
564
+ }
565
+ i += 2;
566
+ }
567
+ if (M - i > 0) {
568
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
569
+ Scalar *B_t = &B_arr[0 * LDB + j];
570
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
571
+ urolls::template setzero<1, 1>(zmm);
572
+ {
573
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
574
+ urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
575
+ LDA, zmm);
576
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
577
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
578
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
579
+ }
580
+ EIGEN_IF_CONSTEXPR(handleKRem) {
581
+ for (int64_t k = K_; k < K; k++) {
582
+ urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
583
+ B_t += LDB;
584
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
585
+ else A_t += LDA;
586
+ }
587
+ }
588
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
589
+ urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
590
+ urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
591
+ }
592
+ else {
593
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
594
+ }
595
+ }
596
+ }
597
+ j += U1;
598
+ }
599
+ if (N - j > 0) {
600
+ constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
601
+ int64_t i = 0;
602
+ for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
603
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
604
+ Scalar *B_t = &B_arr[0 * LDB + j];
605
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
606
+ urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
607
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
608
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
609
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
610
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
611
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
612
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
613
+ }
614
+ EIGEN_IF_CONSTEXPR(handleKRem) {
615
+ for (int64_t k = K_; k < K; k++) {
616
+ urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
617
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
618
+ B_t += LDB;
619
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
620
+ else A_t += LDA;
621
+ }
622
+ }
623
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
624
+ urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
625
+ urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
626
+ }
627
+ else {
628
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
629
+ }
630
+ }
631
+ if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
632
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
633
+ Scalar *B_t = &B_arr[0 * LDB + j];
634
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
635
+ urolls::template setzero<1, 4>(zmm);
636
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
637
+ urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
638
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
639
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
640
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
641
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
642
+ }
643
+ EIGEN_IF_CONSTEXPR(handleKRem) {
644
+ for (int64_t k = K_; k < K; k++) {
645
+ urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
646
+ B_t, A_t, LDB, LDA, zmm, N - j);
647
+ B_t += LDB;
648
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
649
+ else A_t += LDA;
650
+ }
651
+ }
652
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
653
+ urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
654
+ urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
655
+ }
656
+ else {
657
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
658
+ }
659
+ i += 4;
660
+ }
661
+ if (M - i >= 2) {
662
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
663
+ Scalar *B_t = &B_arr[0 * LDB + j];
664
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
665
+ urolls::template setzero<1, 2>(zmm);
666
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
667
+ urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
668
+ EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
669
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
670
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
671
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
672
+ }
673
+ EIGEN_IF_CONSTEXPR(handleKRem) {
674
+ for (int64_t k = K_; k < K; k++) {
675
+ urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
676
+ B_t, A_t, LDB, LDA, zmm, N - j);
677
+ B_t += LDB;
678
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
679
+ else A_t += LDA;
680
+ }
681
+ }
682
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
683
+ urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
684
+ urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
685
+ }
686
+ else {
687
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
688
+ }
689
+ i += 2;
690
+ }
691
+ if (M - i > 0) {
692
+ Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
693
+ Scalar *B_t = &B_arr[0 * LDB + j];
694
+ PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
695
+ urolls::template setzero<1, 1>(zmm);
696
+ for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
697
+ urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
698
+ B_t, A_t, LDB, LDA, zmm, N - j);
699
+ B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
700
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
701
+ else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
702
+ }
703
+ EIGEN_IF_CONSTEXPR(handleKRem) {
704
+ for (int64_t k = K_; k < K; k++) {
705
+ urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
706
+ N - j);
707
+ B_t += LDB;
708
+ EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
709
+ else A_t += LDA;
710
+ }
711
+ }
712
+ EIGEN_IF_CONSTEXPR(isCRowMajor) {
713
+ urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
714
+ urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
715
+ }
716
+ else {
717
+ transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
718
+ }
719
+ }
720
+ }
721
+ }
722
+
723
+ /**
724
+ * Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
725
+ *
726
+ * unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
727
+ * isFWDSolve: is forward solve?
728
+ * isUnitDiag: is the diagonal of A all ones?
729
+ * The B matrix (RHS) is assumed to be row-major
730
+ */
731
+ template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
732
+ EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
733
+ static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
734
+ using urolls = unrolls::trsm<Scalar>;
735
+ constexpr int64_t U3 = urolls::PacketSize * 3;
736
+ constexpr int64_t U2 = urolls::PacketSize * 2;
737
+ constexpr int64_t U1 = urolls::PacketSize * 1;
738
+
739
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
740
+ PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
741
+
742
+ int64_t k = 0;
743
+ while (K - k >= U3) {
744
+ urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
745
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
746
+ AInPacket);
747
+ urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
748
+ k += U3;
749
+ }
750
+ if (K - k >= U2) {
751
+ urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
752
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
753
+ AInPacket);
754
+ urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
755
+ k += U2;
756
+ }
757
+ if (K - k >= U1) {
758
+ urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
759
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
760
+ AInPacket);
761
+ urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
762
+ k += U1;
763
+ }
764
+ if (K - k > 0) {
765
+ // Handle remaining number of RHS
766
+ urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
767
+ urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
768
+ AInPacket);
769
+ urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
770
+ }
771
+ }
772
+
773
+ /**
774
+ * Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
775
+ * a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
776
+ *
777
+ * isFWDSolve: is forward solve?
778
+ * isUnitDiag: is the diagonal of A all ones?
779
+ * The B matrix (RHS) is assumed to be row-major
780
+ */
781
+ template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
782
+ void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
783
+ // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
784
+ // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
785
+ using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
786
+ if (M == 8)
787
+ triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
788
+ else if (M == 7)
789
+ triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
790
+ else if (M == 6)
791
+ triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
792
+ else if (M == 5)
793
+ triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
794
+ else if (M == 4)
795
+ triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
796
+ else if (M == 3)
797
+ triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
798
+ else if (M == 2)
799
+ triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
800
+ else if (M == 1)
801
+ triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
802
+ return;
803
+ }
804
+
805
+ /**
806
+ * This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
807
+ *
808
+ * toTemp: true => copy to temporary array, false => copy from temporary array
809
+ * remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
810
+ *
811
+ */
812
+ template <typename Scalar, bool toTemp = true, bool remM = false>
813
+ EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
814
+ int64_t remM_ = 0) {
815
+ EIGEN_UNUSED_VARIABLE(remM_);
816
+ using urolls = unrolls::transB<Scalar>;
817
+ using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
818
+ PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
819
+ constexpr int64_t U3 = urolls::PacketSize * 3;
820
+ constexpr int64_t U2 = urolls::PacketSize * 2;
821
+ constexpr int64_t U1 = urolls::PacketSize * 1;
822
+ int64_t K_ = K / U3 * U3;
823
+ int64_t k = 0;
824
+
825
+ for (; k < K_; k += U3) {
826
+ urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
827
+ B_temp += U3;
828
+ }
829
+ if (K - k >= U2) {
830
+ urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
831
+ B_temp += U2;
832
+ k += U2;
833
+ }
834
+ if (K - k >= U1) {
835
+ urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
836
+ B_temp += U1;
837
+ k += U1;
838
+ }
839
+ EIGEN_IF_CONSTEXPR(U1 > 8) {
840
+ // Note: without "if constexpr" this section of code will also be
841
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
842
+ // to make sure the counter is not non-negative.
843
+ if (K - k >= 8) {
844
+ urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
845
+ B_temp += 8;
846
+ k += 8;
847
+ }
848
+ }
849
+ EIGEN_IF_CONSTEXPR(U1 > 4) {
850
+ // Note: without "if constexpr" this section of code will also be
851
+ // parsed by the compiler so there is an additional check in {load/store}BBlock
852
+ // to make sure the counter is not non-negative.
853
+ if (K - k >= 4) {
854
+ urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
855
+ B_temp += 4;
856
+ k += 4;
857
+ }
858
+ }
859
+ if (K - k >= 2) {
860
+ urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
861
+ B_temp += 2;
862
+ k += 2;
863
+ }
864
+ if (K - k >= 1) {
865
+ urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
866
+ B_temp += 1;
867
+ k += 1;
868
+ }
869
+ }
870
+
871
+ /**
872
+ * Main triangular solve driver
873
+ *
874
+ * Triangular solve with A on the left.
875
+ * Scalar: Scalar precision, only float/double is supported.
876
+ * isARowMajor: is A row-major?
877
+ * isBRowMajor: is B row-major?
878
+ * isFWDSolve: is this forward solve or backward (true => forward)?
879
+ * isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
880
+ *
881
+ * M: dimension of A
882
+ * numRHS: number of right hand sides (coincides with K dimension for gemm updates)
883
+ *
884
+ * Here are the mapping between the different TRSM cases (col-major) and triSolve:
885
+ *
886
+ * LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
887
+ * LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
888
+ * LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
889
+ * LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
890
+ * RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
891
+ * RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
892
+ * RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
893
+ * RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
894
+ *
895
+ * Note: For RXX cases M,numRHS should be swapped.
896
+ *
897
+ */
898
+ template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
899
+ bool isUnitDiag = false>
900
+ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
901
+ constexpr int64_t psize = packet_traits<Scalar>::size;
902
+ /**
903
+ * The values for kB, numM were determined experimentally.
904
+ * kB: Number of RHS we process at a time.
905
+ * numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
906
+ *
907
+ * kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
908
+ * performance with M=numRHS.
909
+ * It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
910
+ *
911
+ * numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
912
+ * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
913
+ * determine optimal values for numM.
914
+ */
915
+ constexpr int64_t kB = (3 * psize) * 5; // 5*U3
916
+ constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
917
+
918
+ int64_t sizeBTemp = 0;
919
+ Scalar *B_temp = NULL;
920
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
921
+ /**
922
+ * If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
923
+ * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
924
+ * The updated row-major copy of B is reused in the GEMM updates.
925
+ */
926
+ sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
927
+ }
928
+
929
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
930
+
931
+ for (int64_t k = 0; k < numRHS; k += kB) {
932
+ int64_t bK = numRHS - k > kB ? kB : numRHS - k;
933
+ int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
934
+
935
+ // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
936
+ // instead of bK RHS in triSolveKernelLxK.
937
+ int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
938
+ const int64_t numScalarPerCache = 64 / sizeof(Scalar);
939
+ // Leading dimension of B_temp, will be a multiple of the cache line size.
940
+ int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
941
+ int64_t offsetBTemp = 0;
942
+ for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
943
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
944
+ int64_t indA_i = isFWDSolve ? i : M - 1 - i;
945
+ int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
946
+ int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
947
+ int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
948
+ // Copy values from B to B_temp.
949
+ copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
950
+ // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
951
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
952
+ &A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
953
+ // Copy values from B_temp back to B. B_temp will be reused in gemm call below.
954
+ copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
955
+
956
+ offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
957
+ }
958
+ else {
959
+ int64_t ind = isFWDSolve ? i : M - 1 - i;
960
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
961
+ &A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
962
+ }
963
+ if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
964
+ /**
965
+ * For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
966
+ * to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
967
+ * B as follows:
968
+ *
969
+ * A B
970
+ * __
971
+ * |__|__ |__|
972
+ * |__|__|__ |__|
973
+ * |__|__|__|__ |__|
974
+ * |********|__| |**|
975
+ */
976
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
977
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
978
+ int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
979
+ int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
980
+ int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
981
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
982
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
983
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
984
+ }
985
+ else {
986
+ if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
987
+ /**
988
+ * Similar idea as mentioned above, but here we are limited by the number of updated values of B
989
+ * that can be stored (row-major) in B_temp.
990
+ *
991
+ * If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
992
+ * update and partially update the remaining old values of B which depends on the new values
993
+ * of B stored in B_temp. These values are then no longer needed and can be overwritten.
994
+ */
995
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
996
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
997
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
998
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
999
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1000
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1001
+ M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1002
+ offsetBTemp = 0;
1003
+ gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
1004
+ } else {
1005
+ /**
1006
+ * If there is enough space in B_temp, we only update the next 8xbK values of B.
1007
+ */
1008
+ int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1009
+ int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
1010
+ int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
1011
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1012
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1013
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
1014
+ EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
1015
+ }
1016
+ }
1017
+ }
1018
+ }
1019
+ // Handle M remainder..
1020
+ int64_t bM = M - M_;
1021
+ if (bM > 0) {
1022
+ if (M_ > 0) {
1023
+ EIGEN_IF_CONSTEXPR(isBRowMajor) {
1024
+ int64_t indA_i = isFWDSolve ? M_ : 0;
1025
+ int64_t indA_j = isFWDSolve ? 0 : bM;
1026
+ int64_t indB_i = isFWDSolve ? 0 : bM;
1027
+ int64_t indB_i2 = isFWDSolve ? M_ : 0;
1028
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
1029
+ &A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
1030
+ bK, M_, LDA, LDB, LDB);
1031
+ }
1032
+ else {
1033
+ int64_t indA_i = isFWDSolve ? M_ : 0;
1034
+ int64_t indA_j = isFWDSolve ? gemmOff : bM;
1035
+ int64_t indB_i = isFWDSolve ? M_ : 0;
1036
+ int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
1037
+ gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
1038
+ B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
1039
+ M_ - gemmOff, LDA, LDT, LDB);
1040
+ }
1041
+ }
1042
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) {
1043
+ int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
1044
+ int64_t indB_i = isFWDSolve ? M_ : 0;
1045
+ int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
1046
+ copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1047
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
1048
+ B_temp + offB_1, bM, bkL, LDA, bkL);
1049
+ copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
1050
+ }
1051
+ else {
1052
+ int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
1053
+ triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
1054
+ B_arr + k + ind * LDB, bM, bK, LDA, LDB);
1055
+ }
1056
+ }
1057
+ }
1058
+
1059
+ EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
1060
+ }
1061
+
1062
+ // Template specializations of trsmKernelL/R for float/double and inner strides of 1.
1063
+ #if (EIGEN_USE_AVX512_TRSM_KERNELS)
1064
+ #if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1065
+ template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1066
+ bool Specialized>
1067
+ struct trsmKernelR;
1068
+
1069
+ template <typename Index, int Mode, int TriStorageOrder>
1070
+ struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
1071
+ static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1072
+ Index otherStride);
1073
+ };
1074
+
1075
+ template <typename Index, int Mode, int TriStorageOrder>
1076
+ struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
1077
+ static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1078
+ Index otherStride);
1079
+ };
1080
+
1081
+ template <typename Index, int Mode, int TriStorageOrder>
1082
+ EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1083
+ Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1084
+ Index otherStride) {
1085
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1086
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1087
+ if (!is_malloc_allowed()) {
1088
+ trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1089
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1090
+ return;
1091
+ }
1092
+ #endif
1093
+ triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1094
+ const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1095
+ }
1096
+
1097
+ template <typename Index, int Mode, int TriStorageOrder>
1098
+ EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1099
+ Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1100
+ Index otherStride) {
1101
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1102
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1103
+ if (!is_malloc_allowed()) {
1104
+ trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1105
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1106
+ return;
1107
+ }
1108
+ #endif
1109
+ triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
1110
+ const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1111
+ }
1112
+ #endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
1113
+
1114
+ // These trsm kernels require temporary memory allocation
1115
+ #if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
1116
+ template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
1117
+ bool Specialized = true>
1118
+ struct trsmKernelL;
1119
+
1120
+ template <typename Index, int Mode, int TriStorageOrder>
1121
+ struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
1122
+ static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1123
+ Index otherStride);
1124
+ };
1125
+
1126
+ template <typename Index, int Mode, int TriStorageOrder>
1127
+ struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
1128
+ static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1129
+ Index otherStride);
1130
+ };
1131
+
1132
+ template <typename Index, int Mode, int TriStorageOrder>
1133
+ EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1134
+ Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
1135
+ Index otherStride) {
1136
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1137
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1138
+ if (!is_malloc_allowed()) {
1139
+ trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1140
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1141
+ return;
1142
+ }
1143
+ #endif
1144
+ triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1145
+ const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
1146
+ }
1147
+
1148
+ template <typename Index, int Mode, int TriStorageOrder>
1149
+ EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
1150
+ Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
1151
+ Index otherStride) {
1152
+ EIGEN_UNUSED_VARIABLE(otherIncr);
1153
+ #ifdef EIGEN_RUNTIME_NO_MALLOC
1154
+ if (!is_malloc_allowed()) {
1155
+ trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
1156
+ size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
1157
+ return;
1158
+ }
1159
+ #endif
1160
+ triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
1161
+ const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
1162
+ }
1163
+ #endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
1164
+ #endif // EIGEN_USE_AVX512_TRSM_KERNELS
1165
+ } // namespace internal
1166
+ } // namespace Eigen
1167
+ #endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H