svmredlight 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/.document +5 -0
- data/Gemfile +14 -0
- data/Gemfile.lock +35 -0
- data/LICENSE.txt +20 -0
- data/README.rdoc +44 -0
- data/Rakefile +46 -0
- data/VERSION +1 -0
- data/examples/example1/example.rb +76 -0
- data/examples/example1/test.dat +601 -0
- data/examples/example1/train.dat +2001 -0
- data/examples/example1/words +9947 -0
- data/ext/extconf.rb +6 -0
- data/ext/svmredlight.c +762 -0
- data/lib/svmredlight/document.rb +23 -0
- data/lib/svmredlight/model.rb +22 -0
- data/lib/svmredlight.rb +4 -0
- data/svmredlight.gemspec +73 -0
- data/test/assets/model +3888 -0
- data/test/helper.rb +19 -0
- data/test/test_document.rb +55 -0
- data/test/test_model.rb +114 -0
- metadata +134 -0
data/ext/svmredlight.c
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
1
|
+
#include "ruby.h"
|
|
2
|
+
#include "svm_light/svm_common.h"
|
|
3
|
+
#include "string.h"
|
|
4
|
+
|
|
5
|
+
/* Helper function to determine if a model uses linear kernel, this could be a #define
|
|
6
|
+
* macro */
|
|
7
|
+
int
|
|
8
|
+
is_linear(MODEL *model){
|
|
9
|
+
return model->kernel_parm.kernel_type == 0;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
// Modules and Classes
|
|
13
|
+
static VALUE rb_mSvmLight;
|
|
14
|
+
static VALUE rb_cModel;
|
|
15
|
+
static VALUE rb_cDocument;
|
|
16
|
+
|
|
17
|
+
// GC functions
|
|
18
|
+
|
|
19
|
+
/* Not using deep free anymore, let ruby call free on the documents otherwise we might end
|
|
20
|
+
* up having double free problems, from svm_learn_main: Warning: The model contains
|
|
21
|
+
* references to the original data 'docs'. If you want to free the original data, and
|
|
22
|
+
* only keep the model, you have to make a deep copy of 'model'.
|
|
23
|
+
* deep_copy_of_model=copy_model(model); */
|
|
24
|
+
void
|
|
25
|
+
model_free(MODEL *m){
|
|
26
|
+
if(m)
|
|
27
|
+
free_model(m, 0);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
void
|
|
31
|
+
doc_free(DOC *d){
|
|
32
|
+
if(d)
|
|
33
|
+
free_example(d, 1);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/* Read a svm_light model from a file generated by svm_learn receives the filename as
|
|
37
|
+
* argument do make sure the file exists before calling this! otherwise exit(1) might be
|
|
38
|
+
* called and the ruby interpreter will die.*/
|
|
39
|
+
static VALUE
|
|
40
|
+
model_read_from_file(VALUE klass, VALUE filename){
|
|
41
|
+
Check_Type(filename, T_STRING);
|
|
42
|
+
MODEL *m;
|
|
43
|
+
|
|
44
|
+
m = read_model(StringValuePtr(filename));
|
|
45
|
+
|
|
46
|
+
if(is_linear(m))
|
|
47
|
+
add_weight_vector_to_linear_model(m);
|
|
48
|
+
|
|
49
|
+
return Data_Wrap_Struct(klass, 0, model_free, m);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/* Helper function type checks a string meant to be used as a learn_parm, in case of error
|
|
53
|
+
* returns 1 and sets the correct exception message in error, on success returns 0 and
|
|
54
|
+
* copies the c string data of new_val to target*/
|
|
55
|
+
int check_string_param(VALUE new_val,
|
|
56
|
+
const char *default_val,
|
|
57
|
+
char *target,
|
|
58
|
+
const char *name,
|
|
59
|
+
char *error){
|
|
60
|
+
|
|
61
|
+
if(TYPE(new_val) == T_STRING){
|
|
62
|
+
strlcpy(target, StringValuePtr(new_val), 199);
|
|
63
|
+
}else if(NIL_P(new_val)){
|
|
64
|
+
strlcpy(target, default_val, 199);
|
|
65
|
+
}else{
|
|
66
|
+
sprintf(error, "The value of the learning option '%s' must be a string", name);
|
|
67
|
+
return 1;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
return 0;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/* Helper function type checks a long meant to be used as a learn_parm or kernel_parm, in
|
|
74
|
+
* case of error returns 1 and sets the correct exception message in error, on success
|
|
75
|
+
* returns 0 and copies the c string data of new_val to target*/
|
|
76
|
+
int check_long_param(VALUE new_val,
|
|
77
|
+
long default_val,
|
|
78
|
+
long *target,
|
|
79
|
+
const char *name,
|
|
80
|
+
char *error){
|
|
81
|
+
if(TYPE(new_val) == T_FLOAT || TYPE(new_val) == T_FIXNUM){
|
|
82
|
+
*target = FIX2LONG(new_val);
|
|
83
|
+
}else if(NIL_P(new_val)){
|
|
84
|
+
*target = default_val;
|
|
85
|
+
}else{
|
|
86
|
+
sprintf(error, "The value of the learning option '%s' must be a numeric", name);
|
|
87
|
+
return 1;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return 0;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
/* Helper function type checks a double meant to be used as a learn_parm or kernel_parm, in
|
|
94
|
+
* case of error returns 1 and sets the correct exception message in error, on success
|
|
95
|
+
* returns 0 and copies the c string data of new_val to target*/
|
|
96
|
+
int check_double_param(VALUE new_val,
|
|
97
|
+
double default_val,
|
|
98
|
+
double *target,
|
|
99
|
+
const char *name,
|
|
100
|
+
char *error){
|
|
101
|
+
if(TYPE(new_val) == T_FLOAT || TYPE(new_val) == T_FIXNUM){
|
|
102
|
+
*target = NUM2DBL(new_val);
|
|
103
|
+
}else if(NIL_P(new_val) ){
|
|
104
|
+
*target = default_val;
|
|
105
|
+
}else{
|
|
106
|
+
sprintf(error, "The value of the learning option '%s' must be a numeric", name);
|
|
107
|
+
return 1;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
return 0;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/* Helper function type checks an int meant to be used as a boolean learn_parm or
|
|
114
|
+
* kernel_parm, in case of error returns 1 and sets the correct exception message in
|
|
115
|
+
* error, on success returns 0 and copies the c string data of new_val to target*/
|
|
116
|
+
int check_bool_param(VALUE new_val,
|
|
117
|
+
long default_val,
|
|
118
|
+
long *target,
|
|
119
|
+
const char *name,
|
|
120
|
+
char *error){
|
|
121
|
+
if(TYPE(new_val) == T_TRUE){
|
|
122
|
+
*target = 1L;
|
|
123
|
+
}else if(TYPE(new_val) == T_FALSE){
|
|
124
|
+
*target = 0L;
|
|
125
|
+
}else if(NIL_P(new_val) ){
|
|
126
|
+
*target = default_val;
|
|
127
|
+
}else{
|
|
128
|
+
sprintf(error, "The value of the learning option '%s' must be a true or false", name);
|
|
129
|
+
return 1;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
return 0;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
/* Helper function in charge of setting up the learn parameters before they are passed to
|
|
136
|
+
* the svm_learn_classification copies part of the logic in svm_learn_main.c */
|
|
137
|
+
int setup_learn_params(LEARN_PARM *c_learn_param, VALUE r_hash, char *error_message){
|
|
138
|
+
// Defaults taken from from svm_learn_main
|
|
139
|
+
VALUE inter_val, temp_ary, svm_type, svm_type_ruby_str;
|
|
140
|
+
char *svm_type_str;
|
|
141
|
+
|
|
142
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("predfile"));
|
|
143
|
+
if(1 == check_string_param(inter_val,
|
|
144
|
+
"trans_predictions",
|
|
145
|
+
&c_learn_param->predfile,
|
|
146
|
+
"predfile",
|
|
147
|
+
error_message)){
|
|
148
|
+
return 1;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("alphafile"));
|
|
152
|
+
if(1 == check_string_param(inter_val,
|
|
153
|
+
"",
|
|
154
|
+
&c_learn_param->alphafile,
|
|
155
|
+
"alphafile",
|
|
156
|
+
error_message)){
|
|
157
|
+
return 1;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("biased_hyperplane"));
|
|
161
|
+
if(1 == check_bool_param(inter_val,
|
|
162
|
+
1L,
|
|
163
|
+
&(c_learn_param->biased_hyperplane),
|
|
164
|
+
"biased_hyperplane",
|
|
165
|
+
error_message)){
|
|
166
|
+
return 1;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("sharedslack"));
|
|
170
|
+
if(1 == check_bool_param(inter_val,
|
|
171
|
+
0L,
|
|
172
|
+
&(c_learn_param->sharedslack),
|
|
173
|
+
"sharedslack",
|
|
174
|
+
error_message)){
|
|
175
|
+
return 1;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("remove_inconsistent"));
|
|
179
|
+
if(1 == check_bool_param(inter_val,
|
|
180
|
+
0L,
|
|
181
|
+
&(c_learn_param->remove_inconsistent),
|
|
182
|
+
"remove_inconsistent",
|
|
183
|
+
error_message)){
|
|
184
|
+
return 1;
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("skip_final_opt_check"));
|
|
188
|
+
if(1 == check_bool_param(inter_val,
|
|
189
|
+
0L,
|
|
190
|
+
&(c_learn_param->skip_final_opt_check),
|
|
191
|
+
"skip_final_opt_check",
|
|
192
|
+
error_message)){
|
|
193
|
+
return 1;
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_newvarsinqp"));
|
|
197
|
+
if(1 == check_bool_param(inter_val,
|
|
198
|
+
0L,
|
|
199
|
+
&(c_learn_param->svm_newvarsinqp),
|
|
200
|
+
"svm_newvarsinqp",
|
|
201
|
+
error_message)){
|
|
202
|
+
return 1;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("compute_loo"));
|
|
206
|
+
if(1 == check_bool_param(inter_val,
|
|
207
|
+
0L,
|
|
208
|
+
&(c_learn_param->compute_loo),
|
|
209
|
+
"compute_loo",
|
|
210
|
+
error_message)){
|
|
211
|
+
return 1;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_maxqpsize"));
|
|
216
|
+
if(1 == check_long_param(inter_val,
|
|
217
|
+
10L,
|
|
218
|
+
&(c_learn_param->svm_maxqpsize),
|
|
219
|
+
"svm_maxqpsize",
|
|
220
|
+
error_message)){
|
|
221
|
+
return 1;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_iter_to_shrink"));
|
|
225
|
+
if(1 == check_long_param(inter_val,
|
|
226
|
+
-9999,
|
|
227
|
+
&(c_learn_param->svm_iter_to_shrink),
|
|
228
|
+
"svm_iter_to_shrink",
|
|
229
|
+
error_message)){
|
|
230
|
+
return 1;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("maxiter"));
|
|
234
|
+
if(1 == check_long_param(inter_val,
|
|
235
|
+
100000,
|
|
236
|
+
&(c_learn_param->maxiter),
|
|
237
|
+
"maxiter",
|
|
238
|
+
error_message)){
|
|
239
|
+
return 1;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("kernel_cache_size"));
|
|
243
|
+
if(1 == check_long_param(inter_val,
|
|
244
|
+
40L,
|
|
245
|
+
&(c_learn_param->kernel_cache_size),
|
|
246
|
+
"kernel_cache_size",
|
|
247
|
+
error_message)){
|
|
248
|
+
return 1;
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("xa_depth"));
|
|
252
|
+
if(1 == check_long_param(inter_val,
|
|
253
|
+
0L,
|
|
254
|
+
&(c_learn_param->xa_depth),
|
|
255
|
+
"xa_depth",
|
|
256
|
+
error_message)){
|
|
257
|
+
return 1;
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_c"));
|
|
261
|
+
if(1 == check_double_param(inter_val,
|
|
262
|
+
0.0,
|
|
263
|
+
&(c_learn_param->svm_c),
|
|
264
|
+
"svm_c",
|
|
265
|
+
error_message)){
|
|
266
|
+
return 1;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("eps"));
|
|
270
|
+
if(1 == check_double_param(inter_val,
|
|
271
|
+
0.1,
|
|
272
|
+
&(c_learn_param->eps),
|
|
273
|
+
"eps",
|
|
274
|
+
error_message)){
|
|
275
|
+
return 1;
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("transduction_posratio"));
|
|
279
|
+
if(1 == check_double_param(inter_val,
|
|
280
|
+
-1.0,
|
|
281
|
+
&(c_learn_param->transduction_posratio),
|
|
282
|
+
"transduction_posratio",
|
|
283
|
+
error_message)){
|
|
284
|
+
return 1;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_costratio"));
|
|
288
|
+
if(1 == check_double_param(inter_val,
|
|
289
|
+
1.0,
|
|
290
|
+
&(c_learn_param->svm_costratio),
|
|
291
|
+
"svm_costratio",
|
|
292
|
+
error_message)){
|
|
293
|
+
return 1;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_costratio_unlab"));
|
|
297
|
+
if(1 == check_double_param(inter_val,
|
|
298
|
+
1.0,
|
|
299
|
+
&(c_learn_param->svm_costratio_unlab),
|
|
300
|
+
"svm_costratio_unlab",
|
|
301
|
+
error_message)){
|
|
302
|
+
return 1;
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("svm_unlabbound"));
|
|
306
|
+
if(1 == check_double_param(inter_val,
|
|
307
|
+
1.0000000000000001e-05,
|
|
308
|
+
&(c_learn_param->svm_unlabbound),
|
|
309
|
+
"svm_unlabbound",
|
|
310
|
+
error_message)){
|
|
311
|
+
return 1;
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("epsilon_crit"));
|
|
315
|
+
if(1 == check_double_param(inter_val,
|
|
316
|
+
0.001,
|
|
317
|
+
&(c_learn_param->epsilon_crit),
|
|
318
|
+
"epsilon_crit",
|
|
319
|
+
error_message)){
|
|
320
|
+
return 1;
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("epsilon_a"));
|
|
324
|
+
if(1 == check_double_param(inter_val,
|
|
325
|
+
1E-15,
|
|
326
|
+
&(c_learn_param->epsilon_a),
|
|
327
|
+
"epsilon_a",
|
|
328
|
+
error_message)){
|
|
329
|
+
return 1;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
c_learn_param->rho=1.0;
|
|
333
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("rho"));
|
|
334
|
+
if(1 == check_double_param(inter_val,
|
|
335
|
+
1.0,
|
|
336
|
+
&(c_learn_param->rho),
|
|
337
|
+
"rho",
|
|
338
|
+
error_message)){
|
|
339
|
+
return 1;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
return 0;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
int setup_kernel_params(KERNEL_PARM *c_kernel_param, VALUE r_hash, char *error_message){
|
|
347
|
+
VALUE inter_val;
|
|
348
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("poly_degree"));
|
|
349
|
+
if(1 == check_long_param(inter_val,
|
|
350
|
+
3L,
|
|
351
|
+
&(c_kernel_param->poly_degree),
|
|
352
|
+
"poly_degree",
|
|
353
|
+
error_message)){
|
|
354
|
+
return 1;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("rbf_gamma"));
|
|
358
|
+
if(1 == check_double_param(inter_val,
|
|
359
|
+
1.0,
|
|
360
|
+
&(c_kernel_param->rbf_gamma),
|
|
361
|
+
"rbf_gamma",
|
|
362
|
+
error_message)){
|
|
363
|
+
return 1;
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("coef_lin"));
|
|
367
|
+
if(1 == check_double_param(inter_val,
|
|
368
|
+
1.0,
|
|
369
|
+
&(c_kernel_param->coef_lin),
|
|
370
|
+
"coef_lin",
|
|
371
|
+
error_message)){
|
|
372
|
+
return 1;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
inter_val = rb_hash_aref(r_hash, rb_str_new2("coef_const"));
|
|
376
|
+
if(1 == check_double_param(inter_val,
|
|
377
|
+
1.0,
|
|
378
|
+
&(c_kernel_param->coef_const),
|
|
379
|
+
"coef_const",
|
|
380
|
+
error_message)){
|
|
381
|
+
return 1;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
// No support for custom kernel yet just set it to empty
|
|
385
|
+
strlcpy(c_kernel_param->custom,"empty", 49);
|
|
386
|
+
return 0;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
/* Do logic checks for the learn and kernel params, this logic is copied from
|
|
390
|
+
* svm_learn_main.c */
|
|
391
|
+
int check_kernel_and_learn_params_logic(KERNEL_PARM *c_kernel_param,
|
|
392
|
+
LEARN_PARM *c_learn_param, char *error_msg){
|
|
393
|
+
|
|
394
|
+
if(c_learn_param->svm_iter_to_shrink == -9999) {
|
|
395
|
+
if(c_kernel_param->kernel_type == LINEAR)
|
|
396
|
+
c_learn_param->svm_iter_to_shrink=2;
|
|
397
|
+
else
|
|
398
|
+
c_learn_param->svm_iter_to_shrink=100;
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
//It does not make sense to skip the final optimality check for linear kernels.
|
|
402
|
+
if((c_learn_param->skip_final_opt_check)
|
|
403
|
+
&& (c_kernel_param->kernel_type == LINEAR)) {
|
|
404
|
+
c_learn_param->skip_final_opt_check=0;
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
if((c_learn_param->skip_final_opt_check)
|
|
408
|
+
&& (c_learn_param->remove_inconsistent)) {
|
|
409
|
+
strncpy(error_msg,"It is necessary to do the final optimality "
|
|
410
|
+
"check when removing inconsistent examples.", 300);
|
|
411
|
+
return 1;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
if((c_learn_param->svm_maxqpsize<2)) {
|
|
415
|
+
snprintf(error_msg, 300, "Maximum size of QP-subproblems "
|
|
416
|
+
"not in valid range: %ld [2..]",c_learn_param->svm_maxqpsize);
|
|
417
|
+
return 1;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
if((c_learn_param->svm_maxqpsize<c_learn_param->svm_newvarsinqp)) {
|
|
421
|
+
snprintf(error_msg, 300, "Maximum size of QP-subproblems [%ld] must be larger than the number of",
|
|
422
|
+
"new variables [%ld] entering the working set in each iteration.\n",c_learn_param->svm_maxqpsize
|
|
423
|
+
,c_learn_param->svm_newvarsinqp);
|
|
424
|
+
return 1;
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
if(c_learn_param->svm_iter_to_shrink<1) {
|
|
428
|
+
snprintf(error_msg, 300, "Maximum number of iterations for shrinking not"
|
|
429
|
+
" in valid range: %ld [1,..]",c_learn_param->svm_iter_to_shrink);
|
|
430
|
+
return 1;
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
if(c_learn_param->svm_c<0) {
|
|
434
|
+
strncpy(error_msg,"The C parameter must be greater than zero", 300);
|
|
435
|
+
return 1;
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
if(c_learn_param->transduction_posratio>1) {
|
|
439
|
+
strncpy(error_msg,"The fraction of unlabeled examples to classify as positives must"
|
|
440
|
+
"be less than 1.0", 300);
|
|
441
|
+
return 1;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
if(c_learn_param->svm_costratio<=0) {
|
|
445
|
+
strncpy(error_msg,"The COSTRATIO parameter must be greater than zero", 300);
|
|
446
|
+
return 1;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
if(c_learn_param->epsilon_crit<=0) {
|
|
450
|
+
strncpy(error_msg,"The epsilon parameter must be greater than zero", 300);
|
|
451
|
+
return 1;
|
|
452
|
+
}
|
|
453
|
+
|
|
454
|
+
if(c_learn_param->rho<0) {
|
|
455
|
+
strncpy(error_msg, "The parameter rho for xi/alpha-estimates and leave-one-out pruning must"
|
|
456
|
+
"be greater than zero (typically 1.0 or 2.0, see T. Joachims, Estimating the"
|
|
457
|
+
"Generalization Performance of an SVM Efficiently, ICML, 2000.)!", 300);
|
|
458
|
+
return 1;
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
if((c_learn_param->xa_depth<0) || (c_learn_param->xa_depth>100)) {
|
|
462
|
+
strncpy(error_msg, "The parameter depth for ext. xi/alpha-estimates must be in [0..100] (zero) "
|
|
463
|
+
"for switching to the conventional xa/estimates described in T. Joachims,"
|
|
464
|
+
"Estimating the Generalization Performance of an SVM Efficiently, ICML, 2000.)", 300);
|
|
465
|
+
return 1;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
return 0;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
/* This function will let you train a new SVM model, for now we *only* support
|
|
472
|
+
* classification SVMs and linear kernels, in ruby-land the kernel and learning params
|
|
473
|
+
* will be represented by hashes where the keys are the name of the respective field in
|
|
474
|
+
* the options structure
|
|
475
|
+
*
|
|
476
|
+
* @param [Array] r_docs_and_classes is an array of arrays where each of the inner arrays must have two elements, the first a Document and the second a label (1, -1 ) for classification
|
|
477
|
+
* @param [Hash] learn_params the learning options, each key is the name of a filed in the LEARN_PARM struct
|
|
478
|
+
* @param [Hash] kernel_params the kernel options, each key is the name of a filed in the KERNEL_PARM struct
|
|
479
|
+
* @param [Bool] use_cache, useless for now caches cannot be set for linear kernels
|
|
480
|
+
* @param [Array] alpha, array of alpha values
|
|
481
|
+
* */
|
|
482
|
+
static VALUE
|
|
483
|
+
model_learn_classification(VALUE klass,
|
|
484
|
+
VALUE r_docs_and_classes, // Docs + labels array of arrays
|
|
485
|
+
VALUE learn_params, // Options hash with learning options
|
|
486
|
+
VALUE kernel_params, // Options hash with kernel options
|
|
487
|
+
VALUE use_cache, // If no linear
|
|
488
|
+
VALUE alpha
|
|
489
|
+
){
|
|
490
|
+
int i;
|
|
491
|
+
double *labels = NULL, *alpha_in = NULL;
|
|
492
|
+
long totdocs, totwords = 0, fnum = 0;
|
|
493
|
+
MODEL *m = NULL;
|
|
494
|
+
DOC **c_docs = NULL;
|
|
495
|
+
LEARN_PARM c_learn_param;
|
|
496
|
+
KERNEL_PARM c_kernel_param;
|
|
497
|
+
VALUE temp_ary, exception = rb_eArgError;
|
|
498
|
+
char error_msg[300];
|
|
499
|
+
|
|
500
|
+
Check_Type(r_docs_and_classes, T_ARRAY);
|
|
501
|
+
Check_Type(learn_params, T_HASH);
|
|
502
|
+
Check_Type(kernel_params, T_HASH);
|
|
503
|
+
|
|
504
|
+
if(!(TYPE(alpha) == T_ARRAY || NIL_P(alpha) ))
|
|
505
|
+
rb_raise(rb_eTypeError, "alpha must be an numeric array or nil");
|
|
506
|
+
|
|
507
|
+
if(TYPE(alpha) == T_ARRAY){
|
|
508
|
+
|
|
509
|
+
alpha_in = my_malloc(sizeof(double) * RARRAY_LEN(alpha));
|
|
510
|
+
|
|
511
|
+
for(i=0; i < RARRAY_LEN(alpha); i++){
|
|
512
|
+
|
|
513
|
+
if(TYPE(RARRAY_PTR(alpha)[i]) != T_FLOAT &&
|
|
514
|
+
TYPE(RARRAY_PTR(alpha)[i]) != T_FIXNUM ){
|
|
515
|
+
|
|
516
|
+
strncpy(error_msg,"All elements of the alpha array must be numeric ", 300);
|
|
517
|
+
goto bail;
|
|
518
|
+
}
|
|
519
|
+
|
|
520
|
+
alpha_in[i] = NUM2DBL(RARRAY_PTR(alpha)[i]);
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
if(setup_learn_params(&c_learn_param, learn_params, error_msg) != 0){
|
|
525
|
+
goto bail;
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
c_learn_param.type = CLASSIFICATION;
|
|
529
|
+
|
|
530
|
+
if(setup_kernel_params(&c_kernel_param, kernel_params, error_msg) != 0){
|
|
531
|
+
goto bail;
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
//TODO Setup kernel cache when we support non linear kernels
|
|
535
|
+
c_kernel_param.kernel_type = LINEAR;
|
|
536
|
+
|
|
537
|
+
if(check_kernel_and_learn_params_logic(&c_kernel_param, &c_learn_param, error_msg) != 0){
|
|
538
|
+
goto bail;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
totdocs = (long)RARRAY_LEN(r_docs_and_classes);
|
|
542
|
+
|
|
543
|
+
if (totdocs == 0){
|
|
544
|
+
strncpy(error_msg, "Cannot create Model from empty Documents array", 300);
|
|
545
|
+
goto bail;
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
c_docs = (DOC **)my_malloc(sizeof(DOC *)*(totdocs));
|
|
549
|
+
labels = (double*)my_malloc(sizeof(double)*totdocs);
|
|
550
|
+
|
|
551
|
+
for(i=0; i < totdocs; i++){
|
|
552
|
+
// Just one of the documents and classes arrays, we expect temp_ary to have a Document
|
|
553
|
+
// and a label (long)
|
|
554
|
+
temp_ary = RARRAY_PTR(r_docs_and_classes)[i] ;
|
|
555
|
+
|
|
556
|
+
if( TYPE(temp_ary) != T_ARRAY ||
|
|
557
|
+
RARRAY_LEN(temp_ary) < 2 ||
|
|
558
|
+
rb_obj_class(RARRAY_PTR(temp_ary)[0]) != rb_cDocument ||
|
|
559
|
+
(TYPE(RARRAY_PTR(temp_ary)[1]) != T_FLOAT && TYPE(RARRAY_PTR(temp_ary)[1]) != T_FIXNUM )){
|
|
560
|
+
|
|
561
|
+
strncpy(error_msg, "All elements of documents and labels should be arrays,"
|
|
562
|
+
"where the first element is a document and the second a number", 300);
|
|
563
|
+
|
|
564
|
+
goto bail;
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
Data_Get_Struct(RARRAY_PTR(temp_ary)[0], DOC, c_docs[i]);
|
|
568
|
+
labels[i] = NUM2DBL(RARRAY_PTR(temp_ary)[1]);
|
|
569
|
+
|
|
570
|
+
fnum = 0;
|
|
571
|
+
|
|
572
|
+
// Increase feature number while there are still words in the vector
|
|
573
|
+
while(c_docs[i]->fvec->words[fnum].wnum) {
|
|
574
|
+
fnum++;
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
if(c_docs[i]->fvec->words[fnum -1].wnum > totwords)
|
|
578
|
+
totwords = c_docs[i]->fvec->words[fnum-1].wnum;
|
|
579
|
+
|
|
580
|
+
if(totwords > MAXFEATNUM){
|
|
581
|
+
strncpy(error_msg, "The number of features exceeds MAXFEATNUM the maximun "
|
|
582
|
+
"number of features defined for this version of SVMLight", 300);
|
|
583
|
+
goto bail;
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
m = (MODEL *)my_malloc(sizeof(MODEL));
|
|
588
|
+
|
|
589
|
+
svm_learn_classification(c_docs, labels, totdocs, totwords,
|
|
590
|
+
&c_learn_param, &c_kernel_param, NULL, m, alpha_in);
|
|
591
|
+
|
|
592
|
+
free(alpha_in);
|
|
593
|
+
free(labels);
|
|
594
|
+
|
|
595
|
+
// If need arises to free the data do a deep copy of m and create the ruby object with
|
|
596
|
+
// that data.
|
|
597
|
+
// free(c_docs);
|
|
598
|
+
return Data_Wrap_Struct(klass, 0, model_free, m);
|
|
599
|
+
|
|
600
|
+
bail:
|
|
601
|
+
free(alpha_in);
|
|
602
|
+
free(labels);
|
|
603
|
+
free(c_docs);
|
|
604
|
+
rb_raise(exception, error_msg, "%s");
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
/* Classify, takes an example (instance of Document) and returns its classification */
|
|
608
|
+
static VALUE
|
|
609
|
+
model_classify_example(VALUE self, VALUE example){
|
|
610
|
+
DOC *ex;
|
|
611
|
+
MODEL *m;
|
|
612
|
+
double result;
|
|
613
|
+
|
|
614
|
+
Data_Get_Struct(example, DOC, ex);
|
|
615
|
+
Data_Get_Struct(self, MODEL, m);
|
|
616
|
+
|
|
617
|
+
/* Apparently unnecessary code
|
|
618
|
+
|
|
619
|
+
if(is_linear(m))
|
|
620
|
+
result = classify_example_linear(m, ex);
|
|
621
|
+
else
|
|
622
|
+
*/
|
|
623
|
+
|
|
624
|
+
result = classify_example(m, ex);
|
|
625
|
+
|
|
626
|
+
return rb_float_new((float)result);
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
static VALUE
|
|
630
|
+
model_support_vectors_count(VALUE self){
|
|
631
|
+
MODEL *m;
|
|
632
|
+
Data_Get_Struct(self, MODEL, m);
|
|
633
|
+
|
|
634
|
+
return INT2FIX(m->sv_num);
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
static VALUE
|
|
638
|
+
model_total_words(VALUE self){
|
|
639
|
+
MODEL *m;
|
|
640
|
+
Data_Get_Struct(self, MODEL, m);
|
|
641
|
+
|
|
642
|
+
return INT2FIX(m->totwords);
|
|
643
|
+
}
|
|
644
|
+
|
|
645
|
+
static VALUE
|
|
646
|
+
model_totdoc(VALUE self){
|
|
647
|
+
MODEL *m;
|
|
648
|
+
Data_Get_Struct(self, MODEL, m);
|
|
649
|
+
|
|
650
|
+
return INT2FIX(m->totdoc);
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
static VALUE
|
|
654
|
+
model_maxdiff(VALUE self){
|
|
655
|
+
MODEL *m;
|
|
656
|
+
Data_Get_Struct(self, MODEL, m);
|
|
657
|
+
|
|
658
|
+
return DBL2NUM(m->maxdiff);
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
/* Creates a DOC from an array of words it also takes an id
|
|
662
|
+
* -1 is normally OK for that value when using in filtering it also takes the C (cost)
|
|
663
|
+
* parameter for the SVM.
|
|
664
|
+
* words_ary an array of arrays like this
|
|
665
|
+
* [[wnum, weight], [wnum, weight], ...] so we do not waste memory, defeating the svec implementation and do
|
|
666
|
+
* not introduce a bunch of 0's that seem to be OK when classifying but screw all up on
|
|
667
|
+
* training*/
|
|
668
|
+
static VALUE
|
|
669
|
+
doc_create(VALUE klass, VALUE id, VALUE cost, VALUE slackid, VALUE queryid, VALUE words_ary ){
|
|
670
|
+
long docnum, i, c_slackid, c_queryid;
|
|
671
|
+
double c;
|
|
672
|
+
WORD *words;
|
|
673
|
+
SVECTOR *vec;
|
|
674
|
+
DOC *d;
|
|
675
|
+
VALUE inner_array;
|
|
676
|
+
|
|
677
|
+
Check_Type(words_ary, T_ARRAY);
|
|
678
|
+
Check_Type(slackid, T_FIXNUM);
|
|
679
|
+
Check_Type(queryid, T_FIXNUM);
|
|
680
|
+
|
|
681
|
+
if (RARRAY_LEN(words_ary) == 0)
|
|
682
|
+
rb_raise(rb_eArgError, "Cannot create Document from empty arrays");
|
|
683
|
+
|
|
684
|
+
words = (WORD*) my_malloc(sizeof(WORD) * (RARRAY_LEN(words_ary) + 1));
|
|
685
|
+
|
|
686
|
+
for(i=0; i < (long)RARRAY_LEN(words_ary); i++){
|
|
687
|
+
inner_array = RARRAY_PTR(words_ary)[i];
|
|
688
|
+
Check_Type(inner_array, T_ARRAY);
|
|
689
|
+
Check_Type(RARRAY_PTR(inner_array)[0], T_FIXNUM);
|
|
690
|
+
|
|
691
|
+
if(!(TYPE(RARRAY_PTR(inner_array)[1]) == T_FLOAT || TYPE(RARRAY_PTR(inner_array)[1]) == T_FIXNUM ))
|
|
692
|
+
rb_raise(rb_eArgError, "Feature weights must be numeric");
|
|
693
|
+
|
|
694
|
+
if(FIX2LONG(RARRAY_PTR(inner_array)[0]) <= 0 )
|
|
695
|
+
rb_raise(rb_eArgError, "Feature number has to be greater than zero");
|
|
696
|
+
|
|
697
|
+
(words[i]).wnum = FIX2LONG(RARRAY_PTR(inner_array)[0]);
|
|
698
|
+
(words[i]).weight = (FVAL)(NUM2DBL(RARRAY_PTR(inner_array)[1]));
|
|
699
|
+
}
|
|
700
|
+
words[i].wnum = 0;
|
|
701
|
+
|
|
702
|
+
vec = create_svector(words, (char*)"", 1.0);
|
|
703
|
+
c = NUM2DBL(cost);
|
|
704
|
+
docnum = FIX2INT(id);
|
|
705
|
+
|
|
706
|
+
d = create_example(docnum, FIX2LONG(queryid), FIX2LONG(slackid), c, vec);
|
|
707
|
+
|
|
708
|
+
return Data_Wrap_Struct(klass, 0, doc_free, d);
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
static VALUE
|
|
712
|
+
doc_get_docnum(VALUE self){
|
|
713
|
+
DOC *d;
|
|
714
|
+
Data_Get_Struct(self, DOC, d);
|
|
715
|
+
|
|
716
|
+
return INT2FIX(d->docnum);
|
|
717
|
+
}
|
|
718
|
+
|
|
719
|
+
static VALUE
|
|
720
|
+
doc_get_slackid(VALUE self){
|
|
721
|
+
DOC *d;
|
|
722
|
+
Data_Get_Struct(self, DOC, d);
|
|
723
|
+
|
|
724
|
+
return INT2FIX(d->slackid);
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
static VALUE
|
|
728
|
+
doc_get_queryid(VALUE self){
|
|
729
|
+
DOC *d;
|
|
730
|
+
Data_Get_Struct(self, DOC, d);
|
|
731
|
+
|
|
732
|
+
return INT2FIX(d->queryid);
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
static VALUE
|
|
736
|
+
doc_get_costfactor(VALUE self){
|
|
737
|
+
DOC *d;
|
|
738
|
+
Data_Get_Struct(self, DOC, d);
|
|
739
|
+
|
|
740
|
+
return DBL2NUM(d->costfactor);
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
void
|
|
744
|
+
Init_svmredlight(){
|
|
745
|
+
rb_mSvmLight = rb_define_module("SVMLight");
|
|
746
|
+
//Model
|
|
747
|
+
rb_cModel = rb_define_class_under(rb_mSvmLight, "Model", rb_cObject);
|
|
748
|
+
rb_define_singleton_method(rb_cModel, "read_from_file", model_read_from_file, 1);
|
|
749
|
+
rb_define_singleton_method(rb_cModel, "learn_classification", model_learn_classification, 5);
|
|
750
|
+
rb_define_method(rb_cModel, "support_vectors_count", model_support_vectors_count, 0);
|
|
751
|
+
rb_define_method(rb_cModel, "total_words", model_total_words, 0);
|
|
752
|
+
rb_define_method(rb_cModel, "classify", model_classify_example, 1);
|
|
753
|
+
rb_define_method(rb_cModel, "totdoc", model_totdoc,0);
|
|
754
|
+
rb_define_method(rb_cModel, "maxdiff", model_maxdiff,0);
|
|
755
|
+
//Document
|
|
756
|
+
rb_cDocument = rb_define_class_under(rb_mSvmLight, "Document", rb_cObject);
|
|
757
|
+
rb_define_singleton_method(rb_cDocument, "create", doc_create, 5);
|
|
758
|
+
rb_define_method(rb_cDocument, "docnum", doc_get_docnum, 0);
|
|
759
|
+
rb_define_method(rb_cDocument, "costfactor", doc_get_costfactor, 0);
|
|
760
|
+
rb_define_method(rb_cDocument, "slackid", doc_get_slackid, 0);
|
|
761
|
+
rb_define_method(rb_cDocument, "queryid", doc_get_queryid, 0);
|
|
762
|
+
}
|