ffi-fasttext 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
@@ -0,0 +1,33 @@
|
|
1
|
+
Additional Grant of Patent Rights Version 2
|
2
|
+
|
3
|
+
"Software" means the fastText software distributed by Facebook, Inc.
|
4
|
+
|
5
|
+
Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
|
6
|
+
("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
|
7
|
+
(subject to the termination provision below) license under any Necessary
|
8
|
+
Claims, to make, have made, use, sell, offer to sell, import, and otherwise
|
9
|
+
transfer the Software. For avoidance of doubt, no license is granted under
|
10
|
+
Facebook’s rights in any patent claims that are infringed by (i) modifications
|
11
|
+
to the Software made by you or any third party or (ii) the Software in
|
12
|
+
combination with any software or other technology.
|
13
|
+
|
14
|
+
The license granted hereunder will terminate, automatically and without notice,
|
15
|
+
if you (or any of your subsidiaries, corporate affiliates or agents) initiate
|
16
|
+
directly or indirectly, or take a direct financial interest in, any Patent
|
17
|
+
Assertion: (i) against Facebook or any of its subsidiaries or corporate
|
18
|
+
affiliates, (ii) against any party if such Patent Assertion arises in whole or
|
19
|
+
in part from any software, technology, product or service of Facebook or any of
|
20
|
+
its subsidiaries or corporate affiliates, or (iii) against any party relating
|
21
|
+
to the Software. Notwithstanding the foregoing, if Facebook or any of its
|
22
|
+
subsidiaries or corporate affiliates files a lawsuit alleging patent
|
23
|
+
infringement against you in the first instance, and you respond by filing a
|
24
|
+
patent infringement counterclaim in that lawsuit against that party that is
|
25
|
+
unrelated to the Software, the license granted hereunder will not terminate
|
26
|
+
under section (i) of this paragraph due to such counterclaim.
|
27
|
+
|
28
|
+
A "Necessary Claim" is a claim of a patent owned by Facebook that is
|
29
|
+
necessarily infringed by the Software standing alone.
|
30
|
+
|
31
|
+
A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
|
32
|
+
or contributory infringement or inducement to infringe any patent, including a
|
33
|
+
cross-claim or counterclaim.
|
@@ -0,0 +1,250 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "args.h"
|
11
|
+
|
12
|
+
#include <stdlib.h>
|
13
|
+
|
14
|
+
#include <iostream>
|
15
|
+
|
16
|
+
namespace fasttext {
|
17
|
+
|
18
|
+
Args::Args() {
|
19
|
+
lr = 0.05;
|
20
|
+
dim = 100;
|
21
|
+
ws = 5;
|
22
|
+
epoch = 5;
|
23
|
+
minCount = 5;
|
24
|
+
minCountLabel = 0;
|
25
|
+
neg = 5;
|
26
|
+
wordNgrams = 1;
|
27
|
+
loss = loss_name::ns;
|
28
|
+
model = model_name::sg;
|
29
|
+
bucket = 2000000;
|
30
|
+
minn = 3;
|
31
|
+
maxn = 6;
|
32
|
+
thread = 12;
|
33
|
+
lrUpdateRate = 100;
|
34
|
+
t = 1e-4;
|
35
|
+
label = "__label__";
|
36
|
+
verbose = 2;
|
37
|
+
pretrainedVectors = "";
|
38
|
+
saveOutput = 0;
|
39
|
+
|
40
|
+
qout = false;
|
41
|
+
retrain = false;
|
42
|
+
qnorm = false;
|
43
|
+
cutoff = 0;
|
44
|
+
dsub = 2;
|
45
|
+
}
|
46
|
+
|
47
|
+
std::string Args::lossToString(loss_name ln) {
|
48
|
+
switch (ln) {
|
49
|
+
case loss_name::hs:
|
50
|
+
return "hs";
|
51
|
+
case loss_name::ns:
|
52
|
+
return "ns";
|
53
|
+
case loss_name::softmax:
|
54
|
+
return "softmax";
|
55
|
+
}
|
56
|
+
return "Unknown loss!"; // should never happen
|
57
|
+
}
|
58
|
+
|
59
|
+
void Args::parseArgs(const std::vector<std::string>& args) {
|
60
|
+
std::string command(args[1]);
|
61
|
+
if (command == "supervised") {
|
62
|
+
model = model_name::sup;
|
63
|
+
loss = loss_name::softmax;
|
64
|
+
minCount = 1;
|
65
|
+
minn = 0;
|
66
|
+
maxn = 0;
|
67
|
+
lr = 0.1;
|
68
|
+
} else if (command == "cbow") {
|
69
|
+
model = model_name::cbow;
|
70
|
+
}
|
71
|
+
int ai = 2;
|
72
|
+
while (ai < args.size()) {
|
73
|
+
if (args[ai][0] != '-') {
|
74
|
+
std::cerr << "Provided argument without a dash! Usage:" << std::endl;
|
75
|
+
printHelp();
|
76
|
+
exit(EXIT_FAILURE);
|
77
|
+
}
|
78
|
+
if (args[ai] == "-h") {
|
79
|
+
std::cerr << "Here is the help! Usage:" << std::endl;
|
80
|
+
printHelp();
|
81
|
+
exit(EXIT_FAILURE);
|
82
|
+
} else if (args[ai] == "-input") {
|
83
|
+
input = std::string(args[ai + 1]);
|
84
|
+
} else if (args[ai] == "-test") {
|
85
|
+
test = std::string(args[ai + 1]);
|
86
|
+
} else if (args[ai] == "-output") {
|
87
|
+
output = std::string(args[ai + 1]);
|
88
|
+
} else if (args[ai] == "-lr") {
|
89
|
+
lr = std::stof(args[ai + 1]);
|
90
|
+
} else if (args[ai] == "-lrUpdateRate") {
|
91
|
+
lrUpdateRate = std::stoi(args[ai + 1]);
|
92
|
+
} else if (args[ai] == "-dim") {
|
93
|
+
dim = std::stoi(args[ai + 1]);
|
94
|
+
} else if (args[ai] == "-ws") {
|
95
|
+
ws = std::stoi(args[ai + 1]);
|
96
|
+
} else if (args[ai] == "-epoch") {
|
97
|
+
epoch = std::stoi(args[ai + 1]);
|
98
|
+
} else if (args[ai] == "-minCount") {
|
99
|
+
minCount = std::stoi(args[ai + 1]);
|
100
|
+
} else if (args[ai] == "-minCountLabel") {
|
101
|
+
minCountLabel = std::stoi(args[ai + 1]);
|
102
|
+
} else if (args[ai] == "-neg") {
|
103
|
+
neg = std::stoi(args[ai + 1]);
|
104
|
+
} else if (args[ai] == "-wordNgrams") {
|
105
|
+
wordNgrams = std::stoi(args[ai + 1]);
|
106
|
+
} else if (args[ai] == "-loss") {
|
107
|
+
if (args[ai + 1] == "hs") {
|
108
|
+
loss = loss_name::hs;
|
109
|
+
} else if (args[ai + 1] == "ns") {
|
110
|
+
loss = loss_name::ns;
|
111
|
+
} else if (args[ai + 1] == "softmax") {
|
112
|
+
loss = loss_name::softmax;
|
113
|
+
} else {
|
114
|
+
std::cerr << "Unknown loss: " << args[ai + 1] << std::endl;
|
115
|
+
printHelp();
|
116
|
+
exit(EXIT_FAILURE);
|
117
|
+
}
|
118
|
+
} else if (args[ai] == "-bucket") {
|
119
|
+
bucket = std::stoi(args[ai + 1]);
|
120
|
+
} else if (args[ai] == "-minn") {
|
121
|
+
minn = std::stoi(args[ai + 1]);
|
122
|
+
} else if (args[ai] == "-maxn") {
|
123
|
+
maxn = std::stoi(args[ai + 1]);
|
124
|
+
} else if (args[ai] == "-thread") {
|
125
|
+
thread = std::stoi(args[ai + 1]);
|
126
|
+
} else if (args[ai] == "-t") {
|
127
|
+
t = std::stof(args[ai + 1]);
|
128
|
+
} else if (args[ai] == "-label") {
|
129
|
+
label = std::string(args[ai + 1]);
|
130
|
+
} else if (args[ai] == "-verbose") {
|
131
|
+
verbose = std::stoi(args[ai + 1]);
|
132
|
+
} else if (args[ai] == "-pretrainedVectors") {
|
133
|
+
pretrainedVectors = std::string(args[ai + 1]);
|
134
|
+
} else if (args[ai] == "-saveOutput") {
|
135
|
+
saveOutput = std::stoi(args[ai + 1]);
|
136
|
+
} else if (args[ai] == "-qnorm") {
|
137
|
+
qnorm = true; ai--;
|
138
|
+
} else if (args[ai] == "-retrain") {
|
139
|
+
retrain = true; ai--;
|
140
|
+
} else if (args[ai] == "-qout") {
|
141
|
+
qout = true; ai--;
|
142
|
+
} else if (args[ai] == "-cutoff") {
|
143
|
+
cutoff = std::stoi(args[ai + 1]);
|
144
|
+
} else if (args[ai] == "-dsub") {
|
145
|
+
dsub = std::stoi(args[ai + 1]);
|
146
|
+
} else {
|
147
|
+
std::cerr << "Unknown argument: " << args[ai] << std::endl;
|
148
|
+
printHelp();
|
149
|
+
exit(EXIT_FAILURE);
|
150
|
+
}
|
151
|
+
ai += 2;
|
152
|
+
}
|
153
|
+
if (input.empty() || output.empty()) {
|
154
|
+
std::cerr << "Empty input or output path." << std::endl;
|
155
|
+
printHelp();
|
156
|
+
exit(EXIT_FAILURE);
|
157
|
+
}
|
158
|
+
if (wordNgrams <= 1 && maxn == 0) {
|
159
|
+
bucket = 0;
|
160
|
+
}
|
161
|
+
}
|
162
|
+
|
163
|
+
void Args::printHelp() {
|
164
|
+
printBasicHelp();
|
165
|
+
printDictionaryHelp();
|
166
|
+
printTrainingHelp();
|
167
|
+
printQuantizationHelp();
|
168
|
+
}
|
169
|
+
|
170
|
+
|
171
|
+
void Args::printBasicHelp() {
|
172
|
+
std::cerr
|
173
|
+
<< "\nThe following arguments are mandatory:\n"
|
174
|
+
<< " -input training file path\n"
|
175
|
+
<< " -output output file path\n"
|
176
|
+
<< "\nThe following arguments are optional:\n"
|
177
|
+
<< " -verbose verbosity level [" << verbose << "]\n";
|
178
|
+
}
|
179
|
+
|
180
|
+
void Args::printDictionaryHelp() {
|
181
|
+
std::cerr
|
182
|
+
<< "\nThe following arguments for the dictionary are optional:\n"
|
183
|
+
<< " -minCount minimal number of word occurences [" << minCount << "]\n"
|
184
|
+
<< " -minCountLabel minimal number of label occurences [" << minCountLabel << "]\n"
|
185
|
+
<< " -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
|
186
|
+
<< " -bucket number of buckets [" << bucket << "]\n"
|
187
|
+
<< " -minn min length of char ngram [" << minn << "]\n"
|
188
|
+
<< " -maxn max length of char ngram [" << maxn << "]\n"
|
189
|
+
<< " -t sampling threshold [" << t << "]\n"
|
190
|
+
<< " -label labels prefix [" << label << "]\n";
|
191
|
+
}
|
192
|
+
|
193
|
+
void Args::printTrainingHelp() {
|
194
|
+
std::cerr
|
195
|
+
<< "\nThe following arguments for training are optional:\n"
|
196
|
+
<< " -lr learning rate [" << lr << "]\n"
|
197
|
+
<< " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
|
198
|
+
<< " -dim size of word vectors [" << dim << "]\n"
|
199
|
+
<< " -ws size of the context window [" << ws << "]\n"
|
200
|
+
<< " -epoch number of epochs [" << epoch << "]\n"
|
201
|
+
<< " -neg number of negatives sampled [" << neg << "]\n"
|
202
|
+
<< " -loss loss function {ns, hs, softmax} [" << lossToString(loss) << "]\n"
|
203
|
+
<< " -thread number of threads [" << thread << "]\n"
|
204
|
+
<< " -pretrainedVectors pretrained word vectors for supervised learning ["<< pretrainedVectors <<"]\n"
|
205
|
+
<< " -saveOutput whether output params should be saved [" << saveOutput << "]\n";
|
206
|
+
}
|
207
|
+
|
208
|
+
void Args::printQuantizationHelp() {
|
209
|
+
std::cerr
|
210
|
+
<< "\nThe following arguments for quantization are optional:\n"
|
211
|
+
<< " -cutoff number of words and ngrams to retain [" << cutoff << "]\n"
|
212
|
+
<< " -retrain finetune embeddings if a cutoff is applied [" << retrain << "]\n"
|
213
|
+
<< " -qnorm quantizing the norm separately [" << qnorm << "]\n"
|
214
|
+
<< " -qout quantizing the classifier [" << qout << "]\n"
|
215
|
+
<< " -dsub size of each sub-vector [" << dsub << "]\n";
|
216
|
+
}
|
217
|
+
|
218
|
+
void Args::save(std::ostream& out) {
|
219
|
+
out.write((char*) &(dim), sizeof(int));
|
220
|
+
out.write((char*) &(ws), sizeof(int));
|
221
|
+
out.write((char*) &(epoch), sizeof(int));
|
222
|
+
out.write((char*) &(minCount), sizeof(int));
|
223
|
+
out.write((char*) &(neg), sizeof(int));
|
224
|
+
out.write((char*) &(wordNgrams), sizeof(int));
|
225
|
+
out.write((char*) &(loss), sizeof(loss_name));
|
226
|
+
out.write((char*) &(model), sizeof(model_name));
|
227
|
+
out.write((char*) &(bucket), sizeof(int));
|
228
|
+
out.write((char*) &(minn), sizeof(int));
|
229
|
+
out.write((char*) &(maxn), sizeof(int));
|
230
|
+
out.write((char*) &(lrUpdateRate), sizeof(int));
|
231
|
+
out.write((char*) &(t), sizeof(double));
|
232
|
+
}
|
233
|
+
|
234
|
+
void Args::load(std::istream& in) {
|
235
|
+
in.read((char*) &(dim), sizeof(int));
|
236
|
+
in.read((char*) &(ws), sizeof(int));
|
237
|
+
in.read((char*) &(epoch), sizeof(int));
|
238
|
+
in.read((char*) &(minCount), sizeof(int));
|
239
|
+
in.read((char*) &(neg), sizeof(int));
|
240
|
+
in.read((char*) &(wordNgrams), sizeof(int));
|
241
|
+
in.read((char*) &(loss), sizeof(loss_name));
|
242
|
+
in.read((char*) &(model), sizeof(model_name));
|
243
|
+
in.read((char*) &(bucket), sizeof(int));
|
244
|
+
in.read((char*) &(minn), sizeof(int));
|
245
|
+
in.read((char*) &(maxn), sizeof(int));
|
246
|
+
in.read((char*) &(lrUpdateRate), sizeof(int));
|
247
|
+
in.read((char*) &(t), sizeof(double));
|
248
|
+
}
|
249
|
+
|
250
|
+
}
|
@@ -0,0 +1,71 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_ARGS_H
|
11
|
+
#define FASTTEXT_ARGS_H
|
12
|
+
|
13
|
+
#include <istream>
|
14
|
+
#include <ostream>
|
15
|
+
#include <string>
|
16
|
+
#include <vector>
|
17
|
+
|
18
|
+
namespace fasttext {
|
19
|
+
|
20
|
+
enum class model_name : int {cbow=1, sg, sup};
|
21
|
+
enum class loss_name : int {hs=1, ns, softmax};
|
22
|
+
|
23
|
+
class Args {
|
24
|
+
private:
|
25
|
+
std::string lossToString(loss_name);
|
26
|
+
|
27
|
+
public:
|
28
|
+
Args();
|
29
|
+
std::string input;
|
30
|
+
std::string test;
|
31
|
+
std::string output;
|
32
|
+
double lr;
|
33
|
+
int lrUpdateRate;
|
34
|
+
int dim;
|
35
|
+
int ws;
|
36
|
+
int epoch;
|
37
|
+
int minCount;
|
38
|
+
int minCountLabel;
|
39
|
+
int neg;
|
40
|
+
int wordNgrams;
|
41
|
+
loss_name loss;
|
42
|
+
model_name model;
|
43
|
+
int bucket;
|
44
|
+
int minn;
|
45
|
+
int maxn;
|
46
|
+
int thread;
|
47
|
+
double t;
|
48
|
+
std::string label;
|
49
|
+
int verbose;
|
50
|
+
std::string pretrainedVectors;
|
51
|
+
int saveOutput;
|
52
|
+
|
53
|
+
bool qout;
|
54
|
+
bool retrain;
|
55
|
+
bool qnorm;
|
56
|
+
size_t cutoff;
|
57
|
+
size_t dsub;
|
58
|
+
|
59
|
+
void parseArgs(const std::vector<std::string>& args);
|
60
|
+
void printHelp();
|
61
|
+
void printBasicHelp();
|
62
|
+
void printDictionaryHelp();
|
63
|
+
void printTrainingHelp();
|
64
|
+
void printQuantizationHelp();
|
65
|
+
void save(std::ostream&);
|
66
|
+
void load(std::istream&);
|
67
|
+
};
|
68
|
+
|
69
|
+
}
|
70
|
+
|
71
|
+
#endif
|
@@ -0,0 +1,475 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "dictionary.h"
|
11
|
+
|
12
|
+
#include <assert.h>
|
13
|
+
|
14
|
+
#include <iostream>
|
15
|
+
#include <fstream>
|
16
|
+
#include <algorithm>
|
17
|
+
#include <iterator>
|
18
|
+
#include <cmath>
|
19
|
+
|
20
|
+
namespace fasttext {
|
21
|
+
|
22
|
+
const std::string Dictionary::EOS = "</s>";
|
23
|
+
const std::string Dictionary::BOW = "<";
|
24
|
+
const std::string Dictionary::EOW = ">";
|
25
|
+
|
26
|
+
Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
|
27
|
+
word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
|
28
|
+
ntokens_(0), pruneidx_size_(-1) {}
|
29
|
+
|
30
|
+
int32_t Dictionary::find(const std::string& w) const {
|
31
|
+
return find(w, hash(w));
|
32
|
+
}
|
33
|
+
|
34
|
+
int32_t Dictionary::find(const std::string& w, uint32_t h) const {
|
35
|
+
int32_t id = h % MAX_VOCAB_SIZE;
|
36
|
+
while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
|
37
|
+
id = (id + 1) % MAX_VOCAB_SIZE;
|
38
|
+
}
|
39
|
+
return id;
|
40
|
+
}
|
41
|
+
|
42
|
+
void Dictionary::add(const std::string& w) {
|
43
|
+
int32_t h = find(w);
|
44
|
+
ntokens_++;
|
45
|
+
if (word2int_[h] == -1) {
|
46
|
+
entry e;
|
47
|
+
e.word = w;
|
48
|
+
e.count = 1;
|
49
|
+
e.type = getType(w);
|
50
|
+
words_.push_back(e);
|
51
|
+
word2int_[h] = size_++;
|
52
|
+
} else {
|
53
|
+
words_[word2int_[h]].count++;
|
54
|
+
}
|
55
|
+
}
|
56
|
+
|
57
|
+
int32_t Dictionary::nwords() const {
|
58
|
+
return nwords_;
|
59
|
+
}
|
60
|
+
|
61
|
+
int32_t Dictionary::nlabels() const {
|
62
|
+
return nlabels_;
|
63
|
+
}
|
64
|
+
|
65
|
+
int64_t Dictionary::ntokens() const {
|
66
|
+
return ntokens_;
|
67
|
+
}
|
68
|
+
|
69
|
+
const std::vector<int32_t>& Dictionary::getSubwords(int32_t i) const {
|
70
|
+
assert(i >= 0);
|
71
|
+
assert(i < nwords_);
|
72
|
+
return words_[i].subwords;
|
73
|
+
}
|
74
|
+
|
75
|
+
const std::vector<int32_t> Dictionary::getSubwords(
|
76
|
+
const std::string& word) const {
|
77
|
+
int32_t i = getId(word);
|
78
|
+
if (i >= 0) {
|
79
|
+
return getSubwords(i);
|
80
|
+
}
|
81
|
+
std::vector<int32_t> ngrams;
|
82
|
+
computeSubwords(BOW + word + EOW, ngrams);
|
83
|
+
return ngrams;
|
84
|
+
}
|
85
|
+
|
86
|
+
void Dictionary::getSubwords(const std::string& word,
|
87
|
+
std::vector<int32_t>& ngrams,
|
88
|
+
std::vector<std::string>& substrings) const {
|
89
|
+
int32_t i = getId(word);
|
90
|
+
ngrams.clear();
|
91
|
+
substrings.clear();
|
92
|
+
if (i >= 0) {
|
93
|
+
ngrams.push_back(i);
|
94
|
+
substrings.push_back(words_[i].word);
|
95
|
+
} else {
|
96
|
+
ngrams.push_back(-1);
|
97
|
+
substrings.push_back(word);
|
98
|
+
}
|
99
|
+
computeSubwords(BOW + word + EOW, ngrams, substrings);
|
100
|
+
}
|
101
|
+
|
102
|
+
bool Dictionary::discard(int32_t id, real rand) const {
|
103
|
+
assert(id >= 0);
|
104
|
+
assert(id < nwords_);
|
105
|
+
if (args_->model == model_name::sup) return false;
|
106
|
+
return rand > pdiscard_[id];
|
107
|
+
}
|
108
|
+
|
109
|
+
int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
|
110
|
+
int32_t id = find(w, h);
|
111
|
+
return word2int_[id];
|
112
|
+
}
|
113
|
+
|
114
|
+
int32_t Dictionary::getId(const std::string& w) const {
|
115
|
+
int32_t h = find(w);
|
116
|
+
return word2int_[h];
|
117
|
+
}
|
118
|
+
|
119
|
+
entry_type Dictionary::getType(int32_t id) const {
|
120
|
+
assert(id >= 0);
|
121
|
+
assert(id < size_);
|
122
|
+
return words_[id].type;
|
123
|
+
}
|
124
|
+
|
125
|
+
entry_type Dictionary::getType(const std::string& w) const {
|
126
|
+
return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
|
127
|
+
}
|
128
|
+
|
129
|
+
std::string Dictionary::getWord(int32_t id) const {
|
130
|
+
assert(id >= 0);
|
131
|
+
assert(id < size_);
|
132
|
+
return words_[id].word;
|
133
|
+
}
|
134
|
+
|
135
|
+
uint32_t Dictionary::hash(const std::string& str) const {
|
136
|
+
uint32_t h = 2166136261;
|
137
|
+
for (size_t i = 0; i < str.size(); i++) {
|
138
|
+
h = h ^ uint32_t(str[i]);
|
139
|
+
h = h * 16777619;
|
140
|
+
}
|
141
|
+
return h;
|
142
|
+
}
|
143
|
+
|
144
|
+
void Dictionary::computeSubwords(const std::string& word,
|
145
|
+
std::vector<int32_t>& ngrams,
|
146
|
+
std::vector<std::string>& substrings) const {
|
147
|
+
for (size_t i = 0; i < word.size(); i++) {
|
148
|
+
std::string ngram;
|
149
|
+
if ((word[i] & 0xC0) == 0x80) continue;
|
150
|
+
for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
|
151
|
+
ngram.push_back(word[j++]);
|
152
|
+
while (j < word.size() && (word[j] & 0xC0) == 0x80) {
|
153
|
+
ngram.push_back(word[j++]);
|
154
|
+
}
|
155
|
+
if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
|
156
|
+
int32_t h = hash(ngram) % args_->bucket;
|
157
|
+
ngrams.push_back(nwords_ + h);
|
158
|
+
substrings.push_back(ngram);
|
159
|
+
}
|
160
|
+
}
|
161
|
+
}
|
162
|
+
}
|
163
|
+
|
164
|
+
void Dictionary::computeSubwords(const std::string& word,
|
165
|
+
std::vector<int32_t>& ngrams) const {
|
166
|
+
for (size_t i = 0; i < word.size(); i++) {
|
167
|
+
std::string ngram;
|
168
|
+
if ((word[i] & 0xC0) == 0x80) continue;
|
169
|
+
for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
|
170
|
+
ngram.push_back(word[j++]);
|
171
|
+
while (j < word.size() && (word[j] & 0xC0) == 0x80) {
|
172
|
+
ngram.push_back(word[j++]);
|
173
|
+
}
|
174
|
+
if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
|
175
|
+
int32_t h = hash(ngram) % args_->bucket;
|
176
|
+
pushHash(ngrams, h);
|
177
|
+
}
|
178
|
+
}
|
179
|
+
}
|
180
|
+
}
|
181
|
+
|
182
|
+
void Dictionary::initNgrams() {
|
183
|
+
for (size_t i = 0; i < size_; i++) {
|
184
|
+
std::string word = BOW + words_[i].word + EOW;
|
185
|
+
words_[i].subwords.clear();
|
186
|
+
words_[i].subwords.push_back(i);
|
187
|
+
if (words_[i].word != EOS) {
|
188
|
+
computeSubwords(word, words_[i].subwords);
|
189
|
+
}
|
190
|
+
}
|
191
|
+
}
|
192
|
+
|
193
|
+
bool Dictionary::readWord(std::istream& in, std::string& word) const
|
194
|
+
{
|
195
|
+
char c;
|
196
|
+
std::streambuf& sb = *in.rdbuf();
|
197
|
+
word.clear();
|
198
|
+
while ((c = sb.sbumpc()) != EOF) {
|
199
|
+
if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
|
200
|
+
c == '\f' || c == '\0') {
|
201
|
+
if (word.empty()) {
|
202
|
+
if (c == '\n') {
|
203
|
+
word += EOS;
|
204
|
+
return true;
|
205
|
+
}
|
206
|
+
continue;
|
207
|
+
} else {
|
208
|
+
if (c == '\n')
|
209
|
+
sb.sungetc();
|
210
|
+
return true;
|
211
|
+
}
|
212
|
+
}
|
213
|
+
word.push_back(c);
|
214
|
+
}
|
215
|
+
// trigger eofbit
|
216
|
+
in.get();
|
217
|
+
return !word.empty();
|
218
|
+
}
|
219
|
+
|
220
|
+
void Dictionary::readFromFile(std::istream& in) {
|
221
|
+
std::string word;
|
222
|
+
int64_t minThreshold = 1;
|
223
|
+
while (readWord(in, word)) {
|
224
|
+
add(word);
|
225
|
+
if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
|
226
|
+
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
|
227
|
+
}
|
228
|
+
if (size_ > 0.75 * MAX_VOCAB_SIZE) {
|
229
|
+
minThreshold++;
|
230
|
+
threshold(minThreshold, minThreshold);
|
231
|
+
}
|
232
|
+
}
|
233
|
+
threshold(args_->minCount, args_->minCountLabel);
|
234
|
+
initTableDiscard();
|
235
|
+
initNgrams();
|
236
|
+
if (args_->verbose > 0) {
|
237
|
+
std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
|
238
|
+
std::cerr << "Number of words: " << nwords_ << std::endl;
|
239
|
+
std::cerr << "Number of labels: " << nlabels_ << std::endl;
|
240
|
+
}
|
241
|
+
if (size_ == 0) {
|
242
|
+
std::cerr << "Empty vocabulary. Try a smaller -minCount value."
|
243
|
+
<< std::endl;
|
244
|
+
exit(EXIT_FAILURE);
|
245
|
+
}
|
246
|
+
}
|
247
|
+
|
248
|
+
void Dictionary::threshold(int64_t t, int64_t tl) {
|
249
|
+
sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
|
250
|
+
if (e1.type != e2.type) return e1.type < e2.type;
|
251
|
+
return e1.count > e2.count;
|
252
|
+
});
|
253
|
+
words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
|
254
|
+
return (e.type == entry_type::word && e.count < t) ||
|
255
|
+
(e.type == entry_type::label && e.count < tl);
|
256
|
+
}), words_.end());
|
257
|
+
words_.shrink_to_fit();
|
258
|
+
size_ = 0;
|
259
|
+
nwords_ = 0;
|
260
|
+
nlabels_ = 0;
|
261
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
262
|
+
for (auto it = words_.begin(); it != words_.end(); ++it) {
|
263
|
+
int32_t h = find(it->word);
|
264
|
+
word2int_[h] = size_++;
|
265
|
+
if (it->type == entry_type::word) nwords_++;
|
266
|
+
if (it->type == entry_type::label) nlabels_++;
|
267
|
+
}
|
268
|
+
}
|
269
|
+
|
270
|
+
void Dictionary::initTableDiscard() {
|
271
|
+
pdiscard_.resize(size_);
|
272
|
+
for (size_t i = 0; i < size_; i++) {
|
273
|
+
real f = real(words_[i].count) / real(ntokens_);
|
274
|
+
pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f;
|
275
|
+
}
|
276
|
+
}
|
277
|
+
|
278
|
+
std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
|
279
|
+
std::vector<int64_t> counts;
|
280
|
+
for (auto& w : words_) {
|
281
|
+
if (w.type == type) counts.push_back(w.count);
|
282
|
+
}
|
283
|
+
return counts;
|
284
|
+
}
|
285
|
+
|
286
|
+
void Dictionary::addWordNgrams(std::vector<int32_t>& line,
|
287
|
+
const std::vector<int32_t>& hashes,
|
288
|
+
int32_t n) const {
|
289
|
+
for (int32_t i = 0; i < hashes.size(); i++) {
|
290
|
+
uint64_t h = hashes[i];
|
291
|
+
for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
|
292
|
+
h = h * 116049371 + hashes[j];
|
293
|
+
pushHash(line, h % args_->bucket);
|
294
|
+
}
|
295
|
+
}
|
296
|
+
}
|
297
|
+
|
298
|
+
void Dictionary::addSubwords(std::vector<int32_t>& line,
|
299
|
+
const std::string& token,
|
300
|
+
int32_t wid) const {
|
301
|
+
if (wid < 0) { // out of vocab
|
302
|
+
computeSubwords(BOW + token + EOW, line);
|
303
|
+
} else {
|
304
|
+
if (args_->maxn <= 0) { // in vocab w/o subwords
|
305
|
+
line.push_back(wid);
|
306
|
+
} else { // in vocab w/ subwords
|
307
|
+
const std::vector<int32_t>& ngrams = getSubwords(wid);
|
308
|
+
line.insert(line.end(), ngrams.cbegin(), ngrams.cend());
|
309
|
+
}
|
310
|
+
}
|
311
|
+
}
|
312
|
+
|
313
|
+
void Dictionary::reset(std::istream& in) const {
|
314
|
+
if (in.eof()) {
|
315
|
+
in.clear();
|
316
|
+
in.seekg(std::streampos(0));
|
317
|
+
}
|
318
|
+
}
|
319
|
+
|
320
|
+
int32_t Dictionary::getLine(std::istream& in,
|
321
|
+
std::vector<int32_t>& words,
|
322
|
+
std::minstd_rand& rng) const {
|
323
|
+
std::uniform_real_distribution<> uniform(0, 1);
|
324
|
+
std::string token;
|
325
|
+
int32_t ntokens = 0;
|
326
|
+
|
327
|
+
reset(in);
|
328
|
+
words.clear();
|
329
|
+
while (readWord(in, token)) {
|
330
|
+
int32_t h = find(token);
|
331
|
+
int32_t wid = word2int_[h];
|
332
|
+
if (wid < 0) continue;
|
333
|
+
|
334
|
+
ntokens++;
|
335
|
+
if (getType(wid) == entry_type::word && !discard(wid, uniform(rng))) {
|
336
|
+
words.push_back(wid);
|
337
|
+
}
|
338
|
+
if (ntokens > MAX_LINE_SIZE || token == EOS) break;
|
339
|
+
}
|
340
|
+
return ntokens;
|
341
|
+
}
|
342
|
+
|
343
|
+
int32_t Dictionary::getLine(std::istream& in,
|
344
|
+
std::vector<int32_t>& words,
|
345
|
+
std::vector<int32_t>& labels,
|
346
|
+
std::minstd_rand& rng) const {
|
347
|
+
std::vector<int32_t> word_hashes;
|
348
|
+
std::string token;
|
349
|
+
int32_t ntokens = 0;
|
350
|
+
|
351
|
+
reset(in);
|
352
|
+
words.clear();
|
353
|
+
labels.clear();
|
354
|
+
while (readWord(in, token)) {
|
355
|
+
uint32_t h = hash(token);
|
356
|
+
int32_t wid = getId(token, h);
|
357
|
+
entry_type type = wid < 0 ? getType(token) : getType(wid);
|
358
|
+
|
359
|
+
ntokens++;
|
360
|
+
if (type == entry_type::word) {
|
361
|
+
addSubwords(words, token, wid);
|
362
|
+
word_hashes.push_back(h);
|
363
|
+
} else if (type == entry_type::label && wid >= 0) {
|
364
|
+
labels.push_back(wid - nwords_);
|
365
|
+
}
|
366
|
+
if (token == EOS) break;
|
367
|
+
}
|
368
|
+
addWordNgrams(words, word_hashes, args_->wordNgrams);
|
369
|
+
return ntokens;
|
370
|
+
}
|
371
|
+
|
372
|
+
void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
|
373
|
+
if (pruneidx_size_ == 0 || id < 0) return;
|
374
|
+
if (pruneidx_size_ > 0) {
|
375
|
+
if (pruneidx_.count(id)) {
|
376
|
+
id = pruneidx_.at(id);
|
377
|
+
} else {
|
378
|
+
return;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
hashes.push_back(nwords_ + id);
|
382
|
+
}
|
383
|
+
|
384
|
+
std::string Dictionary::getLabel(int32_t lid) const {
|
385
|
+
assert(lid >= 0);
|
386
|
+
assert(lid < nlabels_);
|
387
|
+
return words_[lid + nwords_].word;
|
388
|
+
}
|
389
|
+
|
390
|
+
void Dictionary::save(std::ostream& out) const {
|
391
|
+
out.write((char*) &size_, sizeof(int32_t));
|
392
|
+
out.write((char*) &nwords_, sizeof(int32_t));
|
393
|
+
out.write((char*) &nlabels_, sizeof(int32_t));
|
394
|
+
out.write((char*) &ntokens_, sizeof(int64_t));
|
395
|
+
out.write((char*) &pruneidx_size_, sizeof(int64_t));
|
396
|
+
for (int32_t i = 0; i < size_; i++) {
|
397
|
+
entry e = words_[i];
|
398
|
+
out.write(e.word.data(), e.word.size() * sizeof(char));
|
399
|
+
out.put(0);
|
400
|
+
out.write((char*) &(e.count), sizeof(int64_t));
|
401
|
+
out.write((char*) &(e.type), sizeof(entry_type));
|
402
|
+
}
|
403
|
+
for (const auto pair : pruneidx_) {
|
404
|
+
out.write((char*) &(pair.first), sizeof(int32_t));
|
405
|
+
out.write((char*) &(pair.second), sizeof(int32_t));
|
406
|
+
}
|
407
|
+
}
|
408
|
+
|
409
|
+
void Dictionary::load(std::istream& in) {
|
410
|
+
words_.clear();
|
411
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
412
|
+
in.read((char*) &size_, sizeof(int32_t));
|
413
|
+
in.read((char*) &nwords_, sizeof(int32_t));
|
414
|
+
in.read((char*) &nlabels_, sizeof(int32_t));
|
415
|
+
in.read((char*) &ntokens_, sizeof(int64_t));
|
416
|
+
in.read((char*) &pruneidx_size_, sizeof(int64_t));
|
417
|
+
for (int32_t i = 0; i < size_; i++) {
|
418
|
+
char c;
|
419
|
+
entry e;
|
420
|
+
while ((c = in.get()) != 0) {
|
421
|
+
e.word.push_back(c);
|
422
|
+
}
|
423
|
+
in.read((char*) &e.count, sizeof(int64_t));
|
424
|
+
in.read((char*) &e.type, sizeof(entry_type));
|
425
|
+
words_.push_back(e);
|
426
|
+
word2int_[find(e.word)] = i;
|
427
|
+
}
|
428
|
+
pruneidx_.clear();
|
429
|
+
for (int32_t i = 0; i < pruneidx_size_; i++) {
|
430
|
+
int32_t first;
|
431
|
+
int32_t second;
|
432
|
+
in.read((char*) &first, sizeof(int32_t));
|
433
|
+
in.read((char*) &second, sizeof(int32_t));
|
434
|
+
pruneidx_[first] = second;
|
435
|
+
}
|
436
|
+
initTableDiscard();
|
437
|
+
initNgrams();
|
438
|
+
}
|
439
|
+
|
440
|
+
void Dictionary::prune(std::vector<int32_t>& idx) {
|
441
|
+
std::vector<int32_t> words, ngrams;
|
442
|
+
for (auto it = idx.cbegin(); it != idx.cend(); ++it) {
|
443
|
+
if (*it < nwords_) {words.push_back(*it);}
|
444
|
+
else {ngrams.push_back(*it);}
|
445
|
+
}
|
446
|
+
std::sort(words.begin(), words.end());
|
447
|
+
idx = words;
|
448
|
+
|
449
|
+
if (ngrams.size() != 0) {
|
450
|
+
int32_t j = 0;
|
451
|
+
for (const auto ngram : ngrams) {
|
452
|
+
pruneidx_[ngram - nwords_] = j;
|
453
|
+
j++;
|
454
|
+
}
|
455
|
+
idx.insert(idx.end(), ngrams.begin(), ngrams.end());
|
456
|
+
}
|
457
|
+
pruneidx_size_ = pruneidx_.size();
|
458
|
+
|
459
|
+
std::fill(word2int_.begin(), word2int_.end(), -1);
|
460
|
+
|
461
|
+
int32_t j = 0;
|
462
|
+
for (int32_t i = 0; i < words_.size(); i++) {
|
463
|
+
if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) {
|
464
|
+
words_[j] = words_[i];
|
465
|
+
word2int_[find(words_[j].word)] = j;
|
466
|
+
j++;
|
467
|
+
}
|
468
|
+
}
|
469
|
+
nwords_ = words.size();
|
470
|
+
size_ = nwords_ + nlabels_;
|
471
|
+
words_.erase(words_.begin() + size_, words_.end());
|
472
|
+
initNgrams();
|
473
|
+
}
|
474
|
+
|
475
|
+
}
|