ffi-fasttext 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,112 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #ifndef FASTTEXT_DICTIONARY_H
11
+ #define FASTTEXT_DICTIONARY_H
12
+
13
+ #include <vector>
14
+ #include <string>
15
+ #include <istream>
16
+ #include <ostream>
17
+ #include <random>
18
+ #include <memory>
19
+ #include <unordered_map>
20
+
21
+ #include "args.h"
22
+ #include "real.h"
23
+
24
+ namespace fasttext {
25
+
26
+ typedef int32_t id_type;
27
+ enum class entry_type : int8_t {word=0, label=1};
28
+
29
+ struct entry {
30
+ std::string word;
31
+ int64_t count;
32
+ entry_type type;
33
+ std::vector<int32_t> subwords;
34
+ };
35
+
36
+ class Dictionary {
37
+ private:
38
+ static const int32_t MAX_VOCAB_SIZE = 30000000;
39
+ static const int32_t MAX_LINE_SIZE = 1024;
40
+
41
+ int32_t find(const std::string&) const;
42
+ int32_t find(const std::string&, uint32_t h) const;
43
+ void initTableDiscard();
44
+ void initNgrams();
45
+ void reset(std::istream&) const;
46
+ void pushHash(std::vector<int32_t>&, int32_t) const;
47
+ void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
48
+
49
+ std::shared_ptr<Args> args_;
50
+ std::vector<int32_t> word2int_;
51
+ std::vector<entry> words_;
52
+
53
+ std::vector<real> pdiscard_;
54
+ int32_t size_;
55
+ int32_t nwords_;
56
+ int32_t nlabels_;
57
+ int64_t ntokens_;
58
+
59
+ int64_t pruneidx_size_;
60
+ std::unordered_map<int32_t, int32_t> pruneidx_;
61
+ void addWordNgrams(
62
+ std::vector<int32_t>& line,
63
+ const std::vector<int32_t>& hashes,
64
+ int32_t n) const;
65
+
66
+
67
+ public:
68
+ static const std::string EOS;
69
+ static const std::string BOW;
70
+ static const std::string EOW;
71
+
72
+ explicit Dictionary(std::shared_ptr<Args>);
73
+ int32_t nwords() const;
74
+ int32_t nlabels() const;
75
+ int64_t ntokens() const;
76
+ int32_t getId(const std::string&) const;
77
+ int32_t getId(const std::string&, uint32_t h) const;
78
+ entry_type getType(int32_t) const;
79
+ entry_type getType(const std::string&) const;
80
+ bool discard(int32_t, real) const;
81
+ std::string getWord(int32_t) const;
82
+ const std::vector<int32_t>& getSubwords(int32_t) const;
83
+ const std::vector<int32_t> getSubwords(const std::string&) const;
84
+ void computeSubwords(const std::string&, std::vector<int32_t>&) const;
85
+ void computeSubwords(
86
+ const std::string&,
87
+ std::vector<int32_t>&,
88
+ std::vector<std::string>&) const;
89
+ void getSubwords(
90
+ const std::string&,
91
+ std::vector<int32_t>&,
92
+ std::vector<std::string>&) const;
93
+ uint32_t hash(const std::string& str) const;
94
+ void add(const std::string&);
95
+ bool readWord(std::istream&, std::string&) const;
96
+ void readFromFile(std::istream&);
97
+ std::string getLabel(int32_t) const;
98
+ void save(std::ostream&) const;
99
+ void load(std::istream&);
100
+ std::vector<int64_t> getCounts(entry_type) const;
101
+ int32_t getLine(std::istream&, std::vector<int32_t>&,
102
+ std::vector<int32_t>&, std::minstd_rand&) const;
103
+ int32_t getLine(std::istream&, std::vector<int32_t>&,
104
+ std::minstd_rand&) const;
105
+ void threshold(int64_t, int64_t);
106
+ void prune(std::vector<int32_t>&);
107
+ bool isPruned() { return pruneidx_size_ >= 0; }
108
+ };
109
+
110
+ }
111
+
112
+ #endif
@@ -0,0 +1,693 @@
1
+ /**
2
+ * Copyright (c) 2016-present, Facebook, Inc.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree. An additional grant
7
+ * of patent rights can be found in the PATENTS file in the same directory.
8
+ */
9
+
10
+ #include "fasttext.h"
11
+
12
+ #include <math.h>
13
+
14
+ #include <iostream>
15
+ #include <sstream>
16
+ #include <iomanip>
17
+ #include <thread>
18
+ #include <string>
19
+ #include <vector>
20
+ #include <queue>
21
+ #include <algorithm>
22
+
23
+
24
+ namespace fasttext {
25
+
26
+ FastText::FastText() : quant_(false) {}
27
+
28
+ void FastText::getVector(Vector& vec, const std::string& word) const {
29
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
30
+ vec.zero();
31
+ for (auto it = ngrams.begin(); it != ngrams.end(); ++it) {
32
+ if (quant_) {
33
+ vec.addRow(*qinput_, *it);
34
+ } else {
35
+ vec.addRow(*input_, *it);
36
+ }
37
+ }
38
+ if (ngrams.size() > 0) {
39
+ vec.mul(1.0 / ngrams.size());
40
+ }
41
+ }
42
+
43
+ void FastText::saveVectors() {
44
+ std::ofstream ofs(args_->output + ".vec");
45
+ if (!ofs.is_open()) {
46
+ std::cerr << "Error opening file for saving vectors." << std::endl;
47
+ exit(EXIT_FAILURE);
48
+ }
49
+ ofs << dict_->nwords() << " " << args_->dim << std::endl;
50
+ Vector vec(args_->dim);
51
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
52
+ std::string word = dict_->getWord(i);
53
+ getVector(vec, word);
54
+ ofs << word << " " << vec << std::endl;
55
+ }
56
+ ofs.close();
57
+ }
58
+
59
+ void FastText::saveOutput() {
60
+ std::ofstream ofs(args_->output + ".output");
61
+ if (!ofs.is_open()) {
62
+ std::cerr << "Error opening file for saving vectors." << std::endl;
63
+ exit(EXIT_FAILURE);
64
+ }
65
+ if (quant_) {
66
+ std::cerr << "Option -saveOutput is not supported for quantized models."
67
+ << std::endl;
68
+ return;
69
+ }
70
+ int32_t n = (args_->model == model_name::sup) ? dict_->nlabels()
71
+ : dict_->nwords();
72
+ ofs << n << " " << args_->dim << std::endl;
73
+ Vector vec(args_->dim);
74
+ for (int32_t i = 0; i < n; i++) {
75
+ std::string word = (args_->model == model_name::sup) ? dict_->getLabel(i)
76
+ : dict_->getWord(i);
77
+ vec.zero();
78
+ vec.addRow(*output_, i);
79
+ ofs << word << " " << vec << std::endl;
80
+ }
81
+ ofs.close();
82
+ }
83
+
84
+ bool FastText::checkModel(std::istream& in) {
85
+ int32_t magic;
86
+ in.read((char*)&(magic), sizeof(int32_t));
87
+ if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) {
88
+ return false;
89
+ }
90
+ in.read((char*)&(version), sizeof(int32_t));
91
+ if (version > FASTTEXT_VERSION) {
92
+ return false;
93
+ }
94
+ return true;
95
+ }
96
+
97
+ void FastText::signModel(std::ostream& out) {
98
+ const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32;
99
+ const int32_t version = FASTTEXT_VERSION;
100
+ out.write((char*)&(magic), sizeof(int32_t));
101
+ out.write((char*)&(version), sizeof(int32_t));
102
+ }
103
+
104
+ void FastText::saveModel() {
105
+ std::string fn(args_->output);
106
+ if (quant_) {
107
+ fn += ".ftz";
108
+ } else {
109
+ fn += ".bin";
110
+ }
111
+ std::ofstream ofs(fn, std::ofstream::binary);
112
+ if (!ofs.is_open()) {
113
+ std::cerr << "Model file cannot be opened for saving!" << std::endl;
114
+ exit(EXIT_FAILURE);
115
+ }
116
+ signModel(ofs);
117
+ args_->save(ofs);
118
+ dict_->save(ofs);
119
+
120
+ ofs.write((char*)&(quant_), sizeof(bool));
121
+ if (quant_) {
122
+ qinput_->save(ofs);
123
+ } else {
124
+ input_->save(ofs);
125
+ }
126
+
127
+ ofs.write((char*)&(args_->qout), sizeof(bool));
128
+ if (quant_ && args_->qout) {
129
+ qoutput_->save(ofs);
130
+ } else {
131
+ output_->save(ofs);
132
+ }
133
+
134
+ ofs.close();
135
+ }
136
+
137
+ void FastText::loadModel(const std::string& filename) {
138
+ std::ifstream ifs(filename, std::ifstream::binary);
139
+ if (!ifs.is_open()) {
140
+ std::cerr << "Model file cannot be opened for loading!" << std::endl;
141
+ exit(EXIT_FAILURE);
142
+ }
143
+ if (!checkModel(ifs)) {
144
+ std::cerr << "Model file has wrong file format!" << std::endl;
145
+ exit(EXIT_FAILURE);
146
+ }
147
+ loadModel(ifs);
148
+ ifs.close();
149
+ }
150
+
151
+ void FastText::loadModel(std::istream& in) {
152
+ args_ = std::make_shared<Args>();
153
+ dict_ = std::make_shared<Dictionary>(args_);
154
+ input_ = std::make_shared<Matrix>();
155
+ output_ = std::make_shared<Matrix>();
156
+ qinput_ = std::make_shared<QMatrix>();
157
+ qoutput_ = std::make_shared<QMatrix>();
158
+ args_->load(in);
159
+ if (version == 11 && args_->model == model_name::sup) {
160
+ // backward compatibility: old supervised models do not use char ngrams.
161
+ args_->maxn = 0;
162
+ }
163
+ dict_->load(in);
164
+
165
+ bool quant_input;
166
+ in.read((char*) &quant_input, sizeof(bool));
167
+ if (quant_input) {
168
+ quant_ = true;
169
+ qinput_->load(in);
170
+ } else {
171
+ input_->load(in);
172
+ }
173
+
174
+ if (!quant_input && dict_->isPruned()) {
175
+ std::cerr << "Invalid model file.\n"
176
+ << "Please download the updated model from www.fasttext.cc.\n"
177
+ << "See issue #332 on Github for more information.\n";
178
+ exit(1);
179
+ }
180
+
181
+ in.read((char*) &args_->qout, sizeof(bool));
182
+ if (quant_ && args_->qout) {
183
+ qoutput_->load(in);
184
+ } else {
185
+ output_->load(in);
186
+ }
187
+
188
+ model_ = std::make_shared<Model>(input_, output_, args_, 0);
189
+ model_->quant_ = quant_;
190
+ model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
191
+
192
+ if (args_->model == model_name::sup) {
193
+ model_->setTargetCounts(dict_->getCounts(entry_type::label));
194
+ } else {
195
+ model_->setTargetCounts(dict_->getCounts(entry_type::word));
196
+ }
197
+ }
198
+
199
+ void FastText::printInfo(real progress, real loss) {
200
+ real t = real(clock() - start) / CLOCKS_PER_SEC;
201
+ real wst = real(tokenCount) / t;
202
+ real lr = args_->lr * (1.0 - progress);
203
+ int eta = int(t / progress * (1 - progress) / args_->thread);
204
+ int etah = eta / 3600;
205
+ int etam = (eta - etah * 3600) / 60;
206
+ std::cerr << std::fixed;
207
+ std::cerr << "\rProgress: " << std::setprecision(1) << 100 * progress << "%";
208
+ std::cerr << " words/sec/thread: " << std::setprecision(0) << wst;
209
+ std::cerr << " lr: " << std::setprecision(6) << lr;
210
+ std::cerr << " loss: " << std::setprecision(6) << loss;
211
+ std::cerr << " eta: " << etah << "h" << etam << "m ";
212
+ std::cerr << std::flush;
213
+ }
214
+
215
+ std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
216
+ Vector norms(input_->m_);
217
+ input_->l2NormRow(norms);
218
+ std::vector<int32_t> idx(input_->m_, 0);
219
+ std::iota(idx.begin(), idx.end(), 0);
220
+ auto eosid = dict_->getId(Dictionary::EOS);
221
+ std::sort(idx.begin(), idx.end(),
222
+ [&norms, eosid] (size_t i1, size_t i2) {
223
+ return eosid ==i1 || (eosid != i2 && norms[i1] > norms[i2]);
224
+ });
225
+ idx.erase(idx.begin() + cutoff, idx.end());
226
+ return idx;
227
+ }
228
+
229
+ void FastText::quantize(std::shared_ptr<Args> qargs) {
230
+ if (qargs->output.empty()) {
231
+ std::cerr<<"No model provided!"<<std::endl;
232
+ exit(1);
233
+ }
234
+ loadModel(qargs->output + ".bin");
235
+
236
+ args_->input = qargs->input;
237
+ args_->qout = qargs->qout;
238
+ args_->output = qargs->output;
239
+
240
+
241
+ if (qargs->cutoff > 0 && qargs->cutoff < input_->m_) {
242
+ auto idx = selectEmbeddings(qargs->cutoff);
243
+ dict_->prune(idx);
244
+ std::shared_ptr<Matrix> ninput =
245
+ std::make_shared<Matrix> (idx.size(), args_->dim);
246
+ for (auto i = 0; i < idx.size(); i++) {
247
+ for (auto j = 0; j < args_->dim; j++) {
248
+ ninput->at(i,j) = input_->at(idx[i], j);
249
+ }
250
+ }
251
+ input_ = ninput;
252
+ if (qargs->retrain) {
253
+ args_->epoch = qargs->epoch;
254
+ args_->lr = qargs->lr;
255
+ args_->thread = qargs->thread;
256
+ args_->verbose = qargs->verbose;
257
+ start = clock();
258
+ tokenCount = 0;
259
+ start = clock();
260
+ std::vector<std::thread> threads;
261
+ for (int32_t i = 0; i < args_->thread; i++) {
262
+ threads.push_back(std::thread([=]() { trainThread(i); }));
263
+ }
264
+ for (auto it = threads.begin(); it != threads.end(); ++it) {
265
+ it->join();
266
+ }
267
+ }
268
+ }
269
+
270
+ qinput_ = std::make_shared<QMatrix>(*input_, qargs->dsub, qargs->qnorm);
271
+
272
+ if (args_->qout) {
273
+ qoutput_ = std::make_shared<QMatrix>(*output_, 2, qargs->qnorm);
274
+ }
275
+
276
+ quant_ = true;
277
+ saveModel();
278
+ }
279
+
280
+ void FastText::supervised(Model& model, real lr,
281
+ const std::vector<int32_t>& line,
282
+ const std::vector<int32_t>& labels) {
283
+ if (labels.size() == 0 || line.size() == 0) return;
284
+ std::uniform_int_distribution<> uniform(0, labels.size() - 1);
285
+ int32_t i = uniform(model.rng);
286
+ model.update(line, labels[i], lr);
287
+ }
288
+
289
+ void FastText::cbow(Model& model, real lr,
290
+ const std::vector<int32_t>& line) {
291
+ std::vector<int32_t> bow;
292
+ std::uniform_int_distribution<> uniform(1, args_->ws);
293
+ for (int32_t w = 0; w < line.size(); w++) {
294
+ int32_t boundary = uniform(model.rng);
295
+ bow.clear();
296
+ for (int32_t c = -boundary; c <= boundary; c++) {
297
+ if (c != 0 && w + c >= 0 && w + c < line.size()) {
298
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w + c]);
299
+ bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend());
300
+ }
301
+ }
302
+ model.update(bow, line[w], lr);
303
+ }
304
+ }
305
+
306
+ void FastText::skipgram(Model& model, real lr,
307
+ const std::vector<int32_t>& line) {
308
+ std::uniform_int_distribution<> uniform(1, args_->ws);
309
+ for (int32_t w = 0; w < line.size(); w++) {
310
+ int32_t boundary = uniform(model.rng);
311
+ const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w]);
312
+ for (int32_t c = -boundary; c <= boundary; c++) {
313
+ if (c != 0 && w + c >= 0 && w + c < line.size()) {
314
+ model.update(ngrams, line[w + c], lr);
315
+ }
316
+ }
317
+ }
318
+ }
319
+
320
+ void FastText::test(std::istream& in, int32_t k) {
321
+ int32_t nexamples = 0, nlabels = 0;
322
+ double precision = 0.0;
323
+ std::vector<int32_t> line, labels;
324
+
325
+ while (in.peek() != EOF) {
326
+ dict_->getLine(in, line, labels, model_->rng);
327
+ if (labels.size() > 0 && line.size() > 0) {
328
+ std::vector<std::pair<real, int32_t>> modelPredictions;
329
+ model_->predict(line, k, modelPredictions);
330
+ for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
331
+ if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
332
+ precision += 1.0;
333
+ }
334
+ }
335
+ nexamples++;
336
+ nlabels += labels.size();
337
+ }
338
+ }
339
+ std::cout << "N" << "\t" << nexamples << std::endl;
340
+ std::cout << std::setprecision(3);
341
+ std::cout << "P@" << k << "\t" << precision / (k * nexamples) << std::endl;
342
+ std::cout << "R@" << k << "\t" << precision / nlabels << std::endl;
343
+ std::cerr << "Number of examples: " << nexamples << std::endl;
344
+ }
345
+
346
+ void FastText::predict(std::istream& in, int32_t k,
347
+ std::vector<std::pair<real,std::string>>& predictions) const {
348
+ std::vector<int32_t> words, labels;
349
+ predictions.clear();
350
+ dict_->getLine(in, words, labels, model_->rng);
351
+ predictions.clear();
352
+ if (words.empty()) return;
353
+ Vector hidden(args_->dim);
354
+ Vector output(dict_->nlabels());
355
+ std::vector<std::pair<real,int32_t>> modelPredictions;
356
+ model_->predict(words, k, modelPredictions, hidden, output);
357
+ for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
358
+ predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second)));
359
+ }
360
+ }
361
+
362
+ void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
363
+ std::vector<std::pair<real,std::string>> predictions;
364
+ while (in.peek() != EOF) {
365
+ predictions.clear();
366
+ predict(in, k, predictions);
367
+ if (predictions.empty()) {
368
+ std::cout << std::endl;
369
+ continue;
370
+ }
371
+ for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
372
+ if (it != predictions.cbegin()) {
373
+ std::cout << " ";
374
+ }
375
+ std::cout << it->second;
376
+ if (print_prob) {
377
+ std::cout << " " << exp(it->first);
378
+ }
379
+ }
380
+ std::cout << std::endl;
381
+ }
382
+ }
383
+
384
+ void FastText::wordVectors() {
385
+ std::string word;
386
+ Vector vec(args_->dim);
387
+ while (std::cin >> word) {
388
+ getVector(vec, word);
389
+ std::cout << word << " " << vec << std::endl;
390
+ }
391
+ }
392
+
393
+ void FastText::sentenceVectors() {
394
+ Vector vec(args_->dim);
395
+ std::string sentence;
396
+ Vector svec(args_->dim);
397
+ std::string word;
398
+ while (std::getline(std::cin, sentence)) {
399
+ std::istringstream iss(sentence);
400
+ svec.zero();
401
+ int32_t count = 0;
402
+ while(iss >> word) {
403
+ getVector(vec, word);
404
+ real norm = vec.norm();
405
+ if (norm > 0) {
406
+ vec.mul(1.0 / norm);
407
+ svec.addVector(vec);
408
+ count++;
409
+ }
410
+ }
411
+ if (count > 0) {
412
+ svec.mul(1.0 / count);
413
+ }
414
+ std::cout << sentence << " " << svec << std::endl;
415
+ }
416
+ }
417
+
418
+ std::shared_ptr<const Dictionary> FastText::getDictionary() const {
419
+ return dict_;
420
+ }
421
+
422
+ void FastText::ngramVectors(std::string word) {
423
+ std::vector<int32_t> ngrams;
424
+ std::vector<std::string> substrings;
425
+ Vector vec(args_->dim);
426
+ dict_->getSubwords(word, ngrams, substrings);
427
+ for (int32_t i = 0; i < ngrams.size(); i++) {
428
+ vec.zero();
429
+ if (ngrams[i] >= 0) {
430
+ if (quant_) {
431
+ vec.addRow(*qinput_, ngrams[i]);
432
+ } else {
433
+ vec.addRow(*input_, ngrams[i]);
434
+ }
435
+ }
436
+ std::cout << substrings[i] << " " << vec << std::endl;
437
+ }
438
+ }
439
+
440
+ void FastText::textVectors() {
441
+ std::vector<int32_t> line, labels;
442
+ Vector vec(args_->dim);
443
+ while (std::cin.peek() != EOF) {
444
+ dict_->getLine(std::cin, line, labels, model_->rng);
445
+ vec.zero();
446
+ for (auto it = line.cbegin(); it != line.cend(); ++it) {
447
+ if (quant_) {
448
+ vec.addRow(*qinput_, *it);
449
+ } else {
450
+ vec.addRow(*input_, *it);
451
+ }
452
+ }
453
+ if (!line.empty()) {
454
+ vec.mul(1.0 / line.size());
455
+ }
456
+ std::cout << vec << std::endl;
457
+ }
458
+ }
459
+
460
+ void FastText::printWordVectors() {
461
+ wordVectors();
462
+ }
463
+
464
+ void FastText::printSentenceVectors() {
465
+ if (args_->model == model_name::sup) {
466
+ textVectors();
467
+ } else {
468
+ sentenceVectors();
469
+ }
470
+ }
471
+
472
+ void FastText::precomputeWordVectors(Matrix& wordVectors) {
473
+ Vector vec(args_->dim);
474
+ wordVectors.zero();
475
+ std::cerr << "Pre-computing word vectors...";
476
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
477
+ std::string word = dict_->getWord(i);
478
+ getVector(vec, word);
479
+ real norm = vec.norm();
480
+ if (norm > 0) {
481
+ wordVectors.addRow(vec, i, 1.0 / norm);
482
+ }
483
+ }
484
+ std::cerr << " done." << std::endl;
485
+ }
486
+
487
+ void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
488
+ int32_t k, const std::set<std::string>& banSet) {
489
+ real queryNorm = queryVec.norm();
490
+ if (std::abs(queryNorm) < 1e-8) {
491
+ queryNorm = 1;
492
+ }
493
+ std::priority_queue<std::pair<real, std::string>> heap;
494
+ Vector vec(args_->dim);
495
+ for (int32_t i = 0; i < dict_->nwords(); i++) {
496
+ std::string word = dict_->getWord(i);
497
+ real dp = wordVectors.dotRow(queryVec, i);
498
+ heap.push(std::make_pair(dp / queryNorm, word));
499
+ }
500
+ int32_t i = 0;
501
+ while (i < k && heap.size() > 0) {
502
+ auto it = banSet.find(heap.top().second);
503
+ if (it == banSet.end()) {
504
+ std::cout << heap.top().second << " " << heap.top().first << std::endl;
505
+ i++;
506
+ }
507
+ heap.pop();
508
+ }
509
+ }
510
+
511
+ void FastText::nn(int32_t k) {
512
+ std::string queryWord;
513
+ Vector queryVec(args_->dim);
514
+ Matrix wordVectors(dict_->nwords(), args_->dim);
515
+ precomputeWordVectors(wordVectors);
516
+ std::set<std::string> banSet;
517
+ std::cout << "Query word? ";
518
+ while (std::cin >> queryWord) {
519
+ banSet.clear();
520
+ banSet.insert(queryWord);
521
+ getVector(queryVec, queryWord);
522
+ findNN(wordVectors, queryVec, k, banSet);
523
+ std::cout << "Query word? ";
524
+ }
525
+ }
526
+
527
+ void FastText::analogies(int32_t k) {
528
+ std::string word;
529
+ Vector buffer(args_->dim), query(args_->dim);
530
+ Matrix wordVectors(dict_->nwords(), args_->dim);
531
+ precomputeWordVectors(wordVectors);
532
+ std::set<std::string> banSet;
533
+ std::cout << "Query triplet (A - B + C)? ";
534
+ while (true) {
535
+ banSet.clear();
536
+ query.zero();
537
+ std::cin >> word;
538
+ banSet.insert(word);
539
+ getVector(buffer, word);
540
+ query.addVector(buffer, 1.0);
541
+ std::cin >> word;
542
+ banSet.insert(word);
543
+ getVector(buffer, word);
544
+ query.addVector(buffer, -1.0);
545
+ std::cin >> word;
546
+ banSet.insert(word);
547
+ getVector(buffer, word);
548
+ query.addVector(buffer, 1.0);
549
+
550
+ findNN(wordVectors, query, k, banSet);
551
+ std::cout << "Query triplet (A - B + C)? ";
552
+ }
553
+ }
554
+
555
+ void FastText::trainThread(int32_t threadId) {
556
+ std::ifstream ifs(args_->input);
557
+ utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
558
+
559
+ Model model(input_, output_, args_, threadId);
560
+ if (args_->model == model_name::sup) {
561
+ model.setTargetCounts(dict_->getCounts(entry_type::label));
562
+ } else {
563
+ model.setTargetCounts(dict_->getCounts(entry_type::word));
564
+ }
565
+
566
+ const int64_t ntokens = dict_->ntokens();
567
+ int64_t localTokenCount = 0;
568
+ std::vector<int32_t> line, labels;
569
+ while (tokenCount < args_->epoch * ntokens) {
570
+ real progress = real(tokenCount) / (args_->epoch * ntokens);
571
+ real lr = args_->lr * (1.0 - progress);
572
+ if (args_->model == model_name::sup) {
573
+ localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
574
+ supervised(model, lr, line, labels);
575
+ } else if (args_->model == model_name::cbow) {
576
+ localTokenCount += dict_->getLine(ifs, line, model.rng);
577
+ cbow(model, lr, line);
578
+ } else if (args_->model == model_name::sg) {
579
+ localTokenCount += dict_->getLine(ifs, line, model.rng);
580
+ skipgram(model, lr, line);
581
+ }
582
+ if (localTokenCount > args_->lrUpdateRate) {
583
+ tokenCount += localTokenCount;
584
+ localTokenCount = 0;
585
+ if (threadId == 0 && args_->verbose > 1) {
586
+ printInfo(progress, model.getLoss());
587
+ }
588
+ }
589
+ }
590
+ if (threadId == 0 && args_->verbose > 0) {
591
+ printInfo(1.0, model.getLoss());
592
+ std::cerr << std::endl;
593
+ }
594
+ ifs.close();
595
+ }
596
+
597
+ void FastText::loadVectors(std::string filename) {
598
+ std::ifstream in(filename);
599
+ std::vector<std::string> words;
600
+ std::shared_ptr<Matrix> mat; // temp. matrix for pretrained vectors
601
+ int64_t n, dim;
602
+ if (!in.is_open()) {
603
+ std::cerr << "Pretrained vectors file cannot be opened!" << std::endl;
604
+ exit(EXIT_FAILURE);
605
+ }
606
+ in >> n >> dim;
607
+ if (dim != args_->dim) {
608
+ std::cerr << "Dimension of pretrained vectors does not match -dim option"
609
+ << std::endl;
610
+ exit(EXIT_FAILURE);
611
+ }
612
+ mat = std::make_shared<Matrix>(n, dim);
613
+ for (size_t i = 0; i < n; i++) {
614
+ std::string word;
615
+ in >> word;
616
+ words.push_back(word);
617
+ dict_->add(word);
618
+ for (size_t j = 0; j < dim; j++) {
619
+ in >> mat->data_[i * dim + j];
620
+ }
621
+ }
622
+ in.close();
623
+
624
+ dict_->threshold(1, 0);
625
+ input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
626
+ input_->uniform(1.0 / args_->dim);
627
+
628
+ for (size_t i = 0; i < n; i++) {
629
+ int32_t idx = dict_->getId(words[i]);
630
+ if (idx < 0 || idx >= dict_->nwords()) continue;
631
+ for (size_t j = 0; j < dim; j++) {
632
+ input_->data_[idx * dim + j] = mat->data_[i * dim + j];
633
+ }
634
+ }
635
+ }
636
+
637
+ void FastText::train(std::shared_ptr<Args> args) {
638
+ args_ = args;
639
+ dict_ = std::make_shared<Dictionary>(args_);
640
+ if (args_->input == "-") {
641
+ // manage expectations
642
+ std::cerr << "Cannot use stdin for training!" << std::endl;
643
+ exit(EXIT_FAILURE);
644
+ }
645
+ std::ifstream ifs(args_->input);
646
+ if (!ifs.is_open()) {
647
+ std::cerr << "Input file cannot be opened!" << std::endl;
648
+ exit(EXIT_FAILURE);
649
+ }
650
+ dict_->readFromFile(ifs);
651
+ ifs.close();
652
+
653
+ if (args_->pretrainedVectors.size() != 0) {
654
+ loadVectors(args_->pretrainedVectors);
655
+ } else {
656
+ input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
657
+ input_->uniform(1.0 / args_->dim);
658
+ }
659
+
660
+ if (args_->model == model_name::sup) {
661
+ output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
662
+ } else {
663
+ output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
664
+ }
665
+ output_->zero();
666
+
667
+ start = clock();
668
+ tokenCount = 0;
669
+ if (args_->thread > 1) {
670
+ std::vector<std::thread> threads;
671
+ for (int32_t i = 0; i < args_->thread; i++) {
672
+ threads.push_back(std::thread([=]() { trainThread(i); }));
673
+ }
674
+ for (auto it = threads.begin(); it != threads.end(); ++it) {
675
+ it->join();
676
+ }
677
+ } else {
678
+ trainThread(0);
679
+ }
680
+ model_ = std::make_shared<Model>(input_, output_, args_, 0);
681
+
682
+ saveModel();
683
+ saveVectors();
684
+ if (args_->saveOutput > 0) {
685
+ saveOutput();
686
+ }
687
+ }
688
+
689
+ int FastText::getDimension() const {
690
+ return args_->dim;
691
+ }
692
+
693
+ }