tomoto 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +3 -3
  5. data/ext/tomoto/ext.cpp +34 -9
  6. data/ext/tomoto/extconf.rb +2 -1
  7. data/lib/tomoto/dmr.rb +1 -1
  8. data/lib/tomoto/gdmr.rb +1 -1
  9. data/lib/tomoto/version.rb +1 -1
  10. data/vendor/tomotopy/LICENSE +1 -1
  11. data/vendor/tomotopy/README.kr.rst +32 -3
  12. data/vendor/tomotopy/README.rst +30 -1
  13. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +133 -147
  14. data/vendor/tomotopy/src/Labeling/FoRelevance.h +158 -5
  15. data/vendor/tomotopy/src/TopicModel/DMR.h +1 -16
  16. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +15 -34
  17. data/vendor/tomotopy/src/TopicModel/DT.h +1 -16
  18. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +15 -32
  19. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +18 -37
  20. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +16 -20
  21. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +3 -3
  22. data/vendor/tomotopy/src/TopicModel/LDA.h +0 -11
  23. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +9 -21
  24. data/vendor/tomotopy/src/TopicModel/LLDA.h +0 -15
  25. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +12 -30
  26. data/vendor/tomotopy/src/TopicModel/MGLDA.h +0 -15
  27. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +59 -72
  28. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +12 -30
  29. data/vendor/tomotopy/src/TopicModel/SLDA.h +0 -15
  30. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +17 -35
  31. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +158 -38
  32. data/vendor/tomotopy/src/Utils/Dictionary.h +40 -2
  33. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +122 -3
  34. data/vendor/tomotopy/src/Utils/SharedString.hpp +181 -0
  35. data/vendor/tomotopy/src/Utils/math.h +1 -1
  36. data/vendor/tomotopy/src/Utils/sample.hpp +1 -1
  37. data/vendor/tomotopy/src/Utils/serializer.hpp +17 -0
  38. data/vendor/variant/LICENSE +25 -0
  39. data/vendor/variant/LICENSE_1_0.txt +23 -0
  40. data/vendor/variant/README.md +102 -0
  41. data/vendor/variant/include/mapbox/optional.hpp +74 -0
  42. data/vendor/variant/include/mapbox/recursive_wrapper.hpp +122 -0
  43. data/vendor/variant/include/mapbox/variant.hpp +974 -0
  44. data/vendor/variant/include/mapbox/variant_io.hpp +45 -0
  45. metadata +15 -7
@@ -153,46 +153,28 @@ namespace tomoto
153
153
  return doc;
154
154
  }
155
155
 
156
- size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
156
+ size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
157
157
  {
158
- auto doc = this->_makeDoc(words);
159
- return this->_addDoc(_updateDoc(doc, labels));
158
+ auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
159
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
160
160
  }
161
161
 
162
- std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
162
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
163
163
  {
164
- auto doc = as_mutable(this)->template _makeDoc<true>(words);
165
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
164
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
165
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
166
166
  }
167
167
 
168
- size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
169
- const std::vector<std::string>& labels) override
168
+ size_t addDoc(const RawDoc& rawDoc) override
170
169
  {
171
- auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
172
- return this->_addDoc(_updateDoc(doc, labels));
170
+ auto doc = this->_makeFromRawDoc(rawDoc);
171
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
173
172
  }
174
173
 
175
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
176
- const std::vector<std::string>& labels) const override
174
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
177
175
  {
178
- auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
179
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
180
- }
181
-
182
- size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
183
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
184
- const std::vector<std::string>& labels) override
185
- {
186
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
187
- return this->_addDoc(_updateDoc(doc, labels));
188
- }
189
-
190
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
191
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
192
- const std::vector<std::string>& labels) const override
193
- {
194
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
195
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
176
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
177
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
196
178
  }
197
179
 
198
180
  std::vector<Float> getTopicsByDoc(const _DocType& doc) const
@@ -31,21 +31,6 @@ namespace tomoto
31
31
  size_t seed = std::random_device{}(),
32
32
  bool scalarRng = false);
33
33
 
34
- virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
35
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
36
-
37
- virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
38
- const std::vector<Float>& y) = 0;
39
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
40
- const std::vector<Float>& y) const = 0;
41
-
42
- virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
43
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
44
- const std::vector<Float>& y) = 0;
45
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
46
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
47
- const std::vector<Float>& y) const = 0;
48
-
49
34
  virtual size_t getF() const = 0;
50
35
  virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
51
36
  virtual GLM getTypeOfVar(size_t f) const = 0;
@@ -148,7 +148,7 @@ namespace tomoto
148
148
  + Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
149
149
 
150
150
  RandGen rng;
151
- for (size_t i = 0; i < omega.size(); ++i)
151
+ for (size_t i = 0; i < (size_t)omega.size(); ++i)
152
152
  {
153
153
  if (std::isnan(ys[i])) continue;
154
154
  omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
@@ -358,8 +358,8 @@ namespace tomoto
358
358
  if (_const)
359
359
  {
360
360
  if (y.size() > F) throw std::runtime_error{ text::format(
361
- "size of 'y' is greater than the number of vars.\n"
362
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
361
+ "size of `y` is greater than the number of vars.\n"
362
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
363
363
  doc.y = y;
364
364
  while (doc.y.size() < F)
365
365
  {
@@ -369,53 +369,35 @@ namespace tomoto
369
369
  else
370
370
  {
371
371
  if (y.size() != F) throw std::runtime_error{ text::format(
372
- "size of 'y' must be equal to the number of vars.\n"
373
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
372
+ "size of `y` must be equal to the number of vars.\n"
373
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
374
374
  doc.y = y;
375
375
  }
376
376
  return doc;
377
377
  }
378
378
 
379
- size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override
379
+ size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
380
380
  {
381
- auto doc = this->_makeDoc(words);
382
- return this->_addDoc(_updateDoc(doc, y));
381
+ auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
382
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
383
383
  }
384
384
 
385
- std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override
385
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
386
386
  {
387
- auto doc = as_mutable(this)->template _makeDoc<true>(words);
388
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
387
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
388
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
389
389
  }
390
390
 
391
- size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
392
- const std::vector<Float>& y) override
391
+ size_t addDoc(const RawDoc& rawDoc) override
393
392
  {
394
- auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
395
- return this->_addDoc(_updateDoc(doc, y));
393
+ auto doc = this->_makeFromRawDoc(rawDoc);
394
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
396
395
  }
397
396
 
398
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
399
- const std::vector<Float>& y) const override
397
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
400
398
  {
401
- auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
402
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
403
- }
404
-
405
- size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
406
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
407
- const std::vector<Float>& y) override
408
- {
409
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
410
- return this->_addDoc(_updateDoc(doc, y));
411
- }
412
-
413
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
414
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
415
- const std::vector<Float>& y) const override
416
- {
417
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
418
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
399
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
400
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
419
401
  }
420
402
 
421
403
  std::vector<Float> estimateVars(const DocumentBase* doc) const override
@@ -7,27 +7,112 @@
7
7
  #include "../Utils/ThreadPool.hpp"
8
8
  #include "../Utils/serializer.hpp"
9
9
  #include "../Utils/exception.h"
10
+ #include "../Utils/SharedString.hpp"
10
11
  #include <EigenRand/EigenRand>
12
+ #include <mapbox/variant.hpp>
11
13
 
12
14
  namespace tomoto
13
15
  {
14
16
  using RandGen = Eigen::Rand::P8_mt19937_64<uint32_t>;
15
17
  using ScalarRandGen = Eigen::Rand::UniversalRandomEngine<uint32_t, std::mt19937_64>;
16
18
 
17
- class DocumentBase
19
+ struct RawDocKernel
18
20
  {
19
- public:
20
21
  Float weight = 1;
22
+ SharedString docUid;
23
+ SharedString rawStr;
24
+ std::vector<uint32_t> origWordPos;
25
+ std::vector<uint16_t> origWordLen;
26
+
27
+ RawDocKernel(const RawDocKernel&) = default;
28
+ RawDocKernel(RawDocKernel&&) = default;
29
+
30
+ RawDocKernel(Float _weight = 1, const SharedString& _docUid = {})
31
+ : weight{ _weight }, docUid{ _docUid }
32
+ {
33
+ }
34
+ };
35
+
36
+ struct RawDoc : public RawDocKernel
37
+ {
38
+ using Var = mapbox::util::variant<
39
+ std::string, uint32_t, Float,
40
+ std::vector<std::string>, std::vector<uint32_t>, std::vector<Float>,
41
+ std::shared_ptr<void>
42
+ >;
43
+ using MiscType = std::unordered_map<std::string, Var>;
44
+
45
+ std::vector<Vid> words;
46
+ std::vector<std::string> rawWords;
47
+ MiscType misc;
48
+
49
+ RawDoc() = default;
50
+ RawDoc(const RawDoc&) = default;
51
+ RawDoc(RawDoc&&) = default;
52
+
53
+ RawDoc(const RawDocKernel& o)
54
+ : RawDocKernel{ o }
55
+ {
56
+ }
57
+
58
+ template<typename _Ty>
59
+ const _Ty& getMisc(const std::string& name) const
60
+ {
61
+ auto it = misc.find(name);
62
+ if (it == misc.end()) throw std::invalid_argument{ "There is no value named `" + name + "` in misc data" };
63
+ if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
64
+ return it->second.template get<_Ty>();
65
+ }
66
+
67
+ template<typename _Ty>
68
+ _Ty getMiscDefault(const std::string& name) const
69
+ {
70
+ auto it = misc.find(name);
71
+ if (it == misc.end()) return {};
72
+ if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
73
+ return it->second.template get<_Ty>();
74
+ }
75
+ };
76
+
77
+ class DocumentBase : public RawDocKernel
78
+ {
79
+ public:
21
80
  tvector<Vid> words; // word id of each word
22
81
  std::vector<uint32_t> wOrder; // original word order (optional)
82
+
83
+ DocumentBase(const DocumentBase&) = default;
84
+ DocumentBase(DocumentBase&&) = default;
85
+
86
+ DocumentBase(const RawDocKernel& o)
87
+ : RawDocKernel{ o }
88
+ {
89
+ }
90
+
91
+ DocumentBase(Float _weight = 1, const SharedString& _docUid = {})
92
+ : RawDocKernel{ _weight, _docUid }
93
+ {
94
+ }
23
95
 
24
- std::string docUid;
25
- std::string rawStr;
26
- std::vector<uint32_t> origWordPos;
27
- std::vector<uint16_t> origWordLen;
28
- DocumentBase(Float _weight = 1) : weight(_weight) {}
29
96
  virtual ~DocumentBase() {}
30
97
 
98
+ virtual operator RawDoc() const
99
+ {
100
+ RawDoc raw{ *this };
101
+ if (wOrder.empty())
102
+ {
103
+ raw.words.insert(raw.words.begin(), words.begin(), words.end());
104
+ }
105
+ else
106
+ {
107
+ raw.words.resize(words.size());
108
+ for (size_t i = 0; i < words.size(); ++i)
109
+ {
110
+ raw.words[i] = words[wOrder[i]];
111
+ }
112
+ }
113
+ return raw;
114
+ }
115
+
31
116
  DEFINE_SERIALIZER_WITH_VERSION(0, serializer::to_key("Docu"), weight, words, wOrder);
32
117
  DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, weight, words, wOrder,
33
118
  rawStr, origWordPos, origWordLen,
@@ -127,8 +212,20 @@ namespace tomoto
127
212
  virtual void loadModel(std::istream& reader,
128
213
  std::vector<uint8_t>* extra_data = nullptr) = 0;
129
214
  virtual const DocumentBase* getDoc(size_t docId) const = 0;
215
+ virtual size_t getDocIdByUid(const std::string& docUid) const = 0;
216
+
217
+ // it tokenizes rawDoc.rawStr to get words, pos and len of the document
218
+ virtual size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) = 0;
219
+ virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const = 0;
220
+
221
+ // it uses words, pos and len of rawDoc itself.
222
+ virtual size_t addDoc(const RawDoc& rawDoc) = 0;
223
+ virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const = 0;
224
+
225
+ virtual bool updateVocab(const std::vector<std::string>& words) = 0;
130
226
 
131
- virtual void updateVocab(const std::vector<std::string>& words) = 0;
227
+ virtual double getDocLL(const DocumentBase* doc) const = 0;
228
+ virtual double getStateLL() const = 0;
132
229
 
133
230
  virtual double getLLPerWord() const = 0;
134
231
  virtual double getPerplexity() const = 0;
@@ -201,6 +298,7 @@ namespace tomoto
201
298
  std::vector<DocType> docs;
202
299
  std::vector<uint64_t> vocabCf;
203
300
  std::vector<uint64_t> vocabDf;
301
+ std::unordered_map<SharedString, size_t> uidMap;
204
302
  size_t globalStep = 0;
205
303
  _ModelState globalState, tState;
206
304
  Dictionary dict;
@@ -273,6 +371,8 @@ namespace tomoto
273
371
  >::value, size_t>::type _addDoc(_DocTy&& doc)
274
372
  {
275
373
  if (doc.words.empty()) return -1;
374
+ if (!doc.docUid.empty() && uidMap.count(doc.docUid))
375
+ throw exception::InvalidArgument{ "there is a document with uid = " + std::string{ doc.docUid } + " already." };
276
376
  size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end());
277
377
  if (vocabCf.size() <= maxWid)
278
378
  {
@@ -282,47 +382,48 @@ namespace tomoto
282
382
  for (auto w : doc.words) ++vocabCf[w];
283
383
  std::unordered_set<Vid> uniq{ doc.words.begin(), doc.words.end() };
284
384
  for (auto w : uniq) ++vocabDf[w];
385
+ uidMap.emplace(doc.docUid, docs.size());
285
386
  docs.emplace_back(std::forward<_DocTy>(doc));
286
387
  return docs.size() - 1;
287
388
  }
288
389
 
289
390
  template<bool _const = false>
290
- DocType _makeDoc(const std::vector<std::string>& words, Float weight = 1)
391
+ DocType _makeFromRawDoc(const RawDoc& rawDoc)
291
392
  {
292
- DocType doc{ weight };
293
- for (auto& w : words)
393
+ DocType doc{ rawDoc };
394
+ if (!rawDoc.rawWords.empty())
294
395
  {
295
- Vid id;
296
- if (_const)
396
+ for (auto& w : rawDoc.rawWords)
297
397
  {
298
- id = dict.toWid(w);
299
- if (id == (Vid)-1) continue;
300
- }
301
- else
302
- {
303
- id = dict.add(w);
398
+ Vid id;
399
+ if (_const)
400
+ {
401
+ id = dict.toWid(w);
402
+ if (id == (Vid)-1) continue;
403
+ }
404
+ else
405
+ {
406
+ id = dict.add(w);
407
+ }
408
+ doc.words.emplace_back(id);
304
409
  }
305
- doc.words.emplace_back(id);
306
410
  }
307
- return doc;
308
- }
309
-
310
- DocType _makeRawDoc(const std::string& rawStr, const std::vector<Vid>& words,
311
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, Float weight = 1) const
312
- {
313
- DocType doc{ weight };
314
- doc.rawStr = rawStr;
315
- for (auto& w : words) doc.words.emplace_back(w);
316
- doc.origWordPos = pos;
317
- doc.origWordLen = len;
411
+ else if(!rawDoc.words.empty())
412
+ {
413
+ for (auto& w : rawDoc.words) doc.words.emplace_back(w);
414
+ }
415
+ else
416
+ {
417
+ throw std::invalid_argument{ "Either `words` or `rawWords` must be filled." };
418
+ }
318
419
  return doc;
319
420
  }
320
421
 
321
422
  template<bool _const, typename _FnTokenizer>
322
- DocType _makeRawDoc(const std::string& rawStr, _FnTokenizer&& tokenizer, Float weight = 1)
423
+ DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer)
323
424
  {
324
- DocType doc{ weight };
325
- doc.rawStr = rawStr;
425
+ DocType doc{ rawDoc };
426
+ doc.rawStr = rawDoc.rawStr;
326
427
  for (auto& p : tokenizer(doc.rawStr))
327
428
  {
328
429
  Vid wid;
@@ -452,10 +553,11 @@ namespace tomoto
452
553
  return realV;
453
554
  }
454
555
 
455
- void updateVocab(const std::vector<std::string>& words) override
556
+ bool updateVocab(const std::vector<std::string>& words) override
456
557
  {
457
- if(dict.size()) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "updateVocab after addDoc");
458
- for(auto& w : words) dict.add(w);
558
+ bool empty = dict.size() == 0;
559
+ for (auto& w : words) dict.add(w);
560
+ return empty;
459
561
  }
460
562
 
461
563
  void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
@@ -606,6 +708,18 @@ namespace tomoto
606
708
  return vid2String(getWidsByDocSorted(doc, topN));
607
709
  }
608
710
 
711
+ double getDocLL(const DocumentBase* doc) const override
712
+ {
713
+ auto* p = dynamic_cast<const DocType*>(doc);
714
+ if (!p) throw std::invalid_argument{ "wrong `doc` type." };
715
+ return static_cast<const _Derived*>(this)->getLLDocs(p, p + 1);
716
+ }
717
+
718
+ double getStateLL() const override
719
+ {
720
+ return static_cast<const _Derived*>(this)->getLLRest(this->globalState);
721
+ }
722
+
609
723
  std::vector<double> infer(const std::vector<DocumentBase*>& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const override
610
724
  {
611
725
  if (!numWorkers) numWorkers = std::thread::hardware_concurrency();
@@ -651,12 +765,18 @@ namespace tomoto
651
765
  return extractTopN<Tid>(getTopicsByDoc(doc), topN);
652
766
  }
653
767
 
654
-
655
768
  const DocumentBase* getDoc(size_t docId) const override
656
769
  {
657
770
  return &_getDoc(docId);
658
771
  }
659
772
 
773
+ size_t getDocIdByUid(const std::string& docUid) const override
774
+ {
775
+ auto it = uidMap.find(SharedString{ docUid });
776
+ if (it == uidMap.end()) return -1;
777
+ return it->second;
778
+ }
779
+
660
780
  size_t getGlobalStep() const override
661
781
  {
662
782
  return globalStep;