tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -56,12 +56,21 @@ namespace tomoto
|
|
|
56
56
|
template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
|
|
57
57
|
};
|
|
58
58
|
|
|
59
|
+
struct HDPArgs : public LDAArgs
|
|
60
|
+
{
|
|
61
|
+
Float gamma = 0.1;
|
|
62
|
+
|
|
63
|
+
HDPArgs()
|
|
64
|
+
{
|
|
65
|
+
k = 2;
|
|
66
|
+
}
|
|
67
|
+
};
|
|
68
|
+
|
|
59
69
|
class IHDPModel : public ILDAModel
|
|
60
70
|
{
|
|
61
71
|
public:
|
|
62
72
|
using DefaultDocType = DocumentHDP<TermWeight::one>;
|
|
63
|
-
static IHDPModel* create(TermWeight _weight,
|
|
64
|
-
Float alpha = 0.1, Float eta = 0.01, Float gamma = 0.1, size_t seed = std::random_device{}(),
|
|
73
|
+
static IHDPModel* create(TermWeight _weight, const HDPArgs& args,
|
|
65
74
|
bool scalarRng = false);
|
|
66
75
|
|
|
67
76
|
virtual Float getGamma() const = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class HDPModel<TermWeight::idf>;
|
|
7
|
-
template class HDPModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IHDPModel* IHDPModel::create(TermWeight _weight, size_t _K, Float _alpha , Float _eta, Float _gamma, size_t seed, bool scalarRng)
|
|
5
|
+
IHDPModel* IHDPModel::create(TermWeight _weight, const HDPArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, HDPModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, HDPModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -14,7 +14,7 @@ namespace tomoto
|
|
|
14
14
|
template<TermWeight _tw>
|
|
15
15
|
struct ModelStateHDP : public ModelStateLDA<_tw>
|
|
16
16
|
{
|
|
17
|
-
|
|
17
|
+
Vector tableLikelihood, topicLikelihood;
|
|
18
18
|
Eigen::Matrix<int32_t, -1, 1> numTableByTopic;
|
|
19
19
|
size_t totalTable = 0;
|
|
20
20
|
|
|
@@ -397,58 +397,47 @@ namespace tomoto
|
|
|
397
397
|
|
|
398
398
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
399
399
|
{
|
|
400
|
+
sortAndWriteOrder(doc.words, doc.wOrder);
|
|
400
401
|
doc.numByTopic.init(nullptr, this->K, 1);
|
|
401
402
|
doc.numTopicByTable.clear();
|
|
402
|
-
doc.Zs = tvector<Tid>(wordSize);
|
|
403
|
+
doc.Zs = tvector<Tid>(wordSize, non_topic_id);
|
|
403
404
|
if (_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
|
404
405
|
}
|
|
405
406
|
|
|
406
|
-
template<bool
|
|
407
|
+
template<bool _infer>
|
|
407
408
|
void updateStateWithDoc(typename BaseClass::Generator& g, _ModelState& ld, _RandGen& rgs, _DocType& doc, size_t i) const
|
|
408
409
|
{
|
|
409
|
-
|
|
410
|
-
|
|
410
|
+
Tid t;
|
|
411
|
+
std::vector<double> dist;
|
|
412
|
+
dist.emplace_back(this->alpha);
|
|
413
|
+
for (auto& d : doc.numTopicByTable) dist.emplace_back(d.num);
|
|
414
|
+
std::discrete_distribution<Tid> ddist{ dist.begin(), dist.end() };
|
|
415
|
+
t = ddist(rgs);
|
|
416
|
+
if (t == 0)
|
|
411
417
|
{
|
|
412
|
-
|
|
418
|
+
// new table
|
|
419
|
+
Tid k;
|
|
420
|
+
if (_infer)
|
|
413
421
|
{
|
|
414
|
-
|
|
415
|
-
|
|
422
|
+
std::uniform_int_distribution<> theta{ 0, this->K - 1 };
|
|
423
|
+
do
|
|
416
424
|
{
|
|
417
|
-
|
|
418
|
-
}
|
|
419
|
-
else
|
|
420
|
-
{
|
|
421
|
-
t = std::uniform_int_distribution<size_t>{ 0, doc.getNumTable() - 1 }(rgs);
|
|
422
|
-
}
|
|
423
|
-
++ld.numTableByTopic[doc.numTopicByTable[t].topic];
|
|
424
|
-
++ld.totalTable;
|
|
425
|
-
doc.Zs[i] = t;
|
|
426
|
-
}
|
|
427
|
-
else doc.Zs[i] = std::uniform_int_distribution<size_t>{ 0, doc.getNumTable() - 1 }(rgs);
|
|
428
|
-
}
|
|
429
|
-
// generate tables following CRP
|
|
430
|
-
else
|
|
431
|
-
{
|
|
432
|
-
Tid t;
|
|
433
|
-
std::vector<double> dist;
|
|
434
|
-
dist.emplace_back(this->alpha);
|
|
435
|
-
for (auto& d : doc.numTopicByTable) dist.emplace_back(d.num);
|
|
436
|
-
std::discrete_distribution<Tid> ddist{ dist.begin(), dist.end() };
|
|
437
|
-
t = ddist(rgs);
|
|
438
|
-
if (t == 0)
|
|
439
|
-
{
|
|
440
|
-
// new table
|
|
441
|
-
Tid k = g.theta(rgs);
|
|
442
|
-
t = doc.addNewTable(k);
|
|
443
|
-
++ld.numTableByTopic[k];
|
|
444
|
-
++ld.totalTable;
|
|
425
|
+
k = theta(rgs);
|
|
426
|
+
} while (!isLiveTopic(k));
|
|
445
427
|
}
|
|
446
428
|
else
|
|
447
429
|
{
|
|
448
|
-
|
|
430
|
+
k = g.theta(rgs);
|
|
449
431
|
}
|
|
450
|
-
|
|
432
|
+
t = doc.addNewTable(k);
|
|
433
|
+
++ld.numTableByTopic[k];
|
|
434
|
+
++ld.totalTable;
|
|
451
435
|
}
|
|
436
|
+
else
|
|
437
|
+
{
|
|
438
|
+
t -= 1;
|
|
439
|
+
}
|
|
440
|
+
doc.Zs[i] = t;
|
|
452
441
|
addWordTo<1>(ld, doc, i, doc.words[i], doc.Zs[i], doc.numTopicByTable[doc.Zs[i]].topic);
|
|
453
442
|
}
|
|
454
443
|
|
|
@@ -469,10 +458,11 @@ namespace tomoto
|
|
|
469
458
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
|
|
470
459
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
|
|
471
460
|
|
|
472
|
-
HDPModel(
|
|
473
|
-
: BaseClass(
|
|
461
|
+
HDPModel(const HDPArgs& args)
|
|
462
|
+
: BaseClass(args), gamma(args.gamma)
|
|
474
463
|
{
|
|
475
|
-
if (
|
|
464
|
+
if (gamma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong gamma value (gamma = %f)", gamma));
|
|
465
|
+
if (args.alpha.size() > 1) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "Asymmetric alpha is not supported at HDP.");
|
|
476
466
|
}
|
|
477
467
|
|
|
478
468
|
size_t getTotalTables() const override
|
|
@@ -497,13 +487,21 @@ namespace tomoto
|
|
|
497
487
|
|
|
498
488
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
499
489
|
{
|
|
500
|
-
THROW_ERROR_WITH_INFO(
|
|
490
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "HDPModel doesn't provide setWordPrior function.");
|
|
501
491
|
}
|
|
502
492
|
|
|
503
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
493
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
504
494
|
{
|
|
505
495
|
std::vector<Float> ret(this->K);
|
|
506
|
-
Eigen::Map<Eigen::
|
|
496
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
|
497
|
+
if (normalize)
|
|
498
|
+
{
|
|
499
|
+
m = doc.numByTopic.array().template cast<Float>() / doc.getSumWordWeight();
|
|
500
|
+
}
|
|
501
|
+
else
|
|
502
|
+
{
|
|
503
|
+
m = doc.numByTopic.array().template cast<Float>();
|
|
504
|
+
}
|
|
507
505
|
return ret;
|
|
508
506
|
}
|
|
509
507
|
|
|
@@ -528,7 +526,11 @@ namespace tomoto
|
|
|
528
526
|
liveK++;
|
|
529
527
|
}
|
|
530
528
|
|
|
531
|
-
|
|
529
|
+
LDAArgs args;
|
|
530
|
+
args.k = liveK;
|
|
531
|
+
args.alpha[0] = 0.1f;
|
|
532
|
+
args.eta = this->eta;
|
|
533
|
+
auto lda = std::make_unique<LDAModel<_tw, _RandGen>>(args);
|
|
532
534
|
lda->dict = this->dict;
|
|
533
535
|
|
|
534
536
|
for (auto& doc : this->docs)
|
|
@@ -551,6 +553,11 @@ namespace tomoto
|
|
|
551
553
|
{
|
|
552
554
|
for (size_t j = 0; j < this->docs[i].Zs.size(); ++j)
|
|
553
555
|
{
|
|
556
|
+
if (this->docs[i].Zs[j] == non_topic_id)
|
|
557
|
+
{
|
|
558
|
+
lda->docs[i].Zs[j] = non_topic_id;
|
|
559
|
+
continue;
|
|
560
|
+
}
|
|
554
561
|
size_t newTopic = newK[this->docs[i].numTopicByTable[this->docs[i].Zs[j]].topic];
|
|
555
562
|
while (newTopic == (Tid)-1) newTopic = newK[randomTopic(rng)];
|
|
556
563
|
lda->docs[i].Zs[j] = newTopic;
|
|
@@ -20,12 +20,21 @@ namespace tomoto
|
|
|
20
20
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, path);
|
|
21
21
|
};
|
|
22
22
|
|
|
23
|
+
struct HLDAArgs : public LDAArgs
|
|
24
|
+
{
|
|
25
|
+
Float gamma = 0.1;
|
|
26
|
+
|
|
27
|
+
HLDAArgs()
|
|
28
|
+
{
|
|
29
|
+
k = 2;
|
|
30
|
+
}
|
|
31
|
+
};
|
|
32
|
+
|
|
23
33
|
class IHLDAModel : public ILDAModel
|
|
24
34
|
{
|
|
25
35
|
public:
|
|
26
36
|
using DefaultDocType = DocumentHLDA<TermWeight::one>;
|
|
27
|
-
static IHLDAModel* create(TermWeight _weight,
|
|
28
|
-
Float alpha = 0.1, Float eta = 0.01, Float gamma = 0.1, size_t seed = std::random_device{}(),
|
|
37
|
+
static IHLDAModel* create(TermWeight _weight, const HLDAArgs& args,
|
|
29
38
|
bool scalarRng = false);
|
|
30
39
|
|
|
31
40
|
virtual Float getGamma() const = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class HLDAModel<TermWeight::idf>;
|
|
7
|
-
template class HLDAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IHLDAModel* IHLDAModel::create(TermWeight _weight, size_t levelDepth, Float _alpha, Float _eta, Float _gamma, size_t seed, bool scalarRng)
|
|
5
|
+
IHLDAModel* IHLDAModel::create(TermWeight _weight, const HLDAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, HLDAModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, HLDAModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -114,8 +114,8 @@ namespace tomoto
|
|
|
114
114
|
static constexpr size_t blockSize = 8;
|
|
115
115
|
std::vector<NCRPNode> nodes;
|
|
116
116
|
std::vector<uint8_t> levelBlocks;
|
|
117
|
-
|
|
118
|
-
|
|
117
|
+
Vector nodeLikelihoods; //
|
|
118
|
+
Vector nodeWLikelihoods; //
|
|
119
119
|
|
|
120
120
|
DEFINE_SERIALIZER(nodes, levelBlocks);
|
|
121
121
|
|
|
@@ -351,6 +351,8 @@ namespace tomoto
|
|
|
351
351
|
template<GlobalSampler _gs>
|
|
352
352
|
void samplePathes(_DocType& doc, ThreadPool* pool, _ModelState& ld, _RandGen& rgs) const
|
|
353
353
|
{
|
|
354
|
+
if (!doc.getSumWordWeight()) return;
|
|
355
|
+
|
|
354
356
|
if(_gs != GlobalSampler::inference) ld.nt->nodes[doc.path.back()].dropPathOne();
|
|
355
357
|
ld.nt->template calcNodeLikelihood<_gs == GlobalSampler::train>(gamma, this->K);
|
|
356
358
|
|
|
@@ -433,7 +435,7 @@ namespace tomoto
|
|
|
433
435
|
template<bool _asymEta>
|
|
434
436
|
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
435
437
|
{
|
|
436
|
-
if (_asymEta) THROW_ERROR_WITH_INFO(
|
|
438
|
+
if (_asymEta) THROW_ERROR_WITH_INFO(exc::Unimplemented, "Unimplemented features");
|
|
437
439
|
const size_t V = this->realV;
|
|
438
440
|
assert(vid < V);
|
|
439
441
|
auto& zLikelihood = ld.zLikelihood;
|
|
@@ -461,7 +463,6 @@ namespace tomoto
|
|
|
461
463
|
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
462
464
|
{
|
|
463
465
|
double ll = 0;
|
|
464
|
-
auto lgammaAlpha = math::lgammaT(this->alpha);
|
|
465
466
|
for (; _first != _last; ++_first)
|
|
466
467
|
{
|
|
467
468
|
auto& doc = *_first;
|
|
@@ -472,13 +473,9 @@ namespace tomoto
|
|
|
472
473
|
}
|
|
473
474
|
|
|
474
475
|
// doc-level distribution
|
|
475
|
-
ll -= math::
|
|
476
|
-
|
|
477
|
-
{
|
|
478
|
-
ll += math::lgammaT(doc.numByTopic[l] + this->alpha) - lgammaAlpha;
|
|
479
|
-
}
|
|
476
|
+
ll -= math::lgammaSubt(this->alphas.sum(), doc.getSumWordWeight());
|
|
477
|
+
ll += math::lgammaSubt(this->alphas.array(), doc.numByTopic.template cast<Float>().array()).sum();
|
|
480
478
|
}
|
|
481
|
-
ll += math::lgammaT(this->alpha * this->K) * std::distance(_first, _last);
|
|
482
479
|
return ll;
|
|
483
480
|
}
|
|
484
481
|
|
|
@@ -521,7 +518,7 @@ namespace tomoto
|
|
|
521
518
|
{
|
|
522
519
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
|
523
520
|
doc.numByTopic.init(nullptr, this->K, 1);
|
|
524
|
-
doc.Zs = tvector<Tid>(wordSize);
|
|
521
|
+
doc.Zs = tvector<Tid>(wordSize, non_topic_id);
|
|
525
522
|
doc.path.resize(this->K);
|
|
526
523
|
for (size_t l = 0; l < this->K; ++l) doc.path[l] = l;
|
|
527
524
|
|
|
@@ -597,11 +594,11 @@ namespace tomoto
|
|
|
597
594
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, gamma);
|
|
598
595
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, gamma);
|
|
599
596
|
|
|
600
|
-
HLDAModel(
|
|
601
|
-
: BaseClass(
|
|
597
|
+
HLDAModel(const HLDAArgs& args)
|
|
598
|
+
: BaseClass(args), gamma(args.gamma)
|
|
602
599
|
{
|
|
603
|
-
if (
|
|
604
|
-
if (
|
|
600
|
+
if (args.k == 0 || args.k >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong levelDepth value (levelDepth = %zd)", args.k));
|
|
601
|
+
if (gamma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong gamma value (gamma = %f)", gamma));
|
|
605
602
|
this->globalState.nt = std::make_shared<detail::NodeTrees>();
|
|
606
603
|
}
|
|
607
604
|
|
|
@@ -661,7 +658,7 @@ namespace tomoto
|
|
|
661
658
|
|
|
662
659
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
663
660
|
{
|
|
664
|
-
THROW_ERROR_WITH_INFO(
|
|
661
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "HLDAModel doesn't provide setWordPrior function.");
|
|
665
662
|
}
|
|
666
663
|
};
|
|
667
664
|
|
|
@@ -16,12 +16,15 @@ namespace tomoto
|
|
|
16
16
|
DEFINE_SERIALIZER_BASE_WITH_VERSION(BaseDocument, 1);
|
|
17
17
|
};
|
|
18
18
|
|
|
19
|
+
struct HPAArgs : public PAArgs
|
|
20
|
+
{
|
|
21
|
+
};
|
|
22
|
+
|
|
19
23
|
class IHPAModel : public IPAModel
|
|
20
24
|
{
|
|
21
25
|
public:
|
|
22
26
|
using DefaultDocType = DocumentHPA<TermWeight::one>;
|
|
23
|
-
static IHPAModel* create(TermWeight _weight, bool _exclusive
|
|
24
|
-
Float _alpha = 50, Float _eta = 0.01, size_t seed = std::random_device{}(),
|
|
27
|
+
static IHPAModel* create(TermWeight _weight, bool _exclusive, const HPAArgs& args,
|
|
25
28
|
bool scalarRng = false);
|
|
26
29
|
};
|
|
27
30
|
}
|
|
@@ -2,11 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class HPAModel<TermWeight::idf>;
|
|
7
|
-
template class HPAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IHPAModel* IHPAModel::create(TermWeight _weight, bool _exclusive, size_t _K, size_t _K2, Float _alphaSum, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
IHPAModel* IHPAModel::create(TermWeight _weight, bool _exclusive, const HPAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
7
|
if (_exclusive)
|
|
12
8
|
{
|
|
@@ -14,7 +10,7 @@ namespace tomoto
|
|
|
14
10
|
}
|
|
15
11
|
else
|
|
16
12
|
{
|
|
17
|
-
TMT_SWITCH_TW(_weight, scalarRng, HPAModel,
|
|
13
|
+
TMT_SWITCH_TW(_weight, scalarRng, HPAModel, args);
|
|
18
14
|
}
|
|
19
15
|
return nullptr;
|
|
20
16
|
}
|
|
@@ -16,7 +16,7 @@ namespace tomoto
|
|
|
16
16
|
|
|
17
17
|
std::array<Eigen::Matrix<WeightType, -1, -1>, 3> numByTopicWord;
|
|
18
18
|
std::array<Eigen::Matrix<WeightType, -1, 1>, 3> numByTopic;
|
|
19
|
-
std::array<
|
|
19
|
+
std::array<Vector, 2> subTmp;
|
|
20
20
|
|
|
21
21
|
Eigen::Matrix<WeightType, -1, -1> numByTopic1_2;
|
|
22
22
|
|
|
@@ -45,10 +45,10 @@ namespace tomoto
|
|
|
45
45
|
Float epsilon = 0.00001;
|
|
46
46
|
size_t iteration = 5;
|
|
47
47
|
|
|
48
|
-
//
|
|
48
|
+
//Vector alphas; // len = (K + 1)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
Vector subAlphaSum; // len = K
|
|
51
|
+
Matrix subAlphas; // len = K * (K2 + 1)
|
|
52
52
|
|
|
53
53
|
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
54
54
|
{
|
|
@@ -195,7 +195,7 @@ namespace tomoto
|
|
|
195
195
|
Float* dist;
|
|
196
196
|
if (this->etaByTopicWord.size())
|
|
197
197
|
{
|
|
198
|
-
THROW_ERROR_WITH_INFO(
|
|
198
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "Unimplemented features");
|
|
199
199
|
}
|
|
200
200
|
else
|
|
201
201
|
{
|
|
@@ -379,7 +379,7 @@ namespace tomoto
|
|
|
379
379
|
void initGlobalState(bool initDocs)
|
|
380
380
|
{
|
|
381
381
|
const size_t V = this->realV;
|
|
382
|
-
this->globalState.zLikelihood =
|
|
382
|
+
this->globalState.zLikelihood = Vector::Zero(1 + this->K + this->K * K2);
|
|
383
383
|
if (initDocs)
|
|
384
384
|
{
|
|
385
385
|
this->globalState.numByTopic1_2 = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, K2 + 1);
|
|
@@ -440,13 +440,37 @@ namespace tomoto
|
|
|
440
440
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
|
|
441
441
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
|
|
442
442
|
|
|
443
|
-
HPAModel(
|
|
444
|
-
: BaseClass(
|
|
443
|
+
HPAModel(const HPAArgs& args)
|
|
444
|
+
: BaseClass(args, false), K2(args.k2)
|
|
445
445
|
{
|
|
446
|
-
if (
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
446
|
+
if (K2 == 0 || K2 >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K2 value (K2 = %zd)", K2));
|
|
447
|
+
|
|
448
|
+
if (args.alpha.size() == 1)
|
|
449
|
+
{
|
|
450
|
+
this->alphas = Vector::Constant(args.k + 1, args.alpha[0]);
|
|
451
|
+
}
|
|
452
|
+
else if (args.alpha.size() == args.k + 1)
|
|
453
|
+
{
|
|
454
|
+
this->alphas = Eigen::Map<const Vector>(args.alpha.data(), (Eigen::Index)args.alpha.size());
|
|
455
|
+
}
|
|
456
|
+
else
|
|
457
|
+
{
|
|
458
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alpha value (len = %zd)", args.alpha.size()));
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
if (args.subalpha.size() == 1)
|
|
462
|
+
{
|
|
463
|
+
subAlphas = Matrix::Constant(args.k, args.k2 + 1, args.subalpha[0]);
|
|
464
|
+
}
|
|
465
|
+
else if (args.subalpha.size() == args.k2 + 1)
|
|
466
|
+
{
|
|
467
|
+
subAlphas = Eigen::Map<const Eigen::Matrix<Float, 1, -1>>(args.subalpha.data(), args.subalpha.size()).replicate(args.k, 1);
|
|
468
|
+
}
|
|
469
|
+
else
|
|
470
|
+
{
|
|
471
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong subalpha value (len = %zd)", args.subalpha.size()));
|
|
472
|
+
}
|
|
473
|
+
subAlphaSum = subAlphas.rowwise().sum();
|
|
450
474
|
this->optimInterval = 1;
|
|
451
475
|
}
|
|
452
476
|
|
|
@@ -455,7 +479,7 @@ namespace tomoto
|
|
|
455
479
|
|
|
456
480
|
void setDirichletEstIteration(size_t iter) override
|
|
457
481
|
{
|
|
458
|
-
if (!iter) throw
|
|
482
|
+
if (!iter) throw exc::InvalidArgument("iter must > 0");
|
|
459
483
|
iteration = iter;
|
|
460
484
|
}
|
|
461
485
|
|
|
@@ -475,20 +499,23 @@ namespace tomoto
|
|
|
475
499
|
return ret;
|
|
476
500
|
}
|
|
477
501
|
|
|
478
|
-
std::vector<Float> getSubTopicBySuperTopic(Tid k) const override
|
|
502
|
+
std::vector<Float> getSubTopicBySuperTopic(Tid k, bool normalize) const override
|
|
479
503
|
{
|
|
504
|
+
std::vector<Float> ret(K2);
|
|
480
505
|
assert(k < this->K);
|
|
481
506
|
Float sum = this->globalState.numByTopic1_2.row(k).sum() + subAlphaSum[k];
|
|
482
|
-
|
|
483
|
-
|
|
507
|
+
if (!normalize) sum = 1;
|
|
508
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), (Eigen::Index)K2 };
|
|
509
|
+
m = (this->globalState.numByTopic1_2.row(k).segment(1, K2).array().template cast<Float>() + subAlphas.row(k).segment(1, K2).array()) / sum;
|
|
510
|
+
return ret;
|
|
484
511
|
}
|
|
485
512
|
|
|
486
513
|
std::vector<std::pair<Tid, Float>> getSubTopicBySuperTopicSorted(Tid k, size_t topN) const override
|
|
487
514
|
{
|
|
488
|
-
return extractTopN<Tid>(getSubTopicBySuperTopic(k), topN);
|
|
515
|
+
return extractTopN<Tid>(getSubTopicBySuperTopic(k, true), topN);
|
|
489
516
|
}
|
|
490
517
|
|
|
491
|
-
std::vector<Float> _getWidsByTopic(Tid k) const
|
|
518
|
+
std::vector<Float> _getWidsByTopic(Tid k, bool normalize = true) const
|
|
492
519
|
{
|
|
493
520
|
const size_t V = this->realV;
|
|
494
521
|
std::vector<Float> ret(V);
|
|
@@ -504,6 +531,7 @@ namespace tomoto
|
|
|
504
531
|
}
|
|
505
532
|
}
|
|
506
533
|
Float sum = this->globalState.numByTopic[level][k] + V * this->eta;
|
|
534
|
+
if (!normalize) sum = 1;
|
|
507
535
|
auto r = this->globalState.numByTopicWord[level].row(k);
|
|
508
536
|
for (size_t v = 0; v < V; ++v)
|
|
509
537
|
{
|
|
@@ -512,10 +540,12 @@ namespace tomoto
|
|
|
512
540
|
return ret;
|
|
513
541
|
}
|
|
514
542
|
|
|
515
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
543
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
516
544
|
{
|
|
517
545
|
std::vector<Float> ret(1 + this->K + K2);
|
|
518
546
|
Float sum = doc.getSumWordWeight() + this->alphas.sum();
|
|
547
|
+
if (!normalize) sum = 1;
|
|
548
|
+
|
|
519
549
|
ret[0] = (doc.numByTopic[0] + this->alphas[0]) / sum;
|
|
520
550
|
for (size_t k = 0; k < this->K; ++k)
|
|
521
551
|
{
|
|
@@ -528,7 +558,7 @@ namespace tomoto
|
|
|
528
558
|
return ret;
|
|
529
559
|
}
|
|
530
560
|
|
|
531
|
-
std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc) const override
|
|
561
|
+
std::vector<Float> getSubTopicsByDoc(const DocumentBase* doc, bool normalize) const override
|
|
532
562
|
{
|
|
533
563
|
throw std::runtime_error{ "not applicable" };
|
|
534
564
|
}
|
|
@@ -540,7 +570,7 @@ namespace tomoto
|
|
|
540
570
|
|
|
541
571
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
542
572
|
{
|
|
543
|
-
THROW_ERROR_WITH_INFO(
|
|
573
|
+
THROW_ERROR_WITH_INFO(exc::Unimplemented, "HPAModel doesn't provide setWordPrior function.");
|
|
544
574
|
}
|
|
545
575
|
|
|
546
576
|
std::vector<uint64_t> getCountBySuperTopic() const override
|