tomoto 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +3 -3
- data/ext/tomoto/ext.cpp +34 -9
- data/ext/tomoto/extconf.rb +2 -1
- data/lib/tomoto/dmr.rb +1 -1
- data/lib/tomoto/gdmr.rb +1 -1
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/LICENSE +1 -1
- data/vendor/tomotopy/README.kr.rst +32 -3
- data/vendor/tomotopy/README.rst +30 -1
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +133 -147
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +158 -5
- data/vendor/tomotopy/src/TopicModel/DMR.h +1 -16
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +15 -34
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -16
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +15 -32
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +18 -37
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +16 -20
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/LDA.h +0 -11
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +9 -21
- data/vendor/tomotopy/src/TopicModel/LLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +12 -30
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +59 -72
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +12 -30
- data/vendor/tomotopy/src/TopicModel/SLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +17 -35
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +158 -38
- data/vendor/tomotopy/src/Utils/Dictionary.h +40 -2
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +122 -3
- data/vendor/tomotopy/src/Utils/SharedString.hpp +181 -0
- data/vendor/tomotopy/src/Utils/math.h +1 -1
- data/vendor/tomotopy/src/Utils/sample.hpp +1 -1
- data/vendor/tomotopy/src/Utils/serializer.hpp +17 -0
- data/vendor/variant/LICENSE +25 -0
- data/vendor/variant/LICENSE_1_0.txt +23 -0
- data/vendor/variant/README.md +102 -0
- data/vendor/variant/include/mapbox/optional.hpp +74 -0
- data/vendor/variant/include/mapbox/recursive_wrapper.hpp +122 -0
- data/vendor/variant/include/mapbox/variant.hpp +974 -0
- data/vendor/variant/include/mapbox/variant_io.hpp +45 -0
- metadata +15 -7
@@ -153,46 +153,28 @@ namespace tomoto
|
|
153
153
|
return doc;
|
154
154
|
}
|
155
155
|
|
156
|
-
size_t addDoc(const
|
156
|
+
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
157
157
|
{
|
158
|
-
auto doc = this->
|
159
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
158
|
+
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
159
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
160
160
|
}
|
161
161
|
|
162
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
162
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
163
163
|
{
|
164
|
-
auto doc = as_mutable(this)->template
|
165
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
164
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
165
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
166
166
|
}
|
167
167
|
|
168
|
-
size_t addDoc(const
|
169
|
-
const std::vector<std::string>& labels) override
|
168
|
+
size_t addDoc(const RawDoc& rawDoc) override
|
170
169
|
{
|
171
|
-
auto doc = this->
|
172
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
170
|
+
auto doc = this->_makeFromRawDoc(rawDoc);
|
171
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
173
172
|
}
|
174
173
|
|
175
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
176
|
-
const std::vector<std::string>& labels) const override
|
174
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
177
175
|
{
|
178
|
-
auto doc = as_mutable(this)->template
|
179
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
180
|
-
}
|
181
|
-
|
182
|
-
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
183
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
184
|
-
const std::vector<std::string>& labels) override
|
185
|
-
{
|
186
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
187
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
188
|
-
}
|
189
|
-
|
190
|
-
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
191
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
192
|
-
const std::vector<std::string>& labels) const override
|
193
|
-
{
|
194
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
195
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
176
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
177
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
196
178
|
}
|
197
179
|
|
198
180
|
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
@@ -31,21 +31,6 @@ namespace tomoto
|
|
31
31
|
size_t seed = std::random_device{}(),
|
32
32
|
bool scalarRng = false);
|
33
33
|
|
34
|
-
virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
|
35
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
|
36
|
-
|
37
|
-
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
38
|
-
const std::vector<Float>& y) = 0;
|
39
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
40
|
-
const std::vector<Float>& y) const = 0;
|
41
|
-
|
42
|
-
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
43
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
44
|
-
const std::vector<Float>& y) = 0;
|
45
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
46
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
47
|
-
const std::vector<Float>& y) const = 0;
|
48
|
-
|
49
34
|
virtual size_t getF() const = 0;
|
50
35
|
virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
|
51
36
|
virtual GLM getTypeOfVar(size_t f) const = 0;
|
@@ -148,7 +148,7 @@ namespace tomoto
|
|
148
148
|
+ Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
|
149
149
|
|
150
150
|
RandGen rng;
|
151
|
-
for (size_t i = 0; i < omega.size(); ++i)
|
151
|
+
for (size_t i = 0; i < (size_t)omega.size(); ++i)
|
152
152
|
{
|
153
153
|
if (std::isnan(ys[i])) continue;
|
154
154
|
omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
|
@@ -358,8 +358,8 @@ namespace tomoto
|
|
358
358
|
if (_const)
|
359
359
|
{
|
360
360
|
if (y.size() > F) throw std::runtime_error{ text::format(
|
361
|
-
"size of
|
362
|
-
"size of
|
361
|
+
"size of `y` is greater than the number of vars.\n"
|
362
|
+
"size of `y` : %zd, number of vars: %zd", y.size(), F) };
|
363
363
|
doc.y = y;
|
364
364
|
while (doc.y.size() < F)
|
365
365
|
{
|
@@ -369,53 +369,35 @@ namespace tomoto
|
|
369
369
|
else
|
370
370
|
{
|
371
371
|
if (y.size() != F) throw std::runtime_error{ text::format(
|
372
|
-
"size of
|
373
|
-
"size of
|
372
|
+
"size of `y` must be equal to the number of vars.\n"
|
373
|
+
"size of `y` : %zd, number of vars: %zd", y.size(), F) };
|
374
374
|
doc.y = y;
|
375
375
|
}
|
376
376
|
return doc;
|
377
377
|
}
|
378
378
|
|
379
|
-
size_t addDoc(const
|
379
|
+
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
380
380
|
{
|
381
|
-
auto doc = this->
|
382
|
-
return this->_addDoc(_updateDoc(doc, y));
|
381
|
+
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
382
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
383
383
|
}
|
384
384
|
|
385
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
385
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
386
386
|
{
|
387
|
-
auto doc = as_mutable(this)->template
|
388
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
387
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
388
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
389
389
|
}
|
390
390
|
|
391
|
-
size_t addDoc(const
|
392
|
-
const std::vector<Float>& y) override
|
391
|
+
size_t addDoc(const RawDoc& rawDoc) override
|
393
392
|
{
|
394
|
-
auto doc = this->
|
395
|
-
return this->_addDoc(_updateDoc(doc, y));
|
393
|
+
auto doc = this->_makeFromRawDoc(rawDoc);
|
394
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
396
395
|
}
|
397
396
|
|
398
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
399
|
-
const std::vector<Float>& y) const override
|
397
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
400
398
|
{
|
401
|
-
auto doc = as_mutable(this)->template
|
402
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
403
|
-
}
|
404
|
-
|
405
|
-
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
406
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
407
|
-
const std::vector<Float>& y) override
|
408
|
-
{
|
409
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
410
|
-
return this->_addDoc(_updateDoc(doc, y));
|
411
|
-
}
|
412
|
-
|
413
|
-
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
414
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
415
|
-
const std::vector<Float>& y) const override
|
416
|
-
{
|
417
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
418
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
399
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
400
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
419
401
|
}
|
420
402
|
|
421
403
|
std::vector<Float> estimateVars(const DocumentBase* doc) const override
|
@@ -7,27 +7,112 @@
|
|
7
7
|
#include "../Utils/ThreadPool.hpp"
|
8
8
|
#include "../Utils/serializer.hpp"
|
9
9
|
#include "../Utils/exception.h"
|
10
|
+
#include "../Utils/SharedString.hpp"
|
10
11
|
#include <EigenRand/EigenRand>
|
12
|
+
#include <mapbox/variant.hpp>
|
11
13
|
|
12
14
|
namespace tomoto
|
13
15
|
{
|
14
16
|
using RandGen = Eigen::Rand::P8_mt19937_64<uint32_t>;
|
15
17
|
using ScalarRandGen = Eigen::Rand::UniversalRandomEngine<uint32_t, std::mt19937_64>;
|
16
18
|
|
17
|
-
|
19
|
+
struct RawDocKernel
|
18
20
|
{
|
19
|
-
public:
|
20
21
|
Float weight = 1;
|
22
|
+
SharedString docUid;
|
23
|
+
SharedString rawStr;
|
24
|
+
std::vector<uint32_t> origWordPos;
|
25
|
+
std::vector<uint16_t> origWordLen;
|
26
|
+
|
27
|
+
RawDocKernel(const RawDocKernel&) = default;
|
28
|
+
RawDocKernel(RawDocKernel&&) = default;
|
29
|
+
|
30
|
+
RawDocKernel(Float _weight = 1, const SharedString& _docUid = {})
|
31
|
+
: weight{ _weight }, docUid{ _docUid }
|
32
|
+
{
|
33
|
+
}
|
34
|
+
};
|
35
|
+
|
36
|
+
struct RawDoc : public RawDocKernel
|
37
|
+
{
|
38
|
+
using Var = mapbox::util::variant<
|
39
|
+
std::string, uint32_t, Float,
|
40
|
+
std::vector<std::string>, std::vector<uint32_t>, std::vector<Float>,
|
41
|
+
std::shared_ptr<void>
|
42
|
+
>;
|
43
|
+
using MiscType = std::unordered_map<std::string, Var>;
|
44
|
+
|
45
|
+
std::vector<Vid> words;
|
46
|
+
std::vector<std::string> rawWords;
|
47
|
+
MiscType misc;
|
48
|
+
|
49
|
+
RawDoc() = default;
|
50
|
+
RawDoc(const RawDoc&) = default;
|
51
|
+
RawDoc(RawDoc&&) = default;
|
52
|
+
|
53
|
+
RawDoc(const RawDocKernel& o)
|
54
|
+
: RawDocKernel{ o }
|
55
|
+
{
|
56
|
+
}
|
57
|
+
|
58
|
+
template<typename _Ty>
|
59
|
+
const _Ty& getMisc(const std::string& name) const
|
60
|
+
{
|
61
|
+
auto it = misc.find(name);
|
62
|
+
if (it == misc.end()) throw std::invalid_argument{ "There is no value named `" + name + "` in misc data" };
|
63
|
+
if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
|
64
|
+
return it->second.template get<_Ty>();
|
65
|
+
}
|
66
|
+
|
67
|
+
template<typename _Ty>
|
68
|
+
_Ty getMiscDefault(const std::string& name) const
|
69
|
+
{
|
70
|
+
auto it = misc.find(name);
|
71
|
+
if (it == misc.end()) return {};
|
72
|
+
if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
|
73
|
+
return it->second.template get<_Ty>();
|
74
|
+
}
|
75
|
+
};
|
76
|
+
|
77
|
+
class DocumentBase : public RawDocKernel
|
78
|
+
{
|
79
|
+
public:
|
21
80
|
tvector<Vid> words; // word id of each word
|
22
81
|
std::vector<uint32_t> wOrder; // original word order (optional)
|
82
|
+
|
83
|
+
DocumentBase(const DocumentBase&) = default;
|
84
|
+
DocumentBase(DocumentBase&&) = default;
|
85
|
+
|
86
|
+
DocumentBase(const RawDocKernel& o)
|
87
|
+
: RawDocKernel{ o }
|
88
|
+
{
|
89
|
+
}
|
90
|
+
|
91
|
+
DocumentBase(Float _weight = 1, const SharedString& _docUid = {})
|
92
|
+
: RawDocKernel{ _weight, _docUid }
|
93
|
+
{
|
94
|
+
}
|
23
95
|
|
24
|
-
std::string docUid;
|
25
|
-
std::string rawStr;
|
26
|
-
std::vector<uint32_t> origWordPos;
|
27
|
-
std::vector<uint16_t> origWordLen;
|
28
|
-
DocumentBase(Float _weight = 1) : weight(_weight) {}
|
29
96
|
virtual ~DocumentBase() {}
|
30
97
|
|
98
|
+
virtual operator RawDoc() const
|
99
|
+
{
|
100
|
+
RawDoc raw{ *this };
|
101
|
+
if (wOrder.empty())
|
102
|
+
{
|
103
|
+
raw.words.insert(raw.words.begin(), words.begin(), words.end());
|
104
|
+
}
|
105
|
+
else
|
106
|
+
{
|
107
|
+
raw.words.resize(words.size());
|
108
|
+
for (size_t i = 0; i < words.size(); ++i)
|
109
|
+
{
|
110
|
+
raw.words[i] = words[wOrder[i]];
|
111
|
+
}
|
112
|
+
}
|
113
|
+
return raw;
|
114
|
+
}
|
115
|
+
|
31
116
|
DEFINE_SERIALIZER_WITH_VERSION(0, serializer::to_key("Docu"), weight, words, wOrder);
|
32
117
|
DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, weight, words, wOrder,
|
33
118
|
rawStr, origWordPos, origWordLen,
|
@@ -127,8 +212,20 @@ namespace tomoto
|
|
127
212
|
virtual void loadModel(std::istream& reader,
|
128
213
|
std::vector<uint8_t>* extra_data = nullptr) = 0;
|
129
214
|
virtual const DocumentBase* getDoc(size_t docId) const = 0;
|
215
|
+
virtual size_t getDocIdByUid(const std::string& docUid) const = 0;
|
216
|
+
|
217
|
+
// it tokenizes rawDoc.rawStr to get words, pos and len of the document
|
218
|
+
virtual size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) = 0;
|
219
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const = 0;
|
220
|
+
|
221
|
+
// it uses words, pos and len of rawDoc itself.
|
222
|
+
virtual size_t addDoc(const RawDoc& rawDoc) = 0;
|
223
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const = 0;
|
224
|
+
|
225
|
+
virtual bool updateVocab(const std::vector<std::string>& words) = 0;
|
130
226
|
|
131
|
-
virtual
|
227
|
+
virtual double getDocLL(const DocumentBase* doc) const = 0;
|
228
|
+
virtual double getStateLL() const = 0;
|
132
229
|
|
133
230
|
virtual double getLLPerWord() const = 0;
|
134
231
|
virtual double getPerplexity() const = 0;
|
@@ -201,6 +298,7 @@ namespace tomoto
|
|
201
298
|
std::vector<DocType> docs;
|
202
299
|
std::vector<uint64_t> vocabCf;
|
203
300
|
std::vector<uint64_t> vocabDf;
|
301
|
+
std::unordered_map<SharedString, size_t> uidMap;
|
204
302
|
size_t globalStep = 0;
|
205
303
|
_ModelState globalState, tState;
|
206
304
|
Dictionary dict;
|
@@ -273,6 +371,8 @@ namespace tomoto
|
|
273
371
|
>::value, size_t>::type _addDoc(_DocTy&& doc)
|
274
372
|
{
|
275
373
|
if (doc.words.empty()) return -1;
|
374
|
+
if (!doc.docUid.empty() && uidMap.count(doc.docUid))
|
375
|
+
throw exception::InvalidArgument{ "there is a document with uid = " + std::string{ doc.docUid } + " already." };
|
276
376
|
size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end());
|
277
377
|
if (vocabCf.size() <= maxWid)
|
278
378
|
{
|
@@ -282,47 +382,48 @@ namespace tomoto
|
|
282
382
|
for (auto w : doc.words) ++vocabCf[w];
|
283
383
|
std::unordered_set<Vid> uniq{ doc.words.begin(), doc.words.end() };
|
284
384
|
for (auto w : uniq) ++vocabDf[w];
|
385
|
+
uidMap.emplace(doc.docUid, docs.size());
|
285
386
|
docs.emplace_back(std::forward<_DocTy>(doc));
|
286
387
|
return docs.size() - 1;
|
287
388
|
}
|
288
389
|
|
289
390
|
template<bool _const = false>
|
290
|
-
DocType
|
391
|
+
DocType _makeFromRawDoc(const RawDoc& rawDoc)
|
291
392
|
{
|
292
|
-
DocType doc{
|
293
|
-
|
393
|
+
DocType doc{ rawDoc };
|
394
|
+
if (!rawDoc.rawWords.empty())
|
294
395
|
{
|
295
|
-
|
296
|
-
if (_const)
|
396
|
+
for (auto& w : rawDoc.rawWords)
|
297
397
|
{
|
298
|
-
id
|
299
|
-
if (
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
398
|
+
Vid id;
|
399
|
+
if (_const)
|
400
|
+
{
|
401
|
+
id = dict.toWid(w);
|
402
|
+
if (id == (Vid)-1) continue;
|
403
|
+
}
|
404
|
+
else
|
405
|
+
{
|
406
|
+
id = dict.add(w);
|
407
|
+
}
|
408
|
+
doc.words.emplace_back(id);
|
304
409
|
}
|
305
|
-
doc.words.emplace_back(id);
|
306
410
|
}
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
for (auto& w : words) doc.words.emplace_back(w);
|
316
|
-
doc.origWordPos = pos;
|
317
|
-
doc.origWordLen = len;
|
411
|
+
else if(!rawDoc.words.empty())
|
412
|
+
{
|
413
|
+
for (auto& w : rawDoc.words) doc.words.emplace_back(w);
|
414
|
+
}
|
415
|
+
else
|
416
|
+
{
|
417
|
+
throw std::invalid_argument{ "Either `words` or `rawWords` must be filled." };
|
418
|
+
}
|
318
419
|
return doc;
|
319
420
|
}
|
320
421
|
|
321
422
|
template<bool _const, typename _FnTokenizer>
|
322
|
-
DocType
|
423
|
+
DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer)
|
323
424
|
{
|
324
|
-
DocType doc{
|
325
|
-
doc.rawStr = rawStr;
|
425
|
+
DocType doc{ rawDoc };
|
426
|
+
doc.rawStr = rawDoc.rawStr;
|
326
427
|
for (auto& p : tokenizer(doc.rawStr))
|
327
428
|
{
|
328
429
|
Vid wid;
|
@@ -452,10 +553,11 @@ namespace tomoto
|
|
452
553
|
return realV;
|
453
554
|
}
|
454
555
|
|
455
|
-
|
556
|
+
bool updateVocab(const std::vector<std::string>& words) override
|
456
557
|
{
|
457
|
-
|
458
|
-
for(auto& w : words) dict.add(w);
|
558
|
+
bool empty = dict.size() == 0;
|
559
|
+
for (auto& w : words) dict.add(w);
|
560
|
+
return empty;
|
459
561
|
}
|
460
562
|
|
461
563
|
void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
|
@@ -606,6 +708,18 @@ namespace tomoto
|
|
606
708
|
return vid2String(getWidsByDocSorted(doc, topN));
|
607
709
|
}
|
608
710
|
|
711
|
+
double getDocLL(const DocumentBase* doc) const override
|
712
|
+
{
|
713
|
+
auto* p = dynamic_cast<const DocType*>(doc);
|
714
|
+
if (!p) throw std::invalid_argument{ "wrong `doc` type." };
|
715
|
+
return static_cast<const _Derived*>(this)->getLLDocs(p, p + 1);
|
716
|
+
}
|
717
|
+
|
718
|
+
double getStateLL() const override
|
719
|
+
{
|
720
|
+
return static_cast<const _Derived*>(this)->getLLRest(this->globalState);
|
721
|
+
}
|
722
|
+
|
609
723
|
std::vector<double> infer(const std::vector<DocumentBase*>& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const override
|
610
724
|
{
|
611
725
|
if (!numWorkers) numWorkers = std::thread::hardware_concurrency();
|
@@ -651,12 +765,18 @@ namespace tomoto
|
|
651
765
|
return extractTopN<Tid>(getTopicsByDoc(doc), topN);
|
652
766
|
}
|
653
767
|
|
654
|
-
|
655
768
|
const DocumentBase* getDoc(size_t docId) const override
|
656
769
|
{
|
657
770
|
return &_getDoc(docId);
|
658
771
|
}
|
659
772
|
|
773
|
+
size_t getDocIdByUid(const std::string& docUid) const override
|
774
|
+
{
|
775
|
+
auto it = uidMap.find(SharedString{ docUid });
|
776
|
+
if (it == uidMap.end()) return -1;
|
777
|
+
return it->second;
|
778
|
+
}
|
779
|
+
|
660
780
|
size_t getGlobalStep() const override
|
661
781
|
{
|
662
782
|
return globalStep;
|