tomoto 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (420) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +123 -0
  5. data/ext/tomoto/ext.cpp +245 -0
  6. data/ext/tomoto/extconf.rb +28 -0
  7. data/lib/tomoto.rb +12 -0
  8. data/lib/tomoto/ct.rb +11 -0
  9. data/lib/tomoto/hdp.rb +11 -0
  10. data/lib/tomoto/lda.rb +67 -0
  11. data/lib/tomoto/version.rb +3 -0
  12. data/vendor/EigenRand/EigenRand/Core.h +1139 -0
  13. data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
  14. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
  15. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
  16. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
  17. data/vendor/EigenRand/EigenRand/EigenRand +19 -0
  18. data/vendor/EigenRand/EigenRand/Macro.h +24 -0
  19. data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
  20. data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
  21. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
  22. data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
  23. data/vendor/EigenRand/EigenRand/doc.h +220 -0
  24. data/vendor/EigenRand/LICENSE +21 -0
  25. data/vendor/EigenRand/README.md +288 -0
  26. data/vendor/eigen/COPYING.BSD +26 -0
  27. data/vendor/eigen/COPYING.GPL +674 -0
  28. data/vendor/eigen/COPYING.LGPL +502 -0
  29. data/vendor/eigen/COPYING.MINPACK +52 -0
  30. data/vendor/eigen/COPYING.MPL2 +373 -0
  31. data/vendor/eigen/COPYING.README +18 -0
  32. data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
  33. data/vendor/eigen/Eigen/Cholesky +46 -0
  34. data/vendor/eigen/Eigen/CholmodSupport +48 -0
  35. data/vendor/eigen/Eigen/Core +537 -0
  36. data/vendor/eigen/Eigen/Dense +7 -0
  37. data/vendor/eigen/Eigen/Eigen +2 -0
  38. data/vendor/eigen/Eigen/Eigenvalues +61 -0
  39. data/vendor/eigen/Eigen/Geometry +62 -0
  40. data/vendor/eigen/Eigen/Householder +30 -0
  41. data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
  42. data/vendor/eigen/Eigen/Jacobi +33 -0
  43. data/vendor/eigen/Eigen/LU +50 -0
  44. data/vendor/eigen/Eigen/MetisSupport +35 -0
  45. data/vendor/eigen/Eigen/OrderingMethods +73 -0
  46. data/vendor/eigen/Eigen/PaStiXSupport +48 -0
  47. data/vendor/eigen/Eigen/PardisoSupport +35 -0
  48. data/vendor/eigen/Eigen/QR +51 -0
  49. data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
  50. data/vendor/eigen/Eigen/SPQRSupport +34 -0
  51. data/vendor/eigen/Eigen/SVD +51 -0
  52. data/vendor/eigen/Eigen/Sparse +36 -0
  53. data/vendor/eigen/Eigen/SparseCholesky +45 -0
  54. data/vendor/eigen/Eigen/SparseCore +69 -0
  55. data/vendor/eigen/Eigen/SparseLU +46 -0
  56. data/vendor/eigen/Eigen/SparseQR +37 -0
  57. data/vendor/eigen/Eigen/StdDeque +27 -0
  58. data/vendor/eigen/Eigen/StdList +26 -0
  59. data/vendor/eigen/Eigen/StdVector +27 -0
  60. data/vendor/eigen/Eigen/SuperLUSupport +64 -0
  61. data/vendor/eigen/Eigen/UmfPackSupport +40 -0
  62. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
  63. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
  64. data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
  65. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
  66. data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
  67. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
  68. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
  69. data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
  70. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
  71. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
  72. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
  73. data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
  74. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
  75. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
  76. data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
  77. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
  78. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
  79. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
  80. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
  81. data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
  82. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
  84. data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
  85. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
  86. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
  87. data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
  88. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
  89. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
  90. data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
  91. data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
  92. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
  93. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
  94. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
  95. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
  96. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
  97. data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
  98. data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
  99. data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
  100. data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
  101. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
  102. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
  103. data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
  104. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
  105. data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
  106. data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
  107. data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
  108. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
  109. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
  110. data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
  111. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
  112. data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
  113. data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
  114. data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
  115. data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
  116. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
  117. data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
  118. data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
  119. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
  120. data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
  121. data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
  122. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
  123. data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
  124. data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
  125. data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
  126. data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
  127. data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
  128. data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
  129. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
  130. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
  131. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
  132. data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
  133. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
  134. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
  135. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
  136. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
  137. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
  138. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
  139. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
  140. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
  141. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
  142. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
  143. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
  145. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
  146. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
  148. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
  149. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
  153. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
  154. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
  156. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
  160. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
  161. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
  162. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
  163. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
  164. data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
  165. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
  166. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
  167. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
  168. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
  169. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
  170. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
  171. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
  172. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
  173. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
  174. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
  175. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
  176. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
  177. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
  178. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
  179. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
  180. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
  181. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
  182. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
  183. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
  184. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
  185. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
  186. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
  187. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
  188. data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
  189. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
  190. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
  191. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
  192. data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
  193. data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
  194. data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
  195. data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
  196. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
  197. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
  198. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
  199. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
  200. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
  201. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
  202. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
  203. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
  204. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
  205. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
  206. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
  207. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
  208. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
  209. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
  210. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
  211. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
  212. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
  213. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
  214. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
  215. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
  216. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
  217. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
  218. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
  219. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
  220. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
  221. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
  222. data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
  223. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
  224. data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
  225. data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
  226. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
  227. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
  228. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
  229. data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
  230. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
  231. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
  232. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
  233. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
  234. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
  235. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
  236. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
  237. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
  238. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
  239. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
  240. data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
  241. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
  242. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
  243. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
  244. data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
  245. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
  246. data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
  247. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
  248. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
  249. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
  250. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
  251. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
  252. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
  253. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
  254. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
  255. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
  256. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
  257. data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
  258. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
  259. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
  260. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
  261. data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
  262. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
  263. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
  264. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
  265. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
  266. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
  267. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
  268. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
  269. data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
  270. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
  271. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
  272. data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
  273. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
  274. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
  275. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
  276. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
  277. data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
  278. data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
  279. data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
  280. data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
  281. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
  282. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
  283. data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
  284. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
  285. data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
  295. data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
  296. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
  297. data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
  298. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
  299. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
  300. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
  307. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
  308. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
  309. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
  310. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
  311. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
  312. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
  313. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
  314. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
  315. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
  316. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
  317. data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
  318. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
  319. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
  320. data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
  321. data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
  322. data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
  323. data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
  324. data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
  325. data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
  326. data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
  327. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
  328. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
  329. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
  330. data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
  331. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
  332. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
  333. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
  334. data/vendor/eigen/README.md +3 -0
  335. data/vendor/eigen/bench/README.txt +55 -0
  336. data/vendor/eigen/bench/btl/COPYING +340 -0
  337. data/vendor/eigen/bench/btl/README +154 -0
  338. data/vendor/eigen/bench/tensors/README +21 -0
  339. data/vendor/eigen/blas/README.txt +6 -0
  340. data/vendor/eigen/demos/mandelbrot/README +10 -0
  341. data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
  342. data/vendor/eigen/demos/opengl/README +13 -0
  343. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
  344. data/vendor/eigen/unsupported/README.txt +50 -0
  345. data/vendor/tomotopy/LICENSE +21 -0
  346. data/vendor/tomotopy/README.kr.rst +375 -0
  347. data/vendor/tomotopy/README.rst +382 -0
  348. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
  349. data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
  350. data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
  351. data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
  352. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
  353. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
  354. data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
  355. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
  356. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
  357. data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
  358. data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
  359. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
  360. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
  361. data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
  362. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
  363. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
  364. data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
  365. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
  366. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
  367. data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
  368. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
  369. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
  370. data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
  371. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
  372. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
  373. data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
  374. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
  375. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
  376. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
  377. data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
  378. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
  379. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
  380. data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
  381. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
  382. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
  383. data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
  384. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
  385. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
  386. data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
  387. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
  388. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
  389. data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
  390. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
  391. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
  392. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
  393. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
  394. data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
  395. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
  396. data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
  397. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
  398. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
  399. data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
  400. data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
  401. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
  402. data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
  403. data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
  404. data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
  405. data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
  406. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
  407. data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
  408. data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
  409. data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
  410. data/vendor/tomotopy/src/Utils/exception.h +28 -0
  411. data/vendor/tomotopy/src/Utils/math.h +281 -0
  412. data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
  413. data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
  414. data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
  415. data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
  416. data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
  417. data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
  418. data/vendor/tomotopy/src/Utils/text.hpp +49 -0
  419. data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
  420. metadata +531 -0
@@ -0,0 +1,13 @@
1
+ #include "LDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class LDAModel<TermWeight::one>;
6
+ template class LDAModel<TermWeight::idf>;
7
+ template class LDAModel<TermWeight::pmi>;*/
8
+
9
+ ILDAModel* ILDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, LDAModel, _K, _alpha, _eta, seed);
12
+ }
13
+ }
@@ -0,0 +1,1058 @@
1
+ #pragma once
2
+ #include <unordered_set>
3
+ #include <numeric>
4
+ #include "TopicModel.hpp"
5
+ #include "../Utils/EigenAddonOps.hpp"
6
+ #include "../Utils/Utils.hpp"
7
+ #include "../Utils/math.h"
8
+ #include "../Utils/sample.hpp"
9
+ #include "LDA.h"
10
+
11
+ /*
12
+ Implementation of LDA using Gibbs sampling by bab2min
13
+
14
+ * Blei, D. M., Ng, A. Y., & Jordan, M. I. (2003). Latent dirichlet allocation. Journal of machine Learning research, 3(Jan), 993-1022.
15
+ * Newman, D., Asuncion, A., Smyth, P., & Welling, M. (2009). Distributed algorithms for topic models. Journal of Machine Learning Research, 10(Aug), 1801-1828.
16
+
17
+ Term Weighting Scheme is based on following paper:
18
+ * Wilson, A. T., & Chew, P. A. (2010, June). Term weighting schemes for latent dirichlet allocation. In human language technologies: The 2010 annual conference of the North American Chapter of the Association for Computational Linguistics (pp. 465-473). Association for Computational Linguistics.
19
+
20
+ */
21
+
22
+ #ifdef TMT_SCALAR_RNG
23
+ #define TMT_SWITCH_TW(TW, SRNG, MDL, ...) do{\
24
+ {\
25
+ switch (TW){\
26
+ case TermWeight::one:\
27
+ return new MDL<TermWeight::one, ScalarRandGen>(__VA_ARGS__);\
28
+ case TermWeight::idf:\
29
+ return new MDL<TermWeight::idf, ScalarRandGen>(__VA_ARGS__);\
30
+ case TermWeight::pmi:\
31
+ return new MDL<TermWeight::pmi, ScalarRandGen>(__VA_ARGS__);\
32
+ }\
33
+ }\
34
+ return nullptr; } while(0)
35
+ #else
36
+ #define TMT_SWITCH_TW(TW, SRNG, MDL, ...) do{\
37
+ {\
38
+ switch (TW){\
39
+ case TermWeight::one:\
40
+ return new MDL<TermWeight::one, RandGen>(__VA_ARGS__);\
41
+ case TermWeight::idf:\
42
+ return new MDL<TermWeight::idf, RandGen>(__VA_ARGS__);\
43
+ case TermWeight::pmi:\
44
+ return new MDL<TermWeight::pmi, RandGen>(__VA_ARGS__);\
45
+ }\
46
+ }\
47
+ return nullptr; } while(0)
48
+ #endif
49
+
50
+ #define GETTER(name, type, field) type get##name() const override { return field; }
51
+
52
+ namespace tomoto
53
+ {
54
+ template<TermWeight _tw>
55
+ struct ModelStateLDA
56
+ {
57
+ using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
58
+
59
+ Eigen::Matrix<Float, -1, 1> zLikelihood;
60
+ Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
61
+ Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
+ DEFINE_SERIALIZER(numByTopic, numByTopicWord);
63
+ };
64
+
65
+ namespace flags
66
+ {
67
+ enum
68
+ {
69
+ generator_by_doc = end_flag_of_TopicModel,
70
+ end_flag_of_LDAModel = generator_by_doc << 1,
71
+ };
72
+ }
73
+
74
+
75
+ template<typename _Model, bool _asymEta>
76
+ class EtaHelper
77
+ {
78
+ const _Model& _this;
79
+ public:
80
+ EtaHelper(const _Model& p) : _this(p) {}
81
+
82
+ Float getEta(size_t vid) const
83
+ {
84
+ return _this.eta;
85
+ }
86
+
87
+ Float getEtaSum() const
88
+ {
89
+ return _this.eta * _this.realV;
90
+ }
91
+ };
92
+
93
+ template<typename _Model>
94
+ class EtaHelper<_Model, true>
95
+ {
96
+ const _Model& _this;
97
+ public:
98
+ EtaHelper(const _Model& p) : _this(p) {}
99
+
100
+ auto getEta(size_t vid) const
101
+ -> decltype(_this.etaByTopicWord.col(vid).array())
102
+ {
103
+ return _this.etaByTopicWord.col(vid).array();
104
+ }
105
+
106
+ auto getEtaSum() const
107
+ -> decltype(_this.etaSumByTopic.array())
108
+ {
109
+ return _this.etaSumByTopic.array();
110
+ }
111
+ };
112
+
113
+ template<TermWeight _tw>
114
+ struct TwId;
115
+
116
+ template<>
117
+ struct TwId<TermWeight::one>
118
+ {
119
+ static constexpr char TWID[] = "one\0";
120
+ };
121
+
122
+ template<>
123
+ struct TwId<TermWeight::idf>
124
+ {
125
+ static constexpr char TWID[] = "idf\0";
126
+ };
127
+
128
+ template<>
129
+ struct TwId<TermWeight::pmi>
130
+ {
131
+ static constexpr char TWID[] = "pmi\0";
132
+ };
133
+
134
+ // to make HDP friend of LDA for HDPModel::converToLDA
135
+ template<TermWeight _tw,
136
+ typename _RandGen,
137
+ typename _Interface,
138
+ typename _Derived,
139
+ typename _DocType,
140
+ typename _ModelState>
141
+ class HDPModel;
142
+
143
+ template<TermWeight _tw, typename _RandGen,
144
+ size_t _Flags = flags::partitioned_multisampling,
145
+ typename _Interface = ILDAModel,
146
+ typename _Derived = void,
147
+ typename _DocType = DocumentLDA<_tw>,
148
+ typename _ModelState = ModelStateLDA<_tw>>
149
+ class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
150
+ typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
151
+ _DocType, _ModelState>,
152
+ protected TwId<_tw>
153
+ {
154
+ protected:
155
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, LDAModel, _Derived>::type;
156
+ using BaseClass = TopicModel<_RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
157
+ friend BaseClass;
158
+ friend EtaHelper<DerivedClass, true>;
159
+ friend EtaHelper<DerivedClass, false>;
160
+
161
+ template<TermWeight,
162
+ typename,
163
+ typename,
164
+ typename,
165
+ typename,
166
+ typename>
167
+ friend class HDPModel;
168
+
169
+ static constexpr char TMID[] = "LDA\0";
170
+ using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
171
+
172
+ enum { m_flags = _Flags };
173
+
174
+ std::vector<Float> vocabWeights;
175
+ std::vector<Tid> sharedZs;
176
+ std::vector<Float> sharedWordWeights;
177
+ Tid K;
178
+ Float alpha, eta;
179
+ Eigen::Matrix<Float, -1, 1> alphas;
180
+ std::unordered_map<std::string, std::vector<Float>> etaByWord;
181
+ Eigen::Matrix<Float, -1, -1> etaByTopicWord; // (K, V)
182
+ Eigen::Matrix<Float, -1, 1> etaSumByTopic; // (K, )
183
+ uint32_t optimInterval = 10, burnIn = 0;
184
+ Eigen::Matrix<WeightType, -1, -1> numByTopicDoc;
185
+
186
+ struct ExtraDocData
187
+ {
188
+ std::vector<Vid> vChunkOffset;
189
+ Eigen::Matrix<uint32_t, -1, -1> chunkOffsetByDoc;
190
+ };
191
+
192
+ ExtraDocData eddTrain;
193
+
194
+ template<typename _List>
195
+ static Float calcDigammaSum(ThreadPool* pool, _List list, size_t len, Float alpha)
196
+ {
197
+ auto listExpr = Eigen::Matrix<Float, -1, 1>::NullaryExpr(len, list);
198
+ auto dAlpha = math::digammaT(alpha);
199
+
200
+ size_t suggested = (len + 127) / 128;
201
+ if (pool && suggested > pool->getNumWorkers()) suggested = pool->getNumWorkers();
202
+ if (suggested <= 1 || !pool)
203
+ {
204
+ return (math::digammaApprox(listExpr.array() + alpha) - dAlpha).sum();
205
+ }
206
+
207
+
208
+ std::vector<std::future<Float>> futures;
209
+ for (size_t i = 0; i < suggested; ++i)
210
+ {
211
+ size_t start = (len * i / suggested + 15) & ~0xF,
212
+ end = std::min((len * (i + 1) / suggested + 15) & ~0xF, len);
213
+ futures.emplace_back(pool->enqueue([&, start, end, dAlpha](size_t)
214
+ {
215
+ return (math::digammaApprox(listExpr.array().segment(start, end - start) + alpha) - dAlpha).sum();
216
+ }));
217
+ }
218
+ Float ret = 0;
219
+ for (auto& f : futures) ret += f.get();
220
+ return ret;
221
+ }
222
+
223
+ /*
224
+ function for optimizing hyperparameters
225
+ */
226
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
227
+ {
228
+ const auto K = this->K;
229
+ for (size_t i = 0; i < 10; ++i)
230
+ {
231
+ Float denom = calcDigammaSum(&pool, [&](size_t i) { return this->docs[i].getSumWordWeight(); }, this->docs.size(), alphas.sum());
232
+ for (size_t k = 0; k < K; ++k)
233
+ {
234
+ Float nom = calcDigammaSum(&pool, [&](size_t i) { return this->docs[i].numByTopic[k]; }, this->docs.size(), alphas(k));
235
+ alphas(k) = std::max(nom / denom * alphas(k), 1e-5f);
236
+ }
237
+ }
238
+ }
239
+
240
+ template<bool _asymEta>
241
+ EtaHelper<DerivedClass, _asymEta> getEtaHelper() const
242
+ {
243
+ return EtaHelper<DerivedClass, _asymEta>{ *static_cast<const DerivedClass*>(this) };
244
+ }
245
+
246
+ template<bool _asymEta>
247
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
248
+ {
249
+ const size_t V = this->realV;
250
+ assert(vid < V);
251
+ auto etaHelper = this->template getEtaHelper<_asymEta>();
252
+ auto& zLikelihood = ld.zLikelihood;
253
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + alphas.array())
254
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
255
+ / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
256
+ sample::prefixSum(zLikelihood.data(), K);
257
+ return &zLikelihood[0];
258
+ }
259
+
260
+ template<int _inc>
261
+ inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid) const
262
+ {
263
+ assert(tid < K);
264
+ assert(vid < this->realV);
265
+ constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
266
+ typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
267
+ = _tw != TermWeight::one ? doc.wordWeights[pid] : 1;
268
+
269
+ updateCnt<_dec>(doc.numByTopic[tid], _inc * weight);
270
+ updateCnt<_dec>(ld.numByTopic[tid], _inc * weight);
271
+ updateCnt<_dec>(ld.numByTopicWord(tid, vid), _inc * weight);
272
+ }
273
+
274
+ void resetStatistics()
275
+ {
276
+ this->globalState.numByTopic.setZero();
277
+ this->globalState.numByTopicWord.setZero();
278
+ for (auto& doc : this->docs)
279
+ {
280
+ doc.numByTopic.setZero();
281
+ for (size_t w = 0; w < doc.words.size(); ++w)
282
+ {
283
+ if (doc.words[w] >= this->realV) continue;
284
+ addWordTo<1>(this->globalState, doc, w, doc.words[w], doc.Zs[w]);
285
+ }
286
+ }
287
+ }
288
+
289
+ /*
290
+ called once before sampleDocument
291
+ */
292
+ void presampleDocument(_DocType& doc, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt) const
293
+ {
294
+ }
295
+
296
+ /*
297
+ main sampling procedure (can be called one or more by ParallelScheme)
298
+ */
299
+ template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
300
+ void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
301
+ {
302
+ size_t b = 0, e = doc.words.size();
303
+ if (_ps == ParallelScheme::partition)
304
+ {
305
+ b = edd.chunkOffsetByDoc(partitionId, docId);
306
+ e = edd.chunkOffsetByDoc(partitionId + 1, docId);
307
+ }
308
+
309
+ size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
310
+
311
+ for (size_t w = b; w < e; ++w)
312
+ {
313
+ if (doc.words[w] >= this->realV) continue;
314
+ addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
315
+ Float* dist;
316
+ if (etaByTopicWord.size())
317
+ {
318
+ dist = static_cast<const DerivedClass*>(this)->template
319
+ getZLikelihoods<true>(ld, doc, docId, doc.words[w] - vOffset);
320
+ }
321
+ else
322
+ {
323
+ dist = static_cast<const DerivedClass*>(this)->template
324
+ getZLikelihoods<false>(ld, doc, docId, doc.words[w] - vOffset);
325
+ }
326
+ doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
327
+ addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
328
+ }
329
+ }
330
+
331
+ template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
332
+ void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
333
+ _DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
334
+ {
335
+ // single-threaded sampling
336
+ if (_ps == ParallelScheme::none)
337
+ {
338
+ forRandom((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
339
+ {
340
+ static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
341
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
342
+ docFirst[id], edd, id,
343
+ *localData, *rgs, this->globalStep, 0);
344
+
345
+ });
346
+ }
347
+ // multi-threaded sampling on partition ad update into global
348
+ else if (_ps == ParallelScheme::partition)
349
+ {
350
+ const size_t chStride = pool.getNumWorkers();
351
+ for (size_t i = 0; i < chStride; ++i)
352
+ {
353
+ res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
354
+ {
355
+ size_t didx = (i + partitionId) % chStride;
356
+ forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
357
+ {
358
+ if (i == 0)
359
+ {
360
+ static_cast<const DerivedClass*>(this)->presampleDocument(
361
+ docFirst[id * chStride + didx], id * chStride + didx,
362
+ localData[partitionId], rgs[partitionId], this->globalStep
363
+ );
364
+ }
365
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
366
+ docFirst[id * chStride + didx], edd, id * chStride + didx,
367
+ localData[partitionId], rgs[partitionId], this->globalStep, partitionId
368
+ );
369
+ });
370
+ });
371
+ for (auto& r : res) r.get();
372
+ res.clear();
373
+ }
374
+ }
375
+ // multi-threaded sampling on copy and merge into global
376
+ else if(_ps == ParallelScheme::copy_merge)
377
+ {
378
+ const size_t chStride = std::min(pool.getNumWorkers() * 8, (size_t)std::distance(docFirst, docLast));
379
+ for (size_t ch = 0; ch < chStride; ++ch)
380
+ {
381
+ res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
382
+ {
383
+ forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
384
+ {
385
+ static_cast<const DerivedClass*>(this)->presampleDocument(
386
+ docFirst[id * chStride + ch], id * chStride + ch,
387
+ localData[threadId], rgs[threadId], this->globalStep
388
+ );
389
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
390
+ docFirst[id * chStride + ch], edd, id * chStride + ch,
391
+ localData[threadId], rgs[threadId], this->globalStep, 0
392
+ );
393
+ });
394
+ }));
395
+ }
396
+ for (auto& r : res) r.get();
397
+ res.clear();
398
+ }
399
+ }
400
+
401
+ template<typename _DocIter, typename _ExtraDocData>
402
+ void updatePartition(ThreadPool& pool, const _ModelState& globalState, _ModelState* localData, _DocIter first, _DocIter last, _ExtraDocData& edd) const
403
+ {
404
+ size_t numPools = pool.getNumWorkers();
405
+ if (edd.vChunkOffset.size() != numPools)
406
+ {
407
+ edd.vChunkOffset.clear();
408
+ size_t totCnt = std::accumulate(this->vocabCf.begin(), this->vocabCf.begin() + this->realV, 0);
409
+ size_t cumCnt = 0;
410
+ for (size_t i = 0; i < this->realV; ++i)
411
+ {
412
+ cumCnt += this->vocabCf[i];
413
+ if (cumCnt * numPools >= totCnt * (edd.vChunkOffset.size() + 1)) edd.vChunkOffset.emplace_back(i + 1);
414
+ }
415
+
416
+ edd.chunkOffsetByDoc.resize(numPools + 1, std::distance(first, last));
417
+ size_t i = 0;
418
+ for (; first != last; ++first, ++i)
419
+ {
420
+ auto& doc = *first;
421
+ edd.chunkOffsetByDoc(0, i) = 0;
422
+ size_t g = 0;
423
+ for (size_t j = 0; j < doc.words.size(); ++j)
424
+ {
425
+ for (; g < numPools && doc.words[j] >= edd.vChunkOffset[g]; ++g)
426
+ {
427
+ edd.chunkOffsetByDoc(g + 1, i) = j;
428
+ }
429
+ }
430
+ for (; g < numPools; ++g)
431
+ {
432
+ edd.chunkOffsetByDoc(g + 1, i) = doc.words.size();
433
+ }
434
+ }
435
+ }
436
+ static_cast<const DerivedClass*>(this)->distributePartition(pool, globalState, localData, edd);
437
+ }
438
+
439
+ template<typename _ExtraDocData>
440
+ void distributePartition(ThreadPool& pool, const _ModelState& globalState, _ModelState* localData, const _ExtraDocData& edd) const
441
+ {
442
+ std::vector<std::future<void>> res = pool.enqueueToAll([&](size_t partitionId)
443
+ {
444
+ size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
445
+ e = edd.vChunkOffset[partitionId];
446
+
447
+ localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
448
+ localData[partitionId].numByTopic = globalState.numByTopic;
449
+ if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
450
+ });
451
+
452
+ for (auto& r : res) r.get();
453
+ }
454
+
455
+ template<ParallelScheme _ps>
456
+ size_t estimateMaxThreads() const
457
+ {
458
+ if (_ps == ParallelScheme::partition)
459
+ {
460
+ return (this->realV + 3) / 4;
461
+ }
462
+ if (_ps == ParallelScheme::copy_merge)
463
+ {
464
+ return (this->docs.size() + 1) / 2;
465
+ }
466
+ return (size_t)-1;
467
+ }
468
+
469
+ template<ParallelScheme _ps>
470
+ void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
471
+ {
472
+ std::vector<std::future<void>> res;
473
+ try
474
+ {
475
+ performSampling<_ps, false>(pool, localData, rgs, res,
476
+ this->docs.begin(), this->docs.end(), eddTrain);
477
+ static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
478
+ static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
479
+ static_cast<DerivedClass*>(this)->template sampleGlobalLevel<>(&pool, localData, rgs, this->docs.begin(), this->docs.end());
480
+ if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
481
+ {
482
+ static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
483
+ }
484
+ }
485
+ catch (const exception::TrainingError&)
486
+ {
487
+ for (auto& r : res) if(r.valid()) r.get();
488
+ throw;
489
+ }
490
+ }
491
+
492
+ /*
493
+ updates global informations after sampling documents
494
+ ex) update new global K at HDP model
495
+ */
496
+ void updateGlobalInfo(ThreadPool& pool, _ModelState* localData)
497
+ {
498
+ }
499
+
500
+ /*
501
+ merges multithreaded document sampling result
502
+ */
503
+ template<ParallelScheme _ps, typename _ExtraDocData>
504
+ void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
505
+ {
506
+ std::vector<std::future<void>> res;
507
+
508
+ if (_ps == ParallelScheme::copy_merge)
509
+ {
510
+ tState = globalState;
511
+ globalState = localData[0];
512
+ for (size_t i = 1; i < pool.getNumWorkers(); ++i)
513
+ {
514
+ globalState.numByTopicWord += localData[i].numByTopicWord - tState.numByTopicWord;
515
+ }
516
+
517
+ // make all count being positive
518
+ if (_tw != TermWeight::one)
519
+ {
520
+ globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
521
+ }
522
+ globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
523
+
524
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
525
+ {
526
+ res.emplace_back(pool.enqueue([&, i](size_t)
527
+ {
528
+ localData[i] = globalState;
529
+ }));
530
+ }
531
+ }
532
+ else if (_ps == ParallelScheme::partition)
533
+ {
534
+ res = pool.enqueueToAll([&](size_t partitionId)
535
+ {
536
+ size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
537
+ e = edd.vChunkOffset[partitionId];
538
+ globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
539
+ });
540
+ for (auto& r : res) r.get();
541
+ res.clear();
542
+
543
+ // make all count being positive
544
+ if (_tw != TermWeight::one)
545
+ {
546
+ globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
547
+ }
548
+ globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
549
+
550
+ res = pool.enqueueToAll([&](size_t threadId)
551
+ {
552
+ localData[threadId].numByTopic = globalState.numByTopic;
553
+ });
554
+ }
555
+ for (auto& r : res) r.get();
556
+ }
557
+
558
+ /*
559
+ performs sampling which needs global state modification
560
+ ex) document pathing at hLDA model
561
+ * if pool is nullptr, workers has been already pooled and cannot branch works more.
562
+ */
563
+ template<typename _DocIter>
564
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
565
+ {
566
+ }
567
+
568
+ template<typename _DocIter>
569
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
570
+ {
571
+ }
572
+
573
+ template<typename _DocIter>
574
+ double getLLDocs(_DocIter _first, _DocIter _last) const
575
+ {
576
+ double ll = 0;
577
+ // doc-topic distribution
578
+ for (; _first != _last; ++_first)
579
+ {
580
+ auto& doc = *_first;
581
+ ll -= math::lgammaT(doc.getSumWordWeight() + alphas.sum()) - math::lgammaT(alphas.sum());
582
+ for (Tid k = 0; k < K; ++k)
583
+ {
584
+ ll += math::lgammaT(doc.numByTopic[k] + alphas[k]) - math::lgammaT(alphas[k]);
585
+ }
586
+ }
587
+ return ll;
588
+ }
589
+
590
+ double getLLRest(const _ModelState& ld) const
591
+ {
592
+ double ll = 0;
593
+ const size_t V = this->realV;
594
+ // topic-word distribution
595
+ auto lgammaEta = math::lgammaT(eta);
596
+ ll += math::lgammaT(V*eta) * K;
597
+ for (Tid k = 0; k < K; ++k)
598
+ {
599
+ ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
600
+ for (Vid v = 0; v < V; ++v)
601
+ {
602
+ if (!ld.numByTopicWord(k, v)) continue;
603
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
604
+ assert(std::isfinite(ll));
605
+ }
606
+ }
607
+ return ll;
608
+ }
609
+
610
+ double getLL() const
611
+ {
612
+ return static_cast<const DerivedClass*>(this)->template getLLDocs<>(this->docs.begin(), this->docs.end())
613
+ + static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
614
+ }
615
+
616
+ void prepareShared()
617
+ {
618
+ auto txZs = [](_DocType& doc) { return &doc.Zs; };
619
+ tvector<Tid>::trade(sharedZs,
620
+ makeTransformIter(this->docs.begin(), txZs),
621
+ makeTransformIter(this->docs.end(), txZs));
622
+ if (_tw != TermWeight::one)
623
+ {
624
+ auto txWeights = [](_DocType& doc) { return &doc.wordWeights; };
625
+ tvector<Float>::trade(sharedWordWeights,
626
+ makeTransformIter(this->docs.begin(), txWeights),
627
+ makeTransformIter(this->docs.end(), txWeights));
628
+ }
629
+ }
630
+
631
+ WeightType* getTopicDocPtr(size_t docId) const
632
+ {
633
+ if (!(m_flags & flags::continuous_doc_data) || docId == (size_t)-1) return nullptr;
634
+ return (WeightType*)numByTopicDoc.col(docId).data();
635
+ }
636
+
637
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
638
+ {
639
+ sortAndWriteOrder(doc.words, doc.wOrder);
640
+ doc.numByTopic.init(getTopicDocPtr(docId), K);
641
+ doc.Zs = tvector<Tid>(wordSize);
642
+ if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize, 1);
643
+ }
644
+
645
+ void prepareWordPriors()
646
+ {
647
+ if (etaByWord.empty()) return;
648
+ etaByTopicWord.resize(K, this->realV);
649
+ etaSumByTopic.resize(K);
650
+ etaByTopicWord.array() = eta;
651
+ for (auto& it : etaByWord)
652
+ {
653
+ auto id = this->dict.toWid(it.first);
654
+ if (id == (Vid)-1 || id >= this->realV) continue;
655
+ etaByTopicWord.col(id) = Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ it.second.data(), (Eigen::Index)it.second.size() };
656
+ }
657
+ etaSumByTopic = etaByTopicWord.rowwise().sum();
658
+ }
659
+
660
+ void initGlobalState(bool initDocs)
661
+ {
662
+ const size_t V = this->realV;
663
+ this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(K);
664
+ if (initDocs)
665
+ {
666
+ this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
667
+ this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
668
+ }
669
+ if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
670
+ }
671
+
672
+ struct Generator
673
+ {
674
+ std::uniform_int_distribution<Tid> theta;
675
+ };
676
+
677
+ Generator makeGeneratorForInit(const _DocType*) const
678
+ {
679
+ return Generator{ std::uniform_int_distribution<Tid>{0, (Tid)(K - 1)} };
680
+ }
681
+
682
+ template<bool _Infer>
683
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
684
+ {
685
+ auto& z = doc.Zs[i];
686
+ auto w = doc.words[i];
687
+ if (etaByTopicWord.size())
688
+ {
689
+ auto col = etaByTopicWord.col(w);
690
+ z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
691
+ }
692
+ else
693
+ {
694
+ z = g.theta(rgs);
695
+ }
696
+ addWordTo<1>(ld, doc, i, w, z);
697
+ }
698
+
699
+ template<bool _Infer, typename _Generator>
700
+ void initializeDocState(_DocType& doc, size_t docId, _Generator& g, _ModelState& ld, _RandGen& rgs) const
701
+ {
702
+ std::vector<uint32_t> tf(this->realV);
703
+ static_cast<const DerivedClass*>(this)->prepareDoc(doc, docId, doc.words.size());
704
+ _Generator g2;
705
+ _Generator* selectedG = &g;
706
+ if (m_flags & flags::generator_by_doc)
707
+ {
708
+ g2 = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(&doc);
709
+ selectedG = &g2;
710
+ }
711
+ if (_tw == TermWeight::pmi)
712
+ {
713
+ std::fill(tf.begin(), tf.end(), 0);
714
+ for (auto& w : doc.words) if(w < this->realV) ++tf[w];
715
+ }
716
+
717
+ for (size_t i = 0; i < doc.words.size(); ++i)
718
+ {
719
+ if (doc.words[i] >= this->realV) continue;
720
+ if (_tw == TermWeight::idf)
721
+ {
722
+ doc.wordWeights[i] = vocabWeights[doc.words[i]];
723
+ }
724
+ else if (_tw == TermWeight::pmi)
725
+ {
726
+ doc.wordWeights[i] = std::max((Float)log(tf[doc.words[i]] / vocabWeights[doc.words[i]] / doc.words.size()), (Float)0);
727
+ }
728
+ static_cast<const DerivedClass*>(this)->template updateStateWithDoc<_Infer>(*selectedG, ld, rgs, doc, i);
729
+ }
730
+ doc.updateSumWordWeight(this->realV);
731
+ }
732
+
733
+ std::vector<uint64_t> _getTopicsCount() const
734
+ {
735
+ std::vector<uint64_t> cnt(K);
736
+ for (auto& doc : this->docs)
737
+ {
738
+ for (size_t i = 0; i < doc.Zs.size(); ++i)
739
+ {
740
+ if (doc.words[i] < this->realV) ++cnt[doc.Zs[i]];
741
+ }
742
+ }
743
+ return cnt;
744
+ }
745
+
746
+ std::vector<Float> _getWidsByTopic(size_t tid) const
747
+ {
748
+ assert(tid < this->globalState.numByTopic.rows());
749
+ const size_t V = this->realV;
750
+ std::vector<Float> ret(V);
751
+ Float sum = this->globalState.numByTopic[tid] + V * eta;
752
+ auto r = this->globalState.numByTopicWord.row(tid);
753
+ for (size_t v = 0; v < V; ++v)
754
+ {
755
+ ret[v] = (r[v] + eta) / sum;
756
+ }
757
+ return ret;
758
+ }
759
+
760
+ template<bool _Together, ParallelScheme _ps, typename _Iter>
761
+ std::vector<double> _infer(_Iter docFirst, _Iter docLast, size_t maxIter, Float tolerance, size_t numWorkers) const
762
+ {
763
+ decltype(static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
764
+ if (!(m_flags & flags::generator_by_doc))
765
+ {
766
+ generator = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr);
767
+ }
768
+
769
+ if (_Together)
770
+ {
771
+ numWorkers = std::min(numWorkers, this->maxThreads[(size_t)_ps]);
772
+ ThreadPool pool{ numWorkers };
773
+ // temporary state variable
774
+ _RandGen rgc{};
775
+ auto tmpState = this->globalState, tState = this->globalState;
776
+ for (auto d = docFirst; d != docLast; ++d)
777
+ {
778
+ initializeDocState<true>(*d, -1, generator, tmpState, rgc);
779
+ }
780
+
781
+ std::vector<decltype(tmpState)> localData((m_flags & flags::shared_state) ? 0 : pool.getNumWorkers(), tmpState);
782
+ std::vector<_RandGen> rgs;
783
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i) rgs.emplace_back(rgc());
784
+
785
+ ExtraDocData edd;
786
+ if (_ps == ParallelScheme::partition)
787
+ {
788
+ updatePartition(pool, tmpState, localData.data(), docFirst, docLast, edd);
789
+ }
790
+
791
+ for (size_t i = 0; i < maxIter; ++i)
792
+ {
793
+ std::vector<std::future<void>> res;
794
+ performSampling<_ps, true>(pool,
795
+ (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
796
+ docFirst, docLast, edd);
797
+ static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
798
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
799
+ &pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast);
800
+ }
801
+ double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
802
+ ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
803
+ return { ll };
804
+ }
805
+ else if (m_flags & flags::shared_state)
806
+ {
807
+ ThreadPool pool{ numWorkers };
808
+ ExtraDocData edd;
809
+ std::vector<double> ret;
810
+ const double gllRest = static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
811
+ for (auto d = docFirst; d != docLast; ++d)
812
+ {
813
+ _RandGen rgc{};
814
+ auto tmpState = this->globalState;
815
+ initializeDocState<true>(*d, -1, generator, tmpState, rgc);
816
+ for (size_t i = 0; i < maxIter; ++i)
817
+ {
818
+ static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
819
+ static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
820
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
821
+ &pool, &tmpState, &rgc, &*d, &*d + 1);
822
+ }
823
+ double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
824
+ ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(&*d, &*d + 1);
825
+ ret.emplace_back(ll);
826
+ }
827
+ return ret;
828
+ }
829
+ else
830
+ {
831
+ ThreadPool pool{ numWorkers, numWorkers * 8 };
832
+ ExtraDocData edd;
833
+ std::vector<std::future<double>> res;
834
+ const double gllRest = static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
835
+ for (auto d = docFirst; d != docLast; ++d)
836
+ {
837
+ res.emplace_back(pool.enqueue([&, d](size_t threadId)
838
+ {
839
+ _RandGen rgc{};
840
+ auto tmpState = this->globalState;
841
+ initializeDocState<true>(*d, -1, generator, tmpState, rgc);
842
+ for (size_t i = 0; i < maxIter; ++i)
843
+ {
844
+ static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
845
+ static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
846
+ *d, edd, -1, tmpState, rgc, i
847
+ );
848
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
849
+ nullptr, &tmpState, &rgc, &*d, &*d + 1
850
+ );
851
+ }
852
+ double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
853
+ ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(&*d, &*d + 1);
854
+ return ll;
855
+ }));
856
+ }
857
+ std::vector<double> ret;
858
+ for (auto& r : res) ret.emplace_back(r.get());
859
+ return ret;
860
+ }
861
+ }
862
+
863
+ public:
864
+ DEFINE_SERIALIZER_WITH_VERSION(0, vocabWeights, alpha, alphas, eta, K);
865
+
866
+ DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, vocabWeights, alpha, alphas, eta, K, etaByWord,
867
+ burnIn, optimInterval);
868
+
869
+ LDAModel(size_t _K = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
870
+ : BaseClass(_rg), K(_K), alpha(_alpha), eta(_eta)
871
+ {
872
+ if (_K == 0 || _K >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong K value (K = %zd)", _K));
873
+ if (_alpha <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alpha value (alpha = %f)", _alpha));
874
+ if (_eta <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong eta value (eta = %f)", _eta));
875
+ alphas = Eigen::Matrix<Float, -1, 1>::Constant(K, alpha);
876
+ }
877
+
878
+ GETTER(K, size_t, K);
879
+ GETTER(Alpha, Float, alpha);
880
+ GETTER(Eta, Float, eta);
881
+ GETTER(OptimInterval, size_t, optimInterval);
882
+ GETTER(BurnInIteration, size_t, burnIn);
883
+
884
+ Float getAlpha(size_t k1) const override { return alphas[k1]; }
885
+
886
+ TermWeight getTermWeight() const override
887
+ {
888
+ return _tw;
889
+ }
890
+
891
+ void setOptimInterval(size_t _optimInterval) override
892
+ {
893
+ optimInterval = _optimInterval;
894
+ }
895
+
896
+ void setBurnInIteration(size_t iteration) override
897
+ {
898
+ burnIn = iteration;
899
+ }
900
+
901
+ size_t addDoc(const std::vector<std::string>& words) override
902
+ {
903
+ return this->_addDoc(this->_makeDoc(words));
904
+ }
905
+
906
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words) const override
907
+ {
908
+ return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
909
+ }
910
+
911
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer) override
912
+ {
913
+ return this->_addDoc(this->template _makeRawDoc<false>(rawStr, tokenizer));
914
+ }
915
+
916
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer) const override
917
+ {
918
+ return make_unique<_DocType>(as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer));
919
+ }
920
+
921
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
922
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len) override
923
+ {
924
+ return this->_addDoc(this->_makeRawDoc(rawStr, words, pos, len));
925
+ }
926
+
927
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
928
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len) const override
929
+ {
930
+ return make_unique<_DocType>(this->_makeRawDoc(rawStr, words, pos, len));
931
+ }
932
+
933
+ void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
934
+ {
935
+ if (priors.size() != K) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
936
+ for (auto p : priors)
937
+ {
938
+ if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
939
+ }
940
+ this->dict.add(word);
941
+ etaByWord.emplace(word, priors);
942
+ }
943
+
944
+ std::vector<Float> getWordPrior(const std::string& word) const override
945
+ {
946
+ if (etaByTopicWord.size())
947
+ {
948
+ auto id = this->dict.toWid(word);
949
+ if (id == (Vid)-1) return {};
950
+ auto col = etaByTopicWord.col(id);
951
+ return std::vector<Float>{ col.data(), col.data() + col.size() };
952
+ }
953
+ else
954
+ {
955
+ auto it = etaByWord.find(word);
956
+ if (it == etaByWord.end()) return {};
957
+ return it->second;
958
+ }
959
+ }
960
+
961
+ void updateDocs()
962
+ {
963
+ size_t docId = 0;
964
+ for (auto& doc : this->docs)
965
+ {
966
+ doc.template update<>(getTopicDocPtr(docId++), *static_cast<DerivedClass*>(this));
967
+ }
968
+ }
969
+
970
+ void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
971
+ {
972
+ if (initDocs) this->removeStopwords(minWordCnt, minWordDf, removeTopN);
973
+ static_cast<DerivedClass*>(this)->updateWeakArray();
974
+ static_cast<DerivedClass*>(this)->initGlobalState(initDocs);
975
+ static_cast<DerivedClass*>(this)->prepareWordPriors();
976
+
977
+ const size_t V = this->realV;
978
+
979
+ if (initDocs)
980
+ {
981
+ std::vector<uint32_t> df, cf, tf;
982
+ uint32_t totCf;
983
+
984
+ // calculate weighting
985
+ if (_tw != TermWeight::one)
986
+ {
987
+ df.resize(V);
988
+ tf.resize(V);
989
+ for (auto& doc : this->docs)
990
+ {
991
+ for (auto w : std::unordered_set<Vid>{ doc.words.begin(), doc.words.end() })
992
+ {
993
+ if (w >= this->realV) continue;
994
+ ++df[w];
995
+ }
996
+ }
997
+ totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
998
+ }
999
+ if (_tw == TermWeight::idf)
1000
+ {
1001
+ vocabWeights.resize(V);
1002
+ for (size_t i = 0; i < V; ++i)
1003
+ {
1004
+ vocabWeights[i] = log(this->docs.size() / (Float)df[i]);
1005
+ }
1006
+ }
1007
+ else if (_tw == TermWeight::pmi)
1008
+ {
1009
+ vocabWeights.resize(V);
1010
+ for (size_t i = 0; i < V; ++i)
1011
+ {
1012
+ vocabWeights[i] = this->vocabCf[i] / (float)totCf;
1013
+ }
1014
+ }
1015
+
1016
+ decltype(static_cast<DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
1017
+ if(!(m_flags & flags::generator_by_doc)) generator = static_cast<DerivedClass*>(this)->makeGeneratorForInit(nullptr);
1018
+ for (auto& doc : this->docs)
1019
+ {
1020
+ initializeDocState<false>(doc, &doc - &this->docs[0], generator, this->globalState, this->rg);
1021
+ }
1022
+ }
1023
+ else
1024
+ {
1025
+ static_cast<DerivedClass*>(this)->updateDocs();
1026
+ for (auto& doc : this->docs) doc.updateSumWordWeight(this->realV);
1027
+ }
1028
+ static_cast<DerivedClass*>(this)->prepareShared();
1029
+ BaseClass::prepare(initDocs, minWordCnt, minWordDf, removeTopN);
1030
+ }
1031
+
1032
+ std::vector<uint64_t> getCountByTopic() const override
1033
+ {
1034
+ return static_cast<const DerivedClass*>(this)->_getTopicsCount();
1035
+ }
1036
+
1037
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
1038
+ {
1039
+ std::vector<Float> ret(K);
1040
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), K }.array() =
1041
+ (doc.numByTopic.array().template cast<Float>() + alphas.array()) / (doc.getSumWordWeight() + alphas.sum());
1042
+ return ret;
1043
+ }
1044
+
1045
+ };
1046
+
1047
+ template<TermWeight _tw>
1048
+ template<typename _TopicModel>
1049
+ void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
1050
+ {
1051
+ numByTopic.init(ptr, mdl.getK());
1052
+ for (size_t i = 0; i < Zs.size(); ++i)
1053
+ {
1054
+ if (this->words[i] >= mdl.getV()) continue;
1055
+ numByTopic[Zs[i]] += _tw != TermWeight::one ? wordWeights[i] : 1;
1056
+ }
1057
+ }
1058
+ }