tomoto 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -107,18 +107,16 @@ namespace tomoto
107
107
  e = edd.chunkOffsetByDoc(partitionId + 1, docId);
108
108
  }
109
109
 
110
- size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
111
-
112
110
  const auto K = this->K;
113
111
  for (size_t w = b; w < e; ++w)
114
112
  {
115
113
  if (doc.words[w] >= this->realV) continue;
116
- addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
117
- auto dist = getVZLikelihoods(ld, doc, doc.words[w] - vOffset, doc.sents[w]);
114
+ addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
115
+ auto dist = getVZLikelihoods(ld, doc, doc.words[w], doc.sents[w]);
118
116
  auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
119
117
  doc.Vs[w] = vz / (K + KL);
120
118
  doc.Zs[w] = vz % (K + KL);
121
- addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
119
+ addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
122
120
  }
123
121
  }
124
122
 
@@ -294,7 +292,7 @@ namespace tomoto
294
292
  doc.Zs = tvector<Tid>(wordSize);
295
293
  doc.Vs.resize(wordSize);
296
294
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
297
- doc.numByTopic.init(nullptr, this->K + KL);
295
+ doc.numByTopic.init(nullptr, this->K + KL, 1);
298
296
  doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
299
297
  doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
300
298
  doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
@@ -308,7 +306,8 @@ namespace tomoto
308
306
  if (initDocs)
309
307
  {
310
308
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
311
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
309
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
310
+ this->globalState.numByTopicWord.init(nullptr, this->K + KL, V);
312
311
  }
313
312
  }
314
313
 
@@ -533,7 +532,7 @@ namespace tomoto
533
532
  template<typename _TopicModel>
534
533
  void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
535
534
  {
536
- this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
535
+ this->numByTopic.init(ptr, mdl.getK() + mdl.getKL(), 1);
537
536
  numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
538
537
  for (size_t i = 0; i < this->Zs.size(); ++i)
539
538
  {
@@ -144,7 +144,7 @@ namespace tomoto
144
144
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
145
145
  e = edd.vChunkOffset[partitionId];
146
146
 
147
- localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
147
+ localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
148
148
  localData[partitionId].numByTopic = globalState.numByTopic;
149
149
  localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2;
150
150
  localData[partitionId].numByTopic2 = globalState.numByTopic2;
@@ -157,8 +157,6 @@ namespace tomoto
157
157
  template<ParallelScheme _ps, typename _ExtraDocData>
158
158
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
159
159
  {
160
- std::vector<std::future<void>> res;
161
-
162
160
  if (_ps == ParallelScheme::copy_merge)
163
161
  {
164
162
  tState = globalState;
@@ -177,19 +175,12 @@ namespace tomoto
177
175
  globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
178
176
  globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0);
179
177
  globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0);
180
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
181
- }
182
-
183
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
184
- {
185
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
186
- {
187
- localData[i] = globalState;
188
- }));
178
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
189
179
  }
190
180
  }
191
181
  else if (_ps == ParallelScheme::partition)
192
182
  {
183
+ std::vector<std::future<void>> res;
193
184
  res = pool.enqueueToAll([&](size_t partitionId)
194
185
  {
195
186
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
@@ -197,7 +188,6 @@ namespace tomoto
197
188
  globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
198
189
  });
199
190
  for (auto& r : res) r.get();
200
- res.clear();
201
191
 
202
192
  tState.numByTopic1_2 = globalState.numByTopic1_2;
203
193
  globalState.numByTopic1_2 = localData[0].numByTopic1_2;
@@ -209,11 +199,31 @@ namespace tomoto
209
199
  // make all count being positive
210
200
  if (_tw != TermWeight::one)
211
201
  {
212
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
202
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
213
203
  }
214
204
  globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum();
215
205
  globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum();
216
206
 
207
+ }
208
+ }
209
+
210
+
211
+ template<ParallelScheme _ps>
212
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
213
+ {
214
+ std::vector<std::future<void>> res;
215
+ if (_ps == ParallelScheme::copy_merge)
216
+ {
217
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
218
+ {
219
+ res.emplace_back(pool.enqueue([&, i](size_t)
220
+ {
221
+ localData[i] = globalState;
222
+ }));
223
+ }
224
+ }
225
+ else if (_ps == ParallelScheme::partition)
226
+ {
217
227
  res = pool.enqueueToAll([&](size_t threadId)
218
228
  {
219
229
  localData[threadId].numByTopic = globalState.numByTopic;
@@ -221,7 +231,6 @@ namespace tomoto
221
231
  localData[threadId].numByTopic2 = globalState.numByTopic2;
222
232
  });
223
233
  }
224
-
225
234
  for (auto& r : res) r.get();
226
235
  }
227
236
 
@@ -304,7 +313,8 @@ namespace tomoto
304
313
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
305
314
  this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2);
306
315
  this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
307
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
316
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
317
+ this->globalState.numByTopicWord.init(nullptr, K2, V);
308
318
  }
309
319
  }
310
320
 
@@ -0,0 +1,27 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentPTM : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ using WeightType = typename DocumentLDA<_tw>::WeightType;
12
+
13
+ uint64_t pseudoDoc = 0;
14
+
15
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, pseudoDoc);
16
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, pseudoDoc);
17
+ };
18
+
19
+ class IPTModel : public ILDAModel
20
+ {
21
+ public:
22
+ using DefaultDocType = DocumentPTM<TermWeight::one>;
23
+ static IPTModel* create(TermWeight _weight, size_t _K = 1, size_t _P = 100,
24
+ Float alpha = 0.1, Float eta = 0.01, Float lambda = 0.01, size_t seed = std::random_device{}(),
25
+ bool scalarRng = false);
26
+ };
27
+ }
@@ -0,0 +1,10 @@
1
+ #include "PTModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+
6
+ IPTModel* IPTModel::create(TermWeight _weight, size_t _K, size_t _P, Float _alpha, Float _eta, Float _lambda, size_t seed, bool scalarRng)
7
+ {
8
+ TMT_SWITCH_TW(_weight, scalarRng, PTModel, _K, _P, _alpha, _eta, _lambda, seed);
9
+ }
10
+ }
@@ -0,0 +1,273 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "PT.h"
4
+
5
+ /*
6
+ Implementation of Pseudo-document topic model using Gibbs sampling by bab2min
7
+
8
+ Zuo, Y., Wu, J., Zhang, H., Lin, H., Wang, F., Xu, K., & Xiong, H. (2016, August). Topic modeling of short texts: A pseudo-document view. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 2105-2114).
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ template<TermWeight _tw>
14
+ struct ModelStatePTM : public ModelStateLDA<_tw>
15
+ {
16
+ using WeightType = typename ModelStateLDA<_tw>::WeightType;
17
+
18
+ Eigen::Array<Float, -1, 1> pLikelihood;
19
+ Eigen::ArrayXi numDocsByPDoc;
20
+ Eigen::Matrix<WeightType, -1, -1> numByTopicPDoc;
21
+
22
+ //DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>);
23
+ };
24
+
25
+ template<TermWeight _tw, typename _RandGen,
26
+ typename _Interface = IPTModel,
27
+ typename _Derived = void,
28
+ typename _DocType = DocumentPTM<_tw>,
29
+ typename _ModelState = ModelStatePTM<_tw>>
30
+ class PTModel : public LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface,
31
+ typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type,
32
+ _DocType, _ModelState>
33
+ {
34
+ protected:
35
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type;
36
+ using BaseClass = LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
37
+ friend BaseClass;
38
+ friend typename BaseClass::BaseClass;
39
+ using WeightType = typename BaseClass::WeightType;
40
+
41
+ static constexpr char TMID[] = "PTM";
42
+
43
+ uint64_t numPDocs;
44
+ Float lambda;
45
+ uint32_t pseudoDocSamplingInterval = 10;
46
+
47
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
48
+ {
49
+ const auto K = this->K;
50
+ for (size_t i = 0; i < 10; ++i)
51
+ {
52
+ Float denom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc.col(i).sum(); }, numPDocs, this->alphas.sum());
53
+ for (size_t k = 0; k < K; ++k)
54
+ {
55
+ Float nom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc(k, i);}, numPDocs, this->alphas(k));
56
+ this->alphas(k) = std::max(nom / denom * this->alphas(k), 1e-5f);
57
+ }
58
+ }
59
+ }
60
+
61
+ void samplePseudoDoc(ThreadPool* pool, _ModelState& ld, _RandGen& rgs, _DocType& doc) const
62
+ {
63
+ if (doc.getSumWordWeight() == 0) return;
64
+ Eigen::Array<WeightType, -1, 1> docTopicDist = Eigen::Array<WeightType, -1, 1>::Zero(this->K);
65
+ for (size_t i = 0; i < doc.words.size(); ++i)
66
+ {
67
+ if (doc.words[i] >= this->realV) continue;
68
+ this->template addWordTo<-1>(ld, doc, i, doc.words[i], doc.Zs[i]);
69
+ typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
70
+ = _tw != TermWeight::one ? doc.wordWeights[i] : 1;
71
+ docTopicDist[doc.Zs[i]] += weight;
72
+ }
73
+ --ld.numDocsByPDoc[doc.pseudoDoc];
74
+
75
+ if (pool)
76
+ {
77
+ std::vector<std::future<void>> futures;
78
+ for (size_t w = 0; w < pool->getNumWorkers(); ++w)
79
+ {
80
+ futures.emplace_back(pool->enqueue([&](size_t, size_t w)
81
+ {
82
+ for (size_t p = w; p < numPDocs; p += pool->getNumWorkers())
83
+ {
84
+ Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
85
+ Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
86
+ ld.pLikelihood[p] = ax - bx;
87
+ }
88
+ }, w));
89
+ }
90
+ for (auto& f : futures) f.get();
91
+ }
92
+ else
93
+ {
94
+ for (size_t p = 0; p < numPDocs; ++p)
95
+ {
96
+ Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
97
+ Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
98
+ ld.pLikelihood[p] = ax - bx;
99
+ }
100
+ }
101
+ ld.pLikelihood = (ld.pLikelihood - ld.pLikelihood.maxCoeff()).exp();
102
+ ld.pLikelihood *= ld.numDocsByPDoc.template cast<Float>() + lambda;
103
+
104
+ sample::prefixSum(ld.pLikelihood.data(), numPDocs);
105
+ doc.pseudoDoc = sample::sampleFromDiscreteAcc(ld.pLikelihood.data(), ld.pLikelihood.data() + numPDocs, rgs);
106
+
107
+ ++ld.numDocsByPDoc[doc.pseudoDoc];
108
+ doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
109
+ for (size_t i = 0; i < doc.words.size(); ++i)
110
+ {
111
+ if (doc.words[i] >= this->realV) continue;
112
+ this->template addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i]);
113
+ }
114
+ }
115
+
116
+ template<ParallelScheme _ps, bool _infer, typename _DocIter>
117
+ void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
118
+ _DocIter docFirst, _DocIter docLast) const
119
+ {
120
+ if (this->globalStep % pseudoDocSamplingInterval) return;
121
+ for (; docFirst != docLast; ++docFirst)
122
+ {
123
+ samplePseudoDoc(pool, globalState, rgs[0], *docFirst);
124
+ }
125
+ }
126
+
127
+ template<typename _DocIter>
128
+ double getLLDocs(_DocIter _first, _DocIter _last) const
129
+ {
130
+ double ll = 0;
131
+ // doc-topic distribution
132
+ for (; _first != _last; ++_first)
133
+ {
134
+ auto& doc = *_first;
135
+ }
136
+ return ll;
137
+ }
138
+
139
+ double getLLRest(const _ModelState& ld) const
140
+ {
141
+ double ll = BaseClass::getLLRest(ld);
142
+ const size_t V = this->realV;
143
+ ll -= math::lgammaT(ld.numDocsByPDoc.sum() + lambda * numPDocs) - math::lgammaT(lambda * numPDocs);
144
+ // pseudo_doc-topic distribution
145
+ for (size_t p = 0; p < numPDocs; ++p)
146
+ {
147
+ ll += math::lgammaT(ld.numDocsByPDoc[p] + lambda) - math::lgammaT(lambda);
148
+ ll -= math::lgammaT(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
149
+ for (Tid k = 0; k < this->K; ++k)
150
+ {
151
+ ll += math::lgammaT(ld.numByTopicPDoc(k, p) + this->alphas[k]) - math::lgammaT(this->alphas[k]);
152
+ }
153
+ }
154
+ return ll;
155
+ }
156
+
157
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
158
+ {
159
+ sortAndWriteOrder(doc.words, doc.wOrder);
160
+ doc.numByTopic.init((WeightType*)this->globalState.numByTopicPDoc.col(0).data(), this->K, 1);
161
+ doc.Zs = tvector<Tid>(wordSize);
162
+ if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
163
+ }
164
+
165
+ void initGlobalState(bool initDocs)
166
+ {
167
+ this->alphas.resize(this->K);
168
+ this->alphas.array() = this->alpha;
169
+ this->globalState.pLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(numPDocs);
170
+ this->globalState.numDocsByPDoc = Eigen::ArrayXi::Zero(numPDocs);
171
+ this->globalState.numByTopicPDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, numPDocs);
172
+ BaseClass::initGlobalState(initDocs);
173
+ }
174
+
175
+ struct Generator
176
+ {
177
+ std::uniform_int_distribution<uint64_t> psi;
178
+ std::uniform_int_distribution<Tid> theta;
179
+ };
180
+
181
+ Generator makeGeneratorForInit(const _DocType*) const
182
+ {
183
+ return Generator{
184
+ std::uniform_int_distribution<uint64_t>{0, numPDocs - 1},
185
+ std::uniform_int_distribution<Tid>{0, (Tid)(this->K - 1)}
186
+ };
187
+ }
188
+
189
+ template<bool _Infer>
190
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
191
+ {
192
+ if (i == 0)
193
+ {
194
+ doc.pseudoDoc = g.psi(rgs);
195
+ ++ld.numDocsByPDoc[doc.pseudoDoc];
196
+ doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
197
+ }
198
+ auto& z = doc.Zs[i];
199
+ auto w = doc.words[i];
200
+ if (this->etaByTopicWord.size())
201
+ {
202
+ auto col = this->etaByTopicWord.col(w);
203
+ z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
204
+ }
205
+ else
206
+ {
207
+ z = g.theta(rgs);
208
+ }
209
+ this->template addWordTo<1>(ld, doc, i, w, z);
210
+ }
211
+
212
+ template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
213
+ void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
214
+ _DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
215
+ {
216
+ // single-threaded sampling
217
+ if (_ps == ParallelScheme::none)
218
+ {
219
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
220
+ {
221
+ static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
222
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
223
+ docFirst[id], edd, id,
224
+ *localData, *rgs, this->globalStep, 0);
225
+
226
+ });
227
+ }
228
+ // multi-threaded sampling on partition and update into global
229
+ else if (_ps == ParallelScheme::partition)
230
+ {
231
+ const size_t chStride = pool.getNumWorkers();
232
+ for (size_t i = 0; i < chStride; ++i)
233
+ {
234
+ res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
235
+ {
236
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[partitionId](), [&](size_t id)
237
+ {
238
+ if ((docFirst[id].pseudoDoc + partitionId) % chStride != i) return;
239
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
240
+ docFirst[id], edd, id,
241
+ localData[partitionId], rgs[partitionId], this->globalStep, partitionId
242
+ );
243
+ });
244
+ });
245
+ for (auto& r : res) r.get();
246
+ res.clear();
247
+ }
248
+ }
249
+ else
250
+ {
251
+ throw std::runtime_error{ "Unsupported ParallelScheme" };
252
+ }
253
+ }
254
+
255
+ public:
256
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
257
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
258
+
259
+ PTModel(size_t _K = 1, size_t _P = 100, Float _alpha = 1.0, Float _eta = 0.01, Float _lambda = 0.01,
260
+ size_t _rg = std::random_device{}())
261
+ : BaseClass(_K, _alpha, _eta, _rg), numPDocs(_P), lambda(_lambda)
262
+ {
263
+ }
264
+
265
+ void updateDocs()
266
+ {
267
+ for (auto& doc : this->docs)
268
+ {
269
+ doc.template update<>(this->getTopicDocPtr(doc.pseudoDoc), *static_cast<DerivedClass*>(this));
270
+ }
271
+ }
272
+ };
273
+ }