fasttext 0.1.2 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/LICENSE.txt +18 -18
- data/README.md +26 -19
- data/ext/fasttext/ext.cpp +131 -134
- data/ext/fasttext/extconf.rb +2 -4
- data/lib/fasttext/classifier.rb +23 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +11 -5
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +14 -69
- data/lib/fasttext/ext.bundle +0 -0
@@ -16,6 +16,9 @@
|
|
16
16
|
|
17
17
|
namespace fasttext {
|
18
18
|
|
19
|
+
constexpr int32_t kAllLabels = -1;
|
20
|
+
constexpr real falseNegativeScore = -1.0;
|
21
|
+
|
19
22
|
void Meter::log(
|
20
23
|
const std::vector<int32_t>& labels,
|
21
24
|
const Predictions& predictions) {
|
@@ -26,14 +29,23 @@ void Meter::log(
|
|
26
29
|
for (const auto& prediction : predictions) {
|
27
30
|
labelMetrics_[prediction.second].predicted++;
|
28
31
|
|
32
|
+
real score = std::min(std::exp(prediction.first), 1.0f);
|
33
|
+
real gold = 0.0;
|
29
34
|
if (utils::contains(labels, prediction.second)) {
|
30
35
|
labelMetrics_[prediction.second].predictedGold++;
|
31
36
|
metrics_.predictedGold++;
|
37
|
+
gold = 1.0;
|
32
38
|
}
|
39
|
+
labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
|
33
40
|
}
|
34
41
|
|
35
|
-
|
36
|
-
|
42
|
+
if (falseNegativeLabels_) {
|
43
|
+
for (const auto& label : labels) {
|
44
|
+
labelMetrics_[label].gold++;
|
45
|
+
if (!utils::containsSecond(predictions, label)) {
|
46
|
+
labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
|
47
|
+
}
|
48
|
+
}
|
37
49
|
}
|
38
50
|
}
|
39
51
|
|
@@ -57,6 +69,15 @@ double Meter::recall() const {
|
|
57
69
|
return metrics_.recall();
|
58
70
|
}
|
59
71
|
|
72
|
+
double Meter::f1Score() const {
|
73
|
+
const double precision = this->precision();
|
74
|
+
const double recall = this->recall();
|
75
|
+
if (precision + recall != 0) {
|
76
|
+
return 2 * precision * recall / (precision + recall);
|
77
|
+
}
|
78
|
+
return std::numeric_limits<double>::quiet_NaN();
|
79
|
+
}
|
80
|
+
|
60
81
|
void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
|
61
82
|
out << "N"
|
62
83
|
<< "\t" << nexamples_ << std::endl;
|
@@ -65,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
|
|
65
86
|
out << "R@" << k << "\t" << metrics_.recall() << std::endl;
|
66
87
|
}
|
67
88
|
|
89
|
+
std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
|
90
|
+
int32_t labelId) const {
|
91
|
+
std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
|
92
|
+
|
93
|
+
const auto& v = scoreVsTrue(labelId);
|
94
|
+
uint64_t truePositives = 0;
|
95
|
+
uint64_t falsePositives = 0;
|
96
|
+
double lastScore = falseNegativeScore - 1.0;
|
97
|
+
|
98
|
+
for (auto it = v.rbegin(); it != v.rend(); ++it) {
|
99
|
+
double score = it->first;
|
100
|
+
double gold = it->second;
|
101
|
+
if (score < 0) { // only reachable recall
|
102
|
+
break;
|
103
|
+
}
|
104
|
+
if (gold == 1.0) {
|
105
|
+
truePositives++;
|
106
|
+
} else {
|
107
|
+
falsePositives++;
|
108
|
+
}
|
109
|
+
if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
|
110
|
+
positiveCounts.back() = {truePositives, falsePositives};
|
111
|
+
} else {
|
112
|
+
positiveCounts.emplace_back(truePositives, falsePositives);
|
113
|
+
}
|
114
|
+
lastScore = score;
|
115
|
+
}
|
116
|
+
|
117
|
+
return positiveCounts;
|
118
|
+
}
|
119
|
+
|
120
|
+
double Meter::precisionAtRecall(double recallQuery) const {
|
121
|
+
return precisionAtRecall(kAllLabels, recallQuery);
|
122
|
+
}
|
123
|
+
|
124
|
+
double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
|
125
|
+
const auto& precisionRecall = precisionRecallCurve(labelId);
|
126
|
+
double bestPrecision = 0.0;
|
127
|
+
std::for_each(
|
128
|
+
precisionRecall.begin(),
|
129
|
+
precisionRecall.end(),
|
130
|
+
[&bestPrecision, recallQuery](const std::pair<double, double>& element) {
|
131
|
+
if (element.second >= recallQuery) {
|
132
|
+
bestPrecision = std::max(bestPrecision, element.first);
|
133
|
+
};
|
134
|
+
});
|
135
|
+
return bestPrecision;
|
136
|
+
}
|
137
|
+
|
138
|
+
double Meter::recallAtPrecision(double precisionQuery) const {
|
139
|
+
return recallAtPrecision(kAllLabels, precisionQuery);
|
140
|
+
}
|
141
|
+
|
142
|
+
double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
|
143
|
+
const auto& precisionRecall = precisionRecallCurve(labelId);
|
144
|
+
double bestRecall = 0.0;
|
145
|
+
std::for_each(
|
146
|
+
precisionRecall.begin(),
|
147
|
+
precisionRecall.end(),
|
148
|
+
[&bestRecall, precisionQuery](const std::pair<double, double>& element) {
|
149
|
+
if (element.first >= precisionQuery) {
|
150
|
+
bestRecall = std::max(bestRecall, element.second);
|
151
|
+
};
|
152
|
+
});
|
153
|
+
return bestRecall;
|
154
|
+
}
|
155
|
+
|
156
|
+
std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
|
157
|
+
return precisionRecallCurve(kAllLabels);
|
158
|
+
}
|
159
|
+
|
160
|
+
std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
|
161
|
+
int32_t labelId) const {
|
162
|
+
std::vector<std::pair<double, double>> precisionRecallCurve;
|
163
|
+
const auto& positiveCounts = getPositiveCounts(labelId);
|
164
|
+
if (positiveCounts.empty()) {
|
165
|
+
return precisionRecallCurve;
|
166
|
+
}
|
167
|
+
|
168
|
+
uint64_t golds =
|
169
|
+
(labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
|
170
|
+
|
171
|
+
auto fullRecall = std::lower_bound(
|
172
|
+
positiveCounts.begin(),
|
173
|
+
positiveCounts.end(),
|
174
|
+
golds,
|
175
|
+
utils::compareFirstLess);
|
176
|
+
|
177
|
+
if (fullRecall != positiveCounts.end()) {
|
178
|
+
fullRecall = std::next(fullRecall);
|
179
|
+
}
|
180
|
+
|
181
|
+
for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
|
182
|
+
double precision = 0.0;
|
183
|
+
double truePositives = it->first;
|
184
|
+
double falsePositives = it->second;
|
185
|
+
if (truePositives + falsePositives != 0.0) {
|
186
|
+
precision = truePositives / (truePositives + falsePositives);
|
187
|
+
}
|
188
|
+
double recall = golds != 0 ? (truePositives / double(golds))
|
189
|
+
: std::numeric_limits<double>::quiet_NaN();
|
190
|
+
precisionRecallCurve.emplace_back(precision, recall);
|
191
|
+
}
|
192
|
+
precisionRecallCurve.emplace_back(1.0, 0.0);
|
193
|
+
|
194
|
+
return precisionRecallCurve;
|
195
|
+
}
|
196
|
+
|
197
|
+
std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
|
198
|
+
std::vector<std::pair<real, real>> ret;
|
199
|
+
if (labelId == kAllLabels) {
|
200
|
+
for (const auto& k : labelMetrics_) {
|
201
|
+
auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
|
202
|
+
ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
|
203
|
+
}
|
204
|
+
} else {
|
205
|
+
if (labelMetrics_.count(labelId)) {
|
206
|
+
ret = labelMetrics_.at(labelId).scoreVsTrue;
|
207
|
+
}
|
208
|
+
}
|
209
|
+
sort(ret.begin(), ret.end());
|
210
|
+
|
211
|
+
return ret;
|
212
|
+
}
|
213
|
+
|
68
214
|
} // namespace fasttext
|
data/vendor/fastText/src/meter.h
CHANGED
@@ -22,8 +22,9 @@ class Meter {
|
|
22
22
|
uint64_t gold;
|
23
23
|
uint64_t predicted;
|
24
24
|
uint64_t predictedGold;
|
25
|
+
mutable std::vector<std::pair<real, real>> scoreVsTrue;
|
25
26
|
|
26
|
-
Metrics() : gold(0), predicted(0), predictedGold(0) {}
|
27
|
+
Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
|
27
28
|
|
28
29
|
double precision() const {
|
29
30
|
if (predicted == 0) {
|
@@ -43,18 +44,38 @@ class Meter {
|
|
43
44
|
}
|
44
45
|
return 2 * predictedGold / double(predicted + gold);
|
45
46
|
}
|
47
|
+
|
48
|
+
std::vector<std::pair<real, real>> getScoreVsTrue() {
|
49
|
+
return scoreVsTrue;
|
50
|
+
}
|
46
51
|
};
|
52
|
+
std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
|
53
|
+
int32_t labelId) const;
|
47
54
|
|
48
55
|
public:
|
49
|
-
Meter()
|
56
|
+
Meter() = delete;
|
57
|
+
explicit Meter(bool falseNegativeLabels)
|
58
|
+
: metrics_(),
|
59
|
+
nexamples_(0),
|
60
|
+
labelMetrics_(),
|
61
|
+
falseNegativeLabels_(falseNegativeLabels) {}
|
50
62
|
|
51
63
|
void log(const std::vector<int32_t>& labels, const Predictions& predictions);
|
52
64
|
|
53
65
|
double precision(int32_t);
|
54
66
|
double recall(int32_t);
|
55
67
|
double f1Score(int32_t);
|
68
|
+
std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
|
69
|
+
double precisionAtRecall(int32_t labelId, double recall) const;
|
70
|
+
double precisionAtRecall(double recall) const;
|
71
|
+
double recallAtPrecision(int32_t labelId, double recall) const;
|
72
|
+
double recallAtPrecision(double recall) const;
|
73
|
+
std::vector<std::pair<double, double>> precisionRecallCurve(
|
74
|
+
int32_t labelId) const;
|
75
|
+
std::vector<std::pair<double, double>> precisionRecallCurve() const;
|
56
76
|
double precision() const;
|
57
77
|
double recall() const;
|
78
|
+
double f1Score() const;
|
58
79
|
uint64_t nexamples() const {
|
59
80
|
return nexamples_;
|
60
81
|
}
|
@@ -64,6 +85,7 @@ class Meter {
|
|
64
85
|
Metrics metrics_{};
|
65
86
|
uint64_t nexamples_;
|
66
87
|
std::unordered_map<int32_t, Metrics> labelMetrics_;
|
88
|
+
bool falseNegativeLabels_;
|
67
89
|
};
|
68
90
|
|
69
91
|
} // namespace fasttext
|
data/vendor/fastText/src/real.h
CHANGED
@@ -8,6 +8,7 @@
|
|
8
8
|
|
9
9
|
#include "utils.h"
|
10
10
|
|
11
|
+
#include <iomanip>
|
11
12
|
#include <ios>
|
12
13
|
|
13
14
|
namespace fasttext {
|
@@ -23,6 +24,30 @@ void seek(std::ifstream& ifs, int64_t pos) {
|
|
23
24
|
ifs.clear();
|
24
25
|
ifs.seekg(std::streampos(pos));
|
25
26
|
}
|
27
|
+
|
28
|
+
double getDuration(
|
29
|
+
const std::chrono::steady_clock::time_point& start,
|
30
|
+
const std::chrono::steady_clock::time_point& end) {
|
31
|
+
return std::chrono::duration_cast<std::chrono::duration<double>>(end - start)
|
32
|
+
.count();
|
33
|
+
}
|
34
|
+
|
35
|
+
ClockPrint::ClockPrint(int32_t duration) : duration_(duration) {}
|
36
|
+
|
37
|
+
std::ostream& operator<<(std::ostream& out, const ClockPrint& me) {
|
38
|
+
int32_t etah = me.duration_ / 3600;
|
39
|
+
int32_t etam = (me.duration_ % 3600) / 60;
|
40
|
+
int32_t etas = (me.duration_ % 3600) % 60;
|
41
|
+
|
42
|
+
out << std::setw(3) << etah << "h" << std::setw(2) << etam << "m";
|
43
|
+
out << std::setw(2) << etas << "s";
|
44
|
+
return out;
|
45
|
+
}
|
46
|
+
|
47
|
+
bool compareFirstLess(const std::pair<double, double>& l, const double& r) {
|
48
|
+
return l.first < r;
|
49
|
+
}
|
50
|
+
|
26
51
|
} // namespace utils
|
27
52
|
|
28
53
|
} // namespace fasttext
|
data/vendor/fastText/src/utils.h
CHANGED
@@ -11,7 +11,9 @@
|
|
11
11
|
#include "real.h"
|
12
12
|
|
13
13
|
#include <algorithm>
|
14
|
+
#include <chrono>
|
14
15
|
#include <fstream>
|
16
|
+
#include <ostream>
|
15
17
|
#include <vector>
|
16
18
|
|
17
19
|
#if defined(__clang__) || defined(__GNUC__)
|
@@ -38,6 +40,33 @@ bool contains(const std::vector<T>& container, const T& value) {
|
|
38
40
|
container.end();
|
39
41
|
}
|
40
42
|
|
43
|
+
template <typename T1, typename T2>
|
44
|
+
bool containsSecond(
|
45
|
+
const std::vector<std::pair<T1, T2>>& container,
|
46
|
+
const T2& value) {
|
47
|
+
return std::find_if(
|
48
|
+
container.begin(),
|
49
|
+
container.end(),
|
50
|
+
[&value](const std::pair<T1, T2>& item) {
|
51
|
+
return item.second == value;
|
52
|
+
}) != container.end();
|
53
|
+
}
|
54
|
+
|
55
|
+
double getDuration(
|
56
|
+
const std::chrono::steady_clock::time_point& start,
|
57
|
+
const std::chrono::steady_clock::time_point& end);
|
58
|
+
|
59
|
+
class ClockPrint {
|
60
|
+
public:
|
61
|
+
explicit ClockPrint(int32_t duration);
|
62
|
+
friend std::ostream& operator<<(std::ostream& out, const ClockPrint& me);
|
63
|
+
|
64
|
+
private:
|
65
|
+
int32_t duration_;
|
66
|
+
};
|
67
|
+
|
68
|
+
bool compareFirstLess(const std::pair<double, double>& l, const double& r);
|
69
|
+
|
41
70
|
} // namespace utils
|
42
71
|
|
43
72
|
} // namespace fasttext
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: fasttext
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-10-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,72 +16,16 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 4.0.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
27
|
-
|
28
|
-
|
29
|
-
requirement: !ruby/object:Gem::Requirement
|
30
|
-
requirements:
|
31
|
-
- - ">="
|
32
|
-
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
34
|
-
type: :development
|
35
|
-
prerelease: false
|
36
|
-
version_requirements: !ruby/object:Gem::Requirement
|
37
|
-
requirements:
|
38
|
-
- - ">="
|
39
|
-
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
41
|
-
- !ruby/object:Gem::Dependency
|
42
|
-
name: rake
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: rake-compiler
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '0'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '0'
|
69
|
-
- !ruby/object:Gem::Dependency
|
70
|
-
name: minitest
|
71
|
-
requirement: !ruby/object:Gem::Requirement
|
72
|
-
requirements:
|
73
|
-
- - ">="
|
74
|
-
- !ruby/object:Gem::Version
|
75
|
-
version: '5'
|
76
|
-
type: :development
|
77
|
-
prerelease: false
|
78
|
-
version_requirements: !ruby/object:Gem::Requirement
|
79
|
-
requirements:
|
80
|
-
- - ">="
|
81
|
-
- !ruby/object:Gem::Version
|
82
|
-
version: '5'
|
83
|
-
description:
|
84
|
-
email: andrew@chartkick.com
|
26
|
+
version: 4.0.2
|
27
|
+
description:
|
28
|
+
email: andrew@ankane.org
|
85
29
|
executables: []
|
86
30
|
extensions:
|
87
31
|
- ext/fasttext/extconf.rb
|
@@ -94,7 +38,6 @@ files:
|
|
94
38
|
- ext/fasttext/extconf.rb
|
95
39
|
- lib/fasttext.rb
|
96
40
|
- lib/fasttext/classifier.rb
|
97
|
-
- lib/fasttext/ext.bundle
|
98
41
|
- lib/fasttext/model.rb
|
99
42
|
- lib/fasttext/vectorizer.rb
|
100
43
|
- lib/fasttext/version.rb
|
@@ -102,6 +45,8 @@ files:
|
|
102
45
|
- vendor/fastText/README.md
|
103
46
|
- vendor/fastText/src/args.cc
|
104
47
|
- vendor/fastText/src/args.h
|
48
|
+
- vendor/fastText/src/autotune.cc
|
49
|
+
- vendor/fastText/src/autotune.h
|
105
50
|
- vendor/fastText/src/densematrix.cc
|
106
51
|
- vendor/fastText/src/densematrix.h
|
107
52
|
- vendor/fastText/src/dictionary.cc
|
@@ -126,11 +71,11 @@ files:
|
|
126
71
|
- vendor/fastText/src/utils.h
|
127
72
|
- vendor/fastText/src/vector.cc
|
128
73
|
- vendor/fastText/src/vector.h
|
129
|
-
homepage: https://github.com/ankane/
|
74
|
+
homepage: https://github.com/ankane/fastText
|
130
75
|
licenses:
|
131
76
|
- MIT
|
132
77
|
metadata: {}
|
133
|
-
post_install_message:
|
78
|
+
post_install_message:
|
134
79
|
rdoc_options: []
|
135
80
|
require_paths:
|
136
81
|
- lib
|
@@ -138,15 +83,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
138
83
|
requirements:
|
139
84
|
- - ">="
|
140
85
|
- !ruby/object:Gem::Version
|
141
|
-
version: '2.
|
86
|
+
version: '2.6'
|
142
87
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
143
88
|
requirements:
|
144
89
|
- - ">="
|
145
90
|
- !ruby/object:Gem::Version
|
146
91
|
version: '0'
|
147
92
|
requirements: []
|
148
|
-
rubygems_version: 3.
|
149
|
-
signing_key:
|
93
|
+
rubygems_version: 3.2.22
|
94
|
+
signing_key:
|
150
95
|
specification_version: 4
|
151
96
|
summary: fastText - efficient text classification and representation learning - for
|
152
97
|
Ruby
|
data/lib/fasttext/ext.bundle
DELETED
Binary file
|