fasttext 0.2.1 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/fasttext/ext.cpp +67 -105
- data/lib/fasttext/classifier.rb +13 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +5 -4
- data/lib/fasttext/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9aae7e20933f51ebebd802276d7e006ce792fcfc6dd94ea8ceda887ab0f4eca8
|
4
|
+
data.tar.gz: a90a7bbfffe424829052afc3519bc0473a7fc1b22a2d1ea0c8786d9372a0cdb5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3042a798560e5960d18d8bcefb66833bf621bc1dcf26f77742602d0d4e7b1e4cc33b078f9225c209ac9a0d2f89c9b3c98ad2446582af4a3ac4effc761192c58d
|
7
|
+
data.tar.gz: d70ce1005916a809a78b02f23232b1efd7cc59dc9d3ed2df7929f946db2f326c0df40729b2ea953846e906c2774d6b6e971650f3d585da65ae6562bdcb9a9b5e
|
data/CHANGELOG.md
CHANGED
data/ext/fasttext/ext.cpp
CHANGED
@@ -16,13 +16,12 @@
|
|
16
16
|
#include <rice/rice.hpp>
|
17
17
|
#include <rice/stl.hpp>
|
18
18
|
|
19
|
+
using fasttext::Args;
|
19
20
|
using fasttext::FastText;
|
20
21
|
|
21
22
|
using Rice::Array;
|
22
23
|
using Rice::Constructor;
|
23
|
-
using Rice::Hash;
|
24
24
|
using Rice::Module;
|
25
|
-
using Rice::Object;
|
26
25
|
using Rice::define_class_under;
|
27
26
|
using Rice::define_module;
|
28
27
|
using Rice::define_module_under;
|
@@ -47,103 +46,69 @@ namespace Rice::detail
|
|
47
46
|
};
|
48
47
|
}
|
49
48
|
|
50
|
-
fasttext::Args buildArgs(Hash h) {
|
51
|
-
fasttext::Args a;
|
52
|
-
|
53
|
-
for (const auto& it : h)
|
54
|
-
{
|
55
|
-
auto name = it.key.to_s().str();
|
56
|
-
auto value = (it.value).value();
|
57
|
-
|
58
|
-
if (name == "input") {
|
59
|
-
a.input = Rice::detail::From_Ruby<std::string>().convert(value);
|
60
|
-
} else if (name == "output") {
|
61
|
-
a.output = Rice::detail::From_Ruby<std::string>().convert(value);
|
62
|
-
} else if (name == "lr") {
|
63
|
-
a.lr = Rice::detail::From_Ruby<double>().convert(value);
|
64
|
-
} else if (name == "lr_update_rate") {
|
65
|
-
a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value);
|
66
|
-
} else if (name == "dim") {
|
67
|
-
a.dim = Rice::detail::From_Ruby<int>().convert(value);
|
68
|
-
} else if (name == "ws") {
|
69
|
-
a.ws = Rice::detail::From_Ruby<int>().convert(value);
|
70
|
-
} else if (name == "epoch") {
|
71
|
-
a.epoch = Rice::detail::From_Ruby<int>().convert(value);
|
72
|
-
} else if (name == "min_count") {
|
73
|
-
a.minCount = Rice::detail::From_Ruby<int>().convert(value);
|
74
|
-
} else if (name == "min_count_label") {
|
75
|
-
a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value);
|
76
|
-
} else if (name == "neg") {
|
77
|
-
a.neg = Rice::detail::From_Ruby<int>().convert(value);
|
78
|
-
} else if (name == "word_ngrams") {
|
79
|
-
a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value);
|
80
|
-
} else if (name == "loss") {
|
81
|
-
std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
|
82
|
-
if (str == "softmax") {
|
83
|
-
a.loss = fasttext::loss_name::softmax;
|
84
|
-
} else if (str == "ns") {
|
85
|
-
a.loss = fasttext::loss_name::ns;
|
86
|
-
} else if (str == "hs") {
|
87
|
-
a.loss = fasttext::loss_name::hs;
|
88
|
-
} else if (str == "ova") {
|
89
|
-
a.loss = fasttext::loss_name::ova;
|
90
|
-
} else {
|
91
|
-
throw std::invalid_argument("Unknown loss: " + str);
|
92
|
-
}
|
93
|
-
} else if (name == "model") {
|
94
|
-
std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
|
95
|
-
if (str == "supervised") {
|
96
|
-
a.model = fasttext::model_name::sup;
|
97
|
-
} else if (str == "skipgram") {
|
98
|
-
a.model = fasttext::model_name::sg;
|
99
|
-
} else if (str == "cbow") {
|
100
|
-
a.model = fasttext::model_name::cbow;
|
101
|
-
} else {
|
102
|
-
throw std::invalid_argument("Unknown model: " + str);
|
103
|
-
}
|
104
|
-
} else if (name == "bucket") {
|
105
|
-
a.bucket = Rice::detail::From_Ruby<int>().convert(value);
|
106
|
-
} else if (name == "minn") {
|
107
|
-
a.minn = Rice::detail::From_Ruby<int>().convert(value);
|
108
|
-
} else if (name == "maxn") {
|
109
|
-
a.maxn = Rice::detail::From_Ruby<int>().convert(value);
|
110
|
-
} else if (name == "thread") {
|
111
|
-
a.thread = Rice::detail::From_Ruby<int>().convert(value);
|
112
|
-
} else if (name == "t") {
|
113
|
-
a.t = Rice::detail::From_Ruby<double>().convert(value);
|
114
|
-
} else if (name == "label_prefix") {
|
115
|
-
a.label = Rice::detail::From_Ruby<std::string>().convert(value);
|
116
|
-
} else if (name == "verbose") {
|
117
|
-
a.verbose = Rice::detail::From_Ruby<int>().convert(value);
|
118
|
-
} else if (name == "pretrained_vectors") {
|
119
|
-
a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value);
|
120
|
-
} else if (name == "save_output") {
|
121
|
-
a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value);
|
122
|
-
} else if (name == "seed") {
|
123
|
-
a.seed = Rice::detail::From_Ruby<int>().convert(value);
|
124
|
-
} else if (name == "autotune_validation_file") {
|
125
|
-
a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value);
|
126
|
-
} else if (name == "autotune_metric") {
|
127
|
-
a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value);
|
128
|
-
} else if (name == "autotune_predictions") {
|
129
|
-
a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value);
|
130
|
-
} else if (name == "autotune_duration") {
|
131
|
-
a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value);
|
132
|
-
} else if (name == "autotune_model_size") {
|
133
|
-
a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value);
|
134
|
-
} else {
|
135
|
-
throw std::invalid_argument("Unknown argument: " + name);
|
136
|
-
}
|
137
|
-
}
|
138
|
-
return a;
|
139
|
-
}
|
140
|
-
|
141
49
|
extern "C"
|
142
50
|
void Init_ext()
|
143
51
|
{
|
144
52
|
Module rb_mFastText = define_module("FastText");
|
145
53
|
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
|
146
54
|
|
55
|
+
define_class_under<Args>(rb_mExt, "Args")
|
56
|
+
.define_constructor(Constructor<Args>())
|
57
|
+
.define_attr("input", &Args::input)
|
58
|
+
.define_attr("output", &Args::output)
|
59
|
+
.define_attr("lr", &Args::lr)
|
60
|
+
.define_attr("lr_update_rate", &Args::lrUpdateRate)
|
61
|
+
.define_attr("dim", &Args::dim)
|
62
|
+
.define_attr("ws", &Args::ws)
|
63
|
+
.define_attr("epoch", &Args::epoch)
|
64
|
+
.define_attr("min_count", &Args::minCount)
|
65
|
+
.define_attr("min_count_label", &Args::minCountLabel)
|
66
|
+
.define_attr("neg", &Args::neg)
|
67
|
+
.define_attr("word_ngrams", &Args::wordNgrams)
|
68
|
+
.define_method(
|
69
|
+
"loss=",
|
70
|
+
[](Args& a, const std::string& str) {
|
71
|
+
if (str == "softmax") {
|
72
|
+
a.loss = fasttext::loss_name::softmax;
|
73
|
+
} else if (str == "ns") {
|
74
|
+
a.loss = fasttext::loss_name::ns;
|
75
|
+
} else if (str == "hs") {
|
76
|
+
a.loss = fasttext::loss_name::hs;
|
77
|
+
} else if (str == "ova") {
|
78
|
+
a.loss = fasttext::loss_name::ova;
|
79
|
+
} else {
|
80
|
+
throw std::invalid_argument("Unknown loss: " + str);
|
81
|
+
}
|
82
|
+
})
|
83
|
+
.define_method(
|
84
|
+
"model=",
|
85
|
+
[](Args& a, const std::string& str) {
|
86
|
+
if (str == "supervised") {
|
87
|
+
a.model = fasttext::model_name::sup;
|
88
|
+
} else if (str == "skipgram") {
|
89
|
+
a.model = fasttext::model_name::sg;
|
90
|
+
} else if (str == "cbow") {
|
91
|
+
a.model = fasttext::model_name::cbow;
|
92
|
+
} else {
|
93
|
+
throw std::invalid_argument("Unknown model: " + str);
|
94
|
+
}
|
95
|
+
})
|
96
|
+
.define_attr("bucket", &Args::bucket)
|
97
|
+
.define_attr("minn", &Args::minn)
|
98
|
+
.define_attr("maxn", &Args::maxn)
|
99
|
+
.define_attr("thread", &Args::thread)
|
100
|
+
.define_attr("t", &Args::t)
|
101
|
+
.define_attr("label_prefix", &Args::label)
|
102
|
+
.define_attr("verbose", &Args::verbose)
|
103
|
+
.define_attr("pretrained_vectors", &Args::pretrainedVectors)
|
104
|
+
.define_attr("save_output", &Args::saveOutput)
|
105
|
+
.define_attr("seed", &Args::seed)
|
106
|
+
.define_attr("autotune_validation_file", &Args::autotuneValidationFile)
|
107
|
+
.define_attr("autotune_metric", &Args::autotuneMetric)
|
108
|
+
.define_attr("autotune_predictions", &Args::autotunePredictions)
|
109
|
+
.define_attr("autotune_duration", &Args::autotuneDuration)
|
110
|
+
.define_attr("autotune_model_size", &Args::autotuneModelSize);
|
111
|
+
|
147
112
|
define_class_under<FastText>(rb_mExt, "Model")
|
148
113
|
.define_constructor(Constructor<FastText>())
|
149
114
|
.define_method(
|
@@ -231,13 +196,12 @@ void Init_ext()
|
|
231
196
|
.define_method(
|
232
197
|
"word_vector",
|
233
198
|
[](FastText& m, const std::string& word) {
|
234
|
-
|
199
|
+
auto dimension = m.getDimension();
|
235
200
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
236
201
|
m.getWordVector(vec, word);
|
237
|
-
float* data = vec.data();
|
238
202
|
Array ret;
|
239
|
-
for (
|
240
|
-
ret.push(
|
203
|
+
for (size_t i = 0; i < vec.size(); i++) {
|
204
|
+
ret.push(vec[i]);
|
241
205
|
}
|
242
206
|
return ret;
|
243
207
|
})
|
@@ -259,20 +223,18 @@ void Init_ext()
|
|
259
223
|
"sentence_vector",
|
260
224
|
[](FastText& m, const std::string& text) {
|
261
225
|
std::istringstream in(text);
|
262
|
-
|
226
|
+
auto dimension = m.getDimension();
|
263
227
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
264
228
|
m.getSentenceVector(in, vec);
|
265
|
-
float* data = vec.data();
|
266
229
|
Array ret;
|
267
|
-
for (
|
268
|
-
ret.push(
|
230
|
+
for (size_t i = 0; i < vec.size(); i++) {
|
231
|
+
ret.push(vec[i]);
|
269
232
|
}
|
270
233
|
return ret;
|
271
234
|
})
|
272
235
|
.define_method(
|
273
236
|
"train",
|
274
|
-
[](FastText& m,
|
275
|
-
auto a = buildArgs(h);
|
237
|
+
[](FastText& m, Args& a) {
|
276
238
|
if (a.hasAutotune()) {
|
277
239
|
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
|
278
240
|
autotune.train(a);
|
@@ -282,8 +244,8 @@ void Init_ext()
|
|
282
244
|
})
|
283
245
|
.define_method(
|
284
246
|
"quantize",
|
285
|
-
[](FastText& m,
|
286
|
-
m.quantize(
|
247
|
+
[](FastText& m, Args& a) {
|
248
|
+
m.quantize(a);
|
287
249
|
})
|
288
250
|
.define_method(
|
289
251
|
"supervised?",
|
data/lib/fasttext/classifier.rb
CHANGED
@@ -30,14 +30,16 @@ module FastText
|
|
30
30
|
}
|
31
31
|
|
32
32
|
def fit(x, y = nil, autotune_set: nil)
|
33
|
-
input = input_path(x, y)
|
33
|
+
input, _ref = input_path(x, y)
|
34
34
|
@m ||= Ext::Model.new
|
35
|
-
|
35
|
+
a = build_args(DEFAULT_OPTIONS)
|
36
|
+
a.input = input
|
37
|
+
a.model = "supervised"
|
36
38
|
if autotune_set
|
37
39
|
x, y = autotune_set
|
38
|
-
|
40
|
+
a.autotune_validation_file, _autotune_ref = input_path(x, y)
|
39
41
|
end
|
40
|
-
m.train(
|
42
|
+
m.train(a)
|
41
43
|
end
|
42
44
|
|
43
45
|
def predict(text, k: 1, threshold: 0.0)
|
@@ -47,16 +49,16 @@ module FastText
|
|
47
49
|
# TODO predict multiple in C++ for performance
|
48
50
|
result =
|
49
51
|
text.map do |t|
|
50
|
-
m.predict(prep_text(t), k, threshold).
|
52
|
+
m.predict(prep_text(t), k, threshold).to_h do |v|
|
51
53
|
[remove_prefix(v[1]), v[0]]
|
52
|
-
end
|
54
|
+
end
|
53
55
|
end
|
54
56
|
|
55
57
|
multiple ? result : result.first
|
56
58
|
end
|
57
59
|
|
58
60
|
def test(x, y = nil, k: 1)
|
59
|
-
input = input_path(x, y)
|
61
|
+
input, _ref = input_path(x, y)
|
60
62
|
res = m.test(input, k)
|
61
63
|
{
|
62
64
|
examples: res[0],
|
@@ -67,7 +69,8 @@ module FastText
|
|
67
69
|
|
68
70
|
# TODO support options
|
69
71
|
def quantize
|
70
|
-
|
72
|
+
a = Ext::Args.new
|
73
|
+
m.quantize(a)
|
71
74
|
end
|
72
75
|
|
73
76
|
def labels(include_freq: false)
|
@@ -85,7 +88,7 @@ module FastText
|
|
85
88
|
def input_path(x, y)
|
86
89
|
if x.is_a?(String)
|
87
90
|
raise ArgumentError, "Cannot pass y with file" if y
|
88
|
-
x
|
91
|
+
[x, nil]
|
89
92
|
else
|
90
93
|
tempfile = Tempfile.new("fasttext")
|
91
94
|
x.zip(y) do |xi, yi|
|
@@ -95,7 +98,7 @@ module FastText
|
|
95
98
|
tempfile.write("\n")
|
96
99
|
end
|
97
100
|
tempfile.close
|
98
|
-
tempfile.path
|
101
|
+
[tempfile.path, tempfile]
|
99
102
|
end
|
100
103
|
end
|
101
104
|
|
data/lib/fasttext/model.rb
CHANGED
@@ -56,5 +56,15 @@ module FastText
|
|
56
56
|
def m
|
57
57
|
@m || (raise Error, "Not fit")
|
58
58
|
end
|
59
|
+
|
60
|
+
def build_args(default_options)
|
61
|
+
a = Ext::Args.new
|
62
|
+
opts = @options.dup
|
63
|
+
default_options.each do |k, v|
|
64
|
+
a.send("#{k}=", opts.delete(k) || v)
|
65
|
+
end
|
66
|
+
raise ArgumentError, "Unknown argument: #{opts.keys.first}" if opts.any?
|
67
|
+
a
|
68
|
+
end
|
59
69
|
end
|
60
70
|
end
|
data/lib/fasttext/vectorizer.rb
CHANGED
@@ -29,9 +29,10 @@ module FastText
|
|
29
29
|
}
|
30
30
|
|
31
31
|
def fit(x)
|
32
|
-
input = input_path(x)
|
33
32
|
@m ||= Ext::Model.new
|
34
|
-
|
33
|
+
a = build_args(DEFAULT_OPTIONS)
|
34
|
+
a.input, _ref = input_path(x)
|
35
|
+
m.train(a)
|
35
36
|
end
|
36
37
|
|
37
38
|
def nearest_neighbors(word, k: 10)
|
@@ -48,7 +49,7 @@ module FastText
|
|
48
49
|
# https://github.com/facebookresearch/fastText/issues/518
|
49
50
|
def input_path(x)
|
50
51
|
if x.is_a?(String)
|
51
|
-
x
|
52
|
+
[x, nil]
|
52
53
|
else
|
53
54
|
tempfile = Tempfile.new("fasttext")
|
54
55
|
x.each do |xi|
|
@@ -56,7 +57,7 @@ module FastText
|
|
56
57
|
tempfile.write("\n")
|
57
58
|
end
|
58
59
|
tempfile.close
|
59
|
-
tempfile.path
|
60
|
+
[tempfile.path, tempfile]
|
60
61
|
end
|
61
62
|
end
|
62
63
|
end
|
data/lib/fasttext/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: fasttext
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-10-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -90,7 +90,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
90
90
|
- !ruby/object:Gem::Version
|
91
91
|
version: '0'
|
92
92
|
requirements: []
|
93
|
-
rubygems_version: 3.2.
|
93
|
+
rubygems_version: 3.2.22
|
94
94
|
signing_key:
|
95
95
|
specification_version: 4
|
96
96
|
summary: fastText - efficient text classification and representation learning - for
|