fasttext 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f83be8c01c6a45a90758ccee430b3898396bcfdda5a2c338126ed9dc3620aea5
4
- data.tar.gz: 7e3dee8eb3afe12745f78448fd01ac68a2f0ac946bd01195f2f6e6081f62fbad
3
+ metadata.gz: fb6649e0a3992c6e12572d672e8ba768220662efc37b982c278a8d0713716029
4
+ data.tar.gz: 12a0441cf1030bfbfe99d26824fe757d5b83c2f47de899fecf13a73aa657bd76
5
5
  SHA512:
6
- metadata.gz: be3117e1aceed3f6126fc1d84eb87caf53abb3be802a419ebcaa284cb567ee9ac033d5860842690bbd1c0477a6e5013dfc36eb96f6fb067a63015150fa18a1fe
7
- data.tar.gz: dc2467f3f7317b5e1955ede144d9ad50c5abb1cc2b9dc5ad356350f192631b0142c91aecc4dd03b331d90f1cf65c7f3ba3ec176c17dc6b66a1226171261de1b6
6
+ metadata.gz: dadcbc9a0d39468b4d1070cbf74abece85e687fe2d5de9dfb03c8e79f630b47a0ce37bf73f1383bccf917190bf4f1a2acb12e4741b4aaf45d0c688991baf7893
7
+ data.tar.gz: b9615f2edf557b7a2bb9c90457581489dd86be7539e7f4e9b75ee72cb3fe26ca9a216c5ecda4ff6e82ef348fbcc0482cd4e727e9c6652f2b38195ad3edba1752
@@ -1,3 +1,9 @@
1
+ ## 0.1.3 (2020-04-28)
2
+
3
+ - Updated fastText to 0.9.2
4
+ - Added support for autotune
5
+ - Added `--with-optflags` option
6
+
1
7
  ## 0.1.2 (2020-01-10)
2
8
 
3
9
  - Fixed installation error with Ruby 2.7
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  [fastText](https://fasttext.cc) - efficient text classification and representation learning - for Ruby
4
4
 
5
- [![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext)
5
+ [![Build Status](https://travis-ci.org/ankane/fasttext.svg?branch=master)](https://travis-ci.org/ankane/fasttext) [![Build status](https://ci.appveyor.com/api/projects/status/67yby3w6mth766y9/branch/master?svg=true)](https://ci.appveyor.com/project/ankane/fasttext/branch/master)
6
6
 
7
7
  ## Installation
8
8
 
@@ -77,6 +77,12 @@ model.labels
77
77
 
78
78
  > Use `include_freq: true` to get their frequency
79
79
 
80
+ Search for the best hyperparameters
81
+
82
+ ```ruby
83
+ model.fit(x, y, autotune_set: [x_valid, y_valid])
84
+ ```
85
+
80
86
  Compress the model - significantly reduces size but sacrifices a little performance
81
87
 
82
88
  ```ruby
@@ -168,7 +174,11 @@ FastText::Classifier.new(
168
174
  t: 0.0001, # sampling threshold
169
175
  label_prefix: "__label__" # label prefix
170
176
  verbose: 2, # verbose
171
- pretrained_vectors: nil # pretrained word vectors (.vec file)
177
+ pretrained_vectors: nil, # pretrained word vectors (.vec file)
178
+ autotune_metric: "f1", # autotune optimization metric
179
+ autotune_predictions: 1, # autotune predictions
180
+ autotune_duration: 300, # autotune search time in seconds
181
+ autotune_model_size: nil # autotune model size, like 2M
172
182
  )
173
183
  ```
174
184
 
@@ -200,7 +210,7 @@ FastText::Vectorizer.new(
200
210
  Input can be read directly from files
201
211
 
202
212
  ```ruby
203
- model.fit("train.txt")
213
+ model.fit("train.txt", autotune_set: "valid.txt")
204
214
  model.test("test.txt")
205
215
  ```
206
216
 
@@ -260,12 +270,12 @@ Everyone is encouraged to help improve this project. Here are a few ways you can
260
270
  - Write, clarify, or fix documentation
261
271
  - Suggest or add new features
262
272
 
263
- To get started with development and testing:
273
+ To get started with development:
264
274
 
265
275
  ```sh
266
- git clone https://github.com/ankane/fasttext.git
267
- cd fasttext
276
+ git clone https://github.com/ankane/fastText.git
277
+ cd fastText
268
278
  bundle install
269
- rake compile
270
- rake test
279
+ bundle exec rake compile
280
+ bundle exec rake test
271
281
  ```
@@ -1,18 +1,33 @@
1
+ // stdlib
2
+ #include <cmath>
3
+ #include <iterator>
4
+ #include <sstream>
5
+ #include <stdexcept>
6
+
7
+ // fasttext
1
8
  #include <args.h>
9
+ #include <autotune.h>
2
10
  #include <densematrix.h>
3
11
  #include <fasttext.h>
4
- #include <rice/Data_Type.hpp>
5
- #include <rice/Constructor.hpp>
6
- #include <rice/Array.hpp>
7
- #include <rice/Hash.hpp>
8
12
  #include <real.h>
9
13
  #include <vector.h>
10
- #include <cmath>
11
- #include <iterator>
12
- #include <sstream>
13
- #include <stdexcept>
14
14
 
15
- using namespace Rice;
15
+ // rice
16
+ #include <rice/Array.hpp>
17
+ #include <rice/Constructor.hpp>
18
+ #include <rice/Data_Type.hpp>
19
+ #include <rice/Hash.hpp>
20
+
21
+ using fasttext::FastText;
22
+
23
+ using Rice::Array;
24
+ using Rice::Constructor;
25
+ using Rice::Hash;
26
+ using Rice::Module;
27
+ using Rice::Object;
28
+ using Rice::define_class_under;
29
+ using Rice::define_module;
30
+ using Rice::define_module_under;
16
31
 
17
32
  template<>
18
33
  inline
@@ -104,8 +119,18 @@ fasttext::Args buildArgs(Hash h) {
104
119
  a.pretrainedVectors = from_ruby<std::string>(value);
105
120
  } else if (name == "save_output") {
106
121
  a.saveOutput = from_ruby<bool>(value);
107
- // } else if (name == "seed") {
108
- // a.seed = from_ruby<int>(value);
122
+ } else if (name == "seed") {
123
+ a.seed = from_ruby<int>(value);
124
+ } else if (name == "autotune_validation_file") {
125
+ a.autotuneValidationFile = from_ruby<std::string>(value);
126
+ } else if (name == "autotune_metric") {
127
+ a.autotuneMetric = from_ruby<std::string>(value);
128
+ } else if (name == "autotune_predictions") {
129
+ a.autotunePredictions = from_ruby<int>(value);
130
+ } else if (name == "autotune_duration") {
131
+ a.autotuneDuration = from_ruby<int>(value);
132
+ } else if (name == "autotune_model_size") {
133
+ a.autotuneModelSize = from_ruby<std::string>(value);
109
134
  } else {
110
135
  throw std::invalid_argument("Unknown argument: " + name);
111
136
  }
@@ -119,11 +144,11 @@ void Init_ext()
119
144
  Module rb_mFastText = define_module("FastText");
120
145
  Module rb_mExt = define_module_under(rb_mFastText, "Ext");
121
146
 
122
- define_class_under<fasttext::FastText>(rb_mExt, "Model")
123
- .define_constructor(Constructor<fasttext::FastText>())
147
+ define_class_under<FastText>(rb_mExt, "Model")
148
+ .define_constructor(Constructor<FastText>())
124
149
  .define_method(
125
150
  "words",
126
- *[](fasttext::FastText& m) {
151
+ *[](FastText& m) {
127
152
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
128
153
  std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::word);
129
154
 
@@ -141,7 +166,7 @@ void Init_ext()
141
166
  })
142
167
  .define_method(
143
168
  "labels",
144
- *[](fasttext::FastText& m) {
169
+ *[](FastText& m) {
145
170
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
146
171
  std::vector<int64_t> freq = d->getCounts(fasttext::entry_type::label);
147
172
 
@@ -159,12 +184,12 @@ void Init_ext()
159
184
  })
160
185
  .define_method(
161
186
  "test",
162
- *[](fasttext::FastText& m, const std::string filename, int32_t k) {
187
+ *[](FastText& m, const std::string filename, int32_t k) {
163
188
  std::ifstream ifs(filename);
164
189
  if (!ifs.is_open()) {
165
190
  throw std::invalid_argument("Test file cannot be opened!");
166
191
  }
167
- fasttext::Meter meter;
192
+ fasttext::Meter meter(false);
168
193
  m.test(ifs, k, 0.0, meter);
169
194
  ifs.close();
170
195
 
@@ -176,17 +201,17 @@ void Init_ext()
176
201
  })
177
202
  .define_method(
178
203
  "load_model",
179
- *[](fasttext::FastText& m, std::string s) { m.loadModel(s); })
204
+ *[](FastText& m, std::string s) { m.loadModel(s); })
180
205
  .define_method(
181
206
  "save_model",
182
- *[](fasttext::FastText& m, std::string s) { m.saveModel(s); })
183
- .define_method("dimension", &fasttext::FastText::getDimension)
184
- .define_method("quantized?", &fasttext::FastText::isQuant)
185
- .define_method("word_id", &fasttext::FastText::getWordId)
186
- .define_method("subword_id", &fasttext::FastText::getSubwordId)
207
+ *[](FastText& m, std::string s) { m.saveModel(s); })
208
+ .define_method("dimension", &FastText::getDimension)
209
+ .define_method("quantized?", &FastText::isQuant)
210
+ .define_method("word_id", &FastText::getWordId)
211
+ .define_method("subword_id", &FastText::getSubwordId)
187
212
  .define_method(
188
213
  "predict",
189
- *[](fasttext::FastText& m, const std::string text, int32_t k, float threshold) {
214
+ *[](FastText& m, const std::string text, int32_t k, float threshold) {
190
215
  std::stringstream ioss(text);
191
216
  std::vector<std::pair<fasttext::real, std::string>> predictions;
192
217
  m.predictLine(ioss, predictions, k, threshold);
@@ -194,14 +219,14 @@ void Init_ext()
194
219
  })
195
220
  .define_method(
196
221
  "nearest_neighbors",
197
- *[](fasttext::FastText& m, const std::string& word, int32_t k) {
222
+ *[](FastText& m, const std::string& word, int32_t k) {
198
223
  return m.getNN(word, k);
199
224
  })
200
- .define_method("analogies", &fasttext::FastText::getAnalogies)
201
- .define_method("ngram_vectors", &fasttext::FastText::getNgramVectors)
225
+ .define_method("analogies", &FastText::getAnalogies)
226
+ .define_method("ngram_vectors", &FastText::getNgramVectors)
202
227
  .define_method(
203
228
  "word_vector",
204
- *[](fasttext::FastText& m, const std::string word) {
229
+ *[](FastText& m, const std::string word) {
205
230
  int dimension = m.getDimension();
206
231
  fasttext::Vector vec = fasttext::Vector(dimension);
207
232
  m.getWordVector(vec, word);
@@ -214,7 +239,7 @@ void Init_ext()
214
239
  })
215
240
  .define_method(
216
241
  "subwords",
217
- *[](fasttext::FastText& m, const std::string word) {
242
+ *[](FastText& m, const std::string word) {
218
243
  std::vector<std::string> subwords;
219
244
  std::vector<int32_t> ngrams;
220
245
  std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
@@ -228,7 +253,7 @@ void Init_ext()
228
253
  })
229
254
  .define_method(
230
255
  "sentence_vector",
231
- *[](fasttext::FastText& m, const std::string text) {
256
+ *[](FastText& m, const std::string text) {
232
257
  std::istringstream in(text);
233
258
  int dimension = m.getDimension();
234
259
  fasttext::Vector vec = fasttext::Vector(dimension);
@@ -242,22 +267,28 @@ void Init_ext()
242
267
  })
243
268
  .define_method(
244
269
  "train",
245
- *[](fasttext::FastText& m, Hash h) {
246
- m.train(buildArgs(h));
270
+ *[](FastText& m, Hash h) {
271
+ auto a = buildArgs(h);
272
+ if (a.hasAutotune()) {
273
+ fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
274
+ autotune.train(a);
275
+ } else {
276
+ m.train(a);
277
+ }
247
278
  })
248
279
  .define_method(
249
280
  "quantize",
250
- *[](fasttext::FastText& m, Hash h) {
281
+ *[](FastText& m, Hash h) {
251
282
  m.quantize(buildArgs(h));
252
283
  })
253
284
  .define_method(
254
285
  "supervised?",
255
- *[](fasttext::FastText& m) {
286
+ *[](FastText& m) {
256
287
  return m.getArgs().model == fasttext::model_name::sup;
257
288
  })
258
289
  .define_method(
259
290
  "label_prefix",
260
- *[](fasttext::FastText& m) {
291
+ *[](FastText& m) {
261
292
  return m.getArgs().label;
262
293
  });
263
294
  }
@@ -1,9 +1,8 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
3
  # TODO use -std=c++14 when available
6
- $CXXFLAGS << " -pthread -std=c++11 -funroll-loops -O3 -march=native"
4
+ # -pthread and -O3 set by default
5
+ $CXXFLAGS << " -std=c++11 -funroll-loops " << with_config("optflags", "-march=native")
7
6
 
8
7
  ext = File.expand_path(".", __dir__)
9
8
  fasttext = File.expand_path("../../vendor/fastText/src", __dir__)
@@ -21,13 +21,23 @@ module FastText
21
21
  verbose: 2,
22
22
  pretrained_vectors: "",
23
23
  save_output: false,
24
- # seed: 0
24
+ seed: 0,
25
+ autotune_validation_file: "",
26
+ autotune_metric: "f1",
27
+ autotune_predictions: 1,
28
+ autotune_duration: 60 * 5,
29
+ autotune_model_size: ""
25
30
  }
26
31
 
27
- def fit(x, y = nil)
32
+ def fit(x, y = nil, autotune_set: nil)
28
33
  input = input_path(x, y)
29
34
  @m ||= Ext::Model.new
30
- m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
35
+ opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
36
+ if autotune_set
37
+ x, y = autotune_set
38
+ opts.merge!(autotune_validation_file: input_path(x, y))
39
+ end
40
+ m.train(opts)
31
41
  end
32
42
 
33
43
  def predict(text, k: 1, threshold: 0.0)
@@ -20,7 +20,12 @@ module FastText
20
20
  verbose: 2,
21
21
  pretrained_vectors: "",
22
22
  save_output: false,
23
- # seed: 0
23
+ seed: 0,
24
+ autotune_validation_file: "",
25
+ autotune_metric: "f1",
26
+ autotune_predictions: 1,
27
+ autotune_duration: 60 * 5,
28
+ autotune_model_size: ""
24
29
  }
25
30
 
26
31
  def fit(x)
@@ -1,3 +1,3 @@
1
1
  module FastText
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
@@ -89,9 +89,9 @@ There is also the master branch that contains all of our most recent work, but c
89
89
  ### Building fastText using make (preferred)
90
90
 
91
91
  ```
92
- $ wget https://github.com/facebookresearch/fastText/archive/v0.9.1.zip
93
- $ unzip v0.9.1.zip
94
- $ cd fastText-0.9.1
92
+ $ wget https://github.com/facebookresearch/fastText/archive/v0.9.2.zip
93
+ $ unzip v0.9.2.zip
94
+ $ cd fastText-0.9.2
95
95
  $ make
96
96
  ```
97
97
 
@@ -12,6 +12,8 @@
12
12
 
13
13
  #include <iostream>
14
14
  #include <stdexcept>
15
+ #include <string>
16
+ #include <unordered_map>
15
17
 
16
18
  namespace fasttext {
17
19
 
@@ -36,12 +38,19 @@ Args::Args() {
36
38
  verbose = 2;
37
39
  pretrainedVectors = "";
38
40
  saveOutput = false;
41
+ seed = 0;
39
42
 
40
43
  qout = false;
41
44
  retrain = false;
42
45
  qnorm = false;
43
46
  cutoff = 0;
44
47
  dsub = 2;
48
+
49
+ autotuneValidationFile = "";
50
+ autotuneMetric = "f1";
51
+ autotunePredictions = 1;
52
+ autotuneDuration = 60 * 5; // 5 minutes
53
+ autotuneModelSize = "";
45
54
  }
46
55
 
47
56
  std::string Args::lossToString(loss_name ln) const {
@@ -78,6 +87,24 @@ std::string Args::modelToString(model_name mn) const {
78
87
  return "Unknown model name!"; // should never happen
79
88
  }
80
89
 
90
+ std::string Args::metricToString(metric_name mn) const {
91
+ switch (mn) {
92
+ case metric_name::f1score:
93
+ return "f1score";
94
+ case metric_name::f1scoreLabel:
95
+ return "f1scoreLabel";
96
+ case metric_name::precisionAtRecall:
97
+ return "precisionAtRecall";
98
+ case metric_name::precisionAtRecallLabel:
99
+ return "precisionAtRecallLabel";
100
+ case metric_name::recallAtPrecision:
101
+ return "recallAtPrecision";
102
+ case metric_name::recallAtPrecisionLabel:
103
+ return "recallAtPrecisionLabel";
104
+ }
105
+ return "Unknown metric name!"; // should never happen
106
+ }
107
+
81
108
  void Args::parseArgs(const std::vector<std::string>& args) {
82
109
  std::string command(args[1]);
83
110
  if (command == "supervised") {
@@ -97,6 +124,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
97
124
  exit(EXIT_FAILURE);
98
125
  }
99
126
  try {
127
+ setManual(args[ai].substr(1));
128
+
100
129
  if (args[ai] == "-h") {
101
130
  std::cerr << "Here is the help! Usage:" << std::endl;
102
131
  printHelp();
@@ -157,6 +186,8 @@ void Args::parseArgs(const std::vector<std::string>& args) {
157
186
  } else if (args[ai] == "-saveOutput") {
158
187
  saveOutput = true;
159
188
  ai--;
189
+ } else if (args[ai] == "-seed") {
190
+ seed = std::stoi(args.at(ai + 1));
160
191
  } else if (args[ai] == "-qnorm") {
161
192
  qnorm = true;
162
193
  ai--;
@@ -170,6 +201,18 @@ void Args::parseArgs(const std::vector<std::string>& args) {
170
201
  cutoff = std::stoi(args.at(ai + 1));
171
202
  } else if (args[ai] == "-dsub") {
172
203
  dsub = std::stoi(args.at(ai + 1));
204
+ } else if (args[ai] == "-autotune-validation") {
205
+ autotuneValidationFile = std::string(args.at(ai + 1));
206
+ } else if (args[ai] == "-autotune-metric") {
207
+ autotuneMetric = std::string(args.at(ai + 1));
208
+ getAutotuneMetric(); // throws exception if not able to parse
209
+ getAutotuneMetricLabel(); // throws exception if not able to parse
210
+ } else if (args[ai] == "-autotune-predictions") {
211
+ autotunePredictions = std::stoi(args.at(ai + 1));
212
+ } else if (args[ai] == "-autotune-duration") {
213
+ autotuneDuration = std::stoi(args.at(ai + 1));
214
+ } else if (args[ai] == "-autotune-modelsize") {
215
+ autotuneModelSize = std::string(args.at(ai + 1));
173
216
  } else {
174
217
  std::cerr << "Unknown argument: " << args[ai] << std::endl;
175
218
  printHelp();
@@ -186,7 +229,7 @@ void Args::parseArgs(const std::vector<std::string>& args) {
186
229
  printHelp();
187
230
  exit(EXIT_FAILURE);
188
231
  }
189
- if (wordNgrams <= 1 && maxn == 0) {
232
+ if (wordNgrams <= 1 && maxn == 0 && !hasAutotune()) {
190
233
  bucket = 0;
191
234
  }
192
235
  }
@@ -195,6 +238,7 @@ void Args::printHelp() {
195
238
  printBasicHelp();
196
239
  printDictionaryHelp();
197
240
  printTrainingHelp();
241
+ printAutotuneHelp();
198
242
  printQuantizationHelp();
199
243
  }
200
244
 
@@ -227,7 +271,8 @@ void Args::printTrainingHelp() {
227
271
  std::cerr
228
272
  << "\nThe following arguments for training are optional:\n"
229
273
  << " -lr learning rate [" << lr << "]\n"
230
- << " -lrUpdateRate change the rate of updates for the learning rate ["
274
+ << " -lrUpdateRate change the rate of updates for the learning "
275
+ "rate ["
231
276
  << lrUpdateRate << "]\n"
232
277
  << " -dim size of word vectors [" << dim << "]\n"
233
278
  << " -ws size of the context window [" << ws << "]\n"
@@ -235,11 +280,31 @@ void Args::printTrainingHelp() {
235
280
  << " -neg number of negatives sampled [" << neg << "]\n"
236
281
  << " -loss loss function {ns, hs, softmax, one-vs-all} ["
237
282
  << lossToString(loss) << "]\n"
238
- << " -thread number of threads [" << thread << "]\n"
239
- << " -pretrainedVectors pretrained word vectors for supervised learning ["
283
+ << " -thread number of threads (set to 1 to ensure "
284
+ "reproducible results) ["
285
+ << thread << "]\n"
286
+ << " -pretrainedVectors pretrained word vectors for supervised "
287
+ "learning ["
240
288
  << pretrainedVectors << "]\n"
241
289
  << " -saveOutput whether output params should be saved ["
242
- << boolToString(saveOutput) << "]\n";
290
+ << boolToString(saveOutput) << "]\n"
291
+ << " -seed random generator seed [" << seed << "]\n";
292
+ }
293
+
294
+ void Args::printAutotuneHelp() {
295
+ std::cerr << "\nThe following arguments are for autotune:\n"
296
+ << " -autotune-validation validation file to be used "
297
+ "for evaluation\n"
298
+ << " -autotune-metric metric objective {f1, "
299
+ "f1:labelname} ["
300
+ << autotuneMetric << "]\n"
301
+ << " -autotune-predictions number of predictions used "
302
+ "for evaluation ["
303
+ << autotunePredictions << "]\n"
304
+ << " -autotune-duration maximum duration in seconds ["
305
+ << autotuneDuration << "]\n"
306
+ << " -autotune-modelsize constraint model file size ["
307
+ << autotuneModelSize << "] (empty = do not quantize)\n";
243
308
  }
244
309
 
245
310
  void Args::printQuantizationHelp() {
@@ -247,7 +312,8 @@ void Args::printQuantizationHelp() {
247
312
  << "\nThe following arguments for quantization are optional:\n"
248
313
  << " -cutoff number of words and ngrams to retain ["
249
314
  << cutoff << "]\n"
250
- << " -retrain whether embeddings are finetuned if a cutoff is applied ["
315
+ << " -retrain whether embeddings are finetuned if a cutoff "
316
+ "is applied ["
251
317
  << boolToString(retrain) << "]\n"
252
318
  << " -qnorm whether the norm is quantized separately ["
253
319
  << boolToString(qnorm) << "]\n"
@@ -317,4 +383,111 @@ void Args::dump(std::ostream& out) const {
317
383
  << " " << t << std::endl;
318
384
  }
319
385
 
386
+ bool Args::hasAutotune() const {
387
+ return !autotuneValidationFile.empty();
388
+ }
389
+
390
+ bool Args::isManual(const std::string& argName) const {
391
+ return (manualArgs_.count(argName) != 0);
392
+ }
393
+
394
+ void Args::setManual(const std::string& argName) {
395
+ manualArgs_.emplace(argName);
396
+ }
397
+
398
+ metric_name Args::getAutotuneMetric() const {
399
+ if (autotuneMetric.substr(0, 3) == "f1:") {
400
+ return metric_name::f1scoreLabel;
401
+ } else if (autotuneMetric == "f1") {
402
+ return metric_name::f1score;
403
+ } else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
404
+ size_t semicolon = autotuneMetric.find(":", 18);
405
+ if (semicolon != std::string::npos) {
406
+ return metric_name::precisionAtRecallLabel;
407
+ }
408
+ return metric_name::precisionAtRecall;
409
+ } else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
410
+ size_t semicolon = autotuneMetric.find(":", 18);
411
+ if (semicolon != std::string::npos) {
412
+ return metric_name::recallAtPrecisionLabel;
413
+ }
414
+ return metric_name::recallAtPrecision;
415
+ }
416
+ throw std::runtime_error("Unknown metric : " + autotuneMetric);
417
+ }
418
+
419
+ std::string Args::getAutotuneMetricLabel() const {
420
+ metric_name metric = getAutotuneMetric();
421
+ std::string label;
422
+ if (metric == metric_name::f1scoreLabel) {
423
+ label = autotuneMetric.substr(3);
424
+ } else if (
425
+ metric == metric_name::precisionAtRecallLabel ||
426
+ metric == metric_name::recallAtPrecisionLabel) {
427
+ size_t semicolon = autotuneMetric.find(":", 18);
428
+ label = autotuneMetric.substr(semicolon + 1);
429
+ } else {
430
+ return label;
431
+ }
432
+
433
+ if (label.empty()) {
434
+ throw std::runtime_error("Empty metric label : " + autotuneMetric);
435
+ }
436
+ return label;
437
+ }
438
+
439
+ double Args::getAutotuneMetricValue() const {
440
+ metric_name metric = getAutotuneMetric();
441
+ double value = 0.0;
442
+ if (metric == metric_name::precisionAtRecallLabel ||
443
+ metric == metric_name::precisionAtRecall ||
444
+ metric == metric_name::recallAtPrecisionLabel ||
445
+ metric == metric_name::recallAtPrecision) {
446
+ size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
447
+ size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
448
+ const std::string valueStr =
449
+ autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
450
+ value = std::stof(valueStr) / 100.0;
451
+ }
452
+ return value;
453
+ }
454
+
455
+ int64_t Args::getAutotuneModelSize() const {
456
+ std::string modelSize = autotuneModelSize;
457
+ if (modelSize.empty()) {
458
+ return Args::kUnlimitedModelSize;
459
+ }
460
+ std::unordered_map<char, int> units = {
461
+ {'k', 1000},
462
+ {'K', 1000},
463
+ {'m', 1000000},
464
+ {'M', 1000000},
465
+ {'g', 1000000000},
466
+ {'G', 1000000000},
467
+ };
468
+ uint64_t multiplier = 1;
469
+ char lastCharacter = modelSize.back();
470
+ if (units.count(lastCharacter)) {
471
+ multiplier = units[lastCharacter];
472
+ modelSize = modelSize.substr(0, modelSize.size() - 1);
473
+ }
474
+ uint64_t size = 0;
475
+ size_t nonNumericCharacter = 0;
476
+ bool parseError = false;
477
+ try {
478
+ size = std::stol(modelSize, &nonNumericCharacter);
479
+ } catch (std::invalid_argument&) {
480
+ parseError = true;
481
+ }
482
+ if (!parseError && nonNumericCharacter != modelSize.size()) {
483
+ parseError = true;
484
+ }
485
+ if (parseError) {
486
+ throw std::invalid_argument(
487
+ "Unable to parse model size " + autotuneModelSize);
488
+ }
489
+
490
+ return size * multiplier;
491
+ }
492
+
320
493
  } // namespace fasttext