tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -161,12 +161,19 @@ namespace tomoto
|
|
|
161
161
|
}
|
|
162
162
|
};
|
|
163
163
|
|
|
164
|
+
struct LDAArgs
|
|
165
|
+
{
|
|
166
|
+
size_t k = 1;
|
|
167
|
+
std::vector<Float> alpha = { 0.1 };
|
|
168
|
+
Float eta = 0.01;
|
|
169
|
+
size_t seed = std::random_device{}();
|
|
170
|
+
};
|
|
171
|
+
|
|
164
172
|
class ILDAModel : public ITopicModel
|
|
165
173
|
{
|
|
166
174
|
public:
|
|
167
175
|
using DefaultDocType = DocumentLDA<TermWeight::one>;
|
|
168
|
-
static ILDAModel* create(TermWeight _weight,
|
|
169
|
-
Float _alpha = 0.1, Float _eta = 0.01, size_t seed = std::random_device{}(),
|
|
176
|
+
static ILDAModel* create(TermWeight _weight, const LDAArgs& args,
|
|
170
177
|
bool scalarRng = false);
|
|
171
178
|
|
|
172
179
|
virtual TermWeight getTermWeight() const = 0;
|
|
@@ -85,7 +85,7 @@ namespace tomoto
|
|
|
85
85
|
static constexpr static constexpr char TMID[] = "LDA\0";
|
|
86
86
|
|
|
87
87
|
Float alpha;
|
|
88
|
-
|
|
88
|
+
Vector alphas;
|
|
89
89
|
Float eta;
|
|
90
90
|
Tid K;
|
|
91
91
|
size_t optimInterval = 50;
|
|
@@ -93,7 +93,7 @@ namespace tomoto
|
|
|
93
93
|
template<typename _List>
|
|
94
94
|
static Float calcDigammaSum(_List list, size_t len, Float alpha)
|
|
95
95
|
{
|
|
96
|
-
auto listExpr =
|
|
96
|
+
auto listExpr = Vector::NullaryExpr(len, list);
|
|
97
97
|
auto dAlpha = math::digammaT(alpha);
|
|
98
98
|
return (math::digammaApprox(listExpr.array() + alpha) - dAlpha).sum();
|
|
99
99
|
}
|
|
@@ -265,11 +265,11 @@ namespace tomoto
|
|
|
265
265
|
void initGlobalState(bool initDocs)
|
|
266
266
|
{
|
|
267
267
|
const size_t V = this->realV;
|
|
268
|
-
this->globalState.zLikelihood =
|
|
268
|
+
this->globalState.zLikelihood = Vector::Zero(K);
|
|
269
269
|
if (initDocs)
|
|
270
270
|
{
|
|
271
|
-
this->globalState.numByTopic =
|
|
272
|
-
this->globalState.numByTopicWord =
|
|
271
|
+
this->globalState.numByTopic = Vector::Zero(K);
|
|
272
|
+
this->globalState.numByTopicWord = Matrix::Zero(K, V);
|
|
273
273
|
}
|
|
274
274
|
}
|
|
275
275
|
|
|
@@ -335,7 +335,7 @@ namespace tomoto
|
|
|
335
335
|
LDACVB0Model(size_t _K = 1, Float _alpha = 0.1, Float _eta = 0.01, size_t _rg = std::random_device{}())
|
|
336
336
|
: BaseClass(_rg), K(_K), alpha(_alpha), eta(_eta)
|
|
337
337
|
{
|
|
338
|
-
alphas =
|
|
338
|
+
alphas = Vector::Constant(K, alpha);
|
|
339
339
|
}
|
|
340
340
|
GETTER(K, size_t, K);
|
|
341
341
|
GETTER(Alpha, Float, alpha);
|
|
@@ -355,7 +355,7 @@ namespace tomoto
|
|
|
355
355
|
|
|
356
356
|
std::unique_ptr<DocumentBase> makeDoc(const std::vector<std::string>& words) const override
|
|
357
357
|
{
|
|
358
|
-
return make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
|
|
358
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _makeDoc<true>(words));
|
|
359
359
|
}
|
|
360
360
|
|
|
361
361
|
void updateDocs()
|
|
@@ -403,7 +403,7 @@ namespace tomoto
|
|
|
403
403
|
return ret;
|
|
404
404
|
}
|
|
405
405
|
|
|
406
|
-
std::vector<Float> _getWidsByTopic(Tid tid) const
|
|
406
|
+
std::vector<Float> _getWidsByTopic(Tid tid, bool normalize = true) const
|
|
407
407
|
{
|
|
408
408
|
assert(tid < K);
|
|
409
409
|
const size_t V = this->realV;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class LDAModel<TermWeight::idf>;
|
|
7
|
-
template class LDAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
ILDAModel* ILDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
ILDAModel* ILDAModel::create(TermWeight _weight, const LDAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, LDAModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, LDAModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -56,7 +56,7 @@ namespace tomoto
|
|
|
56
56
|
{
|
|
57
57
|
using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
|
|
58
58
|
|
|
59
|
-
|
|
59
|
+
Vector zLikelihood;
|
|
60
60
|
Eigen::Matrix<WeightType, -1, 1> numByTopic; // Dim: (Topic, 1)
|
|
61
61
|
//Eigen::Matrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
|
62
62
|
ShareableMatrix<WeightType, -1, -1> numByTopicWord; // Dim: (Topic, Vocabs)
|
|
@@ -179,10 +179,10 @@ namespace tomoto
|
|
|
179
179
|
std::vector<Float> sharedWordWeights;
|
|
180
180
|
Tid K;
|
|
181
181
|
Float alpha, eta;
|
|
182
|
-
|
|
182
|
+
Vector alphas;
|
|
183
183
|
std::unordered_map<std::string, std::vector<Float>> etaByWord;
|
|
184
|
-
|
|
185
|
-
|
|
184
|
+
Matrix etaByTopicWord; // (K, V)
|
|
185
|
+
Vector etaSumByTopic; // (K, )
|
|
186
186
|
uint32_t optimInterval = 10, burnIn = 0;
|
|
187
187
|
Eigen::Matrix<WeightType, -1, -1> numByTopicDoc;
|
|
188
188
|
|
|
@@ -197,7 +197,7 @@ namespace tomoto
|
|
|
197
197
|
template<typename _List>
|
|
198
198
|
static Float calcDigammaSum(ThreadPool* pool, _List list, size_t len, Float alpha)
|
|
199
199
|
{
|
|
200
|
-
auto listExpr =
|
|
200
|
+
auto listExpr = Vector::NullaryExpr(len, list);
|
|
201
201
|
auto dAlpha = math::digammaT(alpha);
|
|
202
202
|
|
|
203
203
|
size_t suggested = (len + 127) / 128;
|
|
@@ -507,7 +507,7 @@ namespace tomoto
|
|
|
507
507
|
static_cast<DerivedClass*>(this)->optimizeParameters(pool, localData, rgs);
|
|
508
508
|
}
|
|
509
509
|
}
|
|
510
|
-
catch (const
|
|
510
|
+
catch (const exc::TrainingError&)
|
|
511
511
|
{
|
|
512
512
|
for (auto& r : res) if(r.valid()) r.get();
|
|
513
513
|
throw;
|
|
@@ -663,6 +663,22 @@ namespace tomoto
|
|
|
663
663
|
makeTransformIter(this->docs.end(), txWeights));
|
|
664
664
|
}
|
|
665
665
|
}
|
|
666
|
+
|
|
667
|
+
void updateForCopy()
|
|
668
|
+
{
|
|
669
|
+
BaseClass::updateForCopy();
|
|
670
|
+
size_t offset = 0;
|
|
671
|
+
for (auto& doc : this->docs)
|
|
672
|
+
{
|
|
673
|
+
size_t size = doc.Zs.size();
|
|
674
|
+
doc.Zs = tvector<Tid>{ sharedZs.data() + offset, size };
|
|
675
|
+
if (_tw != TermWeight::one)
|
|
676
|
+
{
|
|
677
|
+
doc.wordWeights = tvector<Float>{ sharedWordWeights.data() + offset, size };
|
|
678
|
+
}
|
|
679
|
+
offset += size;
|
|
680
|
+
}
|
|
681
|
+
}
|
|
666
682
|
|
|
667
683
|
WeightType* getTopicDocPtr(size_t docId) const
|
|
668
684
|
{
|
|
@@ -670,11 +686,14 @@ namespace tomoto
|
|
|
670
686
|
return (WeightType*)numByTopicDoc.col(docId).data();
|
|
671
687
|
}
|
|
672
688
|
|
|
689
|
+
/*
|
|
690
|
+
* called only when initializing a new doc, not when loading from saved model
|
|
691
|
+
*/
|
|
673
692
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
674
693
|
{
|
|
675
694
|
sortAndWriteOrder(doc.words, doc.wOrder);
|
|
676
695
|
doc.numByTopic.init(getTopicDocPtr(docId), K, 1);
|
|
677
|
-
doc.Zs = tvector<Tid>(wordSize);
|
|
696
|
+
doc.Zs = tvector<Tid>(wordSize, non_topic_id);
|
|
678
697
|
if(_tw != TermWeight::one) doc.wordWeights.resize(wordSize);
|
|
679
698
|
}
|
|
680
699
|
|
|
@@ -688,7 +707,7 @@ namespace tomoto
|
|
|
688
707
|
{
|
|
689
708
|
auto id = this->dict.toWid(it.first);
|
|
690
709
|
if (id == (Vid)-1 || id >= this->realV) continue;
|
|
691
|
-
etaByTopicWord.col(id) = Eigen::Map<
|
|
710
|
+
etaByTopicWord.col(id) = Eigen::Map<Vector>{ it.second.data(), (Eigen::Index)it.second.size() };
|
|
692
711
|
}
|
|
693
712
|
etaSumByTopic = etaByTopicWord.rowwise().sum();
|
|
694
713
|
}
|
|
@@ -696,7 +715,7 @@ namespace tomoto
|
|
|
696
715
|
void initGlobalState(bool initDocs)
|
|
697
716
|
{
|
|
698
717
|
const size_t V = this->realV;
|
|
699
|
-
this->globalState.zLikelihood =
|
|
718
|
+
this->globalState.zLikelihood = Vector::Zero(K);
|
|
700
719
|
if (initDocs)
|
|
701
720
|
{
|
|
702
721
|
this->globalState.numByTopic = Eigen::Matrix<WeightType, -1, 1>::Zero(K);
|
|
@@ -708,12 +727,14 @@ namespace tomoto
|
|
|
708
727
|
|
|
709
728
|
struct Generator
|
|
710
729
|
{
|
|
711
|
-
|
|
730
|
+
Eigen::Rand::DiscreteGen<int32_t> theta;
|
|
712
731
|
};
|
|
713
732
|
|
|
714
733
|
Generator makeGeneratorForInit(const _DocType*) const
|
|
715
734
|
{
|
|
716
|
-
|
|
735
|
+
Generator g;
|
|
736
|
+
g.theta = Eigen::Rand::DiscreteGen<int32_t>{ alphas.data(), alphas.data() + alphas.size() };
|
|
737
|
+
return g;
|
|
717
738
|
}
|
|
718
739
|
|
|
719
740
|
template<bool _Infer>
|
|
@@ -780,12 +801,13 @@ namespace tomoto
|
|
|
780
801
|
return cnt;
|
|
781
802
|
}
|
|
782
803
|
|
|
783
|
-
std::vector<Float> _getWidsByTopic(size_t tid) const
|
|
804
|
+
std::vector<Float> _getWidsByTopic(size_t tid, bool normalize = true) const
|
|
784
805
|
{
|
|
785
806
|
assert(tid < this->globalState.numByTopic.rows());
|
|
786
807
|
const size_t V = this->realV;
|
|
787
808
|
std::vector<Float> ret(V);
|
|
788
809
|
Float sum = this->globalState.numByTopic[tid] + V * eta;
|
|
810
|
+
if (!normalize) sum = 1;
|
|
789
811
|
auto r = this->globalState.numByTopicWord.row(tid);
|
|
790
812
|
for (size_t v = 0; v < V; ++v)
|
|
791
813
|
{
|
|
@@ -794,7 +816,7 @@ namespace tomoto
|
|
|
794
816
|
return ret;
|
|
795
817
|
}
|
|
796
818
|
|
|
797
|
-
template<bool
|
|
819
|
+
template<bool together, ParallelScheme _ps, typename _Iter>
|
|
798
820
|
std::vector<double> _infer(_Iter docFirst, _Iter docLast, size_t maxIter, Float tolerance, size_t numWorkers) const
|
|
799
821
|
{
|
|
800
822
|
decltype(static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr)) generator;
|
|
@@ -803,7 +825,7 @@ namespace tomoto
|
|
|
803
825
|
generator = static_cast<const DerivedClass*>(this)->makeGeneratorForInit(nullptr);
|
|
804
826
|
}
|
|
805
827
|
|
|
806
|
-
if (
|
|
828
|
+
if (together)
|
|
807
829
|
{
|
|
808
830
|
numWorkers = std::min(numWorkers, this->maxThreads[(size_t)_ps]);
|
|
809
831
|
ThreadPool pool{ numWorkers };
|
|
@@ -913,13 +935,26 @@ namespace tomoto
|
|
|
913
935
|
DEFINE_TAGGED_SERIALIZER_WITH_VERSION(1, 0x00010001, vocabWeights, alpha, alphas, eta, K, etaByWord,
|
|
914
936
|
burnIn, optimInterval);
|
|
915
937
|
|
|
916
|
-
LDAModel(
|
|
917
|
-
: BaseClass(
|
|
918
|
-
{
|
|
919
|
-
if (
|
|
920
|
-
|
|
921
|
-
if (
|
|
922
|
-
|
|
938
|
+
LDAModel(const LDAArgs& args, bool checkAlpha = true)
|
|
939
|
+
: BaseClass(args.seed), K(args.k), alpha(args.alpha[0]), eta(args.eta)
|
|
940
|
+
{
|
|
941
|
+
if (K == 0 || K >= 0x80000000) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong K value (K = %zd)", K));
|
|
942
|
+
|
|
943
|
+
if (args.alpha.size() == 1)
|
|
944
|
+
{
|
|
945
|
+
alphas = Vector::Constant(K, alpha);
|
|
946
|
+
}
|
|
947
|
+
else if (args.alpha.size() == args.k)
|
|
948
|
+
{
|
|
949
|
+
alphas = Eigen::Map<const Vector>(args.alpha.data(), args.alpha.size());
|
|
950
|
+
}
|
|
951
|
+
else if (checkAlpha)
|
|
952
|
+
{
|
|
953
|
+
THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong alpha value (len = %zd)", args.alpha.size()));
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
if ((alphas.array() <= 0).any()) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong alpha value");
|
|
957
|
+
if (eta <= 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, text::format("wrong eta value (eta = %f)", eta));
|
|
923
958
|
}
|
|
924
959
|
|
|
925
960
|
GETTER(K, size_t, K);
|
|
@@ -952,7 +987,7 @@ namespace tomoto
|
|
|
952
987
|
|
|
953
988
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
954
989
|
{
|
|
955
|
-
return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer));
|
|
990
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer));
|
|
956
991
|
}
|
|
957
992
|
|
|
958
993
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
@@ -962,15 +997,15 @@ namespace tomoto
|
|
|
962
997
|
|
|
963
998
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
964
999
|
{
|
|
965
|
-
return make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
|
|
1000
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _makeFromRawDoc<true>(rawDoc));
|
|
966
1001
|
}
|
|
967
1002
|
|
|
968
1003
|
void setWordPrior(const std::string& word, const std::vector<Float>& priors) override
|
|
969
1004
|
{
|
|
970
|
-
if (priors.size() != K) THROW_ERROR_WITH_INFO(
|
|
1005
|
+
if (priors.size() != K) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors.size() must be equal to K.");
|
|
971
1006
|
for (auto p : priors)
|
|
972
1007
|
{
|
|
973
|
-
if (p < 0) THROW_ERROR_WITH_INFO(
|
|
1008
|
+
if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
|
|
974
1009
|
}
|
|
975
1010
|
this->dict.add(word);
|
|
976
1011
|
etaByWord.emplace(word, priors);
|
|
@@ -1069,11 +1104,18 @@ namespace tomoto
|
|
|
1069
1104
|
return static_cast<const DerivedClass*>(this)->_getTopicsCount();
|
|
1070
1105
|
}
|
|
1071
1106
|
|
|
1072
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
1107
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
1073
1108
|
{
|
|
1074
1109
|
std::vector<Float> ret(K);
|
|
1075
|
-
Eigen::Map<Eigen::
|
|
1076
|
-
|
|
1110
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
|
|
1111
|
+
if (normalize)
|
|
1112
|
+
{
|
|
1113
|
+
m = (doc.numByTopic.array().template cast<Float>() + alphas.array()) / (doc.getSumWordWeight() + alphas.sum());
|
|
1114
|
+
}
|
|
1115
|
+
else
|
|
1116
|
+
{
|
|
1117
|
+
m = doc.numByTopic.array().template cast<Float>() + alphas.array();
|
|
1118
|
+
}
|
|
1077
1119
|
return ret;
|
|
1078
1120
|
}
|
|
1079
1121
|
|
|
@@ -19,8 +19,7 @@ namespace tomoto
|
|
|
19
19
|
{
|
|
20
20
|
public:
|
|
21
21
|
using DefaultDocType = DocumentLLDA<TermWeight::one>;
|
|
22
|
-
static ILLDAModel* create(TermWeight _weight,
|
|
23
|
-
Float alpha = 0.1, Float eta = 0.01, size_t seed = std::random_device{}(),
|
|
22
|
+
static ILLDAModel* create(TermWeight _weight, const LDAArgs& args,
|
|
24
23
|
bool scalarRng = false);
|
|
25
24
|
|
|
26
25
|
virtual const Dictionary& getTopicLabelDict() const = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class LLDAModel<TermWeight::idf>;
|
|
7
|
-
template class LLDAModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
ILLDAModel* ILLDAModel::create(TermWeight _weight, size_t _K, Float _alpha, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
ILLDAModel* ILLDAModel::create(TermWeight _weight, const LDAArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, LLDAModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, LLDAModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -71,13 +71,16 @@ namespace tomoto
|
|
|
71
71
|
|
|
72
72
|
struct Generator
|
|
73
73
|
{
|
|
74
|
-
|
|
74
|
+
Eigen::Array<Float, -1, 1> p;
|
|
75
|
+
Eigen::Rand::DiscreteGen<int32_t> theta;
|
|
75
76
|
};
|
|
76
77
|
|
|
77
78
|
Generator makeGeneratorForInit(const _DocType* doc) const
|
|
78
79
|
{
|
|
79
|
-
|
|
80
|
-
|
|
80
|
+
Generator g;
|
|
81
|
+
g.p = doc->labelMask.array().template cast<Float>() * this->alphas.array();
|
|
82
|
+
g.theta = Eigen::Rand::DiscreteGen<int32_t>{ g.p.data(), g.p.data() + this->K };
|
|
83
|
+
return g;
|
|
81
84
|
}
|
|
82
85
|
|
|
83
86
|
template<bool _Infer>
|
|
@@ -88,7 +91,7 @@ namespace tomoto
|
|
|
88
91
|
if (this->etaByTopicWord.size())
|
|
89
92
|
{
|
|
90
93
|
Eigen::Array<Float, -1, 1> col = this->etaByTopicWord.col(w);
|
|
91
|
-
|
|
94
|
+
col *= g.p;
|
|
92
95
|
z = sample::sampleFromDiscrete(col.data(), col.data() + col.size(), rgs);
|
|
93
96
|
}
|
|
94
97
|
else
|
|
@@ -102,8 +105,8 @@ namespace tomoto
|
|
|
102
105
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, topicLabelDict);
|
|
103
106
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, topicLabelDict);
|
|
104
107
|
|
|
105
|
-
LLDAModel(
|
|
106
|
-
: BaseClass(
|
|
108
|
+
LLDAModel(const LDAArgs& args)
|
|
109
|
+
: BaseClass(args)
|
|
107
110
|
{
|
|
108
111
|
}
|
|
109
112
|
|
|
@@ -153,7 +156,7 @@ namespace tomoto
|
|
|
153
156
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) const override
|
|
154
157
|
{
|
|
155
158
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc, tokenizer);
|
|
156
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
159
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
157
160
|
}
|
|
158
161
|
|
|
159
162
|
size_t addDoc(const RawDoc& rawDoc) override
|
|
@@ -165,16 +168,23 @@ namespace tomoto
|
|
|
165
168
|
std::unique_ptr<DocumentBase> makeDoc(const RawDoc& rawDoc) const override
|
|
166
169
|
{
|
|
167
170
|
auto doc = as_mutable(this)->template _makeFromRawDoc<true>(rawDoc);
|
|
168
|
-
return make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
171
|
+
return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
|
|
169
172
|
}
|
|
170
173
|
|
|
171
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
174
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
172
175
|
{
|
|
173
176
|
std::vector<Float> ret(this->K);
|
|
174
177
|
auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
|
|
175
|
-
Eigen::Map<Eigen::
|
|
176
|
-
|
|
177
|
-
|
|
178
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
|
179
|
+
if (normalize)
|
|
180
|
+
{
|
|
181
|
+
m = (doc.numByTopic.array().template cast<Float>() + maskedAlphas)
|
|
182
|
+
/ (doc.getSumWordWeight() + maskedAlphas.sum());
|
|
183
|
+
}
|
|
184
|
+
else
|
|
185
|
+
{
|
|
186
|
+
m = doc.numByTopic.array().template cast<Float>() + maskedAlphas;
|
|
187
|
+
}
|
|
178
188
|
return ret;
|
|
179
189
|
}
|
|
180
190
|
|
|
@@ -28,13 +28,22 @@ namespace tomoto
|
|
|
28
28
|
template<typename _TopicModel> void update(WeightType* ptr, const _TopicModel& mdl);
|
|
29
29
|
};
|
|
30
30
|
|
|
31
|
+
struct MGLDAArgs : public LDAArgs
|
|
32
|
+
{
|
|
33
|
+
size_t kL = 1;
|
|
34
|
+
size_t t = 3;
|
|
35
|
+
std::vector<Float> alphaL = { 0.1 };
|
|
36
|
+
Float alphaMG = 0.1;
|
|
37
|
+
Float alphaML = 0.1;
|
|
38
|
+
Float etaL = 0.01;
|
|
39
|
+
Float gamma = 0.1;
|
|
40
|
+
};
|
|
41
|
+
|
|
31
42
|
class IMGLDAModel : public ILDAModel
|
|
32
43
|
{
|
|
33
44
|
public:
|
|
34
45
|
using DefaultDocType = DocumentMGLDA<TermWeight::one>;
|
|
35
|
-
static IMGLDAModel* create(TermWeight _weight,
|
|
36
|
-
Float _alphaG = 0.1, Float _alphaL = 0.1, Float _alphaMG = 0.1, Float _alphaML = 0.1,
|
|
37
|
-
Float _etaG = 0.01, Float _etaL = 0.01, Float _gamma = 0.1, size_t seed = std::random_device{}(),
|
|
46
|
+
static IMGLDAModel* create(TermWeight _weight, const MGLDAArgs& args,
|
|
38
47
|
bool scalarRng = false);
|
|
39
48
|
|
|
40
49
|
virtual size_t getKL() const = 0;
|