libmf 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,207 @@
1
+ #include <cstring>
2
+ #include <cstdlib>
3
+ #include <fstream>
4
+ #include <iostream>
5
+ #include <string>
6
+ #include <iomanip>
7
+ #include <stdexcept>
8
+ #include <vector>
9
+
10
+ #include "mf.h"
11
+
12
+ using namespace std;
13
+ using namespace mf;
14
+
15
+ struct Option
16
+ {
17
+ Option() : eval(RMSE) {}
18
+ string test_path, model_path, output_path;
19
+ mf_int eval;
20
+ };
21
+
22
+ string predict_help()
23
+ {
24
+ return string(
25
+ "usage: mf-predict [options] test_file model_file [output_file]\n"
26
+ "\n"
27
+ "options:\n"
28
+ "-e <eval>: specify the evaluation criterion (default 0)\n"
29
+ "\t 0 -- root mean square error\n"
30
+ "\t 1 -- mean absolute error\n"
31
+ "\t 2 -- generalized KL-divergence\n"
32
+ "\t 5 -- logarithmic error\n"
33
+ "\t 6 -- accuracy\n"
34
+ "\t10 -- row-wise mean percentile rank\n"
35
+ "\t11 -- column-wise mean percentile rank\n"
36
+ "\t12 -- row-wise area under the curve\n"
37
+ "\t13 -- column-wise area under the curve\n");
38
+ }
39
+
40
+ Option parse_option(int argc, char **argv)
41
+ {
42
+ vector<string> args;
43
+ for(int i = 0; i < argc; i++)
44
+ args.push_back(string(argv[i]));
45
+
46
+ if(argc == 1)
47
+ throw invalid_argument(predict_help());
48
+
49
+ Option option;
50
+
51
+ mf_int i;
52
+ for(i = 1; i < argc; i++)
53
+ {
54
+ if(args[i].compare("-e") == 0)
55
+ {
56
+ if((i+1) >= argc)
57
+ throw invalid_argument("need to specify evaluation criterion after -e");
58
+ i++;
59
+ option.eval = atoi(argv[i]);
60
+ if(option.eval != RMSE &&
61
+ option.eval != MAE &&
62
+ option.eval != GKL &&
63
+ option.eval != LOGLOSS &&
64
+ option.eval != ACC &&
65
+ option.eval != ROW_AUC &&
66
+ option.eval != COL_AUC &&
67
+ option.eval != ROW_MPR &&
68
+ option.eval != COL_MPR)
69
+ throw invalid_argument("unknown evaluation criterion");
70
+ }
71
+ else
72
+ break;
73
+ }
74
+ if(i >= argc-1)
75
+ throw invalid_argument("testing data and model file not specified");
76
+ option.test_path = string(args[i++]);
77
+ option.model_path = string(args[i++]);
78
+
79
+ if(i < argc)
80
+ {
81
+ option.output_path = string(args[i]);
82
+ }
83
+ else if(i == argc)
84
+ {
85
+ const char *ptr = strrchr(&*option.test_path.begin(), '/');
86
+ if(!ptr)
87
+ ptr = option.test_path.c_str();
88
+ else
89
+ ++ptr;
90
+ option.output_path = string(ptr) + ".out";
91
+ }
92
+ else
93
+ {
94
+ throw invalid_argument("invalid argument");
95
+ }
96
+
97
+ return option;
98
+ }
99
+
100
+ void predict(string test_path, string model_path, string output_path, mf_int eval)
101
+ {
102
+ mf_problem prob = read_problem(test_path);
103
+
104
+ ofstream f_out(output_path);
105
+ if(!f_out.is_open())
106
+ throw runtime_error("cannot open " + output_path);
107
+
108
+ mf_model *model = mf_load_model(model_path.c_str());
109
+ if(model == nullptr)
110
+ throw runtime_error("cannot load model from " + model_path);
111
+
112
+ for(mf_int i = 0; i < prob.nnz; i++)
113
+ {
114
+ mf_float r = mf_predict(model, prob.R[i].u, prob.R[i].v);
115
+ f_out << r << endl;
116
+ }
117
+
118
+ switch(eval)
119
+ {
120
+ case RMSE:
121
+ {
122
+ auto rmse = calc_rmse(&prob, model);
123
+ cout << fixed << setprecision(4) << "RMSE = " << rmse << endl;
124
+ break;
125
+ }
126
+ case MAE:
127
+ {
128
+ auto mae = calc_mae(&prob, model);
129
+ cout << fixed << setprecision(4) << "MAE = " << mae << endl;
130
+ break;
131
+ }
132
+ case GKL:
133
+ {
134
+ auto gkl = calc_gkl(&prob, model);
135
+ cout << fixed << setprecision(4) << "GKL = " << gkl << endl;
136
+ break;
137
+ }
138
+ case LOGLOSS:
139
+ {
140
+ auto logloss = calc_logloss(&prob, model);
141
+ cout << fixed << setprecision(4) << "LOGLOSS = " << logloss << endl;
142
+ break;
143
+ }
144
+ case ACC:
145
+ {
146
+ auto acc = calc_accuracy(&prob, model);
147
+ cout << fixed << setprecision(4) << "ACCURACY = " << acc << endl;
148
+ break;
149
+ }
150
+ case ROW_AUC:
151
+ {
152
+ auto row_wise_auc = calc_auc(&prob, model, false);
153
+ cout << fixed << setprecision(4) << "Row-wise AUC = " << row_wise_auc << endl;
154
+ break;
155
+ }
156
+ case COL_AUC:
157
+ {
158
+ auto col_wise_auc = calc_auc(&prob, model, true);
159
+ cout << fixed << setprecision(4) << "Colmn-wise AUC = " << col_wise_auc << endl;
160
+ break;
161
+ }
162
+ case ROW_MPR:
163
+ {
164
+ auto row_wise_mpr = calc_mpr(&prob, model, false);
165
+ cout << fixed << setprecision(4) << "Row-wise MPR = " << row_wise_mpr << endl;
166
+ break;
167
+ }
168
+ case COL_MPR:
169
+ {
170
+ auto col_wise_mpr = calc_mpr(&prob, model, true);
171
+ cout << fixed << setprecision(4) << "Column-wise MPR = " << col_wise_mpr << endl;
172
+ break;
173
+ }
174
+ default:
175
+ {
176
+ throw invalid_argument("unknown evaluation criterion");
177
+ break;
178
+ }
179
+ }
180
+ mf_destroy_model(&model);
181
+ }
182
+
183
+ int main(int argc, char **argv)
184
+ {
185
+ Option option;
186
+ try
187
+ {
188
+ option = parse_option(argc, argv);
189
+ }
190
+ catch(invalid_argument &e)
191
+ {
192
+ cout << e.what() << endl;
193
+ return 1;
194
+ }
195
+
196
+ try
197
+ {
198
+ predict(option.test_path, option.model_path, option.output_path, option.eval);
199
+ }
200
+ catch(runtime_error &e)
201
+ {
202
+ cout << e.what() << endl;
203
+ return 1;
204
+ }
205
+
206
+ return 0;
207
+ }
@@ -0,0 +1,378 @@
1
+ #include <algorithm>
2
+ #include <cctype>
3
+ #include <cmath>
4
+ #include <cstring>
5
+ #include <cstdlib>
6
+ #include <fstream>
7
+ #include <iostream>
8
+ #include <stdexcept>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ #include "mf.h"
13
+
14
+ using namespace std;
15
+ using namespace mf;
16
+
17
+ struct Option
18
+ {
19
+ Option() : param(mf_get_default_param()), nr_folds(1), on_disk(false), do_cv(false) {}
20
+ string tr_path, va_path, model_path;
21
+ mf_parameter param;
22
+ mf_int nr_folds;
23
+ bool on_disk;
24
+ bool do_cv;
25
+ };
26
+
27
+ string train_help()
28
+ {
29
+ return string(
30
+ "usage: mf-train [options] training_set_file [model_file]\n"
31
+ "\n"
32
+ "options:\n"
33
+ "-l1 <lambda>,<lambda>: set L1-regularization parameters for P and Q (default 0)\n"
34
+ " P and Q share the same lambda if only one lambda is specified\n"
35
+ "-l2 <lambda>,<lambda>: set L2-regularization parameters for P and Q (default 0.1)\n"
36
+ " P and Q share the same lambda if only one lambda is specified\n"
37
+ "-f <loss>: set loss function (default 0)\n"
38
+ " for real-valued matrix factorization\n"
39
+ "\t 0 -- squared error (L2-norm)\n"
40
+ "\t 1 -- absolute error (L1-norm)\n"
41
+ "\t 2 -- generalized KL-divergence\n"
42
+ " for binary matrix factorization\n"
43
+ "\t 5 -- logarithmic loss\n"
44
+ "\t 6 -- squared hinge loss\n"
45
+ "\t 7 -- hinge loss\n"
46
+ " for one-class matrix factorization\n"
47
+ "\t10 -- row-oriented pairwise logarithmic loss\n"
48
+ "\t11 -- column-oriented pairwise logarithmic loss\n"
49
+ "\t12 -- squared error (L2-norm)\n"
50
+ "-k <dimensions>: set number of dimensions (default 8)\n"
51
+ "-t <iter>: set number of iterations (default 20)\n"
52
+ "-r <eta>: set learning rate (default 0.1)\n"
53
+ "-a <alpha>: set coefficient of negative entries' loss (default 1)\n"
54
+ "-c <c>: set value of negative entries (default 0.0001). Positive entry is always 1.\n"
55
+ "-s <threads>: set number of threads (default 12)\n"
56
+ "-n <bins>: set number of bins (may be adjusted by LIBMF)\n"
57
+ "-p <path>: set path to the validation set\n"
58
+ "-v <fold>: set number of folds for cross validation\n"
59
+ "--quiet: quiet mode (no outputs)\n"
60
+ "--nmf: perform non-negative matrix factorization\n"
61
+ "--disk: perform disk-level training (will generate a buffer file)\n");
62
+ }
63
+
64
+ bool is_numerical(char *str)
65
+ {
66
+ int c = 0;
67
+ while(*str != '\0')
68
+ {
69
+ if(isdigit(*str))
70
+ c++;
71
+ str++;
72
+ }
73
+ return c > 0;
74
+ }
75
+
76
+ Option parse_option(int argc, char **argv)
77
+ {
78
+ vector<string> args;
79
+ for(int i = 0; i < argc; i++)
80
+ args.push_back(string(argv[i]));
81
+
82
+ if(argc == 1)
83
+ throw invalid_argument(train_help());
84
+
85
+ Option option;
86
+
87
+ mf_int i;
88
+ for(i = 1; i < argc; i++)
89
+ {
90
+ if(args[i].compare("-l1") == 0)
91
+ {
92
+ if((i+1) >= argc)
93
+ throw invalid_argument("need to specify lambda after -l1");
94
+ i++;
95
+
96
+ char *pch = strtok(argv[i], ",");
97
+ if(!is_numerical(pch))
98
+ throw invalid_argument("regularization coefficient\
99
+ should be a number");
100
+ option.param.lambda_p1 = (mf_float)strtod(pch, NULL);
101
+ option.param.lambda_q1 = (mf_float)strtod(pch, NULL);
102
+ pch = strtok(NULL, ",");
103
+ if(pch != NULL)
104
+ {
105
+ if(!is_numerical(pch))
106
+ throw invalid_argument("regularization coefficient\
107
+ should be a number");
108
+ option.param.lambda_q1 = (mf_float)strtod(pch, NULL);
109
+ }
110
+ }
111
+ else if(args[i].compare("-l2") == 0)
112
+ {
113
+ if((i+1) >= argc)
114
+ throw invalid_argument("need to specify lambda after -l2");
115
+ i++;
116
+
117
+ char *pch = strtok(argv[i], ",");
118
+ if(!is_numerical(pch))
119
+ throw invalid_argument("regularization coefficient\
120
+ should be a number");
121
+ option.param.lambda_p2 = (mf_float)strtod(pch, NULL);
122
+ option.param.lambda_q2 = (mf_float)strtod(pch, NULL);
123
+ pch = strtok(NULL, ",");
124
+ if(pch != NULL)
125
+ {
126
+ if(!is_numerical(pch))
127
+ throw invalid_argument("regularization coefficient\
128
+ should be a number");
129
+ option.param.lambda_q2 = (mf_float)strtod(pch, NULL);
130
+ }
131
+ }
132
+ else if(args[i].compare("-k") == 0)
133
+ {
134
+ if((i+1) >= argc)
135
+ throw invalid_argument("need to specify number of factors\
136
+ after -k");
137
+ i++;
138
+
139
+ if(!is_numerical(argv[i]))
140
+ throw invalid_argument("-k should be followed by a number");
141
+ option.param.k = atoi(argv[i]);
142
+ }
143
+ else if(args[i].compare("-t") == 0)
144
+ {
145
+ if((i+1) >= argc)
146
+ throw invalid_argument("need to specify number of iterations\
147
+ after -t");
148
+ i++;
149
+
150
+ if(!is_numerical(argv[i]))
151
+ throw invalid_argument("-i should be followed by a number");
152
+ option.param.nr_iters = atoi(argv[i]);
153
+ }
154
+ else if(args[i].compare("-r") == 0)
155
+ {
156
+ if((i+1) >= argc)
157
+ throw invalid_argument("need to specify eta after -r");
158
+ i++;
159
+
160
+ if(!is_numerical(argv[i]))
161
+ throw invalid_argument("-r should be followed by a number");
162
+ option.param.eta = (mf_float)atof(argv[i]);
163
+ }
164
+ else if(args[i].compare("-s") == 0)
165
+ {
166
+ if((i+1) >= argc)
167
+ throw invalid_argument("need to specify number of threads\
168
+ after -s");
169
+ i++;
170
+
171
+ if(!is_numerical(argv[i]))
172
+ throw invalid_argument("-s should be followed by a number");
173
+ option.param.nr_threads = atoi(argv[i]);
174
+ }
175
+ else if(args[i].compare("-a") == 0)
176
+ {
177
+ if((i+1) >= argc)
178
+ throw invalid_argument("need to specify negative weight\
179
+ after -a");
180
+ i++;
181
+
182
+ if(!is_numerical(argv[i]))
183
+ throw invalid_argument("-a should be followed by a number");
184
+ option.param.alpha = static_cast<mf_float>(atof(argv[i]));
185
+ }
186
+ else if(args[i].compare("-c") == 0)
187
+ {
188
+ if((i+1) >= argc)
189
+ throw invalid_argument("need to specify negative rating\
190
+ after -c");
191
+ i++;
192
+
193
+ if(!is_numerical(argv[i]))
194
+ throw invalid_argument("-c should be followed by a number");
195
+
196
+ if (argv[i][0] == '-')
197
+ // Negative number starts with - but atof only recognize numbers.
198
+ // Thus, we pass all but the first symbol to atof.
199
+ option.param.c = -static_cast<mf_float>(atof(argv[i] + 1));
200
+ else
201
+ // Non-negative numbers such as 0 and 0.5 can be handled by atof.
202
+ option.param.c = static_cast<mf_float>(atof(argv[i]));
203
+ }
204
+ else if(args[i].compare("-p") == 0)
205
+ {
206
+ if(i == argc-1)
207
+ throw invalid_argument("need to specify path after -p");
208
+ i++;
209
+
210
+ option.va_path = string(args[i]);
211
+ }
212
+ else if(args[i].compare("-v") == 0)
213
+ {
214
+ if(i == argc-1)
215
+ throw invalid_argument("need to specify number of folds\
216
+ after -v");
217
+ i++;
218
+
219
+ if(!is_numerical(argv[i]))
220
+ throw invalid_argument("-v should be followed by a number");
221
+ option.nr_folds = atoi(argv[i]);
222
+
223
+ if(option.nr_folds < 2)
224
+ throw invalid_argument("number of folds\
225
+ must be greater than one");
226
+ option.do_cv = true;
227
+ }
228
+ else if(args[i].compare("-f") == 0)
229
+ {
230
+ if(i == argc-1)
231
+ throw invalid_argument("need to specify loss function\
232
+ after -f");
233
+ i++;
234
+
235
+ if(!is_numerical(argv[i]))
236
+ throw invalid_argument("-f should be followed by a number");
237
+ option.param.fun = atoi(argv[i]);
238
+ }
239
+ else if(args[i].compare("-n") == 0)
240
+ {
241
+ if(i == argc-1)
242
+ throw invalid_argument("need to specify the number of blocks\
243
+ after -n");
244
+ i++;
245
+
246
+ if(!is_numerical(argv[i]))
247
+ throw invalid_argument("-n should be followed by a number");
248
+ option.param.nr_bins = atoi(argv[i]);
249
+ }
250
+ else if(args[i].compare("--nmf") == 0)
251
+ {
252
+ option.param.do_nmf = true;
253
+ }
254
+ else if(args[i].compare("--quiet") == 0)
255
+ {
256
+ option.param.quiet = true;
257
+ }
258
+ else if(args[i].compare("--disk") == 0)
259
+ {
260
+ option.on_disk = true;
261
+ }
262
+ else
263
+ {
264
+ break;
265
+ }
266
+ }
267
+
268
+ if(option.nr_folds > 1 && !option.va_path.empty())
269
+ throw invalid_argument("cannot specify both -p and -v");
270
+
271
+ if(i >= argc)
272
+ throw invalid_argument("training data not specified");
273
+
274
+ option.tr_path = string(args[i++]);
275
+
276
+ if(i < argc)
277
+ {
278
+ option.model_path = string(args[i]);
279
+ }
280
+ else if(i == argc)
281
+ {
282
+ const char *ptr = strrchr(&*option.tr_path.begin(), '/');
283
+ if(!ptr)
284
+ ptr = option.tr_path.c_str();
285
+ else
286
+ ++ptr;
287
+ option.model_path = string(ptr) + ".model";
288
+ }
289
+ else
290
+ {
291
+ throw invalid_argument("invalid argument");
292
+ }
293
+
294
+ option.param.nr_bins = max(option.param.nr_bins,
295
+ 2*option.param.nr_threads+1);
296
+ option.param.copy_data = false;
297
+
298
+ return option;
299
+ }
300
+
301
+ int main(int argc, char **argv)
302
+ {
303
+ Option option;
304
+ try
305
+ {
306
+ option = parse_option(argc, argv);
307
+ }
308
+ catch(invalid_argument &e)
309
+ {
310
+ cout << e.what() << endl;
311
+ return 1;
312
+ }
313
+
314
+ mf_problem tr = {};
315
+ mf_problem va = {};
316
+ if(!option.on_disk)
317
+ {
318
+ try
319
+ {
320
+ tr = read_problem(option.tr_path);
321
+ va = read_problem(option.va_path);
322
+ }
323
+ catch(runtime_error &e)
324
+ {
325
+ cout << e.what() << endl;
326
+ return 1;
327
+ }
328
+ }
329
+
330
+ if(option.do_cv)
331
+ {
332
+ if(!option.on_disk)
333
+ mf_cross_validation(&tr, option.nr_folds, option.param);
334
+ else
335
+ mf_cross_validation_on_disk(
336
+ option.tr_path.c_str(), option.nr_folds, option.param);
337
+ }
338
+ else
339
+ {
340
+ mf_model *model;
341
+ if(!option.on_disk)
342
+ model = mf_train_with_validation(&tr, &va, option.param);
343
+ else
344
+ model = mf_train_with_validation_on_disk(option.tr_path.c_str(),
345
+ option.va_path.c_str(),
346
+ option.param);
347
+
348
+ // use the following function if you do not have a validation set
349
+
350
+ // mf_model model =
351
+ // mf_train_with_validation(&tr, option.param);
352
+
353
+ mf_int status = mf_save_model(model, option.model_path.c_str());
354
+
355
+ mf_destroy_model(&model);
356
+
357
+ if(status != 0)
358
+ {
359
+ cout << "cannot save model to " << option.model_path << endl;
360
+
361
+ if(!option.on_disk)
362
+ {
363
+ delete[] tr.R;
364
+ delete[] va.R;
365
+ }
366
+
367
+ return 1;
368
+ }
369
+ }
370
+
371
+ if(!option.on_disk)
372
+ {
373
+ delete[] tr.R;
374
+ delete[] va.R;
375
+ }
376
+
377
+ return 0;
378
+ }