tomoto 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +123 -0
- data/ext/tomoto/ext.cpp +245 -0
- data/ext/tomoto/extconf.rb +28 -0
- data/lib/tomoto.rb +12 -0
- data/lib/tomoto/ct.rb +11 -0
- data/lib/tomoto/hdp.rb +11 -0
- data/lib/tomoto/lda.rb +67 -0
- data/lib/tomoto/version.rb +3 -0
- data/vendor/EigenRand/EigenRand/Core.h +1139 -0
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
- data/vendor/EigenRand/EigenRand/EigenRand +19 -0
- data/vendor/EigenRand/EigenRand/Macro.h +24 -0
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
- data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
- data/vendor/EigenRand/EigenRand/doc.h +220 -0
- data/vendor/EigenRand/LICENSE +21 -0
- data/vendor/EigenRand/README.md +288 -0
- data/vendor/eigen/COPYING.BSD +26 -0
- data/vendor/eigen/COPYING.GPL +674 -0
- data/vendor/eigen/COPYING.LGPL +502 -0
- data/vendor/eigen/COPYING.MINPACK +52 -0
- data/vendor/eigen/COPYING.MPL2 +373 -0
- data/vendor/eigen/COPYING.README +18 -0
- data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
- data/vendor/eigen/Eigen/Cholesky +46 -0
- data/vendor/eigen/Eigen/CholmodSupport +48 -0
- data/vendor/eigen/Eigen/Core +537 -0
- data/vendor/eigen/Eigen/Dense +7 -0
- data/vendor/eigen/Eigen/Eigen +2 -0
- data/vendor/eigen/Eigen/Eigenvalues +61 -0
- data/vendor/eigen/Eigen/Geometry +62 -0
- data/vendor/eigen/Eigen/Householder +30 -0
- data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
- data/vendor/eigen/Eigen/Jacobi +33 -0
- data/vendor/eigen/Eigen/LU +50 -0
- data/vendor/eigen/Eigen/MetisSupport +35 -0
- data/vendor/eigen/Eigen/OrderingMethods +73 -0
- data/vendor/eigen/Eigen/PaStiXSupport +48 -0
- data/vendor/eigen/Eigen/PardisoSupport +35 -0
- data/vendor/eigen/Eigen/QR +51 -0
- data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
- data/vendor/eigen/Eigen/SPQRSupport +34 -0
- data/vendor/eigen/Eigen/SVD +51 -0
- data/vendor/eigen/Eigen/Sparse +36 -0
- data/vendor/eigen/Eigen/SparseCholesky +45 -0
- data/vendor/eigen/Eigen/SparseCore +69 -0
- data/vendor/eigen/Eigen/SparseLU +46 -0
- data/vendor/eigen/Eigen/SparseQR +37 -0
- data/vendor/eigen/Eigen/StdDeque +27 -0
- data/vendor/eigen/Eigen/StdList +26 -0
- data/vendor/eigen/Eigen/StdVector +27 -0
- data/vendor/eigen/Eigen/SuperLUSupport +64 -0
- data/vendor/eigen/Eigen/UmfPackSupport +40 -0
- data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
- data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
- data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
- data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
- data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
- data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
- data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
- data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
- data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
- data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
- data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
- data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
- data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
- data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
- data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
- data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
- data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
- data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
- data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
- data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
- data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
- data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
- data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
- data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
- data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
- data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
- data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
- data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
- data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
- data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
- data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
- data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
- data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
- data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
- data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
- data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
- data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
- data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
- data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
- data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
- data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
- data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
- data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
- data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
- data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
- data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
- data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
- data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
- data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
- data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
- data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
- data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
- data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
- data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
- data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
- data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
- data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
- data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
- data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
- data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
- data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
- data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
- data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
- data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
- data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
- data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
- data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
- data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
- data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
- data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
- data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
- data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
- data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
- data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
- data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
- data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
- data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
- data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
- data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
- data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
- data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
- data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
- data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
- data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
- data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
- data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
- data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
- data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
- data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
- data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
- data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
- data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
- data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
- data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
- data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
- data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
- data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
- data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
- data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
- data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
- data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
- data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
- data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
- data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
- data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
- data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
- data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
- data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
- data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
- data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
- data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
- data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
- data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
- data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
- data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
- data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
- data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
- data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
- data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
- data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
- data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
- data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
- data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
- data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
- data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
- data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
- data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
- data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
- data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
- data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
- data/vendor/eigen/README.md +3 -0
- data/vendor/eigen/bench/README.txt +55 -0
- data/vendor/eigen/bench/btl/COPYING +340 -0
- data/vendor/eigen/bench/btl/README +154 -0
- data/vendor/eigen/bench/tensors/README +21 -0
- data/vendor/eigen/blas/README.txt +6 -0
- data/vendor/eigen/demos/mandelbrot/README +10 -0
- data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
- data/vendor/eigen/demos/opengl/README +13 -0
- data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
- data/vendor/eigen/unsupported/README.txt +50 -0
- data/vendor/tomotopy/LICENSE +21 -0
- data/vendor/tomotopy/README.kr.rst +375 -0
- data/vendor/tomotopy/README.rst +382 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
- data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
- data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
- data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
- data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
- data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
- data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
- data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
- data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
- data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
- data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
- data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
- data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
- data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
- data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
- data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
- data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
- data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
- data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
- data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
- data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
- data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
- data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
- data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
- data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
- data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
- data/vendor/tomotopy/src/Utils/exception.h +28 -0
- data/vendor/tomotopy/src/Utils/math.h +281 -0
- data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
- data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
- data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
- data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
- data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
- data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
- data/vendor/tomotopy/src/Utils/text.hpp +49 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
- metadata +531 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
template<TermWeight _tw>
|
|
7
|
+
struct DocumentLLDA : public DocumentLDA<_tw>
|
|
8
|
+
{
|
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
+
using WeightType = typename DocumentLDA<_tw>::WeightType;
|
|
12
|
+
Eigen::Matrix<int8_t, -1, 1> labelMask;
|
|
13
|
+
|
|
14
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, labelMask);
|
|
15
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, labelMask);
|
|
16
|
+
};
|
|
17
|
+
|
|
18
|
+
class ILLDAModel : public ILDAModel
|
|
19
|
+
{
|
|
20
|
+
public:
|
|
21
|
+
using DefaultDocType = DocumentLLDA<TermWeight::one>;
|
|
22
|
+
static ILLDAModel* create(TermWeight _weight, size_t _K = 1,
|
|
23
|
+
Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
|
|
24
|
+
bool scalarRng = false);
|
|
25
|
+
|
|
26
|
+
virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) = 0;
|
|
27
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) const = 0;
|
|
28
|
+
|
|
29
|
+
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
30
|
+
const std::vector<std::string>& label) = 0;
|
|
31
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
32
|
+
const std::vector<std::string>& label) const = 0;
|
|
33
|
+
|
|
34
|
+
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
35
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
36
|
+
const std::vector<std::string>& label) = 0;
|
|
37
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
38
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
39
|
+
const std::vector<std::string>& label) const = 0;
|
|
40
|
+
|
|
41
|
+
virtual const Dictionary& getTopicLabelDict() const = 0;
|
|
42
|
+
|
|
43
|
+
virtual size_t getNumTopicsPerLabel() const = 0;
|
|
44
|
+
};
|
|
45
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "LLDAModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class LLDAModel<TermWeight::one>;
|
|
6
|
+
template class LLDAModel<TermWeight::idf>;
|
|
7
|
+
template class LLDAModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
ILLDAModel* ILLDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
10
|
+
{
|
|
11
|
+
TMT_SWITCH_TW(_weight, scalarRng, LLDAModel, _K, _alpha, _eta, seed);
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "LLDA.h"
|
|
4
|
+
|
|
5
|
+
/*
|
|
6
|
+
Implementation of Labeled LDA using Gibbs sampling by bab2min
|
|
7
|
+
|
|
8
|
+
* Ramage, D., Hall, D., Nallapati, R., & Manning, C. D. (2009, August). Labeled LDA: A supervised topic model for credit attribution in multi-labeled corpora. In Proceedings of the 2009 Conference on Empirical Methods in Natural Language Processing: Volume 1-Volume 1 (pp. 248-256). Association for Computational Linguistics.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
namespace tomoto
|
|
12
|
+
{
|
|
13
|
+
template<TermWeight _tw, typename _RandGen,
|
|
14
|
+
typename _Interface = ILLDAModel,
|
|
15
|
+
typename _Derived = void,
|
|
16
|
+
typename _DocType = DocumentLLDA<_tw>,
|
|
17
|
+
typename _ModelState = ModelStateLDA<_tw>>
|
|
18
|
+
class LLDAModel : public LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface,
|
|
19
|
+
typename std::conditional<std::is_same<_Derived, void>::value, LLDAModel<_tw, _RandGen>, _Derived>::type,
|
|
20
|
+
_DocType, _ModelState>
|
|
21
|
+
{
|
|
22
|
+
protected:
|
|
23
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, LLDAModel<_tw, _RandGen>, _Derived>::type;
|
|
24
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
25
|
+
friend BaseClass;
|
|
26
|
+
friend typename BaseClass::BaseClass;
|
|
27
|
+
using WeightType = typename BaseClass::WeightType;
|
|
28
|
+
|
|
29
|
+
static constexpr char TMID[] = "LLDA";
|
|
30
|
+
|
|
31
|
+
Dictionary topicLabelDict;
|
|
32
|
+
|
|
33
|
+
template<bool _asymEta>
|
|
34
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
35
|
+
{
|
|
36
|
+
const size_t V = this->realV;
|
|
37
|
+
assert(vid < V);
|
|
38
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
39
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
|
|
40
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + this->eta)
|
|
41
|
+
/ (ld.numByTopic.array().template cast<Float>() + V * this->eta);
|
|
42
|
+
zLikelihood.array() *= doc.labelMask.array().template cast<Float>();
|
|
43
|
+
sample::prefixSum(zLikelihood.data(), this->K);
|
|
44
|
+
return &zLikelihood[0];
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
48
|
+
{
|
|
49
|
+
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
50
|
+
if (doc.labelMask.size() == 0)
|
|
51
|
+
{
|
|
52
|
+
doc.labelMask.resize(this->K);
|
|
53
|
+
doc.labelMask.setOnes();
|
|
54
|
+
}
|
|
55
|
+
else if (doc.labelMask.size() < this->K)
|
|
56
|
+
{
|
|
57
|
+
size_t oldSize = doc.labelMask.size();
|
|
58
|
+
doc.labelMask.conservativeResize(this->K);
|
|
59
|
+
doc.labelMask.segment(oldSize, topicLabelDict.size() - oldSize).setZero();
|
|
60
|
+
doc.labelMask.segment(topicLabelDict.size(), this->K - topicLabelDict.size()).setOnes();
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
void initGlobalState(bool initDocs)
|
|
65
|
+
{
|
|
66
|
+
this->K = std::max(topicLabelDict.size(), (size_t)this->K);
|
|
67
|
+
this->alphas.resize(this->K);
|
|
68
|
+
this->alphas.array() = this->alpha;
|
|
69
|
+
BaseClass::initGlobalState(initDocs);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
struct Generator
|
|
73
|
+
{
|
|
74
|
+
std::discrete_distribution<> theta;
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
Generator makeGeneratorForInit(const _DocType* doc) const
|
|
78
|
+
{
|
|
79
|
+
std::discrete_distribution<> theta{ doc->labelMask.data(), doc->labelMask.data() + this->K };
|
|
80
|
+
return Generator{ theta };
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
template<bool _Infer>
|
|
84
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
|
85
|
+
{
|
|
86
|
+
auto& z = doc.Zs[i];
|
|
87
|
+
auto w = doc.words[i];
|
|
88
|
+
if (this->etaByTopicWord.size())
|
|
89
|
+
{
|
|
90
|
+
Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
|
|
91
|
+
for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
|
|
92
|
+
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
93
|
+
}
|
|
94
|
+
else
|
|
95
|
+
{
|
|
96
|
+
z = g.theta(rgs);
|
|
97
|
+
}
|
|
98
|
+
this->template addWordTo<1>(ld, doc, i, w, z);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
public:
|
|
102
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict);
|
|
103
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict);
|
|
104
|
+
|
|
105
|
+
LLDAModel(size_t _K = 1, Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
|
|
106
|
+
: BaseClass(_K, _alpha, _eta, _rg)
|
|
107
|
+
{
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
template<bool _const = false>
|
|
111
|
+
_DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& labels)
|
|
112
|
+
{
|
|
113
|
+
if (_const)
|
|
114
|
+
{
|
|
115
|
+
doc.labelMask.resize(this->K);
|
|
116
|
+
doc.labelMask.setOnes();
|
|
117
|
+
|
|
118
|
+
std::vector<Vid> topicLabelIds;
|
|
119
|
+
for (auto& label : labels)
|
|
120
|
+
{
|
|
121
|
+
auto tid = topicLabelDict.toWid(label);
|
|
122
|
+
if (tid == (Vid)-1) continue;
|
|
123
|
+
topicLabelIds.emplace_back(tid);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if (!topicLabelIds.empty())
|
|
127
|
+
{
|
|
128
|
+
doc.labelMask.head(topicLabelDict.size()).setZero();
|
|
129
|
+
for (auto tid : topicLabelIds) doc.labelMask[tid] = 1;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
else
|
|
133
|
+
{
|
|
134
|
+
if (!labels.empty())
|
|
135
|
+
{
|
|
136
|
+
std::vector<Vid> topicLabelIds;
|
|
137
|
+
for (auto& label : labels) topicLabelIds.emplace_back(topicLabelDict.add(label));
|
|
138
|
+
auto maxVal = *std::max_element(topicLabelIds.begin(), topicLabelIds.end());
|
|
139
|
+
doc.labelMask.resize(maxVal + 1);
|
|
140
|
+
doc.labelMask.setZero();
|
|
141
|
+
for (auto i : topicLabelIds) doc.labelMask[i] = 1;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
return doc;
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
|
|
148
|
+
{
|
|
149
|
+
auto doc = this->_makeDoc(words);
|
|
150
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
|
|
154
|
+
{
|
|
155
|
+
auto doc = as_mutable(this)->template _makeDoc<true>(words);
|
|
156
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
160
|
+
const std::vector<std::string>& labels) override
|
|
161
|
+
{
|
|
162
|
+
auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
|
|
163
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
167
|
+
const std::vector<std::string>& labels) const override
|
|
168
|
+
{
|
|
169
|
+
auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
|
|
170
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
174
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
175
|
+
const std::vector<std::string>& labels) override
|
|
176
|
+
{
|
|
177
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
178
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
182
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
183
|
+
const std::vector<std::string>& labels) const override
|
|
184
|
+
{
|
|
185
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
186
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
190
|
+
{
|
|
191
|
+
std::vector<Float> ret(this->K);
|
|
192
|
+
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
|
193
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
|
|
194
|
+
(doc.numByTopic.array().template cast<Float>() + maskedAlphas)
|
|
195
|
+
/ (doc.getSumWordWeight() + maskedAlphas.sum());
|
|
196
|
+
return ret;
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
const Dictionary& getTopicLabelDict() const override { return topicLabelDict; }
|
|
200
|
+
|
|
201
|
+
size_t getNumTopicsPerLabel() const override { return 1; }
|
|
202
|
+
};
|
|
203
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
template<TermWeight _tw>
|
|
7
|
+
struct DocumentMGLDA : public DocumentLDA<_tw>
|
|
8
|
+
{
|
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
+
using WeightType = typename DocumentLDA<_tw>::WeightType;
|
|
12
|
+
|
|
13
|
+
std::vector<uint16_t> sents; // sentence id of each word (const)
|
|
14
|
+
std::vector<WeightType> numBySent; // number of words in the sentence (const)
|
|
15
|
+
|
|
16
|
+
//std::vector<Tid> Zs; // gl./loc. and topic assignment
|
|
17
|
+
std::vector<uint8_t> Vs; // window assignment
|
|
18
|
+
WeightType numGl = 0; // number of words assigned as gl.
|
|
19
|
+
//std::vector<uint32_t> numByTopic; // len = K + KL
|
|
20
|
+
Eigen::Matrix<WeightType, -1, -1> numBySentWin; // len = S * T
|
|
21
|
+
Eigen::Matrix<WeightType, -1, 1> numByWinL; // number of words assigned as loc. in the window (len = S + T - 1)
|
|
22
|
+
Eigen::Matrix<WeightType, -1, 1> numByWin; // number of words in the window (len = S + T - 1)
|
|
23
|
+
Eigen::Matrix<WeightType, -1, -1> numByWinTopicL; // number of words in the loc. topic in the window (len = KL * (S + T - 1))
|
|
24
|
+
|
|
25
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL);
|
|
26
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL);
|
|
27
|
+
|
|
28
|
+
template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
class IMGLDAModel : public ILDAModel
|
|
32
|
+
{
|
|
33
|
+
public:
|
|
34
|
+
using DefaultDocType = DocumentMGLDA<TermWeight::one>;
|
|
35
|
+
static IMGLDAModel* create(TermWeight _weight, size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
|
|
36
|
+
Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
|
|
37
|
+
Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t seed = std::random_device{}(),
|
|
38
|
+
bool scalarRng = false);
|
|
39
|
+
|
|
40
|
+
virtual size_t addDoc(const std::vector<std::string>& words, const std::string& delimiter) = 0;
|
|
41
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::string& delimiter) const = 0;
|
|
42
|
+
|
|
43
|
+
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
44
|
+
const std::string& delimiter) = 0;
|
|
45
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
46
|
+
const std::string& delimiter) const = 0;
|
|
47
|
+
|
|
48
|
+
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
49
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
50
|
+
const std::string& delimiter) = 0;
|
|
51
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
52
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
53
|
+
const std::string& delimiter) const = 0;
|
|
54
|
+
|
|
55
|
+
virtual size_t getKL() const = 0;
|
|
56
|
+
virtual size_t getT() const = 0;
|
|
57
|
+
virtual Float getAlphaL() const = 0;
|
|
58
|
+
virtual Float getEtaL() const = 0;
|
|
59
|
+
virtual Float getGamma() const = 0;
|
|
60
|
+
virtual Float getAlphaM() const = 0;
|
|
61
|
+
virtual Float getAlphaML() const = 0;
|
|
62
|
+
};
|
|
63
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#include "MGLDAModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class MGLDAModel<TermWeight::one>;
|
|
6
|
+
template class MGLDAModel<TermWeight::idf>;
|
|
7
|
+
template class MGLDAModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
IMGLDAModel* IMGLDAModel::create(TermWeight _weight, size_t _KG, size_t _KL, size_t _T,
|
|
10
|
+
Float _alphaG, Float _alphaL, Float _alphaMG, Float _alphaML,
|
|
11
|
+
Float _etaG, Float _etaL, Float _gamma, size_t seed, bool scalarRng)
|
|
12
|
+
{
|
|
13
|
+
TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel, _KG, _KL, _T,
|
|
14
|
+
_alphaG, _alphaL, _alphaMG, _alphaML,
|
|
15
|
+
_etaG, _etaL, _gamma, seed);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "MGLDA.h"
|
|
4
|
+
/*
|
|
5
|
+
Implementation of MG-LDA using Gibbs sampling by bab2min
|
|
6
|
+
Improved version of java implementation(https://github.com/yinfeiy/MG-LDA)
|
|
7
|
+
|
|
8
|
+
* Titov, I., & McDonald, R. (2008, April). Modeling online reviews with multi-grain topic models. In Proceedings of the 17th international conference on World Wide Web (pp. 111-120). ACM.
|
|
9
|
+
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
namespace tomoto
|
|
13
|
+
{
|
|
14
|
+
template<TermWeight _tw, typename _RandGen,
|
|
15
|
+
typename _Interface = IMGLDAModel,
|
|
16
|
+
typename _Derived = void,
|
|
17
|
+
typename _DocType = DocumentMGLDA<_tw>,
|
|
18
|
+
typename _ModelState = ModelStateLDA<_tw>>
|
|
19
|
+
class MGLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
|
|
20
|
+
typename std::conditional<std::is_same<_Derived, void>::value, MGLDAModel<_tw, _RandGen>, _Derived>::type,
|
|
21
|
+
_DocType, _ModelState>
|
|
22
|
+
{
|
|
23
|
+
protected:
|
|
24
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, MGLDAModel<_tw, _RandGen>, _Derived>::type;
|
|
25
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
26
|
+
friend BaseClass;
|
|
27
|
+
friend typename BaseClass::BaseClass;
|
|
28
|
+
using WeightType = typename BaseClass::WeightType;
|
|
29
|
+
|
|
30
|
+
Float alphaL;
|
|
31
|
+
Float alphaM, alphaML;
|
|
32
|
+
Float etaL;
|
|
33
|
+
Float gamma;
|
|
34
|
+
Tid KL;
|
|
35
|
+
uint32_t T; // window size
|
|
36
|
+
|
|
37
|
+
// window and gl./loc. and topic assignment likelihoods for new word. ret T*(K+KL) FLOATs
|
|
38
|
+
Float* getVZLikelihoods(_ModelState& ld, const _DocType& doc, Vid vid, uint16_t s) const
|
|
39
|
+
{
|
|
40
|
+
const auto V = this->realV;
|
|
41
|
+
const auto K = this->K;
|
|
42
|
+
const auto alpha = this->alpha;
|
|
43
|
+
const auto eta = this->eta;
|
|
44
|
+
assert(vid < V);
|
|
45
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
46
|
+
for (size_t v = 0; v < T; ++v)
|
|
47
|
+
{
|
|
48
|
+
Float pLoc = (doc.numByWinL[s + v] + alphaML) / (doc.numByWin[s + v] + alphaM + alphaML);
|
|
49
|
+
Float pW = doc.numBySentWin(s, v) + gamma;
|
|
50
|
+
if (K)
|
|
51
|
+
{
|
|
52
|
+
zLikelihood.segment(v * (K + KL), K) = (1 - pLoc) * pW
|
|
53
|
+
* (doc.numByTopic.segment(0, K).array().template cast<Float>() + alpha) / (doc.numGl + K * alpha)
|
|
54
|
+
* (ld.numByTopicWord.block(0, vid, K, 1).array().template cast<Float>() + eta) / (ld.numByTopic.segment(0, K).array().template cast<Float>() + V * eta);
|
|
55
|
+
}
|
|
56
|
+
zLikelihood.segment(v * (K + KL) + K, KL) = pLoc * pW
|
|
57
|
+
* (doc.numByWinTopicL.col(s + v).array().template cast<Float>()) / (doc.numByWinL[s + v] + KL * alphaL)
|
|
58
|
+
* (ld.numByTopicWord.block(K, vid, KL, 1).array().template cast<Float>() + etaL) / (ld.numByTopic.segment(K, KL).array().template cast<Float>() + V * etaL);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
sample::prefixSum(zLikelihood.data(), T * (K + KL));
|
|
62
|
+
return &zLikelihood[0];
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
template<int _inc>
|
|
66
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
|
|
67
|
+
{
|
|
68
|
+
const auto K = this->K;
|
|
69
|
+
|
|
70
|
+
assert(r != 0 || tid < K);
|
|
71
|
+
assert(r == 0 || tid < KL);
|
|
72
|
+
assert(w < T);
|
|
73
|
+
assert(r < 2);
|
|
74
|
+
assert(vid < this->realV);
|
|
75
|
+
assert(s < doc.numBySent.size());
|
|
76
|
+
|
|
77
|
+
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
|
78
|
+
typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
|
|
79
|
+
= _tw != TermWeight::one ? doc.wordWeights[pid] : 1;
|
|
80
|
+
|
|
81
|
+
updateCnt<_dec>(doc.numByWin[s + w], _inc * weight);
|
|
82
|
+
updateCnt<_dec>(doc.numBySentWin(s, w), _inc * weight);
|
|
83
|
+
if (r == 0)
|
|
84
|
+
{
|
|
85
|
+
updateCnt<_dec>(doc.numByTopic[tid], _inc * weight);
|
|
86
|
+
updateCnt<_dec>(doc.numGl, _inc * weight);
|
|
87
|
+
updateCnt<_dec>(ld.numByTopic[tid], _inc * weight);
|
|
88
|
+
updateCnt<_dec>(ld.numByTopicWord(tid, vid), _inc * weight);
|
|
89
|
+
}
|
|
90
|
+
else
|
|
91
|
+
{
|
|
92
|
+
updateCnt<_dec>(doc.numByTopic[tid + K], _inc * weight);
|
|
93
|
+
updateCnt<_dec>(doc.numByWinL[s + w], _inc * weight);
|
|
94
|
+
updateCnt<_dec>(doc.numByWinTopicL(tid, s + w), _inc * weight);
|
|
95
|
+
updateCnt<_dec>(ld.numByTopic[tid + K], _inc * weight);
|
|
96
|
+
updateCnt<_dec>(ld.numByTopicWord(tid + K, vid), _inc * weight);
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
|
|
101
|
+
void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
|
|
102
|
+
{
|
|
103
|
+
size_t b = 0, e = doc.words.size();
|
|
104
|
+
if (_ps == ParallelScheme::partition)
|
|
105
|
+
{
|
|
106
|
+
b = edd.chunkOffsetByDoc(partitionId, docId);
|
|
107
|
+
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
|
111
|
+
|
|
112
|
+
const auto K = this->K;
|
|
113
|
+
for (size_t w = b; w < e; ++w)
|
|
114
|
+
{
|
|
115
|
+
if (doc.words[w] >= this->realV) continue;
|
|
116
|
+
addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
|
117
|
+
auto dist = getVZLikelihoods(ld, doc, doc.words[w] - vOffset, doc.sents[w]);
|
|
118
|
+
auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
|
|
119
|
+
doc.Vs[w] = vz / (K + KL);
|
|
120
|
+
doc.Zs[w] = vz % (K + KL);
|
|
121
|
+
addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
template<typename _DocIter>
|
|
126
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
127
|
+
{
|
|
128
|
+
const auto K = this->K;
|
|
129
|
+
const auto alpha = this->alpha;
|
|
130
|
+
|
|
131
|
+
size_t totSents = 0, totWins = 0;
|
|
132
|
+
double ll = 0;
|
|
133
|
+
if (K) ll += (math::lgammaT(K*alpha) - math::lgammaT(alpha)*K) * std::distance(_first, _last);
|
|
134
|
+
for (; _first != _last; ++_first)
|
|
135
|
+
{
|
|
136
|
+
auto& doc = *_first;
|
|
137
|
+
const size_t S = doc.numBySent.size();
|
|
138
|
+
if (K)
|
|
139
|
+
{
|
|
140
|
+
ll -= math::lgammaT(doc.numGl + K * alpha);
|
|
141
|
+
for (Tid k = 0; k < K; ++k)
|
|
142
|
+
{
|
|
143
|
+
ll += math::lgammaT(doc.numByTopic[k] + alpha);
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
for (size_t v = 0; v < S + T - 1; ++v)
|
|
148
|
+
{
|
|
149
|
+
ll -= math::lgammaT(doc.numByWinL[v] + KL * alphaL);
|
|
150
|
+
for (Tid k = 0; k < KL; ++k)
|
|
151
|
+
{
|
|
152
|
+
ll += math::lgammaT(doc.numByWinTopicL(k, v) + alphaL);
|
|
153
|
+
}
|
|
154
|
+
if (K)
|
|
155
|
+
{
|
|
156
|
+
ll += math::lgammaT(std::max((Float)doc.numByWin[v] - doc.numByWinL[v], (Float)0) + alphaM);
|
|
157
|
+
ll += math::lgammaT(doc.numByWinL[v] + alphaML);
|
|
158
|
+
ll -= math::lgammaT(doc.numByWin[v] + alphaM + alphaML);
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
totWins += S + T - 1;
|
|
163
|
+
totSents += S;
|
|
164
|
+
for (size_t s = 0; s < S; ++s)
|
|
165
|
+
{
|
|
166
|
+
ll -= math::lgammaT(doc.numBySent[s] + T * gamma);
|
|
167
|
+
for (size_t v = 0; v < T; ++v)
|
|
168
|
+
{
|
|
169
|
+
ll += math::lgammaT(doc.numBySentWin(s, v) + gamma);
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
ll += (math::lgammaT(KL*alphaL) - math::lgammaT(alphaL)*KL) * totWins;
|
|
174
|
+
if (K) ll += (math::lgammaT(alphaM + alphaML) - math::lgammaT(alphaM) - math::lgammaT(alphaML)) * totWins;
|
|
175
|
+
ll += (math::lgammaT(T * gamma) - math::lgammaT(gamma) * T) * totSents;
|
|
176
|
+
|
|
177
|
+
return ll;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
double getLLRest(const _ModelState& ld) const
|
|
181
|
+
{
|
|
182
|
+
const auto V = this->realV;
|
|
183
|
+
const auto K = this->K;
|
|
184
|
+
const auto eta = this->eta;
|
|
185
|
+
|
|
186
|
+
double ll = 0;
|
|
187
|
+
ll += (math::lgammaT(V*eta) - math::lgammaT(eta)*V) * K;
|
|
188
|
+
for (Tid k = 0; k < K; ++k)
|
|
189
|
+
{
|
|
190
|
+
ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
|
|
191
|
+
for (Vid w = 0; w < V; ++w)
|
|
192
|
+
{
|
|
193
|
+
ll += math::lgammaT(ld.numByTopicWord(k, w) + eta);
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
ll += (math::lgammaT(V*etaL) - math::lgammaT(etaL)*V) * KL;
|
|
197
|
+
for (Tid k = 0; k < KL; ++k)
|
|
198
|
+
{
|
|
199
|
+
ll -= math::lgammaT(ld.numByTopic[k + K] + V * etaL);
|
|
200
|
+
for (Vid w = 0; w < V; ++w)
|
|
201
|
+
{
|
|
202
|
+
ll += math::lgammaT(ld.numByTopicWord(k + K, w) + etaL);
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
return ll;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
double getLL() const
|
|
209
|
+
{
|
|
210
|
+
double ll = 0;
|
|
211
|
+
const auto V = this->realV;
|
|
212
|
+
const auto K = this->K;
|
|
213
|
+
const auto alpha = this->alpha;
|
|
214
|
+
const auto eta = this->eta;
|
|
215
|
+
size_t totSents = 0, totWins = 0;
|
|
216
|
+
if(K) ll += (math::lgammaT(K*alpha) - math::lgammaT(alpha)*K) * this->docs.size();
|
|
217
|
+
for (size_t i = 0; i < this->docs.size(); ++i)
|
|
218
|
+
{
|
|
219
|
+
auto& doc = this->docs[i];
|
|
220
|
+
const size_t S = doc.numBySent.size();
|
|
221
|
+
if (K)
|
|
222
|
+
{
|
|
223
|
+
ll -= math::lgammaT(doc.numGl + K * alpha);
|
|
224
|
+
for (Tid k = 0; k < K; ++k)
|
|
225
|
+
{
|
|
226
|
+
ll += math::lgammaT(doc.numByTopic[k] + alpha);
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
for (size_t v = 0; v < S + T - 1; ++v)
|
|
231
|
+
{
|
|
232
|
+
ll -= math::lgammaT(doc.numByWinL[v] + KL * alphaL);
|
|
233
|
+
for (Tid k = 0; k < KL; ++k)
|
|
234
|
+
{
|
|
235
|
+
ll += math::lgammaT(doc.numByWinTopicL(k, v) + alphaL);
|
|
236
|
+
}
|
|
237
|
+
if (K)
|
|
238
|
+
{
|
|
239
|
+
ll += math::lgammaT(std::max((Float)doc.numByWin[v] - doc.numByWinL[v], (Float)0) + alphaM);
|
|
240
|
+
ll += math::lgammaT(doc.numByWinL[v] + alphaML);
|
|
241
|
+
ll -= math::lgammaT(doc.numByWin[v] + alphaM + alphaML);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
totWins += S + T - 1;
|
|
246
|
+
totSents += S;
|
|
247
|
+
for (size_t s = 0; s < S; ++s)
|
|
248
|
+
{
|
|
249
|
+
ll -= math::lgammaT(doc.numBySent[s] + T * gamma);
|
|
250
|
+
for (size_t v = 0; v < T; ++v)
|
|
251
|
+
{
|
|
252
|
+
ll += math::lgammaT(doc.numBySentWin(s, v) + gamma);
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
ll += (math::lgammaT(KL*alphaL) - math::lgammaT(alphaL)*KL) * totWins;
|
|
257
|
+
if(K) ll += (math::lgammaT(alphaM + alphaML) - math::lgammaT(alphaM) - math::lgammaT(alphaML)) * totWins;
|
|
258
|
+
ll += (math::lgammaT(T * gamma) - math::lgammaT(gamma) * T) * totSents;
|
|
259
|
+
|
|
260
|
+
//
|
|
261
|
+
ll += (math::lgammaT(V*eta) - math::lgammaT(eta)*V) * K;
|
|
262
|
+
for (Tid k = 0; k < K; ++k)
|
|
263
|
+
{
|
|
264
|
+
ll -= math::lgammaT(this->globalState.numByTopic[k] + V * eta);
|
|
265
|
+
for (Vid w = 0; w < V; ++w)
|
|
266
|
+
{
|
|
267
|
+
ll += math::lgammaT(this->globalState.numByTopicWord(k, w) + eta);
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
ll += (math::lgammaT(V*etaL) - math::lgammaT(etaL)*V) * KL;
|
|
271
|
+
for (Tid k = 0; k < KL; ++k)
|
|
272
|
+
{
|
|
273
|
+
ll -= math::lgammaT(this->globalState.numByTopic[k + K] + V * etaL);
|
|
274
|
+
for (Vid w = 0; w < V; ++w)
|
|
275
|
+
{
|
|
276
|
+
ll += math::lgammaT(this->globalState.numByTopicWord(k + K, w) + etaL);
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
return ll;
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
284
|
+
{
|
|
285
|
+
sortAndWriteOrder(doc.words, doc.wOrder);
|
|
286
|
+
auto tmp = doc.sents;
|
|
287
|
+
for (size_t i = 0; i < doc.wOrder.size(); ++i)
|
|
288
|
+
{
|
|
289
|
+
doc.sents[doc.wOrder[i]] = tmp[i];
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
const size_t S = doc.numBySent.size();
|
|
293
|
+
std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0);
|
|
294
|
+
doc.Zs = tvector<Tid>(wordSize);
|
|
295
|
+
doc.Vs.resize(wordSize);
|
|
296
|
+
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
|
297
|
+
doc.numByTopic.init(nullptr, this->K + KL);
|
|
298
|
+
doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
|
|
299
|
+
doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
|
300
|
+
doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
|
301
|
+
doc.numByWinTopicL = Eigen::Matrix<WeightType, -1, -1>::Zero(KL, S + T - 1);
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
void initGlobalState(bool initDocs)
|
|
305
|
+
{
|
|
306
|
+
const size_t V = this->realV;
|
|
307
|
+
this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(T * (this->K + KL));
|
|
308
|
+
if (initDocs)
|
|
309
|
+
{
|
|
310
|
+
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
|
|
311
|
+
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
struct Generator
|
|
316
|
+
{
|
|
317
|
+
std::discrete_distribution<uint16_t> pi;
|
|
318
|
+
std::uniform_int_distribution<Tid> theta;
|
|
319
|
+
std::uniform_int_distribution<Tid> thetaL;
|
|
320
|
+
std::uniform_int_distribution<uint16_t> psi;
|
|
321
|
+
};
|
|
322
|
+
|
|
323
|
+
Generator makeGeneratorForInit(const _DocType*) const
|
|
324
|
+
{
|
|
325
|
+
return Generator{ std::discrete_distribution<uint16_t>{ alphaM, alphaML },
|
|
326
|
+
std::uniform_int_distribution<Tid>{ 0, (Tid)(this->K - 1) },
|
|
327
|
+
std::uniform_int_distribution<Tid>{ 0, (Tid)(KL - 1) },
|
|
328
|
+
std::uniform_int_distribution<uint16_t>{ 0, (uint16_t)(T - 1) } };
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
template<bool _Infer>
|
|
332
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
|
333
|
+
{
|
|
334
|
+
doc.numBySent[doc.sents[i]] += _tw == TermWeight::one ? 1 : doc.wordWeights[i];
|
|
335
|
+
auto w = doc.words[i];
|
|
336
|
+
size_t r, z;
|
|
337
|
+
if (this->etaByTopicWord.size())
|
|
338
|
+
{
|
|
339
|
+
Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
|
|
340
|
+
col.head(this->K) *= alphaM / this->K;
|
|
341
|
+
col.tail(this->KL) *= alphaML / this->KL;
|
|
342
|
+
doc.Zs[i] = z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
343
|
+
r = z < this->K;
|
|
344
|
+
if (z >= this->K) z -= this->K;
|
|
345
|
+
}
|
|
346
|
+
else
|
|
347
|
+
{
|
|
348
|
+
r = g.pi(rgs);
|
|
349
|
+
z = (r ? g.thetaL : g.theta)(rgs);
|
|
350
|
+
doc.Zs[i] = z + (r ? this->K : 0);
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
auto& win = doc.Vs[i];
|
|
354
|
+
win = g.psi(rgs);
|
|
355
|
+
addWordTo<1>(ld, doc, i, w, z, doc.sents[i], win, r);
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
std::vector<uint64_t> _getTopicsCount() const
|
|
359
|
+
{
|
|
360
|
+
std::vector<uint64_t> cnt(this->K + KL);
|
|
361
|
+
for (auto& doc : this->docs)
|
|
362
|
+
{
|
|
363
|
+
for (size_t i = 0; i < doc.Zs.size(); ++i)
|
|
364
|
+
{
|
|
365
|
+
if (doc.words[i] < this->realV) ++cnt[doc.Zs[i]];
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
return cnt;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
public:
|
|
372
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
|
|
373
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
|
|
374
|
+
|
|
375
|
+
MGLDAModel(size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
|
|
376
|
+
Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
|
|
377
|
+
Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
|
|
378
|
+
: BaseClass(_KG, _alphaG, _etaG, _rg), KL(_KL), T(_T),
|
|
379
|
+
alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML),
|
|
380
|
+
etaL(_etaL), gamma(_gamma)
|
|
381
|
+
{
|
|
382
|
+
if (_KL == 0 || _KL >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong KL value (KL = %zd)", _KL));
|
|
383
|
+
if (_T == 0 || _T >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong T value (T = %zd)", _T));
|
|
384
|
+
if (_alphaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alphaL value (alphaL = %f)", _alphaL));
|
|
385
|
+
if (_etaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong etaL value (etaL = %f)", _etaL));
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
template<bool _const = false>
|
|
390
|
+
_DocType _makeDoc(const std::vector<std::string>& words, const std::string& delimiter)
|
|
391
|
+
{
|
|
392
|
+
_DocType doc{ 1.f };
|
|
393
|
+
size_t numSent = 0;
|
|
394
|
+
for (auto& w : words)
|
|
395
|
+
{
|
|
396
|
+
if (w == delimiter)
|
|
397
|
+
{
|
|
398
|
+
++numSent;
|
|
399
|
+
continue;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
Vid id;
|
|
403
|
+
if (_const)
|
|
404
|
+
{
|
|
405
|
+
id = this->dict.toWid(w);
|
|
406
|
+
if (id == (Vid)-1) continue;
|
|
407
|
+
}
|
|
408
|
+
else
|
|
409
|
+
{
|
|
410
|
+
id = this->dict.add(w);
|
|
411
|
+
}
|
|
412
|
+
doc.words.emplace_back(id);
|
|
413
|
+
doc.sents.emplace_back(numSent);
|
|
414
|
+
}
|
|
415
|
+
doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
|
|
416
|
+
return doc;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
size_t addDoc(const std::vector<std::string>& words, const std::string& delimiter) override
|
|
420
|
+
{
|
|
421
|
+
return this->_addDoc(_makeDoc(words, delimiter));
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::string& delimiter) const override
|
|
425
|
+
{
|
|
426
|
+
return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words, delimiter));
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
template<bool _const, typename _FnTokenizer>
|
|
430
|
+
_DocType _makeRawDoc(const std::string& rawStr, _FnTokenizer&& tokenizer, const std::string& delimiter)
|
|
431
|
+
{
|
|
432
|
+
_DocType doc{ 1.f };
|
|
433
|
+
size_t numSent = 0;
|
|
434
|
+
doc.rawStr = rawStr;
|
|
435
|
+
for (auto& p : tokenizer(doc.rawStr))
|
|
436
|
+
{
|
|
437
|
+
if (std::get<0>(p) == delimiter)
|
|
438
|
+
{
|
|
439
|
+
++numSent;
|
|
440
|
+
continue;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
Vid wid;
|
|
444
|
+
if (_const)
|
|
445
|
+
{
|
|
446
|
+
wid = this->dict.toWid(std::get<0>(p));
|
|
447
|
+
if (wid == (Vid)-1) continue;
|
|
448
|
+
}
|
|
449
|
+
else
|
|
450
|
+
{
|
|
451
|
+
wid = this->dict.add(std::get<0>(p));
|
|
452
|
+
}
|
|
453
|
+
auto pos = std::get<1>(p);
|
|
454
|
+
auto len = std::get<2>(p);
|
|
455
|
+
doc.words.emplace_back(wid);
|
|
456
|
+
doc.sents.emplace_back(numSent);
|
|
457
|
+
doc.origWordPos.emplace_back(pos);
|
|
458
|
+
doc.origWordLen.emplace_back(len);
|
|
459
|
+
}
|
|
460
|
+
doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
|
|
461
|
+
return doc;
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
465
|
+
const std::string& delimiter)
|
|
466
|
+
{
|
|
467
|
+
return this->_addDoc(_makeRawDoc<false>(rawStr, tokenizer, delimiter));
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
471
|
+
const std::string& delimiter) const
|
|
472
|
+
{
|
|
473
|
+
return make_unique<_DocType>(as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer, delimiter));
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
_DocType _makeRawDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
477
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, const std::string& delimiter) const
|
|
478
|
+
{
|
|
479
|
+
_DocType doc{ 1.f };
|
|
480
|
+
doc.rawStr = rawStr;
|
|
481
|
+
size_t numSent = 0;
|
|
482
|
+
Vid delimiterId = this->dict.toWid(delimiter);
|
|
483
|
+
for (size_t i = 0; i < words.size(); ++i)
|
|
484
|
+
{
|
|
485
|
+
auto& w = words[i];
|
|
486
|
+
if (w == delimiterId)
|
|
487
|
+
{
|
|
488
|
+
++numSent;
|
|
489
|
+
continue;
|
|
490
|
+
}
|
|
491
|
+
doc.words.emplace_back(w);
|
|
492
|
+
doc.sents.emplace_back(numSent);
|
|
493
|
+
if (words.size() == pos.size())
|
|
494
|
+
{
|
|
495
|
+
doc.origWordPos.emplace_back(pos[i]);
|
|
496
|
+
doc.origWordLen.emplace_back(len[i]);
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
|
|
500
|
+
return doc;
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
504
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
505
|
+
const std::string& delimiter)
|
|
506
|
+
{
|
|
507
|
+
return this->_addDoc(_makeRawDoc(rawStr, words, pos, len, delimiter));
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
511
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
512
|
+
const std::string& delimiter) const
|
|
513
|
+
{
|
|
514
|
+
return make_unique<_DocType>(_makeRawDoc(rawStr, words, pos, len, delimiter));
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
518
|
+
{
|
|
519
|
+
if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
|
|
520
|
+
for (auto p : priors)
|
|
521
|
+
{
|
|
522
|
+
if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
|
|
523
|
+
}
|
|
524
|
+
this->dict.add(word);
|
|
525
|
+
this->etaByWord.emplace(word, priors);
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
529
|
+
{
|
|
530
|
+
std::vector<Float> ret(this->K + KL);
|
|
531
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K + KL }.array() =
|
|
532
|
+
doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
|
|
533
|
+
return ret;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
GETTER(KL, size_t, KL);
|
|
537
|
+
GETTER(T, size_t, T);
|
|
538
|
+
GETTER(Gamma, Float, gamma);
|
|
539
|
+
GETTER(AlphaL, Float, alphaL);
|
|
540
|
+
GETTER(EtaL, Float, etaL);
|
|
541
|
+
GETTER(AlphaM, Float, alphaM);
|
|
542
|
+
GETTER(AlphaML, Float, alphaML);
|
|
543
|
+
};
|
|
544
|
+
|
|
545
|
+
template<TermWeight _tw>
|
|
546
|
+
template<typename _TopicModel>
|
|
547
|
+
void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
|
548
|
+
{
|
|
549
|
+
this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
|
|
550
|
+
numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
|
|
551
|
+
for (size_t i = 0; i < this->Zs.size(); ++i)
|
|
552
|
+
{
|
|
553
|
+
if (this->words[i] >= mdl.getV()) continue;
|
|
554
|
+
this->numByTopic[this->Zs[i]] += _tw != TermWeight::one ? this->wordWeights[i] : 1;
|
|
555
|
+
numBySent[sents[i]] += _tw != TermWeight::one ? this->wordWeights[i] : 1;
|
|
556
|
+
}
|
|
557
|
+
}
|
|
558
|
+
}
|