ruby-fann 1.2.5 → 1.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +8 -8
- data/README.md +6 -1
- data/ext/ruby_fann/config.h +4 -61
- data/ext/ruby_fann/doublefann.c +1 -1
- data/ext/ruby_fann/doublefann.h +1 -1
- data/ext/ruby_fann/fann.c +279 -28
- data/ext/ruby_fann/fann.h +11 -1
- data/ext/ruby_fann/fann_activation.h +1 -1
- data/ext/ruby_fann/fann_cascade.c +27 -10
- data/ext/ruby_fann/fann_cascade.h +55 -1
- data/ext/ruby_fann/fann_data.h +28 -3
- data/ext/ruby_fann/fann_error.c +7 -1
- data/ext/ruby_fann/fann_error.h +6 -2
- data/ext/ruby_fann/fann_internal.h +7 -3
- data/ext/ruby_fann/fann_io.c +67 -27
- data/ext/ruby_fann/fann_io.h +1 -1
- data/ext/ruby_fann/fann_train.c +86 -1
- data/ext/ruby_fann/fann_train.h +108 -1
- data/ext/ruby_fann/fann_train_data.c +144 -132
- data/lib/ruby_fann/version.rb +1 -1
- metadata +2 -2
data/ext/ruby_fann/fann_io.h
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
/*
|
2
2
|
Fast Artificial Neural Network Library (fann)
|
3
|
-
Copyright (C) 2003 Steffen Nissen (
|
3
|
+
Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
|
4
4
|
|
5
5
|
This library is free software; you can redistribute it and/or
|
6
6
|
modify it under the terms of the GNU Lesser General Public
|
data/ext/ruby_fann/fann_train.c
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
/*
|
2
2
|
Fast Artificial Neural Network Library (fann)
|
3
|
-
Copyright (C) 2003 Steffen Nissen (
|
3
|
+
Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
|
4
4
|
|
5
5
|
This library is free software; you can redistribute it and/or
|
6
6
|
modify it under the terms of the GNU Lesser General Public
|
@@ -21,6 +21,7 @@
|
|
21
21
|
#include <stdlib.h>
|
22
22
|
#include <stdarg.h>
|
23
23
|
#include <string.h>
|
24
|
+
#include <math.h>
|
24
25
|
|
25
26
|
#include "config.h"
|
26
27
|
#include "fann.h"
|
@@ -206,6 +207,7 @@ FANN_EXTERNAL unsigned int FANN_API fann_get_bit_fail(struct fann *ann)
|
|
206
207
|
*/
|
207
208
|
FANN_EXTERNAL void FANN_API fann_reset_MSE(struct fann *ann)
|
208
209
|
{
|
210
|
+
/*printf("resetMSE %d %f\n", ann->num_MSE, ann->MSE_value);*/
|
209
211
|
ann->num_MSE = 0;
|
210
212
|
ann->MSE_value = 0;
|
211
213
|
ann->num_bit_fail = 0;
|
@@ -764,6 +766,85 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, un
|
|
764
766
|
}
|
765
767
|
}
|
766
768
|
|
769
|
+
/* INTERNAL FUNCTION
|
770
|
+
The SARprop- algorithm
|
771
|
+
*/
|
772
|
+
void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight, unsigned int past_end)
|
773
|
+
{
|
774
|
+
fann_type *train_slopes = ann->train_slopes;
|
775
|
+
fann_type *weights = ann->weights;
|
776
|
+
fann_type *prev_steps = ann->prev_steps;
|
777
|
+
fann_type *prev_train_slopes = ann->prev_train_slopes;
|
778
|
+
|
779
|
+
fann_type prev_step, slope, prev_slope, next_step = 0, same_sign;
|
780
|
+
|
781
|
+
/* These should be set from variables */
|
782
|
+
float increase_factor = ann->rprop_increase_factor; /*1.2; */
|
783
|
+
float decrease_factor = ann->rprop_decrease_factor; /*0.5; */
|
784
|
+
/* TODO: why is delta_min 0.0 in iRprop? SARPROP uses 1x10^-6 (Braun and Riedmiller, 1993) */
|
785
|
+
float delta_min = 0.000001f;
|
786
|
+
float delta_max = ann->rprop_delta_max; /*50.0; */
|
787
|
+
float weight_decay_shift = ann->sarprop_weight_decay_shift; /* ld 0.01 = -6.644 */
|
788
|
+
float step_error_threshold_factor = ann->sarprop_step_error_threshold_factor; /* 0.1 */
|
789
|
+
float step_error_shift = ann->sarprop_step_error_shift; /* ld 3 = 1.585 */
|
790
|
+
float T = ann->sarprop_temperature;
|
791
|
+
float MSE = fann_get_MSE(ann);
|
792
|
+
float RMSE = (float)sqrt(MSE);
|
793
|
+
|
794
|
+
unsigned int i = first_weight;
|
795
|
+
|
796
|
+
|
797
|
+
/* for all weights; TODO: are biases included? */
|
798
|
+
for(; i != past_end; i++)
|
799
|
+
{
|
800
|
+
/* TODO: confirm whether 1x10^-6 == delta_min is really better */
|
801
|
+
prev_step = fann_max(prev_steps[i], (fann_type) 0.000001); /* prev_step may not be zero because then the training will stop */
|
802
|
+
/* calculate SARPROP slope; TODO: better as new error function? (see SARPROP paper)*/
|
803
|
+
slope = -train_slopes[i] - weights[i] * (fann_type)fann_exp2(-T * epoch + weight_decay_shift);
|
804
|
+
|
805
|
+
/* TODO: is prev_train_slopes[i] 0.0 in the beginning? */
|
806
|
+
prev_slope = prev_train_slopes[i];
|
807
|
+
|
808
|
+
same_sign = prev_slope * slope;
|
809
|
+
|
810
|
+
if(same_sign > 0.0)
|
811
|
+
{
|
812
|
+
next_step = fann_min(prev_step * increase_factor, delta_max);
|
813
|
+
/* TODO: are the signs inverted? see differences between SARPROP paper and iRprop */
|
814
|
+
if (slope < 0.0)
|
815
|
+
weights[i] += next_step;
|
816
|
+
else
|
817
|
+
weights[i] -= next_step;
|
818
|
+
}
|
819
|
+
else if(same_sign < 0.0)
|
820
|
+
{
|
821
|
+
if(prev_step < step_error_threshold_factor * MSE)
|
822
|
+
next_step = prev_step * decrease_factor + (float)rand() / RAND_MAX * RMSE * (fann_type)fann_exp2(-T * epoch + step_error_shift);
|
823
|
+
else
|
824
|
+
next_step = fann_max(prev_step * decrease_factor, delta_min);
|
825
|
+
|
826
|
+
slope = 0.0;
|
827
|
+
}
|
828
|
+
else
|
829
|
+
{
|
830
|
+
if(slope < 0.0)
|
831
|
+
weights[i] += prev_step;
|
832
|
+
else
|
833
|
+
weights[i] -= prev_step;
|
834
|
+
}
|
835
|
+
|
836
|
+
|
837
|
+
/*if(i == 2){
|
838
|
+
* printf("weight=%f, slope=%f, next_step=%f, prev_step=%f\n", weights[i], slope, next_step, prev_step);
|
839
|
+
* } */
|
840
|
+
|
841
|
+
/* update global data arrays */
|
842
|
+
prev_steps[i] = next_step;
|
843
|
+
prev_train_slopes[i] = slope;
|
844
|
+
train_slopes[i] = 0.0;
|
845
|
+
}
|
846
|
+
}
|
847
|
+
|
767
848
|
#endif
|
768
849
|
|
769
850
|
FANN_GET_SET(enum fann_train_enum, training_algorithm)
|
@@ -957,6 +1038,10 @@ FANN_GET_SET(float, rprop_decrease_factor)
|
|
957
1038
|
FANN_GET_SET(float, rprop_delta_min)
|
958
1039
|
FANN_GET_SET(float, rprop_delta_max)
|
959
1040
|
FANN_GET_SET(float, rprop_delta_zero)
|
1041
|
+
FANN_GET_SET(float, sarprop_weight_decay_shift)
|
1042
|
+
FANN_GET_SET(float, sarprop_step_error_threshold_factor)
|
1043
|
+
FANN_GET_SET(float, sarprop_step_error_shift)
|
1044
|
+
FANN_GET_SET(float, sarprop_temperature)
|
960
1045
|
FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
|
961
1046
|
FANN_GET_SET(fann_type, bit_fail_limit)
|
962
1047
|
FANN_GET_SET(float, learning_momentum)
|
data/ext/ruby_fann/fann_train.h
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
/*
|
2
2
|
Fast Artificial Neural Network Library (fann)
|
3
|
-
Copyright (C) 2003 Steffen Nissen (
|
3
|
+
Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
|
4
4
|
|
5
5
|
This library is free software; you can redistribute it and/or
|
6
6
|
modify it under the terms of the GNU Lesser General Public
|
@@ -252,6 +252,17 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
|
|
252
252
|
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *filename);
|
253
253
|
|
254
254
|
|
255
|
+
/* Function: fann_create_train
|
256
|
+
Creates an empty training data struct.
|
257
|
+
|
258
|
+
See also:
|
259
|
+
<fann_read_train_from_file>, <fann_train_on_data>, <fann_destroy_train>,
|
260
|
+
<fann_save_train>
|
261
|
+
|
262
|
+
This function appears in FANN >= 2.2.0
|
263
|
+
*/
|
264
|
+
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output);
|
265
|
+
|
255
266
|
/* Function: fann_create_train_from_callback
|
256
267
|
Creates the training data struct from a user supplied function.
|
257
268
|
As the training data are numerable (data 1, data 2...), the user must write
|
@@ -1200,4 +1211,100 @@ FANN_EXTERNAL float FANN_API fann_get_rprop_delta_zero(struct fann *ann);
|
|
1200
1211
|
*/
|
1201
1212
|
FANN_EXTERNAL void FANN_API fann_set_rprop_delta_zero(struct fann *ann, float rprop_delta_max);
|
1202
1213
|
|
1214
|
+
/* Method: fann_get_sarprop_weight_decay_shift
|
1215
|
+
|
1216
|
+
The sarprop weight decay shift.
|
1217
|
+
|
1218
|
+
The default delta max is -6.644.
|
1219
|
+
|
1220
|
+
See also:
|
1221
|
+
<fann fann_set_sarprop_weight_decay_shift>
|
1222
|
+
|
1223
|
+
This function appears in FANN >= 2.1.0.
|
1224
|
+
*/
|
1225
|
+
FANN_EXTERNAL float FANN_API fann_get_sarprop_weight_decay_shift(struct fann *ann);
|
1226
|
+
|
1227
|
+
/* Method: fann_set_sarprop_weight_decay_shift
|
1228
|
+
|
1229
|
+
Set the sarprop weight decay shift.
|
1230
|
+
|
1231
|
+
This function appears in FANN >= 2.1.0.
|
1232
|
+
|
1233
|
+
See also:
|
1234
|
+
<fann_set_sarprop_weight_decay_shift>
|
1235
|
+
*/
|
1236
|
+
FANN_EXTERNAL void FANN_API fann_set_sarprop_weight_decay_shift(struct fann *ann, float sarprop_weight_decay_shift);
|
1237
|
+
|
1238
|
+
/* Method: fann_get_sarprop_step_error_threshold_factor
|
1239
|
+
|
1240
|
+
The sarprop step error threshold factor.
|
1241
|
+
|
1242
|
+
The default delta max is 0.1.
|
1243
|
+
|
1244
|
+
See also:
|
1245
|
+
<fann fann_get_sarprop_step_error_threshold_factor>
|
1246
|
+
|
1247
|
+
This function appears in FANN >= 2.1.0.
|
1248
|
+
*/
|
1249
|
+
FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_threshold_factor(struct fann *ann);
|
1250
|
+
|
1251
|
+
/* Method: fann_set_sarprop_step_error_threshold_factor
|
1252
|
+
|
1253
|
+
Set the sarprop step error threshold factor.
|
1254
|
+
|
1255
|
+
This function appears in FANN >= 2.1.0.
|
1256
|
+
|
1257
|
+
See also:
|
1258
|
+
<fann_get_sarprop_step_error_threshold_factor>
|
1259
|
+
*/
|
1260
|
+
FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_threshold_factor(struct fann *ann, float sarprop_step_error_threshold_factor);
|
1261
|
+
|
1262
|
+
/* Method: fann_get_sarprop_step_error_shift
|
1263
|
+
|
1264
|
+
The get sarprop step error shift.
|
1265
|
+
|
1266
|
+
The default delta max is 1.385.
|
1267
|
+
|
1268
|
+
See also:
|
1269
|
+
<fann_set_sarprop_step_error_shift>
|
1270
|
+
|
1271
|
+
This function appears in FANN >= 2.1.0.
|
1272
|
+
*/
|
1273
|
+
FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_shift(struct fann *ann);
|
1274
|
+
|
1275
|
+
/* Method: fann_set_sarprop_step_error_shift
|
1276
|
+
|
1277
|
+
Set the sarprop step error shift.
|
1278
|
+
|
1279
|
+
This function appears in FANN >= 2.1.0.
|
1280
|
+
|
1281
|
+
See also:
|
1282
|
+
<fann_get_sarprop_step_error_shift>
|
1283
|
+
*/
|
1284
|
+
FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_shift(struct fann *ann, float sarprop_step_error_shift);
|
1285
|
+
|
1286
|
+
/* Method: fann_get_sarprop_temperature
|
1287
|
+
|
1288
|
+
The sarprop weight decay shift.
|
1289
|
+
|
1290
|
+
The default delta max is 0.015.
|
1291
|
+
|
1292
|
+
See also:
|
1293
|
+
<fann_set_sarprop_temperature>
|
1294
|
+
|
1295
|
+
This function appears in FANN >= 2.1.0.
|
1296
|
+
*/
|
1297
|
+
FANN_EXTERNAL float FANN_API fann_get_sarprop_temperature(struct fann *ann);
|
1298
|
+
|
1299
|
+
/* Method: fann_set_sarprop_temperature
|
1300
|
+
|
1301
|
+
Set the sarprop_temperature.
|
1302
|
+
|
1303
|
+
This function appears in FANN >= 2.1.0.
|
1304
|
+
|
1305
|
+
See also:
|
1306
|
+
<fann_get_sarprop_temperature>
|
1307
|
+
*/
|
1308
|
+
FANN_EXTERNAL void FANN_API fann_set_sarprop_temperature(struct fann *ann, float sarprop_temperature);
|
1309
|
+
|
1203
1310
|
#endif
|
@@ -1,21 +1,21 @@
|
|
1
1
|
/*
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
2
|
+
Fast Artificial Neural Network Library (fann)
|
3
|
+
Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
|
4
|
+
|
5
|
+
This library is free software; you can redistribute it and/or
|
6
|
+
modify it under the terms of the GNU Lesser General Public
|
7
|
+
License as published by the Free Software Foundation; either
|
8
|
+
version 2.1 of the License, or (at your option) any later version.
|
9
|
+
|
10
|
+
This library is distributed in the hope that it will be useful,
|
11
|
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
13
|
+
Lesser General Public License for more details.
|
14
|
+
|
15
|
+
You should have received a copy of the GNU Lesser General Public
|
16
|
+
License along with this library; if not, write to the Free Software
|
17
|
+
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
18
|
+
*/
|
19
19
|
|
20
20
|
#include <stdio.h>
|
21
21
|
#include <stdlib.h>
|
@@ -84,7 +84,9 @@ FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
|
|
84
84
|
FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
|
85
85
|
{
|
86
86
|
unsigned int i;
|
87
|
-
|
87
|
+
if(fann_check_input_output_sizes(ann, data) == -1)
|
88
|
+
return 0;
|
89
|
+
|
88
90
|
fann_reset_MSE(ann);
|
89
91
|
|
90
92
|
for(i = 0; i != data->num_data; i++)
|
@@ -95,86 +97,38 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
|
|
95
97
|
return fann_get_MSE(ann);
|
96
98
|
}
|
97
99
|
|
100
|
+
#ifndef FIXEDFANN
|
101
|
+
|
98
102
|
/*
|
99
|
-
*
|
103
|
+
* Internal train function
|
100
104
|
*/
|
101
|
-
|
102
|
-
unsigned int num_input,
|
103
|
-
unsigned int num_output,
|
104
|
-
void (FANN_API *user_function)( unsigned int,
|
105
|
-
unsigned int,
|
106
|
-
unsigned int,
|
107
|
-
fann_type * ,
|
108
|
-
fann_type * ))
|
105
|
+
float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
|
109
106
|
{
|
110
|
-
|
111
|
-
fann_type *data_input, *data_output;
|
112
|
-
struct fann_train_data *data = (struct fann_train_data *)
|
113
|
-
malloc(sizeof(struct fann_train_data));
|
114
|
-
|
115
|
-
if(data == NULL){
|
116
|
-
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
117
|
-
return NULL;
|
118
|
-
}
|
119
|
-
|
120
|
-
fann_init_error_data((struct fann_error *) data);
|
121
|
-
|
122
|
-
data->num_data = num_data;
|
123
|
-
data->num_input = num_input;
|
124
|
-
data->num_output = num_output;
|
125
|
-
|
126
|
-
data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
127
|
-
if(data->input == NULL)
|
128
|
-
{
|
129
|
-
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
130
|
-
fann_destroy_train(data);
|
131
|
-
return NULL;
|
132
|
-
}
|
133
|
-
|
134
|
-
data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
135
|
-
if(data->output == NULL)
|
136
|
-
{
|
137
|
-
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
138
|
-
fann_destroy_train(data);
|
139
|
-
return NULL;
|
140
|
-
}
|
141
|
-
|
142
|
-
data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
|
143
|
-
if(data_input == NULL)
|
144
|
-
{
|
145
|
-
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
146
|
-
fann_destroy_train(data);
|
147
|
-
return NULL;
|
148
|
-
}
|
149
|
-
|
150
|
-
data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
|
151
|
-
if(data_output == NULL)
|
152
|
-
{
|
153
|
-
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
154
|
-
fann_destroy_train(data);
|
155
|
-
return NULL;
|
156
|
-
}
|
157
|
-
|
158
|
-
for( i = 0; i != num_data; i++)
|
159
|
-
{
|
160
|
-
data->input[i] = data_input;
|
161
|
-
data_input += num_input;
|
107
|
+
unsigned int i;
|
162
108
|
|
163
|
-
|
164
|
-
|
109
|
+
if(ann->prev_train_slopes == NULL)
|
110
|
+
{
|
111
|
+
fann_clear_train_arrays(ann);
|
112
|
+
}
|
165
113
|
|
166
|
-
|
167
|
-
}
|
114
|
+
fann_reset_MSE(ann);
|
168
115
|
|
169
|
-
|
170
|
-
|
116
|
+
for(i = 0; i < data->num_data; i++)
|
117
|
+
{
|
118
|
+
fann_run(ann, data->input[i]);
|
119
|
+
fann_compute_MSE(ann, data->output[i]);
|
120
|
+
fann_backpropagate_MSE(ann);
|
121
|
+
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
122
|
+
}
|
123
|
+
fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
|
171
124
|
|
172
|
-
|
125
|
+
return fann_get_MSE(ann);
|
126
|
+
}
|
173
127
|
|
174
128
|
/*
|
175
129
|
* Internal train function
|
176
130
|
*/
|
177
|
-
float
|
131
|
+
float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
|
178
132
|
{
|
179
133
|
unsigned int i;
|
180
134
|
|
@@ -192,7 +146,8 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
|
|
192
146
|
fann_backpropagate_MSE(ann);
|
193
147
|
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
194
148
|
}
|
195
|
-
|
149
|
+
|
150
|
+
fann_update_weights_irpropm(ann, 0, ann->total_connections);
|
196
151
|
|
197
152
|
return fann_get_MSE(ann);
|
198
153
|
}
|
@@ -200,7 +155,7 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
|
|
200
155
|
/*
|
201
156
|
* Internal train function
|
202
157
|
*/
|
203
|
-
float
|
158
|
+
float fann_train_epoch_sarprop(struct fann *ann, struct fann_train_data *data)
|
204
159
|
{
|
205
160
|
unsigned int i;
|
206
161
|
|
@@ -219,7 +174,9 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
|
|
219
174
|
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
220
175
|
}
|
221
176
|
|
222
|
-
|
177
|
+
fann_update_weights_sarprop(ann, ann->sarprop_epoch, 0, ann->total_connections);
|
178
|
+
|
179
|
+
++(ann->sarprop_epoch);
|
223
180
|
|
224
181
|
return fann_get_MSE(ann);
|
225
182
|
}
|
@@ -268,12 +225,17 @@ float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *dat
|
|
268
225
|
*/
|
269
226
|
FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
|
270
227
|
{
|
228
|
+
if(fann_check_input_output_sizes(ann, data) == -1)
|
229
|
+
return 0;
|
230
|
+
|
271
231
|
switch (ann->training_algorithm)
|
272
232
|
{
|
273
233
|
case FANN_TRAIN_QUICKPROP:
|
274
234
|
return fann_train_epoch_quickprop(ann, data);
|
275
235
|
case FANN_TRAIN_RPROP:
|
276
236
|
return fann_train_epoch_irpropm(ann, data);
|
237
|
+
case FANN_TRAIN_SARPROP:
|
238
|
+
return fann_train_epoch_sarprop(ann, data);
|
277
239
|
case FANN_TRAIN_BATCH:
|
278
240
|
return fann_train_epoch_batch(ann, data);
|
279
241
|
case FANN_TRAIN_INCREMENTAL:
|
@@ -795,15 +757,13 @@ int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const
|
|
795
757
|
return retval;
|
796
758
|
}
|
797
759
|
|
798
|
-
|
799
760
|
/*
|
800
|
-
*
|
761
|
+
* Creates an empty set of training data
|
801
762
|
*/
|
802
|
-
struct fann_train_data *
|
763
|
+
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output)
|
803
764
|
{
|
804
|
-
unsigned int num_input, num_output, num_data, i, j;
|
805
|
-
unsigned int line = 1;
|
806
765
|
fann_type *data_input, *data_output;
|
766
|
+
unsigned int i;
|
807
767
|
struct fann_train_data *data =
|
808
768
|
(struct fann_train_data *) malloc(sizeof(struct fann_train_data));
|
809
769
|
|
@@ -812,15 +772,7 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
|
|
812
772
|
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
813
773
|
return NULL;
|
814
774
|
}
|
815
|
-
|
816
|
-
if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
|
817
|
-
{
|
818
|
-
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
|
819
|
-
fann_destroy_train(data);
|
820
|
-
return NULL;
|
821
|
-
}
|
822
|
-
line++;
|
823
|
-
|
775
|
+
|
824
776
|
fann_init_error_data((struct fann_error *) data);
|
825
777
|
|
826
778
|
data->num_data = num_data;
|
@@ -862,7 +814,63 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
|
|
862
814
|
{
|
863
815
|
data->input[i] = data_input;
|
864
816
|
data_input += num_input;
|
817
|
+
data->output[i] = data_output;
|
818
|
+
data_output += num_output;
|
819
|
+
}
|
820
|
+
return data;
|
821
|
+
}
|
822
|
+
|
823
|
+
/*
|
824
|
+
* Creates training data from a callback function.
|
825
|
+
*/
|
826
|
+
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
|
827
|
+
unsigned int num_input,
|
828
|
+
unsigned int num_output,
|
829
|
+
void (FANN_API *user_function)( unsigned int,
|
830
|
+
unsigned int,
|
831
|
+
unsigned int,
|
832
|
+
fann_type * ,
|
833
|
+
fann_type * ))
|
834
|
+
{
|
835
|
+
unsigned int i;
|
836
|
+
struct fann_train_data *data = fann_create_train(num_data, num_input, num_output);
|
837
|
+
if(data == NULL)
|
838
|
+
{
|
839
|
+
return NULL;
|
840
|
+
}
|
841
|
+
|
842
|
+
for( i = 0; i != num_data; i++)
|
843
|
+
{
|
844
|
+
(*user_function)(i, num_input, num_output, data->input[i], data->output[i]);
|
845
|
+
}
|
846
|
+
|
847
|
+
return data;
|
848
|
+
}
|
865
849
|
|
850
|
+
/*
|
851
|
+
* INTERNAL FUNCTION Reads training data from a file descriptor.
|
852
|
+
*/
|
853
|
+
struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
|
854
|
+
{
|
855
|
+
unsigned int num_input, num_output, num_data, i, j;
|
856
|
+
unsigned int line = 1;
|
857
|
+
struct fann_train_data *data;
|
858
|
+
|
859
|
+
if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
|
860
|
+
{
|
861
|
+
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
|
862
|
+
return NULL;
|
863
|
+
}
|
864
|
+
line++;
|
865
|
+
|
866
|
+
data = fann_create_train(num_data, num_input, num_output);
|
867
|
+
if(data == NULL)
|
868
|
+
{
|
869
|
+
return NULL;
|
870
|
+
}
|
871
|
+
|
872
|
+
for(i = 0; i != num_data; i++)
|
873
|
+
{
|
866
874
|
for(j = 0; j != num_input; j++)
|
867
875
|
{
|
868
876
|
if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
|
@@ -874,9 +882,6 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
|
|
874
882
|
}
|
875
883
|
line++;
|
876
884
|
|
877
|
-
data->output[i] = data_output;
|
878
|
-
data_output += num_output;
|
879
|
-
|
880
885
|
for(j = 0; j != num_output; j++)
|
881
886
|
{
|
882
887
|
if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
|
@@ -928,7 +933,7 @@ FANN_EXTERNAL void FANN_API fann_scale_input( struct fann *ann, fann_type *input
|
|
928
933
|
(
|
929
934
|
( input_vector[ cur_neuron ] - ann->scale_mean_in[ cur_neuron ] )
|
930
935
|
/ ann->scale_deviation_in[ cur_neuron ]
|
931
|
-
- ( -1.0 ) /* This is old_min */
|
936
|
+
- ( (fann_type)-1.0 ) /* This is old_min */
|
932
937
|
)
|
933
938
|
* ann->scale_factor_in[ cur_neuron ]
|
934
939
|
+ ann->scale_new_min_in[ cur_neuron ];
|
@@ -951,7 +956,7 @@ FANN_EXTERNAL void FANN_API fann_scale_output( struct fann *ann, fann_type *outp
|
|
951
956
|
(
|
952
957
|
( output_vector[ cur_neuron ] - ann->scale_mean_out[ cur_neuron ] )
|
953
958
|
/ ann->scale_deviation_out[ cur_neuron ]
|
954
|
-
- ( -1.0 ) /* This is old_min */
|
959
|
+
- ( (fann_type)-1.0 ) /* This is old_min */
|
955
960
|
)
|
956
961
|
* ann->scale_factor_out[ cur_neuron ]
|
957
962
|
+ ann->scale_new_min_out[ cur_neuron ];
|
@@ -977,7 +982,7 @@ FANN_EXTERNAL void FANN_API fann_descale_input( struct fann *ann, fann_type *inp
|
|
977
982
|
- ann->scale_new_min_in[ cur_neuron ]
|
978
983
|
)
|
979
984
|
/ ann->scale_factor_in[ cur_neuron ]
|
980
|
-
+ ( -1.0 ) /* This is old_min */
|
985
|
+
+ ( (fann_type)-1.0 ) /* This is old_min */
|
981
986
|
)
|
982
987
|
* ann->scale_deviation_in[ cur_neuron ]
|
983
988
|
+ ann->scale_mean_in[ cur_neuron ];
|
@@ -1003,7 +1008,7 @@ FANN_EXTERNAL void FANN_API fann_descale_output( struct fann *ann, fann_type *ou
|
|
1003
1008
|
- ann->scale_new_min_out[ cur_neuron ]
|
1004
1009
|
)
|
1005
1010
|
/ ann->scale_factor_out[ cur_neuron ]
|
1006
|
-
+ ( -1.0 ) /* This is old_min */
|
1011
|
+
+ ( (fann_type)-1.0 ) /* This is old_min */
|
1007
1012
|
)
|
1008
1013
|
* ann->scale_deviation_out[ cur_neuron ]
|
1009
1014
|
+ ann->scale_mean_out[ cur_neuron ];
|
@@ -1021,14 +1026,8 @@ FANN_EXTERNAL void FANN_API fann_scale_train( struct fann *ann, struct fann_trai
|
|
1021
1026
|
return;
|
1022
1027
|
}
|
1023
1028
|
/* Check that we have good training data. */
|
1024
|
-
|
1025
|
-
if( data->num_input != ann->num_input
|
1026
|
-
|| data->num_output != ann->num_output
|
1027
|
-
)
|
1028
|
-
{
|
1029
|
-
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1029
|
+
if(fann_check_input_output_sizes(ann, data) == -1)
|
1030
1030
|
return;
|
1031
|
-
}
|
1032
1031
|
|
1033
1032
|
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
|
1034
1033
|
{
|
@@ -1049,14 +1048,8 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
|
|
1049
1048
|
return;
|
1050
1049
|
}
|
1051
1050
|
/* Check that we have good training data. */
|
1052
|
-
|
1053
|
-
if( data->num_input != ann->num_input
|
1054
|
-
|| data->num_output != ann->num_output
|
1055
|
-
)
|
1056
|
-
{
|
1057
|
-
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1051
|
+
if(fann_check_input_output_sizes(ann, data) == -1)
|
1058
1052
|
return;
|
1059
|
-
}
|
1060
1053
|
|
1061
1054
|
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
|
1062
1055
|
{
|
@@ -1072,38 +1065,38 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
|
|
1072
1065
|
#define SCALE_SET_PARAM( where ) \
|
1073
1066
|
/* Calculate mean: sum(x)/length */ \
|
1074
1067
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1075
|
-
ann->scale_mean_##where[ cur_neuron ] = 0.
|
1068
|
+
ann->scale_mean_##where[ cur_neuron ] = 0.0f; \
|
1076
1069
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1077
1070
|
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
|
1078
|
-
ann->scale_mean_##where[ cur_neuron ] += data->where##put[ cur_sample ][ cur_neuron ]
|
1071
|
+
ann->scale_mean_##where[ cur_neuron ] += (float)data->where##put[ cur_sample ][ cur_neuron ];\
|
1079
1072
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1080
1073
|
ann->scale_mean_##where[ cur_neuron ] /= (float)data->num_data; \
|
1081
1074
|
/* Calculate deviation: sqrt(sum((x-mean)^2)/length) */ \
|
1082
1075
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1083
|
-
ann->scale_deviation_##where[ cur_neuron ] = 0.
|
1076
|
+
ann->scale_deviation_##where[ cur_neuron ] = 0.0f; \
|
1084
1077
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1085
1078
|
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
|
1086
1079
|
ann->scale_deviation_##where[ cur_neuron ] += \
|
1087
1080
|
/* Another local variable in macro? Oh no! */ \
|
1088
1081
|
( \
|
1089
|
-
data->where##put[ cur_sample ][ cur_neuron ]
|
1082
|
+
(float)data->where##put[ cur_sample ][ cur_neuron ] \
|
1090
1083
|
- ann->scale_mean_##where[ cur_neuron ] \
|
1091
1084
|
) \
|
1092
1085
|
* \
|
1093
1086
|
( \
|
1094
|
-
data->where##put[ cur_sample ][ cur_neuron ]
|
1087
|
+
(float)data->where##put[ cur_sample ][ cur_neuron ] \
|
1095
1088
|
- ann->scale_mean_##where[ cur_neuron ] \
|
1096
1089
|
); \
|
1097
1090
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1098
1091
|
ann->scale_deviation_##where[ cur_neuron ] = \
|
1099
|
-
sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data );
|
1092
|
+
(float)sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
|
1100
1093
|
/* Calculate factor: (new_max-new_min)/(old_max(1)-old_min(-1)) */ \
|
1101
1094
|
/* Looks like we dont need whole array of factors? */ \
|
1102
1095
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1103
1096
|
ann->scale_factor_##where[ cur_neuron ] = \
|
1104
1097
|
( new_##where##put_max - new_##where##put_min ) \
|
1105
1098
|
/ \
|
1106
|
-
( 1.
|
1099
|
+
( 1.0f - ( -1.0f ) ); \
|
1107
1100
|
/* Copy new minimum. */ \
|
1108
1101
|
/* Looks like we dont need whole array of new minimums? */ \
|
1109
1102
|
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
@@ -1229,3 +1222,22 @@ FANN_EXTERNAL int FANN_API fann_clear_scaling_params(struct fann *ann)
|
|
1229
1222
|
}
|
1230
1223
|
|
1231
1224
|
#endif
|
1225
|
+
|
1226
|
+
int fann_check_input_output_sizes(struct fann *ann, struct fann_train_data *data)
|
1227
|
+
{
|
1228
|
+
if(ann->num_input != data->num_input)
|
1229
|
+
{
|
1230
|
+
fann_error((struct fann_error *) ann, FANN_E_INPUT_NO_MATCH,
|
1231
|
+
ann->num_input, data->num_input);
|
1232
|
+
return -1;
|
1233
|
+
}
|
1234
|
+
|
1235
|
+
if(ann->num_output != data->num_output)
|
1236
|
+
{
|
1237
|
+
fann_error((struct fann_error *) ann, FANN_E_OUTPUT_NO_MATCH,
|
1238
|
+
ann->num_output, data->num_output);
|
1239
|
+
return -1;
|
1240
|
+
}
|
1241
|
+
|
1242
|
+
return 0;
|
1243
|
+
}
|