tomoto 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -10,6 +10,7 @@ namespace tomoto
10
10
  struct Candidate
11
11
  {
12
12
  float score = 0;
13
+ size_t cf = 0, df = 0;
13
14
  std::vector<Vid> w;
14
15
  std::string name;
15
16
 
@@ -18,17 +19,17 @@ namespace tomoto
18
19
  }
19
20
 
20
21
  Candidate(float _score, Vid w1)
21
- : w{ w1 }, score{ _score }
22
+ : score{ _score }, w{ w1 }
22
23
  {
23
24
  }
24
25
 
25
26
  Candidate(float _score, Vid w1, Vid w2)
26
- : w{ w1, w2 }, score{ _score }
27
+ : score{ _score }, w{ w1, w2 }
27
28
  {
28
29
  }
29
30
 
30
31
  Candidate(float _score, const std::vector<Vid>& _w)
31
- : w{ _w }, score{ _score }
32
+ : score{ _score }, w{ _w }
32
33
  {
33
34
  }
34
35
  };
@@ -36,6 +37,7 @@ namespace tomoto
36
37
  class IExtractor
37
38
  {
38
39
  public:
40
+
39
41
  virtual std::vector<Candidate> extract(const ITopicModel* tm) const = 0;
40
42
  virtual ~IExtractor() {}
41
43
  };
@@ -0,0 +1,518 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <unordered_map>
5
+ #include "Labeler.h"
6
+ #include "../Utils/Trie.hpp"
7
+
8
+ namespace tomoto
9
+ {
10
+ namespace phraser
11
+ {
12
+ template<typename _DocIter>
13
+ void countUnigrams(std::vector<size_t>& unigramCf, std::vector<size_t>& unigramDf,
14
+ _DocIter docBegin, _DocIter docEnd
15
+ )
16
+ {
17
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
18
+ {
19
+ auto doc = *docIt;
20
+ if (!doc.size()) continue;
21
+ std::unordered_set<Vid> uniqs;
22
+ for (size_t i = 0; i < doc.size(); ++i)
23
+ {
24
+ if (doc[i] == non_vocab_id) continue;
25
+ unigramCf[doc[i]]++;
26
+ uniqs.emplace(doc[i]);
27
+ }
28
+
29
+ for (auto w : uniqs) unigramDf[w]++;
30
+ }
31
+ }
32
+
33
+ template<typename _DocIter, typename _VvHash, typename _Freqs>
34
+ void countBigrams(std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramCf,
35
+ std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramDf,
36
+ _DocIter docBegin, _DocIter docEnd,
37
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
38
+ size_t candMinCnt, size_t candMinDf
39
+ )
40
+ {
41
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
42
+ {
43
+ std::unordered_set<std::pair<Vid, Vid>, _VvHash> uniqBigram;
44
+ auto doc = *docIt;
45
+ if (!doc.size()) continue;
46
+ Vid prevWord = doc[0];
47
+ for (size_t j = 1; j < doc.size(); ++j)
48
+ {
49
+ Vid curWord = doc[j];
50
+ if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf)
51
+ {
52
+ if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
53
+ {
54
+ bigramCf[std::make_pair(prevWord, curWord)]++;
55
+ uniqBigram.emplace(prevWord, curWord);
56
+ }
57
+ }
58
+ prevWord = curWord;
59
+ }
60
+
61
+ for (auto& p : uniqBigram) bigramDf[p]++;
62
+ }
63
+ }
64
+
65
+ template<bool _reverse, typename _DocIter, typename _Freqs, typename _BigramPairs>
66
+ void countNgrams(std::vector<TrieEx<Vid, size_t>>& dest,
67
+ _DocIter docBegin, _DocIter docEnd,
68
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs,
69
+ size_t candMinCnt, size_t candMinDf, size_t maxNgrams
70
+ )
71
+ {
72
+ if (dest.empty())
73
+ {
74
+ dest.resize(1);
75
+ dest.reserve(1024);
76
+ }
77
+ auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
78
+
79
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
80
+ {
81
+ auto doc = *docIt;
82
+ if (!doc.size()) continue;
83
+ if (dest.capacity() < dest.size() + doc.size() * maxNgrams)
84
+ {
85
+ dest.reserve(std::max(dest.size() + doc.size() * maxNgrams, dest.capacity() * 2));
86
+ }
87
+
88
+ Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin();
89
+ size_t labelLen = 0;
90
+ auto node = &dest[0];
91
+ if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
92
+ {
93
+ node = dest[0].makeNext(prevWord, allocNode);
94
+ node->val++;
95
+ labelLen = 1;
96
+ }
97
+
98
+ const auto func = [&](Vid curWord)
99
+ {
100
+ if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf))
101
+ {
102
+ node = &dest[0];
103
+ labelLen = 0;
104
+ }
105
+ else
106
+ {
107
+ if (labelLen >= maxNgrams)
108
+ {
109
+ node = node->getFail();
110
+ labelLen--;
111
+ }
112
+
113
+ if (validPairs.count(_reverse ? std::make_pair(curWord, prevWord) : std::make_pair(prevWord, curWord)))
114
+ {
115
+ auto nnode = node->makeNext(curWord, allocNode);
116
+ node = nnode;
117
+ do
118
+ {
119
+ nnode->val++;
120
+ } while ((nnode = nnode->getFail()));
121
+ labelLen++;
122
+ }
123
+ else
124
+ {
125
+ node = dest[0].makeNext(curWord, allocNode);
126
+ node->val++;
127
+ labelLen = 1;
128
+ }
129
+ }
130
+ prevWord = curWord;
131
+ };
132
+
133
+ if (_reverse) std::for_each(doc.rbegin() + 1, doc.rend(), func);
134
+ else std::for_each(doc.begin() + 1, doc.end(), func);
135
+ }
136
+ }
137
+
138
+ inline void mergeNgramCounts(std::vector<TrieEx<Vid, size_t>>& dest, std::vector<TrieEx<Vid, size_t>>&& src)
139
+ {
140
+ if (src.empty()) return;
141
+ if (dest.empty()) dest.resize(1);
142
+
143
+ auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
144
+
145
+ std::vector<Vid> rkeys;
146
+ src[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
147
+ {
148
+ if (dest.capacity() < dest.size() + rkeys.size() * rkeys.size())
149
+ {
150
+ dest.reserve(std::max(dest.size() + rkeys.size() * rkeys.size(), dest.capacity() * 2));
151
+ }
152
+ dest[0].build(rkeys.begin(), rkeys.end(), 0, allocNode)->val += node->val;
153
+ }, rkeys);
154
+ }
155
+
156
+ inline float branchingEntropy(const TrieEx<Vid, size_t>* node, size_t minCnt)
157
+ {
158
+ float entropy = 0;
159
+ size_t rest = node->val;
160
+ for (auto n : *node)
161
+ {
162
+ float p = n.second->val / (float)node->val;
163
+ entropy -= p * std::log(p);
164
+ rest -= n.second->val;
165
+ }
166
+ if (rest > 0)
167
+ {
168
+ float p = rest / (float)node->val;
169
+ entropy -= p * std::log(std::min(std::max(minCnt, (size_t)1), (size_t)rest) / (float)node->val);
170
+ }
171
+ return entropy;
172
+ }
173
+
174
+ template<typename _LocalData, typename _ReduceFn>
175
+ _LocalData parallelReduce(std::vector<_LocalData>&& data, _ReduceFn&& fn, ThreadPool* pool = nullptr)
176
+ {
177
+ if (pool)
178
+ {
179
+ for (size_t s = data.size(); s > 1; s = (s + 1) / 2)
180
+ {
181
+ std::vector<std::future<void>> futures;
182
+ size_t h = (s + 1) / 2;
183
+ for (size_t i = h; i < s; ++i)
184
+ {
185
+ futures.emplace_back(pool->enqueue([&, i, h](size_t)
186
+ {
187
+ _LocalData d = std::move(data[i]);
188
+ fn(data[i - h], std::move(d));
189
+ }));
190
+ }
191
+ for (auto& f : futures) f.get();
192
+ }
193
+ }
194
+ else
195
+ {
196
+ for (size_t i = 1; i < data.size(); ++i)
197
+ {
198
+ _LocalData d = std::move(data[i]);
199
+ fn(data[0], std::move(d));
200
+ }
201
+ }
202
+ return std::move(data[0]);
203
+ }
204
+
205
+ namespace detail
206
+ {
207
+ struct vvhash
208
+ {
209
+ size_t operator()(const std::pair<Vid, Vid>& k) const
210
+ {
211
+ return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
212
+ }
213
+ };
214
+ }
215
+
216
+ template<typename _DocIter, typename _Freqs>
217
+ std::vector<label::Candidate> extractPMINgrams(_DocIter docBegin, _DocIter docEnd,
218
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
219
+ size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
220
+ float minScore, bool normalized = false,
221
+ ThreadPool* pool = nullptr)
222
+ {
223
+ // counting unigrams & bigrams
224
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
225
+
226
+ if (pool && pool->getNumWorkers() > 1)
227
+ {
228
+ using LocalCfDf = std::pair<
229
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
230
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
231
+ >;
232
+ std::vector<LocalCfDf> localdata(pool->getNumWorkers());
233
+ std::vector<std::future<void>> futures;
234
+ const size_t stride = pool->getNumWorkers() * 8;
235
+ auto docIt = docBegin;
236
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
237
+ {
238
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
239
+ {
240
+ countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
241
+ }));
242
+ }
243
+
244
+ for (auto& f : futures) f.get();
245
+
246
+ auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
247
+ {
248
+ for (auto& p : src.first) dest.first[p.first] += p.second;
249
+ for (auto& p : src.second) dest.second[p.first] += p.second;
250
+ }, pool);
251
+
252
+ bigramCnt = std::move(r.first);
253
+ bigramDf = std::move(r.second);
254
+ }
255
+ else
256
+ {
257
+ countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
258
+ }
259
+
260
+ // counting ngrams
261
+ std::vector<TrieEx<Vid, size_t>> trieNodes;
262
+ if (maxNgrams > 2)
263
+ {
264
+ std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
265
+ for (auto& p : bigramCnt)
266
+ {
267
+ if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
268
+ }
269
+
270
+ if (pool && pool->getNumWorkers() > 1)
271
+ {
272
+ using LocalFwBw = std::vector<TrieEx<Vid, size_t>>;
273
+ std::vector<LocalFwBw> localdata(pool->getNumWorkers());
274
+ std::vector<std::future<void>> futures;
275
+ const size_t stride = pool->getNumWorkers() * 8;
276
+ auto docIt = docBegin;
277
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
278
+ {
279
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
280
+ {
281
+ countNgrams<false>(localdata[tid],
282
+ makeStrideIter(docIt, stride, docEnd),
283
+ makeStrideIter(docEnd, stride, docEnd),
284
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
285
+ );
286
+ }));
287
+ }
288
+
289
+ for (auto& f : futures) f.get();
290
+
291
+ auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
292
+ {
293
+ mergeNgramCounts(dest, std::move(src));
294
+ }, pool);
295
+
296
+ trieNodes = std::move(r);
297
+ }
298
+ else
299
+ {
300
+ countNgrams<false>(trieNodes,
301
+ docBegin, docEnd,
302
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
303
+ );
304
+ }
305
+ }
306
+
307
+ float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
308
+ const float logTotN = std::log(totN);
309
+
310
+ // calculating PMIs
311
+ std::vector<label::Candidate> candidates;
312
+ for (auto& p : bigramCnt)
313
+ {
314
+ auto& bigram = p.first;
315
+ if (p.second < candMinCnt) continue;
316
+ if (bigramDf[bigram] < candMinDf) continue;
317
+ auto pmi = std::log(p.second * totN
318
+ / vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
319
+ if (normalized)
320
+ {
321
+ pmi /= std::log(totN / p.second);
322
+ }
323
+ if (pmi < minScore) continue;
324
+ candidates.emplace_back(pmi, bigram.first, bigram.second);
325
+ candidates.back().cf = p.second;
326
+ candidates.back().df = bigramDf[bigram];
327
+ }
328
+
329
+ if (maxNgrams > 2)
330
+ {
331
+ std::vector<Vid> rkeys;
332
+ trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
333
+ {
334
+ if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
335
+ auto pmi = std::log((float)node->val) - logTotN;
336
+ for (auto k : rkeys)
337
+ {
338
+ pmi += logTotN - std::log((float)vocabFreqs[k]);
339
+ }
340
+ if (normalized)
341
+ {
342
+ pmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
343
+ }
344
+ if (pmi < minScore) return;
345
+ candidates.emplace_back(pmi, rkeys);
346
+ candidates.back().cf = node->val;
347
+ }, rkeys);
348
+ }
349
+
350
+ std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
351
+ {
352
+ return a.score > b.score;
353
+ });
354
+ if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
355
+ return candidates;
356
+ }
357
+
358
+ template<typename _DocIter, typename _Freqs>
359
+ std::vector<label::Candidate> extractPMIBENgrams(_DocIter docBegin, _DocIter docEnd,
360
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
361
+ size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
362
+ float minNPMI = 0, float minNBE = 0,
363
+ ThreadPool* pool = nullptr)
364
+ {
365
+ // counting unigrams & bigrams
366
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
367
+
368
+ if (pool && pool->getNumWorkers() > 1)
369
+ {
370
+ using LocalCfDf = std::pair<
371
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
372
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
373
+ >;
374
+ std::vector<LocalCfDf> localdata(pool->getNumWorkers());
375
+ std::vector<std::future<void>> futures;
376
+ const size_t stride = pool->getNumWorkers() * 8;
377
+ auto docIt = docBegin;
378
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
379
+ {
380
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
381
+ {
382
+ countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
383
+ }));
384
+ }
385
+
386
+ for (auto& f : futures) f.get();
387
+
388
+ auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
389
+ {
390
+ for (auto& p : src.first) dest.first[p.first] += p.second;
391
+ for (auto& p : src.second) dest.second[p.first] += p.second;
392
+ }, pool);
393
+
394
+ bigramCnt = std::move(r.first);
395
+ bigramDf = std::move(r.second);
396
+ }
397
+ else
398
+ {
399
+ countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
400
+ }
401
+
402
+ // counting ngrams
403
+ std::vector<TrieEx<Vid, size_t>> trieNodes, trieNodesBw;
404
+ if (maxNgrams > 2)
405
+ {
406
+ std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
407
+ for (auto& p : bigramCnt)
408
+ {
409
+ if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
410
+ }
411
+
412
+ if (pool && pool->getNumWorkers() > 1)
413
+ {
414
+ using LocalFwBw = std::pair<
415
+ std::vector<TrieEx<Vid, size_t>>,
416
+ std::vector<TrieEx<Vid, size_t>>
417
+ >;
418
+ std::vector<LocalFwBw> localdata(pool->getNumWorkers());
419
+ std::vector<std::future<void>> futures;
420
+ const size_t stride = pool->getNumWorkers() * 8;
421
+ auto docIt = docBegin;
422
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
423
+ {
424
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
425
+ {
426
+ countNgrams<false>(localdata[tid].first,
427
+ makeStrideIter(docIt, stride, docEnd),
428
+ makeStrideIter(docEnd, stride, docEnd),
429
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
430
+ );
431
+ countNgrams<true>(localdata[tid].second,
432
+ makeStrideIter(docIt, stride, docEnd),
433
+ makeStrideIter(docEnd, stride, docEnd),
434
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
435
+ );
436
+ }));
437
+ }
438
+
439
+ for (auto& f : futures) f.get();
440
+
441
+ auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
442
+ {
443
+ mergeNgramCounts(dest.first, std::move(src.first));
444
+ mergeNgramCounts(dest.second, std::move(src.second));
445
+ }, pool);
446
+
447
+ trieNodes = std::move(r.first);
448
+ trieNodesBw = std::move(r.second);
449
+ }
450
+ else
451
+ {
452
+ countNgrams<false>(trieNodes,
453
+ docBegin, docEnd,
454
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
455
+ );
456
+ countNgrams<true>(trieNodesBw,
457
+ docBegin, docEnd,
458
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
459
+ );
460
+ }
461
+ }
462
+
463
+ float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
464
+ const float logTotN = std::log(totN);
465
+
466
+ // calculating PMIs
467
+ std::vector<label::Candidate> candidates;
468
+ for (auto& p : bigramCnt)
469
+ {
470
+ auto& bigram = p.first;
471
+ if (p.second < candMinCnt) continue;
472
+ if (bigramDf[bigram] < candMinDf) continue;
473
+ float npmi = std::log(p.second * totN
474
+ / vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
475
+ npmi /= std::log(totN / p.second);
476
+ if (npmi < minNPMI) continue;
477
+
478
+ float rbe = branchingEntropy(trieNodes[0].getNext(bigram.first)->getNext(bigram.second), candMinCnt);
479
+ float lbe = branchingEntropy(trieNodesBw[0].getNext(bigram.second)->getNext(bigram.first), candMinCnt);
480
+ float nbe = std::sqrt(rbe * lbe) / std::log(p.second);
481
+ if (nbe < minNBE) continue;
482
+ candidates.emplace_back(npmi * nbe, bigram.first, bigram.second);
483
+ candidates.back().cf = p.second;
484
+ candidates.back().df = bigramDf[bigram];
485
+ }
486
+
487
+ if (maxNgrams > 2)
488
+ {
489
+ std::vector<Vid> rkeys;
490
+ trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
491
+ {
492
+ if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
493
+ auto npmi = std::log((float)node->val) - logTotN;
494
+ for (auto k : rkeys)
495
+ {
496
+ npmi += logTotN - std::log((float)vocabFreqs[k]);
497
+ }
498
+ npmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
499
+ if (npmi < minNPMI) return;
500
+
501
+ float rbe = branchingEntropy(node, candMinCnt);
502
+ float lbe = branchingEntropy(trieNodesBw[0].findNode(rkeys.rbegin(), rkeys.rend()), candMinCnt);
503
+ float nbe = std::sqrt(rbe * lbe) / std::log(node->val);
504
+ if (nbe < minNBE) return;
505
+ candidates.emplace_back(npmi * nbe, rkeys);
506
+ candidates.back().cf = node->val;
507
+ }, rkeys);
508
+ }
509
+
510
+ std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
511
+ {
512
+ return a.score > b.score;
513
+ });
514
+ if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
515
+ return candidates;
516
+ }
517
+ }
518
+ }