tomoto 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (420) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +22 -0
  4. data/README.md +123 -0
  5. data/ext/tomoto/ext.cpp +245 -0
  6. data/ext/tomoto/extconf.rb +28 -0
  7. data/lib/tomoto.rb +12 -0
  8. data/lib/tomoto/ct.rb +11 -0
  9. data/lib/tomoto/hdp.rb +11 -0
  10. data/lib/tomoto/lda.rb +67 -0
  11. data/lib/tomoto/version.rb +3 -0
  12. data/vendor/EigenRand/EigenRand/Core.h +1139 -0
  13. data/vendor/EigenRand/EigenRand/Dists/Basic.h +111 -0
  14. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +877 -0
  15. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +108 -0
  16. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +626 -0
  17. data/vendor/EigenRand/EigenRand/EigenRand +19 -0
  18. data/vendor/EigenRand/EigenRand/Macro.h +24 -0
  19. data/vendor/EigenRand/EigenRand/MorePacketMath.h +978 -0
  20. data/vendor/EigenRand/EigenRand/PacketFilter.h +286 -0
  21. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +624 -0
  22. data/vendor/EigenRand/EigenRand/RandUtils.h +413 -0
  23. data/vendor/EigenRand/EigenRand/doc.h +220 -0
  24. data/vendor/EigenRand/LICENSE +21 -0
  25. data/vendor/EigenRand/README.md +288 -0
  26. data/vendor/eigen/COPYING.BSD +26 -0
  27. data/vendor/eigen/COPYING.GPL +674 -0
  28. data/vendor/eigen/COPYING.LGPL +502 -0
  29. data/vendor/eigen/COPYING.MINPACK +52 -0
  30. data/vendor/eigen/COPYING.MPL2 +373 -0
  31. data/vendor/eigen/COPYING.README +18 -0
  32. data/vendor/eigen/Eigen/CMakeLists.txt +19 -0
  33. data/vendor/eigen/Eigen/Cholesky +46 -0
  34. data/vendor/eigen/Eigen/CholmodSupport +48 -0
  35. data/vendor/eigen/Eigen/Core +537 -0
  36. data/vendor/eigen/Eigen/Dense +7 -0
  37. data/vendor/eigen/Eigen/Eigen +2 -0
  38. data/vendor/eigen/Eigen/Eigenvalues +61 -0
  39. data/vendor/eigen/Eigen/Geometry +62 -0
  40. data/vendor/eigen/Eigen/Householder +30 -0
  41. data/vendor/eigen/Eigen/IterativeLinearSolvers +48 -0
  42. data/vendor/eigen/Eigen/Jacobi +33 -0
  43. data/vendor/eigen/Eigen/LU +50 -0
  44. data/vendor/eigen/Eigen/MetisSupport +35 -0
  45. data/vendor/eigen/Eigen/OrderingMethods +73 -0
  46. data/vendor/eigen/Eigen/PaStiXSupport +48 -0
  47. data/vendor/eigen/Eigen/PardisoSupport +35 -0
  48. data/vendor/eigen/Eigen/QR +51 -0
  49. data/vendor/eigen/Eigen/QtAlignedMalloc +40 -0
  50. data/vendor/eigen/Eigen/SPQRSupport +34 -0
  51. data/vendor/eigen/Eigen/SVD +51 -0
  52. data/vendor/eigen/Eigen/Sparse +36 -0
  53. data/vendor/eigen/Eigen/SparseCholesky +45 -0
  54. data/vendor/eigen/Eigen/SparseCore +69 -0
  55. data/vendor/eigen/Eigen/SparseLU +46 -0
  56. data/vendor/eigen/Eigen/SparseQR +37 -0
  57. data/vendor/eigen/Eigen/StdDeque +27 -0
  58. data/vendor/eigen/Eigen/StdList +26 -0
  59. data/vendor/eigen/Eigen/StdVector +27 -0
  60. data/vendor/eigen/Eigen/SuperLUSupport +64 -0
  61. data/vendor/eigen/Eigen/UmfPackSupport +40 -0
  62. data/vendor/eigen/Eigen/src/Cholesky/LDLT.h +673 -0
  63. data/vendor/eigen/Eigen/src/Cholesky/LLT.h +542 -0
  64. data/vendor/eigen/Eigen/src/Cholesky/LLT_LAPACKE.h +99 -0
  65. data/vendor/eigen/Eigen/src/CholmodSupport/CholmodSupport.h +639 -0
  66. data/vendor/eigen/Eigen/src/Core/Array.h +329 -0
  67. data/vendor/eigen/Eigen/src/Core/ArrayBase.h +226 -0
  68. data/vendor/eigen/Eigen/src/Core/ArrayWrapper.h +209 -0
  69. data/vendor/eigen/Eigen/src/Core/Assign.h +90 -0
  70. data/vendor/eigen/Eigen/src/Core/AssignEvaluator.h +935 -0
  71. data/vendor/eigen/Eigen/src/Core/Assign_MKL.h +178 -0
  72. data/vendor/eigen/Eigen/src/Core/BandMatrix.h +353 -0
  73. data/vendor/eigen/Eigen/src/Core/Block.h +452 -0
  74. data/vendor/eigen/Eigen/src/Core/BooleanRedux.h +164 -0
  75. data/vendor/eigen/Eigen/src/Core/CommaInitializer.h +160 -0
  76. data/vendor/eigen/Eigen/src/Core/ConditionEstimator.h +175 -0
  77. data/vendor/eigen/Eigen/src/Core/CoreEvaluators.h +1688 -0
  78. data/vendor/eigen/Eigen/src/Core/CoreIterators.h +127 -0
  79. data/vendor/eigen/Eigen/src/Core/CwiseBinaryOp.h +184 -0
  80. data/vendor/eigen/Eigen/src/Core/CwiseNullaryOp.h +866 -0
  81. data/vendor/eigen/Eigen/src/Core/CwiseTernaryOp.h +197 -0
  82. data/vendor/eigen/Eigen/src/Core/CwiseUnaryOp.h +103 -0
  83. data/vendor/eigen/Eigen/src/Core/CwiseUnaryView.h +128 -0
  84. data/vendor/eigen/Eigen/src/Core/DenseBase.h +611 -0
  85. data/vendor/eigen/Eigen/src/Core/DenseCoeffsBase.h +681 -0
  86. data/vendor/eigen/Eigen/src/Core/DenseStorage.h +570 -0
  87. data/vendor/eigen/Eigen/src/Core/Diagonal.h +260 -0
  88. data/vendor/eigen/Eigen/src/Core/DiagonalMatrix.h +343 -0
  89. data/vendor/eigen/Eigen/src/Core/DiagonalProduct.h +28 -0
  90. data/vendor/eigen/Eigen/src/Core/Dot.h +318 -0
  91. data/vendor/eigen/Eigen/src/Core/EigenBase.h +159 -0
  92. data/vendor/eigen/Eigen/src/Core/ForceAlignedAccess.h +146 -0
  93. data/vendor/eigen/Eigen/src/Core/Fuzzy.h +155 -0
  94. data/vendor/eigen/Eigen/src/Core/GeneralProduct.h +455 -0
  95. data/vendor/eigen/Eigen/src/Core/GenericPacketMath.h +593 -0
  96. data/vendor/eigen/Eigen/src/Core/GlobalFunctions.h +187 -0
  97. data/vendor/eigen/Eigen/src/Core/IO.h +225 -0
  98. data/vendor/eigen/Eigen/src/Core/Inverse.h +118 -0
  99. data/vendor/eigen/Eigen/src/Core/Map.h +171 -0
  100. data/vendor/eigen/Eigen/src/Core/MapBase.h +303 -0
  101. data/vendor/eigen/Eigen/src/Core/MathFunctions.h +1415 -0
  102. data/vendor/eigen/Eigen/src/Core/MathFunctionsImpl.h +101 -0
  103. data/vendor/eigen/Eigen/src/Core/Matrix.h +459 -0
  104. data/vendor/eigen/Eigen/src/Core/MatrixBase.h +529 -0
  105. data/vendor/eigen/Eigen/src/Core/NestByValue.h +110 -0
  106. data/vendor/eigen/Eigen/src/Core/NoAlias.h +108 -0
  107. data/vendor/eigen/Eigen/src/Core/NumTraits.h +248 -0
  108. data/vendor/eigen/Eigen/src/Core/PermutationMatrix.h +633 -0
  109. data/vendor/eigen/Eigen/src/Core/PlainObjectBase.h +1035 -0
  110. data/vendor/eigen/Eigen/src/Core/Product.h +186 -0
  111. data/vendor/eigen/Eigen/src/Core/ProductEvaluators.h +1112 -0
  112. data/vendor/eigen/Eigen/src/Core/Random.h +182 -0
  113. data/vendor/eigen/Eigen/src/Core/Redux.h +505 -0
  114. data/vendor/eigen/Eigen/src/Core/Ref.h +283 -0
  115. data/vendor/eigen/Eigen/src/Core/Replicate.h +142 -0
  116. data/vendor/eigen/Eigen/src/Core/ReturnByValue.h +117 -0
  117. data/vendor/eigen/Eigen/src/Core/Reverse.h +211 -0
  118. data/vendor/eigen/Eigen/src/Core/Select.h +162 -0
  119. data/vendor/eigen/Eigen/src/Core/SelfAdjointView.h +352 -0
  120. data/vendor/eigen/Eigen/src/Core/SelfCwiseBinaryOp.h +47 -0
  121. data/vendor/eigen/Eigen/src/Core/Solve.h +188 -0
  122. data/vendor/eigen/Eigen/src/Core/SolveTriangular.h +235 -0
  123. data/vendor/eigen/Eigen/src/Core/SolverBase.h +130 -0
  124. data/vendor/eigen/Eigen/src/Core/StableNorm.h +221 -0
  125. data/vendor/eigen/Eigen/src/Core/Stride.h +111 -0
  126. data/vendor/eigen/Eigen/src/Core/Swap.h +67 -0
  127. data/vendor/eigen/Eigen/src/Core/Transpose.h +403 -0
  128. data/vendor/eigen/Eigen/src/Core/Transpositions.h +407 -0
  129. data/vendor/eigen/Eigen/src/Core/TriangularMatrix.h +983 -0
  130. data/vendor/eigen/Eigen/src/Core/VectorBlock.h +96 -0
  131. data/vendor/eigen/Eigen/src/Core/VectorwiseOp.h +695 -0
  132. data/vendor/eigen/Eigen/src/Core/Visitor.h +273 -0
  133. data/vendor/eigen/Eigen/src/Core/arch/AVX/Complex.h +451 -0
  134. data/vendor/eigen/Eigen/src/Core/arch/AVX/MathFunctions.h +439 -0
  135. data/vendor/eigen/Eigen/src/Core/arch/AVX/PacketMath.h +637 -0
  136. data/vendor/eigen/Eigen/src/Core/arch/AVX/TypeCasting.h +51 -0
  137. data/vendor/eigen/Eigen/src/Core/arch/AVX512/MathFunctions.h +391 -0
  138. data/vendor/eigen/Eigen/src/Core/arch/AVX512/PacketMath.h +1316 -0
  139. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/Complex.h +430 -0
  140. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/MathFunctions.h +322 -0
  141. data/vendor/eigen/Eigen/src/Core/arch/AltiVec/PacketMath.h +1061 -0
  142. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Complex.h +103 -0
  143. data/vendor/eigen/Eigen/src/Core/arch/CUDA/Half.h +674 -0
  144. data/vendor/eigen/Eigen/src/Core/arch/CUDA/MathFunctions.h +91 -0
  145. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMath.h +333 -0
  146. data/vendor/eigen/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +1124 -0
  147. data/vendor/eigen/Eigen/src/Core/arch/CUDA/TypeCasting.h +212 -0
  148. data/vendor/eigen/Eigen/src/Core/arch/Default/ConjHelper.h +29 -0
  149. data/vendor/eigen/Eigen/src/Core/arch/Default/Settings.h +49 -0
  150. data/vendor/eigen/Eigen/src/Core/arch/NEON/Complex.h +490 -0
  151. data/vendor/eigen/Eigen/src/Core/arch/NEON/MathFunctions.h +91 -0
  152. data/vendor/eigen/Eigen/src/Core/arch/NEON/PacketMath.h +760 -0
  153. data/vendor/eigen/Eigen/src/Core/arch/SSE/Complex.h +471 -0
  154. data/vendor/eigen/Eigen/src/Core/arch/SSE/MathFunctions.h +562 -0
  155. data/vendor/eigen/Eigen/src/Core/arch/SSE/PacketMath.h +895 -0
  156. data/vendor/eigen/Eigen/src/Core/arch/SSE/TypeCasting.h +77 -0
  157. data/vendor/eigen/Eigen/src/Core/arch/ZVector/Complex.h +397 -0
  158. data/vendor/eigen/Eigen/src/Core/arch/ZVector/MathFunctions.h +137 -0
  159. data/vendor/eigen/Eigen/src/Core/arch/ZVector/PacketMath.h +945 -0
  160. data/vendor/eigen/Eigen/src/Core/functors/AssignmentFunctors.h +168 -0
  161. data/vendor/eigen/Eigen/src/Core/functors/BinaryFunctors.h +475 -0
  162. data/vendor/eigen/Eigen/src/Core/functors/NullaryFunctors.h +188 -0
  163. data/vendor/eigen/Eigen/src/Core/functors/StlFunctors.h +136 -0
  164. data/vendor/eigen/Eigen/src/Core/functors/TernaryFunctors.h +25 -0
  165. data/vendor/eigen/Eigen/src/Core/functors/UnaryFunctors.h +792 -0
  166. data/vendor/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h +2156 -0
  167. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix.h +492 -0
  168. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +311 -0
  169. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrixTriangular_BLAS.h +145 -0
  170. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixMatrix_BLAS.h +122 -0
  171. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector.h +619 -0
  172. data/vendor/eigen/Eigen/src/Core/products/GeneralMatrixVector_BLAS.h +136 -0
  173. data/vendor/eigen/Eigen/src/Core/products/Parallelizer.h +163 -0
  174. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix.h +521 -0
  175. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixMatrix_BLAS.h +287 -0
  176. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector.h +260 -0
  177. data/vendor/eigen/Eigen/src/Core/products/SelfadjointMatrixVector_BLAS.h +118 -0
  178. data/vendor/eigen/Eigen/src/Core/products/SelfadjointProduct.h +133 -0
  179. data/vendor/eigen/Eigen/src/Core/products/SelfadjointRank2Update.h +93 -0
  180. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix.h +466 -0
  181. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h +315 -0
  182. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector.h +350 -0
  183. data/vendor/eigen/Eigen/src/Core/products/TriangularMatrixVector_BLAS.h +255 -0
  184. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix.h +335 -0
  185. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverMatrix_BLAS.h +163 -0
  186. data/vendor/eigen/Eigen/src/Core/products/TriangularSolverVector.h +145 -0
  187. data/vendor/eigen/Eigen/src/Core/util/BlasUtil.h +398 -0
  188. data/vendor/eigen/Eigen/src/Core/util/Constants.h +547 -0
  189. data/vendor/eigen/Eigen/src/Core/util/DisableStupidWarnings.h +83 -0
  190. data/vendor/eigen/Eigen/src/Core/util/ForwardDeclarations.h +302 -0
  191. data/vendor/eigen/Eigen/src/Core/util/MKL_support.h +130 -0
  192. data/vendor/eigen/Eigen/src/Core/util/Macros.h +1001 -0
  193. data/vendor/eigen/Eigen/src/Core/util/Memory.h +993 -0
  194. data/vendor/eigen/Eigen/src/Core/util/Meta.h +534 -0
  195. data/vendor/eigen/Eigen/src/Core/util/NonMPL2.h +3 -0
  196. data/vendor/eigen/Eigen/src/Core/util/ReenableStupidWarnings.h +27 -0
  197. data/vendor/eigen/Eigen/src/Core/util/StaticAssert.h +218 -0
  198. data/vendor/eigen/Eigen/src/Core/util/XprHelper.h +821 -0
  199. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexEigenSolver.h +346 -0
  200. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur.h +459 -0
  201. data/vendor/eigen/Eigen/src/Eigenvalues/ComplexSchur_LAPACKE.h +91 -0
  202. data/vendor/eigen/Eigen/src/Eigenvalues/EigenSolver.h +622 -0
  203. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedEigenSolver.h +418 -0
  204. data/vendor/eigen/Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h +226 -0
  205. data/vendor/eigen/Eigen/src/Eigenvalues/HessenbergDecomposition.h +374 -0
  206. data/vendor/eigen/Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h +158 -0
  207. data/vendor/eigen/Eigen/src/Eigenvalues/RealQZ.h +654 -0
  208. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur.h +546 -0
  209. data/vendor/eigen/Eigen/src/Eigenvalues/RealSchur_LAPACKE.h +77 -0
  210. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h +870 -0
  211. data/vendor/eigen/Eigen/src/Eigenvalues/SelfAdjointEigenSolver_LAPACKE.h +87 -0
  212. data/vendor/eigen/Eigen/src/Eigenvalues/Tridiagonalization.h +556 -0
  213. data/vendor/eigen/Eigen/src/Geometry/AlignedBox.h +392 -0
  214. data/vendor/eigen/Eigen/src/Geometry/AngleAxis.h +247 -0
  215. data/vendor/eigen/Eigen/src/Geometry/EulerAngles.h +114 -0
  216. data/vendor/eigen/Eigen/src/Geometry/Homogeneous.h +497 -0
  217. data/vendor/eigen/Eigen/src/Geometry/Hyperplane.h +282 -0
  218. data/vendor/eigen/Eigen/src/Geometry/OrthoMethods.h +234 -0
  219. data/vendor/eigen/Eigen/src/Geometry/ParametrizedLine.h +195 -0
  220. data/vendor/eigen/Eigen/src/Geometry/Quaternion.h +814 -0
  221. data/vendor/eigen/Eigen/src/Geometry/Rotation2D.h +199 -0
  222. data/vendor/eigen/Eigen/src/Geometry/RotationBase.h +206 -0
  223. data/vendor/eigen/Eigen/src/Geometry/Scaling.h +170 -0
  224. data/vendor/eigen/Eigen/src/Geometry/Transform.h +1542 -0
  225. data/vendor/eigen/Eigen/src/Geometry/Translation.h +208 -0
  226. data/vendor/eigen/Eigen/src/Geometry/Umeyama.h +166 -0
  227. data/vendor/eigen/Eigen/src/Geometry/arch/Geometry_SSE.h +161 -0
  228. data/vendor/eigen/Eigen/src/Householder/BlockHouseholder.h +103 -0
  229. data/vendor/eigen/Eigen/src/Householder/Householder.h +172 -0
  230. data/vendor/eigen/Eigen/src/Householder/HouseholderSequence.h +470 -0
  231. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h +226 -0
  232. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/BiCGSTAB.h +228 -0
  233. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +246 -0
  234. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h +400 -0
  235. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IncompleteLUT.h +462 -0
  236. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h +394 -0
  237. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h +216 -0
  238. data/vendor/eigen/Eigen/src/IterativeLinearSolvers/SolveWithGuess.h +115 -0
  239. data/vendor/eigen/Eigen/src/Jacobi/Jacobi.h +462 -0
  240. data/vendor/eigen/Eigen/src/LU/Determinant.h +101 -0
  241. data/vendor/eigen/Eigen/src/LU/FullPivLU.h +891 -0
  242. data/vendor/eigen/Eigen/src/LU/InverseImpl.h +415 -0
  243. data/vendor/eigen/Eigen/src/LU/PartialPivLU.h +611 -0
  244. data/vendor/eigen/Eigen/src/LU/PartialPivLU_LAPACKE.h +83 -0
  245. data/vendor/eigen/Eigen/src/LU/arch/Inverse_SSE.h +338 -0
  246. data/vendor/eigen/Eigen/src/MetisSupport/MetisSupport.h +137 -0
  247. data/vendor/eigen/Eigen/src/OrderingMethods/Amd.h +445 -0
  248. data/vendor/eigen/Eigen/src/OrderingMethods/Eigen_Colamd.h +1843 -0
  249. data/vendor/eigen/Eigen/src/OrderingMethods/Ordering.h +157 -0
  250. data/vendor/eigen/Eigen/src/PaStiXSupport/PaStiXSupport.h +678 -0
  251. data/vendor/eigen/Eigen/src/PardisoSupport/PardisoSupport.h +543 -0
  252. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR.h +653 -0
  253. data/vendor/eigen/Eigen/src/QR/ColPivHouseholderQR_LAPACKE.h +97 -0
  254. data/vendor/eigen/Eigen/src/QR/CompleteOrthogonalDecomposition.h +562 -0
  255. data/vendor/eigen/Eigen/src/QR/FullPivHouseholderQR.h +676 -0
  256. data/vendor/eigen/Eigen/src/QR/HouseholderQR.h +409 -0
  257. data/vendor/eigen/Eigen/src/QR/HouseholderQR_LAPACKE.h +68 -0
  258. data/vendor/eigen/Eigen/src/SPQRSupport/SuiteSparseQRSupport.h +313 -0
  259. data/vendor/eigen/Eigen/src/SVD/BDCSVD.h +1246 -0
  260. data/vendor/eigen/Eigen/src/SVD/JacobiSVD.h +804 -0
  261. data/vendor/eigen/Eigen/src/SVD/JacobiSVD_LAPACKE.h +91 -0
  262. data/vendor/eigen/Eigen/src/SVD/SVDBase.h +315 -0
  263. data/vendor/eigen/Eigen/src/SVD/UpperBidiagonalization.h +414 -0
  264. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky.h +689 -0
  265. data/vendor/eigen/Eigen/src/SparseCholesky/SimplicialCholesky_impl.h +199 -0
  266. data/vendor/eigen/Eigen/src/SparseCore/AmbiVector.h +377 -0
  267. data/vendor/eigen/Eigen/src/SparseCore/CompressedStorage.h +258 -0
  268. data/vendor/eigen/Eigen/src/SparseCore/ConservativeSparseSparseProduct.h +352 -0
  269. data/vendor/eigen/Eigen/src/SparseCore/MappedSparseMatrix.h +67 -0
  270. data/vendor/eigen/Eigen/src/SparseCore/SparseAssign.h +216 -0
  271. data/vendor/eigen/Eigen/src/SparseCore/SparseBlock.h +603 -0
  272. data/vendor/eigen/Eigen/src/SparseCore/SparseColEtree.h +206 -0
  273. data/vendor/eigen/Eigen/src/SparseCore/SparseCompressedBase.h +341 -0
  274. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseBinaryOp.h +726 -0
  275. data/vendor/eigen/Eigen/src/SparseCore/SparseCwiseUnaryOp.h +148 -0
  276. data/vendor/eigen/Eigen/src/SparseCore/SparseDenseProduct.h +320 -0
  277. data/vendor/eigen/Eigen/src/SparseCore/SparseDiagonalProduct.h +138 -0
  278. data/vendor/eigen/Eigen/src/SparseCore/SparseDot.h +98 -0
  279. data/vendor/eigen/Eigen/src/SparseCore/SparseFuzzy.h +29 -0
  280. data/vendor/eigen/Eigen/src/SparseCore/SparseMap.h +305 -0
  281. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrix.h +1403 -0
  282. data/vendor/eigen/Eigen/src/SparseCore/SparseMatrixBase.h +405 -0
  283. data/vendor/eigen/Eigen/src/SparseCore/SparsePermutation.h +178 -0
  284. data/vendor/eigen/Eigen/src/SparseCore/SparseProduct.h +169 -0
  285. data/vendor/eigen/Eigen/src/SparseCore/SparseRedux.h +49 -0
  286. data/vendor/eigen/Eigen/src/SparseCore/SparseRef.h +397 -0
  287. data/vendor/eigen/Eigen/src/SparseCore/SparseSelfAdjointView.h +656 -0
  288. data/vendor/eigen/Eigen/src/SparseCore/SparseSolverBase.h +124 -0
  289. data/vendor/eigen/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +198 -0
  290. data/vendor/eigen/Eigen/src/SparseCore/SparseTranspose.h +92 -0
  291. data/vendor/eigen/Eigen/src/SparseCore/SparseTriangularView.h +189 -0
  292. data/vendor/eigen/Eigen/src/SparseCore/SparseUtil.h +178 -0
  293. data/vendor/eigen/Eigen/src/SparseCore/SparseVector.h +478 -0
  294. data/vendor/eigen/Eigen/src/SparseCore/SparseView.h +253 -0
  295. data/vendor/eigen/Eigen/src/SparseCore/TriangularSolver.h +315 -0
  296. data/vendor/eigen/Eigen/src/SparseLU/SparseLU.h +773 -0
  297. data/vendor/eigen/Eigen/src/SparseLU/SparseLUImpl.h +66 -0
  298. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Memory.h +226 -0
  299. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Structs.h +110 -0
  300. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h +301 -0
  301. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_Utils.h +80 -0
  302. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_bmod.h +181 -0
  303. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_column_dfs.h +179 -0
  304. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_copy_to_ucol.h +107 -0
  305. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_gemm_kernel.h +280 -0
  306. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_heap_relax_snode.h +126 -0
  307. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_kernel_bmod.h +130 -0
  308. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_bmod.h +223 -0
  309. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_panel_dfs.h +258 -0
  310. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pivotL.h +137 -0
  311. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_pruneL.h +136 -0
  312. data/vendor/eigen/Eigen/src/SparseLU/SparseLU_relax_snode.h +83 -0
  313. data/vendor/eigen/Eigen/src/SparseQR/SparseQR.h +745 -0
  314. data/vendor/eigen/Eigen/src/StlSupport/StdDeque.h +126 -0
  315. data/vendor/eigen/Eigen/src/StlSupport/StdList.h +106 -0
  316. data/vendor/eigen/Eigen/src/StlSupport/StdVector.h +131 -0
  317. data/vendor/eigen/Eigen/src/StlSupport/details.h +84 -0
  318. data/vendor/eigen/Eigen/src/SuperLUSupport/SuperLUSupport.h +1027 -0
  319. data/vendor/eigen/Eigen/src/UmfPackSupport/UmfPackSupport.h +506 -0
  320. data/vendor/eigen/Eigen/src/misc/Image.h +82 -0
  321. data/vendor/eigen/Eigen/src/misc/Kernel.h +79 -0
  322. data/vendor/eigen/Eigen/src/misc/RealSvd2x2.h +55 -0
  323. data/vendor/eigen/Eigen/src/misc/blas.h +440 -0
  324. data/vendor/eigen/Eigen/src/misc/lapack.h +152 -0
  325. data/vendor/eigen/Eigen/src/misc/lapacke.h +16291 -0
  326. data/vendor/eigen/Eigen/src/misc/lapacke_mangling.h +17 -0
  327. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseBinaryOps.h +332 -0
  328. data/vendor/eigen/Eigen/src/plugins/ArrayCwiseUnaryOps.h +552 -0
  329. data/vendor/eigen/Eigen/src/plugins/BlockMethods.h +1058 -0
  330. data/vendor/eigen/Eigen/src/plugins/CommonCwiseBinaryOps.h +115 -0
  331. data/vendor/eigen/Eigen/src/plugins/CommonCwiseUnaryOps.h +163 -0
  332. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseBinaryOps.h +152 -0
  333. data/vendor/eigen/Eigen/src/plugins/MatrixCwiseUnaryOps.h +85 -0
  334. data/vendor/eigen/README.md +3 -0
  335. data/vendor/eigen/bench/README.txt +55 -0
  336. data/vendor/eigen/bench/btl/COPYING +340 -0
  337. data/vendor/eigen/bench/btl/README +154 -0
  338. data/vendor/eigen/bench/tensors/README +21 -0
  339. data/vendor/eigen/blas/README.txt +6 -0
  340. data/vendor/eigen/demos/mandelbrot/README +10 -0
  341. data/vendor/eigen/demos/mix_eigen_and_c/README +9 -0
  342. data/vendor/eigen/demos/opengl/README +13 -0
  343. data/vendor/eigen/unsupported/Eigen/CXX11/src/Tensor/README.md +1760 -0
  344. data/vendor/eigen/unsupported/README.txt +50 -0
  345. data/vendor/tomotopy/LICENSE +21 -0
  346. data/vendor/tomotopy/README.kr.rst +375 -0
  347. data/vendor/tomotopy/README.rst +382 -0
  348. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +362 -0
  349. data/vendor/tomotopy/src/Labeling/FoRelevance.h +88 -0
  350. data/vendor/tomotopy/src/Labeling/Labeler.h +50 -0
  351. data/vendor/tomotopy/src/TopicModel/CT.h +37 -0
  352. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +13 -0
  353. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +293 -0
  354. data/vendor/tomotopy/src/TopicModel/DMR.h +51 -0
  355. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +13 -0
  356. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +374 -0
  357. data/vendor/tomotopy/src/TopicModel/DT.h +65 -0
  358. data/vendor/tomotopy/src/TopicModel/DTM.h +22 -0
  359. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +15 -0
  360. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +572 -0
  361. data/vendor/tomotopy/src/TopicModel/GDMR.h +37 -0
  362. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +14 -0
  363. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +485 -0
  364. data/vendor/tomotopy/src/TopicModel/HDP.h +74 -0
  365. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +13 -0
  366. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +592 -0
  367. data/vendor/tomotopy/src/TopicModel/HLDA.h +40 -0
  368. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +13 -0
  369. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +681 -0
  370. data/vendor/tomotopy/src/TopicModel/HPA.h +27 -0
  371. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +21 -0
  372. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +588 -0
  373. data/vendor/tomotopy/src/TopicModel/LDA.h +144 -0
  374. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +442 -0
  375. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +13 -0
  376. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +1058 -0
  377. data/vendor/tomotopy/src/TopicModel/LLDA.h +45 -0
  378. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +13 -0
  379. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +203 -0
  380. data/vendor/tomotopy/src/TopicModel/MGLDA.h +63 -0
  381. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +17 -0
  382. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +558 -0
  383. data/vendor/tomotopy/src/TopicModel/PA.h +43 -0
  384. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +13 -0
  385. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +467 -0
  386. data/vendor/tomotopy/src/TopicModel/PLDA.h +17 -0
  387. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +13 -0
  388. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +214 -0
  389. data/vendor/tomotopy/src/TopicModel/SLDA.h +54 -0
  390. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +17 -0
  391. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +456 -0
  392. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +692 -0
  393. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +169 -0
  394. data/vendor/tomotopy/src/Utils/Dictionary.h +80 -0
  395. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +181 -0
  396. data/vendor/tomotopy/src/Utils/LBFGS.h +202 -0
  397. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBacktracking.h +120 -0
  398. data/vendor/tomotopy/src/Utils/LBFGS/LineSearchBracketing.h +122 -0
  399. data/vendor/tomotopy/src/Utils/LBFGS/Param.h +213 -0
  400. data/vendor/tomotopy/src/Utils/LUT.hpp +82 -0
  401. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +69 -0
  402. data/vendor/tomotopy/src/Utils/PolyaGamma.hpp +200 -0
  403. data/vendor/tomotopy/src/Utils/PolyaGammaHybrid.hpp +672 -0
  404. data/vendor/tomotopy/src/Utils/ThreadPool.hpp +150 -0
  405. data/vendor/tomotopy/src/Utils/Trie.hpp +220 -0
  406. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +94 -0
  407. data/vendor/tomotopy/src/Utils/Utils.hpp +337 -0
  408. data/vendor/tomotopy/src/Utils/avx_gamma.h +46 -0
  409. data/vendor/tomotopy/src/Utils/avx_mathfun.h +736 -0
  410. data/vendor/tomotopy/src/Utils/exception.h +28 -0
  411. data/vendor/tomotopy/src/Utils/math.h +281 -0
  412. data/vendor/tomotopy/src/Utils/rtnorm.hpp +2690 -0
  413. data/vendor/tomotopy/src/Utils/sample.hpp +192 -0
  414. data/vendor/tomotopy/src/Utils/serializer.hpp +695 -0
  415. data/vendor/tomotopy/src/Utils/slp.hpp +131 -0
  416. data/vendor/tomotopy/src/Utils/sse_gamma.h +48 -0
  417. data/vendor/tomotopy/src/Utils/sse_mathfun.h +710 -0
  418. data/vendor/tomotopy/src/Utils/text.hpp +49 -0
  419. data/vendor/tomotopy/src/Utils/tvector.hpp +543 -0
  420. metadata +531 -0
@@ -0,0 +1,17 @@
1
+ #pragma once
2
+ #include "LLDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+
7
+ class IPLDAModel : public ILLDAModel
8
+ {
9
+ public:
10
+ using DefaultDocType = DocumentLLDA<TermWeight::one>;
11
+ static IPLDAModel* create(TermWeight _weight, size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
12
+ Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
13
+ bool scalarRng = false);
14
+
15
+ virtual size_t getNumLatentTopics() const = 0;
16
+ };
17
+ }
@@ -0,0 +1,13 @@
1
+ #include "PLDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class PLDAModel<TermWeight::one>;
6
+ template class PLDAModel<TermWeight::idf>;
7
+ template class PLDAModel<TermWeight::pmi>;*/
8
+
9
+ IPLDAModel* IPLDAModel::create(TermWeight _weight, size_t _numLatentTopics, size_t _numTopicsPerLabel, Float _alpha, Float _eta, size_t seed, bool scalarRng)
10
+ {
11
+ TMT_SWITCH_TW(_weight, scalarRng, PLDAModel, _numLatentTopics, _numTopicsPerLabel, _alpha, _eta, seed);
12
+ }
13
+ }
@@ -0,0 +1,214 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "PLDA.h"
4
+
5
+ /*
6
+ Implementation of Labeled LDA using Gibbs sampling by bab2min
7
+
8
+ * Ramage, D., Manning, C. D., & Dumais, S. (2011, August). Partially labeled topic models for interpretable text mining. In Proceedings of the 17th ACM SIGKDD international conference on Knowledge discovery and data mining (pp. 457-465). ACM.
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ template<TermWeight _tw, typename _RandGen,
14
+ typename _Interface = IPLDAModel,
15
+ typename _Derived = void,
16
+ typename _DocType = DocumentLLDA<_tw>,
17
+ typename _ModelState = ModelStateLDA<_tw>>
18
+ class PLDAModel : public LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface,
19
+ typename std::conditional<std::is_same<_Derived, void>::value, PLDAModel<_tw, _RandGen>, _Derived>::type,
20
+ _DocType, _ModelState>
21
+ {
22
+ protected:
23
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PLDAModel<_tw, _RandGen>, _Derived>::type;
24
+ using BaseClass = LDAModel<_tw, _RandGen, flags::generator_by_doc | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
25
+ friend BaseClass;
26
+ friend typename BaseClass::BaseClass;
27
+ using WeightType = typename BaseClass::WeightType;
28
+
29
+ static constexpr char TMID[] = "PLDA";
30
+
31
+ Dictionary topicLabelDict;
32
+
33
+ uint64_t numLatentTopics, numTopicsPerLabel;
34
+
35
+ template<bool _asymEta>
36
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
37
+ {
38
+ const size_t V = this->realV;
39
+ assert(vid < V);
40
+ auto etaHelper = this->template getEtaHelper<_asymEta>();
41
+ auto& zLikelihood = ld.zLikelihood;
42
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
43
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
44
+ / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
45
+ zLikelihood.array() *= doc.labelMask.array().template cast<Float>();
46
+ sample::prefixSum(zLikelihood.data(), this->K);
47
+ return &zLikelihood[0];
48
+ }
49
+
50
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
51
+ {
52
+ BaseClass::prepareDoc(doc, docId, wordSize);
53
+ if (doc.labelMask.size() == 0)
54
+ {
55
+ doc.labelMask.resize(this->K);
56
+ doc.labelMask.setZero();
57
+ doc.labelMask.tail(numLatentTopics).setOnes();
58
+ }
59
+ else if (doc.labelMask.size() < this->K)
60
+ {
61
+ size_t oldSize = doc.labelMask.size();
62
+ doc.labelMask.conservativeResize(this->K);
63
+ doc.labelMask.tail(this->K - oldSize).setZero();
64
+ doc.labelMask.tail(numLatentTopics).setOnes();
65
+ }
66
+ }
67
+
68
+ void initGlobalState(bool initDocs)
69
+ {
70
+ this->K = topicLabelDict.size() * numTopicsPerLabel + numLatentTopics;
71
+ this->alphas.resize(this->K);
72
+ this->alphas.array() = this->alpha;
73
+ BaseClass::initGlobalState(initDocs);
74
+ }
75
+
76
+ struct Generator
77
+ {
78
+ std::discrete_distribution<> theta;
79
+ };
80
+
81
+ Generator makeGeneratorForInit(const _DocType* doc) const
82
+ {
83
+ return Generator{
84
+ std::discrete_distribution<>{ doc->labelMask.data(), doc->labelMask.data() + doc->labelMask.size() }
85
+ };
86
+ }
87
+
88
+ template<bool _Infer>
89
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
90
+ {
91
+ auto& z = doc.Zs[i];
92
+ auto w = doc.words[i];
93
+ if (this->etaByTopicWord.size())
94
+ {
95
+ Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
96
+ for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
97
+ z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
98
+ }
99
+ else
100
+ {
101
+ z = g.theta(rgs);
102
+ }
103
+ this->template addWordTo<1>(ld, doc, i, w, z);
104
+ }
105
+
106
+ public:
107
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
108
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
109
+
110
+ PLDAModel(size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
111
+ Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
112
+ : BaseClass(1, _alpha, _eta, _rg),
113
+ numLatentTopics(_numLatentTopics), numTopicsPerLabel(_numTopicsPerLabel)
114
+ {
115
+ if (_numLatentTopics >= 0x80000000)
116
+ THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numLatentTopics value (numLatentTopics = %zd)", _numLatentTopics));
117
+ if (_numTopicsPerLabel == 0 || _numTopicsPerLabel >= 0x80000000)
118
+ THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numTopicsPerLabel value (numTopicsPerLabel = %zd)", _numTopicsPerLabel));
119
+ }
120
+
121
+ template<bool _const = false>
122
+ _DocType& _updateDoc(_DocType& doc, const std::vector<std::string>& labels)
123
+ {
124
+ if (_const)
125
+ {
126
+ doc.labelMask.resize(this->K);
127
+ doc.labelMask.setZero();
128
+ doc.labelMask.tail(numLatentTopics).setOnes();
129
+
130
+ std::vector<Vid> topicLabelIds;
131
+ for (auto& label : labels)
132
+ {
133
+ auto tid = topicLabelDict.toWid(label);
134
+ if (tid == (Vid)-1) continue;
135
+ topicLabelIds.emplace_back(tid);
136
+ }
137
+
138
+ for (auto tid : topicLabelIds) doc.labelMask.segment(tid * numTopicsPerLabel, numTopicsPerLabel).setOnes();
139
+ if (labels.empty()) doc.labelMask.setOnes();
140
+ }
141
+ else
142
+ {
143
+ if (!labels.empty())
144
+ {
145
+ std::vector<Vid> topicLabelIds;
146
+ for (auto& label : labels) topicLabelIds.emplace_back(topicLabelDict.add(label));
147
+ auto maxVal = *std::max_element(topicLabelIds.begin(), topicLabelIds.end());
148
+ doc.labelMask.resize((maxVal + 1) * numTopicsPerLabel);
149
+ doc.labelMask.setZero();
150
+ for (auto i : topicLabelIds) doc.labelMask.segment(i * numTopicsPerLabel, numTopicsPerLabel).setOnes();
151
+ }
152
+ }
153
+ return doc;
154
+ }
155
+
156
+ size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
157
+ {
158
+ auto doc = this->_makeDoc(words);
159
+ return this->_addDoc(_updateDoc(doc, labels));
160
+ }
161
+
162
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
163
+ {
164
+ auto doc = as_mutable(this)->template _makeDoc<true>(words);
165
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
166
+ }
167
+
168
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
169
+ const std::vector<std::string>& labels) override
170
+ {
171
+ auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
172
+ return this->_addDoc(_updateDoc(doc, labels));
173
+ }
174
+
175
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
176
+ const std::vector<std::string>& labels) const override
177
+ {
178
+ auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
179
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
180
+ }
181
+
182
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
183
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
184
+ const std::vector<std::string>& labels) override
185
+ {
186
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
187
+ return this->_addDoc(_updateDoc(doc, labels));
188
+ }
189
+
190
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
191
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
192
+ const std::vector<std::string>& labels) const override
193
+ {
194
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
195
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
196
+ }
197
+
198
+ std::vector<Float> getTopicsByDoc(const _DocType& doc) const
199
+ {
200
+ std::vector<Float> ret(this->K);
201
+ auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
202
+ Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
203
+ (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
204
+ / (doc.getSumWordWeight() + maskedAlphas.sum());
205
+ return ret;
206
+ }
207
+
208
+ const Dictionary& getTopicLabelDict() const override { return topicLabelDict; }
209
+
210
+ size_t getNumLatentTopics() const override { return numLatentTopics; }
211
+
212
+ size_t getNumTopicsPerLabel() const override { return numTopicsPerLabel; }
213
+ };
214
+ }
@@ -0,0 +1,54 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentSLDA : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ std::vector<Float> y;
12
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, y);
13
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, y);
14
+ };
15
+
16
+ class ISLDAModel : public ILDAModel
17
+ {
18
+ public:
19
+ enum class GLM
20
+ {
21
+ linear = 0,
22
+ binary_logistic = 1,
23
+ };
24
+
25
+ using DefaultDocType = DocumentSLDA<TermWeight::one>;
26
+ static ISLDAModel* create(TermWeight _weight, size_t _K = 1,
27
+ const std::vector<ISLDAModel::GLM>& vars = {},
28
+ Float alpha = 0.1, Float _eta = 0.01,
29
+ const std::vector<Float>& _mu = {}, const std::vector<Float>& _nuSq = {},
30
+ const std::vector<Float>& _glmParam = {},
31
+ size_t seed = std::random_device{}(),
32
+ bool scalarRng = false);
33
+
34
+ virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
35
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
36
+
37
+ virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
38
+ const std::vector<Float>& y) = 0;
39
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
40
+ const std::vector<Float>& y) const = 0;
41
+
42
+ virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
43
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
44
+ const std::vector<Float>& y) = 0;
45
+ virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
46
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
47
+ const std::vector<Float>& y) const = 0;
48
+
49
+ virtual size_t getF() const = 0;
50
+ virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
51
+ virtual GLM getTypeOfVar(size_t f) const = 0;
52
+ virtual std::vector<Float> estimateVars(const DocumentBase* doc) const = 0;
53
+ };
54
+ }
@@ -0,0 +1,17 @@
1
+ #include "SLDAModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+ /*template class SLDAModel<TermWeight::one>;
6
+ template class SLDAModel<TermWeight::idf>;
7
+ template class SLDAModel<TermWeight::pmi>;*/
8
+
9
+ ISLDAModel* ISLDAModel::create(TermWeight _weight, size_t _K, const std::vector<ISLDAModel::GLM>& vars,
10
+ Float _alpha, Float _eta,
11
+ const std::vector<Float>& _mu, const std::vector<Float>& _nuSq,
12
+ const std::vector<Float>& _glmParam,
13
+ size_t seed, bool scalarRng)
14
+ {
15
+ TMT_SWITCH_TW(_weight, scalarRng, SLDAModel, _K, vars, _alpha, _eta, _mu, _nuSq, _glmParam, seed);
16
+ }
17
+ }
@@ -0,0 +1,456 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "../Utils/PolyaGamma.hpp"
4
+ #include "SLDA.h"
5
+
6
+ /*
7
+ Implementation of sLDA using Gibbs sampling by bab2min
8
+ * Mcauliffe, J. D., & Blei, D. M. (2008). Supervised topic models. In Advances in neural information processing systems (pp. 121-128).
9
+ * Python version implementation using Gibbs sampling : https://github.com/Savvysherpa/slda
10
+ */
11
+
12
+ namespace tomoto
13
+ {
14
+ namespace detail
15
+ {
16
+ template<typename _WeightType>
17
+ struct GLMFunctor
18
+ {
19
+ Eigen::Matrix<Float, -1, 1> regressionCoef; // Dim : (K)
20
+
21
+ GLMFunctor(size_t K = 0, Float mu = 0) : regressionCoef(Eigen::Matrix<Float, -1, 1>::Constant(K, mu))
22
+ {
23
+ }
24
+
25
+ virtual ISLDAModel::GLM getType() const = 0;
26
+
27
+ virtual void updateZLL(
28
+ Eigen::Matrix<Float, -1, 1>& zLikelihood,
29
+ Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const = 0;
30
+
31
+ virtual void optimizeCoef(
32
+ const Eigen::Matrix<Float, -1, -1>& normZ,
33
+ Float mu, Float nuSq,
34
+ Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
35
+ ) = 0;
36
+
37
+ virtual double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
38
+ Float docSize) const = 0;
39
+
40
+ virtual Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
41
+ Float docSize) const = 0;
42
+
43
+ virtual ~GLMFunctor() {};
44
+
45
+ DEFINE_SERIALIZER_VIRTUAL(regressionCoef);
46
+
47
+ static void serializerWrite(const std::unique_ptr<GLMFunctor>& p, std::ostream& ostr)
48
+ {
49
+ if (!p) serializer::writeToStream<uint32_t>(ostr, 0);
50
+ else
51
+ {
52
+ serializer::writeToStream<uint32_t>(ostr, (uint32_t)p->getType() + 1);
53
+ p->serializerWrite(ostr);
54
+ }
55
+ }
56
+
57
+ static void serializerRead(std::unique_ptr<GLMFunctor>& p, std::istream& istr);
58
+ };
59
+
60
+ template<typename _WeightType>
61
+ struct LinearFunctor : public GLMFunctor<_WeightType>
62
+ {
63
+ Float sigmaSq = 1;
64
+
65
+ LinearFunctor(size_t K = 0, Float mu = 0, Float _sigmaSq = 1)
66
+ : GLMFunctor<_WeightType>(K, mu), sigmaSq(_sigmaSq)
67
+ {
68
+ }
69
+
70
+ ISLDAModel::GLM getType() const override { return ISLDAModel::GLM::linear; }
71
+
72
+ void updateZLL(
73
+ Eigen::Matrix<Float, -1, 1>& zLikelihood,
74
+ Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const override
75
+ {
76
+ Float yErr = y -
77
+ (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
78
+ / docSize;
79
+ zLikelihood.array() *= (this->regressionCoef.array() / docSize / 2 / sigmaSq *
80
+ (2 * yErr - this->regressionCoef.array() / docSize)).exp();
81
+ }
82
+
83
+ void optimizeCoef(
84
+ const Eigen::Matrix<Float, -1, -1>& normZ,
85
+ Float mu, Float nuSq,
86
+ Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
87
+ ) override
88
+ {
89
+ Eigen::Matrix<Float, -1, -1> selectedNormZ = normZ.array().rowwise() * (!ys.array().transpose().isNaN()).template cast<Float>();
90
+ Eigen::Matrix<Float, -1, -1> normZZT = selectedNormZ * selectedNormZ.transpose();
91
+ normZZT += Eigen::Matrix<Float, -1, -1>::Identity(normZZT.cols(), normZZT.cols()) / nuSq;
92
+ this->regressionCoef = normZZT.colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, ys).matrix());
93
+ }
94
+
95
+ double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
96
+ Float docSize) const override
97
+ {
98
+ Float estimatedY = estimate(numByTopic, docSize);
99
+ return -pow(estimatedY - y, 2) / 2 / sigmaSq;
100
+ }
101
+
102
+ Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
103
+ Float docSize) const override
104
+ {
105
+ return (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
106
+ / std::max(docSize, 0.01f);
107
+ }
108
+
109
+ DEFINE_SERIALIZER_AFTER_BASE(GLMFunctor<_WeightType>, sigmaSq);
110
+ };
111
+
112
+ template<typename _WeightType>
113
+ struct BinaryLogisticFunctor : public GLMFunctor<_WeightType>
114
+ {
115
+ Float b = 1;
116
+ Eigen::Matrix<Float, -1, 1> omega;
117
+
118
+ BinaryLogisticFunctor(size_t K = 0, Float mu = 0, Float _b = 1, size_t numDocs = 0)
119
+ : GLMFunctor<_WeightType>(K, mu), b(_b), omega{ Eigen::Matrix<Float, -1, 1>::Ones(numDocs) }
120
+ {
121
+ }
122
+
123
+ ISLDAModel::GLM getType() const override { return ISLDAModel::GLM::binary_logistic; }
124
+
125
+ void updateZLL(
126
+ Eigen::Matrix<Float, -1, 1>& zLikelihood,
127
+ Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic, size_t docId, Float docSize) const override
128
+ {
129
+ Float yErr = b * (y - 0.5f) -
130
+ (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
131
+ / docSize * omega[docId];
132
+ zLikelihood.array() *= (this->regressionCoef.array() / docSize *
133
+ (yErr - omega[docId] / 2 * this->regressionCoef.array() / docSize)).exp();
134
+ }
135
+
136
+ void optimizeCoef(
137
+ const Eigen::Matrix<Float, -1, -1>& normZ,
138
+ Float mu, Float nuSq,
139
+ Eigen::Block<Eigen::Matrix<Float, -1, -1>, -1, 1, true> ys
140
+ ) override
141
+ {
142
+ Eigen::Matrix<Float, -1, -1> selectedNormZ = normZ.array().rowwise() * (!ys.array().transpose().isNaN()).template cast<Float>();
143
+ Eigen::Matrix<Float, -1, -1> normZZT = selectedNormZ * Eigen::DiagonalMatrix<Float, -1>{ omega } * selectedNormZ.transpose();
144
+ normZZT += Eigen::Matrix<Float, -1, -1>::Identity(normZZT.cols(), normZZT.cols()) / nuSq;
145
+
146
+ this->regressionCoef = normZZT
147
+ .colPivHouseholderQr().solve(selectedNormZ * ys.array().isNaN().select(0, b * (ys.array() - 0.5f)).matrix()
148
+ + Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
149
+
150
+ RandGen rng;
151
+ for (size_t i = 0; i < omega.size(); ++i)
152
+ {
153
+ if (std::isnan(ys[i])) continue;
154
+ omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
155
+ }
156
+ }
157
+
158
+ double getLL(Float y, const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
159
+ Float docSize) const override
160
+ {
161
+ Float z = (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
162
+ / std::max(docSize, 0.01f);
163
+ return b * (y * z - log(1 + exp(z)));
164
+ }
165
+
166
+ Float estimate(const Eigen::Matrix<_WeightType, -1, 1>& numByTopic,
167
+ Float docSize) const override
168
+ {
169
+ Float z = (this->regressionCoef.array() * numByTopic.array().template cast<Float>()).sum()
170
+ / std::max(docSize, 0.01f);
171
+ return 1 / (1 + exp(-z));
172
+ }
173
+
174
+ DEFINE_SERIALIZER_AFTER_BASE(GLMFunctor<_WeightType>, b, omega);
175
+ };
176
+ }
177
+
178
+ template<TermWeight _tw, typename _RandGen,
179
+ size_t _Flags = flags::partitioned_multisampling,
180
+ typename _Interface = ISLDAModel,
181
+ typename _Derived = void,
182
+ typename _DocType = DocumentSLDA<_tw>,
183
+ typename _ModelState = ModelStateLDA<_tw>>
184
+ class SLDAModel : public LDAModel<_tw, _RandGen, _Flags, _Interface,
185
+ typename std::conditional<std::is_same<_Derived, void>::value, SLDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
186
+ _DocType, _ModelState>
187
+ {
188
+ protected:
189
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, SLDAModel<_tw, _RandGen>, _Derived>::type;
190
+ using BaseClass = LDAModel<_tw, _RandGen, _Flags, _Interface, DerivedClass, _DocType, _ModelState>;
191
+ friend BaseClass;
192
+ friend typename BaseClass::BaseClass;
193
+ using WeightType = typename BaseClass::WeightType;
194
+
195
+ static constexpr char TMID[] = "SLDA";
196
+
197
+ uint64_t F; // number of response variables
198
+ std::vector<ISLDAModel::GLM> varTypes;
199
+ std::vector<Float> glmParam;
200
+
201
+ Eigen::Matrix<Float, -1, 1> mu; // Mean of regression coefficients, Dim : (F)
202
+ Eigen::Matrix<Float, -1, 1> nuSq; // Variance of regression coefficients, Dim : (F)
203
+
204
+ std::vector<std::unique_ptr<detail::GLMFunctor<WeightType>>> responseVars;
205
+ Eigen::Matrix<Float, -1, -1> normZ; // topic proportions for all docs, Dim : (K, D)
206
+ Eigen::Matrix<Float, -1, -1> Ys; // response variables, Dim : (D, F)
207
+
208
+ template<bool _asymEta>
209
+ Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
210
+ {
211
+ const size_t V = this->realV;
212
+ assert(vid < V);
213
+ auto etaHelper = this->template getEtaHelper<_asymEta>();
214
+ auto& zLikelihood = ld.zLikelihood;
215
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->alphas.array())
216
+ * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
217
+ / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
218
+
219
+ for (size_t f = 0; f < F; ++f)
220
+ {
221
+ if (std::isnan(doc.y[f])) continue;
222
+ responseVars[f]->updateZLL(zLikelihood, doc.y[f], doc.numByTopic,
223
+ docId, doc.getSumWordWeight());
224
+ }
225
+ sample::prefixSum(zLikelihood.data(), this->K);
226
+ return &zLikelihood[0];
227
+ }
228
+
229
+ void optimizeRegressionCoef()
230
+ {
231
+ for (size_t i = 0; i < this->docs.size(); ++i)
232
+ {
233
+ normZ.col(i) = this->docs[i].numByTopic.array().template cast<Float>() /
234
+ std::max((Float)this->docs[i].getSumWordWeight(), 0.01f);
235
+ }
236
+
237
+ for (size_t f = 0; f < F; ++f)
238
+ {
239
+ responseVars[f]->optimizeCoef(normZ, mu[f], nuSq[f], Ys.col(f));
240
+ }
241
+ }
242
+
243
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
244
+ {
245
+ BaseClass::optimizeParameters(pool, localData, rgs);
246
+ }
247
+
248
+ void updateGlobalInfo(ThreadPool& pool, _ModelState* localData)
249
+ {
250
+ optimizeRegressionCoef();
251
+ }
252
+
253
+ template<typename _DocIter>
254
+ double getLLDocs(_DocIter _first, _DocIter _last) const
255
+ {
256
+ const auto K = this->K;
257
+
258
+ double ll = 0;
259
+ for (; _first != _last; ++_first)
260
+ {
261
+ auto& doc = *_first;
262
+ ll -= math::lgammaT(doc.getSumWordWeight() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
263
+ for (size_t f = 0; f < F; ++f)
264
+ {
265
+ if (std::isnan(doc.y[f])) continue;
266
+ ll += responseVars[f]->getLL(doc.y[f], doc.numByTopic, doc.getSumWordWeight());
267
+ }
268
+ for (Tid k = 0; k < K; ++k)
269
+ {
270
+ ll += math::lgammaT(doc.numByTopic[k] + this->alphas[k]) - math::lgammaT(this->alphas[k]);
271
+ }
272
+ }
273
+ return ll;
274
+ }
275
+
276
+ double getLLRest(const _ModelState& ld) const
277
+ {
278
+ double ll = BaseClass::getLLRest(ld);
279
+ for (size_t f = 0; f < F; ++f)
280
+ {
281
+ ll -= (responseVars[f]->regressionCoef.array() - mu[f]).pow(2).sum() / 2 / nuSq[f];
282
+ }
283
+ return ll;
284
+ }
285
+
286
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
287
+ {
288
+ BaseClass::prepareDoc(doc, docId, wordSize);
289
+ }
290
+
291
+ void initGlobalState(bool initDocs)
292
+ {
293
+ BaseClass::initGlobalState(initDocs);
294
+ if (initDocs)
295
+ {
296
+ for (size_t f = 0; f < F; ++f)
297
+ {
298
+ std::unique_ptr<detail::GLMFunctor<WeightType>> v;
299
+ switch (varTypes[f])
300
+ {
301
+ case ISLDAModel::GLM::linear:
302
+ v = make_unique<detail::LinearFunctor<WeightType>>(this->K, mu[f],
303
+ f < glmParam.size() ? glmParam[f] : 1.f);
304
+ break;
305
+ case ISLDAModel::GLM::binary_logistic:
306
+ v = make_unique<detail::BinaryLogisticFunctor<WeightType>>(this->K, mu[f],
307
+ f < glmParam.size() ? glmParam[f] : 1.f, this->docs.size());
308
+ break;
309
+ }
310
+ responseVars.emplace_back(std::move(v));
311
+ }
312
+ }
313
+ Ys.resize(this->docs.size(), F);
314
+ normZ.resize(this->K, this->docs.size());
315
+ for (size_t i = 0; i < this->docs.size(); ++i)
316
+ {
317
+ Ys.row(i) = Eigen::Map<Eigen::Matrix<Float, 1, -1>>(this->docs[i].y.data(), F);
318
+ }
319
+ }
320
+
321
+ public:
322
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, F, responseVars, mu, nuSq);
323
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, F, responseVars, mu, nuSq);
324
+
325
+ SLDAModel(size_t _K = 1, const std::vector<ISLDAModel::GLM>& vars = {},
326
+ Float _alpha = 0.1, Float _eta = 0.01,
327
+ const std::vector<Float>& _mu = {}, const std::vector<Float>& _nuSq = {},
328
+ const std::vector<Float>& _glmParam = {},
329
+ size_t _rg = std::random_device{}())
330
+ : BaseClass(_K, _alpha, _eta, _rg), F(vars.size()), varTypes(vars),
331
+ glmParam(_glmParam)
332
+ {
333
+ for (auto t : varTypes)
334
+ {
335
+ if (t != ISLDAModel::GLM::linear && t != ISLDAModel::GLM::binary_logistic) THROW_ERROR_WITH_INFO(std::runtime_error, "unknown var GLM type in 'vars'");
336
+ }
337
+ mu = decltype(mu)::Zero(F);
338
+ std::copy(_mu.begin(), _mu.end(), mu.data());
339
+ nuSq = decltype(nuSq)::Ones(F);
340
+ std::copy(_nuSq.begin(), _nuSq.end(), nuSq.data());
341
+ }
342
+
343
+ std::vector<Float> getRegressionCoef(size_t f) const override
344
+ {
345
+ return { responseVars[f]->regressionCoef.data(), responseVars[f]->regressionCoef.data() + this->K };
346
+ }
347
+
348
+ GETTER(F, size_t, F);
349
+
350
+ ISLDAModel::GLM getTypeOfVar(size_t f) const override
351
+ {
352
+ return responseVars[f]->getType();
353
+ }
354
+
355
+ template<bool _const = false>
356
+ _DocType& _updateDoc(_DocType& doc, const std::vector<Float>& y)
357
+ {
358
+ if (_const)
359
+ {
360
+ if (y.size() > F) throw std::runtime_error{ text::format(
361
+ "size of 'y' is greater than the number of vars.\n"
362
+ "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
363
+ doc.y = y;
364
+ while (doc.y.size() < F)
365
+ {
366
+ doc.y.emplace_back(NAN);
367
+ }
368
+ }
369
+ else
370
+ {
371
+ if (y.size() != F) throw std::runtime_error{ text::format(
372
+ "size of 'y' must be equal to the number of vars.\n"
373
+ "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
374
+ doc.y = y;
375
+ }
376
+ return doc;
377
+ }
378
+
379
+ size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override
380
+ {
381
+ auto doc = this->_makeDoc(words);
382
+ return this->_addDoc(_updateDoc(doc, y));
383
+ }
384
+
385
+ std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override
386
+ {
387
+ auto doc = as_mutable(this)->template _makeDoc<true>(words);
388
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
389
+ }
390
+
391
+ size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
392
+ const std::vector<Float>& y) override
393
+ {
394
+ auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
395
+ return this->_addDoc(_updateDoc(doc, y));
396
+ }
397
+
398
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
399
+ const std::vector<Float>& y) const override
400
+ {
401
+ auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
402
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
403
+ }
404
+
405
+ size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
406
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
407
+ const std::vector<Float>& y) override
408
+ {
409
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
410
+ return this->_addDoc(_updateDoc(doc, y));
411
+ }
412
+
413
+ std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
414
+ const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
415
+ const std::vector<Float>& y) const override
416
+ {
417
+ auto doc = this->_makeRawDoc(rawStr, words, pos, len);
418
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
419
+ }
420
+
421
+ std::vector<Float> estimateVars(const DocumentBase* doc) const override
422
+ {
423
+ std::vector<Float> ret;
424
+ auto pdoc = dynamic_cast<const _DocType*>(doc);
425
+ if (!pdoc) return ret;
426
+ for (auto& f : responseVars)
427
+ {
428
+ ret.emplace_back(f->estimate(pdoc->numByTopic, pdoc->getSumWordWeight()));
429
+ }
430
+ return ret;
431
+ }
432
+ };
433
+
434
+ template<typename _WeightType>
435
+ void detail::GLMFunctor<_WeightType>::serializerRead(
436
+ std::unique_ptr<detail::GLMFunctor<_WeightType>>& p, std::istream& istr)
437
+ {
438
+ uint32_t t = serializer::readFromStream<uint32_t>(istr);
439
+ if (!t) p.reset();
440
+ else
441
+ {
442
+ switch ((ISLDAModel::GLM)(t - 1))
443
+ {
444
+ case ISLDAModel::GLM::linear:
445
+ p = make_unique<LinearFunctor<_WeightType>>();
446
+ break;
447
+ case ISLDAModel::GLM::binary_logistic:
448
+ p = make_unique<BinaryLogisticFunctor<_WeightType>>();
449
+ break;
450
+ default:
451
+ throw std::ios_base::failure(text::format("wrong GLMFunctor type id %d", (t - 1)));
452
+ }
453
+ p->serializerRead(istr);
454
+ }
455
+ }
456
+ }