fasttext 0.1.2 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
47
47
  }
48
48
  }
49
49
 
50
- FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
50
+ FastText::FastText()
51
+ : quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
51
52
 
52
53
  void FastText::addInputVector(Vector& vec, int32_t ind) const {
53
54
  vec.addRow(*input_, ind);
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
69
70
  return std::dynamic_pointer_cast<DenseMatrix>(input_);
70
71
  }
71
72
 
73
+ void FastText::setMatrices(
74
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
75
+ const std::shared_ptr<DenseMatrix>& outputMatrix) {
76
+ assert(input_->size(1) == output_->size(1));
77
+
78
+ input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
79
+ output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
80
+ wordVectors_.reset();
81
+ args_->dim = input_->size(1);
82
+
83
+ buildModel();
84
+ }
85
+
72
86
  std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
73
87
  if (quant_ && args_->qout) {
74
88
  throw std::runtime_error("Can't export quantized matrix");
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
86
100
  return dict_->nwords() + h;
87
101
  }
88
102
 
103
+ int32_t FastText::getLabelId(const std::string& label) const {
104
+ int32_t labelId = dict_->getId(label);
105
+ if (labelId != -1) {
106
+ labelId -= dict_->nwords();
107
+ }
108
+ return labelId;
109
+ }
110
+
89
111
  void FastText::getWordVector(Vector& vec, const std::string& word) const {
90
112
  const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
91
113
  vec.zero();
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
97
119
  }
98
120
  }
99
121
 
100
- void FastText::getVector(Vector& vec, const std::string& word) const {
101
- getWordVector(vec, word);
102
- }
103
-
104
122
  void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
105
123
  vec.zero();
106
124
  int32_t h = dict_->hash(subword) % args_->bucket;
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
109
127
  }
110
128
 
111
129
  void FastText::saveVectors(const std::string& filename) {
130
+ if (!input_ || !output_) {
131
+ throw std::runtime_error("Model never trained");
132
+ }
112
133
  std::ofstream ofs(filename);
113
134
  if (!ofs.is_open()) {
114
135
  throw std::invalid_argument(
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
124
145
  ofs.close();
125
146
  }
126
147
 
127
- void FastText::saveVectors() {
128
- saveVectors(args_->output + ".vec");
129
- }
130
-
131
148
  void FastText::saveOutput(const std::string& filename) {
132
149
  std::ofstream ofs(filename);
133
150
  if (!ofs.is_open()) {
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
152
169
  ofs.close();
153
170
  }
154
171
 
155
- void FastText::saveOutput() {
156
- saveOutput(args_->output + ".output");
157
- }
158
-
159
172
  bool FastText::checkModel(std::istream& in) {
160
173
  int32_t magic;
161
174
  in.read((char*)&(magic), sizeof(int32_t));
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
176
189
  out.write((char*)&(version), sizeof(int32_t));
177
190
  }
178
191
 
179
- void FastText::saveModel() {
180
- std::string fn(args_->output);
181
- if (quant_) {
182
- fn += ".ftz";
183
- } else {
184
- fn += ".bin";
185
- }
186
- saveModel(fn);
187
- }
188
-
189
192
  void FastText::saveModel(const std::string& filename) {
190
193
  std::ofstream ofs(filename, std::ofstream::binary);
191
194
  if (!ofs.is_open()) {
192
195
  throw std::invalid_argument(filename + " cannot be opened for saving!");
193
196
  }
197
+ if (!input_ || !output_) {
198
+ throw std::runtime_error("Model never trained");
199
+ }
194
200
  signModel(ofs);
195
201
  args_->save(ofs);
196
202
  dict_->save(ofs);
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
224
230
  }
225
231
  }
226
232
 
233
+ void FastText::buildModel() {
234
+ auto loss = createLoss(output_);
235
+ bool normalizeGradient = (args_->model == model_name::sup);
236
+ model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
237
+ }
238
+
227
239
  void FastText::loadModel(std::istream& in) {
228
240
  args_ = std::make_shared<Args>();
229
241
  input_ = std::make_shared<DenseMatrix>();
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
256
268
  }
257
269
  output_->load(in);
258
270
 
259
- auto loss = createLoss(output_);
260
- bool normalizeGradient = (args_->model == model_name::sup);
261
- model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
271
+ buildModel();
262
272
  }
263
273
 
264
- void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
265
- std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
266
- double t =
267
- std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
268
- .count();
274
+ std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
275
+ double t = utils::getDuration(start_, std::chrono::steady_clock::now());
269
276
  double lr = args_->lr * (1.0 - progress);
270
277
  double wst = 0;
271
278
 
272
279
  int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
273
280
 
274
281
  if (progress > 0 && t >= 0) {
275
- progress = progress * 100;
276
- eta = t * (100 - progress) / progress;
282
+ eta = t * (1 - progress) / progress;
277
283
  wst = double(tokenCount_) / t / args_->thread;
278
284
  }
279
- int32_t etah = eta / 3600;
280
- int32_t etam = (eta % 3600) / 60;
285
+
286
+ return std::tuple<double, double, int64_t>(wst, lr, eta);
287
+ }
288
+
289
+ void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
290
+ double wst;
291
+ double lr;
292
+ int64_t eta;
293
+ std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
281
294
 
282
295
  log_stream << std::fixed;
283
296
  log_stream << "Progress: ";
284
- log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
297
+ log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
285
298
  log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
286
299
  log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
287
- log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
288
- log_stream << " ETA: " << std::setw(3) << etah;
289
- log_stream << "h" << std::setw(2) << etam << "m";
300
+ log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
301
+ log_stream << " ETA: " << utils::ClockPrint(eta);
290
302
  log_stream << std::flush;
291
303
  }
292
304
 
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
299
311
  std::iota(idx.begin(), idx.end(), 0);
300
312
  auto eosid = dict_->getId(Dictionary::EOS);
301
313
  std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
314
+ if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
315
+ return false;
316
+ }
302
317
  return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
303
318
  });
304
319
  idx.erase(idx.begin() + cutoff, idx.end());
305
320
  return idx;
306
321
  }
307
322
 
308
- void FastText::quantize(const Args& qargs) {
323
+ void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
309
324
  if (args_->model != model_name::sup) {
310
325
  throw std::invalid_argument(
311
326
  "For now we only support quantization of supervised models");
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
337
352
  args_->verbose = qargs.verbose;
338
353
  auto loss = createLoss(output_);
339
354
  model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
340
- startThreads();
355
+ startThreads(callback);
341
356
  }
342
357
  }
343
-
344
358
  input_ = std::make_shared<QuantMatrix>(
345
359
  std::move(*(input.get())), qargs.dsub, qargs.qnorm);
346
360
 
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
348
362
  output_ = std::make_shared<QuantMatrix>(
349
363
  std::move(*(output.get())), 2, qargs.qnorm);
350
364
  }
351
-
352
365
  quant_ = true;
353
366
  auto loss = createLoss(output_);
354
367
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
@@ -408,7 +421,7 @@ void FastText::skipgram(
408
421
 
409
422
  std::tuple<int64_t, double, double>
410
423
  FastText::test(std::istream& in, int32_t k, real threshold) {
411
- Meter meter;
424
+ Meter meter(false);
412
425
  test(in, k, threshold, meter);
413
426
 
414
427
  return std::tuple<int64_t, double, double>(
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
420
433
  std::vector<int32_t> line;
421
434
  std::vector<int32_t> labels;
422
435
  Predictions predictions;
436
+ Model::State state(args_->dim, dict_->nlabels(), 0);
437
+ in.clear();
438
+ in.seekg(0, std::ios_base::beg);
423
439
 
424
440
  while (in.peek() != EOF) {
425
441
  line.clear();
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
521
537
  return result;
522
538
  }
523
539
 
524
- // deprecated. use getNgramVectors instead
525
- void FastText::ngramVectors(std::string word) {
526
- std::vector<std::pair<std::string, Vector>> ngramVectors =
527
- getNgramVectors(word);
528
-
529
- for (const auto& ngramVector : ngramVectors) {
530
- std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
531
- }
532
- }
533
-
534
540
  void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
535
541
  Vector vec(args_->dim);
536
542
  wordVectors.zero();
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
598
604
  return heap;
599
605
  }
600
606
 
601
- // depracted. use getNN instead
602
- void FastText::findNN(
603
- const DenseMatrix& wordVectors,
604
- const Vector& query,
605
- int32_t k,
606
- const std::set<std::string>& banSet,
607
- std::vector<std::pair<real, std::string>>& results) {
608
- results.clear();
609
- results = getNN(wordVectors, query, k, banSet);
610
- }
611
-
612
607
  std::vector<std::pair<real, std::string>> FastText::getAnalogies(
613
608
  int32_t k,
614
609
  const std::string& wordA,
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
630
625
  return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
631
626
  }
632
627
 
633
- // depreacted, use getAnalogies instead
634
- void FastText::analogies(int32_t k) {
635
- std::string prompt("Query triplet (A - B + C)? ");
636
- std::string wordA, wordB, wordC;
637
- std::cout << prompt;
638
- while (true) {
639
- std::cin >> wordA;
640
- std::cin >> wordB;
641
- std::cin >> wordC;
642
- auto results = getAnalogies(k, wordA, wordB, wordC);
643
-
644
- for (auto& pair : results) {
645
- std::cout << pair.second << " " << pair.first << std::endl;
646
- }
647
- std::cout << prompt;
648
- }
628
+ bool FastText::keepTraining(const int64_t ntokens) const {
629
+ return tokenCount_ < args_->epoch * ntokens && !trainException_;
649
630
  }
650
631
 
651
- void FastText::trainThread(int32_t threadId) {
632
+ void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
652
633
  std::ifstream ifs(args_->input);
653
634
  utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
654
635
 
655
- Model::State state(args_->dim, output_->size(0), threadId);
636
+ Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
656
637
 
657
638
  const int64_t ntokens = dict_->ntokens();
658
639
  int64_t localTokenCount = 0;
659
640
  std::vector<int32_t> line, labels;
660
- while (tokenCount_ < args_->epoch * ntokens) {
661
- real progress = real(tokenCount_) / (args_->epoch * ntokens);
662
- real lr = args_->lr * (1.0 - progress);
663
- if (args_->model == model_name::sup) {
664
- localTokenCount += dict_->getLine(ifs, line, labels);
665
- supervised(state, lr, line, labels);
666
- } else if (args_->model == model_name::cbow) {
667
- localTokenCount += dict_->getLine(ifs, line, state.rng);
668
- cbow(state, lr, line);
669
- } else if (args_->model == model_name::sg) {
670
- localTokenCount += dict_->getLine(ifs, line, state.rng);
671
- skipgram(state, lr, line);
672
- }
673
- if (localTokenCount > args_->lrUpdateRate) {
674
- tokenCount_ += localTokenCount;
675
- localTokenCount = 0;
676
- if (threadId == 0 && args_->verbose > 1)
677
- loss_ = state.getLoss();
641
+ uint64_t callbackCounter = 0;
642
+ try {
643
+ while (keepTraining(ntokens)) {
644
+ real progress = real(tokenCount_) / (args_->epoch * ntokens);
645
+ if (callback && ((callbackCounter++ % 64) == 0)) {
646
+ double wst;
647
+ double lr;
648
+ int64_t eta;
649
+ std::tie<double, double, int64_t>(wst, lr, eta) =
650
+ progressInfo(progress);
651
+ callback(progress, loss_, wst, lr, eta);
652
+ }
653
+ real lr = args_->lr * (1.0 - progress);
654
+ if (args_->model == model_name::sup) {
655
+ localTokenCount += dict_->getLine(ifs, line, labels);
656
+ supervised(state, lr, line, labels);
657
+ } else if (args_->model == model_name::cbow) {
658
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
659
+ cbow(state, lr, line);
660
+ } else if (args_->model == model_name::sg) {
661
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
662
+ skipgram(state, lr, line);
663
+ }
664
+ if (localTokenCount > args_->lrUpdateRate) {
665
+ tokenCount_ += localTokenCount;
666
+ localTokenCount = 0;
667
+ if (threadId == 0 && args_->verbose > 1) {
668
+ loss_ = state.getLoss();
669
+ }
670
+ }
678
671
  }
672
+ } catch (DenseMatrix::EncounteredNaNError&) {
673
+ trainException_ = std::current_exception();
679
674
  }
680
675
  if (threadId == 0)
681
676
  loss_ = state.getLoss();
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
713
708
  dict_->init();
714
709
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
715
710
  dict_->nwords() + args_->bucket, args_->dim);
716
- input->uniform(1.0 / args_->dim);
711
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
717
712
 
718
713
  for (size_t i = 0; i < n; i++) {
719
714
  int32_t idx = dict_->getId(words[i]);
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
727
722
  return input;
728
723
  }
729
724
 
730
- void FastText::loadVectors(const std::string& filename) {
731
- input_ = getInputMatrixFromFile(filename);
732
- }
733
-
734
725
  std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
735
726
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
736
727
  dict_->nwords() + args_->bucket, args_->dim);
737
- input->uniform(1.0 / args_->dim);
728
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
738
729
 
739
730
  return input;
740
731
  }
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
749
740
  return output;
750
741
  }
751
742
 
752
- void FastText::train(const Args& args) {
743
+ void FastText::train(const Args& args, const TrainCallback& callback) {
753
744
  args_ = std::make_shared<Args>(args);
754
745
  dict_ = std::make_shared<Dictionary>(args_);
755
746
  if (args_->input == "-") {
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
770
761
  input_ = createRandomMatrix();
771
762
  }
772
763
  output_ = createTrainOutputMatrix();
764
+ quant_ = false;
773
765
  auto loss = createLoss(output_);
774
766
  bool normalizeGradient = (args_->model == model_name::sup);
775
767
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
776
- startThreads();
768
+ startThreads(callback);
769
+ }
770
+
771
+ void FastText::abort() {
772
+ try {
773
+ throw AbortError();
774
+ } catch (AbortError&) {
775
+ trainException_ = std::current_exception();
776
+ }
777
777
  }
778
778
 
779
- void FastText::startThreads() {
779
+ void FastText::startThreads(const TrainCallback& callback) {
780
780
  start_ = std::chrono::steady_clock::now();
781
781
  tokenCount_ = 0;
782
782
  loss_ = -1;
783
+ trainException_ = nullptr;
783
784
  std::vector<std::thread> threads;
784
- for (int32_t i = 0; i < args_->thread; i++) {
785
- threads.push_back(std::thread([=]() { trainThread(i); }));
785
+ if (args_->thread > 1) {
786
+ for (int32_t i = 0; i < args_->thread; i++) {
787
+ threads.push_back(std::thread([=]() { trainThread(i, callback); }));
788
+ }
789
+ } else {
790
+ // webassembly can't instantiate `std::thread`
791
+ trainThread(0, callback);
786
792
  }
787
793
  const int64_t ntokens = dict_->ntokens();
788
794
  // Same condition as trainThread
789
- while (tokenCount_ < args_->epoch * ntokens) {
795
+ while (keepTraining(ntokens)) {
790
796
  std::this_thread::sleep_for(std::chrono::milliseconds(100));
791
797
  if (loss_ >= 0 && args_->verbose > 1) {
792
798
  real progress = real(tokenCount_) / (args_->epoch * ntokens);
@@ -794,9 +800,14 @@ void FastText::startThreads() {
794
800
  printInfo(progress, loss_, std::cerr);
795
801
  }
796
802
  }
797
- for (int32_t i = 0; i < args_->thread; i++) {
803
+ for (int32_t i = 0; i < threads.size(); i++) {
798
804
  threads[i].join();
799
805
  }
806
+ if (trainException_) {
807
+ std::exception_ptr exception = trainException_;
808
+ trainException_ = nullptr;
809
+ std::rethrow_exception(exception);
810
+ }
800
811
  if (args_->verbose > 0) {
801
812
  std::cerr << "\r";
802
813
  printInfo(1.0, loss_, std::cerr);
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <atomic>
14
14
  #include <chrono>
15
+ #include <functional>
15
16
  #include <iostream>
16
17
  #include <memory>
17
18
  #include <queue>
@@ -31,24 +32,29 @@
31
32
  namespace fasttext {
32
33
 
33
34
  class FastText {
35
+ public:
36
+ using TrainCallback =
37
+ std::function<void(float, float, double, double, int64_t)>;
38
+
34
39
  protected:
35
40
  std::shared_ptr<Args> args_;
36
41
  std::shared_ptr<Dictionary> dict_;
37
-
38
42
  std::shared_ptr<Matrix> input_;
39
43
  std::shared_ptr<Matrix> output_;
40
-
41
44
  std::shared_ptr<Model> model_;
42
-
43
45
  std::atomic<int64_t> tokenCount_{};
44
46
  std::atomic<real> loss_{};
45
-
46
47
  std::chrono::steady_clock::time_point start_;
48
+ bool quant_;
49
+ int32_t version;
50
+ std::unique_ptr<DenseMatrix> wordVectors_;
51
+ std::exception_ptr trainException_;
52
+
47
53
  void signModel(std::ostream&);
48
54
  bool checkModel(std::istream&);
49
- void startThreads();
55
+ void startThreads(const TrainCallback& callback = {});
50
56
  void addInputVector(Vector&, int32_t) const;
51
- void trainThread(int32_t);
57
+ void trainThread(int32_t, const TrainCallback& callback);
52
58
  std::vector<std::pair<real, std::string>> getNN(
53
59
  const DenseMatrix& wordVectors,
54
60
  const Vector& queryVec,
@@ -68,10 +74,11 @@ class FastText {
68
74
  const std::vector<int32_t>& labels);
69
75
  void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
70
76
  void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
71
-
72
- bool quant_;
73
- int32_t version;
74
- std::unique_ptr<DenseMatrix> wordVectors_;
77
+ std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
78
+ void precomputeWordVectors(DenseMatrix& wordVectors);
79
+ bool keepTraining(const int64_t ntokens) const;
80
+ void buildModel();
81
+ std::tuple<int64_t, double, double> progressInfo(real progress);
75
82
 
76
83
  public:
77
84
  FastText();
@@ -80,6 +87,8 @@ class FastText {
80
87
 
81
88
  int32_t getSubwordId(const std::string& subword) const;
82
89
 
90
+ int32_t getLabelId(const std::string& label) const;
91
+
83
92
  void getWordVector(Vector& vec, const std::string& word) const;
84
93
 
85
94
  void getSubwordVector(Vector& vec, const std::string& subword) const;
@@ -95,6 +104,10 @@ class FastText {
95
104
 
96
105
  std::shared_ptr<const DenseMatrix> getInputMatrix() const;
97
106
 
107
+ void setMatrices(
108
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
109
+ const std::shared_ptr<DenseMatrix>& outputMatrix);
110
+
98
111
  std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
99
112
 
100
113
  void saveVectors(const std::string& filename);
@@ -109,7 +122,7 @@ class FastText {
109
122
 
110
123
  void getSentenceVector(std::istream& in, Vector& vec);
111
124
 
112
- void quantize(const Args& qargs);
125
+ void quantize(const Args& qargs, const TrainCallback& callback = {});
113
126
 
114
127
  std::tuple<int64_t, double, double>
115
128
  test(std::istream& in, int32_t k, real threshold = 0.0);
@@ -141,51 +154,17 @@ class FastText {
141
154
  const std::string& wordB,
142
155
  const std::string& wordC);
143
156
 
144
- void train(const Args& args);
157
+ void train(const Args& args, const TrainCallback& callback = {});
158
+
159
+ void abort();
145
160
 
146
161
  int getDimension() const;
147
162
 
148
163
  bool isQuant() const;
149
164
 
150
- FASTTEXT_DEPRECATED("loadVectors is being deprecated.")
151
- void loadVectors(const std::string& filename);
152
-
153
- FASTTEXT_DEPRECATED(
154
- "getVector is being deprecated and replaced by getWordVector.")
155
- void getVector(Vector& vec, const std::string& word) const;
156
-
157
- FASTTEXT_DEPRECATED(
158
- "ngramVectors is being deprecated and replaced by getNgramVectors.")
159
- void ngramVectors(std::string word);
160
-
161
- FASTTEXT_DEPRECATED(
162
- "analogies is being deprecated and replaced by getAnalogies.")
163
- void analogies(int32_t k);
164
-
165
- FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
166
- std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
167
-
168
- FASTTEXT_DEPRECATED(
169
- "saveVectors is being deprecated, please use the other signature.")
170
- void saveVectors();
171
-
172
- FASTTEXT_DEPRECATED(
173
- "saveOutput is being deprecated, please use the other signature.")
174
- void saveOutput();
175
-
176
- FASTTEXT_DEPRECATED(
177
- "saveModel is being deprecated, please use the other signature.")
178
- void saveModel();
179
-
180
- FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
181
- void precomputeWordVectors(DenseMatrix& wordVectors);
182
-
183
- FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
184
- void findNN(
185
- const DenseMatrix& wordVectors,
186
- const Vector& query,
187
- int32_t k,
188
- const std::set<std::string>& banSet,
189
- std::vector<std::pair<real, std::string>>& results);
165
+ class AbortError : public std::runtime_error {
166
+ public:
167
+ AbortError() : std::runtime_error("Aborted.") {}
168
+ };
190
169
  };
191
170
  } // namespace fasttext
@@ -11,6 +11,7 @@
11
11
  #include <queue>
12
12
  #include <stdexcept>
13
13
  #include "args.h"
14
+ #include "autotune.h"
14
15
  #include "fasttext.h"
15
16
 
16
17
  using namespace fasttext;
@@ -20,19 +21,25 @@ void printUsage() {
20
21
  << "usage: fasttext <command> <args>\n\n"
21
22
  << "The commands supported by fasttext are:\n\n"
22
23
  << " supervised train a supervised classifier\n"
23
- << " quantize quantize a model to reduce the memory usage\n"
24
+ << " quantize quantize a model to reduce the memory "
25
+ "usage\n"
24
26
  << " test evaluate a supervised classifier\n"
25
- << " test-label print labels with precision and recall scores\n"
27
+ << " test-label print labels with precision and recall "
28
+ "scores\n"
26
29
  << " predict predict most likely labels\n"
27
- << " predict-prob predict most likely labels with probabilities\n"
30
+ << " predict-prob predict most likely labels with "
31
+ "probabilities\n"
28
32
  << " skipgram train a skipgram model\n"
29
33
  << " cbow train a cbow model\n"
30
34
  << " print-word-vectors print word vectors given a trained model\n"
31
- << " print-sentence-vectors print sentence vectors given a trained model\n"
32
- << " print-ngrams print ngrams given a trained model and word\n"
35
+ << " print-sentence-vectors print sentence vectors given a trained "
36
+ "model\n"
37
+ << " print-ngrams print ngrams given a trained model and "
38
+ "word\n"
33
39
  << " nn query for nearest neighbors\n"
34
40
  << " analogies query for analogies\n"
35
- << " dump dump arguments,dictionary,input/output vectors\n"
41
+ << " dump dump arguments,dictionary,input/output "
42
+ "vectors\n"
36
43
  << std::endl;
37
44
  }
38
45
 
@@ -141,7 +148,7 @@ void test(const std::vector<std::string>& args) {
141
148
  FastText fasttext;
142
149
  fasttext.loadModel(model);
143
150
 
144
- Meter meter;
151
+ Meter meter(false);
145
152
 
146
153
  if (input == "-") {
147
154
  fasttext.test(std::cin, k, threshold, meter);
@@ -351,19 +358,31 @@ void analogies(const std::vector<std::string> args) {
351
358
  void train(const std::vector<std::string> args) {
352
359
  Args a = Args();
353
360
  a.parseArgs(args);
354
- FastText fasttext;
355
- std::string outputFileName(a.output + ".bin");
361
+ std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
362
+ std::string outputFileName;
363
+
364
+ if (a.hasAutotune() &&
365
+ a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
366
+ outputFileName = a.output + ".ftz";
367
+ } else {
368
+ outputFileName = a.output + ".bin";
369
+ }
356
370
  std::ofstream ofs(outputFileName);
357
371
  if (!ofs.is_open()) {
358
372
  throw std::invalid_argument(
359
373
  outputFileName + " cannot be opened for saving.");
360
374
  }
361
375
  ofs.close();
362
- fasttext.train(a);
363
- fasttext.saveModel(outputFileName);
364
- fasttext.saveVectors(a.output + ".vec");
376
+ if (a.hasAutotune()) {
377
+ Autotune autotune(fasttext);
378
+ autotune.train(a);
379
+ } else {
380
+ fasttext->train(a);
381
+ }
382
+ fasttext->saveModel(outputFileName);
383
+ fasttext->saveVectors(a.output + ".vec");
365
384
  if (a.saveOutput) {
366
- fasttext.saveOutput(a.output + ".output");
385
+ fasttext->saveOutput(a.output + ".output");
367
386
  }
368
387
  }
369
388