@smake/eigen 1.0.2 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (435) hide show
  1. package/README.md +1 -1
  2. package/eigen/Eigen/AccelerateSupport +52 -0
  3. package/eigen/Eigen/Cholesky +18 -21
  4. package/eigen/Eigen/CholmodSupport +28 -28
  5. package/eigen/Eigen/Core +235 -326
  6. package/eigen/Eigen/Eigenvalues +16 -14
  7. package/eigen/Eigen/Geometry +21 -24
  8. package/eigen/Eigen/Householder +9 -8
  9. package/eigen/Eigen/IterativeLinearSolvers +8 -4
  10. package/eigen/Eigen/Jacobi +14 -14
  11. package/eigen/Eigen/KLUSupport +43 -0
  12. package/eigen/Eigen/LU +16 -20
  13. package/eigen/Eigen/MetisSupport +12 -12
  14. package/eigen/Eigen/OrderingMethods +54 -54
  15. package/eigen/Eigen/PaStiXSupport +23 -20
  16. package/eigen/Eigen/PardisoSupport +17 -14
  17. package/eigen/Eigen/QR +18 -21
  18. package/eigen/Eigen/QtAlignedMalloc +5 -13
  19. package/eigen/Eigen/SPQRSupport +21 -14
  20. package/eigen/Eigen/SVD +23 -18
  21. package/eigen/Eigen/Sparse +1 -4
  22. package/eigen/Eigen/SparseCholesky +18 -23
  23. package/eigen/Eigen/SparseCore +18 -17
  24. package/eigen/Eigen/SparseLU +12 -8
  25. package/eigen/Eigen/SparseQR +16 -14
  26. package/eigen/Eigen/StdDeque +5 -2
  27. package/eigen/Eigen/StdList +5 -2
  28. package/eigen/Eigen/StdVector +5 -2
  29. package/eigen/Eigen/SuperLUSupport +30 -24
  30. package/eigen/Eigen/ThreadPool +80 -0
  31. package/eigen/Eigen/UmfPackSupport +19 -17
  32. package/eigen/Eigen/Version +14 -0
  33. package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
  34. package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
  35. package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
  36. package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
  37. package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
  38. package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
  39. package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
  40. package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
  41. package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
  42. package/eigen/Eigen/src/Core/Array.h +341 -294
  43. package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
  44. package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
  45. package/eigen/Eigen/src/Core/Assign.h +30 -40
  46. package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
  47. package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
  48. package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
  49. package/eigen/Eigen/src/Core/Block.h +375 -398
  50. package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
  51. package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
  52. package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
  53. package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
  54. package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
  55. package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
  56. package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
  57. package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
  58. package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
  59. package/eigen/Eigen/src/Core/DenseBase.h +632 -571
  60. package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
  61. package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
  62. package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
  63. package/eigen/Eigen/src/Core/Diagonal.h +169 -210
  64. package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
  65. package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
  66. package/eigen/Eigen/src/Core/Dot.h +172 -222
  67. package/eigen/Eigen/src/Core/EigenBase.h +75 -85
  68. package/eigen/Eigen/src/Core/Fill.h +138 -0
  69. package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
  70. package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
  71. package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
  72. package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
  73. package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
  74. package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
  75. package/eigen/Eigen/src/Core/IO.h +147 -139
  76. package/eigen/Eigen/src/Core/IndexedView.h +321 -0
  77. package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
  78. package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
  79. package/eigen/Eigen/src/Core/Inverse.h +56 -66
  80. package/eigen/Eigen/src/Core/Map.h +124 -142
  81. package/eigen/Eigen/src/Core/MapBase.h +256 -281
  82. package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
  83. package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
  84. package/eigen/Eigen/src/Core/Matrix.h +491 -416
  85. package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
  86. package/eigen/Eigen/src/Core/NestByValue.h +66 -85
  87. package/eigen/Eigen/src/Core/NoAlias.h +79 -85
  88. package/eigen/Eigen/src/Core/NumTraits.h +235 -148
  89. package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
  90. package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
  91. package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
  92. package/eigen/Eigen/src/Core/Product.h +260 -139
  93. package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
  94. package/eigen/Eigen/src/Core/Random.h +161 -136
  95. package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
  96. package/eigen/Eigen/src/Core/RealView.h +250 -0
  97. package/eigen/Eigen/src/Core/Redux.h +366 -336
  98. package/eigen/Eigen/src/Core/Ref.h +308 -209
  99. package/eigen/Eigen/src/Core/Replicate.h +94 -106
  100. package/eigen/Eigen/src/Core/Reshaped.h +398 -0
  101. package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
  102. package/eigen/Eigen/src/Core/Reverse.h +136 -145
  103. package/eigen/Eigen/src/Core/Select.h +70 -140
  104. package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
  105. package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
  106. package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
  107. package/eigen/Eigen/src/Core/Solve.h +97 -111
  108. package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
  109. package/eigen/Eigen/src/Core/SolverBase.h +138 -101
  110. package/eigen/Eigen/src/Core/StableNorm.h +156 -160
  111. package/eigen/Eigen/src/Core/StlIterators.h +619 -0
  112. package/eigen/Eigen/src/Core/Stride.h +91 -88
  113. package/eigen/Eigen/src/Core/Swap.h +70 -38
  114. package/eigen/Eigen/src/Core/Transpose.h +295 -273
  115. package/eigen/Eigen/src/Core/Transpositions.h +272 -317
  116. package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
  117. package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
  118. package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
  119. package/eigen/Eigen/src/Core/Visitor.h +480 -216
  120. package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
  121. package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
  122. package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
  123. package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
  124. package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
  125. package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
  126. package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
  127. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
  128. package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
  129. package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
  130. package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
  131. package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
  132. package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
  133. package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
  134. package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
  135. package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
  136. package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
  137. package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
  138. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
  139. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
  140. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
  141. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
  142. package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
  143. package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
  144. package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
  145. package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
  146. package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
  147. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
  148. package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
  149. package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
  150. package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
  151. package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
  152. package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
  153. package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
  154. package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
  155. package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
  156. package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
  157. package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
  158. package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
  159. package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
  160. package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
  161. package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
  162. package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
  163. package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
  164. package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
  165. package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
  166. package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
  167. package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
  168. package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
  169. package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
  170. package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
  171. package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
  172. package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
  173. package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
  174. package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
  175. package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
  176. package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
  177. package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
  178. package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
  179. package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
  180. package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
  181. package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
  182. package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
  183. package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
  184. package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
  185. package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
  186. package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
  187. package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
  188. package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
  189. package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
  190. package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
  191. package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
  192. package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
  193. package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
  194. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
  195. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
  196. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
  197. package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
  198. package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
  199. package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
  200. package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
  201. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
  202. package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
  203. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
  204. package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
  205. package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
  206. package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
  207. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
  208. package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
  209. package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
  210. package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
  211. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
  212. package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
  213. package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
  214. package/eigen/Eigen/src/Core/util/Assert.h +158 -0
  215. package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
  216. package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
  217. package/eigen/Eigen/src/Core/util/Constants.h +314 -263
  218. package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
  219. package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
  220. package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
  221. package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
  222. package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
  223. package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
  224. package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
  225. package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
  226. package/eigen/Eigen/src/Core/util/Macros.h +939 -646
  227. package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
  228. package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
  229. package/eigen/Eigen/src/Core/util/Meta.h +618 -426
  230. package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
  231. package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
  232. package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
  233. package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
  234. package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
  235. package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
  236. package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
  237. package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
  238. package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
  239. package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
  240. package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
  241. package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
  242. package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
  243. package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
  244. package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
  245. package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
  246. package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
  247. package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
  248. package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
  249. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
  250. package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
  251. package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
  252. package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
  253. package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
  254. package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
  255. package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
  256. package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
  257. package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
  258. package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
  259. package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
  260. package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
  261. package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
  262. package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
  263. package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
  264. package/eigen/Eigen/src/Geometry/Transform.h +896 -953
  265. package/eigen/Eigen/src/Geometry/Translation.h +100 -98
  266. package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
  267. package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
  268. package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
  269. package/eigen/Eigen/src/Householder/Householder.h +104 -122
  270. package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
  271. package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
  272. package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
  273. package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
  274. package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
  275. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
  276. package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
  277. package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
  278. package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
  279. package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
  280. package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
  281. package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
  282. package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
  283. package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
  284. package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
  285. package/eigen/Eigen/src/LU/Determinant.h +60 -63
  286. package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
  287. package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
  288. package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
  289. package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
  290. package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
  291. package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
  292. package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
  293. package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
  294. package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
  295. package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
  296. package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
  297. package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
  298. package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
  299. package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
  300. package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
  301. package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
  302. package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
  303. package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
  304. package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
  305. package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
  306. package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
  307. package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
  308. package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
  309. package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
  310. package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
  311. package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
  312. package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
  313. package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
  314. package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
  315. package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
  316. package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
  317. package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
  318. package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
  319. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
  320. package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
  321. package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
  322. package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
  323. package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
  324. package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
  325. package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
  326. package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
  327. package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
  328. package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
  329. package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
  330. package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
  331. package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
  332. package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
  333. package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
  334. package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
  335. package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
  336. package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
  337. package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
  338. package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
  339. package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
  340. package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
  341. package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
  342. package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
  343. package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
  344. package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
  345. package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
  346. package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
  347. package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
  348. package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
  349. package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
  350. package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
  351. package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
  352. package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
  353. package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
  354. package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
  355. package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
  356. package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
  357. package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
  358. package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
  359. package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
  360. package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
  361. package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
  362. package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
  363. package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
  364. package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
  365. package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
  366. package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
  367. package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
  368. package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
  369. package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
  370. package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
  371. package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
  372. package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
  373. package/eigen/Eigen/src/StlSupport/details.h +48 -50
  374. package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
  375. package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
  376. package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
  377. package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
  378. package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
  379. package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
  380. package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
  381. package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
  382. package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
  383. package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
  384. package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
  385. package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
  386. package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
  387. package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
  388. package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
  389. package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
  390. package/eigen/Eigen/src/misc/Image.h +41 -43
  391. package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
  392. package/eigen/Eigen/src/misc/Kernel.h +39 -41
  393. package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
  394. package/eigen/Eigen/src/misc/blas.h +83 -426
  395. package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
  396. package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
  397. package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
  398. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
  399. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
  400. package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
  401. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
  402. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
  403. package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
  404. package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
  405. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
  406. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
  407. package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
  408. package/lib/LibEigen.d.ts +4 -0
  409. package/lib/LibEigen.js +14 -0
  410. package/lib/index.d.ts +1 -1
  411. package/lib/index.js +7 -3
  412. package/package.json +2 -10
  413. package/eigen/Eigen/CMakeLists.txt +0 -19
  414. package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
  415. package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
  416. package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
  417. package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
  418. package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
  419. package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
  420. package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
  421. package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
  422. package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
  423. package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
  424. package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
  425. package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
  426. package/eigen/Eigen/src/misc/lapack.h +0 -152
  427. package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
  428. package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
  429. package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
  430. package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
  431. package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
  432. package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
  433. package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
  434. package/lib/eigen.d.ts +0 -2
  435. package/lib/eigen.js +0 -15
@@ -0,0 +1,2818 @@
1
+ // This file is part of Eigen, a lightweight C++ template library
2
+ // for linear algebra.
3
+ //
4
+ // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com)
5
+ //
6
+ // This Source Code Form is subject to the terms of the Mozilla
7
+ // Public License v. 2.0. If a copy of the MPL was not distributed
8
+ // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
+
10
+ #ifndef EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
11
+ #define EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H
12
+
13
+ // IWYU pragma: private
14
+ #include "../../InternalHeaderCheck.h"
15
+
16
+ #if defined(__MMA__) && !EIGEN_ALTIVEC_DISABLE_MMA
17
+ #if EIGEN_COMP_LLVM || (__GNUC__ > 10 || __GNUC_MINOR__ >= 3)
18
+ #define USE_GEMV_MMA
19
+ #endif
20
+
21
+ #if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
22
+ // Only allow one vector_pair in buggy gcc - gcc 10.x has a bug
23
+ #define GCC_ONE_VECTORPAIR_BUG
24
+ #endif
25
+ #endif
26
+
27
+ // #define USE_SLOWER_GEMV_MMA // MMA is currently not as fast as VSX in complex double GEMV (revisit when gcc is
28
+ // improved)
29
+
30
+ // #define EIGEN_POWER_USE_GEMV_PREFETCH
31
+ #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
32
+ #define EIGEN_POWER_GEMV_PREFETCH(p) prefetch(p)
33
+ #else
34
+ #define EIGEN_POWER_GEMV_PREFETCH(p)
35
+ #endif
36
+
37
+ #ifdef __has_builtin
38
+ #if !__has_builtin(__builtin_vsx_assemble_pair)
39
+ #define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair
40
+ #endif
41
+ #if !__has_builtin(__builtin_vsx_disassemble_pair)
42
+ #define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair
43
+ #endif
44
+ #endif
45
+
46
+ #if EIGEN_COMP_LLVM
47
+ #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
48
+ __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
49
+ #else
50
+ #if (__GNUC__ <= 10)
51
+ #if (__GNUC_MINOR__ > 3)
52
+ #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
53
+ __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src2, (__vector unsigned char)src1)
54
+ #else
55
+ #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
56
+ __builtin_vsx_assemble_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
57
+ #endif
58
+ #else
59
+ #define GEMV_BUILDPAIR_MMA(dst, src1, src2) \
60
+ __builtin_vsx_build_pair(&dst, (__vector unsigned char)src1, (__vector unsigned char)src2)
61
+ #endif
62
+ #endif
63
+
64
+ #define GEMV_IS_COMPLEX_COMPLEX ((sizeof(LhsPacket) == 16) && (sizeof(RhsPacket) == 16))
65
+ #define GEMV_IS_FLOAT (ResPacketSize == (16 / sizeof(float)))
66
+ #define GEMV_IS_SCALAR (sizeof(ResPacket) != 16)
67
+ #define GEMV_IS_COMPLEX_FLOAT (ResPacketSize == (16 / sizeof(std::complex<float>)))
68
+
69
+ /** \internal multiply and add and store results */
70
+ template <typename ResPacket, typename ResScalar>
71
+ EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResPacket& palpha, ResPacket& data) {
72
+ pstoreu(res, pmadd(data, palpha, ploadu<ResPacket>(res)));
73
+ }
74
+
75
+ template <typename ResScalar>
76
+ EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResScalar& alpha, ResScalar& data) {
77
+ *res += (alpha * data);
78
+ }
79
+
80
+ #define GEMV_UNROLL(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
81
+
82
+ #define GEMV_UNROLL_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
83
+
84
+ #define GEMV_GETN(N) (((N) * ResPacketSize) >> 2)
85
+
86
+ #define GEMV_LOADPACKET_COL(iter) lhs.template load<LhsPacket, LhsAlignment>(i + ((iter) * LhsPacketSize), j)
87
+
88
+ #ifdef USE_GEMV_MMA
89
+ #define GEMV_UNROLL3(func, N, which) \
90
+ func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
91
+ func(6, N, which) func(7, N, which)
92
+
93
+ #define GEMV_UNUSED_VAR(iter, N, which) \
94
+ if (GEMV_GETN(N) <= iter) { \
95
+ EIGEN_UNUSED_VARIABLE(which##iter); \
96
+ }
97
+
98
+ #define GEMV_UNUSED_EXTRA_VAR(iter, N, which) \
99
+ if (N <= iter) { \
100
+ EIGEN_UNUSED_VARIABLE(which##iter); \
101
+ }
102
+
103
+ #define GEMV_UNUSED_EXTRA(N, which) GEMV_UNROLL3(GEMV_UNUSED_EXTRA_VAR, N, which)
104
+
105
+ #define GEMV_UNUSED(N, which) GEMV_UNROLL3(GEMV_UNUSED_VAR, N, which)
106
+
107
+ #define GEMV_INIT_MMA(iter, N) \
108
+ if (GEMV_GETN(N) > iter) { \
109
+ __builtin_mma_xxsetaccz(&e##iter); \
110
+ }
111
+
112
+ #if EIGEN_COMP_LLVM
113
+ #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
114
+ GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
115
+ #else
116
+ #define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
117
+ const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
118
+ b##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src##iter1));
119
+ #endif
120
+
121
+ #define GEMV_LOAD1A_COL_MMA(iter, N) \
122
+ if (GEMV_GETN(N) > iter) { \
123
+ if (GEMV_IS_FLOAT) { \
124
+ g##iter = GEMV_LOADPACKET_COL(iter); \
125
+ EIGEN_UNUSED_VARIABLE(b##iter); \
126
+ } else { \
127
+ GEMV_LOADPAIR_COL_MMA(iter, iter << 1) \
128
+ EIGEN_UNUSED_VARIABLE(g##iter); \
129
+ } \
130
+ } else { \
131
+ EIGEN_UNUSED_VARIABLE(b##iter); \
132
+ EIGEN_UNUSED_VARIABLE(g##iter); \
133
+ }
134
+
135
+ #define GEMV_WORK1A_COL_MMA(iter, N) \
136
+ if (GEMV_GETN(N) > iter) { \
137
+ if (GEMV_IS_FLOAT) { \
138
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, a0, g##iter); \
139
+ } else { \
140
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter, b##iter, a0); \
141
+ } \
142
+ }
143
+
144
+ #define GEMV_LOAD1B_COL_MMA(iter1, iter2, iter3, N) \
145
+ if (GEMV_GETN(N) > iter1) { \
146
+ if (GEMV_IS_FLOAT) { \
147
+ GEMV_LOADPAIR_COL_MMA(iter2, iter2) \
148
+ EIGEN_UNUSED_VARIABLE(b##iter3); \
149
+ } else { \
150
+ GEMV_LOADPAIR_COL_MMA(iter2, iter2 << 1) \
151
+ GEMV_LOADPAIR_COL_MMA(iter3, iter3 << 1) \
152
+ } \
153
+ } else { \
154
+ EIGEN_UNUSED_VARIABLE(b##iter2); \
155
+ EIGEN_UNUSED_VARIABLE(b##iter3); \
156
+ } \
157
+ EIGEN_UNUSED_VARIABLE(g##iter2); \
158
+ EIGEN_UNUSED_VARIABLE(g##iter3);
159
+
160
+ #define GEMV_WORK1B_COL_MMA(iter1, iter2, iter3, N) \
161
+ if (GEMV_GETN(N) > iter1) { \
162
+ if (GEMV_IS_FLOAT) { \
163
+ LhsPacket h[2]; \
164
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(h), &b##iter2); \
165
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, a0, h[0]); \
166
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, a0, h[1]); \
167
+ } else { \
168
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter2, b##iter2, a0); \
169
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&e##iter3, b##iter3, a0); \
170
+ } \
171
+ }
172
+
173
+ #if EIGEN_COMP_LLVM
174
+ #define GEMV_LOAD_COL_MMA(N) \
175
+ if (GEMV_GETN(N) > 1) { \
176
+ GEMV_UNROLL_HALF(GEMV_LOAD1B_COL_MMA, (N >> 1)) \
177
+ } else { \
178
+ GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N) \
179
+ }
180
+
181
+ #define GEMV_WORK_COL_MMA(N) \
182
+ if (GEMV_GETN(N) > 1) { \
183
+ GEMV_UNROLL_HALF(GEMV_WORK1B_COL_MMA, (N >> 1)) \
184
+ } else { \
185
+ GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N) \
186
+ }
187
+ #else
188
+ #define GEMV_LOAD_COL_MMA(N) GEMV_UNROLL(GEMV_LOAD1A_COL_MMA, N)
189
+
190
+ #define GEMV_WORK_COL_MMA(N) GEMV_UNROLL(GEMV_WORK1A_COL_MMA, N)
191
+ #endif
192
+
193
+ #define GEMV_DISASSEMBLE_MMA(iter, N) \
194
+ if (GEMV_GETN(N) > iter) { \
195
+ __builtin_mma_disassemble_acc(&result##iter.packet, &e##iter); \
196
+ if (!GEMV_IS_FLOAT) { \
197
+ result##iter.packet[0][1] = result##iter.packet[1][0]; \
198
+ result##iter.packet[2][1] = result##iter.packet[3][0]; \
199
+ } \
200
+ }
201
+
202
+ #define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
203
+ b##iter1 = *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize));
204
+
205
+ #define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
206
+ if (GEMV_GETN(N) > iter1) { \
207
+ if (GEMV_IS_FLOAT) { \
208
+ GEMV_LOADPAIR2_COL_MMA(iter2, iter2); \
209
+ EIGEN_UNUSED_VARIABLE(b##iter3); \
210
+ } else { \
211
+ GEMV_LOADPAIR2_COL_MMA(iter2, iter2 << 1); \
212
+ GEMV_LOADPAIR2_COL_MMA(iter3, iter3 << 1); \
213
+ } \
214
+ } else { \
215
+ EIGEN_UNUSED_VARIABLE(b##iter2); \
216
+ EIGEN_UNUSED_VARIABLE(b##iter3); \
217
+ }
218
+
219
+ #if EIGEN_COMP_LLVM
220
+ #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
221
+ ResPacket f##iter2[2]; \
222
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
223
+ f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
224
+ f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
225
+ GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
226
+ #else
227
+ #define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
228
+ if (GEMV_IS_FLOAT) { \
229
+ __asm__("xvmaddasp %0,%x1,%x3\n\txvmaddasp %L0,%x2,%x3" \
230
+ : "+&d"(b##iter2) \
231
+ : "wa"(result##iter3.packet[0]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
232
+ } else { \
233
+ __asm__("xvmaddadp %0,%x1,%x3\n\txvmaddadp %L0,%x2,%x3" \
234
+ : "+&d"(b##iter2) \
235
+ : "wa"(result##iter2.packet[2]), "wa"(result##iter2.packet[0]), "wa"(palpha)); \
236
+ }
237
+ #endif
238
+
239
+ #define GEMV_WORK2_COL_MMA(iter1, iter2, iter3, N) \
240
+ if (GEMV_GETN(N) > iter1) { \
241
+ if (GEMV_IS_FLOAT) { \
242
+ GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter2); \
243
+ } else { \
244
+ GEMV_WORKPAIR2_COL_MMA(iter2, iter2, iter2 << 1); \
245
+ GEMV_WORKPAIR2_COL_MMA(iter3, iter3, iter3 << 1); \
246
+ } \
247
+ }
248
+
249
+ #define GEMV_STOREPAIR2_COL_MMA(iter1, iter2) \
250
+ *reinterpret_cast<__vector_pair*>(res + i + ((iter2) * ResPacketSize)) = b##iter1;
251
+
252
+ #define GEMV_STORE_COL_MMA(iter, N) \
253
+ if (GEMV_GETN(N) > iter) { \
254
+ if (GEMV_IS_FLOAT) { \
255
+ storeMaddData<ResPacket, ResScalar>(res + i + (iter * ResPacketSize), palpha, result##iter.packet[0]); \
256
+ } else { \
257
+ GEMV_LOADPAIR2_COL_MMA(iter, iter << 1) \
258
+ GEMV_WORKPAIR2_COL_MMA(iter, iter, iter << 1) \
259
+ GEMV_STOREPAIR2_COL_MMA(iter, iter << 1) \
260
+ } \
261
+ }
262
+
263
+ #define GEMV_STORE2_COL_MMA(iter1, iter2, iter3, N) \
264
+ if (GEMV_GETN(N) > iter1) { \
265
+ if (GEMV_IS_FLOAT) { \
266
+ GEMV_STOREPAIR2_COL_MMA(iter2, iter2); \
267
+ } else { \
268
+ GEMV_STOREPAIR2_COL_MMA(iter2, iter2 << 1) \
269
+ GEMV_STOREPAIR2_COL_MMA(iter3, iter3 << 1) \
270
+ } \
271
+ }
272
+
273
+ #define GEMV_PROCESS_COL_ONE_MMA(N) \
274
+ GEMV_UNROLL(GEMV_INIT_MMA, N) \
275
+ Index j = j2; \
276
+ __vector_pair b0, b1, b2, b3, b4, b5, b6, b7; \
277
+ do { \
278
+ LhsPacket g0, g1, g2, g3, g4, g5, g6, g7; \
279
+ RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
280
+ GEMV_UNROLL(GEMV_PREFETCH, N) \
281
+ GEMV_LOAD_COL_MMA(N) \
282
+ GEMV_WORK_COL_MMA(N) \
283
+ } while (++j < jend); \
284
+ GEMV_UNROLL(GEMV_DISASSEMBLE_MMA, N) \
285
+ if (GEMV_GETN(N) <= 1) { \
286
+ GEMV_UNROLL(GEMV_STORE_COL_MMA, N) \
287
+ } else { \
288
+ GEMV_UNROLL_HALF(GEMV_LOAD2_COL_MMA, (N >> 1)) \
289
+ GEMV_UNROLL_HALF(GEMV_WORK2_COL_MMA, (N >> 1)) \
290
+ GEMV_UNROLL_HALF(GEMV_STORE2_COL_MMA, (N >> 1)) \
291
+ } \
292
+ i += (ResPacketSize * N);
293
+ #endif
294
+
295
+ #define GEMV_INIT(iter, N) \
296
+ if (N > iter) { \
297
+ c##iter = pset1<ResPacket>(ResScalar(0)); \
298
+ } else { \
299
+ EIGEN_UNUSED_VARIABLE(c##iter); \
300
+ }
301
+
302
+ #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
303
+ #define GEMV_PREFETCH(iter, N) \
304
+ if (GEMV_GETN(N) > ((iter >> 1) + ((N >> 1) * (iter & 1)))) { \
305
+ lhs.prefetch(i + (iter * LhsPacketSize) + prefetch_dist, j); \
306
+ }
307
+ #else
308
+ #define GEMV_PREFETCH(iter, N)
309
+ #endif
310
+
311
+ #define GEMV_WORK_COL(iter, N) \
312
+ if (N > iter) { \
313
+ c##iter = pcj.pmadd(GEMV_LOADPACKET_COL(iter), a0, c##iter); \
314
+ }
315
+
316
+ #define GEMV_STORE_COL(iter, N) \
317
+ if (N > iter) { \
318
+ pstoreu(res + i + (iter * ResPacketSize), \
319
+ pmadd(c##iter, palpha, ploadu<ResPacket>(res + i + (iter * ResPacketSize)))); \
320
+ }
321
+
322
+ /** \internal main macro for gemv_col - initialize accumulators, multiply and add inputs, and store results */
323
+ #define GEMV_PROCESS_COL_ONE(N) \
324
+ GEMV_UNROLL(GEMV_INIT, N) \
325
+ Index j = j2; \
326
+ do { \
327
+ RhsPacket a0 = pset1<RhsPacket>(rhs2(j, 0)); \
328
+ GEMV_UNROLL(GEMV_PREFETCH, N) \
329
+ GEMV_UNROLL(GEMV_WORK_COL, N) \
330
+ } while (++j < jend); \
331
+ GEMV_UNROLL(GEMV_STORE_COL, N) \
332
+ i += (ResPacketSize * N);
333
+
334
+ #ifdef USE_GEMV_MMA
335
+ #define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE_MMA(N)
336
+ #else
337
+ #define GEMV_PROCESS_COL(N) GEMV_PROCESS_COL_ONE(N)
338
+ #endif
339
+
340
+ /** \internal perform a matrix multiply and accumulate of packet a and packet b */
341
+ #ifdef USE_GEMV_MMA
342
+ template <typename LhsPacket, typename RhsPacket, bool accumulate>
343
+ EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b) {
344
+ if (accumulate) {
345
+ __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
346
+ } else {
347
+ __builtin_mma_xvf32ger(acc, (__vector unsigned char)a, (__vector unsigned char)b);
348
+ }
349
+ }
350
+
351
+ /** \internal perform a matrix multiply and accumulate of vector_pair a and packet b */
352
+ template <typename LhsPacket, typename RhsPacket, bool accumulate>
353
+ EIGEN_ALWAYS_INLINE void pger_vecMMA_acc(__vector_quad* acc, __vector_pair& a, const LhsPacket& b) {
354
+ if (accumulate) {
355
+ __builtin_mma_xvf64gerpp(acc, a, (__vector unsigned char)b);
356
+ } else {
357
+ __builtin_mma_xvf64ger(acc, a, (__vector unsigned char)b);
358
+ }
359
+ }
360
+ #endif
361
+
362
+ template <typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
363
+ EIGEN_STRONG_INLINE void gemv_col(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res,
364
+ Index resIncr, ResScalar alpha) {
365
+ typedef gemv_traits<LhsScalar, RhsScalar> Traits;
366
+
367
+ typedef typename Traits::LhsPacket LhsPacket;
368
+ typedef typename Traits::RhsPacket RhsPacket;
369
+ typedef typename Traits::ResPacket ResPacket;
370
+
371
+ EIGEN_UNUSED_VARIABLE(resIncr);
372
+ eigen_internal_assert(resIncr == 1);
373
+
374
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
375
+ // This helps GCC to generate proper code.
376
+ LhsMapper lhs(alhs);
377
+ RhsMapper rhs2(rhs);
378
+
379
+ conj_helper<LhsScalar, RhsScalar, false, false> cj;
380
+ conj_helper<LhsPacket, RhsPacket, false, false> pcj;
381
+
382
+ const Index lhsStride = lhs.stride();
383
+ // TODO: for padded aligned inputs, we could enable aligned reads
384
+ enum {
385
+ LhsAlignment = Unaligned,
386
+ ResPacketSize = Traits::ResPacketSize,
387
+ LhsPacketSize = Traits::LhsPacketSize,
388
+ RhsPacketSize = Traits::RhsPacketSize,
389
+ };
390
+
391
+ #ifndef GCC_ONE_VECTORPAIR_BUG
392
+ const Index n8 = rows - 8 * ResPacketSize + 1;
393
+ const Index n4 = rows - 4 * ResPacketSize + 1;
394
+ const Index n2 = rows - 2 * ResPacketSize + 1;
395
+ #endif
396
+ const Index n1 = rows - 1 * ResPacketSize + 1;
397
+ #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
398
+ const Index prefetch_dist = 64 * LhsPacketSize;
399
+ #endif
400
+
401
+ // TODO: improve the following heuristic:
402
+ const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 16000 ? 16 : 8);
403
+ ResPacket palpha = pset1<ResPacket>(alpha);
404
+
405
+ for (Index j2 = 0; j2 < cols; j2 += block_cols) {
406
+ Index jend = numext::mini(j2 + block_cols, cols);
407
+ Index i = 0;
408
+ ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
409
+ #ifdef USE_GEMV_MMA
410
+ __vector_quad e0, e1, e2, e3, e4, e5, e6, e7;
411
+ PacketBlock<ResPacket, 4> result0, result1, result2, result3, result4, result5, result6, result7;
412
+ GEMV_UNUSED(8, e)
413
+ GEMV_UNUSED(8, result)
414
+ GEMV_UNUSED_EXTRA(1, c)
415
+ #endif
416
+ #ifndef GCC_ONE_VECTORPAIR_BUG
417
+ while (i < n8) {
418
+ GEMV_PROCESS_COL(8)
419
+ }
420
+ if (i < n4) {
421
+ GEMV_PROCESS_COL(4)
422
+ }
423
+ if (i < n2) {
424
+ GEMV_PROCESS_COL(2)
425
+ }
426
+ if (i < n1)
427
+ #else
428
+ while (i < n1)
429
+ #endif
430
+ {
431
+ GEMV_PROCESS_COL_ONE(1)
432
+ }
433
+ for (; i < rows; ++i) {
434
+ ResScalar d0(0);
435
+ Index j = j2;
436
+ do {
437
+ d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
438
+ } while (++j < jend);
439
+ res[i] += alpha * d0;
440
+ }
441
+ }
442
+ }
443
+
444
+ template <bool extraRows>
445
+ EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float* result, Packet4f pAlpha, Index extra_rows) {
446
+ Packet4f d0 = ploadu<Packet4f>(result);
447
+ d0 = pmadd(acc, pAlpha, d0);
448
+ if (extraRows) {
449
+ pstoreu_partial(result, d0, extra_rows);
450
+ } else {
451
+ pstoreu(result, d0);
452
+ }
453
+ }
454
+
455
+ template <Index num_acc, bool extraRows, Index size>
456
+ EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float* result, Packet4f pAlpha,
457
+ Index extra_rows) {
458
+ constexpr Index real_acc = (num_acc - (extraRows ? 1 : 0));
459
+ for (Index k = 0; k < real_acc; k++) {
460
+ outputVecCol<false>(acc[k][0], result + k * 4, pAlpha, extra_rows);
461
+ }
462
+ if (extraRows) {
463
+ outputVecCol<true>(acc[real_acc][0], result + real_acc * 4, pAlpha, extra_rows);
464
+ }
465
+ }
466
+
467
+ static Packet16uc p16uc_MERGE16_32_V1 = {0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17, 0, 1, 16, 17};
468
+ static Packet16uc p16uc_MERGE16_32_V2 = {2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19, 2, 3, 18, 19};
469
+
470
+ template <Index num_acc, typename LhsMapper, bool zero>
471
+ EIGEN_ALWAYS_INLINE void loadVecLoopVSX(Index k, LhsMapper& lhs, Packet4f (&a0)[num_acc][2]) {
472
+ Packet8bf c0 = lhs.template loadPacket<Packet8bf>(k * 4, 0);
473
+ Packet8bf b1;
474
+ if (!zero) {
475
+ b1 = lhs.template loadPacket<Packet8bf>(k * 4, 1);
476
+
477
+ a0[k + 0][1] = oneConvertBF16Hi(b1.m_val);
478
+ }
479
+ a0[k + 0][0] = oneConvertBF16Hi(c0.m_val);
480
+
481
+ if (num_acc > (k + 1)) {
482
+ a0[k + 1][0] = oneConvertBF16Lo(c0.m_val);
483
+ if (!zero) {
484
+ a0[k + 1][1] = oneConvertBF16Lo(b1.m_val);
485
+ }
486
+ }
487
+ }
488
+
489
+ template <Index num_acc, bool zero>
490
+ EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[num_acc][2], Packet4f (&b0)[2]) {
491
+ for (Index k = 0; k < num_acc; k++) {
492
+ for (Index i = 0; i < (zero ? 1 : 2); i++) {
493
+ acc[k][i] = pmadd(b0[i], a0[k][i], acc[k][i]);
494
+ }
495
+ }
496
+ }
497
+
498
+ template <typename RhsMapper, bool linear>
499
+ struct loadColData_impl {
500
+ // linear == false
501
+ static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) {
502
+ const Index n = unpacket_traits<Packet8bf>::size;
503
+ EIGEN_ALIGN16 bfloat16 to[n];
504
+ LOAD_STORE_UNROLL_16
505
+ for (Index i = 0; i < n; i++) {
506
+ to[i] = rhs(j + i, 0);
507
+ }
508
+ return pload<Packet8bf>(to);
509
+ }
510
+ };
511
+
512
+ template <typename RhsMapper>
513
+ struct loadColData_impl<RhsMapper, true> {
514
+ // linear == true
515
+ static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j) {
516
+ return rhs.template loadPacket<Packet8bf>(j + 0, 0);
517
+ }
518
+ };
519
+
520
+ template <typename RhsMapper, bool linear>
521
+ EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) {
522
+ return loadColData_impl<RhsMapper, linear>::run(rhs, j);
523
+ }
524
+
525
+ template <Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
526
+ EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2]) {
527
+ Packet4f a0[num_acc][2], b0[2];
528
+ Packet8bf b2 = loadColData<RhsMapper, linear>(rhs, j);
529
+
530
+ b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1);
531
+ if (!zero) {
532
+ b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2);
533
+ }
534
+
535
+ using LhsSubMapper = typename LhsMapper::SubMapper;
536
+
537
+ LhsSubMapper lhs2 = lhs.getSubMapper(0, j);
538
+ for (Index k = 0; k < num_acc; k += 2) {
539
+ loadVecLoopVSX<num_acc, LhsSubMapper, zero>(k, lhs2, a0);
540
+ }
541
+
542
+ multVecVSX<num_acc, zero>(acc, a0, b0);
543
+ }
544
+
545
+ template <Index num_acc>
546
+ EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2]) {
547
+ for (Index i = 0; i < num_acc; i++) {
548
+ acc[i][0] = acc[i][0] + acc[i][1];
549
+ }
550
+ }
551
+
552
+ // Uses 2X the accumulators or 4X the number of VSX registers
553
+ #define MAX_BFLOAT16_VEC_ACC_VSX 8
554
+
555
+ template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
556
+ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
557
+ float* result) {
558
+ constexpr Index step = (num_acc * 4);
559
+ const Index extra_rows = (extraRows) ? (rows & 3) : 0;
560
+ constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
561
+
562
+ do {
563
+ Packet4f acc[num_acc][2];
564
+
565
+ zeroAccumulators<num_acc, 2>(acc);
566
+
567
+ using LhsSubMapper = typename LhsMapper::SubMapper;
568
+
569
+ LhsSubMapper lhs2 = lhs.getSubMapper(row, 0);
570
+ for (Index j = 0; j + 2 <= cend; j += 2) {
571
+ vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
572
+ }
573
+ if (cend & 1) {
574
+ vecColLoopVSX<num_acc, LhsSubMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
575
+ }
576
+
577
+ addResultsVSX<num_acc>(acc);
578
+
579
+ outputVecColResults<num_acc, extraRows, 2>(acc, result, pAlpha, extra_rows);
580
+
581
+ result += step;
582
+ } while (multiIters && (step <= rows - (row += step)));
583
+ }
584
+
585
+ template <const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
586
+ EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
587
+ const Packet4f pAlpha, float* result) {
588
+ if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
589
+ colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs,
590
+ rhs, pAlpha, result);
591
+ }
592
+ }
593
+
594
+ template <typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
595
+ EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
596
+ const Packet4f pAlpha, float* result) {
597
+ switch ((rows - row) >> 2) {
598
+ case 7:
599
+ colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
600
+ break;
601
+ case 6:
602
+ colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
603
+ break;
604
+ case 5:
605
+ colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
606
+ break;
607
+ case 4:
608
+ colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
609
+ break;
610
+ case 3:
611
+ colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
612
+ break;
613
+ case 2:
614
+ colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
615
+ break;
616
+ case 1:
617
+ colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
618
+ break;
619
+ default:
620
+ if (extraRows) {
621
+ colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
622
+ }
623
+ break;
624
+ }
625
+ }
626
+
627
+ template <typename LhsMapper, typename RhsMapper, bool linear>
628
+ EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs,
629
+ const Packet4f pAlpha, float* result) {
630
+ Index row = 0;
631
+ if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) {
632
+ colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs,
633
+ pAlpha, result);
634
+ result += row;
635
+ }
636
+ if (rows & 3) {
637
+ colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
638
+ } else {
639
+ colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
640
+ }
641
+ }
642
+
643
+ template <const Index size, bool inc, Index delta>
644
+ EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra) {
645
+ if (inc) {
646
+ if (size < 8) {
647
+ pscatter_partial(dst + delta * resInc, data, resInc, extra);
648
+ } else {
649
+ pscatter(dst + delta * resInc, data, resInc);
650
+ }
651
+ } else {
652
+ if (size < 8) {
653
+ pstoreu_partial(dst + delta, data, extra);
654
+ } else {
655
+ pstoreu(dst + delta, data);
656
+ }
657
+ }
658
+ }
659
+
660
+ template <const Index size, bool inc = false>
661
+ EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst,
662
+ Index resInc = 1) {
663
+ constexpr Index extra = ((size < 8) ? 8 : size);
664
+ while (i + size <= rows) {
665
+ PacketBlock<Packet8bf, (size + 7) / 8> r32;
666
+ r32.packet[0] = convertF32toBF16VSX(result + i + 0);
667
+ if (size >= 16) {
668
+ r32.packet[1] = convertF32toBF16VSX(result + i + 8);
669
+ }
670
+ if (size >= 32) {
671
+ r32.packet[2] = convertF32toBF16VSX(result + i + 16);
672
+ r32.packet[3] = convertF32toBF16VSX(result + i + 24);
673
+ }
674
+ storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc, rows & 7);
675
+ if (size >= 16) {
676
+ storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
677
+ }
678
+ if (size >= 32) {
679
+ storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
680
+ storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
681
+ }
682
+ i += extra;
683
+ dst += extra * resInc;
684
+ if (size != 32) break;
685
+ }
686
+ }
687
+
688
+ template <bool inc = false>
689
+ EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float* result, Index rows, bfloat16* dst, Index resInc = 1) {
690
+ Index i = 0;
691
+ convertPointerF32toBF16VSX<32, inc>(i, result, rows, dst, resInc);
692
+ convertPointerF32toBF16VSX<16, inc>(i, result, rows, dst, resInc);
693
+ convertPointerF32toBF16VSX<8, inc>(i, result, rows, dst, resInc);
694
+ convertPointerF32toBF16VSX<1, inc>(i, result, rows, dst, resInc);
695
+ }
696
+
697
+ template <typename RhsMapper, typename LhsMapper, typename = void>
698
+ struct UseStride : std::false_type {
699
+ static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
700
+ float* result) {
701
+ using RhsSubMapper = typename RhsMapper::SubMapper;
702
+
703
+ RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
704
+ calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
705
+ }
706
+ };
707
+
708
+ template <typename RhsMapper, typename LhsMapper>
709
+ struct UseStride<RhsMapper, LhsMapper,
710
+ std::enable_if_t<std::is_member_function_pointer<decltype(&RhsMapper::stride)>::value>>
711
+ : std::true_type {
712
+ static EIGEN_ALWAYS_INLINE void run(Index j2, Index jend, Index rows, LhsMapper& lhs, RhsMapper& rhs, Packet4f pAlpha,
713
+ float* result) {
714
+ using RhsSubMapper = typename RhsMapper::SubMapper;
715
+
716
+ RhsSubMapper rhs2 = rhs.getSubMapper(j2, 0);
717
+ if (rhs.stride() == 1) {
718
+ calcVSXVecColLoops<LhsMapper, RhsSubMapper, true>(jend - j2, rows, lhs, rhs2, pAlpha, result);
719
+ } else {
720
+ calcVSXVecColLoops<LhsMapper, RhsSubMapper, false>(jend - j2, rows, lhs, rhs2, pAlpha, result);
721
+ }
722
+ }
723
+ };
724
+
725
+ template <typename LhsMapper, typename RhsMapper>
726
+ void gemv_bfloat16_col(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, bfloat16* res,
727
+ Index resIncr, bfloat16 alpha) {
728
+ EIGEN_UNUSED_VARIABLE(resIncr);
729
+ eigen_internal_assert(resIncr == 1);
730
+
731
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
732
+ // This helps GCC to generate proper code.
733
+ LhsMapper lhs(alhs);
734
+ RhsMapper rhs2(rhs);
735
+
736
+ const Index lhsStride = lhs.stride();
737
+
738
+ // TODO: improve the following heuristic:
739
+ const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
740
+ float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
741
+ Packet4f pAlpha = pset1<Packet4f>(falpha);
742
+
743
+ ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
744
+
745
+ convertArrayPointerBF16toF32(result, 1, rows, res);
746
+
747
+ for (Index j2 = 0; j2 < cols; j2 += block_cols) {
748
+ Index jend = numext::mini(j2 + block_cols, cols);
749
+
750
+ using LhsSubMapper = typename LhsMapper::SubMapper;
751
+
752
+ LhsSubMapper lhs2 = lhs.getSubMapper(0, j2);
753
+ UseStride<RhsMapper, LhsSubMapper>::run(j2, jend, rows, lhs2, rhs2, pAlpha, result);
754
+ }
755
+
756
+ convertArrayPointerF32toBF16VSX(result, rows, res);
757
+ }
758
+
759
+ template <Index num_acc, Index size>
760
+ EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float* result, Packet4f pAlpha) {
761
+ constexpr Index extra = num_acc & 3;
762
+
763
+ for (Index k = 0; k < num_acc; k += 4) {
764
+ Packet4f d0 = ploadu<Packet4f>(result + k);
765
+ d0 = pmadd(acc[k + 0][0], pAlpha, d0);
766
+
767
+ if (num_acc > (k + 3)) {
768
+ pstoreu(result + k, d0);
769
+ } else {
770
+ if (extra == 3) {
771
+ pstoreu_partial(result + k, d0, extra);
772
+ } else {
773
+ memcpy((void*)(result + k), (void*)(&d0), sizeof(float) * extra);
774
+ }
775
+ }
776
+ }
777
+ }
778
+
779
+ template <Index num_acc>
780
+ EIGEN_ALWAYS_INLINE void preduxVecResults2VSX(Packet4f (&acc)[num_acc][2], Index k) {
781
+ if (num_acc > (k + 1)) {
782
+ acc[k][1] = vec_mergel(acc[k + 0][0], acc[k + 1][0]);
783
+ acc[k][0] = vec_mergeh(acc[k + 0][0], acc[k + 1][0]);
784
+ acc[k][0] = acc[k][0] + acc[k][1];
785
+ acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
786
+ } else {
787
+ acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
788
+ #ifdef _BIG_ENDIAN
789
+ acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
790
+ #else
791
+ acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
792
+ #endif
793
+ }
794
+ }
795
+
796
+ template <Index num_acc>
797
+ EIGEN_ALWAYS_INLINE void preduxVecResultsVSX(Packet4f (&acc)[num_acc][2]) {
798
+ for (Index k = 0; k < num_acc; k += 4) {
799
+ preduxVecResults2VSX<num_acc>(acc, k + 0);
800
+ if (num_acc > (k + 2)) {
801
+ preduxVecResults2VSX<num_acc>(acc, k + 2);
802
+ #ifdef EIGEN_VECTORIZE_VSX
803
+ acc[k + 0][0] = reinterpret_cast<Packet4f>(
804
+ vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
805
+ #else
806
+ acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_perm(acc[k + 0][0], acc[k + 2][0], p16uc_TRANSPOSE64_HI));
807
+ #endif
808
+ }
809
+ }
810
+ }
811
+
812
+ #ifndef _ARCH_PWR9
813
+ EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols) {
814
+ Packet16uc shift = pset1<Packet16uc>(8 * 2 * (8 - extra_cols));
815
+ #ifdef _BIG_ENDIAN
816
+ return reinterpret_cast<Packet8us>(vec_slo(vec_sro(reinterpret_cast<Packet16uc>(data), shift), shift));
817
+ #else
818
+ return reinterpret_cast<Packet8us>(vec_sro(vec_slo(reinterpret_cast<Packet16uc>(data), shift), shift));
819
+ #endif
820
+ }
821
+ #endif
822
+
823
+ template <Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
824
+ EIGEN_ALWAYS_INLINE void multVSXVecLoop(Packet4f (&acc)[num_acc][2], const LhsMapper& lhs, RhsMapper& rhs, Index j,
825
+ Index extra_cols) {
826
+ Packet4f a0[num_acc][2], b0[2];
827
+ Packet8bf a1, b1;
828
+
829
+ if (extra) {
830
+ b1 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
831
+ #ifndef _ARCH_PWR9
832
+ b1 = loadPacketPartialZero(b1.m_val, extra_cols);
833
+ #endif
834
+ } else {
835
+ b1 = rhs.template loadPacket<Packet8bf>(j);
836
+ }
837
+ b0[0] = oneConvertBF16Hi(b1.m_val);
838
+ b0[1] = oneConvertBF16Lo(b1.m_val);
839
+
840
+ const LhsMapper lhs2 = lhs.getSubMapper(0, j);
841
+ for (Index k = 0; k < num_acc; k++) {
842
+ if (extra) {
843
+ a1 = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
844
+ #ifndef _ARCH_PWR9
845
+ a1 = loadPacketPartialZero(a1.m_val, extra_cols);
846
+ #endif
847
+ } else {
848
+ a1 = lhs2.template loadPacket<Packet8bf>(k, 0);
849
+ }
850
+ a0[k][0] = oneConvertBF16Hi(a1.m_val);
851
+ a0[k][1] = oneConvertBF16Lo(a1.m_val);
852
+ }
853
+
854
+ multVecVSX<num_acc, false>(acc, a0, b0);
855
+ }
856
+
857
+ template <Index num_acc, typename LhsMapper, typename RhsMapper>
858
+ EIGEN_ALWAYS_INLINE void vecVSXLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2],
859
+ Index extra_cols) {
860
+ Index j = 0;
861
+ for (; j + 8 <= cols; j += 8) {
862
+ multVSXVecLoop<num_acc, LhsMapper, RhsMapper, false>(acc, lhs, rhs, j, extra_cols);
863
+ }
864
+
865
+ if (extra_cols) {
866
+ multVSXVecLoop<num_acc, LhsMapper, RhsMapper, true>(acc, lhs, rhs, j, extra_cols);
867
+ }
868
+ }
869
+
870
+ template <const Index num_acc, typename LhsMapper, typename RhsMapper>
871
+ void colVSXVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
872
+ float* result) {
873
+ constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
874
+ const Index extra_cols = (cols & 7);
875
+
876
+ do {
877
+ Packet4f acc[num_acc][2];
878
+
879
+ zeroAccumulators<num_acc, 2>(acc);
880
+
881
+ const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
882
+ vecVSXLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, acc, extra_cols);
883
+
884
+ addResultsVSX<num_acc>(acc);
885
+
886
+ preduxVecResultsVSX<num_acc>(acc);
887
+
888
+ outputVecResults<num_acc, 2>(acc, result, pAlpha);
889
+
890
+ result += num_acc;
891
+ } while (multiIters && (num_acc <= rows - (row += num_acc)));
892
+ }
893
+
894
+ template <const Index num_acc, typename LhsMapper, typename RhsMapper>
895
+ EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
896
+ const Packet4f pAlpha, float* result) {
897
+ if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
898
+ colVSXVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
899
+ }
900
+ }
901
+
902
+ template <typename LhsMapper, typename RhsMapper>
903
+ EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs,
904
+ const Packet4f pAlpha, float* result) {
905
+ switch (rows - row) {
906
+ case 7:
907
+ colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
908
+ break;
909
+ case 6:
910
+ colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
911
+ break;
912
+ case 5:
913
+ colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
914
+ break;
915
+ case 4:
916
+ colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
917
+ break;
918
+ case 3:
919
+ colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
920
+ break;
921
+ case 2:
922
+ colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
923
+ break;
924
+ case 1:
925
+ colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
926
+ break;
927
+ }
928
+ }
929
+
930
+ template <typename LhsMapper, typename RhsMapper>
931
+ EIGEN_ALWAYS_INLINE void calcVSXVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha,
932
+ float* result) {
933
+ Index row = 0;
934
+ if (rows >= MAX_BFLOAT16_VEC_ACC_VSX) {
935
+ colVSXVecLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
936
+ result += row;
937
+ }
938
+ colVSXVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
939
+ }
940
+
941
+ template <typename LhsMapper, typename RhsMapper>
942
+ EIGEN_STRONG_INLINE void gemv_bfloat16_row(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
943
+ bfloat16* res, Index resIncr, bfloat16 alpha) {
944
+ typedef typename RhsMapper::LinearMapper LinearMapper;
945
+
946
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
947
+ // This helps GCC to generate proper code.
948
+ LhsMapper lhs(alhs);
949
+ LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
950
+
951
+ eigen_internal_assert(rhs.stride() == 1);
952
+
953
+ float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
954
+ const Packet4f pAlpha = pset1<Packet4f>(falpha);
955
+
956
+ ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
957
+ if (resIncr == 1) {
958
+ convertArrayPointerBF16toF32(result, 1, rows, res);
959
+ } else {
960
+ convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
961
+ }
962
+ calcVSXVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
963
+ if (resIncr == 1) {
964
+ convertArrayPointerF32toBF16VSX(result, rows, res);
965
+ } else {
966
+ convertArrayPointerF32toBF16VSX<true>(result, rows, res, resIncr);
967
+ }
968
+ }
969
+
970
+ #undef MAX_BFLOAT16_VEC_ACC_VSX
971
+
972
+ const Packet16uc p16uc_COMPLEX32_XORFLIP = {0x44, 0x55, 0x66, 0x77, 0x00, 0x11, 0x22, 0x33,
973
+ 0xcc, 0xdd, 0xee, 0xff, 0x88, 0x99, 0xaa, 0xbb};
974
+ const Packet16uc p16uc_COMPLEX64_XORFLIP = {0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff,
975
+ 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77};
976
+
977
+ #ifdef _BIG_ENDIAN
978
+ const Packet16uc p16uc_COMPLEX32_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00,
979
+ 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
980
+ const Packet16uc p16uc_COMPLEX64_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
981
+ 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
982
+ const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
983
+ 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
984
+ const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
985
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
986
+ const Packet16uc p16uc_COMPLEX32_NEGATE = {0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00,
987
+ 0x80, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00};
988
+ const Packet16uc p16uc_COMPLEX64_NEGATE = {0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
989
+ 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
990
+ #else
991
+ const Packet16uc p16uc_COMPLEX32_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
992
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
993
+ const Packet16uc p16uc_COMPLEX64_CONJ_XOR = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
994
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
995
+ const Packet16uc p16uc_COMPLEX32_CONJ_XOR2 = {0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00,
996
+ 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00};
997
+ const Packet16uc p16uc_COMPLEX64_CONJ_XOR2 = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
998
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
999
+ const Packet16uc p16uc_COMPLEX32_NEGATE = {0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80,
1000
+ 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x80};
1001
+ const Packet16uc p16uc_COMPLEX64_NEGATE = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80,
1002
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80};
1003
+ #endif
1004
+
1005
+ #ifdef _BIG_ENDIAN
1006
+ #define COMPLEX_DELTA 0
1007
+ #else
1008
+ #define COMPLEX_DELTA 2
1009
+ #endif
1010
+
1011
+ /** \internal packet conjugate (same as pconj but uses the constants in pcplxflipconj for better code generation) */
1012
+ EIGEN_ALWAYS_INLINE Packet2cf pconj2(const Packet2cf& a) {
1013
+ return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_CONJ_XOR)));
1014
+ }
1015
+
1016
+ EIGEN_ALWAYS_INLINE Packet1cd pconj2(const Packet1cd& a) {
1017
+ return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_CONJ_XOR)));
1018
+ }
1019
+
1020
+ /** \internal packet conjugate with real & imaginary operation inverted */
1021
+ EIGEN_ALWAYS_INLINE Packet2cf pconjinv(const Packet2cf& a) {
1022
+ #ifdef __POWER8_VECTOR__
1023
+ return Packet2cf(Packet4f(vec_neg(Packet2d(a.v))));
1024
+ #else
1025
+ return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_CONJ_XOR2)));
1026
+ #endif
1027
+ }
1028
+
1029
+ EIGEN_ALWAYS_INLINE Packet1cd pconjinv(const Packet1cd& a) {
1030
+ return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_CONJ_XOR2)));
1031
+ }
1032
+
1033
+ #if defined(_ARCH_PWR8) && (!EIGEN_COMP_LLVM || __clang_major__ >= 12)
1034
+ #define PERMXOR_GOOD // Clang had a bug with vec_permxor and endianness prior to version 12
1035
+ #endif
1036
+
1037
+ /** \internal flip the real & imaginary results and packet conjugate */
1038
+ EIGEN_ALWAYS_INLINE Packet2cf pcplxflipconj(Packet2cf a) {
1039
+ #ifdef PERMXOR_GOOD
1040
+ return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR, p16uc_COMPLEX32_XORFLIP)));
1041
+ #else
1042
+ return pcplxflip(pconj2(a));
1043
+ #endif
1044
+ }
1045
+
1046
+ EIGEN_ALWAYS_INLINE Packet1cd pcplxflipconj(Packet1cd a) {
1047
+ #ifdef PERMXOR_GOOD
1048
+ return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR, p16uc_COMPLEX64_XORFLIP)));
1049
+ #else
1050
+ return pcplxflip(pconj2(a));
1051
+ #endif
1052
+ }
1053
+
1054
+ /** \internal packet conjugate and flip the real & imaginary results */
1055
+ EIGEN_ALWAYS_INLINE Packet2cf pcplxconjflip(Packet2cf a) {
1056
+ #ifdef PERMXOR_GOOD
1057
+ return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_CONJ_XOR2, p16uc_COMPLEX32_XORFLIP)));
1058
+ #else
1059
+ return pconj2(pcplxflip(a));
1060
+ #endif
1061
+ }
1062
+
1063
+ EIGEN_ALWAYS_INLINE Packet1cd pcplxconjflip(Packet1cd a) {
1064
+ #ifdef PERMXOR_GOOD
1065
+ return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_CONJ_XOR2, p16uc_COMPLEX64_XORFLIP)));
1066
+ #else
1067
+ return pconj2(pcplxflip(a));
1068
+ #endif
1069
+ }
1070
+
1071
+ /** \internal packet negate */
1072
+ EIGEN_ALWAYS_INLINE Packet2cf pnegate2(Packet2cf a) {
1073
+ #ifdef __POWER8_VECTOR__
1074
+ return Packet2cf(vec_neg(a.v));
1075
+ #else
1076
+ return Packet2cf(pxor(a.v, reinterpret_cast<Packet4f>(p16uc_COMPLEX32_NEGATE)));
1077
+ #endif
1078
+ }
1079
+
1080
+ EIGEN_ALWAYS_INLINE Packet1cd pnegate2(Packet1cd a) {
1081
+ #ifdef __POWER8_VECTOR__
1082
+ return Packet1cd(vec_neg(a.v));
1083
+ #else
1084
+ return Packet1cd(pxor(a.v, reinterpret_cast<Packet2d>(p16uc_COMPLEX64_NEGATE)));
1085
+ #endif
1086
+ }
1087
+
1088
+ /** \internal flip the real & imaginary results and negate */
1089
+ EIGEN_ALWAYS_INLINE Packet2cf pcplxflipnegate(Packet2cf a) {
1090
+ #ifdef PERMXOR_GOOD
1091
+ return Packet2cf(Packet4f(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX32_NEGATE, p16uc_COMPLEX32_XORFLIP)));
1092
+ #else
1093
+ return pcplxflip(pnegate2(a));
1094
+ #endif
1095
+ }
1096
+
1097
+ EIGEN_ALWAYS_INLINE Packet1cd pcplxflipnegate(Packet1cd a) {
1098
+ #ifdef PERMXOR_GOOD
1099
+ return Packet1cd(Packet2d(vec_permxor(Packet16uc(a.v), p16uc_COMPLEX64_NEGATE, p16uc_COMPLEX64_XORFLIP)));
1100
+ #else
1101
+ return pcplxflip(pnegate2(a));
1102
+ #endif
1103
+ }
1104
+
1105
+ /** \internal flip the real & imaginary results */
1106
+ EIGEN_ALWAYS_INLINE Packet2cf pcplxflip2(Packet2cf a) {
1107
+ return Packet2cf(Packet4f(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX32_XORFLIP)));
1108
+ }
1109
+
1110
+ EIGEN_ALWAYS_INLINE Packet1cd pcplxflip2(Packet1cd a) {
1111
+ #ifdef EIGEN_VECTORIZE_VSX
1112
+ return Packet1cd(__builtin_vsx_xxpermdi(a.v, a.v, 2));
1113
+ #else
1114
+ return Packet1cd(Packet2d(vec_perm(Packet16uc(a.v), Packet16uc(a.v), p16uc_COMPLEX64_XORFLIP)));
1115
+ #endif
1116
+ }
1117
+
1118
+ /** \internal load half a vector with one complex value */
1119
+ EIGEN_ALWAYS_INLINE Packet4f pload_complex_half(std::complex<float>* src) {
1120
+ Packet4f t;
1121
+ #ifdef EIGEN_VECTORIZE_VSX
1122
+ // Load float64/two float32 (doubleword alignment)
1123
+ __asm__("lxsdx %x0,%y1" : "=wa"(t) : "Z"(*src));
1124
+ #else
1125
+ *reinterpret_cast<std::complex<float>*>(reinterpret_cast<float*>(&t) + COMPLEX_DELTA) = *src;
1126
+ #endif
1127
+ return t;
1128
+ }
1129
+
1130
+ /** \internal load two vectors from the real and imaginary portions of a complex value */
1131
+ template <typename RhsScalar>
1132
+ EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar* src, Packet4f& r, Packet4f& i) {
1133
+ #ifdef _ARCH_PWR9
1134
+ __asm__("lxvwsx %x0,%y1" : "=wa"(r) : "Z"(*(reinterpret_cast<float*>(src) + 0)));
1135
+ __asm__("lxvwsx %x0,%y1" : "=wa"(i) : "Z"(*(reinterpret_cast<float*>(src) + 1)));
1136
+ #else
1137
+ Packet4f t = pload_complex_half(src);
1138
+ r = vec_splat(t, COMPLEX_DELTA + 0);
1139
+ i = vec_splat(t, COMPLEX_DELTA + 1);
1140
+ #endif
1141
+ }
1142
+
1143
+ template <typename RhsScalar>
1144
+ EIGEN_ALWAYS_INLINE void pload_realimag(RhsScalar* src, Packet2d& r, Packet2d& i) {
1145
+ #ifdef EIGEN_VECTORIZE_VSX
1146
+ __asm__("lxvdsx %x0,%y1" : "=wa"(r) : "Z"(*(reinterpret_cast<double*>(src) + 0)));
1147
+ __asm__("lxvdsx %x0,%y1" : "=wa"(i) : "Z"(*(reinterpret_cast<double*>(src) + 1)));
1148
+ #else
1149
+ Packet2d t = ploadu<Packet2d>(reinterpret_cast<double*>(src));
1150
+ r = vec_splat(t, 0);
1151
+ i = vec_splat(t, 1);
1152
+ #endif
1153
+ }
1154
+
1155
+ #ifndef __POWER8_VECTOR__
1156
+ const Packet16uc p16uc_MERGEE = {0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13,
1157
+ 0x08, 0x09, 0x0A, 0x0B, 0x18, 0x19, 0x1A, 0x1B};
1158
+
1159
+ const Packet16uc p16uc_MERGEO = {0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17,
1160
+ 0x0C, 0x0D, 0x0E, 0x0F, 0x1C, 0x1D, 0x1E, 0x1F};
1161
+ #endif
1162
+
1163
+ /** \internal load two vectors from the interleaved real & imaginary values of src */
1164
+ template <typename RhsScalar>
1165
+ EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet4f& r, Packet4f& i) {
1166
+ Packet4f t = ploadu<Packet4f>(reinterpret_cast<float*>(src));
1167
+ #ifdef __POWER8_VECTOR__
1168
+ r = vec_mergee(t, t);
1169
+ i = vec_mergeo(t, t);
1170
+ #else
1171
+ r = vec_perm(t, t, p16uc_MERGEE);
1172
+ i = vec_perm(t, t, p16uc_MERGEO);
1173
+ #endif
1174
+ }
1175
+
1176
+ template <typename RhsScalar>
1177
+ EIGEN_ALWAYS_INLINE void pload_realimag_row(RhsScalar* src, Packet2d& r, Packet2d& i) {
1178
+ return pload_realimag(src, r, i);
1179
+ }
1180
+
1181
+ /** \internal load and splat a complex value into a vector - column-wise */
1182
+ EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine(std::complex<float>* src) {
1183
+ #ifdef EIGEN_VECTORIZE_VSX
1184
+ Packet4f ret;
1185
+ __asm__("lxvdsx %x0,%y1" : "=wa"(ret) : "Z"(*(reinterpret_cast<double*>(src) + 0)));
1186
+ return ret;
1187
+ #else
1188
+ return Packet4f(ploaddup<Packet2d>(reinterpret_cast<double*>(src)));
1189
+ #endif
1190
+ }
1191
+
1192
+ EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine(std::complex<double>* src) { return ploadu<Packet1cd>(src).v; }
1193
+
1194
+ /** \internal load a complex value into a vector - row-wise */
1195
+ EIGEN_ALWAYS_INLINE Packet4f pload_realimag_combine_row(std::complex<float>* src) { return ploadu<Packet2cf>(src).v; }
1196
+
1197
+ EIGEN_ALWAYS_INLINE Packet2d pload_realimag_combine_row(std::complex<double>* src) { return ploadu<Packet1cd>(src).v; }
1198
+
1199
+ /** \internal load a scalar or a vector from complex location */
1200
+ template <typename ResPacket>
1201
+ EIGEN_ALWAYS_INLINE Packet4f pload_complex(std::complex<float>* src) {
1202
+ if (GEMV_IS_SCALAR) {
1203
+ return pload_complex_half(src);
1204
+ } else {
1205
+ return ploadu<Packet4f>(reinterpret_cast<float*>(src));
1206
+ }
1207
+ }
1208
+
1209
+ template <typename ResPacket>
1210
+ EIGEN_ALWAYS_INLINE Packet2d pload_complex(std::complex<double>* src) {
1211
+ return ploadu<Packet2d>(reinterpret_cast<double*>(src));
1212
+ }
1213
+
1214
+ /** \internal load from a complex vector and convert to a real vector */
1215
+ template <typename ResPacket>
1216
+ EIGEN_ALWAYS_INLINE Packet4f pload_complex(Packet2cf* src) {
1217
+ return src->v;
1218
+ }
1219
+
1220
+ template <typename ResPacket>
1221
+ EIGEN_ALWAYS_INLINE Packet2d pload_complex(Packet1cd* src) {
1222
+ return src->v;
1223
+ }
1224
+
1225
+ /** \internal load a full vector from complex location - column-wise */
1226
+ EIGEN_ALWAYS_INLINE Packet4f pload_complex_full(std::complex<float>* src) {
1227
+ return Packet4f(ploaddup<Packet2d>(reinterpret_cast<double*>(src)));
1228
+ }
1229
+
1230
+ EIGEN_ALWAYS_INLINE Packet2d pload_complex_full(std::complex<double>* src) { return ploadu<Packet1cd>(src).v; }
1231
+
1232
+ /** \internal load a full vector from complex location - row-wise */
1233
+ EIGEN_ALWAYS_INLINE Packet4f pload_complex_full_row(std::complex<float>* src) { return ploadu<Packet2cf>(src).v; }
1234
+
1235
+ EIGEN_ALWAYS_INLINE Packet2d pload_complex_full_row(std::complex<double>* src) { return pload_complex_full(src); }
1236
+
1237
+ /** \internal load a vector from a real-only scalar location - column-wise */
1238
+ EIGEN_ALWAYS_INLINE Packet4f pload_real(float* src) { return pset1<Packet4f>(*src); }
1239
+
1240
+ EIGEN_ALWAYS_INLINE Packet2d pload_real(double* src) { return pset1<Packet2d>(*src); }
1241
+
1242
+ EIGEN_ALWAYS_INLINE Packet4f pload_real(Packet4f& src) { return src; }
1243
+
1244
+ EIGEN_ALWAYS_INLINE Packet2d pload_real(Packet2d& src) { return src; }
1245
+
1246
+ /** \internal load a vector from a real-only vector location */
1247
+ EIGEN_ALWAYS_INLINE Packet4f pload_real_full(float* src) {
1248
+ Packet4f ret = ploadu<Packet4f>(src);
1249
+ return vec_mergeh(ret, ret);
1250
+ }
1251
+
1252
+ EIGEN_ALWAYS_INLINE Packet2d pload_real_full(double* src) { return pload_real(src); }
1253
+
1254
+ EIGEN_ALWAYS_INLINE Packet4f pload_real_full(std::complex<float>* src) {
1255
+ return pload_complex_full(src); // Just for compilation
1256
+ }
1257
+
1258
+ EIGEN_ALWAYS_INLINE Packet2d pload_real_full(std::complex<double>* src) {
1259
+ return pload_complex_full(src); // Just for compilation
1260
+ }
1261
+
1262
+ /** \internal load a vector from a real-only scalar location - row-wise */
1263
+ template <typename ResPacket>
1264
+ EIGEN_ALWAYS_INLINE Packet4f pload_real_row(float* src) {
1265
+ if (GEMV_IS_SCALAR) {
1266
+ return pload_real_full(src);
1267
+ } else {
1268
+ return ploadu<Packet4f>(src);
1269
+ }
1270
+ }
1271
+
1272
+ template <typename ResPacket>
1273
+ EIGEN_ALWAYS_INLINE Packet2d pload_real_row(double* src) {
1274
+ return pload_real(src);
1275
+ }
1276
+
1277
+ EIGEN_ALWAYS_INLINE Packet2cf padd(Packet2cf& a, std::complex<float>& b) {
1278
+ EIGEN_UNUSED_VARIABLE(b);
1279
+ return a; // Just for compilation
1280
+ }
1281
+
1282
+ EIGEN_ALWAYS_INLINE Packet1cd padd(Packet1cd& a, std::complex<double>& b) {
1283
+ EIGEN_UNUSED_VARIABLE(b);
1284
+ return a; // Just for compilation
1285
+ }
1286
+
1287
+ /** \internal set a scalar from complex location */
1288
+ template <typename Scalar, typename ResScalar>
1289
+ EIGEN_ALWAYS_INLINE Scalar pset1_realimag(ResScalar& alpha, int which, int conj) {
1290
+ return (which) ? ((conj) ? -alpha.real() : alpha.real()) : ((conj) ? -alpha.imag() : alpha.imag());
1291
+ }
1292
+
1293
+ /** \internal set a vector from complex location */
1294
+ template <typename Scalar, typename ResScalar, typename ResPacket, int which>
1295
+ EIGEN_ALWAYS_INLINE Packet2cf pset1_complex(std::complex<float>& alpha) {
1296
+ Packet2cf ret;
1297
+ ret.v[COMPLEX_DELTA + 0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1298
+ ret.v[COMPLEX_DELTA + 1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1299
+ ret.v[2 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 0];
1300
+ ret.v[3 - COMPLEX_DELTA] = ret.v[COMPLEX_DELTA + 1];
1301
+ return ret;
1302
+ }
1303
+
1304
+ template <typename Scalar, typename ResScalar, typename ResPacket, int which>
1305
+ EIGEN_ALWAYS_INLINE Packet1cd pset1_complex(std::complex<double>& alpha) {
1306
+ Packet1cd ret;
1307
+ ret.v[0] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x01), (which & 0x04));
1308
+ ret.v[1] = pset1_realimag<Scalar, ResScalar>(alpha, (which & 0x02), (which & 0x08));
1309
+ return ret;
1310
+ }
1311
+
1312
+ /** \internal zero out a vector for real or complex forms */
1313
+ template <typename Packet>
1314
+ EIGEN_ALWAYS_INLINE Packet pset_zero() {
1315
+ return pset1<Packet>(__UNPACK_TYPE__(Packet)(0));
1316
+ }
1317
+
1318
+ template <>
1319
+ EIGEN_ALWAYS_INLINE Packet2cf pset_zero<Packet2cf>() {
1320
+ return Packet2cf(pset1<Packet4f>(float(0)));
1321
+ }
1322
+
1323
+ template <>
1324
+ EIGEN_ALWAYS_INLINE Packet1cd pset_zero<Packet1cd>() {
1325
+ return Packet1cd(pset1<Packet2d>(double(0)));
1326
+ }
1327
+
1328
+ /** \internal initialize a vector from another vector */
1329
+ template <typename Packet, typename LhsPacket, typename RhsPacket>
1330
+ EIGEN_ALWAYS_INLINE Packet pset_init(Packet& c1) {
1331
+ if (GEMV_IS_COMPLEX_COMPLEX) {
1332
+ EIGEN_UNUSED_VARIABLE(c1);
1333
+ return pset_zero<Packet>();
1334
+ } else {
1335
+ return c1; // Intentionally left uninitialized
1336
+ }
1337
+ }
1338
+
1339
+ template <typename PResPacket, typename ResPacket, typename ResScalar, typename Scalar>
1340
+ struct alpha_store {
1341
+ alpha_store(ResScalar& alpha) {
1342
+ separate.r = pset1_complex<Scalar, ResScalar, ResPacket, 0x3>(alpha);
1343
+ separate.i = pset1_complex<Scalar, ResScalar, ResPacket, 0x0>(alpha);
1344
+ }
1345
+ struct ri {
1346
+ PResPacket r;
1347
+ PResPacket i;
1348
+ } separate;
1349
+ };
1350
+
1351
+ /** \internal multiply and add for complex math */
1352
+ template <typename ScalarPacket, typename AlphaData>
1353
+ EIGEN_ALWAYS_INLINE ScalarPacket pmadd_complex(ScalarPacket& c0, ScalarPacket& c2, ScalarPacket& c4, AlphaData& b0) {
1354
+ return pmadd(c2, b0.separate.i.v, pmadd(c0, b0.separate.r.v, c4));
1355
+ }
1356
+
1357
+ /** \internal store and madd for complex math */
1358
+ template <typename Scalar, typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar,
1359
+ typename AlphaData>
1360
+ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, AlphaData& b0, ResScalar* res) {
1361
+ PResPacket c2 = pcplxflipconj(c0);
1362
+ if (GEMV_IS_SCALAR) {
1363
+ ScalarPacket c4 = ploadu<ScalarPacket>(reinterpret_cast<Scalar*>(res));
1364
+ ScalarPacket c3 = pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0);
1365
+ pstoreu(reinterpret_cast<Scalar*>(res), c3);
1366
+ } else {
1367
+ ScalarPacket c4 = pload_complex<ResPacket>(res);
1368
+ PResPacket c3 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1369
+ pstoreu(res, c3);
1370
+ }
1371
+ }
1372
+
1373
+ template <typename ScalarPacket, typename PResPacket, typename ResPacket, typename ResScalar, typename AlphaData,
1374
+ Index ResPacketSize, Index iter2>
1375
+ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, AlphaData& b0, ResScalar* res) {
1376
+ PResPacket c2 = pcplxflipconj(c0);
1377
+ PResPacket c3 = pcplxflipconj(c1);
1378
+ #if !defined(_ARCH_PWR10)
1379
+ ScalarPacket c4 = pload_complex<ResPacket>(res + (iter2 * ResPacketSize));
1380
+ ScalarPacket c5 = pload_complex<ResPacket>(res + ((iter2 + 1) * ResPacketSize));
1381
+ PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
1382
+ PResPacket c7 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c5, b0));
1383
+ pstoreu(res + (iter2 * ResPacketSize), c6);
1384
+ pstoreu(res + ((iter2 + 1) * ResPacketSize), c7);
1385
+ #else
1386
+ __vector_pair a = *reinterpret_cast<__vector_pair*>(res + (iter2 * ResPacketSize));
1387
+ #if EIGEN_COMP_LLVM
1388
+ PResPacket c6[2];
1389
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(c6), &a);
1390
+ c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
1391
+ c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
1392
+ GEMV_BUILDPAIR_MMA(a, c6[0].v, c6[1].v);
1393
+ #else
1394
+ if (GEMV_IS_COMPLEX_FLOAT) {
1395
+ __asm__("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d"(a) : "wa"(b0.separate.r.v), "wa"(c0.v), "wa"(c1.v));
1396
+ __asm__("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d"(a) : "wa"(b0.separate.i.v), "wa"(c2.v), "wa"(c3.v));
1397
+ } else {
1398
+ __asm__("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d"(a) : "wa"(b0.separate.r.v), "wa"(c0.v), "wa"(c1.v));
1399
+ __asm__("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d"(a) : "wa"(b0.separate.i.v), "wa"(c2.v), "wa"(c3.v));
1400
+ }
1401
+ #endif
1402
+ *reinterpret_cast<__vector_pair*>(res + (iter2 * ResPacketSize)) = a;
1403
+ #endif
1404
+ }
1405
+
1406
+ /** \internal load lhs packet */
1407
+ template <typename Scalar, typename LhsScalar, typename LhsMapper, typename LhsPacket>
1408
+ EIGEN_ALWAYS_INLINE LhsPacket loadLhsPacket(LhsMapper& lhs, Index i, Index j) {
1409
+ if (sizeof(Scalar) == sizeof(LhsScalar)) {
1410
+ const LhsScalar& src = lhs(i + 0, j);
1411
+ return LhsPacket(pload_real_full(const_cast<LhsScalar*>(&src)));
1412
+ }
1413
+ return lhs.template load<LhsPacket, Unaligned>(i + 0, j);
1414
+ }
1415
+
1416
+ /** \internal madd for complex times complex */
1417
+ template <typename ComplexPacket, typename RealPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
1418
+ EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_complex(RealPacket& a, RealPacket& b, RealPacket& c) {
1419
+ if (ConjugateLhs && ConjugateRhs) {
1420
+ return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
1421
+ } else if (Negate && !ConjugateLhs && ConjugateRhs) {
1422
+ return vec_nmsub(a, b, c);
1423
+ } else {
1424
+ return vec_madd(a, b, c);
1425
+ }
1426
+ }
1427
+
1428
+ /** \internal madd for complex times real */
1429
+ template <typename ComplexPacket, typename RealPacket, bool Conjugate>
1430
+ EIGEN_ALWAYS_INLINE RealPacket pmadd_complex_real(RealPacket& a, RealPacket& b, RealPacket& c) {
1431
+ if (Conjugate) {
1432
+ return vec_madd(a, pconj2(ComplexPacket(b)).v, c);
1433
+ } else {
1434
+ return vec_madd(a, b, c);
1435
+ }
1436
+ }
1437
+
1438
+ template <typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, bool ConjugateLhs,
1439
+ bool ConjugateRhs, int StorageOrder>
1440
+ EIGEN_ALWAYS_INLINE void gemv_mult_generic(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1441
+ conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
1442
+ RhsPacket b0;
1443
+ if (StorageOrder == ColMajor) {
1444
+ b0 = pset1<RhsPacket>(*b);
1445
+ } else {
1446
+ b0 = ploadu<RhsPacket>(b);
1447
+ }
1448
+ c0 = pcj.pmadd(a0, b0, c0);
1449
+ }
1450
+
1451
+ /** \internal core multiply operation for vectors - complex times complex */
1452
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket,
1453
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1454
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0, ResPacket& c1) {
1455
+ ScalarPacket br, bi;
1456
+ if (StorageOrder == ColMajor) {
1457
+ pload_realimag<RhsScalar>(b, br, bi);
1458
+ } else {
1459
+ pload_realimag_row<RhsScalar>(b, br, bi);
1460
+ }
1461
+ if (ConjugateLhs && !ConjugateRhs) a0 = pconj2(a0);
1462
+ LhsPacket a1 = pcplxflipconj(a0);
1463
+ ScalarPacket cr = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, false>(a0.v, br, c0.v);
1464
+ ScalarPacket ci = pmadd_complex_complex<LhsPacket, ScalarPacket, ConjugateLhs, ConjugateRhs, true>(a1.v, bi, c1.v);
1465
+ c1 = ResPacket(ci);
1466
+ c0 = PResPacket(cr);
1467
+ }
1468
+
1469
+ /** \internal core multiply operation for vectors - real times complex */
1470
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket,
1471
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1472
+ EIGEN_ALWAYS_INLINE void gemv_mult_real_complex(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1473
+ ScalarPacket b0;
1474
+ if (StorageOrder == ColMajor) {
1475
+ b0 = pload_complex_full(b);
1476
+ } else {
1477
+ b0 = pload_complex_full_row(b);
1478
+ }
1479
+ ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateRhs>(a0, b0, c0.v);
1480
+ c0 = PResPacket(cri);
1481
+ }
1482
+
1483
+ /** \internal core multiply operation for vectors - complex times real */
1484
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket,
1485
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1486
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_real(LhsPacket& a0, RhsScalar* b, PResPacket& c0) {
1487
+ ScalarPacket a1 = pload_complex<ResPacket>(&a0);
1488
+ ScalarPacket b0;
1489
+ if (StorageOrder == ColMajor) {
1490
+ b0 = pload_real(b);
1491
+ } else {
1492
+ b0 = pload_real_row<ResPacket>(b);
1493
+ }
1494
+ ScalarPacket cri = pmadd_complex_real<PResPacket, ScalarPacket, ConjugateLhs>(a1, b0, c0.v);
1495
+ c0 = PResPacket(cri);
1496
+ }
1497
+
1498
+ #define GEMV_MULT_COMPLEX_COMPLEX(LhsType, RhsType, ResType) \
1499
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1500
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1501
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, ResType& c1) { \
1502
+ gemv_mult_complex_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1503
+ ConjugateRhs, StorageOrder>(a0, b, c0, c1); \
1504
+ }
1505
+
1506
+ GEMV_MULT_COMPLEX_COMPLEX(Packet2cf, std::complex<float>, Packet2cf)
1507
+ GEMV_MULT_COMPLEX_COMPLEX(Packet1cd, std::complex<double>, Packet1cd)
1508
+
1509
+ #define GEMV_MULT_REAL_COMPLEX(LhsType, RhsType, ResType) \
1510
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1511
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1512
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType& c0, RhsType&) { \
1513
+ gemv_mult_real_complex<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1514
+ ConjugateRhs, StorageOrder>(a0, b, c0); \
1515
+ }
1516
+
1517
+ GEMV_MULT_REAL_COMPLEX(float, std::complex<float>, Packet2cf)
1518
+ GEMV_MULT_REAL_COMPLEX(double, std::complex<double>, Packet1cd)
1519
+ GEMV_MULT_REAL_COMPLEX(Packet4f, std::complex<float>, Packet2cf)
1520
+ GEMV_MULT_REAL_COMPLEX(Packet2d, std::complex<double>, Packet1cd)
1521
+
1522
+ #define GEMV_MULT_COMPLEX_REAL(LhsType, RhsType, ResType1, ResType2) \
1523
+ template <typename ScalarPacket, typename LhsPacket, typename RhsScalar, typename RhsPacket, typename PResPacket, \
1524
+ typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1525
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex(LhsType& a0, RhsType* b, ResType1& c0, ResType2&) { \
1526
+ gemv_mult_complex_real<ScalarPacket, LhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1527
+ ConjugateRhs, StorageOrder>(a0, b, c0); \
1528
+ }
1529
+
1530
+ GEMV_MULT_COMPLEX_REAL(Packet2cf, float, Packet2cf, std::complex<float>)
1531
+ GEMV_MULT_COMPLEX_REAL(Packet1cd, double, Packet1cd, std::complex<double>)
1532
+ GEMV_MULT_COMPLEX_REAL(std::complex<float>, float, Packet2cf, std::complex<float>)
1533
+ GEMV_MULT_COMPLEX_REAL(std::complex<double>, double, Packet1cd, std::complex<double>)
1534
+
1535
+ #ifdef USE_GEMV_MMA
1536
+ /** \internal convert packet to real form */
1537
+ template <typename T>
1538
+ EIGEN_ALWAYS_INLINE T convertReal(T a) {
1539
+ return a;
1540
+ }
1541
+
1542
+ EIGEN_ALWAYS_INLINE Packet4f convertReal(Packet2cf a) { return a.v; }
1543
+
1544
+ EIGEN_ALWAYS_INLINE Packet2d convertReal(Packet1cd a) { return a.v; }
1545
+
1546
+ /** \internal convert packet to complex form */
1547
+ template <typename T>
1548
+ EIGEN_ALWAYS_INLINE T convertComplex(T a) {
1549
+ return a;
1550
+ }
1551
+
1552
+ EIGEN_ALWAYS_INLINE Packet2cf convertComplex(Packet4f a) { return Packet2cf(a); }
1553
+
1554
+ EIGEN_ALWAYS_INLINE Packet1cd convertComplex(Packet2d a) { return Packet1cd(a); }
1555
+
1556
+ /** \internal load a vector from a complex location (for MMA version) */
1557
+ template <typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename ResPacket>
1558
+ EIGEN_ALWAYS_INLINE void pload_complex_MMA(SLhsPacket& a) {
1559
+ a = SLhsPacket(pload_complex<ResPacket>(&a));
1560
+ }
1561
+
1562
+ template <typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename ResPacket>
1563
+ EIGEN_ALWAYS_INLINE void pload_complex_MMA(__vector_pair&) {
1564
+ // Pass thru
1565
+ }
1566
+
1567
+ /** \internal perform a matrix multiply and accumulate (positive and negative) of packet a and packet b */
1568
+ template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1569
+ EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad* acc, RhsPacket& a, LhsPacket& b) {
1570
+ if (NegativeAccumulate) {
1571
+ __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
1572
+ } else {
1573
+ __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b);
1574
+ }
1575
+ }
1576
+
1577
+ /** \internal perform a matrix multiply and accumulate (positive and negative) of vector_pair a and packet b */
1578
+ template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1579
+ EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad* acc, __vector_pair& a, Packet2d& b) {
1580
+ if (NegativeAccumulate) {
1581
+ __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b);
1582
+ } else {
1583
+ __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b);
1584
+ }
1585
+ }
1586
+
1587
+ template <typename LhsPacket, typename RhsPacket, bool NegativeAccumulate>
1588
+ EIGEN_ALWAYS_INLINE void pger_vecMMA(__vector_quad*, __vector_pair&, Packet4f&) {
1589
+ // Just for compilation
1590
+ }
1591
+
1592
+ /** \internal madd for complex times complex (MMA version) */
1593
+ template <typename RealPacket, typename LhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
1594
+ EIGEN_ALWAYS_INLINE void pmadd_complex_complex_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c) {
1595
+ if (ConjugateLhs && ConjugateRhs) {
1596
+ RealPacket b2 = pconj2(convertComplex(b)).v;
1597
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a.v);
1598
+ } else if (Negate && !ConjugateLhs && ConjugateRhs) {
1599
+ return pger_vecMMA<RealPacket, RealPacket, true>(c, b, a.v);
1600
+ } else {
1601
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a.v);
1602
+ }
1603
+ }
1604
+
1605
+ template <typename RealPacket, typename LhsPacket, bool ConjugateLhs, bool ConjugateRhs, bool Negate>
1606
+ EIGEN_ALWAYS_INLINE void pmadd_complex_complex_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c) {
1607
+ if (ConjugateLhs && ConjugateRhs) {
1608
+ RealPacket b2 = pconj2(convertComplex(b)).v;
1609
+ return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1610
+ } else if (Negate && !ConjugateLhs && ConjugateRhs) {
1611
+ return pger_vecMMA<RealPacket, __vector_pair, true>(c, a, b);
1612
+ } else {
1613
+ return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1614
+ }
1615
+ }
1616
+
1617
+ /** \internal madd for complex times real (MMA version) */
1618
+ template <typename RealPacket, typename LhsPacket, bool Conjugate, int StorageOrder>
1619
+ EIGEN_ALWAYS_INLINE void pmadd_complex_real_MMA(LhsPacket& a, RealPacket& b, __vector_quad* c) {
1620
+ RealPacket a2 = convertReal(a);
1621
+ if (Conjugate) {
1622
+ RealPacket b2 = pconj2(convertComplex(b)).v;
1623
+ if (StorageOrder == ColMajor) {
1624
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, b2, a2);
1625
+ } else {
1626
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b2);
1627
+ }
1628
+ } else {
1629
+ if (StorageOrder == ColMajor) {
1630
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, b, a2);
1631
+ } else {
1632
+ return pger_vecMMA<RealPacket, RealPacket, false>(c, a2, b);
1633
+ }
1634
+ }
1635
+ }
1636
+
1637
+ /** \internal madd for real times complex (MMA version) */
1638
+ template <typename RealPacket, typename LhsPacket, bool Conjugate, int StorageOrder>
1639
+ EIGEN_ALWAYS_INLINE void pmadd_complex_real_MMA(__vector_pair& a, RealPacket& b, __vector_quad* c) {
1640
+ if (Conjugate) {
1641
+ RealPacket b2 = pconj2(convertComplex(b)).v;
1642
+ return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b2);
1643
+ } else {
1644
+ return pger_vecMMA<RealPacket, __vector_pair, false>(c, a, b);
1645
+ }
1646
+ }
1647
+
1648
+ /** \internal core multiply operation for vectors (MMA version) - complex times complex */
1649
+ template <typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket,
1650
+ bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1651
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1652
+ ScalarPacket b0;
1653
+ if (StorageOrder == ColMajor) {
1654
+ b0 = pload_realimag_combine(b);
1655
+ } else {
1656
+ b0 = pload_realimag_combine_row(b);
1657
+ }
1658
+ pmadd_complex_complex_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ConjugateRhs, false>(a0, b0, c0);
1659
+ }
1660
+
1661
+ /** \internal core multiply operation for vectors (MMA version) - complex times real */
1662
+ template <typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket,
1663
+ bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1664
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_real_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1665
+ pload_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, ResPacket>(a0);
1666
+ ScalarPacket b0;
1667
+ if (StorageOrder == ColMajor) {
1668
+ b0 = pload_real(b);
1669
+ } else {
1670
+ b0 = pload_real_row<ResPacket>(b);
1671
+ }
1672
+ pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateLhs, ColMajor>(a0, b0, c0);
1673
+ }
1674
+
1675
+ /** \internal core multiply operation for vectors (MMA version) - real times complex */
1676
+ template <typename ScalarPacket, typename LhsPacket, typename SLhsPacket, typename RhsScalar, typename ResPacket,
1677
+ bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1678
+ EIGEN_ALWAYS_INLINE void gemv_mult_real_complex_MMA(SLhsPacket& a0, RhsScalar* b, __vector_quad* c0) {
1679
+ ScalarPacket b0;
1680
+ if (StorageOrder == ColMajor) {
1681
+ b0 = pload_complex_full(b);
1682
+ } else {
1683
+ b0 = pload_complex_full_row(b);
1684
+ }
1685
+ pmadd_complex_real_MMA<ScalarPacket, LhsPacket, ConjugateRhs,
1686
+ (sizeof(RhsScalar) == sizeof(std::complex<float>)) ? StorageOrder : ColMajor>(a0, b0, c0);
1687
+ }
1688
+
1689
+ #define GEMV_MULT_COMPLEX_COMPLEX_MMA(LhsType, RhsType) \
1690
+ template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1691
+ typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1692
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1693
+ gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, \
1694
+ ConjugateRhs, StorageOrder>(a0, b, c0); \
1695
+ }
1696
+
1697
+ GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet2cf, std::complex<float>)
1698
+ GEMV_MULT_COMPLEX_COMPLEX_MMA(__vector_pair, std::complex<float>)
1699
+ GEMV_MULT_COMPLEX_COMPLEX_MMA(Packet1cd, std::complex<double>)
1700
+
1701
+ /** \internal core multiply operation for vectors (MMA version) - complex times complex */
1702
+ template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar,
1703
+ typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder>
1704
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(__vector_pair& a0, std::complex<double>* b, __vector_quad* c0) {
1705
+ if (sizeof(LhsScalar) == 16) {
1706
+ gemv_mult_complex_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1707
+ StorageOrder>(a0, b, c0);
1708
+ } else {
1709
+ gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs,
1710
+ StorageOrder>(a0, b, c0);
1711
+ }
1712
+ }
1713
+
1714
+ #define GEMV_MULT_REAL_COMPLEX_MMA(LhsType, RhsType) \
1715
+ template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1716
+ typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1717
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1718
+ gemv_mult_real_complex_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1719
+ StorageOrder>(a0, b, c0); \
1720
+ }
1721
+
1722
+ GEMV_MULT_REAL_COMPLEX_MMA(Packet4f, std::complex<float>)
1723
+ GEMV_MULT_REAL_COMPLEX_MMA(Packet2d, std::complex<double>)
1724
+
1725
+ #define GEMV_MULT_COMPLEX_REAL_MMA(LhsType, RhsType) \
1726
+ template <typename ScalarPacket, typename LhsScalar, typename LhsPacket, typename SLhsPacket, typename RhsScalar, \
1727
+ typename RhsPacket, typename ResPacket, bool ConjugateLhs, bool ConjugateRhs, int StorageOrder> \
1728
+ EIGEN_ALWAYS_INLINE void gemv_mult_complex_MMA(LhsType& a0, RhsType* b, __vector_quad* c0) { \
1729
+ gemv_mult_complex_real_MMA<ScalarPacket, LhsPacket, SLhsPacket, RhsScalar, ResPacket, ConjugateLhs, ConjugateRhs, \
1730
+ StorageOrder>(a0, b, c0); \
1731
+ }
1732
+
1733
+ GEMV_MULT_COMPLEX_REAL_MMA(Packet2cf, float)
1734
+ GEMV_MULT_COMPLEX_REAL_MMA(Packet1cd, double)
1735
+ GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair, float)
1736
+ GEMV_MULT_COMPLEX_REAL_MMA(__vector_pair, double)
1737
+
1738
+ /** \internal disassemble MMA accumulator results into packets */
1739
+ template <typename Scalar, typename ScalarPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs,
1740
+ bool ConjugateRhs>
1741
+ EIGEN_ALWAYS_INLINE void disassembleResults2(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1742
+ __builtin_mma_disassemble_acc(&result0.packet, c0);
1743
+ if (sizeof(LhsPacket) == 16) {
1744
+ if (sizeof(RhsPacket) == 16) {
1745
+ ScalarPacket tmp0, tmp2;
1746
+ tmp2 = vec_mergeh(result0.packet[2], result0.packet[3]);
1747
+ tmp0 = vec_mergeh(result0.packet[0], result0.packet[1]);
1748
+ result0.packet[3] = vec_mergel(result0.packet[3], result0.packet[2]);
1749
+ result0.packet[1] = vec_mergel(result0.packet[1], result0.packet[0]);
1750
+ result0.packet[2] = tmp2;
1751
+ result0.packet[0] = tmp0;
1752
+
1753
+ if (ConjugateLhs) {
1754
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1755
+ result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
1756
+ } else if (ConjugateRhs) {
1757
+ result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
1758
+ result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
1759
+ } else {
1760
+ result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
1761
+ result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
1762
+ }
1763
+ result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1764
+ result0.packet[2] = vec_add(result0.packet[2], result0.packet[3]);
1765
+ } else {
1766
+ result0.packet[0][1] = result0.packet[1][1];
1767
+ result0.packet[2][1] = result0.packet[3][1];
1768
+ }
1769
+ }
1770
+ }
1771
+
1772
+ template <typename Scalar, typename ScalarPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs,
1773
+ bool ConjugateRhs>
1774
+ EIGEN_ALWAYS_INLINE void disassembleResults4(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1775
+ __builtin_mma_disassemble_acc(&result0.packet, c0);
1776
+ if (GEMV_IS_COMPLEX_COMPLEX) {
1777
+ if (ConjugateLhs) {
1778
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1779
+ result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
1780
+ } else {
1781
+ if (ConjugateRhs) {
1782
+ result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
1783
+ } else {
1784
+ result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
1785
+ }
1786
+ }
1787
+ result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
1788
+ } else if (sizeof(LhsPacket) == sizeof(std::complex<float>)) {
1789
+ if (ConjugateLhs) {
1790
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
1791
+ }
1792
+ } else {
1793
+ result0.packet[0] = vec_mergee(result0.packet[0], result0.packet[1]);
1794
+ }
1795
+ }
1796
+
1797
+ template <typename Scalar, typename ScalarPacket, int ResPacketSize, typename LhsPacket, typename RhsPacket,
1798
+ bool ConjugateLhs, bool ConjugateRhs>
1799
+ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<ScalarPacket, 4>& result0) {
1800
+ if (!GEMV_IS_COMPLEX_FLOAT) {
1801
+ disassembleResults2<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1802
+ } else {
1803
+ disassembleResults4<Scalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(c0, result0);
1804
+ }
1805
+ }
1806
+ #endif
1807
+
1808
+ #define GEMV_GETN_COMPLEX(N) (((N) * ResPacketSize) >> 1)
1809
+
1810
+ #define GEMV_LOADPACKET_COL_COMPLEX(iter) \
1811
+ loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + ((iter) * ResPacketSize), j)
1812
+
1813
+ #define GEMV_LOADPACKET_COL_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_COL_COMPLEX(iter))
1814
+
1815
+ #ifdef USE_GEMV_MMA
1816
+ #define GEMV_INIT_COL_COMPLEX_MMA(iter, N) \
1817
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
1818
+ __builtin_mma_xxsetaccz(&e0##iter); \
1819
+ }
1820
+
1821
+ #if EIGEN_COMP_LLVM
1822
+ #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1823
+ GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), \
1824
+ GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1)); \
1825
+ EIGEN_UNUSED_VARIABLE(f##iter1);
1826
+ #else
1827
+ #define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
1828
+ if (sizeof(LhsPacket) == 16) { \
1829
+ const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
1830
+ a##iter1 = *reinterpret_cast<__vector_pair*>(const_cast<LhsScalar*>(&src)); \
1831
+ EIGEN_UNUSED_VARIABLE(f##iter1); \
1832
+ } else { \
1833
+ f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
1834
+ GEMV_BUILDPAIR_MMA(a##iter1, vec_splat(convertReal(f##iter1), 0), vec_splat(convertReal(f##iter1), 1)); \
1835
+ }
1836
+ #endif
1837
+
1838
+ #define GEMV_LOAD1_COL_COMPLEX_MMA(iter, N) \
1839
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
1840
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1841
+ f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1842
+ EIGEN_UNUSED_VARIABLE(a##iter); \
1843
+ } else { \
1844
+ GEMV_LOADPAIR_COL_COMPLEX_MMA(iter, iter << 1) \
1845
+ } \
1846
+ } else { \
1847
+ EIGEN_UNUSED_VARIABLE(a##iter); \
1848
+ EIGEN_UNUSED_VARIABLE(f##iter); \
1849
+ }
1850
+
1851
+ #define GEMV_WORK1_COL_COMPLEX_MMA(iter, N) \
1852
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
1853
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1854
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1855
+ ConjugateLhs, ConjugateRhs, ColMajor>(f##iter, b, &e0##iter); \
1856
+ } else { \
1857
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1858
+ ConjugateLhs, ConjugateRhs, ColMajor>(a##iter, b, &e0##iter); \
1859
+ } \
1860
+ }
1861
+
1862
+ #define GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter1, iter2) \
1863
+ GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_COL_COMPLEX_DATA(iter2), GEMV_LOADPACKET_COL_COMPLEX_DATA((iter2) + 1));
1864
+
1865
+ #define GEMV_LOAD2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1866
+ if (GEMV_GETN_COMPLEX(N) > iter1) { \
1867
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1868
+ GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2); \
1869
+ EIGEN_UNUSED_VARIABLE(a##iter3) \
1870
+ } else { \
1871
+ GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter2, iter2 << 1); \
1872
+ GEMV_LOADPAIR2_COL_COMPLEX_MMA(iter3, iter3 << 1); \
1873
+ } \
1874
+ } else { \
1875
+ EIGEN_UNUSED_VARIABLE(a##iter2); \
1876
+ EIGEN_UNUSED_VARIABLE(a##iter3); \
1877
+ } \
1878
+ EIGEN_UNUSED_VARIABLE(f##iter2); \
1879
+ EIGEN_UNUSED_VARIABLE(f##iter3);
1880
+
1881
+ #define GEMV_WORK2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1882
+ if (GEMV_GETN_COMPLEX(N) > iter1) { \
1883
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1884
+ PLhsPacket g[2]; \
1885
+ __builtin_vsx_disassemble_pair(reinterpret_cast<void*>(g), &a##iter2); \
1886
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1887
+ ConjugateLhs, ConjugateRhs, ColMajor>(g[0], b, &e0##iter2); \
1888
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
1889
+ ConjugateLhs, ConjugateRhs, ColMajor>(g[1], b, &e0##iter3); \
1890
+ } else { \
1891
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1892
+ ConjugateLhs, ConjugateRhs, ColMajor>(a##iter2, b, &e0##iter2); \
1893
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
1894
+ ConjugateLhs, ConjugateRhs, ColMajor>(a##iter3, b, &e0##iter3); \
1895
+ } \
1896
+ }
1897
+
1898
+ #if EIGEN_COMP_LLVM
1899
+ #define GEMV_LOAD_COL_COMPLEX_MMA(N) \
1900
+ if (GEMV_GETN_COMPLEX(N) > 1) { \
1901
+ GEMV_UNROLL_HALF(GEMV_LOAD2_COL_COMPLEX_MMA, (N >> 1)) \
1902
+ } else { \
1903
+ GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N) \
1904
+ }
1905
+
1906
+ #define GEMV_WORK_COL_COMPLEX_MMA(N) \
1907
+ if (GEMV_GETN_COMPLEX(N) > 1) { \
1908
+ GEMV_UNROLL_HALF(GEMV_WORK2_COL_COMPLEX_MMA, (N >> 1)) \
1909
+ } else { \
1910
+ GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N) \
1911
+ }
1912
+ #else
1913
+ #define GEMV_LOAD_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_LOAD1_COL_COMPLEX_MMA, N)
1914
+
1915
+ #define GEMV_WORK_COL_COMPLEX_MMA(N) GEMV_UNROLL(GEMV_WORK1_COL_COMPLEX_MMA, N)
1916
+ #endif
1917
+
1918
+ #define GEMV_DISASSEMBLE_COMPLEX_MMA(iter) \
1919
+ disassembleResults<Scalar, ScalarPacket, ResPacketSize, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
1920
+ &e0##iter, result0##iter);
1921
+
1922
+ #define GEMV_STORE_COL_COMPLEX_MMA(iter, N) \
1923
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
1924
+ GEMV_DISASSEMBLE_COMPLEX_MMA(iter); \
1925
+ c0##iter = PResPacket(result0##iter.packet[0]); \
1926
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1927
+ pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1928
+ c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
1929
+ } else { \
1930
+ pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1931
+ c0##iter, alpha_data, res + i + ((iter << 1) * ResPacketSize)); \
1932
+ c0##iter = PResPacket(result0##iter.packet[2]); \
1933
+ pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
1934
+ c0##iter, alpha_data, res + i + (((iter << 1) + 1) * ResPacketSize)); \
1935
+ } \
1936
+ }
1937
+
1938
+ #define GEMV_STORE2_COL_COMPLEX_MMA(iter1, iter2, iter3, N) \
1939
+ if (GEMV_GETN_COMPLEX(N) > iter1) { \
1940
+ GEMV_DISASSEMBLE_COMPLEX_MMA(iter2); \
1941
+ GEMV_DISASSEMBLE_COMPLEX_MMA(iter3); \
1942
+ c0##iter2 = PResPacket(result0##iter2.packet[0]); \
1943
+ if (GEMV_IS_COMPLEX_FLOAT) { \
1944
+ c0##iter3 = PResPacket(result0##iter3.packet[0]); \
1945
+ pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2>( \
1946
+ c0##iter2, c0##iter3, alpha_data, res + i); \
1947
+ } else { \
1948
+ c0##iter3 = PResPacket(result0##iter2.packet[2]); \
1949
+ pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter2 << 1>( \
1950
+ c0##iter2, c0##iter3, alpha_data, res + i); \
1951
+ c0##iter2 = PResPacket(result0##iter3.packet[0]); \
1952
+ c0##iter3 = PResPacket(result0##iter3.packet[2]); \
1953
+ pstoreu_pmadd_complex<ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData, ResPacketSize, iter3 << 1>( \
1954
+ c0##iter2, c0##iter3, alpha_data, res + i); \
1955
+ } \
1956
+ }
1957
+
1958
+ #define GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
1959
+ GEMV_UNROLL(GEMV_INIT_COL_COMPLEX_MMA, N) \
1960
+ Index j = j2; \
1961
+ do { \
1962
+ const RhsScalar& b1 = rhs2(j, 0); \
1963
+ RhsScalar* b = const_cast<RhsScalar*>(&b1); \
1964
+ GEMV_UNROLL(GEMV_PREFETCH, N) \
1965
+ GEMV_LOAD_COL_COMPLEX_MMA(N) \
1966
+ GEMV_WORK_COL_COMPLEX_MMA(N) \
1967
+ } while (++j < jend); \
1968
+ if (GEMV_GETN(N) <= 2) { \
1969
+ GEMV_UNROLL(GEMV_STORE_COL_COMPLEX_MMA, N) \
1970
+ } else { \
1971
+ GEMV_UNROLL_HALF(GEMV_STORE2_COL_COMPLEX_MMA, (N >> 1)) \
1972
+ } \
1973
+ i += (ResPacketSize * N);
1974
+ #endif
1975
+
1976
+ #define GEMV_INIT_COMPLEX(iter, N) \
1977
+ if (N > iter) { \
1978
+ c0##iter = pset_zero<PResPacket>(); \
1979
+ c1##iter = pset_init<ResPacket, LhsPacket, RhsPacket>(c1##iter); \
1980
+ } else { \
1981
+ EIGEN_UNUSED_VARIABLE(c0##iter); \
1982
+ EIGEN_UNUSED_VARIABLE(c1##iter); \
1983
+ }
1984
+
1985
+ #define GEMV_WORK_COL_COMPLEX(iter, N) \
1986
+ if (N > iter) { \
1987
+ f##iter = GEMV_LOADPACKET_COL_COMPLEX(iter); \
1988
+ gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
1989
+ ConjugateRhs, ColMajor>(f##iter, b, c0##iter, c1##iter); \
1990
+ } else { \
1991
+ EIGEN_UNUSED_VARIABLE(f##iter); \
1992
+ }
1993
+
1994
+ #define GEMV_STORE_COL_COMPLEX(iter, N) \
1995
+ if (N > iter) { \
1996
+ if (GEMV_IS_COMPLEX_COMPLEX) { \
1997
+ c0##iter = padd(c0##iter, c1##iter); \
1998
+ } \
1999
+ pstoreu_pmadd_complex<Scalar, ScalarPacket, PResPacket, ResPacket, ResScalar, AlphaData>( \
2000
+ c0##iter, alpha_data, res + i + (iter * ResPacketSize)); \
2001
+ }
2002
+
2003
+ /** \internal main macro for gemv_complex_col - initialize accumulators, multiply and add inputs, and store results */
2004
+ #define GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2005
+ GEMV_UNROLL(GEMV_INIT_COMPLEX, N) \
2006
+ Index j = j2; \
2007
+ do { \
2008
+ const RhsScalar& b1 = rhs2(j, 0); \
2009
+ RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2010
+ GEMV_UNROLL(GEMV_PREFETCH, N) \
2011
+ GEMV_UNROLL(GEMV_WORK_COL_COMPLEX, N) \
2012
+ } while (++j < jend); \
2013
+ GEMV_UNROLL(GEMV_STORE_COL_COMPLEX, N) \
2014
+ i += (ResPacketSize * N);
2015
+
2016
+ #if defined(USE_GEMV_MMA) && (EIGEN_COMP_LLVM || defined(USE_SLOWER_GEMV_MMA))
2017
+ #define USE_GEMV_COL_COMPLEX_MMA
2018
+ #endif
2019
+
2020
+ #ifdef USE_GEMV_COL_COMPLEX_MMA
2021
+ #define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N)
2022
+ #else
2023
+ #if defined(USE_GEMV_MMA) && (__GNUC__ > 10)
2024
+ #define GEMV_PROCESS_COL_COMPLEX(N) \
2025
+ if (sizeof(Scalar) != sizeof(LhsPacket)) { \
2026
+ GEMV_PROCESS_COL_COMPLEX_ONE_MMA(N) \
2027
+ } else { \
2028
+ GEMV_PROCESS_COL_COMPLEX_ONE(N) \
2029
+ }
2030
+ #else
2031
+ #define GEMV_PROCESS_COL_COMPLEX(N) GEMV_PROCESS_COL_COMPLEX_ONE(N)
2032
+ #endif
2033
+ #endif
2034
+
2035
+ template <typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal,
2036
+ typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
2037
+ EIGEN_STRONG_INLINE void gemv_complex_col(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
2038
+ ResScalar* res, Index resIncr, ResScalar alpha) {
2039
+ typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2040
+
2041
+ typedef typename Traits::LhsPacket LhsPacket;
2042
+ typedef typename Traits::RhsPacket RhsPacket;
2043
+ typedef typename Traits::ResPacket ResPacket;
2044
+
2045
+ typedef typename packet_traits<Scalar>::type ScalarPacket;
2046
+ typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2047
+ typedef typename packet_traits<ResScalar>::type PResPacket;
2048
+ typedef gemv_traits<ResPacket, ResPacket> PTraits;
2049
+
2050
+ EIGEN_UNUSED_VARIABLE(resIncr);
2051
+ eigen_internal_assert(resIncr == 1);
2052
+
2053
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
2054
+ // This helps GCC to generate proper code.
2055
+ LhsMapper lhs(alhs);
2056
+ RhsMapper rhs2(rhs);
2057
+
2058
+ conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2059
+
2060
+ const Index lhsStride = lhs.stride();
2061
+ // TODO: for padded aligned inputs, we could enable aligned reads
2062
+ enum {
2063
+ LhsAlignment = Unaligned,
2064
+ ResPacketSize = PTraits::ResPacketSize,
2065
+ LhsPacketSize = PTraits::LhsPacketSize,
2066
+ RhsPacketSize = PTraits::RhsPacketSize,
2067
+ };
2068
+ #ifdef EIGEN_POWER_USE_GEMV_PREFETCH
2069
+ const Index prefetch_dist = 64 * LhsPacketSize;
2070
+ #endif
2071
+
2072
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2073
+ const Index n8 = rows - 8 * ResPacketSize + 1;
2074
+ const Index n4 = rows - 4 * ResPacketSize + 1;
2075
+ const Index n2 = rows - 2 * ResPacketSize + 1;
2076
+ #endif
2077
+ const Index n1 = rows - 1 * ResPacketSize + 1;
2078
+
2079
+ // TODO: improve the following heuristic:
2080
+ const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 16000 ? 16 : 8);
2081
+
2082
+ typedef alpha_store<PResPacket, ResPacket, ResScalar, Scalar> AlphaData;
2083
+ AlphaData alpha_data(alpha);
2084
+
2085
+ for (Index j2 = 0; j2 < cols; j2 += block_cols) {
2086
+ Index jend = numext::mini(j2 + block_cols, cols);
2087
+ Index i = 0;
2088
+ PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2089
+ ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2090
+ PLhsPacket f0, f1, f2, f3, f4, f5, f6, f7;
2091
+ #ifdef USE_GEMV_MMA
2092
+ __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2093
+ __vector_pair a0, a1, a2, a3, a4, a5, a6, a7;
2094
+ PacketBlock<ScalarPacket, 4> result00, result01, result02, result03, result04, result05, result06, result07;
2095
+ GEMV_UNUSED(8, e0)
2096
+ GEMV_UNUSED(8, result0)
2097
+ GEMV_UNUSED(8, a)
2098
+ GEMV_UNUSED(8, f)
2099
+ #if !defined(GCC_ONE_VECTORPAIR_BUG) && defined(USE_GEMV_COL_COMPLEX_MMA)
2100
+ if (GEMV_IS_COMPLEX_COMPLEX || !GEMV_IS_COMPLEX_FLOAT)
2101
+ #endif
2102
+ #endif
2103
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2104
+ {
2105
+ while (i < n8) {
2106
+ GEMV_PROCESS_COL_COMPLEX(8)
2107
+ }
2108
+ }
2109
+ while (i < n4) {
2110
+ GEMV_PROCESS_COL_COMPLEX(4)
2111
+ }
2112
+ if (i < n2) {
2113
+ GEMV_PROCESS_COL_COMPLEX(2)
2114
+ }
2115
+ if (i < n1)
2116
+ #else
2117
+ while (i < n1)
2118
+ #endif
2119
+ {
2120
+ GEMV_PROCESS_COL_COMPLEX_ONE(1)
2121
+ }
2122
+ for (; i < rows; ++i) {
2123
+ ResScalar d0(0);
2124
+ Index j = j2;
2125
+ do {
2126
+ d0 += cj.pmul(lhs(i, j), rhs2(j, 0));
2127
+ } while (++j < jend);
2128
+ res[i] += alpha * d0;
2129
+ }
2130
+ }
2131
+ }
2132
+
2133
+ template <typename Scalar, int N>
2134
+ struct ScalarBlock {
2135
+ Scalar scalar[N];
2136
+ };
2137
+
2138
+ #ifdef USE_GEMV_MMA
2139
+ static Packet16uc p16uc_ELEMENT_3 = {0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f,
2140
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f};
2141
+
2142
+ /** \internal predux (add elements of a vector) from a MMA accumulator - real results */
2143
+ template <typename ResScalar, typename ResPacket>
2144
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1) {
2145
+ PacketBlock<ResPacket, 4> result0, result1;
2146
+ __builtin_mma_disassemble_acc(&result0.packet, acc0);
2147
+ __builtin_mma_disassemble_acc(&result1.packet, acc1);
2148
+ result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
2149
+ result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
2150
+ result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
2151
+ result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
2152
+ result0.packet[0] =
2153
+ vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
2154
+ return *reinterpret_cast<ScalarBlock<ResScalar, 2>*>(&result0.packet[0]);
2155
+ }
2156
+
2157
+ template <>
2158
+ EIGEN_ALWAYS_INLINE ScalarBlock<double, 2> predux_real<double, Packet2d>(__vector_quad* acc0, __vector_quad* acc1) {
2159
+ PacketBlock<Packet2d, 4> result0, result1;
2160
+ __builtin_mma_disassemble_acc(&result0.packet, acc0);
2161
+ __builtin_mma_disassemble_acc(&result1.packet, acc1);
2162
+ result0.packet[0] =
2163
+ vec_add(vec_mergeh(result0.packet[0], result1.packet[0]), vec_mergel(result0.packet[1], result1.packet[1]));
2164
+ return *reinterpret_cast<ScalarBlock<double, 2>*>(&result0.packet[0]);
2165
+ }
2166
+
2167
+ /** \internal add complex results together */
2168
+ template <typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
2169
+ EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<float>, 2> addComplexResults(PacketBlock<Packet4f, 4>& result0,
2170
+ PacketBlock<Packet4f, 4>& result1) {
2171
+ ScalarBlock<std::complex<float>, 2> cc0;
2172
+ result0.packet[0] = reinterpret_cast<Packet4f>(
2173
+ vec_mergeh(reinterpret_cast<Packet2d>(result0.packet[0]), reinterpret_cast<Packet2d>(result1.packet[0])));
2174
+ result0.packet[2] = reinterpret_cast<Packet4f>(
2175
+ vec_mergel(reinterpret_cast<Packet2d>(result0.packet[2]), reinterpret_cast<Packet2d>(result1.packet[2])));
2176
+ result0.packet[0] = vec_add(result0.packet[0], result0.packet[2]);
2177
+ if (GEMV_IS_COMPLEX_COMPLEX) {
2178
+ result0.packet[1] = reinterpret_cast<Packet4f>(
2179
+ vec_mergeh(reinterpret_cast<Packet2d>(result0.packet[1]), reinterpret_cast<Packet2d>(result1.packet[1])));
2180
+ result0.packet[3] = reinterpret_cast<Packet4f>(
2181
+ vec_mergel(reinterpret_cast<Packet2d>(result0.packet[3]), reinterpret_cast<Packet2d>(result1.packet[3])));
2182
+ result0.packet[1] = vec_add(result0.packet[1], result0.packet[3]);
2183
+ if (ConjugateLhs) {
2184
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2185
+ result0.packet[1] = pcplxflip2(convertComplex(result0.packet[1])).v;
2186
+ } else if (ConjugateRhs) {
2187
+ result0.packet[1] = pcplxconjflip(convertComplex(result0.packet[1])).v;
2188
+ } else {
2189
+ result0.packet[1] = pcplxflipconj(convertComplex(result0.packet[1])).v;
2190
+ }
2191
+ result0.packet[0] = vec_add(result0.packet[0], result0.packet[1]);
2192
+ } else {
2193
+ if (ConjugateLhs && (sizeof(LhsPacket) == sizeof(std::complex<float>))) {
2194
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2195
+ }
2196
+ }
2197
+ cc0.scalar[0].real(result0.packet[0][0]);
2198
+ cc0.scalar[0].imag(result0.packet[0][1]);
2199
+ cc0.scalar[1].real(result0.packet[0][2]);
2200
+ cc0.scalar[1].imag(result0.packet[0][3]);
2201
+ return cc0;
2202
+ }
2203
+
2204
+ template <typename LhsPacket, typename RhsPacket, bool ConjugateLhs, bool ConjugateRhs>
2205
+ EIGEN_ALWAYS_INLINE ScalarBlock<std::complex<double>, 2> addComplexResults(PacketBlock<Packet2d, 4>&,
2206
+ PacketBlock<Packet2d, 4>&) {
2207
+ ScalarBlock<std::complex<double>, 2> cc0;
2208
+ EIGEN_UNUSED_VARIABLE(cc0);
2209
+ return cc0; // Just for compilation
2210
+ }
2211
+
2212
+ /** \internal predux (add elements of a vector) from a MMA accumulator - complex results */
2213
+ template <typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs,
2214
+ bool ConjugateRhs>
2215
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0, __vector_quad* acc1) {
2216
+ PacketBlock<ResPacket, 4> result0, result1;
2217
+ __builtin_mma_disassemble_acc(&result0.packet, acc0);
2218
+ __builtin_mma_disassemble_acc(&result1.packet, acc1);
2219
+ return addComplexResults<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(result0, result1);
2220
+ }
2221
+
2222
+ template <typename ResScalar, typename ResPacket>
2223
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0) {
2224
+ PacketBlock<ResPacket, 4> result0;
2225
+ __builtin_mma_disassemble_acc(&result0.packet, acc0);
2226
+ result0.packet[0] =
2227
+ vec_add(vec_mergeh(result0.packet[0], result0.packet[2]), vec_mergel(result0.packet[1], result0.packet[3]));
2228
+ return *reinterpret_cast<ScalarBlock<ResScalar, 2>*>(&result0.packet[0]);
2229
+ }
2230
+
2231
+ template <typename ResScalar, typename ResPacket, typename LhsPacket, typename RhsPacket, bool ConjugateLhs,
2232
+ bool ConjugateRhs>
2233
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(__vector_quad* acc0) {
2234
+ ScalarBlock<ResScalar, 2> cc0;
2235
+ PacketBlock<ResPacket, 4> result0;
2236
+ __builtin_mma_disassemble_acc(&result0.packet, acc0);
2237
+ if (GEMV_IS_COMPLEX_COMPLEX) {
2238
+ if (ConjugateLhs) {
2239
+ result0.packet[1] = pconjinv(convertComplex(result0.packet[1])).v;
2240
+ result0.packet[3] = pconjinv(convertComplex(result0.packet[3])).v;
2241
+ } else if (ConjugateRhs) {
2242
+ result0.packet[0] = pconj2(convertComplex(result0.packet[0])).v;
2243
+ result0.packet[2] = pconj2(convertComplex(result0.packet[2])).v;
2244
+ } else {
2245
+ result0.packet[1] = pconj2(convertComplex(result0.packet[1])).v;
2246
+ result0.packet[3] = pconj2(convertComplex(result0.packet[3])).v;
2247
+ }
2248
+ result0.packet[0] = vec_add(result0.packet[0], __builtin_vsx_xxpermdi(result0.packet[1], result0.packet[1], 2));
2249
+ result0.packet[2] = vec_add(result0.packet[2], __builtin_vsx_xxpermdi(result0.packet[3], result0.packet[3], 2));
2250
+ } else {
2251
+ result0.packet[0] = __builtin_vsx_xxpermdi(result0.packet[0], result0.packet[1], 1);
2252
+ result0.packet[2] = __builtin_vsx_xxpermdi(result0.packet[2], result0.packet[3], 1);
2253
+ }
2254
+ cc0.scalar[0].real(result0.packet[0][0]);
2255
+ cc0.scalar[0].imag(result0.packet[0][1]);
2256
+ cc0.scalar[1].real(result0.packet[2][0]);
2257
+ cc0.scalar[1].imag(result0.packet[2][1]);
2258
+ return cc0;
2259
+ }
2260
+ #endif
2261
+
2262
+ template <typename ResScalar, typename ResPacket>
2263
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(ResPacket& a, ResPacket& b) {
2264
+ ScalarBlock<ResScalar, 2> cc0;
2265
+ cc0.scalar[0] = predux(a);
2266
+ cc0.scalar[1] = predux(b);
2267
+ return cc0;
2268
+ }
2269
+
2270
+ template <typename ResScalar, typename ResPacket>
2271
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(ResPacket& a, ResPacket& b) {
2272
+ return predux_real<ResScalar, ResPacket>(a, b);
2273
+ }
2274
+
2275
+ #define GEMV_UNROLL_ROW(func, N) func(0, N) func(1, N) func(2, N) func(3, N) func(4, N) func(5, N) func(6, N) func(7, N)
2276
+
2277
+ #define GEMV_UNROLL_ROW_HALF(func, N) func(0, 0, 1, N) func(1, 2, 3, N) func(2, 4, 5, N) func(3, 6, 7, N)
2278
+
2279
+ #define GEMV_LOADPACKET_ROW(iter) lhs.template load<LhsPacket, Unaligned>(i + (iter), j)
2280
+
2281
+ #ifdef USE_GEMV_MMA
2282
+ #define GEMV_UNROLL3_ROW(func, N, which) \
2283
+ func(0, N, which) func(1, N, which) func(2, N, which) func(3, N, which) func(4, N, which) func(5, N, which) \
2284
+ func(6, N, which) func(7, N, which)
2285
+
2286
+ #define GEMV_UNUSED_ROW(N, which) GEMV_UNROLL3_ROW(GEMV_UNUSED_VAR, N, which)
2287
+
2288
+ #define GEMV_INIT_ROW(iter, N) \
2289
+ if (GEMV_GETN(N) > iter) { \
2290
+ __builtin_mma_xxsetaccz(&c##iter); \
2291
+ }
2292
+
2293
+ #define GEMV_LOADPAIR_ROW(iter1, iter2) \
2294
+ GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_ROW(iter2), GEMV_LOADPACKET_ROW((iter2) + 1));
2295
+
2296
+ #define GEMV_WORK_ROW(iter, N) \
2297
+ if (GEMV_GETN(N) > iter) { \
2298
+ if (GEMV_IS_FLOAT) { \
2299
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, a0, GEMV_LOADPACKET_ROW(iter)); \
2300
+ } else { \
2301
+ __vector_pair b##iter; \
2302
+ GEMV_LOADPAIR_ROW(iter, iter << 1) \
2303
+ pger_vecMMA_acc<LhsPacket, RhsPacket, true>(&c##iter, b##iter, a0); \
2304
+ } \
2305
+ }
2306
+
2307
+ #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2308
+ if (N > iter1) { \
2309
+ if (GEMV_IS_FLOAT) { \
2310
+ cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter2, &c##iter3); \
2311
+ } else { \
2312
+ cc##iter1 = predux_real<ResScalar, ResPacket>(&c##iter1); \
2313
+ } \
2314
+ } else { \
2315
+ EIGEN_UNUSED_VARIABLE(cc##iter1); \
2316
+ }
2317
+ #else
2318
+ #define GEMV_INIT_ROW(iter, N) \
2319
+ if (N > iter) { \
2320
+ c##iter = pset1<ResPacket>(ResScalar(0)); \
2321
+ } else { \
2322
+ EIGEN_UNUSED_VARIABLE(c##iter); \
2323
+ }
2324
+
2325
+ #define GEMV_WORK_ROW(iter, N) \
2326
+ if (N > iter) { \
2327
+ c##iter = pcj.pmadd(GEMV_LOADPACKET_ROW(iter), a0, c##iter); \
2328
+ }
2329
+
2330
+ #define GEMV_PREDUX2(iter1, iter2, iter3, N) \
2331
+ if (N > iter1) { \
2332
+ cc##iter1 = predux_real<ResScalar, ResPacket>(c##iter2, c##iter3); \
2333
+ } else { \
2334
+ EIGEN_UNUSED_VARIABLE(cc##iter1); \
2335
+ }
2336
+ #endif
2337
+
2338
+ #define GEMV_MULT(iter1, iter2, iter3, N) \
2339
+ if (N > iter1) { \
2340
+ cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), a0); \
2341
+ cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), a0); \
2342
+ }
2343
+
2344
+ #define GEMV_STORE_ROW(iter1, iter2, iter3, N) \
2345
+ if (N > iter1) { \
2346
+ storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2347
+ storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2348
+ }
2349
+
2350
+ /** \internal main macro for gemv_row - initialize accumulators, multiply and add inputs, predux and store results */
2351
+ #define GEMV_PROCESS_ROW(N) \
2352
+ for (; i < n##N; i += N) { \
2353
+ GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
2354
+ Index j = 0; \
2355
+ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2356
+ RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2357
+ GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
2358
+ } \
2359
+ GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
2360
+ for (; j < cols; ++j) { \
2361
+ RhsScalar a0 = rhs2(j); \
2362
+ GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
2363
+ } \
2364
+ GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
2365
+ }
2366
+
2367
+ template <typename LhsScalar, typename LhsMapper, typename RhsScalar, typename RhsMapper, typename ResScalar>
2368
+ EIGEN_STRONG_INLINE void gemv_row(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs, ResScalar* res,
2369
+ Index resIncr, ResScalar alpha) {
2370
+ typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2371
+
2372
+ typedef typename Traits::LhsPacket LhsPacket;
2373
+ typedef typename Traits::RhsPacket RhsPacket;
2374
+ typedef typename Traits::ResPacket ResPacket;
2375
+
2376
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
2377
+ // This helps GCC to generate proper code.
2378
+ LhsMapper lhs(alhs);
2379
+ typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2380
+
2381
+ eigen_internal_assert(rhs.stride() == 1);
2382
+ conj_helper<LhsScalar, RhsScalar, false, false> cj;
2383
+ conj_helper<LhsPacket, RhsPacket, false, false> pcj;
2384
+
2385
+ // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
2386
+ // processing 8 rows at once might be counter productive wrt cache.
2387
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2388
+ const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
2389
+ const Index n4 = rows - 3;
2390
+ const Index n2 = rows - 1;
2391
+ #endif
2392
+
2393
+ // TODO: for padded aligned inputs, we could enable aligned reads
2394
+ enum {
2395
+ LhsAlignment = Unaligned,
2396
+ ResPacketSize = Traits::ResPacketSize,
2397
+ LhsPacketSize = Traits::LhsPacketSize,
2398
+ RhsPacketSize = Traits::RhsPacketSize,
2399
+ };
2400
+
2401
+ Index i = 0;
2402
+ #ifdef USE_GEMV_MMA
2403
+ __vector_quad c0, c1, c2, c3, c4, c5, c6, c7;
2404
+ GEMV_UNUSED_ROW(8, c)
2405
+ #else
2406
+ ResPacket c0, c1, c2, c3, c4, c5, c6, c7;
2407
+ #endif
2408
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2409
+ ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2410
+ GEMV_PROCESS_ROW(8)
2411
+ GEMV_PROCESS_ROW(4)
2412
+ GEMV_PROCESS_ROW(2)
2413
+ #endif
2414
+ for (; i < rows; ++i) {
2415
+ ResPacket d0 = pset1<ResPacket>(ResScalar(0));
2416
+ Index j = 0;
2417
+ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) {
2418
+ RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j);
2419
+
2420
+ d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, d0);
2421
+ }
2422
+ ResScalar dd0 = predux(d0);
2423
+ for (; j < cols; ++j) {
2424
+ dd0 += cj.pmul(lhs(i, j), rhs2(j));
2425
+ }
2426
+ res[i * resIncr] += alpha * dd0;
2427
+ }
2428
+ }
2429
+
2430
+ #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(Scalar) \
2431
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2432
+ struct general_matrix_vector_product<Index, Scalar, LhsMapper, ColMajor, ConjugateLhs, Scalar, RhsMapper, \
2433
+ ConjugateRhs, Version> { \
2434
+ typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2435
+ \
2436
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2437
+ const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2438
+ ResScalar alpha) { \
2439
+ gemv_col<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2440
+ } \
2441
+ };
2442
+
2443
+ #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(Scalar) \
2444
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2445
+ struct general_matrix_vector_product<Index, Scalar, LhsMapper, RowMajor, ConjugateLhs, Scalar, RhsMapper, \
2446
+ ConjugateRhs, Version> { \
2447
+ typedef typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType ResScalar; \
2448
+ \
2449
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2450
+ const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2451
+ ResScalar alpha) { \
2452
+ gemv_row<Scalar, LhsMapper, Scalar, RhsMapper, ResScalar>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2453
+ } \
2454
+ };
2455
+
2456
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(float)
2457
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
2458
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
2459
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
2460
+
2461
+ #ifdef USE_GEMV_MMA
2462
+ #define gemv_bf16_col gemvMMA_bfloat16_col
2463
+ #define gemv_bf16_row gemvMMA_bfloat16_row
2464
+ #else
2465
+ #define gemv_bf16_col gemv_bfloat16_col
2466
+ #define gemv_bf16_row gemv_bfloat16_row
2467
+ #endif
2468
+
2469
+ #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
2470
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2471
+ struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, \
2472
+ ConjugateRhs, Version> { \
2473
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2474
+ const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2475
+ bfloat16 alpha) { \
2476
+ gemv_bf16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2477
+ } \
2478
+ };
2479
+
2480
+ #define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
2481
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2482
+ struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, \
2483
+ ConjugateRhs, Version> { \
2484
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2485
+ const RhsMapper& rhs, bfloat16* res, Index resIncr, \
2486
+ bfloat16 alpha) { \
2487
+ gemv_bf16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
2488
+ } \
2489
+ };
2490
+
2491
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
2492
+ EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
2493
+
2494
+ template <typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
2495
+ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1,
2496
+ ResPacket& b1) {
2497
+ if (GEMV_IS_COMPLEX_COMPLEX) {
2498
+ a0 = padd(a0, a1);
2499
+ b0 = padd(b0, b1);
2500
+ }
2501
+ return predux_complex<ResScalar, PResPacket>(a0, b0);
2502
+ }
2503
+
2504
+ #define GEMV_LOADPACKET_ROW_COMPLEX(iter) loadLhsPacket<Scalar, LhsScalar, LhsMapper, PLhsPacket>(lhs, i + (iter), j)
2505
+
2506
+ #define GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter) convertReal(GEMV_LOADPACKET_ROW_COMPLEX(iter))
2507
+
2508
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
2509
+ j = 0; \
2510
+ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2511
+ const RhsScalar& b1 = rhs2(j); \
2512
+ RhsScalar* b = const_cast<RhsScalar*>(&b1); \
2513
+ GEMV_UNROLL_ROW(which, N) \
2514
+ }
2515
+
2516
+ #define GEMV_PROCESS_END_ROW_COMPLEX(N) \
2517
+ for (; j < cols; ++j) { \
2518
+ RhsScalar b0 = rhs2(j); \
2519
+ GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
2520
+ } \
2521
+ GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
2522
+
2523
+ #ifdef USE_GEMV_MMA
2524
+ #define GEMV_INIT_ROW_COMPLEX_MMA(iter, N) \
2525
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
2526
+ __builtin_mma_xxsetaccz(&e0##iter); \
2527
+ }
2528
+
2529
+ #define GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter1, iter2) \
2530
+ GEMV_BUILDPAIR_MMA(a##iter1, GEMV_LOADPACKET_ROW_COMPLEX_DATA(iter2), GEMV_LOADPACKET_ROW_COMPLEX_DATA((iter2) + 1));
2531
+
2532
+ #define GEMV_WORK_ROW_COMPLEX_MMA(iter, N) \
2533
+ if (GEMV_GETN_COMPLEX(N) > iter) { \
2534
+ if (GEMV_IS_COMPLEX_FLOAT) { \
2535
+ PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2536
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, PLhsPacket, RhsScalar, RhsPacket, ResPacket, \
2537
+ ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2538
+ } else { \
2539
+ __vector_pair a##iter; \
2540
+ GEMV_LOADPAIR_ROW_COMPLEX_MMA(iter, iter << 1) \
2541
+ gemv_mult_complex_MMA<ScalarPacket, LhsScalar, PLhsPacket, __vector_pair, RhsScalar, RhsPacket, ResPacket, \
2542
+ ConjugateLhs, ConjugateRhs, RowMajor>(a##iter, b, &e0##iter); \
2543
+ } \
2544
+ }
2545
+
2546
+ #define GEMV_PREDUX4_COMPLEX_MMA(iter1, iter2, iter3, N) \
2547
+ if (N > iter1) { \
2548
+ if (GEMV_IS_COMPLEX_FLOAT) { \
2549
+ cc##iter1 = predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>( \
2550
+ &e0##iter2, &e0##iter3); \
2551
+ } else { \
2552
+ cc##iter1 = \
2553
+ predux_complex<ResScalar, ScalarPacket, LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs>(&e0##iter1); \
2554
+ } \
2555
+ } else { \
2556
+ EIGEN_UNUSED_VARIABLE(cc##iter1); \
2557
+ }
2558
+
2559
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2560
+ GEMV_UNROLL_ROW(GEMV_INIT_ROW_COMPLEX_MMA, N) \
2561
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX_MMA, N)
2562
+
2563
+ #define GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N) \
2564
+ for (; i < n##N; i += N) { \
2565
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_MMA(N) \
2566
+ GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_MMA, (N >> 1)) \
2567
+ GEMV_PROCESS_END_ROW_COMPLEX(N); \
2568
+ }
2569
+ #endif
2570
+
2571
+ #define GEMV_WORK_ROW_COMPLEX(iter, N) \
2572
+ if (N > iter) { \
2573
+ PLhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX(iter); \
2574
+ gemv_mult_complex<ScalarPacket, PLhsPacket, RhsScalar, RhsPacket, PResPacket, ResPacket, ConjugateLhs, \
2575
+ ConjugateRhs, RowMajor>(a##iter, b, c0##iter, c1##iter); \
2576
+ }
2577
+
2578
+ #define GEMV_PREDUX4_COMPLEX(iter1, iter2, iter3, N) \
2579
+ if (N > iter1) { \
2580
+ cc##iter1 = predux_complex<ResScalar, PResPacket, ResPacket, LhsPacket, RhsPacket>(c0##iter2, c0##iter3, \
2581
+ c1##iter2, c1##iter3); \
2582
+ } else { \
2583
+ EIGEN_UNUSED_VARIABLE(cc##iter1); \
2584
+ }
2585
+
2586
+ #define GEMV_MULT_COMPLEX(iter1, iter2, iter3, N) \
2587
+ if (N > iter1) { \
2588
+ cc##iter1.scalar[0] += cj.pmul(lhs(i + iter2, j), b0); \
2589
+ cc##iter1.scalar[1] += cj.pmul(lhs(i + iter3, j), b0); \
2590
+ }
2591
+
2592
+ #define GEMV_STORE_ROW_COMPLEX(iter1, iter2, iter3, N) \
2593
+ if (N > iter1) { \
2594
+ storeMaddData<ResScalar>(res + ((i + iter2) * resIncr), alpha, cc##iter1.scalar[0]); \
2595
+ storeMaddData<ResScalar>(res + ((i + iter3) * resIncr), alpha, cc##iter1.scalar[1]); \
2596
+ }
2597
+
2598
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2599
+ GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX, N) \
2600
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(GEMV_WORK_ROW_COMPLEX, N)
2601
+
2602
+ /** \internal main macro for gemv_complex_row - initialize accumulators, multiply and add inputs, predux and store
2603
+ * results */
2604
+ #define GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2605
+ for (; i < n##N; i += N) { \
2606
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2607
+ GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX, (N >> 1)) \
2608
+ GEMV_PROCESS_END_ROW_COMPLEX(N); \
2609
+ }
2610
+
2611
+ #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2612
+ if (GEMV_IS_COMPLEX_COMPLEX) { \
2613
+ c0##iter = padd(c0##iter, c1##iter); \
2614
+ } \
2615
+ dd0 = predux(c0##iter);
2616
+
2617
+ #if EIGEN_COMP_LLVM
2618
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N)
2619
+
2620
+ #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N)
2621
+
2622
+ #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter)
2623
+ #else
2624
+ // gcc seems to be reading and writing registers unnecessarily to memory.
2625
+ // Use the old way for complex double until it is fixed.
2626
+
2627
+ #define GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter) lhs.template load<LhsPacket, LhsAlignment>(i + (iter), j)
2628
+
2629
+ #define GEMV_INIT_COMPLEX_OLD(iter, N) \
2630
+ EIGEN_UNUSED_VARIABLE(c0##iter); \
2631
+ if (N > iter) { \
2632
+ c1##iter = pset_zero<ResPacket>(); \
2633
+ } else { \
2634
+ EIGEN_UNUSED_VARIABLE(c1##iter); \
2635
+ }
2636
+
2637
+ #define GEMV_WORK_ROW_COMPLEX_OLD(iter, N) \
2638
+ if (N > iter) { \
2639
+ LhsPacket a##iter = GEMV_LOADPACKET_ROW_COMPLEX_OLD(iter); \
2640
+ c1##iter = pcj.pmadd(a##iter, b0, c1##iter); \
2641
+ }
2642
+
2643
+ #define GEMV_PREDUX4_COMPLEX_OLD(iter1, iter2, iter3, N) \
2644
+ if (N > iter1) { \
2645
+ cc##iter1.scalar[0] = predux(c1##iter2); \
2646
+ cc##iter1.scalar[1] = predux(c1##iter3); \
2647
+ } else { \
2648
+ EIGEN_UNUSED_VARIABLE(cc##iter1); \
2649
+ }
2650
+
2651
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2652
+ GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
2653
+ j = 0; \
2654
+ for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
2655
+ RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
2656
+ GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
2657
+ }
2658
+
2659
+ #define GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2660
+ for (; i < n##N; i += N) { \
2661
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2662
+ GEMV_UNROLL_ROW_HALF(GEMV_PREDUX4_COMPLEX_OLD, (N >> 1)) \
2663
+ GEMV_PROCESS_END_ROW_COMPLEX(N) \
2664
+ }
2665
+
2666
+ #define GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) dd0 = predux(c1##iter);
2667
+
2668
+ #if (__GNUC__ > 10)
2669
+ #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW 1
2670
+ #else
2671
+ #define GEMV_PROCESS_ROW_COMPLEX_IS_NEW (sizeof(Scalar) == sizeof(float)) || GEMV_IS_COMPLEX_COMPLEX
2672
+ #endif
2673
+
2674
+ #define GEMV_PROCESS_ROW_COMPLEX_SINGLE(N) \
2675
+ if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2676
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_NEW(N) \
2677
+ } else { \
2678
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE_OLD(N) \
2679
+ }
2680
+
2681
+ #define GEMV_PROCESS_ROW_COMPLEX_ONE(N) \
2682
+ if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2683
+ GEMV_PROCESS_ROW_COMPLEX_ONE_NEW(N) \
2684
+ } else { \
2685
+ GEMV_PROCESS_ROW_COMPLEX_ONE_OLD(N) \
2686
+ }
2687
+
2688
+ #define GEMV_PROCESS_ROW_COMPLEX_PREDUX(iter) \
2689
+ if (GEMV_PROCESS_ROW_COMPLEX_IS_NEW) { \
2690
+ GEMV_PROCESS_ROW_COMPLEX_PREDUX_NEW(iter) \
2691
+ } else { \
2692
+ GEMV_PROCESS_ROW_COMPLEX_PREDUX_OLD(iter) \
2693
+ }
2694
+ #endif
2695
+
2696
+ #ifdef USE_GEMV_MMA
2697
+ #define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE_MMA(N)
2698
+ #else
2699
+ #define GEMV_PROCESS_ROW_COMPLEX(N) GEMV_PROCESS_ROW_COMPLEX_ONE(N)
2700
+ #endif
2701
+
2702
+ template <typename Scalar, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, bool LhsIsReal,
2703
+ typename RhsScalar, typename RhsMapper, bool ConjugateRhs, bool RhsIsReal, typename ResScalar>
2704
+ EIGEN_STRONG_INLINE void gemv_complex_row(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
2705
+ ResScalar* res, Index resIncr, ResScalar alpha) {
2706
+ typedef gemv_traits<LhsScalar, RhsScalar> Traits;
2707
+
2708
+ typedef typename Traits::LhsPacket LhsPacket;
2709
+ typedef typename Traits::RhsPacket RhsPacket;
2710
+ typedef typename Traits::ResPacket ResPacket;
2711
+
2712
+ typedef typename packet_traits<Scalar>::type ScalarPacket;
2713
+ typedef typename packet_traits<LhsScalar>::type PLhsPacket;
2714
+ typedef typename packet_traits<ResScalar>::type PResPacket;
2715
+ typedef gemv_traits<ResPacket, ResPacket> PTraits;
2716
+
2717
+ // The following copy tells the compiler that lhs's attributes are not modified outside this function
2718
+ // This helps GCC to generate proper code.
2719
+ LhsMapper lhs(alhs);
2720
+ typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
2721
+
2722
+ eigen_internal_assert(rhs.stride() == 1);
2723
+ conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
2724
+ #if !EIGEN_COMP_LLVM
2725
+ conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
2726
+ #endif
2727
+
2728
+ // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
2729
+ // processing 8 rows at once might be counter productive wrt cache.
2730
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2731
+ const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? (rows - 7) : (rows - 7);
2732
+ const Index n4 = rows - 3;
2733
+ const Index n2 = rows - 1;
2734
+ #endif
2735
+
2736
+ // TODO: for padded aligned inputs, we could enable aligned reads
2737
+ enum {
2738
+ LhsAlignment = Unaligned,
2739
+ ResPacketSize = PTraits::ResPacketSize,
2740
+ LhsPacketSize = PTraits::LhsPacketSize,
2741
+ RhsPacketSize = PTraits::RhsPacketSize,
2742
+ };
2743
+
2744
+ Index i = 0, j;
2745
+ PResPacket c00, c01, c02, c03, c04, c05, c06, c07;
2746
+ ResPacket c10, c11, c12, c13, c14, c15, c16, c17;
2747
+ #ifdef USE_GEMV_MMA
2748
+ __vector_quad e00, e01, e02, e03, e04, e05, e06, e07;
2749
+ GEMV_UNUSED_ROW(8, e0)
2750
+ GEMV_UNUSED_EXTRA(1, c0)
2751
+ GEMV_UNUSED_EXTRA(1, c1)
2752
+ #endif
2753
+ ResScalar dd0;
2754
+ #ifndef GCC_ONE_VECTORPAIR_BUG
2755
+ ScalarBlock<ResScalar, 2> cc0, cc1, cc2, cc3;
2756
+ #ifdef USE_GEMV_MMA
2757
+ if (!GEMV_IS_COMPLEX_COMPLEX)
2758
+ #endif
2759
+ {
2760
+ GEMV_PROCESS_ROW_COMPLEX(8)
2761
+ }
2762
+ GEMV_PROCESS_ROW_COMPLEX(4)
2763
+ GEMV_PROCESS_ROW_COMPLEX(2)
2764
+ #endif
2765
+ for (; i < rows; ++i) {
2766
+ GEMV_PROCESS_ROW_COMPLEX_SINGLE(1)
2767
+ GEMV_PROCESS_ROW_COMPLEX_PREDUX(0)
2768
+ for (; j < cols; ++j) {
2769
+ dd0 += cj.pmul(lhs(i, j), rhs2(j));
2770
+ }
2771
+ res[i * resIncr] += alpha * dd0;
2772
+ }
2773
+ }
2774
+
2775
+ #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(Scalar, LhsScalar, RhsScalar) \
2776
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2777
+ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2778
+ ConjugateRhs, Version> { \
2779
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2780
+ \
2781
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2782
+ const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2783
+ ResScalar alpha) { \
2784
+ gemv_complex_col<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2785
+ RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2786
+ res, resIncr, alpha); \
2787
+ } \
2788
+ };
2789
+
2790
+ #define EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(Scalar, LhsScalar, RhsScalar) \
2791
+ template <typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
2792
+ struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, \
2793
+ ConjugateRhs, Version> { \
2794
+ typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar; \
2795
+ \
2796
+ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs, \
2797
+ const RhsMapper& rhs, ResScalar* res, Index resIncr, \
2798
+ ResScalar alpha) { \
2799
+ gemv_complex_row<Scalar, LhsScalar, LhsMapper, ConjugateLhs, sizeof(Scalar) == sizeof(LhsScalar), RhsScalar, \
2800
+ RhsMapper, ConjugateRhs, sizeof(Scalar) == sizeof(RhsScalar), ResScalar>(rows, cols, lhs, rhs, \
2801
+ res, resIncr, alpha); \
2802
+ } \
2803
+ };
2804
+
2805
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, float, std::complex<float>)
2806
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, float)
2807
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(float, std::complex<float>, std::complex<float>)
2808
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, double, std::complex<double>)
2809
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, double)
2810
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_COL(double, std::complex<double>, std::complex<double>)
2811
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, float, std::complex<float>)
2812
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, float)
2813
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(float, std::complex<float>, std::complex<float>)
2814
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, double, std::complex<double>)
2815
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, double)
2816
+ EIGEN_POWER_GEMV_COMPLEX_SPECIALIZE_ROW(double, std::complex<double>, std::complex<double>)
2817
+
2818
+ #endif // EIGEN_MATRIX_VECTOR_PRODUCT_ALTIVEC_H