tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -56,12 +56,21 @@ namespace tomoto
56
56
  template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
57
57
  };
58
58
 
59
+ struct HDPArgs : public LDAArgs
60
+ {
61
+ Float gamma = 0.1;
62
+
63
+ HDPArgs()
64
+ {
65
+ k = 2;
66
+ }
67
+ };
68
+
59
69
  class IHDPModel : public ILDAModel
60
70
  {
61
71
  public:
62
72
  using DefaultDocType = DocumentHDP<TermWeight::one>;
63
- static IHDPModel* create(TermWeight _weight, size_t _K = 1,
64
- Float alpha = 0.1, Float eta = 0.01, Float gamma = 0.1, size_t seed = std::random_device{}(),
73
+ static IHDPModel* create(TermWeight _weight, const HDPArgs& args,
65
74
  bool scalarRng = false);
66
75
 
67
76
  virtual Float getGamma() const = 0;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class HDPModel<TermWeight::one>;
6
- template class HDPModel<TermWeight::idf>;
7
- template class HDPModel<TermWeight::pmi>;*/
8
-
9
- IHDPModel* IHDPModel::create(TermWeight _weight, size_t _K, Float _alpha , Float _eta, Float _gamma, size_t seed, bool scalarRng)
5
+ IHDPModel* IHDPModel::create(TermWeight _weight, const HDPArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, HDPModel, _K, _alpha, _eta, _gamma, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, HDPModel, args);
12
8
  }
13
9
  }
@@ -14,7 +14,7 @@ namespace tomoto
14
14
  template<TermWeight _tw>
15
15
  struct ModelStateHDP : public ModelStateLDA<_tw>
16
16
  {
17
- Eigen::Matrix<Float, -1, 1> tableLikelihood, topicLikelihood;
17
+ Vector tableLikelihood, topicLikelihood;
18
18
  Eigen::Matrix<int32_t, -1, 1> numTableByTopic;
19
19
  size_t totalTable = 0;
20
20
 
@@ -397,58 +397,47 @@ namespace tomoto
397
397
 
398
398
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
399
399
  {
400
+ sortAndWriteOrder(doc.words, doc.wOrder);
400
401
  doc.numByTopic.init(nullptr, this->K, 1);
401
402
  doc.numTopicByTable.clear();
402
- doc.Zs = tvector<Tid>(wordSize);
403
+ doc.Zs = tvector<Tid>(wordSize, non_topic_id);
403
404
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
404
405
  }
405
406
 
406
- template<bool _Infer>
407
+ template<bool _infer>
407
408
  void updateStateWithDoc(typename BaseClass::Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
408
409
  {
409
- // generate tables for each topic when inferring
410
- if (_Infer)
410
+ Tid t;
411
+ std::vector<double> dist;
412
+ dist.emplace_back(this->alpha);
413
+ for (auto& d : doc.numTopicByTable) dist.emplace_back(d.num);
414
+ std::discrete_distribution<Tid> ddist{ dist.begin(), dist.end() };
415
+ t = ddist(rgs);
416
+ if (t == 0)
411
417
  {
412
- if (i < this->K)
418
+ // new table
419
+ Tid k;
420
+ if (_infer)
413
421
  {
414
- Tid t = i;
415
- if (isLiveTopic(i))
422
+ std::uniform_int_distribution<> theta{ 0, this->K - 1 };
423
+ do
416
424
  {
417
- t = doc.addNewTable(i);
418
- }
419
- else
420
- {
421
- t = std::uniform_int_distribution<size_t>{ 0, doc.getNumTable() - 1 }(rgs);
422
- }
423
- ++ld.numTableByTopic[doc.numTopicByTable[t].topic];
424
- ++ld.totalTable;
425
- doc.Zs[i] = t;
426
- }
427
- else doc.Zs[i] = std::uniform_int_distribution<size_t>{ 0, doc.getNumTable() - 1 }(rgs);
428
- }
429
- // generate tables following CRP
430
- else
431
- {
432
- Tid t;
433
- std::vector<double> dist;
434
- dist.emplace_back(this->alpha);
435
- for (auto& d : doc.numTopicByTable) dist.emplace_back(d.num);
436
- std::discrete_distribution<Tid> ddist{ dist.begin(), dist.end() };
437
- t = ddist(rgs);
438
- if (t == 0)
439
- {
440
- // new table
441
- Tid k = g.theta(rgs);
442
- t = doc.addNewTable(k);
443
- ++ld.numTableByTopic[k];
444
- ++ld.totalTable;
425
+ k = theta(rgs);
426
+ } while (!isLiveTopic(k));
445
427
  }
446
428
  else
447
429
  {
448
- t -= 1;
430
+ k = g.theta(rgs);
449
431
  }
450
- doc.Zs[i] = t;
432
+ t = doc.addNewTable(k);
433
+ ++ld.numTableByTopic[k];
434
+ ++ld.totalTable;
451
435
  }
436
+ else
437
+ {
438
+ t -= 1;
439
+ }
440
+ doc.Zs[i] = t;
452
441
  addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i], doc.numTopicByTable[doc.Zs[i]].topic);
453
442
  }
454
443
 
@@ -469,10 +458,11 @@ namespace tomoto
469
458
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
470
459
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
471
460
 
472
- HDPModel(size_t initialK = 2, Float _alpha = 0.1, Float _eta = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
473
- : BaseClass(initialK, _alpha, _eta, _rg), gamma(_gamma)
461
+ HDPModel(const HDPArgs& args)
462
+ : BaseClass(args), gamma(args.gamma)
474
463
  {
475
- if (_gamma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong gamma value (gamma = %f)", _gamma));
464
+ if (gamma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong gamma value (gamma = %f)", gamma));
465
+ if (args.alpha.size() > 1) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "Asymmetric alpha is not supported at HDP.");
476
466
  }
477
467
 
478
468
  size_t getTotalTables() const override
@@ -497,13 +487,21 @@ namespace tomoto
497
487
 
498
488
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
499
489
  {
500
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "HDPModel doesn't provide setWordPrior function.");
490
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "HDPModel doesn't provide setWordPrior function.");
501
491
  }
502
492
 
503
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
493
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
504
494
  {
505
495
  std::vector<Float> ret(this->K);
506
- Eigen::Map<Eigen::Matrix<Float, -1, 1>> { ret.data(), this->K }.array() = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
496
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
497
+ if (normalize)
498
+ {
499
+ m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
500
+ }
501
+ else
502
+ {
503
+ m = doc.numByTopic.array().template cast<Float>();
504
+ }
507
505
  return ret;
508
506
  }
509
507
 
@@ -528,7 +526,11 @@ namespace tomoto
528
526
  liveK++;
529
527
  }
530
528
 
531
- auto lda = make_unique<LDAModel<_tw, _RandGen>>(liveK, 0.1f, this->eta);
529
+ LDAArgs args;
530
+ args.k = liveK;
531
+ args.alpha[0] = 0.1f;
532
+ args.eta = this->eta;
533
+ auto lda = std::make_unique<LDAModel<_tw, _RandGen>>(args);
532
534
  lda->dict = this->dict;
533
535
 
534
536
  for (auto& doc : this->docs)
@@ -551,6 +553,11 @@ namespace tomoto
551
553
  {
552
554
  for (size_t j = 0; j < this->docs[i].Zs.size(); ++j)
553
555
  {
556
+ if (this->docs[i].Zs[j] == non_topic_id)
557
+ {
558
+ lda->docs[i].Zs[j] = non_topic_id;
559
+ continue;
560
+ }
554
561
  size_t newTopic = newK[this->docs[i].numTopicByTable[this->docs[i].Zs[j]].topic];
555
562
  while (newTopic == (Tid)-1) newTopic = newK[randomTopic(rng)];
556
563
  lda->docs[i].Zs[j] = newTopic;
@@ -20,12 +20,21 @@ namespace tomoto
20
20
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, path);
21
21
  };
22
22
 
23
+ struct HLDAArgs : public LDAArgs
24
+ {
25
+ Float gamma = 0.1;
26
+
27
+ HLDAArgs()
28
+ {
29
+ k = 2;
30
+ }
31
+ };
32
+
23
33
  class IHLDAModel : public ILDAModel
24
34
  {
25
35
  public:
26
36
  using DefaultDocType = DocumentHLDA<TermWeight::one>;
27
- static IHLDAModel* create(TermWeight _weight, size_t levelDepth = 1,
28
- Float alpha = 0.1, Float eta = 0.01, Float gamma = 0.1, size_t seed = std::random_device{}(),
37
+ static IHLDAModel* create(TermWeight _weight, const HLDAArgs& args,
29
38
  bool scalarRng = false);
30
39
 
31
40
  virtual Float getGamma() const = 0;
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class HLDAModel<TermWeight::one>;
6
- template class HLDAModel<TermWeight::idf>;
7
- template class HLDAModel<TermWeight::pmi>;*/
8
-
9
- IHLDAModel* IHLDAModel::create(TermWeight _weight, size_t levelDepth, Float _alpha, Float _eta, Float _gamma, size_t seed, bool scalarRng)
5
+ IHLDAModel* IHLDAModel::create(TermWeight _weight, const HLDAArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, HLDAModel, levelDepth, _alpha, _eta, _gamma, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, HLDAModel, args);
12
8
  }
13
9
  }
@@ -114,8 +114,8 @@ namespace tomoto
114
114
  static constexpr size_t blockSize = 8;
115
115
  std::vector<NCRPNode> nodes;
116
116
  std::vector<uint8_t> levelBlocks;
117
- Eigen::Matrix<Float, -1, 1> nodeLikelihoods; //
118
- Eigen::Matrix<Float, -1, 1> nodeWLikelihoods; //
117
+ Vector nodeLikelihoods; //
118
+ Vector nodeWLikelihoods; //
119
119
 
120
120
  DEFINE_SERIALIZER(nodes, levelBlocks);
121
121
 
@@ -351,6 +351,8 @@ namespace tomoto
351
351
  template<GlobalSampler _gs>
352
352
  void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
353
353
  {
354
+ if (!doc.getSumWordWeight()) return;
355
+
354
356
  if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
355
357
  ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
356
358
 
@@ -433,7 +435,7 @@ namespace tomoto
433
435
  template<bool _asymEta>
434
436
  Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
435
437
  {
436
- if (_asymEta) THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
438
+ if (_asymEta) THROW_ERROR_WITH_INFO(exc::Unimplemented, "Unimplemented features");
437
439
  const size_t V = this->realV;
438
440
  assert(vid < V);
439
441
  auto& zLikelihood = ld.zLikelihood;
@@ -461,7 +463,6 @@ namespace tomoto
461
463
  double getLLDocs(_DocIter _first, _DocIter _last) const
462
464
  {
463
465
  double ll = 0;
464
- auto lgammaAlpha = math::lgammaT(this->alpha);
465
466
  for (; _first != _last; ++_first)
466
467
  {
467
468
  auto& doc = *_first;
@@ -472,13 +473,9 @@ namespace tomoto
472
473
  }
473
474
 
474
475
  // doc-level distribution
475
- ll -= math::lgammaT(doc.getSumWordWeight() + this->alpha * this->K);
476
- for (Tid l = 0; l < this->K; ++l)
477
- {
478
- ll += math::lgammaT(doc.numByTopic[l] + this->alpha) - lgammaAlpha;
479
- }
476
+ ll -= math::lgammaSubt(this->alphas.sum(), doc.getSumWordWeight());
477
+ ll += math::lgammaSubt(this->alphas.array(), doc.numByTopic.template cast<Float>().array()).sum();
480
478
  }
481
- ll += math::lgammaT(this->alpha * this->K) * std::distance(_first, _last);
482
479
  return ll;
483
480
  }
484
481
 
@@ -521,7 +518,7 @@ namespace tomoto
521
518
  {
522
519
  sortAndWriteOrder(doc.words, doc.wOrder);
523
520
  doc.numByTopic.init(nullptr, this->K, 1);
524
- doc.Zs = tvector<Tid>(wordSize);
521
+ doc.Zs = tvector<Tid>(wordSize, non_topic_id);
525
522
  doc.path.resize(this->K);
526
523
  for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
527
524
 
@@ -597,11 +594,11 @@ namespace tomoto
597
594
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
598
595
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
599
596
 
600
- HLDAModel(size_t _levelDepth = 4, Float _alpha = 0.1, Float _eta = 0.01, Float _gamma = 0.1, size_t _rg = std::random_device{}())
601
- : BaseClass(_levelDepth, _alpha, _eta, _rg), gamma(_gamma)
597
+ HLDAModel(const HLDAArgs& args)
598
+ : BaseClass(args), gamma(args.gamma)
602
599
  {
603
- if (_levelDepth == 0 || _levelDepth >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong levelDepth value (levelDepth = %zd)", _levelDepth));
604
- if (_gamma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong gamma value (gamma = %f)", _gamma));
600
+ if (args.k == 0 || args.k >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong levelDepth value (levelDepth = %zd)", args.k));
601
+ if (gamma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong gamma value (gamma = %f)", gamma));
605
602
  this->globalState.nt = std::make_shared<detail::NodeTrees>();
606
603
  }
607
604
 
@@ -661,7 +658,7 @@ namespace tomoto
661
658
 
662
659
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
663
660
  {
664
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "HLDAModel doesn't provide setWordPrior function.");
661
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "HLDAModel doesn't provide setWordPrior function.");
665
662
  }
666
663
  };
667
664
 
@@ -16,12 +16,15 @@ namespace tomoto
16
16
  DEFINE_SERIALIZER_BASE_WITH_VERSION(BaseDocument, 1);
17
17
  };
18
18
 
19
+ struct HPAArgs : public PAArgs
20
+ {
21
+ };
22
+
19
23
  class IHPAModel : public IPAModel
20
24
  {
21
25
  public:
22
26
  using DefaultDocType = DocumentHPA<TermWeight::one>;
23
- static IHPAModel* create(TermWeight _weight, bool _exclusive = false, size_t _K1 = 1, size_t _K2 = 1,
24
- Float _alpha = 50, Float _eta = 0.01, size_t seed = std::random_device{}(),
27
+ static IHPAModel* create(TermWeight _weight, bool _exclusive, const HPAArgs& args,
25
28
  bool scalarRng = false);
26
29
  };
27
30
  }
@@ -2,11 +2,7 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class HPAModel<TermWeight::one>;
6
- template class HPAModel<TermWeight::idf>;
7
- template class HPAModel<TermWeight::pmi>;*/
8
-
9
- IHPAModel* IHPAModel::create(TermWeight _weight, bool _exclusive, size_t _K, size_t _K2, Float _alphaSum, Float _eta, size_t seed, bool scalarRng)
5
+ IHPAModel* IHPAModel::create(TermWeight _weight, bool _exclusive, const HPAArgs& args, bool scalarRng)
10
6
  {
11
7
  if (_exclusive)
12
8
  {
@@ -14,7 +10,7 @@ namespace tomoto
14
10
  }
15
11
  else
16
12
  {
17
- TMT_SWITCH_TW(_weight, scalarRng, HPAModel, _K, _K2, _alphaSum, _eta, seed);
13
+ TMT_SWITCH_TW(_weight, scalarRng, HPAModel, args);
18
14
  }
19
15
  return nullptr;
20
16
  }
@@ -16,7 +16,7 @@ namespace tomoto
16
16
 
17
17
  std::array<Eigen::Matrix<WeightType, -1, -1>, 3> numByTopicWord;
18
18
  std::array<Eigen::Matrix<WeightType, -1, 1>, 3> numByTopic;
19
- std::array<Eigen::Matrix<Float, -1, 1>, 2> subTmp;
19
+ std::array<Vector, 2> subTmp;
20
20
 
21
21
  Eigen::Matrix<WeightType, -1, -1> numByTopic1_2;
22
22
 
@@ -45,10 +45,10 @@ namespace tomoto
45
45
  Float epsilon = 0.00001;
46
46
  size_t iteration = 5;
47
47
 
48
- //Eigen::Matrix<Float, -1, 1> alphas; // len = (K + 1)
48
+ //Vector alphas; // len = (K + 1)
49
49
 
50
- Eigen::Matrix<Float, -1, 1> subAlphaSum; // len = K
51
- Eigen::Matrix<Float, -1, -1> subAlphas; // len = K * (K2 + 1)
50
+ Vector subAlphaSum; // len = K
51
+ Matrix subAlphas; // len = K * (K2 + 1)
52
52
 
53
53
  void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
54
54
  {
@@ -195,7 +195,7 @@ namespace tomoto
195
195
  Float* dist;
196
196
  if (this->etaByTopicWord.size())
197
197
  {
198
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
198
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "Unimplemented features");
199
199
  }
200
200
  else
201
201
  {
@@ -379,7 +379,7 @@ namespace tomoto
379
379
  void initGlobalState(bool initDocs)
380
380
  {
381
381
  const size_t V = this->realV;
382
- this->globalState.zLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(1 + this->K + this->K * K2);
382
+ this->globalState.zLikelihood = Vector::Zero(1 + this->K + this->K * K2);
383
383
  if (initDocs)
384
384
  {
385
385
  this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
@@ -440,13 +440,37 @@ namespace tomoto
440
440
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
441
441
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
442
442
 
443
- HPAModel(size_t _K1 = 1, size_t _K2 = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
444
- : BaseClass(_K1, _alpha, _eta, _rg), K2(_K2)
443
+ HPAModel(const HPAArgs& args)
444
+ : BaseClass(args, false), K2(args.k2)
445
445
  {
446
- if (_K2 == 0 || _K2 >= 0x80000000) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong K2 value (K2 = %zd)", _K2));
447
- this->alphas = Eigen::Matrix<Float, -1, 1>::Constant(_K1 + 1, _alpha);
448
- subAlphas = Eigen::Matrix<Float, -1, -1>::Constant(_K1, _K2 + 1, _alpha);
449
- subAlphaSum = Eigen::Matrix<Float, -1, 1>::Constant(_K1, (_K2 + 1) * _alpha);
446
+ if (K2 == 0 || K2 >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K2 value (K2 = %zd)", K2));
447
+
448
+ if (args.alpha.size() == 1)
449
+ {
450
+ this->alphas = Vector::Constant(args.k + 1, args.alpha[0]);
451
+ }
452
+ else if (args.alpha.size() == args.k + 1)
453
+ {
454
+ this->alphas = Eigen::Map<const Vector>(args.alpha.data(), (Eigen::Index)args.alpha.size());
455
+ }
456
+ else
457
+ {
458
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alpha value (len = %zd)", args.alpha.size()));
459
+ }
460
+
461
+ if (args.subalpha.size() == 1)
462
+ {
463
+ subAlphas = Matrix::Constant(args.k, args.k2 + 1, args.subalpha[0]);
464
+ }
465
+ else if (args.subalpha.size() == args.k2 + 1)
466
+ {
467
+ subAlphas = Eigen::Map<const Eigen::Matrix<Float, 1, -1>>(args.subalpha.data(), args.subalpha.size()).replicate(args.k, 1);
468
+ }
469
+ else
470
+ {
471
+ THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong subalpha value (len = %zd)", args.subalpha.size()));
472
+ }
473
+ subAlphaSum = subAlphas.rowwise().sum();
450
474
  this->optimInterval = 1;
451
475
  }
452
476
 
@@ -455,7 +479,7 @@ namespace tomoto
455
479
 
456
480
  void setDirichletEstIteration(size_t iter) override
457
481
  {
458
- if (!iter) throw std::invalid_argument("iter must > 0");
482
+ if (!iter) throw exc::InvalidArgument("iter must > 0");
459
483
  iteration = iter;
460
484
  }
461
485
 
@@ -475,20 +499,23 @@ namespace tomoto
475
499
  return ret;
476
500
  }
477
501
 
478
- std::vector<Float> getSubTopicBySuperTopic(Tid k) const override
502
+ std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize) const override
479
503
  {
504
+ std::vector<Float> ret(K2);
480
505
  assert(k < this->K);
481
506
  Float sum = this->globalState.numByTopic1_2.row(k).sum() + subAlphaSum[k];
482
- Eigen::Matrix<Float, -1, 1> ret = (this->globalState.numByTopic1_2.row(k).array().template cast<Float>() + subAlphas.row(k).array()) / sum;
483
- return { ret.data() + 1, ret.data() + K2 + 1 };
507
+ if (!normalize) sum = 1;
508
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), (Eigen::Index)K2 };
509
+ m = (this->globalState.numByTopic1_2.row(k).segment(1, K2).array().template cast<Float>() + subAlphas.row(k).segment(1, K2).array()) / sum;
510
+ return ret;
484
511
  }
485
512
 
486
513
  std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const override
487
514
  {
488
- return extractTopN<Tid>(getSubTopicBySuperTopic(k), topN);
515
+ return extractTopN<Tid>(getSubTopicBySuperTopic(k, true), topN);
489
516
  }
490
517
 
491
- std::vector<Float> _getWidsByTopic(Tid k) const
518
+ std::vector<Float> _getWidsByTopic(Tid k, bool normalize = true) const
492
519
  {
493
520
  const size_t V = this->realV;
494
521
  std::vector<Float> ret(V);
@@ -504,6 +531,7 @@ namespace tomoto
504
531
  }
505
532
  }
506
533
  Float sum = this->globalState.numByTopic[level][k] + V * this->eta;
534
+ if (!normalize) sum = 1;
507
535
  auto r = this->globalState.numByTopicWord[level].row(k);
508
536
  for (size_t v = 0; v < V; ++v)
509
537
  {
@@ -512,10 +540,12 @@ namespace tomoto
512
540
  return ret;
513
541
  }
514
542
 
515
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
543
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
516
544
  {
517
545
  std::vector<Float> ret(1 + this->K + K2);
518
546
  Float sum = doc.getSumWordWeight() + this->alphas.sum();
547
+ if (!normalize) sum = 1;
548
+
519
549
  ret[0] = (doc.numByTopic[0] + this->alphas[0]) / sum;
520
550
  for (size_t k = 0; k < this->K; ++k)
521
551
  {
@@ -528,7 +558,7 @@ namespace tomoto
528
558
  return ret;
529
559
  }
530
560
 
531
- std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const override
561
+ std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize) const override
532
562
  {
533
563
  throw std::runtime_error{ "not applicable" };
534
564
  }
@@ -540,7 +570,7 @@ namespace tomoto
540
570
 
541
571
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
542
572
  {
543
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "HPAModel doesn't provide setWordPrior function.");
573
+ THROW_ERROR_WITH_INFO(exc::Unimplemented, "HPAModel doesn't provide setWordPrior function.");
544
574
  }
545
575
 
546
576
  std::vector<uint64_t> getCountBySuperTopic() const override