numo-libsvm 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +29 -2
- data/ext/numo/libsvm/libsvmext.c +121 -92
- data/ext/numo/libsvm/libsvmext.h +1 -0
- data/ext/numo/libsvm/svm_problem.c +58 -0
- data/ext/numo/libsvm/svm_problem.h +12 -0
- data/lib/numo/libsvm/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 92066a1f30fa986bae110cffad133765bc1a2356
|
4
|
+
data.tar.gz: db27417bc129d089716a97df86b88c3dbd935ac1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 77064dc4a7128b69ca9130165378dd53a71d98ce8fc1cdccc1462a913d1a5f79ebb1394750cdb67122bf3e37ddfd986ade99c39cde299aef492da11b1d58f548
|
7
|
+
data.tar.gz: b37b562c48d434c21b94fff3d1718703011d69be188ed078936936d84fca3d0164cefafad1c880f25191eb72fccd2207824f3e434552b88dc51f736bea8090bb
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,7 +2,8 @@
|
|
2
2
|
|
3
3
|
[](https://travis-ci.org/yoshoku/numo-libsvm)
|
4
4
|
[](https://badge.fury.io/rb/numo-libsvm)
|
5
|
-
[](https://github.com/yoshoku/numo-libsvm/blob/master/LICENSE.txt)
|
6
|
+
[](https://www.rubydoc.info/gems/numo-libsvm/0.2.0)
|
6
7
|
|
7
8
|
Numo::Libsvm is a Ruby gem binding to the [LIBSVM](https://github.com/cjlin1/libsvm) library.
|
8
9
|
LIBSVM is one of the famous libraries that implemented Support Vector Machines,
|
@@ -167,6 +168,32 @@ The hyperparameter of SVM is given with Ruby Hash on Numo::Libsvm.
|
|
167
168
|
The hash key of hyperparameter and its meaning match the struct svm_parameter of LIBSVM.
|
168
169
|
The svm_parameter is detailed in [LIBSVM README](https://github.com/cjlin1/libsvm/blob/master/README).
|
169
170
|
|
171
|
+
```ruby
|
172
|
+
param = {
|
173
|
+
svm_type: # [Integer] Type of SVM
|
174
|
+
Numo::Libsvm::SvmType::C_SVC,
|
175
|
+
# for kernel function
|
176
|
+
kernel_type: # [Integer] Type of kernel function
|
177
|
+
Numo::Libsvm::KernelType::RBF,
|
178
|
+
degree: 3, # [Integer] Degree in polynomial kernel function
|
179
|
+
gamma: 0.5, # [Float] Gamma in poly/rbf/sigmoid kernel function
|
180
|
+
coef0: 1.0, # [Float] Coefficient in poly/sigmoid kernel function
|
181
|
+
# for training procedure
|
182
|
+
cache_size: 100, # [Float] Cache memory size in MB
|
183
|
+
eps: 1e-3, # [Float] Tolerance of termination criterion
|
184
|
+
C: 1.0, # [Float] Parameter C of C-SVC, epsilon-SVR, and nu-SVR
|
185
|
+
nr_weight: 3, # [Integer] Number of weights for C-SVC
|
186
|
+
weight_label: # [Numo::Int32] Labels to add weight in C-SVC
|
187
|
+
Numo::Int32[0, 1, 2],
|
188
|
+
weight: # [Numo::DFloat] Weight values in C-SVC
|
189
|
+
Numo::DFloat[0.4, 0.4, 0.2],
|
190
|
+
nu: 0.5, # [Float] Parameter nu of nu-SVC, one-class SVM, and nu-SVR
|
191
|
+
p: 0.1, # [Float] Parameter epsilon in loss function of epsilon-SVR
|
192
|
+
shrinking: true, # [Boolean] Whether to use the shrinking heuristics
|
193
|
+
probability: false # [Boolean] Whether to train a SVC or SVR model for probability estimates
|
194
|
+
}
|
195
|
+
```
|
196
|
+
|
170
197
|
## Contributing
|
171
198
|
|
172
199
|
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/numo-libsvm. This project is intended to be a safe, welcoming space for collaboration, and contributors are expected to adhere to the [Contributor Covenant](http://contributor-covenant.org) code of conduct.
|
@@ -177,4 +204,4 @@ The gem is available as open source under the terms of the [BSD-3-Clause License
|
|
177
204
|
|
178
205
|
## Code of Conduct
|
179
206
|
|
180
|
-
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/
|
207
|
+
Everyone interacting in the Numo::Libsvm project’s codebases, issue trackers, chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/numo-libsvm/blob/master/CODE_OF_CONDUCT.md).
|
data/ext/numo/libsvm/libsvmext.c
CHANGED
@@ -12,10 +12,13 @@ void print_null(const char *s) {}
|
|
12
12
|
* Train the SVM model according to the given training data.
|
13
13
|
*
|
14
14
|
* @overload train(x, y, param) -> Hash
|
15
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
16
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
17
|
+
* @param param [Hash] The parameters of an SVM model.
|
15
18
|
*
|
16
|
-
* @
|
17
|
-
*
|
18
|
-
*
|
19
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
|
20
|
+
* the sample array and label array do not have the same number of samples, or
|
21
|
+
* the hyperparameter has an invalid value, this error is raised.
|
19
22
|
* @return [Hash] The model obtained from the training procedure.
|
20
23
|
*/
|
21
24
|
static
|
@@ -23,16 +26,12 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
23
26
|
{
|
24
27
|
struct svm_problem* problem;
|
25
28
|
struct svm_parameter* param;
|
26
|
-
narray_t* x_nary;
|
27
|
-
double* x_pt;
|
28
|
-
double* y_pt;
|
29
|
-
int i, j;
|
30
|
-
int n_samples;
|
31
|
-
int n_features;
|
32
29
|
struct svm_model* model;
|
30
|
+
narray_t* x_nary;
|
31
|
+
narray_t* y_nary;
|
32
|
+
char* err_msg;
|
33
33
|
VALUE model_hash;
|
34
34
|
|
35
|
-
/* Obtain C data structures. */
|
36
35
|
if (CLASS_OF(x_val) != numo_cDFloat) {
|
37
36
|
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
38
37
|
}
|
@@ -45,41 +44,39 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
45
44
|
if (!RTEST(nary_check_contiguous(y_val))) {
|
46
45
|
y_val = nary_dup(y_val);
|
47
46
|
}
|
48
|
-
GetNArray(x_val, x_nary);
|
49
|
-
param = rb_hash_to_svm_parameter(param_hash);
|
50
47
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
48
|
+
GetNArray(x_val, x_nary);
|
49
|
+
GetNArray(y_val, y_nary);
|
50
|
+
if (NA_NDIM(x_nary) != 2) {
|
51
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
52
|
+
return Qnil;
|
53
|
+
}
|
54
|
+
if (NA_NDIM(y_nary) != 1) {
|
55
|
+
rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
|
56
|
+
return Qnil;
|
57
|
+
}
|
58
|
+
if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
|
59
|
+
rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
|
60
|
+
return Qnil;
|
61
|
+
}
|
56
62
|
|
57
|
-
|
58
|
-
problem =
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
problem->x[i][j].value = x_pt[i * n_features + j];
|
67
|
-
}
|
68
|
-
problem->x[i][n_features].index = -1;
|
69
|
-
problem->x[i][n_features].value = 0.0;
|
70
|
-
problem->y[i] = y_pt[i];
|
63
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
64
|
+
problem = dataset_to_svm_problem(x_val, y_val);
|
65
|
+
|
66
|
+
err_msg = svm_check_parameter(problem, param);
|
67
|
+
if (err_msg) {
|
68
|
+
xfree_svm_problem(problem);
|
69
|
+
xfree_svm_parameter(param);
|
70
|
+
rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
|
71
|
+
return Qnil;
|
71
72
|
}
|
72
73
|
|
73
|
-
/* Perform training. */
|
74
74
|
svm_set_print_string_function(print_null);
|
75
75
|
model = svm_train(problem, param);
|
76
76
|
model_hash = svm_model_to_rb_hash(model);
|
77
77
|
svm_free_and_destroy_model(&model);
|
78
78
|
|
79
|
-
|
80
|
-
xfree(problem->x);
|
81
|
-
xfree(problem->y);
|
82
|
-
xfree(problem);
|
79
|
+
xfree_svm_problem(problem);
|
83
80
|
xfree_svm_parameter(param);
|
84
81
|
|
85
82
|
return model_hash;
|
@@ -90,30 +87,29 @@ VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
|
|
90
87
|
* The predicted labels or values in the validation process are returned.
|
91
88
|
*
|
92
89
|
* @overload cv(x, y, param, n_folds) -> Numo::DFloat
|
90
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
|
91
|
+
* @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
|
92
|
+
* @param param [Hash] The parameters of an SVM model.
|
93
|
+
* @param n_folds [Integer] The number of folds.
|
93
94
|
*
|
94
|
-
* @
|
95
|
-
*
|
96
|
-
*
|
97
|
-
* @param n_folds [Integer] The number of folds.
|
95
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
|
96
|
+
* the sample array and label array do not have the same number of samples, or
|
97
|
+
* the hyperparameter has an invalid value, this error is raised.
|
98
98
|
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
99
99
|
*/
|
100
100
|
static
|
101
101
|
VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds)
|
102
102
|
{
|
103
103
|
const int n_folds = NUM2INT(nr_folds);
|
104
|
-
struct svm_problem* problem;
|
105
|
-
struct svm_parameter* param;
|
106
|
-
narray_t* x_nary;
|
107
|
-
double* x_pt;
|
108
|
-
double* y_pt;
|
109
|
-
int i, j;
|
110
|
-
int n_samples;
|
111
|
-
int n_features;
|
112
104
|
size_t t_shape[1];
|
113
105
|
VALUE t_val;
|
114
106
|
double* t_pt;
|
107
|
+
narray_t* x_nary;
|
108
|
+
narray_t* y_nary;
|
109
|
+
char* err_msg;
|
110
|
+
struct svm_problem* problem;
|
111
|
+
struct svm_parameter* param;
|
115
112
|
|
116
|
-
/* Obtain C data structures. */
|
117
113
|
if (CLASS_OF(x_val) != numo_cDFloat) {
|
118
114
|
x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
|
119
115
|
}
|
@@ -126,42 +122,41 @@ VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, V
|
|
126
122
|
if (!RTEST(nary_check_contiguous(y_val))) {
|
127
123
|
y_val = nary_dup(y_val);
|
128
124
|
}
|
129
|
-
GetNArray(x_val, x_nary);
|
130
|
-
param = rb_hash_to_svm_parameter(param_hash);
|
131
125
|
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
126
|
+
GetNArray(x_val, x_nary);
|
127
|
+
GetNArray(y_val, y_nary);
|
128
|
+
if (NA_NDIM(x_nary) != 2) {
|
129
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
130
|
+
return Qnil;
|
131
|
+
}
|
132
|
+
if (NA_NDIM(y_nary) != 1) {
|
133
|
+
rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
|
134
|
+
return Qnil;
|
135
|
+
}
|
136
|
+
if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
|
137
|
+
rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
|
138
|
+
return Qnil;
|
139
|
+
}
|
137
140
|
|
138
|
-
|
139
|
-
problem =
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
problem->x[i][j].value = x_pt[i * n_features + j];
|
148
|
-
}
|
149
|
-
problem->x[i][n_features].index = -1;
|
150
|
-
problem->x[i][n_features].value = 0.0;
|
151
|
-
problem->y[i] = y_pt[i];
|
141
|
+
param = rb_hash_to_svm_parameter(param_hash);
|
142
|
+
problem = dataset_to_svm_problem(x_val, y_val);
|
143
|
+
|
144
|
+
err_msg = svm_check_parameter(problem, param);
|
145
|
+
if (err_msg) {
|
146
|
+
xfree_svm_problem(problem);
|
147
|
+
xfree_svm_parameter(param);
|
148
|
+
rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
|
149
|
+
return Qnil;
|
152
150
|
}
|
153
151
|
|
154
|
-
|
155
|
-
t_shape[0] = n_samples;
|
152
|
+
t_shape[0] = problem->l;
|
156
153
|
t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
|
157
154
|
t_pt = (double*)na_get_pointer_for_write(t_val);
|
155
|
+
|
158
156
|
svm_set_print_string_function(print_null);
|
159
157
|
svm_cross_validation(problem, param, n_folds, t_pt);
|
160
158
|
|
161
|
-
|
162
|
-
xfree(problem->x);
|
163
|
-
xfree(problem->y);
|
164
|
-
xfree(problem);
|
159
|
+
xfree_svm_problem(problem);
|
165
160
|
xfree_svm_parameter(param);
|
166
161
|
|
167
162
|
return t_val;
|
@@ -171,10 +166,11 @@ VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, V
|
|
171
166
|
* Predict class labels or values for given samples.
|
172
167
|
*
|
173
168
|
* @overload predict(x, param, model) -> Numo::DFloat
|
169
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
170
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
171
|
+
* @param model [Hash] The model obtained from the training procedure.
|
174
172
|
*
|
175
|
-
* @
|
176
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
177
|
-
* @param model [Hash] The model obtained from the training procedure.
|
173
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
178
174
|
* @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
|
179
175
|
*/
|
180
176
|
static
|
@@ -199,7 +195,13 @@ VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
199
195
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
200
196
|
x_val = nary_dup(x_val);
|
201
197
|
}
|
198
|
+
|
202
199
|
GetNArray(x_val, x_nary);
|
200
|
+
if (NA_NDIM(x_nary) != 2) {
|
201
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
202
|
+
return Qnil;
|
203
|
+
}
|
204
|
+
|
203
205
|
param = rb_hash_to_svm_parameter(param_hash);
|
204
206
|
model = rb_hash_to_svm_model(model_hash);
|
205
207
|
model->param = *param;
|
@@ -235,10 +237,11 @@ VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
235
237
|
* Calculate decision values for given samples.
|
236
238
|
*
|
237
239
|
* @overload decision_function(x, param, model) -> Numo::DFloat
|
240
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
|
241
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
242
|
+
* @param model [Hash] The model obtained from the training procedure.
|
238
243
|
*
|
239
|
-
* @
|
240
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
241
|
-
* @param model [Hash] The model obtained from the training procedure.
|
244
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
242
245
|
* @return [Numo::DFloat] (shape: [n_samples, n_classes * (n_classes - 1) / 2]) The decision value of each sample.
|
243
246
|
*/
|
244
247
|
static
|
@@ -265,7 +268,13 @@ VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_h
|
|
265
268
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
266
269
|
x_val = nary_dup(x_val);
|
267
270
|
}
|
271
|
+
|
268
272
|
GetNArray(x_val, x_nary);
|
273
|
+
if (NA_NDIM(x_nary) != 2) {
|
274
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
275
|
+
return Qnil;
|
276
|
+
}
|
277
|
+
|
269
278
|
param = rb_hash_to_svm_parameter(param_hash);
|
270
279
|
model = rb_hash_to_svm_model(model_hash);
|
271
280
|
model->param = *param;
|
@@ -331,10 +340,11 @@ VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_h
|
|
331
340
|
* The parameter ':probability' set to 1 in training procedure.
|
332
341
|
*
|
333
342
|
* @overload predict_proba(x, param, model) -> Numo::DFloat
|
343
|
+
* @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the class probabilities.
|
344
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
345
|
+
* @param model [Hash] The model obtained from the training procedure.
|
334
346
|
*
|
335
|
-
* @
|
336
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
337
|
-
* @param model [Hash] The model obtained from the training procedure.
|
347
|
+
* @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
|
338
348
|
* @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probablity of each class per sample.
|
339
349
|
*/
|
340
350
|
static
|
@@ -353,6 +363,12 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
353
363
|
int n_samples;
|
354
364
|
int n_features;
|
355
365
|
|
366
|
+
GetNArray(x_val, x_nary);
|
367
|
+
if (NA_NDIM(x_nary) != 2) {
|
368
|
+
rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
|
369
|
+
return Qnil;
|
370
|
+
}
|
371
|
+
|
356
372
|
param = rb_hash_to_svm_parameter(param_hash);
|
357
373
|
model = rb_hash_to_svm_model(model_hash);
|
358
374
|
model->param = *param;
|
@@ -365,7 +381,6 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
365
381
|
if (!RTEST(nary_check_contiguous(x_val))) {
|
366
382
|
x_val = nary_dup(x_val);
|
367
383
|
}
|
368
|
-
GetNArray(x_val, x_nary);
|
369
384
|
|
370
385
|
/* Initialize some variables. */
|
371
386
|
n_samples = (int)NA_SHAPE(x_nary)[0];
|
@@ -405,16 +420,23 @@ VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
|
|
405
420
|
* Load the SVM parameters and model from a text file with LIBSVM format.
|
406
421
|
*
|
407
422
|
* @param filename [String] The path to a file to load.
|
423
|
+
* @raise [IOError] This error raises when failed to load the model file.
|
408
424
|
* @return [Array] Array contains the SVM parameters and model.
|
409
425
|
*/
|
410
426
|
static
|
411
427
|
VALUE load_svm_model(VALUE self, VALUE filename)
|
412
428
|
{
|
413
|
-
|
429
|
+
char* filename_ = StringValuePtr(filename);
|
430
|
+
struct svm_model* model = svm_load_model(filename_);
|
414
431
|
VALUE res = rb_ary_new2(2);
|
415
432
|
VALUE param_hash = Qnil;
|
416
433
|
VALUE model_hash = Qnil;
|
417
434
|
|
435
|
+
if (model == NULL) {
|
436
|
+
rb_raise(rb_eIOError, "Failed to load file '%s'", filename_);
|
437
|
+
return Qnil;
|
438
|
+
}
|
439
|
+
|
418
440
|
if (model) {
|
419
441
|
param_hash = svm_parameter_to_rb_hash(&(model->param));
|
420
442
|
model_hash = svm_model_to_rb_hash(model);
|
@@ -432,26 +454,33 @@ VALUE load_svm_model(VALUE self, VALUE filename)
|
|
432
454
|
* Note that the svm_save_model saves only the parameters necessary for estimation with the trained model.
|
433
455
|
*
|
434
456
|
* @overload save_svm_model(filename, param, model) -> Boolean
|
457
|
+
* @param filename [String] The path to a file to save.
|
458
|
+
* @param param [Hash] The parameters of the trained SVM model.
|
459
|
+
* @param model [Hash] The model obtained from the training procedure.
|
435
460
|
*
|
436
|
-
* @
|
437
|
-
* @param param [Hash] The parameters of the trained SVM model.
|
438
|
-
* @param model [Hash] The model obtained from the training procedure.
|
461
|
+
* @raise [IOError] This error raises when failed to save the model file.
|
439
462
|
* @return [Boolean] true on success, or false if an error occurs.
|
440
463
|
*/
|
441
464
|
static
|
442
465
|
VALUE save_svm_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash)
|
443
466
|
{
|
467
|
+
char* filename_ = StringValuePtr(filename);
|
444
468
|
struct svm_parameter* param = rb_hash_to_svm_parameter(param_hash);
|
445
469
|
struct svm_model* model = rb_hash_to_svm_model(model_hash);
|
446
470
|
int res;
|
447
471
|
|
448
472
|
model->param = *param;
|
449
|
-
res = svm_save_model(
|
473
|
+
res = svm_save_model(filename_, model);
|
450
474
|
|
451
475
|
xfree_svm_model(model);
|
452
476
|
xfree_svm_parameter(param);
|
453
477
|
|
454
|
-
|
478
|
+
if (res < 0) {
|
479
|
+
rb_raise(rb_eIOError, "Failed to save file '%s'", filename_);
|
480
|
+
return Qfalse;
|
481
|
+
}
|
482
|
+
|
483
|
+
return Qtrue;
|
455
484
|
}
|
456
485
|
|
457
486
|
void Init_libsvmext()
|
data/ext/numo/libsvm/libsvmext.h
CHANGED
@@ -0,0 +1,58 @@
|
|
1
|
+
#include "svm_problem.h"
|
2
|
+
|
3
|
+
void xfree_svm_problem(struct svm_problem* problem)
|
4
|
+
{
|
5
|
+
int i;
|
6
|
+
if (problem) {
|
7
|
+
if (problem->x) {
|
8
|
+
for (i = 0; i < problem->l; i++) {
|
9
|
+
if (problem->x[i]) {
|
10
|
+
xfree(problem->x[i]);
|
11
|
+
problem->x[i] = NULL;
|
12
|
+
}
|
13
|
+
}
|
14
|
+
xfree(problem->x);
|
15
|
+
problem->x = NULL;
|
16
|
+
}
|
17
|
+
if (problem->y) {
|
18
|
+
xfree(problem->y);
|
19
|
+
problem->y = NULL;
|
20
|
+
}
|
21
|
+
xfree(problem);
|
22
|
+
problem = NULL;
|
23
|
+
}
|
24
|
+
}
|
25
|
+
|
26
|
+
struct svm_problem* dataset_to_svm_problem(VALUE x_val, VALUE y_val)
|
27
|
+
{
|
28
|
+
struct svm_problem* problem;
|
29
|
+
narray_t* x_nary;
|
30
|
+
double* x_pt;
|
31
|
+
double* y_pt;
|
32
|
+
int i, j;
|
33
|
+
int n_samples;
|
34
|
+
int n_features;
|
35
|
+
|
36
|
+
GetNArray(x_val, x_nary);
|
37
|
+
n_samples = (int)NA_SHAPE(x_nary)[0];
|
38
|
+
n_features = (int)NA_SHAPE(x_nary)[1];
|
39
|
+
x_pt = (double*)na_get_pointer_for_read(x_val);
|
40
|
+
y_pt = (double*)na_get_pointer_for_read(y_val);
|
41
|
+
|
42
|
+
problem = ALLOC(struct svm_problem);
|
43
|
+
problem->l = n_samples;
|
44
|
+
problem->x = ALLOC_N(struct svm_node*, n_samples);
|
45
|
+
problem->y = ALLOC_N(double, n_samples);
|
46
|
+
for (i = 0; i < n_samples; i++) {
|
47
|
+
problem->x[i] = ALLOC_N(struct svm_node, n_features + 1);
|
48
|
+
for (j = 0; j < n_features; j++) {
|
49
|
+
problem->x[i][j].index = j + 1;
|
50
|
+
problem->x[i][j].value = x_pt[i * n_features + j];
|
51
|
+
}
|
52
|
+
problem->x[i][n_features].index = -1;
|
53
|
+
problem->x[i][n_features].value = 0.0;
|
54
|
+
problem->y[i] = y_pt[i];
|
55
|
+
}
|
56
|
+
|
57
|
+
return problem;
|
58
|
+
}
|
@@ -0,0 +1,12 @@
|
|
1
|
+
#ifndef NUMO_LIBSVM_SVM_PROBLEM_H
|
2
|
+
#define NUMO_LIBSVM_SVM_PROBLEM_H 1
|
3
|
+
|
4
|
+
#include <svm.h>
|
5
|
+
#include <ruby.h>
|
6
|
+
#include <numo/narray.h>
|
7
|
+
#include <numo/template.h>
|
8
|
+
|
9
|
+
void xfree_svm_problem(struct svm_problem* problem);
|
10
|
+
struct svm_problem* dataset_to_svm_problem(VALUE x_val, VALUE y_val);
|
11
|
+
|
12
|
+
#endif /* NUMO_LIBSVM_SVM_PROBLEM_H */
|
data/lib/numo/libsvm/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: numo-libsvm
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.2.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-08-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,6 +112,8 @@ files:
|
|
112
112
|
- ext/numo/libsvm/svm_model.h
|
113
113
|
- ext/numo/libsvm/svm_parameter.c
|
114
114
|
- ext/numo/libsvm/svm_parameter.h
|
115
|
+
- ext/numo/libsvm/svm_problem.c
|
116
|
+
- ext/numo/libsvm/svm_problem.h
|
115
117
|
- ext/numo/libsvm/svm_type.c
|
116
118
|
- ext/numo/libsvm/svm_type.h
|
117
119
|
- lib/numo/libsvm.rb
|