tomoto 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (420) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +123 -0
  5. data/ext/tomoto/ext.cpp +245 -0
  6. data/ext/tomoto/extconf.rb +28 -0
  7. data/lib/tomoto.rb +12 -0
  8. data/lib/tomoto/ct.rb +11 -0
  9. data/lib/tomoto/hdp.rb +11 -0
  10. data/lib/tomoto/lda.rb +67 -0
  11. data/lib/tomoto/version.rb +3 -0
  12. data/vendor/EigenRand/EigenRand/Core.h +1139 -0
  13. data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
  14. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
  15. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
  16. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
  17. data/vendor/EigenRand/EigenRand/EigenRand +19 -0
  18. data/vendor/EigenRand/EigenRand/Macro.h +24 -0
  19. data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
  20. data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
  21. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
  22. data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
  23. data/vendor/EigenRand/EigenRand/doc.h +220 -0
  24. data/vendor/EigenRand/LICENSE +21 -0
  25. data/vendor/EigenRand/README.md +288 -0
  26. data/vendor/eigen/COPYING.BSD +26 -0
  27. data/vendor/eigen/COPYING.GPL +674 -0
  28. data/vendor/eigen/COPYING.LGPL +502 -0
  29. data/vendor/eigen/COPYING.MINPACK +52 -0
  30. data/vendor/eigen/COPYING.MPL2 +373 -0
  31. data/vendor/eigen/COPYING.README +18 -0
  32. data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
  33. data/vendor/eigen/Eigen/Cholesky +46 -0
  34. data/vendor/eigen/Eigen/CholmodSupport +48 -0
  35. data/vendor/eigen/Eigen/Core +537 -0
  36. data/vendor/eigen/Eigen/Dense +7 -0
  37. data/vendor/eigen/Eigen/Eigen +2 -0
  38. data/vendor/eigen/Eigen/Eigenvalues +61 -0
  39. data/vendor/eigen/Eigen/Geometry +62 -0
  40. data/vendor/eigen/Eigen/Householder +30 -0
  41. data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
  42. data/vendor/eigen/Eigen/Jacobi +33 -0
  43. data/vendor/eigen/Eigen/LU +50 -0
  44. data/vendor/eigen/Eigen/MetisSupport +35 -0
  45. data/vendor/eigen/Eigen/OrderingMethods +73 -0
  46. data/vendor/eigen/Eigen/PaStiXSupport +48 -0
  47. data/vendor/eigen/Eigen/PardisoSupport +35 -0
  48. data/vendor/eigen/Eigen/QR +51 -0
  49. data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
  50. data/vendor/eigen/Eigen/SPQRSupport +34 -0
  51. data/vendor/eigen/Eigen/SVD +51 -0
  52. data/vendor/eigen/Eigen/Sparse +36 -0
  53. data/vendor/eigen/Eigen/SparseCholesky +45 -0
  54. data/vendor/eigen/Eigen/SparseCore +69 -0
  55. data/vendor/eigen/Eigen/SparseLU +46 -0
  56. data/vendor/eigen/Eigen/SparseQR +37 -0
  57. data/vendor/eigen/Eigen/StdDeque +27 -0
  58. data/vendor/eigen/Eigen/StdList +26 -0
  59. data/vendor/eigen/Eigen/StdVector +27 -0
  60. data/vendor/eigen/Eigen/SuperLUSupport +64 -0
  61. data/vendor/eigen/Eigen/UmfPackSupport +40 -0
  62. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
  63. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
  64. data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
  65. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
  66. data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
  67. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
  68. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
  69. data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
  70. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
  71. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
  72. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
  73. data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
  74. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
  75. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
  76. data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
  77. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
  78. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
  79. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
  80. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
  81. data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
  82. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
  84. data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
  85. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
  86. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
  87. data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
  88. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
  89. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
  90. data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
  91. data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
  92. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
  93. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
  94. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
  95. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
  96. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
  97. data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
  98. data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
  99. data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
  100. data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
  101. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
  102. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
  103. data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
  104. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
  105. data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
  106. data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
  107. data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
  108. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
  109. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
  110. data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
  111. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
  112. data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
  113. data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
  114. data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
  115. data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
  116. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
  117. data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
  118. data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
  119. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
  120. data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
  121. data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
  122. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
  123. data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
  124. data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
  125. data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
  126. data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
  127. data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
  128. data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
  129. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
  130. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
  131. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
  132. data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
  133. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
  134. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
  135. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
  136. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
  137. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
  138. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
  139. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
  140. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
  141. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
  142. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
  143. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
  145. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
  146. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
  148. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
  149. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
  153. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
  154. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
  156. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
  160. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
  161. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
  162. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
  163. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
  164. data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
  165. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
  166. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
  167. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
  168. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
  169. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
  170. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
  171. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
  172. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
  173. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
  174. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
  175. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
  176. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
  177. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
  178. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
  179. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
  180. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
  181. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
  182. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
  183. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
  184. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
  185. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
  186. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
  187. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
  188. data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
  189. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
  190. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
  191. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
  192. data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
  193. data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
  194. data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
  195. data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
  196. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
  197. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
  198. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
  199. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
  200. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
  201. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
  202. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
  203. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
  204. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
  205. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
  206. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
  207. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
  208. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
  209. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
  210. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
  211. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
  212. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
  213. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
  214. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
  215. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
  216. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
  217. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
  218. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
  219. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
  220. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
  221. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
  222. data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
  223. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
  224. data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
  225. data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
  226. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
  227. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
  228. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
  229. data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
  230. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
  231. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
  232. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
  233. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
  234. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
  235. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
  236. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
  237. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
  238. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
  239. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
  240. data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
  241. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
  242. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
  243. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
  244. data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
  245. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
  246. data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
  247. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
  248. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
  249. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
  250. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
  251. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
  252. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
  253. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
  254. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
  255. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
  256. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
  257. data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
  258. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
  259. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
  260. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
  261. data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
  262. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
  263. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
  264. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
  265. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
  266. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
  267. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
  268. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
  269. data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
  270. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
  271. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
  272. data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
  273. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
  274. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
  275. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
  276. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
  277. data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
  278. data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
  279. data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
  280. data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
  281. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
  282. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
  283. data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
  284. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
  285. data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
  295. data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
  296. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
  297. data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
  298. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
  299. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
  300. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
  307. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
  308. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
  309. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
  310. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
  311. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
  312. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
  313. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
  314. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
  315. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
  316. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
  317. data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
  318. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
  319. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
  320. data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
  321. data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
  322. data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
  323. data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
  324. data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
  325. data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
  326. data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
  327. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
  328. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
  329. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
  330. data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
  331. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
  332. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
  333. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
  334. data/vendor/eigen/README.md +3 -0
  335. data/vendor/eigen/bench/README.txt +55 -0
  336. data/vendor/eigen/bench/btl/COPYING +340 -0
  337. data/vendor/eigen/bench/btl/README +154 -0
  338. data/vendor/eigen/bench/tensors/README +21 -0
  339. data/vendor/eigen/blas/README.txt +6 -0
  340. data/vendor/eigen/demos/mandelbrot/README +10 -0
  341. data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
  342. data/vendor/eigen/demos/opengl/README +13 -0
  343. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
  344. data/vendor/eigen/unsupported/README.txt +50 -0
  345. data/vendor/tomotopy/LICENSE +21 -0
  346. data/vendor/tomotopy/README.kr.rst +375 -0
  347. data/vendor/tomotopy/README.rst +382 -0
  348. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
  349. data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
  350. data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
  351. data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
  352. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
  353. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
  354. data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
  355. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
  356. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
  357. data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
  358. data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
  359. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
  360. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
  361. data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
  362. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
  363. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
  364. data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
  365. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
  366. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
  367. data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
  368. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
  369. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
  370. data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
  371. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
  372. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
  373. data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
  374. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
  375. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
  376. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
  377. data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
  378. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
  379. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
  380. data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
  381. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
  382. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
  383. data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
  384. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
  385. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
  386. data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
  387. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
  388. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
  389. data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
  390. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
  391. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
  392. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
  393. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
  394. data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
  395. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
  396. data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
  397. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
  398. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
  399. data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
  400. data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
  401. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
  402. data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
  403. data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
  404. data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
  405. data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
  406. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
  407. data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
  408. data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
  409. data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
  410. data/vendor/tomotopy/src/Utils/exception.h +28 -0
  411. data/vendor/tomotopy/src/Utils/math.h +281 -0
  412. data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
  413. data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
  414. data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
  415. data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
  416. data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
  417. data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
  418. data/vendor/tomotopy/src/Utils/text.hpp +49 -0
  419. data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
  420. metadata +531 -0
@@ -0,0 +1,37 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentCTM : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ Eigen::Matrix<Float, -1, -1> beta; // Dim: (K, betaSample)
12
+ Eigen::Matrix<Float, -1, 1> smBeta; // Dim: K
13
+
14
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, smBeta);
15
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, smBeta);
16
+ };
17
+
18
+ class ICTModel : public ILDAModel
19
+ {
20
+ public:
21
+ using DefaultDocType = DocumentCTM<TermWeight::one>;
22
+ static ICTModel* create(TermWeight _weight, size_t _K = 1,
23
+ Float smoothingAlpha = 0.1, Float _eta = 0.01,
24
+ size_t seed = std::random_device{}(),
25
+ bool scalarRng = false);
26
+
27
+ virtual void setNumBetaSample(size_t numSample) = 0;
28
+ virtual size_t getNumBetaSample() const = 0;
29
+ virtual void setNumTMNSample(size_t numSample) = 0;
30
+ virtual size_t getNumTMNSample() const = 0;
31
+ virtual void setNumDocBetaSample(size_t numSample) = 0;
32
+ virtual size_t getNumDocBetaSample() const = 0;
33
+ virtual std::vector<Float> getPriorMean() const = 0;
34
+ virtual std::vector<Float> getPriorCov() const = 0;
35
+ virtual std::vector<Float> getCorrelationTopic(Tid k) const = 0;
36
+ };
37
+ }
@@ -0,0 +1,13 @@
1
+ #include "CTModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class CTModel<TermWeight::one>;
6
+ template class CTModel<TermWeight::idf>;
7
+ template class CTModel<TermWeight::pmi>;*/
8
+
9
+ ICTModel* ICTModel::create(TermWeight _weight, size_t _K, Float smoothingAlpha, Float _eta, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, CTModel, _K, smoothingAlpha, _eta, seed);
12
+ }
13
+ }
@@ -0,0 +1,293 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "../Utils/MultiNormalDistribution.hpp"
4
+ #include "../Utils/TruncMultiNormal.hpp"
5
+ #include "CT.h"
6
+ /*
7
+ Implementation of CTM using Gibbs sampling by bab2min
8
+ * Blei, D., & Lafferty, J. (2006). Correlated topic models. Advances in neural information processing systems, 18, 147.
9
+ * Mimno, D., Wallach, H., & McCallum, A. (2008, December). Gibbs sampling for logistic normal topic models with graph-based priors. In NIPS Workshop on Analyzing Graphs (Vol. 61).
10
+ */
11
+
12
+ namespace tomoto
13
+ {
14
+ template<TermWeight _tw>
15
+ struct ModelStateCTM : public ModelStateLDA<_tw>
16
+ {
17
+ };
18
+
19
+ template<TermWeight _tw, typename _RandGen,
20
+ size_t _Flags = flags::partitioned_multisampling,
21
+ typename _Interface = ICTModel,
22
+ typename _Derived = void,
23
+ typename _DocType = DocumentCTM<_tw>,
24
+ typename _ModelState = ModelStateCTM<_tw>>
25
+ class CTModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
26
+ typename std::conditional<std::is_same<_Derived, void>::value, CTModel<_tw, _RandGen, _Flags>, _Derived>::type,
27
+ _DocType, _ModelState>
28
+ {
29
+ protected:
30
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, CTModel<_tw, _RandGen>, _Derived>::type;
31
+ using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
32
+ friend BaseClass;
33
+ friend typename BaseClass::BaseClass;
34
+ using WeightType = typename BaseClass::WeightType;
35
+
36
+ static constexpr char TMID[] = "CTM\0";
37
+
38
+ uint64_t numBetaSample = 10;
39
+ uint64_t numTMNSample = 5;
40
+ uint64_t numDocBetaSample = -1;
41
+ math::MultiNormalDistribution<Float> topicPrior;
42
+
43
+ template<bool _asymEta>
44
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
45
+ {
46
+ const size_t V = this->realV;
47
+ assert(vid < V);
48
+ auto etaHelper = this->template getEtaHelper<_asymEta>();
49
+ auto& zLikelihood = ld.zLikelihood;
50
+ zLikelihood = doc.smBeta.array()
51
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
52
+ / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
53
+ sample::prefixSum(zLikelihood.data(), this->K);
54
+ return &zLikelihood[0];
55
+ }
56
+
57
+ void updateBeta(_DocType& doc, _RandGen& rg) const
58
+ {
59
+ Eigen::Matrix<Float, -1, 1> pbeta, lowerBound, upperBound;
60
+ constexpr Float epsilon = 1e-8;
61
+ constexpr size_t burnIn = 3;
62
+
63
+ pbeta = lowerBound = upperBound = Eigen::Matrix<Float, -1, 1>::Zero(this->K);
64
+ for (size_t i = 0; i < numBetaSample + burnIn; ++i)
65
+ {
66
+ if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
67
+ else pbeta = doc.beta.col(i % numBetaSample).array().exp();
68
+ Float betaESum = pbeta.sum() + 1;
69
+ pbeta /= betaESum;
70
+ for (size_t k = 0; k < this->K; ++k)
71
+ {
72
+ Float N_k = doc.numByTopic[k] + this->alpha;
73
+ Float N_nk = doc.getSumWordWeight() + this->alpha * (this->K + 1) - N_k;
74
+ Float u1 = rg.uniform_real(), u2 = rg.uniform_real();
75
+ Float max_uk = epsilon + pow(u1, (Float)1 / N_k) * (pbeta[k] - epsilon);
76
+ Float min_unk = (1 - pow(u2, (Float)1 / N_nk))
77
+ * (1 - pbeta[k]) + pbeta[k];
78
+
79
+ Float c = betaESum * (1 - pbeta[k]);
80
+ lowerBound[k] = log(c * max_uk / (1 - max_uk));
81
+ upperBound[k] = log(c * min_unk / (1 - min_unk));
82
+ if (lowerBound[k] > upperBound[k])
83
+ {
84
+ THROW_ERROR_WITH_INFO(exception::TrainingError,
85
+ text::format("Bound Error: LB(%f) > UB(%f)\n"
86
+ "max_uk: %f, min_unk: %f, c: %f", lowerBound[k], upperBound[k], max_uk, min_unk, c));
87
+ }
88
+ }
89
+
90
+ try
91
+ {
92
+ math::sampleFromTruncatedMultiNormal(doc.beta.col((i + 1) % numBetaSample),
93
+ topicPrior, lowerBound, upperBound, rg, numTMNSample);
94
+
95
+ if (!std::isfinite(doc.beta.col((i + 1) % numBetaSample)[0]))
96
+ THROW_ERROR_WITH_INFO(exception::TrainingError,
97
+ text::format("doc.beta.col(%d) is %f", (i + 1) % numBetaSample,
98
+ doc.beta.col((i + 1) % numBetaSample)[0]));
99
+ }
100
+ catch (const std::runtime_error& e)
101
+ {
102
+ std::cerr << e.what() << std::endl;
103
+ THROW_ERROR_WITH_INFO(exception::TrainingError, e.what());
104
+ }
105
+ }
106
+
107
+ // update softmax-applied beta coefficient
108
+ doc.smBeta.head(this->K) = doc.beta.block(0, 0, this->K, std::min(numBetaSample, numDocBetaSample)).rowwise().mean();
109
+ doc.smBeta = doc.smBeta.array().exp();
110
+ doc.smBeta /= doc.smBeta.array().sum();
111
+ }
112
+
113
+ template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
114
+ void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
115
+ {
116
+ BaseClass::template sampleDocument<_ps, _infer>(doc, edd, docId, ld, rgs, iterationCnt, partitionId);
117
+ /*if (iterationCnt >= this->burnIn && this->optimInterval && (iterationCnt + 1) % this->optimInterval == 0)
118
+ {
119
+ updateBeta(doc, rgs);
120
+ }*/
121
+ }
122
+
123
+ template<typename _DocIter>
124
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
125
+ {
126
+ if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
127
+
128
+ if (pool)
129
+ {
130
+ std::vector<std::future<void>> res;
131
+ const size_t chStride = pool->getNumWorkers() * 8;
132
+ size_t dist = std::distance(first, last);
133
+ for (size_t ch = 0; ch < chStride; ++ch)
134
+ {
135
+ auto b = first, e = first;
136
+ std::advance(b, dist * ch / chStride);
137
+ std::advance(e, dist * (ch + 1) / chStride);
138
+ res.emplace_back(pool->enqueue([&, ch, chStride](size_t threadId, _DocIter b, _DocIter e)
139
+ {
140
+ for (auto doc = b; doc != e; ++doc)
141
+ {
142
+ updateBeta(*doc, rgs[threadId]);
143
+ }
144
+ }, b, e));
145
+ }
146
+ for (auto& r : res) r.get();
147
+ }
148
+ else
149
+ {
150
+ for (auto doc = first; doc != last; ++doc)
151
+ {
152
+ updateBeta(*doc, rgs[0]);
153
+ }
154
+ }
155
+ }
156
+
157
+ int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
158
+ {
159
+ std::cerr << "Failed to sample! Reset prior and retry!" << std::endl;
160
+ const size_t chStride = std::min(pool.getNumWorkers() * 8, this->docs.size());
161
+ topicPrior = math::MultiNormalDistribution<Float>{ this->K };
162
+ std::vector<std::future<void>> res;
163
+ for (size_t ch = 0; ch < chStride; ++ch)
164
+ {
165
+ res.emplace_back(pool.enqueue([&, this](size_t threadId, size_t ch)
166
+ {
167
+ for (size_t i = ch; i < this->docs.size(); i += chStride)
168
+ {
169
+ this->docs[i].beta.setZero();
170
+ updateBeta(this->docs[i], rgs[threadId]);
171
+ }
172
+ }, ch));
173
+ }
174
+ for (auto& r : res) r.get();
175
+ return 0;
176
+ }
177
+
178
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
179
+ {
180
+ std::vector<std::future<void>> res;
181
+ topicPrior = math::MultiNormalDistribution<Float>::estimate([this](size_t i)
182
+ {
183
+ return this->docs[i / numBetaSample].beta.col(i % numBetaSample);
184
+ }, this->docs.size() * numBetaSample);
185
+ if (!std::isfinite(topicPrior.mean[0]))
186
+ THROW_ERROR_WITH_INFO(exception::TrainingError,
187
+ text::format("topicPrior.mean is %f", topicPrior.mean[0]));
188
+ }
189
+
190
+ template<typename _DocIter>
191
+ double getLLDocs(_DocIter _first, _DocIter _last) const
192
+ {
193
+ const auto K = this->K;
194
+ const auto alpha = this->alpha;
195
+
196
+ double ll = 0;
197
+ for (; _first != _last; ++_first)
198
+ {
199
+ auto& doc = *_first;
200
+ Eigen::Matrix<Float, -1, 1> pbeta = doc.smBeta.array().log();
201
+ Float last = pbeta[K - 1];
202
+ for (Tid k = 0; k < K; ++k)
203
+ {
204
+ ll += pbeta[k] * (doc.numByTopic[k] + alpha) - math::lgammaT(doc.numByTopic[k] + alpha + 1);
205
+ }
206
+ pbeta.array() -= last;
207
+ ll += topicPrior.getLL(pbeta.head(this->K));
208
+ ll += math::lgammaT(doc.getSumWordWeight() + alpha * K + 1);
209
+ }
210
+ return ll;
211
+ }
212
+
213
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
214
+ {
215
+ BaseClass::prepareDoc(doc, docId, wordSize);
216
+ doc.beta = Eigen::Matrix<Float, -1, -1>::Zero(this->K, numBetaSample);
217
+ doc.smBeta = Eigen::Matrix<Float, -1, 1>::Constant(this->K, (Float)1 / this->K);
218
+ }
219
+
220
+ void updateDocs()
221
+ {
222
+ BaseClass::updateDocs();
223
+ for (auto& doc : this->docs)
224
+ {
225
+ doc.beta = Eigen::Matrix<Float, -1, -1>::Zero(this->K, numBetaSample);
226
+ }
227
+ }
228
+
229
+ void initGlobalState(bool initDocs)
230
+ {
231
+ BaseClass::initGlobalState(initDocs);
232
+ if (initDocs)
233
+ {
234
+ topicPrior = math::MultiNormalDistribution<Float>{ this->K };
235
+ }
236
+ }
237
+
238
+ public:
239
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numBetaSample, numTMNSample, topicPrior);
240
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numBetaSample, numTMNSample, topicPrior);
241
+
242
+ CTModel(size_t _K = 1, Float smoothingAlpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
243
+ : BaseClass(_K, smoothingAlpha, _eta, _rg)
244
+ {
245
+ this->optimInterval = 2;
246
+ }
247
+
248
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
249
+ {
250
+ std::vector<Float> ret(this->K);
251
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() =
252
+ doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
253
+ return ret;
254
+ }
255
+
256
+ std::vector<Float> getPriorMean() const override
257
+ {
258
+ return { topicPrior.mean.data(), topicPrior.mean.data() + topicPrior.mean.size() };
259
+ }
260
+
261
+ std::vector<Float> getPriorCov() const override
262
+ {
263
+ return { topicPrior.cov.data(), topicPrior.cov.data() + topicPrior.cov.size() };
264
+ }
265
+
266
+ std::vector<Float> getCorrelationTopic(Tid k) const override
267
+ {
268
+ Eigen::Matrix<Float, -1, 1> ret = topicPrior.cov.col(k).array() / (topicPrior.cov.diagonal().array() * topicPrior.cov(k, k)).sqrt();
269
+ return { ret.data(), ret.data() + ret.size() };
270
+ }
271
+
272
+ GETTER(NumBetaSample, size_t, numBetaSample);
273
+
274
+ void setNumBetaSample(size_t _numSample) override
275
+ {
276
+ numBetaSample = _numSample;
277
+ }
278
+
279
+ GETTER(NumDocBetaSample, size_t, numDocBetaSample);
280
+
281
+ void setNumDocBetaSample(size_t _numSample) override
282
+ {
283
+ numDocBetaSample = _numSample;
284
+ }
285
+
286
+ GETTER(NumTMNSample, size_t, numTMNSample);
287
+
288
+ void setNumTMNSample(size_t _numSample) override
289
+ {
290
+ numTMNSample = _numSample;
291
+ }
292
+ };
293
+ }
@@ -0,0 +1,51 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentDMR : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ size_t metadata = 0;
12
+
13
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadata);
14
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata);
15
+ };
16
+
17
+ class IDMRModel : public ILDAModel
18
+ {
19
+ public:
20
+ using DefaultDocType = DocumentDMR<TermWeight::one>;
21
+ static IDMRModel* create(TermWeight _weight, size_t _K = 1,
22
+ Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
23
+ size_t seed = std::random_device{}(),
24
+ bool scalarRng = false);
25
+
26
+ virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) = 0;
27
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) const = 0;
28
+
29
+ virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
30
+ const std::vector<std::string>& metadata) = 0;
31
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
32
+ const std::vector<std::string>& metadata) const = 0;
33
+
34
+ virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
35
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
36
+ const std::vector<std::string>& metadata) = 0;
37
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
38
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
39
+ const std::vector<std::string>& metadata) const = 0;
40
+
41
+ virtual void setAlphaEps(Float _alphaEps) = 0;
42
+ virtual Float getAlphaEps() const = 0;
43
+ virtual void setOptimRepeat(size_t repeat) = 0;
44
+ virtual size_t getOptimRepeat() const = 0;
45
+ virtual size_t getF() const = 0;
46
+ virtual Float getSigma() const = 0;
47
+ virtual const Dictionary& getMetadataDict() const = 0;
48
+ virtual std::vector<Float> getLambdaByMetadata(size_t metadataId) const = 0;
49
+ virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
50
+ };
51
+ }
@@ -0,0 +1,13 @@
1
+ #include "DMRModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class DMRModel<TermWeight::one>;
6
+ template class DMRModel<TermWeight::idf>;
7
+ template class DMRModel<TermWeight::pmi>;*/
8
+
9
+ IDMRModel* IDMRModel::create(TermWeight _weight, size_t _K, Float _defaultAlpha, Float _sigma, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, DMRModel, _K, _defaultAlpha, _sigma, _eta, _alphaEps, seed);
12
+ }
13
+ }
@@ -0,0 +1,374 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "../Utils/LBFGS.h"
4
+ #include "../Utils/text.hpp"
5
+ #include "DMR.h"
6
+ /*
7
+ Implementation of DMR using Gibbs sampling by bab2min
8
+ * Mimno, D., & McCallum, A. (2012). Topic models conditioned on arbitrary features with dirichlet-multinomial regression. arXiv preprint arXiv:1206.3278.
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ template<TermWeight _tw>
14
+ struct ModelStateDMR : public ModelStateLDA<_tw>
15
+ {
16
+ Eigen::Matrix<Float, -1, 1> tmpK;
17
+ };
18
+
19
+ template<TermWeight _tw, typename _RandGen,
20
+ size_t _Flags = flags::partitioned_multisampling,
21
+ typename _Interface = IDMRModel,
22
+ typename _Derived = void,
23
+ typename _DocType = DocumentDMR<_tw>,
24
+ typename _ModelState = ModelStateDMR<_tw>>
25
+ class DMRModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
26
+ typename std::conditional<std::is_same<_Derived, void>::value, DMRModel<_tw, _RandGen, _Flags>, _Derived>::type,
27
+ _DocType, _ModelState>
28
+ {
29
+ protected:
30
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, DMRModel<_tw, _RandGen>, _Derived>::type;
31
+ using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
32
+ friend BaseClass;
33
+ friend typename BaseClass::BaseClass;
34
+ using WeightType = typename BaseClass::WeightType;
35
+
36
+ static constexpr char TMID[] = "DMR\0";
37
+
38
+ Eigen::Matrix<Float, -1, -1> lambda;
39
+ Eigen::Matrix<Float, -1, -1> expLambda;
40
+ Float sigma;
41
+ uint32_t F = 0;
42
+ uint32_t optimRepeat = 5;
43
+ Float alphaEps = 1e-10;
44
+ Float temperatureScale = 0;
45
+ static constexpr Float maxLambda = 10;
46
+ static constexpr size_t maxBFGSIteration = 10;
47
+
48
+ Dictionary metadataDict;
49
+ LBFGSpp::LBFGSSolver<Float, LBFGSpp::LineSearchBracketing> solver;
50
+
51
+ Float getNegativeLambdaLL(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) const
52
+ {
53
+ g = (x.array() - log(this->alpha)) / pow(sigma, 2);
54
+ return (x.array() - log(this->alpha)).pow(2).sum() / 2 / pow(sigma, 2);
55
+ }
56
+
57
+ Float evaluateLambdaObj(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g, ThreadPool& pool, _ModelState* localData) const
58
+ {
59
+ // if one of x is greater than maxLambda, return +inf for preventing searching more
60
+ if ((x.array() > maxLambda).any()) return INFINITY;
61
+
62
+ const auto K = this->K;
63
+
64
+ Float fx = - static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
65
+ auto alphas = (x.array().exp() + alphaEps).eval();
66
+
67
+ std::vector<std::future<Eigen::Matrix<Float, -1, 1>>> res;
68
+ const size_t chStride = pool.getNumWorkers() * 8;
69
+ for (size_t ch = 0; ch < chStride; ++ch)
70
+ {
71
+ res.emplace_back(pool.enqueue([&](size_t threadId)
72
+ {
73
+ auto& tmpK = localData[threadId].tmpK;
74
+ if (!tmpK.size()) tmpK.resize(this->K);
75
+ Eigen::Matrix<Float, -1, 1> val = Eigen::Matrix<Float, -1, 1>::Zero(K * F + 1);
76
+ for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
77
+ {
78
+ const auto& doc = this->docs[docId];
79
+ auto alphaDoc = alphas.segment(doc.metadata * K, K);
80
+ Float alphaSum = alphaDoc.sum();
81
+ for (Tid k = 0; k < K; ++k)
82
+ {
83
+ val[K * F] -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
84
+ if (!std::isfinite(alphaDoc[k]) && alphaDoc[k] > 0) tmpK[k] = 0;
85
+ else tmpK[k] = -(math::digammaT(alphaDoc[k]) - math::digammaT(doc.numByTopic[k] + alphaDoc[k]));
86
+ }
87
+ //val[K * F] = -(lgammaApprox(alphaDoc.array()) - lgammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array())).sum();
88
+ //tmpK = -(digammaApprox(alphaDoc.array()) - digammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array()));
89
+ val[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
90
+ Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
91
+ if (!std::isfinite(alphaSum) && alphaSum > 0)
92
+ {
93
+ val[K * F] = -INFINITY;
94
+ t = 0;
95
+ }
96
+ val.segment(doc.metadata * K, K).array() -= alphaDoc.array() * (tmpK.array() + t);
97
+ }
98
+ return val;
99
+ }));
100
+ }
101
+ for (auto& r : res)
102
+ {
103
+ auto ret = r.get();
104
+ fx += ret[K * F];
105
+ g += ret.head(K * F);
106
+ }
107
+
108
+ // positive fx is an error from limited precision of float.
109
+ if (fx > 0) return INFINITY;
110
+ return -fx;
111
+ }
112
+
113
+ void initParameters()
114
+ {
115
+ auto dist = std::normal_distribution<Float>(log(this->alpha), sigma);
116
+ for (size_t i = 0; i < this->K; ++i) for (size_t j = 0; j < F; ++j)
117
+ {
118
+ lambda(i, j) = dist(this->rg);
119
+ }
120
+ }
121
+
122
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
123
+ {
124
+ Eigen::Matrix<Float, -1, -1> bLambda;
125
+ Float fx = 0, bestFx = INFINITY;
126
+ for (size_t i = 0; i < optimRepeat; ++i)
127
+ {
128
+ static_cast<DerivedClass*>(this)->initParameters();
129
+ int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g)
130
+ {
131
+ return static_cast<DerivedClass*>(this)->evaluateLambdaObj(x, g, pool, localData);
132
+ }, Eigen::Map<Eigen::Matrix<Float, -1, 1>>(lambda.data(), lambda.size()), fx);
133
+
134
+ if (fx < bestFx)
135
+ {
136
+ bLambda = lambda;
137
+ bestFx = fx;
138
+ //printf("\t(%d) %e\n", ret, fx);
139
+ }
140
+ }
141
+ if (!std::isfinite(bestFx))
142
+ {
143
+ throw exception::TrainingError{ "optimizing parameters has been failed!" };
144
+ }
145
+ lambda = bLambda;
146
+ //std::cerr << fx << std::endl;
147
+ expLambda = lambda.array().exp() + alphaEps;
148
+ }
149
+
150
+ int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
151
+ {
152
+ std::cerr << "Failed to optimize! Reset prior and retry!" << std::endl;
153
+ lambda.setZero();
154
+ expLambda = lambda.array().exp() + alphaEps;
155
+ return 0;
156
+ }
157
+
158
+ template<bool _asymEta>
159
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
160
+ {
161
+ const size_t V = this->realV;
162
+ assert(vid < V);
163
+ auto etaHelper = this->template getEtaHelper<_asymEta>();
164
+ auto& zLikelihood = ld.zLikelihood;
165
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->expLambda.col(doc.metadata).array())
166
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
167
+ / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
168
+
169
+ sample::prefixSum(zLikelihood.data(), this->K);
170
+ return &zLikelihood[0];
171
+ }
172
+
173
+
174
+ double getLLDocTopic(const _DocType& doc) const
175
+ {
176
+ const size_t V = this->realV;
177
+ const auto K = this->K;
178
+
179
+ auto alphaDoc = expLambda.col(doc.metadata);
180
+
181
+ Float ll = 0;
182
+ Float alphaSum = alphaDoc.sum();
183
+ for (Tid k = 0; k < K; ++k)
184
+ {
185
+ ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
186
+ ll -= math::lgammaT(alphaDoc[k]);
187
+ }
188
+ ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum);
189
+ ll += math::lgammaT(alphaSum);
190
+ return ll;
191
+ }
192
+
193
+ template<typename _DocIter>
194
+ double getLLDocs(_DocIter _first, _DocIter _last) const
195
+ {
196
+ const auto K = this->K;
197
+
198
+ double ll = 0;
199
+ for (; _first != _last; ++_first)
200
+ {
201
+ auto& doc = *_first;
202
+ auto alphaDoc = expLambda.col(doc.metadata);
203
+ Float alphaSum = alphaDoc.sum();
204
+
205
+ for (Tid k = 0; k < K; ++k)
206
+ {
207
+ ll += math::lgammaT(doc.numByTopic[k] + alphaDoc[k]) - math::lgammaT(alphaDoc[k]);
208
+ }
209
+ ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum) - math::lgammaT(alphaSum);
210
+ }
211
+ return ll;
212
+ }
213
+
214
+ double getLLRest(const _ModelState& ld) const
215
+ {
216
+ const auto K = this->K;
217
+ const auto alpha = this->alpha;
218
+ const auto eta = this->eta;
219
+ const size_t V = this->realV;
220
+
221
+ double ll = -(lambda.array() - log(alpha)).pow(2).sum() / 2 / pow(sigma, 2);
222
+ // topic-word distribution
223
+ auto lgammaEta = math::lgammaT(eta);
224
+ ll += math::lgammaT(V*eta) * K;
225
+ for (Tid k = 0; k < K; ++k)
226
+ {
227
+ ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
228
+ for (Vid v = 0; v < V; ++v)
229
+ {
230
+ if (!ld.numByTopicWord(k, v)) continue;
231
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
232
+ }
233
+ }
234
+ return ll;
235
+ }
236
+
237
+ void initGlobalState(bool initDocs)
238
+ {
239
+ BaseClass::initGlobalState(initDocs);
240
+ this->globalState.tmpK = Eigen::Matrix<Float, -1, 1>::Zero(this->K);
241
+ F = metadataDict.size();
242
+ if (initDocs)
243
+ {
244
+ lambda = Eigen::Matrix<Float, -1, -1>::Constant(this->K, F, log(this->alpha));
245
+ }
246
+ if (_Flags & flags::continuous_doc_data) this->numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, this->docs.size());
247
+ expLambda = lambda.array().exp();
248
+ LBFGSpp::LBFGSParam<Float> param;
249
+ param.max_iterations = maxBFGSIteration;
250
+ solver = decltype(solver){ param };
251
+ }
252
+
253
+ public:
254
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma, alphaEps, metadataDict, lambda);
255
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda);
256
+
257
+ DMRModel(size_t _K = 1, Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01,
258
+ Float _alphaEps = 0, size_t _rg = std::random_device{}())
259
+ : BaseClass(_K, defaultAlpha, _eta, _rg), sigma(_sigma), alphaEps(_alphaEps)
260
+ {
261
+ if (_sigma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong sigma value (sigma = %f)", _sigma));
262
+ }
263
+
264
+ template<bool _const = false>
265
+ _DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& metadata)
266
+ {
267
+ std::string metadataJoined = text::join(metadata.begin(), metadata.end(), "_");
268
+ Vid xid;
269
+ if (_const)
270
+ {
271
+ xid = metadataDict.toWid(metadataJoined);
272
+ if (xid == (Vid)-1) throw std::invalid_argument("unknown metadata");
273
+ }
274
+ else
275
+ {
276
+ xid = metadataDict.add(metadataJoined);
277
+ }
278
+ doc.metadata = xid;
279
+ return doc;
280
+ }
281
+
282
+ size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) override
283
+ {
284
+ auto doc = this->_makeDoc(words);
285
+ return this->_addDoc(_updateDoc(doc, metadata));
286
+ }
287
+
288
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& metadata) const override
289
+ {
290
+ auto doc = as_mutable(this)->template _makeDoc<true>(words);
291
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
292
+ }
293
+
294
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
295
+ const std::vector<std::string>& metadata) override
296
+ {
297
+ auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
298
+ return this->_addDoc(_updateDoc(doc, metadata));
299
+ }
300
+
301
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
302
+ const std::vector<std::string>& metadata) const override
303
+ {
304
+ auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
305
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
306
+ }
307
+
308
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
309
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
310
+ const std::vector<std::string>& metadata) override
311
+ {
312
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
313
+ return this->_addDoc(_updateDoc(doc, metadata));
314
+ }
315
+
316
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
317
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
318
+ const std::vector<std::string>& metadata) const override
319
+ {
320
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
321
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, metadata));
322
+ }
323
+
324
+ GETTER(F, size_t, F);
325
+ GETTER(Sigma, Float, sigma);
326
+ GETTER(AlphaEps, Float, alphaEps);
327
+ GETTER(OptimRepeat, size_t, optimRepeat);
328
+
329
+ void setAlphaEps(Float _alphaEps) override
330
+ {
331
+ alphaEps = _alphaEps;
332
+ }
333
+
334
+ void setOptimRepeat(size_t _optimRepeat) override
335
+ {
336
+ optimRepeat = _optimRepeat;
337
+ }
338
+
339
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
340
+ {
341
+ std::vector<Float> ret(this->K);
342
+ auto alphaDoc = expLambda.col(doc.metadata);
343
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() =
344
+ (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum());
345
+ return ret;
346
+ }
347
+
348
+ std::vector<Float> getLambdaByMetadata(size_t metadataId) const override
349
+ {
350
+ assert(metadataId < metadataDict.size());
351
+ auto l = lambda.col(metadataId);
352
+ return { l.data(), l.data() + this->K };
353
+ }
354
+
355
+ std::vector<Float> getLambdaByTopic(Tid tid) const override
356
+ {
357
+ assert(tid < this->K);
358
+ auto l = lambda.row(tid);
359
+ return { l.data(), l.data() + F };
360
+ }
361
+
362
+ const Dictionary& getMetadataDict() const override { return metadataDict; }
363
+ };
364
+
365
+ /* This is for preventing 'undefined symbol' problem in compiling by clang. */
366
+ template<TermWeight _tw, typename _RandGen, size_t _Flags,
367
+ typename _Interface, typename _Derived, typename _DocType, typename _ModelState>
368
+ constexpr Float DMRModel<_tw, _RandGen, _Flags, _Interface, _Derived, _DocType, _ModelState>::maxLambda;
369
+
370
+ template<TermWeight _tw, typename _RandGen, size_t _Flags,
371
+ typename _Interface, typename _Derived, typename _DocType, typename _ModelState>
372
+ constexpr size_t DMRModel<_tw, _RandGen, _Flags, _Interface, _Derived, _DocType, _ModelState>::maxBFGSIteration;
373
+
374
+ }