@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -0,0 +1,3686 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com)
5
+ // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
6
+ //
7
+ // This Source Code Form is subject to the terms of the Mozilla
8
+ // Public License v. 2.0. If a copy of the MPL was not distributed
9
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
+
11
+ #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H
12
+ #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H
13
+
14
+ #ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK
15
+ #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1
16
+ #endif
17
+
18
+ #if !defined(EIGEN_ALTIVEC_DISABLE_MMA)
19
+ #define EIGEN_ALTIVEC_DISABLE_MMA 0
20
+ #endif
21
+
22
+ // Check for MMA builtin support.
23
+ #if !EIGEN_ALTIVEC_DISABLE_MMA && defined(__has_builtin)
24
+ #if __has_builtin(__builtin_mma_assemble_acc)
25
+ #define EIGEN_ALTIVEC_MMA_SUPPORT
26
+ #endif
27
+ #endif
28
+
29
+ // Check if and how we should actually use MMA if supported.
30
+ #if defined(EIGEN_ALTIVEC_MMA_SUPPORT)
31
+
32
+ #if !defined(EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH)
33
+ #define EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH 0
34
+ #endif
35
+
36
+ // Check if we want to enable dynamic dispatch. Not supported by LLVM.
37
+ #if EIGEN_ALTIVEC_ENABLE_MMA_DYNAMIC_DISPATCH && !EIGEN_COMP_LLVM
38
+ #define EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH 1
39
+ // Otherwise, use MMA by default if available.
40
+ #elif defined(__MMA__)
41
+ #define EIGEN_ALTIVEC_MMA_ONLY 1
42
+ #endif
43
+
44
+ #endif // EIGEN_ALTIVEC_MMA_SUPPORT
45
+
46
+ #include "MatrixProductCommon.h"
47
+
48
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY) || defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH)
49
+ #include "MatrixProductMMA.h"
50
+ #endif
51
+
52
+ // IWYU pragma: private
53
+ #include "../../InternalHeaderCheck.h"
54
+
55
+ namespace Eigen {
56
+
57
+ namespace internal {
58
+
59
+ /**************************
60
+ * Constants and typedefs *
61
+ **************************/
62
+ template <typename Scalar>
63
+ struct quad_traits {
64
+ typedef typename packet_traits<Scalar>::type vectortype;
65
+ typedef PacketBlock<vectortype, 4> type;
66
+ typedef vectortype rhstype;
67
+ enum { vectorsize = packet_traits<Scalar>::size, size = 4, rows = 4 };
68
+ };
69
+
70
+ template <>
71
+ struct quad_traits<double> {
72
+ typedef Packet2d vectortype;
73
+ typedef PacketBlock<vectortype, 4> type;
74
+ typedef PacketBlock<Packet2d, 2> rhstype;
75
+ enum { vectorsize = packet_traits<double>::size, size = 2, rows = 4 };
76
+ };
77
+
78
+ template <>
79
+ struct quad_traits<bfloat16> {
80
+ typedef Packet8bf vectortype;
81
+ typedef PacketBlock<vectortype, 4> type;
82
+ typedef vectortype rhstype;
83
+ enum { vectorsize = packet_traits<bfloat16>::size, size = 8, rows = 4 };
84
+ };
85
+
86
+ // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out
87
+ // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then
88
+ // are responsible to extract from convert between Eigen's and MatrixProduct approach.
89
+
90
+ const static Packet16uc p16uc_GETREAL32 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
91
+
92
+ const static Packet16uc p16uc_GETIMAG32 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
93
+
94
+ const static Packet16uc p16uc_GETREAL32b = {0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27};
95
+
96
+ const static Packet16uc p16uc_GETIMAG32b = {4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31};
97
+
98
+ /*********************************************
99
+ * Single precision real and complex packing *
100
+ * *******************************************/
101
+
102
+ /**
103
+ * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves
104
+ * the diagonal real, whatever is below it is copied from the respective upper diagonal element and
105
+ * conjugated. There's no PanelMode available for symm packing.
106
+ *
107
+ * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using
108
+ * its respective rank-update instructions. The float32/64 versions are different because at this moment
109
+ * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements.
110
+ *
111
+ * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has
112
+ * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main
113
+ * reason why packing for complex is broken down into several different parts, also the reason why we endup having a
114
+ * float32/64 and complex float32/64 version.
115
+ **/
116
+ template <typename Scalar, int StorageOrder>
117
+ EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(
118
+ Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt) {
119
+ std::complex<Scalar> v;
120
+ if (i < j) {
121
+ v.real(dt(j, i).real());
122
+ v.imag(-dt(j, i).imag());
123
+ } else if (i > j) {
124
+ v.real(dt(i, j).real());
125
+ v.imag(dt(i, j).imag());
126
+ } else {
127
+ v.real(dt(i, j).real());
128
+ v.imag((Scalar)0.0);
129
+ }
130
+ return v;
131
+ }
132
+
133
+ template <typename Scalar, int StorageOrder, int N>
134
+ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs,
135
+ Index rhsStride, Index rows, Index cols, Index k2) {
136
+ const Index depth = k2 + rows;
137
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride);
138
+ const Index vectorSize = N * quad_traits<Scalar>::vectorsize;
139
+ const Index vectorDelta = vectorSize * rows;
140
+ Scalar* blockBf = reinterpret_cast<Scalar*>(blockB);
141
+
142
+ Index rir = 0, rii, j = 0;
143
+ for (; j + vectorSize <= cols; j += vectorSize) {
144
+ rii = rir + vectorDelta;
145
+
146
+ for (Index i = k2; i < depth; i++) {
147
+ for (Index k = 0; k < vectorSize; k++) {
148
+ std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j + k, rhs);
149
+
150
+ blockBf[rir + k] = v.real();
151
+ blockBf[rii + k] = v.imag();
152
+ }
153
+ rir += vectorSize;
154
+ rii += vectorSize;
155
+ }
156
+
157
+ rir += vectorDelta;
158
+ }
159
+
160
+ for (; j < cols; j++) {
161
+ rii = rir + rows;
162
+
163
+ for (Index i = k2; i < depth; i++) {
164
+ std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(i, j, rhs);
165
+
166
+ blockBf[rir] = v.real();
167
+ blockBf[rii] = v.imag();
168
+
169
+ rir += 1;
170
+ rii += 1;
171
+ }
172
+
173
+ rir += rows;
174
+ }
175
+ }
176
+
177
+ template <typename Scalar, int StorageOrder>
178
+ EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs,
179
+ Index lhsStride, Index cols, Index rows) {
180
+ const Index depth = cols;
181
+ const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride);
182
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
183
+ const Index vectorDelta = vectorSize * depth;
184
+ Scalar* blockAf = reinterpret_cast<Scalar*>(blockA);
185
+
186
+ Index rir = 0, rii, j = 0;
187
+ for (; j + vectorSize <= rows; j += vectorSize) {
188
+ rii = rir + vectorDelta;
189
+
190
+ for (Index i = 0; i < depth; i++) {
191
+ for (Index k = 0; k < vectorSize; k++) {
192
+ std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(j + k, i, lhs);
193
+
194
+ blockAf[rir + k] = v.real();
195
+ blockAf[rii + k] = v.imag();
196
+ }
197
+ rir += vectorSize;
198
+ rii += vectorSize;
199
+ }
200
+
201
+ rir += vectorDelta;
202
+ }
203
+
204
+ if (j < rows) {
205
+ rii = rir + ((rows - j) * depth);
206
+
207
+ for (Index i = 0; i < depth; i++) {
208
+ Index k = j;
209
+ for (; k < rows; k++) {
210
+ std::complex<Scalar> v = getAdjointVal<Scalar, StorageOrder>(k, i, lhs);
211
+
212
+ blockAf[rir] = v.real();
213
+ blockAf[rii] = v.imag();
214
+
215
+ rir += 1;
216
+ rii += 1;
217
+ }
218
+ }
219
+ }
220
+ }
221
+
222
+ template <typename Scalar, int StorageOrder, int N>
223
+ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows,
224
+ Index cols, Index k2) {
225
+ const Index depth = k2 + rows;
226
+ const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride);
227
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
228
+
229
+ Index ri = 0, j = 0;
230
+ for (; j + N * vectorSize <= cols; j += N * vectorSize) {
231
+ Index i = k2;
232
+ for (; i < depth; i++) {
233
+ for (Index k = 0; k < N * vectorSize; k++) {
234
+ if (i <= j + k)
235
+ blockB[ri + k] = rhs(j + k, i);
236
+ else
237
+ blockB[ri + k] = rhs(i, j + k);
238
+ }
239
+ ri += N * vectorSize;
240
+ }
241
+ }
242
+
243
+ for (; j < cols; j++) {
244
+ for (Index i = k2; i < depth; i++) {
245
+ if (j <= i)
246
+ blockB[ri] = rhs(i, j);
247
+ else
248
+ blockB[ri] = rhs(j, i);
249
+ ri += 1;
250
+ }
251
+ }
252
+ }
253
+
254
+ template <typename Scalar, int StorageOrder>
255
+ EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols,
256
+ Index rows) {
257
+ const Index depth = cols;
258
+ const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride);
259
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
260
+
261
+ Index ri = 0, j = 0;
262
+ for (; j + vectorSize <= rows; j += vectorSize) {
263
+ Index i = 0;
264
+
265
+ for (; i < depth; i++) {
266
+ for (Index k = 0; k < vectorSize; k++) {
267
+ if (i <= j + k)
268
+ blockA[ri + k] = lhs(j + k, i);
269
+ else
270
+ blockA[ri + k] = lhs(i, j + k);
271
+ }
272
+ ri += vectorSize;
273
+ }
274
+ }
275
+
276
+ if (j < rows) {
277
+ for (Index i = 0; i < depth; i++) {
278
+ Index k = j;
279
+ for (; k < rows; k++) {
280
+ if (i <= k)
281
+ blockA[ri] = lhs(k, i);
282
+ else
283
+ blockA[ri] = lhs(i, k);
284
+ ri += 1;
285
+ }
286
+ }
287
+ }
288
+ }
289
+
290
+ template <typename Index, int nr, int StorageOrder>
291
+ struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder> {
292
+ void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols,
293
+ Index k2) {
294
+ symm_pack_complex_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
295
+ }
296
+ };
297
+
298
+ template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
299
+ struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder> {
300
+ void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols,
301
+ Index rows) {
302
+ symm_pack_complex_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
303
+ }
304
+ };
305
+
306
+ // *********** symm_pack std::complex<float64> ***********
307
+
308
+ template <typename Index, int nr, int StorageOrder>
309
+ struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder> {
310
+ void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows,
311
+ Index cols, Index k2) {
312
+ symm_pack_complex_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
313
+ }
314
+ };
315
+
316
+ template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
317
+ struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder> {
318
+ void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols,
319
+ Index rows) {
320
+ symm_pack_complex_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
321
+ }
322
+ };
323
+
324
+ // *********** symm_pack float32 ***********
325
+ template <typename Index, int nr, int StorageOrder>
326
+ struct symm_pack_rhs<float, Index, nr, StorageOrder> {
327
+ void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
328
+ symm_pack_rhs_helper<float, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2);
329
+ }
330
+ };
331
+
332
+ template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
333
+ struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder> {
334
+ void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) {
335
+ symm_pack_lhs_helper<float, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
336
+ }
337
+ };
338
+
339
+ // *********** symm_pack float64 ***********
340
+ template <typename Index, int nr, int StorageOrder>
341
+ struct symm_pack_rhs<double, Index, nr, StorageOrder> {
342
+ void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) {
343
+ symm_pack_rhs_helper<double, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2);
344
+ }
345
+ };
346
+
347
+ template <typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
348
+ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder> {
349
+ void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) {
350
+ symm_pack_lhs_helper<double, StorageOrder>(blockA, _lhs, lhsStride, cols, rows);
351
+ }
352
+ };
353
+
354
+ /**
355
+ * PanelMode
356
+ * Packing might be called several times before being multiplied by gebp_kernel, this happens because
357
+ * on special occasions it fills part of block with other parts of the matrix. Two variables control
358
+ * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever
359
+ * is going to be the real offset and stride in the future and this is what you should obey. The process
360
+ * is to behave as you would with normal packing but leave the start of each part with the correct offset
361
+ * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride
362
+ * and offset and behaves accordingly.
363
+ **/
364
+
365
+ template <typename Scalar, typename Packet, int N>
366
+ EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet, N>& block) {
367
+ const Index size = 16 / sizeof(Scalar);
368
+ pstore<Scalar>(to + (0 * size), block.packet[0]);
369
+ pstore<Scalar>(to + (1 * size), block.packet[1]);
370
+ if (N > 2) {
371
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
372
+ }
373
+ if (N > 3) {
374
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
375
+ }
376
+ }
377
+
378
+ // General template for lhs & rhs complex packing.
379
+ template <typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate,
380
+ bool PanelMode, bool UseLhs>
381
+ struct dhs_cpack {
382
+ template <bool transpose>
383
+ EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock<PacketC, 8>& cblock, PacketBlock<Packet, 4>& block,
384
+ Packet16uc permute) {
385
+ if (transpose) {
386
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
387
+ block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
388
+ block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
389
+ block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
390
+
391
+ Packet4f t0, t1, t2, t3;
392
+ #ifdef EIGEN_VECTORIZE_VSX
393
+ t0 = reinterpret_cast<Packet>(
394
+ vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
395
+ t1 = reinterpret_cast<Packet>(
396
+ vec_mergel(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
397
+ t2 = reinterpret_cast<Packet>(
398
+ vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
399
+ t3 = reinterpret_cast<Packet>(
400
+ vec_mergel(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
401
+ #else
402
+ t0 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
403
+ t1 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
404
+ t2 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
405
+ t3 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
406
+ #endif
407
+
408
+ block.packet[0] = t0;
409
+ block.packet[1] = t1;
410
+ block.packet[2] = t2;
411
+ block.packet[3] = t3;
412
+ } else {
413
+ block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
414
+ block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
415
+ block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
416
+ block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
417
+ }
418
+ }
419
+
420
+ EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
421
+ Index depth, const Index vectorSize) {
422
+ PacketBlock<Packet, 4> blockr, blocki;
423
+ PacketBlock<PacketC, 8> cblock;
424
+
425
+ for (; i + vectorSize <= depth; i += vectorSize) {
426
+ if (UseLhs) {
427
+ bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
428
+ } else {
429
+ bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
430
+ }
431
+
432
+ if (((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) {
433
+ dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
434
+ dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
435
+ } else {
436
+ dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
437
+ dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
438
+ }
439
+
440
+ if (Conjugate) {
441
+ blocki.packet[0] = -blocki.packet[0];
442
+ blocki.packet[1] = -blocki.packet[1];
443
+ blocki.packet[2] = -blocki.packet[2];
444
+ blocki.packet[3] = -blocki.packet[3];
445
+ }
446
+
447
+ storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
448
+ storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
449
+
450
+ rir += 4 * vectorSize;
451
+ rii += 4 * vectorSize;
452
+ }
453
+ }
454
+
455
+ EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows,
456
+ Index stride, Index offset) {
457
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
458
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
459
+ Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
460
+ Scalar* blockAt = reinterpret_cast<Scalar*>(blockA);
461
+ Index j = 0;
462
+
463
+ for (; j + vectorSize <= rows; j += vectorSize) {
464
+ const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
465
+ Index i = 0;
466
+
467
+ rii = rir + vectorDelta;
468
+
469
+ dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
470
+
471
+ for (; i < depth; i++) {
472
+ PacketBlock<Packet, 1> blockr, blocki;
473
+ PacketBlock<PacketC, 2> cblock;
474
+
475
+ if (((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) {
476
+ if (UseLhs) {
477
+ cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
478
+ cblock.packet[1] = lhs2.template loadPacket<PacketC>(2, i);
479
+ } else {
480
+ cblock.packet[0] = lhs2.template loadPacket<PacketC>(i, 0);
481
+ cblock.packet[1] = lhs2.template loadPacket<PacketC>(i, 2);
482
+ }
483
+ } else {
484
+ if (UseLhs) {
485
+ cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i));
486
+ cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i));
487
+ } else {
488
+ cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1));
489
+ cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3));
490
+ }
491
+ }
492
+
493
+ blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32);
494
+ blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32);
495
+
496
+ if (Conjugate) {
497
+ blocki.packet[0] = -blocki.packet[0];
498
+ }
499
+
500
+ pstore<Scalar>(blockAt + rir, blockr.packet[0]);
501
+ pstore<Scalar>(blockAt + rii, blocki.packet[0]);
502
+
503
+ rir += vectorSize;
504
+ rii += vectorSize;
505
+ }
506
+
507
+ rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
508
+ }
509
+
510
+ if (!UseLhs) {
511
+ if (PanelMode) rir -= (offset * (vectorSize - 1));
512
+
513
+ for (; j < rows; j++) {
514
+ const DataMapper lhs2 = lhs.getSubMapper(0, j);
515
+ rii = rir + ((PanelMode) ? stride : depth);
516
+
517
+ for (Index i = 0; i < depth; i++) {
518
+ blockAt[rir] = lhs2(i, 0).real();
519
+
520
+ if (Conjugate)
521
+ blockAt[rii] = -lhs2(i, 0).imag();
522
+ else
523
+ blockAt[rii] = lhs2(i, 0).imag();
524
+
525
+ rir += 1;
526
+ rii += 1;
527
+ }
528
+
529
+ rir += ((PanelMode) ? (2 * stride - depth) : depth);
530
+ }
531
+ } else {
532
+ if (j < rows) {
533
+ if (PanelMode) rir += (offset * (rows - j - vectorSize));
534
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
535
+
536
+ for (Index i = 0; i < depth; i++) {
537
+ Index k = j;
538
+ for (; k < rows; k++) {
539
+ blockAt[rir] = lhs(k, i).real();
540
+
541
+ if (Conjugate)
542
+ blockAt[rii] = -lhs(k, i).imag();
543
+ else
544
+ blockAt[rii] = lhs(k, i).imag();
545
+
546
+ rir += 1;
547
+ rii += 1;
548
+ }
549
+ }
550
+ }
551
+ }
552
+ }
553
+ };
554
+
555
+ // General template for lhs & rhs packing.
556
+ template <typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
557
+ struct dhs_pack {
558
+ template <Index n>
559
+ EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
560
+ const Index vectorSize) {
561
+ PacketBlock<Packet, 4> block[n];
562
+
563
+ for (; i + n * vectorSize <= depth; i += n * vectorSize) {
564
+ for (Index k = 0; k < n; k++) {
565
+ if (UseLhs) {
566
+ bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k * vectorSize);
567
+ } else {
568
+ bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k * vectorSize, 0);
569
+ }
570
+ }
571
+
572
+ if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
573
+ for (Index k = 0; k < n; k++) {
574
+ ptranspose(block[k]);
575
+ }
576
+ }
577
+
578
+ for (Index k = 0; k < n; k++) {
579
+ storeBlock<Scalar, Packet, 4>(blockA + ri + k * 4 * vectorSize, block[k]);
580
+ }
581
+
582
+ ri += n * 4 * vectorSize;
583
+ }
584
+ }
585
+
586
+ EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
587
+ Index offset) {
588
+ const Index vectorSize = quad_traits<Scalar>::vectorsize;
589
+ Index ri = 0, j = 0;
590
+
591
+ for (; j + vectorSize <= rows; j += vectorSize) {
592
+ const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
593
+ Index i = 0;
594
+
595
+ if (PanelMode) ri += vectorSize * offset;
596
+
597
+ dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
598
+ dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
599
+ dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
600
+
601
+ for (; i < depth; i++) {
602
+ if (((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) {
603
+ if (UseLhs) {
604
+ blockA[ri + 0] = lhs2(0, i);
605
+ blockA[ri + 1] = lhs2(1, i);
606
+ blockA[ri + 2] = lhs2(2, i);
607
+ blockA[ri + 3] = lhs2(3, i);
608
+ } else {
609
+ blockA[ri + 0] = lhs2(i, 0);
610
+ blockA[ri + 1] = lhs2(i, 1);
611
+ blockA[ri + 2] = lhs2(i, 2);
612
+ blockA[ri + 3] = lhs2(i, 3);
613
+ }
614
+ } else {
615
+ Packet lhsV;
616
+ if (UseLhs) {
617
+ lhsV = lhs2.template loadPacket<Packet>(0, i);
618
+ } else {
619
+ lhsV = lhs2.template loadPacket<Packet>(i, 0);
620
+ }
621
+ pstore<Scalar>(blockA + ri, lhsV);
622
+ }
623
+
624
+ ri += vectorSize;
625
+ }
626
+
627
+ if (PanelMode) ri += vectorSize * (stride - offset - depth);
628
+ }
629
+
630
+ if (!UseLhs) {
631
+ if (PanelMode) ri += offset;
632
+
633
+ for (; j < rows; j++) {
634
+ const DataMapper lhs2 = lhs.getSubMapper(0, j);
635
+ for (Index i = 0; i < depth; i++) {
636
+ blockA[ri] = lhs2(i, 0);
637
+ ri += 1;
638
+ }
639
+
640
+ if (PanelMode) ri += stride - depth;
641
+ }
642
+ } else {
643
+ if (j < rows) {
644
+ if (PanelMode) ri += offset * (rows - j);
645
+
646
+ for (Index i = 0; i < depth; i++) {
647
+ Index k = j;
648
+ for (; k < rows; k++) {
649
+ blockA[ri] = lhs(k, i);
650
+ ri += 1;
651
+ }
652
+ }
653
+ }
654
+ }
655
+ }
656
+ };
657
+
658
+ // General template for lhs packing, float64 specialization.
659
+ template <typename DataMapper, int StorageOrder, bool PanelMode>
660
+ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true> {
661
+ template <Index n>
662
+ EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth,
663
+ const Index vectorSize) {
664
+ PacketBlock<Packet2d, 2> block[n];
665
+
666
+ for (; i + n * vectorSize <= depth; i += n * vectorSize) {
667
+ for (Index k = 0; k < n; k++) {
668
+ if (StorageOrder == RowMajor) {
669
+ block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize);
670
+ block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k * vectorSize);
671
+ } else {
672
+ block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 0);
673
+ block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k * vectorSize + 1);
674
+ }
675
+ }
676
+
677
+ if (StorageOrder == RowMajor) {
678
+ for (Index k = 0; k < n; k++) {
679
+ ptranspose(block[k]);
680
+ }
681
+ }
682
+
683
+ for (Index k = 0; k < n; k++) {
684
+ storeBlock<double, Packet2d, 2>(blockA + ri + k * 2 * vectorSize, block[k]);
685
+ }
686
+
687
+ ri += n * 2 * vectorSize;
688
+ }
689
+ }
690
+
691
+ EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
692
+ Index offset) {
693
+ const Index vectorSize = quad_traits<double>::vectorsize;
694
+ Index ri = 0, j = 0;
695
+
696
+ for (; j + vectorSize <= rows; j += vectorSize) {
697
+ const DataMapper lhs2 = lhs.getSubMapper(j, 0);
698
+ Index i = 0;
699
+
700
+ if (PanelMode) ri += vectorSize * offset;
701
+
702
+ dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
703
+ dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
704
+ dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
705
+
706
+ for (; i < depth; i++) {
707
+ if (StorageOrder == RowMajor) {
708
+ blockA[ri + 0] = lhs2(0, i);
709
+ blockA[ri + 1] = lhs2(1, i);
710
+ } else {
711
+ Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0, i);
712
+ pstore<double>(blockA + ri, lhsV);
713
+ }
714
+
715
+ ri += vectorSize;
716
+ }
717
+
718
+ if (PanelMode) ri += vectorSize * (stride - offset - depth);
719
+ }
720
+
721
+ if (j < rows) {
722
+ if (PanelMode) ri += offset * (rows - j);
723
+
724
+ for (Index i = 0; i < depth; i++) {
725
+ Index k = j;
726
+ for (; k < rows; k++) {
727
+ blockA[ri] = lhs(k, i);
728
+ ri += 1;
729
+ }
730
+ }
731
+ }
732
+ }
733
+ };
734
+
735
+ // General template for rhs packing, float64 specialization.
736
+ template <typename DataMapper, int StorageOrder, bool PanelMode>
737
+ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false> {
738
+ template <Index n>
739
+ EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth,
740
+ const Index vectorSize) {
741
+ PacketBlock<Packet2d, 2> block1[n], block2[n];
742
+ PacketBlock<Packet2d, 4> block3[n];
743
+
744
+ for (; i + n * vectorSize <= depth; i += n * vectorSize) {
745
+ for (Index k = 0; k < n; k++) {
746
+ if (StorageOrder == ColMajor) {
747
+ block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 0);
748
+ block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 1);
749
+ block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 2);
750
+ block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize, 3);
751
+ } else {
752
+ block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 0); //[a1 a2]
753
+ block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 0, 2); //[a3 a4]
754
+ block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 0); //[b1 b2]
755
+ block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k * vectorSize + 1, 2); //[b3 b4]
756
+ }
757
+ }
758
+
759
+ if (StorageOrder == ColMajor) {
760
+ for (Index k = 0; k < n; k++) {
761
+ ptranspose(block1[k]);
762
+ ptranspose(block2[k]);
763
+ }
764
+ }
765
+
766
+ for (Index k = 0; k < n; k++) {
767
+ if (StorageOrder == ColMajor) {
768
+ pstore<double>(blockB + ri + k * 4 * vectorSize, block1[k].packet[0]);
769
+ pstore<double>(blockB + ri + k * 4 * vectorSize + 2, block2[k].packet[0]);
770
+ pstore<double>(blockB + ri + k * 4 * vectorSize + 4, block1[k].packet[1]);
771
+ pstore<double>(blockB + ri + k * 4 * vectorSize + 6, block2[k].packet[1]);
772
+ } else {
773
+ storeBlock<double, Packet2d, 4>(blockB + ri + k * 4 * vectorSize, block3[k]);
774
+ }
775
+ }
776
+
777
+ ri += n * 4 * vectorSize;
778
+ }
779
+ }
780
+
781
+ EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
782
+ Index offset) {
783
+ const Index vectorSize = quad_traits<double>::vectorsize;
784
+ Index ri = 0, j = 0;
785
+
786
+ for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
787
+ const DataMapper rhs2 = rhs.getSubMapper(0, j);
788
+ Index i = 0;
789
+
790
+ if (PanelMode) ri += offset * (2 * vectorSize);
791
+
792
+ dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
793
+ dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
794
+ dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
795
+
796
+ for (; i < depth; i++) {
797
+ if (StorageOrder == ColMajor) {
798
+ blockB[ri + 0] = rhs2(i, 0);
799
+ blockB[ri + 1] = rhs2(i, 1);
800
+
801
+ ri += vectorSize;
802
+
803
+ blockB[ri + 0] = rhs2(i, 2);
804
+ blockB[ri + 1] = rhs2(i, 3);
805
+ } else {
806
+ Packet2d rhsV = rhs2.template loadPacket<Packet2d>(i, 0);
807
+ pstore<double>(blockB + ri, rhsV);
808
+
809
+ ri += vectorSize;
810
+
811
+ rhsV = rhs2.template loadPacket<Packet2d>(i, 2);
812
+ pstore<double>(blockB + ri, rhsV);
813
+ }
814
+ ri += vectorSize;
815
+ }
816
+
817
+ if (PanelMode) ri += (2 * vectorSize) * (stride - offset - depth);
818
+ }
819
+
820
+ if (PanelMode) ri += offset;
821
+
822
+ for (; j < cols; j++) {
823
+ const DataMapper rhs2 = rhs.getSubMapper(0, j);
824
+ for (Index i = 0; i < depth; i++) {
825
+ blockB[ri] = rhs2(i, 0);
826
+ ri += 1;
827
+ }
828
+
829
+ if (PanelMode) ri += stride - depth;
830
+ }
831
+ }
832
+ };
833
+
834
+ // General template for lhs packing, bfloat16 specialization.
835
+ template <typename DataMapper, int StorageOrder, bool PanelMode>
836
+ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, true> {
837
+ EIGEN_STRONG_INLINE void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride,
838
+ Index offset) {
839
+ const Index vectorSize = quad_traits<bfloat16>::vectorsize;
840
+ Index ri = 0, j = 0;
841
+
842
+ for (; j + 2 * vectorSize <= rows; j += 2 * vectorSize) {
843
+ const DataMapper lhs2 = lhs.getSubMapper(j, 0);
844
+ Index i = 0;
845
+
846
+ if (PanelMode) ri += 2 * vectorSize * offset;
847
+
848
+ if (StorageOrder == ColMajor) {
849
+ for (; i + 2 <= depth; i += 2) {
850
+ PacketBlock<Packet8bf, 4> block;
851
+
852
+ block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
853
+ block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
854
+ block.packet[2] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
855
+ block.packet[3] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 1);
856
+
857
+ Packet8bf t0, t1;
858
+ t0 = vec_mergeh(block.packet[0].m_val, block.packet[2].m_val);
859
+ t1 = vec_mergel(block.packet[0].m_val, block.packet[2].m_val);
860
+ block.packet[2] = vec_mergeh(block.packet[1].m_val, block.packet[3].m_val);
861
+ block.packet[3] = vec_mergel(block.packet[1].m_val, block.packet[3].m_val);
862
+ block.packet[0] = t0;
863
+ block.packet[1] = t1;
864
+
865
+ storeBlock<bfloat16, Packet8bf, 4>(blockA + ri, block);
866
+
867
+ ri += 2 * 2 * vectorSize;
868
+ }
869
+ if (depth & 1) {
870
+ PacketBlock<Packet8bf, 2> block;
871
+
872
+ block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
873
+ block.packet[1] = lhs2.template loadPacket<Packet8bf>(1 * vectorSize, i + 0);
874
+
875
+ storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
876
+
877
+ ri += 2 * vectorSize;
878
+ }
879
+ } else {
880
+ for (; i + vectorSize <= depth; i += vectorSize) {
881
+ PacketBlock<Packet8bf, 8> block1, block2;
882
+
883
+ bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
884
+ bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block2, lhs2, 1 * vectorSize, i);
885
+
886
+ Packet4ui v1[8], v2[8];
887
+
888
+ v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
889
+ reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
890
+ v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
891
+ reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
892
+ v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
893
+ reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
894
+ v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
895
+ reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
896
+ v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
897
+ reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
898
+ v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
899
+ reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
900
+ v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
901
+ reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
902
+ v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
903
+ reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
904
+ v2[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
905
+ reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
906
+ v2[1] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[0].m_val),
907
+ reinterpret_cast<Packet4ui>(block2.packet[1].m_val));
908
+ v2[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
909
+ reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
910
+ v2[3] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[2].m_val),
911
+ reinterpret_cast<Packet4ui>(block2.packet[3].m_val));
912
+ v2[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
913
+ reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
914
+ v2[5] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[4].m_val),
915
+ reinterpret_cast<Packet4ui>(block2.packet[5].m_val));
916
+ v2[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
917
+ reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
918
+ v2[7] = vec_mergel(reinterpret_cast<Packet4ui>(block2.packet[6].m_val),
919
+ reinterpret_cast<Packet4ui>(block2.packet[7].m_val));
920
+
921
+ #ifdef EIGEN_VECTORIZE_VSX
922
+ block1.packet[0] = reinterpret_cast<Packet8us>(
923
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
924
+ block1.packet[2] = reinterpret_cast<Packet8us>(
925
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
926
+ block1.packet[4] = reinterpret_cast<Packet8us>(
927
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
928
+ block1.packet[6] = reinterpret_cast<Packet8us>(
929
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
930
+ block1.packet[1] = reinterpret_cast<Packet8us>(
931
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
932
+ block1.packet[3] = reinterpret_cast<Packet8us>(
933
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
934
+ block1.packet[5] = reinterpret_cast<Packet8us>(
935
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
936
+ block1.packet[7] = reinterpret_cast<Packet8us>(
937
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
938
+ block2.packet[0] = reinterpret_cast<Packet8us>(
939
+ vec_mergeh(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
940
+ block2.packet[2] = reinterpret_cast<Packet8us>(
941
+ vec_mergel(reinterpret_cast<Packet2ul>(v2[0]), reinterpret_cast<Packet2ul>(v2[2])));
942
+ block2.packet[4] = reinterpret_cast<Packet8us>(
943
+ vec_mergeh(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
944
+ block2.packet[6] = reinterpret_cast<Packet8us>(
945
+ vec_mergel(reinterpret_cast<Packet2ul>(v2[1]), reinterpret_cast<Packet2ul>(v2[3])));
946
+ block2.packet[1] = reinterpret_cast<Packet8us>(
947
+ vec_mergeh(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
948
+ block2.packet[3] = reinterpret_cast<Packet8us>(
949
+ vec_mergel(reinterpret_cast<Packet2ul>(v2[4]), reinterpret_cast<Packet2ul>(v2[6])));
950
+ block2.packet[5] = reinterpret_cast<Packet8us>(
951
+ vec_mergeh(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
952
+ block2.packet[7] = reinterpret_cast<Packet8us>(
953
+ vec_mergel(reinterpret_cast<Packet2ul>(v2[5]), reinterpret_cast<Packet2ul>(v2[7])));
954
+ #else
955
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
956
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
957
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
958
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
959
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
960
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
961
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
962
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
963
+ block2.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_HI));
964
+ block2.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v2[0], v2[2], p16uc_TRANSPOSE64_LO));
965
+ block2.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_HI));
966
+ block2.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v2[1], v2[3], p16uc_TRANSPOSE64_LO));
967
+ block2.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_HI));
968
+ block2.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v2[4], v2[6], p16uc_TRANSPOSE64_LO));
969
+ block2.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_HI));
970
+ block2.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v2[5], v2[7], p16uc_TRANSPOSE64_LO));
971
+ #endif
972
+
973
+ for (Index M = 0; M < 8; M += 2) {
974
+ pstore<bfloat16>(blockA + ri + (0 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 0]);
975
+ pstore<bfloat16>(blockA + ri + (1 * vectorSize) + (2 * vectorSize * M), block1.packet[M + 1]);
976
+ pstore<bfloat16>(blockA + ri + (2 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 0]);
977
+ pstore<bfloat16>(blockA + ri + (3 * vectorSize) + (2 * vectorSize * M), block2.packet[M + 1]);
978
+ }
979
+
980
+ ri += 2 * vectorSize * vectorSize;
981
+ }
982
+ for (; i + 2 <= depth; i += 2) {
983
+ for (Index M = 0; M < 2 * vectorSize; M++) {
984
+ blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
985
+ blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
986
+ }
987
+
988
+ ri += 2 * 2 * vectorSize;
989
+ }
990
+ if (depth & 1) {
991
+ for (Index M = 0; M < 2 * vectorSize; M++) {
992
+ blockA[ri + M] = lhs2(M, i);
993
+ }
994
+ ri += 2 * vectorSize;
995
+ }
996
+ }
997
+
998
+ if (PanelMode) ri += 2 * vectorSize * (stride - offset - depth);
999
+ }
1000
+ for (; j + vectorSize <= rows; j += vectorSize) {
1001
+ const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1002
+ Index i = 0;
1003
+
1004
+ if (PanelMode) ri += vectorSize * offset;
1005
+
1006
+ if (StorageOrder == ColMajor) {
1007
+ for (; i + 2 <= depth; i += 2) {
1008
+ PacketBlock<Packet8bf, 2> block;
1009
+
1010
+ block.packet[0] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1011
+ block.packet[1] = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 1);
1012
+
1013
+ Packet8bf t0;
1014
+ t0 = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1015
+ block.packet[1] = vec_mergel(block.packet[0].m_val, block.packet[1].m_val);
1016
+ block.packet[0] = t0;
1017
+
1018
+ storeBlock<bfloat16, Packet8bf, 2>(blockA + ri, block);
1019
+
1020
+ ri += 2 * vectorSize;
1021
+ }
1022
+ if (depth & 1) {
1023
+ Packet8bf lhsV = lhs2.template loadPacket<Packet8bf>(0 * vectorSize, i + 0);
1024
+ pstore<bfloat16>(blockA + ri, lhsV);
1025
+
1026
+ ri += vectorSize;
1027
+ }
1028
+ } else {
1029
+ for (; i + vectorSize <= depth; i += vectorSize) {
1030
+ PacketBlock<Packet8bf, 8> block1;
1031
+
1032
+ bload<DataMapper, Packet8bf, 8, StorageOrder, false, 8>(block1, lhs2, 0 * vectorSize, i);
1033
+
1034
+ Packet4ui v1[8];
1035
+
1036
+ // This is transposing and interleaving data
1037
+ v1[0] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1038
+ reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1039
+ v1[1] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[0].m_val),
1040
+ reinterpret_cast<Packet4ui>(block1.packet[1].m_val));
1041
+ v1[2] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1042
+ reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1043
+ v1[3] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[2].m_val),
1044
+ reinterpret_cast<Packet4ui>(block1.packet[3].m_val));
1045
+ v1[4] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1046
+ reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1047
+ v1[5] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[4].m_val),
1048
+ reinterpret_cast<Packet4ui>(block1.packet[5].m_val));
1049
+ v1[6] = vec_mergeh(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1050
+ reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1051
+ v1[7] = vec_mergel(reinterpret_cast<Packet4ui>(block1.packet[6].m_val),
1052
+ reinterpret_cast<Packet4ui>(block1.packet[7].m_val));
1053
+
1054
+ #ifdef EIGEN_VECTORIZE_VSX
1055
+ block1.packet[0] = reinterpret_cast<Packet8us>(
1056
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1057
+ block1.packet[2] = reinterpret_cast<Packet8us>(
1058
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[0]), reinterpret_cast<Packet2ul>(v1[2])));
1059
+ block1.packet[4] = reinterpret_cast<Packet8us>(
1060
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1061
+ block1.packet[6] = reinterpret_cast<Packet8us>(
1062
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[1]), reinterpret_cast<Packet2ul>(v1[3])));
1063
+ block1.packet[1] = reinterpret_cast<Packet8us>(
1064
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1065
+ block1.packet[3] = reinterpret_cast<Packet8us>(
1066
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[4]), reinterpret_cast<Packet2ul>(v1[6])));
1067
+ block1.packet[5] = reinterpret_cast<Packet8us>(
1068
+ vec_mergeh(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1069
+ block1.packet[7] = reinterpret_cast<Packet8us>(
1070
+ vec_mergel(reinterpret_cast<Packet2ul>(v1[5]), reinterpret_cast<Packet2ul>(v1[7])));
1071
+ #else
1072
+ block1.packet[0] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_HI));
1073
+ block1.packet[2] = reinterpret_cast<Packet8us>(vec_perm(v1[0], v1[2], p16uc_TRANSPOSE64_LO));
1074
+ block1.packet[4] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_HI));
1075
+ block1.packet[6] = reinterpret_cast<Packet8us>(vec_perm(v1[1], v1[3], p16uc_TRANSPOSE64_LO));
1076
+ block1.packet[1] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_HI));
1077
+ block1.packet[3] = reinterpret_cast<Packet8us>(vec_perm(v1[4], v1[6], p16uc_TRANSPOSE64_LO));
1078
+ block1.packet[5] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_HI));
1079
+ block1.packet[7] = reinterpret_cast<Packet8us>(vec_perm(v1[5], v1[7], p16uc_TRANSPOSE64_LO));
1080
+ #endif
1081
+
1082
+ for (Index M = 0; M < 8; M++) {
1083
+ pstore<bfloat16>(blockA + ri + (vectorSize * M), block1.packet[M]);
1084
+ }
1085
+
1086
+ ri += vectorSize * vectorSize;
1087
+ }
1088
+ for (; i + 2 <= depth; i += 2) {
1089
+ for (Index M = 0; M < vectorSize; M++) {
1090
+ blockA[ri + (M * 2) + 0] = lhs2(M, i + 0);
1091
+ blockA[ri + (M * 2) + 1] = lhs2(M, i + 1);
1092
+ }
1093
+
1094
+ ri += 2 * vectorSize;
1095
+ }
1096
+ if (depth & 1) {
1097
+ for (Index M = 0; M < vectorSize; M++) {
1098
+ blockA[ri + M] = lhs2(M, i);
1099
+ }
1100
+
1101
+ ri += vectorSize;
1102
+ }
1103
+ }
1104
+
1105
+ if (PanelMode) ri += vectorSize * (stride - offset - depth);
1106
+ }
1107
+ if (j + 4 <= rows) {
1108
+ const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1109
+ Index i = 0;
1110
+
1111
+ if (PanelMode) ri += 4 * offset;
1112
+
1113
+ for (; i + 2 <= depth; i += 2) {
1114
+ if (StorageOrder == ColMajor) {
1115
+ PacketBlock<Packet8bf, 2> block;
1116
+
1117
+ block.packet[0] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1118
+ block.packet[1] = lhs2.template loadPacketPartial<Packet8bf>(0, i + 1, 4);
1119
+
1120
+ block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1121
+
1122
+ pstore<bfloat16>(blockA + ri, block.packet[0]);
1123
+ } else {
1124
+ blockA[ri + 0] = lhs2(0, i + 0);
1125
+ blockA[ri + 1] = lhs2(0, i + 1);
1126
+ blockA[ri + 2] = lhs2(1, i + 0);
1127
+ blockA[ri + 3] = lhs2(1, i + 1);
1128
+ blockA[ri + 4] = lhs2(2, i + 0);
1129
+ blockA[ri + 5] = lhs2(2, i + 1);
1130
+ blockA[ri + 6] = lhs2(3, i + 0);
1131
+ blockA[ri + 7] = lhs2(3, i + 1);
1132
+ }
1133
+
1134
+ ri += 2 * 4;
1135
+ }
1136
+ if (depth & 1) {
1137
+ if (StorageOrder == ColMajor) {
1138
+ Packet8bf lhsV = lhs2.template loadPacketPartial<Packet8bf>(0, i + 0, 4);
1139
+
1140
+ pstore_partial<bfloat16>(blockA + ri, lhsV, 4);
1141
+ } else {
1142
+ blockA[ri + 0] = lhs2(0, i);
1143
+ blockA[ri + 1] = lhs2(1, i);
1144
+ blockA[ri + 2] = lhs2(2, i);
1145
+ blockA[ri + 3] = lhs2(3, i);
1146
+ }
1147
+
1148
+ ri += 4;
1149
+ }
1150
+
1151
+ if (PanelMode) ri += 4 * (stride - offset - depth);
1152
+ j += 4;
1153
+ }
1154
+
1155
+ if (j < rows) {
1156
+ if (PanelMode) ri += offset * (rows - j);
1157
+
1158
+ Index i = 0;
1159
+ for (; i + 2 <= depth; i += 2) {
1160
+ Index k = j;
1161
+ for (; k < rows; k++) {
1162
+ blockA[ri + 0] = lhs(k, i + 0);
1163
+ blockA[ri + 1] = lhs(k, i + 1);
1164
+ ri += 2;
1165
+ }
1166
+ }
1167
+ if (depth & 1) {
1168
+ for (; j < rows; j++) {
1169
+ blockA[ri] = lhs(j, i);
1170
+ ri += 1;
1171
+ }
1172
+ }
1173
+ }
1174
+ }
1175
+ };
1176
+
1177
+ // General template for rhs packing, bfloat16 specialization.
1178
+ template <typename DataMapper, int StorageOrder, bool PanelMode>
1179
+ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false> {
1180
+ EIGEN_STRONG_INLINE void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride,
1181
+ Index offset) {
1182
+ const Index vectorSize = quad_traits<bfloat16>::vectorsize;
1183
+ Index ri = 0, j = 0;
1184
+
1185
+ for (; j + 4 <= cols; j += 4) {
1186
+ const DataMapper rhs2 = rhs.getSubMapper(0, j);
1187
+ Index i = 0;
1188
+
1189
+ if (PanelMode) ri += 4 * offset;
1190
+
1191
+ for (; i + vectorSize <= depth; i += vectorSize) {
1192
+ if (StorageOrder == ColMajor) {
1193
+ PacketBlock<Packet8bf, 4> block;
1194
+
1195
+ bload<DataMapper, Packet8bf, 4, StorageOrder, false, 4>(block, rhs2, i, 0);
1196
+
1197
+ Packet4ui t0, t1, t2, t3;
1198
+
1199
+ t0 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1200
+ reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1201
+ t1 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[0].m_val),
1202
+ reinterpret_cast<Packet4ui>(block.packet[1].m_val));
1203
+ t2 = vec_mergeh(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1204
+ reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1205
+ t3 = vec_mergel(reinterpret_cast<Packet4ui>(block.packet[2].m_val),
1206
+ reinterpret_cast<Packet4ui>(block.packet[3].m_val));
1207
+
1208
+ #ifdef EIGEN_VECTORIZE_VSX
1209
+ block.packet[0] =
1210
+ reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1211
+ block.packet[1] =
1212
+ reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t0), reinterpret_cast<Packet2ul>(t2)));
1213
+ block.packet[2] =
1214
+ reinterpret_cast<Packet8us>(vec_mergeh(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1215
+ block.packet[3] =
1216
+ reinterpret_cast<Packet8us>(vec_mergel(reinterpret_cast<Packet2ul>(t1), reinterpret_cast<Packet2ul>(t3)));
1217
+ #else
1218
+ block.packet[0] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_HI));
1219
+ block.packet[1] = reinterpret_cast<Packet8us>(vec_perm(t0, t2, p16uc_TRANSPOSE64_LO));
1220
+ block.packet[2] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_HI));
1221
+ block.packet[3] = reinterpret_cast<Packet8us>(vec_perm(t1, t3, p16uc_TRANSPOSE64_LO));
1222
+ #endif
1223
+
1224
+ storeBlock<bfloat16, Packet8bf, 4>(blockB + ri, block);
1225
+ } else {
1226
+ PacketBlock<Packet8bf, 8> block;
1227
+
1228
+ for (int M = 0; M < 8; M++) {
1229
+ block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1230
+ }
1231
+
1232
+ block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1233
+ block.packet[1] = vec_mergeh(block.packet[2].m_val, block.packet[3].m_val);
1234
+ block.packet[2] = vec_mergeh(block.packet[4].m_val, block.packet[5].m_val);
1235
+ block.packet[3] = vec_mergeh(block.packet[6].m_val, block.packet[7].m_val);
1236
+
1237
+ const Index size = 16 / sizeof(bfloat16);
1238
+
1239
+ for (int M = 0; M < 4; M++) {
1240
+ pstore<bfloat16>(blockB + ri + (M * size), block.packet[M]);
1241
+ }
1242
+ }
1243
+
1244
+ ri += 4 * vectorSize;
1245
+ }
1246
+ for (; i + 2 <= depth; i += 2) {
1247
+ if (StorageOrder == ColMajor) {
1248
+ blockB[ri + 0] = rhs2(i + 0, 0);
1249
+ blockB[ri + 1] = rhs2(i + 1, 0);
1250
+ blockB[ri + 2] = rhs2(i + 0, 1);
1251
+ blockB[ri + 3] = rhs2(i + 1, 1);
1252
+ blockB[ri + 4] = rhs2(i + 0, 2);
1253
+ blockB[ri + 5] = rhs2(i + 1, 2);
1254
+ blockB[ri + 6] = rhs2(i + 0, 3);
1255
+ blockB[ri + 7] = rhs2(i + 1, 3);
1256
+ } else {
1257
+ PacketBlock<Packet8bf, 2> block;
1258
+
1259
+ for (int M = 0; M < 2; M++) {
1260
+ block.packet[M] = rhs2.template loadPacketPartial<Packet8bf>(i + M, 0, 4);
1261
+ }
1262
+
1263
+ block.packet[0] = vec_mergeh(block.packet[0].m_val, block.packet[1].m_val);
1264
+
1265
+ pstore<bfloat16>(blockB + ri, block.packet[0]);
1266
+ }
1267
+
1268
+ ri += 4 * 2;
1269
+ }
1270
+ if (depth & 1) {
1271
+ blockB[ri + 0] = rhs2(i, 0);
1272
+ blockB[ri + 1] = rhs2(i, 1);
1273
+ blockB[ri + 2] = rhs2(i, 2);
1274
+ blockB[ri + 3] = rhs2(i, 3);
1275
+
1276
+ ri += 4;
1277
+ }
1278
+
1279
+ if (PanelMode) ri += 4 * (stride - offset - depth);
1280
+ }
1281
+
1282
+ if (j < cols) {
1283
+ if (PanelMode) ri += offset * (cols - j);
1284
+
1285
+ Index i = 0;
1286
+ for (; i + 2 <= depth; i += 2) {
1287
+ Index k = j;
1288
+ for (; k < cols; k++) {
1289
+ blockB[ri + 0] = rhs(i + 0, k);
1290
+ blockB[ri + 1] = rhs(i + 1, k);
1291
+ ri += 2;
1292
+ }
1293
+ }
1294
+ if (depth & 1) {
1295
+ for (; j < cols; j++) {
1296
+ blockB[ri] = rhs(i, j);
1297
+ ri += 1;
1298
+ }
1299
+ }
1300
+ }
1301
+ }
1302
+ };
1303
+
1304
+ // General template for lhs complex packing, float64 specialization.
1305
+ template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1306
+ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> {
1307
+ EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii,
1308
+ Index depth, const Index vectorSize) {
1309
+ PacketBlock<Packet, 2> blockr, blocki;
1310
+ PacketBlock<PacketC, 4> cblock;
1311
+
1312
+ for (; i + vectorSize <= depth; i += vectorSize) {
1313
+ if (StorageOrder == ColMajor) {
1314
+ cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
1315
+ cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1316
+
1317
+ cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
1318
+ cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
1319
+
1320
+ blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
1321
+ blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
1322
+
1323
+ blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
1324
+ blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
1325
+ } else {
1326
+ cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
1327
+ cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
1328
+
1329
+ cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
1330
+ cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
1331
+
1332
+ blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
1333
+ blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
1334
+
1335
+ blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1336
+ blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1337
+ }
1338
+
1339
+ if (Conjugate) {
1340
+ blocki.packet[0] = -blocki.packet[0];
1341
+ blocki.packet[1] = -blocki.packet[1];
1342
+ }
1343
+
1344
+ storeBlock<double, Packet, 2>(blockAt + rir, blockr);
1345
+ storeBlock<double, Packet, 2>(blockAt + rii, blocki);
1346
+
1347
+ rir += 2 * vectorSize;
1348
+ rii += 2 * vectorSize;
1349
+ }
1350
+ }
1351
+
1352
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
1353
+ Index stride, Index offset) {
1354
+ const Index vectorSize = quad_traits<double>::vectorsize;
1355
+ const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth);
1356
+ Index rir = ((PanelMode) ? (vectorSize * offset) : 0), rii;
1357
+ double* blockAt = reinterpret_cast<double*>(blockA);
1358
+ Index j = 0;
1359
+
1360
+ for (; j + vectorSize <= rows; j += vectorSize) {
1361
+ const DataMapper lhs2 = lhs.getSubMapper(j, 0);
1362
+ Index i = 0;
1363
+
1364
+ rii = rir + vectorDelta;
1365
+
1366
+ dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
1367
+
1368
+ for (; i < depth; i++) {
1369
+ PacketBlock<Packet, 1> blockr, blocki;
1370
+ PacketBlock<PacketC, 2> cblock;
1371
+
1372
+ cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
1373
+ cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
1374
+
1375
+ blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1376
+ blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1377
+
1378
+ if (Conjugate) {
1379
+ blocki.packet[0] = -blocki.packet[0];
1380
+ }
1381
+
1382
+ pstore<double>(blockAt + rir, blockr.packet[0]);
1383
+ pstore<double>(blockAt + rii, blocki.packet[0]);
1384
+
1385
+ rir += vectorSize;
1386
+ rii += vectorSize;
1387
+ }
1388
+
1389
+ rir += ((PanelMode) ? (vectorSize * (2 * stride - depth)) : vectorDelta);
1390
+ }
1391
+
1392
+ if (j < rows) {
1393
+ if (PanelMode) rir += (offset * (rows - j - vectorSize));
1394
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
1395
+
1396
+ for (Index i = 0; i < depth; i++) {
1397
+ Index k = j;
1398
+ for (; k < rows; k++) {
1399
+ blockAt[rir] = lhs(k, i).real();
1400
+
1401
+ if (Conjugate)
1402
+ blockAt[rii] = -lhs(k, i).imag();
1403
+ else
1404
+ blockAt[rii] = lhs(k, i).imag();
1405
+
1406
+ rir += 1;
1407
+ rii += 1;
1408
+ }
1409
+ }
1410
+ }
1411
+ }
1412
+ };
1413
+
1414
+ // General template for rhs complex packing, float64 specialization.
1415
+ template <typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
1416
+ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false> {
1417
+ EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii,
1418
+ Index depth, const Index vectorSize) {
1419
+ for (; i < depth; i++) {
1420
+ PacketBlock<PacketC, 4> cblock;
1421
+ PacketBlock<Packet, 2> blockr, blocki;
1422
+
1423
+ bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
1424
+
1425
+ blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
1426
+ blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
1427
+
1428
+ blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
1429
+ blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
1430
+
1431
+ if (Conjugate) {
1432
+ blocki.packet[0] = -blocki.packet[0];
1433
+ blocki.packet[1] = -blocki.packet[1];
1434
+ }
1435
+
1436
+ storeBlock<double, Packet, 2>(blockBt + rir, blockr);
1437
+ storeBlock<double, Packet, 2>(blockBt + rii, blocki);
1438
+
1439
+ rir += 2 * vectorSize;
1440
+ rii += 2 * vectorSize;
1441
+ }
1442
+ }
1443
+
1444
+ EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols,
1445
+ Index stride, Index offset) {
1446
+ const Index vectorSize = quad_traits<double>::vectorsize;
1447
+ const Index vectorDelta = 2 * vectorSize * ((PanelMode) ? stride : depth);
1448
+ Index rir = ((PanelMode) ? (2 * vectorSize * offset) : 0), rii;
1449
+ double* blockBt = reinterpret_cast<double*>(blockB);
1450
+ Index j = 0;
1451
+
1452
+ for (; j + 2 * vectorSize <= cols; j += 2 * vectorSize) {
1453
+ const DataMapper rhs2 = rhs.getSubMapper(0, j);
1454
+ Index i = 0;
1455
+
1456
+ rii = rir + vectorDelta;
1457
+
1458
+ dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
1459
+
1460
+ rir += ((PanelMode) ? (2 * vectorSize * (2 * stride - depth)) : vectorDelta);
1461
+ }
1462
+
1463
+ if (PanelMode) rir -= (offset * (2 * vectorSize - 1));
1464
+
1465
+ for (; j < cols; j++) {
1466
+ const DataMapper rhs2 = rhs.getSubMapper(0, j);
1467
+ rii = rir + ((PanelMode) ? stride : depth);
1468
+
1469
+ for (Index i = 0; i < depth; i++) {
1470
+ blockBt[rir] = rhs2(i, 0).real();
1471
+
1472
+ if (Conjugate)
1473
+ blockBt[rii] = -rhs2(i, 0).imag();
1474
+ else
1475
+ blockBt[rii] = rhs2(i, 0).imag();
1476
+
1477
+ rir += 1;
1478
+ rii += 1;
1479
+ }
1480
+
1481
+ rir += ((PanelMode) ? (2 * stride - depth) : depth);
1482
+ }
1483
+ }
1484
+ };
1485
+
1486
+ /**************
1487
+ * GEMM utils *
1488
+ **************/
1489
+
1490
+ // 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
1491
+ template <typename Packet, bool NegativeAccumulate, int N>
1492
+ EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet, N>* acc, const Packet& lhsV, const Packet* rhsV) {
1493
+ if (NegativeAccumulate) {
1494
+ for (int M = 0; M < N; M++) {
1495
+ acc->packet[M] = vec_nmsub(lhsV, rhsV[M], acc->packet[M]);
1496
+ }
1497
+ } else {
1498
+ for (int M = 0; M < N; M++) {
1499
+ acc->packet[M] = vec_madd(lhsV, rhsV[M], acc->packet[M]);
1500
+ }
1501
+ }
1502
+ }
1503
+
1504
+ template <int N, typename Scalar, typename Packet, bool NegativeAccumulate>
1505
+ EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet, N>* acc, const Scalar* lhs, const Packet* rhsV) {
1506
+ Packet lhsV = pload<Packet>(lhs);
1507
+
1508
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
1509
+ }
1510
+
1511
+ // 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types
1512
+ // real * complex and complex * real.
1513
+ template <int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1514
+ EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag,
1515
+ const Packet& lhsV, Packet& lhsVi, const Packet* rhsV, const Packet* rhsVi) {
1516
+ pger_common<Packet, false, N>(accReal, lhsV, rhsV);
1517
+ if (LhsIsReal) {
1518
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1519
+ EIGEN_UNUSED_VARIABLE(lhsVi);
1520
+ } else {
1521
+ if (!RhsIsReal) {
1522
+ pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
1523
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
1524
+ } else {
1525
+ EIGEN_UNUSED_VARIABLE(rhsVi);
1526
+ }
1527
+ pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
1528
+ }
1529
+ }
1530
+
1531
+ template <int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
1532
+ EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet, N>* accReal, PacketBlock<Packet, N>* accImag, const Scalar* lhs_ptr,
1533
+ const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) {
1534
+ Packet lhsV = ploadLhs<Packet>(lhs_ptr);
1535
+ Packet lhsVi;
1536
+ if (!LhsIsReal)
1537
+ lhsVi = ploadLhs<Packet>(lhs_ptr_imag);
1538
+ else
1539
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
1540
+
1541
+ pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
1542
+ }
1543
+
1544
+ template <typename Packet>
1545
+ EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet) * lhs) {
1546
+ return ploadu<Packet>(lhs);
1547
+ }
1548
+
1549
+ // Zero the accumulator on PacketBlock.
1550
+ template <typename Packet, int N>
1551
+ EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet, N>& acc) {
1552
+ for (int M = 0; M < N; M++) {
1553
+ acc.packet[M] = pset1<Packet>((__UNPACK_TYPE__(Packet))0);
1554
+ }
1555
+ }
1556
+
1557
+ template <typename Packet, int N>
1558
+ EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ,
1559
+ const Packet& pAlpha) {
1560
+ for (int M = 0; M < N; M++) {
1561
+ acc.packet[M] = vec_mul(accZ.packet[M], pAlpha);
1562
+ }
1563
+ }
1564
+
1565
+ template <typename Packet, int N>
1566
+ EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet, N>& acc, const Packet& pMask) {
1567
+ for (int M = 0; M < N; M++) {
1568
+ acc.packet[M] = pand<Packet>(acc.packet[M], pMask);
1569
+ }
1570
+ }
1571
+
1572
+ // Complex version of PacketBlock scaling.
1573
+ template <typename Packet, int N, bool mask>
1574
+ EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet, N>& aReal, PacketBlock<Packet, N>& aImag, const Packet& bReal,
1575
+ const Packet& bImag, PacketBlock<Packet, N>& cReal, PacketBlock<Packet, N>& cImag,
1576
+ const Packet& pMask) {
1577
+ if (mask && (sizeof(__UNPACK_TYPE__(Packet)) == sizeof(float))) {
1578
+ band<Packet, N>(aReal, pMask);
1579
+ band<Packet, N>(aImag, pMask);
1580
+ } else {
1581
+ EIGEN_UNUSED_VARIABLE(pMask);
1582
+ }
1583
+
1584
+ bscalec_common<Packet, N>(cReal, aReal, bReal);
1585
+
1586
+ bscalec_common<Packet, N>(cImag, aImag, bReal);
1587
+
1588
+ pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
1589
+
1590
+ pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
1591
+ }
1592
+
1593
+ // Load a PacketBlock, the N parameters make tuning gemm easier so we can add more accumulators as needed.
1594
+ //
1595
+ // full = operate (load) on the entire PacketBlock or only half
1596
+ template <typename DataMapper, typename Packet, const Index accCols, int StorageOrder, bool Complex, int N, bool full>
1597
+ EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1598
+ Index col) {
1599
+ if (StorageOrder == RowMajor) {
1600
+ for (int M = 0; M < N; M++) {
1601
+ acc.packet[M] = res.template loadPacket<Packet>(row + M, col);
1602
+ }
1603
+ if (Complex) {
1604
+ for (int M = 0; M < N; M++) {
1605
+ acc.packet[M + N] = res.template loadPacket<Packet>(row + M, col + accCols);
1606
+ }
1607
+ }
1608
+ } else {
1609
+ for (int M = 0; M < N; M++) {
1610
+ acc.packet[M] = res.template loadPacket<Packet>(row, col + M);
1611
+ }
1612
+ if (Complex && full) {
1613
+ for (int M = 0; M < N; M++) {
1614
+ acc.packet[M + N] = res.template loadPacket<Packet>(row + accCols, col + M);
1615
+ }
1616
+ }
1617
+ }
1618
+ }
1619
+
1620
+ template <typename DataMapper, typename Packet, int N>
1621
+ EIGEN_ALWAYS_INLINE void bstore(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row) {
1622
+ for (int M = 0; M < N; M++) {
1623
+ res.template storePacket<Packet>(row, M, acc.packet[M]);
1624
+ }
1625
+ }
1626
+
1627
+ #ifdef USE_PARTIAL_PACKETS
1628
+ template <typename DataMapper, typename Packet, const Index accCols, bool Complex, Index N, bool full>
1629
+ EIGEN_ALWAYS_INLINE void bload_partial(PacketBlock<Packet, N*(Complex ? 2 : 1)>& acc, const DataMapper& res, Index row,
1630
+ Index elements) {
1631
+ for (Index M = 0; M < N; M++) {
1632
+ acc.packet[M] = res.template loadPacketPartial<Packet>(row, M, elements);
1633
+ }
1634
+ if (Complex && full) {
1635
+ for (Index M = 0; M < N; M++) {
1636
+ acc.packet[M + N] = res.template loadPacketPartial<Packet>(row + accCols, M, elements);
1637
+ }
1638
+ }
1639
+ }
1640
+
1641
+ template <typename DataMapper, typename Packet, Index N>
1642
+ EIGEN_ALWAYS_INLINE void bstore_partial(PacketBlock<Packet, N>& acc, const DataMapper& res, Index row, Index elements) {
1643
+ for (Index M = 0; M < N; M++) {
1644
+ res.template storePacketPartial<Packet>(row, M, acc.packet[M], elements);
1645
+ }
1646
+ }
1647
+ #endif
1648
+
1649
+ #ifdef _ARCH_PWR10
1650
+ #define USE_P10_AND_PVIPR2_0 (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
1651
+ #else
1652
+ #define USE_P10_AND_PVIPR2_0 0
1653
+ #endif
1654
+
1655
+ #if !USE_P10_AND_PVIPR2_0
1656
+ const static Packet4i mask4[4] = {{0, 0, 0, 0}, {-1, 0, 0, 0}, {-1, -1, 0, 0}, {-1, -1, -1, 0}};
1657
+ #endif
1658
+
1659
+ template <typename Packet>
1660
+ EIGEN_ALWAYS_INLINE Packet bmask(const Index remaining_rows) {
1661
+ #if USE_P10_AND_PVIPR2_0
1662
+ #ifdef _BIG_ENDIAN
1663
+ return Packet(vec_reve(vec_genwm((1 << remaining_rows) - 1)));
1664
+ #else
1665
+ return Packet(vec_genwm((1 << remaining_rows) - 1));
1666
+ #endif
1667
+ #else
1668
+ return Packet(mask4[remaining_rows]);
1669
+ #endif
1670
+ }
1671
+
1672
+ template <>
1673
+ EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const Index remaining_rows) {
1674
+ #if USE_P10_AND_PVIPR2_0
1675
+ Packet2d mask2 = Packet2d(vec_gendm(remaining_rows));
1676
+ #ifdef _BIG_ENDIAN
1677
+ return preverse(mask2);
1678
+ #else
1679
+ return mask2;
1680
+ #endif
1681
+ #else
1682
+ Packet2l ret = {-remaining_rows, 0};
1683
+ return Packet2d(ret);
1684
+ #endif
1685
+ }
1686
+
1687
+ template <typename Packet, int N>
1688
+ EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha) {
1689
+ for (int M = 0; M < N; M++) {
1690
+ acc.packet[M] = pmadd<Packet>(pAlpha, accZ.packet[M], acc.packet[M]);
1691
+ }
1692
+ }
1693
+
1694
+ // Scale the PacketBlock vectors by alpha.
1695
+ template <typename Packet, int N, bool mask>
1696
+ EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet, N>& acc, PacketBlock<Packet, N>& accZ, const Packet& pAlpha,
1697
+ const Packet& pMask) {
1698
+ if (mask) {
1699
+ band<Packet, N>(accZ, pMask);
1700
+ } else {
1701
+ EIGEN_UNUSED_VARIABLE(pMask);
1702
+ }
1703
+
1704
+ bscale<Packet, N>(acc, accZ, pAlpha);
1705
+ }
1706
+
1707
+ template <typename Packet, int N, bool real>
1708
+ EIGEN_ALWAYS_INLINE void pbroadcastN(const __UNPACK_TYPE__(Packet) * ap0, const __UNPACK_TYPE__(Packet) * ap1,
1709
+ const __UNPACK_TYPE__(Packet) * ap2, Packet& a0, Packet& a1, Packet& a2,
1710
+ Packet& a3) {
1711
+ a0 = pset1<Packet>(ap0[0]);
1712
+ if (N == 4) {
1713
+ a1 = pset1<Packet>(ap0[1]);
1714
+ a2 = pset1<Packet>(ap0[2]);
1715
+ a3 = pset1<Packet>(ap0[3]);
1716
+ EIGEN_UNUSED_VARIABLE(ap1);
1717
+ EIGEN_UNUSED_VARIABLE(ap2);
1718
+ } else {
1719
+ if (N > 1) {
1720
+ a1 = pset1<Packet>(ap1[0]);
1721
+ } else {
1722
+ EIGEN_UNUSED_VARIABLE(a1);
1723
+ EIGEN_UNUSED_VARIABLE(ap1);
1724
+ }
1725
+ if (N > 2) {
1726
+ a2 = pset1<Packet>(ap2[0]);
1727
+ } else {
1728
+ EIGEN_UNUSED_VARIABLE(a2);
1729
+ EIGEN_UNUSED_VARIABLE(ap2);
1730
+ }
1731
+ }
1732
+ }
1733
+
1734
+ template <>
1735
+ EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, true>(const float* ap0, const float*, const float*, Packet4f& a0,
1736
+ Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1737
+ pbroadcast4<Packet4f>(ap0, a0, a1, a2, a3);
1738
+ }
1739
+
1740
+ template <>
1741
+ EIGEN_ALWAYS_INLINE void pbroadcastN<Packet4f, 4, false>(const float* ap0, const float* ap1, const float* ap2,
1742
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) {
1743
+ pbroadcastN<Packet4f, 4, true>(ap0, ap1, ap2, a0, a1, a2, a3);
1744
+ }
1745
+
1746
+ template <>
1747
+ EIGEN_ALWAYS_INLINE void pbroadcastN<Packet2d, 4, false>(const double* ap0, const double*, const double*, Packet2d& a0,
1748
+ Packet2d& a1, Packet2d& a2, Packet2d& a3) {
1749
+ a1 = pload<Packet2d>(ap0);
1750
+ a3 = pload<Packet2d>(ap0 + 2);
1751
+ a0 = vec_splat(a1, 0);
1752
+ a1 = vec_splat(a1, 1);
1753
+ a2 = vec_splat(a3, 0);
1754
+ a3 = vec_splat(a3, 1);
1755
+ }
1756
+
1757
+ // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
1758
+ template <typename Packet, typename Packetc, int N, bool full>
1759
+ EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1760
+ PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2) {
1761
+ for (int M = 0; M < N; M++) {
1762
+ acc1.packet[M].v = vec_mergeh(taccReal.packet[M], taccImag.packet[M]);
1763
+ }
1764
+
1765
+ if (full) {
1766
+ for (int M = 0; M < N; M++) {
1767
+ acc2.packet[M].v = vec_mergel(taccReal.packet[M], taccImag.packet[M]);
1768
+ }
1769
+ }
1770
+ }
1771
+
1772
+ template <typename Packet, typename Packetc, int N, bool full>
1773
+ EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet, N>& taccReal, PacketBlock<Packet, N>& taccImag,
1774
+ PacketBlock<Packetc, N * 2>& tRes, PacketBlock<Packetc, N>& acc1,
1775
+ PacketBlock<Packetc, N>& acc2) {
1776
+ bcouple_common<Packet, Packetc, N, full>(taccReal, taccImag, acc1, acc2);
1777
+
1778
+ for (int M = 0; M < N; M++) {
1779
+ acc1.packet[M] = padd<Packetc>(tRes.packet[M], acc1.packet[M]);
1780
+ }
1781
+
1782
+ if (full) {
1783
+ for (int M = 0; M < N; M++) {
1784
+ acc2.packet[M] = padd<Packetc>(tRes.packet[M + N], acc2.packet[M]);
1785
+ }
1786
+ }
1787
+ }
1788
+
1789
+ // PEEL loop factor.
1790
+ #define PEEL 7
1791
+ #define PEEL_ROW 7
1792
+
1793
+ #define MICRO_UNROLL(func) func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
1794
+
1795
+ #define MICRO_NORMAL_ROWS accRows == quad_traits<Scalar>::rows || accRows == 1
1796
+
1797
+ #define MICRO_NEW_ROWS ((MICRO_NORMAL_ROWS) ? accRows : 1)
1798
+
1799
+ #define MICRO_RHS(ptr, N) rhs_##ptr##N
1800
+
1801
+ #define MICRO_ZERO_PEEL(peel) \
1802
+ if ((PEEL_ROW > peel) && (peel != 0)) { \
1803
+ bsetzero<Packet, accRows>(accZero##peel); \
1804
+ } else { \
1805
+ EIGEN_UNUSED_VARIABLE(accZero##peel); \
1806
+ }
1807
+
1808
+ #define MICRO_ADD(ptr, N) \
1809
+ if (MICRO_NORMAL_ROWS) { \
1810
+ MICRO_RHS(ptr, 0) += (accRows * N); \
1811
+ } else { \
1812
+ MICRO_RHS(ptr, 0) += N; \
1813
+ MICRO_RHS(ptr, 1) += N; \
1814
+ if (accRows == 3) { \
1815
+ MICRO_RHS(ptr, 2) += N; \
1816
+ } \
1817
+ }
1818
+
1819
+ #define MICRO_ADD_ROWS(N) MICRO_ADD(ptr, N)
1820
+
1821
+ #define MICRO_BROADCAST1(peel, ptr, rhsV, real) \
1822
+ if (MICRO_NORMAL_ROWS) { \
1823
+ pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + (accRows * peel), MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 0), \
1824
+ rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1825
+ } else { \
1826
+ pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0) + peel, MICRO_RHS(ptr, 1) + peel, MICRO_RHS(ptr, 2) + peel, \
1827
+ rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
1828
+ }
1829
+
1830
+ #define MICRO_BROADCAST(peel) MICRO_BROADCAST1(peel, ptr, rhsV, true)
1831
+
1832
+ #define MICRO_BROADCAST_EXTRA1(ptr, rhsV, real) \
1833
+ pbroadcastN<Packet, accRows, real>(MICRO_RHS(ptr, 0), MICRO_RHS(ptr, 1), MICRO_RHS(ptr, 2), rhsV[0], rhsV[1], \
1834
+ rhsV[2], rhsV[3]);
1835
+
1836
+ #define MICRO_BROADCAST_EXTRA \
1837
+ Packet rhsV[4]; \
1838
+ MICRO_BROADCAST_EXTRA1(ptr, rhsV, true) \
1839
+ MICRO_ADD_ROWS(1)
1840
+
1841
+ #define MICRO_SRC2(ptr, N, M) \
1842
+ if (MICRO_NORMAL_ROWS) { \
1843
+ EIGEN_UNUSED_VARIABLE(strideB); \
1844
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 1)); \
1845
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1846
+ } else { \
1847
+ MICRO_RHS(ptr, 1) = rhs_base + N + M; \
1848
+ if (accRows == 3) { \
1849
+ MICRO_RHS(ptr, 2) = rhs_base + N * 2 + M; \
1850
+ } else { \
1851
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr, 2)); \
1852
+ } \
1853
+ }
1854
+
1855
+ #define MICRO_SRC2_PTR MICRO_SRC2(ptr, strideB, 0)
1856
+
1857
+ #define MICRO_ZERO_PEEL_ROW MICRO_UNROLL(MICRO_ZERO_PEEL)
1858
+
1859
+ #define MICRO_WORK_PEEL(peel) \
1860
+ if (PEEL_ROW > peel) { \
1861
+ MICRO_BROADCAST(peel) \
1862
+ pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
1863
+ } else { \
1864
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
1865
+ }
1866
+
1867
+ #define MICRO_WORK_PEEL_ROW \
1868
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
1869
+ MICRO_UNROLL(MICRO_WORK_PEEL) \
1870
+ lhs_ptr += (remaining_rows * PEEL_ROW); \
1871
+ MICRO_ADD_ROWS(PEEL_ROW)
1872
+
1873
+ #define MICRO_ADD_PEEL(peel, sum) \
1874
+ if (PEEL_ROW > peel) { \
1875
+ for (Index i = 0; i < accRows; i++) { \
1876
+ accZero##sum.packet[i] += accZero##peel.packet[i]; \
1877
+ } \
1878
+ }
1879
+
1880
+ #define MICRO_ADD_PEEL_ROW \
1881
+ MICRO_ADD_PEEL(4, 0) \
1882
+ MICRO_ADD_PEEL(5, 1) \
1883
+ MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
1884
+
1885
+ #define MICRO_PREFETCHN1(ptr, N) \
1886
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 0)); \
1887
+ if (N == 2 || N == 3) { \
1888
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 1)); \
1889
+ if (N == 3) { \
1890
+ EIGEN_POWER_PREFETCH(MICRO_RHS(ptr, 2)); \
1891
+ } \
1892
+ }
1893
+
1894
+ #define MICRO_PREFETCHN(N) MICRO_PREFETCHN1(ptr, N)
1895
+
1896
+ #define MICRO_COMPLEX_PREFETCHN(N) \
1897
+ MICRO_PREFETCHN1(ptr_real, N); \
1898
+ if (!RhsIsReal) { \
1899
+ MICRO_PREFETCHN1(ptr_imag, N); \
1900
+ }
1901
+
1902
+ template <typename Scalar, typename Packet, const Index accRows, const Index remaining_rows>
1903
+ EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(const Scalar*& lhs_ptr, const Scalar*& rhs_ptr0, const Scalar*& rhs_ptr1,
1904
+ const Scalar*& rhs_ptr2, PacketBlock<Packet, accRows>& accZero) {
1905
+ MICRO_BROADCAST_EXTRA
1906
+ pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
1907
+ lhs_ptr += remaining_rows;
1908
+ }
1909
+
1910
+ template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols,
1911
+ const Index remaining_rows>
1912
+ EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(const DataMapper& res, const Scalar* lhs_base,
1913
+ const Scalar* rhs_base, Index depth, Index strideA, Index offsetA,
1914
+ Index strideB, Index row, Index rows, const Packet& pAlpha,
1915
+ const Packet& pMask) {
1916
+ const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
1917
+ const Scalar* lhs_ptr = lhs_base + row * strideA + remaining_rows * offsetA;
1918
+ PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
1919
+
1920
+ MICRO_SRC2_PTR
1921
+ bsetzero<Packet, accRows>(accZero0);
1922
+
1923
+ Index remaining_depth = depth & -quad_traits<Scalar>::rows;
1924
+ Index k = 0;
1925
+ if (remaining_depth >= PEEL_ROW) {
1926
+ MICRO_ZERO_PEEL_ROW
1927
+ do {
1928
+ MICRO_PREFETCHN(accRows)
1929
+ EIGEN_POWER_PREFETCH(lhs_ptr);
1930
+ MICRO_WORK_PEEL_ROW
1931
+ } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
1932
+ MICRO_ADD_PEEL_ROW
1933
+ }
1934
+ for (; k < depth; k++) {
1935
+ MICRO_EXTRA_ROW<Scalar, Packet, accRows, remaining_rows>(lhs_ptr, rhs_ptr0, rhs_ptr1, rhs_ptr2, accZero0);
1936
+ }
1937
+
1938
+ #ifdef USE_PARTIAL_PACKETS
1939
+ EIGEN_UNUSED_VARIABLE(rows);
1940
+ EIGEN_UNUSED_VARIABLE(pMask);
1941
+ bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row, remaining_rows);
1942
+ bscale<Packet, accRows>(acc, accZero0, pAlpha);
1943
+ bstore_partial<DataMapper, Packet, accRows>(acc, res, row, remaining_rows);
1944
+ #else
1945
+ bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row, 0);
1946
+ if ((accRows == 1) || (rows >= accCols)) {
1947
+ bscale<Packet, accRows, true>(acc, accZero0, pAlpha, pMask);
1948
+ bstore<DataMapper, Packet, accRows>(acc, res, row);
1949
+ } else {
1950
+ bscale<Packet, accRows, false>(acc, accZero0, pAlpha, pMask);
1951
+ for (Index j = 0; j < accRows; j++) {
1952
+ for (Index i = 0; i < remaining_rows; i++) {
1953
+ res(row + i, j) = acc.packet[j][i];
1954
+ }
1955
+ }
1956
+ }
1957
+ #endif
1958
+ }
1959
+
1960
+ #define MICRO_EXTRA(MICRO_EXTRA_UNROLL, value, is_col) \
1961
+ switch (value) { \
1962
+ default: \
1963
+ MICRO_EXTRA_UNROLL(1) \
1964
+ break; \
1965
+ case 2: \
1966
+ if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1967
+ MICRO_EXTRA_UNROLL(2) \
1968
+ } \
1969
+ break; \
1970
+ case 3: \
1971
+ if (is_col || (sizeof(Scalar) == sizeof(float))) { \
1972
+ MICRO_EXTRA_UNROLL(3) \
1973
+ } \
1974
+ break; \
1975
+ }
1976
+
1977
+ #define MICRO_EXTRA_ROWS(N) \
1978
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, accRows, accCols, N>( \
1979
+ res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlpha, pMask);
1980
+
1981
+ template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
1982
+ EIGEN_ALWAYS_INLINE void gemm_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
1983
+ Index depth, Index strideA, Index offsetA, Index strideB, Index row, Index rows,
1984
+ Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
1985
+ MICRO_EXTRA(MICRO_EXTRA_ROWS, remaining_rows, false)
1986
+ }
1987
+
1988
+ #define MICRO_UNROLL_WORK(func, func2, peel) \
1989
+ MICRO_UNROLL(func2); \
1990
+ func(0, peel) func(1, peel) func(2, peel) func(3, peel) func(4, peel) func(5, peel) func(6, peel) func(7, peel)
1991
+
1992
+ #define MICRO_WORK_ONE(iter, peel) \
1993
+ if (unroll_factor > iter) { \
1994
+ pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
1995
+ }
1996
+
1997
+ #define MICRO_TYPE_PEEL4(func, func2, peel) \
1998
+ if (PEEL > peel) { \
1999
+ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
2000
+ MICRO_BROADCAST(peel) \
2001
+ MICRO_UNROLL_WORK(func, func2, peel) \
2002
+ } else { \
2003
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2004
+ }
2005
+
2006
+ #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2007
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
2008
+ func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3) func(func1, func2, 4) \
2009
+ func(func1, func2, 5) func(func1, func2, 6) func(func1, func2, 7)
2010
+
2011
+ #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
2012
+ Packet rhsV0[M]; \
2013
+ func(func1, func2, 0)
2014
+
2015
+ #define MICRO_UNROLL_TYPE(MICRO_TYPE, size) \
2016
+ MICRO_TYPE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE) \
2017
+ MICRO_ADD_ROWS(size)
2018
+
2019
+ #define MICRO_ONE_PEEL4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_PEEL, PEEL)
2020
+
2021
+ #define MICRO_ONE4 MICRO_UNROLL_TYPE(MICRO_UNROLL_TYPE_ONE, 1)
2022
+
2023
+ #define MICRO_DST_PTR_ONE(iter) \
2024
+ if (unroll_factor > iter) { \
2025
+ bsetzero<Packet, accRows>(accZero##iter); \
2026
+ } else { \
2027
+ EIGEN_UNUSED_VARIABLE(accZero##iter); \
2028
+ }
2029
+
2030
+ #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE)
2031
+
2032
+ #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE)
2033
+
2034
+ #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE)
2035
+
2036
+ #ifdef USE_PARTIAL_PACKETS
2037
+ #define MICRO_STORE_ONE(iter) \
2038
+ if (unroll_factor > iter) { \
2039
+ if (MICRO_NORMAL_PARTIAL(iter)) { \
2040
+ bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2041
+ bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2042
+ bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2043
+ } else { \
2044
+ bload_partial<DataMapper, Packet, 0, false, accRows>(acc, res, row + iter * accCols, accCols2); \
2045
+ bscale<Packet, accRows>(acc, accZero##iter, pAlpha); \
2046
+ bstore_partial<DataMapper, Packet, accRows>(acc, res, row + iter * accCols, accCols2); \
2047
+ } \
2048
+ }
2049
+ #else
2050
+ #define MICRO_STORE_ONE(iter) \
2051
+ if (unroll_factor > iter) { \
2052
+ bload<DataMapper, Packet, 0, ColMajor, false, accRows>(acc, res, row + iter * accCols, 0); \
2053
+ bscale<Packet, accRows, !(MICRO_NORMAL(iter))>(acc, accZero##iter, pAlpha, pMask); \
2054
+ bstore<DataMapper, Packet, accRows>(acc, res, row + iter * accCols); \
2055
+ }
2056
+ #endif
2057
+
2058
+ #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
2059
+
2060
+ #ifdef USE_PARTIAL_PACKETS
2061
+ template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2062
+ const Index accCols, bool full>
2063
+ #else
2064
+ template <int unroll_factor, typename Scalar, typename Packet, typename DataMapper, const Index accRows,
2065
+ const Index accCols, const Index accCols2>
2066
+ #endif
2067
+ EIGEN_ALWAYS_INLINE void gemm_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2068
+ Index depth, Index strideA, Index offsetA, Index strideB, Index& row,
2069
+ const Packet& pAlpha,
2070
+ #ifdef USE_PARTIAL_PACKETS
2071
+ Index accCols2
2072
+ #else
2073
+ const Packet& pMask
2074
+ #endif
2075
+ ) {
2076
+ const Scalar *rhs_ptr0 = rhs_base, *rhs_ptr1 = NULL, *rhs_ptr2 = NULL;
2077
+ const Scalar *lhs_ptr0 = NULL, *lhs_ptr1 = NULL, *lhs_ptr2 = NULL, *lhs_ptr3 = NULL, *lhs_ptr4 = NULL,
2078
+ *lhs_ptr5 = NULL, *lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
2079
+ PacketBlock<Packet, accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
2080
+ PacketBlock<Packet, accRows> acc;
2081
+
2082
+ MICRO_SRC2_PTR
2083
+ MICRO_SRC_PTR
2084
+ MICRO_DST_PTR
2085
+
2086
+ Index k = 0;
2087
+ for (; k + PEEL <= depth; k += PEEL) {
2088
+ MICRO_PREFETCHN(accRows)
2089
+ MICRO_PREFETCH
2090
+ MICRO_ONE_PEEL4
2091
+ }
2092
+ for (; k < depth; k++) {
2093
+ MICRO_ONE4
2094
+ }
2095
+ MICRO_STORE
2096
+
2097
+ MICRO_UPDATE
2098
+ }
2099
+
2100
+ #ifdef USE_PARTIAL_PACKETS
2101
+ #define MICRO_UNROLL_ITER2(N, M) \
2102
+ gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, !M>( \
2103
+ res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, M ? remaining_rows : accCols); \
2104
+ if (M) return;
2105
+ #else
2106
+ #define MICRO_UNROLL_ITER2(N, M) \
2107
+ gemm_unrolled_iteration<N + ((M) ? 1 : 0), Scalar, Packet, DataMapper, accRows, accCols, M ? M : accCols>( \
2108
+ res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlpha, pMask); \
2109
+ if (M) return;
2110
+ #endif
2111
+
2112
+ template <typename Scalar, typename Packet, typename DataMapper, const Index accRows, const Index accCols>
2113
+ EIGEN_ALWAYS_INLINE void gemm_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2114
+ Index strideA, Index offsetA, Index strideB, Index offsetB, Index col, Index rows,
2115
+ Index remaining_rows, const Packet& pAlpha, const Packet& pMask) {
2116
+ const DataMapper res3 = res.getSubMapper(0, col);
2117
+
2118
+ const Scalar* rhs_base = blockB + col * strideB + MICRO_NEW_ROWS * offsetB;
2119
+ const Scalar* lhs_base = blockA + accCols * offsetA;
2120
+ Index row = 0;
2121
+
2122
+ #define MAX_UNROLL 7
2123
+ while (row + MAX_UNROLL * accCols <= rows) {
2124
+ MICRO_UNROLL_ITER2(MAX_UNROLL, 0);
2125
+ }
2126
+ switch ((rows - row) / accCols) {
2127
+ #if MAX_UNROLL > 7
2128
+ case 7:
2129
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 7)
2130
+ break;
2131
+ #endif
2132
+ #if MAX_UNROLL > 6
2133
+ case 6:
2134
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 6)
2135
+ break;
2136
+ #endif
2137
+ #if MAX_UNROLL > 5
2138
+ case 5:
2139
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 5)
2140
+ break;
2141
+ #endif
2142
+ #if MAX_UNROLL > 4
2143
+ case 4:
2144
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 4)
2145
+ break;
2146
+ #endif
2147
+ #if MAX_UNROLL > 3
2148
+ case 3:
2149
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 3)
2150
+ break;
2151
+ #endif
2152
+ #if MAX_UNROLL > 2
2153
+ case 2:
2154
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 2)
2155
+ break;
2156
+ #endif
2157
+ #if MAX_UNROLL > 1
2158
+ case 1:
2159
+ MICRO_UNROLL_ITER(MICRO_UNROLL_ITER2, 1)
2160
+ break;
2161
+ #endif
2162
+ default:
2163
+ break;
2164
+ }
2165
+ #undef MAX_UNROLL
2166
+
2167
+ if (remaining_rows > 0) {
2168
+ gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA,
2169
+ strideB, row, rows, remaining_rows, pAlpha, pMask);
2170
+ }
2171
+ }
2172
+
2173
+ #define MICRO_EXTRA_COLS(N) \
2174
+ gemm_cols<Scalar, Packet, DataMapper, N, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, \
2175
+ col, rows, remaining_rows, pAlpha, pMask);
2176
+
2177
+ template <typename Scalar, typename Packet, typename DataMapper, const Index accCols>
2178
+ EIGEN_ALWAYS_INLINE void gemm_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index depth,
2179
+ Index strideA, Index offsetA, Index strideB, Index offsetB, Index col,
2180
+ Index rows, Index cols, Index remaining_rows, const Packet& pAlpha,
2181
+ const Packet& pMask) {
2182
+ MICRO_EXTRA(MICRO_EXTRA_COLS, cols - col, true)
2183
+ }
2184
+
2185
+ /****************
2186
+ * GEMM kernels *
2187
+ * **************/
2188
+ template <typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows,
2189
+ const Index accCols>
2190
+ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows,
2191
+ Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA,
2192
+ Index offsetB) {
2193
+ const Index remaining_rows = rows % accCols;
2194
+
2195
+ if (strideA == -1) strideA = depth;
2196
+ if (strideB == -1) strideB = depth;
2197
+
2198
+ const Packet pAlpha = pset1<Packet>(alpha);
2199
+ const Packet pMask = bmask<Packet>(remaining_rows);
2200
+
2201
+ Index col = 0;
2202
+ for (; col + accRows <= cols; col += accRows) {
2203
+ gemm_cols<Scalar, Packet, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB,
2204
+ offsetB, col, rows, remaining_rows, pAlpha, pMask);
2205
+ }
2206
+
2207
+ if (col != cols) {
2208
+ gemm_extra_cols<Scalar, Packet, DataMapper, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB,
2209
+ col, rows, cols, remaining_rows, pAlpha, pMask);
2210
+ }
2211
+ }
2212
+
2213
+ #define accColsC (accCols / 2)
2214
+ #define advanceRows ((LhsIsReal) ? 1 : 2)
2215
+ #define advanceCols ((RhsIsReal) ? 1 : 2)
2216
+
2217
+ // PEEL_COMPLEX loop factor.
2218
+ #define PEEL_COMPLEX 3
2219
+ #define PEEL_COMPLEX_ROW 3
2220
+
2221
+ #define MICRO_COMPLEX_UNROLL(func) func(0) func(1) func(2) func(3)
2222
+
2223
+ #define MICRO_COMPLEX_ZERO_PEEL(peel) \
2224
+ if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
2225
+ bsetzero<Packet, accRows>(accReal##peel); \
2226
+ bsetzero<Packet, accRows>(accImag##peel); \
2227
+ } else { \
2228
+ EIGEN_UNUSED_VARIABLE(accReal##peel); \
2229
+ EIGEN_UNUSED_VARIABLE(accImag##peel); \
2230
+ }
2231
+
2232
+ #define MICRO_COMPLEX_ADD_ROWS(N, used) \
2233
+ MICRO_ADD(ptr_real, N) \
2234
+ if (!RhsIsReal) { \
2235
+ MICRO_ADD(ptr_imag, N) \
2236
+ } else if (used) { \
2237
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2238
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2239
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2240
+ }
2241
+
2242
+ #define MICRO_COMPLEX_BROADCAST(peel) \
2243
+ MICRO_BROADCAST1(peel, ptr_real, rhsV, false) \
2244
+ if (!RhsIsReal) { \
2245
+ MICRO_BROADCAST1(peel, ptr_imag, rhsVi, false) \
2246
+ } else { \
2247
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2248
+ }
2249
+
2250
+ #define MICRO_COMPLEX_BROADCAST_EXTRA \
2251
+ Packet rhsV[4], rhsVi[4]; \
2252
+ MICRO_BROADCAST_EXTRA1(ptr_real, rhsV, false) \
2253
+ if (!RhsIsReal) { \
2254
+ MICRO_BROADCAST_EXTRA1(ptr_imag, rhsVi, false) \
2255
+ } else { \
2256
+ EIGEN_UNUSED_VARIABLE(rhsVi); \
2257
+ } \
2258
+ MICRO_COMPLEX_ADD_ROWS(1, true)
2259
+
2260
+ #define MICRO_COMPLEX_SRC2_PTR \
2261
+ MICRO_SRC2(ptr_real, strideB* advanceCols, 0) \
2262
+ if (!RhsIsReal) { \
2263
+ MICRO_RHS(ptr_imag, 0) = rhs_base + MICRO_NEW_ROWS * strideB; \
2264
+ MICRO_SRC2(ptr_imag, strideB* advanceCols, strideB) \
2265
+ } else { \
2266
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 0)); \
2267
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 1)); \
2268
+ EIGEN_UNUSED_VARIABLE(MICRO_RHS(ptr_imag, 2)); \
2269
+ }
2270
+
2271
+ #define MICRO_COMPLEX_ZERO_PEEL_ROW MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_ZERO_PEEL)
2272
+
2273
+ #define MICRO_COMPLEX_WORK_PEEL(peel) \
2274
+ if (PEEL_COMPLEX_ROW > peel) { \
2275
+ MICRO_COMPLEX_BROADCAST(peel) \
2276
+ pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2277
+ &accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), \
2278
+ lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
2279
+ } else { \
2280
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2281
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2282
+ }
2283
+
2284
+ #define MICRO_COMPLEX_ADD_COLS(size) \
2285
+ lhs_ptr_real += (remaining_rows * size); \
2286
+ if (!LhsIsReal) \
2287
+ lhs_ptr_imag += (remaining_rows * size); \
2288
+ else \
2289
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2290
+
2291
+ #define MICRO_COMPLEX_WORK_PEEL_ROW \
2292
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
2293
+ Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
2294
+ MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_WORK_PEEL) \
2295
+ MICRO_COMPLEX_ADD_COLS(PEEL_COMPLEX_ROW) \
2296
+ MICRO_COMPLEX_ADD_ROWS(PEEL_COMPLEX_ROW, false)
2297
+
2298
+ #define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
2299
+ if (PEEL_COMPLEX_ROW > peel) { \
2300
+ for (Index i = 0; i < accRows; i++) { \
2301
+ accReal##sum.packet[i] += accReal##peel.packet[i]; \
2302
+ accImag##sum.packet[i] += accImag##peel.packet[i]; \
2303
+ } \
2304
+ }
2305
+
2306
+ #define MICRO_COMPLEX_ADD_PEEL_ROW \
2307
+ MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) MICRO_COMPLEX_ADD_PEEL(1, 0)
2308
+
2309
+ template <typename Scalar, typename Packet, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal,
2310
+ bool RhsIsReal, const Index remaining_rows>
2311
+ EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(const Scalar*& lhs_ptr_real, const Scalar*& lhs_ptr_imag,
2312
+ const Scalar*& rhs_ptr_real0, const Scalar*& rhs_ptr_real1,
2313
+ const Scalar*& rhs_ptr_real2, const Scalar*& rhs_ptr_imag0,
2314
+ const Scalar*& rhs_ptr_imag1, const Scalar*& rhs_ptr_imag2,
2315
+ PacketBlock<Packet, accRows>& accReal,
2316
+ PacketBlock<Packet, accRows>& accImag) {
2317
+ MICRO_COMPLEX_BROADCAST_EXTRA
2318
+ pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real,
2319
+ lhs_ptr_imag, rhsV, rhsVi);
2320
+ MICRO_COMPLEX_ADD_COLS(1)
2321
+ }
2322
+
2323
+ template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2324
+ const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal,
2325
+ const Index remaining_rows>
2326
+ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(const DataMapper& res, const Scalar* lhs_base,
2327
+ const Scalar* rhs_base, Index depth, Index strideA,
2328
+ Index offsetA, Index strideB, Index row, Index rows,
2329
+ const Packet& pAlphaReal, const Packet& pAlphaImag,
2330
+ const Packet& pMask) {
2331
+ const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2332
+ const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2333
+ const Scalar* lhs_ptr_real = lhs_base + advanceRows * row * strideA + remaining_rows * offsetA;
2334
+ const Scalar* lhs_ptr_imag = NULL;
2335
+ if (!LhsIsReal)
2336
+ lhs_ptr_imag = lhs_ptr_real + remaining_rows * strideA;
2337
+ else
2338
+ EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
2339
+ PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
2340
+ PacketBlock<Packet, accRows> taccReal, taccImag;
2341
+ PacketBlock<Packetc, accRows> acc0, acc1;
2342
+ PacketBlock<Packetc, accRows * 2> tRes;
2343
+
2344
+ MICRO_COMPLEX_SRC2_PTR
2345
+
2346
+ bsetzero<Packet, accRows>(accReal0);
2347
+ bsetzero<Packet, accRows>(accImag0);
2348
+
2349
+ Index remaining_depth = depth & -quad_traits<Scalar>::rows;
2350
+ Index k = 0;
2351
+ if (remaining_depth >= PEEL_COMPLEX_ROW) {
2352
+ MICRO_COMPLEX_ZERO_PEEL_ROW
2353
+ do {
2354
+ MICRO_COMPLEX_PREFETCHN(accRows)
2355
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
2356
+ if (!LhsIsReal) {
2357
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
2358
+ }
2359
+ MICRO_COMPLEX_WORK_PEEL_ROW
2360
+ } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
2361
+ MICRO_COMPLEX_ADD_PEEL_ROW
2362
+ }
2363
+ for (; k < depth; k++) {
2364
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(
2365
+ lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real0, rhs_ptr_real1, rhs_ptr_real2, rhs_ptr_imag0, rhs_ptr_imag1,
2366
+ rhs_ptr_imag2, accReal0, accImag0);
2367
+ }
2368
+
2369
+ constexpr bool full = (remaining_rows > accColsC);
2370
+ bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row, 0);
2371
+ if ((accRows == 1) || (rows >= accCols)) {
2372
+ bscalec<Packet, accRows, true>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2373
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2374
+ bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2375
+ if (full) {
2376
+ bstore<DataMapper, Packetc, accRows>(acc1, res, row + accColsC);
2377
+ }
2378
+ } else {
2379
+ bscalec<Packet, accRows, false>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
2380
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1);
2381
+
2382
+ if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) {
2383
+ for (Index j = 0; j < accRows; j++) {
2384
+ res(row + 0, j) = pfirst<Packetc>(acc0.packet[j]);
2385
+ }
2386
+ } else {
2387
+ bstore<DataMapper, Packetc, accRows>(acc0, res, row + 0);
2388
+ if (full) {
2389
+ for (Index j = 0; j < accRows; j++) {
2390
+ res(row + accColsC, j) = pfirst<Packetc>(acc1.packet[j]);
2391
+ }
2392
+ }
2393
+ }
2394
+ }
2395
+ }
2396
+
2397
+ #define MICRO_COMPLEX_EXTRA_ROWS(N) \
2398
+ gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, \
2399
+ ConjugateRhs, LhsIsReal, RhsIsReal, N>( \
2400
+ res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, pAlphaReal, pAlphaImag, pMask);
2401
+
2402
+ template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2403
+ const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2404
+ EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base,
2405
+ Index depth, Index strideA, Index offsetA, Index strideB, Index row,
2406
+ Index rows, Index remaining_rows, const Packet& pAlphaReal,
2407
+ const Packet& pAlphaImag, const Packet& pMask) {
2408
+ MICRO_EXTRA(MICRO_COMPLEX_EXTRA_ROWS, remaining_rows, false)
2409
+ }
2410
+
2411
+ #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2412
+ MICRO_COMPLEX_UNROLL(func2); \
2413
+ func(0, peel) func(1, peel) func(2, peel) func(3, peel)
2414
+
2415
+ #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
2416
+ if (unroll_factor > iter) { \
2417
+ pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2418
+ &accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
2419
+ }
2420
+
2421
+ #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
2422
+ if (PEEL_COMPLEX > peel) { \
2423
+ Packet lhsV0, lhsV1, lhsV2, lhsV3; \
2424
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
2425
+ MICRO_COMPLEX_BROADCAST(peel) \
2426
+ MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
2427
+ } else { \
2428
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
2429
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
2430
+ }
2431
+
2432
+ #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
2433
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
2434
+ Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
2435
+ func(func1, func2, 0) func(func1, func2, 1) func(func1, func2, 2) func(func1, func2, 3)
2436
+
2437
+ #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
2438
+ Packet rhsV0[M], rhsVi0[M]; \
2439
+ func(func1, func2, 0)
2440
+
2441
+ #define MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_TYPE, size) \
2442
+ MICRO_COMPLEX_TYPE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE) \
2443
+ MICRO_COMPLEX_ADD_ROWS(size, false)
2444
+
2445
+ #define MICRO_COMPLEX_ONE_PEEL4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_PEEL, PEEL_COMPLEX)
2446
+
2447
+ #define MICRO_COMPLEX_ONE4 MICRO_COMPLEX_UNROLL_TYPE(MICRO_COMPLEX_UNROLL_TYPE_ONE, 1)
2448
+
2449
+ #define MICRO_COMPLEX_DST_PTR_ONE(iter) \
2450
+ if (unroll_factor > iter) { \
2451
+ bsetzero<Packet, accRows>(accReal##iter); \
2452
+ bsetzero<Packet, accRows>(accImag##iter); \
2453
+ } else { \
2454
+ EIGEN_UNUSED_VARIABLE(accReal##iter); \
2455
+ EIGEN_UNUSED_VARIABLE(accImag##iter); \
2456
+ }
2457
+
2458
+ #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE)
2459
+
2460
+ #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
2461
+
2462
+ #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
2463
+
2464
+ #define MICRO_COMPLEX_STORE_ONE(iter) \
2465
+ if (unroll_factor > iter) { \
2466
+ constexpr bool full = ((MICRO_NORMAL(iter)) || (accCols2 > accColsC)); \
2467
+ bload<DataMapper, Packetc, accColsC, ColMajor, true, accRows, full>(tRes, res, row + iter * accCols, 0); \
2468
+ bscalec<Packet, accRows, !(MICRO_NORMAL(iter))>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, \
2469
+ taccImag, pMask); \
2470
+ bcouple<Packet, Packetc, accRows, full>(taccReal, taccImag, tRes, acc0, acc1); \
2471
+ bstore<DataMapper, Packetc, accRows>(acc0, res, row + iter * accCols + 0); \
2472
+ if (full) { \
2473
+ bstore<DataMapper, Packetc, accRows>(acc1, res, row + iter * accCols + accColsC); \
2474
+ } \
2475
+ }
2476
+
2477
+ #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
2478
+
2479
+ template <int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper,
2480
+ const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs,
2481
+ bool LhsIsReal, bool RhsIsReal>
2482
+ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_iteration(const DataMapper& res, const Scalar* lhs_base,
2483
+ const Scalar* rhs_base, Index depth, Index strideA,
2484
+ Index offsetA, Index strideB, Index& row,
2485
+ const Packet& pAlphaReal, const Packet& pAlphaImag,
2486
+ const Packet& pMask) {
2487
+ const Scalar *rhs_ptr_real0 = rhs_base, *rhs_ptr_real1 = NULL, *rhs_ptr_real2 = NULL;
2488
+ const Scalar *rhs_ptr_imag0 = NULL, *rhs_ptr_imag1 = NULL, *rhs_ptr_imag2 = NULL;
2489
+ const Index imag_delta = accCols * strideA;
2490
+ const Index imag_delta2 = accCols2 * strideA;
2491
+ const Scalar *lhs_ptr_real0 = NULL, *lhs_ptr_real1 = NULL;
2492
+ const Scalar *lhs_ptr_real2 = NULL, *lhs_ptr_real3 = NULL;
2493
+ PacketBlock<Packet, accRows> accReal0, accImag0, accReal1, accImag1;
2494
+ PacketBlock<Packet, accRows> accReal2, accImag2, accReal3, accImag3;
2495
+ PacketBlock<Packet, accRows> taccReal, taccImag;
2496
+ PacketBlock<Packetc, accRows> acc0, acc1;
2497
+ PacketBlock<Packetc, accRows * 2> tRes;
2498
+
2499
+ MICRO_COMPLEX_SRC2_PTR
2500
+ MICRO_COMPLEX_SRC_PTR
2501
+ MICRO_COMPLEX_DST_PTR
2502
+
2503
+ Index k = 0;
2504
+ for (; k + PEEL_COMPLEX <= depth; k += PEEL_COMPLEX) {
2505
+ MICRO_COMPLEX_PREFETCHN(accRows)
2506
+ MICRO_COMPLEX_PREFETCH
2507
+ MICRO_COMPLEX_ONE_PEEL4
2508
+ }
2509
+ for (; k < depth; k++) {
2510
+ MICRO_COMPLEX_ONE4
2511
+ }
2512
+ MICRO_COMPLEX_STORE
2513
+
2514
+ MICRO_COMPLEX_UPDATE
2515
+ }
2516
+
2517
+ #define MICRO_COMPLEX_UNROLL_ITER2(N, M) \
2518
+ gemm_complex_unrolled_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, DataMapper, accRows, accCols, \
2519
+ M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>( \
2520
+ res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
2521
+ if (M) return;
2522
+
2523
+ template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accRows,
2524
+ const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2525
+ EIGEN_ALWAYS_INLINE void gemm_complex_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2526
+ Index depth, Index strideA, Index offsetA, Index strideB, Index offsetB,
2527
+ Index col, Index rows, Index remaining_rows, const Packet& pAlphaReal,
2528
+ const Packet& pAlphaImag, const Packet& pMask) {
2529
+ const DataMapper res3 = res.getSubMapper(0, col);
2530
+
2531
+ const Scalar* rhs_base = blockB + advanceCols * col * strideB + MICRO_NEW_ROWS * offsetB;
2532
+ const Scalar* lhs_base = blockA + accCols * offsetA;
2533
+ Index row = 0;
2534
+
2535
+ #define MAX_COMPLEX_UNROLL 4
2536
+ while (row + MAX_COMPLEX_UNROLL * accCols <= rows) {
2537
+ MICRO_COMPLEX_UNROLL_ITER2(MAX_COMPLEX_UNROLL, 0);
2538
+ }
2539
+ switch ((rows - row) / accCols) {
2540
+ #if MAX_COMPLEX_UNROLL > 4
2541
+ case 4:
2542
+ MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 4)
2543
+ break;
2544
+ #endif
2545
+ #if MAX_COMPLEX_UNROLL > 3
2546
+ case 3:
2547
+ MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 3)
2548
+ break;
2549
+ #endif
2550
+ #if MAX_COMPLEX_UNROLL > 2
2551
+ case 2:
2552
+ MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 2)
2553
+ break;
2554
+ #endif
2555
+ #if MAX_COMPLEX_UNROLL > 1
2556
+ case 1:
2557
+ MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_UNROLL_ITER2, 1)
2558
+ break;
2559
+ #endif
2560
+ default:
2561
+ break;
2562
+ }
2563
+ #undef MAX_COMPLEX_UNROLL
2564
+
2565
+ if (remaining_rows > 0) {
2566
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2567
+ RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows,
2568
+ remaining_rows, pAlphaReal, pAlphaImag, pMask);
2569
+ }
2570
+ }
2571
+
2572
+ #define MICRO_COMPLEX_EXTRA_COLS(N) \
2573
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, N, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, \
2574
+ RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, \
2575
+ remaining_rows, pAlphaReal, pAlphaImag, pMask);
2576
+
2577
+ template <typename Scalar, typename Packet, typename Packetc, typename DataMapper, const Index accCols,
2578
+ bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2579
+ EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
2580
+ Index depth, Index strideA, Index offsetA, Index strideB,
2581
+ Index offsetB, Index col, Index rows, Index cols, Index remaining_rows,
2582
+ const Packet& pAlphaReal, const Packet& pAlphaImag,
2583
+ const Packet& pMask) {
2584
+ MICRO_EXTRA(MICRO_COMPLEX_EXTRA_COLS, cols - col, true)
2585
+ }
2586
+
2587
+ template <typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc,
2588
+ typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs,
2589
+ bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
2590
+ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc,
2591
+ Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB,
2592
+ Index offsetA, Index offsetB) {
2593
+ const Index remaining_rows = rows % accCols;
2594
+
2595
+ if (strideA == -1) strideA = depth;
2596
+ if (strideB == -1) strideB = depth;
2597
+
2598
+ const Packet pAlphaReal = pset1<Packet>(alpha.real());
2599
+ const Packet pAlphaImag = pset1<Packet>(alpha.imag());
2600
+ const Packet pMask = bmask<Packet>(remaining_rows);
2601
+
2602
+ const Scalar* blockA = (Scalar*)blockAc;
2603
+ const Scalar* blockB = (Scalar*)blockBc;
2604
+
2605
+ Index col = 0;
2606
+ for (; col + accRows <= cols; col += accRows) {
2607
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2608
+ RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows,
2609
+ remaining_rows, pAlphaReal, pAlphaImag, pMask);
2610
+ }
2611
+
2612
+ if (col != cols) {
2613
+ gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal,
2614
+ RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols,
2615
+ remaining_rows, pAlphaReal, pAlphaImag, pMask);
2616
+ }
2617
+ }
2618
+
2619
+ #undef accColsC
2620
+ #undef advanceCols
2621
+ #undef advanceRows
2622
+
2623
+ EIGEN_ALWAYS_INLINE bool supportsMMA() {
2624
+ #if defined(EIGEN_ALTIVEC_MMA_ONLY)
2625
+ return true;
2626
+ #elif defined(EIGEN_ALTIVEC_MMA_DYNAMIC_DISPATCH) && defined(__BUILTIN_CPU_SUPPORTS__)
2627
+ return __builtin_cpu_supports("arch_3_1") && __builtin_cpu_supports("mma");
2628
+ #else
2629
+ return false; // No dynamic dispatch for LLVM or older GCC
2630
+ #endif
2631
+ }
2632
+
2633
+ EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result) {
2634
+ Packet4f result_block = ploadu<Packet4f>(result);
2635
+ return pmadd(acc, pAlpha, result_block);
2636
+ }
2637
+
2638
+ template <bool lhsExtraRows>
2639
+ EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows) {
2640
+ if (lhsExtraRows) {
2641
+ pstoreu_partial(result, result_block, extra_rows);
2642
+ } else {
2643
+ pstoreu(result, result_block);
2644
+ }
2645
+ result += rows;
2646
+ }
2647
+
2648
+ template <bool rhsExtraCols, bool lhsExtraRows>
2649
+ EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result,
2650
+ Index extra_cols, Index extra_rows) {
2651
+ Index x = 0;
2652
+ if (rhsExtraCols) {
2653
+ do {
2654
+ Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
2655
+ storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
2656
+ } while (++x < extra_cols);
2657
+ } else {
2658
+ Packet4f result_block[4];
2659
+ float* result2 = result;
2660
+ do {
2661
+ result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
2662
+ result += rows;
2663
+ } while (++x < 4);
2664
+ x = 0;
2665
+ do {
2666
+ storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
2667
+ } while (++x < 4);
2668
+ }
2669
+ }
2670
+
2671
+ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Hi(Packet8us data) {
2672
+ Packet8us z = pset1<Packet8us>(0);
2673
+ #ifdef _BIG_ENDIAN
2674
+ return reinterpret_cast<Packet4f>(vec_mergeh(data, z));
2675
+ #else
2676
+ return reinterpret_cast<Packet4f>(vec_mergeh(z, data));
2677
+ #endif
2678
+ }
2679
+
2680
+ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Lo(Packet8us data) {
2681
+ Packet8us z = pset1<Packet8us>(0);
2682
+ #ifdef _BIG_ENDIAN
2683
+ return reinterpret_cast<Packet4f>(vec_mergel(data, z));
2684
+ #else
2685
+ return reinterpret_cast<Packet4f>(vec_mergel(z, data));
2686
+ #endif
2687
+ }
2688
+
2689
+ template <Index N, Index M>
2690
+ EIGEN_ALWAYS_INLINE void storeConvertTwoBF16(float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra = 0) {
2691
+ if (N < 4) {
2692
+ pstoreu_partial(to + 0, oneConvertBF16Hi(block.packet[0].m_val), extra);
2693
+ } else if (N >= (M * 8 + 4)) {
2694
+ pstoreu(to + 0, oneConvertBF16Hi(block.packet[M].m_val));
2695
+ if (N >= 8) {
2696
+ pstoreu(to + 4, oneConvertBF16Lo(block.packet[M].m_val));
2697
+ }
2698
+ }
2699
+ }
2700
+
2701
+ template <Index N>
2702
+ EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf, (N + 7) / 8>& block, Index extra) {
2703
+ storeConvertTwoBF16<N, 0>(to + 0, block, extra);
2704
+ if (N >= 16) {
2705
+ storeConvertTwoBF16<N, 1>(to + 8, block);
2706
+ }
2707
+ if (N >= 32) {
2708
+ storeConvertTwoBF16<N, 2>(to + 16, block);
2709
+ storeConvertTwoBF16<N, 3>(to + 24, block);
2710
+ }
2711
+ }
2712
+
2713
+ template <bool non_unit_stride, Index delta>
2714
+ EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc) {
2715
+ if (non_unit_stride) {
2716
+ return pgather<bfloat16, Packet8bf>(src + delta * resInc, resInc);
2717
+ } else {
2718
+ return ploadu<Packet8bf>(src + delta);
2719
+ }
2720
+ }
2721
+
2722
+ static Packet16uc p16uc_MERGE16_32_1 = {0, 1, 16, 17, 2, 3, 18, 19, 0, 1, 16, 17, 2, 3, 18, 19};
2723
+ static Packet16uc p16uc_MERGE16_32_2 = {4, 5, 20, 21, 6, 7, 22, 23, 4, 5, 20, 21, 6, 7, 22, 23};
2724
+ static Packet16uc p16uc_MERGE16_32_3 = {8, 9, 24, 25, 10, 11, 26, 27, 8, 9, 24, 25, 10, 11, 26, 27};
2725
+ static Packet16uc p16uc_MERGE16_32_4 = {12, 13, 28, 29, 14, 15, 30, 31, 12, 13, 28, 29, 14, 15, 30, 31};
2726
+
2727
+ static Packet16uc p16uc_MERGE16_32_5 = {0, 1, 16, 17, 16, 17, 16, 17, 0, 1, 16, 17, 16, 17, 16, 17};
2728
+ static Packet16uc p16uc_MERGE16_32_6 = {2, 3, 18, 19, 18, 19, 18, 19, 2, 3, 18, 19, 18, 19, 18, 19};
2729
+ static Packet16uc p16uc_MERGE16_32_7 = {4, 5, 20, 21, 20, 21, 20, 21, 4, 5, 20, 21, 20, 21, 20, 21};
2730
+ static Packet16uc p16uc_MERGE16_32_8 = {6, 7, 22, 23, 22, 23, 22, 23, 6, 7, 22, 23, 22, 23, 22, 23};
2731
+
2732
+ EIGEN_ALWAYS_INLINE Packet4f oneConvertBF16Perm(Packet8us data, Packet16uc mask) {
2733
+ Packet8us z = pset1<Packet8us>(0);
2734
+ #ifdef _BIG_ENDIAN
2735
+ return reinterpret_cast<Packet4f>(vec_perm(data, z, mask));
2736
+ #else
2737
+ return reinterpret_cast<Packet4f>(vec_perm(z, data, mask));
2738
+ #endif
2739
+ }
2740
+
2741
+ template <bool lhsExtraRows, bool odd, Index size>
2742
+ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32DupOne(float* result, Index rows, const bfloat16* src,
2743
+ Index extra_rows) {
2744
+ Packet4f dup[4 * 4];
2745
+ Packet8bf data[4];
2746
+
2747
+ for (Index i = 0; i < size; i++) {
2748
+ data[i] = ploadu<Packet8bf>(src + rows * i);
2749
+ }
2750
+
2751
+ for (Index i = 0, j = 0; i < size; i++, j += 4) {
2752
+ dup[j + 0] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_5 : p16uc_MERGE16_32_1);
2753
+ dup[j + 1] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_6 : p16uc_MERGE16_32_2);
2754
+ dup[j + 2] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_7 : p16uc_MERGE16_32_3);
2755
+ dup[j + 3] = oneConvertBF16Perm(data[i].m_val, odd ? p16uc_MERGE16_32_8 : p16uc_MERGE16_32_4);
2756
+ }
2757
+
2758
+ for (Index j = 0; j < 4 * size; j += 4) {
2759
+ if (lhsExtraRows) {
2760
+ Packet4f z = pset1<Packet4f>(float(0));
2761
+ Index i = 0;
2762
+ do {
2763
+ pstoreu(result + (j + i) * 4, dup[j + i]);
2764
+ } while (++i < extra_rows);
2765
+ do {
2766
+ pstoreu(result + (j + i) * 4, z);
2767
+ } while (++i < 4);
2768
+ } else {
2769
+ for (Index i = 0; i < 4; i++) {
2770
+ pstoreu(result + (j + i) * 4, dup[j + i]);
2771
+ }
2772
+ }
2773
+ }
2774
+ }
2775
+
2776
+ template <bool lhsExtraRows>
2777
+ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32Dup(float* result, Index cols, Index rows, const bfloat16* src,
2778
+ Index delta, Index extra_rows) {
2779
+ Index col = 0;
2780
+ src += delta * 2;
2781
+ for (; col + 4 * 2 <= cols; col += 4 * 2, result += 4 * 4 * 4, src += 4 * rows) {
2782
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 4>(result, rows, src, extra_rows);
2783
+ }
2784
+ for (; col + 2 <= cols; col += 2, result += 4 * 4, src += rows) {
2785
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows, false, 1>(result, rows, src, extra_rows);
2786
+ }
2787
+ if (cols & 1) {
2788
+ convertArrayPointerBF16toF32DupOne<lhsExtraRows, true, 1>(result, rows, src - delta, extra_rows);
2789
+ }
2790
+ }
2791
+
2792
+ template <const Index size, bool non_unit_stride>
2793
+ EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float* result, Index rows, bfloat16*& src, Index resInc) {
2794
+ constexpr Index extra = ((size < 4) ? 4 : size);
2795
+ while (i + size <= rows) {
2796
+ PacketBlock<Packet8bf, (size + 7) / 8> r32;
2797
+ r32.packet[0] = loadBF16fromResult<non_unit_stride, 0>(src, resInc);
2798
+ if (size >= 16) {
2799
+ r32.packet[1] = loadBF16fromResult<non_unit_stride, 8>(src, resInc);
2800
+ }
2801
+ if (size >= 32) {
2802
+ r32.packet[2] = loadBF16fromResult<non_unit_stride, 16>(src, resInc);
2803
+ r32.packet[3] = loadBF16fromResult<non_unit_stride, 24>(src, resInc);
2804
+ }
2805
+ storeConvertBlockBF16<size>(result + i, r32, rows & 3);
2806
+ i += extra;
2807
+ src += extra * resInc;
2808
+ if (size != 32) break;
2809
+ }
2810
+ }
2811
+
2812
+ template <bool non_unit_stride>
2813
+ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float* result, Index cols, Index rows, bfloat16* src,
2814
+ Index resInc) {
2815
+ for (Index col = 0; col < cols; col++, src += (rows * resInc), result += rows) {
2816
+ Index i = 0;
2817
+ bfloat16* src2 = src;
2818
+ convertPointerBF16toF32<32, non_unit_stride>(i, result, rows, src2, resInc);
2819
+ convertPointerBF16toF32<16, non_unit_stride>(i, result, rows, src2, resInc);
2820
+ convertPointerBF16toF32<8, non_unit_stride>(i, result, rows, src2, resInc);
2821
+ convertPointerBF16toF32<4, non_unit_stride>(i, result, rows, src2, resInc);
2822
+ convertPointerBF16toF32<1, non_unit_stride>(i, result, rows, src2, resInc);
2823
+ }
2824
+ }
2825
+
2826
+ template <Index num_acc, Index size = 4>
2827
+ EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][size]) {
2828
+ Packet4f z = pset1<Packet4f>(float(0));
2829
+
2830
+ for (Index k = 0; k < num_acc; k++) {
2831
+ for (Index j = 0; j < size; j++) {
2832
+ acc[k][j] = z;
2833
+ }
2834
+ }
2835
+ }
2836
+
2837
+ template <Index num_acc>
2838
+ EIGEN_ALWAYS_INLINE void tranposeResults(Packet4f (&acc)[num_acc][4]) {
2839
+ for (Index i = 0; i < num_acc; i++) {
2840
+ Packet4ui t0, t1, t2, t3;
2841
+ t0 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2842
+ t1 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][0]), reinterpret_cast<Packet4ui>(acc[i][2]));
2843
+ t2 = vec_mergeh(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2844
+ t3 = vec_mergel(reinterpret_cast<Packet4ui>(acc[i][1]), reinterpret_cast<Packet4ui>(acc[i][3]));
2845
+ acc[i][0] = reinterpret_cast<Packet4f>(vec_mergeh(t0, t2));
2846
+ acc[i][1] = reinterpret_cast<Packet4f>(vec_mergel(t0, t2));
2847
+ acc[i][2] = reinterpret_cast<Packet4f>(vec_mergeh(t1, t3));
2848
+ acc[i][3] = reinterpret_cast<Packet4f>(vec_mergel(t1, t3));
2849
+ }
2850
+ }
2851
+
2852
+ template <Index num_acc>
2853
+ EIGEN_ALWAYS_INLINE void addResults(Packet4f (&acc)[num_acc][4]) {
2854
+ for (Index i = 0, j = 0; j < num_acc; i++, j += 2) {
2855
+ for (Index x = 0, y = 0; x < 2; x++, y += 2) {
2856
+ for (Index w = 0, z = 0; w < 2; w++, z += 2) {
2857
+ acc[i][y + w] = acc[j + x][z + 0] + acc[j + x][z + 1];
2858
+ }
2859
+ }
2860
+ }
2861
+ }
2862
+
2863
+ template <Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs>
2864
+ EIGEN_ALWAYS_INLINE void outputResultsVSX(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result,
2865
+ const Index extra_cols, Index extra_rows) {
2866
+ tranposeResults<num_acc>(acc);
2867
+ addResults<num_acc>(acc);
2868
+
2869
+ constexpr Index real_rhs = ((num_rhs / 2) - (rhsExtraCols ? 1 : 0));
2870
+ Index k = 0;
2871
+ for (Index i = 0; i < real_rhs; i++, result += 4 * rows, k++) {
2872
+ storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2873
+ }
2874
+ if (rhsExtraCols) {
2875
+ storeResults<rhsExtraCols, lhsExtraRows>(acc[k], rows, pAlpha, result, extra_cols, extra_rows);
2876
+ }
2877
+ }
2878
+
2879
+ template <bool zero>
2880
+ EIGEN_ALWAYS_INLINE void loadTwoRhsFloat32(const float* block, Index strideB, Index i, Packet4f& dhs0, Packet4f& dhs1) {
2881
+ dhs0 = ploadu<Packet4f>(block + strideB * i + 0);
2882
+ if (zero) {
2883
+ Packet4f dhs2 = pset1<Packet4f>(float(0));
2884
+ dhs1 = vec_mergel(dhs0, dhs2);
2885
+ dhs0 = vec_mergeh(dhs0, dhs2);
2886
+ } else {
2887
+ dhs1 = ploadu<Packet4f>(block + strideB * i + 4);
2888
+ }
2889
+ }
2890
+
2891
+ template <Index num_acc, bool zero, bool rhsExtraCols, Index num_rhs>
2892
+ EIGEN_ALWAYS_INLINE void KLoop(const float* indexA, const float* indexB, Packet4f (&acc)[num_acc][4], Index strideB,
2893
+ Index k, Index offsetB, Index extra_cols) {
2894
+ constexpr Index num_lhs = 4;
2895
+ Packet4f lhs[num_lhs], rhs[num_rhs];
2896
+
2897
+ constexpr Index real_rhs = (num_rhs - (rhsExtraCols ? 2 : 0));
2898
+ for (Index i = 0; i < real_rhs; i += 2) {
2899
+ loadTwoRhsFloat32<zero>(indexB + k * 4, strideB, i, rhs[i + 0], rhs[i + 1]);
2900
+ }
2901
+ if (rhsExtraCols) {
2902
+ loadTwoRhsFloat32<zero>(indexB + k * extra_cols - offsetB, strideB, real_rhs, rhs[real_rhs + 0], rhs[real_rhs + 1]);
2903
+ }
2904
+
2905
+ indexA += 2 * k * 4;
2906
+ for (Index j = 0; j < num_lhs; j++) {
2907
+ lhs[j] = ploadu<Packet4f>(indexA + j * 4);
2908
+ }
2909
+
2910
+ for (Index j = 0; j < num_rhs; j++) {
2911
+ for (Index i = 0; i < num_lhs; i++) {
2912
+ acc[j][i] = pmadd(rhs[j], lhs[i], acc[j][i]);
2913
+ }
2914
+ }
2915
+ }
2916
+
2917
+ template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2918
+ EIGEN_ALWAYS_INLINE void colVSXLoopBodyIter(Index depth, Index rows, const Packet4f pAlpha, const float* indexA,
2919
+ const float* indexB, Index strideB, Index offsetB, float* result,
2920
+ const Index extra_cols, const Index extra_rows) {
2921
+ constexpr Index num_rhs = num_acc;
2922
+
2923
+ Packet4f acc[num_acc][4];
2924
+
2925
+ zeroAccumulators<num_acc>(acc);
2926
+
2927
+ Index k;
2928
+ for (k = 0; k + 2 <= depth; k += 2) {
2929
+ KLoop<num_acc, false, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2930
+ }
2931
+ if (depth & 1) {
2932
+ KLoop<num_acc, true, rhsExtraCols, num_rhs>(indexA, indexB, acc, strideB, k, offsetB, extra_cols);
2933
+ }
2934
+
2935
+ outputResultsVSX<num_acc, rhsExtraCols, lhsExtraRows, num_rhs>(acc, rows, pAlpha, result, extra_cols, extra_rows);
2936
+ }
2937
+
2938
+ // No more than 4 (uses 2X the accumulators or 8X the number of VSX registers)
2939
+ #define MAX_BFLOAT16_ACC_VSX 4
2940
+
2941
+ template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2942
+ void colVSXLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2943
+ const float* indexB, Index strideB, Index offsetB, float* result) {
2944
+ constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
2945
+ const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
2946
+ const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
2947
+ constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC_VSX);
2948
+
2949
+ do {
2950
+ colVSXLoopBodyIter<num_acc * 2, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB,
2951
+ result, extra_cols, extra_rows);
2952
+
2953
+ indexB += strideB * (num_acc * 2);
2954
+ result += rows * step;
2955
+ } while (multiIters && (step <= cols - (col += step)));
2956
+ }
2957
+
2958
+ template <const Index num_acc, bool rhsExtraCols, bool lhsExtraRows>
2959
+ EIGEN_ALWAYS_INLINE void colVSXLoopBodyExtraN(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha,
2960
+ const float* indexA, const float* blockB, Index strideB, Index offsetB,
2961
+ float* result) {
2962
+ if (MAX_BFLOAT16_ACC_VSX > num_acc) {
2963
+ colVSXLoopBody<num_acc + (rhsExtraCols ? 1 : 0), rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA,
2964
+ blockB, strideB, offsetB, result);
2965
+ }
2966
+ }
2967
+
2968
+ template <bool rhsExtraCols, bool lhsExtraRows>
2969
+ void colVSXLoopBodyExtra(Index col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const float* indexA,
2970
+ const float* blockB, Index strideB, Index offsetB, float* result) {
2971
+ switch ((cols - col) >> 2) {
2972
+ case 3:
2973
+ colVSXLoopBodyExtraN<3, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2974
+ offsetB, result);
2975
+ break;
2976
+ case 2:
2977
+ colVSXLoopBodyExtraN<2, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2978
+ offsetB, result);
2979
+ break;
2980
+ case 1:
2981
+ colVSXLoopBodyExtraN<1, rhsExtraCols, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB,
2982
+ offsetB, result);
2983
+ break;
2984
+ default:
2985
+ if (rhsExtraCols) {
2986
+ colVSXLoopBody<1, true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA, blockB, strideB, offsetB, result);
2987
+ }
2988
+ break;
2989
+ }
2990
+ }
2991
+
2992
+ template <Index size, bool lhsExtraRows = false>
2993
+ EIGEN_ALWAYS_INLINE void colVSXLoops(Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA,
2994
+ const float* indexA2, const float* blockB2, Index strideA, Index strideB,
2995
+ Index offsetB, float* result2) {
2996
+ Index delta_rows = 2 * (lhsExtraRows ? (rows & 3) : size);
2997
+ for (Index row = 0; row < size; row += 4) {
2998
+ convertArrayPointerBF16toF32Dup<lhsExtraRows>(const_cast<float*>(indexA2), strideA, delta_rows, indexA, row,
2999
+ rows & 3);
3000
+
3001
+ const float* blockB = blockB2;
3002
+ float* result = result2 + row;
3003
+
3004
+ Index col = 0;
3005
+ if (cols >= (MAX_BFLOAT16_ACC_VSX * 4)) {
3006
+ colVSXLoopBody<MAX_BFLOAT16_ACC_VSX, false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB,
3007
+ strideB, 0, result);
3008
+ blockB += (strideB >> 1) * col;
3009
+ result += rows * col;
3010
+ }
3011
+ if (cols & 3) {
3012
+ colVSXLoopBodyExtra<true, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, offsetB,
3013
+ result);
3014
+ } else {
3015
+ colVSXLoopBodyExtra<false, lhsExtraRows>(col, depth, cols, rows, pAlpha, indexA2, blockB, strideB, 0, result);
3016
+ }
3017
+ }
3018
+ }
3019
+
3020
+ template <Index size>
3021
+ EIGEN_ALWAYS_INLINE void calcVSXColLoops(const bfloat16*& indexA, const float* indexA2, Index& row, Index depth,
3022
+ Index cols, Index rows, const Packet4f pAlpha, const float* indexB,
3023
+ Index strideA, Index strideB, Index offsetA, Index offsetB, Index bigSuffix,
3024
+ float* result) {
3025
+ if ((size == 16) || (rows & size)) {
3026
+ indexA += size * offsetA;
3027
+ colVSXLoops<size>(depth, cols, rows, pAlpha, indexA, indexA2, indexB, strideA, strideB, offsetB, result + row);
3028
+ row += size;
3029
+ indexA += bigSuffix * size / 16;
3030
+ }
3031
+ }
3032
+
3033
+ template <const Index size, typename DataMapper>
3034
+ EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float* result, Index rows, const DataMapper& src) {
3035
+ constexpr Index extra = ((size < 4) ? 4 : size);
3036
+ while (i + size <= rows) {
3037
+ PacketBlock<Packet8bf, (size + 7) / 8> r32;
3038
+ r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
3039
+ if (size >= 16) {
3040
+ r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
3041
+ }
3042
+ if (size >= 32) {
3043
+ r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
3044
+ r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
3045
+ }
3046
+ storeConvertBlockBF16<size>(result + i, r32, rows & 3);
3047
+ i += extra;
3048
+ if (size != 32) break;
3049
+ }
3050
+ }
3051
+
3052
+ template <typename DataMapper>
3053
+ EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float* result, Index cols, Index rows, const DataMapper& src) {
3054
+ typedef typename DataMapper::LinearMapper LinearMapper;
3055
+ for (Index j = 0; j < cols; j++, result += rows) {
3056
+ const LinearMapper src2 = src.getLinearMapper(0, j);
3057
+ Index i = 0;
3058
+ convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
3059
+ convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
3060
+ convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
3061
+ convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
3062
+ convertBF16toF32<1, LinearMapper>(i, result, rows, src2);
3063
+ }
3064
+ }
3065
+
3066
+ EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16VSX(const float* res) {
3067
+ return F32ToBf16Both(ploadu<Packet4f>(res + 0), ploadu<Packet4f>(res + 4));
3068
+ }
3069
+
3070
+ template <typename DataMapper, const Index size>
3071
+ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16ColVSX(float* result, Index col, Index rows, const DataMapper& res) {
3072
+ const DataMapper res2 = res.getSubMapper(0, col);
3073
+ Index row;
3074
+ float* result2 = result + col * rows;
3075
+ for (row = 0; row + 8 <= rows; row += 8, result2 += 8) {
3076
+ // get and save block
3077
+ PacketBlock<Packet8bf, size> block;
3078
+ for (Index j = 0; j < size; j++) {
3079
+ block.packet[j] = convertF32toBF16VSX(result2 + j * rows);
3080
+ }
3081
+ res2.template storePacketBlock<Packet8bf, size>(row, 0, block);
3082
+ }
3083
+ // extra rows
3084
+ if (row < rows) {
3085
+ for (Index j = 0; j < size; j++) {
3086
+ Packet8bf fp16 = convertF32toBF16VSX(result2 + j * rows);
3087
+ res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
3088
+ }
3089
+ }
3090
+ }
3091
+
3092
+ template <typename DataMapper>
3093
+ EIGEN_ALWAYS_INLINE void convertArrayF32toBF16VSX(float* result, Index cols, Index rows, const DataMapper& res) {
3094
+ Index col;
3095
+ for (col = 0; col + 4 <= cols; col += 4) {
3096
+ convertArrayF32toBF16ColVSX<DataMapper, 4>(result, col, rows, res);
3097
+ }
3098
+ // extra cols
3099
+ switch (cols - col) {
3100
+ case 1:
3101
+ convertArrayF32toBF16ColVSX<DataMapper, 1>(result, col, rows, res);
3102
+ break;
3103
+ case 2:
3104
+ convertArrayF32toBF16ColVSX<DataMapper, 2>(result, col, rows, res);
3105
+ break;
3106
+ case 3:
3107
+ convertArrayF32toBF16ColVSX<DataMapper, 3>(result, col, rows, res);
3108
+ break;
3109
+ }
3110
+ }
3111
+
3112
+ template <typename DataMapper>
3113
+ void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth,
3114
+ Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3115
+ float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
3116
+ const Packet4f pAlpha = pset1<Packet4f>(falpha);
3117
+
3118
+ if (strideA == -1) strideA = depth;
3119
+ if (strideB == -1) strideB = depth;
3120
+
3121
+ ei_declare_aligned_stack_constructed_variable(float, result, cols* rows, 0);
3122
+ ei_declare_aligned_stack_constructed_variable(float, indexB2, strideB* cols, 0);
3123
+ ei_declare_aligned_stack_constructed_variable(float, indexA2, ((strideA + 1) & -2) * 4 * 2, 0);
3124
+
3125
+ convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
3126
+ convertArrayPointerBF16toF32(indexB2, cols, strideB, const_cast<bfloat16*>(indexB));
3127
+
3128
+ Index bigSuffix = 2 * 8 * (strideA - offsetA);
3129
+ float* indexBF32 = indexB2 + 4 * offsetB;
3130
+ offsetB *= 3;
3131
+ strideB *= 2;
3132
+
3133
+ Index row = 0;
3134
+ // LHS (8x16) block
3135
+ while (row + 16 <= rows) {
3136
+ calcVSXColLoops<16>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3137
+ bigSuffix, result);
3138
+ }
3139
+ // LHS (8x8) block
3140
+ calcVSXColLoops<8>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3141
+ bigSuffix, result);
3142
+ // LHS (8x4) block
3143
+ calcVSXColLoops<4>(indexA, indexA2, row, depth, cols, rows, pAlpha, indexBF32, strideA, strideB, offsetA, offsetB,
3144
+ bigSuffix, result);
3145
+ // extra rows
3146
+ if (rows & 3) {
3147
+ // This index is the beginning of remaining block.
3148
+ colVSXLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexA2, indexBF32, strideA, strideB, offsetB,
3149
+ result + row);
3150
+ }
3151
+
3152
+ // Convert back to bfloat16
3153
+ convertArrayF32toBF16VSX<DataMapper>(result, cols, rows, res);
3154
+ }
3155
+
3156
+ #undef MAX_BFLOAT16_ACC_VSX
3157
+
3158
+ #include "MatrixVectorProduct.inc"
3159
+
3160
+ /************************************
3161
+ * ppc64le template specializations *
3162
+ * **********************************/
3163
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3164
+ struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3165
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3166
+ };
3167
+
3168
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3169
+ void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3170
+ double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3171
+ dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, true> pack;
3172
+ pack(blockA, lhs, depth, rows, stride, offset);
3173
+ }
3174
+
3175
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3176
+ struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3177
+ void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3178
+ };
3179
+
3180
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3181
+ void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3182
+ double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3183
+ dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, true> pack;
3184
+ pack(blockA, lhs, depth, rows, stride, offset);
3185
+ }
3186
+
3187
+ #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3188
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3189
+ struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3190
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3191
+ };
3192
+
3193
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3194
+ void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3195
+ double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3196
+ dhs_pack<double, DataMapper, Packet2d, ColMajor, PanelMode, false> pack;
3197
+ pack(blockB, rhs, depth, cols, stride, offset);
3198
+ }
3199
+
3200
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3201
+ struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3202
+ void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3203
+ };
3204
+
3205
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3206
+ void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3207
+ double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3208
+ dhs_pack<double, DataMapper, Packet2d, RowMajor, PanelMode, false> pack;
3209
+ pack(blockB, rhs, depth, cols, stride, offset);
3210
+ }
3211
+
3212
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3213
+ struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3214
+ void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3215
+ };
3216
+
3217
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3218
+ void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3219
+ bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3220
+ dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, false> pack;
3221
+ pack(blockB, rhs, depth, cols, stride, offset);
3222
+ }
3223
+
3224
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3225
+ struct gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3226
+ void operator()(bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3227
+ };
3228
+
3229
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3230
+ void gemm_pack_rhs<bfloat16, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3231
+ bfloat16* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3232
+ dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, false> pack;
3233
+ pack(blockB, rhs, depth, cols, stride, offset);
3234
+ }
3235
+ #endif
3236
+
3237
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3238
+ struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3239
+ void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3240
+ };
3241
+
3242
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3243
+ void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3244
+ bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3245
+ dhs_pack<bfloat16, DataMapper, Packet8bf, ColMajor, PanelMode, true> pack;
3246
+ pack(blockA, lhs, depth, rows, stride, offset);
3247
+ }
3248
+
3249
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3250
+ struct gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3251
+ void operator()(bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3252
+ };
3253
+
3254
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3255
+ void gemm_pack_lhs<bfloat16, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3256
+ bfloat16* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3257
+ dhs_pack<bfloat16, DataMapper, Packet8bf, RowMajor, PanelMode, true> pack;
3258
+ pack(blockA, lhs, depth, rows, stride, offset);
3259
+ }
3260
+
3261
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3262
+ struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3263
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3264
+ };
3265
+
3266
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3267
+ void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode>::operator()(
3268
+ float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3269
+ dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, true> pack;
3270
+ pack(blockA, lhs, depth, rows, stride, offset);
3271
+ }
3272
+
3273
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3274
+ struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3275
+ void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0, Index offset = 0);
3276
+ };
3277
+
3278
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3279
+ void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode>::operator()(
3280
+ float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) {
3281
+ dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, true> pack;
3282
+ pack(blockA, lhs, depth, rows, stride, offset);
3283
+ }
3284
+
3285
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3286
+ struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3287
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3288
+ Index offset = 0);
3289
+ };
3290
+
3291
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3292
+ void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3293
+ PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3294
+ Index stride, Index offset) {
3295
+ dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack;
3296
+ pack(blockA, lhs, depth, rows, stride, offset);
3297
+ }
3298
+
3299
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3300
+ struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3301
+ void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3302
+ Index offset = 0);
3303
+ };
3304
+
3305
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3306
+ void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3307
+ PanelMode>::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows,
3308
+ Index stride, Index offset) {
3309
+ dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack;
3310
+ pack(blockA, lhs, depth, rows, stride, offset);
3311
+ }
3312
+
3313
+ #if EIGEN_ALTIVEC_USE_CUSTOM_PACK
3314
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3315
+ struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3316
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3317
+ };
3318
+
3319
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3320
+ void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3321
+ float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3322
+ dhs_pack<float, DataMapper, Packet4f, ColMajor, PanelMode, false> pack;
3323
+ pack(blockB, rhs, depth, cols, stride, offset);
3324
+ }
3325
+
3326
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3327
+ struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3328
+ void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0, Index offset = 0);
3329
+ };
3330
+
3331
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3332
+ void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3333
+ float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3334
+ dhs_pack<float, DataMapper, Packet4f, RowMajor, PanelMode, false> pack;
3335
+ pack(blockB, rhs, depth, cols, stride, offset);
3336
+ }
3337
+ #endif
3338
+
3339
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3340
+ struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3341
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3342
+ Index offset = 0);
3343
+ };
3344
+
3345
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3346
+ void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3347
+ std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3348
+ dhs_cpack<float, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack;
3349
+ pack(blockB, rhs, depth, cols, stride, offset);
3350
+ }
3351
+
3352
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3353
+ struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3354
+ void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3355
+ Index offset = 0);
3356
+ };
3357
+
3358
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3359
+ void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3360
+ std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3361
+ dhs_cpack<float, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack;
3362
+ pack(blockB, rhs, depth, cols, stride, offset);
3363
+ }
3364
+
3365
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3366
+ struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> {
3367
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3368
+ Index offset = 0);
3369
+ };
3370
+
3371
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3372
+ void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate,
3373
+ PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3374
+ Index stride, Index offset) {
3375
+ dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack;
3376
+ pack(blockA, lhs, depth, rows, stride, offset);
3377
+ }
3378
+
3379
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3380
+ struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> {
3381
+ void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride = 0,
3382
+ Index offset = 0);
3383
+ };
3384
+
3385
+ template <typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode>
3386
+ void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate,
3387
+ PanelMode>::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows,
3388
+ Index stride, Index offset) {
3389
+ dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack;
3390
+ pack(blockA, lhs, depth, rows, stride, offset);
3391
+ }
3392
+
3393
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3394
+ struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> {
3395
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3396
+ Index offset = 0);
3397
+ };
3398
+
3399
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3400
+ void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode>::operator()(
3401
+ std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3402
+ dhs_cpack<double, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack;
3403
+ pack(blockB, rhs, depth, cols, stride, offset);
3404
+ }
3405
+
3406
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3407
+ struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> {
3408
+ void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride = 0,
3409
+ Index offset = 0);
3410
+ };
3411
+
3412
+ template <typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode>
3413
+ void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>::operator()(
3414
+ std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) {
3415
+ dhs_cpack<double, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack;
3416
+ pack(blockB, rhs, depth, cols, stride, offset);
3417
+ }
3418
+
3419
+ // ********* gebp specializations *********
3420
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3421
+ struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3422
+ typedef typename quad_traits<float>::vectortype Packet;
3423
+ typedef typename quad_traits<float>::rhstype RhsPacket;
3424
+
3425
+ void operator()(const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols,
3426
+ float alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3427
+ };
3428
+
3429
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3430
+ void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3431
+ const DataMapper& res, const float* blockA, const float* blockB, Index rows, Index depth, Index cols, float alpha,
3432
+ Index strideA, Index strideB, Index offsetA, Index offsetB) {
3433
+ const Index accRows = quad_traits<float>::rows;
3434
+ const Index accCols = quad_traits<float>::size;
3435
+ static void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index,
3436
+ Index, Index) =
3437
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3438
+ (supportsMMA()) ? &Eigen::internal::gemmMMA<float, Packet, RhsPacket, DataMapper, accRows, accCols> :
3439
+ #endif
3440
+ &Eigen::internal::gemm<float, Packet, RhsPacket, DataMapper, accRows, accCols>;
3441
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3442
+ }
3443
+
3444
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3445
+ struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3446
+ typedef Packet4f Packet;
3447
+ typedef Packet2cf Packetc;
3448
+ typedef Packet4f RhsPacket;
3449
+
3450
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB,
3451
+ Index rows, Index depth, Index cols, std::complex<float> alpha, Index strideA = -1,
3452
+ Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3453
+ };
3454
+
3455
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3456
+ void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs,
3457
+ ConjugateRhs>::operator()(const DataMapper& res, const std::complex<float>* blockA,
3458
+ const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3459
+ std::complex<float> alpha, Index strideA, Index strideB, Index offsetA,
3460
+ Index offsetB) {
3461
+ const Index accRows = quad_traits<float>::rows;
3462
+ const Index accCols = quad_traits<float>::size;
3463
+ static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*, Index, Index,
3464
+ Index, std::complex<float>, Index, Index, Index, Index) =
3465
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3466
+ (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>,
3467
+ float, Packet, Packetc, RhsPacket, DataMapper, accRows,
3468
+ accCols, ConjugateLhs, ConjugateRhs, false, false>
3469
+ :
3470
+ #endif
3471
+ &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>,
3472
+ float, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3473
+ ConjugateLhs, ConjugateRhs, false, false>;
3474
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3475
+ }
3476
+
3477
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3478
+ struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3479
+ typedef Packet4f Packet;
3480
+ typedef Packet2cf Packetc;
3481
+ typedef Packet4f RhsPacket;
3482
+
3483
+ void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows,
3484
+ Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3485
+ Index offsetA = 0, Index offsetB = 0);
3486
+ };
3487
+
3488
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3489
+ void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3490
+ const DataMapper& res, const float* blockA, const std::complex<float>* blockB, Index rows, Index depth, Index cols,
3491
+ std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3492
+ const Index accRows = quad_traits<float>::rows;
3493
+ const Index accCols = quad_traits<float>::size;
3494
+ static void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*, Index, Index, Index,
3495
+ std::complex<float>, Index, Index, Index, Index) =
3496
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3497
+ (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float,
3498
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3499
+ ConjugateLhs, ConjugateRhs, true, false>
3500
+ :
3501
+ #endif
3502
+ &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Packet,
3503
+ Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3504
+ ConjugateRhs, true, false>;
3505
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3506
+ }
3507
+
3508
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3509
+ struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3510
+ typedef Packet4f Packet;
3511
+ typedef Packet2cf Packetc;
3512
+ typedef Packet4f RhsPacket;
3513
+
3514
+ void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows,
3515
+ Index depth, Index cols, std::complex<float> alpha, Index strideA = -1, Index strideB = -1,
3516
+ Index offsetA = 0, Index offsetB = 0);
3517
+ };
3518
+
3519
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3520
+ void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3521
+ const DataMapper& res, const std::complex<float>* blockA, const float* blockB, Index rows, Index depth, Index cols,
3522
+ std::complex<float> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3523
+ const Index accRows = quad_traits<float>::rows;
3524
+ const Index accCols = quad_traits<float>::size;
3525
+ static void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*, Index, Index, Index,
3526
+ std::complex<float>, Index, Index, Index, Index) =
3527
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3528
+ (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float,
3529
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3530
+ ConjugateLhs, ConjugateRhs, false, true>
3531
+ :
3532
+ #endif
3533
+ &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Packet,
3534
+ Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3535
+ ConjugateRhs, false, true>;
3536
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3537
+ }
3538
+
3539
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3540
+ struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3541
+ typedef typename quad_traits<double>::vectortype Packet;
3542
+ typedef typename quad_traits<double>::rhstype RhsPacket;
3543
+
3544
+ void operator()(const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth,
3545
+ Index cols, double alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3546
+ Index offsetB = 0);
3547
+ };
3548
+
3549
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3550
+ void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3551
+ const DataMapper& res, const double* blockA, const double* blockB, Index rows, Index depth, Index cols,
3552
+ double alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3553
+ const Index accRows = quad_traits<double>::rows;
3554
+ const Index accCols = quad_traits<double>::size;
3555
+ static void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index,
3556
+ Index, Index, Index) =
3557
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3558
+ (supportsMMA()) ? &Eigen::internal::gemmMMA<double, Packet, RhsPacket, DataMapper, accRows, accCols> :
3559
+ #endif
3560
+ &Eigen::internal::gemm<double, Packet, RhsPacket, DataMapper, accRows, accCols>;
3561
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3562
+ }
3563
+
3564
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3565
+ struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3566
+ typedef quad_traits<double>::vectortype Packet;
3567
+ typedef Packet1cd Packetc;
3568
+ typedef quad_traits<double>::rhstype RhsPacket;
3569
+
3570
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB,
3571
+ Index rows, Index depth, Index cols, std::complex<double> alpha, Index strideA = -1,
3572
+ Index strideB = -1, Index offsetA = 0, Index offsetB = 0);
3573
+ };
3574
+
3575
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3576
+ void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs,
3577
+ ConjugateRhs>::operator()(const DataMapper& res, const std::complex<double>* blockA,
3578
+ const std::complex<double>* blockB, Index rows, Index depth, Index cols,
3579
+ std::complex<double> alpha, Index strideA, Index strideB, Index offsetA,
3580
+ Index offsetB) {
3581
+ const Index accRows = quad_traits<double>::rows;
3582
+ const Index accCols = quad_traits<double>::size;
3583
+ static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*, Index,
3584
+ Index, Index, std::complex<double>, Index, Index, Index, Index) =
3585
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3586
+ (supportsMMA())
3587
+ ? &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double,
3588
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3589
+ ConjugateRhs, false, false>
3590
+ :
3591
+ #endif
3592
+ &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double,
3593
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3594
+ ConjugateRhs, false, false>;
3595
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3596
+ }
3597
+
3598
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3599
+ struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3600
+ typedef quad_traits<double>::vectortype Packet;
3601
+ typedef Packet1cd Packetc;
3602
+ typedef quad_traits<double>::rhstype RhsPacket;
3603
+
3604
+ void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows,
3605
+ Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3606
+ Index offsetA = 0, Index offsetB = 0);
3607
+ };
3608
+
3609
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3610
+ void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3611
+ const DataMapper& res, const std::complex<double>* blockA, const double* blockB, Index rows, Index depth,
3612
+ Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3613
+ const Index accRows = quad_traits<double>::rows;
3614
+ const Index accCols = quad_traits<double>::size;
3615
+ static void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*, Index, Index, Index,
3616
+ std::complex<double>, Index, Index, Index, Index) =
3617
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3618
+ (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double,
3619
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3620
+ ConjugateLhs, ConjugateRhs, false, true>
3621
+ :
3622
+ #endif
3623
+ &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Packet,
3624
+ Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3625
+ ConjugateRhs, false, true>;
3626
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3627
+ }
3628
+
3629
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3630
+ struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3631
+ typedef quad_traits<double>::vectortype Packet;
3632
+ typedef Packet1cd Packetc;
3633
+ typedef quad_traits<double>::rhstype RhsPacket;
3634
+
3635
+ void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows,
3636
+ Index depth, Index cols, std::complex<double> alpha, Index strideA = -1, Index strideB = -1,
3637
+ Index offsetA = 0, Index offsetB = 0);
3638
+ };
3639
+
3640
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3641
+ void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3642
+ const DataMapper& res, const double* blockA, const std::complex<double>* blockB, Index rows, Index depth,
3643
+ Index cols, std::complex<double> alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3644
+ const Index accRows = quad_traits<double>::rows;
3645
+ const Index accCols = quad_traits<double>::size;
3646
+ static void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*, Index, Index, Index,
3647
+ std::complex<double>, Index, Index, Index, Index) =
3648
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3649
+ (supportsMMA()) ? &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double,
3650
+ Packet, Packetc, RhsPacket, DataMapper, accRows, accCols,
3651
+ ConjugateLhs, ConjugateRhs, true, false>
3652
+ :
3653
+ #endif
3654
+ &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Packet,
3655
+ Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs,
3656
+ ConjugateRhs, true, false>;
3657
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3658
+ }
3659
+
3660
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3661
+ struct gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> {
3662
+ typedef typename quad_traits<bfloat16>::vectortype Packet;
3663
+ typedef typename quad_traits<bfloat16>::rhstype RhsPacket;
3664
+
3665
+ void operator()(const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth,
3666
+ Index cols, bfloat16 alpha, Index strideA = -1, Index strideB = -1, Index offsetA = 0,
3667
+ Index offsetB = 0);
3668
+ };
3669
+
3670
+ template <typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs>
3671
+ void gebp_kernel<bfloat16, bfloat16, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs>::operator()(
3672
+ const DataMapper& res, const bfloat16* blockA, const bfloat16* blockB, Index rows, Index depth, Index cols,
3673
+ bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
3674
+ static void (*gemm_function)(const DataMapper&, const bfloat16*, const bfloat16*, Index, Index, Index, bfloat16,
3675
+ Index, Index, Index, Index) =
3676
+ #ifdef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
3677
+ (supportsMMA()) ? &Eigen::internal::gemmMMAbfloat16<DataMapper> :
3678
+ #endif
3679
+ &Eigen::internal::gemmbfloat16<DataMapper>;
3680
+ gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB);
3681
+ }
3682
+ } // end namespace internal
3683
+
3684
+ } // end namespace Eigen
3685
+
3686
+ #endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H