tomoto 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/README.md +3 -3
- data/ext/tomoto/ext.cpp +34 -9
- data/ext/tomoto/extconf.rb +2 -1
- data/lib/tomoto/dmr.rb +1 -1
- data/lib/tomoto/gdmr.rb +1 -1
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/LICENSE +1 -1
- data/vendor/tomotopy/README.kr.rst +32 -3
- data/vendor/tomotopy/README.rst +30 -1
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +133 -147
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +158 -5
- data/vendor/tomotopy/src/TopicModel/DMR.h +1 -16
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +15 -34
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -16
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +15 -32
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +18 -37
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +16 -20
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/LDA.h +0 -11
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +9 -21
- data/vendor/tomotopy/src/TopicModel/LLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +12 -30
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +59 -72
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +12 -30
- data/vendor/tomotopy/src/TopicModel/SLDA.h +0 -15
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +17 -35
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +158 -38
- data/vendor/tomotopy/src/Utils/Dictionary.h +40 -2
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +122 -3
- data/vendor/tomotopy/src/Utils/SharedString.hpp +181 -0
- data/vendor/tomotopy/src/Utils/math.h +1 -1
- data/vendor/tomotopy/src/Utils/sample.hpp +1 -1
- data/vendor/tomotopy/src/Utils/serializer.hpp +17 -0
- data/vendor/variant/LICENSE +25 -0
- data/vendor/variant/LICENSE_1_0.txt +23 -0
- data/vendor/variant/README.md +102 -0
- data/vendor/variant/include/mapbox/optional.hpp +74 -0
- data/vendor/variant/include/mapbox/recursive_wrapper.hpp +122 -0
- data/vendor/variant/include/mapbox/variant.hpp +974 -0
- data/vendor/variant/include/mapbox/variant_io.hpp +45 -0
- metadata +15 -7
@@ -153,46 +153,28 @@ namespace tomoto
|
|
153
153
|
return doc;
|
154
154
|
}
|
155
155
|
|
156
|
-
size_t addDoc(const
|
156
|
+
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
157
157
|
{
|
158
|
-
auto doc = this->
|
159
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
158
|
+
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
159
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
160
160
|
}
|
161
161
|
|
162
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
162
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
163
163
|
{
|
164
|
-
auto doc = as_mutable(this)->template
|
165
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
164
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
165
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
166
166
|
}
|
167
167
|
|
168
|
-
size_t addDoc(const
|
169
|
-
const std::vector<std::string>& labels) override
|
168
|
+
size_t addDoc(const RawDoc& rawDoc) override
|
170
169
|
{
|
171
|
-
auto doc = this->
|
172
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
170
|
+
auto doc = this->_makeFromRawDoc(rawDoc);
|
171
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
173
172
|
}
|
174
173
|
|
175
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
176
|
-
const std::vector<std::string>& labels) const override
|
174
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
177
175
|
{
|
178
|
-
auto doc = as_mutable(this)->template
|
179
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
180
|
-
}
|
181
|
-
|
182
|
-
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
183
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
184
|
-
const std::vector<std::string>& labels) override
|
185
|
-
{
|
186
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
187
|
-
return this->_addDoc(_updateDoc(doc, labels));
|
188
|
-
}
|
189
|
-
|
190
|
-
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
191
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
192
|
-
const std::vector<std::string>& labels) const override
|
193
|
-
{
|
194
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
195
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
|
176
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
177
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
196
178
|
}
|
197
179
|
|
198
180
|
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
@@ -31,21 +31,6 @@ namespace tomoto
|
|
31
31
|
size_t seed = std::random_device{}(),
|
32
32
|
bool scalarRng = false);
|
33
33
|
|
34
|
-
virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
|
35
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
|
36
|
-
|
37
|
-
virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
38
|
-
const std::vector<Float>& y) = 0;
|
39
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
|
40
|
-
const std::vector<Float>& y) const = 0;
|
41
|
-
|
42
|
-
virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
43
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
44
|
-
const std::vector<Float>& y) = 0;
|
45
|
-
virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
46
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
47
|
-
const std::vector<Float>& y) const = 0;
|
48
|
-
|
49
34
|
virtual size_t getF() const = 0;
|
50
35
|
virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
|
51
36
|
virtual GLM getTypeOfVar(size_t f) const = 0;
|
@@ -148,7 +148,7 @@ namespace tomoto
|
|
148
148
|
+ Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
|
149
149
|
|
150
150
|
RandGen rng;
|
151
|
-
for (size_t i = 0; i < omega.size(); ++i)
|
151
|
+
for (size_t i = 0; i < (size_t)omega.size(); ++i)
|
152
152
|
{
|
153
153
|
if (std::isnan(ys[i])) continue;
|
154
154
|
omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
|
@@ -358,8 +358,8 @@ namespace tomoto
|
|
358
358
|
if (_const)
|
359
359
|
{
|
360
360
|
if (y.size() > F) throw std::runtime_error{ text::format(
|
361
|
-
"size of
|
362
|
-
"size of
|
361
|
+
"size of `y` is greater than the number of vars.\n"
|
362
|
+
"size of `y` : %zd, number of vars: %zd", y.size(), F) };
|
363
363
|
doc.y = y;
|
364
364
|
while (doc.y.size() < F)
|
365
365
|
{
|
@@ -369,53 +369,35 @@ namespace tomoto
|
|
369
369
|
else
|
370
370
|
{
|
371
371
|
if (y.size() != F) throw std::runtime_error{ text::format(
|
372
|
-
"size of
|
373
|
-
"size of
|
372
|
+
"size of `y` must be equal to the number of vars.\n"
|
373
|
+
"size of `y` : %zd, number of vars: %zd", y.size(), F) };
|
374
374
|
doc.y = y;
|
375
375
|
}
|
376
376
|
return doc;
|
377
377
|
}
|
378
378
|
|
379
|
-
size_t addDoc(const
|
379
|
+
size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
|
380
380
|
{
|
381
|
-
auto doc = this->
|
382
|
-
return this->_addDoc(_updateDoc(doc, y));
|
381
|
+
auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
|
382
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
383
383
|
}
|
384
384
|
|
385
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
385
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
386
386
|
{
|
387
|
-
auto doc = as_mutable(this)->template
|
388
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
387
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
388
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
389
389
|
}
|
390
390
|
|
391
|
-
size_t addDoc(const
|
392
|
-
const std::vector<Float>& y) override
|
391
|
+
size_t addDoc(const RawDoc& rawDoc) override
|
393
392
|
{
|
394
|
-
auto doc = this->
|
395
|
-
return this->_addDoc(_updateDoc(doc, y));
|
393
|
+
auto doc = this->_makeFromRawDoc(rawDoc);
|
394
|
+
return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
396
395
|
}
|
397
396
|
|
398
|
-
std::unique_ptr<DocumentBase> makeDoc(const
|
399
|
-
const std::vector<Float>& y) const override
|
397
|
+
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
400
398
|
{
|
401
|
-
auto doc = as_mutable(this)->template
|
402
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
403
|
-
}
|
404
|
-
|
405
|
-
size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
406
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
407
|
-
const std::vector<Float>& y) override
|
408
|
-
{
|
409
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
410
|
-
return this->_addDoc(_updateDoc(doc, y));
|
411
|
-
}
|
412
|
-
|
413
|
-
std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
|
414
|
-
const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
|
415
|
-
const std::vector<Float>& y) const override
|
416
|
-
{
|
417
|
-
auto doc = this->_makeRawDoc(rawStr, words, pos, len);
|
418
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
|
399
|
+
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
400
|
+
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
|
419
401
|
}
|
420
402
|
|
421
403
|
std::vector<Float> estimateVars(const DocumentBase* doc) const override
|
@@ -7,27 +7,112 @@
|
|
7
7
|
#include "../Utils/ThreadPool.hpp"
|
8
8
|
#include "../Utils/serializer.hpp"
|
9
9
|
#include "../Utils/exception.h"
|
10
|
+
#include "../Utils/SharedString.hpp"
|
10
11
|
#include <EigenRand/EigenRand>
|
12
|
+
#include <mapbox/variant.hpp>
|
11
13
|
|
12
14
|
namespace tomoto
|
13
15
|
{
|
14
16
|
using RandGen = Eigen::Rand::P8_mt19937_64<uint32_t>;
|
15
17
|
using ScalarRandGen = Eigen::Rand::UniversalRandomEngine<uint32_t, std::mt19937_64>;
|
16
18
|
|
17
|
-
|
19
|
+
struct RawDocKernel
|
18
20
|
{
|
19
|
-
public:
|
20
21
|
Float weight = 1;
|
22
|
+
SharedString docUid;
|
23
|
+
SharedString rawStr;
|
24
|
+
std::vector<uint32_t> origWordPos;
|
25
|
+
std::vector<uint16_t> origWordLen;
|
26
|
+
|
27
|
+
RawDocKernel(const RawDocKernel&) = default;
|
28
|
+
RawDocKernel(RawDocKernel&&) = default;
|
29
|
+
|
30
|
+
RawDocKernel(Float _weight = 1, const SharedString& _docUid = {})
|
31
|
+
: weight{ _weight }, docUid{ _docUid }
|
32
|
+
{
|
33
|
+
}
|
34
|
+
};
|
35
|
+
|
36
|
+
struct RawDoc : public RawDocKernel
|
37
|
+
{
|
38
|
+
using Var = mapbox::util::variant<
|
39
|
+
std::string, uint32_t, Float,
|
40
|
+
std::vector<std::string>, std::vector<uint32_t>, std::vector<Float>,
|
41
|
+
std::shared_ptr<void>
|
42
|
+
>;
|
43
|
+
using MiscType = std::unordered_map<std::string, Var>;
|
44
|
+
|
45
|
+
std::vector<Vid> words;
|
46
|
+
std::vector<std::string> rawWords;
|
47
|
+
MiscType misc;
|
48
|
+
|
49
|
+
RawDoc() = default;
|
50
|
+
RawDoc(const RawDoc&) = default;
|
51
|
+
RawDoc(RawDoc&&) = default;
|
52
|
+
|
53
|
+
RawDoc(const RawDocKernel& o)
|
54
|
+
: RawDocKernel{ o }
|
55
|
+
{
|
56
|
+
}
|
57
|
+
|
58
|
+
template<typename _Ty>
|
59
|
+
const _Ty& getMisc(const std::string& name) const
|
60
|
+
{
|
61
|
+
auto it = misc.find(name);
|
62
|
+
if (it == misc.end()) throw std::invalid_argument{ "There is no value named `" + name + "` in misc data" };
|
63
|
+
if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
|
64
|
+
return it->second.template get<_Ty>();
|
65
|
+
}
|
66
|
+
|
67
|
+
template<typename _Ty>
|
68
|
+
_Ty getMiscDefault(const std::string& name) const
|
69
|
+
{
|
70
|
+
auto it = misc.find(name);
|
71
|
+
if (it == misc.end()) return {};
|
72
|
+
if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
|
73
|
+
return it->second.template get<_Ty>();
|
74
|
+
}
|
75
|
+
};
|
76
|
+
|
77
|
+
class DocumentBase : public RawDocKernel
|
78
|
+
{
|
79
|
+
public:
|
21
80
|
tvector<Vid> words; // word id of each word
|
22
81
|
std::vector<uint32_t> wOrder; // original word order (optional)
|
82
|
+
|
83
|
+
DocumentBase(const DocumentBase&) = default;
|
84
|
+
DocumentBase(DocumentBase&&) = default;
|
85
|
+
|
86
|
+
DocumentBase(const RawDocKernel& o)
|
87
|
+
: RawDocKernel{ o }
|
88
|
+
{
|
89
|
+
}
|
90
|
+
|
91
|
+
DocumentBase(Float _weight = 1, const SharedString& _docUid = {})
|
92
|
+
: RawDocKernel{ _weight, _docUid }
|
93
|
+
{
|
94
|
+
}
|
23
95
|
|
24
|
-
std::string docUid;
|
25
|
-
std::string rawStr;
|
26
|
-
std::vector<uint32_t> origWordPos;
|
27
|
-
std::vector<uint16_t> origWordLen;
|
28
|
-
DocumentBase(Float _weight = 1) : weight(_weight) {}
|
29
96
|
virtual ~DocumentBase() {}
|
30
97
|
|
98
|
+
virtual operator RawDoc() const
|
99
|
+
{
|
100
|
+
RawDoc raw{ *this };
|
101
|
+
if (wOrder.empty())
|
102
|
+
{
|
103
|
+
raw.words.insert(raw.words.begin(), words.begin(), words.end());
|
104
|
+
}
|
105
|
+
else
|
106
|
+
{
|
107
|
+
raw.words.resize(words.size());
|
108
|
+
for (size_t i = 0; i < words.size(); ++i)
|
109
|
+
{
|
110
|
+
raw.words[i] = words[wOrder[i]];
|
111
|
+
}
|
112
|
+
}
|
113
|
+
return raw;
|
114
|
+
}
|
115
|
+
|
31
116
|
DEFINE_SERIALIZER_WITH_VERSION(0, serializer::to_key("Docu"), weight, words, wOrder);
|
32
117
|
DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, weight, words, wOrder,
|
33
118
|
rawStr, origWordPos, origWordLen,
|
@@ -127,8 +212,20 @@ namespace tomoto
|
|
127
212
|
virtual void loadModel(std::istream& reader,
|
128
213
|
std::vector<uint8_t>* extra_data = nullptr) = 0;
|
129
214
|
virtual const DocumentBase* getDoc(size_t docId) const = 0;
|
215
|
+
virtual size_t getDocIdByUid(const std::string& docUid) const = 0;
|
216
|
+
|
217
|
+
// it tokenizes rawDoc.rawStr to get words, pos and len of the document
|
218
|
+
virtual size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) = 0;
|
219
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const = 0;
|
220
|
+
|
221
|
+
// it uses words, pos and len of rawDoc itself.
|
222
|
+
virtual size_t addDoc(const RawDoc& rawDoc) = 0;
|
223
|
+
virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const = 0;
|
224
|
+
|
225
|
+
virtual bool updateVocab(const std::vector<std::string>& words) = 0;
|
130
226
|
|
131
|
-
virtual
|
227
|
+
virtual double getDocLL(const DocumentBase* doc) const = 0;
|
228
|
+
virtual double getStateLL() const = 0;
|
132
229
|
|
133
230
|
virtual double getLLPerWord() const = 0;
|
134
231
|
virtual double getPerplexity() const = 0;
|
@@ -201,6 +298,7 @@ namespace tomoto
|
|
201
298
|
std::vector<DocType> docs;
|
202
299
|
std::vector<uint64_t> vocabCf;
|
203
300
|
std::vector<uint64_t> vocabDf;
|
301
|
+
std::unordered_map<SharedString, size_t> uidMap;
|
204
302
|
size_t globalStep = 0;
|
205
303
|
_ModelState globalState, tState;
|
206
304
|
Dictionary dict;
|
@@ -273,6 +371,8 @@ namespace tomoto
|
|
273
371
|
>::value, size_t>::type _addDoc(_DocTy&& doc)
|
274
372
|
{
|
275
373
|
if (doc.words.empty()) return -1;
|
374
|
+
if (!doc.docUid.empty() && uidMap.count(doc.docUid))
|
375
|
+
throw exception::InvalidArgument{ "there is a document with uid = " + std::string{ doc.docUid } + " already." };
|
276
376
|
size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end());
|
277
377
|
if (vocabCf.size() <= maxWid)
|
278
378
|
{
|
@@ -282,47 +382,48 @@ namespace tomoto
|
|
282
382
|
for (auto w : doc.words) ++vocabCf[w];
|
283
383
|
std::unordered_set<Vid> uniq{ doc.words.begin(), doc.words.end() };
|
284
384
|
for (auto w : uniq) ++vocabDf[w];
|
385
|
+
uidMap.emplace(doc.docUid, docs.size());
|
285
386
|
docs.emplace_back(std::forward<_DocTy>(doc));
|
286
387
|
return docs.size() - 1;
|
287
388
|
}
|
288
389
|
|
289
390
|
template<bool _const = false>
|
290
|
-
DocType
|
391
|
+
DocType _makeFromRawDoc(const RawDoc& rawDoc)
|
291
392
|
{
|
292
|
-
DocType doc{
|
293
|
-
|
393
|
+
DocType doc{ rawDoc };
|
394
|
+
if (!rawDoc.rawWords.empty())
|
294
395
|
{
|
295
|
-
|
296
|
-
if (_const)
|
396
|
+
for (auto& w : rawDoc.rawWords)
|
297
397
|
{
|
298
|
-
id
|
299
|
-
if (
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
398
|
+
Vid id;
|
399
|
+
if (_const)
|
400
|
+
{
|
401
|
+
id = dict.toWid(w);
|
402
|
+
if (id == (Vid)-1) continue;
|
403
|
+
}
|
404
|
+
else
|
405
|
+
{
|
406
|
+
id = dict.add(w);
|
407
|
+
}
|
408
|
+
doc.words.emplace_back(id);
|
304
409
|
}
|
305
|
-
doc.words.emplace_back(id);
|
306
410
|
}
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
for (auto& w : words) doc.words.emplace_back(w);
|
316
|
-
doc.origWordPos = pos;
|
317
|
-
doc.origWordLen = len;
|
411
|
+
else if(!rawDoc.words.empty())
|
412
|
+
{
|
413
|
+
for (auto& w : rawDoc.words) doc.words.emplace_back(w);
|
414
|
+
}
|
415
|
+
else
|
416
|
+
{
|
417
|
+
throw std::invalid_argument{ "Either `words` or `rawWords` must be filled." };
|
418
|
+
}
|
318
419
|
return doc;
|
319
420
|
}
|
320
421
|
|
321
422
|
template<bool _const, typename _FnTokenizer>
|
322
|
-
DocType
|
423
|
+
DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer)
|
323
424
|
{
|
324
|
-
DocType doc{
|
325
|
-
doc.rawStr = rawStr;
|
425
|
+
DocType doc{ rawDoc };
|
426
|
+
doc.rawStr = rawDoc.rawStr;
|
326
427
|
for (auto& p : tokenizer(doc.rawStr))
|
327
428
|
{
|
328
429
|
Vid wid;
|
@@ -452,10 +553,11 @@ namespace tomoto
|
|
452
553
|
return realV;
|
453
554
|
}
|
454
555
|
|
455
|
-
|
556
|
+
bool updateVocab(const std::vector<std::string>& words) override
|
456
557
|
{
|
457
|
-
|
458
|
-
for(auto& w : words) dict.add(w);
|
558
|
+
bool empty = dict.size() == 0;
|
559
|
+
for (auto& w : words) dict.add(w);
|
560
|
+
return empty;
|
459
561
|
}
|
460
562
|
|
461
563
|
void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
|
@@ -606,6 +708,18 @@ namespace tomoto
|
|
606
708
|
return vid2String(getWidsByDocSorted(doc, topN));
|
607
709
|
}
|
608
710
|
|
711
|
+
double getDocLL(const DocumentBase* doc) const override
|
712
|
+
{
|
713
|
+
auto* p = dynamic_cast<const DocType*>(doc);
|
714
|
+
if (!p) throw std::invalid_argument{ "wrong `doc` type." };
|
715
|
+
return static_cast<const _Derived*>(this)->getLLDocs(p, p + 1);
|
716
|
+
}
|
717
|
+
|
718
|
+
double getStateLL() const override
|
719
|
+
{
|
720
|
+
return static_cast<const _Derived*>(this)->getLLRest(this->globalState);
|
721
|
+
}
|
722
|
+
|
609
723
|
std::vector<double> infer(const std::vector<DocumentBase*>& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const override
|
610
724
|
{
|
611
725
|
if (!numWorkers) numWorkers = std::thread::hardware_concurrency();
|
@@ -651,12 +765,18 @@ namespace tomoto
|
|
651
765
|
return extractTopN<Tid>(getTopicsByDoc(doc), topN);
|
652
766
|
}
|
653
767
|
|
654
|
-
|
655
768
|
const DocumentBase* getDoc(size_t docId) const override
|
656
769
|
{
|
657
770
|
return &_getDoc(docId);
|
658
771
|
}
|
659
772
|
|
773
|
+
size_t getDocIdByUid(const std::string& docUid) const override
|
774
|
+
{
|
775
|
+
auto it = uidMap.find(SharedString{ docUid });
|
776
|
+
if (it == uidMap.end()) return -1;
|
777
|
+
return it->second;
|
778
|
+
}
|
779
|
+
|
660
780
|
size_t getGlobalStep() const override
|
661
781
|
{
|
662
782
|
return globalStep;
|