tomoto 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (45) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +3 -3
  5. data/ext/tomoto/ext.cpp +34 -9
  6. data/ext/tomoto/extconf.rb +2 -1
  7. data/lib/tomoto/dmr.rb +1 -1
  8. data/lib/tomoto/gdmr.rb +1 -1
  9. data/lib/tomoto/version.rb +1 -1
  10. data/vendor/tomotopy/LICENSE +1 -1
  11. data/vendor/tomotopy/README.kr.rst +32 -3
  12. data/vendor/tomotopy/README.rst +30 -1
  13. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +133 -147
  14. data/vendor/tomotopy/src/Labeling/FoRelevance.h +158 -5
  15. data/vendor/tomotopy/src/TopicModel/DMR.h +1 -16
  16. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +15 -34
  17. data/vendor/tomotopy/src/TopicModel/DT.h +1 -16
  18. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +15 -32
  19. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +18 -37
  20. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +16 -20
  21. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +3 -3
  22. data/vendor/tomotopy/src/TopicModel/LDA.h +0 -11
  23. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +9 -21
  24. data/vendor/tomotopy/src/TopicModel/LLDA.h +0 -15
  25. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +12 -30
  26. data/vendor/tomotopy/src/TopicModel/MGLDA.h +0 -15
  27. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +59 -72
  28. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +12 -30
  29. data/vendor/tomotopy/src/TopicModel/SLDA.h +0 -15
  30. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +17 -35
  31. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +158 -38
  32. data/vendor/tomotopy/src/Utils/Dictionary.h +40 -2
  33. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +122 -3
  34. data/vendor/tomotopy/src/Utils/SharedString.hpp +181 -0
  35. data/vendor/tomotopy/src/Utils/math.h +1 -1
  36. data/vendor/tomotopy/src/Utils/sample.hpp +1 -1
  37. data/vendor/tomotopy/src/Utils/serializer.hpp +17 -0
  38. data/vendor/variant/LICENSE +25 -0
  39. data/vendor/variant/LICENSE_1_0.txt +23 -0
  40. data/vendor/variant/README.md +102 -0
  41. data/vendor/variant/include/mapbox/optional.hpp +74 -0
  42. data/vendor/variant/include/mapbox/recursive_wrapper.hpp +122 -0
  43. data/vendor/variant/include/mapbox/variant.hpp +974 -0
  44. data/vendor/variant/include/mapbox/variant_io.hpp +45 -0
  45. metadata +15 -7
@@ -153,46 +153,28 @@ namespace tomoto
153
153
  return doc;
154
154
  }
155
155
 
156
- size_t addDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) override
156
+ size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
157
157
  {
158
- auto doc = this->_makeDoc(words);
159
- return this->_addDoc(_updateDoc(doc, labels));
158
+ auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
159
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
160
160
  }
161
161
 
162
- std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<std::string>& labels) const override
162
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
163
163
  {
164
- auto doc = as_mutable(this)->template _makeDoc<true>(words);
165
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
164
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
165
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
166
166
  }
167
167
 
168
- size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
169
- const std::vector<std::string>& labels) override
168
+ size_t addDoc(const RawDoc& rawDoc) override
170
169
  {
171
- auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
172
- return this->_addDoc(_updateDoc(doc, labels));
170
+ auto doc = this->_makeFromRawDoc(rawDoc);
171
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
173
172
  }
174
173
 
175
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
176
- const std::vector<std::string>& labels) const override
174
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
177
175
  {
178
- auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
179
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
180
- }
181
-
182
- size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
183
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
184
- const std::vector<std::string>& labels) override
185
- {
186
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
187
- return this->_addDoc(_updateDoc(doc, labels));
188
- }
189
-
190
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
191
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
192
- const std::vector<std::string>& labels) const override
193
- {
194
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
195
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, labels));
176
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
177
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
196
178
  }
197
179
 
198
180
  std::vector<Float> getTopicsByDoc(const _DocType& doc) const
@@ -31,21 +31,6 @@ namespace tomoto
31
31
  size_t seed = std::random_device{}(),
32
32
  bool scalarRng = false);
33
33
 
34
- virtual size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) = 0;
35
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const = 0;
36
-
37
- virtual size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
38
- const std::vector<Float>& y) = 0;
39
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
40
- const std::vector<Float>& y) const = 0;
41
-
42
- virtual size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
43
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
44
- const std::vector<Float>& y) = 0;
45
- virtual std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
46
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
47
- const std::vector<Float>& y) const = 0;
48
-
49
34
  virtual size_t getF() const = 0;
50
35
  virtual std::vector<Float> getRegressionCoef(size_t f) const = 0;
51
36
  virtual GLM getTypeOfVar(size_t f) const = 0;
@@ -148,7 +148,7 @@ namespace tomoto
148
148
  + Eigen::Matrix<Float, -1, 1>::Constant(selectedNormZ.rows(), mu / nuSq));
149
149
 
150
150
  RandGen rng;
151
- for (size_t i = 0; i < omega.size(); ++i)
151
+ for (size_t i = 0; i < (size_t)omega.size(); ++i)
152
152
  {
153
153
  if (std::isnan(ys[i])) continue;
154
154
  omega[i] = math::drawPolyaGamma(b, (this->regressionCoef.array() * normZ.col(i).array()).sum(), rng);
@@ -358,8 +358,8 @@ namespace tomoto
358
358
  if (_const)
359
359
  {
360
360
  if (y.size() > F) throw std::runtime_error{ text::format(
361
- "size of 'y' is greater than the number of vars.\n"
362
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
361
+ "size of `y` is greater than the number of vars.\n"
362
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
363
363
  doc.y = y;
364
364
  while (doc.y.size() < F)
365
365
  {
@@ -369,53 +369,35 @@ namespace tomoto
369
369
  else
370
370
  {
371
371
  if (y.size() != F) throw std::runtime_error{ text::format(
372
- "size of 'y' must be equal to the number of vars.\n"
373
- "size of 'y' : %zd, number of vars: %zd", y.size(), F) };
372
+ "size of `y` must be equal to the number of vars.\n"
373
+ "size of `y` : %zd, number of vars: %zd", y.size(), F) };
374
374
  doc.y = y;
375
375
  }
376
376
  return doc;
377
377
  }
378
378
 
379
- size_t addDoc(const std::vector<std::string>& words, const std::vector<Float>& y) override
379
+ size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
380
380
  {
381
- auto doc = this->_makeDoc(words);
382
- return this->_addDoc(_updateDoc(doc, y));
381
+ auto doc = this->template _makeFromRawDoc<false>(rawDoc, tokenizer);
382
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
383
383
  }
384
384
 
385
- std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words, const std::vector<Float>& y) const override
385
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
386
386
  {
387
- auto doc = as_mutable(this)->template _makeDoc<true>(words);
388
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
387
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
388
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
389
389
  }
390
390
 
391
- size_t addDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
392
- const std::vector<Float>& y) override
391
+ size_t addDoc(const RawDoc& rawDoc) override
393
392
  {
394
- auto doc = this->template _makeRawDoc<false>(rawStr, tokenizer);
395
- return this->_addDoc(_updateDoc(doc, y));
393
+ auto doc = this->_makeFromRawDoc(rawDoc);
394
+ return this->_addDoc(_updateDoc(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
396
395
  }
397
396
 
398
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const RawDocTokenizer::Factory& tokenizer,
399
- const std::vector<Float>& y) const override
397
+ std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
400
398
  {
401
- auto doc = as_mutable(this)->template _makeRawDoc<true>(rawStr, tokenizer);
402
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
403
- }
404
-
405
- size_t addDoc(const std::string& rawStr, const std::vector<Vid>& words,
406
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
407
- const std::vector<Float>& y) override
408
- {
409
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
410
- return this->_addDoc(_updateDoc(doc, y));
411
- }
412
-
413
- std::unique_ptr<DocumentBase> makeDoc(const std::string& rawStr, const std::vector<Vid>& words,
414
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len,
415
- const std::vector<Float>& y) const override
416
- {
417
- auto doc = this->_makeRawDoc(rawStr, words, pos, len);
418
- return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, y));
399
+ auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
400
+ return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<Float>>("y")));
419
401
  }
420
402
 
421
403
  std::vector<Float> estimateVars(const DocumentBase* doc) const override
@@ -7,27 +7,112 @@
7
7
  #include "../Utils/ThreadPool.hpp"
8
8
  #include "../Utils/serializer.hpp"
9
9
  #include "../Utils/exception.h"
10
+ #include "../Utils/SharedString.hpp"
10
11
  #include <EigenRand/EigenRand>
12
+ #include <mapbox/variant.hpp>
11
13
 
12
14
  namespace tomoto
13
15
  {
14
16
  using RandGen = Eigen::Rand::P8_mt19937_64<uint32_t>;
15
17
  using ScalarRandGen = Eigen::Rand::UniversalRandomEngine<uint32_t, std::mt19937_64>;
16
18
 
17
- class DocumentBase
19
+ struct RawDocKernel
18
20
  {
19
- public:
20
21
  Float weight = 1;
22
+ SharedString docUid;
23
+ SharedString rawStr;
24
+ std::vector<uint32_t> origWordPos;
25
+ std::vector<uint16_t> origWordLen;
26
+
27
+ RawDocKernel(const RawDocKernel&) = default;
28
+ RawDocKernel(RawDocKernel&&) = default;
29
+
30
+ RawDocKernel(Float _weight = 1, const SharedString& _docUid = {})
31
+ : weight{ _weight }, docUid{ _docUid }
32
+ {
33
+ }
34
+ };
35
+
36
+ struct RawDoc : public RawDocKernel
37
+ {
38
+ using Var = mapbox::util::variant<
39
+ std::string, uint32_t, Float,
40
+ std::vector<std::string>, std::vector<uint32_t>, std::vector<Float>,
41
+ std::shared_ptr<void>
42
+ >;
43
+ using MiscType = std::unordered_map<std::string, Var>;
44
+
45
+ std::vector<Vid> words;
46
+ std::vector<std::string> rawWords;
47
+ MiscType misc;
48
+
49
+ RawDoc() = default;
50
+ RawDoc(const RawDoc&) = default;
51
+ RawDoc(RawDoc&&) = default;
52
+
53
+ RawDoc(const RawDocKernel& o)
54
+ : RawDocKernel{ o }
55
+ {
56
+ }
57
+
58
+ template<typename _Ty>
59
+ const _Ty& getMisc(const std::string& name) const
60
+ {
61
+ auto it = misc.find(name);
62
+ if (it == misc.end()) throw std::invalid_argument{ "There is no value named `" + name + "` in misc data" };
63
+ if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
64
+ return it->second.template get<_Ty>();
65
+ }
66
+
67
+ template<typename _Ty>
68
+ _Ty getMiscDefault(const std::string& name) const
69
+ {
70
+ auto it = misc.find(name);
71
+ if (it == misc.end()) return {};
72
+ if (!it->second.template is<_Ty>()) throw std::invalid_argument{ "Value named `" + name + "` is not in right type." };
73
+ return it->second.template get<_Ty>();
74
+ }
75
+ };
76
+
77
+ class DocumentBase : public RawDocKernel
78
+ {
79
+ public:
21
80
  tvector<Vid> words; // word id of each word
22
81
  std::vector<uint32_t> wOrder; // original word order (optional)
82
+
83
+ DocumentBase(const DocumentBase&) = default;
84
+ DocumentBase(DocumentBase&&) = default;
85
+
86
+ DocumentBase(const RawDocKernel& o)
87
+ : RawDocKernel{ o }
88
+ {
89
+ }
90
+
91
+ DocumentBase(Float _weight = 1, const SharedString& _docUid = {})
92
+ : RawDocKernel{ _weight, _docUid }
93
+ {
94
+ }
23
95
 
24
- std::string docUid;
25
- std::string rawStr;
26
- std::vector<uint32_t> origWordPos;
27
- std::vector<uint16_t> origWordLen;
28
- DocumentBase(Float _weight = 1) : weight(_weight) {}
29
96
  virtual ~DocumentBase() {}
30
97
 
98
+ virtual operator RawDoc() const
99
+ {
100
+ RawDoc raw{ *this };
101
+ if (wOrder.empty())
102
+ {
103
+ raw.words.insert(raw.words.begin(), words.begin(), words.end());
104
+ }
105
+ else
106
+ {
107
+ raw.words.resize(words.size());
108
+ for (size_t i = 0; i < words.size(); ++i)
109
+ {
110
+ raw.words[i] = words[wOrder[i]];
111
+ }
112
+ }
113
+ return raw;
114
+ }
115
+
31
116
  DEFINE_SERIALIZER_WITH_VERSION(0, serializer::to_key("Docu"), weight, words, wOrder);
32
117
  DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, weight, words, wOrder,
33
118
  rawStr, origWordPos, origWordLen,
@@ -127,8 +212,20 @@ namespace tomoto
127
212
  virtual void loadModel(std::istream& reader,
128
213
  std::vector<uint8_t>* extra_data = nullptr) = 0;
129
214
  virtual const DocumentBase* getDoc(size_t docId) const = 0;
215
+ virtual size_t getDocIdByUid(const std::string& docUid) const = 0;
216
+
217
+ // it tokenizes rawDoc.rawStr to get words, pos and len of the document
218
+ virtual size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) = 0;
219
+ virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const = 0;
220
+
221
+ // it uses words, pos and len of rawDoc itself.
222
+ virtual size_t addDoc(const RawDoc& rawDoc) = 0;
223
+ virtual std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const = 0;
224
+
225
+ virtual bool updateVocab(const std::vector<std::string>& words) = 0;
130
226
 
131
- virtual void updateVocab(const std::vector<std::string>& words) = 0;
227
+ virtual double getDocLL(const DocumentBase* doc) const = 0;
228
+ virtual double getStateLL() const = 0;
132
229
 
133
230
  virtual double getLLPerWord() const = 0;
134
231
  virtual double getPerplexity() const = 0;
@@ -201,6 +298,7 @@ namespace tomoto
201
298
  std::vector<DocType> docs;
202
299
  std::vector<uint64_t> vocabCf;
203
300
  std::vector<uint64_t> vocabDf;
301
+ std::unordered_map<SharedString, size_t> uidMap;
204
302
  size_t globalStep = 0;
205
303
  _ModelState globalState, tState;
206
304
  Dictionary dict;
@@ -273,6 +371,8 @@ namespace tomoto
273
371
  >::value, size_t>::type _addDoc(_DocTy&& doc)
274
372
  {
275
373
  if (doc.words.empty()) return -1;
374
+ if (!doc.docUid.empty() && uidMap.count(doc.docUid))
375
+ throw exception::InvalidArgument{ "there is a document with uid = " + std::string{ doc.docUid } + " already." };
276
376
  size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end());
277
377
  if (vocabCf.size() <= maxWid)
278
378
  {
@@ -282,47 +382,48 @@ namespace tomoto
282
382
  for (auto w : doc.words) ++vocabCf[w];
283
383
  std::unordered_set<Vid> uniq{ doc.words.begin(), doc.words.end() };
284
384
  for (auto w : uniq) ++vocabDf[w];
385
+ uidMap.emplace(doc.docUid, docs.size());
285
386
  docs.emplace_back(std::forward<_DocTy>(doc));
286
387
  return docs.size() - 1;
287
388
  }
288
389
 
289
390
  template<bool _const = false>
290
- DocType _makeDoc(const std::vector<std::string>& words, Float weight = 1)
391
+ DocType _makeFromRawDoc(const RawDoc& rawDoc)
291
392
  {
292
- DocType doc{ weight };
293
- for (auto& w : words)
393
+ DocType doc{ rawDoc };
394
+ if (!rawDoc.rawWords.empty())
294
395
  {
295
- Vid id;
296
- if (_const)
396
+ for (auto& w : rawDoc.rawWords)
297
397
  {
298
- id = dict.toWid(w);
299
- if (id == (Vid)-1) continue;
300
- }
301
- else
302
- {
303
- id = dict.add(w);
398
+ Vid id;
399
+ if (_const)
400
+ {
401
+ id = dict.toWid(w);
402
+ if (id == (Vid)-1) continue;
403
+ }
404
+ else
405
+ {
406
+ id = dict.add(w);
407
+ }
408
+ doc.words.emplace_back(id);
304
409
  }
305
- doc.words.emplace_back(id);
306
410
  }
307
- return doc;
308
- }
309
-
310
- DocType _makeRawDoc(const std::string& rawStr, const std::vector<Vid>& words,
311
- const std::vector<uint32_t>& pos, const std::vector<uint16_t>& len, Float weight = 1) const
312
- {
313
- DocType doc{ weight };
314
- doc.rawStr = rawStr;
315
- for (auto& w : words) doc.words.emplace_back(w);
316
- doc.origWordPos = pos;
317
- doc.origWordLen = len;
411
+ else if(!rawDoc.words.empty())
412
+ {
413
+ for (auto& w : rawDoc.words) doc.words.emplace_back(w);
414
+ }
415
+ else
416
+ {
417
+ throw std::invalid_argument{ "Either `words` or `rawWords` must be filled." };
418
+ }
318
419
  return doc;
319
420
  }
320
421
 
321
422
  template<bool _const, typename _FnTokenizer>
322
- DocType _makeRawDoc(const std::string& rawStr, _FnTokenizer&& tokenizer, Float weight = 1)
423
+ DocType _makeFromRawDoc(const RawDoc& rawDoc, _FnTokenizer&& tokenizer)
323
424
  {
324
- DocType doc{ weight };
325
- doc.rawStr = rawStr;
425
+ DocType doc{ rawDoc };
426
+ doc.rawStr = rawDoc.rawStr;
326
427
  for (auto& p : tokenizer(doc.rawStr))
327
428
  {
328
429
  Vid wid;
@@ -452,10 +553,11 @@ namespace tomoto
452
553
  return realV;
453
554
  }
454
555
 
455
- void updateVocab(const std::vector<std::string>& words) override
556
+ bool updateVocab(const std::vector<std::string>& words) override
456
557
  {
457
- if(dict.size()) THROW_ERROR_WITH_INFO(exception::InvalidArgument, "updateVocab after addDoc");
458
- for(auto& w : words) dict.add(w);
558
+ bool empty = dict.size() == 0;
559
+ for (auto& w : words) dict.add(w);
560
+ return empty;
459
561
  }
460
562
 
461
563
  void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
@@ -606,6 +708,18 @@ namespace tomoto
606
708
  return vid2String(getWidsByDocSorted(doc, topN));
607
709
  }
608
710
 
711
+ double getDocLL(const DocumentBase* doc) const override
712
+ {
713
+ auto* p = dynamic_cast<const DocType*>(doc);
714
+ if (!p) throw std::invalid_argument{ "wrong `doc` type." };
715
+ return static_cast<const _Derived*>(this)->getLLDocs(p, p + 1);
716
+ }
717
+
718
+ double getStateLL() const override
719
+ {
720
+ return static_cast<const _Derived*>(this)->getLLRest(this->globalState);
721
+ }
722
+
609
723
  std::vector<double> infer(const std::vector<DocumentBase*>& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const override
610
724
  {
611
725
  if (!numWorkers) numWorkers = std::thread::hardware_concurrency();
@@ -651,12 +765,18 @@ namespace tomoto
651
765
  return extractTopN<Tid>(getTopicsByDoc(doc), topN);
652
766
  }
653
767
 
654
-
655
768
  const DocumentBase* getDoc(size_t docId) const override
656
769
  {
657
770
  return &_getDoc(docId);
658
771
  }
659
772
 
773
+ size_t getDocIdByUid(const std::string& docUid) const override
774
+ {
775
+ auto it = uidMap.find(SharedString{ docUid });
776
+ if (it == uidMap.end()) return -1;
777
+ return it->second;
778
+ }
779
+
660
780
  size_t getGlobalStep() const override
661
781
  {
662
782
  return globalStep;