tomoto 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (420) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +123 -0
  5. data/ext/tomoto/ext.cpp +245 -0
  6. data/ext/tomoto/extconf.rb +28 -0
  7. data/lib/tomoto.rb +12 -0
  8. data/lib/tomoto/ct.rb +11 -0
  9. data/lib/tomoto/hdp.rb +11 -0
  10. data/lib/tomoto/lda.rb +67 -0
  11. data/lib/tomoto/version.rb +3 -0
  12. data/vendor/EigenRand/EigenRand/Core.h +1139 -0
  13. data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
  14. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
  15. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
  16. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
  17. data/vendor/EigenRand/EigenRand/EigenRand +19 -0
  18. data/vendor/EigenRand/EigenRand/Macro.h +24 -0
  19. data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
  20. data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
  21. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
  22. data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
  23. data/vendor/EigenRand/EigenRand/doc.h +220 -0
  24. data/vendor/EigenRand/LICENSE +21 -0
  25. data/vendor/EigenRand/README.md +288 -0
  26. data/vendor/eigen/COPYING.BSD +26 -0
  27. data/vendor/eigen/COPYING.GPL +674 -0
  28. data/vendor/eigen/COPYING.LGPL +502 -0
  29. data/vendor/eigen/COPYING.MINPACK +52 -0
  30. data/vendor/eigen/COPYING.MPL2 +373 -0
  31. data/vendor/eigen/COPYING.README +18 -0
  32. data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
  33. data/vendor/eigen/Eigen/Cholesky +46 -0
  34. data/vendor/eigen/Eigen/CholmodSupport +48 -0
  35. data/vendor/eigen/Eigen/Core +537 -0
  36. data/vendor/eigen/Eigen/Dense +7 -0
  37. data/vendor/eigen/Eigen/Eigen +2 -0
  38. data/vendor/eigen/Eigen/Eigenvalues +61 -0
  39. data/vendor/eigen/Eigen/Geometry +62 -0
  40. data/vendor/eigen/Eigen/Householder +30 -0
  41. data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
  42. data/vendor/eigen/Eigen/Jacobi +33 -0
  43. data/vendor/eigen/Eigen/LU +50 -0
  44. data/vendor/eigen/Eigen/MetisSupport +35 -0
  45. data/vendor/eigen/Eigen/OrderingMethods +73 -0
  46. data/vendor/eigen/Eigen/PaStiXSupport +48 -0
  47. data/vendor/eigen/Eigen/PardisoSupport +35 -0
  48. data/vendor/eigen/Eigen/QR +51 -0
  49. data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
  50. data/vendor/eigen/Eigen/SPQRSupport +34 -0
  51. data/vendor/eigen/Eigen/SVD +51 -0
  52. data/vendor/eigen/Eigen/Sparse +36 -0
  53. data/vendor/eigen/Eigen/SparseCholesky +45 -0
  54. data/vendor/eigen/Eigen/SparseCore +69 -0
  55. data/vendor/eigen/Eigen/SparseLU +46 -0
  56. data/vendor/eigen/Eigen/SparseQR +37 -0
  57. data/vendor/eigen/Eigen/StdDeque +27 -0
  58. data/vendor/eigen/Eigen/StdList +26 -0
  59. data/vendor/eigen/Eigen/StdVector +27 -0
  60. data/vendor/eigen/Eigen/SuperLUSupport +64 -0
  61. data/vendor/eigen/Eigen/UmfPackSupport +40 -0
  62. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
  63. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
  64. data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
  65. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
  66. data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
  67. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
  68. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
  69. data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
  70. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
  71. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
  72. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
  73. data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
  74. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
  75. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
  76. data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
  77. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
  78. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
  79. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
  80. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
  81. data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
  82. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
  84. data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
  85. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
  86. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
  87. data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
  88. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
  89. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
  90. data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
  91. data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
  92. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
  93. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
  94. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
  95. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
  96. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
  97. data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
  98. data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
  99. data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
  100. data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
  101. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
  102. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
  103. data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
  104. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
  105. data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
  106. data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
  107. data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
  108. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
  109. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
  110. data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
  111. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
  112. data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
  113. data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
  114. data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
  115. data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
  116. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
  117. data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
  118. data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
  119. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
  120. data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
  121. data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
  122. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
  123. data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
  124. data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
  125. data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
  126. data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
  127. data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
  128. data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
  129. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
  130. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
  131. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
  132. data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
  133. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
  134. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
  135. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
  136. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
  137. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
  138. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
  139. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
  140. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
  141. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
  142. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
  143. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
  145. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
  146. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
  148. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
  149. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
  153. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
  154. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
  156. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
  160. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
  161. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
  162. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
  163. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
  164. data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
  165. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
  166. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
  167. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
  168. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
  169. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
  170. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
  171. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
  172. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
  173. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
  174. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
  175. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
  176. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
  177. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
  178. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
  179. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
  180. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
  181. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
  182. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
  183. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
  184. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
  185. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
  186. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
  187. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
  188. data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
  189. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
  190. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
  191. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
  192. data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
  193. data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
  194. data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
  195. data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
  196. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
  197. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
  198. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
  199. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
  200. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
  201. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
  202. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
  203. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
  204. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
  205. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
  206. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
  207. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
  208. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
  209. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
  210. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
  211. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
  212. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
  213. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
  214. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
  215. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
  216. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
  217. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
  218. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
  219. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
  220. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
  221. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
  222. data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
  223. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
  224. data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
  225. data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
  226. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
  227. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
  228. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
  229. data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
  230. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
  231. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
  232. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
  233. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
  234. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
  235. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
  236. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
  237. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
  238. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
  239. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
  240. data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
  241. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
  242. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
  243. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
  244. data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
  245. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
  246. data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
  247. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
  248. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
  249. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
  250. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
  251. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
  252. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
  253. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
  254. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
  255. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
  256. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
  257. data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
  258. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
  259. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
  260. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
  261. data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
  262. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
  263. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
  264. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
  265. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
  266. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
  267. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
  268. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
  269. data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
  270. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
  271. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
  272. data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
  273. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
  274. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
  275. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
  276. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
  277. data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
  278. data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
  279. data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
  280. data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
  281. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
  282. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
  283. data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
  284. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
  285. data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
  295. data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
  296. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
  297. data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
  298. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
  299. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
  300. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
  307. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
  308. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
  309. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
  310. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
  311. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
  312. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
  313. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
  314. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
  315. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
  316. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
  317. data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
  318. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
  319. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
  320. data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
  321. data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
  322. data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
  323. data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
  324. data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
  325. data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
  326. data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
  327. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
  328. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
  329. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
  330. data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
  331. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
  332. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
  333. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
  334. data/vendor/eigen/README.md +3 -0
  335. data/vendor/eigen/bench/README.txt +55 -0
  336. data/vendor/eigen/bench/btl/COPYING +340 -0
  337. data/vendor/eigen/bench/btl/README +154 -0
  338. data/vendor/eigen/bench/tensors/README +21 -0
  339. data/vendor/eigen/blas/README.txt +6 -0
  340. data/vendor/eigen/demos/mandelbrot/README +10 -0
  341. data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
  342. data/vendor/eigen/demos/opengl/README +13 -0
  343. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
  344. data/vendor/eigen/unsupported/README.txt +50 -0
  345. data/vendor/tomotopy/LICENSE +21 -0
  346. data/vendor/tomotopy/README.kr.rst +375 -0
  347. data/vendor/tomotopy/README.rst +382 -0
  348. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
  349. data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
  350. data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
  351. data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
  352. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
  353. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
  354. data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
  355. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
  356. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
  357. data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
  358. data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
  359. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
  360. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
  361. data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
  362. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
  363. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
  364. data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
  365. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
  366. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
  367. data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
  368. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
  369. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
  370. data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
  371. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
  372. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
  373. data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
  374. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
  375. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
  376. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
  377. data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
  378. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
  379. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
  380. data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
  381. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
  382. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
  383. data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
  384. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
  385. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
  386. data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
  387. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
  388. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
  389. data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
  390. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
  391. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
  392. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
  393. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
  394. data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
  395. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
  396. data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
  397. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
  398. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
  399. data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
  400. data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
  401. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
  402. data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
  403. data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
  404. data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
  405. data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
  406. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
  407. data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
  408. data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
  409. data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
  410. data/vendor/tomotopy/src/Utils/exception.h +28 -0
  411. data/vendor/tomotopy/src/Utils/math.h +281 -0
  412. data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
  413. data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
  414. data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
  415. data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
  416. data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
  417. data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
  418. data/vendor/tomotopy/src/Utils/text.hpp +49 -0
  419. data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
  420. metadata +531 -0
@@ -0,0 +1,45 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentLLDA : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ using WeightType = typename DocumentLDA<_tw>::WeightType;
12
+ Eigen::Matrix<int8_t, -1, 1> labelMask;
13
+
14
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, labelMask);
15
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, labelMask);
16
+ };
17
+
18
+ class ILLDAModel : public ILDAModel
19
+ {
20
+ public:
21
+ using DefaultDocType = DocumentLLDA<TermWeight::one>;
22
+ static ILLDAModel* create(TermWeight _weight, size_t _K = 1,
23
+ Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
24
+ bool scalarRng = false);
25
+
26
+ virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) = 0;
27
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& label) const = 0;
28
+
29
+ virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
30
+ const std::vector<std::string>& label) = 0;
31
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
32
+ const std::vector<std::string>& label) const = 0;
33
+
34
+ virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
35
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
36
+ const std::vector<std::string>& label) = 0;
37
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
38
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
39
+ const std::vector<std::string>& label) const = 0;
40
+
41
+ virtual const Dictionary& getTopicLabelDict() const = 0;
42
+
43
+ virtual size_t getNumTopicsPerLabel() const = 0;
44
+ };
45
+ }
@@ -0,0 +1,13 @@
1
+ #include "LLDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class LLDAModel<TermWeight::one>;
6
+ template class LLDAModel<TermWeight::idf>;
7
+ template class LLDAModel<TermWeight::pmi>;*/
8
+
9
+ ILLDAModel* ILLDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, LLDAModel, _K, _alpha, _eta, seed);
12
+ }
13
+ }
@@ -0,0 +1,203 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "LLDA.h"
4
+
5
+ /*
6
+ Implementation of Labeled LDA using Gibbs sampling by bab2min
7
+
8
+ * Ramage, D., Hall, D., Nallapati, R., & Manning, C. D. (2009, August). Labeled LDA: A supervised topic model for credit attribution in multi-labeled corpora. In Proceedings of the 2009 Conference on Empirical Methods in Natural Language Processing: Volume 1-Volume 1 (pp. 248-256). Association for Computational Linguistics.
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ template<TermWeight _tw, typename _RandGen,
14
+ typename _Interface = ILLDAModel,
15
+ typename _Derived = void,
16
+ typename _DocType = DocumentLLDA<_tw>,
17
+ typename _ModelState = ModelStateLDA<_tw>>
18
+ class LLDAModel : public LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface,
19
+ typename std::conditional<std::is_same<_Derived, void>::value, LLDAModel<_tw, _RandGen>, _Derived>::type,
20
+ _DocType, _ModelState>
21
+ {
22
+ protected:
23
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, LLDAModel<_tw, _RandGen>, _Derived>::type;
24
+ using BaseClass = LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
25
+ friend BaseClass;
26
+ friend typename BaseClass::BaseClass;
27
+ using WeightType = typename BaseClass::WeightType;
28
+
29
+ static constexpr char TMID[] = "LLDA";
30
+
31
+ Dictionary topicLabelDict;
32
+
33
+ template<bool _asymEta>
34
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
35
+ {
36
+ const size_t V = this->realV;
37
+ assert(vid < V);
38
+ auto& zLikelihood = ld.zLikelihood;
39
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
40
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + this->eta)
41
+ / (ld.numByTopic.array().template cast<Float>() + V * this->eta);
42
+ zLikelihood.array() *= doc.labelMask.array().template cast<Float>();
43
+ sample::prefixSum(zLikelihood.data(), this->K);
44
+ return &zLikelihood[0];
45
+ }
46
+
47
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
48
+ {
49
+ BaseClass::prepareDoc(doc, docId, wordSize);
50
+ if (doc.labelMask.size() == 0)
51
+ {
52
+ doc.labelMask.resize(this->K);
53
+ doc.labelMask.setOnes();
54
+ }
55
+ else if (doc.labelMask.size() < this->K)
56
+ {
57
+ size_t oldSize = doc.labelMask.size();
58
+ doc.labelMask.conservativeResize(this->K);
59
+ doc.labelMask.segment(oldSize, topicLabelDict.size() - oldSize).setZero();
60
+ doc.labelMask.segment(topicLabelDict.size(), this->K - topicLabelDict.size()).setOnes();
61
+ }
62
+ }
63
+
64
+ void initGlobalState(bool initDocs)
65
+ {
66
+ this->K = std::max(topicLabelDict.size(), (size_t)this->K);
67
+ this->alphas.resize(this->K);
68
+ this->alphas.array() = this->alpha;
69
+ BaseClass::initGlobalState(initDocs);
70
+ }
71
+
72
+ struct Generator
73
+ {
74
+ std::discrete_distribution<> theta;
75
+ };
76
+
77
+ Generator makeGeneratorForInit(const _DocType* doc) const
78
+ {
79
+ std::discrete_distribution<> theta{ doc->labelMask.data(), doc->labelMask.data() + this->K };
80
+ return Generator{ theta };
81
+ }
82
+
83
+ template<bool _Infer>
84
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
85
+ {
86
+ auto& z = doc.Zs[i];
87
+ auto w = doc.words[i];
88
+ if (this->etaByTopicWord.size())
89
+ {
90
+ Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
91
+ for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
92
+ z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
93
+ }
94
+ else
95
+ {
96
+ z = g.theta(rgs);
97
+ }
98
+ this->template addWordTo<1>(ld, doc, i, w, z);
99
+ }
100
+
101
+ public:
102
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict);
103
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict);
104
+
105
+ LLDAModel(size_t _K = 1, Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
106
+ : BaseClass(_K, _alpha, _eta, _rg)
107
+ {
108
+ }
109
+
110
+ template<bool _const = false>
111
+ _DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& labels)
112
+ {
113
+ if (_const)
114
+ {
115
+ doc.labelMask.resize(this->K);
116
+ doc.labelMask.setOnes();
117
+
118
+ std::vector<Vid> topicLabelIds;
119
+ for (auto& label : labels)
120
+ {
121
+ auto tid = topicLabelDict.toWid(label);
122
+ if (tid == (Vid)-1) continue;
123
+ topicLabelIds.emplace_back(tid);
124
+ }
125
+
126
+ if (!topicLabelIds.empty())
127
+ {
128
+ doc.labelMask.head(topicLabelDict.size()).setZero();
129
+ for (auto tid : topicLabelIds) doc.labelMask[tid] = 1;
130
+ }
131
+ }
132
+ else
133
+ {
134
+ if (!labels.empty())
135
+ {
136
+ std::vector<Vid> topicLabelIds;
137
+ for (auto& label : labels) topicLabelIds.emplace_back(topicLabelDict.add(label));
138
+ auto maxVal = *std::max_element(topicLabelIds.begin(), topicLabelIds.end());
139
+ doc.labelMask.resize(maxVal + 1);
140
+ doc.labelMask.setZero();
141
+ for (auto i : topicLabelIds) doc.labelMask[i] = 1;
142
+ }
143
+ }
144
+ return doc;
145
+ }
146
+
147
+ size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
148
+ {
149
+ auto doc = this->_makeDoc(words);
150
+ return this->_addDoc(_updateDoc(doc, labels));
151
+ }
152
+
153
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
154
+ {
155
+ auto doc = as_mutable(this)->template _makeDoc<true>(words);
156
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
157
+ }
158
+
159
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
160
+ const std::vector<std::string>& labels) override
161
+ {
162
+ auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
163
+ return this->_addDoc(_updateDoc(doc, labels));
164
+ }
165
+
166
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
167
+ const std::vector<std::string>& labels) const override
168
+ {
169
+ auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
170
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
171
+ }
172
+
173
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
174
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
175
+ const std::vector<std::string>& labels) override
176
+ {
177
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
178
+ return this->_addDoc(_updateDoc(doc, labels));
179
+ }
180
+
181
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
182
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
183
+ const std::vector<std::string>& labels) const override
184
+ {
185
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
186
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
187
+ }
188
+
189
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
190
+ {
191
+ std::vector<Float> ret(this->K);
192
+ auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
193
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
194
+ (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
195
+ / (doc.getSumWordWeight() + maskedAlphas.sum());
196
+ return ret;
197
+ }
198
+
199
+ const Dictionary& getTopicLabelDict() const override { return topicLabelDict; }
200
+
201
+ size_t getNumTopicsPerLabel() const override { return 1; }
202
+ };
203
+ }
@@ -0,0 +1,63 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentMGLDA : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ using WeightType = typename DocumentLDA<_tw>::WeightType;
12
+
13
+ std::vector<uint16_t> sents; // sentence id of each word (const)
14
+ std::vector<WeightType> numBySent; // number of words in the sentence (const)
15
+
16
+ //std::vector<Tid> Zs; // gl./loc. and topic assignment
17
+ std::vector<uint8_t> Vs; // window assignment
18
+ WeightType numGl = 0; // number of words assigned as gl.
19
+ //std::vector<uint32_t> numByTopic; // len = K + KL
20
+ Eigen::Matrix<WeightType, -1, -1> numBySentWin; // len = S * T
21
+ Eigen::Matrix<WeightType, -1, 1> numByWinL; // number of words assigned as loc. in the window (len = S + T - 1)
22
+ Eigen::Matrix<WeightType, -1, 1> numByWin; // number of words in the window (len = S + T - 1)
23
+ Eigen::Matrix<WeightType, -1, -1> numByWinTopicL; // number of words in the loc. topic in the window (len = KL * (S + T - 1))
24
+
25
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL);
26
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, sents, Vs, numGl, numBySentWin, numByWinL, numByWin, numByWinTopicL);
27
+
28
+ template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
29
+ };
30
+
31
+ class IMGLDAModel : public ILDAModel
32
+ {
33
+ public:
34
+ using DefaultDocType = DocumentMGLDA<TermWeight::one>;
35
+ static IMGLDAModel* create(TermWeight _weight, size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
36
+ Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
37
+ Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t seed = std::random_device{}(),
38
+ bool scalarRng = false);
39
+
40
+ virtual size_t addDoc(const std::vector<std::string>& words, const std::string& delimiter) = 0;
41
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::string& delimiter) const = 0;
42
+
43
+ virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
44
+ const std::string& delimiter) = 0;
45
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
46
+ const std::string& delimiter) const = 0;
47
+
48
+ virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
49
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
50
+ const std::string& delimiter) = 0;
51
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
52
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
53
+ const std::string& delimiter) const = 0;
54
+
55
+ virtual size_t getKL() const = 0;
56
+ virtual size_t getT() const = 0;
57
+ virtual Float getAlphaL() const = 0;
58
+ virtual Float getEtaL() const = 0;
59
+ virtual Float getGamma() const = 0;
60
+ virtual Float getAlphaM() const = 0;
61
+ virtual Float getAlphaML() const = 0;
62
+ };
63
+ }
@@ -0,0 +1,17 @@
1
+ #include "MGLDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class MGLDAModel<TermWeight::one>;
6
+ template class MGLDAModel<TermWeight::idf>;
7
+ template class MGLDAModel<TermWeight::pmi>;*/
8
+
9
+ IMGLDAModel* IMGLDAModel::create(TermWeight _weight, size_t _KG, size_t _KL, size_t _T,
10
+ Float _alphaG, Float _alphaL, Float _alphaMG, Float _alphaML,
11
+ Float _etaG, Float _etaL, Float _gamma, size_t seed, bool scalarRng)
12
+ {
13
+ TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel, _KG, _KL, _T,
14
+ _alphaG, _alphaL, _alphaMG, _alphaML,
15
+ _etaG, _etaL, _gamma, seed);
16
+ }
17
+ }
@@ -0,0 +1,558 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "MGLDA.h"
4
+ /*
5
+ Implementation of MG-LDA using Gibbs sampling by bab2min
6
+ Improved version of java implementation(https://github.com/yinfeiy/MG-LDA)
7
+
8
+ * Titov, I., & McDonald, R. (2008, April). Modeling online reviews with multi-grain topic models. In Proceedings of the 17th international conference on World Wide Web (pp. 111-120). ACM.
9
+
10
+ */
11
+
12
+ namespace tomoto
13
+ {
14
+ template<TermWeight _tw, typename _RandGen,
15
+ typename _Interface = IMGLDAModel,
16
+ typename _Derived = void,
17
+ typename _DocType = DocumentMGLDA<_tw>,
18
+ typename _ModelState = ModelStateLDA<_tw>>
19
+ class MGLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
20
+ typename std::conditional<std::is_same<_Derived, void>::value, MGLDAModel<_tw, _RandGen>, _Derived>::type,
21
+ _DocType, _ModelState>
22
+ {
23
+ protected:
24
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, MGLDAModel<_tw, _RandGen>, _Derived>::type;
25
+ using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
26
+ friend BaseClass;
27
+ friend typename BaseClass::BaseClass;
28
+ using WeightType = typename BaseClass::WeightType;
29
+
30
+ Float alphaL;
31
+ Float alphaM, alphaML;
32
+ Float etaL;
33
+ Float gamma;
34
+ Tid KL;
35
+ uint32_t T; // window size
36
+
37
+ // window and gl./loc. and topic assignment likelihoods for new word. ret T*(K+KL) FLOATs
38
+ Float* getVZLikelihoods(_ModelState& ld, const _DocType& doc, Vid vid, uint16_t s) const
39
+ {
40
+ const auto V = this->realV;
41
+ const auto K = this->K;
42
+ const auto alpha = this->alpha;
43
+ const auto eta = this->eta;
44
+ assert(vid < V);
45
+ auto& zLikelihood = ld.zLikelihood;
46
+ for (size_t v = 0; v < T; ++v)
47
+ {
48
+ Float pLoc = (doc.numByWinL[s + v] + alphaML) / (doc.numByWin[s + v] + alphaM + alphaML);
49
+ Float pW = doc.numBySentWin(s, v) + gamma;
50
+ if (K)
51
+ {
52
+ zLikelihood.segment(v * (K + KL), K) = (1 - pLoc) * pW
53
+ * (doc.numByTopic.segment(0, K).array().template cast<Float>() + alpha) / (doc.numGl + K * alpha)
54
+ * (ld.numByTopicWord.block(0, vid, K, 1).array().template cast<Float>() + eta) / (ld.numByTopic.segment(0, K).array().template cast<Float>() + V * eta);
55
+ }
56
+ zLikelihood.segment(v * (K + KL) + K, KL) = pLoc * pW
57
+ * (doc.numByWinTopicL.col(s + v).array().template cast<Float>()) / (doc.numByWinL[s + v] + KL * alphaL)
58
+ * (ld.numByTopicWord.block(K, vid, KL, 1).array().template cast<Float>() + etaL) / (ld.numByTopic.segment(K, KL).array().template cast<Float>() + V * etaL);
59
+ }
60
+
61
+ sample::prefixSum(zLikelihood.data(), T * (K + KL));
62
+ return &zLikelihood[0];
63
+ }
64
+
65
+ template<int _inc>
66
+ inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
67
+ {
68
+ const auto K = this->K;
69
+
70
+ assert(r != 0 || tid < K);
71
+ assert(r == 0 || tid < KL);
72
+ assert(w < T);
73
+ assert(r < 2);
74
+ assert(vid < this->realV);
75
+ assert(s < doc.numBySent.size());
76
+
77
+ constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
78
+ typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
79
+ = _tw != TermWeight::one ? doc.wordWeights[pid] : 1;
80
+
81
+ updateCnt<_dec>(doc.numByWin[s + w], _inc * weight);
82
+ updateCnt<_dec>(doc.numBySentWin(s, w), _inc * weight);
83
+ if (r == 0)
84
+ {
85
+ updateCnt<_dec>(doc.numByTopic[tid], _inc * weight);
86
+ updateCnt<_dec>(doc.numGl, _inc * weight);
87
+ updateCnt<_dec>(ld.numByTopic[tid], _inc * weight);
88
+ updateCnt<_dec>(ld.numByTopicWord(tid, vid), _inc * weight);
89
+ }
90
+ else
91
+ {
92
+ updateCnt<_dec>(doc.numByTopic[tid + K], _inc * weight);
93
+ updateCnt<_dec>(doc.numByWinL[s + w], _inc * weight);
94
+ updateCnt<_dec>(doc.numByWinTopicL(tid, s + w), _inc * weight);
95
+ updateCnt<_dec>(ld.numByTopic[tid + K], _inc * weight);
96
+ updateCnt<_dec>(ld.numByTopicWord(tid + K, vid), _inc * weight);
97
+ }
98
+ }
99
+
100
+ template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
101
+ void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
102
+ {
103
+ size_t b = 0, e = doc.words.size();
104
+ if (_ps == ParallelScheme::partition)
105
+ {
106
+ b = edd.chunkOffsetByDoc(partitionId, docId);
107
+ e = edd.chunkOffsetByDoc(partitionId + 1, docId);
108
+ }
109
+
110
+ size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
111
+
112
+ const auto K = this->K;
113
+ for (size_t w = b; w < e; ++w)
114
+ {
115
+ if (doc.words[w] >= this->realV) continue;
116
+ addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
117
+ auto dist = getVZLikelihoods(ld, doc, doc.words[w] - vOffset, doc.sents[w]);
118
+ auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
119
+ doc.Vs[w] = vz / (K + KL);
120
+ doc.Zs[w] = vz % (K + KL);
121
+ addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
122
+ }
123
+ }
124
+
125
+ template<typename _DocIter>
126
+ double getLLDocs(_DocIter _first, _DocIter _last) const
127
+ {
128
+ const auto K = this->K;
129
+ const auto alpha = this->alpha;
130
+
131
+ size_t totSents = 0, totWins = 0;
132
+ double ll = 0;
133
+ if (K) ll += (math::lgammaT(K*alpha) - math::lgammaT(alpha)*K) * std::distance(_first, _last);
134
+ for (; _first != _last; ++_first)
135
+ {
136
+ auto& doc = *_first;
137
+ const size_t S = doc.numBySent.size();
138
+ if (K)
139
+ {
140
+ ll -= math::lgammaT(doc.numGl + K * alpha);
141
+ for (Tid k = 0; k < K; ++k)
142
+ {
143
+ ll += math::lgammaT(doc.numByTopic[k] + alpha);
144
+ }
145
+ }
146
+
147
+ for (size_t v = 0; v < S + T - 1; ++v)
148
+ {
149
+ ll -= math::lgammaT(doc.numByWinL[v] + KL * alphaL);
150
+ for (Tid k = 0; k < KL; ++k)
151
+ {
152
+ ll += math::lgammaT(doc.numByWinTopicL(k, v) + alphaL);
153
+ }
154
+ if (K)
155
+ {
156
+ ll += math::lgammaT(std::max((Float)doc.numByWin[v] - doc.numByWinL[v], (Float)0) + alphaM);
157
+ ll += math::lgammaT(doc.numByWinL[v] + alphaML);
158
+ ll -= math::lgammaT(doc.numByWin[v] + alphaM + alphaML);
159
+ }
160
+ }
161
+
162
+ totWins += S + T - 1;
163
+ totSents += S;
164
+ for (size_t s = 0; s < S; ++s)
165
+ {
166
+ ll -= math::lgammaT(doc.numBySent[s] + T * gamma);
167
+ for (size_t v = 0; v < T; ++v)
168
+ {
169
+ ll += math::lgammaT(doc.numBySentWin(s, v) + gamma);
170
+ }
171
+ }
172
+ }
173
+ ll += (math::lgammaT(KL*alphaL) - math::lgammaT(alphaL)*KL) * totWins;
174
+ if (K) ll += (math::lgammaT(alphaM + alphaML) - math::lgammaT(alphaM) - math::lgammaT(alphaML)) * totWins;
175
+ ll += (math::lgammaT(T * gamma) - math::lgammaT(gamma) * T) * totSents;
176
+
177
+ return ll;
178
+ }
179
+
180
+ double getLLRest(const _ModelState& ld) const
181
+ {
182
+ const auto V = this->realV;
183
+ const auto K = this->K;
184
+ const auto eta = this->eta;
185
+
186
+ double ll = 0;
187
+ ll += (math::lgammaT(V*eta) - math::lgammaT(eta)*V) * K;
188
+ for (Tid k = 0; k < K; ++k)
189
+ {
190
+ ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
191
+ for (Vid w = 0; w < V; ++w)
192
+ {
193
+ ll += math::lgammaT(ld.numByTopicWord(k, w) + eta);
194
+ }
195
+ }
196
+ ll += (math::lgammaT(V*etaL) - math::lgammaT(etaL)*V) * KL;
197
+ for (Tid k = 0; k < KL; ++k)
198
+ {
199
+ ll -= math::lgammaT(ld.numByTopic[k + K] + V * etaL);
200
+ for (Vid w = 0; w < V; ++w)
201
+ {
202
+ ll += math::lgammaT(ld.numByTopicWord(k + K, w) + etaL);
203
+ }
204
+ }
205
+ return ll;
206
+ }
207
+
208
+ double getLL() const
209
+ {
210
+ double ll = 0;
211
+ const auto V = this->realV;
212
+ const auto K = this->K;
213
+ const auto alpha = this->alpha;
214
+ const auto eta = this->eta;
215
+ size_t totSents = 0, totWins = 0;
216
+ if(K) ll += (math::lgammaT(K*alpha) - math::lgammaT(alpha)*K) * this->docs.size();
217
+ for (size_t i = 0; i < this->docs.size(); ++i)
218
+ {
219
+ auto& doc = this->docs[i];
220
+ const size_t S = doc.numBySent.size();
221
+ if (K)
222
+ {
223
+ ll -= math::lgammaT(doc.numGl + K * alpha);
224
+ for (Tid k = 0; k < K; ++k)
225
+ {
226
+ ll += math::lgammaT(doc.numByTopic[k] + alpha);
227
+ }
228
+ }
229
+
230
+ for (size_t v = 0; v < S + T - 1; ++v)
231
+ {
232
+ ll -= math::lgammaT(doc.numByWinL[v] + KL * alphaL);
233
+ for (Tid k = 0; k < KL; ++k)
234
+ {
235
+ ll += math::lgammaT(doc.numByWinTopicL(k, v) + alphaL);
236
+ }
237
+ if (K)
238
+ {
239
+ ll += math::lgammaT(std::max((Float)doc.numByWin[v] - doc.numByWinL[v], (Float)0) + alphaM);
240
+ ll += math::lgammaT(doc.numByWinL[v] + alphaML);
241
+ ll -= math::lgammaT(doc.numByWin[v] + alphaM + alphaML);
242
+ }
243
+ }
244
+
245
+ totWins += S + T - 1;
246
+ totSents += S;
247
+ for (size_t s = 0; s < S; ++s)
248
+ {
249
+ ll -= math::lgammaT(doc.numBySent[s] + T * gamma);
250
+ for (size_t v = 0; v < T; ++v)
251
+ {
252
+ ll += math::lgammaT(doc.numBySentWin(s, v) + gamma);
253
+ }
254
+ }
255
+ }
256
+ ll += (math::lgammaT(KL*alphaL) - math::lgammaT(alphaL)*KL) * totWins;
257
+ if(K) ll += (math::lgammaT(alphaM + alphaML) - math::lgammaT(alphaM) - math::lgammaT(alphaML)) * totWins;
258
+ ll += (math::lgammaT(T * gamma) - math::lgammaT(gamma) * T) * totSents;
259
+
260
+ //
261
+ ll += (math::lgammaT(V*eta) - math::lgammaT(eta)*V) * K;
262
+ for (Tid k = 0; k < K; ++k)
263
+ {
264
+ ll -= math::lgammaT(this->globalState.numByTopic[k] + V * eta);
265
+ for (Vid w = 0; w < V; ++w)
266
+ {
267
+ ll += math::lgammaT(this->globalState.numByTopicWord(k, w) + eta);
268
+ }
269
+ }
270
+ ll += (math::lgammaT(V*etaL) - math::lgammaT(etaL)*V) * KL;
271
+ for (Tid k = 0; k < KL; ++k)
272
+ {
273
+ ll -= math::lgammaT(this->globalState.numByTopic[k + K] + V * etaL);
274
+ for (Vid w = 0; w < V; ++w)
275
+ {
276
+ ll += math::lgammaT(this->globalState.numByTopicWord(k + K, w) + etaL);
277
+ }
278
+ }
279
+
280
+ return ll;
281
+ }
282
+
283
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
284
+ {
285
+ sortAndWriteOrder(doc.words, doc.wOrder);
286
+ auto tmp = doc.sents;
287
+ for (size_t i = 0; i < doc.wOrder.size(); ++i)
288
+ {
289
+ doc.sents[doc.wOrder[i]] = tmp[i];
290
+ }
291
+
292
+ const size_t S = doc.numBySent.size();
293
+ std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0);
294
+ doc.Zs = tvector<Tid>(wordSize);
295
+ doc.Vs.resize(wordSize);
296
+ if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
297
+ doc.numByTopic.init(nullptr, this->K + KL);
298
+ doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
299
+ doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
300
+ doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
301
+ doc.numByWinTopicL = Eigen::Matrix<WeightType, -1, -1>::Zero(KL, S + T - 1);
302
+ }
303
+
304
+ void initGlobalState(bool initDocs)
305
+ {
306
+ const size_t V = this->realV;
307
+ this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(T * (this->K + KL));
308
+ if (initDocs)
309
+ {
310
+ this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
311
+ this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
312
+ }
313
+ }
314
+
315
+ struct Generator
316
+ {
317
+ std::discrete_distribution<uint16_t> pi;
318
+ std::uniform_int_distribution<Tid> theta;
319
+ std::uniform_int_distribution<Tid> thetaL;
320
+ std::uniform_int_distribution<uint16_t> psi;
321
+ };
322
+
323
+ Generator makeGeneratorForInit(const _DocType*) const
324
+ {
325
+ return Generator{ std::discrete_distribution<uint16_t>{ alphaM, alphaML },
326
+ std::uniform_int_distribution<Tid>{ 0, (Tid)(this->K - 1) },
327
+ std::uniform_int_distribution<Tid>{ 0, (Tid)(KL - 1) },
328
+ std::uniform_int_distribution<uint16_t>{ 0, (uint16_t)(T - 1) } };
329
+ }
330
+
331
+ template<bool _Infer>
332
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
333
+ {
334
+ doc.numBySent[doc.sents[i]] += _tw == TermWeight::one ? 1 : doc.wordWeights[i];
335
+ auto w = doc.words[i];
336
+ size_t r, z;
337
+ if (this->etaByTopicWord.size())
338
+ {
339
+ Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
340
+ col.head(this->K) *= alphaM / this->K;
341
+ col.tail(this->KL) *= alphaML / this->KL;
342
+ doc.Zs[i] = z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
343
+ r = z < this->K;
344
+ if (z >= this->K) z -= this->K;
345
+ }
346
+ else
347
+ {
348
+ r = g.pi(rgs);
349
+ z = (r ? g.thetaL : g.theta)(rgs);
350
+ doc.Zs[i] = z + (r ? this->K : 0);
351
+ }
352
+
353
+ auto& win = doc.Vs[i];
354
+ win = g.psi(rgs);
355
+ addWordTo<1>(ld, doc, i, w, z, doc.sents[i], win, r);
356
+ }
357
+
358
+ std::vector<uint64_t> _getTopicsCount() const
359
+ {
360
+ std::vector<uint64_t> cnt(this->K + KL);
361
+ for (auto& doc : this->docs)
362
+ {
363
+ for (size_t i = 0; i < doc.Zs.size(); ++i)
364
+ {
365
+ if (doc.words[i] < this->realV) ++cnt[doc.Zs[i]];
366
+ }
367
+ }
368
+ return cnt;
369
+ }
370
+
371
+ public:
372
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
373
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
374
+
375
+ MGLDAModel(size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
376
+ Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
377
+ Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
378
+ : BaseClass(_KG, _alphaG, _etaG, _rg), KL(_KL), T(_T),
379
+ alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML),
380
+ etaL(_etaL), gamma(_gamma)
381
+ {
382
+ if (_KL == 0 || _KL >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong KL value (KL = %zd)", _KL));
383
+ if (_T == 0 || _T >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong T value (T = %zd)", _T));
384
+ if (_alphaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alphaL value (alphaL = %f)", _alphaL));
385
+ if (_etaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong etaL value (etaL = %f)", _etaL));
386
+ }
387
+
388
+
389
+ template<bool _const = false>
390
+ _DocType _makeDoc(const std::vector<std::string>& words, const std::string& delimiter)
391
+ {
392
+ _DocType doc{ 1.f };
393
+ size_t numSent = 0;
394
+ for (auto& w : words)
395
+ {
396
+ if (w == delimiter)
397
+ {
398
+ ++numSent;
399
+ continue;
400
+ }
401
+
402
+ Vid id;
403
+ if (_const)
404
+ {
405
+ id = this->dict.toWid(w);
406
+ if (id == (Vid)-1) continue;
407
+ }
408
+ else
409
+ {
410
+ id = this->dict.add(w);
411
+ }
412
+ doc.words.emplace_back(id);
413
+ doc.sents.emplace_back(numSent);
414
+ }
415
+ doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
416
+ return doc;
417
+ }
418
+
419
+ size_t addDoc(const std::vector<std::string>& words, const std::string& delimiter) override
420
+ {
421
+ return this->_addDoc(_makeDoc(words, delimiter));
422
+ }
423
+
424
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::string& delimiter) const override
425
+ {
426
+ return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words, delimiter));
427
+ }
428
+
429
+ template<bool _const, typename _FnTokenizer>
430
+ _DocType _makeRawDoc(const std::string& rawStr, _FnTokenizer&& tokenizer, const std::string& delimiter)
431
+ {
432
+ _DocType doc{ 1.f };
433
+ size_t numSent = 0;
434
+ doc.rawStr = rawStr;
435
+ for (auto& p : tokenizer(doc.rawStr))
436
+ {
437
+ if (std::get<0>(p) == delimiter)
438
+ {
439
+ ++numSent;
440
+ continue;
441
+ }
442
+
443
+ Vid wid;
444
+ if (_const)
445
+ {
446
+ wid = this->dict.toWid(std::get<0>(p));
447
+ if (wid == (Vid)-1) continue;
448
+ }
449
+ else
450
+ {
451
+ wid = this->dict.add(std::get<0>(p));
452
+ }
453
+ auto pos = std::get<1>(p);
454
+ auto len = std::get<2>(p);
455
+ doc.words.emplace_back(wid);
456
+ doc.sents.emplace_back(numSent);
457
+ doc.origWordPos.emplace_back(pos);
458
+ doc.origWordLen.emplace_back(len);
459
+ }
460
+ doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
461
+ return doc;
462
+ }
463
+
464
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
465
+ const std::string& delimiter)
466
+ {
467
+ return this->_addDoc(_makeRawDoc<false>(rawStr, tokenizer, delimiter));
468
+ }
469
+
470
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
471
+ const std::string& delimiter) const
472
+ {
473
+ return make_unique<_DocType>(as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer, delimiter));
474
+ }
475
+
476
+ _DocType _makeRawDoc(const std::string& rawStr, const std::vector<Vid>& words,
477
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, const std::string& delimiter) const
478
+ {
479
+ _DocType doc{ 1.f };
480
+ doc.rawStr = rawStr;
481
+ size_t numSent = 0;
482
+ Vid delimiterId = this->dict.toWid(delimiter);
483
+ for (size_t i = 0; i < words.size(); ++i)
484
+ {
485
+ auto& w = words[i];
486
+ if (w == delimiterId)
487
+ {
488
+ ++numSent;
489
+ continue;
490
+ }
491
+ doc.words.emplace_back(w);
492
+ doc.sents.emplace_back(numSent);
493
+ if (words.size() == pos.size())
494
+ {
495
+ doc.origWordPos.emplace_back(pos[i]);
496
+ doc.origWordLen.emplace_back(len[i]);
497
+ }
498
+ }
499
+ doc.numBySent.resize(doc.sents.empty() ? 0 : (doc.sents.back() + 1));
500
+ return doc;
501
+ }
502
+
503
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
504
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
505
+ const std::string& delimiter)
506
+ {
507
+ return this->_addDoc(_makeRawDoc(rawStr, words, pos, len, delimiter));
508
+ }
509
+
510
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
511
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
512
+ const std::string& delimiter) const
513
+ {
514
+ return make_unique<_DocType>(_makeRawDoc(rawStr, words, pos, len, delimiter));
515
+ }
516
+
517
+ void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
518
+ {
519
+ if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
520
+ for (auto p : priors)
521
+ {
522
+ if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
523
+ }
524
+ this->dict.add(word);
525
+ this->etaByWord.emplace(word, priors);
526
+ }
527
+
528
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
529
+ {
530
+ std::vector<Float> ret(this->K + KL);
531
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K + KL }.array() =
532
+ doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
533
+ return ret;
534
+ }
535
+
536
+ GETTER(KL, size_t, KL);
537
+ GETTER(T, size_t, T);
538
+ GETTER(Gamma, Float, gamma);
539
+ GETTER(AlphaL, Float, alphaL);
540
+ GETTER(EtaL, Float, etaL);
541
+ GETTER(AlphaM, Float, alphaM);
542
+ GETTER(AlphaML, Float, alphaML);
543
+ };
544
+
545
+ template<TermWeight _tw>
546
+ template<typename _TopicModel>
547
+ void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
548
+ {
549
+ this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
550
+ numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
551
+ for (size_t i = 0; i < this->Zs.size(); ++i)
552
+ {
553
+ if (this->words[i] >= mdl.getV()) continue;
554
+ this->numByTopic[this->Zs[i]] += _tw != TermWeight::one ? this->wordWeights[i] : 1;
555
+ numBySent[sents[i]] += _tw != TermWeight::one ? this->wordWeights[i] : 1;
556
+ }
557
+ }
558
+ }