fasttext 0.1.2 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,6 +16,9 @@
16
16
 
17
17
  namespace fasttext {
18
18
 
19
+ constexpr int32_t kAllLabels = -1;
20
+ constexpr real falseNegativeScore = -1.0;
21
+
19
22
  void Meter::log(
20
23
  const std::vector<int32_t>& labels,
21
24
  const Predictions& predictions) {
@@ -26,14 +29,23 @@ void Meter::log(
26
29
  for (const auto& prediction : predictions) {
27
30
  labelMetrics_[prediction.second].predicted++;
28
31
 
32
+ real score = std::min(std::exp(prediction.first), 1.0f);
33
+ real gold = 0.0;
29
34
  if (utils::contains(labels, prediction.second)) {
30
35
  labelMetrics_[prediction.second].predictedGold++;
31
36
  metrics_.predictedGold++;
37
+ gold = 1.0;
32
38
  }
39
+ labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
33
40
  }
34
41
 
35
- for (const auto& label : labels) {
36
- labelMetrics_[label].gold++;
42
+ if (falseNegativeLabels_) {
43
+ for (const auto& label : labels) {
44
+ labelMetrics_[label].gold++;
45
+ if (!utils::containsSecond(predictions, label)) {
46
+ labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
47
+ }
48
+ }
37
49
  }
38
50
  }
39
51
 
@@ -57,6 +69,15 @@ double Meter::recall() const {
57
69
  return metrics_.recall();
58
70
  }
59
71
 
72
+ double Meter::f1Score() const {
73
+ const double precision = this->precision();
74
+ const double recall = this->recall();
75
+ if (precision + recall != 0) {
76
+ return 2 * precision * recall / (precision + recall);
77
+ }
78
+ return std::numeric_limits<double>::quiet_NaN();
79
+ }
80
+
60
81
  void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
61
82
  out << "N"
62
83
  << "\t" << nexamples_ << std::endl;
@@ -65,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
65
86
  out << "R@" << k << "\t" << metrics_.recall() << std::endl;
66
87
  }
67
88
 
89
+ std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
90
+ int32_t labelId) const {
91
+ std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
92
+
93
+ const auto& v = scoreVsTrue(labelId);
94
+ uint64_t truePositives = 0;
95
+ uint64_t falsePositives = 0;
96
+ double lastScore = falseNegativeScore - 1.0;
97
+
98
+ for (auto it = v.rbegin(); it != v.rend(); ++it) {
99
+ double score = it->first;
100
+ double gold = it->second;
101
+ if (score < 0) { // only reachable recall
102
+ break;
103
+ }
104
+ if (gold == 1.0) {
105
+ truePositives++;
106
+ } else {
107
+ falsePositives++;
108
+ }
109
+ if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
110
+ positiveCounts.back() = {truePositives, falsePositives};
111
+ } else {
112
+ positiveCounts.emplace_back(truePositives, falsePositives);
113
+ }
114
+ lastScore = score;
115
+ }
116
+
117
+ return positiveCounts;
118
+ }
119
+
120
+ double Meter::precisionAtRecall(double recallQuery) const {
121
+ return precisionAtRecall(kAllLabels, recallQuery);
122
+ }
123
+
124
+ double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
125
+ const auto& precisionRecall = precisionRecallCurve(labelId);
126
+ double bestPrecision = 0.0;
127
+ std::for_each(
128
+ precisionRecall.begin(),
129
+ precisionRecall.end(),
130
+ [&bestPrecision, recallQuery](const std::pair<double, double>& element) {
131
+ if (element.second >= recallQuery) {
132
+ bestPrecision = std::max(bestPrecision, element.first);
133
+ };
134
+ });
135
+ return bestPrecision;
136
+ }
137
+
138
+ double Meter::recallAtPrecision(double precisionQuery) const {
139
+ return recallAtPrecision(kAllLabels, precisionQuery);
140
+ }
141
+
142
+ double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
143
+ const auto& precisionRecall = precisionRecallCurve(labelId);
144
+ double bestRecall = 0.0;
145
+ std::for_each(
146
+ precisionRecall.begin(),
147
+ precisionRecall.end(),
148
+ [&bestRecall, precisionQuery](const std::pair<double, double>& element) {
149
+ if (element.first >= precisionQuery) {
150
+ bestRecall = std::max(bestRecall, element.second);
151
+ };
152
+ });
153
+ return bestRecall;
154
+ }
155
+
156
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
157
+ return precisionRecallCurve(kAllLabels);
158
+ }
159
+
160
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
161
+ int32_t labelId) const {
162
+ std::vector<std::pair<double, double>> precisionRecallCurve;
163
+ const auto& positiveCounts = getPositiveCounts(labelId);
164
+ if (positiveCounts.empty()) {
165
+ return precisionRecallCurve;
166
+ }
167
+
168
+ uint64_t golds =
169
+ (labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
170
+
171
+ auto fullRecall = std::lower_bound(
172
+ positiveCounts.begin(),
173
+ positiveCounts.end(),
174
+ golds,
175
+ utils::compareFirstLess);
176
+
177
+ if (fullRecall != positiveCounts.end()) {
178
+ fullRecall = std::next(fullRecall);
179
+ }
180
+
181
+ for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
182
+ double precision = 0.0;
183
+ double truePositives = it->first;
184
+ double falsePositives = it->second;
185
+ if (truePositives + falsePositives != 0.0) {
186
+ precision = truePositives / (truePositives + falsePositives);
187
+ }
188
+ double recall = golds != 0 ? (truePositives / double(golds))
189
+ : std::numeric_limits<double>::quiet_NaN();
190
+ precisionRecallCurve.emplace_back(precision, recall);
191
+ }
192
+ precisionRecallCurve.emplace_back(1.0, 0.0);
193
+
194
+ return precisionRecallCurve;
195
+ }
196
+
197
+ std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
198
+ std::vector<std::pair<real, real>> ret;
199
+ if (labelId == kAllLabels) {
200
+ for (const auto& k : labelMetrics_) {
201
+ auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
202
+ ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
203
+ }
204
+ } else {
205
+ if (labelMetrics_.count(labelId)) {
206
+ ret = labelMetrics_.at(labelId).scoreVsTrue;
207
+ }
208
+ }
209
+ sort(ret.begin(), ret.end());
210
+
211
+ return ret;
212
+ }
213
+
68
214
  } // namespace fasttext
@@ -22,8 +22,9 @@ class Meter {
22
22
  uint64_t gold;
23
23
  uint64_t predicted;
24
24
  uint64_t predictedGold;
25
+ mutable std::vector<std::pair<real, real>> scoreVsTrue;
25
26
 
26
- Metrics() : gold(0), predicted(0), predictedGold(0) {}
27
+ Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
27
28
 
28
29
  double precision() const {
29
30
  if (predicted == 0) {
@@ -43,18 +44,38 @@ class Meter {
43
44
  }
44
45
  return 2 * predictedGold / double(predicted + gold);
45
46
  }
47
+
48
+ std::vector<std::pair<real, real>> getScoreVsTrue() {
49
+ return scoreVsTrue;
50
+ }
46
51
  };
52
+ std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
53
+ int32_t labelId) const;
47
54
 
48
55
  public:
49
- Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
56
+ Meter() = delete;
57
+ explicit Meter(bool falseNegativeLabels)
58
+ : metrics_(),
59
+ nexamples_(0),
60
+ labelMetrics_(),
61
+ falseNegativeLabels_(falseNegativeLabels) {}
50
62
 
51
63
  void log(const std::vector<int32_t>& labels, const Predictions& predictions);
52
64
 
53
65
  double precision(int32_t);
54
66
  double recall(int32_t);
55
67
  double f1Score(int32_t);
68
+ std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
69
+ double precisionAtRecall(int32_t labelId, double recall) const;
70
+ double precisionAtRecall(double recall) const;
71
+ double recallAtPrecision(int32_t labelId, double recall) const;
72
+ double recallAtPrecision(double recall) const;
73
+ std::vector<std::pair<double, double>> precisionRecallCurve(
74
+ int32_t labelId) const;
75
+ std::vector<std::pair<double, double>> precisionRecallCurve() const;
56
76
  double precision() const;
57
77
  double recall() const;
78
+ double f1Score() const;
58
79
  uint64_t nexamples() const {
59
80
  return nexamples_;
60
81
  }
@@ -64,6 +85,7 @@ class Meter {
64
85
  Metrics metrics_{};
65
86
  uint64_t nexamples_;
66
87
  std::unordered_map<int32_t, Metrics> labelMetrics_;
88
+ bool falseNegativeLabels_;
67
89
  };
68
90
 
69
91
  } // namespace fasttext
@@ -10,7 +10,6 @@
10
10
  #include "loss.h"
11
11
  #include "utils.h"
12
12
 
13
- #include <assert.h>
14
13
  #include <algorithm>
15
14
  #include <stdexcept>
16
15
 
@@ -11,5 +11,4 @@
11
11
  namespace fasttext {
12
12
 
13
13
  typedef float real;
14
-
15
14
  }
@@ -8,6 +8,7 @@
8
8
 
9
9
  #include "utils.h"
10
10
 
11
+ #include <iomanip>
11
12
  #include <ios>
12
13
 
13
14
  namespace fasttext {
@@ -23,6 +24,30 @@ void seek(std::ifstream& ifs, int64_t pos) {
23
24
  ifs.clear();
24
25
  ifs.seekg(std::streampos(pos));
25
26
  }
27
+
28
+ double getDuration(
29
+ const std::chrono::steady_clock::time_point& start,
30
+ const std::chrono::steady_clock::time_point& end) {
31
+ return std::chrono::duration_cast<std::chrono::duration<double>>(end - start)
32
+ .count();
33
+ }
34
+
35
+ ClockPrint::ClockPrint(int32_t duration) : duration_(duration) {}
36
+
37
+ std::ostream& operator<<(std::ostream& out, const ClockPrint& me) {
38
+ int32_t etah = me.duration_ / 3600;
39
+ int32_t etam = (me.duration_ % 3600) / 60;
40
+ int32_t etas = (me.duration_ % 3600) % 60;
41
+
42
+ out << std::setw(3) << etah << "h" << std::setw(2) << etam << "m";
43
+ out << std::setw(2) << etas << "s";
44
+ return out;
45
+ }
46
+
47
+ bool compareFirstLess(const std::pair<double, double>& l, const double& r) {
48
+ return l.first < r;
49
+ }
50
+
26
51
  } // namespace utils
27
52
 
28
53
  } // namespace fasttext
@@ -11,7 +11,9 @@
11
11
  #include "real.h"
12
12
 
13
13
  #include <algorithm>
14
+ #include <chrono>
14
15
  #include <fstream>
16
+ #include <ostream>
15
17
  #include <vector>
16
18
 
17
19
  #if defined(__clang__) || defined(__GNUC__)
@@ -38,6 +40,33 @@ bool contains(const std::vector<T>& container, const T& value) {
38
40
  container.end();
39
41
  }
40
42
 
43
+ template <typename T1, typename T2>
44
+ bool containsSecond(
45
+ const std::vector<std::pair<T1, T2>>& container,
46
+ const T2& value) {
47
+ return std::find_if(
48
+ container.begin(),
49
+ container.end(),
50
+ [&value](const std::pair<T1, T2>& item) {
51
+ return item.second == value;
52
+ }) != container.end();
53
+ }
54
+
55
+ double getDuration(
56
+ const std::chrono::steady_clock::time_point& start,
57
+ const std::chrono::steady_clock::time_point& end);
58
+
59
+ class ClockPrint {
60
+ public:
61
+ explicit ClockPrint(int32_t duration);
62
+ friend std::ostream& operator<<(std::ostream& out, const ClockPrint& me);
63
+
64
+ private:
65
+ int32_t duration_;
66
+ };
67
+
68
+ bool compareFirstLess(const std::pair<double, double>& l, const double& r);
69
+
41
70
  } // namespace utils
42
71
 
43
72
  } // namespace fasttext
@@ -12,7 +12,6 @@
12
12
 
13
13
  #include <cmath>
14
14
  #include <iomanip>
15
- #include <utility>
16
15
 
17
16
  #include "matrix.h"
18
17
 
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: fasttext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.2.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-01-11 00:00:00.000000000 Z
11
+ date: 2021-10-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -16,72 +16,16 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '2.2'
19
+ version: 4.0.2
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '2.2'
27
- - !ruby/object:Gem::Dependency
28
- name: bundler
29
- requirement: !ruby/object:Gem::Requirement
30
- requirements:
31
- - - ">="
32
- - !ruby/object:Gem::Version
33
- version: '0'
34
- type: :development
35
- prerelease: false
36
- version_requirements: !ruby/object:Gem::Requirement
37
- requirements:
38
- - - ">="
39
- - !ruby/object:Gem::Version
40
- version: '0'
41
- - !ruby/object:Gem::Dependency
42
- name: rake
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake-compiler
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '5'
83
- description:
84
- email: andrew@chartkick.com
26
+ version: 4.0.2
27
+ description:
28
+ email: andrew@ankane.org
85
29
  executables: []
86
30
  extensions:
87
31
  - ext/fasttext/extconf.rb
@@ -94,7 +38,6 @@ files:
94
38
  - ext/fasttext/extconf.rb
95
39
  - lib/fasttext.rb
96
40
  - lib/fasttext/classifier.rb
97
- - lib/fasttext/ext.bundle
98
41
  - lib/fasttext/model.rb
99
42
  - lib/fasttext/vectorizer.rb
100
43
  - lib/fasttext/version.rb
@@ -102,6 +45,8 @@ files:
102
45
  - vendor/fastText/README.md
103
46
  - vendor/fastText/src/args.cc
104
47
  - vendor/fastText/src/args.h
48
+ - vendor/fastText/src/autotune.cc
49
+ - vendor/fastText/src/autotune.h
105
50
  - vendor/fastText/src/densematrix.cc
106
51
  - vendor/fastText/src/densematrix.h
107
52
  - vendor/fastText/src/dictionary.cc
@@ -126,11 +71,11 @@ files:
126
71
  - vendor/fastText/src/utils.h
127
72
  - vendor/fastText/src/vector.cc
128
73
  - vendor/fastText/src/vector.h
129
- homepage: https://github.com/ankane/fasttext
74
+ homepage: https://github.com/ankane/fastText
130
75
  licenses:
131
76
  - MIT
132
77
  metadata: {}
133
- post_install_message:
78
+ post_install_message:
134
79
  rdoc_options: []
135
80
  require_paths:
136
81
  - lib
@@ -138,15 +83,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
138
83
  requirements:
139
84
  - - ">="
140
85
  - !ruby/object:Gem::Version
141
- version: '2.4'
86
+ version: '2.6'
142
87
  required_rubygems_version: !ruby/object:Gem::Requirement
143
88
  requirements:
144
89
  - - ">="
145
90
  - !ruby/object:Gem::Version
146
91
  version: '0'
147
92
  requirements: []
148
- rubygems_version: 3.1.2
149
- signing_key:
93
+ rubygems_version: 3.2.22
94
+ signing_key:
150
95
  specification_version: 4
151
96
  summary: fastText - efficient text classification and representation learning - for
152
97
  Ruby
Binary file