fasttext 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
@@ -0,0 +1,89 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the MIT license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
7
|
+
*/
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <istream>
|
12
|
+
#include <memory>
|
13
|
+
#include <random>
|
14
|
+
#include <thread>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include "args.h"
|
18
|
+
#include "fasttext.h"
|
19
|
+
|
20
|
+
namespace fasttext {
|
21
|
+
|
22
|
+
class AutotuneStrategy {
|
23
|
+
private:
|
24
|
+
Args bestArgs_;
|
25
|
+
int maxDuration_;
|
26
|
+
std::minstd_rand rng_;
|
27
|
+
int trials_;
|
28
|
+
int bestMinnIndex_;
|
29
|
+
int bestDsubExponent_;
|
30
|
+
int bestNonzeroBucket_;
|
31
|
+
int originalBucket_;
|
32
|
+
std::vector<int> minnChoices_;
|
33
|
+
int getIndex(int val, const std::vector<int>& choices);
|
34
|
+
|
35
|
+
public:
|
36
|
+
explicit AutotuneStrategy(
|
37
|
+
const Args& args,
|
38
|
+
std::minstd_rand::result_type seed);
|
39
|
+
Args ask(double elapsed);
|
40
|
+
void updateBest(const Args& args);
|
41
|
+
};
|
42
|
+
|
43
|
+
class Autotune {
|
44
|
+
protected:
|
45
|
+
std::shared_ptr<FastText> fastText_;
|
46
|
+
double elapsed_;
|
47
|
+
double bestScore_;
|
48
|
+
int32_t trials_;
|
49
|
+
int32_t sizeConstraintFailed_;
|
50
|
+
std::atomic<bool> continueTraining_;
|
51
|
+
std::unique_ptr<AutotuneStrategy> strategy_;
|
52
|
+
std::thread timer_;
|
53
|
+
|
54
|
+
bool keepTraining(double maxDuration) const;
|
55
|
+
void printInfo(double maxDuration);
|
56
|
+
void timer(
|
57
|
+
const std::chrono::steady_clock::time_point& start,
|
58
|
+
double maxDuration);
|
59
|
+
void abort();
|
60
|
+
void startTimer(const Args& args);
|
61
|
+
double getMetricScore(
|
62
|
+
Meter& meter,
|
63
|
+
const metric_name& metricName,
|
64
|
+
const double metricValue,
|
65
|
+
const std::string& metricLabel) const;
|
66
|
+
void printArgs(const Args& args, const Args& autotuneArgs);
|
67
|
+
void printSkippedArgs(const Args& autotuneArgs);
|
68
|
+
bool quantize(Args& args, const Args& autotuneArgs);
|
69
|
+
int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
|
70
|
+
const;
|
71
|
+
|
72
|
+
class TimeoutError : public std::runtime_error {
|
73
|
+
public:
|
74
|
+
TimeoutError() : std::runtime_error("Autotune timed out.") {}
|
75
|
+
};
|
76
|
+
|
77
|
+
public:
|
78
|
+
Autotune() = delete;
|
79
|
+
explicit Autotune(const std::shared_ptr<FastText>& fastText);
|
80
|
+
Autotune(const Autotune&) = delete;
|
81
|
+
Autotune(Autotune&&) = delete;
|
82
|
+
Autotune& operator=(const Autotune&) = delete;
|
83
|
+
Autotune& operator=(Autotune&&) = delete;
|
84
|
+
~Autotune() noexcept = default;
|
85
|
+
|
86
|
+
void train(const Args& args);
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace fasttext
|
@@ -8,11 +8,10 @@
|
|
8
8
|
|
9
9
|
#include "densematrix.h"
|
10
10
|
|
11
|
-
#include <exception>
|
12
11
|
#include <random>
|
13
12
|
#include <stdexcept>
|
13
|
+
#include <thread>
|
14
14
|
#include <utility>
|
15
|
-
|
16
15
|
#include "utils.h"
|
17
16
|
#include "vector.h"
|
18
17
|
|
@@ -25,18 +24,39 @@ DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
|
|
25
24
|
DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
|
26
25
|
: Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
|
27
26
|
|
27
|
+
DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr)
|
28
|
+
: Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {}
|
29
|
+
|
28
30
|
void DenseMatrix::zero() {
|
29
31
|
std::fill(data_.begin(), data_.end(), 0.0);
|
30
32
|
}
|
31
33
|
|
32
|
-
void DenseMatrix::
|
33
|
-
std::minstd_rand rng(
|
34
|
+
void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
|
35
|
+
std::minstd_rand rng(block + seed);
|
34
36
|
std::uniform_real_distribution<> uniform(-a, a);
|
35
|
-
|
37
|
+
int64_t blockSize = (m_ * n_) / 10;
|
38
|
+
for (int64_t i = blockSize * block;
|
39
|
+
i < (m_ * n_) && i < blockSize * (block + 1);
|
40
|
+
i++) {
|
36
41
|
data_[i] = uniform(rng);
|
37
42
|
}
|
38
43
|
}
|
39
44
|
|
45
|
+
void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
|
46
|
+
if (thread > 1) {
|
47
|
+
std::vector<std::thread> threads;
|
48
|
+
for (int i = 0; i < thread; i++) {
|
49
|
+
threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
|
50
|
+
}
|
51
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
52
|
+
threads[i].join();
|
53
|
+
}
|
54
|
+
} else {
|
55
|
+
// webassembly can't instantiate `std::thread`
|
56
|
+
uniformThread(a, 0, seed);
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
40
60
|
void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
|
41
61
|
if (ie == -1) {
|
42
62
|
ie = m_;
|
@@ -73,7 +93,7 @@ real DenseMatrix::l2NormRow(int64_t i) const {
|
|
73
93
|
norm += at(i, j) * at(i, j);
|
74
94
|
}
|
75
95
|
if (std::isnan(norm)) {
|
76
|
-
throw
|
96
|
+
throw EncounteredNaNError();
|
77
97
|
}
|
78
98
|
return std::sqrt(norm);
|
79
99
|
}
|
@@ -94,7 +114,7 @@ real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
|
|
94
114
|
d += at(i, j) * vec[j];
|
95
115
|
}
|
96
116
|
if (std::isnan(d)) {
|
97
|
-
throw
|
117
|
+
throw EncounteredNaNError();
|
98
118
|
}
|
99
119
|
return d;
|
100
120
|
}
|
@@ -8,12 +8,13 @@
|
|
8
8
|
|
9
9
|
#pragma once
|
10
10
|
|
11
|
+
#include <assert.h>
|
11
12
|
#include <cstdint>
|
12
13
|
#include <istream>
|
13
14
|
#include <ostream>
|
15
|
+
#include <stdexcept>
|
14
16
|
#include <vector>
|
15
17
|
|
16
|
-
#include <assert.h>
|
17
18
|
#include "matrix.h"
|
18
19
|
#include "real.h"
|
19
20
|
|
@@ -24,10 +25,12 @@ class Vector;
|
|
24
25
|
class DenseMatrix : public Matrix {
|
25
26
|
protected:
|
26
27
|
std::vector<real> data_;
|
28
|
+
void uniformThread(real, int, int32_t);
|
27
29
|
|
28
30
|
public:
|
29
31
|
DenseMatrix();
|
30
32
|
explicit DenseMatrix(int64_t, int64_t);
|
33
|
+
explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr);
|
31
34
|
DenseMatrix(const DenseMatrix&) = default;
|
32
35
|
DenseMatrix(DenseMatrix&&) noexcept;
|
33
36
|
DenseMatrix& operator=(const DenseMatrix&) = delete;
|
@@ -56,7 +59,7 @@ class DenseMatrix : public Matrix {
|
|
56
59
|
return n_;
|
57
60
|
}
|
58
61
|
void zero();
|
59
|
-
void uniform(real);
|
62
|
+
void uniform(real, unsigned int, int32_t);
|
60
63
|
|
61
64
|
void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
|
62
65
|
void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
|
@@ -71,5 +74,10 @@ class DenseMatrix : public Matrix {
|
|
71
74
|
void save(std::ostream&) const override;
|
72
75
|
void load(std::istream&) override;
|
73
76
|
void dump(std::ostream&) const override;
|
77
|
+
|
78
|
+
class EncounteredNaNError : public std::runtime_error {
|
79
|
+
public:
|
80
|
+
EncounteredNaNError() : std::runtime_error("Encountered NaN.") {}
|
81
|
+
};
|
74
82
|
};
|
75
83
|
} // namespace fasttext
|
@@ -47,7 +47,8 @@ std::shared_ptr<Loss> FastText::createLoss(std::shared_ptr<Matrix>& output) {
|
|
47
47
|
}
|
48
48
|
}
|
49
49
|
|
50
|
-
FastText::FastText()
|
50
|
+
FastText::FastText()
|
51
|
+
: quant_(false), wordVectors_(nullptr), trainException_(nullptr) {}
|
51
52
|
|
52
53
|
void FastText::addInputVector(Vector& vec, int32_t ind) const {
|
53
54
|
vec.addRow(*input_, ind);
|
@@ -69,6 +70,19 @@ std::shared_ptr<const DenseMatrix> FastText::getInputMatrix() const {
|
|
69
70
|
return std::dynamic_pointer_cast<DenseMatrix>(input_);
|
70
71
|
}
|
71
72
|
|
73
|
+
void FastText::setMatrices(
|
74
|
+
const std::shared_ptr<DenseMatrix>& inputMatrix,
|
75
|
+
const std::shared_ptr<DenseMatrix>& outputMatrix) {
|
76
|
+
assert(input_->size(1) == output_->size(1));
|
77
|
+
|
78
|
+
input_ = std::dynamic_pointer_cast<Matrix>(inputMatrix);
|
79
|
+
output_ = std::dynamic_pointer_cast<Matrix>(outputMatrix);
|
80
|
+
wordVectors_.reset();
|
81
|
+
args_->dim = input_->size(1);
|
82
|
+
|
83
|
+
buildModel();
|
84
|
+
}
|
85
|
+
|
72
86
|
std::shared_ptr<const DenseMatrix> FastText::getOutputMatrix() const {
|
73
87
|
if (quant_ && args_->qout) {
|
74
88
|
throw std::runtime_error("Can't export quantized matrix");
|
@@ -86,6 +100,14 @@ int32_t FastText::getSubwordId(const std::string& subword) const {
|
|
86
100
|
return dict_->nwords() + h;
|
87
101
|
}
|
88
102
|
|
103
|
+
int32_t FastText::getLabelId(const std::string& label) const {
|
104
|
+
int32_t labelId = dict_->getId(label);
|
105
|
+
if (labelId != -1) {
|
106
|
+
labelId -= dict_->nwords();
|
107
|
+
}
|
108
|
+
return labelId;
|
109
|
+
}
|
110
|
+
|
89
111
|
void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
90
112
|
const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
|
91
113
|
vec.zero();
|
@@ -97,10 +119,6 @@ void FastText::getWordVector(Vector& vec, const std::string& word) const {
|
|
97
119
|
}
|
98
120
|
}
|
99
121
|
|
100
|
-
void FastText::getVector(Vector& vec, const std::string& word) const {
|
101
|
-
getWordVector(vec, word);
|
102
|
-
}
|
103
|
-
|
104
122
|
void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
105
123
|
vec.zero();
|
106
124
|
int32_t h = dict_->hash(subword) % args_->bucket;
|
@@ -109,6 +127,9 @@ void FastText::getSubwordVector(Vector& vec, const std::string& subword) const {
|
|
109
127
|
}
|
110
128
|
|
111
129
|
void FastText::saveVectors(const std::string& filename) {
|
130
|
+
if (!input_ || !output_) {
|
131
|
+
throw std::runtime_error("Model never trained");
|
132
|
+
}
|
112
133
|
std::ofstream ofs(filename);
|
113
134
|
if (!ofs.is_open()) {
|
114
135
|
throw std::invalid_argument(
|
@@ -124,10 +145,6 @@ void FastText::saveVectors(const std::string& filename) {
|
|
124
145
|
ofs.close();
|
125
146
|
}
|
126
147
|
|
127
|
-
void FastText::saveVectors() {
|
128
|
-
saveVectors(args_->output + ".vec");
|
129
|
-
}
|
130
|
-
|
131
148
|
void FastText::saveOutput(const std::string& filename) {
|
132
149
|
std::ofstream ofs(filename);
|
133
150
|
if (!ofs.is_open()) {
|
@@ -152,10 +169,6 @@ void FastText::saveOutput(const std::string& filename) {
|
|
152
169
|
ofs.close();
|
153
170
|
}
|
154
171
|
|
155
|
-
void FastText::saveOutput() {
|
156
|
-
saveOutput(args_->output + ".output");
|
157
|
-
}
|
158
|
-
|
159
172
|
bool FastText::checkModel(std::istream& in) {
|
160
173
|
int32_t magic;
|
161
174
|
in.read((char*)&(magic), sizeof(int32_t));
|
@@ -176,21 +189,14 @@ void FastText::signModel(std::ostream& out) {
|
|
176
189
|
out.write((char*)&(version), sizeof(int32_t));
|
177
190
|
}
|
178
191
|
|
179
|
-
void FastText::saveModel() {
|
180
|
-
std::string fn(args_->output);
|
181
|
-
if (quant_) {
|
182
|
-
fn += ".ftz";
|
183
|
-
} else {
|
184
|
-
fn += ".bin";
|
185
|
-
}
|
186
|
-
saveModel(fn);
|
187
|
-
}
|
188
|
-
|
189
192
|
void FastText::saveModel(const std::string& filename) {
|
190
193
|
std::ofstream ofs(filename, std::ofstream::binary);
|
191
194
|
if (!ofs.is_open()) {
|
192
195
|
throw std::invalid_argument(filename + " cannot be opened for saving!");
|
193
196
|
}
|
197
|
+
if (!input_ || !output_) {
|
198
|
+
throw std::runtime_error("Model never trained");
|
199
|
+
}
|
194
200
|
signModel(ofs);
|
195
201
|
args_->save(ofs);
|
196
202
|
dict_->save(ofs);
|
@@ -224,6 +230,12 @@ std::vector<int64_t> FastText::getTargetCounts() const {
|
|
224
230
|
}
|
225
231
|
}
|
226
232
|
|
233
|
+
void FastText::buildModel() {
|
234
|
+
auto loss = createLoss(output_);
|
235
|
+
bool normalizeGradient = (args_->model == model_name::sup);
|
236
|
+
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
237
|
+
}
|
238
|
+
|
227
239
|
void FastText::loadModel(std::istream& in) {
|
228
240
|
args_ = std::make_shared<Args>();
|
229
241
|
input_ = std::make_shared<DenseMatrix>();
|
@@ -256,37 +268,37 @@ void FastText::loadModel(std::istream& in) {
|
|
256
268
|
}
|
257
269
|
output_->load(in);
|
258
270
|
|
259
|
-
|
260
|
-
bool normalizeGradient = (args_->model == model_name::sup);
|
261
|
-
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
271
|
+
buildModel();
|
262
272
|
}
|
263
273
|
|
264
|
-
|
265
|
-
|
266
|
-
double t =
|
267
|
-
std::chrono::duration_cast<std::chrono::duration<double>>(end - start_)
|
268
|
-
.count();
|
274
|
+
std::tuple<int64_t, double, double> FastText::progressInfo(real progress) {
|
275
|
+
double t = utils::getDuration(start_, std::chrono::steady_clock::now());
|
269
276
|
double lr = args_->lr * (1.0 - progress);
|
270
277
|
double wst = 0;
|
271
278
|
|
272
279
|
int64_t eta = 2592000; // Default to one month in seconds (720 * 3600)
|
273
280
|
|
274
281
|
if (progress > 0 && t >= 0) {
|
275
|
-
|
276
|
-
eta = t * (100 - progress) / progress;
|
282
|
+
eta = t * (1 - progress) / progress;
|
277
283
|
wst = double(tokenCount_) / t / args_->thread;
|
278
284
|
}
|
279
|
-
|
280
|
-
|
285
|
+
|
286
|
+
return std::tuple<double, double, int64_t>(wst, lr, eta);
|
287
|
+
}
|
288
|
+
|
289
|
+
void FastText::printInfo(real progress, real loss, std::ostream& log_stream) {
|
290
|
+
double wst;
|
291
|
+
double lr;
|
292
|
+
int64_t eta;
|
293
|
+
std::tie<double, double, int64_t>(wst, lr, eta) = progressInfo(progress);
|
281
294
|
|
282
295
|
log_stream << std::fixed;
|
283
296
|
log_stream << "Progress: ";
|
284
|
-
log_stream << std::setprecision(1) << std::setw(5) << progress << "%";
|
297
|
+
log_stream << std::setprecision(1) << std::setw(5) << (progress * 100) << "%";
|
285
298
|
log_stream << " words/sec/thread: " << std::setw(7) << int64_t(wst);
|
286
299
|
log_stream << " lr: " << std::setw(9) << std::setprecision(6) << lr;
|
287
|
-
log_stream << " loss: " << std::setw(9) << std::setprecision(6) << loss;
|
288
|
-
log_stream << " ETA: " <<
|
289
|
-
log_stream << "h" << std::setw(2) << etam << "m";
|
300
|
+
log_stream << " avg.loss: " << std::setw(9) << std::setprecision(6) << loss;
|
301
|
+
log_stream << " ETA: " << utils::ClockPrint(eta);
|
290
302
|
log_stream << std::flush;
|
291
303
|
}
|
292
304
|
|
@@ -299,13 +311,16 @@ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
|
|
299
311
|
std::iota(idx.begin(), idx.end(), 0);
|
300
312
|
auto eosid = dict_->getId(Dictionary::EOS);
|
301
313
|
std::sort(idx.begin(), idx.end(), [&norms, eosid](size_t i1, size_t i2) {
|
314
|
+
if (i1 == eosid && i2 == eosid) { // satisfy strict weak ordering
|
315
|
+
return false;
|
316
|
+
}
|
302
317
|
return eosid == i1 || (eosid != i2 && norms[i1] > norms[i2]);
|
303
318
|
});
|
304
319
|
idx.erase(idx.begin() + cutoff, idx.end());
|
305
320
|
return idx;
|
306
321
|
}
|
307
322
|
|
308
|
-
void FastText::quantize(const Args& qargs) {
|
323
|
+
void FastText::quantize(const Args& qargs, const TrainCallback& callback) {
|
309
324
|
if (args_->model != model_name::sup) {
|
310
325
|
throw std::invalid_argument(
|
311
326
|
"For now we only support quantization of supervised models");
|
@@ -337,10 +352,9 @@ void FastText::quantize(const Args& qargs) {
|
|
337
352
|
args_->verbose = qargs.verbose;
|
338
353
|
auto loss = createLoss(output_);
|
339
354
|
model_ = std::make_shared<Model>(input, output, loss, normalizeGradient);
|
340
|
-
startThreads();
|
355
|
+
startThreads(callback);
|
341
356
|
}
|
342
357
|
}
|
343
|
-
|
344
358
|
input_ = std::make_shared<QuantMatrix>(
|
345
359
|
std::move(*(input.get())), qargs.dsub, qargs.qnorm);
|
346
360
|
|
@@ -348,7 +362,6 @@ void FastText::quantize(const Args& qargs) {
|
|
348
362
|
output_ = std::make_shared<QuantMatrix>(
|
349
363
|
std::move(*(output.get())), 2, qargs.qnorm);
|
350
364
|
}
|
351
|
-
|
352
365
|
quant_ = true;
|
353
366
|
auto loss = createLoss(output_);
|
354
367
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
@@ -408,7 +421,7 @@ void FastText::skipgram(
|
|
408
421
|
|
409
422
|
std::tuple<int64_t, double, double>
|
410
423
|
FastText::test(std::istream& in, int32_t k, real threshold) {
|
411
|
-
Meter meter;
|
424
|
+
Meter meter(false);
|
412
425
|
test(in, k, threshold, meter);
|
413
426
|
|
414
427
|
return std::tuple<int64_t, double, double>(
|
@@ -420,6 +433,9 @@ void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
|
|
420
433
|
std::vector<int32_t> line;
|
421
434
|
std::vector<int32_t> labels;
|
422
435
|
Predictions predictions;
|
436
|
+
Model::State state(args_->dim, dict_->nlabels(), 0);
|
437
|
+
in.clear();
|
438
|
+
in.seekg(0, std::ios_base::beg);
|
423
439
|
|
424
440
|
while (in.peek() != EOF) {
|
425
441
|
line.clear();
|
@@ -521,16 +537,6 @@ std::vector<std::pair<std::string, Vector>> FastText::getNgramVectors(
|
|
521
537
|
return result;
|
522
538
|
}
|
523
539
|
|
524
|
-
// deprecated. use getNgramVectors instead
|
525
|
-
void FastText::ngramVectors(std::string word) {
|
526
|
-
std::vector<std::pair<std::string, Vector>> ngramVectors =
|
527
|
-
getNgramVectors(word);
|
528
|
-
|
529
|
-
for (const auto& ngramVector : ngramVectors) {
|
530
|
-
std::cout << ngramVector.first << " " << ngramVector.second << std::endl;
|
531
|
-
}
|
532
|
-
}
|
533
|
-
|
534
540
|
void FastText::precomputeWordVectors(DenseMatrix& wordVectors) {
|
535
541
|
Vector vec(args_->dim);
|
536
542
|
wordVectors.zero();
|
@@ -598,17 +604,6 @@ std::vector<std::pair<real, std::string>> FastText::getNN(
|
|
598
604
|
return heap;
|
599
605
|
}
|
600
606
|
|
601
|
-
// depracted. use getNN instead
|
602
|
-
void FastText::findNN(
|
603
|
-
const DenseMatrix& wordVectors,
|
604
|
-
const Vector& query,
|
605
|
-
int32_t k,
|
606
|
-
const std::set<std::string>& banSet,
|
607
|
-
std::vector<std::pair<real, std::string>>& results) {
|
608
|
-
results.clear();
|
609
|
-
results = getNN(wordVectors, query, k, banSet);
|
610
|
-
}
|
611
|
-
|
612
607
|
std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
613
608
|
int32_t k,
|
614
609
|
const std::string& wordA,
|
@@ -630,52 +625,52 @@ std::vector<std::pair<real, std::string>> FastText::getAnalogies(
|
|
630
625
|
return getNN(*wordVectors_, query, k, {wordA, wordB, wordC});
|
631
626
|
}
|
632
627
|
|
633
|
-
|
634
|
-
|
635
|
-
std::string prompt("Query triplet (A - B + C)? ");
|
636
|
-
std::string wordA, wordB, wordC;
|
637
|
-
std::cout << prompt;
|
638
|
-
while (true) {
|
639
|
-
std::cin >> wordA;
|
640
|
-
std::cin >> wordB;
|
641
|
-
std::cin >> wordC;
|
642
|
-
auto results = getAnalogies(k, wordA, wordB, wordC);
|
643
|
-
|
644
|
-
for (auto& pair : results) {
|
645
|
-
std::cout << pair.second << " " << pair.first << std::endl;
|
646
|
-
}
|
647
|
-
std::cout << prompt;
|
648
|
-
}
|
628
|
+
bool FastText::keepTraining(const int64_t ntokens) const {
|
629
|
+
return tokenCount_ < args_->epoch * ntokens && !trainException_;
|
649
630
|
}
|
650
631
|
|
651
|
-
void FastText::trainThread(int32_t threadId) {
|
632
|
+
void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
|
652
633
|
std::ifstream ifs(args_->input);
|
653
634
|
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
|
654
635
|
|
655
|
-
Model::State state(args_->dim, output_->size(0), threadId);
|
636
|
+
Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
|
656
637
|
|
657
638
|
const int64_t ntokens = dict_->ntokens();
|
658
639
|
int64_t localTokenCount = 0;
|
659
640
|
std::vector<int32_t> line, labels;
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
if (
|
677
|
-
|
641
|
+
uint64_t callbackCounter = 0;
|
642
|
+
try {
|
643
|
+
while (keepTraining(ntokens)) {
|
644
|
+
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
645
|
+
if (callback && ((callbackCounter++ % 64) == 0)) {
|
646
|
+
double wst;
|
647
|
+
double lr;
|
648
|
+
int64_t eta;
|
649
|
+
std::tie<double, double, int64_t>(wst, lr, eta) =
|
650
|
+
progressInfo(progress);
|
651
|
+
callback(progress, loss_, wst, lr, eta);
|
652
|
+
}
|
653
|
+
real lr = args_->lr * (1.0 - progress);
|
654
|
+
if (args_->model == model_name::sup) {
|
655
|
+
localTokenCount += dict_->getLine(ifs, line, labels);
|
656
|
+
supervised(state, lr, line, labels);
|
657
|
+
} else if (args_->model == model_name::cbow) {
|
658
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
659
|
+
cbow(state, lr, line);
|
660
|
+
} else if (args_->model == model_name::sg) {
|
661
|
+
localTokenCount += dict_->getLine(ifs, line, state.rng);
|
662
|
+
skipgram(state, lr, line);
|
663
|
+
}
|
664
|
+
if (localTokenCount > args_->lrUpdateRate) {
|
665
|
+
tokenCount_ += localTokenCount;
|
666
|
+
localTokenCount = 0;
|
667
|
+
if (threadId == 0 && args_->verbose > 1) {
|
668
|
+
loss_ = state.getLoss();
|
669
|
+
}
|
670
|
+
}
|
678
671
|
}
|
672
|
+
} catch (DenseMatrix::EncounteredNaNError&) {
|
673
|
+
trainException_ = std::current_exception();
|
679
674
|
}
|
680
675
|
if (threadId == 0)
|
681
676
|
loss_ = state.getLoss();
|
@@ -713,7 +708,7 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
713
708
|
dict_->init();
|
714
709
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
715
710
|
dict_->nwords() + args_->bucket, args_->dim);
|
716
|
-
input->uniform(1.0 / args_->dim);
|
711
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
717
712
|
|
718
713
|
for (size_t i = 0; i < n; i++) {
|
719
714
|
int32_t idx = dict_->getId(words[i]);
|
@@ -727,14 +722,10 @@ std::shared_ptr<Matrix> FastText::getInputMatrixFromFile(
|
|
727
722
|
return input;
|
728
723
|
}
|
729
724
|
|
730
|
-
void FastText::loadVectors(const std::string& filename) {
|
731
|
-
input_ = getInputMatrixFromFile(filename);
|
732
|
-
}
|
733
|
-
|
734
725
|
std::shared_ptr<Matrix> FastText::createRandomMatrix() const {
|
735
726
|
std::shared_ptr<DenseMatrix> input = std::make_shared<DenseMatrix>(
|
736
727
|
dict_->nwords() + args_->bucket, args_->dim);
|
737
|
-
input->uniform(1.0 / args_->dim);
|
728
|
+
input->uniform(1.0 / args_->dim, args_->thread, args_->seed);
|
738
729
|
|
739
730
|
return input;
|
740
731
|
}
|
@@ -749,7 +740,7 @@ std::shared_ptr<Matrix> FastText::createTrainOutputMatrix() const {
|
|
749
740
|
return output;
|
750
741
|
}
|
751
742
|
|
752
|
-
void FastText::train(const Args& args) {
|
743
|
+
void FastText::train(const Args& args, const TrainCallback& callback) {
|
753
744
|
args_ = std::make_shared<Args>(args);
|
754
745
|
dict_ = std::make_shared<Dictionary>(args_);
|
755
746
|
if (args_->input == "-") {
|
@@ -770,23 +761,38 @@ void FastText::train(const Args& args) {
|
|
770
761
|
input_ = createRandomMatrix();
|
771
762
|
}
|
772
763
|
output_ = createTrainOutputMatrix();
|
764
|
+
quant_ = false;
|
773
765
|
auto loss = createLoss(output_);
|
774
766
|
bool normalizeGradient = (args_->model == model_name::sup);
|
775
767
|
model_ = std::make_shared<Model>(input_, output_, loss, normalizeGradient);
|
776
|
-
startThreads();
|
768
|
+
startThreads(callback);
|
769
|
+
}
|
770
|
+
|
771
|
+
void FastText::abort() {
|
772
|
+
try {
|
773
|
+
throw AbortError();
|
774
|
+
} catch (AbortError&) {
|
775
|
+
trainException_ = std::current_exception();
|
776
|
+
}
|
777
777
|
}
|
778
778
|
|
779
|
-
void FastText::startThreads() {
|
779
|
+
void FastText::startThreads(const TrainCallback& callback) {
|
780
780
|
start_ = std::chrono::steady_clock::now();
|
781
781
|
tokenCount_ = 0;
|
782
782
|
loss_ = -1;
|
783
|
+
trainException_ = nullptr;
|
783
784
|
std::vector<std::thread> threads;
|
784
|
-
|
785
|
-
|
785
|
+
if (args_->thread > 1) {
|
786
|
+
for (int32_t i = 0; i < args_->thread; i++) {
|
787
|
+
threads.push_back(std::thread([=]() { trainThread(i, callback); }));
|
788
|
+
}
|
789
|
+
} else {
|
790
|
+
// webassembly can't instantiate `std::thread`
|
791
|
+
trainThread(0, callback);
|
786
792
|
}
|
787
793
|
const int64_t ntokens = dict_->ntokens();
|
788
794
|
// Same condition as trainThread
|
789
|
-
while (
|
795
|
+
while (keepTraining(ntokens)) {
|
790
796
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
791
797
|
if (loss_ >= 0 && args_->verbose > 1) {
|
792
798
|
real progress = real(tokenCount_) / (args_->epoch * ntokens);
|
@@ -794,9 +800,14 @@ void FastText::startThreads() {
|
|
794
800
|
printInfo(progress, loss_, std::cerr);
|
795
801
|
}
|
796
802
|
}
|
797
|
-
for (int32_t i = 0; i <
|
803
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
798
804
|
threads[i].join();
|
799
805
|
}
|
806
|
+
if (trainException_) {
|
807
|
+
std::exception_ptr exception = trainException_;
|
808
|
+
trainException_ = nullptr;
|
809
|
+
std::rethrow_exception(exception);
|
810
|
+
}
|
800
811
|
if (args_->verbose > 0) {
|
801
812
|
std::cerr << "\r";
|
802
813
|
printInfo(1.0, loss_, std::cerr);
|