fasttext 0.2.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 597a9c73a6720397b11fb9bd0492c33ed458783747d904deef74618a6f319133
4
- data.tar.gz: b03ab1e051cda71fafc499dbf0847d3da1c27050646e949eac148a5548460054
3
+ metadata.gz: 9aae7e20933f51ebebd802276d7e006ce792fcfc6dd94ea8ceda887ab0f4eca8
4
+ data.tar.gz: a90a7bbfffe424829052afc3519bc0473a7fc1b22a2d1ea0c8786d9372a0cdb5
5
5
  SHA512:
6
- metadata.gz: 844ae824ad06aa270a5c58b6dc109701660b2bb0e5bfe63e285df57bc5db854028703c1c0476a845d5b0e8c32dd02ac42e1c8b1c5c0684713928c895c776b1da
7
- data.tar.gz: 7f7a07f0041981eb046e84f6b44f4207192eb9e8a786edf41b53ae15b370eb875f5b4d578248c90a371805697ad3f0fa076532d1f6789c44d336d72402a54333
6
+ metadata.gz: 3042a798560e5960d18d8bcefb66833bf621bc1dcf26f77742602d0d4e7b1e4cc33b078f9225c209ac9a0d2f89c9b3c98ad2446582af4a3ac4effc761192c58d
7
+ data.tar.gz: d70ce1005916a809a78b02f23232b1efd7cc59dc9d3ed2df7929f946db2f326c0df40729b2ea953846e906c2774d6b6e971650f3d585da65ae6562bdcb9a9b5e
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.2.2 (2021-10-16)
2
+
3
+ - Fixed `file cannot be opened` errors
4
+
1
5
  ## 0.2.1 (2021-05-23)
2
6
 
3
7
  - Improved performance
data/ext/fasttext/ext.cpp CHANGED
@@ -16,13 +16,12 @@
16
16
  #include <rice/rice.hpp>
17
17
  #include <rice/stl.hpp>
18
18
 
19
+ using fasttext::Args;
19
20
  using fasttext::FastText;
20
21
 
21
22
  using Rice::Array;
22
23
  using Rice::Constructor;
23
- using Rice::Hash;
24
24
  using Rice::Module;
25
- using Rice::Object;
26
25
  using Rice::define_class_under;
27
26
  using Rice::define_module;
28
27
  using Rice::define_module_under;
@@ -47,103 +46,69 @@ namespace Rice::detail
47
46
  };
48
47
  }
49
48
 
50
- fasttext::Args buildArgs(Hash h) {
51
- fasttext::Args a;
52
-
53
- for (const auto& it : h)
54
- {
55
- auto name = it.key.to_s().str();
56
- auto value = (it.value).value();
57
-
58
- if (name == "input") {
59
- a.input = Rice::detail::From_Ruby<std::string>().convert(value);
60
- } else if (name == "output") {
61
- a.output = Rice::detail::From_Ruby<std::string>().convert(value);
62
- } else if (name == "lr") {
63
- a.lr = Rice::detail::From_Ruby<double>().convert(value);
64
- } else if (name == "lr_update_rate") {
65
- a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value);
66
- } else if (name == "dim") {
67
- a.dim = Rice::detail::From_Ruby<int>().convert(value);
68
- } else if (name == "ws") {
69
- a.ws = Rice::detail::From_Ruby<int>().convert(value);
70
- } else if (name == "epoch") {
71
- a.epoch = Rice::detail::From_Ruby<int>().convert(value);
72
- } else if (name == "min_count") {
73
- a.minCount = Rice::detail::From_Ruby<int>().convert(value);
74
- } else if (name == "min_count_label") {
75
- a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value);
76
- } else if (name == "neg") {
77
- a.neg = Rice::detail::From_Ruby<int>().convert(value);
78
- } else if (name == "word_ngrams") {
79
- a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value);
80
- } else if (name == "loss") {
81
- std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
82
- if (str == "softmax") {
83
- a.loss = fasttext::loss_name::softmax;
84
- } else if (str == "ns") {
85
- a.loss = fasttext::loss_name::ns;
86
- } else if (str == "hs") {
87
- a.loss = fasttext::loss_name::hs;
88
- } else if (str == "ova") {
89
- a.loss = fasttext::loss_name::ova;
90
- } else {
91
- throw std::invalid_argument("Unknown loss: " + str);
92
- }
93
- } else if (name == "model") {
94
- std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
95
- if (str == "supervised") {
96
- a.model = fasttext::model_name::sup;
97
- } else if (str == "skipgram") {
98
- a.model = fasttext::model_name::sg;
99
- } else if (str == "cbow") {
100
- a.model = fasttext::model_name::cbow;
101
- } else {
102
- throw std::invalid_argument("Unknown model: " + str);
103
- }
104
- } else if (name == "bucket") {
105
- a.bucket = Rice::detail::From_Ruby<int>().convert(value);
106
- } else if (name == "minn") {
107
- a.minn = Rice::detail::From_Ruby<int>().convert(value);
108
- } else if (name == "maxn") {
109
- a.maxn = Rice::detail::From_Ruby<int>().convert(value);
110
- } else if (name == "thread") {
111
- a.thread = Rice::detail::From_Ruby<int>().convert(value);
112
- } else if (name == "t") {
113
- a.t = Rice::detail::From_Ruby<double>().convert(value);
114
- } else if (name == "label_prefix") {
115
- a.label = Rice::detail::From_Ruby<std::string>().convert(value);
116
- } else if (name == "verbose") {
117
- a.verbose = Rice::detail::From_Ruby<int>().convert(value);
118
- } else if (name == "pretrained_vectors") {
119
- a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value);
120
- } else if (name == "save_output") {
121
- a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value);
122
- } else if (name == "seed") {
123
- a.seed = Rice::detail::From_Ruby<int>().convert(value);
124
- } else if (name == "autotune_validation_file") {
125
- a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value);
126
- } else if (name == "autotune_metric") {
127
- a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value);
128
- } else if (name == "autotune_predictions") {
129
- a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value);
130
- } else if (name == "autotune_duration") {
131
- a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value);
132
- } else if (name == "autotune_model_size") {
133
- a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value);
134
- } else {
135
- throw std::invalid_argument("Unknown argument: " + name);
136
- }
137
- }
138
- return a;
139
- }
140
-
141
49
  extern "C"
142
50
  void Init_ext()
143
51
  {
144
52
  Module rb_mFastText = define_module("FastText");
145
53
  Module rb_mExt = define_module_under(rb_mFastText, "Ext");
146
54
 
55
+ define_class_under<Args>(rb_mExt, "Args")
56
+ .define_constructor(Constructor<Args>())
57
+ .define_attr("input", &Args::input)
58
+ .define_attr("output", &Args::output)
59
+ .define_attr("lr", &Args::lr)
60
+ .define_attr("lr_update_rate", &Args::lrUpdateRate)
61
+ .define_attr("dim", &Args::dim)
62
+ .define_attr("ws", &Args::ws)
63
+ .define_attr("epoch", &Args::epoch)
64
+ .define_attr("min_count", &Args::minCount)
65
+ .define_attr("min_count_label", &Args::minCountLabel)
66
+ .define_attr("neg", &Args::neg)
67
+ .define_attr("word_ngrams", &Args::wordNgrams)
68
+ .define_method(
69
+ "loss=",
70
+ [](Args& a, const std::string& str) {
71
+ if (str == "softmax") {
72
+ a.loss = fasttext::loss_name::softmax;
73
+ } else if (str == "ns") {
74
+ a.loss = fasttext::loss_name::ns;
75
+ } else if (str == "hs") {
76
+ a.loss = fasttext::loss_name::hs;
77
+ } else if (str == "ova") {
78
+ a.loss = fasttext::loss_name::ova;
79
+ } else {
80
+ throw std::invalid_argument("Unknown loss: " + str);
81
+ }
82
+ })
83
+ .define_method(
84
+ "model=",
85
+ [](Args& a, const std::string& str) {
86
+ if (str == "supervised") {
87
+ a.model = fasttext::model_name::sup;
88
+ } else if (str == "skipgram") {
89
+ a.model = fasttext::model_name::sg;
90
+ } else if (str == "cbow") {
91
+ a.model = fasttext::model_name::cbow;
92
+ } else {
93
+ throw std::invalid_argument("Unknown model: " + str);
94
+ }
95
+ })
96
+ .define_attr("bucket", &Args::bucket)
97
+ .define_attr("minn", &Args::minn)
98
+ .define_attr("maxn", &Args::maxn)
99
+ .define_attr("thread", &Args::thread)
100
+ .define_attr("t", &Args::t)
101
+ .define_attr("label_prefix", &Args::label)
102
+ .define_attr("verbose", &Args::verbose)
103
+ .define_attr("pretrained_vectors", &Args::pretrainedVectors)
104
+ .define_attr("save_output", &Args::saveOutput)
105
+ .define_attr("seed", &Args::seed)
106
+ .define_attr("autotune_validation_file", &Args::autotuneValidationFile)
107
+ .define_attr("autotune_metric", &Args::autotuneMetric)
108
+ .define_attr("autotune_predictions", &Args::autotunePredictions)
109
+ .define_attr("autotune_duration", &Args::autotuneDuration)
110
+ .define_attr("autotune_model_size", &Args::autotuneModelSize);
111
+
147
112
  define_class_under<FastText>(rb_mExt, "Model")
148
113
  .define_constructor(Constructor<FastText>())
149
114
  .define_method(
@@ -231,13 +196,12 @@ void Init_ext()
231
196
  .define_method(
232
197
  "word_vector",
233
198
  [](FastText& m, const std::string& word) {
234
- int dimension = m.getDimension();
199
+ auto dimension = m.getDimension();
235
200
  fasttext::Vector vec = fasttext::Vector(dimension);
236
201
  m.getWordVector(vec, word);
237
- float* data = vec.data();
238
202
  Array ret;
239
- for (int i = 0; i < dimension; i++) {
240
- ret.push(data[i]);
203
+ for (size_t i = 0; i < vec.size(); i++) {
204
+ ret.push(vec[i]);
241
205
  }
242
206
  return ret;
243
207
  })
@@ -259,20 +223,18 @@ void Init_ext()
259
223
  "sentence_vector",
260
224
  [](FastText& m, const std::string& text) {
261
225
  std::istringstream in(text);
262
- int dimension = m.getDimension();
226
+ auto dimension = m.getDimension();
263
227
  fasttext::Vector vec = fasttext::Vector(dimension);
264
228
  m.getSentenceVector(in, vec);
265
- float* data = vec.data();
266
229
  Array ret;
267
- for (int i = 0; i < dimension; i++) {
268
- ret.push(data[i]);
230
+ for (size_t i = 0; i < vec.size(); i++) {
231
+ ret.push(vec[i]);
269
232
  }
270
233
  return ret;
271
234
  })
272
235
  .define_method(
273
236
  "train",
274
- [](FastText& m, Hash h) {
275
- auto a = buildArgs(h);
237
+ [](FastText& m, Args& a) {
276
238
  if (a.hasAutotune()) {
277
239
  fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
278
240
  autotune.train(a);
@@ -282,8 +244,8 @@ void Init_ext()
282
244
  })
283
245
  .define_method(
284
246
  "quantize",
285
- [](FastText& m, Hash h) {
286
- m.quantize(buildArgs(h));
247
+ [](FastText& m, Args& a) {
248
+ m.quantize(a);
287
249
  })
288
250
  .define_method(
289
251
  "supervised?",
@@ -30,14 +30,16 @@ module FastText
30
30
  }
31
31
 
32
32
  def fit(x, y = nil, autotune_set: nil)
33
- input = input_path(x, y)
33
+ input, _ref = input_path(x, y)
34
34
  @m ||= Ext::Model.new
35
- opts = DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised")
35
+ a = build_args(DEFAULT_OPTIONS)
36
+ a.input = input
37
+ a.model = "supervised"
36
38
  if autotune_set
37
39
  x, y = autotune_set
38
- opts.merge!(autotune_validation_file: input_path(x, y))
40
+ a.autotune_validation_file, _autotune_ref = input_path(x, y)
39
41
  end
40
- m.train(opts)
42
+ m.train(a)
41
43
  end
42
44
 
43
45
  def predict(text, k: 1, threshold: 0.0)
@@ -47,16 +49,16 @@ module FastText
47
49
  # TODO predict multiple in C++ for performance
48
50
  result =
49
51
  text.map do |t|
50
- m.predict(prep_text(t), k, threshold).map do |v|
52
+ m.predict(prep_text(t), k, threshold).to_h do |v|
51
53
  [remove_prefix(v[1]), v[0]]
52
- end.to_h
54
+ end
53
55
  end
54
56
 
55
57
  multiple ? result : result.first
56
58
  end
57
59
 
58
60
  def test(x, y = nil, k: 1)
59
- input = input_path(x, y)
61
+ input, _ref = input_path(x, y)
60
62
  res = m.test(input, k)
61
63
  {
62
64
  examples: res[0],
@@ -67,7 +69,8 @@ module FastText
67
69
 
68
70
  # TODO support options
69
71
  def quantize
70
- m.quantize({})
72
+ a = Ext::Args.new
73
+ m.quantize(a)
71
74
  end
72
75
 
73
76
  def labels(include_freq: false)
@@ -85,7 +88,7 @@ module FastText
85
88
  def input_path(x, y)
86
89
  if x.is_a?(String)
87
90
  raise ArgumentError, "Cannot pass y with file" if y
88
- x
91
+ [x, nil]
89
92
  else
90
93
  tempfile = Tempfile.new("fasttext")
91
94
  x.zip(y) do |xi, yi|
@@ -95,7 +98,7 @@ module FastText
95
98
  tempfile.write("\n")
96
99
  end
97
100
  tempfile.close
98
- tempfile.path
101
+ [tempfile.path, tempfile]
99
102
  end
100
103
  end
101
104
 
@@ -56,5 +56,15 @@ module FastText
56
56
  def m
57
57
  @m || (raise Error, "Not fit")
58
58
  end
59
+
60
+ def build_args(default_options)
61
+ a = Ext::Args.new
62
+ opts = @options.dup
63
+ default_options.each do |k, v|
64
+ a.send("#{k}=", opts.delete(k) || v)
65
+ end
66
+ raise ArgumentError, "Unknown argument: #{opts.keys.first}" if opts.any?
67
+ a
68
+ end
59
69
  end
60
70
  end
@@ -29,9 +29,10 @@ module FastText
29
29
  }
30
30
 
31
31
  def fit(x)
32
- input = input_path(x)
33
32
  @m ||= Ext::Model.new
34
- m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input))
33
+ a = build_args(DEFAULT_OPTIONS)
34
+ a.input, _ref = input_path(x)
35
+ m.train(a)
35
36
  end
36
37
 
37
38
  def nearest_neighbors(word, k: 10)
@@ -48,7 +49,7 @@ module FastText
48
49
  # https://github.com/facebookresearch/fastText/issues/518
49
50
  def input_path(x)
50
51
  if x.is_a?(String)
51
- x
52
+ [x, nil]
52
53
  else
53
54
  tempfile = Tempfile.new("fasttext")
54
55
  x.each do |xi|
@@ -56,7 +57,7 @@ module FastText
56
57
  tempfile.write("\n")
57
58
  end
58
59
  tempfile.close
59
- tempfile.path
60
+ [tempfile.path, tempfile]
60
61
  end
61
62
  end
62
63
  end
@@ -1,3 +1,3 @@
1
1
  module FastText
2
- VERSION = "0.2.1"
2
+ VERSION = "0.2.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: fasttext
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.2.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-05-23 00:00:00.000000000 Z
11
+ date: 2021-10-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -90,7 +90,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
90
90
  - !ruby/object:Gem::Version
91
91
  version: '0'
92
92
  requirements: []
93
- rubygems_version: 3.2.3
93
+ rubygems_version: 3.2.22
94
94
  signing_key:
95
95
  specification_version: 4
96
96
  summary: fastText - efficient text classification and representation learning - for