tomoto 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -65,6 +65,7 @@ namespace tomoto
|
|
65
65
|
{
|
66
66
|
if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
|
67
67
|
else pbeta = doc.beta.col(i % numBetaSample).array().exp();
|
68
|
+
|
68
69
|
Float betaESum = pbeta.sum() + 1;
|
69
70
|
pbeta /= betaESum;
|
70
71
|
for (size_t k = 0; k < this->K; ++k)
|
@@ -78,7 +79,9 @@ namespace tomoto
|
|
78
79
|
|
79
80
|
Float c = betaESum * (1 - pbeta[k]);
|
80
81
|
lowerBound[k] = log(c * max_uk / (1 - max_uk));
|
81
|
-
|
82
|
+
lowerBound[k] = std::max(std::min(lowerBound[k], (Float)100), (Float)-100);
|
83
|
+
upperBound[k] = log(c * min_unk / (1 - min_unk + epsilon));
|
84
|
+
upperBound[k] = std::max(std::min(upperBound[k], (Float)100), (Float)-100);
|
82
85
|
if (lowerBound[k] > upperBound[k])
|
83
86
|
{
|
84
87
|
THROW_ERROR_WITH_INFO(exception::TrainingError,
|
@@ -120,8 +123,8 @@ namespace tomoto
|
|
120
123
|
}*/
|
121
124
|
}
|
122
125
|
|
123
|
-
template<typename _DocIter>
|
124
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState
|
126
|
+
template<GlobalSampler _gs, typename _DocIter>
|
127
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last) const
|
125
128
|
{
|
126
129
|
if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
|
127
130
|
|
@@ -11,7 +11,7 @@ namespace tomoto
|
|
11
11
|
using DocumentLDA<_tw>::DocumentLDA;
|
12
12
|
|
13
13
|
uint64_t timepoint = 0;
|
14
|
-
|
14
|
+
ShareableMatrix<Float, -1, 1> eta;
|
15
15
|
sample::AliasMethod<> aliasTable;
|
16
16
|
|
17
17
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
|
@@ -20,6 +20,7 @@ namespace tomoto
|
|
20
20
|
|
21
21
|
Eigen::Matrix<WeightType, -1, -1> numByTopic; // Dim: (Topic, Time)
|
22
22
|
Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
|
23
|
+
//ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
|
23
24
|
DEFINE_SERIALIZER(numByTopic, numByTopicWord);
|
24
25
|
};
|
25
26
|
|
@@ -139,8 +140,6 @@ namespace tomoto
|
|
139
140
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
140
141
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
141
142
|
{
|
142
|
-
std::vector<std::future<void>> res;
|
143
|
-
|
144
143
|
if (_ps == ParallelScheme::copy_merge)
|
145
144
|
{
|
146
145
|
tState = globalState;
|
@@ -157,17 +156,10 @@ namespace tomoto
|
|
157
156
|
}
|
158
157
|
Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
|
159
158
|
= globalState.numByTopicWord.rowwise().sum();
|
160
|
-
|
161
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
162
|
-
{
|
163
|
-
res.emplace_back(pool.enqueue([&, i](size_t)
|
164
|
-
{
|
165
|
-
localData[i] = globalState;
|
166
|
-
}));
|
167
|
-
}
|
168
159
|
}
|
169
160
|
else if (_ps == ParallelScheme::partition)
|
170
161
|
{
|
162
|
+
std::vector<std::future<void>> res;
|
171
163
|
res = pool.enqueueToAll([&](size_t partitionId)
|
172
164
|
{
|
173
165
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
@@ -175,7 +167,6 @@ namespace tomoto
|
|
175
167
|
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
176
168
|
});
|
177
169
|
for (auto& r : res) r.get();
|
178
|
-
res.clear();
|
179
170
|
|
180
171
|
// make all count being positive
|
181
172
|
if (_tw != TermWeight::one)
|
@@ -184,17 +175,11 @@ namespace tomoto
|
|
184
175
|
}
|
185
176
|
Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
|
186
177
|
= globalState.numByTopicWord.rowwise().sum();
|
187
|
-
|
188
|
-
res = pool.enqueueToAll([&](size_t threadId)
|
189
|
-
{
|
190
|
-
localData[threadId].numByTopic = globalState.numByTopic;
|
191
|
-
});
|
192
178
|
}
|
193
|
-
for (auto& r : res) r.get();
|
194
179
|
}
|
195
180
|
|
196
181
|
template<typename _DocIter>
|
197
|
-
void
|
182
|
+
void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
|
198
183
|
{
|
199
184
|
const auto K = this->K;
|
200
185
|
const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
|
@@ -313,10 +298,10 @@ namespace tomoto
|
|
313
298
|
alphas = newAlphas;
|
314
299
|
}
|
315
300
|
|
316
|
-
template<typename _DocIter>
|
301
|
+
template<GlobalSampler _gs, typename _DocIter>
|
317
302
|
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
318
303
|
{
|
319
|
-
|
304
|
+
if (_gs != GlobalSampler::inference) return const_cast<DerivedClass*>(this)->_sampleGlobalLevel(pool, localData, rgs, first, last);
|
320
305
|
}
|
321
306
|
|
322
307
|
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
@@ -343,11 +328,11 @@ namespace tomoto
|
|
343
328
|
BaseClass::prepareDoc(doc, docId, wordSize);
|
344
329
|
if (docId == (size_t)-1)
|
345
330
|
{
|
346
|
-
doc.eta.init(nullptr, this->K);
|
331
|
+
doc.eta.init(nullptr, this->K, 1);
|
347
332
|
}
|
348
333
|
else
|
349
334
|
{
|
350
|
-
doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K);
|
335
|
+
doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K, 1);
|
351
336
|
}
|
352
337
|
}
|
353
338
|
|
@@ -427,7 +412,7 @@ namespace tomoto
|
|
427
412
|
numDocsByTime[doc.timepoint]++;
|
428
413
|
if (!initDocs)
|
429
414
|
{
|
430
|
-
doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K);
|
415
|
+
doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
|
431
416
|
}
|
432
417
|
}
|
433
418
|
|
@@ -96,7 +96,7 @@ namespace tomoto
|
|
96
96
|
ld.numTableByTopic.tail(newSize - oldSize).setZero();
|
97
97
|
ld.numByTopic.conservativeResize(newSize);
|
98
98
|
ld.numByTopic.tail(newSize - oldSize).setZero();
|
99
|
-
ld.numByTopicWord.conservativeResize(newSize,
|
99
|
+
ld.numByTopicWord.conservativeResize(newSize, V);
|
100
100
|
ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, V).setZero();
|
101
101
|
}
|
102
102
|
else
|
@@ -155,7 +155,7 @@ namespace tomoto
|
|
155
155
|
if (_inc > 0 && tid >= doc.numByTopic.size())
|
156
156
|
{
|
157
157
|
size_t oldSize = doc.numByTopic.size();
|
158
|
-
doc.numByTopic.conservativeResize(tid + 1);
|
158
|
+
doc.numByTopic.conservativeResize(tid + 1, 1);
|
159
159
|
doc.numByTopic.tail(tid + 1 - oldSize).setZero();
|
160
160
|
}
|
161
161
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -282,7 +282,7 @@ namespace tomoto
|
|
282
282
|
auto& doc = this->docs[j];
|
283
283
|
if (doc.numByTopic.size() >= K) continue;
|
284
284
|
size_t oldSize = doc.numByTopic.size();
|
285
|
-
doc.numByTopic.conservativeResize(K);
|
285
|
+
doc.numByTopic.conservativeResize(K, 1);
|
286
286
|
doc.numByTopic.tail(K - oldSize).setZero();
|
287
287
|
}
|
288
288
|
}, this->docs.size() * i / pool.getNumWorkers(), this->docs.size() * (i + 1) / pool.getNumWorkers()));
|
@@ -293,7 +293,6 @@ namespace tomoto
|
|
293
293
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
294
294
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
295
295
|
{
|
296
|
-
std::vector<std::future<void>> res;
|
297
296
|
const size_t V = this->realV;
|
298
297
|
auto K = this->K;
|
299
298
|
|
@@ -303,7 +302,7 @@ namespace tomoto
|
|
303
302
|
globalState.numByTopic.conservativeResize(K);
|
304
303
|
globalState.numByTopic.tail(K - oldSize).setZero();
|
305
304
|
globalState.numTableByTopic.resize(K);
|
306
|
-
globalState.numByTopicWord.conservativeResize(K,
|
305
|
+
globalState.numByTopicWord.conservativeResize(K, V);
|
307
306
|
globalState.numByTopicWord.block(oldSize, 0, K - oldSize, V).setZero();
|
308
307
|
}
|
309
308
|
|
@@ -321,7 +320,7 @@ namespace tomoto
|
|
321
320
|
if (_tw != TermWeight::one)
|
322
321
|
{
|
323
322
|
globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
|
324
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
323
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
325
324
|
}
|
326
325
|
|
327
326
|
|
@@ -334,15 +333,6 @@ namespace tomoto
|
|
334
333
|
}
|
335
334
|
}
|
336
335
|
globalState.totalTable = globalState.numTableByTopic.sum();
|
337
|
-
|
338
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
339
|
-
{
|
340
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
341
|
-
{
|
342
|
-
localData[i] = globalState;
|
343
|
-
}));
|
344
|
-
}
|
345
|
-
for (auto& r : res) r.get();
|
346
336
|
}
|
347
337
|
|
348
338
|
/* this LL calculation is based on https://github.com/blei-lab/hdp/blob/master/hdp/state.cpp */
|
@@ -400,13 +390,14 @@ namespace tomoto
|
|
400
390
|
{
|
401
391
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
402
392
|
this->globalState.numTableByTopic = Eigen::Matrix<int32_t, -1, 1>::Zero(K);
|
403
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
393
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
394
|
+
this->globalState.numByTopicWord.init(nullptr, K, V);
|
404
395
|
}
|
405
396
|
}
|
406
397
|
|
407
398
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
408
399
|
{
|
409
|
-
doc.numByTopic.init(nullptr, this->K);
|
400
|
+
doc.numByTopic.init(nullptr, this->K, 1);
|
410
401
|
doc.numTopicByTable.clear();
|
411
402
|
doc.Zs = tvector<Tid>(wordSize);
|
412
403
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
@@ -577,7 +568,7 @@ namespace tomoto
|
|
577
568
|
template<typename _TopicModel>
|
578
569
|
void DocumentHDP<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
579
570
|
{
|
580
|
-
this->numByTopic.init(ptr, mdl.getK());
|
571
|
+
this->numByTopic.init(ptr, mdl.getK(), 1);
|
581
572
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
582
573
|
{
|
583
574
|
if (this->words[i] >= mdl.getV()) continue;
|
@@ -119,19 +119,26 @@ namespace tomoto
|
|
119
119
|
|
120
120
|
DEFINE_SERIALIZER(nodes, levelBlocks);
|
121
121
|
|
122
|
-
template<bool
|
122
|
+
template<bool _makeNewPath = true>
|
123
123
|
void calcNodeLikelihood(Float gamma, size_t levelDepth)
|
124
124
|
{
|
125
125
|
nodeLikelihoods.resize(nodes.size());
|
126
126
|
nodeLikelihoods.array() = -INFINITY;
|
127
|
-
updateNodeLikelihood<
|
127
|
+
updateNodeLikelihood<_makeNewPath>(gamma, levelDepth, &nodes[0]);
|
128
|
+
if (!_makeNewPath)
|
129
|
+
{
|
130
|
+
for (size_t i = 0; i < levelBlocks.size(); ++i)
|
131
|
+
{
|
132
|
+
if (levelBlocks[i] < levelDepth - 1) nodeLikelihoods.segment((i + 1) * blockSize, blockSize).array() = -INFINITY;
|
133
|
+
}
|
134
|
+
}
|
128
135
|
}
|
129
136
|
|
130
|
-
template<bool
|
137
|
+
template<bool _makeNewPath = true>
|
131
138
|
void updateNodeLikelihood(Float gamma, size_t levelDepth, NCRPNode* node, Float weight = 0)
|
132
139
|
{
|
133
140
|
size_t idx = node - nodes.data();
|
134
|
-
const Float pNewNode =
|
141
|
+
const Float pNewNode = _makeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
|
135
142
|
nodeLikelihoods[idx] = weight + (((size_t)node->level < levelDepth - 1) ? pNewNode : 0);
|
136
143
|
for(auto * child = node->getChild(); child; child = child->getSibling())
|
137
144
|
{
|
@@ -187,7 +194,7 @@ namespace tomoto
|
|
187
194
|
std::vector<std::future<void>> futures;
|
188
195
|
futures.reserve(levelBlocks.size());
|
189
196
|
|
190
|
-
auto calc = [
|
197
|
+
auto calc = [&, eta, realV](size_t threadId, size_t b)
|
191
198
|
{
|
192
199
|
Float cnt = 0;
|
193
200
|
Vid prevWord = -1;
|
@@ -284,7 +291,7 @@ namespace tomoto
|
|
284
291
|
size_t oldSize = ld.numByTopic.rows();
|
285
292
|
size_t newSize = std::max(nodes.size(), ((oldSize + oldSize / 2 + 7) / 8) * 8);
|
286
293
|
ld.numByTopic.conservativeResize(newSize);
|
287
|
-
ld.numByTopicWord.conservativeResize(newSize,
|
294
|
+
ld.numByTopicWord.conservativeResize(newSize, ld.numByTopicWord.cols());
|
288
295
|
ld.numByTopic.segment(oldSize, newSize - oldSize).setZero();
|
289
296
|
ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, ld.numByTopicWord.cols()).setZero();
|
290
297
|
}
|
@@ -317,13 +324,13 @@ namespace tomoto
|
|
317
324
|
typename _Derived = void,
|
318
325
|
typename _DocType = DocumentHLDA<_tw>,
|
319
326
|
typename _ModelState = ModelStateHLDA<_tw>>
|
320
|
-
class HLDAModel : public LDAModel<_tw, _RandGen, flags::
|
327
|
+
class HLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
|
321
328
|
typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type,
|
322
329
|
_DocType, _ModelState>
|
323
330
|
{
|
324
331
|
protected:
|
325
332
|
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type;
|
326
|
-
using BaseClass = LDAModel<_tw, _RandGen, flags::
|
333
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
327
334
|
friend BaseClass;
|
328
335
|
friend typename BaseClass::BaseClass;
|
329
336
|
using WeightType = typename BaseClass::WeightType;
|
@@ -341,11 +348,11 @@ namespace tomoto
|
|
341
348
|
}
|
342
349
|
|
343
350
|
// Words of all documents should be sorted by ascending order.
|
344
|
-
template<
|
351
|
+
template<GlobalSampler _gs>
|
345
352
|
void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
|
346
353
|
{
|
347
|
-
if(
|
348
|
-
ld.nt->template calcNodeLikelihood<
|
354
|
+
if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
|
355
|
+
ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
|
349
356
|
|
350
357
|
std::vector<Float> newTopicWeights(this->K - 1);
|
351
358
|
std::vector<WeightType> cntByLevel(this->K);
|
@@ -355,7 +362,7 @@ namespace tomoto
|
|
355
362
|
if (doc.words[w] >= this->realV) break;
|
356
363
|
addWordToOnlyLocal<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
357
364
|
|
358
|
-
if (
|
365
|
+
if (_gs == GlobalSampler::train)
|
359
366
|
{
|
360
367
|
if (doc.words[w] != prevWord)
|
361
368
|
{
|
@@ -371,7 +378,7 @@ namespace tomoto
|
|
371
378
|
}
|
372
379
|
}
|
373
380
|
|
374
|
-
if (
|
381
|
+
if (_gs == GlobalSampler::train)
|
375
382
|
{
|
376
383
|
for (size_t l = 1; l < this->K; ++l)
|
377
384
|
{
|
@@ -386,7 +393,7 @@ namespace tomoto
|
|
386
393
|
size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
|
387
394
|
ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
|
388
395
|
|
389
|
-
if(
|
396
|
+
if(_gs == GlobalSampler::train) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
|
390
397
|
doc.path.back() = newPath;
|
391
398
|
for (size_t l = this->K - 2; l > 0; --l)
|
392
399
|
{
|
@@ -398,7 +405,7 @@ namespace tomoto
|
|
398
405
|
if (doc.words[w] >= this->realV) break;
|
399
406
|
addWordToOnlyLocal<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
400
407
|
}
|
401
|
-
if (
|
408
|
+
if (_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].addPathOne();
|
402
409
|
}
|
403
410
|
|
404
411
|
template<int _inc>
|
@@ -426,6 +433,7 @@ namespace tomoto
|
|
426
433
|
template<bool _asymEta>
|
427
434
|
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
428
435
|
{
|
436
|
+
if (_asymEta) THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
|
429
437
|
const size_t V = this->realV;
|
430
438
|
assert(vid < V);
|
431
439
|
auto& zLikelihood = ld.zLikelihood;
|
@@ -439,50 +447,14 @@ namespace tomoto
|
|
439
447
|
return &zLikelihood[0];
|
440
448
|
}
|
441
449
|
|
442
|
-
|
443
|
-
|
444
|
-
for (size_t w = 0; w < doc.words.size(); ++w)
|
445
|
-
{
|
446
|
-
if (doc.words[w] >= this->realV) continue;
|
447
|
-
addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
448
|
-
Float* dist;
|
449
|
-
if (this->etaByTopicWord.size())
|
450
|
-
{
|
451
|
-
THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
|
452
|
-
}
|
453
|
-
else
|
454
|
-
{
|
455
|
-
dist = static_cast<const DerivedClass*>(this)->template
|
456
|
-
getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
|
457
|
-
}
|
458
|
-
doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + this->K, rgs);
|
459
|
-
addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
460
|
-
}
|
461
|
-
}
|
462
|
-
|
463
|
-
template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
|
464
|
-
void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
|
465
|
-
{
|
466
|
-
sampleTopics(doc, docId, ld, rgs);
|
467
|
-
}
|
468
|
-
|
469
|
-
template<typename _DocIter>
|
470
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
|
450
|
+
template<GlobalSampler _gs, typename _DocIter>
|
451
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState* globalData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
471
452
|
{
|
472
453
|
for (auto doc = first; doc != last; ++doc)
|
473
454
|
{
|
474
|
-
samplePathes
|
475
|
-
}
|
476
|
-
localData->nt->markEmptyBlocks();
|
477
|
-
}
|
478
|
-
|
479
|
-
template<typename _DocIter>
|
480
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
481
|
-
{
|
482
|
-
for (auto doc = first; doc != last; ++doc)
|
483
|
-
{
|
484
|
-
samplePathes<false>(*doc, pool, *localData, rgs[0]);
|
455
|
+
samplePathes<_gs>(*doc, pool, *globalData, rgs[0]);
|
485
456
|
}
|
457
|
+
if (_gs != GlobalSampler::inference) globalData->nt->markEmptyBlocks();
|
486
458
|
}
|
487
459
|
|
488
460
|
template<typename _DocIter>
|
@@ -539,7 +511,8 @@ namespace tomoto
|
|
539
511
|
if (initDocs)
|
540
512
|
{
|
541
513
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
|
542
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
|
514
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
|
515
|
+
this->globalState.numByTopicWord.init(nullptr, this->K, V);
|
543
516
|
this->globalState.nt->nodes.resize(detail::NodeTrees::blockSize);
|
544
517
|
}
|
545
518
|
}
|
@@ -547,7 +520,7 @@ namespace tomoto
|
|
547
520
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
548
521
|
{
|
549
522
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
550
|
-
doc.numByTopic.init(nullptr, this->K);
|
523
|
+
doc.numByTopic.init(nullptr, this->K, 1);
|
551
524
|
doc.Zs = tvector<Tid>(wordSize);
|
552
525
|
doc.path.resize(this->K);
|
553
526
|
for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
|
@@ -595,6 +568,31 @@ namespace tomoto
|
|
595
568
|
return cnt;
|
596
569
|
}
|
597
570
|
|
571
|
+
template<ParallelScheme _ps>
|
572
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
573
|
+
{
|
574
|
+
std::vector<std::future<void>> res;
|
575
|
+
if (_ps == ParallelScheme::copy_merge)
|
576
|
+
{
|
577
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
578
|
+
{
|
579
|
+
res.emplace_back(pool.enqueue([&, i](size_t)
|
580
|
+
{
|
581
|
+
localData[i] = globalState;
|
582
|
+
}));
|
583
|
+
}
|
584
|
+
}
|
585
|
+
else if (_ps == ParallelScheme::partition)
|
586
|
+
{
|
587
|
+
res = pool.enqueueToAll([&](size_t threadId)
|
588
|
+
{
|
589
|
+
localData[threadId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
|
590
|
+
localData[threadId].numByTopic = globalState.numByTopic;
|
591
|
+
});
|
592
|
+
}
|
593
|
+
for (auto& r : res) r.get();
|
594
|
+
}
|
595
|
+
|
598
596
|
public:
|
599
597
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
|
600
598
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
|
@@ -671,7 +669,7 @@ namespace tomoto
|
|
671
669
|
template<typename _TopicModel>
|
672
670
|
inline void DocumentHLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
673
671
|
{
|
674
|
-
this->numByTopic.init(ptr, mdl.getLevelDepth());
|
672
|
+
this->numByTopic.init(ptr, mdl.getLevelDepth(), 1);
|
675
673
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
676
674
|
{
|
677
675
|
if (this->words[i] >= mdl.getV()) continue;
|