fasttext 0.1.2 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +19 -0
- data/LICENSE.txt +18 -18
- data/README.md +26 -19
- data/ext/fasttext/ext.cpp +131 -134
- data/ext/fasttext/extconf.rb +2 -4
- data/lib/fasttext/classifier.rb +23 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +11 -5
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +14 -69
- data/lib/fasttext/ext.bundle +0 -0
data/lib/fasttext/vectorizer.rb
CHANGED
@@ -20,13 +20,19 @@ module FastText
|
|
20
20
|
verbose: 2,
|
21
21
|
pretrained_vectors: "",
|
22
22
|
save_output: false,
|
23
|
-
|
23
|
+
seed: 0,
|
24
|
+
autotune_validation_file: "",
|
25
|
+
autotune_metric: "f1",
|
26
|
+
autotune_predictions: 1,
|
27
|
+
autotune_duration: 60 * 5,
|
28
|
+
autotune_model_size: ""
|
24
29
|
}
|
25
30
|
|
26
31
|
def fit(x)
|
27
|
-
input = input_path(x)
|
28
32
|
@m ||= Ext::Model.new
|
29
|
-
|
33
|
+
a = build_args(DEFAULT_OPTIONS)
|
34
|
+
a.input, _ref = input_path(x)
|
35
|
+
m.train(a)
|
30
36
|
end
|
31
37
|
|
32
38
|
def nearest_neighbors(word, k: 10)
|
@@ -43,7 +49,7 @@ module FastText
|
|
43
49
|
# https://github.com/facebookresearch/fastText/issues/518
|
44
50
|
def input_path(x)
|
45
51
|
if x.is_a?(String)
|
46
|
-
x
|
52
|
+
[x, nil]
|
47
53
|
else
|
48
54
|
tempfile = Tempfile.new("fasttext")
|
49
55
|
x.each do |xi|
|
@@ -51,7 +57,7 @@ module FastText
|
|
51
57
|
tempfile.write("\n")
|
52
58
|
end
|
53
59
|
tempfile.close
|
54
|
-
tempfile.path
|
60
|
+
[tempfile.path, tempfile]
|
55
61
|
end
|
56
62
|
end
|
57
63
|
end
|
data/lib/fasttext/version.rb
CHANGED
data/vendor/fastText/README.md
CHANGED
@@ -89,9 +89,9 @@ There is also the master branch that contains all of our most recent work, but c
|
|
89
89
|
### Building fastText using make (preferred)
|
90
90
|
|
91
91
|
```
|
92
|
-
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.
|
93
|
-
$ unzip v0.9.
|
94
|
-
$ cd fastText-0.9.
|
92
|
+
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.2.zip
|
93
|
+
$ unzip v0.9.2.zip
|
94
|
+
$ cd fastText-0.9.2
|
95
95
|
$ make
|
96
96
|
```
|
97
97
|
|
data/vendor/fastText/src/args.cc
CHANGED
@@ -12,6 +12,8 @@
|
|
12
12
|
|
13
13
|
#include <iostream>
|
14
14
|
#include <stdexcept>
|
15
|
+
#include <string>
|
16
|
+
#include <unordered_map>
|
15
17
|
|
16
18
|
namespace fasttext {
|
17
19
|
|
@@ -36,12 +38,19 @@ Args::Args() {
|
|
36
38
|
verbose = 2;
|
37
39
|
pretrainedVectors = "";
|
38
40
|
saveOutput = false;
|
41
|
+
seed = 0;
|
39
42
|
|
40
43
|
qout = false;
|
41
44
|
retrain = false;
|
42
45
|
qnorm = false;
|
43
46
|
cutoff = 0;
|
44
47
|
dsub = 2;
|
48
|
+
|
49
|
+
autotuneValidationFile = "";
|
50
|
+
autotuneMetric = "f1";
|
51
|
+
autotunePredictions = 1;
|
52
|
+
autotuneDuration = 60 * 5; // 5 minutes
|
53
|
+
autotuneModelSize = "";
|
45
54
|
}
|
46
55
|
|
47
56
|
std::string Args::lossToString(loss_name ln) const {
|
@@ -78,6 +87,24 @@ std::string Args::modelToString(model_name mn) const {
|
|
78
87
|
return "Unknown model name!"; // should never happen
|
79
88
|
}
|
80
89
|
|
90
|
+
std::string Args::metricToString(metric_name mn) const {
|
91
|
+
switch (mn) {
|
92
|
+
case metric_name::f1score:
|
93
|
+
return "f1score";
|
94
|
+
case metric_name::f1scoreLabel:
|
95
|
+
return "f1scoreLabel";
|
96
|
+
case metric_name::precisionAtRecall:
|
97
|
+
return "precisionAtRecall";
|
98
|
+
case metric_name::precisionAtRecallLabel:
|
99
|
+
return "precisionAtRecallLabel";
|
100
|
+
case metric_name::recallAtPrecision:
|
101
|
+
return "recallAtPrecision";
|
102
|
+
case metric_name::recallAtPrecisionLabel:
|
103
|
+
return "recallAtPrecisionLabel";
|
104
|
+
}
|
105
|
+
return "Unknown metric name!"; // should never happen
|
106
|
+
}
|
107
|
+
|
81
108
|
void Args::parseArgs(const std::vector<std::string>& args) {
|
82
109
|
std::string command(args[1]);
|
83
110
|
if (command == "supervised") {
|
@@ -97,6 +124,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
97
124
|
exit(EXIT_FAILURE);
|
98
125
|
}
|
99
126
|
try {
|
127
|
+
setManual(args[ai].substr(1));
|
128
|
+
|
100
129
|
if (args[ai] == "-h") {
|
101
130
|
std::cerr << "Here is the help! Usage:" << std::endl;
|
102
131
|
printHelp();
|
@@ -157,6 +186,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
157
186
|
} else if (args[ai] == "-saveOutput") {
|
158
187
|
saveOutput = true;
|
159
188
|
ai--;
|
189
|
+
} else if (args[ai] == "-seed") {
|
190
|
+
seed = std::stoi(args.at(ai + 1));
|
160
191
|
} else if (args[ai] == "-qnorm") {
|
161
192
|
qnorm = true;
|
162
193
|
ai--;
|
@@ -170,6 +201,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
170
201
|
cutoff = std::stoi(args.at(ai + 1));
|
171
202
|
} else if (args[ai] == "-dsub") {
|
172
203
|
dsub = std::stoi(args.at(ai + 1));
|
204
|
+
} else if (args[ai] == "-autotune-validation") {
|
205
|
+
autotuneValidationFile = std::string(args.at(ai + 1));
|
206
|
+
} else if (args[ai] == "-autotune-metric") {
|
207
|
+
autotuneMetric = std::string(args.at(ai + 1));
|
208
|
+
getAutotuneMetric(); // throws exception if not able to parse
|
209
|
+
getAutotuneMetricLabel(); // throws exception if not able to parse
|
210
|
+
} else if (args[ai] == "-autotune-predictions") {
|
211
|
+
autotunePredictions = std::stoi(args.at(ai + 1));
|
212
|
+
} else if (args[ai] == "-autotune-duration") {
|
213
|
+
autotuneDuration = std::stoi(args.at(ai + 1));
|
214
|
+
} else if (args[ai] == "-autotune-modelsize") {
|
215
|
+
autotuneModelSize = std::string(args.at(ai + 1));
|
173
216
|
} else {
|
174
217
|
std::cerr << "Unknown argument: " << args[ai] << std::endl;
|
175
218
|
printHelp();
|
@@ -186,7 +229,7 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
186
229
|
printHelp();
|
187
230
|
exit(EXIT_FAILURE);
|
188
231
|
}
|
189
|
-
if (wordNgrams <= 1 && maxn == 0) {
|
232
|
+
if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
|
190
233
|
bucket = 0;
|
191
234
|
}
|
192
235
|
}
|
@@ -195,6 +238,7 @@ void Args::printHelp() {
|
|
195
238
|
printBasicHelp();
|
196
239
|
printDictionaryHelp();
|
197
240
|
printTrainingHelp();
|
241
|
+
printAutotuneHelp();
|
198
242
|
printQuantizationHelp();
|
199
243
|
}
|
200
244
|
|
@@ -227,7 +271,8 @@ void Args::printTrainingHelp() {
|
|
227
271
|
std::cerr
|
228
272
|
<< "\nThe following arguments for training are optional:\n"
|
229
273
|
<< " -lr learning rate [" << lr << "]\n"
|
230
|
-
<< " -lrUpdateRate change the rate of updates for the learning
|
274
|
+
<< " -lrUpdateRate change the rate of updates for the learning "
|
275
|
+
"rate ["
|
231
276
|
<< lrUpdateRate << "]\n"
|
232
277
|
<< " -dim size of word vectors [" << dim << "]\n"
|
233
278
|
<< " -ws size of the context window [" << ws << "]\n"
|
@@ -235,11 +280,31 @@ void Args::printTrainingHelp() {
|
|
235
280
|
<< " -neg number of negatives sampled [" << neg << "]\n"
|
236
281
|
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
|
237
282
|
<< lossToString(loss) << "]\n"
|
238
|
-
<< " -thread number of threads
|
239
|
-
|
283
|
+
<< " -thread number of threads (set to 1 to ensure "
|
284
|
+
"reproducible results) ["
|
285
|
+
<< thread << "]\n"
|
286
|
+
<< " -pretrainedVectors pretrained word vectors for supervised "
|
287
|
+
"learning ["
|
240
288
|
<< pretrainedVectors << "]\n"
|
241
289
|
<< " -saveOutput whether output params should be saved ["
|
242
|
-
<< boolToString(saveOutput) << "]\n"
|
290
|
+
<< boolToString(saveOutput) << "]\n"
|
291
|
+
<< " -seed random generator seed [" << seed << "]\n";
|
292
|
+
}
|
293
|
+
|
294
|
+
void Args::printAutotuneHelp() {
|
295
|
+
std::cerr << "\nThe following arguments are for autotune:\n"
|
296
|
+
<< " -autotune-validation validation file to be used "
|
297
|
+
"for evaluation\n"
|
298
|
+
<< " -autotune-metric metric objective {f1, "
|
299
|
+
"f1:labelname} ["
|
300
|
+
<< autotuneMetric << "]\n"
|
301
|
+
<< " -autotune-predictions number of predictions used "
|
302
|
+
"for evaluation ["
|
303
|
+
<< autotunePredictions << "]\n"
|
304
|
+
<< " -autotune-duration maximum duration in seconds ["
|
305
|
+
<< autotuneDuration << "]\n"
|
306
|
+
<< " -autotune-modelsize constraint model file size ["
|
307
|
+
<< autotuneModelSize << "] (empty = do not quantize)\n";
|
243
308
|
}
|
244
309
|
|
245
310
|
void Args::printQuantizationHelp() {
|
@@ -247,7 +312,8 @@ void Args::printQuantizationHelp() {
|
|
247
312
|
<< "\nThe following arguments for quantization are optional:\n"
|
248
313
|
<< " -cutoff number of words and ngrams to retain ["
|
249
314
|
<< cutoff << "]\n"
|
250
|
-
<< " -retrain whether embeddings are finetuned if a cutoff
|
315
|
+
<< " -retrain whether embeddings are finetuned if a cutoff "
|
316
|
+
"is applied ["
|
251
317
|
<< boolToString(retrain) << "]\n"
|
252
318
|
<< " -qnorm whether the norm is quantized separately ["
|
253
319
|
<< boolToString(qnorm) << "]\n"
|
@@ -317,4 +383,111 @@ void Args::dump(std::ostream& out) const {
|
|
317
383
|
<< " " << t << std::endl;
|
318
384
|
}
|
319
385
|
|
386
|
+
bool Args::hasAutotune() const {
|
387
|
+
return !autotuneValidationFile.empty();
|
388
|
+
}
|
389
|
+
|
390
|
+
bool Args::isManual(const std::string& argName) const {
|
391
|
+
return (manualArgs_.count(argName) != 0);
|
392
|
+
}
|
393
|
+
|
394
|
+
void Args::setManual(const std::string& argName) {
|
395
|
+
manualArgs_.emplace(argName);
|
396
|
+
}
|
397
|
+
|
398
|
+
metric_name Args::getAutotuneMetric() const {
|
399
|
+
if (autotuneMetric.substr(0, 3) == "f1:") {
|
400
|
+
return metric_name::f1scoreLabel;
|
401
|
+
} else if (autotuneMetric == "f1") {
|
402
|
+
return metric_name::f1score;
|
403
|
+
} else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
|
404
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
405
|
+
if (semicolon != std::string::npos) {
|
406
|
+
return metric_name::precisionAtRecallLabel;
|
407
|
+
}
|
408
|
+
return metric_name::precisionAtRecall;
|
409
|
+
} else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
|
410
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
411
|
+
if (semicolon != std::string::npos) {
|
412
|
+
return metric_name::recallAtPrecisionLabel;
|
413
|
+
}
|
414
|
+
return metric_name::recallAtPrecision;
|
415
|
+
}
|
416
|
+
throw std::runtime_error("Unknown metric : " + autotuneMetric);
|
417
|
+
}
|
418
|
+
|
419
|
+
std::string Args::getAutotuneMetricLabel() const {
|
420
|
+
metric_name metric = getAutotuneMetric();
|
421
|
+
std::string label;
|
422
|
+
if (metric == metric_name::f1scoreLabel) {
|
423
|
+
label = autotuneMetric.substr(3);
|
424
|
+
} else if (
|
425
|
+
metric == metric_name::precisionAtRecallLabel ||
|
426
|
+
metric == metric_name::recallAtPrecisionLabel) {
|
427
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
428
|
+
label = autotuneMetric.substr(semicolon + 1);
|
429
|
+
} else {
|
430
|
+
return label;
|
431
|
+
}
|
432
|
+
|
433
|
+
if (label.empty()) {
|
434
|
+
throw std::runtime_error("Empty metric label : " + autotuneMetric);
|
435
|
+
}
|
436
|
+
return label;
|
437
|
+
}
|
438
|
+
|
439
|
+
double Args::getAutotuneMetricValue() const {
|
440
|
+
metric_name metric = getAutotuneMetric();
|
441
|
+
double value = 0.0;
|
442
|
+
if (metric == metric_name::precisionAtRecallLabel ||
|
443
|
+
metric == metric_name::precisionAtRecall ||
|
444
|
+
metric == metric_name::recallAtPrecisionLabel ||
|
445
|
+
metric == metric_name::recallAtPrecision) {
|
446
|
+
size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
|
447
|
+
size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
|
448
|
+
const std::string valueStr =
|
449
|
+
autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
|
450
|
+
value = std::stof(valueStr) / 100.0;
|
451
|
+
}
|
452
|
+
return value;
|
453
|
+
}
|
454
|
+
|
455
|
+
int64_t Args::getAutotuneModelSize() const {
|
456
|
+
std::string modelSize = autotuneModelSize;
|
457
|
+
if (modelSize.empty()) {
|
458
|
+
return Args::kUnlimitedModelSize;
|
459
|
+
}
|
460
|
+
std::unordered_map<char, int> units = {
|
461
|
+
{'k', 1000},
|
462
|
+
{'K', 1000},
|
463
|
+
{'m', 1000000},
|
464
|
+
{'M', 1000000},
|
465
|
+
{'g', 1000000000},
|
466
|
+
{'G', 1000000000},
|
467
|
+
};
|
468
|
+
uint64_t multiplier = 1;
|
469
|
+
char lastCharacter = modelSize.back();
|
470
|
+
if (units.count(lastCharacter)) {
|
471
|
+
multiplier = units[lastCharacter];
|
472
|
+
modelSize = modelSize.substr(0, modelSize.size() - 1);
|
473
|
+
}
|
474
|
+
uint64_t size = 0;
|
475
|
+
size_t nonNumericCharacter = 0;
|
476
|
+
bool parseError = false;
|
477
|
+
try {
|
478
|
+
size = std::stol(modelSize, &nonNumericCharacter);
|
479
|
+
} catch (std::invalid_argument&) {
|
480
|
+
parseError = true;
|
481
|
+
}
|
482
|
+
if (!parseError && nonNumericCharacter != modelSize.size()) {
|
483
|
+
parseError = true;
|
484
|
+
}
|
485
|
+
if (parseError) {
|
486
|
+
throw std::invalid_argument(
|
487
|
+
"Unable to parse model size " + autotuneModelSize);
|
488
|
+
}
|
489
|
+
|
490
|
+
return size * multiplier;
|
491
|
+
}
|
492
|
+
|
320
493
|
} // namespace fasttext
|
data/vendor/fastText/src/args.h
CHANGED
@@ -11,18 +11,28 @@
|
|
11
11
|
#include <istream>
|
12
12
|
#include <ostream>
|
13
13
|
#include <string>
|
14
|
+
#include <unordered_set>
|
14
15
|
#include <vector>
|
15
16
|
|
16
17
|
namespace fasttext {
|
17
18
|
|
18
19
|
enum class model_name : int { cbow = 1, sg, sup };
|
19
20
|
enum class loss_name : int { hs = 1, ns, softmax, ova };
|
21
|
+
enum class metric_name : int {
|
22
|
+
f1score = 1,
|
23
|
+
f1scoreLabel,
|
24
|
+
precisionAtRecall,
|
25
|
+
precisionAtRecallLabel,
|
26
|
+
recallAtPrecision,
|
27
|
+
recallAtPrecisionLabel
|
28
|
+
};
|
20
29
|
|
21
30
|
class Args {
|
22
31
|
protected:
|
23
|
-
std::string lossToString(loss_name) const;
|
24
32
|
std::string boolToString(bool) const;
|
25
33
|
std::string modelToString(model_name) const;
|
34
|
+
std::string metricToString(metric_name) const;
|
35
|
+
std::unordered_set<std::string> manualArgs_;
|
26
36
|
|
27
37
|
public:
|
28
38
|
Args();
|
@@ -48,6 +58,7 @@ class Args {
|
|
48
58
|
int verbose;
|
49
59
|
std::string pretrainedVectors;
|
50
60
|
bool saveOutput;
|
61
|
+
int seed;
|
51
62
|
|
52
63
|
bool qout;
|
53
64
|
bool retrain;
|
@@ -55,14 +66,31 @@ class Args {
|
|
55
66
|
size_t cutoff;
|
56
67
|
size_t dsub;
|
57
68
|
|
69
|
+
std::string autotuneValidationFile;
|
70
|
+
std::string autotuneMetric;
|
71
|
+
int autotunePredictions;
|
72
|
+
int autotuneDuration;
|
73
|
+
std::string autotuneModelSize;
|
74
|
+
|
58
75
|
void parseArgs(const std::vector<std::string>& args);
|
59
76
|
void printHelp();
|
60
77
|
void printBasicHelp();
|
61
78
|
void printDictionaryHelp();
|
62
79
|
void printTrainingHelp();
|
80
|
+
void printAutotuneHelp();
|
63
81
|
void printQuantizationHelp();
|
64
82
|
void save(std::ostream&);
|
65
83
|
void load(std::istream&);
|
66
84
|
void dump(std::ostream&) const;
|
85
|
+
bool hasAutotune() const;
|
86
|
+
bool isManual(const std::string& argName) const;
|
87
|
+
void setManual(const std::string& argName);
|
88
|
+
std::string lossToString(loss_name) const;
|
89
|
+
metric_name getAutotuneMetric() const;
|
90
|
+
std::string getAutotuneMetricLabel() const;
|
91
|
+
double getAutotuneMetricValue() const;
|
92
|
+
int64_t getAutotuneModelSize() const;
|
93
|
+
|
94
|
+
static constexpr double kUnlimitedModelSize = -1.0;
|
67
95
|
};
|
68
96
|
} // namespace fasttext
|