tomoto 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +123 -0
- data/ext/tomoto/ext.cpp +245 -0
- data/ext/tomoto/extconf.rb +28 -0
- data/lib/tomoto.rb +12 -0
- data/lib/tomoto/ct.rb +11 -0
- data/lib/tomoto/hdp.rb +11 -0
- data/lib/tomoto/lda.rb +67 -0
- data/lib/tomoto/version.rb +3 -0
- data/vendor/EigenRand/EigenRand/Core.h +1139 -0
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
- data/vendor/EigenRand/EigenRand/EigenRand +19 -0
- data/vendor/EigenRand/EigenRand/Macro.h +24 -0
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
- data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
- data/vendor/EigenRand/EigenRand/doc.h +220 -0
- data/vendor/EigenRand/LICENSE +21 -0
- data/vendor/EigenRand/README.md +288 -0
- data/vendor/eigen/COPYING.BSD +26 -0
- data/vendor/eigen/COPYING.GPL +674 -0
- data/vendor/eigen/COPYING.LGPL +502 -0
- data/vendor/eigen/COPYING.MINPACK +52 -0
- data/vendor/eigen/COPYING.MPL2 +373 -0
- data/vendor/eigen/COPYING.README +18 -0
- data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
- data/vendor/eigen/Eigen/Cholesky +46 -0
- data/vendor/eigen/Eigen/CholmodSupport +48 -0
- data/vendor/eigen/Eigen/Core +537 -0
- data/vendor/eigen/Eigen/Dense +7 -0
- data/vendor/eigen/Eigen/Eigen +2 -0
- data/vendor/eigen/Eigen/Eigenvalues +61 -0
- data/vendor/eigen/Eigen/Geometry +62 -0
- data/vendor/eigen/Eigen/Householder +30 -0
- data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
- data/vendor/eigen/Eigen/Jacobi +33 -0
- data/vendor/eigen/Eigen/LU +50 -0
- data/vendor/eigen/Eigen/MetisSupport +35 -0
- data/vendor/eigen/Eigen/OrderingMethods +73 -0
- data/vendor/eigen/Eigen/PaStiXSupport +48 -0
- data/vendor/eigen/Eigen/PardisoSupport +35 -0
- data/vendor/eigen/Eigen/QR +51 -0
- data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
- data/vendor/eigen/Eigen/SPQRSupport +34 -0
- data/vendor/eigen/Eigen/SVD +51 -0
- data/vendor/eigen/Eigen/Sparse +36 -0
- data/vendor/eigen/Eigen/SparseCholesky +45 -0
- data/vendor/eigen/Eigen/SparseCore +69 -0
- data/vendor/eigen/Eigen/SparseLU +46 -0
- data/vendor/eigen/Eigen/SparseQR +37 -0
- data/vendor/eigen/Eigen/StdDeque +27 -0
- data/vendor/eigen/Eigen/StdList +26 -0
- data/vendor/eigen/Eigen/StdVector +27 -0
- data/vendor/eigen/Eigen/SuperLUSupport +64 -0
- data/vendor/eigen/Eigen/UmfPackSupport +40 -0
- data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
- data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
- data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
- data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
- data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
- data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
- data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
- data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
- data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
- data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
- data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
- data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
- data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
- data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
- data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
- data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
- data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
- data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
- data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
- data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
- data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
- data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
- data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
- data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
- data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
- data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
- data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
- data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
- data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
- data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
- data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
- data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
- data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
- data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
- data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
- data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
- data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
- data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
- data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
- data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
- data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
- data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
- data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
- data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
- data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
- data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
- data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
- data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
- data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
- data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
- data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
- data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
- data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
- data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
- data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
- data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
- data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
- data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
- data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
- data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
- data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
- data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
- data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
- data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
- data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
- data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
- data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
- data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
- data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
- data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
- data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
- data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
- data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
- data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
- data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
- data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
- data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
- data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
- data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
- data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
- data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
- data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
- data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
- data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
- data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
- data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
- data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
- data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
- data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
- data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
- data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
- data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
- data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
- data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
- data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
- data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
- data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
- data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
- data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
- data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
- data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
- data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
- data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
- data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
- data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
- data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
- data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
- data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
- data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
- data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
- data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
- data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
- data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
- data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
- data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
- data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
- data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
- data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
- data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
- data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
- data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
- data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
- data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
- data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
- data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
- data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
- data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
- data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
- data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
- data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
- data/vendor/eigen/README.md +3 -0
- data/vendor/eigen/bench/README.txt +55 -0
- data/vendor/eigen/bench/btl/COPYING +340 -0
- data/vendor/eigen/bench/btl/README +154 -0
- data/vendor/eigen/bench/tensors/README +21 -0
- data/vendor/eigen/blas/README.txt +6 -0
- data/vendor/eigen/demos/mandelbrot/README +10 -0
- data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
- data/vendor/eigen/demos/opengl/README +13 -0
- data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
- data/vendor/eigen/unsupported/README.txt +50 -0
- data/vendor/tomotopy/LICENSE +21 -0
- data/vendor/tomotopy/README.kr.rst +375 -0
- data/vendor/tomotopy/README.rst +382 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
- data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
- data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
- data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
- data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
- data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
- data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
- data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
- data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
- data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
- data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
- data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
- data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
- data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
- data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
- data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
- data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
- data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
- data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
- data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
- data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
- data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
- data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
- data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
- data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
- data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
- data/vendor/tomotopy/src/Utils/exception.h +28 -0
- data/vendor/tomotopy/src/Utils/math.h +281 -0
- data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
- data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
- data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
- data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
- data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
- data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
- data/vendor/tomotopy/src/Utils/text.hpp +49 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
- metadata +531 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
template<TermWeight _tw>
|
|
7
|
+
struct DocumentCTM : public DocumentLDA<_tw>
|
|
8
|
+
{
|
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
+
Eigen::Matrix<Float, -1, -1> beta; // Dim: (K, betaSample)
|
|
12
|
+
Eigen::Matrix<Float, -1, 1> smBeta; // Dim: K
|
|
13
|
+
|
|
14
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, smBeta);
|
|
15
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, smBeta);
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
class ICTModel : public ILDAModel
|
|
19
|
+
{
|
|
20
|
+
public:
|
|
21
|
+
using DefaultDocType = DocumentCTM<TermWeight::one>;
|
|
22
|
+
static ICTModel* create(TermWeight _weight, size_t _K = 1,
|
|
23
|
+
Float smoothingAlpha = 0.1, Float _eta = 0.01,
|
|
24
|
+
size_t seed = std::random_device{}(),
|
|
25
|
+
bool scalarRng = false);
|
|
26
|
+
|
|
27
|
+
virtual void setNumBetaSample(size_t numSample) = 0;
|
|
28
|
+
virtual size_t getNumBetaSample() const = 0;
|
|
29
|
+
virtual void setNumTMNSample(size_t numSample) = 0;
|
|
30
|
+
virtual size_t getNumTMNSample() const = 0;
|
|
31
|
+
virtual void setNumDocBetaSample(size_t numSample) = 0;
|
|
32
|
+
virtual size_t getNumDocBetaSample() const = 0;
|
|
33
|
+
virtual std::vector<Float> getPriorMean() const = 0;
|
|
34
|
+
virtual std::vector<Float> getPriorCov() const = 0;
|
|
35
|
+
virtual std::vector<Float> getCorrelationTopic(Tid k) const = 0;
|
|
36
|
+
};
|
|
37
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "CTModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class CTModel<TermWeight::one>;
|
|
6
|
+
template class CTModel<TermWeight::idf>;
|
|
7
|
+
template class CTModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
ICTModel* ICTModel::create(TermWeight _weight, size_t _K, Float smoothingAlpha, Float _eta, size_t seed, bool scalarRng)
|
|
10
|
+
{
|
|
11
|
+
TMT_SWITCH_TW(_weight, scalarRng, CTModel, _K, smoothingAlpha, _eta, seed);
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "../Utils/MultiNormalDistribution.hpp"
|
|
4
|
+
#include "../Utils/TruncMultiNormal.hpp"
|
|
5
|
+
#include "CT.h"
|
|
6
|
+
/*
|
|
7
|
+
Implementation of CTM using Gibbs sampling by bab2min
|
|
8
|
+
* Blei, D., & Lafferty, J. (2006). Correlated topic models. Advances in neural information processing systems, 18, 147.
|
|
9
|
+
* Mimno, D., Wallach, H., & McCallum, A. (2008, December). Gibbs sampling for logistic normal topic models with graph-based priors. In NIPS Workshop on Analyzing Graphs (Vol. 61).
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
namespace tomoto
|
|
13
|
+
{
|
|
14
|
+
template<TermWeight _tw>
|
|
15
|
+
struct ModelStateCTM : public ModelStateLDA<_tw>
|
|
16
|
+
{
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
template<TermWeight _tw, typename _RandGen,
|
|
20
|
+
size_t _Flags = flags::partitioned_multisampling,
|
|
21
|
+
typename _Interface = ICTModel,
|
|
22
|
+
typename _Derived = void,
|
|
23
|
+
typename _DocType = DocumentCTM<_tw>,
|
|
24
|
+
typename _ModelState = ModelStateCTM<_tw>>
|
|
25
|
+
class CTModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
|
|
26
|
+
typename std::conditional<std::is_same<_Derived, void>::value, CTModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
|
27
|
+
_DocType, _ModelState>
|
|
28
|
+
{
|
|
29
|
+
protected:
|
|
30
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, CTModel<_tw, _RandGen>, _Derived>::type;
|
|
31
|
+
using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
32
|
+
friend BaseClass;
|
|
33
|
+
friend typename BaseClass::BaseClass;
|
|
34
|
+
using WeightType = typename BaseClass::WeightType;
|
|
35
|
+
|
|
36
|
+
static constexpr char TMID[] = "CTM\0";
|
|
37
|
+
|
|
38
|
+
uint64_t numBetaSample = 10;
|
|
39
|
+
uint64_t numTMNSample = 5;
|
|
40
|
+
uint64_t numDocBetaSample = -1;
|
|
41
|
+
math::MultiNormalDistribution<Float> topicPrior;
|
|
42
|
+
|
|
43
|
+
template<bool _asymEta>
|
|
44
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
45
|
+
{
|
|
46
|
+
const size_t V = this->realV;
|
|
47
|
+
assert(vid < V);
|
|
48
|
+
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
49
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
50
|
+
zLikelihood = doc.smBeta.array()
|
|
51
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
52
|
+
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
53
|
+
sample::prefixSum(zLikelihood.data(), this->K);
|
|
54
|
+
return &zLikelihood[0];
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
void updateBeta(_DocType& doc, _RandGen& rg) const
|
|
58
|
+
{
|
|
59
|
+
Eigen::Matrix<Float, -1, 1> pbeta, lowerBound, upperBound;
|
|
60
|
+
constexpr Float epsilon = 1e-8;
|
|
61
|
+
constexpr size_t burnIn = 3;
|
|
62
|
+
|
|
63
|
+
pbeta = lowerBound = upperBound = Eigen::Matrix<Float, -1, 1>::Zero(this->K);
|
|
64
|
+
for (size_t i = 0; i < numBetaSample + burnIn; ++i)
|
|
65
|
+
{
|
|
66
|
+
if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
|
|
67
|
+
else pbeta = doc.beta.col(i % numBetaSample).array().exp();
|
|
68
|
+
Float betaESum = pbeta.sum() + 1;
|
|
69
|
+
pbeta /= betaESum;
|
|
70
|
+
for (size_t k = 0; k < this->K; ++k)
|
|
71
|
+
{
|
|
72
|
+
Float N_k = doc.numByTopic[k] + this->alpha;
|
|
73
|
+
Float N_nk = doc.getSumWordWeight() + this->alpha * (this->K + 1) - N_k;
|
|
74
|
+
Float u1 = rg.uniform_real(), u2 = rg.uniform_real();
|
|
75
|
+
Float max_uk = epsilon + pow(u1, (Float)1 / N_k) * (pbeta[k] - epsilon);
|
|
76
|
+
Float min_unk = (1 - pow(u2, (Float)1 / N_nk))
|
|
77
|
+
* (1 - pbeta[k]) + pbeta[k];
|
|
78
|
+
|
|
79
|
+
Float c = betaESum * (1 - pbeta[k]);
|
|
80
|
+
lowerBound[k] = log(c * max_uk / (1 - max_uk));
|
|
81
|
+
upperBound[k] = log(c * min_unk / (1 - min_unk));
|
|
82
|
+
if (lowerBound[k] > upperBound[k])
|
|
83
|
+
{
|
|
84
|
+
THROW_ERROR_WITH_INFO(exception::TrainingError,
|
|
85
|
+
text::format("Bound Error: LB(%f) > UB(%f)\n"
|
|
86
|
+
"max_uk: %f, min_unk: %f, c: %f", lowerBound[k], upperBound[k], max_uk, min_unk, c));
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
try
|
|
91
|
+
{
|
|
92
|
+
math::sampleFromTruncatedMultiNormal(doc.beta.col((i + 1) % numBetaSample),
|
|
93
|
+
topicPrior, lowerBound, upperBound, rg, numTMNSample);
|
|
94
|
+
|
|
95
|
+
if (!std::isfinite(doc.beta.col((i + 1) % numBetaSample)[0]))
|
|
96
|
+
THROW_ERROR_WITH_INFO(exception::TrainingError,
|
|
97
|
+
text::format("doc.beta.col(%d) is %f", (i + 1) % numBetaSample,
|
|
98
|
+
doc.beta.col((i + 1) % numBetaSample)[0]));
|
|
99
|
+
}
|
|
100
|
+
catch (const std::runtime_error& e)
|
|
101
|
+
{
|
|
102
|
+
std::cerr << e.what() << std::endl;
|
|
103
|
+
THROW_ERROR_WITH_INFO(exception::TrainingError, e.what());
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// update softmax-applied beta coefficient
|
|
108
|
+
doc.smBeta.head(this->K) = doc.beta.block(0, 0, this->K, std::min(numBetaSample, numDocBetaSample)).rowwise().mean();
|
|
109
|
+
doc.smBeta = doc.smBeta.array().exp();
|
|
110
|
+
doc.smBeta /= doc.smBeta.array().sum();
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
|
|
114
|
+
void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
|
|
115
|
+
{
|
|
116
|
+
BaseClass::template sampleDocument<_ps, _infer>(doc, edd, docId, ld, rgs, iterationCnt, partitionId);
|
|
117
|
+
/*if (iterationCnt >= this->burnIn && this->optimInterval && (iterationCnt + 1) % this->optimInterval == 0)
|
|
118
|
+
{
|
|
119
|
+
updateBeta(doc, rgs);
|
|
120
|
+
}*/
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
template<typename _DocIter>
|
|
124
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
|
125
|
+
{
|
|
126
|
+
if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
|
|
127
|
+
|
|
128
|
+
if (pool)
|
|
129
|
+
{
|
|
130
|
+
std::vector<std::future<void>> res;
|
|
131
|
+
const size_t chStride = pool->getNumWorkers() * 8;
|
|
132
|
+
size_t dist = std::distance(first, last);
|
|
133
|
+
for (size_t ch = 0; ch < chStride; ++ch)
|
|
134
|
+
{
|
|
135
|
+
auto b = first, e = first;
|
|
136
|
+
std::advance(b, dist * ch / chStride);
|
|
137
|
+
std::advance(e, dist * (ch + 1) / chStride);
|
|
138
|
+
res.emplace_back(pool->enqueue([&, ch, chStride](size_t threadId, _DocIter b, _DocIter e)
|
|
139
|
+
{
|
|
140
|
+
for (auto doc = b; doc != e; ++doc)
|
|
141
|
+
{
|
|
142
|
+
updateBeta(*doc, rgs[threadId]);
|
|
143
|
+
}
|
|
144
|
+
}, b, e));
|
|
145
|
+
}
|
|
146
|
+
for (auto& r : res) r.get();
|
|
147
|
+
}
|
|
148
|
+
else
|
|
149
|
+
{
|
|
150
|
+
for (auto doc = first; doc != last; ++doc)
|
|
151
|
+
{
|
|
152
|
+
updateBeta(*doc, rgs[0]);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
158
|
+
{
|
|
159
|
+
std::cerr << "Failed to sample! Reset prior and retry!" << std::endl;
|
|
160
|
+
const size_t chStride = std::min(pool.getNumWorkers() * 8, this->docs.size());
|
|
161
|
+
topicPrior = math::MultiNormalDistribution<Float>{ this->K };
|
|
162
|
+
std::vector<std::future<void>> res;
|
|
163
|
+
for (size_t ch = 0; ch < chStride; ++ch)
|
|
164
|
+
{
|
|
165
|
+
res.emplace_back(pool.enqueue([&, this](size_t threadId, size_t ch)
|
|
166
|
+
{
|
|
167
|
+
for (size_t i = ch; i < this->docs.size(); i += chStride)
|
|
168
|
+
{
|
|
169
|
+
this->docs[i].beta.setZero();
|
|
170
|
+
updateBeta(this->docs[i], rgs[threadId]);
|
|
171
|
+
}
|
|
172
|
+
}, ch));
|
|
173
|
+
}
|
|
174
|
+
for (auto& r : res) r.get();
|
|
175
|
+
return 0;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
179
|
+
{
|
|
180
|
+
std::vector<std::future<void>> res;
|
|
181
|
+
topicPrior = math::MultiNormalDistribution<Float>::estimate([this](size_t i)
|
|
182
|
+
{
|
|
183
|
+
return this->docs[i / numBetaSample].beta.col(i % numBetaSample);
|
|
184
|
+
}, this->docs.size() * numBetaSample);
|
|
185
|
+
if (!std::isfinite(topicPrior.mean[0]))
|
|
186
|
+
THROW_ERROR_WITH_INFO(exception::TrainingError,
|
|
187
|
+
text::format("topicPrior.mean is %f", topicPrior.mean[0]));
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
template<typename _DocIter>
|
|
191
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
192
|
+
{
|
|
193
|
+
const auto K = this->K;
|
|
194
|
+
const auto alpha = this->alpha;
|
|
195
|
+
|
|
196
|
+
double ll = 0;
|
|
197
|
+
for (; _first != _last; ++_first)
|
|
198
|
+
{
|
|
199
|
+
auto& doc = *_first;
|
|
200
|
+
Eigen::Matrix<Float, -1, 1> pbeta = doc.smBeta.array().log();
|
|
201
|
+
Float last = pbeta[K - 1];
|
|
202
|
+
for (Tid k = 0; k < K; ++k)
|
|
203
|
+
{
|
|
204
|
+
ll += pbeta[k] * (doc.numByTopic[k] + alpha) - math::lgammaT(doc.numByTopic[k] + alpha + 1);
|
|
205
|
+
}
|
|
206
|
+
pbeta.array() -= last;
|
|
207
|
+
ll += topicPrior.getLL(pbeta.head(this->K));
|
|
208
|
+
ll += math::lgammaT(doc.getSumWordWeight() + alpha * K + 1);
|
|
209
|
+
}
|
|
210
|
+
return ll;
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
214
|
+
{
|
|
215
|
+
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
216
|
+
doc.beta = Eigen::Matrix<Float, -1, -1>::Zero(this->K, numBetaSample);
|
|
217
|
+
doc.smBeta = Eigen::Matrix<Float, -1, 1>::Constant(this->K, (Float)1 / this->K);
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
void updateDocs()
|
|
221
|
+
{
|
|
222
|
+
BaseClass::updateDocs();
|
|
223
|
+
for (auto& doc : this->docs)
|
|
224
|
+
{
|
|
225
|
+
doc.beta = Eigen::Matrix<Float, -1, -1>::Zero(this->K, numBetaSample);
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
void initGlobalState(bool initDocs)
|
|
230
|
+
{
|
|
231
|
+
BaseClass::initGlobalState(initDocs);
|
|
232
|
+
if (initDocs)
|
|
233
|
+
{
|
|
234
|
+
topicPrior = math::MultiNormalDistribution<Float>{ this->K };
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
public:
|
|
239
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numBetaSample, numTMNSample, topicPrior);
|
|
240
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numBetaSample, numTMNSample, topicPrior);
|
|
241
|
+
|
|
242
|
+
CTModel(size_t _K = 1, Float smoothingAlpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
|
|
243
|
+
: BaseClass(_K, smoothingAlpha, _eta, _rg)
|
|
244
|
+
{
|
|
245
|
+
this->optimInterval = 2;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
249
|
+
{
|
|
250
|
+
std::vector<Float> ret(this->K);
|
|
251
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() =
|
|
252
|
+
doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
|
|
253
|
+
return ret;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
std::vector<Float> getPriorMean() const override
|
|
257
|
+
{
|
|
258
|
+
return { topicPrior.mean.data(), topicPrior.mean.data() + topicPrior.mean.size() };
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
std::vector<Float> getPriorCov() const override
|
|
262
|
+
{
|
|
263
|
+
return { topicPrior.cov.data(), topicPrior.cov.data() + topicPrior.cov.size() };
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
std::vector<Float> getCorrelationTopic(Tid k) const override
|
|
267
|
+
{
|
|
268
|
+
Eigen::Matrix<Float, -1, 1> ret = topicPrior.cov.col(k).array() / (topicPrior.cov.diagonal().array() * topicPrior.cov(k, k)).sqrt();
|
|
269
|
+
return { ret.data(), ret.data() + ret.size() };
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
GETTER(NumBetaSample, size_t, numBetaSample);
|
|
273
|
+
|
|
274
|
+
void setNumBetaSample(size_t _numSample) override
|
|
275
|
+
{
|
|
276
|
+
numBetaSample = _numSample;
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
GETTER(NumDocBetaSample, size_t, numDocBetaSample);
|
|
280
|
+
|
|
281
|
+
void setNumDocBetaSample(size_t _numSample) override
|
|
282
|
+
{
|
|
283
|
+
numDocBetaSample = _numSample;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
GETTER(NumTMNSample, size_t, numTMNSample);
|
|
287
|
+
|
|
288
|
+
void setNumTMNSample(size_t _numSample) override
|
|
289
|
+
{
|
|
290
|
+
numTMNSample = _numSample;
|
|
291
|
+
}
|
|
292
|
+
};
|
|
293
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
template<TermWeight _tw>
|
|
7
|
+
struct DocumentDMR : public DocumentLDA<_tw>
|
|
8
|
+
{
|
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
+
size_t metadata = 0;
|
|
12
|
+
|
|
13
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadata);
|
|
14
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata);
|
|
15
|
+
};
|
|
16
|
+
|
|
17
|
+
class IDMRModel : public ILDAModel
|
|
18
|
+
{
|
|
19
|
+
public:
|
|
20
|
+
using DefaultDocType = DocumentDMR<TermWeight::one>;
|
|
21
|
+
static IDMRModel* create(TermWeight _weight, size_t _K = 1,
|
|
22
|
+
Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
|
|
23
|
+
size_t seed = std::random_device{}(),
|
|
24
|
+
bool scalarRng = false);
|
|
25
|
+
|
|
26
|
+
virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) = 0;
|
|
27
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) const = 0;
|
|
28
|
+
|
|
29
|
+
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
30
|
+
const std::vector<std::string>& metadata) = 0;
|
|
31
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
32
|
+
const std::vector<std::string>& metadata) const = 0;
|
|
33
|
+
|
|
34
|
+
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
35
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
36
|
+
const std::vector<std::string>& metadata) = 0;
|
|
37
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
38
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
39
|
+
const std::vector<std::string>& metadata) const = 0;
|
|
40
|
+
|
|
41
|
+
virtual void setAlphaEps(Float _alphaEps) = 0;
|
|
42
|
+
virtual Float getAlphaEps() const = 0;
|
|
43
|
+
virtual void setOptimRepeat(size_t repeat) = 0;
|
|
44
|
+
virtual size_t getOptimRepeat() const = 0;
|
|
45
|
+
virtual size_t getF() const = 0;
|
|
46
|
+
virtual Float getSigma() const = 0;
|
|
47
|
+
virtual const Dictionary& getMetadataDict() const = 0;
|
|
48
|
+
virtual std::vector<Float> getLambdaByMetadata(size_t metadataId) const = 0;
|
|
49
|
+
virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
|
|
50
|
+
};
|
|
51
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "DMRModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class DMRModel<TermWeight::one>;
|
|
6
|
+
template class DMRModel<TermWeight::idf>;
|
|
7
|
+
template class DMRModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
IDMRModel* IDMRModel::create(TermWeight _weight, size_t _K, Float _defaultAlpha, Float _sigma, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
|
|
10
|
+
{
|
|
11
|
+
TMT_SWITCH_TW(_weight, scalarRng, DMRModel, _K, _defaultAlpha, _sigma, _eta, _alphaEps, seed);
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "../Utils/LBFGS.h"
|
|
4
|
+
#include "../Utils/text.hpp"
|
|
5
|
+
#include "DMR.h"
|
|
6
|
+
/*
|
|
7
|
+
Implementation of DMR using Gibbs sampling by bab2min
|
|
8
|
+
* Mimno, D., & McCallum, A. (2012). Topic models conditioned on arbitrary features with dirichlet-multinomial regression. arXiv preprint arXiv:1206.3278.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
namespace tomoto
|
|
12
|
+
{
|
|
13
|
+
template<TermWeight _tw>
|
|
14
|
+
struct ModelStateDMR : public ModelStateLDA<_tw>
|
|
15
|
+
{
|
|
16
|
+
Eigen::Matrix<Float, -1, 1> tmpK;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
template<TermWeight _tw, typename _RandGen,
|
|
20
|
+
size_t _Flags = flags::partitioned_multisampling,
|
|
21
|
+
typename _Interface = IDMRModel,
|
|
22
|
+
typename _Derived = void,
|
|
23
|
+
typename _DocType = DocumentDMR<_tw>,
|
|
24
|
+
typename _ModelState = ModelStateDMR<_tw>>
|
|
25
|
+
class DMRModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
|
|
26
|
+
typename std::conditional<std::is_same<_Derived, void>::value, DMRModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
|
27
|
+
_DocType, _ModelState>
|
|
28
|
+
{
|
|
29
|
+
protected:
|
|
30
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, DMRModel<_tw, _RandGen>, _Derived>::type;
|
|
31
|
+
using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
32
|
+
friend BaseClass;
|
|
33
|
+
friend typename BaseClass::BaseClass;
|
|
34
|
+
using WeightType = typename BaseClass::WeightType;
|
|
35
|
+
|
|
36
|
+
static constexpr char TMID[] = "DMR\0";
|
|
37
|
+
|
|
38
|
+
Eigen::Matrix<Float, -1, -1> lambda;
|
|
39
|
+
Eigen::Matrix<Float, -1, -1> expLambda;
|
|
40
|
+
Float sigma;
|
|
41
|
+
uint32_t F = 0;
|
|
42
|
+
uint32_t optimRepeat = 5;
|
|
43
|
+
Float alphaEps = 1e-10;
|
|
44
|
+
Float temperatureScale = 0;
|
|
45
|
+
static constexpr Float maxLambda = 10;
|
|
46
|
+
static constexpr size_t maxBFGSIteration = 10;
|
|
47
|
+
|
|
48
|
+
Dictionary metadataDict;
|
|
49
|
+
LBFGSpp::LBFGSSolver<Float, LBFGSpp::LineSearchBracketing> solver;
|
|
50
|
+
|
|
51
|
+
Float getNegativeLambdaLL(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) const
|
|
52
|
+
{
|
|
53
|
+
g = (x.array() - log(this->alpha)) / pow(sigma, 2);
|
|
54
|
+
return (x.array() - log(this->alpha)).pow(2).sum() / 2 / pow(sigma, 2);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
Float evaluateLambdaObj(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g, ThreadPool& pool, _ModelState* localData) const
|
|
58
|
+
{
|
|
59
|
+
// if one of x is greater than maxLambda, return +inf for preventing searching more
|
|
60
|
+
if ((x.array() > maxLambda).any()) return INFINITY;
|
|
61
|
+
|
|
62
|
+
const auto K = this->K;
|
|
63
|
+
|
|
64
|
+
Float fx = - static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
|
|
65
|
+
auto alphas = (x.array().exp() + alphaEps).eval();
|
|
66
|
+
|
|
67
|
+
std::vector<std::future<Eigen::Matrix<Float, -1, 1>>> res;
|
|
68
|
+
const size_t chStride = pool.getNumWorkers() * 8;
|
|
69
|
+
for (size_t ch = 0; ch < chStride; ++ch)
|
|
70
|
+
{
|
|
71
|
+
res.emplace_back(pool.enqueue([&](size_t threadId)
|
|
72
|
+
{
|
|
73
|
+
auto& tmpK = localData[threadId].tmpK;
|
|
74
|
+
if (!tmpK.size()) tmpK.resize(this->K);
|
|
75
|
+
Eigen::Matrix<Float, -1, 1> val = Eigen::Matrix<Float, -1, 1>::Zero(K * F + 1);
|
|
76
|
+
for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
|
|
77
|
+
{
|
|
78
|
+
const auto& doc = this->docs[docId];
|
|
79
|
+
auto alphaDoc = alphas.segment(doc.metadata * K, K);
|
|
80
|
+
Float alphaSum = alphaDoc.sum();
|
|
81
|
+
for (Tid k = 0; k < K; ++k)
|
|
82
|
+
{
|
|
83
|
+
val[K * F] -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
|
|
84
|
+
if (!std::isfinite(alphaDoc[k]) && alphaDoc[k] > 0) tmpK[k] = 0;
|
|
85
|
+
else tmpK[k] = -(math::digammaT(alphaDoc[k]) - math::digammaT(doc.numByTopic[k] + alphaDoc[k]));
|
|
86
|
+
}
|
|
87
|
+
//val[K * F] = -(lgammaApprox(alphaDoc.array()) - lgammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array())).sum();
|
|
88
|
+
//tmpK = -(digammaApprox(alphaDoc.array()) - digammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array()));
|
|
89
|
+
val[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
|
|
90
|
+
Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
|
|
91
|
+
if (!std::isfinite(alphaSum) && alphaSum > 0)
|
|
92
|
+
{
|
|
93
|
+
val[K * F] = -INFINITY;
|
|
94
|
+
t = 0;
|
|
95
|
+
}
|
|
96
|
+
val.segment(doc.metadata * K, K).array() -= alphaDoc.array() * (tmpK.array() + t);
|
|
97
|
+
}
|
|
98
|
+
return val;
|
|
99
|
+
}));
|
|
100
|
+
}
|
|
101
|
+
for (auto& r : res)
|
|
102
|
+
{
|
|
103
|
+
auto ret = r.get();
|
|
104
|
+
fx += ret[K * F];
|
|
105
|
+
g += ret.head(K * F);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// positive fx is an error from limited precision of float.
|
|
109
|
+
if (fx > 0) return INFINITY;
|
|
110
|
+
return -fx;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
void initParameters()
|
|
114
|
+
{
|
|
115
|
+
auto dist = std::normal_distribution<Float>(log(this->alpha), sigma);
|
|
116
|
+
for (size_t i = 0; i < this->K; ++i) for (size_t j = 0; j < F; ++j)
|
|
117
|
+
{
|
|
118
|
+
lambda(i, j) = dist(this->rg);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
123
|
+
{
|
|
124
|
+
Eigen::Matrix<Float, -1, -1> bLambda;
|
|
125
|
+
Float fx = 0, bestFx = INFINITY;
|
|
126
|
+
for (size_t i = 0; i < optimRepeat; ++i)
|
|
127
|
+
{
|
|
128
|
+
static_cast<DerivedClass*>(this)->initParameters();
|
|
129
|
+
int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g)
|
|
130
|
+
{
|
|
131
|
+
return static_cast<DerivedClass*>(this)->evaluateLambdaObj(x, g, pool, localData);
|
|
132
|
+
}, Eigen::Map<Eigen::Matrix<Float, -1, 1>>(lambda.data(), lambda.size()), fx);
|
|
133
|
+
|
|
134
|
+
if (fx < bestFx)
|
|
135
|
+
{
|
|
136
|
+
bLambda = lambda;
|
|
137
|
+
bestFx = fx;
|
|
138
|
+
//printf("\t(%d) %e\n", ret, fx);
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
if (!std::isfinite(bestFx))
|
|
142
|
+
{
|
|
143
|
+
throw exception::TrainingError{ "optimizing parameters has been failed!" };
|
|
144
|
+
}
|
|
145
|
+
lambda = bLambda;
|
|
146
|
+
//std::cerr << fx << std::endl;
|
|
147
|
+
expLambda = lambda.array().exp() + alphaEps;
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
151
|
+
{
|
|
152
|
+
std::cerr << "Failed to optimize! Reset prior and retry!" << std::endl;
|
|
153
|
+
lambda.setZero();
|
|
154
|
+
expLambda = lambda.array().exp() + alphaEps;
|
|
155
|
+
return 0;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
template<bool _asymEta>
|
|
159
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
160
|
+
{
|
|
161
|
+
const size_t V = this->realV;
|
|
162
|
+
assert(vid < V);
|
|
163
|
+
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
164
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
165
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->expLambda.col(doc.metadata).array())
|
|
166
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
167
|
+
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
168
|
+
|
|
169
|
+
sample::prefixSum(zLikelihood.data(), this->K);
|
|
170
|
+
return &zLikelihood[0];
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
double getLLDocTopic(const _DocType& doc) const
|
|
175
|
+
{
|
|
176
|
+
const size_t V = this->realV;
|
|
177
|
+
const auto K = this->K;
|
|
178
|
+
|
|
179
|
+
auto alphaDoc = expLambda.col(doc.metadata);
|
|
180
|
+
|
|
181
|
+
Float ll = 0;
|
|
182
|
+
Float alphaSum = alphaDoc.sum();
|
|
183
|
+
for (Tid k = 0; k < K; ++k)
|
|
184
|
+
{
|
|
185
|
+
ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
|
|
186
|
+
ll -= math::lgammaT(alphaDoc[k]);
|
|
187
|
+
}
|
|
188
|
+
ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum);
|
|
189
|
+
ll += math::lgammaT(alphaSum);
|
|
190
|
+
return ll;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
template<typename _DocIter>
|
|
194
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
195
|
+
{
|
|
196
|
+
const auto K = this->K;
|
|
197
|
+
|
|
198
|
+
double ll = 0;
|
|
199
|
+
for (; _first != _last; ++_first)
|
|
200
|
+
{
|
|
201
|
+
auto& doc = *_first;
|
|
202
|
+
auto alphaDoc = expLambda.col(doc.metadata);
|
|
203
|
+
Float alphaSum = alphaDoc.sum();
|
|
204
|
+
|
|
205
|
+
for (Tid k = 0; k < K; ++k)
|
|
206
|
+
{
|
|
207
|
+
ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]) - math::lgammaT(alphaDoc[k]);
|
|
208
|
+
}
|
|
209
|
+
ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum) - math::lgammaT(alphaSum);
|
|
210
|
+
}
|
|
211
|
+
return ll;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
double getLLRest(const _ModelState& ld) const
|
|
215
|
+
{
|
|
216
|
+
const auto K = this->K;
|
|
217
|
+
const auto alpha = this->alpha;
|
|
218
|
+
const auto eta = this->eta;
|
|
219
|
+
const size_t V = this->realV;
|
|
220
|
+
|
|
221
|
+
double ll = -(lambda.array() - log(alpha)).pow(2).sum() / 2 / pow(sigma, 2);
|
|
222
|
+
// topic-word distribution
|
|
223
|
+
auto lgammaEta = math::lgammaT(eta);
|
|
224
|
+
ll += math::lgammaT(V*eta) * K;
|
|
225
|
+
for (Tid k = 0; k < K; ++k)
|
|
226
|
+
{
|
|
227
|
+
ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
|
|
228
|
+
for (Vid v = 0; v < V; ++v)
|
|
229
|
+
{
|
|
230
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
|
231
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
return ll;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
void initGlobalState(bool initDocs)
|
|
238
|
+
{
|
|
239
|
+
BaseClass::initGlobalState(initDocs);
|
|
240
|
+
this->globalState.tmpK = Eigen::Matrix<Float, -1, 1>::Zero(this->K);
|
|
241
|
+
F = metadataDict.size();
|
|
242
|
+
if (initDocs)
|
|
243
|
+
{
|
|
244
|
+
lambda = Eigen::Matrix<Float, -1, -1>::Constant(this->K, F, log(this->alpha));
|
|
245
|
+
}
|
|
246
|
+
if (_Flags & flags::continuous_doc_data) this->numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, this->docs.size());
|
|
247
|
+
expLambda = lambda.array().exp();
|
|
248
|
+
LBFGSpp::LBFGSParam<Float> param;
|
|
249
|
+
param.max_iterations = maxBFGSIteration;
|
|
250
|
+
solver = decltype(solver){ param };
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
public:
|
|
254
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma, alphaEps, metadataDict, lambda);
|
|
255
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda);
|
|
256
|
+
|
|
257
|
+
DMRModel(size_t _K = 1, Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01,
|
|
258
|
+
Float _alphaEps = 0, size_t _rg = std::random_device{}())
|
|
259
|
+
: BaseClass(_K, defaultAlpha, _eta, _rg), sigma(_sigma), alphaEps(_alphaEps)
|
|
260
|
+
{
|
|
261
|
+
if (_sigma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong sigma value (sigma = %f)", _sigma));
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
template<bool _const = false>
|
|
265
|
+
_DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& metadata)
|
|
266
|
+
{
|
|
267
|
+
std::string metadataJoined = text::join(metadata.begin(), metadata.end(), "_");
|
|
268
|
+
Vid xid;
|
|
269
|
+
if (_const)
|
|
270
|
+
{
|
|
271
|
+
xid = metadataDict.toWid(metadataJoined);
|
|
272
|
+
if (xid == (Vid)-1) throw std::invalid_argument("unknown metadata");
|
|
273
|
+
}
|
|
274
|
+
else
|
|
275
|
+
{
|
|
276
|
+
xid = metadataDict.add(metadataJoined);
|
|
277
|
+
}
|
|
278
|
+
doc.metadata = xid;
|
|
279
|
+
return doc;
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) override
|
|
283
|
+
{
|
|
284
|
+
auto doc = this->_makeDoc(words);
|
|
285
|
+
return this->_addDoc(_updateDoc(doc, metadata));
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) const override
|
|
289
|
+
{
|
|
290
|
+
auto doc = as_mutable(this)->template _makeDoc<true>(words);
|
|
291
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
295
|
+
const std::vector<std::string>& metadata) override
|
|
296
|
+
{
|
|
297
|
+
auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
|
|
298
|
+
return this->_addDoc(_updateDoc(doc, metadata));
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
302
|
+
const std::vector<std::string>& metadata) const override
|
|
303
|
+
{
|
|
304
|
+
auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
|
|
305
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
309
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
310
|
+
const std::vector<std::string>& metadata) override
|
|
311
|
+
{
|
|
312
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
313
|
+
return this->_addDoc(_updateDoc(doc, metadata));
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
317
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
318
|
+
const std::vector<std::string>& metadata) const override
|
|
319
|
+
{
|
|
320
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
321
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
GETTER(F, size_t, F);
|
|
325
|
+
GETTER(Sigma, Float, sigma);
|
|
326
|
+
GETTER(AlphaEps, Float, alphaEps);
|
|
327
|
+
GETTER(OptimRepeat, size_t, optimRepeat);
|
|
328
|
+
|
|
329
|
+
void setAlphaEps(Float _alphaEps) override
|
|
330
|
+
{
|
|
331
|
+
alphaEps = _alphaEps;
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
void setOptimRepeat(size_t _optimRepeat) override
|
|
335
|
+
{
|
|
336
|
+
optimRepeat = _optimRepeat;
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
340
|
+
{
|
|
341
|
+
std::vector<Float> ret(this->K);
|
|
342
|
+
auto alphaDoc = expLambda.col(doc.metadata);
|
|
343
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() =
|
|
344
|
+
(doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum());
|
|
345
|
+
return ret;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
std::vector<Float> getLambdaByMetadata(size_t metadataId) const override
|
|
349
|
+
{
|
|
350
|
+
assert(metadataId < metadataDict.size());
|
|
351
|
+
auto l = lambda.col(metadataId);
|
|
352
|
+
return { l.data(), l.data() + this->K };
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
std::vector<Float> getLambdaByTopic(Tid tid) const override
|
|
356
|
+
{
|
|
357
|
+
assert(tid < this->K);
|
|
358
|
+
auto l = lambda.row(tid);
|
|
359
|
+
return { l.data(), l.data() + F };
|
|
360
|
+
}
|
|
361
|
+
|
|
362
|
+
const Dictionary& getMetadataDict() const override { return metadataDict; }
|
|
363
|
+
};
|
|
364
|
+
|
|
365
|
+
/* This is for preventing 'undefined symbol' problem in compiling by clang. */
|
|
366
|
+
template<TermWeight _tw, typename _RandGen, size_t _Flags,
|
|
367
|
+
typename _Interface, typename _Derived, typename _DocType, typename _ModelState>
|
|
368
|
+
constexpr Float DMRModel<_tw, _RandGen, _Flags, _Interface, _Derived, _DocType, _ModelState>::maxLambda;
|
|
369
|
+
|
|
370
|
+
template<TermWeight _tw, typename _RandGen, size_t _Flags,
|
|
371
|
+
typename _Interface, typename _Derived, typename _DocType, typename _ModelState>
|
|
372
|
+
constexpr size_t DMRModel<_tw, _RandGen, _Flags, _Interface, _Derived, _DocType, _ModelState>::maxBFGSIteration;
|
|
373
|
+
|
|
374
|
+
}
|