tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -16,6 +16,9 @@ namespace tomoto
|
|
|
16
16
|
using RandGen = Eigen::Rand::P8_mt19937_64<uint32_t>;
|
|
17
17
|
using ScalarRandGen = Eigen::Rand::UniversalRandomEngine<uint32_t, std::mt19937_64>;
|
|
18
18
|
|
|
19
|
+
using Vector = Eigen::Matrix<Float, -1, 1>;
|
|
20
|
+
using Matrix = Eigen::Matrix<Float, -1, -1>;
|
|
21
|
+
|
|
19
22
|
struct RawDocKernel
|
|
20
23
|
{
|
|
21
24
|
Float weight = 1;
|
|
@@ -59,8 +62,8 @@ namespace tomoto
|
|
|
59
62
|
const _Ty& getMisc(const std::string& name) const
|
|
60
63
|
{
|
|
61
64
|
auto it = misc.find(name);
|
|
62
|
-
if (it == misc.end()) throw
|
|
63
|
-
if (!it->second.template is<_Ty>()) throw
|
|
65
|
+
if (it == misc.end()) throw exc::InvalidArgument{ "There is no value named `" + name + "` in misc data" };
|
|
66
|
+
if (!it->second.template is<_Ty>()) throw exc::InvalidArgument{ "Value named `" + name + "` is not in right type." };
|
|
64
67
|
return it->second.template get<_Ty>();
|
|
65
68
|
}
|
|
66
69
|
|
|
@@ -69,11 +72,13 @@ namespace tomoto
|
|
|
69
72
|
{
|
|
70
73
|
auto it = misc.find(name);
|
|
71
74
|
if (it == misc.end()) return {};
|
|
72
|
-
if (!it->second.template is<_Ty>()) throw
|
|
75
|
+
if (!it->second.template is<_Ty>()) throw exc::InvalidArgument{ "Value named `" + name + "` is not in right type." };
|
|
73
76
|
return it->second.template get<_Ty>();
|
|
74
77
|
}
|
|
75
78
|
};
|
|
76
79
|
|
|
80
|
+
class ITopicModel;
|
|
81
|
+
|
|
77
82
|
class DocumentBase : public RawDocKernel
|
|
78
83
|
{
|
|
79
84
|
public:
|
|
@@ -95,6 +100,11 @@ namespace tomoto
|
|
|
95
100
|
|
|
96
101
|
virtual ~DocumentBase() {}
|
|
97
102
|
|
|
103
|
+
virtual RawDoc::MiscType makeMisc(const ITopicModel*) const
|
|
104
|
+
{
|
|
105
|
+
return {};
|
|
106
|
+
}
|
|
107
|
+
|
|
98
108
|
virtual operator RawDoc() const
|
|
99
109
|
{
|
|
100
110
|
RawDoc raw{ *this };
|
|
@@ -110,6 +120,7 @@ namespace tomoto
|
|
|
110
120
|
raw.words[i] = words[wOrder[i]];
|
|
111
121
|
}
|
|
112
122
|
}
|
|
123
|
+
//raw.misc = makeMisc();
|
|
113
124
|
return raw;
|
|
114
125
|
}
|
|
115
126
|
|
|
@@ -212,6 +223,9 @@ namespace tomoto
|
|
|
212
223
|
const std::vector<uint8_t>* extra_data = nullptr) const = 0;
|
|
213
224
|
virtual void loadModel(std::istream& reader,
|
|
214
225
|
std::vector<uint8_t>* extra_data = nullptr) = 0;
|
|
226
|
+
|
|
227
|
+
virtual std::unique_ptr<ITopicModel> copy() const = 0;
|
|
228
|
+
|
|
215
229
|
virtual const DocumentBase* getDoc(size_t docId) const = 0;
|
|
216
230
|
virtual size_t getDocIdByUid(const std::string& docUid) const = 0;
|
|
217
231
|
|
|
@@ -242,12 +256,12 @@ namespace tomoto
|
|
|
242
256
|
virtual void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) = 0;
|
|
243
257
|
|
|
244
258
|
virtual size_t getK() const = 0;
|
|
245
|
-
virtual std::vector<Float> getWidsByTopic(size_t tid) const = 0;
|
|
259
|
+
virtual std::vector<Float> getWidsByTopic(size_t tid, bool normalize = true) const = 0;
|
|
246
260
|
virtual std::vector<std::pair<std::string, Float>> getWordsByTopicSorted(size_t tid, size_t topN) const = 0;
|
|
247
261
|
|
|
248
262
|
virtual std::vector<std::pair<std::string, Float>> getWordsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
|
|
249
263
|
|
|
250
|
-
virtual std::vector<Float> getTopicsByDoc(const DocumentBase* doc) const = 0;
|
|
264
|
+
virtual std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize = true) const = 0;
|
|
251
265
|
virtual std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const = 0;
|
|
252
266
|
virtual std::vector<double> infer(const std::vector<DocumentBase*>& docs, size_t maxIter, Float tolerance, size_t numWorkers, ParallelScheme ps, bool together) const = 0;
|
|
253
267
|
virtual ~ITopicModel() {}
|
|
@@ -308,7 +322,7 @@ namespace tomoto
|
|
|
308
322
|
size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
|
|
309
323
|
size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
|
|
310
324
|
|
|
311
|
-
std::unique_ptr<ThreadPool
|
|
325
|
+
PreventCopy<std::unique_ptr<ThreadPool>> cachedPool;
|
|
312
326
|
|
|
313
327
|
void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
|
|
314
328
|
{
|
|
@@ -373,7 +387,7 @@ namespace tomoto
|
|
|
373
387
|
{
|
|
374
388
|
if (doc.words.empty()) return -1;
|
|
375
389
|
if (!doc.docUid.empty() && uidMap.count(doc.docUid))
|
|
376
|
-
throw
|
|
390
|
+
throw exc::InvalidArgument{ "there is a document with uid = '" + std::string{ doc.docUid } + "' already." };
|
|
377
391
|
size_t maxWid = *std::max_element(doc.words.begin(), doc.words.end());
|
|
378
392
|
if (vocabCf.size() <= maxWid)
|
|
379
393
|
{
|
|
@@ -383,7 +397,7 @@ namespace tomoto
|
|
|
383
397
|
for (auto w : doc.words) ++vocabCf[w];
|
|
384
398
|
std::unordered_set<Vid> uniq{ doc.words.begin(), doc.words.end() };
|
|
385
399
|
for (auto w : uniq) ++vocabDf[w];
|
|
386
|
-
uidMap.emplace(doc.docUid, docs.size());
|
|
400
|
+
if (!doc.docUid.empty()) uidMap.emplace(doc.docUid, docs.size());
|
|
387
401
|
docs.emplace_back(std::forward<_DocTy>(doc));
|
|
388
402
|
return docs.size() - 1;
|
|
389
403
|
}
|
|
@@ -415,7 +429,7 @@ namespace tomoto
|
|
|
415
429
|
}
|
|
416
430
|
else
|
|
417
431
|
{
|
|
418
|
-
throw
|
|
432
|
+
throw exc::InvalidArgument{ "Either `words` or `rawWords` must be filled." };
|
|
419
433
|
}
|
|
420
434
|
return doc;
|
|
421
435
|
}
|
|
@@ -461,7 +475,19 @@ namespace tomoto
|
|
|
461
475
|
auto tx = [](_DocType& doc) { return &doc.words; };
|
|
462
476
|
tvector<Vid>::trade(words,
|
|
463
477
|
makeTransformIter(docs.begin(), tx),
|
|
464
|
-
makeTransformIter(docs.end(), tx)
|
|
478
|
+
makeTransformIter(docs.end(), tx)
|
|
479
|
+
);
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
void updateForCopy()
|
|
483
|
+
{
|
|
484
|
+
size_t offset = 0;
|
|
485
|
+
for (auto& doc : docs)
|
|
486
|
+
{
|
|
487
|
+
size_t size = doc.words.size();
|
|
488
|
+
doc.words = tvector<Vid>{ words.data() + offset, size };
|
|
489
|
+
offset += size;
|
|
490
|
+
}
|
|
465
491
|
}
|
|
466
492
|
|
|
467
493
|
size_t countRealN() const
|
|
@@ -529,7 +555,7 @@ namespace tomoto
|
|
|
529
555
|
}
|
|
530
556
|
}
|
|
531
557
|
|
|
532
|
-
int restoreFromTrainingError(const
|
|
558
|
+
int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
533
559
|
{
|
|
534
560
|
throw e;
|
|
535
561
|
}
|
|
@@ -539,6 +565,15 @@ namespace tomoto
|
|
|
539
565
|
{
|
|
540
566
|
}
|
|
541
567
|
|
|
568
|
+
TopicModel(const TopicModel&) = default;
|
|
569
|
+
|
|
570
|
+
std::unique_ptr<ITopicModel> copy() const override
|
|
571
|
+
{
|
|
572
|
+
auto ret = std::make_unique<_Derived>(*static_cast<const _Derived*>(this));
|
|
573
|
+
ret->updateForCopy();
|
|
574
|
+
return ret;
|
|
575
|
+
}
|
|
576
|
+
|
|
542
577
|
size_t getNumDocs() const override
|
|
543
578
|
{
|
|
544
579
|
return docs.size();
|
|
@@ -578,11 +613,11 @@ namespace tomoto
|
|
|
578
613
|
if ((_Flags & flags::shared_state)) return ParallelScheme::none;
|
|
579
614
|
return ParallelScheme::copy_merge;
|
|
580
615
|
case ParallelScheme::copy_merge:
|
|
581
|
-
if ((_Flags & flags::shared_state)) THROW_ERROR_WITH_INFO(
|
|
616
|
+
if ((_Flags & flags::shared_state)) THROW_ERROR_WITH_INFO(exc::InvalidArgument,
|
|
582
617
|
std::string{ "This model doesn't provide ParallelScheme::" } + toString(ps));
|
|
583
618
|
break;
|
|
584
619
|
case ParallelScheme::partition:
|
|
585
|
-
if (!(_Flags & flags::partitioned_multisampling)) THROW_ERROR_WITH_INFO(
|
|
620
|
+
if (!(_Flags & flags::partitioned_multisampling)) THROW_ERROR_WITH_INFO(exc::InvalidArgument,
|
|
586
621
|
std::string{ "This model doesn't provide ParallelScheme::" } + toString(ps));
|
|
587
622
|
break;
|
|
588
623
|
}
|
|
@@ -597,7 +632,7 @@ namespace tomoto
|
|
|
597
632
|
if (numWorkers == 1 || (_Flags & flags::shared_state)) ps = ParallelScheme::none;
|
|
598
633
|
if (!cachedPool || cachedPool->getNumWorkers() != numWorkers)
|
|
599
634
|
{
|
|
600
|
-
cachedPool = make_unique<ThreadPool>(numWorkers);
|
|
635
|
+
cachedPool = std::make_unique<ThreadPool>(numWorkers);
|
|
601
636
|
}
|
|
602
637
|
|
|
603
638
|
std::vector<_ModelState> localData;
|
|
@@ -647,7 +682,7 @@ namespace tomoto
|
|
|
647
682
|
}
|
|
648
683
|
break;
|
|
649
684
|
}
|
|
650
|
-
catch (const
|
|
685
|
+
catch (const exc::TrainingError& e)
|
|
651
686
|
{
|
|
652
687
|
std::cerr << e.what() << std::endl;
|
|
653
688
|
int ret = static_cast<_Derived*>(this)->restoreFromTrainingError(
|
|
@@ -675,14 +710,14 @@ namespace tomoto
|
|
|
675
710
|
return 0;
|
|
676
711
|
}
|
|
677
712
|
|
|
678
|
-
std::vector<Float> getWidsByTopic(size_t tid) const override
|
|
713
|
+
std::vector<Float> getWidsByTopic(size_t tid, bool normalize) const override
|
|
679
714
|
{
|
|
680
|
-
return static_cast<const _Derived*>(this)->_getWidsByTopic(tid);
|
|
715
|
+
return static_cast<const _Derived*>(this)->_getWidsByTopic(tid, normalize);
|
|
681
716
|
}
|
|
682
717
|
|
|
683
718
|
std::vector<std::pair<Vid, Float>> getWidsByTopicSorted(size_t tid, size_t topN) const
|
|
684
719
|
{
|
|
685
|
-
return extractTopN<Vid>(static_cast<const _Derived*>(this)->_getWidsByTopic(tid), topN);
|
|
720
|
+
return extractTopN<Vid>(static_cast<const _Derived*>(this)->_getWidsByTopic(tid, true), topN);
|
|
686
721
|
}
|
|
687
722
|
|
|
688
723
|
std::vector<std::pair<std::string, Float>> vid2String(const std::vector<std::pair<Vid, Float>>& vids) const
|
|
@@ -716,7 +751,7 @@ namespace tomoto
|
|
|
716
751
|
double getDocLL(const DocumentBase* doc) const override
|
|
717
752
|
{
|
|
718
753
|
auto* p = dynamic_cast<const DocType*>(doc);
|
|
719
|
-
if (!p) throw
|
|
754
|
+
if (!p) throw exc::InvalidArgument{ "wrong `doc` type." };
|
|
720
755
|
return static_cast<const _Derived*>(this)->getLLDocs(p, p + 1);
|
|
721
756
|
}
|
|
722
757
|
|
|
@@ -757,17 +792,17 @@ namespace tomoto
|
|
|
757
792
|
return static_cast<const _Derived*>(this)->template _infer<false, ParallelScheme::partition>(b, e, maxIter, tolerance, numWorkers);
|
|
758
793
|
}
|
|
759
794
|
}
|
|
760
|
-
THROW_ERROR_WITH_INFO(
|
|
795
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, "invalid ParallelScheme");
|
|
761
796
|
}
|
|
762
797
|
|
|
763
|
-
std::vector<Float> getTopicsByDoc(const DocumentBase* doc) const override
|
|
798
|
+
std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
|
|
764
799
|
{
|
|
765
|
-
return static_cast<const _Derived*>(this)->getTopicsByDoc(*static_cast<const DocType*>(doc));
|
|
800
|
+
return static_cast<const _Derived*>(this)->getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
|
|
766
801
|
}
|
|
767
802
|
|
|
768
803
|
std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
|
|
769
804
|
{
|
|
770
|
-
return extractTopN<Tid>(getTopicsByDoc(doc), topN);
|
|
805
|
+
return extractTopN<Tid>(getTopicsByDoc(doc, true), topN);
|
|
771
806
|
}
|
|
772
807
|
|
|
773
808
|
const DocumentBase* getDoc(size_t docId) const override
|
|
@@ -35,8 +35,8 @@ namespace tomoto
|
|
|
35
35
|
bitsize = o.bitsize;
|
|
36
36
|
if (msize)
|
|
37
37
|
{
|
|
38
|
-
arr = make_unique<_Precision[]>(1 << bitsize);
|
|
39
|
-
alias = make_unique<size_t[]>(1 << bitsize);
|
|
38
|
+
arr = std::make_unique<_Precision[]>(1 << bitsize);
|
|
39
|
+
alias = std::make_unique<size_t[]>(1 << bitsize);
|
|
40
40
|
|
|
41
41
|
std::copy(o.arr.get(), o.arr.get() + (1 << bitsize), arr.get());
|
|
42
42
|
std::copy(o.alias.get(), o.alias.get() + (1 << bitsize), alias.get());
|
|
@@ -70,7 +70,7 @@ namespace tomoto
|
|
|
70
70
|
sum += *it;
|
|
71
71
|
}
|
|
72
72
|
|
|
73
|
-
if (!std::isfinite(sum)) THROW_ERROR_WITH_INFO(
|
|
73
|
+
if (!std::isfinite(sum)) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "cannot build NaN value distribution");
|
|
74
74
|
|
|
75
75
|
// ceil to power of 2
|
|
76
76
|
nbsize = log2_ceil(msize);
|
|
@@ -78,15 +78,15 @@ namespace tomoto
|
|
|
78
78
|
|
|
79
79
|
if (nbsize != bitsize)
|
|
80
80
|
{
|
|
81
|
-
arr = make_unique<_Precision[]>(psize);
|
|
81
|
+
arr = std::make_unique<_Precision[]>(psize);
|
|
82
82
|
std::fill(arr.get(), arr.get() + psize, 0);
|
|
83
|
-
alias = make_unique<size_t[]>(psize);
|
|
83
|
+
alias = std::make_unique<size_t[]>(psize);
|
|
84
84
|
bitsize = nbsize;
|
|
85
85
|
}
|
|
86
86
|
|
|
87
87
|
sum /= psize;
|
|
88
88
|
|
|
89
|
-
auto f = make_unique<double[]>(psize);
|
|
89
|
+
auto f = std::make_unique<double[]>(psize);
|
|
90
90
|
auto pf = f.get();
|
|
91
91
|
for (auto it = first; it != last; ++it, ++pf)
|
|
92
92
|
{
|
|
@@ -13,6 +13,7 @@ namespace tomoto
|
|
|
13
13
|
using Vid = uint32_t;
|
|
14
14
|
static constexpr Vid non_vocab_id = (Vid)-1;
|
|
15
15
|
using Tid = uint16_t;
|
|
16
|
+
static constexpr Vid non_topic_id = (Tid)-1;
|
|
16
17
|
using Float = float;
|
|
17
18
|
|
|
18
19
|
struct VidPair : public std::pair<Vid, Vid>
|
|
@@ -101,6 +102,16 @@ namespace tomoto
|
|
|
101
102
|
}
|
|
102
103
|
return r;
|
|
103
104
|
}
|
|
105
|
+
|
|
106
|
+
std::vector<Vid> mapToNewDictAdd(const std::vector<Vid>& v, Dictionary& newDict) const
|
|
107
|
+
{
|
|
108
|
+
std::vector<Vid> r(v.size());
|
|
109
|
+
for (size_t i = 0; i < v.size(); ++i)
|
|
110
|
+
{
|
|
111
|
+
r[i] = mapToNewDict(v[i], newDict);
|
|
112
|
+
}
|
|
113
|
+
return r;
|
|
114
|
+
}
|
|
104
115
|
};
|
|
105
116
|
|
|
106
117
|
}
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include <string>
|
|
4
|
+
#include "serializer.hpp"
|
|
4
5
|
|
|
5
6
|
namespace tomoto
|
|
6
7
|
{
|
|
@@ -109,7 +110,7 @@ namespace tomoto
|
|
|
109
110
|
|
|
110
111
|
bool empty() const
|
|
111
112
|
{
|
|
112
|
-
return ptr == nullptr;
|
|
113
|
+
return ptr == nullptr || size() == 0;
|
|
113
114
|
}
|
|
114
115
|
|
|
115
116
|
operator std::string() const
|
|
@@ -167,6 +168,30 @@ namespace tomoto
|
|
|
167
168
|
return !operator==(o);
|
|
168
169
|
}
|
|
169
170
|
};
|
|
171
|
+
|
|
172
|
+
namespace serializer
|
|
173
|
+
{
|
|
174
|
+
template<>
|
|
175
|
+
struct Serializer<SharedString>
|
|
176
|
+
{
|
|
177
|
+
using VTy = SharedString;
|
|
178
|
+
void write(std::ostream& ostr, const VTy& v)
|
|
179
|
+
{
|
|
180
|
+
writeToStream(ostr, (uint32_t)v.size());
|
|
181
|
+
if (!ostr.write((const char*)v.data(), v.size()))
|
|
182
|
+
throw std::ios_base::failure(std::string("writing type 'SharedString' is failed"));
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
void read(std::istream& istr, VTy& v)
|
|
186
|
+
{
|
|
187
|
+
auto size = readFromStream<uint32_t>(istr);
|
|
188
|
+
std::vector<char> t(size);
|
|
189
|
+
if (!istr.read((char*)t.data(), t.size()))
|
|
190
|
+
throw std::ios_base::failure(std::string("reading type 'SharedString' is failed"));
|
|
191
|
+
v = SharedString{ t.data(), t.data() + t.size() };
|
|
192
|
+
}
|
|
193
|
+
};
|
|
194
|
+
}
|
|
170
195
|
}
|
|
171
196
|
|
|
172
197
|
namespace std
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
#include <deque>
|
|
5
5
|
#include <functional>
|
|
6
6
|
#include <iterator>
|
|
7
|
+
#include "serializer.hpp"
|
|
7
8
|
|
|
8
9
|
namespace tomoto
|
|
9
10
|
{
|
|
@@ -24,6 +25,17 @@ namespace tomoto
|
|
|
24
25
|
if (it == this->end()) return this->emplace(key, typename _Map::mapped_type{}).first->second;
|
|
25
26
|
else return it->second;
|
|
26
27
|
}
|
|
28
|
+
|
|
29
|
+
void serializerWrite(std::ostream& os) const
|
|
30
|
+
{
|
|
31
|
+
serializer::writeMany(os, (const _Map&)*this);
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template<typename _Istr>
|
|
35
|
+
void serializerRead(_Istr& is)
|
|
36
|
+
{
|
|
37
|
+
serializer::readMany(is, (_Map&)*this);
|
|
38
|
+
}
|
|
27
39
|
};
|
|
28
40
|
|
|
29
41
|
template<class _Map, class _Node>
|
|
@@ -50,10 +62,13 @@ namespace tomoto
|
|
|
50
62
|
struct Trie
|
|
51
63
|
{
|
|
52
64
|
using Node = typename std::conditional<std::is_same<_Trie, void>::value, Trie, _Trie>::type;
|
|
65
|
+
using Key = _Key;
|
|
66
|
+
using KeyStore = _KeyStore;
|
|
53
67
|
using iterator = TrieIterator<_KeyStore, Node>;
|
|
54
68
|
_KeyStore next = {};
|
|
55
|
-
int32_t fail = 0;
|
|
56
69
|
_Value val = {};
|
|
70
|
+
int32_t fail = 0;
|
|
71
|
+
uint32_t depth = 0;
|
|
57
72
|
|
|
58
73
|
Trie() {}
|
|
59
74
|
~Trie() {}
|
|
@@ -84,13 +99,14 @@ namespace tomoto
|
|
|
84
99
|
if (first == last)
|
|
85
100
|
{
|
|
86
101
|
if (!val) val = _val;
|
|
87
|
-
return
|
|
102
|
+
return this;
|
|
88
103
|
}
|
|
89
104
|
|
|
90
105
|
auto v = *first;
|
|
91
106
|
if (!getNext(v))
|
|
92
107
|
{
|
|
93
108
|
next[v] = alloc() - this;
|
|
109
|
+
getNext(v)->depth = depth + 1;
|
|
94
110
|
}
|
|
95
111
|
return getNext(v)->build(++first, last, _val, alloc);
|
|
96
112
|
}
|
|
@@ -104,50 +120,48 @@ namespace tomoto
|
|
|
104
120
|
return nullptr;
|
|
105
121
|
}
|
|
106
122
|
|
|
107
|
-
template<
|
|
108
|
-
void
|
|
123
|
+
template<typename _Fn>
|
|
124
|
+
void traverse_with_keys(_Fn&& fn, std::vector<_Key>& rkeys) const
|
|
109
125
|
{
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
if (func(val)) return;
|
|
113
|
-
}
|
|
126
|
+
fn((Node*)this, rkeys);
|
|
127
|
+
|
|
114
128
|
for (auto& p : next)
|
|
115
129
|
{
|
|
116
|
-
if (
|
|
130
|
+
if (p.second)
|
|
117
131
|
{
|
|
118
|
-
|
|
132
|
+
rkeys.emplace_back(p.first);
|
|
133
|
+
getNext(p.first)->traverse_with_keys(fn, rkeys);
|
|
134
|
+
rkeys.pop_back();
|
|
119
135
|
}
|
|
120
136
|
}
|
|
121
|
-
return;
|
|
122
137
|
}
|
|
123
138
|
|
|
124
139
|
template<typename _Fn>
|
|
125
|
-
void
|
|
140
|
+
void traverse_with_keys_post(_Fn&& fn, std::vector<_Key>& rkeys) const
|
|
126
141
|
{
|
|
127
|
-
fn((Node*)this, rkeys);
|
|
128
|
-
|
|
129
142
|
for (auto& p : next)
|
|
130
143
|
{
|
|
131
|
-
if (p.
|
|
144
|
+
if (p.second)
|
|
132
145
|
{
|
|
133
146
|
rkeys.emplace_back(p.first);
|
|
134
|
-
getNext(p.first)->
|
|
147
|
+
getNext(p.first)->traverse_with_keys_post(fn, rkeys);
|
|
135
148
|
rkeys.pop_back();
|
|
136
149
|
}
|
|
137
150
|
}
|
|
151
|
+
fn((Node*)this, rkeys);
|
|
138
152
|
}
|
|
139
153
|
|
|
140
154
|
template<class _Iterator>
|
|
141
|
-
std::pair<
|
|
155
|
+
std::pair<Node*, size_t> findMaximumMatch(_Iterator begin, _Iterator end, size_t idxCnt = 0) const
|
|
142
156
|
{
|
|
143
|
-
if (begin == end) return std::make_pair(
|
|
157
|
+
if (begin == end) return std::make_pair((Node*)this, idxCnt);
|
|
144
158
|
auto n = getNext(*begin);
|
|
145
159
|
if (n)
|
|
146
160
|
{
|
|
147
161
|
auto v = n->findMaximumMatch(++begin, end, idxCnt + 1);
|
|
148
|
-
if (v.first) return v;
|
|
162
|
+
if (v.first->val) return v;
|
|
149
163
|
}
|
|
150
|
-
return std::make_pair(
|
|
164
|
+
return std::make_pair((Node*)this, idxCnt);
|
|
151
165
|
}
|
|
152
166
|
|
|
153
167
|
Node* findFail(_Key i) const
|
|
@@ -172,7 +186,7 @@ namespace tomoto
|
|
|
172
186
|
void fillFail()
|
|
173
187
|
{
|
|
174
188
|
std::deque<Node*> dq;
|
|
175
|
-
for (dq.emplace_back(this); !dq.empty(); dq.pop_front())
|
|
189
|
+
for (dq.emplace_back((Node*)this); !dq.empty(); dq.pop_front())
|
|
176
190
|
{
|
|
177
191
|
auto p = dq.front();
|
|
178
192
|
for (auto&& kv : p->next)
|
|
@@ -196,6 +210,17 @@ namespace tomoto
|
|
|
196
210
|
}
|
|
197
211
|
}
|
|
198
212
|
}
|
|
213
|
+
|
|
214
|
+
void serializerWrite(std::ostream& os) const
|
|
215
|
+
{
|
|
216
|
+
serializer::writeMany(os, next, val, fail, depth);
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
template<typename _Istr>
|
|
220
|
+
void serializerRead(_Istr& is)
|
|
221
|
+
{
|
|
222
|
+
serializer::readMany(is, next, val, fail, depth);
|
|
223
|
+
}
|
|
199
224
|
};
|
|
200
225
|
|
|
201
226
|
template<class _Key, class _Value, class _KeyStore = ConstAccess<std::map<_Key, int32_t>>>
|