ffi-fasttext 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,33 @@
1
+ Additional Grant of Patent Rights Version 2
2
+
3
+ "Software" means the fastText software distributed by Facebook, Inc.
4
+
5
+ Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software
6
+ ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable
7
+ (subject to the termination provision below) license under any Necessary
8
+ Claims, to make, have made, use, sell, offer to sell, import, and otherwise
9
+ transfer the Software. For avoidance of doubt, no license is granted under
10
+ Facebook’s rights in any patent claims that are infringed by (i) modifications
11
+ to the Software made by you or any third party or (ii) the Software in
12
+ combination with any software or other technology.
13
+
14
+ The license granted hereunder will terminate, automatically and without notice,
15
+ if you (or any of your subsidiaries, corporate affiliates or agents) initiate
16
+ directly or indirectly, or take a direct financial interest in, any Patent
17
+ Assertion: (i) against Facebook or any of its subsidiaries or corporate
18
+ affiliates, (ii) against any party if such Patent Assertion arises in whole or
19
+ in part from any software, technology, product or service of Facebook or any of
20
+ its subsidiaries or corporate affiliates, or (iii) against any party relating
21
+ to the Software. Notwithstanding the foregoing, if Facebook or any of its
22
+ subsidiaries or corporate affiliates files a lawsuit alleging patent
23
+ infringement against you in the first instance, and you respond by filing a
24
+ patent infringement counterclaim in that lawsuit against that party that is
25
+ unrelated to the Software, the license granted hereunder will not terminate
26
+ under section (i) of this paragraph due to such counterclaim.
27
+
28
+ A "Necessary Claim" is a claim of a patent owned by Facebook that is
29
+ necessarily infringed by the Software standing alone.
30
+
31
+ A "Patent Assertion" is any lawsuit or other action alleging direct, indirect,
32
+ or contributory infringement or inducement to infringe any patent, including a
33
+ cross-claim or counterclaim.
@@ -0,0 +1,250 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #include "args.h"
11
+
12
+ #include <stdlib.h>
13
+
14
+ #include <iostream>
15
+
16
+ namespace fasttext {
17
+
18
+ Args::Args() {
19
+ lr = 0.05;
20
+ dim = 100;
21
+ ws = 5;
22
+ epoch = 5;
23
+ minCount = 5;
24
+ minCountLabel = 0;
25
+ neg = 5;
26
+ wordNgrams = 1;
27
+ loss = loss_name::ns;
28
+ model = model_name::sg;
29
+ bucket = 2000000;
30
+ minn = 3;
31
+ maxn = 6;
32
+ thread = 12;
33
+ lrUpdateRate = 100;
34
+ t = 1e-4;
35
+ label = "__label__";
36
+ verbose = 2;
37
+ pretrainedVectors = "";
38
+ saveOutput = 0;
39
+
40
+ qout = false;
41
+ retrain = false;
42
+ qnorm = false;
43
+ cutoff = 0;
44
+ dsub = 2;
45
+ }
46
+
47
+ std::string Args::lossToString(loss_name ln) {
48
+ switch (ln) {
49
+ case loss_name::hs:
50
+ return "hs";
51
+ case loss_name::ns:
52
+ return "ns";
53
+ case loss_name::softmax:
54
+ return "softmax";
55
+ }
56
+ return "Unknown loss!"; // should never happen
57
+ }
58
+
59
+ void Args::parseArgs(const std::vector<std::string>& args) {
60
+ std::string command(args[1]);
61
+ if (command == "supervised") {
62
+ model = model_name::sup;
63
+ loss = loss_name::softmax;
64
+ minCount = 1;
65
+ minn = 0;
66
+ maxn = 0;
67
+ lr = 0.1;
68
+ } else if (command == "cbow") {
69
+ model = model_name::cbow;
70
+ }
71
+ int ai = 2;
72
+ while (ai < args.size()) {
73
+ if (args[ai][0] != '-') {
74
+ std::cerr << "Provided argument without a dash! Usage:" << std::endl;
75
+ printHelp();
76
+ exit(EXIT_FAILURE);
77
+ }
78
+ if (args[ai] == "-h") {
79
+ std::cerr << "Here is the help! Usage:" << std::endl;
80
+ printHelp();
81
+ exit(EXIT_FAILURE);
82
+ } else if (args[ai] == "-input") {
83
+ input = std::string(args[ai + 1]);
84
+ } else if (args[ai] == "-test") {
85
+ test = std::string(args[ai + 1]);
86
+ } else if (args[ai] == "-output") {
87
+ output = std::string(args[ai + 1]);
88
+ } else if (args[ai] == "-lr") {
89
+ lr = std::stof(args[ai + 1]);
90
+ } else if (args[ai] == "-lrUpdateRate") {
91
+ lrUpdateRate = std::stoi(args[ai + 1]);
92
+ } else if (args[ai] == "-dim") {
93
+ dim = std::stoi(args[ai + 1]);
94
+ } else if (args[ai] == "-ws") {
95
+ ws = std::stoi(args[ai + 1]);
96
+ } else if (args[ai] == "-epoch") {
97
+ epoch = std::stoi(args[ai + 1]);
98
+ } else if (args[ai] == "-minCount") {
99
+ minCount = std::stoi(args[ai + 1]);
100
+ } else if (args[ai] == "-minCountLabel") {
101
+ minCountLabel = std::stoi(args[ai + 1]);
102
+ } else if (args[ai] == "-neg") {
103
+ neg = std::stoi(args[ai + 1]);
104
+ } else if (args[ai] == "-wordNgrams") {
105
+ wordNgrams = std::stoi(args[ai + 1]);
106
+ } else if (args[ai] == "-loss") {
107
+ if (args[ai + 1] == "hs") {
108
+ loss = loss_name::hs;
109
+ } else if (args[ai + 1] == "ns") {
110
+ loss = loss_name::ns;
111
+ } else if (args[ai + 1] == "softmax") {
112
+ loss = loss_name::softmax;
113
+ } else {
114
+ std::cerr << "Unknown loss: " << args[ai + 1] << std::endl;
115
+ printHelp();
116
+ exit(EXIT_FAILURE);
117
+ }
118
+ } else if (args[ai] == "-bucket") {
119
+ bucket = std::stoi(args[ai + 1]);
120
+ } else if (args[ai] == "-minn") {
121
+ minn = std::stoi(args[ai + 1]);
122
+ } else if (args[ai] == "-maxn") {
123
+ maxn = std::stoi(args[ai + 1]);
124
+ } else if (args[ai] == "-thread") {
125
+ thread = std::stoi(args[ai + 1]);
126
+ } else if (args[ai] == "-t") {
127
+ t = std::stof(args[ai + 1]);
128
+ } else if (args[ai] == "-label") {
129
+ label = std::string(args[ai + 1]);
130
+ } else if (args[ai] == "-verbose") {
131
+ verbose = std::stoi(args[ai + 1]);
132
+ } else if (args[ai] == "-pretrainedVectors") {
133
+ pretrainedVectors = std::string(args[ai + 1]);
134
+ } else if (args[ai] == "-saveOutput") {
135
+ saveOutput = std::stoi(args[ai + 1]);
136
+ } else if (args[ai] == "-qnorm") {
137
+ qnorm = true; ai--;
138
+ } else if (args[ai] == "-retrain") {
139
+ retrain = true; ai--;
140
+ } else if (args[ai] == "-qout") {
141
+ qout = true; ai--;
142
+ } else if (args[ai] == "-cutoff") {
143
+ cutoff = std::stoi(args[ai + 1]);
144
+ } else if (args[ai] == "-dsub") {
145
+ dsub = std::stoi(args[ai + 1]);
146
+ } else {
147
+ std::cerr << "Unknown argument: " << args[ai] << std::endl;
148
+ printHelp();
149
+ exit(EXIT_FAILURE);
150
+ }
151
+ ai += 2;
152
+ }
153
+ if (input.empty() || output.empty()) {
154
+ std::cerr << "Empty input or output path." << std::endl;
155
+ printHelp();
156
+ exit(EXIT_FAILURE);
157
+ }
158
+ if (wordNgrams <= 1 && maxn == 0) {
159
+ bucket = 0;
160
+ }
161
+ }
162
+
163
+ void Args::printHelp() {
164
+ printBasicHelp();
165
+ printDictionaryHelp();
166
+ printTrainingHelp();
167
+ printQuantizationHelp();
168
+ }
169
+
170
+
171
+ void Args::printBasicHelp() {
172
+ std::cerr
173
+ << "\nThe following arguments are mandatory:\n"
174
+ << " -input training file path\n"
175
+ << " -output output file path\n"
176
+ << "\nThe following arguments are optional:\n"
177
+ << " -verbose verbosity level [" << verbose << "]\n";
178
+ }
179
+
180
+ void Args::printDictionaryHelp() {
181
+ std::cerr
182
+ << "\nThe following arguments for the dictionary are optional:\n"
183
+ << " -minCount minimal number of word occurences [" << minCount << "]\n"
184
+ << " -minCountLabel minimal number of label occurences [" << minCountLabel << "]\n"
185
+ << " -wordNgrams max length of word ngram [" << wordNgrams << "]\n"
186
+ << " -bucket number of buckets [" << bucket << "]\n"
187
+ << " -minn min length of char ngram [" << minn << "]\n"
188
+ << " -maxn max length of char ngram [" << maxn << "]\n"
189
+ << " -t sampling threshold [" << t << "]\n"
190
+ << " -label labels prefix [" << label << "]\n";
191
+ }
192
+
193
+ void Args::printTrainingHelp() {
194
+ std::cerr
195
+ << "\nThe following arguments for training are optional:\n"
196
+ << " -lr learning rate [" << lr << "]\n"
197
+ << " -lrUpdateRate change the rate of updates for the learning rate [" << lrUpdateRate << "]\n"
198
+ << " -dim size of word vectors [" << dim << "]\n"
199
+ << " -ws size of the context window [" << ws << "]\n"
200
+ << " -epoch number of epochs [" << epoch << "]\n"
201
+ << " -neg number of negatives sampled [" << neg << "]\n"
202
+ << " -loss loss function {ns, hs, softmax} [" << lossToString(loss) << "]\n"
203
+ << " -thread number of threads [" << thread << "]\n"
204
+ << " -pretrainedVectors pretrained word vectors for supervised learning ["<< pretrainedVectors <<"]\n"
205
+ << " -saveOutput whether output params should be saved [" << saveOutput << "]\n";
206
+ }
207
+
208
+ void Args::printQuantizationHelp() {
209
+ std::cerr
210
+ << "\nThe following arguments for quantization are optional:\n"
211
+ << " -cutoff number of words and ngrams to retain [" << cutoff << "]\n"
212
+ << " -retrain finetune embeddings if a cutoff is applied [" << retrain << "]\n"
213
+ << " -qnorm quantizing the norm separately [" << qnorm << "]\n"
214
+ << " -qout quantizing the classifier [" << qout << "]\n"
215
+ << " -dsub size of each sub-vector [" << dsub << "]\n";
216
+ }
217
+
218
+ void Args::save(std::ostream& out) {
219
+ out.write((char*) &(dim), sizeof(int));
220
+ out.write((char*) &(ws), sizeof(int));
221
+ out.write((char*) &(epoch), sizeof(int));
222
+ out.write((char*) &(minCount), sizeof(int));
223
+ out.write((char*) &(neg), sizeof(int));
224
+ out.write((char*) &(wordNgrams), sizeof(int));
225
+ out.write((char*) &(loss), sizeof(loss_name));
226
+ out.write((char*) &(model), sizeof(model_name));
227
+ out.write((char*) &(bucket), sizeof(int));
228
+ out.write((char*) &(minn), sizeof(int));
229
+ out.write((char*) &(maxn), sizeof(int));
230
+ out.write((char*) &(lrUpdateRate), sizeof(int));
231
+ out.write((char*) &(t), sizeof(double));
232
+ }
233
+
234
+ void Args::load(std::istream& in) {
235
+ in.read((char*) &(dim), sizeof(int));
236
+ in.read((char*) &(ws), sizeof(int));
237
+ in.read((char*) &(epoch), sizeof(int));
238
+ in.read((char*) &(minCount), sizeof(int));
239
+ in.read((char*) &(neg), sizeof(int));
240
+ in.read((char*) &(wordNgrams), sizeof(int));
241
+ in.read((char*) &(loss), sizeof(loss_name));
242
+ in.read((char*) &(model), sizeof(model_name));
243
+ in.read((char*) &(bucket), sizeof(int));
244
+ in.read((char*) &(minn), sizeof(int));
245
+ in.read((char*) &(maxn), sizeof(int));
246
+ in.read((char*) &(lrUpdateRate), sizeof(int));
247
+ in.read((char*) &(t), sizeof(double));
248
+ }
249
+
250
+ }
@@ -0,0 +1,71 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #ifndef FASTTEXT_ARGS_H
11
+ #define FASTTEXT_ARGS_H
12
+
13
+ #include <istream>
14
+ #include <ostream>
15
+ #include <string>
16
+ #include <vector>
17
+
18
+ namespace fasttext {
19
+
20
+ enum class model_name : int {cbow=1, sg, sup};
21
+ enum class loss_name : int {hs=1, ns, softmax};
22
+
23
+ class Args {
24
+ private:
25
+ std::string lossToString(loss_name);
26
+
27
+ public:
28
+ Args();
29
+ std::string input;
30
+ std::string test;
31
+ std::string output;
32
+ double lr;
33
+ int lrUpdateRate;
34
+ int dim;
35
+ int ws;
36
+ int epoch;
37
+ int minCount;
38
+ int minCountLabel;
39
+ int neg;
40
+ int wordNgrams;
41
+ loss_name loss;
42
+ model_name model;
43
+ int bucket;
44
+ int minn;
45
+ int maxn;
46
+ int thread;
47
+ double t;
48
+ std::string label;
49
+ int verbose;
50
+ std::string pretrainedVectors;
51
+ int saveOutput;
52
+
53
+ bool qout;
54
+ bool retrain;
55
+ bool qnorm;
56
+ size_t cutoff;
57
+ size_t dsub;
58
+
59
+ void parseArgs(const std::vector<std::string>& args);
60
+ void printHelp();
61
+ void printBasicHelp();
62
+ void printDictionaryHelp();
63
+ void printTrainingHelp();
64
+ void printQuantizationHelp();
65
+ void save(std::ostream&);
66
+ void load(std::istream&);
67
+ };
68
+
69
+ }
70
+
71
+ #endif
@@ -0,0 +1,475 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #include "dictionary.h"
11
+
12
+ #include <assert.h>
13
+
14
+ #include <iostream>
15
+ #include <fstream>
16
+ #include <algorithm>
17
+ #include <iterator>
18
+ #include <cmath>
19
+
20
+ namespace fasttext {
21
+
22
+ const std::string Dictionary::EOS = "</s>";
23
+ const std::string Dictionary::BOW = "<";
24
+ const std::string Dictionary::EOW = ">";
25
+
26
+ Dictionary::Dictionary(std::shared_ptr<Args> args) : args_(args),
27
+ word2int_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
28
+ ntokens_(0), pruneidx_size_(-1) {}
29
+
30
+ int32_t Dictionary::find(const std::string& w) const {
31
+ return find(w, hash(w));
32
+ }
33
+
34
+ int32_t Dictionary::find(const std::string& w, uint32_t h) const {
35
+ int32_t id = h % MAX_VOCAB_SIZE;
36
+ while (word2int_[id] != -1 && words_[word2int_[id]].word != w) {
37
+ id = (id + 1) % MAX_VOCAB_SIZE;
38
+ }
39
+ return id;
40
+ }
41
+
42
+ void Dictionary::add(const std::string& w) {
43
+ int32_t h = find(w);
44
+ ntokens_++;
45
+ if (word2int_[h] == -1) {
46
+ entry e;
47
+ e.word = w;
48
+ e.count = 1;
49
+ e.type = getType(w);
50
+ words_.push_back(e);
51
+ word2int_[h] = size_++;
52
+ } else {
53
+ words_[word2int_[h]].count++;
54
+ }
55
+ }
56
+
57
+ int32_t Dictionary::nwords() const {
58
+ return nwords_;
59
+ }
60
+
61
+ int32_t Dictionary::nlabels() const {
62
+ return nlabels_;
63
+ }
64
+
65
+ int64_t Dictionary::ntokens() const {
66
+ return ntokens_;
67
+ }
68
+
69
+ const std::vector<int32_t>& Dictionary::getSubwords(int32_t i) const {
70
+ assert(i >= 0);
71
+ assert(i < nwords_);
72
+ return words_[i].subwords;
73
+ }
74
+
75
+ const std::vector<int32_t> Dictionary::getSubwords(
76
+ const std::string& word) const {
77
+ int32_t i = getId(word);
78
+ if (i >= 0) {
79
+ return getSubwords(i);
80
+ }
81
+ std::vector<int32_t> ngrams;
82
+ computeSubwords(BOW + word + EOW, ngrams);
83
+ return ngrams;
84
+ }
85
+
86
+ void Dictionary::getSubwords(const std::string& word,
87
+ std::vector<int32_t>& ngrams,
88
+ std::vector<std::string>& substrings) const {
89
+ int32_t i = getId(word);
90
+ ngrams.clear();
91
+ substrings.clear();
92
+ if (i >= 0) {
93
+ ngrams.push_back(i);
94
+ substrings.push_back(words_[i].word);
95
+ } else {
96
+ ngrams.push_back(-1);
97
+ substrings.push_back(word);
98
+ }
99
+ computeSubwords(BOW + word + EOW, ngrams, substrings);
100
+ }
101
+
102
+ bool Dictionary::discard(int32_t id, real rand) const {
103
+ assert(id >= 0);
104
+ assert(id < nwords_);
105
+ if (args_->model == model_name::sup) return false;
106
+ return rand > pdiscard_[id];
107
+ }
108
+
109
+ int32_t Dictionary::getId(const std::string& w, uint32_t h) const {
110
+ int32_t id = find(w, h);
111
+ return word2int_[id];
112
+ }
113
+
114
+ int32_t Dictionary::getId(const std::string& w) const {
115
+ int32_t h = find(w);
116
+ return word2int_[h];
117
+ }
118
+
119
+ entry_type Dictionary::getType(int32_t id) const {
120
+ assert(id >= 0);
121
+ assert(id < size_);
122
+ return words_[id].type;
123
+ }
124
+
125
+ entry_type Dictionary::getType(const std::string& w) const {
126
+ return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word;
127
+ }
128
+
129
+ std::string Dictionary::getWord(int32_t id) const {
130
+ assert(id >= 0);
131
+ assert(id < size_);
132
+ return words_[id].word;
133
+ }
134
+
135
+ uint32_t Dictionary::hash(const std::string& str) const {
136
+ uint32_t h = 2166136261;
137
+ for (size_t i = 0; i < str.size(); i++) {
138
+ h = h ^ uint32_t(str[i]);
139
+ h = h * 16777619;
140
+ }
141
+ return h;
142
+ }
143
+
144
+ void Dictionary::computeSubwords(const std::string& word,
145
+ std::vector<int32_t>& ngrams,
146
+ std::vector<std::string>& substrings) const {
147
+ for (size_t i = 0; i < word.size(); i++) {
148
+ std::string ngram;
149
+ if ((word[i] & 0xC0) == 0x80) continue;
150
+ for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
151
+ ngram.push_back(word[j++]);
152
+ while (j < word.size() && (word[j] & 0xC0) == 0x80) {
153
+ ngram.push_back(word[j++]);
154
+ }
155
+ if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
156
+ int32_t h = hash(ngram) % args_->bucket;
157
+ ngrams.push_back(nwords_ + h);
158
+ substrings.push_back(ngram);
159
+ }
160
+ }
161
+ }
162
+ }
163
+
164
+ void Dictionary::computeSubwords(const std::string& word,
165
+ std::vector<int32_t>& ngrams) const {
166
+ for (size_t i = 0; i < word.size(); i++) {
167
+ std::string ngram;
168
+ if ((word[i] & 0xC0) == 0x80) continue;
169
+ for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) {
170
+ ngram.push_back(word[j++]);
171
+ while (j < word.size() && (word[j] & 0xC0) == 0x80) {
172
+ ngram.push_back(word[j++]);
173
+ }
174
+ if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) {
175
+ int32_t h = hash(ngram) % args_->bucket;
176
+ pushHash(ngrams, h);
177
+ }
178
+ }
179
+ }
180
+ }
181
+
182
+ void Dictionary::initNgrams() {
183
+ for (size_t i = 0; i < size_; i++) {
184
+ std::string word = BOW + words_[i].word + EOW;
185
+ words_[i].subwords.clear();
186
+ words_[i].subwords.push_back(i);
187
+ if (words_[i].word != EOS) {
188
+ computeSubwords(word, words_[i].subwords);
189
+ }
190
+ }
191
+ }
192
+
193
+ bool Dictionary::readWord(std::istream& in, std::string& word) const
194
+ {
195
+ char c;
196
+ std::streambuf& sb = *in.rdbuf();
197
+ word.clear();
198
+ while ((c = sb.sbumpc()) != EOF) {
199
+ if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' ||
200
+ c == '\f' || c == '\0') {
201
+ if (word.empty()) {
202
+ if (c == '\n') {
203
+ word += EOS;
204
+ return true;
205
+ }
206
+ continue;
207
+ } else {
208
+ if (c == '\n')
209
+ sb.sungetc();
210
+ return true;
211
+ }
212
+ }
213
+ word.push_back(c);
214
+ }
215
+ // trigger eofbit
216
+ in.get();
217
+ return !word.empty();
218
+ }
219
+
220
+ void Dictionary::readFromFile(std::istream& in) {
221
+ std::string word;
222
+ int64_t minThreshold = 1;
223
+ while (readWord(in, word)) {
224
+ add(word);
225
+ if (ntokens_ % 1000000 == 0 && args_->verbose > 1) {
226
+ std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
227
+ }
228
+ if (size_ > 0.75 * MAX_VOCAB_SIZE) {
229
+ minThreshold++;
230
+ threshold(minThreshold, minThreshold);
231
+ }
232
+ }
233
+ threshold(args_->minCount, args_->minCountLabel);
234
+ initTableDiscard();
235
+ initNgrams();
236
+ if (args_->verbose > 0) {
237
+ std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
238
+ std::cerr << "Number of words: " << nwords_ << std::endl;
239
+ std::cerr << "Number of labels: " << nlabels_ << std::endl;
240
+ }
241
+ if (size_ == 0) {
242
+ std::cerr << "Empty vocabulary. Try a smaller -minCount value."
243
+ << std::endl;
244
+ exit(EXIT_FAILURE);
245
+ }
246
+ }
247
+
248
+ void Dictionary::threshold(int64_t t, int64_t tl) {
249
+ sort(words_.begin(), words_.end(), [](const entry& e1, const entry& e2) {
250
+ if (e1.type != e2.type) return e1.type < e2.type;
251
+ return e1.count > e2.count;
252
+ });
253
+ words_.erase(remove_if(words_.begin(), words_.end(), [&](const entry& e) {
254
+ return (e.type == entry_type::word && e.count < t) ||
255
+ (e.type == entry_type::label && e.count < tl);
256
+ }), words_.end());
257
+ words_.shrink_to_fit();
258
+ size_ = 0;
259
+ nwords_ = 0;
260
+ nlabels_ = 0;
261
+ std::fill(word2int_.begin(), word2int_.end(), -1);
262
+ for (auto it = words_.begin(); it != words_.end(); ++it) {
263
+ int32_t h = find(it->word);
264
+ word2int_[h] = size_++;
265
+ if (it->type == entry_type::word) nwords_++;
266
+ if (it->type == entry_type::label) nlabels_++;
267
+ }
268
+ }
269
+
270
+ void Dictionary::initTableDiscard() {
271
+ pdiscard_.resize(size_);
272
+ for (size_t i = 0; i < size_; i++) {
273
+ real f = real(words_[i].count) / real(ntokens_);
274
+ pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f;
275
+ }
276
+ }
277
+
278
+ std::vector<int64_t> Dictionary::getCounts(entry_type type) const {
279
+ std::vector<int64_t> counts;
280
+ for (auto& w : words_) {
281
+ if (w.type == type) counts.push_back(w.count);
282
+ }
283
+ return counts;
284
+ }
285
+
286
+ void Dictionary::addWordNgrams(std::vector<int32_t>& line,
287
+ const std::vector<int32_t>& hashes,
288
+ int32_t n) const {
289
+ for (int32_t i = 0; i < hashes.size(); i++) {
290
+ uint64_t h = hashes[i];
291
+ for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) {
292
+ h = h * 116049371 + hashes[j];
293
+ pushHash(line, h % args_->bucket);
294
+ }
295
+ }
296
+ }
297
+
298
+ void Dictionary::addSubwords(std::vector<int32_t>& line,
299
+ const std::string& token,
300
+ int32_t wid) const {
301
+ if (wid < 0) { // out of vocab
302
+ computeSubwords(BOW + token + EOW, line);
303
+ } else {
304
+ if (args_->maxn <= 0) { // in vocab w/o subwords
305
+ line.push_back(wid);
306
+ } else { // in vocab w/ subwords
307
+ const std::vector<int32_t>& ngrams = getSubwords(wid);
308
+ line.insert(line.end(), ngrams.cbegin(), ngrams.cend());
309
+ }
310
+ }
311
+ }
312
+
313
+ void Dictionary::reset(std::istream& in) const {
314
+ if (in.eof()) {
315
+ in.clear();
316
+ in.seekg(std::streampos(0));
317
+ }
318
+ }
319
+
320
+ int32_t Dictionary::getLine(std::istream& in,
321
+ std::vector<int32_t>& words,
322
+ std::minstd_rand& rng) const {
323
+ std::uniform_real_distribution<> uniform(0, 1);
324
+ std::string token;
325
+ int32_t ntokens = 0;
326
+
327
+ reset(in);
328
+ words.clear();
329
+ while (readWord(in, token)) {
330
+ int32_t h = find(token);
331
+ int32_t wid = word2int_[h];
332
+ if (wid < 0) continue;
333
+
334
+ ntokens++;
335
+ if (getType(wid) == entry_type::word && !discard(wid, uniform(rng))) {
336
+ words.push_back(wid);
337
+ }
338
+ if (ntokens > MAX_LINE_SIZE || token == EOS) break;
339
+ }
340
+ return ntokens;
341
+ }
342
+
343
+ int32_t Dictionary::getLine(std::istream& in,
344
+ std::vector<int32_t>& words,
345
+ std::vector<int32_t>& labels,
346
+ std::minstd_rand& rng) const {
347
+ std::vector<int32_t> word_hashes;
348
+ std::string token;
349
+ int32_t ntokens = 0;
350
+
351
+ reset(in);
352
+ words.clear();
353
+ labels.clear();
354
+ while (readWord(in, token)) {
355
+ uint32_t h = hash(token);
356
+ int32_t wid = getId(token, h);
357
+ entry_type type = wid < 0 ? getType(token) : getType(wid);
358
+
359
+ ntokens++;
360
+ if (type == entry_type::word) {
361
+ addSubwords(words, token, wid);
362
+ word_hashes.push_back(h);
363
+ } else if (type == entry_type::label && wid >= 0) {
364
+ labels.push_back(wid - nwords_);
365
+ }
366
+ if (token == EOS) break;
367
+ }
368
+ addWordNgrams(words, word_hashes, args_->wordNgrams);
369
+ return ntokens;
370
+ }
371
+
372
+ void Dictionary::pushHash(std::vector<int32_t>& hashes, int32_t id) const {
373
+ if (pruneidx_size_ == 0 || id < 0) return;
374
+ if (pruneidx_size_ > 0) {
375
+ if (pruneidx_.count(id)) {
376
+ id = pruneidx_.at(id);
377
+ } else {
378
+ return;
379
+ }
380
+ }
381
+ hashes.push_back(nwords_ + id);
382
+ }
383
+
384
+ std::string Dictionary::getLabel(int32_t lid) const {
385
+ assert(lid >= 0);
386
+ assert(lid < nlabels_);
387
+ return words_[lid + nwords_].word;
388
+ }
389
+
390
+ void Dictionary::save(std::ostream& out) const {
391
+ out.write((char*) &size_, sizeof(int32_t));
392
+ out.write((char*) &nwords_, sizeof(int32_t));
393
+ out.write((char*) &nlabels_, sizeof(int32_t));
394
+ out.write((char*) &ntokens_, sizeof(int64_t));
395
+ out.write((char*) &pruneidx_size_, sizeof(int64_t));
396
+ for (int32_t i = 0; i < size_; i++) {
397
+ entry e = words_[i];
398
+ out.write(e.word.data(), e.word.size() * sizeof(char));
399
+ out.put(0);
400
+ out.write((char*) &(e.count), sizeof(int64_t));
401
+ out.write((char*) &(e.type), sizeof(entry_type));
402
+ }
403
+ for (const auto pair : pruneidx_) {
404
+ out.write((char*) &(pair.first), sizeof(int32_t));
405
+ out.write((char*) &(pair.second), sizeof(int32_t));
406
+ }
407
+ }
408
+
409
+ void Dictionary::load(std::istream& in) {
410
+ words_.clear();
411
+ std::fill(word2int_.begin(), word2int_.end(), -1);
412
+ in.read((char*) &size_, sizeof(int32_t));
413
+ in.read((char*) &nwords_, sizeof(int32_t));
414
+ in.read((char*) &nlabels_, sizeof(int32_t));
415
+ in.read((char*) &ntokens_, sizeof(int64_t));
416
+ in.read((char*) &pruneidx_size_, sizeof(int64_t));
417
+ for (int32_t i = 0; i < size_; i++) {
418
+ char c;
419
+ entry e;
420
+ while ((c = in.get()) != 0) {
421
+ e.word.push_back(c);
422
+ }
423
+ in.read((char*) &e.count, sizeof(int64_t));
424
+ in.read((char*) &e.type, sizeof(entry_type));
425
+ words_.push_back(e);
426
+ word2int_[find(e.word)] = i;
427
+ }
428
+ pruneidx_.clear();
429
+ for (int32_t i = 0; i < pruneidx_size_; i++) {
430
+ int32_t first;
431
+ int32_t second;
432
+ in.read((char*) &first, sizeof(int32_t));
433
+ in.read((char*) &second, sizeof(int32_t));
434
+ pruneidx_[first] = second;
435
+ }
436
+ initTableDiscard();
437
+ initNgrams();
438
+ }
439
+
440
+ void Dictionary::prune(std::vector<int32_t>& idx) {
441
+ std::vector<int32_t> words, ngrams;
442
+ for (auto it = idx.cbegin(); it != idx.cend(); ++it) {
443
+ if (*it < nwords_) {words.push_back(*it);}
444
+ else {ngrams.push_back(*it);}
445
+ }
446
+ std::sort(words.begin(), words.end());
447
+ idx = words;
448
+
449
+ if (ngrams.size() != 0) {
450
+ int32_t j = 0;
451
+ for (const auto ngram : ngrams) {
452
+ pruneidx_[ngram - nwords_] = j;
453
+ j++;
454
+ }
455
+ idx.insert(idx.end(), ngrams.begin(), ngrams.end());
456
+ }
457
+ pruneidx_size_ = pruneidx_.size();
458
+
459
+ std::fill(word2int_.begin(), word2int_.end(), -1);
460
+
461
+ int32_t j = 0;
462
+ for (int32_t i = 0; i < words_.size(); i++) {
463
+ if (getType(i) == entry_type::label || (j < words.size() && words[j] == i)) {
464
+ words_[j] = words_[i];
465
+ word2int_[find(words_[j].word)] = j;
466
+ j++;
467
+ }
468
+ }
469
+ nwords_ = words.size();
470
+ size_ = nwords_ + nlabels_;
471
+ words_.erase(words_.begin() + size_, words_.end());
472
+ initNgrams();
473
+ }
474
+
475
+ }