fasttext 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
@@ -12,6 +12,7 @@
|
|
12
12
|
|
13
13
|
#include <atomic>
|
14
14
|
#include <chrono>
|
15
|
+
#include <functional>
|
15
16
|
#include <iostream>
|
16
17
|
#include <memory>
|
17
18
|
#include <queue>
|
@@ -31,24 +32,29 @@
|
|
31
32
|
namespace fasttext {
|
32
33
|
|
33
34
|
class FastText {
|
35
|
+
public:
|
36
|
+
using TrainCallback =
|
37
|
+
std::function<void(float, float, double, double, int64_t)>;
|
38
|
+
|
34
39
|
protected:
|
35
40
|
std::shared_ptr<Args> args_;
|
36
41
|
std::shared_ptr<Dictionary> dict_;
|
37
|
-
|
38
42
|
std::shared_ptr<Matrix> input_;
|
39
43
|
std::shared_ptr<Matrix> output_;
|
40
|
-
|
41
44
|
std::shared_ptr<Model> model_;
|
42
|
-
|
43
45
|
std::atomic<int64_t> tokenCount_{};
|
44
46
|
std::atomic<real> loss_{};
|
45
|
-
|
46
47
|
std::chrono::steady_clock::time_point start_;
|
48
|
+
bool quant_;
|
49
|
+
int32_t version;
|
50
|
+
std::unique_ptr<DenseMatrix> wordVectors_;
|
51
|
+
std::exception_ptr trainException_;
|
52
|
+
|
47
53
|
void signModel(std::ostream&);
|
48
54
|
bool checkModel(std::istream&);
|
49
|
-
void startThreads();
|
55
|
+
void startThreads(const TrainCallback& callback = {});
|
50
56
|
void addInputVector(Vector&, int32_t) const;
|
51
|
-
void trainThread(int32_t);
|
57
|
+
void trainThread(int32_t, const TrainCallback& callback);
|
52
58
|
std::vector<std::pair<real, std::string>> getNN(
|
53
59
|
const DenseMatrix& wordVectors,
|
54
60
|
const Vector& queryVec,
|
@@ -68,10 +74,11 @@ class FastText {
|
|
68
74
|
const std::vector<int32_t>& labels);
|
69
75
|
void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
|
70
76
|
void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
77
|
+
std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
|
78
|
+
void precomputeWordVectors(DenseMatrix& wordVectors);
|
79
|
+
bool keepTraining(const int64_t ntokens) const;
|
80
|
+
void buildModel();
|
81
|
+
std::tuple<int64_t, double, double> progressInfo(real progress);
|
75
82
|
|
76
83
|
public:
|
77
84
|
FastText();
|
@@ -80,6 +87,8 @@ class FastText {
|
|
80
87
|
|
81
88
|
int32_t getSubwordId(const std::string& subword) const;
|
82
89
|
|
90
|
+
int32_t getLabelId(const std::string& label) const;
|
91
|
+
|
83
92
|
void getWordVector(Vector& vec, const std::string& word) const;
|
84
93
|
|
85
94
|
void getSubwordVector(Vector& vec, const std::string& subword) const;
|
@@ -95,6 +104,10 @@ class FastText {
|
|
95
104
|
|
96
105
|
std::shared_ptr<const DenseMatrix> getInputMatrix() const;
|
97
106
|
|
107
|
+
void setMatrices(
|
108
|
+
const std::shared_ptr<DenseMatrix>& inputMatrix,
|
109
|
+
const std::shared_ptr<DenseMatrix>& outputMatrix);
|
110
|
+
|
98
111
|
std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
|
99
112
|
|
100
113
|
void saveVectors(const std::string& filename);
|
@@ -109,7 +122,7 @@ class FastText {
|
|
109
122
|
|
110
123
|
void getSentenceVector(std::istream& in, Vector& vec);
|
111
124
|
|
112
|
-
void quantize(const Args& qargs);
|
125
|
+
void quantize(const Args& qargs, const TrainCallback& callback = {});
|
113
126
|
|
114
127
|
std::tuple<int64_t, double, double>
|
115
128
|
test(std::istream& in, int32_t k, real threshold = 0.0);
|
@@ -141,51 +154,17 @@ class FastText {
|
|
141
154
|
const std::string& wordB,
|
142
155
|
const std::string& wordC);
|
143
156
|
|
144
|
-
void train(const Args& args);
|
157
|
+
void train(const Args& args, const TrainCallback& callback = {});
|
158
|
+
|
159
|
+
void abort();
|
145
160
|
|
146
161
|
int getDimension() const;
|
147
162
|
|
148
163
|
bool isQuant() const;
|
149
164
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
"getVector is being deprecated and replaced by getWordVector.")
|
155
|
-
void getVector(Vector& vec, const std::string& word) const;
|
156
|
-
|
157
|
-
FASTTEXT_DEPRECATED(
|
158
|
-
"ngramVectors is being deprecated and replaced by getNgramVectors.")
|
159
|
-
void ngramVectors(std::string word);
|
160
|
-
|
161
|
-
FASTTEXT_DEPRECATED(
|
162
|
-
"analogies is being deprecated and replaced by getAnalogies.")
|
163
|
-
void analogies(int32_t k);
|
164
|
-
|
165
|
-
FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
|
166
|
-
std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
|
167
|
-
|
168
|
-
FASTTEXT_DEPRECATED(
|
169
|
-
"saveVectors is being deprecated, please use the other signature.")
|
170
|
-
void saveVectors();
|
171
|
-
|
172
|
-
FASTTEXT_DEPRECATED(
|
173
|
-
"saveOutput is being deprecated, please use the other signature.")
|
174
|
-
void saveOutput();
|
175
|
-
|
176
|
-
FASTTEXT_DEPRECATED(
|
177
|
-
"saveModel is being deprecated, please use the other signature.")
|
178
|
-
void saveModel();
|
179
|
-
|
180
|
-
FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
|
181
|
-
void precomputeWordVectors(DenseMatrix& wordVectors);
|
182
|
-
|
183
|
-
FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
|
184
|
-
void findNN(
|
185
|
-
const DenseMatrix& wordVectors,
|
186
|
-
const Vector& query,
|
187
|
-
int32_t k,
|
188
|
-
const std::set<std::string>& banSet,
|
189
|
-
std::vector<std::pair<real, std::string>>& results);
|
165
|
+
class AbortError : public std::runtime_error {
|
166
|
+
public:
|
167
|
+
AbortError() : std::runtime_error("Aborted.") {}
|
168
|
+
};
|
190
169
|
};
|
191
170
|
} // namespace fasttext
|
data/vendor/fastText/src/main.cc
CHANGED
@@ -11,6 +11,7 @@
|
|
11
11
|
#include <queue>
|
12
12
|
#include <stdexcept>
|
13
13
|
#include "args.h"
|
14
|
+
#include "autotune.h"
|
14
15
|
#include "fasttext.h"
|
15
16
|
|
16
17
|
using namespace fasttext;
|
@@ -20,19 +21,25 @@ void printUsage() {
|
|
20
21
|
<< "usage: fasttext <command> <args>\n\n"
|
21
22
|
<< "The commands supported by fasttext are:\n\n"
|
22
23
|
<< " supervised train a supervised classifier\n"
|
23
|
-
<< " quantize quantize a model to reduce the memory
|
24
|
+
<< " quantize quantize a model to reduce the memory "
|
25
|
+
"usage\n"
|
24
26
|
<< " test evaluate a supervised classifier\n"
|
25
|
-
<< " test-label print labels with precision and recall
|
27
|
+
<< " test-label print labels with precision and recall "
|
28
|
+
"scores\n"
|
26
29
|
<< " predict predict most likely labels\n"
|
27
|
-
<< " predict-prob predict most likely labels with
|
30
|
+
<< " predict-prob predict most likely labels with "
|
31
|
+
"probabilities\n"
|
28
32
|
<< " skipgram train a skipgram model\n"
|
29
33
|
<< " cbow train a cbow model\n"
|
30
34
|
<< " print-word-vectors print word vectors given a trained model\n"
|
31
|
-
<< " print-sentence-vectors print sentence vectors given a trained
|
32
|
-
|
35
|
+
<< " print-sentence-vectors print sentence vectors given a trained "
|
36
|
+
"model\n"
|
37
|
+
<< " print-ngrams print ngrams given a trained model and "
|
38
|
+
"word\n"
|
33
39
|
<< " nn query for nearest neighbors\n"
|
34
40
|
<< " analogies query for analogies\n"
|
35
|
-
<< " dump dump arguments,dictionary,input/output
|
41
|
+
<< " dump dump arguments,dictionary,input/output "
|
42
|
+
"vectors\n"
|
36
43
|
<< std::endl;
|
37
44
|
}
|
38
45
|
|
@@ -141,7 +148,7 @@ void test(const std::vector<std::string>& args) {
|
|
141
148
|
FastText fasttext;
|
142
149
|
fasttext.loadModel(model);
|
143
150
|
|
144
|
-
Meter meter;
|
151
|
+
Meter meter(false);
|
145
152
|
|
146
153
|
if (input == "-") {
|
147
154
|
fasttext.test(std::cin, k, threshold, meter);
|
@@ -351,19 +358,31 @@ void analogies(const std::vector<std::string> args) {
|
|
351
358
|
void train(const std::vector<std::string> args) {
|
352
359
|
Args a = Args();
|
353
360
|
a.parseArgs(args);
|
354
|
-
FastText fasttext;
|
355
|
-
std::string outputFileName
|
361
|
+
std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
|
362
|
+
std::string outputFileName;
|
363
|
+
|
364
|
+
if (a.hasAutotune() &&
|
365
|
+
a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
|
366
|
+
outputFileName = a.output + ".ftz";
|
367
|
+
} else {
|
368
|
+
outputFileName = a.output + ".bin";
|
369
|
+
}
|
356
370
|
std::ofstream ofs(outputFileName);
|
357
371
|
if (!ofs.is_open()) {
|
358
372
|
throw std::invalid_argument(
|
359
373
|
outputFileName + " cannot be opened for saving.");
|
360
374
|
}
|
361
375
|
ofs.close();
|
362
|
-
|
363
|
-
|
364
|
-
|
376
|
+
if (a.hasAutotune()) {
|
377
|
+
Autotune autotune(fasttext);
|
378
|
+
autotune.train(a);
|
379
|
+
} else {
|
380
|
+
fasttext->train(a);
|
381
|
+
}
|
382
|
+
fasttext->saveModel(outputFileName);
|
383
|
+
fasttext->saveVectors(a.output + ".vec");
|
365
384
|
if (a.saveOutput) {
|
366
|
-
fasttext
|
385
|
+
fasttext->saveOutput(a.output + ".output");
|
367
386
|
}
|
368
387
|
}
|
369
388
|
|
@@ -16,6 +16,9 @@
|
|
16
16
|
|
17
17
|
namespace fasttext {
|
18
18
|
|
19
|
+
constexpr int32_t kAllLabels = -1;
|
20
|
+
constexpr real falseNegativeScore = -1.0;
|
21
|
+
|
19
22
|
void Meter::log(
|
20
23
|
const std::vector<int32_t>& labels,
|
21
24
|
const Predictions& predictions) {
|
@@ -26,14 +29,23 @@ void Meter::log(
|
|
26
29
|
for (const auto& prediction : predictions) {
|
27
30
|
labelMetrics_[prediction.second].predicted++;
|
28
31
|
|
32
|
+
real score = std::min(std::exp(prediction.first), 1.0f);
|
33
|
+
real gold = 0.0;
|
29
34
|
if (utils::contains(labels, prediction.second)) {
|
30
35
|
labelMetrics_[prediction.second].predictedGold++;
|
31
36
|
metrics_.predictedGold++;
|
37
|
+
gold = 1.0;
|
32
38
|
}
|
39
|
+
labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
|
33
40
|
}
|
34
41
|
|
35
|
-
|
36
|
-
|
42
|
+
if (falseNegativeLabels_) {
|
43
|
+
for (const auto& label : labels) {
|
44
|
+
labelMetrics_[label].gold++;
|
45
|
+
if (!utils::containsSecond(predictions, label)) {
|
46
|
+
labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
|
47
|
+
}
|
48
|
+
}
|
37
49
|
}
|
38
50
|
}
|
39
51
|
|
@@ -57,6 +69,15 @@ double Meter::recall() const {
|
|
57
69
|
return metrics_.recall();
|
58
70
|
}
|
59
71
|
|
72
|
+
double Meter::f1Score() const {
|
73
|
+
const double precision = this->precision();
|
74
|
+
const double recall = this->recall();
|
75
|
+
if (precision + recall != 0) {
|
76
|
+
return 2 * precision * recall / (precision + recall);
|
77
|
+
}
|
78
|
+
return std::numeric_limits<double>::quiet_NaN();
|
79
|
+
}
|
80
|
+
|
60
81
|
void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
|
61
82
|
out << "N"
|
62
83
|
<< "\t" << nexamples_ << std::endl;
|
@@ -65,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
|
|
65
86
|
out << "R@" << k << "\t" << metrics_.recall() << std::endl;
|
66
87
|
}
|
67
88
|
|
89
|
+
std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
|
90
|
+
int32_t labelId) const {
|
91
|
+
std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
|
92
|
+
|
93
|
+
const auto& v = scoreVsTrue(labelId);
|
94
|
+
uint64_t truePositives = 0;
|
95
|
+
uint64_t falsePositives = 0;
|
96
|
+
double lastScore = falseNegativeScore - 1.0;
|
97
|
+
|
98
|
+
for (auto it = v.rbegin(); it != v.rend(); ++it) {
|
99
|
+
double score = it->first;
|
100
|
+
double gold = it->second;
|
101
|
+
if (score < 0) { // only reachable recall
|
102
|
+
break;
|
103
|
+
}
|
104
|
+
if (gold == 1.0) {
|
105
|
+
truePositives++;
|
106
|
+
} else {
|
107
|
+
falsePositives++;
|
108
|
+
}
|
109
|
+
if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
|
110
|
+
positiveCounts.back() = {truePositives, falsePositives};
|
111
|
+
} else {
|
112
|
+
positiveCounts.emplace_back(truePositives, falsePositives);
|
113
|
+
}
|
114
|
+
lastScore = score;
|
115
|
+
}
|
116
|
+
|
117
|
+
return positiveCounts;
|
118
|
+
}
|
119
|
+
|
120
|
+
double Meter::precisionAtRecall(double recallQuery) const {
|
121
|
+
return precisionAtRecall(kAllLabels, recallQuery);
|
122
|
+
}
|
123
|
+
|
124
|
+
double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
|
125
|
+
const auto& precisionRecall = precisionRecallCurve(labelId);
|
126
|
+
double bestPrecision = 0.0;
|
127
|
+
std::for_each(
|
128
|
+
precisionRecall.begin(),
|
129
|
+
precisionRecall.end(),
|
130
|
+
[&bestPrecision, recallQuery](const std::pair<double, double>& element) {
|
131
|
+
if (element.second >= recallQuery) {
|
132
|
+
bestPrecision = std::max(bestPrecision, element.first);
|
133
|
+
};
|
134
|
+
});
|
135
|
+
return bestPrecision;
|
136
|
+
}
|
137
|
+
|
138
|
+
double Meter::recallAtPrecision(double precisionQuery) const {
|
139
|
+
return recallAtPrecision(kAllLabels, precisionQuery);
|
140
|
+
}
|
141
|
+
|
142
|
+
double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
|
143
|
+
const auto& precisionRecall = precisionRecallCurve(labelId);
|
144
|
+
double bestRecall = 0.0;
|
145
|
+
std::for_each(
|
146
|
+
precisionRecall.begin(),
|
147
|
+
precisionRecall.end(),
|
148
|
+
[&bestRecall, precisionQuery](const std::pair<double, double>& element) {
|
149
|
+
if (element.first >= precisionQuery) {
|
150
|
+
bestRecall = std::max(bestRecall, element.second);
|
151
|
+
};
|
152
|
+
});
|
153
|
+
return bestRecall;
|
154
|
+
}
|
155
|
+
|
156
|
+
std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
|
157
|
+
return precisionRecallCurve(kAllLabels);
|
158
|
+
}
|
159
|
+
|
160
|
+
std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
|
161
|
+
int32_t labelId) const {
|
162
|
+
std::vector<std::pair<double, double>> precisionRecallCurve;
|
163
|
+
const auto& positiveCounts = getPositiveCounts(labelId);
|
164
|
+
if (positiveCounts.empty()) {
|
165
|
+
return precisionRecallCurve;
|
166
|
+
}
|
167
|
+
|
168
|
+
uint64_t golds =
|
169
|
+
(labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
|
170
|
+
|
171
|
+
auto fullRecall = std::lower_bound(
|
172
|
+
positiveCounts.begin(),
|
173
|
+
positiveCounts.end(),
|
174
|
+
golds,
|
175
|
+
utils::compareFirstLess);
|
176
|
+
|
177
|
+
if (fullRecall != positiveCounts.end()) {
|
178
|
+
fullRecall = std::next(fullRecall);
|
179
|
+
}
|
180
|
+
|
181
|
+
for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
|
182
|
+
double precision = 0.0;
|
183
|
+
double truePositives = it->first;
|
184
|
+
double falsePositives = it->second;
|
185
|
+
if (truePositives + falsePositives != 0.0) {
|
186
|
+
precision = truePositives / (truePositives + falsePositives);
|
187
|
+
}
|
188
|
+
double recall = golds != 0 ? (truePositives / double(golds))
|
189
|
+
: std::numeric_limits<double>::quiet_NaN();
|
190
|
+
precisionRecallCurve.emplace_back(precision, recall);
|
191
|
+
}
|
192
|
+
precisionRecallCurve.emplace_back(1.0, 0.0);
|
193
|
+
|
194
|
+
return precisionRecallCurve;
|
195
|
+
}
|
196
|
+
|
197
|
+
std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
|
198
|
+
std::vector<std::pair<real, real>> ret;
|
199
|
+
if (labelId == kAllLabels) {
|
200
|
+
for (const auto& k : labelMetrics_) {
|
201
|
+
auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
|
202
|
+
ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
|
203
|
+
}
|
204
|
+
} else {
|
205
|
+
if (labelMetrics_.count(labelId)) {
|
206
|
+
ret = labelMetrics_.at(labelId).scoreVsTrue;
|
207
|
+
}
|
208
|
+
}
|
209
|
+
sort(ret.begin(), ret.end());
|
210
|
+
|
211
|
+
return ret;
|
212
|
+
}
|
213
|
+
|
68
214
|
} // namespace fasttext
|
data/vendor/fastText/src/meter.h
CHANGED
@@ -22,8 +22,9 @@ class Meter {
|
|
22
22
|
uint64_t gold;
|
23
23
|
uint64_t predicted;
|
24
24
|
uint64_t predictedGold;
|
25
|
+
mutable std::vector<std::pair<real, real>> scoreVsTrue;
|
25
26
|
|
26
|
-
Metrics() : gold(0), predicted(0), predictedGold(0) {}
|
27
|
+
Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
|
27
28
|
|
28
29
|
double precision() const {
|
29
30
|
if (predicted == 0) {
|
@@ -43,18 +44,38 @@ class Meter {
|
|
43
44
|
}
|
44
45
|
return 2 * predictedGold / double(predicted + gold);
|
45
46
|
}
|
47
|
+
|
48
|
+
std::vector<std::pair<real, real>> getScoreVsTrue() {
|
49
|
+
return scoreVsTrue;
|
50
|
+
}
|
46
51
|
};
|
52
|
+
std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
|
53
|
+
int32_t labelId) const;
|
47
54
|
|
48
55
|
public:
|
49
|
-
Meter()
|
56
|
+
Meter() = delete;
|
57
|
+
explicit Meter(bool falseNegativeLabels)
|
58
|
+
: metrics_(),
|
59
|
+
nexamples_(0),
|
60
|
+
labelMetrics_(),
|
61
|
+
falseNegativeLabels_(falseNegativeLabels) {}
|
50
62
|
|
51
63
|
void log(const std::vector<int32_t>& labels, const Predictions& predictions);
|
52
64
|
|
53
65
|
double precision(int32_t);
|
54
66
|
double recall(int32_t);
|
55
67
|
double f1Score(int32_t);
|
68
|
+
std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
|
69
|
+
double precisionAtRecall(int32_t labelId, double recall) const;
|
70
|
+
double precisionAtRecall(double recall) const;
|
71
|
+
double recallAtPrecision(int32_t labelId, double recall) const;
|
72
|
+
double recallAtPrecision(double recall) const;
|
73
|
+
std::vector<std::pair<double, double>> precisionRecallCurve(
|
74
|
+
int32_t labelId) const;
|
75
|
+
std::vector<std::pair<double, double>> precisionRecallCurve() const;
|
56
76
|
double precision() const;
|
57
77
|
double recall() const;
|
78
|
+
double f1Score() const;
|
58
79
|
uint64_t nexamples() const {
|
59
80
|
return nexamples_;
|
60
81
|
}
|
@@ -64,6 +85,7 @@ class Meter {
|
|
64
85
|
Metrics metrics_{};
|
65
86
|
uint64_t nexamples_;
|
66
87
|
std::unordered_map<int32_t, Metrics> labelMetrics_;
|
88
|
+
bool falseNegativeLabels_;
|
67
89
|
};
|
68
90
|
|
69
91
|
} // namespace fasttext
|