numo-libsvm 1.1.2 → 2.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,721 @@
1
+ /**
2
+ * Copyright (c) 2019-2022 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #ifndef LIBSVMEXT_HPP
32
+ #define LIBSVMEXT_HPP 1
33
+
34
+ #include <cmath>
35
+ #include <cstring>
36
+
37
+ #include <ruby.h>
38
+
39
+ #include <numo/narray.h>
40
+ #include <numo/template.h>
41
+
42
+ #include <svm.h>
43
+
44
+ typedef struct svm_model LibSvmModel;
45
+ typedef struct svm_node LibSvmNode;
46
+ typedef struct svm_parameter LibSvmParameter;
47
+ typedef struct svm_problem LibSvmProblem;
48
+
49
+ void printNull(const char* s) {}
50
+
51
+ #define NR_MARKS 10
52
+
53
+ /** CONVERTERS */
54
+ VALUE convertVectorXiToNArray(const int* const arr, const int size) {
55
+ size_t shape[1] = {(size_t)size};
56
+ VALUE vec_val = rb_narray_new(numo_cInt32, 1, shape);
57
+ int32_t* vec_ptr = (int32_t*)na_get_pointer_for_write(vec_val);
58
+ for (int i = 0; i < size; i++) vec_ptr[i] = (int32_t)arr[i];
59
+ return vec_val;
60
+ }
61
+
62
+ int* convertNArrayToVectorXi(VALUE vec_val) {
63
+ if (NIL_P(vec_val)) return NULL;
64
+
65
+ narray_t* vec_nary;
66
+ GetNArray(vec_val, vec_nary);
67
+ const size_t n_elements = NA_SHAPE(vec_nary)[0];
68
+ int* arr = ALLOC_N(int, n_elements);
69
+ const int32_t* const vec_ptr = (int32_t*)na_get_pointer_for_read(vec_val);
70
+ for (size_t i = 0; i < n_elements; i++) arr[i] = (int)vec_ptr[i];
71
+
72
+ RB_GC_GUARD(vec_val);
73
+
74
+ return arr;
75
+ }
76
+
77
+ VALUE convertVectorXdToNArray(const double* const arr, const int size) {
78
+ size_t shape[1] = {(size_t)size};
79
+ VALUE vec_val = rb_narray_new(numo_cDFloat, 1, shape);
80
+ double* vec_ptr = (double*)na_get_pointer_for_write(vec_val);
81
+ memcpy(vec_ptr, arr, size * sizeof(double));
82
+ return vec_val;
83
+ }
84
+
85
+ double* convertNArrayToVectorXd(VALUE vec_val) {
86
+ if (NIL_P(vec_val)) return NULL;
87
+
88
+ narray_t* vec_nary;
89
+ GetNArray(vec_val, vec_nary);
90
+ const size_t n_elements = NA_SHAPE(vec_nary)[0];
91
+ double* arr = ALLOC_N(double, n_elements);
92
+ const double* const vec_ptr = (double*)na_get_pointer_for_read(vec_val);
93
+ memcpy(arr, vec_ptr, n_elements * sizeof(double));
94
+
95
+ RB_GC_GUARD(vec_val);
96
+
97
+ return arr;
98
+ }
99
+
100
+ VALUE convertMatrixXdToNArray(const double* const* mat, const int n_rows, const int n_cols) {
101
+ size_t shape[2] = {(size_t)n_rows, (size_t)n_cols};
102
+ VALUE mat_val = rb_narray_new(numo_cDFloat, 2, shape);
103
+ double* mat_ptr = (double*)na_get_pointer_for_write(mat_val);
104
+ for (int i = 0; i < n_rows; i++) memcpy(&mat_ptr[i * n_cols], mat[i], n_cols * sizeof(double));
105
+ return mat_val;
106
+ }
107
+
108
+ double** convertNArrayToMatrixXd(VALUE mat_val) {
109
+ if (NIL_P(mat_val)) return NULL;
110
+
111
+ narray_t* mat_nary;
112
+ GetNArray(mat_val, mat_nary);
113
+ const size_t n_rows = NA_SHAPE(mat_nary)[0];
114
+ const size_t n_cols = NA_SHAPE(mat_nary)[1];
115
+ const double* const mat_ptr = (double*)na_get_pointer_for_read(mat_val);
116
+ double** mat = ALLOC_N(double*, n_rows);
117
+ for (size_t i = 0; i < n_rows; i++) {
118
+ mat[i] = ALLOC_N(double, n_cols);
119
+ memcpy(mat[i], &mat_ptr[i * n_cols], n_cols * sizeof(double));
120
+ }
121
+
122
+ RB_GC_GUARD(mat_val);
123
+
124
+ return mat;
125
+ }
126
+
127
+ VALUE convertLibSvmNodeToNArray(const LibSvmNode* const* support_vecs, const int n_support_vecs) {
128
+ int n_dimensions = 0;
129
+ for (int i = 0; i < n_support_vecs; i++) {
130
+ for (int j = 0; support_vecs[i][j].index != -1; j++) {
131
+ if (n_dimensions < support_vecs[i][j].index) {
132
+ n_dimensions = support_vecs[i][j].index;
133
+ }
134
+ }
135
+ }
136
+
137
+ size_t shape[2] = {(size_t)n_support_vecs, (size_t)n_dimensions};
138
+ VALUE vec_val = rb_narray_new(numo_cDFloat, 2, shape);
139
+ double* vec_ptr = (double*)na_get_pointer_for_write(vec_val);
140
+ memset(vec_ptr, 0, n_support_vecs * n_dimensions * sizeof(double));
141
+ for (int i = 0; i < n_support_vecs; i++) {
142
+ for (int j = 0; support_vecs[i][j].index != -1; j++) {
143
+ vec_ptr[i * n_dimensions + support_vecs[i][j].index - 1] = support_vecs[i][j].value;
144
+ }
145
+ }
146
+
147
+ return vec_val;
148
+ }
149
+
150
+ LibSvmNode** convertNArrayToLibSvmNode(VALUE vec_val) {
151
+ if (NIL_P(vec_val)) return NULL;
152
+
153
+ narray_t* vec_nary;
154
+ GetNArray(vec_val, vec_nary);
155
+ const size_t n_rows = NA_SHAPE(vec_nary)[0];
156
+ const size_t n_cols = NA_SHAPE(vec_nary)[1];
157
+ const double* const vec_ptr = (double*)na_get_pointer_for_read(vec_val);
158
+ LibSvmNode** support_vecs = ALLOC_N(LibSvmNode*, n_rows);
159
+ for (size_t i = 0; i < n_rows; i++) {
160
+ int n_nonzero_cols = 0;
161
+ for (size_t j = 0; j < n_cols; j++) {
162
+ if (vec_ptr[i * n_cols + j] != 0) {
163
+ n_nonzero_cols++;
164
+ }
165
+ }
166
+ support_vecs[i] = ALLOC_N(LibSvmNode, n_nonzero_cols + 1);
167
+ for (size_t j = 0, k = 0; j < n_cols; j++) {
168
+ if (vec_ptr[i * n_cols + j] != 0) {
169
+ support_vecs[i][k].index = j + 1;
170
+ support_vecs[i][k].value = vec_ptr[i * n_cols + j];
171
+ k++;
172
+ }
173
+ }
174
+ support_vecs[i][n_nonzero_cols].index = -1;
175
+ support_vecs[i][n_nonzero_cols].value = 0.0;
176
+ }
177
+
178
+ RB_GC_GUARD(vec_val);
179
+
180
+ return support_vecs;
181
+ }
182
+
183
+ LibSvmNode* convertVectorXdToLibSvmNode(const double* const arr, const int size) {
184
+ int n_nonzero_elements = 0;
185
+ for (int i = 0; i < size; i++) {
186
+ if (arr[i] != 0.0) n_nonzero_elements++;
187
+ }
188
+
189
+ LibSvmNode* node = ALLOC_N(LibSvmNode, n_nonzero_elements + 1);
190
+ for (int i = 0, j = 0; i < size; i++) {
191
+ if (arr[i] != 0.0) {
192
+ node[j].index = i + 1;
193
+ node[j].value = arr[i];
194
+ j++;
195
+ }
196
+ }
197
+ node[n_nonzero_elements].index = -1;
198
+ node[n_nonzero_elements].value = 0.0;
199
+
200
+ return node;
201
+ }
202
+
203
+ LibSvmModel* convertHashToLibSvmModel(VALUE model_hash) {
204
+ LibSvmModel* model = ALLOC(LibSvmModel);
205
+ VALUE el;
206
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nr_class")));
207
+ model->nr_class = !NIL_P(el) ? NUM2INT(el) : 0;
208
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("l")));
209
+ model->l = !NIL_P(el) ? NUM2INT(el) : 0;
210
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("SV")));
211
+ model->SV = convertNArrayToLibSvmNode(el);
212
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_coef")));
213
+ model->sv_coef = convertNArrayToMatrixXd(el);
214
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("rho")));
215
+ model->rho = convertNArrayToVectorXd(el);
216
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probA")));
217
+ model->probA = convertNArrayToVectorXd(el);
218
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probB")));
219
+ model->probB = convertNArrayToVectorXd(el);
220
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("prob_density_marks")));
221
+ model->prob_density_marks = convertNArrayToVectorXd(el);
222
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_indices")));
223
+ model->sv_indices = convertNArrayToVectorXi(el);
224
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("label")));
225
+ model->label = convertNArrayToVectorXi(el);
226
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nSV")));
227
+ model->nSV = convertNArrayToVectorXi(el);
228
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("free_sv")));
229
+ model->free_sv = !NIL_P(el) ? NUM2INT(el) : 0;
230
+ return model;
231
+ }
232
+
233
+ VALUE convertLibSvmModelToHash(const LibSvmModel* const model) {
234
+ const int n_classes = model->nr_class;
235
+ const int n_support_vecs = model->l;
236
+ VALUE support_vecs = model->SV ? convertLibSvmNodeToNArray(model->SV, n_support_vecs) : Qnil;
237
+ VALUE coefficients = model->sv_coef ? convertMatrixXdToNArray(model->sv_coef, n_classes - 1, n_support_vecs) : Qnil;
238
+ VALUE intercepts = model->rho ? convertVectorXdToNArray(model->rho, n_classes * (n_classes - 1) / 2) : Qnil;
239
+ VALUE prob_alpha = model->probA ? convertVectorXdToNArray(model->probA, n_classes * (n_classes - 1) / 2) : Qnil;
240
+ VALUE prob_beta = model->probB ? convertVectorXdToNArray(model->probB, n_classes * (n_classes - 1) / 2) : Qnil;
241
+ VALUE prob_density_marks = model->prob_density_marks ? convertVectorXdToNArray(model->prob_density_marks, NR_MARKS) : Qnil;
242
+ VALUE sv_indices = model->sv_indices ? convertVectorXiToNArray(model->sv_indices, n_support_vecs) : Qnil;
243
+ VALUE labels = model->label ? convertVectorXiToNArray(model->label, n_classes) : Qnil;
244
+ VALUE n_support_vecs_each_class = model->nSV ? convertVectorXiToNArray(model->nSV, n_classes) : Qnil;
245
+ VALUE model_hash = rb_hash_new();
246
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("nr_class")), INT2NUM(n_classes));
247
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("l")), INT2NUM(n_support_vecs));
248
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("SV")), support_vecs);
249
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_coef")), coefficients);
250
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("rho")), intercepts);
251
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("probA")), prob_alpha);
252
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("probB")), prob_beta);
253
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("prob_density_marks")), prob_density_marks);
254
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_indices")), sv_indices);
255
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("label")), labels);
256
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("nSV")), n_support_vecs_each_class);
257
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("free_sv")), INT2NUM(model->free_sv));
258
+ return model_hash;
259
+ }
260
+
261
+ LibSvmParameter* convertHashToLibSvmParameter(VALUE param_hash) {
262
+ LibSvmParameter* param = ALLOC(LibSvmParameter);
263
+ VALUE el;
264
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("svm_type")));
265
+ param->svm_type = !NIL_P(el) ? NUM2INT(el) : C_SVC;
266
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("kernel_type")));
267
+ param->kernel_type = !NIL_P(el) ? NUM2INT(el) : RBF;
268
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("degree")));
269
+ param->degree = !NIL_P(el) ? NUM2INT(el) : 3;
270
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("gamma")));
271
+ param->gamma = !NIL_P(el) ? NUM2DBL(el) : 1;
272
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("coef0")));
273
+ param->coef0 = !NIL_P(el) ? NUM2DBL(el) : 0;
274
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("cache_size")));
275
+ param->cache_size = !NIL_P(el) ? NUM2DBL(el) : 100;
276
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("eps")));
277
+ param->eps = !NIL_P(el) ? NUM2DBL(el) : 1e-3;
278
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("C")));
279
+ param->C = !NIL_P(el) ? NUM2DBL(el) : 1;
280
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("nr_weight")));
281
+ param->nr_weight = !NIL_P(el) ? NUM2INT(el) : 0;
282
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("nu")));
283
+ param->nu = !NIL_P(el) ? NUM2DBL(el) : 0.5;
284
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("p")));
285
+ param->p = !NIL_P(el) ? NUM2DBL(el) : 0.1;
286
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("shrinking")));
287
+ param->shrinking = RB_TYPE_P(el, T_FALSE) ? 0 : 1;
288
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("probability")));
289
+ param->probability = RB_TYPE_P(el, T_TRUE) ? 1 : 0;
290
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("weight_label")));
291
+ param->weight_label = NULL;
292
+ if (!NIL_P(el)) {
293
+ param->weight_label = ALLOC_N(int, param->nr_weight);
294
+ memcpy(param->weight_label, (int32_t*)na_get_pointer_for_read(el), param->nr_weight * sizeof(int32_t));
295
+ }
296
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("weight")));
297
+ param->weight = NULL;
298
+ if (!NIL_P(el)) {
299
+ param->weight = ALLOC_N(double, param->nr_weight);
300
+ memcpy(param->weight, (double*)na_get_pointer_for_read(el), param->nr_weight * sizeof(double));
301
+ }
302
+ return param;
303
+ }
304
+
305
+ VALUE convertLibSvmParameterToHash(const LibSvmParameter* const param) {
306
+ VALUE param_hash = rb_hash_new();
307
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("svm_type")), INT2NUM(param->svm_type));
308
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("kernel_type")), INT2NUM(param->kernel_type));
309
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("degree")), INT2NUM(param->degree));
310
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("gamma")), DBL2NUM(param->gamma));
311
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("coef0")), DBL2NUM(param->coef0));
312
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("cache_size")), DBL2NUM(param->cache_size));
313
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("eps")), DBL2NUM(param->eps));
314
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("C")), DBL2NUM(param->C));
315
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("nr_weight")), INT2NUM(param->nr_weight));
316
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("nu")), DBL2NUM(param->nu));
317
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("p")), DBL2NUM(param->p));
318
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("shrinking")), param->shrinking == 1 ? Qtrue : Qfalse);
319
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("probability")), param->probability == 1 ? Qtrue : Qfalse);
320
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("weight_label")),
321
+ param->weight_label ? convertVectorXiToNArray(param->weight_label, param->nr_weight) : Qnil);
322
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("weight")),
323
+ param->weight ? convertVectorXdToNArray(param->weight, param->nr_weight) : Qnil);
324
+ return param_hash;
325
+ }
326
+
327
+ LibSvmProblem* convertDatasetToLibSvmProblem(VALUE x_val, VALUE y_val) {
328
+ narray_t* x_nary;
329
+ GetNArray(x_val, x_nary);
330
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
331
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
332
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
333
+ const double* const y_ptr = (double*)na_get_pointer_for_read(y_val);
334
+
335
+ LibSvmProblem* problem = ALLOC(LibSvmProblem);
336
+ problem->l = n_samples;
337
+ problem->x = ALLOC_N(LibSvmNode*, n_samples);
338
+ problem->y = ALLOC_N(double, n_samples);
339
+
340
+ int last_feature_id = 0;
341
+ bool is_padded = false;
342
+ for (int i = 0; i < n_samples; i++) {
343
+ int n_nonzero_features = 0;
344
+ for (int j = 0; j < n_features; j++) {
345
+ if (x_ptr[i * n_features + j] != 0.0) {
346
+ n_nonzero_features += 1;
347
+ last_feature_id = j + 1;
348
+ }
349
+ }
350
+ if (!is_padded && last_feature_id == n_features) is_padded = true;
351
+ if (is_padded) {
352
+ problem->x[i] = ALLOC_N(LibSvmNode, n_nonzero_features + 1);
353
+ } else {
354
+ problem->x[i] = ALLOC_N(LibSvmNode, n_nonzero_features + 2);
355
+ }
356
+ for (int j = 0, k = 0; j < n_features; j++) {
357
+ if (x_ptr[i * n_features + j] != 0.0) {
358
+ problem->x[i][k].index = j + 1;
359
+ problem->x[i][k].value = x_ptr[i * n_features + j];
360
+ k++;
361
+ }
362
+ }
363
+ if (is_padded) {
364
+ problem->x[i][n_nonzero_features].index = -1;
365
+ problem->x[i][n_nonzero_features].value = 0.0;
366
+ } else {
367
+ problem->x[i][n_nonzero_features].index = n_features;
368
+ problem->x[i][n_nonzero_features].value = 0.0;
369
+ problem->x[i][n_nonzero_features + 1].index = -1;
370
+ problem->x[i][n_nonzero_features + 1].value = 0.0;
371
+ }
372
+ problem->y[i] = y_ptr[i];
373
+ }
374
+
375
+ RB_GC_GUARD(x_val);
376
+ RB_GC_GUARD(y_val);
377
+
378
+ return problem;
379
+ }
380
+
381
+ /** UTILITIES */
382
+ bool isSignleOutputModel(LibSvmModel* model) {
383
+ return (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR);
384
+ }
385
+
386
+ bool isProbabilisticModel(LibSvmModel* model) { return svm_check_probability_model(model) != 0; }
387
+
388
+ void deleteLibSvmModel(LibSvmModel* model) {
389
+ if (model) {
390
+ if (model->SV) {
391
+ for (int i = 0; i < model->l; i++) xfree(model->SV[i]);
392
+ xfree(model->SV);
393
+ model->SV = NULL;
394
+ }
395
+ if (model->sv_coef) {
396
+ for (int i = 0; i < model->nr_class - 1; i++) xfree(model->sv_coef[i]);
397
+ xfree(model->sv_coef);
398
+ model->sv_coef = NULL;
399
+ }
400
+ xfree(model->rho);
401
+ model->rho = NULL;
402
+ xfree(model->probA);
403
+ model->probA = NULL;
404
+ xfree(model->probB);
405
+ model->probB = NULL;
406
+ xfree(model->prob_density_marks);
407
+ model->prob_density_marks = NULL;
408
+ xfree(model->sv_indices);
409
+ model->sv_indices = NULL;
410
+ xfree(model->label);
411
+ model->label = NULL;
412
+ xfree(model->nSV);
413
+ model->nSV = NULL;
414
+ xfree(model);
415
+ model = NULL;
416
+ }
417
+ }
418
+
419
+ void deleteLibSvmParameter(LibSvmParameter* param) {
420
+ if (param) {
421
+ if (param->weight_label) {
422
+ xfree(param->weight_label);
423
+ param->weight_label = NULL;
424
+ }
425
+ if (param->weight) {
426
+ xfree(param->weight);
427
+ param->weight = NULL;
428
+ }
429
+ xfree(param);
430
+ param = NULL;
431
+ }
432
+ }
433
+
434
+ void deleteLibSvmProblem(LibSvmProblem* problem) {
435
+ if (problem) {
436
+ if (problem->x) {
437
+ for (int i = 0; i < problem->l; i++) {
438
+ if (problem->x[i]) {
439
+ xfree(problem->x[i]);
440
+ problem->x[i] = NULL;
441
+ }
442
+ }
443
+ xfree(problem->x);
444
+ problem->x = NULL;
445
+ }
446
+ if (problem->y) {
447
+ xfree(problem->y);
448
+ problem->y = NULL;
449
+ }
450
+ xfree(problem);
451
+ problem = NULL;
452
+ }
453
+ }
454
+
455
+ /** MODULE FUNCTIONS */
456
+ static VALUE numo_libsvm_train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash) {
457
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
458
+ if (CLASS_OF(y_val) != numo_cDFloat) y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
459
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
460
+ if (!RTEST(nary_check_contiguous(y_val))) y_val = nary_dup(y_val);
461
+
462
+ narray_t* x_nary;
463
+ narray_t* y_nary;
464
+ GetNArray(x_val, x_nary);
465
+ GetNArray(y_val, y_nary);
466
+ if (NA_NDIM(x_nary) != 2) {
467
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
468
+ return Qnil;
469
+ }
470
+ if (NA_NDIM(y_nary) != 1) {
471
+ rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
472
+ return Qnil;
473
+ }
474
+ if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
475
+ rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
476
+ return Qnil;
477
+ }
478
+
479
+ VALUE random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
480
+ if (!NIL_P(random_seed)) srand(NUM2UINT(random_seed));
481
+
482
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
483
+ LibSvmProblem* problem = convertDatasetToLibSvmProblem(x_val, y_val);
484
+
485
+ const char* err_msg = svm_check_parameter(problem, param);
486
+ if (err_msg) {
487
+ deleteLibSvmProblem(problem);
488
+ deleteLibSvmParameter(param);
489
+ rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
490
+ return Qnil;
491
+ }
492
+
493
+ VALUE verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
494
+ if (!RTEST(verbose)) svm_set_print_string_function(printNull);
495
+
496
+ LibSvmModel* model = svm_train(problem, param);
497
+ VALUE model_hash = convertLibSvmModelToHash(model);
498
+ svm_free_and_destroy_model(&model);
499
+
500
+ deleteLibSvmProblem(problem);
501
+ deleteLibSvmParameter(param);
502
+
503
+ RB_GC_GUARD(x_val);
504
+ RB_GC_GUARD(y_val);
505
+
506
+ return model_hash;
507
+ }
508
+
509
+ static VALUE numo_libsvm_cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds) {
510
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
511
+ if (CLASS_OF(y_val) != numo_cDFloat) y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
512
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
513
+ if (!RTEST(nary_check_contiguous(y_val))) y_val = nary_dup(y_val);
514
+
515
+ narray_t* x_nary;
516
+ narray_t* y_nary;
517
+ GetNArray(x_val, x_nary);
518
+ GetNArray(y_val, y_nary);
519
+ if (NA_NDIM(x_nary) != 2) {
520
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
521
+ return Qnil;
522
+ }
523
+ if (NA_NDIM(y_nary) != 1) {
524
+ rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
525
+ return Qnil;
526
+ }
527
+ if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
528
+ rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
529
+ return Qnil;
530
+ }
531
+
532
+ VALUE random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
533
+ if (!NIL_P(random_seed)) srand(NUM2UINT(random_seed));
534
+
535
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
536
+ LibSvmProblem* problem = convertDatasetToLibSvmProblem(x_val, y_val);
537
+
538
+ const char* err_msg = svm_check_parameter(problem, param);
539
+ if (err_msg) {
540
+ deleteLibSvmProblem(problem);
541
+ deleteLibSvmParameter(param);
542
+ rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
543
+ return Qnil;
544
+ }
545
+
546
+ size_t t_shape[1] = {(size_t)(problem->l)};
547
+ VALUE t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
548
+ double* t_pt = (double*)na_get_pointer_for_write(t_val);
549
+
550
+ VALUE verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
551
+ if (!RTEST(verbose)) svm_set_print_string_function(printNull);
552
+
553
+ const int n_folds = NUM2INT(nr_folds);
554
+ svm_cross_validation(problem, param, n_folds, t_pt);
555
+
556
+ deleteLibSvmProblem(problem);
557
+ deleteLibSvmParameter(param);
558
+
559
+ RB_GC_GUARD(x_val);
560
+ RB_GC_GUARD(y_val);
561
+
562
+ return t_val;
563
+ }
564
+
565
+ static VALUE numo_libsvm_predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
566
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
567
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
568
+
569
+ narray_t* x_nary;
570
+ GetNArray(x_val, x_nary);
571
+ if (NA_NDIM(x_nary) != 2) {
572
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
573
+ return Qnil;
574
+ }
575
+
576
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
577
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
578
+ model->param = *param;
579
+
580
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
581
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
582
+ size_t y_shape[1] = {(size_t)n_samples};
583
+ VALUE y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
584
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
585
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
586
+ for (int i = 0; i < n_samples; i++) {
587
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
588
+ y_ptr[i] = svm_predict(model, x_nodes);
589
+ xfree(x_nodes);
590
+ }
591
+
592
+ deleteLibSvmModel(model);
593
+ deleteLibSvmParameter(param);
594
+
595
+ RB_GC_GUARD(x_val);
596
+
597
+ return y_val;
598
+ }
599
+
600
+ static VALUE numo_libsvm_decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
601
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
602
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
603
+
604
+ narray_t* x_nary;
605
+ GetNArray(x_val, x_nary);
606
+ if (NA_NDIM(x_nary) != 2) {
607
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
608
+ return Qnil;
609
+ }
610
+
611
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
612
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
613
+ model->param = *param;
614
+
615
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
616
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
617
+ const int y_cols = isSignleOutputModel(model) ? 1 : model->nr_class * (model->nr_class - 1) / 2;
618
+ size_t y_shape[2] = {(size_t)n_samples, (size_t)y_cols};
619
+ const int n_dims = isSignleOutputModel(model) ? 1 : 2;
620
+ VALUE y_val = rb_narray_new(numo_cDFloat, n_dims, y_shape);
621
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
622
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
623
+
624
+ for (int i = 0; i < n_samples; i++) {
625
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
626
+ svm_predict_values(model, x_nodes, &y_ptr[i * y_cols]);
627
+ xfree(x_nodes);
628
+ }
629
+
630
+ deleteLibSvmModel(model);
631
+ deleteLibSvmParameter(param);
632
+
633
+ RB_GC_GUARD(x_val);
634
+
635
+ return y_val;
636
+ }
637
+
638
+ static VALUE numo_libsvm_predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
639
+ narray_t* x_nary;
640
+ GetNArray(x_val, x_nary);
641
+ if (NA_NDIM(x_nary) != 2) {
642
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
643
+ return Qnil;
644
+ }
645
+
646
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
647
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
648
+ model->param = *param;
649
+
650
+ if (!isProbabilisticModel(model)) {
651
+ deleteLibSvmModel(model);
652
+ deleteLibSvmParameter(param);
653
+ return Qnil;
654
+ }
655
+
656
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
657
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
658
+
659
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
660
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
661
+ size_t y_shape[2] = {(size_t)n_samples, (size_t)(model->nr_class)};
662
+ VALUE y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
663
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
664
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
665
+ for (int i = 0; i < n_samples; i++) {
666
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
667
+ svm_predict_probability(model, x_nodes, &y_ptr[i * model->nr_class]);
668
+ xfree(x_nodes);
669
+ }
670
+
671
+ deleteLibSvmModel(model);
672
+ deleteLibSvmParameter(param);
673
+
674
+ RB_GC_GUARD(x_val);
675
+
676
+ return y_val;
677
+ }
678
+
679
+ static VALUE numo_libsvm_load_model(VALUE self, VALUE filename) {
680
+ const char* const filename_ = StringValuePtr(filename);
681
+ LibSvmModel* model = svm_load_model(filename_);
682
+ if (model == NULL) {
683
+ rb_raise(rb_eIOError, "Failed to load file '%s'", filename_);
684
+ return Qnil;
685
+ }
686
+
687
+ VALUE param_hash = convertLibSvmParameterToHash(&(model->param));
688
+ VALUE model_hash = convertLibSvmModelToHash(model);
689
+ svm_free_and_destroy_model(&model);
690
+
691
+ VALUE res = rb_ary_new2(2);
692
+ rb_ary_store(res, 0, param_hash);
693
+ rb_ary_store(res, 1, model_hash);
694
+
695
+ RB_GC_GUARD(filename);
696
+
697
+ return res;
698
+ }
699
+
700
+ static VALUE numo_libsvm_save_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash) {
701
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
702
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
703
+ model->param = *param;
704
+
705
+ const char* const filename_ = StringValuePtr(filename);
706
+ const int res = svm_save_model(filename_, model);
707
+
708
+ deleteLibSvmModel(model);
709
+ deleteLibSvmParameter(param);
710
+
711
+ if (res < 0) {
712
+ rb_raise(rb_eIOError, "Failed to save file '%s'", filename_);
713
+ return Qfalse;
714
+ }
715
+
716
+ RB_GC_GUARD(filename);
717
+
718
+ return Qtrue;
719
+ }
720
+
721
+ #endif /* LIBSVMEXT_HPP */