tomoto 0.2.2 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/tomoto/ct.cpp +11 -11
- data/ext/tomoto/dmr.cpp +14 -13
- data/ext/tomoto/dt.cpp +14 -14
- data/ext/tomoto/ext.cpp +7 -7
- data/ext/tomoto/extconf.rb +1 -3
- data/ext/tomoto/gdmr.cpp +7 -7
- data/ext/tomoto/hdp.cpp +9 -9
- data/ext/tomoto/hlda.cpp +13 -13
- data/ext/tomoto/hpa.cpp +5 -5
- data/ext/tomoto/lda.cpp +42 -39
- data/ext/tomoto/llda.cpp +6 -6
- data/ext/tomoto/mglda.cpp +15 -15
- data/ext/tomoto/pa.cpp +6 -6
- data/ext/tomoto/plda.cpp +6 -6
- data/ext/tomoto/slda.cpp +8 -8
- data/ext/tomoto/utils.h +16 -70
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +57 -0
- data/vendor/tomotopy/README.rst +55 -0
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +4 -4
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/LDA.h +3 -3
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +3 -3
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +34 -14
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +2 -2
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +5 -2
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +4 -1
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +48 -21
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +5 -4
- data/vendor/tomotopy/src/Utils/Dictionary.h +2 -2
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +1 -1
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +1 -1
- data/vendor/tomotopy/src/Utils/math.h +2 -2
- data/vendor/tomotopy/src/Utils/serializer.hpp +30 -5
- metadata +6 -6
| @@ -335,7 +335,10 @@ namespace tomoto | |
| 335 335 | 
             
            		friend typename BaseClass::BaseClass;
         | 
| 336 336 | 
             
            		using WeightType = typename BaseClass::WeightType;
         | 
| 337 337 |  | 
| 338 | 
            -
            		static constexpr  | 
| 338 | 
            +
            		static constexpr auto tmid()
         | 
| 339 | 
            +
            		{
         | 
| 340 | 
            +
            			return serializer::to_key("hLDA");
         | 
| 341 | 
            +
            		}
         | 
| 339 342 |  | 
| 340 343 | 
             
            		Float gamma;
         | 
| 341 344 |  | 
| @@ -422,7 +425,7 @@ namespace tomoto | |
| 422 425 | 
             
            		}
         | 
| 423 426 |  | 
| 424 427 | 
             
            		template<int _inc>
         | 
| 425 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 428 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid level) const
         | 
| 426 429 | 
             
            		{
         | 
| 427 430 | 
             
            			assert(vid < this->realV);
         | 
| 428 431 | 
             
            			constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
         | 
| @@ -143,7 +143,7 @@ namespace tomoto | |
| 143 143 | 
             
            		}
         | 
| 144 144 |  | 
| 145 145 | 
             
            		template<int _inc>
         | 
| 146 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 146 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
         | 
| 147 147 | 
             
            		{
         | 
| 148 148 | 
             
            			assert(vid < this->realV);
         | 
| 149 149 | 
             
            			constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
         | 
| @@ -540,7 +540,7 @@ namespace tomoto | |
| 540 540 | 
             
            			return ret;
         | 
| 541 541 | 
             
            		}
         | 
| 542 542 |  | 
| 543 | 
            -
            		std::vector<Float>  | 
| 543 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 544 544 | 
             
            		{
         | 
| 545 545 | 
             
            			std::vector<Float> ret(1 + this->K + K2);
         | 
| 546 546 | 
             
            			Float sum = doc.getSumWordWeight() + this->alphas.sum();
         | 
| @@ -121,7 +121,7 @@ namespace tomoto | |
| 121 121 |  | 
| 122 122 | 
             
            		void updateSumWordWeight(size_t realV)
         | 
| 123 123 | 
             
            		{
         | 
| 124 | 
            -
            			sumWordWeight = std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
         | 
| 124 | 
            +
            			sumWordWeight = (int32_t)std::count_if(static_cast<_Base*>(this)->words.begin(), static_cast<_Base*>(this)->words.end(), [realV](Vid w)
         | 
| 125 125 | 
             
            			{
         | 
| 126 126 | 
             
            				return w < realV;
         | 
| 127 127 | 
             
            			});
         | 
| @@ -164,8 +164,8 @@ namespace tomoto | |
| 164 164 | 
             
            	struct LDAArgs
         | 
| 165 165 | 
             
            	{
         | 
| 166 166 | 
             
            		size_t k = 1;
         | 
| 167 | 
            -
            		std::vector<Float> alpha = { 0.1 };
         | 
| 168 | 
            -
            		Float eta = 0.01;
         | 
| 167 | 
            +
            		std::vector<Float> alpha = { (Float)0.1 };
         | 
| 168 | 
            +
            		Float eta = (Float)0.01;
         | 
| 169 169 | 
             
            		size_t seed = std::random_device{}();
         | 
| 170 170 | 
             
            	};
         | 
| 171 171 |  | 
| @@ -82,7 +82,7 @@ namespace tomoto | |
| 82 82 | 
             
            		friend BaseClass;
         | 
| 83 83 |  | 
| 84 84 | 
             
            		static constexpr const char TWID[] = "one\0";
         | 
| 85 | 
            -
            		static constexpr  | 
| 85 | 
            +
            		static constexpr const char TMID[] = "LDA\0";
         | 
| 86 86 |  | 
| 87 87 | 
             
            		Float alpha;
         | 
| 88 88 | 
             
            		Vector alphas;
         | 
| @@ -125,7 +125,7 @@ namespace tomoto | |
| 125 125 | 
             
            		}
         | 
| 126 126 |  | 
| 127 127 | 
             
            		template<int _Inc, typename _Vec>
         | 
| 128 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 128 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, _Vec tDist) const
         | 
| 129 129 | 
             
            		{
         | 
| 130 130 | 
             
            			assert(vid < this->realV);
         | 
| 131 131 | 
             
            			constexpr bool _dec = _Inc < 0;
         | 
| @@ -392,7 +392,7 @@ namespace tomoto | |
| 392 392 | 
             
            			return static_cast<const DerivedClass*>(this)->_getTopicsCount();
         | 
| 393 393 | 
             
            		}
         | 
| 394 394 |  | 
| 395 | 
            -
            		std::vector<Float>  | 
| 395 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc) const
         | 
| 396 396 | 
             
            		{
         | 
| 397 397 | 
             
            			std::vector<Float> ret(K);
         | 
| 398 398 | 
             
            			Float sum = doc.getSumWordWeight() + K * alpha;
         | 
| @@ -117,19 +117,28 @@ namespace tomoto | |
| 117 117 | 
             
            	template<>
         | 
| 118 118 | 
             
            	struct TwId<TermWeight::one>
         | 
| 119 119 | 
             
            	{
         | 
| 120 | 
            -
            		static constexpr  | 
| 120 | 
            +
            		static constexpr auto twid()
         | 
| 121 | 
            +
            		{
         | 
| 122 | 
            +
            			return serializer::to_key("one\0");
         | 
| 123 | 
            +
            		}
         | 
| 121 124 | 
             
            	};
         | 
| 122 125 |  | 
| 123 126 | 
             
            	template<>
         | 
| 124 127 | 
             
            	struct TwId<TermWeight::idf>
         | 
| 125 128 | 
             
            	{
         | 
| 126 | 
            -
            		static constexpr  | 
| 129 | 
            +
            		static constexpr auto twid()
         | 
| 130 | 
            +
            		{
         | 
| 131 | 
            +
            			return serializer::to_key("idf\0");
         | 
| 132 | 
            +
            		}
         | 
| 127 133 | 
             
            	};
         | 
| 128 134 |  | 
| 129 135 | 
             
            	template<>
         | 
| 130 136 | 
             
            	struct TwId<TermWeight::pmi>
         | 
| 131 137 | 
             
            	{
         | 
| 132 | 
            -
            		static constexpr  | 
| 138 | 
            +
            		static constexpr auto twid()
         | 
| 139 | 
            +
            		{
         | 
| 140 | 
            +
            			return serializer::to_key("pmi\0");
         | 
| 141 | 
            +
            		}
         | 
| 133 142 | 
             
            	};
         | 
| 134 143 |  | 
| 135 144 | 
             
            	// to make HDP friend of LDA for HDPModel::converToLDA
         | 
| @@ -169,7 +178,11 @@ namespace tomoto | |
| 169 178 | 
             
            			typename>
         | 
| 170 179 | 
             
            		friend class HDPModel;
         | 
| 171 180 |  | 
| 172 | 
            -
            		static constexpr  | 
| 181 | 
            +
            		static constexpr auto tmid()
         | 
| 182 | 
            +
            		{
         | 
| 183 | 
            +
            			return serializer::to_key("LDA\0");
         | 
| 184 | 
            +
            		}
         | 
| 185 | 
            +
             | 
| 173 186 | 
             
            		using WeightType = typename std::conditional<_tw == TermWeight::one, int32_t, float>::type;
         | 
| 174 187 |  | 
| 175 188 | 
             
            		enum { m_flags = _Flags };
         | 
| @@ -189,7 +202,7 @@ namespace tomoto | |
| 189 202 | 
             
            		struct ExtraDocData
         | 
| 190 203 | 
             
            		{
         | 
| 191 204 | 
             
            			std::vector<Vid> vChunkOffset;
         | 
| 192 | 
            -
            			Eigen::Matrix< | 
| 205 | 
            +
            			Eigen::Matrix<size_t, -1, -1> chunkOffsetByDoc;
         | 
| 193 206 | 
             
            		};
         | 
| 194 207 |  | 
| 195 208 | 
             
            		ExtraDocData eddTrain;
         | 
| @@ -261,7 +274,7 @@ namespace tomoto | |
| 261 274 | 
             
            		}
         | 
| 262 275 |  | 
| 263 276 | 
             
            		template<int _inc>
         | 
| 264 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 277 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid) const
         | 
| 265 278 | 
             
            		{
         | 
| 266 279 | 
             
            			assert(tid < K);
         | 
| 267 280 | 
             
            			assert(vid < this->realV);
         | 
| @@ -620,7 +633,7 @@ namespace tomoto | |
| 620 633 | 
             
            					for (Vid v = 0; v < V; ++v)
         | 
| 621 634 | 
             
            					{
         | 
| 622 635 | 
             
            						if (!ld.numByTopicWord(k, v)) continue;
         | 
| 623 | 
            -
            						ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord( | 
| 636 | 
            +
            						ll += math::lgammaT(ld.numByTopicWord(k, v) + etaByTopicWord(k, v)) - math::lgammaT(etaByTopicWord(k, v));
         | 
| 624 637 | 
             
            						assert(std::isfinite(ll));
         | 
| 625 638 | 
             
            					}
         | 
| 626 639 | 
             
            				}
         | 
| @@ -972,12 +985,14 @@ namespace tomoto | |
| 972 985 |  | 
| 973 986 | 
             
            		void setOptimInterval(size_t _optimInterval) override
         | 
| 974 987 | 
             
            		{
         | 
| 975 | 
            -
            			 | 
| 988 | 
            +
            			if (_optimInterval > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
         | 
| 989 | 
            +
            			optimInterval = (uint32_t)_optimInterval;
         | 
| 976 990 | 
             
            		}
         | 
| 977 991 |  | 
| 978 992 | 
             
            		void setBurnInIteration(size_t iteration) override
         | 
| 979 993 | 
             
            		{
         | 
| 980 | 
            -
            			 | 
| 994 | 
            +
            			if (iteration > 0x7FFFFFFF) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "wrong value");
         | 
| 995 | 
            +
            			burnIn = (uint32_t)iteration;
         | 
| 981 996 | 
             
            		}
         | 
| 982 997 |  | 
| 983 998 | 
             
            		size_t addDoc(const RawDoc& rawDoc, const RawDocTokenizer::Factory& tokenizer) override
         | 
| @@ -1008,6 +1023,11 @@ namespace tomoto | |
| 1008 1023 | 
             
            				if (p < 0) THROW_ERROR_WITH_INFO(exc::InvalidArgument, "priors must not be less than 0.");
         | 
| 1009 1024 | 
             
            			}
         | 
| 1010 1025 | 
             
            			this->dict.add(word);
         | 
| 1026 | 
            +
            			if (this->dict.size() > this->vocabCf.size())
         | 
| 1027 | 
            +
            			{
         | 
| 1028 | 
            +
            				this->vocabCf.resize(this->dict.size());
         | 
| 1029 | 
            +
            				this->vocabDf.resize(this->dict.size());
         | 
| 1030 | 
            +
            			}
         | 
| 1011 1031 | 
             
            			etaByWord.emplace(word, priors);
         | 
| 1012 1032 | 
             
            		}
         | 
| 1013 1033 |  | 
| @@ -1049,7 +1069,7 @@ namespace tomoto | |
| 1049 1069 | 
             
            			if (initDocs)
         | 
| 1050 1070 | 
             
            			{
         | 
| 1051 1071 | 
             
            				std::vector<uint32_t> df, cf, tf;
         | 
| 1052 | 
            -
            				 | 
| 1072 | 
            +
            				size_t totCf;
         | 
| 1053 1073 |  | 
| 1054 1074 | 
             
            				// calculate weighting
         | 
| 1055 1075 | 
             
            				if (_tw != TermWeight::one)
         | 
| @@ -1064,14 +1084,14 @@ namespace tomoto | |
| 1064 1084 | 
             
            							++df[w];
         | 
| 1065 1085 | 
             
            						}
         | 
| 1066 1086 | 
             
            					}
         | 
| 1067 | 
            -
            					totCf = accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
         | 
| 1087 | 
            +
            					totCf = std::accumulate(this->vocabCf.begin(), this->vocabCf.end(), 0);
         | 
| 1068 1088 | 
             
            				}
         | 
| 1069 1089 | 
             
            				if (_tw == TermWeight::idf)
         | 
| 1070 1090 | 
             
            				{
         | 
| 1071 1091 | 
             
            					vocabWeights.resize(V);
         | 
| 1072 1092 | 
             
            					for (size_t i = 0; i < V; ++i)
         | 
| 1073 1093 | 
             
            					{
         | 
| 1074 | 
            -
            						vocabWeights[i] = log(this->docs.size() / ( | 
| 1094 | 
            +
            						vocabWeights[i] = (Float)log(this->docs.size() / (double)df[i]);
         | 
| 1075 1095 | 
             
            					}
         | 
| 1076 1096 | 
             
            				}
         | 
| 1077 1097 | 
             
            				else if (_tw == TermWeight::pmi)
         | 
| @@ -1079,7 +1099,7 @@ namespace tomoto | |
| 1079 1099 | 
             
            					vocabWeights.resize(V);
         | 
| 1080 1100 | 
             
            					for (size_t i = 0; i < V; ++i)
         | 
| 1081 1101 | 
             
            					{
         | 
| 1082 | 
            -
            						vocabWeights[i] = this->vocabCf[i] / ( | 
| 1102 | 
            +
            						vocabWeights[i] = (Float)(this->vocabCf[i] / (double)totCf);
         | 
| 1083 1103 | 
             
            					}
         | 
| 1084 1104 | 
             
            				}
         | 
| 1085 1105 |  | 
| @@ -1104,7 +1124,7 @@ namespace tomoto | |
| 1104 1124 | 
             
            			return static_cast<const DerivedClass*>(this)->_getTopicsCount();
         | 
| 1105 1125 | 
             
            		}
         | 
| 1106 1126 |  | 
| 1107 | 
            -
            		std::vector<Float>  | 
| 1127 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 1108 1128 | 
             
            		{
         | 
| 1109 1129 | 
             
            			std::vector<Float> ret(K);
         | 
| 1110 1130 | 
             
            			Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), K };
         | 
| @@ -26,7 +26,10 @@ namespace tomoto | |
| 26 26 | 
             
            		friend typename BaseClass::BaseClass;
         | 
| 27 27 | 
             
            		using WeightType = typename BaseClass::WeightType;
         | 
| 28 28 |  | 
| 29 | 
            -
            		static constexpr  | 
| 29 | 
            +
            		static constexpr auto tmid()
         | 
| 30 | 
            +
            		{
         | 
| 31 | 
            +
            			return serializer::to_key("LLDA");
         | 
| 32 | 
            +
            		}
         | 
| 30 33 |  | 
| 31 34 | 
             
            		Dictionary topicLabelDict;
         | 
| 32 35 |  | 
| @@ -171,7 +174,7 @@ namespace tomoto | |
| 171 174 | 
             
            			return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
         | 
| 172 175 | 
             
            		}
         | 
| 173 176 |  | 
| 174 | 
            -
            		std::vector<Float>  | 
| 177 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 175 178 | 
             
            		{
         | 
| 176 179 | 
             
            			std::vector<Float> ret(this->K);
         | 
| 177 180 | 
             
            			auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
         | 
| @@ -63,7 +63,7 @@ namespace tomoto | |
| 63 63 | 
             
            		}
         | 
| 64 64 |  | 
| 65 65 | 
             
            		template<int _inc> 
         | 
| 66 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 66 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid tid, uint16_t s, uint8_t w, uint8_t r) const
         | 
| 67 67 | 
             
            		{
         | 
| 68 68 | 
             
            			const auto K = this->K;
         | 
| 69 69 |  | 
| @@ -527,7 +527,7 @@ namespace tomoto | |
| 527 527 | 
             
            			this->etaByWord.emplace(word, priors);
         | 
| 528 528 | 
             
            		}
         | 
| 529 529 |  | 
| 530 | 
            -
            		std::vector<Float>  | 
| 530 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 531 531 | 
             
            		{
         | 
| 532 532 | 
             
            			std::vector<Float> ret(this->K + KL);
         | 
| 533 533 | 
             
            			Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K + KL };
         | 
| @@ -90,7 +90,7 @@ namespace tomoto | |
| 90 90 | 
             
            		}
         | 
| 91 91 |  | 
| 92 92 | 
             
            		template<int _inc> 
         | 
| 93 | 
            -
            		inline void addWordTo(_ModelState& ld, _DocType& doc,  | 
| 93 | 
            +
            		inline void addWordTo(_ModelState& ld, _DocType& doc, size_t pid, Vid vid, Tid z1, Tid z2) const
         | 
| 94 94 | 
             
            		{
         | 
| 95 95 | 
             
            			assert(vid < this->realV);
         | 
| 96 96 | 
             
            			constexpr bool _dec = _inc < 0 && _tw != TermWeight::one;
         | 
| @@ -26,7 +26,10 @@ namespace tomoto | |
| 26 26 | 
             
            		friend typename BaseClass::BaseClass;
         | 
| 27 27 | 
             
            		using WeightType = typename BaseClass::WeightType;
         | 
| 28 28 |  | 
| 29 | 
            -
            		static constexpr  | 
| 29 | 
            +
            		static constexpr auto tmid()
         | 
| 30 | 
            +
            		{
         | 
| 31 | 
            +
            			return serializer::to_key("PLDA");
         | 
| 32 | 
            +
            		}
         | 
| 30 33 |  | 
| 31 34 | 
             
            		Dictionary topicLabelDict;
         | 
| 32 35 |  | 
| @@ -178,7 +181,7 @@ namespace tomoto | |
| 178 181 | 
             
            			return std::make_unique<_DocType>(as_mutable(this)->template _updateDoc<true>(doc, rawDoc.template getMiscDefault<std::vector<std::string>>("labels")));
         | 
| 179 182 | 
             
            		}
         | 
| 180 183 |  | 
| 181 | 
            -
            		std::vector<Float>  | 
| 184 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 182 185 | 
             
            		{
         | 
| 183 186 | 
             
            			std::vector<Float> ret(this->K);
         | 
| 184 187 | 
             
            			auto maskedAlphas = this->alphas.array() * doc.labelMask.template cast<Float>().array();
         | 
| @@ -38,7 +38,10 @@ namespace tomoto | |
| 38 38 | 
             
            		friend typename BaseClass::BaseClass;
         | 
| 39 39 | 
             
            		using WeightType = typename BaseClass::WeightType;
         | 
| 40 40 |  | 
| 41 | 
            -
            		static constexpr  | 
| 41 | 
            +
            		static constexpr auto tmid()
         | 
| 42 | 
            +
            		{
         | 
| 43 | 
            +
            			return serializer::to_key("PTM");
         | 
| 44 | 
            +
            		}
         | 
| 42 45 |  | 
| 43 46 | 
             
            		uint64_t numPDocs;
         | 
| 44 47 | 
             
            		Float lambda;
         | 
| @@ -261,7 +264,7 @@ namespace tomoto | |
| 261 264 | 
             
            		{
         | 
| 262 265 | 
             
            		}
         | 
| 263 266 |  | 
| 264 | 
            -
            		std::vector<Float>  | 
| 267 | 
            +
            		std::vector<Float> _getTopicsByDoc(const _DocType& doc, bool normalize) const
         | 
| 265 268 | 
             
            		{
         | 
| 266 269 | 
             
            			std::vector<Float> ret(this->K);
         | 
| 267 270 | 
             
            			Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
         | 
| @@ -216,7 +216,10 @@ namespace tomoto | |
| 216 216 | 
             
            		friend typename BaseClass::BaseClass;
         | 
| 217 217 | 
             
            		using WeightType = typename BaseClass::WeightType;
         | 
| 218 218 |  | 
| 219 | 
            -
            		static constexpr  | 
| 219 | 
            +
            		static constexpr auto tmid()
         | 
| 220 | 
            +
            		{
         | 
| 221 | 
            +
            			return serializer::to_key("SLDA");
         | 
| 222 | 
            +
            		}
         | 
| 220 223 |  | 
| 221 224 | 
             
            		uint64_t F; // number of response variables
         | 
| 222 225 | 
             
            		std::vector<ISLDAModel::GLM> varTypes;
         | 
| @@ -249,6 +249,7 @@ namespace tomoto | |
| 249 249 | 
             
            		virtual size_t getNumDocs() const = 0;
         | 
| 250 250 | 
             
            		virtual const Dictionary& getVocabDict() const = 0;
         | 
| 251 251 | 
             
            		virtual const std::vector<uint64_t>& getVocabCf() const = 0;
         | 
| 252 | 
            +
            		virtual std::vector<double> getVocabWeightedCf() const = 0;
         | 
| 252 253 | 
             
            		virtual const std::vector<uint64_t>& getVocabDf() const = 0;
         | 
| 253 254 |  | 
| 254 255 | 
             
            		virtual int train(size_t iteration, size_t numWorkers, ParallelScheme ps = ParallelScheme::default_, bool freeze_topics = false) = 0;
         | 
| @@ -319,6 +320,7 @@ namespace tomoto | |
| 319 320 | 
             
            		Dictionary dict;
         | 
| 320 321 | 
             
            		uint64_t realV = 0; // vocab size after removing stopwords
         | 
| 321 322 | 
             
            		uint64_t realN = 0; // total word size after removing stopwords
         | 
| 323 | 
            +
            		double weightedN = 0;
         | 
| 322 324 | 
             
            		size_t maxThreads[(size_t)ParallelScheme::size] = { 0, };
         | 
| 323 325 | 
             
            		size_t minWordCf = 0, minWordDf = 0, removeTopN = 0;
         | 
| 324 326 |  | 
| @@ -327,15 +329,17 @@ namespace tomoto | |
| 327 329 | 
             
            		void _saveModel(std::ostream& writer, bool fullModel, const std::vector<uint8_t>* extra_data) const
         | 
| 328 330 | 
             
            		{
         | 
| 329 331 | 
             
            			serializer::writeMany(writer,
         | 
| 330 | 
            -
            				serializer::to_keyz(static_cast<const _Derived*>(this)-> | 
| 331 | 
            -
            				serializer::to_keyz(static_cast<const _Derived*>(this)-> | 
| 332 | 
            +
            				serializer::to_keyz(static_cast<const _Derived*>(this)->tmid()),
         | 
| 333 | 
            +
            				serializer::to_keyz(static_cast<const _Derived*>(this)->twid())
         | 
| 334 | 
            +
            			);
         | 
| 332 335 | 
             
            			serializer::writeTaggedMany(writer, 0x00010001,
         | 
| 333 336 | 
             
            				serializer::to_keyz("dict"), dict, 
         | 
| 334 337 | 
             
            				serializer::to_keyz("vocabCf"), vocabCf,
         | 
| 335 338 | 
             
            				serializer::to_keyz("vocabDf"), vocabDf,
         | 
| 336 339 | 
             
            				serializer::to_keyz("realV"), realV,
         | 
| 337 340 | 
             
            				serializer::to_keyz("globalStep"), globalStep,
         | 
| 338 | 
            -
            				serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0) | 
| 341 | 
            +
            				serializer::to_keyz("extra"), extra_data ? *extra_data : std::vector<uint8_t>(0)
         | 
| 342 | 
            +
            			);
         | 
| 339 343 | 
             
            			serializer::writeMany(writer, *static_cast<const _Derived*>(this));
         | 
| 340 344 | 
             
            			globalState.serializerWrite(writer);
         | 
| 341 345 | 
             
            			if (fullModel)
         | 
| @@ -355,8 +359,9 @@ namespace tomoto | |
| 355 359 | 
             
            			{
         | 
| 356 360 | 
             
            				std::vector<uint8_t> extra;
         | 
| 357 361 | 
             
            				serializer::readMany(reader, 
         | 
| 358 | 
            -
            					serializer::to_keyz(static_cast<_Derived*>(this)-> | 
| 359 | 
            -
            					serializer::to_keyz(static_cast<_Derived*>(this)-> | 
| 362 | 
            +
            					serializer::to_keyz(static_cast<_Derived*>(this)->tmid()),
         | 
| 363 | 
            +
            					serializer::to_keyz(static_cast<_Derived*>(this)->twid())
         | 
| 364 | 
            +
            				);
         | 
| 360 365 | 
             
            				serializer::readTaggedMany(reader, 0x00010001, 
         | 
| 361 366 | 
             
            					serializer::to_keyz("dict"), dict,
         | 
| 362 367 | 
             
            					serializer::to_keyz("vocabCf"), vocabCf,
         | 
| @@ -370,14 +375,17 @@ namespace tomoto | |
| 370 375 | 
             
            			{
         | 
| 371 376 | 
             
            				reader.seekg(start_pos);
         | 
| 372 377 | 
             
            				serializer::readMany(reader,
         | 
| 373 | 
            -
            					serializer::to_key(static_cast<_Derived*>(this)-> | 
| 374 | 
            -
            					serializer::to_key(static_cast<_Derived*>(this)-> | 
| 375 | 
            -
            					dict, vocabCf, realV | 
| 378 | 
            +
            					serializer::to_key(static_cast<_Derived*>(this)->tmid()),
         | 
| 379 | 
            +
            					serializer::to_key(static_cast<_Derived*>(this)->twid()),
         | 
| 380 | 
            +
            					dict, vocabCf, realV
         | 
| 381 | 
            +
            				);
         | 
| 376 382 | 
             
            			}
         | 
| 377 383 | 
             
            			serializer::readMany(reader, *static_cast<_Derived*>(this));
         | 
| 378 384 | 
             
            			globalState.serializerRead(reader);
         | 
| 379 385 | 
             
            			serializer::readMany(reader, docs);
         | 
| 380 | 
            -
            			 | 
| 386 | 
            +
            			auto p = countRealN();
         | 
| 387 | 
            +
            			realN = p.first;
         | 
| 388 | 
            +
            			weightedN = p.second;
         | 
| 381 389 | 
             
            		}
         | 
| 382 390 |  | 
| 383 391 | 
             
            		template<typename _DocTy>
         | 
| @@ -490,17 +498,23 @@ namespace tomoto | |
| 490 498 | 
             
            			}
         | 
| 491 499 | 
             
            		}
         | 
| 492 500 |  | 
| 493 | 
            -
            		size_t countRealN() const
         | 
| 501 | 
            +
            		std::pair<size_t, double> countRealN() const
         | 
| 494 502 | 
             
            		{
         | 
| 495 503 | 
             
            			size_t n = 0;
         | 
| 504 | 
            +
            			double weighted = 0;
         | 
| 496 505 | 
             
            			for (auto& doc : docs)
         | 
| 497 506 | 
             
            			{
         | 
| 498 | 
            -
            				for ( | 
| 507 | 
            +
            				for (size_t i = 0; i < doc.words.size(); ++i)
         | 
| 499 508 | 
             
            				{
         | 
| 500 | 
            -
            					 | 
| 509 | 
            +
            					auto w = doc.words[i];
         | 
| 510 | 
            +
            					if (w < realV)
         | 
| 511 | 
            +
            					{
         | 
| 512 | 
            +
            						++n;
         | 
| 513 | 
            +
            						weighted += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
         | 
| 514 | 
            +
            					}
         | 
| 501 515 | 
             
            				}
         | 
| 502 516 | 
             
            			}
         | 
| 503 | 
            -
            			return n;
         | 
| 517 | 
            +
            			return std::make_pair(n, weighted);
         | 
| 504 518 | 
             
            		}
         | 
| 505 519 |  | 
| 506 520 | 
             
            		void removeStopwords(size_t minWordCnt, size_t minWordDf, size_t removeTopN)
         | 
| @@ -544,14 +558,9 @@ namespace tomoto | |
| 544 558 | 
             
            			}
         | 
| 545 559 |  | 
| 546 560 | 
             
            			dict.reorder(order);
         | 
| 547 | 
            -
            			realN = 0;
         | 
| 548 561 | 
             
            			for (auto& doc : docs)
         | 
| 549 562 | 
             
            			{
         | 
| 550 | 
            -
            				for (auto& w : doc.words)
         | 
| 551 | 
            -
            				{
         | 
| 552 | 
            -
            					w = order[w];
         | 
| 553 | 
            -
            					if (w < realV) ++realN;
         | 
| 554 | 
            -
            				}
         | 
| 563 | 
            +
            				for (auto& w : doc.words) w = order[w];
         | 
| 555 564 | 
             
            			}
         | 
| 556 565 | 
             
            		}
         | 
| 557 566 |  | 
| @@ -598,6 +607,10 @@ namespace tomoto | |
| 598 607 |  | 
| 599 608 | 
             
            		void prepare(bool initDocs = true, size_t minWordCnt = 0, size_t minWordDf = 0, size_t removeTopN = 0) override
         | 
| 600 609 | 
             
            		{
         | 
| 610 | 
            +
            			auto p = countRealN();
         | 
| 611 | 
            +
            			realN = p.first;
         | 
| 612 | 
            +
            			weightedN = p.second;
         | 
| 613 | 
            +
             | 
| 601 614 | 
             
            			maxThreads[(size_t)ParallelScheme::default_] = -1;
         | 
| 602 615 | 
             
            			maxThreads[(size_t)ParallelScheme::none] = -1;
         | 
| 603 616 | 
             
            			maxThreads[(size_t)ParallelScheme::copy_merge] = static_cast<_Derived*>(this)->template estimateMaxThreads<ParallelScheme::copy_merge>();
         | 
| @@ -697,7 +710,7 @@ namespace tomoto | |
| 697 710 |  | 
| 698 711 | 
             
            		double getLLPerWord() const override
         | 
| 699 712 | 
             
            		{
         | 
| 700 | 
            -
            			return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() /  | 
| 713 | 
            +
            			return words.empty() ? 0 : static_cast<const _Derived*>(this)->getLL() / weightedN;
         | 
| 701 714 | 
             
            		}
         | 
| 702 715 |  | 
| 703 716 | 
             
            		double getPerplexity() const override
         | 
| @@ -797,7 +810,7 @@ namespace tomoto | |
| 797 810 |  | 
| 798 811 | 
             
            		std::vector<Float> getTopicsByDoc(const DocumentBase* doc, bool normalize) const override
         | 
| 799 812 | 
             
            		{
         | 
| 800 | 
            -
            			return static_cast<const _Derived*>(this)-> | 
| 813 | 
            +
            			return static_cast<const _Derived*>(this)->_getTopicsByDoc(*static_cast<const DocType*>(doc), normalize);
         | 
| 801 814 | 
             
            		}
         | 
| 802 815 |  | 
| 803 816 | 
             
            		std::vector<std::pair<Tid, Float>> getTopicsByDocSorted(const DocumentBase* doc, size_t topN) const override
         | 
| @@ -832,6 +845,20 @@ namespace tomoto | |
| 832 845 | 
             
            			return vocabCf;
         | 
| 833 846 | 
             
            		}
         | 
| 834 847 |  | 
| 848 | 
            +
            		std::vector<double> getVocabWeightedCf() const override
         | 
| 849 | 
            +
            		{
         | 
| 850 | 
            +
            			std::vector<double> ret(realV);
         | 
| 851 | 
            +
            			for (auto& doc : docs)
         | 
| 852 | 
            +
            			{
         | 
| 853 | 
            +
            				for (size_t i = 0; i < doc.words.size(); ++i)
         | 
| 854 | 
            +
            				{
         | 
| 855 | 
            +
            					if (doc.words[i] >= realV) continue;
         | 
| 856 | 
            +
            					ret[doc.words[i]] += doc.wordWeights.empty() ? 1 : doc.wordWeights[i];
         | 
| 857 | 
            +
            				}
         | 
| 858 | 
            +
            			}
         | 
| 859 | 
            +
            			return ret;
         | 
| 860 | 
            +
            		}
         | 
| 861 | 
            +
             | 
| 835 862 | 
             
            		const std::vector<uint64_t>& getVocabDf() const override
         | 
| 836 863 | 
             
            		{
         | 
| 837 864 | 
             
            			return vocabDf;
         |