fasttext 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
data/vendor/fastText/src/args.h
CHANGED
@@ -11,18 +11,28 @@
|
|
11
11
|
#include <istream>
|
12
12
|
#include <ostream>
|
13
13
|
#include <string>
|
14
|
+
#include <unordered_set>
|
14
15
|
#include <vector>
|
15
16
|
|
16
17
|
namespace fasttext {
|
17
18
|
|
18
19
|
enum class model_name : int { cbow = 1, sg, sup };
|
19
20
|
enum class loss_name : int { hs = 1, ns, softmax, ova };
|
21
|
+
enum class metric_name : int {
|
22
|
+
f1score = 1,
|
23
|
+
f1scoreLabel,
|
24
|
+
precisionAtRecall,
|
25
|
+
precisionAtRecallLabel,
|
26
|
+
recallAtPrecision,
|
27
|
+
recallAtPrecisionLabel
|
28
|
+
};
|
20
29
|
|
21
30
|
class Args {
|
22
31
|
protected:
|
23
|
-
std::string lossToString(loss_name) const;
|
24
32
|
std::string boolToString(bool) const;
|
25
33
|
std::string modelToString(model_name) const;
|
34
|
+
std::string metricToString(metric_name) const;
|
35
|
+
std::unordered_set<std::string> manualArgs_;
|
26
36
|
|
27
37
|
public:
|
28
38
|
Args();
|
@@ -48,6 +58,7 @@ class Args {
|
|
48
58
|
int verbose;
|
49
59
|
std::string pretrainedVectors;
|
50
60
|
bool saveOutput;
|
61
|
+
int seed;
|
51
62
|
|
52
63
|
bool qout;
|
53
64
|
bool retrain;
|
@@ -55,14 +66,31 @@ class Args {
|
|
55
66
|
size_t cutoff;
|
56
67
|
size_t dsub;
|
57
68
|
|
69
|
+
std::string autotuneValidationFile;
|
70
|
+
std::string autotuneMetric;
|
71
|
+
int autotunePredictions;
|
72
|
+
int autotuneDuration;
|
73
|
+
std::string autotuneModelSize;
|
74
|
+
|
58
75
|
void parseArgs(const std::vector<std::string>& args);
|
59
76
|
void printHelp();
|
60
77
|
void printBasicHelp();
|
61
78
|
void printDictionaryHelp();
|
62
79
|
void printTrainingHelp();
|
80
|
+
void printAutotuneHelp();
|
63
81
|
void printQuantizationHelp();
|
64
82
|
void save(std::ostream&);
|
65
83
|
void load(std::istream&);
|
66
84
|
void dump(std::ostream&) const;
|
85
|
+
bool hasAutotune() const;
|
86
|
+
bool isManual(const std::string& argName) const;
|
87
|
+
void setManual(const std::string& argName);
|
88
|
+
std::string lossToString(loss_name) const;
|
89
|
+
metric_name getAutotuneMetric() const;
|
90
|
+
std::string getAutotuneMetricLabel() const;
|
91
|
+
double getAutotuneMetricValue() const;
|
92
|
+
int64_t getAutotuneModelSize() const;
|
93
|
+
|
94
|
+
static constexpr double kUnlimitedModelSize = -1.0;
|
67
95
|
};
|
68
96
|
} // namespace fasttext
|
@@ -0,0 +1,477 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the MIT license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree.
|
7
|
+
*/
|
8
|
+
|
9
|
+
#include "autotune.h"
|
10
|
+
|
11
|
+
#include <algorithm>
|
12
|
+
#include <csignal>
|
13
|
+
#include <functional>
|
14
|
+
#include <iomanip>
|
15
|
+
#include <iostream>
|
16
|
+
#include <random>
|
17
|
+
#include <thread>
|
18
|
+
|
19
|
+
#define LOG_VAL(name, val) \
|
20
|
+
if (autotuneArgs.verbose > 2) { \
|
21
|
+
std::cout << #name " = " << val << std::endl; \
|
22
|
+
}
|
23
|
+
#define LOG_VAL_NAN(name, val) \
|
24
|
+
if (autotuneArgs.verbose > 2) { \
|
25
|
+
if (std::isnan(val)) { \
|
26
|
+
std::cout << #name " = NaN" << std::endl; \
|
27
|
+
} else { \
|
28
|
+
std::cout << #name " = " << val << std::endl; \
|
29
|
+
} \
|
30
|
+
}
|
31
|
+
|
32
|
+
namespace {
|
33
|
+
|
34
|
+
std::function<void()> interruptSignalHandler;
|
35
|
+
|
36
|
+
void signalHandler(int signal) {
|
37
|
+
if (signal == SIGINT) {
|
38
|
+
interruptSignalHandler();
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
class ElapsedTimeMarker {
|
43
|
+
std::chrono::steady_clock::time_point start_;
|
44
|
+
|
45
|
+
public:
|
46
|
+
ElapsedTimeMarker() {
|
47
|
+
start_ = std::chrono::steady_clock::now();
|
48
|
+
}
|
49
|
+
double getElapsed() {
|
50
|
+
return fasttext::utils::getDuration(
|
51
|
+
start_, std::chrono::steady_clock::now());
|
52
|
+
}
|
53
|
+
};
|
54
|
+
|
55
|
+
} // namespace
|
56
|
+
|
57
|
+
namespace fasttext {
|
58
|
+
|
59
|
+
constexpr double kUnknownBestScore = -1.0;
|
60
|
+
constexpr int kCutoffLimit = 256;
|
61
|
+
|
62
|
+
template <typename T>
|
63
|
+
T getArgGauss(
|
64
|
+
T val,
|
65
|
+
std::minstd_rand& rng,
|
66
|
+
double startSigma,
|
67
|
+
double endSigma,
|
68
|
+
double t,
|
69
|
+
bool linear) {
|
70
|
+
T returnValue;
|
71
|
+
const double stddev = startSigma -
|
72
|
+
((startSigma - endSigma) / 0.5) *
|
73
|
+
std::min(0.5, std::max((t - 0.25), 0.0));
|
74
|
+
|
75
|
+
std::normal_distribution<double> normal(0.0, stddev);
|
76
|
+
|
77
|
+
const double coeff = normal(rng);
|
78
|
+
double updateCoeff = 0.0;
|
79
|
+
|
80
|
+
if (linear) {
|
81
|
+
updateCoeff = coeff;
|
82
|
+
returnValue = static_cast<T>(updateCoeff + val);
|
83
|
+
} else {
|
84
|
+
updateCoeff = std::pow(2.0, coeff);
|
85
|
+
returnValue = static_cast<T>(updateCoeff * val);
|
86
|
+
}
|
87
|
+
|
88
|
+
return returnValue;
|
89
|
+
}
|
90
|
+
|
91
|
+
template <typename T>
|
92
|
+
T updateArgGauss(
|
93
|
+
T val,
|
94
|
+
T min,
|
95
|
+
T max,
|
96
|
+
double startSigma,
|
97
|
+
double endSigma,
|
98
|
+
double t,
|
99
|
+
bool linear,
|
100
|
+
std::minstd_rand& rng) {
|
101
|
+
T retVal = getArgGauss(val, rng, startSigma, endSigma, t, linear);
|
102
|
+
if (retVal > max) {
|
103
|
+
retVal = max;
|
104
|
+
}
|
105
|
+
if (retVal < min) {
|
106
|
+
retVal = min;
|
107
|
+
}
|
108
|
+
return retVal;
|
109
|
+
}
|
110
|
+
|
111
|
+
AutotuneStrategy::AutotuneStrategy(
|
112
|
+
const Args& originalArgs,
|
113
|
+
std::minstd_rand::result_type seed)
|
114
|
+
: bestArgs_(),
|
115
|
+
maxDuration_(originalArgs.autotuneDuration),
|
116
|
+
rng_(seed),
|
117
|
+
trials_(0),
|
118
|
+
bestMinnIndex_(0),
|
119
|
+
bestDsubExponent_(1),
|
120
|
+
bestNonzeroBucket_(2000000),
|
121
|
+
originalBucket_(originalArgs.bucket) {
|
122
|
+
minnChoices_ = {0, 2, 3};
|
123
|
+
updateBest(originalArgs);
|
124
|
+
}
|
125
|
+
|
126
|
+
Args AutotuneStrategy::ask(double elapsed) {
|
127
|
+
const double t = std::min(1.0, elapsed / maxDuration_);
|
128
|
+
trials_++;
|
129
|
+
|
130
|
+
if (trials_ == 1) {
|
131
|
+
return bestArgs_;
|
132
|
+
}
|
133
|
+
|
134
|
+
Args args = bestArgs_;
|
135
|
+
|
136
|
+
if (!args.isManual("epoch")) {
|
137
|
+
args.epoch = updateArgGauss(args.epoch, 1, 100, 2.8, 2.5, t, false, rng_);
|
138
|
+
}
|
139
|
+
if (!args.isManual("lr")) {
|
140
|
+
args.lr = updateArgGauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, rng_);
|
141
|
+
};
|
142
|
+
if (!args.isManual("dim")) {
|
143
|
+
args.dim = updateArgGauss(args.dim, 1, 1000, 1.4, 0.3, t, false, rng_);
|
144
|
+
}
|
145
|
+
if (!args.isManual("wordNgrams")) {
|
146
|
+
args.wordNgrams =
|
147
|
+
updateArgGauss(args.wordNgrams, 1, 5, 4.3, 2.4, t, true, rng_);
|
148
|
+
}
|
149
|
+
if (!args.isManual("dsub")) {
|
150
|
+
int dsubExponent =
|
151
|
+
updateArgGauss(bestDsubExponent_, 1, 4, 2.0, 1.0, t, true, rng_);
|
152
|
+
args.dsub = (1 << dsubExponent);
|
153
|
+
}
|
154
|
+
if (!args.isManual("minn")) {
|
155
|
+
int minnIndex = updateArgGauss(
|
156
|
+
bestMinnIndex_,
|
157
|
+
0,
|
158
|
+
static_cast<int>(minnChoices_.size() - 1),
|
159
|
+
4.0,
|
160
|
+
1.4,
|
161
|
+
t,
|
162
|
+
true,
|
163
|
+
rng_);
|
164
|
+
args.minn = minnChoices_[minnIndex];
|
165
|
+
}
|
166
|
+
if (!args.isManual("maxn")) {
|
167
|
+
if (args.minn == 0) {
|
168
|
+
args.maxn = 0;
|
169
|
+
} else {
|
170
|
+
args.maxn = args.minn + 3;
|
171
|
+
}
|
172
|
+
}
|
173
|
+
if (!args.isManual("bucket")) {
|
174
|
+
int nonZeroBucket = updateArgGauss(
|
175
|
+
bestNonzeroBucket_, 10000, 10000000, 2.0, 1.5, t, false, rng_);
|
176
|
+
args.bucket = nonZeroBucket;
|
177
|
+
} else {
|
178
|
+
args.bucket = originalBucket_;
|
179
|
+
}
|
180
|
+
if (args.wordNgrams <= 1 && args.maxn == 0) {
|
181
|
+
args.bucket = 0;
|
182
|
+
}
|
183
|
+
if (!args.isManual("loss")) {
|
184
|
+
args.loss = loss_name::softmax;
|
185
|
+
}
|
186
|
+
|
187
|
+
return args;
|
188
|
+
}
|
189
|
+
|
190
|
+
int AutotuneStrategy::getIndex(int val, const std::vector<int>& choices) {
|
191
|
+
auto found = std::find(choices.begin(), choices.end(), val);
|
192
|
+
int ind = 0;
|
193
|
+
if (found != choices.end()) {
|
194
|
+
ind = std::distance(choices.begin(), found);
|
195
|
+
}
|
196
|
+
return ind;
|
197
|
+
}
|
198
|
+
|
199
|
+
void AutotuneStrategy::updateBest(const Args& args) {
|
200
|
+
bestArgs_ = args;
|
201
|
+
bestMinnIndex_ = getIndex(args.minn, minnChoices_);
|
202
|
+
bestDsubExponent_ = log2(args.dsub);
|
203
|
+
if (args.bucket != 0) {
|
204
|
+
bestNonzeroBucket_ = args.bucket;
|
205
|
+
}
|
206
|
+
}
|
207
|
+
|
208
|
+
Autotune::Autotune(const std::shared_ptr<FastText>& fastText)
|
209
|
+
: fastText_(fastText),
|
210
|
+
elapsed_(0.),
|
211
|
+
bestScore_(0.),
|
212
|
+
trials_(0),
|
213
|
+
sizeConstraintFailed_(0),
|
214
|
+
continueTraining_(false),
|
215
|
+
strategy_(),
|
216
|
+
timer_() {}
|
217
|
+
|
218
|
+
void Autotune::printInfo(double maxDuration) {
|
219
|
+
double progress = elapsed_ * 100 / maxDuration;
|
220
|
+
progress = std::min(progress, 100.0);
|
221
|
+
|
222
|
+
std::cerr << "\r";
|
223
|
+
std::cerr << std::fixed;
|
224
|
+
std::cerr << "Progress: ";
|
225
|
+
std::cerr << std::setprecision(1) << std::setw(5) << progress << "%";
|
226
|
+
std::cerr << " Trials: " << std::setw(4) << trials_;
|
227
|
+
std::cerr << " Best score: " << std::setw(9) << std::setprecision(6);
|
228
|
+
if (bestScore_ == kUnknownBestScore) {
|
229
|
+
std::cerr << "unknown";
|
230
|
+
} else {
|
231
|
+
std::cerr << bestScore_;
|
232
|
+
}
|
233
|
+
std::cerr << " ETA: "
|
234
|
+
<< utils::ClockPrint(std::max(maxDuration - elapsed_, 0.0));
|
235
|
+
std::cerr << std::flush;
|
236
|
+
}
|
237
|
+
|
238
|
+
void Autotune::timer(
|
239
|
+
const std::chrono::steady_clock::time_point& start,
|
240
|
+
double maxDuration) {
|
241
|
+
elapsed_ = 0.0;
|
242
|
+
while (keepTraining(maxDuration)) {
|
243
|
+
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
244
|
+
elapsed_ = utils::getDuration(start, std::chrono::steady_clock::now());
|
245
|
+
printInfo(maxDuration);
|
246
|
+
}
|
247
|
+
abort();
|
248
|
+
}
|
249
|
+
|
250
|
+
bool Autotune::keepTraining(double maxDuration) const {
|
251
|
+
return continueTraining_ && elapsed_ < maxDuration;
|
252
|
+
}
|
253
|
+
|
254
|
+
void Autotune::abort() {
|
255
|
+
if (continueTraining_) {
|
256
|
+
continueTraining_ = false;
|
257
|
+
fastText_->abort();
|
258
|
+
}
|
259
|
+
}
|
260
|
+
|
261
|
+
void Autotune::startTimer(const Args& args) {
|
262
|
+
std::chrono::steady_clock::time_point start =
|
263
|
+
std::chrono::steady_clock::now();
|
264
|
+
timer_ = std::thread([=]() { timer(start, args.autotuneDuration); });
|
265
|
+
bestScore_ = kUnknownBestScore;
|
266
|
+
trials_ = 0;
|
267
|
+
continueTraining_ = true;
|
268
|
+
|
269
|
+
auto previousSignalHandler = std::signal(SIGINT, signalHandler);
|
270
|
+
interruptSignalHandler = [&]() {
|
271
|
+
std::signal(SIGINT, previousSignalHandler);
|
272
|
+
std::cerr << std::endl << "Aborting autotune..." << std::endl;
|
273
|
+
abort();
|
274
|
+
};
|
275
|
+
}
|
276
|
+
|
277
|
+
double Autotune::getMetricScore(
|
278
|
+
Meter& meter,
|
279
|
+
const metric_name& metricName,
|
280
|
+
const double metricValue,
|
281
|
+
const std::string& metricLabel) const {
|
282
|
+
double score = 0.0;
|
283
|
+
int32_t labelId = -1;
|
284
|
+
if (!metricLabel.empty()) {
|
285
|
+
labelId = fastText_->getLabelId(metricLabel);
|
286
|
+
if (labelId == -1) {
|
287
|
+
throw std::runtime_error("Unknown autotune metric label");
|
288
|
+
}
|
289
|
+
}
|
290
|
+
if (metricName == metric_name::f1score) {
|
291
|
+
score = meter.f1Score();
|
292
|
+
} else if (metricName == metric_name::f1scoreLabel) {
|
293
|
+
score = meter.f1Score(labelId);
|
294
|
+
} else if (metricName == metric_name::precisionAtRecall) {
|
295
|
+
score = meter.precisionAtRecall(metricValue);
|
296
|
+
} else if (metricName == metric_name::precisionAtRecallLabel) {
|
297
|
+
score = meter.precisionAtRecall(labelId, metricValue);
|
298
|
+
} else if (metricName == metric_name::recallAtPrecision) {
|
299
|
+
score = meter.recallAtPrecision(metricValue);
|
300
|
+
} else if (metricName == metric_name::recallAtPrecisionLabel) {
|
301
|
+
score = meter.recallAtPrecision(labelId, metricValue);
|
302
|
+
} else {
|
303
|
+
throw std::runtime_error("Unknown metric");
|
304
|
+
}
|
305
|
+
return score;
|
306
|
+
}
|
307
|
+
|
308
|
+
void Autotune::printArgs(const Args& args, const Args& autotuneArgs) {
|
309
|
+
LOG_VAL(epoch, args.epoch)
|
310
|
+
LOG_VAL(lr, args.lr)
|
311
|
+
LOG_VAL(dim, args.dim)
|
312
|
+
LOG_VAL(minCount, args.minCount)
|
313
|
+
LOG_VAL(wordNgrams, args.wordNgrams)
|
314
|
+
LOG_VAL(minn, args.minn)
|
315
|
+
LOG_VAL(maxn, args.maxn)
|
316
|
+
LOG_VAL(bucket, args.bucket)
|
317
|
+
LOG_VAL(dsub, args.dsub)
|
318
|
+
LOG_VAL(loss, args.lossToString(args.loss))
|
319
|
+
}
|
320
|
+
|
321
|
+
int Autotune::getCutoffForFileSize(
|
322
|
+
bool qout,
|
323
|
+
bool qnorm,
|
324
|
+
int dsub,
|
325
|
+
int64_t fileSize) const {
|
326
|
+
int64_t outModelSize = 0;
|
327
|
+
const int64_t outM = fastText_->getOutputMatrix()->size(0);
|
328
|
+
const int64_t outN = fastText_->getOutputMatrix()->size(1);
|
329
|
+
if (qout) {
|
330
|
+
const int64_t outputPqSize = 16 + 4 * (outN * (1 << 8));
|
331
|
+
outModelSize =
|
332
|
+
21 + (outM * ((outN + 2 - 1) / 2)) + outputPqSize + (qnorm ? outM : 0);
|
333
|
+
} else {
|
334
|
+
outModelSize = 16 + 4 * (outM * outN);
|
335
|
+
}
|
336
|
+
const int64_t dim = fastText_->getInputMatrix()->size(1);
|
337
|
+
|
338
|
+
int target = (fileSize - (107) - 4 * (1 << 8) * dim - outModelSize);
|
339
|
+
int cutoff = target / ((dim + dsub - 1) / dsub + (qnorm ? 1 : 0) + 10);
|
340
|
+
|
341
|
+
return std::max(cutoff, kCutoffLimit);
|
342
|
+
}
|
343
|
+
|
344
|
+
bool Autotune::quantize(Args& args, const Args& autotuneArgs) {
|
345
|
+
if (autotuneArgs.getAutotuneModelSize() == Args::kUnlimitedModelSize) {
|
346
|
+
return true;
|
347
|
+
}
|
348
|
+
auto outputSize = fastText_->getOutputMatrix()->size(0);
|
349
|
+
|
350
|
+
args.qnorm = true;
|
351
|
+
args.qout = (outputSize >= kCutoffLimit);
|
352
|
+
args.retrain = true;
|
353
|
+
args.cutoff = getCutoffForFileSize(
|
354
|
+
args.qout, args.qnorm, args.dsub, autotuneArgs.getAutotuneModelSize());
|
355
|
+
LOG_VAL(cutoff, args.cutoff);
|
356
|
+
if (args.cutoff == kCutoffLimit) {
|
357
|
+
return false;
|
358
|
+
}
|
359
|
+
fastText_->quantize(args);
|
360
|
+
|
361
|
+
return true;
|
362
|
+
}
|
363
|
+
|
364
|
+
void Autotune::printSkippedArgs(const Args& autotuneArgs) {
|
365
|
+
std::unordered_set<std::string> argsToCheck = {"epoch",
|
366
|
+
"lr",
|
367
|
+
"dim",
|
368
|
+
"wordNgrams",
|
369
|
+
"loss",
|
370
|
+
"bucket",
|
371
|
+
"minn",
|
372
|
+
"maxn",
|
373
|
+
"dsub"};
|
374
|
+
for (const auto& arg : argsToCheck) {
|
375
|
+
if (autotuneArgs.isManual(arg)) {
|
376
|
+
std::cerr << "Warning : " << arg
|
377
|
+
<< " is manually set to a specific value. "
|
378
|
+
<< "It will not be automatically optimized." << std::endl;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
}
|
382
|
+
|
383
|
+
void Autotune::train(const Args& autotuneArgs) {
|
384
|
+
std::ifstream validationFileStream(autotuneArgs.autotuneValidationFile);
|
385
|
+
if (!validationFileStream.is_open()) {
|
386
|
+
throw std::invalid_argument("Validation file cannot be opened!");
|
387
|
+
}
|
388
|
+
printSkippedArgs(autotuneArgs);
|
389
|
+
|
390
|
+
bool sizeConstraintWarning = false;
|
391
|
+
int verbose = autotuneArgs.verbose;
|
392
|
+
Args bestTrainArgs(autotuneArgs);
|
393
|
+
Args trainArgs(autotuneArgs);
|
394
|
+
trainArgs.verbose = 0;
|
395
|
+
strategy_ = std::unique_ptr<AutotuneStrategy>(
|
396
|
+
new AutotuneStrategy(trainArgs, autotuneArgs.seed));
|
397
|
+
startTimer(autotuneArgs);
|
398
|
+
|
399
|
+
while (keepTraining(autotuneArgs.autotuneDuration)) {
|
400
|
+
trials_++;
|
401
|
+
|
402
|
+
trainArgs = strategy_->ask(elapsed_);
|
403
|
+
LOG_VAL(Trial, trials_)
|
404
|
+
printArgs(trainArgs, autotuneArgs);
|
405
|
+
ElapsedTimeMarker elapsedTimeMarker;
|
406
|
+
double currentScore = std::numeric_limits<double>::quiet_NaN();
|
407
|
+
try {
|
408
|
+
fastText_->train(trainArgs);
|
409
|
+
bool sizeConstraintOK = quantize(trainArgs, autotuneArgs);
|
410
|
+
if (sizeConstraintOK) {
|
411
|
+
const auto& metricLabel = autotuneArgs.getAutotuneMetricLabel();
|
412
|
+
Meter meter(!metricLabel.empty());
|
413
|
+
fastText_->test(
|
414
|
+
validationFileStream, autotuneArgs.autotunePredictions, 0.0, meter);
|
415
|
+
|
416
|
+
currentScore = getMetricScore(
|
417
|
+
meter,
|
418
|
+
autotuneArgs.getAutotuneMetric(),
|
419
|
+
autotuneArgs.getAutotuneMetricValue(),
|
420
|
+
metricLabel);
|
421
|
+
|
422
|
+
if (bestScore_ == kUnknownBestScore || (currentScore > bestScore_)) {
|
423
|
+
bestTrainArgs = trainArgs;
|
424
|
+
bestScore_ = currentScore;
|
425
|
+
strategy_->updateBest(bestTrainArgs);
|
426
|
+
}
|
427
|
+
} else {
|
428
|
+
sizeConstraintFailed_++;
|
429
|
+
if (!sizeConstraintWarning && trials_ > 10 &&
|
430
|
+
sizeConstraintFailed_ > (trials_ / 2)) {
|
431
|
+
sizeConstraintWarning = true;
|
432
|
+
std::cerr << std::endl
|
433
|
+
<< "Warning : requested model size is probably too small. "
|
434
|
+
"You may want to increase `autotune-modelsize`."
|
435
|
+
<< std::endl;
|
436
|
+
}
|
437
|
+
}
|
438
|
+
} catch (DenseMatrix::EncounteredNaNError&) {
|
439
|
+
// ignore diverging loss and go on
|
440
|
+
} catch (std::bad_alloc&) {
|
441
|
+
// ignore parameter samples asking too much memory
|
442
|
+
} catch (TimeoutError&) {
|
443
|
+
break;
|
444
|
+
} catch (FastText::AbortError&) {
|
445
|
+
break;
|
446
|
+
}
|
447
|
+
LOG_VAL_NAN(currentScore, currentScore)
|
448
|
+
LOG_VAL(train took, elapsedTimeMarker.getElapsed())
|
449
|
+
}
|
450
|
+
if (timer_.joinable()) {
|
451
|
+
timer_.join();
|
452
|
+
}
|
453
|
+
|
454
|
+
if (bestScore_ == kUnknownBestScore) {
|
455
|
+
std::string errorMessage;
|
456
|
+
if (sizeConstraintWarning) {
|
457
|
+
errorMessage =
|
458
|
+
"Couldn't fulfil model size constraint: please increase "
|
459
|
+
"`autotune-modelsize`.";
|
460
|
+
} else {
|
461
|
+
errorMessage =
|
462
|
+
"Didn't have enough time to train once: please increase "
|
463
|
+
"`autotune-duration`.";
|
464
|
+
}
|
465
|
+
throw std::runtime_error(errorMessage);
|
466
|
+
} else {
|
467
|
+
std::cerr << std::endl;
|
468
|
+
std::cerr << "Training again with best arguments" << std::endl;
|
469
|
+
bestTrainArgs.verbose = verbose;
|
470
|
+
LOG_VAL(Best selected args, 0)
|
471
|
+
printArgs(bestTrainArgs, autotuneArgs);
|
472
|
+
fastText_->train(bestTrainArgs);
|
473
|
+
quantize(bestTrainArgs, autotuneArgs);
|
474
|
+
}
|
475
|
+
}
|
476
|
+
|
477
|
+
} // namespace fasttext
|