fasttext 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fb6649e0a3992c6e12572d672e8ba768220662efc37b982c278a8d0713716029
|
4
|
+
data.tar.gz: 12a0441cf1030bfbfe99d26824fe757d5b83c2f47de899fecf13a73aa657bd76
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: dadcbc9a0d39468b4d1070cbf74abece85e687fe2d5de9dfb03c8e79f630b47a0ce37bf73f1383bccf917190bf4f1a2acb12e4741b4aaf45d0c688991baf7893
|
7
|
+
data.tar.gz: b9615f2edf557b7a2bb9c90457581489dd86be7539e7f4e9b75ee72cb3fe26ca9a216c5ecda4ff6e82ef348fbcc0482cd4e727e9c6652f2b38195ad3edba1752
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
[fastText](https://fasttext.cc) - efficient text classification and representation learning - for Ruby
|
4
4
|
|
5
|
-
[![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext)
|
5
|
+
[![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext) [![Build status](https://ci.appveyor.com/api/projects/status/67yby3w6mth766y9/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/fasttext/branch/master)
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
@@ -77,6 +77,12 @@ model.labels
|
|
77
77
|
|
78
78
|
> Use `include_freq: true` to get their frequency
|
79
79
|
|
80
|
+
Search for the best hyperparameters
|
81
|
+
|
82
|
+
```ruby
|
83
|
+
model.fit(x, y, autotune_set: [x_valid, y_valid])
|
84
|
+
```
|
85
|
+
|
80
86
|
Compress the model - significantly reduces size but sacrifices a little performance
|
81
87
|
|
82
88
|
```ruby
|
@@ -168,7 +174,11 @@ FastText::Classifier.new(
|
|
168
174
|
t: 0.0001, # sampling threshold
|
169
175
|
label_prefix: "__label__" # label prefix
|
170
176
|
verbose: 2, # verbose
|
171
|
-
pretrained_vectors: nil
|
177
|
+
pretrained_vectors: nil, # pretrained word vectors (.vec file)
|
178
|
+
autotune_metric: "f1", # autotune optimization metric
|
179
|
+
autotune_predictions: 1, # autotune predictions
|
180
|
+
autotune_duration: 300, # autotune search time in seconds
|
181
|
+
autotune_model_size: nil # autotune model size, like 2M
|
172
182
|
)
|
173
183
|
```
|
174
184
|
|
@@ -200,7 +210,7 @@ FastText::Vectorizer.new(
|
|
200
210
|
Input can be read directly from files
|
201
211
|
|
202
212
|
```ruby
|
203
|
-
model.fit("train.txt")
|
213
|
+
model.fit("train.txt", autotune_set: "valid.txt")
|
204
214
|
model.test("test.txt")
|
205
215
|
```
|
206
216
|
|
@@ -260,12 +270,12 @@ Everyone is encouraged to help improve this project. Here are a few ways you can
|
|
260
270
|
- Write, clarify, or fix documentation
|
261
271
|
- Suggest or add new features
|
262
272
|
|
263
|
-
To get started with development
|
273
|
+
To get started with development:
|
264
274
|
|
265
275
|
```sh
|
266
|
-
git clone https://github.com/ankane/
|
267
|
-
cd
|
276
|
+
git clone https://github.com/ankane/fastText.git
|
277
|
+
cd fastText
|
268
278
|
bundle install
|
269
|
-
rake compile
|
270
|
-
rake test
|
279
|
+
bundle exec rake compile
|
280
|
+
bundle exec rake test
|
271
281
|
```
|
data/ext/fasttext/ext.cpp
CHANGED
@@ -1,18 +1,33 @@
|
|
1
|
+
// stdlib
|
2
|
+
#include <cmath>
|
3
|
+
#include <iterator>
|
4
|
+
#include <sstream>
|
5
|
+
#include <stdexcept>
|
6
|
+
|
7
|
+
// fasttext
|
1
8
|
#include <args.h>
|
9
|
+
#include <autotune.h>
|
2
10
|
#include <densematrix.h>
|
3
11
|
#include <fasttext.h>
|
4
|
-
#include <rice/Data_Type.hpp>
|
5
|
-
#include <rice/Constructor.hpp>
|
6
|
-
#include <rice/Array.hpp>
|
7
|
-
#include <rice/Hash.hpp>
|
8
12
|
#include <real.h>
|
9
13
|
#include <vector.h>
|
10
|
-
#include <cmath>
|
11
|
-
#include <iterator>
|
12
|
-
#include <sstream>
|
13
|
-
#include <stdexcept>
|
14
14
|
|
15
|
-
|
15
|
+
// rice
|
16
|
+
#include <rice/Array.hpp>
|
17
|
+
#include <rice/Constructor.hpp>
|
18
|
+
#include <rice/Data_Type.hpp>
|
19
|
+
#include <rice/Hash.hpp>
|
20
|
+
|
21
|
+
using fasttext::FastText;
|
22
|
+
|
23
|
+
using Rice::Array;
|
24
|
+
using Rice::Constructor;
|
25
|
+
using Rice::Hash;
|
26
|
+
using Rice::Module;
|
27
|
+
using Rice::Object;
|
28
|
+
using Rice::define_class_under;
|
29
|
+
using Rice::define_module;
|
30
|
+
using Rice::define_module_under;
|
16
31
|
|
17
32
|
template<>
|
18
33
|
inline
|
@@ -104,8 +119,18 @@ fasttext::Args buildArgs(Hash h) {
|
|
104
119
|
a.pretrainedVectors = from_ruby<std::string>(value);
|
105
120
|
} else if (name == "save_output") {
|
106
121
|
a.saveOutput = from_ruby<bool>(value);
|
107
|
-
|
108
|
-
|
122
|
+
} else if (name == "seed") {
|
123
|
+
a.seed = from_ruby<int>(value);
|
124
|
+
} else if (name == "autotune_validation_file") {
|
125
|
+
a.autotuneValidationFile = from_ruby<std::string>(value);
|
126
|
+
} else if (name == "autotune_metric") {
|
127
|
+
a.autotuneMetric = from_ruby<std::string>(value);
|
128
|
+
} else if (name == "autotune_predictions") {
|
129
|
+
a.autotunePredictions = from_ruby<int>(value);
|
130
|
+
} else if (name == "autotune_duration") {
|
131
|
+
a.autotuneDuration = from_ruby<int>(value);
|
132
|
+
} else if (name == "autotune_model_size") {
|
133
|
+
a.autotuneModelSize = from_ruby<std::string>(value);
|
109
134
|
} else {
|
110
135
|
throw std::invalid_argument("Unknown argument: " + name);
|
111
136
|
}
|
@@ -119,11 +144,11 @@ void Init_ext()
|
|
119
144
|
Module rb_mFastText = define_module("FastText");
|
120
145
|
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
|
121
146
|
|
122
|
-
define_class_under<
|
123
|
-
.define_constructor(Constructor<
|
147
|
+
define_class_under<FastText>(rb_mExt, "Model")
|
148
|
+
.define_constructor(Constructor<FastText>())
|
124
149
|
.define_method(
|
125
150
|
"words",
|
126
|
-
*[](
|
151
|
+
*[](FastText& m) {
|
127
152
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
128
153
|
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
|
129
154
|
|
@@ -141,7 +166,7 @@ void Init_ext()
|
|
141
166
|
})
|
142
167
|
.define_method(
|
143
168
|
"labels",
|
144
|
-
*[](
|
169
|
+
*[](FastText& m) {
|
145
170
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
146
171
|
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
|
147
172
|
|
@@ -159,12 +184,12 @@ void Init_ext()
|
|
159
184
|
})
|
160
185
|
.define_method(
|
161
186
|
"test",
|
162
|
-
*[](
|
187
|
+
*[](FastText& m, const std::string filename, int32_t k) {
|
163
188
|
std::ifstream ifs(filename);
|
164
189
|
if (!ifs.is_open()) {
|
165
190
|
throw std::invalid_argument("Test file cannot be opened!");
|
166
191
|
}
|
167
|
-
fasttext::Meter meter;
|
192
|
+
fasttext::Meter meter(false);
|
168
193
|
m.test(ifs, k, 0.0, meter);
|
169
194
|
ifs.close();
|
170
195
|
|
@@ -176,17 +201,17 @@ void Init_ext()
|
|
176
201
|
})
|
177
202
|
.define_method(
|
178
203
|
"load_model",
|
179
|
-
*[](
|
204
|
+
*[](FastText& m, std::string s) { m.loadModel(s); })
|
180
205
|
.define_method(
|
181
206
|
"save_model",
|
182
|
-
*[](
|
183
|
-
.define_method("dimension", &
|
184
|
-
.define_method("quantized?", &
|
185
|
-
.define_method("word_id", &
|
186
|
-
.define_method("subword_id", &
|
207
|
+
*[](FastText& m, std::string s) { m.saveModel(s); })
|
208
|
+
.define_method("dimension", &FastText::getDimension)
|
209
|
+
.define_method("quantized?", &FastText::isQuant)
|
210
|
+
.define_method("word_id", &FastText::getWordId)
|
211
|
+
.define_method("subword_id", &FastText::getSubwordId)
|
187
212
|
.define_method(
|
188
213
|
"predict",
|
189
|
-
*[](
|
214
|
+
*[](FastText& m, const std::string text, int32_t k, float threshold) {
|
190
215
|
std::stringstream ioss(text);
|
191
216
|
std::vector<std::pair<fasttext::real, std::string>> predictions;
|
192
217
|
m.predictLine(ioss, predictions, k, threshold);
|
@@ -194,14 +219,14 @@ void Init_ext()
|
|
194
219
|
})
|
195
220
|
.define_method(
|
196
221
|
"nearest_neighbors",
|
197
|
-
*[](
|
222
|
+
*[](FastText& m, const std::string& word, int32_t k) {
|
198
223
|
return m.getNN(word, k);
|
199
224
|
})
|
200
|
-
.define_method("analogies", &
|
201
|
-
.define_method("ngram_vectors", &
|
225
|
+
.define_method("analogies", &FastText::getAnalogies)
|
226
|
+
.define_method("ngram_vectors", &FastText::getNgramVectors)
|
202
227
|
.define_method(
|
203
228
|
"word_vector",
|
204
|
-
*[](
|
229
|
+
*[](FastText& m, const std::string word) {
|
205
230
|
int dimension = m.getDimension();
|
206
231
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
207
232
|
m.getWordVector(vec, word);
|
@@ -214,7 +239,7 @@ void Init_ext()
|
|
214
239
|
})
|
215
240
|
.define_method(
|
216
241
|
"subwords",
|
217
|
-
*[](
|
242
|
+
*[](FastText& m, const std::string word) {
|
218
243
|
std::vector<std::string> subwords;
|
219
244
|
std::vector<int32_t> ngrams;
|
220
245
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
@@ -228,7 +253,7 @@ void Init_ext()
|
|
228
253
|
})
|
229
254
|
.define_method(
|
230
255
|
"sentence_vector",
|
231
|
-
*[](
|
256
|
+
*[](FastText& m, const std::string text) {
|
232
257
|
std::istringstream in(text);
|
233
258
|
int dimension = m.getDimension();
|
234
259
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
@@ -242,22 +267,28 @@ void Init_ext()
|
|
242
267
|
})
|
243
268
|
.define_method(
|
244
269
|
"train",
|
245
|
-
*[](
|
246
|
-
|
270
|
+
*[](FastText& m, Hash h) {
|
271
|
+
auto a = buildArgs(h);
|
272
|
+
if (a.hasAutotune()) {
|
273
|
+
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
|
274
|
+
autotune.train(a);
|
275
|
+
} else {
|
276
|
+
m.train(a);
|
277
|
+
}
|
247
278
|
})
|
248
279
|
.define_method(
|
249
280
|
"quantize",
|
250
|
-
*[](
|
281
|
+
*[](FastText& m, Hash h) {
|
251
282
|
m.quantize(buildArgs(h));
|
252
283
|
})
|
253
284
|
.define_method(
|
254
285
|
"supervised?",
|
255
|
-
*[](
|
286
|
+
*[](FastText& m) {
|
256
287
|
return m.getArgs().model == fasttext::model_name::sup;
|
257
288
|
})
|
258
289
|
.define_method(
|
259
290
|
"label_prefix",
|
260
|
-
*[](
|
291
|
+
*[](FastText& m) {
|
261
292
|
return m.getArgs().label;
|
262
293
|
});
|
263
294
|
}
|
data/ext/fasttext/extconf.rb
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
require "mkmf-rice"
|
2
2
|
|
3
|
-
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
-
|
5
3
|
# TODO use -std=c++14 when available
|
6
|
-
|
4
|
+
# -pthread and -O3 set by default
|
5
|
+
$CXXFLAGS << " -std=c++11 -funroll-loops " << with_config("optflags", "-march=native")
|
7
6
|
|
8
7
|
ext = File.expand_path(".", __dir__)
|
9
8
|
fasttext = File.expand_path("../../vendor/fastText/src", __dir__)
|
data/lib/fasttext/classifier.rb
CHANGED
@@ -21,13 +21,23 @@ module FastText
|
|
21
21
|
verbose: 2,
|
22
22
|
pretrained_vectors: "",
|
23
23
|
save_output: false,
|
24
|
-
|
24
|
+
seed: 0,
|
25
|
+
autotune_validation_file: "",
|
26
|
+
autotune_metric: "f1",
|
27
|
+
autotune_predictions: 1,
|
28
|
+
autotune_duration: 60 * 5,
|
29
|
+
autotune_model_size: ""
|
25
30
|
}
|
26
31
|
|
27
|
-
def fit(x, y = nil)
|
32
|
+
def fit(x, y = nil, autotune_set: nil)
|
28
33
|
input = input_path(x, y)
|
29
34
|
@m ||= Ext::Model.new
|
30
|
-
|
35
|
+
opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
|
36
|
+
if autotune_set
|
37
|
+
x, y = autotune_set
|
38
|
+
opts.merge!(autotune_validation_file: input_path(x, y))
|
39
|
+
end
|
40
|
+
m.train(opts)
|
31
41
|
end
|
32
42
|
|
33
43
|
def predict(text, k: 1, threshold: 0.0)
|
data/lib/fasttext/vectorizer.rb
CHANGED
@@ -20,7 +20,12 @@ module FastText
|
|
20
20
|
verbose: 2,
|
21
21
|
pretrained_vectors: "",
|
22
22
|
save_output: false,
|
23
|
-
|
23
|
+
seed: 0,
|
24
|
+
autotune_validation_file: "",
|
25
|
+
autotune_metric: "f1",
|
26
|
+
autotune_predictions: 1,
|
27
|
+
autotune_duration: 60 * 5,
|
28
|
+
autotune_model_size: ""
|
24
29
|
}
|
25
30
|
|
26
31
|
def fit(x)
|
data/lib/fasttext/version.rb
CHANGED
data/vendor/fastText/README.md
CHANGED
@@ -89,9 +89,9 @@ There is also the master branch that contains all of our most recent work, but c
|
|
89
89
|
### Building fastText using make (preferred)
|
90
90
|
|
91
91
|
```
|
92
|
-
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.
|
93
|
-
$ unzip v0.9.
|
94
|
-
$ cd fastText-0.9.
|
92
|
+
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.2.zip
|
93
|
+
$ unzip v0.9.2.zip
|
94
|
+
$ cd fastText-0.9.2
|
95
95
|
$ make
|
96
96
|
```
|
97
97
|
|
data/vendor/fastText/src/args.cc
CHANGED
@@ -12,6 +12,8 @@
|
|
12
12
|
|
13
13
|
#include <iostream>
|
14
14
|
#include <stdexcept>
|
15
|
+
#include <string>
|
16
|
+
#include <unordered_map>
|
15
17
|
|
16
18
|
namespace fasttext {
|
17
19
|
|
@@ -36,12 +38,19 @@ Args::Args() {
|
|
36
38
|
verbose = 2;
|
37
39
|
pretrainedVectors = "";
|
38
40
|
saveOutput = false;
|
41
|
+
seed = 0;
|
39
42
|
|
40
43
|
qout = false;
|
41
44
|
retrain = false;
|
42
45
|
qnorm = false;
|
43
46
|
cutoff = 0;
|
44
47
|
dsub = 2;
|
48
|
+
|
49
|
+
autotuneValidationFile = "";
|
50
|
+
autotuneMetric = "f1";
|
51
|
+
autotunePredictions = 1;
|
52
|
+
autotuneDuration = 60 * 5; // 5 minutes
|
53
|
+
autotuneModelSize = "";
|
45
54
|
}
|
46
55
|
|
47
56
|
std::string Args::lossToString(loss_name ln) const {
|
@@ -78,6 +87,24 @@ std::string Args::modelToString(model_name mn) const {
|
|
78
87
|
return "Unknown model name!"; // should never happen
|
79
88
|
}
|
80
89
|
|
90
|
+
std::string Args::metricToString(metric_name mn) const {
|
91
|
+
switch (mn) {
|
92
|
+
case metric_name::f1score:
|
93
|
+
return "f1score";
|
94
|
+
case metric_name::f1scoreLabel:
|
95
|
+
return "f1scoreLabel";
|
96
|
+
case metric_name::precisionAtRecall:
|
97
|
+
return "precisionAtRecall";
|
98
|
+
case metric_name::precisionAtRecallLabel:
|
99
|
+
return "precisionAtRecallLabel";
|
100
|
+
case metric_name::recallAtPrecision:
|
101
|
+
return "recallAtPrecision";
|
102
|
+
case metric_name::recallAtPrecisionLabel:
|
103
|
+
return "recallAtPrecisionLabel";
|
104
|
+
}
|
105
|
+
return "Unknown metric name!"; // should never happen
|
106
|
+
}
|
107
|
+
|
81
108
|
void Args::parseArgs(const std::vector<std::string>& args) {
|
82
109
|
std::string command(args[1]);
|
83
110
|
if (command == "supervised") {
|
@@ -97,6 +124,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
97
124
|
exit(EXIT_FAILURE);
|
98
125
|
}
|
99
126
|
try {
|
127
|
+
setManual(args[ai].substr(1));
|
128
|
+
|
100
129
|
if (args[ai] == "-h") {
|
101
130
|
std::cerr << "Here is the help! Usage:" << std::endl;
|
102
131
|
printHelp();
|
@@ -157,6 +186,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
157
186
|
} else if (args[ai] == "-saveOutput") {
|
158
187
|
saveOutput = true;
|
159
188
|
ai--;
|
189
|
+
} else if (args[ai] == "-seed") {
|
190
|
+
seed = std::stoi(args.at(ai + 1));
|
160
191
|
} else if (args[ai] == "-qnorm") {
|
161
192
|
qnorm = true;
|
162
193
|
ai--;
|
@@ -170,6 +201,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
170
201
|
cutoff = std::stoi(args.at(ai + 1));
|
171
202
|
} else if (args[ai] == "-dsub") {
|
172
203
|
dsub = std::stoi(args.at(ai + 1));
|
204
|
+
} else if (args[ai] == "-autotune-validation") {
|
205
|
+
autotuneValidationFile = std::string(args.at(ai + 1));
|
206
|
+
} else if (args[ai] == "-autotune-metric") {
|
207
|
+
autotuneMetric = std::string(args.at(ai + 1));
|
208
|
+
getAutotuneMetric(); // throws exception if not able to parse
|
209
|
+
getAutotuneMetricLabel(); // throws exception if not able to parse
|
210
|
+
} else if (args[ai] == "-autotune-predictions") {
|
211
|
+
autotunePredictions = std::stoi(args.at(ai + 1));
|
212
|
+
} else if (args[ai] == "-autotune-duration") {
|
213
|
+
autotuneDuration = std::stoi(args.at(ai + 1));
|
214
|
+
} else if (args[ai] == "-autotune-modelsize") {
|
215
|
+
autotuneModelSize = std::string(args.at(ai + 1));
|
173
216
|
} else {
|
174
217
|
std::cerr << "Unknown argument: " << args[ai] << std::endl;
|
175
218
|
printHelp();
|
@@ -186,7 +229,7 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
186
229
|
printHelp();
|
187
230
|
exit(EXIT_FAILURE);
|
188
231
|
}
|
189
|
-
if (wordNgrams <= 1 && maxn == 0) {
|
232
|
+
if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
|
190
233
|
bucket = 0;
|
191
234
|
}
|
192
235
|
}
|
@@ -195,6 +238,7 @@ void Args::printHelp() {
|
|
195
238
|
printBasicHelp();
|
196
239
|
printDictionaryHelp();
|
197
240
|
printTrainingHelp();
|
241
|
+
printAutotuneHelp();
|
198
242
|
printQuantizationHelp();
|
199
243
|
}
|
200
244
|
|
@@ -227,7 +271,8 @@ void Args::printTrainingHelp() {
|
|
227
271
|
std::cerr
|
228
272
|
<< "\nThe following arguments for training are optional:\n"
|
229
273
|
<< " -lr learning rate [" << lr << "]\n"
|
230
|
-
<< " -lrUpdateRate change the rate of updates for the learning
|
274
|
+
<< " -lrUpdateRate change the rate of updates for the learning "
|
275
|
+
"rate ["
|
231
276
|
<< lrUpdateRate << "]\n"
|
232
277
|
<< " -dim size of word vectors [" << dim << "]\n"
|
233
278
|
<< " -ws size of the context window [" << ws << "]\n"
|
@@ -235,11 +280,31 @@ void Args::printTrainingHelp() {
|
|
235
280
|
<< " -neg number of negatives sampled [" << neg << "]\n"
|
236
281
|
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
|
237
282
|
<< lossToString(loss) << "]\n"
|
238
|
-
<< " -thread number of threads
|
239
|
-
|
283
|
+
<< " -thread number of threads (set to 1 to ensure "
|
284
|
+
"reproducible results) ["
|
285
|
+
<< thread << "]\n"
|
286
|
+
<< " -pretrainedVectors pretrained word vectors for supervised "
|
287
|
+
"learning ["
|
240
288
|
<< pretrainedVectors << "]\n"
|
241
289
|
<< " -saveOutput whether output params should be saved ["
|
242
|
-
<< boolToString(saveOutput) << "]\n"
|
290
|
+
<< boolToString(saveOutput) << "]\n"
|
291
|
+
<< " -seed random generator seed [" << seed << "]\n";
|
292
|
+
}
|
293
|
+
|
294
|
+
void Args::printAutotuneHelp() {
|
295
|
+
std::cerr << "\nThe following arguments are for autotune:\n"
|
296
|
+
<< " -autotune-validation validation file to be used "
|
297
|
+
"for evaluation\n"
|
298
|
+
<< " -autotune-metric metric objective {f1, "
|
299
|
+
"f1:labelname} ["
|
300
|
+
<< autotuneMetric << "]\n"
|
301
|
+
<< " -autotune-predictions number of predictions used "
|
302
|
+
"for evaluation ["
|
303
|
+
<< autotunePredictions << "]\n"
|
304
|
+
<< " -autotune-duration maximum duration in seconds ["
|
305
|
+
<< autotuneDuration << "]\n"
|
306
|
+
<< " -autotune-modelsize constraint model file size ["
|
307
|
+
<< autotuneModelSize << "] (empty = do not quantize)\n";
|
243
308
|
}
|
244
309
|
|
245
310
|
void Args::printQuantizationHelp() {
|
@@ -247,7 +312,8 @@ void Args::printQuantizationHelp() {
|
|
247
312
|
<< "\nThe following arguments for quantization are optional:\n"
|
248
313
|
<< " -cutoff number of words and ngrams to retain ["
|
249
314
|
<< cutoff << "]\n"
|
250
|
-
<< " -retrain whether embeddings are finetuned if a cutoff
|
315
|
+
<< " -retrain whether embeddings are finetuned if a cutoff "
|
316
|
+
"is applied ["
|
251
317
|
<< boolToString(retrain) << "]\n"
|
252
318
|
<< " -qnorm whether the norm is quantized separately ["
|
253
319
|
<< boolToString(qnorm) << "]\n"
|
@@ -317,4 +383,111 @@ void Args::dump(std::ostream& out) const {
|
|
317
383
|
<< " " << t << std::endl;
|
318
384
|
}
|
319
385
|
|
386
|
+
bool Args::hasAutotune() const {
|
387
|
+
return !autotuneValidationFile.empty();
|
388
|
+
}
|
389
|
+
|
390
|
+
bool Args::isManual(const std::string& argName) const {
|
391
|
+
return (manualArgs_.count(argName) != 0);
|
392
|
+
}
|
393
|
+
|
394
|
+
void Args::setManual(const std::string& argName) {
|
395
|
+
manualArgs_.emplace(argName);
|
396
|
+
}
|
397
|
+
|
398
|
+
metric_name Args::getAutotuneMetric() const {
|
399
|
+
if (autotuneMetric.substr(0, 3) == "f1:") {
|
400
|
+
return metric_name::f1scoreLabel;
|
401
|
+
} else if (autotuneMetric == "f1") {
|
402
|
+
return metric_name::f1score;
|
403
|
+
} else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
|
404
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
405
|
+
if (semicolon != std::string::npos) {
|
406
|
+
return metric_name::precisionAtRecallLabel;
|
407
|
+
}
|
408
|
+
return metric_name::precisionAtRecall;
|
409
|
+
} else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
|
410
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
411
|
+
if (semicolon != std::string::npos) {
|
412
|
+
return metric_name::recallAtPrecisionLabel;
|
413
|
+
}
|
414
|
+
return metric_name::recallAtPrecision;
|
415
|
+
}
|
416
|
+
throw std::runtime_error("Unknown metric : " + autotuneMetric);
|
417
|
+
}
|
418
|
+
|
419
|
+
std::string Args::getAutotuneMetricLabel() const {
|
420
|
+
metric_name metric = getAutotuneMetric();
|
421
|
+
std::string label;
|
422
|
+
if (metric == metric_name::f1scoreLabel) {
|
423
|
+
label = autotuneMetric.substr(3);
|
424
|
+
} else if (
|
425
|
+
metric == metric_name::precisionAtRecallLabel ||
|
426
|
+
metric == metric_name::recallAtPrecisionLabel) {
|
427
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
428
|
+
label = autotuneMetric.substr(semicolon + 1);
|
429
|
+
} else {
|
430
|
+
return label;
|
431
|
+
}
|
432
|
+
|
433
|
+
if (label.empty()) {
|
434
|
+
throw std::runtime_error("Empty metric label : " + autotuneMetric);
|
435
|
+
}
|
436
|
+
return label;
|
437
|
+
}
|
438
|
+
|
439
|
+
double Args::getAutotuneMetricValue() const {
|
440
|
+
metric_name metric = getAutotuneMetric();
|
441
|
+
double value = 0.0;
|
442
|
+
if (metric == metric_name::precisionAtRecallLabel ||
|
443
|
+
metric == metric_name::precisionAtRecall ||
|
444
|
+
metric == metric_name::recallAtPrecisionLabel ||
|
445
|
+
metric == metric_name::recallAtPrecision) {
|
446
|
+
size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
|
447
|
+
size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
|
448
|
+
const std::string valueStr =
|
449
|
+
autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
|
450
|
+
value = std::stof(valueStr) / 100.0;
|
451
|
+
}
|
452
|
+
return value;
|
453
|
+
}
|
454
|
+
|
455
|
+
int64_t Args::getAutotuneModelSize() const {
|
456
|
+
std::string modelSize = autotuneModelSize;
|
457
|
+
if (modelSize.empty()) {
|
458
|
+
return Args::kUnlimitedModelSize;
|
459
|
+
}
|
460
|
+
std::unordered_map<char, int> units = {
|
461
|
+
{'k', 1000},
|
462
|
+
{'K', 1000},
|
463
|
+
{'m', 1000000},
|
464
|
+
{'M', 1000000},
|
465
|
+
{'g', 1000000000},
|
466
|
+
{'G', 1000000000},
|
467
|
+
};
|
468
|
+
uint64_t multiplier = 1;
|
469
|
+
char lastCharacter = modelSize.back();
|
470
|
+
if (units.count(lastCharacter)) {
|
471
|
+
multiplier = units[lastCharacter];
|
472
|
+
modelSize = modelSize.substr(0, modelSize.size() - 1);
|
473
|
+
}
|
474
|
+
uint64_t size = 0;
|
475
|
+
size_t nonNumericCharacter = 0;
|
476
|
+
bool parseError = false;
|
477
|
+
try {
|
478
|
+
size = std::stol(modelSize, &nonNumericCharacter);
|
479
|
+
} catch (std::invalid_argument&) {
|
480
|
+
parseError = true;
|
481
|
+
}
|
482
|
+
if (!parseError && nonNumericCharacter != modelSize.size()) {
|
483
|
+
parseError = true;
|
484
|
+
}
|
485
|
+
if (parseError) {
|
486
|
+
throw std::invalid_argument(
|
487
|
+
"Unable to parse model size " + autotuneModelSize);
|
488
|
+
}
|
489
|
+
|
490
|
+
return size * multiplier;
|
491
|
+
}
|
492
|
+
|
320
493
|
} // namespace fasttext
|