numo-libsvm 1.1.2 → 2.0.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,578 +0,0 @@
1
- /**
2
- * LIBSVM interface for Numo::NArray
3
- */
4
- #include "libsvmext.h"
5
-
6
- VALUE mNumo;
7
- VALUE mLibsvm;
8
-
9
- void print_null(const char *s) {}
10
-
11
- /**
12
- * Train the SVM model according to the given training data.
13
- *
14
- * @overload train(x, y, param) -> Hash
15
- * @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
16
- * @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
17
- * @param param [Hash] The parameters of an SVM model.
18
- *
19
- * @example
20
- * require 'numo/libsvm'
21
- *
22
- * # Prepare XOR data.
23
- * x = Numo::DFloat[[-0.8, -0.7], [0.9, 0.8], [-0.7, 0.9], [0.8, -0.9]]
24
- * y = Numo::Int32[-1, -1, 1, 1]
25
- *
26
- * # Train C-Support Vector Classifier with RBF kernel.
27
- * param = {
28
- * svm_type: Numo::Libsvm::SvmType::C_SVC,
29
- * kernel_type: Numo::Libsvm::KernelType::RBF,
30
- * gamma: 2.0,
31
- * C: 1,
32
- * random_seed: 1
33
- * }
34
- * model = Numo::Libsvm.train(x, y, param)
35
- *
36
- * # Predict labels of test data.
37
- * x_test = Numo::DFloat[[-0.4, -0.5], [0.5, -0.4]]
38
- * result = Numo::Libsvm.predict(x_test, param, model)
39
- * p result
40
- * # Numo::DFloat#shape=[2]
41
- * # [-1, 1]
42
- *
43
- * @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
44
- * the sample array and label array do not have the same number of samples, or
45
- * the hyperparameter has an invalid value, this error is raised.
46
- * @return [Hash] The model obtained from the training procedure.
47
- */
48
- static
49
- VALUE train(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash)
50
- {
51
- struct svm_problem* problem;
52
- struct svm_parameter* param;
53
- struct svm_model* model;
54
- narray_t* x_nary;
55
- narray_t* y_nary;
56
- char* err_msg;
57
- VALUE random_seed;
58
- VALUE verbose;
59
- VALUE model_hash;
60
-
61
- if (CLASS_OF(x_val) != numo_cDFloat) {
62
- x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
63
- }
64
- if (CLASS_OF(y_val) != numo_cDFloat) {
65
- y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
66
- }
67
- if (!RTEST(nary_check_contiguous(x_val))) {
68
- x_val = nary_dup(x_val);
69
- }
70
- if (!RTEST(nary_check_contiguous(y_val))) {
71
- y_val = nary_dup(y_val);
72
- }
73
-
74
- GetNArray(x_val, x_nary);
75
- GetNArray(y_val, y_nary);
76
- if (NA_NDIM(x_nary) != 2) {
77
- rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
78
- return Qnil;
79
- }
80
- if (NA_NDIM(y_nary) != 1) {
81
- rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
82
- return Qnil;
83
- }
84
- if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
85
- rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
86
- return Qnil;
87
- }
88
-
89
- random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
90
- if (!NIL_P(random_seed)) {
91
- srand(NUM2UINT(random_seed));
92
- }
93
-
94
- param = rb_hash_to_svm_parameter(param_hash);
95
- problem = dataset_to_svm_problem(x_val, y_val);
96
-
97
- err_msg = svm_check_parameter(problem, param);
98
- if (err_msg) {
99
- xfree_svm_problem(problem);
100
- xfree_svm_parameter(param);
101
- rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
102
- return Qnil;
103
- }
104
-
105
- verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
106
- if (verbose != Qtrue) {
107
- svm_set_print_string_function(print_null);
108
- }
109
-
110
- model = svm_train(problem, param);
111
- model_hash = svm_model_to_rb_hash(model);
112
- svm_free_and_destroy_model(&model);
113
-
114
- xfree_svm_problem(problem);
115
- xfree_svm_parameter(param);
116
-
117
- RB_GC_GUARD(x_val);
118
- RB_GC_GUARD(y_val);
119
-
120
- return model_hash;
121
- }
122
-
123
- /**
124
- * Perform cross validation under given parameters. The given samples are separated to n_fols folds.
125
- * The predicted labels or values in the validation process are returned.
126
- *
127
- * @overload cv(x, y, param, n_folds) -> Numo::DFloat
128
- * @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for training the model.
129
- * @param y [Numo::DFloat] (shape: [n_samples]) The labels or target values for samples.
130
- * @param param [Hash] The parameters of an SVM model.
131
- * @param n_folds [Integer] The number of folds.
132
- *
133
- * @example
134
- * require 'numo/libsvm'
135
- *
136
- * # x: samples
137
- * # y: labels
138
- *
139
- * # Define parameters of C-SVC with RBF Kernel.
140
- * param = {
141
- * svm_type: Numo::Libsvm::SvmType::C_SVC,
142
- * kernel_type: Numo::Libsvm::KernelType::RBF,
143
- * gamma: 1.0,
144
- * C: 1,
145
- * random_seed: 1,
146
- * verbose: true
147
- * }
148
- *
149
- * # Perform 5-cross validation.
150
- * n_folds = 5
151
- * res = Numo::Libsvm.cv(x, y, param, n_folds)
152
- *
153
- * # Print mean accuracy.
154
- * mean_accuracy = y.eq(res).count.fdiv(y.size)
155
- * puts "Accuracy: %.1f %%" % (100 * mean_accuracy)
156
- *
157
- * @raise [ArgumentError] If the sample array is not 2-dimensional, the label array is not 1-dimensional,
158
- * the sample array and label array do not have the same number of samples, or
159
- * the hyperparameter has an invalid value, this error is raised.
160
- * @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
161
- */
162
- static
163
- VALUE cross_validation(VALUE self, VALUE x_val, VALUE y_val, VALUE param_hash, VALUE nr_folds)
164
- {
165
- const int n_folds = NUM2INT(nr_folds);
166
- size_t t_shape[1];
167
- VALUE t_val;
168
- double* t_pt;
169
- narray_t* x_nary;
170
- narray_t* y_nary;
171
- char* err_msg;
172
- VALUE random_seed;
173
- VALUE verbose;
174
- struct svm_problem* problem;
175
- struct svm_parameter* param;
176
-
177
- if (CLASS_OF(x_val) != numo_cDFloat) {
178
- x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
179
- }
180
- if (CLASS_OF(y_val) != numo_cDFloat) {
181
- y_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, y_val);
182
- }
183
- if (!RTEST(nary_check_contiguous(x_val))) {
184
- x_val = nary_dup(x_val);
185
- }
186
- if (!RTEST(nary_check_contiguous(y_val))) {
187
- y_val = nary_dup(y_val);
188
- }
189
-
190
- GetNArray(x_val, x_nary);
191
- GetNArray(y_val, y_nary);
192
- if (NA_NDIM(x_nary) != 2) {
193
- rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
194
- return Qnil;
195
- }
196
- if (NA_NDIM(y_nary) != 1) {
197
- rb_raise(rb_eArgError, "Expect label or target values to be 1-D arrray.");
198
- return Qnil;
199
- }
200
- if (NA_SHAPE(x_nary)[0] != NA_SHAPE(y_nary)[0]) {
201
- rb_raise(rb_eArgError, "Expect to have the same number of samples for samples and labels.");
202
- return Qnil;
203
- }
204
-
205
- random_seed = rb_hash_aref(param_hash, ID2SYM(rb_intern("random_seed")));
206
- if (!NIL_P(random_seed)) {
207
- srand(NUM2UINT(random_seed));
208
- }
209
-
210
- param = rb_hash_to_svm_parameter(param_hash);
211
- problem = dataset_to_svm_problem(x_val, y_val);
212
-
213
- err_msg = svm_check_parameter(problem, param);
214
- if (err_msg) {
215
- xfree_svm_problem(problem);
216
- xfree_svm_parameter(param);
217
- rb_raise(rb_eArgError, "Invalid LIBSVM parameter is given: %s", err_msg);
218
- return Qnil;
219
- }
220
-
221
- t_shape[0] = problem->l;
222
- t_val = rb_narray_new(numo_cDFloat, 1, t_shape);
223
- t_pt = (double*)na_get_pointer_for_write(t_val);
224
-
225
- verbose = rb_hash_aref(param_hash, ID2SYM(rb_intern("verbose")));
226
- if (verbose != Qtrue) {
227
- svm_set_print_string_function(print_null);
228
- }
229
-
230
- svm_cross_validation(problem, param, n_folds, t_pt);
231
-
232
- xfree_svm_problem(problem);
233
- xfree_svm_parameter(param);
234
-
235
- RB_GC_GUARD(x_val);
236
- RB_GC_GUARD(y_val);
237
-
238
- return t_val;
239
- }
240
-
241
- /**
242
- * Predict class labels or values for given samples.
243
- *
244
- * @overload predict(x, param, model) -> Numo::DFloat
245
- * @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
246
- * @param param [Hash] The parameters of the trained SVM model.
247
- * @param model [Hash] The model obtained from the training procedure.
248
- *
249
- * @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
250
- * @return [Numo::DFloat] (shape: [n_samples]) The predicted class label or value of each sample.
251
- */
252
- static
253
- VALUE predict(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
254
- {
255
- struct svm_parameter* param;
256
- struct svm_model* model;
257
- struct svm_node* x_nodes;
258
- narray_t* x_nary;
259
- double* x_pt;
260
- size_t y_shape[1];
261
- VALUE y_val;
262
- double* y_pt;
263
- int i, j, k;
264
- int n_samples;
265
- int n_features;
266
- int n_nonzero_features;
267
-
268
- /* Obtain C data structures. */
269
- if (CLASS_OF(x_val) != numo_cDFloat) {
270
- x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
271
- }
272
- if (!RTEST(nary_check_contiguous(x_val))) {
273
- x_val = nary_dup(x_val);
274
- }
275
-
276
- GetNArray(x_val, x_nary);
277
- if (NA_NDIM(x_nary) != 2) {
278
- rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
279
- return Qnil;
280
- }
281
-
282
- param = rb_hash_to_svm_parameter(param_hash);
283
- model = rb_hash_to_svm_model(model_hash);
284
- model->param = *param;
285
-
286
- /* Initialize some variables. */
287
- n_samples = (int)NA_SHAPE(x_nary)[0];
288
- n_features = (int)NA_SHAPE(x_nary)[1];
289
- y_shape[0] = n_samples;
290
- y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
291
- y_pt = (double*)na_get_pointer_for_write(y_val);
292
- x_pt = (double*)na_get_pointer_for_read(x_val);
293
-
294
- /* Predict values. */
295
- for (i = 0; i < n_samples; i++) {
296
- x_nodes = dbl_vec_to_svm_node(&x_pt[i * n_features], n_features);
297
- y_pt[i] = svm_predict(model, x_nodes);
298
- xfree(x_nodes);
299
- }
300
-
301
- xfree_svm_model(model);
302
- xfree_svm_parameter(param);
303
-
304
- RB_GC_GUARD(x_val);
305
-
306
- return y_val;
307
- }
308
-
309
- /**
310
- * Calculate decision values for given samples.
311
- *
312
- * @overload decision_function(x, param, model) -> Numo::DFloat
313
- * @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the scores.
314
- * @param param [Hash] The parameters of the trained SVM model.
315
- * @param model [Hash] The model obtained from the training procedure.
316
- *
317
- * @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
318
- * @return [Numo::DFloat] (shape: [n_samples, n_classes * (n_classes - 1) / 2]) The decision value of each sample.
319
- */
320
- static
321
- VALUE decision_function(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
322
- {
323
- struct svm_parameter* param;
324
- struct svm_model* model;
325
- struct svm_node* x_nodes;
326
- narray_t* x_nary;
327
- double* x_pt;
328
- size_t y_shape[2];
329
- VALUE y_val;
330
- double* y_pt;
331
- double* dec_values;
332
- int y_cols;
333
- int i, j;
334
- int n_samples;
335
- int n_features;
336
-
337
- /* Obtain C data structures. */
338
- if (CLASS_OF(x_val) != numo_cDFloat) {
339
- x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
340
- }
341
- if (!RTEST(nary_check_contiguous(x_val))) {
342
- x_val = nary_dup(x_val);
343
- }
344
-
345
- GetNArray(x_val, x_nary);
346
- if (NA_NDIM(x_nary) != 2) {
347
- rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
348
- return Qnil;
349
- }
350
-
351
- param = rb_hash_to_svm_parameter(param_hash);
352
- model = rb_hash_to_svm_model(model_hash);
353
- model->param = *param;
354
-
355
- /* Initialize some variables. */
356
- n_samples = (int)NA_SHAPE(x_nary)[0];
357
- n_features = (int)NA_SHAPE(x_nary)[1];
358
-
359
- if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) {
360
- y_shape[0] = n_samples;
361
- y_shape[1] = 1;
362
- y_val = rb_narray_new(numo_cDFloat, 1, y_shape);
363
- } else {
364
- y_shape[0] = n_samples;
365
- y_shape[1] = model->nr_class * (model->nr_class - 1) / 2;
366
- y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
367
- }
368
-
369
- x_pt = (double*)na_get_pointer_for_read(x_val);
370
- y_pt = (double*)na_get_pointer_for_write(y_val);
371
-
372
- /* Predict values. */
373
- if (model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) {
374
- for (i = 0; i < n_samples; i++) {
375
- x_nodes = dbl_vec_to_svm_node(&x_pt[i * n_features], n_features);
376
- svm_predict_values(model, x_nodes, &y_pt[i]);
377
- xfree(x_nodes);
378
- }
379
- } else {
380
- y_cols = (int)y_shape[1];
381
- dec_values = ALLOC_N(double, y_cols);
382
- for (i = 0; i < n_samples; i++) {
383
- x_nodes = dbl_vec_to_svm_node(&x_pt[i * n_features], n_features);
384
- svm_predict_values(model, x_nodes, dec_values);
385
- xfree(x_nodes);
386
- for (j = 0; j < y_cols; j++) {
387
- y_pt[i * y_cols + j] = dec_values[j];
388
- }
389
- }
390
- xfree(dec_values);
391
- }
392
-
393
- xfree_svm_model(model);
394
- xfree_svm_parameter(param);
395
-
396
- RB_GC_GUARD(x_val);
397
-
398
- return y_val;
399
- }
400
-
401
- /**
402
- * Predict class probability for given samples. The model must have probability information calcualted in training procedure.
403
- * The parameter ':probability' set to 1 in training procedure.
404
- *
405
- * @overload predict_proba(x, param, model) -> Numo::DFloat
406
- * @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the class probabilities.
407
- * @param param [Hash] The parameters of the trained SVM model.
408
- * @param model [Hash] The model obtained from the training procedure.
409
- *
410
- * @raise [ArgumentError] If the sample array is not 2-dimensional, this error is raised.
411
- * @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probablity of each class per sample.
412
- */
413
- static
414
- VALUE predict_proba(VALUE self, VALUE x_val, VALUE param_hash, VALUE model_hash)
415
- {
416
- struct svm_parameter* param;
417
- struct svm_model* model;
418
- struct svm_node* x_nodes;
419
- narray_t* x_nary;
420
- double* x_pt;
421
- size_t y_shape[2];
422
- VALUE y_val = Qnil;
423
- double* y_pt;
424
- double* probs;
425
- int i, j;
426
- int n_samples;
427
- int n_features;
428
-
429
- GetNArray(x_val, x_nary);
430
- if (NA_NDIM(x_nary) != 2) {
431
- rb_raise(rb_eArgError, "Expect samples to be 2-D array.");
432
- return Qnil;
433
- }
434
-
435
- param = rb_hash_to_svm_parameter(param_hash);
436
- model = rb_hash_to_svm_model(model_hash);
437
- model->param = *param;
438
-
439
- if ((model->param.svm_type == C_SVC || model->param.svm_type == NU_SVC) && model->probA != NULL && model->probB != NULL) {
440
- /* Obtain C data structures. */
441
- if (CLASS_OF(x_val) != numo_cDFloat) {
442
- x_val = rb_funcall(numo_cDFloat, rb_intern("cast"), 1, x_val);
443
- }
444
- if (!RTEST(nary_check_contiguous(x_val))) {
445
- x_val = nary_dup(x_val);
446
- }
447
-
448
- /* Initialize some variables. */
449
- n_samples = (int)NA_SHAPE(x_nary)[0];
450
- n_features = (int)NA_SHAPE(x_nary)[1];
451
- y_shape[0] = n_samples;
452
- y_shape[1] = model->nr_class;
453
- y_val = rb_narray_new(numo_cDFloat, 2, y_shape);
454
- x_pt = (double*)na_get_pointer_for_read(x_val);
455
- y_pt = (double*)na_get_pointer_for_write(y_val);
456
-
457
- /* Predict values. */
458
- probs = ALLOC_N(double, model->nr_class);
459
- for (i = 0; i < n_samples; i++) {
460
- x_nodes = dbl_vec_to_svm_node(&x_pt[i * n_features], n_features);
461
- svm_predict_probability(model, x_nodes, probs);
462
- xfree(x_nodes);
463
- for (j = 0; j < model->nr_class; j++) {
464
- y_pt[i * model->nr_class + j] = probs[j];
465
- }
466
- }
467
- xfree(probs);
468
- }
469
-
470
- xfree_svm_model(model);
471
- xfree_svm_parameter(param);
472
-
473
- RB_GC_GUARD(x_val);
474
-
475
- return y_val;
476
- }
477
-
478
- /**
479
- * Load the SVM parameters and model from a text file with LIBSVM format.
480
- *
481
- * @param filename [String] The path to a file to load.
482
- * @raise [IOError] This error raises when failed to load the model file.
483
- * @return [Array] Array contains the SVM parameters and model.
484
- */
485
- static
486
- VALUE load_svm_model(VALUE self, VALUE filename)
487
- {
488
- char* filename_ = StringValuePtr(filename);
489
- struct svm_model* model = svm_load_model(filename_);
490
- VALUE res = rb_ary_new2(2);
491
- VALUE param_hash = Qnil;
492
- VALUE model_hash = Qnil;
493
-
494
- if (model == NULL) {
495
- rb_raise(rb_eIOError, "Failed to load file '%s'", filename_);
496
- return Qnil;
497
- }
498
-
499
- if (model) {
500
- param_hash = svm_parameter_to_rb_hash(&(model->param));
501
- model_hash = svm_model_to_rb_hash(model);
502
- svm_free_and_destroy_model(&model);
503
- }
504
-
505
- rb_ary_store(res, 0, param_hash);
506
- rb_ary_store(res, 1, model_hash);
507
-
508
- RB_GC_GUARD(filename);
509
-
510
- return res;
511
- }
512
-
513
- /**
514
- * Save the SVM parameters and model as a text file with LIBSVM format. The saved file can be used with the libsvm tools.
515
- * Note that the svm_save_model saves only the parameters necessary for estimation with the trained model.
516
- *
517
- * @overload save_svm_model(filename, param, model) -> Boolean
518
- * @param filename [String] The path to a file to save.
519
- * @param param [Hash] The parameters of the trained SVM model.
520
- * @param model [Hash] The model obtained from the training procedure.
521
- *
522
- * @raise [IOError] This error raises when failed to save the model file.
523
- * @return [Boolean] true on success, or false if an error occurs.
524
- */
525
- static
526
- VALUE save_svm_model(VALUE self, VALUE filename, VALUE param_hash, VALUE model_hash)
527
- {
528
- char* filename_ = StringValuePtr(filename);
529
- struct svm_parameter* param = rb_hash_to_svm_parameter(param_hash);
530
- struct svm_model* model = rb_hash_to_svm_model(model_hash);
531
- int res;
532
-
533
- model->param = *param;
534
- res = svm_save_model(filename_, model);
535
-
536
- xfree_svm_model(model);
537
- xfree_svm_parameter(param);
538
-
539
- if (res < 0) {
540
- rb_raise(rb_eIOError, "Failed to save file '%s'", filename_);
541
- return Qfalse;
542
- }
543
-
544
- RB_GC_GUARD(filename);
545
-
546
- return Qtrue;
547
- }
548
-
549
- void Init_libsvmext()
550
- {
551
- rb_require("numo/narray");
552
-
553
- /**
554
- * Document-module: Numo
555
- * Numo is the top level namespace of NUmerical MOdules for Ruby.
556
- */
557
- mNumo = rb_define_module("Numo");
558
-
559
- /**
560
- * Document-module: Numo::Libsvm
561
- * Numo::Libsvm is a binding library for LIBSVM that handles dataset with Numo::NArray.
562
- */
563
- mLibsvm = rb_define_module_under(mNumo, "Libsvm");
564
-
565
- /* The version of LIBSVM used in backgroud library. */
566
- rb_define_const(mLibsvm, "LIBSVM_VERSION", INT2NUM(LIBSVM_VERSION));
567
-
568
- rb_define_module_function(mLibsvm, "train", train, 3);
569
- rb_define_module_function(mLibsvm, "cv", cross_validation, 4);
570
- rb_define_module_function(mLibsvm, "predict", predict, 3);
571
- rb_define_module_function(mLibsvm, "decision_function", decision_function, 3);
572
- rb_define_module_function(mLibsvm, "predict_proba", predict_proba, 3);
573
- rb_define_module_function(mLibsvm, "load_svm_model", load_svm_model, 1);
574
- rb_define_module_function(mLibsvm, "save_svm_model", save_svm_model, 3);
575
-
576
- rb_init_svm_type_module();
577
- rb_init_kernel_type_module();
578
- }
@@ -1,18 +0,0 @@
1
- #ifndef NUMO_LIBSVMEXT_H
2
- #define NUMO_LIBSVMEXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
- #include <svm.h>
7
- #include <ruby.h>
8
- #include <numo/narray.h>
9
- #include <numo/template.h>
10
-
11
- #include "converter.h"
12
- #include "svm_parameter.h"
13
- #include "svm_model.h"
14
- #include "svm_problem.h"
15
- #include "svm_type.h"
16
- #include "kernel_type.h"
17
-
18
- #endif /* NUMO_LIBSVMEXT_H */
@@ -1,89 +0,0 @@
1
-
2
- #include "svm_model.h"
3
-
4
- struct svm_model* rb_hash_to_svm_model(VALUE model_hash)
5
- {
6
- VALUE el;
7
- struct svm_model* model = ALLOC(struct svm_model);
8
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nr_class")));
9
- model->nr_class = el != Qnil ? NUM2INT(el) : 0;
10
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("l")));
11
- model->l = el != Qnil ? NUM2INT(el) : 0;
12
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("SV")));
13
- model->SV = nary_to_svm_nodes(el);
14
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_coef")));
15
- model->sv_coef = nary_to_dbl_mat(el);
16
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("rho")));
17
- model->rho = nary_to_dbl_vec(el);
18
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probA")));
19
- model->probA = nary_to_dbl_vec(el);
20
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("probB")));
21
- model->probB = nary_to_dbl_vec(el);
22
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("sv_indices")));
23
- model->sv_indices = nary_to_int_vec(el);
24
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("label")));
25
- model->label = nary_to_int_vec(el);
26
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("nSV")));
27
- model->nSV = nary_to_int_vec(el);
28
- el = rb_hash_aref(model_hash, ID2SYM(rb_intern("free_sv")));
29
- model->free_sv = el != Qnil ? NUM2INT(el) : 0;
30
- return model;
31
- }
32
-
33
- VALUE svm_model_to_rb_hash(struct svm_model* const model)
34
- {
35
- int const n_classes = model->nr_class;
36
- int const n_support_vecs = model->l;
37
- VALUE support_vecs = model->SV ? svm_nodes_to_nary(model->SV, n_support_vecs) : Qnil;
38
- VALUE coefficients = model->sv_coef ? dbl_mat_to_nary(model->sv_coef, n_classes - 1, n_support_vecs) : Qnil;
39
- VALUE intercepts = model->rho ? dbl_vec_to_nary(model->rho, n_classes * (n_classes - 1) / 2) : Qnil;
40
- VALUE prob_alpha = model->probA ? dbl_vec_to_nary(model->probA, n_classes * (n_classes - 1) / 2) : Qnil;
41
- VALUE prob_beta = model->probB ? dbl_vec_to_nary(model->probB, n_classes * (n_classes - 1) / 2) : Qnil;
42
- VALUE sv_indices = model->sv_indices ? int_vec_to_nary(model->sv_indices, n_support_vecs) : Qnil;
43
- VALUE labels = model->label ? int_vec_to_nary(model->label, n_classes) : Qnil;
44
- VALUE n_support_vecs_each_class = model->nSV ? int_vec_to_nary(model->nSV, n_classes) : Qnil;
45
- VALUE model_hash = rb_hash_new();
46
- rb_hash_aset(model_hash, ID2SYM(rb_intern("nr_class")), INT2NUM(n_classes));
47
- rb_hash_aset(model_hash, ID2SYM(rb_intern("l")), INT2NUM(n_support_vecs));
48
- rb_hash_aset(model_hash, ID2SYM(rb_intern("SV")), support_vecs);
49
- rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_coef")), coefficients);
50
- rb_hash_aset(model_hash, ID2SYM(rb_intern("rho")), intercepts);
51
- rb_hash_aset(model_hash, ID2SYM(rb_intern("probA")), prob_alpha);
52
- rb_hash_aset(model_hash, ID2SYM(rb_intern("probB")), prob_beta);
53
- rb_hash_aset(model_hash, ID2SYM(rb_intern("sv_indices")), sv_indices);
54
- rb_hash_aset(model_hash, ID2SYM(rb_intern("label")), labels);
55
- rb_hash_aset(model_hash, ID2SYM(rb_intern("nSV")), n_support_vecs_each_class);
56
- rb_hash_aset(model_hash, ID2SYM(rb_intern("free_sv")), INT2NUM(model->free_sv));
57
- return model_hash;
58
- }
59
-
60
- void xfree_svm_model(struct svm_model* model)
61
- {
62
- int i;
63
- if (model) {
64
- if (model->SV) {
65
- for (i = 0; i < model->l; xfree(model->SV[i++]));
66
- xfree(model->SV);
67
- model->SV = NULL;
68
- }
69
- if (model->sv_coef) {
70
- for (i = 0; i < model->nr_class - 1; xfree(model->sv_coef[i++]));
71
- xfree(model->sv_coef);
72
- model->sv_coef = NULL;
73
- }
74
- xfree(model->rho);
75
- model->rho = NULL;
76
- xfree(model->probA);
77
- model->probA = NULL;
78
- xfree(model->probB);
79
- model->probB = NULL;
80
- xfree(model->sv_indices);
81
- model->sv_indices = NULL;
82
- xfree(model->label);
83
- model->label = NULL;
84
- xfree(model->nSV);
85
- model->nSV = NULL;
86
- xfree(model);
87
- model = NULL;
88
- }
89
- }
@@ -1,15 +0,0 @@
1
- #ifndef NUMO_LIBSVM_SVM_MODEL_H
2
- #define NUMO_LIBSVM_SVM_MODEL_H 1
3
-
4
- #include <svm.h>
5
- #include <ruby.h>
6
- #include <numo/narray.h>
7
- #include <numo/template.h>
8
-
9
- #include "converter.h"
10
-
11
- struct svm_model* rb_hash_to_svm_model(VALUE model_hash);
12
- VALUE svm_model_to_rb_hash(struct svm_model* const model);
13
- void xfree_svm_model(struct svm_model* model);
14
-
15
- #endif /* NUMO_LIBSVM_SVM_MODEL_H */