tomoto 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +123 -0
- data/ext/tomoto/ext.cpp +245 -0
- data/ext/tomoto/extconf.rb +28 -0
- data/lib/tomoto.rb +12 -0
- data/lib/tomoto/ct.rb +11 -0
- data/lib/tomoto/hdp.rb +11 -0
- data/lib/tomoto/lda.rb +67 -0
- data/lib/tomoto/version.rb +3 -0
- data/vendor/EigenRand/EigenRand/Core.h +1139 -0
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
- data/vendor/EigenRand/EigenRand/EigenRand +19 -0
- data/vendor/EigenRand/EigenRand/Macro.h +24 -0
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
- data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
- data/vendor/EigenRand/EigenRand/doc.h +220 -0
- data/vendor/EigenRand/LICENSE +21 -0
- data/vendor/EigenRand/README.md +288 -0
- data/vendor/eigen/COPYING.BSD +26 -0
- data/vendor/eigen/COPYING.GPL +674 -0
- data/vendor/eigen/COPYING.LGPL +502 -0
- data/vendor/eigen/COPYING.MINPACK +52 -0
- data/vendor/eigen/COPYING.MPL2 +373 -0
- data/vendor/eigen/COPYING.README +18 -0
- data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
- data/vendor/eigen/Eigen/Cholesky +46 -0
- data/vendor/eigen/Eigen/CholmodSupport +48 -0
- data/vendor/eigen/Eigen/Core +537 -0
- data/vendor/eigen/Eigen/Dense +7 -0
- data/vendor/eigen/Eigen/Eigen +2 -0
- data/vendor/eigen/Eigen/Eigenvalues +61 -0
- data/vendor/eigen/Eigen/Geometry +62 -0
- data/vendor/eigen/Eigen/Householder +30 -0
- data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
- data/vendor/eigen/Eigen/Jacobi +33 -0
- data/vendor/eigen/Eigen/LU +50 -0
- data/vendor/eigen/Eigen/MetisSupport +35 -0
- data/vendor/eigen/Eigen/OrderingMethods +73 -0
- data/vendor/eigen/Eigen/PaStiXSupport +48 -0
- data/vendor/eigen/Eigen/PardisoSupport +35 -0
- data/vendor/eigen/Eigen/QR +51 -0
- data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
- data/vendor/eigen/Eigen/SPQRSupport +34 -0
- data/vendor/eigen/Eigen/SVD +51 -0
- data/vendor/eigen/Eigen/Sparse +36 -0
- data/vendor/eigen/Eigen/SparseCholesky +45 -0
- data/vendor/eigen/Eigen/SparseCore +69 -0
- data/vendor/eigen/Eigen/SparseLU +46 -0
- data/vendor/eigen/Eigen/SparseQR +37 -0
- data/vendor/eigen/Eigen/StdDeque +27 -0
- data/vendor/eigen/Eigen/StdList +26 -0
- data/vendor/eigen/Eigen/StdVector +27 -0
- data/vendor/eigen/Eigen/SuperLUSupport +64 -0
- data/vendor/eigen/Eigen/UmfPackSupport +40 -0
- data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
- data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
- data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
- data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
- data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
- data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
- data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
- data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
- data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
- data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
- data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
- data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
- data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
- data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
- data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
- data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
- data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
- data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
- data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
- data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
- data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
- data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
- data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
- data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
- data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
- data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
- data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
- data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
- data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
- data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
- data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
- data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
- data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
- data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
- data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
- data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
- data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
- data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
- data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
- data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
- data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
- data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
- data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
- data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
- data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
- data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
- data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
- data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
- data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
- data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
- data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
- data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
- data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
- data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
- data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
- data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
- data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
- data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
- data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
- data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
- data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
- data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
- data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
- data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
- data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
- data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
- data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
- data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
- data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
- data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
- data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
- data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
- data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
- data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
- data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
- data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
- data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
- data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
- data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
- data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
- data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
- data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
- data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
- data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
- data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
- data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
- data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
- data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
- data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
- data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
- data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
- data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
- data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
- data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
- data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
- data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
- data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
- data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
- data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
- data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
- data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
- data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
- data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
- data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
- data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
- data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
- data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
- data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
- data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
- data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
- data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
- data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
- data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
- data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
- data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
- data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
- data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
- data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
- data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
- data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
- data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
- data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
- data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
- data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
- data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
- data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
- data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
- data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
- data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
- data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
- data/vendor/eigen/README.md +3 -0
- data/vendor/eigen/bench/README.txt +55 -0
- data/vendor/eigen/bench/btl/COPYING +340 -0
- data/vendor/eigen/bench/btl/README +154 -0
- data/vendor/eigen/bench/tensors/README +21 -0
- data/vendor/eigen/blas/README.txt +6 -0
- data/vendor/eigen/demos/mandelbrot/README +10 -0
- data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
- data/vendor/eigen/demos/opengl/README +13 -0
- data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
- data/vendor/eigen/unsupported/README.txt +50 -0
- data/vendor/tomotopy/LICENSE +21 -0
- data/vendor/tomotopy/README.kr.rst +375 -0
- data/vendor/tomotopy/README.rst +382 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
- data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
- data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
- data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
- data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
- data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
- data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
- data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
- data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
- data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
- data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
- data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
- data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
- data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
- data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
- data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
- data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
- data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
- data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
- data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
- data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
- data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
- data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
- data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
- data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
- data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
- data/vendor/tomotopy/src/Utils/exception.h +28 -0
- data/vendor/tomotopy/src/Utils/math.h +281 -0
- data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
- data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
- data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
- data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
- data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
- data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
- data/vendor/tomotopy/src/Utils/text.hpp +49 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
- metadata +531 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "LDAModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class LDAModel<TermWeight::one>;
|
|
6
|
+
template class LDAModel<TermWeight::idf>;
|
|
7
|
+
template class LDAModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
ILDAModel* ILDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
10
|
+
{
|
|
11
|
+
TMT_SWITCH_TW(_weight, scalarRng, LDAModel, _K, _alpha, _eta, seed);
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include <unordered_set>
|
|
3
|
+
#include <numeric>
|
|
4
|
+
#include "TopicModel.hpp"
|
|
5
|
+
#include "../Utils/EigenAddonOps.hpp"
|
|
6
|
+
#include "../Utils/Utils.hpp"
|
|
7
|
+
#include "../Utils/math.h"
|
|
8
|
+
#include "../Utils/sample.hpp"
|
|
9
|
+
#include "LDA.h"
|
|
10
|
+
|
|
11
|
+
/*
|
|
12
|
+
Implementation of LDA using Gibbs sampling by bab2min
|
|
13
|
+
|
|
14
|
+
* Blei, D. M., Ng, A. Y., & Jordan, M. I. (2003). Latent dirichlet allocation. Journal of machine Learning research, 3(Jan), 993-1022.
|
|
15
|
+
* Newman, D., Asuncion, A., Smyth, P., & Welling, M. (2009). Distributed algorithms for topic models. Journal of Machine Learning Research, 10(Aug), 1801-1828.
|
|
16
|
+
|
|
17
|
+
Term Weighting Scheme is based on following paper:
|
|
18
|
+
* Wilson, A. T., & Chew, P. A. (2010, June). Term weighting schemes for latent dirichlet allocation. In human language technologies: The 2010 annual conference of the North American Chapter of the Association for Computational Linguistics (pp. 465-473). Association for Computational Linguistics.
|
|
19
|
+
|
|
20
|
+
*/
|
|
21
|
+
|
|
22
|
+
#ifdef TMT_SCALAR_RNG
|
|
23
|
+
#define TMT_SWITCH_TW(TW, SRNG, MDL, ...) do{\
|
|
24
|
+
{\
|
|
25
|
+
switch (TW){\
|
|
26
|
+
case TermWeight::one:\
|
|
27
|
+
return new MDL<TermWeight::one, ScalarRandGen>(__VA_ARGS__);\
|
|
28
|
+
case TermWeight::idf:\
|
|
29
|
+
return new MDL<TermWeight::idf, ScalarRandGen>(__VA_ARGS__);\
|
|
30
|
+
case TermWeight::pmi:\
|
|
31
|
+
return new MDL<TermWeight::pmi, ScalarRandGen>(__VA_ARGS__);\
|
|
32
|
+
}\
|
|
33
|
+
}\
|
|
34
|
+
return nullptr; } while(0)
|
|
35
|
+
#else
|
|
36
|
+
#define TMT_SWITCH_TW(TW, SRNG, MDL, ...) do{\
|
|
37
|
+
{\
|
|
38
|
+
switch (TW){\
|
|
39
|
+
case TermWeight::one:\
|
|
40
|
+
return new MDL<TermWeight::one, RandGen>(__VA_ARGS__);\
|
|
41
|
+
case TermWeight::idf:\
|
|
42
|
+
return new MDL<TermWeight::idf, RandGen>(__VA_ARGS__);\
|
|
43
|
+
case TermWeight::pmi:\
|
|
44
|
+
return new MDL<TermWeight::pmi, RandGen>(__VA_ARGS__);\
|
|
45
|
+
}\
|
|
46
|
+
}\
|
|
47
|
+
return nullptr; } while(0)
|
|
48
|
+
#endif
|
|
49
|
+
|
|
50
|
+
#define GETTER(name, type, field) type get##name() const override { return field; }
|
|
51
|
+
|
|
52
|
+
namespace tomoto
|
|
53
|
+
{
|
|
54
|
+
template<TermWeight _tw>
|
|
55
|
+
struct ModelStateLDA
|
|
56
|
+
{
|
|
57
|
+
using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
|
|
58
|
+
|
|
59
|
+
Eigen::Matrix<Float, -1, 1> zLikelihood;
|
|
60
|
+
Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
|
|
61
|
+
Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
|
62
|
+
DEFINE_SERIALIZER(numByTopic, numByTopicWord);
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
namespace flags
|
|
66
|
+
{
|
|
67
|
+
enum
|
|
68
|
+
{
|
|
69
|
+
generator_by_doc = end_flag_of_TopicModel,
|
|
70
|
+
end_flag_of_LDAModel = generator_by_doc << 1,
|
|
71
|
+
};
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
template<typename _Model, bool _asymEta>
|
|
76
|
+
class EtaHelper
|
|
77
|
+
{
|
|
78
|
+
const _Model& _this;
|
|
79
|
+
public:
|
|
80
|
+
EtaHelper(const _Model& p) : _this(p) {}
|
|
81
|
+
|
|
82
|
+
Float getEta(size_t vid) const
|
|
83
|
+
{
|
|
84
|
+
return _this.eta;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
Float getEtaSum() const
|
|
88
|
+
{
|
|
89
|
+
return _this.eta * _this.realV;
|
|
90
|
+
}
|
|
91
|
+
};
|
|
92
|
+
|
|
93
|
+
template<typename _Model>
|
|
94
|
+
class EtaHelper<_Model, true>
|
|
95
|
+
{
|
|
96
|
+
const _Model& _this;
|
|
97
|
+
public:
|
|
98
|
+
EtaHelper(const _Model& p) : _this(p) {}
|
|
99
|
+
|
|
100
|
+
auto getEta(size_t vid) const
|
|
101
|
+
-> decltype(_this.etaByTopicWord.col(vid).array())
|
|
102
|
+
{
|
|
103
|
+
return _this.etaByTopicWord.col(vid).array();
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
auto getEtaSum() const
|
|
107
|
+
-> decltype(_this.etaSumByTopic.array())
|
|
108
|
+
{
|
|
109
|
+
return _this.etaSumByTopic.array();
|
|
110
|
+
}
|
|
111
|
+
};
|
|
112
|
+
|
|
113
|
+
template<TermWeight _tw>
|
|
114
|
+
struct TwId;
|
|
115
|
+
|
|
116
|
+
template<>
|
|
117
|
+
struct TwId<TermWeight::one>
|
|
118
|
+
{
|
|
119
|
+
static constexpr char TWID[] = "one\0";
|
|
120
|
+
};
|
|
121
|
+
|
|
122
|
+
template<>
|
|
123
|
+
struct TwId<TermWeight::idf>
|
|
124
|
+
{
|
|
125
|
+
static constexpr char TWID[] = "idf\0";
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
template<>
|
|
129
|
+
struct TwId<TermWeight::pmi>
|
|
130
|
+
{
|
|
131
|
+
static constexpr char TWID[] = "pmi\0";
|
|
132
|
+
};
|
|
133
|
+
|
|
134
|
+
// to make HDP friend of LDA for HDPModel::converToLDA
|
|
135
|
+
template<TermWeight _tw,
|
|
136
|
+
typename _RandGen,
|
|
137
|
+
typename _Interface,
|
|
138
|
+
typename _Derived,
|
|
139
|
+
typename _DocType,
|
|
140
|
+
typename _ModelState>
|
|
141
|
+
class HDPModel;
|
|
142
|
+
|
|
143
|
+
template<TermWeight _tw, typename _RandGen,
|
|
144
|
+
size_t _Flags = flags::partitioned_multisampling,
|
|
145
|
+
typename _Interface = ILDAModel,
|
|
146
|
+
typename _Derived = void,
|
|
147
|
+
typename _DocType = DocumentLDA<_tw>,
|
|
148
|
+
typename _ModelState = ModelStateLDA<_tw>>
|
|
149
|
+
class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
|
|
150
|
+
typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
|
151
|
+
_DocType, _ModelState>,
|
|
152
|
+
protected TwId<_tw>
|
|
153
|
+
{
|
|
154
|
+
protected:
|
|
155
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, LDAModel, _Derived>::type;
|
|
156
|
+
using BaseClass = TopicModel<_RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
157
|
+
friend BaseClass;
|
|
158
|
+
friend EtaHelper<DerivedClass, true>;
|
|
159
|
+
friend EtaHelper<DerivedClass, false>;
|
|
160
|
+
|
|
161
|
+
template<TermWeight,
|
|
162
|
+
typename,
|
|
163
|
+
typename,
|
|
164
|
+
typename,
|
|
165
|
+
typename,
|
|
166
|
+
typename>
|
|
167
|
+
friend class HDPModel;
|
|
168
|
+
|
|
169
|
+
static constexpr char TMID[] = "LDA\0";
|
|
170
|
+
using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
|
|
171
|
+
|
|
172
|
+
enum { m_flags = _Flags };
|
|
173
|
+
|
|
174
|
+
std::vector<Float> vocabWeights;
|
|
175
|
+
std::vector<Tid> sharedZs;
|
|
176
|
+
std::vector<Float> sharedWordWeights;
|
|
177
|
+
Tid K;
|
|
178
|
+
Float alpha, eta;
|
|
179
|
+
Eigen::Matrix<Float, -1, 1> alphas;
|
|
180
|
+
std::unordered_map<std::string, std::vector<Float>> etaByWord;
|
|
181
|
+
Eigen::Matrix<Float, -1, -1> etaByTopicWord; // (K, V)
|
|
182
|
+
Eigen::Matrix<Float, -1, 1> etaSumByTopic; // (K, )
|
|
183
|
+
uint32_t optimInterval = 10, burnIn = 0;
|
|
184
|
+
Eigen::Matrix<WeightType, -1, -1> numByTopicDoc;
|
|
185
|
+
|
|
186
|
+
struct ExtraDocData
|
|
187
|
+
{
|
|
188
|
+
std::vector<Vid> vChunkOffset;
|
|
189
|
+
Eigen::Matrix<uint32_t, -1, -1> chunkOffsetByDoc;
|
|
190
|
+
};
|
|
191
|
+
|
|
192
|
+
ExtraDocData eddTrain;
|
|
193
|
+
|
|
194
|
+
template<typename _List>
|
|
195
|
+
static Float calcDigammaSum(ThreadPool* pool, _List list, size_t len, Float alpha)
|
|
196
|
+
{
|
|
197
|
+
auto listExpr = Eigen::Matrix<Float, -1, 1>::NullaryExpr(len, list);
|
|
198
|
+
auto dAlpha = math::digammaT(alpha);
|
|
199
|
+
|
|
200
|
+
size_t suggested = (len + 127) / 128;
|
|
201
|
+
if (pool && suggested > pool->getNumWorkers()) suggested = pool->getNumWorkers();
|
|
202
|
+
if (suggested <= 1 || !pool)
|
|
203
|
+
{
|
|
204
|
+
return (math::digammaApprox(listExpr.array() + alpha) - dAlpha).sum();
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
std::vector<std::future<Float>> futures;
|
|
209
|
+
for (size_t i = 0; i < suggested; ++i)
|
|
210
|
+
{
|
|
211
|
+
size_t start = (len * i / suggested + 15) & ~0xF,
|
|
212
|
+
end = std::min((len * (i + 1) / suggested + 15) & ~0xF, len);
|
|
213
|
+
futures.emplace_back(pool->enqueue([&, start, end, dAlpha](size_t)
|
|
214
|
+
{
|
|
215
|
+
return (math::digammaApprox(listExpr.array().segment(start, end - start) + alpha) - dAlpha).sum();
|
|
216
|
+
}));
|
|
217
|
+
}
|
|
218
|
+
Float ret = 0;
|
|
219
|
+
for (auto& f : futures) ret += f.get();
|
|
220
|
+
return ret;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
/*
|
|
224
|
+
function for optimizing hyperparameters
|
|
225
|
+
*/
|
|
226
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
227
|
+
{
|
|
228
|
+
const auto K = this->K;
|
|
229
|
+
for (size_t i = 0; i < 10; ++i)
|
|
230
|
+
{
|
|
231
|
+
Float denom = calcDigammaSum(&pool, [&](size_t i) { return this->docs[i].getSumWordWeight(); }, this->docs.size(), alphas.sum());
|
|
232
|
+
for (size_t k = 0; k < K; ++k)
|
|
233
|
+
{
|
|
234
|
+
Float nom = calcDigammaSum(&pool, [&](size_t i) { return this->docs[i].numByTopic[k]; }, this->docs.size(), alphas(k));
|
|
235
|
+
alphas(k) = std::max(nom / denom * alphas(k), 1e-5f);
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
template<bool _asymEta>
|
|
241
|
+
EtaHelper<DerivedClass, _asymEta> getEtaHelper() const
|
|
242
|
+
{
|
|
243
|
+
return EtaHelper<DerivedClass, _asymEta>{ *static_cast<const DerivedClass*>(this) };
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
template<bool _asymEta>
|
|
247
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
248
|
+
{
|
|
249
|
+
const size_t V = this->realV;
|
|
250
|
+
assert(vid < V);
|
|
251
|
+
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
252
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
253
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + alphas.array())
|
|
254
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
255
|
+
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
256
|
+
sample::prefixSum(zLikelihood.data(), K);
|
|
257
|
+
return &zLikelihood[0];
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
template<int _inc>
|
|
261
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid) const
|
|
262
|
+
{
|
|
263
|
+
assert(tid < K);
|
|
264
|
+
assert(vid < this->realV);
|
|
265
|
+
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
|
266
|
+
typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
|
|
267
|
+
= _tw != TermWeight::one ? doc.wordWeights[pid] : 1;
|
|
268
|
+
|
|
269
|
+
updateCnt<_dec>(doc.numByTopic[tid], _inc * weight);
|
|
270
|
+
updateCnt<_dec>(ld.numByTopic[tid], _inc * weight);
|
|
271
|
+
updateCnt<_dec>(ld.numByTopicWord(tid, vid), _inc * weight);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
void resetStatistics()
|
|
275
|
+
{
|
|
276
|
+
this->globalState.numByTopic.setZero();
|
|
277
|
+
this->globalState.numByTopicWord.setZero();
|
|
278
|
+
for (auto& doc : this->docs)
|
|
279
|
+
{
|
|
280
|
+
doc.numByTopic.setZero();
|
|
281
|
+
for (size_t w = 0; w < doc.words.size(); ++w)
|
|
282
|
+
{
|
|
283
|
+
if (doc.words[w] >= this->realV) continue;
|
|
284
|
+
addWordTo<1>(this->globalState, doc, w, doc.words[w], doc.Zs[w]);
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
/*
|
|
290
|
+
called once before sampleDocument
|
|
291
|
+
*/
|
|
292
|
+
void presampleDocument(_DocType& doc, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt) const
|
|
293
|
+
{
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
/*
|
|
297
|
+
main sampling procedure (can be called one or more by ParallelScheme)
|
|
298
|
+
*/
|
|
299
|
+
template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
|
|
300
|
+
void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
|
|
301
|
+
{
|
|
302
|
+
size_t b = 0, e = doc.words.size();
|
|
303
|
+
if (_ps == ParallelScheme::partition)
|
|
304
|
+
{
|
|
305
|
+
b = edd.chunkOffsetByDoc(partitionId, docId);
|
|
306
|
+
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
|
310
|
+
|
|
311
|
+
for (size_t w = b; w < e; ++w)
|
|
312
|
+
{
|
|
313
|
+
if (doc.words[w] >= this->realV) continue;
|
|
314
|
+
addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
|
|
315
|
+
Float* dist;
|
|
316
|
+
if (etaByTopicWord.size())
|
|
317
|
+
{
|
|
318
|
+
dist = static_cast<const DerivedClass*>(this)->template
|
|
319
|
+
getZLikelihoods<true>(ld, doc, docId, doc.words[w] - vOffset);
|
|
320
|
+
}
|
|
321
|
+
else
|
|
322
|
+
{
|
|
323
|
+
dist = static_cast<const DerivedClass*>(this)->template
|
|
324
|
+
getZLikelihoods<false>(ld, doc, docId, doc.words[w] - vOffset);
|
|
325
|
+
}
|
|
326
|
+
doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
|
|
327
|
+
addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
|
|
332
|
+
void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
|
|
333
|
+
_DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
|
|
334
|
+
{
|
|
335
|
+
// single-threaded sampling
|
|
336
|
+
if (_ps == ParallelScheme::none)
|
|
337
|
+
{
|
|
338
|
+
forRandom((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
|
|
339
|
+
{
|
|
340
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
|
|
341
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
|
342
|
+
docFirst[id], edd, id,
|
|
343
|
+
*localData, *rgs, this->globalStep, 0);
|
|
344
|
+
|
|
345
|
+
});
|
|
346
|
+
}
|
|
347
|
+
// multi-threaded sampling on partition ad update into global
|
|
348
|
+
else if (_ps == ParallelScheme::partition)
|
|
349
|
+
{
|
|
350
|
+
const size_t chStride = pool.getNumWorkers();
|
|
351
|
+
for (size_t i = 0; i < chStride; ++i)
|
|
352
|
+
{
|
|
353
|
+
res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
|
|
354
|
+
{
|
|
355
|
+
size_t didx = (i + partitionId) % chStride;
|
|
356
|
+
forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
|
|
357
|
+
{
|
|
358
|
+
if (i == 0)
|
|
359
|
+
{
|
|
360
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(
|
|
361
|
+
docFirst[id * chStride + didx], id * chStride + didx,
|
|
362
|
+
localData[partitionId], rgs[partitionId], this->globalStep
|
|
363
|
+
);
|
|
364
|
+
}
|
|
365
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
|
366
|
+
docFirst[id * chStride + didx], edd, id * chStride + didx,
|
|
367
|
+
localData[partitionId], rgs[partitionId], this->globalStep, partitionId
|
|
368
|
+
);
|
|
369
|
+
});
|
|
370
|
+
});
|
|
371
|
+
for (auto& r : res) r.get();
|
|
372
|
+
res.clear();
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
// multi-threaded sampling on copy and merge into global
|
|
376
|
+
else if(_ps == ParallelScheme::copy_merge)
|
|
377
|
+
{
|
|
378
|
+
const size_t chStride = std::min(pool.getNumWorkers() * 8, (size_t)std::distance(docFirst, docLast));
|
|
379
|
+
for (size_t ch = 0; ch < chStride; ++ch)
|
|
380
|
+
{
|
|
381
|
+
res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
|
|
382
|
+
{
|
|
383
|
+
forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
|
|
384
|
+
{
|
|
385
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(
|
|
386
|
+
docFirst[id * chStride + ch], id * chStride + ch,
|
|
387
|
+
localData[threadId], rgs[threadId], this->globalStep
|
|
388
|
+
);
|
|
389
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
|
390
|
+
docFirst[id * chStride + ch], edd, id * chStride + ch,
|
|
391
|
+
localData[threadId], rgs[threadId], this->globalStep, 0
|
|
392
|
+
);
|
|
393
|
+
});
|
|
394
|
+
}));
|
|
395
|
+
}
|
|
396
|
+
for (auto& r : res) r.get();
|
|
397
|
+
res.clear();
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
template<typename _DocIter, typename _ExtraDocData>
|
|
402
|
+
void updatePartition(ThreadPool& pool, const _ModelState& globalState, _ModelState* localData, _DocIter first, _DocIter last, _ExtraDocData& edd) const
|
|
403
|
+
{
|
|
404
|
+
size_t numPools = pool.getNumWorkers();
|
|
405
|
+
if (edd.vChunkOffset.size() != numPools)
|
|
406
|
+
{
|
|
407
|
+
edd.vChunkOffset.clear();
|
|
408
|
+
size_t totCnt = std::accumulate(this->vocabCf.begin(), this->vocabCf.begin() + this->realV, 0);
|
|
409
|
+
size_t cumCnt = 0;
|
|
410
|
+
for (size_t i = 0; i < this->realV; ++i)
|
|
411
|
+
{
|
|
412
|
+
cumCnt += this->vocabCf[i];
|
|
413
|
+
if (cumCnt * numPools >= totCnt * (edd.vChunkOffset.size() + 1)) edd.vChunkOffset.emplace_back(i + 1);
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
edd.chunkOffsetByDoc.resize(numPools + 1, std::distance(first, last));
|
|
417
|
+
size_t i = 0;
|
|
418
|
+
for (; first != last; ++first, ++i)
|
|
419
|
+
{
|
|
420
|
+
auto& doc = *first;
|
|
421
|
+
edd.chunkOffsetByDoc(0, i) = 0;
|
|
422
|
+
size_t g = 0;
|
|
423
|
+
for (size_t j = 0; j < doc.words.size(); ++j)
|
|
424
|
+
{
|
|
425
|
+
for (; g < numPools && doc.words[j] >= edd.vChunkOffset[g]; ++g)
|
|
426
|
+
{
|
|
427
|
+
edd.chunkOffsetByDoc(g + 1, i) = j;
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
for (; g < numPools; ++g)
|
|
431
|
+
{
|
|
432
|
+
edd.chunkOffsetByDoc(g + 1, i) = doc.words.size();
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
static_cast<const DerivedClass*>(this)->distributePartition(pool, globalState, localData, edd);
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
template<typename _ExtraDocData>
|
|
440
|
+
void distributePartition(ThreadPool& pool, const _ModelState& globalState, _ModelState* localData, const _ExtraDocData& edd) const
|
|
441
|
+
{
|
|
442
|
+
std::vector<std::future<void>> res = pool.enqueueToAll([&](size_t partitionId)
|
|
443
|
+
{
|
|
444
|
+
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
|
445
|
+
e = edd.vChunkOffset[partitionId];
|
|
446
|
+
|
|
447
|
+
localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
|
448
|
+
localData[partitionId].numByTopic = globalState.numByTopic;
|
|
449
|
+
if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
|
|
450
|
+
});
|
|
451
|
+
|
|
452
|
+
for (auto& r : res) r.get();
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
template<ParallelScheme _ps>
|
|
456
|
+
size_t estimateMaxThreads() const
|
|
457
|
+
{
|
|
458
|
+
if (_ps == ParallelScheme::partition)
|
|
459
|
+
{
|
|
460
|
+
return (this->realV + 3) / 4;
|
|
461
|
+
}
|
|
462
|
+
if (_ps == ParallelScheme::copy_merge)
|
|
463
|
+
{
|
|
464
|
+
return (this->docs.size() + 1) / 2;
|
|
465
|
+
}
|
|
466
|
+
return (size_t)-1;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
template<ParallelScheme _ps>
|
|
470
|
+
void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
471
|
+
{
|
|
472
|
+
std::vector<std::future<void>> res;
|
|
473
|
+
try
|
|
474
|
+
{
|
|
475
|
+
performSampling<_ps, false>(pool, localData, rgs, res,
|
|
476
|
+
this->docs.begin(), this->docs.end(), eddTrain);
|
|
477
|
+
static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
|
|
478
|
+
static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
|
|
479
|
+
static_cast<DerivedClass*>(this)->template sampleGlobalLevel<>(&pool, localData, rgs, this->docs.begin(), this->docs.end());
|
|
480
|
+
if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
|
|
481
|
+
{
|
|
482
|
+
static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
catch (const exception::TrainingError&)
|
|
486
|
+
{
|
|
487
|
+
for (auto& r : res) if(r.valid()) r.get();
|
|
488
|
+
throw;
|
|
489
|
+
}
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
/*
|
|
493
|
+
updates global informations after sampling documents
|
|
494
|
+
ex) update new global K at HDP model
|
|
495
|
+
*/
|
|
496
|
+
void updateGlobalInfo(ThreadPool& pool, _ModelState* localData)
|
|
497
|
+
{
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
/*
|
|
501
|
+
merges multithreaded document sampling result
|
|
502
|
+
*/
|
|
503
|
+
template<ParallelScheme _ps, typename _ExtraDocData>
|
|
504
|
+
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
|
505
|
+
{
|
|
506
|
+
std::vector<std::future<void>> res;
|
|
507
|
+
|
|
508
|
+
if (_ps == ParallelScheme::copy_merge)
|
|
509
|
+
{
|
|
510
|
+
tState = globalState;
|
|
511
|
+
globalState = localData[0];
|
|
512
|
+
for (size_t i = 1; i < pool.getNumWorkers(); ++i)
|
|
513
|
+
{
|
|
514
|
+
globalState.numByTopicWord += localData[i].numByTopicWord - tState.numByTopicWord;
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
// make all count being positive
|
|
518
|
+
if (_tw != TermWeight::one)
|
|
519
|
+
{
|
|
520
|
+
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
|
521
|
+
}
|
|
522
|
+
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
|
523
|
+
|
|
524
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
|
525
|
+
{
|
|
526
|
+
res.emplace_back(pool.enqueue([&, i](size_t)
|
|
527
|
+
{
|
|
528
|
+
localData[i] = globalState;
|
|
529
|
+
}));
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
else if (_ps == ParallelScheme::partition)
|
|
533
|
+
{
|
|
534
|
+
res = pool.enqueueToAll([&](size_t partitionId)
|
|
535
|
+
{
|
|
536
|
+
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
|
537
|
+
e = edd.vChunkOffset[partitionId];
|
|
538
|
+
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
|
539
|
+
});
|
|
540
|
+
for (auto& r : res) r.get();
|
|
541
|
+
res.clear();
|
|
542
|
+
|
|
543
|
+
// make all count being positive
|
|
544
|
+
if (_tw != TermWeight::one)
|
|
545
|
+
{
|
|
546
|
+
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
|
547
|
+
}
|
|
548
|
+
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
|
549
|
+
|
|
550
|
+
res = pool.enqueueToAll([&](size_t threadId)
|
|
551
|
+
{
|
|
552
|
+
localData[threadId].numByTopic = globalState.numByTopic;
|
|
553
|
+
});
|
|
554
|
+
}
|
|
555
|
+
for (auto& r : res) r.get();
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
/*
|
|
559
|
+
performs sampling which needs global state modification
|
|
560
|
+
ex) document pathing at hLDA model
|
|
561
|
+
* if pool is nullptr, workers has been already pooled and cannot branch works more.
|
|
562
|
+
*/
|
|
563
|
+
template<typename _DocIter>
|
|
564
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
|
565
|
+
{
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
template<typename _DocIter>
|
|
569
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
|
|
570
|
+
{
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
template<typename _DocIter>
|
|
574
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
575
|
+
{
|
|
576
|
+
double ll = 0;
|
|
577
|
+
// doc-topic distribution
|
|
578
|
+
for (; _first != _last; ++_first)
|
|
579
|
+
{
|
|
580
|
+
auto& doc = *_first;
|
|
581
|
+
ll -= math::lgammaT(doc.getSumWordWeight() + alphas.sum()) - math::lgammaT(alphas.sum());
|
|
582
|
+
for (Tid k = 0; k < K; ++k)
|
|
583
|
+
{
|
|
584
|
+
ll += math::lgammaT(doc.numByTopic[k] + alphas[k]) - math::lgammaT(alphas[k]);
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
return ll;
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
double getLLRest(const _ModelState& ld) const
|
|
591
|
+
{
|
|
592
|
+
double ll = 0;
|
|
593
|
+
const size_t V = this->realV;
|
|
594
|
+
// topic-word distribution
|
|
595
|
+
auto lgammaEta = math::lgammaT(eta);
|
|
596
|
+
ll += math::lgammaT(V*eta) * K;
|
|
597
|
+
for (Tid k = 0; k < K; ++k)
|
|
598
|
+
{
|
|
599
|
+
ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
|
|
600
|
+
for (Vid v = 0; v < V; ++v)
|
|
601
|
+
{
|
|
602
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
|
603
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
|
|
604
|
+
assert(std::isfinite(ll));
|
|
605
|
+
}
|
|
606
|
+
}
|
|
607
|
+
return ll;
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
double getLL() const
|
|
611
|
+
{
|
|
612
|
+
return static_cast<const DerivedClass*>(this)->template getLLDocs<>(this->docs.begin(), this->docs.end())
|
|
613
|
+
+ static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
|
614
|
+
}
|
|
615
|
+
|
|
616
|
+
void prepareShared()
|
|
617
|
+
{
|
|
618
|
+
auto txZs = [](_DocType& doc) { return &doc.Zs; };
|
|
619
|
+
tvector<Tid>::trade(sharedZs,
|
|
620
|
+
makeTransformIter(this->docs.begin(), txZs),
|
|
621
|
+
makeTransformIter(this->docs.end(), txZs));
|
|
622
|
+
if (_tw != TermWeight::one)
|
|
623
|
+
{
|
|
624
|
+
auto txWeights = [](_DocType& doc) { return &doc.wordWeights; };
|
|
625
|
+
tvector<Float>::trade(sharedWordWeights,
|
|
626
|
+
makeTransformIter(this->docs.begin(), txWeights),
|
|
627
|
+
makeTransformIter(this->docs.end(), txWeights));
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
WeightType* getTopicDocPtr(size_t docId) const
|
|
632
|
+
{
|
|
633
|
+
if (!(m_flags & flags::continuous_doc_data) || docId == (size_t)-1) return nullptr;
|
|
634
|
+
return (WeightType*)numByTopicDoc.col(docId).data();
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
638
|
+
{
|
|
639
|
+
sortAndWriteOrder(doc.words, doc.wOrder);
|
|
640
|
+
doc.numByTopic.init(getTopicDocPtr(docId), K);
|
|
641
|
+
doc.Zs = tvector<Tid>(wordSize);
|
|
642
|
+
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize, 1);
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
void prepareWordPriors()
|
|
646
|
+
{
|
|
647
|
+
if (etaByWord.empty()) return;
|
|
648
|
+
etaByTopicWord.resize(K, this->realV);
|
|
649
|
+
etaSumByTopic.resize(K);
|
|
650
|
+
etaByTopicWord.array() = eta;
|
|
651
|
+
for (auto& it : etaByWord)
|
|
652
|
+
{
|
|
653
|
+
auto id = this->dict.toWid(it.first);
|
|
654
|
+
if (id == (Vid)-1 || id >= this->realV) continue;
|
|
655
|
+
etaByTopicWord.col(id) = Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ it.second.data(), (Eigen::Index)it.second.size() };
|
|
656
|
+
}
|
|
657
|
+
etaSumByTopic = etaByTopicWord.rowwise().sum();
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
void initGlobalState(bool initDocs)
|
|
661
|
+
{
|
|
662
|
+
const size_t V = this->realV;
|
|
663
|
+
this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(K);
|
|
664
|
+
if (initDocs)
|
|
665
|
+
{
|
|
666
|
+
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
|
667
|
+
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
|
668
|
+
}
|
|
669
|
+
if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
struct Generator
|
|
673
|
+
{
|
|
674
|
+
std::uniform_int_distribution<Tid> theta;
|
|
675
|
+
};
|
|
676
|
+
|
|
677
|
+
Generator makeGeneratorForInit(const _DocType*) const
|
|
678
|
+
{
|
|
679
|
+
return Generator{ std::uniform_int_distribution<Tid>{0, (Tid)(K - 1)} };
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
template<bool _Infer>
|
|
683
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
|
684
|
+
{
|
|
685
|
+
auto& z = doc.Zs[i];
|
|
686
|
+
auto w = doc.words[i];
|
|
687
|
+
if (etaByTopicWord.size())
|
|
688
|
+
{
|
|
689
|
+
auto col = etaByTopicWord.col(w);
|
|
690
|
+
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
691
|
+
}
|
|
692
|
+
else
|
|
693
|
+
{
|
|
694
|
+
z = g.theta(rgs);
|
|
695
|
+
}
|
|
696
|
+
addWordTo<1>(ld, doc, i, w, z);
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
template<bool _Infer, typename _Generator>
|
|
700
|
+
void initializeDocState(_DocType& doc, size_t docId, _Generator& g, _ModelState& ld, _RandGen& rgs) const
|
|
701
|
+
{
|
|
702
|
+
std::vector<uint32_t> tf(this->realV);
|
|
703
|
+
static_cast<const DerivedClass*>(this)->prepareDoc(doc, docId, doc.words.size());
|
|
704
|
+
_Generator g2;
|
|
705
|
+
_Generator* selectedG = &g;
|
|
706
|
+
if (m_flags & flags::generator_by_doc)
|
|
707
|
+
{
|
|
708
|
+
g2 = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(&doc);
|
|
709
|
+
selectedG = &g2;
|
|
710
|
+
}
|
|
711
|
+
if (_tw == TermWeight::pmi)
|
|
712
|
+
{
|
|
713
|
+
std::fill(tf.begin(), tf.end(), 0);
|
|
714
|
+
for (auto& w : doc.words) if(w < this->realV) ++tf[w];
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
|
718
|
+
{
|
|
719
|
+
if (doc.words[i] >= this->realV) continue;
|
|
720
|
+
if (_tw == TermWeight::idf)
|
|
721
|
+
{
|
|
722
|
+
doc.wordWeights[i] = vocabWeights[doc.words[i]];
|
|
723
|
+
}
|
|
724
|
+
else if (_tw == TermWeight::pmi)
|
|
725
|
+
{
|
|
726
|
+
doc.wordWeights[i] = std::max((Float)log(tf[doc.words[i]] / vocabWeights[doc.words[i]] / doc.words.size()), (Float)0);
|
|
727
|
+
}
|
|
728
|
+
static_cast<const DerivedClass*>(this)->template updateStateWithDoc<_Infer>(*selectedG, ld, rgs, doc, i);
|
|
729
|
+
}
|
|
730
|
+
doc.updateSumWordWeight(this->realV);
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
std::vector<uint64_t> _getTopicsCount() const
|
|
734
|
+
{
|
|
735
|
+
std::vector<uint64_t> cnt(K);
|
|
736
|
+
for (auto& doc : this->docs)
|
|
737
|
+
{
|
|
738
|
+
for (size_t i = 0; i < doc.Zs.size(); ++i)
|
|
739
|
+
{
|
|
740
|
+
if (doc.words[i] < this->realV) ++cnt[doc.Zs[i]];
|
|
741
|
+
}
|
|
742
|
+
}
|
|
743
|
+
return cnt;
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
std::vector<Float> _getWidsByTopic(size_t tid) const
|
|
747
|
+
{
|
|
748
|
+
assert(tid < this->globalState.numByTopic.rows());
|
|
749
|
+
const size_t V = this->realV;
|
|
750
|
+
std::vector<Float> ret(V);
|
|
751
|
+
Float sum = this->globalState.numByTopic[tid] + V * eta;
|
|
752
|
+
auto r = this->globalState.numByTopicWord.row(tid);
|
|
753
|
+
for (size_t v = 0; v < V; ++v)
|
|
754
|
+
{
|
|
755
|
+
ret[v] = (r[v] + eta) / sum;
|
|
756
|
+
}
|
|
757
|
+
return ret;
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
template<bool _Together, ParallelScheme _ps, typename _Iter>
|
|
761
|
+
std::vector<double> _infer(_Iter docFirst, _Iter docLast, size_t maxIter, Float tolerance, size_t numWorkers) const
|
|
762
|
+
{
|
|
763
|
+
decltype(static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
|
|
764
|
+
if (!(m_flags & flags::generator_by_doc))
|
|
765
|
+
{
|
|
766
|
+
generator = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr);
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
if (_Together)
|
|
770
|
+
{
|
|
771
|
+
numWorkers = std::min(numWorkers, this->maxThreads[(size_t)_ps]);
|
|
772
|
+
ThreadPool pool{ numWorkers };
|
|
773
|
+
// temporary state variable
|
|
774
|
+
_RandGen rgc{};
|
|
775
|
+
auto tmpState = this->globalState, tState = this->globalState;
|
|
776
|
+
for (auto d = docFirst; d != docLast; ++d)
|
|
777
|
+
{
|
|
778
|
+
initializeDocState<true>(*d, -1, generator, tmpState, rgc);
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
std::vector<decltype(tmpState)> localData((m_flags & flags::shared_state) ? 0 : pool.getNumWorkers(), tmpState);
|
|
782
|
+
std::vector<_RandGen> rgs;
|
|
783
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i) rgs.emplace_back(rgc());
|
|
784
|
+
|
|
785
|
+
ExtraDocData edd;
|
|
786
|
+
if (_ps == ParallelScheme::partition)
|
|
787
|
+
{
|
|
788
|
+
updatePartition(pool, tmpState, localData.data(), docFirst, docLast, edd);
|
|
789
|
+
}
|
|
790
|
+
|
|
791
|
+
for (size_t i = 0; i < maxIter; ++i)
|
|
792
|
+
{
|
|
793
|
+
std::vector<std::future<void>> res;
|
|
794
|
+
performSampling<_ps, true>(pool,
|
|
795
|
+
(m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
|
|
796
|
+
docFirst, docLast, edd);
|
|
797
|
+
static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
|
|
798
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
|
|
799
|
+
&pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast);
|
|
800
|
+
}
|
|
801
|
+
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
|
802
|
+
ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
|
|
803
|
+
return { ll };
|
|
804
|
+
}
|
|
805
|
+
else if (m_flags & flags::shared_state)
|
|
806
|
+
{
|
|
807
|
+
ThreadPool pool{ numWorkers };
|
|
808
|
+
ExtraDocData edd;
|
|
809
|
+
std::vector<double> ret;
|
|
810
|
+
const double gllRest = static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
|
811
|
+
for (auto d = docFirst; d != docLast; ++d)
|
|
812
|
+
{
|
|
813
|
+
_RandGen rgc{};
|
|
814
|
+
auto tmpState = this->globalState;
|
|
815
|
+
initializeDocState<true>(*d, -1, generator, tmpState, rgc);
|
|
816
|
+
for (size_t i = 0; i < maxIter; ++i)
|
|
817
|
+
{
|
|
818
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
|
|
819
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
|
|
820
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
|
|
821
|
+
&pool, &tmpState, &rgc, &*d, &*d + 1);
|
|
822
|
+
}
|
|
823
|
+
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
|
|
824
|
+
ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(&*d, &*d + 1);
|
|
825
|
+
ret.emplace_back(ll);
|
|
826
|
+
}
|
|
827
|
+
return ret;
|
|
828
|
+
}
|
|
829
|
+
else
|
|
830
|
+
{
|
|
831
|
+
ThreadPool pool{ numWorkers, numWorkers * 8 };
|
|
832
|
+
ExtraDocData edd;
|
|
833
|
+
std::vector<std::future<double>> res;
|
|
834
|
+
const double gllRest = static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
|
835
|
+
for (auto d = docFirst; d != docLast; ++d)
|
|
836
|
+
{
|
|
837
|
+
res.emplace_back(pool.enqueue([&, d](size_t threadId)
|
|
838
|
+
{
|
|
839
|
+
_RandGen rgc{};
|
|
840
|
+
auto tmpState = this->globalState;
|
|
841
|
+
initializeDocState<true>(*d, -1, generator, tmpState, rgc);
|
|
842
|
+
for (size_t i = 0; i < maxIter; ++i)
|
|
843
|
+
{
|
|
844
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
|
|
845
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
|
|
846
|
+
*d, edd, -1, tmpState, rgc, i
|
|
847
|
+
);
|
|
848
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
|
|
849
|
+
nullptr, &tmpState, &rgc, &*d, &*d + 1
|
|
850
|
+
);
|
|
851
|
+
}
|
|
852
|
+
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
|
|
853
|
+
ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(&*d, &*d + 1);
|
|
854
|
+
return ll;
|
|
855
|
+
}));
|
|
856
|
+
}
|
|
857
|
+
std::vector<double> ret;
|
|
858
|
+
for (auto& r : res) ret.emplace_back(r.get());
|
|
859
|
+
return ret;
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
public:
|
|
864
|
+
DEFINE_SERIALIZER_WITH_VERSION(0, vocabWeights, alpha, alphas, eta, K);
|
|
865
|
+
|
|
866
|
+
DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, vocabWeights, alpha, alphas, eta, K, etaByWord,
|
|
867
|
+
burnIn, optimInterval);
|
|
868
|
+
|
|
869
|
+
LDAModel(size_t _K = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
|
|
870
|
+
: BaseClass(_rg), K(_K), alpha(_alpha), eta(_eta)
|
|
871
|
+
{
|
|
872
|
+
if (_K == 0 || _K >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong K value (K = %zd)", _K));
|
|
873
|
+
if (_alpha <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alpha value (alpha = %f)", _alpha));
|
|
874
|
+
if (_eta <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong eta value (eta = %f)", _eta));
|
|
875
|
+
alphas = Eigen::Matrix<Float, -1, 1>::Constant(K, alpha);
|
|
876
|
+
}
|
|
877
|
+
|
|
878
|
+
GETTER(K, size_t, K);
|
|
879
|
+
GETTER(Alpha, Float, alpha);
|
|
880
|
+
GETTER(Eta, Float, eta);
|
|
881
|
+
GETTER(OptimInterval, size_t, optimInterval);
|
|
882
|
+
GETTER(BurnInIteration, size_t, burnIn);
|
|
883
|
+
|
|
884
|
+
Float getAlpha(size_t k1) const override { return alphas[k1]; }
|
|
885
|
+
|
|
886
|
+
TermWeight getTermWeight() const override
|
|
887
|
+
{
|
|
888
|
+
return _tw;
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
void setOptimInterval(size_t _optimInterval) override
|
|
892
|
+
{
|
|
893
|
+
optimInterval = _optimInterval;
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
void setBurnInIteration(size_t iteration) override
|
|
897
|
+
{
|
|
898
|
+
burnIn = iteration;
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
size_t addDoc(const std::vector<std::string>& words) override
|
|
902
|
+
{
|
|
903
|
+
return this->_addDoc(this->_makeDoc(words));
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words) const override
|
|
907
|
+
{
|
|
908
|
+
return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer) override
|
|
912
|
+
{
|
|
913
|
+
return this->_addDoc(this->template _makeRawDoc<false>(rawStr, tokenizer));
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer) const override
|
|
917
|
+
{
|
|
918
|
+
return make_unique<_DocType>(as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer));
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
922
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len) override
|
|
923
|
+
{
|
|
924
|
+
return this->_addDoc(this->_makeRawDoc(rawStr, words, pos, len));
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
928
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len) const override
|
|
929
|
+
{
|
|
930
|
+
return make_unique<_DocType>(this->_makeRawDoc(rawStr, words, pos, len));
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
934
|
+
{
|
|
935
|
+
if (priors.size() != K) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
|
|
936
|
+
for (auto p : priors)
|
|
937
|
+
{
|
|
938
|
+
if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
|
|
939
|
+
}
|
|
940
|
+
this->dict.add(word);
|
|
941
|
+
etaByWord.emplace(word, priors);
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
std::vector<Float> getWordPrior(const std::string& word) const override
|
|
945
|
+
{
|
|
946
|
+
if (etaByTopicWord.size())
|
|
947
|
+
{
|
|
948
|
+
auto id = this->dict.toWid(word);
|
|
949
|
+
if (id == (Vid)-1) return {};
|
|
950
|
+
auto col = etaByTopicWord.col(id);
|
|
951
|
+
return std::vector<Float>{ col.data(), col.data() + col.size() };
|
|
952
|
+
}
|
|
953
|
+
else
|
|
954
|
+
{
|
|
955
|
+
auto it = etaByWord.find(word);
|
|
956
|
+
if (it == etaByWord.end()) return {};
|
|
957
|
+
return it->second;
|
|
958
|
+
}
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
void updateDocs()
|
|
962
|
+
{
|
|
963
|
+
size_t docId = 0;
|
|
964
|
+
for (auto& doc : this->docs)
|
|
965
|
+
{
|
|
966
|
+
doc.template update<>(getTopicDocPtr(docId++), *static_cast<DerivedClass*>(this));
|
|
967
|
+
}
|
|
968
|
+
}
|
|
969
|
+
|
|
970
|
+
void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
|
|
971
|
+
{
|
|
972
|
+
if (initDocs) this->removeStopwords(minWordCnt, minWordDf, removeTopN);
|
|
973
|
+
static_cast<DerivedClass*>(this)->updateWeakArray();
|
|
974
|
+
static_cast<DerivedClass*>(this)->initGlobalState(initDocs);
|
|
975
|
+
static_cast<DerivedClass*>(this)->prepareWordPriors();
|
|
976
|
+
|
|
977
|
+
const size_t V = this->realV;
|
|
978
|
+
|
|
979
|
+
if (initDocs)
|
|
980
|
+
{
|
|
981
|
+
std::vector<uint32_t> df, cf, tf;
|
|
982
|
+
uint32_t totCf;
|
|
983
|
+
|
|
984
|
+
// calculate weighting
|
|
985
|
+
if (_tw != TermWeight::one)
|
|
986
|
+
{
|
|
987
|
+
df.resize(V);
|
|
988
|
+
tf.resize(V);
|
|
989
|
+
for (auto& doc : this->docs)
|
|
990
|
+
{
|
|
991
|
+
for (auto w : std::unordered_set<Vid>{ doc.words.begin(), doc.words.end() })
|
|
992
|
+
{
|
|
993
|
+
if (w >= this->realV) continue;
|
|
994
|
+
++df[w];
|
|
995
|
+
}
|
|
996
|
+
}
|
|
997
|
+
totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
|
|
998
|
+
}
|
|
999
|
+
if (_tw == TermWeight::idf)
|
|
1000
|
+
{
|
|
1001
|
+
vocabWeights.resize(V);
|
|
1002
|
+
for (size_t i = 0; i < V; ++i)
|
|
1003
|
+
{
|
|
1004
|
+
vocabWeights[i] = log(this->docs.size() / (Float)df[i]);
|
|
1005
|
+
}
|
|
1006
|
+
}
|
|
1007
|
+
else if (_tw == TermWeight::pmi)
|
|
1008
|
+
{
|
|
1009
|
+
vocabWeights.resize(V);
|
|
1010
|
+
for (size_t i = 0; i < V; ++i)
|
|
1011
|
+
{
|
|
1012
|
+
vocabWeights[i] = this->vocabCf[i] / (float)totCf;
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
decltype(static_cast<DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
|
|
1017
|
+
if(!(m_flags & flags::generator_by_doc)) generator = static_cast<DerivedClass*>(this)->makeGeneratorForInit(nullptr);
|
|
1018
|
+
for (auto& doc : this->docs)
|
|
1019
|
+
{
|
|
1020
|
+
initializeDocState<false>(doc, &doc - &this->docs[0], generator, this->globalState, this->rg);
|
|
1021
|
+
}
|
|
1022
|
+
}
|
|
1023
|
+
else
|
|
1024
|
+
{
|
|
1025
|
+
static_cast<DerivedClass*>(this)->updateDocs();
|
|
1026
|
+
for (auto& doc : this->docs) doc.updateSumWordWeight(this->realV);
|
|
1027
|
+
}
|
|
1028
|
+
static_cast<DerivedClass*>(this)->prepareShared();
|
|
1029
|
+
BaseClass::prepare(initDocs, minWordCnt, minWordDf, removeTopN);
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
std::vector<uint64_t> getCountByTopic() const override
|
|
1033
|
+
{
|
|
1034
|
+
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
1038
|
+
{
|
|
1039
|
+
std::vector<Float> ret(K);
|
|
1040
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), K }.array() =
|
|
1041
|
+
(doc.numByTopic.array().template cast<Float>() + alphas.array()) / (doc.getSumWordWeight() + alphas.sum());
|
|
1042
|
+
return ret;
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
};
|
|
1046
|
+
|
|
1047
|
+
template<TermWeight _tw>
|
|
1048
|
+
template<typename _TopicModel>
|
|
1049
|
+
void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
|
|
1050
|
+
{
|
|
1051
|
+
numByTopic.init(ptr, mdl.getK());
|
|
1052
|
+
for (size_t i = 0; i < Zs.size(); ++i)
|
|
1053
|
+
{
|
|
1054
|
+
if (this->words[i] >= mdl.getV()) continue;
|
|
1055
|
+
numByTopic[Zs[i]] += _tw != TermWeight::one ? wordWeights[i] : 1;
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
}
|