numo-libsvm 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/.gitignore +20 -0
- data/.rspec +3 -0
- data/.travis.yml +14 -0
- data/CHANGELOG.md +2 -0
- data/CODE_OF_CONDUCT.md +74 -0
- data/Gemfile +4 -0
- data/LICENSE.txt +27 -0
- data/README.md +180 -0
- data/Rakefile +15 -0
- data/ext/numo/libsvm/converter.c +162 -0
- data/ext/numo/libsvm/converter.h +19 -0
- data/ext/numo/libsvm/extconf.rb +43 -0
- data/ext/numo/libsvm/kernel_type.c +22 -0
- data/ext/numo/libsvm/kernel_type.h +9 -0
- data/ext/numo/libsvm/libsvmext.c +486 -0
- data/ext/numo/libsvm/libsvmext.h +17 -0
- data/ext/numo/libsvm/svm_model.c +89 -0
- data/ext/numo/libsvm/svm_model.h +15 -0
- data/ext/numo/libsvm/svm_parameter.c +88 -0
- data/ext/numo/libsvm/svm_parameter.h +15 -0
- data/ext/numo/libsvm/svm_type.c +22 -0
- data/ext/numo/libsvm/svm_type.h +9 -0
- data/lib/numo/libsvm.rb +5 -0
- data/lib/numo/libsvm/version.rb +8 -0
- data/numo-libsvm.gemspec +41 -0
- metadata +145 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
#ifndef NUMO_LIBSVM_CONVERTER_H
|
2
|
+
#define NUMO_LIBSVM_CONVERTER_H 1
|
3
|
+
|
4
|
+
#include <string.h>
|
5
|
+
#include <svm.h>
|
6
|
+
#include <ruby.h>
|
7
|
+
#include <numo/narray.h>
|
8
|
+
#include <numo/template.h>
|
9
|
+
|
10
|
+
VALUE int_vec_to_nary(int* const arr, int const size);
|
11
|
+
int* nary_to_int_vec(VALUE vec_val);
|
12
|
+
VALUE dbl_vec_to_nary(double* const arr, int const size);
|
13
|
+
double* nary_to_dbl_vec(VALUE vec_val);
|
14
|
+
VALUE dbl_mat_to_nary(double** const mat, int const n_rows, int const n_cols);
|
15
|
+
double** nary_to_dbl_mat(VALUE mat_val);
|
16
|
+
VALUE svm_nodes_to_nary(struct svm_node** const support_vecs, const int n_support_vecs);
|
17
|
+
struct svm_node** nary_to_svm_nodes(VALUE model_val);
|
18
|
+
|
19
|
+
#endif /* NUMO_LIBSVM_CONVERTER_H */
|
@@ -0,0 +1,43 @@
|
|
1
|
+
require 'mkmf'
|
2
|
+
require 'numo/narray'
|
3
|
+
|
4
|
+
$LOAD_PATH.each do |lp|
|
5
|
+
if File.exist?(File.join(lp, 'numo/numo/narray.h'))
|
6
|
+
$INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
|
7
|
+
break
|
8
|
+
end
|
9
|
+
end
|
10
|
+
|
11
|
+
unless have_header('numo/narray.h')
|
12
|
+
puts 'numo/narray.h not found.'
|
13
|
+
exit(1)
|
14
|
+
end
|
15
|
+
|
16
|
+
if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
|
17
|
+
$LOAD_PATH.each do |lp|
|
18
|
+
if File.exist?(File.join(lp, 'numo/libnarray.a'))
|
19
|
+
$LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
|
20
|
+
break
|
21
|
+
end
|
22
|
+
end
|
23
|
+
unless have_library('narray', 'nary_new')
|
24
|
+
puts 'libnarray.a not found.'
|
25
|
+
exit(1)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
if RUBY_PLATFORM =~ /linux/
|
30
|
+
$INCFLAGS = "-I/usr/include/libsvm #{$INCFLAGS}"
|
31
|
+
end
|
32
|
+
|
33
|
+
unless have_header('svm.h')
|
34
|
+
puts 'svm.h not found.'
|
35
|
+
exit(1)
|
36
|
+
end
|
37
|
+
|
38
|
+
unless have_library('svm')
|
39
|
+
puts 'libsvm not found.'
|
40
|
+
exit(1)
|
41
|
+
end
|
42
|
+
|
43
|
+
create_makefile('numo/libsvm/libsvmext')
|
@@ -0,0 +1,22 @@
|
|
1
|
+
#include "kernel_type.h"
|
2
|
+
|
3
|
+
RUBY_EXTERN VALUE mLibsvm;
|
4
|
+
|
5
|
+
void rb_init_kernel_type_module()
|
6
|
+
{
|
7
|
+
/**
|
8
|
+
* Document-module: Numo::Libsvm::KernelType
|
9
|
+
* The module consisting of constants for kernel type that used for parameter of LIBSVM.
|
10
|
+
*/
|
11
|
+
VALUE mKernelType = rb_define_module_under(mLibsvm, "KernelType");
|
12
|
+
/* Linear kernel; u' * v */
|
13
|
+
rb_define_const(mKernelType, "LINEAR", INT2NUM(LINEAR));
|
14
|
+
/* Polynomial kernel; (gamma * u' * v + coef0)^degree */
|
15
|
+
rb_define_const(mKernelType, "POLY", INT2NUM(POLY));
|
16
|
+
/* RBF kernel; exp(-gamma * ||u - v||^2) */
|
17
|
+
rb_define_const(mKernelType, "RBF", INT2NUM(RBF));
|
18
|
+
/* Sigmoid kernel; tanh(gamma * u' * v + coef0) */
|
19
|
+
rb_define_const(mKernelType, "SIGMOID", INT2NUM(SIGMOID));
|
20
|
+
/* Precomputed kernel */
|
21
|
+
rb_define_const(mKernelType, "PRECOMPUTED", INT2NUM(PRECOMPUTED));
|
22
|
+
}
|
@@ -0,0 +1,486 @@
|
|
1
|
+
/**
|
2
|
+
* LIBSVM interface for Numo::NArray
|
3
|
+
*/
|
4
|
+
#include "libsvmext.h"
|
5
|
+
|
6
|
+
VALUE mNumo;
|
7
|
+
VALUE mLibsvm;
|
8
|
+
|
9
|
+
void print_null(const char *s) {}
|
10
|
+
|
11
|
+
/**
|
12
|
+
* Train the SVM model according to the given training data.
|
13
|
+
*
|
14
|
+
* @overload train(x, y, param) -> Hash
|
15
|
+
*
|
16
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
17
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
18
|
+
* @param param [Hash] The parameters of an SVM model.
|
19
|
+
* @return [Hash] The model obtained from the training procedure.
|
20
|
+
*/
|
21
|
+
static
|
22
|
+
VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
23
|
+
{
|
24
|
+
struct svm_problem* problem;
|
25
|
+
struct svm_parameter* param;
|
26
|
+
narray_t* x_nary;
|
27
|
+
double* x_pt;
|
28
|
+
double* y_pt;
|
29
|
+
int i, j;
|
30
|
+
int n_samples;
|
31
|
+
int n_features;
|
32
|
+
struct svm_model* model;
|
33
|
+
VALUE model_hash;
|
34
|
+
|
35
|
+
/* Obtain C data structures. */
|
36
|
+
if (CLASS_OF(x_val) != numo_cDFloat) {
|
37
|
+
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
38
|
+
}
|
39
|
+
if (CLASS_OF(y_val) != numo_cDFloat) {
|
40
|
+
y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
|
41
|
+
}
|
42
|
+
if (!RTEST(nary_check_contiguous(x_val))) {
|
43
|
+
x_val = nary_dup(x_val);
|
44
|
+
}
|
45
|
+
if (!RTEST(nary_check_contiguous(y_val))) {
|
46
|
+
y_val = nary_dup(y_val);
|
47
|
+
}
|
48
|
+
GetNArray(x_val, x_nary);
|
49
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
50
|
+
|
51
|
+
/* Initialize some variables. */
|
52
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
53
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
54
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
55
|
+
y_pt = (double*)na_get_pointer_for_read(y_val);
|
56
|
+
|
57
|
+
/* Prepare LIBSVM problem. */
|
58
|
+
problem = ALLOC(struct svm_problem);
|
59
|
+
problem->l = n_samples;
|
60
|
+
problem->x = ALLOC_N(struct svm_node*, n_samples);
|
61
|
+
problem->y = ALLOC_N(double, n_samples);
|
62
|
+
for (i = 0; i < n_samples; i++) {
|
63
|
+
problem->x[i] = ALLOC_N(struct svm_node, n_features + 1);
|
64
|
+
for (j = 0; j < n_features; j++) {
|
65
|
+
problem->x[i][j].index = j + 1;
|
66
|
+
problem->x[i][j].value = x_pt[i * n_features + j];
|
67
|
+
}
|
68
|
+
problem->x[i][n_features].index = -1;
|
69
|
+
problem->x[i][n_features].value = 0.0;
|
70
|
+
problem->y[i] = y_pt[i];
|
71
|
+
}
|
72
|
+
|
73
|
+
/* Perform training. */
|
74
|
+
svm_set_print_string_function(print_null);
|
75
|
+
model = svm_train(problem, param);
|
76
|
+
model_hash = svm_model_to_rb_hash(model);
|
77
|
+
svm_free_and_destroy_model(&model);
|
78
|
+
|
79
|
+
for (i = 0; i < n_samples; xfree(problem->x[i++]));
|
80
|
+
xfree(problem->x);
|
81
|
+
xfree(problem->y);
|
82
|
+
xfree(problem);
|
83
|
+
xfree_svm_parameter(param);
|
84
|
+
|
85
|
+
return model_hash;
|
86
|
+
}
|
87
|
+
|
88
|
+
/**
|
89
|
+
* Perform cross validation under given parameters. The given samples are separated to n_fols folds.
|
90
|
+
* The predicted labels or values in the validation process are returned.
|
91
|
+
*
|
92
|
+
* @overload cv(x, y, param, n_folds) -> Numo::DFloat
|
93
|
+
*
|
94
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
95
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
96
|
+
* @param param [Hash] The parameters of an SVM model.
|
97
|
+
* @param n_folds [Integer] The number of folds.
|
98
|
+
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
99
|
+
*/
|
100
|
+
static
|
101
|
+
VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds)
|
102
|
+
{
|
103
|
+
const int n_folds = NUM2INT(nr_folds);
|
104
|
+
struct svm_problem* problem;
|
105
|
+
struct svm_parameter* param;
|
106
|
+
narray_t* x_nary;
|
107
|
+
double* x_pt;
|
108
|
+
double* y_pt;
|
109
|
+
int i, j;
|
110
|
+
int n_samples;
|
111
|
+
int n_features;
|
112
|
+
size_t t_shape[1];
|
113
|
+
VALUE t_val;
|
114
|
+
double* t_pt;
|
115
|
+
|
116
|
+
/* Obtain C data structures. */
|
117
|
+
if (CLASS_OF(x_val) != numo_cDFloat) {
|
118
|
+
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
119
|
+
}
|
120
|
+
if (CLASS_OF(y_val) != numo_cDFloat) {
|
121
|
+
y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
|
122
|
+
}
|
123
|
+
if (!RTEST(nary_check_contiguous(x_val))) {
|
124
|
+
x_val = nary_dup(x_val);
|
125
|
+
}
|
126
|
+
if (!RTEST(nary_check_contiguous(y_val))) {
|
127
|
+
y_val = nary_dup(y_val);
|
128
|
+
}
|
129
|
+
GetNArray(x_val, x_nary);
|
130
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
131
|
+
|
132
|
+
/* Initialize some variables. */
|
133
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
134
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
135
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
136
|
+
y_pt = (double*)na_get_pointer_for_read(y_val);
|
137
|
+
|
138
|
+
/* Prepare LIBSVM problem. */
|
139
|
+
problem = ALLOC(struct svm_problem);
|
140
|
+
problem->l = n_samples;
|
141
|
+
problem->x = ALLOC_N(struct svm_node*, n_samples);
|
142
|
+
problem->y = ALLOC_N(double, n_samples);
|
143
|
+
for (i = 0; i < n_samples; i++) {
|
144
|
+
problem->x[i] = ALLOC_N(struct svm_node, n_features + 1);
|
145
|
+
for (j = 0; j < n_features; j++) {
|
146
|
+
problem->x[i][j].index = j + 1;
|
147
|
+
problem->x[i][j].value = x_pt[i * n_features + j];
|
148
|
+
}
|
149
|
+
problem->x[i][n_features].index = -1;
|
150
|
+
problem->x[i][n_features].value = 0.0;
|
151
|
+
problem->y[i] = y_pt[i];
|
152
|
+
}
|
153
|
+
|
154
|
+
/* Perform cross validation. */
|
155
|
+
t_shape[0] = n_samples;
|
156
|
+
t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
|
157
|
+
t_pt = (double*)na_get_pointer_for_write(t_val);
|
158
|
+
svm_set_print_string_function(print_null);
|
159
|
+
svm_cross_validation(problem, param, n_folds, t_pt);
|
160
|
+
|
161
|
+
for (i = 0; i < n_samples; xfree(problem->x[i++]));
|
162
|
+
xfree(problem->x);
|
163
|
+
xfree(problem->y);
|
164
|
+
xfree(problem);
|
165
|
+
xfree_svm_parameter(param);
|
166
|
+
|
167
|
+
return t_val;
|
168
|
+
}
|
169
|
+
|
170
|
+
/**
|
171
|
+
* Predict class labels or values for given samples.
|
172
|
+
*
|
173
|
+
* @overload predict(x, param, model) -> Numo::DFloat
|
174
|
+
*
|
175
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
176
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
177
|
+
* @param model [Hash] The model obtained from the training procedure.
|
178
|
+
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
179
|
+
*/
|
180
|
+
static
|
181
|
+
VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
182
|
+
{
|
183
|
+
struct svm_parameter* param;
|
184
|
+
struct svm_model* model;
|
185
|
+
struct svm_node* x_nodes;
|
186
|
+
narray_t* x_nary;
|
187
|
+
double* x_pt;
|
188
|
+
size_t y_shape[1];
|
189
|
+
VALUE y_val;
|
190
|
+
double* y_pt;
|
191
|
+
int i, j;
|
192
|
+
int n_samples;
|
193
|
+
int n_features;
|
194
|
+
|
195
|
+
/* Obtain C data structures. */
|
196
|
+
if (CLASS_OF(x_val) != numo_cDFloat) {
|
197
|
+
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
198
|
+
}
|
199
|
+
if (!RTEST(nary_check_contiguous(x_val))) {
|
200
|
+
x_val = nary_dup(x_val);
|
201
|
+
}
|
202
|
+
GetNArray(x_val, x_nary);
|
203
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
204
|
+
model = rb_hash_to_svm_model(model_hash);
|
205
|
+
model->param = *param;
|
206
|
+
|
207
|
+
/* Initialize some variables. */
|
208
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
209
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
210
|
+
y_shape[0] = n_samples;
|
211
|
+
y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
|
212
|
+
y_pt = (double*)na_get_pointer_for_write(y_val);
|
213
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
214
|
+
|
215
|
+
/* Predict values. */
|
216
|
+
x_nodes = ALLOC_N(struct svm_node, n_features + 1);
|
217
|
+
x_nodes[n_features].index = -1;
|
218
|
+
x_nodes[n_features].value = 0.0;
|
219
|
+
for (i = 0; i < n_samples; i++) {
|
220
|
+
for (j = 0; j < n_features; j++) {
|
221
|
+
x_nodes[j].index = j + 1;
|
222
|
+
x_nodes[j].value = (double)x_pt[i * n_features + j];
|
223
|
+
}
|
224
|
+
y_pt[i] = svm_predict(model, x_nodes);
|
225
|
+
}
|
226
|
+
|
227
|
+
xfree(x_nodes);
|
228
|
+
xfree_svm_model(model);
|
229
|
+
xfree_svm_parameter(param);
|
230
|
+
|
231
|
+
return y_val;
|
232
|
+
}
|
233
|
+
|
234
|
+
/**
|
235
|
+
* Calculate decision values for given samples.
|
236
|
+
*
|
237
|
+
* @overload decision_function(x, param, model) -> Numo::DFloat
|
238
|
+
*
|
239
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
240
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
241
|
+
* @param model [Hash] The model obtained from the training procedure.
|
242
|
+
* @return [Numo::DFloat] (shape: [n_samples, n_classes * (n_classes - 1) / 2]) The decision value of each sample.
|
243
|
+
*/
|
244
|
+
static
|
245
|
+
VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
246
|
+
{
|
247
|
+
struct svm_parameter* param;
|
248
|
+
struct svm_model* model;
|
249
|
+
struct svm_node* x_nodes;
|
250
|
+
narray_t* x_nary;
|
251
|
+
double* x_pt;
|
252
|
+
size_t y_shape[2];
|
253
|
+
VALUE y_val;
|
254
|
+
double* y_pt;
|
255
|
+
double* dec_values;
|
256
|
+
int y_cols;
|
257
|
+
int i, j;
|
258
|
+
int n_samples;
|
259
|
+
int n_features;
|
260
|
+
|
261
|
+
/* Obtain C data structures. */
|
262
|
+
if (CLASS_OF(x_val) != numo_cDFloat) {
|
263
|
+
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
264
|
+
}
|
265
|
+
if (!RTEST(nary_check_contiguous(x_val))) {
|
266
|
+
x_val = nary_dup(x_val);
|
267
|
+
}
|
268
|
+
GetNArray(x_val, x_nary);
|
269
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
270
|
+
model = rb_hash_to_svm_model(model_hash);
|
271
|
+
model->param = *param;
|
272
|
+
|
273
|
+
/* Initialize some variables. */
|
274
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
275
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
276
|
+
|
277
|
+
if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) {
|
278
|
+
y_shape[0] = n_samples;
|
279
|
+
y_shape[1] = 1;
|
280
|
+
y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
|
281
|
+
} else {
|
282
|
+
y_shape[0] = n_samples;
|
283
|
+
y_shape[1] = model->nr_class * (model->nr_class - 1) / 2;
|
284
|
+
y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
|
285
|
+
}
|
286
|
+
|
287
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
288
|
+
y_pt = (double*)na_get_pointer_for_write(y_val);
|
289
|
+
|
290
|
+
/* Predict values. */
|
291
|
+
if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) {
|
292
|
+
x_nodes = ALLOC_N(struct svm_node, n_features + 1);
|
293
|
+
x_nodes[n_features].index = -1;
|
294
|
+
x_nodes[n_features].value = 0.0;
|
295
|
+
for (i = 0; i < n_samples; i++) {
|
296
|
+
for (j = 0; j < n_features; j++) {
|
297
|
+
x_nodes[j].index = j + 1;
|
298
|
+
x_nodes[j].value = (double)x_pt[i * n_features + j];
|
299
|
+
}
|
300
|
+
svm_predict_values(model, x_nodes, &y_pt[i]);
|
301
|
+
}
|
302
|
+
xfree(x_nodes);
|
303
|
+
} else {
|
304
|
+
y_cols = (int)y_shape[1];
|
305
|
+
dec_values = ALLOC_N(double, y_cols);
|
306
|
+
x_nodes = ALLOC_N(struct svm_node, n_features + 1);
|
307
|
+
x_nodes[n_features].index = -1;
|
308
|
+
x_nodes[n_features].value = 0.0;
|
309
|
+
for (i = 0; i < n_samples; i++) {
|
310
|
+
for (j = 0; j < n_features; j++) {
|
311
|
+
x_nodes[j].index = j + 1;
|
312
|
+
x_nodes[j].value = (double)x_pt[i * n_features + j];
|
313
|
+
}
|
314
|
+
svm_predict_values(model, x_nodes, dec_values);
|
315
|
+
for (j = 0; j < y_cols; j++) {
|
316
|
+
y_pt[i * y_cols + j] = dec_values[j];
|
317
|
+
}
|
318
|
+
}
|
319
|
+
xfree(x_nodes);
|
320
|
+
xfree(dec_values);
|
321
|
+
}
|
322
|
+
|
323
|
+
xfree_svm_model(model);
|
324
|
+
xfree_svm_parameter(param);
|
325
|
+
|
326
|
+
return y_val;
|
327
|
+
}
|
328
|
+
|
329
|
+
/**
|
330
|
+
* Predict class probability for given samples. The model must have probability information calcualted in training procedure.
|
331
|
+
* The parameter ':probability' set to 1 in training procedure.
|
332
|
+
*
|
333
|
+
* @overload predict_proba(x, param, model) -> Numo::DFloat
|
334
|
+
*
|
335
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the class probabilities.
|
336
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
337
|
+
* @param model [Hash] The model obtained from the training procedure.
|
338
|
+
* @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probablity of each class per sample.
|
339
|
+
*/
|
340
|
+
static
|
341
|
+
VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
342
|
+
{
|
343
|
+
struct svm_parameter* param;
|
344
|
+
struct svm_model* model;
|
345
|
+
struct svm_node* x_nodes;
|
346
|
+
narray_t* x_nary;
|
347
|
+
double* x_pt;
|
348
|
+
size_t y_shape[2];
|
349
|
+
VALUE y_val = Qnil;
|
350
|
+
double* y_pt;
|
351
|
+
double* probs;
|
352
|
+
int i, j;
|
353
|
+
int n_samples;
|
354
|
+
int n_features;
|
355
|
+
|
356
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
357
|
+
model = rb_hash_to_svm_model(model_hash);
|
358
|
+
model->param = *param;
|
359
|
+
|
360
|
+
if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && model->probA != NULL && model->probB != NULL) {
|
361
|
+
/* Obtain C data structures. */
|
362
|
+
if (CLASS_OF(x_val) != numo_cDFloat) {
|
363
|
+
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
364
|
+
}
|
365
|
+
if (!RTEST(nary_check_contiguous(x_val))) {
|
366
|
+
x_val = nary_dup(x_val);
|
367
|
+
}
|
368
|
+
GetNArray(x_val, x_nary);
|
369
|
+
|
370
|
+
/* Initialize some variables. */
|
371
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
372
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
373
|
+
y_shape[0] = n_samples;
|
374
|
+
y_shape[1] = model->nr_class;
|
375
|
+
y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
|
376
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
377
|
+
y_pt = (double*)na_get_pointer_for_write(y_val);
|
378
|
+
|
379
|
+
/* Predict values. */
|
380
|
+
probs = ALLOC_N(double, model->nr_class);
|
381
|
+
x_nodes = ALLOC_N(struct svm_node, n_features + 1);
|
382
|
+
x_nodes[n_features].index = -1;
|
383
|
+
x_nodes[n_features].value = 0.0;
|
384
|
+
for (i = 0; i < n_samples; i++) {
|
385
|
+
for (j = 0; j < n_features; j++) {
|
386
|
+
x_nodes[j].index = j + 1;
|
387
|
+
x_nodes[j].value = (double)x_pt[i * n_features + j];
|
388
|
+
}
|
389
|
+
svm_predict_probability(model, x_nodes, probs);
|
390
|
+
for (j = 0; j < model->nr_class; j++) {
|
391
|
+
y_pt[i * model->nr_class + j] = probs[j];
|
392
|
+
}
|
393
|
+
}
|
394
|
+
xfree(x_nodes);
|
395
|
+
xfree(probs);
|
396
|
+
}
|
397
|
+
|
398
|
+
xfree_svm_model(model);
|
399
|
+
xfree_svm_parameter(param);
|
400
|
+
|
401
|
+
return y_val;
|
402
|
+
}
|
403
|
+
|
404
|
+
/**
|
405
|
+
* Load the SVM parameters and model from a text file with LIBSVM format.
|
406
|
+
*
|
407
|
+
* @param filename [String] The path to a file to load.
|
408
|
+
* @return [Array] Array contains the SVM parameters and model.
|
409
|
+
*/
|
410
|
+
static
|
411
|
+
VALUE load_svm_model(VALUE self, VALUE filename)
|
412
|
+
{
|
413
|
+
struct svm_model* model = svm_load_model(StringValuePtr(filename));
|
414
|
+
VALUE res = rb_ary_new2(2);
|
415
|
+
VALUE param_hash = Qnil;
|
416
|
+
VALUE model_hash = Qnil;
|
417
|
+
|
418
|
+
if (model) {
|
419
|
+
param_hash = svm_parameter_to_rb_hash(&(model->param));
|
420
|
+
model_hash = svm_model_to_rb_hash(model);
|
421
|
+
svm_free_and_destroy_model(&model);
|
422
|
+
}
|
423
|
+
|
424
|
+
rb_ary_store(res, 0, param_hash);
|
425
|
+
rb_ary_store(res, 1, model_hash);
|
426
|
+
|
427
|
+
return res;
|
428
|
+
}
|
429
|
+
|
430
|
+
/**
|
431
|
+
* Save the SVM parameters and model as a text file with LIBSVM format. The saved file can be used with the libsvm tools.
|
432
|
+
* Note that the svm_save_model saves only the parameters necessary for estimation with the trained model.
|
433
|
+
*
|
434
|
+
* @overload save_svm_model(filename, param, model) -> Boolean
|
435
|
+
*
|
436
|
+
* @param filename [String] The path to a file to save.
|
437
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
438
|
+
* @param model [Hash] The model obtained from the training procedure.
|
439
|
+
* @return [Boolean] true on success, or false if an error occurs.
|
440
|
+
*/
|
441
|
+
static
|
442
|
+
VALUE save_svm_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash)
|
443
|
+
{
|
444
|
+
struct svm_parameter* param = rb_hash_to_svm_parameter(param_hash);
|
445
|
+
struct svm_model* model = rb_hash_to_svm_model(model_hash);
|
446
|
+
int res;
|
447
|
+
|
448
|
+
model->param = *param;
|
449
|
+
res = svm_save_model(StringValuePtr(filename), model);
|
450
|
+
|
451
|
+
xfree_svm_model(model);
|
452
|
+
xfree_svm_parameter(param);
|
453
|
+
|
454
|
+
return res < 0 ? Qfalse : Qtrue;
|
455
|
+
}
|
456
|
+
|
457
|
+
void Init_libsvmext()
|
458
|
+
{
|
459
|
+
rb_require("numo/narray");
|
460
|
+
|
461
|
+
/**
|
462
|
+
* Document-module: Numo
|
463
|
+
* Numo is the top level namespace of NUmerical MOdules for Ruby.
|
464
|
+
*/
|
465
|
+
mNumo = rb_define_module("Numo");
|
466
|
+
|
467
|
+
/**
|
468
|
+
* Document-module: Numo::Libsvm
|
469
|
+
* Numo::Libsvm is a binding library for LIBSVM that handles dataset with Numo::NArray.
|
470
|
+
*/
|
471
|
+
mLibsvm = rb_define_module_under(mNumo, "Libsvm");
|
472
|
+
|
473
|
+
/* The version of LIBSVM used in backgroud library. */
|
474
|
+
rb_define_const(mLibsvm, "LIBSVM_VERSION", INT2NUM(LIBSVM_VERSION));
|
475
|
+
|
476
|
+
rb_define_module_function(mLibsvm, "train", train, 3);
|
477
|
+
rb_define_module_function(mLibsvm, "cv", cross_validation, 4);
|
478
|
+
rb_define_module_function(mLibsvm, "predict", predict, 3);
|
479
|
+
rb_define_module_function(mLibsvm, "decision_function", decision_function, 3);
|
480
|
+
rb_define_module_function(mLibsvm, "predict_proba", predict_proba, 3);
|
481
|
+
rb_define_module_function(mLibsvm, "load_svm_model", load_svm_model, 1);
|
482
|
+
rb_define_module_function(mLibsvm, "save_svm_model", save_svm_model, 3);
|
483
|
+
|
484
|
+
rb_init_svm_type_module();
|
485
|
+
rb_init_kernel_type_module();
|
486
|
+
}
|