svmredlight 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/svmredlight.c ADDED
@@ -0,0 +1,762 @@
1
+ #include "ruby.h"
2
+ #include "svm_light/svm_common.h"
3
+ #include "string.h"
4
+
5
+ /* Helper function to determine if a model uses linear kernel, this could be a #define
6
+ * macro */
7
+ int
8
+ is_linear(MODEL *model){
9
+ return model->kernel_parm.kernel_type == 0;
10
+ }
11
+
12
+ // Modules and Classes
13
+ static VALUE rb_mSvmLight;
14
+ static VALUE rb_cModel;
15
+ static VALUE rb_cDocument;
16
+
17
+ // GC functions
18
+
19
+ /* Not using deep free anymore, let ruby call free on the documents otherwise we might end
20
+ * up having double free problems, from svm_learn_main: Warning: The model contains
21
+ * references to the original data 'docs'. If you want to free the original data, and
22
+ * only keep the model, you have to make a deep copy of 'model'.
23
+ * deep_copy_of_model=copy_model(model); */
24
+ void
25
+ model_free(MODEL *m){
26
+ if(m)
27
+ free_model(m, 0);
28
+ }
29
+
30
+ void
31
+ doc_free(DOC *d){
32
+ if(d)
33
+ free_example(d, 1);
34
+ }
35
+
36
+ /* Read a svm_light model from a file generated by svm_learn receives the filename as
37
+ * argument do make sure the file exists before calling this! otherwise exit(1) might be
38
+ * called and the ruby interpreter will die.*/
39
+ static VALUE
40
+ model_read_from_file(VALUE klass, VALUE filename){
41
+ Check_Type(filename, T_STRING);
42
+ MODEL *m;
43
+
44
+ m = read_model(StringValuePtr(filename));
45
+
46
+ if(is_linear(m))
47
+ add_weight_vector_to_linear_model(m);
48
+
49
+ return Data_Wrap_Struct(klass, 0, model_free, m);
50
+ }
51
+
52
+ /* Helper function type checks a string meant to be used as a learn_parm, in case of error
53
+ * returns 1 and sets the correct exception message in error, on success returns 0 and
54
+ * copies the c string data of new_val to target*/
55
+ int check_string_param(VALUE new_val,
56
+ const char *default_val,
57
+ char *target,
58
+ const char *name,
59
+ char *error){
60
+
61
+ if(TYPE(new_val) == T_STRING){
62
+ strlcpy(target, StringValuePtr(new_val), 199);
63
+ }else if(NIL_P(new_val)){
64
+ strlcpy(target, default_val, 199);
65
+ }else{
66
+ sprintf(error, "The value of the learning option '%s' must be a string", name);
67
+ return 1;
68
+ }
69
+
70
+ return 0;
71
+ }
72
+
73
+ /* Helper function type checks a long meant to be used as a learn_parm or kernel_parm, in
74
+ * case of error returns 1 and sets the correct exception message in error, on success
75
+ * returns 0 and copies the c string data of new_val to target*/
76
+ int check_long_param(VALUE new_val,
77
+ long default_val,
78
+ long *target,
79
+ const char *name,
80
+ char *error){
81
+ if(TYPE(new_val) == T_FLOAT || TYPE(new_val) == T_FIXNUM){
82
+ *target = FIX2LONG(new_val);
83
+ }else if(NIL_P(new_val)){
84
+ *target = default_val;
85
+ }else{
86
+ sprintf(error, "The value of the learning option '%s' must be a numeric", name);
87
+ return 1;
88
+ }
89
+
90
+ return 0;
91
+ }
92
+
93
+ /* Helper function type checks a double meant to be used as a learn_parm or kernel_parm, in
94
+ * case of error returns 1 and sets the correct exception message in error, on success
95
+ * returns 0 and copies the c string data of new_val to target*/
96
+ int check_double_param(VALUE new_val,
97
+ double default_val,
98
+ double *target,
99
+ const char *name,
100
+ char *error){
101
+ if(TYPE(new_val) == T_FLOAT || TYPE(new_val) == T_FIXNUM){
102
+ *target = NUM2DBL(new_val);
103
+ }else if(NIL_P(new_val) ){
104
+ *target = default_val;
105
+ }else{
106
+ sprintf(error, "The value of the learning option '%s' must be a numeric", name);
107
+ return 1;
108
+ }
109
+
110
+ return 0;
111
+ }
112
+
113
+ /* Helper function type checks an int meant to be used as a boolean learn_parm or
114
+ * kernel_parm, in case of error returns 1 and sets the correct exception message in
115
+ * error, on success returns 0 and copies the c string data of new_val to target*/
116
+ int check_bool_param(VALUE new_val,
117
+ long default_val,
118
+ long *target,
119
+ const char *name,
120
+ char *error){
121
+ if(TYPE(new_val) == T_TRUE){
122
+ *target = 1L;
123
+ }else if(TYPE(new_val) == T_FALSE){
124
+ *target = 0L;
125
+ }else if(NIL_P(new_val) ){
126
+ *target = default_val;
127
+ }else{
128
+ sprintf(error, "The value of the learning option '%s' must be a true or false", name);
129
+ return 1;
130
+ }
131
+
132
+ return 0;
133
+ }
134
+
135
+ /* Helper function in charge of setting up the learn parameters before they are passed to
136
+ * the svm_learn_classification copies part of the logic in svm_learn_main.c */
137
+ int setup_learn_params(LEARN_PARM *c_learn_param, VALUE r_hash, char *error_message){
138
+ // Defaults taken from from svm_learn_main
139
+ VALUE inter_val, temp_ary, svm_type, svm_type_ruby_str;
140
+ char *svm_type_str;
141
+
142
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("predfile"));
143
+ if(1 == check_string_param(inter_val,
144
+ "trans_predictions",
145
+ &c_learn_param->predfile,
146
+ "predfile",
147
+ error_message)){
148
+ return 1;
149
+ }
150
+
151
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("alphafile"));
152
+ if(1 == check_string_param(inter_val,
153
+ "",
154
+ &c_learn_param->alphafile,
155
+ "alphafile",
156
+ error_message)){
157
+ return 1;
158
+ }
159
+
160
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("biased_hyperplane"));
161
+ if(1 == check_bool_param(inter_val,
162
+ 1L,
163
+ &(c_learn_param->biased_hyperplane),
164
+ "biased_hyperplane",
165
+ error_message)){
166
+ return 1;
167
+ }
168
+
169
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("sharedslack"));
170
+ if(1 == check_bool_param(inter_val,
171
+ 0L,
172
+ &(c_learn_param->sharedslack),
173
+ "sharedslack",
174
+ error_message)){
175
+ return 1;
176
+ }
177
+
178
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("remove_inconsistent"));
179
+ if(1 == check_bool_param(inter_val,
180
+ 0L,
181
+ &(c_learn_param->remove_inconsistent),
182
+ "remove_inconsistent",
183
+ error_message)){
184
+ return 1;
185
+ }
186
+
187
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("skip_final_opt_check"));
188
+ if(1 == check_bool_param(inter_val,
189
+ 0L,
190
+ &(c_learn_param->skip_final_opt_check),
191
+ "skip_final_opt_check",
192
+ error_message)){
193
+ return 1;
194
+ }
195
+
196
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_newvarsinqp"));
197
+ if(1 == check_bool_param(inter_val,
198
+ 0L,
199
+ &(c_learn_param->svm_newvarsinqp),
200
+ "svm_newvarsinqp",
201
+ error_message)){
202
+ return 1;
203
+ }
204
+
205
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("compute_loo"));
206
+ if(1 == check_bool_param(inter_val,
207
+ 0L,
208
+ &(c_learn_param->compute_loo),
209
+ "compute_loo",
210
+ error_message)){
211
+ return 1;
212
+ }
213
+
214
+
215
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_maxqpsize"));
216
+ if(1 == check_long_param(inter_val,
217
+ 10L,
218
+ &(c_learn_param->svm_maxqpsize),
219
+ "svm_maxqpsize",
220
+ error_message)){
221
+ return 1;
222
+ }
223
+
224
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_iter_to_shrink"));
225
+ if(1 == check_long_param(inter_val,
226
+ -9999,
227
+ &(c_learn_param->svm_iter_to_shrink),
228
+ "svm_iter_to_shrink",
229
+ error_message)){
230
+ return 1;
231
+ }
232
+
233
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("maxiter"));
234
+ if(1 == check_long_param(inter_val,
235
+ 100000,
236
+ &(c_learn_param->maxiter),
237
+ "maxiter",
238
+ error_message)){
239
+ return 1;
240
+ }
241
+
242
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("kernel_cache_size"));
243
+ if(1 == check_long_param(inter_val,
244
+ 40L,
245
+ &(c_learn_param->kernel_cache_size),
246
+ "kernel_cache_size",
247
+ error_message)){
248
+ return 1;
249
+ }
250
+
251
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("xa_depth"));
252
+ if(1 == check_long_param(inter_val,
253
+ 0L,
254
+ &(c_learn_param->xa_depth),
255
+ "xa_depth",
256
+ error_message)){
257
+ return 1;
258
+ }
259
+
260
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_c"));
261
+ if(1 == check_double_param(inter_val,
262
+ 0.0,
263
+ &(c_learn_param->svm_c),
264
+ "svm_c",
265
+ error_message)){
266
+ return 1;
267
+ }
268
+
269
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("eps"));
270
+ if(1 == check_double_param(inter_val,
271
+ 0.1,
272
+ &(c_learn_param->eps),
273
+ "eps",
274
+ error_message)){
275
+ return 1;
276
+ }
277
+
278
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("transduction_posratio"));
279
+ if(1 == check_double_param(inter_val,
280
+ -1.0,
281
+ &(c_learn_param->transduction_posratio),
282
+ "transduction_posratio",
283
+ error_message)){
284
+ return 1;
285
+ }
286
+
287
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_costratio"));
288
+ if(1 == check_double_param(inter_val,
289
+ 1.0,
290
+ &(c_learn_param->svm_costratio),
291
+ "svm_costratio",
292
+ error_message)){
293
+ return 1;
294
+ }
295
+
296
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_costratio_unlab"));
297
+ if(1 == check_double_param(inter_val,
298
+ 1.0,
299
+ &(c_learn_param->svm_costratio_unlab),
300
+ "svm_costratio_unlab",
301
+ error_message)){
302
+ return 1;
303
+ }
304
+
305
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_unlabbound"));
306
+ if(1 == check_double_param(inter_val,
307
+ 1.0000000000000001e-05,
308
+ &(c_learn_param->svm_unlabbound),
309
+ "svm_unlabbound",
310
+ error_message)){
311
+ return 1;
312
+ }
313
+
314
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("epsilon_crit"));
315
+ if(1 == check_double_param(inter_val,
316
+ 0.001,
317
+ &(c_learn_param->epsilon_crit),
318
+ "epsilon_crit",
319
+ error_message)){
320
+ return 1;
321
+ }
322
+
323
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("epsilon_a"));
324
+ if(1 == check_double_param(inter_val,
325
+ 1E-15,
326
+ &(c_learn_param->epsilon_a),
327
+ "epsilon_a",
328
+ error_message)){
329
+ return 1;
330
+ }
331
+
332
+ c_learn_param->rho=1.0;
333
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("rho"));
334
+ if(1 == check_double_param(inter_val,
335
+ 1.0,
336
+ &(c_learn_param->rho),
337
+ "rho",
338
+ error_message)){
339
+ return 1;
340
+ }
341
+
342
+
343
+ return 0;
344
+ }
345
+
346
+ int setup_kernel_params(KERNEL_PARM *c_kernel_param, VALUE r_hash, char *error_message){
347
+ VALUE inter_val;
348
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("poly_degree"));
349
+ if(1 == check_long_param(inter_val,
350
+ 3L,
351
+ &(c_kernel_param->poly_degree),
352
+ "poly_degree",
353
+ error_message)){
354
+ return 1;
355
+ }
356
+
357
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("rbf_gamma"));
358
+ if(1 == check_double_param(inter_val,
359
+ 1.0,
360
+ &(c_kernel_param->rbf_gamma),
361
+ "rbf_gamma",
362
+ error_message)){
363
+ return 1;
364
+ }
365
+
366
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("coef_lin"));
367
+ if(1 == check_double_param(inter_val,
368
+ 1.0,
369
+ &(c_kernel_param->coef_lin),
370
+ "coef_lin",
371
+ error_message)){
372
+ return 1;
373
+ }
374
+
375
+ inter_val = rb_hash_aref(r_hash, rb_str_new2("coef_const"));
376
+ if(1 == check_double_param(inter_val,
377
+ 1.0,
378
+ &(c_kernel_param->coef_const),
379
+ "coef_const",
380
+ error_message)){
381
+ return 1;
382
+ }
383
+
384
+ // No support for custom kernel yet just set it to empty
385
+ strlcpy(c_kernel_param->custom,"empty", 49);
386
+ return 0;
387
+ }
388
+
389
+ /* Do logic checks for the learn and kernel params, this logic is copied from
390
+ * svm_learn_main.c */
391
+ int check_kernel_and_learn_params_logic(KERNEL_PARM *c_kernel_param,
392
+ LEARN_PARM *c_learn_param, char *error_msg){
393
+
394
+ if(c_learn_param->svm_iter_to_shrink == -9999) {
395
+ if(c_kernel_param->kernel_type == LINEAR)
396
+ c_learn_param->svm_iter_to_shrink=2;
397
+ else
398
+ c_learn_param->svm_iter_to_shrink=100;
399
+ }
400
+
401
+ //It does not make sense to skip the final optimality check for linear kernels.
402
+ if((c_learn_param->skip_final_opt_check)
403
+ && (c_kernel_param->kernel_type == LINEAR)) {
404
+ c_learn_param->skip_final_opt_check=0;
405
+ }
406
+
407
+ if((c_learn_param->skip_final_opt_check)
408
+ && (c_learn_param->remove_inconsistent)) {
409
+ strncpy(error_msg,"It is necessary to do the final optimality "
410
+ "check when removing inconsistent examples.", 300);
411
+ return 1;
412
+ }
413
+
414
+ if((c_learn_param->svm_maxqpsize<2)) {
415
+ snprintf(error_msg, 300, "Maximum size of QP-subproblems "
416
+ "not in valid range: %ld [2..]",c_learn_param->svm_maxqpsize);
417
+ return 1;
418
+ }
419
+
420
+ if((c_learn_param->svm_maxqpsize<c_learn_param->svm_newvarsinqp)) {
421
+ snprintf(error_msg, 300, "Maximum size of QP-subproblems [%ld] must be larger than the number of",
422
+ "new variables [%ld] entering the working set in each iteration.\n",c_learn_param->svm_maxqpsize
423
+ ,c_learn_param->svm_newvarsinqp);
424
+ return 1;
425
+ }
426
+
427
+ if(c_learn_param->svm_iter_to_shrink<1) {
428
+ snprintf(error_msg, 300, "Maximum number of iterations for shrinking not"
429
+ " in valid range: %ld [1,..]",c_learn_param->svm_iter_to_shrink);
430
+ return 1;
431
+ }
432
+
433
+ if(c_learn_param->svm_c<0) {
434
+ strncpy(error_msg,"The C parameter must be greater than zero", 300);
435
+ return 1;
436
+ }
437
+
438
+ if(c_learn_param->transduction_posratio>1) {
439
+ strncpy(error_msg,"The fraction of unlabeled examples to classify as positives must"
440
+ "be less than 1.0", 300);
441
+ return 1;
442
+ }
443
+
444
+ if(c_learn_param->svm_costratio<=0) {
445
+ strncpy(error_msg,"The COSTRATIO parameter must be greater than zero", 300);
446
+ return 1;
447
+ }
448
+
449
+ if(c_learn_param->epsilon_crit<=0) {
450
+ strncpy(error_msg,"The epsilon parameter must be greater than zero", 300);
451
+ return 1;
452
+ }
453
+
454
+ if(c_learn_param->rho<0) {
455
+ strncpy(error_msg, "The parameter rho for xi/alpha-estimates and leave-one-out pruning must"
456
+ "be greater than zero (typically 1.0 or 2.0, see T. Joachims, Estimating the"
457
+ "Generalization Performance of an SVM Efficiently, ICML, 2000.)!", 300);
458
+ return 1;
459
+ }
460
+
461
+ if((c_learn_param->xa_depth<0) || (c_learn_param->xa_depth>100)) {
462
+ strncpy(error_msg, "The parameter depth for ext. xi/alpha-estimates must be in [0..100] (zero) "
463
+ "for switching to the conventional xa/estimates described in T. Joachims,"
464
+ "Estimating the Generalization Performance of an SVM Efficiently, ICML, 2000.)", 300);
465
+ return 1;
466
+ }
467
+
468
+ return 0;
469
+ }
470
+
471
+ /* This function will let you train a new SVM model, for now we *only* support
472
+ * classification SVMs and linear kernels, in ruby-land the kernel and learning params
473
+ * will be represented by hashes where the keys are the name of the respective field in
474
+ * the options structure
475
+ *
476
+ * @param [Array] r_docs_and_classes is an array of arrays where each of the inner arrays must have two elements, the first a Document and the second a label (1, -1 ) for classification
477
+ * @param [Hash] learn_params the learning options, each key is the name of a filed in the LEARN_PARM struct
478
+ * @param [Hash] kernel_params the kernel options, each key is the name of a filed in the KERNEL_PARM struct
479
+ * @param [Bool] use_cache, useless for now caches cannot be set for linear kernels
480
+ * @param [Array] alpha, array of alpha values
481
+ * */
482
+ static VALUE
483
+ model_learn_classification(VALUE klass,
484
+ VALUE r_docs_and_classes, // Docs + labels array of arrays
485
+ VALUE learn_params, // Options hash with learning options
486
+ VALUE kernel_params, // Options hash with kernel options
487
+ VALUE use_cache, // If no linear
488
+ VALUE alpha
489
+ ){
490
+ int i;
491
+ double *labels = NULL, *alpha_in = NULL;
492
+ long totdocs, totwords = 0, fnum = 0;
493
+ MODEL *m = NULL;
494
+ DOC **c_docs = NULL;
495
+ LEARN_PARM c_learn_param;
496
+ KERNEL_PARM c_kernel_param;
497
+ VALUE temp_ary, exception = rb_eArgError;
498
+ char error_msg[300];
499
+
500
+ Check_Type(r_docs_and_classes, T_ARRAY);
501
+ Check_Type(learn_params, T_HASH);
502
+ Check_Type(kernel_params, T_HASH);
503
+
504
+ if(!(TYPE(alpha) == T_ARRAY || NIL_P(alpha) ))
505
+ rb_raise(rb_eTypeError, "alpha must be an numeric array or nil");
506
+
507
+ if(TYPE(alpha) == T_ARRAY){
508
+
509
+ alpha_in = my_malloc(sizeof(double) * RARRAY_LEN(alpha));
510
+
511
+ for(i=0; i < RARRAY_LEN(alpha); i++){
512
+
513
+ if(TYPE(RARRAY_PTR(alpha)[i]) != T_FLOAT &&
514
+ TYPE(RARRAY_PTR(alpha)[i]) != T_FIXNUM ){
515
+
516
+ strncpy(error_msg,"All elements of the alpha array must be numeric ", 300);
517
+ goto bail;
518
+ }
519
+
520
+ alpha_in[i] = NUM2DBL(RARRAY_PTR(alpha)[i]);
521
+ }
522
+ }
523
+
524
+ if(setup_learn_params(&c_learn_param, learn_params, error_msg) != 0){
525
+ goto bail;
526
+ }
527
+
528
+ c_learn_param.type = CLASSIFICATION;
529
+
530
+ if(setup_kernel_params(&c_kernel_param, kernel_params, error_msg) != 0){
531
+ goto bail;
532
+ }
533
+
534
+ //TODO Setup kernel cache when we support non linear kernels
535
+ c_kernel_param.kernel_type = LINEAR;
536
+
537
+ if(check_kernel_and_learn_params_logic(&c_kernel_param, &c_learn_param, error_msg) != 0){
538
+ goto bail;
539
+ }
540
+
541
+ totdocs = (long)RARRAY_LEN(r_docs_and_classes);
542
+
543
+ if (totdocs == 0){
544
+ strncpy(error_msg, "Cannot create Model from empty Documents array", 300);
545
+ goto bail;
546
+ }
547
+
548
+ c_docs = (DOC **)my_malloc(sizeof(DOC *)*(totdocs));
549
+ labels = (double*)my_malloc(sizeof(double)*totdocs);
550
+
551
+ for(i=0; i < totdocs; i++){
552
+ // Just one of the documents and classes arrays, we expect temp_ary to have a Document
553
+ // and a label (long)
554
+ temp_ary = RARRAY_PTR(r_docs_and_classes)[i] ;
555
+
556
+ if( TYPE(temp_ary) != T_ARRAY ||
557
+ RARRAY_LEN(temp_ary) < 2 ||
558
+ rb_obj_class(RARRAY_PTR(temp_ary)[0]) != rb_cDocument ||
559
+ (TYPE(RARRAY_PTR(temp_ary)[1]) != T_FLOAT && TYPE(RARRAY_PTR(temp_ary)[1]) != T_FIXNUM )){
560
+
561
+ strncpy(error_msg, "All elements of documents and labels should be arrays,"
562
+ "where the first element is a document and the second a number", 300);
563
+
564
+ goto bail;
565
+ }
566
+
567
+ Data_Get_Struct(RARRAY_PTR(temp_ary)[0], DOC, c_docs[i]);
568
+ labels[i] = NUM2DBL(RARRAY_PTR(temp_ary)[1]);
569
+
570
+ fnum = 0;
571
+
572
+ // Increase feature number while there are still words in the vector
573
+ while(c_docs[i]->fvec->words[fnum].wnum) {
574
+ fnum++;
575
+ }
576
+
577
+ if(c_docs[i]->fvec->words[fnum -1].wnum > totwords)
578
+ totwords = c_docs[i]->fvec->words[fnum-1].wnum;
579
+
580
+ if(totwords > MAXFEATNUM){
581
+ strncpy(error_msg, "The number of features exceeds MAXFEATNUM the maximun "
582
+ "number of features defined for this version of SVMLight", 300);
583
+ goto bail;
584
+ }
585
+ }
586
+
587
+ m = (MODEL *)my_malloc(sizeof(MODEL));
588
+
589
+ svm_learn_classification(c_docs, labels, totdocs, totwords,
590
+ &c_learn_param, &c_kernel_param, NULL, m, alpha_in);
591
+
592
+ free(alpha_in);
593
+ free(labels);
594
+
595
+ // If need arises to free the data do a deep copy of m and create the ruby object with
596
+ // that data.
597
+ // free(c_docs);
598
+ return Data_Wrap_Struct(klass, 0, model_free, m);
599
+
600
+ bail:
601
+ free(alpha_in);
602
+ free(labels);
603
+ free(c_docs);
604
+ rb_raise(exception, error_msg, "%s");
605
+ }
606
+
607
+ /* Classify, takes an example (instance of Document) and returns its classification */
608
+ static VALUE
609
+ model_classify_example(VALUE self, VALUE example){
610
+ DOC *ex;
611
+ MODEL *m;
612
+ double result;
613
+
614
+ Data_Get_Struct(example, DOC, ex);
615
+ Data_Get_Struct(self, MODEL, m);
616
+
617
+ /* Apparently unnecessary code
618
+
619
+ if(is_linear(m))
620
+ result = classify_example_linear(m, ex);
621
+ else
622
+ */
623
+
624
+ result = classify_example(m, ex);
625
+
626
+ return rb_float_new((float)result);
627
+ }
628
+
629
+ static VALUE
630
+ model_support_vectors_count(VALUE self){
631
+ MODEL *m;
632
+ Data_Get_Struct(self, MODEL, m);
633
+
634
+ return INT2FIX(m->sv_num);
635
+ }
636
+
637
+ static VALUE
638
+ model_total_words(VALUE self){
639
+ MODEL *m;
640
+ Data_Get_Struct(self, MODEL, m);
641
+
642
+ return INT2FIX(m->totwords);
643
+ }
644
+
645
+ static VALUE
646
+ model_totdoc(VALUE self){
647
+ MODEL *m;
648
+ Data_Get_Struct(self, MODEL, m);
649
+
650
+ return INT2FIX(m->totdoc);
651
+ }
652
+
653
+ static VALUE
654
+ model_maxdiff(VALUE self){
655
+ MODEL *m;
656
+ Data_Get_Struct(self, MODEL, m);
657
+
658
+ return DBL2NUM(m->maxdiff);
659
+ }
660
+
661
+ /* Creates a DOC from an array of words it also takes an id
662
+ * -1 is normally OK for that value when using in filtering it also takes the C (cost)
663
+ * parameter for the SVM.
664
+ * words_ary an array of arrays like this
665
+ * [[wnum, weight], [wnum, weight], ...] so we do not waste memory, defeating the svec implementation and do
666
+ * not introduce a bunch of 0's that seem to be OK when classifying but screw all up on
667
+ * training*/
668
+ static VALUE
669
+ doc_create(VALUE klass, VALUE id, VALUE cost, VALUE slackid, VALUE queryid, VALUE words_ary ){
670
+ long docnum, i, c_slackid, c_queryid;
671
+ double c;
672
+ WORD *words;
673
+ SVECTOR *vec;
674
+ DOC *d;
675
+ VALUE inner_array;
676
+
677
+ Check_Type(words_ary, T_ARRAY);
678
+ Check_Type(slackid, T_FIXNUM);
679
+ Check_Type(queryid, T_FIXNUM);
680
+
681
+ if (RARRAY_LEN(words_ary) == 0)
682
+ rb_raise(rb_eArgError, "Cannot create Document from empty arrays");
683
+
684
+ words = (WORD*) my_malloc(sizeof(WORD) * (RARRAY_LEN(words_ary) + 1));
685
+
686
+ for(i=0; i < (long)RARRAY_LEN(words_ary); i++){
687
+ inner_array = RARRAY_PTR(words_ary)[i];
688
+ Check_Type(inner_array, T_ARRAY);
689
+ Check_Type(RARRAY_PTR(inner_array)[0], T_FIXNUM);
690
+
691
+ if(!(TYPE(RARRAY_PTR(inner_array)[1]) == T_FLOAT || TYPE(RARRAY_PTR(inner_array)[1]) == T_FIXNUM ))
692
+ rb_raise(rb_eArgError, "Feature weights must be numeric");
693
+
694
+ if(FIX2LONG(RARRAY_PTR(inner_array)[0]) <= 0 )
695
+ rb_raise(rb_eArgError, "Feature number has to be greater than zero");
696
+
697
+ (words[i]).wnum = FIX2LONG(RARRAY_PTR(inner_array)[0]);
698
+ (words[i]).weight = (FVAL)(NUM2DBL(RARRAY_PTR(inner_array)[1]));
699
+ }
700
+ words[i].wnum = 0;
701
+
702
+ vec = create_svector(words, (char*)"", 1.0);
703
+ c = NUM2DBL(cost);
704
+ docnum = FIX2INT(id);
705
+
706
+ d = create_example(docnum, FIX2LONG(queryid), FIX2LONG(slackid), c, vec);
707
+
708
+ return Data_Wrap_Struct(klass, 0, doc_free, d);
709
+ }
710
+
711
+ static VALUE
712
+ doc_get_docnum(VALUE self){
713
+ DOC *d;
714
+ Data_Get_Struct(self, DOC, d);
715
+
716
+ return INT2FIX(d->docnum);
717
+ }
718
+
719
+ static VALUE
720
+ doc_get_slackid(VALUE self){
721
+ DOC *d;
722
+ Data_Get_Struct(self, DOC, d);
723
+
724
+ return INT2FIX(d->slackid);
725
+ }
726
+
727
+ static VALUE
728
+ doc_get_queryid(VALUE self){
729
+ DOC *d;
730
+ Data_Get_Struct(self, DOC, d);
731
+
732
+ return INT2FIX(d->queryid);
733
+ }
734
+
735
+ static VALUE
736
+ doc_get_costfactor(VALUE self){
737
+ DOC *d;
738
+ Data_Get_Struct(self, DOC, d);
739
+
740
+ return DBL2NUM(d->costfactor);
741
+ }
742
+
743
+ void
744
+ Init_svmredlight(){
745
+ rb_mSvmLight = rb_define_module("SVMLight");
746
+ //Model
747
+ rb_cModel = rb_define_class_under(rb_mSvmLight, "Model", rb_cObject);
748
+ rb_define_singleton_method(rb_cModel, "read_from_file", model_read_from_file, 1);
749
+ rb_define_singleton_method(rb_cModel, "learn_classification", model_learn_classification, 5);
750
+ rb_define_method(rb_cModel, "support_vectors_count", model_support_vectors_count, 0);
751
+ rb_define_method(rb_cModel, "total_words", model_total_words, 0);
752
+ rb_define_method(rb_cModel, "classify", model_classify_example, 1);
753
+ rb_define_method(rb_cModel, "totdoc", model_totdoc,0);
754
+ rb_define_method(rb_cModel, "maxdiff", model_maxdiff,0);
755
+ //Document
756
+ rb_cDocument = rb_define_class_under(rb_mSvmLight, "Document", rb_cObject);
757
+ rb_define_singleton_method(rb_cDocument, "create", doc_create, 5);
758
+ rb_define_method(rb_cDocument, "docnum", doc_get_docnum, 0);
759
+ rb_define_method(rb_cDocument, "costfactor", doc_get_costfactor, 0);
760
+ rb_define_method(rb_cDocument, "slackid", doc_get_slackid, 0);
761
+ rb_define_method(rb_cDocument, "queryid", doc_get_queryid, 0);
762
+ }