@smake/eigen 1.1.0 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/eigen/Eigen/AccelerateSupport +52 -0
- package/eigen/Eigen/Cholesky +18 -20
- package/eigen/Eigen/CholmodSupport +28 -28
- package/eigen/Eigen/Core +187 -120
- package/eigen/Eigen/Eigenvalues +16 -13
- package/eigen/Eigen/Geometry +18 -18
- package/eigen/Eigen/Householder +9 -7
- package/eigen/Eigen/IterativeLinearSolvers +8 -4
- package/eigen/Eigen/Jacobi +14 -13
- package/eigen/Eigen/KLUSupport +23 -21
- package/eigen/Eigen/LU +15 -16
- package/eigen/Eigen/MetisSupport +12 -12
- package/eigen/Eigen/OrderingMethods +54 -51
- package/eigen/Eigen/PaStiXSupport +23 -21
- package/eigen/Eigen/PardisoSupport +17 -14
- package/eigen/Eigen/QR +18 -20
- package/eigen/Eigen/QtAlignedMalloc +5 -12
- package/eigen/Eigen/SPQRSupport +21 -14
- package/eigen/Eigen/SVD +23 -17
- package/eigen/Eigen/Sparse +1 -2
- package/eigen/Eigen/SparseCholesky +18 -15
- package/eigen/Eigen/SparseCore +18 -17
- package/eigen/Eigen/SparseLU +9 -9
- package/eigen/Eigen/SparseQR +16 -14
- package/eigen/Eigen/StdDeque +5 -2
- package/eigen/Eigen/StdList +5 -2
- package/eigen/Eigen/StdVector +5 -2
- package/eigen/Eigen/SuperLUSupport +30 -24
- package/eigen/Eigen/ThreadPool +80 -0
- package/eigen/Eigen/UmfPackSupport +19 -17
- package/eigen/Eigen/Version +14 -0
- package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
- package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/LDLT.h +366 -405
- package/eigen/Eigen/src/Cholesky/LLT.h +323 -367
- package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
- package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +585 -529
- package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/ArithmeticSequence.h +143 -317
- package/eigen/Eigen/src/Core/Array.h +329 -370
- package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
- package/eigen/Eigen/src/Core/ArrayWrapper.h +126 -170
- package/eigen/Eigen/src/Core/Assign.h +30 -40
- package/eigen/Eigen/src/Core/AssignEvaluator.h +651 -604
- package/eigen/Eigen/src/Core/Assign_MKL.h +125 -120
- package/eigen/Eigen/src/Core/BandMatrix.h +267 -282
- package/eigen/Eigen/src/Core/Block.h +371 -390
- package/eigen/Eigen/src/Core/CommaInitializer.h +85 -100
- package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
- package/eigen/Eigen/src/Core/CoreEvaluators.h +1214 -937
- package/eigen/Eigen/src/Core/CoreIterators.h +72 -63
- package/eigen/Eigen/src/Core/CwiseBinaryOp.h +112 -129
- package/eigen/Eigen/src/Core/CwiseNullaryOp.h +676 -702
- package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
- package/eigen/Eigen/src/Core/CwiseUnaryOp.h +55 -67
- package/eigen/Eigen/src/Core/CwiseUnaryView.h +127 -92
- package/eigen/Eigen/src/Core/DenseBase.h +630 -658
- package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -628
- package/eigen/Eigen/src/Core/DenseStorage.h +511 -590
- package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
- package/eigen/Eigen/src/Core/Diagonal.h +168 -207
- package/eigen/Eigen/src/Core/DiagonalMatrix.h +346 -317
- package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
- package/eigen/Eigen/src/Core/Dot.h +167 -217
- package/eigen/Eigen/src/Core/EigenBase.h +74 -85
- package/eigen/Eigen/src/Core/Fill.h +138 -0
- package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
- package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -113
- package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
- package/eigen/Eigen/src/Core/GeneralProduct.h +315 -261
- package/eigen/Eigen/src/Core/GenericPacketMath.h +1182 -520
- package/eigen/Eigen/src/Core/GlobalFunctions.h +193 -157
- package/eigen/Eigen/src/Core/IO.h +131 -156
- package/eigen/Eigen/src/Core/IndexedView.h +209 -125
- package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
- package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/Inverse.h +50 -59
- package/eigen/Eigen/src/Core/Map.h +123 -141
- package/eigen/Eigen/src/Core/MapBase.h +255 -282
- package/eigen/Eigen/src/Core/MathFunctions.h +1247 -1201
- package/eigen/Eigen/src/Core/MathFunctionsImpl.h +162 -99
- package/eigen/Eigen/src/Core/Matrix.h +463 -494
- package/eigen/Eigen/src/Core/MatrixBase.h +468 -470
- package/eigen/Eigen/src/Core/NestByValue.h +58 -52
- package/eigen/Eigen/src/Core/NoAlias.h +79 -86
- package/eigen/Eigen/src/Core/NumTraits.h +206 -206
- package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +163 -142
- package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
- package/eigen/Eigen/src/Core/PlainObjectBase.h +858 -972
- package/eigen/Eigen/src/Core/Product.h +246 -130
- package/eigen/Eigen/src/Core/ProductEvaluators.h +779 -671
- package/eigen/Eigen/src/Core/Random.h +153 -164
- package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
- package/eigen/Eigen/src/Core/RealView.h +250 -0
- package/eigen/Eigen/src/Core/Redux.h +334 -314
- package/eigen/Eigen/src/Core/Ref.h +259 -257
- package/eigen/Eigen/src/Core/Replicate.h +92 -104
- package/eigen/Eigen/src/Core/Reshaped.h +215 -271
- package/eigen/Eigen/src/Core/ReturnByValue.h +47 -55
- package/eigen/Eigen/src/Core/Reverse.h +133 -148
- package/eigen/Eigen/src/Core/Select.h +68 -140
- package/eigen/Eigen/src/Core/SelfAdjointView.h +254 -290
- package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
- package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
- package/eigen/Eigen/src/Core/Solve.h +88 -102
- package/eigen/Eigen/src/Core/SolveTriangular.h +126 -124
- package/eigen/Eigen/src/Core/SolverBase.h +132 -133
- package/eigen/Eigen/src/Core/StableNorm.h +113 -147
- package/eigen/Eigen/src/Core/StlIterators.h +404 -248
- package/eigen/Eigen/src/Core/Stride.h +90 -92
- package/eigen/Eigen/src/Core/Swap.h +70 -39
- package/eigen/Eigen/src/Core/Transpose.h +258 -295
- package/eigen/Eigen/src/Core/Transpositions.h +270 -333
- package/eigen/Eigen/src/Core/TriangularMatrix.h +642 -743
- package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
- package/eigen/Eigen/src/Core/VectorwiseOp.h +653 -704
- package/eigen/Eigen/src/Core/Visitor.h +464 -308
- package/eigen/Eigen/src/Core/arch/AVX/Complex.h +380 -187
- package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +65 -163
- package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2145 -638
- package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
- package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +253 -60
- package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +278 -228
- package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +48 -269
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1597 -754
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
- package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +229 -41
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +420 -184
- package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +40 -49
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +2962 -2213
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +196 -212
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +713 -441
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +2380 -1362
- package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
- package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +390 -224
- package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +78 -67
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +1784 -799
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +167 -50
- package/eigen/Eigen/src/Core/arch/Default/Half.h +528 -379
- package/eigen/Eigen/src/Core/arch/Default/Settings.h +10 -12
- package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
- package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +41 -40
- package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +550 -523
- package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
- package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +27 -30
- package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +8 -8
- package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
- package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
- package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
- package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
- package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
- package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
- package/eigen/Eigen/src/Core/arch/MSA/Complex.h +54 -82
- package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +84 -92
- package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +51 -47
- package/eigen/Eigen/src/Core/arch/NEON/Complex.h +454 -306
- package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +175 -115
- package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +23 -30
- package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +4366 -2857
- package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +616 -393
- package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
- package/eigen/Eigen/src/Core/arch/SSE/Complex.h +350 -198
- package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +38 -149
- package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +1791 -912
- package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
- package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +128 -40
- package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +10 -6
- package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +156 -234
- package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +6 -3
- package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +27 -32
- package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +119 -117
- package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +325 -419
- package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +15 -17
- package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +325 -181
- package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +94 -83
- package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +811 -458
- package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +121 -124
- package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +576 -370
- package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +194 -109
- package/eigen/Eigen/src/Core/functors/StlFunctors.h +95 -112
- package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
- package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1038 -749
- package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +1883 -1375
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +312 -370
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +189 -176
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +84 -81
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +292 -337
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
- package/eigen/Eigen/src/Core/products/Parallelizer.h +207 -105
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +327 -388
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +138 -147
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
- package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
- package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -47
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -277
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
- package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +68 -94
- package/eigen/Eigen/src/Core/util/Assert.h +158 -0
- package/eigen/Eigen/src/Core/util/BlasUtil.h +342 -303
- package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +348 -317
- package/eigen/Eigen/src/Core/util/Constants.h +297 -262
- package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -90
- package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
- package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +449 -247
- package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
- package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
- package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +417 -116
- package/eigen/Eigen/src/Core/util/IntegralConstant.h +211 -204
- package/eigen/Eigen/src/Core/util/MKL_support.h +39 -37
- package/eigen/Eigen/src/Core/util/Macros.h +655 -773
- package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
- package/eigen/Eigen/src/Core/util/Memory.h +970 -748
- package/eigen/Eigen/src/Core/util/Meta.h +581 -633
- package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
- package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
- package/eigen/Eigen/src/Core/util/ReshapedHelper.h +17 -17
- package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
- package/eigen/Eigen/src/Core/util/StaticAssert.h +50 -166
- package/eigen/Eigen/src/Core/util/SymbolicIndex.h +377 -225
- package/eigen/Eigen/src/Core/util/XprHelper.h +784 -547
- package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
- package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
- package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
- package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
- package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
- package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +89 -105
- package/eigen/Eigen/src/Eigenvalues/RealQZ.h +537 -607
- package/eigen/Eigen/src/Eigenvalues/RealSchur.h +342 -381
- package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +541 -595
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
- package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +430 -462
- package/eigen/Eigen/src/Geometry/AlignedBox.h +226 -227
- package/eigen/Eigen/src/Geometry/AngleAxis.h +131 -133
- package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
- package/eigen/Eigen/src/Geometry/Homogeneous.h +285 -333
- package/eigen/Eigen/src/Geometry/Hyperplane.h +151 -160
- package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -146
- package/eigen/Eigen/src/Geometry/ParametrizedLine.h +127 -127
- package/eigen/Eigen/src/Geometry/Quaternion.h +566 -506
- package/eigen/Eigen/src/Geometry/Rotation2D.h +107 -105
- package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
- package/eigen/Eigen/src/Geometry/Scaling.h +113 -106
- package/eigen/Eigen/src/Geometry/Transform.h +858 -936
- package/eigen/Eigen/src/Geometry/Translation.h +94 -92
- package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
- package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +90 -104
- package/eigen/Eigen/src/Householder/BlockHouseholder.h +51 -46
- package/eigen/Eigen/src/Householder/Householder.h +102 -124
- package/eigen/Eigen/src/Householder/HouseholderSequence.h +412 -453
- package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +149 -162
- package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +124 -119
- package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +92 -104
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +251 -243
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +224 -228
- package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +178 -227
- package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +79 -84
- package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +54 -60
- package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Jacobi/Jacobi.h +252 -308
- package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/KLUSupport/KLUSupport.h +208 -227
- package/eigen/Eigen/src/LU/Determinant.h +50 -69
- package/eigen/Eigen/src/LU/FullPivLU.h +545 -596
- package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/LU/InverseImpl.h +206 -285
- package/eigen/Eigen/src/LU/PartialPivLU.h +390 -428
- package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
- package/eigen/Eigen/src/LU/arch/InverseSize4.h +72 -70
- package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
- package/eigen/Eigen/src/OrderingMethods/Amd.h +243 -265
- package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +831 -1004
- package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/OrderingMethods/Ordering.h +112 -119
- package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
- package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -430
- package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +479 -479
- package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
- package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +166 -153
- package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +495 -475
- package/eigen/Eigen/src/QR/HouseholderQR.h +394 -285
- package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
- package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +244 -264
- package/eigen/Eigen/src/SVD/BDCSVD.h +817 -713
- package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
- package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SVD/JacobiSVD.h +577 -543
- package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
- package/eigen/Eigen/src/SVD/SVDBase.h +242 -182
- package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +200 -235
- package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +765 -594
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +308 -94
- package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
- package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -252
- package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +134 -178
- package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCore/SparseAssign.h +149 -140
- package/eigen/Eigen/src/SparseCore/SparseBlock.h +403 -440
- package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
- package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +525 -303
- package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +555 -339
- package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
- package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +169 -197
- package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
- package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
- package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
- package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
- package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1603 -1245
- package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -350
- package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
- package/eigen/Eigen/src/SparseCore/SparseProduct.h +94 -97
- package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
- package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
- package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +370 -416
- package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
- package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
- package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
- package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
- package/eigen/Eigen/src/SparseCore/SparseUtil.h +138 -115
- package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
- package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
- package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
- package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseLU/SparseLU.h +756 -710
- package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
- package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
- package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
- package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +245 -301
- package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
- package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
- package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +89 -100
- package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
- package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
- package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +124 -132
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
- package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
- package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
- package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
- package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseQR/SparseQR.h +450 -502
- package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -93
- package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
- package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
- package/eigen/Eigen/src/StlSupport/details.h +48 -50
- package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -730
- package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
- package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
- package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
- package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
- package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
- package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
- package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
- package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
- package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
- package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
- package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
- package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
- package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +428 -464
- package/eigen/Eigen/src/misc/Image.h +41 -43
- package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/misc/Kernel.h +39 -41
- package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
- package/eigen/Eigen/src/misc/blas.h +83 -426
- package/eigen/Eigen/src/misc/lapacke.h +9972 -16179
- package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
- package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
- package/eigen/Eigen/src/plugins/{BlockMethods.h → BlockMethods.inc} +434 -506
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
- package/eigen/Eigen/src/plugins/{CommonCwiseUnaryOps.h → CommonCwiseUnaryOps.inc} +58 -68
- package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
- package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
- package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
- package/package.json +1 -1
- package/eigen/COPYING.APACHE +0 -203
- package/eigen/COPYING.BSD +0 -26
- package/eigen/COPYING.GPL +0 -674
- package/eigen/COPYING.LGPL +0 -502
- package/eigen/COPYING.MINPACK +0 -51
- package/eigen/COPYING.MPL2 +0 -373
- package/eigen/COPYING.README +0 -18
- package/eigen/Eigen/src/Core/BooleanRedux.h +0 -162
- package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -258
- package/eigen/Eigen/src/Core/arch/Default/TypeCasting.h +0 -120
- package/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h +0 -694
- package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
- package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
- package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
- package/eigen/Eigen/src/misc/lapack.h +0 -152
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -358
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -696
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
- package/eigen/Eigen/src/plugins/IndexedViewMethods.h +0 -262
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -95
- package/eigen/Eigen/src/plugins/ReshapedMethods.h +0 -149
- package/eigen/README.md +0 -5
|
@@ -0,0 +1,1245 @@
|
|
|
1
|
+
// This file is part of Eigen, a lightweight C++ template library
|
|
2
|
+
// for linear algebra.
|
|
3
|
+
//
|
|
4
|
+
// Copyright (C) 2022 Intel Corporation
|
|
5
|
+
//
|
|
6
|
+
// This Source Code Form is subject to the terms of the Mozilla
|
|
7
|
+
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
8
|
+
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
9
|
+
|
|
10
|
+
#ifndef EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
|
|
11
|
+
#define EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
|
|
12
|
+
|
|
13
|
+
#if EIGEN_COMP_MSVC
|
|
14
|
+
#include <intrin.h>
|
|
15
|
+
#else
|
|
16
|
+
#include <x86intrin.h>
|
|
17
|
+
#endif
|
|
18
|
+
#include <immintrin.h>
|
|
19
|
+
#include <type_traits>
|
|
20
|
+
|
|
21
|
+
// IWYU pragma: private
|
|
22
|
+
#include "../../InternalHeaderCheck.h"
|
|
23
|
+
|
|
24
|
+
#if !defined(EIGEN_USE_AVX512_GEMM_KERNELS)
|
|
25
|
+
#define EIGEN_USE_AVX512_GEMM_KERNELS 1
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
#define SECOND_FETCH (32)
|
|
29
|
+
#if (EIGEN_COMP_GNUC_STRICT != 0) && !defined(EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS)
|
|
30
|
+
// Use less registers to load A elements to workaround compiler spills. Loose a
|
|
31
|
+
// bit of performance (less than ~2%).
|
|
32
|
+
#define EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
|
33
|
+
#endif
|
|
34
|
+
|
|
35
|
+
namespace Eigen {
|
|
36
|
+
namespace internal {
|
|
37
|
+
|
|
38
|
+
template <typename Scalar, bool is_unit_inc>
|
|
39
|
+
class gemm_class {
|
|
40
|
+
using vec = typename packet_traits<Scalar>::type;
|
|
41
|
+
using vec_ymm = typename unpacket_traits<vec>::half;
|
|
42
|
+
using vec_xmm = typename unpacket_traits<vec_ymm>::half;
|
|
43
|
+
using umask_t = typename unpacket_traits<vec>::mask_t;
|
|
44
|
+
|
|
45
|
+
static constexpr bool is_f32 = sizeof(Scalar) == sizeof(float);
|
|
46
|
+
static constexpr bool is_f64 = sizeof(Scalar) == sizeof(double);
|
|
47
|
+
|
|
48
|
+
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_A_REGS
|
|
49
|
+
static constexpr bool use_less_a_regs = !is_unit_inc;
|
|
50
|
+
#else
|
|
51
|
+
static constexpr bool use_less_a_regs = true;
|
|
52
|
+
#endif
|
|
53
|
+
#ifndef EIGEN_ARCH_AVX512_GEMM_KERNEL_USE_LESS_B_REGS
|
|
54
|
+
static constexpr bool use_less_b_regs = !is_unit_inc;
|
|
55
|
+
#else
|
|
56
|
+
static constexpr bool use_less_b_regs = true;
|
|
57
|
+
#endif
|
|
58
|
+
|
|
59
|
+
static constexpr int a_regs[] = {0, 1, 2, use_less_a_regs ? 0 : 3, use_less_a_regs ? 1 : 4, use_less_a_regs ? 2 : 5};
|
|
60
|
+
static constexpr int b_regs[] = {6, use_less_b_regs ? 6 : 7};
|
|
61
|
+
static constexpr int c_regs[] = {
|
|
62
|
+
8, 16, 24, 9, 17, 25, 10, 18, 26, 11, 19, 27, 12, 20, 28, 13, 21, 29, 14, 22, 30, 15, 23, 31,
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
static constexpr int alpha_load_reg = 0;
|
|
66
|
+
static constexpr int c_load_regs[] = {1, 2, 6};
|
|
67
|
+
|
|
68
|
+
static constexpr int a_shift = 128;
|
|
69
|
+
static constexpr int b_shift = 128;
|
|
70
|
+
|
|
71
|
+
static constexpr int nelems_in_cache_line = is_f32 ? 16 : 8;
|
|
72
|
+
static constexpr int a_prefetch_size = nelems_in_cache_line * 2;
|
|
73
|
+
static constexpr int b_prefetch_size = nelems_in_cache_line * 8;
|
|
74
|
+
|
|
75
|
+
vec zmm[32];
|
|
76
|
+
umask_t mask;
|
|
77
|
+
|
|
78
|
+
// gemm arguments.
|
|
79
|
+
Index m;
|
|
80
|
+
const Index n, k, ldc;
|
|
81
|
+
const Index inc;
|
|
82
|
+
const Scalar *alpha;
|
|
83
|
+
|
|
84
|
+
const Scalar *a, *b;
|
|
85
|
+
Scalar *c;
|
|
86
|
+
|
|
87
|
+
const bool is_alpha1;
|
|
88
|
+
const bool is_beta0;
|
|
89
|
+
|
|
90
|
+
const Index a_stride, b_stride;
|
|
91
|
+
const Index a_off, b_off;
|
|
92
|
+
|
|
93
|
+
EIGEN_ALWAYS_INLINE void prefetch_a(const Scalar *a_addr) {
|
|
94
|
+
_mm_prefetch((char *)(a_prefetch_size + a_addr - a_shift), _MM_HINT_T0);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
EIGEN_ALWAYS_INLINE void prefetch_b(const Scalar *b_addr) {
|
|
98
|
+
_mm_prefetch((char *)(b_prefetch_size + b_addr - b_shift), _MM_HINT_T0);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
EIGEN_ALWAYS_INLINE void prefetch_x(const Scalar *x_addr) { _mm_prefetch((char *)(x_addr - a_shift), _MM_HINT_T2); }
|
|
102
|
+
|
|
103
|
+
EIGEN_ALWAYS_INLINE void prefetch_c(const Scalar *c_addr) {
|
|
104
|
+
#if defined(__PRFCHW__) && __PRFCHW__ == 1
|
|
105
|
+
_m_prefetchw((void *)c_addr);
|
|
106
|
+
#else
|
|
107
|
+
_mm_prefetch((char *)c_addr, _MM_HINT_T0);
|
|
108
|
+
#endif
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
template <int nelems>
|
|
112
|
+
EIGEN_ALWAYS_INLINE void a_load(vec &a_reg, const Scalar *a_addr) {
|
|
113
|
+
switch (nelems * sizeof(*a_addr) * 8) {
|
|
114
|
+
default:
|
|
115
|
+
case 512 * 3:
|
|
116
|
+
a_reg = ploadu<vec>(a_addr);
|
|
117
|
+
break;
|
|
118
|
+
case 512 * 2:
|
|
119
|
+
a_reg = ploadu<vec>(a_addr);
|
|
120
|
+
break;
|
|
121
|
+
case 512 * 1:
|
|
122
|
+
a_reg = ploadu<vec>(a_addr);
|
|
123
|
+
break;
|
|
124
|
+
case 256 * 1:
|
|
125
|
+
a_reg = preinterpret<vec>(_mm512_broadcast_f64x4(ploadu<Packet4d>(reinterpret_cast<const double *>(a_addr))));
|
|
126
|
+
break;
|
|
127
|
+
case 128 * 1:
|
|
128
|
+
a_reg = preinterpret<vec>(_mm512_broadcast_f32x4(ploadu<Packet4f>(reinterpret_cast<const float *>(a_addr))));
|
|
129
|
+
break;
|
|
130
|
+
case 64 * 1:
|
|
131
|
+
a_reg = preinterpret<vec>(pload1<Packet8d>(reinterpret_cast<const double *>(a_addr)));
|
|
132
|
+
break;
|
|
133
|
+
case 32 * 1:
|
|
134
|
+
a_reg = pload1<vec>(a_addr);
|
|
135
|
+
break;
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
EIGEN_ALWAYS_INLINE void b_load(vec &b_reg, const Scalar *b_addr) { b_reg = pload1<vec>(b_addr); }
|
|
140
|
+
|
|
141
|
+
template <int nelems>
|
|
142
|
+
EIGEN_ALWAYS_INLINE void c_store(Scalar *mem, vec &src) {
|
|
143
|
+
if (is_unit_inc) {
|
|
144
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
145
|
+
default:
|
|
146
|
+
case 512 * 3:
|
|
147
|
+
pstoreu(mem, src);
|
|
148
|
+
break;
|
|
149
|
+
case 512 * 2:
|
|
150
|
+
pstoreu(mem, src);
|
|
151
|
+
break;
|
|
152
|
+
case 512 * 1:
|
|
153
|
+
pstoreu(mem, src);
|
|
154
|
+
break;
|
|
155
|
+
case 256 * 1:
|
|
156
|
+
pstoreu(mem, preinterpret<vec_ymm>(src));
|
|
157
|
+
break;
|
|
158
|
+
case 128 * 1:
|
|
159
|
+
pstoreu(mem, preinterpret<vec_xmm>(src));
|
|
160
|
+
break;
|
|
161
|
+
case 64 * 1:
|
|
162
|
+
pstorel(mem, preinterpret<vec_xmm>(src));
|
|
163
|
+
break;
|
|
164
|
+
case 32 * 1:
|
|
165
|
+
pstores(mem, preinterpret<vec_xmm>(src));
|
|
166
|
+
break;
|
|
167
|
+
}
|
|
168
|
+
} else {
|
|
169
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
170
|
+
default:
|
|
171
|
+
case 512 * 3:
|
|
172
|
+
pscatter(mem, src, inc);
|
|
173
|
+
break;
|
|
174
|
+
case 512 * 2:
|
|
175
|
+
pscatter(mem, src, inc);
|
|
176
|
+
break;
|
|
177
|
+
case 512 * 1:
|
|
178
|
+
pscatter(mem, src, inc);
|
|
179
|
+
break;
|
|
180
|
+
case 256 * 1:
|
|
181
|
+
pscatter(mem, src, inc, mask);
|
|
182
|
+
break;
|
|
183
|
+
case 128 * 1:
|
|
184
|
+
pscatter(mem, src, inc, mask);
|
|
185
|
+
break;
|
|
186
|
+
case 64 * 1:
|
|
187
|
+
pscatter(mem, src, inc, mask);
|
|
188
|
+
break;
|
|
189
|
+
case 32 * 1:
|
|
190
|
+
pscatter(mem, src, inc, mask);
|
|
191
|
+
break;
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
template <int nelems>
|
|
197
|
+
EIGEN_ALWAYS_INLINE void vaddm(vec &dst, const Scalar *mem, vec &src, vec ®) {
|
|
198
|
+
if (is_unit_inc) {
|
|
199
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
200
|
+
default:
|
|
201
|
+
case 512 * 3:
|
|
202
|
+
dst = padd(src, ploadu<vec>(mem));
|
|
203
|
+
break;
|
|
204
|
+
case 512 * 2:
|
|
205
|
+
dst = padd(src, ploadu<vec>(mem));
|
|
206
|
+
break;
|
|
207
|
+
case 512 * 1:
|
|
208
|
+
dst = padd(src, ploadu<vec>(mem));
|
|
209
|
+
break;
|
|
210
|
+
case 256 * 1:
|
|
211
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
|
|
212
|
+
break;
|
|
213
|
+
case 128 * 1:
|
|
214
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
|
|
215
|
+
break;
|
|
216
|
+
case 64 * 1:
|
|
217
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
|
|
218
|
+
break;
|
|
219
|
+
case 32 * 1:
|
|
220
|
+
dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
|
|
221
|
+
break;
|
|
222
|
+
}
|
|
223
|
+
} else {
|
|
224
|
+
// Zero out scratch register
|
|
225
|
+
reg = pzero(reg);
|
|
226
|
+
|
|
227
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
228
|
+
default:
|
|
229
|
+
case 512 * 3:
|
|
230
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
231
|
+
dst = padd(src, reg);
|
|
232
|
+
break;
|
|
233
|
+
case 512 * 2:
|
|
234
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
235
|
+
dst = padd(src, reg);
|
|
236
|
+
break;
|
|
237
|
+
case 512 * 1:
|
|
238
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
239
|
+
dst = padd(src, reg);
|
|
240
|
+
break;
|
|
241
|
+
case 256 * 1:
|
|
242
|
+
reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
|
|
243
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
|
|
244
|
+
break;
|
|
245
|
+
case 128 * 1:
|
|
246
|
+
reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
|
|
247
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
|
|
248
|
+
break;
|
|
249
|
+
case 64 * 1:
|
|
250
|
+
if (is_f32) {
|
|
251
|
+
reg = pgather(reg, mem, inc, mask);
|
|
252
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
|
|
253
|
+
} else {
|
|
254
|
+
dst = preinterpret<vec>(padd(preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
|
|
255
|
+
}
|
|
256
|
+
break;
|
|
257
|
+
case 32 * 1:
|
|
258
|
+
dst = preinterpret<vec>(padds(preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
|
|
259
|
+
break;
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
EIGEN_STRONG_INLINE void vfmadd(vec &dst, const vec &src1, const vec &src2) {
|
|
265
|
+
dst = pmadd(src1, src2, dst);
|
|
266
|
+
|
|
267
|
+
#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0)
|
|
268
|
+
// Workaround register spills for gcc and clang
|
|
269
|
+
__asm__("#" : [dst] "+v"(dst) : [src1] "%v"(src1), [src2] "v"(src2));
|
|
270
|
+
#endif
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
template <int nelems>
|
|
274
|
+
EIGEN_ALWAYS_INLINE void vfmaddm(vec &dst, const Scalar *mem, vec &src, vec &scale, vec ®) {
|
|
275
|
+
if (is_unit_inc) {
|
|
276
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
277
|
+
default:
|
|
278
|
+
case 512 * 3:
|
|
279
|
+
dst = pmadd(scale, src, ploadu<vec>(mem));
|
|
280
|
+
break;
|
|
281
|
+
case 512 * 2:
|
|
282
|
+
dst = pmadd(scale, src, ploadu<vec>(mem));
|
|
283
|
+
break;
|
|
284
|
+
case 512 * 1:
|
|
285
|
+
dst = pmadd(scale, src, ploadu<vec>(mem));
|
|
286
|
+
break;
|
|
287
|
+
case 256 * 1:
|
|
288
|
+
dst =
|
|
289
|
+
preinterpret<vec>(pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), ploadu<vec_ymm>(mem)));
|
|
290
|
+
break;
|
|
291
|
+
case 128 * 1:
|
|
292
|
+
dst =
|
|
293
|
+
preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadu<vec_xmm>(mem)));
|
|
294
|
+
break;
|
|
295
|
+
case 64 * 1:
|
|
296
|
+
dst =
|
|
297
|
+
preinterpret<vec>(pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
|
|
298
|
+
break;
|
|
299
|
+
case 32 * 1:
|
|
300
|
+
dst =
|
|
301
|
+
preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
|
|
302
|
+
break;
|
|
303
|
+
}
|
|
304
|
+
} else {
|
|
305
|
+
// Zero out scratch register
|
|
306
|
+
reg = pzero(reg);
|
|
307
|
+
|
|
308
|
+
switch (nelems * sizeof(*mem) * 8) {
|
|
309
|
+
default:
|
|
310
|
+
case 512 * 3:
|
|
311
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
312
|
+
dst = pmadd(scale, src, reg);
|
|
313
|
+
break;
|
|
314
|
+
case 512 * 2:
|
|
315
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
316
|
+
dst = pmadd(scale, src, reg);
|
|
317
|
+
break;
|
|
318
|
+
case 512 * 1:
|
|
319
|
+
reg = pgather<Scalar, vec>(mem, inc);
|
|
320
|
+
dst = pmadd(scale, src, reg);
|
|
321
|
+
break;
|
|
322
|
+
case 256 * 1:
|
|
323
|
+
reg = preinterpret<vec>(pgather<Scalar, vec_ymm>(mem, inc));
|
|
324
|
+
dst = preinterpret<vec>(
|
|
325
|
+
pmadd(preinterpret<vec_ymm>(scale), preinterpret<vec_ymm>(src), preinterpret<vec_ymm>(reg)));
|
|
326
|
+
break;
|
|
327
|
+
case 128 * 1:
|
|
328
|
+
reg = preinterpret<vec>(pgather<Scalar, vec_xmm>(mem, inc));
|
|
329
|
+
dst = preinterpret<vec>(
|
|
330
|
+
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
|
|
331
|
+
break;
|
|
332
|
+
case 64 * 1:
|
|
333
|
+
if (is_f32) {
|
|
334
|
+
reg = pgather(reg, mem, inc, mask);
|
|
335
|
+
dst = preinterpret<vec>(
|
|
336
|
+
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), preinterpret<vec_xmm>(reg)));
|
|
337
|
+
} else {
|
|
338
|
+
dst = preinterpret<vec>(
|
|
339
|
+
pmadd(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploadl<vec_xmm>(mem)));
|
|
340
|
+
}
|
|
341
|
+
break;
|
|
342
|
+
case 32 * 1:
|
|
343
|
+
dst =
|
|
344
|
+
preinterpret<vec>(pmadds(preinterpret<vec_xmm>(scale), preinterpret<vec_xmm>(src), ploads<vec_xmm>(mem)));
|
|
345
|
+
break;
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
template <int j, int endX, int i, int endY, int nelems>
|
|
351
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(j > endX) || (i > endY)> a_loads(const Scalar *ao) {
|
|
352
|
+
EIGEN_UNUSED_VARIABLE(ao);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
template <int j, int endX, int i, int endY, int nelems>
|
|
356
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(j <= endX) && (i <= endY)> a_loads(const Scalar *ao) {
|
|
357
|
+
if (j < endX) {
|
|
358
|
+
if (i < endY) {
|
|
359
|
+
auto &a_reg = zmm[a_regs[i + (j % 2) * 3]];
|
|
360
|
+
const Scalar *a_addr = ao + nelems * j + nelems_in_cache_line * i - a_shift;
|
|
361
|
+
a_load<nelems>(a_reg, a_addr);
|
|
362
|
+
|
|
363
|
+
a_loads<j, endX, i + 1, endY, nelems>(ao);
|
|
364
|
+
} else {
|
|
365
|
+
a_loads<j + 1, endX, 0, endY, nelems>(ao);
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
|
|
371
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(un > max_b_unroll) || (i > um_vecs)> prefetch_cs(const Scalar *co1,
|
|
372
|
+
const Scalar *co2) {
|
|
373
|
+
EIGEN_UNUSED_VARIABLE(co1);
|
|
374
|
+
EIGEN_UNUSED_VARIABLE(co2);
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
/* C prefetch loop structure.
|
|
378
|
+
* for (int un = 0; un < 8; un++) {
|
|
379
|
+
* if (b_unroll >= un + 1) {
|
|
380
|
+
* if (un == 4) co2 = co1 + 4 * ldc;
|
|
381
|
+
*
|
|
382
|
+
* for (int i = 0; i < um_vecs; i++) {
|
|
383
|
+
* Scalar *co = (un + 1 <= 4) ? co1 : co2;
|
|
384
|
+
* auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
|
|
385
|
+
* prefetch_c(co + co_off);
|
|
386
|
+
* }
|
|
387
|
+
* }
|
|
388
|
+
* }
|
|
389
|
+
*/
|
|
390
|
+
|
|
391
|
+
template <int un, int max_b_unroll, int i, int um_vecs, int a_unroll, int b_unroll>
|
|
392
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(un <= max_b_unroll) && (i <= um_vecs)> prefetch_cs(Scalar *&co1, Scalar *&co2) {
|
|
393
|
+
if (un < max_b_unroll) {
|
|
394
|
+
if (b_unroll >= un + 1) {
|
|
395
|
+
if (un == 4 && i == 0) co2 = co1 + 4 * ldc;
|
|
396
|
+
|
|
397
|
+
if (i < um_vecs) {
|
|
398
|
+
Scalar *co = (un + 1 <= 4) ? co1 : co2;
|
|
399
|
+
auto co_off = (un % 4) * ldc + a_unroll - 1 + i * nelems_in_cache_line * sizeof *co;
|
|
400
|
+
prefetch_c(co + co_off);
|
|
401
|
+
|
|
402
|
+
prefetch_cs<un, max_b_unroll, i + 1, um_vecs, a_unroll, b_unroll>(co1, co2);
|
|
403
|
+
} else {
|
|
404
|
+
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
} else {
|
|
408
|
+
prefetch_cs<un + 1, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
// load_c
|
|
414
|
+
template <int i, int um_vecs, int idx, int nelems>
|
|
415
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
|
|
416
|
+
EIGEN_UNUSED_VARIABLE(cox);
|
|
417
|
+
EIGEN_UNUSED_VARIABLE(alpha_reg);
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
template <int i, int um_vecs, int idx, int nelems>
|
|
421
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> scale_load_c(const Scalar *cox, vec &alpha_reg) {
|
|
422
|
+
if (i < um_vecs) {
|
|
423
|
+
auto &c_reg = zmm[c_regs[i + idx * 3]];
|
|
424
|
+
auto &c_load_reg = zmm[c_load_regs[i % 3]];
|
|
425
|
+
auto c_mem = cox;
|
|
426
|
+
if (is_unit_inc)
|
|
427
|
+
c_mem += i * nelems_in_cache_line;
|
|
428
|
+
else
|
|
429
|
+
c_mem += i * nelems_in_cache_line * inc;
|
|
430
|
+
|
|
431
|
+
if (!is_beta0 && is_alpha1)
|
|
432
|
+
vaddm<nelems>(c_reg, c_mem, c_reg, c_load_reg);
|
|
433
|
+
else if (!is_beta0 && !is_alpha1)
|
|
434
|
+
vfmaddm<nelems>(c_reg, c_mem, c_reg, alpha_reg, c_load_reg);
|
|
435
|
+
else if (is_beta0 && !is_alpha1)
|
|
436
|
+
c_reg = pmul(alpha_reg, c_reg);
|
|
437
|
+
|
|
438
|
+
scale_load_c<i + 1, um_vecs, idx, nelems>(cox, alpha_reg);
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// store_c
|
|
443
|
+
template <int i, int um_vecs, int idx, int nelems>
|
|
444
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(i > um_vecs)> write_c(Scalar *cox) {
|
|
445
|
+
EIGEN_UNUSED_VARIABLE(cox);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
template <int i, int um_vecs, int idx, int nelems>
|
|
449
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(i <= um_vecs)> write_c(Scalar *cox) {
|
|
450
|
+
if (i < um_vecs) {
|
|
451
|
+
auto &c_reg = zmm[c_regs[i + idx * 3]];
|
|
452
|
+
auto c_mem = cox;
|
|
453
|
+
if (is_unit_inc)
|
|
454
|
+
c_mem += i * nelems_in_cache_line;
|
|
455
|
+
else
|
|
456
|
+
c_mem += i * nelems_in_cache_line * inc;
|
|
457
|
+
|
|
458
|
+
c_store<nelems>(c_mem, c_reg);
|
|
459
|
+
c_reg = pzero(c_reg);
|
|
460
|
+
|
|
461
|
+
write_c<i + 1, um_vecs, idx, nelems>(cox);
|
|
462
|
+
}
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
/* C update loop structure.
|
|
466
|
+
* co2 = co1 + ldc;
|
|
467
|
+
*
|
|
468
|
+
* auto &alpha_reg = zmm[alpha_load_reg];
|
|
469
|
+
* if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
|
|
470
|
+
*
|
|
471
|
+
* int idx = 0;
|
|
472
|
+
* for (pow = 1; pow <= 8; pow <<= 1) {
|
|
473
|
+
*
|
|
474
|
+
* if (b_unroll >= pow) {
|
|
475
|
+
* for (count = 1; count < (pow + 1) / 2 + 1; count++) {
|
|
476
|
+
* if (pow >= 4) co2 += ldc;
|
|
477
|
+
*
|
|
478
|
+
* const Scalar *cox = (idx == 0) ? co1 : co2;
|
|
479
|
+
*
|
|
480
|
+
* const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
|
|
481
|
+
* scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
|
|
482
|
+
* write_c<0, um_vecs, idx, a_unroll>(cox);
|
|
483
|
+
*
|
|
484
|
+
* idx++;
|
|
485
|
+
* }
|
|
486
|
+
* }
|
|
487
|
+
* }
|
|
488
|
+
*
|
|
489
|
+
* if (b_unroll == 1)
|
|
490
|
+
* co1 += ldc;
|
|
491
|
+
* else
|
|
492
|
+
* co1 = co2 + ldc;
|
|
493
|
+
*/
|
|
494
|
+
|
|
495
|
+
template <int pow, int a_unroll, int idx>
|
|
496
|
+
EIGEN_ALWAYS_INLINE void c_update_1count(Scalar *&cox) {
|
|
497
|
+
if (pow >= 4) cox += ldc;
|
|
498
|
+
|
|
499
|
+
const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
|
|
500
|
+
auto &alpha_reg = zmm[alpha_load_reg];
|
|
501
|
+
|
|
502
|
+
scale_load_c<0, um_vecs, idx, a_unroll>(cox, alpha_reg);
|
|
503
|
+
write_c<0, um_vecs, idx, a_unroll>(cox);
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
template <int pow, int a_unroll>
|
|
507
|
+
EIGEN_ALWAYS_INLINE void c_update_1pow(Scalar *&co1, Scalar *&co2) {
|
|
508
|
+
constexpr int idx = pow / 2;
|
|
509
|
+
Scalar *&cox = idx == 0 ? co1 : co2;
|
|
510
|
+
|
|
511
|
+
constexpr int max_count = (pow + 1) / 2;
|
|
512
|
+
static_assert(max_count <= 4, "Unsupported max_count.");
|
|
513
|
+
|
|
514
|
+
if (1 <= max_count) c_update_1count<pow, a_unroll, idx + 0>(cox);
|
|
515
|
+
if (2 <= max_count) c_update_1count<pow, a_unroll, idx + 1>(cox);
|
|
516
|
+
if (3 <= max_count) c_update_1count<pow, a_unroll, idx + 2>(cox);
|
|
517
|
+
if (4 <= max_count) c_update_1count<pow, a_unroll, idx + 3>(cox);
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
template <int max_b_unroll, int a_unroll, int b_unroll>
|
|
521
|
+
EIGEN_ALWAYS_INLINE void c_update(Scalar *&co1, Scalar *&co2) {
|
|
522
|
+
auto &alpha_reg = zmm[alpha_load_reg];
|
|
523
|
+
|
|
524
|
+
co2 = co1 + ldc;
|
|
525
|
+
if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
|
|
526
|
+
if (!is_unit_inc && a_unroll < nelems_in_cache_line) mask = static_cast<umask_t>((1ull << a_unroll) - 1);
|
|
527
|
+
|
|
528
|
+
static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");
|
|
529
|
+
|
|
530
|
+
if (1 <= max_b_unroll && 1 <= b_unroll) c_update_1pow<1, a_unroll>(co1, co2);
|
|
531
|
+
if (2 <= max_b_unroll && 2 <= b_unroll) c_update_1pow<2, a_unroll>(co1, co2);
|
|
532
|
+
if (4 <= max_b_unroll && 4 <= b_unroll) c_update_1pow<4, a_unroll>(co1, co2);
|
|
533
|
+
if (8 <= max_b_unroll && 8 <= b_unroll) c_update_1pow<8, a_unroll>(co1, co2);
|
|
534
|
+
|
|
535
|
+
if (b_unroll == 1)
|
|
536
|
+
co1 += ldc;
|
|
537
|
+
else
|
|
538
|
+
co1 = co2 + ldc;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
// compute
|
|
542
|
+
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
|
|
543
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
|
|
544
|
+
int &fetchB_idx, vec &b_reg) {
|
|
545
|
+
EIGEN_UNUSED_VARIABLE(ao);
|
|
546
|
+
EIGEN_UNUSED_VARIABLE(bo);
|
|
547
|
+
EIGEN_UNUSED_VARIABLE(fetchA_idx);
|
|
548
|
+
EIGEN_UNUSED_VARIABLE(fetchB_idx);
|
|
549
|
+
EIGEN_UNUSED_VARIABLE(b_reg);
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
template <int um, int um_vecs, int idx, int uk, bool fetch_x, bool ktail>
|
|
553
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> compute(const Scalar *ao, const Scalar *bo, int &fetchA_idx,
|
|
554
|
+
int &fetchB_idx, vec &b_reg) {
|
|
555
|
+
if (um < um_vecs) {
|
|
556
|
+
auto &c_reg = zmm[c_regs[um + idx * 3]];
|
|
557
|
+
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
|
|
558
|
+
|
|
559
|
+
vfmadd(c_reg, a_reg, b_reg);
|
|
560
|
+
|
|
561
|
+
if (!fetch_x && um == 0 &&
|
|
562
|
+
(((idx == 0 || idx == 6) && (uk % 2 == 0 || is_f64 || ktail)) ||
|
|
563
|
+
(idx == 3 && (uk % 2 == 1 || is_f64 || ktail)))) {
|
|
564
|
+
prefetch_a(ao + nelems_in_cache_line * fetchA_idx);
|
|
565
|
+
fetchA_idx++;
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
if (um == 0 && idx == 1 && (uk % 2 == 0 || is_f64 || ktail)) {
|
|
569
|
+
prefetch_b(bo + nelems_in_cache_line * fetchB_idx);
|
|
570
|
+
fetchB_idx++;
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
compute<um + 1, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
// load_a
|
|
578
|
+
template <int um, int um_vecs, int uk, int nelems, bool ktail>
|
|
579
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(um > um_vecs)> load_a(const Scalar *ao) {
|
|
580
|
+
EIGEN_UNUSED_VARIABLE(ao);
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
template <int um, int um_vecs, int uk, int nelems, bool ktail>
|
|
584
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(um <= um_vecs)> load_a(const Scalar *ao) {
|
|
585
|
+
if (um < um_vecs) {
|
|
586
|
+
auto &a_reg = zmm[a_regs[um + (uk % 2) * 3]];
|
|
587
|
+
const Scalar *a_addr = ao + nelems * (1 + !ktail * !use_less_a_regs + uk) + nelems_in_cache_line * um - a_shift;
|
|
588
|
+
a_load<nelems>(a_reg, a_addr);
|
|
589
|
+
|
|
590
|
+
load_a<um + 1, um_vecs, uk, nelems, ktail>(ao);
|
|
591
|
+
}
|
|
592
|
+
}
|
|
593
|
+
template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
|
594
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(count > (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
|
|
595
|
+
const Scalar *const &ao,
|
|
596
|
+
const Scalar *const &bo, Scalar *&co2,
|
|
597
|
+
int &fetchA_idx, int &fetchB_idx) {
|
|
598
|
+
EIGEN_UNUSED_VARIABLE(aa);
|
|
599
|
+
EIGEN_UNUSED_VARIABLE(ao);
|
|
600
|
+
EIGEN_UNUSED_VARIABLE(bo);
|
|
601
|
+
EIGEN_UNUSED_VARIABLE(co2);
|
|
602
|
+
EIGEN_UNUSED_VARIABLE(fetchA_idx);
|
|
603
|
+
EIGEN_UNUSED_VARIABLE(fetchB_idx);
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
template <int uk, int pow, int count, int um_vecs, int b_unroll, bool ktail, bool fetch_x, bool c_fetch>
|
|
607
|
+
EIGEN_ALWAYS_INLINE std::enable_if_t<(count <= (pow + 1) / 2)> innerkernel_1pow(const Scalar *&aa,
|
|
608
|
+
const Scalar *const &ao,
|
|
609
|
+
const Scalar *const &bo, Scalar *&co2,
|
|
610
|
+
int &fetchA_idx, int &fetchB_idx) {
|
|
611
|
+
const int idx = (pow / 2) + count;
|
|
612
|
+
|
|
613
|
+
if (count < (pow + 1) / 2) {
|
|
614
|
+
auto &b_reg = zmm[b_regs[idx % 2]];
|
|
615
|
+
|
|
616
|
+
if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
|
|
617
|
+
if (fetch_x && uk == 3 && idx == 4) aa += 8;
|
|
618
|
+
|
|
619
|
+
if (b_unroll >= pow) {
|
|
620
|
+
compute<0, um_vecs, idx, uk, fetch_x, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
|
621
|
+
|
|
622
|
+
const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) * !use_less_b_regs - b_shift;
|
|
623
|
+
b_load(b_reg, b_addr);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
// Go to the next count.
|
|
627
|
+
innerkernel_1pow<uk, pow, count + 1, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx,
|
|
628
|
+
fetchB_idx);
|
|
629
|
+
|
|
630
|
+
} else {
|
|
631
|
+
// Maybe prefetch C data after count-loop.
|
|
632
|
+
if (pow == 2 && c_fetch) {
|
|
633
|
+
if (uk % 3 == 0 && uk > 0) {
|
|
634
|
+
co2 += ldc;
|
|
635
|
+
} else {
|
|
636
|
+
prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
template <int uk, int max_b_unroll, int a_unroll, int b_unroll, bool ktail, bool fetch_x, bool c_fetch,
|
|
643
|
+
bool no_a_preload = false>
|
|
644
|
+
EIGEN_ALWAYS_INLINE void innerkernel_1uk(const Scalar *&aa, const Scalar *const &ao, const Scalar *const &bo,
|
|
645
|
+
Scalar *&co2, int &fetchA_idx, int &fetchB_idx) {
|
|
646
|
+
const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
|
|
647
|
+
|
|
648
|
+
if (max_b_unroll >= 1)
|
|
649
|
+
innerkernel_1pow<uk, 1, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
650
|
+
if (max_b_unroll >= 2)
|
|
651
|
+
innerkernel_1pow<uk, 2, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
652
|
+
if (max_b_unroll >= 4)
|
|
653
|
+
innerkernel_1pow<uk, 4, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
654
|
+
if (max_b_unroll >= 8)
|
|
655
|
+
innerkernel_1pow<uk, 8, 0, um_vecs, b_unroll, ktail, fetch_x, c_fetch>(aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
656
|
+
|
|
657
|
+
// Load A after pow-loop. Skip this at the end to prevent running over the buffer
|
|
658
|
+
if (!no_a_preload) load_a<0, um_vecs, uk, a_unroll, ktail>(ao);
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
/* Inner kernel loop structure.
|
|
662
|
+
* for (int uk = 0; uk < kfactor; uk++) {
|
|
663
|
+
* int idx = 0;
|
|
664
|
+
*
|
|
665
|
+
* for (pow = 1; pow < max_b_unroll << 1; pow <<= 1) {
|
|
666
|
+
* for (int count = 0; count < (pow + 1) / 2; count++) {
|
|
667
|
+
* auto &b_reg = zmm[b_regs[idx % 2]];
|
|
668
|
+
*
|
|
669
|
+
* if (fetch_x && uk == 3 && idx == 0) prefetch_x(aa);
|
|
670
|
+
* if (fetch_x && uk == 3 && idx == 4) aa += 8;
|
|
671
|
+
*
|
|
672
|
+
* if (b_unroll >= pow) {
|
|
673
|
+
* compute<0, um_vecs, idx, uk, fetchx, ktail>(ao, bo, fetchA_idx, fetchB_idx, b_reg);
|
|
674
|
+
*
|
|
675
|
+
* const Scalar *b_addr = bo + b_unroll * uk + idx + 1 + (b_unroll > 1) - b_shift ;
|
|
676
|
+
* b_load(b_reg, b_addr);
|
|
677
|
+
* }
|
|
678
|
+
* idx++;
|
|
679
|
+
* }
|
|
680
|
+
*
|
|
681
|
+
* Maybe prefetch C data.
|
|
682
|
+
* if (pow == 2 && c_fetch) {
|
|
683
|
+
* if (uk % 3 == 0 && uk > 0) {
|
|
684
|
+
* co2 += ldc;
|
|
685
|
+
* } else {
|
|
686
|
+
* prefetch_c(co2 + (uk % 3) * nelems_in_cache_line);
|
|
687
|
+
* }
|
|
688
|
+
* }
|
|
689
|
+
* }
|
|
690
|
+
*
|
|
691
|
+
* Load A.
|
|
692
|
+
* load_a<0, um_vecs, uk, ktail, a_unroll>(ao);
|
|
693
|
+
* }
|
|
694
|
+
*
|
|
695
|
+
* Advance A/B pointers after uk-loop.
|
|
696
|
+
* ao += a_unroll * kfactor;
|
|
697
|
+
* bo += b_unroll * kfactor;
|
|
698
|
+
*/
|
|
699
|
+
|
|
700
|
+
template <int a_unroll, int b_unroll, int k_factor, int max_b_unroll, int max_k_factor, bool c_fetch,
|
|
701
|
+
bool no_a_preload = false>
|
|
702
|
+
EIGEN_ALWAYS_INLINE void innerkernel(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co2) {
|
|
703
|
+
int fetchA_idx = 0;
|
|
704
|
+
int fetchB_idx = 0;
|
|
705
|
+
|
|
706
|
+
const bool fetch_x = k_factor == max_k_factor;
|
|
707
|
+
const bool ktail = k_factor == 1;
|
|
708
|
+
|
|
709
|
+
static_assert(k_factor <= 4 && k_factor > 0, "innerkernel maximum k_factor supported is 4");
|
|
710
|
+
static_assert(no_a_preload == false || (no_a_preload == true && k_factor == 1),
|
|
711
|
+
"skipping a preload only allowed when k unroll is 1");
|
|
712
|
+
|
|
713
|
+
if (k_factor > 0)
|
|
714
|
+
innerkernel_1uk<0, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
|
|
715
|
+
aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
716
|
+
if (k_factor > 1)
|
|
717
|
+
innerkernel_1uk<1, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
|
|
718
|
+
aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
719
|
+
if (k_factor > 2)
|
|
720
|
+
innerkernel_1uk<2, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
|
|
721
|
+
aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
722
|
+
if (k_factor > 3)
|
|
723
|
+
innerkernel_1uk<3, max_b_unroll, a_unroll, b_unroll, ktail, fetch_x, c_fetch, no_a_preload>(
|
|
724
|
+
aa, ao, bo, co2, fetchA_idx, fetchB_idx);
|
|
725
|
+
|
|
726
|
+
// Advance A/B pointers after uk-loop.
|
|
727
|
+
ao += a_unroll * k_factor;
|
|
728
|
+
bo += b_unroll * k_factor;
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
template <int a_unroll, int b_unroll, int max_b_unroll>
|
|
732
|
+
EIGEN_ALWAYS_INLINE void kloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
|
|
733
|
+
const int um_vecs = numext::div_ceil(a_unroll, nelems_in_cache_line);
|
|
734
|
+
if (!use_less_a_regs && k > 1)
|
|
735
|
+
a_loads<0, 2, 0, um_vecs, a_unroll>(ao);
|
|
736
|
+
else
|
|
737
|
+
a_loads<0, 1, 0, um_vecs, a_unroll>(ao);
|
|
738
|
+
|
|
739
|
+
b_load(zmm[b_regs[0]], bo - b_shift + 0);
|
|
740
|
+
if (!use_less_b_regs) b_load(zmm[b_regs[1]], bo - b_shift + 1);
|
|
741
|
+
|
|
742
|
+
#ifndef SECOND_FETCH
|
|
743
|
+
prefetch_cs<0, max_b_unroll, 0, um_vecs, a_unroll, b_unroll>(co1, co2);
|
|
744
|
+
#endif // SECOND_FETCH
|
|
745
|
+
|
|
746
|
+
// Unrolling k-loop by a factor of 4.
|
|
747
|
+
const int max_k_factor = 4;
|
|
748
|
+
Index kRem = k % max_k_factor;
|
|
749
|
+
Index k_ = k - kRem;
|
|
750
|
+
if (k_ >= max_k_factor) {
|
|
751
|
+
k_ -= max_k_factor;
|
|
752
|
+
kRem += max_k_factor;
|
|
753
|
+
}
|
|
754
|
+
Index loop_count = k_ / max_k_factor;
|
|
755
|
+
|
|
756
|
+
if (loop_count > 0) {
|
|
757
|
+
#ifdef SECOND_FETCH
|
|
758
|
+
loop_count -= SECOND_FETCH;
|
|
759
|
+
#endif
|
|
760
|
+
while (loop_count > 0) {
|
|
761
|
+
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
|
762
|
+
loop_count--;
|
|
763
|
+
}
|
|
764
|
+
#ifdef SECOND_FETCH
|
|
765
|
+
co2 = co1 + nelems_in_cache_line - 1;
|
|
766
|
+
|
|
767
|
+
loop_count += b_unroll;
|
|
768
|
+
while (loop_count > 0) {
|
|
769
|
+
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 1>(aa, ao, bo, co2);
|
|
770
|
+
loop_count--;
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
loop_count += SECOND_FETCH - b_unroll;
|
|
774
|
+
while (loop_count > 0) {
|
|
775
|
+
innerkernel<a_unroll, b_unroll, max_k_factor, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
|
776
|
+
loop_count--;
|
|
777
|
+
}
|
|
778
|
+
#endif
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
// k-loop remainder handling.
|
|
782
|
+
loop_count = kRem;
|
|
783
|
+
while (loop_count > 1) {
|
|
784
|
+
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0>(aa, ao, bo, co2);
|
|
785
|
+
loop_count--;
|
|
786
|
+
}
|
|
787
|
+
if (loop_count > 0) {
|
|
788
|
+
innerkernel<a_unroll, b_unroll, 1, max_b_unroll, max_k_factor, 0, true>(aa, ao, bo, co2);
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
// Update C matrix.
|
|
792
|
+
c_update<max_b_unroll, a_unroll, b_unroll>(co1, co2);
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
template <int a_unroll, int b_unroll, int max_b_unroll>
|
|
796
|
+
EIGEN_ALWAYS_INLINE void nloop(const Scalar *&aa, const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
|
|
797
|
+
// Set A matrix pointer.
|
|
798
|
+
ao = a + a_off * a_unroll;
|
|
799
|
+
|
|
800
|
+
// Set B matrix pointer if needed.
|
|
801
|
+
bo += b_unroll * b_off;
|
|
802
|
+
|
|
803
|
+
kloop<a_unroll, b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
804
|
+
|
|
805
|
+
// Advance B matrix pointer if needed.
|
|
806
|
+
bo += b_unroll * (b_stride - k - b_off);
|
|
807
|
+
|
|
808
|
+
// Advance prefetch A pointer.
|
|
809
|
+
aa += 16;
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
template <int a_unroll, int max_a_unroll, int max_b_unroll>
|
|
813
|
+
EIGEN_ALWAYS_INLINE void mloop(const Scalar *&ao, const Scalar *&bo, Scalar *&co1, Scalar *&co2) {
|
|
814
|
+
// Set prefetch A pointers.
|
|
815
|
+
const Scalar *aa = a + a_unroll * a_stride;
|
|
816
|
+
|
|
817
|
+
// Set C matrix pointers.
|
|
818
|
+
co1 = c;
|
|
819
|
+
if (a_unroll >= max_a_unroll) co2 = c + 2 * ldc;
|
|
820
|
+
if (is_unit_inc)
|
|
821
|
+
c += a_unroll;
|
|
822
|
+
else
|
|
823
|
+
c += a_unroll * inc;
|
|
824
|
+
|
|
825
|
+
// Set B matrix pointer.
|
|
826
|
+
bo = b;
|
|
827
|
+
|
|
828
|
+
// Main n-loop.
|
|
829
|
+
for (Index i = n / max_b_unroll; i > 0; i--) nloop<a_unroll, max_b_unroll, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
830
|
+
|
|
831
|
+
// n-remainders.
|
|
832
|
+
if (n & 4 && max_b_unroll > 4) nloop<a_unroll, 4, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
833
|
+
#if 0
|
|
834
|
+
if (n & 2 && max_b_unroll > 2) nloop<a_unroll, 2, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
835
|
+
if (n & 1 && max_b_unroll > 1) nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
836
|
+
#else
|
|
837
|
+
// Copy kernels don't support tails of n = 2 for single/double precision.
|
|
838
|
+
// Loop over ones.
|
|
839
|
+
int n_rem = 2 * ((n & 2) != 0) + 1 * ((n & 1) != 0);
|
|
840
|
+
while (n_rem > 0) {
|
|
841
|
+
nloop<a_unroll, 1, max_b_unroll>(aa, ao, bo, co1, co2);
|
|
842
|
+
n_rem--;
|
|
843
|
+
}
|
|
844
|
+
#endif
|
|
845
|
+
|
|
846
|
+
// Advance A matrix pointer.
|
|
847
|
+
a = ao + a_unroll * (a_stride - k - a_off);
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
public:
|
|
851
|
+
// Compute kernel unrolling C matrix by max_a_unroll x max_b_unroll.
|
|
852
|
+
template <int max_a_unroll, int max_b_unroll>
|
|
853
|
+
EIGEN_ALWAYS_INLINE void compute_kern() {
|
|
854
|
+
a -= -a_shift;
|
|
855
|
+
b -= -b_shift;
|
|
856
|
+
|
|
857
|
+
const Scalar *ao = nullptr;
|
|
858
|
+
const Scalar *bo = nullptr;
|
|
859
|
+
Scalar *co1 = nullptr;
|
|
860
|
+
Scalar *co2 = nullptr;
|
|
861
|
+
|
|
862
|
+
// Main m-loop.
|
|
863
|
+
for (; m >= max_a_unroll; m -= max_a_unroll) mloop<max_a_unroll, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
864
|
+
|
|
865
|
+
// m-remainders.
|
|
866
|
+
if (m & 32 && max_a_unroll > 32) mloop<32, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
867
|
+
if (m & 16 && max_a_unroll > 16) mloop<16, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
868
|
+
if (m & 8 && max_a_unroll > 8) mloop<8, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
869
|
+
if (m & 4 && max_a_unroll > 4) mloop<4, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
870
|
+
if (m & 2 && max_a_unroll > 2 && is_f64) mloop<2, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
871
|
+
if (m & 1 && max_a_unroll > 1 && is_f64) mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
872
|
+
|
|
873
|
+
// Copy kernels don't support tails of m = 2 for single precision.
|
|
874
|
+
// Loop over ones.
|
|
875
|
+
if (is_f32) {
|
|
876
|
+
int m_rem = 2 * ((m & 2) != 0) + 1 * ((m & 1) != 0);
|
|
877
|
+
while (m_rem > 0) {
|
|
878
|
+
mloop<1, max_a_unroll, max_b_unroll>(ao, bo, co1, co2);
|
|
879
|
+
m_rem--;
|
|
880
|
+
}
|
|
881
|
+
}
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
gemm_class(Index m_, Index n_, Index k_, Index ldc_, Index inc_, const Scalar *alpha_, const Scalar *a_,
|
|
885
|
+
const Scalar *b_, Scalar *c_, bool is_alpha1_, bool is_beta0_, Index a_stride_, Index b_stride_,
|
|
886
|
+
Index a_off_, Index b_off_)
|
|
887
|
+
: m(m_),
|
|
888
|
+
n(n_),
|
|
889
|
+
k(k_),
|
|
890
|
+
ldc(ldc_),
|
|
891
|
+
inc(inc_),
|
|
892
|
+
alpha(alpha_),
|
|
893
|
+
a(a_),
|
|
894
|
+
b(b_),
|
|
895
|
+
c(c_),
|
|
896
|
+
is_alpha1(is_alpha1_),
|
|
897
|
+
is_beta0(is_beta0_),
|
|
898
|
+
a_stride(a_stride_),
|
|
899
|
+
b_stride(b_stride_),
|
|
900
|
+
a_off(a_off_),
|
|
901
|
+
b_off(b_off_) {
|
|
902
|
+
// Zero out all accumulation registers.
|
|
903
|
+
zmm[8] = pzero(zmm[8]);
|
|
904
|
+
zmm[9] = pzero(zmm[9]);
|
|
905
|
+
zmm[10] = pzero(zmm[10]);
|
|
906
|
+
zmm[11] = pzero(zmm[11]);
|
|
907
|
+
zmm[12] = pzero(zmm[12]);
|
|
908
|
+
zmm[13] = pzero(zmm[13]);
|
|
909
|
+
zmm[14] = pzero(zmm[14]);
|
|
910
|
+
zmm[15] = pzero(zmm[15]);
|
|
911
|
+
zmm[16] = pzero(zmm[16]);
|
|
912
|
+
zmm[17] = pzero(zmm[17]);
|
|
913
|
+
zmm[18] = pzero(zmm[18]);
|
|
914
|
+
zmm[19] = pzero(zmm[19]);
|
|
915
|
+
zmm[20] = pzero(zmm[20]);
|
|
916
|
+
zmm[21] = pzero(zmm[21]);
|
|
917
|
+
zmm[22] = pzero(zmm[22]);
|
|
918
|
+
zmm[23] = pzero(zmm[23]);
|
|
919
|
+
zmm[24] = pzero(zmm[24]);
|
|
920
|
+
zmm[25] = pzero(zmm[25]);
|
|
921
|
+
zmm[26] = pzero(zmm[26]);
|
|
922
|
+
zmm[27] = pzero(zmm[27]);
|
|
923
|
+
zmm[28] = pzero(zmm[28]);
|
|
924
|
+
zmm[29] = pzero(zmm[29]);
|
|
925
|
+
zmm[30] = pzero(zmm[30]);
|
|
926
|
+
zmm[31] = pzero(zmm[31]);
|
|
927
|
+
}
|
|
928
|
+
};
|
|
929
|
+
|
|
930
|
+
// Compute kernel with max unroll support of:
|
|
931
|
+
// Single precision:
|
|
932
|
+
// max_a_unroll: 48, 32, 16, 8, 4, 2, 1
|
|
933
|
+
// max_b_unroll: 8, 4, 2, 1
|
|
934
|
+
// Double precision:
|
|
935
|
+
// max_a_unroll: 24, 16, 8, 4, 2, 1
|
|
936
|
+
// max_b_unroll: 8, 4, 2, 1
|
|
937
|
+
template <typename Scalar, int max_a_unroll, int max_b_unroll, bool is_alpha1, bool is_beta0, bool is_unit_inc>
|
|
938
|
+
EIGEN_DONT_INLINE void gemm_kern_avx512(Index m, Index n, Index k, Scalar *alpha, const Scalar *a, const Scalar *b,
|
|
939
|
+
Scalar *c, Index ldc, Index inc = 1, Index a_stride = -1, Index b_stride = -1,
|
|
940
|
+
Index a_off = 0, Index b_off = 0) {
|
|
941
|
+
if (a_stride == -1) a_stride = k;
|
|
942
|
+
if (b_stride == -1) b_stride = k;
|
|
943
|
+
|
|
944
|
+
gemm_class<Scalar, is_unit_inc> g(m, n, k, ldc, inc, alpha, a, b, c, is_alpha1, is_beta0, a_stride, b_stride, a_off,
|
|
945
|
+
b_off);
|
|
946
|
+
g.template compute_kern<max_a_unroll, max_b_unroll>();
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
// Template specializations of GEBP kernels with nr = 8.
|
|
950
|
+
#if EIGEN_USE_AVX512_GEMM_KERNELS
|
|
951
|
+
template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
|
|
952
|
+
class gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
|
|
953
|
+
: public gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
|
|
954
|
+
using Base = gebp_traits<float, float, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
|
|
955
|
+
|
|
956
|
+
public:
|
|
957
|
+
enum { nr = Base::Vectorizable ? 8 : 4 };
|
|
958
|
+
};
|
|
959
|
+
|
|
960
|
+
template <bool ConjLhs_, bool ConjRhs_, int PacketSize_>
|
|
961
|
+
class gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Target, PacketSize_>
|
|
962
|
+
: public gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_> {
|
|
963
|
+
using Base = gebp_traits<double, double, ConjLhs_, ConjRhs_, Architecture::Generic, PacketSize_>;
|
|
964
|
+
|
|
965
|
+
public:
|
|
966
|
+
enum { nr = Base::Vectorizable ? 8 : 4 };
|
|
967
|
+
};
|
|
968
|
+
|
|
969
|
+
template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
|
|
970
|
+
struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode> {
|
|
971
|
+
typedef typename packet_traits<Scalar>::type Packet;
|
|
972
|
+
typedef typename DataMapper::LinearMapper LinearMapper;
|
|
973
|
+
enum { PacketSize = packet_traits<Scalar>::size };
|
|
974
|
+
EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
|
|
975
|
+
Index offset = 0);
|
|
976
|
+
};
|
|
977
|
+
|
|
978
|
+
template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
|
|
979
|
+
EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, 8, ColMajor, Conjugate, PanelMode>::operator()(
|
|
980
|
+
Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride, Index offset) {
|
|
981
|
+
constexpr int nr = 8;
|
|
982
|
+
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS COLMAJOR");
|
|
983
|
+
EIGEN_UNUSED_VARIABLE(stride);
|
|
984
|
+
EIGEN_UNUSED_VARIABLE(offset);
|
|
985
|
+
eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
|
|
986
|
+
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
|
987
|
+
Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
|
|
988
|
+
Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
|
|
989
|
+
Index count = 0;
|
|
990
|
+
const Index peeled_k = (depth / PacketSize) * PacketSize;
|
|
991
|
+
if (nr >= 8) {
|
|
992
|
+
for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
|
|
993
|
+
// skip what we have before
|
|
994
|
+
if (PanelMode) count += 8 * offset;
|
|
995
|
+
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
|
|
996
|
+
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
|
|
997
|
+
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
|
|
998
|
+
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
|
|
999
|
+
const LinearMapper dm4 = rhs.getLinearMapper(0, j2 + 4);
|
|
1000
|
+
const LinearMapper dm5 = rhs.getLinearMapper(0, j2 + 5);
|
|
1001
|
+
const LinearMapper dm6 = rhs.getLinearMapper(0, j2 + 6);
|
|
1002
|
+
const LinearMapper dm7 = rhs.getLinearMapper(0, j2 + 7);
|
|
1003
|
+
Index k = 0;
|
|
1004
|
+
if ((PacketSize % 8) == 0) // TODO enable vectorized transposition for PacketSize==4
|
|
1005
|
+
{
|
|
1006
|
+
for (; k < peeled_k; k += PacketSize) {
|
|
1007
|
+
PacketBlock<Packet, (PacketSize % 8) == 0 ? 8 : PacketSize> kernel;
|
|
1008
|
+
|
|
1009
|
+
kernel.packet[0] = dm0.template loadPacket<Packet>(k);
|
|
1010
|
+
kernel.packet[1] = dm1.template loadPacket<Packet>(k);
|
|
1011
|
+
kernel.packet[2] = dm2.template loadPacket<Packet>(k);
|
|
1012
|
+
kernel.packet[3] = dm3.template loadPacket<Packet>(k);
|
|
1013
|
+
kernel.packet[4] = dm4.template loadPacket<Packet>(k);
|
|
1014
|
+
kernel.packet[5] = dm5.template loadPacket<Packet>(k);
|
|
1015
|
+
kernel.packet[6] = dm6.template loadPacket<Packet>(k);
|
|
1016
|
+
kernel.packet[7] = dm7.template loadPacket<Packet>(k);
|
|
1017
|
+
|
|
1018
|
+
ptranspose(kernel);
|
|
1019
|
+
|
|
1020
|
+
pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
|
|
1021
|
+
pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
|
|
1022
|
+
pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
|
|
1023
|
+
pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
|
|
1024
|
+
pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel.packet[4 % PacketSize]));
|
|
1025
|
+
pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel.packet[5 % PacketSize]));
|
|
1026
|
+
pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel.packet[6 % PacketSize]));
|
|
1027
|
+
pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel.packet[7 % PacketSize]));
|
|
1028
|
+
count += 8 * PacketSize;
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
for (; k < depth; k++) {
|
|
1032
|
+
blockB[count + 0] = cj(dm0(k));
|
|
1033
|
+
blockB[count + 1] = cj(dm1(k));
|
|
1034
|
+
blockB[count + 2] = cj(dm2(k));
|
|
1035
|
+
blockB[count + 3] = cj(dm3(k));
|
|
1036
|
+
blockB[count + 4] = cj(dm4(k));
|
|
1037
|
+
blockB[count + 5] = cj(dm5(k));
|
|
1038
|
+
blockB[count + 6] = cj(dm6(k));
|
|
1039
|
+
blockB[count + 7] = cj(dm7(k));
|
|
1040
|
+
count += 8;
|
|
1041
|
+
}
|
|
1042
|
+
// skip what we have after
|
|
1043
|
+
if (PanelMode) count += 8 * (stride - offset - depth);
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
|
|
1047
|
+
if (nr >= 4) {
|
|
1048
|
+
for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
|
|
1049
|
+
// skip what we have before
|
|
1050
|
+
if (PanelMode) count += 4 * offset;
|
|
1051
|
+
const LinearMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
|
|
1052
|
+
const LinearMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
|
|
1053
|
+
const LinearMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
|
|
1054
|
+
const LinearMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
|
|
1055
|
+
|
|
1056
|
+
Index k = 0;
|
|
1057
|
+
if ((PacketSize % 4) == 0) // TODO enable vectorized transposition for PacketSize==2 ??
|
|
1058
|
+
{
|
|
1059
|
+
for (; k < peeled_k; k += PacketSize) {
|
|
1060
|
+
PacketBlock<Packet, (PacketSize % 4) == 0 ? 4 : PacketSize> kernel;
|
|
1061
|
+
kernel.packet[0] = dm0.template loadPacket<Packet>(k);
|
|
1062
|
+
kernel.packet[1 % PacketSize] = dm1.template loadPacket<Packet>(k);
|
|
1063
|
+
kernel.packet[2 % PacketSize] = dm2.template loadPacket<Packet>(k);
|
|
1064
|
+
kernel.packet[3 % PacketSize] = dm3.template loadPacket<Packet>(k);
|
|
1065
|
+
ptranspose(kernel);
|
|
1066
|
+
pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel.packet[0]));
|
|
1067
|
+
pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel.packet[1 % PacketSize]));
|
|
1068
|
+
pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel.packet[2 % PacketSize]));
|
|
1069
|
+
pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel.packet[3 % PacketSize]));
|
|
1070
|
+
count += 4 * PacketSize;
|
|
1071
|
+
}
|
|
1072
|
+
}
|
|
1073
|
+
for (; k < depth; k++) {
|
|
1074
|
+
blockB[count + 0] = cj(dm0(k));
|
|
1075
|
+
blockB[count + 1] = cj(dm1(k));
|
|
1076
|
+
blockB[count + 2] = cj(dm2(k));
|
|
1077
|
+
blockB[count + 3] = cj(dm3(k));
|
|
1078
|
+
count += 4;
|
|
1079
|
+
}
|
|
1080
|
+
// skip what we have after
|
|
1081
|
+
if (PanelMode) count += 4 * (stride - offset - depth);
|
|
1082
|
+
}
|
|
1083
|
+
}
|
|
1084
|
+
|
|
1085
|
+
// copy the remaining columns one at a time (nr==1)
|
|
1086
|
+
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
|
|
1087
|
+
if (PanelMode) count += offset;
|
|
1088
|
+
const LinearMapper dm0 = rhs.getLinearMapper(0, j2);
|
|
1089
|
+
for (Index k = 0; k < depth; k++) {
|
|
1090
|
+
blockB[count] = cj(dm0(k));
|
|
1091
|
+
count += 1;
|
|
1092
|
+
}
|
|
1093
|
+
if (PanelMode) count += (stride - offset - depth);
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
template <typename Scalar, typename Index, typename DataMapper, bool Conjugate, bool PanelMode>
|
|
1098
|
+
struct gemm_pack_rhs<Scalar, Index, DataMapper, 8, RowMajor, Conjugate, PanelMode> {
|
|
1099
|
+
typedef typename packet_traits<Scalar>::type Packet;
|
|
1100
|
+
typedef typename unpacket_traits<Packet>::half HalfPacket;
|
|
1101
|
+
typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
|
|
1102
|
+
typedef typename DataMapper::LinearMapper LinearMapper;
|
|
1103
|
+
enum {
|
|
1104
|
+
PacketSize = packet_traits<Scalar>::size,
|
|
1105
|
+
HalfPacketSize = unpacket_traits<HalfPacket>::size,
|
|
1106
|
+
QuarterPacketSize = unpacket_traits<QuarterPacket>::size
|
|
1107
|
+
};
|
|
1108
|
+
EIGEN_DONT_INLINE void operator()(Scalar *blockB, const DataMapper &rhs, Index depth, Index cols, Index stride = 0,
|
|
1109
|
+
Index offset = 0) {
|
|
1110
|
+
constexpr int nr = 8;
|
|
1111
|
+
EIGEN_ASM_COMMENT("EIGEN PRODUCT PACK RHS ROWMAJOR");
|
|
1112
|
+
EIGEN_UNUSED_VARIABLE(stride);
|
|
1113
|
+
EIGEN_UNUSED_VARIABLE(offset);
|
|
1114
|
+
eigen_assert(((!PanelMode) && stride == 0 && offset == 0) || (PanelMode && stride >= depth && offset <= stride));
|
|
1115
|
+
const bool HasHalf = (int)HalfPacketSize < (int)PacketSize;
|
|
1116
|
+
const bool HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize;
|
|
1117
|
+
conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
|
|
1118
|
+
Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
|
|
1119
|
+
Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
|
|
1120
|
+
Index count = 0;
|
|
1121
|
+
|
|
1122
|
+
if (nr >= 8) {
|
|
1123
|
+
for (Index j2 = 0; j2 < packet_cols8; j2 += 8) {
|
|
1124
|
+
// skip what we have before
|
|
1125
|
+
if (PanelMode) count += 8 * offset;
|
|
1126
|
+
for (Index k = 0; k < depth; k++) {
|
|
1127
|
+
if (PacketSize == 8) {
|
|
1128
|
+
// Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
|
|
1129
|
+
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
|
1130
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1131
|
+
} else if (HasHalf && HalfPacketSize == 8) {
|
|
1132
|
+
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
|
|
1133
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1134
|
+
} else if (HasQuarter && QuarterPacketSize == 8) {
|
|
1135
|
+
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
|
|
1136
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1137
|
+
} else if (PacketSize == 4) {
|
|
1138
|
+
// Packet A = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2]);
|
|
1139
|
+
// Packet B = ploadu<Packet>(&rhs.data()[k*rhs.stride() + j2 + PacketSize]);
|
|
1140
|
+
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
|
1141
|
+
Packet B = rhs.template loadPacket<Packet>(k, j2 + PacketSize);
|
|
1142
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1143
|
+
pstoreu(blockB + count + PacketSize, cj.pconj(B));
|
|
1144
|
+
} else {
|
|
1145
|
+
// const Scalar* b0 = &rhs.data()[k*rhs.stride() + j2];
|
|
1146
|
+
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
|
|
1147
|
+
blockB[count + 0] = cj(dm0(0));
|
|
1148
|
+
blockB[count + 1] = cj(dm0(1));
|
|
1149
|
+
blockB[count + 2] = cj(dm0(2));
|
|
1150
|
+
blockB[count + 3] = cj(dm0(3));
|
|
1151
|
+
blockB[count + 4] = cj(dm0(4));
|
|
1152
|
+
blockB[count + 5] = cj(dm0(5));
|
|
1153
|
+
blockB[count + 6] = cj(dm0(6));
|
|
1154
|
+
blockB[count + 7] = cj(dm0(7));
|
|
1155
|
+
}
|
|
1156
|
+
count += 8;
|
|
1157
|
+
}
|
|
1158
|
+
// skip what we have after
|
|
1159
|
+
if (PanelMode) count += 8 * (stride - offset - depth);
|
|
1160
|
+
}
|
|
1161
|
+
}
|
|
1162
|
+
|
|
1163
|
+
if (nr >= 4) {
|
|
1164
|
+
for (Index j2 = packet_cols8; j2 < packet_cols4; j2 += 4) {
|
|
1165
|
+
// skip what we have before
|
|
1166
|
+
if (PanelMode) count += 4 * offset;
|
|
1167
|
+
for (Index k = 0; k < depth; k++) {
|
|
1168
|
+
if (PacketSize == 4) {
|
|
1169
|
+
Packet A = rhs.template loadPacket<Packet>(k, j2);
|
|
1170
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1171
|
+
count += PacketSize;
|
|
1172
|
+
} else if (HasHalf && HalfPacketSize == 4) {
|
|
1173
|
+
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
|
|
1174
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1175
|
+
count += HalfPacketSize;
|
|
1176
|
+
} else if (HasQuarter && QuarterPacketSize == 4) {
|
|
1177
|
+
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
|
|
1178
|
+
pstoreu(blockB + count, cj.pconj(A));
|
|
1179
|
+
count += QuarterPacketSize;
|
|
1180
|
+
} else {
|
|
1181
|
+
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
|
|
1182
|
+
blockB[count + 0] = cj(dm0(0));
|
|
1183
|
+
blockB[count + 1] = cj(dm0(1));
|
|
1184
|
+
blockB[count + 2] = cj(dm0(2));
|
|
1185
|
+
blockB[count + 3] = cj(dm0(3));
|
|
1186
|
+
count += 4;
|
|
1187
|
+
}
|
|
1188
|
+
}
|
|
1189
|
+
// skip what we have after
|
|
1190
|
+
if (PanelMode) count += 4 * (stride - offset - depth);
|
|
1191
|
+
}
|
|
1192
|
+
}
|
|
1193
|
+
// copy the remaining columns one at a time (nr==1)
|
|
1194
|
+
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
|
|
1195
|
+
if (PanelMode) count += offset;
|
|
1196
|
+
for (Index k = 0; k < depth; k++) {
|
|
1197
|
+
blockB[count] = cj(rhs(k, j2));
|
|
1198
|
+
count += 1;
|
|
1199
|
+
}
|
|
1200
|
+
if (PanelMode) count += stride - offset - depth;
|
|
1201
|
+
}
|
|
1202
|
+
}
|
|
1203
|
+
};
|
|
1204
|
+
|
|
1205
|
+
template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
|
|
1206
|
+
struct gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs> {
|
|
1207
|
+
EIGEN_ALWAYS_INLINE void operator()(const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows,
|
|
1208
|
+
Index depth, Index cols, Scalar alpha, Index strideA = -1, Index strideB = -1,
|
|
1209
|
+
Index offsetA = 0, Index offsetB = 0);
|
|
1210
|
+
};
|
|
1211
|
+
|
|
1212
|
+
template <typename Scalar, typename Index, typename DataMapper, int mr, bool ConjugateLhs, bool ConjugateRhs>
|
|
1213
|
+
EIGEN_ALWAYS_INLINE void gebp_kernel<Scalar, Scalar, Index, DataMapper, mr, 8, ConjugateLhs, ConjugateRhs>::operator()(
|
|
1214
|
+
const DataMapper &res, const Scalar *blockA, const Scalar *blockB, Index rows, Index depth, Index cols,
|
|
1215
|
+
Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) {
|
|
1216
|
+
if (res.incr() == 1) {
|
|
1217
|
+
if (alpha == 1) {
|
|
1218
|
+
gemm_kern_avx512<Scalar, mr, 8, true, false, true>(rows, cols, depth, &alpha, blockA, blockB,
|
|
1219
|
+
(Scalar *)res.data(), res.stride(), res.incr(), strideA,
|
|
1220
|
+
strideB, offsetA, offsetB);
|
|
1221
|
+
} else {
|
|
1222
|
+
gemm_kern_avx512<Scalar, mr, 8, false, false, true>(rows, cols, depth, &alpha, blockA, blockB,
|
|
1223
|
+
(Scalar *)res.data(), res.stride(), res.incr(), strideA,
|
|
1224
|
+
strideB, offsetA, offsetB);
|
|
1225
|
+
}
|
|
1226
|
+
} else {
|
|
1227
|
+
if (alpha == 1) {
|
|
1228
|
+
gemm_kern_avx512<Scalar, mr, 8, true, false, false>(rows, cols, depth, &alpha, blockA, blockB,
|
|
1229
|
+
(Scalar *)res.data(), res.stride(), res.incr(), strideA,
|
|
1230
|
+
strideB, offsetA, offsetB);
|
|
1231
|
+
} else {
|
|
1232
|
+
gemm_kern_avx512<Scalar, mr, 8, false, false, false>(rows, cols, depth, &alpha, blockA, blockB,
|
|
1233
|
+
(Scalar *)res.data(), res.stride(), res.incr(), strideA,
|
|
1234
|
+
strideB, offsetA, offsetB);
|
|
1235
|
+
}
|
|
1236
|
+
}
|
|
1237
|
+
}
|
|
1238
|
+
#endif // EIGEN_USE_AVX512_GEMM_KERNELS
|
|
1239
|
+
|
|
1240
|
+
} // namespace internal
|
|
1241
|
+
} // namespace Eigen
|
|
1242
|
+
|
|
1243
|
+
#undef SECOND_FETCH
|
|
1244
|
+
|
|
1245
|
+
#endif // EIGEN_CORE_ARCH_AVX512_GEMM_KERNEL_H
|