liblinear-ruby 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (80) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +19 -0
  3. data/Gemfile +4 -0
  4. data/LICENSE.txt +22 -0
  5. data/README.md +46 -0
  6. data/Rakefile +1 -0
  7. data/ext/Makefile +237 -0
  8. data/ext/blas.h +25 -0
  9. data/ext/blasp.h +430 -0
  10. data/ext/daxpy.c +49 -0
  11. data/ext/ddot.c +50 -0
  12. data/ext/dnrm2.c +62 -0
  13. data/ext/dscal.c +44 -0
  14. data/ext/extconf.rb +12 -0
  15. data/ext/liblinear_wrap.cxx +4646 -0
  16. data/ext/linear.cpp +2811 -0
  17. data/ext/linear.h +74 -0
  18. data/ext/linear.rb +357 -0
  19. data/ext/tron.cpp +235 -0
  20. data/ext/tron.h +34 -0
  21. data/lib/liblinear.rb +89 -0
  22. data/lib/liblinear/error.rb +4 -0
  23. data/lib/liblinear/model.rb +66 -0
  24. data/lib/liblinear/parameter.rb +42 -0
  25. data/lib/liblinear/problem.rb +55 -0
  26. data/lib/liblinear/version.rb +3 -0
  27. data/liblinear-1.93/COPYRIGHT +31 -0
  28. data/liblinear-1.93/Makefile +37 -0
  29. data/liblinear-1.93/Makefile.win +30 -0
  30. data/liblinear-1.93/README +531 -0
  31. data/liblinear-1.93/blas/Makefile +22 -0
  32. data/liblinear-1.93/blas/blas.a +0 -0
  33. data/liblinear-1.93/blas/blas.h +25 -0
  34. data/liblinear-1.93/blas/blasp.h +430 -0
  35. data/liblinear-1.93/blas/daxpy.c +49 -0
  36. data/liblinear-1.93/blas/daxpy.o +0 -0
  37. data/liblinear-1.93/blas/ddot.c +50 -0
  38. data/liblinear-1.93/blas/ddot.o +0 -0
  39. data/liblinear-1.93/blas/dnrm2.c +62 -0
  40. data/liblinear-1.93/blas/dnrm2.o +0 -0
  41. data/liblinear-1.93/blas/dscal.c +44 -0
  42. data/liblinear-1.93/blas/dscal.o +0 -0
  43. data/liblinear-1.93/heart_scale +270 -0
  44. data/liblinear-1.93/linear.cpp +2811 -0
  45. data/liblinear-1.93/linear.def +18 -0
  46. data/liblinear-1.93/linear.h +74 -0
  47. data/liblinear-1.93/linear.o +0 -0
  48. data/liblinear-1.93/matlab/Makefile +58 -0
  49. data/liblinear-1.93/matlab/README +197 -0
  50. data/liblinear-1.93/matlab/libsvmread.c +212 -0
  51. data/liblinear-1.93/matlab/libsvmwrite.c +106 -0
  52. data/liblinear-1.93/matlab/linear_model_matlab.c +176 -0
  53. data/liblinear-1.93/matlab/linear_model_matlab.h +2 -0
  54. data/liblinear-1.93/matlab/make.m +21 -0
  55. data/liblinear-1.93/matlab/predict.c +331 -0
  56. data/liblinear-1.93/matlab/train.c +418 -0
  57. data/liblinear-1.93/predict +0 -0
  58. data/liblinear-1.93/predict.c +245 -0
  59. data/liblinear-1.93/python/Makefile +4 -0
  60. data/liblinear-1.93/python/README +343 -0
  61. data/liblinear-1.93/python/liblinear.py +277 -0
  62. data/liblinear-1.93/python/liblinearutil.py +250 -0
  63. data/liblinear-1.93/ruby/liblinear.i +41 -0
  64. data/liblinear-1.93/ruby/liblinear_wrap.cxx +4646 -0
  65. data/liblinear-1.93/ruby/linear.h +74 -0
  66. data/liblinear-1.93/ruby/linear.o +0 -0
  67. data/liblinear-1.93/train +0 -0
  68. data/liblinear-1.93/train.c +399 -0
  69. data/liblinear-1.93/tron.cpp +235 -0
  70. data/liblinear-1.93/tron.h +34 -0
  71. data/liblinear-1.93/tron.o +0 -0
  72. data/liblinear-1.93/windows/liblinear.dll +0 -0
  73. data/liblinear-1.93/windows/libsvmread.mexw64 +0 -0
  74. data/liblinear-1.93/windows/libsvmwrite.mexw64 +0 -0
  75. data/liblinear-1.93/windows/predict.exe +0 -0
  76. data/liblinear-1.93/windows/predict.mexw64 +0 -0
  77. data/liblinear-1.93/windows/train.exe +0 -0
  78. data/liblinear-1.93/windows/train.mexw64 +0 -0
  79. data/liblinear-ruby.gemspec +24 -0
  80. metadata +152 -0
@@ -0,0 +1,418 @@
1
+ #include <stdio.h>
2
+ #include <math.h>
3
+ #include <stdlib.h>
4
+ #include <string.h>
5
+ #include <ctype.h>
6
+ #include "../linear.h"
7
+
8
+ #include "mex.h"
9
+ #include "linear_model_matlab.h"
10
+
11
+ #ifdef MX_API_VER
12
+ #if MX_API_VER < 0x07030000
13
+ typedef int mwIndex;
14
+ #endif
15
+ #endif
16
+
17
+ #define CMD_LEN 2048
18
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
19
+ #define INF HUGE_VAL
20
+
21
+ void print_null(const char *s) {}
22
+ void print_string_matlab(const char *s) {mexPrintf(s);}
23
+
24
+ void exit_with_help()
25
+ {
26
+ mexPrintf(
27
+ "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n"
28
+ "liblinear_options:\n"
29
+ "-s type : set type of solver (default 1)\n"
30
+ " for multi-class classification\n"
31
+ " 0 -- L2-regularized logistic regression (primal)\n"
32
+ " 1 -- L2-regularized L2-loss support vector classification (dual)\n"
33
+ " 2 -- L2-regularized L2-loss support vector classification (primal)\n"
34
+ " 3 -- L2-regularized L1-loss support vector classification (dual)\n"
35
+ " 4 -- support vector classification by Crammer and Singer\n"
36
+ " 5 -- L1-regularized L2-loss support vector classification\n"
37
+ " 6 -- L1-regularized logistic regression\n"
38
+ " 7 -- L2-regularized logistic regression (dual)\n"
39
+ " for regression\n"
40
+ " 11 -- L2-regularized L2-loss support vector regression (primal)\n"
41
+ " 12 -- L2-regularized L2-loss support vector regression (dual)\n"
42
+ " 13 -- L2-regularized L1-loss support vector regression (dual)\n"
43
+ "-c cost : set the parameter C (default 1)\n"
44
+ "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
45
+ "-e epsilon : set tolerance of termination criterion\n"
46
+ " -s 0 and 2\n"
47
+ " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
48
+ " where f is the primal function and pos/neg are # of\n"
49
+ " positive/negative data (default 0.01)\n"
50
+ " -s 11\n"
51
+ " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n"
52
+ " -s 1, 3, 4 and 7\n"
53
+ " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
54
+ " -s 5 and 6\n"
55
+ " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
56
+ " where f is the primal function (default 0.01)\n"
57
+ " -s 12 and 13\n"
58
+ " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
59
+ " where f is the dual function (default 0.1)\n"
60
+ "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
61
+ "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
62
+ "-v n: n-fold cross validation mode\n"
63
+ "-q : quiet mode (no outputs)\n"
64
+ "col:\n"
65
+ " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n"
66
+ );
67
+ }
68
+
69
+ // liblinear arguments
70
+ struct parameter param; // set by parse_command_line
71
+ struct problem prob; // set by read_problem
72
+ struct model *model_;
73
+ struct feature_node *x_space;
74
+ int cross_validation_flag;
75
+ int col_format_flag;
76
+ int nr_fold;
77
+ double bias;
78
+
79
+ double do_cross_validation()
80
+ {
81
+ int i;
82
+ int total_correct = 0;
83
+ double total_error = 0;
84
+ double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
85
+ double *target = Malloc(double, prob.l);
86
+ double retval = 0.0;
87
+
88
+ cross_validation(&prob,&param,nr_fold,target);
89
+ if(param.solver_type == L2R_L2LOSS_SVR ||
90
+ param.solver_type == L2R_L1LOSS_SVR_DUAL ||
91
+ param.solver_type == L2R_L2LOSS_SVR_DUAL)
92
+ {
93
+ for(i=0;i<prob.l;i++)
94
+ {
95
+ double y = prob.y[i];
96
+ double v = target[i];
97
+ total_error += (v-y)*(v-y);
98
+ sumv += v;
99
+ sumy += y;
100
+ sumvv += v*v;
101
+ sumyy += y*y;
102
+ sumvy += v*y;
103
+ }
104
+ printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
105
+ printf("Cross Validation Squared correlation coefficient = %g\n",
106
+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
107
+ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
108
+ );
109
+ retval = total_error/prob.l;
110
+ }
111
+ else
112
+ {
113
+ for(i=0;i<prob.l;i++)
114
+ if(target[i] == prob.y[i])
115
+ ++total_correct;
116
+ printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
117
+ retval = 100.0*total_correct/prob.l;
118
+ }
119
+
120
+ free(target);
121
+ return retval;
122
+ }
123
+
124
+ // nrhs should be 3
125
+ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
126
+ {
127
+ int i, argc = 1;
128
+ char cmd[CMD_LEN];
129
+ char *argv[CMD_LEN/2];
130
+ void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display
131
+
132
+ // default values
133
+ param.solver_type = L2R_L2LOSS_SVC_DUAL;
134
+ param.C = 1;
135
+ param.eps = INF; // see setting below
136
+ param.p = 0.1;
137
+ param.nr_weight = 0;
138
+ param.weight_label = NULL;
139
+ param.weight = NULL;
140
+ cross_validation_flag = 0;
141
+ col_format_flag = 0;
142
+ bias = -1;
143
+
144
+
145
+ if(nrhs <= 1)
146
+ return 1;
147
+
148
+ if(nrhs == 4)
149
+ {
150
+ mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1);
151
+ if(strcmp(cmd, "col") == 0)
152
+ col_format_flag = 1;
153
+ }
154
+
155
+ // put options in argv[]
156
+ if(nrhs > 2)
157
+ {
158
+ mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
159
+ if((argv[argc] = strtok(cmd, " ")) != NULL)
160
+ while((argv[++argc] = strtok(NULL, " ")) != NULL)
161
+ ;
162
+ }
163
+
164
+ // parse options
165
+ for(i=1;i<argc;i++)
166
+ {
167
+ if(argv[i][0] != '-') break;
168
+ ++i;
169
+ if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter
170
+ return 1;
171
+ switch(argv[i-1][1])
172
+ {
173
+ case 's':
174
+ param.solver_type = atoi(argv[i]);
175
+ break;
176
+ case 'c':
177
+ param.C = atof(argv[i]);
178
+ break;
179
+ case 'p':
180
+ param.p = atof(argv[i]);
181
+ break;
182
+ case 'e':
183
+ param.eps = atof(argv[i]);
184
+ break;
185
+ case 'B':
186
+ bias = atof(argv[i]);
187
+ break;
188
+ case 'v':
189
+ cross_validation_flag = 1;
190
+ nr_fold = atoi(argv[i]);
191
+ if(nr_fold < 2)
192
+ {
193
+ mexPrintf("n-fold cross validation: n must >= 2\n");
194
+ return 1;
195
+ }
196
+ break;
197
+ case 'w':
198
+ ++param.nr_weight;
199
+ param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
200
+ param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
201
+ param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
202
+ param.weight[param.nr_weight-1] = atof(argv[i]);
203
+ break;
204
+ case 'q':
205
+ print_func = &print_null;
206
+ i--;
207
+ break;
208
+ default:
209
+ mexPrintf("unknown option\n");
210
+ return 1;
211
+ }
212
+ }
213
+
214
+ set_print_string_function(print_func);
215
+
216
+ if(param.eps == INF)
217
+ {
218
+ switch(param.solver_type)
219
+ {
220
+ case L2R_LR:
221
+ case L2R_L2LOSS_SVC:
222
+ param.eps = 0.01;
223
+ break;
224
+ case L2R_L2LOSS_SVR:
225
+ param.eps = 0.001;
226
+ break;
227
+ case L2R_L2LOSS_SVC_DUAL:
228
+ case L2R_L1LOSS_SVC_DUAL:
229
+ case MCSVM_CS:
230
+ case L2R_LR_DUAL:
231
+ param.eps = 0.1;
232
+ break;
233
+ case L1R_L2LOSS_SVC:
234
+ case L1R_LR:
235
+ param.eps = 0.01;
236
+ break;
237
+ case L2R_L1LOSS_SVR_DUAL:
238
+ case L2R_L2LOSS_SVR_DUAL:
239
+ param.eps = 0.1;
240
+ break;
241
+ }
242
+ }
243
+ return 0;
244
+ }
245
+
246
+ static void fake_answer(mxArray *plhs[])
247
+ {
248
+ plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
249
+ }
250
+
251
+ int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
252
+ {
253
+ int i, j, k, low, high;
254
+ mwIndex *ir, *jc;
255
+ int elements, max_index, num_samples, label_vector_row_num;
256
+ double *samples, *labels;
257
+ mxArray *instance_mat_col; // instance sparse matrix in column format
258
+
259
+ prob.x = NULL;
260
+ prob.y = NULL;
261
+ x_space = NULL;
262
+
263
+ if(col_format_flag)
264
+ instance_mat_col = (mxArray *)instance_mat;
265
+ else
266
+ {
267
+ // transpose instance matrix
268
+ mxArray *prhs[1], *plhs[1];
269
+ prhs[0] = mxDuplicateArray(instance_mat);
270
+ if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
271
+ {
272
+ mexPrintf("Error: cannot transpose training instance matrix\n");
273
+ return -1;
274
+ }
275
+ instance_mat_col = plhs[0];
276
+ mxDestroyArray(prhs[0]);
277
+ }
278
+
279
+ // the number of instance
280
+ prob.l = (int) mxGetN(instance_mat_col);
281
+ label_vector_row_num = (int) mxGetM(label_vec);
282
+
283
+ if(label_vector_row_num!=prob.l)
284
+ {
285
+ mexPrintf("Length of label vector does not match # of instances.\n");
286
+ return -1;
287
+ }
288
+
289
+ // each column is one instance
290
+ labels = mxGetPr(label_vec);
291
+ samples = mxGetPr(instance_mat_col);
292
+ ir = mxGetIr(instance_mat_col);
293
+ jc = mxGetJc(instance_mat_col);
294
+
295
+ num_samples = (int) mxGetNzmax(instance_mat_col);
296
+
297
+ elements = num_samples + prob.l*2;
298
+ max_index = (int) mxGetM(instance_mat_col);
299
+
300
+ prob.y = Malloc(double, prob.l);
301
+ prob.x = Malloc(struct feature_node*, prob.l);
302
+ x_space = Malloc(struct feature_node, elements);
303
+
304
+ prob.bias=bias;
305
+
306
+ j = 0;
307
+ for(i=0;i<prob.l;i++)
308
+ {
309
+ prob.x[i] = &x_space[j];
310
+ prob.y[i] = labels[i];
311
+ low = (int) jc[i], high = (int) jc[i+1];
312
+ for(k=low;k<high;k++)
313
+ {
314
+ x_space[j].index = (int) ir[k]+1;
315
+ x_space[j].value = samples[k];
316
+ j++;
317
+ }
318
+ if(prob.bias>=0)
319
+ {
320
+ x_space[j].index = max_index+1;
321
+ x_space[j].value = prob.bias;
322
+ j++;
323
+ }
324
+ x_space[j++].index = -1;
325
+ }
326
+
327
+ if(prob.bias>=0)
328
+ prob.n = max_index+1;
329
+ else
330
+ prob.n = max_index;
331
+
332
+ return 0;
333
+ }
334
+
335
+ // Interface function of matlab
336
+ // now assume prhs[0]: label prhs[1]: features
337
+ void mexFunction( int nlhs, mxArray *plhs[],
338
+ int nrhs, const mxArray *prhs[] )
339
+ {
340
+ const char *error_msg;
341
+ // fix random seed to have same results for each run
342
+ // (for cross validation)
343
+ srand(1);
344
+
345
+ // Transform the input Matrix to libsvm format
346
+ if(nrhs > 1 && nrhs < 5)
347
+ {
348
+ int err=0;
349
+
350
+ if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
351
+ mexPrintf("Error: label vector and instance matrix must be double\n");
352
+ fake_answer(plhs);
353
+ return;
354
+ }
355
+
356
+ if(parse_command_line(nrhs, prhs, NULL))
357
+ {
358
+ exit_with_help();
359
+ destroy_param(&param);
360
+ fake_answer(plhs);
361
+ return;
362
+ }
363
+
364
+ if(mxIsSparse(prhs[1]))
365
+ err = read_problem_sparse(prhs[0], prhs[1]);
366
+ else
367
+ {
368
+ mexPrintf("Training_instance_matrix must be sparse; "
369
+ "use sparse(Training_instance_matrix) first\n");
370
+ destroy_param(&param);
371
+ fake_answer(plhs);
372
+ return;
373
+ }
374
+
375
+ // train's original code
376
+ error_msg = check_parameter(&prob, &param);
377
+
378
+ if(err || error_msg)
379
+ {
380
+ if (error_msg != NULL)
381
+ mexPrintf("Error: %s\n", error_msg);
382
+ destroy_param(&param);
383
+ free(prob.y);
384
+ free(prob.x);
385
+ free(x_space);
386
+ fake_answer(plhs);
387
+ return;
388
+ }
389
+
390
+ if(cross_validation_flag)
391
+ {
392
+ double *ptr;
393
+ plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
394
+ ptr = mxGetPr(plhs[0]);
395
+ ptr[0] = do_cross_validation();
396
+ }
397
+ else
398
+ {
399
+ const char *error_msg;
400
+
401
+ model_ = train(&prob, &param);
402
+ error_msg = model_to_matlab_structure(plhs, model_);
403
+ if(error_msg)
404
+ mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
405
+ free_and_destroy_model(&model_);
406
+ }
407
+ destroy_param(&param);
408
+ free(prob.y);
409
+ free(prob.x);
410
+ free(x_space);
411
+ }
412
+ else
413
+ {
414
+ exit_with_help();
415
+ fake_answer(plhs);
416
+ return;
417
+ }
418
+ }
Binary file
@@ -0,0 +1,245 @@
1
+ #include <stdio.h>
2
+ #include <ctype.h>
3
+ #include <stdlib.h>
4
+ #include <string.h>
5
+ #include <errno.h>
6
+ #include "linear.h"
7
+
8
+ int print_null(const char *s,...) {return 0;}
9
+
10
+ static int (*info)(const char *fmt,...) = &printf;
11
+
12
+ struct feature_node *x;
13
+ int max_nr_attr = 64;
14
+
15
+ struct model* model_;
16
+ int flag_predict_probability=0;
17
+
18
+ void exit_input_error(int line_num)
19
+ {
20
+ fprintf(stderr,"Wrong input format at line %d\n", line_num);
21
+ exit(1);
22
+ }
23
+
24
+ static char *line = NULL;
25
+ static int max_line_len;
26
+
27
+ static char* readline(FILE *input)
28
+ {
29
+ int len;
30
+
31
+ if(fgets(line,max_line_len,input) == NULL)
32
+ return NULL;
33
+
34
+ while(strrchr(line,'\n') == NULL)
35
+ {
36
+ max_line_len *= 2;
37
+ line = (char *) realloc(line,max_line_len);
38
+ len = (int) strlen(line);
39
+ if(fgets(line+len,max_line_len-len,input) == NULL)
40
+ break;
41
+ }
42
+ return line;
43
+ }
44
+
45
+ void do_predict(FILE *input, FILE *output)
46
+ {
47
+ int correct = 0;
48
+ int total = 0;
49
+ double error = 0;
50
+ double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
51
+
52
+ int nr_class=get_nr_class(model_);
53
+ double *prob_estimates=NULL;
54
+ int j, n;
55
+ int nr_feature=get_nr_feature(model_);
56
+ if(model_->bias>=0)
57
+ n=nr_feature+1;
58
+ else
59
+ n=nr_feature;
60
+
61
+ if(flag_predict_probability)
62
+ {
63
+ int *labels;
64
+
65
+ if(!check_probability_model(model_))
66
+ {
67
+ fprintf(stderr, "probability output is only supported for logistic regression\n");
68
+ exit(1);
69
+ }
70
+
71
+ labels=(int *) malloc(nr_class*sizeof(int));
72
+ get_labels(model_,labels);
73
+ prob_estimates = (double *) malloc(nr_class*sizeof(double));
74
+ fprintf(output,"labels");
75
+ for(j=0;j<nr_class;j++)
76
+ fprintf(output," %d",labels[j]);
77
+ fprintf(output,"\n");
78
+ free(labels);
79
+ }
80
+
81
+ max_line_len = 1024;
82
+ line = (char *)malloc(max_line_len*sizeof(char));
83
+ while(readline(input) != NULL)
84
+ {
85
+ int i = 0;
86
+ double target_label, predict_label;
87
+ char *idx, *val, *label, *endptr;
88
+ int inst_max_index = 0; // strtol gives 0 if wrong format
89
+
90
+ label = strtok(line," \t\n");
91
+ if(label == NULL) // empty line
92
+ exit_input_error(total+1);
93
+
94
+ target_label = strtod(label,&endptr);
95
+ if(endptr == label || *endptr != '\0')
96
+ exit_input_error(total+1);
97
+
98
+ while(1)
99
+ {
100
+ if(i>=max_nr_attr-2) // need one more for index = -1
101
+ {
102
+ max_nr_attr *= 2;
103
+ x = (struct feature_node *) realloc(x,max_nr_attr*sizeof(struct feature_node));
104
+ }
105
+
106
+ idx = strtok(NULL,":");
107
+ val = strtok(NULL," \t");
108
+
109
+ if(val == NULL)
110
+ break;
111
+ errno = 0;
112
+ x[i].index = (int) strtol(idx,&endptr,10);
113
+ if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
114
+ exit_input_error(total+1);
115
+ else
116
+ inst_max_index = x[i].index;
117
+
118
+ errno = 0;
119
+ x[i].value = strtod(val,&endptr);
120
+ if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
121
+ exit_input_error(total+1);
122
+
123
+ // feature indices larger than those in training are not used
124
+ if(x[i].index <= nr_feature)
125
+ ++i;
126
+ }
127
+
128
+ if(model_->bias>=0)
129
+ {
130
+ x[i].index = n;
131
+ x[i].value = model_->bias;
132
+ i++;
133
+ }
134
+ x[i].index = -1;
135
+
136
+ if(flag_predict_probability)
137
+ {
138
+ int j;
139
+ predict_label = predict_probability(model_,x,prob_estimates);
140
+ fprintf(output,"%g",predict_label);
141
+ for(j=0;j<model_->nr_class;j++)
142
+ fprintf(output," %g",prob_estimates[j]);
143
+ fprintf(output,"\n");
144
+ }
145
+ else
146
+ {
147
+ predict_label = predict(model_,x);
148
+ fprintf(output,"%g\n",predict_label);
149
+ }
150
+
151
+ if(predict_label == target_label)
152
+ ++correct;
153
+ error += (predict_label-target_label)*(predict_label-target_label);
154
+ sump += predict_label;
155
+ sumt += target_label;
156
+ sumpp += predict_label*predict_label;
157
+ sumtt += target_label*target_label;
158
+ sumpt += predict_label*target_label;
159
+ ++total;
160
+ }
161
+ if(model_->param.solver_type==L2R_L2LOSS_SVR ||
162
+ model_->param.solver_type==L2R_L1LOSS_SVR_DUAL ||
163
+ model_->param.solver_type==L2R_L2LOSS_SVR_DUAL)
164
+ {
165
+ info("Mean squared error = %g (regression)\n",error/total);
166
+ info("Squared correlation coefficient = %g (regression)\n",
167
+ ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
168
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
169
+ );
170
+ }
171
+ else
172
+ info("Accuracy = %g%% (%d/%d)\n",(double) correct/total*100,correct,total);
173
+ if(flag_predict_probability)
174
+ free(prob_estimates);
175
+ }
176
+
177
+ void exit_with_help()
178
+ {
179
+ printf(
180
+ "Usage: predict [options] test_file model_file output_file\n"
181
+ "options:\n"
182
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only\n"
183
+ "-q : quiet mode (no outputs)\n"
184
+ );
185
+ exit(1);
186
+ }
187
+
188
+ int main(int argc, char **argv)
189
+ {
190
+ FILE *input, *output;
191
+ int i;
192
+
193
+ // parse options
194
+ for(i=1;i<argc;i++)
195
+ {
196
+ if(argv[i][0] != '-') break;
197
+ ++i;
198
+ switch(argv[i-1][1])
199
+ {
200
+ case 'b':
201
+ flag_predict_probability = atoi(argv[i]);
202
+ break;
203
+ case 'q':
204
+ info = &print_null;
205
+ i--;
206
+ break;
207
+ default:
208
+ fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
209
+ exit_with_help();
210
+ break;
211
+ }
212
+ }
213
+ if(i>=argc)
214
+ exit_with_help();
215
+
216
+ input = fopen(argv[i],"r");
217
+ if(input == NULL)
218
+ {
219
+ fprintf(stderr,"can't open input file %s\n",argv[i]);
220
+ exit(1);
221
+ }
222
+
223
+ output = fopen(argv[i+2],"w");
224
+ if(output == NULL)
225
+ {
226
+ fprintf(stderr,"can't open output file %s\n",argv[i+2]);
227
+ exit(1);
228
+ }
229
+
230
+ if((model_=load_model(argv[i+1]))==0)
231
+ {
232
+ fprintf(stderr,"can't open model file %s\n",argv[i+1]);
233
+ exit(1);
234
+ }
235
+
236
+ x = (struct feature_node *) malloc(max_nr_attr*sizeof(struct feature_node));
237
+ do_predict(input, output);
238
+ free_and_destroy_model(&model_);
239
+ free(line);
240
+ free(x);
241
+ fclose(input);
242
+ fclose(output);
243
+ return 0;
244
+ }
245
+