numo-libsvm 0.3.0 → 1.0.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.github/workflows/build.yml +27 -0
- data/.gitmodules +3 -0
- data/CHANGELOG.md +20 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -14
- data/ext/numo/libsvm/converter.c +57 -15
- data/ext/numo/libsvm/converter.h +2 -1
- data/ext/numo/libsvm/extconf.rb +7 -11
- data/ext/numo/libsvm/libsvm/svm.cpp +3182 -0
- data/ext/numo/libsvm/libsvm/svm.h +104 -0
- data/ext/numo/libsvm/libsvmext.c +62 -35
- data/ext/numo/libsvm/svm_parameter.c +2 -2
- data/ext/numo/libsvm/svm_problem.c +38 -6
- data/lib/numo/libsvm/version.rb +1 -1
- data/numo-libsvm.gemspec +15 -1
- metadata +16 -11
- data/.travis.yml +0 -14
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: d72a8a48f52a71000c0ff4b1202684a8929a8e34b3004f5f142cccbd9af1e034
|
4
|
+
data.tar.gz: 7c8f38da980376b9ca84191235e10cf98bd9ff44ab546ee3b52d105b57cb4ec7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 795830bf885ba8164bb95619eb713cd3b893c61c0c3b1f43de66afb7d3632383054e1ca6870d48ca8a69f0b1fc2007e42b5e37ab67ef14848dd7a471c4d195d2
|
7
|
+
data.tar.gz: 41388d1b428d3e4e7a181149cc53a7b1654af3b61961dce67948b86ef54fb96eb24e626d1ffe2645fc59886e37d44ec3ba20333c2e17cab5446513d65f205e8c
|
@@ -0,0 +1,27 @@
|
|
1
|
+
name: build
|
2
|
+
|
3
|
+
on: [push, pull_request]
|
4
|
+
|
5
|
+
jobs:
|
6
|
+
build:
|
7
|
+
runs-on: ubuntu-latest
|
8
|
+
strategy:
|
9
|
+
matrix:
|
10
|
+
ruby: [ '2.5', '2.6', '2.7' ]
|
11
|
+
steps:
|
12
|
+
- uses: actions/checkout@v2
|
13
|
+
- name: Checkout submodule
|
14
|
+
shell: bash
|
15
|
+
run: |
|
16
|
+
auth_header="$(git config --local --get http.https://github.com/.extraheader)"
|
17
|
+
git submodule sync --recursive
|
18
|
+
git -c "http.extraheader=$auth_header" -c protocol.version=2 submodule update --init --force --recursive --depth=1
|
19
|
+
- name: Set up Ruby ${{ matrix.ruby }}
|
20
|
+
uses: actions/setup-ruby@v1
|
21
|
+
with:
|
22
|
+
ruby-version: ${{ matrix.ruby }}
|
23
|
+
- name: Build and test with Rake
|
24
|
+
run: |
|
25
|
+
gem install bundler
|
26
|
+
bundle install --jobs 4 --retry 3
|
27
|
+
bundle exec rake
|
data/.gitmodules
ADDED
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,23 @@
|
|
1
|
+
# 1.0.2
|
2
|
+
- Add GC guard to model saving and loading methods.
|
3
|
+
- Fix size specification to memcpy function.
|
4
|
+
|
5
|
+
# 1.0.1
|
6
|
+
- Add GC guard codes.
|
7
|
+
- Fix some configuration files.
|
8
|
+
|
9
|
+
# 1.0.0
|
10
|
+
## Breaking change
|
11
|
+
- For easy installation, Numo::LIBSVM bundles LIBSVM codes.
|
12
|
+
There is no need to install LIBSVM in advance to use Numo::LIBSVM.
|
13
|
+
|
14
|
+
# 0.5.0
|
15
|
+
- Fix to use LIBSVM sparce vector representation for internal processing.
|
16
|
+
|
17
|
+
# 0.4.0
|
18
|
+
- Add verbose parameter to output learning process messages.
|
19
|
+
- Several documentation improvements.
|
20
|
+
|
1
21
|
# 0.3.0
|
2
22
|
- Add random_seed parameter for specifying seed to give to srand function.
|
3
23
|
- Several documentation improvements.
|
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
# Numo::Libsvm
|
2
2
|
|
3
|
-
[![Build Status](https://
|
3
|
+
[![Build Status](https://github.com/yoshoku/numo-libsvm/workflows/build/badge.svg)](https://github.com/yoshoku/numo-libsvm/actions?query=workflow%3Abuild)
|
4
4
|
[![Gem Version](https://badge.fury.io/rb/numo-libsvm.svg)](https://badge.fury.io/rb/numo-libsvm)
|
5
|
-
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-libsvm/blob/
|
6
|
-
[![Documentation](
|
5
|
+
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-libsvm/blob/main/LICENSE.txt)
|
6
|
+
[![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/numo-libsvm/doc/)
|
7
7
|
|
8
8
|
Numo::Libsvm is a Ruby gem binding to the [LIBSVM](https://github.com/cjlin1/libsvm) library.
|
9
9
|
LIBSVM is one of the famous libraries that implemented Support Vector Machines,
|
@@ -16,15 +16,7 @@ Note: There are other useful Ruby gems binding to LIBSVM:
|
|
16
16
|
and [jrb-libsvm](https://github.com/andreaseger/jrb-libsvm) by Andreas Eger.
|
17
17
|
|
18
18
|
## Installation
|
19
|
-
Numo::Libsvm
|
20
|
-
|
21
|
-
macOS:
|
22
|
-
|
23
|
-
$ brew install libsvm
|
24
|
-
|
25
|
-
Ubuntu:
|
26
|
-
|
27
|
-
$ sudo apt-get install libsvm-dev
|
19
|
+
Numo::Libsvm bundles LIBSVM. There is no need to install LIBSVM in advance.
|
28
20
|
|
29
21
|
Add this line to your application's Gemfile:
|
30
22
|
|
@@ -166,7 +158,7 @@ Accuracy: 98.3 %
|
|
166
158
|
### Note
|
167
159
|
The hyperparameter of SVM is given with Ruby Hash on Numo::Libsvm.
|
168
160
|
The hash key of hyperparameter and its meaning match the struct svm_parameter of LIBSVM.
|
169
|
-
The svm_parameter is detailed in [LIBSVM README](https://github.com/cjlin1/libsvm/blob/
|
161
|
+
The svm_parameter is detailed in [LIBSVM README](https://github.com/cjlin1/libsvm/blob/main/README).
|
170
162
|
|
171
163
|
```ruby
|
172
164
|
param = {
|
@@ -190,6 +182,7 @@ param = {
|
|
190
182
|
p: 0.1, # [Float] Parameter epsilon in loss function of epsilon-SVR
|
191
183
|
shrinking: true, # [Boolean] Whether to use the shrinking heuristics
|
192
184
|
probability: false, # [Boolean] Whether to train a SVC or SVR model for probability estimates
|
185
|
+
verbose: false, # [Boolean] Whether to output learning process message
|
193
186
|
random_seed: 1 # [Integer/Nil] Random seed
|
194
187
|
}
|
195
188
|
```
|
@@ -204,4 +197,4 @@ The gem is available as open source under the terms of the [BSD-3-Clause License
|
|
204
197
|
|
205
198
|
## Code of Conduct
|
206
199
|
|
207
|
-
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/numo-libsvm/blob/
|
200
|
+
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/numo-libsvm/blob/main/CODE_OF_CONDUCT.md).
|
data/ext/numo/libsvm/converter.c
CHANGED
@@ -28,6 +28,8 @@ int* nary_to_int_vec(VALUE vec_val)
|
|
28
28
|
vec_pt = (int32_t*)na_get_pointer_for_read(vec_val);
|
29
29
|
for (i = 0; i < n_elements; i++) { vec[i] = (int)vec_pt[i]; }
|
30
30
|
|
31
|
+
RB_GC_GUARD(vec_val);
|
32
|
+
|
31
33
|
return vec;
|
32
34
|
}
|
33
35
|
|
@@ -57,6 +59,8 @@ double* nary_to_dbl_vec(VALUE vec_val)
|
|
57
59
|
vec_pt = (double*)na_get_pointer_for_read(vec_val);
|
58
60
|
memcpy(vec, vec_pt, n_elements * sizeof(double));
|
59
61
|
|
62
|
+
RB_GC_GUARD(vec_val);
|
63
|
+
|
60
64
|
return vec;
|
61
65
|
}
|
62
66
|
|
@@ -99,6 +103,8 @@ double** nary_to_dbl_mat(VALUE mat_val)
|
|
99
103
|
}
|
100
104
|
}
|
101
105
|
|
106
|
+
RB_GC_GUARD(mat_val);
|
107
|
+
|
102
108
|
return mat;
|
103
109
|
}
|
104
110
|
|
@@ -132,31 +138,67 @@ VALUE svm_nodes_to_nary(struct svm_node** const support_vecs, const int n_suppor
|
|
132
138
|
return v;
|
133
139
|
}
|
134
140
|
|
135
|
-
struct svm_node** nary_to_svm_nodes(VALUE
|
141
|
+
struct svm_node** nary_to_svm_nodes(VALUE nary_val)
|
136
142
|
{
|
137
|
-
int i, j;
|
138
|
-
int n_rows, n_cols;
|
139
|
-
narray_t*
|
140
|
-
double*
|
143
|
+
int i, j, k;
|
144
|
+
int n_rows, n_cols, n_nonzero_cols;
|
145
|
+
narray_t* nary;
|
146
|
+
double* nary_pt;
|
141
147
|
struct svm_node** support_vecs;
|
142
148
|
|
143
|
-
if (
|
149
|
+
if (nary_val == Qnil) return NULL;
|
144
150
|
|
145
|
-
GetNArray(
|
146
|
-
n_rows = (int)NA_SHAPE(
|
147
|
-
n_cols = (int)NA_SHAPE(
|
151
|
+
GetNArray(nary_val, nary);
|
152
|
+
n_rows = (int)NA_SHAPE(nary)[0];
|
153
|
+
n_cols = (int)NA_SHAPE(nary)[1];
|
148
154
|
|
149
|
-
|
155
|
+
nary_pt = (double*)na_get_pointer_for_read(nary_val);
|
150
156
|
support_vecs = ALLOC_N(struct svm_node*, n_rows);
|
151
157
|
for (i = 0; i < n_rows; i++) {
|
152
|
-
|
158
|
+
n_nonzero_cols = 0;
|
153
159
|
for (j = 0; j < n_cols; j++) {
|
154
|
-
|
155
|
-
|
160
|
+
if (nary_pt[i * n_cols + j] != 0) {
|
161
|
+
n_nonzero_cols++;
|
162
|
+
}
|
163
|
+
}
|
164
|
+
support_vecs[i] = ALLOC_N(struct svm_node, n_nonzero_cols + 1);
|
165
|
+
for (j = 0, k = 0; j < n_cols; j++) {
|
166
|
+
if (nary_pt[i * n_cols + j] != 0) {
|
167
|
+
support_vecs[i][k].index = j + 1;
|
168
|
+
support_vecs[i][k].value = nary_pt[i * n_cols + j];
|
169
|
+
k++;
|
170
|
+
}
|
156
171
|
}
|
157
|
-
support_vecs[i][
|
158
|
-
support_vecs[i][
|
172
|
+
support_vecs[i][n_nonzero_cols].index = -1;
|
173
|
+
support_vecs[i][n_nonzero_cols].value = 0.0;
|
159
174
|
}
|
160
175
|
|
176
|
+
RB_GC_GUARD(nary_val);
|
177
|
+
|
161
178
|
return support_vecs;
|
162
179
|
}
|
180
|
+
|
181
|
+
struct svm_node* dbl_vec_to_svm_node(double* const arr, int const size)
|
182
|
+
{
|
183
|
+
int i, j;
|
184
|
+
int n_nonzero_elements;
|
185
|
+
struct svm_node* node;
|
186
|
+
|
187
|
+
n_nonzero_elements = 0;
|
188
|
+
for (i = 0; i < size; i++) {
|
189
|
+
if (arr[i] != 0.0) n_nonzero_elements++;
|
190
|
+
}
|
191
|
+
|
192
|
+
node = ALLOC_N(struct svm_node, n_nonzero_elements + 1);
|
193
|
+
for (i = 0, j = 0; i < size; i++) {
|
194
|
+
if (arr[i] != 0.0) {
|
195
|
+
node[j].index = i + 1;
|
196
|
+
node[j].value = arr[i];
|
197
|
+
j++;
|
198
|
+
}
|
199
|
+
}
|
200
|
+
node[n_nonzero_elements].index = -1;
|
201
|
+
node[n_nonzero_elements].value = 0.0;
|
202
|
+
|
203
|
+
return node;
|
204
|
+
}
|
data/ext/numo/libsvm/converter.h
CHANGED
@@ -14,6 +14,7 @@ double* nary_to_dbl_vec(VALUE vec_val);
|
|
14
14
|
VALUE dbl_mat_to_nary(double** const mat, int const n_rows, int const n_cols);
|
15
15
|
double** nary_to_dbl_mat(VALUE mat_val);
|
16
16
|
VALUE svm_nodes_to_nary(struct svm_node** const support_vecs, const int n_support_vecs);
|
17
|
-
struct svm_node** nary_to_svm_nodes(VALUE
|
17
|
+
struct svm_node** nary_to_svm_nodes(VALUE nary_val);
|
18
|
+
struct svm_node* dbl_vec_to_svm_node(double* const arr, int const size);
|
18
19
|
|
19
20
|
#endif /* NUMO_LIBSVM_CONVERTER_H */
|
data/ext/numo/libsvm/extconf.rb
CHANGED
@@ -26,18 +26,14 @@ if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
|
|
26
26
|
end
|
27
27
|
end
|
28
28
|
|
29
|
-
|
30
|
-
$INCFLAGS = "-I/usr/include/libsvm #{$INCFLAGS}"
|
31
|
-
end
|
29
|
+
$LDFLAGS << ' -lstdc++ '
|
32
30
|
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
puts 'libsvm not found.'
|
40
|
-
exit(1)
|
31
|
+
$srcs = Dir.glob("#{$srcdir}/*.c").map { |path| File.basename(path) }
|
32
|
+
$srcs << 'svm.cpp'
|
33
|
+
Dir.glob("#{$srcdir}/*/") do |path|
|
34
|
+
dir = File.basename(path)
|
35
|
+
$INCFLAGS << " -I$(srcdir)/#{dir}"
|
36
|
+
$VPATH << "$(srcdir)/#{dir}"
|
41
37
|
end
|
42
38
|
|
43
39
|
create_makefile('numo/libsvm/libsvmext')
|
@@ -0,0 +1,3182 @@
|
|
1
|
+
#include <math.h>
|
2
|
+
#include <stdio.h>
|
3
|
+
#include <stdlib.h>
|
4
|
+
#include <ctype.h>
|
5
|
+
#include <float.h>
|
6
|
+
#include <string.h>
|
7
|
+
#include <stdarg.h>
|
8
|
+
#include <limits.h>
|
9
|
+
#include <locale.h>
|
10
|
+
#include "svm.h"
|
11
|
+
int libsvm_version = LIBSVM_VERSION;
|
12
|
+
typedef float Qfloat;
|
13
|
+
typedef signed char schar;
|
14
|
+
#ifndef min
|
15
|
+
template <class T> static inline T min(T x,T y) { return (x<y)?x:y; }
|
16
|
+
#endif
|
17
|
+
#ifndef max
|
18
|
+
template <class T> static inline T max(T x,T y) { return (x>y)?x:y; }
|
19
|
+
#endif
|
20
|
+
template <class T> static inline void swap(T& x, T& y) { T t=x; x=y; y=t; }
|
21
|
+
template <class S, class T> static inline void clone(T*& dst, S* src, int n)
|
22
|
+
{
|
23
|
+
dst = new T[n];
|
24
|
+
memcpy((void *)dst,(void *)src,sizeof(T)*n);
|
25
|
+
}
|
26
|
+
static inline double powi(double base, int times)
|
27
|
+
{
|
28
|
+
double tmp = base, ret = 1.0;
|
29
|
+
|
30
|
+
for(int t=times; t>0; t/=2)
|
31
|
+
{
|
32
|
+
if(t%2==1) ret*=tmp;
|
33
|
+
tmp = tmp * tmp;
|
34
|
+
}
|
35
|
+
return ret;
|
36
|
+
}
|
37
|
+
#define INF HUGE_VAL
|
38
|
+
#define TAU 1e-12
|
39
|
+
#define Malloc(type,n) (type *)malloc((n)*sizeof(type))
|
40
|
+
|
41
|
+
static void print_string_stdout(const char *s)
|
42
|
+
{
|
43
|
+
fputs(s,stdout);
|
44
|
+
fflush(stdout);
|
45
|
+
}
|
46
|
+
static void (*svm_print_string) (const char *) = &print_string_stdout;
|
47
|
+
#if 1
|
48
|
+
static void info(const char *fmt,...)
|
49
|
+
{
|
50
|
+
char buf[BUFSIZ];
|
51
|
+
va_list ap;
|
52
|
+
va_start(ap,fmt);
|
53
|
+
vsprintf(buf,fmt,ap);
|
54
|
+
va_end(ap);
|
55
|
+
(*svm_print_string)(buf);
|
56
|
+
}
|
57
|
+
#else
|
58
|
+
static void info(const char *fmt,...) {}
|
59
|
+
#endif
|
60
|
+
|
61
|
+
//
|
62
|
+
// Kernel Cache
|
63
|
+
//
|
64
|
+
// l is the number of total data items
|
65
|
+
// size is the cache size limit in bytes
|
66
|
+
//
|
67
|
+
class Cache
|
68
|
+
{
|
69
|
+
public:
|
70
|
+
Cache(int l,long int size);
|
71
|
+
~Cache();
|
72
|
+
|
73
|
+
// request data [0,len)
|
74
|
+
// return some position p where [p,len) need to be filled
|
75
|
+
// (p >= len if nothing needs to be filled)
|
76
|
+
int get_data(const int index, Qfloat **data, int len);
|
77
|
+
void swap_index(int i, int j);
|
78
|
+
private:
|
79
|
+
int l;
|
80
|
+
long int size;
|
81
|
+
struct head_t
|
82
|
+
{
|
83
|
+
head_t *prev, *next; // a circular list
|
84
|
+
Qfloat *data;
|
85
|
+
int len; // data[0,len) is cached in this entry
|
86
|
+
};
|
87
|
+
|
88
|
+
head_t *head;
|
89
|
+
head_t lru_head;
|
90
|
+
void lru_delete(head_t *h);
|
91
|
+
void lru_insert(head_t *h);
|
92
|
+
};
|
93
|
+
|
94
|
+
Cache::Cache(int l_,long int size_):l(l_),size(size_)
|
95
|
+
{
|
96
|
+
head = (head_t *)calloc(l,sizeof(head_t)); // initialized to 0
|
97
|
+
size /= sizeof(Qfloat);
|
98
|
+
size -= l * sizeof(head_t) / sizeof(Qfloat);
|
99
|
+
size = max(size, 2 * (long int) l); // cache must be large enough for two columns
|
100
|
+
lru_head.next = lru_head.prev = &lru_head;
|
101
|
+
}
|
102
|
+
|
103
|
+
Cache::~Cache()
|
104
|
+
{
|
105
|
+
for(head_t *h = lru_head.next; h != &lru_head; h=h->next)
|
106
|
+
free(h->data);
|
107
|
+
free(head);
|
108
|
+
}
|
109
|
+
|
110
|
+
void Cache::lru_delete(head_t *h)
|
111
|
+
{
|
112
|
+
// delete from current location
|
113
|
+
h->prev->next = h->next;
|
114
|
+
h->next->prev = h->prev;
|
115
|
+
}
|
116
|
+
|
117
|
+
void Cache::lru_insert(head_t *h)
|
118
|
+
{
|
119
|
+
// insert to last position
|
120
|
+
h->next = &lru_head;
|
121
|
+
h->prev = lru_head.prev;
|
122
|
+
h->prev->next = h;
|
123
|
+
h->next->prev = h;
|
124
|
+
}
|
125
|
+
|
126
|
+
int Cache::get_data(const int index, Qfloat **data, int len)
|
127
|
+
{
|
128
|
+
head_t *h = &head[index];
|
129
|
+
if(h->len) lru_delete(h);
|
130
|
+
int more = len - h->len;
|
131
|
+
|
132
|
+
if(more > 0)
|
133
|
+
{
|
134
|
+
// free old space
|
135
|
+
while(size < more)
|
136
|
+
{
|
137
|
+
head_t *old = lru_head.next;
|
138
|
+
lru_delete(old);
|
139
|
+
free(old->data);
|
140
|
+
size += old->len;
|
141
|
+
old->data = 0;
|
142
|
+
old->len = 0;
|
143
|
+
}
|
144
|
+
|
145
|
+
// allocate new space
|
146
|
+
h->data = (Qfloat *)realloc(h->data,sizeof(Qfloat)*len);
|
147
|
+
size -= more;
|
148
|
+
swap(h->len,len);
|
149
|
+
}
|
150
|
+
|
151
|
+
lru_insert(h);
|
152
|
+
*data = h->data;
|
153
|
+
return len;
|
154
|
+
}
|
155
|
+
|
156
|
+
void Cache::swap_index(int i, int j)
|
157
|
+
{
|
158
|
+
if(i==j) return;
|
159
|
+
|
160
|
+
if(head[i].len) lru_delete(&head[i]);
|
161
|
+
if(head[j].len) lru_delete(&head[j]);
|
162
|
+
swap(head[i].data,head[j].data);
|
163
|
+
swap(head[i].len,head[j].len);
|
164
|
+
if(head[i].len) lru_insert(&head[i]);
|
165
|
+
if(head[j].len) lru_insert(&head[j]);
|
166
|
+
|
167
|
+
if(i>j) swap(i,j);
|
168
|
+
for(head_t *h = lru_head.next; h!=&lru_head; h=h->next)
|
169
|
+
{
|
170
|
+
if(h->len > i)
|
171
|
+
{
|
172
|
+
if(h->len > j)
|
173
|
+
swap(h->data[i],h->data[j]);
|
174
|
+
else
|
175
|
+
{
|
176
|
+
// give up
|
177
|
+
lru_delete(h);
|
178
|
+
free(h->data);
|
179
|
+
size += h->len;
|
180
|
+
h->data = 0;
|
181
|
+
h->len = 0;
|
182
|
+
}
|
183
|
+
}
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
//
|
188
|
+
// Kernel evaluation
|
189
|
+
//
|
190
|
+
// the static method k_function is for doing single kernel evaluation
|
191
|
+
// the constructor of Kernel prepares to calculate the l*l kernel matrix
|
192
|
+
// the member function get_Q is for getting one column from the Q Matrix
|
193
|
+
//
|
194
|
+
class QMatrix {
|
195
|
+
public:
|
196
|
+
virtual Qfloat *get_Q(int column, int len) const = 0;
|
197
|
+
virtual double *get_QD() const = 0;
|
198
|
+
virtual void swap_index(int i, int j) const = 0;
|
199
|
+
virtual ~QMatrix() {}
|
200
|
+
};
|
201
|
+
|
202
|
+
class Kernel: public QMatrix {
|
203
|
+
public:
|
204
|
+
Kernel(int l, svm_node * const * x, const svm_parameter& param);
|
205
|
+
virtual ~Kernel();
|
206
|
+
|
207
|
+
static double k_function(const svm_node *x, const svm_node *y,
|
208
|
+
const svm_parameter& param);
|
209
|
+
virtual Qfloat *get_Q(int column, int len) const = 0;
|
210
|
+
virtual double *get_QD() const = 0;
|
211
|
+
virtual void swap_index(int i, int j) const // no so const...
|
212
|
+
{
|
213
|
+
swap(x[i],x[j]);
|
214
|
+
if(x_square) swap(x_square[i],x_square[j]);
|
215
|
+
}
|
216
|
+
protected:
|
217
|
+
|
218
|
+
double (Kernel::*kernel_function)(int i, int j) const;
|
219
|
+
|
220
|
+
private:
|
221
|
+
const svm_node **x;
|
222
|
+
double *x_square;
|
223
|
+
|
224
|
+
// svm_parameter
|
225
|
+
const int kernel_type;
|
226
|
+
const int degree;
|
227
|
+
const double gamma;
|
228
|
+
const double coef0;
|
229
|
+
|
230
|
+
static double dot(const svm_node *px, const svm_node *py);
|
231
|
+
double kernel_linear(int i, int j) const
|
232
|
+
{
|
233
|
+
return dot(x[i],x[j]);
|
234
|
+
}
|
235
|
+
double kernel_poly(int i, int j) const
|
236
|
+
{
|
237
|
+
return powi(gamma*dot(x[i],x[j])+coef0,degree);
|
238
|
+
}
|
239
|
+
double kernel_rbf(int i, int j) const
|
240
|
+
{
|
241
|
+
return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));
|
242
|
+
}
|
243
|
+
double kernel_sigmoid(int i, int j) const
|
244
|
+
{
|
245
|
+
return tanh(gamma*dot(x[i],x[j])+coef0);
|
246
|
+
}
|
247
|
+
double kernel_precomputed(int i, int j) const
|
248
|
+
{
|
249
|
+
return x[i][(int)(x[j][0].value)].value;
|
250
|
+
}
|
251
|
+
};
|
252
|
+
|
253
|
+
Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)
|
254
|
+
:kernel_type(param.kernel_type), degree(param.degree),
|
255
|
+
gamma(param.gamma), coef0(param.coef0)
|
256
|
+
{
|
257
|
+
switch(kernel_type)
|
258
|
+
{
|
259
|
+
case LINEAR:
|
260
|
+
kernel_function = &Kernel::kernel_linear;
|
261
|
+
break;
|
262
|
+
case POLY:
|
263
|
+
kernel_function = &Kernel::kernel_poly;
|
264
|
+
break;
|
265
|
+
case RBF:
|
266
|
+
kernel_function = &Kernel::kernel_rbf;
|
267
|
+
break;
|
268
|
+
case SIGMOID:
|
269
|
+
kernel_function = &Kernel::kernel_sigmoid;
|
270
|
+
break;
|
271
|
+
case PRECOMPUTED:
|
272
|
+
kernel_function = &Kernel::kernel_precomputed;
|
273
|
+
break;
|
274
|
+
}
|
275
|
+
|
276
|
+
clone(x,x_,l);
|
277
|
+
|
278
|
+
if(kernel_type == RBF)
|
279
|
+
{
|
280
|
+
x_square = new double[l];
|
281
|
+
for(int i=0;i<l;i++)
|
282
|
+
x_square[i] = dot(x[i],x[i]);
|
283
|
+
}
|
284
|
+
else
|
285
|
+
x_square = 0;
|
286
|
+
}
|
287
|
+
|
288
|
+
Kernel::~Kernel()
|
289
|
+
{
|
290
|
+
delete[] x;
|
291
|
+
delete[] x_square;
|
292
|
+
}
|
293
|
+
|
294
|
+
double Kernel::dot(const svm_node *px, const svm_node *py)
|
295
|
+
{
|
296
|
+
double sum = 0;
|
297
|
+
while(px->index != -1 && py->index != -1)
|
298
|
+
{
|
299
|
+
if(px->index == py->index)
|
300
|
+
{
|
301
|
+
sum += px->value * py->value;
|
302
|
+
++px;
|
303
|
+
++py;
|
304
|
+
}
|
305
|
+
else
|
306
|
+
{
|
307
|
+
if(px->index > py->index)
|
308
|
+
++py;
|
309
|
+
else
|
310
|
+
++px;
|
311
|
+
}
|
312
|
+
}
|
313
|
+
return sum;
|
314
|
+
}
|
315
|
+
|
316
|
+
double Kernel::k_function(const svm_node *x, const svm_node *y,
|
317
|
+
const svm_parameter& param)
|
318
|
+
{
|
319
|
+
switch(param.kernel_type)
|
320
|
+
{
|
321
|
+
case LINEAR:
|
322
|
+
return dot(x,y);
|
323
|
+
case POLY:
|
324
|
+
return powi(param.gamma*dot(x,y)+param.coef0,param.degree);
|
325
|
+
case RBF:
|
326
|
+
{
|
327
|
+
double sum = 0;
|
328
|
+
while(x->index != -1 && y->index !=-1)
|
329
|
+
{
|
330
|
+
if(x->index == y->index)
|
331
|
+
{
|
332
|
+
double d = x->value - y->value;
|
333
|
+
sum += d*d;
|
334
|
+
++x;
|
335
|
+
++y;
|
336
|
+
}
|
337
|
+
else
|
338
|
+
{
|
339
|
+
if(x->index > y->index)
|
340
|
+
{
|
341
|
+
sum += y->value * y->value;
|
342
|
+
++y;
|
343
|
+
}
|
344
|
+
else
|
345
|
+
{
|
346
|
+
sum += x->value * x->value;
|
347
|
+
++x;
|
348
|
+
}
|
349
|
+
}
|
350
|
+
}
|
351
|
+
|
352
|
+
while(x->index != -1)
|
353
|
+
{
|
354
|
+
sum += x->value * x->value;
|
355
|
+
++x;
|
356
|
+
}
|
357
|
+
|
358
|
+
while(y->index != -1)
|
359
|
+
{
|
360
|
+
sum += y->value * y->value;
|
361
|
+
++y;
|
362
|
+
}
|
363
|
+
|
364
|
+
return exp(-param.gamma*sum);
|
365
|
+
}
|
366
|
+
case SIGMOID:
|
367
|
+
return tanh(param.gamma*dot(x,y)+param.coef0);
|
368
|
+
case PRECOMPUTED: //x: test (validation), y: SV
|
369
|
+
return x[(int)(y->value)].value;
|
370
|
+
default:
|
371
|
+
return 0; // Unreachable
|
372
|
+
}
|
373
|
+
}
|
374
|
+
|
375
|
+
// An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918
|
376
|
+
// Solves:
|
377
|
+
//
|
378
|
+
// min 0.5(\alpha^T Q \alpha) + p^T \alpha
|
379
|
+
//
|
380
|
+
// y^T \alpha = \delta
|
381
|
+
// y_i = +1 or -1
|
382
|
+
// 0 <= alpha_i <= Cp for y_i = 1
|
383
|
+
// 0 <= alpha_i <= Cn for y_i = -1
|
384
|
+
//
|
385
|
+
// Given:
|
386
|
+
//
|
387
|
+
// Q, p, y, Cp, Cn, and an initial feasible point \alpha
|
388
|
+
// l is the size of vectors and matrices
|
389
|
+
// eps is the stopping tolerance
|
390
|
+
//
|
391
|
+
// solution will be put in \alpha, objective value will be put in obj
|
392
|
+
//
|
393
|
+
class Solver {
|
394
|
+
public:
|
395
|
+
Solver() {};
|
396
|
+
virtual ~Solver() {};
|
397
|
+
|
398
|
+
struct SolutionInfo {
|
399
|
+
double obj;
|
400
|
+
double rho;
|
401
|
+
double upper_bound_p;
|
402
|
+
double upper_bound_n;
|
403
|
+
double r; // for Solver_NU
|
404
|
+
};
|
405
|
+
|
406
|
+
void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
|
407
|
+
double *alpha_, double Cp, double Cn, double eps,
|
408
|
+
SolutionInfo* si, int shrinking);
|
409
|
+
protected:
|
410
|
+
int active_size;
|
411
|
+
schar *y;
|
412
|
+
double *G; // gradient of objective function
|
413
|
+
enum { LOWER_BOUND, UPPER_BOUND, FREE };
|
414
|
+
char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE
|
415
|
+
double *alpha;
|
416
|
+
const QMatrix *Q;
|
417
|
+
const double *QD;
|
418
|
+
double eps;
|
419
|
+
double Cp,Cn;
|
420
|
+
double *p;
|
421
|
+
int *active_set;
|
422
|
+
double *G_bar; // gradient, if we treat free variables as 0
|
423
|
+
int l;
|
424
|
+
bool unshrink; // XXX
|
425
|
+
|
426
|
+
double get_C(int i)
|
427
|
+
{
|
428
|
+
return (y[i] > 0)? Cp : Cn;
|
429
|
+
}
|
430
|
+
void update_alpha_status(int i)
|
431
|
+
{
|
432
|
+
if(alpha[i] >= get_C(i))
|
433
|
+
alpha_status[i] = UPPER_BOUND;
|
434
|
+
else if(alpha[i] <= 0)
|
435
|
+
alpha_status[i] = LOWER_BOUND;
|
436
|
+
else alpha_status[i] = FREE;
|
437
|
+
}
|
438
|
+
bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }
|
439
|
+
bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }
|
440
|
+
bool is_free(int i) { return alpha_status[i] == FREE; }
|
441
|
+
void swap_index(int i, int j);
|
442
|
+
void reconstruct_gradient();
|
443
|
+
virtual int select_working_set(int &i, int &j);
|
444
|
+
virtual double calculate_rho();
|
445
|
+
virtual void do_shrinking();
|
446
|
+
private:
|
447
|
+
bool be_shrunk(int i, double Gmax1, double Gmax2);
|
448
|
+
};
|
449
|
+
|
450
|
+
void Solver::swap_index(int i, int j)
|
451
|
+
{
|
452
|
+
Q->swap_index(i,j);
|
453
|
+
swap(y[i],y[j]);
|
454
|
+
swap(G[i],G[j]);
|
455
|
+
swap(alpha_status[i],alpha_status[j]);
|
456
|
+
swap(alpha[i],alpha[j]);
|
457
|
+
swap(p[i],p[j]);
|
458
|
+
swap(active_set[i],active_set[j]);
|
459
|
+
swap(G_bar[i],G_bar[j]);
|
460
|
+
}
|
461
|
+
|
462
|
+
void Solver::reconstruct_gradient()
|
463
|
+
{
|
464
|
+
// reconstruct inactive elements of G from G_bar and free variables
|
465
|
+
|
466
|
+
if(active_size == l) return;
|
467
|
+
|
468
|
+
int i,j;
|
469
|
+
int nr_free = 0;
|
470
|
+
|
471
|
+
for(j=active_size;j<l;j++)
|
472
|
+
G[j] = G_bar[j] + p[j];
|
473
|
+
|
474
|
+
for(j=0;j<active_size;j++)
|
475
|
+
if(is_free(j))
|
476
|
+
nr_free++;
|
477
|
+
|
478
|
+
if(2*nr_free < active_size)
|
479
|
+
info("\nWARNING: using -h 0 may be faster\n");
|
480
|
+
|
481
|
+
if (nr_free*l > 2*active_size*(l-active_size))
|
482
|
+
{
|
483
|
+
for(i=active_size;i<l;i++)
|
484
|
+
{
|
485
|
+
const Qfloat *Q_i = Q->get_Q(i,active_size);
|
486
|
+
for(j=0;j<active_size;j++)
|
487
|
+
if(is_free(j))
|
488
|
+
G[i] += alpha[j] * Q_i[j];
|
489
|
+
}
|
490
|
+
}
|
491
|
+
else
|
492
|
+
{
|
493
|
+
for(i=0;i<active_size;i++)
|
494
|
+
if(is_free(i))
|
495
|
+
{
|
496
|
+
const Qfloat *Q_i = Q->get_Q(i,l);
|
497
|
+
double alpha_i = alpha[i];
|
498
|
+
for(j=active_size;j<l;j++)
|
499
|
+
G[j] += alpha_i * Q_i[j];
|
500
|
+
}
|
501
|
+
}
|
502
|
+
}
|
503
|
+
|
504
|
+
void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,
|
505
|
+
double *alpha_, double Cp, double Cn, double eps,
|
506
|
+
SolutionInfo* si, int shrinking)
|
507
|
+
{
|
508
|
+
this->l = l;
|
509
|
+
this->Q = &Q;
|
510
|
+
QD=Q.get_QD();
|
511
|
+
clone(p, p_,l);
|
512
|
+
clone(y, y_,l);
|
513
|
+
clone(alpha,alpha_,l);
|
514
|
+
this->Cp = Cp;
|
515
|
+
this->Cn = Cn;
|
516
|
+
this->eps = eps;
|
517
|
+
unshrink = false;
|
518
|
+
|
519
|
+
// initialize alpha_status
|
520
|
+
{
|
521
|
+
alpha_status = new char[l];
|
522
|
+
for(int i=0;i<l;i++)
|
523
|
+
update_alpha_status(i);
|
524
|
+
}
|
525
|
+
|
526
|
+
// initialize active set (for shrinking)
|
527
|
+
{
|
528
|
+
active_set = new int[l];
|
529
|
+
for(int i=0;i<l;i++)
|
530
|
+
active_set[i] = i;
|
531
|
+
active_size = l;
|
532
|
+
}
|
533
|
+
|
534
|
+
// initialize gradient
|
535
|
+
{
|
536
|
+
G = new double[l];
|
537
|
+
G_bar = new double[l];
|
538
|
+
int i;
|
539
|
+
for(i=0;i<l;i++)
|
540
|
+
{
|
541
|
+
G[i] = p[i];
|
542
|
+
G_bar[i] = 0;
|
543
|
+
}
|
544
|
+
for(i=0;i<l;i++)
|
545
|
+
if(!is_lower_bound(i))
|
546
|
+
{
|
547
|
+
const Qfloat *Q_i = Q.get_Q(i,l);
|
548
|
+
double alpha_i = alpha[i];
|
549
|
+
int j;
|
550
|
+
for(j=0;j<l;j++)
|
551
|
+
G[j] += alpha_i*Q_i[j];
|
552
|
+
if(is_upper_bound(i))
|
553
|
+
for(j=0;j<l;j++)
|
554
|
+
G_bar[j] += get_C(i) * Q_i[j];
|
555
|
+
}
|
556
|
+
}
|
557
|
+
|
558
|
+
// optimization step
|
559
|
+
|
560
|
+
int iter = 0;
|
561
|
+
int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);
|
562
|
+
int counter = min(l,1000)+1;
|
563
|
+
|
564
|
+
while(iter < max_iter)
|
565
|
+
{
|
566
|
+
// show progress and do shrinking
|
567
|
+
|
568
|
+
if(--counter == 0)
|
569
|
+
{
|
570
|
+
counter = min(l,1000);
|
571
|
+
if(shrinking) do_shrinking();
|
572
|
+
info(".");
|
573
|
+
}
|
574
|
+
|
575
|
+
int i,j;
|
576
|
+
if(select_working_set(i,j)!=0)
|
577
|
+
{
|
578
|
+
// reconstruct the whole gradient
|
579
|
+
reconstruct_gradient();
|
580
|
+
// reset active set size and check
|
581
|
+
active_size = l;
|
582
|
+
info("*");
|
583
|
+
if(select_working_set(i,j)!=0)
|
584
|
+
break;
|
585
|
+
else
|
586
|
+
counter = 1; // do shrinking next iteration
|
587
|
+
}
|
588
|
+
|
589
|
+
++iter;
|
590
|
+
|
591
|
+
// update alpha[i] and alpha[j], handle bounds carefully
|
592
|
+
|
593
|
+
const Qfloat *Q_i = Q.get_Q(i,active_size);
|
594
|
+
const Qfloat *Q_j = Q.get_Q(j,active_size);
|
595
|
+
|
596
|
+
double C_i = get_C(i);
|
597
|
+
double C_j = get_C(j);
|
598
|
+
|
599
|
+
double old_alpha_i = alpha[i];
|
600
|
+
double old_alpha_j = alpha[j];
|
601
|
+
|
602
|
+
if(y[i]!=y[j])
|
603
|
+
{
|
604
|
+
double quad_coef = QD[i]+QD[j]+2*Q_i[j];
|
605
|
+
if (quad_coef <= 0)
|
606
|
+
quad_coef = TAU;
|
607
|
+
double delta = (-G[i]-G[j])/quad_coef;
|
608
|
+
double diff = alpha[i] - alpha[j];
|
609
|
+
alpha[i] += delta;
|
610
|
+
alpha[j] += delta;
|
611
|
+
|
612
|
+
if(diff > 0)
|
613
|
+
{
|
614
|
+
if(alpha[j] < 0)
|
615
|
+
{
|
616
|
+
alpha[j] = 0;
|
617
|
+
alpha[i] = diff;
|
618
|
+
}
|
619
|
+
}
|
620
|
+
else
|
621
|
+
{
|
622
|
+
if(alpha[i] < 0)
|
623
|
+
{
|
624
|
+
alpha[i] = 0;
|
625
|
+
alpha[j] = -diff;
|
626
|
+
}
|
627
|
+
}
|
628
|
+
if(diff > C_i - C_j)
|
629
|
+
{
|
630
|
+
if(alpha[i] > C_i)
|
631
|
+
{
|
632
|
+
alpha[i] = C_i;
|
633
|
+
alpha[j] = C_i - diff;
|
634
|
+
}
|
635
|
+
}
|
636
|
+
else
|
637
|
+
{
|
638
|
+
if(alpha[j] > C_j)
|
639
|
+
{
|
640
|
+
alpha[j] = C_j;
|
641
|
+
alpha[i] = C_j + diff;
|
642
|
+
}
|
643
|
+
}
|
644
|
+
}
|
645
|
+
else
|
646
|
+
{
|
647
|
+
double quad_coef = QD[i]+QD[j]-2*Q_i[j];
|
648
|
+
if (quad_coef <= 0)
|
649
|
+
quad_coef = TAU;
|
650
|
+
double delta = (G[i]-G[j])/quad_coef;
|
651
|
+
double sum = alpha[i] + alpha[j];
|
652
|
+
alpha[i] -= delta;
|
653
|
+
alpha[j] += delta;
|
654
|
+
|
655
|
+
if(sum > C_i)
|
656
|
+
{
|
657
|
+
if(alpha[i] > C_i)
|
658
|
+
{
|
659
|
+
alpha[i] = C_i;
|
660
|
+
alpha[j] = sum - C_i;
|
661
|
+
}
|
662
|
+
}
|
663
|
+
else
|
664
|
+
{
|
665
|
+
if(alpha[j] < 0)
|
666
|
+
{
|
667
|
+
alpha[j] = 0;
|
668
|
+
alpha[i] = sum;
|
669
|
+
}
|
670
|
+
}
|
671
|
+
if(sum > C_j)
|
672
|
+
{
|
673
|
+
if(alpha[j] > C_j)
|
674
|
+
{
|
675
|
+
alpha[j] = C_j;
|
676
|
+
alpha[i] = sum - C_j;
|
677
|
+
}
|
678
|
+
}
|
679
|
+
else
|
680
|
+
{
|
681
|
+
if(alpha[i] < 0)
|
682
|
+
{
|
683
|
+
alpha[i] = 0;
|
684
|
+
alpha[j] = sum;
|
685
|
+
}
|
686
|
+
}
|
687
|
+
}
|
688
|
+
|
689
|
+
// update G
|
690
|
+
|
691
|
+
double delta_alpha_i = alpha[i] - old_alpha_i;
|
692
|
+
double delta_alpha_j = alpha[j] - old_alpha_j;
|
693
|
+
|
694
|
+
for(int k=0;k<active_size;k++)
|
695
|
+
{
|
696
|
+
G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
|
697
|
+
}
|
698
|
+
|
699
|
+
// update alpha_status and G_bar
|
700
|
+
|
701
|
+
{
|
702
|
+
bool ui = is_upper_bound(i);
|
703
|
+
bool uj = is_upper_bound(j);
|
704
|
+
update_alpha_status(i);
|
705
|
+
update_alpha_status(j);
|
706
|
+
int k;
|
707
|
+
if(ui != is_upper_bound(i))
|
708
|
+
{
|
709
|
+
Q_i = Q.get_Q(i,l);
|
710
|
+
if(ui)
|
711
|
+
for(k=0;k<l;k++)
|
712
|
+
G_bar[k] -= C_i * Q_i[k];
|
713
|
+
else
|
714
|
+
for(k=0;k<l;k++)
|
715
|
+
G_bar[k] += C_i * Q_i[k];
|
716
|
+
}
|
717
|
+
|
718
|
+
if(uj != is_upper_bound(j))
|
719
|
+
{
|
720
|
+
Q_j = Q.get_Q(j,l);
|
721
|
+
if(uj)
|
722
|
+
for(k=0;k<l;k++)
|
723
|
+
G_bar[k] -= C_j * Q_j[k];
|
724
|
+
else
|
725
|
+
for(k=0;k<l;k++)
|
726
|
+
G_bar[k] += C_j * Q_j[k];
|
727
|
+
}
|
728
|
+
}
|
729
|
+
}
|
730
|
+
|
731
|
+
if(iter >= max_iter)
|
732
|
+
{
|
733
|
+
if(active_size < l)
|
734
|
+
{
|
735
|
+
// reconstruct the whole gradient to calculate objective value
|
736
|
+
reconstruct_gradient();
|
737
|
+
active_size = l;
|
738
|
+
info("*");
|
739
|
+
}
|
740
|
+
fprintf(stderr,"\nWARNING: reaching max number of iterations\n");
|
741
|
+
}
|
742
|
+
|
743
|
+
// calculate rho
|
744
|
+
|
745
|
+
si->rho = calculate_rho();
|
746
|
+
|
747
|
+
// calculate objective value
|
748
|
+
{
|
749
|
+
double v = 0;
|
750
|
+
int i;
|
751
|
+
for(i=0;i<l;i++)
|
752
|
+
v += alpha[i] * (G[i] + p[i]);
|
753
|
+
|
754
|
+
si->obj = v/2;
|
755
|
+
}
|
756
|
+
|
757
|
+
// put back the solution
|
758
|
+
{
|
759
|
+
for(int i=0;i<l;i++)
|
760
|
+
alpha_[active_set[i]] = alpha[i];
|
761
|
+
}
|
762
|
+
|
763
|
+
// juggle everything back
|
764
|
+
/*{
|
765
|
+
for(int i=0;i<l;i++)
|
766
|
+
while(active_set[i] != i)
|
767
|
+
swap_index(i,active_set[i]);
|
768
|
+
// or Q.swap_index(i,active_set[i]);
|
769
|
+
}*/
|
770
|
+
|
771
|
+
si->upper_bound_p = Cp;
|
772
|
+
si->upper_bound_n = Cn;
|
773
|
+
|
774
|
+
info("\noptimization finished, #iter = %d\n",iter);
|
775
|
+
|
776
|
+
delete[] p;
|
777
|
+
delete[] y;
|
778
|
+
delete[] alpha;
|
779
|
+
delete[] alpha_status;
|
780
|
+
delete[] active_set;
|
781
|
+
delete[] G;
|
782
|
+
delete[] G_bar;
|
783
|
+
}
|
784
|
+
|
785
|
+
// return 1 if already optimal, return 0 otherwise
|
786
|
+
int Solver::select_working_set(int &out_i, int &out_j)
|
787
|
+
{
|
788
|
+
// return i,j such that
|
789
|
+
// i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
|
790
|
+
// j: minimizes the decrease of obj value
|
791
|
+
// (if quadratic coefficeint <= 0, replace it with tau)
|
792
|
+
// -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
|
793
|
+
|
794
|
+
double Gmax = -INF;
|
795
|
+
double Gmax2 = -INF;
|
796
|
+
int Gmax_idx = -1;
|
797
|
+
int Gmin_idx = -1;
|
798
|
+
double obj_diff_min = INF;
|
799
|
+
|
800
|
+
for(int t=0;t<active_size;t++)
|
801
|
+
if(y[t]==+1)
|
802
|
+
{
|
803
|
+
if(!is_upper_bound(t))
|
804
|
+
if(-G[t] >= Gmax)
|
805
|
+
{
|
806
|
+
Gmax = -G[t];
|
807
|
+
Gmax_idx = t;
|
808
|
+
}
|
809
|
+
}
|
810
|
+
else
|
811
|
+
{
|
812
|
+
if(!is_lower_bound(t))
|
813
|
+
if(G[t] >= Gmax)
|
814
|
+
{
|
815
|
+
Gmax = G[t];
|
816
|
+
Gmax_idx = t;
|
817
|
+
}
|
818
|
+
}
|
819
|
+
|
820
|
+
int i = Gmax_idx;
|
821
|
+
const Qfloat *Q_i = NULL;
|
822
|
+
if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1
|
823
|
+
Q_i = Q->get_Q(i,active_size);
|
824
|
+
|
825
|
+
for(int j=0;j<active_size;j++)
|
826
|
+
{
|
827
|
+
if(y[j]==+1)
|
828
|
+
{
|
829
|
+
if (!is_lower_bound(j))
|
830
|
+
{
|
831
|
+
double grad_diff=Gmax+G[j];
|
832
|
+
if (G[j] >= Gmax2)
|
833
|
+
Gmax2 = G[j];
|
834
|
+
if (grad_diff > 0)
|
835
|
+
{
|
836
|
+
double obj_diff;
|
837
|
+
double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];
|
838
|
+
if (quad_coef > 0)
|
839
|
+
obj_diff = -(grad_diff*grad_diff)/quad_coef;
|
840
|
+
else
|
841
|
+
obj_diff = -(grad_diff*grad_diff)/TAU;
|
842
|
+
|
843
|
+
if (obj_diff <= obj_diff_min)
|
844
|
+
{
|
845
|
+
Gmin_idx=j;
|
846
|
+
obj_diff_min = obj_diff;
|
847
|
+
}
|
848
|
+
}
|
849
|
+
}
|
850
|
+
}
|
851
|
+
else
|
852
|
+
{
|
853
|
+
if (!is_upper_bound(j))
|
854
|
+
{
|
855
|
+
double grad_diff= Gmax-G[j];
|
856
|
+
if (-G[j] >= Gmax2)
|
857
|
+
Gmax2 = -G[j];
|
858
|
+
if (grad_diff > 0)
|
859
|
+
{
|
860
|
+
double obj_diff;
|
861
|
+
double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];
|
862
|
+
if (quad_coef > 0)
|
863
|
+
obj_diff = -(grad_diff*grad_diff)/quad_coef;
|
864
|
+
else
|
865
|
+
obj_diff = -(grad_diff*grad_diff)/TAU;
|
866
|
+
|
867
|
+
if (obj_diff <= obj_diff_min)
|
868
|
+
{
|
869
|
+
Gmin_idx=j;
|
870
|
+
obj_diff_min = obj_diff;
|
871
|
+
}
|
872
|
+
}
|
873
|
+
}
|
874
|
+
}
|
875
|
+
}
|
876
|
+
|
877
|
+
if(Gmax+Gmax2 < eps || Gmin_idx == -1)
|
878
|
+
return 1;
|
879
|
+
|
880
|
+
out_i = Gmax_idx;
|
881
|
+
out_j = Gmin_idx;
|
882
|
+
return 0;
|
883
|
+
}
|
884
|
+
|
885
|
+
bool Solver::be_shrunk(int i, double Gmax1, double Gmax2)
|
886
|
+
{
|
887
|
+
if(is_upper_bound(i))
|
888
|
+
{
|
889
|
+
if(y[i]==+1)
|
890
|
+
return(-G[i] > Gmax1);
|
891
|
+
else
|
892
|
+
return(-G[i] > Gmax2);
|
893
|
+
}
|
894
|
+
else if(is_lower_bound(i))
|
895
|
+
{
|
896
|
+
if(y[i]==+1)
|
897
|
+
return(G[i] > Gmax2);
|
898
|
+
else
|
899
|
+
return(G[i] > Gmax1);
|
900
|
+
}
|
901
|
+
else
|
902
|
+
return(false);
|
903
|
+
}
|
904
|
+
|
905
|
+
void Solver::do_shrinking()
|
906
|
+
{
|
907
|
+
int i;
|
908
|
+
double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) }
|
909
|
+
double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) }
|
910
|
+
|
911
|
+
// find maximal violating pair first
|
912
|
+
for(i=0;i<active_size;i++)
|
913
|
+
{
|
914
|
+
if(y[i]==+1)
|
915
|
+
{
|
916
|
+
if(!is_upper_bound(i))
|
917
|
+
{
|
918
|
+
if(-G[i] >= Gmax1)
|
919
|
+
Gmax1 = -G[i];
|
920
|
+
}
|
921
|
+
if(!is_lower_bound(i))
|
922
|
+
{
|
923
|
+
if(G[i] >= Gmax2)
|
924
|
+
Gmax2 = G[i];
|
925
|
+
}
|
926
|
+
}
|
927
|
+
else
|
928
|
+
{
|
929
|
+
if(!is_upper_bound(i))
|
930
|
+
{
|
931
|
+
if(-G[i] >= Gmax2)
|
932
|
+
Gmax2 = -G[i];
|
933
|
+
}
|
934
|
+
if(!is_lower_bound(i))
|
935
|
+
{
|
936
|
+
if(G[i] >= Gmax1)
|
937
|
+
Gmax1 = G[i];
|
938
|
+
}
|
939
|
+
}
|
940
|
+
}
|
941
|
+
|
942
|
+
if(unshrink == false && Gmax1 + Gmax2 <= eps*10)
|
943
|
+
{
|
944
|
+
unshrink = true;
|
945
|
+
reconstruct_gradient();
|
946
|
+
active_size = l;
|
947
|
+
info("*");
|
948
|
+
}
|
949
|
+
|
950
|
+
for(i=0;i<active_size;i++)
|
951
|
+
if (be_shrunk(i, Gmax1, Gmax2))
|
952
|
+
{
|
953
|
+
active_size--;
|
954
|
+
while (active_size > i)
|
955
|
+
{
|
956
|
+
if (!be_shrunk(active_size, Gmax1, Gmax2))
|
957
|
+
{
|
958
|
+
swap_index(i,active_size);
|
959
|
+
break;
|
960
|
+
}
|
961
|
+
active_size--;
|
962
|
+
}
|
963
|
+
}
|
964
|
+
}
|
965
|
+
|
966
|
+
double Solver::calculate_rho()
|
967
|
+
{
|
968
|
+
double r;
|
969
|
+
int nr_free = 0;
|
970
|
+
double ub = INF, lb = -INF, sum_free = 0;
|
971
|
+
for(int i=0;i<active_size;i++)
|
972
|
+
{
|
973
|
+
double yG = y[i]*G[i];
|
974
|
+
|
975
|
+
if(is_upper_bound(i))
|
976
|
+
{
|
977
|
+
if(y[i]==-1)
|
978
|
+
ub = min(ub,yG);
|
979
|
+
else
|
980
|
+
lb = max(lb,yG);
|
981
|
+
}
|
982
|
+
else if(is_lower_bound(i))
|
983
|
+
{
|
984
|
+
if(y[i]==+1)
|
985
|
+
ub = min(ub,yG);
|
986
|
+
else
|
987
|
+
lb = max(lb,yG);
|
988
|
+
}
|
989
|
+
else
|
990
|
+
{
|
991
|
+
++nr_free;
|
992
|
+
sum_free += yG;
|
993
|
+
}
|
994
|
+
}
|
995
|
+
|
996
|
+
if(nr_free>0)
|
997
|
+
r = sum_free/nr_free;
|
998
|
+
else
|
999
|
+
r = (ub+lb)/2;
|
1000
|
+
|
1001
|
+
return r;
|
1002
|
+
}
|
1003
|
+
|
1004
|
+
//
|
1005
|
+
// Solver for nu-svm classification and regression
|
1006
|
+
//
|
1007
|
+
// additional constraint: e^T \alpha = constant
|
1008
|
+
//
|
1009
|
+
class Solver_NU: public Solver
|
1010
|
+
{
|
1011
|
+
public:
|
1012
|
+
Solver_NU() {}
|
1013
|
+
void Solve(int l, const QMatrix& Q, const double *p, const schar *y,
|
1014
|
+
double *alpha, double Cp, double Cn, double eps,
|
1015
|
+
SolutionInfo* si, int shrinking)
|
1016
|
+
{
|
1017
|
+
this->si = si;
|
1018
|
+
Solver::Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking);
|
1019
|
+
}
|
1020
|
+
private:
|
1021
|
+
SolutionInfo *si;
|
1022
|
+
int select_working_set(int &i, int &j);
|
1023
|
+
double calculate_rho();
|
1024
|
+
bool be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4);
|
1025
|
+
void do_shrinking();
|
1026
|
+
};
|
1027
|
+
|
1028
|
+
// return 1 if already optimal, return 0 otherwise
|
1029
|
+
int Solver_NU::select_working_set(int &out_i, int &out_j)
|
1030
|
+
{
|
1031
|
+
// return i,j such that y_i = y_j and
|
1032
|
+
// i: maximizes -y_i * grad(f)_i, i in I_up(\alpha)
|
1033
|
+
// j: minimizes the decrease of obj value
|
1034
|
+
// (if quadratic coefficeint <= 0, replace it with tau)
|
1035
|
+
// -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha)
|
1036
|
+
|
1037
|
+
double Gmaxp = -INF;
|
1038
|
+
double Gmaxp2 = -INF;
|
1039
|
+
int Gmaxp_idx = -1;
|
1040
|
+
|
1041
|
+
double Gmaxn = -INF;
|
1042
|
+
double Gmaxn2 = -INF;
|
1043
|
+
int Gmaxn_idx = -1;
|
1044
|
+
|
1045
|
+
int Gmin_idx = -1;
|
1046
|
+
double obj_diff_min = INF;
|
1047
|
+
|
1048
|
+
for(int t=0;t<active_size;t++)
|
1049
|
+
if(y[t]==+1)
|
1050
|
+
{
|
1051
|
+
if(!is_upper_bound(t))
|
1052
|
+
if(-G[t] >= Gmaxp)
|
1053
|
+
{
|
1054
|
+
Gmaxp = -G[t];
|
1055
|
+
Gmaxp_idx = t;
|
1056
|
+
}
|
1057
|
+
}
|
1058
|
+
else
|
1059
|
+
{
|
1060
|
+
if(!is_lower_bound(t))
|
1061
|
+
if(G[t] >= Gmaxn)
|
1062
|
+
{
|
1063
|
+
Gmaxn = G[t];
|
1064
|
+
Gmaxn_idx = t;
|
1065
|
+
}
|
1066
|
+
}
|
1067
|
+
|
1068
|
+
int ip = Gmaxp_idx;
|
1069
|
+
int in = Gmaxn_idx;
|
1070
|
+
const Qfloat *Q_ip = NULL;
|
1071
|
+
const Qfloat *Q_in = NULL;
|
1072
|
+
if(ip != -1) // NULL Q_ip not accessed: Gmaxp=-INF if ip=-1
|
1073
|
+
Q_ip = Q->get_Q(ip,active_size);
|
1074
|
+
if(in != -1)
|
1075
|
+
Q_in = Q->get_Q(in,active_size);
|
1076
|
+
|
1077
|
+
for(int j=0;j<active_size;j++)
|
1078
|
+
{
|
1079
|
+
if(y[j]==+1)
|
1080
|
+
{
|
1081
|
+
if (!is_lower_bound(j))
|
1082
|
+
{
|
1083
|
+
double grad_diff=Gmaxp+G[j];
|
1084
|
+
if (G[j] >= Gmaxp2)
|
1085
|
+
Gmaxp2 = G[j];
|
1086
|
+
if (grad_diff > 0)
|
1087
|
+
{
|
1088
|
+
double obj_diff;
|
1089
|
+
double quad_coef = QD[ip]+QD[j]-2*Q_ip[j];
|
1090
|
+
if (quad_coef > 0)
|
1091
|
+
obj_diff = -(grad_diff*grad_diff)/quad_coef;
|
1092
|
+
else
|
1093
|
+
obj_diff = -(grad_diff*grad_diff)/TAU;
|
1094
|
+
|
1095
|
+
if (obj_diff <= obj_diff_min)
|
1096
|
+
{
|
1097
|
+
Gmin_idx=j;
|
1098
|
+
obj_diff_min = obj_diff;
|
1099
|
+
}
|
1100
|
+
}
|
1101
|
+
}
|
1102
|
+
}
|
1103
|
+
else
|
1104
|
+
{
|
1105
|
+
if (!is_upper_bound(j))
|
1106
|
+
{
|
1107
|
+
double grad_diff=Gmaxn-G[j];
|
1108
|
+
if (-G[j] >= Gmaxn2)
|
1109
|
+
Gmaxn2 = -G[j];
|
1110
|
+
if (grad_diff > 0)
|
1111
|
+
{
|
1112
|
+
double obj_diff;
|
1113
|
+
double quad_coef = QD[in]+QD[j]-2*Q_in[j];
|
1114
|
+
if (quad_coef > 0)
|
1115
|
+
obj_diff = -(grad_diff*grad_diff)/quad_coef;
|
1116
|
+
else
|
1117
|
+
obj_diff = -(grad_diff*grad_diff)/TAU;
|
1118
|
+
|
1119
|
+
if (obj_diff <= obj_diff_min)
|
1120
|
+
{
|
1121
|
+
Gmin_idx=j;
|
1122
|
+
obj_diff_min = obj_diff;
|
1123
|
+
}
|
1124
|
+
}
|
1125
|
+
}
|
1126
|
+
}
|
1127
|
+
}
|
1128
|
+
|
1129
|
+
if(max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps || Gmin_idx == -1)
|
1130
|
+
return 1;
|
1131
|
+
|
1132
|
+
if (y[Gmin_idx] == +1)
|
1133
|
+
out_i = Gmaxp_idx;
|
1134
|
+
else
|
1135
|
+
out_i = Gmaxn_idx;
|
1136
|
+
out_j = Gmin_idx;
|
1137
|
+
|
1138
|
+
return 0;
|
1139
|
+
}
|
1140
|
+
|
1141
|
+
bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)
|
1142
|
+
{
|
1143
|
+
if(is_upper_bound(i))
|
1144
|
+
{
|
1145
|
+
if(y[i]==+1)
|
1146
|
+
return(-G[i] > Gmax1);
|
1147
|
+
else
|
1148
|
+
return(-G[i] > Gmax4);
|
1149
|
+
}
|
1150
|
+
else if(is_lower_bound(i))
|
1151
|
+
{
|
1152
|
+
if(y[i]==+1)
|
1153
|
+
return(G[i] > Gmax2);
|
1154
|
+
else
|
1155
|
+
return(G[i] > Gmax3);
|
1156
|
+
}
|
1157
|
+
else
|
1158
|
+
return(false);
|
1159
|
+
}
|
1160
|
+
|
1161
|
+
void Solver_NU::do_shrinking()
|
1162
|
+
{
|
1163
|
+
double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) }
|
1164
|
+
double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) }
|
1165
|
+
double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) }
|
1166
|
+
double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) }
|
1167
|
+
|
1168
|
+
// find maximal violating pair first
|
1169
|
+
int i;
|
1170
|
+
for(i=0;i<active_size;i++)
|
1171
|
+
{
|
1172
|
+
if(!is_upper_bound(i))
|
1173
|
+
{
|
1174
|
+
if(y[i]==+1)
|
1175
|
+
{
|
1176
|
+
if(-G[i] > Gmax1) Gmax1 = -G[i];
|
1177
|
+
}
|
1178
|
+
else if(-G[i] > Gmax4) Gmax4 = -G[i];
|
1179
|
+
}
|
1180
|
+
if(!is_lower_bound(i))
|
1181
|
+
{
|
1182
|
+
if(y[i]==+1)
|
1183
|
+
{
|
1184
|
+
if(G[i] > Gmax2) Gmax2 = G[i];
|
1185
|
+
}
|
1186
|
+
else if(G[i] > Gmax3) Gmax3 = G[i];
|
1187
|
+
}
|
1188
|
+
}
|
1189
|
+
|
1190
|
+
if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10)
|
1191
|
+
{
|
1192
|
+
unshrink = true;
|
1193
|
+
reconstruct_gradient();
|
1194
|
+
active_size = l;
|
1195
|
+
}
|
1196
|
+
|
1197
|
+
for(i=0;i<active_size;i++)
|
1198
|
+
if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4))
|
1199
|
+
{
|
1200
|
+
active_size--;
|
1201
|
+
while (active_size > i)
|
1202
|
+
{
|
1203
|
+
if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4))
|
1204
|
+
{
|
1205
|
+
swap_index(i,active_size);
|
1206
|
+
break;
|
1207
|
+
}
|
1208
|
+
active_size--;
|
1209
|
+
}
|
1210
|
+
}
|
1211
|
+
}
|
1212
|
+
|
1213
|
+
double Solver_NU::calculate_rho()
|
1214
|
+
{
|
1215
|
+
int nr_free1 = 0,nr_free2 = 0;
|
1216
|
+
double ub1 = INF, ub2 = INF;
|
1217
|
+
double lb1 = -INF, lb2 = -INF;
|
1218
|
+
double sum_free1 = 0, sum_free2 = 0;
|
1219
|
+
|
1220
|
+
for(int i=0;i<active_size;i++)
|
1221
|
+
{
|
1222
|
+
if(y[i]==+1)
|
1223
|
+
{
|
1224
|
+
if(is_upper_bound(i))
|
1225
|
+
lb1 = max(lb1,G[i]);
|
1226
|
+
else if(is_lower_bound(i))
|
1227
|
+
ub1 = min(ub1,G[i]);
|
1228
|
+
else
|
1229
|
+
{
|
1230
|
+
++nr_free1;
|
1231
|
+
sum_free1 += G[i];
|
1232
|
+
}
|
1233
|
+
}
|
1234
|
+
else
|
1235
|
+
{
|
1236
|
+
if(is_upper_bound(i))
|
1237
|
+
lb2 = max(lb2,G[i]);
|
1238
|
+
else if(is_lower_bound(i))
|
1239
|
+
ub2 = min(ub2,G[i]);
|
1240
|
+
else
|
1241
|
+
{
|
1242
|
+
++nr_free2;
|
1243
|
+
sum_free2 += G[i];
|
1244
|
+
}
|
1245
|
+
}
|
1246
|
+
}
|
1247
|
+
|
1248
|
+
double r1,r2;
|
1249
|
+
if(nr_free1 > 0)
|
1250
|
+
r1 = sum_free1/nr_free1;
|
1251
|
+
else
|
1252
|
+
r1 = (ub1+lb1)/2;
|
1253
|
+
|
1254
|
+
if(nr_free2 > 0)
|
1255
|
+
r2 = sum_free2/nr_free2;
|
1256
|
+
else
|
1257
|
+
r2 = (ub2+lb2)/2;
|
1258
|
+
|
1259
|
+
si->r = (r1+r2)/2;
|
1260
|
+
return (r1-r2)/2;
|
1261
|
+
}
|
1262
|
+
|
1263
|
+
//
|
1264
|
+
// Q matrices for various formulations
|
1265
|
+
//
|
1266
|
+
class SVC_Q: public Kernel
|
1267
|
+
{
|
1268
|
+
public:
|
1269
|
+
SVC_Q(const svm_problem& prob, const svm_parameter& param, const schar *y_)
|
1270
|
+
:Kernel(prob.l, prob.x, param)
|
1271
|
+
{
|
1272
|
+
clone(y,y_,prob.l);
|
1273
|
+
cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
|
1274
|
+
QD = new double[prob.l];
|
1275
|
+
for(int i=0;i<prob.l;i++)
|
1276
|
+
QD[i] = (this->*kernel_function)(i,i);
|
1277
|
+
}
|
1278
|
+
|
1279
|
+
Qfloat *get_Q(int i, int len) const
|
1280
|
+
{
|
1281
|
+
Qfloat *data;
|
1282
|
+
int start, j;
|
1283
|
+
if((start = cache->get_data(i,&data,len)) < len)
|
1284
|
+
{
|
1285
|
+
for(j=start;j<len;j++)
|
1286
|
+
data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));
|
1287
|
+
}
|
1288
|
+
return data;
|
1289
|
+
}
|
1290
|
+
|
1291
|
+
double *get_QD() const
|
1292
|
+
{
|
1293
|
+
return QD;
|
1294
|
+
}
|
1295
|
+
|
1296
|
+
void swap_index(int i, int j) const
|
1297
|
+
{
|
1298
|
+
cache->swap_index(i,j);
|
1299
|
+
Kernel::swap_index(i,j);
|
1300
|
+
swap(y[i],y[j]);
|
1301
|
+
swap(QD[i],QD[j]);
|
1302
|
+
}
|
1303
|
+
|
1304
|
+
~SVC_Q()
|
1305
|
+
{
|
1306
|
+
delete[] y;
|
1307
|
+
delete cache;
|
1308
|
+
delete[] QD;
|
1309
|
+
}
|
1310
|
+
private:
|
1311
|
+
schar *y;
|
1312
|
+
Cache *cache;
|
1313
|
+
double *QD;
|
1314
|
+
};
|
1315
|
+
|
1316
|
+
class ONE_CLASS_Q: public Kernel
|
1317
|
+
{
|
1318
|
+
public:
|
1319
|
+
ONE_CLASS_Q(const svm_problem& prob, const svm_parameter& param)
|
1320
|
+
:Kernel(prob.l, prob.x, param)
|
1321
|
+
{
|
1322
|
+
cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20)));
|
1323
|
+
QD = new double[prob.l];
|
1324
|
+
for(int i=0;i<prob.l;i++)
|
1325
|
+
QD[i] = (this->*kernel_function)(i,i);
|
1326
|
+
}
|
1327
|
+
|
1328
|
+
Qfloat *get_Q(int i, int len) const
|
1329
|
+
{
|
1330
|
+
Qfloat *data;
|
1331
|
+
int start, j;
|
1332
|
+
if((start = cache->get_data(i,&data,len)) < len)
|
1333
|
+
{
|
1334
|
+
for(j=start;j<len;j++)
|
1335
|
+
data[j] = (Qfloat)(this->*kernel_function)(i,j);
|
1336
|
+
}
|
1337
|
+
return data;
|
1338
|
+
}
|
1339
|
+
|
1340
|
+
double *get_QD() const
|
1341
|
+
{
|
1342
|
+
return QD;
|
1343
|
+
}
|
1344
|
+
|
1345
|
+
void swap_index(int i, int j) const
|
1346
|
+
{
|
1347
|
+
cache->swap_index(i,j);
|
1348
|
+
Kernel::swap_index(i,j);
|
1349
|
+
swap(QD[i],QD[j]);
|
1350
|
+
}
|
1351
|
+
|
1352
|
+
~ONE_CLASS_Q()
|
1353
|
+
{
|
1354
|
+
delete cache;
|
1355
|
+
delete[] QD;
|
1356
|
+
}
|
1357
|
+
private:
|
1358
|
+
Cache *cache;
|
1359
|
+
double *QD;
|
1360
|
+
};
|
1361
|
+
|
1362
|
+
class SVR_Q: public Kernel
|
1363
|
+
{
|
1364
|
+
public:
|
1365
|
+
SVR_Q(const svm_problem& prob, const svm_parameter& param)
|
1366
|
+
:Kernel(prob.l, prob.x, param)
|
1367
|
+
{
|
1368
|
+
l = prob.l;
|
1369
|
+
cache = new Cache(l,(long int)(param.cache_size*(1<<20)));
|
1370
|
+
QD = new double[2*l];
|
1371
|
+
sign = new schar[2*l];
|
1372
|
+
index = new int[2*l];
|
1373
|
+
for(int k=0;k<l;k++)
|
1374
|
+
{
|
1375
|
+
sign[k] = 1;
|
1376
|
+
sign[k+l] = -1;
|
1377
|
+
index[k] = k;
|
1378
|
+
index[k+l] = k;
|
1379
|
+
QD[k] = (this->*kernel_function)(k,k);
|
1380
|
+
QD[k+l] = QD[k];
|
1381
|
+
}
|
1382
|
+
buffer[0] = new Qfloat[2*l];
|
1383
|
+
buffer[1] = new Qfloat[2*l];
|
1384
|
+
next_buffer = 0;
|
1385
|
+
}
|
1386
|
+
|
1387
|
+
void swap_index(int i, int j) const
|
1388
|
+
{
|
1389
|
+
swap(sign[i],sign[j]);
|
1390
|
+
swap(index[i],index[j]);
|
1391
|
+
swap(QD[i],QD[j]);
|
1392
|
+
}
|
1393
|
+
|
1394
|
+
Qfloat *get_Q(int i, int len) const
|
1395
|
+
{
|
1396
|
+
Qfloat *data;
|
1397
|
+
int j, real_i = index[i];
|
1398
|
+
if(cache->get_data(real_i,&data,l) < l)
|
1399
|
+
{
|
1400
|
+
for(j=0;j<l;j++)
|
1401
|
+
data[j] = (Qfloat)(this->*kernel_function)(real_i,j);
|
1402
|
+
}
|
1403
|
+
|
1404
|
+
// reorder and copy
|
1405
|
+
Qfloat *buf = buffer[next_buffer];
|
1406
|
+
next_buffer = 1 - next_buffer;
|
1407
|
+
schar si = sign[i];
|
1408
|
+
for(j=0;j<len;j++)
|
1409
|
+
buf[j] = (Qfloat) si * (Qfloat) sign[j] * data[index[j]];
|
1410
|
+
return buf;
|
1411
|
+
}
|
1412
|
+
|
1413
|
+
double *get_QD() const
|
1414
|
+
{
|
1415
|
+
return QD;
|
1416
|
+
}
|
1417
|
+
|
1418
|
+
~SVR_Q()
|
1419
|
+
{
|
1420
|
+
delete cache;
|
1421
|
+
delete[] sign;
|
1422
|
+
delete[] index;
|
1423
|
+
delete[] buffer[0];
|
1424
|
+
delete[] buffer[1];
|
1425
|
+
delete[] QD;
|
1426
|
+
}
|
1427
|
+
private:
|
1428
|
+
int l;
|
1429
|
+
Cache *cache;
|
1430
|
+
schar *sign;
|
1431
|
+
int *index;
|
1432
|
+
mutable int next_buffer;
|
1433
|
+
Qfloat *buffer[2];
|
1434
|
+
double *QD;
|
1435
|
+
};
|
1436
|
+
|
1437
|
+
//
|
1438
|
+
// construct and solve various formulations
|
1439
|
+
//
|
1440
|
+
static void solve_c_svc(
|
1441
|
+
const svm_problem *prob, const svm_parameter* param,
|
1442
|
+
double *alpha, Solver::SolutionInfo* si, double Cp, double Cn)
|
1443
|
+
{
|
1444
|
+
int l = prob->l;
|
1445
|
+
double *minus_ones = new double[l];
|
1446
|
+
schar *y = new schar[l];
|
1447
|
+
|
1448
|
+
int i;
|
1449
|
+
|
1450
|
+
for(i=0;i<l;i++)
|
1451
|
+
{
|
1452
|
+
alpha[i] = 0;
|
1453
|
+
minus_ones[i] = -1;
|
1454
|
+
if(prob->y[i] > 0) y[i] = +1; else y[i] = -1;
|
1455
|
+
}
|
1456
|
+
|
1457
|
+
Solver s;
|
1458
|
+
s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y,
|
1459
|
+
alpha, Cp, Cn, param->eps, si, param->shrinking);
|
1460
|
+
|
1461
|
+
double sum_alpha=0;
|
1462
|
+
for(i=0;i<l;i++)
|
1463
|
+
sum_alpha += alpha[i];
|
1464
|
+
|
1465
|
+
if (Cp==Cn)
|
1466
|
+
info("nu = %f\n", sum_alpha/(Cp*prob->l));
|
1467
|
+
|
1468
|
+
for(i=0;i<l;i++)
|
1469
|
+
alpha[i] *= y[i];
|
1470
|
+
|
1471
|
+
delete[] minus_ones;
|
1472
|
+
delete[] y;
|
1473
|
+
}
|
1474
|
+
|
1475
|
+
static void solve_nu_svc(
|
1476
|
+
const svm_problem *prob, const svm_parameter *param,
|
1477
|
+
double *alpha, Solver::SolutionInfo* si)
|
1478
|
+
{
|
1479
|
+
int i;
|
1480
|
+
int l = prob->l;
|
1481
|
+
double nu = param->nu;
|
1482
|
+
|
1483
|
+
schar *y = new schar[l];
|
1484
|
+
|
1485
|
+
for(i=0;i<l;i++)
|
1486
|
+
if(prob->y[i]>0)
|
1487
|
+
y[i] = +1;
|
1488
|
+
else
|
1489
|
+
y[i] = -1;
|
1490
|
+
|
1491
|
+
double sum_pos = nu*l/2;
|
1492
|
+
double sum_neg = nu*l/2;
|
1493
|
+
|
1494
|
+
for(i=0;i<l;i++)
|
1495
|
+
if(y[i] == +1)
|
1496
|
+
{
|
1497
|
+
alpha[i] = min(1.0,sum_pos);
|
1498
|
+
sum_pos -= alpha[i];
|
1499
|
+
}
|
1500
|
+
else
|
1501
|
+
{
|
1502
|
+
alpha[i] = min(1.0,sum_neg);
|
1503
|
+
sum_neg -= alpha[i];
|
1504
|
+
}
|
1505
|
+
|
1506
|
+
double *zeros = new double[l];
|
1507
|
+
|
1508
|
+
for(i=0;i<l;i++)
|
1509
|
+
zeros[i] = 0;
|
1510
|
+
|
1511
|
+
Solver_NU s;
|
1512
|
+
s.Solve(l, SVC_Q(*prob,*param,y), zeros, y,
|
1513
|
+
alpha, 1.0, 1.0, param->eps, si, param->shrinking);
|
1514
|
+
double r = si->r;
|
1515
|
+
|
1516
|
+
info("C = %f\n",1/r);
|
1517
|
+
|
1518
|
+
for(i=0;i<l;i++)
|
1519
|
+
alpha[i] *= y[i]/r;
|
1520
|
+
|
1521
|
+
si->rho /= r;
|
1522
|
+
si->obj /= (r*r);
|
1523
|
+
si->upper_bound_p = 1/r;
|
1524
|
+
si->upper_bound_n = 1/r;
|
1525
|
+
|
1526
|
+
delete[] y;
|
1527
|
+
delete[] zeros;
|
1528
|
+
}
|
1529
|
+
|
1530
|
+
static void solve_one_class(
|
1531
|
+
const svm_problem *prob, const svm_parameter *param,
|
1532
|
+
double *alpha, Solver::SolutionInfo* si)
|
1533
|
+
{
|
1534
|
+
int l = prob->l;
|
1535
|
+
double *zeros = new double[l];
|
1536
|
+
schar *ones = new schar[l];
|
1537
|
+
int i;
|
1538
|
+
|
1539
|
+
int n = (int)(param->nu*prob->l); // # of alpha's at upper bound
|
1540
|
+
|
1541
|
+
for(i=0;i<n;i++)
|
1542
|
+
alpha[i] = 1;
|
1543
|
+
if(n<prob->l)
|
1544
|
+
alpha[n] = param->nu * prob->l - n;
|
1545
|
+
for(i=n+1;i<l;i++)
|
1546
|
+
alpha[i] = 0;
|
1547
|
+
|
1548
|
+
for(i=0;i<l;i++)
|
1549
|
+
{
|
1550
|
+
zeros[i] = 0;
|
1551
|
+
ones[i] = 1;
|
1552
|
+
}
|
1553
|
+
|
1554
|
+
Solver s;
|
1555
|
+
s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones,
|
1556
|
+
alpha, 1.0, 1.0, param->eps, si, param->shrinking);
|
1557
|
+
|
1558
|
+
delete[] zeros;
|
1559
|
+
delete[] ones;
|
1560
|
+
}
|
1561
|
+
|
1562
|
+
static void solve_epsilon_svr(
|
1563
|
+
const svm_problem *prob, const svm_parameter *param,
|
1564
|
+
double *alpha, Solver::SolutionInfo* si)
|
1565
|
+
{
|
1566
|
+
int l = prob->l;
|
1567
|
+
double *alpha2 = new double[2*l];
|
1568
|
+
double *linear_term = new double[2*l];
|
1569
|
+
schar *y = new schar[2*l];
|
1570
|
+
int i;
|
1571
|
+
|
1572
|
+
for(i=0;i<l;i++)
|
1573
|
+
{
|
1574
|
+
alpha2[i] = 0;
|
1575
|
+
linear_term[i] = param->p - prob->y[i];
|
1576
|
+
y[i] = 1;
|
1577
|
+
|
1578
|
+
alpha2[i+l] = 0;
|
1579
|
+
linear_term[i+l] = param->p + prob->y[i];
|
1580
|
+
y[i+l] = -1;
|
1581
|
+
}
|
1582
|
+
|
1583
|
+
Solver s;
|
1584
|
+
s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
|
1585
|
+
alpha2, param->C, param->C, param->eps, si, param->shrinking);
|
1586
|
+
|
1587
|
+
double sum_alpha = 0;
|
1588
|
+
for(i=0;i<l;i++)
|
1589
|
+
{
|
1590
|
+
alpha[i] = alpha2[i] - alpha2[i+l];
|
1591
|
+
sum_alpha += fabs(alpha[i]);
|
1592
|
+
}
|
1593
|
+
info("nu = %f\n",sum_alpha/(param->C*l));
|
1594
|
+
|
1595
|
+
delete[] alpha2;
|
1596
|
+
delete[] linear_term;
|
1597
|
+
delete[] y;
|
1598
|
+
}
|
1599
|
+
|
1600
|
+
static void solve_nu_svr(
|
1601
|
+
const svm_problem *prob, const svm_parameter *param,
|
1602
|
+
double *alpha, Solver::SolutionInfo* si)
|
1603
|
+
{
|
1604
|
+
int l = prob->l;
|
1605
|
+
double C = param->C;
|
1606
|
+
double *alpha2 = new double[2*l];
|
1607
|
+
double *linear_term = new double[2*l];
|
1608
|
+
schar *y = new schar[2*l];
|
1609
|
+
int i;
|
1610
|
+
|
1611
|
+
double sum = C * param->nu * l / 2;
|
1612
|
+
for(i=0;i<l;i++)
|
1613
|
+
{
|
1614
|
+
alpha2[i] = alpha2[i+l] = min(sum,C);
|
1615
|
+
sum -= alpha2[i];
|
1616
|
+
|
1617
|
+
linear_term[i] = - prob->y[i];
|
1618
|
+
y[i] = 1;
|
1619
|
+
|
1620
|
+
linear_term[i+l] = prob->y[i];
|
1621
|
+
y[i+l] = -1;
|
1622
|
+
}
|
1623
|
+
|
1624
|
+
Solver_NU s;
|
1625
|
+
s.Solve(2*l, SVR_Q(*prob,*param), linear_term, y,
|
1626
|
+
alpha2, C, C, param->eps, si, param->shrinking);
|
1627
|
+
|
1628
|
+
info("epsilon = %f\n",-si->r);
|
1629
|
+
|
1630
|
+
for(i=0;i<l;i++)
|
1631
|
+
alpha[i] = alpha2[i] - alpha2[i+l];
|
1632
|
+
|
1633
|
+
delete[] alpha2;
|
1634
|
+
delete[] linear_term;
|
1635
|
+
delete[] y;
|
1636
|
+
}
|
1637
|
+
|
1638
|
+
//
|
1639
|
+
// decision_function
|
1640
|
+
//
|
1641
|
+
struct decision_function
|
1642
|
+
{
|
1643
|
+
double *alpha;
|
1644
|
+
double rho;
|
1645
|
+
};
|
1646
|
+
|
1647
|
+
static decision_function svm_train_one(
|
1648
|
+
const svm_problem *prob, const svm_parameter *param,
|
1649
|
+
double Cp, double Cn)
|
1650
|
+
{
|
1651
|
+
double *alpha = Malloc(double,prob->l);
|
1652
|
+
Solver::SolutionInfo si;
|
1653
|
+
switch(param->svm_type)
|
1654
|
+
{
|
1655
|
+
case C_SVC:
|
1656
|
+
solve_c_svc(prob,param,alpha,&si,Cp,Cn);
|
1657
|
+
break;
|
1658
|
+
case NU_SVC:
|
1659
|
+
solve_nu_svc(prob,param,alpha,&si);
|
1660
|
+
break;
|
1661
|
+
case ONE_CLASS:
|
1662
|
+
solve_one_class(prob,param,alpha,&si);
|
1663
|
+
break;
|
1664
|
+
case EPSILON_SVR:
|
1665
|
+
solve_epsilon_svr(prob,param,alpha,&si);
|
1666
|
+
break;
|
1667
|
+
case NU_SVR:
|
1668
|
+
solve_nu_svr(prob,param,alpha,&si);
|
1669
|
+
break;
|
1670
|
+
}
|
1671
|
+
|
1672
|
+
info("obj = %f, rho = %f\n",si.obj,si.rho);
|
1673
|
+
|
1674
|
+
// output SVs
|
1675
|
+
|
1676
|
+
int nSV = 0;
|
1677
|
+
int nBSV = 0;
|
1678
|
+
for(int i=0;i<prob->l;i++)
|
1679
|
+
{
|
1680
|
+
if(fabs(alpha[i]) > 0)
|
1681
|
+
{
|
1682
|
+
++nSV;
|
1683
|
+
if(prob->y[i] > 0)
|
1684
|
+
{
|
1685
|
+
if(fabs(alpha[i]) >= si.upper_bound_p)
|
1686
|
+
++nBSV;
|
1687
|
+
}
|
1688
|
+
else
|
1689
|
+
{
|
1690
|
+
if(fabs(alpha[i]) >= si.upper_bound_n)
|
1691
|
+
++nBSV;
|
1692
|
+
}
|
1693
|
+
}
|
1694
|
+
}
|
1695
|
+
|
1696
|
+
info("nSV = %d, nBSV = %d\n",nSV,nBSV);
|
1697
|
+
|
1698
|
+
decision_function f;
|
1699
|
+
f.alpha = alpha;
|
1700
|
+
f.rho = si.rho;
|
1701
|
+
return f;
|
1702
|
+
}
|
1703
|
+
|
1704
|
+
// Platt's binary SVM Probablistic Output: an improvement from Lin et al.
|
1705
|
+
static void sigmoid_train(
|
1706
|
+
int l, const double *dec_values, const double *labels,
|
1707
|
+
double& A, double& B)
|
1708
|
+
{
|
1709
|
+
double prior1=0, prior0 = 0;
|
1710
|
+
int i;
|
1711
|
+
|
1712
|
+
for (i=0;i<l;i++)
|
1713
|
+
if (labels[i] > 0) prior1+=1;
|
1714
|
+
else prior0+=1;
|
1715
|
+
|
1716
|
+
int max_iter=100; // Maximal number of iterations
|
1717
|
+
double min_step=1e-10; // Minimal step taken in line search
|
1718
|
+
double sigma=1e-12; // For numerically strict PD of Hessian
|
1719
|
+
double eps=1e-5;
|
1720
|
+
double hiTarget=(prior1+1.0)/(prior1+2.0);
|
1721
|
+
double loTarget=1/(prior0+2.0);
|
1722
|
+
double *t=Malloc(double,l);
|
1723
|
+
double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
|
1724
|
+
double newA,newB,newf,d1,d2;
|
1725
|
+
int iter;
|
1726
|
+
|
1727
|
+
// Initial Point and Initial Fun Value
|
1728
|
+
A=0.0; B=log((prior0+1.0)/(prior1+1.0));
|
1729
|
+
double fval = 0.0;
|
1730
|
+
|
1731
|
+
for (i=0;i<l;i++)
|
1732
|
+
{
|
1733
|
+
if (labels[i]>0) t[i]=hiTarget;
|
1734
|
+
else t[i]=loTarget;
|
1735
|
+
fApB = dec_values[i]*A+B;
|
1736
|
+
if (fApB>=0)
|
1737
|
+
fval += t[i]*fApB + log(1+exp(-fApB));
|
1738
|
+
else
|
1739
|
+
fval += (t[i] - 1)*fApB +log(1+exp(fApB));
|
1740
|
+
}
|
1741
|
+
for (iter=0;iter<max_iter;iter++)
|
1742
|
+
{
|
1743
|
+
// Update Gradient and Hessian (use H' = H + sigma I)
|
1744
|
+
h11=sigma; // numerically ensures strict PD
|
1745
|
+
h22=sigma;
|
1746
|
+
h21=0.0;g1=0.0;g2=0.0;
|
1747
|
+
for (i=0;i<l;i++)
|
1748
|
+
{
|
1749
|
+
fApB = dec_values[i]*A+B;
|
1750
|
+
if (fApB >= 0)
|
1751
|
+
{
|
1752
|
+
p=exp(-fApB)/(1.0+exp(-fApB));
|
1753
|
+
q=1.0/(1.0+exp(-fApB));
|
1754
|
+
}
|
1755
|
+
else
|
1756
|
+
{
|
1757
|
+
p=1.0/(1.0+exp(fApB));
|
1758
|
+
q=exp(fApB)/(1.0+exp(fApB));
|
1759
|
+
}
|
1760
|
+
d2=p*q;
|
1761
|
+
h11+=dec_values[i]*dec_values[i]*d2;
|
1762
|
+
h22+=d2;
|
1763
|
+
h21+=dec_values[i]*d2;
|
1764
|
+
d1=t[i]-p;
|
1765
|
+
g1+=dec_values[i]*d1;
|
1766
|
+
g2+=d1;
|
1767
|
+
}
|
1768
|
+
|
1769
|
+
// Stopping Criteria
|
1770
|
+
if (fabs(g1)<eps && fabs(g2)<eps)
|
1771
|
+
break;
|
1772
|
+
|
1773
|
+
// Finding Newton direction: -inv(H') * g
|
1774
|
+
det=h11*h22-h21*h21;
|
1775
|
+
dA=-(h22*g1 - h21 * g2) / det;
|
1776
|
+
dB=-(-h21*g1+ h11 * g2) / det;
|
1777
|
+
gd=g1*dA+g2*dB;
|
1778
|
+
|
1779
|
+
|
1780
|
+
stepsize = 1; // Line Search
|
1781
|
+
while (stepsize >= min_step)
|
1782
|
+
{
|
1783
|
+
newA = A + stepsize * dA;
|
1784
|
+
newB = B + stepsize * dB;
|
1785
|
+
|
1786
|
+
// New function value
|
1787
|
+
newf = 0.0;
|
1788
|
+
for (i=0;i<l;i++)
|
1789
|
+
{
|
1790
|
+
fApB = dec_values[i]*newA+newB;
|
1791
|
+
if (fApB >= 0)
|
1792
|
+
newf += t[i]*fApB + log(1+exp(-fApB));
|
1793
|
+
else
|
1794
|
+
newf += (t[i] - 1)*fApB +log(1+exp(fApB));
|
1795
|
+
}
|
1796
|
+
// Check sufficient decrease
|
1797
|
+
if (newf<fval+0.0001*stepsize*gd)
|
1798
|
+
{
|
1799
|
+
A=newA;B=newB;fval=newf;
|
1800
|
+
break;
|
1801
|
+
}
|
1802
|
+
else
|
1803
|
+
stepsize = stepsize / 2.0;
|
1804
|
+
}
|
1805
|
+
|
1806
|
+
if (stepsize < min_step)
|
1807
|
+
{
|
1808
|
+
info("Line search fails in two-class probability estimates\n");
|
1809
|
+
break;
|
1810
|
+
}
|
1811
|
+
}
|
1812
|
+
|
1813
|
+
if (iter>=max_iter)
|
1814
|
+
info("Reaching maximal iterations in two-class probability estimates\n");
|
1815
|
+
free(t);
|
1816
|
+
}
|
1817
|
+
|
1818
|
+
static double sigmoid_predict(double decision_value, double A, double B)
|
1819
|
+
{
|
1820
|
+
double fApB = decision_value*A+B;
|
1821
|
+
// 1-p used later; avoid catastrophic cancellation
|
1822
|
+
if (fApB >= 0)
|
1823
|
+
return exp(-fApB)/(1.0+exp(-fApB));
|
1824
|
+
else
|
1825
|
+
return 1.0/(1+exp(fApB)) ;
|
1826
|
+
}
|
1827
|
+
|
1828
|
+
// Method 2 from the multiclass_prob paper by Wu, Lin, and Weng
|
1829
|
+
static void multiclass_probability(int k, double **r, double *p)
|
1830
|
+
{
|
1831
|
+
int t,j;
|
1832
|
+
int iter = 0, max_iter=max(100,k);
|
1833
|
+
double **Q=Malloc(double *,k);
|
1834
|
+
double *Qp=Malloc(double,k);
|
1835
|
+
double pQp, eps=0.005/k;
|
1836
|
+
|
1837
|
+
for (t=0;t<k;t++)
|
1838
|
+
{
|
1839
|
+
p[t]=1.0/k; // Valid if k = 1
|
1840
|
+
Q[t]=Malloc(double,k);
|
1841
|
+
Q[t][t]=0;
|
1842
|
+
for (j=0;j<t;j++)
|
1843
|
+
{
|
1844
|
+
Q[t][t]+=r[j][t]*r[j][t];
|
1845
|
+
Q[t][j]=Q[j][t];
|
1846
|
+
}
|
1847
|
+
for (j=t+1;j<k;j++)
|
1848
|
+
{
|
1849
|
+
Q[t][t]+=r[j][t]*r[j][t];
|
1850
|
+
Q[t][j]=-r[j][t]*r[t][j];
|
1851
|
+
}
|
1852
|
+
}
|
1853
|
+
for (iter=0;iter<max_iter;iter++)
|
1854
|
+
{
|
1855
|
+
// stopping condition, recalculate QP,pQP for numerical accuracy
|
1856
|
+
pQp=0;
|
1857
|
+
for (t=0;t<k;t++)
|
1858
|
+
{
|
1859
|
+
Qp[t]=0;
|
1860
|
+
for (j=0;j<k;j++)
|
1861
|
+
Qp[t]+=Q[t][j]*p[j];
|
1862
|
+
pQp+=p[t]*Qp[t];
|
1863
|
+
}
|
1864
|
+
double max_error=0;
|
1865
|
+
for (t=0;t<k;t++)
|
1866
|
+
{
|
1867
|
+
double error=fabs(Qp[t]-pQp);
|
1868
|
+
if (error>max_error)
|
1869
|
+
max_error=error;
|
1870
|
+
}
|
1871
|
+
if (max_error<eps) break;
|
1872
|
+
|
1873
|
+
for (t=0;t<k;t++)
|
1874
|
+
{
|
1875
|
+
double diff=(-Qp[t]+pQp)/Q[t][t];
|
1876
|
+
p[t]+=diff;
|
1877
|
+
pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff);
|
1878
|
+
for (j=0;j<k;j++)
|
1879
|
+
{
|
1880
|
+
Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff);
|
1881
|
+
p[j]/=(1+diff);
|
1882
|
+
}
|
1883
|
+
}
|
1884
|
+
}
|
1885
|
+
if (iter>=max_iter)
|
1886
|
+
info("Exceeds max_iter in multiclass_prob\n");
|
1887
|
+
for(t=0;t<k;t++) free(Q[t]);
|
1888
|
+
free(Q);
|
1889
|
+
free(Qp);
|
1890
|
+
}
|
1891
|
+
|
1892
|
+
// Cross-validation decision values for probability estimates
|
1893
|
+
static void svm_binary_svc_probability(
|
1894
|
+
const svm_problem *prob, const svm_parameter *param,
|
1895
|
+
double Cp, double Cn, double& probA, double& probB)
|
1896
|
+
{
|
1897
|
+
int i;
|
1898
|
+
int nr_fold = 5;
|
1899
|
+
int *perm = Malloc(int,prob->l);
|
1900
|
+
double *dec_values = Malloc(double,prob->l);
|
1901
|
+
|
1902
|
+
// random shuffle
|
1903
|
+
for(i=0;i<prob->l;i++) perm[i]=i;
|
1904
|
+
for(i=0;i<prob->l;i++)
|
1905
|
+
{
|
1906
|
+
int j = i+rand()%(prob->l-i);
|
1907
|
+
swap(perm[i],perm[j]);
|
1908
|
+
}
|
1909
|
+
for(i=0;i<nr_fold;i++)
|
1910
|
+
{
|
1911
|
+
int begin = i*prob->l/nr_fold;
|
1912
|
+
int end = (i+1)*prob->l/nr_fold;
|
1913
|
+
int j,k;
|
1914
|
+
struct svm_problem subprob;
|
1915
|
+
|
1916
|
+
subprob.l = prob->l-(end-begin);
|
1917
|
+
subprob.x = Malloc(struct svm_node*,subprob.l);
|
1918
|
+
subprob.y = Malloc(double,subprob.l);
|
1919
|
+
|
1920
|
+
k=0;
|
1921
|
+
for(j=0;j<begin;j++)
|
1922
|
+
{
|
1923
|
+
subprob.x[k] = prob->x[perm[j]];
|
1924
|
+
subprob.y[k] = prob->y[perm[j]];
|
1925
|
+
++k;
|
1926
|
+
}
|
1927
|
+
for(j=end;j<prob->l;j++)
|
1928
|
+
{
|
1929
|
+
subprob.x[k] = prob->x[perm[j]];
|
1930
|
+
subprob.y[k] = prob->y[perm[j]];
|
1931
|
+
++k;
|
1932
|
+
}
|
1933
|
+
int p_count=0,n_count=0;
|
1934
|
+
for(j=0;j<k;j++)
|
1935
|
+
if(subprob.y[j]>0)
|
1936
|
+
p_count++;
|
1937
|
+
else
|
1938
|
+
n_count++;
|
1939
|
+
|
1940
|
+
if(p_count==0 && n_count==0)
|
1941
|
+
for(j=begin;j<end;j++)
|
1942
|
+
dec_values[perm[j]] = 0;
|
1943
|
+
else if(p_count > 0 && n_count == 0)
|
1944
|
+
for(j=begin;j<end;j++)
|
1945
|
+
dec_values[perm[j]] = 1;
|
1946
|
+
else if(p_count == 0 && n_count > 0)
|
1947
|
+
for(j=begin;j<end;j++)
|
1948
|
+
dec_values[perm[j]] = -1;
|
1949
|
+
else
|
1950
|
+
{
|
1951
|
+
svm_parameter subparam = *param;
|
1952
|
+
subparam.probability=0;
|
1953
|
+
subparam.C=1.0;
|
1954
|
+
subparam.nr_weight=2;
|
1955
|
+
subparam.weight_label = Malloc(int,2);
|
1956
|
+
subparam.weight = Malloc(double,2);
|
1957
|
+
subparam.weight_label[0]=+1;
|
1958
|
+
subparam.weight_label[1]=-1;
|
1959
|
+
subparam.weight[0]=Cp;
|
1960
|
+
subparam.weight[1]=Cn;
|
1961
|
+
struct svm_model *submodel = svm_train(&subprob,&subparam);
|
1962
|
+
for(j=begin;j<end;j++)
|
1963
|
+
{
|
1964
|
+
svm_predict_values(submodel,prob->x[perm[j]],&(dec_values[perm[j]]));
|
1965
|
+
// ensure +1 -1 order; reason not using CV subroutine
|
1966
|
+
dec_values[perm[j]] *= submodel->label[0];
|
1967
|
+
}
|
1968
|
+
svm_free_and_destroy_model(&submodel);
|
1969
|
+
svm_destroy_param(&subparam);
|
1970
|
+
}
|
1971
|
+
free(subprob.x);
|
1972
|
+
free(subprob.y);
|
1973
|
+
}
|
1974
|
+
sigmoid_train(prob->l,dec_values,prob->y,probA,probB);
|
1975
|
+
free(dec_values);
|
1976
|
+
free(perm);
|
1977
|
+
}
|
1978
|
+
|
1979
|
+
// Return parameter of a Laplace distribution
|
1980
|
+
static double svm_svr_probability(
|
1981
|
+
const svm_problem *prob, const svm_parameter *param)
|
1982
|
+
{
|
1983
|
+
int i;
|
1984
|
+
int nr_fold = 5;
|
1985
|
+
double *ymv = Malloc(double,prob->l);
|
1986
|
+
double mae = 0;
|
1987
|
+
|
1988
|
+
svm_parameter newparam = *param;
|
1989
|
+
newparam.probability = 0;
|
1990
|
+
svm_cross_validation(prob,&newparam,nr_fold,ymv);
|
1991
|
+
for(i=0;i<prob->l;i++)
|
1992
|
+
{
|
1993
|
+
ymv[i]=prob->y[i]-ymv[i];
|
1994
|
+
mae += fabs(ymv[i]);
|
1995
|
+
}
|
1996
|
+
mae /= prob->l;
|
1997
|
+
double std=sqrt(2*mae*mae);
|
1998
|
+
int count=0;
|
1999
|
+
mae=0;
|
2000
|
+
for(i=0;i<prob->l;i++)
|
2001
|
+
if (fabs(ymv[i]) > 5*std)
|
2002
|
+
count=count+1;
|
2003
|
+
else
|
2004
|
+
mae+=fabs(ymv[i]);
|
2005
|
+
mae /= (prob->l-count);
|
2006
|
+
info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae);
|
2007
|
+
free(ymv);
|
2008
|
+
return mae;
|
2009
|
+
}
|
2010
|
+
|
2011
|
+
|
2012
|
+
// label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data
|
2013
|
+
// perm, length l, must be allocated before calling this subroutine
|
2014
|
+
static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)
|
2015
|
+
{
|
2016
|
+
int l = prob->l;
|
2017
|
+
int max_nr_class = 16;
|
2018
|
+
int nr_class = 0;
|
2019
|
+
int *label = Malloc(int,max_nr_class);
|
2020
|
+
int *count = Malloc(int,max_nr_class);
|
2021
|
+
int *data_label = Malloc(int,l);
|
2022
|
+
int i;
|
2023
|
+
|
2024
|
+
for(i=0;i<l;i++)
|
2025
|
+
{
|
2026
|
+
int this_label = (int)prob->y[i];
|
2027
|
+
int j;
|
2028
|
+
for(j=0;j<nr_class;j++)
|
2029
|
+
{
|
2030
|
+
if(this_label == label[j])
|
2031
|
+
{
|
2032
|
+
++count[j];
|
2033
|
+
break;
|
2034
|
+
}
|
2035
|
+
}
|
2036
|
+
data_label[i] = j;
|
2037
|
+
if(j == nr_class)
|
2038
|
+
{
|
2039
|
+
if(nr_class == max_nr_class)
|
2040
|
+
{
|
2041
|
+
max_nr_class *= 2;
|
2042
|
+
label = (int *)realloc(label,max_nr_class*sizeof(int));
|
2043
|
+
count = (int *)realloc(count,max_nr_class*sizeof(int));
|
2044
|
+
}
|
2045
|
+
label[nr_class] = this_label;
|
2046
|
+
count[nr_class] = 1;
|
2047
|
+
++nr_class;
|
2048
|
+
}
|
2049
|
+
}
|
2050
|
+
|
2051
|
+
//
|
2052
|
+
// Labels are ordered by their first occurrence in the training set.
|
2053
|
+
// However, for two-class sets with -1/+1 labels and -1 appears first,
|
2054
|
+
// we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.
|
2055
|
+
//
|
2056
|
+
if (nr_class == 2 && label[0] == -1 && label[1] == 1)
|
2057
|
+
{
|
2058
|
+
swap(label[0],label[1]);
|
2059
|
+
swap(count[0],count[1]);
|
2060
|
+
for(i=0;i<l;i++)
|
2061
|
+
{
|
2062
|
+
if(data_label[i] == 0)
|
2063
|
+
data_label[i] = 1;
|
2064
|
+
else
|
2065
|
+
data_label[i] = 0;
|
2066
|
+
}
|
2067
|
+
}
|
2068
|
+
|
2069
|
+
int *start = Malloc(int,nr_class);
|
2070
|
+
start[0] = 0;
|
2071
|
+
for(i=1;i<nr_class;i++)
|
2072
|
+
start[i] = start[i-1]+count[i-1];
|
2073
|
+
for(i=0;i<l;i++)
|
2074
|
+
{
|
2075
|
+
perm[start[data_label[i]]] = i;
|
2076
|
+
++start[data_label[i]];
|
2077
|
+
}
|
2078
|
+
start[0] = 0;
|
2079
|
+
for(i=1;i<nr_class;i++)
|
2080
|
+
start[i] = start[i-1]+count[i-1];
|
2081
|
+
|
2082
|
+
*nr_class_ret = nr_class;
|
2083
|
+
*label_ret = label;
|
2084
|
+
*start_ret = start;
|
2085
|
+
*count_ret = count;
|
2086
|
+
free(data_label);
|
2087
|
+
}
|
2088
|
+
|
2089
|
+
//
|
2090
|
+
// Interface functions
|
2091
|
+
//
|
2092
|
+
svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)
|
2093
|
+
{
|
2094
|
+
svm_model *model = Malloc(svm_model,1);
|
2095
|
+
model->param = *param;
|
2096
|
+
model->free_sv = 0; // XXX
|
2097
|
+
|
2098
|
+
if(param->svm_type == ONE_CLASS ||
|
2099
|
+
param->svm_type == EPSILON_SVR ||
|
2100
|
+
param->svm_type == NU_SVR)
|
2101
|
+
{
|
2102
|
+
// regression or one-class-svm
|
2103
|
+
model->nr_class = 2;
|
2104
|
+
model->label = NULL;
|
2105
|
+
model->nSV = NULL;
|
2106
|
+
model->probA = NULL; model->probB = NULL;
|
2107
|
+
model->sv_coef = Malloc(double *,1);
|
2108
|
+
|
2109
|
+
if(param->probability &&
|
2110
|
+
(param->svm_type == EPSILON_SVR ||
|
2111
|
+
param->svm_type == NU_SVR))
|
2112
|
+
{
|
2113
|
+
model->probA = Malloc(double,1);
|
2114
|
+
model->probA[0] = svm_svr_probability(prob,param);
|
2115
|
+
}
|
2116
|
+
|
2117
|
+
decision_function f = svm_train_one(prob,param,0,0);
|
2118
|
+
model->rho = Malloc(double,1);
|
2119
|
+
model->rho[0] = f.rho;
|
2120
|
+
|
2121
|
+
int nSV = 0;
|
2122
|
+
int i;
|
2123
|
+
for(i=0;i<prob->l;i++)
|
2124
|
+
if(fabs(f.alpha[i]) > 0) ++nSV;
|
2125
|
+
model->l = nSV;
|
2126
|
+
model->SV = Malloc(svm_node *,nSV);
|
2127
|
+
model->sv_coef[0] = Malloc(double,nSV);
|
2128
|
+
model->sv_indices = Malloc(int,nSV);
|
2129
|
+
int j = 0;
|
2130
|
+
for(i=0;i<prob->l;i++)
|
2131
|
+
if(fabs(f.alpha[i]) > 0)
|
2132
|
+
{
|
2133
|
+
model->SV[j] = prob->x[i];
|
2134
|
+
model->sv_coef[0][j] = f.alpha[i];
|
2135
|
+
model->sv_indices[j] = i+1;
|
2136
|
+
++j;
|
2137
|
+
}
|
2138
|
+
|
2139
|
+
free(f.alpha);
|
2140
|
+
}
|
2141
|
+
else
|
2142
|
+
{
|
2143
|
+
// classification
|
2144
|
+
int l = prob->l;
|
2145
|
+
int nr_class;
|
2146
|
+
int *label = NULL;
|
2147
|
+
int *start = NULL;
|
2148
|
+
int *count = NULL;
|
2149
|
+
int *perm = Malloc(int,l);
|
2150
|
+
|
2151
|
+
// group training data of the same class
|
2152
|
+
svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
|
2153
|
+
if(nr_class == 1)
|
2154
|
+
info("WARNING: training data in only one class. See README for details.\n");
|
2155
|
+
|
2156
|
+
svm_node **x = Malloc(svm_node *,l);
|
2157
|
+
int i;
|
2158
|
+
for(i=0;i<l;i++)
|
2159
|
+
x[i] = prob->x[perm[i]];
|
2160
|
+
|
2161
|
+
// calculate weighted C
|
2162
|
+
|
2163
|
+
double *weighted_C = Malloc(double, nr_class);
|
2164
|
+
for(i=0;i<nr_class;i++)
|
2165
|
+
weighted_C[i] = param->C;
|
2166
|
+
for(i=0;i<param->nr_weight;i++)
|
2167
|
+
{
|
2168
|
+
int j;
|
2169
|
+
for(j=0;j<nr_class;j++)
|
2170
|
+
if(param->weight_label[i] == label[j])
|
2171
|
+
break;
|
2172
|
+
if(j == nr_class)
|
2173
|
+
fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]);
|
2174
|
+
else
|
2175
|
+
weighted_C[j] *= param->weight[i];
|
2176
|
+
}
|
2177
|
+
|
2178
|
+
// train k*(k-1)/2 models
|
2179
|
+
|
2180
|
+
bool *nonzero = Malloc(bool,l);
|
2181
|
+
for(i=0;i<l;i++)
|
2182
|
+
nonzero[i] = false;
|
2183
|
+
decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);
|
2184
|
+
|
2185
|
+
double *probA=NULL,*probB=NULL;
|
2186
|
+
if (param->probability)
|
2187
|
+
{
|
2188
|
+
probA=Malloc(double,nr_class*(nr_class-1)/2);
|
2189
|
+
probB=Malloc(double,nr_class*(nr_class-1)/2);
|
2190
|
+
}
|
2191
|
+
|
2192
|
+
int p = 0;
|
2193
|
+
for(i=0;i<nr_class;i++)
|
2194
|
+
for(int j=i+1;j<nr_class;j++)
|
2195
|
+
{
|
2196
|
+
svm_problem sub_prob;
|
2197
|
+
int si = start[i], sj = start[j];
|
2198
|
+
int ci = count[i], cj = count[j];
|
2199
|
+
sub_prob.l = ci+cj;
|
2200
|
+
sub_prob.x = Malloc(svm_node *,sub_prob.l);
|
2201
|
+
sub_prob.y = Malloc(double,sub_prob.l);
|
2202
|
+
int k;
|
2203
|
+
for(k=0;k<ci;k++)
|
2204
|
+
{
|
2205
|
+
sub_prob.x[k] = x[si+k];
|
2206
|
+
sub_prob.y[k] = +1;
|
2207
|
+
}
|
2208
|
+
for(k=0;k<cj;k++)
|
2209
|
+
{
|
2210
|
+
sub_prob.x[ci+k] = x[sj+k];
|
2211
|
+
sub_prob.y[ci+k] = -1;
|
2212
|
+
}
|
2213
|
+
|
2214
|
+
if(param->probability)
|
2215
|
+
svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);
|
2216
|
+
|
2217
|
+
f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);
|
2218
|
+
for(k=0;k<ci;k++)
|
2219
|
+
if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)
|
2220
|
+
nonzero[si+k] = true;
|
2221
|
+
for(k=0;k<cj;k++)
|
2222
|
+
if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)
|
2223
|
+
nonzero[sj+k] = true;
|
2224
|
+
free(sub_prob.x);
|
2225
|
+
free(sub_prob.y);
|
2226
|
+
++p;
|
2227
|
+
}
|
2228
|
+
|
2229
|
+
// build output
|
2230
|
+
|
2231
|
+
model->nr_class = nr_class;
|
2232
|
+
|
2233
|
+
model->label = Malloc(int,nr_class);
|
2234
|
+
for(i=0;i<nr_class;i++)
|
2235
|
+
model->label[i] = label[i];
|
2236
|
+
|
2237
|
+
model->rho = Malloc(double,nr_class*(nr_class-1)/2);
|
2238
|
+
for(i=0;i<nr_class*(nr_class-1)/2;i++)
|
2239
|
+
model->rho[i] = f[i].rho;
|
2240
|
+
|
2241
|
+
if(param->probability)
|
2242
|
+
{
|
2243
|
+
model->probA = Malloc(double,nr_class*(nr_class-1)/2);
|
2244
|
+
model->probB = Malloc(double,nr_class*(nr_class-1)/2);
|
2245
|
+
for(i=0;i<nr_class*(nr_class-1)/2;i++)
|
2246
|
+
{
|
2247
|
+
model->probA[i] = probA[i];
|
2248
|
+
model->probB[i] = probB[i];
|
2249
|
+
}
|
2250
|
+
}
|
2251
|
+
else
|
2252
|
+
{
|
2253
|
+
model->probA=NULL;
|
2254
|
+
model->probB=NULL;
|
2255
|
+
}
|
2256
|
+
|
2257
|
+
int total_sv = 0;
|
2258
|
+
int *nz_count = Malloc(int,nr_class);
|
2259
|
+
model->nSV = Malloc(int,nr_class);
|
2260
|
+
for(i=0;i<nr_class;i++)
|
2261
|
+
{
|
2262
|
+
int nSV = 0;
|
2263
|
+
for(int j=0;j<count[i];j++)
|
2264
|
+
if(nonzero[start[i]+j])
|
2265
|
+
{
|
2266
|
+
++nSV;
|
2267
|
+
++total_sv;
|
2268
|
+
}
|
2269
|
+
model->nSV[i] = nSV;
|
2270
|
+
nz_count[i] = nSV;
|
2271
|
+
}
|
2272
|
+
|
2273
|
+
info("Total nSV = %d\n",total_sv);
|
2274
|
+
|
2275
|
+
model->l = total_sv;
|
2276
|
+
model->SV = Malloc(svm_node *,total_sv);
|
2277
|
+
model->sv_indices = Malloc(int,total_sv);
|
2278
|
+
p = 0;
|
2279
|
+
for(i=0;i<l;i++)
|
2280
|
+
if(nonzero[i])
|
2281
|
+
{
|
2282
|
+
model->SV[p] = x[i];
|
2283
|
+
model->sv_indices[p++] = perm[i] + 1;
|
2284
|
+
}
|
2285
|
+
|
2286
|
+
int *nz_start = Malloc(int,nr_class);
|
2287
|
+
nz_start[0] = 0;
|
2288
|
+
for(i=1;i<nr_class;i++)
|
2289
|
+
nz_start[i] = nz_start[i-1]+nz_count[i-1];
|
2290
|
+
|
2291
|
+
model->sv_coef = Malloc(double *,nr_class-1);
|
2292
|
+
for(i=0;i<nr_class-1;i++)
|
2293
|
+
model->sv_coef[i] = Malloc(double,total_sv);
|
2294
|
+
|
2295
|
+
p = 0;
|
2296
|
+
for(i=0;i<nr_class;i++)
|
2297
|
+
for(int j=i+1;j<nr_class;j++)
|
2298
|
+
{
|
2299
|
+
// classifier (i,j): coefficients with
|
2300
|
+
// i are in sv_coef[j-1][nz_start[i]...],
|
2301
|
+
// j are in sv_coef[i][nz_start[j]...]
|
2302
|
+
|
2303
|
+
int si = start[i];
|
2304
|
+
int sj = start[j];
|
2305
|
+
int ci = count[i];
|
2306
|
+
int cj = count[j];
|
2307
|
+
|
2308
|
+
int q = nz_start[i];
|
2309
|
+
int k;
|
2310
|
+
for(k=0;k<ci;k++)
|
2311
|
+
if(nonzero[si+k])
|
2312
|
+
model->sv_coef[j-1][q++] = f[p].alpha[k];
|
2313
|
+
q = nz_start[j];
|
2314
|
+
for(k=0;k<cj;k++)
|
2315
|
+
if(nonzero[sj+k])
|
2316
|
+
model->sv_coef[i][q++] = f[p].alpha[ci+k];
|
2317
|
+
++p;
|
2318
|
+
}
|
2319
|
+
|
2320
|
+
free(label);
|
2321
|
+
free(probA);
|
2322
|
+
free(probB);
|
2323
|
+
free(count);
|
2324
|
+
free(perm);
|
2325
|
+
free(start);
|
2326
|
+
free(x);
|
2327
|
+
free(weighted_C);
|
2328
|
+
free(nonzero);
|
2329
|
+
for(i=0;i<nr_class*(nr_class-1)/2;i++)
|
2330
|
+
free(f[i].alpha);
|
2331
|
+
free(f);
|
2332
|
+
free(nz_count);
|
2333
|
+
free(nz_start);
|
2334
|
+
}
|
2335
|
+
return model;
|
2336
|
+
}
|
2337
|
+
|
2338
|
+
// Stratified cross validation
|
2339
|
+
void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
|
2340
|
+
{
|
2341
|
+
int i;
|
2342
|
+
int *fold_start;
|
2343
|
+
int l = prob->l;
|
2344
|
+
int *perm = Malloc(int,l);
|
2345
|
+
int nr_class;
|
2346
|
+
if (nr_fold > l)
|
2347
|
+
{
|
2348
|
+
nr_fold = l;
|
2349
|
+
fprintf(stderr,"WARNING: # folds > # data. Will use # folds = # data instead (i.e., leave-one-out cross validation)\n");
|
2350
|
+
}
|
2351
|
+
fold_start = Malloc(int,nr_fold+1);
|
2352
|
+
// stratified cv may not give leave-one-out rate
|
2353
|
+
// Each class to l folds -> some folds may have zero elements
|
2354
|
+
if((param->svm_type == C_SVC ||
|
2355
|
+
param->svm_type == NU_SVC) && nr_fold < l)
|
2356
|
+
{
|
2357
|
+
int *start = NULL;
|
2358
|
+
int *label = NULL;
|
2359
|
+
int *count = NULL;
|
2360
|
+
svm_group_classes(prob,&nr_class,&label,&start,&count,perm);
|
2361
|
+
|
2362
|
+
// random shuffle and then data grouped by fold using the array perm
|
2363
|
+
int *fold_count = Malloc(int,nr_fold);
|
2364
|
+
int c;
|
2365
|
+
int *index = Malloc(int,l);
|
2366
|
+
for(i=0;i<l;i++)
|
2367
|
+
index[i]=perm[i];
|
2368
|
+
for (c=0; c<nr_class; c++)
|
2369
|
+
for(i=0;i<count[c];i++)
|
2370
|
+
{
|
2371
|
+
int j = i+rand()%(count[c]-i);
|
2372
|
+
swap(index[start[c]+j],index[start[c]+i]);
|
2373
|
+
}
|
2374
|
+
for(i=0;i<nr_fold;i++)
|
2375
|
+
{
|
2376
|
+
fold_count[i] = 0;
|
2377
|
+
for (c=0; c<nr_class;c++)
|
2378
|
+
fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold;
|
2379
|
+
}
|
2380
|
+
fold_start[0]=0;
|
2381
|
+
for (i=1;i<=nr_fold;i++)
|
2382
|
+
fold_start[i] = fold_start[i-1]+fold_count[i-1];
|
2383
|
+
for (c=0; c<nr_class;c++)
|
2384
|
+
for(i=0;i<nr_fold;i++)
|
2385
|
+
{
|
2386
|
+
int begin = start[c]+i*count[c]/nr_fold;
|
2387
|
+
int end = start[c]+(i+1)*count[c]/nr_fold;
|
2388
|
+
for(int j=begin;j<end;j++)
|
2389
|
+
{
|
2390
|
+
perm[fold_start[i]] = index[j];
|
2391
|
+
fold_start[i]++;
|
2392
|
+
}
|
2393
|
+
}
|
2394
|
+
fold_start[0]=0;
|
2395
|
+
for (i=1;i<=nr_fold;i++)
|
2396
|
+
fold_start[i] = fold_start[i-1]+fold_count[i-1];
|
2397
|
+
free(start);
|
2398
|
+
free(label);
|
2399
|
+
free(count);
|
2400
|
+
free(index);
|
2401
|
+
free(fold_count);
|
2402
|
+
}
|
2403
|
+
else
|
2404
|
+
{
|
2405
|
+
for(i=0;i<l;i++) perm[i]=i;
|
2406
|
+
for(i=0;i<l;i++)
|
2407
|
+
{
|
2408
|
+
int j = i+rand()%(l-i);
|
2409
|
+
swap(perm[i],perm[j]);
|
2410
|
+
}
|
2411
|
+
for(i=0;i<=nr_fold;i++)
|
2412
|
+
fold_start[i]=i*l/nr_fold;
|
2413
|
+
}
|
2414
|
+
|
2415
|
+
for(i=0;i<nr_fold;i++)
|
2416
|
+
{
|
2417
|
+
int begin = fold_start[i];
|
2418
|
+
int end = fold_start[i+1];
|
2419
|
+
int j,k;
|
2420
|
+
struct svm_problem subprob;
|
2421
|
+
|
2422
|
+
subprob.l = l-(end-begin);
|
2423
|
+
subprob.x = Malloc(struct svm_node*,subprob.l);
|
2424
|
+
subprob.y = Malloc(double,subprob.l);
|
2425
|
+
|
2426
|
+
k=0;
|
2427
|
+
for(j=0;j<begin;j++)
|
2428
|
+
{
|
2429
|
+
subprob.x[k] = prob->x[perm[j]];
|
2430
|
+
subprob.y[k] = prob->y[perm[j]];
|
2431
|
+
++k;
|
2432
|
+
}
|
2433
|
+
for(j=end;j<l;j++)
|
2434
|
+
{
|
2435
|
+
subprob.x[k] = prob->x[perm[j]];
|
2436
|
+
subprob.y[k] = prob->y[perm[j]];
|
2437
|
+
++k;
|
2438
|
+
}
|
2439
|
+
struct svm_model *submodel = svm_train(&subprob,param);
|
2440
|
+
if(param->probability &&
|
2441
|
+
(param->svm_type == C_SVC || param->svm_type == NU_SVC))
|
2442
|
+
{
|
2443
|
+
double *prob_estimates=Malloc(double,svm_get_nr_class(submodel));
|
2444
|
+
for(j=begin;j<end;j++)
|
2445
|
+
target[perm[j]] = svm_predict_probability(submodel,prob->x[perm[j]],prob_estimates);
|
2446
|
+
free(prob_estimates);
|
2447
|
+
}
|
2448
|
+
else
|
2449
|
+
for(j=begin;j<end;j++)
|
2450
|
+
target[perm[j]] = svm_predict(submodel,prob->x[perm[j]]);
|
2451
|
+
svm_free_and_destroy_model(&submodel);
|
2452
|
+
free(subprob.x);
|
2453
|
+
free(subprob.y);
|
2454
|
+
}
|
2455
|
+
free(fold_start);
|
2456
|
+
free(perm);
|
2457
|
+
}
|
2458
|
+
|
2459
|
+
|
2460
|
+
int svm_get_svm_type(const svm_model *model)
|
2461
|
+
{
|
2462
|
+
return model->param.svm_type;
|
2463
|
+
}
|
2464
|
+
|
2465
|
+
int svm_get_nr_class(const svm_model *model)
|
2466
|
+
{
|
2467
|
+
return model->nr_class;
|
2468
|
+
}
|
2469
|
+
|
2470
|
+
void svm_get_labels(const svm_model *model, int* label)
|
2471
|
+
{
|
2472
|
+
if (model->label != NULL)
|
2473
|
+
for(int i=0;i<model->nr_class;i++)
|
2474
|
+
label[i] = model->label[i];
|
2475
|
+
}
|
2476
|
+
|
2477
|
+
void svm_get_sv_indices(const svm_model *model, int* indices)
|
2478
|
+
{
|
2479
|
+
if (model->sv_indices != NULL)
|
2480
|
+
for(int i=0;i<model->l;i++)
|
2481
|
+
indices[i] = model->sv_indices[i];
|
2482
|
+
}
|
2483
|
+
|
2484
|
+
int svm_get_nr_sv(const svm_model *model)
|
2485
|
+
{
|
2486
|
+
return model->l;
|
2487
|
+
}
|
2488
|
+
|
2489
|
+
double svm_get_svr_probability(const svm_model *model)
|
2490
|
+
{
|
2491
|
+
if ((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
|
2492
|
+
model->probA!=NULL)
|
2493
|
+
return model->probA[0];
|
2494
|
+
else
|
2495
|
+
{
|
2496
|
+
fprintf(stderr,"Model doesn't contain information for SVR probability inference\n");
|
2497
|
+
return 0;
|
2498
|
+
}
|
2499
|
+
}
|
2500
|
+
|
2501
|
+
double svm_predict_values(const svm_model *model, const svm_node *x, double* dec_values)
|
2502
|
+
{
|
2503
|
+
int i;
|
2504
|
+
if(model->param.svm_type == ONE_CLASS ||
|
2505
|
+
model->param.svm_type == EPSILON_SVR ||
|
2506
|
+
model->param.svm_type == NU_SVR)
|
2507
|
+
{
|
2508
|
+
double *sv_coef = model->sv_coef[0];
|
2509
|
+
double sum = 0;
|
2510
|
+
for(i=0;i<model->l;i++)
|
2511
|
+
sum += sv_coef[i] * Kernel::k_function(x,model->SV[i],model->param);
|
2512
|
+
sum -= model->rho[0];
|
2513
|
+
*dec_values = sum;
|
2514
|
+
|
2515
|
+
if(model->param.svm_type == ONE_CLASS)
|
2516
|
+
return (sum>0)?1:-1;
|
2517
|
+
else
|
2518
|
+
return sum;
|
2519
|
+
}
|
2520
|
+
else
|
2521
|
+
{
|
2522
|
+
int nr_class = model->nr_class;
|
2523
|
+
int l = model->l;
|
2524
|
+
|
2525
|
+
double *kvalue = Malloc(double,l);
|
2526
|
+
for(i=0;i<l;i++)
|
2527
|
+
kvalue[i] = Kernel::k_function(x,model->SV[i],model->param);
|
2528
|
+
|
2529
|
+
int *start = Malloc(int,nr_class);
|
2530
|
+
start[0] = 0;
|
2531
|
+
for(i=1;i<nr_class;i++)
|
2532
|
+
start[i] = start[i-1]+model->nSV[i-1];
|
2533
|
+
|
2534
|
+
int *vote = Malloc(int,nr_class);
|
2535
|
+
for(i=0;i<nr_class;i++)
|
2536
|
+
vote[i] = 0;
|
2537
|
+
|
2538
|
+
int p=0;
|
2539
|
+
for(i=0;i<nr_class;i++)
|
2540
|
+
for(int j=i+1;j<nr_class;j++)
|
2541
|
+
{
|
2542
|
+
double sum = 0;
|
2543
|
+
int si = start[i];
|
2544
|
+
int sj = start[j];
|
2545
|
+
int ci = model->nSV[i];
|
2546
|
+
int cj = model->nSV[j];
|
2547
|
+
|
2548
|
+
int k;
|
2549
|
+
double *coef1 = model->sv_coef[j-1];
|
2550
|
+
double *coef2 = model->sv_coef[i];
|
2551
|
+
for(k=0;k<ci;k++)
|
2552
|
+
sum += coef1[si+k] * kvalue[si+k];
|
2553
|
+
for(k=0;k<cj;k++)
|
2554
|
+
sum += coef2[sj+k] * kvalue[sj+k];
|
2555
|
+
sum -= model->rho[p];
|
2556
|
+
dec_values[p] = sum;
|
2557
|
+
|
2558
|
+
if(dec_values[p] > 0)
|
2559
|
+
++vote[i];
|
2560
|
+
else
|
2561
|
+
++vote[j];
|
2562
|
+
p++;
|
2563
|
+
}
|
2564
|
+
|
2565
|
+
int vote_max_idx = 0;
|
2566
|
+
for(i=1;i<nr_class;i++)
|
2567
|
+
if(vote[i] > vote[vote_max_idx])
|
2568
|
+
vote_max_idx = i;
|
2569
|
+
|
2570
|
+
free(kvalue);
|
2571
|
+
free(start);
|
2572
|
+
free(vote);
|
2573
|
+
return model->label[vote_max_idx];
|
2574
|
+
}
|
2575
|
+
}
|
2576
|
+
|
2577
|
+
double svm_predict(const svm_model *model, const svm_node *x)
|
2578
|
+
{
|
2579
|
+
int nr_class = model->nr_class;
|
2580
|
+
double *dec_values;
|
2581
|
+
if(model->param.svm_type == ONE_CLASS ||
|
2582
|
+
model->param.svm_type == EPSILON_SVR ||
|
2583
|
+
model->param.svm_type == NU_SVR)
|
2584
|
+
dec_values = Malloc(double, 1);
|
2585
|
+
else
|
2586
|
+
dec_values = Malloc(double, nr_class*(nr_class-1)/2);
|
2587
|
+
double pred_result = svm_predict_values(model, x, dec_values);
|
2588
|
+
free(dec_values);
|
2589
|
+
return pred_result;
|
2590
|
+
}
|
2591
|
+
|
2592
|
+
double svm_predict_probability(
|
2593
|
+
const svm_model *model, const svm_node *x, double *prob_estimates)
|
2594
|
+
{
|
2595
|
+
if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
|
2596
|
+
model->probA!=NULL && model->probB!=NULL)
|
2597
|
+
{
|
2598
|
+
int i;
|
2599
|
+
int nr_class = model->nr_class;
|
2600
|
+
double *dec_values = Malloc(double, nr_class*(nr_class-1)/2);
|
2601
|
+
svm_predict_values(model, x, dec_values);
|
2602
|
+
|
2603
|
+
double min_prob=1e-7;
|
2604
|
+
double **pairwise_prob=Malloc(double *,nr_class);
|
2605
|
+
for(i=0;i<nr_class;i++)
|
2606
|
+
pairwise_prob[i]=Malloc(double,nr_class);
|
2607
|
+
int k=0;
|
2608
|
+
for(i=0;i<nr_class;i++)
|
2609
|
+
for(int j=i+1;j<nr_class;j++)
|
2610
|
+
{
|
2611
|
+
pairwise_prob[i][j]=min(max(sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob);
|
2612
|
+
pairwise_prob[j][i]=1-pairwise_prob[i][j];
|
2613
|
+
k++;
|
2614
|
+
}
|
2615
|
+
if (nr_class == 2)
|
2616
|
+
{
|
2617
|
+
prob_estimates[0] = pairwise_prob[0][1];
|
2618
|
+
prob_estimates[1] = pairwise_prob[1][0];
|
2619
|
+
}
|
2620
|
+
else
|
2621
|
+
multiclass_probability(nr_class,pairwise_prob,prob_estimates);
|
2622
|
+
|
2623
|
+
int prob_max_idx = 0;
|
2624
|
+
for(i=1;i<nr_class;i++)
|
2625
|
+
if(prob_estimates[i] > prob_estimates[prob_max_idx])
|
2626
|
+
prob_max_idx = i;
|
2627
|
+
for(i=0;i<nr_class;i++)
|
2628
|
+
free(pairwise_prob[i]);
|
2629
|
+
free(dec_values);
|
2630
|
+
free(pairwise_prob);
|
2631
|
+
return model->label[prob_max_idx];
|
2632
|
+
}
|
2633
|
+
else
|
2634
|
+
return svm_predict(model, x);
|
2635
|
+
}
|
2636
|
+
|
2637
|
+
static const char *svm_type_table[] =
|
2638
|
+
{
|
2639
|
+
"c_svc","nu_svc","one_class","epsilon_svr","nu_svr",NULL
|
2640
|
+
};
|
2641
|
+
|
2642
|
+
static const char *kernel_type_table[]=
|
2643
|
+
{
|
2644
|
+
"linear","polynomial","rbf","sigmoid","precomputed",NULL
|
2645
|
+
};
|
2646
|
+
|
2647
|
+
int svm_save_model(const char *model_file_name, const svm_model *model)
|
2648
|
+
{
|
2649
|
+
FILE *fp = fopen(model_file_name,"w");
|
2650
|
+
if(fp==NULL) return -1;
|
2651
|
+
|
2652
|
+
char *old_locale = setlocale(LC_ALL, NULL);
|
2653
|
+
if (old_locale) {
|
2654
|
+
old_locale = strdup(old_locale);
|
2655
|
+
}
|
2656
|
+
setlocale(LC_ALL, "C");
|
2657
|
+
|
2658
|
+
const svm_parameter& param = model->param;
|
2659
|
+
|
2660
|
+
fprintf(fp,"svm_type %s\n", svm_type_table[param.svm_type]);
|
2661
|
+
fprintf(fp,"kernel_type %s\n", kernel_type_table[param.kernel_type]);
|
2662
|
+
|
2663
|
+
if(param.kernel_type == POLY)
|
2664
|
+
fprintf(fp,"degree %d\n", param.degree);
|
2665
|
+
|
2666
|
+
if(param.kernel_type == POLY || param.kernel_type == RBF || param.kernel_type == SIGMOID)
|
2667
|
+
fprintf(fp,"gamma %.17g\n", param.gamma);
|
2668
|
+
|
2669
|
+
if(param.kernel_type == POLY || param.kernel_type == SIGMOID)
|
2670
|
+
fprintf(fp,"coef0 %.17g\n", param.coef0);
|
2671
|
+
|
2672
|
+
int nr_class = model->nr_class;
|
2673
|
+
int l = model->l;
|
2674
|
+
fprintf(fp, "nr_class %d\n", nr_class);
|
2675
|
+
fprintf(fp, "total_sv %d\n",l);
|
2676
|
+
|
2677
|
+
{
|
2678
|
+
fprintf(fp, "rho");
|
2679
|
+
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
|
2680
|
+
fprintf(fp," %.17g",model->rho[i]);
|
2681
|
+
fprintf(fp, "\n");
|
2682
|
+
}
|
2683
|
+
|
2684
|
+
if(model->label)
|
2685
|
+
{
|
2686
|
+
fprintf(fp, "label");
|
2687
|
+
for(int i=0;i<nr_class;i++)
|
2688
|
+
fprintf(fp," %d",model->label[i]);
|
2689
|
+
fprintf(fp, "\n");
|
2690
|
+
}
|
2691
|
+
|
2692
|
+
if(model->probA) // regression has probA only
|
2693
|
+
{
|
2694
|
+
fprintf(fp, "probA");
|
2695
|
+
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
|
2696
|
+
fprintf(fp," %.17g",model->probA[i]);
|
2697
|
+
fprintf(fp, "\n");
|
2698
|
+
}
|
2699
|
+
if(model->probB)
|
2700
|
+
{
|
2701
|
+
fprintf(fp, "probB");
|
2702
|
+
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
|
2703
|
+
fprintf(fp," %.17g",model->probB[i]);
|
2704
|
+
fprintf(fp, "\n");
|
2705
|
+
}
|
2706
|
+
|
2707
|
+
if(model->nSV)
|
2708
|
+
{
|
2709
|
+
fprintf(fp, "nr_sv");
|
2710
|
+
for(int i=0;i<nr_class;i++)
|
2711
|
+
fprintf(fp," %d",model->nSV[i]);
|
2712
|
+
fprintf(fp, "\n");
|
2713
|
+
}
|
2714
|
+
|
2715
|
+
fprintf(fp, "SV\n");
|
2716
|
+
const double * const *sv_coef = model->sv_coef;
|
2717
|
+
const svm_node * const *SV = model->SV;
|
2718
|
+
|
2719
|
+
for(int i=0;i<l;i++)
|
2720
|
+
{
|
2721
|
+
for(int j=0;j<nr_class-1;j++)
|
2722
|
+
fprintf(fp, "%.17g ",sv_coef[j][i]);
|
2723
|
+
|
2724
|
+
const svm_node *p = SV[i];
|
2725
|
+
|
2726
|
+
if(param.kernel_type == PRECOMPUTED)
|
2727
|
+
fprintf(fp,"0:%d ",(int)(p->value));
|
2728
|
+
else
|
2729
|
+
while(p->index != -1)
|
2730
|
+
{
|
2731
|
+
fprintf(fp,"%d:%.8g ",p->index,p->value);
|
2732
|
+
p++;
|
2733
|
+
}
|
2734
|
+
fprintf(fp, "\n");
|
2735
|
+
}
|
2736
|
+
|
2737
|
+
setlocale(LC_ALL, old_locale);
|
2738
|
+
free(old_locale);
|
2739
|
+
|
2740
|
+
if (ferror(fp) != 0 || fclose(fp) != 0) return -1;
|
2741
|
+
else return 0;
|
2742
|
+
}
|
2743
|
+
|
2744
|
+
static char *line = NULL;
|
2745
|
+
static int max_line_len;
|
2746
|
+
|
2747
|
+
static char* readline(FILE *input)
|
2748
|
+
{
|
2749
|
+
int len;
|
2750
|
+
|
2751
|
+
if(fgets(line,max_line_len,input) == NULL)
|
2752
|
+
return NULL;
|
2753
|
+
|
2754
|
+
while(strrchr(line,'\n') == NULL)
|
2755
|
+
{
|
2756
|
+
max_line_len *= 2;
|
2757
|
+
line = (char *) realloc(line,max_line_len);
|
2758
|
+
len = (int) strlen(line);
|
2759
|
+
if(fgets(line+len,max_line_len-len,input) == NULL)
|
2760
|
+
break;
|
2761
|
+
}
|
2762
|
+
return line;
|
2763
|
+
}
|
2764
|
+
|
2765
|
+
//
|
2766
|
+
// FSCANF helps to handle fscanf failures.
|
2767
|
+
// Its do-while block avoids the ambiguity when
|
2768
|
+
// if (...)
|
2769
|
+
// FSCANF();
|
2770
|
+
// is used
|
2771
|
+
//
|
2772
|
+
#define FSCANF(_stream, _format, _var) do{ if (fscanf(_stream, _format, _var) != 1) return false; }while(0)
|
2773
|
+
bool read_model_header(FILE *fp, svm_model* model)
|
2774
|
+
{
|
2775
|
+
svm_parameter& param = model->param;
|
2776
|
+
// parameters for training only won't be assigned, but arrays are assigned as NULL for safety
|
2777
|
+
param.nr_weight = 0;
|
2778
|
+
param.weight_label = NULL;
|
2779
|
+
param.weight = NULL;
|
2780
|
+
|
2781
|
+
char cmd[81];
|
2782
|
+
while(1)
|
2783
|
+
{
|
2784
|
+
FSCANF(fp,"%80s",cmd);
|
2785
|
+
|
2786
|
+
if(strcmp(cmd,"svm_type")==0)
|
2787
|
+
{
|
2788
|
+
FSCANF(fp,"%80s",cmd);
|
2789
|
+
int i;
|
2790
|
+
for(i=0;svm_type_table[i];i++)
|
2791
|
+
{
|
2792
|
+
if(strcmp(svm_type_table[i],cmd)==0)
|
2793
|
+
{
|
2794
|
+
param.svm_type=i;
|
2795
|
+
break;
|
2796
|
+
}
|
2797
|
+
}
|
2798
|
+
if(svm_type_table[i] == NULL)
|
2799
|
+
{
|
2800
|
+
fprintf(stderr,"unknown svm type.\n");
|
2801
|
+
return false;
|
2802
|
+
}
|
2803
|
+
}
|
2804
|
+
else if(strcmp(cmd,"kernel_type")==0)
|
2805
|
+
{
|
2806
|
+
FSCANF(fp,"%80s",cmd);
|
2807
|
+
int i;
|
2808
|
+
for(i=0;kernel_type_table[i];i++)
|
2809
|
+
{
|
2810
|
+
if(strcmp(kernel_type_table[i],cmd)==0)
|
2811
|
+
{
|
2812
|
+
param.kernel_type=i;
|
2813
|
+
break;
|
2814
|
+
}
|
2815
|
+
}
|
2816
|
+
if(kernel_type_table[i] == NULL)
|
2817
|
+
{
|
2818
|
+
fprintf(stderr,"unknown kernel function.\n");
|
2819
|
+
return false;
|
2820
|
+
}
|
2821
|
+
}
|
2822
|
+
else if(strcmp(cmd,"degree")==0)
|
2823
|
+
FSCANF(fp,"%d",¶m.degree);
|
2824
|
+
else if(strcmp(cmd,"gamma")==0)
|
2825
|
+
FSCANF(fp,"%lf",¶m.gamma);
|
2826
|
+
else if(strcmp(cmd,"coef0")==0)
|
2827
|
+
FSCANF(fp,"%lf",¶m.coef0);
|
2828
|
+
else if(strcmp(cmd,"nr_class")==0)
|
2829
|
+
FSCANF(fp,"%d",&model->nr_class);
|
2830
|
+
else if(strcmp(cmd,"total_sv")==0)
|
2831
|
+
FSCANF(fp,"%d",&model->l);
|
2832
|
+
else if(strcmp(cmd,"rho")==0)
|
2833
|
+
{
|
2834
|
+
int n = model->nr_class * (model->nr_class-1)/2;
|
2835
|
+
model->rho = Malloc(double,n);
|
2836
|
+
for(int i=0;i<n;i++)
|
2837
|
+
FSCANF(fp,"%lf",&model->rho[i]);
|
2838
|
+
}
|
2839
|
+
else if(strcmp(cmd,"label")==0)
|
2840
|
+
{
|
2841
|
+
int n = model->nr_class;
|
2842
|
+
model->label = Malloc(int,n);
|
2843
|
+
for(int i=0;i<n;i++)
|
2844
|
+
FSCANF(fp,"%d",&model->label[i]);
|
2845
|
+
}
|
2846
|
+
else if(strcmp(cmd,"probA")==0)
|
2847
|
+
{
|
2848
|
+
int n = model->nr_class * (model->nr_class-1)/2;
|
2849
|
+
model->probA = Malloc(double,n);
|
2850
|
+
for(int i=0;i<n;i++)
|
2851
|
+
FSCANF(fp,"%lf",&model->probA[i]);
|
2852
|
+
}
|
2853
|
+
else if(strcmp(cmd,"probB")==0)
|
2854
|
+
{
|
2855
|
+
int n = model->nr_class * (model->nr_class-1)/2;
|
2856
|
+
model->probB = Malloc(double,n);
|
2857
|
+
for(int i=0;i<n;i++)
|
2858
|
+
FSCANF(fp,"%lf",&model->probB[i]);
|
2859
|
+
}
|
2860
|
+
else if(strcmp(cmd,"nr_sv")==0)
|
2861
|
+
{
|
2862
|
+
int n = model->nr_class;
|
2863
|
+
model->nSV = Malloc(int,n);
|
2864
|
+
for(int i=0;i<n;i++)
|
2865
|
+
FSCANF(fp,"%d",&model->nSV[i]);
|
2866
|
+
}
|
2867
|
+
else if(strcmp(cmd,"SV")==0)
|
2868
|
+
{
|
2869
|
+
while(1)
|
2870
|
+
{
|
2871
|
+
int c = getc(fp);
|
2872
|
+
if(c==EOF || c=='\n') break;
|
2873
|
+
}
|
2874
|
+
break;
|
2875
|
+
}
|
2876
|
+
else
|
2877
|
+
{
|
2878
|
+
fprintf(stderr,"unknown text in model file: [%s]\n",cmd);
|
2879
|
+
return false;
|
2880
|
+
}
|
2881
|
+
}
|
2882
|
+
|
2883
|
+
return true;
|
2884
|
+
|
2885
|
+
}
|
2886
|
+
|
2887
|
+
svm_model *svm_load_model(const char *model_file_name)
|
2888
|
+
{
|
2889
|
+
FILE *fp = fopen(model_file_name,"rb");
|
2890
|
+
if(fp==NULL) return NULL;
|
2891
|
+
|
2892
|
+
char *old_locale = setlocale(LC_ALL, NULL);
|
2893
|
+
if (old_locale) {
|
2894
|
+
old_locale = strdup(old_locale);
|
2895
|
+
}
|
2896
|
+
setlocale(LC_ALL, "C");
|
2897
|
+
|
2898
|
+
// read parameters
|
2899
|
+
|
2900
|
+
svm_model *model = Malloc(svm_model,1);
|
2901
|
+
model->rho = NULL;
|
2902
|
+
model->probA = NULL;
|
2903
|
+
model->probB = NULL;
|
2904
|
+
model->sv_indices = NULL;
|
2905
|
+
model->label = NULL;
|
2906
|
+
model->nSV = NULL;
|
2907
|
+
|
2908
|
+
// read header
|
2909
|
+
if (!read_model_header(fp, model))
|
2910
|
+
{
|
2911
|
+
fprintf(stderr, "ERROR: fscanf failed to read model\n");
|
2912
|
+
setlocale(LC_ALL, old_locale);
|
2913
|
+
free(old_locale);
|
2914
|
+
free(model->rho);
|
2915
|
+
free(model->label);
|
2916
|
+
free(model->nSV);
|
2917
|
+
free(model);
|
2918
|
+
return NULL;
|
2919
|
+
}
|
2920
|
+
|
2921
|
+
// read sv_coef and SV
|
2922
|
+
|
2923
|
+
int elements = 0;
|
2924
|
+
long pos = ftell(fp);
|
2925
|
+
|
2926
|
+
max_line_len = 1024;
|
2927
|
+
line = Malloc(char,max_line_len);
|
2928
|
+
char *p,*endptr,*idx,*val;
|
2929
|
+
|
2930
|
+
while(readline(fp)!=NULL)
|
2931
|
+
{
|
2932
|
+
p = strtok(line,":");
|
2933
|
+
while(1)
|
2934
|
+
{
|
2935
|
+
p = strtok(NULL,":");
|
2936
|
+
if(p == NULL)
|
2937
|
+
break;
|
2938
|
+
++elements;
|
2939
|
+
}
|
2940
|
+
}
|
2941
|
+
elements += model->l;
|
2942
|
+
|
2943
|
+
fseek(fp,pos,SEEK_SET);
|
2944
|
+
|
2945
|
+
int m = model->nr_class - 1;
|
2946
|
+
int l = model->l;
|
2947
|
+
model->sv_coef = Malloc(double *,m);
|
2948
|
+
int i;
|
2949
|
+
for(i=0;i<m;i++)
|
2950
|
+
model->sv_coef[i] = Malloc(double,l);
|
2951
|
+
model->SV = Malloc(svm_node*,l);
|
2952
|
+
svm_node *x_space = NULL;
|
2953
|
+
if(l>0) x_space = Malloc(svm_node,elements);
|
2954
|
+
|
2955
|
+
int j=0;
|
2956
|
+
for(i=0;i<l;i++)
|
2957
|
+
{
|
2958
|
+
readline(fp);
|
2959
|
+
model->SV[i] = &x_space[j];
|
2960
|
+
|
2961
|
+
p = strtok(line, " \t");
|
2962
|
+
model->sv_coef[0][i] = strtod(p,&endptr);
|
2963
|
+
for(int k=1;k<m;k++)
|
2964
|
+
{
|
2965
|
+
p = strtok(NULL, " \t");
|
2966
|
+
model->sv_coef[k][i] = strtod(p,&endptr);
|
2967
|
+
}
|
2968
|
+
|
2969
|
+
while(1)
|
2970
|
+
{
|
2971
|
+
idx = strtok(NULL, ":");
|
2972
|
+
val = strtok(NULL, " \t");
|
2973
|
+
|
2974
|
+
if(val == NULL)
|
2975
|
+
break;
|
2976
|
+
x_space[j].index = (int) strtol(idx,&endptr,10);
|
2977
|
+
x_space[j].value = strtod(val,&endptr);
|
2978
|
+
|
2979
|
+
++j;
|
2980
|
+
}
|
2981
|
+
x_space[j++].index = -1;
|
2982
|
+
}
|
2983
|
+
free(line);
|
2984
|
+
|
2985
|
+
setlocale(LC_ALL, old_locale);
|
2986
|
+
free(old_locale);
|
2987
|
+
|
2988
|
+
if (ferror(fp) != 0 || fclose(fp) != 0)
|
2989
|
+
return NULL;
|
2990
|
+
|
2991
|
+
model->free_sv = 1; // XXX
|
2992
|
+
return model;
|
2993
|
+
}
|
2994
|
+
|
2995
|
+
void svm_free_model_content(svm_model* model_ptr)
|
2996
|
+
{
|
2997
|
+
if(model_ptr->free_sv && model_ptr->l > 0 && model_ptr->SV != NULL)
|
2998
|
+
free((void *)(model_ptr->SV[0]));
|
2999
|
+
if(model_ptr->sv_coef)
|
3000
|
+
{
|
3001
|
+
for(int i=0;i<model_ptr->nr_class-1;i++)
|
3002
|
+
free(model_ptr->sv_coef[i]);
|
3003
|
+
}
|
3004
|
+
|
3005
|
+
free(model_ptr->SV);
|
3006
|
+
model_ptr->SV = NULL;
|
3007
|
+
|
3008
|
+
free(model_ptr->sv_coef);
|
3009
|
+
model_ptr->sv_coef = NULL;
|
3010
|
+
|
3011
|
+
free(model_ptr->rho);
|
3012
|
+
model_ptr->rho = NULL;
|
3013
|
+
|
3014
|
+
free(model_ptr->label);
|
3015
|
+
model_ptr->label= NULL;
|
3016
|
+
|
3017
|
+
free(model_ptr->probA);
|
3018
|
+
model_ptr->probA = NULL;
|
3019
|
+
|
3020
|
+
free(model_ptr->probB);
|
3021
|
+
model_ptr->probB= NULL;
|
3022
|
+
|
3023
|
+
free(model_ptr->sv_indices);
|
3024
|
+
model_ptr->sv_indices = NULL;
|
3025
|
+
|
3026
|
+
free(model_ptr->nSV);
|
3027
|
+
model_ptr->nSV = NULL;
|
3028
|
+
}
|
3029
|
+
|
3030
|
+
void svm_free_and_destroy_model(svm_model** model_ptr_ptr)
|
3031
|
+
{
|
3032
|
+
if(model_ptr_ptr != NULL && *model_ptr_ptr != NULL)
|
3033
|
+
{
|
3034
|
+
svm_free_model_content(*model_ptr_ptr);
|
3035
|
+
free(*model_ptr_ptr);
|
3036
|
+
*model_ptr_ptr = NULL;
|
3037
|
+
}
|
3038
|
+
}
|
3039
|
+
|
3040
|
+
void svm_destroy_param(svm_parameter* param)
|
3041
|
+
{
|
3042
|
+
free(param->weight_label);
|
3043
|
+
free(param->weight);
|
3044
|
+
}
|
3045
|
+
|
3046
|
+
const char *svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
|
3047
|
+
{
|
3048
|
+
// svm_type
|
3049
|
+
|
3050
|
+
int svm_type = param->svm_type;
|
3051
|
+
if(svm_type != C_SVC &&
|
3052
|
+
svm_type != NU_SVC &&
|
3053
|
+
svm_type != ONE_CLASS &&
|
3054
|
+
svm_type != EPSILON_SVR &&
|
3055
|
+
svm_type != NU_SVR)
|
3056
|
+
return "unknown svm type";
|
3057
|
+
|
3058
|
+
// kernel_type, degree
|
3059
|
+
|
3060
|
+
int kernel_type = param->kernel_type;
|
3061
|
+
if(kernel_type != LINEAR &&
|
3062
|
+
kernel_type != POLY &&
|
3063
|
+
kernel_type != RBF &&
|
3064
|
+
kernel_type != SIGMOID &&
|
3065
|
+
kernel_type != PRECOMPUTED)
|
3066
|
+
return "unknown kernel type";
|
3067
|
+
|
3068
|
+
if((kernel_type == POLY || kernel_type == RBF || kernel_type == SIGMOID) &&
|
3069
|
+
param->gamma < 0)
|
3070
|
+
return "gamma < 0";
|
3071
|
+
|
3072
|
+
if(kernel_type == POLY && param->degree < 0)
|
3073
|
+
return "degree of polynomial kernel < 0";
|
3074
|
+
|
3075
|
+
// cache_size,eps,C,nu,p,shrinking
|
3076
|
+
|
3077
|
+
if(param->cache_size <= 0)
|
3078
|
+
return "cache_size <= 0";
|
3079
|
+
|
3080
|
+
if(param->eps <= 0)
|
3081
|
+
return "eps <= 0";
|
3082
|
+
|
3083
|
+
if(svm_type == C_SVC ||
|
3084
|
+
svm_type == EPSILON_SVR ||
|
3085
|
+
svm_type == NU_SVR)
|
3086
|
+
if(param->C <= 0)
|
3087
|
+
return "C <= 0";
|
3088
|
+
|
3089
|
+
if(svm_type == NU_SVC ||
|
3090
|
+
svm_type == ONE_CLASS ||
|
3091
|
+
svm_type == NU_SVR)
|
3092
|
+
if(param->nu <= 0 || param->nu > 1)
|
3093
|
+
return "nu <= 0 or nu > 1";
|
3094
|
+
|
3095
|
+
if(svm_type == EPSILON_SVR)
|
3096
|
+
if(param->p < 0)
|
3097
|
+
return "p < 0";
|
3098
|
+
|
3099
|
+
if(param->shrinking != 0 &&
|
3100
|
+
param->shrinking != 1)
|
3101
|
+
return "shrinking != 0 and shrinking != 1";
|
3102
|
+
|
3103
|
+
if(param->probability != 0 &&
|
3104
|
+
param->probability != 1)
|
3105
|
+
return "probability != 0 and probability != 1";
|
3106
|
+
|
3107
|
+
if(param->probability == 1 &&
|
3108
|
+
svm_type == ONE_CLASS)
|
3109
|
+
return "one-class SVM probability output not supported yet";
|
3110
|
+
|
3111
|
+
|
3112
|
+
// check whether nu-svc is feasible
|
3113
|
+
|
3114
|
+
if(svm_type == NU_SVC)
|
3115
|
+
{
|
3116
|
+
int l = prob->l;
|
3117
|
+
int max_nr_class = 16;
|
3118
|
+
int nr_class = 0;
|
3119
|
+
int *label = Malloc(int,max_nr_class);
|
3120
|
+
int *count = Malloc(int,max_nr_class);
|
3121
|
+
|
3122
|
+
int i;
|
3123
|
+
for(i=0;i<l;i++)
|
3124
|
+
{
|
3125
|
+
int this_label = (int)prob->y[i];
|
3126
|
+
int j;
|
3127
|
+
for(j=0;j<nr_class;j++)
|
3128
|
+
if(this_label == label[j])
|
3129
|
+
{
|
3130
|
+
++count[j];
|
3131
|
+
break;
|
3132
|
+
}
|
3133
|
+
if(j == nr_class)
|
3134
|
+
{
|
3135
|
+
if(nr_class == max_nr_class)
|
3136
|
+
{
|
3137
|
+
max_nr_class *= 2;
|
3138
|
+
label = (int *)realloc(label,max_nr_class*sizeof(int));
|
3139
|
+
count = (int *)realloc(count,max_nr_class*sizeof(int));
|
3140
|
+
}
|
3141
|
+
label[nr_class] = this_label;
|
3142
|
+
count[nr_class] = 1;
|
3143
|
+
++nr_class;
|
3144
|
+
}
|
3145
|
+
}
|
3146
|
+
|
3147
|
+
for(i=0;i<nr_class;i++)
|
3148
|
+
{
|
3149
|
+
int n1 = count[i];
|
3150
|
+
for(int j=i+1;j<nr_class;j++)
|
3151
|
+
{
|
3152
|
+
int n2 = count[j];
|
3153
|
+
if(param->nu*(n1+n2)/2 > min(n1,n2))
|
3154
|
+
{
|
3155
|
+
free(label);
|
3156
|
+
free(count);
|
3157
|
+
return "specified nu is infeasible";
|
3158
|
+
}
|
3159
|
+
}
|
3160
|
+
}
|
3161
|
+
free(label);
|
3162
|
+
free(count);
|
3163
|
+
}
|
3164
|
+
|
3165
|
+
return NULL;
|
3166
|
+
}
|
3167
|
+
|
3168
|
+
int svm_check_probability_model(const svm_model *model)
|
3169
|
+
{
|
3170
|
+
return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) &&
|
3171
|
+
model->probA!=NULL && model->probB!=NULL) ||
|
3172
|
+
((model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) &&
|
3173
|
+
model->probA!=NULL);
|
3174
|
+
}
|
3175
|
+
|
3176
|
+
void svm_set_print_string_function(void (*print_func)(const char *))
|
3177
|
+
{
|
3178
|
+
if(print_func == NULL)
|
3179
|
+
svm_print_string = &print_string_stdout;
|
3180
|
+
else
|
3181
|
+
svm_print_string = print_func;
|
3182
|
+
}
|