fasttext 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
@@ -0,0 +1,89 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the MIT license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
7
|
+
*/
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <istream>
|
12
|
+
#include <memory>
|
13
|
+
#include <random>
|
14
|
+
#include <thread>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include "args.h"
|
18
|
+
#include "fasttext.h"
|
19
|
+
|
20
|
+
namespace fasttext {
|
21
|
+
|
22
|
+
class AutotuneStrategy {
|
23
|
+
private:
|
24
|
+
Args bestArgs_;
|
25
|
+
int maxDuration_;
|
26
|
+
std::minstd_rand rng_;
|
27
|
+
int trials_;
|
28
|
+
int bestMinnIndex_;
|
29
|
+
int bestDsubExponent_;
|
30
|
+
int bestNonzeroBucket_;
|
31
|
+
int originalBucket_;
|
32
|
+
std::vector<int> minnChoices_;
|
33
|
+
int getIndex(int val, const std::vector<int>& choices);
|
34
|
+
|
35
|
+
public:
|
36
|
+
explicit AutotuneStrategy(
|
37
|
+
const Args& args,
|
38
|
+
std::minstd_rand::result_type seed);
|
39
|
+
Args ask(double elapsed);
|
40
|
+
void updateBest(const Args& args);
|
41
|
+
};
|
42
|
+
|
43
|
+
class Autotune {
|
44
|
+
protected:
|
45
|
+
std::shared_ptr<FastText> fastText_;
|
46
|
+
double elapsed_;
|
47
|
+
double bestScore_;
|
48
|
+
int32_t trials_;
|
49
|
+
int32_t sizeConstraintFailed_;
|
50
|
+
std::atomic<bool> continueTraining_;
|
51
|
+
std::unique_ptr<AutotuneStrategy> strategy_;
|
52
|
+
std::thread timer_;
|
53
|
+
|
54
|
+
bool keepTraining(double maxDuration) const;
|
55
|
+
void printInfo(double maxDuration);
|
56
|
+
void timer(
|
57
|
+
const std::chrono::steady_clock::time_point& start,
|
58
|
+
double maxDuration);
|
59
|
+
void abort();
|
60
|
+
void startTimer(const Args& args);
|
61
|
+
double getMetricScore(
|
62
|
+
Meter& meter,
|
63
|
+
const metric_name& metricName,
|
64
|
+
const double metricValue,
|
65
|
+
const std::string& metricLabel) const;
|
66
|
+
void printArgs(const Args& args, const Args& autotuneArgs);
|
67
|
+
void printSkippedArgs(const Args& autotuneArgs);
|
68
|
+
bool quantize(Args& args, const Args& autotuneArgs);
|
69
|
+
int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
|
70
|
+
const;
|
71
|
+
|
72
|
+
class TimeoutError : public std::runtime_error {
|
73
|
+
public:
|
74
|
+
TimeoutError() : std::runtime_error("Autotune timed out.") {}
|
75
|
+
};
|
76
|
+
|
77
|
+
public:
|
78
|
+
Autotune() = delete;
|
79
|
+
explicit Autotune(const std::shared_ptr<FastText>& fastText);
|
80
|
+
Autotune(const Autotune&) = delete;
|
81
|
+
Autotune(Autotune&&) = delete;
|
82
|
+
Autotune& operator=(const Autotune&) = delete;
|
83
|
+
Autotune& operator=(Autotune&&) = delete;
|
84
|
+
~Autotune() noexcept = default;
|
85
|
+
|
86
|
+
void train(const Args& args);
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace fasttext
|
@@ -8,11 +8,10 @@
|
|
8
8
|
|
9
9
|
#include "densematrix.h"
|
10
10
|
|
11
|
-
#include <exception>
|
12
11
|
#include <random>
|
13
12
|
#include <stdexcept>
|
13
|
+
#include <thread>
|
14
14
|
#include <utility>
|
15
|
-
|
16
15
|
#include "utils.h"
|
17
16
|
#include "vector.h"
|
18
17
|
|
@@ -25,18 +24,39 @@ DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
|
|
25
24
|
DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
|
26
25
|
: Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
|
27
26
|
|
27
|
+
DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr)
|
28
|
+
: Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {}
|
29
|
+
|
28
30
|
void DenseMatrix::zero() {
|
29
31
|
std::fill(data_.begin(), data_.end(), 0.0);
|
30
32
|
}
|
31
33
|
|
32
|
-
void DenseMatrix::
|
33
|
-
std::minstd_rand rng(
|
34
|
+
void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
|
35
|
+
std::minstd_rand rng(block + seed);
|
34
36
|
std::uniform_real_distribution<> uniform(-a, a);
|
35
|
-
|
37
|
+
int64_t blockSize = (m_ * n_) / 10;
|
38
|
+
for (int64_t i = blockSize * block;
|
39
|
+
i < (m_ * n_) && i < blockSize * (block + 1);
|
40
|
+
i++) {
|
36
41
|
data_[i] = uniform(rng);
|
37
42
|
}
|
38
43
|
}
|
39
44
|
|
45
|
+
void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
|
46
|
+
if (thread > 1) {
|
47
|
+
std::vector<std::thread> threads;
|
48
|
+
for (int i = 0; i < thread; i++) {
|
49
|
+
threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
|
50
|
+
}
|
51
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
52
|
+
threads[i].join();
|
53
|
+
}
|
54
|
+
} else {
|
55
|
+
// webassembly can't instantiate `std::thread`
|
56
|
+
uniformThread(a, 0, seed);
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
40
60
|
void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
|
41
61
|
if (ie == -1) {
|
42
62
|
ie = m_;
|
@@ -73,7 +93,7 @@ real DenseMatrix::l2NormRow(int64_t i) const {
|
|
73
93
|
norm += at(i, j) * at(i, j);
|
74
94
|
}
|
75
95
|
if (std::isnan(norm)) {
|
76
|
-
throw
|
96
|
+
throw EncounteredNaNError();
|
77
97
|
}
|
78
98
|
return std::sqrt(norm);
|
79
99
|
}
|
@@ -94,7 +114,7 @@ real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
|
|
94
114
|
d += at(i, j) * vec[j];
|
95
115
|
}
|
96
116
|
if (std::isnan(d)) {
|
97
|
-
throw
|
117
|
+
throw EncounteredNaNError();
|
98
118
|
}
|
99
119
|
return d;
|
100
120
|
}
|
@@ -8,12 +8,13 @@
|
|
8
8
|
|
9
9
|
#pragma once
|
10
10
|
|
11
|
+
#include <assert.h>
|
11
12
|
#include <cstdint>
|
12
13
|
#include <istream>
|
13
14
|
#include <ostream>
|
15
|
+
#include <stdexcept>
|
14
16
|
#include <vector>
|
15
17
|
|
16
|
-
#include <assert.h>
|
17
18
|
#include "matrix.h"
|
18
19
|
#include "real.h"
|
19
20
|
|
@@ -24,10 +25,12 @@ class Vector;
|
|
24
25
|
class DenseMatrix : public Matrix {
|
25
26
|
protected:
|
26
27
|
std::vector<real> data_;
|
28
|
+
void uniformThread(real, int, int32_t);
|
27
29
|
|
28
30
|
public:
|
29
31
|
DenseMatrix();
|
30
32
|
explicit DenseMatrix(int64_t, int64_t);
|
33
|
+
explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr);
|
31
34
|
DenseMatrix(const DenseMatrix&) = default;
|
32
35
|
DenseMatrix(DenseMatrix&&) noexcept;
|
33
36
|
DenseMatrix& operator=(const DenseMatrix&) = delete;
|
@@ -56,7 +59,7 @@ class DenseMatrix : public Matrix {
|
|
56
59
|
return n_;
|
57
60
|
}
|
58
61
|
void zero();
|
59
|
-
void uniform(real);
|
62
|
+
void uniform(real, unsigned int, int32_t);
|
60
63
|
|
61
64
|
void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
|
62
65
|
void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
|
@@ -71,5 +74,10 @@ class DenseMatrix : public Matrix {
|
|
71
74
|
void save(std::ostream&) const override;
|
72
75
|
void load(std::istream&) override;
|
73
76
|
void dump(std::ostream&) const override;
|
77
|
+
|
78
|
+
class EncounteredNaNError : public std::runtime_error {
|
79
|
+
public:
|
80
|
+
EncounteredNaNError() : std::runtime_error("Encountered NaN.") {}
|
81
|
+
};
|
74
82
|
};
|
75
83
|
} // namespace fasttext
|
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
|
|
47
47
|
}
|
48
48
|
}
|
49
49
|
|
50
|
-
FastText::FastText()
|
50
|
+
FastText::FastText()
|
51
|
+
: quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
|
51
52
|
|
52
53
|
void FastText::addInputVector(Vector& vec, int32_t ind) const {
|
53
54
|
vec.addRow(*input_, ind);
|
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
|
|
69
70
|
return std::dynamic_pointer_cast<DenseMatrix>(input_);
|
70
71
|
}
|
71
72
|
|
73
|
+
void FastText::setMatrices(
|
74
|
+
const std::shared_ptr<DenseMatrix>& inputMatrix,
|
75
|
+
const std::shared_ptr<DenseMatrix>& outputMatrix) {
|
76
|
+
assert(input_->size(1) == output_->size(1));
|
77
|
+
|
78
|
+
input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
|
79
|
+
output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
|
80
|
+
wordVectors_.reset();
|
81
|
+
args_->dim = input_->size(1);
|
82
|
+
|
83
|
+
buildModel();
|
84
|
+
}
|
85
|
+
|
72
86
|
std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
|
73
87
|
if (quant_ && args_->qout) {
|
74
88
|
throw std::runtime_error("Can't export quantized matrix");
|
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
|
|
86
100
|
return dict_->nwords() + h;
|
87
101
|
}
|
88
102
|
|
103
|
+
int32_t FastText::getLabelId(const std::string& label) const {
|
104
|
+
int32_t labelId = dict_->getId(label);
|
105
|
+
if (labelId != -1) {
|
106
|
+
labelId -= dict_->nwords();
|
107
|
+
}
|
108
|
+
return labelId;
|
109
|
+
}
|
110
|
+
|
89
111
|
void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
90
112
|
const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
|
91
113
|
vec.zero();
|
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
|
97
119
|
}
|
98
120
|
}
|
99
121
|
|
100
|
-
void FastText::getVector(Vector& vec, const std::string& word) const {
|
101
|
-
getWordVector(vec, word);
|
102
|
-
}
|
103
|
-
|
104
122
|
void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
105
123
|
vec.zero();
|
106
124
|
int32_t h = dict_->hash(subword) % args_->bucket;
|
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
|
109
127
|
}
|
110
128
|
|
111
129
|
void FastText::saveVectors(const std::string& filename) {
|
130
|
+
if (!input_ || !output_) {
|
131
|
+
throw std::runtime_error("Model never trained");
|
132
|
+
}
|
112
133
|
std::ofstream ofs(filename);
|
113
134
|
if (!ofs.is_open()) {
|
114
135
|
throw std::invalid_argument(
|
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
|
|
124
145
|
ofs.close();
|
125
146
|
}
|
126
147
|
|
127
|
-
void FastText::saveVectors() {
|
128
|
-
saveVectors(args_->output + ".vec");
|
129
|
-
}
|
130
|
-
|
131
148
|
void FastText::saveOutput(const std::string& filename) {
|
132
149
|
std::ofstream ofs(filename);
|
133
150
|
if (!ofs.is_open()) {
|
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
|
|
152
169
|
ofs.close();
|
153
170
|
}
|
154
171
|
|
155
|
-
void FastText::saveOutput() {
|
156
|
-
saveOutput(args_->output + ".output");
|
157
|
-
}
|
158
|
-
|
159
172
|
bool FastText::checkModel(std::istream& in) {
|
160
173
|
int32_t magic;
|
161
174
|
in.read((char*)&(magic), sizeof(int32_t));
|
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
|
|
176
189
|
out.write((char*)&(version), sizeof(int32_t));
|
177
190
|
}
|
178
191
|
|
179
|
-
void FastText::saveModel() {
|
180
|
-
std::string fn(args_->output);
|
181
|
-
if (quant_) {
|
182
|
-
fn += ".ftz";
|
183
|
-
} else {
|
184
|
-
fn += ".bin";
|
185
|
-
}
|
186
|
-
saveModel(fn);
|
187
|
-
}
|
188
|
-
|
189
192
|
void FastText::saveModel(const std::string& filename) {
|
190
193
|
std::ofstream ofs(filename, std::ofstream::binary);
|
191
194
|
if (!ofs.is_open()) {
|
192
195
|
throw std::invalid_argument(filename + " cannot be opened for saving!");
|
193
196
|
}
|
197
|
+
if (!input_ || !output_) {
|
198
|
+
throw std::runtime_error("Model never trained");
|
199
|
+
}
|
194
200
|
signModel(ofs);
|
195
201
|
args_->save(ofs);
|
196
202
|
dict_->save(ofs);
|
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
|
|
224
230
|
}
|
225
231
|
}
|
226
232
|
|
233
|
+
void FastText::buildModel() {
|
234
|
+
auto loss = createLoss(output_);
|
235
|
+
bool normalizeGradient = (args_->model == model_name::sup);
|
236
|
+
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
237
|
+
}
|
238
|
+
|
227
239
|
void FastText::loadModel(std::istream& in) {
|
228
240
|
args_ = std::make_shared<Args>();
|
229
241
|
input_ = std::make_shared<DenseMatrix>();
|
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
|
|
256
268
|
}
|
257
269
|
output_->load(in);
|
258
270
|
|
259
|
-
|
260
|
-
bool normalizeGradient = (args_->model == model_name::sup);
|
261
|
-
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
271
|
+
buildModel();
|
262
272
|
}
|
263
273
|
|
264
|
-
|
265
|
-
|
266
|
-
double t =
|
267
|
-
std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
|
268
|
-
.count();
|
274
|
+
std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
|
275
|
+
double t = utils::getDuration(start_, std::chrono::steady_clock::now());
|
269
276
|
double lr = args_->lr * (1.0 - progress);
|
270
277
|
double wst = 0;
|
271
278
|
|
272
279
|
int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
|
273
280
|
|
274
281
|
if (progress > 0 && t >= 0) {
|
275
|
-
|
276
|
-
eta = t * (100 - progress) / progress;
|
282
|
+
eta = t * (1 - progress) / progress;
|
277
283
|
wst = double(tokenCount_) / t / args_->thread;
|
278
284
|
}
|
279
|
-
|
280
|
-
|
285
|
+
|
286
|
+
return std::tuple<double, double, int64_t>(wst, lr, eta);
|
287
|
+
}
|
288
|
+
|
289
|
+
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
|
290
|
+
double wst;
|
291
|
+
double lr;
|
292
|
+
int64_t eta;
|
293
|
+
std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
|
281
294
|
|
282
295
|
log_stream << std::fixed;
|
283
296
|
log_stream << "Progress: ";
|
284
|
-
log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
|
297
|
+
log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
|
285
298
|
log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
|
286
299
|
log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
|
287
|
-
log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
|
288
|
-
log_stream << " ETA: " <<
|
289
|
-
log_stream << "h" << std::setw(2) << etam << "m";
|
300
|
+
log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
|
301
|
+
log_stream << " ETA: " << utils::ClockPrint(eta);
|
290
302
|
log_stream << std::flush;
|
291
303
|
}
|
292
304
|
|
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
|
|
299
311
|
std::iota(idx.begin(), idx.end(), 0);
|
300
312
|
auto eosid = dict_->getId(Dictionary::EOS);
|
301
313
|
std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
|
314
|
+
if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
|
315
|
+
return false;
|
316
|
+
}
|
302
317
|
return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
|
303
318
|
});
|
304
319
|
idx.erase(idx.begin() + cutoff, idx.end());
|
305
320
|
return idx;
|
306
321
|
}
|
307
322
|
|
308
|
-
void FastText::quantize(const Args& qargs) {
|
323
|
+
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
|
309
324
|
if (args_->model != model_name::sup) {
|
310
325
|
throw std::invalid_argument(
|
311
326
|
"For now we only support quantization of supervised models");
|
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
|
|
337
352
|
args_->verbose = qargs.verbose;
|
338
353
|
auto loss = createLoss(output_);
|
339
354
|
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
|
340
|
-
startThreads();
|
355
|
+
startThreads(callback);
|
341
356
|
}
|
342
357
|
}
|
343
|
-
|
344
358
|
input_ = std::make_shared<QuantMatrix>(
|
345
359
|
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
|
346
360
|
|
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
|
|
348
362
|
output_ = std::make_shared<QuantMatrix>(
|
349
363
|
std::move(*(output.get())), 2, qargs.qnorm);
|
350
364
|
}
|
351
|
-
|
352
365
|
quant_ = true;
|
353
366
|
auto loss = createLoss(output_);
|
354
367
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
@@ -408,7 +421,7 @@ void FastText::skipgram(
|
|
408
421
|
|
409
422
|
std::tuple<int64_t, double, double>
|
410
423
|
FastText::test(std::istream& in, int32_t k, real threshold) {
|
411
|
-
Meter meter;
|
424
|
+
Meter meter(false);
|
412
425
|
test(in, k, threshold, meter);
|
413
426
|
|
414
427
|
return std::tuple<int64_t, double, double>(
|
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
|
|
420
433
|
std::vector<int32_t> line;
|
421
434
|
std::vector<int32_t> labels;
|
422
435
|
Predictions predictions;
|
436
|
+
Model::State state(args_->dim, dict_->nlabels(), 0);
|
437
|
+
in.clear();
|
438
|
+
in.seekg(0, std::ios_base::beg);
|
423
439
|
|
424
440
|
while (in.peek() != EOF) {
|
425
441
|
line.clear();
|
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
|
|
521
537
|
return result;
|
522
538
|
}
|
523
539
|
|
524
|
-
// deprecated. use getNgramVectors instead
|
525
|
-
void FastText::ngramVectors(std::string word) {
|
526
|
-
std::vector<std::pair<std::string, Vector>> ngramVectors =
|
527
|
-
getNgramVectors(word);
|
528
|
-
|
529
|
-
for (const auto& ngramVector : ngramVectors) {
|
530
|
-
std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
|
531
|
-
}
|
532
|
-
}
|
533
|
-
|
534
540
|
void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
|
535
541
|
Vector vec(args_->dim);
|
536
542
|
wordVectors.zero();
|
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
|
|
598
604
|
return heap;
|
599
605
|
}
|
600
606
|
|
601
|
-
// depracted. use getNN instead
|
602
|
-
void FastText::findNN(
|
603
|
-
const DenseMatrix& wordVectors,
|
604
|
-
const Vector& query,
|
605
|
-
int32_t k,
|
606
|
-
const std::set<std::string>& banSet,
|
607
|
-
std::vector<std::pair<real, std::string>>& results) {
|
608
|
-
results.clear();
|
609
|
-
results = getNN(wordVectors, query, k, banSet);
|
610
|
-
}
|
611
|
-
|
612
607
|
std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
613
608
|
int32_t k,
|
614
609
|
const std::string& wordA,
|
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
|
630
625
|
return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
|
631
626
|
}
|
632
627
|
|
633
|
-
|
634
|
-
|
635
|
-
std::string prompt("Query triplet (A - B + C)? ");
|
636
|
-
std::string wordA, wordB, wordC;
|
637
|
-
std::cout << prompt;
|
638
|
-
while (true) {
|
639
|
-
std::cin >> wordA;
|
640
|
-
std::cin >> wordB;
|
641
|
-
std::cin >> wordC;
|
642
|
-
auto results = getAnalogies(k, wordA, wordB, wordC);
|
643
|
-
|
644
|
-
for (auto& pair : results) {
|
645
|
-
std::cout << pair.second << " " << pair.first << std::endl;
|
646
|
-
}
|
647
|
-
std::cout << prompt;
|
648
|
-
}
|
628
|
+
bool FastText::keepTraining(const int64_t ntokens) const {
|
629
|
+
return tokenCount_ < args_->epoch * ntokens && !trainException_;
|
649
630
|
}
|
650
631
|
|
651
|
-
void FastText::trainThread(int32_t threadId) {
|
632
|
+
void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
|
652
633
|
std::ifstream ifs(args_->input);
|
653
634
|
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
|
654
635
|
|
655
|
-
Model::State state(args_->dim, output_->size(0), threadId);
|
636
|
+
Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
|
656
637
|
|
657
638
|
const int64_t ntokens = dict_->ntokens();
|
658
639
|
int64_t localTokenCount = 0;
|
659
640
|
std::vector<int32_t> line, labels;
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
if (
|
677
|
-
|
641
|
+
uint64_t callbackCounter = 0;
|
642
|
+
try {
|
643
|
+
while (keepTraining(ntokens)) {
|
644
|
+
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
645
|
+
if (callback && ((callbackCounter++ % 64) == 0)) {
|
646
|
+
double wst;
|
647
|
+
double lr;
|
648
|
+
int64_t eta;
|
649
|
+
std::tie<double, double, int64_t>(wst, lr, eta) =
|
650
|
+
progressInfo(progress);
|
651
|
+
callback(progress, loss_, wst, lr, eta);
|
652
|
+
}
|
653
|
+
real lr = args_->lr * (1.0 - progress);
|
654
|
+
if (args_->model == model_name::sup) {
|
655
|
+
localTokenCount += dict_->getLine(ifs, line, labels);
|
656
|
+
supervised(state, lr, line, labels);
|
657
|
+
} else if (args_->model == model_name::cbow) {
|
658
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
659
|
+
cbow(state, lr, line);
|
660
|
+
} else if (args_->model == model_name::sg) {
|
661
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
662
|
+
skipgram(state, lr, line);
|
663
|
+
}
|
664
|
+
if (localTokenCount > args_->lrUpdateRate) {
|
665
|
+
tokenCount_ += localTokenCount;
|
666
|
+
localTokenCount = 0;
|
667
|
+
if (threadId == 0 && args_->verbose > 1) {
|
668
|
+
loss_ = state.getLoss();
|
669
|
+
}
|
670
|
+
}
|
678
671
|
}
|
672
|
+
} catch (DenseMatrix::EncounteredNaNError&) {
|
673
|
+
trainException_ = std::current_exception();
|
679
674
|
}
|
680
675
|
if (threadId == 0)
|
681
676
|
loss_ = state.getLoss();
|
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
713
708
|
dict_->init();
|
714
709
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
715
710
|
dict_->nwords() + args_->bucket, args_->dim);
|
716
|
-
input->uniform(1.0 / args_->dim);
|
711
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
717
712
|
|
718
713
|
for (size_t i = 0; i < n; i++) {
|
719
714
|
int32_t idx = dict_->getId(words[i]);
|
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
727
722
|
return input;
|
728
723
|
}
|
729
724
|
|
730
|
-
void FastText::loadVectors(const std::string& filename) {
|
731
|
-
input_ = getInputMatrixFromFile(filename);
|
732
|
-
}
|
733
|
-
|
734
725
|
std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
|
735
726
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
736
727
|
dict_->nwords() + args_->bucket, args_->dim);
|
737
|
-
input->uniform(1.0 / args_->dim);
|
728
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
738
729
|
|
739
730
|
return input;
|
740
731
|
}
|
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
|
|
749
740
|
return output;
|
750
741
|
}
|
751
742
|
|
752
|
-
void FastText::train(const Args& args) {
|
743
|
+
void FastText::train(const Args& args, const TrainCallback& callback) {
|
753
744
|
args_ = std::make_shared<Args>(args);
|
754
745
|
dict_ = std::make_shared<Dictionary>(args_);
|
755
746
|
if (args_->input == "-") {
|
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
|
|
770
761
|
input_ = createRandomMatrix();
|
771
762
|
}
|
772
763
|
output_ = createTrainOutputMatrix();
|
764
|
+
quant_ = false;
|
773
765
|
auto loss = createLoss(output_);
|
774
766
|
bool normalizeGradient = (args_->model == model_name::sup);
|
775
767
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
776
|
-
startThreads();
|
768
|
+
startThreads(callback);
|
769
|
+
}
|
770
|
+
|
771
|
+
void FastText::abort() {
|
772
|
+
try {
|
773
|
+
throw AbortError();
|
774
|
+
} catch (AbortError&) {
|
775
|
+
trainException_ = std::current_exception();
|
776
|
+
}
|
777
777
|
}
|
778
778
|
|
779
|
-
void FastText::startThreads() {
|
779
|
+
void FastText::startThreads(const TrainCallback& callback) {
|
780
780
|
start_ = std::chrono::steady_clock::now();
|
781
781
|
tokenCount_ = 0;
|
782
782
|
loss_ = -1;
|
783
|
+
trainException_ = nullptr;
|
783
784
|
std::vector<std::thread> threads;
|
784
|
-
|
785
|
-
|
785
|
+
if (args_->thread > 1) {
|
786
|
+
for (int32_t i = 0; i < args_->thread; i++) {
|
787
|
+
threads.push_back(std::thread([=]() { trainThread(i, callback); }));
|
788
|
+
}
|
789
|
+
} else {
|
790
|
+
// webassembly can't instantiate `std::thread`
|
791
|
+
trainThread(0, callback);
|
786
792
|
}
|
787
793
|
const int64_t ntokens = dict_->ntokens();
|
788
794
|
// Same condition as trainThread
|
789
|
-
while (
|
795
|
+
while (keepTraining(ntokens)) {
|
790
796
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
791
797
|
if (loss_ >= 0 && args_->verbose > 1) {
|
792
798
|
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
@@ -794,9 +800,14 @@ void FastText::startThreads() {
|
|
794
800
|
printInfo(progress, loss_, std::cerr);
|
795
801
|
}
|
796
802
|
}
|
797
|
-
for (int32_t i = 0; i <
|
803
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
798
804
|
threads[i].join();
|
799
805
|
}
|
806
|
+
if (trainException_) {
|
807
|
+
std::exception_ptr exception = trainException_;
|
808
|
+
trainException_ = nullptr;
|
809
|
+
std::rethrow_exception(exception);
|
810
|
+
}
|
800
811
|
if (args_->verbose > 0) {
|
801
812
|
std::cerr << "\r";
|
802
813
|
printInfo(1.0, loss_, std::cerr);
|