tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -14,19 +14,38 @@ namespace tomoto
|
|
|
14
14
|
ShareableMatrix<Float, -1, 1> eta;
|
|
15
15
|
sample::AliasMethod<> aliasTable;
|
|
16
16
|
|
|
17
|
+
RawDoc::MiscType makeMisc(const ITopicModel* tm) const override
|
|
18
|
+
{
|
|
19
|
+
RawDoc::MiscType ret = DocumentLDA<_tw>::makeMisc(tm);
|
|
20
|
+
ret["timepoint"] = (uint32_t)timepoint;
|
|
21
|
+
return ret;
|
|
22
|
+
}
|
|
23
|
+
|
|
17
24
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, timepoint);
|
|
18
25
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, timepoint);
|
|
19
26
|
};
|
|
20
27
|
|
|
28
|
+
struct DTArgs : public LDAArgs
|
|
29
|
+
{
|
|
30
|
+
size_t t = 1;
|
|
31
|
+
Float phi = 0.1;
|
|
32
|
+
Float shapeA = 0.01;
|
|
33
|
+
Float shapeB = 0.1;
|
|
34
|
+
Float shapeC = 0.55;
|
|
35
|
+
Float etaL2Reg = 0;
|
|
36
|
+
|
|
37
|
+
DTArgs()
|
|
38
|
+
{
|
|
39
|
+
alpha[0] = 0.1;
|
|
40
|
+
eta = 0.1;
|
|
41
|
+
}
|
|
42
|
+
};
|
|
43
|
+
|
|
21
44
|
class IDTModel : public ILDAModel
|
|
22
45
|
{
|
|
23
46
|
public:
|
|
24
47
|
using DefaultDocType = DocumentDTM<TermWeight::one>;
|
|
25
|
-
static IDTModel* create(TermWeight _weight,
|
|
26
|
-
Float _alphaVar = 1.0, Float _etaVar = 1.0, Float _phiVar = 1.0,
|
|
27
|
-
Float _shapeA = 0.03, Float _shapeB = 0.1, Float _shapeC = 0.55,
|
|
28
|
-
Float _etaRegL2 = 0,
|
|
29
|
-
size_t seed = std::random_device{}(),
|
|
48
|
+
static IDTModel* create(TermWeight _weight, const DTArgs& args,
|
|
30
49
|
bool scalarRng = false);
|
|
31
50
|
|
|
32
51
|
virtual size_t getT() const = 0;
|
|
@@ -2,14 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class DTModel<TermWeight::idf>;
|
|
7
|
-
template class DTModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IDTModel* IDTModel::create(TermWeight _weight, size_t _K, size_t _T,
|
|
10
|
-
Float _alphaVar, Float _etaVar, Float _phiVar,
|
|
11
|
-
Float _shapeA, Float _shapeB, Float _shapeC, Float _etaRegL2, size_t seed, bool scalarRng)
|
|
5
|
+
IDTModel* IDTModel::create(TermWeight _weight, const DTArgs& args, bool scalarRng)
|
|
12
6
|
{
|
|
13
|
-
TMT_SWITCH_TW(_weight, scalarRng, DTModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, DTModel, args);
|
|
14
8
|
}
|
|
15
9
|
}
|
|
@@ -45,12 +45,12 @@ namespace tomoto
|
|
|
45
45
|
|
|
46
46
|
uint64_t T;
|
|
47
47
|
Float shapeA = 0.03f, shapeB = 0.1f, shapeC = 0.55f;
|
|
48
|
-
|
|
48
|
+
Float alphaVar = 1.f, etaVar = 1.f, phiVar = 1.f, etaRegL2 = 0.0f;
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
|
|
50
|
+
Matrix alphas; // Dim: (Topic, Time)
|
|
51
|
+
Matrix etaByDoc; // Dim: (Topic, Docs) : Topic distribution by docs(and time)
|
|
52
52
|
std::vector<uint32_t> numDocsByTime; // Dim: (Time)
|
|
53
|
-
|
|
53
|
+
Matrix phi; // Dim: (Word, Topic * Time)
|
|
54
54
|
std::vector<sample::AliasMethod<>> wordAliasTables; // Dim: (Word * Time)
|
|
55
55
|
|
|
56
56
|
template<int _inc>
|
|
@@ -84,8 +84,8 @@ namespace tomoto
|
|
|
84
84
|
|
|
85
85
|
// sampling eta
|
|
86
86
|
{
|
|
87
|
-
|
|
88
|
-
|
|
87
|
+
Vector estimatedCnt = (doc.eta.array() - doc.eta.maxCoeff()).exp();
|
|
88
|
+
Vector etaTmp;
|
|
89
89
|
estimatedCnt *= doc.getSumWordWeight() / estimatedCnt.sum();
|
|
90
90
|
auto prior = (alphas.col(doc.timepoint) - doc.eta) / std::max(etaVar, eps * 2);
|
|
91
91
|
auto grad = doc.numByTopic.template cast<Float>() - estimatedCnt;
|
|
@@ -181,20 +181,21 @@ namespace tomoto
|
|
|
181
181
|
template<typename _DocIter>
|
|
182
182
|
void _sampleGlobalLevel(ThreadPool* pool, _ModelState*, _RandGen* rgs, _DocIter first, _DocIter last)
|
|
183
183
|
{
|
|
184
|
+
if (!this->realV) return;
|
|
184
185
|
const auto K = this->K;
|
|
185
186
|
const Float eps = shapeA * (std::pow(shapeB + 1 + this->globalStep, -shapeC));
|
|
186
187
|
|
|
187
188
|
// sampling phi
|
|
188
189
|
for (size_t k = 0; k < K; ++k)
|
|
189
190
|
{
|
|
190
|
-
|
|
191
|
+
Matrix phiGrad{ (Eigen::Index)this->realV, (Eigen::Index)T };
|
|
191
192
|
for (size_t t = 0; t < T; ++t)
|
|
192
193
|
{
|
|
193
194
|
auto phi_tk = phi.col(k + K * t);
|
|
194
|
-
|
|
195
|
+
Vector estimatedCnt = (phi_tk.array() - phi_tk.maxCoeff()).exp();
|
|
195
196
|
estimatedCnt *= this->globalState.numByTopic(k, t) / estimatedCnt.sum();
|
|
196
197
|
|
|
197
|
-
|
|
198
|
+
Vector grad = this->globalState.numByTopicWord.row(k + K * t).template cast<Float>();
|
|
198
199
|
grad -= estimatedCnt;
|
|
199
200
|
auto epsNoise = Eigen::Rand::normal<Eigen::Array<Float, -1, 1>>(this->realV, 1, *rgs) * eps;
|
|
200
201
|
if (t == 0)
|
|
@@ -228,7 +229,7 @@ namespace tomoto
|
|
|
228
229
|
}
|
|
229
230
|
}
|
|
230
231
|
|
|
231
|
-
|
|
232
|
+
Matrix newAlphas = Matrix::Zero(alphas.rows(), alphas.cols());
|
|
232
233
|
for (size_t t = 0; t < T; ++t)
|
|
233
234
|
{
|
|
234
235
|
// update alias tables for word proposal
|
|
@@ -398,9 +399,9 @@ namespace tomoto
|
|
|
398
399
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, T);
|
|
399
400
|
this->globalState.numByTopicWord = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K * T, V);
|
|
400
401
|
|
|
401
|
-
alphas =
|
|
402
|
-
etaByDoc =
|
|
403
|
-
phi =
|
|
402
|
+
alphas = Matrix::Zero(this->K, T);
|
|
403
|
+
etaByDoc = Matrix::Zero(this->K, this->docs.size());
|
|
404
|
+
phi = Matrix::Zero(this->realV, this->K * T);
|
|
404
405
|
}
|
|
405
406
|
|
|
406
407
|
numDocsByTime.resize(T);
|
|
@@ -418,7 +419,7 @@ namespace tomoto
|
|
|
418
419
|
|
|
419
420
|
for (Tid t = 0; t < T; ++t)
|
|
420
421
|
{
|
|
421
|
-
if (initDocs && !numDocsByTime[t]) THROW_ERROR_WITH_INFO(
|
|
422
|
+
if (initDocs && !numDocsByTime[t]) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("No document with timepoint = %d", t));
|
|
422
423
|
|
|
423
424
|
// update alias tables for word proposal
|
|
424
425
|
for (Vid v = 0; v < this->realV; ++v)
|
|
@@ -439,23 +440,26 @@ namespace tomoto
|
|
|
439
440
|
addWordTo<1>(ld, doc, i, w, z);
|
|
440
441
|
}
|
|
441
442
|
|
|
442
|
-
std::vector<Float> _getWidsByTopic(size_t tid) const
|
|
443
|
+
std::vector<Float> _getWidsByTopic(size_t tid, bool normalize = true) const
|
|
443
444
|
{
|
|
444
445
|
const size_t V = this->realV;
|
|
445
446
|
std::vector<Float> ret(V);
|
|
446
447
|
Eigen::Map<Eigen::Array<Float, -1, 1>> retMap(ret.data(), V);
|
|
447
448
|
retMap = phi.col(tid).array().exp();
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
449
|
+
if (normalize)
|
|
450
|
+
{
|
|
451
|
+
retMap /= retMap.sum();
|
|
452
|
+
Eigen::Array<Float, -1, 1> t = this->globalState.numByTopicWord.row(tid).array().template cast<Float>();
|
|
453
|
+
t /= std::max(t.sum(), (Float)0.1);
|
|
454
|
+
retMap += t;
|
|
455
|
+
retMap /= 2;
|
|
456
|
+
}
|
|
453
457
|
return ret;
|
|
454
458
|
}
|
|
455
459
|
|
|
456
460
|
_DocType& _updateDoc(_DocType& doc, uint32_t timepoint) const
|
|
457
461
|
{
|
|
458
|
-
if (timepoint >= T) THROW_ERROR_WITH_INFO(
|
|
462
|
+
if (timepoint >= T) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "timepoint must < T");
|
|
459
463
|
doc.timepoint = timepoint;
|
|
460
464
|
return doc;
|
|
461
465
|
}
|
|
@@ -473,6 +477,16 @@ namespace tomoto
|
|
|
473
477
|
return cnt;
|
|
474
478
|
}
|
|
475
479
|
|
|
480
|
+
void updateForCopy()
|
|
481
|
+
{
|
|
482
|
+
BaseClass::updateForCopy();
|
|
483
|
+
size_t docId = 0;
|
|
484
|
+
for (auto& doc : this->docs)
|
|
485
|
+
{
|
|
486
|
+
doc.eta.init((Float*)etaByDoc.col(docId++).data(), this->K, 1);
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
476
490
|
public:
|
|
477
491
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0,
|
|
478
492
|
T, shapeA, shapeB, shapeC, alphaVar, etaVar, phiVar, alphas, etaByDoc, phi);
|
|
@@ -489,11 +503,10 @@ namespace tomoto
|
|
|
489
503
|
GETTER(ShapeB, Float, shapeB);
|
|
490
504
|
GETTER(ShapeC, Float, shapeC);
|
|
491
505
|
|
|
492
|
-
DTModel(
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
shapeA{ _shapeA }, shapeB{ _shapeB }, shapeC{ _shapeC }, etaRegL2{ _etaRegL2 }
|
|
506
|
+
DTModel(const DTArgs& args)
|
|
507
|
+
: BaseClass{ args },
|
|
508
|
+
T{ args.t }, alphaVar{ args.alpha[0] }, etaVar{ args.eta }, phiVar{ args.phi },
|
|
509
|
+
shapeA{ args.shapeA }, shapeB{ args.shapeB }, shapeC{ args.shapeC }, etaRegL2{ args.etaL2Reg }
|
|
497
510
|
{
|
|
498
511
|
}
|
|
499
512
|
|
|
@@ -506,7 +519,7 @@ namespace tomoto
|
|
|
506
519
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
507
520
|
{
|
|
508
521
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
|
509
|
-
return make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
|
|
522
|
+
return std::make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
|
|
510
523
|
}
|
|
511
524
|
|
|
512
525
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
@@ -518,7 +531,7 @@ namespace tomoto
|
|
|
518
531
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
519
532
|
{
|
|
520
533
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
|
521
|
-
return make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
|
|
534
|
+
return std::make_unique<_DocType>(_updateDoc(doc, rawDoc.template getMisc<uint32_t>("timepoint")));
|
|
522
535
|
}
|
|
523
536
|
|
|
524
537
|
Float getAlpha(size_t k, size_t t) const override
|
|
@@ -10,26 +10,52 @@ namespace tomoto
|
|
|
10
10
|
using DocumentDMR<_tw>::DocumentDMR;
|
|
11
11
|
std::vector<Float> metadataOrg, metadataNormalized;
|
|
12
12
|
|
|
13
|
+
RawDoc::MiscType makeMisc(const ITopicModel* tm) const override
|
|
14
|
+
{
|
|
15
|
+
RawDoc::MiscType ret = DocumentDMR<_tw>::makeMisc(tm);
|
|
16
|
+
ret["numeric_metadata"] = metadataOrg;
|
|
17
|
+
return ret;
|
|
18
|
+
}
|
|
19
|
+
|
|
13
20
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadataOrg);
|
|
14
21
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadataOrg, metadataNormalized);
|
|
15
22
|
};
|
|
16
23
|
|
|
24
|
+
struct GDMRArgs : public DMRArgs
|
|
25
|
+
{
|
|
26
|
+
std::vector<uint64_t> degrees;
|
|
27
|
+
Float sigma0 = 3.0;
|
|
28
|
+
Float orderDecay = 0;
|
|
29
|
+
};
|
|
30
|
+
|
|
17
31
|
class IGDMRModel : public IDMRModel
|
|
18
32
|
{
|
|
19
33
|
public:
|
|
20
34
|
using DefaultDocType = DocumentDMR<TermWeight::one>;
|
|
21
|
-
static IGDMRModel* create(TermWeight _weight,
|
|
22
|
-
Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _sigma0 = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
|
|
23
|
-
size_t seed = std::random_device{}(),
|
|
35
|
+
static IGDMRModel* create(TermWeight _weight, const GDMRArgs& args,
|
|
24
36
|
bool scalarRng = false);
|
|
25
37
|
|
|
26
38
|
virtual Float getSigma0() const = 0;
|
|
39
|
+
virtual Float getOrderDecay() const = 0;
|
|
27
40
|
virtual void setSigma0(Float) = 0;
|
|
28
41
|
virtual const std::vector<uint64_t>& getFs() const = 0;
|
|
29
42
|
virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
|
|
30
43
|
|
|
31
|
-
virtual std::vector<Float> getTDF(
|
|
32
|
-
|
|
44
|
+
virtual std::vector<Float> getTDF(
|
|
45
|
+
const Float* metadata,
|
|
46
|
+
const std::string& metadataCat,
|
|
47
|
+
const std::vector<std::string>& multiMetadataCat,
|
|
48
|
+
bool normalize
|
|
49
|
+
) const = 0;
|
|
50
|
+
|
|
51
|
+
virtual std::vector<Float> getTDFBatch(
|
|
52
|
+
const Float* metadata,
|
|
53
|
+
const std::string& metadataCat,
|
|
54
|
+
const std::vector<std::string>& multiMetadataCat,
|
|
55
|
+
size_t stride,
|
|
56
|
+
size_t cnt,
|
|
57
|
+
bool normalize
|
|
58
|
+
) const = 0;
|
|
33
59
|
|
|
34
60
|
virtual void setMdRange(const std::vector<Float>& vMin, const std::vector<Float>& vMax) = 0;
|
|
35
61
|
virtual void getMdRange(std::vector<Float>& vMin, std::vector<Float>& vMax) const = 0;
|
|
@@ -2,13 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class GDMRModel<TermWeight::idf>;
|
|
7
|
-
template class GDMRModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IGDMRModel* IGDMRModel::create(TermWeight _weight, size_t _K, const std::vector<uint64_t>& degreeByF,
|
|
10
|
-
Float _defaultAlpha, Float _sigma, Float _sigma0, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
|
|
5
|
+
IGDMRModel* IGDMRModel::create(TermWeight _weight, const GDMRArgs& args, bool scalarRng)
|
|
11
6
|
{
|
|
12
|
-
TMT_SWITCH_TW(_weight, scalarRng, GDMRModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, GDMRModel, args);
|
|
13
8
|
}
|
|
14
9
|
}
|
|
@@ -8,8 +8,8 @@ namespace tomoto
|
|
|
8
8
|
template<TermWeight _tw>
|
|
9
9
|
struct ModelStateGDMR : public ModelStateDMR<_tw>
|
|
10
10
|
{
|
|
11
|
-
/*
|
|
12
|
-
|
|
11
|
+
/*Vector alphas;
|
|
12
|
+
Vector terms;
|
|
13
13
|
std::vector<std::vector<Float>> slpCache;
|
|
14
14
|
std::vector<size_t> ndimCnt;*/
|
|
15
15
|
};
|
|
@@ -22,7 +22,8 @@ namespace tomoto
|
|
|
22
22
|
typename _ModelState = ModelStateGDMR<_tw>>
|
|
23
23
|
class GDMRModel : public DMRModel<_tw, _RandGen, _Flags, _Interface,
|
|
24
24
|
typename std::conditional<std::is_same<_Derived, void>::value, GDMRModel<_tw, _RandGen>, _Derived>::type,
|
|
25
|
-
_DocType, _ModelState
|
|
25
|
+
_DocType, _ModelState
|
|
26
|
+
>
|
|
26
27
|
{
|
|
27
28
|
protected:
|
|
28
29
|
using DerivedClass = typename std::conditional<std::is_same<_Derived, void>::value, GDMRModel<_tw, _RandGen>, _Derived>::type;
|
|
@@ -32,51 +33,60 @@ namespace tomoto
|
|
|
32
33
|
friend typename BaseClass::BaseClass::BaseClass;
|
|
33
34
|
using WeightType = typename BaseClass::WeightType;
|
|
34
35
|
|
|
35
|
-
Float sigma0 = 3;
|
|
36
|
+
Float sigma0 = 3, orderDecay = 0;
|
|
36
37
|
std::vector<Float> mdCoefs, mdIntercepts, mdMax;
|
|
37
38
|
std::vector<uint64_t> degreeByF;
|
|
39
|
+
Eigen::Array<Float, -1, 1> orderDecayCached;
|
|
40
|
+
size_t fCont = 1;
|
|
38
41
|
|
|
39
|
-
Float getIntegratedLambdaSq(const Eigen::Ref<const
|
|
42
|
+
Float getIntegratedLambdaSq(const Eigen::Ref<const Vector, 0, Eigen::InnerStride<>>& lambdas) const
|
|
40
43
|
{
|
|
41
|
-
Float ret =
|
|
42
|
-
for (size_t i =
|
|
44
|
+
Float ret = 0;
|
|
45
|
+
for (size_t i = 0; i < this->F; ++i)
|
|
43
46
|
{
|
|
44
|
-
ret += pow(lambdas[i], 2) / 2 / pow(this->
|
|
47
|
+
ret += pow(lambdas[this->mdVecSize * i] - log(this->alpha), 2) / 2 / pow(this->sigma0, 2);
|
|
48
|
+
ret += (lambdas.segment(this->mdVecSize * i + 1, fCont - 1).array().pow(2) / 2 * orderDecayCached.segment(1, fCont - 1) / pow(this->sigma, 2)).sum();
|
|
49
|
+
ret += lambdas.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont).array().pow(2).sum() / 2 / pow(this->sigma, 2);
|
|
45
50
|
}
|
|
46
51
|
return ret;
|
|
47
52
|
}
|
|
48
53
|
|
|
49
|
-
void getIntegratedLambdaSqP(const Eigen::Ref<const
|
|
50
|
-
Eigen::Ref<
|
|
54
|
+
void getIntegratedLambdaSqP(const Eigen::Ref<const Vector, 0, Eigen::InnerStride<>>& lambdas,
|
|
55
|
+
Eigen::Ref<Vector, 0, Eigen::InnerStride<>> ret) const
|
|
51
56
|
{
|
|
52
|
-
|
|
53
|
-
for (size_t i = 1; i < this->F; ++i)
|
|
57
|
+
for (size_t i = 0; i < this->F; ++i)
|
|
54
58
|
{
|
|
55
|
-
ret[i] = lambdas[i] / pow(this->
|
|
59
|
+
ret[this->mdVecSize * i] = (lambdas[this->mdVecSize * i] - log(this->alpha)) / pow(this->sigma0, 2);
|
|
60
|
+
ret.segment(this->mdVecSize * i + 1, fCont - 1) = lambdas.segment(this->mdVecSize * i + 1, fCont - 1).array() * orderDecayCached.segment(1, fCont - 1) / pow(this->sigma, 2);
|
|
61
|
+
ret.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont) = lambdas.segment(this->mdVecSize * i + fCont, this->mdVecSize - fCont).array() / pow(this->sigma, 2);
|
|
56
62
|
}
|
|
57
63
|
}
|
|
58
64
|
|
|
59
65
|
void initParameters()
|
|
60
66
|
{
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
for (size_t i = 0; i < this->
|
|
67
|
+
this->lambda = Eigen::Rand::normalLike(this->lambda, this->rg);
|
|
68
|
+
|
|
69
|
+
for (size_t i = 0; i < this->F; ++i)
|
|
64
70
|
{
|
|
65
|
-
|
|
71
|
+
this->lambda.col(this->mdVecSize * i).array() *= sigma0;
|
|
72
|
+
this->lambda.col(this->mdVecSize * i).array() += log(this->alphas.array());
|
|
73
|
+
|
|
74
|
+
for (size_t j = 1; j < fCont; ++j)
|
|
66
75
|
{
|
|
67
|
-
this->lambda(i
|
|
76
|
+
this->lambda.col(this->mdVecSize * i + j).array() *= this->sigma / std::sqrt(orderDecayCached[j]);
|
|
68
77
|
}
|
|
69
|
-
|
|
78
|
+
|
|
79
|
+
for (size_t j = fCont; j < this->mdVecSize; ++j)
|
|
70
80
|
{
|
|
71
|
-
this->lambda(i
|
|
81
|
+
this->lambda.col(this->mdVecSize * i + j).array() *= this->sigma;
|
|
72
82
|
}
|
|
73
83
|
}
|
|
74
84
|
}
|
|
75
85
|
|
|
76
|
-
Float getNegativeLambdaLL(Eigen::Ref<
|
|
86
|
+
Float getNegativeLambdaLL(Eigen::Ref<Vector> x, Vector& g) const
|
|
77
87
|
{
|
|
78
|
-
auto mappedX = Eigen::Map<
|
|
79
|
-
auto mappedG = Eigen::Map<
|
|
88
|
+
auto mappedX = Eigen::Map<Matrix>(x.data(), this->K, this->F);
|
|
89
|
+
auto mappedG = Eigen::Map<Matrix>(g.data(), this->K, this->F);
|
|
80
90
|
|
|
81
91
|
Float fx = 0;
|
|
82
92
|
for (size_t k = 0; k < this->K; ++k)
|
|
@@ -87,50 +97,51 @@ namespace tomoto
|
|
|
87
97
|
return fx;
|
|
88
98
|
}
|
|
89
99
|
|
|
90
|
-
Float evaluateLambdaObj(Eigen::Ref<
|
|
100
|
+
/*Float evaluateLambdaObj(Eigen::Ref<Vector> x, Vector& g, ThreadPool& pool, _ModelState* localData) const
|
|
91
101
|
{
|
|
92
102
|
// if one of x is greater than maxLambda, return +inf for preventing search more
|
|
93
103
|
if ((x.array() > this->maxLambda).any()) return INFINITY;
|
|
94
104
|
|
|
95
105
|
const auto K = this->K;
|
|
96
|
-
const auto
|
|
106
|
+
const auto KF = this->K * this->F;
|
|
97
107
|
|
|
98
|
-
auto mappedX = Eigen::Map<
|
|
108
|
+
auto mappedX = Eigen::Map<Matrix>(x.data(), K, this->F);
|
|
99
109
|
Float fx = -static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
|
|
100
110
|
|
|
101
|
-
std::vector<std::future<
|
|
111
|
+
std::vector<std::future<Vector>> res;
|
|
102
112
|
const size_t chStride = pool.getNumWorkers() * 8;
|
|
103
113
|
for (size_t ch = 0; ch < chStride; ++ch)
|
|
104
114
|
{
|
|
105
115
|
res.emplace_back(pool.enqueue([&, this](size_t threadId)
|
|
106
116
|
{
|
|
107
117
|
auto& ld = localData[threadId];
|
|
108
|
-
thread_local
|
|
109
|
-
|
|
118
|
+
thread_local Vector alphas{ K }, tmpK{ K }, terms{ fCont };
|
|
119
|
+
Vector ret = Vector::Zero(KF + 1);
|
|
110
120
|
for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
|
|
111
121
|
{
|
|
112
122
|
const auto& doc = this->docs[docId];
|
|
113
123
|
const auto& vx = doc.metadataNormalized;
|
|
124
|
+
size_t xOffset = doc.metadata * fCont;
|
|
114
125
|
getTermsFromMd(&vx[0], terms.data());
|
|
115
126
|
for (Tid k = 0; k < K; ++k)
|
|
116
127
|
{
|
|
117
|
-
alphas[k] = exp(mappedX.row(k) * terms) + this->alphaEps;
|
|
118
|
-
ret[
|
|
119
|
-
assert(std::isfinite(ret[
|
|
128
|
+
alphas[k] = exp(mappedX.row(k).segment(xOffset, fCont) * terms) + this->alphaEps;
|
|
129
|
+
ret[KF] -= math::lgammaT(alphas[k]) - math::lgammaT(doc.numByTopic[k] + alphas[k]);
|
|
130
|
+
assert(std::isfinite(ret[KF]));
|
|
120
131
|
if (!std::isfinite(alphas[k]) && alphas[k] > 0) tmpK[k] = 0;
|
|
121
132
|
else tmpK[k] = -(math::digammaT(alphas[k]) - math::digammaT(doc.numByTopic[k] + alphas[k]));
|
|
122
133
|
}
|
|
123
134
|
Float alphaSum = alphas.sum();
|
|
124
|
-
ret[
|
|
135
|
+
ret[KF] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
|
|
125
136
|
Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
|
|
126
137
|
if (!std::isfinite(alphaSum) && alphaSum > 0)
|
|
127
138
|
{
|
|
128
|
-
ret[
|
|
139
|
+
ret[KF] = -INFINITY;
|
|
129
140
|
t = 0;
|
|
130
141
|
}
|
|
131
|
-
for (size_t
|
|
142
|
+
for (size_t i = 0; i < fCont; ++i)
|
|
132
143
|
{
|
|
133
|
-
ret.segment(
|
|
144
|
+
ret.segment((i + xOffset) * K, K).array() -= ((tmpK.array() + t) * alphas.array()) * terms[i];
|
|
134
145
|
}
|
|
135
146
|
assert(ret.allFinite());
|
|
136
147
|
}
|
|
@@ -140,14 +151,14 @@ namespace tomoto
|
|
|
140
151
|
for (auto& r : res)
|
|
141
152
|
{
|
|
142
153
|
auto ret = r.get();
|
|
143
|
-
fx += ret[
|
|
144
|
-
g += ret.head(
|
|
154
|
+
fx += ret[KF];
|
|
155
|
+
g += ret.head(KF);
|
|
145
156
|
}
|
|
146
157
|
|
|
147
158
|
// positive fx is an error from limited precision of float.
|
|
148
159
|
if (fx > 0) return INFINITY;
|
|
149
160
|
return -fx;
|
|
150
|
-
}
|
|
161
|
+
}*/
|
|
151
162
|
|
|
152
163
|
void getTermsFromMd(const Float* vx, Float* out, bool normalize = false) const
|
|
153
164
|
{
|
|
@@ -172,7 +183,7 @@ namespace tomoto
|
|
|
172
183
|
}
|
|
173
184
|
}
|
|
174
185
|
|
|
175
|
-
for (size_t i = 0; i <
|
|
186
|
+
for (size_t i = 0; i < fCont; ++i)
|
|
176
187
|
{
|
|
177
188
|
out[i] = 1;
|
|
178
189
|
for (size_t n = 0; n < degreeByF.size(); ++n)
|
|
@@ -180,47 +191,69 @@ namespace tomoto
|
|
|
180
191
|
if(digit[n]) out[i] *= slpCache[n][digit[n] - 1];
|
|
181
192
|
}
|
|
182
193
|
|
|
183
|
-
size_t u;
|
|
184
|
-
for (u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
|
|
194
|
+
for (size_t u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
|
|
185
195
|
{
|
|
186
196
|
digit[u] = 0;
|
|
187
197
|
}
|
|
188
|
-
u = std::min(u, degreeByF.size() - 1);
|
|
189
198
|
}
|
|
190
199
|
}
|
|
191
200
|
|
|
192
|
-
|
|
201
|
+
Eigen::Array<Float, -1, 1> calcOrderDecay() const
|
|
202
|
+
{
|
|
203
|
+
Eigen::Array<Float, -1, 1> ret{ fCont };
|
|
204
|
+
std::vector<size_t> digit(degreeByF.size());
|
|
205
|
+
std::fill(digit.begin(), digit.end(), 0);
|
|
206
|
+
|
|
207
|
+
for (size_t i = 0; i < fCont; ++i)
|
|
208
|
+
{
|
|
209
|
+
ret[i] = 1;
|
|
210
|
+
for (size_t n = 0; n < degreeByF.size(); ++n)
|
|
211
|
+
{
|
|
212
|
+
ret[i] *= pow(digit[n] + 1, orderDecay * 2);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
for (size_t u = 0; u < digit.size() && ++digit[u] > degreeByF[u]; ++u)
|
|
216
|
+
{
|
|
217
|
+
digit[u] = 0;
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
return ret;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
/*template<bool _asymEta>
|
|
193
224
|
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
194
225
|
{
|
|
195
226
|
const size_t V = this->realV;
|
|
196
227
|
assert(vid < V);
|
|
197
228
|
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
198
229
|
auto& zLikelihood = ld.zLikelihood;
|
|
199
|
-
thread_local
|
|
230
|
+
thread_local Vector terms{ fCont };
|
|
231
|
+
size_t xOffset = doc.metadata * fCont;
|
|
200
232
|
getTermsFromMd(&doc.metadataNormalized[0], terms.data());
|
|
201
|
-
zLikelihood = (doc.numByTopic.array().template cast<Float>() + (this->lambda * terms).array().exp() + this->alphaEps)
|
|
233
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + (this->lambda.middleCols(xOffset, fCont) * terms).array().exp() + this->alphaEps)
|
|
202
234
|
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
203
235
|
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
204
236
|
|
|
205
237
|
sample::prefixSum(zLikelihood.data(), this->K);
|
|
206
238
|
return &zLikelihood[0];
|
|
207
|
-
}
|
|
239
|
+
}*/
|
|
208
240
|
|
|
209
|
-
template<typename _DocIter>
|
|
241
|
+
/*template<typename _DocIter>
|
|
210
242
|
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
211
243
|
{
|
|
212
244
|
const auto K = this->K;
|
|
213
245
|
double ll = 0;
|
|
214
246
|
|
|
215
|
-
|
|
247
|
+
Vector alphas(K);
|
|
216
248
|
for (; _first != _last; ++_first)
|
|
217
249
|
{
|
|
218
250
|
auto& doc = *_first;
|
|
219
|
-
thread_local
|
|
251
|
+
thread_local Vector terms{ fCont };
|
|
220
252
|
getTermsFromMd(&doc.metadataNormalized[0], terms.data());
|
|
253
|
+
size_t xOffset = doc.metadata * fCont;
|
|
221
254
|
for (Tid k = 0; k < K; ++k)
|
|
222
255
|
{
|
|
223
|
-
alphas[k] = exp(this->lambda.row(k) * terms) + this->alphaEps;
|
|
256
|
+
alphas[k] = exp(this->lambda.row(k).segment(xOffset, fCont) * terms) + this->alphaEps;
|
|
224
257
|
}
|
|
225
258
|
Float alphaSum = alphas.sum();
|
|
226
259
|
for (Tid k = 0; k < K; ++k)
|
|
@@ -231,7 +264,7 @@ namespace tomoto
|
|
|
231
264
|
ll -= math::lgammaT(doc.getSumWordWeight() + alphaSum) - math::lgammaT(alphaSum);
|
|
232
265
|
}
|
|
233
266
|
return ll;
|
|
234
|
-
}
|
|
267
|
+
}*/
|
|
235
268
|
|
|
236
269
|
double getLLRest(const _ModelState& ld) const
|
|
237
270
|
{
|
|
@@ -296,15 +329,44 @@ namespace tomoto
|
|
|
296
329
|
|
|
297
330
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
298
331
|
{
|
|
299
|
-
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
332
|
+
BaseClass::BaseClass::prepareDoc(doc, docId, wordSize);
|
|
300
333
|
doc.metadataNormalized = normalizeMetadata(doc.metadataOrg);
|
|
334
|
+
|
|
335
|
+
doc.mdVec = Vector::Zero(this->mdVecSize);
|
|
336
|
+
getTermsFromMd(doc.metadataNormalized.data(), doc.mdVec.data());
|
|
337
|
+
for (auto x : doc.multiMetadata)
|
|
338
|
+
{
|
|
339
|
+
doc.mdVec[fCont + x] = 1;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
auto p = std::make_pair(doc.metadata, doc.mdVec);
|
|
343
|
+
auto it = this->mdHashMap.find(p);
|
|
344
|
+
if (it == this->mdHashMap.end())
|
|
345
|
+
{
|
|
346
|
+
it = this->mdHashMap.emplace(p, this->mdHashMap.size()).first;
|
|
347
|
+
}
|
|
348
|
+
doc.mdHash = it->second;
|
|
301
349
|
}
|
|
302
350
|
|
|
303
351
|
void initGlobalState(bool initDocs)
|
|
304
352
|
{
|
|
305
353
|
BaseClass::BaseClass::initGlobalState(initDocs);
|
|
306
|
-
|
|
307
|
-
if (
|
|
354
|
+
fCont = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
|
|
355
|
+
if (!this->metadataDict.size())
|
|
356
|
+
{
|
|
357
|
+
this->metadataDict.add("");
|
|
358
|
+
}
|
|
359
|
+
this->F = this->metadataDict.size();
|
|
360
|
+
this->mdVecSize = fCont + this->multiMetadataDict.size();
|
|
361
|
+
if (initDocs)
|
|
362
|
+
{
|
|
363
|
+
collectMinMaxMetadata();
|
|
364
|
+
this->lambda = Matrix::Zero(this->K, this->F * this->mdVecSize);
|
|
365
|
+
for (size_t i = 0; i < this->F; ++i)
|
|
366
|
+
{
|
|
367
|
+
this->lambda.col(this->mdVecSize * i) = log(this->alphas.array());
|
|
368
|
+
}
|
|
369
|
+
}
|
|
308
370
|
else
|
|
309
371
|
{
|
|
310
372
|
// Old binary file has metadataNormalized values into `metadataOrg`
|
|
@@ -320,13 +382,28 @@ namespace tomoto
|
|
|
320
382
|
}
|
|
321
383
|
}
|
|
322
384
|
}
|
|
385
|
+
|
|
386
|
+
for (auto& doc : this->docs)
|
|
387
|
+
{
|
|
388
|
+
if (doc.mdVec.size() == this->mdVecSize) continue;
|
|
389
|
+
doc.mdVec = Vector::Zero(this->mdVecSize);
|
|
390
|
+
getTermsFromMd(doc.metadataNormalized.data(), doc.mdVec.data());
|
|
391
|
+
for (auto x : doc.multiMetadata)
|
|
392
|
+
{
|
|
393
|
+
doc.mdVec[fCont + x] = 1;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
auto p = std::make_pair(doc.metadata, doc.mdVec);
|
|
397
|
+
auto it = this->mdHashMap.find(p);
|
|
398
|
+
if (it == this->mdHashMap.end())
|
|
399
|
+
{
|
|
400
|
+
it = this->mdHashMap.emplace(p, this->mdHashMap.size()).first;
|
|
401
|
+
}
|
|
402
|
+
doc.mdHash = it->second;
|
|
403
|
+
}
|
|
323
404
|
}
|
|
324
405
|
|
|
325
|
-
|
|
326
|
-
{
|
|
327
|
-
this->lambda = Eigen::Matrix<Float, -1, -1>::Zero(this->K, this->F);
|
|
328
|
-
this->lambda.col(0).fill(log(this->alpha));
|
|
329
|
-
}
|
|
406
|
+
orderDecayCached = calcOrderDecay();
|
|
330
407
|
LBFGSpp::LBFGSParam<Float> param;
|
|
331
408
|
param.max_iterations = this->maxBFGSIteration;
|
|
332
409
|
this->solver = decltype(this->solver){ param };
|
|
@@ -334,18 +411,17 @@ namespace tomoto
|
|
|
334
411
|
|
|
335
412
|
public:
|
|
336
413
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma0, degreeByF, mdCoefs, mdIntercepts);
|
|
337
|
-
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma0, degreeByF, mdCoefs, mdIntercepts, mdMax);
|
|
414
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma0, orderDecay, degreeByF, mdCoefs, mdIntercepts, mdMax);
|
|
338
415
|
|
|
339
|
-
GDMRModel(
|
|
340
|
-
|
|
341
|
-
Float _alphaEps = 1e-10, size_t _rg = std::random_device{}())
|
|
342
|
-
: BaseClass(_K, defaultAlpha, _sigma, _eta, _alphaEps, _rg), sigma0(_sigma0), degreeByF(_degreeByF)
|
|
416
|
+
GDMRModel(const GDMRArgs& args)
|
|
417
|
+
: BaseClass(args), sigma0(args.sigma0), orderDecay(args.orderDecay), degreeByF(args.degrees)
|
|
343
418
|
{
|
|
344
|
-
|
|
419
|
+
fCont = accumulate(degreeByF.begin(), degreeByF.end(), 1, [](size_t a, size_t b) {return a * (b + 1); });
|
|
345
420
|
}
|
|
346
421
|
|
|
347
422
|
GETTER(Fs, const std::vector<uint64_t>&, degreeByF);
|
|
348
423
|
GETTER(Sigma0, Float, sigma0);
|
|
424
|
+
GETTER(OrderDecay, Float, orderDecay);
|
|
349
425
|
|
|
350
426
|
void setSigma0(Float _sigma0) override
|
|
351
427
|
{
|
|
@@ -353,73 +429,94 @@ namespace tomoto
|
|
|
353
429
|
}
|
|
354
430
|
|
|
355
431
|
template<bool _const = false>
|
|
356
|
-
_DocType& _updateDoc(_DocType& doc, const std::vector<Float>& metadata
|
|
432
|
+
_DocType& _updateDoc(_DocType& doc, const std::vector<Float>& metadata, const std::string& metadataCat = {}, const std::vector<std::string>& mdVec = {})
|
|
357
433
|
{
|
|
358
434
|
if (metadata.size() != degreeByF.size())
|
|
359
|
-
throw
|
|
435
|
+
throw exc::InvalidArgument{ "a length of `metadata` should be equal to a length of `degrees`" };
|
|
360
436
|
doc.metadataOrg = metadata;
|
|
437
|
+
|
|
438
|
+
Vid xid;
|
|
439
|
+
if (_const)
|
|
440
|
+
{
|
|
441
|
+
xid = this->metadataDict.toWid(metadataCat);
|
|
442
|
+
if (xid == non_vocab_id) throw exc::InvalidArgument("unknown metadata '" + metadataCat + "'");
|
|
443
|
+
|
|
444
|
+
for (auto& m : mdVec)
|
|
445
|
+
{
|
|
446
|
+
Vid x = this->multiMetadataDict.toWid(m);
|
|
447
|
+
if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
|
|
448
|
+
doc.multiMetadata.emplace_back(x);
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
else
|
|
452
|
+
{
|
|
453
|
+
xid = this->metadataDict.add(metadataCat);
|
|
454
|
+
|
|
455
|
+
for (auto& m : mdVec)
|
|
456
|
+
{
|
|
457
|
+
doc.multiMetadata.emplace_back(this->multiMetadataDict.add(m));
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
doc.metadata = xid;
|
|
361
461
|
return doc;
|
|
362
462
|
}
|
|
363
463
|
|
|
364
464
|
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
|
365
465
|
{
|
|
366
466
|
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
|
367
|
-
return this->_addDoc(_updateDoc(doc,
|
|
467
|
+
return this->_addDoc(_updateDoc(doc,
|
|
468
|
+
rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
|
|
469
|
+
rawDoc.template getMiscDefault<std::string>("metadata"),
|
|
470
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
471
|
+
));
|
|
368
472
|
}
|
|
369
473
|
|
|
370
474
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
371
475
|
{
|
|
372
476
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
|
373
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
477
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
478
|
+
rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
|
|
479
|
+
rawDoc.template getMiscDefault<std::string>("metadata"),
|
|
480
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
481
|
+
));
|
|
374
482
|
}
|
|
375
483
|
|
|
376
484
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
377
485
|
{
|
|
378
486
|
auto doc = this->_makeFromRawDoc(rawDoc);
|
|
379
|
-
return this->_addDoc(_updateDoc(doc,
|
|
487
|
+
return this->_addDoc(_updateDoc(doc,
|
|
488
|
+
rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
|
|
489
|
+
rawDoc.template getMiscDefault<std::string>("metadata"),
|
|
490
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
491
|
+
));
|
|
380
492
|
}
|
|
381
493
|
|
|
382
494
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
383
495
|
{
|
|
384
496
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
|
385
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
497
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
498
|
+
rawDoc.template getMisc<std::vector<Float>>("numeric_metadata"),
|
|
499
|
+
rawDoc.template getMiscDefault<std::string>("metadata"),
|
|
500
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
501
|
+
));
|
|
386
502
|
}
|
|
387
503
|
|
|
388
|
-
std::vector<Float>
|
|
504
|
+
std::vector<Float> getTDF(const Float* metadata, const std::string& metadataCat, const std::vector<std::string>& multiMetadataCat, bool normalize) const override
|
|
389
505
|
{
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
for (Tid k = 0; k < this->K; ++k)
|
|
394
|
-
{
|
|
395
|
-
alphas[k] = exp(this->lambda.row(k) * terms) + this->alphaEps;
|
|
396
|
-
}
|
|
397
|
-
std::vector<Float> ret(this->K);
|
|
398
|
-
Float sum = doc.getSumWordWeight() + alphas.sum();
|
|
399
|
-
for (size_t k = 0; k < this->K; ++k)
|
|
400
|
-
{
|
|
401
|
-
ret[k] = (doc.numByTopic[k] + alphas[k]) / sum;
|
|
402
|
-
}
|
|
403
|
-
return ret;
|
|
404
|
-
}
|
|
405
|
-
|
|
406
|
-
std::vector<Float> getLambdaByTopic(Tid tid) const override
|
|
407
|
-
{
|
|
408
|
-
std::vector<Float> ret(this->F);
|
|
409
|
-
if (this->lambda.size())
|
|
506
|
+
Vector terms = Vector::Zero(this->mdVecSize);
|
|
507
|
+
getTermsFromMd(metadata, terms.data(), true);
|
|
508
|
+
for (auto& s : multiMetadataCat)
|
|
410
509
|
{
|
|
411
|
-
|
|
510
|
+
Vid x = this->multiMetadataDict.toWid(s);
|
|
511
|
+
if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata " + text::quote(s));
|
|
512
|
+
terms[fCont + x] = 1;
|
|
412
513
|
}
|
|
413
|
-
|
|
414
|
-
|
|
514
|
+
Vid x = this->metadataDict.toWid(metadataCat);
|
|
515
|
+
if (x == non_vocab_id) throw exc::InvalidArgument("unknown metadata " + text::quote(metadataCat));
|
|
415
516
|
|
|
416
|
-
std::vector<Float> getTDF(const Float* metadata, bool normalize) const override
|
|
417
|
-
{
|
|
418
|
-
Eigen::Matrix<Float, -1, 1> terms{ this->F };
|
|
419
|
-
getTermsFromMd(metadata, terms.data(), true);
|
|
420
517
|
std::vector<Float> ret(this->K);
|
|
421
518
|
Eigen::Map<Eigen::Array<Float, -1, 1>> retMap{ ret.data(), (Eigen::Index)ret.size() };
|
|
422
|
-
retMap = (this->lambda * terms).array();
|
|
519
|
+
retMap = (this->lambda.middleCols(x * this->mdVecSize, this->mdVecSize) * terms).array();
|
|
423
520
|
if (normalize)
|
|
424
521
|
{
|
|
425
522
|
retMap = (retMap - retMap.maxCoeff()).exp();
|
|
@@ -428,16 +525,25 @@ namespace tomoto
|
|
|
428
525
|
return ret;
|
|
429
526
|
}
|
|
430
527
|
|
|
431
|
-
std::vector<Float> getTDFBatch(const Float* metadata, size_t stride, size_t cnt, bool normalize) const override
|
|
528
|
+
std::vector<Float> getTDFBatch(const Float* metadata, const std::string& metadataCat, const std::vector<std::string>& multiMetadataCat, size_t stride, size_t cnt, bool normalize) const override
|
|
432
529
|
{
|
|
433
|
-
|
|
530
|
+
Matrix terms = Matrix::Zero(this->mdVecSize, (Eigen::Index)cnt);
|
|
434
531
|
for (size_t i = 0; i < cnt; ++i)
|
|
435
532
|
{
|
|
436
533
|
getTermsFromMd(metadata + stride * i, terms.col(i).data(), true);
|
|
437
534
|
}
|
|
535
|
+
for (auto& s : multiMetadataCat)
|
|
536
|
+
{
|
|
537
|
+
Vid x = this->multiMetadataDict.toWid(s);
|
|
538
|
+
if (x == non_vocab_id) throw exc::InvalidArgument("unknown multi_metadata " + text::quote(s));
|
|
539
|
+
terms.row(fCont + x).setOnes();
|
|
540
|
+
}
|
|
541
|
+
Vid x = this->metadataDict.toWid(metadataCat);
|
|
542
|
+
if (x == non_vocab_id) throw exc::InvalidArgument("unknown metadata " + text::quote(metadataCat));
|
|
543
|
+
|
|
438
544
|
std::vector<Float> ret(this->K * cnt);
|
|
439
545
|
Eigen::Map<Eigen::Array<Float, -1, -1>> retMap{ ret.data(), (Eigen::Index)this->K, (Eigen::Index)cnt };
|
|
440
|
-
retMap = (this->lambda * terms).array();
|
|
546
|
+
retMap = (this->lambda.middleCols(x * this->mdVecSize, this->mdVecSize) * terms).array();
|
|
441
547
|
if (normalize)
|
|
442
548
|
{
|
|
443
549
|
retMap.rowwise() -= retMap.colwise().maxCoeff();
|
|
@@ -446,6 +552,7 @@ namespace tomoto
|
|
|
446
552
|
}
|
|
447
553
|
return ret;
|
|
448
554
|
}
|
|
555
|
+
|
|
449
556
|
void setMdRange(const std::vector<Float>& vMin, const std::vector<Float>& vMax) override
|
|
450
557
|
{
|
|
451
558
|
mdIntercepts = vMin;
|