fasttext 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,89 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the MIT license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include <istream>
12
+ #include <memory>
13
+ #include <random>
14
+ #include <thread>
15
+ #include <vector>
16
+
17
+ #include "args.h"
18
+ #include "fasttext.h"
19
+
20
+ namespace fasttext {
21
+
22
+ class AutotuneStrategy {
23
+ private:
24
+ Args bestArgs_;
25
+ int maxDuration_;
26
+ std::minstd_rand rng_;
27
+ int trials_;
28
+ int bestMinnIndex_;
29
+ int bestDsubExponent_;
30
+ int bestNonzeroBucket_;
31
+ int originalBucket_;
32
+ std::vector<int> minnChoices_;
33
+ int getIndex(int val, const std::vector<int>& choices);
34
+
35
+ public:
36
+ explicit AutotuneStrategy(
37
+ const Args& args,
38
+ std::minstd_rand::result_type seed);
39
+ Args ask(double elapsed);
40
+ void updateBest(const Args& args);
41
+ };
42
+
43
+ class Autotune {
44
+ protected:
45
+ std::shared_ptr<FastText> fastText_;
46
+ double elapsed_;
47
+ double bestScore_;
48
+ int32_t trials_;
49
+ int32_t sizeConstraintFailed_;
50
+ std::atomic<bool> continueTraining_;
51
+ std::unique_ptr<AutotuneStrategy> strategy_;
52
+ std::thread timer_;
53
+
54
+ bool keepTraining(double maxDuration) const;
55
+ void printInfo(double maxDuration);
56
+ void timer(
57
+ const std::chrono::steady_clock::time_point& start,
58
+ double maxDuration);
59
+ void abort();
60
+ void startTimer(const Args& args);
61
+ double getMetricScore(
62
+ Meter& meter,
63
+ const metric_name& metricName,
64
+ const double metricValue,
65
+ const std::string& metricLabel) const;
66
+ void printArgs(const Args& args, const Args& autotuneArgs);
67
+ void printSkippedArgs(const Args& autotuneArgs);
68
+ bool quantize(Args& args, const Args& autotuneArgs);
69
+ int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
70
+ const;
71
+
72
+ class TimeoutError : public std::runtime_error {
73
+ public:
74
+ TimeoutError() : std::runtime_error("Autotune timed out.") {}
75
+ };
76
+
77
+ public:
78
+ Autotune() = delete;
79
+ explicit Autotune(const std::shared_ptr<FastText>& fastText);
80
+ Autotune(const Autotune&) = delete;
81
+ Autotune(Autotune&&) = delete;
82
+ Autotune& operator=(const Autotune&) = delete;
83
+ Autotune& operator=(Autotune&&) = delete;
84
+ ~Autotune() noexcept = default;
85
+
86
+ void train(const Args& args);
87
+ };
88
+
89
+ } // namespace fasttext
@@ -8,11 +8,10 @@
8
8
 
9
9
  #include "densematrix.h"
10
10
 
11
- #include <exception>
12
11
  #include <random>
13
12
  #include <stdexcept>
13
+ #include <thread>
14
14
  #include <utility>
15
-
16
15
  #include "utils.h"
17
16
  #include "vector.h"
18
17
 
@@ -25,18 +24,39 @@ DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
25
24
  DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
26
25
  : Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
27
26
 
27
+ DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr)
28
+ : Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {}
29
+
28
30
  void DenseMatrix::zero() {
29
31
  std::fill(data_.begin(), data_.end(), 0.0);
30
32
  }
31
33
 
32
- void DenseMatrix::uniform(real a) {
33
- std::minstd_rand rng(1);
34
+ void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
35
+ std::minstd_rand rng(block + seed);
34
36
  std::uniform_real_distribution<> uniform(-a, a);
35
- for (int64_t i = 0; i < (m_ * n_); i++) {
37
+ int64_t blockSize = (m_ * n_) / 10;
38
+ for (int64_t i = blockSize * block;
39
+ i < (m_ * n_) && i < blockSize * (block + 1);
40
+ i++) {
36
41
  data_[i] = uniform(rng);
37
42
  }
38
43
  }
39
44
 
45
+ void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
46
+ if (thread > 1) {
47
+ std::vector<std::thread> threads;
48
+ for (int i = 0; i < thread; i++) {
49
+ threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
50
+ }
51
+ for (int32_t i = 0; i < threads.size(); i++) {
52
+ threads[i].join();
53
+ }
54
+ } else {
55
+ // webassembly can't instantiate `std::thread`
56
+ uniformThread(a, 0, seed);
57
+ }
58
+ }
59
+
40
60
  void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
41
61
  if (ie == -1) {
42
62
  ie = m_;
@@ -73,7 +93,7 @@ real DenseMatrix::l2NormRow(int64_t i) const {
73
93
  norm += at(i, j) * at(i, j);
74
94
  }
75
95
  if (std::isnan(norm)) {
76
- throw std::runtime_error("Encountered NaN.");
96
+ throw EncounteredNaNError();
77
97
  }
78
98
  return std::sqrt(norm);
79
99
  }
@@ -94,7 +114,7 @@ real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
94
114
  d += at(i, j) * vec[j];
95
115
  }
96
116
  if (std::isnan(d)) {
97
- throw std::runtime_error("Encountered NaN.");
117
+ throw EncounteredNaNError();
98
118
  }
99
119
  return d;
100
120
  }
@@ -8,12 +8,13 @@
8
8
 
9
9
  #pragma once
10
10
 
11
+ #include <assert.h>
11
12
  #include <cstdint>
12
13
  #include <istream>
13
14
  #include <ostream>
15
+ #include <stdexcept>
14
16
  #include <vector>
15
17
 
16
- #include <assert.h>
17
18
  #include "matrix.h"
18
19
  #include "real.h"
19
20
 
@@ -24,10 +25,12 @@ class Vector;
24
25
  class DenseMatrix : public Matrix {
25
26
  protected:
26
27
  std::vector<real> data_;
28
+ void uniformThread(real, int, int32_t);
27
29
 
28
30
  public:
29
31
  DenseMatrix();
30
32
  explicit DenseMatrix(int64_t, int64_t);
33
+ explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr);
31
34
  DenseMatrix(const DenseMatrix&) = default;
32
35
  DenseMatrix(DenseMatrix&&) noexcept;
33
36
  DenseMatrix& operator=(const DenseMatrix&) = delete;
@@ -56,7 +59,7 @@ class DenseMatrix : public Matrix {
56
59
  return n_;
57
60
  }
58
61
  void zero();
59
- void uniform(real);
62
+ void uniform(real, unsigned int, int32_t);
60
63
 
61
64
  void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
62
65
  void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
@@ -71,5 +74,10 @@ class DenseMatrix : public Matrix {
71
74
  void save(std::ostream&) const override;
72
75
  void load(std::istream&) override;
73
76
  void dump(std::ostream&) const override;
77
+
78
+ class EncounteredNaNError : public std::runtime_error {
79
+ public:
80
+ EncounteredNaNError() : std::runtime_error("Encountered NaN.") {}
81
+ };
74
82
  };
75
83
  } // namespace fasttext
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
47
47
  }
48
48
  }
49
49
 
50
- FastText::FastText() : quant_(false), wordVectors_(nullptr) {}
50
+ FastText::FastText()
51
+ : quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
51
52
 
52
53
  void FastText::addInputVector(Vector& vec, int32_t ind) const {
53
54
  vec.addRow(*input_, ind);
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
69
70
  return std::dynamic_pointer_cast<DenseMatrix>(input_);
70
71
  }
71
72
 
73
+ void FastText::setMatrices(
74
+ const std::shared_ptr<DenseMatrix>& inputMatrix,
75
+ const std::shared_ptr<DenseMatrix>& outputMatrix) {
76
+ assert(input_->size(1) == output_->size(1));
77
+
78
+ input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
79
+ output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
80
+ wordVectors_.reset();
81
+ args_->dim = input_->size(1);
82
+
83
+ buildModel();
84
+ }
85
+
72
86
  std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
73
87
  if (quant_ && args_->qout) {
74
88
  throw std::runtime_error("Can't export quantized matrix");
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
86
100
  return dict_->nwords() + h;
87
101
  }
88
102
 
103
+ int32_t FastText::getLabelId(const std::string& label) const {
104
+ int32_t labelId = dict_->getId(label);
105
+ if (labelId != -1) {
106
+ labelId -= dict_->nwords();
107
+ }
108
+ return labelId;
109
+ }
110
+
89
111
  void FastText::getWordVector(Vector& vec, const std::string& word) const {
90
112
  const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
91
113
  vec.zero();
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
97
119
  }
98
120
  }
99
121
 
100
- void FastText::getVector(Vector& vec, const std::string& word) const {
101
- getWordVector(vec, word);
102
- }
103
-
104
122
  void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
105
123
  vec.zero();
106
124
  int32_t h = dict_->hash(subword) % args_->bucket;
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
109
127
  }
110
128
 
111
129
  void FastText::saveVectors(const std::string& filename) {
130
+ if (!input_ || !output_) {
131
+ throw std::runtime_error("Model never trained");
132
+ }
112
133
  std::ofstream ofs(filename);
113
134
  if (!ofs.is_open()) {
114
135
  throw std::invalid_argument(
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
124
145
  ofs.close();
125
146
  }
126
147
 
127
- void FastText::saveVectors() {
128
- saveVectors(args_->output + ".vec");
129
- }
130
-
131
148
  void FastText::saveOutput(const std::string& filename) {
132
149
  std::ofstream ofs(filename);
133
150
  if (!ofs.is_open()) {
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
152
169
  ofs.close();
153
170
  }
154
171
 
155
- void FastText::saveOutput() {
156
- saveOutput(args_->output + ".output");
157
- }
158
-
159
172
  bool FastText::checkModel(std::istream& in) {
160
173
  int32_t magic;
161
174
  in.read((char*)&(magic), sizeof(int32_t));
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
176
189
  out.write((char*)&(version), sizeof(int32_t));
177
190
  }
178
191
 
179
- void FastText::saveModel() {
180
- std::string fn(args_->output);
181
- if (quant_) {
182
- fn += ".ftz";
183
- } else {
184
- fn += ".bin";
185
- }
186
- saveModel(fn);
187
- }
188
-
189
192
  void FastText::saveModel(const std::string& filename) {
190
193
  std::ofstream ofs(filename, std::ofstream::binary);
191
194
  if (!ofs.is_open()) {
192
195
  throw std::invalid_argument(filename + " cannot be opened for saving!");
193
196
  }
197
+ if (!input_ || !output_) {
198
+ throw std::runtime_error("Model never trained");
199
+ }
194
200
  signModel(ofs);
195
201
  args_->save(ofs);
196
202
  dict_->save(ofs);
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
224
230
  }
225
231
  }
226
232
 
233
+ void FastText::buildModel() {
234
+ auto loss = createLoss(output_);
235
+ bool normalizeGradient = (args_->model == model_name::sup);
236
+ model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
237
+ }
238
+
227
239
  void FastText::loadModel(std::istream& in) {
228
240
  args_ = std::make_shared<Args>();
229
241
  input_ = std::make_shared<DenseMatrix>();
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
256
268
  }
257
269
  output_->load(in);
258
270
 
259
- auto loss = createLoss(output_);
260
- bool normalizeGradient = (args_->model == model_name::sup);
261
- model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
271
+ buildModel();
262
272
  }
263
273
 
264
- void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
265
- std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
266
- double t =
267
- std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
268
- .count();
274
+ std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
275
+ double t = utils::getDuration(start_, std::chrono::steady_clock::now());
269
276
  double lr = args_->lr * (1.0 - progress);
270
277
  double wst = 0;
271
278
 
272
279
  int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
273
280
 
274
281
  if (progress > 0 && t >= 0) {
275
- progress = progress * 100;
276
- eta = t * (100 - progress) / progress;
282
+ eta = t * (1 - progress) / progress;
277
283
  wst = double(tokenCount_) / t / args_->thread;
278
284
  }
279
- int32_t etah = eta / 3600;
280
- int32_t etam = (eta % 3600) / 60;
285
+
286
+ return std::tuple<double, double, int64_t>(wst, lr, eta);
287
+ }
288
+
289
+ void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
290
+ double wst;
291
+ double lr;
292
+ int64_t eta;
293
+ std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
281
294
 
282
295
  log_stream << std::fixed;
283
296
  log_stream << "Progress: ";
284
- log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
297
+ log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
285
298
  log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
286
299
  log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
287
- log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
288
- log_stream << " ETA: " << std::setw(3) << etah;
289
- log_stream << "h" << std::setw(2) << etam << "m";
300
+ log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
301
+ log_stream << " ETA: " << utils::ClockPrint(eta);
290
302
  log_stream << std::flush;
291
303
  }
292
304
 
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
299
311
  std::iota(idx.begin(), idx.end(), 0);
300
312
  auto eosid = dict_->getId(Dictionary::EOS);
301
313
  std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
314
+ if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
315
+ return false;
316
+ }
302
317
  return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
303
318
  });
304
319
  idx.erase(idx.begin() + cutoff, idx.end());
305
320
  return idx;
306
321
  }
307
322
 
308
- void FastText::quantize(const Args& qargs) {
323
+ void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
309
324
  if (args_->model != model_name::sup) {
310
325
  throw std::invalid_argument(
311
326
  "For now we only support quantization of supervised models");
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
337
352
  args_->verbose = qargs.verbose;
338
353
  auto loss = createLoss(output_);
339
354
  model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
340
- startThreads();
355
+ startThreads(callback);
341
356
  }
342
357
  }
343
-
344
358
  input_ = std::make_shared<QuantMatrix>(
345
359
  std::move(*(input.get())), qargs.dsub, qargs.qnorm);
346
360
 
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
348
362
  output_ = std::make_shared<QuantMatrix>(
349
363
  std::move(*(output.get())), 2, qargs.qnorm);
350
364
  }
351
-
352
365
  quant_ = true;
353
366
  auto loss = createLoss(output_);
354
367
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
@@ -408,7 +421,7 @@ void FastText::skipgram(
408
421
 
409
422
  std::tuple<int64_t, double, double>
410
423
  FastText::test(std::istream& in, int32_t k, real threshold) {
411
- Meter meter;
424
+ Meter meter(false);
412
425
  test(in, k, threshold, meter);
413
426
 
414
427
  return std::tuple<int64_t, double, double>(
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
420
433
  std::vector<int32_t> line;
421
434
  std::vector<int32_t> labels;
422
435
  Predictions predictions;
436
+ Model::State state(args_->dim, dict_->nlabels(), 0);
437
+ in.clear();
438
+ in.seekg(0, std::ios_base::beg);
423
439
 
424
440
  while (in.peek() != EOF) {
425
441
  line.clear();
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
521
537
  return result;
522
538
  }
523
539
 
524
- // deprecated. use getNgramVectors instead
525
- void FastText::ngramVectors(std::string word) {
526
- std::vector<std::pair<std::string, Vector>> ngramVectors =
527
- getNgramVectors(word);
528
-
529
- for (const auto& ngramVector : ngramVectors) {
530
- std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
531
- }
532
- }
533
-
534
540
  void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
535
541
  Vector vec(args_->dim);
536
542
  wordVectors.zero();
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
598
604
  return heap;
599
605
  }
600
606
 
601
- // depracted. use getNN instead
602
- void FastText::findNN(
603
- const DenseMatrix& wordVectors,
604
- const Vector& query,
605
- int32_t k,
606
- const std::set<std::string>& banSet,
607
- std::vector<std::pair<real, std::string>>& results) {
608
- results.clear();
609
- results = getNN(wordVectors, query, k, banSet);
610
- }
611
-
612
607
  std::vector<std::pair<real, std::string>> FastText::getAnalogies(
613
608
  int32_t k,
614
609
  const std::string& wordA,
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
630
625
  return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
631
626
  }
632
627
 
633
- // depreacted, use getAnalogies instead
634
- void FastText::analogies(int32_t k) {
635
- std::string prompt("Query triplet (A - B + C)? ");
636
- std::string wordA, wordB, wordC;
637
- std::cout << prompt;
638
- while (true) {
639
- std::cin >> wordA;
640
- std::cin >> wordB;
641
- std::cin >> wordC;
642
- auto results = getAnalogies(k, wordA, wordB, wordC);
643
-
644
- for (auto& pair : results) {
645
- std::cout << pair.second << " " << pair.first << std::endl;
646
- }
647
- std::cout << prompt;
648
- }
628
+ bool FastText::keepTraining(const int64_t ntokens) const {
629
+ return tokenCount_ < args_->epoch * ntokens && !trainException_;
649
630
  }
650
631
 
651
- void FastText::trainThread(int32_t threadId) {
632
+ void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
652
633
  std::ifstream ifs(args_->input);
653
634
  utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
654
635
 
655
- Model::State state(args_->dim, output_->size(0), threadId);
636
+ Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
656
637
 
657
638
  const int64_t ntokens = dict_->ntokens();
658
639
  int64_t localTokenCount = 0;
659
640
  std::vector<int32_t> line, labels;
660
- while (tokenCount_ < args_->epoch * ntokens) {
661
- real progress = real(tokenCount_) / (args_->epoch * ntokens);
662
- real lr = args_->lr * (1.0 - progress);
663
- if (args_->model == model_name::sup) {
664
- localTokenCount += dict_->getLine(ifs, line, labels);
665
- supervised(state, lr, line, labels);
666
- } else if (args_->model == model_name::cbow) {
667
- localTokenCount += dict_->getLine(ifs, line, state.rng);
668
- cbow(state, lr, line);
669
- } else if (args_->model == model_name::sg) {
670
- localTokenCount += dict_->getLine(ifs, line, state.rng);
671
- skipgram(state, lr, line);
672
- }
673
- if (localTokenCount > args_->lrUpdateRate) {
674
- tokenCount_ += localTokenCount;
675
- localTokenCount = 0;
676
- if (threadId == 0 && args_->verbose > 1)
677
- loss_ = state.getLoss();
641
+ uint64_t callbackCounter = 0;
642
+ try {
643
+ while (keepTraining(ntokens)) {
644
+ real progress = real(tokenCount_) / (args_->epoch * ntokens);
645
+ if (callback && ((callbackCounter++ % 64) == 0)) {
646
+ double wst;
647
+ double lr;
648
+ int64_t eta;
649
+ std::tie<double, double, int64_t>(wst, lr, eta) =
650
+ progressInfo(progress);
651
+ callback(progress, loss_, wst, lr, eta);
652
+ }
653
+ real lr = args_->lr * (1.0 - progress);
654
+ if (args_->model == model_name::sup) {
655
+ localTokenCount += dict_->getLine(ifs, line, labels);
656
+ supervised(state, lr, line, labels);
657
+ } else if (args_->model == model_name::cbow) {
658
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
659
+ cbow(state, lr, line);
660
+ } else if (args_->model == model_name::sg) {
661
+ localTokenCount += dict_->getLine(ifs, line, state.rng);
662
+ skipgram(state, lr, line);
663
+ }
664
+ if (localTokenCount > args_->lrUpdateRate) {
665
+ tokenCount_ += localTokenCount;
666
+ localTokenCount = 0;
667
+ if (threadId == 0 && args_->verbose > 1) {
668
+ loss_ = state.getLoss();
669
+ }
670
+ }
678
671
  }
672
+ } catch (DenseMatrix::EncounteredNaNError&) {
673
+ trainException_ = std::current_exception();
679
674
  }
680
675
  if (threadId == 0)
681
676
  loss_ = state.getLoss();
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
713
708
  dict_->init();
714
709
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
715
710
  dict_->nwords() + args_->bucket, args_->dim);
716
- input->uniform(1.0 / args_->dim);
711
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
717
712
 
718
713
  for (size_t i = 0; i < n; i++) {
719
714
  int32_t idx = dict_->getId(words[i]);
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
727
722
  return input;
728
723
  }
729
724
 
730
- void FastText::loadVectors(const std::string& filename) {
731
- input_ = getInputMatrixFromFile(filename);
732
- }
733
-
734
725
  std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
735
726
  std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
736
727
  dict_->nwords() + args_->bucket, args_->dim);
737
- input->uniform(1.0 / args_->dim);
728
+ input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
738
729
 
739
730
  return input;
740
731
  }
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
749
740
  return output;
750
741
  }
751
742
 
752
- void FastText::train(const Args& args) {
743
+ void FastText::train(const Args& args, const TrainCallback& callback) {
753
744
  args_ = std::make_shared<Args>(args);
754
745
  dict_ = std::make_shared<Dictionary>(args_);
755
746
  if (args_->input == "-") {
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
770
761
  input_ = createRandomMatrix();
771
762
  }
772
763
  output_ = createTrainOutputMatrix();
764
+ quant_ = false;
773
765
  auto loss = createLoss(output_);
774
766
  bool normalizeGradient = (args_->model == model_name::sup);
775
767
  model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
776
- startThreads();
768
+ startThreads(callback);
769
+ }
770
+
771
+ void FastText::abort() {
772
+ try {
773
+ throw AbortError();
774
+ } catch (AbortError&) {
775
+ trainException_ = std::current_exception();
776
+ }
777
777
  }
778
778
 
779
- void FastText::startThreads() {
779
+ void FastText::startThreads(const TrainCallback& callback) {
780
780
  start_ = std::chrono::steady_clock::now();
781
781
  tokenCount_ = 0;
782
782
  loss_ = -1;
783
+ trainException_ = nullptr;
783
784
  std::vector<std::thread> threads;
784
- for (int32_t i = 0; i < args_->thread; i++) {
785
- threads.push_back(std::thread([=]() { trainThread(i); }));
785
+ if (args_->thread > 1) {
786
+ for (int32_t i = 0; i < args_->thread; i++) {
787
+ threads.push_back(std::thread([=]() { trainThread(i, callback); }));
788
+ }
789
+ } else {
790
+ // webassembly can't instantiate `std::thread`
791
+ trainThread(0, callback);
786
792
  }
787
793
  const int64_t ntokens = dict_->ntokens();
788
794
  // Same condition as trainThread
789
- while (tokenCount_ < args_->epoch * ntokens) {
795
+ while (keepTraining(ntokens)) {
790
796
  std::this_thread::sleep_for(std::chrono::milliseconds(100));
791
797
  if (loss_ >= 0 && args_->verbose > 1) {
792
798
  real progress = real(tokenCount_) / (args_->epoch * ntokens);
@@ -794,9 +800,14 @@ void FastText::startThreads() {
794
800
  printInfo(progress, loss_, std::cerr);
795
801
  }
796
802
  }
797
- for (int32_t i = 0; i < args_->thread; i++) {
803
+ for (int32_t i = 0; i < threads.size(); i++) {
798
804
  threads[i].join();
799
805
  }
806
+ if (trainException_) {
807
+ std::exception_ptr exception = trainException_;
808
+ trainException_ = nullptr;
809
+ std::rethrow_exception(exception);
810
+ }
800
811
  if (args_->verbose > 0) {
801
812
  std::cerr << "\r";
802
813
  printInfo(1.0, loss_, std::cerr);