tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -3,24 +3,36 @@
|
|
|
3
3
|
|
|
4
4
|
namespace tomoto
|
|
5
5
|
{
|
|
6
|
+
class IDMRModel;
|
|
7
|
+
|
|
6
8
|
template<TermWeight _tw>
|
|
7
9
|
struct DocumentDMR : public DocumentLDA<_tw>
|
|
8
10
|
{
|
|
9
11
|
using BaseDocument = DocumentLDA<_tw>;
|
|
10
12
|
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
13
|
uint64_t metadata = 0;
|
|
14
|
+
std::vector<uint64_t> multiMetadata;
|
|
15
|
+
Vector mdVec;
|
|
16
|
+
size_t mdHash = (size_t)-1;
|
|
17
|
+
mutable Matrix cachedAlpha;
|
|
18
|
+
|
|
19
|
+
RawDoc::MiscType makeMisc(const ITopicModel* tm) const override;
|
|
12
20
|
|
|
13
21
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, metadata);
|
|
14
|
-
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata);
|
|
22
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, metadata, multiMetadata);
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
struct DMRArgs : public LDAArgs
|
|
26
|
+
{
|
|
27
|
+
Float alphaEps = 1e-10;
|
|
28
|
+
Float sigma = 1.0;
|
|
15
29
|
};
|
|
16
30
|
|
|
17
31
|
class IDMRModel : public ILDAModel
|
|
18
32
|
{
|
|
19
33
|
public:
|
|
20
34
|
using DefaultDocType = DocumentDMR<TermWeight::one>;
|
|
21
|
-
static IDMRModel* create(TermWeight _weight,
|
|
22
|
-
Float defaultAlpha = 1.0, Float _sigma = 1.0, Float _eta = 0.01, Float _alphaEps = 1e-10,
|
|
23
|
-
size_t seed = std::random_device{}(),
|
|
35
|
+
static IDMRModel* create(TermWeight _weight, const DMRArgs& args,
|
|
24
36
|
bool scalarRng = false);
|
|
25
37
|
|
|
26
38
|
virtual void setAlphaEps(Float _alphaEps) = 0;
|
|
@@ -28,9 +40,26 @@ namespace tomoto
|
|
|
28
40
|
virtual void setOptimRepeat(size_t repeat) = 0;
|
|
29
41
|
virtual size_t getOptimRepeat() const = 0;
|
|
30
42
|
virtual size_t getF() const = 0;
|
|
43
|
+
virtual size_t getMdVecSize() const = 0;
|
|
31
44
|
virtual Float getSigma() const = 0;
|
|
32
45
|
virtual const Dictionary& getMetadataDict() const = 0;
|
|
46
|
+
virtual const Dictionary& getMultiMetadataDict() const = 0;
|
|
33
47
|
virtual std::vector<Float> getLambdaByMetadata(size_t metadataId) const = 0;
|
|
34
48
|
virtual std::vector<Float> getLambdaByTopic(Tid tid) const = 0;
|
|
49
|
+
|
|
50
|
+
virtual std::vector<Float> getTopicPrior(
|
|
51
|
+
const std::string& metadata,
|
|
52
|
+
const std::vector<std::string>& multiMetadata,
|
|
53
|
+
bool raw = false
|
|
54
|
+
) const = 0;
|
|
35
55
|
};
|
|
56
|
+
|
|
57
|
+
template<TermWeight _tw>
|
|
58
|
+
RawDoc::MiscType DocumentDMR<_tw>::makeMisc(const ITopicModel* tm) const
|
|
59
|
+
{
|
|
60
|
+
RawDoc::MiscType ret = DocumentLDA<_tw>::makeMisc(tm);
|
|
61
|
+
auto inst = static_cast<const IDMRModel*>(tm);
|
|
62
|
+
ret["metadata"] = inst->getMetadataDict().toWord(metadata);
|
|
63
|
+
return ret;
|
|
64
|
+
}
|
|
36
65
|
}
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class DMRModel<TermWeight::idf>;
|
|
7
|
-
template class DMRModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
IDMRModel* IDMRModel::create(TermWeight _weight, size_t _K, Float _defaultAlpha, Float _sigma, Float _eta, Float _alphaEps, size_t seed, bool scalarRng)
|
|
5
|
+
IDMRModel* IDMRModel::create(TermWeight _weight, const DMRArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, DMRModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, DMRModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -13,7 +13,21 @@ namespace tomoto
|
|
|
13
13
|
template<TermWeight _tw>
|
|
14
14
|
struct ModelStateDMR : public ModelStateLDA<_tw>
|
|
15
15
|
{
|
|
16
|
-
|
|
16
|
+
Vector tmpK;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
struct MdHash
|
|
20
|
+
{
|
|
21
|
+
size_t operator()(std::pair<uint64_t, Vector> const& p) const
|
|
22
|
+
{
|
|
23
|
+
size_t seed = p.first;
|
|
24
|
+
for (size_t i = 0; i < p.second.size(); ++i)
|
|
25
|
+
{
|
|
26
|
+
auto elem = p.second[i];
|
|
27
|
+
seed ^= std::hash<decltype(elem)>()(elem) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
28
|
+
}
|
|
29
|
+
return seed;
|
|
30
|
+
}
|
|
17
31
|
};
|
|
18
32
|
|
|
19
33
|
template<TermWeight _tw, typename _RandGen,
|
|
@@ -35,36 +49,37 @@ namespace tomoto
|
|
|
35
49
|
|
|
36
50
|
static constexpr char TMID[] = "DMR\0";
|
|
37
51
|
|
|
38
|
-
|
|
39
|
-
|
|
52
|
+
Matrix lambda;
|
|
53
|
+
mutable std::unordered_map<std::pair<uint64_t, Vector>, size_t, MdHash> mdHashMap;
|
|
54
|
+
mutable Matrix cachedAlphas;
|
|
40
55
|
Float sigma;
|
|
41
|
-
uint32_t F = 0;
|
|
56
|
+
uint32_t F = 0, mdVecSize = 1;
|
|
42
57
|
uint32_t optimRepeat = 5;
|
|
43
58
|
Float alphaEps = 1e-10;
|
|
44
|
-
Float temperatureScale = 0;
|
|
45
59
|
static constexpr Float maxLambda = 10;
|
|
46
60
|
static constexpr size_t maxBFGSIteration = 10;
|
|
47
61
|
|
|
48
62
|
Dictionary metadataDict;
|
|
63
|
+
Dictionary multiMetadataDict;
|
|
49
64
|
LBFGSpp::LBFGSSolver<Float, LBFGSpp::LineSearchBracketing> solver;
|
|
50
65
|
|
|
51
|
-
Float getNegativeLambdaLL(Eigen::Ref<
|
|
66
|
+
Float getNegativeLambdaLL(Eigen::Ref<Vector> x, Vector& g) const
|
|
52
67
|
{
|
|
53
68
|
g = (x.array() - log(this->alpha)) / pow(sigma, 2);
|
|
54
69
|
return (x.array() - log(this->alpha)).pow(2).sum() / 2 / pow(sigma, 2);
|
|
55
70
|
}
|
|
56
71
|
|
|
57
|
-
Float evaluateLambdaObj(Eigen::Ref<
|
|
72
|
+
Float evaluateLambdaObj(Eigen::Ref<Vector> x, Vector& g, ThreadPool& pool, _ModelState* localData) const
|
|
58
73
|
{
|
|
59
74
|
// if one of x is greater than maxLambda, return +inf for preventing searching more
|
|
60
75
|
if ((x.array() > maxLambda).any()) return INFINITY;
|
|
61
76
|
|
|
62
77
|
const auto K = this->K;
|
|
63
78
|
|
|
64
|
-
Float fx = -
|
|
65
|
-
|
|
79
|
+
Float fx = -static_cast<const DerivedClass*>(this)->getNegativeLambdaLL(x, g);
|
|
80
|
+
Eigen::Map<Matrix> xReshaped{ x.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) };
|
|
66
81
|
|
|
67
|
-
std::vector<std::future<Eigen::
|
|
82
|
+
std::vector<std::future<Eigen::Array<Float, -1, 1>>> res;
|
|
68
83
|
const size_t chStride = pool.getNumWorkers() * 8;
|
|
69
84
|
for (size_t ch = 0; ch < chStride; ++ch)
|
|
70
85
|
{
|
|
@@ -72,28 +87,28 @@ namespace tomoto
|
|
|
72
87
|
{
|
|
73
88
|
auto& tmpK = localData[threadId].tmpK;
|
|
74
89
|
if (!tmpK.size()) tmpK.resize(this->K);
|
|
75
|
-
Eigen::
|
|
90
|
+
Eigen::Array<Float, -1, 1> val = Eigen::Array<Float, -1, 1>::Zero(K * F * mdVecSize + 1);
|
|
91
|
+
Eigen::Map<Matrix> grad{ val.data(), (Eigen::Index)K, (Eigen::Index)(F * mdVecSize) };
|
|
92
|
+
Float& fx = val[K * F * mdVecSize];
|
|
76
93
|
for (size_t docId = ch; docId < this->docs.size(); docId += chStride)
|
|
77
94
|
{
|
|
78
95
|
const auto& doc = this->docs[docId];
|
|
79
|
-
auto alphaDoc =
|
|
96
|
+
auto alphaDoc = ((xReshaped.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps).matrix().eval();
|
|
80
97
|
Float alphaSum = alphaDoc.sum();
|
|
81
98
|
for (Tid k = 0; k < K; ++k)
|
|
82
99
|
{
|
|
83
|
-
|
|
100
|
+
fx -= math::lgammaT(alphaDoc[k]) - math::lgammaT(doc.numByTopic[k] + alphaDoc[k]);
|
|
84
101
|
if (!std::isfinite(alphaDoc[k]) && alphaDoc[k] > 0) tmpK[k] = 0;
|
|
85
102
|
else tmpK[k] = -(math::digammaT(alphaDoc[k]) - math::digammaT(doc.numByTopic[k] + alphaDoc[k]));
|
|
86
103
|
}
|
|
87
|
-
|
|
88
|
-
//tmpK = -(digammaApprox(alphaDoc.array()) - digammaApprox(doc.numByTopic.array().cast<Float>() + alphaDoc.array()));
|
|
89
|
-
val[K * F] += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
|
|
104
|
+
fx += math::lgammaT(alphaSum) - math::lgammaT(doc.getSumWordWeight() + alphaSum);
|
|
90
105
|
Float t = math::digammaT(alphaSum) - math::digammaT(doc.getSumWordWeight() + alphaSum);
|
|
91
106
|
if (!std::isfinite(alphaSum) && alphaSum > 0)
|
|
92
107
|
{
|
|
93
|
-
|
|
108
|
+
fx = -INFINITY;
|
|
94
109
|
t = 0;
|
|
95
110
|
}
|
|
96
|
-
|
|
111
|
+
grad.middleCols(doc.metadata * mdVecSize, mdVecSize) -= (alphaDoc.array() * (tmpK.array() + t)).matrix() * doc.mdVec.transpose();
|
|
97
112
|
}
|
|
98
113
|
return val;
|
|
99
114
|
}));
|
|
@@ -101,8 +116,8 @@ namespace tomoto
|
|
|
101
116
|
for (auto& r : res)
|
|
102
117
|
{
|
|
103
118
|
auto ret = r.get();
|
|
104
|
-
fx += ret[K * F];
|
|
105
|
-
g += ret.head(K * F);
|
|
119
|
+
fx += ret[K * F * mdVecSize];
|
|
120
|
+
g += ret.head(K * F * mdVecSize).matrix();
|
|
106
121
|
}
|
|
107
122
|
|
|
108
123
|
// positive fx is an error from limited precision of float.
|
|
@@ -112,24 +127,24 @@ namespace tomoto
|
|
|
112
127
|
|
|
113
128
|
void initParameters()
|
|
114
129
|
{
|
|
115
|
-
|
|
116
|
-
for (size_t
|
|
130
|
+
lambda = Eigen::Rand::normalLike(lambda, this->rg, 0, sigma);
|
|
131
|
+
for (size_t f = 0; f < F; ++f)
|
|
117
132
|
{
|
|
118
|
-
lambda(
|
|
133
|
+
lambda.col(f * mdVecSize) += this->alphas.array().log().matrix();
|
|
119
134
|
}
|
|
120
135
|
}
|
|
121
136
|
|
|
122
137
|
void optimizeParameters(ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
123
138
|
{
|
|
124
|
-
|
|
139
|
+
Matrix bLambda;
|
|
125
140
|
Float fx = 0, bestFx = INFINITY;
|
|
126
141
|
for (size_t i = 0; i < optimRepeat; ++i)
|
|
127
142
|
{
|
|
128
143
|
static_cast<DerivedClass*>(this)->initParameters();
|
|
129
|
-
int ret = solver.minimize([this, &pool, localData](Eigen::Ref<
|
|
144
|
+
int ret = solver.minimize([this, &pool, localData](Eigen::Ref<Vector> x, Vector& g)
|
|
130
145
|
{
|
|
131
146
|
return static_cast<DerivedClass*>(this)->evaluateLambdaObj(x, g, pool, localData);
|
|
132
|
-
}, Eigen::Map<
|
|
147
|
+
}, Eigen::Map<Vector>(lambda.data(), lambda.size()), fx);
|
|
133
148
|
|
|
134
149
|
if (fx < bestFx)
|
|
135
150
|
{
|
|
@@ -140,44 +155,60 @@ namespace tomoto
|
|
|
140
155
|
}
|
|
141
156
|
if (!std::isfinite(bestFx))
|
|
142
157
|
{
|
|
143
|
-
throw
|
|
158
|
+
throw exc::TrainingError{ "optimizing parameters has been failed!" };
|
|
144
159
|
}
|
|
145
160
|
lambda = bLambda;
|
|
161
|
+
updateCachedAlphas();
|
|
146
162
|
//std::cerr << fx << std::endl;
|
|
147
|
-
expLambda = lambda.array().exp() + alphaEps;
|
|
148
163
|
}
|
|
149
164
|
|
|
150
|
-
int restoreFromTrainingError(const
|
|
165
|
+
int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
151
166
|
{
|
|
152
167
|
std::cerr << "Failed to optimize! Reset prior and retry!" << std::endl;
|
|
153
168
|
lambda.setZero();
|
|
154
|
-
|
|
169
|
+
updateCachedAlphas();
|
|
155
170
|
return 0;
|
|
156
171
|
}
|
|
157
172
|
|
|
173
|
+
auto getCachedAlpha(const _DocType& doc) const
|
|
174
|
+
{
|
|
175
|
+
if (doc.mdHash < cachedAlphas.cols())
|
|
176
|
+
{
|
|
177
|
+
return cachedAlphas.col(doc.mdHash);
|
|
178
|
+
}
|
|
179
|
+
else
|
|
180
|
+
{
|
|
181
|
+
if (!doc.cachedAlpha.size())
|
|
182
|
+
{
|
|
183
|
+
doc.cachedAlpha = (lambda.middleCols(doc.metadata * mdVecSize, mdVecSize) * doc.mdVec).array().exp() + alphaEps;
|
|
184
|
+
}
|
|
185
|
+
return doc.cachedAlpha.col(0);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
158
189
|
template<bool _asymEta>
|
|
159
190
|
Float* getZLikelihoods(_ModelState& ld, const _DocType& doc, size_t docId, size_t vid) const
|
|
160
191
|
{
|
|
161
192
|
const size_t V = this->realV;
|
|
162
193
|
assert(vid < V);
|
|
163
194
|
auto etaHelper = this->template getEtaHelper<_asymEta>();
|
|
195
|
+
auto alphas = getCachedAlpha(doc);
|
|
164
196
|
auto& zLikelihood = ld.zLikelihood;
|
|
165
|
-
zLikelihood = (doc.numByTopic.array().template cast<Float>() +
|
|
197
|
+
zLikelihood = (doc.numByTopic.array().template cast<Float>() + alphas.array())
|
|
166
198
|
* (ld.numByTopicWord.col(vid).array().template cast<Float>() + etaHelper.getEta(vid))
|
|
167
199
|
/ (ld.numByTopic.array().template cast<Float>() + etaHelper.getEtaSum());
|
|
168
200
|
|
|
169
201
|
sample::prefixSum(zLikelihood.data(), this->K);
|
|
170
202
|
return &zLikelihood[0];
|
|
171
203
|
}
|
|
172
|
-
|
|
173
204
|
|
|
174
205
|
double getLLDocTopic(const _DocType& doc) const
|
|
175
206
|
{
|
|
176
207
|
const size_t V = this->realV;
|
|
177
208
|
const auto K = this->K;
|
|
178
209
|
|
|
179
|
-
auto alphaDoc =
|
|
180
|
-
|
|
210
|
+
auto alphaDoc = getCachedAlpha(doc);
|
|
211
|
+
|
|
181
212
|
Float ll = 0;
|
|
182
213
|
Float alphaSum = alphaDoc.sum();
|
|
183
214
|
for (Tid k = 0; k < K; ++k)
|
|
@@ -199,7 +230,7 @@ namespace tomoto
|
|
|
199
230
|
for (; _first != _last; ++_first)
|
|
200
231
|
{
|
|
201
232
|
auto& doc = *_first;
|
|
202
|
-
auto alphaDoc =
|
|
233
|
+
auto alphaDoc = getCachedAlpha(doc);
|
|
203
234
|
Float alphaSum = alphaDoc.sum();
|
|
204
235
|
|
|
205
236
|
for (Tid k = 0; k < K; ++k)
|
|
@@ -234,45 +265,133 @@ namespace tomoto
|
|
|
234
265
|
return ll;
|
|
235
266
|
}
|
|
236
267
|
|
|
268
|
+
void updateCachedAlphas() const
|
|
269
|
+
{
|
|
270
|
+
cachedAlphas.resize(this->K, mdHashMap.size());
|
|
271
|
+
|
|
272
|
+
for (auto& p : mdHashMap)
|
|
273
|
+
{
|
|
274
|
+
cachedAlphas.col(p.second) = (lambda.middleCols(p.first.first * mdVecSize, mdVecSize) * p.first.second).array().exp() + alphaEps;
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
279
|
+
{
|
|
280
|
+
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
281
|
+
|
|
282
|
+
doc.mdVec = Vector::Zero(mdVecSize);
|
|
283
|
+
doc.mdVec[0] = 1;
|
|
284
|
+
for (auto x : doc.multiMetadata)
|
|
285
|
+
{
|
|
286
|
+
doc.mdVec[x + 1] = 1;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
auto p = std::make_pair(doc.metadata, doc.mdVec);
|
|
290
|
+
auto it = mdHashMap.find(p);
|
|
291
|
+
if (it == mdHashMap.end())
|
|
292
|
+
{
|
|
293
|
+
it = mdHashMap.emplace(p, mdHashMap.size()).first;
|
|
294
|
+
}
|
|
295
|
+
doc.mdHash = it->second;
|
|
296
|
+
}
|
|
297
|
+
|
|
237
298
|
void initGlobalState(bool initDocs)
|
|
238
299
|
{
|
|
239
300
|
BaseClass::initGlobalState(initDocs);
|
|
240
|
-
this->globalState.tmpK =
|
|
301
|
+
this->globalState.tmpK = Vector::Zero(this->K);
|
|
241
302
|
F = metadataDict.size();
|
|
303
|
+
mdVecSize = multiMetadataDict.size() + 1;
|
|
242
304
|
if (initDocs)
|
|
243
305
|
{
|
|
244
|
-
lambda
|
|
306
|
+
lambda.resize(this->K, F * mdVecSize);
|
|
307
|
+
for (size_t f = 0; f < F; ++f)
|
|
308
|
+
{
|
|
309
|
+
lambda.col(f * mdVecSize) = this->alphas.array().log();
|
|
310
|
+
lambda.middleCols(f * mdVecSize + 1, mdVecSize - 1).setZero();
|
|
311
|
+
}
|
|
245
312
|
}
|
|
313
|
+
else
|
|
314
|
+
{
|
|
315
|
+
for (auto& doc : this->docs)
|
|
316
|
+
{
|
|
317
|
+
if (doc.mdVec.size() == mdVecSize) continue;
|
|
318
|
+
doc.mdVec = Vector::Zero(mdVecSize);
|
|
319
|
+
doc.mdVec[0] = 1;
|
|
320
|
+
for (auto x : doc.multiMetadata)
|
|
321
|
+
{
|
|
322
|
+
doc.mdVec[x + 1] = 1;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
auto p = std::make_pair(doc.metadata, doc.mdVec);
|
|
326
|
+
auto it = this->mdHashMap.find(p);
|
|
327
|
+
if (it == this->mdHashMap.end())
|
|
328
|
+
{
|
|
329
|
+
it = this->mdHashMap.emplace(p, mdHashMap.size()).first;
|
|
330
|
+
}
|
|
331
|
+
doc.mdHash = it->second;
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
246
335
|
if (_Flags & flags::continuous_doc_data) this->numByTopicDoc = Eigen::Matrix<WeightType, -1, -1>::Zero(this->K, this->docs.size());
|
|
247
|
-
expLambda = lambda.array().exp();
|
|
248
336
|
LBFGSpp::LBFGSParam<Float> param;
|
|
249
337
|
param.max_iterations = maxBFGSIteration;
|
|
250
338
|
solver = decltype(solver){ param };
|
|
251
339
|
}
|
|
252
340
|
|
|
341
|
+
void prepareShared()
|
|
342
|
+
{
|
|
343
|
+
BaseClass::prepareShared();
|
|
344
|
+
|
|
345
|
+
for (auto doc : this->docs)
|
|
346
|
+
{
|
|
347
|
+
if (doc.mdHash != (size_t)-1) continue;
|
|
348
|
+
|
|
349
|
+
auto p = std::make_pair(doc.metadata, doc.mdVec);
|
|
350
|
+
auto it = mdHashMap.find(p);
|
|
351
|
+
if (it == mdHashMap.end())
|
|
352
|
+
{
|
|
353
|
+
it = mdHashMap.emplace(p, mdHashMap.size()).first;
|
|
354
|
+
}
|
|
355
|
+
doc.mdHash = it->second;
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
updateCachedAlphas();
|
|
359
|
+
}
|
|
360
|
+
|
|
253
361
|
public:
|
|
254
362
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, sigma, alphaEps, metadataDict, lambda);
|
|
255
|
-
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda);
|
|
363
|
+
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, sigma, alphaEps, metadataDict, lambda, multiMetadataDict);
|
|
256
364
|
|
|
257
|
-
DMRModel(
|
|
258
|
-
|
|
259
|
-
: BaseClass(_K, defaultAlpha, _eta, _rg), sigma(_sigma), alphaEps(_alphaEps)
|
|
365
|
+
DMRModel(const DMRArgs& args)
|
|
366
|
+
: BaseClass(args), sigma(args.sigma), alphaEps(args.alphaEps)
|
|
260
367
|
{
|
|
261
|
-
if (
|
|
368
|
+
if (sigma <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong sigma value (sigma = %f)", sigma));
|
|
262
369
|
}
|
|
263
370
|
|
|
264
371
|
template<bool _const = false>
|
|
265
|
-
_DocType& _updateDoc(_DocType& doc, const std::string& metadata)
|
|
372
|
+
_DocType& _updateDoc(_DocType& doc, const std::string& metadata, const std::vector<std::string>& mdVec = {})
|
|
266
373
|
{
|
|
267
374
|
Vid xid;
|
|
268
375
|
if (_const)
|
|
269
376
|
{
|
|
270
377
|
xid = metadataDict.toWid(metadata);
|
|
271
|
-
if (xid == (Vid)-1) throw
|
|
378
|
+
if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'");
|
|
379
|
+
|
|
380
|
+
for (auto& m : mdVec)
|
|
381
|
+
{
|
|
382
|
+
Vid x = multiMetadataDict.toWid(m);
|
|
383
|
+
if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
|
|
384
|
+
doc.multiMetadata.emplace_back(x);
|
|
385
|
+
}
|
|
272
386
|
}
|
|
273
387
|
else
|
|
274
388
|
{
|
|
275
389
|
xid = metadataDict.add(metadata);
|
|
390
|
+
|
|
391
|
+
for (auto& m : mdVec)
|
|
392
|
+
{
|
|
393
|
+
doc.multiMetadata.emplace_back(multiMetadataDict.add(m));
|
|
394
|
+
}
|
|
276
395
|
}
|
|
277
396
|
doc.metadata = xid;
|
|
278
397
|
return doc;
|
|
@@ -281,28 +400,41 @@ namespace tomoto
|
|
|
281
400
|
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
|
282
401
|
{
|
|
283
402
|
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
|
284
|
-
return this->_addDoc(_updateDoc(doc,
|
|
403
|
+
return this->_addDoc(_updateDoc(doc,
|
|
404
|
+
rawDoc.template getMisc<std::string>("metadata"),
|
|
405
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
406
|
+
));
|
|
285
407
|
}
|
|
286
408
|
|
|
287
409
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
288
410
|
{
|
|
289
411
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
|
290
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
412
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
413
|
+
rawDoc.template getMisc<std::string>("metadata"),
|
|
414
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
415
|
+
));
|
|
291
416
|
}
|
|
292
417
|
|
|
293
418
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
294
419
|
{
|
|
295
420
|
auto doc = this->_makeFromRawDoc(rawDoc);
|
|
296
|
-
return this->_addDoc(_updateDoc(doc,
|
|
421
|
+
return this->_addDoc(_updateDoc(doc,
|
|
422
|
+
rawDoc.template getMisc<std::string>("metadata"),
|
|
423
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
424
|
+
));
|
|
297
425
|
}
|
|
298
426
|
|
|
299
427
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
300
428
|
{
|
|
301
429
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
|
302
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
430
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc,
|
|
431
|
+
rawDoc.template getMisc<std::string>("metadata"),
|
|
432
|
+
rawDoc.template getMiscDefault<std::vector<std::string>>("multi_metadata")
|
|
433
|
+
));
|
|
303
434
|
}
|
|
304
435
|
|
|
305
436
|
GETTER(F, size_t, F);
|
|
437
|
+
GETTER(MdVecSize, size_t, mdVecSize);
|
|
306
438
|
GETTER(Sigma, Float, sigma);
|
|
307
439
|
GETTER(AlphaEps, Float, alphaEps);
|
|
308
440
|
GETTER(OptimRepeat, size_t, optimRepeat);
|
|
@@ -317,12 +449,19 @@ namespace tomoto
|
|
|
317
449
|
optimRepeat = _optimRepeat;
|
|
318
450
|
}
|
|
319
451
|
|
|
320
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
452
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
321
453
|
{
|
|
322
454
|
std::vector<Float> ret(this->K);
|
|
323
|
-
auto alphaDoc =
|
|
324
|
-
Eigen::Map<Eigen::
|
|
325
|
-
|
|
455
|
+
auto alphaDoc = getCachedAlpha(doc);
|
|
456
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
|
457
|
+
if (normalize)
|
|
458
|
+
{
|
|
459
|
+
m = (doc.numByTopic.array().template cast<Float>() + alphaDoc.array()) / (doc.getSumWordWeight() + alphaDoc.sum());
|
|
460
|
+
}
|
|
461
|
+
else
|
|
462
|
+
{
|
|
463
|
+
m = doc.numByTopic.array().template cast<Float>() + alphaDoc.array();
|
|
464
|
+
}
|
|
326
465
|
return ret;
|
|
327
466
|
}
|
|
328
467
|
|
|
@@ -330,17 +469,52 @@ namespace tomoto
|
|
|
330
469
|
{
|
|
331
470
|
assert(metadataId < metadataDict.size());
|
|
332
471
|
auto l = lambda.col(metadataId);
|
|
333
|
-
return { l.data(), l.data() +
|
|
472
|
+
return { l.data(), l.data() + l.size() };
|
|
334
473
|
}
|
|
335
474
|
|
|
336
475
|
std::vector<Float> getLambdaByTopic(Tid tid) const override
|
|
337
476
|
{
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
477
|
+
std::vector<Float> ret(F * mdVecSize);
|
|
478
|
+
if (this->lambda.size())
|
|
479
|
+
{
|
|
480
|
+
Eigen::Map<Vector>{ ret.data(), (Eigen::Index)ret.size() } = this->lambda.row(tid);
|
|
481
|
+
}
|
|
482
|
+
return ret;
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
std::vector<Float> getTopicPrior(const std::string& metadata,
|
|
486
|
+
const std::vector<std::string>& mdVec,
|
|
487
|
+
bool raw = false
|
|
488
|
+
) const override
|
|
489
|
+
{
|
|
490
|
+
Vid xid = metadataDict.toWid(metadata);
|
|
491
|
+
if (xid == (Vid)-1) throw exc::InvalidArgument("unknown metadata '" + metadata + "'");
|
|
492
|
+
|
|
493
|
+
Vector xs = Vector::Zero(mdVecSize);
|
|
494
|
+
xs[0] = 1;
|
|
495
|
+
for (auto& m : mdVec)
|
|
496
|
+
{
|
|
497
|
+
Vid x = multiMetadataDict.toWid(m);
|
|
498
|
+
if (x == (Vid)-1) throw exc::InvalidArgument("unknown multi_metadata '" + m + "'");
|
|
499
|
+
xs[x + 1] = 1;
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
std::vector<Float> ret(this->K);
|
|
503
|
+
Eigen::Map<Vector> map{ ret.data(), (Eigen::Index)ret.size() };
|
|
504
|
+
|
|
505
|
+
if (raw)
|
|
506
|
+
{
|
|
507
|
+
map = lambda.middleCols(xid * mdVecSize, mdVecSize) * xs;
|
|
508
|
+
}
|
|
509
|
+
else
|
|
510
|
+
{
|
|
511
|
+
map = (lambda.middleCols(xid * mdVecSize, mdVecSize) * xs).array().exp() + alphaEps;
|
|
512
|
+
}
|
|
513
|
+
return ret;
|
|
341
514
|
}
|
|
342
515
|
|
|
343
516
|
const Dictionary& getMetadataDict() const override { return metadataDict; }
|
|
517
|
+
const Dictionary& getMultiMetadataDict() const override { return multiMetadataDict; }
|
|
344
518
|
};
|
|
345
519
|
|
|
346
520
|
/* This is for preventing 'undefined symbol' problem in compiling by clang. */
|