ruby-fann 2.0.0 → 2.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/ruby_fann/fann_augment.h +62 -62
- data/ext/ruby_fann/ruby_fann.c +362 -268
- data/lib/ruby_fann/version.rb +1 -1
- metadata +1 -1
data/ext/ruby_fann/ruby_fann.c
CHANGED
@@ -9,94 +9,131 @@ static VALUE m_rb_fann_standard_class;
|
|
9
9
|
static VALUE m_rb_fann_shortcut_class;
|
10
10
|
static VALUE m_rb_fann_train_data_class;
|
11
11
|
|
12
|
-
#define RETURN_FANN_INT(fn)
|
13
|
-
struct fann*
|
14
|
-
Data_Get_Struct
|
15
|
-
return INT2NUM(fn(f));
|
16
|
-
|
17
|
-
#define SET_FANN_INT(attr_name, fann_fn)
|
18
|
-
Check_Type(attr_name, T_FIXNUM);
|
19
|
-
struct fann*
|
20
|
-
Data_Get_Struct(self, struct fann, f); \
|
21
|
-
fann_fn(f, NUM2INT(attr_name));
|
22
|
-
return 0;
|
23
|
-
|
24
|
-
#define RETURN_FANN_UINT(fn)
|
25
|
-
struct fann*
|
26
|
-
Data_Get_Struct
|
27
|
-
return rb_int_new(fn(f));
|
28
|
-
|
29
|
-
#define SET_FANN_UINT(attr_name, fann_fn)
|
30
|
-
Check_Type(attr_name, T_FIXNUM);
|
31
|
-
struct fann*
|
32
|
-
Data_Get_Struct(self, struct fann, f); \
|
33
|
-
fann_fn(f, NUM2UINT(attr_name));
|
34
|
-
return 0;
|
12
|
+
#define RETURN_FANN_INT(fn) \
|
13
|
+
struct fann *f; \
|
14
|
+
Data_Get_Struct(self, struct fann, f); \
|
15
|
+
return INT2NUM(fn(f));
|
16
|
+
|
17
|
+
#define SET_FANN_INT(attr_name, fann_fn) \
|
18
|
+
Check_Type(attr_name, T_FIXNUM); \
|
19
|
+
struct fann *f; \
|
20
|
+
Data_Get_Struct(self, struct fann, f); \
|
21
|
+
fann_fn(f, NUM2INT(attr_name)); \
|
22
|
+
return 0;
|
23
|
+
|
24
|
+
#define RETURN_FANN_UINT(fn) \
|
25
|
+
struct fann *f; \
|
26
|
+
Data_Get_Struct(self, struct fann, f); \
|
27
|
+
return rb_int_new(fn(f));
|
28
|
+
|
29
|
+
#define SET_FANN_UINT(attr_name, fann_fn) \
|
30
|
+
Check_Type(attr_name, T_FIXNUM); \
|
31
|
+
struct fann *f; \
|
32
|
+
Data_Get_Struct(self, struct fann, f); \
|
33
|
+
fann_fn(f, NUM2UINT(attr_name)); \
|
34
|
+
return 0;
|
35
35
|
|
36
36
|
// Converts float return values to a double with same precision, avoids floating point errors.
|
37
|
-
#define RETURN_FANN_FLT(fn)
|
38
|
-
struct fann*
|
39
|
-
Data_Get_Struct
|
40
|
-
char buffy[20];
|
41
|
-
sprintf(buffy, "%0.6g", fn(f));
|
42
|
-
return rb_float_new(atof(buffy));
|
43
|
-
|
44
|
-
#define SET_FANN_FLT(attr_name, fann_fn)
|
45
|
-
Check_Type(attr_name, T_FLOAT);
|
46
|
-
struct fann*
|
47
|
-
Data_Get_Struct(self, struct fann, f); \
|
48
|
-
fann_fn(f, NUM2DBL(attr_name));
|
49
|
-
return self;
|
50
|
-
|
51
|
-
#define RETURN_FANN_DBL(fn)
|
52
|
-
struct fann*
|
53
|
-
Data_Get_Struct
|
54
|
-
return rb_float_new(fn(f));
|
37
|
+
#define RETURN_FANN_FLT(fn) \
|
38
|
+
struct fann *f; \
|
39
|
+
Data_Get_Struct(self, struct fann, f); \
|
40
|
+
char buffy[20]; \
|
41
|
+
sprintf(buffy, "%0.6g", fn(f)); \
|
42
|
+
return rb_float_new(atof(buffy));
|
43
|
+
|
44
|
+
#define SET_FANN_FLT(attr_name, fann_fn) \
|
45
|
+
Check_Type(attr_name, T_FLOAT); \
|
46
|
+
struct fann *f; \
|
47
|
+
Data_Get_Struct(self, struct fann, f); \
|
48
|
+
fann_fn(f, NUM2DBL(attr_name)); \
|
49
|
+
return self;
|
50
|
+
|
51
|
+
#define RETURN_FANN_DBL(fn) \
|
52
|
+
struct fann *f; \
|
53
|
+
Data_Get_Struct(self, struct fann, f); \
|
54
|
+
return rb_float_new(fn(f));
|
55
55
|
|
56
56
|
#define SET_FANN_DBL SET_FANN_FLT
|
57
57
|
|
58
58
|
// Convert ruby symbol to corresponding FANN enum type for activation function:
|
59
59
|
enum fann_activationfunc_enum sym_to_activation_function(VALUE activation_func)
|
60
60
|
{
|
61
|
-
ID id=SYM2ID(activation_func);
|
61
|
+
ID id = SYM2ID(activation_func);
|
62
62
|
enum fann_activationfunc_enum activation_function;
|
63
|
-
if(id==rb_intern("linear"))
|
63
|
+
if (id == rb_intern("linear"))
|
64
|
+
{
|
64
65
|
activation_function = FANN_LINEAR;
|
65
|
-
}
|
66
|
+
}
|
67
|
+
else if (id == rb_intern("threshold"))
|
68
|
+
{
|
66
69
|
activation_function = FANN_THRESHOLD;
|
67
|
-
}
|
70
|
+
}
|
71
|
+
else if (id == rb_intern("threshold_symmetric"))
|
72
|
+
{
|
68
73
|
activation_function = FANN_THRESHOLD_SYMMETRIC;
|
69
|
-
}
|
74
|
+
}
|
75
|
+
else if (id == rb_intern("sigmoid"))
|
76
|
+
{
|
70
77
|
activation_function = FANN_SIGMOID;
|
71
|
-
}
|
78
|
+
}
|
79
|
+
else if (id == rb_intern("sigmoid_stepwise"))
|
80
|
+
{
|
72
81
|
activation_function = FANN_SIGMOID_STEPWISE;
|
73
|
-
}
|
82
|
+
}
|
83
|
+
else if (id == rb_intern("sigmoid_symmetric"))
|
84
|
+
{
|
74
85
|
activation_function = FANN_SIGMOID_SYMMETRIC;
|
75
|
-
}
|
86
|
+
}
|
87
|
+
else if (id == rb_intern("sigmoid_symmetric_stepwise"))
|
88
|
+
{
|
76
89
|
activation_function = FANN_SIGMOID_SYMMETRIC_STEPWISE;
|
77
|
-
}
|
90
|
+
}
|
91
|
+
else if (id == rb_intern("gaussian"))
|
92
|
+
{
|
78
93
|
activation_function = FANN_GAUSSIAN;
|
79
|
-
}
|
94
|
+
}
|
95
|
+
else if (id == rb_intern("gaussian_symmetric"))
|
96
|
+
{
|
80
97
|
activation_function = FANN_GAUSSIAN_SYMMETRIC;
|
81
|
-
}
|
98
|
+
}
|
99
|
+
else if (id == rb_intern("gaussian_stepwise"))
|
100
|
+
{
|
82
101
|
activation_function = FANN_GAUSSIAN_STEPWISE;
|
83
|
-
}
|
102
|
+
}
|
103
|
+
else if (id == rb_intern("elliot"))
|
104
|
+
{
|
84
105
|
activation_function = FANN_ELLIOT;
|
85
|
-
}
|
106
|
+
}
|
107
|
+
else if (id == rb_intern("elliot_symmetric"))
|
108
|
+
{
|
86
109
|
activation_function = FANN_ELLIOT_SYMMETRIC;
|
87
|
-
}
|
110
|
+
}
|
111
|
+
else if (id == rb_intern("linear_piece"))
|
112
|
+
{
|
88
113
|
activation_function = FANN_LINEAR_PIECE;
|
89
|
-
}
|
114
|
+
}
|
115
|
+
else if (id == rb_intern("linear_piece_symmetric"))
|
116
|
+
{
|
90
117
|
activation_function = FANN_LINEAR_PIECE_SYMMETRIC;
|
91
|
-
}
|
118
|
+
}
|
119
|
+
else if (id == rb_intern("sin_symmetric"))
|
120
|
+
{
|
92
121
|
activation_function = FANN_SIN_SYMMETRIC;
|
93
|
-
}
|
122
|
+
}
|
123
|
+
else if (id == rb_intern("cos_symmetric"))
|
124
|
+
{
|
94
125
|
activation_function = FANN_COS_SYMMETRIC;
|
95
|
-
}
|
126
|
+
}
|
127
|
+
else if (id == rb_intern("sin"))
|
128
|
+
{
|
96
129
|
activation_function = FANN_SIN;
|
97
|
-
}
|
130
|
+
}
|
131
|
+
else if (id == rb_intern("cos"))
|
132
|
+
{
|
98
133
|
activation_function = FANN_COS;
|
99
|
-
}
|
134
|
+
}
|
135
|
+
else
|
136
|
+
{
|
100
137
|
rb_raise(rb_eRuntimeError, "Unrecognized activation function: [%s]", rb_id2name(SYM2ID(activation_func)));
|
101
138
|
}
|
102
139
|
return activation_function;
|
@@ -107,83 +144,118 @@ VALUE activation_function_to_sym(enum fann_activationfunc_enum fn)
|
|
107
144
|
{
|
108
145
|
VALUE activation_function;
|
109
146
|
|
110
|
-
if(fn==FANN_LINEAR)
|
147
|
+
if (fn == FANN_LINEAR)
|
148
|
+
{
|
111
149
|
activation_function = ID2SYM(rb_intern("linear"));
|
112
|
-
}
|
150
|
+
}
|
151
|
+
else if (fn == FANN_THRESHOLD)
|
152
|
+
{
|
113
153
|
activation_function = ID2SYM(rb_intern("threshold"));
|
114
|
-
}
|
154
|
+
}
|
155
|
+
else if (fn == FANN_THRESHOLD_SYMMETRIC)
|
156
|
+
{
|
115
157
|
activation_function = ID2SYM(rb_intern("threshold_symmetric"));
|
116
|
-
}
|
158
|
+
}
|
159
|
+
else if (fn == FANN_SIGMOID)
|
160
|
+
{
|
117
161
|
activation_function = ID2SYM(rb_intern("sigmoid"));
|
118
|
-
}
|
162
|
+
}
|
163
|
+
else if (fn == FANN_SIGMOID_STEPWISE)
|
164
|
+
{
|
119
165
|
activation_function = ID2SYM(rb_intern("sigmoid_stepwise"));
|
120
|
-
}
|
166
|
+
}
|
167
|
+
else if (fn == FANN_SIGMOID_SYMMETRIC)
|
168
|
+
{
|
121
169
|
activation_function = ID2SYM(rb_intern("sigmoid_symmetric"));
|
122
|
-
}
|
170
|
+
}
|
171
|
+
else if (fn == FANN_SIGMOID_SYMMETRIC_STEPWISE)
|
172
|
+
{
|
123
173
|
activation_function = ID2SYM(rb_intern("sigmoid_symmetric_stepwise"));
|
124
|
-
}
|
174
|
+
}
|
175
|
+
else if (fn == FANN_GAUSSIAN)
|
176
|
+
{
|
125
177
|
activation_function = ID2SYM(rb_intern("gaussian"));
|
126
|
-
}
|
178
|
+
}
|
179
|
+
else if (fn == FANN_GAUSSIAN_SYMMETRIC)
|
180
|
+
{
|
127
181
|
activation_function = ID2SYM(rb_intern("gaussian_symmetric"));
|
128
|
-
}
|
182
|
+
}
|
183
|
+
else if (fn == FANN_GAUSSIAN_STEPWISE)
|
184
|
+
{
|
129
185
|
activation_function = ID2SYM(rb_intern("gaussian_stepwise"));
|
130
|
-
}
|
186
|
+
}
|
187
|
+
else if (fn == FANN_ELLIOT)
|
188
|
+
{
|
131
189
|
activation_function = ID2SYM(rb_intern("elliot"));
|
132
|
-
}
|
190
|
+
}
|
191
|
+
else if (fn == FANN_ELLIOT_SYMMETRIC)
|
192
|
+
{
|
133
193
|
activation_function = ID2SYM(rb_intern("elliot_symmetric"));
|
134
|
-
}
|
194
|
+
}
|
195
|
+
else if (fn == FANN_LINEAR_PIECE)
|
196
|
+
{
|
135
197
|
activation_function = ID2SYM(rb_intern("linear_piece"));
|
136
|
-
}
|
198
|
+
}
|
199
|
+
else if (fn == FANN_LINEAR_PIECE_SYMMETRIC)
|
200
|
+
{
|
137
201
|
activation_function = ID2SYM(rb_intern("linear_piece_symmetric"));
|
138
|
-
}
|
202
|
+
}
|
203
|
+
else if (fn == FANN_SIN_SYMMETRIC)
|
204
|
+
{
|
139
205
|
activation_function = ID2SYM(rb_intern("sin_symmetric"));
|
140
|
-
}
|
206
|
+
}
|
207
|
+
else if (fn == FANN_COS_SYMMETRIC)
|
208
|
+
{
|
141
209
|
activation_function = ID2SYM(rb_intern("cos_symmetric"));
|
142
|
-
}
|
210
|
+
}
|
211
|
+
else if (fn == FANN_SIN)
|
212
|
+
{
|
143
213
|
activation_function = ID2SYM(rb_intern("sin"));
|
144
|
-
}
|
214
|
+
}
|
215
|
+
else if (fn == FANN_COS)
|
216
|
+
{
|
145
217
|
activation_function = ID2SYM(rb_intern("cos"));
|
146
|
-
}
|
218
|
+
}
|
219
|
+
else
|
220
|
+
{
|
147
221
|
rb_raise(rb_eRuntimeError, "Unrecognized activation function: [%d]", fn);
|
148
222
|
}
|
149
223
|
return activation_function;
|
150
224
|
}
|
151
225
|
|
152
|
-
|
153
226
|
// Unused for now:
|
154
|
-
static void fann_mark
|
227
|
+
static void fann_mark(struct fann *ann) {}
|
155
228
|
|
156
229
|
// #define DEBUG 1
|
157
230
|
|
158
231
|
// Free memory associated with FANN:
|
159
|
-
static void fann_free
|
232
|
+
static void fann_free(struct fann *ann)
|
160
233
|
{
|
161
|
-
|
234
|
+
fann_destroy(ann);
|
162
235
|
// ("Destroyed FANN network [%d].\n", ann);
|
163
236
|
}
|
164
237
|
|
165
238
|
// Free memory associated with FANN Training data:
|
166
|
-
static void fann_training_data_free
|
239
|
+
static void fann_training_data_free(struct fann_train_data *train_data)
|
167
240
|
{
|
168
|
-
|
241
|
+
fann_destroy_train(train_data);
|
169
242
|
// printf("Destroyed Training data [%d].\n", train_data);
|
170
243
|
}
|
171
244
|
|
172
245
|
// Create wrapper, but don't allocate anything...do that in
|
173
246
|
// initialize, so we can construct with args:
|
174
|
-
static VALUE fann_allocate
|
247
|
+
static VALUE fann_allocate(VALUE klass)
|
175
248
|
{
|
176
|
-
return Data_Wrap_Struct
|
249
|
+
return Data_Wrap_Struct(klass, fann_mark, fann_free, 0);
|
177
250
|
}
|
178
251
|
|
179
252
|
// Create wrapper, but don't allocate annything...do that in
|
180
253
|
// initialize, so we can construct with args:
|
181
|
-
static VALUE fann_training_data_allocate
|
254
|
+
static VALUE fann_training_data_allocate(VALUE klass)
|
182
255
|
{
|
183
|
-
return Data_Wrap_Struct
|
256
|
+
return Data_Wrap_Struct(klass, fann_mark, fann_training_data_free, 0);
|
184
257
|
}
|
185
258
|
|
186
|
-
|
187
259
|
// static VALUE invoke_training_callback(VALUE self)
|
188
260
|
// {
|
189
261
|
// VALUE callback = rb_funcall(self, rb_intern("training_callback"), 0);
|
@@ -194,8 +266,8 @@ static VALUE fann_training_data_allocate (VALUE klass)
|
|
194
266
|
// unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs)
|
195
267
|
|
196
268
|
static int FANN_API fann_training_callback(struct fann *ann, struct fann_train_data *train,
|
197
|
-
|
198
|
-
|
269
|
+
unsigned int max_epochs, unsigned int epochs_between_reports,
|
270
|
+
float desired_error, unsigned int epochs)
|
199
271
|
{
|
200
272
|
VALUE self = (VALUE)fann_get_user_data(ann);
|
201
273
|
VALUE args = rb_hash_new();
|
@@ -213,13 +285,13 @@ static int FANN_API fann_training_callback(struct fann *ann, struct fann_train_d
|
|
213
285
|
|
214
286
|
VALUE callback = rb_funcall(self, rb_intern("training_callback"), 1, args);
|
215
287
|
|
216
|
-
if (TYPE(callback)!=T_FIXNUM)
|
288
|
+
if (TYPE(callback) != T_FIXNUM)
|
217
289
|
{
|
218
|
-
rb_raise
|
290
|
+
rb_raise(rb_eRuntimeError, "Callback method must return an integer (-1 to stop training).");
|
219
291
|
}
|
220
292
|
|
221
293
|
int status = NUM2INT(callback);
|
222
|
-
if (status
|
294
|
+
if (status == -1)
|
223
295
|
{
|
224
296
|
printf("Callback method returned -1; training will stop.\n");
|
225
297
|
}
|
@@ -251,19 +323,19 @@ static VALUE fann_initialize(VALUE self, VALUE hash)
|
|
251
323
|
VALUE num_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("num_outputs")));
|
252
324
|
VALUE hidden_neurons = rb_hash_aref(hash, ID2SYM(rb_intern("hidden_neurons")));
|
253
325
|
// printf("initializing\n\n\n");
|
254
|
-
struct fann*
|
255
|
-
if (TYPE(filename)==T_STRING)
|
326
|
+
struct fann *ann;
|
327
|
+
if (TYPE(filename) == T_STRING)
|
256
328
|
{
|
257
329
|
// Initialize with file:
|
258
330
|
// train_data = fann_read_train_from_file(StringValuePtr(filename));
|
259
331
|
// DATA_PTR(self) = train_data;
|
260
332
|
ann = fann_create_from_file(StringValuePtr(filename));
|
261
|
-
|
333
|
+
// printf("Created RubyFann::Standard [%d] from file [%s].\n", ann, StringValuePtr(filename));
|
262
334
|
}
|
263
|
-
else if(rb_obj_is_kind_of(self, m_rb_fann_shortcut_class))
|
335
|
+
else if (rb_obj_is_kind_of(self, m_rb_fann_shortcut_class))
|
264
336
|
{
|
265
337
|
// Initialize as shortcut, suitable for cascade training:
|
266
|
-
//ann = fann_create_shortcut_array(num_layers, layers);
|
338
|
+
// ann = fann_create_shortcut_array(num_layers, layers);
|
267
339
|
Check_Type(num_inputs, T_FIXNUM);
|
268
340
|
Check_Type(num_outputs, T_FIXNUM);
|
269
341
|
|
@@ -278,17 +350,18 @@ static VALUE fann_initialize(VALUE self, VALUE hash)
|
|
278
350
|
Check_Type(num_outputs, T_FIXNUM);
|
279
351
|
|
280
352
|
// Initialize layers:
|
281
|
-
unsigned int num_layers=RARRAY_LEN(hidden_neurons) + 2;
|
353
|
+
unsigned int num_layers = RARRAY_LEN(hidden_neurons) + 2;
|
282
354
|
unsigned int layers[num_layers];
|
283
355
|
|
284
356
|
// Input:
|
285
|
-
layers[0]=NUM2INT(num_inputs);
|
357
|
+
layers[0] = NUM2INT(num_inputs);
|
286
358
|
// Output:
|
287
|
-
layers[num_layers-1]=NUM2INT(num_outputs);
|
359
|
+
layers[num_layers - 1] = NUM2INT(num_outputs);
|
288
360
|
// Hidden:
|
289
361
|
unsigned int i;
|
290
|
-
for (i=1; i<=num_layers-2; i++)
|
291
|
-
|
362
|
+
for (i = 1; i <= num_layers - 2; i++)
|
363
|
+
{
|
364
|
+
layers[i] = NUM2INT(RARRAY_PTR(hidden_neurons)[i - 1]);
|
292
365
|
}
|
293
366
|
ann = fann_create_standard_array(num_layers, layers);
|
294
367
|
}
|
@@ -297,9 +370,9 @@ static VALUE fann_initialize(VALUE self, VALUE hash)
|
|
297
370
|
|
298
371
|
// printf("Checking for callback...");
|
299
372
|
|
300
|
-
//int callback = rb_protect(invoke_training_callback, (self), &status);
|
301
|
-
//
|
302
|
-
if(rb_respond_to(self, rb_intern("training_callback")))
|
373
|
+
// int callback = rb_protect(invoke_training_callback, (self), &status);
|
374
|
+
// VALUE callback = rb_funcall(DATA_PTR(self), "training_callback", 0);
|
375
|
+
if (rb_respond_to(self, rb_intern("training_callback")))
|
303
376
|
{
|
304
377
|
fann_set_callback(ann, &fann_training_callback);
|
305
378
|
fann_set_user_data(ann, self);
|
@@ -329,39 +402,39 @@ static VALUE fann_initialize(VALUE self, VALUE hash)
|
|
329
402
|
*/
|
330
403
|
static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
|
331
404
|
{
|
332
|
-
struct fann_train_data*
|
405
|
+
struct fann_train_data *train_data;
|
333
406
|
Check_Type(hash, T_HASH);
|
334
407
|
|
335
408
|
VALUE filename = rb_hash_aref(hash, ID2SYM(rb_intern("filename")));
|
336
409
|
VALUE inputs = rb_hash_aref(hash, ID2SYM(rb_intern("inputs")));
|
337
410
|
VALUE desired_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("desired_outputs")));
|
338
411
|
|
339
|
-
if (TYPE(filename)==T_STRING)
|
412
|
+
if (TYPE(filename) == T_STRING)
|
340
413
|
{
|
341
414
|
train_data = fann_read_train_from_file(StringValuePtr(filename));
|
342
415
|
DATA_PTR(self) = train_data;
|
343
416
|
}
|
344
|
-
else if (TYPE(inputs)==T_ARRAY)
|
417
|
+
else if (TYPE(inputs) == T_ARRAY)
|
345
418
|
{
|
346
|
-
if (TYPE(desired_outputs)!=T_ARRAY)
|
419
|
+
if (TYPE(desired_outputs) != T_ARRAY)
|
347
420
|
{
|
348
|
-
rb_raise
|
421
|
+
rb_raise(rb_eRuntimeError, "[desired_outputs] must be present when [inputs] used.");
|
349
422
|
}
|
350
423
|
|
351
424
|
if (RARRAY_LEN(inputs) < 1)
|
352
425
|
{
|
353
|
-
rb_raise
|
426
|
+
rb_raise(rb_eRuntimeError, "[inputs] must contain at least one value.");
|
354
427
|
}
|
355
428
|
|
356
429
|
if (RARRAY_LEN(desired_outputs) < 1)
|
357
430
|
{
|
358
|
-
rb_raise
|
431
|
+
rb_raise(rb_eRuntimeError, "[desired_outputs] must contain at least one value.");
|
359
432
|
}
|
360
433
|
|
361
434
|
// The data is here, start constructing:
|
362
|
-
if(RARRAY_LEN(inputs) != RARRAY_LEN(desired_outputs))
|
435
|
+
if (RARRAY_LEN(inputs) != RARRAY_LEN(desired_outputs))
|
363
436
|
{
|
364
|
-
rb_raise
|
437
|
+
rb_raise(
|
365
438
|
rb_eRuntimeError,
|
366
439
|
"Number of inputs must match number of outputs: (%d != %d)",
|
367
440
|
(int)RARRAY_LEN(inputs),
|
@@ -373,13 +446,12 @@ static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
|
|
373
446
|
}
|
374
447
|
else
|
375
448
|
{
|
376
|
-
rb_raise
|
449
|
+
rb_raise(rb_eRuntimeError, "Must construct with a filename(string) or inputs/desired_outputs(arrays). All args passed via hash with symbols as keys.");
|
377
450
|
}
|
378
451
|
|
379
452
|
return (VALUE)train_data;
|
380
453
|
}
|
381
454
|
|
382
|
-
|
383
455
|
/** call-seq: save(filename)
|
384
456
|
|
385
457
|
Save to given filename
|
@@ -387,8 +459,8 @@ static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
|
|
387
459
|
static VALUE training_save(VALUE self, VALUE filename)
|
388
460
|
{
|
389
461
|
Check_Type(filename, T_STRING);
|
390
|
-
struct fann_train_data*
|
391
|
-
Data_Get_Struct
|
462
|
+
struct fann_train_data *t;
|
463
|
+
Data_Get_Struct(self, struct fann_train_data, t);
|
392
464
|
fann_save_train(t, StringValuePtr(filename));
|
393
465
|
return self;
|
394
466
|
}
|
@@ -397,8 +469,8 @@ static VALUE training_save(VALUE self, VALUE filename)
|
|
397
469
|
This is recommended for incremental training, while it will have no influence during batch training.*/
|
398
470
|
static VALUE shuffle(VALUE self)
|
399
471
|
{
|
400
|
-
struct fann_train_data*
|
401
|
-
Data_Get_Struct
|
472
|
+
struct fann_train_data *t;
|
473
|
+
Data_Get_Struct(self, struct fann_train_data, t);
|
402
474
|
fann_shuffle_train_data(t);
|
403
475
|
return self;
|
404
476
|
}
|
@@ -406,9 +478,9 @@ static VALUE shuffle(VALUE self)
|
|
406
478
|
/** Length of training data*/
|
407
479
|
static VALUE length_train_data(VALUE self)
|
408
480
|
{
|
409
|
-
struct fann_train_data*
|
410
|
-
Data_Get_Struct
|
411
|
-
return(UINT2NUM(fann_length_train_data(t)));
|
481
|
+
struct fann_train_data *t;
|
482
|
+
Data_Get_Struct(self, struct fann_train_data, t);
|
483
|
+
return (UINT2NUM(fann_length_train_data(t)));
|
412
484
|
return self;
|
413
485
|
}
|
414
486
|
|
@@ -426,7 +498,7 @@ static VALUE set_activation_function(VALUE self, VALUE activation_func, VALUE la
|
|
426
498
|
Check_Type(layer, T_FIXNUM);
|
427
499
|
Check_Type(neuron, T_FIXNUM);
|
428
500
|
|
429
|
-
struct fann*
|
501
|
+
struct fann *f;
|
430
502
|
Data_Get_Struct(self, struct fann, f);
|
431
503
|
fann_set_activation_function(f, sym_to_activation_function(activation_func), NUM2INT(layer), NUM2INT(neuron));
|
432
504
|
return self;
|
@@ -442,7 +514,7 @@ static VALUE set_activation_function(VALUE self, VALUE activation_func, VALUE la
|
|
442
514
|
static VALUE set_activation_function_hidden(VALUE self, VALUE activation_func)
|
443
515
|
{
|
444
516
|
Check_Type(activation_func, T_SYMBOL);
|
445
|
-
struct fann*
|
517
|
+
struct fann *f;
|
446
518
|
Data_Get_Struct(self, struct fann, f);
|
447
519
|
fann_set_activation_function_hidden(f, sym_to_activation_function(activation_func));
|
448
520
|
return self;
|
@@ -463,7 +535,7 @@ static VALUE set_activation_function_layer(VALUE self, VALUE activation_func, VA
|
|
463
535
|
{
|
464
536
|
Check_Type(activation_func, T_SYMBOL);
|
465
537
|
Check_Type(layer, T_FIXNUM);
|
466
|
-
struct fann*
|
538
|
+
struct fann *f;
|
467
539
|
Data_Get_Struct(self, struct fann, f);
|
468
540
|
fann_set_activation_function_layer(f, sym_to_activation_function(activation_func), NUM2INT(layer));
|
469
541
|
return self;
|
@@ -480,7 +552,7 @@ static VALUE get_activation_function(VALUE self, VALUE layer, VALUE neuron)
|
|
480
552
|
{
|
481
553
|
Check_Type(layer, T_FIXNUM);
|
482
554
|
Check_Type(neuron, T_FIXNUM);
|
483
|
-
struct fann*
|
555
|
+
struct fann *f;
|
484
556
|
Data_Get_Struct(self, struct fann, f);
|
485
557
|
fann_type val = fann_get_activation_function(f, NUM2INT(layer), NUM2INT(neuron));
|
486
558
|
return activation_function_to_sym(val);
|
@@ -497,7 +569,7 @@ static VALUE get_activation_function(VALUE self, VALUE layer, VALUE neuron)
|
|
497
569
|
static VALUE set_activation_function_output(VALUE self, VALUE activation_func)
|
498
570
|
{
|
499
571
|
Check_Type(activation_func, T_SYMBOL);
|
500
|
-
struct fann*
|
572
|
+
struct fann *f;
|
501
573
|
Data_Get_Struct(self, struct fann, f);
|
502
574
|
fann_set_activation_function_output(f, sym_to_activation_function(activation_func));
|
503
575
|
return self;
|
@@ -511,7 +583,7 @@ static VALUE get_activation_steepness(VALUE self, VALUE layer, VALUE neuron)
|
|
511
583
|
{
|
512
584
|
Check_Type(layer, T_FIXNUM);
|
513
585
|
Check_Type(neuron, T_FIXNUM);
|
514
|
-
struct fann*
|
586
|
+
struct fann *f;
|
515
587
|
Data_Get_Struct(self, struct fann, f);
|
516
588
|
fann_type val = fann_get_activation_steepness(f, NUM2INT(layer), NUM2INT(neuron));
|
517
589
|
return rb_float_new(val);
|
@@ -527,7 +599,7 @@ static VALUE set_activation_steepness(VALUE self, VALUE steepness, VALUE layer,
|
|
527
599
|
Check_Type(layer, T_FIXNUM);
|
528
600
|
Check_Type(neuron, T_FIXNUM);
|
529
601
|
|
530
|
-
struct fann*
|
602
|
+
struct fann *f;
|
531
603
|
Data_Get_Struct(self, struct fann, f);
|
532
604
|
fann_set_activation_steepness(f, NUM2DBL(steepness), NUM2INT(layer), NUM2INT(neuron));
|
533
605
|
return self;
|
@@ -550,7 +622,7 @@ static VALUE set_activation_steepness_layer(VALUE self, VALUE steepness, VALUE l
|
|
550
622
|
Check_Type(steepness, T_FLOAT);
|
551
623
|
Check_Type(layer, T_FIXNUM);
|
552
624
|
|
553
|
-
struct fann*
|
625
|
+
struct fann *f;
|
554
626
|
Data_Get_Struct(self, struct fann, f);
|
555
627
|
fann_set_activation_steepness_layer(f, NUM2DBL(steepness), NUM2INT(layer));
|
556
628
|
return self;
|
@@ -684,9 +756,9 @@ static VALUE set_rprop_delta_zero(VALUE self, VALUE rprop_delta_zero)
|
|
684
756
|
/** Return array of bias(es)*/
|
685
757
|
static VALUE get_bias_array(VALUE self)
|
686
758
|
{
|
687
|
-
struct fann*
|
759
|
+
struct fann *f;
|
688
760
|
unsigned int num_layers;
|
689
|
-
Data_Get_Struct
|
761
|
+
Data_Get_Struct(self, struct fann, f);
|
690
762
|
num_layers = fann_get_num_layers(f);
|
691
763
|
unsigned int layers[num_layers];
|
692
764
|
fann_get_bias_array(f, layers);
|
@@ -695,7 +767,7 @@ static VALUE get_bias_array(VALUE self)
|
|
695
767
|
VALUE arr;
|
696
768
|
arr = rb_ary_new();
|
697
769
|
unsigned int i;
|
698
|
-
for (i=0; i<num_layers; i++)
|
770
|
+
for (i = 0; i < num_layers; i++)
|
699
771
|
{
|
700
772
|
rb_ary_push(arr, INT2NUM(layers[i]));
|
701
773
|
}
|
@@ -737,9 +809,9 @@ static VALUE get_neurons(VALUE self, VALUE layer)
|
|
737
809
|
struct fann_layer *layer_it;
|
738
810
|
struct fann_neuron *neuron_it;
|
739
811
|
|
740
|
-
struct fann*
|
812
|
+
struct fann *f;
|
741
813
|
unsigned int i;
|
742
|
-
Data_Get_Struct
|
814
|
+
Data_Get_Struct(self, struct fann, f);
|
743
815
|
|
744
816
|
VALUE neuron_array = rb_ary_new();
|
745
817
|
VALUE activation_function_sym = ID2SYM(rb_intern("activation_function"));
|
@@ -750,16 +822,17 @@ static VALUE get_neurons(VALUE self, VALUE layer)
|
|
750
822
|
VALUE connections_sym = ID2SYM(rb_intern("connections"));
|
751
823
|
unsigned int layer_num = 0;
|
752
824
|
|
753
|
-
|
754
|
-
|
755
|
-
for(layer_it = f->first_layer; layer_it != f->last_layer; layer_it++)
|
825
|
+
int nuke_bias_neuron = (fann_get_network_type(f) == FANN_NETTYPE_LAYER);
|
826
|
+
for (layer_it = f->first_layer; layer_it != f->last_layer; layer_it++)
|
756
827
|
{
|
757
|
-
for(neuron_it = layer_it->first_neuron; neuron_it != layer_it->last_neuron; neuron_it++)
|
828
|
+
for (neuron_it = layer_it->first_neuron; neuron_it != layer_it->last_neuron; neuron_it++)
|
758
829
|
{
|
759
|
-
if (nuke_bias_neuron && (neuron_it==(layer_it->last_neuron)-1))
|
830
|
+
if (nuke_bias_neuron && (neuron_it == (layer_it->last_neuron) - 1))
|
831
|
+
continue;
|
760
832
|
// Create array of connection indicies:
|
761
833
|
VALUE connection_array = rb_ary_new();
|
762
|
-
for (i = neuron_it->first_con; i < neuron_it->last_con; i++)
|
834
|
+
for (i = neuron_it->first_con; i < neuron_it->last_con; i++)
|
835
|
+
{
|
763
836
|
rb_ary_push(connection_array, INT2NUM(f->connections[i] - f->first_layer->first_neuron));
|
764
837
|
}
|
765
838
|
|
@@ -778,17 +851,16 @@ static VALUE get_neurons(VALUE self, VALUE layer)
|
|
778
851
|
++layer_num;
|
779
852
|
}
|
780
853
|
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
854
|
+
// switch (fann_get_network_type(ann)) {
|
855
|
+
// case FANN_NETTYPE_LAYER: {
|
856
|
+
// /* Report one bias in each layer except the last */
|
857
|
+
// if (layer_it != ann->last_layer-1)
|
858
|
+
// *bias = 1;
|
859
|
+
// else
|
860
|
+
// *bias = 0;
|
861
|
+
// break;
|
862
|
+
// }
|
863
|
+
// case FANN_NETTYPE_SHORTCUT: {
|
792
864
|
|
793
865
|
return neuron_array;
|
794
866
|
}
|
@@ -796,9 +868,9 @@ static VALUE get_neurons(VALUE self, VALUE layer)
|
|
796
868
|
/** Get list of layers in array format where each element contains number of neurons in that layer*/
|
797
869
|
static VALUE get_layer_array(VALUE self)
|
798
870
|
{
|
799
|
-
struct fann*
|
871
|
+
struct fann *f;
|
800
872
|
unsigned int num_layers;
|
801
|
-
Data_Get_Struct
|
873
|
+
Data_Get_Struct(self, struct fann, f);
|
802
874
|
num_layers = fann_get_num_layers(f);
|
803
875
|
unsigned int layers[num_layers];
|
804
876
|
fann_get_layer_array(f, layers);
|
@@ -807,12 +879,12 @@ static VALUE get_layer_array(VALUE self)
|
|
807
879
|
VALUE arr;
|
808
880
|
arr = rb_ary_new();
|
809
881
|
unsigned int i;
|
810
|
-
for (i=0; i<num_layers; i++)
|
882
|
+
for (i = 0; i < num_layers; i++)
|
811
883
|
{
|
812
884
|
rb_ary_push(arr, INT2NUM(layers[i]));
|
813
885
|
}
|
814
886
|
|
815
|
-
|
887
|
+
return arr;
|
816
888
|
}
|
817
889
|
|
818
890
|
/** Reads the mean square error from the network.*/
|
@@ -825,8 +897,8 @@ static VALUE get_MSE(VALUE self)
|
|
825
897
|
This function also resets the number of bits that fail.*/
|
826
898
|
static VALUE reset_MSE(VALUE self)
|
827
899
|
{
|
828
|
-
struct fann*
|
829
|
-
Data_Get_Struct
|
900
|
+
struct fann *f;
|
901
|
+
Data_Get_Struct(self, struct fann, f);
|
830
902
|
fann_reset_MSE(f);
|
831
903
|
return self;
|
832
904
|
}
|
@@ -834,18 +906,18 @@ static VALUE reset_MSE(VALUE self)
|
|
834
906
|
/** Get the type of network. Returns as ruby symbol (one of :shortcut, :layer)*/
|
835
907
|
static VALUE get_network_type(VALUE self)
|
836
908
|
{
|
837
|
-
struct fann*
|
909
|
+
struct fann *f;
|
838
910
|
enum fann_nettype_enum net_type;
|
839
911
|
VALUE ret_val;
|
840
|
-
Data_Get_Struct
|
912
|
+
Data_Get_Struct(self, struct fann, f);
|
841
913
|
|
842
914
|
net_type = fann_get_network_type(f);
|
843
915
|
|
844
|
-
if(net_type==FANN_NETTYPE_LAYER)
|
916
|
+
if (net_type == FANN_NETTYPE_LAYER)
|
845
917
|
{
|
846
918
|
ret_val = ID2SYM(rb_intern("layer")); // (rb_str_new2("FANN_NETTYPE_LAYER"));
|
847
919
|
}
|
848
|
-
else if(net_type==FANN_NETTYPE_SHORTCUT)
|
920
|
+
else if (net_type == FANN_NETTYPE_SHORTCUT)
|
849
921
|
{
|
850
922
|
ret_val = ID2SYM(rb_intern("shortcut")); // (rb_str_new2("FANN_NETTYPE_SHORTCUT"));
|
851
923
|
}
|
@@ -890,19 +962,24 @@ static VALUE set_train_error_function(VALUE self, VALUE train_error_function)
|
|
890
962
|
{
|
891
963
|
Check_Type(train_error_function, T_SYMBOL);
|
892
964
|
|
893
|
-
ID id=SYM2ID(train_error_function);
|
965
|
+
ID id = SYM2ID(train_error_function);
|
894
966
|
enum fann_errorfunc_enum fann_train_error_function;
|
895
967
|
|
896
|
-
if(id==rb_intern("linear"))
|
968
|
+
if (id == rb_intern("linear"))
|
969
|
+
{
|
897
970
|
fann_train_error_function = FANN_ERRORFUNC_LINEAR;
|
898
|
-
}
|
971
|
+
}
|
972
|
+
else if (id == rb_intern("tanh"))
|
973
|
+
{
|
899
974
|
fann_train_error_function = FANN_ERRORFUNC_TANH;
|
900
|
-
}
|
975
|
+
}
|
976
|
+
else
|
977
|
+
{
|
901
978
|
rb_raise(rb_eRuntimeError, "Unrecognized train error function: [%s]", rb_id2name(SYM2ID(train_error_function)));
|
902
979
|
}
|
903
980
|
|
904
|
-
struct fann*
|
905
|
-
Data_Get_Struct
|
981
|
+
struct fann *f;
|
982
|
+
Data_Get_Struct(self, struct fann, f);
|
906
983
|
fann_set_train_error_function(f, fann_train_error_function);
|
907
984
|
return self;
|
908
985
|
}
|
@@ -911,14 +988,14 @@ static VALUE set_train_error_function(VALUE self, VALUE train_error_function)
|
|
911
988
|
:linear, :tanh*/
|
912
989
|
static VALUE get_train_error_function(VALUE self)
|
913
990
|
{
|
914
|
-
struct fann*
|
991
|
+
struct fann *f;
|
915
992
|
enum fann_errorfunc_enum train_error;
|
916
993
|
VALUE ret_val;
|
917
|
-
Data_Get_Struct
|
994
|
+
Data_Get_Struct(self, struct fann, f);
|
918
995
|
|
919
996
|
train_error = fann_get_train_error_function(f);
|
920
997
|
|
921
|
-
if(train_error==FANN_ERRORFUNC_LINEAR)
|
998
|
+
if (train_error == FANN_ERRORFUNC_LINEAR)
|
922
999
|
{
|
923
1000
|
ret_val = ID2SYM(rb_intern("linear"));
|
924
1001
|
}
|
@@ -937,23 +1014,32 @@ static VALUE set_training_algorithm(VALUE self, VALUE train_error_function)
|
|
937
1014
|
{
|
938
1015
|
Check_Type(train_error_function, T_SYMBOL);
|
939
1016
|
|
940
|
-
ID id=SYM2ID(train_error_function);
|
1017
|
+
ID id = SYM2ID(train_error_function);
|
941
1018
|
enum fann_train_enum fann_train_algorithm;
|
942
1019
|
|
943
|
-
if(id==rb_intern("incremental"))
|
1020
|
+
if (id == rb_intern("incremental"))
|
1021
|
+
{
|
944
1022
|
fann_train_algorithm = FANN_TRAIN_INCREMENTAL;
|
945
|
-
}
|
1023
|
+
}
|
1024
|
+
else if (id == rb_intern("batch"))
|
1025
|
+
{
|
946
1026
|
fann_train_algorithm = FANN_TRAIN_BATCH;
|
947
|
-
}
|
1027
|
+
}
|
1028
|
+
else if (id == rb_intern("rprop"))
|
1029
|
+
{
|
948
1030
|
fann_train_algorithm = FANN_TRAIN_RPROP;
|
949
|
-
}
|
1031
|
+
}
|
1032
|
+
else if (id == rb_intern("quickprop"))
|
1033
|
+
{
|
950
1034
|
fann_train_algorithm = FANN_TRAIN_QUICKPROP;
|
951
|
-
}
|
1035
|
+
}
|
1036
|
+
else
|
1037
|
+
{
|
952
1038
|
rb_raise(rb_eRuntimeError, "Unrecognized training algorithm function: [%s]", rb_id2name(SYM2ID(train_error_function)));
|
953
1039
|
}
|
954
1040
|
|
955
|
-
struct fann*
|
956
|
-
Data_Get_Struct
|
1041
|
+
struct fann *f;
|
1042
|
+
Data_Get_Struct(self, struct fann, f);
|
957
1043
|
fann_set_training_algorithm(f, fann_train_algorithm);
|
958
1044
|
return self;
|
959
1045
|
}
|
@@ -962,20 +1048,27 @@ static VALUE set_training_algorithm(VALUE self, VALUE train_error_function)
|
|
962
1048
|
:incremental, :batch, :rprop, :quickprop */
|
963
1049
|
static VALUE get_training_algorithm(VALUE self)
|
964
1050
|
{
|
965
|
-
struct fann*
|
1051
|
+
struct fann *f;
|
966
1052
|
enum fann_train_enum fann_train_algorithm;
|
967
1053
|
VALUE ret_val;
|
968
|
-
Data_Get_Struct
|
1054
|
+
Data_Get_Struct(self, struct fann, f);
|
969
1055
|
|
970
1056
|
fann_train_algorithm = fann_get_training_algorithm(f);
|
971
1057
|
|
972
|
-
if(fann_train_algorithm==FANN_TRAIN_INCREMENTAL)
|
1058
|
+
if (fann_train_algorithm == FANN_TRAIN_INCREMENTAL)
|
1059
|
+
{
|
973
1060
|
ret_val = ID2SYM(rb_intern("incremental"));
|
974
|
-
}
|
1061
|
+
}
|
1062
|
+
else if (fann_train_algorithm == FANN_TRAIN_BATCH)
|
1063
|
+
{
|
975
1064
|
ret_val = ID2SYM(rb_intern("batch"));
|
976
|
-
}
|
1065
|
+
}
|
1066
|
+
else if (fann_train_algorithm == FANN_TRAIN_RPROP)
|
1067
|
+
{
|
977
1068
|
ret_val = ID2SYM(rb_intern("rprop"));
|
978
|
-
}
|
1069
|
+
}
|
1070
|
+
else if (fann_train_algorithm == FANN_TRAIN_QUICKPROP)
|
1071
|
+
{
|
979
1072
|
ret_val = ID2SYM(rb_intern("quickprop"));
|
980
1073
|
}
|
981
1074
|
return ret_val;
|
@@ -988,19 +1081,24 @@ static VALUE get_training_algorithm(VALUE self)
|
|
988
1081
|
static VALUE set_train_stop_function(VALUE self, VALUE train_stop_function)
|
989
1082
|
{
|
990
1083
|
Check_Type(train_stop_function, T_SYMBOL);
|
991
|
-
ID id=SYM2ID(train_stop_function);
|
1084
|
+
ID id = SYM2ID(train_stop_function);
|
992
1085
|
enum fann_stopfunc_enum fann_train_stop_function;
|
993
1086
|
|
994
|
-
if(id==rb_intern("mse"))
|
1087
|
+
if (id == rb_intern("mse"))
|
1088
|
+
{
|
995
1089
|
fann_train_stop_function = FANN_STOPFUNC_MSE;
|
996
|
-
}
|
1090
|
+
}
|
1091
|
+
else if (id == rb_intern("bit"))
|
1092
|
+
{
|
997
1093
|
fann_train_stop_function = FANN_STOPFUNC_BIT;
|
998
|
-
}
|
1094
|
+
}
|
1095
|
+
else
|
1096
|
+
{
|
999
1097
|
rb_raise(rb_eRuntimeError, "Unrecognized stop function: [%s]", rb_id2name(SYM2ID(train_stop_function)));
|
1000
1098
|
}
|
1001
1099
|
|
1002
|
-
struct fann*
|
1003
|
-
Data_Get_Struct
|
1100
|
+
struct fann *f;
|
1101
|
+
Data_Get_Struct(self, struct fann, f);
|
1004
1102
|
fann_set_train_stop_function(f, fann_train_stop_function);
|
1005
1103
|
return self;
|
1006
1104
|
}
|
@@ -1009,14 +1107,14 @@ static VALUE set_train_stop_function(VALUE self, VALUE train_stop_function)
|
|
1009
1107
|
:mse, :bit */
|
1010
1108
|
static VALUE get_train_stop_function(VALUE self)
|
1011
1109
|
{
|
1012
|
-
struct fann*
|
1110
|
+
struct fann *f;
|
1013
1111
|
enum fann_stopfunc_enum train_stop;
|
1014
1112
|
VALUE ret_val;
|
1015
|
-
Data_Get_Struct
|
1113
|
+
Data_Get_Struct(self, struct fann, f);
|
1016
1114
|
|
1017
1115
|
train_stop = fann_get_train_stop_function(f);
|
1018
1116
|
|
1019
|
-
if(train_stop==FANN_STOPFUNC_MSE)
|
1117
|
+
if (train_stop == FANN_STOPFUNC_MSE)
|
1020
1118
|
{
|
1021
1119
|
ret_val = ID2SYM(rb_intern("mse")); // (rb_str_new2("FANN_NETTYPE_LAYER"));
|
1022
1120
|
}
|
@@ -1027,13 +1125,12 @@ static VALUE get_train_stop_function(VALUE self)
|
|
1027
1125
|
return ret_val;
|
1028
1126
|
}
|
1029
1127
|
|
1030
|
-
|
1031
1128
|
/** Will print the connections of the ann in a compact matrix,
|
1032
1129
|
for easy viewing of the internals of the ann. */
|
1033
1130
|
static VALUE print_connections(VALUE self)
|
1034
1131
|
{
|
1035
|
-
struct fann*
|
1036
|
-
Data_Get_Struct
|
1132
|
+
struct fann *f;
|
1133
|
+
Data_Get_Struct(self, struct fann, f);
|
1037
1134
|
fann_print_connections(f);
|
1038
1135
|
return self;
|
1039
1136
|
}
|
@@ -1041,8 +1138,8 @@ static VALUE print_connections(VALUE self)
|
|
1041
1138
|
/** Print current NN parameters to stdout */
|
1042
1139
|
static VALUE print_parameters(VALUE self)
|
1043
1140
|
{
|
1044
|
-
struct fann*
|
1045
|
-
Data_Get_Struct
|
1141
|
+
struct fann *f;
|
1142
|
+
Data_Get_Struct(self, struct fann, f);
|
1046
1143
|
fann_print_parameters(f);
|
1047
1144
|
return Qnil;
|
1048
1145
|
}
|
@@ -1054,8 +1151,8 @@ static VALUE randomize_weights(VALUE self, VALUE min_weight, VALUE max_weight)
|
|
1054
1151
|
{
|
1055
1152
|
Check_Type(min_weight, T_FLOAT);
|
1056
1153
|
Check_Type(max_weight, T_FLOAT);
|
1057
|
-
struct fann*
|
1058
|
-
Data_Get_Struct
|
1154
|
+
struct fann *f;
|
1155
|
+
Data_Get_Struct(self, struct fann, f);
|
1059
1156
|
fann_randomize_weights(f, NUM2DBL(min_weight), NUM2DBL(max_weight));
|
1060
1157
|
return self;
|
1061
1158
|
}
|
@@ -1064,37 +1161,36 @@ static VALUE randomize_weights(VALUE self, VALUE min_weight, VALUE max_weight)
|
|
1064
1161
|
|
1065
1162
|
Run neural net on array<Float> of inputs with current parameters.
|
1066
1163
|
Returns array<Float> as output */
|
1067
|
-
static VALUE run
|
1164
|
+
static VALUE run(VALUE self, VALUE inputs)
|
1068
1165
|
{
|
1069
1166
|
Check_Type(inputs, T_ARRAY);
|
1070
1167
|
|
1071
|
-
|
1168
|
+
struct fann *f;
|
1072
1169
|
unsigned int i;
|
1073
|
-
fann_type*
|
1170
|
+
fann_type *outputs;
|
1074
1171
|
|
1075
1172
|
// Convert inputs to type needed for NN:
|
1076
1173
|
unsigned int len = RARRAY_LEN(inputs);
|
1077
1174
|
fann_type fann_inputs[len];
|
1078
|
-
for (i=0; i<len; i++)
|
1175
|
+
for (i = 0; i < len; i++)
|
1079
1176
|
{
|
1080
1177
|
fann_inputs[i] = NUM2DBL(RARRAY_PTR(inputs)[i]);
|
1081
1178
|
}
|
1082
1179
|
|
1083
|
-
|
1084
1180
|
// Obtain NN & run method:
|
1085
|
-
|
1181
|
+
Data_Get_Struct(self, struct fann, f);
|
1086
1182
|
outputs = fann_run(f, fann_inputs);
|
1087
1183
|
|
1088
1184
|
// Create ruby array & set outputs:
|
1089
1185
|
VALUE arr;
|
1090
1186
|
arr = rb_ary_new();
|
1091
|
-
unsigned int output_len=fann_get_num_output(f);
|
1092
|
-
for (i=0; i<output_len; i++)
|
1187
|
+
unsigned int output_len = fann_get_num_output(f);
|
1188
|
+
for (i = 0; i < output_len; i++)
|
1093
1189
|
{
|
1094
1190
|
rb_ary_push(arr, rb_float_new(outputs[i]));
|
1095
1191
|
}
|
1096
1192
|
|
1097
|
-
|
1193
|
+
return arr;
|
1098
1194
|
}
|
1099
1195
|
|
1100
1196
|
/** call-seq: init_weights(train_data) -> return value
|
@@ -1105,10 +1201,10 @@ static VALUE init_weights(VALUE self, VALUE train_data)
|
|
1105
1201
|
|
1106
1202
|
Check_Type(train_data, T_DATA);
|
1107
1203
|
|
1108
|
-
struct fann*
|
1109
|
-
struct fann_train_data*
|
1110
|
-
Data_Get_Struct
|
1111
|
-
Data_Get_Struct
|
1204
|
+
struct fann *f;
|
1205
|
+
struct fann_train_data *t;
|
1206
|
+
Data_Get_Struct(self, struct fann, f);
|
1207
|
+
Data_Get_Struct(train_data, struct fann_train_data, t);
|
1112
1208
|
|
1113
1209
|
fann_init_weights(f, t);
|
1114
1210
|
return self;
|
@@ -1124,7 +1220,7 @@ static VALUE train(VALUE self, VALUE input, VALUE expected_output)
|
|
1124
1220
|
Check_Type(input, T_ARRAY);
|
1125
1221
|
Check_Type(expected_output, T_ARRAY);
|
1126
1222
|
|
1127
|
-
struct fann*
|
1223
|
+
struct fann *f;
|
1128
1224
|
Data_Get_Struct(self, struct fann, f);
|
1129
1225
|
|
1130
1226
|
unsigned int num_input = RARRAY_LEN(input);
|
@@ -1134,11 +1230,13 @@ static VALUE train(VALUE self, VALUE input, VALUE expected_output)
|
|
1134
1230
|
|
1135
1231
|
unsigned int i;
|
1136
1232
|
|
1137
|
-
for (i = 0; i < num_input; i++)
|
1233
|
+
for (i = 0; i < num_input; i++)
|
1234
|
+
{
|
1138
1235
|
data_input[i] = NUM2DBL(RARRAY_PTR(input)[i]);
|
1139
1236
|
}
|
1140
1237
|
|
1141
|
-
for (i = 0; i < num_output; i++)
|
1238
|
+
for (i = 0; i < num_output; i++)
|
1239
|
+
{
|
1142
1240
|
data_output[i] = NUM2DBL(RARRAY_PTR(expected_output)[i]);
|
1143
1241
|
}
|
1144
1242
|
|
@@ -1161,10 +1259,10 @@ static VALUE train_on_data(VALUE self, VALUE train_data, VALUE max_epochs, VALUE
|
|
1161
1259
|
Check_Type(epochs_between_reports, T_FIXNUM);
|
1162
1260
|
Check_Type(desired_error, T_FLOAT);
|
1163
1261
|
|
1164
|
-
struct fann*
|
1165
|
-
struct fann_train_data*
|
1166
|
-
Data_Get_Struct
|
1167
|
-
Data_Get_Struct
|
1262
|
+
struct fann *f;
|
1263
|
+
struct fann_train_data *t;
|
1264
|
+
Data_Get_Struct(self, struct fann, f);
|
1265
|
+
Data_Get_Struct(train_data, struct fann_train_data, t);
|
1168
1266
|
|
1169
1267
|
unsigned int fann_max_epochs = NUM2INT(max_epochs);
|
1170
1268
|
unsigned int fann_epochs_between_reports = NUM2INT(epochs_between_reports);
|
@@ -1179,10 +1277,10 @@ static VALUE train_on_data(VALUE self, VALUE train_data, VALUE max_epochs, VALUE
|
|
1179
1277
|
static VALUE train_epoch(VALUE self, VALUE train_data)
|
1180
1278
|
{
|
1181
1279
|
Check_Type(train_data, T_DATA);
|
1182
|
-
struct fann*
|
1183
|
-
struct fann_train_data*
|
1184
|
-
Data_Get_Struct
|
1185
|
-
Data_Get_Struct
|
1280
|
+
struct fann *f;
|
1281
|
+
struct fann_train_data *t;
|
1282
|
+
Data_Get_Struct(self, struct fann, f);
|
1283
|
+
Data_Get_Struct(train_data, struct fann_train_data, t);
|
1186
1284
|
return rb_float_new(fann_train_epoch(f, t));
|
1187
1285
|
}
|
1188
1286
|
|
@@ -1192,10 +1290,10 @@ static VALUE train_epoch(VALUE self, VALUE train_data)
|
|
1192
1290
|
static VALUE test_data(VALUE self, VALUE train_data)
|
1193
1291
|
{
|
1194
1292
|
Check_Type(train_data, T_DATA);
|
1195
|
-
struct fann*
|
1196
|
-
struct fann_train_data*
|
1197
|
-
Data_Get_Struct
|
1198
|
-
Data_Get_Struct
|
1293
|
+
struct fann *f;
|
1294
|
+
struct fann_train_data *t;
|
1295
|
+
Data_Get_Struct(self, struct fann, f);
|
1296
|
+
Data_Get_Struct(train_data, struct fann_train_data, t);
|
1199
1297
|
return rb_float_new(fann_test_data(f, t));
|
1200
1298
|
}
|
1201
1299
|
|
@@ -1232,10 +1330,10 @@ static VALUE cascadetrain_on_data(VALUE self, VALUE train_data, VALUE max_neuron
|
|
1232
1330
|
Check_Type(neurons_between_reports, T_FIXNUM);
|
1233
1331
|
Check_Type(desired_error, T_FLOAT);
|
1234
1332
|
|
1235
|
-
struct fann*
|
1236
|
-
struct fann_train_data*
|
1237
|
-
Data_Get_Struct
|
1238
|
-
Data_Get_Struct
|
1333
|
+
struct fann *f;
|
1334
|
+
struct fann_train_data *t;
|
1335
|
+
Data_Get_Struct(self, struct fann, f);
|
1336
|
+
Data_Get_Struct(train_data, struct fann_train_data, t);
|
1239
1337
|
|
1240
1338
|
unsigned int fann_max_neurons = NUM2INT(max_neurons);
|
1241
1339
|
unsigned int fann_neurons_between_reports = NUM2INT(neurons_between_reports);
|
@@ -1305,7 +1403,6 @@ static VALUE set_cascade_candidate_stagnation_epochs(VALUE self, VALUE cascade_c
|
|
1305
1403
|
SET_FANN_UINT(cascade_candidate_stagnation_epochs, fann_set_cascade_candidate_stagnation_epochs);
|
1306
1404
|
}
|
1307
1405
|
|
1308
|
-
|
1309
1406
|
/** The weight multiplier is a parameter which is used to multiply the weights from the candidate neuron
|
1310
1407
|
before adding the neuron to the neural network. This parameter is usually between 0 and 1, and is used
|
1311
1408
|
to make the training a bit less aggressive. */
|
@@ -1426,13 +1523,13 @@ static VALUE set_learning_momentum(VALUE self, VALUE learning_momentum)
|
|
1426
1523
|
static VALUE set_cascade_activation_functions(VALUE self, VALUE cascade_activation_functions)
|
1427
1524
|
{
|
1428
1525
|
Check_Type(cascade_activation_functions, T_ARRAY);
|
1429
|
-
struct fann*
|
1430
|
-
Data_Get_Struct
|
1526
|
+
struct fann *f;
|
1527
|
+
Data_Get_Struct(self, struct fann, f);
|
1431
1528
|
|
1432
1529
|
unsigned long cnt = RARRAY_LEN(cascade_activation_functions);
|
1433
1530
|
enum fann_activationfunc_enum fann_activation_functions[cnt];
|
1434
1531
|
unsigned int i;
|
1435
|
-
for (i=0; i<cnt; i++)
|
1532
|
+
for (i = 0; i < cnt; i++)
|
1436
1533
|
{
|
1437
1534
|
fann_activation_functions[i] = sym_to_activation_function(RARRAY_PTR(cascade_activation_functions)[i]);
|
1438
1535
|
}
|
@@ -1445,16 +1542,16 @@ static VALUE set_cascade_activation_functions(VALUE self, VALUE cascade_activati
|
|
1445
1542
|
the candidates. The default is [:sigmoid, :sigmoid_symmetric, :gaussian, :gaussian_symmetric, :elliot, :elliot_symmetric] */
|
1446
1543
|
static VALUE get_cascade_activation_functions(VALUE self)
|
1447
1544
|
{
|
1448
|
-
struct fann*
|
1449
|
-
Data_Get_Struct
|
1545
|
+
struct fann *f;
|
1546
|
+
Data_Get_Struct(self, struct fann, f);
|
1450
1547
|
unsigned int cnt = fann_get_cascade_activation_functions_count(f);
|
1451
|
-
enum fann_activationfunc_enum*
|
1548
|
+
enum fann_activationfunc_enum *fann_functions = fann_get_cascade_activation_functions(f);
|
1452
1549
|
|
1453
1550
|
// Create ruby array & set outputs:
|
1454
1551
|
VALUE arr;
|
1455
1552
|
arr = rb_ary_new();
|
1456
1553
|
unsigned int i;
|
1457
|
-
for (i=0; i<cnt; i++)
|
1554
|
+
for (i = 0; i < cnt; i++)
|
1458
1555
|
{
|
1459
1556
|
rb_ary_push(arr, activation_function_to_sym(fann_functions[i]));
|
1460
1557
|
}
|
@@ -1490,13 +1587,13 @@ static VALUE set_cascade_num_candidate_groups(VALUE self, VALUE cascade_num_cand
|
|
1490
1587
|
static VALUE set_cascade_activation_steepnesses(VALUE self, VALUE cascade_activation_steepnesses)
|
1491
1588
|
{
|
1492
1589
|
Check_Type(cascade_activation_steepnesses, T_ARRAY);
|
1493
|
-
struct fann*
|
1494
|
-
Data_Get_Struct
|
1590
|
+
struct fann *f;
|
1591
|
+
Data_Get_Struct(self, struct fann, f);
|
1495
1592
|
|
1496
1593
|
unsigned int cnt = RARRAY_LEN(cascade_activation_steepnesses);
|
1497
1594
|
fann_type fann_activation_steepnesses[cnt];
|
1498
1595
|
unsigned int i;
|
1499
|
-
for (i=0; i<cnt; i++)
|
1596
|
+
for (i = 0; i < cnt; i++)
|
1500
1597
|
{
|
1501
1598
|
fann_activation_steepnesses[i] = NUM2DBL(RARRAY_PTR(cascade_activation_steepnesses)[i]);
|
1502
1599
|
}
|
@@ -1509,16 +1606,16 @@ static VALUE set_cascade_activation_steepnesses(VALUE self, VALUE cascade_activa
|
|
1509
1606
|
the candidates. */
|
1510
1607
|
static VALUE get_cascade_activation_steepnesses(VALUE self)
|
1511
1608
|
{
|
1512
|
-
struct fann*
|
1513
|
-
Data_Get_Struct
|
1514
|
-
fann_type*
|
1609
|
+
struct fann *f;
|
1610
|
+
Data_Get_Struct(self, struct fann, f);
|
1611
|
+
fann_type *fann_steepnesses = fann_get_cascade_activation_steepnesses(f);
|
1515
1612
|
unsigned int cnt = fann_get_cascade_activation_steepnesses_count(f);
|
1516
1613
|
|
1517
1614
|
// Create ruby array & set outputs:
|
1518
1615
|
VALUE arr;
|
1519
1616
|
arr = rb_ary_new();
|
1520
1617
|
unsigned int i;
|
1521
|
-
for (i=0; i<cnt; i++)
|
1618
|
+
for (i = 0; i < cnt; i++)
|
1522
1619
|
{
|
1523
1620
|
rb_ary_push(arr, rb_float_new(fann_steepnesses[i]));
|
1524
1621
|
}
|
@@ -1531,21 +1628,21 @@ static VALUE get_cascade_activation_steepnesses(VALUE self)
|
|
1531
1628
|
Save the entire network to configuration file with given name */
|
1532
1629
|
static VALUE nn_save(VALUE self, VALUE filename)
|
1533
1630
|
{
|
1534
|
-
struct fann*
|
1535
|
-
Data_Get_Struct
|
1631
|
+
struct fann *f;
|
1632
|
+
Data_Get_Struct(self, struct fann, f);
|
1536
1633
|
int status = fann_save(f, StringValuePtr(filename));
|
1537
1634
|
return INT2NUM(status);
|
1538
1635
|
}
|
1539
1636
|
|
1540
1637
|
/** Initializes class under RubyFann module/namespace. */
|
1541
|
-
void Init_ruby_fann
|
1638
|
+
void Init_ruby_fann()
|
1542
1639
|
{
|
1543
1640
|
// RubyFann module/namespace:
|
1544
|
-
m_rb_fann_module = rb_define_module
|
1641
|
+
m_rb_fann_module = rb_define_module("RubyFann");
|
1545
1642
|
|
1546
1643
|
// Standard NN class:
|
1547
|
-
m_rb_fann_standard_class = rb_define_class_under
|
1548
|
-
rb_define_alloc_func
|
1644
|
+
m_rb_fann_standard_class = rb_define_class_under(m_rb_fann_module, "Standard", rb_cObject);
|
1645
|
+
rb_define_alloc_func(m_rb_fann_standard_class, fann_allocate);
|
1549
1646
|
rb_define_method(m_rb_fann_standard_class, "initialize", fann_initialize, 1);
|
1550
1647
|
rb_define_method(m_rb_fann_standard_class, "init_weights", init_weights, 1);
|
1551
1648
|
rb_define_method(m_rb_fann_standard_class, "set_activation_function", set_activation_function, 3);
|
@@ -1608,7 +1705,6 @@ void Init_ruby_fann ()
|
|
1608
1705
|
rb_define_method(m_rb_fann_standard_class, "get_training_algorithm", get_training_algorithm, 0);
|
1609
1706
|
rb_define_method(m_rb_fann_standard_class, "set_training_algorithm", set_training_algorithm, 1);
|
1610
1707
|
|
1611
|
-
|
1612
1708
|
// Cascade functions:
|
1613
1709
|
rb_define_method(m_rb_fann_standard_class, "cascadetrain_on_data", cascadetrain_on_data, 4);
|
1614
1710
|
rb_define_method(m_rb_fann_standard_class, "get_cascade_output_change_fraction", get_cascade_output_change_fraction, 0);
|
@@ -1638,14 +1734,13 @@ void Init_ruby_fann ()
|
|
1638
1734
|
rb_define_method(m_rb_fann_standard_class, "set_cascade_num_candidate_groups", set_cascade_num_candidate_groups, 1);
|
1639
1735
|
rb_define_method(m_rb_fann_standard_class, "save", nn_save, 1);
|
1640
1736
|
|
1641
|
-
|
1642
1737
|
// Uncomment for fixed-point mode (also recompile fann). Probably not going to be needed:
|
1643
|
-
//rb_define_method(clazz, "get_decimal_point", get_decimal_point, 0);
|
1644
|
-
//rb_define_method(clazz, "get_multiplier", get_multiplier, 0);
|
1738
|
+
// rb_define_method(clazz, "get_decimal_point", get_decimal_point, 0);
|
1739
|
+
// rb_define_method(clazz, "get_multiplier", get_multiplier, 0);
|
1645
1740
|
|
1646
1741
|
// Shortcut NN class (duplicated from above so that rdoc generation tools can find the methods:):
|
1647
|
-
m_rb_fann_shortcut_class = rb_define_class_under
|
1648
|
-
rb_define_alloc_func
|
1742
|
+
m_rb_fann_shortcut_class = rb_define_class_under(m_rb_fann_module, "Shortcut", rb_cObject);
|
1743
|
+
rb_define_alloc_func(m_rb_fann_shortcut_class, fann_allocate);
|
1649
1744
|
rb_define_method(m_rb_fann_shortcut_class, "initialize", fann_initialize, 1);
|
1650
1745
|
rb_define_method(m_rb_fann_shortcut_class, "init_weights", init_weights, 1);
|
1651
1746
|
rb_define_method(m_rb_fann_shortcut_class, "set_activation_function", set_activation_function, 3);
|
@@ -1737,10 +1832,9 @@ void Init_ruby_fann ()
|
|
1737
1832
|
rb_define_method(m_rb_fann_shortcut_class, "set_cascade_num_candidate_groups", set_cascade_num_candidate_groups, 1);
|
1738
1833
|
rb_define_method(m_rb_fann_shortcut_class, "save", nn_save, 1);
|
1739
1834
|
|
1740
|
-
|
1741
1835
|
// TrainData NN class:
|
1742
|
-
m_rb_fann_train_data_class = rb_define_class_under
|
1743
|
-
rb_define_alloc_func
|
1836
|
+
m_rb_fann_train_data_class = rb_define_class_under(m_rb_fann_module, "TrainData", rb_cObject);
|
1837
|
+
rb_define_alloc_func(m_rb_fann_train_data_class, fann_training_data_allocate);
|
1744
1838
|
rb_define_method(m_rb_fann_train_data_class, "initialize", fann_train_data_initialize, 1);
|
1745
1839
|
rb_define_method(m_rb_fann_train_data_class, "length", length_train_data, 0);
|
1746
1840
|
rb_define_method(m_rb_fann_train_data_class, "shuffle", shuffle, 0);
|