tomoto 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -65,6 +65,7 @@ namespace tomoto
65
65
  {
66
66
  if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
67
67
  else pbeta = doc.beta.col(i % numBetaSample).array().exp();
68
+
68
69
  Float betaESum = pbeta.sum() + 1;
69
70
  pbeta /= betaESum;
70
71
  for (size_t k = 0; k < this->K; ++k)
@@ -78,7 +79,9 @@ namespace tomoto
78
79
 
79
80
  Float c = betaESum * (1 - pbeta[k]);
80
81
  lowerBound[k] = log(c * max_uk / (1 - max_uk));
81
- upperBound[k] = log(c * min_unk / (1 - min_unk));
82
+ lowerBound[k] = std::max(std::min(lowerBound[k], (Float)100), (Float)-100);
83
+ upperBound[k] = log(c * min_unk / (1 - min_unk + epsilon));
84
+ upperBound[k] = std::max(std::min(upperBound[k], (Float)100), (Float)-100);
82
85
  if (lowerBound[k] > upperBound[k])
83
86
  {
84
87
  THROW_ERROR_WITH_INFO(exception::TrainingError,
@@ -120,8 +123,8 @@ namespace tomoto
120
123
  }*/
121
124
  }
122
125
 
123
- template<typename _DocIter>
124
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
126
+ template<GlobalSampler _gs, typename _DocIter>
127
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last) const
125
128
  {
126
129
  if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
127
130
 
@@ -11,7 +11,7 @@ namespace tomoto
11
11
  using DocumentLDA<_tw>::DocumentLDA;
12
12
 
13
13
  uint64_t timepoint = 0;
14
- ShareableVector<Float> eta;
14
+ ShareableMatrix<Float, -1, 1> eta;
15
15
  sample::AliasMethod<> aliasTable;
16
16
 
17
17
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
@@ -20,6 +20,7 @@ namespace tomoto
20
20
 
21
21
  Eigen::Matrix<WeightType, -1, -1> numByTopic; // Dim: (Topic, Time)
22
22
  Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
23
+ //ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
23
24
  DEFINE_SERIALIZER(numByTopic, numByTopicWord);
24
25
  };
25
26
 
@@ -139,8 +140,6 @@ namespace tomoto
139
140
  template<ParallelScheme _ps, typename _ExtraDocData>
140
141
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
141
142
  {
142
- std::vector<std::future<void>> res;
143
-
144
143
  if (_ps == ParallelScheme::copy_merge)
145
144
  {
146
145
  tState = globalState;
@@ -157,17 +156,10 @@ namespace tomoto
157
156
  }
158
157
  Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
159
158
  = globalState.numByTopicWord.rowwise().sum();
160
-
161
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
162
- {
163
- res.emplace_back(pool.enqueue([&, i](size_t)
164
- {
165
- localData[i] = globalState;
166
- }));
167
- }
168
159
  }
169
160
  else if (_ps == ParallelScheme::partition)
170
161
  {
162
+ std::vector<std::future<void>> res;
171
163
  res = pool.enqueueToAll([&](size_t partitionId)
172
164
  {
173
165
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
@@ -175,7 +167,6 @@ namespace tomoto
175
167
  globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
176
168
  });
177
169
  for (auto& r : res) r.get();
178
- res.clear();
179
170
 
180
171
  // make all count being positive
181
172
  if (_tw != TermWeight::one)
@@ -184,17 +175,11 @@ namespace tomoto
184
175
  }
185
176
  Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
186
177
  = globalState.numByTopicWord.rowwise().sum();
187
-
188
- res = pool.enqueueToAll([&](size_t threadId)
189
- {
190
- localData[threadId].numByTopic = globalState.numByTopic;
191
- });
192
178
  }
193
- for (auto& r : res) r.get();
194
179
  }
195
180
 
196
181
  template<typename _DocIter>
197
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
182
+ void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
198
183
  {
199
184
  const auto K = this->K;
200
185
  const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
@@ -313,10 +298,10 @@ namespace tomoto
313
298
  alphas = newAlphas;
314
299
  }
315
300
 
316
- template<typename _DocIter>
301
+ template<GlobalSampler _gs, typename _DocIter>
317
302
  void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
318
303
  {
319
- // do nothing
304
+ if (_gs != GlobalSampler::inference) return const_cast<DerivedClass*>(this)->_sampleGlobalLevel(pool, localData, rgs, first, last);
320
305
  }
321
306
 
322
307
  void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
@@ -343,11 +328,11 @@ namespace tomoto
343
328
  BaseClass::prepareDoc(doc, docId, wordSize);
344
329
  if (docId == (size_t)-1)
345
330
  {
346
- doc.eta.init(nullptr, this->K);
331
+ doc.eta.init(nullptr, this->K, 1);
347
332
  }
348
333
  else
349
334
  {
350
- doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K);
335
+ doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K, 1);
351
336
  }
352
337
  }
353
338
 
@@ -427,7 +412,7 @@ namespace tomoto
427
412
  numDocsByTime[doc.timepoint]++;
428
413
  if (!initDocs)
429
414
  {
430
- doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K);
415
+ doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
431
416
  }
432
417
  }
433
418
 
@@ -96,7 +96,7 @@ namespace tomoto
96
96
  ld.numTableByTopic.tail(newSize - oldSize).setZero();
97
97
  ld.numByTopic.conservativeResize(newSize);
98
98
  ld.numByTopic.tail(newSize - oldSize).setZero();
99
- ld.numByTopicWord.conservativeResize(newSize, Eigen::NoChange);
99
+ ld.numByTopicWord.conservativeResize(newSize, V);
100
100
  ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, V).setZero();
101
101
  }
102
102
  else
@@ -155,7 +155,7 @@ namespace tomoto
155
155
  if (_inc > 0 && tid >= doc.numByTopic.size())
156
156
  {
157
157
  size_t oldSize = doc.numByTopic.size();
158
- doc.numByTopic.conservativeResize(tid + 1);
158
+ doc.numByTopic.conservativeResize(tid + 1, 1);
159
159
  doc.numByTopic.tail(tid + 1 - oldSize).setZero();
160
160
  }
161
161
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -282,7 +282,7 @@ namespace tomoto
282
282
  auto& doc = this->docs[j];
283
283
  if (doc.numByTopic.size() >= K) continue;
284
284
  size_t oldSize = doc.numByTopic.size();
285
- doc.numByTopic.conservativeResize(K);
285
+ doc.numByTopic.conservativeResize(K, 1);
286
286
  doc.numByTopic.tail(K - oldSize).setZero();
287
287
  }
288
288
  }, this->docs.size() * i / pool.getNumWorkers(), this->docs.size() * (i + 1) / pool.getNumWorkers()));
@@ -293,7 +293,6 @@ namespace tomoto
293
293
  template<ParallelScheme _ps, typename _ExtraDocData>
294
294
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
295
295
  {
296
- std::vector<std::future<void>> res;
297
296
  const size_t V = this->realV;
298
297
  auto K = this->K;
299
298
 
@@ -303,7 +302,7 @@ namespace tomoto
303
302
  globalState.numByTopic.conservativeResize(K);
304
303
  globalState.numByTopic.tail(K - oldSize).setZero();
305
304
  globalState.numTableByTopic.resize(K);
306
- globalState.numByTopicWord.conservativeResize(K, Eigen::NoChange);
305
+ globalState.numByTopicWord.conservativeResize(K, V);
307
306
  globalState.numByTopicWord.block(oldSize, 0, K - oldSize, V).setZero();
308
307
  }
309
308
 
@@ -321,7 +320,7 @@ namespace tomoto
321
320
  if (_tw != TermWeight::one)
322
321
  {
323
322
  globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
324
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
323
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
325
324
  }
326
325
 
327
326
 
@@ -334,15 +333,6 @@ namespace tomoto
334
333
  }
335
334
  }
336
335
  globalState.totalTable = globalState.numTableByTopic.sum();
337
-
338
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
339
- {
340
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
341
- {
342
- localData[i] = globalState;
343
- }));
344
- }
345
- for (auto& r : res) r.get();
346
336
  }
347
337
 
348
338
  /* this LL calculation is based on https://github.com/blei-lab/hdp/blob/master/hdp/state.cpp */
@@ -400,13 +390,14 @@ namespace tomoto
400
390
  {
401
391
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
402
392
  this->globalState.numTableByTopic = Eigen::Matrix<int32_t, -1, 1>::Zero(K);
403
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
393
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
394
+ this->globalState.numByTopicWord.init(nullptr, K, V);
404
395
  }
405
396
  }
406
397
 
407
398
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
408
399
  {
409
- doc.numByTopic.init(nullptr, this->K);
400
+ doc.numByTopic.init(nullptr, this->K, 1);
410
401
  doc.numTopicByTable.clear();
411
402
  doc.Zs = tvector<Tid>(wordSize);
412
403
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
@@ -577,7 +568,7 @@ namespace tomoto
577
568
  template<typename _TopicModel>
578
569
  void DocumentHDP<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
579
570
  {
580
- this->numByTopic.init(ptr, mdl.getK());
571
+ this->numByTopic.init(ptr, mdl.getK(), 1);
581
572
  for (size_t i = 0; i < this->Zs.size(); ++i)
582
573
  {
583
574
  if (this->words[i] >= mdl.getV()) continue;
@@ -119,19 +119,26 @@ namespace tomoto
119
119
 
120
120
  DEFINE_SERIALIZER(nodes, levelBlocks);
121
121
 
122
- template<bool _MakeNewPath = true>
122
+ template<bool _makeNewPath = true>
123
123
  void calcNodeLikelihood(Float gamma, size_t levelDepth)
124
124
  {
125
125
  nodeLikelihoods.resize(nodes.size());
126
126
  nodeLikelihoods.array() = -INFINITY;
127
- updateNodeLikelihood<_MakeNewPath>(gamma, levelDepth, &nodes[0]);
127
+ updateNodeLikelihood<_makeNewPath>(gamma, levelDepth, &nodes[0]);
128
+ if (!_makeNewPath)
129
+ {
130
+ for (size_t i = 0; i < levelBlocks.size(); ++i)
131
+ {
132
+ if (levelBlocks[i] < levelDepth - 1) nodeLikelihoods.segment((i + 1) * blockSize, blockSize).array() = -INFINITY;
133
+ }
134
+ }
128
135
  }
129
136
 
130
- template<bool _MakeNewPath = true>
137
+ template<bool _makeNewPath = true>
131
138
  void updateNodeLikelihood(Float gamma, size_t levelDepth, NCRPNode* node, Float weight = 0)
132
139
  {
133
140
  size_t idx = node - nodes.data();
134
- const Float pNewNode = _MakeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
141
+ const Float pNewNode = _makeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
135
142
  nodeLikelihoods[idx] = weight + (((size_t)node->level < levelDepth - 1) ? pNewNode : 0);
136
143
  for(auto * child = node->getChild(); child; child = child->getSibling())
137
144
  {
@@ -187,7 +194,7 @@ namespace tomoto
187
194
  std::vector<std::future<void>> futures;
188
195
  futures.reserve(levelBlocks.size());
189
196
 
190
- auto calc = [this, eta, realV, &doc, &ld](size_t threadId, size_t b)
197
+ auto calc = [&, eta, realV](size_t threadId, size_t b)
191
198
  {
192
199
  Float cnt = 0;
193
200
  Vid prevWord = -1;
@@ -284,7 +291,7 @@ namespace tomoto
284
291
  size_t oldSize = ld.numByTopic.rows();
285
292
  size_t newSize = std::max(nodes.size(), ((oldSize + oldSize / 2 + 7) / 8) * 8);
286
293
  ld.numByTopic.conservativeResize(newSize);
287
- ld.numByTopicWord.conservativeResize(newSize, Eigen::NoChange);
294
+ ld.numByTopicWord.conservativeResize(newSize, ld.numByTopicWord.cols());
288
295
  ld.numByTopic.segment(oldSize, newSize - oldSize).setZero();
289
296
  ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, ld.numByTopicWord.cols()).setZero();
290
297
  }
@@ -317,13 +324,13 @@ namespace tomoto
317
324
  typename _Derived = void,
318
325
  typename _DocType = DocumentHLDA<_tw>,
319
326
  typename _ModelState = ModelStateHLDA<_tw>>
320
- class HLDAModel : public LDAModel<_tw, _RandGen, flags::shared_state, _Interface,
327
+ class HLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
321
328
  typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type,
322
329
  _DocType, _ModelState>
323
330
  {
324
331
  protected:
325
332
  using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type;
326
- using BaseClass = LDAModel<_tw, _RandGen, flags::shared_state, _Interface, DerivedClass, _DocType, _ModelState>;
333
+ using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
327
334
  friend BaseClass;
328
335
  friend typename BaseClass::BaseClass;
329
336
  using WeightType = typename BaseClass::WeightType;
@@ -341,11 +348,11 @@ namespace tomoto
341
348
  }
342
349
 
343
350
  // Words of all documents should be sorted by ascending order.
344
- template<bool _MakeNewPath = true>
351
+ template<GlobalSampler _gs>
345
352
  void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
346
353
  {
347
- if(_MakeNewPath) ld.nt->nodes[doc.path.back()].dropPathOne();
348
- ld.nt->template calcNodeLikelihood<_MakeNewPath>(gamma, this->K);
354
+ if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
355
+ ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
349
356
 
350
357
  std::vector<Float> newTopicWeights(this->K - 1);
351
358
  std::vector<WeightType> cntByLevel(this->K);
@@ -355,7 +362,7 @@ namespace tomoto
355
362
  if (doc.words[w] >= this->realV) break;
356
363
  addWordToOnlyLocal<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
357
364
 
358
- if (_MakeNewPath)
365
+ if (_gs == GlobalSampler::train)
359
366
  {
360
367
  if (doc.words[w] != prevWord)
361
368
  {
@@ -371,7 +378,7 @@ namespace tomoto
371
378
  }
372
379
  }
373
380
 
374
- if (_MakeNewPath)
381
+ if (_gs == GlobalSampler::train)
375
382
  {
376
383
  for (size_t l = 1; l < this->K; ++l)
377
384
  {
@@ -386,7 +393,7 @@ namespace tomoto
386
393
  size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
387
394
  ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
388
395
 
389
- if(_MakeNewPath) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
396
+ if(_gs == GlobalSampler::train) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
390
397
  doc.path.back() = newPath;
391
398
  for (size_t l = this->K - 2; l > 0; --l)
392
399
  {
@@ -398,7 +405,7 @@ namespace tomoto
398
405
  if (doc.words[w] >= this->realV) break;
399
406
  addWordToOnlyLocal<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
400
407
  }
401
- if (_MakeNewPath) ld.nt->nodes[doc.path.back()].addPathOne();
408
+ if (_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].addPathOne();
402
409
  }
403
410
 
404
411
  template<int _inc>
@@ -426,6 +433,7 @@ namespace tomoto
426
433
  template<bool _asymEta>
427
434
  Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
428
435
  {
436
+ if (_asymEta) THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
429
437
  const size_t V = this->realV;
430
438
  assert(vid < V);
431
439
  auto& zLikelihood = ld.zLikelihood;
@@ -439,50 +447,14 @@ namespace tomoto
439
447
  return &zLikelihood[0];
440
448
  }
441
449
 
442
- void sampleTopics(_DocType& doc, size_t docId, _ModelState& ld, _RandGen& rgs) const
443
- {
444
- for (size_t w = 0; w < doc.words.size(); ++w)
445
- {
446
- if (doc.words[w] >= this->realV) continue;
447
- addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
448
- Float* dist;
449
- if (this->etaByTopicWord.size())
450
- {
451
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
452
- }
453
- else
454
- {
455
- dist = static_cast<const DerivedClass*>(this)->template
456
- getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
457
- }
458
- doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + this->K, rgs);
459
- addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
460
- }
461
- }
462
-
463
- template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
464
- void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
465
- {
466
- sampleTopics(doc, docId, ld, rgs);
467
- }
468
-
469
- template<typename _DocIter>
470
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
450
+ template<GlobalSampler _gs, typename _DocIter>
451
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* globalData, _RandGen* rgs, _DocIter first, _DocIter last) const
471
452
  {
472
453
  for (auto doc = first; doc != last; ++doc)
473
454
  {
474
- samplePathes<>(*doc, pool, *localData, rgs[0]);
475
- }
476
- localData->nt->markEmptyBlocks();
477
- }
478
-
479
- template<typename _DocIter>
480
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
481
- {
482
- for (auto doc = first; doc != last; ++doc)
483
- {
484
- samplePathes<false>(*doc, pool, *localData, rgs[0]);
455
+ samplePathes<_gs>(*doc, pool, *globalData, rgs[0]);
485
456
  }
457
+ if (_gs != GlobalSampler::inference) globalData->nt->markEmptyBlocks();
486
458
  }
487
459
 
488
460
  template<typename _DocIter>
@@ -539,7 +511,8 @@ namespace tomoto
539
511
  if (initDocs)
540
512
  {
541
513
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
542
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
514
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
515
+ this->globalState.numByTopicWord.init(nullptr, this->K, V);
543
516
  this->globalState.nt->nodes.resize(detail::NodeTrees::blockSize);
544
517
  }
545
518
  }
@@ -547,7 +520,7 @@ namespace tomoto
547
520
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
548
521
  {
549
522
  sortAndWriteOrder(doc.words, doc.wOrder);
550
- doc.numByTopic.init(nullptr, this->K);
523
+ doc.numByTopic.init(nullptr, this->K, 1);
551
524
  doc.Zs = tvector<Tid>(wordSize);
552
525
  doc.path.resize(this->K);
553
526
  for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
@@ -595,6 +568,31 @@ namespace tomoto
595
568
  return cnt;
596
569
  }
597
570
 
571
+ template<ParallelScheme _ps>
572
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
573
+ {
574
+ std::vector<std::future<void>> res;
575
+ if (_ps == ParallelScheme::copy_merge)
576
+ {
577
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
578
+ {
579
+ res.emplace_back(pool.enqueue([&, i](size_t)
580
+ {
581
+ localData[i] = globalState;
582
+ }));
583
+ }
584
+ }
585
+ else if (_ps == ParallelScheme::partition)
586
+ {
587
+ res = pool.enqueueToAll([&](size_t threadId)
588
+ {
589
+ localData[threadId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
590
+ localData[threadId].numByTopic = globalState.numByTopic;
591
+ });
592
+ }
593
+ for (auto& r : res) r.get();
594
+ }
595
+
598
596
  public:
599
597
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
600
598
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
@@ -671,7 +669,7 @@ namespace tomoto
671
669
  template<typename _TopicModel>
672
670
  inline void DocumentHLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
673
671
  {
674
- this->numByTopic.init(ptr, mdl.getLevelDepth());
672
+ this->numByTopic.init(ptr, mdl.getLevelDepth(), 1);
675
673
  for (size_t i = 0; i < this->Zs.size(); ++i)
676
674
  {
677
675
  if (this->words[i] >= mdl.getV()) continue;