tomoto 0.2.2 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (45) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/tomoto/ct.cpp +11 -11
  4. data/ext/tomoto/dmr.cpp +14 -13
  5. data/ext/tomoto/dt.cpp +14 -14
  6. data/ext/tomoto/ext.cpp +7 -7
  7. data/ext/tomoto/extconf.rb +1 -3
  8. data/ext/tomoto/gdmr.cpp +7 -7
  9. data/ext/tomoto/hdp.cpp +9 -9
  10. data/ext/tomoto/hlda.cpp +13 -13
  11. data/ext/tomoto/hpa.cpp +5 -5
  12. data/ext/tomoto/lda.cpp +42 -39
  13. data/ext/tomoto/llda.cpp +6 -6
  14. data/ext/tomoto/mglda.cpp +15 -15
  15. data/ext/tomoto/pa.cpp +6 -6
  16. data/ext/tomoto/plda.cpp +6 -6
  17. data/ext/tomoto/slda.cpp +8 -8
  18. data/ext/tomoto/utils.h +16 -70
  19. data/lib/tomoto/version.rb +1 -1
  20. data/vendor/tomotopy/README.kr.rst +57 -0
  21. data/vendor/tomotopy/README.rst +55 -0
  22. data/vendor/tomotopy/src/Labeling/Phraser.hpp +3 -3
  23. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +5 -2
  24. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +5 -2
  25. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +5 -2
  26. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +4 -4
  27. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +5 -2
  28. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -2
  29. data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
  30. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +3 -3
  31. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +34 -14
  32. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +5 -2
  33. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +2 -2
  34. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
  35. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +5 -2
  36. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +5 -2
  37. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +4 -1
  38. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +48 -21
  39. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
  40. data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
  41. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
  42. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
  43. data/vendor/tomotopy/src/Utils/math.h +2 -2
  44. data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
  45. metadata +6 -6
@@ -335,7 +335,10 @@ namespace tomoto
335
335
  friend typename BaseClass::BaseClass;
336
336
  using WeightType = typename BaseClass::WeightType;
337
337
 
338
- static constexpr char TMID[] = "hLDA";
338
+ static constexpr auto tmid()
339
+ {
340
+ return serializer::to_key("hLDA");
341
+ }
339
342
 
340
343
  Float gamma;
341
344
 
@@ -422,7 +425,7 @@ namespace tomoto
422
425
  }
423
426
 
424
427
  template<int _inc>
425
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid level) const
428
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid level) const
426
429
  {
427
430
  assert(vid < this->realV);
428
431
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -143,7 +143,7 @@ namespace tomoto
143
143
  }
144
144
 
145
145
  template<int _inc>
146
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid z1, Tid z2) const
146
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
147
147
  {
148
148
  assert(vid < this->realV);
149
149
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -540,7 +540,7 @@ namespace tomoto
540
540
  return ret;
541
541
  }
542
542
 
543
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
543
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
544
544
  {
545
545
  std::vector<Float> ret(1 + this->K + K2);
546
546
  Float sum = doc.getSumWordWeight() + this->alphas.sum();
@@ -121,7 +121,7 @@ namespace tomoto
121
121
 
122
122
  void updateSumWordWeight(size_t realV)
123
123
  {
124
- sumWordWeight = std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
124
+ sumWordWeight = (int32_t)std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
125
125
  {
126
126
  return w < realV;
127
127
  });
@@ -164,8 +164,8 @@ namespace tomoto
164
164
  struct LDAArgs
165
165
  {
166
166
  size_t k = 1;
167
- std::vector<Float> alpha = { 0.1 };
168
- Float eta = 0.01;
167
+ std::vector<Float> alpha = { (Float)0.1 };
168
+ Float eta = (Float)0.01;
169
169
  size_t seed = std::random_device{}();
170
170
  };
171
171
 
@@ -82,7 +82,7 @@ namespace tomoto
82
82
  friend BaseClass;
83
83
 
84
84
  static constexpr const char TWID[] = "one\0";
85
- static constexpr static constexpr char TMID[] = "LDA\0";
85
+ static constexpr const char TMID[] = "LDA\0";
86
86
 
87
87
  Float alpha;
88
88
  Vector alphas;
@@ -125,7 +125,7 @@ namespace tomoto
125
125
  }
126
126
 
127
127
  template<int _Inc, typename _Vec>
128
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, _Vec tDist) const
128
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, _Vec tDist) const
129
129
  {
130
130
  assert(vid < this->realV);
131
131
  constexpr bool _dec = _Inc < 0;
@@ -392,7 +392,7 @@ namespace tomoto
392
392
  return static_cast<const DerivedClass*>(this)->_getTopicsCount();
393
393
  }
394
394
 
395
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
395
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc) const
396
396
  {
397
397
  std::vector<Float> ret(K);
398
398
  Float sum = doc.getSumWordWeight() + K * alpha;
@@ -117,19 +117,28 @@ namespace tomoto
117
117
  template<>
118
118
  struct TwId<TermWeight::one>
119
119
  {
120
- static constexpr char TWID[] = "one\0";
120
+ static constexpr auto twid()
121
+ {
122
+ return serializer::to_key("one\0");
123
+ }
121
124
  };
122
125
 
123
126
  template<>
124
127
  struct TwId<TermWeight::idf>
125
128
  {
126
- static constexpr char TWID[] = "idf\0";
129
+ static constexpr auto twid()
130
+ {
131
+ return serializer::to_key("idf\0");
132
+ }
127
133
  };
128
134
 
129
135
  template<>
130
136
  struct TwId<TermWeight::pmi>
131
137
  {
132
- static constexpr char TWID[] = "pmi\0";
138
+ static constexpr auto twid()
139
+ {
140
+ return serializer::to_key("pmi\0");
141
+ }
133
142
  };
134
143
 
135
144
  // to make HDP friend of LDA for HDPModel::converToLDA
@@ -169,7 +178,11 @@ namespace tomoto
169
178
  typename>
170
179
  friend class HDPModel;
171
180
 
172
- static constexpr char TMID[] = "LDA\0";
181
+ static constexpr auto tmid()
182
+ {
183
+ return serializer::to_key("LDA\0");
184
+ }
185
+
173
186
  using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
174
187
 
175
188
  enum { m_flags = _Flags };
@@ -189,7 +202,7 @@ namespace tomoto
189
202
  struct ExtraDocData
190
203
  {
191
204
  std::vector<Vid> vChunkOffset;
192
- Eigen::Matrix<uint32_t, -1, -1> chunkOffsetByDoc;
205
+ Eigen::Matrix<size_t, -1, -1> chunkOffsetByDoc;
193
206
  };
194
207
 
195
208
  ExtraDocData eddTrain;
@@ -261,7 +274,7 @@ namespace tomoto
261
274
  }
262
275
 
263
276
  template<int _inc>
264
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid) const
277
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid) const
265
278
  {
266
279
  assert(tid < K);
267
280
  assert(vid < this->realV);
@@ -620,7 +633,7 @@ namespace tomoto
620
633
  for (Vid v = 0; v < V; ++v)
621
634
  {
622
635
  if (!ld.numByTopicWord(k, v)) continue;
623
- ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
636
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(k, v)) - math::lgammaT(etaByTopicWord(k, v));
624
637
  assert(std::isfinite(ll));
625
638
  }
626
639
  }
@@ -972,12 +985,14 @@ namespace tomoto
972
985
 
973
986
  void setOptimInterval(size_t _optimInterval) override
974
987
  {
975
- optimInterval = _optimInterval;
988
+ if (_optimInterval > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
989
+ optimInterval = (uint32_t)_optimInterval;
976
990
  }
977
991
 
978
992
  void setBurnInIteration(size_t iteration) override
979
993
  {
980
- burnIn = iteration;
994
+ if (iteration > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
995
+ burnIn = (uint32_t)iteration;
981
996
  }
982
997
 
983
998
  size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
@@ -1008,6 +1023,11 @@ namespace tomoto
1008
1023
  if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
1009
1024
  }
1010
1025
  this->dict.add(word);
1026
+ if (this->dict.size() > this->vocabCf.size())
1027
+ {
1028
+ this->vocabCf.resize(this->dict.size());
1029
+ this->vocabDf.resize(this->dict.size());
1030
+ }
1011
1031
  etaByWord.emplace(word, priors);
1012
1032
  }
1013
1033
 
@@ -1049,7 +1069,7 @@ namespace tomoto
1049
1069
  if (initDocs)
1050
1070
  {
1051
1071
  std::vector<uint32_t> df, cf, tf;
1052
- uint32_t totCf;
1072
+ size_t totCf;
1053
1073
 
1054
1074
  // calculate weighting
1055
1075
  if (_tw != TermWeight::one)
@@ -1064,14 +1084,14 @@ namespace tomoto
1064
1084
  ++df[w];
1065
1085
  }
1066
1086
  }
1067
- totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
1087
+ totCf = std::accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
1068
1088
  }
1069
1089
  if (_tw == TermWeight::idf)
1070
1090
  {
1071
1091
  vocabWeights.resize(V);
1072
1092
  for (size_t i = 0; i < V; ++i)
1073
1093
  {
1074
- vocabWeights[i] = log(this->docs.size() / (Float)df[i]);
1094
+ vocabWeights[i] = (Float)log(this->docs.size() / (double)df[i]);
1075
1095
  }
1076
1096
  }
1077
1097
  else if (_tw == TermWeight::pmi)
@@ -1079,7 +1099,7 @@ namespace tomoto
1079
1099
  vocabWeights.resize(V);
1080
1100
  for (size_t i = 0; i < V; ++i)
1081
1101
  {
1082
- vocabWeights[i] = this->vocabCf[i] / (float)totCf;
1102
+ vocabWeights[i] = (Float)(this->vocabCf[i] / (double)totCf);
1083
1103
  }
1084
1104
  }
1085
1105
 
@@ -1104,7 +1124,7 @@ namespace tomoto
1104
1124
  return static_cast<const DerivedClass*>(this)->_getTopicsCount();
1105
1125
  }
1106
1126
 
1107
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
1127
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
1108
1128
  {
1109
1129
  std::vector<Float> ret(K);
1110
1130
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
@@ -26,7 +26,10 @@ namespace tomoto
26
26
  friend typename BaseClass::BaseClass;
27
27
  using WeightType = typename BaseClass::WeightType;
28
28
 
29
- static constexpr char TMID[] = "LLDA";
29
+ static constexpr auto tmid()
30
+ {
31
+ return serializer::to_key("LLDA");
32
+ }
30
33
 
31
34
  Dictionary topicLabelDict;
32
35
 
@@ -171,7 +174,7 @@ namespace tomoto
171
174
  return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
172
175
  }
173
176
 
174
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
177
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
175
178
  {
176
179
  std::vector<Float> ret(this->K);
177
180
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
@@ -63,7 +63,7 @@ namespace tomoto
63
63
  }
64
64
 
65
65
  template<int _inc>
66
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
66
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
67
67
  {
68
68
  const auto K = this->K;
69
69
 
@@ -527,7 +527,7 @@ namespace tomoto
527
527
  this->etaByWord.emplace(word, priors);
528
528
  }
529
529
 
530
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
530
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
531
531
  {
532
532
  std::vector<Float> ret(this->K + KL);
533
533
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
@@ -90,7 +90,7 @@ namespace tomoto
90
90
  }
91
91
 
92
92
  template<int _inc>
93
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid z1, Tid z2) const
93
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
94
94
  {
95
95
  assert(vid < this->realV);
96
96
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -26,7 +26,10 @@ namespace tomoto
26
26
  friend typename BaseClass::BaseClass;
27
27
  using WeightType = typename BaseClass::WeightType;
28
28
 
29
- static constexpr char TMID[] = "PLDA";
29
+ static constexpr auto tmid()
30
+ {
31
+ return serializer::to_key("PLDA");
32
+ }
30
33
 
31
34
  Dictionary topicLabelDict;
32
35
 
@@ -178,7 +181,7 @@ namespace tomoto
178
181
  return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
179
182
  }
180
183
 
181
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
184
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
182
185
  {
183
186
  std::vector<Float> ret(this->K);
184
187
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
@@ -38,7 +38,10 @@ namespace tomoto
38
38
  friend typename BaseClass::BaseClass;
39
39
  using WeightType = typename BaseClass::WeightType;
40
40
 
41
- static constexpr char TMID[] = "PTM";
41
+ static constexpr auto tmid()
42
+ {
43
+ return serializer::to_key("PTM");
44
+ }
42
45
 
43
46
  uint64_t numPDocs;
44
47
  Float lambda;
@@ -261,7 +264,7 @@ namespace tomoto
261
264
  {
262
265
  }
263
266
 
264
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
267
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
265
268
  {
266
269
  std::vector<Float> ret(this->K);
267
270
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
@@ -216,7 +216,10 @@ namespace tomoto
216
216
  friend typename BaseClass::BaseClass;
217
217
  using WeightType = typename BaseClass::WeightType;
218
218
 
219
- static constexpr char TMID[] = "SLDA";
219
+ static constexpr auto tmid()
220
+ {
221
+ return serializer::to_key("SLDA");
222
+ }
220
223
 
221
224
  uint64_t F; // number of response variables
222
225
  std::vector<ISLDAModel::GLM> varTypes;
@@ -249,6 +249,7 @@ namespace tomoto
249
249
  virtual size_t getNumDocs() const = 0;
250
250
  virtual const Dictionary& getVocabDict() const = 0;
251
251
  virtual const std::vector<uint64_t>& getVocabCf() const = 0;
252
+ virtual std::vector<double> getVocabWeightedCf() const = 0;
252
253
  virtual const std::vector<uint64_t>& getVocabDf() const = 0;
253
254
 
254
255
  virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
@@ -319,6 +320,7 @@ namespace tomoto
319
320
  Dictionary dict;
320
321
  uint64_t realV = 0; // vocab size after removing stopwords
321
322
  uint64_t realN = 0; // total word size after removing stopwords
323
+ double weightedN = 0;
322
324
  size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
323
325
  size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
324
326
 
@@ -327,15 +329,17 @@ namespace tomoto
327
329
  void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
328
330
  {
329
331
  serializer::writeMany(writer,
330
- serializer::to_keyz(static_cast<const _Derived*>(this)->TMID),
331
- serializer::to_keyz(static_cast<const _Derived*>(this)->TWID));
332
+ serializer::to_keyz(static_cast<const _Derived*>(this)->tmid()),
333
+ serializer::to_keyz(static_cast<const _Derived*>(this)->twid())
334
+ );
332
335
  serializer::writeTaggedMany(writer, 0x00010001,
333
336
  serializer::to_keyz("dict"), dict,
334
337
  serializer::to_keyz("vocabCf"), vocabCf,
335
338
  serializer::to_keyz("vocabDf"), vocabDf,
336
339
  serializer::to_keyz("realV"), realV,
337
340
  serializer::to_keyz("globalStep"), globalStep,
338
- serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0));
341
+ serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
342
+ );
339
343
  serializer::writeMany(writer, *static_cast<const _Derived*>(this));
340
344
  globalState.serializerWrite(writer);
341
345
  if (fullModel)
@@ -355,8 +359,9 @@ namespace tomoto
355
359
  {
356
360
  std::vector<uint8_t> extra;
357
361
  serializer::readMany(reader,
358
- serializer::to_keyz(static_cast<_Derived*>(this)->TMID),
359
- serializer::to_keyz(static_cast<_Derived*>(this)->TWID));
362
+ serializer::to_keyz(static_cast<_Derived*>(this)->tmid()),
363
+ serializer::to_keyz(static_cast<_Derived*>(this)->twid())
364
+ );
360
365
  serializer::readTaggedMany(reader, 0x00010001,
361
366
  serializer::to_keyz("dict"), dict,
362
367
  serializer::to_keyz("vocabCf"), vocabCf,
@@ -370,14 +375,17 @@ namespace tomoto
370
375
  {
371
376
  reader.seekg(start_pos);
372
377
  serializer::readMany(reader,
373
- serializer::to_key(static_cast<_Derived*>(this)->TMID),
374
- serializer::to_key(static_cast<_Derived*>(this)->TWID),
375
- dict, vocabCf, realV);
378
+ serializer::to_key(static_cast<_Derived*>(this)->tmid()),
379
+ serializer::to_key(static_cast<_Derived*>(this)->twid()),
380
+ dict, vocabCf, realV
381
+ );
376
382
  }
377
383
  serializer::readMany(reader, *static_cast<_Derived*>(this));
378
384
  globalState.serializerRead(reader);
379
385
  serializer::readMany(reader, docs);
380
- realN = countRealN();
386
+ auto p = countRealN();
387
+ realN = p.first;
388
+ weightedN = p.second;
381
389
  }
382
390
 
383
391
  template<typename _DocTy>
@@ -490,17 +498,23 @@ namespace tomoto
490
498
  }
491
499
  }
492
500
 
493
- size_t countRealN() const
501
+ std::pair<size_t, double> countRealN() const
494
502
  {
495
503
  size_t n = 0;
504
+ double weighted = 0;
496
505
  for (auto& doc : docs)
497
506
  {
498
- for (auto& w : doc.words)
507
+ for (size_t i = 0; i < doc.words.size(); ++i)
499
508
  {
500
- if (w < realV) ++n;
509
+ auto w = doc.words[i];
510
+ if (w < realV)
511
+ {
512
+ ++n;
513
+ weighted += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
514
+ }
501
515
  }
502
516
  }
503
- return n;
517
+ return std::make_pair(n, weighted);
504
518
  }
505
519
 
506
520
  void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN)
@@ -544,14 +558,9 @@ namespace tomoto
544
558
  }
545
559
 
546
560
  dict.reorder(order);
547
- realN = 0;
548
561
  for (auto& doc : docs)
549
562
  {
550
- for (auto& w : doc.words)
551
- {
552
- w = order[w];
553
- if (w < realV) ++realN;
554
- }
563
+ for (auto& w : doc.words) w = order[w];
555
564
  }
556
565
  }
557
566
 
@@ -598,6 +607,10 @@ namespace tomoto
598
607
 
599
608
  void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
600
609
  {
610
+ auto p = countRealN();
611
+ realN = p.first;
612
+ weightedN = p.second;
613
+
601
614
  maxThreads[(size_t)ParallelScheme::default_] = -1;
602
615
  maxThreads[(size_t)ParallelScheme::none] = -1;
603
616
  maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads<ParallelScheme::copy_merge>();
@@ -697,7 +710,7 @@ namespace tomoto
697
710
 
698
711
  double getLLPerWord() const override
699
712
  {
700
- return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / realN;
713
+ return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / weightedN;
701
714
  }
702
715
 
703
716
  double getPerplexity() const override
@@ -797,7 +810,7 @@ namespace tomoto
797
810
 
798
811
  std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
799
812
  {
800
- return static_cast<const _Derived*>(this)->getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
813
+ return static_cast<const _Derived*>(this)->_getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
801
814
  }
802
815
 
803
816
  std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
@@ -832,6 +845,20 @@ namespace tomoto
832
845
  return vocabCf;
833
846
  }
834
847
 
848
+ std::vector<double> getVocabWeightedCf() const override
849
+ {
850
+ std::vector<double> ret(realV);
851
+ for (auto& doc : docs)
852
+ {
853
+ for (size_t i = 0; i < doc.words.size(); ++i)
854
+ {
855
+ if (doc.words[i] >= realV) continue;
856
+ ret[doc.words[i]] += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
857
+ }
858
+ }
859
+ return ret;
860
+ }
861
+
835
862
  const std::vector<uint64_t>& getVocabDf() const override
836
863
  {
837
864
  return vocabDf;