numo-libsvm 0.1.0 → 0.2.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +29 -2
- data/ext/numo/libsvm/libsvmext.c +121 -92
- data/ext/numo/libsvm/libsvmext.h +1 -0
- data/ext/numo/libsvm/svm_problem.c +58 -0
- data/ext/numo/libsvm/svm_problem.h +12 -0
- data/lib/numo/libsvm/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 92066a1f30fa986bae110cffad133765bc1a2356
|
4
|
+
data.tar.gz: db27417bc129d089716a97df86b88c3dbd935ac1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 77064dc4a7128b69ca9130165378dd53a71d98ce8fc1cdccc1462a913d1a5f79ebb1394750cdb67122bf3e37ddfd986ade99c39cde299aef492da11b1d58f548
|
7
|
+
data.tar.gz: b37b562c48d434c21b94fff3d1718703011d69be188ed078936936d84fca3d0164cefafad1c880f25191eb72fccd2207824f3e434552b88dc51f736bea8090bb
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,8 @@
|
|
2
2
|
|
3
3
|
[![Build Status](https://travis-ci.org/yoshoku/numo-libsvm.svg?branch=master)](https://travis-ci.org/yoshoku/numo-libsvm)
|
4
4
|
[![Gem Version](https://badge.fury.io/rb/numo-libsvm.svg)](https://badge.fury.io/rb/numo-libsvm)
|
5
|
-
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-
|
5
|
+
[![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/numo-libsvm/blob/master/LICENSE.txt)
|
6
|
+
[![Documentation](http://img.shields.io/badge/docs-rdoc.info-blue.svg)](https://www.rubydoc.info/gems/numo-libsvm/0.2.0)
|
6
7
|
|
7
8
|
Numo::Libsvm is a Ruby gem binding to the [LIBSVM](https://github.com/cjlin1/libsvm) library.
|
8
9
|
LIBSVM is one of the famous libraries that implemented Support Vector Machines,
|
@@ -167,6 +168,32 @@ The hyperparameter of SVM is given with Ruby Hash on Numo::Libsvm.
|
|
167
168
|
The hash key of hyperparameter and its meaning match the struct svm_parameter of LIBSVM.
|
168
169
|
The svm_parameter is detailed in [LIBSVM README](https://github.com/cjlin1/libsvm/blob/master/README).
|
169
170
|
|
171
|
+
```ruby
|
172
|
+
param = {
|
173
|
+
svm_type: # [Integer] Type of SVM
|
174
|
+
Numo::Libsvm::SvmType::C_SVC,
|
175
|
+
# for kernel function
|
176
|
+
kernel_type: # [Integer] Type of kernel function
|
177
|
+
Numo::Libsvm::KernelType::RBF,
|
178
|
+
degree: 3, # [Integer] Degree in polynomial kernel function
|
179
|
+
gamma: 0.5, # [Float] Gamma in poly/rbf/sigmoid kernel function
|
180
|
+
coef0: 1.0, # [Float] Coefficient in poly/sigmoid kernel function
|
181
|
+
# for training procedure
|
182
|
+
cache_size: 100, # [Float] Cache memory size in MB
|
183
|
+
eps: 1e-3, # [Float] Tolerance of termination criterion
|
184
|
+
C: 1.0, # [Float] Parameter C of C-SVC, epsilon-SVR, and nu-SVR
|
185
|
+
nr_weight: 3, # [Integer] Number of weights for C-SVC
|
186
|
+
weight_label: # [Numo::Int32] Labels to add weight in C-SVC
|
187
|
+
Numo::Int32[0, 1, 2],
|
188
|
+
weight: # [Numo::DFloat] Weight values in C-SVC
|
189
|
+
Numo::DFloat[0.4, 0.4, 0.2],
|
190
|
+
nu: 0.5, # [Float] Parameter nu of nu-SVC, one-class SVM, and nu-SVR
|
191
|
+
p: 0.1, # [Float] Parameter epsilon in loss function of epsilon-SVR
|
192
|
+
shrinking: true, # [Boolean] Whether to use the shrinking heuristics
|
193
|
+
probability: false # [Boolean] Whether to train a SVC or SVR model for probability estimates
|
194
|
+
}
|
195
|
+
```
|
196
|
+
|
170
197
|
## Contributing
|
171
198
|
|
172
199
|
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/numo-libsvm. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
|
@@ -177,4 +204,4 @@ The gem is available as open source under the terms of the [BSD-3-Clause License
|
|
177
204
|
|
178
205
|
## Code of Conduct
|
179
206
|
|
180
|
-
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/
|
207
|
+
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/numo-libsvm/blob/master/CODE_OF_CONDUCT.md).
|
data/ext/numo/libsvm/libsvmext.c
CHANGED
@@ -12,10 +12,13 @@ void print_null(const char *s) {}
|
|
12
12
|
* Train the SVM model according to the given training data.
|
13
13
|
*
|
14
14
|
* @overload train(x, y, param) -> Hash
|
15
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
16
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
17
|
+
* @param param [Hash] The parameters of an SVM model.
|
15
18
|
*
|
16
|
-
* @
|
17
|
-
*
|
18
|
-
*
|
19
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
|
20
|
+
* the sample array and label array do not have the same number of samples, or
|
21
|
+
* the hyperparameter has an invalid value, this error is raised.
|
19
22
|
* @return [Hash] The model obtained from the training procedure.
|
20
23
|
*/
|
21
24
|
static
|
@@ -23,16 +26,12 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
23
26
|
{
|
24
27
|
struct svm_problem* problem;
|
25
28
|
struct svm_parameter* param;
|
26
|
-
narray_t* x_nary;
|
27
|
-
double* x_pt;
|
28
|
-
double* y_pt;
|
29
|
-
int i, j;
|
30
|
-
int n_samples;
|
31
|
-
int n_features;
|
32
29
|
struct svm_model* model;
|
30
|
+
narray_t* x_nary;
|
31
|
+
narray_t* y_nary;
|
32
|
+
char* err_msg;
|
33
33
|
VALUE model_hash;
|
34
34
|
|
35
|
-
/* Obtain C data structures. */
|
36
35
|
if (CLASS_OF(x_val) != numo_cDFloat) {
|
37
36
|
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
38
37
|
}
|
@@ -45,41 +44,39 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
45
44
|
if (!RTEST(nary_check_contiguous(y_val))) {
|
46
45
|
y_val = nary_dup(y_val);
|
47
46
|
}
|
48
|
-
GetNArray(x_val, x_nary);
|
49
|
-
param = rb_hash_to_svm_parameter(param_hash);
|
50
47
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
GetNArray(x_val, x_nary);
|
49
|
+
GetNArray(y_val, y_nary);
|
50
|
+
if (NA_NDIM(x_nary) != 2) {
|
51
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
52
|
+
return Qnil;
|
53
|
+
}
|
54
|
+
if (NA_NDIM(y_nary) != 1) {
|
55
|
+
rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
|
56
|
+
return Qnil;
|
57
|
+
}
|
58
|
+
if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
|
59
|
+
rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
|
60
|
+
return Qnil;
|
61
|
+
}
|
56
62
|
|
57
|
-
|
58
|
-
problem =
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
problem->x[i][j].value = x_pt[i * n_features + j];
|
67
|
-
}
|
68
|
-
problem->x[i][n_features].index = -1;
|
69
|
-
problem->x[i][n_features].value = 0.0;
|
70
|
-
problem->y[i] = y_pt[i];
|
63
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
64
|
+
problem = dataset_to_svm_problem(x_val, y_val);
|
65
|
+
|
66
|
+
err_msg = svm_check_parameter(problem, param);
|
67
|
+
if (err_msg) {
|
68
|
+
xfree_svm_problem(problem);
|
69
|
+
xfree_svm_parameter(param);
|
70
|
+
rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
|
71
|
+
return Qnil;
|
71
72
|
}
|
72
73
|
|
73
|
-
/* Perform training. */
|
74
74
|
svm_set_print_string_function(print_null);
|
75
75
|
model = svm_train(problem, param);
|
76
76
|
model_hash = svm_model_to_rb_hash(model);
|
77
77
|
svm_free_and_destroy_model(&model);
|
78
78
|
|
79
|
-
|
80
|
-
xfree(problem->x);
|
81
|
-
xfree(problem->y);
|
82
|
-
xfree(problem);
|
79
|
+
xfree_svm_problem(problem);
|
83
80
|
xfree_svm_parameter(param);
|
84
81
|
|
85
82
|
return model_hash;
|
@@ -90,30 +87,29 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
90
87
|
* The predicted labels or values in the validation process are returned.
|
91
88
|
*
|
92
89
|
* @overload cv(x, y, param, n_folds) -> Numo::DFloat
|
90
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
91
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
92
|
+
* @param param [Hash] The parameters of an SVM model.
|
93
|
+
* @param n_folds [Integer] The number of folds.
|
93
94
|
*
|
94
|
-
* @
|
95
|
-
*
|
96
|
-
*
|
97
|
-
* @param n_folds [Integer] The number of folds.
|
95
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
|
96
|
+
* the sample array and label array do not have the same number of samples, or
|
97
|
+
* the hyperparameter has an invalid value, this error is raised.
|
98
98
|
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
99
99
|
*/
|
100
100
|
static
|
101
101
|
VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds)
|
102
102
|
{
|
103
103
|
const int n_folds = NUM2INT(nr_folds);
|
104
|
-
struct svm_problem* problem;
|
105
|
-
struct svm_parameter* param;
|
106
|
-
narray_t* x_nary;
|
107
|
-
double* x_pt;
|
108
|
-
double* y_pt;
|
109
|
-
int i, j;
|
110
|
-
int n_samples;
|
111
|
-
int n_features;
|
112
104
|
size_t t_shape[1];
|
113
105
|
VALUE t_val;
|
114
106
|
double* t_pt;
|
107
|
+
narray_t* x_nary;
|
108
|
+
narray_t* y_nary;
|
109
|
+
char* err_msg;
|
110
|
+
struct svm_problem* problem;
|
111
|
+
struct svm_parameter* param;
|
115
112
|
|
116
|
-
/* Obtain C data structures. */
|
117
113
|
if (CLASS_OF(x_val) != numo_cDFloat) {
|
118
114
|
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
119
115
|
}
|
@@ -126,42 +122,41 @@ VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, V
|
|
126
122
|
if (!RTEST(nary_check_contiguous(y_val))) {
|
127
123
|
y_val = nary_dup(y_val);
|
128
124
|
}
|
129
|
-
GetNArray(x_val, x_nary);
|
130
|
-
param = rb_hash_to_svm_parameter(param_hash);
|
131
125
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
126
|
+
GetNArray(x_val, x_nary);
|
127
|
+
GetNArray(y_val, y_nary);
|
128
|
+
if (NA_NDIM(x_nary) != 2) {
|
129
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
130
|
+
return Qnil;
|
131
|
+
}
|
132
|
+
if (NA_NDIM(y_nary) != 1) {
|
133
|
+
rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
|
134
|
+
return Qnil;
|
135
|
+
}
|
136
|
+
if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
|
137
|
+
rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
|
138
|
+
return Qnil;
|
139
|
+
}
|
137
140
|
|
138
|
-
|
139
|
-
problem =
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
problem->x[i][j].value = x_pt[i * n_features + j];
|
148
|
-
}
|
149
|
-
problem->x[i][n_features].index = -1;
|
150
|
-
problem->x[i][n_features].value = 0.0;
|
151
|
-
problem->y[i] = y_pt[i];
|
141
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
142
|
+
problem = dataset_to_svm_problem(x_val, y_val);
|
143
|
+
|
144
|
+
err_msg = svm_check_parameter(problem, param);
|
145
|
+
if (err_msg) {
|
146
|
+
xfree_svm_problem(problem);
|
147
|
+
xfree_svm_parameter(param);
|
148
|
+
rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
|
149
|
+
return Qnil;
|
152
150
|
}
|
153
151
|
|
154
|
-
|
155
|
-
t_shape[0] = n_samples;
|
152
|
+
t_shape[0] = problem->l;
|
156
153
|
t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
|
157
154
|
t_pt = (double*)na_get_pointer_for_write(t_val);
|
155
|
+
|
158
156
|
svm_set_print_string_function(print_null);
|
159
157
|
svm_cross_validation(problem, param, n_folds, t_pt);
|
160
158
|
|
161
|
-
|
162
|
-
xfree(problem->x);
|
163
|
-
xfree(problem->y);
|
164
|
-
xfree(problem);
|
159
|
+
xfree_svm_problem(problem);
|
165
160
|
xfree_svm_parameter(param);
|
166
161
|
|
167
162
|
return t_val;
|
@@ -171,10 +166,11 @@ VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, V
|
|
171
166
|
* Predict class labels or values for given samples.
|
172
167
|
*
|
173
168
|
* @overload predict(x, param, model) -> Numo::DFloat
|
169
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
170
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
171
|
+
* @param model [Hash] The model obtained from the training procedure.
|
174
172
|
*
|
175
|
-
* @
|
176
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
177
|
-
* @param model [Hash] The model obtained from the training procedure.
|
173
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
178
174
|
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
179
175
|
*/
|
180
176
|
static
|
@@ -199,7 +195,13 @@ VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
199
195
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
200
196
|
x_val = nary_dup(x_val);
|
201
197
|
}
|
198
|
+
|
202
199
|
GetNArray(x_val, x_nary);
|
200
|
+
if (NA_NDIM(x_nary) != 2) {
|
201
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
202
|
+
return Qnil;
|
203
|
+
}
|
204
|
+
|
203
205
|
param = rb_hash_to_svm_parameter(param_hash);
|
204
206
|
model = rb_hash_to_svm_model(model_hash);
|
205
207
|
model->param = *param;
|
@@ -235,10 +237,11 @@ VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
235
237
|
* Calculate decision values for given samples.
|
236
238
|
*
|
237
239
|
* @overload decision_function(x, param, model) -> Numo::DFloat
|
240
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
241
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
242
|
+
* @param model [Hash] The model obtained from the training procedure.
|
238
243
|
*
|
239
|
-
* @
|
240
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
241
|
-
* @param model [Hash] The model obtained from the training procedure.
|
244
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
242
245
|
* @return [Numo::DFloat] (shape: [n_samples, n_classes * (n_classes - 1) / 2]) The decision value of each sample.
|
243
246
|
*/
|
244
247
|
static
|
@@ -265,7 +268,13 @@ VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_h
|
|
265
268
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
266
269
|
x_val = nary_dup(x_val);
|
267
270
|
}
|
271
|
+
|
268
272
|
GetNArray(x_val, x_nary);
|
273
|
+
if (NA_NDIM(x_nary) != 2) {
|
274
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
275
|
+
return Qnil;
|
276
|
+
}
|
277
|
+
|
269
278
|
param = rb_hash_to_svm_parameter(param_hash);
|
270
279
|
model = rb_hash_to_svm_model(model_hash);
|
271
280
|
model->param = *param;
|
@@ -331,10 +340,11 @@ VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_h
|
|
331
340
|
* The parameter ':probability' set to 1 in training procedure.
|
332
341
|
*
|
333
342
|
* @overload predict_proba(x, param, model) -> Numo::DFloat
|
343
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the class probabilities.
|
344
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
345
|
+
* @param model [Hash] The model obtained from the training procedure.
|
334
346
|
*
|
335
|
-
* @
|
336
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
337
|
-
* @param model [Hash] The model obtained from the training procedure.
|
347
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
338
348
|
* @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probablity of each class per sample.
|
339
349
|
*/
|
340
350
|
static
|
@@ -353,6 +363,12 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
353
363
|
int n_samples;
|
354
364
|
int n_features;
|
355
365
|
|
366
|
+
GetNArray(x_val, x_nary);
|
367
|
+
if (NA_NDIM(x_nary) != 2) {
|
368
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
369
|
+
return Qnil;
|
370
|
+
}
|
371
|
+
|
356
372
|
param = rb_hash_to_svm_parameter(param_hash);
|
357
373
|
model = rb_hash_to_svm_model(model_hash);
|
358
374
|
model->param = *param;
|
@@ -365,7 +381,6 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
365
381
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
366
382
|
x_val = nary_dup(x_val);
|
367
383
|
}
|
368
|
-
GetNArray(x_val, x_nary);
|
369
384
|
|
370
385
|
/* Initialize some variables. */
|
371
386
|
n_samples = (int)NA_SHAPE(x_nary)[0];
|
@@ -405,16 +420,23 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
405
420
|
* Load the SVM parameters and model from a text file with LIBSVM format.
|
406
421
|
*
|
407
422
|
* @param filename [String] The path to a file to load.
|
423
|
+
* @raise [IOError] This error raises when failed to load the model file.
|
408
424
|
* @return [Array] Array contains the SVM parameters and model.
|
409
425
|
*/
|
410
426
|
static
|
411
427
|
VALUE load_svm_model(VALUE self, VALUE filename)
|
412
428
|
{
|
413
|
-
|
429
|
+
char* filename_ = StringValuePtr(filename);
|
430
|
+
struct svm_model* model = svm_load_model(filename_);
|
414
431
|
VALUE res = rb_ary_new2(2);
|
415
432
|
VALUE param_hash = Qnil;
|
416
433
|
VALUE model_hash = Qnil;
|
417
434
|
|
435
|
+
if (model == NULL) {
|
436
|
+
rb_raise(rb_eIOError, "Failed to load file '%s'", filename_);
|
437
|
+
return Qnil;
|
438
|
+
}
|
439
|
+
|
418
440
|
if (model) {
|
419
441
|
param_hash = svm_parameter_to_rb_hash(&(model->param));
|
420
442
|
model_hash = svm_model_to_rb_hash(model);
|
@@ -432,26 +454,33 @@ VALUE load_svm_model(VALUE self, VALUE filename)
|
|
432
454
|
* Note that the svm_save_model saves only the parameters necessary for estimation with the trained model.
|
433
455
|
*
|
434
456
|
* @overload save_svm_model(filename, param, model) -> Boolean
|
457
|
+
* @param filename [String] The path to a file to save.
|
458
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
459
|
+
* @param model [Hash] The model obtained from the training procedure.
|
435
460
|
*
|
436
|
-
* @
|
437
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
438
|
-
* @param model [Hash] The model obtained from the training procedure.
|
461
|
+
* @raise [IOError] This error raises when failed to save the model file.
|
439
462
|
* @return [Boolean] true on success, or false if an error occurs.
|
440
463
|
*/
|
441
464
|
static
|
442
465
|
VALUE save_svm_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash)
|
443
466
|
{
|
467
|
+
char* filename_ = StringValuePtr(filename);
|
444
468
|
struct svm_parameter* param = rb_hash_to_svm_parameter(param_hash);
|
445
469
|
struct svm_model* model = rb_hash_to_svm_model(model_hash);
|
446
470
|
int res;
|
447
471
|
|
448
472
|
model->param = *param;
|
449
|
-
res = svm_save_model(
|
473
|
+
res = svm_save_model(filename_, model);
|
450
474
|
|
451
475
|
xfree_svm_model(model);
|
452
476
|
xfree_svm_parameter(param);
|
453
477
|
|
454
|
-
|
478
|
+
if (res < 0) {
|
479
|
+
rb_raise(rb_eIOError, "Failed to save file '%s'", filename_);
|
480
|
+
return Qfalse;
|
481
|
+
}
|
482
|
+
|
483
|
+
return Qtrue;
|
455
484
|
}
|
456
485
|
|
457
486
|
void Init_libsvmext()
|
data/ext/numo/libsvm/libsvmext.h
CHANGED
@@ -0,0 +1,58 @@
|
|
1
|
+
#include "svm_problem.h"
|
2
|
+
|
3
|
+
void xfree_svm_problem(struct svm_problem* problem)
|
4
|
+
{
|
5
|
+
int i;
|
6
|
+
if (problem) {
|
7
|
+
if (problem->x) {
|
8
|
+
for (i = 0; i < problem->l; i++) {
|
9
|
+
if (problem->x[i]) {
|
10
|
+
xfree(problem->x[i]);
|
11
|
+
problem->x[i] = NULL;
|
12
|
+
}
|
13
|
+
}
|
14
|
+
xfree(problem->x);
|
15
|
+
problem->x = NULL;
|
16
|
+
}
|
17
|
+
if (problem->y) {
|
18
|
+
xfree(problem->y);
|
19
|
+
problem->y = NULL;
|
20
|
+
}
|
21
|
+
xfree(problem);
|
22
|
+
problem = NULL;
|
23
|
+
}
|
24
|
+
}
|
25
|
+
|
26
|
+
struct svm_problem* dataset_to_svm_problem(VALUE x_val, VALUE y_val)
|
27
|
+
{
|
28
|
+
struct svm_problem* problem;
|
29
|
+
narray_t* x_nary;
|
30
|
+
double* x_pt;
|
31
|
+
double* y_pt;
|
32
|
+
int i, j;
|
33
|
+
int n_samples;
|
34
|
+
int n_features;
|
35
|
+
|
36
|
+
GetNArray(x_val, x_nary);
|
37
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
38
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
39
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
40
|
+
y_pt = (double*)na_get_pointer_for_read(y_val);
|
41
|
+
|
42
|
+
problem = ALLOC(struct svm_problem);
|
43
|
+
problem->l = n_samples;
|
44
|
+
problem->x = ALLOC_N(struct svm_node*, n_samples);
|
45
|
+
problem->y = ALLOC_N(double, n_samples);
|
46
|
+
for (i = 0; i < n_samples; i++) {
|
47
|
+
problem->x[i] = ALLOC_N(struct svm_node, n_features + 1);
|
48
|
+
for (j = 0; j < n_features; j++) {
|
49
|
+
problem->x[i][j].index = j + 1;
|
50
|
+
problem->x[i][j].value = x_pt[i * n_features + j];
|
51
|
+
}
|
52
|
+
problem->x[i][n_features].index = -1;
|
53
|
+
problem->x[i][n_features].value = 0.0;
|
54
|
+
problem->y[i] = y_pt[i];
|
55
|
+
}
|
56
|
+
|
57
|
+
return problem;
|
58
|
+
}
|
@@ -0,0 +1,12 @@
|
|
1
|
+
#ifndef NUMO_LIBSVM_SVM_PROBLEM_H
|
2
|
+
#define NUMO_LIBSVM_SVM_PROBLEM_H 1
|
3
|
+
|
4
|
+
#include <svm.h>
|
5
|
+
#include <ruby.h>
|
6
|
+
#include <numo/narray.h>
|
7
|
+
#include <numo/template.h>
|
8
|
+
|
9
|
+
void xfree_svm_problem(struct svm_problem* problem);
|
10
|
+
struct svm_problem* dataset_to_svm_problem(VALUE x_val, VALUE y_val);
|
11
|
+
|
12
|
+
#endif /* NUMO_LIBSVM_SVM_PROBLEM_H */
|
data/lib/numo/libsvm/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-libsvm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-08-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,6 +112,8 @@ files:
|
|
112
112
|
- ext/numo/libsvm/svm_model.h
|
113
113
|
- ext/numo/libsvm/svm_parameter.c
|
114
114
|
- ext/numo/libsvm/svm_parameter.h
|
115
|
+
- ext/numo/libsvm/svm_problem.c
|
116
|
+
- ext/numo/libsvm/svm_problem.h
|
115
117
|
- ext/numo/libsvm/svm_type.c
|
116
118
|
- ext/numo/libsvm/svm_type.h
|
117
119
|
- lib/numo/libsvm.rb
|