liblinear-ruby 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +19 -0
  3. data/Gemfile +4 -0
  4. data/LICENSE.txt +22 -0
  5. data/README.md +46 -0
  6. data/Rakefile +1 -0
  7. data/ext/Makefile +237 -0
  8. data/ext/blas.h +25 -0
  9. data/ext/blasp.h +430 -0
  10. data/ext/daxpy.c +49 -0
  11. data/ext/ddot.c +50 -0
  12. data/ext/dnrm2.c +62 -0
  13. data/ext/dscal.c +44 -0
  14. data/ext/extconf.rb +12 -0
  15. data/ext/liblinear_wrap.cxx +4646 -0
  16. data/ext/linear.cpp +2811 -0
  17. data/ext/linear.h +74 -0
  18. data/ext/linear.rb +357 -0
  19. data/ext/tron.cpp +235 -0
  20. data/ext/tron.h +34 -0
  21. data/lib/liblinear.rb +89 -0
  22. data/lib/liblinear/error.rb +4 -0
  23. data/lib/liblinear/model.rb +66 -0
  24. data/lib/liblinear/parameter.rb +42 -0
  25. data/lib/liblinear/problem.rb +55 -0
  26. data/lib/liblinear/version.rb +3 -0
  27. data/liblinear-1.93/COPYRIGHT +31 -0
  28. data/liblinear-1.93/Makefile +37 -0
  29. data/liblinear-1.93/Makefile.win +30 -0
  30. data/liblinear-1.93/README +531 -0
  31. data/liblinear-1.93/blas/Makefile +22 -0
  32. data/liblinear-1.93/blas/blas.a +0 -0
  33. data/liblinear-1.93/blas/blas.h +25 -0
  34. data/liblinear-1.93/blas/blasp.h +430 -0
  35. data/liblinear-1.93/blas/daxpy.c +49 -0
  36. data/liblinear-1.93/blas/daxpy.o +0 -0
  37. data/liblinear-1.93/blas/ddot.c +50 -0
  38. data/liblinear-1.93/blas/ddot.o +0 -0
  39. data/liblinear-1.93/blas/dnrm2.c +62 -0
  40. data/liblinear-1.93/blas/dnrm2.o +0 -0
  41. data/liblinear-1.93/blas/dscal.c +44 -0
  42. data/liblinear-1.93/blas/dscal.o +0 -0
  43. data/liblinear-1.93/heart_scale +270 -0
  44. data/liblinear-1.93/linear.cpp +2811 -0
  45. data/liblinear-1.93/linear.def +18 -0
  46. data/liblinear-1.93/linear.h +74 -0
  47. data/liblinear-1.93/linear.o +0 -0
  48. data/liblinear-1.93/matlab/Makefile +58 -0
  49. data/liblinear-1.93/matlab/README +197 -0
  50. data/liblinear-1.93/matlab/libsvmread.c +212 -0
  51. data/liblinear-1.93/matlab/libsvmwrite.c +106 -0
  52. data/liblinear-1.93/matlab/linear_model_matlab.c +176 -0
  53. data/liblinear-1.93/matlab/linear_model_matlab.h +2 -0
  54. data/liblinear-1.93/matlab/make.m +21 -0
  55. data/liblinear-1.93/matlab/predict.c +331 -0
  56. data/liblinear-1.93/matlab/train.c +418 -0
  57. data/liblinear-1.93/predict +0 -0
  58. data/liblinear-1.93/predict.c +245 -0
  59. data/liblinear-1.93/python/Makefile +4 -0
  60. data/liblinear-1.93/python/README +343 -0
  61. data/liblinear-1.93/python/liblinear.py +277 -0
  62. data/liblinear-1.93/python/liblinearutil.py +250 -0
  63. data/liblinear-1.93/ruby/liblinear.i +41 -0
  64. data/liblinear-1.93/ruby/liblinear_wrap.cxx +4646 -0
  65. data/liblinear-1.93/ruby/linear.h +74 -0
  66. data/liblinear-1.93/ruby/linear.o +0 -0
  67. data/liblinear-1.93/train +0 -0
  68. data/liblinear-1.93/train.c +399 -0
  69. data/liblinear-1.93/tron.cpp +235 -0
  70. data/liblinear-1.93/tron.h +34 -0
  71. data/liblinear-1.93/tron.o +0 -0
  72. data/liblinear-1.93/windows/liblinear.dll +0 -0
  73. data/liblinear-1.93/windows/libsvmread.mexw64 +0 -0
  74. data/liblinear-1.93/windows/libsvmwrite.mexw64 +0 -0
  75. data/liblinear-1.93/windows/predict.exe +0 -0
  76. data/liblinear-1.93/windows/predict.mexw64 +0 -0
  77. data/liblinear-1.93/windows/train.exe +0 -0
  78. data/liblinear-1.93/windows/train.mexw64 +0 -0
  79. data/liblinear-ruby.gemspec +24 -0
  80. metadata +152 -0
@@ -0,0 +1,106 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include "mex.h"
5
+
6
+ #ifdef MX_API_VER
7
+ #if MX_API_VER < 0x07030000
8
+ typedef int mwIndex;
9
+ #endif
10
+ #endif
11
+
12
+ void exit_with_help()
13
+ {
14
+ mexPrintf(
15
+ "Usage: libsvmwrite('filename', label_vector, instance_matrix);\n"
16
+ );
17
+ }
18
+
19
+ void libsvmwrite(const char *filename, const mxArray *label_vec, const mxArray *instance_mat)
20
+ {
21
+ FILE *fp = fopen(filename,"w");
22
+ int i, k, low, high, l;
23
+ mwIndex *ir, *jc;
24
+ int label_vector_row_num;
25
+ double *samples, *labels;
26
+ mxArray *instance_mat_col; // instance sparse matrix in column format
27
+
28
+ if(fp ==NULL)
29
+ {
30
+ mexPrintf("can't open output file %s\n",filename);
31
+ return;
32
+ }
33
+
34
+ // transpose instance matrix
35
+ {
36
+ mxArray *prhs[1], *plhs[1];
37
+ prhs[0] = mxDuplicateArray(instance_mat);
38
+ if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
39
+ {
40
+ mexPrintf("Error: cannot transpose instance matrix\n");
41
+ return;
42
+ }
43
+ instance_mat_col = plhs[0];
44
+ mxDestroyArray(prhs[0]);
45
+ }
46
+
47
+ // the number of instance
48
+ l = (int) mxGetN(instance_mat_col);
49
+ label_vector_row_num = (int) mxGetM(label_vec);
50
+
51
+ if(label_vector_row_num!=l)
52
+ {
53
+ mexPrintf("Length of label vector does not match # of instances.\n");
54
+ return;
55
+ }
56
+
57
+ // each column is one instance
58
+ labels = mxGetPr(label_vec);
59
+ samples = mxGetPr(instance_mat_col);
60
+ ir = mxGetIr(instance_mat_col);
61
+ jc = mxGetJc(instance_mat_col);
62
+
63
+ for(i=0;i<l;i++)
64
+ {
65
+ fprintf(fp,"%g", labels[i]);
66
+
67
+ low = (int) jc[i], high = (int) jc[i+1];
68
+ for(k=low;k<high;k++)
69
+ fprintf(fp," %ld:%g", ir[k]+1, samples[k]);
70
+
71
+ fprintf(fp,"\n");
72
+ }
73
+
74
+ fclose(fp);
75
+ return;
76
+ }
77
+
78
+ void mexFunction( int nlhs, mxArray *plhs[],
79
+ int nrhs, const mxArray *prhs[] )
80
+ {
81
+ // Transform the input Matrix to libsvm format
82
+ if(nrhs == 3)
83
+ {
84
+ char filename[256];
85
+ if(!mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]))
86
+ {
87
+ mexPrintf("Error: label vector and instance matrix must be double\n");
88
+ return;
89
+ }
90
+
91
+ mxGetString(prhs[0], filename, mxGetN(prhs[0])+1);
92
+
93
+ if(mxIsSparse(prhs[2]))
94
+ libsvmwrite(filename, prhs[1], prhs[2]);
95
+ else
96
+ {
97
+ mexPrintf("Instance_matrix must be sparse\n");
98
+ return;
99
+ }
100
+ }
101
+ else
102
+ {
103
+ exit_with_help();
104
+ return;
105
+ }
106
+ }
@@ -0,0 +1,176 @@
1
+ #include <stdlib.h>
2
+ #include <string.h>
3
+ #include "../linear.h"
4
+
5
+ #include "mex.h"
6
+
7
+ #ifdef MX_API_VER
8
+ #if MX_API_VER < 0x07030000
9
+ typedef int mwIndex;
10
+ #endif
11
+ #endif
12
+
13
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
14
+
15
+ #define NUM_OF_RETURN_FIELD 6
16
+
17
+ static const char *field_names[] = {
18
+ "Parameters",
19
+ "nr_class",
20
+ "nr_feature",
21
+ "bias",
22
+ "Label",
23
+ "w",
24
+ };
25
+
26
+ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_)
27
+ {
28
+ int i;
29
+ int nr_w;
30
+ double *ptr;
31
+ mxArray *return_model, **rhs;
32
+ int out_id = 0;
33
+ int n, w_size;
34
+
35
+ rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
36
+
37
+ // Parameters
38
+ // for now, only solver_type is needed
39
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
40
+ ptr = mxGetPr(rhs[out_id]);
41
+ ptr[0] = model_->param.solver_type;
42
+ out_id++;
43
+
44
+ // nr_class
45
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
46
+ ptr = mxGetPr(rhs[out_id]);
47
+ ptr[0] = model_->nr_class;
48
+ out_id++;
49
+
50
+ if(model_->nr_class==2 && model_->param.solver_type != MCSVM_CS)
51
+ nr_w=1;
52
+ else
53
+ nr_w=model_->nr_class;
54
+
55
+ // nr_feature
56
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
57
+ ptr = mxGetPr(rhs[out_id]);
58
+ ptr[0] = model_->nr_feature;
59
+ out_id++;
60
+
61
+ // bias
62
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
63
+ ptr = mxGetPr(rhs[out_id]);
64
+ ptr[0] = model_->bias;
65
+ out_id++;
66
+
67
+ if(model_->bias>=0)
68
+ n=model_->nr_feature+1;
69
+ else
70
+ n=model_->nr_feature;
71
+
72
+ w_size = n;
73
+ // Label
74
+ if(model_->label)
75
+ {
76
+ rhs[out_id] = mxCreateDoubleMatrix(model_->nr_class, 1, mxREAL);
77
+ ptr = mxGetPr(rhs[out_id]);
78
+ for(i = 0; i < model_->nr_class; i++)
79
+ ptr[i] = model_->label[i];
80
+ }
81
+ else
82
+ rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
83
+ out_id++;
84
+
85
+ // w
86
+ rhs[out_id] = mxCreateDoubleMatrix(nr_w, w_size, mxREAL);
87
+ ptr = mxGetPr(rhs[out_id]);
88
+ for(i = 0; i < w_size*nr_w; i++)
89
+ ptr[i]=model_->w[i];
90
+ out_id++;
91
+
92
+ /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
93
+ return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
94
+
95
+ /* Fill struct matrix with input arguments */
96
+ for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
97
+ mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
98
+ /* return */
99
+ plhs[0] = return_model;
100
+ mxFree(rhs);
101
+
102
+ return NULL;
103
+ }
104
+
105
+ const char *matlab_matrix_to_model(struct model *model_, const mxArray *matlab_struct)
106
+ {
107
+ int i, num_of_fields;
108
+ int nr_w;
109
+ double *ptr;
110
+ int id = 0;
111
+ int n, w_size;
112
+ mxArray **rhs;
113
+
114
+ num_of_fields = mxGetNumberOfFields(matlab_struct);
115
+ rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
116
+
117
+ for(i=0;i<num_of_fields;i++)
118
+ rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
119
+
120
+ model_->nr_class=0;
121
+ nr_w=0;
122
+ model_->nr_feature=0;
123
+ model_->w=NULL;
124
+ model_->label=NULL;
125
+
126
+ // Parameters
127
+ ptr = mxGetPr(rhs[id]);
128
+ model_->param.solver_type = (int)ptr[0];
129
+ id++;
130
+
131
+ // nr_class
132
+ ptr = mxGetPr(rhs[id]);
133
+ model_->nr_class = (int)ptr[0];
134
+ id++;
135
+
136
+ if(model_->nr_class==2 && model_->param.solver_type != MCSVM_CS)
137
+ nr_w=1;
138
+ else
139
+ nr_w=model_->nr_class;
140
+
141
+ // nr_feature
142
+ ptr = mxGetPr(rhs[id]);
143
+ model_->nr_feature = (int)ptr[0];
144
+ id++;
145
+
146
+ // bias
147
+ ptr = mxGetPr(rhs[id]);
148
+ model_->bias = (int)ptr[0];
149
+ id++;
150
+
151
+ if(model_->bias>=0)
152
+ n=model_->nr_feature+1;
153
+ else
154
+ n=model_->nr_feature;
155
+ w_size = n;
156
+
157
+ // Label
158
+ if(mxIsEmpty(rhs[id]) == 0)
159
+ {
160
+ model_->label = Malloc(int, model_->nr_class);
161
+ ptr = mxGetPr(rhs[id]);
162
+ for(i=0;i<model_->nr_class;i++)
163
+ model_->label[i] = (int)ptr[i];
164
+ }
165
+ id++;
166
+
167
+ ptr = mxGetPr(rhs[id]);
168
+ model_->w=Malloc(double, w_size*nr_w);
169
+ for(i = 0; i < w_size*nr_w; i++)
170
+ model_->w[i]=ptr[i];
171
+ id++;
172
+ mxFree(rhs);
173
+
174
+ return NULL;
175
+ }
176
+
@@ -0,0 +1,2 @@
1
+ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_);
2
+ const char *matlab_matrix_to_model(struct model *model_, const mxArray *matlab_struct);
@@ -0,0 +1,21 @@
1
+ % This make.m is for MATLAB and OCTAVE under Windows, Mac, and Unix
2
+
3
+ try
4
+ Type = ver;
5
+ % This part is for OCTAVE
6
+ if(strcmp(Type(1).Name, 'Octave') == 1)
7
+ mex libsvmread.c
8
+ mex libsvmwrite.c
9
+ mex train.c linear_model_matlab.c ../linear.cpp ../tron.cpp ../blas/*.c
10
+ mex predict.c linear_model_matlab.c ../linear.cpp ../tron.cpp ../blas/*.c
11
+ % This part is for MATLAB
12
+ % Add -largeArrayDims on 64-bit machines of MATLAB
13
+ else
14
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims libsvmread.c
15
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims libsvmwrite.c
16
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims train.c linear_model_matlab.c ../linear.cpp ../tron.cpp "../blas/*.c"
17
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims predict.c linear_model_matlab.c ../linear.cpp ../tron.cpp "../blas/*.c"
18
+ end
19
+ catch
20
+ fprintf('If make.m fails, please check README about detailed instructions.\n');
21
+ end
@@ -0,0 +1,331 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include "../linear.h"
5
+
6
+ #include "mex.h"
7
+ #include "linear_model_matlab.h"
8
+
9
+ #ifdef MX_API_VER
10
+ #if MX_API_VER < 0x07030000
11
+ typedef int mwIndex;
12
+ #endif
13
+ #endif
14
+
15
+ #define CMD_LEN 2048
16
+
17
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
18
+
19
+ int print_null(const char *s,...) {}
20
+ int (*info)(const char *fmt,...);
21
+
22
+ int col_format_flag;
23
+
24
+ void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x, int feature_number, double bias)
25
+ {
26
+ int i, j, low, high;
27
+ mwIndex *ir, *jc;
28
+ double *samples;
29
+
30
+ ir = mxGetIr(prhs);
31
+ jc = mxGetJc(prhs);
32
+ samples = mxGetPr(prhs);
33
+
34
+ // each column is one instance
35
+ j = 0;
36
+ low = (int) jc[index], high = (int) jc[index+1];
37
+ for(i=low; i<high && (int) (ir[i])<feature_number; i++)
38
+ {
39
+ x[j].index = (int) ir[i]+1;
40
+ x[j].value = samples[i];
41
+ j++;
42
+ }
43
+ if(bias>=0)
44
+ {
45
+ x[j].index = feature_number+1;
46
+ x[j].value = bias;
47
+ j++;
48
+ }
49
+ x[j].index = -1;
50
+ }
51
+
52
+ static void fake_answer(mxArray *plhs[])
53
+ {
54
+ plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
55
+ plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
56
+ plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
57
+ }
58
+
59
+ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag)
60
+ {
61
+ int label_vector_row_num, label_vector_col_num;
62
+ int feature_number, testing_instance_number;
63
+ int instance_index;
64
+ double *ptr_label, *ptr_predict_label;
65
+ double *ptr_prob_estimates, *ptr_dec_values, *ptr;
66
+ struct feature_node *x;
67
+ mxArray *pplhs[1]; // instance sparse matrix in row format
68
+
69
+ int correct = 0;
70
+ int total = 0;
71
+ double error = 0;
72
+ double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
73
+
74
+ int nr_class=get_nr_class(model_);
75
+ int nr_w;
76
+ double *prob_estimates=NULL;
77
+
78
+ if(nr_class==2 && model_->param.solver_type!=MCSVM_CS)
79
+ nr_w=1;
80
+ else
81
+ nr_w=nr_class;
82
+
83
+ // prhs[1] = testing instance matrix
84
+ feature_number = get_nr_feature(model_);
85
+ testing_instance_number = (int) mxGetM(prhs[1]);
86
+ if(col_format_flag)
87
+ {
88
+ feature_number = (int) mxGetM(prhs[1]);
89
+ testing_instance_number = (int) mxGetN(prhs[1]);
90
+ }
91
+
92
+ label_vector_row_num = (int) mxGetM(prhs[0]);
93
+ label_vector_col_num = (int) mxGetN(prhs[0]);
94
+
95
+ if(label_vector_row_num!=testing_instance_number)
96
+ {
97
+ mexPrintf("Length of label vector does not match # of instances.\n");
98
+ fake_answer(plhs);
99
+ return;
100
+ }
101
+ if(label_vector_col_num!=1)
102
+ {
103
+ mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
104
+ fake_answer(plhs);
105
+ return;
106
+ }
107
+
108
+ ptr_label = mxGetPr(prhs[0]);
109
+
110
+ // transpose instance matrix
111
+ if(col_format_flag)
112
+ pplhs[0] = (mxArray *)prhs[1];
113
+ else
114
+ {
115
+ mxArray *pprhs[1];
116
+ pprhs[0] = mxDuplicateArray(prhs[1]);
117
+ if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
118
+ {
119
+ mexPrintf("Error: cannot transpose testing instance matrix\n");
120
+ fake_answer(plhs);
121
+ return;
122
+ }
123
+ }
124
+
125
+
126
+ prob_estimates = Malloc(double, nr_class);
127
+
128
+ plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
129
+ if(predict_probability_flag)
130
+ plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
131
+ else
132
+ plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL);
133
+
134
+ ptr_predict_label = mxGetPr(plhs[0]);
135
+ ptr_prob_estimates = mxGetPr(plhs[2]);
136
+ ptr_dec_values = mxGetPr(plhs[2]);
137
+ x = Malloc(struct feature_node, feature_number+2);
138
+ for(instance_index=0;instance_index<testing_instance_number;instance_index++)
139
+ {
140
+ int i;
141
+ double target_label, predict_label;
142
+
143
+ target_label = ptr_label[instance_index];
144
+
145
+ // prhs[1] and prhs[1]^T are sparse
146
+ read_sparse_instance(pplhs[0], instance_index, x, feature_number, model_->bias);
147
+
148
+ if(predict_probability_flag)
149
+ {
150
+ predict_label = predict_probability(model_, x, prob_estimates);
151
+ ptr_predict_label[instance_index] = predict_label;
152
+ for(i=0;i<nr_class;i++)
153
+ ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
154
+ }
155
+ else
156
+ {
157
+ double *dec_values = Malloc(double, nr_class);
158
+ predict_label = predict_values(model_, x, dec_values);
159
+ ptr_predict_label[instance_index] = predict_label;
160
+
161
+ for(i=0;i<nr_w;i++)
162
+ ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
163
+ free(dec_values);
164
+ }
165
+
166
+ if(predict_label == target_label)
167
+ ++correct;
168
+ error += (predict_label-target_label)*(predict_label-target_label);
169
+ sump += predict_label;
170
+ sumt += target_label;
171
+ sumpp += predict_label*predict_label;
172
+ sumtt += target_label*target_label;
173
+ sumpt += predict_label*target_label;
174
+
175
+ ++total;
176
+ }
177
+
178
+ if(model_->param.solver_type==L2R_L2LOSS_SVR ||
179
+ model_->param.solver_type==L2R_L1LOSS_SVR_DUAL ||
180
+ model_->param.solver_type==L2R_L2LOSS_SVR_DUAL)
181
+ {
182
+ info("Mean squared error = %g (regression)\n",error/total);
183
+ info("Squared correlation coefficient = %g (regression)\n",
184
+ ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
185
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
186
+ );
187
+ }
188
+ else
189
+ info("Accuracy = %g%% (%d/%d)\n", (double) correct/total*100,correct,total);
190
+
191
+ // return accuracy, mean squared error, squared correlation coefficient
192
+ plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
193
+ ptr = mxGetPr(plhs[1]);
194
+ ptr[0] = (double)correct/total*100;
195
+ ptr[1] = error/total;
196
+ ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
197
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
198
+
199
+ free(x);
200
+ if(prob_estimates != NULL)
201
+ free(prob_estimates);
202
+ }
203
+
204
+ void exit_with_help()
205
+ {
206
+ mexPrintf(
207
+ "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n"
208
+ "liblinear_options:\n"
209
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only\n"
210
+ "-q quiet mode (no outputs)\n"
211
+ "col: if 'col' is setted testing_instance_matrix is parsed in column format, otherwise is in row format\n"
212
+ "Returns:\n"
213
+ " predicted_label: prediction output vector.\n"
214
+ " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
215
+ " prob_estimates: If selected, probability estimate vector.\n"
216
+ );
217
+ }
218
+
219
+ void mexFunction( int nlhs, mxArray *plhs[],
220
+ int nrhs, const mxArray *prhs[] )
221
+ {
222
+ int prob_estimate_flag = 0;
223
+ struct model *model_;
224
+ char cmd[CMD_LEN];
225
+ info = &mexPrintf;
226
+ col_format_flag = 0;
227
+
228
+ if(nrhs > 5 || nrhs < 3)
229
+ {
230
+ exit_with_help();
231
+ fake_answer(plhs);
232
+ return;
233
+ }
234
+ if(nrhs == 5)
235
+ {
236
+ mxGetString(prhs[4], cmd, mxGetN(prhs[4])+1);
237
+ if(strcmp(cmd, "col") == 0)
238
+ {
239
+ col_format_flag = 1;
240
+ }
241
+ }
242
+
243
+ if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
244
+ mexPrintf("Error: label vector and instance matrix must be double\n");
245
+ fake_answer(plhs);
246
+ return;
247
+ }
248
+
249
+ if(mxIsStruct(prhs[2]))
250
+ {
251
+ const char *error_msg;
252
+
253
+ // parse options
254
+ if(nrhs>=4)
255
+ {
256
+ int i, argc = 1;
257
+ char *argv[CMD_LEN/2];
258
+
259
+ // put options in argv[]
260
+ mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
261
+ if((argv[argc] = strtok(cmd, " ")) != NULL)
262
+ while((argv[++argc] = strtok(NULL, " ")) != NULL)
263
+ ;
264
+
265
+ for(i=1;i<argc;i++)
266
+ {
267
+ if(argv[i][0] != '-') break;
268
+ ++i;
269
+ if(i>=argc && argv[i-1][1] != 'q')
270
+ {
271
+ exit_with_help();
272
+ fake_answer(plhs);
273
+ return;
274
+ }
275
+ switch(argv[i-1][1])
276
+ {
277
+ case 'b':
278
+ prob_estimate_flag = atoi(argv[i]);
279
+ break;
280
+ case 'q':
281
+ info = &print_null;
282
+ i--;
283
+ break;
284
+ default:
285
+ mexPrintf("unknown option\n");
286
+ exit_with_help();
287
+ fake_answer(plhs);
288
+ return;
289
+ }
290
+ }
291
+ }
292
+
293
+ model_ = Malloc(struct model, 1);
294
+ error_msg = matlab_matrix_to_model(model_, prhs[2]);
295
+ if(error_msg)
296
+ {
297
+ mexPrintf("Error: can't read model: %s\n", error_msg);
298
+ free_and_destroy_model(&model_);
299
+ fake_answer(plhs);
300
+ return;
301
+ }
302
+
303
+ if(prob_estimate_flag)
304
+ {
305
+ if(!check_probability_model(model_))
306
+ {
307
+ mexPrintf("probability output is only supported for logistic regression\n");
308
+ prob_estimate_flag=0;
309
+ }
310
+ }
311
+
312
+ if(mxIsSparse(prhs[1]))
313
+ do_predict(plhs, prhs, model_, prob_estimate_flag);
314
+ else
315
+ {
316
+ mexPrintf("Testing_instance_matrix must be sparse; "
317
+ "use sparse(Testing_instance_matrix) first\n");
318
+ fake_answer(plhs);
319
+ }
320
+
321
+ // destroy model_
322
+ free_and_destroy_model(&model_);
323
+ }
324
+ else
325
+ {
326
+ mexPrintf("model file should be a struct array\n");
327
+ fake_answer(plhs);
328
+ }
329
+
330
+ return;
331
+ }