fasttext 0.1.2 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/LICENSE.txt +18 -18
- data/README.md +26 -19
- data/ext/fasttext/ext.cpp +131 -134
- data/ext/fasttext/extconf.rb +2 -4
- data/lib/fasttext/classifier.rb +23 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +11 -5
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +14 -69
- data/lib/fasttext/ext.bundle +0 -0
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
|
|
47
47
|
}
|
48
48
|
}
|
49
49
|
|
50
|
-
FastText::FastText()
|
50
|
+
FastText::FastText()
|
51
|
+
: quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
|
51
52
|
|
52
53
|
void FastText::addInputVector(Vector& vec, int32_t ind) const {
|
53
54
|
vec.addRow(*input_, ind);
|
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
|
|
69
70
|
return std::dynamic_pointer_cast<DenseMatrix>(input_);
|
70
71
|
}
|
71
72
|
|
73
|
+
void FastText::setMatrices(
|
74
|
+
const std::shared_ptr<DenseMatrix>& inputMatrix,
|
75
|
+
const std::shared_ptr<DenseMatrix>& outputMatrix) {
|
76
|
+
assert(input_->size(1) == output_->size(1));
|
77
|
+
|
78
|
+
input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
|
79
|
+
output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
|
80
|
+
wordVectors_.reset();
|
81
|
+
args_->dim = input_->size(1);
|
82
|
+
|
83
|
+
buildModel();
|
84
|
+
}
|
85
|
+
|
72
86
|
std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
|
73
87
|
if (quant_ && args_->qout) {
|
74
88
|
throw std::runtime_error("Can't export quantized matrix");
|
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
|
|
86
100
|
return dict_->nwords() + h;
|
87
101
|
}
|
88
102
|
|
103
|
+
int32_t FastText::getLabelId(const std::string& label) const {
|
104
|
+
int32_t labelId = dict_->getId(label);
|
105
|
+
if (labelId != -1) {
|
106
|
+
labelId -= dict_->nwords();
|
107
|
+
}
|
108
|
+
return labelId;
|
109
|
+
}
|
110
|
+
|
89
111
|
void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
90
112
|
const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
|
91
113
|
vec.zero();
|
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
|
97
119
|
}
|
98
120
|
}
|
99
121
|
|
100
|
-
void FastText::getVector(Vector& vec, const std::string& word) const {
|
101
|
-
getWordVector(vec, word);
|
102
|
-
}
|
103
|
-
|
104
122
|
void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
105
123
|
vec.zero();
|
106
124
|
int32_t h = dict_->hash(subword) % args_->bucket;
|
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
|
109
127
|
}
|
110
128
|
|
111
129
|
void FastText::saveVectors(const std::string& filename) {
|
130
|
+
if (!input_ || !output_) {
|
131
|
+
throw std::runtime_error("Model never trained");
|
132
|
+
}
|
112
133
|
std::ofstream ofs(filename);
|
113
134
|
if (!ofs.is_open()) {
|
114
135
|
throw std::invalid_argument(
|
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
|
|
124
145
|
ofs.close();
|
125
146
|
}
|
126
147
|
|
127
|
-
void FastText::saveVectors() {
|
128
|
-
saveVectors(args_->output + ".vec");
|
129
|
-
}
|
130
|
-
|
131
148
|
void FastText::saveOutput(const std::string& filename) {
|
132
149
|
std::ofstream ofs(filename);
|
133
150
|
if (!ofs.is_open()) {
|
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
|
|
152
169
|
ofs.close();
|
153
170
|
}
|
154
171
|
|
155
|
-
void FastText::saveOutput() {
|
156
|
-
saveOutput(args_->output + ".output");
|
157
|
-
}
|
158
|
-
|
159
172
|
bool FastText::checkModel(std::istream& in) {
|
160
173
|
int32_t magic;
|
161
174
|
in.read((char*)&(magic), sizeof(int32_t));
|
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
|
|
176
189
|
out.write((char*)&(version), sizeof(int32_t));
|
177
190
|
}
|
178
191
|
|
179
|
-
void FastText::saveModel() {
|
180
|
-
std::string fn(args_->output);
|
181
|
-
if (quant_) {
|
182
|
-
fn += ".ftz";
|
183
|
-
} else {
|
184
|
-
fn += ".bin";
|
185
|
-
}
|
186
|
-
saveModel(fn);
|
187
|
-
}
|
188
|
-
|
189
192
|
void FastText::saveModel(const std::string& filename) {
|
190
193
|
std::ofstream ofs(filename, std::ofstream::binary);
|
191
194
|
if (!ofs.is_open()) {
|
192
195
|
throw std::invalid_argument(filename + " cannot be opened for saving!");
|
193
196
|
}
|
197
|
+
if (!input_ || !output_) {
|
198
|
+
throw std::runtime_error("Model never trained");
|
199
|
+
}
|
194
200
|
signModel(ofs);
|
195
201
|
args_->save(ofs);
|
196
202
|
dict_->save(ofs);
|
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
|
|
224
230
|
}
|
225
231
|
}
|
226
232
|
|
233
|
+
void FastText::buildModel() {
|
234
|
+
auto loss = createLoss(output_);
|
235
|
+
bool normalizeGradient = (args_->model == model_name::sup);
|
236
|
+
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
237
|
+
}
|
238
|
+
|
227
239
|
void FastText::loadModel(std::istream& in) {
|
228
240
|
args_ = std::make_shared<Args>();
|
229
241
|
input_ = std::make_shared<DenseMatrix>();
|
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
|
|
256
268
|
}
|
257
269
|
output_->load(in);
|
258
270
|
|
259
|
-
|
260
|
-
bool normalizeGradient = (args_->model == model_name::sup);
|
261
|
-
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
271
|
+
buildModel();
|
262
272
|
}
|
263
273
|
|
264
|
-
|
265
|
-
|
266
|
-
double t =
|
267
|
-
std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
|
268
|
-
.count();
|
274
|
+
std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
|
275
|
+
double t = utils::getDuration(start_, std::chrono::steady_clock::now());
|
269
276
|
double lr = args_->lr * (1.0 - progress);
|
270
277
|
double wst = 0;
|
271
278
|
|
272
279
|
int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
|
273
280
|
|
274
281
|
if (progress > 0 && t >= 0) {
|
275
|
-
|
276
|
-
eta = t * (100 - progress) / progress;
|
282
|
+
eta = t * (1 - progress) / progress;
|
277
283
|
wst = double(tokenCount_) / t / args_->thread;
|
278
284
|
}
|
279
|
-
|
280
|
-
|
285
|
+
|
286
|
+
return std::tuple<double, double, int64_t>(wst, lr, eta);
|
287
|
+
}
|
288
|
+
|
289
|
+
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
|
290
|
+
double wst;
|
291
|
+
double lr;
|
292
|
+
int64_t eta;
|
293
|
+
std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
|
281
294
|
|
282
295
|
log_stream << std::fixed;
|
283
296
|
log_stream << "Progress: ";
|
284
|
-
log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
|
297
|
+
log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
|
285
298
|
log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
|
286
299
|
log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
|
287
|
-
log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
|
288
|
-
log_stream << " ETA: " <<
|
289
|
-
log_stream << "h" << std::setw(2) << etam << "m";
|
300
|
+
log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
|
301
|
+
log_stream << " ETA: " << utils::ClockPrint(eta);
|
290
302
|
log_stream << std::flush;
|
291
303
|
}
|
292
304
|
|
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
|
|
299
311
|
std::iota(idx.begin(), idx.end(), 0);
|
300
312
|
auto eosid = dict_->getId(Dictionary::EOS);
|
301
313
|
std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
|
314
|
+
if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
|
315
|
+
return false;
|
316
|
+
}
|
302
317
|
return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
|
303
318
|
});
|
304
319
|
idx.erase(idx.begin() + cutoff, idx.end());
|
305
320
|
return idx;
|
306
321
|
}
|
307
322
|
|
308
|
-
void FastText::quantize(const Args& qargs) {
|
323
|
+
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
|
309
324
|
if (args_->model != model_name::sup) {
|
310
325
|
throw std::invalid_argument(
|
311
326
|
"For now we only support quantization of supervised models");
|
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
|
|
337
352
|
args_->verbose = qargs.verbose;
|
338
353
|
auto loss = createLoss(output_);
|
339
354
|
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
|
340
|
-
startThreads();
|
355
|
+
startThreads(callback);
|
341
356
|
}
|
342
357
|
}
|
343
|
-
|
344
358
|
input_ = std::make_shared<QuantMatrix>(
|
345
359
|
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
|
346
360
|
|
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
|
|
348
362
|
output_ = std::make_shared<QuantMatrix>(
|
349
363
|
std::move(*(output.get())), 2, qargs.qnorm);
|
350
364
|
}
|
351
|
-
|
352
365
|
quant_ = true;
|
353
366
|
auto loss = createLoss(output_);
|
354
367
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
@@ -408,7 +421,7 @@ void FastText::skipgram(
|
|
408
421
|
|
409
422
|
std::tuple<int64_t, double, double>
|
410
423
|
FastText::test(std::istream& in, int32_t k, real threshold) {
|
411
|
-
Meter meter;
|
424
|
+
Meter meter(false);
|
412
425
|
test(in, k, threshold, meter);
|
413
426
|
|
414
427
|
return std::tuple<int64_t, double, double>(
|
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
|
|
420
433
|
std::vector<int32_t> line;
|
421
434
|
std::vector<int32_t> labels;
|
422
435
|
Predictions predictions;
|
436
|
+
Model::State state(args_->dim, dict_->nlabels(), 0);
|
437
|
+
in.clear();
|
438
|
+
in.seekg(0, std::ios_base::beg);
|
423
439
|
|
424
440
|
while (in.peek() != EOF) {
|
425
441
|
line.clear();
|
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
|
|
521
537
|
return result;
|
522
538
|
}
|
523
539
|
|
524
|
-
// deprecated. use getNgramVectors instead
|
525
|
-
void FastText::ngramVectors(std::string word) {
|
526
|
-
std::vector<std::pair<std::string, Vector>> ngramVectors =
|
527
|
-
getNgramVectors(word);
|
528
|
-
|
529
|
-
for (const auto& ngramVector : ngramVectors) {
|
530
|
-
std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
|
531
|
-
}
|
532
|
-
}
|
533
|
-
|
534
540
|
void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
|
535
541
|
Vector vec(args_->dim);
|
536
542
|
wordVectors.zero();
|
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
|
|
598
604
|
return heap;
|
599
605
|
}
|
600
606
|
|
601
|
-
// depracted. use getNN instead
|
602
|
-
void FastText::findNN(
|
603
|
-
const DenseMatrix& wordVectors,
|
604
|
-
const Vector& query,
|
605
|
-
int32_t k,
|
606
|
-
const std::set<std::string>& banSet,
|
607
|
-
std::vector<std::pair<real, std::string>>& results) {
|
608
|
-
results.clear();
|
609
|
-
results = getNN(wordVectors, query, k, banSet);
|
610
|
-
}
|
611
|
-
|
612
607
|
std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
613
608
|
int32_t k,
|
614
609
|
const std::string& wordA,
|
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
|
630
625
|
return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
|
631
626
|
}
|
632
627
|
|
633
|
-
|
634
|
-
|
635
|
-
std::string prompt("Query triplet (A - B + C)? ");
|
636
|
-
std::string wordA, wordB, wordC;
|
637
|
-
std::cout << prompt;
|
638
|
-
while (true) {
|
639
|
-
std::cin >> wordA;
|
640
|
-
std::cin >> wordB;
|
641
|
-
std::cin >> wordC;
|
642
|
-
auto results = getAnalogies(k, wordA, wordB, wordC);
|
643
|
-
|
644
|
-
for (auto& pair : results) {
|
645
|
-
std::cout << pair.second << " " << pair.first << std::endl;
|
646
|
-
}
|
647
|
-
std::cout << prompt;
|
648
|
-
}
|
628
|
+
bool FastText::keepTraining(const int64_t ntokens) const {
|
629
|
+
return tokenCount_ < args_->epoch * ntokens && !trainException_;
|
649
630
|
}
|
650
631
|
|
651
|
-
void FastText::trainThread(int32_t threadId) {
|
632
|
+
void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
|
652
633
|
std::ifstream ifs(args_->input);
|
653
634
|
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
|
654
635
|
|
655
|
-
Model::State state(args_->dim, output_->size(0), threadId);
|
636
|
+
Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
|
656
637
|
|
657
638
|
const int64_t ntokens = dict_->ntokens();
|
658
639
|
int64_t localTokenCount = 0;
|
659
640
|
std::vector<int32_t> line, labels;
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
if (
|
677
|
-
|
641
|
+
uint64_t callbackCounter = 0;
|
642
|
+
try {
|
643
|
+
while (keepTraining(ntokens)) {
|
644
|
+
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
645
|
+
if (callback && ((callbackCounter++ % 64) == 0)) {
|
646
|
+
double wst;
|
647
|
+
double lr;
|
648
|
+
int64_t eta;
|
649
|
+
std::tie<double, double, int64_t>(wst, lr, eta) =
|
650
|
+
progressInfo(progress);
|
651
|
+
callback(progress, loss_, wst, lr, eta);
|
652
|
+
}
|
653
|
+
real lr = args_->lr * (1.0 - progress);
|
654
|
+
if (args_->model == model_name::sup) {
|
655
|
+
localTokenCount += dict_->getLine(ifs, line, labels);
|
656
|
+
supervised(state, lr, line, labels);
|
657
|
+
} else if (args_->model == model_name::cbow) {
|
658
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
659
|
+
cbow(state, lr, line);
|
660
|
+
} else if (args_->model == model_name::sg) {
|
661
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
662
|
+
skipgram(state, lr, line);
|
663
|
+
}
|
664
|
+
if (localTokenCount > args_->lrUpdateRate) {
|
665
|
+
tokenCount_ += localTokenCount;
|
666
|
+
localTokenCount = 0;
|
667
|
+
if (threadId == 0 && args_->verbose > 1) {
|
668
|
+
loss_ = state.getLoss();
|
669
|
+
}
|
670
|
+
}
|
678
671
|
}
|
672
|
+
} catch (DenseMatrix::EncounteredNaNError&) {
|
673
|
+
trainException_ = std::current_exception();
|
679
674
|
}
|
680
675
|
if (threadId == 0)
|
681
676
|
loss_ = state.getLoss();
|
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
713
708
|
dict_->init();
|
714
709
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
715
710
|
dict_->nwords() + args_->bucket, args_->dim);
|
716
|
-
input->uniform(1.0 / args_->dim);
|
711
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
717
712
|
|
718
713
|
for (size_t i = 0; i < n; i++) {
|
719
714
|
int32_t idx = dict_->getId(words[i]);
|
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
727
722
|
return input;
|
728
723
|
}
|
729
724
|
|
730
|
-
void FastText::loadVectors(const std::string& filename) {
|
731
|
-
input_ = getInputMatrixFromFile(filename);
|
732
|
-
}
|
733
|
-
|
734
725
|
std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
|
735
726
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
736
727
|
dict_->nwords() + args_->bucket, args_->dim);
|
737
|
-
input->uniform(1.0 / args_->dim);
|
728
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
738
729
|
|
739
730
|
return input;
|
740
731
|
}
|
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
|
|
749
740
|
return output;
|
750
741
|
}
|
751
742
|
|
752
|
-
void FastText::train(const Args& args) {
|
743
|
+
void FastText::train(const Args& args, const TrainCallback& callback) {
|
753
744
|
args_ = std::make_shared<Args>(args);
|
754
745
|
dict_ = std::make_shared<Dictionary>(args_);
|
755
746
|
if (args_->input == "-") {
|
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
|
|
770
761
|
input_ = createRandomMatrix();
|
771
762
|
}
|
772
763
|
output_ = createTrainOutputMatrix();
|
764
|
+
quant_ = false;
|
773
765
|
auto loss = createLoss(output_);
|
774
766
|
bool normalizeGradient = (args_->model == model_name::sup);
|
775
767
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
776
|
-
startThreads();
|
768
|
+
startThreads(callback);
|
769
|
+
}
|
770
|
+
|
771
|
+
void FastText::abort() {
|
772
|
+
try {
|
773
|
+
throw AbortError();
|
774
|
+
} catch (AbortError&) {
|
775
|
+
trainException_ = std::current_exception();
|
776
|
+
}
|
777
777
|
}
|
778
778
|
|
779
|
-
void FastText::startThreads() {
|
779
|
+
void FastText::startThreads(const TrainCallback& callback) {
|
780
780
|
start_ = std::chrono::steady_clock::now();
|
781
781
|
tokenCount_ = 0;
|
782
782
|
loss_ = -1;
|
783
|
+
trainException_ = nullptr;
|
783
784
|
std::vector<std::thread> threads;
|
784
|
-
|
785
|
-
|
785
|
+
if (args_->thread > 1) {
|
786
|
+
for (int32_t i = 0; i < args_->thread; i++) {
|
787
|
+
threads.push_back(std::thread([=]() { trainThread(i, callback); }));
|
788
|
+
}
|
789
|
+
} else {
|
790
|
+
// webassembly can't instantiate `std::thread`
|
791
|
+
trainThread(0, callback);
|
786
792
|
}
|
787
793
|
const int64_t ntokens = dict_->ntokens();
|
788
794
|
// Same condition as trainThread
|
789
|
-
while (
|
795
|
+
while (keepTraining(ntokens)) {
|
790
796
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
791
797
|
if (loss_ >= 0 && args_->verbose > 1) {
|
792
798
|
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
@@ -794,9 +800,14 @@ void FastText::startThreads() {
|
|
794
800
|
printInfo(progress, loss_, std::cerr);
|
795
801
|
}
|
796
802
|
}
|
797
|
-
for (int32_t i = 0; i <
|
803
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
798
804
|
threads[i].join();
|
799
805
|
}
|
806
|
+
if (trainException_) {
|
807
|
+
std::exception_ptr exception = trainException_;
|
808
|
+
trainException_ = nullptr;
|
809
|
+
std::rethrow_exception(exception);
|
810
|
+
}
|
800
811
|
if (args_->verbose > 0) {
|
801
812
|
std::cerr << "\r";
|
802
813
|
printInfo(1.0, loss_, std::cerr);
|
@@ -12,6 +12,7 @@
|
|
12
12
|
|
13
13
|
#include <atomic>
|
14
14
|
#include <chrono>
|
15
|
+
#include <functional>
|
15
16
|
#include <iostream>
|
16
17
|
#include <memory>
|
17
18
|
#include <queue>
|
@@ -31,24 +32,29 @@
|
|
31
32
|
namespace fasttext {
|
32
33
|
|
33
34
|
class FastText {
|
35
|
+
public:
|
36
|
+
using TrainCallback =
|
37
|
+
std::function<void(float, float, double, double, int64_t)>;
|
38
|
+
|
34
39
|
protected:
|
35
40
|
std::shared_ptr<Args> args_;
|
36
41
|
std::shared_ptr<Dictionary> dict_;
|
37
|
-
|
38
42
|
std::shared_ptr<Matrix> input_;
|
39
43
|
std::shared_ptr<Matrix> output_;
|
40
|
-
|
41
44
|
std::shared_ptr<Model> model_;
|
42
|
-
|
43
45
|
std::atomic<int64_t> tokenCount_{};
|
44
46
|
std::atomic<real> loss_{};
|
45
|
-
|
46
47
|
std::chrono::steady_clock::time_point start_;
|
48
|
+
bool quant_;
|
49
|
+
int32_t version;
|
50
|
+
std::unique_ptr<DenseMatrix> wordVectors_;
|
51
|
+
std::exception_ptr trainException_;
|
52
|
+
|
47
53
|
void signModel(std::ostream&);
|
48
54
|
bool checkModel(std::istream&);
|
49
|
-
void startThreads();
|
55
|
+
void startThreads(const TrainCallback& callback = {});
|
50
56
|
void addInputVector(Vector&, int32_t) const;
|
51
|
-
void trainThread(int32_t);
|
57
|
+
void trainThread(int32_t, const TrainCallback& callback);
|
52
58
|
std::vector<std::pair<real, std::string>> getNN(
|
53
59
|
const DenseMatrix& wordVectors,
|
54
60
|
const Vector& queryVec,
|
@@ -68,10 +74,11 @@ class FastText {
|
|
68
74
|
const std::vector<int32_t>& labels);
|
69
75
|
void cbow(Model::State& state, real lr, const std::vector<int32_t>& line);
|
70
76
|
void skipgram(Model::State& state, real lr, const std::vector<int32_t>& line);
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
77
|
+
std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
|
78
|
+
void precomputeWordVectors(DenseMatrix& wordVectors);
|
79
|
+
bool keepTraining(const int64_t ntokens) const;
|
80
|
+
void buildModel();
|
81
|
+
std::tuple<int64_t, double, double> progressInfo(real progress);
|
75
82
|
|
76
83
|
public:
|
77
84
|
FastText();
|
@@ -80,6 +87,8 @@ class FastText {
|
|
80
87
|
|
81
88
|
int32_t getSubwordId(const std::string& subword) const;
|
82
89
|
|
90
|
+
int32_t getLabelId(const std::string& label) const;
|
91
|
+
|
83
92
|
void getWordVector(Vector& vec, const std::string& word) const;
|
84
93
|
|
85
94
|
void getSubwordVector(Vector& vec, const std::string& subword) const;
|
@@ -95,6 +104,10 @@ class FastText {
|
|
95
104
|
|
96
105
|
std::shared_ptr<const DenseMatrix> getInputMatrix() const;
|
97
106
|
|
107
|
+
void setMatrices(
|
108
|
+
const std::shared_ptr<DenseMatrix>& inputMatrix,
|
109
|
+
const std::shared_ptr<DenseMatrix>& outputMatrix);
|
110
|
+
|
98
111
|
std::shared_ptr<const DenseMatrix> getOutputMatrix() const;
|
99
112
|
|
100
113
|
void saveVectors(const std::string& filename);
|
@@ -109,7 +122,7 @@ class FastText {
|
|
109
122
|
|
110
123
|
void getSentenceVector(std::istream& in, Vector& vec);
|
111
124
|
|
112
|
-
void quantize(const Args& qargs);
|
125
|
+
void quantize(const Args& qargs, const TrainCallback& callback = {});
|
113
126
|
|
114
127
|
std::tuple<int64_t, double, double>
|
115
128
|
test(std::istream& in, int32_t k, real threshold = 0.0);
|
@@ -141,51 +154,17 @@ class FastText {
|
|
141
154
|
const std::string& wordB,
|
142
155
|
const std::string& wordC);
|
143
156
|
|
144
|
-
void train(const Args& args);
|
157
|
+
void train(const Args& args, const TrainCallback& callback = {});
|
158
|
+
|
159
|
+
void abort();
|
145
160
|
|
146
161
|
int getDimension() const;
|
147
162
|
|
148
163
|
bool isQuant() const;
|
149
164
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
"getVector is being deprecated and replaced by getWordVector.")
|
155
|
-
void getVector(Vector& vec, const std::string& word) const;
|
156
|
-
|
157
|
-
FASTTEXT_DEPRECATED(
|
158
|
-
"ngramVectors is being deprecated and replaced by getNgramVectors.")
|
159
|
-
void ngramVectors(std::string word);
|
160
|
-
|
161
|
-
FASTTEXT_DEPRECATED(
|
162
|
-
"analogies is being deprecated and replaced by getAnalogies.")
|
163
|
-
void analogies(int32_t k);
|
164
|
-
|
165
|
-
FASTTEXT_DEPRECATED("selectEmbeddings is being deprecated.")
|
166
|
-
std::vector<int32_t> selectEmbeddings(int32_t cutoff) const;
|
167
|
-
|
168
|
-
FASTTEXT_DEPRECATED(
|
169
|
-
"saveVectors is being deprecated, please use the other signature.")
|
170
|
-
void saveVectors();
|
171
|
-
|
172
|
-
FASTTEXT_DEPRECATED(
|
173
|
-
"saveOutput is being deprecated, please use the other signature.")
|
174
|
-
void saveOutput();
|
175
|
-
|
176
|
-
FASTTEXT_DEPRECATED(
|
177
|
-
"saveModel is being deprecated, please use the other signature.")
|
178
|
-
void saveModel();
|
179
|
-
|
180
|
-
FASTTEXT_DEPRECATED("precomputeWordVectors is being deprecated.")
|
181
|
-
void precomputeWordVectors(DenseMatrix& wordVectors);
|
182
|
-
|
183
|
-
FASTTEXT_DEPRECATED("findNN is being deprecated and replaced by getNN.")
|
184
|
-
void findNN(
|
185
|
-
const DenseMatrix& wordVectors,
|
186
|
-
const Vector& query,
|
187
|
-
int32_t k,
|
188
|
-
const std::set<std::string>& banSet,
|
189
|
-
std::vector<std::pair<real, std::string>>& results);
|
165
|
+
class AbortError : public std::runtime_error {
|
166
|
+
public:
|
167
|
+
AbortError() : std::runtime_error("Aborted.") {}
|
168
|
+
};
|
190
169
|
};
|
191
170
|
} // namespace fasttext
|
data/vendor/fastText/src/main.cc
CHANGED
@@ -11,6 +11,7 @@
|
|
11
11
|
#include <queue>
|
12
12
|
#include <stdexcept>
|
13
13
|
#include "args.h"
|
14
|
+
#include "autotune.h"
|
14
15
|
#include "fasttext.h"
|
15
16
|
|
16
17
|
using namespace fasttext;
|
@@ -20,19 +21,25 @@ void printUsage() {
|
|
20
21
|
<< "usage: fasttext <command> <args>\n\n"
|
21
22
|
<< "The commands supported by fasttext are:\n\n"
|
22
23
|
<< " supervised train a supervised classifier\n"
|
23
|
-
<< " quantize quantize a model to reduce the memory
|
24
|
+
<< " quantize quantize a model to reduce the memory "
|
25
|
+
"usage\n"
|
24
26
|
<< " test evaluate a supervised classifier\n"
|
25
|
-
<< " test-label print labels with precision and recall
|
27
|
+
<< " test-label print labels with precision and recall "
|
28
|
+
"scores\n"
|
26
29
|
<< " predict predict most likely labels\n"
|
27
|
-
<< " predict-prob predict most likely labels with
|
30
|
+
<< " predict-prob predict most likely labels with "
|
31
|
+
"probabilities\n"
|
28
32
|
<< " skipgram train a skipgram model\n"
|
29
33
|
<< " cbow train a cbow model\n"
|
30
34
|
<< " print-word-vectors print word vectors given a trained model\n"
|
31
|
-
<< " print-sentence-vectors print sentence vectors given a trained
|
32
|
-
|
35
|
+
<< " print-sentence-vectors print sentence vectors given a trained "
|
36
|
+
"model\n"
|
37
|
+
<< " print-ngrams print ngrams given a trained model and "
|
38
|
+
"word\n"
|
33
39
|
<< " nn query for nearest neighbors\n"
|
34
40
|
<< " analogies query for analogies\n"
|
35
|
-
<< " dump dump arguments,dictionary,input/output
|
41
|
+
<< " dump dump arguments,dictionary,input/output "
|
42
|
+
"vectors\n"
|
36
43
|
<< std::endl;
|
37
44
|
}
|
38
45
|
|
@@ -141,7 +148,7 @@ void test(const std::vector<std::string>& args) {
|
|
141
148
|
FastText fasttext;
|
142
149
|
fasttext.loadModel(model);
|
143
150
|
|
144
|
-
Meter meter;
|
151
|
+
Meter meter(false);
|
145
152
|
|
146
153
|
if (input == "-") {
|
147
154
|
fasttext.test(std::cin, k, threshold, meter);
|
@@ -351,19 +358,31 @@ void analogies(const std::vector<std::string> args) {
|
|
351
358
|
void train(const std::vector<std::string> args) {
|
352
359
|
Args a = Args();
|
353
360
|
a.parseArgs(args);
|
354
|
-
FastText fasttext;
|
355
|
-
std::string outputFileName
|
361
|
+
std::shared_ptr<FastText> fasttext = std::make_shared<FastText>();
|
362
|
+
std::string outputFileName;
|
363
|
+
|
364
|
+
if (a.hasAutotune() &&
|
365
|
+
a.getAutotuneModelSize() != Args::kUnlimitedModelSize) {
|
366
|
+
outputFileName = a.output + ".ftz";
|
367
|
+
} else {
|
368
|
+
outputFileName = a.output + ".bin";
|
369
|
+
}
|
356
370
|
std::ofstream ofs(outputFileName);
|
357
371
|
if (!ofs.is_open()) {
|
358
372
|
throw std::invalid_argument(
|
359
373
|
outputFileName + " cannot be opened for saving.");
|
360
374
|
}
|
361
375
|
ofs.close();
|
362
|
-
|
363
|
-
|
364
|
-
|
376
|
+
if (a.hasAutotune()) {
|
377
|
+
Autotune autotune(fasttext);
|
378
|
+
autotune.train(a);
|
379
|
+
} else {
|
380
|
+
fasttext->train(a);
|
381
|
+
}
|
382
|
+
fasttext->saveModel(outputFileName);
|
383
|
+
fasttext->saveVectors(a.output + ".vec");
|
365
384
|
if (a.saveOutput) {
|
366
|
-
fasttext
|
385
|
+
fasttext->saveOutput(a.output + ".output");
|
367
386
|
}
|
368
387
|
}
|
369
388
|
|