tomoto 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -0
- data/ext/tomoto/ct.cpp +54 -0
- data/ext/tomoto/dmr.cpp +62 -0
- data/ext/tomoto/dt.cpp +82 -0
- data/ext/tomoto/ext.cpp +27 -773
- data/ext/tomoto/gdmr.cpp +34 -0
- data/ext/tomoto/hdp.cpp +42 -0
- data/ext/tomoto/hlda.cpp +66 -0
- data/ext/tomoto/hpa.cpp +27 -0
- data/ext/tomoto/lda.cpp +250 -0
- data/ext/tomoto/llda.cpp +29 -0
- data/ext/tomoto/mglda.cpp +71 -0
- data/ext/tomoto/pa.cpp +27 -0
- data/ext/tomoto/plda.cpp +29 -0
- data/ext/tomoto/slda.cpp +40 -0
- data/ext/tomoto/utils.h +84 -0
- data/lib/tomoto/tomoto.bundle +0 -0
- data/lib/tomoto/tomoto.so +0 -0
- data/lib/tomoto/version.rb +1 -1
- data/vendor/tomotopy/README.kr.rst +12 -3
- data/vendor/tomotopy/README.rst +12 -3
- data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
- data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
- data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
- data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
- data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
- data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
- data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
- data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
- data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
- data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
- data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
- data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
- data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
- data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
- data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
- data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
- data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
- data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
- data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
- data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
- data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
- data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
- data/vendor/tomotopy/src/Utils/math.h +8 -4
- data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
- metadata +24 -60
@@ -10,6 +10,7 @@ namespace tomoto
|
|
10
10
|
struct Candidate
|
11
11
|
{
|
12
12
|
float score = 0;
|
13
|
+
size_t cf = 0, df = 0;
|
13
14
|
std::vector<Vid> w;
|
14
15
|
std::string name;
|
15
16
|
|
@@ -18,17 +19,17 @@ namespace tomoto
|
|
18
19
|
}
|
19
20
|
|
20
21
|
Candidate(float _score, Vid w1)
|
21
|
-
:
|
22
|
+
: score{ _score }, w{ w1 }
|
22
23
|
{
|
23
24
|
}
|
24
25
|
|
25
26
|
Candidate(float _score, Vid w1, Vid w2)
|
26
|
-
:
|
27
|
+
: score{ _score }, w{ w1, w2 }
|
27
28
|
{
|
28
29
|
}
|
29
30
|
|
30
31
|
Candidate(float _score, const std::vector<Vid>& _w)
|
31
|
-
:
|
32
|
+
: score{ _score }, w{ _w }
|
32
33
|
{
|
33
34
|
}
|
34
35
|
};
|
@@ -36,6 +37,7 @@ namespace tomoto
|
|
36
37
|
class IExtractor
|
37
38
|
{
|
38
39
|
public:
|
40
|
+
|
39
41
|
virtual std::vector<Candidate> extract(const ITopicModel* tm) const = 0;
|
40
42
|
virtual ~IExtractor() {}
|
41
43
|
};
|
@@ -0,0 +1,518 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <vector>
|
4
|
+
#include <unordered_map>
|
5
|
+
#include "Labeler.h"
|
6
|
+
#include "../Utils/Trie.hpp"
|
7
|
+
|
8
|
+
namespace tomoto
|
9
|
+
{
|
10
|
+
namespace phraser
|
11
|
+
{
|
12
|
+
template<typename _DocIter>
|
13
|
+
void countUnigrams(std::vector<size_t>& unigramCf, std::vector<size_t>& unigramDf,
|
14
|
+
_DocIter docBegin, _DocIter docEnd
|
15
|
+
)
|
16
|
+
{
|
17
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
18
|
+
{
|
19
|
+
auto doc = *docIt;
|
20
|
+
if (!doc.size()) continue;
|
21
|
+
std::unordered_set<Vid> uniqs;
|
22
|
+
for (size_t i = 0; i < doc.size(); ++i)
|
23
|
+
{
|
24
|
+
if (doc[i] == non_vocab_id) continue;
|
25
|
+
unigramCf[doc[i]]++;
|
26
|
+
uniqs.emplace(doc[i]);
|
27
|
+
}
|
28
|
+
|
29
|
+
for (auto w : uniqs) unigramDf[w]++;
|
30
|
+
}
|
31
|
+
}
|
32
|
+
|
33
|
+
template<typename _DocIter, typename _VvHash, typename _Freqs>
|
34
|
+
void countBigrams(std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramCf,
|
35
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramDf,
|
36
|
+
_DocIter docBegin, _DocIter docEnd,
|
37
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
38
|
+
size_t candMinCnt, size_t candMinDf
|
39
|
+
)
|
40
|
+
{
|
41
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
42
|
+
{
|
43
|
+
std::unordered_set<std::pair<Vid, Vid>, _VvHash> uniqBigram;
|
44
|
+
auto doc = *docIt;
|
45
|
+
if (!doc.size()) continue;
|
46
|
+
Vid prevWord = doc[0];
|
47
|
+
for (size_t j = 1; j < doc.size(); ++j)
|
48
|
+
{
|
49
|
+
Vid curWord = doc[j];
|
50
|
+
if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf)
|
51
|
+
{
|
52
|
+
if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
|
53
|
+
{
|
54
|
+
bigramCf[std::make_pair(prevWord, curWord)]++;
|
55
|
+
uniqBigram.emplace(prevWord, curWord);
|
56
|
+
}
|
57
|
+
}
|
58
|
+
prevWord = curWord;
|
59
|
+
}
|
60
|
+
|
61
|
+
for (auto& p : uniqBigram) bigramDf[p]++;
|
62
|
+
}
|
63
|
+
}
|
64
|
+
|
65
|
+
template<bool _reverse, typename _DocIter, typename _Freqs, typename _BigramPairs>
|
66
|
+
void countNgrams(std::vector<TrieEx<Vid, size_t>>& dest,
|
67
|
+
_DocIter docBegin, _DocIter docEnd,
|
68
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs,
|
69
|
+
size_t candMinCnt, size_t candMinDf, size_t maxNgrams
|
70
|
+
)
|
71
|
+
{
|
72
|
+
if (dest.empty())
|
73
|
+
{
|
74
|
+
dest.resize(1);
|
75
|
+
dest.reserve(1024);
|
76
|
+
}
|
77
|
+
auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
|
78
|
+
|
79
|
+
for (auto docIt = docBegin; docIt != docEnd; ++docIt)
|
80
|
+
{
|
81
|
+
auto doc = *docIt;
|
82
|
+
if (!doc.size()) continue;
|
83
|
+
if (dest.capacity() < dest.size() + doc.size() * maxNgrams)
|
84
|
+
{
|
85
|
+
dest.reserve(std::max(dest.size() + doc.size() * maxNgrams, dest.capacity() * 2));
|
86
|
+
}
|
87
|
+
|
88
|
+
Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin();
|
89
|
+
size_t labelLen = 0;
|
90
|
+
auto node = &dest[0];
|
91
|
+
if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
|
92
|
+
{
|
93
|
+
node = dest[0].makeNext(prevWord, allocNode);
|
94
|
+
node->val++;
|
95
|
+
labelLen = 1;
|
96
|
+
}
|
97
|
+
|
98
|
+
const auto func = [&](Vid curWord)
|
99
|
+
{
|
100
|
+
if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf))
|
101
|
+
{
|
102
|
+
node = &dest[0];
|
103
|
+
labelLen = 0;
|
104
|
+
}
|
105
|
+
else
|
106
|
+
{
|
107
|
+
if (labelLen >= maxNgrams)
|
108
|
+
{
|
109
|
+
node = node->getFail();
|
110
|
+
labelLen--;
|
111
|
+
}
|
112
|
+
|
113
|
+
if (validPairs.count(_reverse ? std::make_pair(curWord, prevWord) : std::make_pair(prevWord, curWord)))
|
114
|
+
{
|
115
|
+
auto nnode = node->makeNext(curWord, allocNode);
|
116
|
+
node = nnode;
|
117
|
+
do
|
118
|
+
{
|
119
|
+
nnode->val++;
|
120
|
+
} while ((nnode = nnode->getFail()));
|
121
|
+
labelLen++;
|
122
|
+
}
|
123
|
+
else
|
124
|
+
{
|
125
|
+
node = dest[0].makeNext(curWord, allocNode);
|
126
|
+
node->val++;
|
127
|
+
labelLen = 1;
|
128
|
+
}
|
129
|
+
}
|
130
|
+
prevWord = curWord;
|
131
|
+
};
|
132
|
+
|
133
|
+
if (_reverse) std::for_each(doc.rbegin() + 1, doc.rend(), func);
|
134
|
+
else std::for_each(doc.begin() + 1, doc.end(), func);
|
135
|
+
}
|
136
|
+
}
|
137
|
+
|
138
|
+
inline void mergeNgramCounts(std::vector<TrieEx<Vid, size_t>>& dest, std::vector<TrieEx<Vid, size_t>>&& src)
|
139
|
+
{
|
140
|
+
if (src.empty()) return;
|
141
|
+
if (dest.empty()) dest.resize(1);
|
142
|
+
|
143
|
+
auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
|
144
|
+
|
145
|
+
std::vector<Vid> rkeys;
|
146
|
+
src[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
147
|
+
{
|
148
|
+
if (dest.capacity() < dest.size() + rkeys.size() * rkeys.size())
|
149
|
+
{
|
150
|
+
dest.reserve(std::max(dest.size() + rkeys.size() * rkeys.size(), dest.capacity() * 2));
|
151
|
+
}
|
152
|
+
dest[0].build(rkeys.begin(), rkeys.end(), 0, allocNode)->val += node->val;
|
153
|
+
}, rkeys);
|
154
|
+
}
|
155
|
+
|
156
|
+
inline float branchingEntropy(const TrieEx<Vid, size_t>* node, size_t minCnt)
|
157
|
+
{
|
158
|
+
float entropy = 0;
|
159
|
+
size_t rest = node->val;
|
160
|
+
for (auto n : *node)
|
161
|
+
{
|
162
|
+
float p = n.second->val / (float)node->val;
|
163
|
+
entropy -= p * std::log(p);
|
164
|
+
rest -= n.second->val;
|
165
|
+
}
|
166
|
+
if (rest > 0)
|
167
|
+
{
|
168
|
+
float p = rest / (float)node->val;
|
169
|
+
entropy -= p * std::log(std::min(std::max(minCnt, (size_t)1), (size_t)rest) / (float)node->val);
|
170
|
+
}
|
171
|
+
return entropy;
|
172
|
+
}
|
173
|
+
|
174
|
+
template<typename _LocalData, typename _ReduceFn>
|
175
|
+
_LocalData parallelReduce(std::vector<_LocalData>&& data, _ReduceFn&& fn, ThreadPool* pool = nullptr)
|
176
|
+
{
|
177
|
+
if (pool)
|
178
|
+
{
|
179
|
+
for (size_t s = data.size(); s > 1; s = (s + 1) / 2)
|
180
|
+
{
|
181
|
+
std::vector<std::future<void>> futures;
|
182
|
+
size_t h = (s + 1) / 2;
|
183
|
+
for (size_t i = h; i < s; ++i)
|
184
|
+
{
|
185
|
+
futures.emplace_back(pool->enqueue([&, i, h](size_t)
|
186
|
+
{
|
187
|
+
_LocalData d = std::move(data[i]);
|
188
|
+
fn(data[i - h], std::move(d));
|
189
|
+
}));
|
190
|
+
}
|
191
|
+
for (auto& f : futures) f.get();
|
192
|
+
}
|
193
|
+
}
|
194
|
+
else
|
195
|
+
{
|
196
|
+
for (size_t i = 1; i < data.size(); ++i)
|
197
|
+
{
|
198
|
+
_LocalData d = std::move(data[i]);
|
199
|
+
fn(data[0], std::move(d));
|
200
|
+
}
|
201
|
+
}
|
202
|
+
return std::move(data[0]);
|
203
|
+
}
|
204
|
+
|
205
|
+
namespace detail
|
206
|
+
{
|
207
|
+
struct vvhash
|
208
|
+
{
|
209
|
+
size_t operator()(const std::pair<Vid, Vid>& k) const
|
210
|
+
{
|
211
|
+
return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
|
212
|
+
}
|
213
|
+
};
|
214
|
+
}
|
215
|
+
|
216
|
+
template<typename _DocIter, typename _Freqs>
|
217
|
+
std::vector<label::Candidate> extractPMINgrams(_DocIter docBegin, _DocIter docEnd,
|
218
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
219
|
+
size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
|
220
|
+
float minScore, bool normalized = false,
|
221
|
+
ThreadPool* pool = nullptr)
|
222
|
+
{
|
223
|
+
// counting unigrams & bigrams
|
224
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
|
225
|
+
|
226
|
+
if (pool && pool->getNumWorkers() > 1)
|
227
|
+
{
|
228
|
+
using LocalCfDf = std::pair<
|
229
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
|
230
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
|
231
|
+
>;
|
232
|
+
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
233
|
+
std::vector<std::future<void>> futures;
|
234
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
235
|
+
auto docIt = docBegin;
|
236
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
237
|
+
{
|
238
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
239
|
+
{
|
240
|
+
countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
241
|
+
}));
|
242
|
+
}
|
243
|
+
|
244
|
+
for (auto& f : futures) f.get();
|
245
|
+
|
246
|
+
auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
|
247
|
+
{
|
248
|
+
for (auto& p : src.first) dest.first[p.first] += p.second;
|
249
|
+
for (auto& p : src.second) dest.second[p.first] += p.second;
|
250
|
+
}, pool);
|
251
|
+
|
252
|
+
bigramCnt = std::move(r.first);
|
253
|
+
bigramDf = std::move(r.second);
|
254
|
+
}
|
255
|
+
else
|
256
|
+
{
|
257
|
+
countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
258
|
+
}
|
259
|
+
|
260
|
+
// counting ngrams
|
261
|
+
std::vector<TrieEx<Vid, size_t>> trieNodes;
|
262
|
+
if (maxNgrams > 2)
|
263
|
+
{
|
264
|
+
std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
|
265
|
+
for (auto& p : bigramCnt)
|
266
|
+
{
|
267
|
+
if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
|
268
|
+
}
|
269
|
+
|
270
|
+
if (pool && pool->getNumWorkers() > 1)
|
271
|
+
{
|
272
|
+
using LocalFwBw = std::vector<TrieEx<Vid, size_t>>;
|
273
|
+
std::vector<LocalFwBw> localdata(pool->getNumWorkers());
|
274
|
+
std::vector<std::future<void>> futures;
|
275
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
276
|
+
auto docIt = docBegin;
|
277
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
278
|
+
{
|
279
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
280
|
+
{
|
281
|
+
countNgrams<false>(localdata[tid],
|
282
|
+
makeStrideIter(docIt, stride, docEnd),
|
283
|
+
makeStrideIter(docEnd, stride, docEnd),
|
284
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
|
285
|
+
);
|
286
|
+
}));
|
287
|
+
}
|
288
|
+
|
289
|
+
for (auto& f : futures) f.get();
|
290
|
+
|
291
|
+
auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
|
292
|
+
{
|
293
|
+
mergeNgramCounts(dest, std::move(src));
|
294
|
+
}, pool);
|
295
|
+
|
296
|
+
trieNodes = std::move(r);
|
297
|
+
}
|
298
|
+
else
|
299
|
+
{
|
300
|
+
countNgrams<false>(trieNodes,
|
301
|
+
docBegin, docEnd,
|
302
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
|
303
|
+
);
|
304
|
+
}
|
305
|
+
}
|
306
|
+
|
307
|
+
float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
|
308
|
+
const float logTotN = std::log(totN);
|
309
|
+
|
310
|
+
// calculating PMIs
|
311
|
+
std::vector<label::Candidate> candidates;
|
312
|
+
for (auto& p : bigramCnt)
|
313
|
+
{
|
314
|
+
auto& bigram = p.first;
|
315
|
+
if (p.second < candMinCnt) continue;
|
316
|
+
if (bigramDf[bigram] < candMinDf) continue;
|
317
|
+
auto pmi = std::log(p.second * totN
|
318
|
+
/ vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
|
319
|
+
if (normalized)
|
320
|
+
{
|
321
|
+
pmi /= std::log(totN / p.second);
|
322
|
+
}
|
323
|
+
if (pmi < minScore) continue;
|
324
|
+
candidates.emplace_back(pmi, bigram.first, bigram.second);
|
325
|
+
candidates.back().cf = p.second;
|
326
|
+
candidates.back().df = bigramDf[bigram];
|
327
|
+
}
|
328
|
+
|
329
|
+
if (maxNgrams > 2)
|
330
|
+
{
|
331
|
+
std::vector<Vid> rkeys;
|
332
|
+
trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
333
|
+
{
|
334
|
+
if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
|
335
|
+
auto pmi = std::log((float)node->val) - logTotN;
|
336
|
+
for (auto k : rkeys)
|
337
|
+
{
|
338
|
+
pmi += logTotN - std::log((float)vocabFreqs[k]);
|
339
|
+
}
|
340
|
+
if (normalized)
|
341
|
+
{
|
342
|
+
pmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
|
343
|
+
}
|
344
|
+
if (pmi < minScore) return;
|
345
|
+
candidates.emplace_back(pmi, rkeys);
|
346
|
+
candidates.back().cf = node->val;
|
347
|
+
}, rkeys);
|
348
|
+
}
|
349
|
+
|
350
|
+
std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
|
351
|
+
{
|
352
|
+
return a.score > b.score;
|
353
|
+
});
|
354
|
+
if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
|
355
|
+
return candidates;
|
356
|
+
}
|
357
|
+
|
358
|
+
template<typename _DocIter, typename _Freqs>
|
359
|
+
std::vector<label::Candidate> extractPMIBENgrams(_DocIter docBegin, _DocIter docEnd,
|
360
|
+
_Freqs&& vocabFreqs, _Freqs&& vocabDf,
|
361
|
+
size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
|
362
|
+
float minNPMI = 0, float minNBE = 0,
|
363
|
+
ThreadPool* pool = nullptr)
|
364
|
+
{
|
365
|
+
// counting unigrams & bigrams
|
366
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
|
367
|
+
|
368
|
+
if (pool && pool->getNumWorkers() > 1)
|
369
|
+
{
|
370
|
+
using LocalCfDf = std::pair<
|
371
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
|
372
|
+
std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
|
373
|
+
>;
|
374
|
+
std::vector<LocalCfDf> localdata(pool->getNumWorkers());
|
375
|
+
std::vector<std::future<void>> futures;
|
376
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
377
|
+
auto docIt = docBegin;
|
378
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
379
|
+
{
|
380
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
381
|
+
{
|
382
|
+
countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
383
|
+
}));
|
384
|
+
}
|
385
|
+
|
386
|
+
for (auto& f : futures) f.get();
|
387
|
+
|
388
|
+
auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
|
389
|
+
{
|
390
|
+
for (auto& p : src.first) dest.first[p.first] += p.second;
|
391
|
+
for (auto& p : src.second) dest.second[p.first] += p.second;
|
392
|
+
}, pool);
|
393
|
+
|
394
|
+
bigramCnt = std::move(r.first);
|
395
|
+
bigramDf = std::move(r.second);
|
396
|
+
}
|
397
|
+
else
|
398
|
+
{
|
399
|
+
countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
|
400
|
+
}
|
401
|
+
|
402
|
+
// counting ngrams
|
403
|
+
std::vector<TrieEx<Vid, size_t>> trieNodes, trieNodesBw;
|
404
|
+
if (maxNgrams > 2)
|
405
|
+
{
|
406
|
+
std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
|
407
|
+
for (auto& p : bigramCnt)
|
408
|
+
{
|
409
|
+
if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
|
410
|
+
}
|
411
|
+
|
412
|
+
if (pool && pool->getNumWorkers() > 1)
|
413
|
+
{
|
414
|
+
using LocalFwBw = std::pair<
|
415
|
+
std::vector<TrieEx<Vid, size_t>>,
|
416
|
+
std::vector<TrieEx<Vid, size_t>>
|
417
|
+
>;
|
418
|
+
std::vector<LocalFwBw> localdata(pool->getNumWorkers());
|
419
|
+
std::vector<std::future<void>> futures;
|
420
|
+
const size_t stride = pool->getNumWorkers() * 8;
|
421
|
+
auto docIt = docBegin;
|
422
|
+
for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
|
423
|
+
{
|
424
|
+
futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
|
425
|
+
{
|
426
|
+
countNgrams<false>(localdata[tid].first,
|
427
|
+
makeStrideIter(docIt, stride, docEnd),
|
428
|
+
makeStrideIter(docEnd, stride, docEnd),
|
429
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
430
|
+
);
|
431
|
+
countNgrams<true>(localdata[tid].second,
|
432
|
+
makeStrideIter(docIt, stride, docEnd),
|
433
|
+
makeStrideIter(docEnd, stride, docEnd),
|
434
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
435
|
+
);
|
436
|
+
}));
|
437
|
+
}
|
438
|
+
|
439
|
+
for (auto& f : futures) f.get();
|
440
|
+
|
441
|
+
auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
|
442
|
+
{
|
443
|
+
mergeNgramCounts(dest.first, std::move(src.first));
|
444
|
+
mergeNgramCounts(dest.second, std::move(src.second));
|
445
|
+
}, pool);
|
446
|
+
|
447
|
+
trieNodes = std::move(r.first);
|
448
|
+
trieNodesBw = std::move(r.second);
|
449
|
+
}
|
450
|
+
else
|
451
|
+
{
|
452
|
+
countNgrams<false>(trieNodes,
|
453
|
+
docBegin, docEnd,
|
454
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
455
|
+
);
|
456
|
+
countNgrams<true>(trieNodesBw,
|
457
|
+
docBegin, docEnd,
|
458
|
+
vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
|
459
|
+
);
|
460
|
+
}
|
461
|
+
}
|
462
|
+
|
463
|
+
float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
|
464
|
+
const float logTotN = std::log(totN);
|
465
|
+
|
466
|
+
// calculating PMIs
|
467
|
+
std::vector<label::Candidate> candidates;
|
468
|
+
for (auto& p : bigramCnt)
|
469
|
+
{
|
470
|
+
auto& bigram = p.first;
|
471
|
+
if (p.second < candMinCnt) continue;
|
472
|
+
if (bigramDf[bigram] < candMinDf) continue;
|
473
|
+
float npmi = std::log(p.second * totN
|
474
|
+
/ vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
|
475
|
+
npmi /= std::log(totN / p.second);
|
476
|
+
if (npmi < minNPMI) continue;
|
477
|
+
|
478
|
+
float rbe = branchingEntropy(trieNodes[0].getNext(bigram.first)->getNext(bigram.second), candMinCnt);
|
479
|
+
float lbe = branchingEntropy(trieNodesBw[0].getNext(bigram.second)->getNext(bigram.first), candMinCnt);
|
480
|
+
float nbe = std::sqrt(rbe * lbe) / std::log(p.second);
|
481
|
+
if (nbe < minNBE) continue;
|
482
|
+
candidates.emplace_back(npmi * nbe, bigram.first, bigram.second);
|
483
|
+
candidates.back().cf = p.second;
|
484
|
+
candidates.back().df = bigramDf[bigram];
|
485
|
+
}
|
486
|
+
|
487
|
+
if (maxNgrams > 2)
|
488
|
+
{
|
489
|
+
std::vector<Vid> rkeys;
|
490
|
+
trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
|
491
|
+
{
|
492
|
+
if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
|
493
|
+
auto npmi = std::log((float)node->val) - logTotN;
|
494
|
+
for (auto k : rkeys)
|
495
|
+
{
|
496
|
+
npmi += logTotN - std::log((float)vocabFreqs[k]);
|
497
|
+
}
|
498
|
+
npmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
|
499
|
+
if (npmi < minNPMI) return;
|
500
|
+
|
501
|
+
float rbe = branchingEntropy(node, candMinCnt);
|
502
|
+
float lbe = branchingEntropy(trieNodesBw[0].findNode(rkeys.rbegin(), rkeys.rend()), candMinCnt);
|
503
|
+
float nbe = std::sqrt(rbe * lbe) / std::log(node->val);
|
504
|
+
if (nbe < minNBE) return;
|
505
|
+
candidates.emplace_back(npmi * nbe, rkeys);
|
506
|
+
candidates.back().cf = node->val;
|
507
|
+
}, rkeys);
|
508
|
+
}
|
509
|
+
|
510
|
+
std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
|
511
|
+
{
|
512
|
+
return a.score > b.score;
|
513
|
+
});
|
514
|
+
if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
|
515
|
+
return candidates;
|
516
|
+
}
|
517
|
+
}
|
518
|
+
}
|