liblinear-ruby 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +19 -0
  3. data/Gemfile +4 -0
  4. data/LICENSE.txt +22 -0
  5. data/README.md +46 -0
  6. data/Rakefile +1 -0
  7. data/ext/Makefile +237 -0
  8. data/ext/blas.h +25 -0
  9. data/ext/blasp.h +430 -0
  10. data/ext/daxpy.c +49 -0
  11. data/ext/ddot.c +50 -0
  12. data/ext/dnrm2.c +62 -0
  13. data/ext/dscal.c +44 -0
  14. data/ext/extconf.rb +12 -0
  15. data/ext/liblinear_wrap.cxx +4646 -0
  16. data/ext/linear.cpp +2811 -0
  17. data/ext/linear.h +74 -0
  18. data/ext/linear.rb +357 -0
  19. data/ext/tron.cpp +235 -0
  20. data/ext/tron.h +34 -0
  21. data/lib/liblinear.rb +89 -0
  22. data/lib/liblinear/error.rb +4 -0
  23. data/lib/liblinear/model.rb +66 -0
  24. data/lib/liblinear/parameter.rb +42 -0
  25. data/lib/liblinear/problem.rb +55 -0
  26. data/lib/liblinear/version.rb +3 -0
  27. data/liblinear-1.93/COPYRIGHT +31 -0
  28. data/liblinear-1.93/Makefile +37 -0
  29. data/liblinear-1.93/Makefile.win +30 -0
  30. data/liblinear-1.93/README +531 -0
  31. data/liblinear-1.93/blas/Makefile +22 -0
  32. data/liblinear-1.93/blas/blas.a +0 -0
  33. data/liblinear-1.93/blas/blas.h +25 -0
  34. data/liblinear-1.93/blas/blasp.h +430 -0
  35. data/liblinear-1.93/blas/daxpy.c +49 -0
  36. data/liblinear-1.93/blas/daxpy.o +0 -0
  37. data/liblinear-1.93/blas/ddot.c +50 -0
  38. data/liblinear-1.93/blas/ddot.o +0 -0
  39. data/liblinear-1.93/blas/dnrm2.c +62 -0
  40. data/liblinear-1.93/blas/dnrm2.o +0 -0
  41. data/liblinear-1.93/blas/dscal.c +44 -0
  42. data/liblinear-1.93/blas/dscal.o +0 -0
  43. data/liblinear-1.93/heart_scale +270 -0
  44. data/liblinear-1.93/linear.cpp +2811 -0
  45. data/liblinear-1.93/linear.def +18 -0
  46. data/liblinear-1.93/linear.h +74 -0
  47. data/liblinear-1.93/linear.o +0 -0
  48. data/liblinear-1.93/matlab/Makefile +58 -0
  49. data/liblinear-1.93/matlab/README +197 -0
  50. data/liblinear-1.93/matlab/libsvmread.c +212 -0
  51. data/liblinear-1.93/matlab/libsvmwrite.c +106 -0
  52. data/liblinear-1.93/matlab/linear_model_matlab.c +176 -0
  53. data/liblinear-1.93/matlab/linear_model_matlab.h +2 -0
  54. data/liblinear-1.93/matlab/make.m +21 -0
  55. data/liblinear-1.93/matlab/predict.c +331 -0
  56. data/liblinear-1.93/matlab/train.c +418 -0
  57. data/liblinear-1.93/predict +0 -0
  58. data/liblinear-1.93/predict.c +245 -0
  59. data/liblinear-1.93/python/Makefile +4 -0
  60. data/liblinear-1.93/python/README +343 -0
  61. data/liblinear-1.93/python/liblinear.py +277 -0
  62. data/liblinear-1.93/python/liblinearutil.py +250 -0
  63. data/liblinear-1.93/ruby/liblinear.i +41 -0
  64. data/liblinear-1.93/ruby/liblinear_wrap.cxx +4646 -0
  65. data/liblinear-1.93/ruby/linear.h +74 -0
  66. data/liblinear-1.93/ruby/linear.o +0 -0
  67. data/liblinear-1.93/train +0 -0
  68. data/liblinear-1.93/train.c +399 -0
  69. data/liblinear-1.93/tron.cpp +235 -0
  70. data/liblinear-1.93/tron.h +34 -0
  71. data/liblinear-1.93/tron.o +0 -0
  72. data/liblinear-1.93/windows/liblinear.dll +0 -0
  73. data/liblinear-1.93/windows/libsvmread.mexw64 +0 -0
  74. data/liblinear-1.93/windows/libsvmwrite.mexw64 +0 -0
  75. data/liblinear-1.93/windows/predict.exe +0 -0
  76. data/liblinear-1.93/windows/predict.mexw64 +0 -0
  77. data/liblinear-1.93/windows/train.exe +0 -0
  78. data/liblinear-1.93/windows/train.mexw64 +0 -0
  79. data/liblinear-ruby.gemspec +24 -0
  80. metadata +152 -0
@@ -0,0 +1,418 @@
1
+ #include <stdio.h>
2
+ #include <math.h>
3
+ #include <stdlib.h>
4
+ #include <string.h>
5
+ #include <ctype.h>
6
+ #include "../linear.h"
7
+
8
+ #include "mex.h"
9
+ #include "linear_model_matlab.h"
10
+
11
+ #ifdef MX_API_VER
12
+ #if MX_API_VER < 0x07030000
13
+ typedef int mwIndex;
14
+ #endif
15
+ #endif
16
+
17
+ #define CMD_LEN 2048
18
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
19
+ #define INF HUGE_VAL
20
+
21
+ void print_null(const char *s) {}
22
+ void print_string_matlab(const char *s) {mexPrintf(s);}
23
+
24
+ void exit_with_help()
25
+ {
26
+ mexPrintf(
27
+ "Usage: model = train(training_label_vector, training_instance_matrix, 'liblinear_options', 'col');\n"
28
+ "liblinear_options:\n"
29
+ "-s type : set type of solver (default 1)\n"
30
+ " for multi-class classification\n"
31
+ " 0 -- L2-regularized logistic regression (primal)\n"
32
+ " 1 -- L2-regularized L2-loss support vector classification (dual)\n"
33
+ " 2 -- L2-regularized L2-loss support vector classification (primal)\n"
34
+ " 3 -- L2-regularized L1-loss support vector classification (dual)\n"
35
+ " 4 -- support vector classification by Crammer and Singer\n"
36
+ " 5 -- L1-regularized L2-loss support vector classification\n"
37
+ " 6 -- L1-regularized logistic regression\n"
38
+ " 7 -- L2-regularized logistic regression (dual)\n"
39
+ " for regression\n"
40
+ " 11 -- L2-regularized L2-loss support vector regression (primal)\n"
41
+ " 12 -- L2-regularized L2-loss support vector regression (dual)\n"
42
+ " 13 -- L2-regularized L1-loss support vector regression (dual)\n"
43
+ "-c cost : set the parameter C (default 1)\n"
44
+ "-p epsilon : set the epsilon in loss function of SVR (default 0.1)\n"
45
+ "-e epsilon : set tolerance of termination criterion\n"
46
+ " -s 0 and 2\n"
47
+ " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2,\n"
48
+ " where f is the primal function and pos/neg are # of\n"
49
+ " positive/negative data (default 0.01)\n"
50
+ " -s 11\n"
51
+ " |f'(w)|_2 <= eps*|f'(w0)|_2 (default 0.001)\n"
52
+ " -s 1, 3, 4 and 7\n"
53
+ " Dual maximal violation <= eps; similar to libsvm (default 0.1)\n"
54
+ " -s 5 and 6\n"
55
+ " |f'(w)|_1 <= eps*min(pos,neg)/l*|f'(w0)|_1,\n"
56
+ " where f is the primal function (default 0.01)\n"
57
+ " -s 12 and 13\n"
58
+ " |f'(alpha)|_1 <= eps |f'(alpha0)|,\n"
59
+ " where f is the dual function (default 0.1)\n"
60
+ "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default -1)\n"
61
+ "-wi weight: weights adjust the parameter C of different classes (see README for details)\n"
62
+ "-v n: n-fold cross validation mode\n"
63
+ "-q : quiet mode (no outputs)\n"
64
+ "col:\n"
65
+ " if 'col' is setted, training_instance_matrix is parsed in column format, otherwise is in row format\n"
66
+ );
67
+ }
68
+
69
+ // liblinear arguments
70
+ struct parameter param; // set by parse_command_line
71
+ struct problem prob; // set by read_problem
72
+ struct model *model_;
73
+ struct feature_node *x_space;
74
+ int cross_validation_flag;
75
+ int col_format_flag;
76
+ int nr_fold;
77
+ double bias;
78
+
79
+ double do_cross_validation()
80
+ {
81
+ int i;
82
+ int total_correct = 0;
83
+ double total_error = 0;
84
+ double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
85
+ double *target = Malloc(double, prob.l);
86
+ double retval = 0.0;
87
+
88
+ cross_validation(&prob,&param,nr_fold,target);
89
+ if(param.solver_type == L2R_L2LOSS_SVR ||
90
+ param.solver_type == L2R_L1LOSS_SVR_DUAL ||
91
+ param.solver_type == L2R_L2LOSS_SVR_DUAL)
92
+ {
93
+ for(i=0;i<prob.l;i++)
94
+ {
95
+ double y = prob.y[i];
96
+ double v = target[i];
97
+ total_error += (v-y)*(v-y);
98
+ sumv += v;
99
+ sumy += y;
100
+ sumvv += v*v;
101
+ sumyy += y*y;
102
+ sumvy += v*y;
103
+ }
104
+ printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
105
+ printf("Cross Validation Squared correlation coefficient = %g\n",
106
+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
107
+ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
108
+ );
109
+ retval = total_error/prob.l;
110
+ }
111
+ else
112
+ {
113
+ for(i=0;i<prob.l;i++)
114
+ if(target[i] == prob.y[i])
115
+ ++total_correct;
116
+ printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
117
+ retval = 100.0*total_correct/prob.l;
118
+ }
119
+
120
+ free(target);
121
+ return retval;
122
+ }
123
+
124
+ // nrhs should be 3
125
+ int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
126
+ {
127
+ int i, argc = 1;
128
+ char cmd[CMD_LEN];
129
+ char *argv[CMD_LEN/2];
130
+ void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display
131
+
132
+ // default values
133
+ param.solver_type = L2R_L2LOSS_SVC_DUAL;
134
+ param.C = 1;
135
+ param.eps = INF; // see setting below
136
+ param.p = 0.1;
137
+ param.nr_weight = 0;
138
+ param.weight_label = NULL;
139
+ param.weight = NULL;
140
+ cross_validation_flag = 0;
141
+ col_format_flag = 0;
142
+ bias = -1;
143
+
144
+
145
+ if(nrhs <= 1)
146
+ return 1;
147
+
148
+ if(nrhs == 4)
149
+ {
150
+ mxGetString(prhs[3], cmd, mxGetN(prhs[3])+1);
151
+ if(strcmp(cmd, "col") == 0)
152
+ col_format_flag = 1;
153
+ }
154
+
155
+ // put options in argv[]
156
+ if(nrhs > 2)
157
+ {
158
+ mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
159
+ if((argv[argc] = strtok(cmd, " ")) != NULL)
160
+ while((argv[++argc] = strtok(NULL, " ")) != NULL)
161
+ ;
162
+ }
163
+
164
+ // parse options
165
+ for(i=1;i<argc;i++)
166
+ {
167
+ if(argv[i][0] != '-') break;
168
+ ++i;
169
+ if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter
170
+ return 1;
171
+ switch(argv[i-1][1])
172
+ {
173
+ case 's':
174
+ param.solver_type = atoi(argv[i]);
175
+ break;
176
+ case 'c':
177
+ param.C = atof(argv[i]);
178
+ break;
179
+ case 'p':
180
+ param.p = atof(argv[i]);
181
+ break;
182
+ case 'e':
183
+ param.eps = atof(argv[i]);
184
+ break;
185
+ case 'B':
186
+ bias = atof(argv[i]);
187
+ break;
188
+ case 'v':
189
+ cross_validation_flag = 1;
190
+ nr_fold = atoi(argv[i]);
191
+ if(nr_fold < 2)
192
+ {
193
+ mexPrintf("n-fold cross validation: n must >= 2\n");
194
+ return 1;
195
+ }
196
+ break;
197
+ case 'w':
198
+ ++param.nr_weight;
199
+ param.weight_label = (int *) realloc(param.weight_label,sizeof(int)*param.nr_weight);
200
+ param.weight = (double *) realloc(param.weight,sizeof(double)*param.nr_weight);
201
+ param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
202
+ param.weight[param.nr_weight-1] = atof(argv[i]);
203
+ break;
204
+ case 'q':
205
+ print_func = &print_null;
206
+ i--;
207
+ break;
208
+ default:
209
+ mexPrintf("unknown option\n");
210
+ return 1;
211
+ }
212
+ }
213
+
214
+ set_print_string_function(print_func);
215
+
216
+ if(param.eps == INF)
217
+ {
218
+ switch(param.solver_type)
219
+ {
220
+ case L2R_LR:
221
+ case L2R_L2LOSS_SVC:
222
+ param.eps = 0.01;
223
+ break;
224
+ case L2R_L2LOSS_SVR:
225
+ param.eps = 0.001;
226
+ break;
227
+ case L2R_L2LOSS_SVC_DUAL:
228
+ case L2R_L1LOSS_SVC_DUAL:
229
+ case MCSVM_CS:
230
+ case L2R_LR_DUAL:
231
+ param.eps = 0.1;
232
+ break;
233
+ case L1R_L2LOSS_SVC:
234
+ case L1R_LR:
235
+ param.eps = 0.01;
236
+ break;
237
+ case L2R_L1LOSS_SVR_DUAL:
238
+ case L2R_L2LOSS_SVR_DUAL:
239
+ param.eps = 0.1;
240
+ break;
241
+ }
242
+ }
243
+ return 0;
244
+ }
245
+
246
+ static void fake_answer(mxArray *plhs[])
247
+ {
248
+ plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
249
+ }
250
+
251
+ int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
252
+ {
253
+ int i, j, k, low, high;
254
+ mwIndex *ir, *jc;
255
+ int elements, max_index, num_samples, label_vector_row_num;
256
+ double *samples, *labels;
257
+ mxArray *instance_mat_col; // instance sparse matrix in column format
258
+
259
+ prob.x = NULL;
260
+ prob.y = NULL;
261
+ x_space = NULL;
262
+
263
+ if(col_format_flag)
264
+ instance_mat_col = (mxArray *)instance_mat;
265
+ else
266
+ {
267
+ // transpose instance matrix
268
+ mxArray *prhs[1], *plhs[1];
269
+ prhs[0] = mxDuplicateArray(instance_mat);
270
+ if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
271
+ {
272
+ mexPrintf("Error: cannot transpose training instance matrix\n");
273
+ return -1;
274
+ }
275
+ instance_mat_col = plhs[0];
276
+ mxDestroyArray(prhs[0]);
277
+ }
278
+
279
+ // the number of instance
280
+ prob.l = (int) mxGetN(instance_mat_col);
281
+ label_vector_row_num = (int) mxGetM(label_vec);
282
+
283
+ if(label_vector_row_num!=prob.l)
284
+ {
285
+ mexPrintf("Length of label vector does not match # of instances.\n");
286
+ return -1;
287
+ }
288
+
289
+ // each column is one instance
290
+ labels = mxGetPr(label_vec);
291
+ samples = mxGetPr(instance_mat_col);
292
+ ir = mxGetIr(instance_mat_col);
293
+ jc = mxGetJc(instance_mat_col);
294
+
295
+ num_samples = (int) mxGetNzmax(instance_mat_col);
296
+
297
+ elements = num_samples + prob.l*2;
298
+ max_index = (int) mxGetM(instance_mat_col);
299
+
300
+ prob.y = Malloc(double, prob.l);
301
+ prob.x = Malloc(struct feature_node*, prob.l);
302
+ x_space = Malloc(struct feature_node, elements);
303
+
304
+ prob.bias=bias;
305
+
306
+ j = 0;
307
+ for(i=0;i<prob.l;i++)
308
+ {
309
+ prob.x[i] = &x_space[j];
310
+ prob.y[i] = labels[i];
311
+ low = (int) jc[i], high = (int) jc[i+1];
312
+ for(k=low;k<high;k++)
313
+ {
314
+ x_space[j].index = (int) ir[k]+1;
315
+ x_space[j].value = samples[k];
316
+ j++;
317
+ }
318
+ if(prob.bias>=0)
319
+ {
320
+ x_space[j].index = max_index+1;
321
+ x_space[j].value = prob.bias;
322
+ j++;
323
+ }
324
+ x_space[j++].index = -1;
325
+ }
326
+
327
+ if(prob.bias>=0)
328
+ prob.n = max_index+1;
329
+ else
330
+ prob.n = max_index;
331
+
332
+ return 0;
333
+ }
334
+
335
+ // Interface function of matlab
336
+ // now assume prhs[0]: label prhs[1]: features
337
+ void mexFunction( int nlhs, mxArray *plhs[],
338
+ int nrhs, const mxArray *prhs[] )
339
+ {
340
+ const char *error_msg;
341
+ // fix random seed to have same results for each run
342
+ // (for cross validation)
343
+ srand(1);
344
+
345
+ // Transform the input Matrix to libsvm format
346
+ if(nrhs > 1 && nrhs < 5)
347
+ {
348
+ int err=0;
349
+
350
+ if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
351
+ mexPrintf("Error: label vector and instance matrix must be double\n");
352
+ fake_answer(plhs);
353
+ return;
354
+ }
355
+
356
+ if(parse_command_line(nrhs, prhs, NULL))
357
+ {
358
+ exit_with_help();
359
+ destroy_param(&param);
360
+ fake_answer(plhs);
361
+ return;
362
+ }
363
+
364
+ if(mxIsSparse(prhs[1]))
365
+ err = read_problem_sparse(prhs[0], prhs[1]);
366
+ else
367
+ {
368
+ mexPrintf("Training_instance_matrix must be sparse; "
369
+ "use sparse(Training_instance_matrix) first\n");
370
+ destroy_param(&param);
371
+ fake_answer(plhs);
372
+ return;
373
+ }
374
+
375
+ // train's original code
376
+ error_msg = check_parameter(&prob, &param);
377
+
378
+ if(err || error_msg)
379
+ {
380
+ if (error_msg != NULL)
381
+ mexPrintf("Error: %s\n", error_msg);
382
+ destroy_param(&param);
383
+ free(prob.y);
384
+ free(prob.x);
385
+ free(x_space);
386
+ fake_answer(plhs);
387
+ return;
388
+ }
389
+
390
+ if(cross_validation_flag)
391
+ {
392
+ double *ptr;
393
+ plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
394
+ ptr = mxGetPr(plhs[0]);
395
+ ptr[0] = do_cross_validation();
396
+ }
397
+ else
398
+ {
399
+ const char *error_msg;
400
+
401
+ model_ = train(&prob, &param);
402
+ error_msg = model_to_matlab_structure(plhs, model_);
403
+ if(error_msg)
404
+ mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
405
+ free_and_destroy_model(&model_);
406
+ }
407
+ destroy_param(&param);
408
+ free(prob.y);
409
+ free(prob.x);
410
+ free(x_space);
411
+ }
412
+ else
413
+ {
414
+ exit_with_help();
415
+ fake_answer(plhs);
416
+ return;
417
+ }
418
+ }
Binary file
@@ -0,0 +1,245 @@
1
+ #include <stdio.h>
2
+ #include <ctype.h>
3
+ #include <stdlib.h>
4
+ #include <string.h>
5
+ #include <errno.h>
6
+ #include "linear.h"
7
+
8
+ int print_null(const char *s,...) {return 0;}
9
+
10
+ static int (*info)(const char *fmt,...) = &printf;
11
+
12
+ struct feature_node *x;
13
+ int max_nr_attr = 64;
14
+
15
+ struct model* model_;
16
+ int flag_predict_probability=0;
17
+
18
+ void exit_input_error(int line_num)
19
+ {
20
+ fprintf(stderr,"Wrong input format at line %d\n", line_num);
21
+ exit(1);
22
+ }
23
+
24
+ static char *line = NULL;
25
+ static int max_line_len;
26
+
27
+ static char* readline(FILE *input)
28
+ {
29
+ int len;
30
+
31
+ if(fgets(line,max_line_len,input) == NULL)
32
+ return NULL;
33
+
34
+ while(strrchr(line,'\n') == NULL)
35
+ {
36
+ max_line_len *= 2;
37
+ line = (char *) realloc(line,max_line_len);
38
+ len = (int) strlen(line);
39
+ if(fgets(line+len,max_line_len-len,input) == NULL)
40
+ break;
41
+ }
42
+ return line;
43
+ }
44
+
45
+ void do_predict(FILE *input, FILE *output)
46
+ {
47
+ int correct = 0;
48
+ int total = 0;
49
+ double error = 0;
50
+ double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
51
+
52
+ int nr_class=get_nr_class(model_);
53
+ double *prob_estimates=NULL;
54
+ int j, n;
55
+ int nr_feature=get_nr_feature(model_);
56
+ if(model_->bias>=0)
57
+ n=nr_feature+1;
58
+ else
59
+ n=nr_feature;
60
+
61
+ if(flag_predict_probability)
62
+ {
63
+ int *labels;
64
+
65
+ if(!check_probability_model(model_))
66
+ {
67
+ fprintf(stderr, "probability output is only supported for logistic regression\n");
68
+ exit(1);
69
+ }
70
+
71
+ labels=(int *) malloc(nr_class*sizeof(int));
72
+ get_labels(model_,labels);
73
+ prob_estimates = (double *) malloc(nr_class*sizeof(double));
74
+ fprintf(output,"labels");
75
+ for(j=0;j<nr_class;j++)
76
+ fprintf(output," %d",labels[j]);
77
+ fprintf(output,"\n");
78
+ free(labels);
79
+ }
80
+
81
+ max_line_len = 1024;
82
+ line = (char *)malloc(max_line_len*sizeof(char));
83
+ while(readline(input) != NULL)
84
+ {
85
+ int i = 0;
86
+ double target_label, predict_label;
87
+ char *idx, *val, *label, *endptr;
88
+ int inst_max_index = 0; // strtol gives 0 if wrong format
89
+
90
+ label = strtok(line," \t\n");
91
+ if(label == NULL) // empty line
92
+ exit_input_error(total+1);
93
+
94
+ target_label = strtod(label,&endptr);
95
+ if(endptr == label || *endptr != '\0')
96
+ exit_input_error(total+1);
97
+
98
+ while(1)
99
+ {
100
+ if(i>=max_nr_attr-2) // need one more for index = -1
101
+ {
102
+ max_nr_attr *= 2;
103
+ x = (struct feature_node *) realloc(x,max_nr_attr*sizeof(struct feature_node));
104
+ }
105
+
106
+ idx = strtok(NULL,":");
107
+ val = strtok(NULL," \t");
108
+
109
+ if(val == NULL)
110
+ break;
111
+ errno = 0;
112
+ x[i].index = (int) strtol(idx,&endptr,10);
113
+ if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
114
+ exit_input_error(total+1);
115
+ else
116
+ inst_max_index = x[i].index;
117
+
118
+ errno = 0;
119
+ x[i].value = strtod(val,&endptr);
120
+ if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
121
+ exit_input_error(total+1);
122
+
123
+ // feature indices larger than those in training are not used
124
+ if(x[i].index <= nr_feature)
125
+ ++i;
126
+ }
127
+
128
+ if(model_->bias>=0)
129
+ {
130
+ x[i].index = n;
131
+ x[i].value = model_->bias;
132
+ i++;
133
+ }
134
+ x[i].index = -1;
135
+
136
+ if(flag_predict_probability)
137
+ {
138
+ int j;
139
+ predict_label = predict_probability(model_,x,prob_estimates);
140
+ fprintf(output,"%g",predict_label);
141
+ for(j=0;j<model_->nr_class;j++)
142
+ fprintf(output," %g",prob_estimates[j]);
143
+ fprintf(output,"\n");
144
+ }
145
+ else
146
+ {
147
+ predict_label = predict(model_,x);
148
+ fprintf(output,"%g\n",predict_label);
149
+ }
150
+
151
+ if(predict_label == target_label)
152
+ ++correct;
153
+ error += (predict_label-target_label)*(predict_label-target_label);
154
+ sump += predict_label;
155
+ sumt += target_label;
156
+ sumpp += predict_label*predict_label;
157
+ sumtt += target_label*target_label;
158
+ sumpt += predict_label*target_label;
159
+ ++total;
160
+ }
161
+ if(model_->param.solver_type==L2R_L2LOSS_SVR ||
162
+ model_->param.solver_type==L2R_L1LOSS_SVR_DUAL ||
163
+ model_->param.solver_type==L2R_L2LOSS_SVR_DUAL)
164
+ {
165
+ info("Mean squared error = %g (regression)\n",error/total);
166
+ info("Squared correlation coefficient = %g (regression)\n",
167
+ ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
168
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
169
+ );
170
+ }
171
+ else
172
+ info("Accuracy = %g%% (%d/%d)\n",(double) correct/total*100,correct,total);
173
+ if(flag_predict_probability)
174
+ free(prob_estimates);
175
+ }
176
+
177
+ void exit_with_help()
178
+ {
179
+ printf(
180
+ "Usage: predict [options] test_file model_file output_file\n"
181
+ "options:\n"
182
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only\n"
183
+ "-q : quiet mode (no outputs)\n"
184
+ );
185
+ exit(1);
186
+ }
187
+
188
+ int main(int argc, char **argv)
189
+ {
190
+ FILE *input, *output;
191
+ int i;
192
+
193
+ // parse options
194
+ for(i=1;i<argc;i++)
195
+ {
196
+ if(argv[i][0] != '-') break;
197
+ ++i;
198
+ switch(argv[i-1][1])
199
+ {
200
+ case 'b':
201
+ flag_predict_probability = atoi(argv[i]);
202
+ break;
203
+ case 'q':
204
+ info = &print_null;
205
+ i--;
206
+ break;
207
+ default:
208
+ fprintf(stderr,"unknown option: -%c\n", argv[i-1][1]);
209
+ exit_with_help();
210
+ break;
211
+ }
212
+ }
213
+ if(i>=argc)
214
+ exit_with_help();
215
+
216
+ input = fopen(argv[i],"r");
217
+ if(input == NULL)
218
+ {
219
+ fprintf(stderr,"can't open input file %s\n",argv[i]);
220
+ exit(1);
221
+ }
222
+
223
+ output = fopen(argv[i+2],"w");
224
+ if(output == NULL)
225
+ {
226
+ fprintf(stderr,"can't open output file %s\n",argv[i+2]);
227
+ exit(1);
228
+ }
229
+
230
+ if((model_=load_model(argv[i+1]))==0)
231
+ {
232
+ fprintf(stderr,"can't open model file %s\n",argv[i+1]);
233
+ exit(1);
234
+ }
235
+
236
+ x = (struct feature_node *) malloc(max_nr_attr*sizeof(struct feature_node));
237
+ do_predict(input, output);
238
+ free_and_destroy_model(&model_);
239
+ free(line);
240
+ free(x);
241
+ fclose(input);
242
+ fclose(output);
243
+ return 0;
244
+ }
245
+