fasttext 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,89 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the MIT license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <istream>
12
+ #include <memory>
13
+ #include <random>
14
+ #include <thread>
15
+ #include <vector>
16
+
17
+ #include "args.h"
18
+ #include "fasttext.h"
19
+
20
+ namespace fasttext {
21
+
22
+ class AutotuneStrategy {
23
+ private:
24
+ Args bestArgs_;
25
+ int maxDuration_;
26
+ std::minstd_rand rng_;
27
+ int trials_;
28
+ int bestMinnIndex_;
29
+ int bestDsubExponent_;
30
+ int bestNonzeroBucket_;
31
+ int originalBucket_;
32
+ std::vector<int> minnChoices_;
33
+ int getIndex(int val, const std::vector<int>& choices);
34
+
35
+ public:
36
+ explicit AutotuneStrategy(
37
+ const Args& args,
38
+ std::minstd_rand::result_type seed);
39
+ Args ask(double elapsed);
40
+ void updateBest(const Args& args);
41
+ };
42
+
43
+ class Autotune {
44
+ protected:
45
+ std::shared_ptr<FastText> fastText_;
46
+ double elapsed_;
47
+ double bestScore_;
48
+ int32_t trials_;
49
+ int32_t sizeConstraintFailed_;
50
+ std::atomic<bool> continueTraining_;
51
+ std::unique_ptr<AutotuneStrategy> strategy_;
52
+ std::thread timer_;
53
+
54
+ bool keepTraining(double maxDuration) const;
55
+ void printInfo(double maxDuration);
56
+ void timer(
57
+ const std::chrono::steady_clock::time_point& start,
58
+ double maxDuration);
59
+ void abort();
60
+ void startTimer(const Args& args);
61
+ double getMetricScore(
62
+ Meter& meter,
63
+ const metric_name& metricName,
64
+ const double metricValue,
65
+ const std::string& metricLabel) const;
66
+ void printArgs(const Args& args, const Args& autotuneArgs);
67
+ void printSkippedArgs(const Args& autotuneArgs);
68
+ bool quantize(Args& args, const Args& autotuneArgs);
69
+ int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
70
+ const;
71
+
72
+ class TimeoutError : public std::runtime_error {
73
+ public:
74
+ TimeoutError() : std::runtime_error("Autotune timed out.") {}
75
+ };
76
+
77
+ public:
78
+ Autotune() = delete;
79
+ explicit Autotune(const std::shared_ptr<FastText>& fastText);
80
+ Autotune(const Autotune&) = delete;
81
+ Autotune(Autotune&&) = delete;
82
+ Autotune& operator=(const Autotune&) = delete;
83
+ Autotune& operator=(Autotune&&) = delete;
84
+ ~Autotune() noexcept = default;
85
+
86
+ void train(const Args& args);
87
+ };
88
+
89
+ } // namespace fasttext
@@ -8,11 +8,10 @@
8
8
 
9
9
  #include "densematrix.h"
10
10
 
11
- #include <exception>
12
11
  #include <random>
13
12
  #include <stdexcept>
13
+ #include <thread>
14
14
  #include <utility>
15
-
16
15
  #include "utils.h"
17
16
  #include "vector.h"
18
17
 
@@ -25,18 +24,39 @@ DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
25
24
  DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
26
25
  : Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
27
26
 
27
+ DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr)
28
+ : Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {}
29
+
28
30
  void DenseMatrix::zero() {
29
31
  std::fill(data_.begin(), data_.end(), 0.0);
30
32
  }
31
33
 
32
- void DenseMatrix::uniform(real a) {
33
- std::minstd_rand rng(1);
34
+ void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
35
+ std::minstd_rand rng(block + seed);
34
36
  std::uniform_real_distribution<> uniform(-a, a);
35
- for (int64_t i = 0; i < (m_ * n_); i++) {
37
+ int64_t blockSize = (m_ * n_) / 10;
38
+ for (int64_t i = blockSize * block;
39
+ i < (m_ * n_) && i < blockSize * (block + 1);
40
+ i++) {
36
41
  data_[i] = uniform(rng);
37
42
  }
38
43
  }
39
44
 
45
+ void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
46
+ if (thread > 1) {
47
+ std::vector<std::thread> threads;
48
+ for (int i = 0; i < thread; i++) {
49
+ threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
50
+ }
51
+ for (int32_t i = 0; i < threads.size(); i++) {
52
+ threads[i].join();
53
+ }
54
+ } else {
55
+ // webassembly can't instantiate `std::thread`
56
+ uniformThread(a, 0, seed);
57
+ }
58
+ }
59
+
40
60
  void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
41
61
  if (ie == -1) {
42
62
  ie = m_;
@@ -73,7 +93,7 @@ real DenseMatrix::l2NormRow(int64_t i) const {
73
93
  norm += at(i, j) * at(i, j);
74
94
  }
75
95
  if (std::isnan(norm)) {
76
- throw std::runtime_error("Encountered NaN.");
96
+ throw EncounteredNaNError();
77
97
  }
78
98
  return std::sqrt(norm);
79
99
  }
@@ -94,7 +114,7 @@ real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
94
114
  d += at(i, j) * vec[j];
95
115
  }
96
116
  if (std::isnan(d)) {
97
- throw std::runtime_error("Encountered NaN.");
117
+ throw EncounteredNaNError();
98
118
  }
99
119
  return d;
100
120
  }
@@ -8,12 +8,13 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include <assert.h>
11
12
  #include <cstdint>
12
13
  #include <istream>
13
14
  #include <ostream>
15
+ #include <stdexcept>
14
16
  #include <vector>
15
17
 
16
- #include <assert.h>
17
18
  #include "matrix.h"
18
19
  #include "real.h"
19
20
 
@@ -24,10 +25,12 @@ class Vector;
24
25
  class DenseMatrix : public Matrix {
25
26
  protected:
26
27
  std::vector<real> data_;
28
+ void uniformThread(real, int, int32_t);
27
29
 
28
30
  public:
29
31
  DenseMatrix();
30
32
  explicit DenseMatrix(int64_t, int64_t);
33
+ explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr);
31
34
  DenseMatrix(const DenseMatrix&) = default;
32
35
  DenseMatrix(DenseMatrix&&) noexcept;
33
36
  DenseMatrix& operator=(const DenseMatrix&) = delete;
@@ -56,7 +59,7 @@ class DenseMatrix : public Matrix {
56
59
  return n_;
57
60
  }
58
61
  void zero();
59
- void uniform(real);
62
+ void uniform(real, unsigned int, int32_t);
60
63
 
61
64
  void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
62
65
  void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
@@ -71,5 +74,10 @@ class DenseMatrix : public Matrix {
71
74
  void save(std::ostream&) const override;
72
75
  void load(std::istream&) override;
73
76
  void dump(std::ostream&) const override;
77
+
78
+ class EncounteredNaNError : public std::runtime_error {
79
+ public:
80
+ EncounteredNaNError() : std::runtime_error("Encountered NaN.") {}
81
+ };
74
82
  };
75
83
  } // namespace fasttext
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
47
47
  }
48
48
  }
49
49
 
50
- FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
50
+ FastText::FastText()
51
+ : quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
51
52
 
52
53
  void FastText::addInputVector(Vector& vec, int32_t ind) const {
53
54
  vec.addRow(*input_, ind);
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
69
70
  return std::dynamic_pointer_cast<DenseMatrix>(input_);
70
71
  }
71
72
 
73
+ void FastText::setMatrices(
74
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
75
+ const std::shared_ptr<DenseMatrix>& outputMatrix) {
76
+ assert(input_->size(1) == output_->size(1));
77
+
78
+ input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
79
+ output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
80
+ wordVectors_.reset();
81
+ args_->dim = input_->size(1);
82
+
83
+ buildModel();
84
+ }
85
+
72
86
  std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
73
87
  if (quant_ && args_->qout) {
74
88
  throw std::runtime_error("Can't export quantized matrix");
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
86
100
  return dict_->nwords() + h;
87
101
  }
88
102
 
103
+ int32_t FastText::getLabelId(const std::string& label) const {
104
+ int32_t labelId = dict_->getId(label);
105
+ if (labelId != -1) {
106
+ labelId -= dict_->nwords();
107
+ }
108
+ return labelId;
109
+ }
110
+
89
111
  void FastText::getWordVector(Vector& vec, const std::string& word) const {
90
112
  const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
91
113
  vec.zero();
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
97
119
  }
98
120
  }
99
121
 
100
- void FastText::getVector(Vector& vec, const std::string& word) const {
101
- getWordVector(vec, word);
102
- }
103
-
104
122
  void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
105
123
  vec.zero();
106
124
  int32_t h = dict_->hash(subword) % args_->bucket;
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
109
127
  }
110
128
 
111
129
  void FastText::saveVectors(const std::string& filename) {
130
+ if (!input_ || !output_) {
131
+ throw std::runtime_error("Model never trained");
132
+ }
112
133
  std::ofstream ofs(filename);
113
134
  if (!ofs.is_open()) {
114
135
  throw std::invalid_argument(
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
124
145
  ofs.close();
125
146
  }
126
147
 
127
- void FastText::saveVectors() {
128
- saveVectors(args_->output + ".vec");
129
- }
130
-
131
148
  void FastText::saveOutput(const std::string& filename) {
132
149
  std::ofstream ofs(filename);
133
150
  if (!ofs.is_open()) {
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
152
169
  ofs.close();
153
170
  }
154
171
 
155
- void FastText::saveOutput() {
156
- saveOutput(args_->output + ".output");
157
- }
158
-
159
172
  bool FastText::checkModel(std::istream& in) {
160
173
  int32_t magic;
161
174
  in.read((char*)&(magic), sizeof(int32_t));
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
176
189
  out.write((char*)&(version), sizeof(int32_t));
177
190
  }
178
191
 
179
- void FastText::saveModel() {
180
- std::string fn(args_->output);
181
- if (quant_) {
182
- fn += ".ftz";
183
- } else {
184
- fn += ".bin";
185
- }
186
- saveModel(fn);
187
- }
188
-
189
192
  void FastText::saveModel(const std::string& filename) {
190
193
  std::ofstream ofs(filename, std::ofstream::binary);
191
194
  if (!ofs.is_open()) {
192
195
  throw std::invalid_argument(filename + " cannot be opened for saving!");
193
196
  }
197
+ if (!input_ || !output_) {
198
+ throw std::runtime_error("Model never trained");
199
+ }
194
200
  signModel(ofs);
195
201
  args_->save(ofs);
196
202
  dict_->save(ofs);
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
224
230
  }
225
231
  }
226
232
 
233
+ void FastText::buildModel() {
234
+ auto loss = createLoss(output_);
235
+ bool normalizeGradient = (args_->model == model_name::sup);
236
+ model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
237
+ }
238
+
227
239
  void FastText::loadModel(std::istream& in) {
228
240
  args_ = std::make_shared<Args>();
229
241
  input_ = std::make_shared<DenseMatrix>();
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
256
268
  }
257
269
  output_->load(in);
258
270
 
259
- auto loss = createLoss(output_);
260
- bool normalizeGradient = (args_->model == model_name::sup);
261
- model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
271
+ buildModel();
262
272
  }
263
273
 
264
- void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
265
- std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
266
- double t =
267
- std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
268
- .count();
274
+ std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
275
+ double t = utils::getDuration(start_, std::chrono::steady_clock::now());
269
276
  double lr = args_->lr * (1.0 - progress);
270
277
  double wst = 0;
271
278
 
272
279
  int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
273
280
 
274
281
  if (progress > 0 && t >= 0) {
275
- progress = progress * 100;
276
- eta = t * (100 - progress) / progress;
282
+ eta = t * (1 - progress) / progress;
277
283
  wst = double(tokenCount_) / t / args_->thread;
278
284
  }
279
- int32_t etah = eta / 3600;
280
- int32_t etam = (eta % 3600) / 60;
285
+
286
+ return std::tuple<double, double, int64_t>(wst, lr, eta);
287
+ }
288
+
289
+ void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
290
+ double wst;
291
+ double lr;
292
+ int64_t eta;
293
+ std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
281
294
 
282
295
  log_stream << std::fixed;
283
296
  log_stream << "Progress: ";
284
- log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
297
+ log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
285
298
  log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
286
299
  log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
287
- log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
288
- log_stream << " ETA: " << std::setw(3) << etah;
289
- log_stream << "h" << std::setw(2) << etam << "m";
300
+ log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
301
+ log_stream << " ETA: " << utils::ClockPrint(eta);
290
302
  log_stream << std::flush;
291
303
  }
292
304
 
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
299
311
  std::iota(idx.begin(), idx.end(), 0);
300
312
  auto eosid = dict_->getId(Dictionary::EOS);
301
313
  std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
314
+ if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
315
+ return false;
316
+ }
302
317
  return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
303
318
  });
304
319
  idx.erase(idx.begin() + cutoff, idx.end());
305
320
  return idx;
306
321
  }
307
322
 
308
- void FastText::quantize(const Args& qargs) {
323
+ void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
309
324
  if (args_->model != model_name::sup) {
310
325
  throw std::invalid_argument(
311
326
  "For now we only support quantization of supervised models");
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
337
352
  args_->verbose = qargs.verbose;
338
353
  auto loss = createLoss(output_);
339
354
  model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
340
- startThreads();
355
+ startThreads(callback);
341
356
  }
342
357
  }
343
-
344
358
  input_ = std::make_shared<QuantMatrix>(
345
359
  std::move(*(input.get())), qargs.dsub, qargs.qnorm);
346
360
 
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
348
362
  output_ = std::make_shared<QuantMatrix>(
349
363
  std::move(*(output.get())), 2, qargs.qnorm);
350
364
  }
351
-
352
365
  quant_ = true;
353
366
  auto loss = createLoss(output_);
354
367
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
@@ -408,7 +421,7 @@ void FastText::skipgram(
408
421
 
409
422
  std::tuple<int64_t, double, double>
410
423
  FastText::test(std::istream& in, int32_t k, real threshold) {
411
- Meter meter;
424
+ Meter meter(false);
412
425
  test(in, k, threshold, meter);
413
426
 
414
427
  return std::tuple<int64_t, double, double>(
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
420
433
  std::vector<int32_t> line;
421
434
  std::vector<int32_t> labels;
422
435
  Predictions predictions;
436
+ Model::State state(args_->dim, dict_->nlabels(), 0);
437
+ in.clear();
438
+ in.seekg(0, std::ios_base::beg);
423
439
 
424
440
  while (in.peek() != EOF) {
425
441
  line.clear();
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
521
537
  return result;
522
538
  }
523
539
 
524
- // deprecated. use getNgramVectors instead
525
- void FastText::ngramVectors(std::string word) {
526
- std::vector<std::pair<std::string, Vector>> ngramVectors =
527
- getNgramVectors(word);
528
-
529
- for (const auto& ngramVector : ngramVectors) {
530
- std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
531
- }
532
- }
533
-
534
540
  void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
535
541
  Vector vec(args_->dim);
536
542
  wordVectors.zero();
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
598
604
  return heap;
599
605
  }
600
606
 
601
- // depracted. use getNN instead
602
- void FastText::findNN(
603
- const DenseMatrix& wordVectors,
604
- const Vector& query,
605
- int32_t k,
606
- const std::set<std::string>& banSet,
607
- std::vector<std::pair<real, std::string>>& results) {
608
- results.clear();
609
- results = getNN(wordVectors, query, k, banSet);
610
- }
611
-
612
607
  std::vector<std::pair<real, std::string>> FastText::getAnalogies(
613
608
  int32_t k,
614
609
  const std::string& wordA,
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
630
625
  return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
631
626
  }
632
627
 
633
- // depreacted, use getAnalogies instead
634
- void FastText::analogies(int32_t k) {
635
- std::string prompt("Query triplet (A - B + C)? ");
636
- std::string wordA, wordB, wordC;
637
- std::cout << prompt;
638
- while (true) {
639
- std::cin >> wordA;
640
- std::cin >> wordB;
641
- std::cin >> wordC;
642
- auto results = getAnalogies(k, wordA, wordB, wordC);
643
-
644
- for (auto& pair : results) {
645
- std::cout << pair.second << " " << pair.first << std::endl;
646
- }
647
- std::cout << prompt;
648
- }
628
+ bool FastText::keepTraining(const int64_t ntokens) const {
629
+ return tokenCount_ < args_->epoch * ntokens && !trainException_;
649
630
  }
650
631
 
651
- void FastText::trainThread(int32_t threadId) {
632
+ void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
652
633
  std::ifstream ifs(args_->input);
653
634
  utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
654
635
 
655
- Model::State state(args_->dim, output_->size(0), threadId);
636
+ Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
656
637
 
657
638
  const int64_t ntokens = dict_->ntokens();
658
639
  int64_t localTokenCount = 0;
659
640
  std::vector<int32_t> line, labels;
660
- while (tokenCount_ < args_->epoch * ntokens) {
661
- real progress = real(tokenCount_) / (args_->epoch * ntokens);
662
- real lr = args_->lr * (1.0 - progress);
663
- if (args_->model == model_name::sup) {
664
- localTokenCount += dict_->getLine(ifs, line, labels);
665
- supervised(state, lr, line, labels);
666
- } else if (args_->model == model_name::cbow) {
667
- localTokenCount += dict_->getLine(ifs, line, state.rng);
668
- cbow(state, lr, line);
669
- } else if (args_->model == model_name::sg) {
670
- localTokenCount += dict_->getLine(ifs, line, state.rng);
671
- skipgram(state, lr, line);
672
- }
673
- if (localTokenCount > args_->lrUpdateRate) {
674
- tokenCount_ += localTokenCount;
675
- localTokenCount = 0;
676
- if (threadId == 0 && args_->verbose > 1)
677
- loss_ = state.getLoss();
641
+ uint64_t callbackCounter = 0;
642
+ try {
643
+ while (keepTraining(ntokens)) {
644
+ real progress = real(tokenCount_) / (args_->epoch * ntokens);
645
+ if (callback && ((callbackCounter++ % 64) == 0)) {
646
+ double wst;
647
+ double lr;
648
+ int64_t eta;
649
+ std::tie<double, double, int64_t>(wst, lr, eta) =
650
+ progressInfo(progress);
651
+ callback(progress, loss_, wst, lr, eta);
652
+ }
653
+ real lr = args_->lr * (1.0 - progress);
654
+ if (args_->model == model_name::sup) {
655
+ localTokenCount += dict_->getLine(ifs, line, labels);
656
+ supervised(state, lr, line, labels);
657
+ } else if (args_->model == model_name::cbow) {
658
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
659
+ cbow(state, lr, line);
660
+ } else if (args_->model == model_name::sg) {
661
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
662
+ skipgram(state, lr, line);
663
+ }
664
+ if (localTokenCount > args_->lrUpdateRate) {
665
+ tokenCount_ += localTokenCount;
666
+ localTokenCount = 0;
667
+ if (threadId == 0 && args_->verbose > 1) {
668
+ loss_ = state.getLoss();
669
+ }
670
+ }
678
671
  }
672
+ } catch (DenseMatrix::EncounteredNaNError&) {
673
+ trainException_ = std::current_exception();
679
674
  }
680
675
  if (threadId == 0)
681
676
  loss_ = state.getLoss();
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
713
708
  dict_->init();
714
709
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
715
710
  dict_->nwords() + args_->bucket, args_->dim);
716
- input->uniform(1.0 / args_->dim);
711
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
717
712
 
718
713
  for (size_t i = 0; i < n; i++) {
719
714
  int32_t idx = dict_->getId(words[i]);
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
727
722
  return input;
728
723
  }
729
724
 
730
- void FastText::loadVectors(const std::string& filename) {
731
- input_ = getInputMatrixFromFile(filename);
732
- }
733
-
734
725
  std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
735
726
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
736
727
  dict_->nwords() + args_->bucket, args_->dim);
737
- input->uniform(1.0 / args_->dim);
728
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
738
729
 
739
730
  return input;
740
731
  }
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
749
740
  return output;
750
741
  }
751
742
 
752
- void FastText::train(const Args& args) {
743
+ void FastText::train(const Args& args, const TrainCallback& callback) {
753
744
  args_ = std::make_shared<Args>(args);
754
745
  dict_ = std::make_shared<Dictionary>(args_);
755
746
  if (args_->input == "-") {
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
770
761
  input_ = createRandomMatrix();
771
762
  }
772
763
  output_ = createTrainOutputMatrix();
764
+ quant_ = false;
773
765
  auto loss = createLoss(output_);
774
766
  bool normalizeGradient = (args_->model == model_name::sup);
775
767
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
776
- startThreads();
768
+ startThreads(callback);
769
+ }
770
+
771
+ void FastText::abort() {
772
+ try {
773
+ throw AbortError();
774
+ } catch (AbortError&) {
775
+ trainException_ = std::current_exception();
776
+ }
777
777
  }
778
778
 
779
- void FastText::startThreads() {
779
+ void FastText::startThreads(const TrainCallback& callback) {
780
780
  start_ = std::chrono::steady_clock::now();
781
781
  tokenCount_ = 0;
782
782
  loss_ = -1;
783
+ trainException_ = nullptr;
783
784
  std::vector<std::thread> threads;
784
- for (int32_t i = 0; i < args_->thread; i++) {
785
- threads.push_back(std::thread([=]() { trainThread(i); }));
785
+ if (args_->thread > 1) {
786
+ for (int32_t i = 0; i < args_->thread; i++) {
787
+ threads.push_back(std::thread([=]() { trainThread(i, callback); }));
788
+ }
789
+ } else {
790
+ // webassembly can't instantiate `std::thread`
791
+ trainThread(0, callback);
786
792
  }
787
793
  const int64_t ntokens = dict_->ntokens();
788
794
  // Same condition as trainThread
789
- while (tokenCount_ < args_->epoch * ntokens) {
795
+ while (keepTraining(ntokens)) {
790
796
  std::this_thread::sleep_for(std::chrono::milliseconds(100));
791
797
  if (loss_ >= 0 && args_->verbose > 1) {
792
798
  real progress = real(tokenCount_) / (args_->epoch * ntokens);
@@ -794,9 +800,14 @@ void FastText::startThreads() {
794
800
  printInfo(progress, loss_, std::cerr);
795
801
  }
796
802
  }
797
- for (int32_t i = 0; i < args_->thread; i++) {
803
+ for (int32_t i = 0; i < threads.size(); i++) {
798
804
  threads[i].join();
799
805
  }
806
+ if (trainException_) {
807
+ std::exception_ptr exception = trainException_;
808
+ trainException_ = nullptr;
809
+ std::rethrow_exception(exception);
810
+ }
800
811
  if (args_->verbose > 0) {
801
812
  std::cerr << "\r";
802
813
  printInfo(1.0, loss_, std::cerr);