fasttext 0.2.1 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/ext/fasttext/ext.cpp +67 -105
- data/lib/fasttext/classifier.rb +13 -10
- data/lib/fasttext/model.rb +10 -0
- data/lib/fasttext/vectorizer.rb +5 -4
- data/lib/fasttext/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9aae7e20933f51ebebd802276d7e006ce792fcfc6dd94ea8ceda887ab0f4eca8
|
4
|
+
data.tar.gz: a90a7bbfffe424829052afc3519bc0473a7fc1b22a2d1ea0c8786d9372a0cdb5
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3042a798560e5960d18d8bcefb66833bf621bc1dcf26f77742602d0d4e7b1e4cc33b078f9225c209ac9a0d2f89c9b3c98ad2446582af4a3ac4effc761192c58d
|
7
|
+
data.tar.gz: d70ce1005916a809a78b02f23232b1efd7cc59dc9d3ed2df7929f946db2f326c0df40729b2ea953846e906c2774d6b6e971650f3d585da65ae6562bdcb9a9b5e
|
data/CHANGELOG.md
CHANGED
data/ext/fasttext/ext.cpp
CHANGED
@@ -16,13 +16,12 @@
|
|
16
16
|
#include <rice/rice.hpp>
|
17
17
|
#include <rice/stl.hpp>
|
18
18
|
|
19
|
+
using fasttext::Args;
|
19
20
|
using fasttext::FastText;
|
20
21
|
|
21
22
|
using Rice::Array;
|
22
23
|
using Rice::Constructor;
|
23
|
-
using Rice::Hash;
|
24
24
|
using Rice::Module;
|
25
|
-
using Rice::Object;
|
26
25
|
using Rice::define_class_under;
|
27
26
|
using Rice::define_module;
|
28
27
|
using Rice::define_module_under;
|
@@ -47,103 +46,69 @@ namespace Rice::detail
|
|
47
46
|
};
|
48
47
|
}
|
49
48
|
|
50
|
-
fasttext::Args buildArgs(Hash h) {
|
51
|
-
fasttext::Args a;
|
52
|
-
|
53
|
-
for (const auto& it : h)
|
54
|
-
{
|
55
|
-
auto name = it.key.to_s().str();
|
56
|
-
auto value = (it.value).value();
|
57
|
-
|
58
|
-
if (name == "input") {
|
59
|
-
a.input = Rice::detail::From_Ruby<std::string>().convert(value);
|
60
|
-
} else if (name == "output") {
|
61
|
-
a.output = Rice::detail::From_Ruby<std::string>().convert(value);
|
62
|
-
} else if (name == "lr") {
|
63
|
-
a.lr = Rice::detail::From_Ruby<double>().convert(value);
|
64
|
-
} else if (name == "lr_update_rate") {
|
65
|
-
a.lrUpdateRate = Rice::detail::From_Ruby<int>().convert(value);
|
66
|
-
} else if (name == "dim") {
|
67
|
-
a.dim = Rice::detail::From_Ruby<int>().convert(value);
|
68
|
-
} else if (name == "ws") {
|
69
|
-
a.ws = Rice::detail::From_Ruby<int>().convert(value);
|
70
|
-
} else if (name == "epoch") {
|
71
|
-
a.epoch = Rice::detail::From_Ruby<int>().convert(value);
|
72
|
-
} else if (name == "min_count") {
|
73
|
-
a.minCount = Rice::detail::From_Ruby<int>().convert(value);
|
74
|
-
} else if (name == "min_count_label") {
|
75
|
-
a.minCountLabel = Rice::detail::From_Ruby<int>().convert(value);
|
76
|
-
} else if (name == "neg") {
|
77
|
-
a.neg = Rice::detail::From_Ruby<int>().convert(value);
|
78
|
-
} else if (name == "word_ngrams") {
|
79
|
-
a.wordNgrams = Rice::detail::From_Ruby<int>().convert(value);
|
80
|
-
} else if (name == "loss") {
|
81
|
-
std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
|
82
|
-
if (str == "softmax") {
|
83
|
-
a.loss = fasttext::loss_name::softmax;
|
84
|
-
} else if (str == "ns") {
|
85
|
-
a.loss = fasttext::loss_name::ns;
|
86
|
-
} else if (str == "hs") {
|
87
|
-
a.loss = fasttext::loss_name::hs;
|
88
|
-
} else if (str == "ova") {
|
89
|
-
a.loss = fasttext::loss_name::ova;
|
90
|
-
} else {
|
91
|
-
throw std::invalid_argument("Unknown loss: " + str);
|
92
|
-
}
|
93
|
-
} else if (name == "model") {
|
94
|
-
std::string str = Rice::detail::From_Ruby<std::string>().convert(value);
|
95
|
-
if (str == "supervised") {
|
96
|
-
a.model = fasttext::model_name::sup;
|
97
|
-
} else if (str == "skipgram") {
|
98
|
-
a.model = fasttext::model_name::sg;
|
99
|
-
} else if (str == "cbow") {
|
100
|
-
a.model = fasttext::model_name::cbow;
|
101
|
-
} else {
|
102
|
-
throw std::invalid_argument("Unknown model: " + str);
|
103
|
-
}
|
104
|
-
} else if (name == "bucket") {
|
105
|
-
a.bucket = Rice::detail::From_Ruby<int>().convert(value);
|
106
|
-
} else if (name == "minn") {
|
107
|
-
a.minn = Rice::detail::From_Ruby<int>().convert(value);
|
108
|
-
} else if (name == "maxn") {
|
109
|
-
a.maxn = Rice::detail::From_Ruby<int>().convert(value);
|
110
|
-
} else if (name == "thread") {
|
111
|
-
a.thread = Rice::detail::From_Ruby<int>().convert(value);
|
112
|
-
} else if (name == "t") {
|
113
|
-
a.t = Rice::detail::From_Ruby<double>().convert(value);
|
114
|
-
} else if (name == "label_prefix") {
|
115
|
-
a.label = Rice::detail::From_Ruby<std::string>().convert(value);
|
116
|
-
} else if (name == "verbose") {
|
117
|
-
a.verbose = Rice::detail::From_Ruby<int>().convert(value);
|
118
|
-
} else if (name == "pretrained_vectors") {
|
119
|
-
a.pretrainedVectors = Rice::detail::From_Ruby<std::string>().convert(value);
|
120
|
-
} else if (name == "save_output") {
|
121
|
-
a.saveOutput = Rice::detail::From_Ruby<bool>().convert(value);
|
122
|
-
} else if (name == "seed") {
|
123
|
-
a.seed = Rice::detail::From_Ruby<int>().convert(value);
|
124
|
-
} else if (name == "autotune_validation_file") {
|
125
|
-
a.autotuneValidationFile = Rice::detail::From_Ruby<std::string>().convert(value);
|
126
|
-
} else if (name == "autotune_metric") {
|
127
|
-
a.autotuneMetric = Rice::detail::From_Ruby<std::string>().convert(value);
|
128
|
-
} else if (name == "autotune_predictions") {
|
129
|
-
a.autotunePredictions = Rice::detail::From_Ruby<int>().convert(value);
|
130
|
-
} else if (name == "autotune_duration") {
|
131
|
-
a.autotuneDuration = Rice::detail::From_Ruby<int>().convert(value);
|
132
|
-
} else if (name == "autotune_model_size") {
|
133
|
-
a.autotuneModelSize = Rice::detail::From_Ruby<std::string>().convert(value);
|
134
|
-
} else {
|
135
|
-
throw std::invalid_argument("Unknown argument: " + name);
|
136
|
-
}
|
137
|
-
}
|
138
|
-
return a;
|
139
|
-
}
|
140
|
-
|
141
49
|
extern "C"
|
142
50
|
void Init_ext()
|
143
51
|
{
|
144
52
|
Module rb_mFastText = define_module("FastText");
|
145
53
|
Module rb_mExt = define_module_under(rb_mFastText, "Ext");
|
146
54
|
|
55
|
+
define_class_under<Args>(rb_mExt, "Args")
|
56
|
+
.define_constructor(Constructor<Args>())
|
57
|
+
.define_attr("input", &Args::input)
|
58
|
+
.define_attr("output", &Args::output)
|
59
|
+
.define_attr("lr", &Args::lr)
|
60
|
+
.define_attr("lr_update_rate", &Args::lrUpdateRate)
|
61
|
+
.define_attr("dim", &Args::dim)
|
62
|
+
.define_attr("ws", &Args::ws)
|
63
|
+
.define_attr("epoch", &Args::epoch)
|
64
|
+
.define_attr("min_count", &Args::minCount)
|
65
|
+
.define_attr("min_count_label", &Args::minCountLabel)
|
66
|
+
.define_attr("neg", &Args::neg)
|
67
|
+
.define_attr("word_ngrams", &Args::wordNgrams)
|
68
|
+
.define_method(
|
69
|
+
"loss=",
|
70
|
+
[](Args& a, const std::string& str) {
|
71
|
+
if (str == "softmax") {
|
72
|
+
a.loss = fasttext::loss_name::softmax;
|
73
|
+
} else if (str == "ns") {
|
74
|
+
a.loss = fasttext::loss_name::ns;
|
75
|
+
} else if (str == "hs") {
|
76
|
+
a.loss = fasttext::loss_name::hs;
|
77
|
+
} else if (str == "ova") {
|
78
|
+
a.loss = fasttext::loss_name::ova;
|
79
|
+
} else {
|
80
|
+
throw std::invalid_argument("Unknown loss: " + str);
|
81
|
+
}
|
82
|
+
})
|
83
|
+
.define_method(
|
84
|
+
"model=",
|
85
|
+
[](Args& a, const std::string& str) {
|
86
|
+
if (str == "supervised") {
|
87
|
+
a.model = fasttext::model_name::sup;
|
88
|
+
} else if (str == "skipgram") {
|
89
|
+
a.model = fasttext::model_name::sg;
|
90
|
+
} else if (str == "cbow") {
|
91
|
+
a.model = fasttext::model_name::cbow;
|
92
|
+
} else {
|
93
|
+
throw std::invalid_argument("Unknown model: " + str);
|
94
|
+
}
|
95
|
+
})
|
96
|
+
.define_attr("bucket", &Args::bucket)
|
97
|
+
.define_attr("minn", &Args::minn)
|
98
|
+
.define_attr("maxn", &Args::maxn)
|
99
|
+
.define_attr("thread", &Args::thread)
|
100
|
+
.define_attr("t", &Args::t)
|
101
|
+
.define_attr("label_prefix", &Args::label)
|
102
|
+
.define_attr("verbose", &Args::verbose)
|
103
|
+
.define_attr("pretrained_vectors", &Args::pretrainedVectors)
|
104
|
+
.define_attr("save_output", &Args::saveOutput)
|
105
|
+
.define_attr("seed", &Args::seed)
|
106
|
+
.define_attr("autotune_validation_file", &Args::autotuneValidationFile)
|
107
|
+
.define_attr("autotune_metric", &Args::autotuneMetric)
|
108
|
+
.define_attr("autotune_predictions", &Args::autotunePredictions)
|
109
|
+
.define_attr("autotune_duration", &Args::autotuneDuration)
|
110
|
+
.define_attr("autotune_model_size", &Args::autotuneModelSize);
|
111
|
+
|
147
112
|
define_class_under<FastText>(rb_mExt, "Model")
|
148
113
|
.define_constructor(Constructor<FastText>())
|
149
114
|
.define_method(
|
@@ -231,13 +196,12 @@ void Init_ext()
|
|
231
196
|
.define_method(
|
232
197
|
"word_vector",
|
233
198
|
[](FastText& m, const std::string& word) {
|
234
|
-
|
199
|
+
auto dimension = m.getDimension();
|
235
200
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
236
201
|
m.getWordVector(vec, word);
|
237
|
-
float* data = vec.data();
|
238
202
|
Array ret;
|
239
|
-
for (
|
240
|
-
ret.push(
|
203
|
+
for (size_t i = 0; i < vec.size(); i++) {
|
204
|
+
ret.push(vec[i]);
|
241
205
|
}
|
242
206
|
return ret;
|
243
207
|
})
|
@@ -259,20 +223,18 @@ void Init_ext()
|
|
259
223
|
"sentence_vector",
|
260
224
|
[](FastText& m, const std::string& text) {
|
261
225
|
std::istringstream in(text);
|
262
|
-
|
226
|
+
auto dimension = m.getDimension();
|
263
227
|
fasttext::Vector vec = fasttext::Vector(dimension);
|
264
228
|
m.getSentenceVector(in, vec);
|
265
|
-
float* data = vec.data();
|
266
229
|
Array ret;
|
267
|
-
for (
|
268
|
-
ret.push(
|
230
|
+
for (size_t i = 0; i < vec.size(); i++) {
|
231
|
+
ret.push(vec[i]);
|
269
232
|
}
|
270
233
|
return ret;
|
271
234
|
})
|
272
235
|
.define_method(
|
273
236
|
"train",
|
274
|
-
[](FastText& m,
|
275
|
-
auto a = buildArgs(h);
|
237
|
+
[](FastText& m, Args& a) {
|
276
238
|
if (a.hasAutotune()) {
|
277
239
|
fasttext::Autotune autotune(std::shared_ptr<fasttext::FastText>(&m, [](fasttext::FastText*) {}));
|
278
240
|
autotune.train(a);
|
@@ -282,8 +244,8 @@ void Init_ext()
|
|
282
244
|
})
|
283
245
|
.define_method(
|
284
246
|
"quantize",
|
285
|
-
[](FastText& m,
|
286
|
-
m.quantize(
|
247
|
+
[](FastText& m, Args& a) {
|
248
|
+
m.quantize(a);
|
287
249
|
})
|
288
250
|
.define_method(
|
289
251
|
"supervised?",
|
data/lib/fasttext/classifier.rb
CHANGED
@@ -30,14 +30,16 @@ module FastText
|
|
30
30
|
}
|
31
31
|
|
32
32
|
def fit(x, y = nil, autotune_set: nil)
|
33
|
-
input = input_path(x, y)
|
33
|
+
input, _ref = input_path(x, y)
|
34
34
|
@m ||= Ext::Model.new
|
35
|
-
|
35
|
+
a = build_args(DEFAULT_OPTIONS)
|
36
|
+
a.input = input
|
37
|
+
a.model = "supervised"
|
36
38
|
if autotune_set
|
37
39
|
x, y = autotune_set
|
38
|
-
|
40
|
+
a.autotune_validation_file, _autotune_ref = input_path(x, y)
|
39
41
|
end
|
40
|
-
m.train(
|
42
|
+
m.train(a)
|
41
43
|
end
|
42
44
|
|
43
45
|
def predict(text, k: 1, threshold: 0.0)
|
@@ -47,16 +49,16 @@ module FastText
|
|
47
49
|
# TODO predict multiple in C++ for performance
|
48
50
|
result =
|
49
51
|
text.map do |t|
|
50
|
-
m.predict(prep_text(t), k, threshold).
|
52
|
+
m.predict(prep_text(t), k, threshold).to_h do |v|
|
51
53
|
[remove_prefix(v[1]), v[0]]
|
52
|
-
end
|
54
|
+
end
|
53
55
|
end
|
54
56
|
|
55
57
|
multiple ? result : result.first
|
56
58
|
end
|
57
59
|
|
58
60
|
def test(x, y = nil, k: 1)
|
59
|
-
input = input_path(x, y)
|
61
|
+
input, _ref = input_path(x, y)
|
60
62
|
res = m.test(input, k)
|
61
63
|
{
|
62
64
|
examples: res[0],
|
@@ -67,7 +69,8 @@ module FastText
|
|
67
69
|
|
68
70
|
# TODO support options
|
69
71
|
def quantize
|
70
|
-
|
72
|
+
a = Ext::Args.new
|
73
|
+
m.quantize(a)
|
71
74
|
end
|
72
75
|
|
73
76
|
def labels(include_freq: false)
|
@@ -85,7 +88,7 @@ module FastText
|
|
85
88
|
def input_path(x, y)
|
86
89
|
if x.is_a?(String)
|
87
90
|
raise ArgumentError, "Cannot pass y with file" if y
|
88
|
-
x
|
91
|
+
[x, nil]
|
89
92
|
else
|
90
93
|
tempfile = Tempfile.new("fasttext")
|
91
94
|
x.zip(y) do |xi, yi|
|
@@ -95,7 +98,7 @@ module FastText
|
|
95
98
|
tempfile.write("\n")
|
96
99
|
end
|
97
100
|
tempfile.close
|
98
|
-
tempfile.path
|
101
|
+
[tempfile.path, tempfile]
|
99
102
|
end
|
100
103
|
end
|
101
104
|
|
data/lib/fasttext/model.rb
CHANGED
@@ -56,5 +56,15 @@ module FastText
|
|
56
56
|
def m
|
57
57
|
@m || (raise Error, "Not fit")
|
58
58
|
end
|
59
|
+
|
60
|
+
def build_args(default_options)
|
61
|
+
a = Ext::Args.new
|
62
|
+
opts = @options.dup
|
63
|
+
default_options.each do |k, v|
|
64
|
+
a.send("#{k}=", opts.delete(k) || v)
|
65
|
+
end
|
66
|
+
raise ArgumentError, "Unknown argument: #{opts.keys.first}" if opts.any?
|
67
|
+
a
|
68
|
+
end
|
59
69
|
end
|
60
70
|
end
|
data/lib/fasttext/vectorizer.rb
CHANGED
@@ -29,9 +29,10 @@ module FastText
|
|
29
29
|
}
|
30
30
|
|
31
31
|
def fit(x)
|
32
|
-
input = input_path(x)
|
33
32
|
@m ||= Ext::Model.new
|
34
|
-
|
33
|
+
a = build_args(DEFAULT_OPTIONS)
|
34
|
+
a.input, _ref = input_path(x)
|
35
|
+
m.train(a)
|
35
36
|
end
|
36
37
|
|
37
38
|
def nearest_neighbors(word, k: 10)
|
@@ -48,7 +49,7 @@ module FastText
|
|
48
49
|
# https://github.com/facebookresearch/fastText/issues/518
|
49
50
|
def input_path(x)
|
50
51
|
if x.is_a?(String)
|
51
|
-
x
|
52
|
+
[x, nil]
|
52
53
|
else
|
53
54
|
tempfile = Tempfile.new("fasttext")
|
54
55
|
x.each do |xi|
|
@@ -56,7 +57,7 @@ module FastText
|
|
56
57
|
tempfile.write("\n")
|
57
58
|
end
|
58
59
|
tempfile.close
|
59
|
-
tempfile.path
|
60
|
+
[tempfile.path, tempfile]
|
60
61
|
end
|
61
62
|
end
|
62
63
|
end
|
data/lib/fasttext/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: fasttext
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-10-16 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -90,7 +90,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
90
90
|
- !ruby/object:Gem::Version
|
91
91
|
version: '0'
|
92
92
|
requirements: []
|
93
|
-
rubygems_version: 3.2.
|
93
|
+
rubygems_version: 3.2.22
|
94
94
|
signing_key:
|
95
95
|
specification_version: 4
|
96
96
|
summary: fastText - efficient text classification and representation learning - for
|