tomoto 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -28,7 +28,8 @@ namespace tomoto
28
28
  typename _Interface = IHPAModel,
29
29
  typename _Derived = void,
30
30
  typename _DocType = DocumentHPA<_tw>,
31
- typename _ModelState = ModelStateHPA<_tw>>
31
+ typename _ModelState = ModelStateHPA<_tw>
32
+ >
32
33
  class HPAModel : public LDAModel<_tw, _RandGen, 0, _Interface,
33
34
  typename std::conditional<std::is_same<_Derived, void>::value, HPAModel<_tw, _RandGen, _Exclusive>, _Derived>::type,
34
35
  _DocType, _ModelState>
@@ -250,8 +251,6 @@ namespace tomoto
250
251
  template<ParallelScheme _ps, typename _ExtraDocData>
251
252
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
252
253
  {
253
- std::vector<std::future<void>> res;
254
-
255
254
  tState = globalState;
256
255
  globalState = localData[0];
257
256
  for (size_t i = 1; i < pool.getNumWorkers(); ++i)
@@ -276,15 +275,6 @@ namespace tomoto
276
275
  globalState.numByTopicWord[1] = globalState.numByTopicWord[1].cwiseMax(0);
277
276
  globalState.numByTopicWord[2] = globalState.numByTopicWord[2].cwiseMax(0);
278
277
  }
279
-
280
- for (size_t i = 0; i < pool.getNumWorkers(); ++i)
281
- {
282
- res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
283
- {
284
- localData[i] = globalState;
285
- }));
286
- }
287
- for (auto& r : res) r.get();
288
278
  }
289
279
 
290
280
  std::vector<uint64_t> _getTopicsCount() const
@@ -379,7 +369,7 @@ namespace tomoto
379
369
 
380
370
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
381
371
  {
382
- doc.numByTopic.init(nullptr, this->K + 1);
372
+ doc.numByTopic.init(nullptr, this->K + 1, 1);
383
373
  doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
384
374
  doc.Zs = tvector<Tid>(wordSize);
385
375
  doc.Z2s = tvector<Tid>(wordSize);
@@ -575,7 +565,7 @@ namespace tomoto
575
565
  template<typename _TopicModel>
576
566
  void DocumentHPA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
577
567
  {
578
- this->numByTopic.init(ptr, mdl.getK() + 1);
568
+ this->numByTopic.init(ptr, mdl.getK() + 1, 1);
579
569
  this->numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(mdl.getK(), mdl.getK2() + 1);
580
570
  for (size_t i = 0; i < this->Zs.size(); ++i)
581
571
  {
@@ -5,32 +5,67 @@ namespace tomoto
5
5
  {
6
6
  enum class TermWeight { one, idf, pmi, size };
7
7
 
8
- template<typename _Scalar>
9
- struct ShareableVector : Eigen::Map<Eigen::Matrix<_Scalar, -1, 1>>
8
+ template<typename _Scalar, Eigen::Index _rows, Eigen::Index _cols>
9
+ struct ShareableMatrix : Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>
10
10
  {
11
- Eigen::Matrix<_Scalar, -1, 1> ownData;
12
- ShareableVector(_Scalar* ptr = nullptr, Eigen::Index len = 0)
13
- : Eigen::Map<Eigen::Matrix<_Scalar, -1, 1>>(nullptr, 0)
11
+ using BaseType = Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>;
12
+ Eigen::Matrix<_Scalar, _rows, _cols> ownData;
13
+
14
+ ShareableMatrix(_Scalar* ptr = nullptr, Eigen::Index rows = 0, Eigen::Index cols = 0)
15
+ : BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0)
14
16
  {
15
- init(ptr, len);
17
+ init(ptr, rows, cols);
16
18
  }
17
19
 
18
- void init(_Scalar* ptr, Eigen::Index len)
20
+ ShareableMatrix(const ShareableMatrix& o)
21
+ : BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0), ownData{ o.ownData }
19
22
  {
20
- if (!ptr && len)
23
+ if (o.ownData.data())
24
+ {
25
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
26
+ }
27
+ else
21
28
  {
22
- ownData = Eigen::Matrix<_Scalar, -1, 1>::Zero(len);
29
+ new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
30
+ }
31
+ }
32
+
33
+ ShareableMatrix(ShareableMatrix&& o) = default;
34
+
35
+ ShareableMatrix& operator=(const ShareableMatrix& o)
36
+ {
37
+ if (o.ownData.data())
38
+ {
39
+ ownData = o.ownData;
40
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
41
+ }
42
+ else
43
+ {
44
+ new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
45
+ }
46
+ return *this;
47
+ }
48
+
49
+ ShareableMatrix& operator=(ShareableMatrix&& o) = default;
50
+
51
+ void init(_Scalar* ptr, Eigen::Index rows, Eigen::Index cols)
52
+ {
53
+ if (!ptr && rows && cols)
54
+ {
55
+ ownData = Eigen::Matrix<_Scalar, _rows, _cols>::Zero(_rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
23
56
  ptr = ownData.data();
24
57
  }
25
- // is this the best way??
26
- this->m_data = ptr;
27
- ((Eigen::internal::variable_if_dynamic<Eigen::Index, -1>*)&this->m_rows)->setValue(len);
58
+ else
59
+ {
60
+ ownData = Eigen::Matrix<_Scalar, _rows, _cols>{};
61
+ }
62
+ new (this) BaseType(ptr, _rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
28
63
  }
29
64
 
30
- void conservativeResize(size_t newSize)
65
+ void conservativeResize(size_t newRows, size_t newCols)
31
66
  {
32
- ownData.conservativeResize(newSize);
33
- init(ownData.data(), ownData.size());
67
+ ownData.conservativeResize(_rows != -1 ? _rows : newRows, _cols != -1 ? _cols : newCols);
68
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
34
69
  }
35
70
 
36
71
  void becomeOwner()
@@ -38,9 +73,26 @@ namespace tomoto
38
73
  if (ownData.data() != this->m_data)
39
74
  {
40
75
  ownData = *this;
41
- init(ownData.data(), ownData.size());
76
+ new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
42
77
  }
43
78
  }
79
+
80
+ void serializerRead(std::istream& istr)
81
+ {
82
+ uint32_t rows = serializer::readFromStream<uint32_t>(istr);
83
+ uint32_t cols = serializer::readFromStream<uint32_t>(istr);
84
+ init(nullptr, rows, cols);
85
+ if (!istr.read((char*)this->data(), sizeof(_Scalar) * this->size()))
86
+ throw std::ios_base::failure(std::string("reading type '") + typeid(_Scalar).name() + std::string("' is failed"));
87
+ }
88
+
89
+ void serializerWrite(std::ostream& ostr) const
90
+ {
91
+ serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->rows());
92
+ serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->cols());
93
+ if (!ostr.write((const char*)this->data(), sizeof(_Scalar) * this->size()))
94
+ throw std::ios_base::failure(std::string("writing type '") + typeid(_Scalar).name() + std::string("' is failed"));
95
+ }
44
96
  };
45
97
 
46
98
  template<typename _Base, TermWeight _tw>
@@ -85,7 +137,7 @@ namespace tomoto
85
137
 
86
138
  tvector<Tid> Zs;
87
139
  tvector<Float> wordWeights;
88
- ShareableVector<WeightType> numByTopic;
140
+ ShareableMatrix<WeightType, -1, 1> numByTopic;
89
141
 
90
142
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 0, Zs, wordWeights);
91
143
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 1, 0x00010001, Zs, wordWeights);
@@ -163,7 +163,7 @@ namespace tomoto
163
163
  {
164
164
  res.emplace_back(pool.enqueue([&, this, ch, chStride](size_t threadId)
165
165
  {
166
- forRandom((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
166
+ forShuffled((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
167
167
  {
168
168
  static_cast<DerivedClass*>(this)->template sampleDocument<ParallelScheme::copy_merge>(
169
169
  this->docs[id * chStride + ch], 0, id * chStride + ch,
@@ -58,7 +58,8 @@ namespace tomoto
58
58
 
59
59
  Eigen::Matrix<Float, -1, 1> zLikelihood;
60
60
  Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
61
- Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
61
+ //Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
+ ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
62
63
  DEFINE_SERIALIZER(numByTopic, numByTopicWord);
63
64
  };
64
65
 
@@ -137,7 +138,8 @@ namespace tomoto
137
138
  typename _Interface,
138
139
  typename _Derived,
139
140
  typename _DocType,
140
- typename _ModelState>
141
+ typename _ModelState
142
+ >
141
143
  class HDPModel;
142
144
 
143
145
  template<TermWeight _tw, typename _RandGen,
@@ -145,7 +147,8 @@ namespace tomoto
145
147
  typename _Interface = ILDAModel,
146
148
  typename _Derived = void,
147
149
  typename _DocType = DocumentLDA<_tw>,
148
- typename _ModelState = ModelStateLDA<_tw>>
150
+ typename _ModelState = ModelStateLDA<_tw>
151
+ >
149
152
  class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
150
153
  typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
151
154
  _DocType, _ModelState>,
@@ -306,25 +309,23 @@ namespace tomoto
306
309
  e = edd.chunkOffsetByDoc(partitionId + 1, docId);
307
310
  }
308
311
 
309
- size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
310
-
311
312
  for (size_t w = b; w < e; ++w)
312
313
  {
313
314
  if (doc.words[w] >= this->realV) continue;
314
- addWordTo<-1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
315
+ static_cast<const DerivedClass*>(this)->template addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
315
316
  Float* dist;
316
317
  if (etaByTopicWord.size())
317
318
  {
318
319
  dist = static_cast<const DerivedClass*>(this)->template
319
- getZLikelihoods<true>(ld, doc, docId, doc.words[w] - vOffset);
320
+ getZLikelihoods<true>(ld, doc, docId, doc.words[w]);
320
321
  }
321
322
  else
322
323
  {
323
324
  dist = static_cast<const DerivedClass*>(this)->template
324
- getZLikelihoods<false>(ld, doc, docId, doc.words[w] - vOffset);
325
+ getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
325
326
  }
326
327
  doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
327
- addWordTo<1>(ld, doc, w, doc.words[w] - vOffset, doc.Zs[w]);
328
+ static_cast<const DerivedClass*>(this)->template addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
328
329
  }
329
330
  }
330
331
 
@@ -335,7 +336,7 @@ namespace tomoto
335
336
  // single-threaded sampling
336
337
  if (_ps == ParallelScheme::none)
337
338
  {
338
- forRandom((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
339
+ forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
339
340
  {
340
341
  static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
341
342
  static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
@@ -344,7 +345,7 @@ namespace tomoto
344
345
 
345
346
  });
346
347
  }
347
- // multi-threaded sampling on partition ad update into global
348
+ // multi-threaded sampling on partition and update into global
348
349
  else if (_ps == ParallelScheme::partition)
349
350
  {
350
351
  const size_t chStride = pool.getNumWorkers();
@@ -353,7 +354,7 @@ namespace tomoto
353
354
  res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
354
355
  {
355
356
  size_t didx = (i + partitionId) % chStride;
356
- forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
357
+ forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
357
358
  {
358
359
  if (i == 0)
359
360
  {
@@ -380,7 +381,7 @@ namespace tomoto
380
381
  {
381
382
  res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
382
383
  {
383
- forRandom(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
384
+ forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
384
385
  {
385
386
  static_cast<const DerivedClass*>(this)->presampleDocument(
386
387
  docFirst[id * chStride + ch], id * chStride + ch,
@@ -396,6 +397,16 @@ namespace tomoto
396
397
  for (auto& r : res) r.get();
397
398
  res.clear();
398
399
  }
400
+ else
401
+ {
402
+ throw std::runtime_error{ "Unsupported ParallelScheme" };
403
+ }
404
+ }
405
+
406
+ template<ParallelScheme _ps, bool _infer, typename _DocIter>
407
+ void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
408
+ _DocIter docFirst, _DocIter docLast) const
409
+ {
399
410
  }
400
411
 
401
412
  template<typename _DocIter, typename _ExtraDocData>
@@ -444,7 +455,8 @@ namespace tomoto
444
455
  size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
445
456
  e = edd.vChunkOffset[partitionId];
446
457
 
447
- localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
458
+ //localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
459
+ localData[partitionId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
448
460
  localData[partitionId].numByTopic = globalState.numByTopic;
449
461
  if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
450
462
  });
@@ -467,16 +479,29 @@ namespace tomoto
467
479
  }
468
480
 
469
481
  template<ParallelScheme _ps>
470
- void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
482
+ void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, bool freeze_topics = false)
471
483
  {
472
484
  std::vector<std::future<void>> res;
473
485
  try
474
486
  {
475
- performSampling<_ps, false>(pool, localData, rgs, res,
476
- this->docs.begin(), this->docs.end(), eddTrain);
487
+ static_cast<DerivedClass*>(this)->template performSampling<_ps, false>(pool, localData, rgs, res,
488
+ this->docs.begin(), this->docs.end(), eddTrain
489
+ );
477
490
  static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
478
491
  static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
479
- static_cast<DerivedClass*>(this)->template sampleGlobalLevel<>(&pool, localData, rgs, this->docs.begin(), this->docs.end());
492
+ static_cast<DerivedClass*>(this)->template performSamplingGlobal<_ps, false>(&pool, this->globalState, rgs,
493
+ this->docs.begin(), this->docs.end()
494
+ );
495
+
496
+ if(freeze_topics) static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::freeze_topics>(
497
+ &pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
498
+ );
499
+ else static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::train>(
500
+ &pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
501
+ );
502
+
503
+ static_cast<DerivedClass*>(this)->template distributeMergedState<_ps>(pool, this->globalState, localData);
504
+
480
505
  if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
481
506
  {
482
507
  static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
@@ -503,8 +528,6 @@ namespace tomoto
503
528
  template<ParallelScheme _ps, typename _ExtraDocData>
504
529
  void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
505
530
  {
506
- std::vector<std::future<void>> res;
507
-
508
531
  if (_ps == ParallelScheme::copy_merge)
509
532
  {
510
533
  tState = globalState;
@@ -517,10 +540,27 @@ namespace tomoto
517
540
  // make all count being positive
518
541
  if (_tw != TermWeight::one)
519
542
  {
520
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
543
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
544
+ }
545
+ globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
546
+ }
547
+ else if (_ps == ParallelScheme::partition)
548
+ {
549
+ // make all count being positive
550
+ if (_tw != TermWeight::one)
551
+ {
552
+ globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
521
553
  }
522
554
  globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
555
+ }
556
+ }
523
557
 
558
+ template<ParallelScheme _ps>
559
+ void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
560
+ {
561
+ std::vector<std::future<void>> res;
562
+ if (_ps == ParallelScheme::copy_merge)
563
+ {
524
564
  for (size_t i = 0; i < pool.getNumWorkers(); ++i)
525
565
  {
526
566
  res.emplace_back(pool.enqueue([&, i](size_t)
@@ -531,22 +571,6 @@ namespace tomoto
531
571
  }
532
572
  else if (_ps == ParallelScheme::partition)
533
573
  {
534
- res = pool.enqueueToAll([&](size_t partitionId)
535
- {
536
- size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
537
- e = edd.vChunkOffset[partitionId];
538
- globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
539
- });
540
- for (auto& r : res) r.get();
541
- res.clear();
542
-
543
- // make all count being positive
544
- if (_tw != TermWeight::one)
545
- {
546
- globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
547
- }
548
- globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
549
-
550
574
  res = pool.enqueueToAll([&](size_t threadId)
551
575
  {
552
576
  localData[threadId].numByTopic = globalState.numByTopic;
@@ -560,16 +584,11 @@ namespace tomoto
560
584
  ex) document pathing at hLDA model
561
585
  * if pool is nullptr, workers has been already pooled and cannot branch works more.
562
586
  */
563
- template<typename _DocIter>
587
+ template<GlobalSampler _gs, typename _DocIter>
564
588
  void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
565
589
  {
566
590
  }
567
591
 
568
- template<typename _DocIter>
569
- void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
570
- {
571
- }
572
-
573
592
  template<typename _DocIter>
574
593
  double getLLDocs(_DocIter _first, _DocIter _last) const
575
594
  {
@@ -592,16 +611,33 @@ namespace tomoto
592
611
  double ll = 0;
593
612
  const size_t V = this->realV;
594
613
  // topic-word distribution
595
- auto lgammaEta = math::lgammaT(eta);
596
- ll += math::lgammaT(V*eta) * K;
597
- for (Tid k = 0; k < K; ++k)
614
+ if (etaByTopicWord.size())
598
615
  {
599
- ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
600
- for (Vid v = 0; v < V; ++v)
616
+ for (Tid k = 0; k < K; ++k)
601
617
  {
602
- if (!ld.numByTopicWord(k, v)) continue;
603
- ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
604
- assert(std::isfinite(ll));
618
+ Float etasum = etaByTopicWord.row(k).sum();
619
+ ll += math::lgammaT(etasum) - math::lgammaT(ld.numByTopic[k] + etasum);
620
+ for (Vid v = 0; v < V; ++v)
621
+ {
622
+ if (!ld.numByTopicWord(k, v)) continue;
623
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
624
+ assert(std::isfinite(ll));
625
+ }
626
+ }
627
+ }
628
+ else
629
+ {
630
+ auto lgammaEta = math::lgammaT(eta);
631
+ ll += math::lgammaT(V * eta) * K;
632
+ for (Tid k = 0; k < K; ++k)
633
+ {
634
+ ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
635
+ for (Vid v = 0; v < V; ++v)
636
+ {
637
+ if (!ld.numByTopicWord(k, v)) continue;
638
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
639
+ assert(std::isfinite(ll));
640
+ }
605
641
  }
606
642
  }
607
643
  return ll;
@@ -637,9 +673,9 @@ namespace tomoto
637
673
  void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
638
674
  {
639
675
  sortAndWriteOrder(doc.words, doc.wOrder);
640
- doc.numByTopic.init(getTopicDocPtr(docId), K);
676
+ doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
641
677
  doc.Zs = tvector<Tid>(wordSize);
642
- if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize, 1);
678
+ if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
643
679
  }
644
680
 
645
681
  void prepareWordPriors()
@@ -664,7 +700,8 @@ namespace tomoto
664
700
  if (initDocs)
665
701
  {
666
702
  this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
667
- this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
703
+ //this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
704
+ this->globalState.numByTopicWord.init(nullptr, K, V);
668
705
  }
669
706
  if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
670
707
  }
@@ -791,12 +828,18 @@ namespace tomoto
791
828
  for (size_t i = 0; i < maxIter; ++i)
792
829
  {
793
830
  std::vector<std::future<void>> res;
794
- performSampling<_ps, true>(pool,
831
+ static_cast<const DerivedClass*>(this)->template performSampling<_ps, true>(pool,
795
832
  (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
796
- docFirst, docLast, edd);
833
+ docFirst, docLast, edd
834
+ );
797
835
  static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
798
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
799
- &pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast);
836
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, rgs.data(),
837
+ docFirst, docLast
838
+ );
839
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
840
+ &pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast
841
+ );
842
+ static_cast<const DerivedClass*>(this)->template distributeMergedState<_ps>(pool, tmpState, localData.data());
800
843
  }
801
844
  double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
802
845
  ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
@@ -817,7 +860,9 @@ namespace tomoto
817
860
  {
818
861
  static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
819
862
  static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
820
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
863
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, &rgc,
864
+ &*d, &*d + 1);
865
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
821
866
  &pool, &tmpState, &rgc, &*d, &*d + 1);
822
867
  }
823
868
  double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
@@ -845,7 +890,9 @@ namespace tomoto
845
890
  static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
846
891
  *d, edd, -1, tmpState, rgc, i
847
892
  );
848
- static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<>(
893
+ static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(nullptr, tmpState, &rgc,
894
+ &*d, &*d + 1);
895
+ static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
849
896
  nullptr, &tmpState, &rgc, &*d, &*d + 1
850
897
  );
851
898
  }
@@ -1036,7 +1083,7 @@ namespace tomoto
1036
1083
  template<typename _TopicModel>
1037
1084
  void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
1038
1085
  {
1039
- numByTopic.init(ptr, mdl.getK());
1086
+ numByTopic.init(ptr, mdl.getK(), 1);
1040
1087
  for (size_t i = 0; i < Zs.size(); ++i)
1041
1088
  {
1042
1089
  if (this->words[i] >= mdl.getV()) continue;