fasttext 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <atomic>
14
14
  #include <chrono>
15
+ #include <functional>
15
16
  #include <iostream>
16
17
  #include <memory>
17
18
  #include <queue>
@@ -31,24 +32,29 @@
31
32
  namespace fasttext {
32
33
 
33
34
  class FastText {
35
+ public:
36
+ using TrainCallback =
37
+ std::function<void(float, float, double, double, int64_t)>;
38
+
34
39
  protected:
35
40
  std::shared_ptr<Args> args_;
36
41
  std::shared_ptr<Dictionary> dict_;
37
-
38
42
  std::shared_ptr<Matrix> input_;
39
43
  std::shared_ptr<Matrix> output_;
40
-
41
44
  std::shared_ptr<Model> model_;
42
-
43
45
  std::atomic<int64_t> tokenCount_{};
44
46
  std::atomic<real> loss_{};
45
-
46
47
  std::chrono::steady_clock::time_point start_;
48
+ bool quant_;
49
+ int32_t version;
50
+ std::unique_ptr<DenseMatrix> wordVectors_;
51
+ std::exception_ptr trainException_;
52
+
47
53
  void signModel(std::ostream&);
48
54
  bool checkModel(std::istream&);
49
- void startThreads();
55
+ void startThreads(const TrainCallback& callback = {});
50
56
  void addInputVector(Vector&, int32_t) const;
51
- void trainThread(int32_t);
57
+ void trainThread(int32_t, const TrainCallback& callback);
52
58
  std::vector<std::pair<real, std::string>> getNN(
53
59
  const DenseMatrix& wordVectors,
54
60
  const Vector& queryVec,
@@ -68,10 +74,11 @@ class FastText {
68
74
  const std::vector<int32_t>& labels);
69
75
  void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
70
76
  void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
71
-
72
- bool quant_;
73
- int32_t version;
74
- std::unique_ptr<DenseMatrix> wordVectors_;
77
+ std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
78
+ void precomputeWordVectors(DenseMatrix& wordVectors);
79
+ bool keepTraining(const int64_t ntokens) const;
80
+ void buildModel();
81
+ std::tuple<int64_t, double, double> progressInfo(real progress);
75
82
 
76
83
  public:
77
84
  FastText();
@@ -80,6 +87,8 @@ class FastText {
80
87
 
81
88
  int32_t getSubwordId(const std::string& subword) const;
82
89
 
90
+ int32_t getLabelId(const std::string& label) const;
91
+
83
92
  void getWordVector(Vector& vec, const std::string& word) const;
84
93
 
85
94
  void getSubwordVector(Vector& vec, const std::string& subword) const;
@@ -95,6 +104,10 @@ class FastText {
95
104
 
96
105
  std::shared_ptr<const DenseMatrix> getInputMatrix() const;
97
106
 
107
+ void setMatrices(
108
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
109
+ const std::shared_ptr<DenseMatrix>& outputMatrix);
110
+
98
111
  std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
99
112
 
100
113
  void saveVectors(const std::string& filename);
@@ -109,7 +122,7 @@ class FastText {
109
122
 
110
123
  void getSentenceVector(std::istream& in, Vector& vec);
111
124
 
112
- void quantize(const Args& qargs);
125
+ void quantize(const Args& qargs, const TrainCallback& callback = {});
113
126
 
114
127
  std::tuple<int64_t, double, double>
115
128
  test(std::istream& in, int32_t k, real threshold = 0.0);
@@ -141,51 +154,17 @@ class FastText {
141
154
  const std::string& wordB,
142
155
  const std::string& wordC);
143
156
 
144
- void train(const Args& args);
157
+ void train(const Args& args, const TrainCallback& callback = {});
158
+
159
+ void abort();
145
160
 
146
161
  int getDimension() const;
147
162
 
148
163
  bool isQuant() const;
149
164
 
150
- FASTTEXT_DEPRECATED("loadVectors is being deprecated.")
151
- void loadVectors(const std::string& filename);
152
-
153
- FASTTEXT_DEPRECATED(
154
- "getVector is being deprecated and replaced by getWordVector.")
155
- void getVector(Vector& vec, const std::string& word) const;
156
-
157
- FASTTEXT_DEPRECATED(
158
- "ngramVectors is being deprecated and replaced by getNgramVectors.")
159
- void ngramVectors(std::string word);
160
-
161
- FASTTEXT_DEPRECATED(
162
- "analogies is being deprecated and replaced by getAnalogies.")
163
- void analogies(int32_t k);
164
-
165
- FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
166
- std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
167
-
168
- FASTTEXT_DEPRECATED(
169
- "saveVectors is being deprecated, please use the other signature.")
170
- void saveVectors();
171
-
172
- FASTTEXT_DEPRECATED(
173
- "saveOutput is being deprecated, please use the other signature.")
174
- void saveOutput();
175
-
176
- FASTTEXT_DEPRECATED(
177
- "saveModel is being deprecated, please use the other signature.")
178
- void saveModel();
179
-
180
- FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
181
- void precomputeWordVectors(DenseMatrix& wordVectors);
182
-
183
- FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
184
- void findNN(
185
- const DenseMatrix& wordVectors,
186
- const Vector& query,
187
- int32_t k,
188
- const std::set<std::string>& banSet,
189
- std::vector<std::pair<real, std::string>>& results);
165
+ class AbortError : public std::runtime_error {
166
+ public:
167
+ AbortError() : std::runtime_error("Aborted.") {}
168
+ };
190
169
  };
191
170
  } // namespace fasttext
@@ -11,6 +11,7 @@
11
11
  #include <queue>
12
12
  #include <stdexcept>
13
13
  #include "args.h"
14
+ #include "autotune.h"
14
15
  #include "fasttext.h"
15
16
 
16
17
  using namespace fasttext;
@@ -20,19 +21,25 @@ void printUsage() {
20
21
  << "usage: fasttext <command> <args>\n\n"
21
22
  << "The commands supported by fasttext are:\n\n"
22
23
  << " supervised train a supervised classifier\n"
23
- << " quantize quantize a model to reduce the memory usage\n"
24
+ << " quantize quantize a model to reduce the memory "
25
+ "usage\n"
24
26
  << " test evaluate a supervised classifier\n"
25
- << " test-label print labels with precision and recall scores\n"
27
+ << " test-label print labels with precision and recall "
28
+ "scores\n"
26
29
  << " predict predict most likely labels\n"
27
- << " predict-prob predict most likely labels with probabilities\n"
30
+ << " predict-prob predict most likely labels with "
31
+ "probabilities\n"
28
32
  << " skipgram train a skipgram model\n"
29
33
  << " cbow train a cbow model\n"
30
34
  << " print-word-vectors print word vectors given a trained model\n"
31
- << " print-sentence-vectors print sentence vectors given a trained model\n"
32
- << " print-ngrams print ngrams given a trained model and word\n"
35
+ << " print-sentence-vectors print sentence vectors given a trained "
36
+ "model\n"
37
+ << " print-ngrams print ngrams given a trained model and "
38
+ "word\n"
33
39
  << " nn query for nearest neighbors\n"
34
40
  << " analogies query for analogies\n"
35
- << " dump dump arguments,dictionary,input/output vectors\n"
41
+ << " dump dump arguments,dictionary,input/output "
42
+ "vectors\n"
36
43
  << std::endl;
37
44
  }
38
45
 
@@ -141,7 +148,7 @@ void test(const std::vector<std::string>& args) {
141
148
  FastText fasttext;
142
149
  fasttext.loadModel(model);
143
150
 
144
- Meter meter;
151
+ Meter meter(false);
145
152
 
146
153
  if (input == "-") {
147
154
  fasttext.test(std::cin, k, threshold, meter);
@@ -351,19 +358,31 @@ void analogies(const std::vector<std::string> args) {
351
358
  void train(const std::vector<std::string> args) {
352
359
  Args a = Args();
353
360
  a.parseArgs(args);
354
- FastText fasttext;
355
- std::string outputFileName(a.output + ".bin");
361
+ std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
362
+ std::string outputFileName;
363
+
364
+ if (a.hasAutotune() &&
365
+ a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
366
+ outputFileName = a.output + ".ftz";
367
+ } else {
368
+ outputFileName = a.output + ".bin";
369
+ }
356
370
  std::ofstream ofs(outputFileName);
357
371
  if (!ofs.is_open()) {
358
372
  throw std::invalid_argument(
359
373
  outputFileName + " cannot be opened for saving.");
360
374
  }
361
375
  ofs.close();
362
- fasttext.train(a);
363
- fasttext.saveModel(outputFileName);
364
- fasttext.saveVectors(a.output + ".vec");
376
+ if (a.hasAutotune()) {
377
+ Autotune autotune(fasttext);
378
+ autotune.train(a);
379
+ } else {
380
+ fasttext->train(a);
381
+ }
382
+ fasttext->saveModel(outputFileName);
383
+ fasttext->saveVectors(a.output + ".vec");
365
384
  if (a.saveOutput) {
366
- fasttext.saveOutput(a.output + ".output");
385
+ fasttext->saveOutput(a.output + ".output");
367
386
  }
368
387
  }
369
388
 
@@ -16,6 +16,9 @@
16
16
 
17
17
  namespace fasttext {
18
18
 
19
+ constexpr int32_t kAllLabels = -1;
20
+ constexpr real falseNegativeScore = -1.0;
21
+
19
22
  void Meter::log(
20
23
  const std::vector<int32_t>& labels,
21
24
  const Predictions& predictions) {
@@ -26,14 +29,23 @@ void Meter::log(
26
29
  for (const auto& prediction : predictions) {
27
30
  labelMetrics_[prediction.second].predicted++;
28
31
 
32
+ real score = std::min(std::exp(prediction.first), 1.0f);
33
+ real gold = 0.0;
29
34
  if (utils::contains(labels, prediction.second)) {
30
35
  labelMetrics_[prediction.second].predictedGold++;
31
36
  metrics_.predictedGold++;
37
+ gold = 1.0;
32
38
  }
39
+ labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
33
40
  }
34
41
 
35
- for (const auto& label : labels) {
36
- labelMetrics_[label].gold++;
42
+ if (falseNegativeLabels_) {
43
+ for (const auto& label : labels) {
44
+ labelMetrics_[label].gold++;
45
+ if (!utils::containsSecond(predictions, label)) {
46
+ labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
47
+ }
48
+ }
37
49
  }
38
50
  }
39
51
 
@@ -57,6 +69,15 @@ double Meter::recall() const {
57
69
  return metrics_.recall();
58
70
  }
59
71
 
72
+ double Meter::f1Score() const {
73
+ const double precision = this->precision();
74
+ const double recall = this->recall();
75
+ if (precision + recall != 0) {
76
+ return 2 * precision * recall / (precision + recall);
77
+ }
78
+ return std::numeric_limits<double>::quiet_NaN();
79
+ }
80
+
60
81
  void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
61
82
  out << "N"
62
83
  << "\t" << nexamples_ << std::endl;
@@ -65,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
65
86
  out << "R@" << k << "\t" << metrics_.recall() << std::endl;
66
87
  }
67
88
 
89
+ std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
90
+ int32_t labelId) const {
91
+ std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
92
+
93
+ const auto& v = scoreVsTrue(labelId);
94
+ uint64_t truePositives = 0;
95
+ uint64_t falsePositives = 0;
96
+ double lastScore = falseNegativeScore - 1.0;
97
+
98
+ for (auto it = v.rbegin(); it != v.rend(); ++it) {
99
+ double score = it->first;
100
+ double gold = it->second;
101
+ if (score < 0) { // only reachable recall
102
+ break;
103
+ }
104
+ if (gold == 1.0) {
105
+ truePositives++;
106
+ } else {
107
+ falsePositives++;
108
+ }
109
+ if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
110
+ positiveCounts.back() = {truePositives, falsePositives};
111
+ } else {
112
+ positiveCounts.emplace_back(truePositives, falsePositives);
113
+ }
114
+ lastScore = score;
115
+ }
116
+
117
+ return positiveCounts;
118
+ }
119
+
120
+ double Meter::precisionAtRecall(double recallQuery) const {
121
+ return precisionAtRecall(kAllLabels, recallQuery);
122
+ }
123
+
124
+ double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
125
+ const auto& precisionRecall = precisionRecallCurve(labelId);
126
+ double bestPrecision = 0.0;
127
+ std::for_each(
128
+ precisionRecall.begin(),
129
+ precisionRecall.end(),
130
+ [&bestPrecision, recallQuery](const std::pair<double, double>& element) {
131
+ if (element.second >= recallQuery) {
132
+ bestPrecision = std::max(bestPrecision, element.first);
133
+ };
134
+ });
135
+ return bestPrecision;
136
+ }
137
+
138
+ double Meter::recallAtPrecision(double precisionQuery) const {
139
+ return recallAtPrecision(kAllLabels, precisionQuery);
140
+ }
141
+
142
+ double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
143
+ const auto& precisionRecall = precisionRecallCurve(labelId);
144
+ double bestRecall = 0.0;
145
+ std::for_each(
146
+ precisionRecall.begin(),
147
+ precisionRecall.end(),
148
+ [&bestRecall, precisionQuery](const std::pair<double, double>& element) {
149
+ if (element.first >= precisionQuery) {
150
+ bestRecall = std::max(bestRecall, element.second);
151
+ };
152
+ });
153
+ return bestRecall;
154
+ }
155
+
156
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
157
+ return precisionRecallCurve(kAllLabels);
158
+ }
159
+
160
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
161
+ int32_t labelId) const {
162
+ std::vector<std::pair<double, double>> precisionRecallCurve;
163
+ const auto& positiveCounts = getPositiveCounts(labelId);
164
+ if (positiveCounts.empty()) {
165
+ return precisionRecallCurve;
166
+ }
167
+
168
+ uint64_t golds =
169
+ (labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
170
+
171
+ auto fullRecall = std::lower_bound(
172
+ positiveCounts.begin(),
173
+ positiveCounts.end(),
174
+ golds,
175
+ utils::compareFirstLess);
176
+
177
+ if (fullRecall != positiveCounts.end()) {
178
+ fullRecall = std::next(fullRecall);
179
+ }
180
+
181
+ for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
182
+ double precision = 0.0;
183
+ double truePositives = it->first;
184
+ double falsePositives = it->second;
185
+ if (truePositives + falsePositives != 0.0) {
186
+ precision = truePositives / (truePositives + falsePositives);
187
+ }
188
+ double recall = golds != 0 ? (truePositives / double(golds))
189
+ : std::numeric_limits<double>::quiet_NaN();
190
+ precisionRecallCurve.emplace_back(precision, recall);
191
+ }
192
+ precisionRecallCurve.emplace_back(1.0, 0.0);
193
+
194
+ return precisionRecallCurve;
195
+ }
196
+
197
+ std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
198
+ std::vector<std::pair<real, real>> ret;
199
+ if (labelId == kAllLabels) {
200
+ for (const auto& k : labelMetrics_) {
201
+ auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
202
+ ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
203
+ }
204
+ } else {
205
+ if (labelMetrics_.count(labelId)) {
206
+ ret = labelMetrics_.at(labelId).scoreVsTrue;
207
+ }
208
+ }
209
+ sort(ret.begin(), ret.end());
210
+
211
+ return ret;
212
+ }
213
+
68
214
  } // namespace fasttext
@@ -22,8 +22,9 @@ class Meter {
22
22
  uint64_t gold;
23
23
  uint64_t predicted;
24
24
  uint64_t predictedGold;
25
+ mutable std::vector<std::pair<real, real>> scoreVsTrue;
25
26
 
26
- Metrics() : gold(0), predicted(0), predictedGold(0) {}
27
+ Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
27
28
 
28
29
  double precision() const {
29
30
  if (predicted == 0) {
@@ -43,18 +44,38 @@ class Meter {
43
44
  }
44
45
  return 2 * predictedGold / double(predicted + gold);
45
46
  }
47
+
48
+ std::vector<std::pair<real, real>> getScoreVsTrue() {
49
+ return scoreVsTrue;
50
+ }
46
51
  };
52
+ std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
53
+ int32_t labelId) const;
47
54
 
48
55
  public:
49
- Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
56
+ Meter() = delete;
57
+ explicit Meter(bool falseNegativeLabels)
58
+ : metrics_(),
59
+ nexamples_(0),
60
+ labelMetrics_(),
61
+ falseNegativeLabels_(falseNegativeLabels) {}
50
62
 
51
63
  void log(const std::vector<int32_t>& labels, const Predictions& predictions);
52
64
 
53
65
  double precision(int32_t);
54
66
  double recall(int32_t);
55
67
  double f1Score(int32_t);
68
+ std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
69
+ double precisionAtRecall(int32_t labelId, double recall) const;
70
+ double precisionAtRecall(double recall) const;
71
+ double recallAtPrecision(int32_t labelId, double recall) const;
72
+ double recallAtPrecision(double recall) const;
73
+ std::vector<std::pair<double, double>> precisionRecallCurve(
74
+ int32_t labelId) const;
75
+ std::vector<std::pair<double, double>> precisionRecallCurve() const;
56
76
  double precision() const;
57
77
  double recall() const;
78
+ double f1Score() const;
58
79
  uint64_t nexamples() const {
59
80
  return nexamples_;
60
81
  }
@@ -64,6 +85,7 @@ class Meter {
64
85
  Metrics metrics_{};
65
86
  uint64_t nexamples_;
66
87
  std::unordered_map<int32_t, Metrics> labelMetrics_;
88
+ bool falseNegativeLabels_;
67
89
  };
68
90
 
69
91
  } // namespace fasttext