ffi-fasttext 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +44 -0
- data/.travis.yml +5 -0
- data/Gemfile +6 -0
- data/LICENSE.txt +21 -0
- data/README.md +59 -0
- data/Rakefile +19 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/ext/ffi/fasttext/Rakefile +71 -0
- data/ffi-fasttext.gemspec +40 -0
- data/lib/ffi/fasttext.rb +108 -0
- data/lib/ffi/fasttext/version.rb +5 -0
- data/vendor/fasttext/LICENSE +30 -0
- data/vendor/fasttext/PATENTS +33 -0
- data/vendor/fasttext/args.cc +250 -0
- data/vendor/fasttext/args.h +71 -0
- data/vendor/fasttext/dictionary.cc +475 -0
- data/vendor/fasttext/dictionary.h +112 -0
- data/vendor/fasttext/fasttext.cc +693 -0
- data/vendor/fasttext/fasttext.h +97 -0
- data/vendor/fasttext/ffi_fasttext.cc +66 -0
- data/vendor/fasttext/main.cc +270 -0
- data/vendor/fasttext/matrix.cc +144 -0
- data/vendor/fasttext/matrix.h +57 -0
- data/vendor/fasttext/model.cc +341 -0
- data/vendor/fasttext/model.h +110 -0
- data/vendor/fasttext/productquantizer.cc +211 -0
- data/vendor/fasttext/productquantizer.h +67 -0
- data/vendor/fasttext/qmatrix.cc +121 -0
- data/vendor/fasttext/qmatrix.h +65 -0
- data/vendor/fasttext/real.h +19 -0
- data/vendor/fasttext/utils.cc +29 -0
- data/vendor/fasttext/utils.h +25 -0
- data/vendor/fasttext/vector.cc +137 -0
- data/vendor/fasttext/vector.h +53 -0
- metadata +151 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#ifndef FASTTEXT_FASTTEXT_H
|
11
|
+
#define FASTTEXT_FASTTEXT_H
|
12
|
+
|
13
|
+
#define FASTTEXT_VERSION 12 /* Version 1b */
|
14
|
+
#define FASTTEXT_FILEFORMAT_MAGIC_INT32 793712314
|
15
|
+
|
16
|
+
#include <time.h>
|
17
|
+
|
18
|
+
#include <atomic>
|
19
|
+
#include <memory>
|
20
|
+
#include <set>
|
21
|
+
|
22
|
+
#include "args.h"
|
23
|
+
#include "dictionary.h"
|
24
|
+
#include "matrix.h"
|
25
|
+
#include "qmatrix.h"
|
26
|
+
#include "model.h"
|
27
|
+
#include "real.h"
|
28
|
+
#include "utils.h"
|
29
|
+
#include "vector.h"
|
30
|
+
|
31
|
+
namespace fasttext {
|
32
|
+
|
33
|
+
class FastText {
|
34
|
+
private:
|
35
|
+
std::shared_ptr<Args> args_;
|
36
|
+
std::shared_ptr<Dictionary> dict_;
|
37
|
+
|
38
|
+
std::shared_ptr<Matrix> input_;
|
39
|
+
std::shared_ptr<Matrix> output_;
|
40
|
+
|
41
|
+
std::shared_ptr<QMatrix> qinput_;
|
42
|
+
std::shared_ptr<QMatrix> qoutput_;
|
43
|
+
|
44
|
+
std::shared_ptr<Model> model_;
|
45
|
+
|
46
|
+
std::atomic<int64_t> tokenCount;
|
47
|
+
clock_t start;
|
48
|
+
void signModel(std::ostream&);
|
49
|
+
bool checkModel(std::istream&);
|
50
|
+
|
51
|
+
bool quant_;
|
52
|
+
int32_t version;
|
53
|
+
|
54
|
+
public:
|
55
|
+
FastText();
|
56
|
+
|
57
|
+
void getVector(Vector&, const std::string&) const;
|
58
|
+
std::shared_ptr<const Dictionary> getDictionary() const;
|
59
|
+
void saveVectors();
|
60
|
+
void saveOutput();
|
61
|
+
void saveModel();
|
62
|
+
void loadModel(std::istream&);
|
63
|
+
void loadModel(const std::string&);
|
64
|
+
void printInfo(real, real);
|
65
|
+
|
66
|
+
void supervised(Model&, real, const std::vector<int32_t>&,
|
67
|
+
const std::vector<int32_t>&);
|
68
|
+
void cbow(Model&, real, const std::vector<int32_t>&);
|
69
|
+
void skipgram(Model&, real, const std::vector<int32_t>&);
|
70
|
+
std::vector<int32_t> selectEmbeddings(int32_t) const;
|
71
|
+
void quantize(std::shared_ptr<Args>);
|
72
|
+
void test(std::istream&, int32_t);
|
73
|
+
void predict(std::istream&, int32_t, bool);
|
74
|
+
void predict(
|
75
|
+
std::istream&,
|
76
|
+
int32_t,
|
77
|
+
std::vector<std::pair<real, std::string>>&) const;
|
78
|
+
void wordVectors();
|
79
|
+
void sentenceVectors();
|
80
|
+
void ngramVectors(std::string);
|
81
|
+
void textVectors();
|
82
|
+
void printWordVectors();
|
83
|
+
void printSentenceVectors();
|
84
|
+
void precomputeWordVectors(Matrix&);
|
85
|
+
void findNN(const Matrix&, const Vector&, int32_t,
|
86
|
+
const std::set<std::string>&);
|
87
|
+
void nn(int32_t);
|
88
|
+
void analogies(int32_t);
|
89
|
+
void trainThread(int32_t);
|
90
|
+
void train(std::shared_ptr<Args>);
|
91
|
+
|
92
|
+
void loadVectors(std::string);
|
93
|
+
int getDimension() const;
|
94
|
+
};
|
95
|
+
|
96
|
+
}
|
97
|
+
#endif
|
@@ -0,0 +1,66 @@
|
|
1
|
+
#include <algorithm>
|
2
|
+
#include <iostream>
|
3
|
+
#include <cstring>
|
4
|
+
#include <math.h>
|
5
|
+
#include <sstream>
|
6
|
+
#include <string>
|
7
|
+
#include <vector>
|
8
|
+
|
9
|
+
#include "real.h"
|
10
|
+
#include "fasttext.h"
|
11
|
+
|
12
|
+
#ifdef __cplusplus
|
13
|
+
#define EXTERN_C extern "C"
|
14
|
+
#define EXTERN_C_BEGIN extern "C" {
|
15
|
+
#define EXTERN_C_END }
|
16
|
+
#else
|
17
|
+
#define EXTERN_C /* Nothing */
|
18
|
+
#define EXTERN_C_BEGIN /* Nothing */
|
19
|
+
#define EXTERN_C_END /* Nothing */
|
20
|
+
#endif
|
21
|
+
|
22
|
+
EXTERN_C_BEGIN
|
23
|
+
fasttext::FastText* create(const char* model_name) {
|
24
|
+
fasttext::FastText* new_fasttext = new fasttext::FastText();
|
25
|
+
new_fasttext->loadModel(std::string(model_name));
|
26
|
+
|
27
|
+
return new_fasttext;
|
28
|
+
}
|
29
|
+
|
30
|
+
void destroy(fasttext::FastText* destroy_fasttext) {
|
31
|
+
delete destroy_fasttext;
|
32
|
+
}
|
33
|
+
|
34
|
+
void predict_string_free(const char* match) {
|
35
|
+
if (match != NULL) {
|
36
|
+
delete[] match;
|
37
|
+
}
|
38
|
+
}
|
39
|
+
|
40
|
+
const char* predict(fasttext::FastText* fasttext_pointer, const char* key, int32_t number_of_predictions) {
|
41
|
+
std::string string_key(key);
|
42
|
+
std::stringstream key_stream;
|
43
|
+
std::ostringstream output_stream;
|
44
|
+
key_stream.str(string_key);
|
45
|
+
key_stream << std::endl;
|
46
|
+
|
47
|
+
std::vector<std::pair<fasttext::real, std::string>> predictions;
|
48
|
+
fasttext_pointer->predict(key_stream, number_of_predictions, predictions);
|
49
|
+
|
50
|
+
for (auto iter = predictions.begin(); iter != predictions.end(); iter++) {
|
51
|
+
output_stream << iter->second << " " << std::exp(iter->first) << " ";
|
52
|
+
}
|
53
|
+
|
54
|
+
if (!output_stream.str().empty()) {
|
55
|
+
std::string first = output_stream.str();
|
56
|
+
char *val = new char[first.size() + 1]{0};
|
57
|
+
val[first.size()] = '\0';
|
58
|
+
memcpy(val, first.c_str(), first.size());
|
59
|
+
|
60
|
+
return val;
|
61
|
+
}
|
62
|
+
|
63
|
+
return NULL;
|
64
|
+
}
|
65
|
+
|
66
|
+
EXTERN_C_END
|
@@ -0,0 +1,270 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include <iostream>
|
11
|
+
|
12
|
+
#include "fasttext.h"
|
13
|
+
#include "args.h"
|
14
|
+
|
15
|
+
using namespace fasttext;
|
16
|
+
|
17
|
+
void printUsage() {
|
18
|
+
std::cerr
|
19
|
+
<< "usage: fasttext <command> <args>\n\n"
|
20
|
+
<< "The commands supported by fasttext are:\n\n"
|
21
|
+
<< " supervised train a supervised classifier\n"
|
22
|
+
<< " quantize quantize a model to reduce the memory usage\n"
|
23
|
+
<< " test evaluate a supervised classifier\n"
|
24
|
+
<< " predict predict most likely labels\n"
|
25
|
+
<< " predict-prob predict most likely labels with probabilities\n"
|
26
|
+
<< " skipgram train a skipgram model\n"
|
27
|
+
<< " cbow train a cbow model\n"
|
28
|
+
<< " print-word-vectors print word vectors given a trained model\n"
|
29
|
+
<< " print-sentence-vectors print sentence vectors given a trained model\n"
|
30
|
+
<< " nn query for nearest neighbors\n"
|
31
|
+
<< " analogies query for analogies\n"
|
32
|
+
<< std::endl;
|
33
|
+
}
|
34
|
+
|
35
|
+
void printQuantizeUsage() {
|
36
|
+
std::cerr
|
37
|
+
<< "usage: fasttext quantize <args>"
|
38
|
+
<< std::endl;
|
39
|
+
}
|
40
|
+
|
41
|
+
void printTestUsage() {
|
42
|
+
std::cerr
|
43
|
+
<< "usage: fasttext test <model> <test-data> [<k>]\n\n"
|
44
|
+
<< " <model> model filename\n"
|
45
|
+
<< " <test-data> test data filename (if -, read from stdin)\n"
|
46
|
+
<< " <k> (optional; 1 by default) predict top k labels\n"
|
47
|
+
<< std::endl;
|
48
|
+
}
|
49
|
+
|
50
|
+
void printPredictUsage() {
|
51
|
+
std::cerr
|
52
|
+
<< "usage: fasttext predict[-prob] <model> <test-data> [<k>]\n\n"
|
53
|
+
<< " <model> model filename\n"
|
54
|
+
<< " <test-data> test data filename (if -, read from stdin)\n"
|
55
|
+
<< " <k> (optional; 1 by default) predict top k labels\n"
|
56
|
+
<< std::endl;
|
57
|
+
}
|
58
|
+
|
59
|
+
void printPrintWordVectorsUsage() {
|
60
|
+
std::cerr
|
61
|
+
<< "usage: fasttext print-word-vectors <model>\n\n"
|
62
|
+
<< " <model> model filename\n"
|
63
|
+
<< std::endl;
|
64
|
+
}
|
65
|
+
|
66
|
+
void printPrintSentenceVectorsUsage() {
|
67
|
+
std::cerr
|
68
|
+
<< "usage: fasttext print-sentence-vectors <model>\n\n"
|
69
|
+
<< " <model> model filename\n"
|
70
|
+
<< std::endl;
|
71
|
+
}
|
72
|
+
|
73
|
+
void printPrintNgramsUsage() {
|
74
|
+
std::cerr
|
75
|
+
<< "usage: fasttext print-ngrams <model> <word>\n\n"
|
76
|
+
<< " <model> model filename\n"
|
77
|
+
<< " <word> word to print\n"
|
78
|
+
<< std::endl;
|
79
|
+
}
|
80
|
+
|
81
|
+
void quantize(const std::vector<std::string>& args) {
|
82
|
+
std::shared_ptr<Args> a = std::make_shared<Args>();
|
83
|
+
if (args.size() < 3) {
|
84
|
+
printQuantizeUsage();
|
85
|
+
a->printHelp();
|
86
|
+
exit(EXIT_FAILURE);
|
87
|
+
}
|
88
|
+
a->parseArgs(args);
|
89
|
+
FastText fasttext;
|
90
|
+
fasttext.quantize(a);
|
91
|
+
exit(0);
|
92
|
+
}
|
93
|
+
|
94
|
+
void printNNUsage() {
|
95
|
+
std::cout
|
96
|
+
<< "usage: fasttext nn <model> <k>\n\n"
|
97
|
+
<< " <model> model filename\n"
|
98
|
+
<< " <k> (optional; 10 by default) predict top k labels\n"
|
99
|
+
<< std::endl;
|
100
|
+
}
|
101
|
+
|
102
|
+
void printAnalogiesUsage() {
|
103
|
+
std::cout
|
104
|
+
<< "usage: fasttext analogies <model> <k>\n\n"
|
105
|
+
<< " <model> model filename\n"
|
106
|
+
<< " <k> (optional; 10 by default) predict top k labels\n"
|
107
|
+
<< std::endl;
|
108
|
+
}
|
109
|
+
|
110
|
+
void test(const std::vector<std::string>& args) {
|
111
|
+
if (args.size() < 4 || args.size() > 5) {
|
112
|
+
printTestUsage();
|
113
|
+
exit(EXIT_FAILURE);
|
114
|
+
}
|
115
|
+
int32_t k = 1;
|
116
|
+
if (args.size() >= 5) {
|
117
|
+
k = std::stoi(args[4]);
|
118
|
+
}
|
119
|
+
|
120
|
+
FastText fasttext;
|
121
|
+
fasttext.loadModel(args[2]);
|
122
|
+
|
123
|
+
std::string infile = args[3];
|
124
|
+
if (infile == "-") {
|
125
|
+
fasttext.test(std::cin, k);
|
126
|
+
} else {
|
127
|
+
std::ifstream ifs(infile);
|
128
|
+
if (!ifs.is_open()) {
|
129
|
+
std::cerr << "Test file cannot be opened!" << std::endl;
|
130
|
+
exit(EXIT_FAILURE);
|
131
|
+
}
|
132
|
+
fasttext.test(ifs, k);
|
133
|
+
ifs.close();
|
134
|
+
}
|
135
|
+
exit(0);
|
136
|
+
}
|
137
|
+
|
138
|
+
void predict(const std::vector<std::string>& args) {
|
139
|
+
if (args.size() < 4 || args.size() > 5) {
|
140
|
+
printPredictUsage();
|
141
|
+
exit(EXIT_FAILURE);
|
142
|
+
}
|
143
|
+
int32_t k = 1;
|
144
|
+
if (args.size() >= 5) {
|
145
|
+
k = std::stoi(args[4]);
|
146
|
+
}
|
147
|
+
|
148
|
+
bool print_prob = args[1] == "predict-prob";
|
149
|
+
FastText fasttext;
|
150
|
+
fasttext.loadModel(std::string(args[2]));
|
151
|
+
|
152
|
+
std::string infile(args[3]);
|
153
|
+
if (infile == "-") {
|
154
|
+
fasttext.predict(std::cin, k, print_prob);
|
155
|
+
} else {
|
156
|
+
std::ifstream ifs(infile);
|
157
|
+
if (!ifs.is_open()) {
|
158
|
+
std::cerr << "Input file cannot be opened!" << std::endl;
|
159
|
+
exit(EXIT_FAILURE);
|
160
|
+
}
|
161
|
+
fasttext.predict(ifs, k, print_prob);
|
162
|
+
ifs.close();
|
163
|
+
}
|
164
|
+
|
165
|
+
exit(0);
|
166
|
+
}
|
167
|
+
|
168
|
+
void printWordVectors(const std::vector<std::string> args) {
|
169
|
+
if (args.size() != 3) {
|
170
|
+
printPrintWordVectorsUsage();
|
171
|
+
exit(EXIT_FAILURE);
|
172
|
+
}
|
173
|
+
FastText fasttext;
|
174
|
+
fasttext.loadModel(std::string(args[2]));
|
175
|
+
fasttext.printWordVectors();
|
176
|
+
exit(0);
|
177
|
+
}
|
178
|
+
|
179
|
+
void printSentenceVectors(const std::vector<std::string> args) {
|
180
|
+
if (args.size() != 3) {
|
181
|
+
printPrintSentenceVectorsUsage();
|
182
|
+
exit(EXIT_FAILURE);
|
183
|
+
}
|
184
|
+
FastText fasttext;
|
185
|
+
fasttext.loadModel(std::string(args[2]));
|
186
|
+
fasttext.printSentenceVectors();
|
187
|
+
exit(0);
|
188
|
+
}
|
189
|
+
|
190
|
+
void printNgrams(const std::vector<std::string> args) {
|
191
|
+
if (args.size() != 4) {
|
192
|
+
printPrintNgramsUsage();
|
193
|
+
exit(EXIT_FAILURE);
|
194
|
+
}
|
195
|
+
FastText fasttext;
|
196
|
+
fasttext.loadModel(std::string(args[2]));
|
197
|
+
fasttext.ngramVectors(std::string(args[3]));
|
198
|
+
exit(0);
|
199
|
+
}
|
200
|
+
|
201
|
+
void nn(const std::vector<std::string> args) {
|
202
|
+
int32_t k;
|
203
|
+
if (args.size() == 3) {
|
204
|
+
k = 10;
|
205
|
+
} else if (args.size() == 4) {
|
206
|
+
k = std::stoi(args[3]);
|
207
|
+
} else {
|
208
|
+
printNNUsage();
|
209
|
+
exit(EXIT_FAILURE);
|
210
|
+
}
|
211
|
+
FastText fasttext;
|
212
|
+
fasttext.loadModel(std::string(args[2]));
|
213
|
+
fasttext.nn(k);
|
214
|
+
exit(0);
|
215
|
+
}
|
216
|
+
|
217
|
+
void analogies(const std::vector<std::string> args) {
|
218
|
+
int32_t k;
|
219
|
+
if (args.size() == 3) {
|
220
|
+
k = 10;
|
221
|
+
} else if (args.size() == 4) {
|
222
|
+
k = std::stoi(args[3]);
|
223
|
+
} else {
|
224
|
+
printAnalogiesUsage();
|
225
|
+
exit(EXIT_FAILURE);
|
226
|
+
}
|
227
|
+
FastText fasttext;
|
228
|
+
fasttext.loadModel(std::string(args[2]));
|
229
|
+
fasttext.analogies(k);
|
230
|
+
exit(0);
|
231
|
+
}
|
232
|
+
|
233
|
+
void train(const std::vector<std::string> args) {
|
234
|
+
std::shared_ptr<Args> a = std::make_shared<Args>();
|
235
|
+
a->parseArgs(args);
|
236
|
+
FastText fasttext;
|
237
|
+
fasttext.train(a);
|
238
|
+
}
|
239
|
+
|
240
|
+
int main(int argc, char** argv) {
|
241
|
+
std::vector<std::string> args(argv, argv + argc);
|
242
|
+
if (args.size() < 2) {
|
243
|
+
printUsage();
|
244
|
+
exit(EXIT_FAILURE);
|
245
|
+
}
|
246
|
+
std::string command(args[1]);
|
247
|
+
if (command == "skipgram" || command == "cbow" || command == "supervised") {
|
248
|
+
train(args);
|
249
|
+
} else if (command == "test") {
|
250
|
+
test(args);
|
251
|
+
} else if (command == "quantize") {
|
252
|
+
quantize(args);
|
253
|
+
} else if (command == "print-word-vectors") {
|
254
|
+
printWordVectors(args);
|
255
|
+
} else if (command == "print-sentence-vectors") {
|
256
|
+
printSentenceVectors(args);
|
257
|
+
} else if (command == "print-ngrams") {
|
258
|
+
printNgrams(args);
|
259
|
+
} else if (command == "nn") {
|
260
|
+
nn(args);
|
261
|
+
} else if (command == "analogies") {
|
262
|
+
analogies(args);
|
263
|
+
} else if (command == "predict" || command == "predict-prob" ) {
|
264
|
+
predict(args);
|
265
|
+
} else {
|
266
|
+
printUsage();
|
267
|
+
exit(EXIT_FAILURE);
|
268
|
+
}
|
269
|
+
return 0;
|
270
|
+
}
|
@@ -0,0 +1,144 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2016-present, Facebook, Inc.
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* This source code is licensed under the BSD-style license found in the
|
6
|
+
* LICENSE file in the root directory of this source tree. An additional grant
|
7
|
+
* of patent rights can be found in the PATENTS file in the same directory.
|
8
|
+
*/
|
9
|
+
|
10
|
+
#include "matrix.h"
|
11
|
+
|
12
|
+
#include <assert.h>
|
13
|
+
|
14
|
+
#include <random>
|
15
|
+
|
16
|
+
#include "utils.h"
|
17
|
+
#include "vector.h"
|
18
|
+
|
19
|
+
namespace fasttext {
|
20
|
+
|
21
|
+
Matrix::Matrix() {
|
22
|
+
m_ = 0;
|
23
|
+
n_ = 0;
|
24
|
+
data_ = nullptr;
|
25
|
+
}
|
26
|
+
|
27
|
+
Matrix::Matrix(int64_t m, int64_t n) {
|
28
|
+
m_ = m;
|
29
|
+
n_ = n;
|
30
|
+
data_ = new real[m * n];
|
31
|
+
}
|
32
|
+
|
33
|
+
Matrix::Matrix(const Matrix& other) {
|
34
|
+
m_ = other.m_;
|
35
|
+
n_ = other.n_;
|
36
|
+
data_ = new real[m_ * n_];
|
37
|
+
for (int64_t i = 0; i < (m_ * n_); i++) {
|
38
|
+
data_[i] = other.data_[i];
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
Matrix& Matrix::operator=(const Matrix& other) {
|
43
|
+
Matrix temp(other);
|
44
|
+
m_ = temp.m_;
|
45
|
+
n_ = temp.n_;
|
46
|
+
std::swap(data_, temp.data_);
|
47
|
+
return *this;
|
48
|
+
}
|
49
|
+
|
50
|
+
Matrix::~Matrix() {
|
51
|
+
delete[] data_;
|
52
|
+
}
|
53
|
+
|
54
|
+
void Matrix::zero() {
|
55
|
+
for (int64_t i = 0; i < (m_ * n_); i++) {
|
56
|
+
data_[i] = 0.0;
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
void Matrix::uniform(real a) {
|
61
|
+
std::minstd_rand rng(1);
|
62
|
+
std::uniform_real_distribution<> uniform(-a, a);
|
63
|
+
for (int64_t i = 0; i < (m_ * n_); i++) {
|
64
|
+
data_[i] = uniform(rng);
|
65
|
+
}
|
66
|
+
}
|
67
|
+
|
68
|
+
real Matrix::dotRow(const Vector& vec, int64_t i) const {
|
69
|
+
assert(i >= 0);
|
70
|
+
assert(i < m_);
|
71
|
+
assert(vec.size() == n_);
|
72
|
+
real d = 0.0;
|
73
|
+
for (int64_t j = 0; j < n_; j++) {
|
74
|
+
d += at(i, j) * vec.data_[j];
|
75
|
+
}
|
76
|
+
return d;
|
77
|
+
}
|
78
|
+
|
79
|
+
void Matrix::addRow(const Vector& vec, int64_t i, real a) {
|
80
|
+
assert(i >= 0);
|
81
|
+
assert(i < m_);
|
82
|
+
assert(vec.size() == n_);
|
83
|
+
for (int64_t j = 0; j < n_; j++) {
|
84
|
+
data_[i * n_ + j] += a * vec.data_[j];
|
85
|
+
}
|
86
|
+
}
|
87
|
+
|
88
|
+
void Matrix::multiplyRow(const Vector& nums, int64_t ib, int64_t ie) {
|
89
|
+
if (ie == -1) {ie = m_;}
|
90
|
+
assert(ie <= nums.size());
|
91
|
+
for (auto i = ib; i < ie; i++) {
|
92
|
+
real n = nums[i-ib];
|
93
|
+
if (n != 0) {
|
94
|
+
for (auto j = 0; j < n_; j++) {
|
95
|
+
at(i, j) *= n;
|
96
|
+
}
|
97
|
+
}
|
98
|
+
}
|
99
|
+
}
|
100
|
+
|
101
|
+
void Matrix::divideRow(const Vector& denoms, int64_t ib, int64_t ie) {
|
102
|
+
if (ie == -1) {ie = m_;}
|
103
|
+
assert(ie <= denoms.size());
|
104
|
+
for (auto i = ib; i < ie; i++) {
|
105
|
+
real n = denoms[i-ib];
|
106
|
+
if (n != 0) {
|
107
|
+
for (auto j = 0; j < n_; j++) {
|
108
|
+
at(i, j) /= n;
|
109
|
+
}
|
110
|
+
}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
|
114
|
+
real Matrix::l2NormRow(int64_t i) const {
|
115
|
+
auto norm = 0.0;
|
116
|
+
for (auto j = 0; j < n_; j++) {
|
117
|
+
const real v = at(i,j);
|
118
|
+
norm += v * v;
|
119
|
+
}
|
120
|
+
return std::sqrt(norm);
|
121
|
+
}
|
122
|
+
|
123
|
+
void Matrix::l2NormRow(Vector& norms) const {
|
124
|
+
assert(norms.size() == m_);
|
125
|
+
for (auto i = 0; i < m_; i++) {
|
126
|
+
norms[i] = l2NormRow(i);
|
127
|
+
}
|
128
|
+
}
|
129
|
+
|
130
|
+
void Matrix::save(std::ostream& out) {
|
131
|
+
out.write((char*) &m_, sizeof(int64_t));
|
132
|
+
out.write((char*) &n_, sizeof(int64_t));
|
133
|
+
out.write((char*) data_, m_ * n_ * sizeof(real));
|
134
|
+
}
|
135
|
+
|
136
|
+
void Matrix::load(std::istream& in) {
|
137
|
+
in.read((char*) &m_, sizeof(int64_t));
|
138
|
+
in.read((char*) &n_, sizeof(int64_t));
|
139
|
+
delete[] data_;
|
140
|
+
data_ = new real[m_ * n_];
|
141
|
+
in.read((char*) data_, m_ * n_ * sizeof(real));
|
142
|
+
}
|
143
|
+
|
144
|
+
}
|