fasttext 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -8
- data/ext/fasttext/ext.cpp +66 -35
- data/ext/fasttext/extconf.rb +2 -3
- data/lib/fasttext/classifier.rb +13 -3
- data/lib/fasttext/vectorizer.rb +6 -1
- data/lib/fasttext/version.rb +1 -1
- data/vendor/fastText/README.md +3 -3
- data/vendor/fastText/src/args.cc +179 -6
- data/vendor/fastText/src/args.h +29 -1
- data/vendor/fastText/src/autotune.cc +477 -0
- data/vendor/fastText/src/autotune.h +89 -0
- data/vendor/fastText/src/densematrix.cc +27 -7
- data/vendor/fastText/src/densematrix.h +10 -2
- data/vendor/fastText/src/fasttext.cc +125 -114
- data/vendor/fastText/src/fasttext.h +31 -52
- data/vendor/fastText/src/main.cc +32 -13
- data/vendor/fastText/src/meter.cc +148 -2
- data/vendor/fastText/src/meter.h +24 -2
- data/vendor/fastText/src/model.cc +0 -1
- data/vendor/fastText/src/real.h +0 -1
- data/vendor/fastText/src/utils.cc +25 -0
- data/vendor/fastText/src/utils.h +29 -0
- data/vendor/fastText/src/vector.cc +0 -1
- metadata +5 -4
- data/lib/fasttext/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: fb6649e0a3992c6e12572d672e8ba768220662efc37b982c278a8d0713716029
|
4
|
+
data.tar.gz: 12a0441cf1030bfbfe99d26824fe757d5b83c2f47de899fecf13a73aa657bd76
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: dadcbc9a0d39468b4d1070cbf74abece85e687fe2d5de9dfb03c8e79f630b47a0ce37bf73f1383bccf917190bf4f1a2acb12e4741b4aaf45d0c688991baf7893
|
7
|
+
data.tar.gz: b9615f2edf557b7a2bb9c90457581489dd86be7539e7f4e9b75ee72cb3fe26ca9a216c5ecda4ff6e82ef348fbcc0482cd4e727e9c6652f2b38195ad3edba1752
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
[fastText](https://fasttext.cc) - efficient text classification and representation learning - for Ruby
|
4
4
|
|
5
|
-
[](https://travis-ci.org/ankane/fasttext)
|
5
|
+
[](https://travis-ci.org/ankane/fasttext) [](https://ci.appveyor.com/project/ankane/fasttext/branch/master)
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
@@ -77,6 +77,12 @@ model.labels
|
|
77
77
|
|
78
78
|
> Use `include_freq: true` to get their frequency
|
79
79
|
|
80
|
+
Search for the best hyperparameters
|
81
|
+
|
82
|
+
```ruby
|
83
|
+
model.fit(x, y, autotune_set: [x_valid, y_valid])
|
84
|
+
```
|
85
|
+
|
80
86
|
Compress the model - significantly reduces size but sacrifices a little performance
|
81
87
|
|
82
88
|
```ruby
|
@@ -168,7 +174,11 @@ FastText::Classifier.new(
|
|
168
174
|
t: 0.0001, # sampling threshold
|
169
175
|
label_prefix: "__label__" # label prefix
|
170
176
|
verbose: 2, # verbose
|
171
|
-
pretrained_vectors: nil
|
177
|
+
pretrained_vectors: nil, # pretrained word vectors (.vec file)
|
178
|
+
autotune_metric: "f1", # autotune optimization metric
|
179
|
+
autotune_predictions: 1, # autotune predictions
|
180
|
+
autotune_duration: 300, # autotune search time in seconds
|
181
|
+
autotune_model_size: nil # autotune model size, like 2M
|
172
182
|
)
|
173
183
|
```
|
174
184
|
|
@@ -200,7 +210,7 @@ FastText::Vectorizer.new(
|
|
200
210
|
Input can be read directly from files
|
201
211
|
|
202
212
|
```ruby
|
203
|
-
model.fit("train.txt")
|
213
|
+
model.fit("train.txt", autotune_set: "valid.txt")
|
204
214
|
model.test("test.txt")
|
205
215
|
```
|
206
216
|
|
@@ -260,12 +270,12 @@ Everyone is encouraged to help improve this project. Here are a few ways you can
|
|
260
270
|
- Write, clarify, or fix documentation
|
261
271
|
- Suggest or add new features
|
262
272
|
|
263
|
-
To get started with development
|
273
|
+
To get started with development:
|
264
274
|
|
265
275
|
```sh
|
266
|
-
git clone https://github.com/ankane/
|
267
|
-
cd
|
276
|
+
git clone https://github.com/ankane/fastText.git
|
277
|
+
cd fastText
|
268
278
|
bundle install
|
269
|
-
rake compile
|
270
|
-
rake test
|
279
|
+
bundle exec rake compile
|
280
|
+
bundle exec rake test
|
271
281
|
```
|
data/ext/fasttext/ext.cpp
CHANGED
@@ -1,18 +1,33 @@
|
|
1
|
+
// stdlib
|
2
|
+
#include <cmath>
|
3
|
+
#include <iterator>
|
4
|
+
#include <sstream>
|
5
|
+
#include <stdexcept>
|
6
|
+
|
7
|
+
// fasttext
|
1
8
|
#include <args.h>
|
9
|
+
#include <autotune.h>
|
2
10
|
#include <densematrix.h>
|
3
11
|
#include <fasttext.h>
|
4
|
-
#include <rice/Data_Type.hpp>
|
5
|
-
#include <rice/Constructor.hpp>
|
6
|
-
#include <rice/Array.hpp>
|
7
|
-
#include <rice/Hash.hpp>
|
8
12
|
#include <real.h>
|
9
13
|
#include <vector.h>
|
10
|
-
#include <cmath>
|
11
|
-
#include <iterator>
|
12
|
-
#include <sstream>
|
13
|
-
#include <stdexcept>
|
14
14
|
|
15
|
-
|
15
|
+
// rice
|
16
|
+
#include <rice/Array.hpp>
|
17
|
+
#include <rice/Constructor.hpp>
|
18
|
+
#include <rice/Data_Type.hpp>
|
19
|
+
#include <rice/Hash.hpp>
|
20
|
+
|
21
|
+
using fasttext::FastText;
|
22
|
+
|
23
|
+
using Rice::Array;
|
24
|
+
using Rice::Constructor;
|
25
|
+
using Rice::Hash;
|
26
|
+
using Rice::Module;
|
27
|
+
using Rice::Object;
|
28
|
+
using Rice::define_class_under;
|
29
|
+
using Rice::define_module;
|
30
|
+
using Rice::define_module_under;
|
16
31
|
|
17
32
|
template<>
|
18
33
|
inline
|
@@ -104,8 +119,18 @@ fasttext::Args buildArgs(Hash h) {
|
|
104
119
|
a.pretrainedVectors = from_ruby<std::string>(value);
|
105
120
|
} else if (name == "save_output") {
|
106
121
|
a.saveOutput = from_ruby<bool>(value);
|
107
|
-
|
108
|
-
|
122
|
+
} else if (name == "seed") {
|
123
|
+
a.seed = from_ruby<int>(value);
|
124
|
+
} else if (name == "autotune_validation_file") {
|
125
|
+
a.autotuneValidationFile = from_ruby<std::string>(value);
|
126
|
+
} else if (name == "autotune_metric") {
|
127
|
+
a.autotuneMetric = from_ruby<std::string>(value);
|
128
|
+
} else if (name == "autotune_predictions") {
|
129
|
+
a.autotunePredictions = from_ruby<int>(value);
|
130
|
+
} else if (name == "autotune_duration") {
|
131
|
+
a.autotuneDuration = from_ruby<int>(value);
|
132
|
+
} else if (name == "autotune_model_size") {
|
133
|
+
a.autotuneModelSize = from_ruby<std::string>(value);
|
109
134
|
} else {
|
110
135
|
throw std::invalid_argument("Unknown argument: " + name);
|
111
136
|
}
|
@@ -119,11 +144,11 @@ void Init_ext()
|
|
119
144
|
Module rb_mFastText = define_module("FastText");
|
120
145
|
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
|
121
146
|
|
122
|
-
define_class_under<
|
123
|
-
.define_constructor(Constructor<
|
147
|
+
define_class_under<FastText>(rb_mExt, "Model")
|
148
|
+
.define_constructor(Constructor<FastText>())
|
124
149
|
.define_method(
|
125
150
|
"words",
|
126
|
-
*[](
|
151
|
+
*[](FastText& m) {
|
127
152
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
128
153
|
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
|
129
154
|
|
@@ -141,7 +166,7 @@ void Init_ext()
|
|
141
166
|
})
|
142
167
|
.define_method(
|
143
168
|
"labels",
|
144
|
-
*[](
|
169
|
+
*[](FastText& m) {
|
145
170
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
146
171
|
std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
|
147
172
|
|
@@ -159,12 +184,12 @@ void Init_ext()
|
|
159
184
|
})
|
160
185
|
.define_method(
|
161
186
|
"test",
|
162
|
-
*[](
|
187
|
+
*[](FastText& m, const std::string filename, int32_t k) {
|
163
188
|
std::ifstream ifs(filename);
|
164
189
|
if (!ifs.is_open()) {
|
165
190
|
throw std::invalid_argument("Test file cannot be opened!");
|
166
191
|
}
|
167
|
-
fasttext::Meter meter;
|
192
|
+
fasttext::Meter meter(false);
|
168
193
|
m.test(ifs, k, 0.0, meter);
|
169
194
|
ifs.close();
|
170
195
|
|
@@ -176,17 +201,17 @@ void Init_ext()
|
|
176
201
|
})
|
177
202
|
.define_method(
|
178
203
|
"load_model",
|
179
|
-
*[](
|
204
|
+
*[](FastText& m, std::string s) { m.loadModel(s); })
|
180
205
|
.define_method(
|
181
206
|
"save_model",
|
182
|
-
*[](
|
183
|
-
.define_method("dimension", &
|
184
|
-
.define_method("quantized?", &
|
185
|
-
.define_method("word_id", &
|
186
|
-
.define_method("subword_id", &
|
207
|
+
*[](FastText& m, std::string s) { m.saveModel(s); })
|
208
|
+
.define_method("dimension", &FastText::getDimension)
|
209
|
+
.define_method("quantized?", &FastText::isQuant)
|
210
|
+
.define_method("word_id", &FastText::getWordId)
|
211
|
+
.define_method("subword_id", &FastText::getSubwordId)
|
187
212
|
.define_method(
|
188
213
|
"predict",
|
189
|
-
*[](
|
214
|
+
*[](FastText& m, const std::string text, int32_t k, float threshold) {
|
190
215
|
std::stringstream ioss(text);
|
191
216
|
std::vector<std::pair<fasttext::real, std::string>> predictions;
|
192
217
|
m.predictLine(ioss, predictions, k, threshold);
|
@@ -194,14 +219,14 @@ void Init_ext()
|
|
194
219
|
})
|
195
220
|
.define_method(
|
196
221
|
"nearest_neighbors",
|
197
|
-
*[](
|
222
|
+
*[](FastText& m, const std::string& word, int32_t k) {
|
198
223
|
return m.getNN(word, k);
|
199
224
|
})
|
200
|
-
.define_method("analogies", &
|
201
|
-
.define_method("ngram_vectors", &
|
225
|
+
.define_method("analogies", &FastText::getAnalogies)
|
226
|
+
.define_method("ngram_vectors", &FastText::getNgramVectors)
|
202
227
|
.define_method(
|
203
228
|
"word_vector",
|
204
|
-
*[](
|
229
|
+
*[](FastText& m, const std::string word) {
|
205
230
|
int dimension = m.getDimension();
|
206
231
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
207
232
|
m.getWordVector(vec, word);
|
@@ -214,7 +239,7 @@ void Init_ext()
|
|
214
239
|
})
|
215
240
|
.define_method(
|
216
241
|
"subwords",
|
217
|
-
*[](
|
242
|
+
*[](FastText& m, const std::string word) {
|
218
243
|
std::vector<std::string> subwords;
|
219
244
|
std::vector<int32_t> ngrams;
|
220
245
|
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
|
@@ -228,7 +253,7 @@ void Init_ext()
|
|
228
253
|
})
|
229
254
|
.define_method(
|
230
255
|
"sentence_vector",
|
231
|
-
*[](
|
256
|
+
*[](FastText& m, const std::string text) {
|
232
257
|
std::istringstream in(text);
|
233
258
|
int dimension = m.getDimension();
|
234
259
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
@@ -242,22 +267,28 @@ void Init_ext()
|
|
242
267
|
})
|
243
268
|
.define_method(
|
244
269
|
"train",
|
245
|
-
*[](
|
246
|
-
|
270
|
+
*[](FastText& m, Hash h) {
|
271
|
+
auto a = buildArgs(h);
|
272
|
+
if (a.hasAutotune()) {
|
273
|
+
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
|
274
|
+
autotune.train(a);
|
275
|
+
} else {
|
276
|
+
m.train(a);
|
277
|
+
}
|
247
278
|
})
|
248
279
|
.define_method(
|
249
280
|
"quantize",
|
250
|
-
*[](
|
281
|
+
*[](FastText& m, Hash h) {
|
251
282
|
m.quantize(buildArgs(h));
|
252
283
|
})
|
253
284
|
.define_method(
|
254
285
|
"supervised?",
|
255
|
-
*[](
|
286
|
+
*[](FastText& m) {
|
256
287
|
return m.getArgs().model == fasttext::model_name::sup;
|
257
288
|
})
|
258
289
|
.define_method(
|
259
290
|
"label_prefix",
|
260
|
-
*[](
|
291
|
+
*[](FastText& m) {
|
261
292
|
return m.getArgs().label;
|
262
293
|
});
|
263
294
|
}
|
data/ext/fasttext/extconf.rb
CHANGED
@@ -1,9 +1,8 @@
|
|
1
1
|
require "mkmf-rice"
|
2
2
|
|
3
|
-
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
-
|
5
3
|
# TODO use -std=c++14 when available
|
6
|
-
|
4
|
+
# -pthread and -O3 set by default
|
5
|
+
$CXXFLAGS << " -std=c++11 -funroll-loops " << with_config("optflags", "-march=native")
|
7
6
|
|
8
7
|
ext = File.expand_path(".", __dir__)
|
9
8
|
fasttext = File.expand_path("../../vendor/fastText/src", __dir__)
|
data/lib/fasttext/classifier.rb
CHANGED
@@ -21,13 +21,23 @@ module FastText
|
|
21
21
|
verbose: 2,
|
22
22
|
pretrained_vectors: "",
|
23
23
|
save_output: false,
|
24
|
-
|
24
|
+
seed: 0,
|
25
|
+
autotune_validation_file: "",
|
26
|
+
autotune_metric: "f1",
|
27
|
+
autotune_predictions: 1,
|
28
|
+
autotune_duration: 60 * 5,
|
29
|
+
autotune_model_size: ""
|
25
30
|
}
|
26
31
|
|
27
|
-
def fit(x, y = nil)
|
32
|
+
def fit(x, y = nil, autotune_set: nil)
|
28
33
|
input = input_path(x, y)
|
29
34
|
@m ||= Ext::Model.new
|
30
|
-
|
35
|
+
opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
|
36
|
+
if autotune_set
|
37
|
+
x, y = autotune_set
|
38
|
+
opts.merge!(autotune_validation_file: input_path(x, y))
|
39
|
+
end
|
40
|
+
m.train(opts)
|
31
41
|
end
|
32
42
|
|
33
43
|
def predict(text, k: 1, threshold: 0.0)
|
data/lib/fasttext/vectorizer.rb
CHANGED
@@ -20,7 +20,12 @@ module FastText
|
|
20
20
|
verbose: 2,
|
21
21
|
pretrained_vectors: "",
|
22
22
|
save_output: false,
|
23
|
-
|
23
|
+
seed: 0,
|
24
|
+
autotune_validation_file: "",
|
25
|
+
autotune_metric: "f1",
|
26
|
+
autotune_predictions: 1,
|
27
|
+
autotune_duration: 60 * 5,
|
28
|
+
autotune_model_size: ""
|
24
29
|
}
|
25
30
|
|
26
31
|
def fit(x)
|
data/lib/fasttext/version.rb
CHANGED
data/vendor/fastText/README.md
CHANGED
@@ -89,9 +89,9 @@ There is also the master branch that contains all of our most recent work, but c
|
|
89
89
|
### Building fastText using make (preferred)
|
90
90
|
|
91
91
|
```
|
92
|
-
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.
|
93
|
-
$ unzip v0.9.
|
94
|
-
$ cd fastText-0.9.
|
92
|
+
$ wget https://github.com/facebookresearch/fastText/archive/v0.9.2.zip
|
93
|
+
$ unzip v0.9.2.zip
|
94
|
+
$ cd fastText-0.9.2
|
95
95
|
$ make
|
96
96
|
```
|
97
97
|
|
data/vendor/fastText/src/args.cc
CHANGED
@@ -12,6 +12,8 @@
|
|
12
12
|
|
13
13
|
#include <iostream>
|
14
14
|
#include <stdexcept>
|
15
|
+
#include <string>
|
16
|
+
#include <unordered_map>
|
15
17
|
|
16
18
|
namespace fasttext {
|
17
19
|
|
@@ -36,12 +38,19 @@ Args::Args() {
|
|
36
38
|
verbose = 2;
|
37
39
|
pretrainedVectors = "";
|
38
40
|
saveOutput = false;
|
41
|
+
seed = 0;
|
39
42
|
|
40
43
|
qout = false;
|
41
44
|
retrain = false;
|
42
45
|
qnorm = false;
|
43
46
|
cutoff = 0;
|
44
47
|
dsub = 2;
|
48
|
+
|
49
|
+
autotuneValidationFile = "";
|
50
|
+
autotuneMetric = "f1";
|
51
|
+
autotunePredictions = 1;
|
52
|
+
autotuneDuration = 60 * 5; // 5 minutes
|
53
|
+
autotuneModelSize = "";
|
45
54
|
}
|
46
55
|
|
47
56
|
std::string Args::lossToString(loss_name ln) const {
|
@@ -78,6 +87,24 @@ std::string Args::modelToString(model_name mn) const {
|
|
78
87
|
return "Unknown model name!"; // should never happen
|
79
88
|
}
|
80
89
|
|
90
|
+
std::string Args::metricToString(metric_name mn) const {
|
91
|
+
switch (mn) {
|
92
|
+
case metric_name::f1score:
|
93
|
+
return "f1score";
|
94
|
+
case metric_name::f1scoreLabel:
|
95
|
+
return "f1scoreLabel";
|
96
|
+
case metric_name::precisionAtRecall:
|
97
|
+
return "precisionAtRecall";
|
98
|
+
case metric_name::precisionAtRecallLabel:
|
99
|
+
return "precisionAtRecallLabel";
|
100
|
+
case metric_name::recallAtPrecision:
|
101
|
+
return "recallAtPrecision";
|
102
|
+
case metric_name::recallAtPrecisionLabel:
|
103
|
+
return "recallAtPrecisionLabel";
|
104
|
+
}
|
105
|
+
return "Unknown metric name!"; // should never happen
|
106
|
+
}
|
107
|
+
|
81
108
|
void Args::parseArgs(const std::vector<std::string>& args) {
|
82
109
|
std::string command(args[1]);
|
83
110
|
if (command == "supervised") {
|
@@ -97,6 +124,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
97
124
|
exit(EXIT_FAILURE);
|
98
125
|
}
|
99
126
|
try {
|
127
|
+
setManual(args[ai].substr(1));
|
128
|
+
|
100
129
|
if (args[ai] == "-h") {
|
101
130
|
std::cerr << "Here is the help! Usage:" << std::endl;
|
102
131
|
printHelp();
|
@@ -157,6 +186,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
157
186
|
} else if (args[ai] == "-saveOutput") {
|
158
187
|
saveOutput = true;
|
159
188
|
ai--;
|
189
|
+
} else if (args[ai] == "-seed") {
|
190
|
+
seed = std::stoi(args.at(ai + 1));
|
160
191
|
} else if (args[ai] == "-qnorm") {
|
161
192
|
qnorm = true;
|
162
193
|
ai--;
|
@@ -170,6 +201,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
170
201
|
cutoff = std::stoi(args.at(ai + 1));
|
171
202
|
} else if (args[ai] == "-dsub") {
|
172
203
|
dsub = std::stoi(args.at(ai + 1));
|
204
|
+
} else if (args[ai] == "-autotune-validation") {
|
205
|
+
autotuneValidationFile = std::string(args.at(ai + 1));
|
206
|
+
} else if (args[ai] == "-autotune-metric") {
|
207
|
+
autotuneMetric = std::string(args.at(ai + 1));
|
208
|
+
getAutotuneMetric(); // throws exception if not able to parse
|
209
|
+
getAutotuneMetricLabel(); // throws exception if not able to parse
|
210
|
+
} else if (args[ai] == "-autotune-predictions") {
|
211
|
+
autotunePredictions = std::stoi(args.at(ai + 1));
|
212
|
+
} else if (args[ai] == "-autotune-duration") {
|
213
|
+
autotuneDuration = std::stoi(args.at(ai + 1));
|
214
|
+
} else if (args[ai] == "-autotune-modelsize") {
|
215
|
+
autotuneModelSize = std::string(args.at(ai + 1));
|
173
216
|
} else {
|
174
217
|
std::cerr << "Unknown argument: " << args[ai] << std::endl;
|
175
218
|
printHelp();
|
@@ -186,7 +229,7 @@ void Args::parseArgs(const std::vector<std::string>& args) {
|
|
186
229
|
printHelp();
|
187
230
|
exit(EXIT_FAILURE);
|
188
231
|
}
|
189
|
-
if (wordNgrams <= 1 && maxn == 0) {
|
232
|
+
if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
|
190
233
|
bucket = 0;
|
191
234
|
}
|
192
235
|
}
|
@@ -195,6 +238,7 @@ void Args::printHelp() {
|
|
195
238
|
printBasicHelp();
|
196
239
|
printDictionaryHelp();
|
197
240
|
printTrainingHelp();
|
241
|
+
printAutotuneHelp();
|
198
242
|
printQuantizationHelp();
|
199
243
|
}
|
200
244
|
|
@@ -227,7 +271,8 @@ void Args::printTrainingHelp() {
|
|
227
271
|
std::cerr
|
228
272
|
<< "\nThe following arguments for training are optional:\n"
|
229
273
|
<< " -lr learning rate [" << lr << "]\n"
|
230
|
-
<< " -lrUpdateRate change the rate of updates for the learning
|
274
|
+
<< " -lrUpdateRate change the rate of updates for the learning "
|
275
|
+
"rate ["
|
231
276
|
<< lrUpdateRate << "]\n"
|
232
277
|
<< " -dim size of word vectors [" << dim << "]\n"
|
233
278
|
<< " -ws size of the context window [" << ws << "]\n"
|
@@ -235,11 +280,31 @@ void Args::printTrainingHelp() {
|
|
235
280
|
<< " -neg number of negatives sampled [" << neg << "]\n"
|
236
281
|
<< " -loss loss function {ns, hs, softmax, one-vs-all} ["
|
237
282
|
<< lossToString(loss) << "]\n"
|
238
|
-
<< " -thread number of threads
|
239
|
-
|
283
|
+
<< " -thread number of threads (set to 1 to ensure "
|
284
|
+
"reproducible results) ["
|
285
|
+
<< thread << "]\n"
|
286
|
+
<< " -pretrainedVectors pretrained word vectors for supervised "
|
287
|
+
"learning ["
|
240
288
|
<< pretrainedVectors << "]\n"
|
241
289
|
<< " -saveOutput whether output params should be saved ["
|
242
|
-
<< boolToString(saveOutput) << "]\n"
|
290
|
+
<< boolToString(saveOutput) << "]\n"
|
291
|
+
<< " -seed random generator seed [" << seed << "]\n";
|
292
|
+
}
|
293
|
+
|
294
|
+
void Args::printAutotuneHelp() {
|
295
|
+
std::cerr << "\nThe following arguments are for autotune:\n"
|
296
|
+
<< " -autotune-validation validation file to be used "
|
297
|
+
"for evaluation\n"
|
298
|
+
<< " -autotune-metric metric objective {f1, "
|
299
|
+
"f1:labelname} ["
|
300
|
+
<< autotuneMetric << "]\n"
|
301
|
+
<< " -autotune-predictions number of predictions used "
|
302
|
+
"for evaluation ["
|
303
|
+
<< autotunePredictions << "]\n"
|
304
|
+
<< " -autotune-duration maximum duration in seconds ["
|
305
|
+
<< autotuneDuration << "]\n"
|
306
|
+
<< " -autotune-modelsize constraint model file size ["
|
307
|
+
<< autotuneModelSize << "] (empty = do not quantize)\n";
|
243
308
|
}
|
244
309
|
|
245
310
|
void Args::printQuantizationHelp() {
|
@@ -247,7 +312,8 @@ void Args::printQuantizationHelp() {
|
|
247
312
|
<< "\nThe following arguments for quantization are optional:\n"
|
248
313
|
<< " -cutoff number of words and ngrams to retain ["
|
249
314
|
<< cutoff << "]\n"
|
250
|
-
<< " -retrain whether embeddings are finetuned if a cutoff
|
315
|
+
<< " -retrain whether embeddings are finetuned if a cutoff "
|
316
|
+
"is applied ["
|
251
317
|
<< boolToString(retrain) << "]\n"
|
252
318
|
<< " -qnorm whether the norm is quantized separately ["
|
253
319
|
<< boolToString(qnorm) << "]\n"
|
@@ -317,4 +383,111 @@ void Args::dump(std::ostream& out) const {
|
|
317
383
|
<< " " << t << std::endl;
|
318
384
|
}
|
319
385
|
|
386
|
+
bool Args::hasAutotune() const {
|
387
|
+
return !autotuneValidationFile.empty();
|
388
|
+
}
|
389
|
+
|
390
|
+
bool Args::isManual(const std::string& argName) const {
|
391
|
+
return (manualArgs_.count(argName) != 0);
|
392
|
+
}
|
393
|
+
|
394
|
+
void Args::setManual(const std::string& argName) {
|
395
|
+
manualArgs_.emplace(argName);
|
396
|
+
}
|
397
|
+
|
398
|
+
metric_name Args::getAutotuneMetric() const {
|
399
|
+
if (autotuneMetric.substr(0, 3) == "f1:") {
|
400
|
+
return metric_name::f1scoreLabel;
|
401
|
+
} else if (autotuneMetric == "f1") {
|
402
|
+
return metric_name::f1score;
|
403
|
+
} else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
|
404
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
405
|
+
if (semicolon != std::string::npos) {
|
406
|
+
return metric_name::precisionAtRecallLabel;
|
407
|
+
}
|
408
|
+
return metric_name::precisionAtRecall;
|
409
|
+
} else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
|
410
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
411
|
+
if (semicolon != std::string::npos) {
|
412
|
+
return metric_name::recallAtPrecisionLabel;
|
413
|
+
}
|
414
|
+
return metric_name::recallAtPrecision;
|
415
|
+
}
|
416
|
+
throw std::runtime_error("Unknown metric : " + autotuneMetric);
|
417
|
+
}
|
418
|
+
|
419
|
+
std::string Args::getAutotuneMetricLabel() const {
|
420
|
+
metric_name metric = getAutotuneMetric();
|
421
|
+
std::string label;
|
422
|
+
if (metric == metric_name::f1scoreLabel) {
|
423
|
+
label = autotuneMetric.substr(3);
|
424
|
+
} else if (
|
425
|
+
metric == metric_name::precisionAtRecallLabel ||
|
426
|
+
metric == metric_name::recallAtPrecisionLabel) {
|
427
|
+
size_t semicolon = autotuneMetric.find(":", 18);
|
428
|
+
label = autotuneMetric.substr(semicolon + 1);
|
429
|
+
} else {
|
430
|
+
return label;
|
431
|
+
}
|
432
|
+
|
433
|
+
if (label.empty()) {
|
434
|
+
throw std::runtime_error("Empty metric label : " + autotuneMetric);
|
435
|
+
}
|
436
|
+
return label;
|
437
|
+
}
|
438
|
+
|
439
|
+
double Args::getAutotuneMetricValue() const {
|
440
|
+
metric_name metric = getAutotuneMetric();
|
441
|
+
double value = 0.0;
|
442
|
+
if (metric == metric_name::precisionAtRecallLabel ||
|
443
|
+
metric == metric_name::precisionAtRecall ||
|
444
|
+
metric == metric_name::recallAtPrecisionLabel ||
|
445
|
+
metric == metric_name::recallAtPrecision) {
|
446
|
+
size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
|
447
|
+
size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
|
448
|
+
const std::string valueStr =
|
449
|
+
autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
|
450
|
+
value = std::stof(valueStr) / 100.0;
|
451
|
+
}
|
452
|
+
return value;
|
453
|
+
}
|
454
|
+
|
455
|
+
int64_t Args::getAutotuneModelSize() const {
|
456
|
+
std::string modelSize = autotuneModelSize;
|
457
|
+
if (modelSize.empty()) {
|
458
|
+
return Args::kUnlimitedModelSize;
|
459
|
+
}
|
460
|
+
std::unordered_map<char, int> units = {
|
461
|
+
{'k', 1000},
|
462
|
+
{'K', 1000},
|
463
|
+
{'m', 1000000},
|
464
|
+
{'M', 1000000},
|
465
|
+
{'g', 1000000000},
|
466
|
+
{'G', 1000000000},
|
467
|
+
};
|
468
|
+
uint64_t multiplier = 1;
|
469
|
+
char lastCharacter = modelSize.back();
|
470
|
+
if (units.count(lastCharacter)) {
|
471
|
+
multiplier = units[lastCharacter];
|
472
|
+
modelSize = modelSize.substr(0, modelSize.size() - 1);
|
473
|
+
}
|
474
|
+
uint64_t size = 0;
|
475
|
+
size_t nonNumericCharacter = 0;
|
476
|
+
bool parseError = false;
|
477
|
+
try {
|
478
|
+
size = std::stol(modelSize, &nonNumericCharacter);
|
479
|
+
} catch (std::invalid_argument&) {
|
480
|
+
parseError = true;
|
481
|
+
}
|
482
|
+
if (!parseError && nonNumericCharacter != modelSize.size()) {
|
483
|
+
parseError = true;
|
484
|
+
}
|
485
|
+
if (parseError) {
|
486
|
+
throw std::invalid_argument(
|
487
|
+
"Unable to parse model size " + autotuneModelSize);
|
488
|
+
}
|
489
|
+
|
490
|
+
return size * multiplier;
|
491
|
+
}
|
492
|
+
|
320
493
|
} // namespace fasttext
|