tomoto 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +123 -0
- data/ext/tomoto/ext.cpp +245 -0
- data/ext/tomoto/extconf.rb +28 -0
- data/lib/tomoto.rb +12 -0
- data/lib/tomoto/ct.rb +11 -0
- data/lib/tomoto/hdp.rb +11 -0
- data/lib/tomoto/lda.rb +67 -0
- data/lib/tomoto/version.rb +3 -0
- data/vendor/EigenRand/EigenRand/Core.h +1139 -0
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
- data/vendor/EigenRand/EigenRand/EigenRand +19 -0
- data/vendor/EigenRand/EigenRand/Macro.h +24 -0
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
- data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
- data/vendor/EigenRand/EigenRand/doc.h +220 -0
- data/vendor/EigenRand/LICENSE +21 -0
- data/vendor/EigenRand/README.md +288 -0
- data/vendor/eigen/COPYING.BSD +26 -0
- data/vendor/eigen/COPYING.GPL +674 -0
- data/vendor/eigen/COPYING.LGPL +502 -0
- data/vendor/eigen/COPYING.MINPACK +52 -0
- data/vendor/eigen/COPYING.MPL2 +373 -0
- data/vendor/eigen/COPYING.README +18 -0
- data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
- data/vendor/eigen/Eigen/Cholesky +46 -0
- data/vendor/eigen/Eigen/CholmodSupport +48 -0
- data/vendor/eigen/Eigen/Core +537 -0
- data/vendor/eigen/Eigen/Dense +7 -0
- data/vendor/eigen/Eigen/Eigen +2 -0
- data/vendor/eigen/Eigen/Eigenvalues +61 -0
- data/vendor/eigen/Eigen/Geometry +62 -0
- data/vendor/eigen/Eigen/Householder +30 -0
- data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
- data/vendor/eigen/Eigen/Jacobi +33 -0
- data/vendor/eigen/Eigen/LU +50 -0
- data/vendor/eigen/Eigen/MetisSupport +35 -0
- data/vendor/eigen/Eigen/OrderingMethods +73 -0
- data/vendor/eigen/Eigen/PaStiXSupport +48 -0
- data/vendor/eigen/Eigen/PardisoSupport +35 -0
- data/vendor/eigen/Eigen/QR +51 -0
- data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
- data/vendor/eigen/Eigen/SPQRSupport +34 -0
- data/vendor/eigen/Eigen/SVD +51 -0
- data/vendor/eigen/Eigen/Sparse +36 -0
- data/vendor/eigen/Eigen/SparseCholesky +45 -0
- data/vendor/eigen/Eigen/SparseCore +69 -0
- data/vendor/eigen/Eigen/SparseLU +46 -0
- data/vendor/eigen/Eigen/SparseQR +37 -0
- data/vendor/eigen/Eigen/StdDeque +27 -0
- data/vendor/eigen/Eigen/StdList +26 -0
- data/vendor/eigen/Eigen/StdVector +27 -0
- data/vendor/eigen/Eigen/SuperLUSupport +64 -0
- data/vendor/eigen/Eigen/UmfPackSupport +40 -0
- data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
- data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
- data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
- data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
- data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
- data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
- data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
- data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
- data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
- data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
- data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
- data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
- data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
- data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
- data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
- data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
- data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
- data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
- data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
- data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
- data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
- data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
- data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
- data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
- data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
- data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
- data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
- data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
- data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
- data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
- data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
- data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
- data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
- data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
- data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
- data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
- data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
- data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
- data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
- data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
- data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
- data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
- data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
- data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
- data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
- data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
- data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
- data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
- data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
- data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
- data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
- data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
- data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
- data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
- data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
- data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
- data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
- data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
- data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
- data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
- data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
- data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
- data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
- data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
- data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
- data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
- data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
- data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
- data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
- data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
- data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
- data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
- data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
- data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
- data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
- data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
- data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
- data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
- data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
- data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
- data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
- data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
- data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
- data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
- data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
- data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
- data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
- data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
- data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
- data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
- data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
- data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
- data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
- data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
- data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
- data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
- data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
- data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
- data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
- data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
- data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
- data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
- data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
- data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
- data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
- data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
- data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
- data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
- data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
- data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
- data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
- data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
- data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
- data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
- data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
- data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
- data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
- data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
- data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
- data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
- data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
- data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
- data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
- data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
- data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
- data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
- data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
- data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
- data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
- data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
- data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
- data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
- data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
- data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
- data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
- data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
- data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
- data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
- data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
- data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
- data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
- data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
- data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
- data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
- data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
- data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
- data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
- data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
- data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
- data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
- data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
- data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
- data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
- data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
- data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
- data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
- data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
- data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
- data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
- data/vendor/eigen/README.md +3 -0
- data/vendor/eigen/bench/README.txt +55 -0
- data/vendor/eigen/bench/btl/COPYING +340 -0
- data/vendor/eigen/bench/btl/README +154 -0
- data/vendor/eigen/bench/tensors/README +21 -0
- data/vendor/eigen/blas/README.txt +6 -0
- data/vendor/eigen/demos/mandelbrot/README +10 -0
- data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
- data/vendor/eigen/demos/opengl/README +13 -0
- data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
- data/vendor/eigen/unsupported/README.txt +50 -0
- data/vendor/tomotopy/LICENSE +21 -0
- data/vendor/tomotopy/README.kr.rst +375 -0
- data/vendor/tomotopy/README.rst +382 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
- data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
- data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
- data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
- data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
- data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
- data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
- data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
- data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
- data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
- data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
- data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
- data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
- data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
- data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
- data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
- data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
- data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
- data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
- data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
- data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
- data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
- data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
- data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
- data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
- data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
- data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
- data/vendor/tomotopy/src/Utils/exception.h +28 -0
- data/vendor/tomotopy/src/Utils/math.h +281 -0
- data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
- data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
- data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
- data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
- data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
- data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
- data/vendor/tomotopy/src/Utils/text.hpp +49 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
- metadata +531 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LLDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
|
|
7
|
+
class IPLDAModel : public ILLDAModel
|
|
8
|
+
{
|
|
9
|
+
public:
|
|
10
|
+
using DefaultDocType = DocumentLLDA<TermWeight::one>;
|
|
11
|
+
static IPLDAModel* create(TermWeight _weight, size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
|
|
12
|
+
Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
|
|
13
|
+
bool scalarRng = false);
|
|
14
|
+
|
|
15
|
+
virtual size_t getNumLatentTopics() const = 0;
|
|
16
|
+
};
|
|
17
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
#include "PLDAModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class PLDAModel<TermWeight::one>;
|
|
6
|
+
template class PLDAModel<TermWeight::idf>;
|
|
7
|
+
template class PLDAModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
IPLDAModel* IPLDAModel::create(TermWeight _weight, size_t _numLatentTopics, size_t _numTopicsPerLabel, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
10
|
+
{
|
|
11
|
+
TMT_SWITCH_TW(_weight, scalarRng, PLDAModel, _numLatentTopics, _numTopicsPerLabel, _alpha, _eta, seed);
|
|
12
|
+
}
|
|
13
|
+
}
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "PLDA.h"
|
|
4
|
+
|
|
5
|
+
/*
|
|
6
|
+
Implementation of Labeled LDA using Gibbs sampling by bab2min
|
|
7
|
+
|
|
8
|
+
* Ramage, D., Manning, C. D., & Dumais, S. (2011, August). Partially labeled topic models for interpretable text mining. In Proceedings of the 17th ACM SIGKDD international conference on Knowledge discovery and data mining (pp. 457-465). ACM.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
namespace tomoto
|
|
12
|
+
{
|
|
13
|
+
template<TermWeight _tw, typename _RandGen,
|
|
14
|
+
typename _Interface = IPLDAModel,
|
|
15
|
+
typename _Derived = void,
|
|
16
|
+
typename _DocType = DocumentLLDA<_tw>,
|
|
17
|
+
typename _ModelState = ModelStateLDA<_tw>>
|
|
18
|
+
class PLDAModel : public LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface,
|
|
19
|
+
typename std::conditional<std::is_same<_Derived, void>::value, PLDAModel<_tw, _RandGen>, _Derived>::type,
|
|
20
|
+
_DocType, _ModelState>
|
|
21
|
+
{
|
|
22
|
+
protected:
|
|
23
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PLDAModel<_tw, _RandGen>, _Derived>::type;
|
|
24
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
25
|
+
friend BaseClass;
|
|
26
|
+
friend typename BaseClass::BaseClass;
|
|
27
|
+
using WeightType = typename BaseClass::WeightType;
|
|
28
|
+
|
|
29
|
+
static constexpr char TMID[] = "PLDA";
|
|
30
|
+
|
|
31
|
+
Dictionary topicLabelDict;
|
|
32
|
+
|
|
33
|
+
uint64_t numLatentTopics, numTopicsPerLabel;
|
|
34
|
+
|
|
35
|
+
template<bool _asymEta>
|
|
36
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
37
|
+
{
|
|
38
|
+
const size_t V = this->realV;
|
|
39
|
+
assert(vid < V);
|
|
40
|
+
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
41
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
42
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
|
|
43
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
44
|
+
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
45
|
+
zLikelihood.array() *= doc.labelMask.array().template cast<Float>();
|
|
46
|
+
sample::prefixSum(zLikelihood.data(), this->K);
|
|
47
|
+
return &zLikelihood[0];
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
51
|
+
{
|
|
52
|
+
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
53
|
+
if (doc.labelMask.size() == 0)
|
|
54
|
+
{
|
|
55
|
+
doc.labelMask.resize(this->K);
|
|
56
|
+
doc.labelMask.setZero();
|
|
57
|
+
doc.labelMask.tail(numLatentTopics).setOnes();
|
|
58
|
+
}
|
|
59
|
+
else if (doc.labelMask.size() < this->K)
|
|
60
|
+
{
|
|
61
|
+
size_t oldSize = doc.labelMask.size();
|
|
62
|
+
doc.labelMask.conservativeResize(this->K);
|
|
63
|
+
doc.labelMask.tail(this->K - oldSize).setZero();
|
|
64
|
+
doc.labelMask.tail(numLatentTopics).setOnes();
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
void initGlobalState(bool initDocs)
|
|
69
|
+
{
|
|
70
|
+
this->K = topicLabelDict.size() * numTopicsPerLabel + numLatentTopics;
|
|
71
|
+
this->alphas.resize(this->K);
|
|
72
|
+
this->alphas.array() = this->alpha;
|
|
73
|
+
BaseClass::initGlobalState(initDocs);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
struct Generator
|
|
77
|
+
{
|
|
78
|
+
std::discrete_distribution<> theta;
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
Generator makeGeneratorForInit(const _DocType* doc) const
|
|
82
|
+
{
|
|
83
|
+
return Generator{
|
|
84
|
+
std::discrete_distribution<>{ doc->labelMask.data(), doc->labelMask.data() + doc->labelMask.size() }
|
|
85
|
+
};
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
template<bool _Infer>
|
|
89
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
|
90
|
+
{
|
|
91
|
+
auto& z = doc.Zs[i];
|
|
92
|
+
auto w = doc.words[i];
|
|
93
|
+
if (this->etaByTopicWord.size())
|
|
94
|
+
{
|
|
95
|
+
Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
|
|
96
|
+
for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
|
|
97
|
+
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
98
|
+
}
|
|
99
|
+
else
|
|
100
|
+
{
|
|
101
|
+
z = g.theta(rgs);
|
|
102
|
+
}
|
|
103
|
+
this->template addWordTo<1>(ld, doc, i, w, z);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
public:
|
|
107
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
|
108
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
|
109
|
+
|
|
110
|
+
PLDAModel(size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
|
|
111
|
+
Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
|
|
112
|
+
: BaseClass(1, _alpha, _eta, _rg),
|
|
113
|
+
numLatentTopics(_numLatentTopics), numTopicsPerLabel(_numTopicsPerLabel)
|
|
114
|
+
{
|
|
115
|
+
if (_numLatentTopics >= 0x80000000)
|
|
116
|
+
THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numLatentTopics value (numLatentTopics = %zd)", _numLatentTopics));
|
|
117
|
+
if (_numTopicsPerLabel == 0 || _numTopicsPerLabel >= 0x80000000)
|
|
118
|
+
THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numTopicsPerLabel value (numTopicsPerLabel = %zd)", _numTopicsPerLabel));
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
template<bool _const = false>
|
|
122
|
+
_DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& labels)
|
|
123
|
+
{
|
|
124
|
+
if (_const)
|
|
125
|
+
{
|
|
126
|
+
doc.labelMask.resize(this->K);
|
|
127
|
+
doc.labelMask.setZero();
|
|
128
|
+
doc.labelMask.tail(numLatentTopics).setOnes();
|
|
129
|
+
|
|
130
|
+
std::vector<Vid> topicLabelIds;
|
|
131
|
+
for (auto& label : labels)
|
|
132
|
+
{
|
|
133
|
+
auto tid = topicLabelDict.toWid(label);
|
|
134
|
+
if (tid == (Vid)-1) continue;
|
|
135
|
+
topicLabelIds.emplace_back(tid);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
for (auto tid : topicLabelIds) doc.labelMask.segment(tid * numTopicsPerLabel, numTopicsPerLabel).setOnes();
|
|
139
|
+
if (labels.empty()) doc.labelMask.setOnes();
|
|
140
|
+
}
|
|
141
|
+
else
|
|
142
|
+
{
|
|
143
|
+
if (!labels.empty())
|
|
144
|
+
{
|
|
145
|
+
std::vector<Vid> topicLabelIds;
|
|
146
|
+
for (auto& label : labels) topicLabelIds.emplace_back(topicLabelDict.add(label));
|
|
147
|
+
auto maxVal = *std::max_element(topicLabelIds.begin(), topicLabelIds.end());
|
|
148
|
+
doc.labelMask.resize((maxVal + 1) * numTopicsPerLabel);
|
|
149
|
+
doc.labelMask.setZero();
|
|
150
|
+
for (auto i : topicLabelIds) doc.labelMask.segment(i * numTopicsPerLabel, numTopicsPerLabel).setOnes();
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
return doc;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
|
|
157
|
+
{
|
|
158
|
+
auto doc = this->_makeDoc(words);
|
|
159
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
|
|
163
|
+
{
|
|
164
|
+
auto doc = as_mutable(this)->template _makeDoc<true>(words);
|
|
165
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
169
|
+
const std::vector<std::string>& labels) override
|
|
170
|
+
{
|
|
171
|
+
auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
|
|
172
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
176
|
+
const std::vector<std::string>& labels) const override
|
|
177
|
+
{
|
|
178
|
+
auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
|
|
179
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
183
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
184
|
+
const std::vector<std::string>& labels) override
|
|
185
|
+
{
|
|
186
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
187
|
+
return this->_addDoc(_updateDoc(doc, labels));
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
191
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
192
|
+
const std::vector<std::string>& labels) const override
|
|
193
|
+
{
|
|
194
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
195
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
199
|
+
{
|
|
200
|
+
std::vector<Float> ret(this->K);
|
|
201
|
+
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
|
202
|
+
Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
|
|
203
|
+
(doc.numByTopic.array().template cast<Float>() + maskedAlphas)
|
|
204
|
+
/ (doc.getSumWordWeight() + maskedAlphas.sum());
|
|
205
|
+
return ret;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
const Dictionary& getTopicLabelDict() const override { return topicLabelDict; }
|
|
209
|
+
|
|
210
|
+
size_t getNumLatentTopics() const override { return numLatentTopics; }
|
|
211
|
+
|
|
212
|
+
size_t getNumTopicsPerLabel() const override { return numTopicsPerLabel; }
|
|
213
|
+
};
|
|
214
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDA.h"
|
|
3
|
+
|
|
4
|
+
namespace tomoto
|
|
5
|
+
{
|
|
6
|
+
template<TermWeight _tw>
|
|
7
|
+
struct DocumentSLDA : public DocumentLDA<_tw>
|
|
8
|
+
{
|
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
+
std::vector<Float> y;
|
|
12
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, y);
|
|
13
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, y);
|
|
14
|
+
};
|
|
15
|
+
|
|
16
|
+
class ISLDAModel : public ILDAModel
|
|
17
|
+
{
|
|
18
|
+
public:
|
|
19
|
+
enum class GLM
|
|
20
|
+
{
|
|
21
|
+
linear = 0,
|
|
22
|
+
binary_logistic = 1,
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
using DefaultDocType = DocumentSLDA<TermWeight::one>;
|
|
26
|
+
static ISLDAModel* create(TermWeight _weight, size_t _K = 1,
|
|
27
|
+
const std::vector<ISLDAModel::GLM>& vars = {},
|
|
28
|
+
Float alpha = 0.1, Float _eta = 0.01,
|
|
29
|
+
const std::vector<Float>& _mu = {}, const std::vector<Float>& _nuSq = {},
|
|
30
|
+
const std::vector<Float>& _glmParam = {},
|
|
31
|
+
size_t seed = std::random_device{}(),
|
|
32
|
+
bool scalarRng = false);
|
|
33
|
+
|
|
34
|
+
virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
|
|
35
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
|
|
36
|
+
|
|
37
|
+
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
38
|
+
const std::vector<Float>& y) = 0;
|
|
39
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
40
|
+
const std::vector<Float>& y) const = 0;
|
|
41
|
+
|
|
42
|
+
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
43
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
44
|
+
const std::vector<Float>& y) = 0;
|
|
45
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
46
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
47
|
+
const std::vector<Float>& y) const = 0;
|
|
48
|
+
|
|
49
|
+
virtual size_t getF() const = 0;
|
|
50
|
+
virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
|
|
51
|
+
virtual GLM getTypeOfVar(size_t f) const = 0;
|
|
52
|
+
virtual std::vector<Float> estimateVars(const DocumentBase* doc) const = 0;
|
|
53
|
+
};
|
|
54
|
+
}
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
#include "SLDAModel.hpp"
|
|
2
|
+
|
|
3
|
+
namespace tomoto
|
|
4
|
+
{
|
|
5
|
+
/*template class SLDAModel<TermWeight::one>;
|
|
6
|
+
template class SLDAModel<TermWeight::idf>;
|
|
7
|
+
template class SLDAModel<TermWeight::pmi>;*/
|
|
8
|
+
|
|
9
|
+
ISLDAModel* ISLDAModel::create(TermWeight _weight, size_t _K, const std::vector<ISLDAModel::GLM>& vars,
|
|
10
|
+
Float _alpha, Float _eta,
|
|
11
|
+
const std::vector<Float>& _mu, const std::vector<Float>& _nuSq,
|
|
12
|
+
const std::vector<Float>& _glmParam,
|
|
13
|
+
size_t seed, bool scalarRng)
|
|
14
|
+
{
|
|
15
|
+
TMT_SWITCH_TW(_weight, scalarRng, SLDAModel, _K, vars, _alpha, _eta, _mu, _nuSq, _glmParam, seed);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
#include "LDAModel.hpp"
|
|
3
|
+
#include "../Utils/PolyaGamma.hpp"
|
|
4
|
+
#include "SLDA.h"
|
|
5
|
+
|
|
6
|
+
/*
|
|
7
|
+
Implementation of sLDA using Gibbs sampling by bab2min
|
|
8
|
+
* Mcauliffe, J. D., & Blei, D. M. (2008). Supervised topic models. In Advances in neural information processing systems (pp. 121-128).
|
|
9
|
+
* Python version implementation using Gibbs sampling : https://github.com/Savvysherpa/slda
|
|
10
|
+
*/
|
|
11
|
+
|
|
12
|
+
namespace tomoto
|
|
13
|
+
{
|
|
14
|
+
namespace detail
|
|
15
|
+
{
|
|
16
|
+
template<typename _WeightType>
|
|
17
|
+
struct GLMFunctor
|
|
18
|
+
{
|
|
19
|
+
Eigen::Matrix<Float, -1, 1> regressionCoef; // Dim : (K)
|
|
20
|
+
|
|
21
|
+
GLMFunctor(size_t K = 0, Float mu = 0) : regressionCoef(Eigen::Matrix<Float, -1, 1>::Constant(K, mu))
|
|
22
|
+
{
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
virtual ISLDAModel::GLM getType() const = 0;
|
|
26
|
+
|
|
27
|
+
virtual void updateZLL(
|
|
28
|
+
Eigen::Matrix<Float, -1, 1>& zLikelihood,
|
|
29
|
+
Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const = 0;
|
|
30
|
+
|
|
31
|
+
virtual void optimizeCoef(
|
|
32
|
+
const Eigen::Matrix<Float, -1, -1>& normZ,
|
|
33
|
+
Float mu, Float nuSq,
|
|
34
|
+
Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
|
|
35
|
+
) = 0;
|
|
36
|
+
|
|
37
|
+
virtual double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
38
|
+
Float docSize) const = 0;
|
|
39
|
+
|
|
40
|
+
virtual Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
41
|
+
Float docSize) const = 0;
|
|
42
|
+
|
|
43
|
+
virtual ~GLMFunctor() {};
|
|
44
|
+
|
|
45
|
+
DEFINE_SERIALIZER_VIRTUAL(regressionCoef);
|
|
46
|
+
|
|
47
|
+
static void serializerWrite(const std::unique_ptr<GLMFunctor>& p, std::ostream& ostr)
|
|
48
|
+
{
|
|
49
|
+
if (!p) serializer::writeToStream<uint32_t>(ostr, 0);
|
|
50
|
+
else
|
|
51
|
+
{
|
|
52
|
+
serializer::writeToStream<uint32_t>(ostr, (uint32_t)p->getType() + 1);
|
|
53
|
+
p->serializerWrite(ostr);
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
static void serializerRead(std::unique_ptr<GLMFunctor>& p, std::istream& istr);
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
template<typename _WeightType>
|
|
61
|
+
struct LinearFunctor : public GLMFunctor<_WeightType>
|
|
62
|
+
{
|
|
63
|
+
Float sigmaSq = 1;
|
|
64
|
+
|
|
65
|
+
LinearFunctor(size_t K = 0, Float mu = 0, Float _sigmaSq = 1)
|
|
66
|
+
: GLMFunctor<_WeightType>(K, mu), sigmaSq(_sigmaSq)
|
|
67
|
+
{
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
ISLDAModel::GLM getType() const override { return ISLDAModel::GLM::linear; }
|
|
71
|
+
|
|
72
|
+
void updateZLL(
|
|
73
|
+
Eigen::Matrix<Float, -1, 1>& zLikelihood,
|
|
74
|
+
Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const override
|
|
75
|
+
{
|
|
76
|
+
Float yErr = y -
|
|
77
|
+
(this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
|
|
78
|
+
/ docSize;
|
|
79
|
+
zLikelihood.array() *= (this->regressionCoef.array() / docSize / 2 / sigmaSq *
|
|
80
|
+
(2 * yErr - this->regressionCoef.array() / docSize)).exp();
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
void optimizeCoef(
|
|
84
|
+
const Eigen::Matrix<Float, -1, -1>& normZ,
|
|
85
|
+
Float mu, Float nuSq,
|
|
86
|
+
Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
|
|
87
|
+
) override
|
|
88
|
+
{
|
|
89
|
+
Eigen::Matrix<Float, -1, -1> selectedNormZ = normZ.array().rowwise() * (!ys.array().transpose().isNaN()).template cast<Float>();
|
|
90
|
+
Eigen::Matrix<Float, -1, -1> normZZT = selectedNormZ * selectedNormZ.transpose();
|
|
91
|
+
normZZT += Eigen::Matrix<Float, -1, -1>::Identity(normZZT.cols(), normZZT.cols()) / nuSq;
|
|
92
|
+
this->regressionCoef = normZZT.colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, ys).matrix());
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
96
|
+
Float docSize) const override
|
|
97
|
+
{
|
|
98
|
+
Float estimatedY = estimate(numByTopic, docSize);
|
|
99
|
+
return -pow(estimatedY - y, 2) / 2 / sigmaSq;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
103
|
+
Float docSize) const override
|
|
104
|
+
{
|
|
105
|
+
return (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
|
|
106
|
+
/ std::max(docSize, 0.01f);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
DEFINE_SERIALIZER_AFTER_BASE(GLMFunctor<_WeightType>, sigmaSq);
|
|
110
|
+
};
|
|
111
|
+
|
|
112
|
+
template<typename _WeightType>
|
|
113
|
+
struct BinaryLogisticFunctor : public GLMFunctor<_WeightType>
|
|
114
|
+
{
|
|
115
|
+
Float b = 1;
|
|
116
|
+
Eigen::Matrix<Float, -1, 1> omega;
|
|
117
|
+
|
|
118
|
+
BinaryLogisticFunctor(size_t K = 0, Float mu = 0, Float _b = 1, size_t numDocs = 0)
|
|
119
|
+
: GLMFunctor<_WeightType>(K, mu), b(_b), omega{ Eigen::Matrix<Float, -1, 1>::Ones(numDocs) }
|
|
120
|
+
{
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
ISLDAModel::GLM getType() const override { return ISLDAModel::GLM::binary_logistic; }
|
|
124
|
+
|
|
125
|
+
void updateZLL(
|
|
126
|
+
Eigen::Matrix<Float, -1, 1>& zLikelihood,
|
|
127
|
+
Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const override
|
|
128
|
+
{
|
|
129
|
+
Float yErr = b * (y - 0.5f) -
|
|
130
|
+
(this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
|
|
131
|
+
/ docSize * omega[docId];
|
|
132
|
+
zLikelihood.array() *= (this->regressionCoef.array() / docSize *
|
|
133
|
+
(yErr - omega[docId] / 2 * this->regressionCoef.array() / docSize)).exp();
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
void optimizeCoef(
|
|
137
|
+
const Eigen::Matrix<Float, -1, -1>& normZ,
|
|
138
|
+
Float mu, Float nuSq,
|
|
139
|
+
Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
|
|
140
|
+
) override
|
|
141
|
+
{
|
|
142
|
+
Eigen::Matrix<Float, -1, -1> selectedNormZ = normZ.array().rowwise() * (!ys.array().transpose().isNaN()).template cast<Float>();
|
|
143
|
+
Eigen::Matrix<Float, -1, -1> normZZT = selectedNormZ * Eigen::DiagonalMatrix<Float, -1>{ omega } * selectedNormZ.transpose();
|
|
144
|
+
normZZT += Eigen::Matrix<Float, -1, -1>::Identity(normZZT.cols(), normZZT.cols()) / nuSq;
|
|
145
|
+
|
|
146
|
+
this->regressionCoef = normZZT
|
|
147
|
+
.colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, b * (ys.array() - 0.5f)).matrix()
|
|
148
|
+
+ Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
|
|
149
|
+
|
|
150
|
+
RandGen rng;
|
|
151
|
+
for (size_t i = 0; i < omega.size(); ++i)
|
|
152
|
+
{
|
|
153
|
+
if (std::isnan(ys[i])) continue;
|
|
154
|
+
omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
159
|
+
Float docSize) const override
|
|
160
|
+
{
|
|
161
|
+
Float z = (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
|
|
162
|
+
/ std::max(docSize, 0.01f);
|
|
163
|
+
return b * (y * z - log(1 + exp(z)));
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
|
|
167
|
+
Float docSize) const override
|
|
168
|
+
{
|
|
169
|
+
Float z = (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
|
|
170
|
+
/ std::max(docSize, 0.01f);
|
|
171
|
+
return 1 / (1 + exp(-z));
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
DEFINE_SERIALIZER_AFTER_BASE(GLMFunctor<_WeightType>, b, omega);
|
|
175
|
+
};
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
template<TermWeight _tw, typename _RandGen,
|
|
179
|
+
size_t _Flags = flags::partitioned_multisampling,
|
|
180
|
+
typename _Interface = ISLDAModel,
|
|
181
|
+
typename _Derived = void,
|
|
182
|
+
typename _DocType = DocumentSLDA<_tw>,
|
|
183
|
+
typename _ModelState = ModelStateLDA<_tw>>
|
|
184
|
+
class SLDAModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
|
|
185
|
+
typename std::conditional<std::is_same<_Derived, void>::value, SLDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
|
186
|
+
_DocType, _ModelState>
|
|
187
|
+
{
|
|
188
|
+
protected:
|
|
189
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, SLDAModel<_tw, _RandGen>, _Derived>::type;
|
|
190
|
+
using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
|
|
191
|
+
friend BaseClass;
|
|
192
|
+
friend typename BaseClass::BaseClass;
|
|
193
|
+
using WeightType = typename BaseClass::WeightType;
|
|
194
|
+
|
|
195
|
+
static constexpr char TMID[] = "SLDA";
|
|
196
|
+
|
|
197
|
+
uint64_t F; // number of response variables
|
|
198
|
+
std::vector<ISLDAModel::GLM> varTypes;
|
|
199
|
+
std::vector<Float> glmParam;
|
|
200
|
+
|
|
201
|
+
Eigen::Matrix<Float, -1, 1> mu; // Mean of regression coefficients, Dim : (F)
|
|
202
|
+
Eigen::Matrix<Float, -1, 1> nuSq; // Variance of regression coefficients, Dim : (F)
|
|
203
|
+
|
|
204
|
+
std::vector<std::unique_ptr<detail::GLMFunctor<WeightType>>> responseVars;
|
|
205
|
+
Eigen::Matrix<Float, -1, -1> normZ; // topic proportions for all docs, Dim : (K, D)
|
|
206
|
+
Eigen::Matrix<Float, -1, -1> Ys; // response variables, Dim : (D, F)
|
|
207
|
+
|
|
208
|
+
template<bool _asymEta>
|
|
209
|
+
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
210
|
+
{
|
|
211
|
+
const size_t V = this->realV;
|
|
212
|
+
assert(vid < V);
|
|
213
|
+
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
214
|
+
auto& zLikelihood = ld.zLikelihood;
|
|
215
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
|
|
216
|
+
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
217
|
+
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
218
|
+
|
|
219
|
+
for (size_t f = 0; f < F; ++f)
|
|
220
|
+
{
|
|
221
|
+
if (std::isnan(doc.y[f])) continue;
|
|
222
|
+
responseVars[f]->updateZLL(zLikelihood, doc.y[f], doc.numByTopic,
|
|
223
|
+
docId, doc.getSumWordWeight());
|
|
224
|
+
}
|
|
225
|
+
sample::prefixSum(zLikelihood.data(), this->K);
|
|
226
|
+
return &zLikelihood[0];
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
void optimizeRegressionCoef()
|
|
230
|
+
{
|
|
231
|
+
for (size_t i = 0; i < this->docs.size(); ++i)
|
|
232
|
+
{
|
|
233
|
+
normZ.col(i) = this->docs[i].numByTopic.array().template cast<Float>() /
|
|
234
|
+
std::max((Float)this->docs[i].getSumWordWeight(), 0.01f);
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
for (size_t f = 0; f < F; ++f)
|
|
238
|
+
{
|
|
239
|
+
responseVars[f]->optimizeCoef(normZ, mu[f], nuSq[f], Ys.col(f));
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
244
|
+
{
|
|
245
|
+
BaseClass::optimizeParameters(pool, localData, rgs);
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
void updateGlobalInfo(ThreadPool& pool, _ModelState* localData)
|
|
249
|
+
{
|
|
250
|
+
optimizeRegressionCoef();
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
template<typename _DocIter>
|
|
254
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
255
|
+
{
|
|
256
|
+
const auto K = this->K;
|
|
257
|
+
|
|
258
|
+
double ll = 0;
|
|
259
|
+
for (; _first != _last; ++_first)
|
|
260
|
+
{
|
|
261
|
+
auto& doc = *_first;
|
|
262
|
+
ll -= math::lgammaT(doc.getSumWordWeight() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
|
|
263
|
+
for (size_t f = 0; f < F; ++f)
|
|
264
|
+
{
|
|
265
|
+
if (std::isnan(doc.y[f])) continue;
|
|
266
|
+
ll += responseVars[f]->getLL(doc.y[f], doc.numByTopic, doc.getSumWordWeight());
|
|
267
|
+
}
|
|
268
|
+
for (Tid k = 0; k < K; ++k)
|
|
269
|
+
{
|
|
270
|
+
ll += math::lgammaT(doc.numByTopic[k] + this->alphas[k]) - math::lgammaT(this->alphas[k]);
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
return ll;
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
double getLLRest(const _ModelState& ld) const
|
|
277
|
+
{
|
|
278
|
+
double ll = BaseClass::getLLRest(ld);
|
|
279
|
+
for (size_t f = 0; f < F; ++f)
|
|
280
|
+
{
|
|
281
|
+
ll -= (responseVars[f]->regressionCoef.array() - mu[f]).pow(2).sum() / 2 / nuSq[f];
|
|
282
|
+
}
|
|
283
|
+
return ll;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
287
|
+
{
|
|
288
|
+
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
void initGlobalState(bool initDocs)
|
|
292
|
+
{
|
|
293
|
+
BaseClass::initGlobalState(initDocs);
|
|
294
|
+
if (initDocs)
|
|
295
|
+
{
|
|
296
|
+
for (size_t f = 0; f < F; ++f)
|
|
297
|
+
{
|
|
298
|
+
std::unique_ptr<detail::GLMFunctor<WeightType>> v;
|
|
299
|
+
switch (varTypes[f])
|
|
300
|
+
{
|
|
301
|
+
case ISLDAModel::GLM::linear:
|
|
302
|
+
v = make_unique<detail::LinearFunctor<WeightType>>(this->K, mu[f],
|
|
303
|
+
f < glmParam.size() ? glmParam[f] : 1.f);
|
|
304
|
+
break;
|
|
305
|
+
case ISLDAModel::GLM::binary_logistic:
|
|
306
|
+
v = make_unique<detail::BinaryLogisticFunctor<WeightType>>(this->K, mu[f],
|
|
307
|
+
f < glmParam.size() ? glmParam[f] : 1.f, this->docs.size());
|
|
308
|
+
break;
|
|
309
|
+
}
|
|
310
|
+
responseVars.emplace_back(std::move(v));
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
Ys.resize(this->docs.size(), F);
|
|
314
|
+
normZ.resize(this->K, this->docs.size());
|
|
315
|
+
for (size_t i = 0; i < this->docs.size(); ++i)
|
|
316
|
+
{
|
|
317
|
+
Ys.row(i) = Eigen::Map<Eigen::Matrix<Float, 1, -1>>(this->docs[i].y.data(), F);
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
public:
|
|
322
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, F, responseVars, mu, nuSq);
|
|
323
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, F, responseVars, mu, nuSq);
|
|
324
|
+
|
|
325
|
+
SLDAModel(size_t _K = 1, const std::vector<ISLDAModel::GLM>& vars = {},
|
|
326
|
+
Float _alpha = 0.1, Float _eta = 0.01,
|
|
327
|
+
const std::vector<Float>& _mu = {}, const std::vector<Float>& _nuSq = {},
|
|
328
|
+
const std::vector<Float>& _glmParam = {},
|
|
329
|
+
size_t _rg = std::random_device{}())
|
|
330
|
+
: BaseClass(_K, _alpha, _eta, _rg), F(vars.size()), varTypes(vars),
|
|
331
|
+
glmParam(_glmParam)
|
|
332
|
+
{
|
|
333
|
+
for (auto t : varTypes)
|
|
334
|
+
{
|
|
335
|
+
if (t != ISLDAModel::GLM::linear && t != ISLDAModel::GLM::binary_logistic) THROW_ERROR_WITH_INFO(std::runtime_error, "unknown var GLM type in 'vars'");
|
|
336
|
+
}
|
|
337
|
+
mu = decltype(mu)::Zero(F);
|
|
338
|
+
std::copy(_mu.begin(), _mu.end(), mu.data());
|
|
339
|
+
nuSq = decltype(nuSq)::Ones(F);
|
|
340
|
+
std::copy(_nuSq.begin(), _nuSq.end(), nuSq.data());
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
std::vector<Float> getRegressionCoef(size_t f) const override
|
|
344
|
+
{
|
|
345
|
+
return { responseVars[f]->regressionCoef.data(), responseVars[f]->regressionCoef.data() + this->K };
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
GETTER(F, size_t, F);
|
|
349
|
+
|
|
350
|
+
ISLDAModel::GLM getTypeOfVar(size_t f) const override
|
|
351
|
+
{
|
|
352
|
+
return responseVars[f]->getType();
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
template<bool _const = false>
|
|
356
|
+
_DocType& _updateDoc(_DocType& doc, const std::vector<Float>& y)
|
|
357
|
+
{
|
|
358
|
+
if (_const)
|
|
359
|
+
{
|
|
360
|
+
if (y.size() > F) throw std::runtime_error{ text::format(
|
|
361
|
+
"size of 'y' is greater than the number of vars.\n"
|
|
362
|
+
"size of 'y' : %zd, number of vars: %zd", y.size(), F) };
|
|
363
|
+
doc.y = y;
|
|
364
|
+
while (doc.y.size() < F)
|
|
365
|
+
{
|
|
366
|
+
doc.y.emplace_back(NAN);
|
|
367
|
+
}
|
|
368
|
+
}
|
|
369
|
+
else
|
|
370
|
+
{
|
|
371
|
+
if (y.size() != F) throw std::runtime_error{ text::format(
|
|
372
|
+
"size of 'y' must be equal to the number of vars.\n"
|
|
373
|
+
"size of 'y' : %zd, number of vars: %zd", y.size(), F) };
|
|
374
|
+
doc.y = y;
|
|
375
|
+
}
|
|
376
|
+
return doc;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override
|
|
380
|
+
{
|
|
381
|
+
auto doc = this->_makeDoc(words);
|
|
382
|
+
return this->_addDoc(_updateDoc(doc, y));
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override
|
|
386
|
+
{
|
|
387
|
+
auto doc = as_mutable(this)->template _makeDoc<true>(words);
|
|
388
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
392
|
+
const std::vector<Float>& y) override
|
|
393
|
+
{
|
|
394
|
+
auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
|
|
395
|
+
return this->_addDoc(_updateDoc(doc, y));
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
|
399
|
+
const std::vector<Float>& y) const override
|
|
400
|
+
{
|
|
401
|
+
auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
|
|
402
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
406
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
407
|
+
const std::vector<Float>& y) override
|
|
408
|
+
{
|
|
409
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
410
|
+
return this->_addDoc(_updateDoc(doc, y));
|
|
411
|
+
}
|
|
412
|
+
|
|
413
|
+
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
|
414
|
+
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
|
415
|
+
const std::vector<Float>& y) const override
|
|
416
|
+
{
|
|
417
|
+
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
|
418
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
std::vector<Float> estimateVars(const DocumentBase* doc) const override
|
|
422
|
+
{
|
|
423
|
+
std::vector<Float> ret;
|
|
424
|
+
auto pdoc = dynamic_cast<const _DocType*>(doc);
|
|
425
|
+
if (!pdoc) return ret;
|
|
426
|
+
for (auto& f : responseVars)
|
|
427
|
+
{
|
|
428
|
+
ret.emplace_back(f->estimate(pdoc->numByTopic, pdoc->getSumWordWeight()));
|
|
429
|
+
}
|
|
430
|
+
return ret;
|
|
431
|
+
}
|
|
432
|
+
};
|
|
433
|
+
|
|
434
|
+
template<typename _WeightType>
|
|
435
|
+
void detail::GLMFunctor<_WeightType>::serializerRead(
|
|
436
|
+
std::unique_ptr<detail::GLMFunctor<_WeightType>>& p, std::istream& istr)
|
|
437
|
+
{
|
|
438
|
+
uint32_t t = serializer::readFromStream<uint32_t>(istr);
|
|
439
|
+
if (!t) p.reset();
|
|
440
|
+
else
|
|
441
|
+
{
|
|
442
|
+
switch ((ISLDAModel::GLM)(t - 1))
|
|
443
|
+
{
|
|
444
|
+
case ISLDAModel::GLM::linear:
|
|
445
|
+
p = make_unique<LinearFunctor<_WeightType>>();
|
|
446
|
+
break;
|
|
447
|
+
case ISLDAModel::GLM::binary_logistic:
|
|
448
|
+
p = make_unique<BinaryLogisticFunctor<_WeightType>>();
|
|
449
|
+
break;
|
|
450
|
+
default:
|
|
451
|
+
throw std::ios_base::failure(text::format("wrong GLMFunctor type id %d", (t - 1)));
|
|
452
|
+
}
|
|
453
|
+
p->serializerRead(istr);
|
|
454
|
+
}
|
|
455
|
+
}
|
|
456
|
+
}
|