tomoto 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -107,18 +107,16 @@ namespace tomoto
|
|
107
107
|
e = edd.chunkOffsetByDoc(partitionId + 1, docId);
|
108
108
|
}
|
109
109
|
|
110
|
-
size_t vOffset = (_ps == ParallelScheme::partition && partitionId) ? edd.vChunkOffset[partitionId - 1] : 0;
|
111
|
-
|
112
110
|
const auto K = this->K;
|
113
111
|
for (size_t w = b; w < e; ++w)
|
114
112
|
{
|
115
113
|
if (doc.words[w] >= this->realV) continue;
|
116
|
-
addWordTo<-1>(ld, doc, w, doc.words[w]
|
117
|
-
auto dist = getVZLikelihoods(ld, doc, doc.words[w]
|
114
|
+
addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
115
|
+
auto dist = getVZLikelihoods(ld, doc, doc.words[w], doc.sents[w]);
|
118
116
|
auto vz = sample::sampleFromDiscreteAcc(dist, dist + T * (K + KL), rgs);
|
119
117
|
doc.Vs[w] = vz / (K + KL);
|
120
118
|
doc.Zs[w] = vz % (K + KL);
|
121
|
-
addWordTo<1>(ld, doc, w, doc.words[w]
|
119
|
+
addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w] - (doc.Zs[w] < K ? 0 : K), doc.sents[w], doc.Vs[w], doc.Zs[w] < K ? 0 : 1);
|
122
120
|
}
|
123
121
|
}
|
124
122
|
|
@@ -294,7 +292,7 @@ namespace tomoto
|
|
294
292
|
doc.Zs = tvector<Tid>(wordSize);
|
295
293
|
doc.Vs.resize(wordSize);
|
296
294
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
297
|
-
doc.numByTopic.init(nullptr, this->K + KL);
|
295
|
+
doc.numByTopic.init(nullptr, this->K + KL, 1);
|
298
296
|
doc.numBySentWin = Eigen::Matrix<WeightType, -1, -1>::Zero(S, T);
|
299
297
|
doc.numByWin = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
300
298
|
doc.numByWinL = Eigen::Matrix<WeightType, -1, 1>::Zero(S + T - 1);
|
@@ -308,7 +306,8 @@ namespace tomoto
|
|
308
306
|
if (initDocs)
|
309
307
|
{
|
310
308
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
|
311
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
|
309
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K + KL, V);
|
310
|
+
this->globalState.numByTopicWord.init(nullptr, this->K + KL, V);
|
312
311
|
}
|
313
312
|
}
|
314
313
|
|
@@ -533,7 +532,7 @@ namespace tomoto
|
|
533
532
|
template<typename _TopicModel>
|
534
533
|
void DocumentMGLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
535
534
|
{
|
536
|
-
this->numByTopic.init(ptr, mdl.getK() + mdl.getKL());
|
535
|
+
this->numByTopic.init(ptr, mdl.getK() + mdl.getKL(), 1);
|
537
536
|
numBySent.resize(*std::max_element(sents.begin(), sents.end()) + 1);
|
538
537
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
539
538
|
{
|
@@ -144,7 +144,7 @@ namespace tomoto
|
|
144
144
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
145
145
|
e = edd.vChunkOffset[partitionId];
|
146
146
|
|
147
|
-
localData[partitionId].numByTopicWord = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
147
|
+
localData[partitionId].numByTopicWord.matrix() = globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b);
|
148
148
|
localData[partitionId].numByTopic = globalState.numByTopic;
|
149
149
|
localData[partitionId].numByTopic1_2 = globalState.numByTopic1_2;
|
150
150
|
localData[partitionId].numByTopic2 = globalState.numByTopic2;
|
@@ -157,8 +157,6 @@ namespace tomoto
|
|
157
157
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
158
158
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
159
159
|
{
|
160
|
-
std::vector<std::future<void>> res;
|
161
|
-
|
162
160
|
if (_ps == ParallelScheme::copy_merge)
|
163
161
|
{
|
164
162
|
tState = globalState;
|
@@ -177,19 +175,12 @@ namespace tomoto
|
|
177
175
|
globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
|
178
176
|
globalState.numByTopic1_2 = globalState.numByTopic1_2.cwiseMax(0);
|
179
177
|
globalState.numByTopic2 = globalState.numByTopic2.cwiseMax(0);
|
180
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
181
|
-
}
|
182
|
-
|
183
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
184
|
-
{
|
185
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
186
|
-
{
|
187
|
-
localData[i] = globalState;
|
188
|
-
}));
|
178
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
189
179
|
}
|
190
180
|
}
|
191
181
|
else if (_ps == ParallelScheme::partition)
|
192
182
|
{
|
183
|
+
std::vector<std::future<void>> res;
|
193
184
|
res = pool.enqueueToAll([&](size_t partitionId)
|
194
185
|
{
|
195
186
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
@@ -197,7 +188,6 @@ namespace tomoto
|
|
197
188
|
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
198
189
|
});
|
199
190
|
for (auto& r : res) r.get();
|
200
|
-
res.clear();
|
201
191
|
|
202
192
|
tState.numByTopic1_2 = globalState.numByTopic1_2;
|
203
193
|
globalState.numByTopic1_2 = localData[0].numByTopic1_2;
|
@@ -209,11 +199,31 @@ namespace tomoto
|
|
209
199
|
// make all count being positive
|
210
200
|
if (_tw != TermWeight::one)
|
211
201
|
{
|
212
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
202
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
213
203
|
}
|
214
204
|
globalState.numByTopic = globalState.numByTopic1_2.rowwise().sum();
|
215
205
|
globalState.numByTopic2 = globalState.numByTopicWord.rowwise().sum();
|
216
206
|
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
|
211
|
+
template<ParallelScheme _ps>
|
212
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
213
|
+
{
|
214
|
+
std::vector<std::future<void>> res;
|
215
|
+
if (_ps == ParallelScheme::copy_merge)
|
216
|
+
{
|
217
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
218
|
+
{
|
219
|
+
res.emplace_back(pool.enqueue([&, i](size_t)
|
220
|
+
{
|
221
|
+
localData[i] = globalState;
|
222
|
+
}));
|
223
|
+
}
|
224
|
+
}
|
225
|
+
else if (_ps == ParallelScheme::partition)
|
226
|
+
{
|
217
227
|
res = pool.enqueueToAll([&](size_t threadId)
|
218
228
|
{
|
219
229
|
localData[threadId].numByTopic = globalState.numByTopic;
|
@@ -221,7 +231,6 @@ namespace tomoto
|
|
221
231
|
localData[threadId].numByTopic2 = globalState.numByTopic2;
|
222
232
|
});
|
223
233
|
}
|
224
|
-
|
225
234
|
for (auto& r : res) r.get();
|
226
235
|
}
|
227
236
|
|
@@ -304,7 +313,8 @@ namespace tomoto
|
|
304
313
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
|
305
314
|
this->globalState.numByTopic2 = Eigen::Matrix<WeightType, -1, 1>::Zero(K2);
|
306
315
|
this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
|
307
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
|
316
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K2, V);
|
317
|
+
this->globalState.numByTopicWord.init(nullptr, K2, V);
|
308
318
|
}
|
309
319
|
}
|
310
320
|
|
@@ -0,0 +1,27 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "LDA.h"
|
3
|
+
|
4
|
+
namespace tomoto
|
5
|
+
{
|
6
|
+
template<TermWeight _tw>
|
7
|
+
struct DocumentPTM : public DocumentLDA<_tw>
|
8
|
+
{
|
9
|
+
using BaseDocument = DocumentLDA<_tw>;
|
10
|
+
using DocumentLDA<_tw>::DocumentLDA;
|
11
|
+
using WeightType = typename DocumentLDA<_tw>::WeightType;
|
12
|
+
|
13
|
+
uint64_t pseudoDoc = 0;
|
14
|
+
|
15
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, pseudoDoc);
|
16
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, pseudoDoc);
|
17
|
+
};
|
18
|
+
|
19
|
+
class IPTModel : public ILDAModel
|
20
|
+
{
|
21
|
+
public:
|
22
|
+
using DefaultDocType = DocumentPTM<TermWeight::one>;
|
23
|
+
static IPTModel* create(TermWeight _weight, size_t _K = 1, size_t _P = 100,
|
24
|
+
Float alpha = 0.1, Float eta = 0.01, Float lambda = 0.01, size_t seed = std::random_device{}(),
|
25
|
+
bool scalarRng = false);
|
26
|
+
};
|
27
|
+
}
|
@@ -0,0 +1,10 @@
|
|
1
|
+
#include "PTModel.hpp"
|
2
|
+
|
3
|
+
namespace tomoto
|
4
|
+
{
|
5
|
+
|
6
|
+
IPTModel* IPTModel::create(TermWeight _weight, size_t _K, size_t _P, Float _alpha, Float _eta, Float _lambda, size_t seed, bool scalarRng)
|
7
|
+
{
|
8
|
+
TMT_SWITCH_TW(_weight, scalarRng, PTModel, _K, _P, _alpha, _eta, _lambda, seed);
|
9
|
+
}
|
10
|
+
}
|
@@ -0,0 +1,273 @@
|
|
1
|
+
#pragma once
|
2
|
+
#include "LDAModel.hpp"
|
3
|
+
#include "PT.h"
|
4
|
+
|
5
|
+
/*
|
6
|
+
Implementation of Pseudo-document topic model using Gibbs sampling by bab2min
|
7
|
+
|
8
|
+
Zuo, Y., Wu, J., Zhang, H., Lin, H., Wang, F., Xu, K., & Xiong, H. (2016, August). Topic modeling of short texts: A pseudo-document view. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining (pp. 2105-2114).
|
9
|
+
*/
|
10
|
+
|
11
|
+
namespace tomoto
|
12
|
+
{
|
13
|
+
template<TermWeight _tw>
|
14
|
+
struct ModelStatePTM : public ModelStateLDA<_tw>
|
15
|
+
{
|
16
|
+
using WeightType = typename ModelStateLDA<_tw>::WeightType;
|
17
|
+
|
18
|
+
Eigen::Array<Float, -1, 1> pLikelihood;
|
19
|
+
Eigen::ArrayXi numDocsByPDoc;
|
20
|
+
Eigen::Matrix<WeightType, -1, -1> numByTopicPDoc;
|
21
|
+
|
22
|
+
//DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>);
|
23
|
+
};
|
24
|
+
|
25
|
+
template<TermWeight _tw, typename _RandGen,
|
26
|
+
typename _Interface = IPTModel,
|
27
|
+
typename _Derived = void,
|
28
|
+
typename _DocType = DocumentPTM<_tw>,
|
29
|
+
typename _ModelState = ModelStatePTM<_tw>>
|
30
|
+
class PTModel : public LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface,
|
31
|
+
typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type,
|
32
|
+
_DocType, _ModelState>
|
33
|
+
{
|
34
|
+
protected:
|
35
|
+
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, PTModel<_tw, _RandGen>, _Derived>::type;
|
36
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::continuous_doc_data | flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
37
|
+
friend BaseClass;
|
38
|
+
friend typename BaseClass::BaseClass;
|
39
|
+
using WeightType = typename BaseClass::WeightType;
|
40
|
+
|
41
|
+
static constexpr char TMID[] = "PTM";
|
42
|
+
|
43
|
+
uint64_t numPDocs;
|
44
|
+
Float lambda;
|
45
|
+
uint32_t pseudoDocSamplingInterval = 10;
|
46
|
+
|
47
|
+
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
48
|
+
{
|
49
|
+
const auto K = this->K;
|
50
|
+
for (size_t i = 0; i < 10; ++i)
|
51
|
+
{
|
52
|
+
Float denom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc.col(i).sum(); }, numPDocs, this->alphas.sum());
|
53
|
+
for (size_t k = 0; k < K; ++k)
|
54
|
+
{
|
55
|
+
Float nom = this->calcDigammaSum(&pool, [&](size_t i) { return this->globalState.numByTopicPDoc(k, i);}, numPDocs, this->alphas(k));
|
56
|
+
this->alphas(k) = std::max(nom / denom * this->alphas(k), 1e-5f);
|
57
|
+
}
|
58
|
+
}
|
59
|
+
}
|
60
|
+
|
61
|
+
void samplePseudoDoc(ThreadPool* pool, _ModelState& ld, _RandGen& rgs, _DocType& doc) const
|
62
|
+
{
|
63
|
+
if (doc.getSumWordWeight() == 0) return;
|
64
|
+
Eigen::Array<WeightType, -1, 1> docTopicDist = Eigen::Array<WeightType, -1, 1>::Zero(this->K);
|
65
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
66
|
+
{
|
67
|
+
if (doc.words[i] >= this->realV) continue;
|
68
|
+
this->template addWordTo<-1>(ld, doc, i, doc.words[i], doc.Zs[i]);
|
69
|
+
typename std::conditional<_tw != TermWeight::one, float, int32_t>::type weight
|
70
|
+
= _tw != TermWeight::one ? doc.wordWeights[i] : 1;
|
71
|
+
docTopicDist[doc.Zs[i]] += weight;
|
72
|
+
}
|
73
|
+
--ld.numDocsByPDoc[doc.pseudoDoc];
|
74
|
+
|
75
|
+
if (pool)
|
76
|
+
{
|
77
|
+
std::vector<std::future<void>> futures;
|
78
|
+
for (size_t w = 0; w < pool->getNumWorkers(); ++w)
|
79
|
+
{
|
80
|
+
futures.emplace_back(pool->enqueue([&](size_t, size_t w)
|
81
|
+
{
|
82
|
+
for (size_t p = w; p < numPDocs; p += pool->getNumWorkers())
|
83
|
+
{
|
84
|
+
Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
|
85
|
+
Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
|
86
|
+
ld.pLikelihood[p] = ax - bx;
|
87
|
+
}
|
88
|
+
}, w));
|
89
|
+
}
|
90
|
+
for (auto& f : futures) f.get();
|
91
|
+
}
|
92
|
+
else
|
93
|
+
{
|
94
|
+
for (size_t p = 0; p < numPDocs; ++p)
|
95
|
+
{
|
96
|
+
Float ax = math::lgammaSubt(ld.numByTopicPDoc.col(p).array().template cast<Float>() + this->alphas.array(), docTopicDist.template cast<Float>()).sum();
|
97
|
+
Float bx = math::lgammaSubt(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum(), docTopicDist.sum());
|
98
|
+
ld.pLikelihood[p] = ax - bx;
|
99
|
+
}
|
100
|
+
}
|
101
|
+
ld.pLikelihood = (ld.pLikelihood - ld.pLikelihood.maxCoeff()).exp();
|
102
|
+
ld.pLikelihood *= ld.numDocsByPDoc.template cast<Float>() + lambda;
|
103
|
+
|
104
|
+
sample::prefixSum(ld.pLikelihood.data(), numPDocs);
|
105
|
+
doc.pseudoDoc = sample::sampleFromDiscreteAcc(ld.pLikelihood.data(), ld.pLikelihood.data() + numPDocs, rgs);
|
106
|
+
|
107
|
+
++ld.numDocsByPDoc[doc.pseudoDoc];
|
108
|
+
doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
|
109
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
110
|
+
{
|
111
|
+
if (doc.words[i] >= this->realV) continue;
|
112
|
+
this->template addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i]);
|
113
|
+
}
|
114
|
+
}
|
115
|
+
|
116
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter>
|
117
|
+
void performSamplingGlobal(ThreadPool* pool, _ModelState& globalState, _RandGen* rgs,
|
118
|
+
_DocIter docFirst, _DocIter docLast) const
|
119
|
+
{
|
120
|
+
if (this->globalStep % pseudoDocSamplingInterval) return;
|
121
|
+
for (; docFirst != docLast; ++docFirst)
|
122
|
+
{
|
123
|
+
samplePseudoDoc(pool, globalState, rgs[0], *docFirst);
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
template<typename _DocIter>
|
128
|
+
double getLLDocs(_DocIter _first, _DocIter _last) const
|
129
|
+
{
|
130
|
+
double ll = 0;
|
131
|
+
// doc-topic distribution
|
132
|
+
for (; _first != _last; ++_first)
|
133
|
+
{
|
134
|
+
auto& doc = *_first;
|
135
|
+
}
|
136
|
+
return ll;
|
137
|
+
}
|
138
|
+
|
139
|
+
double getLLRest(const _ModelState& ld) const
|
140
|
+
{
|
141
|
+
double ll = BaseClass::getLLRest(ld);
|
142
|
+
const size_t V = this->realV;
|
143
|
+
ll -= math::lgammaT(ld.numDocsByPDoc.sum() + lambda * numPDocs) - math::lgammaT(lambda * numPDocs);
|
144
|
+
// pseudo_doc-topic distribution
|
145
|
+
for (size_t p = 0; p < numPDocs; ++p)
|
146
|
+
{
|
147
|
+
ll += math::lgammaT(ld.numDocsByPDoc[p] + lambda) - math::lgammaT(lambda);
|
148
|
+
ll -= math::lgammaT(ld.numByTopicPDoc.col(p).sum() + this->alphas.sum()) - math::lgammaT(this->alphas.sum());
|
149
|
+
for (Tid k = 0; k < this->K; ++k)
|
150
|
+
{
|
151
|
+
ll += math::lgammaT(ld.numByTopicPDoc(k, p) + this->alphas[k]) - math::lgammaT(this->alphas[k]);
|
152
|
+
}
|
153
|
+
}
|
154
|
+
return ll;
|
155
|
+
}
|
156
|
+
|
157
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
158
|
+
{
|
159
|
+
sortAndWriteOrder(doc.words, doc.wOrder);
|
160
|
+
doc.numByTopic.init((WeightType*)this->globalState.numByTopicPDoc.col(0).data(), this->K, 1);
|
161
|
+
doc.Zs = tvector<Tid>(wordSize);
|
162
|
+
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
163
|
+
}
|
164
|
+
|
165
|
+
void initGlobalState(bool initDocs)
|
166
|
+
{
|
167
|
+
this->alphas.resize(this->K);
|
168
|
+
this->alphas.array() = this->alpha;
|
169
|
+
this->globalState.pLikelihood = Eigen::Matrix<Float, -1, 1>::Zero(numPDocs);
|
170
|
+
this->globalState.numDocsByPDoc = Eigen::ArrayXi::Zero(numPDocs);
|
171
|
+
this->globalState.numByTopicPDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, numPDocs);
|
172
|
+
BaseClass::initGlobalState(initDocs);
|
173
|
+
}
|
174
|
+
|
175
|
+
struct Generator
|
176
|
+
{
|
177
|
+
std::uniform_int_distribution<uint64_t> psi;
|
178
|
+
std::uniform_int_distribution<Tid> theta;
|
179
|
+
};
|
180
|
+
|
181
|
+
Generator makeGeneratorForInit(const _DocType*) const
|
182
|
+
{
|
183
|
+
return Generator{
|
184
|
+
std::uniform_int_distribution<uint64_t>{0, numPDocs - 1},
|
185
|
+
std::uniform_int_distribution<Tid>{0, (Tid)(this->K - 1)}
|
186
|
+
};
|
187
|
+
}
|
188
|
+
|
189
|
+
template<bool _Infer>
|
190
|
+
void updateStateWithDoc(Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
191
|
+
{
|
192
|
+
if (i == 0)
|
193
|
+
{
|
194
|
+
doc.pseudoDoc = g.psi(rgs);
|
195
|
+
++ld.numDocsByPDoc[doc.pseudoDoc];
|
196
|
+
doc.numByTopic.init(ld.numByTopicPDoc.col(doc.pseudoDoc).data(), this->K, 1);
|
197
|
+
}
|
198
|
+
auto& z = doc.Zs[i];
|
199
|
+
auto w = doc.words[i];
|
200
|
+
if (this->etaByTopicWord.size())
|
201
|
+
{
|
202
|
+
auto col = this->etaByTopicWord.col(w);
|
203
|
+
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
204
|
+
}
|
205
|
+
else
|
206
|
+
{
|
207
|
+
z = g.theta(rgs);
|
208
|
+
}
|
209
|
+
this->template addWordTo<1>(ld, doc, i, w, z);
|
210
|
+
}
|
211
|
+
|
212
|
+
template<ParallelScheme _ps, bool _infer, typename _DocIter, typename _ExtraDocData>
|
213
|
+
void performSampling(ThreadPool& pool, _ModelState* localData, _RandGen* rgs, std::vector<std::future<void>>& res,
|
214
|
+
_DocIter docFirst, _DocIter docLast, const _ExtraDocData& edd) const
|
215
|
+
{
|
216
|
+
// single-threaded sampling
|
217
|
+
if (_ps == ParallelScheme::none)
|
218
|
+
{
|
219
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[0](), [&](size_t id)
|
220
|
+
{
|
221
|
+
static_cast<const DerivedClass*>(this)->presampleDocument(docFirst[id], id, *localData, *rgs, this->globalStep);
|
222
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
223
|
+
docFirst[id], edd, id,
|
224
|
+
*localData, *rgs, this->globalStep, 0);
|
225
|
+
|
226
|
+
});
|
227
|
+
}
|
228
|
+
// multi-threaded sampling on partition and update into global
|
229
|
+
else if (_ps == ParallelScheme::partition)
|
230
|
+
{
|
231
|
+
const size_t chStride = pool.getNumWorkers();
|
232
|
+
for (size_t i = 0; i < chStride; ++i)
|
233
|
+
{
|
234
|
+
res = pool.enqueueToAll([&, i, chStride](size_t partitionId)
|
235
|
+
{
|
236
|
+
forShuffled((size_t)std::distance(docFirst, docLast), rgs[partitionId](), [&](size_t id)
|
237
|
+
{
|
238
|
+
if ((docFirst[id].pseudoDoc + partitionId) % chStride != i) return;
|
239
|
+
static_cast<const DerivedClass*>(this)->template sampleDocument<_ps, _infer>(
|
240
|
+
docFirst[id], edd, id,
|
241
|
+
localData[partitionId], rgs[partitionId], this->globalStep, partitionId
|
242
|
+
);
|
243
|
+
});
|
244
|
+
});
|
245
|
+
for (auto& r : res) r.get();
|
246
|
+
res.clear();
|
247
|
+
}
|
248
|
+
}
|
249
|
+
else
|
250
|
+
{
|
251
|
+
throw std::runtime_error{ "Unsupported ParallelScheme" };
|
252
|
+
}
|
253
|
+
}
|
254
|
+
|
255
|
+
public:
|
256
|
+
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
|
257
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
|
258
|
+
|
259
|
+
PTModel(size_t _K = 1, size_t _P = 100, Float _alpha = 1.0, Float _eta = 0.01, Float _lambda = 0.01,
|
260
|
+
size_t _rg = std::random_device{}())
|
261
|
+
: BaseClass(_K, _alpha, _eta, _rg), numPDocs(_P), lambda(_lambda)
|
262
|
+
{
|
263
|
+
}
|
264
|
+
|
265
|
+
void updateDocs()
|
266
|
+
{
|
267
|
+
for (auto& doc : this->docs)
|
268
|
+
{
|
269
|
+
doc.template update<>(this->getTopicDocPtr(doc.pseudoDoc), *static_cast<DerivedClass*>(this));
|
270
|
+
}
|
271
|
+
}
|
272
|
+
};
|
273
|
+
}
|