tomoto 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -65,6 +65,7 @@ namespace tomoto
|
|
65
65
|
{
|
66
66
|
if (i == 0) pbeta = Eigen::Matrix<Float, -1, 1>::Ones(this->K);
|
67
67
|
else pbeta = doc.beta.col(i % numBetaSample).array().exp();
|
68
|
+
|
68
69
|
Float betaESum = pbeta.sum() + 1;
|
69
70
|
pbeta /= betaESum;
|
70
71
|
for (size_t k = 0; k < this->K; ++k)
|
@@ -78,7 +79,9 @@ namespace tomoto
|
|
78
79
|
|
79
80
|
Float c = betaESum * (1 - pbeta[k]);
|
80
81
|
lowerBound[k] = log(c * max_uk / (1 - max_uk));
|
81
|
-
|
82
|
+
lowerBound[k] = std::max(std::min(lowerBound[k], (Float)100), (Float)-100);
|
83
|
+
upperBound[k] = log(c * min_unk / (1 - min_unk + epsilon));
|
84
|
+
upperBound[k] = std::max(std::min(upperBound[k], (Float)100), (Float)-100);
|
82
85
|
if (lowerBound[k] > upperBound[k])
|
83
86
|
{
|
84
87
|
THROW_ERROR_WITH_INFO(exception::TrainingError,
|
@@ -120,8 +123,8 @@ namespace tomoto
|
|
120
123
|
}*/
|
121
124
|
}
|
122
125
|
|
123
|
-
template<typename _DocIter>
|
124
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState
|
126
|
+
template<GlobalSampler _gs, typename _DocIter>
|
127
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last) const
|
125
128
|
{
|
126
129
|
if (this->globalStep < this->burnIn || !this->optimInterval || (this->globalStep + 1) % this->optimInterval != 0) return;
|
127
130
|
|
@@ -11,7 +11,7 @@ namespace tomoto
|
|
11
11
|
using DocumentLDA<_tw>::DocumentLDA;
|
12
12
|
|
13
13
|
uint64_t timepoint = 0;
|
14
|
-
|
14
|
+
ShareableMatrix<Float, -1, 1> eta;
|
15
15
|
sample::AliasMethod<> aliasTable;
|
16
16
|
|
17
17
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
|
@@ -20,6 +20,7 @@ namespace tomoto
|
|
20
20
|
|
21
21
|
Eigen::Matrix<WeightType, -1, -1> numByTopic; // Dim: (Topic, Time)
|
22
22
|
Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
|
23
|
+
//ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic * Time, Vocabs)
|
23
24
|
DEFINE_SERIALIZER(numByTopic, numByTopicWord);
|
24
25
|
};
|
25
26
|
|
@@ -139,8 +140,6 @@ namespace tomoto
|
|
139
140
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
140
141
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
141
142
|
{
|
142
|
-
std::vector<std::future<void>> res;
|
143
|
-
|
144
143
|
if (_ps == ParallelScheme::copy_merge)
|
145
144
|
{
|
146
145
|
tState = globalState;
|
@@ -157,17 +156,10 @@ namespace tomoto
|
|
157
156
|
}
|
158
157
|
Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
|
159
158
|
= globalState.numByTopicWord.rowwise().sum();
|
160
|
-
|
161
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
162
|
-
{
|
163
|
-
res.emplace_back(pool.enqueue([&, i](size_t)
|
164
|
-
{
|
165
|
-
localData[i] = globalState;
|
166
|
-
}));
|
167
|
-
}
|
168
159
|
}
|
169
160
|
else if (_ps == ParallelScheme::partition)
|
170
161
|
{
|
162
|
+
std::vector<std::future<void>> res;
|
171
163
|
res = pool.enqueueToAll([&](size_t partitionId)
|
172
164
|
{
|
173
165
|
size_t b = partitionId ? edd.vChunkOffset[partitionId - 1] : 0,
|
@@ -175,7 +167,6 @@ namespace tomoto
|
|
175
167
|
globalState.numByTopicWord.block(0, b, globalState.numByTopicWord.rows(), e - b) = localData[partitionId].numByTopicWord;
|
176
168
|
});
|
177
169
|
for (auto& r : res) r.get();
|
178
|
-
res.clear();
|
179
170
|
|
180
171
|
// make all count being positive
|
181
172
|
if (_tw != TermWeight::one)
|
@@ -184,17 +175,11 @@ namespace tomoto
|
|
184
175
|
}
|
185
176
|
Eigen::Map<Eigen::Matrix<WeightType, -1, 1>>{ globalState.numByTopic.data(), globalState.numByTopic.size() }
|
186
177
|
= globalState.numByTopicWord.rowwise().sum();
|
187
|
-
|
188
|
-
res = pool.enqueueToAll([&](size_t threadId)
|
189
|
-
{
|
190
|
-
localData[threadId].numByTopic = globalState.numByTopic;
|
191
|
-
});
|
192
178
|
}
|
193
|
-
for (auto& r : res) r.get();
|
194
179
|
}
|
195
180
|
|
196
181
|
template<typename _DocIter>
|
197
|
-
void
|
182
|
+
void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
|
198
183
|
{
|
199
184
|
const auto K = this->K;
|
200
185
|
const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
|
@@ -313,10 +298,10 @@ namespace tomoto
|
|
313
298
|
alphas = newAlphas;
|
314
299
|
}
|
315
300
|
|
316
|
-
template<typename _DocIter>
|
301
|
+
template<GlobalSampler _gs, typename _DocIter>
|
317
302
|
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
318
303
|
{
|
319
|
-
|
304
|
+
if (_gs != GlobalSampler::inference) return const_cast<DerivedClass*>(this)->_sampleGlobalLevel(pool, localData, rgs, first, last);
|
320
305
|
}
|
321
306
|
|
322
307
|
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
@@ -343,11 +328,11 @@ namespace tomoto
|
|
343
328
|
BaseClass::prepareDoc(doc, docId, wordSize);
|
344
329
|
if (docId == (size_t)-1)
|
345
330
|
{
|
346
|
-
doc.eta.init(nullptr, this->K);
|
331
|
+
doc.eta.init(nullptr, this->K, 1);
|
347
332
|
}
|
348
333
|
else
|
349
334
|
{
|
350
|
-
doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K);
|
335
|
+
doc.eta.init((Float*)etaByDoc.col(docId).data(), this->K, 1);
|
351
336
|
}
|
352
337
|
}
|
353
338
|
|
@@ -427,7 +412,7 @@ namespace tomoto
|
|
427
412
|
numDocsByTime[doc.timepoint]++;
|
428
413
|
if (!initDocs)
|
429
414
|
{
|
430
|
-
doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K);
|
415
|
+
doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
|
431
416
|
}
|
432
417
|
}
|
433
418
|
|
@@ -96,7 +96,7 @@ namespace tomoto
|
|
96
96
|
ld.numTableByTopic.tail(newSize - oldSize).setZero();
|
97
97
|
ld.numByTopic.conservativeResize(newSize);
|
98
98
|
ld.numByTopic.tail(newSize - oldSize).setZero();
|
99
|
-
ld.numByTopicWord.conservativeResize(newSize,
|
99
|
+
ld.numByTopicWord.conservativeResize(newSize, V);
|
100
100
|
ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, V).setZero();
|
101
101
|
}
|
102
102
|
else
|
@@ -155,7 +155,7 @@ namespace tomoto
|
|
155
155
|
if (_inc > 0 && tid >= doc.numByTopic.size())
|
156
156
|
{
|
157
157
|
size_t oldSize = doc.numByTopic.size();
|
158
|
-
doc.numByTopic.conservativeResize(tid + 1);
|
158
|
+
doc.numByTopic.conservativeResize(tid + 1, 1);
|
159
159
|
doc.numByTopic.tail(tid + 1 - oldSize).setZero();
|
160
160
|
}
|
161
161
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -282,7 +282,7 @@ namespace tomoto
|
|
282
282
|
auto& doc = this->docs[j];
|
283
283
|
if (doc.numByTopic.size() >= K) continue;
|
284
284
|
size_t oldSize = doc.numByTopic.size();
|
285
|
-
doc.numByTopic.conservativeResize(K);
|
285
|
+
doc.numByTopic.conservativeResize(K, 1);
|
286
286
|
doc.numByTopic.tail(K - oldSize).setZero();
|
287
287
|
}
|
288
288
|
}, this->docs.size() * i / pool.getNumWorkers(), this->docs.size() * (i + 1) / pool.getNumWorkers()));
|
@@ -293,7 +293,6 @@ namespace tomoto
|
|
293
293
|
template<ParallelScheme _ps, typename _ExtraDocData>
|
294
294
|
void mergeState(ThreadPool& pool, _ModelState& globalState, _ModelState& tState, _ModelState* localData, _RandGen*, const _ExtraDocData& edd) const
|
295
295
|
{
|
296
|
-
std::vector<std::future<void>> res;
|
297
296
|
const size_t V = this->realV;
|
298
297
|
auto K = this->K;
|
299
298
|
|
@@ -303,7 +302,7 @@ namespace tomoto
|
|
303
302
|
globalState.numByTopic.conservativeResize(K);
|
304
303
|
globalState.numByTopic.tail(K - oldSize).setZero();
|
305
304
|
globalState.numTableByTopic.resize(K);
|
306
|
-
globalState.numByTopicWord.conservativeResize(K,
|
305
|
+
globalState.numByTopicWord.conservativeResize(K, V);
|
307
306
|
globalState.numByTopicWord.block(oldSize, 0, K - oldSize, V).setZero();
|
308
307
|
}
|
309
308
|
|
@@ -321,7 +320,7 @@ namespace tomoto
|
|
321
320
|
if (_tw != TermWeight::one)
|
322
321
|
{
|
323
322
|
globalState.numByTopic = globalState.numByTopic.cwiseMax(0);
|
324
|
-
globalState.numByTopicWord = globalState.numByTopicWord.cwiseMax(0);
|
323
|
+
globalState.numByTopicWord.matrix() = globalState.numByTopicWord.cwiseMax(0);
|
325
324
|
}
|
326
325
|
|
327
326
|
|
@@ -334,15 +333,6 @@ namespace tomoto
|
|
334
333
|
}
|
335
334
|
}
|
336
335
|
globalState.totalTable = globalState.numTableByTopic.sum();
|
337
|
-
|
338
|
-
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
339
|
-
{
|
340
|
-
res.emplace_back(pool.enqueue([&, this, i](size_t threadId)
|
341
|
-
{
|
342
|
-
localData[i] = globalState;
|
343
|
-
}));
|
344
|
-
}
|
345
|
-
for (auto& r : res) r.get();
|
346
336
|
}
|
347
337
|
|
348
338
|
/* this LL calculation is based on https://github.com/blei-lab/hdp/blob/master/hdp/state.cpp */
|
@@ -400,13 +390,14 @@ namespace tomoto
|
|
400
390
|
{
|
401
391
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
402
392
|
this->globalState.numTableByTopic = Eigen::Matrix<int32_t, -1, 1>::Zero(K);
|
403
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
393
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(K, V);
|
394
|
+
this->globalState.numByTopicWord.init(nullptr, K, V);
|
404
395
|
}
|
405
396
|
}
|
406
397
|
|
407
398
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
408
399
|
{
|
409
|
-
doc.numByTopic.init(nullptr, this->K);
|
400
|
+
doc.numByTopic.init(nullptr, this->K, 1);
|
410
401
|
doc.numTopicByTable.clear();
|
411
402
|
doc.Zs = tvector<Tid>(wordSize);
|
412
403
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
@@ -577,7 +568,7 @@ namespace tomoto
|
|
577
568
|
template<typename _TopicModel>
|
578
569
|
void DocumentHDP<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
579
570
|
{
|
580
|
-
this->numByTopic.init(ptr, mdl.getK());
|
571
|
+
this->numByTopic.init(ptr, mdl.getK(), 1);
|
581
572
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
582
573
|
{
|
583
574
|
if (this->words[i] >= mdl.getV()) continue;
|
@@ -119,19 +119,26 @@ namespace tomoto
|
|
119
119
|
|
120
120
|
DEFINE_SERIALIZER(nodes, levelBlocks);
|
121
121
|
|
122
|
-
template<bool
|
122
|
+
template<bool _makeNewPath = true>
|
123
123
|
void calcNodeLikelihood(Float gamma, size_t levelDepth)
|
124
124
|
{
|
125
125
|
nodeLikelihoods.resize(nodes.size());
|
126
126
|
nodeLikelihoods.array() = -INFINITY;
|
127
|
-
updateNodeLikelihood<
|
127
|
+
updateNodeLikelihood<_makeNewPath>(gamma, levelDepth, &nodes[0]);
|
128
|
+
if (!_makeNewPath)
|
129
|
+
{
|
130
|
+
for (size_t i = 0; i < levelBlocks.size(); ++i)
|
131
|
+
{
|
132
|
+
if (levelBlocks[i] < levelDepth - 1) nodeLikelihoods.segment((i + 1) * blockSize, blockSize).array() = -INFINITY;
|
133
|
+
}
|
134
|
+
}
|
128
135
|
}
|
129
136
|
|
130
|
-
template<bool
|
137
|
+
template<bool _makeNewPath = true>
|
131
138
|
void updateNodeLikelihood(Float gamma, size_t levelDepth, NCRPNode* node, Float weight = 0)
|
132
139
|
{
|
133
140
|
size_t idx = node - nodes.data();
|
134
|
-
const Float pNewNode =
|
141
|
+
const Float pNewNode = _makeNewPath ? log(gamma / (node->numCustomers + gamma)) : -INFINITY;
|
135
142
|
nodeLikelihoods[idx] = weight + (((size_t)node->level < levelDepth - 1) ? pNewNode : 0);
|
136
143
|
for(auto * child = node->getChild(); child; child = child->getSibling())
|
137
144
|
{
|
@@ -187,7 +194,7 @@ namespace tomoto
|
|
187
194
|
std::vector<std::future<void>> futures;
|
188
195
|
futures.reserve(levelBlocks.size());
|
189
196
|
|
190
|
-
auto calc = [
|
197
|
+
auto calc = [&, eta, realV](size_t threadId, size_t b)
|
191
198
|
{
|
192
199
|
Float cnt = 0;
|
193
200
|
Vid prevWord = -1;
|
@@ -284,7 +291,7 @@ namespace tomoto
|
|
284
291
|
size_t oldSize = ld.numByTopic.rows();
|
285
292
|
size_t newSize = std::max(nodes.size(), ((oldSize + oldSize / 2 + 7) / 8) * 8);
|
286
293
|
ld.numByTopic.conservativeResize(newSize);
|
287
|
-
ld.numByTopicWord.conservativeResize(newSize,
|
294
|
+
ld.numByTopicWord.conservativeResize(newSize, ld.numByTopicWord.cols());
|
288
295
|
ld.numByTopic.segment(oldSize, newSize - oldSize).setZero();
|
289
296
|
ld.numByTopicWord.block(oldSize, 0, newSize - oldSize, ld.numByTopicWord.cols()).setZero();
|
290
297
|
}
|
@@ -317,13 +324,13 @@ namespace tomoto
|
|
317
324
|
typename _Derived = void,
|
318
325
|
typename _DocType = DocumentHLDA<_tw>,
|
319
326
|
typename _ModelState = ModelStateHLDA<_tw>>
|
320
|
-
class HLDAModel : public LDAModel<_tw, _RandGen, flags::
|
327
|
+
class HLDAModel : public LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface,
|
321
328
|
typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type,
|
322
329
|
_DocType, _ModelState>
|
323
330
|
{
|
324
331
|
protected:
|
325
332
|
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, HLDAModel<_tw, _RandGen>, _Derived>::type;
|
326
|
-
using BaseClass = LDAModel<_tw, _RandGen, flags::
|
333
|
+
using BaseClass = LDAModel<_tw, _RandGen, flags::partitioned_multisampling, _Interface, DerivedClass, _DocType, _ModelState>;
|
327
334
|
friend BaseClass;
|
328
335
|
friend typename BaseClass::BaseClass;
|
329
336
|
using WeightType = typename BaseClass::WeightType;
|
@@ -341,11 +348,11 @@ namespace tomoto
|
|
341
348
|
}
|
342
349
|
|
343
350
|
// Words of all documents should be sorted by ascending order.
|
344
|
-
template<
|
351
|
+
template<GlobalSampler _gs>
|
345
352
|
void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
|
346
353
|
{
|
347
|
-
if(
|
348
|
-
ld.nt->template calcNodeLikelihood<
|
354
|
+
if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
|
355
|
+
ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
|
349
356
|
|
350
357
|
std::vector<Float> newTopicWeights(this->K - 1);
|
351
358
|
std::vector<WeightType> cntByLevel(this->K);
|
@@ -355,7 +362,7 @@ namespace tomoto
|
|
355
362
|
if (doc.words[w] >= this->realV) break;
|
356
363
|
addWordToOnlyLocal<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
357
364
|
|
358
|
-
if (
|
365
|
+
if (_gs == GlobalSampler::train)
|
359
366
|
{
|
360
367
|
if (doc.words[w] != prevWord)
|
361
368
|
{
|
@@ -371,7 +378,7 @@ namespace tomoto
|
|
371
378
|
}
|
372
379
|
}
|
373
380
|
|
374
|
-
if (
|
381
|
+
if (_gs == GlobalSampler::train)
|
375
382
|
{
|
376
383
|
for (size_t l = 1; l < this->K; ++l)
|
377
384
|
{
|
@@ -386,7 +393,7 @@ namespace tomoto
|
|
386
393
|
size_t newPath = sample::sampleFromDiscreteAcc(ld.nt->nodeLikelihoods.data(),
|
387
394
|
ld.nt->nodeLikelihoods.data() + ld.nt->nodeLikelihoods.size(), rgs);
|
388
395
|
|
389
|
-
if(
|
396
|
+
if(_gs == GlobalSampler::train) newPath = ld.nt->template generateLeafNode<_tw>(newPath, this->K, ld);
|
390
397
|
doc.path.back() = newPath;
|
391
398
|
for (size_t l = this->K - 2; l > 0; --l)
|
392
399
|
{
|
@@ -398,7 +405,7 @@ namespace tomoto
|
|
398
405
|
if (doc.words[w] >= this->realV) break;
|
399
406
|
addWordToOnlyLocal<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
400
407
|
}
|
401
|
-
if (
|
408
|
+
if (_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].addPathOne();
|
402
409
|
}
|
403
410
|
|
404
411
|
template<int _inc>
|
@@ -426,6 +433,7 @@ namespace tomoto
|
|
426
433
|
template<bool _asymEta>
|
427
434
|
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
428
435
|
{
|
436
|
+
if (_asymEta) THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
|
429
437
|
const size_t V = this->realV;
|
430
438
|
assert(vid < V);
|
431
439
|
auto& zLikelihood = ld.zLikelihood;
|
@@ -439,50 +447,14 @@ namespace tomoto
|
|
439
447
|
return &zLikelihood[0];
|
440
448
|
}
|
441
449
|
|
442
|
-
|
443
|
-
|
444
|
-
for (size_t w = 0; w < doc.words.size(); ++w)
|
445
|
-
{
|
446
|
-
if (doc.words[w] >= this->realV) continue;
|
447
|
-
addWordTo<-1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
448
|
-
Float* dist;
|
449
|
-
if (this->etaByTopicWord.size())
|
450
|
-
{
|
451
|
-
THROW_ERROR_WITH_INFO(exception::Unimplemented, "Unimplemented features");
|
452
|
-
}
|
453
|
-
else
|
454
|
-
{
|
455
|
-
dist = static_cast<const DerivedClass*>(this)->template
|
456
|
-
getZLikelihoods<false>(ld, doc, docId, doc.words[w]);
|
457
|
-
}
|
458
|
-
doc.Zs[w] = sample::sampleFromDiscreteAcc(dist, dist + this->K, rgs);
|
459
|
-
addWordTo<1>(ld, doc, w, doc.words[w], doc.Zs[w]);
|
460
|
-
}
|
461
|
-
}
|
462
|
-
|
463
|
-
template<ParallelScheme _ps, bool _infer, typename _ExtraDocData>
|
464
|
-
void sampleDocument(_DocType& doc, const _ExtraDocData& edd, size_t docId, _ModelState& ld, _RandGen& rgs, size_t iterationCnt, size_t partitionId = 0) const
|
465
|
-
{
|
466
|
-
sampleTopics(doc, docId, ld, rgs);
|
467
|
-
}
|
468
|
-
|
469
|
-
template<typename _DocIter>
|
470
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last)
|
450
|
+
template<GlobalSampler _gs, typename _DocIter>
|
451
|
+
void sampleGlobalLevel(ThreadPool* pool, _ModelState* globalData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
471
452
|
{
|
472
453
|
for (auto doc = first; doc != last; ++doc)
|
473
454
|
{
|
474
|
-
samplePathes
|
475
|
-
}
|
476
|
-
localData->nt->markEmptyBlocks();
|
477
|
-
}
|
478
|
-
|
479
|
-
template<typename _DocIter>
|
480
|
-
void sampleGlobalLevel(ThreadPool* pool, _ModelState* localData, _RandGen* rgs, _DocIter first, _DocIter last) const
|
481
|
-
{
|
482
|
-
for (auto doc = first; doc != last; ++doc)
|
483
|
-
{
|
484
|
-
samplePathes<false>(*doc, pool, *localData, rgs[0]);
|
455
|
+
samplePathes<_gs>(*doc, pool, *globalData, rgs[0]);
|
485
456
|
}
|
457
|
+
if (_gs != GlobalSampler::inference) globalData->nt->markEmptyBlocks();
|
486
458
|
}
|
487
459
|
|
488
460
|
template<typename _DocIter>
|
@@ -539,7 +511,8 @@ namespace tomoto
|
|
539
511
|
if (initDocs)
|
540
512
|
{
|
541
513
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
|
542
|
-
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
|
514
|
+
//this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, V);
|
515
|
+
this->globalState.numByTopicWord.init(nullptr, this->K, V);
|
543
516
|
this->globalState.nt->nodes.resize(detail::NodeTrees::blockSize);
|
544
517
|
}
|
545
518
|
}
|
@@ -547,7 +520,7 @@ namespace tomoto
|
|
547
520
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
548
521
|
{
|
549
522
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
550
|
-
doc.numByTopic.init(nullptr, this->K);
|
523
|
+
doc.numByTopic.init(nullptr, this->K, 1);
|
551
524
|
doc.Zs = tvector<Tid>(wordSize);
|
552
525
|
doc.path.resize(this->K);
|
553
526
|
for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
|
@@ -595,6 +568,31 @@ namespace tomoto
|
|
595
568
|
return cnt;
|
596
569
|
}
|
597
570
|
|
571
|
+
template<ParallelScheme _ps>
|
572
|
+
void distributeMergedState(ThreadPool& pool, _ModelState& globalState, _ModelState* localData) const
|
573
|
+
{
|
574
|
+
std::vector<std::future<void>> res;
|
575
|
+
if (_ps == ParallelScheme::copy_merge)
|
576
|
+
{
|
577
|
+
for (size_t i = 0; i < pool.getNumWorkers(); ++i)
|
578
|
+
{
|
579
|
+
res.emplace_back(pool.enqueue([&, i](size_t)
|
580
|
+
{
|
581
|
+
localData[i] = globalState;
|
582
|
+
}));
|
583
|
+
}
|
584
|
+
}
|
585
|
+
else if (_ps == ParallelScheme::partition)
|
586
|
+
{
|
587
|
+
res = pool.enqueueToAll([&](size_t threadId)
|
588
|
+
{
|
589
|
+
localData[threadId].numByTopicWord.init((WeightType*)globalState.numByTopicWord.data(), globalState.numByTopicWord.rows(), globalState.numByTopicWord.cols());
|
590
|
+
localData[threadId].numByTopic = globalState.numByTopic;
|
591
|
+
});
|
592
|
+
}
|
593
|
+
for (auto& r : res) r.get();
|
594
|
+
}
|
595
|
+
|
598
596
|
public:
|
599
597
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
|
600
598
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
|
@@ -671,7 +669,7 @@ namespace tomoto
|
|
671
669
|
template<typename _TopicModel>
|
672
670
|
inline void DocumentHLDA<_tw>::update(WeightType * ptr, const _TopicModel & mdl)
|
673
671
|
{
|
674
|
-
this->numByTopic.init(ptr, mdl.getLevelDepth());
|
672
|
+
this->numByTopic.init(ptr, mdl.getLevelDepth(), 1);
|
675
673
|
for (size_t i = 0; i < this->Zs.size(); ++i)
|
676
674
|
{
|
677
675
|
if (this->words[i] >= mdl.getV()) continue;
|