tomoto 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -107,18 +107,16 @@ namespace tomoto
107
107
  e = edd.chunkOffsetByDoc(partitionId + 1, docId);
108
108
  }
109
109
 
110
- size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
111
-
112
110
  const auto K = this->K;
113
111
  for (size_t w = b; w < e; ++w)
114
112
  {
115
113
  if (doc.words[w] >= this->realV) continue;
116
- addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
117
- auto dist = getVZLikelihoods(ld, doc, doc.words[w] - vOffset, doc.sents[w]);
114
+ addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
115
+ auto dist = getVZLikelihoods(ld, doc, doc.words[w], doc.sents[w]);
118
116
  auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
119
117
  doc.Vs[w] = vz / (K + KL);
120
118
  doc.Zs[w] = vz % (K + KL);
121
- addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
119
+ addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
122
120
  }
123
121
  }
124
122
 
@@ -294,7 +292,7 @@ namespace tomoto
294
292
  doc.Zs = tvector<Tid>(wordSize);
295
293
  doc.Vs.resize(wordSize);
296
294
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
297
- doc.numByTopic.init(nullptr, this->K + KL);
295
+ doc.numByTopic.init(nullptr, this->K + KL, 1);
298
296
  doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
299
297
  doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
300
298
  doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
@@ -308,7 +306,8 @@ namespace tomoto
308
306
  if (initDocs)
309
307
  {
310
308
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
311
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
309
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
310
+ this->globalState.numByTopicWord.init(nullptr, this->K + KL, V);
312
311
  }
313
312
  }
314
313
 
@@ -533,7 +532,7 @@ namespace tomoto
533
532
  template<typename _TopicModel>
534
533
  void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
535
534
  {
536
- this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
535
+ this->numByTopic.init(ptr, mdl.getK() + mdl.getKL(), 1);
537
536
  numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
538
537
  for (size_t i = 0; i < this->Zs.size(); ++i)
539
538
  {
@@ -144,7 +144,7 @@ namespace tomoto
144
144
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
145
145
  e = edd.vChunkOffset[partitionId];
146
146
 
147
- localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
147
+ localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
148
148
  localData[partitionId].numByTopic = globalState.numByTopic;
149
149
  localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2;
150
150
  localData[partitionId].numByTopic2 = globalState.numByTopic2;
@@ -157,8 +157,6 @@ namespace tomoto
157
157
  template<ParallelScheme _ps, typename _ExtraDocData>
158
158
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
159
159
  {
160
- std::vector<std::future<void>> res;
161
-
162
160
  if (_ps == ParallelScheme::copy_merge)
163
161
  {
164
162
  tState = globalState;
@@ -177,19 +175,12 @@ namespace tomoto
177
175
  globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
178
176
  globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0);
179
177
  globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0);
180
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
181
- }
182
-
183
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
184
- {
185
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
186
- {
187
- localData[i] = globalState;
188
- }));
178
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
189
179
  }
190
180
  }
191
181
  else if (_ps == ParallelScheme::partition)
192
182
  {
183
+ std::vector<std::future<void>> res;
193
184
  res = pool.enqueueToAll([&](size_t partitionId)
194
185
  {
195
186
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
@@ -197,7 +188,6 @@ namespace tomoto
197
188
  globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
198
189
  });
199
190
  for (auto& r : res) r.get();
200
- res.clear();
201
191
 
202
192
  tState.numByTopic1_2 = globalState.numByTopic1_2;
203
193
  globalState.numByTopic1_2 = localData[0].numByTopic1_2;
@@ -209,11 +199,31 @@ namespace tomoto
209
199
  // make all count being positive
210
200
  if (_tw != TermWeight::one)
211
201
  {
212
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
202
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
213
203
  }
214
204
  globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum();
215
205
  globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum();
216
206
 
207
+ }
208
+ }
209
+
210
+
211
+ template<ParallelScheme _ps>
212
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
213
+ {
214
+ std::vector<std::future<void>> res;
215
+ if (_ps == ParallelScheme::copy_merge)
216
+ {
217
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
218
+ {
219
+ res.emplace_back(pool.enqueue([&, i](size_t)
220
+ {
221
+ localData[i] = globalState;
222
+ }));
223
+ }
224
+ }
225
+ else if (_ps == ParallelScheme::partition)
226
+ {
217
227
  res = pool.enqueueToAll([&](size_t threadId)
218
228
  {
219
229
  localData[threadId].numByTopic = globalState.numByTopic;
@@ -221,7 +231,6 @@ namespace tomoto
221
231
  localData[threadId].numByTopic2 = globalState.numByTopic2;
222
232
  });
223
233
  }
224
-
225
234
  for (auto& r : res) r.get();
226
235
  }
227
236
 
@@ -304,7 +313,8 @@ namespace tomoto
304
313
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
305
314
  this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2);
306
315
  this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
307
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
316
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
317
+ this->globalState.numByTopicWord.init(nullptr, K2, V);
308
318
  }
309
319
  }
310
320
 
@@ -0,0 +1,27 @@
1
+ #pragma once
2
+ #include "LDA.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ template<TermWeight _tw>
7
+ struct DocumentPTM : public DocumentLDA<_tw>
8
+ {
9
+ using BaseDocument = DocumentLDA<_tw>;
10
+ using DocumentLDA<_tw>::DocumentLDA;
11
+ using WeightType = typename DocumentLDA<_tw>::WeightType;
12
+
13
+ uint64_t pseudoDoc = 0;
14
+
15
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, pseudoDoc);
16
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, pseudoDoc);
17
+ };
18
+
19
+ class IPTModel : public ILDAModel
20
+ {
21
+ public:
22
+ using DefaultDocType = DocumentPTM<TermWeight::one>;
23
+ static IPTModel* create(TermWeight _weight, size_t _K = 1, size_t _P = 100,
24
+ Float alpha = 0.1, Float eta = 0.01, Float lambda = 0.01, size_t seed = std::random_device{}(),
25
+ bool scalarRng = false);
26
+ };
27
+ }
@@ -0,0 +1,10 @@
1
+ #include "PTModel.hpp"
2
+
3
+ namespace tomoto
4
+ {
5
+
6
+ IPTModel* IPTModel::create(TermWeight _weight, size_t _K, size_t _P, Float _alpha, Float _eta, Float _lambda, size_t seed, bool scalarRng)
7
+ {
8
+ TMT_SWITCH_TW(_weight, scalarRng, PTModel, _K, _P, _alpha, _eta, _lambda, seed);
9
+ }
10
+ }
@@ -0,0 +1,273 @@
1
+ #pragma once
2
+ #include "LDAModel.hpp"
3
+ #include "PT.h"
4
+
5
+ /*
6
+ Implementation of Pseudo-document topic model using Gibbs sampling by bab2min
7
+
8
+ Zuo, Y., Wu, J., Zhang, H., Lin, H., Wang, F., Xu, K., & Xiong, H. (2016, August). Topic modeling of short texts: A pseudo-document view. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 2105-2114).
9
+ */
10
+
11
+ namespace tomoto
12
+ {
13
+ template<TermWeight _tw>
14
+ struct ModelStatePTM : public ModelStateLDA<_tw>
15
+ {
16
+ using WeightType = typename ModelStateLDA<_tw>::WeightType;
17
+
18
+ Eigen::Array<Float, -1, 1> pLikelihood;
19
+ Eigen::ArrayXi numDocsByPDoc;
20
+ Eigen::Matrix<WeightType, -1, -1> numByTopicPDoc;
21
+
22
+ //DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>);
23
+ };
24
+
25
+ template<TermWeight _tw, typename _RandGen,
26
+ typename _Interface = IPTModel,
27
+ typename _Derived = void,
28
+ typename _DocType = DocumentPTM<_tw>,
29
+ typename _ModelState = ModelStatePTM<_tw>>
30
+ class PTModel : public LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface,
31
+ typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type,
32
+ _DocType, _ModelState>
33
+ {
34
+ protected:
35
+ using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type;
36
+ using BaseClass = LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
37
+ friend BaseClass;
38
+ friend typename BaseClass::BaseClass;
39
+ using WeightType = typename BaseClass::WeightType;
40
+
41
+ static constexpr char TMID[] = "PTM";
42
+
43
+ uint64_t numPDocs;
44
+ Float lambda;
45
+ uint32_t pseudoDocSamplingInterval = 10;
46
+
47
+ void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
48
+ {
49
+ const auto K = this->K;
50
+ for (size_t i = 0; i < 10; ++i)
51
+ {
52
+ Float denom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc.col(i).sum(); }, numPDocs, this->alphas.sum());
53
+ for (size_t k = 0; k < K; ++k)
54
+ {
55
+ Float nom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc(k, i);}, numPDocs, this->alphas(k));
56
+ this->alphas(k) = std::max(nom / denom * this->alphas(k), 1e-5f);
57
+ }
58
+ }
59
+ }
60
+
61
+ void samplePseudoDoc(ThreadPool* pool, _ModelState& ld, _RandGen& rgs, _DocType& doc) const
62
+ {
63
+ if (doc.getSumWordWeight() == 0) return;
64
+ Eigen::Array<WeightType, -1, 1> docTopicDist = Eigen::Array<WeightType, -1, 1>::Zero(this->K);
65
+ for (size_t i = 0; i < doc.words.size(); ++i)
66
+ {
67
+ if (doc.words[i] >= this->realV) continue;
68
+ this->template addWordTo<-1>(ld, doc, i, doc.words[i], doc.Zs[i]);
69
+ typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
70
+ = _tw != TermWeight::one ? doc.wordWeights[i] : 1;
71
+ docTopicDist[doc.Zs[i]] += weight;
72
+ }
73
+ --ld.numDocsByPDoc[doc.pseudoDoc];
74
+
75
+ if (pool)
76
+ {
77
+ std::vector<std::future<void>> futures;
78
+ for (size_t w = 0; w < pool->getNumWorkers(); ++w)
79
+ {
80
+ futures.emplace_back(pool->enqueue([&](size_t, size_t w)
81
+ {
82
+ for (size_t p = w; p < numPDocs; p += pool->getNumWorkers())
83
+ {
84
+ Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
85
+ Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
86
+ ld.pLikelihood[p] = ax - bx;
87
+ }
88
+ }, w));
89
+ }
90
+ for (auto& f : futures) f.get();
91
+ }
92
+ else
93
+ {
94
+ for (size_t p = 0; p < numPDocs; ++p)
95
+ {
96
+ Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
97
+ Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
98
+ ld.pLikelihood[p] = ax - bx;
99
+ }
100
+ }
101
+ ld.pLikelihood = (ld.pLikelihood - ld.pLikelihood.maxCoeff()).exp();
102
+ ld.pLikelihood *= ld.numDocsByPDoc.template cast<Float>() + lambda;
103
+
104
+ sample::prefixSum(ld.pLikelihood.data(), numPDocs);
105
+ doc.pseudoDoc = sample::sampleFromDiscreteAcc(ld.pLikelihood.data(), ld.pLikelihood.data() + numPDocs, rgs);
106
+
107
+ ++ld.numDocsByPDoc[doc.pseudoDoc];
108
+ doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
109
+ for (size_t i = 0; i < doc.words.size(); ++i)
110
+ {
111
+ if (doc.words[i] >= this->realV) continue;
112
+ this->template addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i]);
113
+ }
114
+ }
115
+
116
+ template<ParallelScheme _ps, bool _infer, typename _DocIter>
117
+ void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
118
+ _DocIter docFirst, _DocIter docLast) const
119
+ {
120
+ if (this->globalStep % pseudoDocSamplingInterval) return;
121
+ for (; docFirst != docLast; ++docFirst)
122
+ {
123
+ samplePseudoDoc(pool, globalState, rgs[0], *docFirst);
124
+ }
125
+ }
126
+
127
+ template<typename _DocIter>
128
+ double getLLDocs(_DocIter _first, _DocIter _last) const
129
+ {
130
+ double ll = 0;
131
+ // doc-topic distribution
132
+ for (; _first != _last; ++_first)
133
+ {
134
+ auto& doc = *_first;
135
+ }
136
+ return ll;
137
+ }
138
+
139
+ double getLLRest(const _ModelState& ld) const
140
+ {
141
+ double ll = BaseClass::getLLRest(ld);
142
+ const size_t V = this->realV;
143
+ ll -= math::lgammaT(ld.numDocsByPDoc.sum() + lambda * numPDocs) - math::lgammaT(lambda * numPDocs);
144
+ // pseudo_doc-topic distribution
145
+ for (size_t p = 0; p < numPDocs; ++p)
146
+ {
147
+ ll += math::lgammaT(ld.numDocsByPDoc[p] + lambda) - math::lgammaT(lambda);
148
+ ll -= math::lgammaT(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
149
+ for (Tid k = 0; k < this->K; ++k)
150
+ {
151
+ ll += math::lgammaT(ld.numByTopicPDoc(k, p) + this->alphas[k]) - math::lgammaT(this->alphas[k]);
152
+ }
153
+ }
154
+ return ll;
155
+ }
156
+
157
+ void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
158
+ {
159
+ sortAndWriteOrder(doc.words, doc.wOrder);
160
+ doc.numByTopic.init((WeightType*)this->globalState.numByTopicPDoc.col(0).data(), this->K, 1);
161
+ doc.Zs = tvector<Tid>(wordSize);
162
+ if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
163
+ }
164
+
165
+ void initGlobalState(bool initDocs)
166
+ {
167
+ this->alphas.resize(this->K);
168
+ this->alphas.array() = this->alpha;
169
+ this->globalState.pLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(numPDocs);
170
+ this->globalState.numDocsByPDoc = Eigen::ArrayXi::Zero(numPDocs);
171
+ this->globalState.numByTopicPDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, numPDocs);
172
+ BaseClass::initGlobalState(initDocs);
173
+ }
174
+
175
+ struct Generator
176
+ {
177
+ std::uniform_int_distribution<uint64_t> psi;
178
+ std::uniform_int_distribution<Tid> theta;
179
+ };
180
+
181
+ Generator makeGeneratorForInit(const _DocType*) const
182
+ {
183
+ return Generator{
184
+ std::uniform_int_distribution<uint64_t>{0, numPDocs - 1},
185
+ std::uniform_int_distribution<Tid>{0, (Tid)(this->K - 1)}
186
+ };
187
+ }
188
+
189
+ template<bool _Infer>
190
+ void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
191
+ {
192
+ if (i == 0)
193
+ {
194
+ doc.pseudoDoc = g.psi(rgs);
195
+ ++ld.numDocsByPDoc[doc.pseudoDoc];
196
+ doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
197
+ }
198
+ auto& z = doc.Zs[i];
199
+ auto w = doc.words[i];
200
+ if (this->etaByTopicWord.size())
201
+ {
202
+ auto col = this->etaByTopicWord.col(w);
203
+ z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
204
+ }
205
+ else
206
+ {
207
+ z = g.theta(rgs);
208
+ }
209
+ this->template addWordTo<1>(ld, doc, i, w, z);
210
+ }
211
+
212
+ template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
213
+ void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
214
+ _DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
215
+ {
216
+ // single-threaded sampling
217
+ if (_ps == ParallelScheme::none)
218
+ {
219
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
220
+ {
221
+ static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
222
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
223
+ docFirst[id], edd, id,
224
+ *localData, *rgs, this->globalStep, 0);
225
+
226
+ });
227
+ }
228
+ // multi-threaded sampling on partition and update into global
229
+ else if (_ps == ParallelScheme::partition)
230
+ {
231
+ const size_t chStride = pool.getNumWorkers();
232
+ for (size_t i = 0; i < chStride; ++i)
233
+ {
234
+ res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
235
+ {
236
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[partitionId](), [&](size_t id)
237
+ {
238
+ if ((docFirst[id].pseudoDoc + partitionId) % chStride != i) return;
239
+ static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
240
+ docFirst[id], edd, id,
241
+ localData[partitionId], rgs[partitionId], this->globalStep, partitionId
242
+ );
243
+ });
244
+ });
245
+ for (auto& r : res) r.get();
246
+ res.clear();
247
+ }
248
+ }
249
+ else
250
+ {
251
+ throw std::runtime_error{ "Unsupported ParallelScheme" };
252
+ }
253
+ }
254
+
255
+ public:
256
+ DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
257
+ DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
258
+
259
+ PTModel(size_t _K = 1, size_t _P = 100, Float _alpha = 1.0, Float _eta = 0.01, Float _lambda = 0.01,
260
+ size_t _rg = std::random_device{}())
261
+ : BaseClass(_K, _alpha, _eta, _rg), numPDocs(_P), lambda(_lambda)
262
+ {
263
+ }
264
+
265
+ void updateDocs()
266
+ {
267
+ for (auto& doc : this->docs)
268
+ {
269
+ doc.template update<>(this->getTopicDocPtr(doc.pseudoDoc), *static_cast<DerivedClass*>(this));
270
+ }
271
+ }
272
+ };
273
+ }