tomoto 0.2.2 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/tomoto/ct.cpp +11 -11
- data/ext/tomoto/dmr.cpp +14 -13
- data/ext/tomoto/dt.cpp +14 -14
- data/ext/tomoto/ext.cpp +7 -7
- data/ext/tomoto/extconf.rb +1 -3
- data/ext/tomoto/gdmr.cpp +7 -7
- data/ext/tomoto/hdp.cpp +9 -9
- data/ext/tomoto/hlda.cpp +13 -13
- data/ext/tomoto/hpa.cpp +5 -5
- data/ext/tomoto/lda.cpp +42 -39
- data/ext/tomoto/llda.cpp +6 -6
- data/ext/tomoto/mglda.cpp +15 -15
- data/ext/tomoto/pa.cpp +6 -6
- data/ext/tomoto/plda.cpp +6 -6
- data/ext/tomoto/slda.cpp +8 -8
- data/ext/tomoto/utils.h +16 -70
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +57 -0
- data/vendor/tomotopy/README.rst +55 -0
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +4 -4
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +34 -14
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +4 -1
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +48 -21
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
- data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
- data/vendor/tomotopy/src/Utils/math.h +2 -2
- data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
- metadata +6 -6
@@ -335,7 +335,10 @@ namespace tomoto
|
|
335
335
|
friend typename BaseClass::BaseClass;
|
336
336
|
using WeightType = typename BaseClass::WeightType;
|
337
337
|
|
338
|
-
static constexpr
|
338
|
+
static constexpr auto tmid()
|
339
|
+
{
|
340
|
+
return serializer::to_key("hLDA");
|
341
|
+
}
|
339
342
|
|
340
343
|
Float gamma;
|
341
344
|
|
@@ -422,7 +425,7 @@ namespace tomoto
|
|
422
425
|
}
|
423
426
|
|
424
427
|
template<int _inc>
|
425
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
428
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid level) const
|
426
429
|
{
|
427
430
|
assert(vid < this->realV);
|
428
431
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -143,7 +143,7 @@ namespace tomoto
|
|
143
143
|
}
|
144
144
|
|
145
145
|
template<int _inc>
|
146
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
146
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
|
147
147
|
{
|
148
148
|
assert(vid < this->realV);
|
149
149
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -540,7 +540,7 @@ namespace tomoto
|
|
540
540
|
return ret;
|
541
541
|
}
|
542
542
|
|
543
|
-
std::vector<Float>
|
543
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
544
544
|
{
|
545
545
|
std::vector<Float> ret(1 + this->K + K2);
|
546
546
|
Float sum = doc.getSumWordWeight() + this->alphas.sum();
|
@@ -121,7 +121,7 @@ namespace tomoto
|
|
121
121
|
|
122
122
|
void updateSumWordWeight(size_t realV)
|
123
123
|
{
|
124
|
-
sumWordWeight = std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
|
124
|
+
sumWordWeight = (int32_t)std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
|
125
125
|
{
|
126
126
|
return w < realV;
|
127
127
|
});
|
@@ -164,8 +164,8 @@ namespace tomoto
|
|
164
164
|
struct LDAArgs
|
165
165
|
{
|
166
166
|
size_t k = 1;
|
167
|
-
std::vector<Float> alpha = { 0.1 };
|
168
|
-
Float eta = 0.01;
|
167
|
+
std::vector<Float> alpha = { (Float)0.1 };
|
168
|
+
Float eta = (Float)0.01;
|
169
169
|
size_t seed = std::random_device{}();
|
170
170
|
};
|
171
171
|
|
@@ -82,7 +82,7 @@ namespace tomoto
|
|
82
82
|
friend BaseClass;
|
83
83
|
|
84
84
|
static constexpr const char TWID[] = "one\0";
|
85
|
-
static constexpr
|
85
|
+
static constexpr const char TMID[] = "LDA\0";
|
86
86
|
|
87
87
|
Float alpha;
|
88
88
|
Vector alphas;
|
@@ -125,7 +125,7 @@ namespace tomoto
|
|
125
125
|
}
|
126
126
|
|
127
127
|
template<int _Inc, typename _Vec>
|
128
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
128
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, _Vec tDist) const
|
129
129
|
{
|
130
130
|
assert(vid < this->realV);
|
131
131
|
constexpr bool _dec = _Inc < 0;
|
@@ -392,7 +392,7 @@ namespace tomoto
|
|
392
392
|
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
393
393
|
}
|
394
394
|
|
395
|
-
std::vector<Float>
|
395
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc) const
|
396
396
|
{
|
397
397
|
std::vector<Float> ret(K);
|
398
398
|
Float sum = doc.getSumWordWeight() + K * alpha;
|
@@ -117,19 +117,28 @@ namespace tomoto
|
|
117
117
|
template<>
|
118
118
|
struct TwId<TermWeight::one>
|
119
119
|
{
|
120
|
-
static constexpr
|
120
|
+
static constexpr auto twid()
|
121
|
+
{
|
122
|
+
return serializer::to_key("one\0");
|
123
|
+
}
|
121
124
|
};
|
122
125
|
|
123
126
|
template<>
|
124
127
|
struct TwId<TermWeight::idf>
|
125
128
|
{
|
126
|
-
static constexpr
|
129
|
+
static constexpr auto twid()
|
130
|
+
{
|
131
|
+
return serializer::to_key("idf\0");
|
132
|
+
}
|
127
133
|
};
|
128
134
|
|
129
135
|
template<>
|
130
136
|
struct TwId<TermWeight::pmi>
|
131
137
|
{
|
132
|
-
static constexpr
|
138
|
+
static constexpr auto twid()
|
139
|
+
{
|
140
|
+
return serializer::to_key("pmi\0");
|
141
|
+
}
|
133
142
|
};
|
134
143
|
|
135
144
|
// to make HDP friend of LDA for HDPModel::converToLDA
|
@@ -169,7 +178,11 @@ namespace tomoto
|
|
169
178
|
typename>
|
170
179
|
friend class HDPModel;
|
171
180
|
|
172
|
-
static constexpr
|
181
|
+
static constexpr auto tmid()
|
182
|
+
{
|
183
|
+
return serializer::to_key("LDA\0");
|
184
|
+
}
|
185
|
+
|
173
186
|
using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
|
174
187
|
|
175
188
|
enum { m_flags = _Flags };
|
@@ -189,7 +202,7 @@ namespace tomoto
|
|
189
202
|
struct ExtraDocData
|
190
203
|
{
|
191
204
|
std::vector<Vid> vChunkOffset;
|
192
|
-
Eigen::Matrix<
|
205
|
+
Eigen::Matrix<size_t, -1, -1> chunkOffsetByDoc;
|
193
206
|
};
|
194
207
|
|
195
208
|
ExtraDocData eddTrain;
|
@@ -261,7 +274,7 @@ namespace tomoto
|
|
261
274
|
}
|
262
275
|
|
263
276
|
template<int _inc>
|
264
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
277
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid) const
|
265
278
|
{
|
266
279
|
assert(tid < K);
|
267
280
|
assert(vid < this->realV);
|
@@ -620,7 +633,7 @@ namespace tomoto
|
|
620
633
|
for (Vid v = 0; v < V; ++v)
|
621
634
|
{
|
622
635
|
if (!ld.numByTopicWord(k, v)) continue;
|
623
|
-
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(
|
636
|
+
ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(k, v)) - math::lgammaT(etaByTopicWord(k, v));
|
624
637
|
assert(std::isfinite(ll));
|
625
638
|
}
|
626
639
|
}
|
@@ -972,12 +985,14 @@ namespace tomoto
|
|
972
985
|
|
973
986
|
void setOptimInterval(size_t _optimInterval) override
|
974
987
|
{
|
975
|
-
|
988
|
+
if (_optimInterval > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
|
989
|
+
optimInterval = (uint32_t)_optimInterval;
|
976
990
|
}
|
977
991
|
|
978
992
|
void setBurnInIteration(size_t iteration) override
|
979
993
|
{
|
980
|
-
|
994
|
+
if (iteration > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
|
995
|
+
burnIn = (uint32_t)iteration;
|
981
996
|
}
|
982
997
|
|
983
998
|
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
@@ -1008,6 +1023,11 @@ namespace tomoto
|
|
1008
1023
|
if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
|
1009
1024
|
}
|
1010
1025
|
this->dict.add(word);
|
1026
|
+
if (this->dict.size() > this->vocabCf.size())
|
1027
|
+
{
|
1028
|
+
this->vocabCf.resize(this->dict.size());
|
1029
|
+
this->vocabDf.resize(this->dict.size());
|
1030
|
+
}
|
1011
1031
|
etaByWord.emplace(word, priors);
|
1012
1032
|
}
|
1013
1033
|
|
@@ -1049,7 +1069,7 @@ namespace tomoto
|
|
1049
1069
|
if (initDocs)
|
1050
1070
|
{
|
1051
1071
|
std::vector<uint32_t> df, cf, tf;
|
1052
|
-
|
1072
|
+
size_t totCf;
|
1053
1073
|
|
1054
1074
|
// calculate weighting
|
1055
1075
|
if (_tw != TermWeight::one)
|
@@ -1064,14 +1084,14 @@ namespace tomoto
|
|
1064
1084
|
++df[w];
|
1065
1085
|
}
|
1066
1086
|
}
|
1067
|
-
totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
|
1087
|
+
totCf = std::accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
|
1068
1088
|
}
|
1069
1089
|
if (_tw == TermWeight::idf)
|
1070
1090
|
{
|
1071
1091
|
vocabWeights.resize(V);
|
1072
1092
|
for (size_t i = 0; i < V; ++i)
|
1073
1093
|
{
|
1074
|
-
vocabWeights[i] = log(this->docs.size() / (
|
1094
|
+
vocabWeights[i] = (Float)log(this->docs.size() / (double)df[i]);
|
1075
1095
|
}
|
1076
1096
|
}
|
1077
1097
|
else if (_tw == TermWeight::pmi)
|
@@ -1079,7 +1099,7 @@ namespace tomoto
|
|
1079
1099
|
vocabWeights.resize(V);
|
1080
1100
|
for (size_t i = 0; i < V; ++i)
|
1081
1101
|
{
|
1082
|
-
vocabWeights[i] = this->vocabCf[i] / (
|
1102
|
+
vocabWeights[i] = (Float)(this->vocabCf[i] / (double)totCf);
|
1083
1103
|
}
|
1084
1104
|
}
|
1085
1105
|
|
@@ -1104,7 +1124,7 @@ namespace tomoto
|
|
1104
1124
|
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
1105
1125
|
}
|
1106
1126
|
|
1107
|
-
std::vector<Float>
|
1127
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
1108
1128
|
{
|
1109
1129
|
std::vector<Float> ret(K);
|
1110
1130
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
|
@@ -26,7 +26,10 @@ namespace tomoto
|
|
26
26
|
friend typename BaseClass::BaseClass;
|
27
27
|
using WeightType = typename BaseClass::WeightType;
|
28
28
|
|
29
|
-
static constexpr
|
29
|
+
static constexpr auto tmid()
|
30
|
+
{
|
31
|
+
return serializer::to_key("LLDA");
|
32
|
+
}
|
30
33
|
|
31
34
|
Dictionary topicLabelDict;
|
32
35
|
|
@@ -171,7 +174,7 @@ namespace tomoto
|
|
171
174
|
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
172
175
|
}
|
173
176
|
|
174
|
-
std::vector<Float>
|
177
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
175
178
|
{
|
176
179
|
std::vector<Float> ret(this->K);
|
177
180
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
@@ -63,7 +63,7 @@ namespace tomoto
|
|
63
63
|
}
|
64
64
|
|
65
65
|
template<int _inc>
|
66
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
66
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
|
67
67
|
{
|
68
68
|
const auto K = this->K;
|
69
69
|
|
@@ -527,7 +527,7 @@ namespace tomoto
|
|
527
527
|
this->etaByWord.emplace(word, priors);
|
528
528
|
}
|
529
529
|
|
530
|
-
std::vector<Float>
|
530
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
531
531
|
{
|
532
532
|
std::vector<Float> ret(this->K + KL);
|
533
533
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
|
@@ -90,7 +90,7 @@ namespace tomoto
|
|
90
90
|
}
|
91
91
|
|
92
92
|
template<int _inc>
|
93
|
-
inline void addWordTo(_ModelState& ld, _DocType& doc,
|
93
|
+
inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
|
94
94
|
{
|
95
95
|
assert(vid < this->realV);
|
96
96
|
constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
|
@@ -26,7 +26,10 @@ namespace tomoto
|
|
26
26
|
friend typename BaseClass::BaseClass;
|
27
27
|
using WeightType = typename BaseClass::WeightType;
|
28
28
|
|
29
|
-
static constexpr
|
29
|
+
static constexpr auto tmid()
|
30
|
+
{
|
31
|
+
return serializer::to_key("PLDA");
|
32
|
+
}
|
30
33
|
|
31
34
|
Dictionary topicLabelDict;
|
32
35
|
|
@@ -178,7 +181,7 @@ namespace tomoto
|
|
178
181
|
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
179
182
|
}
|
180
183
|
|
181
|
-
std::vector<Float>
|
184
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
182
185
|
{
|
183
186
|
std::vector<Float> ret(this->K);
|
184
187
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
@@ -38,7 +38,10 @@ namespace tomoto
|
|
38
38
|
friend typename BaseClass::BaseClass;
|
39
39
|
using WeightType = typename BaseClass::WeightType;
|
40
40
|
|
41
|
-
static constexpr
|
41
|
+
static constexpr auto tmid()
|
42
|
+
{
|
43
|
+
return serializer::to_key("PTM");
|
44
|
+
}
|
42
45
|
|
43
46
|
uint64_t numPDocs;
|
44
47
|
Float lambda;
|
@@ -261,7 +264,7 @@ namespace tomoto
|
|
261
264
|
{
|
262
265
|
}
|
263
266
|
|
264
|
-
std::vector<Float>
|
267
|
+
std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
|
265
268
|
{
|
266
269
|
std::vector<Float> ret(this->K);
|
267
270
|
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
@@ -216,7 +216,10 @@ namespace tomoto
|
|
216
216
|
friend typename BaseClass::BaseClass;
|
217
217
|
using WeightType = typename BaseClass::WeightType;
|
218
218
|
|
219
|
-
static constexpr
|
219
|
+
static constexpr auto tmid()
|
220
|
+
{
|
221
|
+
return serializer::to_key("SLDA");
|
222
|
+
}
|
220
223
|
|
221
224
|
uint64_t F; // number of response variables
|
222
225
|
std::vector<ISLDAModel::GLM> varTypes;
|
@@ -249,6 +249,7 @@ namespace tomoto
|
|
249
249
|
virtual size_t getNumDocs() const = 0;
|
250
250
|
virtual const Dictionary& getVocabDict() const = 0;
|
251
251
|
virtual const std::vector<uint64_t>& getVocabCf() const = 0;
|
252
|
+
virtual std::vector<double> getVocabWeightedCf() const = 0;
|
252
253
|
virtual const std::vector<uint64_t>& getVocabDf() const = 0;
|
253
254
|
|
254
255
|
virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
|
@@ -319,6 +320,7 @@ namespace tomoto
|
|
319
320
|
Dictionary dict;
|
320
321
|
uint64_t realV = 0; // vocab size after removing stopwords
|
321
322
|
uint64_t realN = 0; // total word size after removing stopwords
|
323
|
+
double weightedN = 0;
|
322
324
|
size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
|
323
325
|
size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
|
324
326
|
|
@@ -327,15 +329,17 @@ namespace tomoto
|
|
327
329
|
void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
|
328
330
|
{
|
329
331
|
serializer::writeMany(writer,
|
330
|
-
serializer::to_keyz(static_cast<const _Derived*>(this)->
|
331
|
-
serializer::to_keyz(static_cast<const _Derived*>(this)->
|
332
|
+
serializer::to_keyz(static_cast<const _Derived*>(this)->tmid()),
|
333
|
+
serializer::to_keyz(static_cast<const _Derived*>(this)->twid())
|
334
|
+
);
|
332
335
|
serializer::writeTaggedMany(writer, 0x00010001,
|
333
336
|
serializer::to_keyz("dict"), dict,
|
334
337
|
serializer::to_keyz("vocabCf"), vocabCf,
|
335
338
|
serializer::to_keyz("vocabDf"), vocabDf,
|
336
339
|
serializer::to_keyz("realV"), realV,
|
337
340
|
serializer::to_keyz("globalStep"), globalStep,
|
338
|
-
serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
|
341
|
+
serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
|
342
|
+
);
|
339
343
|
serializer::writeMany(writer, *static_cast<const _Derived*>(this));
|
340
344
|
globalState.serializerWrite(writer);
|
341
345
|
if (fullModel)
|
@@ -355,8 +359,9 @@ namespace tomoto
|
|
355
359
|
{
|
356
360
|
std::vector<uint8_t> extra;
|
357
361
|
serializer::readMany(reader,
|
358
|
-
serializer::to_keyz(static_cast<_Derived*>(this)->
|
359
|
-
serializer::to_keyz(static_cast<_Derived*>(this)->
|
362
|
+
serializer::to_keyz(static_cast<_Derived*>(this)->tmid()),
|
363
|
+
serializer::to_keyz(static_cast<_Derived*>(this)->twid())
|
364
|
+
);
|
360
365
|
serializer::readTaggedMany(reader, 0x00010001,
|
361
366
|
serializer::to_keyz("dict"), dict,
|
362
367
|
serializer::to_keyz("vocabCf"), vocabCf,
|
@@ -370,14 +375,17 @@ namespace tomoto
|
|
370
375
|
{
|
371
376
|
reader.seekg(start_pos);
|
372
377
|
serializer::readMany(reader,
|
373
|
-
serializer::to_key(static_cast<_Derived*>(this)->
|
374
|
-
serializer::to_key(static_cast<_Derived*>(this)->
|
375
|
-
dict, vocabCf, realV
|
378
|
+
serializer::to_key(static_cast<_Derived*>(this)->tmid()),
|
379
|
+
serializer::to_key(static_cast<_Derived*>(this)->twid()),
|
380
|
+
dict, vocabCf, realV
|
381
|
+
);
|
376
382
|
}
|
377
383
|
serializer::readMany(reader, *static_cast<_Derived*>(this));
|
378
384
|
globalState.serializerRead(reader);
|
379
385
|
serializer::readMany(reader, docs);
|
380
|
-
|
386
|
+
auto p = countRealN();
|
387
|
+
realN = p.first;
|
388
|
+
weightedN = p.second;
|
381
389
|
}
|
382
390
|
|
383
391
|
template<typename _DocTy>
|
@@ -490,17 +498,23 @@ namespace tomoto
|
|
490
498
|
}
|
491
499
|
}
|
492
500
|
|
493
|
-
size_t countRealN() const
|
501
|
+
std::pair<size_t, double> countRealN() const
|
494
502
|
{
|
495
503
|
size_t n = 0;
|
504
|
+
double weighted = 0;
|
496
505
|
for (auto& doc : docs)
|
497
506
|
{
|
498
|
-
for (
|
507
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
499
508
|
{
|
500
|
-
|
509
|
+
auto w = doc.words[i];
|
510
|
+
if (w < realV)
|
511
|
+
{
|
512
|
+
++n;
|
513
|
+
weighted += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
|
514
|
+
}
|
501
515
|
}
|
502
516
|
}
|
503
|
-
return n;
|
517
|
+
return std::make_pair(n, weighted);
|
504
518
|
}
|
505
519
|
|
506
520
|
void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN)
|
@@ -544,14 +558,9 @@ namespace tomoto
|
|
544
558
|
}
|
545
559
|
|
546
560
|
dict.reorder(order);
|
547
|
-
realN = 0;
|
548
561
|
for (auto& doc : docs)
|
549
562
|
{
|
550
|
-
for (auto& w : doc.words)
|
551
|
-
{
|
552
|
-
w = order[w];
|
553
|
-
if (w < realV) ++realN;
|
554
|
-
}
|
563
|
+
for (auto& w : doc.words) w = order[w];
|
555
564
|
}
|
556
565
|
}
|
557
566
|
|
@@ -598,6 +607,10 @@ namespace tomoto
|
|
598
607
|
|
599
608
|
void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
|
600
609
|
{
|
610
|
+
auto p = countRealN();
|
611
|
+
realN = p.first;
|
612
|
+
weightedN = p.second;
|
613
|
+
|
601
614
|
maxThreads[(size_t)ParallelScheme::default_] = -1;
|
602
615
|
maxThreads[(size_t)ParallelScheme::none] = -1;
|
603
616
|
maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads<ParallelScheme::copy_merge>();
|
@@ -697,7 +710,7 @@ namespace tomoto
|
|
697
710
|
|
698
711
|
double getLLPerWord() const override
|
699
712
|
{
|
700
|
-
return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() /
|
713
|
+
return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / weightedN;
|
701
714
|
}
|
702
715
|
|
703
716
|
double getPerplexity() const override
|
@@ -797,7 +810,7 @@ namespace tomoto
|
|
797
810
|
|
798
811
|
std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
|
799
812
|
{
|
800
|
-
return static_cast<const _Derived*>(this)->
|
813
|
+
return static_cast<const _Derived*>(this)->_getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
|
801
814
|
}
|
802
815
|
|
803
816
|
std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
|
@@ -832,6 +845,20 @@ namespace tomoto
|
|
832
845
|
return vocabCf;
|
833
846
|
}
|
834
847
|
|
848
|
+
std::vector<double> getVocabWeightedCf() const override
|
849
|
+
{
|
850
|
+
std::vector<double> ret(realV);
|
851
|
+
for (auto& doc : docs)
|
852
|
+
{
|
853
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
854
|
+
{
|
855
|
+
if (doc.words[i] >= realV) continue;
|
856
|
+
ret[doc.words[i]] += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
|
857
|
+
}
|
858
|
+
}
|
859
|
+
return ret;
|
860
|
+
}
|
861
|
+
|
835
862
|
const std::vector<uint64_t>& getVocabDf() const override
|
836
863
|
{
|
837
864
|
return vocabDf;
|