tomoto 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/ext/tomoto/ct.cpp +8 -4
  4. data/ext/tomoto/dmr.cpp +10 -4
  5. data/ext/tomoto/dt.cpp +13 -4
  6. data/ext/tomoto/extconf.rb +1 -1
  7. data/ext/tomoto/gdmr.cpp +14 -6
  8. data/ext/tomoto/hdp.cpp +9 -4
  9. data/ext/tomoto/hlda.cpp +9 -4
  10. data/ext/tomoto/hpa.cpp +9 -4
  11. data/ext/tomoto/lda.cpp +8 -4
  12. data/ext/tomoto/llda.cpp +8 -4
  13. data/ext/tomoto/mglda.cpp +11 -1
  14. data/ext/tomoto/pa.cpp +9 -4
  15. data/ext/tomoto/plda.cpp +8 -4
  16. data/ext/tomoto/slda.cpp +13 -5
  17. data/lib/tomoto/gdmr.rb +2 -2
  18. data/lib/tomoto/version.rb +1 -1
  19. data/vendor/EigenRand/EigenRand/Core.h +6 -1107
  20. data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
  21. data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
  22. data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
  23. data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
  24. data/vendor/EigenRand/EigenRand/EigenRand +2 -2
  25. data/vendor/EigenRand/EigenRand/Macro.h +4 -4
  26. data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
  27. data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
  28. data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
  29. data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
  30. data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
  31. data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
  32. data/vendor/EigenRand/EigenRand/doc.h +142 -25
  33. data/vendor/EigenRand/LICENSE +1 -1
  34. data/vendor/EigenRand/README.md +109 -24
  35. data/vendor/tomotopy/README.kr.rst +27 -6
  36. data/vendor/tomotopy/README.rst +29 -8
  37. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
  38. data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
  39. data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
  40. data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
  41. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
  42. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
  43. data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
  44. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
  45. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
  46. data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
  47. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
  48. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
  49. data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
  50. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
  51. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
  52. data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
  53. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
  54. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
  55. data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
  56. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
  57. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
  58. data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
  59. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
  60. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
  61. data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
  62. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
  63. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
  64. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
  65. data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
  66. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
  67. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
  68. data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
  69. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
  70. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
  71. data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
  72. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
  73. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
  74. data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
  75. data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
  76. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
  77. data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
  78. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
  79. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
  80. data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
  81. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
  82. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
  83. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
  84. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
  85. data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
  86. data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
  87. data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
  88. data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
  89. data/vendor/tomotopy/src/Utils/exception.h +1 -1
  90. data/vendor/tomotopy/src/Utils/math.h +5 -7
  91. data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
  92. data/vendor/tomotopy/src/Utils/text.hpp +8 -0
  93. data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
  94. metadata +9 -7
@@ -3,24 +3,36 @@
3
3
 
4
4
  namespace tomoto
5
5
  {
6
+ class IDMRModel;
7
+
6
8
  template<TermWeight _tw>
7
9
  struct DocumentDMR : public DocumentLDA<_tw>
8
10
  {
9
11
  using BaseDocument = DocumentLDA<_tw>;
10
12
  using DocumentLDA<_tw>::DocumentLDA;
11
13
  uint64_t metadata = 0;
14
+ std::vector<uint64_t> multiMetadata;
15
+ Vector mdVec;
16
+ size_t mdHash = (size_t)-1;
17
+ mutable Matrix cachedAlpha;
18
+
19
+ RawDoc::MiscType makeMisc(const ITopicModel* tm) const override;
12
20
 
13
21
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadata);
14
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata);
22
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata, multiMetadata);
23
+ };
24
+
25
+ struct DMRArgs : public LDAArgs
26
+ {
27
+ Float alphaEps = 1e-10;
28
+ Float sigma = 1.0;
15
29
  };
16
30
 
17
31
  class IDMRModel : public ILDAModel
18
32
  {
19
33
  public:
20
34
  using DefaultDocType = DocumentDMR<TermWeight::one>;
21
- static IDMRModel* create(TermWeight _weight, size_t _K = 1,
22
- Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
23
- size_t seed = std::random_device{}(),
35
+ static IDMRModel* create(TermWeight _weight, const DMRArgs& args,
24
36
  bool scalarRng = false);
25
37
 
26
38
  virtual void setAlphaEps(Float _alphaEps) = 0;
@@ -28,9 +40,26 @@ namespace tomoto
28
40
  virtual void setOptimRepeat(size_t repeat) = 0;
29
41
  virtual size_t getOptimRepeat() const = 0;
30
42
  virtual size_t getF() const = 0;
43
+ virtual size_t getMdVecSize() const = 0;
31
44
  virtual Float getSigma() const = 0;
32
45
  virtual const Dictionary& getMetadataDict() const = 0;
46
+ virtual const Dictionary& getMultiMetadataDict() const = 0;
33
47
  virtual std::vector<Float> getLambdaByMetadata(size_t metadataId) const = 0;
34
48
  virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
49
+
50
+ virtual std::vector<Float> getTopicPrior(
51
+ const std::string& metadata,
52
+ const std::vector<std::string>& multiMetadata,
53
+ bool raw = false
54
+ ) const = 0;
35
55
  };
56
+
57
+ template<TermWeight _tw>
58
+ RawDoc::MiscType DocumentDMR<_tw>::makeMisc(const ITopicModel* tm) const
59
+ {
60
+ RawDoc::MiscType ret = DocumentLDA<_tw>::makeMisc(tm);
61
+ auto inst = static_cast<const IDMRModel*>(tm);
62
+ ret["metadata"] = inst->getMetadataDict().toWord(metadata);
63
+ return ret;
64
+ }
36
65
  }
@@ -2,12 +2,8 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
- /*template class DMRModel<TermWeight::one>;
6
- template class DMRModel<TermWeight::idf>;
7
- template class DMRModel<TermWeight::pmi>;*/
8
-
9
- IDMRModel* IDMRModel::create(TermWeight _weight, size_t _K, Float _defaultAlpha, Float _sigma, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
5
+ IDMRModel* IDMRModel::create(TermWeight _weight, const DMRArgs& args, bool scalarRng)
10
6
  {
11
- TMT_SWITCH_TW(_weight, scalarRng, DMRModel, _K, _defaultAlpha, _sigma, _eta, _alphaEps, seed);
7
+ TMT_SWITCH_TW(_weight, scalarRng, DMRModel, args);
12
8
  }
13
9
  }
@@ -13,7 +13,21 @@ namespace tomoto
13
13
  template<TermWeight _tw>
14
14
  struct ModelStateDMR : public ModelStateLDA<_tw>
15
15
  {
16
- Eigen::Matrix<Float, -1, 1> tmpK;
16
+ Vector tmpK;
17
+ };
18
+
19
+ struct MdHash
20
+ {
21
+ size_t operator()(std::pair<uint64_t, Vector> const& p) const
22
+ {
23
+ size_t seed = p.first;
24
+ for (size_t i = 0; i < p.second.size(); ++i)
25
+ {
26
+ auto elem = p.second[i];
27
+ seed ^= std::hash<decltype(elem)>()(elem) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
28
+ }
29
+ return seed;
30
+ }
17
31
  };
18
32
 
19
33
  template<TermWeight _tw, typename _RandGen,
@@ -35,36 +49,37 @@ namespace tomoto
35
49
 
36
50
  static constexpr char TMID[] = "DMR\0";
37
51
 
38
- Eigen::Matrix<Float, -1, -1> lambda;
39
- Eigen::Matrix<Float, -1, -1> expLambda;
52
+ Matrix lambda;
53
+ mutable std::unordered_map<std::pair<uint64_t, Vector>, size_t, MdHash> mdHashMap;
54
+ mutable Matrix cachedAlphas;
40
55
  Float sigma;
41
- uint32_t F = 0;
56
+ uint32_t F = 0, mdVecSize = 1;
42
57
  uint32_t optimRepeat = 5;
43
58
  Float alphaEps = 1e-10;
44
- Float temperatureScale = 0;
45
59
  static constexpr Float maxLambda = 10;
46
60
  static constexpr size_t maxBFGSIteration = 10;
47
61
 
48
62
  Dictionary metadataDict;
63
+ Dictionary multiMetadataDict;
49
64
  LBFGSpp::LBFGSSolver<Float, LBFGSpp::LineSearchBracketing> solver;
50
65
 
51
- Float getNegativeLambdaLL(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g) const
66
+ Float getNegativeLambdaLL(Eigen::Ref<Vector> x, Vector& g) const
52
67
  {
53
68
  g = (x.array() - log(this->alpha)) / pow(sigma, 2);
54
69
  return (x.array() - log(this->alpha)).pow(2).sum() / 2 / pow(sigma, 2);
55
70
  }
56
71
 
57
- Float evaluateLambdaObj(Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g, ThreadPool& pool, _ModelState* localData) const
72
+ Float evaluateLambdaObj(Eigen::Ref<Vector> x, Vector& g, ThreadPool& pool, _ModelState* localData) const
58
73
  {
59
74
  // if one of x is greater than maxLambda, return +inf for preventing searching more
60
75
  if ((x.array() > maxLambda).any()) return INFINITY;
61
76
 
62
77
  const auto K = this->K;
63
78
 
64
- Float fx = - static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
65
- auto alphas = (x.array().exp() + alphaEps).eval();
79
+ Float fx = -static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
80
+ Eigen::Map<Matrix> xReshaped{ x.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) };
66
81
 
67
- std::vector<std::future<Eigen::Matrix<Float, -1, 1>>> res;
82
+ std::vector<std::future<Eigen::Array<Float, -1, 1>>> res;
68
83
  const size_t chStride = pool.getNumWorkers() * 8;
69
84
  for (size_t ch = 0; ch < chStride; ++ch)
70
85
  {
@@ -72,28 +87,28 @@ namespace tomoto
72
87
  {
73
88
  auto& tmpK = localData[threadId].tmpK;
74
89
  if (!tmpK.size()) tmpK.resize(this->K);
75
- Eigen::Matrix<Float, -1, 1> val = Eigen::Matrix<Float, -1, 1>::Zero(K * F + 1);
90
+ Eigen::Array<Float, -1, 1> val = Eigen::Array<Float, -1, 1>::Zero(K * F * mdVecSize + 1);
91
+ Eigen::Map<Matrix> grad{ val.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) };
92
+ Float& fx = val[K * F * mdVecSize];
76
93
  for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
77
94
  {
78
95
  const auto& doc = this->docs[docId];
79
- auto alphaDoc = alphas.segment(doc.metadata * K, K);
96
+ auto alphaDoc = ((xReshaped.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps).matrix().eval();
80
97
  Float alphaSum = alphaDoc.sum();
81
98
  for (Tid k = 0; k < K; ++k)
82
99
  {
83
- val[K * F] -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
100
+ fx -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
84
101
  if (!std::isfinite(alphaDoc[k]) && alphaDoc[k] > 0) tmpK[k] = 0;
85
102
  else tmpK[k] = -(math::digammaT(alphaDoc[k]) - math::digammaT(doc.numByTopic[k] + alphaDoc[k]));
86
103
  }
87
- //val[K * F] = -(lgammaApprox(alphaDoc.array()) - lgammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array())).sum();
88
- //tmpK = -(digammaApprox(alphaDoc.array()) - digammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array()));
89
- val[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
104
+ fx += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
90
105
  Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
91
106
  if (!std::isfinite(alphaSum) && alphaSum > 0)
92
107
  {
93
- val[K * F] = -INFINITY;
108
+ fx = -INFINITY;
94
109
  t = 0;
95
110
  }
96
- val.segment(doc.metadata * K, K).array() -= alphaDoc.array() * (tmpK.array() + t);
111
+ grad.middleCols(doc.metadata * mdVecSize, mdVecSize) -= (alphaDoc.array() * (tmpK.array() + t)).matrix() * doc.mdVec.transpose();
97
112
  }
98
113
  return val;
99
114
  }));
@@ -101,8 +116,8 @@ namespace tomoto
101
116
  for (auto& r : res)
102
117
  {
103
118
  auto ret = r.get();
104
- fx += ret[K * F];
105
- g += ret.head(K * F);
119
+ fx += ret[K * F * mdVecSize];
120
+ g += ret.head(K * F * mdVecSize).matrix();
106
121
  }
107
122
 
108
123
  // positive fx is an error from limited precision of float.
@@ -112,24 +127,24 @@ namespace tomoto
112
127
 
113
128
  void initParameters()
114
129
  {
115
- auto dist = std::normal_distribution<Float>(log(this->alpha), sigma);
116
- for (size_t i = 0; i < this->K; ++i) for (size_t j = 0; j < F; ++j)
130
+ lambda = Eigen::Rand::normalLike(lambda, this->rg, 0, sigma);
131
+ for (size_t f = 0; f < F; ++f)
117
132
  {
118
- lambda(i, j) = dist(this->rg);
133
+ lambda.col(f * mdVecSize) += this->alphas.array().log().matrix();
119
134
  }
120
135
  }
121
136
 
122
137
  void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
123
138
  {
124
- Eigen::Matrix<Float, -1, -1> bLambda;
139
+ Matrix bLambda;
125
140
  Float fx = 0, bestFx = INFINITY;
126
141
  for (size_t i = 0; i < optimRepeat; ++i)
127
142
  {
128
143
  static_cast<DerivedClass*>(this)->initParameters();
129
- int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Eigen::Matrix<Float, -1, 1>> x, Eigen::Matrix<Float, -1, 1>& g)
144
+ int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Vector> x, Vector& g)
130
145
  {
131
146
  return static_cast<DerivedClass*>(this)->evaluateLambdaObj(x, g, pool, localData);
132
- }, Eigen::Map<Eigen::Matrix<Float, -1, 1>>(lambda.data(), lambda.size()), fx);
147
+ }, Eigen::Map<Vector>(lambda.data(), lambda.size()), fx);
133
148
 
134
149
  if (fx < bestFx)
135
150
  {
@@ -140,44 +155,60 @@ namespace tomoto
140
155
  }
141
156
  if (!std::isfinite(bestFx))
142
157
  {
143
- throw exception::TrainingError{ "optimizing parameters has been failed!" };
158
+ throw exc::TrainingError{ "optimizing parameters has been failed!" };
144
159
  }
145
160
  lambda = bLambda;
161
+ updateCachedAlphas();
146
162
  //std::cerr << fx << std::endl;
147
- expLambda = lambda.array().exp() + alphaEps;
148
163
  }
149
164
 
150
- int restoreFromTrainingError(const exception::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
165
+ int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
151
166
  {
152
167
  std::cerr << "Failed to optimize! Reset prior and retry!" << std::endl;
153
168
  lambda.setZero();
154
- expLambda = lambda.array().exp() + alphaEps;
169
+ updateCachedAlphas();
155
170
  return 0;
156
171
  }
157
172
 
173
+ auto getCachedAlpha(const _DocType& doc) const
174
+ {
175
+ if (doc.mdHash < cachedAlphas.cols())
176
+ {
177
+ return cachedAlphas.col(doc.mdHash);
178
+ }
179
+ else
180
+ {
181
+ if (!doc.cachedAlpha.size())
182
+ {
183
+ doc.cachedAlpha = (lambda.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps;
184
+ }
185
+ return doc.cachedAlpha.col(0);
186
+ }
187
+ }
188
+
158
189
  template<bool _asymEta>
159
190
  Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
160
191
  {
161
192
  const size_t V = this->realV;
162
193
  assert(vid < V);
163
194
  auto etaHelper = this->template getEtaHelper<_asymEta>();
195
+ auto alphas = getCachedAlpha(doc);
164
196
  auto& zLikelihood = ld.zLikelihood;
165
- zLikelihood = (doc.numByTopic.array().template cast<Float>() + this->expLambda.col(doc.metadata).array())
197
+ zLikelihood = (doc.numByTopic.array().template cast<Float>() + alphas.array())
166
198
  * (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
167
199
  / (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
168
200
 
169
201
  sample::prefixSum(zLikelihood.data(), this->K);
170
202
  return &zLikelihood[0];
171
203
  }
172
-
173
204
 
174
205
  double getLLDocTopic(const _DocType& doc) const
175
206
  {
176
207
  const size_t V = this->realV;
177
208
  const auto K = this->K;
178
209
 
179
- auto alphaDoc = expLambda.col(doc.metadata);
180
-
210
+ auto alphaDoc = getCachedAlpha(doc);
211
+
181
212
  Float ll = 0;
182
213
  Float alphaSum = alphaDoc.sum();
183
214
  for (Tid k = 0; k < K; ++k)
@@ -199,7 +230,7 @@ namespace tomoto
199
230
  for (; _first != _last; ++_first)
200
231
  {
201
232
  auto& doc = *_first;
202
- auto alphaDoc = expLambda.col(doc.metadata);
233
+ auto alphaDoc = getCachedAlpha(doc);
203
234
  Float alphaSum = alphaDoc.sum();
204
235
 
205
236
  for (Tid k = 0; k < K; ++k)
@@ -234,45 +265,133 @@ namespace tomoto
234
265
  return ll;
235
266
  }
236
267
 
268
+ void updateCachedAlphas() const
269
+ {
270
+ cachedAlphas.resize(this->K, mdHashMap.size());
271
+
272
+ for (auto& p : mdHashMap)
273
+ {
274
+ cachedAlphas.col(p.second) = (lambda.middleCols(p.first.first * mdVecSize, mdVecSize) * p.first.second).array().exp() + alphaEps;
275
+ }
276
+ }
277
+
278
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
279
+ {
280
+ BaseClass::prepareDoc(doc, docId, wordSize);
281
+
282
+ doc.mdVec = Vector::Zero(mdVecSize);
283
+ doc.mdVec[0] = 1;
284
+ for (auto x : doc.multiMetadata)
285
+ {
286
+ doc.mdVec[x + 1] = 1;
287
+ }
288
+
289
+ auto p = std::make_pair(doc.metadata, doc.mdVec);
290
+ auto it = mdHashMap.find(p);
291
+ if (it == mdHashMap.end())
292
+ {
293
+ it = mdHashMap.emplace(p, mdHashMap.size()).first;
294
+ }
295
+ doc.mdHash = it->second;
296
+ }
297
+
237
298
  void initGlobalState(bool initDocs)
238
299
  {
239
300
  BaseClass::initGlobalState(initDocs);
240
- this->globalState.tmpK = Eigen::Matrix<Float, -1, 1>::Zero(this->K);
301
+ this->globalState.tmpK = Vector::Zero(this->K);
241
302
  F = metadataDict.size();
303
+ mdVecSize = multiMetadataDict.size() + 1;
242
304
  if (initDocs)
243
305
  {
244
- lambda = Eigen::Matrix<Float, -1, -1>::Constant(this->K, F, log(this->alpha));
306
+ lambda.resize(this->K, F * mdVecSize);
307
+ for (size_t f = 0; f < F; ++f)
308
+ {
309
+ lambda.col(f * mdVecSize) = this->alphas.array().log();
310
+ lambda.middleCols(f * mdVecSize + 1, mdVecSize - 1).setZero();
311
+ }
245
312
  }
313
+ else
314
+ {
315
+ for (auto& doc : this->docs)
316
+ {
317
+ if (doc.mdVec.size() == mdVecSize) continue;
318
+ doc.mdVec = Vector::Zero(mdVecSize);
319
+ doc.mdVec[0] = 1;
320
+ for (auto x : doc.multiMetadata)
321
+ {
322
+ doc.mdVec[x + 1] = 1;
323
+ }
324
+
325
+ auto p = std::make_pair(doc.metadata, doc.mdVec);
326
+ auto it = this->mdHashMap.find(p);
327
+ if (it == this->mdHashMap.end())
328
+ {
329
+ it = this->mdHashMap.emplace(p, mdHashMap.size()).first;
330
+ }
331
+ doc.mdHash = it->second;
332
+ }
333
+ }
334
+
246
335
  if (_Flags & flags::continuous_doc_data) this->numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, this->docs.size());
247
- expLambda = lambda.array().exp();
248
336
  LBFGSpp::LBFGSParam<Float> param;
249
337
  param.max_iterations = maxBFGSIteration;
250
338
  solver = decltype(solver){ param };
251
339
  }
252
340
 
341
+ void prepareShared()
342
+ {
343
+ BaseClass::prepareShared();
344
+
345
+ for (auto doc : this->docs)
346
+ {
347
+ if (doc.mdHash != (size_t)-1) continue;
348
+
349
+ auto p = std::make_pair(doc.metadata, doc.mdVec);
350
+ auto it = mdHashMap.find(p);
351
+ if (it == mdHashMap.end())
352
+ {
353
+ it = mdHashMap.emplace(p, mdHashMap.size()).first;
354
+ }
355
+ doc.mdHash = it->second;
356
+ }
357
+
358
+ updateCachedAlphas();
359
+ }
360
+
253
361
  public:
254
362
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma, alphaEps, metadataDict, lambda);
255
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda);
363
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda, multiMetadataDict);
256
364
 
257
- DMRModel(size_t _K = 1, Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01,
258
- Float _alphaEps = 0, size_t _rg = std::random_device{}())
259
- : BaseClass(_K, defaultAlpha, _eta, _rg), sigma(_sigma), alphaEps(_alphaEps)
365
+ DMRModel(const DMRArgs& args)
366
+ : BaseClass(args), sigma(args.sigma), alphaEps(args.alphaEps)
260
367
  {
261
- if (_sigma <= 0) THROW_ERROR_WITH_INFO(std::runtime_error, text::format("wrong sigma value (sigma = %f)", _sigma));
368
+ if (sigma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong sigma value (sigma = %f)", sigma));
262
369
  }
263
370
 
264
371
  template<bool _const = false>
265
- _DocType& _updateDoc(_DocType& doc, const std::string& metadata)
372
+ _DocType& _updateDoc(_DocType& doc, const std::string& metadata, const std::vector<std::string>& mdVec = {})
266
373
  {
267
374
  Vid xid;
268
375
  if (_const)
269
376
  {
270
377
  xid = metadataDict.toWid(metadata);
271
- if (xid == (Vid)-1) throw std::invalid_argument("unknown metadata");
378
+ if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'");
379
+
380
+ for (auto& m : mdVec)
381
+ {
382
+ Vid x = multiMetadataDict.toWid(m);
383
+ if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
384
+ doc.multiMetadata.emplace_back(x);
385
+ }
272
386
  }
273
387
  else
274
388
  {
275
389
  xid = metadataDict.add(metadata);
390
+
391
+ for (auto& m : mdVec)
392
+ {
393
+ doc.multiMetadata.emplace_back(multiMetadataDict.add(m));
394
+ }
276
395
  }
277
396
  doc.metadata = xid;
278
397
  return doc;
@@ -281,28 +400,41 @@ namespace tomoto
281
400
  size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
282
401
  {
283
402
  auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
284
- return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::string>("metadata")));
403
+ return this->_addDoc(_updateDoc(doc,
404
+ rawDoc.template getMisc<std::string>("metadata"),
405
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
406
+ ));
285
407
  }
286
408
 
287
409
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
288
410
  {
289
411
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
290
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::string>("metadata")));
412
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
413
+ rawDoc.template getMisc<std::string>("metadata"),
414
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
415
+ ));
291
416
  }
292
417
 
293
418
  size_t addDoc(const RawDoc& rawDoc) override
294
419
  {
295
420
  auto doc = this->_makeFromRawDoc(rawDoc);
296
- return this->_addDoc(_updateDoc(doc, rawDoc.template getMisc<std::string>("metadata")));
421
+ return this->_addDoc(_updateDoc(doc,
422
+ rawDoc.template getMisc<std::string>("metadata"),
423
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
424
+ ));
297
425
  }
298
426
 
299
427
  std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
300
428
  {
301
429
  auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
302
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMisc<std::string>("metadata")));
430
+ return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
431
+ rawDoc.template getMisc<std::string>("metadata"),
432
+ rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
433
+ ));
303
434
  }
304
435
 
305
436
  GETTER(F, size_t, F);
437
+ GETTER(MdVecSize, size_t, mdVecSize);
306
438
  GETTER(Sigma, Float, sigma);
307
439
  GETTER(AlphaEps, Float, alphaEps);
308
440
  GETTER(OptimRepeat, size_t, optimRepeat);
@@ -317,12 +449,19 @@ namespace tomoto
317
449
  optimRepeat = _optimRepeat;
318
450
  }
319
451
 
320
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
452
+ std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
321
453
  {
322
454
  std::vector<Float> ret(this->K);
323
- auto alphaDoc = expLambda.col(doc.metadata);
324
- Eigen::Map<Eigen::Matrix<Float, -1, 1>>{ret.data(), this->K}.array() =
325
- (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum());
455
+ auto alphaDoc = getCachedAlpha(doc);
456
+ Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
457
+ if (normalize)
458
+ {
459
+ m = (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum());
460
+ }
461
+ else
462
+ {
463
+ m = doc.numByTopic.array().template cast<Float>() + alphaDoc.array();
464
+ }
326
465
  return ret;
327
466
  }
328
467
 
@@ -330,17 +469,52 @@ namespace tomoto
330
469
  {
331
470
  assert(metadataId < metadataDict.size());
332
471
  auto l = lambda.col(metadataId);
333
- return { l.data(), l.data() + this->K };
472
+ return { l.data(), l.data() + l.size() };
334
473
  }
335
474
 
336
475
  std::vector<Float> getLambdaByTopic(Tid tid) const override
337
476
  {
338
- assert(tid < this->K);
339
- auto l = lambda.row(tid);
340
- return { l.data(), l.data() + F };
477
+ std::vector<Float> ret(F * mdVecSize);
478
+ if (this->lambda.size())
479
+ {
480
+ Eigen::Map<Vector>{ ret.data(), (Eigen::Index)ret.size() } = this->lambda.row(tid);
481
+ }
482
+ return ret;
483
+ }
484
+
485
+ std::vector<Float> getTopicPrior(const std::string& metadata,
486
+ const std::vector<std::string>& mdVec,
487
+ bool raw = false
488
+ ) const override
489
+ {
490
+ Vid xid = metadataDict.toWid(metadata);
491
+ if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'");
492
+
493
+ Vector xs = Vector::Zero(mdVecSize);
494
+ xs[0] = 1;
495
+ for (auto& m : mdVec)
496
+ {
497
+ Vid x = multiMetadataDict.toWid(m);
498
+ if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
499
+ xs[x + 1] = 1;
500
+ }
501
+
502
+ std::vector<Float> ret(this->K);
503
+ Eigen::Map<Vector> map{ ret.data(), (Eigen::Index)ret.size() };
504
+
505
+ if (raw)
506
+ {
507
+ map = lambda.middleCols(xid * mdVecSize, mdVecSize) * xs;
508
+ }
509
+ else
510
+ {
511
+ map = (lambda.middleCols(xid * mdVecSize, mdVecSize) * xs).array().exp() + alphaEps;
512
+ }
513
+ return ret;
341
514
  }
342
515
 
343
516
  const Dictionary& getMetadataDict() const override { return metadataDict; }
517
+ const Dictionary& getMultiMetadataDict() const override { return multiMetadataDict; }
344
518
  };
345
519
 
346
520
  /* This is for preventing 'undefined symbol' problem in compiling by clang. */