fasttext 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <atomic>
14
14
  #include <chrono>
15
+ #include <functional>
15
16
  #include <iostream>
16
17
  #include <memory>
17
18
  #include <queue>
@@ -31,24 +32,29 @@
31
32
  namespace fasttext {
32
33
 
33
34
  class FastText {
35
+ public:
36
+ using TrainCallback =
37
+ std::function<void(float, float, double, double, int64_t)>;
38
+
34
39
  protected:
35
40
  std::shared_ptr<Args> args_;
36
41
  std::shared_ptr<Dictionary> dict_;
37
-
38
42
  std::shared_ptr<Matrix> input_;
39
43
  std::shared_ptr<Matrix> output_;
40
-
41
44
  std::shared_ptr<Model> model_;
42
-
43
45
  std::atomic<int64_t> tokenCount_{};
44
46
  std::atomic<real> loss_{};
45
-
46
47
  std::chrono::steady_clock::time_point start_;
48
+ bool quant_;
49
+ int32_t version;
50
+ std::unique_ptr<DenseMatrix> wordVectors_;
51
+ std::exception_ptr trainException_;
52
+
47
53
  void signModel(std::ostream&);
48
54
  bool checkModel(std::istream&);
49
- void startThreads();
55
+ void startThreads(const TrainCallback& callback = {});
50
56
  void addInputVector(Vector&, int32_t) const;
51
- void trainThread(int32_t);
57
+ void trainThread(int32_t, const TrainCallback& callback);
52
58
  std::vector<std::pair<real, std::string>> getNN(
53
59
  const DenseMatrix& wordVectors,
54
60
  const Vector& queryVec,
@@ -68,10 +74,11 @@ class FastText {
68
74
  const std::vector<int32_t>& labels);
69
75
  void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
70
76
  void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
71
-
72
- bool quant_;
73
- int32_t version;
74
- std::unique_ptr<DenseMatrix> wordVectors_;
77
+ std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
78
+ void precomputeWordVectors(DenseMatrix& wordVectors);
79
+ bool keepTraining(const int64_t ntokens) const;
80
+ void buildModel();
81
+ std::tuple<int64_t, double, double> progressInfo(real progress);
75
82
 
76
83
  public:
77
84
  FastText();
@@ -80,6 +87,8 @@ class FastText {
80
87
 
81
88
  int32_t getSubwordId(const std::string& subword) const;
82
89
 
90
+ int32_t getLabelId(const std::string& label) const;
91
+
83
92
  void getWordVector(Vector& vec, const std::string& word) const;
84
93
 
85
94
  void getSubwordVector(Vector& vec, const std::string& subword) const;
@@ -95,6 +104,10 @@ class FastText {
95
104
 
96
105
  std::shared_ptr<const DenseMatrix> getInputMatrix() const;
97
106
 
107
+ void setMatrices(
108
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
109
+ const std::shared_ptr<DenseMatrix>& outputMatrix);
110
+
98
111
  std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
99
112
 
100
113
  void saveVectors(const std::string& filename);
@@ -109,7 +122,7 @@ class FastText {
109
122
 
110
123
  void getSentenceVector(std::istream& in, Vector& vec);
111
124
 
112
- void quantize(const Args& qargs);
125
+ void quantize(const Args& qargs, const TrainCallback& callback = {});
113
126
 
114
127
  std::tuple<int64_t, double, double>
115
128
  test(std::istream& in, int32_t k, real threshold = 0.0);
@@ -141,51 +154,17 @@ class FastText {
141
154
  const std::string& wordB,
142
155
  const std::string& wordC);
143
156
 
144
- void train(const Args& args);
157
+ void train(const Args& args, const TrainCallback& callback = {});
158
+
159
+ void abort();
145
160
 
146
161
  int getDimension() const;
147
162
 
148
163
  bool isQuant() const;
149
164
 
150
- FASTTEXT_DEPRECATED("loadVectors is being deprecated.")
151
- void loadVectors(const std::string& filename);
152
-
153
- FASTTEXT_DEPRECATED(
154
- "getVector is being deprecated and replaced by getWordVector.")
155
- void getVector(Vector& vec, const std::string& word) const;
156
-
157
- FASTTEXT_DEPRECATED(
158
- "ngramVectors is being deprecated and replaced by getNgramVectors.")
159
- void ngramVectors(std::string word);
160
-
161
- FASTTEXT_DEPRECATED(
162
- "analogies is being deprecated and replaced by getAnalogies.")
163
- void analogies(int32_t k);
164
-
165
- FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
166
- std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
167
-
168
- FASTTEXT_DEPRECATED(
169
- "saveVectors is being deprecated, please use the other signature.")
170
- void saveVectors();
171
-
172
- FASTTEXT_DEPRECATED(
173
- "saveOutput is being deprecated, please use the other signature.")
174
- void saveOutput();
175
-
176
- FASTTEXT_DEPRECATED(
177
- "saveModel is being deprecated, please use the other signature.")
178
- void saveModel();
179
-
180
- FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
181
- void precomputeWordVectors(DenseMatrix& wordVectors);
182
-
183
- FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
184
- void findNN(
185
- const DenseMatrix& wordVectors,
186
- const Vector& query,
187
- int32_t k,
188
- const std::set<std::string>& banSet,
189
- std::vector<std::pair<real, std::string>>& results);
165
+ class AbortError : public std::runtime_error {
166
+ public:
167
+ AbortError() : std::runtime_error("Aborted.") {}
168
+ };
190
169
  };
191
170
  } // namespace fasttext
@@ -11,6 +11,7 @@
11
11
  #include <queue>
12
12
  #include <stdexcept>
13
13
  #include "args.h"
14
+ #include "autotune.h"
14
15
  #include "fasttext.h"
15
16
 
16
17
  using namespace fasttext;
@@ -20,19 +21,25 @@ void printUsage() {
20
21
  << "usage: fasttext <command> <args>\n\n"
21
22
  << "The commands supported by fasttext are:\n\n"
22
23
  << " supervised train a supervised classifier\n"
23
- << " quantize quantize a model to reduce the memory usage\n"
24
+ << " quantize quantize a model to reduce the memory "
25
+ "usage\n"
24
26
  << " test evaluate a supervised classifier\n"
25
- << " test-label print labels with precision and recall scores\n"
27
+ << " test-label print labels with precision and recall "
28
+ "scores\n"
26
29
  << " predict predict most likely labels\n"
27
- << " predict-prob predict most likely labels with probabilities\n"
30
+ << " predict-prob predict most likely labels with "
31
+ "probabilities\n"
28
32
  << " skipgram train a skipgram model\n"
29
33
  << " cbow train a cbow model\n"
30
34
  << " print-word-vectors print word vectors given a trained model\n"
31
- << " print-sentence-vectors print sentence vectors given a trained model\n"
32
- << " print-ngrams print ngrams given a trained model and word\n"
35
+ << " print-sentence-vectors print sentence vectors given a trained "
36
+ "model\n"
37
+ << " print-ngrams print ngrams given a trained model and "
38
+ "word\n"
33
39
  << " nn query for nearest neighbors\n"
34
40
  << " analogies query for analogies\n"
35
- << " dump dump arguments,dictionary,input/output vectors\n"
41
+ << " dump dump arguments,dictionary,input/output "
42
+ "vectors\n"
36
43
  << std::endl;
37
44
  }
38
45
 
@@ -141,7 +148,7 @@ void test(const std::vector<std::string>& args) {
141
148
  FastText fasttext;
142
149
  fasttext.loadModel(model);
143
150
 
144
- Meter meter;
151
+ Meter meter(false);
145
152
 
146
153
  if (input == "-") {
147
154
  fasttext.test(std::cin, k, threshold, meter);
@@ -351,19 +358,31 @@ void analogies(const std::vector<std::string> args) {
351
358
  void train(const std::vector<std::string> args) {
352
359
  Args a = Args();
353
360
  a.parseArgs(args);
354
- FastText fasttext;
355
- std::string outputFileName(a.output + ".bin");
361
+ std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
362
+ std::string outputFileName;
363
+
364
+ if (a.hasAutotune() &&
365
+ a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
366
+ outputFileName = a.output + ".ftz";
367
+ } else {
368
+ outputFileName = a.output + ".bin";
369
+ }
356
370
  std::ofstream ofs(outputFileName);
357
371
  if (!ofs.is_open()) {
358
372
  throw std::invalid_argument(
359
373
  outputFileName + " cannot be opened for saving.");
360
374
  }
361
375
  ofs.close();
362
- fasttext.train(a);
363
- fasttext.saveModel(outputFileName);
364
- fasttext.saveVectors(a.output + ".vec");
376
+ if (a.hasAutotune()) {
377
+ Autotune autotune(fasttext);
378
+ autotune.train(a);
379
+ } else {
380
+ fasttext->train(a);
381
+ }
382
+ fasttext->saveModel(outputFileName);
383
+ fasttext->saveVectors(a.output + ".vec");
365
384
  if (a.saveOutput) {
366
- fasttext.saveOutput(a.output + ".output");
385
+ fasttext->saveOutput(a.output + ".output");
367
386
  }
368
387
  }
369
388
 
@@ -16,6 +16,9 @@
16
16
 
17
17
  namespace fasttext {
18
18
 
19
+ constexpr int32_t kAllLabels = -1;
20
+ constexpr real falseNegativeScore = -1.0;
21
+
19
22
  void Meter::log(
20
23
  const std::vector<int32_t>& labels,
21
24
  const Predictions& predictions) {
@@ -26,14 +29,23 @@ void Meter::log(
26
29
  for (const auto& prediction : predictions) {
27
30
  labelMetrics_[prediction.second].predicted++;
28
31
 
32
+ real score = std::min(std::exp(prediction.first), 1.0f);
33
+ real gold = 0.0;
29
34
  if (utils::contains(labels, prediction.second)) {
30
35
  labelMetrics_[prediction.second].predictedGold++;
31
36
  metrics_.predictedGold++;
37
+ gold = 1.0;
32
38
  }
39
+ labelMetrics_[prediction.second].scoreVsTrue.emplace_back(score, gold);
33
40
  }
34
41
 
35
- for (const auto& label : labels) {
36
- labelMetrics_[label].gold++;
42
+ if (falseNegativeLabels_) {
43
+ for (const auto& label : labels) {
44
+ labelMetrics_[label].gold++;
45
+ if (!utils::containsSecond(predictions, label)) {
46
+ labelMetrics_[label].scoreVsTrue.emplace_back(falseNegativeScore, 1.0);
47
+ }
48
+ }
37
49
  }
38
50
  }
39
51
 
@@ -57,6 +69,15 @@ double Meter::recall() const {
57
69
  return metrics_.recall();
58
70
  }
59
71
 
72
+ double Meter::f1Score() const {
73
+ const double precision = this->precision();
74
+ const double recall = this->recall();
75
+ if (precision + recall != 0) {
76
+ return 2 * precision * recall / (precision + recall);
77
+ }
78
+ return std::numeric_limits<double>::quiet_NaN();
79
+ }
80
+
60
81
  void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
61
82
  out << "N"
62
83
  << "\t" << nexamples_ << std::endl;
@@ -65,4 +86,129 @@ void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
65
86
  out << "R@" << k << "\t" << metrics_.recall() << std::endl;
66
87
  }
67
88
 
89
+ std::vector<std::pair<uint64_t, uint64_t>> Meter::getPositiveCounts(
90
+ int32_t labelId) const {
91
+ std::vector<std::pair<uint64_t, uint64_t>> positiveCounts;
92
+
93
+ const auto& v = scoreVsTrue(labelId);
94
+ uint64_t truePositives = 0;
95
+ uint64_t falsePositives = 0;
96
+ double lastScore = falseNegativeScore - 1.0;
97
+
98
+ for (auto it = v.rbegin(); it != v.rend(); ++it) {
99
+ double score = it->first;
100
+ double gold = it->second;
101
+ if (score < 0) { // only reachable recall
102
+ break;
103
+ }
104
+ if (gold == 1.0) {
105
+ truePositives++;
106
+ } else {
107
+ falsePositives++;
108
+ }
109
+ if (score == lastScore && positiveCounts.size()) { // squeeze tied scores
110
+ positiveCounts.back() = {truePositives, falsePositives};
111
+ } else {
112
+ positiveCounts.emplace_back(truePositives, falsePositives);
113
+ }
114
+ lastScore = score;
115
+ }
116
+
117
+ return positiveCounts;
118
+ }
119
+
120
+ double Meter::precisionAtRecall(double recallQuery) const {
121
+ return precisionAtRecall(kAllLabels, recallQuery);
122
+ }
123
+
124
+ double Meter::precisionAtRecall(int32_t labelId, double recallQuery) const {
125
+ const auto& precisionRecall = precisionRecallCurve(labelId);
126
+ double bestPrecision = 0.0;
127
+ std::for_each(
128
+ precisionRecall.begin(),
129
+ precisionRecall.end(),
130
+ [&bestPrecision, recallQuery](const std::pair<double, double>& element) {
131
+ if (element.second >= recallQuery) {
132
+ bestPrecision = std::max(bestPrecision, element.first);
133
+ };
134
+ });
135
+ return bestPrecision;
136
+ }
137
+
138
+ double Meter::recallAtPrecision(double precisionQuery) const {
139
+ return recallAtPrecision(kAllLabels, precisionQuery);
140
+ }
141
+
142
+ double Meter::recallAtPrecision(int32_t labelId, double precisionQuery) const {
143
+ const auto& precisionRecall = precisionRecallCurve(labelId);
144
+ double bestRecall = 0.0;
145
+ std::for_each(
146
+ precisionRecall.begin(),
147
+ precisionRecall.end(),
148
+ [&bestRecall, precisionQuery](const std::pair<double, double>& element) {
149
+ if (element.first >= precisionQuery) {
150
+ bestRecall = std::max(bestRecall, element.second);
151
+ };
152
+ });
153
+ return bestRecall;
154
+ }
155
+
156
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve() const {
157
+ return precisionRecallCurve(kAllLabels);
158
+ }
159
+
160
+ std::vector<std::pair<double, double>> Meter::precisionRecallCurve(
161
+ int32_t labelId) const {
162
+ std::vector<std::pair<double, double>> precisionRecallCurve;
163
+ const auto& positiveCounts = getPositiveCounts(labelId);
164
+ if (positiveCounts.empty()) {
165
+ return precisionRecallCurve;
166
+ }
167
+
168
+ uint64_t golds =
169
+ (labelId == kAllLabels) ? metrics_.gold : labelMetrics_.at(labelId).gold;
170
+
171
+ auto fullRecall = std::lower_bound(
172
+ positiveCounts.begin(),
173
+ positiveCounts.end(),
174
+ golds,
175
+ utils::compareFirstLess);
176
+
177
+ if (fullRecall != positiveCounts.end()) {
178
+ fullRecall = std::next(fullRecall);
179
+ }
180
+
181
+ for (auto it = positiveCounts.begin(); it != fullRecall; it++) {
182
+ double precision = 0.0;
183
+ double truePositives = it->first;
184
+ double falsePositives = it->second;
185
+ if (truePositives + falsePositives != 0.0) {
186
+ precision = truePositives / (truePositives + falsePositives);
187
+ }
188
+ double recall = golds != 0 ? (truePositives / double(golds))
189
+ : std::numeric_limits<double>::quiet_NaN();
190
+ precisionRecallCurve.emplace_back(precision, recall);
191
+ }
192
+ precisionRecallCurve.emplace_back(1.0, 0.0);
193
+
194
+ return precisionRecallCurve;
195
+ }
196
+
197
+ std::vector<std::pair<real, real>> Meter::scoreVsTrue(int32_t labelId) const {
198
+ std::vector<std::pair<real, real>> ret;
199
+ if (labelId == kAllLabels) {
200
+ for (const auto& k : labelMetrics_) {
201
+ auto& labelScoreVsTrue = labelMetrics_.at(k.first).scoreVsTrue;
202
+ ret.insert(ret.end(), labelScoreVsTrue.begin(), labelScoreVsTrue.end());
203
+ }
204
+ } else {
205
+ if (labelMetrics_.count(labelId)) {
206
+ ret = labelMetrics_.at(labelId).scoreVsTrue;
207
+ }
208
+ }
209
+ sort(ret.begin(), ret.end());
210
+
211
+ return ret;
212
+ }
213
+
68
214
  } // namespace fasttext
@@ -22,8 +22,9 @@ class Meter {
22
22
  uint64_t gold;
23
23
  uint64_t predicted;
24
24
  uint64_t predictedGold;
25
+ mutable std::vector<std::pair<real, real>> scoreVsTrue;
25
26
 
26
- Metrics() : gold(0), predicted(0), predictedGold(0) {}
27
+ Metrics() : gold(0), predicted(0), predictedGold(0), scoreVsTrue() {}
27
28
 
28
29
  double precision() const {
29
30
  if (predicted == 0) {
@@ -43,18 +44,38 @@ class Meter {
43
44
  }
44
45
  return 2 * predictedGold / double(predicted + gold);
45
46
  }
47
+
48
+ std::vector<std::pair<real, real>> getScoreVsTrue() {
49
+ return scoreVsTrue;
50
+ }
46
51
  };
52
+ std::vector<std::pair<uint64_t, uint64_t>> getPositiveCounts(
53
+ int32_t labelId) const;
47
54
 
48
55
  public:
49
- Meter() : metrics_(), nexamples_(0), labelMetrics_() {}
56
+ Meter() = delete;
57
+ explicit Meter(bool falseNegativeLabels)
58
+ : metrics_(),
59
+ nexamples_(0),
60
+ labelMetrics_(),
61
+ falseNegativeLabels_(falseNegativeLabels) {}
50
62
 
51
63
  void log(const std::vector<int32_t>& labels, const Predictions& predictions);
52
64
 
53
65
  double precision(int32_t);
54
66
  double recall(int32_t);
55
67
  double f1Score(int32_t);
68
+ std::vector<std::pair<real, real>> scoreVsTrue(int32_t labelId) const;
69
+ double precisionAtRecall(int32_t labelId, double recall) const;
70
+ double precisionAtRecall(double recall) const;
71
+ double recallAtPrecision(int32_t labelId, double recall) const;
72
+ double recallAtPrecision(double recall) const;
73
+ std::vector<std::pair<double, double>> precisionRecallCurve(
74
+ int32_t labelId) const;
75
+ std::vector<std::pair<double, double>> precisionRecallCurve() const;
56
76
  double precision() const;
57
77
  double recall() const;
78
+ double f1Score() const;
58
79
  uint64_t nexamples() const {
59
80
  return nexamples_;
60
81
  }
@@ -64,6 +85,7 @@ class Meter {
64
85
  Metrics metrics_{};
65
86
  uint64_t nexamples_;
66
87
  std::unordered_map<int32_t, Metrics> labelMetrics_;
88
+ bool falseNegativeLabels_;
67
89
  };
68
90
 
69
91
  } // namespace fasttext