liblinear-ruby 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (80) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +19 -0
  3. data/Gemfile +4 -0
  4. data/LICENSE.txt +22 -0
  5. data/README.md +46 -0
  6. data/Rakefile +1 -0
  7. data/ext/Makefile +237 -0
  8. data/ext/blas.h +25 -0
  9. data/ext/blasp.h +430 -0
  10. data/ext/daxpy.c +49 -0
  11. data/ext/ddot.c +50 -0
  12. data/ext/dnrm2.c +62 -0
  13. data/ext/dscal.c +44 -0
  14. data/ext/extconf.rb +12 -0
  15. data/ext/liblinear_wrap.cxx +4646 -0
  16. data/ext/linear.cpp +2811 -0
  17. data/ext/linear.h +74 -0
  18. data/ext/linear.rb +357 -0
  19. data/ext/tron.cpp +235 -0
  20. data/ext/tron.h +34 -0
  21. data/lib/liblinear.rb +89 -0
  22. data/lib/liblinear/error.rb +4 -0
  23. data/lib/liblinear/model.rb +66 -0
  24. data/lib/liblinear/parameter.rb +42 -0
  25. data/lib/liblinear/problem.rb +55 -0
  26. data/lib/liblinear/version.rb +3 -0
  27. data/liblinear-1.93/COPYRIGHT +31 -0
  28. data/liblinear-1.93/Makefile +37 -0
  29. data/liblinear-1.93/Makefile.win +30 -0
  30. data/liblinear-1.93/README +531 -0
  31. data/liblinear-1.93/blas/Makefile +22 -0
  32. data/liblinear-1.93/blas/blas.a +0 -0
  33. data/liblinear-1.93/blas/blas.h +25 -0
  34. data/liblinear-1.93/blas/blasp.h +430 -0
  35. data/liblinear-1.93/blas/daxpy.c +49 -0
  36. data/liblinear-1.93/blas/daxpy.o +0 -0
  37. data/liblinear-1.93/blas/ddot.c +50 -0
  38. data/liblinear-1.93/blas/ddot.o +0 -0
  39. data/liblinear-1.93/blas/dnrm2.c +62 -0
  40. data/liblinear-1.93/blas/dnrm2.o +0 -0
  41. data/liblinear-1.93/blas/dscal.c +44 -0
  42. data/liblinear-1.93/blas/dscal.o +0 -0
  43. data/liblinear-1.93/heart_scale +270 -0
  44. data/liblinear-1.93/linear.cpp +2811 -0
  45. data/liblinear-1.93/linear.def +18 -0
  46. data/liblinear-1.93/linear.h +74 -0
  47. data/liblinear-1.93/linear.o +0 -0
  48. data/liblinear-1.93/matlab/Makefile +58 -0
  49. data/liblinear-1.93/matlab/README +197 -0
  50. data/liblinear-1.93/matlab/libsvmread.c +212 -0
  51. data/liblinear-1.93/matlab/libsvmwrite.c +106 -0
  52. data/liblinear-1.93/matlab/linear_model_matlab.c +176 -0
  53. data/liblinear-1.93/matlab/linear_model_matlab.h +2 -0
  54. data/liblinear-1.93/matlab/make.m +21 -0
  55. data/liblinear-1.93/matlab/predict.c +331 -0
  56. data/liblinear-1.93/matlab/train.c +418 -0
  57. data/liblinear-1.93/predict +0 -0
  58. data/liblinear-1.93/predict.c +245 -0
  59. data/liblinear-1.93/python/Makefile +4 -0
  60. data/liblinear-1.93/python/README +343 -0
  61. data/liblinear-1.93/python/liblinear.py +277 -0
  62. data/liblinear-1.93/python/liblinearutil.py +250 -0
  63. data/liblinear-1.93/ruby/liblinear.i +41 -0
  64. data/liblinear-1.93/ruby/liblinear_wrap.cxx +4646 -0
  65. data/liblinear-1.93/ruby/linear.h +74 -0
  66. data/liblinear-1.93/ruby/linear.o +0 -0
  67. data/liblinear-1.93/train +0 -0
  68. data/liblinear-1.93/train.c +399 -0
  69. data/liblinear-1.93/tron.cpp +235 -0
  70. data/liblinear-1.93/tron.h +34 -0
  71. data/liblinear-1.93/tron.o +0 -0
  72. data/liblinear-1.93/windows/liblinear.dll +0 -0
  73. data/liblinear-1.93/windows/libsvmread.mexw64 +0 -0
  74. data/liblinear-1.93/windows/libsvmwrite.mexw64 +0 -0
  75. data/liblinear-1.93/windows/predict.exe +0 -0
  76. data/liblinear-1.93/windows/predict.mexw64 +0 -0
  77. data/liblinear-1.93/windows/train.exe +0 -0
  78. data/liblinear-1.93/windows/train.mexw64 +0 -0
  79. data/liblinear-ruby.gemspec +24 -0
  80. metadata +152 -0
@@ -0,0 +1,106 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include "mex.h"
5
+
6
+ #ifdef MX_API_VER
7
+ #if MX_API_VER < 0x07030000
8
+ typedef int mwIndex;
9
+ #endif
10
+ #endif
11
+
12
+ void exit_with_help()
13
+ {
14
+ mexPrintf(
15
+ "Usage: libsvmwrite('filename', label_vector, instance_matrix);\n"
16
+ );
17
+ }
18
+
19
+ void libsvmwrite(const char *filename, const mxArray *label_vec, const mxArray *instance_mat)
20
+ {
21
+ FILE *fp = fopen(filename,"w");
22
+ int i, k, low, high, l;
23
+ mwIndex *ir, *jc;
24
+ int label_vector_row_num;
25
+ double *samples, *labels;
26
+ mxArray *instance_mat_col; // instance sparse matrix in column format
27
+
28
+ if(fp ==NULL)
29
+ {
30
+ mexPrintf("can't open output file %s\n",filename);
31
+ return;
32
+ }
33
+
34
+ // transpose instance matrix
35
+ {
36
+ mxArray *prhs[1], *plhs[1];
37
+ prhs[0] = mxDuplicateArray(instance_mat);
38
+ if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
39
+ {
40
+ mexPrintf("Error: cannot transpose instance matrix\n");
41
+ return;
42
+ }
43
+ instance_mat_col = plhs[0];
44
+ mxDestroyArray(prhs[0]);
45
+ }
46
+
47
+ // the number of instance
48
+ l = (int) mxGetN(instance_mat_col);
49
+ label_vector_row_num = (int) mxGetM(label_vec);
50
+
51
+ if(label_vector_row_num!=l)
52
+ {
53
+ mexPrintf("Length of label vector does not match # of instances.\n");
54
+ return;
55
+ }
56
+
57
+ // each column is one instance
58
+ labels = mxGetPr(label_vec);
59
+ samples = mxGetPr(instance_mat_col);
60
+ ir = mxGetIr(instance_mat_col);
61
+ jc = mxGetJc(instance_mat_col);
62
+
63
+ for(i=0;i<l;i++)
64
+ {
65
+ fprintf(fp,"%g", labels[i]);
66
+
67
+ low = (int) jc[i], high = (int) jc[i+1];
68
+ for(k=low;k<high;k++)
69
+ fprintf(fp," %ld:%g", ir[k]+1, samples[k]);
70
+
71
+ fprintf(fp,"\n");
72
+ }
73
+
74
+ fclose(fp);
75
+ return;
76
+ }
77
+
78
+ void mexFunction( int nlhs, mxArray *plhs[],
79
+ int nrhs, const mxArray *prhs[] )
80
+ {
81
+ // Transform the input Matrix to libsvm format
82
+ if(nrhs == 3)
83
+ {
84
+ char filename[256];
85
+ if(!mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]))
86
+ {
87
+ mexPrintf("Error: label vector and instance matrix must be double\n");
88
+ return;
89
+ }
90
+
91
+ mxGetString(prhs[0], filename, mxGetN(prhs[0])+1);
92
+
93
+ if(mxIsSparse(prhs[2]))
94
+ libsvmwrite(filename, prhs[1], prhs[2]);
95
+ else
96
+ {
97
+ mexPrintf("Instance_matrix must be sparse\n");
98
+ return;
99
+ }
100
+ }
101
+ else
102
+ {
103
+ exit_with_help();
104
+ return;
105
+ }
106
+ }
@@ -0,0 +1,176 @@
1
+ #include <stdlib.h>
2
+ #include <string.h>
3
+ #include "../linear.h"
4
+
5
+ #include "mex.h"
6
+
7
+ #ifdef MX_API_VER
8
+ #if MX_API_VER < 0x07030000
9
+ typedef int mwIndex;
10
+ #endif
11
+ #endif
12
+
13
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
14
+
15
+ #define NUM_OF_RETURN_FIELD 6
16
+
17
+ static const char *field_names[] = {
18
+ "Parameters",
19
+ "nr_class",
20
+ "nr_feature",
21
+ "bias",
22
+ "Label",
23
+ "w",
24
+ };
25
+
26
+ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_)
27
+ {
28
+ int i;
29
+ int nr_w;
30
+ double *ptr;
31
+ mxArray *return_model, **rhs;
32
+ int out_id = 0;
33
+ int n, w_size;
34
+
35
+ rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
36
+
37
+ // Parameters
38
+ // for now, only solver_type is needed
39
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
40
+ ptr = mxGetPr(rhs[out_id]);
41
+ ptr[0] = model_->param.solver_type;
42
+ out_id++;
43
+
44
+ // nr_class
45
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
46
+ ptr = mxGetPr(rhs[out_id]);
47
+ ptr[0] = model_->nr_class;
48
+ out_id++;
49
+
50
+ if(model_->nr_class==2 && model_->param.solver_type != MCSVM_CS)
51
+ nr_w=1;
52
+ else
53
+ nr_w=model_->nr_class;
54
+
55
+ // nr_feature
56
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
57
+ ptr = mxGetPr(rhs[out_id]);
58
+ ptr[0] = model_->nr_feature;
59
+ out_id++;
60
+
61
+ // bias
62
+ rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
63
+ ptr = mxGetPr(rhs[out_id]);
64
+ ptr[0] = model_->bias;
65
+ out_id++;
66
+
67
+ if(model_->bias>=0)
68
+ n=model_->nr_feature+1;
69
+ else
70
+ n=model_->nr_feature;
71
+
72
+ w_size = n;
73
+ // Label
74
+ if(model_->label)
75
+ {
76
+ rhs[out_id] = mxCreateDoubleMatrix(model_->nr_class, 1, mxREAL);
77
+ ptr = mxGetPr(rhs[out_id]);
78
+ for(i = 0; i < model_->nr_class; i++)
79
+ ptr[i] = model_->label[i];
80
+ }
81
+ else
82
+ rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
83
+ out_id++;
84
+
85
+ // w
86
+ rhs[out_id] = mxCreateDoubleMatrix(nr_w, w_size, mxREAL);
87
+ ptr = mxGetPr(rhs[out_id]);
88
+ for(i = 0; i < w_size*nr_w; i++)
89
+ ptr[i]=model_->w[i];
90
+ out_id++;
91
+
92
+ /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
93
+ return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
94
+
95
+ /* Fill struct matrix with input arguments */
96
+ for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
97
+ mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
98
+ /* return */
99
+ plhs[0] = return_model;
100
+ mxFree(rhs);
101
+
102
+ return NULL;
103
+ }
104
+
105
+ const char *matlab_matrix_to_model(struct model *model_, const mxArray *matlab_struct)
106
+ {
107
+ int i, num_of_fields;
108
+ int nr_w;
109
+ double *ptr;
110
+ int id = 0;
111
+ int n, w_size;
112
+ mxArray **rhs;
113
+
114
+ num_of_fields = mxGetNumberOfFields(matlab_struct);
115
+ rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
116
+
117
+ for(i=0;i<num_of_fields;i++)
118
+ rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
119
+
120
+ model_->nr_class=0;
121
+ nr_w=0;
122
+ model_->nr_feature=0;
123
+ model_->w=NULL;
124
+ model_->label=NULL;
125
+
126
+ // Parameters
127
+ ptr = mxGetPr(rhs[id]);
128
+ model_->param.solver_type = (int)ptr[0];
129
+ id++;
130
+
131
+ // nr_class
132
+ ptr = mxGetPr(rhs[id]);
133
+ model_->nr_class = (int)ptr[0];
134
+ id++;
135
+
136
+ if(model_->nr_class==2 && model_->param.solver_type != MCSVM_CS)
137
+ nr_w=1;
138
+ else
139
+ nr_w=model_->nr_class;
140
+
141
+ // nr_feature
142
+ ptr = mxGetPr(rhs[id]);
143
+ model_->nr_feature = (int)ptr[0];
144
+ id++;
145
+
146
+ // bias
147
+ ptr = mxGetPr(rhs[id]);
148
+ model_->bias = (int)ptr[0];
149
+ id++;
150
+
151
+ if(model_->bias>=0)
152
+ n=model_->nr_feature+1;
153
+ else
154
+ n=model_->nr_feature;
155
+ w_size = n;
156
+
157
+ // Label
158
+ if(mxIsEmpty(rhs[id]) == 0)
159
+ {
160
+ model_->label = Malloc(int, model_->nr_class);
161
+ ptr = mxGetPr(rhs[id]);
162
+ for(i=0;i<model_->nr_class;i++)
163
+ model_->label[i] = (int)ptr[i];
164
+ }
165
+ id++;
166
+
167
+ ptr = mxGetPr(rhs[id]);
168
+ model_->w=Malloc(double, w_size*nr_w);
169
+ for(i = 0; i < w_size*nr_w; i++)
170
+ model_->w[i]=ptr[i];
171
+ id++;
172
+ mxFree(rhs);
173
+
174
+ return NULL;
175
+ }
176
+
@@ -0,0 +1,2 @@
1
+ const char *model_to_matlab_structure(mxArray *plhs[], struct model *model_);
2
+ const char *matlab_matrix_to_model(struct model *model_, const mxArray *matlab_struct);
@@ -0,0 +1,21 @@
1
+ % This make.m is for MATLAB and OCTAVE under Windows, Mac, and Unix
2
+
3
+ try
4
+ Type = ver;
5
+ % This part is for OCTAVE
6
+ if(strcmp(Type(1).Name, 'Octave') == 1)
7
+ mex libsvmread.c
8
+ mex libsvmwrite.c
9
+ mex train.c linear_model_matlab.c ../linear.cpp ../tron.cpp ../blas/*.c
10
+ mex predict.c linear_model_matlab.c ../linear.cpp ../tron.cpp ../blas/*.c
11
+ % This part is for MATLAB
12
+ % Add -largeArrayDims on 64-bit machines of MATLAB
13
+ else
14
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims libsvmread.c
15
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims libsvmwrite.c
16
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims train.c linear_model_matlab.c ../linear.cpp ../tron.cpp "../blas/*.c"
17
+ mex CFLAGS="\$CFLAGS -std=c99" -largeArrayDims predict.c linear_model_matlab.c ../linear.cpp ../tron.cpp "../blas/*.c"
18
+ end
19
+ catch
20
+ fprintf('If make.m fails, please check README about detailed instructions.\n');
21
+ end
@@ -0,0 +1,331 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include "../linear.h"
5
+
6
+ #include "mex.h"
7
+ #include "linear_model_matlab.h"
8
+
9
+ #ifdef MX_API_VER
10
+ #if MX_API_VER < 0x07030000
11
+ typedef int mwIndex;
12
+ #endif
13
+ #endif
14
+
15
+ #define CMD_LEN 2048
16
+
17
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
18
+
19
+ int print_null(const char *s,...) {}
20
+ int (*info)(const char *fmt,...);
21
+
22
+ int col_format_flag;
23
+
24
+ void read_sparse_instance(const mxArray *prhs, int index, struct feature_node *x, int feature_number, double bias)
25
+ {
26
+ int i, j, low, high;
27
+ mwIndex *ir, *jc;
28
+ double *samples;
29
+
30
+ ir = mxGetIr(prhs);
31
+ jc = mxGetJc(prhs);
32
+ samples = mxGetPr(prhs);
33
+
34
+ // each column is one instance
35
+ j = 0;
36
+ low = (int) jc[index], high = (int) jc[index+1];
37
+ for(i=low; i<high && (int) (ir[i])<feature_number; i++)
38
+ {
39
+ x[j].index = (int) ir[i]+1;
40
+ x[j].value = samples[i];
41
+ j++;
42
+ }
43
+ if(bias>=0)
44
+ {
45
+ x[j].index = feature_number+1;
46
+ x[j].value = bias;
47
+ j++;
48
+ }
49
+ x[j].index = -1;
50
+ }
51
+
52
+ static void fake_answer(mxArray *plhs[])
53
+ {
54
+ plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
55
+ plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
56
+ plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
57
+ }
58
+
59
+ void do_predict(mxArray *plhs[], const mxArray *prhs[], struct model *model_, const int predict_probability_flag)
60
+ {
61
+ int label_vector_row_num, label_vector_col_num;
62
+ int feature_number, testing_instance_number;
63
+ int instance_index;
64
+ double *ptr_label, *ptr_predict_label;
65
+ double *ptr_prob_estimates, *ptr_dec_values, *ptr;
66
+ struct feature_node *x;
67
+ mxArray *pplhs[1]; // instance sparse matrix in row format
68
+
69
+ int correct = 0;
70
+ int total = 0;
71
+ double error = 0;
72
+ double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
73
+
74
+ int nr_class=get_nr_class(model_);
75
+ int nr_w;
76
+ double *prob_estimates=NULL;
77
+
78
+ if(nr_class==2 && model_->param.solver_type!=MCSVM_CS)
79
+ nr_w=1;
80
+ else
81
+ nr_w=nr_class;
82
+
83
+ // prhs[1] = testing instance matrix
84
+ feature_number = get_nr_feature(model_);
85
+ testing_instance_number = (int) mxGetM(prhs[1]);
86
+ if(col_format_flag)
87
+ {
88
+ feature_number = (int) mxGetM(prhs[1]);
89
+ testing_instance_number = (int) mxGetN(prhs[1]);
90
+ }
91
+
92
+ label_vector_row_num = (int) mxGetM(prhs[0]);
93
+ label_vector_col_num = (int) mxGetN(prhs[0]);
94
+
95
+ if(label_vector_row_num!=testing_instance_number)
96
+ {
97
+ mexPrintf("Length of label vector does not match # of instances.\n");
98
+ fake_answer(plhs);
99
+ return;
100
+ }
101
+ if(label_vector_col_num!=1)
102
+ {
103
+ mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
104
+ fake_answer(plhs);
105
+ return;
106
+ }
107
+
108
+ ptr_label = mxGetPr(prhs[0]);
109
+
110
+ // transpose instance matrix
111
+ if(col_format_flag)
112
+ pplhs[0] = (mxArray *)prhs[1];
113
+ else
114
+ {
115
+ mxArray *pprhs[1];
116
+ pprhs[0] = mxDuplicateArray(prhs[1]);
117
+ if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
118
+ {
119
+ mexPrintf("Error: cannot transpose testing instance matrix\n");
120
+ fake_answer(plhs);
121
+ return;
122
+ }
123
+ }
124
+
125
+
126
+ prob_estimates = Malloc(double, nr_class);
127
+
128
+ plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
129
+ if(predict_probability_flag)
130
+ plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
131
+ else
132
+ plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_w, mxREAL);
133
+
134
+ ptr_predict_label = mxGetPr(plhs[0]);
135
+ ptr_prob_estimates = mxGetPr(plhs[2]);
136
+ ptr_dec_values = mxGetPr(plhs[2]);
137
+ x = Malloc(struct feature_node, feature_number+2);
138
+ for(instance_index=0;instance_index<testing_instance_number;instance_index++)
139
+ {
140
+ int i;
141
+ double target_label, predict_label;
142
+
143
+ target_label = ptr_label[instance_index];
144
+
145
+ // prhs[1] and prhs[1]^T are sparse
146
+ read_sparse_instance(pplhs[0], instance_index, x, feature_number, model_->bias);
147
+
148
+ if(predict_probability_flag)
149
+ {
150
+ predict_label = predict_probability(model_, x, prob_estimates);
151
+ ptr_predict_label[instance_index] = predict_label;
152
+ for(i=0;i<nr_class;i++)
153
+ ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
154
+ }
155
+ else
156
+ {
157
+ double *dec_values = Malloc(double, nr_class);
158
+ predict_label = predict_values(model_, x, dec_values);
159
+ ptr_predict_label[instance_index] = predict_label;
160
+
161
+ for(i=0;i<nr_w;i++)
162
+ ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
163
+ free(dec_values);
164
+ }
165
+
166
+ if(predict_label == target_label)
167
+ ++correct;
168
+ error += (predict_label-target_label)*(predict_label-target_label);
169
+ sump += predict_label;
170
+ sumt += target_label;
171
+ sumpp += predict_label*predict_label;
172
+ sumtt += target_label*target_label;
173
+ sumpt += predict_label*target_label;
174
+
175
+ ++total;
176
+ }
177
+
178
+ if(model_->param.solver_type==L2R_L2LOSS_SVR ||
179
+ model_->param.solver_type==L2R_L1LOSS_SVR_DUAL ||
180
+ model_->param.solver_type==L2R_L2LOSS_SVR_DUAL)
181
+ {
182
+ info("Mean squared error = %g (regression)\n",error/total);
183
+ info("Squared correlation coefficient = %g (regression)\n",
184
+ ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
185
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
186
+ );
187
+ }
188
+ else
189
+ info("Accuracy = %g%% (%d/%d)\n", (double) correct/total*100,correct,total);
190
+
191
+ // return accuracy, mean squared error, squared correlation coefficient
192
+ plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
193
+ ptr = mxGetPr(plhs[1]);
194
+ ptr[0] = (double)correct/total*100;
195
+ ptr[1] = error/total;
196
+ ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
197
+ ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
198
+
199
+ free(x);
200
+ if(prob_estimates != NULL)
201
+ free(prob_estimates);
202
+ }
203
+
204
+ void exit_with_help()
205
+ {
206
+ mexPrintf(
207
+ "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = predict(testing_label_vector, testing_instance_matrix, model, 'liblinear_options','col')\n"
208
+ "liblinear_options:\n"
209
+ "-b probability_estimates: whether to output probability estimates, 0 or 1 (default 0); currently for logistic regression only\n"
210
+ "-q quiet mode (no outputs)\n"
211
+ "col: if 'col' is setted testing_instance_matrix is parsed in column format, otherwise is in row format\n"
212
+ "Returns:\n"
213
+ " predicted_label: prediction output vector.\n"
214
+ " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
215
+ " prob_estimates: If selected, probability estimate vector.\n"
216
+ );
217
+ }
218
+
219
+ void mexFunction( int nlhs, mxArray *plhs[],
220
+ int nrhs, const mxArray *prhs[] )
221
+ {
222
+ int prob_estimate_flag = 0;
223
+ struct model *model_;
224
+ char cmd[CMD_LEN];
225
+ info = &mexPrintf;
226
+ col_format_flag = 0;
227
+
228
+ if(nrhs > 5 || nrhs < 3)
229
+ {
230
+ exit_with_help();
231
+ fake_answer(plhs);
232
+ return;
233
+ }
234
+ if(nrhs == 5)
235
+ {
236
+ mxGetString(prhs[4], cmd, mxGetN(prhs[4])+1);
237
+ if(strcmp(cmd, "col") == 0)
238
+ {
239
+ col_format_flag = 1;
240
+ }
241
+ }
242
+
243
+ if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
244
+ mexPrintf("Error: label vector and instance matrix must be double\n");
245
+ fake_answer(plhs);
246
+ return;
247
+ }
248
+
249
+ if(mxIsStruct(prhs[2]))
250
+ {
251
+ const char *error_msg;
252
+
253
+ // parse options
254
+ if(nrhs>=4)
255
+ {
256
+ int i, argc = 1;
257
+ char *argv[CMD_LEN/2];
258
+
259
+ // put options in argv[]
260
+ mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
261
+ if((argv[argc] = strtok(cmd, " ")) != NULL)
262
+ while((argv[++argc] = strtok(NULL, " ")) != NULL)
263
+ ;
264
+
265
+ for(i=1;i<argc;i++)
266
+ {
267
+ if(argv[i][0] != '-') break;
268
+ ++i;
269
+ if(i>=argc && argv[i-1][1] != 'q')
270
+ {
271
+ exit_with_help();
272
+ fake_answer(plhs);
273
+ return;
274
+ }
275
+ switch(argv[i-1][1])
276
+ {
277
+ case 'b':
278
+ prob_estimate_flag = atoi(argv[i]);
279
+ break;
280
+ case 'q':
281
+ info = &print_null;
282
+ i--;
283
+ break;
284
+ default:
285
+ mexPrintf("unknown option\n");
286
+ exit_with_help();
287
+ fake_answer(plhs);
288
+ return;
289
+ }
290
+ }
291
+ }
292
+
293
+ model_ = Malloc(struct model, 1);
294
+ error_msg = matlab_matrix_to_model(model_, prhs[2]);
295
+ if(error_msg)
296
+ {
297
+ mexPrintf("Error: can't read model: %s\n", error_msg);
298
+ free_and_destroy_model(&model_);
299
+ fake_answer(plhs);
300
+ return;
301
+ }
302
+
303
+ if(prob_estimate_flag)
304
+ {
305
+ if(!check_probability_model(model_))
306
+ {
307
+ mexPrintf("probability output is only supported for logistic regression\n");
308
+ prob_estimate_flag=0;
309
+ }
310
+ }
311
+
312
+ if(mxIsSparse(prhs[1]))
313
+ do_predict(plhs, prhs, model_, prob_estimate_flag);
314
+ else
315
+ {
316
+ mexPrintf("Testing_instance_matrix must be sparse; "
317
+ "use sparse(Testing_instance_matrix) first\n");
318
+ fake_answer(plhs);
319
+ }
320
+
321
+ // destroy model_
322
+ free_and_destroy_model(&model_);
323
+ }
324
+ else
325
+ {
326
+ mexPrintf("model file should be a struct array\n");
327
+ fake_answer(plhs);
328
+ }
329
+
330
+ return;
331
+ }