ffi-fasttext 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
@@ -0,0 +1,33 @@
|
|
1
|
+
Additional Grant of Patent Rights Version 2
|
2
|
+
|
3
|
+
"Software" means the fastText software distributed by Facebook, Inc.
|
4
|
+
|
5
|
+
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
|
6
|
+
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
|
7
|
+
(subject to the termination provision below) license under any Necessary
|
8
|
+
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
|
9
|
+
transfer the Software. For avoidance of doubt, no license is granted under
|
10
|
+
Facebook’s rights in any patent claims that are infringed by (i) modifications
|
11
|
+
to the Software made by you or any third party or (ii) the Software in
|
12
|
+
combination with any software or other technology.
|
13
|
+
|
14
|
+
The license granted hereunder will terminate, automatically and without notice,
|
15
|
+
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
|
16
|
+
directly or indirectly, or take a direct financial interest in, any Patent
|
17
|
+
Assertion: (i) against Facebook or any of its subsidiaries or corporate
|
18
|
+
affiliates, (ii) against any party if such Patent Assertion arises in whole or
|
19
|
+
in part from any software, technology, product or service of Facebook or any of
|
20
|
+
its subsidiaries or corporate affiliates, or (iii) against any party relating
|
21
|
+
to the Software. Notwithstanding the foregoing, if Facebook or any of its
|
22
|
+
subsidiaries or corporate affiliates files a lawsuit alleging patent
|
23
|
+
infringement against you in the first instance, and you respond by filing a
|
24
|
+
patent infringement counterclaim in that lawsuit against that party that is
|
25
|
+
unrelated to the Software, the license granted hereunder will not terminate
|
26
|
+
under section (i) of this paragraph due to such counterclaim.
|
27
|
+
|
28
|
+
A "Necessary Claim" is a claim of a patent owned by Facebook that is
|
29
|
+
necessarily infringed by the Software standing alone.
|
30
|
+
|
31
|
+
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
|
32
|
+
or contributory infringement or inducement to infringe any patent, including a
|
33
|
+
cross-claim or counterclaim.
|
@@ -0,0 +1,250 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "args.h"
|
11
|
+
|
12
|
+
#include <stdlib.h>
|
13
|
+
|
14
|
+
#include <iostream>
|
15
|
+
|
16
|
+
namespace fasttext {
|
17
|
+
|
18
|
+
Args::Args() {
|
19
|
+
lr = 0.05;
|
20
|
+
dim = 100;
|
21
|
+
ws = 5;
|
22
|
+
epoch = 5;
|
23
|
+
minCount = 5;
|
24
|
+
minCountLabel = 0;
|
25
|
+
neg = 5;
|
26
|
+
wordNgrams = 1;
|
27
|
+
loss = loss_name::ns;
|
28
|
+
model = model_name::sg;
|
29
|
+
bucket = 2000000;
|
30
|
+
minn = 3;
|
31
|
+
maxn = 6;
|
32
|
+
thread = 12;
|
33
|
+
lrUpdateRate = 100;
|
34
|
+
t = 1e-4;
|
35
|
+
label = "__label__";
|
36
|
+
verbose = 2;
|
37
|
+
pretrainedVectors = "";
|
38
|
+
saveOutput = 0;
|
39
|
+
|
40
|
+
qout = false;
|
41
|
+
retrain = false;
|
42
|
+
qnorm = false;
|
43
|
+
cutoff = 0;
|
44
|
+
dsub = 2;
|
45
|
+
}
|
46
|
+
|
47
|
+
std::string Args::lossToString(loss_name ln) {
|
48
|
+
switch (ln) {
|
49
|
+
case loss_name::hs:
|
50
|
+
return "hs";
|
51
|
+
case loss_name::ns:
|
52
|
+
return "ns";
|
53
|
+
case loss_name::softmax:
|
54
|
+
return "softmax";
|
55
|
+
}
|
56
|
+
return "Unknown loss!"; // should never happen
|
57
|
+
}
|
58
|
+
|
59
|
+
void Args::parseArgs(const std::vector<std::string>& args) {
|
60
|
+
std::string command(args[1]);
|
61
|
+
if (command == "supervised") {
|
62
|
+
model = model_name::sup;
|
63
|
+
loss = loss_name::softmax;
|
64
|
+
minCount = 1;
|
65
|
+
minn = 0;
|
66
|
+
maxn = 0;
|
67
|
+
lr = 0.1;
|
68
|
+
} else if (command == "cbow") {
|
69
|
+
model = model_name::cbow;
|
70
|
+
}
|
71
|
+
int ai = 2;
|
72
|
+
while (ai < args.size()) {
|
73
|
+
if (args[ai][0] != '-') {
|
74
|
+
std::cerr << "Provided argument without a dash! Usage:" << std::endl;
|
75
|
+
printHelp();
|
76
|
+
exit(EXIT_FAILURE);
|
77
|
+
}
|
78
|
+
if (args[ai] == "-h") {
|
79
|
+
std::cerr << "Here is the help! Usage:" << std::endl;
|
80
|
+
printHelp();
|
81
|
+
exit(EXIT_FAILURE);
|
82
|
+
} else if (args[ai] == "-input") {
|
83
|
+
input = std::string(args[ai + 1]);
|
84
|
+
} else if (args[ai] == "-test") {
|
85
|
+
test = std::string(args[ai + 1]);
|
86
|
+
} else if (args[ai] == "-output") {
|
87
|
+
output = std::string(args[ai + 1]);
|
88
|
+
} else if (args[ai] == "-lr") {
|
89
|
+
lr = std::stof(args[ai + 1]);
|
90
|
+
} else if (args[ai] == "-lrUpdateRate") {
|
91
|
+
lrUpdateRate = std::stoi(args[ai + 1]);
|
92
|
+
} else if (args[ai] == "-dim") {
|
93
|
+
dim = std::stoi(args[ai + 1]);
|
94
|
+
} else if (args[ai] == "-ws") {
|
95
|
+
ws = std::stoi(args[ai + 1]);
|
96
|
+
} else if (args[ai] == "-epoch") {
|
97
|
+
epoch = std::stoi(args[ai + 1]);
|
98
|
+
} else if (args[ai] == "-minCount") {
|
99
|
+
minCount = std::stoi(args[ai + 1]);
|
100
|
+
} else if (args[ai] == "-minCountLabel") {
|
101
|
+
minCountLabel = std::stoi(args[ai + 1]);
|
102
|
+
} else if (args[ai] == "-neg") {
|
103
|
+
neg = std::stoi(args[ai + 1]);
|
104
|
+
} else if (args[ai] == "-wordNgrams") {
|
105
|
+
wordNgrams = std::stoi(args[ai + 1]);
|
106
|
+
} else if (args[ai] == "-loss") {
|
107
|
+
if (args[ai + 1] == "hs") {
|
108
|
+
loss = loss_name::hs;
|
109
|
+
} else if (args[ai + 1] == "ns") {
|
110
|
+
loss = loss_name::ns;
|
111
|
+
} else if (args[ai + 1] == "softmax") {
|
112
|
+
loss = loss_name::softmax;
|
113
|
+
} else {
|
114
|
+
std::cerr << "Unknown loss: " << args[ai + 1] << std::endl;
|
115
|
+
printHelp();
|
116
|
+
exit(EXIT_FAILURE);
|
117
|
+
}
|
118
|
+
} else if (args[ai] == "-bucket") {
|
119
|
+
bucket = std::stoi(args[ai + 1]);
|
120
|
+
} else if (args[ai] == "-minn") {
|
121
|
+
minn = std::stoi(args[ai + 1]);
|
122
|
+
} else if (args[ai] == "-maxn") {
|
123
|
+
maxn = std::stoi(args[ai + 1]);
|
124
|
+
} else if (args[ai] == "-thread") {
|
125
|
+
thread = std::stoi(args[ai + 1]);
|
126
|
+
} else if (args[ai] == "-t") {
|
127
|
+
t = std::stof(args[ai + 1]);
|
128
|
+
} else if (args[ai] == "-label") {
|
129
|
+
label = std::string(args[ai + 1]);
|
130
|
+
} else if (args[ai] == "-verbose") {
|
131
|
+
verbose = std::stoi(args[ai + 1]);
|
132
|
+
} else if (args[ai] == "-pretrainedVectors") {
|
133
|
+
pretrainedVectors = std::string(args[ai + 1]);
|
134
|
+
} else if (args[ai] == "-saveOutput") {
|
135
|
+
saveOutput = std::stoi(args[ai + 1]);
|
136
|
+
} else if (args[ai] == "-qnorm") {
|
137
|
+
qnorm = true; ai--;
|
138
|
+
} else if (args[ai] == "-retrain") {
|
139
|
+
retrain = true; ai--;
|
140
|
+
} else if (args[ai] == "-qout") {
|
141
|
+
qout = true; ai--;
|
142
|
+
} else if (args[ai] == "-cutoff") {
|
143
|
+
cutoff = std::stoi(args[ai + 1]);
|
144
|
+
} else if (args[ai] == "-dsub") {
|
145
|
+
dsub = std::stoi(args[ai + 1]);
|
146
|
+
} else {
|
147
|
+
std::cerr << "Unknown argument: " << args[ai] << std::endl;
|
148
|
+
printHelp();
|
149
|
+
exit(EXIT_FAILURE);
|
150
|
+
}
|
151
|
+
ai += 2;
|
152
|
+
}
|
153
|
+
if (input.empty() || output.empty()) {
|
154
|
+
std::cerr << "Empty input or output path." << std::endl;
|
155
|
+
printHelp();
|
156
|
+
exit(EXIT_FAILURE);
|
157
|
+
}
|
158
|
+
if (wordNgrams <= 1 && maxn == 0) {
|
159
|
+
bucket = 0;
|
160
|
+
}
|
161
|
+
}
|
162
|
+
|
163
|
+
void Args::printHelp() {
|
164
|
+
printBasicHelp();
|
165
|
+
printDictionaryHelp();
|
166
|
+
printTrainingHelp();
|
167
|
+
printQuantizationHelp();
|
168
|
+
}
|
169
|
+
|
170
|
+
|
171
|
+
void Args::printBasicHelp() {
|
172
|
+
std::cerr
|
173
|
+
<< "\nThe following arguments are mandatory:\n"
|
174
|
+
<< " -input training file path\n"
|
175
|
+
<< " -output output file path\n"
|
176
|
+
<< "\nThe following arguments are optional:\n"
|
177
|
+
<< " -verbose verbosity level [" << verbose << "]\n";
|
178
|
+
}
|
179
|
+
|
180
|
+
void Args::printDictionaryHelp() {
|
181
|
+
std::cerr
|
182
|
+
<< "\nThe following arguments for the dictionary are optional:\n"
|
183
|
+
<< " -minCount minimal number of word occurences [" << minCount << "]\n"
|
184
|
+
<< " -minCountLabel minimal number of label occurences [" << minCountLabel << "]\n"
|
185
|
+
<< " -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
|
186
|
+
<< " -bucket number of buckets [" << bucket << "]\n"
|
187
|
+
<< " -minn min length of char ngram [" << minn << "]\n"
|
188
|
+
<< " -maxn max length of char ngram [" << maxn << "]\n"
|
189
|
+
<< " -t sampling threshold [" << t << "]\n"
|
190
|
+
<< " -label labels prefix [" << label << "]\n";
|
191
|
+
}
|
192
|
+
|
193
|
+
void Args::printTrainingHelp() {
|
194
|
+
std::cerr
|
195
|
+
<< "\nThe following arguments for training are optional:\n"
|
196
|
+
<< " -lr learning rate [" << lr << "]\n"
|
197
|
+
<< " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
|
198
|
+
<< " -dim size of word vectors [" << dim << "]\n"
|
199
|
+
<< " -ws size of the context window [" << ws << "]\n"
|
200
|
+
<< " -epoch number of epochs [" << epoch << "]\n"
|
201
|
+
<< " -neg number of negatives sampled [" << neg << "]\n"
|
202
|
+
<< " -loss loss function {ns, hs, softmax} [" << lossToString(loss) << "]\n"
|
203
|
+
<< " -thread number of threads [" << thread << "]\n"
|
204
|
+
<< " -pretrainedVectors pretrained word vectors for supervised learning ["<< pretrainedVectors <<"]\n"
|
205
|
+
<< " -saveOutput whether output params should be saved [" << saveOutput << "]\n";
|
206
|
+
}
|
207
|
+
|
208
|
+
void Args::printQuantizationHelp() {
|
209
|
+
std::cerr
|
210
|
+
<< "\nThe following arguments for quantization are optional:\n"
|
211
|
+
<< " -cutoff number of words and ngrams to retain [" << cutoff << "]\n"
|
212
|
+
<< " -retrain finetune embeddings if a cutoff is applied [" << retrain << "]\n"
|
213
|
+
<< " -qnorm quantizing the norm separately [" << qnorm << "]\n"
|
214
|
+
<< " -qout quantizing the classifier [" << qout << "]\n"
|
215
|
+
<< " -dsub size of each sub-vector [" << dsub << "]\n";
|
216
|
+
}
|
217
|
+
|
218
|
+
void Args::save(std::ostream& out) {
|
219
|
+
out.write((char*) &(dim), sizeof(int));
|
220
|
+
out.write((char*) &(ws), sizeof(int));
|
221
|
+
out.write((char*) &(epoch), sizeof(int));
|
222
|
+
out.write((char*) &(minCount), sizeof(int));
|
223
|
+
out.write((char*) &(neg), sizeof(int));
|
224
|
+
out.write((char*) &(wordNgrams), sizeof(int));
|
225
|
+
out.write((char*) &(loss), sizeof(loss_name));
|
226
|
+
out.write((char*) &(model), sizeof(model_name));
|
227
|
+
out.write((char*) &(bucket), sizeof(int));
|
228
|
+
out.write((char*) &(minn), sizeof(int));
|
229
|
+
out.write((char*) &(maxn), sizeof(int));
|
230
|
+
out.write((char*) &(lrUpdateRate), sizeof(int));
|
231
|
+
out.write((char*) &(t), sizeof(double));
|
232
|
+
}
|
233
|
+
|
234
|
+
void Args::load(std::istream& in) {
|
235
|
+
in.read((char*) &(dim), sizeof(int));
|
236
|
+
in.read((char*) &(ws), sizeof(int));
|
237
|
+
in.read((char*) &(epoch), sizeof(int));
|
238
|
+
in.read((char*) &(minCount), sizeof(int));
|
239
|
+
in.read((char*) &(neg), sizeof(int));
|
240
|
+
in.read((char*) &(wordNgrams), sizeof(int));
|
241
|
+
in.read((char*) &(loss), sizeof(loss_name));
|
242
|
+
in.read((char*) &(model), sizeof(model_name));
|
243
|
+
in.read((char*) &(bucket), sizeof(int));
|
244
|
+
in.read((char*) &(minn), sizeof(int));
|
245
|
+
in.read((char*) &(maxn), sizeof(int));
|
246
|
+
in.read((char*) &(lrUpdateRate), sizeof(int));
|
247
|
+
in.read((char*) &(t), sizeof(double));
|
248
|
+
}
|
249
|
+
|
250
|
+
}
|
@@ -0,0 +1,71 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_ARGS_H
|
11
|
+
#define FASTTEXT_ARGS_H
|
12
|
+
|
13
|
+
#include <istream>
|
14
|
+
#include <ostream>
|
15
|
+
#include <string>
|
16
|
+
#include <vector>
|
17
|
+
|
18
|
+
namespace fasttext {
|
19
|
+
|
20
|
+
enum class model_name : int {cbow=1, sg, sup};
|
21
|
+
enum class loss_name : int {hs=1, ns, softmax};
|
22
|
+
|
23
|
+
class Args {
|
24
|
+
private:
|
25
|
+
std::string lossToString(loss_name);
|
26
|
+
|
27
|
+
public:
|
28
|
+
Args();
|
29
|
+
std::string input;
|
30
|
+
std::string test;
|
31
|
+
std::string output;
|
32
|
+
double lr;
|
33
|
+
int lrUpdateRate;
|
34
|
+
int dim;
|
35
|
+
int ws;
|
36
|
+
int epoch;
|
37
|
+
int minCount;
|
38
|
+
int minCountLabel;
|
39
|
+
int neg;
|
40
|
+
int wordNgrams;
|
41
|
+
loss_name loss;
|
42
|
+
model_name model;
|
43
|
+
int bucket;
|
44
|
+
int minn;
|
45
|
+
int maxn;
|
46
|
+
int thread;
|
47
|
+
double t;
|
48
|
+
std::string label;
|
49
|
+
int verbose;
|
50
|
+
std::string pretrainedVectors;
|
51
|
+
int saveOutput;
|
52
|
+
|
53
|
+
bool qout;
|
54
|
+
bool retrain;
|
55
|
+
bool qnorm;
|
56
|
+
size_t cutoff;
|
57
|
+
size_t dsub;
|
58
|
+
|
59
|
+
void parseArgs(const std::vector<std::string>& args);
|
60
|
+
void printHelp();
|
61
|
+
void printBasicHelp();
|
62
|
+
void printDictionaryHelp();
|
63
|
+
void printTrainingHelp();
|
64
|
+
void printQuantizationHelp();
|
65
|
+
void save(std::ostream&);
|
66
|
+
void load(std::istream&);
|
67
|
+
};
|
68
|
+
|
69
|
+
}
|
70
|
+
|
71
|
+
#endif
|
@@ -0,0 +1,475 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "dictionary.h"
|
11
|
+
|
12
|
+
#include <assert.h>
|
13
|
+
|
14
|
+
#include <iostream>
|
15
|
+
#include <fstream>
|
16
|
+
#include <algorithm>
|
17
|
+
#include <iterator>
|
18
|
+
#include <cmath>
|
19
|
+
|
20
|
+
namespace fasttext {
|
21
|
+
|
22
|
+
const std::string Dictionary::EOS = "</s>";
|
23
|
+
const std::string Dictionary::BOW = "<";
|
24
|
+
const std::string Dictionary::EOW = ">";
|
25
|
+
|
26
|
+
Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
|
27
|
+
word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
|
28
|
+
ntokens_(0), pruneidx_size_(-1) {}
|
29
|
+
|
30
|
+
int32_t Dictionary::find(const std::string& w) const {
|
31
|
+
return find(w, hash(w));
|
32
|
+
}
|
33
|
+
|
34
|
+
int32_t Dictionary::find(const std::string& w, uint32_t h) const {
|
35
|
+
int32_t id = h % MAX_VOCAB_SIZE;
|
36
|
+
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
|
37
|
+
id = (id + 1) % MAX_VOCAB_SIZE;
|
38
|
+
}
|
39
|
+
return id;
|
40
|
+
}
|
41
|
+
|
42
|
+
void Dictionary::add(const std::string& w) {
|
43
|
+
int32_t h = find(w);
|
44
|
+
ntokens_++;
|
45
|
+
if (word2int_[h] == -1) {
|
46
|
+
entry e;
|
47
|
+
e.word = w;
|
48
|
+
e.count = 1;
|
49
|
+
e.type = getType(w);
|
50
|
+
words_.push_back(e);
|
51
|
+
word2int_[h] = size_++;
|
52
|
+
} else {
|
53
|
+
words_[word2int_[h]].count++;
|
54
|
+
}
|
55
|
+
}
|
56
|
+
|
57
|
+
int32_t Dictionary::nwords() const {
|
58
|
+
return nwords_;
|
59
|
+
}
|
60
|
+
|
61
|
+
int32_t Dictionary::nlabels() const {
|
62
|
+
return nlabels_;
|
63
|
+
}
|
64
|
+
|
65
|
+
int64_t Dictionary::ntokens() const {
|
66
|
+
return ntokens_;
|
67
|
+
}
|
68
|
+
|
69
|
+
const std::vector<int32_t>& Dictionary::getSubwords(int32_t i) const {
|
70
|
+
assert(i >= 0);
|
71
|
+
assert(i < nwords_);
|
72
|
+
return words_[i].subwords;
|
73
|
+
}
|
74
|
+
|
75
|
+
const std::vector<int32_t> Dictionary::getSubwords(
|
76
|
+
const std::string& word) const {
|
77
|
+
int32_t i = getId(word);
|
78
|
+
if (i >= 0) {
|
79
|
+
return getSubwords(i);
|
80
|
+
}
|
81
|
+
std::vector<int32_t> ngrams;
|
82
|
+
computeSubwords(BOW + word + EOW, ngrams);
|
83
|
+
return ngrams;
|
84
|
+
}
|
85
|
+
|
86
|
+
void Dictionary::getSubwords(const std::string& word,
|
87
|
+
std::vector<int32_t>& ngrams,
|
88
|
+
std::vector<std::string>& substrings) const {
|
89
|
+
int32_t i = getId(word);
|
90
|
+
ngrams.clear();
|
91
|
+
substrings.clear();
|
92
|
+
if (i >= 0) {
|
93
|
+
ngrams.push_back(i);
|
94
|
+
substrings.push_back(words_[i].word);
|
95
|
+
} else {
|
96
|
+
ngrams.push_back(-1);
|
97
|
+
substrings.push_back(word);
|
98
|
+
}
|
99
|
+
computeSubwords(BOW + word + EOW, ngrams, substrings);
|
100
|
+
}
|
101
|
+
|
102
|
+
bool Dictionary::discard(int32_t id, real rand) const {
|
103
|
+
assert(id >= 0);
|
104
|
+
assert(id < nwords_);
|
105
|
+
if (args_->model == model_name::sup) return false;
|
106
|
+
return rand > pdiscard_[id];
|
107
|
+
}
|
108
|
+
|
109
|
+
int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
|
110
|
+
int32_t id = find(w, h);
|
111
|
+
return word2int_[id];
|
112
|
+
}
|
113
|
+
|
114
|
+
int32_t Dictionary::getId(const std::string& w) const {
|
115
|
+
int32_t h = find(w);
|
116
|
+
return word2int_[h];
|
117
|
+
}
|
118
|
+
|
119
|
+
entry_type Dictionary::getType(int32_t id) const {
|
120
|
+
assert(id >= 0);
|
121
|
+
assert(id < size_);
|
122
|
+
return words_[id].type;
|
123
|
+
}
|
124
|
+
|
125
|
+
entry_type Dictionary::getType(const std::string& w) const {
|
126
|
+
return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
|
127
|
+
}
|
128
|
+
|
129
|
+
std::string Dictionary::getWord(int32_t id) const {
|
130
|
+
assert(id >= 0);
|
131
|
+
assert(id < size_);
|
132
|
+
return words_[id].word;
|
133
|
+
}
|
134
|
+
|
135
|
+
uint32_t Dictionary::hash(const std::string& str) const {
|
136
|
+
uint32_t h = 2166136261;
|
137
|
+
for (size_t i = 0; i < str.size(); i++) {
|
138
|
+
h = h ^ uint32_t(str[i]);
|
139
|
+
h = h * 16777619;
|
140
|
+
}
|
141
|
+
return h;
|
142
|
+
}
|
143
|
+
|
144
|
+
void Dictionary::computeSubwords(const std::string& word,
|
145
|
+
std::vector<int32_t>& ngrams,
|
146
|
+
std::vector<std::string>& substrings) const {
|
147
|
+
for (size_t i = 0; i < word.size(); i++) {
|
148
|
+
std::string ngram;
|
149
|
+
if ((word[i] & 0xC0) == 0x80) continue;
|
150
|
+
for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
|
151
|
+
ngram.push_back(word[j++]);
|
152
|
+
while (j < word.size() && (word[j] & 0xC0) == 0x80) {
|
153
|
+
ngram.push_back(word[j++]);
|
154
|
+
}
|
155
|
+
if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
|
156
|
+
int32_t h = hash(ngram) % args_->bucket;
|
157
|
+
ngrams.push_back(nwords_ + h);
|
158
|
+
substrings.push_back(ngram);
|
159
|
+
}
|
160
|
+
}
|
161
|
+
}
|
162
|
+
}
|
163
|
+
|
164
|
+
void Dictionary::computeSubwords(const std::string& word,
|
165
|
+
std::vector<int32_t>& ngrams) const {
|
166
|
+
for (size_t i = 0; i < word.size(); i++) {
|
167
|
+
std::string ngram;
|
168
|
+
if ((word[i] & 0xC0) == 0x80) continue;
|
169
|
+
for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
|
170
|
+
ngram.push_back(word[j++]);
|
171
|
+
while (j < word.size() && (word[j] & 0xC0) == 0x80) {
|
172
|
+
ngram.push_back(word[j++]);
|
173
|
+
}
|
174
|
+
if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
|
175
|
+
int32_t h = hash(ngram) % args_->bucket;
|
176
|
+
pushHash(ngrams, h);
|
177
|
+
}
|
178
|
+
}
|
179
|
+
}
|
180
|
+
}
|
181
|
+
|
182
|
+
void Dictionary::initNgrams() {
|
183
|
+
for (size_t i = 0; i < size_; i++) {
|
184
|
+
std::string word = BOW + words_[i].word + EOW;
|
185
|
+
words_[i].subwords.clear();
|
186
|
+
words_[i].subwords.push_back(i);
|
187
|
+
if (words_[i].word != EOS) {
|
188
|
+
computeSubwords(word, words_[i].subwords);
|
189
|
+
}
|
190
|
+
}
|
191
|
+
}
|
192
|
+
|
193
|
+
bool Dictionary::readWord(std::istream& in, std::string& word) const
|
194
|
+
{
|
195
|
+
char c;
|
196
|
+
std::streambuf& sb = *in.rdbuf();
|
197
|
+
word.clear();
|
198
|
+
while ((c = sb.sbumpc()) != EOF) {
|
199
|
+
if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
|
200
|
+
c == '\f' || c == '\0') {
|
201
|
+
if (word.empty()) {
|
202
|
+
if (c == '\n') {
|
203
|
+
word += EOS;
|
204
|
+
return true;
|
205
|
+
}
|
206
|
+
continue;
|
207
|
+
} else {
|
208
|
+
if (c == '\n')
|
209
|
+
sb.sungetc();
|
210
|
+
return true;
|
211
|
+
}
|
212
|
+
}
|
213
|
+
word.push_back(c);
|
214
|
+
}
|
215
|
+
// trigger eofbit
|
216
|
+
in.get();
|
217
|
+
return !word.empty();
|
218
|
+
}
|
219
|
+
|
220
|
+
void Dictionary::readFromFile(std::istream& in) {
|
221
|
+
std::string word;
|
222
|
+
int64_t minThreshold = 1;
|
223
|
+
while (readWord(in, word)) {
|
224
|
+
add(word);
|
225
|
+
if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
|
226
|
+
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
|
227
|
+
}
|
228
|
+
if (size_ > 0.75 * MAX_VOCAB_SIZE) {
|
229
|
+
minThreshold++;
|
230
|
+
threshold(minThreshold, minThreshold);
|
231
|
+
}
|
232
|
+
}
|
233
|
+
threshold(args_->minCount, args_->minCountLabel);
|
234
|
+
initTableDiscard();
|
235
|
+
initNgrams();
|
236
|
+
if (args_->verbose > 0) {
|
237
|
+
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
|
238
|
+
std::cerr << "Number of words: " << nwords_ << std::endl;
|
239
|
+
std::cerr << "Number of labels: " << nlabels_ << std::endl;
|
240
|
+
}
|
241
|
+
if (size_ == 0) {
|
242
|
+
std::cerr << "Empty vocabulary. Try a smaller -minCount value."
|
243
|
+
<< std::endl;
|
244
|
+
exit(EXIT_FAILURE);
|
245
|
+
}
|
246
|
+
}
|
247
|
+
|
248
|
+
void Dictionary::threshold(int64_t t, int64_t tl) {
|
249
|
+
sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
|
250
|
+
if (e1.type != e2.type) return e1.type < e2.type;
|
251
|
+
return e1.count > e2.count;
|
252
|
+
});
|
253
|
+
words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
|
254
|
+
return (e.type == entry_type::word && e.count < t) ||
|
255
|
+
(e.type == entry_type::label && e.count < tl);
|
256
|
+
}), words_.end());
|
257
|
+
words_.shrink_to_fit();
|
258
|
+
size_ = 0;
|
259
|
+
nwords_ = 0;
|
260
|
+
nlabels_ = 0;
|
261
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
262
|
+
for (auto it = words_.begin(); it != words_.end(); ++it) {
|
263
|
+
int32_t h = find(it->word);
|
264
|
+
word2int_[h] = size_++;
|
265
|
+
if (it->type == entry_type::word) nwords_++;
|
266
|
+
if (it->type == entry_type::label) nlabels_++;
|
267
|
+
}
|
268
|
+
}
|
269
|
+
|
270
|
+
void Dictionary::initTableDiscard() {
|
271
|
+
pdiscard_.resize(size_);
|
272
|
+
for (size_t i = 0; i < size_; i++) {
|
273
|
+
real f = real(words_[i].count) / real(ntokens_);
|
274
|
+
pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f;
|
275
|
+
}
|
276
|
+
}
|
277
|
+
|
278
|
+
std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
|
279
|
+
std::vector<int64_t> counts;
|
280
|
+
for (auto& w : words_) {
|
281
|
+
if (w.type == type) counts.push_back(w.count);
|
282
|
+
}
|
283
|
+
return counts;
|
284
|
+
}
|
285
|
+
|
286
|
+
void Dictionary::addWordNgrams(std::vector<int32_t>& line,
|
287
|
+
const std::vector<int32_t>& hashes,
|
288
|
+
int32_t n) const {
|
289
|
+
for (int32_t i = 0; i < hashes.size(); i++) {
|
290
|
+
uint64_t h = hashes[i];
|
291
|
+
for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
|
292
|
+
h = h * 116049371 + hashes[j];
|
293
|
+
pushHash(line, h % args_->bucket);
|
294
|
+
}
|
295
|
+
}
|
296
|
+
}
|
297
|
+
|
298
|
+
void Dictionary::addSubwords(std::vector<int32_t>& line,
|
299
|
+
const std::string& token,
|
300
|
+
int32_t wid) const {
|
301
|
+
if (wid < 0) { // out of vocab
|
302
|
+
computeSubwords(BOW + token + EOW, line);
|
303
|
+
} else {
|
304
|
+
if (args_->maxn <= 0) { // in vocab w/o subwords
|
305
|
+
line.push_back(wid);
|
306
|
+
} else { // in vocab w/ subwords
|
307
|
+
const std::vector<int32_t>& ngrams = getSubwords(wid);
|
308
|
+
line.insert(line.end(), ngrams.cbegin(), ngrams.cend());
|
309
|
+
}
|
310
|
+
}
|
311
|
+
}
|
312
|
+
|
313
|
+
void Dictionary::reset(std::istream& in) const {
|
314
|
+
if (in.eof()) {
|
315
|
+
in.clear();
|
316
|
+
in.seekg(std::streampos(0));
|
317
|
+
}
|
318
|
+
}
|
319
|
+
|
320
|
+
int32_t Dictionary::getLine(std::istream& in,
|
321
|
+
std::vector<int32_t>& words,
|
322
|
+
std::minstd_rand& rng) const {
|
323
|
+
std::uniform_real_distribution<> uniform(0, 1);
|
324
|
+
std::string token;
|
325
|
+
int32_t ntokens = 0;
|
326
|
+
|
327
|
+
reset(in);
|
328
|
+
words.clear();
|
329
|
+
while (readWord(in, token)) {
|
330
|
+
int32_t h = find(token);
|
331
|
+
int32_t wid = word2int_[h];
|
332
|
+
if (wid < 0) continue;
|
333
|
+
|
334
|
+
ntokens++;
|
335
|
+
if (getType(wid) == entry_type::word && !discard(wid, uniform(rng))) {
|
336
|
+
words.push_back(wid);
|
337
|
+
}
|
338
|
+
if (ntokens > MAX_LINE_SIZE || token == EOS) break;
|
339
|
+
}
|
340
|
+
return ntokens;
|
341
|
+
}
|
342
|
+
|
343
|
+
int32_t Dictionary::getLine(std::istream& in,
|
344
|
+
std::vector<int32_t>& words,
|
345
|
+
std::vector<int32_t>& labels,
|
346
|
+
std::minstd_rand& rng) const {
|
347
|
+
std::vector<int32_t> word_hashes;
|
348
|
+
std::string token;
|
349
|
+
int32_t ntokens = 0;
|
350
|
+
|
351
|
+
reset(in);
|
352
|
+
words.clear();
|
353
|
+
labels.clear();
|
354
|
+
while (readWord(in, token)) {
|
355
|
+
uint32_t h = hash(token);
|
356
|
+
int32_t wid = getId(token, h);
|
357
|
+
entry_type type = wid < 0 ? getType(token) : getType(wid);
|
358
|
+
|
359
|
+
ntokens++;
|
360
|
+
if (type == entry_type::word) {
|
361
|
+
addSubwords(words, token, wid);
|
362
|
+
word_hashes.push_back(h);
|
363
|
+
} else if (type == entry_type::label && wid >= 0) {
|
364
|
+
labels.push_back(wid - nwords_);
|
365
|
+
}
|
366
|
+
if (token == EOS) break;
|
367
|
+
}
|
368
|
+
addWordNgrams(words, word_hashes, args_->wordNgrams);
|
369
|
+
return ntokens;
|
370
|
+
}
|
371
|
+
|
372
|
+
void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
|
373
|
+
if (pruneidx_size_ == 0 || id < 0) return;
|
374
|
+
if (pruneidx_size_ > 0) {
|
375
|
+
if (pruneidx_.count(id)) {
|
376
|
+
id = pruneidx_.at(id);
|
377
|
+
} else {
|
378
|
+
return;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
hashes.push_back(nwords_ + id);
|
382
|
+
}
|
383
|
+
|
384
|
+
std::string Dictionary::getLabel(int32_t lid) const {
|
385
|
+
assert(lid >= 0);
|
386
|
+
assert(lid < nlabels_);
|
387
|
+
return words_[lid + nwords_].word;
|
388
|
+
}
|
389
|
+
|
390
|
+
void Dictionary::save(std::ostream& out) const {
|
391
|
+
out.write((char*) &size_, sizeof(int32_t));
|
392
|
+
out.write((char*) &nwords_, sizeof(int32_t));
|
393
|
+
out.write((char*) &nlabels_, sizeof(int32_t));
|
394
|
+
out.write((char*) &ntokens_, sizeof(int64_t));
|
395
|
+
out.write((char*) &pruneidx_size_, sizeof(int64_t));
|
396
|
+
for (int32_t i = 0; i < size_; i++) {
|
397
|
+
entry e = words_[i];
|
398
|
+
out.write(e.word.data(), e.word.size() * sizeof(char));
|
399
|
+
out.put(0);
|
400
|
+
out.write((char*) &(e.count), sizeof(int64_t));
|
401
|
+
out.write((char*) &(e.type), sizeof(entry_type));
|
402
|
+
}
|
403
|
+
for (const auto pair : pruneidx_) {
|
404
|
+
out.write((char*) &(pair.first), sizeof(int32_t));
|
405
|
+
out.write((char*) &(pair.second), sizeof(int32_t));
|
406
|
+
}
|
407
|
+
}
|
408
|
+
|
409
|
+
void Dictionary::load(std::istream& in) {
|
410
|
+
words_.clear();
|
411
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
412
|
+
in.read((char*) &size_, sizeof(int32_t));
|
413
|
+
in.read((char*) &nwords_, sizeof(int32_t));
|
414
|
+
in.read((char*) &nlabels_, sizeof(int32_t));
|
415
|
+
in.read((char*) &ntokens_, sizeof(int64_t));
|
416
|
+
in.read((char*) &pruneidx_size_, sizeof(int64_t));
|
417
|
+
for (int32_t i = 0; i < size_; i++) {
|
418
|
+
char c;
|
419
|
+
entry e;
|
420
|
+
while ((c = in.get()) != 0) {
|
421
|
+
e.word.push_back(c);
|
422
|
+
}
|
423
|
+
in.read((char*) &e.count, sizeof(int64_t));
|
424
|
+
in.read((char*) &e.type, sizeof(entry_type));
|
425
|
+
words_.push_back(e);
|
426
|
+
word2int_[find(e.word)] = i;
|
427
|
+
}
|
428
|
+
pruneidx_.clear();
|
429
|
+
for (int32_t i = 0; i < pruneidx_size_; i++) {
|
430
|
+
int32_t first;
|
431
|
+
int32_t second;
|
432
|
+
in.read((char*) &first, sizeof(int32_t));
|
433
|
+
in.read((char*) &second, sizeof(int32_t));
|
434
|
+
pruneidx_[first] = second;
|
435
|
+
}
|
436
|
+
initTableDiscard();
|
437
|
+
initNgrams();
|
438
|
+
}
|
439
|
+
|
440
|
+
void Dictionary::prune(std::vector<int32_t>& idx) {
|
441
|
+
std::vector<int32_t> words, ngrams;
|
442
|
+
for (auto it = idx.cbegin(); it != idx.cend(); ++it) {
|
443
|
+
if (*it < nwords_) {words.push_back(*it);}
|
444
|
+
else {ngrams.push_back(*it);}
|
445
|
+
}
|
446
|
+
std::sort(words.begin(), words.end());
|
447
|
+
idx = words;
|
448
|
+
|
449
|
+
if (ngrams.size() != 0) {
|
450
|
+
int32_t j = 0;
|
451
|
+
for (const auto ngram : ngrams) {
|
452
|
+
pruneidx_[ngram - nwords_] = j;
|
453
|
+
j++;
|
454
|
+
}
|
455
|
+
idx.insert(idx.end(), ngrams.begin(), ngrams.end());
|
456
|
+
}
|
457
|
+
pruneidx_size_ = pruneidx_.size();
|
458
|
+
|
459
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
460
|
+
|
461
|
+
int32_t j = 0;
|
462
|
+
for (int32_t i = 0; i < words_.size(); i++) {
|
463
|
+
if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) {
|
464
|
+
words_[j] = words_[i];
|
465
|
+
word2int_[find(words_[j].word)] = j;
|
466
|
+
j++;
|
467
|
+
}
|
468
|
+
}
|
469
|
+
nwords_ = words.size();
|
470
|
+
size_ = nwords_ + nlabels_;
|
471
|
+
words_.erase(words_.begin() + size_, words_.end());
|
472
|
+
initNgrams();
|
473
|
+
}
|
474
|
+
|
475
|
+
}
|