tomoto 0.3.3 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +11 -0
  3. data/README.md +1 -1
  4. data/ext/tomoto/extconf.rb +4 -2
  5. data/lib/tomoto/version.rb +1 -1
  6. data/lib/tomoto.rb +14 -14
  7. data/vendor/tomotopy/README.kr.rst +27 -1
  8. data/vendor/tomotopy/README.rst +27 -1
  9. data/vendor/tomotopy/src/TopicModel/CT.h +2 -2
  10. data/vendor/tomotopy/src/TopicModel/CTModel.cpp +5 -0
  11. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +1 -0
  12. data/vendor/tomotopy/src/TopicModel/DMR.h +2 -2
  13. data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +5 -0
  14. data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +1 -0
  15. data/vendor/tomotopy/src/TopicModel/DT.h +2 -2
  16. data/vendor/tomotopy/src/TopicModel/DTModel.cpp +5 -0
  17. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +4 -0
  18. data/vendor/tomotopy/src/TopicModel/GDMR.h +2 -2
  19. data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +5 -0
  20. data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +1 -0
  21. data/vendor/tomotopy/src/TopicModel/HDP.h +2 -2
  22. data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +5 -0
  23. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +2 -0
  24. data/vendor/tomotopy/src/TopicModel/HLDA.h +2 -2
  25. data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +5 -0
  26. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +9 -0
  27. data/vendor/tomotopy/src/TopicModel/HPA.h +2 -2
  28. data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +5 -0
  29. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -0
  30. data/vendor/tomotopy/src/TopicModel/LDA.h +8 -2
  31. data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +5 -0
  32. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +8 -0
  33. data/vendor/tomotopy/src/TopicModel/LLDA.h +2 -2
  34. data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +5 -0
  35. data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +1 -0
  36. data/vendor/tomotopy/src/TopicModel/MGLDA.h +2 -2
  37. data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +5 -0
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -1
  39. data/vendor/tomotopy/src/TopicModel/PA.h +2 -2
  40. data/vendor/tomotopy/src/TopicModel/PAModel.cpp +5 -0
  41. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +7 -0
  42. data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +1 -0
  43. data/vendor/tomotopy/src/TopicModel/PT.h +3 -3
  44. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +5 -0
  45. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +1 -0
  46. data/vendor/tomotopy/src/TopicModel/SLDA.h +3 -2
  47. data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +5 -0
  48. data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +1 -0
  49. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +83 -3
  50. data/vendor/tomotopy/src/Utils/Dictionary.cpp +102 -0
  51. data/vendor/tomotopy/src/Utils/Dictionary.h +26 -75
  52. data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +1 -1
  53. data/vendor/tomotopy/src/Utils/Mmap.cpp +146 -0
  54. data/vendor/tomotopy/src/Utils/Mmap.h +139 -0
  55. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -0
  56. data/vendor/tomotopy/src/Utils/SharedString.cpp +134 -0
  57. data/vendor/tomotopy/src/Utils/SharedString.h +104 -0
  58. data/vendor/tomotopy/src/Utils/serializer.cpp +166 -0
  59. data/vendor/tomotopy/src/Utils/serializer.hpp +261 -85
  60. metadata +12 -7
  61. data/vendor/tomotopy/src/Utils/SharedString.hpp +0 -206
@@ -15,8 +15,8 @@ namespace tomoto
15
15
 
16
16
  template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
17
17
 
18
- DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, Z2s);
19
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, Z2s);
18
+ DECLARE_SERIALIZER_WITH_VERSION(0);
19
+ DECLARE_SERIALIZER_WITH_VERSION(1);
20
20
  };
21
21
 
22
22
  struct PAArgs : public LDAArgs
@@ -2,6 +2,11 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
+ DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPA, BaseDocument, 0, Z2s);
6
+ DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPA, BaseDocument, 1, 0x00010001, Z2s);
7
+
8
+ TMT_INSTANTIATE_DOC(DocumentPA);
9
+
5
10
  IPAModel* IPAModel::create(TermWeight _weight, const PAArgs& args, bool scalarRng)
6
11
  {
7
12
  TMT_SWITCH_TW(_weight, scalarRng, PAModel, args);
@@ -19,6 +19,7 @@ namespace tomoto
19
19
  Vector subTmp;
20
20
 
21
21
  DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
22
+ DEFINE_HASHER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
22
23
  };
23
24
 
24
25
  template<TermWeight _tw, typename _RandGen,
@@ -364,6 +365,7 @@ namespace tomoto
364
365
  public:
365
366
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
366
367
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
368
+ DEFINE_HASHER_AFTER_BASE(BaseClass, K2, subAlphas, subAlphaSum);
367
369
 
368
370
  PAModel(const PAArgs& args)
369
371
  : BaseClass(args), K2(args.k2)
@@ -460,6 +462,11 @@ namespace tomoto
460
462
  return ret;
461
463
  }
462
464
 
465
+ size_t getNumTopicsForPrior() const override
466
+ {
467
+ return this->K2;
468
+ }
469
+
463
470
  void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
464
471
  {
465
472
  if (priors.size() != K2) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K2.");
@@ -111,6 +111,7 @@ namespace tomoto
111
111
  public:
112
112
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
113
113
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
114
+ DEFINE_HASHER_AFTER_BASE(BaseClass, topicLabelDict, numLatentTopics, numTopicsPerLabel);
114
115
 
115
116
  PLDAModel(const PLDAArgs& args)
116
117
  : BaseClass(args.setK(1)),
@@ -11,9 +11,9 @@ namespace tomoto
11
11
  using WeightType = typename DocumentLDA<_tw>::WeightType;
12
12
 
13
13
  uint64_t pseudoDoc = 0;
14
-
15
- DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, pseudoDoc);
16
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, pseudoDoc);
14
+
15
+ DECLARE_SERIALIZER_WITH_VERSION(0);
16
+ DECLARE_SERIALIZER_WITH_VERSION(1);
17
17
  };
18
18
 
19
19
  struct PTArgs : public LDAArgs
@@ -2,6 +2,11 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
+ DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPT, BaseDocument, 0, pseudoDoc);
6
+ DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPT, BaseDocument, 1, 0x00010001, pseudoDoc);
7
+
8
+ TMT_INSTANTIATE_DOC(DocumentPT);
9
+
5
10
  IPTModel* IPTModel::create(TermWeight _weight, const PTArgs& args, bool scalarRng)
6
11
  {
7
12
  TMT_SWITCH_TW(_weight, scalarRng, PTModel, args);
@@ -266,6 +266,7 @@ namespace tomoto
266
266
  public:
267
267
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
268
268
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
269
+ DEFINE_HASHER_AFTER_BASE(BaseClass, numPDocs, lambda);
269
270
 
270
271
  GETTER(P, size_t, numPDocs);
271
272
 
@@ -16,8 +16,9 @@ namespace tomoto
16
16
  ret["y"] = y;
17
17
  return ret;
18
18
  }
19
- DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, y);
20
- DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, y);
19
+
20
+ DECLARE_SERIALIZER_WITH_VERSION(0);
21
+ DECLARE_SERIALIZER_WITH_VERSION(1);
21
22
  };
22
23
 
23
24
  struct SLDAArgs;
@@ -2,6 +2,11 @@
2
2
 
3
3
  namespace tomoto
4
4
  {
5
+ DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentSLDA, BaseDocument, 0, y);
6
+ DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentSLDA, BaseDocument, 1, 0x00010001, y);
7
+
8
+ TMT_INSTANTIATE_DOC(DocumentSLDA);
9
+
5
10
  ISLDAModel* ISLDAModel::create(TermWeight _weight, const SLDAArgs& args, bool scalarRng)
6
11
  {
7
12
  TMT_SWITCH_TW(_weight, scalarRng, SLDAModel, args);
@@ -348,6 +348,7 @@ namespace tomoto
348
348
  public:
349
349
  DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, F, responseVars, mu, nuSq);
350
350
  DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, F, responseVars, mu, nuSq);
351
+ DEFINE_HASHER_AFTER_BASE(BaseClass, F, mu, nuSq);
351
352
 
352
353
  SLDAModel(const SLDAArgs& args)
353
354
  : BaseClass(args), F(args.vars.size()), varTypes(args.vars),
@@ -1,4 +1,4 @@
1
- #pragma once
1
+ #pragma once
2
2
  #include <numeric>
3
3
  #include <unordered_set>
4
4
  #include "../Utils/Utils.hpp"
@@ -7,7 +7,7 @@
7
7
  #include "../Utils/ThreadPool.hpp"
8
8
  #include "../Utils/serializer.hpp"
9
9
  #include "../Utils/exception.h"
10
- #include "../Utils/SharedString.hpp"
10
+ #include "../Utils/SharedString.h"
11
11
  #include <EigenRand/EigenRand>
12
12
  #include <mapbox/variant.hpp>
13
13
 
@@ -107,7 +107,7 @@ namespace tomoto
107
107
 
108
108
  virtual operator RawDoc() const
109
109
  {
110
- RawDoc raw{ *this };
110
+ RawDoc raw{ *static_cast<const RawDocKernel*>(this) };
111
111
  if (wOrder.empty())
112
112
  {
113
113
  raw.words.insert(raw.words.begin(), words.begin(), words.end());
@@ -224,6 +224,8 @@ namespace tomoto
224
224
  virtual void loadModel(std::istream& reader,
225
225
  std::vector<uint8_t>* extra_data = nullptr) = 0;
226
226
 
227
+ virtual std::array<uint64_t, 2> getHash() const = 0;
228
+
227
229
  virtual std::unique_ptr<ITopicModel> copy() const = 0;
228
230
 
229
231
  virtual const DocumentBase* getDoc(size_t docId) const = 0;
@@ -251,14 +253,17 @@ namespace tomoto
251
253
  virtual const std::vector<uint64_t>& getVocabCf() const = 0;
252
254
  virtual std::vector<double> getVocabWeightedCf() const = 0;
253
255
  virtual const std::vector<uint64_t>& getVocabDf() const = 0;
256
+ virtual const std::vector<std::vector<std::pair<std::string, size_t>>>& getWordFormCnts() const = 0;
254
257
 
255
258
  virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
256
259
  virtual size_t getGlobalStep() const = 0;
257
260
  virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) = 0;
258
261
 
259
262
  virtual size_t getK() const = 0;
263
+ virtual size_t getNumTopicsForPrior() const = 0;
260
264
  virtual std::vector<Float> getWidsByTopic(size_t tid, bool normalize = true) const = 0;
261
265
  virtual std::vector<std::pair<std::string, Float>> getWordsByTopicSorted(size_t tid, size_t topN) const = 0;
266
+ virtual std::vector<std::tuple<std::string, Vid, Float>> getWordIdsByTopicSorted(size_t tid, size_t topN) const = 0;
262
267
 
263
268
  virtual std::vector<std::pair<std::string, Float>> getWordsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
264
269
 
@@ -318,6 +323,7 @@ namespace tomoto
318
323
  size_t globalStep = 0;
319
324
  _ModelState globalState, tState;
320
325
  Dictionary dict;
326
+ std::vector<std::vector<std::pair<std::string, size_t>>> wordFormCnts;
321
327
  uint64_t realV = 0; // vocab size after removing stopwords
322
328
  uint64_t realN = 0; // total word size after removing stopwords
323
329
  double weightedN = 0;
@@ -564,6 +570,44 @@ namespace tomoto
564
570
  }
565
571
  }
566
572
 
573
+ void updateWordFormCnts()
574
+ {
575
+ wordFormCnts.clear();
576
+ wordFormCnts.resize(realV);
577
+ std::vector<std::unordered_map<std::string, size_t>> cnts(realV);
578
+ for (auto& doc : docs)
579
+ {
580
+ for (size_t i = 0; i < doc.words.size(); ++i)
581
+ {
582
+ auto w = doc.words[doc.wOrder.empty() ? i : doc.wOrder[i]];
583
+ if (w >= realV) continue;
584
+ auto& cnt = cnts[w];
585
+ std::string word;
586
+ if (!doc.rawStr.empty() && i < doc.origWordPos.size())
587
+ {
588
+ word = doc.rawStr.substr(doc.origWordPos[i], doc.origWordLen[i]);
589
+ }
590
+ else
591
+ {
592
+ word = dict.toWord(w);
593
+ }
594
+ ++cnt[word];
595
+ }
596
+ }
597
+
598
+ for (size_t i = 0; i < realV; ++i)
599
+ {
600
+ auto& cnt = cnts[i];
601
+ std::vector<std::pair<std::string, size_t>> v{ std::make_move_iterator(cnt.begin()), std::make_move_iterator(cnt.end()) };
602
+ std::sort(v.begin(), v.end(), [](const std::pair<std::string, size_t>& a, const std::pair<std::string, size_t>& b)
603
+ {
604
+ return a.second > b.second;
605
+ });
606
+ wordFormCnts[i] = move(v);
607
+ cnt.clear();
608
+ }
609
+ }
610
+
567
611
  int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
568
612
  {
569
613
  throw e;
@@ -725,6 +769,11 @@ namespace tomoto
725
769
  return 0;
726
770
  }
727
771
 
772
+ size_t getNumTopicsForPrior() const override
773
+ {
774
+ return this->getK();
775
+ }
776
+
728
777
  std::vector<Float> getWidsByTopic(size_t tid, bool normalize) const override
729
778
  {
730
779
  return static_cast<const _Derived*>(this)->_getWidsByTopic(tid, normalize);
@@ -745,11 +794,26 @@ namespace tomoto
745
794
  return ret;
746
795
  }
747
796
 
797
+ std::vector<std::tuple<std::string, Vid, Float>> vid2StringVid(const std::vector<std::pair<Vid, Float>>& vids) const
798
+ {
799
+ std::vector<std::tuple<std::string, Vid, Float>> ret(vids.size());
800
+ for (size_t i = 0; i < vids.size(); ++i)
801
+ {
802
+ ret[i] = std::make_tuple(dict.toWord(vids[i].first), vids[i].first, vids[i].second);
803
+ }
804
+ return ret;
805
+ }
806
+
748
807
  std::vector<std::pair<std::string, Float>> getWordsByTopicSorted(size_t tid, size_t topN) const override
749
808
  {
750
809
  return vid2String(getWidsByTopicSorted(tid, topN));
751
810
  }
752
811
 
812
+ std::vector<std::tuple<std::string, Vid, Float>> getWordIdsByTopicSorted(size_t tid, size_t topN) const override
813
+ {
814
+ return vid2StringVid(getWidsByTopicSorted(tid, topN));
815
+ }
816
+
753
817
  std::vector<std::pair<Vid, Float>> getWidsByDocSorted(const DocumentBase* doc, size_t topN) const
754
818
  {
755
819
  std::vector<Float> cnt(dict.size());
@@ -866,6 +930,11 @@ namespace tomoto
866
930
  return vocabDf;
867
931
  }
868
932
 
933
+ const std::vector<std::vector<std::pair<std::string, size_t>>>& getWordFormCnts() const override
934
+ {
935
+ return wordFormCnts;
936
+ }
937
+
869
938
  void saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const override
870
939
  {
871
940
  static_cast<const _Derived*>(this)->_saveModel(writer, fullModel, extra_data);
@@ -876,6 +945,17 @@ namespace tomoto
876
945
  static_cast<_Derived*>(this)->_loadModel(reader, extra_data);
877
946
  static_cast<_Derived*>(this)->prepare(false);
878
947
  }
948
+
949
+ std::array<uint64_t, 2> getHash() const override
950
+ {
951
+ std::array<uint64_t, 2> ret;
952
+ ret[0] = dict.computeHash(0);
953
+ const std::string s = static_cast<const _Derived*>(this)->tmid().str() + static_cast<const _Derived*>(this)->twid().str();
954
+ ret[0] = serializer::computeHashMany(ret[0], s, realV, globalStep, docs.size());
955
+ ret[1] = globalState.computeHash(0);
956
+ ret[1] = static_cast<const _Derived*>(this)->computeHash(ret[1]);
957
+ return ret;
958
+ }
879
959
  };
880
960
 
881
961
  }
@@ -0,0 +1,102 @@
1
+ #include "Dictionary.h"
2
+
3
+ namespace tomoto
4
+ {
5
+ Dictionary::Dictionary() = default;
6
+ Dictionary::~Dictionary() = default;
7
+
8
+ Dictionary::Dictionary(const Dictionary&) = default;
9
+ Dictionary& Dictionary::operator=(const Dictionary&) = default;
10
+
11
+ Dictionary::Dictionary(Dictionary&&) noexcept = default;
12
+ Dictionary& Dictionary::operator=(Dictionary&&) noexcept = default;
13
+
14
+ Vid Dictionary::add(const std::string& word)
15
+ {
16
+ auto it = dict.find(word);
17
+ if (it == dict.end())
18
+ {
19
+ dict.emplace(word, (Vid)dict.size());
20
+ id2word.emplace_back(word);
21
+ return (Vid)(dict.size() - 1);
22
+ }
23
+ return it->second;
24
+ }
25
+
26
+ const std::string& Dictionary::toWord(Vid vid) const
27
+ {
28
+ assert(vid < id2word.size());
29
+ return id2word[vid];
30
+ }
31
+
32
+ Vid Dictionary::toWid(const std::string& word) const
33
+ {
34
+ auto it = dict.find(word);
35
+ if (it == dict.end()) return non_vocab_id;
36
+ return it->second;
37
+ }
38
+
39
+ void Dictionary::serializerWrite(std::ostream& writer) const
40
+ {
41
+ serializer::writeMany(writer, serializer::to_key("Dict"), id2word);
42
+ }
43
+
44
+ void Dictionary::serializerRead(std::istream& reader)
45
+ {
46
+ serializer::readMany(reader, serializer::to_key("Dict"), id2word);
47
+ for (size_t i = 0; i < id2word.size(); ++i)
48
+ {
49
+ dict.emplace(id2word[i], (Vid)i);
50
+ }
51
+ }
52
+
53
+ uint64_t Dictionary::computeHash(uint64_t seed) const
54
+ {
55
+ return serializer::computeHashMany(seed, id2word);
56
+ }
57
+
58
+ void Dictionary::swap(Dictionary& rhs)
59
+ {
60
+ std::swap(dict, rhs.dict);
61
+ std::swap(id2word, rhs.id2word);
62
+ }
63
+
64
+ void Dictionary::reorder(const std::vector<Vid>& order)
65
+ {
66
+ for (auto& p : dict)
67
+ {
68
+ p.second = order[p.second];
69
+ id2word[p.second] = p.first;
70
+ }
71
+ }
72
+
73
+ const std::vector<std::string>& Dictionary::getRaw() const
74
+ {
75
+ return id2word;
76
+ }
77
+
78
+ Vid Dictionary::mapToNewDict(Vid v, const Dictionary& newDict) const
79
+ {
80
+ return newDict.toWid(toWord(v));
81
+ }
82
+
83
+ std::vector<Vid> Dictionary::mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const
84
+ {
85
+ std::vector<Vid> r(v.size());
86
+ for (size_t i = 0; i < v.size(); ++i)
87
+ {
88
+ r[i] = mapToNewDict(v[i], newDict);
89
+ }
90
+ return r;
91
+ }
92
+
93
+ std::vector<Vid> Dictionary::mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const
94
+ {
95
+ std::vector<Vid> r(v.size());
96
+ for (size_t i = 0; i < v.size(); ++i)
97
+ {
98
+ r[i] = mapToNewDict(v[i], newDict);
99
+ }
100
+ return r;
101
+ }
102
+ }
@@ -12,8 +12,9 @@ namespace tomoto
12
12
  {
13
13
  using Vid = uint32_t;
14
14
  static constexpr Vid non_vocab_id = (Vid)-1;
15
+ static constexpr Vid rm_vocab_id = (Vid)-2;
15
16
  using Tid = uint16_t;
16
- static constexpr Vid non_topic_id = (Tid)-1;
17
+ static constexpr Tid non_topic_id = (Tid)-1;
17
18
  using Float = float;
18
19
 
19
20
  struct VidPair : public std::pair<Vid, Vid>
@@ -27,91 +28,41 @@ namespace tomoto
27
28
  std::unordered_map<std::string, Vid> dict;
28
29
  std::vector<std::string> id2word;
29
30
  public:
30
- Vid add(const std::string& word)
31
- {
32
- auto it = dict.find(word);
33
- if (it == dict.end())
34
- {
35
- dict.emplace(word, (Vid)dict.size());
36
- id2word.emplace_back(word);
37
- return (Vid)(dict.size() - 1);
38
- }
39
- return it->second;
40
- }
31
+
32
+ Dictionary();
33
+ ~Dictionary();
34
+
35
+ Dictionary(const Dictionary&);
36
+ Dictionary& operator=(const Dictionary&);
37
+
38
+ Dictionary(Dictionary&&) noexcept;
39
+ Dictionary& operator=(Dictionary&&) noexcept;
40
+
41
+ Vid add(const std::string& word);
41
42
 
42
43
  size_t size() const { return dict.size(); }
43
44
 
44
- const std::string& toWord(Vid vid) const
45
- {
46
- assert(vid < id2word.size());
47
- return id2word[vid];
48
- }
45
+ const std::string& toWord(Vid vid) const;
49
46
 
50
- Vid toWid(const std::string& word) const
51
- {
52
- auto it = dict.find(word);
53
- if (it == dict.end()) return non_vocab_id;
54
- return it->second;
55
- }
47
+ Vid toWid(const std::string& word) const;
56
48
 
57
- void serializerWrite(std::ostream& writer) const
58
- {
59
- serializer::writeMany(writer, serializer::to_key("Dict"), id2word);
60
- }
49
+ void serializerWrite(std::ostream& writer) const;
61
50
 
62
- void serializerRead(std::istream& reader)
63
- {
64
- serializer::readMany(reader, serializer::to_key("Dict"), id2word);
65
- for (size_t i = 0; i < id2word.size(); ++i)
66
- {
67
- dict.emplace(id2word[i], (Vid)i);
68
- }
69
- }
51
+ void serializerRead(std::istream& reader);
70
52
 
71
- void swap(Dictionary& rhs)
72
- {
73
- std::swap(dict, rhs.dict);
74
- std::swap(id2word, rhs.id2word);
75
- }
53
+ uint64_t computeHash(uint64_t seed) const;
76
54
 
77
- void reorder(const std::vector<Vid>& order)
78
- {
79
- for (auto& p : dict)
80
- {
81
- p.second = order[p.second];
82
- id2word[p.second] = p.first;
83
- }
84
- }
55
+ void swap(Dictionary& rhs);
85
56
 
86
- const std::vector<std::string>& getRaw() const
87
- {
88
- return id2word;
89
- }
57
+ void reorder(const std::vector<Vid>& order);
90
58
 
91
- Vid mapToNewDict(Vid v, const Dictionary& newDict) const
92
- {
93
- return newDict.toWid(toWord(v));
94
- }
59
+ const std::vector<std::string>& getRaw() const;
95
60
 
96
- std::vector<Vid> mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const
97
- {
98
- std::vector<Vid> r(v.size());
99
- for (size_t i = 0; i < v.size(); ++i)
100
- {
101
- r[i] = mapToNewDict(v[i], newDict);
102
- }
103
- return r;
104
- }
61
+ Vid mapToNewDict(Vid v, const Dictionary& newDict) const;
105
62
 
106
- std::vector<Vid> mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const
107
- {
108
- std::vector<Vid> r(v.size());
109
- for (size_t i = 0; i < v.size(); ++i)
110
- {
111
- r[i] = mapToNewDict(v[i], newDict);
112
- }
113
- return r;
114
- }
63
+ std::vector<Vid> mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const;
64
+
65
+ std::vector<Vid> mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const;
115
66
  };
116
67
 
117
68
  }
@@ -126,4 +77,4 @@ namespace std
126
77
  return hash<size_t>{}(p.first) ^ hash<size_t>{}(p.second);
127
78
  }
128
79
  };
129
- }
80
+ }
@@ -116,7 +116,7 @@ namespace Eigen
116
116
 
117
117
  EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4f& a)
118
118
  {
119
- return vcvtq_f32_s32(vandq_s32(a, vdupq_n_s32(1)));
119
+ return vcvtq_f32_s32(vandq_s32((Packet4i)a, vdupq_n_s32(1)));
120
120
  }
121
121
 
122
122
  EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4i& a)
@@ -0,0 +1,146 @@
1
+ #include <cstdint>
2
+ #include "Mmap.h"
3
+
4
+ namespace tomoto
5
+ {
6
+ namespace utils
7
+ {
8
+ static std::u16string utf8To16(const std::string& str)
9
+ {
10
+ std::u16string ret;
11
+ for (auto it = str.begin(); it != str.end(); ++it)
12
+ {
13
+ uint32_t code = 0;
14
+ uint32_t byte = (uint8_t)*it;
15
+ if ((byte & 0xF8) == 0xF0)
16
+ {
17
+ code = (uint32_t)((byte & 0x07) << 18);
18
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
19
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
20
+ code |= (uint32_t)((byte & 0x3F) << 12);
21
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
22
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
23
+ code |= (uint32_t)((byte & 0x3F) << 6);
24
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
25
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
26
+ code |= (byte & 0x3F);
27
+ }
28
+ else if ((byte & 0xF0) == 0xE0)
29
+ {
30
+ code = (uint32_t)((byte & 0x0F) << 12);
31
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
32
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
33
+ code |= (uint32_t)((byte & 0x3F) << 6);
34
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
35
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
36
+ code |= (byte & 0x3F);
37
+ }
38
+ else if ((byte & 0xE0) == 0xC0)
39
+ {
40
+ code = (uint32_t)((byte & 0x1F) << 6);
41
+ if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
42
+ if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
43
+ code |= (byte & 0x3F);
44
+ }
45
+ else if ((byte & 0x80) == 0x00)
46
+ {
47
+ code = byte;
48
+ }
49
+ else
50
+ {
51
+ throw std::invalid_argument{ "unicode error" };
52
+ }
53
+
54
+ if (code < 0x10000)
55
+ {
56
+ ret.push_back((char16_t)code);
57
+ }
58
+ else if (code < 0x10FFFF)
59
+ {
60
+ code -= 0x10000;
61
+ ret.push_back((char16_t)(0xD800 | (code >> 10)));
62
+ ret.push_back((char16_t)(0xDC00 | (code & 0x3FF)));
63
+ }
64
+ else
65
+ {
66
+ throw std::invalid_argument{ "unicode error" };
67
+ }
68
+ }
69
+ return ret;
70
+ }
71
+ }
72
+ }
73
+
74
+ namespace tomoto
75
+ {
76
+ namespace utils
77
+ {
78
+ MMap::MMap(const std::string& filepath)
79
+ {
80
+ #ifdef _WIN32
81
+ hFile = CreateFileW((const wchar_t*)utf8To16(filepath).c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_READONLY, nullptr);
82
+ if (hFile == INVALID_HANDLE_VALUE) throw std::ios_base::failure("Cannot open '" + filepath + "'");
83
+ hFileMap = CreateFileMapping(hFile, nullptr, PAGE_READONLY, 0, 0, nullptr);
84
+ if (hFileMap == nullptr) throw std::ios_base::failure("Cannot open '" + filepath + "' Code:" + std::to_string(GetLastError()));
85
+ view = (const char*)MapViewOfFile(hFileMap, FILE_MAP_READ, 0, 0, 0);
86
+ if (!view) throw std::ios_base::failure("Cannot MapViewOfFile() Code:" + std::to_string(GetLastError()));
87
+ DWORD high;
88
+ len = GetFileSize(hFile, &high);
89
+ len |= (uint64_t)high << 32;
90
+ #else
91
+ fd = open(filepath.c_str(), O_RDONLY);
92
+ if (fd == -1) throw std::ios_base::failure("Cannot open '" + filepath + "'");
93
+ struct stat sb;
94
+ if (fstat(fd, &sb) < 0) throw std::ios_base::failure("Cannot open '" + filepath + "'");
95
+ len = sb.st_size;
96
+ view = (const char*)mmap(nullptr, len, PROT_READ, MAP_PRIVATE, fd, 0);
97
+ if (view == MAP_FAILED) throw std::ios_base::failure("Mapping failed");
98
+ #endif
99
+ }
100
+
101
+ #ifdef _WIN32
102
+ MMap::MMap(MMap&& o) noexcept
103
+ : view{ o.view }, len{ o.len }
104
+ {
105
+ o.view = nullptr;
106
+ std::swap(hFile, o.hFile);
107
+ std::swap(hFileMap, o.hFileMap);
108
+ }
109
+ #else
110
+ MMap::MMap(MMap&& o) noexcept
111
+ : len{ o.len }, fd{ std::move(o.fd) }
112
+ {
113
+ std::swap(view, o.view);
114
+ }
115
+ #endif
116
+
117
+ MMap& MMap::operator=(MMap&& o) noexcept
118
+ {
119
+ std::swap(view, o.view);
120
+ std::swap(len, o.len);
121
+ #ifdef _WIN32
122
+ std::swap(hFile, o.hFile);
123
+ std::swap(hFileMap, o.hFileMap);
124
+ #else
125
+ std::swap(fd, o.fd);
126
+ #endif
127
+ return *this;
128
+ }
129
+
130
+ MMap::~MMap()
131
+ {
132
+ #ifdef _WIN32
133
+ if (hFileMap)
134
+ {
135
+ UnmapViewOfFile(view);
136
+ view = nullptr;
137
+ }
138
+ #else
139
+ if (view)
140
+ {
141
+ munmap((void*)view, len);
142
+ }
143
+ #endif
144
+ }
145
+ }
146
+ }