tomoto 0.3.3 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +1 -1
- data/ext/tomoto/extconf.rb +4 -2
- data/lib/tomoto/version.rb +1 -1
- data/lib/tomoto.rb +14 -14
- data/vendor/tomotopy/README.kr.rst +27 -1
- data/vendor/tomotopy/README.rst +27 -1
- data/vendor/tomotopy/src/TopicModel/CT.h +2 -2
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/DMR.h +2 -2
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/DT.h +2 -2
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +4 -0
- data/vendor/tomotopy/src/TopicModel/GDMR.h +2 -2
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/HDP.h +2 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +2 -0
- data/vendor/tomotopy/src/TopicModel/HLDA.h +2 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +9 -0
- data/vendor/tomotopy/src/TopicModel/HPA.h +2 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -0
- data/vendor/tomotopy/src/TopicModel/LDA.h +8 -2
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +8 -0
- data/vendor/tomotopy/src/TopicModel/LLDA.h +2 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +2 -2
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -1
- data/vendor/tomotopy/src/TopicModel/PA.h +2 -2
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +7 -0
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/PT.h +3 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/SLDA.h +3 -2
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +5 -0
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +1 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +83 -3
- data/vendor/tomotopy/src/Utils/Dictionary.cpp +102 -0
- data/vendor/tomotopy/src/Utils/Dictionary.h +26 -75
- data/vendor/tomotopy/src/Utils/EigenAddonOps.hpp +1 -1
- data/vendor/tomotopy/src/Utils/Mmap.cpp +146 -0
- data/vendor/tomotopy/src/Utils/Mmap.h +139 -0
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -0
- data/vendor/tomotopy/src/Utils/SharedString.cpp +134 -0
- data/vendor/tomotopy/src/Utils/SharedString.h +104 -0
- data/vendor/tomotopy/src/Utils/serializer.cpp +166 -0
- data/vendor/tomotopy/src/Utils/serializer.hpp +261 -85
- metadata +12 -7
- data/vendor/tomotopy/src/Utils/SharedString.hpp +0 -206
@@ -15,8 +15,8 @@ namespace tomoto
|
|
15
15
|
|
16
16
|
template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
|
17
17
|
|
18
|
-
|
19
|
-
|
18
|
+
DECLARE_SERIALIZER_WITH_VERSION(0);
|
19
|
+
DECLARE_SERIALIZER_WITH_VERSION(1);
|
20
20
|
};
|
21
21
|
|
22
22
|
struct PAArgs : public LDAArgs
|
@@ -2,6 +2,11 @@
|
|
2
2
|
|
3
3
|
namespace tomoto
|
4
4
|
{
|
5
|
+
DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPA, BaseDocument, 0, Z2s);
|
6
|
+
DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPA, BaseDocument, 1, 0x00010001, Z2s);
|
7
|
+
|
8
|
+
TMT_INSTANTIATE_DOC(DocumentPA);
|
9
|
+
|
5
10
|
IPAModel* IPAModel::create(TermWeight _weight, const PAArgs& args, bool scalarRng)
|
6
11
|
{
|
7
12
|
TMT_SWITCH_TW(_weight, scalarRng, PAModel, args);
|
@@ -19,6 +19,7 @@ namespace tomoto
|
|
19
19
|
Vector subTmp;
|
20
20
|
|
21
21
|
DEFINE_SERIALIZER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
|
22
|
+
DEFINE_HASHER_AFTER_BASE(ModelStateLDA<_tw>, numByTopic1_2, numByTopic2);
|
22
23
|
};
|
23
24
|
|
24
25
|
template<TermWeight _tw, typename _RandGen,
|
@@ -364,6 +365,7 @@ namespace tomoto
|
|
364
365
|
public:
|
365
366
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, K2, subAlphas, subAlphaSum);
|
366
367
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, K2, subAlphas, subAlphaSum);
|
368
|
+
DEFINE_HASHER_AFTER_BASE(BaseClass, K2, subAlphas, subAlphaSum);
|
367
369
|
|
368
370
|
PAModel(const PAArgs& args)
|
369
371
|
: BaseClass(args), K2(args.k2)
|
@@ -460,6 +462,11 @@ namespace tomoto
|
|
460
462
|
return ret;
|
461
463
|
}
|
462
464
|
|
465
|
+
size_t getNumTopicsForPrior() const override
|
466
|
+
{
|
467
|
+
return this->K2;
|
468
|
+
}
|
469
|
+
|
463
470
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
464
471
|
{
|
465
472
|
if (priors.size() != K2) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K2.");
|
@@ -111,6 +111,7 @@ namespace tomoto
|
|
111
111
|
public:
|
112
112
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
113
113
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
114
|
+
DEFINE_HASHER_AFTER_BASE(BaseClass, topicLabelDict, numLatentTopics, numTopicsPerLabel);
|
114
115
|
|
115
116
|
PLDAModel(const PLDAArgs& args)
|
116
117
|
: BaseClass(args.setK(1)),
|
@@ -11,9 +11,9 @@ namespace tomoto
|
|
11
11
|
using WeightType = typename DocumentLDA<_tw>::WeightType;
|
12
12
|
|
13
13
|
uint64_t pseudoDoc = 0;
|
14
|
-
|
15
|
-
|
16
|
-
|
14
|
+
|
15
|
+
DECLARE_SERIALIZER_WITH_VERSION(0);
|
16
|
+
DECLARE_SERIALIZER_WITH_VERSION(1);
|
17
17
|
};
|
18
18
|
|
19
19
|
struct PTArgs : public LDAArgs
|
@@ -2,6 +2,11 @@
|
|
2
2
|
|
3
3
|
namespace tomoto
|
4
4
|
{
|
5
|
+
DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPT, BaseDocument, 0, pseudoDoc);
|
6
|
+
DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentPT, BaseDocument, 1, 0x00010001, pseudoDoc);
|
7
|
+
|
8
|
+
TMT_INSTANTIATE_DOC(DocumentPT);
|
9
|
+
|
5
10
|
IPTModel* IPTModel::create(TermWeight _weight, const PTArgs& args, bool scalarRng)
|
6
11
|
{
|
7
12
|
TMT_SWITCH_TW(_weight, scalarRng, PTModel, args);
|
@@ -266,6 +266,7 @@ namespace tomoto
|
|
266
266
|
public:
|
267
267
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numPDocs, lambda);
|
268
268
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numPDocs, lambda);
|
269
|
+
DEFINE_HASHER_AFTER_BASE(BaseClass, numPDocs, lambda);
|
269
270
|
|
270
271
|
GETTER(P, size_t, numPDocs);
|
271
272
|
|
@@ -16,8 +16,9 @@ namespace tomoto
|
|
16
16
|
ret["y"] = y;
|
17
17
|
return ret;
|
18
18
|
}
|
19
|
-
|
20
|
-
|
19
|
+
|
20
|
+
DECLARE_SERIALIZER_WITH_VERSION(0);
|
21
|
+
DECLARE_SERIALIZER_WITH_VERSION(1);
|
21
22
|
};
|
22
23
|
|
23
24
|
struct SLDAArgs;
|
@@ -2,6 +2,11 @@
|
|
2
2
|
|
3
3
|
namespace tomoto
|
4
4
|
{
|
5
|
+
DEFINE_OUT_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentSLDA, BaseDocument, 0, y);
|
6
|
+
DEFINE_OUT_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(DocumentSLDA, BaseDocument, 1, 0x00010001, y);
|
7
|
+
|
8
|
+
TMT_INSTANTIATE_DOC(DocumentSLDA);
|
9
|
+
|
5
10
|
ISLDAModel* ISLDAModel::create(TermWeight _weight, const SLDAArgs& args, bool scalarRng)
|
6
11
|
{
|
7
12
|
TMT_SWITCH_TW(_weight, scalarRng, SLDAModel, args);
|
@@ -348,6 +348,7 @@ namespace tomoto
|
|
348
348
|
public:
|
349
349
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, F, responseVars, mu, nuSq);
|
350
350
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, F, responseVars, mu, nuSq);
|
351
|
+
DEFINE_HASHER_AFTER_BASE(BaseClass, F, mu, nuSq);
|
351
352
|
|
352
353
|
SLDAModel(const SLDAArgs& args)
|
353
354
|
: BaseClass(args), F(args.vars.size()), varTypes(args.vars),
|
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
#pragma once
|
2
2
|
#include <numeric>
|
3
3
|
#include <unordered_set>
|
4
4
|
#include "../Utils/Utils.hpp"
|
@@ -7,7 +7,7 @@
|
|
7
7
|
#include "../Utils/ThreadPool.hpp"
|
8
8
|
#include "../Utils/serializer.hpp"
|
9
9
|
#include "../Utils/exception.h"
|
10
|
-
#include "../Utils/SharedString.
|
10
|
+
#include "../Utils/SharedString.h"
|
11
11
|
#include <EigenRand/EigenRand>
|
12
12
|
#include <mapbox/variant.hpp>
|
13
13
|
|
@@ -107,7 +107,7 @@ namespace tomoto
|
|
107
107
|
|
108
108
|
virtual operator RawDoc() const
|
109
109
|
{
|
110
|
-
RawDoc raw{ *this };
|
110
|
+
RawDoc raw{ *static_cast<const RawDocKernel*>(this) };
|
111
111
|
if (wOrder.empty())
|
112
112
|
{
|
113
113
|
raw.words.insert(raw.words.begin(), words.begin(), words.end());
|
@@ -224,6 +224,8 @@ namespace tomoto
|
|
224
224
|
virtual void loadModel(std::istream& reader,
|
225
225
|
std::vector<uint8_t>* extra_data = nullptr) = 0;
|
226
226
|
|
227
|
+
virtual std::array<uint64_t, 2> getHash() const = 0;
|
228
|
+
|
227
229
|
virtual std::unique_ptr<ITopicModel> copy() const = 0;
|
228
230
|
|
229
231
|
virtual const DocumentBase* getDoc(size_t docId) const = 0;
|
@@ -251,14 +253,17 @@ namespace tomoto
|
|
251
253
|
virtual const std::vector<uint64_t>& getVocabCf() const = 0;
|
252
254
|
virtual std::vector<double> getVocabWeightedCf() const = 0;
|
253
255
|
virtual const std::vector<uint64_t>& getVocabDf() const = 0;
|
256
|
+
virtual const std::vector<std::vector<std::pair<std::string, size_t>>>& getWordFormCnts() const = 0;
|
254
257
|
|
255
258
|
virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
|
256
259
|
virtual size_t getGlobalStep() const = 0;
|
257
260
|
virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0, bool updateStopwords = true) = 0;
|
258
261
|
|
259
262
|
virtual size_t getK() const = 0;
|
263
|
+
virtual size_t getNumTopicsForPrior() const = 0;
|
260
264
|
virtual std::vector<Float> getWidsByTopic(size_t tid, bool normalize = true) const = 0;
|
261
265
|
virtual std::vector<std::pair<std::string, Float>> getWordsByTopicSorted(size_t tid, size_t topN) const = 0;
|
266
|
+
virtual std::vector<std::tuple<std::string, Vid, Float>> getWordIdsByTopicSorted(size_t tid, size_t topN) const = 0;
|
262
267
|
|
263
268
|
virtual std::vector<std::pair<std::string, Float>> getWordsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
|
264
269
|
|
@@ -318,6 +323,7 @@ namespace tomoto
|
|
318
323
|
size_t globalStep = 0;
|
319
324
|
_ModelState globalState, tState;
|
320
325
|
Dictionary dict;
|
326
|
+
std::vector<std::vector<std::pair<std::string, size_t>>> wordFormCnts;
|
321
327
|
uint64_t realV = 0; // vocab size after removing stopwords
|
322
328
|
uint64_t realN = 0; // total word size after removing stopwords
|
323
329
|
double weightedN = 0;
|
@@ -564,6 +570,44 @@ namespace tomoto
|
|
564
570
|
}
|
565
571
|
}
|
566
572
|
|
573
|
+
void updateWordFormCnts()
|
574
|
+
{
|
575
|
+
wordFormCnts.clear();
|
576
|
+
wordFormCnts.resize(realV);
|
577
|
+
std::vector<std::unordered_map<std::string, size_t>> cnts(realV);
|
578
|
+
for (auto& doc : docs)
|
579
|
+
{
|
580
|
+
for (size_t i = 0; i < doc.words.size(); ++i)
|
581
|
+
{
|
582
|
+
auto w = doc.words[doc.wOrder.empty() ? i : doc.wOrder[i]];
|
583
|
+
if (w >= realV) continue;
|
584
|
+
auto& cnt = cnts[w];
|
585
|
+
std::string word;
|
586
|
+
if (!doc.rawStr.empty() && i < doc.origWordPos.size())
|
587
|
+
{
|
588
|
+
word = doc.rawStr.substr(doc.origWordPos[i], doc.origWordLen[i]);
|
589
|
+
}
|
590
|
+
else
|
591
|
+
{
|
592
|
+
word = dict.toWord(w);
|
593
|
+
}
|
594
|
+
++cnt[word];
|
595
|
+
}
|
596
|
+
}
|
597
|
+
|
598
|
+
for (size_t i = 0; i < realV; ++i)
|
599
|
+
{
|
600
|
+
auto& cnt = cnts[i];
|
601
|
+
std::vector<std::pair<std::string, size_t>> v{ std::make_move_iterator(cnt.begin()), std::make_move_iterator(cnt.end()) };
|
602
|
+
std::sort(v.begin(), v.end(), [](const std::pair<std::string, size_t>& a, const std::pair<std::string, size_t>& b)
|
603
|
+
{
|
604
|
+
return a.second > b.second;
|
605
|
+
});
|
606
|
+
wordFormCnts[i] = move(v);
|
607
|
+
cnt.clear();
|
608
|
+
}
|
609
|
+
}
|
610
|
+
|
567
611
|
int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
568
612
|
{
|
569
613
|
throw e;
|
@@ -725,6 +769,11 @@ namespace tomoto
|
|
725
769
|
return 0;
|
726
770
|
}
|
727
771
|
|
772
|
+
size_t getNumTopicsForPrior() const override
|
773
|
+
{
|
774
|
+
return this->getK();
|
775
|
+
}
|
776
|
+
|
728
777
|
std::vector<Float> getWidsByTopic(size_t tid, bool normalize) const override
|
729
778
|
{
|
730
779
|
return static_cast<const _Derived*>(this)->_getWidsByTopic(tid, normalize);
|
@@ -745,11 +794,26 @@ namespace tomoto
|
|
745
794
|
return ret;
|
746
795
|
}
|
747
796
|
|
797
|
+
std::vector<std::tuple<std::string, Vid, Float>> vid2StringVid(const std::vector<std::pair<Vid, Float>>& vids) const
|
798
|
+
{
|
799
|
+
std::vector<std::tuple<std::string, Vid, Float>> ret(vids.size());
|
800
|
+
for (size_t i = 0; i < vids.size(); ++i)
|
801
|
+
{
|
802
|
+
ret[i] = std::make_tuple(dict.toWord(vids[i].first), vids[i].first, vids[i].second);
|
803
|
+
}
|
804
|
+
return ret;
|
805
|
+
}
|
806
|
+
|
748
807
|
std::vector<std::pair<std::string, Float>> getWordsByTopicSorted(size_t tid, size_t topN) const override
|
749
808
|
{
|
750
809
|
return vid2String(getWidsByTopicSorted(tid, topN));
|
751
810
|
}
|
752
811
|
|
812
|
+
std::vector<std::tuple<std::string, Vid, Float>> getWordIdsByTopicSorted(size_t tid, size_t topN) const override
|
813
|
+
{
|
814
|
+
return vid2StringVid(getWidsByTopicSorted(tid, topN));
|
815
|
+
}
|
816
|
+
|
753
817
|
std::vector<std::pair<Vid, Float>> getWidsByDocSorted(const DocumentBase* doc, size_t topN) const
|
754
818
|
{
|
755
819
|
std::vector<Float> cnt(dict.size());
|
@@ -866,6 +930,11 @@ namespace tomoto
|
|
866
930
|
return vocabDf;
|
867
931
|
}
|
868
932
|
|
933
|
+
const std::vector<std::vector<std::pair<std::string, size_t>>>& getWordFormCnts() const override
|
934
|
+
{
|
935
|
+
return wordFormCnts;
|
936
|
+
}
|
937
|
+
|
869
938
|
void saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const override
|
870
939
|
{
|
871
940
|
static_cast<const _Derived*>(this)->_saveModel(writer, fullModel, extra_data);
|
@@ -876,6 +945,17 @@ namespace tomoto
|
|
876
945
|
static_cast<_Derived*>(this)->_loadModel(reader, extra_data);
|
877
946
|
static_cast<_Derived*>(this)->prepare(false);
|
878
947
|
}
|
948
|
+
|
949
|
+
std::array<uint64_t, 2> getHash() const override
|
950
|
+
{
|
951
|
+
std::array<uint64_t, 2> ret;
|
952
|
+
ret[0] = dict.computeHash(0);
|
953
|
+
const std::string s = static_cast<const _Derived*>(this)->tmid().str() + static_cast<const _Derived*>(this)->twid().str();
|
954
|
+
ret[0] = serializer::computeHashMany(ret[0], s, realV, globalStep, docs.size());
|
955
|
+
ret[1] = globalState.computeHash(0);
|
956
|
+
ret[1] = static_cast<const _Derived*>(this)->computeHash(ret[1]);
|
957
|
+
return ret;
|
958
|
+
}
|
879
959
|
};
|
880
960
|
|
881
961
|
}
|
@@ -0,0 +1,102 @@
|
|
1
|
+
#include "Dictionary.h"
|
2
|
+
|
3
|
+
namespace tomoto
|
4
|
+
{
|
5
|
+
Dictionary::Dictionary() = default;
|
6
|
+
Dictionary::~Dictionary() = default;
|
7
|
+
|
8
|
+
Dictionary::Dictionary(const Dictionary&) = default;
|
9
|
+
Dictionary& Dictionary::operator=(const Dictionary&) = default;
|
10
|
+
|
11
|
+
Dictionary::Dictionary(Dictionary&&) noexcept = default;
|
12
|
+
Dictionary& Dictionary::operator=(Dictionary&&) noexcept = default;
|
13
|
+
|
14
|
+
Vid Dictionary::add(const std::string& word)
|
15
|
+
{
|
16
|
+
auto it = dict.find(word);
|
17
|
+
if (it == dict.end())
|
18
|
+
{
|
19
|
+
dict.emplace(word, (Vid)dict.size());
|
20
|
+
id2word.emplace_back(word);
|
21
|
+
return (Vid)(dict.size() - 1);
|
22
|
+
}
|
23
|
+
return it->second;
|
24
|
+
}
|
25
|
+
|
26
|
+
const std::string& Dictionary::toWord(Vid vid) const
|
27
|
+
{
|
28
|
+
assert(vid < id2word.size());
|
29
|
+
return id2word[vid];
|
30
|
+
}
|
31
|
+
|
32
|
+
Vid Dictionary::toWid(const std::string& word) const
|
33
|
+
{
|
34
|
+
auto it = dict.find(word);
|
35
|
+
if (it == dict.end()) return non_vocab_id;
|
36
|
+
return it->second;
|
37
|
+
}
|
38
|
+
|
39
|
+
void Dictionary::serializerWrite(std::ostream& writer) const
|
40
|
+
{
|
41
|
+
serializer::writeMany(writer, serializer::to_key("Dict"), id2word);
|
42
|
+
}
|
43
|
+
|
44
|
+
void Dictionary::serializerRead(std::istream& reader)
|
45
|
+
{
|
46
|
+
serializer::readMany(reader, serializer::to_key("Dict"), id2word);
|
47
|
+
for (size_t i = 0; i < id2word.size(); ++i)
|
48
|
+
{
|
49
|
+
dict.emplace(id2word[i], (Vid)i);
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
uint64_t Dictionary::computeHash(uint64_t seed) const
|
54
|
+
{
|
55
|
+
return serializer::computeHashMany(seed, id2word);
|
56
|
+
}
|
57
|
+
|
58
|
+
void Dictionary::swap(Dictionary& rhs)
|
59
|
+
{
|
60
|
+
std::swap(dict, rhs.dict);
|
61
|
+
std::swap(id2word, rhs.id2word);
|
62
|
+
}
|
63
|
+
|
64
|
+
void Dictionary::reorder(const std::vector<Vid>& order)
|
65
|
+
{
|
66
|
+
for (auto& p : dict)
|
67
|
+
{
|
68
|
+
p.second = order[p.second];
|
69
|
+
id2word[p.second] = p.first;
|
70
|
+
}
|
71
|
+
}
|
72
|
+
|
73
|
+
const std::vector<std::string>& Dictionary::getRaw() const
|
74
|
+
{
|
75
|
+
return id2word;
|
76
|
+
}
|
77
|
+
|
78
|
+
Vid Dictionary::mapToNewDict(Vid v, const Dictionary& newDict) const
|
79
|
+
{
|
80
|
+
return newDict.toWid(toWord(v));
|
81
|
+
}
|
82
|
+
|
83
|
+
std::vector<Vid> Dictionary::mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const
|
84
|
+
{
|
85
|
+
std::vector<Vid> r(v.size());
|
86
|
+
for (size_t i = 0; i < v.size(); ++i)
|
87
|
+
{
|
88
|
+
r[i] = mapToNewDict(v[i], newDict);
|
89
|
+
}
|
90
|
+
return r;
|
91
|
+
}
|
92
|
+
|
93
|
+
std::vector<Vid> Dictionary::mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const
|
94
|
+
{
|
95
|
+
std::vector<Vid> r(v.size());
|
96
|
+
for (size_t i = 0; i < v.size(); ++i)
|
97
|
+
{
|
98
|
+
r[i] = mapToNewDict(v[i], newDict);
|
99
|
+
}
|
100
|
+
return r;
|
101
|
+
}
|
102
|
+
}
|
@@ -12,8 +12,9 @@ namespace tomoto
|
|
12
12
|
{
|
13
13
|
using Vid = uint32_t;
|
14
14
|
static constexpr Vid non_vocab_id = (Vid)-1;
|
15
|
+
static constexpr Vid rm_vocab_id = (Vid)-2;
|
15
16
|
using Tid = uint16_t;
|
16
|
-
static constexpr
|
17
|
+
static constexpr Tid non_topic_id = (Tid)-1;
|
17
18
|
using Float = float;
|
18
19
|
|
19
20
|
struct VidPair : public std::pair<Vid, Vid>
|
@@ -27,91 +28,41 @@ namespace tomoto
|
|
27
28
|
std::unordered_map<std::string, Vid> dict;
|
28
29
|
std::vector<std::string> id2word;
|
29
30
|
public:
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
31
|
+
|
32
|
+
Dictionary();
|
33
|
+
~Dictionary();
|
34
|
+
|
35
|
+
Dictionary(const Dictionary&);
|
36
|
+
Dictionary& operator=(const Dictionary&);
|
37
|
+
|
38
|
+
Dictionary(Dictionary&&) noexcept;
|
39
|
+
Dictionary& operator=(Dictionary&&) noexcept;
|
40
|
+
|
41
|
+
Vid add(const std::string& word);
|
41
42
|
|
42
43
|
size_t size() const { return dict.size(); }
|
43
44
|
|
44
|
-
const std::string& toWord(Vid vid) const
|
45
|
-
{
|
46
|
-
assert(vid < id2word.size());
|
47
|
-
return id2word[vid];
|
48
|
-
}
|
45
|
+
const std::string& toWord(Vid vid) const;
|
49
46
|
|
50
|
-
Vid toWid(const std::string& word) const
|
51
|
-
{
|
52
|
-
auto it = dict.find(word);
|
53
|
-
if (it == dict.end()) return non_vocab_id;
|
54
|
-
return it->second;
|
55
|
-
}
|
47
|
+
Vid toWid(const std::string& word) const;
|
56
48
|
|
57
|
-
void serializerWrite(std::ostream& writer) const
|
58
|
-
{
|
59
|
-
serializer::writeMany(writer, serializer::to_key("Dict"), id2word);
|
60
|
-
}
|
49
|
+
void serializerWrite(std::ostream& writer) const;
|
61
50
|
|
62
|
-
void serializerRead(std::istream& reader)
|
63
|
-
{
|
64
|
-
serializer::readMany(reader, serializer::to_key("Dict"), id2word);
|
65
|
-
for (size_t i = 0; i < id2word.size(); ++i)
|
66
|
-
{
|
67
|
-
dict.emplace(id2word[i], (Vid)i);
|
68
|
-
}
|
69
|
-
}
|
51
|
+
void serializerRead(std::istream& reader);
|
70
52
|
|
71
|
-
|
72
|
-
{
|
73
|
-
std::swap(dict, rhs.dict);
|
74
|
-
std::swap(id2word, rhs.id2word);
|
75
|
-
}
|
53
|
+
uint64_t computeHash(uint64_t seed) const;
|
76
54
|
|
77
|
-
void
|
78
|
-
{
|
79
|
-
for (auto& p : dict)
|
80
|
-
{
|
81
|
-
p.second = order[p.second];
|
82
|
-
id2word[p.second] = p.first;
|
83
|
-
}
|
84
|
-
}
|
55
|
+
void swap(Dictionary& rhs);
|
85
56
|
|
86
|
-
const std::vector<
|
87
|
-
{
|
88
|
-
return id2word;
|
89
|
-
}
|
57
|
+
void reorder(const std::vector<Vid>& order);
|
90
58
|
|
91
|
-
|
92
|
-
{
|
93
|
-
return newDict.toWid(toWord(v));
|
94
|
-
}
|
59
|
+
const std::vector<std::string>& getRaw() const;
|
95
60
|
|
96
|
-
|
97
|
-
{
|
98
|
-
std::vector<Vid> r(v.size());
|
99
|
-
for (size_t i = 0; i < v.size(); ++i)
|
100
|
-
{
|
101
|
-
r[i] = mapToNewDict(v[i], newDict);
|
102
|
-
}
|
103
|
-
return r;
|
104
|
-
}
|
61
|
+
Vid mapToNewDict(Vid v, const Dictionary& newDict) const;
|
105
62
|
|
106
|
-
std::vector<Vid>
|
107
|
-
|
108
|
-
|
109
|
-
for (size_t i = 0; i < v.size(); ++i)
|
110
|
-
{
|
111
|
-
r[i] = mapToNewDict(v[i], newDict);
|
112
|
-
}
|
113
|
-
return r;
|
114
|
-
}
|
63
|
+
std::vector<Vid> mapToNewDict(const std::vector<Vid>& v, const Dictionary& newDict) const;
|
64
|
+
|
65
|
+
std::vector<Vid> mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const;
|
115
66
|
};
|
116
67
|
|
117
68
|
}
|
@@ -126,4 +77,4 @@ namespace std
|
|
126
77
|
return hash<size_t>{}(p.first) ^ hash<size_t>{}(p.second);
|
127
78
|
}
|
128
79
|
};
|
129
|
-
}
|
80
|
+
}
|
@@ -116,7 +116,7 @@ namespace Eigen
|
|
116
116
|
|
117
117
|
EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4f& a)
|
118
118
|
{
|
119
|
-
return vcvtq_f32_s32(vandq_s32(a, vdupq_n_s32(1)));
|
119
|
+
return vcvtq_f32_s32(vandq_s32((Packet4i)a, vdupq_n_s32(1)));
|
120
120
|
}
|
121
121
|
|
122
122
|
EIGEN_STRONG_INLINE Packet4f p_bool2float(const Packet4i& a)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
#include <cstdint>
|
2
|
+
#include "Mmap.h"
|
3
|
+
|
4
|
+
namespace tomoto
|
5
|
+
{
|
6
|
+
namespace utils
|
7
|
+
{
|
8
|
+
static std::u16string utf8To16(const std::string& str)
|
9
|
+
{
|
10
|
+
std::u16string ret;
|
11
|
+
for (auto it = str.begin(); it != str.end(); ++it)
|
12
|
+
{
|
13
|
+
uint32_t code = 0;
|
14
|
+
uint32_t byte = (uint8_t)*it;
|
15
|
+
if ((byte & 0xF8) == 0xF0)
|
16
|
+
{
|
17
|
+
code = (uint32_t)((byte & 0x07) << 18);
|
18
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
19
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
20
|
+
code |= (uint32_t)((byte & 0x3F) << 12);
|
21
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
22
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
23
|
+
code |= (uint32_t)((byte & 0x3F) << 6);
|
24
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
25
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
26
|
+
code |= (byte & 0x3F);
|
27
|
+
}
|
28
|
+
else if ((byte & 0xF0) == 0xE0)
|
29
|
+
{
|
30
|
+
code = (uint32_t)((byte & 0x0F) << 12);
|
31
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
32
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
33
|
+
code |= (uint32_t)((byte & 0x3F) << 6);
|
34
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
35
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
36
|
+
code |= (byte & 0x3F);
|
37
|
+
}
|
38
|
+
else if ((byte & 0xE0) == 0xC0)
|
39
|
+
{
|
40
|
+
code = (uint32_t)((byte & 0x1F) << 6);
|
41
|
+
if (++it == str.end()) throw std::invalid_argument{ "unexpected ending" };
|
42
|
+
if (((byte = *it) & 0xC0) != 0x80) throw std::invalid_argument{ "unexpected trailing byte" };
|
43
|
+
code |= (byte & 0x3F);
|
44
|
+
}
|
45
|
+
else if ((byte & 0x80) == 0x00)
|
46
|
+
{
|
47
|
+
code = byte;
|
48
|
+
}
|
49
|
+
else
|
50
|
+
{
|
51
|
+
throw std::invalid_argument{ "unicode error" };
|
52
|
+
}
|
53
|
+
|
54
|
+
if (code < 0x10000)
|
55
|
+
{
|
56
|
+
ret.push_back((char16_t)code);
|
57
|
+
}
|
58
|
+
else if (code < 0x10FFFF)
|
59
|
+
{
|
60
|
+
code -= 0x10000;
|
61
|
+
ret.push_back((char16_t)(0xD800 | (code >> 10)));
|
62
|
+
ret.push_back((char16_t)(0xDC00 | (code & 0x3FF)));
|
63
|
+
}
|
64
|
+
else
|
65
|
+
{
|
66
|
+
throw std::invalid_argument{ "unicode error" };
|
67
|
+
}
|
68
|
+
}
|
69
|
+
return ret;
|
70
|
+
}
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
namespace tomoto
|
75
|
+
{
|
76
|
+
namespace utils
|
77
|
+
{
|
78
|
+
MMap::MMap(const std::string& filepath)
|
79
|
+
{
|
80
|
+
#ifdef _WIN32
|
81
|
+
hFile = CreateFileW((const wchar_t*)utf8To16(filepath).c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_READONLY, nullptr);
|
82
|
+
if (hFile == INVALID_HANDLE_VALUE) throw std::ios_base::failure("Cannot open '" + filepath + "'");
|
83
|
+
hFileMap = CreateFileMapping(hFile, nullptr, PAGE_READONLY, 0, 0, nullptr);
|
84
|
+
if (hFileMap == nullptr) throw std::ios_base::failure("Cannot open '" + filepath + "' Code:" + std::to_string(GetLastError()));
|
85
|
+
view = (const char*)MapViewOfFile(hFileMap, FILE_MAP_READ, 0, 0, 0);
|
86
|
+
if (!view) throw std::ios_base::failure("Cannot MapViewOfFile() Code:" + std::to_string(GetLastError()));
|
87
|
+
DWORD high;
|
88
|
+
len = GetFileSize(hFile, &high);
|
89
|
+
len |= (uint64_t)high << 32;
|
90
|
+
#else
|
91
|
+
fd = open(filepath.c_str(), O_RDONLY);
|
92
|
+
if (fd == -1) throw std::ios_base::failure("Cannot open '" + filepath + "'");
|
93
|
+
struct stat sb;
|
94
|
+
if (fstat(fd, &sb) < 0) throw std::ios_base::failure("Cannot open '" + filepath + "'");
|
95
|
+
len = sb.st_size;
|
96
|
+
view = (const char*)mmap(nullptr, len, PROT_READ, MAP_PRIVATE, fd, 0);
|
97
|
+
if (view == MAP_FAILED) throw std::ios_base::failure("Mapping failed");
|
98
|
+
#endif
|
99
|
+
}
|
100
|
+
|
101
|
+
#ifdef _WIN32
|
102
|
+
MMap::MMap(MMap&& o) noexcept
|
103
|
+
: view{ o.view }, len{ o.len }
|
104
|
+
{
|
105
|
+
o.view = nullptr;
|
106
|
+
std::swap(hFile, o.hFile);
|
107
|
+
std::swap(hFileMap, o.hFileMap);
|
108
|
+
}
|
109
|
+
#else
|
110
|
+
MMap::MMap(MMap&& o) noexcept
|
111
|
+
: len{ o.len }, fd{ std::move(o.fd) }
|
112
|
+
{
|
113
|
+
std::swap(view, o.view);
|
114
|
+
}
|
115
|
+
#endif
|
116
|
+
|
117
|
+
MMap& MMap::operator=(MMap&& o) noexcept
|
118
|
+
{
|
119
|
+
std::swap(view, o.view);
|
120
|
+
std::swap(len, o.len);
|
121
|
+
#ifdef _WIN32
|
122
|
+
std::swap(hFile, o.hFile);
|
123
|
+
std::swap(hFileMap, o.hFileMap);
|
124
|
+
#else
|
125
|
+
std::swap(fd, o.fd);
|
126
|
+
#endif
|
127
|
+
return *this;
|
128
|
+
}
|
129
|
+
|
130
|
+
MMap::~MMap()
|
131
|
+
{
|
132
|
+
#ifdef _WIN32
|
133
|
+
if (hFileMap)
|
134
|
+
{
|
135
|
+
UnmapViewOfFile(view);
|
136
|
+
view = nullptr;
|
137
|
+
}
|
138
|
+
#else
|
139
|
+
if (view)
|
140
|
+
{
|
141
|
+
munmap((void*)view, len);
|
142
|
+
}
|
143
|
+
#endif
|
144
|
+
}
|
145
|
+
}
|
146
|
+
}
|