ffi-fasttext 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_DICTIONARY_H
|
11
|
+
#define FASTTEXT_DICTIONARY_H
|
12
|
+
|
13
|
+
#include <vector>
|
14
|
+
#include <string>
|
15
|
+
#include <istream>
|
16
|
+
#include <ostream>
|
17
|
+
#include <random>
|
18
|
+
#include <memory>
|
19
|
+
#include <unordered_map>
|
20
|
+
|
21
|
+
#include "args.h"
|
22
|
+
#include "real.h"
|
23
|
+
|
24
|
+
namespace fasttext {
|
25
|
+
|
26
|
+
typedef int32_t id_type;
|
27
|
+
enum class entry_type : int8_t {word=0, label=1};
|
28
|
+
|
29
|
+
struct entry {
|
30
|
+
std::string word;
|
31
|
+
int64_t count;
|
32
|
+
entry_type type;
|
33
|
+
std::vector<int32_t> subwords;
|
34
|
+
};
|
35
|
+
|
36
|
+
class Dictionary {
|
37
|
+
private:
|
38
|
+
static const int32_t MAX_VOCAB_SIZE = 30000000;
|
39
|
+
static const int32_t MAX_LINE_SIZE = 1024;
|
40
|
+
|
41
|
+
int32_t find(const std::string&) const;
|
42
|
+
int32_t find(const std::string&, uint32_t h) const;
|
43
|
+
void initTableDiscard();
|
44
|
+
void initNgrams();
|
45
|
+
void reset(std::istream&) const;
|
46
|
+
void pushHash(std::vector<int32_t>&, int32_t) const;
|
47
|
+
void addSubwords(std::vector<int32_t>&, const std::string&, int32_t) const;
|
48
|
+
|
49
|
+
std::shared_ptr<Args> args_;
|
50
|
+
std::vector<int32_t> word2int_;
|
51
|
+
std::vector<entry> words_;
|
52
|
+
|
53
|
+
std::vector<real> pdiscard_;
|
54
|
+
int32_t size_;
|
55
|
+
int32_t nwords_;
|
56
|
+
int32_t nlabels_;
|
57
|
+
int64_t ntokens_;
|
58
|
+
|
59
|
+
int64_t pruneidx_size_;
|
60
|
+
std::unordered_map<int32_t, int32_t> pruneidx_;
|
61
|
+
void addWordNgrams(
|
62
|
+
std::vector<int32_t>& line,
|
63
|
+
const std::vector<int32_t>& hashes,
|
64
|
+
int32_t n) const;
|
65
|
+
|
66
|
+
|
67
|
+
public:
|
68
|
+
static const std::string EOS;
|
69
|
+
static const std::string BOW;
|
70
|
+
static const std::string EOW;
|
71
|
+
|
72
|
+
explicit Dictionary(std::shared_ptr<Args>);
|
73
|
+
int32_t nwords() const;
|
74
|
+
int32_t nlabels() const;
|
75
|
+
int64_t ntokens() const;
|
76
|
+
int32_t getId(const std::string&) const;
|
77
|
+
int32_t getId(const std::string&, uint32_t h) const;
|
78
|
+
entry_type getType(int32_t) const;
|
79
|
+
entry_type getType(const std::string&) const;
|
80
|
+
bool discard(int32_t, real) const;
|
81
|
+
std::string getWord(int32_t) const;
|
82
|
+
const std::vector<int32_t>& getSubwords(int32_t) const;
|
83
|
+
const std::vector<int32_t> getSubwords(const std::string&) const;
|
84
|
+
void computeSubwords(const std::string&, std::vector<int32_t>&) const;
|
85
|
+
void computeSubwords(
|
86
|
+
const std::string&,
|
87
|
+
std::vector<int32_t>&,
|
88
|
+
std::vector<std::string>&) const;
|
89
|
+
void getSubwords(
|
90
|
+
const std::string&,
|
91
|
+
std::vector<int32_t>&,
|
92
|
+
std::vector<std::string>&) const;
|
93
|
+
uint32_t hash(const std::string& str) const;
|
94
|
+
void add(const std::string&);
|
95
|
+
bool readWord(std::istream&, std::string&) const;
|
96
|
+
void readFromFile(std::istream&);
|
97
|
+
std::string getLabel(int32_t) const;
|
98
|
+
void save(std::ostream&) const;
|
99
|
+
void load(std::istream&);
|
100
|
+
std::vector<int64_t> getCounts(entry_type) const;
|
101
|
+
int32_t getLine(std::istream&, std::vector<int32_t>&,
|
102
|
+
std::vector<int32_t>&, std::minstd_rand&) const;
|
103
|
+
int32_t getLine(std::istream&, std::vector<int32_t>&,
|
104
|
+
std::minstd_rand&) const;
|
105
|
+
void threshold(int64_t, int64_t);
|
106
|
+
void prune(std::vector<int32_t>&);
|
107
|
+
bool isPruned() { return pruneidx_size_ >= 0; }
|
108
|
+
};
|
109
|
+
|
110
|
+
}
|
111
|
+
|
112
|
+
#endif
|
@@ -0,0 +1,693 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "fasttext.h"
|
11
|
+
|
12
|
+
#include <math.h>
|
13
|
+
|
14
|
+
#include <iostream>
|
15
|
+
#include <sstream>
|
16
|
+
#include <iomanip>
|
17
|
+
#include <thread>
|
18
|
+
#include <string>
|
19
|
+
#include <vector>
|
20
|
+
#include <queue>
|
21
|
+
#include <algorithm>
|
22
|
+
|
23
|
+
|
24
|
+
namespace fasttext {
|
25
|
+
|
26
|
+
FastText::FastText() : quant_(false) {}
|
27
|
+
|
28
|
+
void FastText::getVector(Vector& vec, const std::string& word) const {
|
29
|
+
const std::vector<int32_t>& ngrams = dict_->getSubwords(word);
|
30
|
+
vec.zero();
|
31
|
+
for (auto it = ngrams.begin(); it != ngrams.end(); ++it) {
|
32
|
+
if (quant_) {
|
33
|
+
vec.addRow(*qinput_, *it);
|
34
|
+
} else {
|
35
|
+
vec.addRow(*input_, *it);
|
36
|
+
}
|
37
|
+
}
|
38
|
+
if (ngrams.size() > 0) {
|
39
|
+
vec.mul(1.0 / ngrams.size());
|
40
|
+
}
|
41
|
+
}
|
42
|
+
|
43
|
+
void FastText::saveVectors() {
|
44
|
+
std::ofstream ofs(args_->output + ".vec");
|
45
|
+
if (!ofs.is_open()) {
|
46
|
+
std::cerr << "Error opening file for saving vectors." << std::endl;
|
47
|
+
exit(EXIT_FAILURE);
|
48
|
+
}
|
49
|
+
ofs << dict_->nwords() << " " << args_->dim << std::endl;
|
50
|
+
Vector vec(args_->dim);
|
51
|
+
for (int32_t i = 0; i < dict_->nwords(); i++) {
|
52
|
+
std::string word = dict_->getWord(i);
|
53
|
+
getVector(vec, word);
|
54
|
+
ofs << word << " " << vec << std::endl;
|
55
|
+
}
|
56
|
+
ofs.close();
|
57
|
+
}
|
58
|
+
|
59
|
+
void FastText::saveOutput() {
|
60
|
+
std::ofstream ofs(args_->output + ".output");
|
61
|
+
if (!ofs.is_open()) {
|
62
|
+
std::cerr << "Error opening file for saving vectors." << std::endl;
|
63
|
+
exit(EXIT_FAILURE);
|
64
|
+
}
|
65
|
+
if (quant_) {
|
66
|
+
std::cerr << "Option -saveOutput is not supported for quantized models."
|
67
|
+
<< std::endl;
|
68
|
+
return;
|
69
|
+
}
|
70
|
+
int32_t n = (args_->model == model_name::sup) ? dict_->nlabels()
|
71
|
+
: dict_->nwords();
|
72
|
+
ofs << n << " " << args_->dim << std::endl;
|
73
|
+
Vector vec(args_->dim);
|
74
|
+
for (int32_t i = 0; i < n; i++) {
|
75
|
+
std::string word = (args_->model == model_name::sup) ? dict_->getLabel(i)
|
76
|
+
: dict_->getWord(i);
|
77
|
+
vec.zero();
|
78
|
+
vec.addRow(*output_, i);
|
79
|
+
ofs << word << " " << vec << std::endl;
|
80
|
+
}
|
81
|
+
ofs.close();
|
82
|
+
}
|
83
|
+
|
84
|
+
bool FastText::checkModel(std::istream& in) {
|
85
|
+
int32_t magic;
|
86
|
+
in.read((char*)&(magic), sizeof(int32_t));
|
87
|
+
if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) {
|
88
|
+
return false;
|
89
|
+
}
|
90
|
+
in.read((char*)&(version), sizeof(int32_t));
|
91
|
+
if (version > FASTTEXT_VERSION) {
|
92
|
+
return false;
|
93
|
+
}
|
94
|
+
return true;
|
95
|
+
}
|
96
|
+
|
97
|
+
void FastText::signModel(std::ostream& out) {
|
98
|
+
const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32;
|
99
|
+
const int32_t version = FASTTEXT_VERSION;
|
100
|
+
out.write((char*)&(magic), sizeof(int32_t));
|
101
|
+
out.write((char*)&(version), sizeof(int32_t));
|
102
|
+
}
|
103
|
+
|
104
|
+
void FastText::saveModel() {
|
105
|
+
std::string fn(args_->output);
|
106
|
+
if (quant_) {
|
107
|
+
fn += ".ftz";
|
108
|
+
} else {
|
109
|
+
fn += ".bin";
|
110
|
+
}
|
111
|
+
std::ofstream ofs(fn, std::ofstream::binary);
|
112
|
+
if (!ofs.is_open()) {
|
113
|
+
std::cerr << "Model file cannot be opened for saving!" << std::endl;
|
114
|
+
exit(EXIT_FAILURE);
|
115
|
+
}
|
116
|
+
signModel(ofs);
|
117
|
+
args_->save(ofs);
|
118
|
+
dict_->save(ofs);
|
119
|
+
|
120
|
+
ofs.write((char*)&(quant_), sizeof(bool));
|
121
|
+
if (quant_) {
|
122
|
+
qinput_->save(ofs);
|
123
|
+
} else {
|
124
|
+
input_->save(ofs);
|
125
|
+
}
|
126
|
+
|
127
|
+
ofs.write((char*)&(args_->qout), sizeof(bool));
|
128
|
+
if (quant_ && args_->qout) {
|
129
|
+
qoutput_->save(ofs);
|
130
|
+
} else {
|
131
|
+
output_->save(ofs);
|
132
|
+
}
|
133
|
+
|
134
|
+
ofs.close();
|
135
|
+
}
|
136
|
+
|
137
|
+
void FastText::loadModel(const std::string& filename) {
|
138
|
+
std::ifstream ifs(filename, std::ifstream::binary);
|
139
|
+
if (!ifs.is_open()) {
|
140
|
+
std::cerr << "Model file cannot be opened for loading!" << std::endl;
|
141
|
+
exit(EXIT_FAILURE);
|
142
|
+
}
|
143
|
+
if (!checkModel(ifs)) {
|
144
|
+
std::cerr << "Model file has wrong file format!" << std::endl;
|
145
|
+
exit(EXIT_FAILURE);
|
146
|
+
}
|
147
|
+
loadModel(ifs);
|
148
|
+
ifs.close();
|
149
|
+
}
|
150
|
+
|
151
|
+
void FastText::loadModel(std::istream& in) {
|
152
|
+
args_ = std::make_shared<Args>();
|
153
|
+
dict_ = std::make_shared<Dictionary>(args_);
|
154
|
+
input_ = std::make_shared<Matrix>();
|
155
|
+
output_ = std::make_shared<Matrix>();
|
156
|
+
qinput_ = std::make_shared<QMatrix>();
|
157
|
+
qoutput_ = std::make_shared<QMatrix>();
|
158
|
+
args_->load(in);
|
159
|
+
if (version == 11 && args_->model == model_name::sup) {
|
160
|
+
// backward compatibility: old supervised models do not use char ngrams.
|
161
|
+
args_->maxn = 0;
|
162
|
+
}
|
163
|
+
dict_->load(in);
|
164
|
+
|
165
|
+
bool quant_input;
|
166
|
+
in.read((char*) &quant_input, sizeof(bool));
|
167
|
+
if (quant_input) {
|
168
|
+
quant_ = true;
|
169
|
+
qinput_->load(in);
|
170
|
+
} else {
|
171
|
+
input_->load(in);
|
172
|
+
}
|
173
|
+
|
174
|
+
if (!quant_input && dict_->isPruned()) {
|
175
|
+
std::cerr << "Invalid model file.\n"
|
176
|
+
<< "Please download the updated model from www.fasttext.cc.\n"
|
177
|
+
<< "See issue #332 on Github for more information.\n";
|
178
|
+
exit(1);
|
179
|
+
}
|
180
|
+
|
181
|
+
in.read((char*) &args_->qout, sizeof(bool));
|
182
|
+
if (quant_ && args_->qout) {
|
183
|
+
qoutput_->load(in);
|
184
|
+
} else {
|
185
|
+
output_->load(in);
|
186
|
+
}
|
187
|
+
|
188
|
+
model_ = std::make_shared<Model>(input_, output_, args_, 0);
|
189
|
+
model_->quant_ = quant_;
|
190
|
+
model_->setQuantizePointer(qinput_, qoutput_, args_->qout);
|
191
|
+
|
192
|
+
if (args_->model == model_name::sup) {
|
193
|
+
model_->setTargetCounts(dict_->getCounts(entry_type::label));
|
194
|
+
} else {
|
195
|
+
model_->setTargetCounts(dict_->getCounts(entry_type::word));
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
void FastText::printInfo(real progress, real loss) {
|
200
|
+
real t = real(clock() - start) / CLOCKS_PER_SEC;
|
201
|
+
real wst = real(tokenCount) / t;
|
202
|
+
real lr = args_->lr * (1.0 - progress);
|
203
|
+
int eta = int(t / progress * (1 - progress) / args_->thread);
|
204
|
+
int etah = eta / 3600;
|
205
|
+
int etam = (eta - etah * 3600) / 60;
|
206
|
+
std::cerr << std::fixed;
|
207
|
+
std::cerr << "\rProgress: " << std::setprecision(1) << 100 * progress << "%";
|
208
|
+
std::cerr << " words/sec/thread: " << std::setprecision(0) << wst;
|
209
|
+
std::cerr << " lr: " << std::setprecision(6) << lr;
|
210
|
+
std::cerr << " loss: " << std::setprecision(6) << loss;
|
211
|
+
std::cerr << " eta: " << etah << "h" << etam << "m ";
|
212
|
+
std::cerr << std::flush;
|
213
|
+
}
|
214
|
+
|
215
|
+
std::vector<int32_t> FastText::selectEmbeddings(int32_t cutoff) const {
|
216
|
+
Vector norms(input_->m_);
|
217
|
+
input_->l2NormRow(norms);
|
218
|
+
std::vector<int32_t> idx(input_->m_, 0);
|
219
|
+
std::iota(idx.begin(), idx.end(), 0);
|
220
|
+
auto eosid = dict_->getId(Dictionary::EOS);
|
221
|
+
std::sort(idx.begin(), idx.end(),
|
222
|
+
[&norms, eosid] (size_t i1, size_t i2) {
|
223
|
+
return eosid ==i1 || (eosid != i2 && norms[i1] > norms[i2]);
|
224
|
+
});
|
225
|
+
idx.erase(idx.begin() + cutoff, idx.end());
|
226
|
+
return idx;
|
227
|
+
}
|
228
|
+
|
229
|
+
void FastText::quantize(std::shared_ptr<Args> qargs) {
|
230
|
+
if (qargs->output.empty()) {
|
231
|
+
std::cerr<<"No model provided!"<<std::endl;
|
232
|
+
exit(1);
|
233
|
+
}
|
234
|
+
loadModel(qargs->output + ".bin");
|
235
|
+
|
236
|
+
args_->input = qargs->input;
|
237
|
+
args_->qout = qargs->qout;
|
238
|
+
args_->output = qargs->output;
|
239
|
+
|
240
|
+
|
241
|
+
if (qargs->cutoff > 0 && qargs->cutoff < input_->m_) {
|
242
|
+
auto idx = selectEmbeddings(qargs->cutoff);
|
243
|
+
dict_->prune(idx);
|
244
|
+
std::shared_ptr<Matrix> ninput =
|
245
|
+
std::make_shared<Matrix> (idx.size(), args_->dim);
|
246
|
+
for (auto i = 0; i < idx.size(); i++) {
|
247
|
+
for (auto j = 0; j < args_->dim; j++) {
|
248
|
+
ninput->at(i,j) = input_->at(idx[i], j);
|
249
|
+
}
|
250
|
+
}
|
251
|
+
input_ = ninput;
|
252
|
+
if (qargs->retrain) {
|
253
|
+
args_->epoch = qargs->epoch;
|
254
|
+
args_->lr = qargs->lr;
|
255
|
+
args_->thread = qargs->thread;
|
256
|
+
args_->verbose = qargs->verbose;
|
257
|
+
start = clock();
|
258
|
+
tokenCount = 0;
|
259
|
+
start = clock();
|
260
|
+
std::vector<std::thread> threads;
|
261
|
+
for (int32_t i = 0; i < args_->thread; i++) {
|
262
|
+
threads.push_back(std::thread([=]() { trainThread(i); }));
|
263
|
+
}
|
264
|
+
for (auto it = threads.begin(); it != threads.end(); ++it) {
|
265
|
+
it->join();
|
266
|
+
}
|
267
|
+
}
|
268
|
+
}
|
269
|
+
|
270
|
+
qinput_ = std::make_shared<QMatrix>(*input_, qargs->dsub, qargs->qnorm);
|
271
|
+
|
272
|
+
if (args_->qout) {
|
273
|
+
qoutput_ = std::make_shared<QMatrix>(*output_, 2, qargs->qnorm);
|
274
|
+
}
|
275
|
+
|
276
|
+
quant_ = true;
|
277
|
+
saveModel();
|
278
|
+
}
|
279
|
+
|
280
|
+
void FastText::supervised(Model& model, real lr,
|
281
|
+
const std::vector<int32_t>& line,
|
282
|
+
const std::vector<int32_t>& labels) {
|
283
|
+
if (labels.size() == 0 || line.size() == 0) return;
|
284
|
+
std::uniform_int_distribution<> uniform(0, labels.size() - 1);
|
285
|
+
int32_t i = uniform(model.rng);
|
286
|
+
model.update(line, labels[i], lr);
|
287
|
+
}
|
288
|
+
|
289
|
+
void FastText::cbow(Model& model, real lr,
|
290
|
+
const std::vector<int32_t>& line) {
|
291
|
+
std::vector<int32_t> bow;
|
292
|
+
std::uniform_int_distribution<> uniform(1, args_->ws);
|
293
|
+
for (int32_t w = 0; w < line.size(); w++) {
|
294
|
+
int32_t boundary = uniform(model.rng);
|
295
|
+
bow.clear();
|
296
|
+
for (int32_t c = -boundary; c <= boundary; c++) {
|
297
|
+
if (c != 0 && w + c >= 0 && w + c < line.size()) {
|
298
|
+
const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w + c]);
|
299
|
+
bow.insert(bow.end(), ngrams.cbegin(), ngrams.cend());
|
300
|
+
}
|
301
|
+
}
|
302
|
+
model.update(bow, line[w], lr);
|
303
|
+
}
|
304
|
+
}
|
305
|
+
|
306
|
+
void FastText::skipgram(Model& model, real lr,
|
307
|
+
const std::vector<int32_t>& line) {
|
308
|
+
std::uniform_int_distribution<> uniform(1, args_->ws);
|
309
|
+
for (int32_t w = 0; w < line.size(); w++) {
|
310
|
+
int32_t boundary = uniform(model.rng);
|
311
|
+
const std::vector<int32_t>& ngrams = dict_->getSubwords(line[w]);
|
312
|
+
for (int32_t c = -boundary; c <= boundary; c++) {
|
313
|
+
if (c != 0 && w + c >= 0 && w + c < line.size()) {
|
314
|
+
model.update(ngrams, line[w + c], lr);
|
315
|
+
}
|
316
|
+
}
|
317
|
+
}
|
318
|
+
}
|
319
|
+
|
320
|
+
void FastText::test(std::istream& in, int32_t k) {
|
321
|
+
int32_t nexamples = 0, nlabels = 0;
|
322
|
+
double precision = 0.0;
|
323
|
+
std::vector<int32_t> line, labels;
|
324
|
+
|
325
|
+
while (in.peek() != EOF) {
|
326
|
+
dict_->getLine(in, line, labels, model_->rng);
|
327
|
+
if (labels.size() > 0 && line.size() > 0) {
|
328
|
+
std::vector<std::pair<real, int32_t>> modelPredictions;
|
329
|
+
model_->predict(line, k, modelPredictions);
|
330
|
+
for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
|
331
|
+
if (std::find(labels.begin(), labels.end(), it->second) != labels.end()) {
|
332
|
+
precision += 1.0;
|
333
|
+
}
|
334
|
+
}
|
335
|
+
nexamples++;
|
336
|
+
nlabels += labels.size();
|
337
|
+
}
|
338
|
+
}
|
339
|
+
std::cout << "N" << "\t" << nexamples << std::endl;
|
340
|
+
std::cout << std::setprecision(3);
|
341
|
+
std::cout << "P@" << k << "\t" << precision / (k * nexamples) << std::endl;
|
342
|
+
std::cout << "R@" << k << "\t" << precision / nlabels << std::endl;
|
343
|
+
std::cerr << "Number of examples: " << nexamples << std::endl;
|
344
|
+
}
|
345
|
+
|
346
|
+
void FastText::predict(std::istream& in, int32_t k,
|
347
|
+
std::vector<std::pair<real,std::string>>& predictions) const {
|
348
|
+
std::vector<int32_t> words, labels;
|
349
|
+
predictions.clear();
|
350
|
+
dict_->getLine(in, words, labels, model_->rng);
|
351
|
+
predictions.clear();
|
352
|
+
if (words.empty()) return;
|
353
|
+
Vector hidden(args_->dim);
|
354
|
+
Vector output(dict_->nlabels());
|
355
|
+
std::vector<std::pair<real,int32_t>> modelPredictions;
|
356
|
+
model_->predict(words, k, modelPredictions, hidden, output);
|
357
|
+
for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) {
|
358
|
+
predictions.push_back(std::make_pair(it->first, dict_->getLabel(it->second)));
|
359
|
+
}
|
360
|
+
}
|
361
|
+
|
362
|
+
void FastText::predict(std::istream& in, int32_t k, bool print_prob) {
|
363
|
+
std::vector<std::pair<real,std::string>> predictions;
|
364
|
+
while (in.peek() != EOF) {
|
365
|
+
predictions.clear();
|
366
|
+
predict(in, k, predictions);
|
367
|
+
if (predictions.empty()) {
|
368
|
+
std::cout << std::endl;
|
369
|
+
continue;
|
370
|
+
}
|
371
|
+
for (auto it = predictions.cbegin(); it != predictions.cend(); it++) {
|
372
|
+
if (it != predictions.cbegin()) {
|
373
|
+
std::cout << " ";
|
374
|
+
}
|
375
|
+
std::cout << it->second;
|
376
|
+
if (print_prob) {
|
377
|
+
std::cout << " " << exp(it->first);
|
378
|
+
}
|
379
|
+
}
|
380
|
+
std::cout << std::endl;
|
381
|
+
}
|
382
|
+
}
|
383
|
+
|
384
|
+
void FastText::wordVectors() {
|
385
|
+
std::string word;
|
386
|
+
Vector vec(args_->dim);
|
387
|
+
while (std::cin >> word) {
|
388
|
+
getVector(vec, word);
|
389
|
+
std::cout << word << " " << vec << std::endl;
|
390
|
+
}
|
391
|
+
}
|
392
|
+
|
393
|
+
void FastText::sentenceVectors() {
|
394
|
+
Vector vec(args_->dim);
|
395
|
+
std::string sentence;
|
396
|
+
Vector svec(args_->dim);
|
397
|
+
std::string word;
|
398
|
+
while (std::getline(std::cin, sentence)) {
|
399
|
+
std::istringstream iss(sentence);
|
400
|
+
svec.zero();
|
401
|
+
int32_t count = 0;
|
402
|
+
while(iss >> word) {
|
403
|
+
getVector(vec, word);
|
404
|
+
real norm = vec.norm();
|
405
|
+
if (norm > 0) {
|
406
|
+
vec.mul(1.0 / norm);
|
407
|
+
svec.addVector(vec);
|
408
|
+
count++;
|
409
|
+
}
|
410
|
+
}
|
411
|
+
if (count > 0) {
|
412
|
+
svec.mul(1.0 / count);
|
413
|
+
}
|
414
|
+
std::cout << sentence << " " << svec << std::endl;
|
415
|
+
}
|
416
|
+
}
|
417
|
+
|
418
|
+
std::shared_ptr<const Dictionary> FastText::getDictionary() const {
|
419
|
+
return dict_;
|
420
|
+
}
|
421
|
+
|
422
|
+
void FastText::ngramVectors(std::string word) {
|
423
|
+
std::vector<int32_t> ngrams;
|
424
|
+
std::vector<std::string> substrings;
|
425
|
+
Vector vec(args_->dim);
|
426
|
+
dict_->getSubwords(word, ngrams, substrings);
|
427
|
+
for (int32_t i = 0; i < ngrams.size(); i++) {
|
428
|
+
vec.zero();
|
429
|
+
if (ngrams[i] >= 0) {
|
430
|
+
if (quant_) {
|
431
|
+
vec.addRow(*qinput_, ngrams[i]);
|
432
|
+
} else {
|
433
|
+
vec.addRow(*input_, ngrams[i]);
|
434
|
+
}
|
435
|
+
}
|
436
|
+
std::cout << substrings[i] << " " << vec << std::endl;
|
437
|
+
}
|
438
|
+
}
|
439
|
+
|
440
|
+
void FastText::textVectors() {
|
441
|
+
std::vector<int32_t> line, labels;
|
442
|
+
Vector vec(args_->dim);
|
443
|
+
while (std::cin.peek() != EOF) {
|
444
|
+
dict_->getLine(std::cin, line, labels, model_->rng);
|
445
|
+
vec.zero();
|
446
|
+
for (auto it = line.cbegin(); it != line.cend(); ++it) {
|
447
|
+
if (quant_) {
|
448
|
+
vec.addRow(*qinput_, *it);
|
449
|
+
} else {
|
450
|
+
vec.addRow(*input_, *it);
|
451
|
+
}
|
452
|
+
}
|
453
|
+
if (!line.empty()) {
|
454
|
+
vec.mul(1.0 / line.size());
|
455
|
+
}
|
456
|
+
std::cout << vec << std::endl;
|
457
|
+
}
|
458
|
+
}
|
459
|
+
|
460
|
+
void FastText::printWordVectors() {
|
461
|
+
wordVectors();
|
462
|
+
}
|
463
|
+
|
464
|
+
void FastText::printSentenceVectors() {
|
465
|
+
if (args_->model == model_name::sup) {
|
466
|
+
textVectors();
|
467
|
+
} else {
|
468
|
+
sentenceVectors();
|
469
|
+
}
|
470
|
+
}
|
471
|
+
|
472
|
+
void FastText::precomputeWordVectors(Matrix& wordVectors) {
|
473
|
+
Vector vec(args_->dim);
|
474
|
+
wordVectors.zero();
|
475
|
+
std::cerr << "Pre-computing word vectors...";
|
476
|
+
for (int32_t i = 0; i < dict_->nwords(); i++) {
|
477
|
+
std::string word = dict_->getWord(i);
|
478
|
+
getVector(vec, word);
|
479
|
+
real norm = vec.norm();
|
480
|
+
if (norm > 0) {
|
481
|
+
wordVectors.addRow(vec, i, 1.0 / norm);
|
482
|
+
}
|
483
|
+
}
|
484
|
+
std::cerr << " done." << std::endl;
|
485
|
+
}
|
486
|
+
|
487
|
+
void FastText::findNN(const Matrix& wordVectors, const Vector& queryVec,
|
488
|
+
int32_t k, const std::set<std::string>& banSet) {
|
489
|
+
real queryNorm = queryVec.norm();
|
490
|
+
if (std::abs(queryNorm) < 1e-8) {
|
491
|
+
queryNorm = 1;
|
492
|
+
}
|
493
|
+
std::priority_queue<std::pair<real, std::string>> heap;
|
494
|
+
Vector vec(args_->dim);
|
495
|
+
for (int32_t i = 0; i < dict_->nwords(); i++) {
|
496
|
+
std::string word = dict_->getWord(i);
|
497
|
+
real dp = wordVectors.dotRow(queryVec, i);
|
498
|
+
heap.push(std::make_pair(dp / queryNorm, word));
|
499
|
+
}
|
500
|
+
int32_t i = 0;
|
501
|
+
while (i < k && heap.size() > 0) {
|
502
|
+
auto it = banSet.find(heap.top().second);
|
503
|
+
if (it == banSet.end()) {
|
504
|
+
std::cout << heap.top().second << " " << heap.top().first << std::endl;
|
505
|
+
i++;
|
506
|
+
}
|
507
|
+
heap.pop();
|
508
|
+
}
|
509
|
+
}
|
510
|
+
|
511
|
+
void FastText::nn(int32_t k) {
|
512
|
+
std::string queryWord;
|
513
|
+
Vector queryVec(args_->dim);
|
514
|
+
Matrix wordVectors(dict_->nwords(), args_->dim);
|
515
|
+
precomputeWordVectors(wordVectors);
|
516
|
+
std::set<std::string> banSet;
|
517
|
+
std::cout << "Query word? ";
|
518
|
+
while (std::cin >> queryWord) {
|
519
|
+
banSet.clear();
|
520
|
+
banSet.insert(queryWord);
|
521
|
+
getVector(queryVec, queryWord);
|
522
|
+
findNN(wordVectors, queryVec, k, banSet);
|
523
|
+
std::cout << "Query word? ";
|
524
|
+
}
|
525
|
+
}
|
526
|
+
|
527
|
+
void FastText::analogies(int32_t k) {
|
528
|
+
std::string word;
|
529
|
+
Vector buffer(args_->dim), query(args_->dim);
|
530
|
+
Matrix wordVectors(dict_->nwords(), args_->dim);
|
531
|
+
precomputeWordVectors(wordVectors);
|
532
|
+
std::set<std::string> banSet;
|
533
|
+
std::cout << "Query triplet (A - B + C)? ";
|
534
|
+
while (true) {
|
535
|
+
banSet.clear();
|
536
|
+
query.zero();
|
537
|
+
std::cin >> word;
|
538
|
+
banSet.insert(word);
|
539
|
+
getVector(buffer, word);
|
540
|
+
query.addVector(buffer, 1.0);
|
541
|
+
std::cin >> word;
|
542
|
+
banSet.insert(word);
|
543
|
+
getVector(buffer, word);
|
544
|
+
query.addVector(buffer, -1.0);
|
545
|
+
std::cin >> word;
|
546
|
+
banSet.insert(word);
|
547
|
+
getVector(buffer, word);
|
548
|
+
query.addVector(buffer, 1.0);
|
549
|
+
|
550
|
+
findNN(wordVectors, query, k, banSet);
|
551
|
+
std::cout << "Query triplet (A - B + C)? ";
|
552
|
+
}
|
553
|
+
}
|
554
|
+
|
555
|
+
void FastText::trainThread(int32_t threadId) {
|
556
|
+
std::ifstream ifs(args_->input);
|
557
|
+
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
|
558
|
+
|
559
|
+
Model model(input_, output_, args_, threadId);
|
560
|
+
if (args_->model == model_name::sup) {
|
561
|
+
model.setTargetCounts(dict_->getCounts(entry_type::label));
|
562
|
+
} else {
|
563
|
+
model.setTargetCounts(dict_->getCounts(entry_type::word));
|
564
|
+
}
|
565
|
+
|
566
|
+
const int64_t ntokens = dict_->ntokens();
|
567
|
+
int64_t localTokenCount = 0;
|
568
|
+
std::vector<int32_t> line, labels;
|
569
|
+
while (tokenCount < args_->epoch * ntokens) {
|
570
|
+
real progress = real(tokenCount) / (args_->epoch * ntokens);
|
571
|
+
real lr = args_->lr * (1.0 - progress);
|
572
|
+
if (args_->model == model_name::sup) {
|
573
|
+
localTokenCount += dict_->getLine(ifs, line, labels, model.rng);
|
574
|
+
supervised(model, lr, line, labels);
|
575
|
+
} else if (args_->model == model_name::cbow) {
|
576
|
+
localTokenCount += dict_->getLine(ifs, line, model.rng);
|
577
|
+
cbow(model, lr, line);
|
578
|
+
} else if (args_->model == model_name::sg) {
|
579
|
+
localTokenCount += dict_->getLine(ifs, line, model.rng);
|
580
|
+
skipgram(model, lr, line);
|
581
|
+
}
|
582
|
+
if (localTokenCount > args_->lrUpdateRate) {
|
583
|
+
tokenCount += localTokenCount;
|
584
|
+
localTokenCount = 0;
|
585
|
+
if (threadId == 0 && args_->verbose > 1) {
|
586
|
+
printInfo(progress, model.getLoss());
|
587
|
+
}
|
588
|
+
}
|
589
|
+
}
|
590
|
+
if (threadId == 0 && args_->verbose > 0) {
|
591
|
+
printInfo(1.0, model.getLoss());
|
592
|
+
std::cerr << std::endl;
|
593
|
+
}
|
594
|
+
ifs.close();
|
595
|
+
}
|
596
|
+
|
597
|
+
void FastText::loadVectors(std::string filename) {
|
598
|
+
std::ifstream in(filename);
|
599
|
+
std::vector<std::string> words;
|
600
|
+
std::shared_ptr<Matrix> mat; // temp. matrix for pretrained vectors
|
601
|
+
int64_t n, dim;
|
602
|
+
if (!in.is_open()) {
|
603
|
+
std::cerr << "Pretrained vectors file cannot be opened!" << std::endl;
|
604
|
+
exit(EXIT_FAILURE);
|
605
|
+
}
|
606
|
+
in >> n >> dim;
|
607
|
+
if (dim != args_->dim) {
|
608
|
+
std::cerr << "Dimension of pretrained vectors does not match -dim option"
|
609
|
+
<< std::endl;
|
610
|
+
exit(EXIT_FAILURE);
|
611
|
+
}
|
612
|
+
mat = std::make_shared<Matrix>(n, dim);
|
613
|
+
for (size_t i = 0; i < n; i++) {
|
614
|
+
std::string word;
|
615
|
+
in >> word;
|
616
|
+
words.push_back(word);
|
617
|
+
dict_->add(word);
|
618
|
+
for (size_t j = 0; j < dim; j++) {
|
619
|
+
in >> mat->data_[i * dim + j];
|
620
|
+
}
|
621
|
+
}
|
622
|
+
in.close();
|
623
|
+
|
624
|
+
dict_->threshold(1, 0);
|
625
|
+
input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
|
626
|
+
input_->uniform(1.0 / args_->dim);
|
627
|
+
|
628
|
+
for (size_t i = 0; i < n; i++) {
|
629
|
+
int32_t idx = dict_->getId(words[i]);
|
630
|
+
if (idx < 0 || idx >= dict_->nwords()) continue;
|
631
|
+
for (size_t j = 0; j < dim; j++) {
|
632
|
+
input_->data_[idx * dim + j] = mat->data_[i * dim + j];
|
633
|
+
}
|
634
|
+
}
|
635
|
+
}
|
636
|
+
|
637
|
+
void FastText::train(std::shared_ptr<Args> args) {
|
638
|
+
args_ = args;
|
639
|
+
dict_ = std::make_shared<Dictionary>(args_);
|
640
|
+
if (args_->input == "-") {
|
641
|
+
// manage expectations
|
642
|
+
std::cerr << "Cannot use stdin for training!" << std::endl;
|
643
|
+
exit(EXIT_FAILURE);
|
644
|
+
}
|
645
|
+
std::ifstream ifs(args_->input);
|
646
|
+
if (!ifs.is_open()) {
|
647
|
+
std::cerr << "Input file cannot be opened!" << std::endl;
|
648
|
+
exit(EXIT_FAILURE);
|
649
|
+
}
|
650
|
+
dict_->readFromFile(ifs);
|
651
|
+
ifs.close();
|
652
|
+
|
653
|
+
if (args_->pretrainedVectors.size() != 0) {
|
654
|
+
loadVectors(args_->pretrainedVectors);
|
655
|
+
} else {
|
656
|
+
input_ = std::make_shared<Matrix>(dict_->nwords()+args_->bucket, args_->dim);
|
657
|
+
input_->uniform(1.0 / args_->dim);
|
658
|
+
}
|
659
|
+
|
660
|
+
if (args_->model == model_name::sup) {
|
661
|
+
output_ = std::make_shared<Matrix>(dict_->nlabels(), args_->dim);
|
662
|
+
} else {
|
663
|
+
output_ = std::make_shared<Matrix>(dict_->nwords(), args_->dim);
|
664
|
+
}
|
665
|
+
output_->zero();
|
666
|
+
|
667
|
+
start = clock();
|
668
|
+
tokenCount = 0;
|
669
|
+
if (args_->thread > 1) {
|
670
|
+
std::vector<std::thread> threads;
|
671
|
+
for (int32_t i = 0; i < args_->thread; i++) {
|
672
|
+
threads.push_back(std::thread([=]() { trainThread(i); }));
|
673
|
+
}
|
674
|
+
for (auto it = threads.begin(); it != threads.end(); ++it) {
|
675
|
+
it->join();
|
676
|
+
}
|
677
|
+
} else {
|
678
|
+
trainThread(0);
|
679
|
+
}
|
680
|
+
model_ = std::make_shared<Model>(input_, output_, args_, 0);
|
681
|
+
|
682
|
+
saveModel();
|
683
|
+
saveVectors();
|
684
|
+
if (args_->saveOutput > 0) {
|
685
|
+
saveOutput();
|
686
|
+
}
|
687
|
+
}
|
688
|
+
|
689
|
+
int FastText::getDimension() const {
|
690
|
+
return args_->dim;
|
691
|
+
}
|
692
|
+
|
693
|
+
}
|