tomoto 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -28,7 +28,8 @@ namespace tomoto
|
|
28
28
|
typename _Interface = IHPAModel,
|
29
29
|
typename _Derived = void,
|
30
30
|
typename _DocType = DocumentHPA<_tw>,
|
31
|
-
typename _ModelState = ModelStateHPA<_tw
|
31
|
+
typename _ModelState = ModelStateHPA<_tw>
|
32
|
+
>
|
32
33
|
class HPAModel : public LDAModel<_tw, _RandGen, 0, _Interface,
|
33
34
|
typename std::conditional<std::is_same<_Derived, void>::value, HPAModel<_tw, _RandGen, _Exclusive>, _Derived>::type,
|
34
35
|
_DocType, _ModelState>
|
@@ -250,8 +251,6 @@ namespace tomoto
|
|
250
251
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
251
252
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
252
253
|
{
|
253
|
-
std::vector<std::future<void>> res;
|
254
|
-
|
255
254
|
tState = globalState;
|
256
255
|
globalState = localData[0];
|
257
256
|
for (size_t i = 1; i < pool.getNumWorkers(); ++i)
|
@@ -276,15 +275,6 @@ namespace tomoto
|
|
276
275
|
globalState.numByTopicWord[1] = globalState.numByTopicWord[1].cwiseMax(0);
|
277
276
|
globalState.numByTopicWord[2] = globalState.numByTopicWord[2].cwiseMax(0);
|
278
277
|
}
|
279
|
-
|
280
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
281
|
-
{
|
282
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
283
|
-
{
|
284
|
-
localData[i] = globalState;
|
285
|
-
}));
|
286
|
-
}
|
287
|
-
for (auto& r : res) r.get();
|
288
278
|
}
|
289
279
|
|
290
280
|
std::vector<uint64_t> _getTopicsCount() const
|
@@ -379,7 +369,7 @@ namespace tomoto
|
|
379
369
|
|
380
370
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
381
371
|
{
|
382
|
-
doc.numByTopic.init(nullptr, this->K + 1);
|
372
|
+
doc.numByTopic.init(nullptr, this->K + 1, 1);
|
383
373
|
doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
|
384
374
|
doc.Zs = tvector<Tid>(wordSize);
|
385
375
|
doc.Z2s = tvector<Tid>(wordSize);
|
@@ -575,7 +565,7 @@ namespace tomoto
|
|
575
565
|
template<typename _TopicModel>
|
576
566
|
void DocumentHPA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
577
567
|
{
|
578
|
-
this->numByTopic.init(ptr, mdl.getK() + 1);
|
568
|
+
this->numByTopic.init(ptr, mdl.getK() + 1, 1);
|
579
569
|
this->numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(mdl.getK(), mdl.getK2() + 1);
|
580
570
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
581
571
|
{
|
@@ -5,32 +5,67 @@ namespace tomoto
|
|
5
5
|
{
|
6
6
|
enum class TermWeight { one, idf, pmi, size };
|
7
7
|
|
8
|
-
template<typename _Scalar>
|
9
|
-
struct
|
8
|
+
template<typename _Scalar, Eigen::Index _rows, Eigen::Index _cols>
|
9
|
+
struct ShareableMatrix : Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>
|
10
10
|
{
|
11
|
-
Eigen::Matrix<_Scalar,
|
12
|
-
|
13
|
-
|
11
|
+
using BaseType = Eigen::Map<Eigen::Matrix<_Scalar, _rows, _cols>>;
|
12
|
+
Eigen::Matrix<_Scalar, _rows, _cols> ownData;
|
13
|
+
|
14
|
+
ShareableMatrix(_Scalar* ptr = nullptr, Eigen::Index rows = 0, Eigen::Index cols = 0)
|
15
|
+
: BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0)
|
14
16
|
{
|
15
|
-
init(ptr,
|
17
|
+
init(ptr, rows, cols);
|
16
18
|
}
|
17
19
|
|
18
|
-
|
20
|
+
ShareableMatrix(const ShareableMatrix& o)
|
21
|
+
: BaseType(nullptr, _rows != -1 ? _rows : 0, _cols != -1 ? _cols : 0), ownData{ o.ownData }
|
19
22
|
{
|
20
|
-
if (
|
23
|
+
if (o.ownData.data())
|
24
|
+
{
|
25
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
26
|
+
}
|
27
|
+
else
|
21
28
|
{
|
22
|
-
|
29
|
+
new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
ShareableMatrix(ShareableMatrix&& o) = default;
|
34
|
+
|
35
|
+
ShareableMatrix& operator=(const ShareableMatrix& o)
|
36
|
+
{
|
37
|
+
if (o.ownData.data())
|
38
|
+
{
|
39
|
+
ownData = o.ownData;
|
40
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
41
|
+
}
|
42
|
+
else
|
43
|
+
{
|
44
|
+
new (this) BaseType((_Scalar*)o.data(), o.rows(), o.cols());
|
45
|
+
}
|
46
|
+
return *this;
|
47
|
+
}
|
48
|
+
|
49
|
+
ShareableMatrix& operator=(ShareableMatrix&& o) = default;
|
50
|
+
|
51
|
+
void init(_Scalar* ptr, Eigen::Index rows, Eigen::Index cols)
|
52
|
+
{
|
53
|
+
if (!ptr && rows && cols)
|
54
|
+
{
|
55
|
+
ownData = Eigen::Matrix<_Scalar, _rows, _cols>::Zero(_rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
|
23
56
|
ptr = ownData.data();
|
24
57
|
}
|
25
|
-
|
26
|
-
|
27
|
-
|
58
|
+
else
|
59
|
+
{
|
60
|
+
ownData = Eigen::Matrix<_Scalar, _rows, _cols>{};
|
61
|
+
}
|
62
|
+
new (this) BaseType(ptr, _rows != -1 ? _rows : rows, _cols != -1 ? _cols : cols);
|
28
63
|
}
|
29
64
|
|
30
|
-
void conservativeResize(size_t
|
65
|
+
void conservativeResize(size_t newRows, size_t newCols)
|
31
66
|
{
|
32
|
-
ownData.conservativeResize(
|
33
|
-
|
67
|
+
ownData.conservativeResize(_rows != -1 ? _rows : newRows, _cols != -1 ? _cols : newCols);
|
68
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
34
69
|
}
|
35
70
|
|
36
71
|
void becomeOwner()
|
@@ -38,9 +73,26 @@ namespace tomoto
|
|
38
73
|
if (ownData.data() != this->m_data)
|
39
74
|
{
|
40
75
|
ownData = *this;
|
41
|
-
|
76
|
+
new (this) BaseType(ownData.data(), ownData.rows(), ownData.cols());
|
42
77
|
}
|
43
78
|
}
|
79
|
+
|
80
|
+
void serializerRead(std::istream& istr)
|
81
|
+
{
|
82
|
+
uint32_t rows = serializer::readFromStream<uint32_t>(istr);
|
83
|
+
uint32_t cols = serializer::readFromStream<uint32_t>(istr);
|
84
|
+
init(nullptr, rows, cols);
|
85
|
+
if (!istr.read((char*)this->data(), sizeof(_Scalar) * this->size()))
|
86
|
+
throw std::ios_base::failure(std::string("reading type '") + typeid(_Scalar).name() + std::string("' is failed"));
|
87
|
+
}
|
88
|
+
|
89
|
+
void serializerWrite(std::ostream& ostr) const
|
90
|
+
{
|
91
|
+
serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->rows());
|
92
|
+
serializer::writeToStream<uint32_t>(ostr, (uint32_t)this->cols());
|
93
|
+
if (!ostr.write((const char*)this->data(), sizeof(_Scalar) * this->size()))
|
94
|
+
throw std::ios_base::failure(std::string("writing type '") + typeid(_Scalar).name() + std::string("' is failed"));
|
95
|
+
}
|
44
96
|
};
|
45
97
|
|
46
98
|
template<typename _Base, TermWeight _tw>
|
@@ -85,7 +137,7 @@ namespace tomoto
|
|
85
137
|
|
86
138
|
tvector<Tid> Zs;
|
87
139
|
tvector<Float> wordWeights;
|
88
|
-
|
140
|
+
ShareableMatrix<WeightType, -1, 1> numByTopic;
|
89
141
|
|
90
142
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 0, Zs, wordWeights);
|
91
143
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentBase, 1, 0x00010001, Zs, wordWeights);
|
@@ -163,7 +163,7 @@ namespace tomoto
|
|
163
163
|
{
|
164
164
|
res.emplace_back(pool.enqueue([&, this, ch, chStride](size_t threadId)
|
165
165
|
{
|
166
|
-
|
166
|
+
forShuffled((this->docs.size() - 1 - ch) / chStride + 1, rgs[threadId](), [&, this](size_t id)
|
167
167
|
{
|
168
168
|
static_cast<DerivedClass*>(this)->template sampleDocument<ParallelScheme::copy_merge>(
|
169
169
|
this->docs[id * chStride + ch], 0, id * chStride + ch,
|
@@ -58,7 +58,8 @@ namespace tomoto
|
|
58
58
|
|
59
59
|
Eigen::Matrix<Float, -1, 1> zLikelihood;
|
60
60
|
Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
|
61
|
-
Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
61
|
+
//Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
62
|
+
ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
62
63
|
DEFINE_SERIALIZER(numByTopic, numByTopicWord);
|
63
64
|
};
|
64
65
|
|
@@ -137,7 +138,8 @@ namespace tomoto
|
|
137
138
|
typename _Interface,
|
138
139
|
typename _Derived,
|
139
140
|
typename _DocType,
|
140
|
-
typename _ModelState
|
141
|
+
typename _ModelState
|
142
|
+
>
|
141
143
|
class HDPModel;
|
142
144
|
|
143
145
|
template<TermWeight _tw, typename _RandGen,
|
@@ -145,7 +147,8 @@ namespace tomoto
|
|
145
147
|
typename _Interface = ILDAModel,
|
146
148
|
typename _Derived = void,
|
147
149
|
typename _DocType = DocumentLDA<_tw>,
|
148
|
-
typename _ModelState = ModelStateLDA<_tw
|
150
|
+
typename _ModelState = ModelStateLDA<_tw>
|
151
|
+
>
|
149
152
|
class LDAModel : public TopicModel<_RandGen, _Flags, _Interface,
|
150
153
|
typename std::conditional<std::is_same<_Derived, void>::value, LDAModel<_tw, _RandGen, _Flags>, _Derived>::type,
|
151
154
|
_DocType, _ModelState>,
|
@@ -306,25 +309,23 @@ namespace tomoto
|
|
306
309
|
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
307
310
|
}
|
308
311
|
|
309
|
-
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
310
|
-
|
311
312
|
for (size_t w = b; w < e; ++w)
|
312
313
|
{
|
313
314
|
if (doc.words[w] >= this->realV) continue;
|
314
|
-
addWordTo<-1>(ld, doc, w, doc.words[w]
|
315
|
+
static_cast<const DerivedClass*>(this)->template addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
315
316
|
Float* dist;
|
316
317
|
if (etaByTopicWord.size())
|
317
318
|
{
|
318
319
|
dist = static_cast<const DerivedClass*>(this)->template
|
319
|
-
getZLikelihoods<true>(ld, doc, docId, doc.words[w]
|
320
|
+
getZLikelihoods<true>(ld, doc, docId, doc.words[w]);
|
320
321
|
}
|
321
322
|
else
|
322
323
|
{
|
323
324
|
dist = static_cast<const DerivedClass*>(this)->template
|
324
|
-
getZLikelihoods<false>(ld, doc, docId, doc.words[w]
|
325
|
+
getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
|
325
326
|
}
|
326
327
|
doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + K, rgs);
|
327
|
-
addWordTo<1>(ld, doc, w, doc.words[w]
|
328
|
+
static_cast<const DerivedClass*>(this)->template addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
328
329
|
}
|
329
330
|
}
|
330
331
|
|
@@ -335,7 +336,7 @@ namespace tomoto
|
|
335
336
|
// single-threaded sampling
|
336
337
|
if (_ps == ParallelScheme::none)
|
337
338
|
{
|
338
|
-
|
339
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
|
339
340
|
{
|
340
341
|
static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
|
341
342
|
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
@@ -344,7 +345,7 @@ namespace tomoto
|
|
344
345
|
|
345
346
|
});
|
346
347
|
}
|
347
|
-
// multi-threaded sampling on partition
|
348
|
+
// multi-threaded sampling on partition and update into global
|
348
349
|
else if (_ps == ParallelScheme::partition)
|
349
350
|
{
|
350
351
|
const size_t chStride = pool.getNumWorkers();
|
@@ -353,7 +354,7 @@ namespace tomoto
|
|
353
354
|
res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
|
354
355
|
{
|
355
356
|
size_t didx = (i + partitionId) % chStride;
|
356
|
-
|
357
|
+
forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - didx) / chStride, rgs[partitionId](), [&](size_t id)
|
357
358
|
{
|
358
359
|
if (i == 0)
|
359
360
|
{
|
@@ -380,7 +381,7 @@ namespace tomoto
|
|
380
381
|
{
|
381
382
|
res.emplace_back(pool.enqueue([&, ch, chStride](size_t threadId)
|
382
383
|
{
|
383
|
-
|
384
|
+
forShuffled(((size_t)std::distance(docFirst, docLast) + (chStride - 1) - ch) / chStride, rgs[threadId](), [&](size_t id)
|
384
385
|
{
|
385
386
|
static_cast<const DerivedClass*>(this)->presampleDocument(
|
386
387
|
docFirst[id * chStride + ch], id * chStride + ch,
|
@@ -396,6 +397,16 @@ namespace tomoto
|
|
396
397
|
for (auto& r : res) r.get();
|
397
398
|
res.clear();
|
398
399
|
}
|
400
|
+
else
|
401
|
+
{
|
402
|
+
throw std::runtime_error{ "Unsupported ParallelScheme" };
|
403
|
+
}
|
404
|
+
}
|
405
|
+
|
406
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter>
|
407
|
+
void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
|
408
|
+
_DocIter docFirst, _DocIter docLast) const
|
409
|
+
{
|
399
410
|
}
|
400
411
|
|
401
412
|
template<typename _DocIter, typename _ExtraDocData>
|
@@ -444,7 +455,8 @@ namespace tomoto
|
|
444
455
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
445
456
|
e = edd.vChunkOffset[partitionId];
|
446
457
|
|
447
|
-
localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
458
|
+
//localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
459
|
+
localData[partitionId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
|
448
460
|
localData[partitionId].numByTopic = globalState.numByTopic;
|
449
461
|
if (!localData[partitionId].zLikelihood.size()) localData[partitionId].zLikelihood = globalState.zLikelihood;
|
450
462
|
});
|
@@ -467,16 +479,29 @@ namespace tomoto
|
|
467
479
|
}
|
468
480
|
|
469
481
|
template<ParallelScheme _ps>
|
470
|
-
void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
482
|
+
void trainOne(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, bool freeze_topics = false)
|
471
483
|
{
|
472
484
|
std::vector<std::future<void>> res;
|
473
485
|
try
|
474
486
|
{
|
475
|
-
performSampling<_ps, false>(pool, localData, rgs, res,
|
476
|
-
this->docs.begin(), this->docs.end(), eddTrain
|
487
|
+
static_cast<DerivedClass*>(this)->template performSampling<_ps, false>(pool, localData, rgs, res,
|
488
|
+
this->docs.begin(), this->docs.end(), eddTrain
|
489
|
+
);
|
477
490
|
static_cast<DerivedClass*>(this)->updateGlobalInfo(pool, localData);
|
478
491
|
static_cast<DerivedClass*>(this)->template mergeState<_ps>(pool, this->globalState, this->tState, localData, rgs, eddTrain);
|
479
|
-
static_cast<DerivedClass*>(this)->template
|
492
|
+
static_cast<DerivedClass*>(this)->template performSamplingGlobal<_ps, false>(&pool, this->globalState, rgs,
|
493
|
+
this->docs.begin(), this->docs.end()
|
494
|
+
);
|
495
|
+
|
496
|
+
if(freeze_topics) static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::freeze_topics>(
|
497
|
+
&pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
|
498
|
+
);
|
499
|
+
else static_cast<DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::train>(
|
500
|
+
&pool, &this->globalState, rgs, this->docs.begin(), this->docs.end()
|
501
|
+
);
|
502
|
+
|
503
|
+
static_cast<DerivedClass*>(this)->template distributeMergedState<_ps>(pool, this->globalState, localData);
|
504
|
+
|
480
505
|
if (this->globalStep >= this->burnIn && optimInterval && (this->globalStep + 1) % optimInterval == 0)
|
481
506
|
{
|
482
507
|
static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
|
@@ -503,8 +528,6 @@ namespace tomoto
|
|
503
528
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
504
529
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
505
530
|
{
|
506
|
-
std::vector<std::future<void>> res;
|
507
|
-
|
508
531
|
if (_ps == ParallelScheme::copy_merge)
|
509
532
|
{
|
510
533
|
tState = globalState;
|
@@ -517,10 +540,27 @@ namespace tomoto
|
|
517
540
|
// make all count being positive
|
518
541
|
if (_tw != TermWeight::one)
|
519
542
|
{
|
520
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
543
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
544
|
+
}
|
545
|
+
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
546
|
+
}
|
547
|
+
else if (_ps == ParallelScheme::partition)
|
548
|
+
{
|
549
|
+
// make all count being positive
|
550
|
+
if (_tw != TermWeight::one)
|
551
|
+
{
|
552
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
521
553
|
}
|
522
554
|
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
555
|
+
}
|
556
|
+
}
|
523
557
|
|
558
|
+
template<ParallelScheme _ps>
|
559
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
560
|
+
{
|
561
|
+
std::vector<std::future<void>> res;
|
562
|
+
if (_ps == ParallelScheme::copy_merge)
|
563
|
+
{
|
524
564
|
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
525
565
|
{
|
526
566
|
res.emplace_back(pool.enqueue([&, i](size_t)
|
@@ -531,22 +571,6 @@ namespace tomoto
|
|
531
571
|
}
|
532
572
|
else if (_ps == ParallelScheme::partition)
|
533
573
|
{
|
534
|
-
res = pool.enqueueToAll([&](size_t partitionId)
|
535
|
-
{
|
536
|
-
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
537
|
-
e = edd.vChunkOffset[partitionId];
|
538
|
-
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
539
|
-
});
|
540
|
-
for (auto& r : res) r.get();
|
541
|
-
res.clear();
|
542
|
-
|
543
|
-
// make all count being positive
|
544
|
-
if (_tw != TermWeight::one)
|
545
|
-
{
|
546
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
547
|
-
}
|
548
|
-
globalState.numByTopic = globalState.numByTopicWord.rowwise().sum();
|
549
|
-
|
550
574
|
res = pool.enqueueToAll([&](size_t threadId)
|
551
575
|
{
|
552
576
|
localData[threadId].numByTopic = globalState.numByTopic;
|
@@ -560,16 +584,11 @@ namespace tomoto
|
|
560
584
|
ex) document pathing at hLDA model
|
561
585
|
* if pool is nullptr, workers has been already pooled and cannot branch works more.
|
562
586
|
*/
|
563
|
-
template<typename _DocIter>
|
587
|
+
template<GlobalSampler _gs, typename _DocIter>
|
564
588
|
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
565
589
|
{
|
566
590
|
}
|
567
591
|
|
568
|
-
template<typename _DocIter>
|
569
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
|
570
|
-
{
|
571
|
-
}
|
572
|
-
|
573
592
|
template<typename _DocIter>
|
574
593
|
double getLLDocs(_DocIter _first, _DocIter _last) const
|
575
594
|
{
|
@@ -592,16 +611,33 @@ namespace tomoto
|
|
592
611
|
double ll = 0;
|
593
612
|
const size_t V = this->realV;
|
594
613
|
// topic-word distribution
|
595
|
-
|
596
|
-
ll += math::lgammaT(V*eta) * K;
|
597
|
-
for (Tid k = 0; k < K; ++k)
|
614
|
+
if (etaByTopicWord.size())
|
598
615
|
{
|
599
|
-
|
600
|
-
for (Vid v = 0; v < V; ++v)
|
616
|
+
for (Tid k = 0; k < K; ++k)
|
601
617
|
{
|
602
|
-
|
603
|
-
ll += math::lgammaT(ld.
|
604
|
-
|
618
|
+
Float etasum = etaByTopicWord.row(k).sum();
|
619
|
+
ll += math::lgammaT(etasum) - math::lgammaT(ld.numByTopic[k] + etasum);
|
620
|
+
for (Vid v = 0; v < V; ++v)
|
621
|
+
{
|
622
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
623
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
|
624
|
+
assert(std::isfinite(ll));
|
625
|
+
}
|
626
|
+
}
|
627
|
+
}
|
628
|
+
else
|
629
|
+
{
|
630
|
+
auto lgammaEta = math::lgammaT(eta);
|
631
|
+
ll += math::lgammaT(V * eta) * K;
|
632
|
+
for (Tid k = 0; k < K; ++k)
|
633
|
+
{
|
634
|
+
ll -= math::lgammaT(ld.numByTopic[k] + V * eta);
|
635
|
+
for (Vid v = 0; v < V; ++v)
|
636
|
+
{
|
637
|
+
if (!ld.numByTopicWord(k, v)) continue;
|
638
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + eta) - lgammaEta;
|
639
|
+
assert(std::isfinite(ll));
|
640
|
+
}
|
605
641
|
}
|
606
642
|
}
|
607
643
|
return ll;
|
@@ -637,9 +673,9 @@ namespace tomoto
|
|
637
673
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
638
674
|
{
|
639
675
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
640
|
-
doc.numByTopic.init(getTopicDocPtr(docId), K);
|
676
|
+
doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
|
641
677
|
doc.Zs = tvector<Tid>(wordSize);
|
642
|
-
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize
|
678
|
+
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
643
679
|
}
|
644
680
|
|
645
681
|
void prepareWordPriors()
|
@@ -664,7 +700,8 @@ namespace tomoto
|
|
664
700
|
if (initDocs)
|
665
701
|
{
|
666
702
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
667
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
703
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
704
|
+
this->globalState.numByTopicWord.init(nullptr, K, V);
|
668
705
|
}
|
669
706
|
if(m_flags & flags::continuous_doc_data) numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(K, this->docs.size());
|
670
707
|
}
|
@@ -791,12 +828,18 @@ namespace tomoto
|
|
791
828
|
for (size_t i = 0; i < maxIter; ++i)
|
792
829
|
{
|
793
830
|
std::vector<std::future<void>> res;
|
794
|
-
performSampling<_ps, true>(pool,
|
831
|
+
static_cast<const DerivedClass*>(this)->template performSampling<_ps, true>(pool,
|
795
832
|
(m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), res,
|
796
|
-
docFirst, docLast, edd
|
833
|
+
docFirst, docLast, edd
|
834
|
+
);
|
797
835
|
static_cast<const DerivedClass*>(this)->template mergeState<_ps>(pool, tmpState, tState, localData.data(), rgs.data(), edd);
|
798
|
-
static_cast<const DerivedClass*>(this)->template
|
799
|
-
|
836
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, rgs.data(),
|
837
|
+
docFirst, docLast
|
838
|
+
);
|
839
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
840
|
+
&pool, (m_flags & flags::shared_state) ? &tmpState : localData.data(), rgs.data(), docFirst, docLast
|
841
|
+
);
|
842
|
+
static_cast<const DerivedClass*>(this)->template distributeMergedState<_ps>(pool, tmpState, localData.data());
|
800
843
|
}
|
801
844
|
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - static_cast<const DerivedClass*>(this)->getLLRest(this->globalState);
|
802
845
|
ll += static_cast<const DerivedClass*>(this)->template getLLDocs<>(docFirst, docLast);
|
@@ -817,7 +860,9 @@ namespace tomoto
|
|
817
860
|
{
|
818
861
|
static_cast<const DerivedClass*>(this)->presampleDocument(*d, -1, tmpState, rgc, i);
|
819
862
|
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(*d, edd, -1, tmpState, rgc, i);
|
820
|
-
static_cast<const DerivedClass*>(this)->template
|
863
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(&pool, tmpState, &rgc,
|
864
|
+
&*d, &*d + 1);
|
865
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
821
866
|
&pool, &tmpState, &rgc, &*d, &*d + 1);
|
822
867
|
}
|
823
868
|
double ll = static_cast<const DerivedClass*>(this)->getLLRest(tmpState) - gllRest;
|
@@ -845,7 +890,9 @@ namespace tomoto
|
|
845
890
|
static_cast<const DerivedClass*>(this)->template sampleDocument<ParallelScheme::none, true>(
|
846
891
|
*d, edd, -1, tmpState, rgc, i
|
847
892
|
);
|
848
|
-
static_cast<const DerivedClass*>(this)->template
|
893
|
+
static_cast<const DerivedClass*>(this)->template performSamplingGlobal<_ps, true>(nullptr, tmpState, &rgc,
|
894
|
+
&*d, &*d + 1);
|
895
|
+
static_cast<const DerivedClass*>(this)->template sampleGlobalLevel<GlobalSampler::inference>(
|
849
896
|
nullptr, &tmpState, &rgc, &*d, &*d + 1
|
850
897
|
);
|
851
898
|
}
|
@@ -1036,7 +1083,7 @@ namespace tomoto
|
|
1036
1083
|
template<typename _TopicModel>
|
1037
1084
|
void DocumentLDA<_tw>::update(WeightType* ptr, const _TopicModel& mdl)
|
1038
1085
|
{
|
1039
|
-
numByTopic.init(ptr, mdl.getK());
|
1086
|
+
numByTopic.init(ptr, mdl.getK(), 1);
|
1040
1087
|
for (size_t i = 0; i < Zs.size(); ++i)
|
1041
1088
|
{
|
1042
1089
|
if (this->words[i] >= mdl.getV()) continue;
|