numo-libsvm 1.1.2 → 2.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,715 @@
1
+ /**
2
+ * Copyright (c) 2019-2022 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #ifndef LIBSVMEXT_HPP
32
+ #define LIBSVMEXT_HPP 1
33
+
34
+ #include <cmath>
35
+ #include <cstring>
36
+
37
+ #include <ruby.h>
38
+
39
+ #include <numo/narray.h>
40
+ #include <numo/template.h>
41
+
42
+ #include <svm.h>
43
+
44
+ typedef struct svm_model LibSvmModel;
45
+ typedef struct svm_node LibSvmNode;
46
+ typedef struct svm_parameter LibSvmParameter;
47
+ typedef struct svm_problem LibSvmProblem;
48
+
49
+ void printNull(const char* s) {}
50
+
51
+ /** CONVERTERS */
52
+ VALUE convertVectorXiToNArray(const int* const arr, const int size) {
53
+ size_t shape[1] = {(size_t)size};
54
+ VALUE vec_val = rb_narray_new(numo_cInt32, 1, shape);
55
+ int32_t* vec_ptr = (int32_t*)na_get_pointer_for_write(vec_val);
56
+ for (int i = 0; i < size; i++) vec_ptr[i] = (int32_t)arr[i];
57
+ return vec_val;
58
+ }
59
+
60
+ int* convertNArrayToVectorXi(VALUE vec_val) {
61
+ if (NIL_P(vec_val)) return NULL;
62
+
63
+ narray_t* vec_nary;
64
+ GetNArray(vec_val, vec_nary);
65
+ const size_t n_elements = NA_SHAPE(vec_nary)[0];
66
+ int* arr = ALLOC_N(int, n_elements);
67
+ const int32_t* const vec_ptr = (int32_t*)na_get_pointer_for_read(vec_val);
68
+ for (size_t i = 0; i < n_elements; i++) arr[i] = (int)vec_ptr[i];
69
+
70
+ RB_GC_GUARD(vec_val);
71
+
72
+ return arr;
73
+ }
74
+
75
+ VALUE convertVectorXdToNArray(const double* const arr, const int size) {
76
+ size_t shape[1] = {(size_t)size};
77
+ VALUE vec_val = rb_narray_new(numo_cDFloat, 1, shape);
78
+ double* vec_ptr = (double*)na_get_pointer_for_write(vec_val);
79
+ for (int i = 0; i < size; i++) vec_ptr[i] = arr[i];
80
+ return vec_val;
81
+ }
82
+
83
+ double* convertNArrayToVectorXd(VALUE vec_val) {
84
+ if (NIL_P(vec_val)) return NULL;
85
+
86
+ narray_t* vec_nary;
87
+ GetNArray(vec_val, vec_nary);
88
+ const size_t n_elements = NA_SHAPE(vec_nary)[0];
89
+ double* arr = ALLOC_N(double, n_elements);
90
+ const double* const vec_ptr = (double*)na_get_pointer_for_read(vec_val);
91
+ memcpy(arr, vec_ptr, n_elements * sizeof(double));
92
+
93
+ RB_GC_GUARD(vec_val);
94
+
95
+ return arr;
96
+ }
97
+
98
+ VALUE convertMatrixXdToNArray(const double* const* mat, const int n_rows, const int n_cols) {
99
+ size_t shape[2] = {(size_t)n_rows, (size_t)n_cols};
100
+ VALUE mat_val = rb_narray_new(numo_cDFloat, 2, shape);
101
+ double* mat_ptr = (double*)na_get_pointer_for_write(mat_val);
102
+ for (int i = 0; i < n_rows; i++) memcpy(&mat_ptr[i * n_cols], mat[i], n_cols * sizeof(double));
103
+ return mat_val;
104
+ }
105
+
106
+ double** convertNArrayToMatrixXd(VALUE mat_val) {
107
+ if (NIL_P(mat_val)) return NULL;
108
+
109
+ narray_t* mat_nary;
110
+ GetNArray(mat_val, mat_nary);
111
+ const size_t n_rows = NA_SHAPE(mat_nary)[0];
112
+ const size_t n_cols = NA_SHAPE(mat_nary)[1];
113
+ const double* const mat_ptr = (double*)na_get_pointer_for_read(mat_val);
114
+ double** mat = ALLOC_N(double*, n_rows);
115
+ for (size_t i = 0; i < n_rows; i++) {
116
+ mat[i] = ALLOC_N(double, n_cols);
117
+ memcpy(mat[i], &mat_ptr[i * n_cols], n_cols * sizeof(double));
118
+ }
119
+
120
+ RB_GC_GUARD(mat_val);
121
+
122
+ return mat;
123
+ }
124
+
125
+ VALUE convertLibSvmNodeToNArray(const LibSvmNode* const* support_vecs, const int n_support_vecs) {
126
+ int n_dimensions = 0;
127
+ for (int i = 0; i < n_support_vecs; i++) {
128
+ for (int j = 0; support_vecs[i][j].index != -1; j++) {
129
+ if (n_dimensions < support_vecs[i][j].index) {
130
+ n_dimensions = support_vecs[i][j].index;
131
+ }
132
+ }
133
+ }
134
+
135
+ size_t shape[2] = {(size_t)n_support_vecs, (size_t)n_dimensions};
136
+ VALUE vec_val = rb_narray_new(numo_cDFloat, 2, shape);
137
+ double* vec_ptr = (double*)na_get_pointer_for_write(vec_val);
138
+ memset(vec_ptr, 0, n_support_vecs * n_dimensions * sizeof(double));
139
+ for (int i = 0; i < n_support_vecs; i++) {
140
+ for (int j = 0; support_vecs[i][j].index != -1; j++) {
141
+ vec_ptr[i * n_dimensions + support_vecs[i][j].index - 1] = support_vecs[i][j].value;
142
+ }
143
+ }
144
+
145
+ return vec_val;
146
+ }
147
+
148
+ LibSvmNode** convertNArrayToLibSvmNode(VALUE vec_val) {
149
+ if (NIL_P(vec_val)) return NULL;
150
+
151
+ narray_t* vec_nary;
152
+ GetNArray(vec_val, vec_nary);
153
+ const size_t n_rows = NA_SHAPE(vec_nary)[0];
154
+ const size_t n_cols = NA_SHAPE(vec_nary)[1];
155
+ const double* const vec_ptr = (double*)na_get_pointer_for_read(vec_val);
156
+ LibSvmNode** support_vecs = ALLOC_N(LibSvmNode*, n_rows);
157
+ for (size_t i = 0; i < n_rows; i++) {
158
+ int n_nonzero_cols = 0;
159
+ for (size_t j = 0; j < n_cols; j++) {
160
+ if (vec_ptr[i * n_cols + j] != 0) {
161
+ n_nonzero_cols++;
162
+ }
163
+ }
164
+ support_vecs[i] = ALLOC_N(LibSvmNode, n_nonzero_cols + 1);
165
+ for (size_t j = 0, k = 0; j < n_cols; j++) {
166
+ if (vec_ptr[i * n_cols + j] != 0) {
167
+ support_vecs[i][k].index = j + 1;
168
+ support_vecs[i][k].value = vec_ptr[i * n_cols + j];
169
+ k++;
170
+ }
171
+ }
172
+ support_vecs[i][n_nonzero_cols].index = -1;
173
+ support_vecs[i][n_nonzero_cols].value = 0.0;
174
+ }
175
+
176
+ RB_GC_GUARD(vec_val);
177
+
178
+ return support_vecs;
179
+ }
180
+
181
+ LibSvmNode* convertVectorXdToLibSvmNode(const double* const arr, const int size) {
182
+ int n_nonzero_elements = 0;
183
+ for (int i = 0; i < size; i++) {
184
+ if (arr[i] != 0.0) n_nonzero_elements++;
185
+ }
186
+
187
+ LibSvmNode* node = ALLOC_N(LibSvmNode, n_nonzero_elements + 1);
188
+ for (int i = 0, j = 0; i < size; i++) {
189
+ if (arr[i] != 0.0) {
190
+ node[j].index = i + 1;
191
+ node[j].value = arr[i];
192
+ j++;
193
+ }
194
+ }
195
+ node[n_nonzero_elements].index = -1;
196
+ node[n_nonzero_elements].value = 0.0;
197
+
198
+ return node;
199
+ }
200
+
201
+ LibSvmModel* convertHashToLibSvmModel(VALUE model_hash) {
202
+ LibSvmModel* model = ALLOC(LibSvmModel);
203
+ VALUE el;
204
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nr_class")));
205
+ model->nr_class = !NIL_P(el) ? NUM2INT(el) : 0;
206
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("l")));
207
+ model->l = !NIL_P(el) ? NUM2INT(el) : 0;
208
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("SV")));
209
+ model->SV = convertNArrayToLibSvmNode(el);
210
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_coef")));
211
+ model->sv_coef = convertNArrayToMatrixXd(el);
212
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("rho")));
213
+ model->rho = convertNArrayToVectorXd(el);
214
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probA")));
215
+ model->probA = convertNArrayToVectorXd(el);
216
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probB")));
217
+ model->probB = convertNArrayToVectorXd(el);
218
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_indices")));
219
+ model->sv_indices = convertNArrayToVectorXi(el);
220
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("label")));
221
+ model->label = convertNArrayToVectorXi(el);
222
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nSV")));
223
+ model->nSV = convertNArrayToVectorXi(el);
224
+ el = rb_hash_aref(model_hash, ID2SYM(rb_intern("free_sv")));
225
+ model->free_sv = !NIL_P(el) ? NUM2INT(el) : 0;
226
+ return model;
227
+ }
228
+
229
+ VALUE convertLibSvmModelToHash(const LibSvmModel* const model) {
230
+ const int n_classes = model->nr_class;
231
+ const int n_support_vecs = model->l;
232
+ VALUE support_vecs = model->SV ? convertLibSvmNodeToNArray(model->SV, n_support_vecs) : Qnil;
233
+ VALUE coefficients = model->sv_coef ? convertMatrixXdToNArray(model->sv_coef, n_classes - 1, n_support_vecs) : Qnil;
234
+ VALUE intercepts = model->rho ? convertVectorXdToNArray(model->rho, n_classes * (n_classes - 1) / 2) : Qnil;
235
+ VALUE prob_alpha = model->probA ? convertVectorXdToNArray(model->probA, n_classes * (n_classes - 1) / 2) : Qnil;
236
+ VALUE prob_beta = model->probB ? convertVectorXdToNArray(model->probB, n_classes * (n_classes - 1) / 2) : Qnil;
237
+ VALUE sv_indices = model->sv_indices ? convertVectorXiToNArray(model->sv_indices, n_support_vecs) : Qnil;
238
+ VALUE labels = model->label ? convertVectorXiToNArray(model->label, n_classes) : Qnil;
239
+ VALUE n_support_vecs_each_class = model->nSV ? convertVectorXiToNArray(model->nSV, n_classes) : Qnil;
240
+ VALUE model_hash = rb_hash_new();
241
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("nr_class")), INT2NUM(n_classes));
242
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("l")), INT2NUM(n_support_vecs));
243
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("SV")), support_vecs);
244
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_coef")), coefficients);
245
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("rho")), intercepts);
246
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("probA")), prob_alpha);
247
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("probB")), prob_beta);
248
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_indices")), sv_indices);
249
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("label")), labels);
250
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("nSV")), n_support_vecs_each_class);
251
+ rb_hash_aset(model_hash, ID2SYM(rb_intern("free_sv")), INT2NUM(model->free_sv));
252
+ return model_hash;
253
+ }
254
+
255
+ LibSvmParameter* convertHashToLibSvmParameter(VALUE param_hash) {
256
+ LibSvmParameter* param = ALLOC(LibSvmParameter);
257
+ VALUE el;
258
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("svm_type")));
259
+ param->svm_type = !NIL_P(el) ? NUM2INT(el) : C_SVC;
260
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("kernel_type")));
261
+ param->kernel_type = !NIL_P(el) ? NUM2INT(el) : RBF;
262
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("degree")));
263
+ param->degree = !NIL_P(el) ? NUM2INT(el) : 3;
264
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("gamma")));
265
+ param->gamma = !NIL_P(el) ? NUM2DBL(el) : 1;
266
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("coef0")));
267
+ param->coef0 = !NIL_P(el) ? NUM2DBL(el) : 0;
268
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("cache_size")));
269
+ param->cache_size = !NIL_P(el) ? NUM2DBL(el) : 100;
270
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("eps")));
271
+ param->eps = !NIL_P(el) ? NUM2DBL(el) : 1e-3;
272
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("C")));
273
+ param->C = !NIL_P(el) ? NUM2DBL(el) : 1;
274
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("nr_weight")));
275
+ param->nr_weight = !NIL_P(el) ? NUM2INT(el) : 0;
276
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("nu")));
277
+ param->nu = !NIL_P(el) ? NUM2DBL(el) : 0.5;
278
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("p")));
279
+ param->p = !NIL_P(el) ? NUM2DBL(el) : 0.1;
280
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("shrinking")));
281
+ param->shrinking = RB_TYPE_P(el, T_FALSE) ? 0 : 1;
282
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("probability")));
283
+ param->probability = RB_TYPE_P(el, T_TRUE) ? 1 : 0;
284
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("weight_label")));
285
+ param->weight_label = NULL;
286
+ if (!NIL_P(el)) {
287
+ param->weight_label = ALLOC_N(int, param->nr_weight);
288
+ memcpy(param->weight_label, (int32_t*)na_get_pointer_for_read(el), param->nr_weight * sizeof(int32_t));
289
+ }
290
+ el = rb_hash_aref(param_hash, ID2SYM(rb_intern("weight")));
291
+ param->weight = NULL;
292
+ if (!NIL_P(el)) {
293
+ param->weight = ALLOC_N(double, param->nr_weight);
294
+ memcpy(param->weight, (double*)na_get_pointer_for_read(el), param->nr_weight * sizeof(double));
295
+ }
296
+ return param;
297
+ }
298
+
299
+ VALUE convertLibSvmParameterToHash(const LibSvmParameter* const param) {
300
+ VALUE param_hash = rb_hash_new();
301
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("svm_type")), INT2NUM(param->svm_type));
302
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("kernel_type")), INT2NUM(param->kernel_type));
303
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("degree")), INT2NUM(param->degree));
304
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("gamma")), DBL2NUM(param->gamma));
305
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("coef0")), DBL2NUM(param->coef0));
306
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("cache_size")), DBL2NUM(param->cache_size));
307
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("eps")), DBL2NUM(param->eps));
308
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("C")), DBL2NUM(param->C));
309
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("nr_weight")), INT2NUM(param->nr_weight));
310
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("nu")), DBL2NUM(param->nu));
311
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("p")), DBL2NUM(param->p));
312
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("shrinking")), param->shrinking == 1 ? Qtrue : Qfalse);
313
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("probability")), param->probability == 1 ? Qtrue : Qfalse);
314
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("weight_label")),
315
+ param->weight_label ? convertVectorXiToNArray(param->weight_label, param->nr_weight) : Qnil);
316
+ rb_hash_aset(param_hash, ID2SYM(rb_intern("weight")),
317
+ param->weight ? convertVectorXdToNArray(param->weight, param->nr_weight) : Qnil);
318
+ return param_hash;
319
+ }
320
+
321
+ LibSvmProblem* convertDatasetToLibSvmProblem(VALUE x_val, VALUE y_val) {
322
+ narray_t* x_nary;
323
+ GetNArray(x_val, x_nary);
324
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
325
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
326
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
327
+ const double* const y_ptr = (double*)na_get_pointer_for_read(y_val);
328
+
329
+ LibSvmProblem* problem = ALLOC(LibSvmProblem);
330
+ problem->l = n_samples;
331
+ problem->x = ALLOC_N(LibSvmNode*, n_samples);
332
+ problem->y = ALLOC_N(double, n_samples);
333
+
334
+ int last_feature_id = 0;
335
+ bool is_padded = false;
336
+ for (int i = 0; i < n_samples; i++) {
337
+ int n_nonzero_features = 0;
338
+ for (int j = 0; j < n_features; j++) {
339
+ if (x_ptr[i * n_features + j] != 0.0) {
340
+ n_nonzero_features += 1;
341
+ last_feature_id = j + 1;
342
+ }
343
+ }
344
+ if (!is_padded && last_feature_id == n_features) is_padded = true;
345
+ if (is_padded) {
346
+ problem->x[i] = ALLOC_N(LibSvmNode, n_nonzero_features + 1);
347
+ } else {
348
+ problem->x[i] = ALLOC_N(LibSvmNode, n_nonzero_features + 2);
349
+ }
350
+ for (int j = 0, k = 0; j < n_features; j++) {
351
+ if (x_ptr[i * n_features + j] != 0.0) {
352
+ problem->x[i][k].index = j + 1;
353
+ problem->x[i][k].value = x_ptr[i * n_features + j];
354
+ k++;
355
+ }
356
+ }
357
+ if (is_padded) {
358
+ problem->x[i][n_nonzero_features].index = -1;
359
+ problem->x[i][n_nonzero_features].value = 0.0;
360
+ } else {
361
+ problem->x[i][n_nonzero_features].index = n_features;
362
+ problem->x[i][n_nonzero_features].value = 0.0;
363
+ problem->x[i][n_nonzero_features + 1].index = -1;
364
+ problem->x[i][n_nonzero_features + 1].value = 0.0;
365
+ }
366
+ problem->y[i] = y_ptr[i];
367
+ }
368
+
369
+ RB_GC_GUARD(x_val);
370
+ RB_GC_GUARD(y_val);
371
+
372
+ return problem;
373
+ }
374
+
375
+ /** UTILITIES */
376
+ bool isSignleOutputModel(LibSvmModel* model) {
377
+ return (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR);
378
+ }
379
+
380
+ bool isProbabilisticModel(LibSvmModel* model) {
381
+ return ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && model->probA != NULL && model->probB != NULL);
382
+ }
383
+
384
+ void deleteLibSvmModel(LibSvmModel* model) {
385
+ if (model) {
386
+ if (model->SV) {
387
+ for (int i = 0; i < model->l; i++) xfree(model->SV[i]);
388
+ xfree(model->SV);
389
+ model->SV = NULL;
390
+ }
391
+ if (model->sv_coef) {
392
+ for (int i = 0; i < model->nr_class - 1; i++) xfree(model->sv_coef[i]);
393
+ xfree(model->sv_coef);
394
+ model->sv_coef = NULL;
395
+ }
396
+ xfree(model->rho);
397
+ model->rho = NULL;
398
+ xfree(model->probA);
399
+ model->probA = NULL;
400
+ xfree(model->probB);
401
+ model->probB = NULL;
402
+ xfree(model->sv_indices);
403
+ model->sv_indices = NULL;
404
+ xfree(model->label);
405
+ model->label = NULL;
406
+ xfree(model->nSV);
407
+ model->nSV = NULL;
408
+ xfree(model);
409
+ model = NULL;
410
+ }
411
+ }
412
+
413
+ void deleteLibSvmParameter(LibSvmParameter* param) {
414
+ if (param) {
415
+ if (param->weight_label) {
416
+ xfree(param->weight_label);
417
+ param->weight_label = NULL;
418
+ }
419
+ if (param->weight) {
420
+ xfree(param->weight);
421
+ param->weight = NULL;
422
+ }
423
+ xfree(param);
424
+ param = NULL;
425
+ }
426
+ }
427
+
428
+ void deleteLibSvmProblem(LibSvmProblem* problem) {
429
+ if (problem) {
430
+ if (problem->x) {
431
+ for (int i = 0; i < problem->l; i++) {
432
+ if (problem->x[i]) {
433
+ xfree(problem->x[i]);
434
+ problem->x[i] = NULL;
435
+ }
436
+ }
437
+ xfree(problem->x);
438
+ problem->x = NULL;
439
+ }
440
+ if (problem->y) {
441
+ xfree(problem->y);
442
+ problem->y = NULL;
443
+ }
444
+ xfree(problem);
445
+ problem = NULL;
446
+ }
447
+ }
448
+
449
+ /** MODULE FUNCTIONS */
450
+ static VALUE numo_libsvm_train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash) {
451
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
452
+ if (CLASS_OF(y_val) != numo_cDFloat) y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
453
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
454
+ if (!RTEST(nary_check_contiguous(y_val))) y_val = nary_dup(y_val);
455
+
456
+ narray_t* x_nary;
457
+ narray_t* y_nary;
458
+ GetNArray(x_val, x_nary);
459
+ GetNArray(y_val, y_nary);
460
+ if (NA_NDIM(x_nary) != 2) {
461
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
462
+ return Qnil;
463
+ }
464
+ if (NA_NDIM(y_nary) != 1) {
465
+ rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
466
+ return Qnil;
467
+ }
468
+ if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
469
+ rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
470
+ return Qnil;
471
+ }
472
+
473
+ VALUE random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
474
+ if (!NIL_P(random_seed)) srand(NUM2UINT(random_seed));
475
+
476
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
477
+ LibSvmProblem* problem = convertDatasetToLibSvmProblem(x_val, y_val);
478
+
479
+ const char* err_msg = svm_check_parameter(problem, param);
480
+ if (err_msg) {
481
+ deleteLibSvmProblem(problem);
482
+ deleteLibSvmParameter(param);
483
+ rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
484
+ return Qnil;
485
+ }
486
+
487
+ VALUE verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
488
+ if (!RTEST(verbose)) svm_set_print_string_function(printNull);
489
+
490
+ LibSvmModel* model = svm_train(problem, param);
491
+ VALUE model_hash = convertLibSvmModelToHash(model);
492
+ svm_free_and_destroy_model(&model);
493
+
494
+ deleteLibSvmProblem(problem);
495
+ deleteLibSvmParameter(param);
496
+
497
+ RB_GC_GUARD(x_val);
498
+ RB_GC_GUARD(y_val);
499
+
500
+ return model_hash;
501
+ }
502
+
503
+ static VALUE numo_libsvm_cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds) {
504
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
505
+ if (CLASS_OF(y_val) != numo_cDFloat) y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
506
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
507
+ if (!RTEST(nary_check_contiguous(y_val))) y_val = nary_dup(y_val);
508
+
509
+ narray_t* x_nary;
510
+ narray_t* y_nary;
511
+ GetNArray(x_val, x_nary);
512
+ GetNArray(y_val, y_nary);
513
+ if (NA_NDIM(x_nary) != 2) {
514
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
515
+ return Qnil;
516
+ }
517
+ if (NA_NDIM(y_nary) != 1) {
518
+ rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
519
+ return Qnil;
520
+ }
521
+ if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
522
+ rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
523
+ return Qnil;
524
+ }
525
+
526
+ VALUE random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
527
+ if (!NIL_P(random_seed)) srand(NUM2UINT(random_seed));
528
+
529
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
530
+ LibSvmProblem* problem = convertDatasetToLibSvmProblem(x_val, y_val);
531
+
532
+ const char* err_msg = svm_check_parameter(problem, param);
533
+ if (err_msg) {
534
+ deleteLibSvmProblem(problem);
535
+ deleteLibSvmParameter(param);
536
+ rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
537
+ return Qnil;
538
+ }
539
+
540
+ size_t t_shape[1] = {(size_t)(problem->l)};
541
+ VALUE t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
542
+ double* t_pt = (double*)na_get_pointer_for_write(t_val);
543
+
544
+ VALUE verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
545
+ if (!RTEST(verbose)) svm_set_print_string_function(printNull);
546
+
547
+ const int n_folds = NUM2INT(nr_folds);
548
+ svm_cross_validation(problem, param, n_folds, t_pt);
549
+
550
+ deleteLibSvmProblem(problem);
551
+ deleteLibSvmParameter(param);
552
+
553
+ RB_GC_GUARD(x_val);
554
+ RB_GC_GUARD(y_val);
555
+
556
+ return t_val;
557
+ }
558
+
559
+ static VALUE numo_libsvm_predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
560
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
561
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
562
+
563
+ narray_t* x_nary;
564
+ GetNArray(x_val, x_nary);
565
+ if (NA_NDIM(x_nary) != 2) {
566
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
567
+ return Qnil;
568
+ }
569
+
570
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
571
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
572
+ model->param = *param;
573
+
574
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
575
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
576
+ size_t y_shape[1] = {(size_t)n_samples};
577
+ VALUE y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
578
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
579
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
580
+ for (int i = 0; i < n_samples; i++) {
581
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
582
+ y_ptr[i] = svm_predict(model, x_nodes);
583
+ xfree(x_nodes);
584
+ }
585
+
586
+ deleteLibSvmModel(model);
587
+ deleteLibSvmParameter(param);
588
+
589
+ RB_GC_GUARD(x_val);
590
+
591
+ return y_val;
592
+ }
593
+
594
+ static VALUE numo_libsvm_decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
595
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
596
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
597
+
598
+ narray_t* x_nary;
599
+ GetNArray(x_val, x_nary);
600
+ if (NA_NDIM(x_nary) != 2) {
601
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
602
+ return Qnil;
603
+ }
604
+
605
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
606
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
607
+ model->param = *param;
608
+
609
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
610
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
611
+ const int y_cols = isSignleOutputModel(model) ? 1 : model->nr_class * (model->nr_class - 1) / 2;
612
+ size_t y_shape[2] = {(size_t)n_samples, (size_t)y_cols};
613
+ const int n_dims = isSignleOutputModel(model) ? 1 : 2;
614
+ VALUE y_val = rb_narray_new(numo_cDFloat, n_dims, y_shape);
615
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
616
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
617
+
618
+ for (int i = 0; i < n_samples; i++) {
619
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
620
+ svm_predict_values(model, x_nodes, &y_ptr[i * y_cols]);
621
+ xfree(x_nodes);
622
+ }
623
+
624
+ deleteLibSvmModel(model);
625
+ deleteLibSvmParameter(param);
626
+
627
+ RB_GC_GUARD(x_val);
628
+
629
+ return y_val;
630
+ }
631
+
632
+ static VALUE numo_libsvm_predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash) {
633
+ narray_t* x_nary;
634
+ GetNArray(x_val, x_nary);
635
+ if (NA_NDIM(x_nary) != 2) {
636
+ rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
637
+ return Qnil;
638
+ }
639
+
640
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
641
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
642
+ model->param = *param;
643
+
644
+ if (!isProbabilisticModel(model)) {
645
+ deleteLibSvmModel(model);
646
+ deleteLibSvmParameter(param);
647
+ return Qnil;
648
+ }
649
+
650
+ if (CLASS_OF(x_val) != numo_cDFloat) x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
651
+ if (!RTEST(nary_check_contiguous(x_val))) x_val = nary_dup(x_val);
652
+
653
+ const int n_samples = (int)NA_SHAPE(x_nary)[0];
654
+ const int n_features = (int)NA_SHAPE(x_nary)[1];
655
+ size_t y_shape[2] = {(size_t)n_samples, (size_t)(model->nr_class)};
656
+ VALUE y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
657
+ const double* const x_ptr = (double*)na_get_pointer_for_read(x_val);
658
+ double* y_ptr = (double*)na_get_pointer_for_write(y_val);
659
+ for (int i = 0; i < n_samples; i++) {
660
+ LibSvmNode* x_nodes = convertVectorXdToLibSvmNode(&x_ptr[i * n_features], n_features);
661
+ svm_predict_probability(model, x_nodes, &y_ptr[i * model->nr_class]);
662
+ xfree(x_nodes);
663
+ }
664
+
665
+ deleteLibSvmModel(model);
666
+ deleteLibSvmParameter(param);
667
+
668
+ RB_GC_GUARD(x_val);
669
+
670
+ return y_val;
671
+ }
672
+
673
+ static VALUE numo_libsvm_load_model(VALUE self, VALUE filename) {
674
+ const char* const filename_ = StringValuePtr(filename);
675
+ LibSvmModel* model = svm_load_model(filename_);
676
+ if (model == NULL) {
677
+ rb_raise(rb_eIOError, "Failed to load file '%s'", filename_);
678
+ return Qnil;
679
+ }
680
+
681
+ VALUE param_hash = convertLibSvmParameterToHash(&(model->param));
682
+ VALUE model_hash = convertLibSvmModelToHash(model);
683
+ svm_free_and_destroy_model(&model);
684
+
685
+ VALUE res = rb_ary_new2(2);
686
+ rb_ary_store(res, 0, param_hash);
687
+ rb_ary_store(res, 1, model_hash);
688
+
689
+ RB_GC_GUARD(filename);
690
+
691
+ return res;
692
+ }
693
+
694
+ static VALUE numo_libsvm_save_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash) {
695
+ LibSvmParameter* param = convertHashToLibSvmParameter(param_hash);
696
+ LibSvmModel* model = convertHashToLibSvmModel(model_hash);
697
+ model->param = *param;
698
+
699
+ const char* const filename_ = StringValuePtr(filename);
700
+ const int res = svm_save_model(filename_, model);
701
+
702
+ deleteLibSvmModel(model);
703
+ deleteLibSvmParameter(param);
704
+
705
+ if (res < 0) {
706
+ rb_raise(rb_eIOError, "Failed to save file '%s'", filename_);
707
+ return Qfalse;
708
+ }
709
+
710
+ RB_GC_GUARD(filename);
711
+
712
+ return Qtrue;
713
+ }
714
+
715
+ #endif /* LIBSVMEXT_HPP */