fasttext 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f83be8c01c6a45a90758ccee430b3898396bcfdda5a2c338126ed9dc3620aea5
4
- data.tar.gz: 7e3dee8eb3afe12745f78448fd01ac68a2f0ac946bd01195f2f6e6081f62fbad
3
+ metadata.gz: fb6649e0a3992c6e12572d672e8ba768220662efc37b982c278a8d0713716029
4
+ data.tar.gz: 12a0441cf1030bfbfe99d26824fe757d5b83c2f47de899fecf13a73aa657bd76
5
5
  SHA512:
6
- metadata.gz: be3117e1aceed3f6126fc1d84eb87caf53abb3be802a419ebcaa284cb567ee9ac033d5860842690bbd1c0477a6e5013dfc36eb96f6fb067a63015150fa18a1fe
7
- data.tar.gz: dc2467f3f7317b5e1955ede144d9ad50c5abb1cc2b9dc5ad356350f192631b0142c91aecc4dd03b331d90f1cf65c7f3ba3ec176c17dc6b66a1226171261de1b6
6
+ metadata.gz: dadcbc9a0d39468b4d1070cbf74abece85e687fe2d5de9dfb03c8e79f630b47a0ce37bf73f1383bccf917190bf4f1a2acb12e4741b4aaf45d0c688991baf7893
7
+ data.tar.gz: b9615f2edf557b7a2bb9c90457581489dd86be7539e7f4e9b75ee72cb3fe26ca9a216c5ecda4ff6e82ef348fbcc0482cd4e727e9c6652f2b38195ad3edba1752
@@ -1,3 +1,9 @@
1
+ ## 0.1.3 (2020-04-28)
2
+
3
+ - Updated fastText to 0.9.2
4
+ - Added support for autotune
5
+ - Added `--with-optflags` option
6
+
1
7
  ## 0.1.2 (2020-01-10)
2
8
 
3
9
  - Fixed installation error with Ruby 2.7
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  [fastText](https://fasttext.cc) - efficient text classification and representation learning - for Ruby
4
4
 
5
- [![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext)
5
+ [![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext) [![Build status](https://ci.appveyor.com/api/projects/status/67yby3w6mth766y9/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/fasttext/branch/master)
6
6
 
7
7
  ## Installation
8
8
 
@@ -77,6 +77,12 @@ model.labels
77
77
 
78
78
  > Use `include_freq: true` to get their frequency
79
79
 
80
+ Search for the best hyperparameters
81
+
82
+ ```ruby
83
+ model.fit(x, y, autotune_set: [x_valid, y_valid])
84
+ ```
85
+
80
86
  Compress the model - significantly reduces size but sacrifices a little performance
81
87
 
82
88
  ```ruby
@@ -168,7 +174,11 @@ FastText::Classifier.new(
168
174
  t: 0.0001, # sampling threshold
169
175
  label_prefix: "__label__" # label prefix
170
176
  verbose: 2, # verbose
171
- pretrained_vectors: nil # pretrained word vectors (.vec file)
177
+ pretrained_vectors: nil, # pretrained word vectors (.vec file)
178
+ autotune_metric: "f1", # autotune optimization metric
179
+ autotune_predictions: 1, # autotune predictions
180
+ autotune_duration: 300, # autotune search time in seconds
181
+ autotune_model_size: nil # autotune model size, like 2M
172
182
  )
173
183
  ```
174
184
 
@@ -200,7 +210,7 @@ FastText::Vectorizer.new(
200
210
  Input can be read directly from files
201
211
 
202
212
  ```ruby
203
- model.fit("train.txt")
213
+ model.fit("train.txt", autotune_set: "valid.txt")
204
214
  model.test("test.txt")
205
215
  ```
206
216
 
@@ -260,12 +270,12 @@ Everyone is encouraged to help improve this project. Here are a few ways you can
260
270
  - Write, clarify, or fix documentation
261
271
  - Suggest or add new features
262
272
 
263
- To get started with development and testing:
273
+ To get started with development:
264
274
 
265
275
  ```sh
266
- git clone https://github.com/ankane/fasttext.git
267
- cd fasttext
276
+ git clone https://github.com/ankane/fastText.git
277
+ cd fastText
268
278
  bundle install
269
- rake compile
270
- rake test
279
+ bundle exec rake compile
280
+ bundle exec rake test
271
281
  ```
@@ -1,18 +1,33 @@
1
+ // stdlib
2
+ #include <cmath>
3
+ #include <iterator>
4
+ #include <sstream>
5
+ #include <stdexcept>
6
+
7
+ // fasttext
1
8
  #include <args.h>
9
+ #include <autotune.h>
2
10
  #include <densematrix.h>
3
11
  #include <fasttext.h>
4
- #include <rice/Data_Type.hpp>
5
- #include <rice/Constructor.hpp>
6
- #include <rice/Array.hpp>
7
- #include <rice/Hash.hpp>
8
12
  #include <real.h>
9
13
  #include <vector.h>
10
- #include <cmath>
11
- #include <iterator>
12
- #include <sstream>
13
- #include <stdexcept>
14
14
 
15
- using namespace Rice;
15
+ // rice
16
+ #include <rice/Array.hpp>
17
+ #include <rice/Constructor.hpp>
18
+ #include <rice/Data_Type.hpp>
19
+ #include <rice/Hash.hpp>
20
+
21
+ using fasttext::FastText;
22
+
23
+ using Rice::Array;
24
+ using Rice::Constructor;
25
+ using Rice::Hash;
26
+ using Rice::Module;
27
+ using Rice::Object;
28
+ using Rice::define_class_under;
29
+ using Rice::define_module;
30
+ using Rice::define_module_under;
16
31
 
17
32
  template<>
18
33
  inline
@@ -104,8 +119,18 @@ fasttext::Args buildArgs(Hash h) {
104
119
  a.pretrainedVectors = from_ruby<std::string>(value);
105
120
  } else if (name == "save_output") {
106
121
  a.saveOutput = from_ruby<bool>(value);
107
- // } else if (name == "seed") {
108
- // a.seed = from_ruby<int>(value);
122
+ } else if (name == "seed") {
123
+ a.seed = from_ruby<int>(value);
124
+ } else if (name == "autotune_validation_file") {
125
+ a.autotuneValidationFile = from_ruby<std::string>(value);
126
+ } else if (name == "autotune_metric") {
127
+ a.autotuneMetric = from_ruby<std::string>(value);
128
+ } else if (name == "autotune_predictions") {
129
+ a.autotunePredictions = from_ruby<int>(value);
130
+ } else if (name == "autotune_duration") {
131
+ a.autotuneDuration = from_ruby<int>(value);
132
+ } else if (name == "autotune_model_size") {
133
+ a.autotuneModelSize = from_ruby<std::string>(value);
109
134
  } else {
110
135
  throw std::invalid_argument("Unknown argument: " + name);
111
136
  }
@@ -119,11 +144,11 @@ void Init_ext()
119
144
  Module rb_mFastText = define_module("FastText");
120
145
  Module rb_mExt = define_module_under(rb_mFastText, "Ext");
121
146
 
122
- define_class_under<fasttext::FastText>(rb_mExt, "Model")
123
- .define_constructor(Constructor<fasttext::FastText>())
147
+ define_class_under<FastText>(rb_mExt, "Model")
148
+ .define_constructor(Constructor<FastText>())
124
149
  .define_method(
125
150
  "words",
126
- *[](fasttext::FastText& m) {
151
+ *[](FastText& m) {
127
152
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
128
153
  std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
129
154
 
@@ -141,7 +166,7 @@ void Init_ext()
141
166
  })
142
167
  .define_method(
143
168
  "labels",
144
- *[](fasttext::FastText& m) {
169
+ *[](FastText& m) {
145
170
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
146
171
  std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
147
172
 
@@ -159,12 +184,12 @@ void Init_ext()
159
184
  })
160
185
  .define_method(
161
186
  "test",
162
- *[](fasttext::FastText& m, const std::string filename, int32_t k) {
187
+ *[](FastText& m, const std::string filename, int32_t k) {
163
188
  std::ifstream ifs(filename);
164
189
  if (!ifs.is_open()) {
165
190
  throw std::invalid_argument("Test file cannot be opened!");
166
191
  }
167
- fasttext::Meter meter;
192
+ fasttext::Meter meter(false);
168
193
  m.test(ifs, k, 0.0, meter);
169
194
  ifs.close();
170
195
 
@@ -176,17 +201,17 @@ void Init_ext()
176
201
  })
177
202
  .define_method(
178
203
  "load_model",
179
- *[](fasttext::FastText& m, std::string s) { m.loadModel(s); })
204
+ *[](FastText& m, std::string s) { m.loadModel(s); })
180
205
  .define_method(
181
206
  "save_model",
182
- *[](fasttext::FastText& m, std::string s) { m.saveModel(s); })
183
- .define_method("dimension", &fasttext::FastText::getDimension)
184
- .define_method("quantized?", &fasttext::FastText::isQuant)
185
- .define_method("word_id", &fasttext::FastText::getWordId)
186
- .define_method("subword_id", &fasttext::FastText::getSubwordId)
207
+ *[](FastText& m, std::string s) { m.saveModel(s); })
208
+ .define_method("dimension", &FastText::getDimension)
209
+ .define_method("quantized?", &FastText::isQuant)
210
+ .define_method("word_id", &FastText::getWordId)
211
+ .define_method("subword_id", &FastText::getSubwordId)
187
212
  .define_method(
188
213
  "predict",
189
- *[](fasttext::FastText& m, const std::string text, int32_t k, float threshold) {
214
+ *[](FastText& m, const std::string text, int32_t k, float threshold) {
190
215
  std::stringstream ioss(text);
191
216
  std::vector<std::pair<fasttext::real, std::string>> predictions;
192
217
  m.predictLine(ioss, predictions, k, threshold);
@@ -194,14 +219,14 @@ void Init_ext()
194
219
  })
195
220
  .define_method(
196
221
  "nearest_neighbors",
197
- *[](fasttext::FastText& m, const std::string& word, int32_t k) {
222
+ *[](FastText& m, const std::string& word, int32_t k) {
198
223
  return m.getNN(word, k);
199
224
  })
200
- .define_method("analogies", &fasttext::FastText::getAnalogies)
201
- .define_method("ngram_vectors", &fasttext::FastText::getNgramVectors)
225
+ .define_method("analogies", &FastText::getAnalogies)
226
+ .define_method("ngram_vectors", &FastText::getNgramVectors)
202
227
  .define_method(
203
228
  "word_vector",
204
- *[](fasttext::FastText& m, const std::string word) {
229
+ *[](FastText& m, const std::string word) {
205
230
  int dimension = m.getDimension();
206
231
  fasttext::Vector vec = fasttext::Vector(dimension);
207
232
  m.getWordVector(vec, word);
@@ -214,7 +239,7 @@ void Init_ext()
214
239
  })
215
240
  .define_method(
216
241
  "subwords",
217
- *[](fasttext::FastText& m, const std::string word) {
242
+ *[](FastText& m, const std::string word) {
218
243
  std::vector<std::string> subwords;
219
244
  std::vector<int32_t> ngrams;
220
245
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
@@ -228,7 +253,7 @@ void Init_ext()
228
253
  })
229
254
  .define_method(
230
255
  "sentence_vector",
231
- *[](fasttext::FastText& m, const std::string text) {
256
+ *[](FastText& m, const std::string text) {
232
257
  std::istringstream in(text);
233
258
  int dimension = m.getDimension();
234
259
  fasttext::Vector vec = fasttext::Vector(dimension);
@@ -242,22 +267,28 @@ void Init_ext()
242
267
  })
243
268
  .define_method(
244
269
  "train",
245
- *[](fasttext::FastText& m, Hash h) {
246
- m.train(buildArgs(h));
270
+ *[](FastText& m, Hash h) {
271
+ auto a = buildArgs(h);
272
+ if (a.hasAutotune()) {
273
+ fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
274
+ autotune.train(a);
275
+ } else {
276
+ m.train(a);
277
+ }
247
278
  })
248
279
  .define_method(
249
280
  "quantize",
250
- *[](fasttext::FastText& m, Hash h) {
281
+ *[](FastText& m, Hash h) {
251
282
  m.quantize(buildArgs(h));
252
283
  })
253
284
  .define_method(
254
285
  "supervised?",
255
- *[](fasttext::FastText& m) {
286
+ *[](FastText& m) {
256
287
  return m.getArgs().model == fasttext::model_name::sup;
257
288
  })
258
289
  .define_method(
259
290
  "label_prefix",
260
- *[](fasttext::FastText& m) {
291
+ *[](FastText& m) {
261
292
  return m.getArgs().label;
262
293
  });
263
294
  }
@@ -1,9 +1,8 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
3
  # TODO use -std=c++14 when available
6
- $CXXFLAGS << " -pthread -std=c++11 -funroll-loops -O3 -march=native"
4
+ # -pthread and -O3 set by default
5
+ $CXXFLAGS << " -std=c++11 -funroll-loops " << with_config("optflags", "-march=native")
7
6
 
8
7
  ext = File.expand_path(".", __dir__)
9
8
  fasttext = File.expand_path("../../vendor/fastText/src", __dir__)
@@ -21,13 +21,23 @@ module FastText
21
21
  verbose: 2,
22
22
  pretrained_vectors: "",
23
23
  save_output: false,
24
- # seed: 0
24
+ seed: 0,
25
+ autotune_validation_file: "",
26
+ autotune_metric: "f1",
27
+ autotune_predictions: 1,
28
+ autotune_duration: 60 * 5,
29
+ autotune_model_size: ""
25
30
  }
26
31
 
27
- def fit(x, y = nil)
32
+ def fit(x, y = nil, autotune_set: nil)
28
33
  input = input_path(x, y)
29
34
  @m ||= Ext::Model.new
30
- m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
35
+ opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
36
+ if autotune_set
37
+ x, y = autotune_set
38
+ opts.merge!(autotune_validation_file: input_path(x, y))
39
+ end
40
+ m.train(opts)
31
41
  end
32
42
 
33
43
  def predict(text, k: 1, threshold: 0.0)
@@ -20,7 +20,12 @@ module FastText
20
20
  verbose: 2,
21
21
  pretrained_vectors: "",
22
22
  save_output: false,
23
- # seed: 0
23
+ seed: 0,
24
+ autotune_validation_file: "",
25
+ autotune_metric: "f1",
26
+ autotune_predictions: 1,
27
+ autotune_duration: 60 * 5,
28
+ autotune_model_size: ""
24
29
  }
25
30
 
26
31
  def fit(x)
@@ -1,3 +1,3 @@
1
1
  module FastText
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
@@ -89,9 +89,9 @@ There is also the master branch that contains all of our most recent work, but c
89
89
  ### Building fastText using make (preferred)
90
90
 
91
91
  ```
92
- $ wget https://github.com/facebookresearch/fastText/archive/v0.9.1.zip
93
- $ unzip v0.9.1.zip
94
- $ cd fastText-0.9.1
92
+ $ wget https://github.com/facebookresearch/fastText/archive/v0.9.2.zip
93
+ $ unzip v0.9.2.zip
94
+ $ cd fastText-0.9.2
95
95
  $ make
96
96
  ```
97
97
 
@@ -12,6 +12,8 @@
12
12
 
13
13
  #include <iostream>
14
14
  #include <stdexcept>
15
+ #include <string>
16
+ #include <unordered_map>
15
17
 
16
18
  namespace fasttext {
17
19
 
@@ -36,12 +38,19 @@ Args::Args() {
36
38
  verbose = 2;
37
39
  pretrainedVectors = "";
38
40
  saveOutput = false;
41
+ seed = 0;
39
42
 
40
43
  qout = false;
41
44
  retrain = false;
42
45
  qnorm = false;
43
46
  cutoff = 0;
44
47
  dsub = 2;
48
+
49
+ autotuneValidationFile = "";
50
+ autotuneMetric = "f1";
51
+ autotunePredictions = 1;
52
+ autotuneDuration = 60 * 5; // 5 minutes
53
+ autotuneModelSize = "";
45
54
  }
46
55
 
47
56
  std::string Args::lossToString(loss_name ln) const {
@@ -78,6 +87,24 @@ std::string Args::modelToString(model_name mn) const {
78
87
  return "Unknown model name!"; // should never happen
79
88
  }
80
89
 
90
+ std::string Args::metricToString(metric_name mn) const {
91
+ switch (mn) {
92
+ case metric_name::f1score:
93
+ return "f1score";
94
+ case metric_name::f1scoreLabel:
95
+ return "f1scoreLabel";
96
+ case metric_name::precisionAtRecall:
97
+ return "precisionAtRecall";
98
+ case metric_name::precisionAtRecallLabel:
99
+ return "precisionAtRecallLabel";
100
+ case metric_name::recallAtPrecision:
101
+ return "recallAtPrecision";
102
+ case metric_name::recallAtPrecisionLabel:
103
+ return "recallAtPrecisionLabel";
104
+ }
105
+ return "Unknown metric name!"; // should never happen
106
+ }
107
+
81
108
  void Args::parseArgs(const std::vector<std::string>& args) {
82
109
  std::string command(args[1]);
83
110
  if (command == "supervised") {
@@ -97,6 +124,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
97
124
  exit(EXIT_FAILURE);
98
125
  }
99
126
  try {
127
+ setManual(args[ai].substr(1));
128
+
100
129
  if (args[ai] == "-h") {
101
130
  std::cerr << "Here is the help! Usage:" << std::endl;
102
131
  printHelp();
@@ -157,6 +186,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
157
186
  } else if (args[ai] == "-saveOutput") {
158
187
  saveOutput = true;
159
188
  ai--;
189
+ } else if (args[ai] == "-seed") {
190
+ seed = std::stoi(args.at(ai + 1));
160
191
  } else if (args[ai] == "-qnorm") {
161
192
  qnorm = true;
162
193
  ai--;
@@ -170,6 +201,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
170
201
  cutoff = std::stoi(args.at(ai + 1));
171
202
  } else if (args[ai] == "-dsub") {
172
203
  dsub = std::stoi(args.at(ai + 1));
204
+ } else if (args[ai] == "-autotune-validation") {
205
+ autotuneValidationFile = std::string(args.at(ai + 1));
206
+ } else if (args[ai] == "-autotune-metric") {
207
+ autotuneMetric = std::string(args.at(ai + 1));
208
+ getAutotuneMetric(); // throws exception if not able to parse
209
+ getAutotuneMetricLabel(); // throws exception if not able to parse
210
+ } else if (args[ai] == "-autotune-predictions") {
211
+ autotunePredictions = std::stoi(args.at(ai + 1));
212
+ } else if (args[ai] == "-autotune-duration") {
213
+ autotuneDuration = std::stoi(args.at(ai + 1));
214
+ } else if (args[ai] == "-autotune-modelsize") {
215
+ autotuneModelSize = std::string(args.at(ai + 1));
173
216
  } else {
174
217
  std::cerr << "Unknown argument: " << args[ai] << std::endl;
175
218
  printHelp();
@@ -186,7 +229,7 @@ void Args::parseArgs(const std::vector<std::string>& args) {
186
229
  printHelp();
187
230
  exit(EXIT_FAILURE);
188
231
  }
189
- if (wordNgrams <= 1 && maxn == 0) {
232
+ if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
190
233
  bucket = 0;
191
234
  }
192
235
  }
@@ -195,6 +238,7 @@ void Args::printHelp() {
195
238
  printBasicHelp();
196
239
  printDictionaryHelp();
197
240
  printTrainingHelp();
241
+ printAutotuneHelp();
198
242
  printQuantizationHelp();
199
243
  }
200
244
 
@@ -227,7 +271,8 @@ void Args::printTrainingHelp() {
227
271
  std::cerr
228
272
  << "\nThe following arguments for training are optional:\n"
229
273
  << " -lr learning rate [" << lr << "]\n"
230
- << " -lrUpdateRate change the rate of updates for the learning rate ["
274
+ << " -lrUpdateRate change the rate of updates for the learning "
275
+ "rate ["
231
276
  << lrUpdateRate << "]\n"
232
277
  << " -dim size of word vectors [" << dim << "]\n"
233
278
  << " -ws size of the context window [" << ws << "]\n"
@@ -235,11 +280,31 @@ void Args::printTrainingHelp() {
235
280
  << " -neg number of negatives sampled [" << neg << "]\n"
236
281
  << " -loss loss function {ns, hs, softmax, one-vs-all} ["
237
282
  << lossToString(loss) << "]\n"
238
- << " -thread number of threads [" << thread << "]\n"
239
- << " -pretrainedVectors pretrained word vectors for supervised learning ["
283
+ << " -thread number of threads (set to 1 to ensure "
284
+ "reproducible results) ["
285
+ << thread << "]\n"
286
+ << " -pretrainedVectors pretrained word vectors for supervised "
287
+ "learning ["
240
288
  << pretrainedVectors << "]\n"
241
289
  << " -saveOutput whether output params should be saved ["
242
- << boolToString(saveOutput) << "]\n";
290
+ << boolToString(saveOutput) << "]\n"
291
+ << " -seed random generator seed [" << seed << "]\n";
292
+ }
293
+
294
+ void Args::printAutotuneHelp() {
295
+ std::cerr << "\nThe following arguments are for autotune:\n"
296
+ << " -autotune-validation validation file to be used "
297
+ "for evaluation\n"
298
+ << " -autotune-metric metric objective {f1, "
299
+ "f1:labelname} ["
300
+ << autotuneMetric << "]\n"
301
+ << " -autotune-predictions number of predictions used "
302
+ "for evaluation ["
303
+ << autotunePredictions << "]\n"
304
+ << " -autotune-duration maximum duration in seconds ["
305
+ << autotuneDuration << "]\n"
306
+ << " -autotune-modelsize constraint model file size ["
307
+ << autotuneModelSize << "] (empty = do not quantize)\n";
243
308
  }
244
309
 
245
310
  void Args::printQuantizationHelp() {
@@ -247,7 +312,8 @@ void Args::printQuantizationHelp() {
247
312
  << "\nThe following arguments for quantization are optional:\n"
248
313
  << " -cutoff number of words and ngrams to retain ["
249
314
  << cutoff << "]\n"
250
- << " -retrain whether embeddings are finetuned if a cutoff is applied ["
315
+ << " -retrain whether embeddings are finetuned if a cutoff "
316
+ "is applied ["
251
317
  << boolToString(retrain) << "]\n"
252
318
  << " -qnorm whether the norm is quantized separately ["
253
319
  << boolToString(qnorm) << "]\n"
@@ -317,4 +383,111 @@ void Args::dump(std::ostream& out) const {
317
383
  << " " << t << std::endl;
318
384
  }
319
385
 
386
+ bool Args::hasAutotune() const {
387
+ return !autotuneValidationFile.empty();
388
+ }
389
+
390
+ bool Args::isManual(const std::string& argName) const {
391
+ return (manualArgs_.count(argName) != 0);
392
+ }
393
+
394
+ void Args::setManual(const std::string& argName) {
395
+ manualArgs_.emplace(argName);
396
+ }
397
+
398
+ metric_name Args::getAutotuneMetric() const {
399
+ if (autotuneMetric.substr(0, 3) == "f1:") {
400
+ return metric_name::f1scoreLabel;
401
+ } else if (autotuneMetric == "f1") {
402
+ return metric_name::f1score;
403
+ } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
404
+ size_t semicolon = autotuneMetric.find(":", 18);
405
+ if (semicolon != std::string::npos) {
406
+ return metric_name::precisionAtRecallLabel;
407
+ }
408
+ return metric_name::precisionAtRecall;
409
+ } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
410
+ size_t semicolon = autotuneMetric.find(":", 18);
411
+ if (semicolon != std::string::npos) {
412
+ return metric_name::recallAtPrecisionLabel;
413
+ }
414
+ return metric_name::recallAtPrecision;
415
+ }
416
+ throw std::runtime_error("Unknown metric : " + autotuneMetric);
417
+ }
418
+
419
+ std::string Args::getAutotuneMetricLabel() const {
420
+ metric_name metric = getAutotuneMetric();
421
+ std::string label;
422
+ if (metric == metric_name::f1scoreLabel) {
423
+ label = autotuneMetric.substr(3);
424
+ } else if (
425
+ metric == metric_name::precisionAtRecallLabel ||
426
+ metric == metric_name::recallAtPrecisionLabel) {
427
+ size_t semicolon = autotuneMetric.find(":", 18);
428
+ label = autotuneMetric.substr(semicolon + 1);
429
+ } else {
430
+ return label;
431
+ }
432
+
433
+ if (label.empty()) {
434
+ throw std::runtime_error("Empty metric label : " + autotuneMetric);
435
+ }
436
+ return label;
437
+ }
438
+
439
+ double Args::getAutotuneMetricValue() const {
440
+ metric_name metric = getAutotuneMetric();
441
+ double value = 0.0;
442
+ if (metric == metric_name::precisionAtRecallLabel ||
443
+ metric == metric_name::precisionAtRecall ||
444
+ metric == metric_name::recallAtPrecisionLabel ||
445
+ metric == metric_name::recallAtPrecision) {
446
+ size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
447
+ size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
448
+ const std::string valueStr =
449
+ autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
450
+ value = std::stof(valueStr) / 100.0;
451
+ }
452
+ return value;
453
+ }
454
+
455
+ int64_t Args::getAutotuneModelSize() const {
456
+ std::string modelSize = autotuneModelSize;
457
+ if (modelSize.empty()) {
458
+ return Args::kUnlimitedModelSize;
459
+ }
460
+ std::unordered_map<char, int> units = {
461
+ {'k', 1000},
462
+ {'K', 1000},
463
+ {'m', 1000000},
464
+ {'M', 1000000},
465
+ {'g', 1000000000},
466
+ {'G', 1000000000},
467
+ };
468
+ uint64_t multiplier = 1;
469
+ char lastCharacter = modelSize.back();
470
+ if (units.count(lastCharacter)) {
471
+ multiplier = units[lastCharacter];
472
+ modelSize = modelSize.substr(0, modelSize.size() - 1);
473
+ }
474
+ uint64_t size = 0;
475
+ size_t nonNumericCharacter = 0;
476
+ bool parseError = false;
477
+ try {
478
+ size = std::stol(modelSize, &nonNumericCharacter);
479
+ } catch (std::invalid_argument&) {
480
+ parseError = true;
481
+ }
482
+ if (!parseError && nonNumericCharacter != modelSize.size()) {
483
+ parseError = true;
484
+ }
485
+ if (parseError) {
486
+ throw std::invalid_argument(
487
+ "Unable to parse model size " + autotuneModelSize);
488
+ }
489
+
490
+ return size * multiplier;
491
+ }
492
+
320
493
  } // namespace fasttext