ffi-fasttext 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
@@ -0,0 +1,57 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_MATRIX_H
|
11
|
+
#define FASTTEXT_MATRIX_H
|
12
|
+
|
13
|
+
#include <cstdint>
|
14
|
+
#include <istream>
|
15
|
+
#include <ostream>
|
16
|
+
|
17
|
+
#include "real.h"
|
18
|
+
|
19
|
+
namespace fasttext {
|
20
|
+
|
21
|
+
class Vector;
|
22
|
+
|
23
|
+
class Matrix {
|
24
|
+
|
25
|
+
public:
|
26
|
+
real* data_;
|
27
|
+
int64_t m_;
|
28
|
+
int64_t n_;
|
29
|
+
|
30
|
+
Matrix();
|
31
|
+
Matrix(int64_t, int64_t);
|
32
|
+
Matrix(const Matrix&);
|
33
|
+
Matrix& operator=(const Matrix&);
|
34
|
+
~Matrix();
|
35
|
+
|
36
|
+
inline const real& at(int64_t i, int64_t j) const {return data_[i * n_ + j];};
|
37
|
+
inline real& at(int64_t i, int64_t j) {return data_[i * n_ + j];};
|
38
|
+
|
39
|
+
|
40
|
+
void zero();
|
41
|
+
void uniform(real);
|
42
|
+
real dotRow(const Vector&, int64_t) const;
|
43
|
+
void addRow(const Vector&, int64_t, real);
|
44
|
+
|
45
|
+
void multiplyRow(const Vector& nums, int64_t ib = 0, int64_t ie = -1);
|
46
|
+
void divideRow(const Vector& denoms, int64_t ib = 0, int64_t ie = -1);
|
47
|
+
|
48
|
+
real l2NormRow(int64_t i) const;
|
49
|
+
void l2NormRow(Vector& norms) const;
|
50
|
+
|
51
|
+
void save(std::ostream&);
|
52
|
+
void load(std::istream&);
|
53
|
+
};
|
54
|
+
|
55
|
+
}
|
56
|
+
|
57
|
+
#endif
|
@@ -0,0 +1,341 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "model.h"
|
11
|
+
|
12
|
+
#include <iostream>
|
13
|
+
#include <assert.h>
|
14
|
+
#include <algorithm>
|
15
|
+
|
16
|
+
namespace fasttext {
|
17
|
+
|
18
|
+
Model::Model(std::shared_ptr<Matrix> wi,
|
19
|
+
std::shared_ptr<Matrix> wo,
|
20
|
+
std::shared_ptr<Args> args,
|
21
|
+
int32_t seed)
|
22
|
+
: hidden_(args->dim), output_(wo->m_),
|
23
|
+
grad_(args->dim), rng(seed), quant_(false)
|
24
|
+
{
|
25
|
+
wi_ = wi;
|
26
|
+
wo_ = wo;
|
27
|
+
args_ = args;
|
28
|
+
osz_ = wo->m_;
|
29
|
+
hsz_ = args->dim;
|
30
|
+
negpos = 0;
|
31
|
+
loss_ = 0.0;
|
32
|
+
nexamples_ = 1;
|
33
|
+
initSigmoid();
|
34
|
+
initLog();
|
35
|
+
}
|
36
|
+
|
37
|
+
Model::~Model() {
|
38
|
+
delete[] t_sigmoid;
|
39
|
+
delete[] t_log;
|
40
|
+
}
|
41
|
+
|
42
|
+
void Model::setQuantizePointer(std::shared_ptr<QMatrix> qwi,
|
43
|
+
std::shared_ptr<QMatrix> qwo, bool qout) {
|
44
|
+
qwi_ = qwi;
|
45
|
+
qwo_ = qwo;
|
46
|
+
if (qout) {
|
47
|
+
osz_ = qwo_->getM();
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
real Model::binaryLogistic(int32_t target, bool label, real lr) {
|
52
|
+
real score = sigmoid(wo_->dotRow(hidden_, target));
|
53
|
+
real alpha = lr * (real(label) - score);
|
54
|
+
grad_.addRow(*wo_, target, alpha);
|
55
|
+
wo_->addRow(hidden_, target, alpha);
|
56
|
+
if (label) {
|
57
|
+
return -log(score);
|
58
|
+
} else {
|
59
|
+
return -log(1.0 - score);
|
60
|
+
}
|
61
|
+
}
|
62
|
+
|
63
|
+
real Model::negativeSampling(int32_t target, real lr) {
|
64
|
+
real loss = 0.0;
|
65
|
+
grad_.zero();
|
66
|
+
for (int32_t n = 0; n <= args_->neg; n++) {
|
67
|
+
if (n == 0) {
|
68
|
+
loss += binaryLogistic(target, true, lr);
|
69
|
+
} else {
|
70
|
+
loss += binaryLogistic(getNegative(target), false, lr);
|
71
|
+
}
|
72
|
+
}
|
73
|
+
return loss;
|
74
|
+
}
|
75
|
+
|
76
|
+
real Model::hierarchicalSoftmax(int32_t target, real lr) {
|
77
|
+
real loss = 0.0;
|
78
|
+
grad_.zero();
|
79
|
+
const std::vector<bool>& binaryCode = codes[target];
|
80
|
+
const std::vector<int32_t>& pathToRoot = paths[target];
|
81
|
+
for (int32_t i = 0; i < pathToRoot.size(); i++) {
|
82
|
+
loss += binaryLogistic(pathToRoot[i], binaryCode[i], lr);
|
83
|
+
}
|
84
|
+
return loss;
|
85
|
+
}
|
86
|
+
|
87
|
+
void Model::computeOutputSoftmax(Vector& hidden, Vector& output) const {
|
88
|
+
if (quant_ && args_->qout) {
|
89
|
+
output.mul(*qwo_, hidden);
|
90
|
+
} else {
|
91
|
+
output.mul(*wo_, hidden);
|
92
|
+
}
|
93
|
+
real max = output[0], z = 0.0;
|
94
|
+
for (int32_t i = 0; i < osz_; i++) {
|
95
|
+
max = std::max(output[i], max);
|
96
|
+
}
|
97
|
+
for (int32_t i = 0; i < osz_; i++) {
|
98
|
+
output[i] = exp(output[i] - max);
|
99
|
+
z += output[i];
|
100
|
+
}
|
101
|
+
for (int32_t i = 0; i < osz_; i++) {
|
102
|
+
output[i] /= z;
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
void Model::computeOutputSoftmax() {
|
107
|
+
computeOutputSoftmax(hidden_, output_);
|
108
|
+
}
|
109
|
+
|
110
|
+
real Model::softmax(int32_t target, real lr) {
|
111
|
+
grad_.zero();
|
112
|
+
computeOutputSoftmax();
|
113
|
+
for (int32_t i = 0; i < osz_; i++) {
|
114
|
+
real label = (i == target) ? 1.0 : 0.0;
|
115
|
+
real alpha = lr * (label - output_[i]);
|
116
|
+
grad_.addRow(*wo_, i, alpha);
|
117
|
+
wo_->addRow(hidden_, i, alpha);
|
118
|
+
}
|
119
|
+
return -log(output_[target]);
|
120
|
+
}
|
121
|
+
|
122
|
+
void Model::computeHidden(const std::vector<int32_t>& input, Vector& hidden) const {
|
123
|
+
assert(hidden.size() == hsz_);
|
124
|
+
hidden.zero();
|
125
|
+
for (auto it = input.cbegin(); it != input.cend(); ++it) {
|
126
|
+
if(quant_) {
|
127
|
+
hidden.addRow(*qwi_, *it);
|
128
|
+
} else {
|
129
|
+
hidden.addRow(*wi_, *it);
|
130
|
+
}
|
131
|
+
}
|
132
|
+
hidden.mul(1.0 / input.size());
|
133
|
+
}
|
134
|
+
|
135
|
+
bool Model::comparePairs(const std::pair<real, int32_t> &l,
|
136
|
+
const std::pair<real, int32_t> &r) {
|
137
|
+
return l.first > r.first;
|
138
|
+
}
|
139
|
+
|
140
|
+
void Model::predict(const std::vector<int32_t>& input, int32_t k,
|
141
|
+
std::vector<std::pair<real, int32_t>>& heap,
|
142
|
+
Vector& hidden, Vector& output) const {
|
143
|
+
assert(k > 0);
|
144
|
+
heap.reserve(k + 1);
|
145
|
+
computeHidden(input, hidden);
|
146
|
+
if (args_->loss == loss_name::hs) {
|
147
|
+
dfs(k, 2 * osz_ - 2, 0.0, heap, hidden);
|
148
|
+
} else {
|
149
|
+
findKBest(k, heap, hidden, output);
|
150
|
+
}
|
151
|
+
std::sort_heap(heap.begin(), heap.end(), comparePairs);
|
152
|
+
}
|
153
|
+
|
154
|
+
void Model::predict(const std::vector<int32_t>& input, int32_t k,
|
155
|
+
std::vector<std::pair<real, int32_t>>& heap) {
|
156
|
+
predict(input, k, heap, hidden_, output_);
|
157
|
+
}
|
158
|
+
|
159
|
+
void Model::findKBest(int32_t k, std::vector<std::pair<real, int32_t>>& heap,
|
160
|
+
Vector& hidden, Vector& output) const {
|
161
|
+
computeOutputSoftmax(hidden, output);
|
162
|
+
for (int32_t i = 0; i < osz_; i++) {
|
163
|
+
if (heap.size() == k && log(output[i]) < heap.front().first) {
|
164
|
+
continue;
|
165
|
+
}
|
166
|
+
heap.push_back(std::make_pair(log(output[i]), i));
|
167
|
+
std::push_heap(heap.begin(), heap.end(), comparePairs);
|
168
|
+
if (heap.size() > k) {
|
169
|
+
std::pop_heap(heap.begin(), heap.end(), comparePairs);
|
170
|
+
heap.pop_back();
|
171
|
+
}
|
172
|
+
}
|
173
|
+
}
|
174
|
+
|
175
|
+
void Model::dfs(int32_t k, int32_t node, real score,
|
176
|
+
std::vector<std::pair<real, int32_t>>& heap,
|
177
|
+
Vector& hidden) const {
|
178
|
+
if (heap.size() == k && score < heap.front().first) {
|
179
|
+
return;
|
180
|
+
}
|
181
|
+
|
182
|
+
if (tree[node].left == -1 && tree[node].right == -1) {
|
183
|
+
heap.push_back(std::make_pair(score, node));
|
184
|
+
std::push_heap(heap.begin(), heap.end(), comparePairs);
|
185
|
+
if (heap.size() > k) {
|
186
|
+
std::pop_heap(heap.begin(), heap.end(), comparePairs);
|
187
|
+
heap.pop_back();
|
188
|
+
}
|
189
|
+
return;
|
190
|
+
}
|
191
|
+
|
192
|
+
real f;
|
193
|
+
if (quant_ && args_->qout) {
|
194
|
+
f= sigmoid(qwo_->dotRow(hidden, node - osz_));
|
195
|
+
} else {
|
196
|
+
f= sigmoid(wo_->dotRow(hidden, node - osz_));
|
197
|
+
}
|
198
|
+
|
199
|
+
dfs(k, tree[node].left, score + log(1.0 - f), heap, hidden);
|
200
|
+
dfs(k, tree[node].right, score + log(f), heap, hidden);
|
201
|
+
}
|
202
|
+
|
203
|
+
void Model::update(const std::vector<int32_t>& input, int32_t target, real lr) {
|
204
|
+
assert(target >= 0);
|
205
|
+
assert(target < osz_);
|
206
|
+
if (input.size() == 0) return;
|
207
|
+
computeHidden(input, hidden_);
|
208
|
+
if (args_->loss == loss_name::ns) {
|
209
|
+
loss_ += negativeSampling(target, lr);
|
210
|
+
} else if (args_->loss == loss_name::hs) {
|
211
|
+
loss_ += hierarchicalSoftmax(target, lr);
|
212
|
+
} else {
|
213
|
+
loss_ += softmax(target, lr);
|
214
|
+
}
|
215
|
+
nexamples_ += 1;
|
216
|
+
|
217
|
+
if (args_->model == model_name::sup) {
|
218
|
+
grad_.mul(1.0 / input.size());
|
219
|
+
}
|
220
|
+
for (auto it = input.cbegin(); it != input.cend(); ++it) {
|
221
|
+
wi_->addRow(grad_, *it, 1.0);
|
222
|
+
}
|
223
|
+
}
|
224
|
+
|
225
|
+
void Model::setTargetCounts(const std::vector<int64_t>& counts) {
|
226
|
+
assert(counts.size() == osz_);
|
227
|
+
if (args_->loss == loss_name::ns) {
|
228
|
+
initTableNegatives(counts);
|
229
|
+
}
|
230
|
+
if (args_->loss == loss_name::hs) {
|
231
|
+
buildTree(counts);
|
232
|
+
}
|
233
|
+
}
|
234
|
+
|
235
|
+
void Model::initTableNegatives(const std::vector<int64_t>& counts) {
|
236
|
+
real z = 0.0;
|
237
|
+
for (size_t i = 0; i < counts.size(); i++) {
|
238
|
+
z += pow(counts[i], 0.5);
|
239
|
+
}
|
240
|
+
for (size_t i = 0; i < counts.size(); i++) {
|
241
|
+
real c = pow(counts[i], 0.5);
|
242
|
+
for (size_t j = 0; j < c * NEGATIVE_TABLE_SIZE / z; j++) {
|
243
|
+
negatives.push_back(i);
|
244
|
+
}
|
245
|
+
}
|
246
|
+
std::shuffle(negatives.begin(), negatives.end(), rng);
|
247
|
+
}
|
248
|
+
|
249
|
+
int32_t Model::getNegative(int32_t target) {
|
250
|
+
int32_t negative;
|
251
|
+
do {
|
252
|
+
negative = negatives[negpos];
|
253
|
+
negpos = (negpos + 1) % negatives.size();
|
254
|
+
} while (target == negative);
|
255
|
+
return negative;
|
256
|
+
}
|
257
|
+
|
258
|
+
void Model::buildTree(const std::vector<int64_t>& counts) {
|
259
|
+
tree.resize(2 * osz_ - 1);
|
260
|
+
for (int32_t i = 0; i < 2 * osz_ - 1; i++) {
|
261
|
+
tree[i].parent = -1;
|
262
|
+
tree[i].left = -1;
|
263
|
+
tree[i].right = -1;
|
264
|
+
tree[i].count = 1e15;
|
265
|
+
tree[i].binary = false;
|
266
|
+
}
|
267
|
+
for (int32_t i = 0; i < osz_; i++) {
|
268
|
+
tree[i].count = counts[i];
|
269
|
+
}
|
270
|
+
int32_t leaf = osz_ - 1;
|
271
|
+
int32_t node = osz_;
|
272
|
+
for (int32_t i = osz_; i < 2 * osz_ - 1; i++) {
|
273
|
+
int32_t mini[2];
|
274
|
+
for (int32_t j = 0; j < 2; j++) {
|
275
|
+
if (leaf >= 0 && tree[leaf].count < tree[node].count) {
|
276
|
+
mini[j] = leaf--;
|
277
|
+
} else {
|
278
|
+
mini[j] = node++;
|
279
|
+
}
|
280
|
+
}
|
281
|
+
tree[i].left = mini[0];
|
282
|
+
tree[i].right = mini[1];
|
283
|
+
tree[i].count = tree[mini[0]].count + tree[mini[1]].count;
|
284
|
+
tree[mini[0]].parent = i;
|
285
|
+
tree[mini[1]].parent = i;
|
286
|
+
tree[mini[1]].binary = true;
|
287
|
+
}
|
288
|
+
for (int32_t i = 0; i < osz_; i++) {
|
289
|
+
std::vector<int32_t> path;
|
290
|
+
std::vector<bool> code;
|
291
|
+
int32_t j = i;
|
292
|
+
while (tree[j].parent != -1) {
|
293
|
+
path.push_back(tree[j].parent - osz_);
|
294
|
+
code.push_back(tree[j].binary);
|
295
|
+
j = tree[j].parent;
|
296
|
+
}
|
297
|
+
paths.push_back(path);
|
298
|
+
codes.push_back(code);
|
299
|
+
}
|
300
|
+
}
|
301
|
+
|
302
|
+
real Model::getLoss() const {
|
303
|
+
return loss_ / nexamples_;
|
304
|
+
}
|
305
|
+
|
306
|
+
void Model::initSigmoid() {
|
307
|
+
t_sigmoid = new real[SIGMOID_TABLE_SIZE + 1];
|
308
|
+
for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) {
|
309
|
+
real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID;
|
310
|
+
t_sigmoid[i] = 1.0 / (1.0 + std::exp(-x));
|
311
|
+
}
|
312
|
+
}
|
313
|
+
|
314
|
+
void Model::initLog() {
|
315
|
+
t_log = new real[LOG_TABLE_SIZE + 1];
|
316
|
+
for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) {
|
317
|
+
real x = (real(i) + 1e-5) / LOG_TABLE_SIZE;
|
318
|
+
t_log[i] = std::log(x);
|
319
|
+
}
|
320
|
+
}
|
321
|
+
|
322
|
+
real Model::log(real x) const {
|
323
|
+
if (x > 1.0) {
|
324
|
+
return 0.0;
|
325
|
+
}
|
326
|
+
int i = int(x * LOG_TABLE_SIZE);
|
327
|
+
return t_log[i];
|
328
|
+
}
|
329
|
+
|
330
|
+
real Model::sigmoid(real x) const {
|
331
|
+
if (x < -MAX_SIGMOID) {
|
332
|
+
return 0.0;
|
333
|
+
} else if (x > MAX_SIGMOID) {
|
334
|
+
return 1.0;
|
335
|
+
} else {
|
336
|
+
int i = int((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2);
|
337
|
+
return t_sigmoid[i];
|
338
|
+
}
|
339
|
+
}
|
340
|
+
|
341
|
+
}
|
@@ -0,0 +1,110 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_MODEL_H
|
11
|
+
#define FASTTEXT_MODEL_H
|
12
|
+
|
13
|
+
#include <vector>
|
14
|
+
#include <random>
|
15
|
+
#include <utility>
|
16
|
+
#include <memory>
|
17
|
+
|
18
|
+
#include "args.h"
|
19
|
+
#include "matrix.h"
|
20
|
+
#include "vector.h"
|
21
|
+
#include "qmatrix.h"
|
22
|
+
#include "real.h"
|
23
|
+
|
24
|
+
#define SIGMOID_TABLE_SIZE 512
|
25
|
+
#define MAX_SIGMOID 8
|
26
|
+
#define LOG_TABLE_SIZE 512
|
27
|
+
|
28
|
+
namespace fasttext {
|
29
|
+
|
30
|
+
struct Node {
|
31
|
+
int32_t parent;
|
32
|
+
int32_t left;
|
33
|
+
int32_t right;
|
34
|
+
int64_t count;
|
35
|
+
bool binary;
|
36
|
+
};
|
37
|
+
|
38
|
+
class Model {
|
39
|
+
private:
|
40
|
+
std::shared_ptr<Matrix> wi_;
|
41
|
+
std::shared_ptr<Matrix> wo_;
|
42
|
+
std::shared_ptr<QMatrix> qwi_;
|
43
|
+
std::shared_ptr<QMatrix> qwo_;
|
44
|
+
std::shared_ptr<Args> args_;
|
45
|
+
Vector hidden_;
|
46
|
+
Vector output_;
|
47
|
+
Vector grad_;
|
48
|
+
int32_t hsz_;
|
49
|
+
int32_t osz_;
|
50
|
+
real loss_;
|
51
|
+
int64_t nexamples_;
|
52
|
+
real* t_sigmoid;
|
53
|
+
real* t_log;
|
54
|
+
// used for negative sampling:
|
55
|
+
std::vector<int32_t> negatives;
|
56
|
+
size_t negpos;
|
57
|
+
// used for hierarchical softmax:
|
58
|
+
std::vector< std::vector<int32_t> > paths;
|
59
|
+
std::vector< std::vector<bool> > codes;
|
60
|
+
std::vector<Node> tree;
|
61
|
+
|
62
|
+
static bool comparePairs(const std::pair<real, int32_t>&,
|
63
|
+
const std::pair<real, int32_t>&);
|
64
|
+
|
65
|
+
int32_t getNegative(int32_t target);
|
66
|
+
void initSigmoid();
|
67
|
+
void initLog();
|
68
|
+
|
69
|
+
static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
|
70
|
+
|
71
|
+
public:
|
72
|
+
Model(std::shared_ptr<Matrix>, std::shared_ptr<Matrix>,
|
73
|
+
std::shared_ptr<Args>, int32_t);
|
74
|
+
~Model();
|
75
|
+
|
76
|
+
real binaryLogistic(int32_t, bool, real);
|
77
|
+
real negativeSampling(int32_t, real);
|
78
|
+
real hierarchicalSoftmax(int32_t, real);
|
79
|
+
real softmax(int32_t, real);
|
80
|
+
|
81
|
+
void predict(const std::vector<int32_t>&, int32_t,
|
82
|
+
std::vector<std::pair<real, int32_t>>&,
|
83
|
+
Vector&, Vector&) const;
|
84
|
+
void predict(const std::vector<int32_t>&, int32_t,
|
85
|
+
std::vector<std::pair<real, int32_t>>&);
|
86
|
+
void dfs(int32_t, int32_t, real,
|
87
|
+
std::vector<std::pair<real, int32_t>>&,
|
88
|
+
Vector&) const;
|
89
|
+
void findKBest(int32_t, std::vector<std::pair<real, int32_t>>&,
|
90
|
+
Vector&, Vector&) const;
|
91
|
+
void update(const std::vector<int32_t>&, int32_t, real);
|
92
|
+
void computeHidden(const std::vector<int32_t>&, Vector&) const;
|
93
|
+
void computeOutputSoftmax(Vector&, Vector&) const;
|
94
|
+
void computeOutputSoftmax();
|
95
|
+
|
96
|
+
void setTargetCounts(const std::vector<int64_t>&);
|
97
|
+
void initTableNegatives(const std::vector<int64_t>&);
|
98
|
+
void buildTree(const std::vector<int64_t>&);
|
99
|
+
real getLoss() const;
|
100
|
+
real sigmoid(real) const;
|
101
|
+
real log(real) const;
|
102
|
+
|
103
|
+
std::minstd_rand rng;
|
104
|
+
bool quant_;
|
105
|
+
void setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);
|
106
|
+
};
|
107
|
+
|
108
|
+
}
|
109
|
+
|
110
|
+
#endif
|