moo_fann 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1768 @@
1
+ /*
2
+ * Copyright 2018 Maxine Michalski <maxine@furfind.net>
3
+ * Copyright 2013 ruby-fann contributors
4
+ * <https://github.com/tangledpath/ruby-fann#contributors>
5
+ *
6
+ * This file is part of moo_fann.
7
+ *
8
+ * moo_fann is free software: you can redistribute it and/or modify
9
+ * it under the terms of the GNU General Public License as published by
10
+ * the Free Software Foundation, either version 3 of the License, or
11
+ * (at your option) any later version.
12
+ *
13
+ * moo_fann is distributed in the hope that it will be useful,
14
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
15
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16
+ * GNU General Public License for more details.
17
+ *
18
+ * You should have received a copy of the GNU General Public License
19
+ * along with moo_fann. If not, see <http://www.gnu.org/licenses/>.
20
+ */
21
+
22
+ #include "ruby.h"
23
+ #include "ruby_compat.h"
24
+ #include "doublefann.h"
25
+ #include "fann_data.h"
26
+ #include "fann_augment.h"
27
+
28
+ static VALUE m_rb_fann_module;
29
+ static VALUE m_rb_fann_standard_class;
30
+ static VALUE m_rb_fann_shortcut_class;
31
+ static VALUE m_rb_fann_train_data_class;
32
+
33
+ #define RETURN_FANN_INT(fn) \
34
+ struct fann* f; \
35
+ Data_Get_Struct (self, struct fann, f); \
36
+ return INT2NUM(fn(f));
37
+
38
+ #define SET_FANN_INT(attr_name, fann_fn) \
39
+ Check_Type(attr_name, T_FIXNUM); \
40
+ struct fann* f; \
41
+ Data_Get_Struct(self, struct fann, f); \
42
+ fann_fn(f, NUM2INT(attr_name)); \
43
+ return 0;
44
+
45
+ #define RETURN_FANN_UINT(fn) \
46
+ struct fann* f; \
47
+ Data_Get_Struct (self, struct fann, f); \
48
+ return UINT2NUM(fn(f));
49
+
50
+ #define SET_FANN_UINT(attr_name, fann_fn) \
51
+ Check_Type(attr_name, T_FIXNUM); \
52
+ struct fann* f; \
53
+ Data_Get_Struct(self, struct fann, f); \
54
+ fann_fn(f, NUM2UINT(attr_name)); \
55
+ return 0;
56
+
57
+ // Converts float return values to a double with same precision, avoids floating point errors.
58
+ #define RETURN_FANN_FLT(fn) \
59
+ struct fann* f; \
60
+ Data_Get_Struct (self, struct fann, f); \
61
+ char buffy[20]; \
62
+ sprintf(buffy, "%0.6g", fn(f)); \
63
+ return rb_float_new(atof(buffy));
64
+
65
+ #define SET_FANN_FLT(attr_name, fann_fn) \
66
+ Check_Type(attr_name, T_FLOAT); \
67
+ struct fann* f; \
68
+ Data_Get_Struct(self, struct fann, f); \
69
+ fann_fn(f, NUM2DBL(attr_name)); \
70
+ return self;
71
+
72
+ #define RETURN_FANN_DBL(fn) \
73
+ struct fann* f; \
74
+ Data_Get_Struct (self, struct fann, f); \
75
+ return rb_float_new(fn(f));
76
+
77
+ #define SET_FANN_DBL SET_FANN_FLT
78
+
79
+ // Convert ruby symbol to corresponding FANN enum type for activation function:
80
+ enum fann_activationfunc_enum sym_to_activation_function(VALUE activation_func)
81
+ {
82
+ ID id=SYM2ID(activation_func);
83
+ enum fann_activationfunc_enum activation_function;
84
+ if(id==rb_intern("linear")) {
85
+ activation_function = FANN_LINEAR;
86
+ } else if(id==rb_intern("threshold")) {
87
+ activation_function = FANN_THRESHOLD;
88
+ } else if(id==rb_intern("threshold_symmetric")) {
89
+ activation_function = FANN_THRESHOLD_SYMMETRIC;
90
+ } else if(id==rb_intern("sigmoid")) {
91
+ activation_function = FANN_SIGMOID;
92
+ } else if(id==rb_intern("sigmoid_stepwise")) {
93
+ activation_function = FANN_SIGMOID_STEPWISE;
94
+ } else if(id==rb_intern("sigmoid_symmetric")) {
95
+ activation_function = FANN_SIGMOID_SYMMETRIC;
96
+ } else if(id==rb_intern("sigmoid_symmetric_stepwise")) {
97
+ activation_function = FANN_SIGMOID_SYMMETRIC_STEPWISE;
98
+ } else if(id==rb_intern("gaussian")) {
99
+ activation_function = FANN_GAUSSIAN;
100
+ } else if(id==rb_intern("gaussian_symmetric")) {
101
+ activation_function = FANN_GAUSSIAN_SYMMETRIC;
102
+ } else if(id==rb_intern("gaussian_stepwise")) {
103
+ activation_function = FANN_GAUSSIAN_STEPWISE;
104
+ } else if(id==rb_intern("elliot")) {
105
+ activation_function = FANN_ELLIOT;
106
+ } else if(id==rb_intern("elliot_symmetric")) {
107
+ activation_function = FANN_ELLIOT_SYMMETRIC;
108
+ } else if(id==rb_intern("linear_piece")) {
109
+ activation_function = FANN_LINEAR_PIECE;
110
+ } else if(id==rb_intern("linear_piece_symmetric")) {
111
+ activation_function = FANN_LINEAR_PIECE_SYMMETRIC;
112
+ } else if(id==rb_intern("sin_symmetric")) {
113
+ activation_function = FANN_SIN_SYMMETRIC;
114
+ } else if(id==rb_intern("cos_symmetric")) {
115
+ activation_function = FANN_COS_SYMMETRIC;
116
+ } else if(id==rb_intern("sin")) {
117
+ activation_function = FANN_SIN;
118
+ } else if(id==rb_intern("cos")) {
119
+ activation_function = FANN_COS;
120
+ } else {
121
+ rb_raise(rb_eRuntimeError, "Unrecognized activation function: [%s]", rb_id2name(SYM2ID(activation_func)));
122
+ }
123
+ return activation_function;
124
+ }
125
+
126
+ // Convert FANN enum type for activation function to corresponding ruby symbol:
127
+ VALUE activation_function_to_sym(enum fann_activationfunc_enum fn)
128
+ {
129
+ VALUE activation_function;
130
+
131
+ if(fn==FANN_LINEAR) {
132
+ activation_function = ID2SYM(rb_intern("linear"));
133
+ } else if(fn==FANN_THRESHOLD) {
134
+ activation_function = ID2SYM(rb_intern("threshold"));
135
+ } else if(fn==FANN_THRESHOLD_SYMMETRIC) {
136
+ activation_function = ID2SYM(rb_intern("threshold_symmetric"));
137
+ } else if(fn==FANN_SIGMOID) {
138
+ activation_function = ID2SYM(rb_intern("sigmoid"));
139
+ } else if(fn==FANN_SIGMOID_STEPWISE) {
140
+ activation_function = ID2SYM(rb_intern("sigmoid_stepwise"));
141
+ } else if(fn==FANN_SIGMOID_SYMMETRIC) {
142
+ activation_function = ID2SYM(rb_intern("sigmoid_symmetric"));
143
+ } else if(fn==FANN_SIGMOID_SYMMETRIC_STEPWISE) {
144
+ activation_function = ID2SYM(rb_intern("sigmoid_symmetric_stepwise"));
145
+ } else if(fn==FANN_GAUSSIAN) {
146
+ activation_function = ID2SYM(rb_intern("gaussian"));
147
+ } else if(fn==FANN_GAUSSIAN_SYMMETRIC) {
148
+ activation_function = ID2SYM(rb_intern("gaussian_symmetric"));
149
+ } else if(fn==FANN_GAUSSIAN_STEPWISE) {
150
+ activation_function = ID2SYM(rb_intern("gaussian_stepwise"));
151
+ } else if(fn==FANN_ELLIOT) {
152
+ activation_function = ID2SYM(rb_intern("elliot"));
153
+ } else if(fn==FANN_ELLIOT_SYMMETRIC) {
154
+ activation_function = ID2SYM(rb_intern("elliot_symmetric"));
155
+ } else if(fn==FANN_LINEAR_PIECE) {
156
+ activation_function = ID2SYM(rb_intern("linear_piece"));
157
+ } else if(fn==FANN_LINEAR_PIECE_SYMMETRIC) {
158
+ activation_function = ID2SYM(rb_intern("linear_piece_symmetric"));
159
+ } else if(fn==FANN_SIN_SYMMETRIC) {
160
+ activation_function = ID2SYM(rb_intern("sin_symmetric"));
161
+ } else if(fn==FANN_COS_SYMMETRIC) {
162
+ activation_function = ID2SYM(rb_intern("cos_symmetric"));
163
+ } else if(fn==FANN_SIN) {
164
+ activation_function = ID2SYM(rb_intern("sin"));
165
+ } else if(fn==FANN_COS) {
166
+ activation_function = ID2SYM(rb_intern("cos"));
167
+ } else {
168
+ rb_raise(rb_eRuntimeError, "Unrecognized activation function: [%d]", fn);
169
+ }
170
+ return activation_function;
171
+ }
172
+
173
+
174
+ // Unused for now:
175
+ static void fann_mark (struct fann* ann){}
176
+
177
+ // #define DEBUG 1
178
+
179
+ // Free memory associated with FANN:
180
+ static void fann_free (struct fann* ann)
181
+ {
182
+ fann_destroy(ann);
183
+ // ("Destroyed FANN network [%d].\n", ann);
184
+ }
185
+
186
+ // Free memory associated with FANN Training data:
187
+ static void fann_training_data_free (struct fann_train_data* train_data)
188
+ {
189
+ fann_destroy_train(train_data);
190
+ // printf("Destroyed Training data [%d].\n", train_data);
191
+ }
192
+
193
+ // Create wrapper, but don't allocate anything...do that in
194
+ // initialize, so we can construct with args:
195
+ static VALUE fann_allocate (VALUE klass)
196
+ {
197
+ return Data_Wrap_Struct (klass, fann_mark, fann_free, 0);
198
+ }
199
+
200
+ // Create wrapper, but don't allocate annything...do that in
201
+ // initialize, so we can construct with args:
202
+ static VALUE fann_training_data_allocate (VALUE klass)
203
+ {
204
+ return Data_Wrap_Struct (klass, fann_mark, fann_training_data_free, 0);
205
+ }
206
+
207
+
208
+ // static VALUE invoke_training_callback(VALUE self)
209
+ // {
210
+ // VALUE callback = rb_funcall(self, rb_intern("training_callback"), 0);
211
+ // return callback;
212
+ // }
213
+
214
+ // static int FANN_API internal_callback(struct fann *ann, struct fann_train_data *train,
215
+ // unsigned int max_epochs, unsigned int epochs_between_reports, float desired_error, unsigned int epochs)
216
+
217
+ static int FANN_API fann_training_callback(struct fann *ann, struct fann_train_data *train,
218
+ unsigned int max_epochs, unsigned int epochs_between_reports,
219
+ float desired_error, unsigned int epochs)
220
+ {
221
+ VALUE self = (VALUE)fann_get_user_data(ann);
222
+ VALUE args = rb_hash_new();
223
+
224
+ // Set attributes on hash & push on array:
225
+ VALUE max_epochs_sym = ID2SYM(rb_intern("max_epochs"));
226
+ VALUE epochs_between_reports_sym = ID2SYM(rb_intern("epochs_between_reports"));
227
+ VALUE desired_error_sym = ID2SYM(rb_intern("desired_error"));
228
+ VALUE epochs_sym = ID2SYM(rb_intern("epochs"));
229
+
230
+ rb_hash_aset(args, max_epochs_sym, INT2NUM(max_epochs));
231
+ rb_hash_aset(args, epochs_between_reports_sym, INT2NUM(epochs_between_reports));
232
+ rb_hash_aset(args, desired_error_sym, rb_float_new(desired_error));
233
+ rb_hash_aset(args, epochs_sym, INT2NUM(epochs));
234
+
235
+ VALUE callback = rb_funcall(self, rb_intern("training_callback"), 1, args);
236
+
237
+ if (TYPE(callback)!=T_FIXNUM)
238
+ {
239
+ rb_raise (rb_eRuntimeError, "Callback method must return an integer (-1 to stop training).");
240
+ }
241
+
242
+ int status = NUM2INT(callback);
243
+ if (status==-1)
244
+ {
245
+ printf("Callback method returned -1; training will stop.\n");
246
+ }
247
+
248
+ return status;
249
+ }
250
+
251
+ /** call-seq: new(hash) -> new ruby-fann neural network object
252
+
253
+ Initialization routine for both standard, shortcut & filename forms of FANN:
254
+
255
+ Standard Initialization:
256
+ MooFann::Standard.new(:num_inputs=>1, :hidden_neurons=>[3, 4, 3, 4], :num_outputs=>1)
257
+
258
+ Shortcut Initialization (e.g., for use in cascade training):
259
+ MooFann::Shortcut.new(:num_inputs=>5, :num_outputs=>1)
260
+
261
+ File Initialization
262
+ MooFann::Standard.new(:filename=>'xor_float.net')
263
+
264
+
265
+
266
+ */
267
+ static VALUE fann_initialize(VALUE self, VALUE hash)
268
+ {
269
+ // Get args:
270
+ VALUE filename = rb_hash_aref(hash, ID2SYM(rb_intern("filename")));
271
+ VALUE num_inputs = rb_hash_aref(hash, ID2SYM(rb_intern("num_inputs")));
272
+ VALUE num_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("num_outputs")));
273
+ VALUE hidden_neurons = rb_hash_aref(hash, ID2SYM(rb_intern("hidden_neurons")));
274
+ // printf("initializing\n\n\n");
275
+ struct fann* ann;
276
+ if (TYPE(filename)==T_STRING)
277
+ {
278
+ // Initialize with file:
279
+ // train_data = fann_read_train_from_file(StringValuePtr(filename));
280
+ // DATA_PTR(self) = train_data;
281
+ ann = fann_create_from_file(StringValuePtr(filename));
282
+ // printf("Created MooFann::Standard [%d] from file [%s].\n", ann, StringValuePtr(filename));
283
+ }
284
+ else if(rb_obj_is_kind_of(self, m_rb_fann_shortcut_class))
285
+ {
286
+ // Initialize as shortcut, suitable for cascade training:
287
+ //ann = fann_create_shortcut_array(num_layers, layers);
288
+ Check_Type(num_inputs, T_FIXNUM);
289
+ Check_Type(num_outputs, T_FIXNUM);
290
+
291
+ ann = fann_create_shortcut(2, NUM2INT(num_inputs), NUM2INT(num_outputs));
292
+ // printf("Created MooFann::Shortcut [%d].\n", ann);
293
+ }
294
+ else
295
+ {
296
+ // Initialize as standard:
297
+ Check_Type(num_inputs, T_FIXNUM);
298
+ Check_Type(hidden_neurons, T_ARRAY);
299
+ Check_Type(num_outputs, T_FIXNUM);
300
+
301
+ // Initialize layers:
302
+ unsigned int num_layers=RARRAY_LEN(hidden_neurons) + 2;
303
+ unsigned int layers[num_layers];
304
+
305
+ // Input:
306
+ layers[0]=NUM2INT(num_inputs);
307
+ // Output:
308
+ layers[num_layers-1]=NUM2INT(num_outputs);
309
+ // Hidden:
310
+ int i;
311
+ for (i=1; i<=num_layers-2; i++) {
312
+ layers[i]=NUM2UINT(RARRAY_PTR(hidden_neurons)[i-1]);
313
+ }
314
+
315
+ ann = fann_create_standard_array(num_layers, layers);
316
+ // printf("Created MooFann::Standard [%d].\n", ann);
317
+ }
318
+
319
+ DATA_PTR(self) = ann;
320
+
321
+ // printf("Checking for callback...");
322
+
323
+ //int callback = rb_protect(invoke_training_callback, (self), &status);
324
+ // VALUE callback = rb_funcall(DATA_PTR(self), "training_callback", 0);
325
+ if(rb_respond_to(self, rb_intern("training_callback")))
326
+ {
327
+ fann_set_callback(ann, &fann_training_callback);
328
+ fann_set_user_data(ann, self);
329
+ // printf("found(%d).\n", ann->callback);
330
+ }
331
+ else
332
+ {
333
+ // printf("none found.\n");
334
+ }
335
+
336
+ return (VALUE)ann;
337
+ }
338
+
339
+ /** call-seq: new(hash) -> new ruby-fann training data object (MooFann::TrainData)
340
+
341
+ Initialize in one of the following forms:
342
+
343
+ # This is a flat file with training data as described in FANN docs.
344
+ MooFann::TrainData.new(:filename => 'path/to/training_file.train')
345
+ OR
346
+ # Train with inputs (array of arrays) & desired_outputs (array of arrays)
347
+ # inputs & desired outputs should be of same length
348
+ # All sub-arrays on inputs should be of same length
349
+ # All sub-arrays on desired_outputs should be of same length
350
+ # Sub-arrays on inputs & desired_outputs can be different sizes from one another
351
+ MooFann::TrainData.new(:inputs=>[[0.2, 0.3, 0.4], [0.8, 0.9, 0.7]], :desired_outputs=>[[3.14], [6.33]])
352
+ */
353
+ static VALUE fann_train_data_initialize(VALUE self, VALUE hash)
354
+ {
355
+ struct fann_train_data* train_data;
356
+ Check_Type(hash, T_HASH);
357
+
358
+ VALUE filename = rb_hash_aref(hash, ID2SYM(rb_intern("filename")));
359
+ VALUE inputs = rb_hash_aref(hash, ID2SYM(rb_intern("inputs")));
360
+ VALUE desired_outputs = rb_hash_aref(hash, ID2SYM(rb_intern("desired_outputs")));
361
+
362
+ if (TYPE(filename)==T_STRING)
363
+ {
364
+ train_data = fann_read_train_from_file(StringValuePtr(filename));
365
+ DATA_PTR(self) = train_data;
366
+ }
367
+ else if (TYPE(inputs)==T_ARRAY)
368
+ {
369
+ if (TYPE(desired_outputs)!=T_ARRAY)
370
+ {
371
+ rb_raise (rb_eRuntimeError, "[desired_outputs] must be present when [inputs] used.");
372
+ }
373
+
374
+ if (RARRAY_LEN(inputs) < 1)
375
+ {
376
+ rb_raise (rb_eRuntimeError, "[inputs/desired_outputs] must contain at least one value.");
377
+ }
378
+
379
+ // The data is here, start constructing:
380
+ if(RARRAY_LEN(inputs) != RARRAY_LEN(desired_outputs))
381
+ {
382
+ rb_raise (
383
+ rb_eRuntimeError,
384
+ "Number of inputs must match number of outputs: (%d != %d)",
385
+ (int)RARRAY_LEN(inputs),
386
+ (int)RARRAY_LEN(desired_outputs));
387
+ }
388
+
389
+ train_data = fann_create_train_from_rb_ary(inputs, desired_outputs);
390
+ DATA_PTR(self) = train_data;
391
+ }
392
+ else
393
+ {
394
+ rb_raise (rb_eRuntimeError, "Must construct with a filename(string) or inputs/desired_outputs(arrays). All args passed via hash with symbols as keys.");
395
+ }
396
+
397
+ return (VALUE)train_data;
398
+ }
399
+
400
+
401
+ /** call-seq: save(filename)
402
+
403
+ Save to given filename
404
+ */
405
+ static VALUE training_save(VALUE self, VALUE filename)
406
+ {
407
+ Check_Type(filename, T_STRING);
408
+ struct fann_train_data* t;
409
+ Data_Get_Struct (self, struct fann_train_data, t);
410
+ fann_save_train(t, StringValuePtr(filename));
411
+ return self;
412
+ }
413
+
414
+ /** Shuffles training data, randomizing the order.
415
+ This is recommended for incremental training, while it will have no influence during batch training.*/
416
+ static VALUE shuffle(VALUE self)
417
+ {
418
+ struct fann_train_data* t;
419
+ Data_Get_Struct (self, struct fann_train_data, t);
420
+ fann_shuffle_train_data(t);
421
+ return self;
422
+ }
423
+
424
+ /** Length of training data*/
425
+ static VALUE length_train_data(VALUE self)
426
+ {
427
+ struct fann_train_data* t;
428
+ Data_Get_Struct (self, struct fann_train_data, t);
429
+ return(UINT2NUM(fann_length_train_data(t)));
430
+ return self;
431
+ }
432
+
433
+ /** call-seq: set_activation_function(activation_func, layer, neuron)
434
+
435
+ Set the activation function for neuron number *neuron* in layer number *layer*,
436
+ counting the input layer as layer 0. activation_func must be one of the following symbols:
437
+ :linear, :threshold, :threshold_symmetric, :sigmoid, :sigmoid_stepwise, :sigmoid_symmetric,
438
+ :sigmoid_symmetric_stepwise, :gaussian, :gaussian_symmetric, :gaussian_stepwise, :elliot,
439
+ :elliot_symmetric, :linear_piece, :linear_piece_symmetric, :sin_symmetric, :cos_symmetric,
440
+ :sin, :cos*/
441
+ static VALUE set_activation_function(VALUE self, VALUE activation_func, VALUE layer, VALUE neuron)
442
+ {
443
+ Check_Type(activation_func, T_SYMBOL);
444
+ Check_Type(layer, T_FIXNUM);
445
+ Check_Type(neuron, T_FIXNUM);
446
+
447
+ struct fann* f;
448
+ Data_Get_Struct(self, struct fann, f);
449
+ fann_set_activation_function(f, sym_to_activation_function(activation_func), NUM2INT(layer), NUM2INT(neuron));
450
+ return self;
451
+ }
452
+
453
+ /** call-seq: set_activation_function_hidden(activation_func)
454
+
455
+ Set the activation function for all of the hidden layers. activation_func must be one of the following symbols:
456
+ :linear, :threshold, :threshold_symmetric, :sigmoid, :sigmoid_stepwise, :sigmoid_symmetric,
457
+ :sigmoid_symmetric_stepwise, :gaussian, :gaussian_symmetric, :gaussian_stepwise, :elliot,
458
+ :elliot_symmetric, :linear_piece, :linear_piece_symmetric, :sin_symmetric, :cos_symmetric,
459
+ :sin, :cos*/
460
+ static VALUE set_activation_function_hidden(VALUE self, VALUE activation_func)
461
+ {
462
+ Check_Type(activation_func, T_SYMBOL);
463
+ struct fann* f;
464
+ Data_Get_Struct(self, struct fann, f);
465
+ fann_set_activation_function_hidden(f, sym_to_activation_function(activation_func));
466
+ return self;
467
+ }
468
+
469
+ /** call-seq: set_activation_function_layer(activation_func, layer)
470
+
471
+ Set the activation function for all the neurons in the layer number *layer*,
472
+ counting the input layer as layer 0. activation_func must be one of the following symbols:
473
+ :linear, :threshold, :threshold_symmetric, :sigmoid, :sigmoid_stepwise, :sigmoid_symmetric,
474
+ :sigmoid_symmetric_stepwise, :gaussian, :gaussian_symmetric, :gaussian_stepwise, :elliot,
475
+ :elliot_symmetric, :linear_piece, :linear_piece_symmetric, :sin_symmetric, :cos_symmetric,
476
+ :sin, :cos
477
+
478
+ It is not possible to set activation functions for the neurons in the input layer.
479
+ */
480
+ static VALUE set_activation_function_layer(VALUE self, VALUE activation_func, VALUE layer)
481
+ {
482
+ Check_Type(activation_func, T_SYMBOL);
483
+ Check_Type(layer, T_FIXNUM);
484
+ struct fann* f;
485
+ Data_Get_Struct(self, struct fann, f);
486
+ fann_set_activation_function_layer(f, sym_to_activation_function(activation_func), NUM2INT(layer));
487
+ return self;
488
+ }
489
+
490
+ /** call-seq: get_activation_function(layer) -> return value
491
+
492
+ Get the activation function for neuron number *neuron* in layer number *layer*,
493
+ counting the input layer as layer 0.
494
+
495
+ It is not possible to get activation functions for the neurons in the input layer.
496
+ */
497
+ static VALUE get_activation_function(VALUE self, VALUE layer, VALUE neuron)
498
+ {
499
+ Check_Type(layer, T_FIXNUM);
500
+ Check_Type(neuron, T_FIXNUM);
501
+ struct fann* f;
502
+ Data_Get_Struct(self, struct fann, f);
503
+ fann_type val = fann_get_activation_function(f, NUM2INT(layer), NUM2INT(neuron));
504
+ return activation_function_to_sym(val);
505
+ }
506
+
507
+ /** call-seq: set_activation_function_output(activation_func)
508
+
509
+ Set the activation function for the output layer. activation_func must be one of the following symbols:
510
+ :linear, :threshold, :threshold_symmetric, :sigmoid, :sigmoid_stepwise, :sigmoid_symmetric,
511
+ :sigmoid_symmetric_stepwise, :gaussian, :gaussian_symmetric, :gaussian_stepwise, :elliot,
512
+ :elliot_symmetric, :linear_piece, :linear_piece_symmetric, :sin_symmetric, :cos_symmetric,
513
+ :sin, :cos*/
514
+
515
+ static VALUE set_activation_function_output(VALUE self, VALUE activation_func)
516
+ {
517
+ Check_Type(activation_func, T_SYMBOL);
518
+ struct fann* f;
519
+ Data_Get_Struct(self, struct fann, f);
520
+ fann_set_activation_function_output(f, sym_to_activation_function(activation_func));
521
+ return self;
522
+ }
523
+
524
+ /** call-seq: get_activation_steepness(layer, neuron) -> return value
525
+
526
+ Get the activation steepness for neuron number neuron in layer number layer, counting the input layer as layer 0.
527
+ */
528
+ static VALUE get_activation_steepness(VALUE self, VALUE layer, VALUE neuron)
529
+ {
530
+ Check_Type(layer, T_FIXNUM);
531
+ Check_Type(neuron, T_FIXNUM);
532
+ struct fann* f;
533
+ Data_Get_Struct(self, struct fann, f);
534
+ fann_type val = fann_get_activation_steepness(f, NUM2INT(layer), NUM2INT(neuron));
535
+ return rb_float_new(val);
536
+ }
537
+
538
+ /** call-seq: set_activation_steepness(steepness, layer, neuron)
539
+
540
+ Set the activation steepness for neuron number {neuron} in layer number {layer},
541
+ counting the input layer as layer 0.*/
542
+ static VALUE set_activation_steepness(VALUE self, VALUE steepness, VALUE layer, VALUE neuron)
543
+ {
544
+ Check_Type(steepness, T_FLOAT);
545
+ Check_Type(layer, T_FIXNUM);
546
+ Check_Type(neuron, T_FIXNUM);
547
+
548
+ struct fann* f;
549
+ Data_Get_Struct(self, struct fann, f);
550
+ fann_set_activation_steepness(f, NUM2DBL(steepness), NUM2INT(layer), NUM2INT(neuron));
551
+ return self;
552
+ }
553
+
554
+ /** call-seq: set_activation_steepness_hidden(arg) -> return value
555
+
556
+ Set the activation steepness in all of the hidden layers.*/
557
+ static VALUE set_activation_steepness_hidden(VALUE self, VALUE steepness)
558
+ {
559
+ SET_FANN_FLT(steepness, fann_set_activation_steepness_hidden);
560
+ }
561
+
562
+ /** call-seq: set_activation_steepness_layer(steepness, layer)
563
+
564
+ Set the activation steepness all of the neurons in layer number *layer*,
565
+ counting the input layer as layer 0.*/
566
+ static VALUE set_activation_steepness_layer(VALUE self, VALUE steepness, VALUE layer)
567
+ {
568
+ Check_Type(steepness, T_FLOAT);
569
+ Check_Type(layer, T_FIXNUM);
570
+
571
+ struct fann* f;
572
+ Data_Get_Struct(self, struct fann, f);
573
+ fann_set_activation_steepness_layer(f, NUM2DBL(steepness), NUM2INT(layer));
574
+ return self;
575
+ }
576
+
577
+ /** call-seq: set_activation_steepness_output(steepness)
578
+
579
+ Set the activation steepness in the output layer.*/
580
+ static VALUE set_activation_steepness_output(VALUE self, VALUE steepness)
581
+ {
582
+ SET_FANN_FLT(steepness, fann_set_activation_steepness_output);
583
+ }
584
+
585
+ /** Returns the bit fail limit used during training.*/
586
+ static VALUE get_bit_fail_limit(VALUE self)
587
+ {
588
+ RETURN_FANN_DBL(fann_get_bit_fail_limit);
589
+ }
590
+
591
+ /** call-seq: set_bit_fail_limit(bit_fail_limit)
592
+
593
+ Sets the bit fail limit used during training.*/
594
+ static VALUE set_bit_fail_limit(VALUE self, VALUE bit_fail_limit)
595
+ {
596
+ SET_FANN_FLT(bit_fail_limit, fann_set_bit_fail_limit);
597
+ }
598
+
599
+ /** The decay is a small negative valued number which is the factor that the weights
600
+ should become smaller in each iteration during quickprop training. This is used
601
+ to make sure that the weights do not become too high during training.*/
602
+ static VALUE get_quickprop_decay(VALUE self)
603
+ {
604
+ RETURN_FANN_FLT(fann_get_quickprop_decay);
605
+ }
606
+
607
+ /** call-seq: set_quickprop_decay(quickprop_decay)
608
+
609
+ Sets the quickprop decay factor*/
610
+ static VALUE set_quickprop_decay(VALUE self, VALUE quickprop_decay)
611
+ {
612
+ SET_FANN_FLT(quickprop_decay, fann_set_quickprop_decay);
613
+ }
614
+
615
+ /** The mu factor is used to increase and decrease the step-size during quickprop training.
616
+ The mu factor should always be above 1, since it would otherwise decrease the step-size
617
+ when it was suppose to increase it. */
618
+ static VALUE get_quickprop_mu(VALUE self)
619
+ {
620
+ RETURN_FANN_FLT(fann_get_quickprop_mu);
621
+ }
622
+
623
+ /** call-seq: set_quickprop_mu(quickprop_mu)
624
+
625
+ Sets the quickprop mu factor.*/
626
+ static VALUE set_quickprop_mu(VALUE self, VALUE quickprop_mu)
627
+ {
628
+ SET_FANN_FLT(quickprop_mu, fann_set_quickprop_mu);
629
+ }
630
+
631
+ /** The increase factor is a value larger than 1, which is used to
632
+ increase the step-size during RPROP training.*/
633
+ static VALUE get_rprop_increase_factor(VALUE self)
634
+ {
635
+ RETURN_FANN_FLT(fann_get_rprop_increase_factor);
636
+ }
637
+
638
+ /** call-seq: set_rprop_increase_factor(rprop_increase_factor)
639
+
640
+ The increase factor used during RPROP training. */
641
+ static VALUE set_rprop_increase_factor(VALUE self, VALUE rprop_increase_factor)
642
+ {
643
+ SET_FANN_FLT(rprop_increase_factor, fann_set_rprop_increase_factor);
644
+ }
645
+
646
+ /** The decrease factor is a value smaller than 1, which is used to decrease the step-size during RPROP training.*/
647
+ static VALUE get_rprop_decrease_factor(VALUE self)
648
+ {
649
+ RETURN_FANN_FLT(fann_get_rprop_decrease_factor);
650
+ }
651
+
652
+ /** call-seq: set_rprop_decrease_factor(rprop_decrease_factor)
653
+
654
+ The decrease factor is a value smaller than 1, which is used to decrease the step-size during RPROP training.*/
655
+ static VALUE set_rprop_decrease_factor(VALUE self, VALUE rprop_decrease_factor)
656
+ {
657
+ SET_FANN_FLT(rprop_decrease_factor, fann_set_rprop_decrease_factor);
658
+ }
659
+
660
+ /** The minimum step-size is a small positive number determining how small the minimum step-size may be.*/
661
+ static VALUE get_rprop_delta_min(VALUE self)
662
+ {
663
+ RETURN_FANN_FLT(fann_get_rprop_delta_min);
664
+ }
665
+
666
+ /** call-seq: set_rprop_delta_min(rprop_delta_min)
667
+
668
+ The minimum step-size is a small positive number determining how small the minimum step-size may be.*/
669
+ static VALUE set_rprop_delta_min(VALUE self, VALUE rprop_delta_min)
670
+ {
671
+ SET_FANN_FLT(rprop_delta_min, fann_set_rprop_delta_min);
672
+ }
673
+
674
+ /** The maximum step-size is a positive number determining how large the maximum step-size may be.*/
675
+ static VALUE get_rprop_delta_max(VALUE self)
676
+ {
677
+ RETURN_FANN_FLT(fann_get_rprop_delta_max);
678
+ }
679
+
680
+ /** call-seq: set_rprop_delta_max(rprop_delta_max)
681
+
682
+ The maximum step-size is a positive number determining how large the maximum step-size may be.*/
683
+ static VALUE set_rprop_delta_max(VALUE self, VALUE rprop_delta_max)
684
+ {
685
+ SET_FANN_FLT(rprop_delta_max, fann_set_rprop_delta_max);
686
+ }
687
+
688
+ /** The initial step-size is a positive number determining the initial step size.*/
689
+ static VALUE get_rprop_delta_zero(VALUE self)
690
+ {
691
+ RETURN_FANN_FLT(fann_get_rprop_delta_zero);
692
+ }
693
+
694
+ /** call-seq: set_rprop_delta_zero(rprop_delta_zero)
695
+
696
+ The initial step-size is a positive number determining the initial step size.*/
697
+ static VALUE set_rprop_delta_zero(VALUE self, VALUE rprop_delta_zero)
698
+ {
699
+ SET_FANN_FLT(rprop_delta_zero, fann_set_rprop_delta_zero);
700
+ }
701
+
702
+ /** Return array of bias(es)*/
703
+ static VALUE get_bias_array(VALUE self)
704
+ {
705
+ struct fann* f;
706
+ unsigned int num_layers;
707
+ Data_Get_Struct (self, struct fann, f);
708
+ num_layers = fann_get_num_layers(f);
709
+ unsigned int layers[num_layers];
710
+ fann_get_bias_array(f, layers);
711
+
712
+ // Create ruby array & set outputs:
713
+ VALUE arr;
714
+ arr = rb_ary_new();
715
+ int i;
716
+ for (i=0; i<num_layers; i++)
717
+ {
718
+ rb_ary_push(arr, INT2NUM(layers[i]));
719
+ }
720
+
721
+ return arr;
722
+ }
723
+
724
+ /** The number of fail bits; means the number of output neurons which differ more
725
+ than the bit fail limit (see <fann_get_bit_fail_limit>, <fann_set_bit_fail_limit>).
726
+ The bits are counted in all of the training data, so this number can be higher than
727
+ the number of training data.*/
728
+ static VALUE get_bit_fail(VALUE self)
729
+ {
730
+ RETURN_FANN_INT(fann_get_bit_fail);
731
+ }
732
+
733
+ /** Get the connection rate used when the network was created.*/
734
+ static VALUE get_connection_rate(VALUE self)
735
+ {
736
+ RETURN_FANN_INT(fann_get_connection_rate);
737
+ }
738
+
739
+ /** call-seq: get_neurons(layer) -> return value
740
+
741
+ Return array<hash> where each array element is a hash
742
+ representing a neuron. It contains the following keys:
743
+ :activation_function, symbol -- the activation function
744
+ :activation_steepness=float -- The steepness of the activation function
745
+ :sum=float -- The sum of the inputs multiplied with the weights
746
+ :value=float -- The value of the activation fuction applied to the sum
747
+ :connections=array<int> -- indices of connected neurons(inputs)
748
+
749
+ This could be done more elegantly (e.g., defining more ruby ext classes).
750
+ This method does not directly correlate to anything in FANN, and accesses
751
+ structs that are not guaranteed to not change.
752
+ */
753
+ static VALUE get_neurons(VALUE self, VALUE layer)
754
+ {
755
+ struct fann_layer *layer_it;
756
+ struct fann_neuron *neuron_it;
757
+
758
+ struct fann* f;
759
+ unsigned int i;
760
+ Data_Get_Struct (self, struct fann, f);
761
+
762
+ VALUE neuron_array = rb_ary_new();
763
+ VALUE activation_function_sym = ID2SYM(rb_intern("activation_function"));
764
+ VALUE activation_steepness_sym = ID2SYM(rb_intern("activation_steepness"));
765
+ VALUE layer_sym = ID2SYM(rb_intern("layer"));
766
+ VALUE sum_sym = ID2SYM(rb_intern("sum"));
767
+ VALUE value_sym = ID2SYM(rb_intern("value"));
768
+ VALUE connections_sym = ID2SYM(rb_intern("connections"));
769
+ unsigned int layer_num = 0;
770
+
771
+
772
+ int nuke_bias_neuron = (fann_get_network_type(f)==FANN_NETTYPE_LAYER);
773
+ for(layer_it = f->first_layer; layer_it != f->last_layer; layer_it++)
774
+ {
775
+ for(neuron_it = layer_it->first_neuron; neuron_it != layer_it->last_neuron; neuron_it++)
776
+ {
777
+ if (nuke_bias_neuron && (neuron_it==(layer_it->last_neuron)-1)) continue;
778
+ // Create array of connection indicies:
779
+ VALUE connection_array = rb_ary_new();
780
+ for (i = neuron_it->first_con; i < neuron_it->last_con; i++) {
781
+ rb_ary_push(connection_array, INT2NUM(f->connections[i] - f->first_layer->first_neuron));
782
+ }
783
+
784
+ VALUE neuron = rb_hash_new();
785
+
786
+ // Set attributes on hash & push on array:
787
+ rb_hash_aset(neuron, activation_function_sym, activation_function_to_sym(neuron_it->activation_function));
788
+ rb_hash_aset(neuron, activation_steepness_sym, rb_float_new(neuron_it->activation_steepness));
789
+ rb_hash_aset(neuron, layer_sym, INT2NUM(layer_num));
790
+ rb_hash_aset(neuron, sum_sym, rb_float_new(neuron_it->sum));
791
+ rb_hash_aset(neuron, value_sym, rb_float_new(neuron_it->value));
792
+ rb_hash_aset(neuron, connections_sym, connection_array);
793
+
794
+ rb_ary_push(neuron_array, neuron);
795
+ }
796
+ ++layer_num;
797
+ }
798
+
799
+ // switch (fann_get_network_type(ann)) {
800
+ // case FANN_NETTYPE_LAYER: {
801
+ // /* Report one bias in each layer except the last */
802
+ // if (layer_it != ann->last_layer-1)
803
+ // *bias = 1;
804
+ // else
805
+ // *bias = 0;
806
+ // break;
807
+ // }
808
+ // case FANN_NETTYPE_SHORTCUT: {
809
+
810
+
811
+ return neuron_array;
812
+ }
813
+
814
+ /** Get list of layers in array format where each element contains number of neurons in that layer*/
815
+ static VALUE get_layer_array(VALUE self)
816
+ {
817
+ struct fann* f;
818
+ unsigned int num_layers;
819
+ Data_Get_Struct (self, struct fann, f);
820
+ num_layers = fann_get_num_layers(f);
821
+ unsigned int layers[num_layers];
822
+ fann_get_layer_array(f, layers);
823
+
824
+ // Create ruby array & set outputs:
825
+ VALUE arr;
826
+ arr = rb_ary_new();
827
+ int i;
828
+ for (i=0; i<num_layers; i++)
829
+ {
830
+ rb_ary_push(arr, INT2NUM(layers[i]));
831
+ }
832
+
833
+ return arr;
834
+ }
835
+
836
+ /** Reads the mean square error from the network.*/
837
+ static VALUE get_MSE(VALUE self)
838
+ {
839
+ RETURN_FANN_DBL(fann_get_MSE);
840
+ }
841
+
842
+ /** Resets the mean square error from the network.
843
+ This function also resets the number of bits that fail.*/
844
+ static VALUE reset_MSE(VALUE self)
845
+ {
846
+ struct fann* f;
847
+ Data_Get_Struct (self, struct fann, f);
848
+ fann_reset_MSE(f);
849
+ return self;
850
+ }
851
+
852
+ /** Get the type of network. Returns as ruby symbol (one of :shortcut, :layer)*/
853
+ static VALUE get_network_type(VALUE self)
854
+ {
855
+ struct fann* f;
856
+ enum fann_nettype_enum net_type;
857
+ VALUE ret_val;
858
+ Data_Get_Struct (self, struct fann, f);
859
+
860
+ net_type = fann_get_network_type(f);
861
+
862
+ if(net_type==FANN_NETTYPE_LAYER)
863
+ {
864
+ ret_val = ID2SYM(rb_intern("layer")); // (rb_str_new2("FANN_NETTYPE_LAYER"));
865
+ }
866
+ else if(net_type==FANN_NETTYPE_SHORTCUT)
867
+ {
868
+ ret_val = ID2SYM(rb_intern("shortcut")); // (rb_str_new2("FANN_NETTYPE_SHORTCUT"));
869
+ }
870
+ return ret_val;
871
+ }
872
+
873
+ /** Get the number of input neurons.*/
874
+ static VALUE get_num_input(VALUE self)
875
+ {
876
+ RETURN_FANN_INT(fann_get_num_input);
877
+ }
878
+
879
+ /** Get the number of layers in the network.*/
880
+ static VALUE get_num_layers(VALUE self)
881
+ {
882
+ RETURN_FANN_INT(fann_get_num_layers);
883
+ }
884
+
885
+ /** Get the number of output neurons.*/
886
+ static VALUE get_num_output(VALUE self)
887
+ {
888
+ RETURN_FANN_INT(fann_get_num_output);
889
+ }
890
+
891
+ /** Get the total number of connections in the entire network.*/
892
+ static VALUE get_total_connections(VALUE self)
893
+ {
894
+ RETURN_FANN_INT(fann_get_total_connections);
895
+ }
896
+
897
+ /** Get the total number of neurons in the entire network.*/
898
+ static VALUE get_total_neurons(VALUE self)
899
+ {
900
+ RETURN_FANN_INT(fann_get_total_neurons);
901
+ }
902
+
903
+ /** call-seq: set_train_error_function(train_error_function)
904
+
905
+ Sets the error function used during training. One of the following symbols:
906
+ :linear, :tanh */
907
+ static VALUE set_train_error_function(VALUE self, VALUE train_error_function)
908
+ {
909
+ Check_Type(train_error_function, T_SYMBOL);
910
+
911
+ ID id=SYM2ID(train_error_function);
912
+ enum fann_errorfunc_enum fann_train_error_function;
913
+
914
+ if(id==rb_intern("linear")) {
915
+ fann_train_error_function = FANN_ERRORFUNC_LINEAR;
916
+ } else if(id==rb_intern("tanh")) {
917
+ fann_train_error_function = FANN_ERRORFUNC_TANH;
918
+ } else {
919
+ rb_raise(rb_eRuntimeError, "Unrecognized train error function: [%s]", rb_id2name(SYM2ID(train_error_function)));
920
+ }
921
+
922
+ struct fann* f;
923
+ Data_Get_Struct (self, struct fann, f);
924
+ fann_set_train_error_function(f, fann_train_error_function);
925
+ return self;
926
+ }
927
+
928
+ /** Returns the error function used during training. One of the following symbols:
929
+ :linear, :tanh*/
930
+ static VALUE get_train_error_function(VALUE self)
931
+ {
932
+ struct fann* f;
933
+ enum fann_errorfunc_enum train_error;
934
+ VALUE ret_val;
935
+ Data_Get_Struct (self, struct fann, f);
936
+
937
+ train_error = fann_get_train_error_function(f);
938
+
939
+ if(train_error==FANN_ERRORFUNC_LINEAR)
940
+ {
941
+ ret_val = ID2SYM(rb_intern("linear"));
942
+ }
943
+ else if(train_error==FANN_ERRORFUNC_TANH)
944
+ {
945
+ ret_val = ID2SYM(rb_intern("tanh"));
946
+ }
947
+ return ret_val;
948
+ }
949
+
950
+ /** call-seq: set_training_algorithm(train_error_function)
951
+
952
+ Set the training algorithm. One of the following symbols:
953
+ :incremental, :batch, :rprop, :quickprop */
954
+ static VALUE set_training_algorithm(VALUE self, VALUE train_error_function)
955
+ {
956
+ Check_Type(train_error_function, T_SYMBOL);
957
+
958
+ ID id=SYM2ID(train_error_function);
959
+ enum fann_train_enum fann_train_algorithm;
960
+
961
+ if(id==rb_intern("incremental")) {
962
+ fann_train_algorithm = FANN_TRAIN_INCREMENTAL;
963
+ } else if(id==rb_intern("batch")) {
964
+ fann_train_algorithm = FANN_TRAIN_BATCH;
965
+ } else if(id==rb_intern("rprop")) {
966
+ fann_train_algorithm = FANN_TRAIN_RPROP;
967
+ } else if(id==rb_intern("quickprop")) {
968
+ fann_train_algorithm = FANN_TRAIN_QUICKPROP;
969
+ } else {
970
+ rb_raise(rb_eRuntimeError, "Unrecognized training algorithm function: [%s]", rb_id2name(SYM2ID(train_error_function)));
971
+ }
972
+
973
+ struct fann* f;
974
+ Data_Get_Struct (self, struct fann, f);
975
+ fann_set_training_algorithm(f, fann_train_algorithm);
976
+ return self;
977
+ }
978
+
979
+ /** Returns the training algorithm. One of the following symbols:
980
+ :incremental, :batch, :rprop, :quickprop */
981
+ static VALUE get_training_algorithm(VALUE self)
982
+ {
983
+ struct fann* f;
984
+ enum fann_train_enum fann_train_algorithm;
985
+ VALUE ret_val;
986
+ Data_Get_Struct (self, struct fann, f);
987
+
988
+ fann_train_algorithm = fann_get_training_algorithm(f);
989
+
990
+ if(fann_train_algorithm==FANN_TRAIN_INCREMENTAL) {
991
+ ret_val = ID2SYM(rb_intern("incremental"));
992
+ } else if(fann_train_algorithm==FANN_TRAIN_BATCH) {
993
+ ret_val = ID2SYM(rb_intern("batch"));
994
+ } else if(fann_train_algorithm==FANN_TRAIN_RPROP) {
995
+ ret_val = ID2SYM(rb_intern("rprop"));
996
+ } else if(fann_train_algorithm==FANN_TRAIN_QUICKPROP) {
997
+ ret_val = ID2SYM(rb_intern("quickprop"));
998
+ }
999
+ return ret_val;
1000
+ }
1001
+
1002
+ /** call-seq: set_train_stop_function(train_stop_function) -> return value
1003
+
1004
+ Set the training stop function. One of the following symbols:
1005
+ :mse, :bit */
1006
+ static VALUE set_train_stop_function(VALUE self, VALUE train_stop_function)
1007
+ {
1008
+ Check_Type(train_stop_function, T_SYMBOL);
1009
+ ID id=SYM2ID(train_stop_function);
1010
+ enum fann_stopfunc_enum fann_train_stop_function;
1011
+
1012
+ if(id==rb_intern("mse")) {
1013
+ fann_train_stop_function = FANN_STOPFUNC_MSE;
1014
+ } else if(id==rb_intern("bit")) {
1015
+ fann_train_stop_function = FANN_STOPFUNC_BIT;
1016
+ } else {
1017
+ rb_raise(rb_eRuntimeError, "Unrecognized stop function: [%s]", rb_id2name(SYM2ID(train_stop_function)));
1018
+ }
1019
+
1020
+ struct fann* f;
1021
+ Data_Get_Struct (self, struct fann, f);
1022
+ fann_set_train_stop_function(f, fann_train_stop_function);
1023
+ return self;
1024
+ }
1025
+
1026
+ /** Returns the training stop function. One of the following symbols:
1027
+ :mse, :bit */
1028
+ static VALUE get_train_stop_function(VALUE self)
1029
+ {
1030
+ struct fann* f;
1031
+ enum fann_stopfunc_enum train_stop;
1032
+ VALUE ret_val;
1033
+ Data_Get_Struct (self, struct fann, f);
1034
+
1035
+ train_stop = fann_get_train_stop_function(f);
1036
+
1037
+ if(train_stop==FANN_STOPFUNC_MSE)
1038
+ {
1039
+ ret_val = ID2SYM(rb_intern("mse")); // (rb_str_new2("FANN_NETTYPE_LAYER"));
1040
+ }
1041
+ else if(train_stop==FANN_STOPFUNC_BIT)
1042
+ {
1043
+ ret_val = ID2SYM(rb_intern("bit")); // (rb_str_new2("FANN_NETTYPE_SHORTCUT"));
1044
+ }
1045
+ return ret_val;
1046
+ }
1047
+
1048
+
1049
+ /** Will print the connections of the ann in a compact matrix,
1050
+ for easy viewing of the internals of the ann. */
1051
+ static VALUE print_connections(VALUE self)
1052
+ {
1053
+ struct fann* f;
1054
+ Data_Get_Struct (self, struct fann, f);
1055
+ fann_print_connections(f);
1056
+ return self;
1057
+ }
1058
+
1059
+ /** Print current NN parameters to stdout */
1060
+ static VALUE print_parameters(VALUE self)
1061
+ {
1062
+ struct fann* f;
1063
+ Data_Get_Struct (self, struct fann, f);
1064
+ fann_print_parameters(f);
1065
+ return Qnil;
1066
+ }
1067
+
1068
+ /** call-seq: randomize_weights(min_weight, max_weight)
1069
+
1070
+ Give each connection a random weight between *min_weight* and *max_weight* */
1071
+ static VALUE randomize_weights(VALUE self, VALUE min_weight, VALUE max_weight)
1072
+ {
1073
+ Check_Type(min_weight, T_FLOAT);
1074
+ Check_Type(max_weight, T_FLOAT);
1075
+ struct fann* f;
1076
+ Data_Get_Struct (self, struct fann, f);
1077
+ fann_randomize_weights(f, NUM2DBL(min_weight), NUM2DBL(max_weight));
1078
+ return self;
1079
+ }
1080
+
1081
+ /** call-seq: run(inputs) -> return value
1082
+
1083
+ Run neural net on array<Float> of inputs with current parameters.
1084
+ Returns array<Float> as output */
1085
+ static VALUE run (VALUE self, VALUE inputs)
1086
+ {
1087
+ Check_Type(inputs, T_ARRAY);
1088
+
1089
+ struct fann* f;
1090
+ int i;
1091
+ fann_type* outputs;
1092
+
1093
+ // Convert inputs to type needed for NN:
1094
+ unsigned int len = RARRAY_LEN(inputs);
1095
+ fann_type fann_inputs[len];
1096
+ for (i=0; i<len; i++)
1097
+ {
1098
+ fann_inputs[i] = NUM2DBL(RARRAY_PTR(inputs)[i]);
1099
+ }
1100
+
1101
+
1102
+ // Obtain NN & run method:
1103
+ Data_Get_Struct (self, struct fann, f);
1104
+ outputs = fann_run(f, fann_inputs);
1105
+
1106
+ // Create ruby array & set outputs:
1107
+ VALUE arr;
1108
+ arr = rb_ary_new();
1109
+ unsigned int output_len=fann_get_num_output(f);
1110
+ for (i=0; i<output_len; i++)
1111
+ {
1112
+ rb_ary_push(arr, rb_float_new(outputs[i]));
1113
+ }
1114
+
1115
+ return arr;
1116
+ }
1117
+
1118
+ /** call-seq: init_weights(train_data) -> return value
1119
+
1120
+ Initialize the weights using Widrow + Nguyen's algorithm. */
1121
+ static VALUE init_weights(VALUE self, VALUE train_data)
1122
+ {
1123
+
1124
+ Check_Type(train_data, T_DATA);
1125
+
1126
+ struct fann* f;
1127
+ struct fann_train_data* t;
1128
+ Data_Get_Struct (self, struct fann, f);
1129
+ Data_Get_Struct (train_data, struct fann_train_data, t);
1130
+
1131
+ fann_init_weights(f, t);
1132
+ return self;
1133
+ }
1134
+
1135
+ /** call-seq: train(input, expected_output)
1136
+
1137
+ Train with a single input-output pair.
1138
+ input - The inputs given to the network
1139
+ expected_output - The outputs expected. */
1140
+ static VALUE train(VALUE self, VALUE input, VALUE expected_output)
1141
+ {
1142
+ Check_Type(input, T_ARRAY);
1143
+ Check_Type(expected_output, T_ARRAY);
1144
+
1145
+ struct fann* f;
1146
+ Data_Get_Struct(self, struct fann, f);
1147
+
1148
+ unsigned int num_input = RARRAY_LEN(input);
1149
+ unsigned int num_output = RARRAY_LEN(expected_output);
1150
+
1151
+ fann_type data_input[num_input], data_output[num_output];
1152
+
1153
+ int i;
1154
+
1155
+ for (i = 0; i < num_input; i++) {
1156
+ data_input[i] = NUM2DBL(RARRAY_PTR(input)[i]);
1157
+ }
1158
+
1159
+ for (i = 0; i < num_output; i++) {
1160
+ data_output[i] = NUM2DBL(RARRAY_PTR(expected_output)[i]);
1161
+ }
1162
+
1163
+ fann_train(f, data_input, data_output);
1164
+
1165
+ return rb_int_new(0);
1166
+ }
1167
+
1168
+ /** call-seq: train_on_data(train_data, max_epochs, epochs_between_reports, desired_error)
1169
+
1170
+ Train with training data created with MooFann::TrainData.new
1171
+ max_epochs - The maximum number of epochs the training should continue
1172
+ epochs_between_reports - The number of epochs between printing a status report to stdout.
1173
+ desired_error - The desired <get_MSE> or <get_bit_fail>, depending on which stop function
1174
+ is chosen by <set_train_stop_function>. */
1175
+ static VALUE train_on_data(VALUE self, VALUE train_data, VALUE max_epochs, VALUE epochs_between_reports, VALUE desired_error)
1176
+ {
1177
+ Check_Type(train_data, T_DATA);
1178
+ Check_Type(max_epochs, T_FIXNUM);
1179
+ Check_Type(epochs_between_reports, T_FIXNUM);
1180
+ Check_Type(desired_error, T_FLOAT);
1181
+
1182
+ struct fann* f;
1183
+ struct fann_train_data* t;
1184
+ Data_Get_Struct (self, struct fann, f);
1185
+ Data_Get_Struct (train_data, struct fann_train_data, t);
1186
+
1187
+ unsigned int fann_max_epochs = NUM2INT(max_epochs);
1188
+ unsigned int fann_epochs_between_reports = NUM2INT(epochs_between_reports);
1189
+ float fann_desired_error = NUM2DBL(desired_error);
1190
+ fann_train_on_data(f, t, fann_max_epochs, fann_epochs_between_reports, fann_desired_error);
1191
+ return rb_int_new(0);
1192
+ }
1193
+
1194
+ /** call-seq: train_epoch(train_data) -> return value
1195
+
1196
+ Train one epoch with a set of training data, created with MooFann::TrainData.new */
1197
+ static VALUE train_epoch(VALUE self, VALUE train_data)
1198
+ {
1199
+ Check_Type(train_data, T_DATA);
1200
+ struct fann* f;
1201
+ struct fann_train_data* t;
1202
+ Data_Get_Struct (self, struct fann, f);
1203
+ Data_Get_Struct (train_data, struct fann_train_data, t);
1204
+ return rb_float_new(fann_train_epoch(f, t));
1205
+ }
1206
+
1207
+ /** call-seq: test_data(train_data) -> return value
1208
+
1209
+ Test a set of training data and calculates the MSE for the training data. */
1210
+ static VALUE test_data(VALUE self, VALUE train_data)
1211
+ {
1212
+ Check_Type(train_data, T_DATA);
1213
+ struct fann* f;
1214
+ struct fann_train_data* t;
1215
+ Data_Get_Struct (self, struct fann, f);
1216
+ Data_Get_Struct (train_data, struct fann_train_data, t);
1217
+ return rb_float_new(fann_test_data(f, t));
1218
+ }
1219
+
1220
+ // Returns the position of the decimal point in the ann.
1221
+ // Only available in fixed-point mode, which we don't need:
1222
+ // static VALUE get_decimal_point(VALUE self)
1223
+ // {
1224
+ // struct fann* f;
1225
+ // Data_Get_Struct (self, struct fann, f);
1226
+ // return INT2NUM(fann_get_decimal_point(f));
1227
+ // }
1228
+
1229
+ // returns the multiplier that fix point data is multiplied with.
1230
+
1231
+ // Only available in fixed-point mode, which we don't need:
1232
+ // static VALUE get_multiplier(VALUE self)
1233
+ // {
1234
+ // struct fann* f;
1235
+ // Data_Get_Struct (self, struct fann, f);
1236
+ // return INT2NUM(fann_get_multiplier(f));
1237
+ // }
1238
+
1239
+ /** call-seq: cascadetrain_on_data(train_data, max_neurons, neurons_between_reports, desired_error)
1240
+
1241
+ Perform cascade training with training data created with MooFann::TrainData.new
1242
+ max_epochs - The maximum number of neurons in trained network
1243
+ neurons_between_reports - The number of neurons between printing a status report to stdout.
1244
+ desired_error - The desired <get_MSE> or <get_bit_fail>, depending on which stop function
1245
+ is chosen by <set_train_stop_function>. */
1246
+ static VALUE cascadetrain_on_data(VALUE self, VALUE train_data, VALUE max_neurons, VALUE neurons_between_reports, VALUE desired_error)
1247
+ {
1248
+ Check_Type(train_data, T_DATA);
1249
+ Check_Type(max_neurons, T_FIXNUM);
1250
+ Check_Type(neurons_between_reports, T_FIXNUM);
1251
+ Check_Type(desired_error, T_FLOAT);
1252
+
1253
+ struct fann* f;
1254
+ struct fann_train_data* t;
1255
+ Data_Get_Struct (self, struct fann, f);
1256
+ Data_Get_Struct (train_data, struct fann_train_data, t);
1257
+
1258
+ unsigned int fann_max_neurons = NUM2INT(max_neurons);
1259
+ unsigned int fann_neurons_between_reports = NUM2INT(neurons_between_reports);
1260
+ float fann_desired_error = NUM2DBL(desired_error);
1261
+
1262
+ fann_cascadetrain_on_data(f, t, fann_max_neurons, fann_neurons_between_reports, fann_desired_error);
1263
+ return self;
1264
+ }
1265
+
1266
+ /** The cascade output change fraction is a number between 0 and 1 */
1267
+ static VALUE get_cascade_output_change_fraction(VALUE self)
1268
+ {
1269
+ RETURN_FANN_FLT(fann_get_cascade_output_change_fraction);
1270
+ }
1271
+
1272
+ /** call-seq: set_cascade_output_change_fraction(cascade_output_change_fraction)
1273
+
1274
+ The cascade output change fraction is a number between 0 and 1 */
1275
+ static VALUE set_cascade_output_change_fraction(VALUE self, VALUE cascade_output_change_fraction)
1276
+ {
1277
+ SET_FANN_FLT(cascade_output_change_fraction, fann_set_cascade_output_change_fraction);
1278
+ }
1279
+
1280
+ /** The number of cascade output stagnation epochs determines the number of epochs training is allowed to
1281
+ continue without changing the MSE by a fraction of <get_cascade_output_change_fraction>. */
1282
+ static VALUE get_cascade_output_stagnation_epochs(VALUE self)
1283
+ {
1284
+ RETURN_FANN_INT(fann_get_cascade_output_stagnation_epochs);
1285
+ }
1286
+
1287
+ /** call-seq: set_cascade_output_stagnation_epochs(cascade_output_stagnation_epochs)
1288
+
1289
+ The number of cascade output stagnation epochs determines the number of epochs training is allowed to
1290
+ continue without changing the MSE by a fraction of <get_cascade_output_change_fraction>. */
1291
+ static VALUE set_cascade_output_stagnation_epochs(VALUE self, VALUE cascade_output_stagnation_epochs)
1292
+ {
1293
+ SET_FANN_INT(cascade_output_stagnation_epochs, fann_set_cascade_output_stagnation_epochs);
1294
+ }
1295
+
1296
+ /** The cascade candidate change fraction is a number between 0 and 1 */
1297
+ static VALUE get_cascade_candidate_change_fraction(VALUE self)
1298
+ {
1299
+ RETURN_FANN_FLT(fann_get_cascade_candidate_change_fraction);
1300
+ }
1301
+
1302
+ /** call-seq: set_cascade_candidate_change_fraction(cascade_candidate_change_fraction)
1303
+
1304
+ The cascade candidate change fraction is a number between 0 and 1 */
1305
+ static VALUE set_cascade_candidate_change_fraction(VALUE self, VALUE cascade_candidate_change_fraction)
1306
+ {
1307
+ SET_FANN_FLT(cascade_candidate_change_fraction, fann_set_cascade_candidate_change_fraction);
1308
+ }
1309
+
1310
+ /** The number of cascade candidate stagnation epochs determines the number of epochs training is allowed to
1311
+ continue without changing the MSE by a fraction of <get_cascade_candidate_change_fraction>. */
1312
+ static VALUE get_cascade_candidate_stagnation_epochs(VALUE self)
1313
+ {
1314
+ RETURN_FANN_UINT(fann_get_cascade_candidate_stagnation_epochs);
1315
+ }
1316
+
1317
+ /** call-seq: set_cascade_candidate_stagnation_epochs(cascade_candidate_stagnation_epochs)
1318
+
1319
+ The number of cascade candidate stagnation epochs determines the number of epochs training is allowed to
1320
+ continue without changing the MSE by a fraction of <get_cascade_candidate_change_fraction>. */
1321
+ static VALUE set_cascade_candidate_stagnation_epochs(VALUE self, VALUE cascade_candidate_stagnation_epochs)
1322
+ {
1323
+ SET_FANN_UINT(cascade_candidate_stagnation_epochs, fann_set_cascade_candidate_stagnation_epochs);
1324
+ }
1325
+
1326
+ /** The weight multiplier is a parameter which is used to multiply the weights from the candidate neuron
1327
+ before adding the neuron to the neural network. This parameter is usually between 0 and 1, and is used
1328
+ to make the training a bit less aggressive. */
1329
+ static VALUE get_cascade_weight_multiplier(VALUE self)
1330
+ {
1331
+ RETURN_FANN_DBL(fann_get_cascade_weight_multiplier);
1332
+ }
1333
+
1334
+ /** call-seq: set_cascade_weight_multiplier(cascade_weight_multiplier)
1335
+
1336
+ The weight multiplier is a parameter which is used to multiply the weights from the candidate neuron
1337
+ before adding the neuron to the neural network. This parameter is usually between 0 and 1, and is used
1338
+ to make the training a bit less aggressive. */
1339
+ static VALUE set_cascade_weight_multiplier(VALUE self, VALUE cascade_weight_multiplier)
1340
+ {
1341
+ SET_FANN_DBL(cascade_weight_multiplier, fann_set_cascade_weight_multiplier);
1342
+ }
1343
+
1344
+ /** The candidate limit is a limit for how much the candidate neuron may be trained.
1345
+ The limit is a limit on the proportion between the MSE and candidate score. */
1346
+ static VALUE get_cascade_candidate_limit(VALUE self)
1347
+ {
1348
+ RETURN_FANN_DBL(fann_get_cascade_candidate_limit);
1349
+ }
1350
+
1351
+ /** call-seq: set_cascade_candidate_limit(cascade_candidate_limit)
1352
+
1353
+ The candidate limit is a limit for how much the candidate neuron may be trained.
1354
+ The limit is a limit on the proportion between the MSE and candidate score. */
1355
+ static VALUE set_cascade_candidate_limit(VALUE self, VALUE cascade_candidate_limit)
1356
+ {
1357
+ SET_FANN_DBL(cascade_candidate_limit, fann_set_cascade_candidate_limit);
1358
+ }
1359
+
1360
+ /** The maximum out epochs determines the maximum number of epochs the output connections
1361
+ may be trained after adding a new candidate neuron. */
1362
+ static VALUE get_cascade_max_out_epochs(VALUE self)
1363
+ {
1364
+ RETURN_FANN_UINT(fann_get_cascade_max_out_epochs);
1365
+ }
1366
+
1367
+ /** call-seq: set_cascade_max_out_epochs(cascade_max_out_epochs)
1368
+
1369
+ The maximum out epochs determines the maximum number of epochs the output connections
1370
+ may be trained after adding a new candidate neuron. */
1371
+ static VALUE set_cascade_max_out_epochs(VALUE self, VALUE cascade_max_out_epochs)
1372
+ {
1373
+ SET_FANN_UINT(cascade_max_out_epochs, fann_set_cascade_max_out_epochs);
1374
+ }
1375
+
1376
+ /** The maximum candidate epochs determines the maximum number of epochs the input
1377
+ connections to the candidates may be trained before adding a new candidate neuron. */
1378
+ static VALUE get_cascade_max_cand_epochs(VALUE self)
1379
+ {
1380
+ RETURN_FANN_UINT(fann_get_cascade_max_cand_epochs);
1381
+ }
1382
+
1383
+ /** call-seq: set_cascade_max_cand_epochs(cascade_max_cand_epochs)
1384
+
1385
+ The maximum candidate epochs determines the maximum number of epochs the input
1386
+ connections to the candidates may be trained before adding a new candidate neuron. */
1387
+ static VALUE set_cascade_max_cand_epochs(VALUE self, VALUE cascade_max_cand_epochs)
1388
+ {
1389
+ SET_FANN_UINT(cascade_max_cand_epochs, fann_set_cascade_max_cand_epochs);
1390
+ }
1391
+
1392
+ /** The number of candidates used during training (calculated by multiplying <get_cascade_activation_functions_count>,
1393
+ <get_cascade_activation_steepnesses_count> and <get_cascade_num_candidate_groups>). */
1394
+ static VALUE get_cascade_num_candidates(VALUE self)
1395
+ {
1396
+ RETURN_FANN_UINT(fann_get_cascade_num_candidates);
1397
+ }
1398
+
1399
+ /** The number of activation functions in the <get_cascade_activation_functions> array */
1400
+ static VALUE get_cascade_activation_functions_count(VALUE self)
1401
+ {
1402
+ RETURN_FANN_UINT(fann_get_cascade_activation_functions_count);
1403
+ }
1404
+
1405
+ /** The learning rate is used to determine how aggressive training should be for some of the
1406
+ training algorithms (:incremental, :batch, :quickprop).
1407
+ Do however note that it is not used in :rprop.
1408
+ The default learning rate is 0.7. */
1409
+ static VALUE get_learning_rate(VALUE self)
1410
+ {
1411
+ RETURN_FANN_FLT(fann_get_learning_rate);
1412
+ }
1413
+
1414
+ /** call-seq: set_learning_rate(learning_rate) -> return value
1415
+
1416
+ The learning rate is used to determine how aggressive training should be for some of the
1417
+ training algorithms (:incremental, :batch, :quickprop).
1418
+ Do however note that it is not used in :rprop.
1419
+ The default learning rate is 0.7. */
1420
+ static VALUE set_learning_rate(VALUE self, VALUE learning_rate)
1421
+ {
1422
+ SET_FANN_FLT(learning_rate, fann_set_learning_rate);
1423
+ }
1424
+
1425
+ /** Get the learning momentum. */
1426
+ static VALUE get_learning_momentum(VALUE self)
1427
+ {
1428
+ RETURN_FANN_FLT(fann_get_learning_momentum);
1429
+ }
1430
+
1431
+ /** call-seq: set_learning_momentum(learning_momentum) -> return value
1432
+
1433
+ Set the learning momentum. */
1434
+ static VALUE set_learning_momentum(VALUE self, VALUE learning_momentum)
1435
+ {
1436
+ SET_FANN_FLT(learning_momentum, fann_set_learning_momentum);
1437
+ }
1438
+
1439
+ /** call-seq: set_cascade_activation_functions(cascade_activation_functions)
1440
+
1441
+ The cascade activation functions is an array of the different activation functions used by
1442
+ the candidates. The default is [:sigmoid, :sigmoid_symmetric, :gaussian, :gaussian_symmetric, :elliot, :elliot_symmetric] */
1443
+ static VALUE set_cascade_activation_functions(VALUE self, VALUE cascade_activation_functions)
1444
+ {
1445
+ Check_Type(cascade_activation_functions, T_ARRAY);
1446
+ struct fann* f;
1447
+ Data_Get_Struct (self, struct fann, f);
1448
+
1449
+ unsigned int cnt = RARRAY_LEN(cascade_activation_functions);
1450
+ enum fann_activationfunc_enum fann_activation_functions[cnt];
1451
+ int i;
1452
+ for (i=0; i<cnt; i++)
1453
+ {
1454
+ fann_activation_functions[i] = sym_to_activation_function(RARRAY_PTR(cascade_activation_functions)[i]);
1455
+ }
1456
+
1457
+ fann_set_cascade_activation_functions(f, fann_activation_functions, cnt);
1458
+ return self;
1459
+ }
1460
+
1461
+ /** The cascade activation functions is an array of the different activation functions used by
1462
+ the candidates. The default is [:sigmoid, :sigmoid_symmetric, :gaussian, :gaussian_symmetric, :elliot, :elliot_symmetric] */
1463
+ static VALUE get_cascade_activation_functions(VALUE self)
1464
+ {
1465
+ struct fann* f;
1466
+ Data_Get_Struct (self, struct fann, f);
1467
+ unsigned int cnt = fann_get_cascade_activation_functions_count(f);
1468
+ enum fann_activationfunc_enum* fann_functions = fann_get_cascade_activation_functions(f);
1469
+
1470
+ // Create ruby array & set outputs:
1471
+ VALUE arr;
1472
+ arr = rb_ary_new();
1473
+ int i;
1474
+ for (i=0; i<cnt; i++)
1475
+ {
1476
+ rb_ary_push(arr, activation_function_to_sym(fann_functions[i]));
1477
+ }
1478
+
1479
+ return arr;
1480
+ }
1481
+
1482
+ /** The number of activation steepnesses in the <get_cascade_activation_functions> array. */
1483
+ static VALUE get_cascade_activation_steepnesses_count(VALUE self)
1484
+ {
1485
+ RETURN_FANN_UINT(fann_get_cascade_activation_steepnesses_count);
1486
+ }
1487
+
1488
+ /** The number of candidate groups is the number of groups of identical candidates which will be used
1489
+ during training. */
1490
+ static VALUE get_cascade_num_candidate_groups(VALUE self)
1491
+ {
1492
+ RETURN_FANN_UINT(fann_get_cascade_num_candidate_groups);
1493
+ }
1494
+
1495
+ /** call-seq: set_cascade_num_candidate_groups(cascade_num_candidate_groups)
1496
+
1497
+ The number of candidate groups is the number of groups of identical candidates which will be used
1498
+ during training. */
1499
+ static VALUE set_cascade_num_candidate_groups(VALUE self, VALUE cascade_num_candidate_groups)
1500
+ {
1501
+ SET_FANN_UINT(cascade_num_candidate_groups, fann_set_cascade_num_candidate_groups);
1502
+ return 0;
1503
+ }
1504
+
1505
+ /** The cascade activation steepnesses array is an array of the different activation functions used by
1506
+ the candidates. */
1507
+ static VALUE set_cascade_activation_steepnesses(VALUE self, VALUE cascade_activation_steepnesses)
1508
+ {
1509
+ Check_Type(cascade_activation_steepnesses, T_ARRAY);
1510
+ struct fann* f;
1511
+ Data_Get_Struct (self, struct fann, f);
1512
+
1513
+ unsigned int cnt = RARRAY_LEN(cascade_activation_steepnesses);
1514
+ fann_type fann_activation_steepnesses[cnt];
1515
+ int i;
1516
+ for (i=0; i<cnt; i++)
1517
+ {
1518
+ fann_activation_steepnesses[i] = NUM2DBL(RARRAY_PTR(cascade_activation_steepnesses)[i]);
1519
+ }
1520
+
1521
+ fann_set_cascade_activation_steepnesses(f, fann_activation_steepnesses, cnt);
1522
+ return self;
1523
+ }
1524
+
1525
+ /** The cascade activation steepnesses array is an array of the different activation functions used by
1526
+ the candidates. */
1527
+ static VALUE get_cascade_activation_steepnesses(VALUE self)
1528
+ {
1529
+ struct fann* f;
1530
+ Data_Get_Struct (self, struct fann, f);
1531
+ fann_type* fann_steepnesses = fann_get_cascade_activation_steepnesses(f);
1532
+ unsigned int cnt = fann_get_cascade_activation_steepnesses_count(f);
1533
+
1534
+ // Create ruby array & set outputs:
1535
+ VALUE arr;
1536
+ arr = rb_ary_new();
1537
+ int i;
1538
+ for (i=0; i<cnt; i++)
1539
+ {
1540
+ rb_ary_push(arr, rb_float_new(fann_steepnesses[i]));
1541
+ }
1542
+
1543
+ return arr;
1544
+ }
1545
+
1546
+ /** call-seq: save(filename) -> return status
1547
+
1548
+ Save the entire network to configuration file with given name */
1549
+ static VALUE nn_save(VALUE self, VALUE filename)
1550
+ {
1551
+ struct fann* f;
1552
+ Data_Get_Struct (self, struct fann, f);
1553
+ int status = fann_save(f, StringValuePtr(filename));
1554
+ return INT2NUM(status);
1555
+ }
1556
+
1557
+ /** Initializes class under MooFann module/namespace. */
1558
+ void Init_moo_fann ()
1559
+ {
1560
+ // MooFann module/namespace:
1561
+ m_rb_fann_module = rb_define_module ("MooFann");
1562
+
1563
+ // Standard NN class:
1564
+ m_rb_fann_standard_class = rb_define_class_under (m_rb_fann_module, "Standard", rb_cObject);
1565
+ rb_define_alloc_func (m_rb_fann_standard_class, fann_allocate);
1566
+ rb_define_method(m_rb_fann_standard_class, "initialize", fann_initialize, 1);
1567
+ rb_define_method(m_rb_fann_standard_class, "init_weights", init_weights, 1);
1568
+ rb_define_method(m_rb_fann_standard_class, "set_activation_function", set_activation_function, 3);
1569
+ rb_define_method(m_rb_fann_standard_class, "set_activation_function_hidden", set_activation_function_hidden, 1);
1570
+ rb_define_method(m_rb_fann_standard_class, "set_activation_function_layer", set_activation_function_layer, 2);
1571
+ rb_define_method(m_rb_fann_standard_class, "get_activation_function", get_activation_function, 2);
1572
+ rb_define_method(m_rb_fann_standard_class, "set_activation_function_output", set_activation_function_output, 1);
1573
+ rb_define_method(m_rb_fann_standard_class, "get_activation_steepness", get_activation_steepness, 2);
1574
+ rb_define_method(m_rb_fann_standard_class, "set_activation_steepness", set_activation_steepness, 3);
1575
+ rb_define_method(m_rb_fann_standard_class, "set_activation_steepness_hidden", set_activation_steepness_hidden, 1);
1576
+ rb_define_method(m_rb_fann_standard_class, "set_activation_steepness_layer", set_activation_steepness_layer, 2);
1577
+ rb_define_method(m_rb_fann_standard_class, "set_activation_steepness_output", set_activation_steepness_output, 1);
1578
+ rb_define_method(m_rb_fann_standard_class, "get_train_error_function", get_train_error_function, 0);
1579
+ rb_define_method(m_rb_fann_standard_class, "set_train_error_function", set_train_error_function, 1);
1580
+ rb_define_method(m_rb_fann_standard_class, "get_train_stop_function", get_train_stop_function, 0);
1581
+ rb_define_method(m_rb_fann_standard_class, "set_train_stop_function", set_train_stop_function, 1);
1582
+ rb_define_method(m_rb_fann_standard_class, "get_bit_fail_limit", get_bit_fail_limit, 0);
1583
+ rb_define_method(m_rb_fann_standard_class, "set_bit_fail_limit", set_bit_fail_limit, 1);
1584
+ rb_define_method(m_rb_fann_standard_class, "get_quickprop_decay", get_quickprop_decay, 0);
1585
+ rb_define_method(m_rb_fann_standard_class, "set_quickprop_decay", set_quickprop_decay, 1);
1586
+ rb_define_method(m_rb_fann_standard_class, "get_quickprop_mu", get_quickprop_mu, 0);
1587
+ rb_define_method(m_rb_fann_standard_class, "set_quickprop_mu", set_quickprop_mu, 1);
1588
+ rb_define_method(m_rb_fann_standard_class, "get_rprop_increase_factor", get_rprop_increase_factor, 0);
1589
+ rb_define_method(m_rb_fann_standard_class, "set_rprop_increase_factor", set_rprop_increase_factor, 1);
1590
+ rb_define_method(m_rb_fann_standard_class, "get_rprop_decrease_factor", get_rprop_decrease_factor, 0);
1591
+ rb_define_method(m_rb_fann_standard_class, "set_rprop_decrease_factor", set_rprop_decrease_factor, 1);
1592
+ rb_define_method(m_rb_fann_standard_class, "get_rprop_delta_max", get_rprop_delta_max, 0);
1593
+ rb_define_method(m_rb_fann_standard_class, "set_rprop_delta_max", set_rprop_delta_max, 1);
1594
+ rb_define_method(m_rb_fann_standard_class, "get_rprop_delta_min", get_rprop_delta_min, 0);
1595
+ rb_define_method(m_rb_fann_standard_class, "set_rprop_delta_min", set_rprop_delta_min, 1);
1596
+ rb_define_method(m_rb_fann_standard_class, "get_rprop_delta_zero", get_rprop_delta_zero, 0);
1597
+ rb_define_method(m_rb_fann_standard_class, "set_rprop_delta_zero", set_rprop_delta_zero, 1);
1598
+ rb_define_method(m_rb_fann_standard_class, "get_bias_array", get_bias_array, 0);
1599
+ rb_define_method(m_rb_fann_standard_class, "get_connection_rate", get_connection_rate, 0);
1600
+ rb_define_method(m_rb_fann_standard_class, "get_layer_array", get_layer_array, 0);
1601
+ rb_define_method(m_rb_fann_standard_class, "get_network_type", get_network_type, 0);
1602
+ rb_define_method(m_rb_fann_standard_class, "get_neurons", get_neurons, 0);
1603
+ rb_define_method(m_rb_fann_standard_class, "get_num_input", get_num_input, 0);
1604
+ rb_define_method(m_rb_fann_standard_class, "get_num_layers", get_num_layers, 0);
1605
+ rb_define_method(m_rb_fann_standard_class, "get_num_output", get_num_output, 0);
1606
+ rb_define_method(m_rb_fann_standard_class, "get_total_connections", get_total_connections, 0);
1607
+ rb_define_method(m_rb_fann_standard_class, "get_total_neurons", get_total_neurons, 0);
1608
+ // rb_define_method(m_rb_fann_standard_class, "get_train_error_function", get_train_error_function, 0);
1609
+ // rb_define_method(m_rb_fann_standard_class, "set_train_error_function", set_train_error_function, 1);
1610
+ rb_define_method(m_rb_fann_standard_class, "print_connections", print_connections, 0);
1611
+ rb_define_method(m_rb_fann_standard_class, "print_parameters", print_parameters, 0);
1612
+ rb_define_method(m_rb_fann_standard_class, "randomize_weights", randomize_weights, 2);
1613
+ rb_define_method(m_rb_fann_standard_class, "run", run, 1);
1614
+ rb_define_method(m_rb_fann_standard_class, "train", train, 2);
1615
+ rb_define_method(m_rb_fann_standard_class, "train_on_data", train_on_data, 4);
1616
+ rb_define_method(m_rb_fann_standard_class, "train_epoch", train_epoch, 1);
1617
+ rb_define_method(m_rb_fann_standard_class, "test_data", test_data, 1);
1618
+ rb_define_method(m_rb_fann_standard_class, "get_MSE", get_MSE, 0);
1619
+ rb_define_method(m_rb_fann_standard_class, "get_bit_fail", get_bit_fail, 0);
1620
+ rb_define_method(m_rb_fann_standard_class, "reset_MSE", reset_MSE, 0);
1621
+ rb_define_method(m_rb_fann_standard_class, "get_learning_rate", get_learning_rate, 0);
1622
+ rb_define_method(m_rb_fann_standard_class, "set_learning_rate", set_learning_rate, 1);
1623
+ rb_define_method(m_rb_fann_standard_class, "get_learning_momentum", get_learning_momentum, 0);
1624
+ rb_define_method(m_rb_fann_standard_class, "set_learning_momentum", set_learning_momentum, 1);
1625
+ rb_define_method(m_rb_fann_standard_class, "get_training_algorithm", get_training_algorithm, 0);
1626
+ rb_define_method(m_rb_fann_standard_class, "set_training_algorithm", set_training_algorithm, 1);
1627
+
1628
+
1629
+ // Cascade functions:
1630
+ rb_define_method(m_rb_fann_standard_class, "cascadetrain_on_data", cascadetrain_on_data, 4);
1631
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_output_change_fraction", get_cascade_output_change_fraction, 0);
1632
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_output_change_fraction", set_cascade_output_change_fraction, 1);
1633
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_output_stagnation_epochs", get_cascade_output_stagnation_epochs, 0);
1634
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_output_stagnation_epochs", set_cascade_output_stagnation_epochs, 1);
1635
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_candidate_change_fraction", get_cascade_candidate_change_fraction, 0);
1636
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_candidate_change_fraction", set_cascade_candidate_change_fraction, 1);
1637
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_candidate_stagnation_epochs", get_cascade_candidate_stagnation_epochs, 0);
1638
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_candidate_stagnation_epochs", set_cascade_candidate_stagnation_epochs, 1);
1639
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_weight_multiplier", get_cascade_weight_multiplier, 0);
1640
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_weight_multiplier", set_cascade_weight_multiplier, 1);
1641
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_candidate_limit", get_cascade_candidate_limit, 0);
1642
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_candidate_limit", set_cascade_candidate_limit, 1);
1643
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_max_out_epochs", get_cascade_max_out_epochs, 0);
1644
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_max_out_epochs", set_cascade_max_out_epochs, 1);
1645
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_max_cand_epochs", get_cascade_max_cand_epochs, 0);
1646
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_max_cand_epochs", set_cascade_max_cand_epochs, 1);
1647
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_num_candidates", get_cascade_num_candidates, 0);
1648
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_activation_functions_count", get_cascade_activation_functions_count, 0);
1649
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_activation_functions", get_cascade_activation_functions, 0);
1650
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_activation_functions", set_cascade_activation_functions, 1);
1651
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_activation_steepnesses_count", get_cascade_activation_steepnesses_count, 0);
1652
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_activation_steepnesses", get_cascade_activation_steepnesses, 0);
1653
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_activation_steepnesses", set_cascade_activation_steepnesses, 1);
1654
+ rb_define_method(m_rb_fann_standard_class, "get_cascade_num_candidate_groups", get_cascade_num_candidate_groups, 0);
1655
+ rb_define_method(m_rb_fann_standard_class, "set_cascade_num_candidate_groups", set_cascade_num_candidate_groups, 1);
1656
+ rb_define_method(m_rb_fann_standard_class, "save", nn_save, 1);
1657
+
1658
+
1659
+ // Uncomment for fixed-point mode (also recompile fann). Probably not going to be needed:
1660
+ //rb_define_method(clazz, "get_decimal_point", get_decimal_point, 0);
1661
+ //rb_define_method(clazz, "get_multiplier", get_multiplier, 0);
1662
+
1663
+ // Shortcut NN class (duplicated from above so that rdoc generation tools can find the methods:):
1664
+ m_rb_fann_shortcut_class = rb_define_class_under (m_rb_fann_module, "Shortcut", rb_cObject);
1665
+ rb_define_alloc_func (m_rb_fann_shortcut_class, fann_allocate);
1666
+ rb_define_method(m_rb_fann_shortcut_class, "initialize", fann_initialize, 1);
1667
+ rb_define_method(m_rb_fann_shortcut_class, "init_weights", init_weights, 1);
1668
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_function", set_activation_function, 3);
1669
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_function_hidden", set_activation_function_hidden, 1);
1670
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_function_layer", set_activation_function_layer, 2);
1671
+ rb_define_method(m_rb_fann_shortcut_class, "get_activation_function", get_activation_function, 2);
1672
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_function_output", set_activation_function_output, 1);
1673
+ rb_define_method(m_rb_fann_shortcut_class, "get_activation_steepness", get_activation_steepness, 2);
1674
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_steepness", set_activation_steepness, 3);
1675
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_steepness_hidden", set_activation_steepness_hidden, 1);
1676
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_steepness_layer", set_activation_steepness_layer, 2);
1677
+ rb_define_method(m_rb_fann_shortcut_class, "set_activation_steepness_output", set_activation_steepness_output, 1);
1678
+ rb_define_method(m_rb_fann_shortcut_class, "get_train_error_function", get_train_error_function, 0);
1679
+ rb_define_method(m_rb_fann_shortcut_class, "set_train_error_function", set_train_error_function, 1);
1680
+ rb_define_method(m_rb_fann_shortcut_class, "get_train_stop_function", get_train_stop_function, 0);
1681
+ rb_define_method(m_rb_fann_shortcut_class, "set_train_stop_function", set_train_stop_function, 1);
1682
+ rb_define_method(m_rb_fann_shortcut_class, "get_bit_fail_limit", get_bit_fail_limit, 0);
1683
+ rb_define_method(m_rb_fann_shortcut_class, "set_bit_fail_limit", set_bit_fail_limit, 1);
1684
+ rb_define_method(m_rb_fann_shortcut_class, "get_quickprop_decay", get_quickprop_decay, 0);
1685
+ rb_define_method(m_rb_fann_shortcut_class, "set_quickprop_decay", set_quickprop_decay, 1);
1686
+ rb_define_method(m_rb_fann_shortcut_class, "get_quickprop_mu", get_quickprop_mu, 0);
1687
+ rb_define_method(m_rb_fann_shortcut_class, "set_quickprop_mu", set_quickprop_mu, 1);
1688
+ rb_define_method(m_rb_fann_shortcut_class, "get_rprop_increase_factor", get_rprop_increase_factor, 0);
1689
+ rb_define_method(m_rb_fann_shortcut_class, "set_rprop_increase_factor", set_rprop_increase_factor, 1);
1690
+ rb_define_method(m_rb_fann_shortcut_class, "get_rprop_decrease_factor", get_rprop_decrease_factor, 0);
1691
+ rb_define_method(m_rb_fann_shortcut_class, "set_rprop_decrease_factor", set_rprop_decrease_factor, 1);
1692
+ rb_define_method(m_rb_fann_shortcut_class, "get_rprop_delta_max", get_rprop_delta_max, 0);
1693
+ rb_define_method(m_rb_fann_shortcut_class, "set_rprop_delta_max", set_rprop_delta_max, 1);
1694
+ rb_define_method(m_rb_fann_shortcut_class, "get_rprop_delta_min", get_rprop_delta_min, 0);
1695
+ rb_define_method(m_rb_fann_shortcut_class, "set_rprop_delta_min", set_rprop_delta_min, 1);
1696
+ rb_define_method(m_rb_fann_shortcut_class, "get_rprop_delta_zero", get_rprop_delta_zero, 0);
1697
+ rb_define_method(m_rb_fann_shortcut_class, "set_rprop_delta_zero", set_rprop_delta_zero, 1);
1698
+ rb_define_method(m_rb_fann_shortcut_class, "get_bias_array", get_bias_array, 0);
1699
+ rb_define_method(m_rb_fann_shortcut_class, "get_connection_rate", get_connection_rate, 0);
1700
+ rb_define_method(m_rb_fann_shortcut_class, "get_layer_array", get_layer_array, 0);
1701
+ rb_define_method(m_rb_fann_shortcut_class, "get_network_type", get_network_type, 0);
1702
+ rb_define_method(m_rb_fann_shortcut_class, "get_neurons", get_neurons, 0);
1703
+ rb_define_method(m_rb_fann_shortcut_class, "get_num_input", get_num_input, 0);
1704
+ rb_define_method(m_rb_fann_shortcut_class, "get_num_layers", get_num_layers, 0);
1705
+ rb_define_method(m_rb_fann_shortcut_class, "get_num_output", get_num_output, 0);
1706
+ rb_define_method(m_rb_fann_shortcut_class, "get_total_connections", get_total_connections, 0);
1707
+ rb_define_method(m_rb_fann_shortcut_class, "get_total_neurons", get_total_neurons, 0);
1708
+ // rb_define_method(m_rb_fann_shortcut_class, "get_train_error_function", get_train_error_function, 0);
1709
+ // rb_define_method(m_rb_fann_shortcut_class, "set_train_error_function", set_train_error_function, 1);
1710
+ rb_define_method(m_rb_fann_shortcut_class, "print_connections", print_connections, 0);
1711
+ rb_define_method(m_rb_fann_shortcut_class, "print_parameters", print_parameters, 0);
1712
+ rb_define_method(m_rb_fann_shortcut_class, "randomize_weights", randomize_weights, 2);
1713
+ rb_define_method(m_rb_fann_shortcut_class, "run", run, 1);
1714
+ rb_define_method(m_rb_fann_shortcut_class, "train", train, 2);
1715
+ rb_define_method(m_rb_fann_shortcut_class, "train_on_data", train_on_data, 4);
1716
+ rb_define_method(m_rb_fann_shortcut_class, "train_epoch", train_epoch, 1);
1717
+ rb_define_method(m_rb_fann_shortcut_class, "test_data", test_data, 1);
1718
+ rb_define_method(m_rb_fann_shortcut_class, "get_MSE", get_MSE, 0);
1719
+ rb_define_method(m_rb_fann_shortcut_class, "get_bit_fail", get_bit_fail, 0);
1720
+ rb_define_method(m_rb_fann_shortcut_class, "reset_MSE", reset_MSE, 0);
1721
+ rb_define_method(m_rb_fann_shortcut_class, "get_learning_rate", get_learning_rate, 0);
1722
+ rb_define_method(m_rb_fann_shortcut_class, "set_learning_rate", set_learning_rate, 1);
1723
+ rb_define_method(m_rb_fann_shortcut_class, "get_learning_momentum", get_learning_momentum, 0);
1724
+ rb_define_method(m_rb_fann_shortcut_class, "set_learning_momentum", set_learning_momentum, 1);
1725
+ rb_define_method(m_rb_fann_shortcut_class, "get_training_algorithm", get_training_algorithm, 0);
1726
+ rb_define_method(m_rb_fann_shortcut_class, "set_training_algorithm", set_training_algorithm, 1);
1727
+
1728
+ // Cascade functions:
1729
+ rb_define_method(m_rb_fann_shortcut_class, "cascadetrain_on_data", cascadetrain_on_data, 4);
1730
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_output_change_fraction", get_cascade_output_change_fraction, 0);
1731
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_output_change_fraction", set_cascade_output_change_fraction, 1);
1732
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_output_stagnation_epochs", get_cascade_output_stagnation_epochs, 0);
1733
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_output_stagnation_epochs", set_cascade_output_stagnation_epochs, 1);
1734
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_candidate_change_fraction", get_cascade_candidate_change_fraction, 0);
1735
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_candidate_change_fraction", set_cascade_candidate_change_fraction, 1);
1736
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_candidate_stagnation_epochs", get_cascade_candidate_stagnation_epochs, 0);
1737
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_candidate_stagnation_epochs", set_cascade_candidate_stagnation_epochs, 1);
1738
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_weight_multiplier", get_cascade_weight_multiplier, 0);
1739
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_weight_multiplier", set_cascade_weight_multiplier, 1);
1740
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_candidate_limit", get_cascade_candidate_limit, 0);
1741
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_candidate_limit", set_cascade_candidate_limit, 1);
1742
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_max_out_epochs", get_cascade_max_out_epochs, 0);
1743
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_max_out_epochs", set_cascade_max_out_epochs, 1);
1744
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_max_cand_epochs", get_cascade_max_cand_epochs, 0);
1745
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_max_cand_epochs", set_cascade_max_cand_epochs, 1);
1746
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_num_candidates", get_cascade_num_candidates, 0);
1747
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_activation_functions_count", get_cascade_activation_functions_count, 0);
1748
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_activation_functions", get_cascade_activation_functions, 0);
1749
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_activation_functions", set_cascade_activation_functions, 1);
1750
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_activation_steepnesses_count", get_cascade_activation_steepnesses_count, 0);
1751
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_activation_steepnesses", get_cascade_activation_steepnesses, 0);
1752
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_activation_steepnesses", set_cascade_activation_steepnesses, 1);
1753
+ rb_define_method(m_rb_fann_shortcut_class, "get_cascade_num_candidate_groups", get_cascade_num_candidate_groups, 0);
1754
+ rb_define_method(m_rb_fann_shortcut_class, "set_cascade_num_candidate_groups", set_cascade_num_candidate_groups, 1);
1755
+ rb_define_method(m_rb_fann_shortcut_class, "save", nn_save, 1);
1756
+
1757
+
1758
+ // TrainData NN class:
1759
+ m_rb_fann_train_data_class = rb_define_class_under (m_rb_fann_module, "TrainData", rb_cObject);
1760
+ rb_define_alloc_func (m_rb_fann_train_data_class, fann_training_data_allocate);
1761
+ rb_define_method(m_rb_fann_train_data_class, "initialize", fann_train_data_initialize, 1);
1762
+ rb_define_method(m_rb_fann_train_data_class, "length", length_train_data, 0);
1763
+ rb_define_method(m_rb_fann_train_data_class, "shuffle", shuffle, 0);
1764
+ rb_define_method(m_rb_fann_train_data_class, "save", training_save, 1);
1765
+
1766
+ // printf("Initialized Ruby Bindings for FANN.\n");
1767
+ }
1768
+