tomoto 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -65,6 +65,7 @@ namespace tomoto
65
65
  {
66
66
  if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
67
67
  else pbeta = doc.beta.col(i % numBetaSample).array().exp();
68
+
68
69
  Float betaESum = pbeta.sum() + 1;
69
70
  pbeta /= betaESum;
70
71
  for (size_t k = 0; k < this->K; ++k)
@@ -78,7 +79,9 @@ namespace tomoto
78
79
 
79
80
  Float c = betaESum * (1 - pbeta[k]);
80
81
  lowerBound[k] = log(c * max_uk / (1 - max_uk));
81
- upperBound[k] = log(c * min_unk / (1 - min_unk));
82
+ lowerBound[k] = std::max(std::min(lowerBound[k], (Float)100), (Float)-100);
83
+ upperBound[k] = log(c * min_unk / (1 - min_unk + epsilon));
84
+ upperBound[k] = std::max(std::min(upperBound[k], (Float)100), (Float)-100);
82
85
  if (lowerBound[k] > upperBound[k])
83
86
  {
84
87
  THROW_ERROR_WITH_INFO(exception::TrainingError,
@@ -120,8 +123,8 @@ namespace tomoto
120
123
  }*/
121
124
  }
122
125
 
123
- template<typename _DocIter>
124
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
126
+ template<GlobalSampler _gs, typename _DocIter>
127
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last) const
125
128
  {
126
129
  if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
127
130
 
@@ -11,7 +11,7 @@ namespace tomoto
11
11
  using DocumentLDA<_tw>::DocumentLDA;
12
12
 
13
13
  uint64_t timepoint = 0;
14
- ShareableVector<Float> eta;
14
+ ShareableMatrix<Float, -1, 1> eta;
15
15
  sample::AliasMethod<> aliasTable;
16
16
 
17
17
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
@@ -20,6 +20,7 @@ namespace tomoto
20
20
 
21
21
  Eigen::Matrix<WeightType, -1, -1> numByTopic; // Dim: (Topic, Time)
22
22
  Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
23
+ //ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
23
24
  DEFINE_SERIALIZER(numByTopic, numByTopicWord);
24
25
  };
25
26
 
@@ -139,8 +140,6 @@ namespace tomoto
139
140
  template<ParallelScheme _ps, typename _ExtraDocData>
140
141
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
141
142
  {
142
- std::vector<std::future<void>> res;
143
-
144
143
  if (_ps == ParallelScheme::copy_merge)
145
144
  {
146
145
  tState = globalState;
@@ -157,17 +156,10 @@ namespace tomoto
157
156
  }
158
157
  Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
159
158
  = globalState.numByTopicWord.rowwise().sum();
160
-
161
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
162
- {
163
- res.emplace_back(pool.enqueue([&, i](size_t)
164
- {
165
- localData[i] = globalState;
166
- }));
167
- }
168
159
  }
169
160
  else if (_ps == ParallelScheme::partition)
170
161
  {
162
+ std::vector<std::future<void>> res;
171
163
  res = pool.enqueueToAll([&](size_t partitionId)
172
164
  {
173
165
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
@@ -175,7 +167,6 @@ namespace tomoto
175
167
  globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
176
168
  });
177
169
  for (auto& r : res) r.get();
178
- res.clear();
179
170
 
180
171
  // make all count being positive
181
172
  if (_tw != TermWeight::one)
@@ -184,17 +175,11 @@ namespace tomoto
184
175
  }
185
176
  Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
186
177
  = globalState.numByTopicWord.rowwise().sum();
187
-
188
- res = pool.enqueueToAll([&](size_t threadId)
189
- {
190
- localData[threadId].numByTopic = globalState.numByTopic;
191
- });
192
178
  }
193
- for (auto& r : res) r.get();
194
179
  }
195
180
 
196
181
  template<typename _DocIter>
197
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
182
+ void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
198
183
  {
199
184
  const auto K = this->K;
200
185
  const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
@@ -313,10 +298,10 @@ namespace tomoto
313
298
  alphas = newAlphas;
314
299
  }
315
300
 
316
- template<typename _DocIter>
301
+ template<GlobalSampler _gs, typename _DocIter>
317
302
  void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
318
303
  {
319
- // do nothing
304
+ if (_gs != GlobalSampler::inference) return const_cast<DerivedClass*>(this)->_sampleGlobalLevel(pool, localData, rgs, first, last);
320
305
  }
321
306
 
322
307
  void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
@@ -343,11 +328,11 @@ namespace tomoto
343
328
  BaseClass::prepareDoc(doc, docId, wordSize);
344
329
  if (docId == (size_t)-1)
345
330
  {
346
- doc.eta.init(nullptr, this->K);
331
+ doc.eta.init(nullptr, this->K, 1);
347
332
  }
348
333
  else
349
334
  {
350
- doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K);
335
+ doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K, 1);
351
336
  }
352
337
  }
353
338
 
@@ -427,7 +412,7 @@ namespace tomoto
427
412
  numDocsByTime[doc.timepoint]++;
428
413
  if (!initDocs)
429
414
  {
430
- doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K);
415
+ doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
431
416
  }
432
417
  }
433
418
 
@@ -96,7 +96,7 @@ namespace tomoto
96
96
  ld.numTableByTopic.tail(newSize - oldSize).setZero();
97
97
  ld.numByTopic.conservativeResize(newSize);
98
98
  ld.numByTopic.tail(newSize - oldSize).setZero();
99
- ld.numByTopicWord.conservativeResize(newSize, Eigen::NoChange);
99
+ ld.numByTopicWord.conservativeResize(newSize, V);
100
100
  ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, V).setZero();
101
101
  }
102
102
  else
@@ -155,7 +155,7 @@ namespace tomoto
155
155
  if (_inc > 0 && tid >= doc.numByTopic.size())
156
156
  {
157
157
  size_t oldSize = doc.numByTopic.size();
158
- doc.numByTopic.conservativeResize(tid + 1);
158
+ doc.numByTopic.conservativeResize(tid + 1, 1);
159
159
  doc.numByTopic.tail(tid + 1 - oldSize).setZero();
160
160
  }
161
161
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -282,7 +282,7 @@ namespace tomoto
282
282
  auto& doc = this->docs[j];
283
283
  if (doc.numByTopic.size() >= K) continue;
284
284
  size_t oldSize = doc.numByTopic.size();
285
- doc.numByTopic.conservativeResize(K);
285
+ doc.numByTopic.conservativeResize(K, 1);
286
286
  doc.numByTopic.tail(K - oldSize).setZero();
287
287
  }
288
288
  }, this->docs.size() * i / pool.getNumWorkers(), this->docs.size() * (i + 1) / pool.getNumWorkers()));
@@ -293,7 +293,6 @@ namespace tomoto
293
293
  template<ParallelScheme _ps, typename _ExtraDocData>
294
294
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
295
295
  {
296
- std::vector<std::future<void>> res;
297
296
  const size_t V = this->realV;
298
297
  auto K = this->K;
299
298
 
@@ -303,7 +302,7 @@ namespace tomoto
303
302
  globalState.numByTopic.conservativeResize(K);
304
303
  globalState.numByTopic.tail(K - oldSize).setZero();
305
304
  globalState.numTableByTopic.resize(K);
306
- globalState.numByTopicWord.conservativeResize(K, Eigen::NoChange);
305
+ globalState.numByTopicWord.conservativeResize(K, V);
307
306
  globalState.numByTopicWord.block(oldSize, 0, K - oldSize, V).setZero();
308
307
  }
309
308
 
@@ -321,7 +320,7 @@ namespace tomoto
321
320
  if (_tw != TermWeight::one)
322
321
  {
323
322
  globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
324
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
323
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
325
324
  }
326
325
 
327
326
 
@@ -334,15 +333,6 @@ namespace tomoto
334
333
  }
335
334
  }
336
335
  globalState.totalTable = globalState.numTableByTopic.sum();
337
-
338
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
339
- {
340
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
341
- {
342
- localData[i] = globalState;
343
- }));
344
- }
345
- for (auto& r : res) r.get();
346
336
  }
347
337
 
348
338
  /* this LL calculation is based on https://github.com/blei-lab/hdp/blob/master/hdp/state.cpp */
@@ -400,13 +390,14 @@ namespace tomoto
400
390
  {
401
391
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
402
392
  this->globalState.numTableByTopic = Eigen::Matrix<int32_t, -1, 1>::Zero(K);
403
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
393
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
394
+ this->globalState.numByTopicWord.init(nullptr, K, V);
404
395
  }
405
396
  }
406
397
 
407
398
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
408
399
  {
409
- doc.numByTopic.init(nullptr, this->K);
400
+ doc.numByTopic.init(nullptr, this->K, 1);
410
401
  doc.numTopicByTable.clear();
411
402
  doc.Zs = tvector<Tid>(wordSize);
412
403
  if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
@@ -577,7 +568,7 @@ namespace tomoto
577
568
  template<typename _TopicModel>
578
569
  void DocumentHDP<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
579
570
  {
580
- this->numByTopic.init(ptr, mdl.getK());
571
+ this->numByTopic.init(ptr, mdl.getK(), 1);
581
572
  for (size_t i = 0; i < this->Zs.size(); ++i)
582
573
  {
583
574
  if (this->words[i] >= mdl.getV()) continue;
@@ -119,19 +119,26 @@ namespace tomoto
119
119
 
120
120
  DEFINE_SERIALIZER(nodes, levelBlocks);
121
121
 
122
- template<bool _MakeNewPath = true>
122
+ template<bool _makeNewPath = true>
123
123
  void calcNodeLikelihood(Float gamma, size_t levelDepth)
124
124
  {
125
125
  nodeLikelihoods.resize(nodes.size());
126
126
  nodeLikelihoods.array() = -INFINITY;
127
- updateNodeLikelihood<_MakeNewPath>(gamma, levelDepth, &nodes[0]);
127
+ updateNodeLikelihood<_makeNewPath>(gamma, levelDepth, &nodes[0]);
128
+ if (!_makeNewPath)
129
+ {
130
+ for (size_t i = 0; i < levelBlocks.size(); ++i)
131
+ {
132
+ if (levelBlocks[i] < levelDepth - 1) nodeLikelihoods.segment((i + 1) * blockSize, blockSize).array() = -INFINITY;
133
+ }
134
+ }
128
135
  }
129
136
 
130
- template<bool _MakeNewPath = true>
137
+ template<bool _makeNewPath = true>
131
138
  void updateNodeLikelihood(Float gamma, size_t levelDepth, NCRPNode* node, Float weight = 0)
132
139
  {
133
140
  size_t idx = node - nodes.data();
134
- const Float pNewNode = _MakeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
141
+ const Float pNewNode = _makeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
135
142
  nodeLikelihoods[idx] = weight + (((size_t)node->level < levelDepth - 1) ? pNewNode : 0);
136
143
  for(auto * child = node->getChild(); child; child = child->getSibling())
137
144
  {
@@ -187,7 +194,7 @@ namespace tomoto
187
194
  std::vector<std::future<void>> futures;
188
195
  futures.reserve(levelBlocks.size());
189
196
 
190
- auto calc = [this, eta, realV, &doc, &ld](size_t threadId, size_t b)
197
+ auto calc = [&, eta, realV](size_t threadId, size_t b)
191
198
  {
192
199
  Float cnt = 0;
193
200
  Vid prevWord = -1;
@@ -284,7 +291,7 @@ namespace tomoto
284
291
  size_t oldSize = ld.numByTopic.rows();
285
292
  size_t newSize = std::max(nodes.size(), ((oldSize + oldSize / 2 + 7) / 8) * 8);
286
293
  ld.numByTopic.conservativeResize(newSize);
287
- ld.numByTopicWord.conservativeResize(newSize, Eigen::NoChange);
294
+ ld.numByTopicWord.conservativeResize(newSize, ld.numByTopicWord.cols());
288
295
  ld.numByTopic.segment(oldSize, newSize - oldSize).setZero();
289
296
  ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, ld.numByTopicWord.cols()).setZero();
290
297
  }
@@ -317,13 +324,13 @@ namespace tomoto
317
324
  typename _Derived = void,
318
325
  typename _DocType = DocumentHLDA<_tw>,
319
326
  typename _ModelState = ModelStateHLDA<_tw>>
320
- class HLDAModel : public LDAModel<_tw, _RandGen, flags::shared_state, _Interface,
327
+ class HLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
321
328
  typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type,
322
329
  _DocType, _ModelState>
323
330
  {
324
331
  protected:
325
332
  using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type;
326
- using BaseClass = LDAModel<_tw, _RandGen, flags::shared_state, _Interface, DerivedClass, _DocType, _ModelState>;
333
+ using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
327
334
  friend BaseClass;
328
335
  friend typename BaseClass::BaseClass;
329
336
  using WeightType = typename BaseClass::WeightType;
@@ -341,11 +348,11 @@ namespace tomoto
341
348
  }
342
349
 
343
350
  // Words of all documents should be sorted by ascending order.
344
- template<bool _MakeNewPath = true>
351
+ template<GlobalSampler _gs>
345
352
  void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
346
353
  {
347
- if(_MakeNewPath) ld.nt->nodes[doc.path.back()].dropPathOne();
348
- ld.nt->template calcNodeLikelihood<_MakeNewPath>(gamma, this->K);
354
+ if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
355
+ ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
349
356
 
350
357
  std::vector<Float> newTopicWeights(this->K - 1);
351
358
  std::vector<WeightType> cntByLevel(this->K);
@@ -355,7 +362,7 @@ namespace tomoto
355
362
  if (doc.words[w] >= this->realV) break;
356
363
  addWordToOnlyLocal<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
357
364
 
358
- if (_MakeNewPath)
365
+ if (_gs == GlobalSampler::train)
359
366
  {
360
367
  if (doc.words[w] != prevWord)
361
368
  {
@@ -371,7 +378,7 @@ namespace tomoto
371
378
  }
372
379
  }
373
380
 
374
- if (_MakeNewPath)
381
+ if (_gs == GlobalSampler::train)
375
382
  {
376
383
  for (size_t l = 1; l < this->K; ++l)
377
384
  {
@@ -386,7 +393,7 @@ namespace tomoto
386
393
  size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
387
394
  ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
388
395
 
389
- if(_MakeNewPath) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
396
+ if(_gs == GlobalSampler::train) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
390
397
  doc.path.back() = newPath;
391
398
  for (size_t l = this->K - 2; l > 0; --l)
392
399
  {
@@ -398,7 +405,7 @@ namespace tomoto
398
405
  if (doc.words[w] >= this->realV) break;
399
406
  addWordToOnlyLocal<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
400
407
  }
401
- if (_MakeNewPath) ld.nt->nodes[doc.path.back()].addPathOne();
408
+ if (_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].addPathOne();
402
409
  }
403
410
 
404
411
  template<int _inc>
@@ -426,6 +433,7 @@ namespace tomoto
426
433
  template<bool _asymEta>
427
434
  Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
428
435
  {
436
+ if (_asymEta) THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
429
437
  const size_t V = this->realV;
430
438
  assert(vid < V);
431
439
  auto& zLikelihood = ld.zLikelihood;
@@ -439,50 +447,14 @@ namespace tomoto
439
447
  return &zLikelihood[0];
440
448
  }
441
449
 
442
- void sampleTopics(_DocType& doc, size_t docId, _ModelState& ld, _RandGen& rgs) const
443
- {
444
- for (size_t w = 0; w < doc.words.size(); ++w)
445
- {
446
- if (doc.words[w] >= this->realV) continue;
447
- addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
448
- Float* dist;
449
- if (this->etaByTopicWord.size())
450
- {
451
- THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
452
- }
453
- else
454
- {
455
- dist = static_cast<const DerivedClass*>(this)->template
456
- getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
457
- }
458
- doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + this->K, rgs);
459
- addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
460
- }
461
- }
462
-
463
- template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
464
- void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
465
- {
466
- sampleTopics(doc, docId, ld, rgs);
467
- }
468
-
469
- template<typename _DocIter>
470
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
450
+ template<GlobalSampler _gs, typename _DocIter>
451
+ void sampleGlobalLevel(ThreadPool* pool, _ModelState* globalData, _RandGen* rgs, _DocIter first, _DocIter last) const
471
452
  {
472
453
  for (auto doc = first; doc != last; ++doc)
473
454
  {
474
- samplePathes<>(*doc, pool, *localData, rgs[0]);
475
- }
476
- localData->nt->markEmptyBlocks();
477
- }
478
-
479
- template<typename _DocIter>
480
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
481
- {
482
- for (auto doc = first; doc != last; ++doc)
483
- {
484
- samplePathes<false>(*doc, pool, *localData, rgs[0]);
455
+ samplePathes<_gs>(*doc, pool, *globalData, rgs[0]);
485
456
  }
457
+ if (_gs != GlobalSampler::inference) globalData->nt->markEmptyBlocks();
486
458
  }
487
459
 
488
460
  template<typename _DocIter>
@@ -539,7 +511,8 @@ namespace tomoto
539
511
  if (initDocs)
540
512
  {
541
513
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
542
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
514
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
515
+ this->globalState.numByTopicWord.init(nullptr, this->K, V);
543
516
  this->globalState.nt->nodes.resize(detail::NodeTrees::blockSize);
544
517
  }
545
518
  }
@@ -547,7 +520,7 @@ namespace tomoto
547
520
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
548
521
  {
549
522
  sortAndWriteOrder(doc.words, doc.wOrder);
550
- doc.numByTopic.init(nullptr, this->K);
523
+ doc.numByTopic.init(nullptr, this->K, 1);
551
524
  doc.Zs = tvector<Tid>(wordSize);
552
525
  doc.path.resize(this->K);
553
526
  for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
@@ -595,6 +568,31 @@ namespace tomoto
595
568
  return cnt;
596
569
  }
597
570
 
571
+ template<ParallelScheme _ps>
572
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
573
+ {
574
+ std::vector<std::future<void>> res;
575
+ if (_ps == ParallelScheme::copy_merge)
576
+ {
577
+ for (size_t i = 0; i < pool.getNumWorkers(); ++i)
578
+ {
579
+ res.emplace_back(pool.enqueue([&, i](size_t)
580
+ {
581
+ localData[i] = globalState;
582
+ }));
583
+ }
584
+ }
585
+ else if (_ps == ParallelScheme::partition)
586
+ {
587
+ res = pool.enqueueToAll([&](size_t threadId)
588
+ {
589
+ localData[threadId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
590
+ localData[threadId].numByTopic = globalState.numByTopic;
591
+ });
592
+ }
593
+ for (auto& r : res) r.get();
594
+ }
595
+
598
596
  public:
599
597
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
600
598
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
@@ -671,7 +669,7 @@ namespace tomoto
671
669
  template<typename _TopicModel>
672
670
  inline void DocumentHLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
673
671
  {
674
- this->numByTopic.init(ptr, mdl.getLevelDepth());
672
+ this->numByTopic.init(ptr, mdl.getLevelDepth(), 1);
675
673
  for (size_t i = 0; i < this->Zs.size(); ++i)
676
674
  {
677
675
  if (this->words[i] >= mdl.getV()) continue;