@smake/eigen 1.0.2 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/eigen/Eigen/AccelerateSupport +52 -0
- package/eigen/Eigen/Cholesky +18 -21
- package/eigen/Eigen/CholmodSupport +28 -28
- package/eigen/Eigen/Core +235 -326
- package/eigen/Eigen/Eigenvalues +16 -14
- package/eigen/Eigen/Geometry +21 -24
- package/eigen/Eigen/Householder +9 -8
- package/eigen/Eigen/IterativeLinearSolvers +8 -4
- package/eigen/Eigen/Jacobi +14 -14
- package/eigen/Eigen/KLUSupport +43 -0
- package/eigen/Eigen/LU +16 -20
- package/eigen/Eigen/MetisSupport +12 -12
- package/eigen/Eigen/OrderingMethods +54 -54
- package/eigen/Eigen/PaStiXSupport +23 -20
- package/eigen/Eigen/PardisoSupport +17 -14
- package/eigen/Eigen/QR +18 -21
- package/eigen/Eigen/QtAlignedMalloc +5 -13
- package/eigen/Eigen/SPQRSupport +21 -14
- package/eigen/Eigen/SVD +23 -18
- package/eigen/Eigen/Sparse +1 -4
- package/eigen/Eigen/SparseCholesky +18 -23
- package/eigen/Eigen/SparseCore +18 -17
- package/eigen/Eigen/SparseLU +12 -8
- package/eigen/Eigen/SparseQR +16 -14
- package/eigen/Eigen/StdDeque +5 -2
- package/eigen/Eigen/StdList +5 -2
- package/eigen/Eigen/StdVector +5 -2
- package/eigen/Eigen/SuperLUSupport +30 -24
- package/eigen/Eigen/ThreadPool +80 -0
- package/eigen/Eigen/UmfPackSupport +19 -17
- package/eigen/Eigen/Version +14 -0
- package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
- package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
- package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
- package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
- package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
- package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
- package/eigen/Eigen/src/Core/Array.h +341 -294
- package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
- package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
- package/eigen/Eigen/src/Core/Assign.h +30 -40
- package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
- package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
- package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
- package/eigen/Eigen/src/Core/Block.h +375 -398
- package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
- package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
- package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
- package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
- package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
- package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
- package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
- package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
- package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
- package/eigen/Eigen/src/Core/DenseBase.h +632 -571
- package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
- package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
- package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
- package/eigen/Eigen/src/Core/Diagonal.h +169 -210
- package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
- package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
- package/eigen/Eigen/src/Core/Dot.h +172 -222
- package/eigen/Eigen/src/Core/EigenBase.h +75 -85
- package/eigen/Eigen/src/Core/Fill.h +138 -0
- package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
- package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
- package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
- package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
- package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
- package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
- package/eigen/Eigen/src/Core/IO.h +147 -139
- package/eigen/Eigen/src/Core/IndexedView.h +321 -0
- package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
- package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/Inverse.h +56 -66
- package/eigen/Eigen/src/Core/Map.h +124 -142
- package/eigen/Eigen/src/Core/MapBase.h +256 -281
- package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
- package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
- package/eigen/Eigen/src/Core/Matrix.h +491 -416
- package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
- package/eigen/Eigen/src/Core/NestByValue.h +66 -85
- package/eigen/Eigen/src/Core/NoAlias.h +79 -85
- package/eigen/Eigen/src/Core/NumTraits.h +235 -148
- package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
- package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
- package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
- package/eigen/Eigen/src/Core/Product.h +260 -139
- package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
- package/eigen/Eigen/src/Core/Random.h +161 -136
- package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
- package/eigen/Eigen/src/Core/RealView.h +250 -0
- package/eigen/Eigen/src/Core/Redux.h +366 -336
- package/eigen/Eigen/src/Core/Ref.h +308 -209
- package/eigen/Eigen/src/Core/Replicate.h +94 -106
- package/eigen/Eigen/src/Core/Reshaped.h +398 -0
- package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
- package/eigen/Eigen/src/Core/Reverse.h +136 -145
- package/eigen/Eigen/src/Core/Select.h +70 -140
- package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
- package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
- package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
- package/eigen/Eigen/src/Core/Solve.h +97 -111
- package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
- package/eigen/Eigen/src/Core/SolverBase.h +138 -101
- package/eigen/Eigen/src/Core/StableNorm.h +156 -160
- package/eigen/Eigen/src/Core/StlIterators.h +619 -0
- package/eigen/Eigen/src/Core/Stride.h +91 -88
- package/eigen/Eigen/src/Core/Swap.h +70 -38
- package/eigen/Eigen/src/Core/Transpose.h +295 -273
- package/eigen/Eigen/src/Core/Transpositions.h +272 -317
- package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
- package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
- package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
- package/eigen/Eigen/src/Core/Visitor.h +480 -216
- package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
- package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
- package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
- package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
- package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
- package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
- package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
- package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
- package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
- package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
- package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
- package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
- package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
- package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
- package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
- package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
- package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
- package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
- package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
- package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
- package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
- package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
- package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
- package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
- package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
- package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
- package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
- package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
- package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
- package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
- package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
- package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
- package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
- package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
- package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
- package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
- package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
- package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
- package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
- package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
- package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
- package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
- package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
- package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
- package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
- package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
- package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
- package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
- package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
- package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
- package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
- package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
- package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
- package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
- package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
- package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
- package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
- package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
- package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
- package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
- package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
- package/eigen/Eigen/src/Core/util/Assert.h +158 -0
- package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
- package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
- package/eigen/Eigen/src/Core/util/Constants.h +314 -263
- package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
- package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
- package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
- package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
- package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
- package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
- package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
- package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
- package/eigen/Eigen/src/Core/util/Macros.h +939 -646
- package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
- package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
- package/eigen/Eigen/src/Core/util/Meta.h +618 -426
- package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
- package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
- package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
- package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
- package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
- package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
- package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
- package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
- package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
- package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
- package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
- package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
- package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
- package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
- package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
- package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
- package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
- package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
- package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
- package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
- package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
- package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
- package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
- package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
- package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
- package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
- package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
- package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
- package/eigen/Eigen/src/Geometry/Transform.h +896 -953
- package/eigen/Eigen/src/Geometry/Translation.h +100 -98
- package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
- package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
- package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
- package/eigen/Eigen/src/Householder/Householder.h +104 -122
- package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
- package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
- package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
- package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
- package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
- package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
- package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
- package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
- package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
- package/eigen/Eigen/src/LU/Determinant.h +60 -63
- package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
- package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
- package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
- package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
- package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
- package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
- package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
- package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
- package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
- package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
- package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
- package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
- package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
- package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
- package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
- package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
- package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
- package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
- package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
- package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
- package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
- package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
- package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
- package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
- package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
- package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
- package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
- package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
- package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
- package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
- package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
- package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
- package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
- package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
- package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
- package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
- package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
- package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
- package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
- package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
- package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
- package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
- package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
- package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
- package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
- package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
- package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
- package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
- package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
- package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
- package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
- package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
- package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
- package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
- package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
- package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
- package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
- package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
- package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
- package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
- package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
- package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
- package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
- package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
- package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
- package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
- package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
- package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
- package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
- package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
- package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
- package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
- package/eigen/Eigen/src/StlSupport/details.h +48 -50
- package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
- package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
- package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
- package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
- package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
- package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
- package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
- package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
- package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
- package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
- package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
- package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
- package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
- package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
- package/eigen/Eigen/src/misc/Image.h +41 -43
- package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/misc/Kernel.h +39 -41
- package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
- package/eigen/Eigen/src/misc/blas.h +83 -426
- package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
- package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
- package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
- package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
- package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
- package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
- package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
- package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
- package/lib/LibEigen.d.ts +4 -0
- package/lib/LibEigen.js +14 -0
- package/lib/index.d.ts +1 -1
- package/lib/index.js +7 -3
- package/package.json +2 -10
- package/eigen/Eigen/CMakeLists.txt +0 -19
- package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
- package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
- package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
- package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
- package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
- package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
- package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
- package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
- package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
- package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
- package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
- package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
- package/eigen/Eigen/src/misc/lapack.h +0 -152
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
- package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
- package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
- package/lib/eigen.d.ts +0 -2
- package/lib/eigen.js +0 -15
|
@@ -0,0 +1,1219 @@
|
|
|
1
|
+
// This file is part of Eigen, a lightweight C++ template library
|
|
2
|
+
// for linear algebra.
|
|
3
|
+
//
|
|
4
|
+
// Copyright (C) 2022 Intel Corporation
|
|
5
|
+
//
|
|
6
|
+
// This Source Code Form is subject to the terms of the Mozilla
|
|
7
|
+
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
8
|
+
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
9
|
+
|
|
10
|
+
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
|
|
11
|
+
#define EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
|
|
12
|
+
|
|
13
|
+
template <bool isARowMajor = true>
|
|
14
|
+
EIGEN_ALWAYS_INLINE int64_t idA(int64_t i, int64_t j, int64_t LDA) {
|
|
15
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j;
|
|
16
|
+
else return i + j * LDA;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* This namespace contains various classes used to generate compile-time unrolls which are
|
|
21
|
+
* used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested
|
|
22
|
+
* for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion
|
|
23
|
+
*
|
|
24
|
+
* Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop.
|
|
25
|
+
*
|
|
26
|
+
* for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++)
|
|
27
|
+
* for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ)
|
|
28
|
+
* func(startI,startJ) startJ = (startC)%(endJ)
|
|
29
|
+
* func(...)
|
|
30
|
+
*
|
|
31
|
+
* The 1-D loop can be unrolled recursively by using enable_if and defining an auxiliary function
|
|
32
|
+
* with a template parameter used as a counter.
|
|
33
|
+
*
|
|
34
|
+
* template <endI, endJ, counter>
|
|
35
|
+
* std::enable_if_t<(counter <= 0)> <---- tail case.
|
|
36
|
+
* aux_func {}
|
|
37
|
+
*
|
|
38
|
+
* template <endI, endJ, counter>
|
|
39
|
+
* std::enable_if_t<(counter > 0)> <---- actual for-loop
|
|
40
|
+
* aux_func {
|
|
41
|
+
* startC = endI*endJ - counter
|
|
42
|
+
* startI = (startC)/(endJ)
|
|
43
|
+
* startJ = (startC)%(endJ)
|
|
44
|
+
* func(startI, startJ)
|
|
45
|
+
* aux_func<endI, endJ, counter-1>()
|
|
46
|
+
* }
|
|
47
|
+
*
|
|
48
|
+
* Note: Additional wrapper functions are provided for aux_func which hides the counter template
|
|
49
|
+
* parameter since counter usually depends on endI, endJ, etc...
|
|
50
|
+
*
|
|
51
|
+
* Conventions:
|
|
52
|
+
* 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++))
|
|
53
|
+
*
|
|
54
|
+
* 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for
|
|
55
|
+
* handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW)
|
|
56
|
+
*/
|
|
57
|
+
namespace unrolls {
|
|
58
|
+
|
|
59
|
+
template <int64_t N>
|
|
60
|
+
EIGEN_ALWAYS_INLINE auto remMask(int64_t m) {
|
|
61
|
+
EIGEN_IF_CONSTEXPR(N == 16) { return 0xFFFF >> (16 - m); }
|
|
62
|
+
else EIGEN_IF_CONSTEXPR(N == 8) {
|
|
63
|
+
return 0xFF >> (8 - m);
|
|
64
|
+
}
|
|
65
|
+
else EIGEN_IF_CONSTEXPR(N == 4) {
|
|
66
|
+
return 0x0F >> (4 - m);
|
|
67
|
+
}
|
|
68
|
+
return 0;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
template <typename Packet>
|
|
72
|
+
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet, 8> &kernel);
|
|
73
|
+
|
|
74
|
+
template <>
|
|
75
|
+
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet16f, 8> &kernel) {
|
|
76
|
+
__m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
|
|
77
|
+
__m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
|
|
78
|
+
__m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
|
|
79
|
+
__m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
|
|
80
|
+
__m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]);
|
|
81
|
+
__m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]);
|
|
82
|
+
__m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]);
|
|
83
|
+
__m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]);
|
|
84
|
+
|
|
85
|
+
kernel.packet[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
|
|
86
|
+
kernel.packet[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T0), _mm512_castps_pd(T2)));
|
|
87
|
+
kernel.packet[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
|
|
88
|
+
kernel.packet[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T1), _mm512_castps_pd(T3)));
|
|
89
|
+
kernel.packet[4] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
|
|
90
|
+
kernel.packet[5] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T4), _mm512_castps_pd(T6)));
|
|
91
|
+
kernel.packet[6] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
|
|
92
|
+
kernel.packet[7] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(T5), _mm512_castps_pd(T7)));
|
|
93
|
+
|
|
94
|
+
T0 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[4]), 0x4E));
|
|
95
|
+
T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0);
|
|
96
|
+
T4 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[0]), 0x4E));
|
|
97
|
+
T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]);
|
|
98
|
+
T1 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[5]), 0x4E));
|
|
99
|
+
T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1);
|
|
100
|
+
T5 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[1]), 0x4E));
|
|
101
|
+
T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]);
|
|
102
|
+
T2 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[6]), 0x4E));
|
|
103
|
+
T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2);
|
|
104
|
+
T6 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[2]), 0x4E));
|
|
105
|
+
T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]);
|
|
106
|
+
T3 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[7]), 0x4E));
|
|
107
|
+
T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3);
|
|
108
|
+
T7 = _mm512_castpd_ps(_mm512_permutex_pd(_mm512_castps_pd(kernel.packet[3]), 0x4E));
|
|
109
|
+
T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]);
|
|
110
|
+
|
|
111
|
+
kernel.packet[0] = T0;
|
|
112
|
+
kernel.packet[1] = T1;
|
|
113
|
+
kernel.packet[2] = T2;
|
|
114
|
+
kernel.packet[3] = T3;
|
|
115
|
+
kernel.packet[4] = T4;
|
|
116
|
+
kernel.packet[5] = T5;
|
|
117
|
+
kernel.packet[6] = T6;
|
|
118
|
+
kernel.packet[7] = T7;
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
template <>
|
|
122
|
+
EIGEN_ALWAYS_INLINE void trans8x8blocks(PacketBlock<Packet8d, 8> &kernel) {
|
|
123
|
+
ptranspose(kernel);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
/***
|
|
127
|
+
* Unrolls for transposed C stores
|
|
128
|
+
*/
|
|
129
|
+
template <typename Scalar>
|
|
130
|
+
class trans {
|
|
131
|
+
public:
|
|
132
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
133
|
+
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
|
134
|
+
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
|
135
|
+
|
|
136
|
+
/***********************************
|
|
137
|
+
* Auxiliary Functions for:
|
|
138
|
+
* - storeC
|
|
139
|
+
***********************************
|
|
140
|
+
*/
|
|
141
|
+
|
|
142
|
+
/**
|
|
143
|
+
* aux_storeC
|
|
144
|
+
*
|
|
145
|
+
* 1-D unroll
|
|
146
|
+
* for(startN = 0; startN < endN; startN++)
|
|
147
|
+
*
|
|
148
|
+
* (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC
|
|
149
|
+
*
|
|
150
|
+
**/
|
|
151
|
+
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
|
152
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> aux_storeC(
|
|
153
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
|
154
|
+
constexpr int64_t counterReverse = endN - counter;
|
|
155
|
+
constexpr int64_t startN = counterReverse;
|
|
156
|
+
|
|
157
|
+
EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) {
|
|
158
|
+
EIGEN_IF_CONSTEXPR(remM) {
|
|
159
|
+
pstoreu<Scalar>(
|
|
160
|
+
C_arr + LDC * startN,
|
|
161
|
+
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
|
162
|
+
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN]),
|
|
163
|
+
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
|
164
|
+
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
|
165
|
+
}
|
|
166
|
+
else {
|
|
167
|
+
pstoreu<Scalar>(C_arr + LDC * startN,
|
|
168
|
+
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
|
169
|
+
preinterpret<vecHalf>(zmm.packet[packetIndexOffset + (unrollN / PacketSize) * startN])));
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
else { // This block is only needed for fp32 case
|
|
173
|
+
// Reinterpret as __m512 for _mm512_shuffle_f32x4
|
|
174
|
+
vecFullFloat zmm2vecFullFloat = preinterpret<vecFullFloat>(
|
|
175
|
+
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)]);
|
|
176
|
+
// Swap lower and upper half of avx register.
|
|
177
|
+
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)] =
|
|
178
|
+
preinterpret<vec>(_mm512_shuffle_f32x4(zmm2vecFullFloat, zmm2vecFullFloat, 0b01001110));
|
|
179
|
+
|
|
180
|
+
EIGEN_IF_CONSTEXPR(remM) {
|
|
181
|
+
pstoreu<Scalar>(
|
|
182
|
+
C_arr + LDC * startN,
|
|
183
|
+
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN, remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_)),
|
|
184
|
+
preinterpret<vecHalf>(
|
|
185
|
+
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])),
|
|
186
|
+
remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
|
187
|
+
}
|
|
188
|
+
else {
|
|
189
|
+
pstoreu<Scalar>(
|
|
190
|
+
C_arr + LDC * startN,
|
|
191
|
+
padd(ploadu<vecHalf>((const Scalar *)C_arr + LDC * startN),
|
|
192
|
+
preinterpret<vecHalf>(
|
|
193
|
+
zmm.packet[packetIndexOffset + (unrollN / PacketSize) * (startN - EIGEN_AVX_MAX_NUM_ROW)])));
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
aux_storeC<endN, counter - 1, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
template <int64_t endN, int64_t counter, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
|
200
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && endN <= PacketSize)> aux_storeC(
|
|
201
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t remM_ = 0) {
|
|
202
|
+
EIGEN_UNUSED_VARIABLE(C_arr);
|
|
203
|
+
EIGEN_UNUSED_VARIABLE(LDC);
|
|
204
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
205
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
template <int64_t endN, int64_t unrollN, int64_t packetIndexOffset, bool remM>
|
|
209
|
+
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
|
210
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
211
|
+
int64_t remM_ = 0) {
|
|
212
|
+
aux_storeC<endN, endN, unrollN, packetIndexOffset, remM>(C_arr, LDC, zmm, remM_);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
/**
|
|
216
|
+
* Transposes LxunrollN row major block of matrices stored `EIGEN_AVX_MAX_NUM_ACC` zmm registers to
|
|
217
|
+
* "unrollN"xL ymm registers to be stored col-major into C.
|
|
218
|
+
*
|
|
219
|
+
* For 8x48, the 8x48 block (row-major) is stored in zmm as follows:
|
|
220
|
+
*
|
|
221
|
+
* ```
|
|
222
|
+
* row0: zmm0 zmm1 zmm2
|
|
223
|
+
* row1: zmm3 zmm4 zmm5
|
|
224
|
+
* .
|
|
225
|
+
* .
|
|
226
|
+
* row7: zmm21 zmm22 zmm23
|
|
227
|
+
*
|
|
228
|
+
* For 8x32, the 8x32 block (row-major) is stored in zmm as follows:
|
|
229
|
+
*
|
|
230
|
+
* row0: zmm0 zmm1
|
|
231
|
+
* row1: zmm2 zmm3
|
|
232
|
+
* .
|
|
233
|
+
* .
|
|
234
|
+
* row7: zmm14 zmm15
|
|
235
|
+
* ```
|
|
236
|
+
*
|
|
237
|
+
* In general we will have {1,2,3} groups of avx registers each of size
|
|
238
|
+
* `EIGEN_AVX_MAX_NUM_ROW`. packetIndexOffset is used to select which "block" of
|
|
239
|
+
* avx registers are being transposed.
|
|
240
|
+
*/
|
|
241
|
+
template <int64_t unrollN, int64_t packetIndexOffset>
|
|
242
|
+
static EIGEN_ALWAYS_INLINE void transpose(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
243
|
+
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
|
244
|
+
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
|
245
|
+
constexpr int64_t zmmStride = unrollN / PacketSize;
|
|
246
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> r;
|
|
247
|
+
r.packet[0] = zmm.packet[packetIndexOffset + zmmStride * 0];
|
|
248
|
+
r.packet[1] = zmm.packet[packetIndexOffset + zmmStride * 1];
|
|
249
|
+
r.packet[2] = zmm.packet[packetIndexOffset + zmmStride * 2];
|
|
250
|
+
r.packet[3] = zmm.packet[packetIndexOffset + zmmStride * 3];
|
|
251
|
+
r.packet[4] = zmm.packet[packetIndexOffset + zmmStride * 4];
|
|
252
|
+
r.packet[5] = zmm.packet[packetIndexOffset + zmmStride * 5];
|
|
253
|
+
r.packet[6] = zmm.packet[packetIndexOffset + zmmStride * 6];
|
|
254
|
+
r.packet[7] = zmm.packet[packetIndexOffset + zmmStride * 7];
|
|
255
|
+
trans8x8blocks(r);
|
|
256
|
+
zmm.packet[packetIndexOffset + zmmStride * 0] = r.packet[0];
|
|
257
|
+
zmm.packet[packetIndexOffset + zmmStride * 1] = r.packet[1];
|
|
258
|
+
zmm.packet[packetIndexOffset + zmmStride * 2] = r.packet[2];
|
|
259
|
+
zmm.packet[packetIndexOffset + zmmStride * 3] = r.packet[3];
|
|
260
|
+
zmm.packet[packetIndexOffset + zmmStride * 4] = r.packet[4];
|
|
261
|
+
zmm.packet[packetIndexOffset + zmmStride * 5] = r.packet[5];
|
|
262
|
+
zmm.packet[packetIndexOffset + zmmStride * 6] = r.packet[6];
|
|
263
|
+
zmm.packet[packetIndexOffset + zmmStride * 7] = r.packet[7];
|
|
264
|
+
}
|
|
265
|
+
};
|
|
266
|
+
|
|
267
|
+
/**
|
|
268
|
+
* Unrolls for copyBToRowMajor
|
|
269
|
+
*
|
|
270
|
+
* Idea:
|
|
271
|
+
* 1) Load a block of right-hand sides to registers (using loadB).
|
|
272
|
+
* 2) Convert the block from column-major to row-major (transposeLxL)
|
|
273
|
+
* 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false).
|
|
274
|
+
*
|
|
275
|
+
* We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are
|
|
276
|
+
* used as temps for transposing.
|
|
277
|
+
*
|
|
278
|
+
* Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks
|
|
279
|
+
* For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we reinterpret packets as packets half the size (zmm -> ymm).
|
|
280
|
+
*/
|
|
281
|
+
template <typename Scalar>
|
|
282
|
+
class transB {
|
|
283
|
+
public:
|
|
284
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
285
|
+
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
|
286
|
+
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
|
287
|
+
|
|
288
|
+
/***********************************
|
|
289
|
+
* Auxiliary Functions for:
|
|
290
|
+
* - loadB
|
|
291
|
+
* - storeB
|
|
292
|
+
* - loadBBlock
|
|
293
|
+
* - storeBBlock
|
|
294
|
+
***********************************
|
|
295
|
+
*/
|
|
296
|
+
|
|
297
|
+
/**
|
|
298
|
+
* aux_loadB
|
|
299
|
+
*
|
|
300
|
+
* 1-D unroll
|
|
301
|
+
* for(startN = 0; startN < endN; startN++)
|
|
302
|
+
**/
|
|
303
|
+
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
|
304
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
|
305
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
306
|
+
int64_t remM_ = 0) {
|
|
307
|
+
constexpr int64_t counterReverse = endN - counter;
|
|
308
|
+
constexpr int64_t startN = counterReverse;
|
|
309
|
+
|
|
310
|
+
EIGEN_IF_CONSTEXPR(remM) {
|
|
311
|
+
ymm.packet[packetIndexOffset + startN] =
|
|
312
|
+
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remM_));
|
|
313
|
+
}
|
|
314
|
+
else {
|
|
315
|
+
EIGEN_IF_CONSTEXPR(remN_ == 0) {
|
|
316
|
+
ymm.packet[packetIndexOffset + startN] = ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB]);
|
|
317
|
+
}
|
|
318
|
+
else ymm.packet[packetIndexOffset + startN] =
|
|
319
|
+
ploadu<vecHalf>((const Scalar *)&B_arr[startN * LDB], remMask<EIGEN_AVX_MAX_NUM_ROW>(remN_));
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
aux_loadB<endN, counter - 1, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
|
326
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
|
327
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
328
|
+
int64_t remM_ = 0) {
|
|
329
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
330
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
331
|
+
EIGEN_UNUSED_VARIABLE(ymm);
|
|
332
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
/**
|
|
336
|
+
* aux_storeB
|
|
337
|
+
*
|
|
338
|
+
* 1-D unroll
|
|
339
|
+
* for(startN = 0; startN < endN; startN++)
|
|
340
|
+
**/
|
|
341
|
+
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
|
342
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeB(
|
|
343
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
|
344
|
+
constexpr int64_t counterReverse = endN - counter;
|
|
345
|
+
constexpr int64_t startN = counterReverse;
|
|
346
|
+
|
|
347
|
+
EIGEN_IF_CONSTEXPR(remK || remM) {
|
|
348
|
+
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN],
|
|
349
|
+
remMask<EIGEN_AVX_MAX_NUM_ROW>(rem_));
|
|
350
|
+
}
|
|
351
|
+
else {
|
|
352
|
+
pstoreu<Scalar>(&B_arr[startN * LDB], ymm.packet[packetIndexOffset + startN]);
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
aux_storeB<endN, counter - 1, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
template <int64_t endN, int64_t counter, int64_t packetIndexOffset, bool remK, bool remM>
|
|
359
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeB(
|
|
360
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t rem_ = 0) {
|
|
361
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
362
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
363
|
+
EIGEN_UNUSED_VARIABLE(ymm);
|
|
364
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
/**
|
|
368
|
+
* aux_loadBBlock
|
|
369
|
+
*
|
|
370
|
+
* 1-D unroll
|
|
371
|
+
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
|
372
|
+
**/
|
|
373
|
+
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
|
|
374
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadBBlock(
|
|
375
|
+
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
376
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
|
377
|
+
constexpr int64_t counterReverse = endN - counter;
|
|
378
|
+
constexpr int64_t startN = counterReverse;
|
|
379
|
+
transB::template loadB<EIGEN_AVX_MAX_NUM_ROW, startN, false, (toTemp ? 0 : remN_)>(&B_temp[startN], LDB_, ymm);
|
|
380
|
+
aux_loadBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remN_>
|
|
384
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadBBlock(
|
|
385
|
+
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
386
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
|
387
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
388
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
389
|
+
EIGEN_UNUSED_VARIABLE(B_temp);
|
|
390
|
+
EIGEN_UNUSED_VARIABLE(LDB_);
|
|
391
|
+
EIGEN_UNUSED_VARIABLE(ymm);
|
|
392
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
/**
|
|
396
|
+
* aux_storeBBlock
|
|
397
|
+
*
|
|
398
|
+
* 1-D unroll
|
|
399
|
+
* for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW)
|
|
400
|
+
**/
|
|
401
|
+
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
|
402
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeBBlock(
|
|
403
|
+
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
404
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
|
405
|
+
constexpr int64_t counterReverse = endN - counter;
|
|
406
|
+
constexpr int64_t startN = counterReverse;
|
|
407
|
+
|
|
408
|
+
EIGEN_IF_CONSTEXPR(toTemp) {
|
|
409
|
+
transB::template storeB<EIGEN_AVX_MAX_NUM_ROW, startN, remK_ != 0, false>(&B_temp[startN], LDB_, ymm, remK_);
|
|
410
|
+
}
|
|
411
|
+
else {
|
|
412
|
+
transB::template storeB<std::min(EIGEN_AVX_MAX_NUM_ROW, endN), startN, false, remM>(&B_arr[0 + startN * LDB], LDB,
|
|
413
|
+
ymm, remM_);
|
|
414
|
+
}
|
|
415
|
+
aux_storeBBlock<endN, counter - EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
template <int64_t endN, int64_t counter, bool toTemp, bool remM, int64_t remK_>
|
|
419
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeBBlock(
|
|
420
|
+
Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
421
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm, int64_t remM_ = 0) {
|
|
422
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
423
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
424
|
+
EIGEN_UNUSED_VARIABLE(B_temp);
|
|
425
|
+
EIGEN_UNUSED_VARIABLE(LDB_);
|
|
426
|
+
EIGEN_UNUSED_VARIABLE(ymm);
|
|
427
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
/********************************************************
|
|
431
|
+
* Wrappers for aux_XXXX to hide counter parameter
|
|
432
|
+
********************************************************/
|
|
433
|
+
|
|
434
|
+
template <int64_t endN, int64_t packetIndexOffset, bool remM, int64_t remN_>
|
|
435
|
+
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_arr, int64_t LDB,
|
|
436
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
437
|
+
int64_t remM_ = 0) {
|
|
438
|
+
aux_loadB<endN, endN, packetIndexOffset, remM, remN_>(B_arr, LDB, ymm, remM_);
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
template <int64_t endN, int64_t packetIndexOffset, bool remK, bool remM>
|
|
442
|
+
static EIGEN_ALWAYS_INLINE void storeB(Scalar *B_arr, int64_t LDB,
|
|
443
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
444
|
+
int64_t rem_ = 0) {
|
|
445
|
+
aux_storeB<endN, endN, packetIndexOffset, remK, remM>(B_arr, LDB, ymm, rem_);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
template <int64_t unrollN, bool toTemp, bool remM, int64_t remN_ = 0>
|
|
449
|
+
static EIGEN_ALWAYS_INLINE void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
450
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
451
|
+
int64_t remM_ = 0) {
|
|
452
|
+
EIGEN_IF_CONSTEXPR(toTemp) { transB::template loadB<unrollN, 0, remM, 0>(&B_arr[0], LDB, ymm, remM_); }
|
|
453
|
+
else {
|
|
454
|
+
aux_loadBBlock<unrollN, unrollN, toTemp, remM, remN_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
template <int64_t unrollN, bool toTemp, bool remM, int64_t remK_>
|
|
459
|
+
static EIGEN_ALWAYS_INLINE void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
460
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
461
|
+
int64_t remM_ = 0) {
|
|
462
|
+
aux_storeBBlock<unrollN, unrollN, toTemp, remM, remK_>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
template <int64_t packetIndexOffset>
|
|
466
|
+
static EIGEN_ALWAYS_INLINE void transposeLxL(PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm) {
|
|
467
|
+
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
|
468
|
+
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
|
469
|
+
PacketBlock<vecHalf, EIGEN_AVX_MAX_NUM_ROW> r;
|
|
470
|
+
r.packet[0] = ymm.packet[packetIndexOffset + 0];
|
|
471
|
+
r.packet[1] = ymm.packet[packetIndexOffset + 1];
|
|
472
|
+
r.packet[2] = ymm.packet[packetIndexOffset + 2];
|
|
473
|
+
r.packet[3] = ymm.packet[packetIndexOffset + 3];
|
|
474
|
+
r.packet[4] = ymm.packet[packetIndexOffset + 4];
|
|
475
|
+
r.packet[5] = ymm.packet[packetIndexOffset + 5];
|
|
476
|
+
r.packet[6] = ymm.packet[packetIndexOffset + 6];
|
|
477
|
+
r.packet[7] = ymm.packet[packetIndexOffset + 7];
|
|
478
|
+
ptranspose(r);
|
|
479
|
+
ymm.packet[packetIndexOffset + 0] = r.packet[0];
|
|
480
|
+
ymm.packet[packetIndexOffset + 1] = r.packet[1];
|
|
481
|
+
ymm.packet[packetIndexOffset + 2] = r.packet[2];
|
|
482
|
+
ymm.packet[packetIndexOffset + 3] = r.packet[3];
|
|
483
|
+
ymm.packet[packetIndexOffset + 4] = r.packet[4];
|
|
484
|
+
ymm.packet[packetIndexOffset + 5] = r.packet[5];
|
|
485
|
+
ymm.packet[packetIndexOffset + 6] = r.packet[6];
|
|
486
|
+
ymm.packet[packetIndexOffset + 7] = r.packet[7];
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
template <int64_t unrollN, bool toTemp, bool remM>
|
|
490
|
+
static EIGEN_ALWAYS_INLINE void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_,
|
|
491
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &ymm,
|
|
492
|
+
int64_t remM_ = 0) {
|
|
493
|
+
constexpr int64_t U3 = PacketSize * 3;
|
|
494
|
+
constexpr int64_t U2 = PacketSize * 2;
|
|
495
|
+
constexpr int64_t U1 = PacketSize * 1;
|
|
496
|
+
/**
|
|
497
|
+
* Unrolls needed for each case:
|
|
498
|
+
* - AVX512 fp32 48 32 16 8 4 2 1
|
|
499
|
+
* - AVX512 fp64 24 16 8 4 2 1
|
|
500
|
+
*
|
|
501
|
+
* For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up.
|
|
502
|
+
*/
|
|
503
|
+
EIGEN_IF_CONSTEXPR(unrollN == U3) {
|
|
504
|
+
// load LxU3 B col major, transpose LxU3 row major
|
|
505
|
+
constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U3);
|
|
506
|
+
transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
507
|
+
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
508
|
+
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
509
|
+
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
510
|
+
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
511
|
+
|
|
512
|
+
EIGEN_IF_CONSTEXPR(maxUBlock < U3) {
|
|
513
|
+
transB::template loadBBlock<maxUBlock, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
|
514
|
+
ymm, remM_);
|
|
515
|
+
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
516
|
+
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
517
|
+
transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
518
|
+
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB, &B_temp[maxUBlock], LDB_,
|
|
519
|
+
ymm, remM_);
|
|
520
|
+
}
|
|
521
|
+
}
|
|
522
|
+
else EIGEN_IF_CONSTEXPR(unrollN == U2) {
|
|
523
|
+
// load LxU2 B col major, transpose LxU2 row major
|
|
524
|
+
constexpr int64_t maxUBlock = std::min(3 * EIGEN_AVX_MAX_NUM_ROW, U2);
|
|
525
|
+
transB::template loadBBlock<maxUBlock, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
526
|
+
transB::template transposeLxL<0 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
527
|
+
transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
528
|
+
EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2 * EIGEN_AVX_MAX_NUM_ROW>(ymm);
|
|
529
|
+
transB::template storeBBlock<maxUBlock, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
530
|
+
|
|
531
|
+
EIGEN_IF_CONSTEXPR(maxUBlock < U2) {
|
|
532
|
+
transB::template loadBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM>(&B_arr[maxUBlock * LDB], LDB,
|
|
533
|
+
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
|
534
|
+
transB::template transposeLxL<0>(ymm);
|
|
535
|
+
transB::template storeBBlock<EIGEN_AVX_MAX_NUM_ROW, toTemp, remM, 0>(&B_arr[maxUBlock * LDB], LDB,
|
|
536
|
+
&B_temp[maxUBlock], LDB_, ymm, remM_);
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
else EIGEN_IF_CONSTEXPR(unrollN == U1) {
|
|
540
|
+
// load LxU1 B col major, transpose LxU1 row major
|
|
541
|
+
transB::template loadBBlock<U1, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
542
|
+
transB::template transposeLxL<0>(ymm);
|
|
543
|
+
EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { transB::template transposeLxL<1 * EIGEN_AVX_MAX_NUM_ROW>(ymm); }
|
|
544
|
+
transB::template storeBBlock<U1, toTemp, remM, 0>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
545
|
+
}
|
|
546
|
+
else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) {
|
|
547
|
+
// load Lx4 B col major, transpose Lx4 row major
|
|
548
|
+
transB::template loadBBlock<8, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
549
|
+
transB::template transposeLxL<0>(ymm);
|
|
550
|
+
transB::template storeBBlock<8, toTemp, remM, 8>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
551
|
+
}
|
|
552
|
+
else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) {
|
|
553
|
+
// load Lx4 B col major, transpose Lx4 row major
|
|
554
|
+
transB::template loadBBlock<4, toTemp, remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
555
|
+
transB::template transposeLxL<0>(ymm);
|
|
556
|
+
transB::template storeBBlock<4, toTemp, remM, 4>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
557
|
+
}
|
|
558
|
+
else EIGEN_IF_CONSTEXPR(unrollN == 2) {
|
|
559
|
+
// load Lx2 B col major, transpose Lx2 row major
|
|
560
|
+
transB::template loadBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
561
|
+
transB::template transposeLxL<0>(ymm);
|
|
562
|
+
transB::template storeBBlock<2, toTemp, remM, 2>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
563
|
+
}
|
|
564
|
+
else EIGEN_IF_CONSTEXPR(unrollN == 1) {
|
|
565
|
+
// load Lx1 B col major, transpose Lx1 row major
|
|
566
|
+
transB::template loadBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
567
|
+
transB::template transposeLxL<0>(ymm);
|
|
568
|
+
transB::template storeBBlock<1, toTemp, remM, 1>(B_arr, LDB, B_temp, LDB_, ymm, remM_);
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
};
|
|
572
|
+
|
|
573
|
+
/**
|
|
574
|
+
* Unrolls for triSolveKernel
|
|
575
|
+
*
|
|
576
|
+
* Idea:
|
|
577
|
+
* 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS).
|
|
578
|
+
* 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix)
|
|
579
|
+
* stored in AInPacket (using triSolveMicroKernel).
|
|
580
|
+
* 3) Store final results (in avx registers) back into memory (using storeRHS).
|
|
581
|
+
*
|
|
582
|
+
* RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most
|
|
583
|
+
* EIGEN_AVX_MAX_NUM_ROW registers.
|
|
584
|
+
*/
|
|
585
|
+
template <typename Scalar>
|
|
586
|
+
class trsm {
|
|
587
|
+
public:
|
|
588
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
589
|
+
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
|
590
|
+
|
|
591
|
+
/***********************************
|
|
592
|
+
* Auxiliary Functions for:
|
|
593
|
+
* - loadRHS
|
|
594
|
+
* - storeRHS
|
|
595
|
+
* - divRHSByDiag
|
|
596
|
+
* - updateRHS
|
|
597
|
+
* - triSolveMicroKernel
|
|
598
|
+
************************************/
|
|
599
|
+
/**
|
|
600
|
+
* aux_loadRHS
|
|
601
|
+
*
|
|
602
|
+
* 2-D unroll
|
|
603
|
+
* for(startM = 0; startM < endM; startM++)
|
|
604
|
+
* for(startK = 0; startK < endK; startK++)
|
|
605
|
+
**/
|
|
606
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
|
607
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadRHS(
|
|
608
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
609
|
+
constexpr int64_t counterReverse = endM * endK - counter;
|
|
610
|
+
constexpr int64_t startM = counterReverse / (endK);
|
|
611
|
+
constexpr int64_t startK = counterReverse % endK;
|
|
612
|
+
|
|
613
|
+
constexpr int64_t packetIndex = startM * endK + startK;
|
|
614
|
+
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
|
|
615
|
+
const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
|
|
616
|
+
EIGEN_IF_CONSTEXPR(krem) {
|
|
617
|
+
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex], remMask<PacketSize>(rem));
|
|
618
|
+
}
|
|
619
|
+
else {
|
|
620
|
+
RHSInPacket.packet[packetIndex] = ploadu<vec>(&B_arr[rhsIndex]);
|
|
621
|
+
}
|
|
622
|
+
aux_loadRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
|
626
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadRHS(
|
|
627
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
628
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
629
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
630
|
+
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
|
631
|
+
EIGEN_UNUSED_VARIABLE(rem);
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
/**
|
|
635
|
+
* aux_storeRHS
|
|
636
|
+
*
|
|
637
|
+
* 2-D unroll
|
|
638
|
+
* for(startM = 0; startM < endM; startM++)
|
|
639
|
+
* for(startK = 0; startK < endK; startK++)
|
|
640
|
+
**/
|
|
641
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
|
642
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeRHS(
|
|
643
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
644
|
+
constexpr int64_t counterReverse = endM * endK - counter;
|
|
645
|
+
constexpr int64_t startM = counterReverse / (endK);
|
|
646
|
+
constexpr int64_t startK = counterReverse % endK;
|
|
647
|
+
|
|
648
|
+
constexpr int64_t packetIndex = startM * endK + startK;
|
|
649
|
+
constexpr int64_t startM_ = isFWDSolve ? startM : -startM;
|
|
650
|
+
const int64_t rhsIndex = (startK * PacketSize) + startM_ * LDB;
|
|
651
|
+
EIGEN_IF_CONSTEXPR(krem) {
|
|
652
|
+
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask<PacketSize>(rem));
|
|
653
|
+
}
|
|
654
|
+
else {
|
|
655
|
+
pstoreu<Scalar>(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]);
|
|
656
|
+
}
|
|
657
|
+
aux_storeRHS<isFWDSolve, endM, endK, counter - 1, krem>(B_arr, LDB, RHSInPacket, rem);
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, int64_t counter, bool krem>
|
|
661
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeRHS(
|
|
662
|
+
Scalar *B_arr, int64_t LDB, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
663
|
+
EIGEN_UNUSED_VARIABLE(B_arr);
|
|
664
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
665
|
+
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
|
666
|
+
EIGEN_UNUSED_VARIABLE(rem);
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
/**
|
|
670
|
+
* aux_divRHSByDiag
|
|
671
|
+
*
|
|
672
|
+
* currM may be -1, (currM >=0) in enable_if checks for this
|
|
673
|
+
*
|
|
674
|
+
* 1-D unroll
|
|
675
|
+
* for(startK = 0; startK < endK; startK++)
|
|
676
|
+
**/
|
|
677
|
+
template <int64_t currM, int64_t endK, int64_t counter>
|
|
678
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
|
679
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
680
|
+
constexpr int64_t counterReverse = endK - counter;
|
|
681
|
+
constexpr int64_t startK = counterReverse;
|
|
682
|
+
|
|
683
|
+
constexpr int64_t packetIndex = currM * endK + startK;
|
|
684
|
+
RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]);
|
|
685
|
+
aux_divRHSByDiag<currM, endK, counter - 1>(RHSInPacket, AInPacket);
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
template <int64_t currM, int64_t endK, int64_t counter>
|
|
689
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<!(counter > 0 && currM >= 0)> aux_divRHSByDiag(
|
|
690
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
691
|
+
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
|
692
|
+
EIGEN_UNUSED_VARIABLE(AInPacket);
|
|
693
|
+
}
|
|
694
|
+
|
|
695
|
+
/**
|
|
696
|
+
* aux_updateRHS
|
|
697
|
+
*
|
|
698
|
+
* 2-D unroll
|
|
699
|
+
* for(startM = initM; startM < endM; startM++)
|
|
700
|
+
* for(startK = 0; startK < endK; startK++)
|
|
701
|
+
**/
|
|
702
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
|
703
|
+
int64_t counter, int64_t currentM>
|
|
704
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateRHS(
|
|
705
|
+
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
706
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
707
|
+
constexpr int64_t counterReverse = (endM - initM) * endK - counter;
|
|
708
|
+
constexpr int64_t startM = initM + counterReverse / (endK);
|
|
709
|
+
constexpr int64_t startK = counterReverse % endK;
|
|
710
|
+
|
|
711
|
+
// For each row of A, first update all corresponding RHS
|
|
712
|
+
constexpr int64_t packetIndex = startM * endK + startK;
|
|
713
|
+
EIGEN_IF_CONSTEXPR(currentM > 0) {
|
|
714
|
+
RHSInPacket.packet[packetIndex] =
|
|
715
|
+
pnmadd(AInPacket.packet[startM], RHSInPacket.packet[(currentM - 1) * endK + startK],
|
|
716
|
+
RHSInPacket.packet[packetIndex]);
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
EIGEN_IF_CONSTEXPR(startK == endK - 1) {
|
|
720
|
+
// Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}.
|
|
721
|
+
EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) {
|
|
722
|
+
// If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM].
|
|
723
|
+
// This will be used in divRHSByDiag
|
|
724
|
+
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
|
725
|
+
AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(currentM, currentM, LDA)]);
|
|
726
|
+
else AInPacket.packet[currentM] = pset1<vec>(Scalar(1) / A_arr[idA<isARowMajor>(-currentM, -currentM, LDA)]);
|
|
727
|
+
}
|
|
728
|
+
else {
|
|
729
|
+
// Broadcast next off diagonal element of A
|
|
730
|
+
EIGEN_IF_CONSTEXPR(isFWDSolve)
|
|
731
|
+
AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(startM, currentM, LDA)]);
|
|
732
|
+
else AInPacket.packet[startM] = pset1<vec>(A_arr[idA<isARowMajor>(-startM, -currentM, LDA)]);
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, initM, endM, endK, counter - 1, currentM>(
|
|
737
|
+
A_arr, LDA, RHSInPacket, AInPacket);
|
|
738
|
+
}
|
|
739
|
+
|
|
740
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t initM, int64_t endM, int64_t endK,
|
|
741
|
+
int64_t counter, int64_t currentM>
|
|
742
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateRHS(
|
|
743
|
+
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
744
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
745
|
+
EIGEN_UNUSED_VARIABLE(A_arr);
|
|
746
|
+
EIGEN_UNUSED_VARIABLE(LDA);
|
|
747
|
+
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
|
748
|
+
EIGEN_UNUSED_VARIABLE(AInPacket);
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
/**
|
|
752
|
+
* aux_triSolverMicroKernel
|
|
753
|
+
*
|
|
754
|
+
* 1-D unroll
|
|
755
|
+
* for(startM = 0; startM < endM; startM++)
|
|
756
|
+
**/
|
|
757
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
|
758
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_triSolveMicroKernel(
|
|
759
|
+
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
760
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
761
|
+
constexpr int64_t counterReverse = endM - counter;
|
|
762
|
+
constexpr int64_t startM = counterReverse;
|
|
763
|
+
|
|
764
|
+
constexpr int64_t currentM = startM;
|
|
765
|
+
// Divides the right-hand side in row startM, by digonal value of A
|
|
766
|
+
// broadcasted to AInPacket.packet[startM-1] in the previous iteration.
|
|
767
|
+
//
|
|
768
|
+
// Without "if constexpr" the compiler instantiates the case <-1, numK>
|
|
769
|
+
// this is handled with enable_if to prevent out-of-bound warnings
|
|
770
|
+
// from the compiler
|
|
771
|
+
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0)
|
|
772
|
+
trsm::template divRHSByDiag<startM - 1, numK>(RHSInPacket, AInPacket);
|
|
773
|
+
|
|
774
|
+
// After division, the rhs corresponding to subsequent rows of A can be partially updated
|
|
775
|
+
// We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed)
|
|
776
|
+
// to be used in the next iteration.
|
|
777
|
+
trsm::template updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, numK, currentM>(A_arr, LDA, RHSInPacket,
|
|
778
|
+
AInPacket);
|
|
779
|
+
|
|
780
|
+
// Handle division for the RHS corresponding to the final row of A.
|
|
781
|
+
EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM - 1)
|
|
782
|
+
trsm::template divRHSByDiag<startM, numK>(RHSInPacket, AInPacket);
|
|
783
|
+
|
|
784
|
+
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, counter - 1, numK>(A_arr, LDA, RHSInPacket,
|
|
785
|
+
AInPacket);
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t counter, int64_t numK>
|
|
789
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_triSolveMicroKernel(
|
|
790
|
+
Scalar *A_arr, int64_t LDA, PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
791
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
792
|
+
EIGEN_UNUSED_VARIABLE(A_arr);
|
|
793
|
+
EIGEN_UNUSED_VARIABLE(LDA);
|
|
794
|
+
EIGEN_UNUSED_VARIABLE(RHSInPacket);
|
|
795
|
+
EIGEN_UNUSED_VARIABLE(AInPacket);
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
/********************************************************
|
|
799
|
+
* Wrappers for aux_XXXX to hide counter parameter
|
|
800
|
+
********************************************************/
|
|
801
|
+
|
|
802
|
+
/**
|
|
803
|
+
* Load endMxendK block of B to RHSInPacket
|
|
804
|
+
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
|
805
|
+
*/
|
|
806
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
|
807
|
+
static EIGEN_ALWAYS_INLINE void loadRHS(Scalar *B_arr, int64_t LDB,
|
|
808
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
809
|
+
aux_loadRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
/**
|
|
813
|
+
* Load endMxendK block of B to RHSInPacket
|
|
814
|
+
* Masked loads are used for cases where endK is not a multiple of PacketSize
|
|
815
|
+
*/
|
|
816
|
+
template <bool isFWDSolve, int64_t endM, int64_t endK, bool krem = false>
|
|
817
|
+
static EIGEN_ALWAYS_INLINE void storeRHS(Scalar *B_arr, int64_t LDB,
|
|
818
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket, int64_t rem = 0) {
|
|
819
|
+
aux_storeRHS<isFWDSolve, endM, endK, endM * endK, krem>(B_arr, LDB, RHSInPacket, rem);
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
/**
|
|
823
|
+
* Only used if Triangular matrix has non-unit diagonal values
|
|
824
|
+
*/
|
|
825
|
+
template <int64_t currM, int64_t endK>
|
|
826
|
+
static EIGEN_ALWAYS_INLINE void divRHSByDiag(PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
827
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
828
|
+
aux_divRHSByDiag<currM, endK, endK>(RHSInPacket, AInPacket);
|
|
829
|
+
}
|
|
830
|
+
|
|
831
|
+
/**
|
|
832
|
+
* Update right-hand sides (stored in avx registers)
|
|
833
|
+
* Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket.
|
|
834
|
+
**/
|
|
835
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t startM, int64_t endM, int64_t endK,
|
|
836
|
+
int64_t currentM>
|
|
837
|
+
static EIGEN_ALWAYS_INLINE void updateRHS(Scalar *A_arr, int64_t LDA,
|
|
838
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
839
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
840
|
+
aux_updateRHS<isARowMajor, isFWDSolve, isUnitDiag, startM, endM, endK, (endM - startM) * endK, currentM>(
|
|
841
|
+
A_arr, LDA, RHSInPacket, AInPacket);
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
/**
|
|
845
|
+
* endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW
|
|
846
|
+
* numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3.
|
|
847
|
+
* isFWDSolve: true => forward substitution, false => backwards substitution
|
|
848
|
+
* isUnitDiag: true => triangular matrix has unit diagonal.
|
|
849
|
+
*/
|
|
850
|
+
template <bool isARowMajor, bool isFWDSolve, bool isUnitDiag, int64_t endM, int64_t numK>
|
|
851
|
+
static EIGEN_ALWAYS_INLINE void triSolveMicroKernel(Scalar *A_arr, int64_t LDA,
|
|
852
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> &RHSInPacket,
|
|
853
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> &AInPacket) {
|
|
854
|
+
static_assert(numK >= 1 && numK <= 3, "numK out of range");
|
|
855
|
+
aux_triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, endM, endM, numK>(A_arr, LDA, RHSInPacket, AInPacket);
|
|
856
|
+
}
|
|
857
|
+
};
|
|
858
|
+
|
|
859
|
+
/**
|
|
860
|
+
* Unrolls for gemm kernel
|
|
861
|
+
*
|
|
862
|
+
* isAdd: true => C += A*B, false => C -= A*B
|
|
863
|
+
*/
|
|
864
|
+
template <typename Scalar, bool isAdd>
|
|
865
|
+
class gemm {
|
|
866
|
+
public:
|
|
867
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
868
|
+
static constexpr int64_t PacketSize = packet_traits<Scalar>::size;
|
|
869
|
+
|
|
870
|
+
/***********************************
|
|
871
|
+
* Auxiliary Functions for:
|
|
872
|
+
* - setzero
|
|
873
|
+
* - updateC
|
|
874
|
+
* - storeC
|
|
875
|
+
* - startLoadB
|
|
876
|
+
* - triSolveMicroKernel
|
|
877
|
+
************************************/
|
|
878
|
+
|
|
879
|
+
/**
|
|
880
|
+
* aux_setzero
|
|
881
|
+
*
|
|
882
|
+
* 2-D unroll
|
|
883
|
+
* for(startM = 0; startM < endM; startM++)
|
|
884
|
+
* for(startN = 0; startN < endN; startN++)
|
|
885
|
+
**/
|
|
886
|
+
template <int64_t endM, int64_t endN, int64_t counter>
|
|
887
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_setzero(
|
|
888
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
889
|
+
constexpr int64_t counterReverse = endM * endN - counter;
|
|
890
|
+
constexpr int64_t startM = counterReverse / (endN);
|
|
891
|
+
constexpr int64_t startN = counterReverse % endN;
|
|
892
|
+
|
|
893
|
+
zmm.packet[startN * endM + startM] = pzero(zmm.packet[startN * endM + startM]);
|
|
894
|
+
aux_setzero<endM, endN, counter - 1>(zmm);
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
template <int64_t endM, int64_t endN, int64_t counter>
|
|
898
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_setzero(
|
|
899
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
900
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
901
|
+
}
|
|
902
|
+
|
|
903
|
+
/**
|
|
904
|
+
* aux_updateC
|
|
905
|
+
*
|
|
906
|
+
* 2-D unroll
|
|
907
|
+
* for(startM = 0; startM < endM; startM++)
|
|
908
|
+
* for(startN = 0; startN < endN; startN++)
|
|
909
|
+
**/
|
|
910
|
+
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
|
911
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_updateC(
|
|
912
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
913
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
914
|
+
constexpr int64_t counterReverse = endM * endN - counter;
|
|
915
|
+
constexpr int64_t startM = counterReverse / (endN);
|
|
916
|
+
constexpr int64_t startN = counterReverse % endN;
|
|
917
|
+
|
|
918
|
+
EIGEN_IF_CONSTEXPR(rem)
|
|
919
|
+
zmm.packet[startN * endM + startM] =
|
|
920
|
+
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize], remMask<PacketSize>(rem_)),
|
|
921
|
+
zmm.packet[startN * endM + startM], remMask<PacketSize>(rem_));
|
|
922
|
+
else zmm.packet[startN * endM + startM] =
|
|
923
|
+
padd(ploadu<vec>(&C_arr[(startN)*LDC + startM * PacketSize]), zmm.packet[startN * endM + startM]);
|
|
924
|
+
aux_updateC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
|
928
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_updateC(
|
|
929
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
930
|
+
EIGEN_UNUSED_VARIABLE(C_arr);
|
|
931
|
+
EIGEN_UNUSED_VARIABLE(LDC);
|
|
932
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
933
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
/**
|
|
937
|
+
* aux_storeC
|
|
938
|
+
*
|
|
939
|
+
* 2-D unroll
|
|
940
|
+
* for(startM = 0; startM < endM; startM++)
|
|
941
|
+
* for(startN = 0; startN < endN; startN++)
|
|
942
|
+
**/
|
|
943
|
+
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
|
944
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_storeC(
|
|
945
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
946
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
947
|
+
constexpr int64_t counterReverse = endM * endN - counter;
|
|
948
|
+
constexpr int64_t startM = counterReverse / (endN);
|
|
949
|
+
constexpr int64_t startN = counterReverse % endN;
|
|
950
|
+
|
|
951
|
+
EIGEN_IF_CONSTEXPR(rem)
|
|
952
|
+
pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM],
|
|
953
|
+
remMask<PacketSize>(rem_));
|
|
954
|
+
else pstoreu<Scalar>(&C_arr[(startN)*LDC + startM * PacketSize], zmm.packet[startN * endM + startM]);
|
|
955
|
+
aux_storeC<endM, endN, counter - 1, rem>(C_arr, LDC, zmm, rem_);
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
template <int64_t endM, int64_t endN, int64_t counter, bool rem>
|
|
959
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_storeC(
|
|
960
|
+
Scalar *C_arr, int64_t LDC, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
961
|
+
EIGEN_UNUSED_VARIABLE(C_arr);
|
|
962
|
+
EIGEN_UNUSED_VARIABLE(LDC);
|
|
963
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
964
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
/**
|
|
968
|
+
* aux_startLoadB
|
|
969
|
+
*
|
|
970
|
+
* 1-D unroll
|
|
971
|
+
* for(startL = 0; startL < endL; startL++)
|
|
972
|
+
**/
|
|
973
|
+
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
|
974
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startLoadB(
|
|
975
|
+
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
976
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
977
|
+
constexpr int64_t counterReverse = endL - counter;
|
|
978
|
+
constexpr int64_t startL = counterReverse;
|
|
979
|
+
|
|
980
|
+
EIGEN_IF_CONSTEXPR(rem)
|
|
981
|
+
zmm.packet[unrollM * unrollN + startL] =
|
|
982
|
+
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize], remMask<PacketSize>(rem_));
|
|
983
|
+
else zmm.packet[unrollM * unrollN + startL] =
|
|
984
|
+
ploadu<vec>(&B_t[(startL / unrollM) * LDB + (startL % unrollM) * PacketSize]);
|
|
985
|
+
|
|
986
|
+
aux_startLoadB<unrollM, unrollN, endL, counter - 1, rem>(B_t, LDB, zmm, rem_);
|
|
987
|
+
}
|
|
988
|
+
|
|
989
|
+
template <int64_t unrollM, int64_t unrollN, int64_t endL, int64_t counter, bool rem>
|
|
990
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startLoadB(
|
|
991
|
+
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
992
|
+
EIGEN_UNUSED_VARIABLE(B_t);
|
|
993
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
994
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
995
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
996
|
+
}
|
|
997
|
+
|
|
998
|
+
/**
|
|
999
|
+
* aux_startBCastA
|
|
1000
|
+
*
|
|
1001
|
+
* 1-D unroll
|
|
1002
|
+
* for(startB = 0; startB < endB; startB++)
|
|
1003
|
+
**/
|
|
1004
|
+
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
|
1005
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_startBCastA(
|
|
1006
|
+
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
1007
|
+
constexpr int64_t counterReverse = endB - counter;
|
|
1008
|
+
constexpr int64_t startB = counterReverse;
|
|
1009
|
+
|
|
1010
|
+
zmm.packet[unrollM * unrollN + numLoad + startB] = pload1<vec>(&A_t[idA<isARowMajor>(startB, 0, LDA)]);
|
|
1011
|
+
|
|
1012
|
+
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, counter - 1, numLoad>(A_t, LDA, zmm);
|
|
1013
|
+
}
|
|
1014
|
+
|
|
1015
|
+
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t counter, int64_t numLoad>
|
|
1016
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_startBCastA(
|
|
1017
|
+
Scalar *A_t, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
1018
|
+
EIGEN_UNUSED_VARIABLE(A_t);
|
|
1019
|
+
EIGEN_UNUSED_VARIABLE(LDA);
|
|
1020
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
/**
|
|
1024
|
+
* aux_loadB
|
|
1025
|
+
* currK: current K
|
|
1026
|
+
*
|
|
1027
|
+
* 1-D unroll
|
|
1028
|
+
* for(startM = 0; startM < endM; startM++)
|
|
1029
|
+
**/
|
|
1030
|
+
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
|
1031
|
+
int64_t numBCast, bool rem>
|
|
1032
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_loadB(
|
|
1033
|
+
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
1034
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1035
|
+
if ((numLoad / endM + currK < unrollK)) {
|
|
1036
|
+
constexpr int64_t counterReverse = endM - counter;
|
|
1037
|
+
constexpr int64_t startM = counterReverse;
|
|
1038
|
+
|
|
1039
|
+
EIGEN_IF_CONSTEXPR(rem) {
|
|
1040
|
+
zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
|
|
1041
|
+
ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize], remMask<PacketSize>(rem_));
|
|
1042
|
+
}
|
|
1043
|
+
else {
|
|
1044
|
+
zmm.packet[endM * unrollN + (startM + currK * endM) % numLoad] =
|
|
1045
|
+
ploadu<vec>(&B_t[(numLoad / endM + currK) * LDB + startM * PacketSize]);
|
|
1046
|
+
}
|
|
1047
|
+
|
|
1048
|
+
aux_loadB<endM, counter - 1, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
|
1049
|
+
}
|
|
1050
|
+
}
|
|
1051
|
+
|
|
1052
|
+
template <int64_t endM, int64_t counter, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad,
|
|
1053
|
+
int64_t numBCast, bool rem>
|
|
1054
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_loadB(
|
|
1055
|
+
Scalar *B_t, int64_t LDB, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, int64_t rem_ = 0) {
|
|
1056
|
+
EIGEN_UNUSED_VARIABLE(B_t);
|
|
1057
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
1058
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
1059
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
/**
|
|
1063
|
+
* aux_microKernel
|
|
1064
|
+
*
|
|
1065
|
+
* 3-D unroll
|
|
1066
|
+
* for(startM = 0; startM < endM; startM++)
|
|
1067
|
+
* for(startN = 0; startN < endN; startN++)
|
|
1068
|
+
* for(startK = 0; startK < endK; startK++)
|
|
1069
|
+
**/
|
|
1070
|
+
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
|
1071
|
+
int64_t numBCast, bool rem>
|
|
1072
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> aux_microKernel(
|
|
1073
|
+
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1074
|
+
int64_t rem_ = 0) {
|
|
1075
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1076
|
+
constexpr int64_t counterReverse = endM * endN * endK - counter;
|
|
1077
|
+
constexpr int startK = counterReverse / (endM * endN);
|
|
1078
|
+
constexpr int startN = (counterReverse / (endM)) % endN;
|
|
1079
|
+
constexpr int startM = counterReverse % endM;
|
|
1080
|
+
|
|
1081
|
+
EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) {
|
|
1082
|
+
gemm::template startLoadB<endM, endN, numLoad, rem>(B_t, LDB, zmm, rem_);
|
|
1083
|
+
gemm::template startBCastA<isARowMajor, endM, endN, numBCast, numLoad>(A_t, LDA, zmm);
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
{
|
|
1087
|
+
// Interleave FMA and Bcast
|
|
1088
|
+
EIGEN_IF_CONSTEXPR(isAdd) {
|
|
1089
|
+
zmm.packet[startN * endM + startM] =
|
|
1090
|
+
pmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
|
|
1091
|
+
zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
|
|
1092
|
+
}
|
|
1093
|
+
else {
|
|
1094
|
+
zmm.packet[startN * endM + startM] =
|
|
1095
|
+
pnmadd(zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast],
|
|
1096
|
+
zmm.packet[endM * endN + (startM + startK * endM) % numLoad], zmm.packet[startN * endM + startM]);
|
|
1097
|
+
}
|
|
1098
|
+
// Bcast
|
|
1099
|
+
EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK * endN < endK * endN)) {
|
|
1100
|
+
zmm.packet[endM * endN + numLoad + (startN + startK * endN) % numBCast] = pload1<vec>(&A_t[idA<isARowMajor>(
|
|
1101
|
+
(numBCast + startN + startK * endN) % endN, (numBCast + startN + startK * endN) / endN, LDA)]);
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
1104
|
+
|
|
1105
|
+
// We have updated all accumulators, time to load next set of B's
|
|
1106
|
+
EIGEN_IF_CONSTEXPR((startN == endN - 1) && (startM == endM - 1)) {
|
|
1107
|
+
gemm::template loadB<endM, endN, startK, endK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
|
1108
|
+
}
|
|
1109
|
+
aux_microKernel<isARowMajor, endM, endN, endK, counter - 1, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm, rem_);
|
|
1110
|
+
}
|
|
1111
|
+
|
|
1112
|
+
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t counter, int64_t numLoad,
|
|
1113
|
+
int64_t numBCast, bool rem>
|
|
1114
|
+
static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> aux_microKernel(
|
|
1115
|
+
Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA, PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1116
|
+
int64_t rem_ = 0) {
|
|
1117
|
+
EIGEN_UNUSED_VARIABLE(B_t);
|
|
1118
|
+
EIGEN_UNUSED_VARIABLE(A_t);
|
|
1119
|
+
EIGEN_UNUSED_VARIABLE(LDB);
|
|
1120
|
+
EIGEN_UNUSED_VARIABLE(LDA);
|
|
1121
|
+
EIGEN_UNUSED_VARIABLE(zmm);
|
|
1122
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
/********************************************************
|
|
1126
|
+
* Wrappers for aux_XXXX to hide counter parameter
|
|
1127
|
+
********************************************************/
|
|
1128
|
+
|
|
1129
|
+
template <int64_t endM, int64_t endN>
|
|
1130
|
+
static EIGEN_ALWAYS_INLINE void setzero(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
1131
|
+
aux_setzero<endM, endN, endM * endN>(zmm);
|
|
1132
|
+
}
|
|
1133
|
+
|
|
1134
|
+
/**
|
|
1135
|
+
* Ideally the compiler folds these into vaddp{s,d} with an embedded memory load.
|
|
1136
|
+
*/
|
|
1137
|
+
template <int64_t endM, int64_t endN, bool rem = false>
|
|
1138
|
+
static EIGEN_ALWAYS_INLINE void updateC(Scalar *C_arr, int64_t LDC,
|
|
1139
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1140
|
+
int64_t rem_ = 0) {
|
|
1141
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1142
|
+
aux_updateC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
template <int64_t endM, int64_t endN, bool rem = false>
|
|
1146
|
+
static EIGEN_ALWAYS_INLINE void storeC(Scalar *C_arr, int64_t LDC,
|
|
1147
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1148
|
+
int64_t rem_ = 0) {
|
|
1149
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1150
|
+
aux_storeC<endM, endN, endM * endN, rem>(C_arr, LDC, zmm, rem_);
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
/**
|
|
1154
|
+
* Use numLoad registers for loading B at start of microKernel
|
|
1155
|
+
*/
|
|
1156
|
+
template <int64_t unrollM, int64_t unrollN, int64_t endL, bool rem>
|
|
1157
|
+
static EIGEN_ALWAYS_INLINE void startLoadB(Scalar *B_t, int64_t LDB,
|
|
1158
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1159
|
+
int64_t rem_ = 0) {
|
|
1160
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1161
|
+
aux_startLoadB<unrollM, unrollN, endL, endL, rem>(B_t, LDB, zmm, rem_);
|
|
1162
|
+
}
|
|
1163
|
+
|
|
1164
|
+
/**
|
|
1165
|
+
* Use numBCast registers for broadcasting A at start of microKernel
|
|
1166
|
+
*/
|
|
1167
|
+
template <bool isARowMajor, int64_t unrollM, int64_t unrollN, int64_t endB, int64_t numLoad>
|
|
1168
|
+
static EIGEN_ALWAYS_INLINE void startBCastA(Scalar *A_t, int64_t LDA,
|
|
1169
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm) {
|
|
1170
|
+
aux_startBCastA<isARowMajor, unrollM, unrollN, endB, endB, numLoad>(A_t, LDA, zmm);
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
/**
|
|
1174
|
+
* Loads next set of B into vector registers between each K unroll.
|
|
1175
|
+
*/
|
|
1176
|
+
template <int64_t endM, int64_t unrollN, int64_t currK, int64_t unrollK, int64_t numLoad, int64_t numBCast, bool rem>
|
|
1177
|
+
static EIGEN_ALWAYS_INLINE void loadB(Scalar *B_t, int64_t LDB,
|
|
1178
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1179
|
+
int64_t rem_ = 0) {
|
|
1180
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1181
|
+
aux_loadB<endM, endM, unrollN, currK, unrollK, numLoad, numBCast, rem>(B_t, LDB, zmm, rem_);
|
|
1182
|
+
}
|
|
1183
|
+
|
|
1184
|
+
/**
|
|
1185
|
+
* Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B.
|
|
1186
|
+
* A matrix can be row/col-major. B matrix is assumed row-major.
|
|
1187
|
+
*
|
|
1188
|
+
* isARowMajor: is A row major
|
|
1189
|
+
* endM: Number registers per row
|
|
1190
|
+
* endN: Number of rows
|
|
1191
|
+
* endK: Loop unroll for K.
|
|
1192
|
+
* numLoad: Number of registers for loading B.
|
|
1193
|
+
* numBCast: Number of registers for broadcasting A.
|
|
1194
|
+
*
|
|
1195
|
+
* Ex: microkernel<isARowMajor,0,3,0,4,0,4,6,2>: 8x48 unroll (24 accumulators), k unrolled 4 times,
|
|
1196
|
+
* 6 register for loading B, 2 for broadcasting A.
|
|
1197
|
+
*
|
|
1198
|
+
* Note: Ideally the microkernel should not have any register spilling.
|
|
1199
|
+
* The avx instruction counts should be:
|
|
1200
|
+
* - endK*endN vbroadcasts{s,d}
|
|
1201
|
+
* - endK*endM vmovup{s,d}
|
|
1202
|
+
* - endK*endN*endM FMAs
|
|
1203
|
+
*
|
|
1204
|
+
* From testing, there are no register spills with clang. There are register spills with GNU, which
|
|
1205
|
+
* causes a performance hit.
|
|
1206
|
+
*/
|
|
1207
|
+
template <bool isARowMajor, int64_t endM, int64_t endN, int64_t endK, int64_t numLoad, int64_t numBCast,
|
|
1208
|
+
bool rem = false>
|
|
1209
|
+
static EIGEN_ALWAYS_INLINE void microKernel(Scalar *B_t, Scalar *A_t, int64_t LDB, int64_t LDA,
|
|
1210
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm,
|
|
1211
|
+
int64_t rem_ = 0) {
|
|
1212
|
+
EIGEN_UNUSED_VARIABLE(rem_);
|
|
1213
|
+
aux_microKernel<isARowMajor, endM, endN, endK, endM * endN * endK, numLoad, numBCast, rem>(B_t, A_t, LDB, LDA, zmm,
|
|
1214
|
+
rem_);
|
|
1215
|
+
}
|
|
1216
|
+
};
|
|
1217
|
+
} // namespace unrolls
|
|
1218
|
+
|
|
1219
|
+
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_UNROLLS_H
|