tomoto 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -107,18 +107,16 @@ namespace tomoto
|
|
107
107
|
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
108
108
|
}
|
109
109
|
|
110
|
-
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
111
|
-
|
112
110
|
const auto K = this->K;
|
113
111
|
for (size_t w = b; w < e; ++w)
|
114
112
|
{
|
115
113
|
if (doc.words[w] >= this->realV) continue;
|
116
|
-
addWordTo<-1>(ld, doc, w, doc.words[w]
|
117
|
-
auto dist = getVZLikelihoods(ld, doc, doc.words[w]
|
114
|
+
addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
115
|
+
auto dist = getVZLikelihoods(ld, doc, doc.words[w], doc.sents[w]);
|
118
116
|
auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
|
119
117
|
doc.Vs[w] = vz / (K + KL);
|
120
118
|
doc.Zs[w] = vz % (K + KL);
|
121
|
-
addWordTo<1>(ld, doc, w, doc.words[w]
|
119
|
+
addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
122
120
|
}
|
123
121
|
}
|
124
122
|
|
@@ -294,7 +292,7 @@ namespace tomoto
|
|
294
292
|
doc.Zs = tvector<Tid>(wordSize);
|
295
293
|
doc.Vs.resize(wordSize);
|
296
294
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
297
|
-
doc.numByTopic.init(nullptr, this->K + KL);
|
295
|
+
doc.numByTopic.init(nullptr, this->K + KL, 1);
|
298
296
|
doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
|
299
297
|
doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
300
298
|
doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
@@ -308,7 +306,8 @@ namespace tomoto
|
|
308
306
|
if (initDocs)
|
309
307
|
{
|
310
308
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
|
311
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
|
309
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
|
310
|
+
this->globalState.numByTopicWord.init(nullptr, this->K + KL, V);
|
312
311
|
}
|
313
312
|
}
|
314
313
|
|
@@ -533,7 +532,7 @@ namespace tomoto
|
|
533
532
|
template<typename _TopicModel>
|
534
533
|
void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
535
534
|
{
|
536
|
-
this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
|
535
|
+
this->numByTopic.init(ptr, mdl.getK() + mdl.getKL(), 1);
|
537
536
|
numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
|
538
537
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
539
538
|
{
|
@@ -144,7 +144,7 @@ namespace tomoto
|
|
144
144
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
145
145
|
e = edd.vChunkOffset[partitionId];
|
146
146
|
|
147
|
-
localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
147
|
+
localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
148
148
|
localData[partitionId].numByTopic = globalState.numByTopic;
|
149
149
|
localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2;
|
150
150
|
localData[partitionId].numByTopic2 = globalState.numByTopic2;
|
@@ -157,8 +157,6 @@ namespace tomoto
|
|
157
157
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
158
158
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
159
159
|
{
|
160
|
-
std::vector<std::future<void>> res;
|
161
|
-
|
162
160
|
if (_ps == ParallelScheme::copy_merge)
|
163
161
|
{
|
164
162
|
tState = globalState;
|
@@ -177,19 +175,12 @@ namespace tomoto
|
|
177
175
|
globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
|
178
176
|
globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0);
|
179
177
|
globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0);
|
180
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
181
|
-
}
|
182
|
-
|
183
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
184
|
-
{
|
185
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
186
|
-
{
|
187
|
-
localData[i] = globalState;
|
188
|
-
}));
|
178
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
189
179
|
}
|
190
180
|
}
|
191
181
|
else if (_ps == ParallelScheme::partition)
|
192
182
|
{
|
183
|
+
std::vector<std::future<void>> res;
|
193
184
|
res = pool.enqueueToAll([&](size_t partitionId)
|
194
185
|
{
|
195
186
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
@@ -197,7 +188,6 @@ namespace tomoto
|
|
197
188
|
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
198
189
|
});
|
199
190
|
for (auto& r : res) r.get();
|
200
|
-
res.clear();
|
201
191
|
|
202
192
|
tState.numByTopic1_2 = globalState.numByTopic1_2;
|
203
193
|
globalState.numByTopic1_2 = localData[0].numByTopic1_2;
|
@@ -209,11 +199,31 @@ namespace tomoto
|
|
209
199
|
// make all count being positive
|
210
200
|
if (_tw != TermWeight::one)
|
211
201
|
{
|
212
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
202
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
213
203
|
}
|
214
204
|
globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum();
|
215
205
|
globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum();
|
216
206
|
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
|
211
|
+
template<ParallelScheme _ps>
|
212
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
213
|
+
{
|
214
|
+
std::vector<std::future<void>> res;
|
215
|
+
if (_ps == ParallelScheme::copy_merge)
|
216
|
+
{
|
217
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
218
|
+
{
|
219
|
+
res.emplace_back(pool.enqueue([&, i](size_t)
|
220
|
+
{
|
221
|
+
localData[i] = globalState;
|
222
|
+
}));
|
223
|
+
}
|
224
|
+
}
|
225
|
+
else if (_ps == ParallelScheme::partition)
|
226
|
+
{
|
217
227
|
res = pool.enqueueToAll([&](size_t threadId)
|
218
228
|
{
|
219
229
|
localData[threadId].numByTopic = globalState.numByTopic;
|
@@ -221,7 +231,6 @@ namespace tomoto
|
|
221
231
|
localData[threadId].numByTopic2 = globalState.numByTopic2;
|
222
232
|
});
|
223
233
|
}
|
224
|
-
|
225
234
|
for (auto& r : res) r.get();
|
226
235
|
}
|
227
236
|
|
@@ -304,7 +313,8 @@ namespace tomoto
|
|
304
313
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
|
305
314
|
this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2);
|
306
315
|
this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
|
307
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
|
316
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
|
317
|
+
this->globalState.numByTopicWord.init(nullptr, K2, V);
|
308
318
|
}
|
309
319
|
}
|
310
320
|
|
@@ -0,0 +1,27 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "LDA.h"
|
3
|
+
|
4
|
+
namespace tomoto
|
5
|
+
{
|
6
|
+
template<TermWeight _tw>
|
7
|
+
struct DocumentPTM : public DocumentLDA<_tw>
|
8
|
+
{
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
11
|
+
using WeightType = typename DocumentLDA<_tw>::WeightType;
|
12
|
+
|
13
|
+
uint64_t pseudoDoc = 0;
|
14
|
+
|
15
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, pseudoDoc);
|
16
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, pseudoDoc);
|
17
|
+
};
|
18
|
+
|
19
|
+
class IPTModel : public ILDAModel
|
20
|
+
{
|
21
|
+
public:
|
22
|
+
using DefaultDocType = DocumentPTM<TermWeight::one>;
|
23
|
+
static IPTModel* create(TermWeight _weight, size_t _K = 1, size_t _P = 100,
|
24
|
+
Float alpha = 0.1, Float eta = 0.01, Float lambda = 0.01, size_t seed = std::random_device{}(),
|
25
|
+
bool scalarRng = false);
|
26
|
+
};
|
27
|
+
}
|
@@ -0,0 +1,10 @@
|
|
1
|
+
#include "PTModel.hpp"
|
2
|
+
|
3
|
+
namespace tomoto
|
4
|
+
{
|
5
|
+
|
6
|
+
IPTModel* IPTModel::create(TermWeight _weight, size_t _K, size_t _P, Float _alpha, Float _eta, Float _lambda, size_t seed, bool scalarRng)
|
7
|
+
{
|
8
|
+
TMT_SWITCH_TW(_weight, scalarRng, PTModel, _K, _P, _alpha, _eta, _lambda, seed);
|
9
|
+
}
|
10
|
+
}
|
@@ -0,0 +1,273 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "LDAModel.hpp"
|
3
|
+
#include "PT.h"
|
4
|
+
|
5
|
+
/*
|
6
|
+
Implementation of Pseudo-document topic model using Gibbs sampling by bab2min
|
7
|
+
|
8
|
+
Zuo, Y., Wu, J., Zhang, H., Lin, H., Wang, F., Xu, K., & Xiong, H. (2016, August). Topic modeling of short texts: A pseudo-document view. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 2105-2114).
|
9
|
+
*/
|
10
|
+
|
11
|
+
namespace tomoto
|
12
|
+
{
|
13
|
+
template<TermWeight _tw>
|
14
|
+
struct ModelStatePTM : public ModelStateLDA<_tw>
|
15
|
+
{
|
16
|
+
using WeightType = typename ModelStateLDA<_tw>::WeightType;
|
17
|
+
|
18
|
+
Eigen::Array<Float, -1, 1> pLikelihood;
|
19
|
+
Eigen::ArrayXi numDocsByPDoc;
|
20
|
+
Eigen::Matrix<WeightType, -1, -1> numByTopicPDoc;
|
21
|
+
|
22
|
+
//DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>);
|
23
|
+
};
|
24
|
+
|
25
|
+
template<TermWeight _tw, typename _RandGen,
|
26
|
+
typename _Interface = IPTModel,
|
27
|
+
typename _Derived = void,
|
28
|
+
typename _DocType = DocumentPTM<_tw>,
|
29
|
+
typename _ModelState = ModelStatePTM<_tw>>
|
30
|
+
class PTModel : public LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface,
|
31
|
+
typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type,
|
32
|
+
_DocType, _ModelState>
|
33
|
+
{
|
34
|
+
protected:
|
35
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type;
|
36
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
37
|
+
friend BaseClass;
|
38
|
+
friend typename BaseClass::BaseClass;
|
39
|
+
using WeightType = typename BaseClass::WeightType;
|
40
|
+
|
41
|
+
static constexpr char TMID[] = "PTM";
|
42
|
+
|
43
|
+
uint64_t numPDocs;
|
44
|
+
Float lambda;
|
45
|
+
uint32_t pseudoDocSamplingInterval = 10;
|
46
|
+
|
47
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
48
|
+
{
|
49
|
+
const auto K = this->K;
|
50
|
+
for (size_t i = 0; i < 10; ++i)
|
51
|
+
{
|
52
|
+
Float denom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc.col(i).sum(); }, numPDocs, this->alphas.sum());
|
53
|
+
for (size_t k = 0; k < K; ++k)
|
54
|
+
{
|
55
|
+
Float nom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc(k, i);}, numPDocs, this->alphas(k));
|
56
|
+
this->alphas(k) = std::max(nom / denom * this->alphas(k), 1e-5f);
|
57
|
+
}
|
58
|
+
}
|
59
|
+
}
|
60
|
+
|
61
|
+
void samplePseudoDoc(ThreadPool* pool, _ModelState& ld, _RandGen& rgs, _DocType& doc) const
|
62
|
+
{
|
63
|
+
if (doc.getSumWordWeight() == 0) return;
|
64
|
+
Eigen::Array<WeightType, -1, 1> docTopicDist = Eigen::Array<WeightType, -1, 1>::Zero(this->K);
|
65
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
66
|
+
{
|
67
|
+
if (doc.words[i] >= this->realV) continue;
|
68
|
+
this->template addWordTo<-1>(ld, doc, i, doc.words[i], doc.Zs[i]);
|
69
|
+
typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
|
70
|
+
= _tw != TermWeight::one ? doc.wordWeights[i] : 1;
|
71
|
+
docTopicDist[doc.Zs[i]] += weight;
|
72
|
+
}
|
73
|
+
--ld.numDocsByPDoc[doc.pseudoDoc];
|
74
|
+
|
75
|
+
if (pool)
|
76
|
+
{
|
77
|
+
std::vector<std::future<void>> futures;
|
78
|
+
for (size_t w = 0; w < pool->getNumWorkers(); ++w)
|
79
|
+
{
|
80
|
+
futures.emplace_back(pool->enqueue([&](size_t, size_t w)
|
81
|
+
{
|
82
|
+
for (size_t p = w; p < numPDocs; p += pool->getNumWorkers())
|
83
|
+
{
|
84
|
+
Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
|
85
|
+
Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
|
86
|
+
ld.pLikelihood[p] = ax - bx;
|
87
|
+
}
|
88
|
+
}, w));
|
89
|
+
}
|
90
|
+
for (auto& f : futures) f.get();
|
91
|
+
}
|
92
|
+
else
|
93
|
+
{
|
94
|
+
for (size_t p = 0; p < numPDocs; ++p)
|
95
|
+
{
|
96
|
+
Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
|
97
|
+
Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
|
98
|
+
ld.pLikelihood[p] = ax - bx;
|
99
|
+
}
|
100
|
+
}
|
101
|
+
ld.pLikelihood = (ld.pLikelihood - ld.pLikelihood.maxCoeff()).exp();
|
102
|
+
ld.pLikelihood *= ld.numDocsByPDoc.template cast<Float>() + lambda;
|
103
|
+
|
104
|
+
sample::prefixSum(ld.pLikelihood.data(), numPDocs);
|
105
|
+
doc.pseudoDoc = sample::sampleFromDiscreteAcc(ld.pLikelihood.data(), ld.pLikelihood.data() + numPDocs, rgs);
|
106
|
+
|
107
|
+
++ld.numDocsByPDoc[doc.pseudoDoc];
|
108
|
+
doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
|
109
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
110
|
+
{
|
111
|
+
if (doc.words[i] >= this->realV) continue;
|
112
|
+
this->template addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i]);
|
113
|
+
}
|
114
|
+
}
|
115
|
+
|
116
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter>
|
117
|
+
void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
|
118
|
+
_DocIter docFirst, _DocIter docLast) const
|
119
|
+
{
|
120
|
+
if (this->globalStep % pseudoDocSamplingInterval) return;
|
121
|
+
for (; docFirst != docLast; ++docFirst)
|
122
|
+
{
|
123
|
+
samplePseudoDoc(pool, globalState, rgs[0], *docFirst);
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
template<typename _DocIter>
|
128
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
129
|
+
{
|
130
|
+
double ll = 0;
|
131
|
+
// doc-topic distribution
|
132
|
+
for (; _first != _last; ++_first)
|
133
|
+
{
|
134
|
+
auto& doc = *_first;
|
135
|
+
}
|
136
|
+
return ll;
|
137
|
+
}
|
138
|
+
|
139
|
+
double getLLRest(const _ModelState& ld) const
|
140
|
+
{
|
141
|
+
double ll = BaseClass::getLLRest(ld);
|
142
|
+
const size_t V = this->realV;
|
143
|
+
ll -= math::lgammaT(ld.numDocsByPDoc.sum() + lambda * numPDocs) - math::lgammaT(lambda * numPDocs);
|
144
|
+
// pseudo_doc-topic distribution
|
145
|
+
for (size_t p = 0; p < numPDocs; ++p)
|
146
|
+
{
|
147
|
+
ll += math::lgammaT(ld.numDocsByPDoc[p] + lambda) - math::lgammaT(lambda);
|
148
|
+
ll -= math::lgammaT(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
|
149
|
+
for (Tid k = 0; k < this->K; ++k)
|
150
|
+
{
|
151
|
+
ll += math::lgammaT(ld.numByTopicPDoc(k, p) + this->alphas[k]) - math::lgammaT(this->alphas[k]);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
return ll;
|
155
|
+
}
|
156
|
+
|
157
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
158
|
+
{
|
159
|
+
sortAndWriteOrder(doc.words, doc.wOrder);
|
160
|
+
doc.numByTopic.init((WeightType*)this->globalState.numByTopicPDoc.col(0).data(), this->K, 1);
|
161
|
+
doc.Zs = tvector<Tid>(wordSize);
|
162
|
+
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
163
|
+
}
|
164
|
+
|
165
|
+
void initGlobalState(bool initDocs)
|
166
|
+
{
|
167
|
+
this->alphas.resize(this->K);
|
168
|
+
this->alphas.array() = this->alpha;
|
169
|
+
this->globalState.pLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(numPDocs);
|
170
|
+
this->globalState.numDocsByPDoc = Eigen::ArrayXi::Zero(numPDocs);
|
171
|
+
this->globalState.numByTopicPDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, numPDocs);
|
172
|
+
BaseClass::initGlobalState(initDocs);
|
173
|
+
}
|
174
|
+
|
175
|
+
struct Generator
|
176
|
+
{
|
177
|
+
std::uniform_int_distribution<uint64_t> psi;
|
178
|
+
std::uniform_int_distribution<Tid> theta;
|
179
|
+
};
|
180
|
+
|
181
|
+
Generator makeGeneratorForInit(const _DocType*) const
|
182
|
+
{
|
183
|
+
return Generator{
|
184
|
+
std::uniform_int_distribution<uint64_t>{0, numPDocs - 1},
|
185
|
+
std::uniform_int_distribution<Tid>{0, (Tid)(this->K - 1)}
|
186
|
+
};
|
187
|
+
}
|
188
|
+
|
189
|
+
template<bool _Infer>
|
190
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
191
|
+
{
|
192
|
+
if (i == 0)
|
193
|
+
{
|
194
|
+
doc.pseudoDoc = g.psi(rgs);
|
195
|
+
++ld.numDocsByPDoc[doc.pseudoDoc];
|
196
|
+
doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
|
197
|
+
}
|
198
|
+
auto& z = doc.Zs[i];
|
199
|
+
auto w = doc.words[i];
|
200
|
+
if (this->etaByTopicWord.size())
|
201
|
+
{
|
202
|
+
auto col = this->etaByTopicWord.col(w);
|
203
|
+
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
204
|
+
}
|
205
|
+
else
|
206
|
+
{
|
207
|
+
z = g.theta(rgs);
|
208
|
+
}
|
209
|
+
this->template addWordTo<1>(ld, doc, i, w, z);
|
210
|
+
}
|
211
|
+
|
212
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
|
213
|
+
void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
|
214
|
+
_DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
|
215
|
+
{
|
216
|
+
// single-threaded sampling
|
217
|
+
if (_ps == ParallelScheme::none)
|
218
|
+
{
|
219
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
|
220
|
+
{
|
221
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
|
222
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
223
|
+
docFirst[id], edd, id,
|
224
|
+
*localData, *rgs, this->globalStep, 0);
|
225
|
+
|
226
|
+
});
|
227
|
+
}
|
228
|
+
// multi-threaded sampling on partition and update into global
|
229
|
+
else if (_ps == ParallelScheme::partition)
|
230
|
+
{
|
231
|
+
const size_t chStride = pool.getNumWorkers();
|
232
|
+
for (size_t i = 0; i < chStride; ++i)
|
233
|
+
{
|
234
|
+
res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
|
235
|
+
{
|
236
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[partitionId](), [&](size_t id)
|
237
|
+
{
|
238
|
+
if ((docFirst[id].pseudoDoc + partitionId) % chStride != i) return;
|
239
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
240
|
+
docFirst[id], edd, id,
|
241
|
+
localData[partitionId], rgs[partitionId], this->globalStep, partitionId
|
242
|
+
);
|
243
|
+
});
|
244
|
+
});
|
245
|
+
for (auto& r : res) r.get();
|
246
|
+
res.clear();
|
247
|
+
}
|
248
|
+
}
|
249
|
+
else
|
250
|
+
{
|
251
|
+
throw std::runtime_error{ "Unsupported ParallelScheme" };
|
252
|
+
}
|
253
|
+
}
|
254
|
+
|
255
|
+
public:
|
256
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
|
257
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
|
258
|
+
|
259
|
+
PTModel(size_t _K = 1, size_t _P = 100, Float _alpha = 1.0, Float _eta = 0.01, Float _lambda = 0.01,
|
260
|
+
size_t _rg = std::random_device{}())
|
261
|
+
: BaseClass(_K, _alpha, _eta, _rg), numPDocs(_P), lambda(_lambda)
|
262
|
+
{
|
263
|
+
}
|
264
|
+
|
265
|
+
void updateDocs()
|
266
|
+
{
|
267
|
+
for (auto& doc : this->docs)
|
268
|
+
{
|
269
|
+
doc.template update<>(this->getTopicDocPtr(doc.pseudoDoc), *static_cast<DerivedClass*>(this));
|
270
|
+
}
|
271
|
+
}
|
272
|
+
};
|
273
|
+
}
|