fasttext 0.1.2 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/LICENSE.txt +18 -18
- data/README.md +26 -19
- data/ext/fasttext/ext.cpp +131 -134
- data/ext/fasttext/extconf.rb +2 -4
- data/lib/fasttext/classifier.rb +23 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +11 -5
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +14 -69
- data/lib/fasttext/ext.bundle +0 -0
@@ -0,0 +1,477 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the MIT license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
7
|
+
*/
|
8
|
+
|
9
|
+
#include "autotune.h"
|
10
|
+
|
11
|
+
#include <algorithm>
|
12
|
+
#include <csignal>
|
13
|
+
#include <functional>
|
14
|
+
#include <iomanip>
|
15
|
+
#include <iostream>
|
16
|
+
#include <random>
|
17
|
+
#include <thread>
|
18
|
+
|
19
|
+
#define LOG_VAL(name, val) \
|
20
|
+
if (autotuneArgs.verbose > 2) { \
|
21
|
+
std::cout << #name " = " << val << std::endl; \
|
22
|
+
}
|
23
|
+
#define LOG_VAL_NAN(name, val) \
|
24
|
+
if (autotuneArgs.verbose > 2) { \
|
25
|
+
if (std::isnan(val)) { \
|
26
|
+
std::cout << #name " = NaN" << std::endl; \
|
27
|
+
} else { \
|
28
|
+
std::cout << #name " = " << val << std::endl; \
|
29
|
+
} \
|
30
|
+
}
|
31
|
+
|
32
|
+
namespace {
|
33
|
+
|
34
|
+
std::function<void()> interruptSignalHandler;
|
35
|
+
|
36
|
+
void signalHandler(int signal) {
|
37
|
+
if (signal == SIGINT) {
|
38
|
+
interruptSignalHandler();
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
class ElapsedTimeMarker {
|
43
|
+
std::chrono::steady_clock::time_point start_;
|
44
|
+
|
45
|
+
public:
|
46
|
+
ElapsedTimeMarker() {
|
47
|
+
start_ = std::chrono::steady_clock::now();
|
48
|
+
}
|
49
|
+
double getElapsed() {
|
50
|
+
return fasttext::utils::getDuration(
|
51
|
+
start_, std::chrono::steady_clock::now());
|
52
|
+
}
|
53
|
+
};
|
54
|
+
|
55
|
+
} // namespace
|
56
|
+
|
57
|
+
namespace fasttext {
|
58
|
+
|
59
|
+
constexpr double kUnknownBestScore = -1.0;
|
60
|
+
constexpr int kCutoffLimit = 256;
|
61
|
+
|
62
|
+
template <typename T>
|
63
|
+
T getArgGauss(
|
64
|
+
T val,
|
65
|
+
std::minstd_rand& rng,
|
66
|
+
double startSigma,
|
67
|
+
double endSigma,
|
68
|
+
double t,
|
69
|
+
bool linear) {
|
70
|
+
T returnValue;
|
71
|
+
const double stddev = startSigma -
|
72
|
+
((startSigma - endSigma) / 0.5) *
|
73
|
+
std::min(0.5, std::max((t - 0.25), 0.0));
|
74
|
+
|
75
|
+
std::normal_distribution<double> normal(0.0, stddev);
|
76
|
+
|
77
|
+
const double coeff = normal(rng);
|
78
|
+
double updateCoeff = 0.0;
|
79
|
+
|
80
|
+
if (linear) {
|
81
|
+
updateCoeff = coeff;
|
82
|
+
returnValue = static_cast<T>(updateCoeff + val);
|
83
|
+
} else {
|
84
|
+
updateCoeff = std::pow(2.0, coeff);
|
85
|
+
returnValue = static_cast<T>(updateCoeff * val);
|
86
|
+
}
|
87
|
+
|
88
|
+
return returnValue;
|
89
|
+
}
|
90
|
+
|
91
|
+
template <typename T>
|
92
|
+
T updateArgGauss(
|
93
|
+
T val,
|
94
|
+
T min,
|
95
|
+
T max,
|
96
|
+
double startSigma,
|
97
|
+
double endSigma,
|
98
|
+
double t,
|
99
|
+
bool linear,
|
100
|
+
std::minstd_rand& rng) {
|
101
|
+
T retVal = getArgGauss(val, rng, startSigma, endSigma, t, linear);
|
102
|
+
if (retVal > max) {
|
103
|
+
retVal = max;
|
104
|
+
}
|
105
|
+
if (retVal < min) {
|
106
|
+
retVal = min;
|
107
|
+
}
|
108
|
+
return retVal;
|
109
|
+
}
|
110
|
+
|
111
|
+
AutotuneStrategy::AutotuneStrategy(
|
112
|
+
const Args& originalArgs,
|
113
|
+
std::minstd_rand::result_type seed)
|
114
|
+
: bestArgs_(),
|
115
|
+
maxDuration_(originalArgs.autotuneDuration),
|
116
|
+
rng_(seed),
|
117
|
+
trials_(0),
|
118
|
+
bestMinnIndex_(0),
|
119
|
+
bestDsubExponent_(1),
|
120
|
+
bestNonzeroBucket_(2000000),
|
121
|
+
originalBucket_(originalArgs.bucket) {
|
122
|
+
minnChoices_ = {0, 2, 3};
|
123
|
+
updateBest(originalArgs);
|
124
|
+
}
|
125
|
+
|
126
|
+
Args AutotuneStrategy::ask(double elapsed) {
|
127
|
+
const double t = std::min(1.0, elapsed / maxDuration_);
|
128
|
+
trials_++;
|
129
|
+
|
130
|
+
if (trials_ == 1) {
|
131
|
+
return bestArgs_;
|
132
|
+
}
|
133
|
+
|
134
|
+
Args args = bestArgs_;
|
135
|
+
|
136
|
+
if (!args.isManual("epoch")) {
|
137
|
+
args.epoch = updateArgGauss(args.epoch, 1, 100, 2.8, 2.5, t, false, rng_);
|
138
|
+
}
|
139
|
+
if (!args.isManual("lr")) {
|
140
|
+
args.lr = updateArgGauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, rng_);
|
141
|
+
};
|
142
|
+
if (!args.isManual("dim")) {
|
143
|
+
args.dim = updateArgGauss(args.dim, 1, 1000, 1.4, 0.3, t, false, rng_);
|
144
|
+
}
|
145
|
+
if (!args.isManual("wordNgrams")) {
|
146
|
+
args.wordNgrams =
|
147
|
+
updateArgGauss(args.wordNgrams, 1, 5, 4.3, 2.4, t, true, rng_);
|
148
|
+
}
|
149
|
+
if (!args.isManual("dsub")) {
|
150
|
+
int dsubExponent =
|
151
|
+
updateArgGauss(bestDsubExponent_, 1, 4, 2.0, 1.0, t, true, rng_);
|
152
|
+
args.dsub = (1 << dsubExponent);
|
153
|
+
}
|
154
|
+
if (!args.isManual("minn")) {
|
155
|
+
int minnIndex = updateArgGauss(
|
156
|
+
bestMinnIndex_,
|
157
|
+
0,
|
158
|
+
static_cast<int>(minnChoices_.size() - 1),
|
159
|
+
4.0,
|
160
|
+
1.4,
|
161
|
+
t,
|
162
|
+
true,
|
163
|
+
rng_);
|
164
|
+
args.minn = minnChoices_[minnIndex];
|
165
|
+
}
|
166
|
+
if (!args.isManual("maxn")) {
|
167
|
+
if (args.minn == 0) {
|
168
|
+
args.maxn = 0;
|
169
|
+
} else {
|
170
|
+
args.maxn = args.minn + 3;
|
171
|
+
}
|
172
|
+
}
|
173
|
+
if (!args.isManual("bucket")) {
|
174
|
+
int nonZeroBucket = updateArgGauss(
|
175
|
+
bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
|
176
|
+
args.bucket = nonZeroBucket;
|
177
|
+
} else {
|
178
|
+
args.bucket = originalBucket_;
|
179
|
+
}
|
180
|
+
if (args.wordNgrams <= 1 && args.maxn == 0) {
|
181
|
+
args.bucket = 0;
|
182
|
+
}
|
183
|
+
if (!args.isManual("loss")) {
|
184
|
+
args.loss = loss_name::softmax;
|
185
|
+
}
|
186
|
+
|
187
|
+
return args;
|
188
|
+
}
|
189
|
+
|
190
|
+
int AutotuneStrategy::getIndex(int val, const std::vector<int>& choices) {
|
191
|
+
auto found = std::find(choices.begin(), choices.end(), val);
|
192
|
+
int ind = 0;
|
193
|
+
if (found != choices.end()) {
|
194
|
+
ind = std::distance(choices.begin(), found);
|
195
|
+
}
|
196
|
+
return ind;
|
197
|
+
}
|
198
|
+
|
199
|
+
void AutotuneStrategy::updateBest(const Args& args) {
|
200
|
+
bestArgs_ = args;
|
201
|
+
bestMinnIndex_ = getIndex(args.minn, minnChoices_);
|
202
|
+
bestDsubExponent_ = log2(args.dsub);
|
203
|
+
if (args.bucket != 0) {
|
204
|
+
bestNonzeroBucket_ = args.bucket;
|
205
|
+
}
|
206
|
+
}
|
207
|
+
|
208
|
+
Autotune::Autotune(const std::shared_ptr<FastText>& fastText)
|
209
|
+
: fastText_(fastText),
|
210
|
+
elapsed_(0.),
|
211
|
+
bestScore_(0.),
|
212
|
+
trials_(0),
|
213
|
+
sizeConstraintFailed_(0),
|
214
|
+
continueTraining_(false),
|
215
|
+
strategy_(),
|
216
|
+
timer_() {}
|
217
|
+
|
218
|
+
void Autotune::printInfo(double maxDuration) {
|
219
|
+
double progress = elapsed_ * 100 / maxDuration;
|
220
|
+
progress = std::min(progress, 100.0);
|
221
|
+
|
222
|
+
std::cerr << "\r";
|
223
|
+
std::cerr << std::fixed;
|
224
|
+
std::cerr << "Progress: ";
|
225
|
+
std::cerr << std::setprecision(1) << std::setw(5) << progress << "%";
|
226
|
+
std::cerr << " Trials: " << std::setw(4) << trials_;
|
227
|
+
std::cerr << " Best score: " << std::setw(9) << std::setprecision(6);
|
228
|
+
if (bestScore_ == kUnknownBestScore) {
|
229
|
+
std::cerr << "unknown";
|
230
|
+
} else {
|
231
|
+
std::cerr << bestScore_;
|
232
|
+
}
|
233
|
+
std::cerr << " ETA: "
|
234
|
+
<< utils::ClockPrint(std::max(maxDuration - elapsed_, 0.0));
|
235
|
+
std::cerr << std::flush;
|
236
|
+
}
|
237
|
+
|
238
|
+
void Autotune::timer(
|
239
|
+
const std::chrono::steady_clock::time_point& start,
|
240
|
+
double maxDuration) {
|
241
|
+
elapsed_ = 0.0;
|
242
|
+
while (keepTraining(maxDuration)) {
|
243
|
+
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
244
|
+
elapsed_ = utils::getDuration(start, std::chrono::steady_clock::now());
|
245
|
+
printInfo(maxDuration);
|
246
|
+
}
|
247
|
+
abort();
|
248
|
+
}
|
249
|
+
|
250
|
+
bool Autotune::keepTraining(double maxDuration) const {
|
251
|
+
return continueTraining_ && elapsed_ < maxDuration;
|
252
|
+
}
|
253
|
+
|
254
|
+
void Autotune::abort() {
|
255
|
+
if (continueTraining_) {
|
256
|
+
continueTraining_ = false;
|
257
|
+
fastText_->abort();
|
258
|
+
}
|
259
|
+
}
|
260
|
+
|
261
|
+
void Autotune::startTimer(const Args& args) {
|
262
|
+
std::chrono::steady_clock::time_point start =
|
263
|
+
std::chrono::steady_clock::now();
|
264
|
+
timer_ = std::thread([=]() { timer(start, args.autotuneDuration); });
|
265
|
+
bestScore_ = kUnknownBestScore;
|
266
|
+
trials_ = 0;
|
267
|
+
continueTraining_ = true;
|
268
|
+
|
269
|
+
auto previousSignalHandler = std::signal(SIGINT, signalHandler);
|
270
|
+
interruptSignalHandler = [&]() {
|
271
|
+
std::signal(SIGINT, previousSignalHandler);
|
272
|
+
std::cerr << std::endl << "Aborting autotune..." << std::endl;
|
273
|
+
abort();
|
274
|
+
};
|
275
|
+
}
|
276
|
+
|
277
|
+
double Autotune::getMetricScore(
|
278
|
+
Meter& meter,
|
279
|
+
const metric_name& metricName,
|
280
|
+
const double metricValue,
|
281
|
+
const std::string& metricLabel) const {
|
282
|
+
double score = 0.0;
|
283
|
+
int32_t labelId = -1;
|
284
|
+
if (!metricLabel.empty()) {
|
285
|
+
labelId = fastText_->getLabelId(metricLabel);
|
286
|
+
if (labelId == -1) {
|
287
|
+
throw std::runtime_error("Unknown autotune metric label");
|
288
|
+
}
|
289
|
+
}
|
290
|
+
if (metricName == metric_name::f1score) {
|
291
|
+
score = meter.f1Score();
|
292
|
+
} else if (metricName == metric_name::f1scoreLabel) {
|
293
|
+
score = meter.f1Score(labelId);
|
294
|
+
} else if (metricName == metric_name::precisionAtRecall) {
|
295
|
+
score = meter.precisionAtRecall(metricValue);
|
296
|
+
} else if (metricName == metric_name::precisionAtRecallLabel) {
|
297
|
+
score = meter.precisionAtRecall(labelId, metricValue);
|
298
|
+
} else if (metricName == metric_name::recallAtPrecision) {
|
299
|
+
score = meter.recallAtPrecision(metricValue);
|
300
|
+
} else if (metricName == metric_name::recallAtPrecisionLabel) {
|
301
|
+
score = meter.recallAtPrecision(labelId, metricValue);
|
302
|
+
} else {
|
303
|
+
throw std::runtime_error("Unknown metric");
|
304
|
+
}
|
305
|
+
return score;
|
306
|
+
}
|
307
|
+
|
308
|
+
void Autotune::printArgs(const Args& args, const Args& autotuneArgs) {
|
309
|
+
LOG_VAL(epoch, args.epoch)
|
310
|
+
LOG_VAL(lr, args.lr)
|
311
|
+
LOG_VAL(dim, args.dim)
|
312
|
+
LOG_VAL(minCount, args.minCount)
|
313
|
+
LOG_VAL(wordNgrams, args.wordNgrams)
|
314
|
+
LOG_VAL(minn, args.minn)
|
315
|
+
LOG_VAL(maxn, args.maxn)
|
316
|
+
LOG_VAL(bucket, args.bucket)
|
317
|
+
LOG_VAL(dsub, args.dsub)
|
318
|
+
LOG_VAL(loss, args.lossToString(args.loss))
|
319
|
+
}
|
320
|
+
|
321
|
+
int Autotune::getCutoffForFileSize(
|
322
|
+
bool qout,
|
323
|
+
bool qnorm,
|
324
|
+
int dsub,
|
325
|
+
int64_t fileSize) const {
|
326
|
+
int64_t outModelSize = 0;
|
327
|
+
const int64_t outM = fastText_->getOutputMatrix()->size(0);
|
328
|
+
const int64_t outN = fastText_->getOutputMatrix()->size(1);
|
329
|
+
if (qout) {
|
330
|
+
const int64_t outputPqSize = 16 + 4 * (outN * (1 << 8));
|
331
|
+
outModelSize =
|
332
|
+
21 + (outM * ((outN + 2 - 1) / 2)) + outputPqSize + (qnorm ? outM : 0);
|
333
|
+
} else {
|
334
|
+
outModelSize = 16 + 4 * (outM * outN);
|
335
|
+
}
|
336
|
+
const int64_t dim = fastText_->getInputMatrix()->size(1);
|
337
|
+
|
338
|
+
int target = (fileSize - (107) - 4 * (1 << 8) * dim - outModelSize);
|
339
|
+
int cutoff = target / ((dim + dsub - 1) / dsub + (qnorm ? 1 : 0) + 10);
|
340
|
+
|
341
|
+
return std::max(cutoff, kCutoffLimit);
|
342
|
+
}
|
343
|
+
|
344
|
+
bool Autotune::quantize(Args& args, const Args& autotuneArgs) {
|
345
|
+
if (autotuneArgs.getAutotuneModelSize() == Args::kUnlimitedModelSize) {
|
346
|
+
return true;
|
347
|
+
}
|
348
|
+
auto outputSize = fastText_->getOutputMatrix()->size(0);
|
349
|
+
|
350
|
+
args.qnorm = true;
|
351
|
+
args.qout = (outputSize >= kCutoffLimit);
|
352
|
+
args.retrain = true;
|
353
|
+
args.cutoff = getCutoffForFileSize(
|
354
|
+
args.qout, args.qnorm, args.dsub, autotuneArgs.getAutotuneModelSize());
|
355
|
+
LOG_VAL(cutoff, args.cutoff);
|
356
|
+
if (args.cutoff == kCutoffLimit) {
|
357
|
+
return false;
|
358
|
+
}
|
359
|
+
fastText_->quantize(args);
|
360
|
+
|
361
|
+
return true;
|
362
|
+
}
|
363
|
+
|
364
|
+
void Autotune::printSkippedArgs(const Args& autotuneArgs) {
|
365
|
+
std::unordered_set<std::string> argsToCheck = {"epoch",
|
366
|
+
"lr",
|
367
|
+
"dim",
|
368
|
+
"wordNgrams",
|
369
|
+
"loss",
|
370
|
+
"bucket",
|
371
|
+
"minn",
|
372
|
+
"maxn",
|
373
|
+
"dsub"};
|
374
|
+
for (const auto& arg : argsToCheck) {
|
375
|
+
if (autotuneArgs.isManual(arg)) {
|
376
|
+
std::cerr << "Warning : " << arg
|
377
|
+
<< " is manually set to a specific value. "
|
378
|
+
<< "It will not be automatically optimized." << std::endl;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
}
|
382
|
+
|
383
|
+
void Autotune::train(const Args& autotuneArgs) {
|
384
|
+
std::ifstream validationFileStream(autotuneArgs.autotuneValidationFile);
|
385
|
+
if (!validationFileStream.is_open()) {
|
386
|
+
throw std::invalid_argument("Validation file cannot be opened!");
|
387
|
+
}
|
388
|
+
printSkippedArgs(autotuneArgs);
|
389
|
+
|
390
|
+
bool sizeConstraintWarning = false;
|
391
|
+
int verbose = autotuneArgs.verbose;
|
392
|
+
Args bestTrainArgs(autotuneArgs);
|
393
|
+
Args trainArgs(autotuneArgs);
|
394
|
+
trainArgs.verbose = 0;
|
395
|
+
strategy_ = std::unique_ptr<AutotuneStrategy>(
|
396
|
+
new AutotuneStrategy(trainArgs, autotuneArgs.seed));
|
397
|
+
startTimer(autotuneArgs);
|
398
|
+
|
399
|
+
while (keepTraining(autotuneArgs.autotuneDuration)) {
|
400
|
+
trials_++;
|
401
|
+
|
402
|
+
trainArgs = strategy_->ask(elapsed_);
|
403
|
+
LOG_VAL(Trial, trials_)
|
404
|
+
printArgs(trainArgs, autotuneArgs);
|
405
|
+
ElapsedTimeMarker elapsedTimeMarker;
|
406
|
+
double currentScore = std::numeric_limits<double>::quiet_NaN();
|
407
|
+
try {
|
408
|
+
fastText_->train(trainArgs);
|
409
|
+
bool sizeConstraintOK = quantize(trainArgs, autotuneArgs);
|
410
|
+
if (sizeConstraintOK) {
|
411
|
+
const auto& metricLabel = autotuneArgs.getAutotuneMetricLabel();
|
412
|
+
Meter meter(!metricLabel.empty());
|
413
|
+
fastText_->test(
|
414
|
+
validationFileStream, autotuneArgs.autotunePredictions, 0.0, meter);
|
415
|
+
|
416
|
+
currentScore = getMetricScore(
|
417
|
+
meter,
|
418
|
+
autotuneArgs.getAutotuneMetric(),
|
419
|
+
autotuneArgs.getAutotuneMetricValue(),
|
420
|
+
metricLabel);
|
421
|
+
|
422
|
+
if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
|
423
|
+
bestTrainArgs = trainArgs;
|
424
|
+
bestScore_ = currentScore;
|
425
|
+
strategy_->updateBest(bestTrainArgs);
|
426
|
+
}
|
427
|
+
} else {
|
428
|
+
sizeConstraintFailed_++;
|
429
|
+
if (!sizeConstraintWarning && trials_ > 10 &&
|
430
|
+
sizeConstraintFailed_ > (trials_ / 2)) {
|
431
|
+
sizeConstraintWarning = true;
|
432
|
+
std::cerr << std::endl
|
433
|
+
<< "Warning : requested model size is probably too small. "
|
434
|
+
"You may want to increase `autotune-modelsize`."
|
435
|
+
<< std::endl;
|
436
|
+
}
|
437
|
+
}
|
438
|
+
} catch (DenseMatrix::EncounteredNaNError&) {
|
439
|
+
// ignore diverging loss and go on
|
440
|
+
} catch (std::bad_alloc&) {
|
441
|
+
// ignore parameter samples asking too much memory
|
442
|
+
} catch (TimeoutError&) {
|
443
|
+
break;
|
444
|
+
} catch (FastText::AbortError&) {
|
445
|
+
break;
|
446
|
+
}
|
447
|
+
LOG_VAL_NAN(currentScore, currentScore)
|
448
|
+
LOG_VAL(train took, elapsedTimeMarker.getElapsed())
|
449
|
+
}
|
450
|
+
if (timer_.joinable()) {
|
451
|
+
timer_.join();
|
452
|
+
}
|
453
|
+
|
454
|
+
if (bestScore_ == kUnknownBestScore) {
|
455
|
+
std::string errorMessage;
|
456
|
+
if (sizeConstraintWarning) {
|
457
|
+
errorMessage =
|
458
|
+
"Couldn't fulfil model size constraint: please increase "
|
459
|
+
"`autotune-modelsize`.";
|
460
|
+
} else {
|
461
|
+
errorMessage =
|
462
|
+
"Didn't have enough time to train once: please increase "
|
463
|
+
"`autotune-duration`.";
|
464
|
+
}
|
465
|
+
throw std::runtime_error(errorMessage);
|
466
|
+
} else {
|
467
|
+
std::cerr << std::endl;
|
468
|
+
std::cerr << "Training again with best arguments" << std::endl;
|
469
|
+
bestTrainArgs.verbose = verbose;
|
470
|
+
LOG_VAL(Best selected args, 0)
|
471
|
+
printArgs(bestTrainArgs, autotuneArgs);
|
472
|
+
fastText_->train(bestTrainArgs);
|
473
|
+
quantize(bestTrainArgs, autotuneArgs);
|
474
|
+
}
|
475
|
+
}
|
476
|
+
|
477
|
+
} // namespace fasttext
|
@@ -0,0 +1,89 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the MIT license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
7
|
+
*/
|
8
|
+
|
9
|
+
#pragma once
|
10
|
+
|
11
|
+
#include <istream>
|
12
|
+
#include <memory>
|
13
|
+
#include <random>
|
14
|
+
#include <thread>
|
15
|
+
#include <vector>
|
16
|
+
|
17
|
+
#include "args.h"
|
18
|
+
#include "fasttext.h"
|
19
|
+
|
20
|
+
namespace fasttext {
|
21
|
+
|
22
|
+
class AutotuneStrategy {
|
23
|
+
private:
|
24
|
+
Args bestArgs_;
|
25
|
+
int maxDuration_;
|
26
|
+
std::minstd_rand rng_;
|
27
|
+
int trials_;
|
28
|
+
int bestMinnIndex_;
|
29
|
+
int bestDsubExponent_;
|
30
|
+
int bestNonzeroBucket_;
|
31
|
+
int originalBucket_;
|
32
|
+
std::vector<int> minnChoices_;
|
33
|
+
int getIndex(int val, const std::vector<int>& choices);
|
34
|
+
|
35
|
+
public:
|
36
|
+
explicit AutotuneStrategy(
|
37
|
+
const Args& args,
|
38
|
+
std::minstd_rand::result_type seed);
|
39
|
+
Args ask(double elapsed);
|
40
|
+
void updateBest(const Args& args);
|
41
|
+
};
|
42
|
+
|
43
|
+
class Autotune {
|
44
|
+
protected:
|
45
|
+
std::shared_ptr<FastText> fastText_;
|
46
|
+
double elapsed_;
|
47
|
+
double bestScore_;
|
48
|
+
int32_t trials_;
|
49
|
+
int32_t sizeConstraintFailed_;
|
50
|
+
std::atomic<bool> continueTraining_;
|
51
|
+
std::unique_ptr<AutotuneStrategy> strategy_;
|
52
|
+
std::thread timer_;
|
53
|
+
|
54
|
+
bool keepTraining(double maxDuration) const;
|
55
|
+
void printInfo(double maxDuration);
|
56
|
+
void timer(
|
57
|
+
const std::chrono::steady_clock::time_point& start,
|
58
|
+
double maxDuration);
|
59
|
+
void abort();
|
60
|
+
void startTimer(const Args& args);
|
61
|
+
double getMetricScore(
|
62
|
+
Meter& meter,
|
63
|
+
const metric_name& metricName,
|
64
|
+
const double metricValue,
|
65
|
+
const std::string& metricLabel) const;
|
66
|
+
void printArgs(const Args& args, const Args& autotuneArgs);
|
67
|
+
void printSkippedArgs(const Args& autotuneArgs);
|
68
|
+
bool quantize(Args& args, const Args& autotuneArgs);
|
69
|
+
int getCutoffForFileSize(bool qout, bool qnorm, int dsub, int64_t fileSize)
|
70
|
+
const;
|
71
|
+
|
72
|
+
class TimeoutError : public std::runtime_error {
|
73
|
+
public:
|
74
|
+
TimeoutError() : std::runtime_error("Autotune timed out.") {}
|
75
|
+
};
|
76
|
+
|
77
|
+
public:
|
78
|
+
Autotune() = delete;
|
79
|
+
explicit Autotune(const std::shared_ptr<FastText>& fastText);
|
80
|
+
Autotune(const Autotune&) = delete;
|
81
|
+
Autotune(Autotune&&) = delete;
|
82
|
+
Autotune& operator=(const Autotune&) = delete;
|
83
|
+
Autotune& operator=(Autotune&&) = delete;
|
84
|
+
~Autotune() noexcept = default;
|
85
|
+
|
86
|
+
void train(const Args& args);
|
87
|
+
};
|
88
|
+
|
89
|
+
} // namespace fasttext
|
@@ -8,11 +8,10 @@
|
|
8
8
|
|
9
9
|
#include "densematrix.h"
|
10
10
|
|
11
|
-
#include <exception>
|
12
11
|
#include <random>
|
13
12
|
#include <stdexcept>
|
13
|
+
#include <thread>
|
14
14
|
#include <utility>
|
15
|
-
|
16
15
|
#include "utils.h"
|
17
16
|
#include "vector.h"
|
18
17
|
|
@@ -25,18 +24,39 @@ DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {}
|
|
25
24
|
DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept
|
26
25
|
: Matrix(other.m_, other.n_), data_(std::move(other.data_)) {}
|
27
26
|
|
27
|
+
DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr)
|
28
|
+
: Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {}
|
29
|
+
|
28
30
|
void DenseMatrix::zero() {
|
29
31
|
std::fill(data_.begin(), data_.end(), 0.0);
|
30
32
|
}
|
31
33
|
|
32
|
-
void DenseMatrix::
|
33
|
-
std::minstd_rand rng(
|
34
|
+
void DenseMatrix::uniformThread(real a, int block, int32_t seed) {
|
35
|
+
std::minstd_rand rng(block + seed);
|
34
36
|
std::uniform_real_distribution<> uniform(-a, a);
|
35
|
-
|
37
|
+
int64_t blockSize = (m_ * n_) / 10;
|
38
|
+
for (int64_t i = blockSize * block;
|
39
|
+
i < (m_ * n_) && i < blockSize * (block + 1);
|
40
|
+
i++) {
|
36
41
|
data_[i] = uniform(rng);
|
37
42
|
}
|
38
43
|
}
|
39
44
|
|
45
|
+
void DenseMatrix::uniform(real a, unsigned int thread, int32_t seed) {
|
46
|
+
if (thread > 1) {
|
47
|
+
std::vector<std::thread> threads;
|
48
|
+
for (int i = 0; i < thread; i++) {
|
49
|
+
threads.push_back(std::thread([=]() { uniformThread(a, i, seed); }));
|
50
|
+
}
|
51
|
+
for (int32_t i = 0; i < threads.size(); i++) {
|
52
|
+
threads[i].join();
|
53
|
+
}
|
54
|
+
} else {
|
55
|
+
// webassembly can't instantiate `std::thread`
|
56
|
+
uniformThread(a, 0, seed);
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
40
60
|
void DenseMatrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
|
41
61
|
if (ie == -1) {
|
42
62
|
ie = m_;
|
@@ -73,7 +93,7 @@ real DenseMatrix::l2NormRow(int64_t i) const {
|
|
73
93
|
norm += at(i, j) * at(i, j);
|
74
94
|
}
|
75
95
|
if (std::isnan(norm)) {
|
76
|
-
throw
|
96
|
+
throw EncounteredNaNError();
|
77
97
|
}
|
78
98
|
return std::sqrt(norm);
|
79
99
|
}
|
@@ -94,7 +114,7 @@ real DenseMatrix::dotRow(const Vector& vec, int64_t i) const {
|
|
94
114
|
d += at(i, j) * vec[j];
|
95
115
|
}
|
96
116
|
if (std::isnan(d)) {
|
97
|
-
throw
|
117
|
+
throw EncounteredNaNError();
|
98
118
|
}
|
99
119
|
return d;
|
100
120
|
}
|
@@ -8,12 +8,13 @@
|
|
8
8
|
|
9
9
|
#pragma once
|
10
10
|
|
11
|
+
#include <assert.h>
|
11
12
|
#include <cstdint>
|
12
13
|
#include <istream>
|
13
14
|
#include <ostream>
|
15
|
+
#include <stdexcept>
|
14
16
|
#include <vector>
|
15
17
|
|
16
|
-
#include <assert.h>
|
17
18
|
#include "matrix.h"
|
18
19
|
#include "real.h"
|
19
20
|
|
@@ -24,10 +25,12 @@ class Vector;
|
|
24
25
|
class DenseMatrix : public Matrix {
|
25
26
|
protected:
|
26
27
|
std::vector<real> data_;
|
28
|
+
void uniformThread(real, int, int32_t);
|
27
29
|
|
28
30
|
public:
|
29
31
|
DenseMatrix();
|
30
32
|
explicit DenseMatrix(int64_t, int64_t);
|
33
|
+
explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr);
|
31
34
|
DenseMatrix(const DenseMatrix&) = default;
|
32
35
|
DenseMatrix(DenseMatrix&&) noexcept;
|
33
36
|
DenseMatrix& operator=(const DenseMatrix&) = delete;
|
@@ -56,7 +59,7 @@ class DenseMatrix : public Matrix {
|
|
56
59
|
return n_;
|
57
60
|
}
|
58
61
|
void zero();
|
59
|
-
void uniform(real);
|
62
|
+
void uniform(real, unsigned int, int32_t);
|
60
63
|
|
61
64
|
void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
|
62
65
|
void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
|
@@ -71,5 +74,10 @@ class DenseMatrix : public Matrix {
|
|
71
74
|
void save(std::ostream&) const override;
|
72
75
|
void load(std::istream&) override;
|
73
76
|
void dump(std::ostream&) const override;
|
77
|
+
|
78
|
+
class EncounteredNaNError : public std::runtime_error {
|
79
|
+
public:
|
80
|
+
EncounteredNaNError() : std::runtime_error("Encountered NaN.") {}
|
81
|
+
};
|
74
82
|
};
|
75
83
|
} // namespace fasttext
|