tomoto 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -0
  5. data/ext/tomoto/ct.cpp +54 -0
  6. data/ext/tomoto/dmr.cpp +62 -0
  7. data/ext/tomoto/dt.cpp +82 -0
  8. data/ext/tomoto/ext.cpp +27 -773
  9. data/ext/tomoto/gdmr.cpp +34 -0
  10. data/ext/tomoto/hdp.cpp +42 -0
  11. data/ext/tomoto/hlda.cpp +66 -0
  12. data/ext/tomoto/hpa.cpp +27 -0
  13. data/ext/tomoto/lda.cpp +250 -0
  14. data/ext/tomoto/llda.cpp +29 -0
  15. data/ext/tomoto/mglda.cpp +71 -0
  16. data/ext/tomoto/pa.cpp +27 -0
  17. data/ext/tomoto/plda.cpp +29 -0
  18. data/ext/tomoto/slda.cpp +40 -0
  19. data/ext/tomoto/utils.h +84 -0
  20. data/lib/tomoto/tomoto.bundle +0 -0
  21. data/lib/tomoto/tomoto.so +0 -0
  22. data/lib/tomoto/version.rb +1 -1
  23. data/vendor/tomotopy/README.kr.rst +12 -3
  24. data/vendor/tomotopy/README.rst +12 -3
  25. data/vendor/tomotopy/src/Labeling/FoRelevance.cpp +47 -2
  26. data/vendor/tomotopy/src/Labeling/FoRelevance.h +21 -151
  27. data/vendor/tomotopy/src/Labeling/Labeler.h +5 -3
  28. data/vendor/tomotopy/src/Labeling/Phraser.hpp +518 -0
  29. data/vendor/tomotopy/src/TopicModel/CTModel.hpp +6 -3
  30. data/vendor/tomotopy/src/TopicModel/DT.h +1 -1
  31. data/vendor/tomotopy/src/TopicModel/DTModel.hpp +8 -23
  32. data/vendor/tomotopy/src/TopicModel/HDPModel.hpp +9 -18
  33. data/vendor/tomotopy/src/TopicModel/HLDAModel.hpp +56 -58
  34. data/vendor/tomotopy/src/TopicModel/HPAModel.hpp +4 -14
  35. data/vendor/tomotopy/src/TopicModel/LDA.h +69 -17
  36. data/vendor/tomotopy/src/TopicModel/LDACVB0Model.hpp +1 -1
  37. data/vendor/tomotopy/src/TopicModel/LDAModel.hpp +108 -61
  38. data/vendor/tomotopy/src/TopicModel/MGLDAModel.hpp +7 -8
  39. data/vendor/tomotopy/src/TopicModel/PAModel.hpp +26 -16
  40. data/vendor/tomotopy/src/TopicModel/PT.h +27 -0
  41. data/vendor/tomotopy/src/TopicModel/PTModel.cpp +10 -0
  42. data/vendor/tomotopy/src/TopicModel/PTModel.hpp +273 -0
  43. data/vendor/tomotopy/src/TopicModel/TopicModel.hpp +16 -11
  44. data/vendor/tomotopy/src/Utils/MultiNormalDistribution.hpp +3 -2
  45. data/vendor/tomotopy/src/Utils/Trie.hpp +39 -8
  46. data/vendor/tomotopy/src/Utils/TruncMultiNormal.hpp +36 -38
  47. data/vendor/tomotopy/src/Utils/Utils.hpp +50 -45
  48. data/vendor/tomotopy/src/Utils/math.h +8 -4
  49. data/vendor/tomotopy/src/Utils/tvector.hpp +4 -0
  50. metadata +24 -60
@@ -10,6 +10,7 @@ namespace tomoto
10
10
  struct Candidate
11
11
  {
12
12
  float score = 0;
13
+ size_t cf = 0, df = 0;
13
14
  std::vector<Vid> w;
14
15
  std::string name;
15
16
 
@@ -18,17 +19,17 @@ namespace tomoto
18
19
  }
19
20
 
20
21
  Candidate(float _score, Vid w1)
21
- : w{ w1 }, score{ _score }
22
+ : score{ _score }, w{ w1 }
22
23
  {
23
24
  }
24
25
 
25
26
  Candidate(float _score, Vid w1, Vid w2)
26
- : w{ w1, w2 }, score{ _score }
27
+ : score{ _score }, w{ w1, w2 }
27
28
  {
28
29
  }
29
30
 
30
31
  Candidate(float _score, const std::vector<Vid>& _w)
31
- : w{ _w }, score{ _score }
32
+ : score{ _score }, w{ _w }
32
33
  {
33
34
  }
34
35
  };
@@ -36,6 +37,7 @@ namespace tomoto
36
37
  class IExtractor
37
38
  {
38
39
  public:
40
+
39
41
  virtual std::vector<Candidate> extract(const ITopicModel* tm) const = 0;
40
42
  virtual ~IExtractor() {}
41
43
  };
@@ -0,0 +1,518 @@
1
+ #pragma once
2
+
3
+ #include <vector>
4
+ #include <unordered_map>
5
+ #include "Labeler.h"
6
+ #include "../Utils/Trie.hpp"
7
+
8
+ namespace tomoto
9
+ {
10
+ namespace phraser
11
+ {
12
+ template<typename _DocIter>
13
+ void countUnigrams(std::vector<size_t>& unigramCf, std::vector<size_t>& unigramDf,
14
+ _DocIter docBegin, _DocIter docEnd
15
+ )
16
+ {
17
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
18
+ {
19
+ auto doc = *docIt;
20
+ if (!doc.size()) continue;
21
+ std::unordered_set<Vid> uniqs;
22
+ for (size_t i = 0; i < doc.size(); ++i)
23
+ {
24
+ if (doc[i] == non_vocab_id) continue;
25
+ unigramCf[doc[i]]++;
26
+ uniqs.emplace(doc[i]);
27
+ }
28
+
29
+ for (auto w : uniqs) unigramDf[w]++;
30
+ }
31
+ }
32
+
33
+ template<typename _DocIter, typename _VvHash, typename _Freqs>
34
+ void countBigrams(std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramCf,
35
+ std::unordered_map<std::pair<Vid, Vid>, size_t, _VvHash>& bigramDf,
36
+ _DocIter docBegin, _DocIter docEnd,
37
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
38
+ size_t candMinCnt, size_t candMinDf
39
+ )
40
+ {
41
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
42
+ {
43
+ std::unordered_set<std::pair<Vid, Vid>, _VvHash> uniqBigram;
44
+ auto doc = *docIt;
45
+ if (!doc.size()) continue;
46
+ Vid prevWord = doc[0];
47
+ for (size_t j = 1; j < doc.size(); ++j)
48
+ {
49
+ Vid curWord = doc[j];
50
+ if (curWord != non_vocab_id && vocabFreqs[curWord] >= candMinCnt && vocabDf[curWord] >= candMinDf)
51
+ {
52
+ if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
53
+ {
54
+ bigramCf[std::make_pair(prevWord, curWord)]++;
55
+ uniqBigram.emplace(prevWord, curWord);
56
+ }
57
+ }
58
+ prevWord = curWord;
59
+ }
60
+
61
+ for (auto& p : uniqBigram) bigramDf[p]++;
62
+ }
63
+ }
64
+
65
+ template<bool _reverse, typename _DocIter, typename _Freqs, typename _BigramPairs>
66
+ void countNgrams(std::vector<TrieEx<Vid, size_t>>& dest,
67
+ _DocIter docBegin, _DocIter docEnd,
68
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf, _BigramPairs&& validPairs,
69
+ size_t candMinCnt, size_t candMinDf, size_t maxNgrams
70
+ )
71
+ {
72
+ if (dest.empty())
73
+ {
74
+ dest.resize(1);
75
+ dest.reserve(1024);
76
+ }
77
+ auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
78
+
79
+ for (auto docIt = docBegin; docIt != docEnd; ++docIt)
80
+ {
81
+ auto doc = *docIt;
82
+ if (!doc.size()) continue;
83
+ if (dest.capacity() < dest.size() + doc.size() * maxNgrams)
84
+ {
85
+ dest.reserve(std::max(dest.size() + doc.size() * maxNgrams, dest.capacity() * 2));
86
+ }
87
+
88
+ Vid prevWord = _reverse ? *doc.rbegin() : *doc.begin();
89
+ size_t labelLen = 0;
90
+ auto node = &dest[0];
91
+ if (prevWord != non_vocab_id && vocabFreqs[prevWord] >= candMinCnt && vocabDf[prevWord] >= candMinDf)
92
+ {
93
+ node = dest[0].makeNext(prevWord, allocNode);
94
+ node->val++;
95
+ labelLen = 1;
96
+ }
97
+
98
+ const auto func = [&](Vid curWord)
99
+ {
100
+ if (curWord != non_vocab_id && (vocabFreqs[curWord] < candMinCnt || vocabDf[curWord] < candMinDf))
101
+ {
102
+ node = &dest[0];
103
+ labelLen = 0;
104
+ }
105
+ else
106
+ {
107
+ if (labelLen >= maxNgrams)
108
+ {
109
+ node = node->getFail();
110
+ labelLen--;
111
+ }
112
+
113
+ if (validPairs.count(_reverse ? std::make_pair(curWord, prevWord) : std::make_pair(prevWord, curWord)))
114
+ {
115
+ auto nnode = node->makeNext(curWord, allocNode);
116
+ node = nnode;
117
+ do
118
+ {
119
+ nnode->val++;
120
+ } while ((nnode = nnode->getFail()));
121
+ labelLen++;
122
+ }
123
+ else
124
+ {
125
+ node = dest[0].makeNext(curWord, allocNode);
126
+ node->val++;
127
+ labelLen = 1;
128
+ }
129
+ }
130
+ prevWord = curWord;
131
+ };
132
+
133
+ if (_reverse) std::for_each(doc.rbegin() + 1, doc.rend(), func);
134
+ else std::for_each(doc.begin() + 1, doc.end(), func);
135
+ }
136
+ }
137
+
138
+ inline void mergeNgramCounts(std::vector<TrieEx<Vid, size_t>>& dest, std::vector<TrieEx<Vid, size_t>>&& src)
139
+ {
140
+ if (src.empty()) return;
141
+ if (dest.empty()) dest.resize(1);
142
+
143
+ auto allocNode = [&]() { return dest.emplace_back(), & dest.back(); };
144
+
145
+ std::vector<Vid> rkeys;
146
+ src[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
147
+ {
148
+ if (dest.capacity() < dest.size() + rkeys.size() * rkeys.size())
149
+ {
150
+ dest.reserve(std::max(dest.size() + rkeys.size() * rkeys.size(), dest.capacity() * 2));
151
+ }
152
+ dest[0].build(rkeys.begin(), rkeys.end(), 0, allocNode)->val += node->val;
153
+ }, rkeys);
154
+ }
155
+
156
+ inline float branchingEntropy(const TrieEx<Vid, size_t>* node, size_t minCnt)
157
+ {
158
+ float entropy = 0;
159
+ size_t rest = node->val;
160
+ for (auto n : *node)
161
+ {
162
+ float p = n.second->val / (float)node->val;
163
+ entropy -= p * std::log(p);
164
+ rest -= n.second->val;
165
+ }
166
+ if (rest > 0)
167
+ {
168
+ float p = rest / (float)node->val;
169
+ entropy -= p * std::log(std::min(std::max(minCnt, (size_t)1), (size_t)rest) / (float)node->val);
170
+ }
171
+ return entropy;
172
+ }
173
+
174
+ template<typename _LocalData, typename _ReduceFn>
175
+ _LocalData parallelReduce(std::vector<_LocalData>&& data, _ReduceFn&& fn, ThreadPool* pool = nullptr)
176
+ {
177
+ if (pool)
178
+ {
179
+ for (size_t s = data.size(); s > 1; s = (s + 1) / 2)
180
+ {
181
+ std::vector<std::future<void>> futures;
182
+ size_t h = (s + 1) / 2;
183
+ for (size_t i = h; i < s; ++i)
184
+ {
185
+ futures.emplace_back(pool->enqueue([&, i, h](size_t)
186
+ {
187
+ _LocalData d = std::move(data[i]);
188
+ fn(data[i - h], std::move(d));
189
+ }));
190
+ }
191
+ for (auto& f : futures) f.get();
192
+ }
193
+ }
194
+ else
195
+ {
196
+ for (size_t i = 1; i < data.size(); ++i)
197
+ {
198
+ _LocalData d = std::move(data[i]);
199
+ fn(data[0], std::move(d));
200
+ }
201
+ }
202
+ return std::move(data[0]);
203
+ }
204
+
205
+ namespace detail
206
+ {
207
+ struct vvhash
208
+ {
209
+ size_t operator()(const std::pair<Vid, Vid>& k) const
210
+ {
211
+ return std::hash<Vid>{}(k.first) ^ std::hash<Vid>{}(k.second);
212
+ }
213
+ };
214
+ }
215
+
216
+ template<typename _DocIter, typename _Freqs>
217
+ std::vector<label::Candidate> extractPMINgrams(_DocIter docBegin, _DocIter docEnd,
218
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
219
+ size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
220
+ float minScore, bool normalized = false,
221
+ ThreadPool* pool = nullptr)
222
+ {
223
+ // counting unigrams & bigrams
224
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
225
+
226
+ if (pool && pool->getNumWorkers() > 1)
227
+ {
228
+ using LocalCfDf = std::pair<
229
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
230
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
231
+ >;
232
+ std::vector<LocalCfDf> localdata(pool->getNumWorkers());
233
+ std::vector<std::future<void>> futures;
234
+ const size_t stride = pool->getNumWorkers() * 8;
235
+ auto docIt = docBegin;
236
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
237
+ {
238
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
239
+ {
240
+ countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
241
+ }));
242
+ }
243
+
244
+ for (auto& f : futures) f.get();
245
+
246
+ auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
247
+ {
248
+ for (auto& p : src.first) dest.first[p.first] += p.second;
249
+ for (auto& p : src.second) dest.second[p.first] += p.second;
250
+ }, pool);
251
+
252
+ bigramCnt = std::move(r.first);
253
+ bigramDf = std::move(r.second);
254
+ }
255
+ else
256
+ {
257
+ countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
258
+ }
259
+
260
+ // counting ngrams
261
+ std::vector<TrieEx<Vid, size_t>> trieNodes;
262
+ if (maxNgrams > 2)
263
+ {
264
+ std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
265
+ for (auto& p : bigramCnt)
266
+ {
267
+ if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
268
+ }
269
+
270
+ if (pool && pool->getNumWorkers() > 1)
271
+ {
272
+ using LocalFwBw = std::vector<TrieEx<Vid, size_t>>;
273
+ std::vector<LocalFwBw> localdata(pool->getNumWorkers());
274
+ std::vector<std::future<void>> futures;
275
+ const size_t stride = pool->getNumWorkers() * 8;
276
+ auto docIt = docBegin;
277
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
278
+ {
279
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
280
+ {
281
+ countNgrams<false>(localdata[tid],
282
+ makeStrideIter(docIt, stride, docEnd),
283
+ makeStrideIter(docEnd, stride, docEnd),
284
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
285
+ );
286
+ }));
287
+ }
288
+
289
+ for (auto& f : futures) f.get();
290
+
291
+ auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
292
+ {
293
+ mergeNgramCounts(dest, std::move(src));
294
+ }, pool);
295
+
296
+ trieNodes = std::move(r);
297
+ }
298
+ else
299
+ {
300
+ countNgrams<false>(trieNodes,
301
+ docBegin, docEnd,
302
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams
303
+ );
304
+ }
305
+ }
306
+
307
+ float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
308
+ const float logTotN = std::log(totN);
309
+
310
+ // calculating PMIs
311
+ std::vector<label::Candidate> candidates;
312
+ for (auto& p : bigramCnt)
313
+ {
314
+ auto& bigram = p.first;
315
+ if (p.second < candMinCnt) continue;
316
+ if (bigramDf[bigram] < candMinDf) continue;
317
+ auto pmi = std::log(p.second * totN
318
+ / vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
319
+ if (normalized)
320
+ {
321
+ pmi /= std::log(totN / p.second);
322
+ }
323
+ if (pmi < minScore) continue;
324
+ candidates.emplace_back(pmi, bigram.first, bigram.second);
325
+ candidates.back().cf = p.second;
326
+ candidates.back().df = bigramDf[bigram];
327
+ }
328
+
329
+ if (maxNgrams > 2)
330
+ {
331
+ std::vector<Vid> rkeys;
332
+ trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
333
+ {
334
+ if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
335
+ auto pmi = std::log((float)node->val) - logTotN;
336
+ for (auto k : rkeys)
337
+ {
338
+ pmi += logTotN - std::log((float)vocabFreqs[k]);
339
+ }
340
+ if (normalized)
341
+ {
342
+ pmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
343
+ }
344
+ if (pmi < minScore) return;
345
+ candidates.emplace_back(pmi, rkeys);
346
+ candidates.back().cf = node->val;
347
+ }, rkeys);
348
+ }
349
+
350
+ std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
351
+ {
352
+ return a.score > b.score;
353
+ });
354
+ if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
355
+ return candidates;
356
+ }
357
+
358
+ template<typename _DocIter, typename _Freqs>
359
+ std::vector<label::Candidate> extractPMIBENgrams(_DocIter docBegin, _DocIter docEnd,
360
+ _Freqs&& vocabFreqs, _Freqs&& vocabDf,
361
+ size_t candMinCnt, size_t candMinDf, size_t minNgrams, size_t maxNgrams, size_t maxCandidates,
362
+ float minNPMI = 0, float minNBE = 0,
363
+ ThreadPool* pool = nullptr)
364
+ {
365
+ // counting unigrams & bigrams
366
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash> bigramCnt, bigramDf;
367
+
368
+ if (pool && pool->getNumWorkers() > 1)
369
+ {
370
+ using LocalCfDf = std::pair<
371
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>,
372
+ std::unordered_map<std::pair<Vid, Vid>, size_t, detail::vvhash>
373
+ >;
374
+ std::vector<LocalCfDf> localdata(pool->getNumWorkers());
375
+ std::vector<std::future<void>> futures;
376
+ const size_t stride = pool->getNumWorkers() * 8;
377
+ auto docIt = docBegin;
378
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
379
+ {
380
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
381
+ {
382
+ countBigrams(localdata[tid].first, localdata[tid].second, makeStrideIter(docIt, stride, docEnd), makeStrideIter(docEnd, stride, docEnd), vocabFreqs, vocabDf, candMinCnt, candMinDf);
383
+ }));
384
+ }
385
+
386
+ for (auto& f : futures) f.get();
387
+
388
+ auto r = parallelReduce(std::move(localdata), [](LocalCfDf& dest, LocalCfDf&& src)
389
+ {
390
+ for (auto& p : src.first) dest.first[p.first] += p.second;
391
+ for (auto& p : src.second) dest.second[p.first] += p.second;
392
+ }, pool);
393
+
394
+ bigramCnt = std::move(r.first);
395
+ bigramDf = std::move(r.second);
396
+ }
397
+ else
398
+ {
399
+ countBigrams(bigramCnt, bigramDf, docBegin, docEnd, vocabFreqs, vocabDf, candMinCnt, candMinDf);
400
+ }
401
+
402
+ // counting ngrams
403
+ std::vector<TrieEx<Vid, size_t>> trieNodes, trieNodesBw;
404
+ if (maxNgrams > 2)
405
+ {
406
+ std::unordered_set<std::pair<Vid, Vid>, detail::vvhash> validPairs;
407
+ for (auto& p : bigramCnt)
408
+ {
409
+ if (p.second >= candMinCnt && bigramDf[p.first] >= candMinDf) validPairs.emplace(p.first);
410
+ }
411
+
412
+ if (pool && pool->getNumWorkers() > 1)
413
+ {
414
+ using LocalFwBw = std::pair<
415
+ std::vector<TrieEx<Vid, size_t>>,
416
+ std::vector<TrieEx<Vid, size_t>>
417
+ >;
418
+ std::vector<LocalFwBw> localdata(pool->getNumWorkers());
419
+ std::vector<std::future<void>> futures;
420
+ const size_t stride = pool->getNumWorkers() * 8;
421
+ auto docIt = docBegin;
422
+ for (size_t i = 0; i < stride && docIt != docEnd; ++i, ++docIt)
423
+ {
424
+ futures.emplace_back(pool->enqueue([&, docIt, stride](size_t tid)
425
+ {
426
+ countNgrams<false>(localdata[tid].first,
427
+ makeStrideIter(docIt, stride, docEnd),
428
+ makeStrideIter(docEnd, stride, docEnd),
429
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
430
+ );
431
+ countNgrams<true>(localdata[tid].second,
432
+ makeStrideIter(docIt, stride, docEnd),
433
+ makeStrideIter(docEnd, stride, docEnd),
434
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
435
+ );
436
+ }));
437
+ }
438
+
439
+ for (auto& f : futures) f.get();
440
+
441
+ auto r = parallelReduce(std::move(localdata), [&](LocalFwBw& dest, LocalFwBw&& src)
442
+ {
443
+ mergeNgramCounts(dest.first, std::move(src.first));
444
+ mergeNgramCounts(dest.second, std::move(src.second));
445
+ }, pool);
446
+
447
+ trieNodes = std::move(r.first);
448
+ trieNodesBw = std::move(r.second);
449
+ }
450
+ else
451
+ {
452
+ countNgrams<false>(trieNodes,
453
+ docBegin, docEnd,
454
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
455
+ );
456
+ countNgrams<true>(trieNodesBw,
457
+ docBegin, docEnd,
458
+ vocabFreqs, vocabDf, validPairs, candMinCnt, candMinDf, maxNgrams + 1
459
+ );
460
+ }
461
+ }
462
+
463
+ float totN = std::accumulate(vocabFreqs.begin(), vocabFreqs.end(), (size_t)0);
464
+ const float logTotN = std::log(totN);
465
+
466
+ // calculating PMIs
467
+ std::vector<label::Candidate> candidates;
468
+ for (auto& p : bigramCnt)
469
+ {
470
+ auto& bigram = p.first;
471
+ if (p.second < candMinCnt) continue;
472
+ if (bigramDf[bigram] < candMinDf) continue;
473
+ float npmi = std::log(p.second * totN
474
+ / vocabFreqs[bigram.first] / vocabFreqs[bigram.second]);
475
+ npmi /= std::log(totN / p.second);
476
+ if (npmi < minNPMI) continue;
477
+
478
+ float rbe = branchingEntropy(trieNodes[0].getNext(bigram.first)->getNext(bigram.second), candMinCnt);
479
+ float lbe = branchingEntropy(trieNodesBw[0].getNext(bigram.second)->getNext(bigram.first), candMinCnt);
480
+ float nbe = std::sqrt(rbe * lbe) / std::log(p.second);
481
+ if (nbe < minNBE) continue;
482
+ candidates.emplace_back(npmi * nbe, bigram.first, bigram.second);
483
+ candidates.back().cf = p.second;
484
+ candidates.back().df = bigramDf[bigram];
485
+ }
486
+
487
+ if (maxNgrams > 2)
488
+ {
489
+ std::vector<Vid> rkeys;
490
+ trieNodes[0].traverse_with_keys([&](const TrieEx<Vid, size_t>* node, const std::vector<Vid>& rkeys)
491
+ {
492
+ if (rkeys.size() <= 2 || rkeys.size() < minNgrams || rkeys.size() > maxNgrams || node->val < candMinCnt) return;
493
+ auto npmi = std::log((float)node->val) - logTotN;
494
+ for (auto k : rkeys)
495
+ {
496
+ npmi += logTotN - std::log((float)vocabFreqs[k]);
497
+ }
498
+ npmi /= (logTotN - std::log((float)node->val)) * (rkeys.size() - 1);
499
+ if (npmi < minNPMI) return;
500
+
501
+ float rbe = branchingEntropy(node, candMinCnt);
502
+ float lbe = branchingEntropy(trieNodesBw[0].findNode(rkeys.rbegin(), rkeys.rend()), candMinCnt);
503
+ float nbe = std::sqrt(rbe * lbe) / std::log(node->val);
504
+ if (nbe < minNBE) return;
505
+ candidates.emplace_back(npmi * nbe, rkeys);
506
+ candidates.back().cf = node->val;
507
+ }, rkeys);
508
+ }
509
+
510
+ std::sort(candidates.begin(), candidates.end(), [](const label::Candidate& a, const label::Candidate& b)
511
+ {
512
+ return a.score > b.score;
513
+ });
514
+ if (candidates.size() > maxCandidates) candidates.erase(candidates.begin() + maxCandidates, candidates.end());
515
+ return candidates;
516
+ }
517
+ }
518
+ }