tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -14,19 +14,38 @@ namespace tomoto
14
14
  ShareableMatrix<Float, -1, 1> eta;
15
15
  sample::AliasMethod<> aliasTable;
16
16
 
17
+ RawDoc::MiscType makeMisc(const ITopicModel* tm) const override
18
+ {
19
+ RawDoc::MiscType ret = DocumentLDA<_tw>::makeMisc(tm);
20
+ ret["timepoint"] = (uint32_t)timepoint;
21
+ return ret;
22
+ }
23
+
17
24
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
18
25
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, timepoint);
19
26
  };
20
27
 
28
+ struct DTArgs : public LDAArgs
29
+ {
30
+ size_t t = 1;
31
+ Float phi = 0.1;
32
+ Float shapeA = 0.01;
33
+ Float shapeB = 0.1;
34
+ Float shapeC = 0.55;
35
+ Float etaL2Reg = 0;
36
+
37
+ DTArgs()
38
+ {
39
+ alpha[0] = 0.1;
40
+ eta = 0.1;
41
+ }
42
+ };
43
+
21
44
  class IDTModel : public ILDAModel
22
45
  {
23
46
  public:
24
47
  using DefaultDocType = DocumentDTM<TermWeight::one>;
25
- static IDTModel* create(TermWeight _weight, size_t _K = 1, size_t _T = 1,
26
- Float _alphaVar = 1.0, Float _etaVar = 1.0, Float _phiVar = 1.0,
27
- Float _shapeA = 0.03, Float _shapeB = 0.1, Float _shapeC = 0.55,
28
- Float _etaRegL2 = 0,
29
- size_t seed = std::random_device{}(),
48
+ static IDTModel* create(TermWeight _weight, const DTArgs& args,
30
49
  bool scalarRng = false);
31
50
 
32
51
  virtual size_t getT() const = 0;
@@ -2,14 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class DTModel<TermWeight::one>;
6
- template class DTModel<TermWeight::idf>;
7
- template class DTModel<TermWeight::pmi>;*/
8
-
9
- IDTModel* IDTModel::create(TermWeight _weight, size_t _K, size_t _T,
10
- Float _alphaVar, Float _etaVar, Float _phiVar,
11
- Float _shapeA, Float _shapeB, Float _shapeC, Float _etaRegL2, size_t seed, bool scalarRng)
5
+ IDTModel* IDTModel::create(TermWeight _weight, const DTArgs& args, bool scalarRng)
12
6
  {
13
- TMT_SWITCH_TW(_weight, scalarRng, DTModel, _K, _T, _alphaVar, _etaVar, _phiVar, _shapeA, _shapeB, _shapeC, _etaRegL2, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, DTModel, args);
14
8
  }
15
9
  }
@@ -45,12 +45,12 @@ namespace tomoto
45
45
 
46
46
  uint64_t T;
47
47
  Float shapeA = 0.03f, shapeB = 0.1f, shapeC = 0.55f;
48
- const Float alphaVar = 1.f, etaVar = 1.f, phiVar = 1.f, etaRegL2 = 0.0f;
48
+ Float alphaVar = 1.f, etaVar = 1.f, phiVar = 1.f, etaRegL2 = 0.0f;
49
49
 
50
- Eigen::Matrix<Float, -1, -1> alphas; // Dim: (Topic, Time)
51
- Eigen::Matrix<Float, -1, -1> etaByDoc; // Dim: (Topic, Docs) : Topic distribution by docs(and time)
50
+ Matrix alphas; // Dim: (Topic, Time)
51
+ Matrix etaByDoc; // Dim: (Topic, Docs) : Topic distribution by docs(and time)
52
52
  std::vector<uint32_t> numDocsByTime; // Dim: (Time)
53
- Eigen::Matrix<Float, -1, -1> phi; // Dim: (Word, Topic * Time)
53
+ Matrix phi; // Dim: (Word, Topic * Time)
54
54
  std::vector<sample::AliasMethod<>> wordAliasTables; // Dim: (Word * Time)
55
55
 
56
56
  template<int _inc>
@@ -84,8 +84,8 @@ namespace tomoto
84
84
 
85
85
  // sampling eta
86
86
  {
87
- Eigen::Matrix<Float, -1, 1> estimatedCnt = (doc.eta.array() - doc.eta.maxCoeff()).exp();
88
- Eigen::Matrix<Float, -1, 1> etaTmp;
87
+ Vector estimatedCnt = (doc.eta.array() - doc.eta.maxCoeff()).exp();
88
+ Vector etaTmp;
89
89
  estimatedCnt *= doc.getSumWordWeight() / estimatedCnt.sum();
90
90
  auto prior = (alphas.col(doc.timepoint) - doc.eta) / std::max(etaVar, eps * 2);
91
91
  auto grad = doc.numByTopic.template cast<Float>() - estimatedCnt;
@@ -181,20 +181,21 @@ namespace tomoto
181
181
  template<typename _DocIter>
182
182
  void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
183
183
  {
184
+ if (!this->realV) return;
184
185
  const auto K = this->K;
185
186
  const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
186
187
 
187
188
  // sampling phi
188
189
  for (size_t k = 0; k < K; ++k)
189
190
  {
190
- Eigen::Matrix<Float, -1, -1> phiGrad{ (Eigen::Index)this->realV, (Eigen::Index)T };
191
+ Matrix phiGrad{ (Eigen::Index)this->realV, (Eigen::Index)T };
191
192
  for (size_t t = 0; t < T; ++t)
192
193
  {
193
194
  auto phi_tk = phi.col(k + K * t);
194
- Eigen::Matrix<Float, -1, 1> estimatedCnt = (phi_tk.array() - phi_tk.maxCoeff()).exp();
195
+ Vector estimatedCnt = (phi_tk.array() - phi_tk.maxCoeff()).exp();
195
196
  estimatedCnt *= this->globalState.numByTopic(k, t) / estimatedCnt.sum();
196
197
 
197
- Eigen::Matrix<Float, -1, 1> grad = this->globalState.numByTopicWord.row(k + K * t).template cast<Float>();
198
+ Vector grad = this->globalState.numByTopicWord.row(k + K * t).template cast<Float>();
198
199
  grad -= estimatedCnt;
199
200
  auto epsNoise = Eigen::Rand::normal<Eigen::Array<Float, -1, 1>>(this->realV, 1, *rgs) * eps;
200
201
  if (t == 0)
@@ -228,7 +229,7 @@ namespace tomoto
228
229
  }
229
230
  }
230
231
 
231
- Eigen::Matrix<Float, -1, -1> newAlphas = Eigen::Matrix<Float, -1, -1>::Zero(alphas.rows(), alphas.cols());
232
+ Matrix newAlphas = Matrix::Zero(alphas.rows(), alphas.cols());
232
233
  for (size_t t = 0; t < T; ++t)
233
234
  {
234
235
  // update alias tables for word proposal
@@ -398,9 +399,9 @@ namespace tomoto
398
399
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, T);
399
400
  this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K * T, V);
400
401
 
401
- alphas = Eigen::Matrix<Float, -1, -1>::Zero(this->K, T);
402
- etaByDoc = Eigen::Matrix<Float, -1, -1>::Zero(this->K, this->docs.size());
403
- phi = Eigen::Matrix<Float, -1, -1>::Zero(this->realV, this->K * T);
402
+ alphas = Matrix::Zero(this->K, T);
403
+ etaByDoc = Matrix::Zero(this->K, this->docs.size());
404
+ phi = Matrix::Zero(this->realV, this->K * T);
404
405
  }
405
406
 
406
407
  numDocsByTime.resize(T);
@@ -418,7 +419,7 @@ namespace tomoto
418
419
 
419
420
  for (Tid t = 0; t < T; ++t)
420
421
  {
421
- if (initDocs && !numDocsByTime[t]) THROW_ERROR_WITH_INFO(exception::InvalidArgument, text::format("No document with timepoint = %d", t));
422
+ if (initDocs && !numDocsByTime[t]) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("No document with timepoint = %d", t));
422
423
 
423
424
  // update alias tables for word proposal
424
425
  for (Vid v = 0; v < this->realV; ++v)
@@ -439,23 +440,26 @@ namespace tomoto
439
440
  addWordTo<1>(ld, doc, i, w, z);
440
441
  }
441
442
 
442
- std::vector<Float> _getWidsByTopic(size_t tid) const
443
+ std::vector<Float> _getWidsByTopic(size_t tid, bool normalize = true) const
443
444
  {
444
445
  const size_t V = this->realV;
445
446
  std::vector<Float> ret(V);
446
447
  Eigen::Map<Eigen::Array<Float, -1, 1>> retMap(ret.data(), V);
447
448
  retMap = phi.col(tid).array().exp();
448
- retMap /= retMap.sum();
449
- Eigen::Array<Float, -1, 1> t = this->globalState.numByTopicWord.row(tid).array().template cast<Float>();
450
- t /= std::max(t.sum(), (Float)0.1);
451
- retMap += t;
452
- retMap /= 2;
449
+ if (normalize)
450
+ {
451
+ retMap /= retMap.sum();
452
+ Eigen::Array<Float, -1, 1> t = this->globalState.numByTopicWord.row(tid).array().template cast<Float>();
453
+ t /= std::max(t.sum(), (Float)0.1);
454
+ retMap += t;
455
+ retMap /= 2;
456
+ }
453
457
  return ret;
454
458
  }
455
459
 
456
460
  _DocType& _updateDoc(_DocType& doc, uint32_t timepoint) const
457
461
  {
458
- if (timepoint >= T) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "timepoint must < T");
462
+ if (timepoint >= T) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "timepoint must < T");
459
463
  doc.timepoint = timepoint;
460
464
  return doc;
461
465
  }
@@ -473,6 +477,16 @@ namespace tomoto
473
477
  return cnt;
474
478
  }
475
479
 
480
+ void updateForCopy()
481
+ {
482
+ BaseClass::updateForCopy();
483
+ size_t docId = 0;
484
+ for (auto& doc : this->docs)
485
+ {
486
+ doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
487
+ }
488
+ }
489
+
476
490
  public:
477
491
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0,
478
492
  T, shapeA, shapeB, shapeC, alphaVar, etaVar, phiVar, alphas, etaByDoc, phi);
@@ -489,11 +503,10 @@ namespace tomoto
489
503
  GETTER(ShapeB, Float, shapeB);
490
504
  GETTER(ShapeC, Float, shapeC);
491
505
 
492
- DTModel(size_t _K, size_t _T, Float _alphaVar, Float _etaVar, Float _phiVar,
493
- Float _shapeA, Float _shapeB, Float _shapeC, Float _etaRegL2, size_t _rg)
494
- : BaseClass{ _K, _alphaVar, _etaVar, _rg },
495
- T{ _T }, alphaVar{ _alphaVar }, etaVar{ _etaVar }, phiVar{ _phiVar },
496
- shapeA{ _shapeA }, shapeB{ _shapeB }, shapeC{ _shapeC }, etaRegL2{ _etaRegL2 }
506
+ DTModel(const DTArgs& args)
507
+ : BaseClass{ args },
508
+ T{ args.t }, alphaVar{ args.alpha[0] }, etaVar{ args.eta }, phiVar{ args.phi },
509
+ shapeA{ args.shapeA }, shapeB{ args.shapeB }, shapeC{ args.shapeC }, etaRegL2{ args.etaL2Reg }
497
510
  {
498
511
  }
499
512
 
@@ -506,7 +519,7 @@ namespace tomoto
506
519
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
507
520
  {
508
521
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
509
- return make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
522
+ return std::make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
510
523
  }
511
524
 
512
525
  size_t addDoc(const RawDoc& rawDoc) override
@@ -518,7 +531,7 @@ namespace tomoto
518
531
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
519
532
  {
520
533
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
521
- return make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
534
+ return std::make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
522
535
  }
523
536
 
524
537
  Float getAlpha(size_t k, size_t t) const override
@@ -10,26 +10,52 @@ namespace tomoto
10
10
  using DocumentDMR<_tw>::DocumentDMR;
11
11
  std::vector<Float> metadataOrg, metadataNormalized;
12
12
 
13
+ RawDoc::MiscType makeMisc(const ITopicModel* tm) const override
14
+ {
15
+ RawDoc::MiscType ret = DocumentDMR<_tw>::makeMisc(tm);
16
+ ret["numeric_metadata"] = metadataOrg;
17
+ return ret;
18
+ }
19
+
13
20
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadataOrg);
14
21
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadataOrg, metadataNormalized);
15
22
  };
16
23
 
24
+ struct GDMRArgs : public DMRArgs
25
+ {
26
+ std::vector<uint64_t> degrees;
27
+ Float sigma0 = 3.0;
28
+ Float orderDecay = 0;
29
+ };
30
+
17
31
  class IGDMRModel : public IDMRModel
18
32
  {
19
33
  public:
20
34
  using DefaultDocType = DocumentDMR<TermWeight::one>;
21
- static IGDMRModel* create(TermWeight _weight, size_t _K = 1, const std::vector<uint64_t>& _degreeByF = {},
22
- Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _sigma0 = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
23
- size_t seed = std::random_device{}(),
35
+ static IGDMRModel* create(TermWeight _weight, const GDMRArgs& args,
24
36
  bool scalarRng = false);
25
37
 
26
38
  virtual Float getSigma0() const = 0;
39
+ virtual Float getOrderDecay() const = 0;
27
40
  virtual void setSigma0(Float) = 0;
28
41
  virtual const std::vector<uint64_t>& getFs() const = 0;
29
42
  virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
30
43
 
31
- virtual std::vector<Float> getTDF(const Float* metadata, bool normalize) const = 0;
32
- virtual std::vector<Float> getTDFBatch(const Float* metadata, size_t stride, size_t cnt, bool normalize) const = 0;
44
+ virtual std::vector<Float> getTDF(
45
+ const Float* metadata,
46
+ const std::string& metadataCat,
47
+ const std::vector<std::string>& multiMetadataCat,
48
+ bool normalize
49
+ ) const = 0;
50
+
51
+ virtual std::vector<Float> getTDFBatch(
52
+ const Float* metadata,
53
+ const std::string& metadataCat,
54
+ const std::vector<std::string>& multiMetadataCat,
55
+ size_t stride,
56
+ size_t cnt,
57
+ bool normalize
58
+ ) const = 0;
33
59
 
34
60
  virtual void setMdRange(const std::vector<Float>& vMin, const std::vector<Float>& vMax) = 0;
35
61
  virtual void getMdRange(std::vector<Float>& vMin, std::vector<Float>& vMax) const = 0;
@@ -2,13 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class GDMRModel<TermWeight::one>;
6
- template class GDMRModel<TermWeight::idf>;
7
- template class GDMRModel<TermWeight::pmi>;*/
8
-
9
- IGDMRModel* IGDMRModel::create(TermWeight _weight, size_t _K, const std::vector<uint64_t>& degreeByF,
10
- Float _defaultAlpha, Float _sigma, Float _sigma0, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
5
+ IGDMRModel* IGDMRModel::create(TermWeight _weight, const GDMRArgs& args, bool scalarRng)
11
6
  {
12
- TMT_SWITCH_TW(_weight, scalarRng, GDMRModel, _K, degreeByF, _defaultAlpha, _sigma, _sigma0, _eta, _alphaEps, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, GDMRModel, args);
13
8
  }
14
9
  }
@@ -8,8 +8,8 @@ namespace tomoto
8
8
  template<TermWeight _tw>
9
9
  struct ModelStateGDMR : public ModelStateDMR<_tw>
10
10
  {
11
- /*Eigen::Matrix<Float, -1, 1> alphas;
12
- Eigen::Matrix<Float, -1, 1> terms;
11
+ /*Vector alphas;
12
+ Vector terms;
13
13
  std::vector<std::vector<Float>> slpCache;
14
14
  std::vector<size_t> ndimCnt;*/
15
15
  };
@@ -22,7 +22,8 @@ namespace tomoto
22
22
  typename _ModelState = ModelStateGDMR<_tw>>
23
23
  class GDMRModel : public DMRModel<_tw, _RandGen, _Flags, _Interface,
24
24
  typename std::conditional<std::is_same<_Derived, void>::value, GDMRModel<_tw, _RandGen>, _Derived>::type,
25
- _DocType, _ModelState>
25
+ _DocType, _ModelState
26
+ >
26
27
  {
27
28
  protected:
28
29
  using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, GDMRModel<_tw, _RandGen>, _Derived>::type;
@@ -32,51 +33,60 @@ namespace tomoto
32
33
  friend typename BaseClass::BaseClass::BaseClass;
33
34
  using WeightType = typename BaseClass::WeightType;
34
35
 
35
- Float sigma0 = 3;
36
+ Float sigma0 = 3, orderDecay = 0;
36
37
  std::vector<Float> mdCoefs, mdIntercepts, mdMax;
37
38
  std::vector<uint64_t> degreeByF;
39
+ Eigen::Array<Float, -1, 1> orderDecayCached;
40
+ size_t fCont = 1;
38
41
 
39
- Float getIntegratedLambdaSq(const Eigen::Ref<const Eigen::Matrix<Float, -1, 1>, 0, Eigen::InnerStride<>>& lambdas) const
42
+ Float getIntegratedLambdaSq(const Eigen::Ref<const Vector, 0, Eigen::InnerStride<>>& lambdas) const
40
43
  {
41
- Float ret = pow(lambdas[0] - log(this->alpha), 2) / 2 / pow(this->sigma0, 2);
42
- for (size_t i = 1; i < this->F; ++i)
44
+ Float ret = 0;
45
+ for (size_t i = 0; i < this->F; ++i)
43
46
  {
44
- ret += pow(lambdas[i], 2) / 2 / pow(this->sigma, 2);
47
+ ret += pow(lambdas[this->mdVecSize * i] - log(this->alpha), 2) / 2 / pow(this->sigma0, 2);
48
+ ret += (lambdas.segment(this->mdVecSize * i + 1, fCont - 1).array().pow(2) / 2 * orderDecayCached.segment(1, fCont - 1) / pow(this->sigma, 2)).sum();
49
+ ret += lambdas.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont).array().pow(2).sum() / 2 / pow(this->sigma, 2);
45
50
  }
46
51
  return ret;
47
52
  }
48
53
 
49
- void getIntegratedLambdaSqP(const Eigen::Ref<const Eigen::Matrix<Float, -1, 1>, 0, Eigen::InnerStride<>>& lambdas,
50
- Eigen::Ref<Eigen::Matrix<Float, -1, 1>, 0, Eigen::InnerStride<>> ret) const
54
+ void getIntegratedLambdaSqP(const Eigen::Ref<const Vector, 0, Eigen::InnerStride<>>& lambdas,
55
+ Eigen::Ref<Vector, 0, Eigen::InnerStride<>> ret) const
51
56
  {
52
- ret[0] = (lambdas[0] - log(this->alpha)) / pow(this->sigma0, 2);
53
- for (size_t i = 1; i < this->F; ++i)
57
+ for (size_t i = 0; i < this->F; ++i)
54
58
  {
55
- ret[i] = lambdas[i] / pow(this->sigma, 2);
59
+ ret[this->mdVecSize * i] = (lambdas[this->mdVecSize * i] - log(this->alpha)) / pow(this->sigma0, 2);
60
+ ret.segment(this->mdVecSize * i + 1, fCont - 1) = lambdas.segment(this->mdVecSize * i + 1, fCont - 1).array() * orderDecayCached.segment(1, fCont - 1) / pow(this->sigma, 2);
61
+ ret.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont) = lambdas.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont).array() / pow(this->sigma, 2);
56
62
  }
57
63
  }
58
64
 
59
65
  void initParameters()
60
66
  {
61
- auto dist0 = std::normal_distribution<Float>(log(this->alpha), sigma0);
62
- auto dist = std::normal_distribution<Float>(0, this->sigma);
63
- for (size_t i = 0; i < this->K; ++i) for (size_t j = 0; j < this->F; ++j)
67
+ this->lambda = Eigen::Rand::normalLike(this->lambda, this->rg);
68
+
69
+ for (size_t i = 0; i < this->F; ++i)
64
70
  {
65
- if (j == 0)
71
+ this->lambda.col(this->mdVecSize * i).array() *= sigma0;
72
+ this->lambda.col(this->mdVecSize * i).array() += log(this->alphas.array());
73
+
74
+ for (size_t j = 1; j < fCont; ++j)
66
75
  {
67
- this->lambda(i, j) = dist0(this->rg);
76
+ this->lambda.col(this->mdVecSize * i + j).array() *= this->sigma / std::sqrt(orderDecayCached[j]);
68
77
  }
69
- else
78
+
79
+ for (size_t j = fCont; j < this->mdVecSize; ++j)
70
80
  {
71
- this->lambda(i, j) = dist(this->rg);
81
+ this->lambda.col(this->mdVecSize * i + j).array() *= this->sigma;
72
82
  }
73
83
  }
74
84
  }
75
85
 
76
- Float getNegativeLambdaLL(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) const
86
+ Float getNegativeLambdaLL(Eigen::Ref<Vector> x, Vector& g) const
77
87
  {
78
- auto mappedX = Eigen::Map<Eigen::Matrix<Float, -1, -1>>(x.data(), this->K, this->F);
79
- auto mappedG = Eigen::Map<Eigen::Matrix<Float, -1, -1>>(g.data(), this->K, this->F);
88
+ auto mappedX = Eigen::Map<Matrix>(x.data(), this->K, this->F);
89
+ auto mappedG = Eigen::Map<Matrix>(g.data(), this->K, this->F);
80
90
 
81
91
  Float fx = 0;
82
92
  for (size_t k = 0; k < this->K; ++k)
@@ -87,50 +97,51 @@ namespace tomoto
87
97
  return fx;
88
98
  }
89
99
 
90
- Float evaluateLambdaObj(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g, ThreadPool& pool, _ModelState* localData) const
100
+ /*Float evaluateLambdaObj(Eigen::Ref<Vector> x, Vector& g, ThreadPool& pool, _ModelState* localData) const
91
101
  {
92
102
  // if one of x is greater than maxLambda, return +inf for preventing search more
93
103
  if ((x.array() > this->maxLambda).any()) return INFINITY;
94
104
 
95
105
  const auto K = this->K;
96
- const auto F = this->F;
106
+ const auto KF = this->K * this->F;
97
107
 
98
- auto mappedX = Eigen::Map<Eigen::Matrix<Float, -1, -1>>(x.data(), K, F);
108
+ auto mappedX = Eigen::Map<Matrix>(x.data(), K, this->F);
99
109
  Float fx = -static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
100
110
 
101
- std::vector<std::future<Eigen::Matrix<Float, -1, 1>>> res;
111
+ std::vector<std::future<Vector>> res;
102
112
  const size_t chStride = pool.getNumWorkers() * 8;
103
113
  for (size_t ch = 0; ch < chStride; ++ch)
104
114
  {
105
115
  res.emplace_back(pool.enqueue([&, this](size_t threadId)
106
116
  {
107
117
  auto& ld = localData[threadId];
108
- thread_local Eigen::Matrix<Float, -1, 1> alphas{ K }, tmpK{ K }, terms{ F };
109
- Eigen::Matrix<Float, -1, 1> ret = Eigen::Matrix<Float, -1, 1>::Zero(F * K + 1);
118
+ thread_local Vector alphas{ K }, tmpK{ K }, terms{ fCont };
119
+ Vector ret = Vector::Zero(KF + 1);
110
120
  for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
111
121
  {
112
122
  const auto& doc = this->docs[docId];
113
123
  const auto& vx = doc.metadataNormalized;
124
+ size_t xOffset = doc.metadata * fCont;
114
125
  getTermsFromMd(&vx[0], terms.data());
115
126
  for (Tid k = 0; k < K; ++k)
116
127
  {
117
- alphas[k] = exp(mappedX.row(k) * terms) + this->alphaEps;
118
- ret[K * F] -= math::lgammaT(alphas[k]) - math::lgammaT(doc.numByTopic[k] + alphas[k]);
119
- assert(std::isfinite(ret[K * F]));
128
+ alphas[k] = exp(mappedX.row(k).segment(xOffset, fCont) * terms) + this->alphaEps;
129
+ ret[KF] -= math::lgammaT(alphas[k]) - math::lgammaT(doc.numByTopic[k] + alphas[k]);
130
+ assert(std::isfinite(ret[KF]));
120
131
  if (!std::isfinite(alphas[k]) && alphas[k] > 0) tmpK[k] = 0;
121
132
  else tmpK[k] = -(math::digammaT(alphas[k]) - math::digammaT(doc.numByTopic[k] + alphas[k]));
122
133
  }
123
134
  Float alphaSum = alphas.sum();
124
- ret[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
135
+ ret[KF] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
125
136
  Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
126
137
  if (!std::isfinite(alphaSum) && alphaSum > 0)
127
138
  {
128
- ret[K * F] = -INFINITY;
139
+ ret[KF] = -INFINITY;
129
140
  t = 0;
130
141
  }
131
- for (size_t f = 0; f < F; ++f)
142
+ for (size_t i = 0; i < fCont; ++i)
132
143
  {
133
- ret.segment(f * K, K).array() -= ((tmpK.array() + t) * alphas.array()) * terms[f];
144
+ ret.segment((i + xOffset) * K, K).array() -= ((tmpK.array() + t) * alphas.array()) * terms[i];
134
145
  }
135
146
  assert(ret.allFinite());
136
147
  }
@@ -140,14 +151,14 @@ namespace tomoto
140
151
  for (auto& r : res)
141
152
  {
142
153
  auto ret = r.get();
143
- fx += ret[K * F];
144
- g += ret.head(K * F);
154
+ fx += ret[KF];
155
+ g += ret.head(KF);
145
156
  }
146
157
 
147
158
  // positive fx is an error from limited precision of float.
148
159
  if (fx > 0) return INFINITY;
149
160
  return -fx;
150
- }
161
+ }*/
151
162
 
152
163
  void getTermsFromMd(const Float* vx, Float* out, bool normalize = false) const
153
164
  {
@@ -172,7 +183,7 @@ namespace tomoto
172
183
  }
173
184
  }
174
185
 
175
- for (size_t i = 0; i < this->F; ++i)
186
+ for (size_t i = 0; i < fCont; ++i)
176
187
  {
177
188
  out[i] = 1;
178
189
  for (size_t n = 0; n < degreeByF.size(); ++n)
@@ -180,47 +191,69 @@ namespace tomoto
180
191
  if(digit[n]) out[i] *= slpCache[n][digit[n] - 1];
181
192
  }
182
193
 
183
- size_t u;
184
- for (u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
194
+ for (size_t u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
185
195
  {
186
196
  digit[u] = 0;
187
197
  }
188
- u = std::min(u, degreeByF.size() - 1);
189
198
  }
190
199
  }
191
200
 
192
- template<bool _asymEta>
201
+ Eigen::Array<Float, -1, 1> calcOrderDecay() const
202
+ {
203
+ Eigen::Array<Float, -1, 1> ret{ fCont };
204
+ std::vector<size_t> digit(degreeByF.size());
205
+ std::fill(digit.begin(), digit.end(), 0);
206
+
207
+ for (size_t i = 0; i < fCont; ++i)
208
+ {
209
+ ret[i] = 1;
210
+ for (size_t n = 0; n < degreeByF.size(); ++n)
211
+ {
212
+ ret[i] *= pow(digit[n] + 1, orderDecay * 2);
213
+ }
214
+
215
+ for (size_t u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
216
+ {
217
+ digit[u] = 0;
218
+ }
219
+ }
220
+ return ret;
221
+ }
222
+
223
+ /*template<bool _asymEta>
193
224
  Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
194
225
  {
195
226
  const size_t V = this->realV;
196
227
  assert(vid < V);
197
228
  auto etaHelper = this->template getEtaHelper<_asymEta>();
198
229
  auto& zLikelihood = ld.zLikelihood;
199
- thread_local Eigen::Matrix<Float, -1, 1> terms{ this->F };
230
+ thread_local Vector terms{ fCont };
231
+ size_t xOffset = doc.metadata * fCont;
200
232
  getTermsFromMd(&doc.metadataNormalized[0], terms.data());
201
- zLikelihood = (doc.numByTopic.array().template cast<Float>() + (this->lambda * terms).array().exp() + this->alphaEps)
233
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + (this->lambda.middleCols(xOffset, fCont) * terms).array().exp() + this->alphaEps)
202
234
  * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
203
235
  / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
204
236
 
205
237
  sample::prefixSum(zLikelihood.data(), this->K);
206
238
  return &zLikelihood[0];
207
- }
239
+ }*/
208
240
 
209
- template<typename _DocIter>
241
+ /*template<typename _DocIter>
210
242
  double getLLDocs(_DocIter _first, _DocIter _last) const
211
243
  {
212
244
  const auto K = this->K;
213
245
  double ll = 0;
214
246
 
215
- Eigen::Matrix<Float, -1, 1> alphas(K);
247
+ Vector alphas(K);
216
248
  for (; _first != _last; ++_first)
217
249
  {
218
250
  auto& doc = *_first;
219
- thread_local Eigen::Matrix<Float, -1, 1> terms{ this->F };
251
+ thread_local Vector terms{ fCont };
220
252
  getTermsFromMd(&doc.metadataNormalized[0], terms.data());
253
+ size_t xOffset = doc.metadata * fCont;
221
254
  for (Tid k = 0; k < K; ++k)
222
255
  {
223
- alphas[k] = exp(this->lambda.row(k) * terms) + this->alphaEps;
256
+ alphas[k] = exp(this->lambda.row(k).segment(xOffset, fCont) * terms) + this->alphaEps;
224
257
  }
225
258
  Float alphaSum = alphas.sum();
226
259
  for (Tid k = 0; k < K; ++k)
@@ -231,7 +264,7 @@ namespace tomoto
231
264
  ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum) - math::lgammaT(alphaSum);
232
265
  }
233
266
  return ll;
234
- }
267
+ }*/
235
268
 
236
269
  double getLLRest(const _ModelState& ld) const
237
270
  {
@@ -296,15 +329,44 @@ namespace tomoto
296
329
 
297
330
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
298
331
  {
299
- BaseClass::prepareDoc(doc, docId, wordSize);
332
+ BaseClass::BaseClass::prepareDoc(doc, docId, wordSize);
300
333
  doc.metadataNormalized = normalizeMetadata(doc.metadataOrg);
334
+
335
+ doc.mdVec = Vector::Zero(this->mdVecSize);
336
+ getTermsFromMd(doc.metadataNormalized.data(), doc.mdVec.data());
337
+ for (auto x : doc.multiMetadata)
338
+ {
339
+ doc.mdVec[fCont + x] = 1;
340
+ }
341
+
342
+ auto p = std::make_pair(doc.metadata, doc.mdVec);
343
+ auto it = this->mdHashMap.find(p);
344
+ if (it == this->mdHashMap.end())
345
+ {
346
+ it = this->mdHashMap.emplace(p, this->mdHashMap.size()).first;
347
+ }
348
+ doc.mdHash = it->second;
301
349
  }
302
350
 
303
351
  void initGlobalState(bool initDocs)
304
352
  {
305
353
  BaseClass::BaseClass::initGlobalState(initDocs);
306
- this->F = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
307
- if (initDocs) collectMinMaxMetadata();
354
+ fCont = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
355
+ if (!this->metadataDict.size())
356
+ {
357
+ this->metadataDict.add("");
358
+ }
359
+ this->F = this->metadataDict.size();
360
+ this->mdVecSize = fCont + this->multiMetadataDict.size();
361
+ if (initDocs)
362
+ {
363
+ collectMinMaxMetadata();
364
+ this->lambda = Matrix::Zero(this->K, this->F * this->mdVecSize);
365
+ for (size_t i = 0; i < this->F; ++i)
366
+ {
367
+ this->lambda.col(this->mdVecSize * i) = log(this->alphas.array());
368
+ }
369
+ }
308
370
  else
309
371
  {
310
372
  // Old binary file has metadataNormalized values into `metadataOrg`
@@ -320,13 +382,28 @@ namespace tomoto
320
382
  }
321
383
  }
322
384
  }
385
+
386
+ for (auto& doc : this->docs)
387
+ {
388
+ if (doc.mdVec.size() == this->mdVecSize) continue;
389
+ doc.mdVec = Vector::Zero(this->mdVecSize);
390
+ getTermsFromMd(doc.metadataNormalized.data(), doc.mdVec.data());
391
+ for (auto x : doc.multiMetadata)
392
+ {
393
+ doc.mdVec[fCont + x] = 1;
394
+ }
395
+
396
+ auto p = std::make_pair(doc.metadata, doc.mdVec);
397
+ auto it = this->mdHashMap.find(p);
398
+ if (it == this->mdHashMap.end())
399
+ {
400
+ it = this->mdHashMap.emplace(p, this->mdHashMap.size()).first;
401
+ }
402
+ doc.mdHash = it->second;
403
+ }
323
404
  }
324
405
 
325
- if (initDocs)
326
- {
327
- this->lambda = Eigen::Matrix<Float, -1, -1>::Zero(this->K, this->F);
328
- this->lambda.col(0).fill(log(this->alpha));
329
- }
406
+ orderDecayCached = calcOrderDecay();
330
407
  LBFGSpp::LBFGSParam<Float> param;
331
408
  param.max_iterations = this->maxBFGSIteration;
332
409
  this->solver = decltype(this->solver){ param };
@@ -334,18 +411,17 @@ namespace tomoto
334
411
 
335
412
  public:
336
413
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma0, degreeByF, mdCoefs, mdIntercepts);
337
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma0, degreeByF, mdCoefs, mdIntercepts, mdMax);
414
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma0, orderDecay, degreeByF, mdCoefs, mdIntercepts, mdMax);
338
415
 
339
- GDMRModel(size_t _K = 1, const std::vector<uint64_t>& _degreeByF = {},
340
- Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _sigma0 = 1.0, Float _eta = 0.01,
341
- Float _alphaEps = 1e-10, size_t _rg = std::random_device{}())
342
- : BaseClass(_K, defaultAlpha, _sigma, _eta, _alphaEps, _rg), sigma0(_sigma0), degreeByF(_degreeByF)
416
+ GDMRModel(const GDMRArgs& args)
417
+ : BaseClass(args), sigma0(args.sigma0), orderDecay(args.orderDecay), degreeByF(args.degrees)
343
418
  {
344
- this->F = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
419
+ fCont = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
345
420
  }
346
421
 
347
422
  GETTER(Fs, const std::vector<uint64_t>&, degreeByF);
348
423
  GETTER(Sigma0, Float, sigma0);
424
+ GETTER(OrderDecay, Float, orderDecay);
349
425
 
350
426
  void setSigma0(Float _sigma0) override
351
427
  {
@@ -353,73 +429,94 @@ namespace tomoto
353
429
  }
354
430
 
355
431
  template<bool _const = false>
356
- _DocType& _updateDoc(_DocType& doc, const std::vector<Float>& metadata) const
432
+ _DocType& _updateDoc(_DocType& doc, const std::vector<Float>& metadata, const std::string& metadataCat = {}, const std::vector<std::string>& mdVec = {})
357
433
  {
358
434
  if (metadata.size() != degreeByF.size())
359
- throw std::invalid_argument{ "a length of `metadata` should be equal to a length of `degrees`" };
435
+ throw exc::InvalidArgument{ "a length of `metadata` should be equal to a length of `degrees`" };
360
436
  doc.metadataOrg = metadata;
437
+
438
+ Vid xid;
439
+ if (_const)
440
+ {
441
+ xid = this->metadataDict.toWid(metadataCat);
442
+ if (xid == non_vocab_id) throw exc::InvalidArgument("unknown metadata '" + metadataCat + "'");
443
+
444
+ for (auto& m : mdVec)
445
+ {
446
+ Vid x = this->multiMetadataDict.toWid(m);
447
+ if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
448
+ doc.multiMetadata.emplace_back(x);
449
+ }
450
+ }
451
+ else
452
+ {
453
+ xid = this->metadataDict.add(metadataCat);
454
+
455
+ for (auto& m : mdVec)
456
+ {
457
+ doc.multiMetadata.emplace_back(this->multiMetadataDict.add(m));
458
+ }
459
+ }
460
+ doc.metadata = xid;
361
461
  return doc;
362
462
  }
363
463
 
364
464
  size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
365
465
  {
366
466
  auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
367
- return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::vector<Float>>("metadata")));
467
+ return this->_addDoc(_updateDoc(doc,
468
+ rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
469
+ rawDoc.template getMiscDefault<std::string>("metadata"),
470
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
471
+ ));
368
472
  }
369
473
 
370
474
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
371
475
  {
372
476
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
373
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::vector<Float>>("metadata")));
477
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
478
+ rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
479
+ rawDoc.template getMiscDefault<std::string>("metadata"),
480
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
481
+ ));
374
482
  }
375
483
 
376
484
  size_t addDoc(const RawDoc& rawDoc) override
377
485
  {
378
486
  auto doc = this->_makeFromRawDoc(rawDoc);
379
- return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::vector<Float>>("metadata")));
487
+ return this->_addDoc(_updateDoc(doc,
488
+ rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
489
+ rawDoc.template getMiscDefault<std::string>("metadata"),
490
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
491
+ ));
380
492
  }
381
493
 
382
494
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
383
495
  {
384
496
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
385
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::vector<Float>>("metadata")));
497
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
498
+ rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
499
+ rawDoc.template getMiscDefault<std::string>("metadata"),
500
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
501
+ ));
386
502
  }
387
503
 
388
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
504
+ std::vector<Float> getTDF(const Float* metadata, const std::string& metadataCat, const std::vector<std::string>& multiMetadataCat, bool normalize) const override
389
505
  {
390
- Eigen::Matrix<Float, -1, 1> alphas(this->K);
391
- thread_local Eigen::Matrix<Float, -1, 1> terms{ this->F };
392
- getTermsFromMd(&doc.metadataNormalized[0], terms.data());
393
- for (Tid k = 0; k < this->K; ++k)
394
- {
395
- alphas[k] = exp(this->lambda.row(k) * terms) + this->alphaEps;
396
- }
397
- std::vector<Float> ret(this->K);
398
- Float sum = doc.getSumWordWeight() + alphas.sum();
399
- for (size_t k = 0; k < this->K; ++k)
400
- {
401
- ret[k] = (doc.numByTopic[k] + alphas[k]) / sum;
402
- }
403
- return ret;
404
- }
405
-
406
- std::vector<Float> getLambdaByTopic(Tid tid) const override
407
- {
408
- std::vector<Float> ret(this->F);
409
- if (this->lambda.size())
506
+ Vector terms = Vector::Zero(this->mdVecSize);
507
+ getTermsFromMd(metadata, terms.data(), true);
508
+ for (auto& s : multiMetadataCat)
410
509
  {
411
- Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ ret.data(), (Eigen::Index)ret.size() } = this->lambda.row(tid);
510
+ Vid x = this->multiMetadataDict.toWid(s);
511
+ if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata " + text::quote(s));
512
+ terms[fCont + x] = 1;
412
513
  }
413
- return ret;
414
- }
514
+ Vid x = this->metadataDict.toWid(metadataCat);
515
+ if (x == non_vocab_id) throw exc::InvalidArgument("unknown metadata " + text::quote(metadataCat));
415
516
 
416
- std::vector<Float> getTDF(const Float* metadata, bool normalize) const override
417
- {
418
- Eigen::Matrix<Float, -1, 1> terms{ this->F };
419
- getTermsFromMd(metadata, terms.data(), true);
420
517
  std::vector<Float> ret(this->K);
421
518
  Eigen::Map<Eigen::Array<Float, -1, 1>> retMap{ ret.data(), (Eigen::Index)ret.size() };
422
- retMap = (this->lambda * terms).array();
519
+ retMap = (this->lambda.middleCols(x * this->mdVecSize, this->mdVecSize) * terms).array();
423
520
  if (normalize)
424
521
  {
425
522
  retMap = (retMap - retMap.maxCoeff()).exp();
@@ -428,16 +525,25 @@ namespace tomoto
428
525
  return ret;
429
526
  }
430
527
 
431
- std::vector<Float> getTDFBatch(const Float* metadata, size_t stride, size_t cnt, bool normalize) const override
528
+ std::vector<Float> getTDFBatch(const Float* metadata, const std::string& metadataCat, const std::vector<std::string>& multiMetadataCat, size_t stride, size_t cnt, bool normalize) const override
432
529
  {
433
- Eigen::Matrix<Float, -1, -1> terms{ this->F, (Eigen::Index)cnt };
530
+ Matrix terms = Matrix::Zero(this->mdVecSize, (Eigen::Index)cnt);
434
531
  for (size_t i = 0; i < cnt; ++i)
435
532
  {
436
533
  getTermsFromMd(metadata + stride * i, terms.col(i).data(), true);
437
534
  }
535
+ for (auto& s : multiMetadataCat)
536
+ {
537
+ Vid x = this->multiMetadataDict.toWid(s);
538
+ if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata " + text::quote(s));
539
+ terms.row(fCont + x).setOnes();
540
+ }
541
+ Vid x = this->metadataDict.toWid(metadataCat);
542
+ if (x == non_vocab_id) throw exc::InvalidArgument("unknown metadata " + text::quote(metadataCat));
543
+
438
544
  std::vector<Float> ret(this->K * cnt);
439
545
  Eigen::Map<Eigen::Array<Float, -1, -1>> retMap{ ret.data(), (Eigen::Index)this->K, (Eigen::Index)cnt };
440
- retMap = (this->lambda * terms).array();
546
+ retMap = (this->lambda.middleCols(x * this->mdVecSize, this->mdVecSize) * terms).array();
441
547
  if (normalize)
442
548
  {
443
549
  retMap.rowwise() -= retMap.colwise().maxCoeff();
@@ -446,6 +552,7 @@ namespace tomoto
446
552
  }
447
553
  return ret;
448
554
  }
555
+
449
556
  void setMdRange(const std::vector<Float>& vMin, const std::vector<Float>& vMax) override
450
557
  {
451
558
  mdIntercepts = vMin;