tomoto 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -10,6 +10,7 @@ namespace tomoto
|
|
10
10
|
struct Candidate
|
11
11
|
{
|
12
12
|
float score = 0;
|
13
|
+
size_t cf = 0, df = 0;
|
13
14
|
std::vector<Vid> w;
|
14
15
|
std::string name;
|
15
16
|
|
@@ -18,17 +19,17 @@ namespace tomoto
|
|
18
19
|
}
|
19
20
|
|
20
21
|
Candidate(float _score, Vid w1)
|
21
|
-
:
|
22
|
+
: score{ _score }, w{ w1 }
|
22
23
|
{
|
23
24
|
}
|
24
25
|
|
25
26
|
Candidate(float _score, Vid w1, Vid w2)
|
26
|
-
:
|
27
|
+
: score{ _score }, w{ w1, w2 }
|
27
28
|
{
|
28
29
|
}
|
29
30
|
|
30
31
|
Candidate(float _score, const std::vector<Vid>& _w)
|
31
|
-
:
|
32
|
+
: score{ _score }, w{ _w }
|
32
33
|
{
|
33
34
|
}
|
34
35
|
};
|
@@ -36,6 +37,7 @@ namespace tomoto
|
|
36
37
|
class IExtractor
|
37
38
|
{
|
38
39
|
public:
|
40
|
+
|
39
41
|
virtual std::vector<Candidate> extract(const ITopicModel* tm) const = 0;
|
40
42
|
virtual ~IExtractor() {}
|
41
43
|
};
|
@@ -0,0 +1,518 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <vector>
|
4
|
+
#include <unordered_map>
|
5
|
+
#include "Labeler.h"
|
6
|
+
#include "../Utils/Trie.hpp"
|
7
|
+
|
8
|
+
namespace tomoto
|
9
|
+
{
|
10
|
+
namespace phraser
|
11
|
+
{
|
12
|
+
template<typename _DocIter>
|
13
|
+
void countUnigrams(std::vector<size_t>& unigramCf, std::vector<size_t>& unigramDf,
|
14
|
+
_DocIter docBegin, _DocIter docEnd
|
15
|
+
)
|
16
|
+
{
|
17
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
18
|
+
{
|
19
|
+
auto doc = *docIt;
|
20
|
+
if (!doc.size()) continue;
|
21
|
+
std::unordered_set<Vid> uniqs;
|
22
|
+
for (size_t i = 0; i < doc.size(); ++i)
|
23
|
+
{
|
24
|
+
if (doc[i] == non_vocab_id) continue;
|
25
|
+
unigramCf[doc[i]]++;
|
26
|
+
uniqs.emplace(doc[i]);
|
27
|
+
}
|
28
|
+
|
29
|
+
for (auto w : uniqs) unigramDf[w]++;
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
template<typename _DocIter, typename _VvHash, typename _Freqs>
|
34
|
+
void countBigrams(std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramCf,
|
35
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramDf,
|
36
|
+
_DocIter docBegin, _DocIter docEnd,
|
37
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
38
|
+
size_t candMinCnt, size_t candMinDf
|
39
|
+
)
|
40
|
+
{
|
41
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
42
|
+
{
|
43
|
+
std::unordered_set<std::pair<Vid, Vid>, _VvHash> uniqBigram;
|
44
|
+
auto doc = *docIt;
|
45
|
+
if (!doc.size()) continue;
|
46
|
+
Vid prevWord = doc[0];
|
47
|
+
for (size_t j = 1; j < doc.size(); ++j)
|
48
|
+
{
|
49
|
+
Vid curWord = doc[j];
|
50
|
+
if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf)
|
51
|
+
{
|
52
|
+
if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
|
53
|
+
{
|
54
|
+
bigramCf[std::make_pair(prevWord, curWord)]++;
|
55
|
+
uniqBigram.emplace(prevWord, curWord);
|
56
|
+
}
|
57
|
+
}
|
58
|
+
prevWord = curWord;
|
59
|
+
}
|
60
|
+
|
61
|
+
for (auto& p : uniqBigram) bigramDf[p]++;
|
62
|
+
}
|
63
|
+
}
|
64
|
+
|
65
|
+
template<bool _reverse, typename _DocIter, typename _Freqs, typename _BigramPairs>
|
66
|
+
void countNgrams(std::vector<TrieEx<Vid, size_t>>& dest,
|
67
|
+
_DocIter docBegin, _DocIter docEnd,
|
68
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs,
|
69
|
+
size_t candMinCnt, size_t candMinDf, size_t maxNgrams
|
70
|
+
)
|
71
|
+
{
|
72
|
+
if (dest.empty())
|
73
|
+
{
|
74
|
+
dest.resize(1);
|
75
|
+
dest.reserve(1024);
|
76
|
+
}
|
77
|
+
auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
|
78
|
+
|
79
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
80
|
+
{
|
81
|
+
auto doc = *docIt;
|
82
|
+
if (!doc.size()) continue;
|
83
|
+
if (dest.capacity() < dest.size() + doc.size() * maxNgrams)
|
84
|
+
{
|
85
|
+
dest.reserve(std::max(dest.size() + doc.size() * maxNgrams, dest.capacity() * 2));
|
86
|
+
}
|
87
|
+
|
88
|
+
Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin();
|
89
|
+
size_t labelLen = 0;
|
90
|
+
auto node = &dest[0];
|
91
|
+
if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
|
92
|
+
{
|
93
|
+
node = dest[0].makeNext(prevWord, allocNode);
|
94
|
+
node->val++;
|
95
|
+
labelLen = 1;
|
96
|
+
}
|
97
|
+
|
98
|
+
const auto func = [&](Vid curWord)
|
99
|
+
{
|
100
|
+
if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf))
|
101
|
+
{
|
102
|
+
node = &dest[0];
|
103
|
+
labelLen = 0;
|
104
|
+
}
|
105
|
+
else
|
106
|
+
{
|
107
|
+
if (labelLen >= maxNgrams)
|
108
|
+
{
|
109
|
+
node = node->getFail();
|
110
|
+
labelLen--;
|
111
|
+
}
|
112
|
+
|
113
|
+
if (validPairs.count(_reverse ? std::make_pair(curWord, prevWord) : std::make_pair(prevWord, curWord)))
|
114
|
+
{
|
115
|
+
auto nnode = node->makeNext(curWord, allocNode);
|
116
|
+
node = nnode;
|
117
|
+
do
|
118
|
+
{
|
119
|
+
nnode->val++;
|
120
|
+
} while ((nnode = nnode->getFail()));
|
121
|
+
labelLen++;
|
122
|
+
}
|
123
|
+
else
|
124
|
+
{
|
125
|
+
node = dest[0].makeNext(curWord, allocNode);
|
126
|
+
node->val++;
|
127
|
+
labelLen = 1;
|
128
|
+
}
|
129
|
+
}
|
130
|
+
prevWord = curWord;
|
131
|
+
};
|
132
|
+
|
133
|
+
if (_reverse) std::for_each(doc.rbegin() + 1, doc.rend(), func);
|
134
|
+
else std::for_each(doc.begin() + 1, doc.end(), func);
|
135
|
+
}
|
136
|
+
}
|
137
|
+
|
138
|
+
inline void mergeNgramCounts(std::vector<TrieEx<Vid, size_t>>& dest, std::vector<TrieEx<Vid, size_t>>&& src)
|
139
|
+
{
|
140
|
+
if (src.empty()) return;
|
141
|
+
if (dest.empty()) dest.resize(1);
|
142
|
+
|
143
|
+
auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
|
144
|
+
|
145
|
+
std::vector<Vid> rkeys;
|
146
|
+
src[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
147
|
+
{
|
148
|
+
if (dest.capacity() < dest.size() + rkeys.size() * rkeys.size())
|
149
|
+
{
|
150
|
+
dest.reserve(std::max(dest.size() + rkeys.size() * rkeys.size(), dest.capacity() * 2));
|
151
|
+
}
|
152
|
+
dest[0].build(rkeys.begin(), rkeys.end(), 0, allocNode)->val += node->val;
|
153
|
+
}, rkeys);
|
154
|
+
}
|
155
|
+
|
156
|
+
inline float branchingEntropy(const TrieEx<Vid, size_t>* node, size_t minCnt)
|
157
|
+
{
|
158
|
+
float entropy = 0;
|
159
|
+
size_t rest = node->val;
|
160
|
+
for (auto n : *node)
|
161
|
+
{
|
162
|
+
float p = n.second->val / (float)node->val;
|
163
|
+
entropy -= p * std::log(p);
|
164
|
+
rest -= n.second->val;
|
165
|
+
}
|
166
|
+
if (rest > 0)
|
167
|
+
{
|
168
|
+
float p = rest / (float)node->val;
|
169
|
+
entropy -= p * std::log(std::min(std::max(minCnt, (size_t)1), (size_t)rest) / (float)node->val);
|
170
|
+
}
|
171
|
+
return entropy;
|
172
|
+
}
|
173
|
+
|
174
|
+
template<typename _LocalData, typename _ReduceFn>
|
175
|
+
_LocalData parallelReduce(std::vector<_LocalData>&& data, _ReduceFn&& fn, ThreadPool* pool = nullptr)
|
176
|
+
{
|
177
|
+
if (pool)
|
178
|
+
{
|
179
|
+
for (size_t s = data.size(); s > 1; s = (s + 1) / 2)
|
180
|
+
{
|
181
|
+
std::vector<std::future<void>> futures;
|
182
|
+
size_t h = (s + 1) / 2;
|
183
|
+
for (size_t i = h; i < s; ++i)
|
184
|
+
{
|
185
|
+
futures.emplace_back(pool->enqueue([&, i, h](size_t)
|
186
|
+
{
|
187
|
+
_LocalData d = std::move(data[i]);
|
188
|
+
fn(data[i - h], std::move(d));
|
189
|
+
}));
|
190
|
+
}
|
191
|
+
for (auto& f : futures) f.get();
|
192
|
+
}
|
193
|
+
}
|
194
|
+
else
|
195
|
+
{
|
196
|
+
for (size_t i = 1; i < data.size(); ++i)
|
197
|
+
{
|
198
|
+
_LocalData d = std::move(data[i]);
|
199
|
+
fn(data[0], std::move(d));
|
200
|
+
}
|
201
|
+
}
|
202
|
+
return std::move(data[0]);
|
203
|
+
}
|
204
|
+
|
205
|
+
namespace detail
|
206
|
+
{
|
207
|
+
struct vvhash
|
208
|
+
{
|
209
|
+
size_t operator()(const std::pair<Vid, Vid>& k) const
|
210
|
+
{
|
211
|
+
return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
|
212
|
+
}
|
213
|
+
};
|
214
|
+
}
|
215
|
+
|
216
|
+
template<typename _DocIter, typename _Freqs>
|
217
|
+
std::vector<label::Candidate> extractPMINgrams(_DocIter docBegin, _DocIter docEnd,
|
218
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
219
|
+
size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
|
220
|
+
float minScore, bool normalized = false,
|
221
|
+
ThreadPool* pool = nullptr)
|
222
|
+
{
|
223
|
+
// counting unigrams & bigrams
|
224
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
|
225
|
+
|
226
|
+
if (pool && pool->getNumWorkers() > 1)
|
227
|
+
{
|
228
|
+
using LocalCfDf = std::pair<
|
229
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
|
230
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
|
231
|
+
>;
|
232
|
+
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
233
|
+
std::vector<std::future<void>> futures;
|
234
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
235
|
+
auto docIt = docBegin;
|
236
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
237
|
+
{
|
238
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
239
|
+
{
|
240
|
+
countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
241
|
+
}));
|
242
|
+
}
|
243
|
+
|
244
|
+
for (auto& f : futures) f.get();
|
245
|
+
|
246
|
+
auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
|
247
|
+
{
|
248
|
+
for (auto& p : src.first) dest.first[p.first] += p.second;
|
249
|
+
for (auto& p : src.second) dest.second[p.first] += p.second;
|
250
|
+
}, pool);
|
251
|
+
|
252
|
+
bigramCnt = std::move(r.first);
|
253
|
+
bigramDf = std::move(r.second);
|
254
|
+
}
|
255
|
+
else
|
256
|
+
{
|
257
|
+
countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
258
|
+
}
|
259
|
+
|
260
|
+
// counting ngrams
|
261
|
+
std::vector<TrieEx<Vid, size_t>> trieNodes;
|
262
|
+
if (maxNgrams > 2)
|
263
|
+
{
|
264
|
+
std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
|
265
|
+
for (auto& p : bigramCnt)
|
266
|
+
{
|
267
|
+
if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
|
268
|
+
}
|
269
|
+
|
270
|
+
if (pool && pool->getNumWorkers() > 1)
|
271
|
+
{
|
272
|
+
using LocalFwBw = std::vector<TrieEx<Vid, size_t>>;
|
273
|
+
std::vector<LocalFwBw> localdata(pool->getNumWorkers());
|
274
|
+
std::vector<std::future<void>> futures;
|
275
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
276
|
+
auto docIt = docBegin;
|
277
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
278
|
+
{
|
279
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
280
|
+
{
|
281
|
+
countNgrams<false>(localdata[tid],
|
282
|
+
makeStrideIter(docIt, stride, docEnd),
|
283
|
+
makeStrideIter(docEnd, stride, docEnd),
|
284
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
|
285
|
+
);
|
286
|
+
}));
|
287
|
+
}
|
288
|
+
|
289
|
+
for (auto& f : futures) f.get();
|
290
|
+
|
291
|
+
auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
|
292
|
+
{
|
293
|
+
mergeNgramCounts(dest, std::move(src));
|
294
|
+
}, pool);
|
295
|
+
|
296
|
+
trieNodes = std::move(r);
|
297
|
+
}
|
298
|
+
else
|
299
|
+
{
|
300
|
+
countNgrams<false>(trieNodes,
|
301
|
+
docBegin, docEnd,
|
302
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
|
303
|
+
);
|
304
|
+
}
|
305
|
+
}
|
306
|
+
|
307
|
+
float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
|
308
|
+
const float logTotN = std::log(totN);
|
309
|
+
|
310
|
+
// calculating PMIs
|
311
|
+
std::vector<label::Candidate> candidates;
|
312
|
+
for (auto& p : bigramCnt)
|
313
|
+
{
|
314
|
+
auto& bigram = p.first;
|
315
|
+
if (p.second < candMinCnt) continue;
|
316
|
+
if (bigramDf[bigram] < candMinDf) continue;
|
317
|
+
auto pmi = std::log(p.second * totN
|
318
|
+
/ vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
|
319
|
+
if (normalized)
|
320
|
+
{
|
321
|
+
pmi /= std::log(totN / p.second);
|
322
|
+
}
|
323
|
+
if (pmi < minScore) continue;
|
324
|
+
candidates.emplace_back(pmi, bigram.first, bigram.second);
|
325
|
+
candidates.back().cf = p.second;
|
326
|
+
candidates.back().df = bigramDf[bigram];
|
327
|
+
}
|
328
|
+
|
329
|
+
if (maxNgrams > 2)
|
330
|
+
{
|
331
|
+
std::vector<Vid> rkeys;
|
332
|
+
trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
333
|
+
{
|
334
|
+
if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
|
335
|
+
auto pmi = std::log((float)node->val) - logTotN;
|
336
|
+
for (auto k : rkeys)
|
337
|
+
{
|
338
|
+
pmi += logTotN - std::log((float)vocabFreqs[k]);
|
339
|
+
}
|
340
|
+
if (normalized)
|
341
|
+
{
|
342
|
+
pmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
|
343
|
+
}
|
344
|
+
if (pmi < minScore) return;
|
345
|
+
candidates.emplace_back(pmi, rkeys);
|
346
|
+
candidates.back().cf = node->val;
|
347
|
+
}, rkeys);
|
348
|
+
}
|
349
|
+
|
350
|
+
std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
|
351
|
+
{
|
352
|
+
return a.score > b.score;
|
353
|
+
});
|
354
|
+
if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
|
355
|
+
return candidates;
|
356
|
+
}
|
357
|
+
|
358
|
+
template<typename _DocIter, typename _Freqs>
|
359
|
+
std::vector<label::Candidate> extractPMIBENgrams(_DocIter docBegin, _DocIter docEnd,
|
360
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
361
|
+
size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
|
362
|
+
float minNPMI = 0, float minNBE = 0,
|
363
|
+
ThreadPool* pool = nullptr)
|
364
|
+
{
|
365
|
+
// counting unigrams & bigrams
|
366
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
|
367
|
+
|
368
|
+
if (pool && pool->getNumWorkers() > 1)
|
369
|
+
{
|
370
|
+
using LocalCfDf = std::pair<
|
371
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
|
372
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
|
373
|
+
>;
|
374
|
+
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
375
|
+
std::vector<std::future<void>> futures;
|
376
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
377
|
+
auto docIt = docBegin;
|
378
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
379
|
+
{
|
380
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
381
|
+
{
|
382
|
+
countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
383
|
+
}));
|
384
|
+
}
|
385
|
+
|
386
|
+
for (auto& f : futures) f.get();
|
387
|
+
|
388
|
+
auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
|
389
|
+
{
|
390
|
+
for (auto& p : src.first) dest.first[p.first] += p.second;
|
391
|
+
for (auto& p : src.second) dest.second[p.first] += p.second;
|
392
|
+
}, pool);
|
393
|
+
|
394
|
+
bigramCnt = std::move(r.first);
|
395
|
+
bigramDf = std::move(r.second);
|
396
|
+
}
|
397
|
+
else
|
398
|
+
{
|
399
|
+
countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
400
|
+
}
|
401
|
+
|
402
|
+
// counting ngrams
|
403
|
+
std::vector<TrieEx<Vid, size_t>> trieNodes, trieNodesBw;
|
404
|
+
if (maxNgrams > 2)
|
405
|
+
{
|
406
|
+
std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
|
407
|
+
for (auto& p : bigramCnt)
|
408
|
+
{
|
409
|
+
if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
|
410
|
+
}
|
411
|
+
|
412
|
+
if (pool && pool->getNumWorkers() > 1)
|
413
|
+
{
|
414
|
+
using LocalFwBw = std::pair<
|
415
|
+
std::vector<TrieEx<Vid, size_t>>,
|
416
|
+
std::vector<TrieEx<Vid, size_t>>
|
417
|
+
>;
|
418
|
+
std::vector<LocalFwBw> localdata(pool->getNumWorkers());
|
419
|
+
std::vector<std::future<void>> futures;
|
420
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
421
|
+
auto docIt = docBegin;
|
422
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
423
|
+
{
|
424
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
425
|
+
{
|
426
|
+
countNgrams<false>(localdata[tid].first,
|
427
|
+
makeStrideIter(docIt, stride, docEnd),
|
428
|
+
makeStrideIter(docEnd, stride, docEnd),
|
429
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
430
|
+
);
|
431
|
+
countNgrams<true>(localdata[tid].second,
|
432
|
+
makeStrideIter(docIt, stride, docEnd),
|
433
|
+
makeStrideIter(docEnd, stride, docEnd),
|
434
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
435
|
+
);
|
436
|
+
}));
|
437
|
+
}
|
438
|
+
|
439
|
+
for (auto& f : futures) f.get();
|
440
|
+
|
441
|
+
auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
|
442
|
+
{
|
443
|
+
mergeNgramCounts(dest.first, std::move(src.first));
|
444
|
+
mergeNgramCounts(dest.second, std::move(src.second));
|
445
|
+
}, pool);
|
446
|
+
|
447
|
+
trieNodes = std::move(r.first);
|
448
|
+
trieNodesBw = std::move(r.second);
|
449
|
+
}
|
450
|
+
else
|
451
|
+
{
|
452
|
+
countNgrams<false>(trieNodes,
|
453
|
+
docBegin, docEnd,
|
454
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
455
|
+
);
|
456
|
+
countNgrams<true>(trieNodesBw,
|
457
|
+
docBegin, docEnd,
|
458
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
459
|
+
);
|
460
|
+
}
|
461
|
+
}
|
462
|
+
|
463
|
+
float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
|
464
|
+
const float logTotN = std::log(totN);
|
465
|
+
|
466
|
+
// calculating PMIs
|
467
|
+
std::vector<label::Candidate> candidates;
|
468
|
+
for (auto& p : bigramCnt)
|
469
|
+
{
|
470
|
+
auto& bigram = p.first;
|
471
|
+
if (p.second < candMinCnt) continue;
|
472
|
+
if (bigramDf[bigram] < candMinDf) continue;
|
473
|
+
float npmi = std::log(p.second * totN
|
474
|
+
/ vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
|
475
|
+
npmi /= std::log(totN / p.second);
|
476
|
+
if (npmi < minNPMI) continue;
|
477
|
+
|
478
|
+
float rbe = branchingEntropy(trieNodes[0].getNext(bigram.first)->getNext(bigram.second), candMinCnt);
|
479
|
+
float lbe = branchingEntropy(trieNodesBw[0].getNext(bigram.second)->getNext(bigram.first), candMinCnt);
|
480
|
+
float nbe = std::sqrt(rbe * lbe) / std::log(p.second);
|
481
|
+
if (nbe < minNBE) continue;
|
482
|
+
candidates.emplace_back(npmi * nbe, bigram.first, bigram.second);
|
483
|
+
candidates.back().cf = p.second;
|
484
|
+
candidates.back().df = bigramDf[bigram];
|
485
|
+
}
|
486
|
+
|
487
|
+
if (maxNgrams > 2)
|
488
|
+
{
|
489
|
+
std::vector<Vid> rkeys;
|
490
|
+
trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
491
|
+
{
|
492
|
+
if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
|
493
|
+
auto npmi = std::log((float)node->val) - logTotN;
|
494
|
+
for (auto k : rkeys)
|
495
|
+
{
|
496
|
+
npmi += logTotN - std::log((float)vocabFreqs[k]);
|
497
|
+
}
|
498
|
+
npmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
|
499
|
+
if (npmi < minNPMI) return;
|
500
|
+
|
501
|
+
float rbe = branchingEntropy(node, candMinCnt);
|
502
|
+
float lbe = branchingEntropy(trieNodesBw[0].findNode(rkeys.rbegin(), rkeys.rend()), candMinCnt);
|
503
|
+
float nbe = std::sqrt(rbe * lbe) / std::log(node->val);
|
504
|
+
if (nbe < minNBE) return;
|
505
|
+
candidates.emplace_back(npmi * nbe, rkeys);
|
506
|
+
candidates.back().cf = node->val;
|
507
|
+
}, rkeys);
|
508
|
+
}
|
509
|
+
|
510
|
+
std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
|
511
|
+
{
|
512
|
+
return a.score > b.score;
|
513
|
+
});
|
514
|
+
if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
|
515
|
+
return candidates;
|
516
|
+
}
|
517
|
+
}
|
518
|
+
}
|