@smake/eigen 1.0.2 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/eigen/Eigen/AccelerateSupport +52 -0
- package/eigen/Eigen/Cholesky +18 -21
- package/eigen/Eigen/CholmodSupport +28 -28
- package/eigen/Eigen/Core +235 -326
- package/eigen/Eigen/Eigenvalues +16 -14
- package/eigen/Eigen/Geometry +21 -24
- package/eigen/Eigen/Householder +9 -8
- package/eigen/Eigen/IterativeLinearSolvers +8 -4
- package/eigen/Eigen/Jacobi +14 -14
- package/eigen/Eigen/KLUSupport +43 -0
- package/eigen/Eigen/LU +16 -20
- package/eigen/Eigen/MetisSupport +12 -12
- package/eigen/Eigen/OrderingMethods +54 -54
- package/eigen/Eigen/PaStiXSupport +23 -20
- package/eigen/Eigen/PardisoSupport +17 -14
- package/eigen/Eigen/QR +18 -21
- package/eigen/Eigen/QtAlignedMalloc +5 -13
- package/eigen/Eigen/SPQRSupport +21 -14
- package/eigen/Eigen/SVD +23 -18
- package/eigen/Eigen/Sparse +1 -4
- package/eigen/Eigen/SparseCholesky +18 -23
- package/eigen/Eigen/SparseCore +18 -17
- package/eigen/Eigen/SparseLU +12 -8
- package/eigen/Eigen/SparseQR +16 -14
- package/eigen/Eigen/StdDeque +5 -2
- package/eigen/Eigen/StdList +5 -2
- package/eigen/Eigen/StdVector +5 -2
- package/eigen/Eigen/SuperLUSupport +30 -24
- package/eigen/Eigen/ThreadPool +80 -0
- package/eigen/Eigen/UmfPackSupport +19 -17
- package/eigen/Eigen/Version +14 -0
- package/eigen/Eigen/src/AccelerateSupport/AccelerateSupport.h +423 -0
- package/eigen/Eigen/src/AccelerateSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Cholesky/LDLT.h +377 -401
- package/eigen/Eigen/src/Cholesky/LLT.h +332 -360
- package/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +81 -56
- package/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +620 -521
- package/eigen/Eigen/src/CholmodSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/ArithmeticSequence.h +239 -0
- package/eigen/Eigen/src/Core/Array.h +341 -294
- package/eigen/Eigen/src/Core/ArrayBase.h +190 -203
- package/eigen/Eigen/src/Core/ArrayWrapper.h +127 -171
- package/eigen/Eigen/src/Core/Assign.h +30 -40
- package/eigen/Eigen/src/Core/AssignEvaluator.h +711 -589
- package/eigen/Eigen/src/Core/Assign_MKL.h +130 -125
- package/eigen/Eigen/src/Core/BandMatrix.h +268 -283
- package/eigen/Eigen/src/Core/Block.h +375 -398
- package/eigen/Eigen/src/Core/CommaInitializer.h +86 -97
- package/eigen/Eigen/src/Core/ConditionEstimator.h +51 -53
- package/eigen/Eigen/src/Core/CoreEvaluators.h +1356 -1026
- package/eigen/Eigen/src/Core/CoreIterators.h +73 -59
- package/eigen/Eigen/src/Core/CwiseBinaryOp.h +114 -132
- package/eigen/Eigen/src/Core/CwiseNullaryOp.h +726 -617
- package/eigen/Eigen/src/Core/CwiseTernaryOp.h +77 -103
- package/eigen/Eigen/src/Core/CwiseUnaryOp.h +56 -68
- package/eigen/Eigen/src/Core/CwiseUnaryView.h +132 -95
- package/eigen/Eigen/src/Core/DenseBase.h +632 -571
- package/eigen/Eigen/src/Core/DenseCoeffsBase.h +511 -624
- package/eigen/Eigen/src/Core/DenseStorage.h +512 -509
- package/eigen/Eigen/src/Core/DeviceWrapper.h +153 -0
- package/eigen/Eigen/src/Core/Diagonal.h +169 -210
- package/eigen/Eigen/src/Core/DiagonalMatrix.h +351 -274
- package/eigen/Eigen/src/Core/DiagonalProduct.h +12 -10
- package/eigen/Eigen/src/Core/Dot.h +172 -222
- package/eigen/Eigen/src/Core/EigenBase.h +75 -85
- package/eigen/Eigen/src/Core/Fill.h +138 -0
- package/eigen/Eigen/src/Core/FindCoeff.h +464 -0
- package/eigen/Eigen/src/Core/ForceAlignedAccess.h +90 -109
- package/eigen/Eigen/src/Core/Fuzzy.h +82 -105
- package/eigen/Eigen/src/Core/GeneralProduct.h +327 -263
- package/eigen/Eigen/src/Core/GenericPacketMath.h +1472 -360
- package/eigen/Eigen/src/Core/GlobalFunctions.h +194 -151
- package/eigen/Eigen/src/Core/IO.h +147 -139
- package/eigen/Eigen/src/Core/IndexedView.h +321 -0
- package/eigen/Eigen/src/Core/InnerProduct.h +260 -0
- package/eigen/Eigen/src/Core/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Core/Inverse.h +56 -66
- package/eigen/Eigen/src/Core/Map.h +124 -142
- package/eigen/Eigen/src/Core/MapBase.h +256 -281
- package/eigen/Eigen/src/Core/MathFunctions.h +1620 -938
- package/eigen/Eigen/src/Core/MathFunctionsImpl.h +233 -71
- package/eigen/Eigen/src/Core/Matrix.h +491 -416
- package/eigen/Eigen/src/Core/MatrixBase.h +468 -453
- package/eigen/Eigen/src/Core/NestByValue.h +66 -85
- package/eigen/Eigen/src/Core/NoAlias.h +79 -85
- package/eigen/Eigen/src/Core/NumTraits.h +235 -148
- package/eigen/Eigen/src/Core/PartialReduxEvaluator.h +253 -0
- package/eigen/Eigen/src/Core/PermutationMatrix.h +461 -511
- package/eigen/Eigen/src/Core/PlainObjectBase.h +871 -894
- package/eigen/Eigen/src/Core/Product.h +260 -139
- package/eigen/Eigen/src/Core/ProductEvaluators.h +863 -714
- package/eigen/Eigen/src/Core/Random.h +161 -136
- package/eigen/Eigen/src/Core/RandomImpl.h +262 -0
- package/eigen/Eigen/src/Core/RealView.h +250 -0
- package/eigen/Eigen/src/Core/Redux.h +366 -336
- package/eigen/Eigen/src/Core/Ref.h +308 -209
- package/eigen/Eigen/src/Core/Replicate.h +94 -106
- package/eigen/Eigen/src/Core/Reshaped.h +398 -0
- package/eigen/Eigen/src/Core/ReturnByValue.h +49 -55
- package/eigen/Eigen/src/Core/Reverse.h +136 -145
- package/eigen/Eigen/src/Core/Select.h +70 -140
- package/eigen/Eigen/src/Core/SelfAdjointView.h +262 -285
- package/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +23 -20
- package/eigen/Eigen/src/Core/SkewSymmetricMatrix3.h +382 -0
- package/eigen/Eigen/src/Core/Solve.h +97 -111
- package/eigen/Eigen/src/Core/SolveTriangular.h +131 -129
- package/eigen/Eigen/src/Core/SolverBase.h +138 -101
- package/eigen/Eigen/src/Core/StableNorm.h +156 -160
- package/eigen/Eigen/src/Core/StlIterators.h +619 -0
- package/eigen/Eigen/src/Core/Stride.h +91 -88
- package/eigen/Eigen/src/Core/Swap.h +70 -38
- package/eigen/Eigen/src/Core/Transpose.h +295 -273
- package/eigen/Eigen/src/Core/Transpositions.h +272 -317
- package/eigen/Eigen/src/Core/TriangularMatrix.h +670 -755
- package/eigen/Eigen/src/Core/VectorBlock.h +59 -72
- package/eigen/Eigen/src/Core/VectorwiseOp.h +668 -630
- package/eigen/Eigen/src/Core/Visitor.h +480 -216
- package/eigen/Eigen/src/Core/arch/AVX/Complex.h +407 -293
- package/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +79 -388
- package/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +2935 -491
- package/eigen/Eigen/src/Core/arch/AVX/Reductions.h +353 -0
- package/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +279 -22
- package/eigen/Eigen/src/Core/arch/AVX512/Complex.h +472 -0
- package/eigen/Eigen/src/Core/arch/AVX512/GemmKernel.h +1245 -0
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +85 -333
- package/eigen/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h +75 -0
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +2490 -649
- package/eigen/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +1413 -0
- package/eigen/Eigen/src/Core/arch/AVX512/Reductions.h +297 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmKernel.h +1167 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TrsmUnrolls.inc +1219 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCasting.h +277 -0
- package/eigen/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h +130 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +521 -298
- package/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +39 -280
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +3686 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +205 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +901 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +742 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.inc +2818 -0
- package/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +3391 -723
- package/eigen/Eigen/src/Core/arch/AltiVec/TypeCasting.h +153 -0
- package/eigen/Eigen/src/Core/arch/Default/BFloat16.h +866 -0
- package/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +113 -14
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +2634 -0
- package/eigen/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +227 -0
- package/eigen/Eigen/src/Core/arch/Default/Half.h +1091 -0
- package/eigen/Eigen/src/Core/arch/Default/Settings.h +11 -13
- package/eigen/Eigen/src/Core/arch/GPU/Complex.h +244 -0
- package/eigen/Eigen/src/Core/arch/GPU/MathFunctions.h +104 -0
- package/eigen/Eigen/src/Core/arch/GPU/PacketMath.h +1712 -0
- package/eigen/Eigen/src/Core/arch/GPU/Tuple.h +268 -0
- package/eigen/Eigen/src/Core/arch/GPU/TypeCasting.h +77 -0
- package/eigen/Eigen/src/Core/arch/HIP/hcc/math_constants.h +23 -0
- package/eigen/Eigen/src/Core/arch/HVX/PacketMath.h +1088 -0
- package/eigen/Eigen/src/Core/arch/LSX/Complex.h +520 -0
- package/eigen/Eigen/src/Core/arch/LSX/GeneralBlockPanelKernel.h +23 -0
- package/eigen/Eigen/src/Core/arch/LSX/MathFunctions.h +43 -0
- package/eigen/Eigen/src/Core/arch/LSX/PacketMath.h +2866 -0
- package/eigen/Eigen/src/Core/arch/LSX/TypeCasting.h +526 -0
- package/eigen/Eigen/src/Core/arch/MSA/Complex.h +620 -0
- package/eigen/Eigen/src/Core/arch/MSA/MathFunctions.h +379 -0
- package/eigen/Eigen/src/Core/arch/MSA/PacketMath.h +1237 -0
- package/eigen/Eigen/src/Core/arch/NEON/Complex.h +531 -289
- package/eigen/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +243 -0
- package/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +50 -73
- package/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +5915 -579
- package/eigen/Eigen/src/Core/arch/NEON/TypeCasting.h +1642 -0
- package/eigen/Eigen/src/Core/arch/NEON/UnaryFunctors.h +57 -0
- package/eigen/Eigen/src/Core/arch/SSE/Complex.h +366 -334
- package/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +40 -514
- package/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +2164 -675
- package/eigen/Eigen/src/Core/arch/SSE/Reductions.h +324 -0
- package/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +188 -35
- package/eigen/Eigen/src/Core/arch/SVE/MathFunctions.h +48 -0
- package/eigen/Eigen/src/Core/arch/SVE/PacketMath.h +674 -0
- package/eigen/Eigen/src/Core/arch/SVE/TypeCasting.h +52 -0
- package/eigen/Eigen/src/Core/arch/SYCL/InteropHeaders.h +227 -0
- package/eigen/Eigen/src/Core/arch/SYCL/MathFunctions.h +303 -0
- package/eigen/Eigen/src/Core/arch/SYCL/PacketMath.h +576 -0
- package/eigen/Eigen/src/Core/arch/SYCL/TypeCasting.h +83 -0
- package/eigen/Eigen/src/Core/arch/ZVector/Complex.h +434 -261
- package/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +160 -53
- package/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +1073 -605
- package/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +123 -117
- package/eigen/Eigen/src/Core/functors/BinaryFunctors.h +594 -322
- package/eigen/Eigen/src/Core/functors/NullaryFunctors.h +204 -118
- package/eigen/Eigen/src/Core/functors/StlFunctors.h +110 -97
- package/eigen/Eigen/src/Core/functors/TernaryFunctors.h +34 -7
- package/eigen/Eigen/src/Core/functors/UnaryFunctors.h +1158 -530
- package/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2329 -1333
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +328 -364
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +191 -178
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +85 -82
- package/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +154 -73
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +396 -542
- package/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +80 -77
- package/eigen/Eigen/src/Core/products/Parallelizer.h +208 -92
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +331 -375
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +206 -224
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +139 -146
- package/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +58 -61
- package/eigen/Eigen/src/Core/products/SelfadjointProduct.h +71 -71
- package/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +48 -46
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +294 -369
- package/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +246 -238
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +244 -247
- package/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +212 -192
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +328 -275
- package/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +108 -109
- package/eigen/Eigen/src/Core/products/TriangularSolverVector.h +70 -93
- package/eigen/Eigen/src/Core/util/Assert.h +158 -0
- package/eigen/Eigen/src/Core/util/BlasUtil.h +413 -290
- package/eigen/Eigen/src/Core/util/ConfigureVectorization.h +543 -0
- package/eigen/Eigen/src/Core/util/Constants.h +314 -263
- package/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +130 -78
- package/eigen/Eigen/src/Core/util/EmulateArray.h +270 -0
- package/eigen/Eigen/src/Core/util/ForwardDeclarations.h +450 -224
- package/eigen/Eigen/src/Core/util/GpuHipCudaDefines.inc +101 -0
- package/eigen/Eigen/src/Core/util/GpuHipCudaUndefines.inc +45 -0
- package/eigen/Eigen/src/Core/util/IndexedViewHelper.h +487 -0
- package/eigen/Eigen/src/Core/util/IntegralConstant.h +279 -0
- package/eigen/Eigen/src/Core/util/MKL_support.h +39 -30
- package/eigen/Eigen/src/Core/util/Macros.h +939 -646
- package/eigen/Eigen/src/Core/util/MaxSizeVector.h +139 -0
- package/eigen/Eigen/src/Core/util/Memory.h +1042 -650
- package/eigen/Eigen/src/Core/util/Meta.h +618 -426
- package/eigen/Eigen/src/Core/util/MoreMeta.h +638 -0
- package/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +32 -19
- package/eigen/Eigen/src/Core/util/ReshapedHelper.h +51 -0
- package/eigen/Eigen/src/Core/util/Serializer.h +209 -0
- package/eigen/Eigen/src/Core/util/StaticAssert.h +51 -164
- package/eigen/Eigen/src/Core/util/SymbolicIndex.h +445 -0
- package/eigen/Eigen/src/Core/util/XprHelper.h +793 -538
- package/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +246 -277
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +299 -319
- package/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +52 -48
- package/eigen/Eigen/src/Eigenvalues/EigenSolver.h +413 -456
- package/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +309 -325
- package/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +157 -171
- package/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +292 -310
- package/eigen/Eigen/src/Eigenvalues/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +91 -107
- package/eigen/Eigen/src/Eigenvalues/RealQZ.h +539 -606
- package/eigen/Eigen/src/Eigenvalues/RealSchur.h +348 -382
- package/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +41 -35
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +579 -600
- package/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +47 -44
- package/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +434 -461
- package/eigen/Eigen/src/Geometry/AlignedBox.h +307 -214
- package/eigen/Eigen/src/Geometry/AngleAxis.h +135 -137
- package/eigen/Eigen/src/Geometry/EulerAngles.h +163 -74
- package/eigen/Eigen/src/Geometry/Homogeneous.h +289 -333
- package/eigen/Eigen/src/Geometry/Hyperplane.h +152 -161
- package/eigen/Eigen/src/Geometry/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Geometry/OrthoMethods.h +168 -145
- package/eigen/Eigen/src/Geometry/ParametrizedLine.h +141 -104
- package/eigen/Eigen/src/Geometry/Quaternion.h +595 -497
- package/eigen/Eigen/src/Geometry/Rotation2D.h +110 -108
- package/eigen/Eigen/src/Geometry/RotationBase.h +148 -145
- package/eigen/Eigen/src/Geometry/Scaling.h +115 -90
- package/eigen/Eigen/src/Geometry/Transform.h +896 -953
- package/eigen/Eigen/src/Geometry/Translation.h +100 -98
- package/eigen/Eigen/src/Geometry/Umeyama.h +79 -84
- package/eigen/Eigen/src/Geometry/arch/Geometry_SIMD.h +154 -0
- package/eigen/Eigen/src/Householder/BlockHouseholder.h +54 -42
- package/eigen/Eigen/src/Householder/Householder.h +104 -122
- package/eigen/Eigen/src/Householder/HouseholderSequence.h +416 -382
- package/eigen/Eigen/src/Householder/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +153 -166
- package/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +127 -138
- package/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +95 -124
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +269 -267
- package/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +246 -259
- package/eigen/Eigen/src/IterativeLinearSolvers/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +218 -217
- package/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +80 -103
- package/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +59 -63
- package/eigen/Eigen/src/Jacobi/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/Jacobi/Jacobi.h +256 -291
- package/eigen/Eigen/src/KLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/KLUSupport/KLUSupport.h +339 -0
- package/eigen/Eigen/src/LU/Determinant.h +60 -63
- package/eigen/Eigen/src/LU/FullPivLU.h +561 -626
- package/eigen/Eigen/src/LU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/LU/InverseImpl.h +213 -275
- package/eigen/Eigen/src/LU/PartialPivLU.h +407 -435
- package/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +54 -40
- package/eigen/Eigen/src/LU/arch/InverseSize4.h +353 -0
- package/eigen/Eigen/src/MetisSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/MetisSupport/MetisSupport.h +81 -93
- package/eigen/Eigen/src/OrderingMethods/Amd.h +250 -282
- package/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +950 -1103
- package/eigen/Eigen/src/OrderingMethods/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/OrderingMethods/Ordering.h +111 -122
- package/eigen/Eigen/src/PaStiXSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +524 -570
- package/eigen/Eigen/src/PardisoSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +385 -429
- package/eigen/Eigen/src/QR/ColPivHouseholderQR.h +494 -473
- package/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +120 -56
- package/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +223 -137
- package/eigen/Eigen/src/QR/FullPivHouseholderQR.h +517 -460
- package/eigen/Eigen/src/QR/HouseholderQR.h +412 -278
- package/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +32 -23
- package/eigen/Eigen/src/QR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +263 -261
- package/eigen/Eigen/src/SVD/BDCSVD.h +872 -679
- package/eigen/Eigen/src/SVD/BDCSVD_LAPACKE.h +174 -0
- package/eigen/Eigen/src/SVD/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SVD/JacobiSVD.h +585 -543
- package/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +85 -49
- package/eigen/Eigen/src/SVD/SVDBase.h +281 -160
- package/eigen/Eigen/src/SVD/UpperBidiagonalization.h +202 -237
- package/eigen/Eigen/src/SparseCholesky/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +769 -590
- package/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +318 -129
- package/eigen/Eigen/src/SparseCore/AmbiVector.h +202 -251
- package/eigen/Eigen/src/SparseCore/CompressedStorage.h +184 -236
- package/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +140 -184
- package/eigen/Eigen/src/SparseCore/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseCore/SparseAssign.h +174 -111
- package/eigen/Eigen/src/SparseCore/SparseBlock.h +408 -477
- package/eigen/Eigen/src/SparseCore/SparseColEtree.h +100 -112
- package/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +531 -280
- package/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +559 -347
- package/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +100 -108
- package/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +185 -191
- package/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +71 -71
- package/eigen/Eigen/src/SparseCore/SparseDot.h +49 -47
- package/eigen/Eigen/src/SparseCore/SparseFuzzy.h +13 -11
- package/eigen/Eigen/src/SparseCore/SparseMap.h +243 -253
- package/eigen/Eigen/src/SparseCore/SparseMatrix.h +1614 -1142
- package/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +403 -357
- package/eigen/Eigen/src/SparseCore/SparsePermutation.h +186 -115
- package/eigen/Eigen/src/SparseCore/SparseProduct.h +100 -91
- package/eigen/Eigen/src/SparseCore/SparseRedux.h +22 -24
- package/eigen/Eigen/src/SparseCore/SparseRef.h +268 -295
- package/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +371 -414
- package/eigen/Eigen/src/SparseCore/SparseSolverBase.h +78 -87
- package/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +81 -95
- package/eigen/Eigen/src/SparseCore/SparseTranspose.h +62 -71
- package/eigen/Eigen/src/SparseCore/SparseTriangularView.h +132 -144
- package/eigen/Eigen/src/SparseCore/SparseUtil.h +146 -115
- package/eigen/Eigen/src/SparseCore/SparseVector.h +426 -372
- package/eigen/Eigen/src/SparseCore/SparseView.h +164 -193
- package/eigen/Eigen/src/SparseCore/TriangularSolver.h +129 -170
- package/eigen/Eigen/src/SparseLU/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseLU/SparseLU.h +814 -618
- package/eigen/Eigen/src/SparseLU/SparseLUImpl.h +61 -48
- package/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +102 -118
- package/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +38 -35
- package/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +273 -255
- package/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +44 -49
- package/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +104 -108
- package/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +90 -101
- package/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +57 -58
- package/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +43 -55
- package/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +74 -71
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +125 -133
- package/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +136 -159
- package/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +51 -52
- package/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +67 -73
- package/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +24 -26
- package/eigen/Eigen/src/SparseQR/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SparseQR/SparseQR.h +451 -490
- package/eigen/Eigen/src/StlSupport/StdDeque.h +28 -105
- package/eigen/Eigen/src/StlSupport/StdList.h +28 -84
- package/eigen/Eigen/src/StlSupport/StdVector.h +28 -108
- package/eigen/Eigen/src/StlSupport/details.h +48 -50
- package/eigen/Eigen/src/SuperLUSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +634 -732
- package/eigen/Eigen/src/ThreadPool/Barrier.h +70 -0
- package/eigen/Eigen/src/ThreadPool/CoreThreadPoolDevice.h +336 -0
- package/eigen/Eigen/src/ThreadPool/EventCount.h +241 -0
- package/eigen/Eigen/src/ThreadPool/ForkJoin.h +140 -0
- package/eigen/Eigen/src/ThreadPool/InternalHeaderCheck.h +4 -0
- package/eigen/Eigen/src/ThreadPool/NonBlockingThreadPool.h +587 -0
- package/eigen/Eigen/src/ThreadPool/RunQueue.h +230 -0
- package/eigen/Eigen/src/ThreadPool/ThreadCancel.h +21 -0
- package/eigen/Eigen/src/ThreadPool/ThreadEnvironment.h +43 -0
- package/eigen/Eigen/src/ThreadPool/ThreadLocal.h +289 -0
- package/eigen/Eigen/src/ThreadPool/ThreadPoolInterface.h +50 -0
- package/eigen/Eigen/src/ThreadPool/ThreadYield.h +16 -0
- package/eigen/Eigen/src/UmfPackSupport/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +480 -380
- package/eigen/Eigen/src/misc/Image.h +41 -43
- package/eigen/Eigen/src/misc/InternalHeaderCheck.h +3 -0
- package/eigen/Eigen/src/misc/Kernel.h +39 -41
- package/eigen/Eigen/src/misc/RealSvd2x2.h +19 -21
- package/eigen/Eigen/src/misc/blas.h +83 -426
- package/eigen/Eigen/src/misc/lapacke.h +9976 -16182
- package/eigen/Eigen/src/misc/lapacke_helpers.h +163 -0
- package/eigen/Eigen/src/misc/lapacke_mangling.h +4 -5
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.inc +344 -0
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.inc +544 -0
- package/eigen/Eigen/src/plugins/BlockMethods.inc +1370 -0
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.inc +116 -0
- package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.inc +167 -0
- package/eigen/Eigen/src/plugins/IndexedViewMethods.inc +192 -0
- package/eigen/Eigen/src/plugins/InternalHeaderCheck.inc +3 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.inc +331 -0
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.inc +118 -0
- package/eigen/Eigen/src/plugins/ReshapedMethods.inc +133 -0
- package/lib/LibEigen.d.ts +4 -0
- package/lib/LibEigen.js +14 -0
- package/lib/index.d.ts +1 -1
- package/lib/index.js +7 -3
- package/package.json +2 -10
- package/eigen/Eigen/CMakeLists.txt +0 -19
- package/eigen/Eigen/src/Core/BooleanRedux.h +0 -164
- package/eigen/Eigen/src/Core/arch/CUDA/Complex.h +0 -103
- package/eigen/Eigen/src/Core/arch/CUDA/Half.h +0 -675
- package/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +0 -91
- package/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +0 -333
- package/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +0 -1124
- package/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +0 -212
- package/eigen/Eigen/src/Core/util/NonMPL2.h +0 -3
- package/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +0 -161
- package/eigen/Eigen/src/LU/arch/Inverse_SSE.h +0 -338
- package/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +0 -67
- package/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +0 -280
- package/eigen/Eigen/src/misc/lapack.h +0 -152
- package/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +0 -332
- package/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +0 -552
- package/eigen/Eigen/src/plugins/BlockMethods.h +0 -1058
- package/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +0 -115
- package/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +0 -163
- package/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +0 -152
- package/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +0 -85
- package/lib/eigen.d.ts +0 -2
- package/lib/eigen.js +0 -15
|
@@ -0,0 +1,1167 @@
|
|
|
1
|
+
// This file is part of Eigen, a lightweight C++ template library
|
|
2
|
+
// for linear algebra.
|
|
3
|
+
//
|
|
4
|
+
// Copyright (C) 2022 Intel Corporation
|
|
5
|
+
//
|
|
6
|
+
// This Source Code Form is subject to the terms of the Mozilla
|
|
7
|
+
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
8
|
+
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
9
|
+
|
|
10
|
+
#ifndef EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
|
|
11
|
+
#define EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
|
|
12
|
+
|
|
13
|
+
// IWYU pragma: private
|
|
14
|
+
#include "../../InternalHeaderCheck.h"
|
|
15
|
+
|
|
16
|
+
#if !defined(EIGEN_USE_AVX512_TRSM_KERNELS)
|
|
17
|
+
#define EIGEN_USE_AVX512_TRSM_KERNELS 1
|
|
18
|
+
#endif
|
|
19
|
+
|
|
20
|
+
// TRSM kernels currently unconditionally rely on malloc with AVX512.
|
|
21
|
+
// Disable them if malloc is explicitly disabled at compile-time.
|
|
22
|
+
#ifdef EIGEN_NO_MALLOC
|
|
23
|
+
#undef EIGEN_USE_AVX512_TRSM_KERNELS
|
|
24
|
+
#define EIGEN_USE_AVX512_TRSM_KERNELS 0
|
|
25
|
+
#endif
|
|
26
|
+
|
|
27
|
+
#if EIGEN_USE_AVX512_TRSM_KERNELS
|
|
28
|
+
#if !defined(EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
|
29
|
+
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 1
|
|
30
|
+
#endif
|
|
31
|
+
#if !defined(EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
|
32
|
+
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 1
|
|
33
|
+
#endif
|
|
34
|
+
#else // EIGEN_USE_AVX512_TRSM_KERNELS == 0
|
|
35
|
+
#define EIGEN_USE_AVX512_TRSM_R_KERNELS 0
|
|
36
|
+
#define EIGEN_USE_AVX512_TRSM_L_KERNELS 0
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
// Need this for some std::min calls.
|
|
40
|
+
#ifdef min
|
|
41
|
+
#undef min
|
|
42
|
+
#endif
|
|
43
|
+
|
|
44
|
+
namespace Eigen {
|
|
45
|
+
namespace internal {
|
|
46
|
+
|
|
47
|
+
#define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
|
|
48
|
+
#define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
|
|
49
|
+
#define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
|
|
50
|
+
#define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
|
|
51
|
+
#define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
|
|
52
|
+
typedef Packet16f vecFullFloat;
|
|
53
|
+
typedef Packet8d vecFullDouble;
|
|
54
|
+
typedef Packet8f vecHalfFloat;
|
|
55
|
+
typedef Packet4d vecHalfDouble;
|
|
56
|
+
|
|
57
|
+
// Compile-time unrolls are implemented here.
|
|
58
|
+
// Note: this depends on macros and typedefs above.
|
|
59
|
+
#include "TrsmUnrolls.inc"
|
|
60
|
+
|
|
61
|
+
#if (EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0)
|
|
62
|
+
/**
|
|
63
|
+
* For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly
|
|
64
|
+
* is faster than the packed versions in TriangularSolverMatrix.h.
|
|
65
|
+
*
|
|
66
|
+
* The current heuristic is based on having having all arrays used in the largest gemm-update
|
|
67
|
+
* in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be
|
|
68
|
+
* larger for some trsm cases.
|
|
69
|
+
* The formula:
|
|
70
|
+
*
|
|
71
|
+
* (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap
|
|
72
|
+
*
|
|
73
|
+
* L = number of rows to solve at a time
|
|
74
|
+
* N = number of rhs
|
|
75
|
+
* M = Dimension of triangular matrix
|
|
76
|
+
*
|
|
77
|
+
*/
|
|
78
|
+
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS)
|
|
79
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 1
|
|
80
|
+
#endif
|
|
81
|
+
|
|
82
|
+
#if EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
|
83
|
+
|
|
84
|
+
#if EIGEN_USE_AVX512_TRSM_R_KERNELS
|
|
85
|
+
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
|
86
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 1
|
|
87
|
+
#endif // !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS)
|
|
88
|
+
#endif
|
|
89
|
+
|
|
90
|
+
#if EIGEN_USE_AVX512_TRSM_L_KERNELS
|
|
91
|
+
#if !defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS)
|
|
92
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 1
|
|
93
|
+
#endif
|
|
94
|
+
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
|
95
|
+
|
|
96
|
+
#else // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS == 0
|
|
97
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
|
|
98
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
|
|
99
|
+
#endif // EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS
|
|
100
|
+
|
|
101
|
+
template <typename Scalar>
|
|
102
|
+
int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap) {
|
|
103
|
+
const int64_t U3 = 3 * packet_traits<Scalar>::size;
|
|
104
|
+
const int64_t MaxNb = 5 * U3;
|
|
105
|
+
int64_t Nb = std::min(MaxNb, N);
|
|
106
|
+
double cutoff_d =
|
|
107
|
+
(((L2Size * L2Cap) / (sizeof(Scalar))) - (EIGEN_AVX_MAX_NUM_ROW)*Nb) / ((EIGEN_AVX_MAX_NUM_ROW) + Nb);
|
|
108
|
+
int64_t cutoff_l = static_cast<int64_t>(cutoff_d);
|
|
109
|
+
return (cutoff_l / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
|
110
|
+
}
|
|
111
|
+
#else // !(EIGEN_USE_AVX512_TRSM_KERNELS) || !(EIGEN_COMP_CLANG != 0)
|
|
112
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS 0
|
|
113
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_R_CUTOFFS 0
|
|
114
|
+
#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_L_CUTOFFS 0
|
|
115
|
+
#endif
|
|
116
|
+
|
|
117
|
+
/**
|
|
118
|
+
* Used by gemmKernel for the case A/B row-major and C col-major.
|
|
119
|
+
*/
|
|
120
|
+
template <typename Scalar, typename vec, int64_t unrollM, int64_t unrollN, bool remM, bool remN>
|
|
121
|
+
EIGEN_ALWAYS_INLINE void transStoreC(PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> &zmm, Scalar *C_arr,
|
|
122
|
+
int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) {
|
|
123
|
+
EIGEN_UNUSED_VARIABLE(remN_);
|
|
124
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
125
|
+
using urolls = unrolls::trans<Scalar>;
|
|
126
|
+
|
|
127
|
+
constexpr int64_t U3 = urolls::PacketSize * 3;
|
|
128
|
+
constexpr int64_t U2 = urolls::PacketSize * 2;
|
|
129
|
+
constexpr int64_t U1 = urolls::PacketSize * 1;
|
|
130
|
+
|
|
131
|
+
static_assert(unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize");
|
|
132
|
+
static_assert(unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
|
|
133
|
+
|
|
134
|
+
urolls::template transpose<unrollN, 0>(zmm);
|
|
135
|
+
EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose<unrollN, 2>(zmm);
|
|
136
|
+
EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose<unrollN, 1>(zmm);
|
|
137
|
+
|
|
138
|
+
static_assert((remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1");
|
|
139
|
+
EIGEN_IF_CONSTEXPR(!remN) {
|
|
140
|
+
urolls::template storeC<std::min(unrollN, U1), unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
141
|
+
EIGEN_IF_CONSTEXPR(unrollN > U1) {
|
|
142
|
+
constexpr int64_t unrollN_ = std::min(unrollN - U1, U1);
|
|
143
|
+
urolls::template storeC<unrollN_, unrollN, 1, remM>(C_arr + U1 * LDC, LDC, zmm, remM_);
|
|
144
|
+
}
|
|
145
|
+
EIGEN_IF_CONSTEXPR(unrollN > U2) {
|
|
146
|
+
constexpr int64_t unrollN_ = std::min(unrollN - U2, U1);
|
|
147
|
+
urolls::template storeC<unrollN_, unrollN, 2, remM>(C_arr + U2 * LDC, LDC, zmm, remM_);
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
else {
|
|
151
|
+
EIGEN_IF_CONSTEXPR((std::is_same<Scalar, float>::value)) {
|
|
152
|
+
// Note: without "if constexpr" this section of code will also be
|
|
153
|
+
// parsed by the compiler so each of the storeC will still be instantiated.
|
|
154
|
+
// We use enable_if in aux_storeC to set it to an empty function for
|
|
155
|
+
// these cases.
|
|
156
|
+
if (remN_ == 15)
|
|
157
|
+
urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
158
|
+
else if (remN_ == 14)
|
|
159
|
+
urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
160
|
+
else if (remN_ == 13)
|
|
161
|
+
urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
162
|
+
else if (remN_ == 12)
|
|
163
|
+
urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
164
|
+
else if (remN_ == 11)
|
|
165
|
+
urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
166
|
+
else if (remN_ == 10)
|
|
167
|
+
urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
168
|
+
else if (remN_ == 9)
|
|
169
|
+
urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
170
|
+
else if (remN_ == 8)
|
|
171
|
+
urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
172
|
+
else if (remN_ == 7)
|
|
173
|
+
urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
174
|
+
else if (remN_ == 6)
|
|
175
|
+
urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
176
|
+
else if (remN_ == 5)
|
|
177
|
+
urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
178
|
+
else if (remN_ == 4)
|
|
179
|
+
urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
180
|
+
else if (remN_ == 3)
|
|
181
|
+
urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
182
|
+
else if (remN_ == 2)
|
|
183
|
+
urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
184
|
+
else if (remN_ == 1)
|
|
185
|
+
urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
186
|
+
}
|
|
187
|
+
else {
|
|
188
|
+
if (remN_ == 7)
|
|
189
|
+
urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
190
|
+
else if (remN_ == 6)
|
|
191
|
+
urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
192
|
+
else if (remN_ == 5)
|
|
193
|
+
urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
194
|
+
else if (remN_ == 4)
|
|
195
|
+
urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
196
|
+
else if (remN_ == 3)
|
|
197
|
+
urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
198
|
+
else if (remN_ == 2)
|
|
199
|
+
urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
200
|
+
else if (remN_ == 1)
|
|
201
|
+
urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_);
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
/**
|
|
207
|
+
* GEMM like operation for trsm panel updates.
|
|
208
|
+
* Computes: C -= A*B
|
|
209
|
+
* K must be multiple of 4.
|
|
210
|
+
*
|
|
211
|
+
* Unrolls used are {1,2,4,8}x{U1,U2,U3};
|
|
212
|
+
* For good performance we want K to be large with M/N relatively small, but also large enough
|
|
213
|
+
* to use the {8,U3} unroll block.
|
|
214
|
+
*
|
|
215
|
+
* isARowMajor: is A_arr row-major?
|
|
216
|
+
* isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major).
|
|
217
|
+
* isAdd: C += A*B or C -= A*B (used by trsm)
|
|
218
|
+
* handleKRem: Handle arbitrary K? This is not needed for trsm.
|
|
219
|
+
*/
|
|
220
|
+
template <typename Scalar, bool isARowMajor, bool isCRowMajor, bool isAdd, bool handleKRem>
|
|
221
|
+
void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, int64_t M, int64_t N, int64_t K, int64_t LDA, int64_t LDB,
|
|
222
|
+
int64_t LDC) {
|
|
223
|
+
using urolls = unrolls::gemm<Scalar, isAdd>;
|
|
224
|
+
constexpr int64_t U3 = urolls::PacketSize * 3;
|
|
225
|
+
constexpr int64_t U2 = urolls::PacketSize * 2;
|
|
226
|
+
constexpr int64_t U1 = urolls::PacketSize * 1;
|
|
227
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
228
|
+
int64_t N_ = (N / U3) * U3;
|
|
229
|
+
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
|
230
|
+
int64_t K_ = (K / EIGEN_AVX_MAX_K_UNROL) * EIGEN_AVX_MAX_K_UNROL;
|
|
231
|
+
int64_t j = 0;
|
|
232
|
+
for (; j < N_; j += U3) {
|
|
233
|
+
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 3;
|
|
234
|
+
int64_t i = 0;
|
|
235
|
+
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
|
|
236
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
|
|
237
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
238
|
+
urolls::template setzero<3, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
|
239
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
240
|
+
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
241
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
242
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
243
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
244
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
245
|
+
}
|
|
246
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
247
|
+
for (int64_t k = K_; k < K; k++) {
|
|
248
|
+
urolls::template microKernel<isARowMajor, 3, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 3,
|
|
249
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
250
|
+
B_t += LDB;
|
|
251
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
252
|
+
else A_t += LDA;
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
256
|
+
urolls::template updateC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
257
|
+
urolls::template storeC<3, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
258
|
+
}
|
|
259
|
+
else {
|
|
260
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, false, false>(zmm, &C_arr[i + j * LDC], LDC);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
|
|
264
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
265
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
266
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
267
|
+
urolls::template setzero<3, 4>(zmm);
|
|
268
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
269
|
+
urolls::template microKernel<isARowMajor, 3, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
|
270
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
271
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
272
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
273
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
274
|
+
}
|
|
275
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
276
|
+
for (int64_t k = K_; k < K; k++) {
|
|
277
|
+
urolls::template microKernel<isARowMajor, 3, 4, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
|
278
|
+
B_t, A_t, LDB, LDA, zmm);
|
|
279
|
+
B_t += LDB;
|
|
280
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
281
|
+
else A_t += LDA;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
285
|
+
urolls::template updateC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
286
|
+
urolls::template storeC<3, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
287
|
+
}
|
|
288
|
+
else {
|
|
289
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
|
|
290
|
+
}
|
|
291
|
+
i += 4;
|
|
292
|
+
}
|
|
293
|
+
if (M - i >= 2) {
|
|
294
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
295
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
296
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
297
|
+
urolls::template setzero<3, 2>(zmm);
|
|
298
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
299
|
+
urolls::template microKernel<isARowMajor, 3, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3,
|
|
300
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
301
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
302
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
303
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
304
|
+
}
|
|
305
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
306
|
+
for (int64_t k = K_; k < K; k++) {
|
|
307
|
+
urolls::template microKernel<isARowMajor, 3, 2, 1, EIGEN_AVX_B_LOAD_SETS * 3, EIGEN_AVX_MAX_A_BCAST>(
|
|
308
|
+
B_t, A_t, LDB, LDA, zmm);
|
|
309
|
+
B_t += LDB;
|
|
310
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
311
|
+
else A_t += LDA;
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
315
|
+
urolls::template updateC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
316
|
+
urolls::template storeC<3, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
317
|
+
}
|
|
318
|
+
else {
|
|
319
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
|
|
320
|
+
}
|
|
321
|
+
i += 2;
|
|
322
|
+
}
|
|
323
|
+
if (M - i > 0) {
|
|
324
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
325
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
326
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
327
|
+
urolls::template setzero<3, 1>(zmm);
|
|
328
|
+
{
|
|
329
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
330
|
+
urolls::template microKernel<isARowMajor, 3, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_B_LOAD_SETS * 3, 1>(
|
|
331
|
+
B_t, A_t, LDB, LDA, zmm);
|
|
332
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
333
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
334
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
335
|
+
}
|
|
336
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
337
|
+
for (int64_t k = K_; k < K; k++) {
|
|
338
|
+
urolls::template microKernel<isARowMajor, 3, 1, 1, EIGEN_AVX_B_LOAD_SETS * 3, 1>(B_t, A_t, LDB, LDA, zmm);
|
|
339
|
+
B_t += LDB;
|
|
340
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
341
|
+
else A_t += LDA;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
345
|
+
urolls::template updateC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
346
|
+
urolls::template storeC<3, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
347
|
+
}
|
|
348
|
+
else {
|
|
349
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U3, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
if (N - j >= U2) {
|
|
355
|
+
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 2;
|
|
356
|
+
int64_t i = 0;
|
|
357
|
+
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
|
|
358
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
|
|
359
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0 * LDB + j];
|
|
360
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
361
|
+
urolls::template setzero<2, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
|
362
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
363
|
+
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
364
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
365
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
366
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
367
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
368
|
+
}
|
|
369
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
370
|
+
for (int64_t k = K_; k < K; k++) {
|
|
371
|
+
urolls::template microKernel<isARowMajor, 2, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
|
372
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
373
|
+
B_t += LDB;
|
|
374
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
375
|
+
else A_t += LDA;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
379
|
+
urolls::template updateC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
380
|
+
urolls::template storeC<2, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
381
|
+
}
|
|
382
|
+
else {
|
|
383
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, false, false>(zmm, &C_arr[i + j * LDC], LDC);
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
|
|
387
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
388
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
389
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
390
|
+
urolls::template setzero<2, 4>(zmm);
|
|
391
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
392
|
+
urolls::template microKernel<isARowMajor, 2, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
393
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
394
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
395
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
396
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
397
|
+
}
|
|
398
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
399
|
+
for (int64_t k = K_; k < K; k++) {
|
|
400
|
+
urolls::template microKernel<isARowMajor, 2, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
|
401
|
+
LDA, zmm);
|
|
402
|
+
B_t += LDB;
|
|
403
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
404
|
+
else A_t += LDA;
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
408
|
+
urolls::template updateC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
409
|
+
urolls::template storeC<2, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
410
|
+
}
|
|
411
|
+
else {
|
|
412
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
|
|
413
|
+
}
|
|
414
|
+
i += 4;
|
|
415
|
+
}
|
|
416
|
+
if (M - i >= 2) {
|
|
417
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
418
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
419
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
420
|
+
urolls::template setzero<2, 2>(zmm);
|
|
421
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
422
|
+
urolls::template microKernel<isARowMajor, 2, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
423
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
424
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
425
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
426
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
427
|
+
}
|
|
428
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
429
|
+
for (int64_t k = K_; k < K; k++) {
|
|
430
|
+
urolls::template microKernel<isARowMajor, 2, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
|
431
|
+
LDA, zmm);
|
|
432
|
+
B_t += LDB;
|
|
433
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
434
|
+
else A_t += LDA;
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
438
|
+
urolls::template updateC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
439
|
+
urolls::template storeC<2, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
440
|
+
}
|
|
441
|
+
else {
|
|
442
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
|
|
443
|
+
}
|
|
444
|
+
i += 2;
|
|
445
|
+
}
|
|
446
|
+
if (M - i > 0) {
|
|
447
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
448
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
449
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
450
|
+
urolls::template setzero<2, 1>(zmm);
|
|
451
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
452
|
+
urolls::template microKernel<isARowMajor, 2, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
|
453
|
+
LDA, zmm);
|
|
454
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
455
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
456
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
457
|
+
}
|
|
458
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
459
|
+
for (int64_t k = K_; k < K; k++) {
|
|
460
|
+
urolls::template microKernel<isARowMajor, 2, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB, LDA, zmm);
|
|
461
|
+
B_t += LDB;
|
|
462
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
463
|
+
else A_t += LDA;
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
467
|
+
urolls::template updateC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
468
|
+
urolls::template storeC<2, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
469
|
+
}
|
|
470
|
+
else {
|
|
471
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U2, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
j += U2;
|
|
475
|
+
}
|
|
476
|
+
if (N - j >= U1) {
|
|
477
|
+
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
|
|
478
|
+
int64_t i = 0;
|
|
479
|
+
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
|
|
480
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)], *B_t = &B_arr[0 * LDB + j];
|
|
481
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
482
|
+
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
|
483
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
484
|
+
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
485
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
486
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
487
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
488
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
489
|
+
}
|
|
490
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
491
|
+
for (int64_t k = K_; k < K; k++) {
|
|
492
|
+
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_B_LOAD_SETS * 1,
|
|
493
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
494
|
+
B_t += LDB;
|
|
495
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
496
|
+
else A_t += LDA;
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
500
|
+
urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
501
|
+
urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i * LDC + j], LDC, zmm);
|
|
502
|
+
}
|
|
503
|
+
else {
|
|
504
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, false>(zmm, &C_arr[i + j * LDC], LDC);
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
|
|
508
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
509
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
510
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
511
|
+
urolls::template setzero<1, 4>(zmm);
|
|
512
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
513
|
+
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
514
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
515
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
516
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
517
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
518
|
+
}
|
|
519
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
520
|
+
for (int64_t k = K_; k < K; k++) {
|
|
521
|
+
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
|
522
|
+
LDA, zmm);
|
|
523
|
+
B_t += LDB;
|
|
524
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
525
|
+
else A_t += LDA;
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
529
|
+
urolls::template updateC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
530
|
+
urolls::template storeC<1, 4>(&C_arr[i * LDC + j], LDC, zmm);
|
|
531
|
+
}
|
|
532
|
+
else {
|
|
533
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 4);
|
|
534
|
+
}
|
|
535
|
+
i += 4;
|
|
536
|
+
}
|
|
537
|
+
if (M - i >= 2) {
|
|
538
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
539
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
540
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
541
|
+
urolls::template setzero<1, 2>(zmm);
|
|
542
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
543
|
+
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
544
|
+
EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB, LDA, zmm);
|
|
545
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
546
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
547
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
548
|
+
}
|
|
549
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
550
|
+
for (int64_t k = K_; k < K; k++) {
|
|
551
|
+
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST>(B_t, A_t, LDB,
|
|
552
|
+
LDA, zmm);
|
|
553
|
+
B_t += LDB;
|
|
554
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
555
|
+
else A_t += LDA;
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
559
|
+
urolls::template updateC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
560
|
+
urolls::template storeC<1, 2>(&C_arr[i * LDC + j], LDC, zmm);
|
|
561
|
+
}
|
|
562
|
+
else {
|
|
563
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 2);
|
|
564
|
+
}
|
|
565
|
+
i += 2;
|
|
566
|
+
}
|
|
567
|
+
if (M - i > 0) {
|
|
568
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
569
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
570
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
571
|
+
urolls::template setzero<1, 1>(zmm);
|
|
572
|
+
{
|
|
573
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
574
|
+
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1>(B_t, A_t, LDB,
|
|
575
|
+
LDA, zmm);
|
|
576
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
577
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
578
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
579
|
+
}
|
|
580
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
581
|
+
for (int64_t k = K_; k < K; k++) {
|
|
582
|
+
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_B_LOAD_SETS * 1, 1>(B_t, A_t, LDB, LDA, zmm);
|
|
583
|
+
B_t += LDB;
|
|
584
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
585
|
+
else A_t += LDA;
|
|
586
|
+
}
|
|
587
|
+
}
|
|
588
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
589
|
+
urolls::template updateC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
590
|
+
urolls::template storeC<1, 1>(&C_arr[i * LDC + j], LDC, zmm);
|
|
591
|
+
}
|
|
592
|
+
else {
|
|
593
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, false>(zmm, &C_arr[i + j * LDC], LDC, 1);
|
|
594
|
+
}
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
j += U1;
|
|
598
|
+
}
|
|
599
|
+
if (N - j > 0) {
|
|
600
|
+
constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS * 1;
|
|
601
|
+
int64_t i = 0;
|
|
602
|
+
for (; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
|
|
603
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
604
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
605
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
606
|
+
urolls::template setzero<1, EIGEN_AVX_MAX_NUM_ROW>(zmm);
|
|
607
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
608
|
+
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
609
|
+
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
|
610
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
611
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
612
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
613
|
+
}
|
|
614
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
615
|
+
for (int64_t k = K_; k < K; k++) {
|
|
616
|
+
urolls::template microKernel<isARowMajor, 1, EIGEN_AVX_MAX_NUM_ROW, 1, EIGEN_AVX_MAX_B_LOAD,
|
|
617
|
+
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
|
618
|
+
B_t += LDB;
|
|
619
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
620
|
+
else A_t += LDA;
|
|
621
|
+
}
|
|
622
|
+
}
|
|
623
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
624
|
+
urolls::template updateC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
625
|
+
urolls::template storeC<1, EIGEN_AVX_MAX_NUM_ROW, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
626
|
+
}
|
|
627
|
+
else {
|
|
628
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, false, true>(zmm, &C_arr[i + j * LDC], LDC, 0, N - j);
|
|
629
|
+
}
|
|
630
|
+
}
|
|
631
|
+
if (M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise
|
|
632
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
633
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
634
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
635
|
+
urolls::template setzero<1, 4>(zmm);
|
|
636
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
637
|
+
urolls::template microKernel<isARowMajor, 1, 4, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
638
|
+
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
|
639
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
640
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
641
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
642
|
+
}
|
|
643
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
644
|
+
for (int64_t k = K_; k < K; k++) {
|
|
645
|
+
urolls::template microKernel<isARowMajor, 1, 4, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
|
646
|
+
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
647
|
+
B_t += LDB;
|
|
648
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
649
|
+
else A_t += LDA;
|
|
650
|
+
}
|
|
651
|
+
}
|
|
652
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
653
|
+
urolls::template updateC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
654
|
+
urolls::template storeC<1, 4, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
655
|
+
}
|
|
656
|
+
else {
|
|
657
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 4, N - j);
|
|
658
|
+
}
|
|
659
|
+
i += 4;
|
|
660
|
+
}
|
|
661
|
+
if (M - i >= 2) {
|
|
662
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
663
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
664
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
665
|
+
urolls::template setzero<1, 2>(zmm);
|
|
666
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
667
|
+
urolls::template microKernel<isARowMajor, 1, 2, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD,
|
|
668
|
+
EIGEN_AVX_MAX_A_BCAST, true>(B_t, A_t, LDB, LDA, zmm, N - j);
|
|
669
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
670
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
671
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
672
|
+
}
|
|
673
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
674
|
+
for (int64_t k = K_; k < K; k++) {
|
|
675
|
+
urolls::template microKernel<isARowMajor, 1, 2, 1, EIGEN_AVX_MAX_B_LOAD, EIGEN_AVX_MAX_A_BCAST, true>(
|
|
676
|
+
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
677
|
+
B_t += LDB;
|
|
678
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
679
|
+
else A_t += LDA;
|
|
680
|
+
}
|
|
681
|
+
}
|
|
682
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
683
|
+
urolls::template updateC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
684
|
+
urolls::template storeC<1, 2, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
685
|
+
}
|
|
686
|
+
else {
|
|
687
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 2, N - j);
|
|
688
|
+
}
|
|
689
|
+
i += 2;
|
|
690
|
+
}
|
|
691
|
+
if (M - i > 0) {
|
|
692
|
+
Scalar *A_t = &A_arr[idA<isARowMajor>(i, 0, LDA)];
|
|
693
|
+
Scalar *B_t = &B_arr[0 * LDB + j];
|
|
694
|
+
PacketBlock<vec, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> zmm;
|
|
695
|
+
urolls::template setzero<1, 1>(zmm);
|
|
696
|
+
for (int64_t k = 0; k < K_; k += EIGEN_AVX_MAX_K_UNROL) {
|
|
697
|
+
urolls::template microKernel<isARowMajor, 1, 1, EIGEN_AVX_MAX_K_UNROL, EIGEN_AVX_MAX_B_LOAD, 1, true>(
|
|
698
|
+
B_t, A_t, LDB, LDA, zmm, N - j);
|
|
699
|
+
B_t += EIGEN_AVX_MAX_K_UNROL * LDB;
|
|
700
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL;
|
|
701
|
+
else A_t += EIGEN_AVX_MAX_K_UNROL * LDA;
|
|
702
|
+
}
|
|
703
|
+
EIGEN_IF_CONSTEXPR(handleKRem) {
|
|
704
|
+
for (int64_t k = K_; k < K; k++) {
|
|
705
|
+
urolls::template microKernel<isARowMajor, 1, 1, 1, EIGEN_AVX_MAX_B_LOAD, 1, true>(B_t, A_t, LDB, LDA, zmm,
|
|
706
|
+
N - j);
|
|
707
|
+
B_t += LDB;
|
|
708
|
+
EIGEN_IF_CONSTEXPR(isARowMajor) A_t++;
|
|
709
|
+
else A_t += LDA;
|
|
710
|
+
}
|
|
711
|
+
}
|
|
712
|
+
EIGEN_IF_CONSTEXPR(isCRowMajor) {
|
|
713
|
+
urolls::template updateC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
714
|
+
urolls::template storeC<1, 1, true>(&C_arr[i * LDC + j], LDC, zmm, N - j);
|
|
715
|
+
}
|
|
716
|
+
else {
|
|
717
|
+
transStoreC<Scalar, vec, EIGEN_AVX_MAX_NUM_ROW, U1, true, true>(zmm, &C_arr[i + j * LDC], LDC, 1, N - j);
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
/**
|
|
724
|
+
* Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM
|
|
725
|
+
*
|
|
726
|
+
* unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW
|
|
727
|
+
* isFWDSolve: is forward solve?
|
|
728
|
+
* isUnitDiag: is the diagonal of A all ones?
|
|
729
|
+
* The B matrix (RHS) is assumed to be row-major
|
|
730
|
+
*/
|
|
731
|
+
template <typename Scalar, typename vec, int64_t unrollM, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
|
|
732
|
+
EIGEN_ALWAYS_INLINE void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) {
|
|
733
|
+
static_assert(unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW");
|
|
734
|
+
using urolls = unrolls::trsm<Scalar>;
|
|
735
|
+
constexpr int64_t U3 = urolls::PacketSize * 3;
|
|
736
|
+
constexpr int64_t U2 = urolls::PacketSize * 2;
|
|
737
|
+
constexpr int64_t U1 = urolls::PacketSize * 1;
|
|
738
|
+
|
|
739
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ACC> RHSInPacket;
|
|
740
|
+
PacketBlock<vec, EIGEN_AVX_MAX_NUM_ROW> AInPacket;
|
|
741
|
+
|
|
742
|
+
int64_t k = 0;
|
|
743
|
+
while (K - k >= U3) {
|
|
744
|
+
urolls::template loadRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
|
745
|
+
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 3>(A_arr, LDA, RHSInPacket,
|
|
746
|
+
AInPacket);
|
|
747
|
+
urolls::template storeRHS<isFWDSolve, unrollM, 3>(B_arr + k, LDB, RHSInPacket);
|
|
748
|
+
k += U3;
|
|
749
|
+
}
|
|
750
|
+
if (K - k >= U2) {
|
|
751
|
+
urolls::template loadRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
|
752
|
+
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 2>(A_arr, LDA, RHSInPacket,
|
|
753
|
+
AInPacket);
|
|
754
|
+
urolls::template storeRHS<isFWDSolve, unrollM, 2>(B_arr + k, LDB, RHSInPacket);
|
|
755
|
+
k += U2;
|
|
756
|
+
}
|
|
757
|
+
if (K - k >= U1) {
|
|
758
|
+
urolls::template loadRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
|
759
|
+
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
|
760
|
+
AInPacket);
|
|
761
|
+
urolls::template storeRHS<isFWDSolve, unrollM, 1>(B_arr + k, LDB, RHSInPacket);
|
|
762
|
+
k += U1;
|
|
763
|
+
}
|
|
764
|
+
if (K - k > 0) {
|
|
765
|
+
// Handle remaining number of RHS
|
|
766
|
+
urolls::template loadRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
|
767
|
+
urolls::template triSolveMicroKernel<isARowMajor, isFWDSolve, isUnitDiag, unrollM, 1>(A_arr, LDA, RHSInPacket,
|
|
768
|
+
AInPacket);
|
|
769
|
+
urolls::template storeRHS<isFWDSolve, unrollM, 1, true>(B_arr + k, LDB, RHSInPacket, K - k);
|
|
770
|
+
}
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
/**
|
|
774
|
+
* Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially
|
|
775
|
+
* a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}.
|
|
776
|
+
*
|
|
777
|
+
* isFWDSolve: is forward solve?
|
|
778
|
+
* isUnitDiag: is the diagonal of A all ones?
|
|
779
|
+
* The B matrix (RHS) is assumed to be row-major
|
|
780
|
+
*/
|
|
781
|
+
template <typename Scalar, bool isARowMajor, bool isFWDSolve, bool isUnitDiag>
|
|
782
|
+
void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) {
|
|
783
|
+
// Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted
|
|
784
|
+
// accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller.
|
|
785
|
+
using vec = typename std::conditional<std::is_same<Scalar, float>::value, vecFullFloat, vecFullDouble>::type;
|
|
786
|
+
if (M == 8)
|
|
787
|
+
triSolveKernel<Scalar, vec, 8, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
788
|
+
else if (M == 7)
|
|
789
|
+
triSolveKernel<Scalar, vec, 7, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
790
|
+
else if (M == 6)
|
|
791
|
+
triSolveKernel<Scalar, vec, 6, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
792
|
+
else if (M == 5)
|
|
793
|
+
triSolveKernel<Scalar, vec, 5, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
794
|
+
else if (M == 4)
|
|
795
|
+
triSolveKernel<Scalar, vec, 4, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
796
|
+
else if (M == 3)
|
|
797
|
+
triSolveKernel<Scalar, vec, 3, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
798
|
+
else if (M == 2)
|
|
799
|
+
triSolveKernel<Scalar, vec, 2, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
800
|
+
else if (M == 1)
|
|
801
|
+
triSolveKernel<Scalar, vec, 1, isARowMajor, isFWDSolve, isUnitDiag>(A_arr, B_arr, K, LDA, LDB);
|
|
802
|
+
return;
|
|
803
|
+
}
|
|
804
|
+
|
|
805
|
+
/**
|
|
806
|
+
* This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major.
|
|
807
|
+
*
|
|
808
|
+
* toTemp: true => copy to temporary array, false => copy from temporary array
|
|
809
|
+
* remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW)
|
|
810
|
+
*
|
|
811
|
+
*/
|
|
812
|
+
template <typename Scalar, bool toTemp = true, bool remM = false>
|
|
813
|
+
EIGEN_ALWAYS_INLINE void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, Scalar *B_temp, int64_t LDB_,
|
|
814
|
+
int64_t remM_ = 0) {
|
|
815
|
+
EIGEN_UNUSED_VARIABLE(remM_);
|
|
816
|
+
using urolls = unrolls::transB<Scalar>;
|
|
817
|
+
using vecHalf = typename std::conditional<std::is_same<Scalar, float>::value, vecHalfFloat, vecFullDouble>::type;
|
|
818
|
+
PacketBlock<vecHalf, EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS> ymm;
|
|
819
|
+
constexpr int64_t U3 = urolls::PacketSize * 3;
|
|
820
|
+
constexpr int64_t U2 = urolls::PacketSize * 2;
|
|
821
|
+
constexpr int64_t U1 = urolls::PacketSize * 1;
|
|
822
|
+
int64_t K_ = K / U3 * U3;
|
|
823
|
+
int64_t k = 0;
|
|
824
|
+
|
|
825
|
+
for (; k < K_; k += U3) {
|
|
826
|
+
urolls::template transB_kernel<U3, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
827
|
+
B_temp += U3;
|
|
828
|
+
}
|
|
829
|
+
if (K - k >= U2) {
|
|
830
|
+
urolls::template transB_kernel<U2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
831
|
+
B_temp += U2;
|
|
832
|
+
k += U2;
|
|
833
|
+
}
|
|
834
|
+
if (K - k >= U1) {
|
|
835
|
+
urolls::template transB_kernel<U1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
836
|
+
B_temp += U1;
|
|
837
|
+
k += U1;
|
|
838
|
+
}
|
|
839
|
+
EIGEN_IF_CONSTEXPR(U1 > 8) {
|
|
840
|
+
// Note: without "if constexpr" this section of code will also be
|
|
841
|
+
// parsed by the compiler so there is an additional check in {load/store}BBlock
|
|
842
|
+
// to make sure the counter is not non-negative.
|
|
843
|
+
if (K - k >= 8) {
|
|
844
|
+
urolls::template transB_kernel<8, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
845
|
+
B_temp += 8;
|
|
846
|
+
k += 8;
|
|
847
|
+
}
|
|
848
|
+
}
|
|
849
|
+
EIGEN_IF_CONSTEXPR(U1 > 4) {
|
|
850
|
+
// Note: without "if constexpr" this section of code will also be
|
|
851
|
+
// parsed by the compiler so there is an additional check in {load/store}BBlock
|
|
852
|
+
// to make sure the counter is not non-negative.
|
|
853
|
+
if (K - k >= 4) {
|
|
854
|
+
urolls::template transB_kernel<4, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
855
|
+
B_temp += 4;
|
|
856
|
+
k += 4;
|
|
857
|
+
}
|
|
858
|
+
}
|
|
859
|
+
if (K - k >= 2) {
|
|
860
|
+
urolls::template transB_kernel<2, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
861
|
+
B_temp += 2;
|
|
862
|
+
k += 2;
|
|
863
|
+
}
|
|
864
|
+
if (K - k >= 1) {
|
|
865
|
+
urolls::template transB_kernel<1, toTemp, remM>(B_arr + k * LDB, LDB, B_temp, LDB_, ymm, remM_);
|
|
866
|
+
B_temp += 1;
|
|
867
|
+
k += 1;
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
/**
|
|
872
|
+
* Main triangular solve driver
|
|
873
|
+
*
|
|
874
|
+
* Triangular solve with A on the left.
|
|
875
|
+
* Scalar: Scalar precision, only float/double is supported.
|
|
876
|
+
* isARowMajor: is A row-major?
|
|
877
|
+
* isBRowMajor: is B row-major?
|
|
878
|
+
* isFWDSolve: is this forward solve or backward (true => forward)?
|
|
879
|
+
* isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)?
|
|
880
|
+
*
|
|
881
|
+
* M: dimension of A
|
|
882
|
+
* numRHS: number of right hand sides (coincides with K dimension for gemm updates)
|
|
883
|
+
*
|
|
884
|
+
* Here are the mapping between the different TRSM cases (col-major) and triSolve:
|
|
885
|
+
*
|
|
886
|
+
* LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true
|
|
887
|
+
* LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true
|
|
888
|
+
* LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false
|
|
889
|
+
* LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false
|
|
890
|
+
* RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true
|
|
891
|
+
* RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true
|
|
892
|
+
* RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false
|
|
893
|
+
* RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false
|
|
894
|
+
*
|
|
895
|
+
* Note: For RXX cases M,numRHS should be swapped.
|
|
896
|
+
*
|
|
897
|
+
*/
|
|
898
|
+
template <typename Scalar, bool isARowMajor = true, bool isBRowMajor = true, bool isFWDSolve = true,
|
|
899
|
+
bool isUnitDiag = false>
|
|
900
|
+
void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) {
|
|
901
|
+
constexpr int64_t psize = packet_traits<Scalar>::size;
|
|
902
|
+
/**
|
|
903
|
+
* The values for kB, numM were determined experimentally.
|
|
904
|
+
* kB: Number of RHS we process at a time.
|
|
905
|
+
* numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L.
|
|
906
|
+
*
|
|
907
|
+
* kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case)
|
|
908
|
+
* performance with M=numRHS.
|
|
909
|
+
* It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent.
|
|
910
|
+
*
|
|
911
|
+
* numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be
|
|
912
|
+
* large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to
|
|
913
|
+
* determine optimal values for numM.
|
|
914
|
+
*/
|
|
915
|
+
constexpr int64_t kB = (3 * psize) * 5; // 5*U3
|
|
916
|
+
constexpr int64_t numM = 8 * EIGEN_AVX_MAX_NUM_ROW;
|
|
917
|
+
|
|
918
|
+
int64_t sizeBTemp = 0;
|
|
919
|
+
Scalar *B_temp = NULL;
|
|
920
|
+
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
|
921
|
+
/**
|
|
922
|
+
* If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and
|
|
923
|
+
* transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array.
|
|
924
|
+
* The updated row-major copy of B is reused in the GEMM updates.
|
|
925
|
+
*/
|
|
926
|
+
sizeBTemp = (((std::min(kB, numRHS) + psize - 1) / psize + 4) * psize) * numM;
|
|
927
|
+
}
|
|
928
|
+
|
|
929
|
+
EIGEN_IF_CONSTEXPR(!isBRowMajor) B_temp = (Scalar *)handmade_aligned_malloc(sizeof(Scalar) * sizeBTemp, 64);
|
|
930
|
+
|
|
931
|
+
for (int64_t k = 0; k < numRHS; k += kB) {
|
|
932
|
+
int64_t bK = numRHS - k > kB ? kB : numRHS - k;
|
|
933
|
+
int64_t M_ = (M / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0;
|
|
934
|
+
|
|
935
|
+
// bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS
|
|
936
|
+
// instead of bK RHS in triSolveKernelLxK.
|
|
937
|
+
int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW - 1)) / EIGEN_AVX_MAX_NUM_ROW) * EIGEN_AVX_MAX_NUM_ROW;
|
|
938
|
+
const int64_t numScalarPerCache = 64 / sizeof(Scalar);
|
|
939
|
+
// Leading dimension of B_temp, will be a multiple of the cache line size.
|
|
940
|
+
int64_t LDT = ((bkL + (numScalarPerCache - 1)) / numScalarPerCache) * numScalarPerCache;
|
|
941
|
+
int64_t offsetBTemp = 0;
|
|
942
|
+
for (int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) {
|
|
943
|
+
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
|
944
|
+
int64_t indA_i = isFWDSolve ? i : M - 1 - i;
|
|
945
|
+
int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
|
946
|
+
int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW * LDT - offsetBTemp;
|
|
947
|
+
int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp;
|
|
948
|
+
// Copy values from B to B_temp.
|
|
949
|
+
copyBToRowMajor<Scalar, true, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
|
|
950
|
+
// Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major)
|
|
951
|
+
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
|
952
|
+
&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT);
|
|
953
|
+
// Copy values from B_temp back to B. B_temp will be reused in gemm call below.
|
|
954
|
+
copyBToRowMajor<Scalar, false, false>(B_arr + indB_i + k * LDB, LDB, bK, B_temp + offB_1, LDT);
|
|
955
|
+
|
|
956
|
+
offsetBTemp += EIGEN_AVX_MAX_NUM_ROW * LDT;
|
|
957
|
+
}
|
|
958
|
+
else {
|
|
959
|
+
int64_t ind = isFWDSolve ? i : M - 1 - i;
|
|
960
|
+
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(
|
|
961
|
+
&A_arr[idA<isARowMajor>(ind, ind, LDA)], B_arr + k + ind * LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB);
|
|
962
|
+
}
|
|
963
|
+
if (i + EIGEN_AVX_MAX_NUM_ROW < M_) {
|
|
964
|
+
/**
|
|
965
|
+
* For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible
|
|
966
|
+
* to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of
|
|
967
|
+
* B as follows:
|
|
968
|
+
*
|
|
969
|
+
* A B
|
|
970
|
+
* __
|
|
971
|
+
* |__|__ |__|
|
|
972
|
+
* |__|__|__ |__|
|
|
973
|
+
* |__|__|__|__ |__|
|
|
974
|
+
* |********|__| |**|
|
|
975
|
+
*/
|
|
976
|
+
EIGEN_IF_CONSTEXPR(isBRowMajor) {
|
|
977
|
+
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
|
978
|
+
int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
|
979
|
+
int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
|
980
|
+
int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
|
981
|
+
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
|
982
|
+
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB,
|
|
983
|
+
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, LDA, LDB, LDB);
|
|
984
|
+
}
|
|
985
|
+
else {
|
|
986
|
+
if (offsetBTemp + EIGEN_AVX_MAX_NUM_ROW * LDT > sizeBTemp) {
|
|
987
|
+
/**
|
|
988
|
+
* Similar idea as mentioned above, but here we are limited by the number of updated values of B
|
|
989
|
+
* that can be stored (row-major) in B_temp.
|
|
990
|
+
*
|
|
991
|
+
* If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM
|
|
992
|
+
* update and partially update the remaining old values of B which depends on the new values
|
|
993
|
+
* of B stored in B_temp. These values are then no longer needed and can be overwritten.
|
|
994
|
+
*/
|
|
995
|
+
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
|
|
996
|
+
int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
|
997
|
+
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0;
|
|
998
|
+
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
|
999
|
+
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
|
1000
|
+
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
|
1001
|
+
M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
|
1002
|
+
offsetBTemp = 0;
|
|
1003
|
+
gemmOff = i + EIGEN_AVX_MAX_NUM_ROW;
|
|
1004
|
+
} else {
|
|
1005
|
+
/**
|
|
1006
|
+
* If there is enough space in B_temp, we only update the next 8xbK values of B.
|
|
1007
|
+
*/
|
|
1008
|
+
int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
|
1009
|
+
int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW);
|
|
1010
|
+
int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2 * EIGEN_AVX_MAX_NUM_ROW);
|
|
1011
|
+
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
|
1012
|
+
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
|
1013
|
+
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_temp + offB_1, B_arr + indB_i + (k)*LDB,
|
|
1014
|
+
EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, LDA, LDT, LDB);
|
|
1015
|
+
}
|
|
1016
|
+
}
|
|
1017
|
+
}
|
|
1018
|
+
}
|
|
1019
|
+
// Handle M remainder..
|
|
1020
|
+
int64_t bM = M - M_;
|
|
1021
|
+
if (bM > 0) {
|
|
1022
|
+
if (M_ > 0) {
|
|
1023
|
+
EIGEN_IF_CONSTEXPR(isBRowMajor) {
|
|
1024
|
+
int64_t indA_i = isFWDSolve ? M_ : 0;
|
|
1025
|
+
int64_t indA_j = isFWDSolve ? 0 : bM;
|
|
1026
|
+
int64_t indB_i = isFWDSolve ? 0 : bM;
|
|
1027
|
+
int64_t indB_i2 = isFWDSolve ? M_ : 0;
|
|
1028
|
+
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(
|
|
1029
|
+
&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)], B_arr + k + indB_i * LDB, B_arr + k + indB_i2 * LDB, bM,
|
|
1030
|
+
bK, M_, LDA, LDB, LDB);
|
|
1031
|
+
}
|
|
1032
|
+
else {
|
|
1033
|
+
int64_t indA_i = isFWDSolve ? M_ : 0;
|
|
1034
|
+
int64_t indA_j = isFWDSolve ? gemmOff : bM;
|
|
1035
|
+
int64_t indB_i = isFWDSolve ? M_ : 0;
|
|
1036
|
+
int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp;
|
|
1037
|
+
gemmKernel<Scalar, isARowMajor, isBRowMajor, false, false>(&A_arr[idA<isARowMajor>(indA_i, indA_j, LDA)],
|
|
1038
|
+
B_temp + offB_1, B_arr + indB_i + (k)*LDB, bM, bK,
|
|
1039
|
+
M_ - gemmOff, LDA, LDT, LDB);
|
|
1040
|
+
}
|
|
1041
|
+
}
|
|
1042
|
+
EIGEN_IF_CONSTEXPR(!isBRowMajor) {
|
|
1043
|
+
int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_;
|
|
1044
|
+
int64_t indB_i = isFWDSolve ? M_ : 0;
|
|
1045
|
+
int64_t offB_1 = isFWDSolve ? 0 : (bM - 1) * bkL;
|
|
1046
|
+
copyBToRowMajor<Scalar, true, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
|
1047
|
+
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(indA_i, indA_i, LDA)],
|
|
1048
|
+
B_temp + offB_1, bM, bkL, LDA, bkL);
|
|
1049
|
+
copyBToRowMajor<Scalar, false, true>(B_arr + indB_i + k * LDB, LDB, bK, B_temp, bkL, bM);
|
|
1050
|
+
}
|
|
1051
|
+
else {
|
|
1052
|
+
int64_t ind = isFWDSolve ? M_ : M - 1 - M_;
|
|
1053
|
+
triSolveKernelLxK<Scalar, isARowMajor, isFWDSolve, isUnitDiag>(&A_arr[idA<isARowMajor>(ind, ind, LDA)],
|
|
1054
|
+
B_arr + k + ind * LDB, bM, bK, LDA, LDB);
|
|
1055
|
+
}
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
|
|
1059
|
+
EIGEN_IF_CONSTEXPR(!isBRowMajor) handmade_aligned_free(B_temp);
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
// Template specializations of trsmKernelL/R for float/double and inner strides of 1.
|
|
1063
|
+
#if (EIGEN_USE_AVX512_TRSM_KERNELS)
|
|
1064
|
+
#if (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
|
1065
|
+
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
|
|
1066
|
+
bool Specialized>
|
|
1067
|
+
struct trsmKernelR;
|
|
1068
|
+
|
|
1069
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1070
|
+
struct trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true> {
|
|
1071
|
+
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
|
1072
|
+
Index otherStride);
|
|
1073
|
+
};
|
|
1074
|
+
|
|
1075
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1076
|
+
struct trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true> {
|
|
1077
|
+
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
|
1078
|
+
Index otherStride);
|
|
1079
|
+
};
|
|
1080
|
+
|
|
1081
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1082
|
+
EIGEN_DONT_INLINE void trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
|
|
1083
|
+
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
|
1084
|
+
Index otherStride) {
|
|
1085
|
+
EIGEN_UNUSED_VARIABLE(otherIncr);
|
|
1086
|
+
#ifdef EIGEN_RUNTIME_NO_MALLOC
|
|
1087
|
+
if (!is_malloc_allowed()) {
|
|
1088
|
+
trsmKernelR<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
|
|
1089
|
+
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
|
|
1090
|
+
return;
|
|
1091
|
+
}
|
|
1092
|
+
#endif
|
|
1093
|
+
triSolve<float, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
|
1094
|
+
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1098
|
+
EIGEN_DONT_INLINE void trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
|
|
1099
|
+
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
|
1100
|
+
Index otherStride) {
|
|
1101
|
+
EIGEN_UNUSED_VARIABLE(otherIncr);
|
|
1102
|
+
#ifdef EIGEN_RUNTIME_NO_MALLOC
|
|
1103
|
+
if (!is_malloc_allowed()) {
|
|
1104
|
+
trsmKernelR<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
|
|
1105
|
+
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
|
|
1106
|
+
return;
|
|
1107
|
+
}
|
|
1108
|
+
#endif
|
|
1109
|
+
triSolve<double, TriStorageOrder != RowMajor, true, (Mode & Lower) != Lower, (Mode & UnitDiag) != 0>(
|
|
1110
|
+
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
|
1111
|
+
}
|
|
1112
|
+
#endif // (EIGEN_USE_AVX512_TRSM_R_KERNELS)
|
|
1113
|
+
|
|
1114
|
+
// These trsm kernels require temporary memory allocation
|
|
1115
|
+
#if (EIGEN_USE_AVX512_TRSM_L_KERNELS)
|
|
1116
|
+
template <typename Scalar, typename Index, int Mode, bool Conjugate, int TriStorageOrder, int OtherInnerStride,
|
|
1117
|
+
bool Specialized = true>
|
|
1118
|
+
struct trsmKernelL;
|
|
1119
|
+
|
|
1120
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1121
|
+
struct trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true> {
|
|
1122
|
+
static void kernel(Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
|
1123
|
+
Index otherStride);
|
|
1124
|
+
};
|
|
1125
|
+
|
|
1126
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1127
|
+
struct trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true> {
|
|
1128
|
+
static void kernel(Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
|
1129
|
+
Index otherStride);
|
|
1130
|
+
};
|
|
1131
|
+
|
|
1132
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1133
|
+
EIGEN_DONT_INLINE void trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
|
|
1134
|
+
Index size, Index otherSize, const float *_tri, Index triStride, float *_other, Index otherIncr,
|
|
1135
|
+
Index otherStride) {
|
|
1136
|
+
EIGEN_UNUSED_VARIABLE(otherIncr);
|
|
1137
|
+
#ifdef EIGEN_RUNTIME_NO_MALLOC
|
|
1138
|
+
if (!is_malloc_allowed()) {
|
|
1139
|
+
trsmKernelL<float, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
|
|
1140
|
+
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
|
|
1141
|
+
return;
|
|
1142
|
+
}
|
|
1143
|
+
#endif
|
|
1144
|
+
triSolve<float, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
|
1145
|
+
const_cast<float *>(_tri), _other, size, otherSize, triStride, otherStride);
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
template <typename Index, int Mode, int TriStorageOrder>
|
|
1149
|
+
EIGEN_DONT_INLINE void trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, true>::kernel(
|
|
1150
|
+
Index size, Index otherSize, const double *_tri, Index triStride, double *_other, Index otherIncr,
|
|
1151
|
+
Index otherStride) {
|
|
1152
|
+
EIGEN_UNUSED_VARIABLE(otherIncr);
|
|
1153
|
+
#ifdef EIGEN_RUNTIME_NO_MALLOC
|
|
1154
|
+
if (!is_malloc_allowed()) {
|
|
1155
|
+
trsmKernelL<double, Index, Mode, false, TriStorageOrder, 1, /*Specialized=*/false>::kernel(
|
|
1156
|
+
size, otherSize, _tri, triStride, _other, otherIncr, otherStride);
|
|
1157
|
+
return;
|
|
1158
|
+
}
|
|
1159
|
+
#endif
|
|
1160
|
+
triSolve<double, TriStorageOrder == RowMajor, false, (Mode & Lower) == Lower, (Mode & UnitDiag) != 0>(
|
|
1161
|
+
const_cast<double *>(_tri), _other, size, otherSize, triStride, otherStride);
|
|
1162
|
+
}
|
|
1163
|
+
#endif // EIGEN_USE_AVX512_TRSM_L_KERNELS
|
|
1164
|
+
#endif // EIGEN_USE_AVX512_TRSM_KERNELS
|
|
1165
|
+
} // namespace internal
|
|
1166
|
+
} // namespace Eigen
|
|
1167
|
+
#endif // EIGEN_CORE_ARCH_AVX512_TRSM_KERNEL_H
|