tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -161,12 +161,19 @@ namespace tomoto
161
161
  }
162
162
  };
163
163
 
164
+ struct LDAArgs
165
+ {
166
+ size_t k = 1;
167
+ std::vector<Float> alpha = { 0.1 };
168
+ Float eta = 0.01;
169
+ size_t seed = std::random_device{}();
170
+ };
171
+
164
172
  class ILDAModel : public ITopicModel
165
173
  {
166
174
  public:
167
175
  using DefaultDocType = DocumentLDA<TermWeight::one>;
168
- static ILDAModel* create(TermWeight _weight, size_t _K = 1,
169
- Float _alpha = 0.1, Float _eta = 0.01, size_t seed = std::random_device{}(),
176
+ static ILDAModel* create(TermWeight _weight, const LDAArgs& args,
170
177
  bool scalarRng = false);
171
178
 
172
179
  virtual TermWeight getTermWeight() const = 0;
@@ -85,7 +85,7 @@ namespace tomoto
85
85
  static constexpr static constexpr char TMID[] = "LDA\0";
86
86
 
87
87
  Float alpha;
88
- Eigen::Matrix<Float, -1, 1> alphas;
88
+ Vector alphas;
89
89
  Float eta;
90
90
  Tid K;
91
91
  size_t optimInterval = 50;
@@ -93,7 +93,7 @@ namespace tomoto
93
93
  template<typename _List>
94
94
  static Float calcDigammaSum(_List list, size_t len, Float alpha)
95
95
  {
96
- auto listExpr = Eigen::Matrix<Float, -1, 1>::NullaryExpr(len, list);
96
+ auto listExpr = Vector::NullaryExpr(len, list);
97
97
  auto dAlpha = math::digammaT(alpha);
98
98
  return (math::digammaApprox(listExpr.array() + alpha) - dAlpha).sum();
99
99
  }
@@ -265,11 +265,11 @@ namespace tomoto
265
265
  void initGlobalState(bool initDocs)
266
266
  {
267
267
  const size_t V = this->realV;
268
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(K);
268
+ this->globalState.zLikelihood = Vector::Zero(K);
269
269
  if (initDocs)
270
270
  {
271
- this->globalState.numByTopic = Eigen::Matrix<Float, -1, 1>::Zero(K);
272
- this->globalState.numByTopicWord = Eigen::Matrix<Float, -1, -1>::Zero(K, V);
271
+ this->globalState.numByTopic = Vector::Zero(K);
272
+ this->globalState.numByTopicWord = Matrix::Zero(K, V);
273
273
  }
274
274
  }
275
275
 
@@ -335,7 +335,7 @@ namespace tomoto
335
335
  LDACVB0Model(size_t _K = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
336
336
  : BaseClass(_rg), K(_K), alpha(_alpha), eta(_eta)
337
337
  {
338
- alphas = Eigen::Matrix<Float, -1, 1>::Constant(K, alpha);
338
+ alphas = Vector::Constant(K, alpha);
339
339
  }
340
340
  GETTER(K, size_t, K);
341
341
  GETTER(Alpha, Float, alpha);
@@ -355,7 +355,7 @@ namespace tomoto
355
355
 
356
356
  std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words) const override
357
357
  {
358
- return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
358
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
359
359
  }
360
360
 
361
361
  void updateDocs()
@@ -403,7 +403,7 @@ namespace tomoto
403
403
  return ret;
404
404
  }
405
405
 
406
- std::vector<Float> _getWidsByTopic(Tid tid) const
406
+ std::vector<Float> _getWidsByTopic(Tid tid, bool normalize = true) const
407
407
  {
408
408
  assert(tid < K);
409
409
  const size_t V = this->realV;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class LDAModel<TermWeight::one>;
6
- template class LDAModel<TermWeight::idf>;
7
- template class LDAModel<TermWeight::pmi>;*/
8
-
9
- ILDAModel* ILDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
5
+ ILDAModel* ILDAModel::create(TermWeight _weight, const LDAArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, LDAModel, _K, _alpha, _eta, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, LDAModel, args);
12
8
  }
13
9
  }
@@ -56,7 +56,7 @@ namespace tomoto
56
56
  {
57
57
  using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
58
58
 
59
- Eigen::Matrix<Float, -1, 1> zLikelihood;
59
+ Vector zLikelihood;
60
60
  Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
61
61
  //Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
62
  ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
@@ -179,10 +179,10 @@ namespace tomoto
179
179
  std::vector<Float> sharedWordWeights;
180
180
  Tid K;
181
181
  Float alpha, eta;
182
- Eigen::Matrix<Float, -1, 1> alphas;
182
+ Vector alphas;
183
183
  std::unordered_map<std::string, std::vector<Float>> etaByWord;
184
- Eigen::Matrix<Float, -1, -1> etaByTopicWord; // (K, V)
185
- Eigen::Matrix<Float, -1, 1> etaSumByTopic; // (K, )
184
+ Matrix etaByTopicWord; // (K, V)
185
+ Vector etaSumByTopic; // (K, )
186
186
  uint32_t optimInterval = 10, burnIn = 0;
187
187
  Eigen::Matrix<WeightType, -1, -1> numByTopicDoc;
188
188
 
@@ -197,7 +197,7 @@ namespace tomoto
197
197
  template<typename _List>
198
198
  static Float calcDigammaSum(ThreadPool* pool, _List list, size_t len, Float alpha)
199
199
  {
200
- auto listExpr = Eigen::Matrix<Float, -1, 1>::NullaryExpr(len, list);
200
+ auto listExpr = Vector::NullaryExpr(len, list);
201
201
  auto dAlpha = math::digammaT(alpha);
202
202
 
203
203
  size_t suggested = (len + 127) / 128;
@@ -507,7 +507,7 @@ namespace tomoto
507
507
  static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
508
508
  }
509
509
  }
510
- catch (const exception::TrainingError&)
510
+ catch (const exc::TrainingError&)
511
511
  {
512
512
  for (auto& r : res) if(r.valid()) r.get();
513
513
  throw;
@@ -663,6 +663,22 @@ namespace tomoto
663
663
  makeTransformIter(this->docs.end(), txWeights));
664
664
  }
665
665
  }
666
+
667
+ void updateForCopy()
668
+ {
669
+ BaseClass::updateForCopy();
670
+ size_t offset = 0;
671
+ for (auto& doc : this->docs)
672
+ {
673
+ size_t size = doc.Zs.size();
674
+ doc.Zs = tvector<Tid>{ sharedZs.data() + offset, size };
675
+ if (_tw != TermWeight::one)
676
+ {
677
+ doc.wordWeights = tvector<Float>{ sharedWordWeights.data() + offset, size };
678
+ }
679
+ offset += size;
680
+ }
681
+ }
666
682
 
667
683
  WeightType* getTopicDocPtr(size_t docId) const
668
684
  {
@@ -670,11 +686,14 @@ namespace tomoto
670
686
  return (WeightType*)numByTopicDoc.col(docId).data();
671
687
  }
672
688
 
689
+ /*
690
+ * called only when initializing a new doc, not when loading from saved model
691
+ */
673
692
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
674
693
  {
675
694
  sortAndWriteOrder(doc.words, doc.wOrder);
676
695
  doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
677
- doc.Zs = tvector<Tid>(wordSize);
696
+ doc.Zs = tvector<Tid>(wordSize, non_topic_id);
678
697
  if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
679
698
  }
680
699
 
@@ -688,7 +707,7 @@ namespace tomoto
688
707
  {
689
708
  auto id = this->dict.toWid(it.first);
690
709
  if (id == (Vid)-1 || id >= this->realV) continue;
691
- etaByTopicWord.col(id) = Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ it.second.data(), (Eigen::Index)it.second.size() };
710
+ etaByTopicWord.col(id) = Eigen::Map<Vector>{ it.second.data(), (Eigen::Index)it.second.size() };
692
711
  }
693
712
  etaSumByTopic = etaByTopicWord.rowwise().sum();
694
713
  }
@@ -696,7 +715,7 @@ namespace tomoto
696
715
  void initGlobalState(bool initDocs)
697
716
  {
698
717
  const size_t V = this->realV;
699
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(K);
718
+ this->globalState.zLikelihood = Vector::Zero(K);
700
719
  if (initDocs)
701
720
  {
702
721
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
@@ -708,12 +727,14 @@ namespace tomoto
708
727
 
709
728
  struct Generator
710
729
  {
711
- std::uniform_int_distribution<Tid> theta;
730
+ Eigen::Rand::DiscreteGen<int32_t> theta;
712
731
  };
713
732
 
714
733
  Generator makeGeneratorForInit(const _DocType*) const
715
734
  {
716
- return Generator{ std::uniform_int_distribution<Tid>{0, (Tid)(K - 1)} };
735
+ Generator g;
736
+ g.theta = Eigen::Rand::DiscreteGen<int32_t>{ alphas.data(), alphas.data() + alphas.size() };
737
+ return g;
717
738
  }
718
739
 
719
740
  template<bool _Infer>
@@ -780,12 +801,13 @@ namespace tomoto
780
801
  return cnt;
781
802
  }
782
803
 
783
- std::vector<Float> _getWidsByTopic(size_t tid) const
804
+ std::vector<Float> _getWidsByTopic(size_t tid, bool normalize = true) const
784
805
  {
785
806
  assert(tid < this->globalState.numByTopic.rows());
786
807
  const size_t V = this->realV;
787
808
  std::vector<Float> ret(V);
788
809
  Float sum = this->globalState.numByTopic[tid] + V * eta;
810
+ if (!normalize) sum = 1;
789
811
  auto r = this->globalState.numByTopicWord.row(tid);
790
812
  for (size_t v = 0; v < V; ++v)
791
813
  {
@@ -794,7 +816,7 @@ namespace tomoto
794
816
  return ret;
795
817
  }
796
818
 
797
- template<bool _Together, ParallelScheme _ps, typename _Iter>
819
+ template<bool together, ParallelScheme _ps, typename _Iter>
798
820
  std::vector<double> _infer(_Iter docFirst, _Iter docLast, size_t maxIter, Float tolerance, size_t numWorkers) const
799
821
  {
800
822
  decltype(static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
@@ -803,7 +825,7 @@ namespace tomoto
803
825
  generator = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr);
804
826
  }
805
827
 
806
- if (_Together)
828
+ if (together)
807
829
  {
808
830
  numWorkers = std::min(numWorkers, this->maxThreads[(size_t)_ps]);
809
831
  ThreadPool pool{ numWorkers };
@@ -913,13 +935,26 @@ namespace tomoto
913
935
  DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, vocabWeights, alpha, alphas, eta, K, etaByWord,
914
936
  burnIn, optimInterval);
915
937
 
916
- LDAModel(size_t _K = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
917
- : BaseClass(_rg), K(_K), alpha(_alpha), eta(_eta)
918
- {
919
- if (_K == 0 || _K >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong K value (K = %zd)", _K));
920
- if (_alpha <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alpha value (alpha = %f)", _alpha));
921
- if (_eta <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong eta value (eta = %f)", _eta));
922
- alphas = Eigen::Matrix<Float, -1, 1>::Constant(K, alpha);
938
+ LDAModel(const LDAArgs& args, bool checkAlpha = true)
939
+ : BaseClass(args.seed), K(args.k), alpha(args.alpha[0]), eta(args.eta)
940
+ {
941
+ if (K == 0 || K >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K value (K = %zd)", K));
942
+
943
+ if (args.alpha.size() == 1)
944
+ {
945
+ alphas = Vector::Constant(K, alpha);
946
+ }
947
+ else if (args.alpha.size() == args.k)
948
+ {
949
+ alphas = Eigen::Map<const Vector>(args.alpha.data(), args.alpha.size());
950
+ }
951
+ else if (checkAlpha)
952
+ {
953
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alpha value (len = %zd)", args.alpha.size()));
954
+ }
955
+
956
+ if ((alphas.array() <= 0).any()) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong alpha value");
957
+ if (eta <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong eta value (eta = %f)", eta));
923
958
  }
924
959
 
925
960
  GETTER(K, size_t, K);
@@ -952,7 +987,7 @@ namespace tomoto
952
987
 
953
988
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
954
989
  {
955
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer));
990
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer));
956
991
  }
957
992
 
958
993
  size_t addDoc(const RawDoc& rawDoc) override
@@ -962,15 +997,15 @@ namespace tomoto
962
997
 
963
998
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
964
999
  {
965
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
1000
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
966
1001
  }
967
1002
 
968
1003
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
969
1004
  {
970
- if (priors.size() != K) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
1005
+ if (priors.size() != K) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K.");
971
1006
  for (auto p : priors)
972
1007
  {
973
- if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
1008
+ if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
974
1009
  }
975
1010
  this->dict.add(word);
976
1011
  etaByWord.emplace(word, priors);
@@ -1069,11 +1104,18 @@ namespace tomoto
1069
1104
  return static_cast<const DerivedClass*>(this)->_getTopicsCount();
1070
1105
  }
1071
1106
 
1072
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
1107
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
1073
1108
  {
1074
1109
  std::vector<Float> ret(K);
1075
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), K }.array() =
1076
- (doc.numByTopic.array().template cast<Float>() + alphas.array()) / (doc.getSumWordWeight() + alphas.sum());
1110
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
1111
+ if (normalize)
1112
+ {
1113
+ m = (doc.numByTopic.array().template cast<Float>() + alphas.array()) / (doc.getSumWordWeight() + alphas.sum());
1114
+ }
1115
+ else
1116
+ {
1117
+ m = doc.numByTopic.array().template cast<Float>() + alphas.array();
1118
+ }
1077
1119
  return ret;
1078
1120
  }
1079
1121
 
@@ -19,8 +19,7 @@ namespace tomoto
19
19
  {
20
20
  public:
21
21
  using DefaultDocType = DocumentLLDA<TermWeight::one>;
22
- static ILLDAModel* create(TermWeight _weight, size_t _K = 1,
23
- Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
22
+ static ILLDAModel* create(TermWeight _weight, const LDAArgs& args,
24
23
  bool scalarRng = false);
25
24
 
26
25
  virtual const Dictionary& getTopicLabelDict() const = 0;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class LLDAModel<TermWeight::one>;
6
- template class LLDAModel<TermWeight::idf>;
7
- template class LLDAModel<TermWeight::pmi>;*/
8
-
9
- ILLDAModel* ILLDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
5
+ ILLDAModel* ILLDAModel::create(TermWeight _weight, const LDAArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, LLDAModel, _K, _alpha, _eta, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, LLDAModel, args);
12
8
  }
13
9
  }
@@ -71,13 +71,16 @@ namespace tomoto
71
71
 
72
72
  struct Generator
73
73
  {
74
- std::discrete_distribution<> theta;
74
+ Eigen::Array<Float, -1, 1> p;
75
+ Eigen::Rand::DiscreteGen<int32_t> theta;
75
76
  };
76
77
 
77
78
  Generator makeGeneratorForInit(const _DocType* doc) const
78
79
  {
79
- std::discrete_distribution<> theta{ doc->labelMask.data(), doc->labelMask.data() + this->K };
80
- return Generator{ theta };
80
+ Generator g;
81
+ g.p = doc->labelMask.array().template cast<Float>() * this->alphas.array();
82
+ g.theta = Eigen::Rand::DiscreteGen<int32_t>{ g.p.data(), g.p.data() + this->K };
83
+ return g;
81
84
  }
82
85
 
83
86
  template<bool _Infer>
@@ -88,7 +91,7 @@ namespace tomoto
88
91
  if (this->etaByTopicWord.size())
89
92
  {
90
93
  Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
91
- for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
94
+ col *= g.p;
92
95
  z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
93
96
  }
94
97
  else
@@ -102,8 +105,8 @@ namespace tomoto
102
105
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict);
103
106
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict);
104
107
 
105
- LLDAModel(size_t _K = 1, Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
106
- : BaseClass(_K, _alpha, _eta, _rg)
108
+ LLDAModel(const LDAArgs& args)
109
+ : BaseClass(args)
107
110
  {
108
111
  }
109
112
 
@@ -153,7 +156,7 @@ namespace tomoto
153
156
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
154
157
  {
155
158
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
156
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
159
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
157
160
  }
158
161
 
159
162
  size_t addDoc(const RawDoc& rawDoc) override
@@ -165,16 +168,23 @@ namespace tomoto
165
168
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
166
169
  {
167
170
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
168
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
171
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
169
172
  }
170
173
 
171
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
174
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
172
175
  {
173
176
  std::vector<Float> ret(this->K);
174
177
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
175
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
176
- (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
177
- / (doc.getSumWordWeight() + maskedAlphas.sum());
178
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
179
+ if (normalize)
180
+ {
181
+ m = (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
182
+ / (doc.getSumWordWeight() + maskedAlphas.sum());
183
+ }
184
+ else
185
+ {
186
+ m = doc.numByTopic.array().template cast<Float>() + maskedAlphas;
187
+ }
178
188
  return ret;
179
189
  }
180
190
 
@@ -28,13 +28,22 @@ namespace tomoto
28
28
  template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
29
29
  };
30
30
 
31
+ struct MGLDAArgs : public LDAArgs
32
+ {
33
+ size_t kL = 1;
34
+ size_t t = 3;
35
+ std::vector<Float> alphaL = { 0.1 };
36
+ Float alphaMG = 0.1;
37
+ Float alphaML = 0.1;
38
+ Float etaL = 0.01;
39
+ Float gamma = 0.1;
40
+ };
41
+
31
42
  class IMGLDAModel : public ILDAModel
32
43
  {
33
44
  public:
34
45
  using DefaultDocType = DocumentMGLDA<TermWeight::one>;
35
- static IMGLDAModel* create(TermWeight _weight, size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
36
- Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
37
- Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t seed = std::random_device{}(),
46
+ static IMGLDAModel* create(TermWeight _weight, const MGLDAArgs& args,
38
47
  bool scalarRng = false);
39
48
 
40
49
  virtual size_t getKL() const = 0;