eluka 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. data/.document +5 -0
  2. data/DOCUMENTATION_STANDARDS +39 -0
  3. data/Gemfile +13 -0
  4. data/Gemfile.lock +20 -0
  5. data/LICENSE.txt +20 -0
  6. data/README.rdoc +19 -0
  7. data/Rakefile +69 -0
  8. data/VERSION +1 -0
  9. data/examples/example.rb +59 -0
  10. data/ext/libsvm/COPYRIGHT +31 -0
  11. data/ext/libsvm/FAQ.html +1749 -0
  12. data/ext/libsvm/Makefile +25 -0
  13. data/ext/libsvm/Makefile.win +33 -0
  14. data/ext/libsvm/README +733 -0
  15. data/ext/libsvm/extconf.rb +1 -0
  16. data/ext/libsvm/heart_scale +270 -0
  17. data/ext/libsvm/java/Makefile +25 -0
  18. data/ext/libsvm/java/libsvm.jar +0 -0
  19. data/ext/libsvm/java/libsvm/svm.java +2776 -0
  20. data/ext/libsvm/java/libsvm/svm.m4 +2776 -0
  21. data/ext/libsvm/java/libsvm/svm_model.java +21 -0
  22. data/ext/libsvm/java/libsvm/svm_node.java +6 -0
  23. data/ext/libsvm/java/libsvm/svm_parameter.java +47 -0
  24. data/ext/libsvm/java/libsvm/svm_print_interface.java +5 -0
  25. data/ext/libsvm/java/libsvm/svm_problem.java +7 -0
  26. data/ext/libsvm/java/svm_predict.java +163 -0
  27. data/ext/libsvm/java/svm_scale.java +350 -0
  28. data/ext/libsvm/java/svm_toy.java +471 -0
  29. data/ext/libsvm/java/svm_train.java +318 -0
  30. data/ext/libsvm/java/test_applet.html +1 -0
  31. data/ext/libsvm/python/Makefile +4 -0
  32. data/ext/libsvm/python/README +331 -0
  33. data/ext/libsvm/python/svm.py +259 -0
  34. data/ext/libsvm/python/svmutil.py +242 -0
  35. data/ext/libsvm/svm-predict.c +226 -0
  36. data/ext/libsvm/svm-scale.c +353 -0
  37. data/ext/libsvm/svm-toy/gtk/Makefile +22 -0
  38. data/ext/libsvm/svm-toy/gtk/callbacks.cpp +423 -0
  39. data/ext/libsvm/svm-toy/gtk/callbacks.h +54 -0
  40. data/ext/libsvm/svm-toy/gtk/interface.c +164 -0
  41. data/ext/libsvm/svm-toy/gtk/interface.h +14 -0
  42. data/ext/libsvm/svm-toy/gtk/main.c +23 -0
  43. data/ext/libsvm/svm-toy/gtk/svm-toy.glade +238 -0
  44. data/ext/libsvm/svm-toy/qt/Makefile +17 -0
  45. data/ext/libsvm/svm-toy/qt/svm-toy.cpp +413 -0
  46. data/ext/libsvm/svm-toy/windows/svm-toy.cpp +456 -0
  47. data/ext/libsvm/svm-train.c +376 -0
  48. data/ext/libsvm/svm.cpp +3060 -0
  49. data/ext/libsvm/svm.def +19 -0
  50. data/ext/libsvm/svm.h +105 -0
  51. data/ext/libsvm/svm.o +0 -0
  52. data/ext/libsvm/tools/README +149 -0
  53. data/ext/libsvm/tools/checkdata.py +108 -0
  54. data/ext/libsvm/tools/easy.py +79 -0
  55. data/ext/libsvm/tools/grid.py +359 -0
  56. data/ext/libsvm/tools/subset.py +146 -0
  57. data/ext/libsvm/windows/libsvm.dll +0 -0
  58. data/ext/libsvm/windows/svm-predict.exe +0 -0
  59. data/ext/libsvm/windows/svm-scale.exe +0 -0
  60. data/ext/libsvm/windows/svm-toy.exe +0 -0
  61. data/ext/libsvm/windows/svm-train.exe +0 -0
  62. data/lib/eluka.rb +10 -0
  63. data/lib/eluka/bijection.rb +23 -0
  64. data/lib/eluka/data_point.rb +36 -0
  65. data/lib/eluka/document.rb +47 -0
  66. data/lib/eluka/feature_vector.rb +86 -0
  67. data/lib/eluka/features.rb +31 -0
  68. data/lib/eluka/model.rb +129 -0
  69. data/lib/fselect.rb +321 -0
  70. data/lib/grid.rb +25 -0
  71. data/test/helper.rb +18 -0
  72. data/test/test_eluka.rb +7 -0
  73. metadata +214 -0
@@ -0,0 +1,376 @@
1
+ #include <stdio.h>
2
+ #include <stdlib.h>
3
+ #include <string.h>
4
+ #include <ctype.h>
5
+ #include <errno.h>
6
+ #include "svm.h"
7
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
8
+
9
+ void print_null(const char *s) {}
10
+
11
+ void exit_with_help()
12
+ {
13
+ printf(
14
+ "Usage: svm-train [options] training_set_file [model_file]\n"
15
+ "options:\n"
16
+ "-s svm_type : set type of SVM (default 0)\n"
17
+ " 0 -- C-SVC\n"
18
+ " 1 -- nu-SVC\n"
19
+ " 2 -- one-class SVM\n"
20
+ " 3 -- epsilon-SVR\n"
21
+ " 4 -- nu-SVR\n"
22
+ "-t kernel_type : set type of kernel function (default 2)\n"
23
+ " 0 -- linear: u'*v\n"
24
+ " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
25
+ " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
26
+ " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
27
+ " 4 -- precomputed kernel (kernel values in training_set_file)\n"
28
+ "-d degree : set degree in kernel function (default 3)\n"
29
+ "-g gamma : set gamma in kernel function (default 1/num_features)\n"
30
+ "-r coef0 : set coef0 in kernel function (default 0)\n"
31
+ "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
32
+ "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
33
+ "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
34
+ "-m cachesize : set cache memory size in MB (default 100)\n"
35
+ "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
36
+ "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
37
+ "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
38
+ "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
39
+ "-v n: n-fold cross validation mode\n"
40
+ "-q : quiet mode (no outputs)\n"
41
+ );
42
+ exit(1);
43
+ }
44
+
45
+ void exit_input_error(int line_num)
46
+ {
47
+ fprintf(stderr,"Wrong input format at line %d\n", line_num);
48
+ exit(1);
49
+ }
50
+
51
+ void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
52
+ void read_problem(const char *filename);
53
+ void do_cross_validation();
54
+
55
+ struct svm_parameter param; // set by parse_command_line
56
+ struct svm_problem prob; // set by read_problem
57
+ struct svm_model *model;
58
+ struct svm_node *x_space;
59
+ int cross_validation;
60
+ int nr_fold;
61
+
62
+ static char *line = NULL;
63
+ static int max_line_len;
64
+
65
+ static char* readline(FILE *input)
66
+ {
67
+ int len;
68
+
69
+ if(fgets(line,max_line_len,input) == NULL)
70
+ return NULL;
71
+
72
+ while(strrchr(line,'\n') == NULL)
73
+ {
74
+ max_line_len *= 2;
75
+ line = (char *) realloc(line,max_line_len);
76
+ len = (int) strlen(line);
77
+ if(fgets(line+len,max_line_len-len,input) == NULL)
78
+ break;
79
+ }
80
+ return line;
81
+ }
82
+
83
+ int main(int argc, char **argv)
84
+ {
85
+ char input_file_name[1024];
86
+ char model_file_name[1024];
87
+ const char *error_msg;
88
+
89
+ parse_command_line(argc, argv, input_file_name, model_file_name);
90
+ read_problem(input_file_name);
91
+ error_msg = svm_check_parameter(&prob,&param);
92
+
93
+ if(error_msg)
94
+ {
95
+ fprintf(stderr,"Error: %s\n",error_msg);
96
+ exit(1);
97
+ }
98
+
99
+ if(cross_validation)
100
+ {
101
+ do_cross_validation();
102
+ }
103
+ else
104
+ {
105
+ model = svm_train(&prob,&param);
106
+ if(svm_save_model(model_file_name,model))
107
+ {
108
+ fprintf(stderr, "can't save model to file %s\n", model_file_name);
109
+ exit(1);
110
+ }
111
+ svm_free_and_destroy_model(&model);
112
+ }
113
+ svm_destroy_param(&param);
114
+ free(prob.y);
115
+ free(prob.x);
116
+ free(x_space);
117
+ free(line);
118
+
119
+ return 0;
120
+ }
121
+
122
+ void do_cross_validation()
123
+ {
124
+ int i;
125
+ int total_correct = 0;
126
+ double total_error = 0;
127
+ double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
128
+ double *target = Malloc(double,prob.l);
129
+
130
+ svm_cross_validation(&prob,&param,nr_fold,target);
131
+ if(param.svm_type == EPSILON_SVR ||
132
+ param.svm_type == NU_SVR)
133
+ {
134
+ for(i=0;i<prob.l;i++)
135
+ {
136
+ double y = prob.y[i];
137
+ double v = target[i];
138
+ total_error += (v-y)*(v-y);
139
+ sumv += v;
140
+ sumy += y;
141
+ sumvv += v*v;
142
+ sumyy += y*y;
143
+ sumvy += v*y;
144
+ }
145
+ printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
146
+ printf("Cross Validation Squared correlation coefficient = %g\n",
147
+ ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
148
+ ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
149
+ );
150
+ }
151
+ else
152
+ {
153
+ for(i=0;i<prob.l;i++)
154
+ if(target[i] == prob.y[i])
155
+ ++total_correct;
156
+ printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
157
+ }
158
+ free(target);
159
+ }
160
+
161
+ void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
162
+ {
163
+ int i;
164
+ void (*print_func)(const char*) = NULL; // default printing to stdout
165
+
166
+ // default values
167
+ param.svm_type = C_SVC;
168
+ param.kernel_type = RBF;
169
+ param.degree = 3;
170
+ param.gamma = 0; // 1/num_features
171
+ param.coef0 = 0;
172
+ param.nu = 0.5;
173
+ param.cache_size = 100;
174
+ param.C = 1;
175
+ param.eps = 1e-3;
176
+ param.p = 0.1;
177
+ param.shrinking = 1;
178
+ param.probability = 0;
179
+ param.nr_weight = 0;
180
+ param.weight_label = NULL;
181
+ param.weight = NULL;
182
+ cross_validation = 0;
183
+
184
+ // parse options
185
+ for(i=1;i<argc;i++)
186
+ {
187
+ if(argv[i][0] != '-') break;
188
+ if(++i>=argc)
189
+ exit_with_help();
190
+ switch(argv[i-1][1])
191
+ {
192
+ case 's':
193
+ param.svm_type = atoi(argv[i]);
194
+ break;
195
+ case 't':
196
+ param.kernel_type = atoi(argv[i]);
197
+ break;
198
+ case 'd':
199
+ param.degree = atoi(argv[i]);
200
+ break;
201
+ case 'g':
202
+ param.gamma = atof(argv[i]);
203
+ break;
204
+ case 'r':
205
+ param.coef0 = atof(argv[i]);
206
+ break;
207
+ case 'n':
208
+ param.nu = atof(argv[i]);
209
+ break;
210
+ case 'm':
211
+ param.cache_size = atof(argv[i]);
212
+ break;
213
+ case 'c':
214
+ param.C = atof(argv[i]);
215
+ break;
216
+ case 'e':
217
+ param.eps = atof(argv[i]);
218
+ break;
219
+ case 'p':
220
+ param.p = atof(argv[i]);
221
+ break;
222
+ case 'h':
223
+ param.shrinking = atoi(argv[i]);
224
+ break;
225
+ case 'b':
226
+ param.probability = atoi(argv[i]);
227
+ break;
228
+ case 'q':
229
+ print_func = &print_null;
230
+ i--;
231
+ break;
232
+ case 'v':
233
+ cross_validation = 1;
234
+ nr_fold = atoi(argv[i]);
235
+ if(nr_fold < 2)
236
+ {
237
+ fprintf(stderr,"n-fold cross validation: n must >= 2\n");
238
+ exit_with_help();
239
+ }
240
+ break;
241
+ case 'w':
242
+ ++param.nr_weight;
243
+ param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
244
+ param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
245
+ param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
246
+ param.weight[param.nr_weight-1] = atof(argv[i]);
247
+ break;
248
+ default:
249
+ fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
250
+ exit_with_help();
251
+ }
252
+ }
253
+
254
+ svm_set_print_string_function(print_func);
255
+
256
+ // determine filenames
257
+
258
+ if(i>=argc)
259
+ exit_with_help();
260
+
261
+ strcpy(input_file_name, argv[i]);
262
+
263
+ if(i<argc-1)
264
+ strcpy(model_file_name,argv[i+1]);
265
+ else
266
+ {
267
+ char *p = strrchr(argv[i],'/');
268
+ if(p==NULL)
269
+ p = argv[i];
270
+ else
271
+ ++p;
272
+ sprintf(model_file_name,"%s.model",p);
273
+ }
274
+ }
275
+
276
+ // read in a problem (in svmlight format)
277
+
278
+ void read_problem(const char *filename)
279
+ {
280
+ int elements, max_index, inst_max_index, i, j;
281
+ FILE *fp = fopen(filename,"r");
282
+ char *endptr;
283
+ char *idx, *val, *label;
284
+
285
+ if(fp == NULL)
286
+ {
287
+ fprintf(stderr,"can't open input file %s\n",filename);
288
+ exit(1);
289
+ }
290
+
291
+ prob.l = 0;
292
+ elements = 0;
293
+
294
+ max_line_len = 1024;
295
+ line = Malloc(char,max_line_len);
296
+ while(readline(fp)!=NULL)
297
+ {
298
+ char *p = strtok(line," \t"); // label
299
+
300
+ // features
301
+ while(1)
302
+ {
303
+ p = strtok(NULL," \t");
304
+ if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
305
+ break;
306
+ ++elements;
307
+ }
308
+ ++elements;
309
+ ++prob.l;
310
+ }
311
+ rewind(fp);
312
+
313
+ prob.y = Malloc(double,prob.l);
314
+ prob.x = Malloc(struct svm_node *,prob.l);
315
+ x_space = Malloc(struct svm_node,elements);
316
+
317
+ max_index = 0;
318
+ j=0;
319
+ for(i=0;i<prob.l;i++)
320
+ {
321
+ inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
322
+ readline(fp);
323
+ prob.x[i] = &x_space[j];
324
+ label = strtok(line," \t");
325
+ prob.y[i] = strtod(label,&endptr);
326
+ if(endptr == label)
327
+ exit_input_error(i+1);
328
+
329
+ while(1)
330
+ {
331
+ idx = strtok(NULL,":");
332
+ val = strtok(NULL," \t");
333
+
334
+ if(val == NULL)
335
+ break;
336
+
337
+ errno = 0;
338
+ x_space[j].index = (int) strtol(idx,&endptr,10);
339
+ if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
340
+ exit_input_error(i+1);
341
+ else
342
+ inst_max_index = x_space[j].index;
343
+
344
+ errno = 0;
345
+ x_space[j].value = strtod(val,&endptr);
346
+ if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
347
+ exit_input_error(i+1);
348
+
349
+ ++j;
350
+ }
351
+
352
+ if(inst_max_index > max_index)
353
+ max_index = inst_max_index;
354
+ x_space[j++].index = -1;
355
+ }
356
+
357
+ if(param.gamma == 0 && max_index > 0)
358
+ param.gamma = 1.0/max_index;
359
+
360
+ if(param.kernel_type == PRECOMPUTED)
361
+ for(i=0;i<prob.l;i++)
362
+ {
363
+ if (prob.x[i][0].index != 0)
364
+ {
365
+ fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
366
+ exit(1);
367
+ }
368
+ if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
369
+ {
370
+ fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
371
+ exit(1);
372
+ }
373
+ }
374
+
375
+ fclose(fp);
376
+ }
@@ -0,0 +1,3060 @@
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+ #include <stdlib.h>
4
+ #include <ctype.h>
5
+ #include <float.h>
6
+ #include <string.h>
7
+ #include <stdarg.h>
8
+ #include "svm.h"
9
+ int libsvm_version = LIBSVM_VERSION;
10
+ typedef float Qfloat;
11
+ typedef signed char schar;
12
+ #ifndef min
13
+ template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
14
+ #endif
15
+ #ifndef max
16
+ template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
17
+ #endif
18
+ template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
19
+ template <class S, class T> static inline void clone(T*& dst, S* src, int n)
20
+ {
21
+ dst = new T[n];
22
+ memcpy((void *)dst,(void *)src,sizeof(T)*n);
23
+ }
24
+ static inline double powi(double base, int times)
25
+ {
26
+ double tmp = base, ret = 1.0;
27
+
28
+ for(int t=times; t>0; t/=2)
29
+ {
30
+ if(t%2==1) ret*=tmp;
31
+ tmp = tmp * tmp;
32
+ }
33
+ return ret;
34
+ }
35
+ #define INF HUGE_VAL
36
+ #define TAU 1e-12
37
+ #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
38
+
39
+ static void print_string_stdout(const char *s)
40
+ {
41
+ fputs(s,stdout);
42
+ fflush(stdout);
43
+ }
44
+ static void (*svm_print_string) (const char *) = &print_string_stdout;
45
+ #if 1
46
+ static void info(const char *fmt,...)
47
+ {
48
+ char buf[BUFSIZ];
49
+ va_list ap;
50
+ va_start(ap,fmt);
51
+ vsprintf(buf,fmt,ap);
52
+ va_end(ap);
53
+ (*svm_print_string)(buf);
54
+ }
55
+ #else
56
+ static void info(const char *fmt,...) {}
57
+ #endif
58
+
59
+ //
60
+ // Kernel Cache
61
+ //
62
+ // l is the number of total data items
63
+ // size is the cache size limit in bytes
64
+ //
65
+ class Cache
66
+ {
67
+ public:
68
+ Cache(int l,long int size);
69
+ ~Cache();
70
+
71
+ // request data [0,len)
72
+ // return some position p where [p,len) need to be filled
73
+ // (p >= len if nothing needs to be filled)
74
+ int get_data(const int index, Qfloat **data, int len);
75
+ void swap_index(int i, int j);
76
+ private:
77
+ int l;
78
+ long int size;
79
+ struct head_t
80
+ {
81
+ head_t *prev, *next; // a circular list
82
+ Qfloat *data;
83
+ int len; // data[0,len) is cached in this entry
84
+ };
85
+
86
+ head_t *head;
87
+ head_t lru_head;
88
+ void lru_delete(head_t *h);
89
+ void lru_insert(head_t *h);
90
+ };
91
+
92
+ Cache::Cache(int l_,long int size_):l(l_),size(size_)
93
+ {
94
+ head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
95
+ size /= sizeof(Qfloat);
96
+ size -= l * sizeof(head_t) / sizeof(Qfloat);
97
+ size = max(size, 2 * (long int) l); // cache must be large enough for two columns
98
+ lru_head.next = lru_head.prev = &lru_head;
99
+ }
100
+
101
+ Cache::~Cache()
102
+ {
103
+ for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
104
+ free(h->data);
105
+ free(head);
106
+ }
107
+
108
+ void Cache::lru_delete(head_t *h)
109
+ {
110
+ // delete from current location
111
+ h->prev->next = h->next;
112
+ h->next->prev = h->prev;
113
+ }
114
+
115
+ void Cache::lru_insert(head_t *h)
116
+ {
117
+ // insert to last position
118
+ h->next = &lru_head;
119
+ h->prev = lru_head.prev;
120
+ h->prev->next = h;
121
+ h->next->prev = h;
122
+ }
123
+
124
+ int Cache::get_data(const int index, Qfloat **data, int len)
125
+ {
126
+ head_t *h = &head[index];
127
+ if(h->len) lru_delete(h);
128
+ int more = len - h->len;
129
+
130
+ if(more > 0)
131
+ {
132
+ // free old space
133
+ while(size < more)
134
+ {
135
+ head_t *old = lru_head.next;
136
+ lru_delete(old);
137
+ free(old->data);
138
+ size += old->len;
139
+ old->data = 0;
140
+ old->len = 0;
141
+ }
142
+
143
+ // allocate new space
144
+ h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
145
+ size -= more;
146
+ swap(h->len,len);
147
+ }
148
+
149
+ lru_insert(h);
150
+ *data = h->data;
151
+ return len;
152
+ }
153
+
154
+ void Cache::swap_index(int i, int j)
155
+ {
156
+ if(i==j) return;
157
+
158
+ if(head[i].len) lru_delete(&head[i]);
159
+ if(head[j].len) lru_delete(&head[j]);
160
+ swap(head[i].data,head[j].data);
161
+ swap(head[i].len,head[j].len);
162
+ if(head[i].len) lru_insert(&head[i]);
163
+ if(head[j].len) lru_insert(&head[j]);
164
+
165
+ if(i>j) swap(i,j);
166
+ for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
167
+ {
168
+ if(h->len > i)
169
+ {
170
+ if(h->len > j)
171
+ swap(h->data[i],h->data[j]);
172
+ else
173
+ {
174
+ // give up
175
+ lru_delete(h);
176
+ free(h->data);
177
+ size += h->len;
178
+ h->data = 0;
179
+ h->len = 0;
180
+ }
181
+ }
182
+ }
183
+ }
184
+
185
+ //
186
+ // Kernel evaluation
187
+ //
188
+ // the static method k_function is for doing single kernel evaluation
189
+ // the constructor of Kernel prepares to calculate the l*l kernel matrix
190
+ // the member function get_Q is for getting one column from the Q Matrix
191
+ //
192
+ class QMatrix {
193
+ public:
194
+ virtual Qfloat *get_Q(int column, int len) const = 0;
195
+ virtual double *get_QD() const = 0;
196
+ virtual void swap_index(int i, int j) const = 0;
197
+ virtual ~QMatrix() {}
198
+ };
199
+
200
+ class Kernel: public QMatrix {
201
+ public:
202
+ Kernel(int l, svm_node * const * x, const svm_parameter& param);
203
+ virtual ~Kernel();
204
+
205
+ static double k_function(const svm_node *x, const svm_node *y,
206
+ const svm_parameter& param);
207
+ virtual Qfloat *get_Q(int column, int len) const = 0;
208
+ virtual double *get_QD() const = 0;
209
+ virtual void swap_index(int i, int j) const // no so const...
210
+ {
211
+ swap(x[i],x[j]);
212
+ if(x_square) swap(x_square[i],x_square[j]);
213
+ }
214
+ protected:
215
+
216
+ double (Kernel::*kernel_function)(int i, int j) const;
217
+
218
+ private:
219
+ const svm_node **x;
220
+ double *x_square;
221
+
222
+ // svm_parameter
223
+ const int kernel_type;
224
+ const int degree;
225
+ const double gamma;
226
+ const double coef0;
227
+
228
+ static double dot(const svm_node *px, const svm_node *py);
229
+ double kernel_linear(int i, int j) const
230
+ {
231
+ return dot(x[i],x[j]);
232
+ }
233
+ double kernel_poly(int i, int j) const
234
+ {
235
+ return powi(gamma*dot(x[i],x[j])+coef0,degree);
236
+ }
237
+ double kernel_rbf(int i, int j) const
238
+ {
239
+ return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
240
+ }
241
+ double kernel_sigmoid(int i, int j) const
242
+ {
243
+ return tanh(gamma*dot(x[i],x[j])+coef0);
244
+ }
245
+ double kernel_precomputed(int i, int j) const
246
+ {
247
+ return x[i][(int)(x[j][0].value)].value;
248
+ }
249
+ };
250
+
251
+ Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
252
+ :kernel_type(param.kernel_type), degree(param.degree),
253
+ gamma(param.gamma), coef0(param.coef0)
254
+ {
255
+ switch(kernel_type)
256
+ {
257
+ case LINEAR:
258
+ kernel_function = &Kernel::kernel_linear;
259
+ break;
260
+ case POLY:
261
+ kernel_function = &Kernel::kernel_poly;
262
+ break;
263
+ case RBF:
264
+ kernel_function = &Kernel::kernel_rbf;
265
+ break;
266
+ case SIGMOID:
267
+ kernel_function = &Kernel::kernel_sigmoid;
268
+ break;
269
+ case PRECOMPUTED:
270
+ kernel_function = &Kernel::kernel_precomputed;
271
+ break;
272
+ }
273
+
274
+ clone(x,x_,l);
275
+
276
+ if(kernel_type == RBF)
277
+ {
278
+ x_square = new double[l];
279
+ for(int i=0;i<l;i++)
280
+ x_square[i] = dot(x[i],x[i]);
281
+ }
282
+ else
283
+ x_square = 0;
284
+ }
285
+
286
+ Kernel::~Kernel()
287
+ {
288
+ delete[] x;
289
+ delete[] x_square;
290
+ }
291
+
292
+ double Kernel::dot(const svm_node *px, const svm_node *py)
293
+ {
294
+ double sum = 0;
295
+ while(px->index != -1 && py->index != -1)
296
+ {
297
+ if(px->index == py->index)
298
+ {
299
+ sum += px->value * py->value;
300
+ ++px;
301
+ ++py;
302
+ }
303
+ else
304
+ {
305
+ if(px->index > py->index)
306
+ ++py;
307
+ else
308
+ ++px;
309
+ }
310
+ }
311
+ return sum;
312
+ }
313
+
314
+ double Kernel::k_function(const svm_node *x, const svm_node *y,
315
+ const svm_parameter& param)
316
+ {
317
+ switch(param.kernel_type)
318
+ {
319
+ case LINEAR:
320
+ return dot(x,y);
321
+ case POLY:
322
+ return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
323
+ case RBF:
324
+ {
325
+ double sum = 0;
326
+ while(x->index != -1 && y->index !=-1)
327
+ {
328
+ if(x->index == y->index)
329
+ {
330
+ double d = x->value - y->value;
331
+ sum += d*d;
332
+ ++x;
333
+ ++y;
334
+ }
335
+ else
336
+ {
337
+ if(x->index > y->index)
338
+ {
339
+ sum += y->value * y->value;
340
+ ++y;
341
+ }
342
+ else
343
+ {
344
+ sum += x->value * x->value;
345
+ ++x;
346
+ }
347
+ }
348
+ }
349
+
350
+ while(x->index != -1)
351
+ {
352
+ sum += x->value * x->value;
353
+ ++x;
354
+ }
355
+
356
+ while(y->index != -1)
357
+ {
358
+ sum += y->value * y->value;
359
+ ++y;
360
+ }
361
+
362
+ return exp(-param.gamma*sum);
363
+ }
364
+ case SIGMOID:
365
+ return tanh(param.gamma*dot(x,y)+param.coef0);
366
+ case PRECOMPUTED: //x: test (validation), y: SV
367
+ return x[(int)(y->value)].value;
368
+ default:
369
+ return 0; // Unreachable
370
+ }
371
+ }
372
+
373
+ // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
374
+ // Solves:
375
+ //
376
+ // min 0.5(\alpha^T Q \alpha) + p^T \alpha
377
+ //
378
+ // y^T \alpha = \delta
379
+ // y_i = +1 or -1
380
+ // 0 <= alpha_i <= Cp for y_i = 1
381
+ // 0 <= alpha_i <= Cn for y_i = -1
382
+ //
383
+ // Given:
384
+ //
385
+ // Q, p, y, Cp, Cn, and an initial feasible point \alpha
386
+ // l is the size of vectors and matrices
387
+ // eps is the stopping tolerance
388
+ //
389
+ // solution will be put in \alpha, objective value will be put in obj
390
+ //
391
+ class Solver {
392
+ public:
393
+ Solver() {};
394
+ virtual ~Solver() {};
395
+
396
+ struct SolutionInfo {
397
+ double obj;
398
+ double rho;
399
+ double upper_bound_p;
400
+ double upper_bound_n;
401
+ double r; // for Solver_NU
402
+ };
403
+
404
+ void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
405
+ double *alpha_, double Cp, double Cn, double eps,
406
+ SolutionInfo* si, int shrinking);
407
+ protected:
408
+ int active_size;
409
+ schar *y;
410
+ double *G; // gradient of objective function
411
+ enum { LOWER_BOUND, UPPER_BOUND, FREE };
412
+ char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
413
+ double *alpha;
414
+ const QMatrix *Q;
415
+ const double *QD;
416
+ double eps;
417
+ double Cp,Cn;
418
+ double *p;
419
+ int *active_set;
420
+ double *G_bar; // gradient, if we treat free variables as 0
421
+ int l;
422
+ bool unshrink; // XXX
423
+
424
+ double get_C(int i)
425
+ {
426
+ return (y[i] > 0)? Cp : Cn;
427
+ }
428
+ void update_alpha_status(int i)
429
+ {
430
+ if(alpha[i] >= get_C(i))
431
+ alpha_status[i] = UPPER_BOUND;
432
+ else if(alpha[i] <= 0)
433
+ alpha_status[i] = LOWER_BOUND;
434
+ else alpha_status[i] = FREE;
435
+ }
436
+ bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
437
+ bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
438
+ bool is_free(int i) { return alpha_status[i] == FREE; }
439
+ void swap_index(int i, int j);
440
+ void reconstruct_gradient();
441
+ virtual int select_working_set(int &i, int &j);
442
+ virtual double calculate_rho();
443
+ virtual void do_shrinking();
444
+ private:
445
+ bool be_shrunk(int i, double Gmax1, double Gmax2);
446
+ };
447
+
448
+ void Solver::swap_index(int i, int j)
449
+ {
450
+ Q->swap_index(i,j);
451
+ swap(y[i],y[j]);
452
+ swap(G[i],G[j]);
453
+ swap(alpha_status[i],alpha_status[j]);
454
+ swap(alpha[i],alpha[j]);
455
+ swap(p[i],p[j]);
456
+ swap(active_set[i],active_set[j]);
457
+ swap(G_bar[i],G_bar[j]);
458
+ }
459
+
460
+ void Solver::reconstruct_gradient()
461
+ {
462
+ // reconstruct inactive elements of G from G_bar and free variables
463
+
464
+ if(active_size == l) return;
465
+
466
+ int i,j;
467
+ int nr_free = 0;
468
+
469
+ for(j=active_size;j<l;j++)
470
+ G[j] = G_bar[j] + p[j];
471
+
472
+ for(j=0;j<active_size;j++)
473
+ if(is_free(j))
474
+ nr_free++;
475
+
476
+ if(2*nr_free < active_size)
477
+ info("\nWarning: using -h 0 may be faster\n");
478
+
479
+ if (nr_free*l > 2*active_size*(l-active_size))
480
+ {
481
+ for(i=active_size;i<l;i++)
482
+ {
483
+ const Qfloat *Q_i = Q->get_Q(i,active_size);
484
+ for(j=0;j<active_size;j++)
485
+ if(is_free(j))
486
+ G[i] += alpha[j] * Q_i[j];
487
+ }
488
+ }
489
+ else
490
+ {
491
+ for(i=0;i<active_size;i++)
492
+ if(is_free(i))
493
+ {
494
+ const Qfloat *Q_i = Q->get_Q(i,l);
495
+ double alpha_i = alpha[i];
496
+ for(j=active_size;j<l;j++)
497
+ G[j] += alpha_i * Q_i[j];
498
+ }
499
+ }
500
+ }
501
+
502
+ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
503
+ double *alpha_, double Cp, double Cn, double eps,
504
+ SolutionInfo* si, int shrinking)
505
+ {
506
+ this->l = l;
507
+ this->Q = &Q;
508
+ QD=Q.get_QD();
509
+ clone(p, p_,l);
510
+ clone(y, y_,l);
511
+ clone(alpha,alpha_,l);
512
+ this->Cp = Cp;
513
+ this->Cn = Cn;
514
+ this->eps = eps;
515
+ unshrink = false;
516
+
517
+ // initialize alpha_status
518
+ {
519
+ alpha_status = new char[l];
520
+ for(int i=0;i<l;i++)
521
+ update_alpha_status(i);
522
+ }
523
+
524
+ // initialize active set (for shrinking)
525
+ {
526
+ active_set = new int[l];
527
+ for(int i=0;i<l;i++)
528
+ active_set[i] = i;
529
+ active_size = l;
530
+ }
531
+
532
+ // initialize gradient
533
+ {
534
+ G = new double[l];
535
+ G_bar = new double[l];
536
+ int i;
537
+ for(i=0;i<l;i++)
538
+ {
539
+ G[i] = p[i];
540
+ G_bar[i] = 0;
541
+ }
542
+ for(i=0;i<l;i++)
543
+ if(!is_lower_bound(i))
544
+ {
545
+ const Qfloat *Q_i = Q.get_Q(i,l);
546
+ double alpha_i = alpha[i];
547
+ int j;
548
+ for(j=0;j<l;j++)
549
+ G[j] += alpha_i*Q_i[j];
550
+ if(is_upper_bound(i))
551
+ for(j=0;j<l;j++)
552
+ G_bar[j] += get_C(i) * Q_i[j];
553
+ }
554
+ }
555
+
556
+ // optimization step
557
+
558
+ int iter = 0;
559
+ int counter = min(l,1000)+1;
560
+
561
+ while(1)
562
+ {
563
+ // show progress and do shrinking
564
+
565
+ if(--counter == 0)
566
+ {
567
+ counter = min(l,1000);
568
+ if(shrinking) do_shrinking();
569
+ info(".");
570
+ }
571
+
572
+ int i,j;
573
+ if(select_working_set(i,j)!=0)
574
+ {
575
+ // reconstruct the whole gradient
576
+ reconstruct_gradient();
577
+ // reset active set size and check
578
+ active_size = l;
579
+ info("*");
580
+ if(select_working_set(i,j)!=0)
581
+ break;
582
+ else
583
+ counter = 1; // do shrinking next iteration
584
+ }
585
+
586
+ ++iter;
587
+
588
+ // update alpha[i] and alpha[j], handle bounds carefully
589
+
590
+ const Qfloat *Q_i = Q.get_Q(i,active_size);
591
+ const Qfloat *Q_j = Q.get_Q(j,active_size);
592
+
593
+ double C_i = get_C(i);
594
+ double C_j = get_C(j);
595
+
596
+ double old_alpha_i = alpha[i];
597
+ double old_alpha_j = alpha[j];
598
+
599
+ if(y[i]!=y[j])
600
+ {
601
+ double quad_coef = QD[i]+QD[j]+2*Q_i[j];
602
+ if (quad_coef <= 0)
603
+ quad_coef = TAU;
604
+ double delta = (-G[i]-G[j])/quad_coef;
605
+ double diff = alpha[i] - alpha[j];
606
+ alpha[i] += delta;
607
+ alpha[j] += delta;
608
+
609
+ if(diff > 0)
610
+ {
611
+ if(alpha[j] < 0)
612
+ {
613
+ alpha[j] = 0;
614
+ alpha[i] = diff;
615
+ }
616
+ }
617
+ else
618
+ {
619
+ if(alpha[i] < 0)
620
+ {
621
+ alpha[i] = 0;
622
+ alpha[j] = -diff;
623
+ }
624
+ }
625
+ if(diff > C_i - C_j)
626
+ {
627
+ if(alpha[i] > C_i)
628
+ {
629
+ alpha[i] = C_i;
630
+ alpha[j] = C_i - diff;
631
+ }
632
+ }
633
+ else
634
+ {
635
+ if(alpha[j] > C_j)
636
+ {
637
+ alpha[j] = C_j;
638
+ alpha[i] = C_j + diff;
639
+ }
640
+ }
641
+ }
642
+ else
643
+ {
644
+ double quad_coef = QD[i]+QD[j]-2*Q_i[j];
645
+ if (quad_coef <= 0)
646
+ quad_coef = TAU;
647
+ double delta = (G[i]-G[j])/quad_coef;
648
+ double sum = alpha[i] + alpha[j];
649
+ alpha[i] -= delta;
650
+ alpha[j] += delta;
651
+
652
+ if(sum > C_i)
653
+ {
654
+ if(alpha[i] > C_i)
655
+ {
656
+ alpha[i] = C_i;
657
+ alpha[j] = sum - C_i;
658
+ }
659
+ }
660
+ else
661
+ {
662
+ if(alpha[j] < 0)
663
+ {
664
+ alpha[j] = 0;
665
+ alpha[i] = sum;
666
+ }
667
+ }
668
+ if(sum > C_j)
669
+ {
670
+ if(alpha[j] > C_j)
671
+ {
672
+ alpha[j] = C_j;
673
+ alpha[i] = sum - C_j;
674
+ }
675
+ }
676
+ else
677
+ {
678
+ if(alpha[i] < 0)
679
+ {
680
+ alpha[i] = 0;
681
+ alpha[j] = sum;
682
+ }
683
+ }
684
+ }
685
+
686
+ // update G
687
+
688
+ double delta_alpha_i = alpha[i] - old_alpha_i;
689
+ double delta_alpha_j = alpha[j] - old_alpha_j;
690
+
691
+ for(int k=0;k<active_size;k++)
692
+ {
693
+ G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
694
+ }
695
+
696
+ // update alpha_status and G_bar
697
+
698
+ {
699
+ bool ui = is_upper_bound(i);
700
+ bool uj = is_upper_bound(j);
701
+ update_alpha_status(i);
702
+ update_alpha_status(j);
703
+ int k;
704
+ if(ui != is_upper_bound(i))
705
+ {
706
+ Q_i = Q.get_Q(i,l);
707
+ if(ui)
708
+ for(k=0;k<l;k++)
709
+ G_bar[k] -= C_i * Q_i[k];
710
+ else
711
+ for(k=0;k<l;k++)
712
+ G_bar[k] += C_i * Q_i[k];
713
+ }
714
+
715
+ if(uj != is_upper_bound(j))
716
+ {
717
+ Q_j = Q.get_Q(j,l);
718
+ if(uj)
719
+ for(k=0;k<l;k++)
720
+ G_bar[k] -= C_j * Q_j[k];
721
+ else
722
+ for(k=0;k<l;k++)
723
+ G_bar[k] += C_j * Q_j[k];
724
+ }
725
+ }
726
+ }
727
+
728
+ // calculate rho
729
+
730
+ si->rho = calculate_rho();
731
+
732
+ // calculate objective value
733
+ {
734
+ double v = 0;
735
+ int i;
736
+ for(i=0;i<l;i++)
737
+ v += alpha[i] * (G[i] + p[i]);
738
+
739
+ si->obj = v/2;
740
+ }
741
+
742
+ // put back the solution
743
+ {
744
+ for(int i=0;i<l;i++)
745
+ alpha_[active_set[i]] = alpha[i];
746
+ }
747
+
748
+ // juggle everything back
749
+ /*{
750
+ for(int i=0;i<l;i++)
751
+ while(active_set[i] != i)
752
+ swap_index(i,active_set[i]);
753
+ // or Q.swap_index(i,active_set[i]);
754
+ }*/
755
+
756
+ si->upper_bound_p = Cp;
757
+ si->upper_bound_n = Cn;
758
+
759
+ info("\noptimization finished, #iter = %d\n",iter);
760
+
761
+ delete[] p;
762
+ delete[] y;
763
+ delete[] alpha;
764
+ delete[] alpha_status;
765
+ delete[] active_set;
766
+ delete[] G;
767
+ delete[] G_bar;
768
+ }
769
+
770
+ // return 1 if already optimal, return 0 otherwise
771
+ int Solver::select_working_set(int &out_i, int &out_j)
772
+ {
773
+ // return i,j such that
774
+ // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
775
+ // j: minimizes the decrease of obj value
776
+ // (if quadratic coefficeint <= 0, replace it with tau)
777
+ // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
778
+
779
+ double Gmax = -INF;
780
+ double Gmax2 = -INF;
781
+ int Gmax_idx = -1;
782
+ int Gmin_idx = -1;
783
+ double obj_diff_min = INF;
784
+
785
+ for(int t=0;t<active_size;t++)
786
+ if(y[t]==+1)
787
+ {
788
+ if(!is_upper_bound(t))
789
+ if(-G[t] >= Gmax)
790
+ {
791
+ Gmax = -G[t];
792
+ Gmax_idx = t;
793
+ }
794
+ }
795
+ else
796
+ {
797
+ if(!is_lower_bound(t))
798
+ if(G[t] >= Gmax)
799
+ {
800
+ Gmax = G[t];
801
+ Gmax_idx = t;
802
+ }
803
+ }
804
+
805
+ int i = Gmax_idx;
806
+ const Qfloat *Q_i = NULL;
807
+ if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
808
+ Q_i = Q->get_Q(i,active_size);
809
+
810
+ for(int j=0;j<active_size;j++)
811
+ {
812
+ if(y[j]==+1)
813
+ {
814
+ if (!is_lower_bound(j))
815
+ {
816
+ double grad_diff=Gmax+G[j];
817
+ if (G[j] >= Gmax2)
818
+ Gmax2 = G[j];
819
+ if (grad_diff > 0)
820
+ {
821
+ double obj_diff;
822
+ double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
823
+ if (quad_coef > 0)
824
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
825
+ else
826
+ obj_diff = -(grad_diff*grad_diff)/TAU;
827
+
828
+ if (obj_diff <= obj_diff_min)
829
+ {
830
+ Gmin_idx=j;
831
+ obj_diff_min = obj_diff;
832
+ }
833
+ }
834
+ }
835
+ }
836
+ else
837
+ {
838
+ if (!is_upper_bound(j))
839
+ {
840
+ double grad_diff= Gmax-G[j];
841
+ if (-G[j] >= Gmax2)
842
+ Gmax2 = -G[j];
843
+ if (grad_diff > 0)
844
+ {
845
+ double obj_diff;
846
+ double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
847
+ if (quad_coef > 0)
848
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
849
+ else
850
+ obj_diff = -(grad_diff*grad_diff)/TAU;
851
+
852
+ if (obj_diff <= obj_diff_min)
853
+ {
854
+ Gmin_idx=j;
855
+ obj_diff_min = obj_diff;
856
+ }
857
+ }
858
+ }
859
+ }
860
+ }
861
+
862
+ if(Gmax+Gmax2 < eps)
863
+ return 1;
864
+
865
+ out_i = Gmax_idx;
866
+ out_j = Gmin_idx;
867
+ return 0;
868
+ }
869
+
870
+ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
871
+ {
872
+ if(is_upper_bound(i))
873
+ {
874
+ if(y[i]==+1)
875
+ return(-G[i] > Gmax1);
876
+ else
877
+ return(-G[i] > Gmax2);
878
+ }
879
+ else if(is_lower_bound(i))
880
+ {
881
+ if(y[i]==+1)
882
+ return(G[i] > Gmax2);
883
+ else
884
+ return(G[i] > Gmax1);
885
+ }
886
+ else
887
+ return(false);
888
+ }
889
+
890
+ void Solver::do_shrinking()
891
+ {
892
+ int i;
893
+ double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
894
+ double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
895
+
896
+ // find maximal violating pair first
897
+ for(i=0;i<active_size;i++)
898
+ {
899
+ if(y[i]==+1)
900
+ {
901
+ if(!is_upper_bound(i))
902
+ {
903
+ if(-G[i] >= Gmax1)
904
+ Gmax1 = -G[i];
905
+ }
906
+ if(!is_lower_bound(i))
907
+ {
908
+ if(G[i] >= Gmax2)
909
+ Gmax2 = G[i];
910
+ }
911
+ }
912
+ else
913
+ {
914
+ if(!is_upper_bound(i))
915
+ {
916
+ if(-G[i] >= Gmax2)
917
+ Gmax2 = -G[i];
918
+ }
919
+ if(!is_lower_bound(i))
920
+ {
921
+ if(G[i] >= Gmax1)
922
+ Gmax1 = G[i];
923
+ }
924
+ }
925
+ }
926
+
927
+ if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
928
+ {
929
+ unshrink = true;
930
+ reconstruct_gradient();
931
+ active_size = l;
932
+ info("*");
933
+ }
934
+
935
+ for(i=0;i<active_size;i++)
936
+ if (be_shrunk(i, Gmax1, Gmax2))
937
+ {
938
+ active_size--;
939
+ while (active_size > i)
940
+ {
941
+ if (!be_shrunk(active_size, Gmax1, Gmax2))
942
+ {
943
+ swap_index(i,active_size);
944
+ break;
945
+ }
946
+ active_size--;
947
+ }
948
+ }
949
+ }
950
+
951
+ double Solver::calculate_rho()
952
+ {
953
+ double r;
954
+ int nr_free = 0;
955
+ double ub = INF, lb = -INF, sum_free = 0;
956
+ for(int i=0;i<active_size;i++)
957
+ {
958
+ double yG = y[i]*G[i];
959
+
960
+ if(is_upper_bound(i))
961
+ {
962
+ if(y[i]==-1)
963
+ ub = min(ub,yG);
964
+ else
965
+ lb = max(lb,yG);
966
+ }
967
+ else if(is_lower_bound(i))
968
+ {
969
+ if(y[i]==+1)
970
+ ub = min(ub,yG);
971
+ else
972
+ lb = max(lb,yG);
973
+ }
974
+ else
975
+ {
976
+ ++nr_free;
977
+ sum_free += yG;
978
+ }
979
+ }
980
+
981
+ if(nr_free>0)
982
+ r = sum_free/nr_free;
983
+ else
984
+ r = (ub+lb)/2;
985
+
986
+ return r;
987
+ }
988
+
989
+ //
990
+ // Solver for nu-svm classification and regression
991
+ //
992
+ // additional constraint: e^T \alpha = constant
993
+ //
994
+ class Solver_NU : public Solver
995
+ {
996
+ public:
997
+ Solver_NU() {}
998
+ void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
999
+ double *alpha, double Cp, double Cn, double eps,
1000
+ SolutionInfo* si, int shrinking)
1001
+ {
1002
+ this->si = si;
1003
+ Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
1004
+ }
1005
+ private:
1006
+ SolutionInfo *si;
1007
+ int select_working_set(int &i, int &j);
1008
+ double calculate_rho();
1009
+ bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
1010
+ void do_shrinking();
1011
+ };
1012
+
1013
+ // return 1 if already optimal, return 0 otherwise
1014
+ int Solver_NU::select_working_set(int &out_i, int &out_j)
1015
+ {
1016
+ // return i,j such that y_i = y_j and
1017
+ // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
1018
+ // j: minimizes the decrease of obj value
1019
+ // (if quadratic coefficeint <= 0, replace it with tau)
1020
+ // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
1021
+
1022
+ double Gmaxp = -INF;
1023
+ double Gmaxp2 = -INF;
1024
+ int Gmaxp_idx = -1;
1025
+
1026
+ double Gmaxn = -INF;
1027
+ double Gmaxn2 = -INF;
1028
+ int Gmaxn_idx = -1;
1029
+
1030
+ int Gmin_idx = -1;
1031
+ double obj_diff_min = INF;
1032
+
1033
+ for(int t=0;t<active_size;t++)
1034
+ if(y[t]==+1)
1035
+ {
1036
+ if(!is_upper_bound(t))
1037
+ if(-G[t] >= Gmaxp)
1038
+ {
1039
+ Gmaxp = -G[t];
1040
+ Gmaxp_idx = t;
1041
+ }
1042
+ }
1043
+ else
1044
+ {
1045
+ if(!is_lower_bound(t))
1046
+ if(G[t] >= Gmaxn)
1047
+ {
1048
+ Gmaxn = G[t];
1049
+ Gmaxn_idx = t;
1050
+ }
1051
+ }
1052
+
1053
+ int ip = Gmaxp_idx;
1054
+ int in = Gmaxn_idx;
1055
+ const Qfloat *Q_ip = NULL;
1056
+ const Qfloat *Q_in = NULL;
1057
+ if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
1058
+ Q_ip = Q->get_Q(ip,active_size);
1059
+ if(in != -1)
1060
+ Q_in = Q->get_Q(in,active_size);
1061
+
1062
+ for(int j=0;j<active_size;j++)
1063
+ {
1064
+ if(y[j]==+1)
1065
+ {
1066
+ if (!is_lower_bound(j))
1067
+ {
1068
+ double grad_diff=Gmaxp+G[j];
1069
+ if (G[j] >= Gmaxp2)
1070
+ Gmaxp2 = G[j];
1071
+ if (grad_diff > 0)
1072
+ {
1073
+ double obj_diff;
1074
+ double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
1075
+ if (quad_coef > 0)
1076
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
1077
+ else
1078
+ obj_diff = -(grad_diff*grad_diff)/TAU;
1079
+
1080
+ if (obj_diff <= obj_diff_min)
1081
+ {
1082
+ Gmin_idx=j;
1083
+ obj_diff_min = obj_diff;
1084
+ }
1085
+ }
1086
+ }
1087
+ }
1088
+ else
1089
+ {
1090
+ if (!is_upper_bound(j))
1091
+ {
1092
+ double grad_diff=Gmaxn-G[j];
1093
+ if (-G[j] >= Gmaxn2)
1094
+ Gmaxn2 = -G[j];
1095
+ if (grad_diff > 0)
1096
+ {
1097
+ double obj_diff;
1098
+ double quad_coef = QD[in]+QD[j]-2*Q_in[j];
1099
+ if (quad_coef > 0)
1100
+ obj_diff = -(grad_diff*grad_diff)/quad_coef;
1101
+ else
1102
+ obj_diff = -(grad_diff*grad_diff)/TAU;
1103
+
1104
+ if (obj_diff <= obj_diff_min)
1105
+ {
1106
+ Gmin_idx=j;
1107
+ obj_diff_min = obj_diff;
1108
+ }
1109
+ }
1110
+ }
1111
+ }
1112
+ }
1113
+
1114
+ if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps)
1115
+ return 1;
1116
+
1117
+ if (y[Gmin_idx] == +1)
1118
+ out_i = Gmaxp_idx;
1119
+ else
1120
+ out_i = Gmaxn_idx;
1121
+ out_j = Gmin_idx;
1122
+
1123
+ return 0;
1124
+ }
1125
+
1126
+ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
1127
+ {
1128
+ if(is_upper_bound(i))
1129
+ {
1130
+ if(y[i]==+1)
1131
+ return(-G[i] > Gmax1);
1132
+ else
1133
+ return(-G[i] > Gmax4);
1134
+ }
1135
+ else if(is_lower_bound(i))
1136
+ {
1137
+ if(y[i]==+1)
1138
+ return(G[i] > Gmax2);
1139
+ else
1140
+ return(G[i] > Gmax3);
1141
+ }
1142
+ else
1143
+ return(false);
1144
+ }
1145
+
1146
+ void Solver_NU::do_shrinking()
1147
+ {
1148
+ double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
1149
+ double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
1150
+ double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
1151
+ double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
1152
+
1153
+ // find maximal violating pair first
1154
+ int i;
1155
+ for(i=0;i<active_size;i++)
1156
+ {
1157
+ if(!is_upper_bound(i))
1158
+ {
1159
+ if(y[i]==+1)
1160
+ {
1161
+ if(-G[i] > Gmax1) Gmax1 = -G[i];
1162
+ }
1163
+ else if(-G[i] > Gmax4) Gmax4 = -G[i];
1164
+ }
1165
+ if(!is_lower_bound(i))
1166
+ {
1167
+ if(y[i]==+1)
1168
+ {
1169
+ if(G[i] > Gmax2) Gmax2 = G[i];
1170
+ }
1171
+ else if(G[i] > Gmax3) Gmax3 = G[i];
1172
+ }
1173
+ }
1174
+
1175
+ if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10)
1176
+ {
1177
+ unshrink = true;
1178
+ reconstruct_gradient();
1179
+ active_size = l;
1180
+ }
1181
+
1182
+ for(i=0;i<active_size;i++)
1183
+ if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
1184
+ {
1185
+ active_size--;
1186
+ while (active_size > i)
1187
+ {
1188
+ if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
1189
+ {
1190
+ swap_index(i,active_size);
1191
+ break;
1192
+ }
1193
+ active_size--;
1194
+ }
1195
+ }
1196
+ }
1197
+
1198
+ double Solver_NU::calculate_rho()
1199
+ {
1200
+ int nr_free1 = 0,nr_free2 = 0;
1201
+ double ub1 = INF, ub2 = INF;
1202
+ double lb1 = -INF, lb2 = -INF;
1203
+ double sum_free1 = 0, sum_free2 = 0;
1204
+
1205
+ for(int i=0;i<active_size;i++)
1206
+ {
1207
+ if(y[i]==+1)
1208
+ {
1209
+ if(is_upper_bound(i))
1210
+ lb1 = max(lb1,G[i]);
1211
+ else if(is_lower_bound(i))
1212
+ ub1 = min(ub1,G[i]);
1213
+ else
1214
+ {
1215
+ ++nr_free1;
1216
+ sum_free1 += G[i];
1217
+ }
1218
+ }
1219
+ else
1220
+ {
1221
+ if(is_upper_bound(i))
1222
+ lb2 = max(lb2,G[i]);
1223
+ else if(is_lower_bound(i))
1224
+ ub2 = min(ub2,G[i]);
1225
+ else
1226
+ {
1227
+ ++nr_free2;
1228
+ sum_free2 += G[i];
1229
+ }
1230
+ }
1231
+ }
1232
+
1233
+ double r1,r2;
1234
+ if(nr_free1 > 0)
1235
+ r1 = sum_free1/nr_free1;
1236
+ else
1237
+ r1 = (ub1+lb1)/2;
1238
+
1239
+ if(nr_free2 > 0)
1240
+ r2 = sum_free2/nr_free2;
1241
+ else
1242
+ r2 = (ub2+lb2)/2;
1243
+
1244
+ si->r = (r1+r2)/2;
1245
+ return (r1-r2)/2;
1246
+ }
1247
+
1248
+ //
1249
+ // Q matrices for various formulations
1250
+ //
1251
+ class SVC_Q: public Kernel
1252
+ {
1253
+ public:
1254
+ SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
1255
+ :Kernel(prob.l, prob.x, param)
1256
+ {
1257
+ clone(y,y_,prob.l);
1258
+ cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1259
+ QD = new double[prob.l];
1260
+ for(int i=0;i<prob.l;i++)
1261
+ QD[i] = (this->*kernel_function)(i,i);
1262
+ }
1263
+
1264
+ Qfloat *get_Q(int i, int len) const
1265
+ {
1266
+ Qfloat *data;
1267
+ int start, j;
1268
+ if((start = cache->get_data(i,&data,len)) < len)
1269
+ {
1270
+ for(j=start;j<len;j++)
1271
+ data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
1272
+ }
1273
+ return data;
1274
+ }
1275
+
1276
+ double *get_QD() const
1277
+ {
1278
+ return QD;
1279
+ }
1280
+
1281
+ void swap_index(int i, int j) const
1282
+ {
1283
+ cache->swap_index(i,j);
1284
+ Kernel::swap_index(i,j);
1285
+ swap(y[i],y[j]);
1286
+ swap(QD[i],QD[j]);
1287
+ }
1288
+
1289
+ ~SVC_Q()
1290
+ {
1291
+ delete[] y;
1292
+ delete cache;
1293
+ delete[] QD;
1294
+ }
1295
+ private:
1296
+ schar *y;
1297
+ Cache *cache;
1298
+ double *QD;
1299
+ };
1300
+
1301
+ class ONE_CLASS_Q: public Kernel
1302
+ {
1303
+ public:
1304
+ ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
1305
+ :Kernel(prob.l, prob.x, param)
1306
+ {
1307
+ cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
1308
+ QD = new double[prob.l];
1309
+ for(int i=0;i<prob.l;i++)
1310
+ QD[i] = (this->*kernel_function)(i,i);
1311
+ }
1312
+
1313
+ Qfloat *get_Q(int i, int len) const
1314
+ {
1315
+ Qfloat *data;
1316
+ int start, j;
1317
+ if((start = cache->get_data(i,&data,len)) < len)
1318
+ {
1319
+ for(j=start;j<len;j++)
1320
+ data[j] = (Qfloat)(this->*kernel_function)(i,j);
1321
+ }
1322
+ return data;
1323
+ }
1324
+
1325
+ double *get_QD() const
1326
+ {
1327
+ return QD;
1328
+ }
1329
+
1330
+ void swap_index(int i, int j) const
1331
+ {
1332
+ cache->swap_index(i,j);
1333
+ Kernel::swap_index(i,j);
1334
+ swap(QD[i],QD[j]);
1335
+ }
1336
+
1337
+ ~ONE_CLASS_Q()
1338
+ {
1339
+ delete cache;
1340
+ delete[] QD;
1341
+ }
1342
+ private:
1343
+ Cache *cache;
1344
+ double *QD;
1345
+ };
1346
+
1347
+ class SVR_Q: public Kernel
1348
+ {
1349
+ public:
1350
+ SVR_Q(const svm_problem& prob, const svm_parameter& param)
1351
+ :Kernel(prob.l, prob.x, param)
1352
+ {
1353
+ l = prob.l;
1354
+ cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
1355
+ QD = new double[2*l];
1356
+ sign = new schar[2*l];
1357
+ index = new int[2*l];
1358
+ for(int k=0;k<l;k++)
1359
+ {
1360
+ sign[k] = 1;
1361
+ sign[k+l] = -1;
1362
+ index[k] = k;
1363
+ index[k+l] = k;
1364
+ QD[k] = (this->*kernel_function)(k,k);
1365
+ QD[k+l] = QD[k];
1366
+ }
1367
+ buffer[0] = new Qfloat[2*l];
1368
+ buffer[1] = new Qfloat[2*l];
1369
+ next_buffer = 0;
1370
+ }
1371
+
1372
+ void swap_index(int i, int j) const
1373
+ {
1374
+ swap(sign[i],sign[j]);
1375
+ swap(index[i],index[j]);
1376
+ swap(QD[i],QD[j]);
1377
+ }
1378
+
1379
+ Qfloat *get_Q(int i, int len) const
1380
+ {
1381
+ Qfloat *data;
1382
+ int j, real_i = index[i];
1383
+ if(cache->get_data(real_i,&data,l) < l)
1384
+ {
1385
+ for(j=0;j<l;j++)
1386
+ data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
1387
+ }
1388
+
1389
+ // reorder and copy
1390
+ Qfloat *buf = buffer[next_buffer];
1391
+ next_buffer = 1 - next_buffer;
1392
+ schar si = sign[i];
1393
+ for(j=0;j<len;j++)
1394
+ buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
1395
+ return buf;
1396
+ }
1397
+
1398
+ double *get_QD() const
1399
+ {
1400
+ return QD;
1401
+ }
1402
+
1403
+ ~SVR_Q()
1404
+ {
1405
+ delete cache;
1406
+ delete[] sign;
1407
+ delete[] index;
1408
+ delete[] buffer[0];
1409
+ delete[] buffer[1];
1410
+ delete[] QD;
1411
+ }
1412
+ private:
1413
+ int l;
1414
+ Cache *cache;
1415
+ schar *sign;
1416
+ int *index;
1417
+ mutable int next_buffer;
1418
+ Qfloat *buffer[2];
1419
+ double *QD;
1420
+ };
1421
+
1422
+ //
1423
+ // construct and solve various formulations
1424
+ //
1425
+ static void solve_c_svc(
1426
+ const svm_problem *prob, const svm_parameter* param,
1427
+ double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
1428
+ {
1429
+ int l = prob->l;
1430
+ double *minus_ones = new double[l];
1431
+ schar *y = new schar[l];
1432
+
1433
+ int i;
1434
+
1435
+ for(i=0;i<l;i++)
1436
+ {
1437
+ alpha[i] = 0;
1438
+ minus_ones[i] = -1;
1439
+ if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
1440
+ }
1441
+
1442
+ Solver s;
1443
+ s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
1444
+ alpha, Cp, Cn, param->eps, si, param->shrinking);
1445
+
1446
+ double sum_alpha=0;
1447
+ for(i=0;i<l;i++)
1448
+ sum_alpha += alpha[i];
1449
+
1450
+ if (Cp==Cn)
1451
+ info("nu = %f\n", sum_alpha/(Cp*prob->l));
1452
+
1453
+ for(i=0;i<l;i++)
1454
+ alpha[i] *= y[i];
1455
+
1456
+ delete[] minus_ones;
1457
+ delete[] y;
1458
+ }
1459
+
1460
+ static void solve_nu_svc(
1461
+ const svm_problem *prob, const svm_parameter *param,
1462
+ double *alpha, Solver::SolutionInfo* si)
1463
+ {
1464
+ int i;
1465
+ int l = prob->l;
1466
+ double nu = param->nu;
1467
+
1468
+ schar *y = new schar[l];
1469
+
1470
+ for(i=0;i<l;i++)
1471
+ if(prob->y[i]>0)
1472
+ y[i] = +1;
1473
+ else
1474
+ y[i] = -1;
1475
+
1476
+ double sum_pos = nu*l/2;
1477
+ double sum_neg = nu*l/2;
1478
+
1479
+ for(i=0;i<l;i++)
1480
+ if(y[i] == +1)
1481
+ {
1482
+ alpha[i] = min(1.0,sum_pos);
1483
+ sum_pos -= alpha[i];
1484
+ }
1485
+ else
1486
+ {
1487
+ alpha[i] = min(1.0,sum_neg);
1488
+ sum_neg -= alpha[i];
1489
+ }
1490
+
1491
+ double *zeros = new double[l];
1492
+
1493
+ for(i=0;i<l;i++)
1494
+ zeros[i] = 0;
1495
+
1496
+ Solver_NU s;
1497
+ s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
1498
+ alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1499
+ double r = si->r;
1500
+
1501
+ info("C = %f\n",1/r);
1502
+
1503
+ for(i=0;i<l;i++)
1504
+ alpha[i] *= y[i]/r;
1505
+
1506
+ si->rho /= r;
1507
+ si->obj /= (r*r);
1508
+ si->upper_bound_p = 1/r;
1509
+ si->upper_bound_n = 1/r;
1510
+
1511
+ delete[] y;
1512
+ delete[] zeros;
1513
+ }
1514
+
1515
+ static void solve_one_class(
1516
+ const svm_problem *prob, const svm_parameter *param,
1517
+ double *alpha, Solver::SolutionInfo* si)
1518
+ {
1519
+ int l = prob->l;
1520
+ double *zeros = new double[l];
1521
+ schar *ones = new schar[l];
1522
+ int i;
1523
+
1524
+ int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
1525
+
1526
+ for(i=0;i<n;i++)
1527
+ alpha[i] = 1;
1528
+ if(n<prob->l)
1529
+ alpha[n] = param->nu * prob->l - n;
1530
+ for(i=n+1;i<l;i++)
1531
+ alpha[i] = 0;
1532
+
1533
+ for(i=0;i<l;i++)
1534
+ {
1535
+ zeros[i] = 0;
1536
+ ones[i] = 1;
1537
+ }
1538
+
1539
+ Solver s;
1540
+ s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
1541
+ alpha, 1.0, 1.0, param->eps, si, param->shrinking);
1542
+
1543
+ delete[] zeros;
1544
+ delete[] ones;
1545
+ }
1546
+
1547
+ static void solve_epsilon_svr(
1548
+ const svm_problem *prob, const svm_parameter *param,
1549
+ double *alpha, Solver::SolutionInfo* si)
1550
+ {
1551
+ int l = prob->l;
1552
+ double *alpha2 = new double[2*l];
1553
+ double *linear_term = new double[2*l];
1554
+ schar *y = new schar[2*l];
1555
+ int i;
1556
+
1557
+ for(i=0;i<l;i++)
1558
+ {
1559
+ alpha2[i] = 0;
1560
+ linear_term[i] = param->p - prob->y[i];
1561
+ y[i] = 1;
1562
+
1563
+ alpha2[i+l] = 0;
1564
+ linear_term[i+l] = param->p + prob->y[i];
1565
+ y[i+l] = -1;
1566
+ }
1567
+
1568
+ Solver s;
1569
+ s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1570
+ alpha2, param->C, param->C, param->eps, si, param->shrinking);
1571
+
1572
+ double sum_alpha = 0;
1573
+ for(i=0;i<l;i++)
1574
+ {
1575
+ alpha[i] = alpha2[i] - alpha2[i+l];
1576
+ sum_alpha += fabs(alpha[i]);
1577
+ }
1578
+ info("nu = %f\n",sum_alpha/(param->C*l));
1579
+
1580
+ delete[] alpha2;
1581
+ delete[] linear_term;
1582
+ delete[] y;
1583
+ }
1584
+
1585
+ static void solve_nu_svr(
1586
+ const svm_problem *prob, const svm_parameter *param,
1587
+ double *alpha, Solver::SolutionInfo* si)
1588
+ {
1589
+ int l = prob->l;
1590
+ double C = param->C;
1591
+ double *alpha2 = new double[2*l];
1592
+ double *linear_term = new double[2*l];
1593
+ schar *y = new schar[2*l];
1594
+ int i;
1595
+
1596
+ double sum = C * param->nu * l / 2;
1597
+ for(i=0;i<l;i++)
1598
+ {
1599
+ alpha2[i] = alpha2[i+l] = min(sum,C);
1600
+ sum -= alpha2[i];
1601
+
1602
+ linear_term[i] = - prob->y[i];
1603
+ y[i] = 1;
1604
+
1605
+ linear_term[i+l] = prob->y[i];
1606
+ y[i+l] = -1;
1607
+ }
1608
+
1609
+ Solver_NU s;
1610
+ s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
1611
+ alpha2, C, C, param->eps, si, param->shrinking);
1612
+
1613
+ info("epsilon = %f\n",-si->r);
1614
+
1615
+ for(i=0;i<l;i++)
1616
+ alpha[i] = alpha2[i] - alpha2[i+l];
1617
+
1618
+ delete[] alpha2;
1619
+ delete[] linear_term;
1620
+ delete[] y;
1621
+ }
1622
+
1623
+ //
1624
+ // decision_function
1625
+ //
1626
+ struct decision_function
1627
+ {
1628
+ double *alpha;
1629
+ double rho;
1630
+ };
1631
+
1632
+ static decision_function svm_train_one(
1633
+ const svm_problem *prob, const svm_parameter *param,
1634
+ double Cp, double Cn)
1635
+ {
1636
+ double *alpha = Malloc(double,prob->l);
1637
+ Solver::SolutionInfo si;
1638
+ switch(param->svm_type)
1639
+ {
1640
+ case C_SVC:
1641
+ solve_c_svc(prob,param,alpha,&si,Cp,Cn);
1642
+ break;
1643
+ case NU_SVC:
1644
+ solve_nu_svc(prob,param,alpha,&si);
1645
+ break;
1646
+ case ONE_CLASS:
1647
+ solve_one_class(prob,param,alpha,&si);
1648
+ break;
1649
+ case EPSILON_SVR:
1650
+ solve_epsilon_svr(prob,param,alpha,&si);
1651
+ break;
1652
+ case NU_SVR:
1653
+ solve_nu_svr(prob,param,alpha,&si);
1654
+ break;
1655
+ }
1656
+
1657
+ info("obj = %f, rho = %f\n",si.obj,si.rho);
1658
+
1659
+ // output SVs
1660
+
1661
+ int nSV = 0;
1662
+ int nBSV = 0;
1663
+ for(int i=0;i<prob->l;i++)
1664
+ {
1665
+ if(fabs(alpha[i]) > 0)
1666
+ {
1667
+ ++nSV;
1668
+ if(prob->y[i] > 0)
1669
+ {
1670
+ if(fabs(alpha[i]) >= si.upper_bound_p)
1671
+ ++nBSV;
1672
+ }
1673
+ else
1674
+ {
1675
+ if(fabs(alpha[i]) >= si.upper_bound_n)
1676
+ ++nBSV;
1677
+ }
1678
+ }
1679
+ }
1680
+
1681
+ info("nSV = %d, nBSV = %d\n",nSV,nBSV);
1682
+
1683
+ decision_function f;
1684
+ f.alpha = alpha;
1685
+ f.rho = si.rho;
1686
+ return f;
1687
+ }
1688
+
1689
+ // Platt's binary SVM Probablistic Output: an improvement from Lin et al.
1690
+ static void sigmoid_train(
1691
+ int l, const double *dec_values, const double *labels,
1692
+ double& A, double& B)
1693
+ {
1694
+ double prior1=0, prior0 = 0;
1695
+ int i;
1696
+
1697
+ for (i=0;i<l;i++)
1698
+ if (labels[i] > 0) prior1+=1;
1699
+ else prior0+=1;
1700
+
1701
+ int max_iter=100; // Maximal number of iterations
1702
+ double min_step=1e-10; // Minimal step taken in line search
1703
+ double sigma=1e-12; // For numerically strict PD of Hessian
1704
+ double eps=1e-5;
1705
+ double hiTarget=(prior1+1.0)/(prior1+2.0);
1706
+ double loTarget=1/(prior0+2.0);
1707
+ double *t=Malloc(double,l);
1708
+ double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
1709
+ double newA,newB,newf,d1,d2;
1710
+ int iter;
1711
+
1712
+ // Initial Point and Initial Fun Value
1713
+ A=0.0; B=log((prior0+1.0)/(prior1+1.0));
1714
+ double fval = 0.0;
1715
+
1716
+ for (i=0;i<l;i++)
1717
+ {
1718
+ if (labels[i]>0) t[i]=hiTarget;
1719
+ else t[i]=loTarget;
1720
+ fApB = dec_values[i]*A+B;
1721
+ if (fApB>=0)
1722
+ fval += t[i]*fApB + log(1+exp(-fApB));
1723
+ else
1724
+ fval += (t[i] - 1)*fApB +log(1+exp(fApB));
1725
+ }
1726
+ for (iter=0;iter<max_iter;iter++)
1727
+ {
1728
+ // Update Gradient and Hessian (use H' = H + sigma I)
1729
+ h11=sigma; // numerically ensures strict PD
1730
+ h22=sigma;
1731
+ h21=0.0;g1=0.0;g2=0.0;
1732
+ for (i=0;i<l;i++)
1733
+ {
1734
+ fApB = dec_values[i]*A+B;
1735
+ if (fApB >= 0)
1736
+ {
1737
+ p=exp(-fApB)/(1.0+exp(-fApB));
1738
+ q=1.0/(1.0+exp(-fApB));
1739
+ }
1740
+ else
1741
+ {
1742
+ p=1.0/(1.0+exp(fApB));
1743
+ q=exp(fApB)/(1.0+exp(fApB));
1744
+ }
1745
+ d2=p*q;
1746
+ h11+=dec_values[i]*dec_values[i]*d2;
1747
+ h22+=d2;
1748
+ h21+=dec_values[i]*d2;
1749
+ d1=t[i]-p;
1750
+ g1+=dec_values[i]*d1;
1751
+ g2+=d1;
1752
+ }
1753
+
1754
+ // Stopping Criteria
1755
+ if (fabs(g1)<eps && fabs(g2)<eps)
1756
+ break;
1757
+
1758
+ // Finding Newton direction: -inv(H') * g
1759
+ det=h11*h22-h21*h21;
1760
+ dA=-(h22*g1 - h21 * g2) / det;
1761
+ dB=-(-h21*g1+ h11 * g2) / det;
1762
+ gd=g1*dA+g2*dB;
1763
+
1764
+
1765
+ stepsize = 1; // Line Search
1766
+ while (stepsize >= min_step)
1767
+ {
1768
+ newA = A + stepsize * dA;
1769
+ newB = B + stepsize * dB;
1770
+
1771
+ // New function value
1772
+ newf = 0.0;
1773
+ for (i=0;i<l;i++)
1774
+ {
1775
+ fApB = dec_values[i]*newA+newB;
1776
+ if (fApB >= 0)
1777
+ newf += t[i]*fApB + log(1+exp(-fApB));
1778
+ else
1779
+ newf += (t[i] - 1)*fApB +log(1+exp(fApB));
1780
+ }
1781
+ // Check sufficient decrease
1782
+ if (newf<fval+0.0001*stepsize*gd)
1783
+ {
1784
+ A=newA;B=newB;fval=newf;
1785
+ break;
1786
+ }
1787
+ else
1788
+ stepsize = stepsize / 2.0;
1789
+ }
1790
+
1791
+ if (stepsize < min_step)
1792
+ {
1793
+ info("Line search fails in two-class probability estimates\n");
1794
+ break;
1795
+ }
1796
+ }
1797
+
1798
+ if (iter>=max_iter)
1799
+ info("Reaching maximal iterations in two-class probability estimates\n");
1800
+ free(t);
1801
+ }
1802
+
1803
+ static double sigmoid_predict(double decision_value, double A, double B)
1804
+ {
1805
+ double fApB = decision_value*A+B;
1806
+ if (fApB >= 0)
1807
+ return exp(-fApB)/(1.0+exp(-fApB));
1808
+ else
1809
+ return 1.0/(1+exp(fApB)) ;
1810
+ }
1811
+
1812
+ // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
1813
+ static void multiclass_probability(int k, double **r, double *p)
1814
+ {
1815
+ int t,j;
1816
+ int iter = 0, max_iter=max(100,k);
1817
+ double **Q=Malloc(double *,k);
1818
+ double *Qp=Malloc(double,k);
1819
+ double pQp, eps=0.005/k;
1820
+
1821
+ for (t=0;t<k;t++)
1822
+ {
1823
+ p[t]=1.0/k; // Valid if k = 1
1824
+ Q[t]=Malloc(double,k);
1825
+ Q[t][t]=0;
1826
+ for (j=0;j<t;j++)
1827
+ {
1828
+ Q[t][t]+=r[j][t]*r[j][t];
1829
+ Q[t][j]=Q[j][t];
1830
+ }
1831
+ for (j=t+1;j<k;j++)
1832
+ {
1833
+ Q[t][t]+=r[j][t]*r[j][t];
1834
+ Q[t][j]=-r[j][t]*r[t][j];
1835
+ }
1836
+ }
1837
+ for (iter=0;iter<max_iter;iter++)
1838
+ {
1839
+ // stopping condition, recalculate QP,pQP for numerical accuracy
1840
+ pQp=0;
1841
+ for (t=0;t<k;t++)
1842
+ {
1843
+ Qp[t]=0;
1844
+ for (j=0;j<k;j++)
1845
+ Qp[t]+=Q[t][j]*p[j];
1846
+ pQp+=p[t]*Qp[t];
1847
+ }
1848
+ double max_error=0;
1849
+ for (t=0;t<k;t++)
1850
+ {
1851
+ double error=fabs(Qp[t]-pQp);
1852
+ if (error>max_error)
1853
+ max_error=error;
1854
+ }
1855
+ if (max_error<eps) break;
1856
+
1857
+ for (t=0;t<k;t++)
1858
+ {
1859
+ double diff=(-Qp[t]+pQp)/Q[t][t];
1860
+ p[t]+=diff;
1861
+ pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
1862
+ for (j=0;j<k;j++)
1863
+ {
1864
+ Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
1865
+ p[j]/=(1+diff);
1866
+ }
1867
+ }
1868
+ }
1869
+ if (iter>=max_iter)
1870
+ info("Exceeds max_iter in multiclass_prob\n");
1871
+ for(t=0;t<k;t++) free(Q[t]);
1872
+ free(Q);
1873
+ free(Qp);
1874
+ }
1875
+
1876
+ // Cross-validation decision values for probability estimates
1877
+ static void svm_binary_svc_probability(
1878
+ const svm_problem *prob, const svm_parameter *param,
1879
+ double Cp, double Cn, double& probA, double& probB)
1880
+ {
1881
+ int i;
1882
+ int nr_fold = 5;
1883
+ int *perm = Malloc(int,prob->l);
1884
+ double *dec_values = Malloc(double,prob->l);
1885
+
1886
+ // random shuffle
1887
+ for(i=0;i<prob->l;i++) perm[i]=i;
1888
+ for(i=0;i<prob->l;i++)
1889
+ {
1890
+ int j = i+rand()%(prob->l-i);
1891
+ swap(perm[i],perm[j]);
1892
+ }
1893
+ for(i=0;i<nr_fold;i++)
1894
+ {
1895
+ int begin = i*prob->l/nr_fold;
1896
+ int end = (i+1)*prob->l/nr_fold;
1897
+ int j,k;
1898
+ struct svm_problem subprob;
1899
+
1900
+ subprob.l = prob->l-(end-begin);
1901
+ subprob.x = Malloc(struct svm_node*,subprob.l);
1902
+ subprob.y = Malloc(double,subprob.l);
1903
+
1904
+ k=0;
1905
+ for(j=0;j<begin;j++)
1906
+ {
1907
+ subprob.x[k] = prob->x[perm[j]];
1908
+ subprob.y[k] = prob->y[perm[j]];
1909
+ ++k;
1910
+ }
1911
+ for(j=end;j<prob->l;j++)
1912
+ {
1913
+ subprob.x[k] = prob->x[perm[j]];
1914
+ subprob.y[k] = prob->y[perm[j]];
1915
+ ++k;
1916
+ }
1917
+ int p_count=0,n_count=0;
1918
+ for(j=0;j<k;j++)
1919
+ if(subprob.y[j]>0)
1920
+ p_count++;
1921
+ else
1922
+ n_count++;
1923
+
1924
+ if(p_count==0 && n_count==0)
1925
+ for(j=begin;j<end;j++)
1926
+ dec_values[perm[j]] = 0;
1927
+ else if(p_count > 0 && n_count == 0)
1928
+ for(j=begin;j<end;j++)
1929
+ dec_values[perm[j]] = 1;
1930
+ else if(p_count == 0 && n_count > 0)
1931
+ for(j=begin;j<end;j++)
1932
+ dec_values[perm[j]] = -1;
1933
+ else
1934
+ {
1935
+ svm_parameter subparam = *param;
1936
+ subparam.probability=0;
1937
+ subparam.C=1.0;
1938
+ subparam.nr_weight=2;
1939
+ subparam.weight_label = Malloc(int,2);
1940
+ subparam.weight = Malloc(double,2);
1941
+ subparam.weight_label[0]=+1;
1942
+ subparam.weight_label[1]=-1;
1943
+ subparam.weight[0]=Cp;
1944
+ subparam.weight[1]=Cn;
1945
+ struct svm_model *submodel = svm_train(&subprob,&subparam);
1946
+ for(j=begin;j<end;j++)
1947
+ {
1948
+ svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
1949
+ // ensure +1 -1 order; reason not using CV subroutine
1950
+ dec_values[perm[j]] *= submodel->label[0];
1951
+ }
1952
+ svm_free_and_destroy_model(&submodel);
1953
+ svm_destroy_param(&subparam);
1954
+ }
1955
+ free(subprob.x);
1956
+ free(subprob.y);
1957
+ }
1958
+ sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
1959
+ free(dec_values);
1960
+ free(perm);
1961
+ }
1962
+
1963
+ // Return parameter of a Laplace distribution
1964
+ static double svm_svr_probability(
1965
+ const svm_problem *prob, const svm_parameter *param)
1966
+ {
1967
+ int i;
1968
+ int nr_fold = 5;
1969
+ double *ymv = Malloc(double,prob->l);
1970
+ double mae = 0;
1971
+
1972
+ svm_parameter newparam = *param;
1973
+ newparam.probability = 0;
1974
+ svm_cross_validation(prob,&newparam,nr_fold,ymv);
1975
+ for(i=0;i<prob->l;i++)
1976
+ {
1977
+ ymv[i]=prob->y[i]-ymv[i];
1978
+ mae += fabs(ymv[i]);
1979
+ }
1980
+ mae /= prob->l;
1981
+ double std=sqrt(2*mae*mae);
1982
+ int count=0;
1983
+ mae=0;
1984
+ for(i=0;i<prob->l;i++)
1985
+ if (fabs(ymv[i]) > 5*std)
1986
+ count=count+1;
1987
+ else
1988
+ mae+=fabs(ymv[i]);
1989
+ mae /= (prob->l-count);
1990
+ info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
1991
+ free(ymv);
1992
+ return mae;
1993
+ }
1994
+
1995
+
1996
+ // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
1997
+ // perm, length l, must be allocated before calling this subroutine
1998
+ static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
1999
+ {
2000
+ int l = prob->l;
2001
+ int max_nr_class = 16;
2002
+ int nr_class = 0;
2003
+ int *label = Malloc(int,max_nr_class);
2004
+ int *count = Malloc(int,max_nr_class);
2005
+ int *data_label = Malloc(int,l);
2006
+ int i;
2007
+
2008
+ for(i=0;i<l;i++)
2009
+ {
2010
+ int this_label = (int)prob->y[i];
2011
+ int j;
2012
+ for(j=0;j<nr_class;j++)
2013
+ {
2014
+ if(this_label == label[j])
2015
+ {
2016
+ ++count[j];
2017
+ break;
2018
+ }
2019
+ }
2020
+ data_label[i] = j;
2021
+ if(j == nr_class)
2022
+ {
2023
+ if(nr_class == max_nr_class)
2024
+ {
2025
+ max_nr_class *= 2;
2026
+ label = (int *)realloc(label,max_nr_class*sizeof(int));
2027
+ count = (int *)realloc(count,max_nr_class*sizeof(int));
2028
+ }
2029
+ label[nr_class] = this_label;
2030
+ count[nr_class] = 1;
2031
+ ++nr_class;
2032
+ }
2033
+ }
2034
+
2035
+ int *start = Malloc(int,nr_class);
2036
+ start[0] = 0;
2037
+ for(i=1;i<nr_class;i++)
2038
+ start[i] = start[i-1]+count[i-1];
2039
+ for(i=0;i<l;i++)
2040
+ {
2041
+ perm[start[data_label[i]]] = i;
2042
+ ++start[data_label[i]];
2043
+ }
2044
+ start[0] = 0;
2045
+ for(i=1;i<nr_class;i++)
2046
+ start[i] = start[i-1]+count[i-1];
2047
+
2048
+ *nr_class_ret = nr_class;
2049
+ *label_ret = label;
2050
+ *start_ret = start;
2051
+ *count_ret = count;
2052
+ free(data_label);
2053
+ }
2054
+
2055
+ //
2056
+ // Interface functions
2057
+ //
2058
+ svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
2059
+ {
2060
+ svm_model *model = Malloc(svm_model,1);
2061
+ model->param = *param;
2062
+ model->free_sv = 0; // XXX
2063
+
2064
+ if(param->svm_type == ONE_CLASS ||
2065
+ param->svm_type == EPSILON_SVR ||
2066
+ param->svm_type == NU_SVR)
2067
+ {
2068
+ // regression or one-class-svm
2069
+ model->nr_class = 2;
2070
+ model->label = NULL;
2071
+ model->nSV = NULL;
2072
+ model->probA = NULL; model->probB = NULL;
2073
+ model->sv_coef = Malloc(double *,1);
2074
+
2075
+ if(param->probability &&
2076
+ (param->svm_type == EPSILON_SVR ||
2077
+ param->svm_type == NU_SVR))
2078
+ {
2079
+ model->probA = Malloc(double,1);
2080
+ model->probA[0] = svm_svr_probability(prob,param);
2081
+ }
2082
+
2083
+ decision_function f = svm_train_one(prob,param,0,0);
2084
+ model->rho = Malloc(double,1);
2085
+ model->rho[0] = f.rho;
2086
+
2087
+ int nSV = 0;
2088
+ int i;
2089
+ for(i=0;i<prob->l;i++)
2090
+ if(fabs(f.alpha[i]) > 0) ++nSV;
2091
+ model->l = nSV;
2092
+ model->SV = Malloc(svm_node *,nSV);
2093
+ model->sv_coef[0] = Malloc(double,nSV);
2094
+ int j = 0;
2095
+ for(i=0;i<prob->l;i++)
2096
+ if(fabs(f.alpha[i]) > 0)
2097
+ {
2098
+ model->SV[j] = prob->x[i];
2099
+ model->sv_coef[0][j] = f.alpha[i];
2100
+ ++j;
2101
+ }
2102
+
2103
+ free(f.alpha);
2104
+ }
2105
+ else
2106
+ {
2107
+ // classification
2108
+ int l = prob->l;
2109
+ int nr_class;
2110
+ int *label = NULL;
2111
+ int *start = NULL;
2112
+ int *count = NULL;
2113
+ int *perm = Malloc(int,l);
2114
+
2115
+ // group training data of the same class
2116
+ svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2117
+ svm_node **x = Malloc(svm_node *,l);
2118
+ int i;
2119
+ for(i=0;i<l;i++)
2120
+ x[i] = prob->x[perm[i]];
2121
+
2122
+ // calculate weighted C
2123
+
2124
+ double *weighted_C = Malloc(double, nr_class);
2125
+ for(i=0;i<nr_class;i++)
2126
+ weighted_C[i] = param->C;
2127
+ for(i=0;i<param->nr_weight;i++)
2128
+ {
2129
+ int j;
2130
+ for(j=0;j<nr_class;j++)
2131
+ if(param->weight_label[i] == label[j])
2132
+ break;
2133
+ if(j == nr_class)
2134
+ fprintf(stderr,"warning: class label %d specified in weight is not found\n", param->weight_label[i]);
2135
+ else
2136
+ weighted_C[j] *= param->weight[i];
2137
+ }
2138
+
2139
+ // train k*(k-1)/2 models
2140
+
2141
+ bool *nonzero = Malloc(bool,l);
2142
+ for(i=0;i<l;i++)
2143
+ nonzero[i] = false;
2144
+ decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
2145
+
2146
+ double *probA=NULL,*probB=NULL;
2147
+ if (param->probability)
2148
+ {
2149
+ probA=Malloc(double,nr_class*(nr_class-1)/2);
2150
+ probB=Malloc(double,nr_class*(nr_class-1)/2);
2151
+ }
2152
+
2153
+ int p = 0;
2154
+ for(i=0;i<nr_class;i++)
2155
+ for(int j=i+1;j<nr_class;j++)
2156
+ {
2157
+ svm_problem sub_prob;
2158
+ int si = start[i], sj = start[j];
2159
+ int ci = count[i], cj = count[j];
2160
+ sub_prob.l = ci+cj;
2161
+ sub_prob.x = Malloc(svm_node *,sub_prob.l);
2162
+ sub_prob.y = Malloc(double,sub_prob.l);
2163
+ int k;
2164
+ for(k=0;k<ci;k++)
2165
+ {
2166
+ sub_prob.x[k] = x[si+k];
2167
+ sub_prob.y[k] = +1;
2168
+ }
2169
+ for(k=0;k<cj;k++)
2170
+ {
2171
+ sub_prob.x[ci+k] = x[sj+k];
2172
+ sub_prob.y[ci+k] = -1;
2173
+ }
2174
+
2175
+ if(param->probability)
2176
+ svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
2177
+
2178
+ f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
2179
+ for(k=0;k<ci;k++)
2180
+ if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
2181
+ nonzero[si+k] = true;
2182
+ for(k=0;k<cj;k++)
2183
+ if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
2184
+ nonzero[sj+k] = true;
2185
+ free(sub_prob.x);
2186
+ free(sub_prob.y);
2187
+ ++p;
2188
+ }
2189
+
2190
+ // build output
2191
+
2192
+ model->nr_class = nr_class;
2193
+
2194
+ model->label = Malloc(int,nr_class);
2195
+ for(i=0;i<nr_class;i++)
2196
+ model->label[i] = label[i];
2197
+
2198
+ model->rho = Malloc(double,nr_class*(nr_class-1)/2);
2199
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2200
+ model->rho[i] = f[i].rho;
2201
+
2202
+ if(param->probability)
2203
+ {
2204
+ model->probA = Malloc(double,nr_class*(nr_class-1)/2);
2205
+ model->probB = Malloc(double,nr_class*(nr_class-1)/2);
2206
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2207
+ {
2208
+ model->probA[i] = probA[i];
2209
+ model->probB[i] = probB[i];
2210
+ }
2211
+ }
2212
+ else
2213
+ {
2214
+ model->probA=NULL;
2215
+ model->probB=NULL;
2216
+ }
2217
+
2218
+ int total_sv = 0;
2219
+ int *nz_count = Malloc(int,nr_class);
2220
+ model->nSV = Malloc(int,nr_class);
2221
+ for(i=0;i<nr_class;i++)
2222
+ {
2223
+ int nSV = 0;
2224
+ for(int j=0;j<count[i];j++)
2225
+ if(nonzero[start[i]+j])
2226
+ {
2227
+ ++nSV;
2228
+ ++total_sv;
2229
+ }
2230
+ model->nSV[i] = nSV;
2231
+ nz_count[i] = nSV;
2232
+ }
2233
+
2234
+ info("Total nSV = %d\n",total_sv);
2235
+
2236
+ model->l = total_sv;
2237
+ model->SV = Malloc(svm_node *,total_sv);
2238
+ p = 0;
2239
+ for(i=0;i<l;i++)
2240
+ if(nonzero[i]) model->SV[p++] = x[i];
2241
+
2242
+ int *nz_start = Malloc(int,nr_class);
2243
+ nz_start[0] = 0;
2244
+ for(i=1;i<nr_class;i++)
2245
+ nz_start[i] = nz_start[i-1]+nz_count[i-1];
2246
+
2247
+ model->sv_coef = Malloc(double *,nr_class-1);
2248
+ for(i=0;i<nr_class-1;i++)
2249
+ model->sv_coef[i] = Malloc(double,total_sv);
2250
+
2251
+ p = 0;
2252
+ for(i=0;i<nr_class;i++)
2253
+ for(int j=i+1;j<nr_class;j++)
2254
+ {
2255
+ // classifier (i,j): coefficients with
2256
+ // i are in sv_coef[j-1][nz_start[i]...],
2257
+ // j are in sv_coef[i][nz_start[j]...]
2258
+
2259
+ int si = start[i];
2260
+ int sj = start[j];
2261
+ int ci = count[i];
2262
+ int cj = count[j];
2263
+
2264
+ int q = nz_start[i];
2265
+ int k;
2266
+ for(k=0;k<ci;k++)
2267
+ if(nonzero[si+k])
2268
+ model->sv_coef[j-1][q++] = f[p].alpha[k];
2269
+ q = nz_start[j];
2270
+ for(k=0;k<cj;k++)
2271
+ if(nonzero[sj+k])
2272
+ model->sv_coef[i][q++] = f[p].alpha[ci+k];
2273
+ ++p;
2274
+ }
2275
+
2276
+ free(label);
2277
+ free(probA);
2278
+ free(probB);
2279
+ free(count);
2280
+ free(perm);
2281
+ free(start);
2282
+ free(x);
2283
+ free(weighted_C);
2284
+ free(nonzero);
2285
+ for(i=0;i<nr_class*(nr_class-1)/2;i++)
2286
+ free(f[i].alpha);
2287
+ free(f);
2288
+ free(nz_count);
2289
+ free(nz_start);
2290
+ }
2291
+ return model;
2292
+ }
2293
+
2294
+ // Stratified cross validation
2295
+ void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
2296
+ {
2297
+ int i;
2298
+ int *fold_start = Malloc(int,nr_fold+1);
2299
+ int l = prob->l;
2300
+ int *perm = Malloc(int,l);
2301
+ int nr_class;
2302
+
2303
+ // stratified cv may not give leave-one-out rate
2304
+ // Each class to l folds -> some folds may have zero elements
2305
+ if((param->svm_type == C_SVC ||
2306
+ param->svm_type == NU_SVC) && nr_fold < l)
2307
+ {
2308
+ int *start = NULL;
2309
+ int *label = NULL;
2310
+ int *count = NULL;
2311
+ svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
2312
+
2313
+ // random shuffle and then data grouped by fold using the array perm
2314
+ int *fold_count = Malloc(int,nr_fold);
2315
+ int c;
2316
+ int *index = Malloc(int,l);
2317
+ for(i=0;i<l;i++)
2318
+ index[i]=perm[i];
2319
+ for (c=0; c<nr_class; c++)
2320
+ for(i=0;i<count[c];i++)
2321
+ {
2322
+ int j = i+rand()%(count[c]-i);
2323
+ swap(index[start[c]+j],index[start[c]+i]);
2324
+ }
2325
+ for(i=0;i<nr_fold;i++)
2326
+ {
2327
+ fold_count[i] = 0;
2328
+ for (c=0; c<nr_class;c++)
2329
+ fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
2330
+ }
2331
+ fold_start[0]=0;
2332
+ for (i=1;i<=nr_fold;i++)
2333
+ fold_start[i] = fold_start[i-1]+fold_count[i-1];
2334
+ for (c=0; c<nr_class;c++)
2335
+ for(i=0;i<nr_fold;i++)
2336
+ {
2337
+ int begin = start[c]+i*count[c]/nr_fold;
2338
+ int end = start[c]+(i+1)*count[c]/nr_fold;
2339
+ for(int j=begin;j<end;j++)
2340
+ {
2341
+ perm[fold_start[i]] = index[j];
2342
+ fold_start[i]++;
2343
+ }
2344
+ }
2345
+ fold_start[0]=0;
2346
+ for (i=1;i<=nr_fold;i++)
2347
+ fold_start[i] = fold_start[i-1]+fold_count[i-1];
2348
+ free(start);
2349
+ free(label);
2350
+ free(count);
2351
+ free(index);
2352
+ free(fold_count);
2353
+ }
2354
+ else
2355
+ {
2356
+ for(i=0;i<l;i++) perm[i]=i;
2357
+ for(i=0;i<l;i++)
2358
+ {
2359
+ int j = i+rand()%(l-i);
2360
+ swap(perm[i],perm[j]);
2361
+ }
2362
+ for(i=0;i<=nr_fold;i++)
2363
+ fold_start[i]=i*l/nr_fold;
2364
+ }
2365
+
2366
+ for(i=0;i<nr_fold;i++)
2367
+ {
2368
+ int begin = fold_start[i];
2369
+ int end = fold_start[i+1];
2370
+ int j,k;
2371
+ struct svm_problem subprob;
2372
+
2373
+ subprob.l = l-(end-begin);
2374
+ subprob.x = Malloc(struct svm_node*,subprob.l);
2375
+ subprob.y = Malloc(double,subprob.l);
2376
+
2377
+ k=0;
2378
+ for(j=0;j<begin;j++)
2379
+ {
2380
+ subprob.x[k] = prob->x[perm[j]];
2381
+ subprob.y[k] = prob->y[perm[j]];
2382
+ ++k;
2383
+ }
2384
+ for(j=end;j<l;j++)
2385
+ {
2386
+ subprob.x[k] = prob->x[perm[j]];
2387
+ subprob.y[k] = prob->y[perm[j]];
2388
+ ++k;
2389
+ }
2390
+ struct svm_model *submodel = svm_train(&subprob,param);
2391
+ if(param->probability &&
2392
+ (param->svm_type == C_SVC || param->svm_type == NU_SVC))
2393
+ {
2394
+ double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
2395
+ for(j=begin;j<end;j++)
2396
+ target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
2397
+ free(prob_estimates);
2398
+ }
2399
+ else
2400
+ for(j=begin;j<end;j++)
2401
+ target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
2402
+ svm_free_and_destroy_model(&submodel);
2403
+ free(subprob.x);
2404
+ free(subprob.y);
2405
+ }
2406
+ free(fold_start);
2407
+ free(perm);
2408
+ }
2409
+
2410
+
2411
+ int svm_get_svm_type(const svm_model *model)
2412
+ {
2413
+ return model->param.svm_type;
2414
+ }
2415
+
2416
+ int svm_get_nr_class(const svm_model *model)
2417
+ {
2418
+ return model->nr_class;
2419
+ }
2420
+
2421
+ void svm_get_labels(const svm_model *model, int* label)
2422
+ {
2423
+ if (model->label != NULL)
2424
+ for(int i=0;i<model->nr_class;i++)
2425
+ label[i] = model->label[i];
2426
+ }
2427
+
2428
+ double svm_get_svr_probability(const svm_model *model)
2429
+ {
2430
+ if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
2431
+ model->probA!=NULL)
2432
+ return model->probA[0];
2433
+ else
2434
+ {
2435
+ fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
2436
+ return 0;
2437
+ }
2438
+ }
2439
+
2440
+ double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
2441
+ {
2442
+ if(model->param.svm_type == ONE_CLASS ||
2443
+ model->param.svm_type == EPSILON_SVR ||
2444
+ model->param.svm_type == NU_SVR)
2445
+ {
2446
+ double *sv_coef = model->sv_coef[0];
2447
+ double sum = 0;
2448
+ for(int i=0;i<model->l;i++)
2449
+ sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
2450
+ sum -= model->rho[0];
2451
+ *dec_values = sum;
2452
+
2453
+ if(model->param.svm_type == ONE_CLASS)
2454
+ return (sum>0)?1:-1;
2455
+ else
2456
+ return sum;
2457
+ }
2458
+ else
2459
+ {
2460
+ int i;
2461
+ int nr_class = model->nr_class;
2462
+ int l = model->l;
2463
+
2464
+ double *kvalue = Malloc(double,l);
2465
+ for(i=0;i<l;i++)
2466
+ kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
2467
+
2468
+ int *start = Malloc(int,nr_class);
2469
+ start[0] = 0;
2470
+ for(i=1;i<nr_class;i++)
2471
+ start[i] = start[i-1]+model->nSV[i-1];
2472
+
2473
+ int *vote = Malloc(int,nr_class);
2474
+ for(i=0;i<nr_class;i++)
2475
+ vote[i] = 0;
2476
+
2477
+ int p=0;
2478
+ for(i=0;i<nr_class;i++)
2479
+ for(int j=i+1;j<nr_class;j++)
2480
+ {
2481
+ double sum = 0;
2482
+ int si = start[i];
2483
+ int sj = start[j];
2484
+ int ci = model->nSV[i];
2485
+ int cj = model->nSV[j];
2486
+
2487
+ int k;
2488
+ double *coef1 = model->sv_coef[j-1];
2489
+ double *coef2 = model->sv_coef[i];
2490
+ for(k=0;k<ci;k++)
2491
+ sum += coef1[si+k] * kvalue[si+k];
2492
+ for(k=0;k<cj;k++)
2493
+ sum += coef2[sj+k] * kvalue[sj+k];
2494
+ sum -= model->rho[p];
2495
+ dec_values[p] = sum;
2496
+
2497
+ if(dec_values[p] > 0)
2498
+ ++vote[i];
2499
+ else
2500
+ ++vote[j];
2501
+ p++;
2502
+ }
2503
+
2504
+ int vote_max_idx = 0;
2505
+ for(i=1;i<nr_class;i++)
2506
+ if(vote[i] > vote[vote_max_idx])
2507
+ vote_max_idx = i;
2508
+
2509
+ free(kvalue);
2510
+ free(start);
2511
+ free(vote);
2512
+ return model->label[vote_max_idx];
2513
+ }
2514
+ }
2515
+
2516
+ double svm_predict(const svm_model *model, const svm_node *x)
2517
+ {
2518
+ int nr_class = model->nr_class;
2519
+ double *dec_values;
2520
+ if(model->param.svm_type == ONE_CLASS ||
2521
+ model->param.svm_type == EPSILON_SVR ||
2522
+ model->param.svm_type == NU_SVR)
2523
+ dec_values = Malloc(double, 1);
2524
+ else
2525
+ dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2526
+ double pred_result = svm_predict_values(model, x, dec_values);
2527
+ free(dec_values);
2528
+ return pred_result;
2529
+ }
2530
+
2531
+ double svm_predict_probability(
2532
+ const svm_model *model, const svm_node *x, double *prob_estimates)
2533
+ {
2534
+ if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
2535
+ model->probA!=NULL && model->probB!=NULL)
2536
+ {
2537
+ int i;
2538
+ int nr_class = model->nr_class;
2539
+ double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
2540
+ svm_predict_values(model, x, dec_values);
2541
+
2542
+ double min_prob=1e-7;
2543
+ double **pairwise_prob=Malloc(double *,nr_class);
2544
+ for(i=0;i<nr_class;i++)
2545
+ pairwise_prob[i]=Malloc(double,nr_class);
2546
+ int k=0;
2547
+ for(i=0;i<nr_class;i++)
2548
+ for(int j=i+1;j<nr_class;j++)
2549
+ {
2550
+ pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
2551
+ pairwise_prob[j][i]=1-pairwise_prob[i][j];
2552
+ k++;
2553
+ }
2554
+ multiclass_probability(nr_class,pairwise_prob,prob_estimates);
2555
+
2556
+ int prob_max_idx = 0;
2557
+ for(i=1;i<nr_class;i++)
2558
+ if(prob_estimates[i] > prob_estimates[prob_max_idx])
2559
+ prob_max_idx = i;
2560
+ for(i=0;i<nr_class;i++)
2561
+ free(pairwise_prob[i]);
2562
+ free(dec_values);
2563
+ free(pairwise_prob);
2564
+ return model->label[prob_max_idx];
2565
+ }
2566
+ else
2567
+ return svm_predict(model, x);
2568
+ }
2569
+
2570
+ static const char *svm_type_table[] =
2571
+ {
2572
+ "c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
2573
+ };
2574
+
2575
+ static const char *kernel_type_table[]=
2576
+ {
2577
+ "linear","polynomial","rbf","sigmoid","precomputed",NULL
2578
+ };
2579
+
2580
+ int svm_save_model(const char *model_file_name, const svm_model *model)
2581
+ {
2582
+ FILE *fp = fopen(model_file_name,"w");
2583
+ if(fp==NULL) return -1;
2584
+
2585
+ const svm_parameter& param = model->param;
2586
+
2587
+ fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
2588
+ fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
2589
+
2590
+ if(param.kernel_type == POLY)
2591
+ fprintf(fp,"degree %d\n", param.degree);
2592
+
2593
+ if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
2594
+ fprintf(fp,"gamma %g\n", param.gamma);
2595
+
2596
+ if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
2597
+ fprintf(fp,"coef0 %g\n", param.coef0);
2598
+
2599
+ int nr_class = model->nr_class;
2600
+ int l = model->l;
2601
+ fprintf(fp, "nr_class %d\n", nr_class);
2602
+ fprintf(fp, "total_sv %d\n",l);
2603
+
2604
+ {
2605
+ fprintf(fp, "rho");
2606
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2607
+ fprintf(fp," %g",model->rho[i]);
2608
+ fprintf(fp, "\n");
2609
+ }
2610
+
2611
+ if(model->label)
2612
+ {
2613
+ fprintf(fp, "label");
2614
+ for(int i=0;i<nr_class;i++)
2615
+ fprintf(fp," %d",model->label[i]);
2616
+ fprintf(fp, "\n");
2617
+ }
2618
+
2619
+ if(model->probA) // regression has probA only
2620
+ {
2621
+ fprintf(fp, "probA");
2622
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2623
+ fprintf(fp," %g",model->probA[i]);
2624
+ fprintf(fp, "\n");
2625
+ }
2626
+ if(model->probB)
2627
+ {
2628
+ fprintf(fp, "probB");
2629
+ for(int i=0;i<nr_class*(nr_class-1)/2;i++)
2630
+ fprintf(fp," %g",model->probB[i]);
2631
+ fprintf(fp, "\n");
2632
+ }
2633
+
2634
+ if(model->nSV)
2635
+ {
2636
+ fprintf(fp, "nr_sv");
2637
+ for(int i=0;i<nr_class;i++)
2638
+ fprintf(fp," %d",model->nSV[i]);
2639
+ fprintf(fp, "\n");
2640
+ }
2641
+
2642
+ fprintf(fp, "SV\n");
2643
+ const double * const *sv_coef = model->sv_coef;
2644
+ const svm_node * const *SV = model->SV;
2645
+
2646
+ for(int i=0;i<l;i++)
2647
+ {
2648
+ for(int j=0;j<nr_class-1;j++)
2649
+ fprintf(fp, "%.16g ",sv_coef[j][i]);
2650
+
2651
+ const svm_node *p = SV[i];
2652
+
2653
+ if(param.kernel_type == PRECOMPUTED)
2654
+ fprintf(fp,"0:%d ",(int)(p->value));
2655
+ else
2656
+ while(p->index != -1)
2657
+ {
2658
+ fprintf(fp,"%d:%.8g ",p->index,p->value);
2659
+ p++;
2660
+ }
2661
+ fprintf(fp, "\n");
2662
+ }
2663
+ if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
2664
+ else return 0;
2665
+ }
2666
+
2667
+ static char *line = NULL;
2668
+ static int max_line_len;
2669
+
2670
+ static char* readline(FILE *input)
2671
+ {
2672
+ int len;
2673
+
2674
+ if(fgets(line,max_line_len,input) == NULL)
2675
+ return NULL;
2676
+
2677
+ while(strrchr(line,'\n') == NULL)
2678
+ {
2679
+ max_line_len *= 2;
2680
+ line = (char *) realloc(line,max_line_len);
2681
+ len = (int) strlen(line);
2682
+ if(fgets(line+len,max_line_len-len,input) == NULL)
2683
+ break;
2684
+ }
2685
+ return line;
2686
+ }
2687
+
2688
+ svm_model *svm_load_model(const char *model_file_name)
2689
+ {
2690
+ FILE *fp = fopen(model_file_name,"rb");
2691
+ if(fp==NULL) return NULL;
2692
+
2693
+ // read parameters
2694
+
2695
+ svm_model *model = Malloc(svm_model,1);
2696
+ svm_parameter& param = model->param;
2697
+ model->rho = NULL;
2698
+ model->probA = NULL;
2699
+ model->probB = NULL;
2700
+ model->label = NULL;
2701
+ model->nSV = NULL;
2702
+
2703
+ char cmd[81];
2704
+ while(1)
2705
+ {
2706
+ fscanf(fp,"%80s",cmd);
2707
+
2708
+ if(strcmp(cmd,"svm_type")==0)
2709
+ {
2710
+ fscanf(fp,"%80s",cmd);
2711
+ int i;
2712
+ for(i=0;svm_type_table[i];i++)
2713
+ {
2714
+ if(strcmp(svm_type_table[i],cmd)==0)
2715
+ {
2716
+ param.svm_type=i;
2717
+ break;
2718
+ }
2719
+ }
2720
+ if(svm_type_table[i] == NULL)
2721
+ {
2722
+ fprintf(stderr,"unknown svm type.\n");
2723
+ free(model->rho);
2724
+ free(model->label);
2725
+ free(model->nSV);
2726
+ free(model);
2727
+ return NULL;
2728
+ }
2729
+ }
2730
+ else if(strcmp(cmd,"kernel_type")==0)
2731
+ {
2732
+ fscanf(fp,"%80s",cmd);
2733
+ int i;
2734
+ for(i=0;kernel_type_table[i];i++)
2735
+ {
2736
+ if(strcmp(kernel_type_table[i],cmd)==0)
2737
+ {
2738
+ param.kernel_type=i;
2739
+ break;
2740
+ }
2741
+ }
2742
+ if(kernel_type_table[i] == NULL)
2743
+ {
2744
+ fprintf(stderr,"unknown kernel function.\n");
2745
+ free(model->rho);
2746
+ free(model->label);
2747
+ free(model->nSV);
2748
+ free(model);
2749
+ return NULL;
2750
+ }
2751
+ }
2752
+ else if(strcmp(cmd,"degree")==0)
2753
+ fscanf(fp,"%d",&param.degree);
2754
+ else if(strcmp(cmd,"gamma")==0)
2755
+ fscanf(fp,"%lf",&param.gamma);
2756
+ else if(strcmp(cmd,"coef0")==0)
2757
+ fscanf(fp,"%lf",&param.coef0);
2758
+ else if(strcmp(cmd,"nr_class")==0)
2759
+ fscanf(fp,"%d",&model->nr_class);
2760
+ else if(strcmp(cmd,"total_sv")==0)
2761
+ fscanf(fp,"%d",&model->l);
2762
+ else if(strcmp(cmd,"rho")==0)
2763
+ {
2764
+ int n = model->nr_class * (model->nr_class-1)/2;
2765
+ model->rho = Malloc(double,n);
2766
+ for(int i=0;i<n;i++)
2767
+ fscanf(fp,"%lf",&model->rho[i]);
2768
+ }
2769
+ else if(strcmp(cmd,"label")==0)
2770
+ {
2771
+ int n = model->nr_class;
2772
+ model->label = Malloc(int,n);
2773
+ for(int i=0;i<n;i++)
2774
+ fscanf(fp,"%d",&model->label[i]);
2775
+ }
2776
+ else if(strcmp(cmd,"probA")==0)
2777
+ {
2778
+ int n = model->nr_class * (model->nr_class-1)/2;
2779
+ model->probA = Malloc(double,n);
2780
+ for(int i=0;i<n;i++)
2781
+ fscanf(fp,"%lf",&model->probA[i]);
2782
+ }
2783
+ else if(strcmp(cmd,"probB")==0)
2784
+ {
2785
+ int n = model->nr_class * (model->nr_class-1)/2;
2786
+ model->probB = Malloc(double,n);
2787
+ for(int i=0;i<n;i++)
2788
+ fscanf(fp,"%lf",&model->probB[i]);
2789
+ }
2790
+ else if(strcmp(cmd,"nr_sv")==0)
2791
+ {
2792
+ int n = model->nr_class;
2793
+ model->nSV = Malloc(int,n);
2794
+ for(int i=0;i<n;i++)
2795
+ fscanf(fp,"%d",&model->nSV[i]);
2796
+ }
2797
+ else if(strcmp(cmd,"SV")==0)
2798
+ {
2799
+ while(1)
2800
+ {
2801
+ int c = getc(fp);
2802
+ if(c==EOF || c=='\n') break;
2803
+ }
2804
+ break;
2805
+ }
2806
+ else
2807
+ {
2808
+ fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
2809
+ free(model->rho);
2810
+ free(model->label);
2811
+ free(model->nSV);
2812
+ free(model);
2813
+ return NULL;
2814
+ }
2815
+ }
2816
+
2817
+ // read sv_coef and SV
2818
+
2819
+ int elements = 0;
2820
+ long pos = ftell(fp);
2821
+
2822
+ max_line_len = 1024;
2823
+ line = Malloc(char,max_line_len);
2824
+ char *p,*endptr,*idx,*val;
2825
+
2826
+ while(readline(fp)!=NULL)
2827
+ {
2828
+ p = strtok(line,":");
2829
+ while(1)
2830
+ {
2831
+ p = strtok(NULL,":");
2832
+ if(p == NULL)
2833
+ break;
2834
+ ++elements;
2835
+ }
2836
+ }
2837
+ elements += model->l;
2838
+
2839
+ fseek(fp,pos,SEEK_SET);
2840
+
2841
+ int m = model->nr_class - 1;
2842
+ int l = model->l;
2843
+ model->sv_coef = Malloc(double *,m);
2844
+ int i;
2845
+ for(i=0;i<m;i++)
2846
+ model->sv_coef[i] = Malloc(double,l);
2847
+ model->SV = Malloc(svm_node*,l);
2848
+ svm_node *x_space = NULL;
2849
+ if(l>0) x_space = Malloc(svm_node,elements);
2850
+
2851
+ int j=0;
2852
+ for(i=0;i<l;i++)
2853
+ {
2854
+ readline(fp);
2855
+ model->SV[i] = &x_space[j];
2856
+
2857
+ p = strtok(line, " \t");
2858
+ model->sv_coef[0][i] = strtod(p,&endptr);
2859
+ for(int k=1;k<m;k++)
2860
+ {
2861
+ p = strtok(NULL, " \t");
2862
+ model->sv_coef[k][i] = strtod(p,&endptr);
2863
+ }
2864
+
2865
+ while(1)
2866
+ {
2867
+ idx = strtok(NULL, ":");
2868
+ val = strtok(NULL, " \t");
2869
+
2870
+ if(val == NULL)
2871
+ break;
2872
+ x_space[j].index = (int) strtol(idx,&endptr,10);
2873
+ x_space[j].value = strtod(val,&endptr);
2874
+
2875
+ ++j;
2876
+ }
2877
+ x_space[j++].index = -1;
2878
+ }
2879
+ free(line);
2880
+
2881
+ if (ferror(fp) != 0 || fclose(fp) != 0)
2882
+ return NULL;
2883
+
2884
+ model->free_sv = 1; // XXX
2885
+ return model;
2886
+ }
2887
+
2888
+ void svm_free_model_content(svm_model* model_ptr)
2889
+ {
2890
+ if(model_ptr->free_sv && model_ptr->l > 0)
2891
+ free((void *)(model_ptr->SV[0]));
2892
+ for(int i=0;i<model_ptr->nr_class-1;i++)
2893
+ free(model_ptr->sv_coef[i]);
2894
+ free(model_ptr->SV);
2895
+ free(model_ptr->sv_coef);
2896
+ free(model_ptr->rho);
2897
+ free(model_ptr->label);
2898
+ free(model_ptr->probA);
2899
+ free(model_ptr->probB);
2900
+ free(model_ptr->nSV);
2901
+ }
2902
+
2903
+ void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
2904
+ {
2905
+ svm_model* model_ptr = *model_ptr_ptr;
2906
+ if(model_ptr != NULL)
2907
+ {
2908
+ svm_free_model_content(model_ptr);
2909
+ free(model_ptr);
2910
+ }
2911
+ }
2912
+
2913
+ void svm_destroy_model(svm_model* model_ptr)
2914
+ {
2915
+ fprintf(stderr,"warning: svm_destroy_model is deprecated and should not be used. Please use svm_free_and_destroy_model(svm_model **model_ptr_ptr)\n");
2916
+ svm_free_and_destroy_model(&model_ptr);
2917
+ }
2918
+
2919
+ void svm_destroy_param(svm_parameter* param)
2920
+ {
2921
+ free(param->weight_label);
2922
+ free(param->weight);
2923
+ }
2924
+
2925
+ const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
2926
+ {
2927
+ // svm_type
2928
+
2929
+ int svm_type = param->svm_type;
2930
+ if(svm_type != C_SVC &&
2931
+ svm_type != NU_SVC &&
2932
+ svm_type != ONE_CLASS &&
2933
+ svm_type != EPSILON_SVR &&
2934
+ svm_type != NU_SVR)
2935
+ return "unknown svm type";
2936
+
2937
+ // kernel_type, degree
2938
+
2939
+ int kernel_type = param->kernel_type;
2940
+ if(kernel_type != LINEAR &&
2941
+ kernel_type != POLY &&
2942
+ kernel_type != RBF &&
2943
+ kernel_type != SIGMOID &&
2944
+ kernel_type != PRECOMPUTED)
2945
+ return "unknown kernel type";
2946
+
2947
+ if(param->gamma < 0)
2948
+ return "gamma < 0";
2949
+
2950
+ if(param->degree < 0)
2951
+ return "degree of polynomial kernel < 0";
2952
+
2953
+ // cache_size,eps,C,nu,p,shrinking
2954
+
2955
+ if(param->cache_size <= 0)
2956
+ return "cache_size <= 0";
2957
+
2958
+ if(param->eps <= 0)
2959
+ return "eps <= 0";
2960
+
2961
+ if(svm_type == C_SVC ||
2962
+ svm_type == EPSILON_SVR ||
2963
+ svm_type == NU_SVR)
2964
+ if(param->C <= 0)
2965
+ return "C <= 0";
2966
+
2967
+ if(svm_type == NU_SVC ||
2968
+ svm_type == ONE_CLASS ||
2969
+ svm_type == NU_SVR)
2970
+ if(param->nu <= 0 || param->nu > 1)
2971
+ return "nu <= 0 or nu > 1";
2972
+
2973
+ if(svm_type == EPSILON_SVR)
2974
+ if(param->p < 0)
2975
+ return "p < 0";
2976
+
2977
+ if(param->shrinking != 0 &&
2978
+ param->shrinking != 1)
2979
+ return "shrinking != 0 and shrinking != 1";
2980
+
2981
+ if(param->probability != 0 &&
2982
+ param->probability != 1)
2983
+ return "probability != 0 and probability != 1";
2984
+
2985
+ if(param->probability == 1 &&
2986
+ svm_type == ONE_CLASS)
2987
+ return "one-class SVM probability output not supported yet";
2988
+
2989
+
2990
+ // check whether nu-svc is feasible
2991
+
2992
+ if(svm_type == NU_SVC)
2993
+ {
2994
+ int l = prob->l;
2995
+ int max_nr_class = 16;
2996
+ int nr_class = 0;
2997
+ int *label = Malloc(int,max_nr_class);
2998
+ int *count = Malloc(int,max_nr_class);
2999
+
3000
+ int i;
3001
+ for(i=0;i<l;i++)
3002
+ {
3003
+ int this_label = (int)prob->y[i];
3004
+ int j;
3005
+ for(j=0;j<nr_class;j++)
3006
+ if(this_label == label[j])
3007
+ {
3008
+ ++count[j];
3009
+ break;
3010
+ }
3011
+ if(j == nr_class)
3012
+ {
3013
+ if(nr_class == max_nr_class)
3014
+ {
3015
+ max_nr_class *= 2;
3016
+ label = (int *)realloc(label,max_nr_class*sizeof(int));
3017
+ count = (int *)realloc(count,max_nr_class*sizeof(int));
3018
+ }
3019
+ label[nr_class] = this_label;
3020
+ count[nr_class] = 1;
3021
+ ++nr_class;
3022
+ }
3023
+ }
3024
+
3025
+ for(i=0;i<nr_class;i++)
3026
+ {
3027
+ int n1 = count[i];
3028
+ for(int j=i+1;j<nr_class;j++)
3029
+ {
3030
+ int n2 = count[j];
3031
+ if(param->nu*(n1+n2)/2 > min(n1,n2))
3032
+ {
3033
+ free(label);
3034
+ free(count);
3035
+ return "specified nu is infeasible";
3036
+ }
3037
+ }
3038
+ }
3039
+ free(label);
3040
+ free(count);
3041
+ }
3042
+
3043
+ return NULL;
3044
+ }
3045
+
3046
+ int svm_check_probability_model(const svm_model *model)
3047
+ {
3048
+ return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
3049
+ model->probA!=NULL && model->probB!=NULL) ||
3050
+ ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
3051
+ model->probA!=NULL);
3052
+ }
3053
+
3054
+ void svm_set_print_string_function(void (*print_func)(const char *))
3055
+ {
3056
+ if(print_func == NULL)
3057
+ svm_print_string = &print_string_stdout;
3058
+ else
3059
+ svm_print_string = print_func;
3060
+ }