tomoto 0.2.2 → 0.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/tomoto/ct.cpp +11 -11
- data/ext/tomoto/dmr.cpp +14 -13
- data/ext/tomoto/dt.cpp +14 -14
- data/ext/tomoto/ext.cpp +7 -7
- data/ext/tomoto/extconf.rb +1 -3
- data/ext/tomoto/gdmr.cpp +7 -7
- data/ext/tomoto/hdp.cpp +9 -9
- data/ext/tomoto/hlda.cpp +13 -13
- data/ext/tomoto/hpa.cpp +5 -5
- data/ext/tomoto/lda.cpp +42 -39
- data/ext/tomoto/llda.cpp +6 -6
- data/ext/tomoto/mglda.cpp +15 -15
- data/ext/tomoto/pa.cpp +6 -6
- data/ext/tomoto/plda.cpp +6 -6
- data/ext/tomoto/slda.cpp +8 -8
- data/ext/tomoto/utils.h +16 -70
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +57 -0
- data/vendor/tomotopy/README.rst +55 -0
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +4 -4
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +34 -14
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +4 -1
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +48 -21
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
- data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
- data/vendor/tomotopy/src/Utils/math.h +2 -2
- data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
- metadata +6 -6
@@ -335,7 +335,10 @@ namespace tomoto
|
|
335
335
|
friend typename BaseClass::BaseClass;
|
336
336
|
using WeightType = typename BaseClass::WeightType;
|
337
337
|
|
338
|
-
static constexpr
|
338
|
+
static constexpr auto tmid()
|
339
|
+
{
|
340
|
+
return serializer::to_key("hLDA");
|
341
|
+
}
|
339
342
|
|
340
343
|
Float gamma;
|
341
344
|
|
@@ -422,7 +425,7 @@ namespace tomoto
|
|
422
425
|
}
|
423
426
|
|
424
427
|
template<int _inc>
|
425
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
428
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid level) const
|
426
429
|
{
|
427
430
|
assert(vid < this->realV);
|
428
431
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -143,7 +143,7 @@ namespace tomoto
|
|
143
143
|
}
|
144
144
|
|
145
145
|
template<int _inc>
|
146
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
146
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
|
147
147
|
{
|
148
148
|
assert(vid < this->realV);
|
149
149
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -540,7 +540,7 @@ namespace tomoto
|
|
540
540
|
return ret;
|
541
541
|
}
|
542
542
|
|
543
|
-
std::vector<Float>
|
543
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
544
544
|
{
|
545
545
|
std::vector<Float> ret(1 + this->K + K2);
|
546
546
|
Float sum = doc.getSumWordWeight() + this->alphas.sum();
|
@@ -121,7 +121,7 @@ namespace tomoto
|
|
121
121
|
|
122
122
|
void updateSumWordWeight(size_t realV)
|
123
123
|
{
|
124
|
-
sumWordWeight = std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
|
124
|
+
sumWordWeight = (int32_t)std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
|
125
125
|
{
|
126
126
|
return w < realV;
|
127
127
|
});
|
@@ -164,8 +164,8 @@ namespace tomoto
|
|
164
164
|
struct LDAArgs
|
165
165
|
{
|
166
166
|
size_t k = 1;
|
167
|
-
std::vector<Float> alpha = { 0.1 };
|
168
|
-
Float eta = 0.01;
|
167
|
+
std::vector<Float> alpha = { (Float)0.1 };
|
168
|
+
Float eta = (Float)0.01;
|
169
169
|
size_t seed = std::random_device{}();
|
170
170
|
};
|
171
171
|
|
@@ -82,7 +82,7 @@ namespace tomoto
|
|
82
82
|
friend BaseClass;
|
83
83
|
|
84
84
|
static constexpr const char TWID[] = "one\0";
|
85
|
-
static constexpr
|
85
|
+
static constexpr const char TMID[] = "LDA\0";
|
86
86
|
|
87
87
|
Float alpha;
|
88
88
|
Vector alphas;
|
@@ -125,7 +125,7 @@ namespace tomoto
|
|
125
125
|
}
|
126
126
|
|
127
127
|
template<int _Inc, typename _Vec>
|
128
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
128
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, _Vec tDist) const
|
129
129
|
{
|
130
130
|
assert(vid < this->realV);
|
131
131
|
constexpr bool _dec = _Inc < 0;
|
@@ -392,7 +392,7 @@ namespace tomoto
|
|
392
392
|
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
393
393
|
}
|
394
394
|
|
395
|
-
std::vector<Float>
|
395
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc) const
|
396
396
|
{
|
397
397
|
std::vector<Float> ret(K);
|
398
398
|
Float sum = doc.getSumWordWeight() + K * alpha;
|
@@ -117,19 +117,28 @@ namespace tomoto
|
|
117
117
|
template<>
|
118
118
|
struct TwId<TermWeight::one>
|
119
119
|
{
|
120
|
-
static constexpr
|
120
|
+
static constexpr auto twid()
|
121
|
+
{
|
122
|
+
return serializer::to_key("one\0");
|
123
|
+
}
|
121
124
|
};
|
122
125
|
|
123
126
|
template<>
|
124
127
|
struct TwId<TermWeight::idf>
|
125
128
|
{
|
126
|
-
static constexpr
|
129
|
+
static constexpr auto twid()
|
130
|
+
{
|
131
|
+
return serializer::to_key("idf\0");
|
132
|
+
}
|
127
133
|
};
|
128
134
|
|
129
135
|
template<>
|
130
136
|
struct TwId<TermWeight::pmi>
|
131
137
|
{
|
132
|
-
static constexpr
|
138
|
+
static constexpr auto twid()
|
139
|
+
{
|
140
|
+
return serializer::to_key("pmi\0");
|
141
|
+
}
|
133
142
|
};
|
134
143
|
|
135
144
|
// to make HDP friend of LDA for HDPModel::converToLDA
|
@@ -169,7 +178,11 @@ namespace tomoto
|
|
169
178
|
typename>
|
170
179
|
friend class HDPModel;
|
171
180
|
|
172
|
-
static constexpr
|
181
|
+
static constexpr auto tmid()
|
182
|
+
{
|
183
|
+
return serializer::to_key("LDA\0");
|
184
|
+
}
|
185
|
+
|
173
186
|
using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
|
174
187
|
|
175
188
|
enum { m_flags = _Flags };
|
@@ -189,7 +202,7 @@ namespace tomoto
|
|
189
202
|
struct ExtraDocData
|
190
203
|
{
|
191
204
|
std::vector<Vid> vChunkOffset;
|
192
|
-
Eigen::Matrix<
|
205
|
+
Eigen::Matrix<size_t, -1, -1> chunkOffsetByDoc;
|
193
206
|
};
|
194
207
|
|
195
208
|
ExtraDocData eddTrain;
|
@@ -261,7 +274,7 @@ namespace tomoto
|
|
261
274
|
}
|
262
275
|
|
263
276
|
template<int _inc>
|
264
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
277
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid) const
|
265
278
|
{
|
266
279
|
assert(tid < K);
|
267
280
|
assert(vid < this->realV);
|
@@ -620,7 +633,7 @@ namespace tomoto
|
|
620
633
|
for (Vid v = 0; v < V; ++v)
|
621
634
|
{
|
622
635
|
if (!ld.numByTopicWord(k, v)) continue;
|
623
|
-
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(
|
636
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(k, v)) - math::lgammaT(etaByTopicWord(k, v));
|
624
637
|
assert(std::isfinite(ll));
|
625
638
|
}
|
626
639
|
}
|
@@ -972,12 +985,14 @@ namespace tomoto
|
|
972
985
|
|
973
986
|
void setOptimInterval(size_t _optimInterval) override
|
974
987
|
{
|
975
|
-
|
988
|
+
if (_optimInterval > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
|
989
|
+
optimInterval = (uint32_t)_optimInterval;
|
976
990
|
}
|
977
991
|
|
978
992
|
void setBurnInIteration(size_t iteration) override
|
979
993
|
{
|
980
|
-
|
994
|
+
if (iteration > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
|
995
|
+
burnIn = (uint32_t)iteration;
|
981
996
|
}
|
982
997
|
|
983
998
|
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
@@ -1008,6 +1023,11 @@ namespace tomoto
|
|
1008
1023
|
if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
|
1009
1024
|
}
|
1010
1025
|
this->dict.add(word);
|
1026
|
+
if (this->dict.size() > this->vocabCf.size())
|
1027
|
+
{
|
1028
|
+
this->vocabCf.resize(this->dict.size());
|
1029
|
+
this->vocabDf.resize(this->dict.size());
|
1030
|
+
}
|
1011
1031
|
etaByWord.emplace(word, priors);
|
1012
1032
|
}
|
1013
1033
|
|
@@ -1049,7 +1069,7 @@ namespace tomoto
|
|
1049
1069
|
if (initDocs)
|
1050
1070
|
{
|
1051
1071
|
std::vector<uint32_t> df, cf, tf;
|
1052
|
-
|
1072
|
+
size_t totCf;
|
1053
1073
|
|
1054
1074
|
// calculate weighting
|
1055
1075
|
if (_tw != TermWeight::one)
|
@@ -1064,14 +1084,14 @@ namespace tomoto
|
|
1064
1084
|
++df[w];
|
1065
1085
|
}
|
1066
1086
|
}
|
1067
|
-
totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
|
1087
|
+
totCf = std::accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
|
1068
1088
|
}
|
1069
1089
|
if (_tw == TermWeight::idf)
|
1070
1090
|
{
|
1071
1091
|
vocabWeights.resize(V);
|
1072
1092
|
for (size_t i = 0; i < V; ++i)
|
1073
1093
|
{
|
1074
|
-
vocabWeights[i] = log(this->docs.size() / (
|
1094
|
+
vocabWeights[i] = (Float)log(this->docs.size() / (double)df[i]);
|
1075
1095
|
}
|
1076
1096
|
}
|
1077
1097
|
else if (_tw == TermWeight::pmi)
|
@@ -1079,7 +1099,7 @@ namespace tomoto
|
|
1079
1099
|
vocabWeights.resize(V);
|
1080
1100
|
for (size_t i = 0; i < V; ++i)
|
1081
1101
|
{
|
1082
|
-
vocabWeights[i] = this->vocabCf[i] / (
|
1102
|
+
vocabWeights[i] = (Float)(this->vocabCf[i] / (double)totCf);
|
1083
1103
|
}
|
1084
1104
|
}
|
1085
1105
|
|
@@ -1104,7 +1124,7 @@ namespace tomoto
|
|
1104
1124
|
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
1105
1125
|
}
|
1106
1126
|
|
1107
|
-
std::vector<Float>
|
1127
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
1108
1128
|
{
|
1109
1129
|
std::vector<Float> ret(K);
|
1110
1130
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
|
@@ -26,7 +26,10 @@ namespace tomoto
|
|
26
26
|
friend typename BaseClass::BaseClass;
|
27
27
|
using WeightType = typename BaseClass::WeightType;
|
28
28
|
|
29
|
-
static constexpr
|
29
|
+
static constexpr auto tmid()
|
30
|
+
{
|
31
|
+
return serializer::to_key("LLDA");
|
32
|
+
}
|
30
33
|
|
31
34
|
Dictionary topicLabelDict;
|
32
35
|
|
@@ -171,7 +174,7 @@ namespace tomoto
|
|
171
174
|
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
172
175
|
}
|
173
176
|
|
174
|
-
std::vector<Float>
|
177
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
175
178
|
{
|
176
179
|
std::vector<Float> ret(this->K);
|
177
180
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
@@ -63,7 +63,7 @@ namespace tomoto
|
|
63
63
|
}
|
64
64
|
|
65
65
|
template<int _inc>
|
66
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
66
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
|
67
67
|
{
|
68
68
|
const auto K = this->K;
|
69
69
|
|
@@ -527,7 +527,7 @@ namespace tomoto
|
|
527
527
|
this->etaByWord.emplace(word, priors);
|
528
528
|
}
|
529
529
|
|
530
|
-
std::vector<Float>
|
530
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
531
531
|
{
|
532
532
|
std::vector<Float> ret(this->K + KL);
|
533
533
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
|
@@ -90,7 +90,7 @@ namespace tomoto
|
|
90
90
|
}
|
91
91
|
|
92
92
|
template<int _inc>
|
93
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
93
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
|
94
94
|
{
|
95
95
|
assert(vid < this->realV);
|
96
96
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -26,7 +26,10 @@ namespace tomoto
|
|
26
26
|
friend typename BaseClass::BaseClass;
|
27
27
|
using WeightType = typename BaseClass::WeightType;
|
28
28
|
|
29
|
-
static constexpr
|
29
|
+
static constexpr auto tmid()
|
30
|
+
{
|
31
|
+
return serializer::to_key("PLDA");
|
32
|
+
}
|
30
33
|
|
31
34
|
Dictionary topicLabelDict;
|
32
35
|
|
@@ -178,7 +181,7 @@ namespace tomoto
|
|
178
181
|
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
179
182
|
}
|
180
183
|
|
181
|
-
std::vector<Float>
|
184
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
182
185
|
{
|
183
186
|
std::vector<Float> ret(this->K);
|
184
187
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
@@ -38,7 +38,10 @@ namespace tomoto
|
|
38
38
|
friend typename BaseClass::BaseClass;
|
39
39
|
using WeightType = typename BaseClass::WeightType;
|
40
40
|
|
41
|
-
static constexpr
|
41
|
+
static constexpr auto tmid()
|
42
|
+
{
|
43
|
+
return serializer::to_key("PTM");
|
44
|
+
}
|
42
45
|
|
43
46
|
uint64_t numPDocs;
|
44
47
|
Float lambda;
|
@@ -261,7 +264,7 @@ namespace tomoto
|
|
261
264
|
{
|
262
265
|
}
|
263
266
|
|
264
|
-
std::vector<Float>
|
267
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
265
268
|
{
|
266
269
|
std::vector<Float> ret(this->K);
|
267
270
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
@@ -216,7 +216,10 @@ namespace tomoto
|
|
216
216
|
friend typename BaseClass::BaseClass;
|
217
217
|
using WeightType = typename BaseClass::WeightType;
|
218
218
|
|
219
|
-
static constexpr
|
219
|
+
static constexpr auto tmid()
|
220
|
+
{
|
221
|
+
return serializer::to_key("SLDA");
|
222
|
+
}
|
220
223
|
|
221
224
|
uint64_t F; // number of response variables
|
222
225
|
std::vector<ISLDAModel::GLM> varTypes;
|
@@ -249,6 +249,7 @@ namespace tomoto
|
|
249
249
|
virtual size_t getNumDocs() const = 0;
|
250
250
|
virtual const Dictionary& getVocabDict() const = 0;
|
251
251
|
virtual const std::vector<uint64_t>& getVocabCf() const = 0;
|
252
|
+
virtual std::vector<double> getVocabWeightedCf() const = 0;
|
252
253
|
virtual const std::vector<uint64_t>& getVocabDf() const = 0;
|
253
254
|
|
254
255
|
virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
|
@@ -319,6 +320,7 @@ namespace tomoto
|
|
319
320
|
Dictionary dict;
|
320
321
|
uint64_t realV = 0; // vocab size after removing stopwords
|
321
322
|
uint64_t realN = 0; // total word size after removing stopwords
|
323
|
+
double weightedN = 0;
|
322
324
|
size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
|
323
325
|
size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
|
324
326
|
|
@@ -327,15 +329,17 @@ namespace tomoto
|
|
327
329
|
void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
|
328
330
|
{
|
329
331
|
serializer::writeMany(writer,
|
330
|
-
serializer::to_keyz(static_cast<const _Derived*>(this)->
|
331
|
-
serializer::to_keyz(static_cast<const _Derived*>(this)->
|
332
|
+
serializer::to_keyz(static_cast<const _Derived*>(this)->tmid()),
|
333
|
+
serializer::to_keyz(static_cast<const _Derived*>(this)->twid())
|
334
|
+
);
|
332
335
|
serializer::writeTaggedMany(writer, 0x00010001,
|
333
336
|
serializer::to_keyz("dict"), dict,
|
334
337
|
serializer::to_keyz("vocabCf"), vocabCf,
|
335
338
|
serializer::to_keyz("vocabDf"), vocabDf,
|
336
339
|
serializer::to_keyz("realV"), realV,
|
337
340
|
serializer::to_keyz("globalStep"), globalStep,
|
338
|
-
serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
|
341
|
+
serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
|
342
|
+
);
|
339
343
|
serializer::writeMany(writer, *static_cast<const _Derived*>(this));
|
340
344
|
globalState.serializerWrite(writer);
|
341
345
|
if (fullModel)
|
@@ -355,8 +359,9 @@ namespace tomoto
|
|
355
359
|
{
|
356
360
|
std::vector<uint8_t> extra;
|
357
361
|
serializer::readMany(reader,
|
358
|
-
serializer::to_keyz(static_cast<_Derived*>(this)->
|
359
|
-
serializer::to_keyz(static_cast<_Derived*>(this)->
|
362
|
+
serializer::to_keyz(static_cast<_Derived*>(this)->tmid()),
|
363
|
+
serializer::to_keyz(static_cast<_Derived*>(this)->twid())
|
364
|
+
);
|
360
365
|
serializer::readTaggedMany(reader, 0x00010001,
|
361
366
|
serializer::to_keyz("dict"), dict,
|
362
367
|
serializer::to_keyz("vocabCf"), vocabCf,
|
@@ -370,14 +375,17 @@ namespace tomoto
|
|
370
375
|
{
|
371
376
|
reader.seekg(start_pos);
|
372
377
|
serializer::readMany(reader,
|
373
|
-
serializer::to_key(static_cast<_Derived*>(this)->
|
374
|
-
serializer::to_key(static_cast<_Derived*>(this)->
|
375
|
-
dict, vocabCf, realV
|
378
|
+
serializer::to_key(static_cast<_Derived*>(this)->tmid()),
|
379
|
+
serializer::to_key(static_cast<_Derived*>(this)->twid()),
|
380
|
+
dict, vocabCf, realV
|
381
|
+
);
|
376
382
|
}
|
377
383
|
serializer::readMany(reader, *static_cast<_Derived*>(this));
|
378
384
|
globalState.serializerRead(reader);
|
379
385
|
serializer::readMany(reader, docs);
|
380
|
-
|
386
|
+
auto p = countRealN();
|
387
|
+
realN = p.first;
|
388
|
+
weightedN = p.second;
|
381
389
|
}
|
382
390
|
|
383
391
|
template<typename _DocTy>
|
@@ -490,17 +498,23 @@ namespace tomoto
|
|
490
498
|
}
|
491
499
|
}
|
492
500
|
|
493
|
-
size_t countRealN() const
|
501
|
+
std::pair<size_t, double> countRealN() const
|
494
502
|
{
|
495
503
|
size_t n = 0;
|
504
|
+
double weighted = 0;
|
496
505
|
for (auto& doc : docs)
|
497
506
|
{
|
498
|
-
for (
|
507
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
499
508
|
{
|
500
|
-
|
509
|
+
auto w = doc.words[i];
|
510
|
+
if (w < realV)
|
511
|
+
{
|
512
|
+
++n;
|
513
|
+
weighted += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
|
514
|
+
}
|
501
515
|
}
|
502
516
|
}
|
503
|
-
return n;
|
517
|
+
return std::make_pair(n, weighted);
|
504
518
|
}
|
505
519
|
|
506
520
|
void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN)
|
@@ -544,14 +558,9 @@ namespace tomoto
|
|
544
558
|
}
|
545
559
|
|
546
560
|
dict.reorder(order);
|
547
|
-
realN = 0;
|
548
561
|
for (auto& doc : docs)
|
549
562
|
{
|
550
|
-
for (auto& w : doc.words)
|
551
|
-
{
|
552
|
-
w = order[w];
|
553
|
-
if (w < realV) ++realN;
|
554
|
-
}
|
563
|
+
for (auto& w : doc.words) w = order[w];
|
555
564
|
}
|
556
565
|
}
|
557
566
|
|
@@ -598,6 +607,10 @@ namespace tomoto
|
|
598
607
|
|
599
608
|
void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
|
600
609
|
{
|
610
|
+
auto p = countRealN();
|
611
|
+
realN = p.first;
|
612
|
+
weightedN = p.second;
|
613
|
+
|
601
614
|
maxThreads[(size_t)ParallelScheme::default_] = -1;
|
602
615
|
maxThreads[(size_t)ParallelScheme::none] = -1;
|
603
616
|
maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads<ParallelScheme::copy_merge>();
|
@@ -697,7 +710,7 @@ namespace tomoto
|
|
697
710
|
|
698
711
|
double getLLPerWord() const override
|
699
712
|
{
|
700
|
-
return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() /
|
713
|
+
return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / weightedN;
|
701
714
|
}
|
702
715
|
|
703
716
|
double getPerplexity() const override
|
@@ -797,7 +810,7 @@ namespace tomoto
|
|
797
810
|
|
798
811
|
std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
|
799
812
|
{
|
800
|
-
return static_cast<const _Derived*>(this)->
|
813
|
+
return static_cast<const _Derived*>(this)->_getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
|
801
814
|
}
|
802
815
|
|
803
816
|
std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
|
@@ -832,6 +845,20 @@ namespace tomoto
|
|
832
845
|
return vocabCf;
|
833
846
|
}
|
834
847
|
|
848
|
+
std::vector<double> getVocabWeightedCf() const override
|
849
|
+
{
|
850
|
+
std::vector<double> ret(realV);
|
851
|
+
for (auto& doc : docs)
|
852
|
+
{
|
853
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
854
|
+
{
|
855
|
+
if (doc.words[i] >= realV) continue;
|
856
|
+
ret[doc.words[i]] += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
|
857
|
+
}
|
858
|
+
}
|
859
|
+
return ret;
|
860
|
+
}
|
861
|
+
|
835
862
|
const std::vector<uint64_t>& getVocabDf() const override
|
836
863
|
{
|
837
864
|
return vocabDf;
|