fasttext 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,18 +11,28 @@
11
11
  #include <istream>
12
12
  #include <ostream>
13
13
  #include <string>
14
+ #include <unordered_set>
14
15
  #include <vector>
15
16
 
16
17
  namespace fasttext {
17
18
 
18
19
  enum class model_name : int { cbow = 1, sg, sup };
19
20
  enum class loss_name : int { hs = 1, ns, softmax, ova };
21
+ enum class metric_name : int {
22
+ f1score = 1,
23
+ f1scoreLabel,
24
+ precisionAtRecall,
25
+ precisionAtRecallLabel,
26
+ recallAtPrecision,
27
+ recallAtPrecisionLabel
28
+ };
20
29
 
21
30
  class Args {
22
31
  protected:
23
- std::string lossToString(loss_name) const;
24
32
  std::string boolToString(bool) const;
25
33
  std::string modelToString(model_name) const;
34
+ std::string metricToString(metric_name) const;
35
+ std::unordered_set<std::string> manualArgs_;
26
36
 
27
37
  public:
28
38
  Args();
@@ -48,6 +58,7 @@ class Args {
48
58
  int verbose;
49
59
  std::string pretrainedVectors;
50
60
  bool saveOutput;
61
+ int seed;
51
62
 
52
63
  bool qout;
53
64
  bool retrain;
@@ -55,14 +66,31 @@ class Args {
55
66
  size_t cutoff;
56
67
  size_t dsub;
57
68
 
69
+ std::string autotuneValidationFile;
70
+ std::string autotuneMetric;
71
+ int autotunePredictions;
72
+ int autotuneDuration;
73
+ std::string autotuneModelSize;
74
+
58
75
  void parseArgs(const std::vector<std::string>& args);
59
76
  void printHelp();
60
77
  void printBasicHelp();
61
78
  void printDictionaryHelp();
62
79
  void printTrainingHelp();
80
+ void printAutotuneHelp();
63
81
  void printQuantizationHelp();
64
82
  void save(std::ostream&);
65
83
  void load(std::istream&);
66
84
  void dump(std::ostream&) const;
85
+ bool hasAutotune() const;
86
+ bool isManual(const std::string& argName) const;
87
+ void setManual(const std::string& argName);
88
+ std::string lossToString(loss_name) const;
89
+ metric_name getAutotuneMetric() const;
90
+ std::string getAutotuneMetricLabel() const;
91
+ double getAutotuneMetricValue() const;
92
+ int64_t getAutotuneModelSize() const;
93
+
94
+ static constexpr double kUnlimitedModelSize = -1.0;
67
95
  };
68
96
  } // namespace fasttext
@@ -0,0 +1,477 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the MIT license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #include "autotune.h"
10
+
11
+ #include <algorithm>
12
+ #include <csignal>
13
+ #include <functional>
14
+ #include <iomanip>
15
+ #include <iostream>
16
+ #include <random>
17
+ #include <thread>
18
+
19
+ #define LOG_VAL(name, val) \
20
+ if (autotuneArgs.verbose > 2) { \
21
+ std::cout << #name " = " << val << std::endl; \
22
+ }
23
+ #define LOG_VAL_NAN(name, val) \
24
+ if (autotuneArgs.verbose > 2) { \
25
+ if (std::isnan(val)) { \
26
+ std::cout << #name " = NaN" << std::endl; \
27
+ } else { \
28
+ std::cout << #name " = " << val << std::endl; \
29
+ } \
30
+ }
31
+
32
+ namespace {
33
+
34
+ std::function<void()> interruptSignalHandler;
35
+
36
+ void signalHandler(int signal) {
37
+ if (signal == SIGINT) {
38
+ interruptSignalHandler();
39
+ }
40
+ }
41
+
42
+ class ElapsedTimeMarker {
43
+ std::chrono::steady_clock::time_point start_;
44
+
45
+ public:
46
+ ElapsedTimeMarker() {
47
+ start_ = std::chrono::steady_clock::now();
48
+ }
49
+ double getElapsed() {
50
+ return fasttext::utils::getDuration(
51
+ start_, std::chrono::steady_clock::now());
52
+ }
53
+ };
54
+
55
+ } // namespace
56
+
57
+ namespace fasttext {
58
+
59
+ constexpr double kUnknownBestScore = -1.0;
60
+ constexpr int kCutoffLimit = 256;
61
+
62
+ template <typename T>
63
+ T getArgGauss(
64
+ T val,
65
+ std::minstd_rand& rng,
66
+ double startSigma,
67
+ double endSigma,
68
+ double t,
69
+ bool linear) {
70
+ T returnValue;
71
+ const double stddev = startSigma -
72
+ ((startSigma - endSigma) / 0.5) *
73
+ std::min(0.5, std::max((t - 0.25), 0.0));
74
+
75
+ std::normal_distribution<double> normal(0.0, stddev);
76
+
77
+ const double coeff = normal(rng);
78
+ double updateCoeff = 0.0;
79
+
80
+ if (linear) {
81
+ updateCoeff = coeff;
82
+ returnValue = static_cast<T>(updateCoeff + val);
83
+ } else {
84
+ updateCoeff = std::pow(2.0, coeff);
85
+ returnValue = static_cast<T>(updateCoeff * val);
86
+ }
87
+
88
+ return returnValue;
89
+ }
90
+
91
+ template <typename T>
92
+ T updateArgGauss(
93
+ T val,
94
+ T min,
95
+ T max,
96
+ double startSigma,
97
+ double endSigma,
98
+ double t,
99
+ bool linear,
100
+ std::minstd_rand& rng) {
101
+ T retVal = getArgGauss(val, rng, startSigma, endSigma, t, linear);
102
+ if (retVal > max) {
103
+ retVal = max;
104
+ }
105
+ if (retVal < min) {
106
+ retVal = min;
107
+ }
108
+ return retVal;
109
+ }
110
+
111
+ AutotuneStrategy::AutotuneStrategy(
112
+ const Args& originalArgs,
113
+ std::minstd_rand::result_type seed)
114
+ : bestArgs_(),
115
+ maxDuration_(originalArgs.autotuneDuration),
116
+ rng_(seed),
117
+ trials_(0),
118
+ bestMinnIndex_(0),
119
+ bestDsubExponent_(1),
120
+ bestNonzeroBucket_(2000000),
121
+ originalBucket_(originalArgs.bucket) {
122
+ minnChoices_ = {0, 2, 3};
123
+ updateBest(originalArgs);
124
+ }
125
+
126
+ Args AutotuneStrategy::ask(double elapsed) {
127
+ const double t = std::min(1.0, elapsed / maxDuration_);
128
+ trials_++;
129
+
130
+ if (trials_ == 1) {
131
+ return bestArgs_;
132
+ }
133
+
134
+ Args args = bestArgs_;
135
+
136
+ if (!args.isManual("epoch")) {
137
+ args.epoch = updateArgGauss(args.epoch, 1, 100, 2.8, 2.5, t, false, rng_);
138
+ }
139
+ if (!args.isManual("lr")) {
140
+ args.lr = updateArgGauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, rng_);
141
+ };
142
+ if (!args.isManual("dim")) {
143
+ args.dim = updateArgGauss(args.dim, 1, 1000, 1.4, 0.3, t, false, rng_);
144
+ }
145
+ if (!args.isManual("wordNgrams")) {
146
+ args.wordNgrams =
147
+ updateArgGauss(args.wordNgrams, 1, 5, 4.3, 2.4, t, true, rng_);
148
+ }
149
+ if (!args.isManual("dsub")) {
150
+ int dsubExponent =
151
+ updateArgGauss(bestDsubExponent_, 1, 4, 2.0, 1.0, t, true, rng_);
152
+ args.dsub = (1 << dsubExponent);
153
+ }
154
+ if (!args.isManual("minn")) {
155
+ int minnIndex = updateArgGauss(
156
+ bestMinnIndex_,
157
+ 0,
158
+ static_cast<int>(minnChoices_.size() - 1),
159
+ 4.0,
160
+ 1.4,
161
+ t,
162
+ true,
163
+ rng_);
164
+ args.minn = minnChoices_[minnIndex];
165
+ }
166
+ if (!args.isManual("maxn")) {
167
+ if (args.minn == 0) {
168
+ args.maxn = 0;
169
+ } else {
170
+ args.maxn = args.minn + 3;
171
+ }
172
+ }
173
+ if (!args.isManual("bucket")) {
174
+ int nonZeroBucket = updateArgGauss(
175
+ bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
176
+ args.bucket = nonZeroBucket;
177
+ } else {
178
+ args.bucket = originalBucket_;
179
+ }
180
+ if (args.wordNgrams <= 1 && args.maxn == 0) {
181
+ args.bucket = 0;
182
+ }
183
+ if (!args.isManual("loss")) {
184
+ args.loss = loss_name::softmax;
185
+ }
186
+
187
+ return args;
188
+ }
189
+
190
+ int AutotuneStrategy::getIndex(int val, const std::vector<int>& choices) {
191
+ auto found = std::find(choices.begin(), choices.end(), val);
192
+ int ind = 0;
193
+ if (found != choices.end()) {
194
+ ind = std::distance(choices.begin(), found);
195
+ }
196
+ return ind;
197
+ }
198
+
199
+ void AutotuneStrategy::updateBest(const Args& args) {
200
+ bestArgs_ = args;
201
+ bestMinnIndex_ = getIndex(args.minn, minnChoices_);
202
+ bestDsubExponent_ = log2(args.dsub);
203
+ if (args.bucket != 0) {
204
+ bestNonzeroBucket_ = args.bucket;
205
+ }
206
+ }
207
+
208
+ Autotune::Autotune(const std::shared_ptr<FastText>& fastText)
209
+ : fastText_(fastText),
210
+ elapsed_(0.),
211
+ bestScore_(0.),
212
+ trials_(0),
213
+ sizeConstraintFailed_(0),
214
+ continueTraining_(false),
215
+ strategy_(),
216
+ timer_() {}
217
+
218
+ void Autotune::printInfo(double maxDuration) {
219
+ double progress = elapsed_ * 100 / maxDuration;
220
+ progress = std::min(progress, 100.0);
221
+
222
+ std::cerr << "\r";
223
+ std::cerr << std::fixed;
224
+ std::cerr << "Progress: ";
225
+ std::cerr << std::setprecision(1) << std::setw(5) << progress << "%";
226
+ std::cerr << " Trials: " << std::setw(4) << trials_;
227
+ std::cerr << " Best score: " << std::setw(9) << std::setprecision(6);
228
+ if (bestScore_ == kUnknownBestScore) {
229
+ std::cerr << "unknown";
230
+ } else {
231
+ std::cerr << bestScore_;
232
+ }
233
+ std::cerr << " ETA: "
234
+ << utils::ClockPrint(std::max(maxDuration - elapsed_, 0.0));
235
+ std::cerr << std::flush;
236
+ }
237
+
238
+ void Autotune::timer(
239
+ const std::chrono::steady_clock::time_point& start,
240
+ double maxDuration) {
241
+ elapsed_ = 0.0;
242
+ while (keepTraining(maxDuration)) {
243
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
244
+ elapsed_ = utils::getDuration(start, std::chrono::steady_clock::now());
245
+ printInfo(maxDuration);
246
+ }
247
+ abort();
248
+ }
249
+
250
+ bool Autotune::keepTraining(double maxDuration) const {
251
+ return continueTraining_ && elapsed_ < maxDuration;
252
+ }
253
+
254
+ void Autotune::abort() {
255
+ if (continueTraining_) {
256
+ continueTraining_ = false;
257
+ fastText_->abort();
258
+ }
259
+ }
260
+
261
+ void Autotune::startTimer(const Args& args) {
262
+ std::chrono::steady_clock::time_point start =
263
+ std::chrono::steady_clock::now();
264
+ timer_ = std::thread([=]() { timer(start, args.autotuneDuration); });
265
+ bestScore_ = kUnknownBestScore;
266
+ trials_ = 0;
267
+ continueTraining_ = true;
268
+
269
+ auto previousSignalHandler = std::signal(SIGINT, signalHandler);
270
+ interruptSignalHandler = [&]() {
271
+ std::signal(SIGINT, previousSignalHandler);
272
+ std::cerr << std::endl << "Aborting autotune..." << std::endl;
273
+ abort();
274
+ };
275
+ }
276
+
277
+ double Autotune::getMetricScore(
278
+ Meter& meter,
279
+ const metric_name& metricName,
280
+ const double metricValue,
281
+ const std::string& metricLabel) const {
282
+ double score = 0.0;
283
+ int32_t labelId = -1;
284
+ if (!metricLabel.empty()) {
285
+ labelId = fastText_->getLabelId(metricLabel);
286
+ if (labelId == -1) {
287
+ throw std::runtime_error("Unknown autotune metric label");
288
+ }
289
+ }
290
+ if (metricName == metric_name::f1score) {
291
+ score = meter.f1Score();
292
+ } else if (metricName == metric_name::f1scoreLabel) {
293
+ score = meter.f1Score(labelId);
294
+ } else if (metricName == metric_name::precisionAtRecall) {
295
+ score = meter.precisionAtRecall(metricValue);
296
+ } else if (metricName == metric_name::precisionAtRecallLabel) {
297
+ score = meter.precisionAtRecall(labelId, metricValue);
298
+ } else if (metricName == metric_name::recallAtPrecision) {
299
+ score = meter.recallAtPrecision(metricValue);
300
+ } else if (metricName == metric_name::recallAtPrecisionLabel) {
301
+ score = meter.recallAtPrecision(labelId, metricValue);
302
+ } else {
303
+ throw std::runtime_error("Unknown metric");
304
+ }
305
+ return score;
306
+ }
307
+
308
+ void Autotune::printArgs(const Args& args, const Args& autotuneArgs) {
309
+ LOG_VAL(epoch, args.epoch)
310
+ LOG_VAL(lr, args.lr)
311
+ LOG_VAL(dim, args.dim)
312
+ LOG_VAL(minCount, args.minCount)
313
+ LOG_VAL(wordNgrams, args.wordNgrams)
314
+ LOG_VAL(minn, args.minn)
315
+ LOG_VAL(maxn, args.maxn)
316
+ LOG_VAL(bucket, args.bucket)
317
+ LOG_VAL(dsub, args.dsub)
318
+ LOG_VAL(loss, args.lossToString(args.loss))
319
+ }
320
+
321
+ int Autotune::getCutoffForFileSize(
322
+ bool qout,
323
+ bool qnorm,
324
+ int dsub,
325
+ int64_t fileSize) const {
326
+ int64_t outModelSize = 0;
327
+ const int64_t outM = fastText_->getOutputMatrix()->size(0);
328
+ const int64_t outN = fastText_->getOutputMatrix()->size(1);
329
+ if (qout) {
330
+ const int64_t outputPqSize = 16 + 4 * (outN * (1 << 8));
331
+ outModelSize =
332
+ 21 + (outM * ((outN + 2 - 1) / 2)) + outputPqSize + (qnorm ? outM : 0);
333
+ } else {
334
+ outModelSize = 16 + 4 * (outM * outN);
335
+ }
336
+ const int64_t dim = fastText_->getInputMatrix()->size(1);
337
+
338
+ int target = (fileSize - (107) - 4 * (1 << 8) * dim - outModelSize);
339
+ int cutoff = target / ((dim + dsub - 1) / dsub + (qnorm ? 1 : 0) + 10);
340
+
341
+ return std::max(cutoff, kCutoffLimit);
342
+ }
343
+
344
+ bool Autotune::quantize(Args& args, const Args& autotuneArgs) {
345
+ if (autotuneArgs.getAutotuneModelSize() == Args::kUnlimitedModelSize) {
346
+ return true;
347
+ }
348
+ auto outputSize = fastText_->getOutputMatrix()->size(0);
349
+
350
+ args.qnorm = true;
351
+ args.qout = (outputSize >= kCutoffLimit);
352
+ args.retrain = true;
353
+ args.cutoff = getCutoffForFileSize(
354
+ args.qout, args.qnorm, args.dsub, autotuneArgs.getAutotuneModelSize());
355
+ LOG_VAL(cutoff, args.cutoff);
356
+ if (args.cutoff == kCutoffLimit) {
357
+ return false;
358
+ }
359
+ fastText_->quantize(args);
360
+
361
+ return true;
362
+ }
363
+
364
+ void Autotune::printSkippedArgs(const Args& autotuneArgs) {
365
+ std::unordered_set<std::string> argsToCheck = {"epoch",
366
+ "lr",
367
+ "dim",
368
+ "wordNgrams",
369
+ "loss",
370
+ "bucket",
371
+ "minn",
372
+ "maxn",
373
+ "dsub"};
374
+ for (const auto& arg : argsToCheck) {
375
+ if (autotuneArgs.isManual(arg)) {
376
+ std::cerr << "Warning : " << arg
377
+ << " is manually set to a specific value. "
378
+ << "It will not be automatically optimized." << std::endl;
379
+ }
380
+ }
381
+ }
382
+
383
+ void Autotune::train(const Args& autotuneArgs) {
384
+ std::ifstream validationFileStream(autotuneArgs.autotuneValidationFile);
385
+ if (!validationFileStream.is_open()) {
386
+ throw std::invalid_argument("Validation file cannot be opened!");
387
+ }
388
+ printSkippedArgs(autotuneArgs);
389
+
390
+ bool sizeConstraintWarning = false;
391
+ int verbose = autotuneArgs.verbose;
392
+ Args bestTrainArgs(autotuneArgs);
393
+ Args trainArgs(autotuneArgs);
394
+ trainArgs.verbose = 0;
395
+ strategy_ = std::unique_ptr<AutotuneStrategy>(
396
+ new AutotuneStrategy(trainArgs, autotuneArgs.seed));
397
+ startTimer(autotuneArgs);
398
+
399
+ while (keepTraining(autotuneArgs.autotuneDuration)) {
400
+ trials_++;
401
+
402
+ trainArgs = strategy_->ask(elapsed_);
403
+ LOG_VAL(Trial, trials_)
404
+ printArgs(trainArgs, autotuneArgs);
405
+ ElapsedTimeMarker elapsedTimeMarker;
406
+ double currentScore = std::numeric_limits<double>::quiet_NaN();
407
+ try {
408
+ fastText_->train(trainArgs);
409
+ bool sizeConstraintOK = quantize(trainArgs, autotuneArgs);
410
+ if (sizeConstraintOK) {
411
+ const auto& metricLabel = autotuneArgs.getAutotuneMetricLabel();
412
+ Meter meter(!metricLabel.empty());
413
+ fastText_->test(
414
+ validationFileStream, autotuneArgs.autotunePredictions, 0.0, meter);
415
+
416
+ currentScore = getMetricScore(
417
+ meter,
418
+ autotuneArgs.getAutotuneMetric(),
419
+ autotuneArgs.getAutotuneMetricValue(),
420
+ metricLabel);
421
+
422
+ if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
423
+ bestTrainArgs = trainArgs;
424
+ bestScore_ = currentScore;
425
+ strategy_->updateBest(bestTrainArgs);
426
+ }
427
+ } else {
428
+ sizeConstraintFailed_++;
429
+ if (!sizeConstraintWarning && trials_ > 10 &&
430
+ sizeConstraintFailed_ > (trials_ / 2)) {
431
+ sizeConstraintWarning = true;
432
+ std::cerr << std::endl
433
+ << "Warning : requested model size is probably too small. "
434
+ "You may want to increase `autotune-modelsize`."
435
+ << std::endl;
436
+ }
437
+ }
438
+ } catch (DenseMatrix::EncounteredNaNError&) {
439
+ // ignore diverging loss and go on
440
+ } catch (std::bad_alloc&) {
441
+ // ignore parameter samples asking too much memory
442
+ } catch (TimeoutError&) {
443
+ break;
444
+ } catch (FastText::AbortError&) {
445
+ break;
446
+ }
447
+ LOG_VAL_NAN(currentScore, currentScore)
448
+ LOG_VAL(train took, elapsedTimeMarker.getElapsed())
449
+ }
450
+ if (timer_.joinable()) {
451
+ timer_.join();
452
+ }
453
+
454
+ if (bestScore_ == kUnknownBestScore) {
455
+ std::string errorMessage;
456
+ if (sizeConstraintWarning) {
457
+ errorMessage =
458
+ "Couldn't fulfil model size constraint: please increase "
459
+ "`autotune-modelsize`.";
460
+ } else {
461
+ errorMessage =
462
+ "Didn't have enough time to train once: please increase "
463
+ "`autotune-duration`.";
464
+ }
465
+ throw std::runtime_error(errorMessage);
466
+ } else {
467
+ std::cerr << std::endl;
468
+ std::cerr << "Training again with best arguments" << std::endl;
469
+ bestTrainArgs.verbose = verbose;
470
+ LOG_VAL(Best selected args, 0)
471
+ printArgs(bestTrainArgs, autotuneArgs);
472
+ fastText_->train(bestTrainArgs);
473
+ quantize(bestTrainArgs, autotuneArgs);
474
+ }
475
+ }
476
+
477
+ } // namespace fasttext