ffi-fasttext 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,112 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #ifndef FASTTEXT_DICTIONARY_H
11
+ #define FASTTEXT_DICTIONARY_H
12
+
13
+ #include <vector>
14
+ #include <string>
15
+ #include <istream>
16
+ #include <ostream>
17
+ #include <random>
18
+ #include <memory>
19
+ #include <unordered_map>
20
+
21
+ #include "args.h"
22
+ #include "real.h"
23
+
24
+ namespace fasttext {
25
+
26
+ typedef int32_t id_type;
27
+ enum class entry_type : int8_t {word=0, label=1};
28
+
29
+ struct entry {
30
+ std::string word;
31
+ int64_t count;
32
+ entry_type type;
33
+ std::vector<int32_t> subwords;
34
+ };
35
+
36
+ class Dictionary {
37
+ private:
38
+ static const int32_t MAX_VOCAB_SIZE = 30000000;
39
+ static const int32_t MAX_LINE_SIZE = 1024;
40
+
41
+ int32_t find(const std::string&) const;
42
+ int32_t find(const std::string&, uint32_t h) const;
43
+ void initTableDiscard();
44
+ void initNgrams();
45
+ void reset(std::istream&) const;
46
+ void pushHash(std::vector<int32_t>&, int32_t) const;
47
+ void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
48
+
49
+ std::shared_ptr<Args> args_;
50
+ std::vector<int32_t> word2int_;
51
+ std::vector<entry> words_;
52
+
53
+ std::vector<real> pdiscard_;
54
+ int32_t size_;
55
+ int32_t nwords_;
56
+ int32_t nlabels_;
57
+ int64_t ntokens_;
58
+
59
+ int64_t pruneidx_size_;
60
+ std::unordered_map<int32_t, int32_t> pruneidx_;
61
+ void addWordNgrams(
62
+ std::vector<int32_t>& line,
63
+ const std::vector<int32_t>& hashes,
64
+ int32_t n) const;
65
+
66
+
67
+ public:
68
+ static const std::string EOS;
69
+ static const std::string BOW;
70
+ static const std::string EOW;
71
+
72
+ explicit Dictionary(std::shared_ptr<Args>);
73
+ int32_t nwords() const;
74
+ int32_t nlabels() const;
75
+ int64_t ntokens() const;
76
+ int32_t getId(const std::string&) const;
77
+ int32_t getId(const std::string&, uint32_t h) const;
78
+ entry_type getType(int32_t) const;
79
+ entry_type getType(const std::string&) const;
80
+ bool discard(int32_t, real) const;
81
+ std::string getWord(int32_t) const;
82
+ const std::vector<int32_t>& getSubwords(int32_t) const;
83
+ const std::vector<int32_t> getSubwords(const std::string&) const;
84
+ void computeSubwords(const std::string&, std::vector<int32_t>&) const;
85
+ void computeSubwords(
86
+ const std::string&,
87
+ std::vector<int32_t>&,
88
+ std::vector<std::string>&) const;
89
+ void getSubwords(
90
+ const std::string&,
91
+ std::vector<int32_t>&,
92
+ std::vector<std::string>&) const;
93
+ uint32_t hash(const std::string& str) const;
94
+ void add(const std::string&);
95
+ bool readWord(std::istream&, std::string&) const;
96
+ void readFromFile(std::istream&);
97
+ std::string getLabel(int32_t) const;
98
+ void save(std::ostream&) const;
99
+ void load(std::istream&);
100
+ std::vector<int64_t> getCounts(entry_type) const;
101
+ int32_t getLine(std::istream&, std::vector<int32_t>&,
102
+ std::vector<int32_t>&, std::minstd_rand&) const;
103
+ int32_t getLine(std::istream&, std::vector<int32_t>&,
104
+ std::minstd_rand&) const;
105
+ void threshold(int64_t, int64_t);
106
+ void prune(std::vector<int32_t>&);
107
+ bool isPruned() { return pruneidx_size_ >= 0; }
108
+ };
109
+
110
+ }
111
+
112
+ #endif
@@ -0,0 +1,693 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #include "fasttext.h"
11
+
12
+ #include <math.h>
13
+
14
+ #include <iostream>
15
+ #include <sstream>
16
+ #include <iomanip>
17
+ #include <thread>
18
+ #include <string>
19
+ #include <vector>
20
+ #include <queue>
21
+ #include <algorithm>
22
+
23
+
24
+ namespace fasttext {
25
+
26
+ FastText::FastText() : quant_(false) {}
27
+
28
+ void FastText::getVector(Vector& vec, const std::string& word) const {
29
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
30
+ vec.zero();
31
+ for (auto it = ngrams.begin(); it != ngrams.end(); ++it) {
32
+ if (quant_) {
33
+ vec.addRow(*qinput_, *it);
34
+ } else {
35
+ vec.addRow(*input_, *it);
36
+ }
37
+ }
38
+ if (ngrams.size() > 0) {
39
+ vec.mul(1.0 / ngrams.size());
40
+ }
41
+ }
42
+
43
+ void FastText::saveVectors() {
44
+ std::ofstream ofs(args_->output + ".vec");
45
+ if (!ofs.is_open()) {
46
+ std::cerr << "Error opening file for saving vectors." << std::endl;
47
+ exit(EXIT_FAILURE);
48
+ }
49
+ ofs << dict_->nwords() << " " << args_->dim << std::endl;
50
+ Vector vec(args_->dim);
51
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
52
+ std::string word = dict_->getWord(i);
53
+ getVector(vec, word);
54
+ ofs << word << " " << vec << std::endl;
55
+ }
56
+ ofs.close();
57
+ }
58
+
59
+ void FastText::saveOutput() {
60
+ std::ofstream ofs(args_->output + ".output");
61
+ if (!ofs.is_open()) {
62
+ std::cerr << "Error opening file for saving vectors." << std::endl;
63
+ exit(EXIT_FAILURE);
64
+ }
65
+ if (quant_) {
66
+ std::cerr << "Option -saveOutput is not supported for quantized models."
67
+ << std::endl;
68
+ return;
69
+ }
70
+ int32_t n = (args_->model == model_name::sup) ? dict_->nlabels()
71
+ : dict_->nwords();
72
+ ofs << n << " " << args_->dim << std::endl;
73
+ Vector vec(args_->dim);
74
+ for (int32_t i = 0; i < n; i++) {
75
+ std::string word = (args_->model == model_name::sup) ? dict_->getLabel(i)
76
+ : dict_->getWord(i);
77
+ vec.zero();
78
+ vec.addRow(*output_, i);
79
+ ofs << word << " " << vec << std::endl;
80
+ }
81
+ ofs.close();
82
+ }
83
+
84
+ bool FastText::checkModel(std::istream& in) {
85
+ int32_t magic;
86
+ in.read((char*)&(magic), sizeof(int32_t));
87
+ if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) {
88
+ return false;
89
+ }
90
+ in.read((char*)&(version), sizeof(int32_t));
91
+ if (version > FASTTEXT_VERSION) {
92
+ return false;
93
+ }
94
+ return true;
95
+ }
96
+
97
+ void FastText::signModel(std::ostream& out) {
98
+ const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32;
99
+ const int32_t version = FASTTEXT_VERSION;
100
+ out.write((char*)&(magic), sizeof(int32_t));
101
+ out.write((char*)&(version), sizeof(int32_t));
102
+ }
103
+
104
+ void FastText::saveModel() {
105
+ std::string fn(args_->output);
106
+ if (quant_) {
107
+ fn += ".ftz";
108
+ } else {
109
+ fn += ".bin";
110
+ }
111
+ std::ofstream ofs(fn, std::ofstream::binary);
112
+ if (!ofs.is_open()) {
113
+ std::cerr << "Model file cannot be opened for saving!" << std::endl;
114
+ exit(EXIT_FAILURE);
115
+ }
116
+ signModel(ofs);
117
+ args_->save(ofs);
118
+ dict_->save(ofs);
119
+
120
+ ofs.write((char*)&(quant_), sizeof(bool));
121
+ if (quant_) {
122
+ qinput_->save(ofs);
123
+ } else {
124
+ input_->save(ofs);
125
+ }
126
+
127
+ ofs.write((char*)&(args_->qout), sizeof(bool));
128
+ if (quant_ && args_->qout) {
129
+ qoutput_->save(ofs);
130
+ } else {
131
+ output_->save(ofs);
132
+ }
133
+
134
+ ofs.close();
135
+ }
136
+
137
+ void FastText::loadModel(const std::string& filename) {
138
+ std::ifstream ifs(filename, std::ifstream::binary);
139
+ if (!ifs.is_open()) {
140
+ std::cerr << "Model file cannot be opened for loading!" << std::endl;
141
+ exit(EXIT_FAILURE);
142
+ }
143
+ if (!checkModel(ifs)) {
144
+ std::cerr << "Model file has wrong file format!" << std::endl;
145
+ exit(EXIT_FAILURE);
146
+ }
147
+ loadModel(ifs);
148
+ ifs.close();
149
+ }
150
+
151
+ void FastText::loadModel(std::istream& in) {
152
+ args_ = std::make_shared<Args>();
153
+ dict_ = std::make_shared<Dictionary>(args_);
154
+ input_ = std::make_shared<Matrix>();
155
+ output_ = std::make_shared<Matrix>();
156
+ qinput_ = std::make_shared<QMatrix>();
157
+ qoutput_ = std::make_shared<QMatrix>();
158
+ args_->load(in);
159
+ if (version == 11 && args_->model == model_name::sup) {
160
+ // backward compatibility: old supervised models do not use char ngrams.
161
+ args_->maxn = 0;
162
+ }
163
+ dict_->load(in);
164
+
165
+ bool quant_input;
166
+ in.read((char*) &quant_input, sizeof(bool));
167
+ if (quant_input) {
168
+ quant_ = true;
169
+ qinput_->load(in);
170
+ } else {
171
+ input_->load(in);
172
+ }
173
+
174
+ if (!quant_input && dict_->isPruned()) {
175
+ std::cerr << "Invalid model file.\n"
176
+ << "Please download the updated model from www.fasttext.cc.\n"
177
+ << "See issue #332 on Github for more information.\n";
178
+ exit(1);
179
+ }
180
+
181
+ in.read((char*) &args_->qout, sizeof(bool));
182
+ if (quant_ && args_->qout) {
183
+ qoutput_->load(in);
184
+ } else {
185
+ output_->load(in);
186
+ }
187
+
188
+ model_ = std::make_shared<Model>(input_, output_, args_, 0);
189
+ model_->quant_ = quant_;
190
+ model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
191
+
192
+ if (args_->model == model_name::sup) {
193
+ model_->setTargetCounts(dict_->getCounts(entry_type::label));
194
+ } else {
195
+ model_->setTargetCounts(dict_->getCounts(entry_type::word));
196
+ }
197
+ }
198
+
199
+ void FastText::printInfo(real progress, real loss) {
200
+ real t = real(clock() - start) / CLOCKS_PER_SEC;
201
+ real wst = real(tokenCount) / t;
202
+ real lr = args_->lr * (1.0 - progress);
203
+ int eta = int(t / progress * (1 - progress) / args_->thread);
204
+ int etah = eta / 3600;
205
+ int etam = (eta - etah * 3600) / 60;
206
+ std::cerr << std::fixed;
207
+ std::cerr << "\rProgress: " << std::setprecision(1) << 100 * progress << "%";
208
+ std::cerr << " words/sec/thread: " << std::setprecision(0) << wst;
209
+ std::cerr << " lr: " << std::setprecision(6) << lr;
210
+ std::cerr << " loss: " << std::setprecision(6) << loss;
211
+ std::cerr << " eta: " << etah << "h" << etam << "m ";
212
+ std::cerr << std::flush;
213
+ }
214
+
215
+ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
216
+ Vector norms(input_->m_);
217
+ input_->l2NormRow(norms);
218
+ std::vector<int32_t> idx(input_->m_, 0);
219
+ std::iota(idx.begin(), idx.end(), 0);
220
+ auto eosid = dict_->getId(Dictionary::EOS);
221
+ std::sort(idx.begin(), idx.end(),
222
+ [&norms, eosid] (size_t i1, size_t i2) {
223
+ return eosid ==i1 || (eosid != i2 && norms[i1] > norms[i2]);
224
+ });
225
+ idx.erase(idx.begin() + cutoff, idx.end());
226
+ return idx;
227
+ }
228
+
229
+ void FastText::quantize(std::shared_ptr<Args> qargs) {
230
+ if (qargs->output.empty()) {
231
+ std::cerr<<"No model provided!"<<std::endl;
232
+ exit(1);
233
+ }
234
+ loadModel(qargs->output + ".bin");
235
+
236
+ args_->input = qargs->input;
237
+ args_->qout = qargs->qout;
238
+ args_->output = qargs->output;
239
+
240
+
241
+ if (qargs->cutoff > 0 && qargs->cutoff < input_->m_) {
242
+ auto idx = selectEmbeddings(qargs->cutoff);
243
+ dict_->prune(idx);
244
+ std::shared_ptr<Matrix> ninput =
245
+ std::make_shared<Matrix> (idx.size(), args_->dim);
246
+ for (auto i = 0; i < idx.size(); i++) {
247
+ for (auto j = 0; j < args_->dim; j++) {
248
+ ninput->at(i,j) = input_->at(idx[i], j);
249
+ }
250
+ }
251
+ input_ = ninput;
252
+ if (qargs->retrain) {
253
+ args_->epoch = qargs->epoch;
254
+ args_->lr = qargs->lr;
255
+ args_->thread = qargs->thread;
256
+ args_->verbose = qargs->verbose;
257
+ start = clock();
258
+ tokenCount = 0;
259
+ start = clock();
260
+ std::vector<std::thread> threads;
261
+ for (int32_t i = 0; i < args_->thread; i++) {
262
+ threads.push_back(std::thread([=]() { trainThread(i); }));
263
+ }
264
+ for (auto it = threads.begin(); it != threads.end(); ++it) {
265
+ it->join();
266
+ }
267
+ }
268
+ }
269
+
270
+ qinput_ = std::make_shared<QMatrix>(*input_, qargs->dsub, qargs->qnorm);
271
+
272
+ if (args_->qout) {
273
+ qoutput_ = std::make_shared<QMatrix>(*output_, 2, qargs->qnorm);
274
+ }
275
+
276
+ quant_ = true;
277
+ saveModel();
278
+ }
279
+
280
+ void FastText::supervised(Model& model, real lr,
281
+ const std::vector<int32_t>& line,
282
+ const std::vector<int32_t>& labels) {
283
+ if (labels.size() == 0 || line.size() == 0) return;
284
+ std::uniform_int_distribution<> uniform(0, labels.size() - 1);
285
+ int32_t i = uniform(model.rng);
286
+ model.update(line, labels[i], lr);
287
+ }
288
+
289
+ void FastText::cbow(Model& model, real lr,
290
+ const std::vector<int32_t>& line) {
291
+ std::vector<int32_t> bow;
292
+ std::uniform_int_distribution<> uniform(1, args_->ws);
293
+ for (int32_t w = 0; w < line.size(); w++) {
294
+ int32_t boundary = uniform(model.rng);
295
+ bow.clear();
296
+ for (int32_t c = -boundary; c <= boundary; c++) {
297
+ if (c != 0 && w + c >= 0 && w + c < line.size()) {
298
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w + c]);
299
+ bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend());
300
+ }
301
+ }
302
+ model.update(bow, line[w], lr);
303
+ }
304
+ }
305
+
306
+ void FastText::skipgram(Model& model, real lr,
307
+ const std::vector<int32_t>& line) {
308
+ std::uniform_int_distribution<> uniform(1, args_->ws);
309
+ for (int32_t w = 0; w < line.size(); w++) {
310
+ int32_t boundary = uniform(model.rng);
311
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w]);
312
+ for (int32_t c = -boundary; c <= boundary; c++) {
313
+ if (c != 0 && w + c >= 0 && w + c < line.size()) {
314
+ model.update(ngrams, line[w + c], lr);
315
+ }
316
+ }
317
+ }
318
+ }
319
+
320
+ void FastText::test(std::istream& in, int32_t k) {
321
+ int32_t nexamples = 0, nlabels = 0;
322
+ double precision = 0.0;
323
+ std::vector<int32_t> line, labels;
324
+
325
+ while (in.peek() != EOF) {
326
+ dict_->getLine(in, line, labels, model_->rng);
327
+ if (labels.size() > 0 && line.size() > 0) {
328
+ std::vector<std::pair<real, int32_t>> modelPredictions;
329
+ model_->predict(line, k, modelPredictions);
330
+ for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
331
+ if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
332
+ precision += 1.0;
333
+ }
334
+ }
335
+ nexamples++;
336
+ nlabels += labels.size();
337
+ }
338
+ }
339
+ std::cout << "N" << "\t" << nexamples << std::endl;
340
+ std::cout << std::setprecision(3);
341
+ std::cout << "P@" << k << "\t" << precision / (k * nexamples) << std::endl;
342
+ std::cout << "R@" << k << "\t" << precision / nlabels << std::endl;
343
+ std::cerr << "Number of examples: " << nexamples << std::endl;
344
+ }
345
+
346
+ void FastText::predict(std::istream& in, int32_t k,
347
+ std::vector<std::pair<real,std::string>>& predictions) const {
348
+ std::vector<int32_t> words, labels;
349
+ predictions.clear();
350
+ dict_->getLine(in, words, labels, model_->rng);
351
+ predictions.clear();
352
+ if (words.empty()) return;
353
+ Vector hidden(args_->dim);
354
+ Vector output(dict_->nlabels());
355
+ std::vector<std::pair<real,int32_t>> modelPredictions;
356
+ model_->predict(words, k, modelPredictions, hidden, output);
357
+ for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
358
+ predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second)));
359
+ }
360
+ }
361
+
362
+ void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
363
+ std::vector<std::pair<real,std::string>> predictions;
364
+ while (in.peek() != EOF) {
365
+ predictions.clear();
366
+ predict(in, k, predictions);
367
+ if (predictions.empty()) {
368
+ std::cout << std::endl;
369
+ continue;
370
+ }
371
+ for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
372
+ if (it != predictions.cbegin()) {
373
+ std::cout << " ";
374
+ }
375
+ std::cout << it->second;
376
+ if (print_prob) {
377
+ std::cout << " " << exp(it->first);
378
+ }
379
+ }
380
+ std::cout << std::endl;
381
+ }
382
+ }
383
+
384
+ void FastText::wordVectors() {
385
+ std::string word;
386
+ Vector vec(args_->dim);
387
+ while (std::cin >> word) {
388
+ getVector(vec, word);
389
+ std::cout << word << " " << vec << std::endl;
390
+ }
391
+ }
392
+
393
+ void FastText::sentenceVectors() {
394
+ Vector vec(args_->dim);
395
+ std::string sentence;
396
+ Vector svec(args_->dim);
397
+ std::string word;
398
+ while (std::getline(std::cin, sentence)) {
399
+ std::istringstream iss(sentence);
400
+ svec.zero();
401
+ int32_t count = 0;
402
+ while(iss >> word) {
403
+ getVector(vec, word);
404
+ real norm = vec.norm();
405
+ if (norm > 0) {
406
+ vec.mul(1.0 / norm);
407
+ svec.addVector(vec);
408
+ count++;
409
+ }
410
+ }
411
+ if (count > 0) {
412
+ svec.mul(1.0 / count);
413
+ }
414
+ std::cout << sentence << " " << svec << std::endl;
415
+ }
416
+ }
417
+
418
+ std::shared_ptr<const Dictionary> FastText::getDictionary() const {
419
+ return dict_;
420
+ }
421
+
422
+ void FastText::ngramVectors(std::string word) {
423
+ std::vector<int32_t> ngrams;
424
+ std::vector<std::string> substrings;
425
+ Vector vec(args_->dim);
426
+ dict_->getSubwords(word, ngrams, substrings);
427
+ for (int32_t i = 0; i < ngrams.size(); i++) {
428
+ vec.zero();
429
+ if (ngrams[i] >= 0) {
430
+ if (quant_) {
431
+ vec.addRow(*qinput_, ngrams[i]);
432
+ } else {
433
+ vec.addRow(*input_, ngrams[i]);
434
+ }
435
+ }
436
+ std::cout << substrings[i] << " " << vec << std::endl;
437
+ }
438
+ }
439
+
440
+ void FastText::textVectors() {
441
+ std::vector<int32_t> line, labels;
442
+ Vector vec(args_->dim);
443
+ while (std::cin.peek() != EOF) {
444
+ dict_->getLine(std::cin, line, labels, model_->rng);
445
+ vec.zero();
446
+ for (auto it = line.cbegin(); it != line.cend(); ++it) {
447
+ if (quant_) {
448
+ vec.addRow(*qinput_, *it);
449
+ } else {
450
+ vec.addRow(*input_, *it);
451
+ }
452
+ }
453
+ if (!line.empty()) {
454
+ vec.mul(1.0 / line.size());
455
+ }
456
+ std::cout << vec << std::endl;
457
+ }
458
+ }
459
+
460
+ void FastText::printWordVectors() {
461
+ wordVectors();
462
+ }
463
+
464
+ void FastText::printSentenceVectors() {
465
+ if (args_->model == model_name::sup) {
466
+ textVectors();
467
+ } else {
468
+ sentenceVectors();
469
+ }
470
+ }
471
+
472
+ void FastText::precomputeWordVectors(Matrix& wordVectors) {
473
+ Vector vec(args_->dim);
474
+ wordVectors.zero();
475
+ std::cerr << "Pre-computing word vectors...";
476
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
477
+ std::string word = dict_->getWord(i);
478
+ getVector(vec, word);
479
+ real norm = vec.norm();
480
+ if (norm > 0) {
481
+ wordVectors.addRow(vec, i, 1.0 / norm);
482
+ }
483
+ }
484
+ std::cerr << " done." << std::endl;
485
+ }
486
+
487
+ void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
488
+ int32_t k, const std::set<std::string>& banSet) {
489
+ real queryNorm = queryVec.norm();
490
+ if (std::abs(queryNorm) < 1e-8) {
491
+ queryNorm = 1;
492
+ }
493
+ std::priority_queue<std::pair<real, std::string>> heap;
494
+ Vector vec(args_->dim);
495
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
496
+ std::string word = dict_->getWord(i);
497
+ real dp = wordVectors.dotRow(queryVec, i);
498
+ heap.push(std::make_pair(dp / queryNorm, word));
499
+ }
500
+ int32_t i = 0;
501
+ while (i < k && heap.size() > 0) {
502
+ auto it = banSet.find(heap.top().second);
503
+ if (it == banSet.end()) {
504
+ std::cout << heap.top().second << " " << heap.top().first << std::endl;
505
+ i++;
506
+ }
507
+ heap.pop();
508
+ }
509
+ }
510
+
511
+ void FastText::nn(int32_t k) {
512
+ std::string queryWord;
513
+ Vector queryVec(args_->dim);
514
+ Matrix wordVectors(dict_->nwords(), args_->dim);
515
+ precomputeWordVectors(wordVectors);
516
+ std::set<std::string> banSet;
517
+ std::cout << "Query word? ";
518
+ while (std::cin >> queryWord) {
519
+ banSet.clear();
520
+ banSet.insert(queryWord);
521
+ getVector(queryVec, queryWord);
522
+ findNN(wordVectors, queryVec, k, banSet);
523
+ std::cout << "Query word? ";
524
+ }
525
+ }
526
+
527
+ void FastText::analogies(int32_t k) {
528
+ std::string word;
529
+ Vector buffer(args_->dim), query(args_->dim);
530
+ Matrix wordVectors(dict_->nwords(), args_->dim);
531
+ precomputeWordVectors(wordVectors);
532
+ std::set<std::string> banSet;
533
+ std::cout << "Query triplet (A - B + C)? ";
534
+ while (true) {
535
+ banSet.clear();
536
+ query.zero();
537
+ std::cin >> word;
538
+ banSet.insert(word);
539
+ getVector(buffer, word);
540
+ query.addVector(buffer, 1.0);
541
+ std::cin >> word;
542
+ banSet.insert(word);
543
+ getVector(buffer, word);
544
+ query.addVector(buffer, -1.0);
545
+ std::cin >> word;
546
+ banSet.insert(word);
547
+ getVector(buffer, word);
548
+ query.addVector(buffer, 1.0);
549
+
550
+ findNN(wordVectors, query, k, banSet);
551
+ std::cout << "Query triplet (A - B + C)? ";
552
+ }
553
+ }
554
+
555
+ void FastText::trainThread(int32_t threadId) {
556
+ std::ifstream ifs(args_->input);
557
+ utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
558
+
559
+ Model model(input_, output_, args_, threadId);
560
+ if (args_->model == model_name::sup) {
561
+ model.setTargetCounts(dict_->getCounts(entry_type::label));
562
+ } else {
563
+ model.setTargetCounts(dict_->getCounts(entry_type::word));
564
+ }
565
+
566
+ const int64_t ntokens = dict_->ntokens();
567
+ int64_t localTokenCount = 0;
568
+ std::vector<int32_t> line, labels;
569
+ while (tokenCount < args_->epoch * ntokens) {
570
+ real progress = real(tokenCount) / (args_->epoch * ntokens);
571
+ real lr = args_->lr * (1.0 - progress);
572
+ if (args_->model == model_name::sup) {
573
+ localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
574
+ supervised(model, lr, line, labels);
575
+ } else if (args_->model == model_name::cbow) {
576
+ localTokenCount += dict_->getLine(ifs, line, model.rng);
577
+ cbow(model, lr, line);
578
+ } else if (args_->model == model_name::sg) {
579
+ localTokenCount += dict_->getLine(ifs, line, model.rng);
580
+ skipgram(model, lr, line);
581
+ }
582
+ if (localTokenCount > args_->lrUpdateRate) {
583
+ tokenCount += localTokenCount;
584
+ localTokenCount = 0;
585
+ if (threadId == 0 && args_->verbose > 1) {
586
+ printInfo(progress, model.getLoss());
587
+ }
588
+ }
589
+ }
590
+ if (threadId == 0 && args_->verbose > 0) {
591
+ printInfo(1.0, model.getLoss());
592
+ std::cerr << std::endl;
593
+ }
594
+ ifs.close();
595
+ }
596
+
597
+ void FastText::loadVectors(std::string filename) {
598
+ std::ifstream in(filename);
599
+ std::vector<std::string> words;
600
+ std::shared_ptr<Matrix> mat; // temp. matrix for pretrained vectors
601
+ int64_t n, dim;
602
+ if (!in.is_open()) {
603
+ std::cerr << "Pretrained vectors file cannot be opened!" << std::endl;
604
+ exit(EXIT_FAILURE);
605
+ }
606
+ in >> n >> dim;
607
+ if (dim != args_->dim) {
608
+ std::cerr << "Dimension of pretrained vectors does not match -dim option"
609
+ << std::endl;
610
+ exit(EXIT_FAILURE);
611
+ }
612
+ mat = std::make_shared<Matrix>(n, dim);
613
+ for (size_t i = 0; i < n; i++) {
614
+ std::string word;
615
+ in >> word;
616
+ words.push_back(word);
617
+ dict_->add(word);
618
+ for (size_t j = 0; j < dim; j++) {
619
+ in >> mat->data_[i * dim + j];
620
+ }
621
+ }
622
+ in.close();
623
+
624
+ dict_->threshold(1, 0);
625
+ input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
626
+ input_->uniform(1.0 / args_->dim);
627
+
628
+ for (size_t i = 0; i < n; i++) {
629
+ int32_t idx = dict_->getId(words[i]);
630
+ if (idx < 0 || idx >= dict_->nwords()) continue;
631
+ for (size_t j = 0; j < dim; j++) {
632
+ input_->data_[idx * dim + j] = mat->data_[i * dim + j];
633
+ }
634
+ }
635
+ }
636
+
637
+ void FastText::train(std::shared_ptr<Args> args) {
638
+ args_ = args;
639
+ dict_ = std::make_shared<Dictionary>(args_);
640
+ if (args_->input == "-") {
641
+ // manage expectations
642
+ std::cerr << "Cannot use stdin for training!" << std::endl;
643
+ exit(EXIT_FAILURE);
644
+ }
645
+ std::ifstream ifs(args_->input);
646
+ if (!ifs.is_open()) {
647
+ std::cerr << "Input file cannot be opened!" << std::endl;
648
+ exit(EXIT_FAILURE);
649
+ }
650
+ dict_->readFromFile(ifs);
651
+ ifs.close();
652
+
653
+ if (args_->pretrainedVectors.size() != 0) {
654
+ loadVectors(args_->pretrainedVectors);
655
+ } else {
656
+ input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
657
+ input_->uniform(1.0 / args_->dim);
658
+ }
659
+
660
+ if (args_->model == model_name::sup) {
661
+ output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
662
+ } else {
663
+ output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
664
+ }
665
+ output_->zero();
666
+
667
+ start = clock();
668
+ tokenCount = 0;
669
+ if (args_->thread > 1) {
670
+ std::vector<std::thread> threads;
671
+ for (int32_t i = 0; i < args_->thread; i++) {
672
+ threads.push_back(std::thread([=]() { trainThread(i); }));
673
+ }
674
+ for (auto it = threads.begin(); it != threads.end(); ++it) {
675
+ it->join();
676
+ }
677
+ } else {
678
+ trainThread(0);
679
+ }
680
+ model_ = std::make_shared<Model>(input_, output_, args_, 0);
681
+
682
+ saveModel();
683
+ saveVectors();
684
+ if (args_->saveOutput > 0) {
685
+ saveOutput();
686
+ }
687
+ }
688
+
689
+ int FastText::getDimension() const {
690
+ return args_->dim;
691
+ }
692
+
693
+ }