tomoto 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (420) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +123 -0
  5. data/ext/tomoto/ext.cpp +245 -0
  6. data/ext/tomoto/extconf.rb +28 -0
  7. data/lib/tomoto.rb +12 -0
  8. data/lib/tomoto/ct.rb +11 -0
  9. data/lib/tomoto/hdp.rb +11 -0
  10. data/lib/tomoto/lda.rb +67 -0
  11. data/lib/tomoto/version.rb +3 -0
  12. data/vendor/EigenRand/EigenRand/Core.h +1139 -0
  13. data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
  14. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
  15. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
  16. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
  17. data/vendor/EigenRand/EigenRand/EigenRand +19 -0
  18. data/vendor/EigenRand/EigenRand/Macro.h +24 -0
  19. data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
  20. data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
  21. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
  22. data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
  23. data/vendor/EigenRand/EigenRand/doc.h +220 -0
  24. data/vendor/EigenRand/LICENSE +21 -0
  25. data/vendor/EigenRand/README.md +288 -0
  26. data/vendor/eigen/COPYING.BSD +26 -0
  27. data/vendor/eigen/COPYING.GPL +674 -0
  28. data/vendor/eigen/COPYING.LGPL +502 -0
  29. data/vendor/eigen/COPYING.MINPACK +52 -0
  30. data/vendor/eigen/COPYING.MPL2 +373 -0
  31. data/vendor/eigen/COPYING.README +18 -0
  32. data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
  33. data/vendor/eigen/Eigen/Cholesky +46 -0
  34. data/vendor/eigen/Eigen/CholmodSupport +48 -0
  35. data/vendor/eigen/Eigen/Core +537 -0
  36. data/vendor/eigen/Eigen/Dense +7 -0
  37. data/vendor/eigen/Eigen/Eigen +2 -0
  38. data/vendor/eigen/Eigen/Eigenvalues +61 -0
  39. data/vendor/eigen/Eigen/Geometry +62 -0
  40. data/vendor/eigen/Eigen/Householder +30 -0
  41. data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
  42. data/vendor/eigen/Eigen/Jacobi +33 -0
  43. data/vendor/eigen/Eigen/LU +50 -0
  44. data/vendor/eigen/Eigen/MetisSupport +35 -0
  45. data/vendor/eigen/Eigen/OrderingMethods +73 -0
  46. data/vendor/eigen/Eigen/PaStiXSupport +48 -0
  47. data/vendor/eigen/Eigen/PardisoSupport +35 -0
  48. data/vendor/eigen/Eigen/QR +51 -0
  49. data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
  50. data/vendor/eigen/Eigen/SPQRSupport +34 -0
  51. data/vendor/eigen/Eigen/SVD +51 -0
  52. data/vendor/eigen/Eigen/Sparse +36 -0
  53. data/vendor/eigen/Eigen/SparseCholesky +45 -0
  54. data/vendor/eigen/Eigen/SparseCore +69 -0
  55. data/vendor/eigen/Eigen/SparseLU +46 -0
  56. data/vendor/eigen/Eigen/SparseQR +37 -0
  57. data/vendor/eigen/Eigen/StdDeque +27 -0
  58. data/vendor/eigen/Eigen/StdList +26 -0
  59. data/vendor/eigen/Eigen/StdVector +27 -0
  60. data/vendor/eigen/Eigen/SuperLUSupport +64 -0
  61. data/vendor/eigen/Eigen/UmfPackSupport +40 -0
  62. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
  63. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
  64. data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
  65. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
  66. data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
  67. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
  68. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
  69. data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
  70. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
  71. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
  72. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
  73. data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
  74. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
  75. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
  76. data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
  77. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
  78. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
  79. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
  80. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
  81. data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
  82. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
  84. data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
  85. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
  86. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
  87. data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
  88. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
  89. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
  90. data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
  91. data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
  92. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
  93. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
  94. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
  95. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
  96. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
  97. data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
  98. data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
  99. data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
  100. data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
  101. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
  102. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
  103. data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
  104. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
  105. data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
  106. data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
  107. data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
  108. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
  109. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
  110. data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
  111. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
  112. data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
  113. data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
  114. data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
  115. data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
  116. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
  117. data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
  118. data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
  119. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
  120. data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
  121. data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
  122. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
  123. data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
  124. data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
  125. data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
  126. data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
  127. data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
  128. data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
  129. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
  130. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
  131. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
  132. data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
  133. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
  134. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
  135. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
  136. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
  137. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
  138. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
  139. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
  140. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
  141. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
  142. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
  143. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
  145. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
  146. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
  148. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
  149. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
  153. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
  154. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
  156. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
  160. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
  161. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
  162. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
  163. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
  164. data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
  165. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
  166. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
  167. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
  168. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
  169. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
  170. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
  171. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
  172. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
  173. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
  174. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
  175. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
  176. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
  177. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
  178. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
  179. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
  180. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
  181. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
  182. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
  183. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
  184. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
  185. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
  186. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
  187. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
  188. data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
  189. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
  190. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
  191. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
  192. data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
  193. data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
  194. data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
  195. data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
  196. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
  197. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
  198. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
  199. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
  200. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
  201. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
  202. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
  203. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
  204. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
  205. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
  206. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
  207. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
  208. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
  209. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
  210. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
  211. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
  212. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
  213. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
  214. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
  215. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
  216. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
  217. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
  218. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
  219. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
  220. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
  221. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
  222. data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
  223. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
  224. data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
  225. data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
  226. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
  227. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
  228. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
  229. data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
  230. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
  231. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
  232. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
  233. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
  234. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
  235. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
  236. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
  237. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
  238. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
  239. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
  240. data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
  241. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
  242. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
  243. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
  244. data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
  245. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
  246. data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
  247. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
  248. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
  249. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
  250. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
  251. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
  252. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
  253. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
  254. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
  255. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
  256. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
  257. data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
  258. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
  259. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
  260. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
  261. data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
  262. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
  263. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
  264. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
  265. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
  266. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
  267. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
  268. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
  269. data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
  270. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
  271. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
  272. data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
  273. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
  274. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
  275. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
  276. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
  277. data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
  278. data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
  279. data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
  280. data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
  281. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
  282. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
  283. data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
  284. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
  285. data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
  295. data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
  296. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
  297. data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
  298. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
  299. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
  300. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
  307. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
  308. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
  309. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
  310. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
  311. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
  312. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
  313. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
  314. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
  315. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
  316. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
  317. data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
  318. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
  319. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
  320. data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
  321. data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
  322. data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
  323. data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
  324. data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
  325. data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
  326. data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
  327. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
  328. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
  329. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
  330. data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
  331. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
  332. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
  333. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
  334. data/vendor/eigen/README.md +3 -0
  335. data/vendor/eigen/bench/README.txt +55 -0
  336. data/vendor/eigen/bench/btl/COPYING +340 -0
  337. data/vendor/eigen/bench/btl/README +154 -0
  338. data/vendor/eigen/bench/tensors/README +21 -0
  339. data/vendor/eigen/blas/README.txt +6 -0
  340. data/vendor/eigen/demos/mandelbrot/README +10 -0
  341. data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
  342. data/vendor/eigen/demos/opengl/README +13 -0
  343. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
  344. data/vendor/eigen/unsupported/README.txt +50 -0
  345. data/vendor/tomotopy/LICENSE +21 -0
  346. data/vendor/tomotopy/README.kr.rst +375 -0
  347. data/vendor/tomotopy/README.rst +382 -0
  348. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
  349. data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
  350. data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
  351. data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
  352. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
  353. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
  354. data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
  355. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
  356. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
  357. data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
  358. data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
  359. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
  360. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
  361. data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
  362. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
  363. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
  364. data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
  365. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
  366. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
  367. data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
  368. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
  369. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
  370. data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
  371. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
  372. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
  373. data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
  374. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
  375. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
  376. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
  377. data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
  378. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
  379. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
  380. data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
  381. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
  382. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
  383. data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
  384. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
  385. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
  386. data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
  387. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
  388. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
  389. data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
  390. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
  391. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
  392. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
  393. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
  394. data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
  395. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
  396. data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
  397. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
  398. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
  399. data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
  400. data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
  401. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
  402. data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
  403. data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
  404. data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
  405. data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
  406. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
  407. data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
  408. data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
  409. data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
  410. data/vendor/tomotopy/src/Utils/exception.h +28 -0
  411. data/vendor/tomotopy/src/Utils/math.h +281 -0
  412. data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
  413. data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
  414. data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
  415. data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
  416. data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
  417. data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
  418. data/vendor/tomotopy/src/Utils/text.hpp +49 -0
  419. data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
  420. metadata +531 -0
@@ -0,0 +1,40 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentHLDA : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using WeightType = typename DocumentLDA<_tw>::WeightType;
11
+ using DocumentLDA<_tw>::DocumentLDA;
12
+
13
+ // numByTopic indicates numByLevel in HLDAModel.
14
+ // Zs indicates level in HLDAModel.
15
+ std::vector<int32_t> path;
16
+
17
+ template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
18
+
19
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, path);
20
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, path);
21
+ };
22
+
23
+ class IHLDAModel : public ILDAModel
24
+ {
25
+ public:
26
+ using DefaultDocType = DocumentHLDA<TermWeight::one>;
27
+ static IHLDAModel* create(TermWeight _weight, size_t levelDepth = 1,
28
+ Float alpha = 0.1, Float eta = 0.01, Float gamma = 0.1, size_t seed = std::random_device{}(),
29
+ bool scalarRng = false);
30
+
31
+ virtual Float getGamma() const = 0;
32
+ virtual size_t getLiveK() const = 0;
33
+ virtual size_t getLevelDepth() const = 0;
34
+ virtual bool isLiveTopic(Tid tid) const = 0;
35
+ virtual size_t getNumDocsOfTopic(Tid tid) const = 0;
36
+ virtual size_t getLevelOfTopic(Tid tid) const = 0;
37
+ virtual size_t getParentTopicId(Tid tid) const = 0;
38
+ virtual std::vector<uint32_t> getChildTopicId(Tid tid) const = 0;
39
+ };
40
+ }
@@ -0,0 +1,13 @@
1
+ #include "HLDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class HLDAModel<TermWeight::one>;
6
+ template class HLDAModel<TermWeight::idf>;
7
+ template class HLDAModel<TermWeight::pmi>;*/
8
+
9
+ IHLDAModel* IHLDAModel::create(TermWeight _weight, size_t levelDepth, Float _alpha, Float _eta, Float _gamma, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, HLDAModel, levelDepth, _alpha, _eta, _gamma, seed);
12
+ }
13
+ }
@@ -0,0 +1,681 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "HLDA.h"
4
+
5
+ /*
6
+ Implementation of hLDA using Gibbs sampling by bab2min
7
+
8
+ * Griffiths, T. L., Jordan, M. I., Tenenbaum, J. B., & Blei, D. M. (2004). Hierarchical topic models and the nested Chinese restaurant process. In Advances in neural information processing systems (pp. 17-24).
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ namespace detail
14
+ {
15
+ struct NCRPNode
16
+ {
17
+ int32_t numCustomers = 0, level = 0;
18
+ int32_t parent = 0, sibling = 0, child = 0;
19
+
20
+ DEFINE_SERIALIZER(numCustomers, level, parent, sibling, child);
21
+
22
+ NCRPNode* getParent() const
23
+ {
24
+ if (!parent) return nullptr;
25
+ return (NCRPNode*)(this + parent);
26
+ }
27
+
28
+ NCRPNode* getSibling() const
29
+ {
30
+ if (!sibling) return nullptr;
31
+ return (NCRPNode*)(this + sibling);
32
+ }
33
+
34
+ NCRPNode* getChild() const
35
+ {
36
+ if (!child) return nullptr;
37
+ return (NCRPNode*)(this + child);
38
+ }
39
+
40
+ void setSibling(NCRPNode* node)
41
+ {
42
+ sibling = node ? (node - this) : 0;
43
+ }
44
+
45
+ NCRPNode* addChild(NCRPNode* newChild)
46
+ {
47
+ auto* orgChild = getChild();
48
+ child = newChild - this;
49
+ newChild->parent = this - newChild;
50
+ newChild->setSibling(orgChild);
51
+ return newChild;
52
+ }
53
+
54
+ void removeChild(NCRPNode* del)
55
+ {
56
+ NCRPNode* prev = getChild();
57
+ if (prev == del)
58
+ {
59
+ child = del->getSibling() ? del->getSibling() - this : 0;
60
+ return;
61
+ }
62
+
63
+ for (NCRPNode* node = prev->getSibling(); node; node = node->getSibling())
64
+ {
65
+ if (node == del)
66
+ {
67
+ prev->setSibling(node->getSibling());
68
+ return;
69
+ }
70
+ prev = node;
71
+ }
72
+
73
+ throw std::runtime_error{ "Cannot find the child" };
74
+ }
75
+
76
+ operator bool() const
77
+ {
78
+ return numCustomers || level;
79
+ }
80
+
81
+ bool isLeaf(int totLevel) const
82
+ {
83
+ return level == totLevel - 1;
84
+ }
85
+
86
+ void dropPathOne()
87
+ {
88
+ NCRPNode* node = this;
89
+ size_t _level = this->level;
90
+ for (size_t i = 0; i <= _level; ++i)
91
+ {
92
+ if (!--node->numCustomers)
93
+ {
94
+ node->level = 0;
95
+ node->getParent()->removeChild(node);
96
+ }
97
+ node = node->getParent();
98
+ }
99
+ }
100
+
101
+ void addPathOne()
102
+ {
103
+ NCRPNode* node = this;
104
+ for (size_t i = 0; i <= level; ++i)
105
+ {
106
+ ++node->numCustomers;
107
+ node = node->getParent();
108
+ }
109
+ }
110
+ };
111
+
112
+ struct NodeTrees
113
+ {
114
+ static constexpr size_t blockSize = 8;
115
+ std::vector<NCRPNode> nodes;
116
+ std::vector<uint8_t> levelBlocks;
117
+ Eigen::Matrix<Float, -1, 1> nodeLikelihoods; //
118
+ Eigen::Matrix<Float, -1, 1> nodeWLikelihoods; //
119
+
120
+ DEFINE_SERIALIZER(nodes, levelBlocks);
121
+
122
+ template<bool _MakeNewPath = true>
123
+ void calcNodeLikelihood(Float gamma, size_t levelDepth)
124
+ {
125
+ nodeLikelihoods.resize(nodes.size());
126
+ nodeLikelihoods.array() = -INFINITY;
127
+ updateNodeLikelihood<_MakeNewPath>(gamma, levelDepth, &nodes[0]);
128
+ }
129
+
130
+ template<bool _MakeNewPath = true>
131
+ void updateNodeLikelihood(Float gamma, size_t levelDepth, NCRPNode* node, Float weight = 0)
132
+ {
133
+ size_t idx = node - nodes.data();
134
+ const Float pNewNode = _MakeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
135
+ nodeLikelihoods[idx] = weight + ((node->level < levelDepth - 1) ? pNewNode : 0);
136
+ for(auto * child = node->getChild(); child; child = child->getSibling())
137
+ {
138
+ updateNodeLikelihood(gamma, levelDepth, child, weight + log(child->numCustomers / (node->numCustomers + gamma)));
139
+ }
140
+ }
141
+
142
+ void markEmptyBlocks()
143
+ {
144
+ for (size_t b = 0; b < levelBlocks.size(); ++b)
145
+ {
146
+ if (!levelBlocks[b]) continue;
147
+ bool filled = std::any_of(nodes.begin() + (b + 1) * blockSize, nodes.begin() + (b + 2) * blockSize, [](const NCRPNode& node)
148
+ {
149
+ return !!node;
150
+ });
151
+ if (!filled) levelBlocks[b] = 0;
152
+ }
153
+ }
154
+
155
+ NCRPNode* newNode(size_t level)
156
+ {
157
+ for (size_t b = 0; b < levelBlocks.size(); ++b)
158
+ {
159
+ if (levelBlocks[b] != level) continue;
160
+ for (size_t i = 0; i < blockSize; ++i)
161
+ {
162
+ const size_t id = blockSize + i + b * blockSize;
163
+ if (!nodes[id]) return &nodes[id];
164
+ }
165
+ }
166
+
167
+ for (size_t b = 0; b < levelBlocks.size(); ++b)
168
+ {
169
+ if (!levelBlocks[b])
170
+ {
171
+ levelBlocks[b] = level;
172
+ return &nodes[blockSize + b * blockSize];
173
+ }
174
+ }
175
+ nodes.insert(nodes.end(), blockSize, NCRPNode{});
176
+ levelBlocks.emplace_back(level);
177
+ return &nodes[nodes.size() - blockSize];
178
+ }
179
+
180
+ template<TermWeight _tw>
181
+ void calcWordLikelihood(Float eta, size_t realV, size_t levelDepth, ThreadPool* pool,
182
+ const DocumentHLDA<_tw>& doc, const std::vector<Float>& newTopicWeights,
183
+ const ModelStateLDA<_tw>& ld)
184
+ {
185
+ nodeWLikelihoods.resize(nodes.size());
186
+ nodeWLikelihoods.setZero();
187
+ std::vector<std::future<void>> futures;
188
+ futures.reserve(levelBlocks.size());
189
+
190
+ auto calc = [this, eta, realV, &doc, &ld](size_t threadId, size_t b)
191
+ {
192
+ Float cnt = 0;
193
+ Vid prevWord = -1;
194
+ const size_t bStart = blockSize + b * blockSize;
195
+ for (size_t w = 0; w < doc.words.size(); ++w)
196
+ {
197
+ if (doc.words[w] >= realV) break;
198
+ if (doc.Zs[w] != levelBlocks[b]) continue;
199
+ if (doc.words[w] != prevWord)
200
+ {
201
+ if (prevWord != (Vid)-1)
202
+ {
203
+ if (cnt == 1) nodeWLikelihoods.segment(bStart, blockSize).array()
204
+ += (ld.numByTopicWord.col(prevWord).segment(bStart, blockSize).array().template cast<Float>() + eta).log();
205
+ else nodeWLikelihoods.segment(bStart, blockSize).array()
206
+ += Eigen::lgamma_subt(ld.numByTopicWord.col(prevWord).segment(bStart, blockSize).array().template cast<Float>() + eta, cnt);
207
+ }
208
+ cnt = 0;
209
+ prevWord = doc.words[w];
210
+ }
211
+ cnt += doc.getWordWeight(w);
212
+ }
213
+ if (prevWord != (Vid)-1)
214
+ {
215
+ if (cnt == 1) nodeWLikelihoods.segment(bStart, blockSize).array()
216
+ += (ld.numByTopicWord.col(prevWord).segment(bStart, blockSize).array().template cast<Float>() + eta).log();
217
+ else nodeWLikelihoods.segment(bStart, blockSize).array()
218
+ += Eigen::lgamma_subt(ld.numByTopicWord.col(prevWord).segment(bStart, blockSize).array().template cast<Float>() + eta, cnt);
219
+ }
220
+ nodeWLikelihoods.segment(bStart, blockSize).array()
221
+ -= Eigen::lgamma_subt(ld.numByTopic.segment(bStart, blockSize).array().template cast<Float>() + realV * eta, (Float)doc.numByTopic[levelBlocks[b]]);
222
+ };
223
+
224
+ // we elide the likelihood for root node because its weight applied to all path and can be seen as constant.
225
+ if (pool)
226
+ {
227
+ const size_t chStride = pool->getNumWorkers() * 8;
228
+ for (size_t ch = 0; ch < chStride; ++ch)
229
+ {
230
+ futures.emplace_back(pool->enqueue([&](size_t threadId, size_t bBegin, size_t bEnd)
231
+ {
232
+ for (size_t b = bBegin; b < bEnd; ++b)
233
+ {
234
+ if (!levelBlocks[b]) continue;
235
+ calc(threadId, b);
236
+ }
237
+ }, levelBlocks.size() * ch / chStride, levelBlocks.size() * (ch + 1) / chStride));
238
+ }
239
+ for (auto& f : futures) f.get();
240
+ }
241
+ else
242
+ {
243
+ for (size_t b = 0; b < levelBlocks.size(); ++b)
244
+ {
245
+ if (!levelBlocks[b]) continue;
246
+ calc(0, b);
247
+ }
248
+ }
249
+
250
+ updateWordLikelihood<_tw>(eta, realV, levelDepth, doc, newTopicWeights, &nodes[0]);
251
+ }
252
+
253
+ template<TermWeight _tw>
254
+ void updateWordLikelihood(Float eta, size_t realV, size_t levelDepth,
255
+ const DocumentHLDA<_tw>& doc, const std::vector<Float>& newTopicWeights,
256
+ detail::NCRPNode* node, Float weight = 0)
257
+ {
258
+ size_t idx = node - nodes.data();
259
+ weight += nodeWLikelihoods[idx];
260
+ nodeLikelihoods[idx] += weight;
261
+ for (size_t l = node->level + 1; l < levelDepth; ++l)
262
+ {
263
+ nodeLikelihoods[idx] += newTopicWeights[l - 1];
264
+ }
265
+ for (auto* child = node->getChild(); child; child = child->getSibling())
266
+ {
267
+ updateWordLikelihood<_tw>(eta, realV, levelDepth, doc, newTopicWeights, child, weight);
268
+ }
269
+ }
270
+
271
+ template<TermWeight _tw>
272
+ size_t generateLeafNode(size_t idx, size_t levelDepth,
273
+ ModelStateLDA<_tw>& ld)
274
+ {
275
+ for (size_t l = nodes[idx].level + 1; l < levelDepth; ++l)
276
+ {
277
+ auto* nnode = newNode(l);
278
+ idx = nodes[idx].addChild(nnode) - nodes.data();
279
+ nodes[idx].level = l;
280
+ }
281
+
282
+ if (ld.numByTopic.size() < nodes.size())
283
+ {
284
+ size_t oldSize = ld.numByTopic.rows();
285
+ size_t newSize = std::max(nodes.size(), ((oldSize + oldSize / 2 + 7) / 8) * 8);
286
+ ld.numByTopic.conservativeResize(newSize);
287
+ ld.numByTopicWord.conservativeResize(newSize, Eigen::NoChange);
288
+ ld.numByTopic.segment(oldSize, newSize - oldSize).setZero();
289
+ ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, ld.numByTopicWord.cols()).setZero();
290
+ }
291
+ return idx;
292
+ }
293
+ };
294
+ }
295
+
296
+ template<TermWeight _tw>
297
+ struct ModelStateHLDA : public ModelStateLDA<_tw>
298
+ {
299
+ std::shared_ptr<detail::NodeTrees> nt;
300
+
301
+ void serializerRead(std::istream& istr)
302
+ {
303
+ ModelStateLDA<_tw>::serializerRead(istr);
304
+ nt = std::make_shared<detail::NodeTrees>();
305
+ nt->serializerRead(istr);
306
+ }
307
+
308
+ void serializerWrite(std::ostream& ostr) const
309
+ {
310
+ ModelStateLDA<_tw>::serializerWrite(ostr);
311
+ nt->serializerWrite(ostr);
312
+ }
313
+ };
314
+
315
+ template<TermWeight _tw, typename _RandGen,
316
+ typename _Interface = IHLDAModel,
317
+ typename _Derived = void,
318
+ typename _DocType = DocumentHLDA<_tw>,
319
+ typename _ModelState = ModelStateHLDA<_tw>>
320
+ class HLDAModel : public LDAModel<_tw, _RandGen, flags::shared_state, _Interface,
321
+ typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type,
322
+ _DocType, _ModelState>
323
+ {
324
+ protected:
325
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type;
326
+ using BaseClass = LDAModel<_tw, _RandGen, flags::shared_state, _Interface, DerivedClass, _DocType, _ModelState>;
327
+ friend BaseClass;
328
+ friend typename BaseClass::BaseClass;
329
+ using WeightType = typename BaseClass::WeightType;
330
+
331
+ static constexpr char TMID[] = "hLDA";
332
+
333
+ Float gamma;
334
+
335
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
336
+ {
337
+ // for alphas
338
+ BaseClass::optimizeParameters(pool, localData, rgs);
339
+ // to do: gamma
340
+
341
+ }
342
+
343
+ // Words of all documents should be sorted by ascending order.
344
+ template<bool _MakeNewPath = true>
345
+ void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
346
+ {
347
+ if(_MakeNewPath) ld.nt->nodes[doc.path.back()].dropPathOne();
348
+ ld.nt->template calcNodeLikelihood<_MakeNewPath>(gamma, this->K);
349
+
350
+ std::vector<Float> newTopicWeights(this->K - 1);
351
+ std::vector<WeightType> cntByLevel(this->K);
352
+ Vid prevWord = -1;
353
+ for (size_t w = 0; w < doc.words.size(); ++w)
354
+ {
355
+ if (doc.words[w] >= this->realV) break;
356
+ addWordToOnlyLocal<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
357
+
358
+ if (_MakeNewPath)
359
+ {
360
+ if (doc.words[w] != prevWord)
361
+ {
362
+ std::fill(cntByLevel.begin(), cntByLevel.end(), 0);
363
+ prevWord = doc.words[w];
364
+ }
365
+ size_t level = doc.Zs[w];
366
+ if (level)
367
+ {
368
+ newTopicWeights[level - 1] += log(this->eta + cntByLevel[level]);
369
+ cntByLevel[level] += doc.getWordWeight(w);
370
+ }
371
+ }
372
+ }
373
+
374
+ if (_MakeNewPath)
375
+ {
376
+ for (size_t l = 1; l < this->K; ++l)
377
+ {
378
+ newTopicWeights[l - 1] -= math::lgammaT(doc.numByTopic[l] + this->realV * this->eta) - math::lgammaT(this->realV * this->eta);
379
+ }
380
+ }
381
+
382
+ ld.nt->template calcWordLikelihood<_tw>(this->eta, this->realV, this->K, pool, doc, newTopicWeights, ld);
383
+
384
+ ld.nt->nodeLikelihoods = (ld.nt->nodeLikelihoods.array() - ld.nt->nodeLikelihoods.maxCoeff()).exp();
385
+ sample::prefixSum(ld.nt->nodeLikelihoods.data(), ld.nt->nodeLikelihoods.size());
386
+ size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
387
+ ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
388
+
389
+ if(_MakeNewPath) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
390
+ doc.path.back() = newPath;
391
+ for (size_t l = this->K - 2; l > 0; --l)
392
+ {
393
+ doc.path[l] = doc.path[l + 1] + ld.nt->nodes[doc.path[l + 1]].parent;
394
+ }
395
+
396
+ for (size_t w = 0; w < doc.words.size(); ++w)
397
+ {
398
+ if (doc.words[w] >= this->realV) break;
399
+ addWordToOnlyLocal<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
400
+ }
401
+ if (_MakeNewPath) ld.nt->nodes[doc.path.back()].addPathOne();
402
+ }
403
+
404
+ template<int _inc>
405
+ inline void addWordToOnlyLocal(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid level) const
406
+ {
407
+ assert(vid < this->realV);
408
+ constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
409
+ auto weight = doc.getWordWeight(pid);
410
+
411
+ updateCnt<_dec>(ld.numByTopic[doc.path[level]], _inc * weight);
412
+ updateCnt<_dec>(ld.numByTopicWord(doc.path[level], vid), _inc * weight);
413
+ }
414
+
415
+ template<int _inc>
416
+ inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid level) const
417
+ {
418
+ assert(vid < this->realV);
419
+ constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
420
+ auto weight = doc.getWordWeight(pid);
421
+
422
+ updateCnt<_dec>(doc.numByTopic[level], _inc * weight);
423
+ addWordToOnlyLocal<_inc>(ld, doc, pid, vid, level);
424
+ }
425
+
426
+ template<bool _asymEta>
427
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
428
+ {
429
+ const size_t V = this->realV;
430
+ assert(vid < V);
431
+ auto& zLikelihood = ld.zLikelihood;
432
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array());
433
+ for (size_t l = 0; l < this->K; ++l)
434
+ {
435
+ zLikelihood[l] *= (ld.numByTopicWord(doc.path[l], vid) + this->eta)
436
+ / (ld.numByTopic(doc.path[l]) + V * this->eta);
437
+ }
438
+ sample::prefixSum(zLikelihood.data(), zLikelihood.size());
439
+ return &zLikelihood[0];
440
+ }
441
+
442
+ void sampleTopics(_DocType& doc, size_t docId, _ModelState& ld, _RandGen& rgs) const
443
+ {
444
+ for (size_t w = 0; w < doc.words.size(); ++w)
445
+ {
446
+ if (doc.words[w] >= this->realV) continue;
447
+ addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
448
+ Float* dist;
449
+ if (this->etaByTopicWord.size())
450
+ {
451
+ THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
452
+ }
453
+ else
454
+ {
455
+ dist = static_cast<const DerivedClass*>(this)->template
456
+ getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
457
+ }
458
+ doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + this->K, rgs);
459
+ addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
460
+ }
461
+ }
462
+
463
+ template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
464
+ void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
465
+ {
466
+ sampleTopics(doc, docId, ld, rgs);
467
+ }
468
+
469
+ template<typename _DocIter>
470
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
471
+ {
472
+ for (auto doc = first; doc != last; ++doc)
473
+ {
474
+ samplePathes<>(*doc, pool, *localData, rgs[0]);
475
+ }
476
+ localData->nt->markEmptyBlocks();
477
+ }
478
+
479
+ template<typename _DocIter>
480
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
481
+ {
482
+ for (auto doc = first; doc != last; ++doc)
483
+ {
484
+ samplePathes<false>(*doc, pool, *localData, rgs[0]);
485
+ }
486
+ }
487
+
488
+ template<typename _DocIter>
489
+ double getLLDocs(_DocIter _first, _DocIter _last) const
490
+ {
491
+ double ll = 0;
492
+ auto lgammaAlpha = math::lgammaT(this->alpha);
493
+ for (; _first != _last; ++_first)
494
+ {
495
+ auto& doc = *_first;
496
+ // doc-path distribution
497
+ for (Tid l = 1; l < this->K; ++l)
498
+ {
499
+ ll += log(this->globalState.nt->nodes[doc.path[l]].numCustomers / (this->globalState.nt->nodes[doc.path[l - 1]].numCustomers + gamma));
500
+ }
501
+
502
+ // doc-level distribution
503
+ ll -= math::lgammaT(doc.getSumWordWeight() + this->alpha * this->K);
504
+ for (Tid l = 0; l < this->K; ++l)
505
+ {
506
+ ll += math::lgammaT(doc.numByTopic[l] + this->alpha) - lgammaAlpha;
507
+ }
508
+ }
509
+ ll += math::lgammaT(this->alpha * this->K) * std::distance(_first, _last);
510
+ return ll;
511
+ }
512
+
513
+ double getLLRest(const _ModelState& ld) const
514
+ {
515
+ double ll = 0;
516
+ const size_t V = this->realV;
517
+ const size_t K = ld.nt->nodes.size();
518
+ size_t liveK = 0;
519
+ // topic-word distribution
520
+ auto lgammaEta = math::lgammaT(this->eta);
521
+ for (Tid k = 0; k < K; ++k)
522
+ {
523
+ if (!ld.nt->nodes[k]) continue;
524
+ ++liveK;
525
+ ll -= math::lgammaT(ld.numByTopic[k] + V * this->eta);
526
+ for (Vid v = 0; v < V; ++v)
527
+ {
528
+ if (!ld.numByTopicWord(k, v)) continue;
529
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + this->eta) - lgammaEta;
530
+ }
531
+ }
532
+ ll += math::lgammaT(V*this->eta) * liveK;
533
+ return ll;
534
+ }
535
+
536
+ void initGlobalState(bool initDocs)
537
+ {
538
+ const size_t V = this->realV;
539
+ if (initDocs)
540
+ {
541
+ this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
542
+ this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
543
+ this->globalState.nt->nodes.resize(detail::NodeTrees::blockSize);
544
+ }
545
+ }
546
+
547
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
548
+ {
549
+ sortAndWriteOrder(doc.words, doc.wOrder);
550
+ doc.numByTopic.init(nullptr, this->K);
551
+ doc.Zs = tvector<Tid>(wordSize);
552
+ doc.path.resize(this->K);
553
+ for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
554
+
555
+ if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
556
+ }
557
+
558
+ template<bool _Infer>
559
+ void updateStateWithDoc(typename BaseClass::Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
560
+ {
561
+ if (i == 0)
562
+ {
563
+ ld.nt->template calcNodeLikelihood<!_Infer>(gamma, this->K);
564
+ ld.nt->nodeLikelihoods = (ld.nt->nodeLikelihoods.array() - ld.nt->nodeLikelihoods.maxCoeff()).exp();
565
+ sample::prefixSum(ld.nt->nodeLikelihoods.data(), ld.nt->nodeLikelihoods.size());
566
+ size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
567
+ ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
568
+
569
+ if (!_Infer) newPath = ld.nt->generateLeafNode(newPath, this->K, ld);
570
+ doc.path.back() = newPath;
571
+ for (size_t l = this->K - 2; l > 0; --l)
572
+ {
573
+ doc.path[l] = doc.path[l + 1] + ld.nt->nodes[doc.path[l + 1]].parent;
574
+ }
575
+
576
+ if (!_Infer) ld.nt->nodes[doc.path.back()].addPathOne();
577
+ }
578
+
579
+ auto& z = doc.Zs[i];
580
+ auto w = doc.words[i];
581
+ z = g.theta(rgs);
582
+ addWordTo<1>(ld, doc, i, w, z);
583
+ }
584
+
585
+ std::vector<uint64_t> _getTopicsCount() const
586
+ {
587
+ std::vector<uint64_t> cnt(this->globalState.nt->nodes.size());
588
+ for (auto& doc : this->docs)
589
+ {
590
+ for (size_t i = 0; i < doc.Zs.size(); ++i)
591
+ {
592
+ if (doc.words[i] < this->realV) ++cnt[doc.path[doc.Zs[i]]];
593
+ }
594
+ }
595
+ return cnt;
596
+ }
597
+
598
+ public:
599
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
600
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
601
+
602
+ HLDAModel(size_t _levelDepth = 4, Float _alpha = 0.1, Float _eta = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
603
+ : BaseClass(_levelDepth, _alpha, _eta, _rg), gamma(_gamma)
604
+ {
605
+ if (_levelDepth == 0 || _levelDepth >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong levelDepth value (levelDepth = %zd)", _levelDepth));
606
+ if (_gamma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong gamma value (gamma = %f)", _gamma));
607
+ this->globalState.nt = std::make_shared<detail::NodeTrees>();
608
+ }
609
+
610
+ size_t getLiveK() const override
611
+ {
612
+ return std::count_if(this->globalState.nt->nodes.begin(), this->globalState.nt->nodes.end(), [](const detail::NCRPNode& n)
613
+ {
614
+ return !!n;
615
+ });
616
+ }
617
+
618
+ size_t getK() const override
619
+ {
620
+ return this->globalState.nt->nodes.size();
621
+ }
622
+
623
+ size_t getLevelDepth() const override
624
+ {
625
+ return this->K;
626
+ }
627
+
628
+ GETTER(Gamma, Float, gamma);
629
+
630
+ bool isLiveTopic(Tid tid) const override
631
+ {
632
+ return this->globalState.nt->nodes[tid];
633
+ }
634
+
635
+ size_t getParentTopicId(Tid tid) const override
636
+ {
637
+ if (!isLiveTopic(tid)) return (size_t)-1;
638
+ return this->globalState.nt->nodes[tid].parent ? (tid + this->globalState.nt->nodes[tid].parent) : (size_t)-1;
639
+ }
640
+
641
+ size_t getNumDocsOfTopic(Tid tid) const override
642
+ {
643
+ if (!isLiveTopic(tid)) return 0;
644
+ return this->globalState.nt->nodes[tid].numCustomers;
645
+ }
646
+
647
+ size_t getLevelOfTopic(Tid tid) const override
648
+ {
649
+ if (!isLiveTopic(tid)) return (size_t)-1;
650
+ return this->globalState.nt->nodes[tid].level;
651
+ }
652
+
653
+ std::vector<uint32_t> getChildTopicId(Tid tid) const override
654
+ {
655
+ std::vector<uint32_t> ret;
656
+ if (!isLiveTopic(tid)) return ret;
657
+ for (auto* node = this->globalState.nt->nodes[tid].getChild(); node; node = node->getSibling())
658
+ {
659
+ ret.emplace_back(node - this->globalState.nt->nodes.data());
660
+ }
661
+ return ret;
662
+ }
663
+
664
+ void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
665
+ {
666
+ THROW_ERROR_WITH_INFO(exception::Unimplemented, "HLDAModel doesn't provide setWordPrior function.");
667
+ }
668
+ };
669
+
670
+ template<TermWeight _tw>
671
+ template<typename _TopicModel>
672
+ inline void DocumentHLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
673
+ {
674
+ this->numByTopic.init(ptr, mdl.getLevelDepth());
675
+ for (size_t i = 0; i < this->Zs.size(); ++i)
676
+ {
677
+ if (this->words[i] >= mdl.getV()) continue;
678
+ this->numByTopic[this->Zs[i]] += _tw != TermWeight::one ? this->wordWeights[i] : 1;
679
+ }
680
+ }
681
+ }