tomoto 0.1.4 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/ext/tomoto/ct.cpp +8 -4
- data/ext/tomoto/dmr.cpp +10 -4
- data/ext/tomoto/dt.cpp +13 -4
- data/ext/tomoto/extconf.rb +1 -1
- data/ext/tomoto/gdmr.cpp +14 -6
- data/ext/tomoto/hdp.cpp +9 -4
- data/ext/tomoto/hlda.cpp +9 -4
- data/ext/tomoto/hpa.cpp +9 -4
- data/ext/tomoto/lda.cpp +8 -4
- data/ext/tomoto/llda.cpp +8 -4
- data/ext/tomoto/mglda.cpp +11 -1
- data/ext/tomoto/pa.cpp +9 -4
- data/ext/tomoto/plda.cpp +8 -4
- data/ext/tomoto/slda.cpp +13 -5
- data/lib/tomoto/gdmr.rb +2 -2
- data/lib/tomoto/version.rb +1 -1
- data/vendor/EigenRand/EigenRand/Core.h +6 -1107
- data/vendor/EigenRand/EigenRand/Dists/Basic.h +490 -43
- data/vendor/EigenRand/EigenRand/Dists/Discrete.h +916 -285
- data/vendor/EigenRand/EigenRand/Dists/GammaPoisson.h +85 -36
- data/vendor/EigenRand/EigenRand/Dists/NormalExp.h +1038 -290
- data/vendor/EigenRand/EigenRand/EigenRand +2 -2
- data/vendor/EigenRand/EigenRand/Macro.h +4 -4
- data/vendor/EigenRand/EigenRand/MorePacketMath.h +54 -22
- data/vendor/EigenRand/EigenRand/MvDists/Multinomial.h +222 -0
- data/vendor/EigenRand/EigenRand/MvDists/MvNormal.h +492 -0
- data/vendor/EigenRand/EigenRand/PacketFilter.h +2 -2
- data/vendor/EigenRand/EigenRand/PacketRandomEngine.h +2 -2
- data/vendor/EigenRand/EigenRand/RandUtils.h +65 -11
- data/vendor/EigenRand/EigenRand/doc.h +142 -25
- data/vendor/EigenRand/LICENSE +1 -1
- data/vendor/EigenRand/README.md +109 -24
- data/vendor/tomotopy/README.kr.rst +27 -6
- data/vendor/tomotopy/README.rst +29 -8
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +60 -12
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +2 -2
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +33 -21
- data/vendor/tomotopy/src/TopicModel/CT.h +8 -5
- data/vendor/tomotopy/src/TopicModel/CTModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +29 -23
- data/vendor/tomotopy/src/TopicModel/DMR.h +33 -4
- data/vendor/tomotopy/src/TopicModel/DMRModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/DMRModel.hpp +231 -57
- data/vendor/tomotopy/src/TopicModel/DT.h +24 -5
- data/vendor/tomotopy/src/TopicModel/DTModel.cpp +2 -8
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +41 -28
- data/vendor/tomotopy/src/TopicModel/GDMR.h +31 -5
- data/vendor/tomotopy/src/TopicModel/GDMRModel.cpp +2 -7
- data/vendor/tomotopy/src/TopicModel/GDMRModel.hpp +211 -104
- data/vendor/tomotopy/src/TopicModel/HDP.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HDPModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +52 -45
- data/vendor/tomotopy/src/TopicModel/HLDA.h +11 -2
- data/vendor/tomotopy/src/TopicModel/HLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +13 -16
- data/vendor/tomotopy/src/TopicModel/HPA.h +5 -2
- data/vendor/tomotopy/src/TopicModel/HPAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +51 -21
- data/vendor/tomotopy/src/TopicModel/LDA.h +9 -2
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +8 -8
- data/vendor/tomotopy/src/TopicModel/LDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +70 -28
- data/vendor/tomotopy/src/TopicModel/LLDA.h +1 -2
- data/vendor/tomotopy/src/TopicModel/LLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/LLDAModel.hpp +22 -12
- data/vendor/tomotopy/src/TopicModel/MGLDA.h +12 -3
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +42 -19
- data/vendor/tomotopy/src/TopicModel/PA.h +9 -4
- data/vendor/tomotopy/src/TopicModel/PAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +48 -25
- data/vendor/tomotopy/src/TopicModel/PLDA.h +13 -2
- data/vendor/tomotopy/src/TopicModel/PLDAModel.cpp +2 -6
- data/vendor/tomotopy/src/TopicModel/PLDAModel.hpp +27 -19
- data/vendor/tomotopy/src/TopicModel/PT.h +12 -5
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +2 -3
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +29 -14
- data/vendor/tomotopy/src/TopicModel/SLDA.h +18 -6
- data/vendor/tomotopy/src/TopicModel/SLDAModel.cpp +2 -10
- data/vendor/tomotopy/src/TopicModel/SLDAModel.hpp +93 -43
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +58 -23
- data/vendor/tomotopy/src/Utils/AliasMethod.hpp +6 -6
- data/vendor/tomotopy/src/Utils/Dictionary.h +11 -0
- data/vendor/tomotopy/src/Utils/SharedString.hpp +26 -1
- data/vendor/tomotopy/src/Utils/Trie.hpp +46 -21
- data/vendor/tomotopy/src/Utils/Utils.hpp +99 -14
- data/vendor/tomotopy/src/Utils/exception.h +1 -1
- data/vendor/tomotopy/src/Utils/math.h +5 -7
- data/vendor/tomotopy/src/Utils/serializer.hpp +329 -201
- data/vendor/tomotopy/src/Utils/text.hpp +8 -0
- data/vendor/tomotopy/src/Utils/tvector.hpp +49 -7
- metadata +9 -7
|
@@ -32,11 +32,10 @@ tomotopy 란?
|
|
|
32
32
|
* Hierarchical PA (`tomotopy.HPAModel`)
|
|
33
33
|
* Correlated Topic Model (`tomotopy.CTModel`)
|
|
34
34
|
* Dynamic Topic Model (`tomotopy.DTModel`)
|
|
35
|
+
* Pseudo-document based Topic Model (`tomotopy.PTModel`)
|
|
35
36
|
|
|
36
37
|
더 자세한 정보는 https://bab2min.github.io/tomotopy/index.kr.html 에서 확인하시길 바랍니다.
|
|
37
38
|
|
|
38
|
-
tomotopy의 가장 최신버전은 0.10.2 입니다.
|
|
39
|
-
|
|
40
39
|
시작하기
|
|
41
40
|
---------------
|
|
42
41
|
다음과 같이 pip를 이용하면 tomotopy를 쉽게 설치할 수 있습니다.
|
|
@@ -47,10 +46,10 @@ tomotopy의 가장 최신버전은 0.10.2 입니다.
|
|
|
47
46
|
|
|
48
47
|
지원하는 운영체제 및 Python 버전은 다음과 같습니다:
|
|
49
48
|
|
|
50
|
-
* Python 3.
|
|
51
|
-
* Python 3.
|
|
52
|
-
* Python 3.
|
|
53
|
-
* Python 3.
|
|
49
|
+
* Python 3.6 이상이 설치된 Linux (x86-64)
|
|
50
|
+
* Python 3.6 이상이 설치된 macOS 10.13나 그 이후 버전
|
|
51
|
+
* Python 3.6 이상이 설치된 Windows 7이나 그 이후 버전 (x86, x86-64)
|
|
52
|
+
* Python 3.6 이상이 설치된 다른 운영체제: 이 경우는 c++14 호환 컴파일러를 통한 소스코드 컴파일이 필요합니다.
|
|
54
53
|
|
|
55
54
|
설치가 끝난 뒤에는 다음과 같이 Python3에서 바로 import하여 tomotopy를 사용할 수 있습니다.
|
|
56
55
|
::
|
|
@@ -255,6 +254,28 @@ tomotopy의 Python3 예제 코드는 https://github.com/bab2min/tomotopy/blob/ma
|
|
|
255
254
|
|
|
256
255
|
역사
|
|
257
256
|
-------
|
|
257
|
+
* 0.12.0 (2021-04-26)
|
|
258
|
+
* 이제 `tomotopy.DMRModel`와 `tomotopy.GDMRModel`가 다중 메타데이터를 지원합니다. (https://github.com/bab2min/tomotopy/blob/main/examples/dmr_multi_label.py 참조)
|
|
259
|
+
* `tomotopy.GDMRModel`의 성능이 개선되었습니다.
|
|
260
|
+
* 깊은 복사를 수행하는 `copy()` 메소드가 모든 토픽 모델 클래스에 추가되었습니다.
|
|
261
|
+
* `min_cf`, `min_df` 등에 의해 학습에서 제외된 단어가 잘못된 토픽id값을 가지는 문제가 해결되었습니다. 이제 제외단 단어들은 토픽id로 모두 `-1` 값을 가집니다.
|
|
262
|
+
* 이제 `tomotopy`에 의해 생성되는 예외 및 경고가 모두 Python 표준 타입을 따릅니다.
|
|
263
|
+
* 컴파일러 요구사항이 C++14로 상향되었습니다.
|
|
264
|
+
|
|
265
|
+
* 0.11.1 (2021-03-28)
|
|
266
|
+
* 비대칭 alpha와 관련된 치명적인 버그가 수정되었습니다. 이 버그로 인해 0.11.0 버전은 릴리즈에서 삭제되었습니다.
|
|
267
|
+
|
|
268
|
+
* 0.11.0 (2021-03-26)
|
|
269
|
+
* 짧은 텍스트를 위한 토픽 모델인 `tomotopy.PTModel`가 추가되었습니다.
|
|
270
|
+
* `tomotopy.HDPModel.infer`가 종종 segmentation fault를 발생시키는 문제가 해결되었습니다.
|
|
271
|
+
* numpy API 버전 충돌이 해결되었습니다.
|
|
272
|
+
* 이제 비대칭 문헌-토픽 사전 분포가 지원됩니다.
|
|
273
|
+
* 토픽 모델 객체를 메모리 상의 `bytes`로 직렬화하는 기능이 지원됩니다.
|
|
274
|
+
* `get_topic_dist()`, `get_topic_word_dist()`, `get_sub_topic_dist()`에 결과의 정규화 여부를 조절하는 `normalize` 인자가 추가되었습니다.
|
|
275
|
+
* `tomotopy.DMRModel.lambdas`와 `tomotopy.DMRModel.alpha`가 잘못된 값을 제공하던 문제가 해결되었습니다.
|
|
276
|
+
* `tomotopy.GDMRModel`에 범주형 메타데이터 지원이 추가되었습니다. (https://github.com/bab2min/tomotopy/blob/main/examples/gdmr_both_categorical_and_numerical.py 참조)
|
|
277
|
+
* Python3.5 지원이 종료되었습니다.
|
|
278
|
+
|
|
258
279
|
* 0.10.2 (2021-02-16)
|
|
259
280
|
* `tomotopy.CTModel.train`가 큰 K값에 대해 실패하는 문제가 수정되었습니다.
|
|
260
281
|
* `tomotopy.utils.Corpus`가 `uid`값을 잃는 문제가 수정되었습니다.
|
data/vendor/tomotopy/README.rst
CHANGED
|
@@ -32,12 +32,11 @@ The current version of `tomoto` supports several major topic models including
|
|
|
32
32
|
* Pachinko Allocation (`tomotopy.PAModel`)
|
|
33
33
|
* Hierarchical PA (`tomotopy.HPAModel`)
|
|
34
34
|
* Correlated Topic Model (`tomotopy.CTModel`)
|
|
35
|
-
* Dynamic Topic Model (`tomotopy.DTModel`)
|
|
35
|
+
* Dynamic Topic Model (`tomotopy.DTModel`)
|
|
36
|
+
* Pseudo-document based Topic Model (`tomotopy.PTModel`).
|
|
36
37
|
|
|
37
38
|
Please visit https://bab2min.github.io/tomotopy to see more information.
|
|
38
39
|
|
|
39
|
-
The most recent version of tomotopy is 0.10.2.
|
|
40
|
-
|
|
41
40
|
Getting Started
|
|
42
41
|
---------------
|
|
43
42
|
You can install tomotopy easily using pip. (https://pypi.org/project/tomotopy/)
|
|
@@ -48,10 +47,10 @@ You can install tomotopy easily using pip. (https://pypi.org/project/tomotopy/)
|
|
|
48
47
|
|
|
49
48
|
The supported OS and Python versions are:
|
|
50
49
|
|
|
51
|
-
* Linux (x86-64) with Python >= 3.
|
|
52
|
-
* macOS >= 10.13 with Python >= 3.
|
|
53
|
-
* Windows 7 or later (x86, x86-64) with Python >= 3.
|
|
54
|
-
* Other OS with Python >= 3.
|
|
50
|
+
* Linux (x86-64) with Python >= 3.6
|
|
51
|
+
* macOS >= 10.13 with Python >= 3.6
|
|
52
|
+
* Windows 7 or later (x86, x86-64) with Python >= 3.6
|
|
53
|
+
* Other OS with Python >= 3.6: Compilation from source code required (with c++14 compatible compiler)
|
|
55
54
|
|
|
56
55
|
After installing, you can start tomotopy by just importing.
|
|
57
56
|
::
|
|
@@ -261,6 +260,28 @@ meaning you can use it for any reasonable purpose and remain in complete ownersh
|
|
|
261
260
|
|
|
262
261
|
History
|
|
263
262
|
-------
|
|
263
|
+
* 0.12.0 (2021-04-26)
|
|
264
|
+
* Now `tomotopy.DMRModel` and `tomotopy.GDMRModel` support multiple values of metadata (see https://github.com/bab2min/tomotopy/blob/main/examples/dmr_multi_label.py )
|
|
265
|
+
* The performance of `tomotopy.GDMRModel` was improved.
|
|
266
|
+
* A `copy()` method has been added for all topic models to do a deep copy.
|
|
267
|
+
* An issue was fixed where words that are excluded from training (by `min_cf`, `min_df`) have incorrect topic id. Now all excluded words have `-1` as topic id.
|
|
268
|
+
* Now all exceptions and warnings that generated by `tomotopy` follow standard Python types.
|
|
269
|
+
* Compiler requirements have been raised to C++14.
|
|
270
|
+
|
|
271
|
+
* 0.11.1 (2021-03-28)
|
|
272
|
+
* A critical bug of asymmetric alphas was fixed. Due to this bug, version 0.11.0 has been removed from releases.
|
|
273
|
+
|
|
274
|
+
* 0.11.0 (2021-03-26) (removed)
|
|
275
|
+
* A new topic model `tomotopy.PTModel` for short texts was added into the package.
|
|
276
|
+
* An issue was fixed where `tomotopy.HDPModel.infer` causes a segmentation fault sometimes.
|
|
277
|
+
* A mismatch of numpy API version was fixed.
|
|
278
|
+
* Now asymmetric document-topic priors are supported.
|
|
279
|
+
* Serializing topic models to `bytes` in memory is supported.
|
|
280
|
+
* An argument `normalize` was added to `get_topic_dist()`, `get_topic_word_dist()` and `get_sub_topic_dist()` for controlling normalization of results.
|
|
281
|
+
* Now `tomotopy.DMRModel.lambdas` and `tomotopy.DMRModel.alpha` give correct values.
|
|
282
|
+
* Categorical metadata supports for `tomotopy.GDMRModel` were added (see https://github.com/bab2min/tomotopy/blob/main/examples/gdmr_both_categorical_and_numerical.py ).
|
|
283
|
+
* Python3.5 support was dropped.
|
|
284
|
+
|
|
264
285
|
* 0.10.2 (2021-02-16)
|
|
265
286
|
* An issue was fixed where `tomotopy.CTModel.train` fails with large K.
|
|
266
287
|
* An issue was fixed where `tomotopy.utils.Corpus` loses their `uid` values.
|
|
@@ -273,7 +294,7 @@ History
|
|
|
273
294
|
|
|
274
295
|
* 0.10.0 (2020-12-19)
|
|
275
296
|
* The interface of `tomotopy.utils.Corpus` and of `tomotopy.LDAModel.docs` were unified. Now you can access the document in corpus with the same manner.
|
|
276
|
-
* __getitem__ of `tomotopy.utils.Corpus` was improved. Not only indexing by int, but also by Iterable[int], slicing are supported. Also indexing by uid is supported.
|
|
297
|
+
* `__getitem__` of `tomotopy.utils.Corpus` was improved. Not only indexing by int, but also by Iterable[int], slicing are supported. Also indexing by uid is supported.
|
|
277
298
|
* New methods `tomotopy.utils.Corpus.extract_ngrams` and `tomotopy.utils.Corpus.concat_ngrams` were added. They extracts n-gram collocations using PMI and concatenates them into a single words.
|
|
278
299
|
* A new method `tomotopy.LDAModel.add_corpus` was added, and `tomotopy.LDAModel.infer` can receive corpus as input.
|
|
279
300
|
* A new module `tomotopy.coherence` was added. It provides the way to calculate coherence of the model.
|
|
@@ -6,6 +6,55 @@
|
|
|
6
6
|
|
|
7
7
|
using namespace tomoto::label;
|
|
8
8
|
|
|
9
|
+
template<bool reverse = false>
|
|
10
|
+
class DocWordIterator
|
|
11
|
+
{
|
|
12
|
+
const tomoto::DocumentBase* doc = nullptr;
|
|
13
|
+
size_t n = 0;
|
|
14
|
+
public:
|
|
15
|
+
DocWordIterator(const tomoto::DocumentBase* _doc = nullptr, size_t _n = 0)
|
|
16
|
+
: doc{ _doc }, n{ _n }
|
|
17
|
+
{
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
tomoto::Vid operator[](size_t i) const
|
|
21
|
+
{
|
|
22
|
+
return doc->words[doc->wOrder.empty() ? (n + i) : doc->wOrder[n + i]];
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
tomoto::Vid operator*() const
|
|
26
|
+
{
|
|
27
|
+
return doc->words[doc->wOrder.empty() ? n : doc->wOrder[n]];
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
bool operator==(const DocWordIterator& o) const
|
|
31
|
+
{
|
|
32
|
+
return doc == o.doc && n == o.n;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
bool operator!=(const DocWordIterator& o) const
|
|
36
|
+
{
|
|
37
|
+
return !operator==(o);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
DocWordIterator& operator++()
|
|
41
|
+
{
|
|
42
|
+
if (reverse) --n;
|
|
43
|
+
else ++n;
|
|
44
|
+
return *this;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
DocWordIterator operator+(ptrdiff_t o) const
|
|
48
|
+
{
|
|
49
|
+
return { doc, (size_t)((ptrdiff_t)n + o) };
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
DocWordIterator operator-(ptrdiff_t o) const
|
|
53
|
+
{
|
|
54
|
+
return { doc, (size_t)((ptrdiff_t)n - o) };
|
|
55
|
+
}
|
|
56
|
+
};
|
|
57
|
+
|
|
9
58
|
class DocWrapper
|
|
10
59
|
{
|
|
11
60
|
const tomoto::DocumentBase* doc;
|
|
@@ -25,24 +74,24 @@ public:
|
|
|
25
74
|
return doc->words[doc->wOrder.empty() ? idx : doc->wOrder[idx]];
|
|
26
75
|
}
|
|
27
76
|
|
|
28
|
-
|
|
77
|
+
DocWordIterator<> begin() const
|
|
29
78
|
{
|
|
30
|
-
return doc
|
|
79
|
+
return { doc, 0 };
|
|
31
80
|
}
|
|
32
81
|
|
|
33
|
-
|
|
82
|
+
DocWordIterator<> end() const
|
|
34
83
|
{
|
|
35
|
-
return doc->words.
|
|
84
|
+
return { doc, doc->words.size() };
|
|
36
85
|
}
|
|
37
86
|
|
|
38
|
-
|
|
87
|
+
DocWordIterator<true> rbegin() const
|
|
39
88
|
{
|
|
40
|
-
return doc->words.
|
|
89
|
+
return { doc, doc->words.size() };
|
|
41
90
|
}
|
|
42
91
|
|
|
43
|
-
|
|
92
|
+
DocWordIterator<true> rend() const
|
|
44
93
|
{
|
|
45
|
-
return doc
|
|
94
|
+
return { doc, 0 };
|
|
46
95
|
}
|
|
47
96
|
};
|
|
48
97
|
|
|
@@ -99,7 +148,6 @@ std::vector<Candidate> PMIExtractor::extract(const tomoto::ITopicModel* tm) cons
|
|
|
99
148
|
return candidates;
|
|
100
149
|
}
|
|
101
150
|
|
|
102
|
-
|
|
103
151
|
std::vector<Candidate> tomoto::label::PMIBEExtractor::extract(const ITopicModel* tm) const
|
|
104
152
|
{
|
|
105
153
|
auto& vocabFreqs = tm->getVocabCf();
|
|
@@ -217,11 +265,11 @@ void FoRelevance::estimateContexts()
|
|
|
217
265
|
}
|
|
218
266
|
}
|
|
219
267
|
|
|
220
|
-
|
|
268
|
+
Matrix wordTopicDist{ tm->getV(), tm->getK() };
|
|
221
269
|
for (size_t i = 0; i < tm->getK(); ++i)
|
|
222
270
|
{
|
|
223
271
|
auto dist = tm->getWidsByTopic(i);
|
|
224
|
-
wordTopicDist.col(i) = Eigen::Map<
|
|
272
|
+
wordTopicDist.col(i) = Eigen::Map<Vector>{ dist.data(), (Eigen::Index)dist.size() };
|
|
225
273
|
}
|
|
226
274
|
|
|
227
275
|
size_t totDocCnt = 0;
|
|
@@ -256,7 +304,7 @@ void FoRelevance::estimateContexts()
|
|
|
256
304
|
}
|
|
257
305
|
|
|
258
306
|
size_t docCnt = 0;
|
|
259
|
-
|
|
307
|
+
Vector wcPMI = Vector::Zero(this->tm->getV());
|
|
260
308
|
for (auto& docId : c.docIds)
|
|
261
309
|
{
|
|
262
310
|
thread_local Eigen::VectorXi bdf(this->tm->getV());
|
|
@@ -93,8 +93,8 @@ namespace tomoto
|
|
|
93
93
|
if (!numWorkers) numWorkers = std::thread::hardware_concurrency();
|
|
94
94
|
if (numWorkers > 1)
|
|
95
95
|
{
|
|
96
|
-
pool = make_unique<ThreadPool>(numWorkers);
|
|
97
|
-
mtx = make_unique<std::mutex[]>(numWorkers);
|
|
96
|
+
pool = std::make_unique<ThreadPool>(numWorkers);
|
|
97
|
+
mtx = std::make_unique<std::mutex[]>(numWorkers);
|
|
98
98
|
}
|
|
99
99
|
|
|
100
100
|
for (; candFirst != candEnd; ++candFirst)
|
|
@@ -1,14 +1,37 @@
|
|
|
1
1
|
#pragma once
|
|
2
2
|
|
|
3
3
|
#include <vector>
|
|
4
|
+
#include <map>
|
|
4
5
|
#include <unordered_map>
|
|
5
6
|
#include "Labeler.h"
|
|
6
7
|
#include "../Utils/Trie.hpp"
|
|
7
8
|
|
|
9
|
+
#ifdef TMT_USE_BTREE
|
|
10
|
+
#include "btree/map.h"
|
|
11
|
+
#else
|
|
12
|
+
#endif
|
|
13
|
+
|
|
8
14
|
namespace tomoto
|
|
9
15
|
{
|
|
10
16
|
namespace phraser
|
|
11
17
|
{
|
|
18
|
+
#ifdef TMT_USE_BTREE
|
|
19
|
+
template<typename K, typename V> using map = btree::map<K, V>;
|
|
20
|
+
#else
|
|
21
|
+
template<typename K, typename V> using map = std::map<K, V>;
|
|
22
|
+
#endif
|
|
23
|
+
|
|
24
|
+
namespace detail
|
|
25
|
+
{
|
|
26
|
+
struct vvhash
|
|
27
|
+
{
|
|
28
|
+
size_t operator()(const std::pair<Vid, Vid>& k) const
|
|
29
|
+
{
|
|
30
|
+
return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
|
|
31
|
+
}
|
|
32
|
+
};
|
|
33
|
+
}
|
|
34
|
+
|
|
12
35
|
template<typename _DocIter>
|
|
13
36
|
void countUnigrams(std::vector<size_t>& unigramCf, std::vector<size_t>& unigramDf,
|
|
14
37
|
_DocIter docBegin, _DocIter docEnd
|
|
@@ -30,9 +53,9 @@ namespace tomoto
|
|
|
30
53
|
}
|
|
31
54
|
}
|
|
32
55
|
|
|
33
|
-
template<typename _DocIter, typename
|
|
34
|
-
void countBigrams(
|
|
35
|
-
|
|
56
|
+
template<typename _DocIter, typename _Freqs>
|
|
57
|
+
void countBigrams(map<std::pair<Vid, Vid>, size_t>& bigramCf,
|
|
58
|
+
map<std::pair<Vid, Vid>, size_t>& bigramDf,
|
|
36
59
|
_DocIter docBegin, _DocIter docEnd,
|
|
37
60
|
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
|
38
61
|
size_t candMinCnt, size_t candMinDf
|
|
@@ -40,7 +63,7 @@ namespace tomoto
|
|
|
40
63
|
{
|
|
41
64
|
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
|
42
65
|
{
|
|
43
|
-
std::unordered_set<std::pair<Vid, Vid>,
|
|
66
|
+
std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> uniqBigram;
|
|
44
67
|
auto doc = *docIt;
|
|
45
68
|
if (!doc.size()) continue;
|
|
46
69
|
Vid prevWord = doc[0];
|
|
@@ -202,17 +225,6 @@ namespace tomoto
|
|
|
202
225
|
return std::move(data[0]);
|
|
203
226
|
}
|
|
204
227
|
|
|
205
|
-
namespace detail
|
|
206
|
-
{
|
|
207
|
-
struct vvhash
|
|
208
|
-
{
|
|
209
|
-
size_t operator()(const std::pair<Vid, Vid>& k) const
|
|
210
|
-
{
|
|
211
|
-
return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
|
|
212
|
-
}
|
|
213
|
-
};
|
|
214
|
-
}
|
|
215
|
-
|
|
216
228
|
template<typename _DocIter, typename _Freqs>
|
|
217
229
|
std::vector<label::Candidate> extractPMINgrams(_DocIter docBegin, _DocIter docEnd,
|
|
218
230
|
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
|
@@ -221,13 +233,13 @@ namespace tomoto
|
|
|
221
233
|
ThreadPool* pool = nullptr)
|
|
222
234
|
{
|
|
223
235
|
// counting unigrams & bigrams
|
|
224
|
-
|
|
236
|
+
map<std::pair<Vid, Vid>, size_t> bigramCnt, bigramDf;
|
|
225
237
|
|
|
226
238
|
if (pool && pool->getNumWorkers() > 1)
|
|
227
239
|
{
|
|
228
240
|
using LocalCfDf = std::pair<
|
|
229
|
-
|
|
230
|
-
|
|
241
|
+
decltype(bigramCnt),
|
|
242
|
+
decltype(bigramDf)
|
|
231
243
|
>;
|
|
232
244
|
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
|
233
245
|
std::vector<std::future<void>> futures;
|
|
@@ -363,13 +375,13 @@ namespace tomoto
|
|
|
363
375
|
ThreadPool* pool = nullptr)
|
|
364
376
|
{
|
|
365
377
|
// counting unigrams & bigrams
|
|
366
|
-
|
|
378
|
+
map<std::pair<Vid, Vid>, size_t> bigramCnt, bigramDf;
|
|
367
379
|
|
|
368
380
|
if (pool && pool->getNumWorkers() > 1)
|
|
369
381
|
{
|
|
370
382
|
using LocalCfDf = std::pair<
|
|
371
|
-
|
|
372
|
-
|
|
383
|
+
decltype(bigramCnt),
|
|
384
|
+
decltype(bigramDf)
|
|
373
385
|
>;
|
|
374
386
|
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
|
375
387
|
std::vector<std::future<void>> futures;
|
|
@@ -8,20 +8,23 @@ namespace tomoto
|
|
|
8
8
|
{
|
|
9
9
|
using BaseDocument = DocumentLDA<_tw>;
|
|
10
10
|
using DocumentLDA<_tw>::DocumentLDA;
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
Matrix beta; // Dim: (K, betaSample)
|
|
12
|
+
Vector smBeta; // Dim: K
|
|
13
13
|
|
|
14
14
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 0, smBeta);
|
|
15
15
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseDocument, 1, 0x00010001, smBeta);
|
|
16
16
|
};
|
|
17
17
|
|
|
18
|
+
struct CTArgs : public LDAArgs
|
|
19
|
+
{
|
|
20
|
+
|
|
21
|
+
};
|
|
22
|
+
|
|
18
23
|
class ICTModel : public ILDAModel
|
|
19
24
|
{
|
|
20
25
|
public:
|
|
21
26
|
using DefaultDocType = DocumentCTM<TermWeight::one>;
|
|
22
|
-
static ICTModel* create(TermWeight _weight,
|
|
23
|
-
Float smoothingAlpha = 0.1, Float _eta = 0.01,
|
|
24
|
-
size_t seed = std::random_device{}(),
|
|
27
|
+
static ICTModel* create(TermWeight _weight, const CTArgs& args,
|
|
25
28
|
bool scalarRng = false);
|
|
26
29
|
|
|
27
30
|
virtual void setNumBetaSample(size_t numSample) = 0;
|
|
@@ -2,12 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
namespace tomoto
|
|
4
4
|
{
|
|
5
|
-
|
|
6
|
-
template class CTModel<TermWeight::idf>;
|
|
7
|
-
template class CTModel<TermWeight::pmi>;*/
|
|
8
|
-
|
|
9
|
-
ICTModel* ICTModel::create(TermWeight _weight, size_t _K, Float smoothingAlpha, Float _eta, size_t seed, bool scalarRng)
|
|
5
|
+
ICTModel* ICTModel::create(TermWeight _weight, const CTArgs& args, bool scalarRng)
|
|
10
6
|
{
|
|
11
|
-
TMT_SWITCH_TW(_weight, scalarRng, CTModel,
|
|
7
|
+
TMT_SWITCH_TW(_weight, scalarRng, CTModel, args);
|
|
12
8
|
}
|
|
13
9
|
}
|
|
@@ -56,22 +56,22 @@ namespace tomoto
|
|
|
56
56
|
|
|
57
57
|
void updateBeta(_DocType& doc, _RandGen& rg) const
|
|
58
58
|
{
|
|
59
|
-
|
|
59
|
+
Vector pbeta, lowerBound, upperBound;
|
|
60
60
|
constexpr Float epsilon = 1e-8;
|
|
61
61
|
constexpr size_t burnIn = 3;
|
|
62
62
|
|
|
63
|
-
pbeta = lowerBound = upperBound =
|
|
63
|
+
pbeta = lowerBound = upperBound = Vector::Zero(this->K);
|
|
64
64
|
for (size_t i = 0; i < numBetaSample + burnIn; ++i)
|
|
65
65
|
{
|
|
66
|
-
if (i == 0) pbeta =
|
|
66
|
+
if (i == 0) pbeta = Vector::Ones(this->K);
|
|
67
67
|
else pbeta = doc.beta.col(i % numBetaSample).array().exp();
|
|
68
68
|
|
|
69
69
|
Float betaESum = pbeta.sum() + 1;
|
|
70
70
|
pbeta /= betaESum;
|
|
71
71
|
for (size_t k = 0; k < this->K; ++k)
|
|
72
72
|
{
|
|
73
|
-
Float N_k = doc.numByTopic[k] + this->
|
|
74
|
-
Float N_nk = doc.getSumWordWeight() + this->
|
|
73
|
+
Float N_k = doc.numByTopic[k] + this->alphas[k];
|
|
74
|
+
Float N_nk = doc.getSumWordWeight() + this->alphas[k] * (this->K + 1) - N_k;
|
|
75
75
|
Float u1 = rg.uniform_real(), u2 = rg.uniform_real();
|
|
76
76
|
Float max_uk = epsilon + pow(u1, (Float)1 / N_k) * (pbeta[k] - epsilon);
|
|
77
77
|
Float min_unk = (1 - pow(u2, (Float)1 / N_nk))
|
|
@@ -84,7 +84,7 @@ namespace tomoto
|
|
|
84
84
|
upperBound[k] = std::max(std::min(upperBound[k], (Float)100), (Float)-100);
|
|
85
85
|
if (lowerBound[k] > upperBound[k])
|
|
86
86
|
{
|
|
87
|
-
THROW_ERROR_WITH_INFO(
|
|
87
|
+
THROW_ERROR_WITH_INFO(exc::TrainingError,
|
|
88
88
|
text::format("Bound Error: LB(%f) > UB(%f)\n"
|
|
89
89
|
"max_uk: %f, min_unk: %f, c: %f", lowerBound[k], upperBound[k], max_uk, min_unk, c));
|
|
90
90
|
}
|
|
@@ -96,14 +96,14 @@ namespace tomoto
|
|
|
96
96
|
topicPrior, lowerBound, upperBound, rg, numTMNSample);
|
|
97
97
|
|
|
98
98
|
if (!std::isfinite(doc.beta.col((i + 1) % numBetaSample)[0]))
|
|
99
|
-
THROW_ERROR_WITH_INFO(
|
|
99
|
+
THROW_ERROR_WITH_INFO(exc::TrainingError,
|
|
100
100
|
text::format("doc.beta.col(%d) is %f", (i + 1) % numBetaSample,
|
|
101
101
|
doc.beta.col((i + 1) % numBetaSample)[0]));
|
|
102
102
|
}
|
|
103
103
|
catch (const std::runtime_error& e)
|
|
104
104
|
{
|
|
105
105
|
std::cerr << e.what() << std::endl;
|
|
106
|
-
THROW_ERROR_WITH_INFO(
|
|
106
|
+
THROW_ERROR_WITH_INFO(exc::TrainingError, e.what());
|
|
107
107
|
}
|
|
108
108
|
}
|
|
109
109
|
|
|
@@ -157,7 +157,7 @@ namespace tomoto
|
|
|
157
157
|
}
|
|
158
158
|
}
|
|
159
159
|
|
|
160
|
-
int restoreFromTrainingError(const
|
|
160
|
+
int restoreFromTrainingError(const exc::TrainingError& e, ThreadPool& pool, _ModelState* localData, _RandGen* rgs)
|
|
161
161
|
{
|
|
162
162
|
std::cerr << "Failed to sample! Reset prior and retry!" << std::endl;
|
|
163
163
|
const size_t chStride = std::min(pool.getNumWorkers() * 8, this->docs.size());
|
|
@@ -186,7 +186,7 @@ namespace tomoto
|
|
|
186
186
|
return this->docs[i / numBetaSample].beta.col(i % numBetaSample);
|
|
187
187
|
}, this->docs.size() * numBetaSample);
|
|
188
188
|
if (!std::isfinite(topicPrior.mean[0]))
|
|
189
|
-
THROW_ERROR_WITH_INFO(
|
|
189
|
+
THROW_ERROR_WITH_INFO(exc::TrainingError,
|
|
190
190
|
text::format("topicPrior.mean is %f", topicPrior.mean[0]));
|
|
191
191
|
}
|
|
192
192
|
|
|
@@ -194,21 +194,20 @@ namespace tomoto
|
|
|
194
194
|
double getLLDocs(_DocIter _first, _DocIter _last) const
|
|
195
195
|
{
|
|
196
196
|
const auto K = this->K;
|
|
197
|
-
const auto alpha = this->alpha;
|
|
198
197
|
|
|
199
198
|
double ll = 0;
|
|
200
199
|
for (; _first != _last; ++_first)
|
|
201
200
|
{
|
|
202
201
|
auto& doc = *_first;
|
|
203
|
-
|
|
202
|
+
Vector pbeta = doc.smBeta.array().log();
|
|
204
203
|
Float last = pbeta[K - 1];
|
|
205
204
|
for (Tid k = 0; k < K; ++k)
|
|
206
205
|
{
|
|
207
|
-
ll += pbeta[k] * (doc.numByTopic[k] +
|
|
206
|
+
ll += pbeta[k] * (doc.numByTopic[k] + this->alphas[k]) - math::lgammaT(doc.numByTopic[k] + this->alphas[k] + 1);
|
|
208
207
|
}
|
|
209
208
|
pbeta.array() -= last;
|
|
210
209
|
ll += topicPrior.getLL(pbeta.head(this->K));
|
|
211
|
-
ll += math::lgammaT(doc.getSumWordWeight() +
|
|
210
|
+
ll += math::lgammaT(doc.getSumWordWeight() + this->alphas.sum() + 1);
|
|
212
211
|
}
|
|
213
212
|
return ll;
|
|
214
213
|
}
|
|
@@ -216,8 +215,8 @@ namespace tomoto
|
|
|
216
215
|
void prepareDoc(_DocType& doc, size_t docId, size_t wordSize) const
|
|
217
216
|
{
|
|
218
217
|
BaseClass::prepareDoc(doc, docId, wordSize);
|
|
219
|
-
doc.beta =
|
|
220
|
-
doc.smBeta =
|
|
218
|
+
doc.beta = Matrix::Zero(this->K, numBetaSample);
|
|
219
|
+
doc.smBeta = Vector::Constant(this->K, (Float)1 / this->K);
|
|
221
220
|
}
|
|
222
221
|
|
|
223
222
|
void updateDocs()
|
|
@@ -225,7 +224,7 @@ namespace tomoto
|
|
|
225
224
|
BaseClass::updateDocs();
|
|
226
225
|
for (auto& doc : this->docs)
|
|
227
226
|
{
|
|
228
|
-
doc.beta =
|
|
227
|
+
doc.beta = Matrix::Zero(this->K, numBetaSample);
|
|
229
228
|
}
|
|
230
229
|
}
|
|
231
230
|
|
|
@@ -242,17 +241,24 @@ namespace tomoto
|
|
|
242
241
|
DEFINE_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 0, numBetaSample, numTMNSample, topicPrior);
|
|
243
242
|
DEFINE_TAGGED_SERIALIZER_AFTER_BASE_WITH_VERSION(BaseClass, 1, 0x00010001, numBetaSample, numTMNSample, topicPrior);
|
|
244
243
|
|
|
245
|
-
CTModel(
|
|
246
|
-
: BaseClass(
|
|
244
|
+
CTModel(const CTArgs& args)
|
|
245
|
+
: BaseClass(args)
|
|
247
246
|
{
|
|
248
247
|
this->optimInterval = 2;
|
|
249
248
|
}
|
|
250
249
|
|
|
251
|
-
std::vector<Float> getTopicsByDoc(const _DocType& doc) const
|
|
250
|
+
std::vector<Float> getTopicsByDoc(const _DocType& doc, bool normalize) const
|
|
252
251
|
{
|
|
253
252
|
std::vector<Float> ret(this->K);
|
|
254
|
-
Eigen::Map<Eigen::
|
|
255
|
-
|
|
253
|
+
Eigen::Map<Eigen::Array<Float, -1, 1>> m{ ret.data(), this->K };
|
|
254
|
+
if (normalize)
|
|
255
|
+
{
|
|
256
|
+
m = (doc.numByTopic.array().template cast<Float>() + this->alphas.array()) / (doc.getSumWordWeight() + this->alphas.sum());
|
|
257
|
+
}
|
|
258
|
+
else
|
|
259
|
+
{
|
|
260
|
+
m = doc.numByTopic.array().template cast<Float>() + this->alphas.array();
|
|
261
|
+
}
|
|
256
262
|
return ret;
|
|
257
263
|
}
|
|
258
264
|
|
|
@@ -268,7 +274,7 @@ namespace tomoto
|
|
|
268
274
|
|
|
269
275
|
std::vector<Float> getCorrelationTopic(Tid k) const override
|
|
270
276
|
{
|
|
271
|
-
|
|
277
|
+
Vector ret = topicPrior.cov.col(k).array() / (topicPrior.cov.diagonal().array() * topicPrior.cov(k, k)).sqrt();
|
|
272
278
|
return { ret.data(), ret.data() + ret.size() };
|
|
273
279
|
}
|
|
274
280
|
|