tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -2,16 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class MGLDAModel<TermWeight::one>;
6
- template class MGLDAModel<TermWeight::idf>;
7
- template class MGLDAModel<TermWeight::pmi>;*/
8
-
9
- IMGLDAModel* IMGLDAModel::create(TermWeight _weight, size_t _KG, size_t _KL, size_t _T,
10
- Float _alphaG, Float _alphaL, Float _alphaMG, Float _alphaML,
11
- Float _etaG, Float _etaL, Float _gamma, size_t seed, bool scalarRng)
5
+ IMGLDAModel* IMGLDAModel::create(TermWeight _weight, const MGLDAArgs& args, bool scalarRng)
12
6
  {
13
- TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel, _KG, _KL, _T,
14
- _alphaG, _alphaL, _alphaMG, _alphaML,
15
- _etaG, _etaL, _gamma, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel, args);
16
8
  }
17
9
  }
@@ -289,7 +289,7 @@ namespace tomoto
289
289
 
290
290
  const size_t S = doc.numBySent.size();
291
291
  std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0);
292
- doc.Zs = tvector<Tid>(wordSize);
292
+ doc.Zs = tvector<Tid>(wordSize, non_topic_id);
293
293
  doc.Vs.resize(wordSize);
294
294
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
295
295
  doc.numByTopic.init(nullptr, this->K + KL, 1);
@@ -302,7 +302,7 @@ namespace tomoto
302
302
  void initGlobalState(bool initDocs)
303
303
  {
304
304
  const size_t V = this->realV;
305
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(T * (this->K + KL));
305
+ this->globalState.zLikelihood = Vector::Zero(T * (this->K + KL));
306
306
  if (initDocs)
307
307
  {
308
308
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
@@ -371,17 +371,33 @@ namespace tomoto
371
371
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
372
372
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
373
373
 
374
- MGLDAModel(size_t _KG = 1, size_t _KL = 1, size_t _T = 3,
375
- Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
376
- Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
377
- : BaseClass(_KG, _alphaG, _etaG, _rg), KL(_KL), T(_T),
378
- alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML),
379
- etaL(_etaL), gamma(_gamma)
374
+ MGLDAModel(const MGLDAArgs& args)
375
+ : BaseClass(args), KL(args.kL), T(args.t),
376
+ alphaL(args.alphaL[0]), alphaM(args.k ? args.alphaMG : 0), alphaML(args.alphaML),
377
+ etaL(args.etaL), gamma(args.gamma)
380
378
  {
381
- if (_KL == 0 || _KL >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong KL value (KL = %zd)", _KL));
382
- if (_T == 0 || _T >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong T value (T = %zd)", _T));
383
- if (_alphaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong alphaL value (alphaL = %f)", _alphaL));
384
- if (_etaL <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong etaL value (etaL = %f)", _etaL));
379
+ if (KL == 0 || KL >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong KL value (KL = %zd)", KL));
380
+ if (T == 0 || T >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong T value (T = %zd)", T));
381
+
382
+ if (args.alpha.size() != 1)
383
+ {
384
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alpha prior is not supported yet at MGLDA.");
385
+ }
386
+
387
+ if (args.alphaL.size() == 1)
388
+ {
389
+ }
390
+ else if (args.alphaL.size() == args.kL)
391
+ {
392
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alphaL prior is not supported yet at MGLDA.");
393
+ }
394
+ else
395
+ {
396
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (len = %zd)", args.alphaL.size()));
397
+ }
398
+
399
+ if (alphaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (alphaL = %f)", alphaL));
400
+ if (etaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong etaL value (etaL = %f)", etaL));
385
401
  }
386
402
 
387
403
  template<bool _const, typename _FnTokenizer>
@@ -426,7 +442,7 @@ namespace tomoto
426
442
 
427
443
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const
428
444
  {
429
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
445
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
430
446
  }
431
447
 
432
448
  template<bool _const = false>
@@ -497,25 +513,32 @@ namespace tomoto
497
513
 
498
514
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const
499
515
  {
500
- return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
516
+ return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
501
517
  }
502
518
 
503
519
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
504
520
  {
505
- if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K.");
521
+ if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K.");
506
522
  for (auto p : priors)
507
523
  {
508
- if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
524
+ if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
509
525
  }
510
526
  this->dict.add(word);
511
527
  this->etaByWord.emplace(word, priors);
512
528
  }
513
529
 
514
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
530
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
515
531
  {
516
532
  std::vector<Float> ret(this->K + KL);
517
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K + KL }.array() =
518
- doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
533
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
534
+ if (normalize)
535
+ {
536
+ m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
537
+ }
538
+ else
539
+ {
540
+ m = doc.numByTopic.array().template cast<Float>();
541
+ }
519
542
  return ret;
520
543
  }
521
544
 
@@ -18,13 +18,18 @@ namespace tomoto
18
18
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, Z2s);
19
19
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, Z2s);
20
20
  };
21
+
22
+ struct PAArgs : public LDAArgs
23
+ {
24
+ size_t k2 = 1;
25
+ std::vector<Float> subalpha = { 0.1 };
26
+ };
21
27
 
22
28
  class IPAModel : public ILDAModel
23
29
  {
24
30
  public:
25
31
  using DefaultDocType = DocumentPA<TermWeight::one>;
26
- static IPAModel* create(TermWeight _weight, size_t _K1 = 1, size_t _K2 = 1,
27
- Float _alpha = 0.1, Float _eta = 0.01, size_t seed = std::random_device{}(),
32
+ static IPAModel* create(TermWeight _weight, const PAArgs& args,
28
33
  bool scalarRng = false);
29
34
 
30
35
  virtual size_t getDirichletEstIteration() const = 0;
@@ -32,10 +37,10 @@ namespace tomoto
32
37
  virtual size_t getK2() const = 0;
33
38
  virtual Float getSubAlpha(Tid k1, Tid k2) const = 0;
34
39
  virtual std::vector<Float> getSubAlpha(Tid k1) const = 0;
35
- virtual std::vector<Float> getSubTopicBySuperTopic(Tid k) const = 0;
40
+ virtual std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize = true) const = 0;
36
41
  virtual std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const = 0;
37
42
 
38
- virtual std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const = 0;
43
+ virtual std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize = true) const = 0;
39
44
  virtual std::vector<std::pair<Tid, Float>> getSubTopicsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
40
45
 
41
46
  virtual std::vector<uint64_t> getCountBySuperTopic() const = 0;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class PAModel<TermWeight::one>;
6
- template class PAModel<TermWeight::idf>;
7
- template class PAModel<TermWeight::pmi>;*/
8
-
9
- IPAModel* IPAModel::create(TermWeight _weight, size_t _K, size_t _K2, Float _alpha, Float _eta, size_t seed, bool scalarRng)
5
+ IPAModel* IPAModel::create(TermWeight _weight, const PAArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, PAModel, _K, _K2, _alpha, _eta, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, PAModel, args);
12
8
  }
13
9
  }
@@ -16,7 +16,7 @@ namespace tomoto
16
16
  using WeightType = typename ModelStateLDA<_tw>::WeightType;
17
17
  Eigen::Matrix<WeightType, -1, -1> numByTopic1_2;
18
18
  Eigen::Matrix<WeightType, -1, 1> numByTopic2;
19
- Eigen::Matrix<Float, -1, 1> subTmp;
19
+ Vector subTmp;
20
20
 
21
21
  DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
22
22
  };
@@ -41,8 +41,8 @@ namespace tomoto
41
41
  Float epsilon = 1e-5;
42
42
  size_t iteration = 5;
43
43
 
44
- Eigen::Matrix<Float, -1, 1> subAlphaSum; // len = K
45
- Eigen::Matrix<Float, -1, -1> subAlphas; // len = K * K2
44
+ Vector subAlphaSum; // len = K
45
+ Matrix subAlphas; // len = K * K2
46
46
  void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
47
47
  {
48
48
  const auto K = this->K;
@@ -286,7 +286,7 @@ namespace tomoto
286
286
  BaseClass::prepareDoc(doc, docId, wordSize);
287
287
 
288
288
  doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
289
- doc.Z2s = tvector<Tid>(wordSize);
289
+ doc.Z2s = tvector<Tid>(wordSize, non_topic_id);
290
290
  }
291
291
 
292
292
  void prepareWordPriors()
@@ -299,7 +299,7 @@ namespace tomoto
299
299
  {
300
300
  auto id = this->dict.toWid(it.first);
301
301
  if (id == (Vid)-1 || id >= this->realV) continue;
302
- this->etaByTopicWord.col(id) = Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ it.second.data(), (Eigen::Index)it.second.size() };
302
+ this->etaByTopicWord.col(id) = Eigen::Map<Vector>{ it.second.data(), (Eigen::Index)it.second.size() };
303
303
  }
304
304
  this->etaSumByTopic = this->etaByTopicWord.rowwise().sum();
305
305
  }
@@ -307,7 +307,7 @@ namespace tomoto
307
307
  void initGlobalState(bool initDocs)
308
308
  {
309
309
  const size_t V = this->realV;
310
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(this->K * K2);
310
+ this->globalState.zLikelihood = Vector::Zero(this->K * K2);
311
311
  if (initDocs)
312
312
  {
313
313
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
@@ -365,12 +365,24 @@ namespace tomoto
365
365
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
366
366
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
367
367
 
368
- PAModel(size_t _K1 = 1, size_t _K2 = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
369
- : BaseClass(_K1, _alpha, _eta, _rg), K2(_K2)
368
+ PAModel(const PAArgs& args)
369
+ : BaseClass(args), K2(args.k2)
370
370
  {
371
- if (_K2 == 0 || _K2 >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong K2 value (K2 = %zd)", _K2));
372
- subAlphaSum = Eigen::Matrix<Float, -1, 1>::Constant(_K1, _K2 * 0.1);
373
- subAlphas = Eigen::Matrix<Float, -1, -1>::Constant(_K1, _K2, 0.1);
371
+ if (K2 == 0 || K2 >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K2 value (K2 = %zd)", K2));
372
+
373
+ if (args.subalpha.size() == 1)
374
+ {
375
+ subAlphas = Matrix::Constant(args.k, args.k2, args.subalpha[0]);
376
+ }
377
+ else if(args.subalpha.size() == args.k2)
378
+ {
379
+ subAlphas = Eigen::Map<const Eigen::Matrix<Float, 1, -1>>(args.subalpha.data(), args.subalpha.size()).replicate(args.k, 1);
380
+ }
381
+ else
382
+ {
383
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong subalpha value (len = %zd)", args.subalpha.size()));
384
+ }
385
+ subAlphaSum = subAlphas.rowwise().sum();
374
386
  this->optimInterval = 1;
375
387
  }
376
388
 
@@ -379,7 +391,7 @@ namespace tomoto
379
391
 
380
392
  void setDirichletEstIteration(size_t iter) override
381
393
  {
382
- if (!iter) throw std::invalid_argument("iter must > 0");
394
+ if (!iter) throw exc::InvalidArgument("iter must > 0");
383
395
  iteration = iter;
384
396
  }
385
397
 
@@ -392,43 +404,54 @@ namespace tomoto
392
404
  return ret;
393
405
  }
394
406
 
395
- std::vector<Float> getSubTopicBySuperTopic(Tid k) const override
407
+ std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize) const override
396
408
  {
397
409
  assert(k < this->K);
410
+ std::vector<Float> ret(K2);
398
411
  Float sum = this->globalState.numByTopic[k] + subAlphaSum[k];
399
- Eigen::Matrix<Float, -1, 1> ret = (this->globalState.numByTopic1_2.row(k).array().template cast<Float>() + subAlphas.row(k).array()) / sum;
400
- return { ret.data(), ret.data() + K2 };
412
+ if (!normalize) sum = 1;
413
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K2 };
414
+ m = (this->globalState.numByTopic1_2.row(k).array().template cast<Float>() + subAlphas.row(k).array()) / sum;
415
+ return ret;
401
416
  }
402
417
 
403
418
  std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const override
404
419
  {
405
- return extractTopN<Tid>(getSubTopicBySuperTopic(k), topN);
420
+ return extractTopN<Tid>(getSubTopicBySuperTopic(k, true), topN);
406
421
  }
407
422
 
408
- std::vector<Float> getSubTopicsByDoc(const _DocType& doc) const
423
+ std::vector<Float> getSubTopicsByDoc(const _DocType& doc, bool normalize) const
409
424
  {
410
425
  std::vector<Float> ret(K2);
411
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), K2 }.array() =
412
- ((doc.numByTopic1_2.array().template cast<Float>() + subAlphas.array()).colwise().sum()) / (doc.getSumWordWeight() + subAlphas.sum());
426
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K2 };
427
+ if (normalize)
428
+ {
429
+ m = ((doc.numByTopic1_2.array().template cast<Float>() + subAlphas.array()).colwise().sum()) / (doc.getSumWordWeight() + subAlphas.sum());
430
+ }
431
+ else
432
+ {
433
+ m = (doc.numByTopic1_2.array().template cast<Float>() + subAlphas.array()).colwise().sum();
434
+ }
413
435
  return ret;
414
436
  }
415
437
 
416
- std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const override
438
+ std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize) const override
417
439
  {
418
- return static_cast<const DerivedClass*>(this)->getSubTopicsByDoc(*static_cast<const _DocType*>(doc));
440
+ return static_cast<const DerivedClass*>(this)->getSubTopicsByDoc(*static_cast<const _DocType*>(doc), normalize);
419
441
  }
420
442
 
421
443
  std::vector<std::pair<Tid, Float>> getSubTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
422
444
  {
423
- return extractTopN<Tid>(getSubTopicsByDoc(doc), topN);
445
+ return extractTopN<Tid>(getSubTopicsByDoc(doc, true), topN);
424
446
  }
425
447
 
426
- std::vector<Float> _getWidsByTopic(Tid k2) const
448
+ std::vector<Float> _getWidsByTopic(Tid k2, bool normalize = true) const
427
449
  {
428
450
  assert(k2 < K2);
429
451
  const size_t V = this->realV;
430
452
  std::vector<Float> ret(V);
431
453
  Float sum = this->globalState.numByTopic2[k2] + V * this->eta;
454
+ if (!normalize) sum = 1;
432
455
  auto r = this->globalState.numByTopicWord.row(k2);
433
456
  for (size_t v = 0; v < V; ++v)
434
457
  {
@@ -439,10 +462,10 @@ namespace tomoto
439
462
 
440
463
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
441
464
  {
442
- if (priors.size() != K2) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors.size() must be equal to K2.");
465
+ if (priors.size() != K2) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K2.");
443
466
  for (auto p : priors)
444
467
  {
445
- if (p < 0) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "priors must not be less than 0.");
468
+ if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
446
469
  }
447
470
  this->dict.add(word);
448
471
  this->etaByWord.emplace(word, priors);
@@ -3,13 +3,24 @@
3
3
 
4
4
  namespace tomoto
5
5
  {
6
+ struct PLDAArgs : public LDAArgs
7
+ {
8
+ size_t numLatentTopics = 0;
9
+ size_t numTopicsPerLabel = 1;
10
+
11
+ PLDAArgs setK(size_t _k = 1) const
12
+ {
13
+ PLDAArgs ret = *this;
14
+ ret.k = _k;
15
+ return ret;
16
+ }
17
+ };
6
18
 
7
19
  class IPLDAModel : public ILLDAModel
8
20
  {
9
21
  public:
10
22
  using DefaultDocType = DocumentLLDA<TermWeight::one>;
11
- static IPLDAModel* create(TermWeight _weight, size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
12
- Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
23
+ static IPLDAModel* create(TermWeight _weight, const PLDAArgs& args,
13
24
  bool scalarRng = false);
14
25
 
15
26
  virtual size_t getNumLatentTopics() const = 0;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class PLDAModel<TermWeight::one>;
6
- template class PLDAModel<TermWeight::idf>;
7
- template class PLDAModel<TermWeight::pmi>;*/
8
-
9
- IPLDAModel* IPLDAModel::create(TermWeight _weight, size_t _numLatentTopics, size_t _numTopicsPerLabel, Float _alpha, Float _eta, size_t seed, bool scalarRng)
5
+ IPLDAModel* IPLDAModel::create(TermWeight _weight, const PLDAArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, PLDAModel, _numLatentTopics, _numTopicsPerLabel, _alpha, _eta, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, PLDAModel, args);
12
8
  }
13
9
  }
@@ -75,14 +75,16 @@ namespace tomoto
75
75
 
76
76
  struct Generator
77
77
  {
78
- std::discrete_distribution<> theta;
78
+ Eigen::Array<Float, -1, 1> p;
79
+ Eigen::Rand::DiscreteGen<int32_t> theta;
79
80
  };
80
81
 
81
82
  Generator makeGeneratorForInit(const _DocType* doc) const
82
83
  {
83
- return Generator{
84
- std::discrete_distribution<>{ doc->labelMask.data(), doc->labelMask.data() + doc->labelMask.size() }
85
- };
84
+ Generator g;
85
+ g.p = doc->labelMask.array().template cast<Float>() * this->alphas.array();
86
+ g.theta = Eigen::Rand::DiscreteGen<int32_t>{ g.p.data(), g.p.data() + this->K };
87
+ return g;
86
88
  }
87
89
 
88
90
  template<bool _Infer>
@@ -93,7 +95,7 @@ namespace tomoto
93
95
  if (this->etaByTopicWord.size())
94
96
  {
95
97
  Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
96
- for (size_t k = 0; k < col.size(); ++k) col[k] *= g.theta.probabilities()[k];
98
+ col *= g.p;
97
99
  z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
98
100
  }
99
101
  else
@@ -107,15 +109,14 @@ namespace tomoto
107
109
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
108
110
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
109
111
 
110
- PLDAModel(size_t _numLatentTopics = 0, size_t _numTopicsPerLabel = 1,
111
- Float _alpha = 1.0, Float _eta = 0.01, size_t _rg = std::random_device{}())
112
- : BaseClass(1, _alpha, _eta, _rg),
113
- numLatentTopics(_numLatentTopics), numTopicsPerLabel(_numTopicsPerLabel)
112
+ PLDAModel(const PLDAArgs& args)
113
+ : BaseClass(args.setK(1)),
114
+ numLatentTopics(args.numLatentTopics), numTopicsPerLabel(args.numTopicsPerLabel)
114
115
  {
115
- if (_numLatentTopics >= 0x80000000)
116
- THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numLatentTopics value (numLatentTopics = %zd)", _numLatentTopics));
117
- if (_numTopicsPerLabel == 0 || _numTopicsPerLabel >= 0x80000000)
118
- THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong numTopicsPerLabel value (numTopicsPerLabel = %zd)", _numTopicsPerLabel));
116
+ if (numLatentTopics >= 0x80000000)
117
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong numLatentTopics value (numLatentTopics = %zd)", numLatentTopics));
118
+ if (numTopicsPerLabel == 0 || numTopicsPerLabel >= 0x80000000)
119
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong numTopicsPerLabel value (numTopicsPerLabel = %zd)", numTopicsPerLabel));
119
120
  }
120
121
 
121
122
  template<bool _const = false>
@@ -162,7 +163,7 @@ namespace tomoto
162
163
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
163
164
  {
164
165
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
165
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
166
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
166
167
  }
167
168
 
168
169
  size_t addDoc(const RawDoc& rawDoc) override
@@ -174,16 +175,23 @@ namespace tomoto
174
175
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
175
176
  {
176
177
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
177
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
178
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
178
179
  }
179
180
 
180
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
181
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
181
182
  {
182
183
  std::vector<Float> ret(this->K);
183
184
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
184
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() =
185
- (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
186
- / (doc.getSumWordWeight() + maskedAlphas.sum());
185
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
186
+ if (normalize)
187
+ {
188
+ m = (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
189
+ / (doc.getSumWordWeight() + maskedAlphas.sum());
190
+ }
191
+ else
192
+ {
193
+ m = doc.numByTopic.array().template cast<Float>() + maskedAlphas;
194
+ }
187
195
  return ret;
188
196
  }
189
197