@smake/eigen 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (431) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -20
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +187 -120
  6. package/eigen/Eigen/Eigenvalues +16 -13
  7. package/eigen/Eigen/Geometry +18 -18
  8. package/eigen/Eigen/Householder +9 -7
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -13
  11. package/eigen/Eigen/KLUSupport +23 -21
  12. package/eigen/Eigen/LU +15 -16
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -51
  15. package/eigen/Eigen/PaStiXSupport +23 -21
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -20
  18. package/eigen/Eigen/QtAlignedMalloc +5 -12
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -17
  21. package/eigen/Eigen/Sparse +1 -2
  22. package/eigen/Eigen/SparseCholesky +18 -15
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +9 -9
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +366 -405
  37. package/eigen/Eigen/src/Cholesky/LLT.h +323 -367
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +585 -529
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +143 -317
  42. package/eigen/Eigen/src/Core/Array.h +329 -370
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +126 -170
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +651 -604
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +125 -120
  48. package/eigen/Eigen/src/Core/BandMatrix.h +267 -282
  49. package/eigen/Eigen/src/Core/Block.h +371 -390
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +85 -100
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1214 -937
  53. package/eigen/Eigen/src/Core/CoreIterators.h +72 -63
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +112 -129
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +676 -702
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +55 -67
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +127 -92
  59. package/eigen/Eigen/src/Core/DenseBase.h +630 -658
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -628
  61. package/eigen/Eigen/src/Core/DenseStorage.h +511 -590
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +168 -207
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +346 -317
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +167 -217
  67. package/eigen/Eigen/src/Core/EigenBase.h +74 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -113
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +315 -261
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1182 -520
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +193 -157
  75. package/eigen/Eigen/src/Core/IO.h +131 -156
  76. package/eigen/Eigen/src/Core/IndexedView.h +209 -125
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +50 -59
  80. package/eigen/Eigen/src/Core/Map.h +123 -141
  81. package/eigen/Eigen/src/Core/MapBase.h +255 -282
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1247 -1201
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +162 -99
  84. package/eigen/Eigen/src/Core/Matrix.h +463 -494
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -470
  86. package/eigen/Eigen/src/Core/NestByValue.h +58 -52
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -86
  88. package/eigen/Eigen/src/Core/NumTraits.h +206 -206
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +163 -142
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +858 -972
  92. package/eigen/Eigen/src/Core/Product.h +246 -130
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +779 -671
  94. package/eigen/Eigen/src/Core/Random.h +153 -164
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +334 -314
  98. package/eigen/Eigen/src/Core/Ref.h +259 -257
  99. package/eigen/Eigen/src/Core/Replicate.h +92 -104
  100. package/eigen/Eigen/src/Core/Reshaped.h +215 -271
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +47 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +133 -148
  103. package/eigen/Eigen/src/Core/Select.h +68 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +254 -290
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +88 -102
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +126 -124
  109. package/eigen/Eigen/src/Core/SolverBase.h +132 -133
  110. package/eigen/Eigen/src/Core/StableNorm.h +113 -147
  111. package/eigen/Eigen/src/Core/StlIterators.h +404 -248
  112. package/eigen/Eigen/src/Core/Stride.h +90 -92
  113. package/eigen/Eigen/src/Core/Swap.h +70 -39
  114. package/eigen/Eigen/src/Core/Transpose.h +258 -295
  115. package/eigen/Eigen/src/Core/Transpositions.h +270 -333
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +642 -743
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +653 -704
  119. package/eigen/Eigen/src/Core/Visitor.h +464 -308
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +380 -187
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +65 -163
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2145 -638
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +253 -60
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +278 -228
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +48 -269
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1597 -754
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +229 -41
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +420 -184
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +40 -49
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2962 -2213
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +196 -212
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +713 -441
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2380 -1362
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +390 -224
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +78 -67
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1784 -799
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +167 -50
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +528 -379
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +10 -12
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +41 -40
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +550 -523
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +27 -30
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +8 -8
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +54 -82
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +84 -92
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +51 -47
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +454 -306
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +175 -115
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +23 -30
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4366 -2857
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +616 -393
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +350 -198
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +38 -149
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +1791 -912
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +128 -40
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +10 -6
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +156 -234
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +6 -3
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +27 -32
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +119 -117
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +325 -419
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +15 -17
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +325 -181
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +94 -83
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +811 -458
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +121 -124
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +576 -370
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +194 -109
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +95 -112
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1038 -749
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1883 -1375
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +312 -370
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +189 -176
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +84 -81
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +292 -337
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +207 -105
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +327 -388
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +138 -147
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -47
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -277
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +68 -94
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +342 -303
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +348 -317
  217. package/eigen/Eigen/src/Core/util/Constants.h +297 -262
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -90
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +449 -247
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +417 -116
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +211 -204
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -37
  226. package/eigen/Eigen/src/Core/util/Macros.h +655 -773
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +970 -748
  229. package/eigen/Eigen/src/Core/util/Meta.h +581 -633
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +17 -17
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +50 -166
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +377 -225
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +784 -547
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +89 -105
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +537 -607
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +342 -381
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +541 -595
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +430 -462
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +226 -227
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +131 -133
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +285 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +151 -160
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -146
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +127 -127
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +566 -506
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +107 -105
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +113 -106
  264. package/eigen/Eigen/src/Geometry/Transform.h +858 -936
  265. package/eigen/Eigen/src/Geometry/Translation.h +94 -92
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +90 -104
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +51 -46
  269. package/eigen/Eigen/src/Householder/Householder.h +102 -124
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +412 -453
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +149 -162
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +124 -119
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +92 -104
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +251 -243
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +224 -228
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +178 -227
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +79 -84
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +54 -60
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +252 -308
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +208 -227
  285. package/eigen/Eigen/src/LU/Determinant.h +50 -69
  286. package/eigen/Eigen/src/LU/FullPivLU.h +545 -596
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +206 -285
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +390 -428
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +72 -70
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +243 -265
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +831 -1004
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +112 -119
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -430
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +479 -479
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +166 -153
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +495 -475
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +394 -285
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +244 -264
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +817 -713
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +577 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +242 -182
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +200 -235
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +765 -594
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +308 -94
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -252
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +134 -178
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +149 -140
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +403 -440
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +525 -303
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +555 -339
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +169 -197
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1603 -1245
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -350
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +94 -97
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +370 -416
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +138 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +756 -710
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +245 -301
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +89 -100
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +124 -132
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +450 -502
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -93
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -730
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +428 -464
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9972 -16179
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/{BlockMethods.h → BlockMethods.inc} +434 -506
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/{CommonCwiseUnaryOps.h → CommonCwiseUnaryOps.inc} +58 -68
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/package.json +1 -1
  409. package/eigen/COPYING.APACHE +0 -203
  410. package/eigen/COPYING.BSD +0 -26
  411. package/eigen/COPYING.GPL +0 -674
  412. package/eigen/COPYING.LGPL +0 -502
  413. package/eigen/COPYING.MINPACK +0 -51
  414. package/eigen/COPYING.MPL2 +0 -373
  415. package/eigen/COPYING.README +0 -18
  416. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -162
  417. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -258
  418. package/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +0 -120
  419. package/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +0 -694
  420. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  421. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  422. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  423. package/eigen/Eigen/src/misc/lapack.h +0 -152
  424. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -358
  425. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -696
  426. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  427. package/eigen/Eigen/src/plugins/IndexedViewMethods.h +0 -262
  428. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  429. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -95
  430. package/eigen/Eigen/src/plugins/ReshapedMethods.h +0 -149
  431. package/eigen/README.md +0 -5
@@ -0,0 +1,1245 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2022 Intel Corporation
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
11
+ #define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
12
+
13
+ #if EIGEN_COMP_MSVC
14
+ #include <intrin.h>
15
+ #else
16
+ #include <x86intrin.h>
17
+ #endif
18
+ #include <immintrin.h>
19
+ #include <type_traits>
20
+
21
+ // IWYU pragma: private
22
+ #include "../../InternalHeaderCheck.h"
23
+
24
+ #if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
25
+ #define EIGEN_USE_AVX512_GEMM_KERNELS 1
26
+ #endif
27
+
28
+ #define SECOND_FETCH (32)
29
+ #if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
30
+ // Use less registers to load A elements to workaround compiler spills. Loose a
31
+ // bit of performance (less than ~2%).
32
+ #define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
33
+ #endif
34
+
35
+ namespace Eigen {
36
+ namespace internal {
37
+
38
+ template <typename Scalar, bool is_unit_inc>
39
+ class gemm_class {
40
+ using vec = typename packet_traits<Scalar>::type;
41
+ using vec_ymm = typename unpacket_traits<vec>::half;
42
+ using vec_xmm = typename unpacket_traits<vec_ymm>::half;
43
+ using umask_t = typename unpacket_traits<vec>::mask_t;
44
+
45
+ static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
46
+ static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
47
+
48
+ #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
49
+ static constexpr bool use_less_a_regs = !is_unit_inc;
50
+ #else
51
+ static constexpr bool use_less_a_regs = true;
52
+ #endif
53
+ #ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
54
+ static constexpr bool use_less_b_regs = !is_unit_inc;
55
+ #else
56
+ static constexpr bool use_less_b_regs = true;
57
+ #endif
58
+
59
+ static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
60
+ static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
61
+ static constexpr int c_regs[] = {
62
+ 8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
63
+ };
64
+
65
+ static constexpr int alpha_load_reg = 0;
66
+ static constexpr int c_load_regs[] = {1, 2, 6};
67
+
68
+ static constexpr int a_shift = 128;
69
+ static constexpr int b_shift = 128;
70
+
71
+ static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
72
+ static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
73
+ static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
74
+
75
+ vec zmm[32];
76
+ umask_t mask;
77
+
78
+ // gemm arguments.
79
+ Index m;
80
+ const Index n, k, ldc;
81
+ const Index inc;
82
+ const Scalar *alpha;
83
+
84
+ const Scalar *a, *b;
85
+ Scalar *c;
86
+
87
+ const bool is_alpha1;
88
+ const bool is_beta0;
89
+
90
+ const Index a_stride, b_stride;
91
+ const Index a_off, b_off;
92
+
93
+ EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) {
94
+ _mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
95
+ }
96
+
97
+ EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) {
98
+ _mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
99
+ }
100
+
101
+ EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); }
102
+
103
+ EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) {
104
+ #if defined(__PRFCHW__) && __PRFCHW__ == 1
105
+ _m_prefetchw((void *)c_addr);
106
+ #else
107
+ _mm_prefetch((char *)c_addr, _MM_HINT_T0);
108
+ #endif
109
+ }
110
+
111
+ template <int nelems>
112
+ EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) {
113
+ switch (nelems * sizeof(*a_addr) * 8) {
114
+ default:
115
+ case 512 * 3:
116
+ a_reg = ploadu<vec>(a_addr);
117
+ break;
118
+ case 512 * 2:
119
+ a_reg = ploadu<vec>(a_addr);
120
+ break;
121
+ case 512 * 1:
122
+ a_reg = ploadu<vec>(a_addr);
123
+ break;
124
+ case 256 * 1:
125
+ a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
126
+ break;
127
+ case 128 * 1:
128
+ a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
129
+ break;
130
+ case 64 * 1:
131
+ a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
132
+ break;
133
+ case 32 * 1:
134
+ a_reg = pload1<vec>(a_addr);
135
+ break;
136
+ }
137
+ }
138
+
139
+ EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
140
+
141
+ template <int nelems>
142
+ EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) {
143
+ if (is_unit_inc) {
144
+ switch (nelems * sizeof(*mem) * 8) {
145
+ default:
146
+ case 512 * 3:
147
+ pstoreu(mem, src);
148
+ break;
149
+ case 512 * 2:
150
+ pstoreu(mem, src);
151
+ break;
152
+ case 512 * 1:
153
+ pstoreu(mem, src);
154
+ break;
155
+ case 256 * 1:
156
+ pstoreu(mem, preinterpret<vec_ymm>(src));
157
+ break;
158
+ case 128 * 1:
159
+ pstoreu(mem, preinterpret<vec_xmm>(src));
160
+ break;
161
+ case 64 * 1:
162
+ pstorel(mem, preinterpret<vec_xmm>(src));
163
+ break;
164
+ case 32 * 1:
165
+ pstores(mem, preinterpret<vec_xmm>(src));
166
+ break;
167
+ }
168
+ } else {
169
+ switch (nelems * sizeof(*mem) * 8) {
170
+ default:
171
+ case 512 * 3:
172
+ pscatter(mem, src, inc);
173
+ break;
174
+ case 512 * 2:
175
+ pscatter(mem, src, inc);
176
+ break;
177
+ case 512 * 1:
178
+ pscatter(mem, src, inc);
179
+ break;
180
+ case 256 * 1:
181
+ pscatter(mem, src, inc, mask);
182
+ break;
183
+ case 128 * 1:
184
+ pscatter(mem, src, inc, mask);
185
+ break;
186
+ case 64 * 1:
187
+ pscatter(mem, src, inc, mask);
188
+ break;
189
+ case 32 * 1:
190
+ pscatter(mem, src, inc, mask);
191
+ break;
192
+ }
193
+ }
194
+ }
195
+
196
+ template <int nelems>
197
+ EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec &reg) {
198
+ if (is_unit_inc) {
199
+ switch (nelems * sizeof(*mem) * 8) {
200
+ default:
201
+ case 512 * 3:
202
+ dst = padd(src, ploadu<vec>(mem));
203
+ break;
204
+ case 512 * 2:
205
+ dst = padd(src, ploadu<vec>(mem));
206
+ break;
207
+ case 512 * 1:
208
+ dst = padd(src, ploadu<vec>(mem));
209
+ break;
210
+ case 256 * 1:
211
+ dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
212
+ break;
213
+ case 128 * 1:
214
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
215
+ break;
216
+ case 64 * 1:
217
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
218
+ break;
219
+ case 32 * 1:
220
+ dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
221
+ break;
222
+ }
223
+ } else {
224
+ // Zero out scratch register
225
+ reg = pzero(reg);
226
+
227
+ switch (nelems * sizeof(*mem) * 8) {
228
+ default:
229
+ case 512 * 3:
230
+ reg = pgather<Scalar, vec>(mem, inc);
231
+ dst = padd(src, reg);
232
+ break;
233
+ case 512 * 2:
234
+ reg = pgather<Scalar, vec>(mem, inc);
235
+ dst = padd(src, reg);
236
+ break;
237
+ case 512 * 1:
238
+ reg = pgather<Scalar, vec>(mem, inc);
239
+ dst = padd(src, reg);
240
+ break;
241
+ case 256 * 1:
242
+ reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
243
+ dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
244
+ break;
245
+ case 128 * 1:
246
+ reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
247
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
248
+ break;
249
+ case 64 * 1:
250
+ if (is_f32) {
251
+ reg = pgather(reg, mem, inc, mask);
252
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
253
+ } else {
254
+ dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
255
+ }
256
+ break;
257
+ case 32 * 1:
258
+ dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
259
+ break;
260
+ }
261
+ }
262
+ }
263
+
264
+ EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
265
+ dst = pmadd(src1, src2, dst);
266
+
267
+ #if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
268
+ // Workaround register spills for gcc and clang
269
+ __asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
270
+ #endif
271
+ }
272
+
273
+ template <int nelems>
274
+ EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec &reg) {
275
+ if (is_unit_inc) {
276
+ switch (nelems * sizeof(*mem) * 8) {
277
+ default:
278
+ case 512 * 3:
279
+ dst = pmadd(scale, src, ploadu<vec>(mem));
280
+ break;
281
+ case 512 * 2:
282
+ dst = pmadd(scale, src, ploadu<vec>(mem));
283
+ break;
284
+ case 512 * 1:
285
+ dst = pmadd(scale, src, ploadu<vec>(mem));
286
+ break;
287
+ case 256 * 1:
288
+ dst =
289
+ preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
290
+ break;
291
+ case 128 * 1:
292
+ dst =
293
+ preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
294
+ break;
295
+ case 64 * 1:
296
+ dst =
297
+ preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
298
+ break;
299
+ case 32 * 1:
300
+ dst =
301
+ preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
302
+ break;
303
+ }
304
+ } else {
305
+ // Zero out scratch register
306
+ reg = pzero(reg);
307
+
308
+ switch (nelems * sizeof(*mem) * 8) {
309
+ default:
310
+ case 512 * 3:
311
+ reg = pgather<Scalar, vec>(mem, inc);
312
+ dst = pmadd(scale, src, reg);
313
+ break;
314
+ case 512 * 2:
315
+ reg = pgather<Scalar, vec>(mem, inc);
316
+ dst = pmadd(scale, src, reg);
317
+ break;
318
+ case 512 * 1:
319
+ reg = pgather<Scalar, vec>(mem, inc);
320
+ dst = pmadd(scale, src, reg);
321
+ break;
322
+ case 256 * 1:
323
+ reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
324
+ dst = preinterpret<vec>(
325
+ pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
326
+ break;
327
+ case 128 * 1:
328
+ reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
329
+ dst = preinterpret<vec>(
330
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
331
+ break;
332
+ case 64 * 1:
333
+ if (is_f32) {
334
+ reg = pgather(reg, mem, inc, mask);
335
+ dst = preinterpret<vec>(
336
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
337
+ } else {
338
+ dst = preinterpret<vec>(
339
+ pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
340
+ }
341
+ break;
342
+ case 32 * 1:
343
+ dst =
344
+ preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
345
+ break;
346
+ }
347
+ }
348
+ }
349
+
350
+ template <int j, int endX, int i, int endY, int nelems>
351
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) {
352
+ EIGEN_UNUSED_VARIABLE(ao);
353
+ }
354
+
355
+ template <int j, int endX, int i, int endY, int nelems>
356
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) {
357
+ if (j < endX) {
358
+ if (i < endY) {
359
+ auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
360
+ const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
361
+ a_load<nelems>(a_reg, a_addr);
362
+
363
+ a_loads<j, endX, i + 1, endY, nelems>(ao);
364
+ } else {
365
+ a_loads<j + 1, endX, 0, endY, nelems>(ao);
366
+ }
367
+ }
368
+ }
369
+
370
+ template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
371
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1,
372
+ const Scalar *co2) {
373
+ EIGEN_UNUSED_VARIABLE(co1);
374
+ EIGEN_UNUSED_VARIABLE(co2);
375
+ }
376
+
377
+ /* C prefetch loop structure.
378
+ * for (int un = 0; un < 8; un++) {
379
+ * if (b_unroll >= un + 1) {
380
+ * if (un == 4) co2 = co1 + 4 * ldc;
381
+ *
382
+ * for (int i = 0; i < um_vecs; i++) {
383
+ * Scalar *co = (un + 1 <= 4) ? co1 : co2;
384
+ * auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
385
+ * prefetch_c(co + co_off);
386
+ * }
387
+ * }
388
+ * }
389
+ */
390
+
391
+ template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
392
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
393
+ if (un < max_b_unroll) {
394
+ if (b_unroll >= un + 1) {
395
+ if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
396
+
397
+ if (i < um_vecs) {
398
+ Scalar *co = (un + 1 <= 4) ? co1 : co2;
399
+ auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
400
+ prefetch_c(co + co_off);
401
+
402
+ prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
403
+ } else {
404
+ prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
405
+ }
406
+
407
+ } else {
408
+ prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
409
+ }
410
+ }
411
+ }
412
+
413
+ // load_c
414
+ template <int i, int um_vecs, int idx, int nelems>
415
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
416
+ EIGEN_UNUSED_VARIABLE(cox);
417
+ EIGEN_UNUSED_VARIABLE(alpha_reg);
418
+ }
419
+
420
+ template <int i, int um_vecs, int idx, int nelems>
421
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
422
+ if (i < um_vecs) {
423
+ auto &c_reg = zmm[c_regs[i + idx * 3]];
424
+ auto &c_load_reg = zmm[c_load_regs[i % 3]];
425
+ auto c_mem = cox;
426
+ if (is_unit_inc)
427
+ c_mem += i * nelems_in_cache_line;
428
+ else
429
+ c_mem += i * nelems_in_cache_line * inc;
430
+
431
+ if (!is_beta0 && is_alpha1)
432
+ vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
433
+ else if (!is_beta0 && !is_alpha1)
434
+ vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
435
+ else if (is_beta0 && !is_alpha1)
436
+ c_reg = pmul(alpha_reg, c_reg);
437
+
438
+ scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
439
+ }
440
+ }
441
+
442
+ // store_c
443
+ template <int i, int um_vecs, int idx, int nelems>
444
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
445
+ EIGEN_UNUSED_VARIABLE(cox);
446
+ }
447
+
448
+ template <int i, int um_vecs, int idx, int nelems>
449
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
450
+ if (i < um_vecs) {
451
+ auto &c_reg = zmm[c_regs[i + idx * 3]];
452
+ auto c_mem = cox;
453
+ if (is_unit_inc)
454
+ c_mem += i * nelems_in_cache_line;
455
+ else
456
+ c_mem += i * nelems_in_cache_line * inc;
457
+
458
+ c_store<nelems>(c_mem, c_reg);
459
+ c_reg = pzero(c_reg);
460
+
461
+ write_c<i + 1, um_vecs, idx, nelems>(cox);
462
+ }
463
+ }
464
+
465
+ /* C update loop structure.
466
+ * co2 = co1 + ldc;
467
+ *
468
+ * auto &alpha_reg = zmm[alpha_load_reg];
469
+ * if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
470
+ *
471
+ * int idx = 0;
472
+ * for (pow = 1; pow <= 8; pow <<= 1) {
473
+ *
474
+ * if (b_unroll >= pow) {
475
+ * for (count = 1; count < (pow + 1) / 2 + 1; count++) {
476
+ * if (pow >= 4) co2 += ldc;
477
+ *
478
+ * const Scalar *cox = (idx == 0) ? co1 : co2;
479
+ *
480
+ * const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
481
+ * scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
482
+ * write_c<0, um_vecs, idx, a_unroll>(cox);
483
+ *
484
+ * idx++;
485
+ * }
486
+ * }
487
+ * }
488
+ *
489
+ * if (b_unroll == 1)
490
+ * co1 += ldc;
491
+ * else
492
+ * co1 = co2 + ldc;
493
+ */
494
+
495
+ template <int pow, int a_unroll, int idx>
496
+ EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) {
497
+ if (pow >= 4) cox += ldc;
498
+
499
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
500
+ auto &alpha_reg = zmm[alpha_load_reg];
501
+
502
+ scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
503
+ write_c<0, um_vecs, idx, a_unroll>(cox);
504
+ }
505
+
506
+ template <int pow, int a_unroll>
507
+ EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) {
508
+ constexpr int idx = pow / 2;
509
+ Scalar *&cox = idx == 0 ? co1 : co2;
510
+
511
+ constexpr int max_count = (pow + 1) / 2;
512
+ static_assert(max_count <= 4, "Unsupported max_count.");
513
+
514
+ if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
515
+ if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
516
+ if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
517
+ if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
518
+ }
519
+
520
+ template <int max_b_unroll, int a_unroll, int b_unroll>
521
+ EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) {
522
+ auto &alpha_reg = zmm[alpha_load_reg];
523
+
524
+ co2 = co1 + ldc;
525
+ if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
526
+ if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
527
+
528
+ static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
529
+
530
+ if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
531
+ if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
532
+ if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
533
+ if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
534
+
535
+ if (b_unroll == 1)
536
+ co1 += ldc;
537
+ else
538
+ co1 = co2 + ldc;
539
+ }
540
+
541
+ // compute
542
+ template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
543
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
544
+ int &fetchB_idx, vec &b_reg) {
545
+ EIGEN_UNUSED_VARIABLE(ao);
546
+ EIGEN_UNUSED_VARIABLE(bo);
547
+ EIGEN_UNUSED_VARIABLE(fetchA_idx);
548
+ EIGEN_UNUSED_VARIABLE(fetchB_idx);
549
+ EIGEN_UNUSED_VARIABLE(b_reg);
550
+ }
551
+
552
+ template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
553
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
554
+ int &fetchB_idx, vec &b_reg) {
555
+ if (um < um_vecs) {
556
+ auto &c_reg = zmm[c_regs[um + idx * 3]];
557
+ auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
558
+
559
+ vfmadd(c_reg, a_reg, b_reg);
560
+
561
+ if (!fetch_x && um == 0 &&
562
+ (((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
563
+ (idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
564
+ prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
565
+ fetchA_idx++;
566
+ }
567
+
568
+ if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
569
+ prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
570
+ fetchB_idx++;
571
+ }
572
+
573
+ compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
574
+ }
575
+ }
576
+
577
+ // load_a
578
+ template <int um, int um_vecs, int uk, int nelems, bool ktail>
579
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) {
580
+ EIGEN_UNUSED_VARIABLE(ao);
581
+ }
582
+
583
+ template <int um, int um_vecs, int uk, int nelems, bool ktail>
584
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) {
585
+ if (um < um_vecs) {
586
+ auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
587
+ const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
588
+ a_load<nelems>(a_reg, a_addr);
589
+
590
+ load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
591
+ }
592
+ }
593
+ template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
594
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
595
+ const Scalar *const &ao,
596
+ const Scalar *const &bo, Scalar *&co2,
597
+ int &fetchA_idx, int &fetchB_idx) {
598
+ EIGEN_UNUSED_VARIABLE(aa);
599
+ EIGEN_UNUSED_VARIABLE(ao);
600
+ EIGEN_UNUSED_VARIABLE(bo);
601
+ EIGEN_UNUSED_VARIABLE(co2);
602
+ EIGEN_UNUSED_VARIABLE(fetchA_idx);
603
+ EIGEN_UNUSED_VARIABLE(fetchB_idx);
604
+ }
605
+
606
+ template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
607
+ EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
608
+ const Scalar *const &ao,
609
+ const Scalar *const &bo, Scalar *&co2,
610
+ int &fetchA_idx, int &fetchB_idx) {
611
+ const int idx = (pow / 2) + count;
612
+
613
+ if (count < (pow + 1) / 2) {
614
+ auto &b_reg = zmm[b_regs[idx % 2]];
615
+
616
+ if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
617
+ if (fetch_x && uk == 3 && idx == 4) aa += 8;
618
+
619
+ if (b_unroll >= pow) {
620
+ compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
621
+
622
+ const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
623
+ b_load(b_reg, b_addr);
624
+ }
625
+
626
+ // Go to the next count.
627
+ innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
628
+ fetchB_idx);
629
+
630
+ } else {
631
+ // Maybe prefetch C data after count-loop.
632
+ if (pow == 2 && c_fetch) {
633
+ if (uk % 3 == 0 && uk > 0) {
634
+ co2 += ldc;
635
+ } else {
636
+ prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
637
+ }
638
+ }
639
+ }
640
+ }
641
+
642
+ template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch,
643
+ bool no_a_preload = false>
644
+ EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
645
+ Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
646
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
647
+
648
+ if (max_b_unroll >= 1)
649
+ innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
650
+ if (max_b_unroll >= 2)
651
+ innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
652
+ if (max_b_unroll >= 4)
653
+ innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
654
+ if (max_b_unroll >= 8)
655
+ innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
656
+
657
+ // Load A after pow-loop. Skip this at the end to prevent running over the buffer
658
+ if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
659
+ }
660
+
661
+ /* Inner kernel loop structure.
662
+ * for (int uk = 0; uk < kfactor; uk++) {
663
+ * int idx = 0;
664
+ *
665
+ * for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
666
+ * for (int count = 0; count < (pow + 1) / 2; count++) {
667
+ * auto &b_reg = zmm[b_regs[idx % 2]];
668
+ *
669
+ * if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
670
+ * if (fetch_x && uk == 3 && idx == 4) aa += 8;
671
+ *
672
+ * if (b_unroll >= pow) {
673
+ * compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
674
+ *
675
+ * const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
676
+ * b_load(b_reg, b_addr);
677
+ * }
678
+ * idx++;
679
+ * }
680
+ *
681
+ * Maybe prefetch C data.
682
+ * if (pow == 2 && c_fetch) {
683
+ * if (uk % 3 == 0 && uk > 0) {
684
+ * co2 += ldc;
685
+ * } else {
686
+ * prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
687
+ * }
688
+ * }
689
+ * }
690
+ *
691
+ * Load A.
692
+ * load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
693
+ * }
694
+ *
695
+ * Advance A/B pointers after uk-loop.
696
+ * ao += a_unroll * kfactor;
697
+ * bo += b_unroll * kfactor;
698
+ */
699
+
700
+ template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch,
701
+ bool no_a_preload = false>
702
+ EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
703
+ int fetchA_idx = 0;
704
+ int fetchB_idx = 0;
705
+
706
+ const bool fetch_x = k_factor == max_k_factor;
707
+ const bool ktail = k_factor == 1;
708
+
709
+ static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
710
+ static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1),
711
+ "skipping a preload only allowed when k unroll is 1");
712
+
713
+ if (k_factor > 0)
714
+ innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
715
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
716
+ if (k_factor > 1)
717
+ innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
718
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
719
+ if (k_factor > 2)
720
+ innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
721
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
722
+ if (k_factor > 3)
723
+ innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
724
+ aa, ao, bo, co2, fetchA_idx, fetchB_idx);
725
+
726
+ // Advance A/B pointers after uk-loop.
727
+ ao += a_unroll * k_factor;
728
+ bo += b_unroll * k_factor;
729
+ }
730
+
731
+ template <int a_unroll, int b_unroll, int max_b_unroll>
732
+ EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
733
+ const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
734
+ if (!use_less_a_regs && k > 1)
735
+ a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
736
+ else
737
+ a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
738
+
739
+ b_load(zmm[b_regs[0]], bo - b_shift + 0);
740
+ if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
741
+
742
+ #ifndef SECOND_FETCH
743
+ prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
744
+ #endif // SECOND_FETCH
745
+
746
+ // Unrolling k-loop by a factor of 4.
747
+ const int max_k_factor = 4;
748
+ Index kRem = k % max_k_factor;
749
+ Index k_ = k - kRem;
750
+ if (k_ >= max_k_factor) {
751
+ k_ -= max_k_factor;
752
+ kRem += max_k_factor;
753
+ }
754
+ Index loop_count = k_ / max_k_factor;
755
+
756
+ if (loop_count > 0) {
757
+ #ifdef SECOND_FETCH
758
+ loop_count -= SECOND_FETCH;
759
+ #endif
760
+ while (loop_count > 0) {
761
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
762
+ loop_count--;
763
+ }
764
+ #ifdef SECOND_FETCH
765
+ co2 = co1 + nelems_in_cache_line - 1;
766
+
767
+ loop_count += b_unroll;
768
+ while (loop_count > 0) {
769
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
770
+ loop_count--;
771
+ }
772
+
773
+ loop_count += SECOND_FETCH - b_unroll;
774
+ while (loop_count > 0) {
775
+ innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
776
+ loop_count--;
777
+ }
778
+ #endif
779
+ }
780
+
781
+ // k-loop remainder handling.
782
+ loop_count = kRem;
783
+ while (loop_count > 1) {
784
+ innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
785
+ loop_count--;
786
+ }
787
+ if (loop_count > 0) {
788
+ innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
789
+ }
790
+
791
+ // Update C matrix.
792
+ c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
793
+ }
794
+
795
+ template <int a_unroll, int b_unroll, int max_b_unroll>
796
+ EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
797
+ // Set A matrix pointer.
798
+ ao = a + a_off * a_unroll;
799
+
800
+ // Set B matrix pointer if needed.
801
+ bo += b_unroll * b_off;
802
+
803
+ kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
804
+
805
+ // Advance B matrix pointer if needed.
806
+ bo += b_unroll * (b_stride - k - b_off);
807
+
808
+ // Advance prefetch A pointer.
809
+ aa += 16;
810
+ }
811
+
812
+ template <int a_unroll, int max_a_unroll, int max_b_unroll>
813
+ EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
814
+ // Set prefetch A pointers.
815
+ const Scalar *aa = a + a_unroll * a_stride;
816
+
817
+ // Set C matrix pointers.
818
+ co1 = c;
819
+ if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
820
+ if (is_unit_inc)
821
+ c += a_unroll;
822
+ else
823
+ c += a_unroll * inc;
824
+
825
+ // Set B matrix pointer.
826
+ bo = b;
827
+
828
+ // Main n-loop.
829
+ for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
830
+
831
+ // n-remainders.
832
+ if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
833
+ #if 0
834
+ if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
835
+ if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
836
+ #else
837
+ // Copy kernels don't support tails of n = 2 for single/double precision.
838
+ // Loop over ones.
839
+ int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
840
+ while (n_rem > 0) {
841
+ nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
842
+ n_rem--;
843
+ }
844
+ #endif
845
+
846
+ // Advance A matrix pointer.
847
+ a = ao + a_unroll * (a_stride - k - a_off);
848
+ }
849
+
850
+ public:
851
+ // Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
852
+ template <int max_a_unroll, int max_b_unroll>
853
+ EIGEN_ALWAYS_INLINE void compute_kern() {
854
+ a -= -a_shift;
855
+ b -= -b_shift;
856
+
857
+ const Scalar *ao = nullptr;
858
+ const Scalar *bo = nullptr;
859
+ Scalar *co1 = nullptr;
860
+ Scalar *co2 = nullptr;
861
+
862
+ // Main m-loop.
863
+ for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
864
+
865
+ // m-remainders.
866
+ if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
867
+ if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
868
+ if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
869
+ if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
870
+ if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
871
+ if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
872
+
873
+ // Copy kernels don't support tails of m = 2 for single precision.
874
+ // Loop over ones.
875
+ if (is_f32) {
876
+ int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
877
+ while (m_rem > 0) {
878
+ mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
879
+ m_rem--;
880
+ }
881
+ }
882
+ }
883
+
884
+ gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_,
885
+ const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_,
886
+ Index a_off_, Index b_off_)
887
+ : m(m_),
888
+ n(n_),
889
+ k(k_),
890
+ ldc(ldc_),
891
+ inc(inc_),
892
+ alpha(alpha_),
893
+ a(a_),
894
+ b(b_),
895
+ c(c_),
896
+ is_alpha1(is_alpha1_),
897
+ is_beta0(is_beta0_),
898
+ a_stride(a_stride_),
899
+ b_stride(b_stride_),
900
+ a_off(a_off_),
901
+ b_off(b_off_) {
902
+ // Zero out all accumulation registers.
903
+ zmm[8] = pzero(zmm[8]);
904
+ zmm[9] = pzero(zmm[9]);
905
+ zmm[10] = pzero(zmm[10]);
906
+ zmm[11] = pzero(zmm[11]);
907
+ zmm[12] = pzero(zmm[12]);
908
+ zmm[13] = pzero(zmm[13]);
909
+ zmm[14] = pzero(zmm[14]);
910
+ zmm[15] = pzero(zmm[15]);
911
+ zmm[16] = pzero(zmm[16]);
912
+ zmm[17] = pzero(zmm[17]);
913
+ zmm[18] = pzero(zmm[18]);
914
+ zmm[19] = pzero(zmm[19]);
915
+ zmm[20] = pzero(zmm[20]);
916
+ zmm[21] = pzero(zmm[21]);
917
+ zmm[22] = pzero(zmm[22]);
918
+ zmm[23] = pzero(zmm[23]);
919
+ zmm[24] = pzero(zmm[24]);
920
+ zmm[25] = pzero(zmm[25]);
921
+ zmm[26] = pzero(zmm[26]);
922
+ zmm[27] = pzero(zmm[27]);
923
+ zmm[28] = pzero(zmm[28]);
924
+ zmm[29] = pzero(zmm[29]);
925
+ zmm[30] = pzero(zmm[30]);
926
+ zmm[31] = pzero(zmm[31]);
927
+ }
928
+ };
929
+
930
+ // Compute kernel with max unroll support of:
931
+ // Single precision:
932
+ // max_a_unroll: 48, 32, 16, 8, 4, 2, 1
933
+ // max_b_unroll: 8, 4, 2, 1
934
+ // Double precision:
935
+ // max_a_unroll: 24, 16, 8, 4, 2, 1
936
+ // max_b_unroll: 8, 4, 2, 1
937
+ template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0, bool is_unit_inc>
938
+ EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b,
939
+ Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
940
+ Index a_off = 0, Index b_off = 0) {
941
+ if (a_stride == -1) a_stride = k;
942
+ if (b_stride == -1) b_stride = k;
943
+
944
+ gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
945
+ b_off);
946
+ g.template compute_kern<max_a_unroll, max_b_unroll>();
947
+ }
948
+
949
+ // Template specializations of GEBP kernels with nr = 8.
950
+ #if EIGEN_USE_AVX512_GEMM_KERNELS
951
+ template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
952
+ class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
953
+ : public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
954
+ using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
955
+
956
+ public:
957
+ enum { nr = Base::Vectorizable ? 8 : 4 };
958
+ };
959
+
960
+ template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
961
+ class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
962
+ : public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
963
+ using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
964
+
965
+ public:
966
+ enum { nr = Base::Vectorizable ? 8 : 4 };
967
+ };
968
+
969
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
970
+ struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode> {
971
+ typedef typename packet_traits<Scalar>::type Packet;
972
+ typedef typename DataMapper::LinearMapper LinearMapper;
973
+ enum { PacketSize = packet_traits<Scalar>::size };
974
+ EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
975
+ Index offset = 0);
976
+ };
977
+
978
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
979
+ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>::operator()(
980
+ Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
981
+ constexpr int nr = 8;
982
+ EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
983
+ EIGEN_UNUSED_VARIABLE(stride);
984
+ EIGEN_UNUSED_VARIABLE(offset);
985
+ eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
986
+ conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
987
+ Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
988
+ Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
989
+ Index count = 0;
990
+ const Index peeled_k = (depth / PacketSize) * PacketSize;
991
+ if (nr >= 8) {
992
+ for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
993
+ // skip what we have before
994
+ if (PanelMode) count += 8 * offset;
995
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
996
+ const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
997
+ const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
998
+ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
999
+ const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
1000
+ const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
1001
+ const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
1002
+ const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
1003
+ Index k = 0;
1004
+ if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4
1005
+ {
1006
+ for (; k < peeled_k; k += PacketSize) {
1007
+ PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
1008
+
1009
+ kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1010
+ kernel.packet[1] = dm1.template loadPacket<Packet>(k);
1011
+ kernel.packet[2] = dm2.template loadPacket<Packet>(k);
1012
+ kernel.packet[3] = dm3.template loadPacket<Packet>(k);
1013
+ kernel.packet[4] = dm4.template loadPacket<Packet>(k);
1014
+ kernel.packet[5] = dm5.template loadPacket<Packet>(k);
1015
+ kernel.packet[6] = dm6.template loadPacket<Packet>(k);
1016
+ kernel.packet[7] = dm7.template loadPacket<Packet>(k);
1017
+
1018
+ ptranspose(kernel);
1019
+
1020
+ pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1021
+ pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1022
+ pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1023
+ pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1024
+ pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
1025
+ pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
1026
+ pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
1027
+ pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
1028
+ count += 8 * PacketSize;
1029
+ }
1030
+ }
1031
+ for (; k < depth; k++) {
1032
+ blockB[count + 0] = cj(dm0(k));
1033
+ blockB[count + 1] = cj(dm1(k));
1034
+ blockB[count + 2] = cj(dm2(k));
1035
+ blockB[count + 3] = cj(dm3(k));
1036
+ blockB[count + 4] = cj(dm4(k));
1037
+ blockB[count + 5] = cj(dm5(k));
1038
+ blockB[count + 6] = cj(dm6(k));
1039
+ blockB[count + 7] = cj(dm7(k));
1040
+ count += 8;
1041
+ }
1042
+ // skip what we have after
1043
+ if (PanelMode) count += 8 * (stride - offset - depth);
1044
+ }
1045
+ }
1046
+
1047
+ if (nr >= 4) {
1048
+ for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1049
+ // skip what we have before
1050
+ if (PanelMode) count += 4 * offset;
1051
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1052
+ const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1053
+ const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1054
+ const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1055
+
1056
+ Index k = 0;
1057
+ if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
1058
+ {
1059
+ for (; k < peeled_k; k += PacketSize) {
1060
+ PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
1061
+ kernel.packet[0] = dm0.template loadPacket<Packet>(k);
1062
+ kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
1063
+ kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
1064
+ kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
1065
+ ptranspose(kernel);
1066
+ pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
1067
+ pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
1068
+ pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
1069
+ pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
1070
+ count += 4 * PacketSize;
1071
+ }
1072
+ }
1073
+ for (; k < depth; k++) {
1074
+ blockB[count + 0] = cj(dm0(k));
1075
+ blockB[count + 1] = cj(dm1(k));
1076
+ blockB[count + 2] = cj(dm2(k));
1077
+ blockB[count + 3] = cj(dm3(k));
1078
+ count += 4;
1079
+ }
1080
+ // skip what we have after
1081
+ if (PanelMode) count += 4 * (stride - offset - depth);
1082
+ }
1083
+ }
1084
+
1085
+ // copy the remaining columns one at a time (nr==1)
1086
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1087
+ if (PanelMode) count += offset;
1088
+ const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
1089
+ for (Index k = 0; k < depth; k++) {
1090
+ blockB[count] = cj(dm0(k));
1091
+ count += 1;
1092
+ }
1093
+ if (PanelMode) count += (stride - offset - depth);
1094
+ }
1095
+ }
1096
+
1097
+ template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
1098
+ struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, RowMajor, Conjugate, PanelMode> {
1099
+ typedef typename packet_traits<Scalar>::type Packet;
1100
+ typedef typename unpacket_traits<Packet>::half HalfPacket;
1101
+ typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
1102
+ typedef typename DataMapper::LinearMapper LinearMapper;
1103
+ enum {
1104
+ PacketSize = packet_traits<Scalar>::size,
1105
+ HalfPacketSize = unpacket_traits<HalfPacket>::size,
1106
+ QuarterPacketSize = unpacket_traits<QuarterPacket>::size
1107
+ };
1108
+ EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
1109
+ Index offset = 0) {
1110
+ constexpr int nr = 8;
1111
+ EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
1112
+ EIGEN_UNUSED_VARIABLE(stride);
1113
+ EIGEN_UNUSED_VARIABLE(offset);
1114
+ eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
1115
+ const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
1116
+ const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
1117
+ conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
1118
+ Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
1119
+ Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
1120
+ Index count = 0;
1121
+
1122
+ if (nr >= 8) {
1123
+ for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
1124
+ // skip what we have before
1125
+ if (PanelMode) count += 8 * offset;
1126
+ for (Index k = 0; k < depth; k++) {
1127
+ if (PacketSize == 8) {
1128
+ // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1129
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1130
+ pstoreu(blockB + count, cj.pconj(A));
1131
+ } else if (HasHalf && HalfPacketSize == 8) {
1132
+ HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1133
+ pstoreu(blockB + count, cj.pconj(A));
1134
+ } else if (HasQuarter && QuarterPacketSize == 8) {
1135
+ QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1136
+ pstoreu(blockB + count, cj.pconj(A));
1137
+ } else if (PacketSize == 4) {
1138
+ // Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
1139
+ // Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
1140
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1141
+ Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
1142
+ pstoreu(blockB + count, cj.pconj(A));
1143
+ pstoreu(blockB + count + PacketSize, cj.pconj(B));
1144
+ } else {
1145
+ // const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
1146
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1147
+ blockB[count + 0] = cj(dm0(0));
1148
+ blockB[count + 1] = cj(dm0(1));
1149
+ blockB[count + 2] = cj(dm0(2));
1150
+ blockB[count + 3] = cj(dm0(3));
1151
+ blockB[count + 4] = cj(dm0(4));
1152
+ blockB[count + 5] = cj(dm0(5));
1153
+ blockB[count + 6] = cj(dm0(6));
1154
+ blockB[count + 7] = cj(dm0(7));
1155
+ }
1156
+ count += 8;
1157
+ }
1158
+ // skip what we have after
1159
+ if (PanelMode) count += 8 * (stride - offset - depth);
1160
+ }
1161
+ }
1162
+
1163
+ if (nr >= 4) {
1164
+ for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
1165
+ // skip what we have before
1166
+ if (PanelMode) count += 4 * offset;
1167
+ for (Index k = 0; k < depth; k++) {
1168
+ if (PacketSize == 4) {
1169
+ Packet A = rhs.template loadPacket<Packet>(k, j2);
1170
+ pstoreu(blockB + count, cj.pconj(A));
1171
+ count += PacketSize;
1172
+ } else if (HasHalf && HalfPacketSize == 4) {
1173
+ HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
1174
+ pstoreu(blockB + count, cj.pconj(A));
1175
+ count += HalfPacketSize;
1176
+ } else if (HasQuarter && QuarterPacketSize == 4) {
1177
+ QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
1178
+ pstoreu(blockB + count, cj.pconj(A));
1179
+ count += QuarterPacketSize;
1180
+ } else {
1181
+ const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
1182
+ blockB[count + 0] = cj(dm0(0));
1183
+ blockB[count + 1] = cj(dm0(1));
1184
+ blockB[count + 2] = cj(dm0(2));
1185
+ blockB[count + 3] = cj(dm0(3));
1186
+ count += 4;
1187
+ }
1188
+ }
1189
+ // skip what we have after
1190
+ if (PanelMode) count += 4 * (stride - offset - depth);
1191
+ }
1192
+ }
1193
+ // copy the remaining columns one at a time (nr==1)
1194
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1195
+ if (PanelMode) count += offset;
1196
+ for (Index k = 0; k < depth; k++) {
1197
+ blockB[count] = cj(rhs(k, j2));
1198
+ count += 1;
1199
+ }
1200
+ if (PanelMode) count += stride - offset - depth;
1201
+ }
1202
+ }
1203
+ };
1204
+
1205
+ template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1206
+ struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
1207
+ EIGEN_ALWAYS_INLINE void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows,
1208
+ Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1,
1209
+ Index offsetA = 0, Index offsetB = 0);
1210
+ };
1211
+
1212
+ template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
1213
+ EIGEN_ALWAYS_INLINE void gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs>::operator()(
1214
+ const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols,
1215
+ Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
1216
+ if (res.incr() == 1) {
1217
+ if (alpha == 1) {
1218
+ gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1219
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1220
+ strideB, offsetA, offsetB);
1221
+ } else {
1222
+ gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
1223
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1224
+ strideB, offsetA, offsetB);
1225
+ }
1226
+ } else {
1227
+ if (alpha == 1) {
1228
+ gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1229
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1230
+ strideB, offsetA, offsetB);
1231
+ } else {
1232
+ gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
1233
+ (Scalar *)res.data(), res.stride(), res.incr(), strideA,
1234
+ strideB, offsetA, offsetB);
1235
+ }
1236
+ }
1237
+ }
1238
+ #endif // EIGEN_USE_AVX512_GEMM_KERNELS
1239
+
1240
+ } // namespace internal
1241
+ } // namespace Eigen
1242
+
1243
+ #undef SECOND_FETCH
1244
+
1245
+ #endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H