eluka 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (73) hide show
  1. data/.document +5 -0
  2. data/DOCUMENTATION_STANDARDS +39 -0
  3. data/Gemfile +13 -0
  4. data/Gemfile.lock +20 -0
  5. data/LICENSE.txt +20 -0
  6. data/README.rdoc +19 -0
  7. data/Rakefile +69 -0
  8. data/VERSION +1 -0
  9. data/examples/example.rb +59 -0
  10. data/ext/libsvm/COPYRIGHT +31 -0
  11. data/ext/libsvm/FAQ.html +1749 -0
  12. data/ext/libsvm/Makefile +25 -0
  13. data/ext/libsvm/Makefile.win +33 -0
  14. data/ext/libsvm/README +733 -0
  15. data/ext/libsvm/extconf.rb +1 -0
  16. data/ext/libsvm/heart_scale +270 -0
  17. data/ext/libsvm/java/Makefile +25 -0
  18. data/ext/libsvm/java/libsvm.jar +0 -0
  19. data/ext/libsvm/java/libsvm/svm.java +2776 -0
  20. data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
  21. data/ext/libsvm/java/libsvm/svm_model.java +21 -0
  22. data/ext/libsvm/java/libsvm/svm_node.java +6 -0
  23. data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
  24. data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
  25. data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
  26. data/ext/libsvm/java/svm_predict.java +163 -0
  27. data/ext/libsvm/java/svm_scale.java +350 -0
  28. data/ext/libsvm/java/svm_toy.java +471 -0
  29. data/ext/libsvm/java/svm_train.java +318 -0
  30. data/ext/libsvm/java/test_applet.html +1 -0
  31. data/ext/libsvm/python/Makefile +4 -0
  32. data/ext/libsvm/python/README +331 -0
  33. data/ext/libsvm/python/svm.py +259 -0
  34. data/ext/libsvm/python/svmutil.py +242 -0
  35. data/ext/libsvm/svm-predict.c +226 -0
  36. data/ext/libsvm/svm-scale.c +353 -0
  37. data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
  38. data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
  39. data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
  40. data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
  41. data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
  42. data/ext/libsvm/svm-toy/gtk/main.c +23 -0
  43. data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
  44. data/ext/libsvm/svm-toy/qt/Makefile +17 -0
  45. data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
  46. data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
  47. data/ext/libsvm/svm-train.c +376 -0
  48. data/ext/libsvm/svm.cpp +3060 -0
  49. data/ext/libsvm/svm.def +19 -0
  50. data/ext/libsvm/svm.h +105 -0
  51. data/ext/libsvm/svm.o +0 -0
  52. data/ext/libsvm/tools/README +149 -0
  53. data/ext/libsvm/tools/checkdata.py +108 -0
  54. data/ext/libsvm/tools/easy.py +79 -0
  55. data/ext/libsvm/tools/grid.py +359 -0
  56. data/ext/libsvm/tools/subset.py +146 -0
  57. data/ext/libsvm/windows/libsvm.dll +0 -0
  58. data/ext/libsvm/windows/svm-predict.exe +0 -0
  59. data/ext/libsvm/windows/svm-scale.exe +0 -0
  60. data/ext/libsvm/windows/svm-toy.exe +0 -0
  61. data/ext/libsvm/windows/svm-train.exe +0 -0
  62. data/lib/eluka.rb +10 -0
  63. data/lib/eluka/bijection.rb +23 -0
  64. data/lib/eluka/data_point.rb +36 -0
  65. data/lib/eluka/document.rb +47 -0
  66. data/lib/eluka/feature_vector.rb +86 -0
  67. data/lib/eluka/features.rb +31 -0
  68. data/lib/eluka/model.rb +129 -0
  69. data/lib/fselect.rb +321 -0
  70. data/lib/grid.rb +25 -0
  71. data/test/helper.rb +18 -0
  72. data/test/test_eluka.rb +7 -0
  73. metadata +214 -0
@@ -0,0 +1,376 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include <ctype.h>
5
+ #include <errno.h>
6
+ #include "svm.h"
7
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
8
+
9
+ void print_null(const char *s) {}
10
+
11
+ void exit_with_help()
12
+ {
13
+ printf(
14
+ "Usage: svm-train [options] training_set_file [model_file]\n"
15
+ "options:\n"
16
+ "-s svm_type : set type of SVM (default 0)\n"
17
+ " 0 -- C-SVC\n"
18
+ " 1 -- nu-SVC\n"
19
+ " 2 -- one-class SVM\n"
20
+ " 3 -- epsilon-SVR\n"
21
+ " 4 -- nu-SVR\n"
22
+ "-t kernel_type : set type of kernel function (default 2)\n"
23
+ " 0 -- linear: u'*v\n"
24
+ " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
25
+ " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
26
+ " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
27
+ " 4 -- precomputed kernel (kernel values in training_set_file)\n"
28
+ "-d degree : set degree in kernel function (default 3)\n"
29
+ "-g gamma : set gamma in kernel function (default 1/num_features)\n"
30
+ "-r coef0 : set coef0 in kernel function (default 0)\n"
31
+ "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
32
+ "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
33
+ "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
34
+ "-m cachesize : set cache memory size in MB (default 100)\n"
35
+ "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
36
+ "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
37
+ "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
38
+ "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
39
+ "-v n: n-fold cross validation mode\n"
40
+ "-q : quiet mode (no outputs)\n"
41
+ );
42
+ exit(1);
43
+ }
44
+
45
+ void exit_input_error(int line_num)
46
+ {
47
+ fprintf(stderr,"Wrong input format at line %d\n", line_num);
48
+ exit(1);
49
+ }
50
+
51
+ void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
52
+ void read_problem(const char *filename);
53
+ void do_cross_validation();
54
+
55
+ struct svm_parameter param; // set by parse_command_line
56
+ struct svm_problem prob; // set by read_problem
57
+ struct svm_model *model;
58
+ struct svm_node *x_space;
59
+ int cross_validation;
60
+ int nr_fold;
61
+
62
+ static char *line = NULL;
63
+ static int max_line_len;
64
+
65
+ static char* readline(FILE *input)
66
+ {
67
+ int len;
68
+
69
+ if(fgets(line,max_line_len,input) == NULL)
70
+ return NULL;
71
+
72
+ while(strrchr(line,'\n') == NULL)
73
+ {
74
+ max_line_len *= 2;
75
+ line = (char *) realloc(line,max_line_len);
76
+ len = (int) strlen(line);
77
+ if(fgets(line+len,max_line_len-len,input) == NULL)
78
+ break;
79
+ }
80
+ return line;
81
+ }
82
+
83
+ int main(int argc, char **argv)
84
+ {
85
+ char input_file_name[1024];
86
+ char model_file_name[1024];
87
+ const char *error_msg;
88
+
89
+ parse_command_line(argc, argv, input_file_name, model_file_name);
90
+ read_problem(input_file_name);
91
+ error_msg = svm_check_parameter(&prob,&param);
92
+
93
+ if(error_msg)
94
+ {
95
+ fprintf(stderr,"Error: %s\n",error_msg);
96
+ exit(1);
97
+ }
98
+
99
+ if(cross_validation)
100
+ {
101
+ do_cross_validation();
102
+ }
103
+ else
104
+ {
105
+ model = svm_train(&prob,&param);
106
+ if(svm_save_model(model_file_name,model))
107
+ {
108
+ fprintf(stderr, "can't save model to file %s\n", model_file_name);
109
+ exit(1);
110
+ }
111
+ svm_free_and_destroy_model(&model);
112
+ }
113
+ svm_destroy_param(&param);
114
+ free(prob.y);
115
+ free(prob.x);
116
+ free(x_space);
117
+ free(line);
118
+
119
+ return 0;
120
+ }
121
+
122
+ void do_cross_validation()
123
+ {
124
+ int i;
125
+ int total_correct = 0;
126
+ double total_error = 0;
127
+ double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
128
+ double *target = Malloc(double,prob.l);
129
+
130
+ svm_cross_validation(&prob,&param,nr_fold,target);
131
+ if(param.svm_type == EPSILON_SVR ||
132
+ param.svm_type == NU_SVR)
133
+ {
134
+ for(i=0;i<prob.l;i++)
135
+ {
136
+ double y = prob.y[i];
137
+ double v = target[i];
138
+ total_error += (v-y)*(v-y);
139
+ sumv += v;
140
+ sumy += y;
141
+ sumvv += v*v;
142
+ sumyy += y*y;
143
+ sumvy += v*y;
144
+ }
145
+ printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
146
+ printf("Cross Validation Squared correlation coefficient = %g\n",
147
+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
148
+ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
149
+ );
150
+ }
151
+ else
152
+ {
153
+ for(i=0;i<prob.l;i++)
154
+ if(target[i] == prob.y[i])
155
+ ++total_correct;
156
+ printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
157
+ }
158
+ free(target);
159
+ }
160
+
161
+ void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
162
+ {
163
+ int i;
164
+ void (*print_func)(const char*) = NULL; // default printing to stdout
165
+
166
+ // default values
167
+ param.svm_type = C_SVC;
168
+ param.kernel_type = RBF;
169
+ param.degree = 3;
170
+ param.gamma = 0; // 1/num_features
171
+ param.coef0 = 0;
172
+ param.nu = 0.5;
173
+ param.cache_size = 100;
174
+ param.C = 1;
175
+ param.eps = 1e-3;
176
+ param.p = 0.1;
177
+ param.shrinking = 1;
178
+ param.probability = 0;
179
+ param.nr_weight = 0;
180
+ param.weight_label = NULL;
181
+ param.weight = NULL;
182
+ cross_validation = 0;
183
+
184
+ // parse options
185
+ for(i=1;i<argc;i++)
186
+ {
187
+ if(argv[i][0] != '-') break;
188
+ if(++i>=argc)
189
+ exit_with_help();
190
+ switch(argv[i-1][1])
191
+ {
192
+ case 's':
193
+ param.svm_type = atoi(argv[i]);
194
+ break;
195
+ case 't':
196
+ param.kernel_type = atoi(argv[i]);
197
+ break;
198
+ case 'd':
199
+ param.degree = atoi(argv[i]);
200
+ break;
201
+ case 'g':
202
+ param.gamma = atof(argv[i]);
203
+ break;
204
+ case 'r':
205
+ param.coef0 = atof(argv[i]);
206
+ break;
207
+ case 'n':
208
+ param.nu = atof(argv[i]);
209
+ break;
210
+ case 'm':
211
+ param.cache_size = atof(argv[i]);
212
+ break;
213
+ case 'c':
214
+ param.C = atof(argv[i]);
215
+ break;
216
+ case 'e':
217
+ param.eps = atof(argv[i]);
218
+ break;
219
+ case 'p':
220
+ param.p = atof(argv[i]);
221
+ break;
222
+ case 'h':
223
+ param.shrinking = atoi(argv[i]);
224
+ break;
225
+ case 'b':
226
+ param.probability = atoi(argv[i]);
227
+ break;
228
+ case 'q':
229
+ print_func = &print_null;
230
+ i--;
231
+ break;
232
+ case 'v':
233
+ cross_validation = 1;
234
+ nr_fold = atoi(argv[i]);
235
+ if(nr_fold < 2)
236
+ {
237
+ fprintf(stderr,"n-fold cross validation: n must >= 2\n");
238
+ exit_with_help();
239
+ }
240
+ break;
241
+ case 'w':
242
+ ++param.nr_weight;
243
+ param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
244
+ param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
245
+ param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
246
+ param.weight[param.nr_weight-1] = atof(argv[i]);
247
+ break;
248
+ default:
249
+ fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
250
+ exit_with_help();
251
+ }
252
+ }
253
+
254
+ svm_set_print_string_function(print_func);
255
+
256
+ // determine filenames
257
+
258
+ if(i>=argc)
259
+ exit_with_help();
260
+
261
+ strcpy(input_file_name, argv[i]);
262
+
263
+ if(i<argc-1)
264
+ strcpy(model_file_name,argv[i+1]);
265
+ else
266
+ {
267
+ char *p = strrchr(argv[i],'/');
268
+ if(p==NULL)
269
+ p = argv[i];
270
+ else
271
+ ++p;
272
+ sprintf(model_file_name,"%s.model",p);
273
+ }
274
+ }
275
+
276
+ // read in a problem (in svmlight format)
277
+
278
+ void read_problem(const char *filename)
279
+ {
280
+ int elements, max_index, inst_max_index, i, j;
281
+ FILE *fp = fopen(filename,"r");
282
+ char *endptr;
283
+ char *idx, *val, *label;
284
+
285
+ if(fp == NULL)
286
+ {
287
+ fprintf(stderr,"can't open input file %s\n",filename);
288
+ exit(1);
289
+ }
290
+
291
+ prob.l = 0;
292
+ elements = 0;
293
+
294
+ max_line_len = 1024;
295
+ line = Malloc(char,max_line_len);
296
+ while(readline(fp)!=NULL)
297
+ {
298
+ char *p = strtok(line," \t"); // label
299
+
300
+ // features
301
+ while(1)
302
+ {
303
+ p = strtok(NULL," \t");
304
+ if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
305
+ break;
306
+ ++elements;
307
+ }
308
+ ++elements;
309
+ ++prob.l;
310
+ }
311
+ rewind(fp);
312
+
313
+ prob.y = Malloc(double,prob.l);
314
+ prob.x = Malloc(struct svm_node *,prob.l);
315
+ x_space = Malloc(struct svm_node,elements);
316
+
317
+ max_index = 0;
318
+ j=0;
319
+ for(i=0;i<prob.l;i++)
320
+ {
321
+ inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
322
+ readline(fp);
323
+ prob.x[i] = &x_space[j];
324
+ label = strtok(line," \t");
325
+ prob.y[i] = strtod(label,&endptr);
326
+ if(endptr == label)
327
+ exit_input_error(i+1);
328
+
329
+ while(1)
330
+ {
331
+ idx = strtok(NULL,":");
332
+ val = strtok(NULL," \t");
333
+
334
+ if(val == NULL)
335
+ break;
336
+
337
+ errno = 0;
338
+ x_space[j].index = (int) strtol(idx,&endptr,10);
339
+ if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
340
+ exit_input_error(i+1);
341
+ else
342
+ inst_max_index = x_space[j].index;
343
+
344
+ errno = 0;
345
+ x_space[j].value = strtod(val,&endptr);
346
+ if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
347
+ exit_input_error(i+1);
348
+
349
+ ++j;
350
+ }
351
+
352
+ if(inst_max_index > max_index)
353
+ max_index = inst_max_index;
354
+ x_space[j++].index = -1;
355
+ }
356
+
357
+ if(param.gamma == 0 && max_index > 0)
358
+ param.gamma = 1.0/max_index;
359
+
360
+ if(param.kernel_type == PRECOMPUTED)
361
+ for(i=0;i<prob.l;i++)
362
+ {
363
+ if (prob.x[i][0].index != 0)
364
+ {
365
+ fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
366
+ exit(1);
367
+ }
368
+ if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
369
+ {
370
+ fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
371
+ exit(1);
372
+ }
373
+ }
374
+
375
+ fclose(fp);
376
+ }
@@ -0,0 +1,3060 @@
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+ #include <ctype.h>
5
+ #include <float.h>
6
+ #include <string.h>
7
+ #include <stdarg.h>
8
+ #include "svm.h"
9
+ int libsvm_version = LIBSVM_VERSION;
10
+ typedef float Qfloat;
11
+ typedef signed char schar;
12
+ #ifndef min
13
+ template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
14
+ #endif
15
+ #ifndef max
16
+ template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
17
+ #endif
18
+ template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
19
+ template <class S, class T> static inline void clone(T*& dst, S* src, int n)
20
+ {
21
+ dst = new T[n];
22
+ memcpy((void *)dst,(void *)src,sizeof(T)*n);
23
+ }
24
+ static inline double powi(double base, int times)
25
+ {
26
+ double tmp = base, ret = 1.0;
27
+
28
+ for(int t=times; t>0; t/=2)
29
+ {
30
+ if(t%2==1) ret*=tmp;
31
+ tmp = tmp * tmp;
32
+ }
33
+ return ret;
34
+ }
35
+ #define INF HUGE_VAL
36
+ #define TAU 1e-12
37
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
38
+
39
+ static void print_string_stdout(const char *s)
40
+ {
41
+ fputs(s,stdout);
42
+ fflush(stdout);
43
+ }
44
+ static void (*svm_print_string) (const char *) = &print_string_stdout;
45
+ #if 1
46
+ static void info(const char *fmt,...)
47
+ {
48
+ char buf[BUFSIZ];
49
+ va_list ap;
50
+ va_start(ap,fmt);
51
+ vsprintf(buf,fmt,ap);
52
+ va_end(ap);
53
+ (*svm_print_string)(buf);
54
+ }
55
+ #else
56
+ static void info(const char *fmt,...) {}
57
+ #endif
58
+
59
+ //
60
+ // Kernel Cache
61
+ //
62
+ // l is the number of total data items
63
+ // size is the cache size limit in bytes
64
+ //
65
+ class Cache
66
+ {
67
+ public:
68
+ Cache(int l,long int size);
69
+ ~Cache();
70
+
71
+ // request data [0,len)
72
+ // return some position p where [p,len) need to be filled
73
+ // (p >= len if nothing needs to be filled)
74
+ int get_data(const int index, Qfloat **data, int len);
75
+ void swap_index(int i, int j);
76
+ private:
77
+ int l;
78
+ long int size;
79
+ struct head_t
80
+ {
81
+ head_t *prev, *next; // a circular list
82
+ Qfloat *data;
83
+ int len; // data[0,len) is cached in this entry
84
+ };
85
+
86
+ head_t *head;
87
+ head_t lru_head;
88
+ void lru_delete(head_t *h);
89
+ void lru_insert(head_t *h);
90
+ };
91
+
92
+ Cache::Cache(int l_,long int size_):l(l_),size(size_)
93
+ {
94
+ head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
95
+ size /= sizeof(Qfloat);
96
+ size -= l * sizeof(head_t) / sizeof(Qfloat);
97
+ size = max(size, 2 * (long int) l); // cache must be large enough for two columns
98
+ lru_head.next = lru_head.prev = &lru_head;
99
+ }
100
+
101
+ Cache::~Cache()
102
+ {
103
+ for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
104
+ free(h->data);
105
+ free(head);
106
+ }
107
+
108
+ void Cache::lru_delete(head_t *h)
109
+ {
110
+ // delete from current location
111
+ h->prev->next = h->next;
112
+ h->next->prev = h->prev;
113
+ }
114
+
115
+ void Cache::lru_insert(head_t *h)
116
+ {
117
+ // insert to last position
118
+ h->next = &lru_head;
119
+ h->prev = lru_head.prev;
120
+ h->prev->next = h;
121
+ h->next->prev = h;
122
+ }
123
+
124
+ int Cache::get_data(const int index, Qfloat **data, int len)
125
+ {
126
+ head_t *h = &head[index];
127
+ if(h->len) lru_delete(h);
128
+ int more = len - h->len;
129
+
130
+ if(more > 0)
131
+ {
132
+ // free old space
133
+ while(size < more)
134
+ {
135
+ head_t *old = lru_head.next;
136
+ lru_delete(old);
137
+ free(old->data);
138
+ size += old->len;
139
+ old->data = 0;
140
+ old->len = 0;
141
+ }
142
+
143
+ // allocate new space
144
+ h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
145
+ size -= more;
146
+ swap(h->len,len);
147
+ }
148
+
149
+ lru_insert(h);
150
+ *data = h->data;
151
+ return len;
152
+ }
153
+
154
+ void Cache::swap_index(int i, int j)
155
+ {
156
+ if(i==j) return;
157
+
158
+ if(head[i].len) lru_delete(&head[i]);
159
+ if(head[j].len) lru_delete(&head[j]);
160
+ swap(head[i].data,head[j].data);
161
+ swap(head[i].len,head[j].len);
162
+ if(head[i].len) lru_insert(&head[i]);
163
+ if(head[j].len) lru_insert(&head[j]);
164
+
165
+ if(i>j) swap(i,j);
166
+ for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
167
+ {
168
+ if(h->len > i)
169
+ {
170
+ if(h->len > j)
171
+ swap(h->data[i],h->data[j]);
172
+ else
173
+ {
174
+ // give up
175
+ lru_delete(h);
176
+ free(h->data);
177
+ size += h->len;
178
+ h->data = 0;
179
+ h->len = 0;
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ //
186
+ // Kernel evaluation
187
+ //
188
+ // the static method k_function is for doing single kernel evaluation
189
+ // the constructor of Kernel prepares to calculate the l*l kernel matrix
190
+ // the member function get_Q is for getting one column from the Q Matrix
191
+ //
192
+ class QMatrix {
193
+ public:
194
+ virtual Qfloat *get_Q(int column, int len) const = 0;
195
+ virtual double *get_QD() const = 0;
196
+ virtual void swap_index(int i, int j) const = 0;
197
+ virtual ~QMatrix() {}
198
+ };
199
+
200
+ class Kernel: public QMatrix {
201
+ public:
202
+ Kernel(int l, svm_node * const * x, const svm_parameter& param);
203
+ virtual ~Kernel();
204
+
205
+ static double k_function(const svm_node *x, const svm_node *y,
206
+ const svm_parameter& param);
207
+ virtual Qfloat *get_Q(int column, int len) const = 0;
208
+ virtual double *get_QD() const = 0;
209
+ virtual void swap_index(int i, int j) const // no so const...
210
+ {
211
+ swap(x[i],x[j]);
212
+ if(x_square) swap(x_square[i],x_square[j]);
213
+ }
214
+ protected:
215
+
216
+ double (Kernel::*kernel_function)(int i, int j) const;
217
+
218
+ private:
219
+ const svm_node **x;
220
+ double *x_square;
221
+
222
+ // svm_parameter
223
+ const int kernel_type;
224
+ const int degree;
225
+ const double gamma;
226
+ const double coef0;
227
+
228
+ static double dot(const svm_node *px, const svm_node *py);
229
+ double kernel_linear(int i, int j) const
230
+ {
231
+ return dot(x[i],x[j]);
232
+ }
233
+ double kernel_poly(int i, int j) const
234
+ {
235
+ return powi(gamma*dot(x[i],x[j])+coef0,degree);
236
+ }
237
+ double kernel_rbf(int i, int j) const
238
+ {
239
+ return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
240
+ }
241
+ double kernel_sigmoid(int i, int j) const
242
+ {
243
+ return tanh(gamma*dot(x[i],x[j])+coef0);
244
+ }
245
+ double kernel_precomputed(int i, int j) const
246
+ {
247
+ return x[i][(int)(x[j][0].value)].value;
248
+ }
249
+ };
250
+
251
+ Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
252
+ :kernel_type(param.kernel_type), degree(param.degree),
253
+ gamma(param.gamma), coef0(param.coef0)
254
+ {
255
+ switch(kernel_type)
256
+ {
257
+ case LINEAR:
258
+ kernel_function = &Kernel::kernel_linear;
259
+ break;
260
+ case POLY:
261
+ kernel_function = &Kernel::kernel_poly;
262
+ break;
263
+ case RBF:
264
+ kernel_function = &Kernel::kernel_rbf;
265
+ break;
266
+ case SIGMOID:
267
+ kernel_function = &Kernel::kernel_sigmoid;
268
+ break;
269
+ case PRECOMPUTED:
270
+ kernel_function = &Kernel::kernel_precomputed;
271
+ break;
272
+ }
273
+
274
+ clone(x,x_,l);
275
+
276
+ if(kernel_type == RBF)
277
+ {
278
+ x_square = new double[l];
279
+ for(int i=0;i<l;i++)
280
+ x_square[i] = dot(x[i],x[i]);
281
+ }
282
+ else
283
+ x_square = 0;
284
+ }
285
+
286
+ Kernel::~Kernel()
287
+ {
288
+ delete[] x;
289
+ delete[] x_square;
290
+ }
291
+
292
+ double Kernel::dot(const svm_node *px, const svm_node *py)
293
+ {
294
+ double sum = 0;
295
+ while(px->index != -1 && py->index != -1)
296
+ {
297
+ if(px->index == py->index)
298
+ {
299
+ sum += px->value * py->value;
300
+ ++px;
301
+ ++py;
302
+ }
303
+ else
304
+ {
305
+ if(px->index > py->index)
306
+ ++py;
307
+ else
308
+ ++px;
309
+ }
310
+ }
311
+ return sum;
312
+ }
313
+
314
+ double Kernel::k_function(const svm_node *x, const svm_node *y,
315
+ const svm_parameter& param)
316
+ {
317
+ switch(param.kernel_type)
318
+ {
319
+ case LINEAR:
320
+ return dot(x,y);
321
+ case POLY:
322
+ return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
323
+ case RBF:
324
+ {
325
+ double sum = 0;
326
+ while(x->index != -1 && y->index !=-1)
327
+ {
328
+ if(x->index == y->index)
329
+ {
330
+ double d = x->value - y->value;
331
+ sum += d*d;
332
+ ++x;
333
+ ++y;
334
+ }
335
+ else
336
+ {
337
+ if(x->index > y->index)
338
+ {
339
+ sum += y->value * y->value;
340
+ ++y;
341
+ }
342
+ else
343
+ {
344
+ sum += x->value * x->value;
345
+ ++x;
346
+ }
347
+ }
348
+ }
349
+
350
+ while(x->index != -1)
351
+ {
352
+ sum += x->value * x->value;
353
+ ++x;
354
+ }
355
+
356
+ while(y->index != -1)
357
+ {
358
+ sum += y->value * y->value;
359
+ ++y;
360
+ }
361
+
362
+ return exp(-param.gamma*sum);
363
+ }
364
+ case SIGMOID:
365
+ return tanh(param.gamma*dot(x,y)+param.coef0);
366
+ case PRECOMPUTED: //x: test (validation), y: SV
367
+ return x[(int)(y->value)].value;
368
+ default:
369
+ return 0; // Unreachable
370
+ }
371
+ }
372
+
373
+ // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
374
+ // Solves:
375
+ //
376
+ // min 0.5(\alpha^T Q \alpha) + p^T \alpha
377
+ //
378
+ // y^T \alpha = \delta
379
+ // y_i = +1 or -1
380
+ // 0 <= alpha_i <= Cp for y_i = 1
381
+ // 0 <= alpha_i <= Cn for y_i = -1
382
+ //
383
+ // Given:
384
+ //
385
+ // Q, p, y, Cp, Cn, and an initial feasible point \alpha
386
+ // l is the size of vectors and matrices
387
+ // eps is the stopping tolerance
388
+ //
389
+ // solution will be put in \alpha, objective value will be put in obj
390
+ //
391
+ class Solver {
392
+ public:
393
+ Solver() {};
394
+ virtual ~Solver() {};
395
+
396
+ struct SolutionInfo {
397
+ double obj;
398
+ double rho;
399
+ double upper_bound_p;
400
+ double upper_bound_n;
401
+ double r; // for Solver_NU
402
+ };
403
+
404
+ void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
405
+ double *alpha_, double Cp, double Cn, double eps,
406
+ SolutionInfo* si, int shrinking);
407
+ protected:
408
+ int active_size;
409
+ schar *y;
410
+ double *G; // gradient of objective function
411
+ enum { LOWER_BOUND, UPPER_BOUND, FREE };
412
+ char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
413
+ double *alpha;
414
+ const QMatrix *Q;
415
+ const double *QD;
416
+ double eps;
417
+ double Cp,Cn;
418
+ double *p;
419
+ int *active_set;
420
+ double *G_bar; // gradient, if we treat free variables as 0
421
+ int l;
422
+ bool unshrink; // XXX
423
+
424
+ double get_C(int i)
425
+ {
426
+ return (y[i] > 0)? Cp : Cn;
427
+ }
428
+ void update_alpha_status(int i)
429
+ {
430
+ if(alpha[i] >= get_C(i))
431
+ alpha_status[i] = UPPER_BOUND;
432
+ else if(alpha[i] <= 0)
433
+ alpha_status[i] = LOWER_BOUND;
434
+ else alpha_status[i] = FREE;
435
+ }
436
+ bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
437
+ bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
438
+ bool is_free(int i) { return alpha_status[i] == FREE; }
439
+ void swap_index(int i, int j);
440
+ void reconstruct_gradient();
441
+ virtual int select_working_set(int &i, int &j);
442
+ virtual double calculate_rho();
443
+ virtual void do_shrinking();
444
+ private:
445
+ bool be_shrunk(int i, double Gmax1, double Gmax2);
446
+ };
447
+
448
+ void Solver::swap_index(int i, int j)
449
+ {
450
+ Q->swap_index(i,j);
451
+ swap(y[i],y[j]);
452
+ swap(G[i],G[j]);
453
+ swap(alpha_status[i],alpha_status[j]);
454
+ swap(alpha[i],alpha[j]);
455
+ swap(p[i],p[j]);
456
+ swap(active_set[i],active_set[j]);
457
+ swap(G_bar[i],G_bar[j]);
458
+ }
459
+
460
+ void Solver::reconstruct_gradient()
461
+ {
462
+ // reconstruct inactive elements of G from G_bar and free variables
463
+
464
+ if(active_size == l) return;
465
+
466
+ int i,j;
467
+ int nr_free = 0;
468
+
469
+ for(j=active_size;j<l;j++)
470
+ G[j] = G_bar[j] + p[j];
471
+
472
+ for(j=0;j<active_size;j++)
473
+ if(is_free(j))
474
+ nr_free++;
475
+
476
+ if(2*nr_free < active_size)
477
+ info("\nWarning: using -h 0 may be faster\n");
478
+
479
+ if (nr_free*l > 2*active_size*(l-active_size))
480
+ {
481
+ for(i=active_size;i<l;i++)
482
+ {
483
+ const Qfloat *Q_i = Q->get_Q(i,active_size);
484
+ for(j=0;j<active_size;j++)
485
+ if(is_free(j))
486
+ G[i] += alpha[j] * Q_i[j];
487
+ }
488
+ }
489
+ else
490
+ {
491
+ for(i=0;i<active_size;i++)
492
+ if(is_free(i))
493
+ {
494
+ const Qfloat *Q_i = Q->get_Q(i,l);
495
+ double alpha_i = alpha[i];
496
+ for(j=active_size;j<l;j++)
497
+ G[j] += alpha_i * Q_i[j];
498
+ }
499
+ }
500
+ }
501
+
502
+ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
503
+ double *alpha_, double Cp, double Cn, double eps,
504
+ SolutionInfo* si, int shrinking)
505
+ {
506
+ this->l = l;
507
+ this->Q = &Q;
508
+ QD=Q.get_QD();
509
+ clone(p, p_,l);
510
+ clone(y, y_,l);
511
+ clone(alpha,alpha_,l);
512
+ this->Cp = Cp;
513
+ this->Cn = Cn;
514
+ this->eps = eps;
515
+ unshrink = false;
516
+
517
+ // initialize alpha_status
518
+ {
519
+ alpha_status = new char[l];
520
+ for(int i=0;i<l;i++)
521
+ update_alpha_status(i);
522
+ }
523
+
524
+ // initialize active set (for shrinking)
525
+ {
526
+ active_set = new int[l];
527
+ for(int i=0;i<l;i++)
528
+ active_set[i] = i;
529
+ active_size = l;
530
+ }
531
+
532
+ // initialize gradient
533
+ {
534
+ G = new double[l];
535
+ G_bar = new double[l];
536
+ int i;
537
+ for(i=0;i<l;i++)
538
+ {
539
+ G[i] = p[i];
540
+ G_bar[i] = 0;
541
+ }
542
+ for(i=0;i<l;i++)
543
+ if(!is_lower_bound(i))
544
+ {
545
+ const Qfloat *Q_i = Q.get_Q(i,l);
546
+ double alpha_i = alpha[i];
547
+ int j;
548
+ for(j=0;j<l;j++)
549
+ G[j] += alpha_i*Q_i[j];
550
+ if(is_upper_bound(i))
551
+ for(j=0;j<l;j++)
552
+ G_bar[j] += get_C(i) * Q_i[j];
553
+ }
554
+ }
555
+
556
+ // optimization step
557
+
558
+ int iter = 0;
559
+ int counter = min(l,1000)+1;
560
+
561
+ while(1)
562
+ {
563
+ // show progress and do shrinking
564
+
565
+ if(--counter == 0)
566
+ {
567
+ counter = min(l,1000);
568
+ if(shrinking) do_shrinking();
569
+ info(".");
570
+ }
571
+
572
+ int i,j;
573
+ if(select_working_set(i,j)!=0)
574
+ {
575
+ // reconstruct the whole gradient
576
+ reconstruct_gradient();
577
+ // reset active set size and check
578
+ active_size = l;
579
+ info("*");
580
+ if(select_working_set(i,j)!=0)
581
+ break;
582
+ else
583
+ counter = 1; // do shrinking next iteration
584
+ }
585
+
586
+ ++iter;
587
+
588
+ // update alpha[i] and alpha[j], handle bounds carefully
589
+
590
+ const Qfloat *Q_i = Q.get_Q(i,active_size);
591
+ const Qfloat *Q_j = Q.get_Q(j,active_size);
592
+
593
+ double C_i = get_C(i);
594
+ double C_j = get_C(j);
595
+
596
+ double old_alpha_i = alpha[i];
597
+ double old_alpha_j = alpha[j];
598
+
599
+ if(y[i]!=y[j])
600
+ {
601
+ double quad_coef = QD[i]+QD[j]+2*Q_i[j];
602
+ if (quad_coef <= 0)
603
+ quad_coef = TAU;
604
+ double delta = (-G[i]-G[j])/quad_coef;
605
+ double diff = alpha[i] - alpha[j];
606
+ alpha[i] += delta;
607
+ alpha[j] += delta;
608
+
609
+ if(diff > 0)
610
+ {
611
+ if(alpha[j] < 0)
612
+ {
613
+ alpha[j] = 0;
614
+ alpha[i] = diff;
615
+ }
616
+ }
617
+ else
618
+ {
619
+ if(alpha[i] < 0)
620
+ {
621
+ alpha[i] = 0;
622
+ alpha[j] = -diff;
623
+ }
624
+ }
625
+ if(diff > C_i - C_j)
626
+ {
627
+ if(alpha[i] > C_i)
628
+ {
629
+ alpha[i] = C_i;
630
+ alpha[j] = C_i - diff;
631
+ }
632
+ }
633
+ else
634
+ {
635
+ if(alpha[j] > C_j)
636
+ {
637
+ alpha[j] = C_j;
638
+ alpha[i] = C_j + diff;
639
+ }
640
+ }
641
+ }
642
+ else
643
+ {
644
+ double quad_coef = QD[i]+QD[j]-2*Q_i[j];
645
+ if (quad_coef <= 0)
646
+ quad_coef = TAU;
647
+ double delta = (G[i]-G[j])/quad_coef;
648
+ double sum = alpha[i] + alpha[j];
649
+ alpha[i] -= delta;
650
+ alpha[j] += delta;
651
+
652
+ if(sum > C_i)
653
+ {
654
+ if(alpha[i] > C_i)
655
+ {
656
+ alpha[i] = C_i;
657
+ alpha[j] = sum - C_i;
658
+ }
659
+ }
660
+ else
661
+ {
662
+ if(alpha[j] < 0)
663
+ {
664
+ alpha[j] = 0;
665
+ alpha[i] = sum;
666
+ }
667
+ }
668
+ if(sum > C_j)
669
+ {
670
+ if(alpha[j] > C_j)
671
+ {
672
+ alpha[j] = C_j;
673
+ alpha[i] = sum - C_j;
674
+ }
675
+ }
676
+ else
677
+ {
678
+ if(alpha[i] < 0)
679
+ {
680
+ alpha[i] = 0;
681
+ alpha[j] = sum;
682
+ }
683
+ }
684
+ }
685
+
686
+ // update G
687
+
688
+ double delta_alpha_i = alpha[i] - old_alpha_i;
689
+ double delta_alpha_j = alpha[j] - old_alpha_j;
690
+
691
+ for(int k=0;k<active_size;k++)
692
+ {
693
+ G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
694
+ }
695
+
696
+ // update alpha_status and G_bar
697
+
698
+ {
699
+ bool ui = is_upper_bound(i);
700
+ bool uj = is_upper_bound(j);
701
+ update_alpha_status(i);
702
+ update_alpha_status(j);
703
+ int k;
704
+ if(ui != is_upper_bound(i))
705
+ {
706
+ Q_i = Q.get_Q(i,l);
707
+ if(ui)
708
+ for(k=0;k<l;k++)
709
+ G_bar[k] -= C_i * Q_i[k];
710
+ else
711
+ for(k=0;k<l;k++)
712
+ G_bar[k] += C_i * Q_i[k];
713
+ }
714
+
715
+ if(uj != is_upper_bound(j))
716
+ {
717
+ Q_j = Q.get_Q(j,l);
718
+ if(uj)
719
+ for(k=0;k<l;k++)
720
+ G_bar[k] -= C_j * Q_j[k];
721
+ else
722
+ for(k=0;k<l;k++)
723
+ G_bar[k] += C_j * Q_j[k];
724
+ }
725
+ }
726
+ }
727
+
728
+ // calculate rho
729
+
730
+ si->rho = calculate_rho();
731
+
732
+ // calculate objective value
733
+ {
734
+ double v = 0;
735
+ int i;
736
+ for(i=0;i<l;i++)
737
+ v += alpha[i] * (G[i] + p[i]);
738
+
739
+ si->obj = v/2;
740
+ }
741
+
742
+ // put back the solution
743
+ {
744
+ for(int i=0;i<l;i++)
745
+ alpha_[active_set[i]] = alpha[i];
746
+ }
747
+
748
+ // juggle everything back
749
+ /*{
750
+ for(int i=0;i<l;i++)
751
+ while(active_set[i] != i)
752
+ swap_index(i,active_set[i]);
753
+ // or Q.swap_index(i,active_set[i]);
754
+ }*/
755
+
756
+ si->upper_bound_p = Cp;
757
+ si->upper_bound_n = Cn;
758
+
759
+ info("\noptimization finished, #iter = %d\n",iter);
760
+
761
+ delete[] p;
762
+ delete[] y;
763
+ delete[] alpha;
764
+ delete[] alpha_status;
765
+ delete[] active_set;
766
+ delete[] G;
767
+ delete[] G_bar;
768
+ }
769
+
770
+ // return 1 if already optimal, return 0 otherwise
771
+ int Solver::select_working_set(int &out_i, int &out_j)
772
+ {
773
+ // return i,j such that
774
+ // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
775
+ // j: minimizes the decrease of obj value
776
+ // (if quadratic coefficeint <= 0, replace it with tau)
777
+ // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
778
+
779
+ double Gmax = -INF;
780
+ double Gmax2 = -INF;
781
+ int Gmax_idx = -1;
782
+ int Gmin_idx = -1;
783
+ double obj_diff_min = INF;
784
+
785
+ for(int t=0;t<active_size;t++)
786
+ if(y[t]==+1)
787
+ {
788
+ if(!is_upper_bound(t))
789
+ if(-G[t] >= Gmax)
790
+ {
791
+ Gmax = -G[t];
792
+ Gmax_idx = t;
793
+ }
794
+ }
795
+ else
796
+ {
797
+ if(!is_lower_bound(t))
798
+ if(G[t] >= Gmax)
799
+ {
800
+ Gmax = G[t];
801
+ Gmax_idx = t;
802
+ }
803
+ }
804
+
805
+ int i = Gmax_idx;
806
+ const Qfloat *Q_i = NULL;
807
+ if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
808
+ Q_i = Q->get_Q(i,active_size);
809
+
810
+ for(int j=0;j<active_size;j++)
811
+ {
812
+ if(y[j]==+1)
813
+ {
814
+ if (!is_lower_bound(j))
815
+ {
816
+ double grad_diff=Gmax+G[j];
817
+ if (G[j] >= Gmax2)
818
+ Gmax2 = G[j];
819
+ if (grad_diff > 0)
820
+ {
821
+ double obj_diff;
822
+ double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
823
+ if (quad_coef > 0)
824
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
825
+ else
826
+ obj_diff = -(grad_diff*grad_diff)/TAU;
827
+
828
+ if (obj_diff <= obj_diff_min)
829
+ {
830
+ Gmin_idx=j;
831
+ obj_diff_min = obj_diff;
832
+ }
833
+ }
834
+ }
835
+ }
836
+ else
837
+ {
838
+ if (!is_upper_bound(j))
839
+ {
840
+ double grad_diff= Gmax-G[j];
841
+ if (-G[j] >= Gmax2)
842
+ Gmax2 = -G[j];
843
+ if (grad_diff > 0)
844
+ {
845
+ double obj_diff;
846
+ double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
847
+ if (quad_coef > 0)
848
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
849
+ else
850
+ obj_diff = -(grad_diff*grad_diff)/TAU;
851
+
852
+ if (obj_diff <= obj_diff_min)
853
+ {
854
+ Gmin_idx=j;
855
+ obj_diff_min = obj_diff;
856
+ }
857
+ }
858
+ }
859
+ }
860
+ }
861
+
862
+ if(Gmax+Gmax2 < eps)
863
+ return 1;
864
+
865
+ out_i = Gmax_idx;
866
+ out_j = Gmin_idx;
867
+ return 0;
868
+ }
869
+
870
+ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
871
+ {
872
+ if(is_upper_bound(i))
873
+ {
874
+ if(y[i]==+1)
875
+ return(-G[i] > Gmax1);
876
+ else
877
+ return(-G[i] > Gmax2);
878
+ }
879
+ else if(is_lower_bound(i))
880
+ {
881
+ if(y[i]==+1)
882
+ return(G[i] > Gmax2);
883
+ else
884
+ return(G[i] > Gmax1);
885
+ }
886
+ else
887
+ return(false);
888
+ }
889
+
890
+ void Solver::do_shrinking()
891
+ {
892
+ int i;
893
+ double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
894
+ double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
895
+
896
+ // find maximal violating pair first
897
+ for(i=0;i<active_size;i++)
898
+ {
899
+ if(y[i]==+1)
900
+ {
901
+ if(!is_upper_bound(i))
902
+ {
903
+ if(-G[i] >= Gmax1)
904
+ Gmax1 = -G[i];
905
+ }
906
+ if(!is_lower_bound(i))
907
+ {
908
+ if(G[i] >= Gmax2)
909
+ Gmax2 = G[i];
910
+ }
911
+ }
912
+ else
913
+ {
914
+ if(!is_upper_bound(i))
915
+ {
916
+ if(-G[i] >= Gmax2)
917
+ Gmax2 = -G[i];
918
+ }
919
+ if(!is_lower_bound(i))
920
+ {
921
+ if(G[i] >= Gmax1)
922
+ Gmax1 = G[i];
923
+ }
924
+ }
925
+ }
926
+
927
+ if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
928
+ {
929
+ unshrink = true;
930
+ reconstruct_gradient();
931
+ active_size = l;
932
+ info("*");
933
+ }
934
+
935
+ for(i=0;i<active_size;i++)
936
+ if (be_shrunk(i, Gmax1, Gmax2))
937
+ {
938
+ active_size--;
939
+ while (active_size > i)
940
+ {
941
+ if (!be_shrunk(active_size, Gmax1, Gmax2))
942
+ {
943
+ swap_index(i,active_size);
944
+ break;
945
+ }
946
+ active_size--;
947
+ }
948
+ }
949
+ }
950
+
951
+ double Solver::calculate_rho()
952
+ {
953
+ double r;
954
+ int nr_free = 0;
955
+ double ub = INF, lb = -INF, sum_free = 0;
956
+ for(int i=0;i<active_size;i++)
957
+ {
958
+ double yG = y[i]*G[i];
959
+
960
+ if(is_upper_bound(i))
961
+ {
962
+ if(y[i]==-1)
963
+ ub = min(ub,yG);
964
+ else
965
+ lb = max(lb,yG);
966
+ }
967
+ else if(is_lower_bound(i))
968
+ {
969
+ if(y[i]==+1)
970
+ ub = min(ub,yG);
971
+ else
972
+ lb = max(lb,yG);
973
+ }
974
+ else
975
+ {
976
+ ++nr_free;
977
+ sum_free += yG;
978
+ }
979
+ }
980
+
981
+ if(nr_free>0)
982
+ r = sum_free/nr_free;
983
+ else
984
+ r = (ub+lb)/2;
985
+
986
+ return r;
987
+ }
988
+
989
+ //
990
+ // Solver for nu-svm classification and regression
991
+ //
992
+ // additional constraint: e^T \alpha = constant
993
+ //
994
+ class Solver_NU : public Solver
995
+ {
996
+ public:
997
+ Solver_NU() {}
998
+ void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
999
+ double *alpha, double Cp, double Cn, double eps,
1000
+ SolutionInfo* si, int shrinking)
1001
+ {
1002
+ this->si = si;
1003
+ Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1004
+ }
1005
+ private:
1006
+ SolutionInfo *si;
1007
+ int select_working_set(int &i, int &j);
1008
+ double calculate_rho();
1009
+ bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1010
+ void do_shrinking();
1011
+ };
1012
+
1013
+ // return 1 if already optimal, return 0 otherwise
1014
+ int Solver_NU::select_working_set(int &out_i, int &out_j)
1015
+ {
1016
+ // return i,j such that y_i = y_j and
1017
+ // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1018
+ // j: minimizes the decrease of obj value
1019
+ // (if quadratic coefficeint <= 0, replace it with tau)
1020
+ // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1021
+
1022
+ double Gmaxp = -INF;
1023
+ double Gmaxp2 = -INF;
1024
+ int Gmaxp_idx = -1;
1025
+
1026
+ double Gmaxn = -INF;
1027
+ double Gmaxn2 = -INF;
1028
+ int Gmaxn_idx = -1;
1029
+
1030
+ int Gmin_idx = -1;
1031
+ double obj_diff_min = INF;
1032
+
1033
+ for(int t=0;t<active_size;t++)
1034
+ if(y[t]==+1)
1035
+ {
1036
+ if(!is_upper_bound(t))
1037
+ if(-G[t] >= Gmaxp)
1038
+ {
1039
+ Gmaxp = -G[t];
1040
+ Gmaxp_idx = t;
1041
+ }
1042
+ }
1043
+ else
1044
+ {
1045
+ if(!is_lower_bound(t))
1046
+ if(G[t] >= Gmaxn)
1047
+ {
1048
+ Gmaxn = G[t];
1049
+ Gmaxn_idx = t;
1050
+ }
1051
+ }
1052
+
1053
+ int ip = Gmaxp_idx;
1054
+ int in = Gmaxn_idx;
1055
+ const Qfloat *Q_ip = NULL;
1056
+ const Qfloat *Q_in = NULL;
1057
+ if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1058
+ Q_ip = Q->get_Q(ip,active_size);
1059
+ if(in != -1)
1060
+ Q_in = Q->get_Q(in,active_size);
1061
+
1062
+ for(int j=0;j<active_size;j++)
1063
+ {
1064
+ if(y[j]==+1)
1065
+ {
1066
+ if (!is_lower_bound(j))
1067
+ {
1068
+ double grad_diff=Gmaxp+G[j];
1069
+ if (G[j] >= Gmaxp2)
1070
+ Gmaxp2 = G[j];
1071
+ if (grad_diff > 0)
1072
+ {
1073
+ double obj_diff;
1074
+ double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
1075
+ if (quad_coef > 0)
1076
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
1077
+ else
1078
+ obj_diff = -(grad_diff*grad_diff)/TAU;
1079
+
1080
+ if (obj_diff <= obj_diff_min)
1081
+ {
1082
+ Gmin_idx=j;
1083
+ obj_diff_min = obj_diff;
1084
+ }
1085
+ }
1086
+ }
1087
+ }
1088
+ else
1089
+ {
1090
+ if (!is_upper_bound(j))
1091
+ {
1092
+ double grad_diff=Gmaxn-G[j];
1093
+ if (-G[j] >= Gmaxn2)
1094
+ Gmaxn2 = -G[j];
1095
+ if (grad_diff > 0)
1096
+ {
1097
+ double obj_diff;
1098
+ double quad_coef = QD[in]+QD[j]-2*Q_in[j];
1099
+ if (quad_coef > 0)
1100
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
1101
+ else
1102
+ obj_diff = -(grad_diff*grad_diff)/TAU;
1103
+
1104
+ if (obj_diff <= obj_diff_min)
1105
+ {
1106
+ Gmin_idx=j;
1107
+ obj_diff_min = obj_diff;
1108
+ }
1109
+ }
1110
+ }
1111
+ }
1112
+ }
1113
+
1114
+ if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1115
+ return 1;
1116
+
1117
+ if (y[Gmin_idx] == +1)
1118
+ out_i = Gmaxp_idx;
1119
+ else
1120
+ out_i = Gmaxn_idx;
1121
+ out_j = Gmin_idx;
1122
+
1123
+ return 0;
1124
+ }
1125
+
1126
+ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1127
+ {
1128
+ if(is_upper_bound(i))
1129
+ {
1130
+ if(y[i]==+1)
1131
+ return(-G[i] > Gmax1);
1132
+ else
1133
+ return(-G[i] > Gmax4);
1134
+ }
1135
+ else if(is_lower_bound(i))
1136
+ {
1137
+ if(y[i]==+1)
1138
+ return(G[i] > Gmax2);
1139
+ else
1140
+ return(G[i] > Gmax3);
1141
+ }
1142
+ else
1143
+ return(false);
1144
+ }
1145
+
1146
+ void Solver_NU::do_shrinking()
1147
+ {
1148
+ double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1149
+ double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1150
+ double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1151
+ double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1152
+
1153
+ // find maximal violating pair first
1154
+ int i;
1155
+ for(i=0;i<active_size;i++)
1156
+ {
1157
+ if(!is_upper_bound(i))
1158
+ {
1159
+ if(y[i]==+1)
1160
+ {
1161
+ if(-G[i] > Gmax1) Gmax1 = -G[i];
1162
+ }
1163
+ else if(-G[i] > Gmax4) Gmax4 = -G[i];
1164
+ }
1165
+ if(!is_lower_bound(i))
1166
+ {
1167
+ if(y[i]==+1)
1168
+ {
1169
+ if(G[i] > Gmax2) Gmax2 = G[i];
1170
+ }
1171
+ else if(G[i] > Gmax3) Gmax3 = G[i];
1172
+ }
1173
+ }
1174
+
1175
+ if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10)
1176
+ {
1177
+ unshrink = true;
1178
+ reconstruct_gradient();
1179
+ active_size = l;
1180
+ }
1181
+
1182
+ for(i=0;i<active_size;i++)
1183
+ if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1184
+ {
1185
+ active_size--;
1186
+ while (active_size > i)
1187
+ {
1188
+ if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1189
+ {
1190
+ swap_index(i,active_size);
1191
+ break;
1192
+ }
1193
+ active_size--;
1194
+ }
1195
+ }
1196
+ }
1197
+
1198
+ double Solver_NU::calculate_rho()
1199
+ {
1200
+ int nr_free1 = 0,nr_free2 = 0;
1201
+ double ub1 = INF, ub2 = INF;
1202
+ double lb1 = -INF, lb2 = -INF;
1203
+ double sum_free1 = 0, sum_free2 = 0;
1204
+
1205
+ for(int i=0;i<active_size;i++)
1206
+ {
1207
+ if(y[i]==+1)
1208
+ {
1209
+ if(is_upper_bound(i))
1210
+ lb1 = max(lb1,G[i]);
1211
+ else if(is_lower_bound(i))
1212
+ ub1 = min(ub1,G[i]);
1213
+ else
1214
+ {
1215
+ ++nr_free1;
1216
+ sum_free1 += G[i];
1217
+ }
1218
+ }
1219
+ else
1220
+ {
1221
+ if(is_upper_bound(i))
1222
+ lb2 = max(lb2,G[i]);
1223
+ else if(is_lower_bound(i))
1224
+ ub2 = min(ub2,G[i]);
1225
+ else
1226
+ {
1227
+ ++nr_free2;
1228
+ sum_free2 += G[i];
1229
+ }
1230
+ }
1231
+ }
1232
+
1233
+ double r1,r2;
1234
+ if(nr_free1 > 0)
1235
+ r1 = sum_free1/nr_free1;
1236
+ else
1237
+ r1 = (ub1+lb1)/2;
1238
+
1239
+ if(nr_free2 > 0)
1240
+ r2 = sum_free2/nr_free2;
1241
+ else
1242
+ r2 = (ub2+lb2)/2;
1243
+
1244
+ si->r = (r1+r2)/2;
1245
+ return (r1-r2)/2;
1246
+ }
1247
+
1248
+ //
1249
+ // Q matrices for various formulations
1250
+ //
1251
+ class SVC_Q: public Kernel
1252
+ {
1253
+ public:
1254
+ SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1255
+ :Kernel(prob.l, prob.x, param)
1256
+ {
1257
+ clone(y,y_,prob.l);
1258
+ cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1259
+ QD = new double[prob.l];
1260
+ for(int i=0;i<prob.l;i++)
1261
+ QD[i] = (this->*kernel_function)(i,i);
1262
+ }
1263
+
1264
+ Qfloat *get_Q(int i, int len) const
1265
+ {
1266
+ Qfloat *data;
1267
+ int start, j;
1268
+ if((start = cache->get_data(i,&data,len)) < len)
1269
+ {
1270
+ for(j=start;j<len;j++)
1271
+ data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1272
+ }
1273
+ return data;
1274
+ }
1275
+
1276
+ double *get_QD() const
1277
+ {
1278
+ return QD;
1279
+ }
1280
+
1281
+ void swap_index(int i, int j) const
1282
+ {
1283
+ cache->swap_index(i,j);
1284
+ Kernel::swap_index(i,j);
1285
+ swap(y[i],y[j]);
1286
+ swap(QD[i],QD[j]);
1287
+ }
1288
+
1289
+ ~SVC_Q()
1290
+ {
1291
+ delete[] y;
1292
+ delete cache;
1293
+ delete[] QD;
1294
+ }
1295
+ private:
1296
+ schar *y;
1297
+ Cache *cache;
1298
+ double *QD;
1299
+ };
1300
+
1301
+ class ONE_CLASS_Q: public Kernel
1302
+ {
1303
+ public:
1304
+ ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1305
+ :Kernel(prob.l, prob.x, param)
1306
+ {
1307
+ cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1308
+ QD = new double[prob.l];
1309
+ for(int i=0;i<prob.l;i++)
1310
+ QD[i] = (this->*kernel_function)(i,i);
1311
+ }
1312
+
1313
+ Qfloat *get_Q(int i, int len) const
1314
+ {
1315
+ Qfloat *data;
1316
+ int start, j;
1317
+ if((start = cache->get_data(i,&data,len)) < len)
1318
+ {
1319
+ for(j=start;j<len;j++)
1320
+ data[j] = (Qfloat)(this->*kernel_function)(i,j);
1321
+ }
1322
+ return data;
1323
+ }
1324
+
1325
+ double *get_QD() const
1326
+ {
1327
+ return QD;
1328
+ }
1329
+
1330
+ void swap_index(int i, int j) const
1331
+ {
1332
+ cache->swap_index(i,j);
1333
+ Kernel::swap_index(i,j);
1334
+ swap(QD[i],QD[j]);
1335
+ }
1336
+
1337
+ ~ONE_CLASS_Q()
1338
+ {
1339
+ delete cache;
1340
+ delete[] QD;
1341
+ }
1342
+ private:
1343
+ Cache *cache;
1344
+ double *QD;
1345
+ };
1346
+
1347
+ class SVR_Q: public Kernel
1348
+ {
1349
+ public:
1350
+ SVR_Q(const svm_problem& prob, const svm_parameter& param)
1351
+ :Kernel(prob.l, prob.x, param)
1352
+ {
1353
+ l = prob.l;
1354
+ cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1355
+ QD = new double[2*l];
1356
+ sign = new schar[2*l];
1357
+ index = new int[2*l];
1358
+ for(int k=0;k<l;k++)
1359
+ {
1360
+ sign[k] = 1;
1361
+ sign[k+l] = -1;
1362
+ index[k] = k;
1363
+ index[k+l] = k;
1364
+ QD[k] = (this->*kernel_function)(k,k);
1365
+ QD[k+l] = QD[k];
1366
+ }
1367
+ buffer[0] = new Qfloat[2*l];
1368
+ buffer[1] = new Qfloat[2*l];
1369
+ next_buffer = 0;
1370
+ }
1371
+
1372
+ void swap_index(int i, int j) const
1373
+ {
1374
+ swap(sign[i],sign[j]);
1375
+ swap(index[i],index[j]);
1376
+ swap(QD[i],QD[j]);
1377
+ }
1378
+
1379
+ Qfloat *get_Q(int i, int len) const
1380
+ {
1381
+ Qfloat *data;
1382
+ int j, real_i = index[i];
1383
+ if(cache->get_data(real_i,&data,l) < l)
1384
+ {
1385
+ for(j=0;j<l;j++)
1386
+ data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1387
+ }
1388
+
1389
+ // reorder and copy
1390
+ Qfloat *buf = buffer[next_buffer];
1391
+ next_buffer = 1 - next_buffer;
1392
+ schar si = sign[i];
1393
+ for(j=0;j<len;j++)
1394
+ buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1395
+ return buf;
1396
+ }
1397
+
1398
+ double *get_QD() const
1399
+ {
1400
+ return QD;
1401
+ }
1402
+
1403
+ ~SVR_Q()
1404
+ {
1405
+ delete cache;
1406
+ delete[] sign;
1407
+ delete[] index;
1408
+ delete[] buffer[0];
1409
+ delete[] buffer[1];
1410
+ delete[] QD;
1411
+ }
1412
+ private:
1413
+ int l;
1414
+ Cache *cache;
1415
+ schar *sign;
1416
+ int *index;
1417
+ mutable int next_buffer;
1418
+ Qfloat *buffer[2];
1419
+ double *QD;
1420
+ };
1421
+
1422
+ //
1423
+ // construct and solve various formulations
1424
+ //
1425
+ static void solve_c_svc(
1426
+ const svm_problem *prob, const svm_parameter* param,
1427
+ double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1428
+ {
1429
+ int l = prob->l;
1430
+ double *minus_ones = new double[l];
1431
+ schar *y = new schar[l];
1432
+
1433
+ int i;
1434
+
1435
+ for(i=0;i<l;i++)
1436
+ {
1437
+ alpha[i] = 0;
1438
+ minus_ones[i] = -1;
1439
+ if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
1440
+ }
1441
+
1442
+ Solver s;
1443
+ s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1444
+ alpha, Cp, Cn, param->eps, si, param->shrinking);
1445
+
1446
+ double sum_alpha=0;
1447
+ for(i=0;i<l;i++)
1448
+ sum_alpha += alpha[i];
1449
+
1450
+ if (Cp==Cn)
1451
+ info("nu = %f\n", sum_alpha/(Cp*prob->l));
1452
+
1453
+ for(i=0;i<l;i++)
1454
+ alpha[i] *= y[i];
1455
+
1456
+ delete[] minus_ones;
1457
+ delete[] y;
1458
+ }
1459
+
1460
+ static void solve_nu_svc(
1461
+ const svm_problem *prob, const svm_parameter *param,
1462
+ double *alpha, Solver::SolutionInfo* si)
1463
+ {
1464
+ int i;
1465
+ int l = prob->l;
1466
+ double nu = param->nu;
1467
+
1468
+ schar *y = new schar[l];
1469
+
1470
+ for(i=0;i<l;i++)
1471
+ if(prob->y[i]>0)
1472
+ y[i] = +1;
1473
+ else
1474
+ y[i] = -1;
1475
+
1476
+ double sum_pos = nu*l/2;
1477
+ double sum_neg = nu*l/2;
1478
+
1479
+ for(i=0;i<l;i++)
1480
+ if(y[i] == +1)
1481
+ {
1482
+ alpha[i] = min(1.0,sum_pos);
1483
+ sum_pos -= alpha[i];
1484
+ }
1485
+ else
1486
+ {
1487
+ alpha[i] = min(1.0,sum_neg);
1488
+ sum_neg -= alpha[i];
1489
+ }
1490
+
1491
+ double *zeros = new double[l];
1492
+
1493
+ for(i=0;i<l;i++)
1494
+ zeros[i] = 0;
1495
+
1496
+ Solver_NU s;
1497
+ s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1498
+ alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1499
+ double r = si->r;
1500
+
1501
+ info("C = %f\n",1/r);
1502
+
1503
+ for(i=0;i<l;i++)
1504
+ alpha[i] *= y[i]/r;
1505
+
1506
+ si->rho /= r;
1507
+ si->obj /= (r*r);
1508
+ si->upper_bound_p = 1/r;
1509
+ si->upper_bound_n = 1/r;
1510
+
1511
+ delete[] y;
1512
+ delete[] zeros;
1513
+ }
1514
+
1515
+ static void solve_one_class(
1516
+ const svm_problem *prob, const svm_parameter *param,
1517
+ double *alpha, Solver::SolutionInfo* si)
1518
+ {
1519
+ int l = prob->l;
1520
+ double *zeros = new double[l];
1521
+ schar *ones = new schar[l];
1522
+ int i;
1523
+
1524
+ int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
1525
+
1526
+ for(i=0;i<n;i++)
1527
+ alpha[i] = 1;
1528
+ if(n<prob->l)
1529
+ alpha[n] = param->nu * prob->l - n;
1530
+ for(i=n+1;i<l;i++)
1531
+ alpha[i] = 0;
1532
+
1533
+ for(i=0;i<l;i++)
1534
+ {
1535
+ zeros[i] = 0;
1536
+ ones[i] = 1;
1537
+ }
1538
+
1539
+ Solver s;
1540
+ s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1541
+ alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1542
+
1543
+ delete[] zeros;
1544
+ delete[] ones;
1545
+ }
1546
+
1547
+ static void solve_epsilon_svr(
1548
+ const svm_problem *prob, const svm_parameter *param,
1549
+ double *alpha, Solver::SolutionInfo* si)
1550
+ {
1551
+ int l = prob->l;
1552
+ double *alpha2 = new double[2*l];
1553
+ double *linear_term = new double[2*l];
1554
+ schar *y = new schar[2*l];
1555
+ int i;
1556
+
1557
+ for(i=0;i<l;i++)
1558
+ {
1559
+ alpha2[i] = 0;
1560
+ linear_term[i] = param->p - prob->y[i];
1561
+ y[i] = 1;
1562
+
1563
+ alpha2[i+l] = 0;
1564
+ linear_term[i+l] = param->p + prob->y[i];
1565
+ y[i+l] = -1;
1566
+ }
1567
+
1568
+ Solver s;
1569
+ s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1570
+ alpha2, param->C, param->C, param->eps, si, param->shrinking);
1571
+
1572
+ double sum_alpha = 0;
1573
+ for(i=0;i<l;i++)
1574
+ {
1575
+ alpha[i] = alpha2[i] - alpha2[i+l];
1576
+ sum_alpha += fabs(alpha[i]);
1577
+ }
1578
+ info("nu = %f\n",sum_alpha/(param->C*l));
1579
+
1580
+ delete[] alpha2;
1581
+ delete[] linear_term;
1582
+ delete[] y;
1583
+ }
1584
+
1585
+ static void solve_nu_svr(
1586
+ const svm_problem *prob, const svm_parameter *param,
1587
+ double *alpha, Solver::SolutionInfo* si)
1588
+ {
1589
+ int l = prob->l;
1590
+ double C = param->C;
1591
+ double *alpha2 = new double[2*l];
1592
+ double *linear_term = new double[2*l];
1593
+ schar *y = new schar[2*l];
1594
+ int i;
1595
+
1596
+ double sum = C * param->nu * l / 2;
1597
+ for(i=0;i<l;i++)
1598
+ {
1599
+ alpha2[i] = alpha2[i+l] = min(sum,C);
1600
+ sum -= alpha2[i];
1601
+
1602
+ linear_term[i] = - prob->y[i];
1603
+ y[i] = 1;
1604
+
1605
+ linear_term[i+l] = prob->y[i];
1606
+ y[i+l] = -1;
1607
+ }
1608
+
1609
+ Solver_NU s;
1610
+ s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1611
+ alpha2, C, C, param->eps, si, param->shrinking);
1612
+
1613
+ info("epsilon = %f\n",-si->r);
1614
+
1615
+ for(i=0;i<l;i++)
1616
+ alpha[i] = alpha2[i] - alpha2[i+l];
1617
+
1618
+ delete[] alpha2;
1619
+ delete[] linear_term;
1620
+ delete[] y;
1621
+ }
1622
+
1623
+ //
1624
+ // decision_function
1625
+ //
1626
+ struct decision_function
1627
+ {
1628
+ double *alpha;
1629
+ double rho;
1630
+ };
1631
+
1632
+ static decision_function svm_train_one(
1633
+ const svm_problem *prob, const svm_parameter *param,
1634
+ double Cp, double Cn)
1635
+ {
1636
+ double *alpha = Malloc(double,prob->l);
1637
+ Solver::SolutionInfo si;
1638
+ switch(param->svm_type)
1639
+ {
1640
+ case C_SVC:
1641
+ solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1642
+ break;
1643
+ case NU_SVC:
1644
+ solve_nu_svc(prob,param,alpha,&si);
1645
+ break;
1646
+ case ONE_CLASS:
1647
+ solve_one_class(prob,param,alpha,&si);
1648
+ break;
1649
+ case EPSILON_SVR:
1650
+ solve_epsilon_svr(prob,param,alpha,&si);
1651
+ break;
1652
+ case NU_SVR:
1653
+ solve_nu_svr(prob,param,alpha,&si);
1654
+ break;
1655
+ }
1656
+
1657
+ info("obj = %f, rho = %f\n",si.obj,si.rho);
1658
+
1659
+ // output SVs
1660
+
1661
+ int nSV = 0;
1662
+ int nBSV = 0;
1663
+ for(int i=0;i<prob->l;i++)
1664
+ {
1665
+ if(fabs(alpha[i]) > 0)
1666
+ {
1667
+ ++nSV;
1668
+ if(prob->y[i] > 0)
1669
+ {
1670
+ if(fabs(alpha[i]) >= si.upper_bound_p)
1671
+ ++nBSV;
1672
+ }
1673
+ else
1674
+ {
1675
+ if(fabs(alpha[i]) >= si.upper_bound_n)
1676
+ ++nBSV;
1677
+ }
1678
+ }
1679
+ }
1680
+
1681
+ info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1682
+
1683
+ decision_function f;
1684
+ f.alpha = alpha;
1685
+ f.rho = si.rho;
1686
+ return f;
1687
+ }
1688
+
1689
+ // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1690
+ static void sigmoid_train(
1691
+ int l, const double *dec_values, const double *labels,
1692
+ double& A, double& B)
1693
+ {
1694
+ double prior1=0, prior0 = 0;
1695
+ int i;
1696
+
1697
+ for (i=0;i<l;i++)
1698
+ if (labels[i] > 0) prior1+=1;
1699
+ else prior0+=1;
1700
+
1701
+ int max_iter=100; // Maximal number of iterations
1702
+ double min_step=1e-10; // Minimal step taken in line search
1703
+ double sigma=1e-12; // For numerically strict PD of Hessian
1704
+ double eps=1e-5;
1705
+ double hiTarget=(prior1+1.0)/(prior1+2.0);
1706
+ double loTarget=1/(prior0+2.0);
1707
+ double *t=Malloc(double,l);
1708
+ double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1709
+ double newA,newB,newf,d1,d2;
1710
+ int iter;
1711
+
1712
+ // Initial Point and Initial Fun Value
1713
+ A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1714
+ double fval = 0.0;
1715
+
1716
+ for (i=0;i<l;i++)
1717
+ {
1718
+ if (labels[i]>0) t[i]=hiTarget;
1719
+ else t[i]=loTarget;
1720
+ fApB = dec_values[i]*A+B;
1721
+ if (fApB>=0)
1722
+ fval += t[i]*fApB + log(1+exp(-fApB));
1723
+ else
1724
+ fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1725
+ }
1726
+ for (iter=0;iter<max_iter;iter++)
1727
+ {
1728
+ // Update Gradient and Hessian (use H' = H + sigma I)
1729
+ h11=sigma; // numerically ensures strict PD
1730
+ h22=sigma;
1731
+ h21=0.0;g1=0.0;g2=0.0;
1732
+ for (i=0;i<l;i++)
1733
+ {
1734
+ fApB = dec_values[i]*A+B;
1735
+ if (fApB >= 0)
1736
+ {
1737
+ p=exp(-fApB)/(1.0+exp(-fApB));
1738
+ q=1.0/(1.0+exp(-fApB));
1739
+ }
1740
+ else
1741
+ {
1742
+ p=1.0/(1.0+exp(fApB));
1743
+ q=exp(fApB)/(1.0+exp(fApB));
1744
+ }
1745
+ d2=p*q;
1746
+ h11+=dec_values[i]*dec_values[i]*d2;
1747
+ h22+=d2;
1748
+ h21+=dec_values[i]*d2;
1749
+ d1=t[i]-p;
1750
+ g1+=dec_values[i]*d1;
1751
+ g2+=d1;
1752
+ }
1753
+
1754
+ // Stopping Criteria
1755
+ if (fabs(g1)<eps && fabs(g2)<eps)
1756
+ break;
1757
+
1758
+ // Finding Newton direction: -inv(H') * g
1759
+ det=h11*h22-h21*h21;
1760
+ dA=-(h22*g1 - h21 * g2) / det;
1761
+ dB=-(-h21*g1+ h11 * g2) / det;
1762
+ gd=g1*dA+g2*dB;
1763
+
1764
+
1765
+ stepsize = 1; // Line Search
1766
+ while (stepsize >= min_step)
1767
+ {
1768
+ newA = A + stepsize * dA;
1769
+ newB = B + stepsize * dB;
1770
+
1771
+ // New function value
1772
+ newf = 0.0;
1773
+ for (i=0;i<l;i++)
1774
+ {
1775
+ fApB = dec_values[i]*newA+newB;
1776
+ if (fApB >= 0)
1777
+ newf += t[i]*fApB + log(1+exp(-fApB));
1778
+ else
1779
+ newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1780
+ }
1781
+ // Check sufficient decrease
1782
+ if (newf<fval+0.0001*stepsize*gd)
1783
+ {
1784
+ A=newA;B=newB;fval=newf;
1785
+ break;
1786
+ }
1787
+ else
1788
+ stepsize = stepsize / 2.0;
1789
+ }
1790
+
1791
+ if (stepsize < min_step)
1792
+ {
1793
+ info("Line search fails in two-class probability estimates\n");
1794
+ break;
1795
+ }
1796
+ }
1797
+
1798
+ if (iter>=max_iter)
1799
+ info("Reaching maximal iterations in two-class probability estimates\n");
1800
+ free(t);
1801
+ }
1802
+
1803
+ static double sigmoid_predict(double decision_value, double A, double B)
1804
+ {
1805
+ double fApB = decision_value*A+B;
1806
+ if (fApB >= 0)
1807
+ return exp(-fApB)/(1.0+exp(-fApB));
1808
+ else
1809
+ return 1.0/(1+exp(fApB)) ;
1810
+ }
1811
+
1812
+ // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1813
+ static void multiclass_probability(int k, double **r, double *p)
1814
+ {
1815
+ int t,j;
1816
+ int iter = 0, max_iter=max(100,k);
1817
+ double **Q=Malloc(double *,k);
1818
+ double *Qp=Malloc(double,k);
1819
+ double pQp, eps=0.005/k;
1820
+
1821
+ for (t=0;t<k;t++)
1822
+ {
1823
+ p[t]=1.0/k; // Valid if k = 1
1824
+ Q[t]=Malloc(double,k);
1825
+ Q[t][t]=0;
1826
+ for (j=0;j<t;j++)
1827
+ {
1828
+ Q[t][t]+=r[j][t]*r[j][t];
1829
+ Q[t][j]=Q[j][t];
1830
+ }
1831
+ for (j=t+1;j<k;j++)
1832
+ {
1833
+ Q[t][t]+=r[j][t]*r[j][t];
1834
+ Q[t][j]=-r[j][t]*r[t][j];
1835
+ }
1836
+ }
1837
+ for (iter=0;iter<max_iter;iter++)
1838
+ {
1839
+ // stopping condition, recalculate QP,pQP for numerical accuracy
1840
+ pQp=0;
1841
+ for (t=0;t<k;t++)
1842
+ {
1843
+ Qp[t]=0;
1844
+ for (j=0;j<k;j++)
1845
+ Qp[t]+=Q[t][j]*p[j];
1846
+ pQp+=p[t]*Qp[t];
1847
+ }
1848
+ double max_error=0;
1849
+ for (t=0;t<k;t++)
1850
+ {
1851
+ double error=fabs(Qp[t]-pQp);
1852
+ if (error>max_error)
1853
+ max_error=error;
1854
+ }
1855
+ if (max_error<eps) break;
1856
+
1857
+ for (t=0;t<k;t++)
1858
+ {
1859
+ double diff=(-Qp[t]+pQp)/Q[t][t];
1860
+ p[t]+=diff;
1861
+ pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1862
+ for (j=0;j<k;j++)
1863
+ {
1864
+ Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1865
+ p[j]/=(1+diff);
1866
+ }
1867
+ }
1868
+ }
1869
+ if (iter>=max_iter)
1870
+ info("Exceeds max_iter in multiclass_prob\n");
1871
+ for(t=0;t<k;t++) free(Q[t]);
1872
+ free(Q);
1873
+ free(Qp);
1874
+ }
1875
+
1876
+ // Cross-validation decision values for probability estimates
1877
+ static void svm_binary_svc_probability(
1878
+ const svm_problem *prob, const svm_parameter *param,
1879
+ double Cp, double Cn, double& probA, double& probB)
1880
+ {
1881
+ int i;
1882
+ int nr_fold = 5;
1883
+ int *perm = Malloc(int,prob->l);
1884
+ double *dec_values = Malloc(double,prob->l);
1885
+
1886
+ // random shuffle
1887
+ for(i=0;i<prob->l;i++) perm[i]=i;
1888
+ for(i=0;i<prob->l;i++)
1889
+ {
1890
+ int j = i+rand()%(prob->l-i);
1891
+ swap(perm[i],perm[j]);
1892
+ }
1893
+ for(i=0;i<nr_fold;i++)
1894
+ {
1895
+ int begin = i*prob->l/nr_fold;
1896
+ int end = (i+1)*prob->l/nr_fold;
1897
+ int j,k;
1898
+ struct svm_problem subprob;
1899
+
1900
+ subprob.l = prob->l-(end-begin);
1901
+ subprob.x = Malloc(struct svm_node*,subprob.l);
1902
+ subprob.y = Malloc(double,subprob.l);
1903
+
1904
+ k=0;
1905
+ for(j=0;j<begin;j++)
1906
+ {
1907
+ subprob.x[k] = prob->x[perm[j]];
1908
+ subprob.y[k] = prob->y[perm[j]];
1909
+ ++k;
1910
+ }
1911
+ for(j=end;j<prob->l;j++)
1912
+ {
1913
+ subprob.x[k] = prob->x[perm[j]];
1914
+ subprob.y[k] = prob->y[perm[j]];
1915
+ ++k;
1916
+ }
1917
+ int p_count=0,n_count=0;
1918
+ for(j=0;j<k;j++)
1919
+ if(subprob.y[j]>0)
1920
+ p_count++;
1921
+ else
1922
+ n_count++;
1923
+
1924
+ if(p_count==0 && n_count==0)
1925
+ for(j=begin;j<end;j++)
1926
+ dec_values[perm[j]] = 0;
1927
+ else if(p_count > 0 && n_count == 0)
1928
+ for(j=begin;j<end;j++)
1929
+ dec_values[perm[j]] = 1;
1930
+ else if(p_count == 0 && n_count > 0)
1931
+ for(j=begin;j<end;j++)
1932
+ dec_values[perm[j]] = -1;
1933
+ else
1934
+ {
1935
+ svm_parameter subparam = *param;
1936
+ subparam.probability=0;
1937
+ subparam.C=1.0;
1938
+ subparam.nr_weight=2;
1939
+ subparam.weight_label = Malloc(int,2);
1940
+ subparam.weight = Malloc(double,2);
1941
+ subparam.weight_label[0]=+1;
1942
+ subparam.weight_label[1]=-1;
1943
+ subparam.weight[0]=Cp;
1944
+ subparam.weight[1]=Cn;
1945
+ struct svm_model *submodel = svm_train(&subprob,&subparam);
1946
+ for(j=begin;j<end;j++)
1947
+ {
1948
+ svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
1949
+ // ensure +1 -1 order; reason not using CV subroutine
1950
+ dec_values[perm[j]] *= submodel->label[0];
1951
+ }
1952
+ svm_free_and_destroy_model(&submodel);
1953
+ svm_destroy_param(&subparam);
1954
+ }
1955
+ free(subprob.x);
1956
+ free(subprob.y);
1957
+ }
1958
+ sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1959
+ free(dec_values);
1960
+ free(perm);
1961
+ }
1962
+
1963
+ // Return parameter of a Laplace distribution
1964
+ static double svm_svr_probability(
1965
+ const svm_problem *prob, const svm_parameter *param)
1966
+ {
1967
+ int i;
1968
+ int nr_fold = 5;
1969
+ double *ymv = Malloc(double,prob->l);
1970
+ double mae = 0;
1971
+
1972
+ svm_parameter newparam = *param;
1973
+ newparam.probability = 0;
1974
+ svm_cross_validation(prob,&newparam,nr_fold,ymv);
1975
+ for(i=0;i<prob->l;i++)
1976
+ {
1977
+ ymv[i]=prob->y[i]-ymv[i];
1978
+ mae += fabs(ymv[i]);
1979
+ }
1980
+ mae /= prob->l;
1981
+ double std=sqrt(2*mae*mae);
1982
+ int count=0;
1983
+ mae=0;
1984
+ for(i=0;i<prob->l;i++)
1985
+ if (fabs(ymv[i]) > 5*std)
1986
+ count=count+1;
1987
+ else
1988
+ mae+=fabs(ymv[i]);
1989
+ mae /= (prob->l-count);
1990
+ info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
1991
+ free(ymv);
1992
+ return mae;
1993
+ }
1994
+
1995
+
1996
+ // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
1997
+ // perm, length l, must be allocated before calling this subroutine
1998
+ static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
1999
+ {
2000
+ int l = prob->l;
2001
+ int max_nr_class = 16;
2002
+ int nr_class = 0;
2003
+ int *label = Malloc(int,max_nr_class);
2004
+ int *count = Malloc(int,max_nr_class);
2005
+ int *data_label = Malloc(int,l);
2006
+ int i;
2007
+
2008
+ for(i=0;i<l;i++)
2009
+ {
2010
+ int this_label = (int)prob->y[i];
2011
+ int j;
2012
+ for(j=0;j<nr_class;j++)
2013
+ {
2014
+ if(this_label == label[j])
2015
+ {
2016
+ ++count[j];
2017
+ break;
2018
+ }
2019
+ }
2020
+ data_label[i] = j;
2021
+ if(j == nr_class)
2022
+ {
2023
+ if(nr_class == max_nr_class)
2024
+ {
2025
+ max_nr_class *= 2;
2026
+ label = (int *)realloc(label,max_nr_class*sizeof(int));
2027
+ count = (int *)realloc(count,max_nr_class*sizeof(int));
2028
+ }
2029
+ label[nr_class] = this_label;
2030
+ count[nr_class] = 1;
2031
+ ++nr_class;
2032
+ }
2033
+ }
2034
+
2035
+ int *start = Malloc(int,nr_class);
2036
+ start[0] = 0;
2037
+ for(i=1;i<nr_class;i++)
2038
+ start[i] = start[i-1]+count[i-1];
2039
+ for(i=0;i<l;i++)
2040
+ {
2041
+ perm[start[data_label[i]]] = i;
2042
+ ++start[data_label[i]];
2043
+ }
2044
+ start[0] = 0;
2045
+ for(i=1;i<nr_class;i++)
2046
+ start[i] = start[i-1]+count[i-1];
2047
+
2048
+ *nr_class_ret = nr_class;
2049
+ *label_ret = label;
2050
+ *start_ret = start;
2051
+ *count_ret = count;
2052
+ free(data_label);
2053
+ }
2054
+
2055
+ //
2056
+ // Interface functions
2057
+ //
2058
+ svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2059
+ {
2060
+ svm_model *model = Malloc(svm_model,1);
2061
+ model->param = *param;
2062
+ model->free_sv = 0; // XXX
2063
+
2064
+ if(param->svm_type == ONE_CLASS ||
2065
+ param->svm_type == EPSILON_SVR ||
2066
+ param->svm_type == NU_SVR)
2067
+ {
2068
+ // regression or one-class-svm
2069
+ model->nr_class = 2;
2070
+ model->label = NULL;
2071
+ model->nSV = NULL;
2072
+ model->probA = NULL; model->probB = NULL;
2073
+ model->sv_coef = Malloc(double *,1);
2074
+
2075
+ if(param->probability &&
2076
+ (param->svm_type == EPSILON_SVR ||
2077
+ param->svm_type == NU_SVR))
2078
+ {
2079
+ model->probA = Malloc(double,1);
2080
+ model->probA[0] = svm_svr_probability(prob,param);
2081
+ }
2082
+
2083
+ decision_function f = svm_train_one(prob,param,0,0);
2084
+ model->rho = Malloc(double,1);
2085
+ model->rho[0] = f.rho;
2086
+
2087
+ int nSV = 0;
2088
+ int i;
2089
+ for(i=0;i<prob->l;i++)
2090
+ if(fabs(f.alpha[i]) > 0) ++nSV;
2091
+ model->l = nSV;
2092
+ model->SV = Malloc(svm_node *,nSV);
2093
+ model->sv_coef[0] = Malloc(double,nSV);
2094
+ int j = 0;
2095
+ for(i=0;i<prob->l;i++)
2096
+ if(fabs(f.alpha[i]) > 0)
2097
+ {
2098
+ model->SV[j] = prob->x[i];
2099
+ model->sv_coef[0][j] = f.alpha[i];
2100
+ ++j;
2101
+ }
2102
+
2103
+ free(f.alpha);
2104
+ }
2105
+ else
2106
+ {
2107
+ // classification
2108
+ int l = prob->l;
2109
+ int nr_class;
2110
+ int *label = NULL;
2111
+ int *start = NULL;
2112
+ int *count = NULL;
2113
+ int *perm = Malloc(int,l);
2114
+
2115
+ // group training data of the same class
2116
+ svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2117
+ svm_node **x = Malloc(svm_node *,l);
2118
+ int i;
2119
+ for(i=0;i<l;i++)
2120
+ x[i] = prob->x[perm[i]];
2121
+
2122
+ // calculate weighted C
2123
+
2124
+ double *weighted_C = Malloc(double, nr_class);
2125
+ for(i=0;i<nr_class;i++)
2126
+ weighted_C[i] = param->C;
2127
+ for(i=0;i<param->nr_weight;i++)
2128
+ {
2129
+ int j;
2130
+ for(j=0;j<nr_class;j++)
2131
+ if(param->weight_label[i] == label[j])
2132
+ break;
2133
+ if(j == nr_class)
2134
+ fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2135
+ else
2136
+ weighted_C[j] *= param->weight[i];
2137
+ }
2138
+
2139
+ // train k*(k-1)/2 models
2140
+
2141
+ bool *nonzero = Malloc(bool,l);
2142
+ for(i=0;i<l;i++)
2143
+ nonzero[i] = false;
2144
+ decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2145
+
2146
+ double *probA=NULL,*probB=NULL;
2147
+ if (param->probability)
2148
+ {
2149
+ probA=Malloc(double,nr_class*(nr_class-1)/2);
2150
+ probB=Malloc(double,nr_class*(nr_class-1)/2);
2151
+ }
2152
+
2153
+ int p = 0;
2154
+ for(i=0;i<nr_class;i++)
2155
+ for(int j=i+1;j<nr_class;j++)
2156
+ {
2157
+ svm_problem sub_prob;
2158
+ int si = start[i], sj = start[j];
2159
+ int ci = count[i], cj = count[j];
2160
+ sub_prob.l = ci+cj;
2161
+ sub_prob.x = Malloc(svm_node *,sub_prob.l);
2162
+ sub_prob.y = Malloc(double,sub_prob.l);
2163
+ int k;
2164
+ for(k=0;k<ci;k++)
2165
+ {
2166
+ sub_prob.x[k] = x[si+k];
2167
+ sub_prob.y[k] = +1;
2168
+ }
2169
+ for(k=0;k<cj;k++)
2170
+ {
2171
+ sub_prob.x[ci+k] = x[sj+k];
2172
+ sub_prob.y[ci+k] = -1;
2173
+ }
2174
+
2175
+ if(param->probability)
2176
+ svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2177
+
2178
+ f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2179
+ for(k=0;k<ci;k++)
2180
+ if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2181
+ nonzero[si+k] = true;
2182
+ for(k=0;k<cj;k++)
2183
+ if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2184
+ nonzero[sj+k] = true;
2185
+ free(sub_prob.x);
2186
+ free(sub_prob.y);
2187
+ ++p;
2188
+ }
2189
+
2190
+ // build output
2191
+
2192
+ model->nr_class = nr_class;
2193
+
2194
+ model->label = Malloc(int,nr_class);
2195
+ for(i=0;i<nr_class;i++)
2196
+ model->label[i] = label[i];
2197
+
2198
+ model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2199
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2200
+ model->rho[i] = f[i].rho;
2201
+
2202
+ if(param->probability)
2203
+ {
2204
+ model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2205
+ model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2206
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2207
+ {
2208
+ model->probA[i] = probA[i];
2209
+ model->probB[i] = probB[i];
2210
+ }
2211
+ }
2212
+ else
2213
+ {
2214
+ model->probA=NULL;
2215
+ model->probB=NULL;
2216
+ }
2217
+
2218
+ int total_sv = 0;
2219
+ int *nz_count = Malloc(int,nr_class);
2220
+ model->nSV = Malloc(int,nr_class);
2221
+ for(i=0;i<nr_class;i++)
2222
+ {
2223
+ int nSV = 0;
2224
+ for(int j=0;j<count[i];j++)
2225
+ if(nonzero[start[i]+j])
2226
+ {
2227
+ ++nSV;
2228
+ ++total_sv;
2229
+ }
2230
+ model->nSV[i] = nSV;
2231
+ nz_count[i] = nSV;
2232
+ }
2233
+
2234
+ info("Total nSV = %d\n",total_sv);
2235
+
2236
+ model->l = total_sv;
2237
+ model->SV = Malloc(svm_node *,total_sv);
2238
+ p = 0;
2239
+ for(i=0;i<l;i++)
2240
+ if(nonzero[i]) model->SV[p++] = x[i];
2241
+
2242
+ int *nz_start = Malloc(int,nr_class);
2243
+ nz_start[0] = 0;
2244
+ for(i=1;i<nr_class;i++)
2245
+ nz_start[i] = nz_start[i-1]+nz_count[i-1];
2246
+
2247
+ model->sv_coef = Malloc(double *,nr_class-1);
2248
+ for(i=0;i<nr_class-1;i++)
2249
+ model->sv_coef[i] = Malloc(double,total_sv);
2250
+
2251
+ p = 0;
2252
+ for(i=0;i<nr_class;i++)
2253
+ for(int j=i+1;j<nr_class;j++)
2254
+ {
2255
+ // classifier (i,j): coefficients with
2256
+ // i are in sv_coef[j-1][nz_start[i]...],
2257
+ // j are in sv_coef[i][nz_start[j]...]
2258
+
2259
+ int si = start[i];
2260
+ int sj = start[j];
2261
+ int ci = count[i];
2262
+ int cj = count[j];
2263
+
2264
+ int q = nz_start[i];
2265
+ int k;
2266
+ for(k=0;k<ci;k++)
2267
+ if(nonzero[si+k])
2268
+ model->sv_coef[j-1][q++] = f[p].alpha[k];
2269
+ q = nz_start[j];
2270
+ for(k=0;k<cj;k++)
2271
+ if(nonzero[sj+k])
2272
+ model->sv_coef[i][q++] = f[p].alpha[ci+k];
2273
+ ++p;
2274
+ }
2275
+
2276
+ free(label);
2277
+ free(probA);
2278
+ free(probB);
2279
+ free(count);
2280
+ free(perm);
2281
+ free(start);
2282
+ free(x);
2283
+ free(weighted_C);
2284
+ free(nonzero);
2285
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2286
+ free(f[i].alpha);
2287
+ free(f);
2288
+ free(nz_count);
2289
+ free(nz_start);
2290
+ }
2291
+ return model;
2292
+ }
2293
+
2294
+ // Stratified cross validation
2295
+ void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2296
+ {
2297
+ int i;
2298
+ int *fold_start = Malloc(int,nr_fold+1);
2299
+ int l = prob->l;
2300
+ int *perm = Malloc(int,l);
2301
+ int nr_class;
2302
+
2303
+ // stratified cv may not give leave-one-out rate
2304
+ // Each class to l folds -> some folds may have zero elements
2305
+ if((param->svm_type == C_SVC ||
2306
+ param->svm_type == NU_SVC) && nr_fold < l)
2307
+ {
2308
+ int *start = NULL;
2309
+ int *label = NULL;
2310
+ int *count = NULL;
2311
+ svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2312
+
2313
+ // random shuffle and then data grouped by fold using the array perm
2314
+ int *fold_count = Malloc(int,nr_fold);
2315
+ int c;
2316
+ int *index = Malloc(int,l);
2317
+ for(i=0;i<l;i++)
2318
+ index[i]=perm[i];
2319
+ for (c=0; c<nr_class; c++)
2320
+ for(i=0;i<count[c];i++)
2321
+ {
2322
+ int j = i+rand()%(count[c]-i);
2323
+ swap(index[start[c]+j],index[start[c]+i]);
2324
+ }
2325
+ for(i=0;i<nr_fold;i++)
2326
+ {
2327
+ fold_count[i] = 0;
2328
+ for (c=0; c<nr_class;c++)
2329
+ fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2330
+ }
2331
+ fold_start[0]=0;
2332
+ for (i=1;i<=nr_fold;i++)
2333
+ fold_start[i] = fold_start[i-1]+fold_count[i-1];
2334
+ for (c=0; c<nr_class;c++)
2335
+ for(i=0;i<nr_fold;i++)
2336
+ {
2337
+ int begin = start[c]+i*count[c]/nr_fold;
2338
+ int end = start[c]+(i+1)*count[c]/nr_fold;
2339
+ for(int j=begin;j<end;j++)
2340
+ {
2341
+ perm[fold_start[i]] = index[j];
2342
+ fold_start[i]++;
2343
+ }
2344
+ }
2345
+ fold_start[0]=0;
2346
+ for (i=1;i<=nr_fold;i++)
2347
+ fold_start[i] = fold_start[i-1]+fold_count[i-1];
2348
+ free(start);
2349
+ free(label);
2350
+ free(count);
2351
+ free(index);
2352
+ free(fold_count);
2353
+ }
2354
+ else
2355
+ {
2356
+ for(i=0;i<l;i++) perm[i]=i;
2357
+ for(i=0;i<l;i++)
2358
+ {
2359
+ int j = i+rand()%(l-i);
2360
+ swap(perm[i],perm[j]);
2361
+ }
2362
+ for(i=0;i<=nr_fold;i++)
2363
+ fold_start[i]=i*l/nr_fold;
2364
+ }
2365
+
2366
+ for(i=0;i<nr_fold;i++)
2367
+ {
2368
+ int begin = fold_start[i];
2369
+ int end = fold_start[i+1];
2370
+ int j,k;
2371
+ struct svm_problem subprob;
2372
+
2373
+ subprob.l = l-(end-begin);
2374
+ subprob.x = Malloc(struct svm_node*,subprob.l);
2375
+ subprob.y = Malloc(double,subprob.l);
2376
+
2377
+ k=0;
2378
+ for(j=0;j<begin;j++)
2379
+ {
2380
+ subprob.x[k] = prob->x[perm[j]];
2381
+ subprob.y[k] = prob->y[perm[j]];
2382
+ ++k;
2383
+ }
2384
+ for(j=end;j<l;j++)
2385
+ {
2386
+ subprob.x[k] = prob->x[perm[j]];
2387
+ subprob.y[k] = prob->y[perm[j]];
2388
+ ++k;
2389
+ }
2390
+ struct svm_model *submodel = svm_train(&subprob,param);
2391
+ if(param->probability &&
2392
+ (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2393
+ {
2394
+ double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2395
+ for(j=begin;j<end;j++)
2396
+ target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2397
+ free(prob_estimates);
2398
+ }
2399
+ else
2400
+ for(j=begin;j<end;j++)
2401
+ target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2402
+ svm_free_and_destroy_model(&submodel);
2403
+ free(subprob.x);
2404
+ free(subprob.y);
2405
+ }
2406
+ free(fold_start);
2407
+ free(perm);
2408
+ }
2409
+
2410
+
2411
+ int svm_get_svm_type(const svm_model *model)
2412
+ {
2413
+ return model->param.svm_type;
2414
+ }
2415
+
2416
+ int svm_get_nr_class(const svm_model *model)
2417
+ {
2418
+ return model->nr_class;
2419
+ }
2420
+
2421
+ void svm_get_labels(const svm_model *model, int* label)
2422
+ {
2423
+ if (model->label != NULL)
2424
+ for(int i=0;i<model->nr_class;i++)
2425
+ label[i] = model->label[i];
2426
+ }
2427
+
2428
+ double svm_get_svr_probability(const svm_model *model)
2429
+ {
2430
+ if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2431
+ model->probA!=NULL)
2432
+ return model->probA[0];
2433
+ else
2434
+ {
2435
+ fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2436
+ return 0;
2437
+ }
2438
+ }
2439
+
2440
+ double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2441
+ {
2442
+ if(model->param.svm_type == ONE_CLASS ||
2443
+ model->param.svm_type == EPSILON_SVR ||
2444
+ model->param.svm_type == NU_SVR)
2445
+ {
2446
+ double *sv_coef = model->sv_coef[0];
2447
+ double sum = 0;
2448
+ for(int i=0;i<model->l;i++)
2449
+ sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2450
+ sum -= model->rho[0];
2451
+ *dec_values = sum;
2452
+
2453
+ if(model->param.svm_type == ONE_CLASS)
2454
+ return (sum>0)?1:-1;
2455
+ else
2456
+ return sum;
2457
+ }
2458
+ else
2459
+ {
2460
+ int i;
2461
+ int nr_class = model->nr_class;
2462
+ int l = model->l;
2463
+
2464
+ double *kvalue = Malloc(double,l);
2465
+ for(i=0;i<l;i++)
2466
+ kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2467
+
2468
+ int *start = Malloc(int,nr_class);
2469
+ start[0] = 0;
2470
+ for(i=1;i<nr_class;i++)
2471
+ start[i] = start[i-1]+model->nSV[i-1];
2472
+
2473
+ int *vote = Malloc(int,nr_class);
2474
+ for(i=0;i<nr_class;i++)
2475
+ vote[i] = 0;
2476
+
2477
+ int p=0;
2478
+ for(i=0;i<nr_class;i++)
2479
+ for(int j=i+1;j<nr_class;j++)
2480
+ {
2481
+ double sum = 0;
2482
+ int si = start[i];
2483
+ int sj = start[j];
2484
+ int ci = model->nSV[i];
2485
+ int cj = model->nSV[j];
2486
+
2487
+ int k;
2488
+ double *coef1 = model->sv_coef[j-1];
2489
+ double *coef2 = model->sv_coef[i];
2490
+ for(k=0;k<ci;k++)
2491
+ sum += coef1[si+k] * kvalue[si+k];
2492
+ for(k=0;k<cj;k++)
2493
+ sum += coef2[sj+k] * kvalue[sj+k];
2494
+ sum -= model->rho[p];
2495
+ dec_values[p] = sum;
2496
+
2497
+ if(dec_values[p] > 0)
2498
+ ++vote[i];
2499
+ else
2500
+ ++vote[j];
2501
+ p++;
2502
+ }
2503
+
2504
+ int vote_max_idx = 0;
2505
+ for(i=1;i<nr_class;i++)
2506
+ if(vote[i] > vote[vote_max_idx])
2507
+ vote_max_idx = i;
2508
+
2509
+ free(kvalue);
2510
+ free(start);
2511
+ free(vote);
2512
+ return model->label[vote_max_idx];
2513
+ }
2514
+ }
2515
+
2516
+ double svm_predict(const svm_model *model, const svm_node *x)
2517
+ {
2518
+ int nr_class = model->nr_class;
2519
+ double *dec_values;
2520
+ if(model->param.svm_type == ONE_CLASS ||
2521
+ model->param.svm_type == EPSILON_SVR ||
2522
+ model->param.svm_type == NU_SVR)
2523
+ dec_values = Malloc(double, 1);
2524
+ else
2525
+ dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2526
+ double pred_result = svm_predict_values(model, x, dec_values);
2527
+ free(dec_values);
2528
+ return pred_result;
2529
+ }
2530
+
2531
+ double svm_predict_probability(
2532
+ const svm_model *model, const svm_node *x, double *prob_estimates)
2533
+ {
2534
+ if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2535
+ model->probA!=NULL && model->probB!=NULL)
2536
+ {
2537
+ int i;
2538
+ int nr_class = model->nr_class;
2539
+ double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2540
+ svm_predict_values(model, x, dec_values);
2541
+
2542
+ double min_prob=1e-7;
2543
+ double **pairwise_prob=Malloc(double *,nr_class);
2544
+ for(i=0;i<nr_class;i++)
2545
+ pairwise_prob[i]=Malloc(double,nr_class);
2546
+ int k=0;
2547
+ for(i=0;i<nr_class;i++)
2548
+ for(int j=i+1;j<nr_class;j++)
2549
+ {
2550
+ pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2551
+ pairwise_prob[j][i]=1-pairwise_prob[i][j];
2552
+ k++;
2553
+ }
2554
+ multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2555
+
2556
+ int prob_max_idx = 0;
2557
+ for(i=1;i<nr_class;i++)
2558
+ if(prob_estimates[i] > prob_estimates[prob_max_idx])
2559
+ prob_max_idx = i;
2560
+ for(i=0;i<nr_class;i++)
2561
+ free(pairwise_prob[i]);
2562
+ free(dec_values);
2563
+ free(pairwise_prob);
2564
+ return model->label[prob_max_idx];
2565
+ }
2566
+ else
2567
+ return svm_predict(model, x);
2568
+ }
2569
+
2570
+ static const char *svm_type_table[] =
2571
+ {
2572
+ "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2573
+ };
2574
+
2575
+ static const char *kernel_type_table[]=
2576
+ {
2577
+ "linear","polynomial","rbf","sigmoid","precomputed",NULL
2578
+ };
2579
+
2580
+ int svm_save_model(const char *model_file_name, const svm_model *model)
2581
+ {
2582
+ FILE *fp = fopen(model_file_name,"w");
2583
+ if(fp==NULL) return -1;
2584
+
2585
+ const svm_parameter& param = model->param;
2586
+
2587
+ fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2588
+ fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2589
+
2590
+ if(param.kernel_type == POLY)
2591
+ fprintf(fp,"degree %d\n", param.degree);
2592
+
2593
+ if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2594
+ fprintf(fp,"gamma %g\n", param.gamma);
2595
+
2596
+ if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2597
+ fprintf(fp,"coef0 %g\n", param.coef0);
2598
+
2599
+ int nr_class = model->nr_class;
2600
+ int l = model->l;
2601
+ fprintf(fp, "nr_class %d\n", nr_class);
2602
+ fprintf(fp, "total_sv %d\n",l);
2603
+
2604
+ {
2605
+ fprintf(fp, "rho");
2606
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2607
+ fprintf(fp," %g",model->rho[i]);
2608
+ fprintf(fp, "\n");
2609
+ }
2610
+
2611
+ if(model->label)
2612
+ {
2613
+ fprintf(fp, "label");
2614
+ for(int i=0;i<nr_class;i++)
2615
+ fprintf(fp," %d",model->label[i]);
2616
+ fprintf(fp, "\n");
2617
+ }
2618
+
2619
+ if(model->probA) // regression has probA only
2620
+ {
2621
+ fprintf(fp, "probA");
2622
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2623
+ fprintf(fp," %g",model->probA[i]);
2624
+ fprintf(fp, "\n");
2625
+ }
2626
+ if(model->probB)
2627
+ {
2628
+ fprintf(fp, "probB");
2629
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2630
+ fprintf(fp," %g",model->probB[i]);
2631
+ fprintf(fp, "\n");
2632
+ }
2633
+
2634
+ if(model->nSV)
2635
+ {
2636
+ fprintf(fp, "nr_sv");
2637
+ for(int i=0;i<nr_class;i++)
2638
+ fprintf(fp," %d",model->nSV[i]);
2639
+ fprintf(fp, "\n");
2640
+ }
2641
+
2642
+ fprintf(fp, "SV\n");
2643
+ const double * const *sv_coef = model->sv_coef;
2644
+ const svm_node * const *SV = model->SV;
2645
+
2646
+ for(int i=0;i<l;i++)
2647
+ {
2648
+ for(int j=0;j<nr_class-1;j++)
2649
+ fprintf(fp, "%.16g ",sv_coef[j][i]);
2650
+
2651
+ const svm_node *p = SV[i];
2652
+
2653
+ if(param.kernel_type == PRECOMPUTED)
2654
+ fprintf(fp,"0:%d ",(int)(p->value));
2655
+ else
2656
+ while(p->index != -1)
2657
+ {
2658
+ fprintf(fp,"%d:%.8g ",p->index,p->value);
2659
+ p++;
2660
+ }
2661
+ fprintf(fp, "\n");
2662
+ }
2663
+ if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2664
+ else return 0;
2665
+ }
2666
+
2667
+ static char *line = NULL;
2668
+ static int max_line_len;
2669
+
2670
+ static char* readline(FILE *input)
2671
+ {
2672
+ int len;
2673
+
2674
+ if(fgets(line,max_line_len,input) == NULL)
2675
+ return NULL;
2676
+
2677
+ while(strrchr(line,'\n') == NULL)
2678
+ {
2679
+ max_line_len *= 2;
2680
+ line = (char *) realloc(line,max_line_len);
2681
+ len = (int) strlen(line);
2682
+ if(fgets(line+len,max_line_len-len,input) == NULL)
2683
+ break;
2684
+ }
2685
+ return line;
2686
+ }
2687
+
2688
+ svm_model *svm_load_model(const char *model_file_name)
2689
+ {
2690
+ FILE *fp = fopen(model_file_name,"rb");
2691
+ if(fp==NULL) return NULL;
2692
+
2693
+ // read parameters
2694
+
2695
+ svm_model *model = Malloc(svm_model,1);
2696
+ svm_parameter& param = model->param;
2697
+ model->rho = NULL;
2698
+ model->probA = NULL;
2699
+ model->probB = NULL;
2700
+ model->label = NULL;
2701
+ model->nSV = NULL;
2702
+
2703
+ char cmd[81];
2704
+ while(1)
2705
+ {
2706
+ fscanf(fp,"%80s",cmd);
2707
+
2708
+ if(strcmp(cmd,"svm_type")==0)
2709
+ {
2710
+ fscanf(fp,"%80s",cmd);
2711
+ int i;
2712
+ for(i=0;svm_type_table[i];i++)
2713
+ {
2714
+ if(strcmp(svm_type_table[i],cmd)==0)
2715
+ {
2716
+ param.svm_type=i;
2717
+ break;
2718
+ }
2719
+ }
2720
+ if(svm_type_table[i] == NULL)
2721
+ {
2722
+ fprintf(stderr,"unknown svm type.\n");
2723
+ free(model->rho);
2724
+ free(model->label);
2725
+ free(model->nSV);
2726
+ free(model);
2727
+ return NULL;
2728
+ }
2729
+ }
2730
+ else if(strcmp(cmd,"kernel_type")==0)
2731
+ {
2732
+ fscanf(fp,"%80s",cmd);
2733
+ int i;
2734
+ for(i=0;kernel_type_table[i];i++)
2735
+ {
2736
+ if(strcmp(kernel_type_table[i],cmd)==0)
2737
+ {
2738
+ param.kernel_type=i;
2739
+ break;
2740
+ }
2741
+ }
2742
+ if(kernel_type_table[i] == NULL)
2743
+ {
2744
+ fprintf(stderr,"unknown kernel function.\n");
2745
+ free(model->rho);
2746
+ free(model->label);
2747
+ free(model->nSV);
2748
+ free(model);
2749
+ return NULL;
2750
+ }
2751
+ }
2752
+ else if(strcmp(cmd,"degree")==0)
2753
+ fscanf(fp,"%d",&param.degree);
2754
+ else if(strcmp(cmd,"gamma")==0)
2755
+ fscanf(fp,"%lf",&param.gamma);
2756
+ else if(strcmp(cmd,"coef0")==0)
2757
+ fscanf(fp,"%lf",&param.coef0);
2758
+ else if(strcmp(cmd,"nr_class")==0)
2759
+ fscanf(fp,"%d",&model->nr_class);
2760
+ else if(strcmp(cmd,"total_sv")==0)
2761
+ fscanf(fp,"%d",&model->l);
2762
+ else if(strcmp(cmd,"rho")==0)
2763
+ {
2764
+ int n = model->nr_class * (model->nr_class-1)/2;
2765
+ model->rho = Malloc(double,n);
2766
+ for(int i=0;i<n;i++)
2767
+ fscanf(fp,"%lf",&model->rho[i]);
2768
+ }
2769
+ else if(strcmp(cmd,"label")==0)
2770
+ {
2771
+ int n = model->nr_class;
2772
+ model->label = Malloc(int,n);
2773
+ for(int i=0;i<n;i++)
2774
+ fscanf(fp,"%d",&model->label[i]);
2775
+ }
2776
+ else if(strcmp(cmd,"probA")==0)
2777
+ {
2778
+ int n = model->nr_class * (model->nr_class-1)/2;
2779
+ model->probA = Malloc(double,n);
2780
+ for(int i=0;i<n;i++)
2781
+ fscanf(fp,"%lf",&model->probA[i]);
2782
+ }
2783
+ else if(strcmp(cmd,"probB")==0)
2784
+ {
2785
+ int n = model->nr_class * (model->nr_class-1)/2;
2786
+ model->probB = Malloc(double,n);
2787
+ for(int i=0;i<n;i++)
2788
+ fscanf(fp,"%lf",&model->probB[i]);
2789
+ }
2790
+ else if(strcmp(cmd,"nr_sv")==0)
2791
+ {
2792
+ int n = model->nr_class;
2793
+ model->nSV = Malloc(int,n);
2794
+ for(int i=0;i<n;i++)
2795
+ fscanf(fp,"%d",&model->nSV[i]);
2796
+ }
2797
+ else if(strcmp(cmd,"SV")==0)
2798
+ {
2799
+ while(1)
2800
+ {
2801
+ int c = getc(fp);
2802
+ if(c==EOF || c=='\n') break;
2803
+ }
2804
+ break;
2805
+ }
2806
+ else
2807
+ {
2808
+ fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2809
+ free(model->rho);
2810
+ free(model->label);
2811
+ free(model->nSV);
2812
+ free(model);
2813
+ return NULL;
2814
+ }
2815
+ }
2816
+
2817
+ // read sv_coef and SV
2818
+
2819
+ int elements = 0;
2820
+ long pos = ftell(fp);
2821
+
2822
+ max_line_len = 1024;
2823
+ line = Malloc(char,max_line_len);
2824
+ char *p,*endptr,*idx,*val;
2825
+
2826
+ while(readline(fp)!=NULL)
2827
+ {
2828
+ p = strtok(line,":");
2829
+ while(1)
2830
+ {
2831
+ p = strtok(NULL,":");
2832
+ if(p == NULL)
2833
+ break;
2834
+ ++elements;
2835
+ }
2836
+ }
2837
+ elements += model->l;
2838
+
2839
+ fseek(fp,pos,SEEK_SET);
2840
+
2841
+ int m = model->nr_class - 1;
2842
+ int l = model->l;
2843
+ model->sv_coef = Malloc(double *,m);
2844
+ int i;
2845
+ for(i=0;i<m;i++)
2846
+ model->sv_coef[i] = Malloc(double,l);
2847
+ model->SV = Malloc(svm_node*,l);
2848
+ svm_node *x_space = NULL;
2849
+ if(l>0) x_space = Malloc(svm_node,elements);
2850
+
2851
+ int j=0;
2852
+ for(i=0;i<l;i++)
2853
+ {
2854
+ readline(fp);
2855
+ model->SV[i] = &x_space[j];
2856
+
2857
+ p = strtok(line, " \t");
2858
+ model->sv_coef[0][i] = strtod(p,&endptr);
2859
+ for(int k=1;k<m;k++)
2860
+ {
2861
+ p = strtok(NULL, " \t");
2862
+ model->sv_coef[k][i] = strtod(p,&endptr);
2863
+ }
2864
+
2865
+ while(1)
2866
+ {
2867
+ idx = strtok(NULL, ":");
2868
+ val = strtok(NULL, " \t");
2869
+
2870
+ if(val == NULL)
2871
+ break;
2872
+ x_space[j].index = (int) strtol(idx,&endptr,10);
2873
+ x_space[j].value = strtod(val,&endptr);
2874
+
2875
+ ++j;
2876
+ }
2877
+ x_space[j++].index = -1;
2878
+ }
2879
+ free(line);
2880
+
2881
+ if (ferror(fp) != 0 || fclose(fp) != 0)
2882
+ return NULL;
2883
+
2884
+ model->free_sv = 1; // XXX
2885
+ return model;
2886
+ }
2887
+
2888
+ void svm_free_model_content(svm_model* model_ptr)
2889
+ {
2890
+ if(model_ptr->free_sv && model_ptr->l > 0)
2891
+ free((void *)(model_ptr->SV[0]));
2892
+ for(int i=0;i<model_ptr->nr_class-1;i++)
2893
+ free(model_ptr->sv_coef[i]);
2894
+ free(model_ptr->SV);
2895
+ free(model_ptr->sv_coef);
2896
+ free(model_ptr->rho);
2897
+ free(model_ptr->label);
2898
+ free(model_ptr->probA);
2899
+ free(model_ptr->probB);
2900
+ free(model_ptr->nSV);
2901
+ }
2902
+
2903
+ void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
2904
+ {
2905
+ svm_model* model_ptr = *model_ptr_ptr;
2906
+ if(model_ptr != NULL)
2907
+ {
2908
+ svm_free_model_content(model_ptr);
2909
+ free(model_ptr);
2910
+ }
2911
+ }
2912
+
2913
+ void svm_destroy_model(svm_model* model_ptr)
2914
+ {
2915
+ fprintf(stderr,"warning: svm_destroy_model is deprecated and should not be used. Please use svm_free_and_destroy_model(svm_model **model_ptr_ptr)\n");
2916
+ svm_free_and_destroy_model(&model_ptr);
2917
+ }
2918
+
2919
+ void svm_destroy_param(svm_parameter* param)
2920
+ {
2921
+ free(param->weight_label);
2922
+ free(param->weight);
2923
+ }
2924
+
2925
+ const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
2926
+ {
2927
+ // svm_type
2928
+
2929
+ int svm_type = param->svm_type;
2930
+ if(svm_type != C_SVC &&
2931
+ svm_type != NU_SVC &&
2932
+ svm_type != ONE_CLASS &&
2933
+ svm_type != EPSILON_SVR &&
2934
+ svm_type != NU_SVR)
2935
+ return "unknown svm type";
2936
+
2937
+ // kernel_type, degree
2938
+
2939
+ int kernel_type = param->kernel_type;
2940
+ if(kernel_type != LINEAR &&
2941
+ kernel_type != POLY &&
2942
+ kernel_type != RBF &&
2943
+ kernel_type != SIGMOID &&
2944
+ kernel_type != PRECOMPUTED)
2945
+ return "unknown kernel type";
2946
+
2947
+ if(param->gamma < 0)
2948
+ return "gamma < 0";
2949
+
2950
+ if(param->degree < 0)
2951
+ return "degree of polynomial kernel < 0";
2952
+
2953
+ // cache_size,eps,C,nu,p,shrinking
2954
+
2955
+ if(param->cache_size <= 0)
2956
+ return "cache_size <= 0";
2957
+
2958
+ if(param->eps <= 0)
2959
+ return "eps <= 0";
2960
+
2961
+ if(svm_type == C_SVC ||
2962
+ svm_type == EPSILON_SVR ||
2963
+ svm_type == NU_SVR)
2964
+ if(param->C <= 0)
2965
+ return "C <= 0";
2966
+
2967
+ if(svm_type == NU_SVC ||
2968
+ svm_type == ONE_CLASS ||
2969
+ svm_type == NU_SVR)
2970
+ if(param->nu <= 0 || param->nu > 1)
2971
+ return "nu <= 0 or nu > 1";
2972
+
2973
+ if(svm_type == EPSILON_SVR)
2974
+ if(param->p < 0)
2975
+ return "p < 0";
2976
+
2977
+ if(param->shrinking != 0 &&
2978
+ param->shrinking != 1)
2979
+ return "shrinking != 0 and shrinking != 1";
2980
+
2981
+ if(param->probability != 0 &&
2982
+ param->probability != 1)
2983
+ return "probability != 0 and probability != 1";
2984
+
2985
+ if(param->probability == 1 &&
2986
+ svm_type == ONE_CLASS)
2987
+ return "one-class SVM probability output not supported yet";
2988
+
2989
+
2990
+ // check whether nu-svc is feasible
2991
+
2992
+ if(svm_type == NU_SVC)
2993
+ {
2994
+ int l = prob->l;
2995
+ int max_nr_class = 16;
2996
+ int nr_class = 0;
2997
+ int *label = Malloc(int,max_nr_class);
2998
+ int *count = Malloc(int,max_nr_class);
2999
+
3000
+ int i;
3001
+ for(i=0;i<l;i++)
3002
+ {
3003
+ int this_label = (int)prob->y[i];
3004
+ int j;
3005
+ for(j=0;j<nr_class;j++)
3006
+ if(this_label == label[j])
3007
+ {
3008
+ ++count[j];
3009
+ break;
3010
+ }
3011
+ if(j == nr_class)
3012
+ {
3013
+ if(nr_class == max_nr_class)
3014
+ {
3015
+ max_nr_class *= 2;
3016
+ label = (int *)realloc(label,max_nr_class*sizeof(int));
3017
+ count = (int *)realloc(count,max_nr_class*sizeof(int));
3018
+ }
3019
+ label[nr_class] = this_label;
3020
+ count[nr_class] = 1;
3021
+ ++nr_class;
3022
+ }
3023
+ }
3024
+
3025
+ for(i=0;i<nr_class;i++)
3026
+ {
3027
+ int n1 = count[i];
3028
+ for(int j=i+1;j<nr_class;j++)
3029
+ {
3030
+ int n2 = count[j];
3031
+ if(param->nu*(n1+n2)/2 > min(n1,n2))
3032
+ {
3033
+ free(label);
3034
+ free(count);
3035
+ return "specified nu is infeasible";
3036
+ }
3037
+ }
3038
+ }
3039
+ free(label);
3040
+ free(count);
3041
+ }
3042
+
3043
+ return NULL;
3044
+ }
3045
+
3046
+ int svm_check_probability_model(const svm_model *model)
3047
+ {
3048
+ return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3049
+ model->probA!=NULL && model->probB!=NULL) ||
3050
+ ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3051
+ model->probA!=NULL);
3052
+ }
3053
+
3054
+ void svm_set_print_string_function(void (*print_func)(const char *))
3055
+ {
3056
+ if(print_func == NULL)
3057
+ svm_print_string = &print_string_stdout;
3058
+ else
3059
+ svm_print_string = print_func;
3060
+ }