tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -2,16 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class MGLDAModel<TermWeight::idf>;
|
|
7
|
-
template class MGLDAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IMGLDAModel* IMGLDAModel::create(TermWeight _weight, size_t _KG, size_t _KL, size_t _T,
|
|
10
|
-
Float _alphaG, Float _alphaL, Float _alphaMG, Float _alphaML,
|
|
11
|
-
Float _etaG, Float _etaL, Float _gamma, size_t seed, bool scalarRng)
|
|
5
|
+
IMGLDAModel* IMGLDAModel::create(TermWeight _weight, const MGLDAArgs& args, bool scalarRng)
|
|
12
6
|
{
|
|
13
|
-
TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel,
|
|
14
|
-
_alphaG, _alphaL, _alphaMG, _alphaML,
|
|
15
|
-
_etaG, _etaL, _gamma, seed);
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, MGLDAModel, args);
|
|
16
8
|
}
|
|
17
9
|
}
|
|
@@ -289,7 +289,7 @@ namespace tomoto
|
|
|
289
289
|
|
|
290
290
|
const size_t S = doc.numBySent.size();
|
|
291
291
|
std::fill(doc.numBySent.begin(), doc.numBySent.end(), 0);
|
|
292
|
-
doc.Zs = tvector<Tid>(wordSize);
|
|
292
|
+
doc.Zs = tvector<Tid>(wordSize, non_topic_id);
|
|
293
293
|
doc.Vs.resize(wordSize);
|
|
294
294
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
|
295
295
|
doc.numByTopic.init(nullptr, this->K + KL, 1);
|
|
@@ -302,7 +302,7 @@ namespace tomoto
|
|
|
302
302
|
void initGlobalState(bool initDocs)
|
|
303
303
|
{
|
|
304
304
|
const size_t V = this->realV;
|
|
305
|
-
this->globalState.zLikelihood =
|
|
305
|
+
this->globalState.zLikelihood = Vector::Zero(T * (this->K + KL));
|
|
306
306
|
if (initDocs)
|
|
307
307
|
{
|
|
308
308
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K + KL);
|
|
@@ -371,17 +371,33 @@ namespace tomoto
|
|
|
371
371
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
|
|
372
372
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, alphaL, alphaM, alphaML, etaL, gamma, KL, T);
|
|
373
373
|
|
|
374
|
-
MGLDAModel(
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
alphaL(_alphaL), alphaM(_KG ? _alphaMG : 0), alphaML(_alphaML),
|
|
379
|
-
etaL(_etaL), gamma(_gamma)
|
|
374
|
+
MGLDAModel(const MGLDAArgs& args)
|
|
375
|
+
: BaseClass(args), KL(args.kL), T(args.t),
|
|
376
|
+
alphaL(args.alphaL[0]), alphaM(args.k ? args.alphaMG : 0), alphaML(args.alphaML),
|
|
377
|
+
etaL(args.etaL), gamma(args.gamma)
|
|
380
378
|
{
|
|
381
|
-
if (
|
|
382
|
-
if (
|
|
383
|
-
|
|
384
|
-
if (
|
|
379
|
+
if (KL == 0 || KL >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong KL value (KL = %zd)", KL));
|
|
380
|
+
if (T == 0 || T >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong T value (T = %zd)", T));
|
|
381
|
+
|
|
382
|
+
if (args.alpha.size() != 1)
|
|
383
|
+
{
|
|
384
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alpha prior is not supported yet at MGLDA.");
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
if (args.alphaL.size() == 1)
|
|
388
|
+
{
|
|
389
|
+
}
|
|
390
|
+
else if (args.alphaL.size() == args.kL)
|
|
391
|
+
{
|
|
392
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "An asymmetric alphaL prior is not supported yet at MGLDA.");
|
|
393
|
+
}
|
|
394
|
+
else
|
|
395
|
+
{
|
|
396
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (len = %zd)", args.alphaL.size()));
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
if (alphaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alphaL value (alphaL = %f)", alphaL));
|
|
400
|
+
if (etaL <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong etaL value (etaL = %f)", etaL));
|
|
385
401
|
}
|
|
386
402
|
|
|
387
403
|
template<bool _const, typename _FnTokenizer>
|
|
@@ -426,7 +442,7 @@ namespace tomoto
|
|
|
426
442
|
|
|
427
443
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const
|
|
428
444
|
{
|
|
429
|
-
return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
|
|
445
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer, rawDoc.template getMisc<std::string>("delimiter")));
|
|
430
446
|
}
|
|
431
447
|
|
|
432
448
|
template<bool _const = false>
|
|
@@ -497,25 +513,32 @@ namespace tomoto
|
|
|
497
513
|
|
|
498
514
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const
|
|
499
515
|
{
|
|
500
|
-
return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
|
|
516
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
|
|
501
517
|
}
|
|
502
518
|
|
|
503
519
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
504
520
|
{
|
|
505
|
-
if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(
|
|
521
|
+
if (priors.size() != this->K + KL) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K.");
|
|
506
522
|
for (auto p : priors)
|
|
507
523
|
{
|
|
508
|
-
if (p < 0) THROW_ERROR_WITH_INFO(
|
|
524
|
+
if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
|
|
509
525
|
}
|
|
510
526
|
this->dict.add(word);
|
|
511
527
|
this->etaByWord.emplace(word, priors);
|
|
512
528
|
}
|
|
513
529
|
|
|
514
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
530
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
515
531
|
{
|
|
516
532
|
std::vector<Float> ret(this->K + KL);
|
|
517
|
-
Eigen::Map<Eigen::
|
|
518
|
-
|
|
533
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
|
|
534
|
+
if (normalize)
|
|
535
|
+
{
|
|
536
|
+
m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
|
|
537
|
+
}
|
|
538
|
+
else
|
|
539
|
+
{
|
|
540
|
+
m = doc.numByTopic.array().template cast<Float>();
|
|
541
|
+
}
|
|
519
542
|
return ret;
|
|
520
543
|
}
|
|
521
544
|
|
|
@@ -18,13 +18,18 @@ namespace tomoto
|
|
|
18
18
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, Z2s);
|
|
19
19
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, Z2s);
|
|
20
20
|
};
|
|
21
|
+
|
|
22
|
+
struct PAArgs : public LDAArgs
|
|
23
|
+
{
|
|
24
|
+
size_t k2 = 1;
|
|
25
|
+
std::vector<Float> subalpha = { 0.1 };
|
|
26
|
+
};
|
|
21
27
|
|
|
22
28
|
class IPAModel : public ILDAModel
|
|
23
29
|
{
|
|
24
30
|
public:
|
|
25
31
|
using DefaultDocType = DocumentPA<TermWeight::one>;
|
|
26
|
-
static IPAModel* create(TermWeight _weight,
|
|
27
|
-
Float _alpha = 0.1, Float _eta = 0.01, size_t seed = std::random_device{}(),
|
|
32
|
+
static IPAModel* create(TermWeight _weight, const PAArgs& args,
|
|
28
33
|
bool scalarRng = false);
|
|
29
34
|
|
|
30
35
|
virtual size_t getDirichletEstIteration() const = 0;
|
|
@@ -32,10 +37,10 @@ namespace tomoto
|
|
|
32
37
|
virtual size_t getK2() const = 0;
|
|
33
38
|
virtual Float getSubAlpha(Tid k1, Tid k2) const = 0;
|
|
34
39
|
virtual std::vector<Float> getSubAlpha(Tid k1) const = 0;
|
|
35
|
-
virtual std::vector<Float> getSubTopicBySuperTopic(Tid k) const = 0;
|
|
40
|
+
virtual std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize = true) const = 0;
|
|
36
41
|
virtual std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const = 0;
|
|
37
42
|
|
|
38
|
-
virtual std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const = 0;
|
|
43
|
+
virtual std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize = true) const = 0;
|
|
39
44
|
virtual std::vector<std::pair<Tid, Float>> getSubTopicsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
|
|
40
45
|
|
|
41
46
|
virtual std::vector<uint64_t> getCountBySuperTopic() const = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class PAModel<TermWeight::idf>;
|
|
7
|
-
template class PAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IPAModel* IPAModel::create(TermWeight _weight, size_t _K, size_t _K2, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
IPAModel* IPAModel::create(TermWeight _weight, const PAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, PAModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, PAModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -16,7 +16,7 @@ namespace tomoto
|
|
|
16
16
|
using WeightType = typename ModelStateLDA<_tw>::WeightType;
|
|
17
17
|
Eigen::Matrix<WeightType, -1, -1> numByTopic1_2;
|
|
18
18
|
Eigen::Matrix<WeightType, -1, 1> numByTopic2;
|
|
19
|
-
|
|
19
|
+
Vector subTmp;
|
|
20
20
|
|
|
21
21
|
DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
|
|
22
22
|
};
|
|
@@ -41,8 +41,8 @@ namespace tomoto
|
|
|
41
41
|
Float epsilon = 1e-5;
|
|
42
42
|
size_t iteration = 5;
|
|
43
43
|
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
Vector subAlphaSum; // len = K
|
|
45
|
+
Matrix subAlphas; // len = K * K2
|
|
46
46
|
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
47
47
|
{
|
|
48
48
|
const auto K = this->K;
|
|
@@ -286,7 +286,7 @@ namespace tomoto
|
|
|
286
286
|
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
287
287
|
|
|
288
288
|
doc.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2);
|
|
289
|
-
doc.Z2s = tvector<Tid>(wordSize);
|
|
289
|
+
doc.Z2s = tvector<Tid>(wordSize, non_topic_id);
|
|
290
290
|
}
|
|
291
291
|
|
|
292
292
|
void prepareWordPriors()
|
|
@@ -299,7 +299,7 @@ namespace tomoto
|
|
|
299
299
|
{
|
|
300
300
|
auto id = this->dict.toWid(it.first);
|
|
301
301
|
if (id == (Vid)-1 || id >= this->realV) continue;
|
|
302
|
-
this->etaByTopicWord.col(id) = Eigen::Map<
|
|
302
|
+
this->etaByTopicWord.col(id) = Eigen::Map<Vector>{ it.second.data(), (Eigen::Index)it.second.size() };
|
|
303
303
|
}
|
|
304
304
|
this->etaSumByTopic = this->etaByTopicWord.rowwise().sum();
|
|
305
305
|
}
|
|
@@ -307,7 +307,7 @@ namespace tomoto
|
|
|
307
307
|
void initGlobalState(bool initDocs)
|
|
308
308
|
{
|
|
309
309
|
const size_t V = this->realV;
|
|
310
|
-
this->globalState.zLikelihood =
|
|
310
|
+
this->globalState.zLikelihood = Vector::Zero(this->K * K2);
|
|
311
311
|
if (initDocs)
|
|
312
312
|
{
|
|
313
313
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(this->K);
|
|
@@ -365,12 +365,24 @@ namespace tomoto
|
|
|
365
365
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
|
|
366
366
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
|
|
367
367
|
|
|
368
|
-
PAModel(
|
|
369
|
-
: BaseClass(
|
|
368
|
+
PAModel(const PAArgs& args)
|
|
369
|
+
: BaseClass(args), K2(args.k2)
|
|
370
370
|
{
|
|
371
|
-
if (
|
|
372
|
-
|
|
373
|
-
|
|
371
|
+
if (K2 == 0 || K2 >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K2 value (K2 = %zd)", K2));
|
|
372
|
+
|
|
373
|
+
if (args.subalpha.size() == 1)
|
|
374
|
+
{
|
|
375
|
+
subAlphas = Matrix::Constant(args.k, args.k2, args.subalpha[0]);
|
|
376
|
+
}
|
|
377
|
+
else if(args.subalpha.size() == args.k2)
|
|
378
|
+
{
|
|
379
|
+
subAlphas = Eigen::Map<const Eigen::Matrix<Float, 1, -1>>(args.subalpha.data(), args.subalpha.size()).replicate(args.k, 1);
|
|
380
|
+
}
|
|
381
|
+
else
|
|
382
|
+
{
|
|
383
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong subalpha value (len = %zd)", args.subalpha.size()));
|
|
384
|
+
}
|
|
385
|
+
subAlphaSum = subAlphas.rowwise().sum();
|
|
374
386
|
this->optimInterval = 1;
|
|
375
387
|
}
|
|
376
388
|
|
|
@@ -379,7 +391,7 @@ namespace tomoto
|
|
|
379
391
|
|
|
380
392
|
void setDirichletEstIteration(size_t iter) override
|
|
381
393
|
{
|
|
382
|
-
if (!iter) throw
|
|
394
|
+
if (!iter) throw exc::InvalidArgument("iter must > 0");
|
|
383
395
|
iteration = iter;
|
|
384
396
|
}
|
|
385
397
|
|
|
@@ -392,43 +404,54 @@ namespace tomoto
|
|
|
392
404
|
return ret;
|
|
393
405
|
}
|
|
394
406
|
|
|
395
|
-
std::vector<Float> getSubTopicBySuperTopic(Tid k) const override
|
|
407
|
+
std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize) const override
|
|
396
408
|
{
|
|
397
409
|
assert(k < this->K);
|
|
410
|
+
std::vector<Float> ret(K2);
|
|
398
411
|
Float sum = this->globalState.numByTopic[k] + subAlphaSum[k];
|
|
399
|
-
|
|
400
|
-
|
|
412
|
+
if (!normalize) sum = 1;
|
|
413
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K2 };
|
|
414
|
+
m = (this->globalState.numByTopic1_2.row(k).array().template cast<Float>() + subAlphas.row(k).array()) / sum;
|
|
415
|
+
return ret;
|
|
401
416
|
}
|
|
402
417
|
|
|
403
418
|
std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const override
|
|
404
419
|
{
|
|
405
|
-
return extractTopN<Tid>(getSubTopicBySuperTopic(k), topN);
|
|
420
|
+
return extractTopN<Tid>(getSubTopicBySuperTopic(k, true), topN);
|
|
406
421
|
}
|
|
407
422
|
|
|
408
|
-
std::vector<Float> getSubTopicsByDoc(const _DocType& doc) const
|
|
423
|
+
std::vector<Float> getSubTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
409
424
|
{
|
|
410
425
|
std::vector<Float> ret(K2);
|
|
411
|
-
Eigen::Map<Eigen::
|
|
412
|
-
|
|
426
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K2 };
|
|
427
|
+
if (normalize)
|
|
428
|
+
{
|
|
429
|
+
m = ((doc.numByTopic1_2.array().template cast<Float>() + subAlphas.array()).colwise().sum()) / (doc.getSumWordWeight() + subAlphas.sum());
|
|
430
|
+
}
|
|
431
|
+
else
|
|
432
|
+
{
|
|
433
|
+
m = (doc.numByTopic1_2.array().template cast<Float>() + subAlphas.array()).colwise().sum();
|
|
434
|
+
}
|
|
413
435
|
return ret;
|
|
414
436
|
}
|
|
415
437
|
|
|
416
|
-
std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const override
|
|
438
|
+
std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize) const override
|
|
417
439
|
{
|
|
418
|
-
return static_cast<const DerivedClass*>(this)->getSubTopicsByDoc(*static_cast<const _DocType*>(doc));
|
|
440
|
+
return static_cast<const DerivedClass*>(this)->getSubTopicsByDoc(*static_cast<const _DocType*>(doc), normalize);
|
|
419
441
|
}
|
|
420
442
|
|
|
421
443
|
std::vector<std::pair<Tid, Float>> getSubTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
|
|
422
444
|
{
|
|
423
|
-
return extractTopN<Tid>(getSubTopicsByDoc(doc), topN);
|
|
445
|
+
return extractTopN<Tid>(getSubTopicsByDoc(doc, true), topN);
|
|
424
446
|
}
|
|
425
447
|
|
|
426
|
-
std::vector<Float> _getWidsByTopic(Tid k2) const
|
|
448
|
+
std::vector<Float> _getWidsByTopic(Tid k2, bool normalize = true) const
|
|
427
449
|
{
|
|
428
450
|
assert(k2 < K2);
|
|
429
451
|
const size_t V = this->realV;
|
|
430
452
|
std::vector<Float> ret(V);
|
|
431
453
|
Float sum = this->globalState.numByTopic2[k2] + V * this->eta;
|
|
454
|
+
if (!normalize) sum = 1;
|
|
432
455
|
auto r = this->globalState.numByTopicWord.row(k2);
|
|
433
456
|
for (size_t v = 0; v < V; ++v)
|
|
434
457
|
{
|
|
@@ -439,10 +462,10 @@ namespace tomoto
|
|
|
439
462
|
|
|
440
463
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
441
464
|
{
|
|
442
|
-
if (priors.size() != K2) THROW_ERROR_WITH_INFO(
|
|
465
|
+
if (priors.size() != K2) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K2.");
|
|
443
466
|
for (auto p : priors)
|
|
444
467
|
{
|
|
445
|
-
if (p < 0) THROW_ERROR_WITH_INFO(
|
|
468
|
+
if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
|
|
446
469
|
}
|
|
447
470
|
this->dict.add(word);
|
|
448
471
|
this->etaByWord.emplace(word, priors);
|
|
@@ -3,13 +3,24 @@
|
|
|
3
3
|
|
|
4
4
|
namespace tomoto
|
|
5
5
|
{
|
|
6
|
+
struct PLDAArgs : public LDAArgs
|
|
7
|
+
{
|
|
8
|
+
size_t numLatentTopics = 0;
|
|
9
|
+
size_t numTopicsPerLabel = 1;
|
|
10
|
+
|
|
11
|
+
PLDAArgs setK(size_t _k = 1) const
|
|
12
|
+
{
|
|
13
|
+
PLDAArgs ret = *this;
|
|
14
|
+
ret.k = _k;
|
|
15
|
+
return ret;
|
|
16
|
+
}
|
|
17
|
+
};
|
|
6
18
|
|
|
7
19
|
class IPLDAModel : public ILLDAModel
|
|
8
20
|
{
|
|
9
21
|
public:
|
|
10
22
|
using DefaultDocType = DocumentLLDA<TermWeight::one>;
|
|
11
|
-
static IPLDAModel* create(TermWeight _weight,
|
|
12
|
-
Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
|
|
23
|
+
static IPLDAModel* create(TermWeight _weight, const PLDAArgs& args,
|
|
13
24
|
bool scalarRng = false);
|
|
14
25
|
|
|
15
26
|
virtual size_t getNumLatentTopics() const = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class PLDAModel<TermWeight::idf>;
|
|
7
|
-
template class PLDAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IPLDAModel* IPLDAModel::create(TermWeight _weight, size_t _numLatentTopics, size_t _numTopicsPerLabel, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
IPLDAModel* IPLDAModel::create(TermWeight _weight, const PLDAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, PLDAModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, PLDAModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -75,14 +75,16 @@ namespace tomoto
|
|
|
75
75
|
|
|
76
76
|
struct Generator
|
|
77
77
|
{
|
|
78
|
-
|
|
78
|
+
Eigen::Array<Float, -1, 1> p;
|
|
79
|
+
Eigen::Rand::DiscreteGen<int32_t> theta;
|
|
79
80
|
};
|
|
80
81
|
|
|
81
82
|
Generator makeGeneratorForInit(const _DocType* doc) const
|
|
82
83
|
{
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
};
|
|
84
|
+
Generator g;
|
|
85
|
+
g.p = doc->labelMask.array().template cast<Float>() * this->alphas.array();
|
|
86
|
+
g.theta = Eigen::Rand::DiscreteGen<int32_t>{ g.p.data(), g.p.data() + this->K };
|
|
87
|
+
return g;
|
|
86
88
|
}
|
|
87
89
|
|
|
88
90
|
template<bool _Infer>
|
|
@@ -93,7 +95,7 @@ namespace tomoto
|
|
|
93
95
|
if (this->etaByTopicWord.size())
|
|
94
96
|
{
|
|
95
97
|
Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
|
|
96
|
-
|
|
98
|
+
col *= g.p;
|
|
97
99
|
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
98
100
|
}
|
|
99
101
|
else
|
|
@@ -107,15 +109,14 @@ namespace tomoto
|
|
|
107
109
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
|
108
110
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
|
109
111
|
|
|
110
|
-
PLDAModel(
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
numLatentTopics(_numLatentTopics), numTopicsPerLabel(_numTopicsPerLabel)
|
|
112
|
+
PLDAModel(const PLDAArgs& args)
|
|
113
|
+
: BaseClass(args.setK(1)),
|
|
114
|
+
numLatentTopics(args.numLatentTopics), numTopicsPerLabel(args.numTopicsPerLabel)
|
|
114
115
|
{
|
|
115
|
-
if (
|
|
116
|
-
THROW_ERROR_WITH_INFO(
|
|
117
|
-
if (
|
|
118
|
-
THROW_ERROR_WITH_INFO(
|
|
116
|
+
if (numLatentTopics >= 0x80000000)
|
|
117
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong numLatentTopics value (numLatentTopics = %zd)", numLatentTopics));
|
|
118
|
+
if (numTopicsPerLabel == 0 || numTopicsPerLabel >= 0x80000000)
|
|
119
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong numTopicsPerLabel value (numTopicsPerLabel = %zd)", numTopicsPerLabel));
|
|
119
120
|
}
|
|
120
121
|
|
|
121
122
|
template<bool _const = false>
|
|
@@ -162,7 +163,7 @@ namespace tomoto
|
|
|
162
163
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
163
164
|
{
|
|
164
165
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
|
165
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
166
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
166
167
|
}
|
|
167
168
|
|
|
168
169
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
@@ -174,16 +175,23 @@ namespace tomoto
|
|
|
174
175
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
175
176
|
{
|
|
176
177
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
|
177
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
178
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
178
179
|
}
|
|
179
180
|
|
|
180
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
181
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
181
182
|
{
|
|
182
183
|
std::vector<Float> ret(this->K);
|
|
183
184
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
|
184
|
-
Eigen::Map<Eigen::
|
|
185
|
-
|
|
186
|
-
|
|
185
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
|
186
|
+
if (normalize)
|
|
187
|
+
{
|
|
188
|
+
m = (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
|
|
189
|
+
/ (doc.getSumWordWeight() + maskedAlphas.sum());
|
|
190
|
+
}
|
|
191
|
+
else
|
|
192
|
+
{
|
|
193
|
+
m = doc.numByTopic.array().template cast<Float>() + maskedAlphas;
|
|
194
|
+
}
|
|
187
195
|
return ret;
|
|
188
196
|
}
|
|
189
197
|
|