tomoto 0.2.2 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/tomoto/ct.cpp +11 -11
  4. data/ext/tomoto/dmr.cpp +14 -13
  5. data/ext/tomoto/dt.cpp +14 -14
  6. data/ext/tomoto/ext.cpp +7 -7
  7. data/ext/tomoto/extconf.rb +1 -3
  8. data/ext/tomoto/gdmr.cpp +7 -7
  9. data/ext/tomoto/hdp.cpp +9 -9
  10. data/ext/tomoto/hlda.cpp +13 -13
  11. data/ext/tomoto/hpa.cpp +5 -5
  12. data/ext/tomoto/lda.cpp +42 -39
  13. data/ext/tomoto/llda.cpp +6 -6
  14. data/ext/tomoto/mglda.cpp +15 -15
  15. data/ext/tomoto/pa.cpp +6 -6
  16. data/ext/tomoto/plda.cpp +6 -6
  17. data/ext/tomoto/slda.cpp +8 -8
  18. data/ext/tomoto/utils.h +16 -70
  19. data/lib/tomoto/version.rb +1 -1
  20. data/vendor/tomotopy/README.kr.rst +57 -0
  21. data/vendor/tomotopy/README.rst +55 -0
  22. data/vendor/tomotopy/src/Labeling/Phraser.hpp +3 -3
  23. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +5 -2
  24. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +5 -2
  25. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +5 -2
  26. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +4 -4
  27. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +5 -2
  28. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -2
  29. data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
  30. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +3 -3
  31. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +34 -14
  32. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +5 -2
  33. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +2 -2
  34. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
  35. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +5 -2
  36. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +5 -2
  37. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +4 -1
  38. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +48 -21
  39. data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
  40. data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
  41. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
  42. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
  43. data/vendor/tomotopy/src/Utils/math.h +2 -2
  44. data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
  45. metadata +6 -6
@@ -335,7 +335,10 @@ namespace tomoto
335
335
  friend typename BaseClass::BaseClass;
336
336
  using WeightType = typename BaseClass::WeightType;
337
337
 
338
- static constexpr char TMID[] = "hLDA";
338
+ static constexpr auto tmid()
339
+ {
340
+ return serializer::to_key("hLDA");
341
+ }
339
342
 
340
343
  Float gamma;
341
344
 
@@ -422,7 +425,7 @@ namespace tomoto
422
425
  }
423
426
 
424
427
  template<int _inc>
425
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid level) const
428
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid level) const
426
429
  {
427
430
  assert(vid < this->realV);
428
431
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -143,7 +143,7 @@ namespace tomoto
143
143
  }
144
144
 
145
145
  template<int _inc>
146
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid z1, Tid z2) const
146
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
147
147
  {
148
148
  assert(vid < this->realV);
149
149
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -540,7 +540,7 @@ namespace tomoto
540
540
  return ret;
541
541
  }
542
542
 
543
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
543
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
544
544
  {
545
545
  std::vector<Float> ret(1 + this->K + K2);
546
546
  Float sum = doc.getSumWordWeight() + this->alphas.sum();
@@ -121,7 +121,7 @@ namespace tomoto
121
121
 
122
122
  void updateSumWordWeight(size_t realV)
123
123
  {
124
- sumWordWeight = std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
124
+ sumWordWeight = (int32_t)std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
125
125
  {
126
126
  return w < realV;
127
127
  });
@@ -164,8 +164,8 @@ namespace tomoto
164
164
  struct LDAArgs
165
165
  {
166
166
  size_t k = 1;
167
- std::vector<Float> alpha = { 0.1 };
168
- Float eta = 0.01;
167
+ std::vector<Float> alpha = { (Float)0.1 };
168
+ Float eta = (Float)0.01;
169
169
  size_t seed = std::random_device{}();
170
170
  };
171
171
 
@@ -82,7 +82,7 @@ namespace tomoto
82
82
  friend BaseClass;
83
83
 
84
84
  static constexpr const char TWID[] = "one\0";
85
- static constexpr static constexpr char TMID[] = "LDA\0";
85
+ static constexpr const char TMID[] = "LDA\0";
86
86
 
87
87
  Float alpha;
88
88
  Vector alphas;
@@ -125,7 +125,7 @@ namespace tomoto
125
125
  }
126
126
 
127
127
  template<int _Inc, typename _Vec>
128
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, _Vec tDist) const
128
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, _Vec tDist) const
129
129
  {
130
130
  assert(vid < this->realV);
131
131
  constexpr bool _dec = _Inc < 0;
@@ -392,7 +392,7 @@ namespace tomoto
392
392
  return static_cast<const DerivedClass*>(this)->_getTopicsCount();
393
393
  }
394
394
 
395
- std::vector<Float> getTopicsByDoc(const _DocType& doc) const
395
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc) const
396
396
  {
397
397
  std::vector<Float> ret(K);
398
398
  Float sum = doc.getSumWordWeight() + K * alpha;
@@ -117,19 +117,28 @@ namespace tomoto
117
117
  template<>
118
118
  struct TwId<TermWeight::one>
119
119
  {
120
- static constexpr char TWID[] = "one\0";
120
+ static constexpr auto twid()
121
+ {
122
+ return serializer::to_key("one\0");
123
+ }
121
124
  };
122
125
 
123
126
  template<>
124
127
  struct TwId<TermWeight::idf>
125
128
  {
126
- static constexpr char TWID[] = "idf\0";
129
+ static constexpr auto twid()
130
+ {
131
+ return serializer::to_key("idf\0");
132
+ }
127
133
  };
128
134
 
129
135
  template<>
130
136
  struct TwId<TermWeight::pmi>
131
137
  {
132
- static constexpr char TWID[] = "pmi\0";
138
+ static constexpr auto twid()
139
+ {
140
+ return serializer::to_key("pmi\0");
141
+ }
133
142
  };
134
143
 
135
144
  // to make HDP friend of LDA for HDPModel::converToLDA
@@ -169,7 +178,11 @@ namespace tomoto
169
178
  typename>
170
179
  friend class HDPModel;
171
180
 
172
- static constexpr char TMID[] = "LDA\0";
181
+ static constexpr auto tmid()
182
+ {
183
+ return serializer::to_key("LDA\0");
184
+ }
185
+
173
186
  using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
174
187
 
175
188
  enum { m_flags = _Flags };
@@ -189,7 +202,7 @@ namespace tomoto
189
202
  struct ExtraDocData
190
203
  {
191
204
  std::vector<Vid> vChunkOffset;
192
- Eigen::Matrix<uint32_t, -1, -1> chunkOffsetByDoc;
205
+ Eigen::Matrix<size_t, -1, -1> chunkOffsetByDoc;
193
206
  };
194
207
 
195
208
  ExtraDocData eddTrain;
@@ -261,7 +274,7 @@ namespace tomoto
261
274
  }
262
275
 
263
276
  template<int _inc>
264
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid) const
277
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid) const
265
278
  {
266
279
  assert(tid < K);
267
280
  assert(vid < this->realV);
@@ -620,7 +633,7 @@ namespace tomoto
620
633
  for (Vid v = 0; v < V; ++v)
621
634
  {
622
635
  if (!ld.numByTopicWord(k, v)) continue;
623
- ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(v, k)) - math::lgammaT(etaByTopicWord(v, k));
636
+ ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(k, v)) - math::lgammaT(etaByTopicWord(k, v));
624
637
  assert(std::isfinite(ll));
625
638
  }
626
639
  }
@@ -972,12 +985,14 @@ namespace tomoto
972
985
 
973
986
  void setOptimInterval(size_t _optimInterval) override
974
987
  {
975
- optimInterval = _optimInterval;
988
+ if (_optimInterval > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
989
+ optimInterval = (uint32_t)_optimInterval;
976
990
  }
977
991
 
978
992
  void setBurnInIteration(size_t iteration) override
979
993
  {
980
- burnIn = iteration;
994
+ if (iteration > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
995
+ burnIn = (uint32_t)iteration;
981
996
  }
982
997
 
983
998
  size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
@@ -1008,6 +1023,11 @@ namespace tomoto
1008
1023
  if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
1009
1024
  }
1010
1025
  this->dict.add(word);
1026
+ if (this->dict.size() > this->vocabCf.size())
1027
+ {
1028
+ this->vocabCf.resize(this->dict.size());
1029
+ this->vocabDf.resize(this->dict.size());
1030
+ }
1011
1031
  etaByWord.emplace(word, priors);
1012
1032
  }
1013
1033
 
@@ -1049,7 +1069,7 @@ namespace tomoto
1049
1069
  if (initDocs)
1050
1070
  {
1051
1071
  std::vector<uint32_t> df, cf, tf;
1052
- uint32_t totCf;
1072
+ size_t totCf;
1053
1073
 
1054
1074
  // calculate weighting
1055
1075
  if (_tw != TermWeight::one)
@@ -1064,14 +1084,14 @@ namespace tomoto
1064
1084
  ++df[w];
1065
1085
  }
1066
1086
  }
1067
- totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
1087
+ totCf = std::accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
1068
1088
  }
1069
1089
  if (_tw == TermWeight::idf)
1070
1090
  {
1071
1091
  vocabWeights.resize(V);
1072
1092
  for (size_t i = 0; i < V; ++i)
1073
1093
  {
1074
- vocabWeights[i] = log(this->docs.size() / (Float)df[i]);
1094
+ vocabWeights[i] = (Float)log(this->docs.size() / (double)df[i]);
1075
1095
  }
1076
1096
  }
1077
1097
  else if (_tw == TermWeight::pmi)
@@ -1079,7 +1099,7 @@ namespace tomoto
1079
1099
  vocabWeights.resize(V);
1080
1100
  for (size_t i = 0; i < V; ++i)
1081
1101
  {
1082
- vocabWeights[i] = this->vocabCf[i] / (float)totCf;
1102
+ vocabWeights[i] = (Float)(this->vocabCf[i] / (double)totCf);
1083
1103
  }
1084
1104
  }
1085
1105
 
@@ -1104,7 +1124,7 @@ namespace tomoto
1104
1124
  return static_cast<const DerivedClass*>(this)->_getTopicsCount();
1105
1125
  }
1106
1126
 
1107
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
1127
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
1108
1128
  {
1109
1129
  std::vector<Float> ret(K);
1110
1130
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
@@ -26,7 +26,10 @@ namespace tomoto
26
26
  friend typename BaseClass::BaseClass;
27
27
  using WeightType = typename BaseClass::WeightType;
28
28
 
29
- static constexpr char TMID[] = "LLDA";
29
+ static constexpr auto tmid()
30
+ {
31
+ return serializer::to_key("LLDA");
32
+ }
30
33
 
31
34
  Dictionary topicLabelDict;
32
35
 
@@ -171,7 +174,7 @@ namespace tomoto
171
174
  return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
172
175
  }
173
176
 
174
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
177
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
175
178
  {
176
179
  std::vector<Float> ret(this->K);
177
180
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
@@ -63,7 +63,7 @@ namespace tomoto
63
63
  }
64
64
 
65
65
  template<int _inc>
66
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
66
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
67
67
  {
68
68
  const auto K = this->K;
69
69
 
@@ -527,7 +527,7 @@ namespace tomoto
527
527
  this->etaByWord.emplace(word, priors);
528
528
  }
529
529
 
530
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
530
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
531
531
  {
532
532
  std::vector<Float> ret(this->K + KL);
533
533
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
@@ -90,7 +90,7 @@ namespace tomoto
90
90
  }
91
91
 
92
92
  template<int _inc>
93
- inline void addWordTo(_ModelState& ld, _DocType& doc, uint32_t pid, Vid vid, Tid z1, Tid z2) const
93
+ inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
94
94
  {
95
95
  assert(vid < this->realV);
96
96
  constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
@@ -26,7 +26,10 @@ namespace tomoto
26
26
  friend typename BaseClass::BaseClass;
27
27
  using WeightType = typename BaseClass::WeightType;
28
28
 
29
- static constexpr char TMID[] = "PLDA";
29
+ static constexpr auto tmid()
30
+ {
31
+ return serializer::to_key("PLDA");
32
+ }
30
33
 
31
34
  Dictionary topicLabelDict;
32
35
 
@@ -178,7 +181,7 @@ namespace tomoto
178
181
  return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
179
182
  }
180
183
 
181
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
184
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
182
185
  {
183
186
  std::vector<Float> ret(this->K);
184
187
  auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
@@ -38,7 +38,10 @@ namespace tomoto
38
38
  friend typename BaseClass::BaseClass;
39
39
  using WeightType = typename BaseClass::WeightType;
40
40
 
41
- static constexpr char TMID[] = "PTM";
41
+ static constexpr auto tmid()
42
+ {
43
+ return serializer::to_key("PTM");
44
+ }
42
45
 
43
46
  uint64_t numPDocs;
44
47
  Float lambda;
@@ -261,7 +264,7 @@ namespace tomoto
261
264
  {
262
265
  }
263
266
 
264
- std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
267
+ std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
265
268
  {
266
269
  std::vector<Float> ret(this->K);
267
270
  Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
@@ -216,7 +216,10 @@ namespace tomoto
216
216
  friend typename BaseClass::BaseClass;
217
217
  using WeightType = typename BaseClass::WeightType;
218
218
 
219
- static constexpr char TMID[] = "SLDA";
219
+ static constexpr auto tmid()
220
+ {
221
+ return serializer::to_key("SLDA");
222
+ }
220
223
 
221
224
  uint64_t F; // number of response variables
222
225
  std::vector<ISLDAModel::GLM> varTypes;
@@ -249,6 +249,7 @@ namespace tomoto
249
249
  virtual size_t getNumDocs() const = 0;
250
250
  virtual const Dictionary& getVocabDict() const = 0;
251
251
  virtual const std::vector<uint64_t>& getVocabCf() const = 0;
252
+ virtual std::vector<double> getVocabWeightedCf() const = 0;
252
253
  virtual const std::vector<uint64_t>& getVocabDf() const = 0;
253
254
 
254
255
  virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
@@ -319,6 +320,7 @@ namespace tomoto
319
320
  Dictionary dict;
320
321
  uint64_t realV = 0; // vocab size after removing stopwords
321
322
  uint64_t realN = 0; // total word size after removing stopwords
323
+ double weightedN = 0;
322
324
  size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
323
325
  size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
324
326
 
@@ -327,15 +329,17 @@ namespace tomoto
327
329
  void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
328
330
  {
329
331
  serializer::writeMany(writer,
330
- serializer::to_keyz(static_cast<const _Derived*>(this)->TMID),
331
- serializer::to_keyz(static_cast<const _Derived*>(this)->TWID));
332
+ serializer::to_keyz(static_cast<const _Derived*>(this)->tmid()),
333
+ serializer::to_keyz(static_cast<const _Derived*>(this)->twid())
334
+ );
332
335
  serializer::writeTaggedMany(writer, 0x00010001,
333
336
  serializer::to_keyz("dict"), dict,
334
337
  serializer::to_keyz("vocabCf"), vocabCf,
335
338
  serializer::to_keyz("vocabDf"), vocabDf,
336
339
  serializer::to_keyz("realV"), realV,
337
340
  serializer::to_keyz("globalStep"), globalStep,
338
- serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0));
341
+ serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
342
+ );
339
343
  serializer::writeMany(writer, *static_cast<const _Derived*>(this));
340
344
  globalState.serializerWrite(writer);
341
345
  if (fullModel)
@@ -355,8 +359,9 @@ namespace tomoto
355
359
  {
356
360
  std::vector<uint8_t> extra;
357
361
  serializer::readMany(reader,
358
- serializer::to_keyz(static_cast<_Derived*>(this)->TMID),
359
- serializer::to_keyz(static_cast<_Derived*>(this)->TWID));
362
+ serializer::to_keyz(static_cast<_Derived*>(this)->tmid()),
363
+ serializer::to_keyz(static_cast<_Derived*>(this)->twid())
364
+ );
360
365
  serializer::readTaggedMany(reader, 0x00010001,
361
366
  serializer::to_keyz("dict"), dict,
362
367
  serializer::to_keyz("vocabCf"), vocabCf,
@@ -370,14 +375,17 @@ namespace tomoto
370
375
  {
371
376
  reader.seekg(start_pos);
372
377
  serializer::readMany(reader,
373
- serializer::to_key(static_cast<_Derived*>(this)->TMID),
374
- serializer::to_key(static_cast<_Derived*>(this)->TWID),
375
- dict, vocabCf, realV);
378
+ serializer::to_key(static_cast<_Derived*>(this)->tmid()),
379
+ serializer::to_key(static_cast<_Derived*>(this)->twid()),
380
+ dict, vocabCf, realV
381
+ );
376
382
  }
377
383
  serializer::readMany(reader, *static_cast<_Derived*>(this));
378
384
  globalState.serializerRead(reader);
379
385
  serializer::readMany(reader, docs);
380
- realN = countRealN();
386
+ auto p = countRealN();
387
+ realN = p.first;
388
+ weightedN = p.second;
381
389
  }
382
390
 
383
391
  template<typename _DocTy>
@@ -490,17 +498,23 @@ namespace tomoto
490
498
  }
491
499
  }
492
500
 
493
- size_t countRealN() const
501
+ std::pair<size_t, double> countRealN() const
494
502
  {
495
503
  size_t n = 0;
504
+ double weighted = 0;
496
505
  for (auto& doc : docs)
497
506
  {
498
- for (auto& w : doc.words)
507
+ for (size_t i = 0; i < doc.words.size(); ++i)
499
508
  {
500
- if (w < realV) ++n;
509
+ auto w = doc.words[i];
510
+ if (w < realV)
511
+ {
512
+ ++n;
513
+ weighted += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
514
+ }
501
515
  }
502
516
  }
503
- return n;
517
+ return std::make_pair(n, weighted);
504
518
  }
505
519
 
506
520
  void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN)
@@ -544,14 +558,9 @@ namespace tomoto
544
558
  }
545
559
 
546
560
  dict.reorder(order);
547
- realN = 0;
548
561
  for (auto& doc : docs)
549
562
  {
550
- for (auto& w : doc.words)
551
- {
552
- w = order[w];
553
- if (w < realV) ++realN;
554
- }
563
+ for (auto& w : doc.words) w = order[w];
555
564
  }
556
565
  }
557
566
 
@@ -598,6 +607,10 @@ namespace tomoto
598
607
 
599
608
  void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
600
609
  {
610
+ auto p = countRealN();
611
+ realN = p.first;
612
+ weightedN = p.second;
613
+
601
614
  maxThreads[(size_t)ParallelScheme::default_] = -1;
602
615
  maxThreads[(size_t)ParallelScheme::none] = -1;
603
616
  maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads<ParallelScheme::copy_merge>();
@@ -697,7 +710,7 @@ namespace tomoto
697
710
 
698
711
  double getLLPerWord() const override
699
712
  {
700
- return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / realN;
713
+ return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / weightedN;
701
714
  }
702
715
 
703
716
  double getPerplexity() const override
@@ -797,7 +810,7 @@ namespace tomoto
797
810
 
798
811
  std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
799
812
  {
800
- return static_cast<const _Derived*>(this)->getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
813
+ return static_cast<const _Derived*>(this)->_getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
801
814
  }
802
815
 
803
816
  std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
@@ -832,6 +845,20 @@ namespace tomoto
832
845
  return vocabCf;
833
846
  }
834
847
 
848
+ std::vector<double> getVocabWeightedCf() const override
849
+ {
850
+ std::vector<double> ret(realV);
851
+ for (auto& doc : docs)
852
+ {
853
+ for (size_t i = 0; i < doc.words.size(); ++i)
854
+ {
855
+ if (doc.words[i] >= realV) continue;
856
+ ret[doc.words[i]] += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
857
+ }
858
+ }
859
+ return ret;
860
+ }
861
+
835
862
  const std::vector<uint64_t>& getVocabDf() const override
836
863
  {
837
864
  return vocabDf;