tomoto 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -28,7 +28,8 @@ namespace tomoto
|
|
28
28
|
typename _Interface = IHPAModel,
|
29
29
|
typename _Derived = void,
|
30
30
|
typename _DocType = DocumentHPA<_tw>,
|
31
|
-
typename _ModelState = ModelStateHPA<_tw
|
31
|
+
typename _ModelState = ModelStateHPA<_tw>
|
32
|
+
>
|
32
33
|
class HPAModel : public LDAModel<_tw, _RandGen, 0, _Interface,
|
33
34
|
typename std::conditional<std::is_same<_Derived, void>::value, HPAModel<_tw, _RandGen, _Exclusive>, _Derived>::type,
|
34
35
|
_DocType, _ModelState>
|
@@ -250,8 +251,6 @@ namespace tomoto
|
|
250
251
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
251
252
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
252
253
|
{
|
253
|
-
std::vector<std::future<void>> res;
|
254
|
-
|
255
254
|
tState = globalState;
|
256
255
|
globalState = localData[0];
|
257
256
|
for (size_t i = 1; i < pool.getNumWorkers(); ++i)
|
@@ -276,15 +275,6 @@ namespace tomoto
|
|
276
275
|
globalState.numByTopicWord[1] = globalState.numByTopicWord[1].cwiseMax(0);
|
277
276
|
globalState.numByTopicWord[2] = globalState.numByTopicWord[2].cwiseMax(0);
|
278
277
|
}
|
279
|
-
|
280
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
281
|
-
{
|
282
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
283
|
-
{
|
284
|
-
localData[i] = globalState;
|
285
|
-
}));
|
286
|
-
}
|
287
|
-
for (auto& r : res) r.get();
|
288
278
|
}
|
289
279
|
|
290
280
|
std::vector<uint64_t> _getTopicsCount() const
|
@@ -379,7 +369,7 @@ namespace tomoto
|
|
379
369
|
|
380
370
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
381
371
|
{
|
382
|
-
doc.numByTopic.init(nullptr, this->K + 1);
|
372
|
+
doc.numByTopic.init(nullptr, this->K + 1, 1);
|
383
373
|
doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
|
384
374
|
doc.Zs = tvector<Tid>(wordSize);
|
385
375
|
doc.Z2s = tvector<Tid>(wordSize);
|
@@ -575,7 +565,7 @@ namespace tomoto
|
|
575
565
|
template<typename _TopicModel>
|
576
566
|
void DocumentHPA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
577
567
|
{
|
578
|
-
this->numByTopic.init(ptr, mdl.getK() + 1);
|
568
|
+
this->numByTopic.init(ptr, mdl.getK() + 1, 1);
|
579
569
|
this->numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(mdl.getK(), mdl.getK2() + 1);
|
580
570
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
581
571
|
{
|
@@ -5,32 +5,67 @@ namespace tomoto
|
|
5
5
|
{
|
6
6
|
enum class TermWeight { one, idf, pmi, size };
|
7
7
|
|
8
|
-
template<typename _Scalar>
|
9
|
-
struct
|
8
|
+
template<typename _Scalar, Eigen::Index _rows, Eigen::Index _cols>
|
9
|
+
struct ShareableMatrix : Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>
|
10
10
|
{
|
11
|
-
Eigen::Matrix<_Scalar,
|
12
|
-
|
13
|
-
|
11
|
+
using BaseType = Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>;
|
12
|
+
Eigen::Matrix<_Scalar, _rows, _cols> ownData;
|
13
|
+
|
14
|
+
ShareableMatrix(_Scalar* ptr = nullptr, Eigen::Index rows = 0, Eigen::Index cols = 0)
|
15
|
+
: BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0)
|
14
16
|
{
|
15
|
-
init(ptr,
|
17
|
+
init(ptr, rows, cols);
|
16
18
|
}
|
17
19
|
|
18
|
-
|
20
|
+
ShareableMatrix(const ShareableMatrix& o)
|
21
|
+
: BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0), ownData{ o.ownData }
|
19
22
|
{
|
20
|
-
if (
|
23
|
+
if (o.ownData.data())
|
24
|
+
{
|
25
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
26
|
+
}
|
27
|
+
else
|
21
28
|
{
|
22
|
-
|
29
|
+
new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
ShareableMatrix(ShareableMatrix&& o) = default;
|
34
|
+
|
35
|
+
ShareableMatrix& operator=(const ShareableMatrix& o)
|
36
|
+
{
|
37
|
+
if (o.ownData.data())
|
38
|
+
{
|
39
|
+
ownData = o.ownData;
|
40
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
41
|
+
}
|
42
|
+
else
|
43
|
+
{
|
44
|
+
new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
|
45
|
+
}
|
46
|
+
return *this;
|
47
|
+
}
|
48
|
+
|
49
|
+
ShareableMatrix& operator=(ShareableMatrix&& o) = default;
|
50
|
+
|
51
|
+
void init(_Scalar* ptr, Eigen::Index rows, Eigen::Index cols)
|
52
|
+
{
|
53
|
+
if (!ptr && rows && cols)
|
54
|
+
{
|
55
|
+
ownData = Eigen::Matrix<_Scalar, _rows, _cols>::Zero(_rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
|
23
56
|
ptr = ownData.data();
|
24
57
|
}
|
25
|
-
|
26
|
-
|
27
|
-
|
58
|
+
else
|
59
|
+
{
|
60
|
+
ownData = Eigen::Matrix<_Scalar, _rows, _cols>{};
|
61
|
+
}
|
62
|
+
new (this) BaseType(ptr, _rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
|
28
63
|
}
|
29
64
|
|
30
|
-
void conservativeResize(size_t
|
65
|
+
void conservativeResize(size_t newRows, size_t newCols)
|
31
66
|
{
|
32
|
-
ownData.conservativeResize(
|
33
|
-
|
67
|
+
ownData.conservativeResize(_rows != -1 ? _rows : newRows, _cols != -1 ? _cols : newCols);
|
68
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
34
69
|
}
|
35
70
|
|
36
71
|
void becomeOwner()
|
@@ -38,9 +73,26 @@ namespace tomoto
|
|
38
73
|
if (ownData.data() != this->m_data)
|
39
74
|
{
|
40
75
|
ownData = *this;
|
41
|
-
|
76
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
42
77
|
}
|
43
78
|
}
|
79
|
+
|
80
|
+
void serializerRead(std::istream& istr)
|
81
|
+
{
|
82
|
+
uint32_t rows = serializer::readFromStream<uint32_t>(istr);
|
83
|
+
uint32_t cols = serializer::readFromStream<uint32_t>(istr);
|
84
|
+
init(nullptr, rows, cols);
|
85
|
+
if (!istr.read((char*)this->data(), sizeof(_Scalar) * this->size()))
|
86
|
+
throw std::ios_base::failure(std::string("reading type '") + typeid(_Scalar).name() + std::string("' is failed"));
|
87
|
+
}
|
88
|
+
|
89
|
+
void serializerWrite(std::ostream& ostr) const
|
90
|
+
{
|
91
|
+
serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->rows());
|
92
|
+
serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->cols());
|
93
|
+
if (!ostr.write((const char*)this->data(), sizeof(_Scalar) * this->size()))
|
94
|
+
throw std::ios_base::failure(std::string("writing type '") + typeid(_Scalar).name() + std::string("' is failed"));
|
95
|
+
}
|
44
96
|
};
|
45
97
|
|
46
98
|
template<typename _Base, TermWeight _tw>
|
@@ -85,7 +137,7 @@ namespace tomoto
|
|
85
137
|
|
86
138
|
tvector<Tid> Zs;
|
87
139
|
tvector<Float> wordWeights;
|
88
|
-
|
140
|
+
ShareableMatrix<WeightType, -1, 1> numByTopic;
|
89
141
|
|
90
142
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 0, Zs, wordWeights);
|
91
143
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 1, 0x00010001, Zs, wordWeights);
|
@@ -163,7 +163,7 @@ namespace tomoto
|
|
163
163
|
{
|
164
164
|
res.emplace_back(pool.enqueue([&, this, ch, chStride](size_t threadId)
|
165
165
|
{
|
166
|
-
|
166
|
+
forShuffled((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
|
167
167
|
{
|
168
168
|
static_cast<DerivedClass*>(this)->template sampleDocument<ParallelScheme::copy_merge>(
|
169
169
|
this->docs[id * chStride + ch], 0, id * chStride + ch,
|
@@ -58,7 +58,8 @@ namespace tomoto
|
|
58
58
|
|
59
59
|
Eigen::Matrix<Float, -1, 1> zLikelihood;
|
60
60
|
Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
|
61
|
-
Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
61
|
+
//Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
62
|
+
ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
62
63
|
DEFINE_SERIALIZER(numByTopic, numByTopicWord);
|
63
64
|
};
|
64
65
|
|
@@ -137,7 +138,8 @@ namespace tomoto
|
|
137
138
|
typename _Interface,
|
138
139
|
typename _Derived,
|
139
140
|
typename _DocType,
|
140
|
-
typename _ModelState
|
141
|
+
typename _ModelState
|
142
|
+
>
|
141
143
|
class HDPModel;
|
142
144
|
|
143
145
|
template<TermWeight _tw, typename _RandGen,
|
@@ -145,7 +147,8 @@ namespace tomoto
|
|
145
147
|
typename _Interface = ILDAModel,
|
146
148
|
typename _Derived = void,
|
147
149
|
typename _DocType = DocumentLDA<_tw>,
|
148
|
-
typename _ModelState = ModelStateLDA<_tw
|
150
|
+
typename _ModelState = ModelStateLDA<_tw>
|
151
|
+
>
|
149
152
|
class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
|
150
153
|
typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
151
154
|
_DocType, _ModelState>,
|
@@ -306,25 +309,23 @@ namespace tomoto
|
|
306
309
|
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
307
310
|
}
|
308
311
|
|
309
|
-
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
310
|
-
|
311
312
|
for (size_t w = b; w < e; ++w)
|
312
313
|
{
|
313
314
|
if (doc.words[w] >= this->realV) continue;
|
314
|
-
addWordTo<-1>(ld, doc, w, doc.words[w]
|
315
|
+
static_cast<const DerivedClass*>(this)->template addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
315
316
|
Float* dist;
|
316
317
|
if (etaByTopicWord.size())
|
317
318
|
{
|
318
319
|
dist = static_cast<const DerivedClass*>(this)->template
|
319
|
-
getZLikelihoods<true>(ld, doc, docId, doc.words[w]
|
320
|
+
getZLikelihoods<true>(ld, doc, docId, doc.words[w]);
|
320
321
|
}
|
321
322
|
else
|
322
323
|
{
|
323
324
|
dist = static_cast<const DerivedClass*>(this)->template
|
324
|
-
getZLikelihoods<false>(ld, doc, docId, doc.words[w]
|
325
|
+
getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
|
325
326
|
}
|
326
327
|
doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
|
327
|
-
addWordTo<1>(ld, doc, w, doc.words[w]
|
328
|
+
static_cast<const DerivedClass*>(this)->template addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
328
329
|
}
|
329
330
|
}
|
330
331
|
|
@@ -335,7 +336,7 @@ namespace tomoto
|
|
335
336
|
// single-threaded sampling
|
336
337
|
if (_ps == ParallelScheme::none)
|
337
338
|
{
|
338
|
-
|
339
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
|
339
340
|
{
|
340
341
|
static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
|
341
342
|
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
@@ -344,7 +345,7 @@ namespace tomoto
|
|
344
345
|
|
345
346
|
});
|
346
347
|
}
|
347
|
-
// multi-threaded sampling on partition
|
348
|
+
// multi-threaded sampling on partition and update into global
|
348
349
|
else if (_ps == ParallelScheme::partition)
|
349
350
|
{
|
350
351
|
const size_t chStride = pool.getNumWorkers();
|
@@ -353,7 +354,7 @@ namespace tomoto
|
|
353
354
|
res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
|
354
355
|
{
|
355
356
|
size_t didx = (i + partitionId) % chStride;
|
356
|
-
|
357
|
+
forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
|
357
358
|
{
|
358
359
|
if (i == 0)
|
359
360
|
{
|
@@ -380,7 +381,7 @@ namespace tomoto
|
|
380
381
|
{
|
381
382
|
res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
|
382
383
|
{
|
383
|
-
|
384
|
+
forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
|
384
385
|
{
|
385
386
|
static_cast<const DerivedClass*>(this)->presampleDocument(
|
386
387
|
docFirst[id * chStride + ch], id * chStride + ch,
|
@@ -396,6 +397,16 @@ namespace tomoto
|
|
396
397
|
for (auto& r : res) r.get();
|
397
398
|
res.clear();
|
398
399
|
}
|
400
|
+
else
|
401
|
+
{
|
402
|
+
throw std::runtime_error{ "Unsupported ParallelScheme" };
|
403
|
+
}
|
404
|
+
}
|
405
|
+
|
406
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter>
|
407
|
+
void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
|
408
|
+
_DocIter docFirst, _DocIter docLast) const
|
409
|
+
{
|
399
410
|
}
|
400
411
|
|
401
412
|
template<typename _DocIter, typename _ExtraDocData>
|
@@ -444,7 +455,8 @@ namespace tomoto
|
|
444
455
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
445
456
|
e = edd.vChunkOffset[partitionId];
|
446
457
|
|
447
|
-
localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
458
|
+
//localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
459
|
+
localData[partitionId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
|
448
460
|
localData[partitionId].numByTopic = globalState.numByTopic;
|
449
461
|
if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
|
450
462
|
});
|
@@ -467,16 +479,29 @@ namespace tomoto
|
|
467
479
|
}
|
468
480
|
|
469
481
|
template<ParallelScheme _ps>
|
470
|
-
void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
482
|
+
void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, bool freeze_topics = false)
|
471
483
|
{
|
472
484
|
std::vector<std::future<void>> res;
|
473
485
|
try
|
474
486
|
{
|
475
|
-
performSampling<_ps, false>(pool, localData, rgs, res,
|
476
|
-
this->docs.begin(), this->docs.end(), eddTrain
|
487
|
+
static_cast<DerivedClass*>(this)->template performSampling<_ps, false>(pool, localData, rgs, res,
|
488
|
+
this->docs.begin(), this->docs.end(), eddTrain
|
489
|
+
);
|
477
490
|
static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
|
478
491
|
static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
|
479
|
-
static_cast<DerivedClass*>(this)->template
|
492
|
+
static_cast<DerivedClass*>(this)->template performSamplingGlobal<_ps, false>(&pool, this->globalState, rgs,
|
493
|
+
this->docs.begin(), this->docs.end()
|
494
|
+
);
|
495
|
+
|
496
|
+
if(freeze_topics) static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::freeze_topics>(
|
497
|
+
&pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
|
498
|
+
);
|
499
|
+
else static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::train>(
|
500
|
+
&pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
|
501
|
+
);
|
502
|
+
|
503
|
+
static_cast<DerivedClass*>(this)->template distributeMergedState<_ps>(pool, this->globalState, localData);
|
504
|
+
|
480
505
|
if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
|
481
506
|
{
|
482
507
|
static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
|
@@ -503,8 +528,6 @@ namespace tomoto
|
|
503
528
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
504
529
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
505
530
|
{
|
506
|
-
std::vector<std::future<void>> res;
|
507
|
-
|
508
531
|
if (_ps == ParallelScheme::copy_merge)
|
509
532
|
{
|
510
533
|
tState = globalState;
|
@@ -517,10 +540,27 @@ namespace tomoto
|
|
517
540
|
// make all count being positive
|
518
541
|
if (_tw != TermWeight::one)
|
519
542
|
{
|
520
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
543
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
544
|
+
}
|
545
|
+
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
546
|
+
}
|
547
|
+
else if (_ps == ParallelScheme::partition)
|
548
|
+
{
|
549
|
+
// make all count being positive
|
550
|
+
if (_tw != TermWeight::one)
|
551
|
+
{
|
552
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
521
553
|
}
|
522
554
|
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
555
|
+
}
|
556
|
+
}
|
523
557
|
|
558
|
+
template<ParallelScheme _ps>
|
559
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
560
|
+
{
|
561
|
+
std::vector<std::future<void>> res;
|
562
|
+
if (_ps == ParallelScheme::copy_merge)
|
563
|
+
{
|
524
564
|
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
525
565
|
{
|
526
566
|
res.emplace_back(pool.enqueue([&, i](size_t)
|
@@ -531,22 +571,6 @@ namespace tomoto
|
|
531
571
|
}
|
532
572
|
else if (_ps == ParallelScheme::partition)
|
533
573
|
{
|
534
|
-
res = pool.enqueueToAll([&](size_t partitionId)
|
535
|
-
{
|
536
|
-
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
537
|
-
e = edd.vChunkOffset[partitionId];
|
538
|
-
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
539
|
-
});
|
540
|
-
for (auto& r : res) r.get();
|
541
|
-
res.clear();
|
542
|
-
|
543
|
-
// make all count being positive
|
544
|
-
if (_tw != TermWeight::one)
|
545
|
-
{
|
546
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
547
|
-
}
|
548
|
-
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
549
|
-
|
550
574
|
res = pool.enqueueToAll([&](size_t threadId)
|
551
575
|
{
|
552
576
|
localData[threadId].numByTopic = globalState.numByTopic;
|
@@ -560,16 +584,11 @@ namespace tomoto
|
|
560
584
|
ex) document pathing at hLDA model
|
561
585
|
* if pool is nullptr, workers has been already pooled and cannot branch works more.
|
562
586
|
*/
|
563
|
-
template<typename _DocIter>
|
587
|
+
template<GlobalSampler _gs, typename _DocIter>
|
564
588
|
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
565
589
|
{
|
566
590
|
}
|
567
591
|
|
568
|
-
template<typename _DocIter>
|
569
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
|
570
|
-
{
|
571
|
-
}
|
572
|
-
|
573
592
|
template<typename _DocIter>
|
574
593
|
double getLLDocs(_DocIter _first, _DocIter _last) const
|
575
594
|
{
|
@@ -592,16 +611,33 @@ namespace tomoto
|
|
592
611
|
double ll = 0;
|
593
612
|
const size_t V = this->realV;
|
594
613
|
// topic-word distribution
|
595
|
-
|
596
|
-
ll += math::lgammaT(V*eta) * K;
|
597
|
-
for (Tid k = 0; k < K; ++k)
|
614
|
+
if (etaByTopicWord.size())
|
598
615
|
{
|
599
|
-
|
600
|
-
for (Vid v = 0; v < V; ++v)
|
616
|
+
for (Tid k = 0; k < K; ++k)
|
601
617
|
{
|
602
|
-
|
603
|
-
ll += math::lgammaT(ld.
|
604
|
-
|
618
|
+
Float etasum = etaByTopicWord.row(k).sum();
|
619
|
+
ll += math::lgammaT(etasum) - math::lgammaT(ld.numByTopic[k] + etasum);
|
620
|
+
for (Vid v = 0; v < V; ++v)
|
621
|
+
{
|
622
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
623
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
|
624
|
+
assert(std::isfinite(ll));
|
625
|
+
}
|
626
|
+
}
|
627
|
+
}
|
628
|
+
else
|
629
|
+
{
|
630
|
+
auto lgammaEta = math::lgammaT(eta);
|
631
|
+
ll += math::lgammaT(V * eta) * K;
|
632
|
+
for (Tid k = 0; k < K; ++k)
|
633
|
+
{
|
634
|
+
ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
|
635
|
+
for (Vid v = 0; v < V; ++v)
|
636
|
+
{
|
637
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
638
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
|
639
|
+
assert(std::isfinite(ll));
|
640
|
+
}
|
605
641
|
}
|
606
642
|
}
|
607
643
|
return ll;
|
@@ -637,9 +673,9 @@ namespace tomoto
|
|
637
673
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
638
674
|
{
|
639
675
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
640
|
-
doc.numByTopic.init(getTopicDocPtr(docId), K);
|
676
|
+
doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
|
641
677
|
doc.Zs = tvector<Tid>(wordSize);
|
642
|
-
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize
|
678
|
+
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
643
679
|
}
|
644
680
|
|
645
681
|
void prepareWordPriors()
|
@@ -664,7 +700,8 @@ namespace tomoto
|
|
664
700
|
if (initDocs)
|
665
701
|
{
|
666
702
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
667
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
703
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
704
|
+
this->globalState.numByTopicWord.init(nullptr, K, V);
|
668
705
|
}
|
669
706
|
if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
|
670
707
|
}
|
@@ -791,12 +828,18 @@ namespace tomoto
|
|
791
828
|
for (size_t i = 0; i < maxIter; ++i)
|
792
829
|
{
|
793
830
|
std::vector<std::future<void>> res;
|
794
|
-
performSampling<_ps, true>(pool,
|
831
|
+
static_cast<const DerivedClass*>(this)->template performSampling<_ps, true>(pool,
|
795
832
|
(m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
|
796
|
-
docFirst, docLast, edd
|
833
|
+
docFirst, docLast, edd
|
834
|
+
);
|
797
835
|
static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
|
798
|
-
static_cast<const DerivedClass*>(this)->template
|
799
|
-
|
836
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, rgs.data(),
|
837
|
+
docFirst, docLast
|
838
|
+
);
|
839
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
840
|
+
&pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast
|
841
|
+
);
|
842
|
+
static_cast<const DerivedClass*>(this)->template distributeMergedState<_ps>(pool, tmpState, localData.data());
|
800
843
|
}
|
801
844
|
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
802
845
|
ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
|
@@ -817,7 +860,9 @@ namespace tomoto
|
|
817
860
|
{
|
818
861
|
static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
|
819
862
|
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
|
820
|
-
static_cast<const DerivedClass*>(this)->template
|
863
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, &rgc,
|
864
|
+
&*d, &*d + 1);
|
865
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
821
866
|
&pool, &tmpState, &rgc, &*d, &*d + 1);
|
822
867
|
}
|
823
868
|
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
|
@@ -845,7 +890,9 @@ namespace tomoto
|
|
845
890
|
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
|
846
891
|
*d, edd, -1, tmpState, rgc, i
|
847
892
|
);
|
848
|
-
static_cast<const DerivedClass*>(this)->template
|
893
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(nullptr, tmpState, &rgc,
|
894
|
+
&*d, &*d + 1);
|
895
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
849
896
|
nullptr, &tmpState, &rgc, &*d, &*d + 1
|
850
897
|
);
|
851
898
|
}
|
@@ -1036,7 +1083,7 @@ namespace tomoto
|
|
1036
1083
|
template<typename _TopicModel>
|
1037
1084
|
void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
|
1038
1085
|
{
|
1039
|
-
numByTopic.init(ptr, mdl.getK());
|
1086
|
+
numByTopic.init(ptr, mdl.getK(), 1);
|
1040
1087
|
for (size_t i = 0; i < Zs.size(); ++i)
|
1041
1088
|
{
|
1042
1089
|
if (this->words[i] >= mdl.getV()) continue;
|