tomoto 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -28,7 +28,8 @@ namespace tomoto
28
28
  typename _Interface = IHPAModel,
29
29
  typename _Derived = void,
30
30
  typename _DocType = DocumentHPA<_tw>,
31
- typename _ModelState = ModelStateHPA<_tw>>
31
+ typename _ModelState = ModelStateHPA<_tw>
32
+ >
32
33
  class HPAModel : public LDAModel<_tw, _RandGen, 0, _Interface,
33
34
  typename std::conditional<std::is_same<_Derived, void>::value, HPAModel<_tw, _RandGen, _Exclusive>, _Derived>::type,
34
35
  _DocType, _ModelState>
@@ -250,8 +251,6 @@ namespace tomoto
250
251
  template<ParallelScheme _ps, typename _ExtraDocData>
251
252
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
252
253
  {
253
- std::vector<std::future<void>> res;
254
-
255
254
  tState = globalState;
256
255
  globalState = localData[0];
257
256
  for (size_t i = 1; i < pool.getNumWorkers(); ++i)
@@ -276,15 +275,6 @@ namespace tomoto
276
275
  globalState.numByTopicWord[1] = globalState.numByTopicWord[1].cwiseMax(0);
277
276
  globalState.numByTopicWord[2] = globalState.numByTopicWord[2].cwiseMax(0);
278
277
  }
279
-
280
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
281
- {
282
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
283
- {
284
- localData[i] = globalState;
285
- }));
286
- }
287
- for (auto& r : res) r.get();
288
278
  }
289
279
 
290
280
  std::vector<uint64_t> _getTopicsCount() const
@@ -379,7 +369,7 @@ namespace tomoto
379
369
 
380
370
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
381
371
  {
382
- doc.numByTopic.init(nullptr, this->K + 1);
372
+ doc.numByTopic.init(nullptr, this->K + 1, 1);
383
373
  doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
384
374
  doc.Zs = tvector<Tid>(wordSize);
385
375
  doc.Z2s = tvector<Tid>(wordSize);
@@ -575,7 +565,7 @@ namespace tomoto
575
565
  template<typename _TopicModel>
576
566
  void DocumentHPA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
577
567
  {
578
- this->numByTopic.init(ptr, mdl.getK() + 1);
568
+ this->numByTopic.init(ptr, mdl.getK() + 1, 1);
579
569
  this->numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(mdl.getK(), mdl.getK2() + 1);
580
570
  for (size_t i = 0; i < this->Zs.size(); ++i)
581
571
  {
@@ -5,32 +5,67 @@ namespace tomoto
5
5
  {
6
6
  enum class TermWeight { one, idf, pmi, size };
7
7
 
8
- template<typename _Scalar>
9
- struct ShareableVector : Eigen::Map<Eigen::Matrix<_Scalar, -1, 1>>
8
+ template<typename _Scalar, Eigen::Index _rows, Eigen::Index _cols>
9
+ struct ShareableMatrix : Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>
10
10
  {
11
- Eigen::Matrix<_Scalar, -1, 1> ownData;
12
- ShareableVector(_Scalar* ptr = nullptr, Eigen::Index len = 0)
13
- : Eigen::Map<Eigen::Matrix<_Scalar, -1, 1>>(nullptr, 0)
11
+ using BaseType = Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>;
12
+ Eigen::Matrix<_Scalar, _rows, _cols> ownData;
13
+
14
+ ShareableMatrix(_Scalar* ptr = nullptr, Eigen::Index rows = 0, Eigen::Index cols = 0)
15
+ : BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0)
14
16
  {
15
- init(ptr, len);
17
+ init(ptr, rows, cols);
16
18
  }
17
19
 
18
- void init(_Scalar* ptr, Eigen::Index len)
20
+ ShareableMatrix(const ShareableMatrix& o)
21
+ : BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0), ownData{ o.ownData }
19
22
  {
20
- if (!ptr && len)
23
+ if (o.ownData.data())
24
+ {
25
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
26
+ }
27
+ else
21
28
  {
22
- ownData = Eigen::Matrix<_Scalar, -1, 1>::Zero(len);
29
+ new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
30
+ }
31
+ }
32
+
33
+ ShareableMatrix(ShareableMatrix&& o) = default;
34
+
35
+ ShareableMatrix& operator=(const ShareableMatrix& o)
36
+ {
37
+ if (o.ownData.data())
38
+ {
39
+ ownData = o.ownData;
40
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
41
+ }
42
+ else
43
+ {
44
+ new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
45
+ }
46
+ return *this;
47
+ }
48
+
49
+ ShareableMatrix& operator=(ShareableMatrix&& o) = default;
50
+
51
+ void init(_Scalar* ptr, Eigen::Index rows, Eigen::Index cols)
52
+ {
53
+ if (!ptr && rows && cols)
54
+ {
55
+ ownData = Eigen::Matrix<_Scalar, _rows, _cols>::Zero(_rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
23
56
  ptr = ownData.data();
24
57
  }
25
- // is this the best way??
26
- this->m_data = ptr;
27
- ((Eigen::internal::variable_if_dynamic<Eigen::Index, -1>*)&this->m_rows)->setValue(len);
58
+ else
59
+ {
60
+ ownData = Eigen::Matrix<_Scalar, _rows, _cols>{};
61
+ }
62
+ new (this) BaseType(ptr, _rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
28
63
  }
29
64
 
30
- void conservativeResize(size_t newSize)
65
+ void conservativeResize(size_t newRows, size_t newCols)
31
66
  {
32
- ownData.conservativeResize(newSize);
33
- init(ownData.data(), ownData.size());
67
+ ownData.conservativeResize(_rows != -1 ? _rows : newRows, _cols != -1 ? _cols : newCols);
68
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
34
69
  }
35
70
 
36
71
  void becomeOwner()
@@ -38,9 +73,26 @@ namespace tomoto
38
73
  if (ownData.data() != this->m_data)
39
74
  {
40
75
  ownData = *this;
41
- init(ownData.data(), ownData.size());
76
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
42
77
  }
43
78
  }
79
+
80
+ void serializerRead(std::istream& istr)
81
+ {
82
+ uint32_t rows = serializer::readFromStream<uint32_t>(istr);
83
+ uint32_t cols = serializer::readFromStream<uint32_t>(istr);
84
+ init(nullptr, rows, cols);
85
+ if (!istr.read((char*)this->data(), sizeof(_Scalar) * this->size()))
86
+ throw std::ios_base::failure(std::string("reading type '") + typeid(_Scalar).name() + std::string("' is failed"));
87
+ }
88
+
89
+ void serializerWrite(std::ostream& ostr) const
90
+ {
91
+ serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->rows());
92
+ serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->cols());
93
+ if (!ostr.write((const char*)this->data(), sizeof(_Scalar) * this->size()))
94
+ throw std::ios_base::failure(std::string("writing type '") + typeid(_Scalar).name() + std::string("' is failed"));
95
+ }
44
96
  };
45
97
 
46
98
  template<typename _Base, TermWeight _tw>
@@ -85,7 +137,7 @@ namespace tomoto
85
137
 
86
138
  tvector<Tid> Zs;
87
139
  tvector<Float> wordWeights;
88
- ShareableVector<WeightType> numByTopic;
140
+ ShareableMatrix<WeightType, -1, 1> numByTopic;
89
141
 
90
142
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 0, Zs, wordWeights);
91
143
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 1, 0x00010001, Zs, wordWeights);
@@ -163,7 +163,7 @@ namespace tomoto
163
163
  {
164
164
  res.emplace_back(pool.enqueue([&, this, ch, chStride](size_t threadId)
165
165
  {
166
- forRandom((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
166
+ forShuffled((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
167
167
  {
168
168
  static_cast<DerivedClass*>(this)->template sampleDocument<ParallelScheme::copy_merge>(
169
169
  this->docs[id * chStride + ch], 0, id * chStride + ch,
@@ -58,7 +58,8 @@ namespace tomoto
58
58
 
59
59
  Eigen::Matrix<Float, -1, 1> zLikelihood;
60
60
  Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
61
- Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
61
+ //Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
+ ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
63
  DEFINE_SERIALIZER(numByTopic, numByTopicWord);
63
64
  };
64
65
 
@@ -137,7 +138,8 @@ namespace tomoto
137
138
  typename _Interface,
138
139
  typename _Derived,
139
140
  typename _DocType,
140
- typename _ModelState>
141
+ typename _ModelState
142
+ >
141
143
  class HDPModel;
142
144
 
143
145
  template<TermWeight _tw, typename _RandGen,
@@ -145,7 +147,8 @@ namespace tomoto
145
147
  typename _Interface = ILDAModel,
146
148
  typename _Derived = void,
147
149
  typename _DocType = DocumentLDA<_tw>,
148
- typename _ModelState = ModelStateLDA<_tw>>
150
+ typename _ModelState = ModelStateLDA<_tw>
151
+ >
149
152
  class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
150
153
  typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
151
154
  _DocType, _ModelState>,
@@ -306,25 +309,23 @@ namespace tomoto
306
309
  e = edd.chunkOffsetByDoc(partitionId + 1, docId);
307
310
  }
308
311
 
309
- size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
310
-
311
312
  for (size_t w = b; w < e; ++w)
312
313
  {
313
314
  if (doc.words[w] >= this->realV) continue;
314
- addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
315
+ static_cast<const DerivedClass*>(this)->template addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
315
316
  Float* dist;
316
317
  if (etaByTopicWord.size())
317
318
  {
318
319
  dist = static_cast<const DerivedClass*>(this)->template
319
- getZLikelihoods<true>(ld, doc, docId, doc.words[w] - vOffset);
320
+ getZLikelihoods<true>(ld, doc, docId, doc.words[w]);
320
321
  }
321
322
  else
322
323
  {
323
324
  dist = static_cast<const DerivedClass*>(this)->template
324
- getZLikelihoods<false>(ld, doc, docId, doc.words[w] - vOffset);
325
+ getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
325
326
  }
326
327
  doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
327
- addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
328
+ static_cast<const DerivedClass*>(this)->template addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
328
329
  }
329
330
  }
330
331
 
@@ -335,7 +336,7 @@ namespace tomoto
335
336
  // single-threaded sampling
336
337
  if (_ps == ParallelScheme::none)
337
338
  {
338
- forRandom((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
339
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
339
340
  {
340
341
  static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
341
342
  static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
@@ -344,7 +345,7 @@ namespace tomoto
344
345
 
345
346
  });
346
347
  }
347
- // multi-threaded sampling on partition ad update into global
348
+ // multi-threaded sampling on partition and update into global
348
349
  else if (_ps == ParallelScheme::partition)
349
350
  {
350
351
  const size_t chStride = pool.getNumWorkers();
@@ -353,7 +354,7 @@ namespace tomoto
353
354
  res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
354
355
  {
355
356
  size_t didx = (i + partitionId) % chStride;
356
- forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
357
+ forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
357
358
  {
358
359
  if (i == 0)
359
360
  {
@@ -380,7 +381,7 @@ namespace tomoto
380
381
  {
381
382
  res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
382
383
  {
383
- forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
384
+ forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
384
385
  {
385
386
  static_cast<const DerivedClass*>(this)->presampleDocument(
386
387
  docFirst[id * chStride + ch], id * chStride + ch,
@@ -396,6 +397,16 @@ namespace tomoto
396
397
  for (auto& r : res) r.get();
397
398
  res.clear();
398
399
  }
400
+ else
401
+ {
402
+ throw std::runtime_error{ "Unsupported ParallelScheme" };
403
+ }
404
+ }
405
+
406
+ template<ParallelScheme _ps, bool _infer, typename _DocIter>
407
+ void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
408
+ _DocIter docFirst, _DocIter docLast) const
409
+ {
399
410
  }
400
411
 
401
412
  template<typename _DocIter, typename _ExtraDocData>
@@ -444,7 +455,8 @@ namespace tomoto
444
455
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
445
456
  e = edd.vChunkOffset[partitionId];
446
457
 
447
- localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
458
+ //localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
459
+ localData[partitionId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
448
460
  localData[partitionId].numByTopic = globalState.numByTopic;
449
461
  if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
450
462
  });
@@ -467,16 +479,29 @@ namespace tomoto
467
479
  }
468
480
 
469
481
  template<ParallelScheme _ps>
470
- void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
482
+ void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, bool freeze_topics = false)
471
483
  {
472
484
  std::vector<std::future<void>> res;
473
485
  try
474
486
  {
475
- performSampling<_ps, false>(pool, localData, rgs, res,
476
- this->docs.begin(), this->docs.end(), eddTrain);
487
+ static_cast<DerivedClass*>(this)->template performSampling<_ps, false>(pool, localData, rgs, res,
488
+ this->docs.begin(), this->docs.end(), eddTrain
489
+ );
477
490
  static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
478
491
  static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
479
- static_cast<DerivedClass*>(this)->template sampleGlobalLevel<>(&pool, localData, rgs, this->docs.begin(), this->docs.end());
492
+ static_cast<DerivedClass*>(this)->template performSamplingGlobal<_ps, false>(&pool, this->globalState, rgs,
493
+ this->docs.begin(), this->docs.end()
494
+ );
495
+
496
+ if(freeze_topics) static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::freeze_topics>(
497
+ &pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
498
+ );
499
+ else static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::train>(
500
+ &pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
501
+ );
502
+
503
+ static_cast<DerivedClass*>(this)->template distributeMergedState<_ps>(pool, this->globalState, localData);
504
+
480
505
  if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
481
506
  {
482
507
  static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
@@ -503,8 +528,6 @@ namespace tomoto
503
528
  template<ParallelScheme _ps, typename _ExtraDocData>
504
529
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
505
530
  {
506
- std::vector<std::future<void>> res;
507
-
508
531
  if (_ps == ParallelScheme::copy_merge)
509
532
  {
510
533
  tState = globalState;
@@ -517,10 +540,27 @@ namespace tomoto
517
540
  // make all count being positive
518
541
  if (_tw != TermWeight::one)
519
542
  {
520
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
543
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
544
+ }
545
+ globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
546
+ }
547
+ else if (_ps == ParallelScheme::partition)
548
+ {
549
+ // make all count being positive
550
+ if (_tw != TermWeight::one)
551
+ {
552
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
521
553
  }
522
554
  globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
555
+ }
556
+ }
523
557
 
558
+ template<ParallelScheme _ps>
559
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
560
+ {
561
+ std::vector<std::future<void>> res;
562
+ if (_ps == ParallelScheme::copy_merge)
563
+ {
524
564
  for (size_t i = 0; i < pool.getNumWorkers(); ++i)
525
565
  {
526
566
  res.emplace_back(pool.enqueue([&, i](size_t)
@@ -531,22 +571,6 @@ namespace tomoto
531
571
  }
532
572
  else if (_ps == ParallelScheme::partition)
533
573
  {
534
- res = pool.enqueueToAll([&](size_t partitionId)
535
- {
536
- size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
537
- e = edd.vChunkOffset[partitionId];
538
- globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
539
- });
540
- for (auto& r : res) r.get();
541
- res.clear();
542
-
543
- // make all count being positive
544
- if (_tw != TermWeight::one)
545
- {
546
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
547
- }
548
- globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
549
-
550
574
  res = pool.enqueueToAll([&](size_t threadId)
551
575
  {
552
576
  localData[threadId].numByTopic = globalState.numByTopic;
@@ -560,16 +584,11 @@ namespace tomoto
560
584
  ex) document pathing at hLDA model
561
585
  * if pool is nullptr, workers has been already pooled and cannot branch works more.
562
586
  */
563
- template<typename _DocIter>
587
+ template<GlobalSampler _gs, typename _DocIter>
564
588
  void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
565
589
  {
566
590
  }
567
591
 
568
- template<typename _DocIter>
569
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
570
- {
571
- }
572
-
573
592
  template<typename _DocIter>
574
593
  double getLLDocs(_DocIter _first, _DocIter _last) const
575
594
  {
@@ -592,16 +611,33 @@ namespace tomoto
592
611
  double ll = 0;
593
612
  const size_t V = this->realV;
594
613
  // topic-word distribution
595
- auto lgammaEta = math::lgammaT(eta);
596
- ll += math::lgammaT(V*eta) * K;
597
- for (Tid k = 0; k < K; ++k)
614
+ if (etaByTopicWord.size())
598
615
  {
599
- ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
600
- for (Vid v = 0; v < V; ++v)
616
+ for (Tid k = 0; k < K; ++k)
601
617
  {
602
- if (!ld.numByTopicWord(k, v)) continue;
603
- ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
604
- assert(std::isfinite(ll));
618
+ Float etasum = etaByTopicWord.row(k).sum();
619
+ ll += math::lgammaT(etasum) - math::lgammaT(ld.numByTopic[k] + etasum);
620
+ for (Vid v = 0; v < V; ++v)
621
+ {
622
+ if (!ld.numByTopicWord(k, v)) continue;
623
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
624
+ assert(std::isfinite(ll));
625
+ }
626
+ }
627
+ }
628
+ else
629
+ {
630
+ auto lgammaEta = math::lgammaT(eta);
631
+ ll += math::lgammaT(V * eta) * K;
632
+ for (Tid k = 0; k < K; ++k)
633
+ {
634
+ ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
635
+ for (Vid v = 0; v < V; ++v)
636
+ {
637
+ if (!ld.numByTopicWord(k, v)) continue;
638
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
639
+ assert(std::isfinite(ll));
640
+ }
605
641
  }
606
642
  }
607
643
  return ll;
@@ -637,9 +673,9 @@ namespace tomoto
637
673
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
638
674
  {
639
675
  sortAndWriteOrder(doc.words, doc.wOrder);
640
- doc.numByTopic.init(getTopicDocPtr(docId), K);
676
+ doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
641
677
  doc.Zs = tvector<Tid>(wordSize);
642
- if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize, 1);
678
+ if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
643
679
  }
644
680
 
645
681
  void prepareWordPriors()
@@ -664,7 +700,8 @@ namespace tomoto
664
700
  if (initDocs)
665
701
  {
666
702
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
667
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
703
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
704
+ this->globalState.numByTopicWord.init(nullptr, K, V);
668
705
  }
669
706
  if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
670
707
  }
@@ -791,12 +828,18 @@ namespace tomoto
791
828
  for (size_t i = 0; i < maxIter; ++i)
792
829
  {
793
830
  std::vector<std::future<void>> res;
794
- performSampling<_ps, true>(pool,
831
+ static_cast<const DerivedClass*>(this)->template performSampling<_ps, true>(pool,
795
832
  (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
796
- docFirst, docLast, edd);
833
+ docFirst, docLast, edd
834
+ );
797
835
  static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
798
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
799
- &pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast);
836
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, rgs.data(),
837
+ docFirst, docLast
838
+ );
839
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
840
+ &pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast
841
+ );
842
+ static_cast<const DerivedClass*>(this)->template distributeMergedState<_ps>(pool, tmpState, localData.data());
800
843
  }
801
844
  double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
802
845
  ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
@@ -817,7 +860,9 @@ namespace tomoto
817
860
  {
818
861
  static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
819
862
  static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
820
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
863
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, &rgc,
864
+ &*d, &*d + 1);
865
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
821
866
  &pool, &tmpState, &rgc, &*d, &*d + 1);
822
867
  }
823
868
  double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
@@ -845,7 +890,9 @@ namespace tomoto
845
890
  static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
846
891
  *d, edd, -1, tmpState, rgc, i
847
892
  );
848
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
893
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(nullptr, tmpState, &rgc,
894
+ &*d, &*d + 1);
895
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
849
896
  nullptr, &tmpState, &rgc, &*d, &*d + 1
850
897
  );
851
898
  }
@@ -1036,7 +1083,7 @@ namespace tomoto
1036
1083
  template<typename _TopicModel>
1037
1084
  void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
1038
1085
  {
1039
- numByTopic.init(ptr, mdl.getK());
1086
+ numByTopic.init(ptr, mdl.getK(), 1);
1040
1087
  for (size_t i = 0; i < Zs.size(); ++i)
1041
1088
  {
1042
1089
  if (this->words[i] >= mdl.getV()) continue;