ruby-fann 1.2.5 → 1.2.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -21,6 +21,7 @@
21
21
  #include <stdlib.h>
22
22
  #include <stdarg.h>
23
23
  #include <string.h>
24
+ #include <math.h>
24
25
 
25
26
  #include "config.h"
26
27
  #include "fann.h"
@@ -206,6 +207,7 @@ FANN_EXTERNAL unsigned int FANN_API fann_get_bit_fail(struct fann *ann)
206
207
  */
207
208
  FANN_EXTERNAL void FANN_API fann_reset_MSE(struct fann *ann)
208
209
  {
210
+ /*printf("resetMSE %d %f\n", ann->num_MSE, ann->MSE_value);*/
209
211
  ann->num_MSE = 0;
210
212
  ann->MSE_value = 0;
211
213
  ann->num_bit_fail = 0;
@@ -764,6 +766,85 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, un
764
766
  }
765
767
  }
766
768
 
769
+ /* INTERNAL FUNCTION
770
+ The SARprop- algorithm
771
+ */
772
+ void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight, unsigned int past_end)
773
+ {
774
+ fann_type *train_slopes = ann->train_slopes;
775
+ fann_type *weights = ann->weights;
776
+ fann_type *prev_steps = ann->prev_steps;
777
+ fann_type *prev_train_slopes = ann->prev_train_slopes;
778
+
779
+ fann_type prev_step, slope, prev_slope, next_step = 0, same_sign;
780
+
781
+ /* These should be set from variables */
782
+ float increase_factor = ann->rprop_increase_factor; /*1.2; */
783
+ float decrease_factor = ann->rprop_decrease_factor; /*0.5; */
784
+ /* TODO: why is delta_min 0.0 in iRprop? SARPROP uses 1x10^-6 (Braun and Riedmiller, 1993) */
785
+ float delta_min = 0.000001f;
786
+ float delta_max = ann->rprop_delta_max; /*50.0; */
787
+ float weight_decay_shift = ann->sarprop_weight_decay_shift; /* ld 0.01 = -6.644 */
788
+ float step_error_threshold_factor = ann->sarprop_step_error_threshold_factor; /* 0.1 */
789
+ float step_error_shift = ann->sarprop_step_error_shift; /* ld 3 = 1.585 */
790
+ float T = ann->sarprop_temperature;
791
+ float MSE = fann_get_MSE(ann);
792
+ float RMSE = (float)sqrt(MSE);
793
+
794
+ unsigned int i = first_weight;
795
+
796
+
797
+ /* for all weights; TODO: are biases included? */
798
+ for(; i != past_end; i++)
799
+ {
800
+ /* TODO: confirm whether 1x10^-6 == delta_min is really better */
801
+ prev_step = fann_max(prev_steps[i], (fann_type) 0.000001); /* prev_step may not be zero because then the training will stop */
802
+ /* calculate SARPROP slope; TODO: better as new error function? (see SARPROP paper)*/
803
+ slope = -train_slopes[i] - weights[i] * (fann_type)fann_exp2(-T * epoch + weight_decay_shift);
804
+
805
+ /* TODO: is prev_train_slopes[i] 0.0 in the beginning? */
806
+ prev_slope = prev_train_slopes[i];
807
+
808
+ same_sign = prev_slope * slope;
809
+
810
+ if(same_sign > 0.0)
811
+ {
812
+ next_step = fann_min(prev_step * increase_factor, delta_max);
813
+ /* TODO: are the signs inverted? see differences between SARPROP paper and iRprop */
814
+ if (slope < 0.0)
815
+ weights[i] += next_step;
816
+ else
817
+ weights[i] -= next_step;
818
+ }
819
+ else if(same_sign < 0.0)
820
+ {
821
+ if(prev_step < step_error_threshold_factor * MSE)
822
+ next_step = prev_step * decrease_factor + (float)rand() / RAND_MAX * RMSE * (fann_type)fann_exp2(-T * epoch + step_error_shift);
823
+ else
824
+ next_step = fann_max(prev_step * decrease_factor, delta_min);
825
+
826
+ slope = 0.0;
827
+ }
828
+ else
829
+ {
830
+ if(slope < 0.0)
831
+ weights[i] += prev_step;
832
+ else
833
+ weights[i] -= prev_step;
834
+ }
835
+
836
+
837
+ /*if(i == 2){
838
+ * printf("weight=%f, slope=%f, next_step=%f, prev_step=%f\n", weights[i], slope, next_step, prev_step);
839
+ * } */
840
+
841
+ /* update global data arrays */
842
+ prev_steps[i] = next_step;
843
+ prev_train_slopes[i] = slope;
844
+ train_slopes[i] = 0.0;
845
+ }
846
+ }
847
+
767
848
  #endif
768
849
 
769
850
  FANN_GET_SET(enum fann_train_enum, training_algorithm)
@@ -957,6 +1038,10 @@ FANN_GET_SET(float, rprop_decrease_factor)
957
1038
  FANN_GET_SET(float, rprop_delta_min)
958
1039
  FANN_GET_SET(float, rprop_delta_max)
959
1040
  FANN_GET_SET(float, rprop_delta_zero)
1041
+ FANN_GET_SET(float, sarprop_weight_decay_shift)
1042
+ FANN_GET_SET(float, sarprop_step_error_threshold_factor)
1043
+ FANN_GET_SET(float, sarprop_step_error_shift)
1044
+ FANN_GET_SET(float, sarprop_temperature)
960
1045
  FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
961
1046
  FANN_GET_SET(fann_type, bit_fail_limit)
962
1047
  FANN_GET_SET(float, learning_momentum)
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -252,6 +252,17 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
252
252
  FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *filename);
253
253
 
254
254
 
255
+ /* Function: fann_create_train
256
+ Creates an empty training data struct.
257
+
258
+ See also:
259
+ <fann_read_train_from_file>, <fann_train_on_data>, <fann_destroy_train>,
260
+ <fann_save_train>
261
+
262
+ This function appears in FANN >= 2.2.0
263
+ */
264
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output);
265
+
255
266
  /* Function: fann_create_train_from_callback
256
267
  Creates the training data struct from a user supplied function.
257
268
  As the training data are numerable (data 1, data 2...), the user must write
@@ -1200,4 +1211,100 @@ FANN_EXTERNAL float FANN_API fann_get_rprop_delta_zero(struct fann *ann);
1200
1211
  */
1201
1212
  FANN_EXTERNAL void FANN_API fann_set_rprop_delta_zero(struct fann *ann, float rprop_delta_max);
1202
1213
 
1214
+ /* Method: fann_get_sarprop_weight_decay_shift
1215
+
1216
+ The sarprop weight decay shift.
1217
+
1218
+ The default delta max is -6.644.
1219
+
1220
+ See also:
1221
+ <fann fann_set_sarprop_weight_decay_shift>
1222
+
1223
+ This function appears in FANN >= 2.1.0.
1224
+ */
1225
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_weight_decay_shift(struct fann *ann);
1226
+
1227
+ /* Method: fann_set_sarprop_weight_decay_shift
1228
+
1229
+ Set the sarprop weight decay shift.
1230
+
1231
+ This function appears in FANN >= 2.1.0.
1232
+
1233
+ See also:
1234
+ <fann_set_sarprop_weight_decay_shift>
1235
+ */
1236
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_weight_decay_shift(struct fann *ann, float sarprop_weight_decay_shift);
1237
+
1238
+ /* Method: fann_get_sarprop_step_error_threshold_factor
1239
+
1240
+ The sarprop step error threshold factor.
1241
+
1242
+ The default delta max is 0.1.
1243
+
1244
+ See also:
1245
+ <fann fann_get_sarprop_step_error_threshold_factor>
1246
+
1247
+ This function appears in FANN >= 2.1.0.
1248
+ */
1249
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_threshold_factor(struct fann *ann);
1250
+
1251
+ /* Method: fann_set_sarprop_step_error_threshold_factor
1252
+
1253
+ Set the sarprop step error threshold factor.
1254
+
1255
+ This function appears in FANN >= 2.1.0.
1256
+
1257
+ See also:
1258
+ <fann_get_sarprop_step_error_threshold_factor>
1259
+ */
1260
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_threshold_factor(struct fann *ann, float sarprop_step_error_threshold_factor);
1261
+
1262
+ /* Method: fann_get_sarprop_step_error_shift
1263
+
1264
+ The get sarprop step error shift.
1265
+
1266
+ The default delta max is 1.385.
1267
+
1268
+ See also:
1269
+ <fann_set_sarprop_step_error_shift>
1270
+
1271
+ This function appears in FANN >= 2.1.0.
1272
+ */
1273
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_shift(struct fann *ann);
1274
+
1275
+ /* Method: fann_set_sarprop_step_error_shift
1276
+
1277
+ Set the sarprop step error shift.
1278
+
1279
+ This function appears in FANN >= 2.1.0.
1280
+
1281
+ See also:
1282
+ <fann_get_sarprop_step_error_shift>
1283
+ */
1284
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_shift(struct fann *ann, float sarprop_step_error_shift);
1285
+
1286
+ /* Method: fann_get_sarprop_temperature
1287
+
1288
+ The sarprop weight decay shift.
1289
+
1290
+ The default delta max is 0.015.
1291
+
1292
+ See also:
1293
+ <fann_set_sarprop_temperature>
1294
+
1295
+ This function appears in FANN >= 2.1.0.
1296
+ */
1297
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_temperature(struct fann *ann);
1298
+
1299
+ /* Method: fann_set_sarprop_temperature
1300
+
1301
+ Set the sarprop_temperature.
1302
+
1303
+ This function appears in FANN >= 2.1.0.
1304
+
1305
+ See also:
1306
+ <fann_get_sarprop_temperature>
1307
+ */
1308
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_temperature(struct fann *ann, float sarprop_temperature);
1309
+
1203
1310
  #endif
@@ -1,21 +1,21 @@
1
1
  /*
2
- * Fast Artificial Neural Network Library (fann) Copyright (C) 2003
3
- * Steffen Nissen (lukesky@diku.dk)
4
- *
5
- * This library is free software; you can redistribute it and/or modify it
6
- * under the terms of the GNU Lesser General Public License as published
7
- * by the Free Software Foundation; either version 2.1 of the License, or
8
- * (at your option) any later version.
9
- *
10
- * This library is distributed in the hope that it will be useful, but
11
- * WITHOUT ANY WARRANTY; without even the implied warranty of
12
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
- * Lesser General Public License for more details.
14
- *
15
- * You should have received a copy of the GNU Lesser General Public
16
- * License along with this library; if not, write to the Free Software
17
- * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18
- */
2
+ Fast Artificial Neural Network Library (fann)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18
+ */
19
19
 
20
20
  #include <stdio.h>
21
21
  #include <stdlib.h>
@@ -84,7 +84,9 @@ FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
84
84
  FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
85
85
  {
86
86
  unsigned int i;
87
-
87
+ if(fann_check_input_output_sizes(ann, data) == -1)
88
+ return 0;
89
+
88
90
  fann_reset_MSE(ann);
89
91
 
90
92
  for(i = 0; i != data->num_data; i++)
@@ -95,86 +97,38 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
95
97
  return fann_get_MSE(ann);
96
98
  }
97
99
 
100
+ #ifndef FIXEDFANN
101
+
98
102
  /*
99
- * Creates training data from a callback function.
103
+ * Internal train function
100
104
  */
101
- FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
102
- unsigned int num_input,
103
- unsigned int num_output,
104
- void (FANN_API *user_function)( unsigned int,
105
- unsigned int,
106
- unsigned int,
107
- fann_type * ,
108
- fann_type * ))
105
+ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
109
106
  {
110
- unsigned int i;
111
- fann_type *data_input, *data_output;
112
- struct fann_train_data *data = (struct fann_train_data *)
113
- malloc(sizeof(struct fann_train_data));
114
-
115
- if(data == NULL){
116
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
117
- return NULL;
118
- }
119
-
120
- fann_init_error_data((struct fann_error *) data);
121
-
122
- data->num_data = num_data;
123
- data->num_input = num_input;
124
- data->num_output = num_output;
125
-
126
- data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
127
- if(data->input == NULL)
128
- {
129
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
130
- fann_destroy_train(data);
131
- return NULL;
132
- }
133
-
134
- data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
135
- if(data->output == NULL)
136
- {
137
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
138
- fann_destroy_train(data);
139
- return NULL;
140
- }
141
-
142
- data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
143
- if(data_input == NULL)
144
- {
145
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
146
- fann_destroy_train(data);
147
- return NULL;
148
- }
149
-
150
- data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
151
- if(data_output == NULL)
152
- {
153
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
154
- fann_destroy_train(data);
155
- return NULL;
156
- }
157
-
158
- for( i = 0; i != num_data; i++)
159
- {
160
- data->input[i] = data_input;
161
- data_input += num_input;
107
+ unsigned int i;
162
108
 
163
- data->output[i] = data_output;
164
- data_output += num_output;
109
+ if(ann->prev_train_slopes == NULL)
110
+ {
111
+ fann_clear_train_arrays(ann);
112
+ }
165
113
 
166
- (*user_function)(i, num_input, num_output, data->input[i],data->output[i] );
167
- }
114
+ fann_reset_MSE(ann);
168
115
 
169
- return data;
170
- }
116
+ for(i = 0; i < data->num_data; i++)
117
+ {
118
+ fann_run(ann, data->input[i]);
119
+ fann_compute_MSE(ann, data->output[i]);
120
+ fann_backpropagate_MSE(ann);
121
+ fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
122
+ }
123
+ fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
171
124
 
172
- #ifndef FIXEDFANN
125
+ return fann_get_MSE(ann);
126
+ }
173
127
 
174
128
  /*
175
129
  * Internal train function
176
130
  */
177
- float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
131
+ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
178
132
  {
179
133
  unsigned int i;
180
134
 
@@ -192,7 +146,8 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
192
146
  fann_backpropagate_MSE(ann);
193
147
  fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
194
148
  }
195
- fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
149
+
150
+ fann_update_weights_irpropm(ann, 0, ann->total_connections);
196
151
 
197
152
  return fann_get_MSE(ann);
198
153
  }
@@ -200,7 +155,7 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
200
155
  /*
201
156
  * Internal train function
202
157
  */
203
- float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
158
+ float fann_train_epoch_sarprop(struct fann *ann, struct fann_train_data *data)
204
159
  {
205
160
  unsigned int i;
206
161
 
@@ -219,7 +174,9 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
219
174
  fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
220
175
  }
221
176
 
222
- fann_update_weights_irpropm(ann, 0, ann->total_connections);
177
+ fann_update_weights_sarprop(ann, ann->sarprop_epoch, 0, ann->total_connections);
178
+
179
+ ++(ann->sarprop_epoch);
223
180
 
224
181
  return fann_get_MSE(ann);
225
182
  }
@@ -268,12 +225,17 @@ float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *dat
268
225
  */
269
226
  FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
270
227
  {
228
+ if(fann_check_input_output_sizes(ann, data) == -1)
229
+ return 0;
230
+
271
231
  switch (ann->training_algorithm)
272
232
  {
273
233
  case FANN_TRAIN_QUICKPROP:
274
234
  return fann_train_epoch_quickprop(ann, data);
275
235
  case FANN_TRAIN_RPROP:
276
236
  return fann_train_epoch_irpropm(ann, data);
237
+ case FANN_TRAIN_SARPROP:
238
+ return fann_train_epoch_sarprop(ann, data);
277
239
  case FANN_TRAIN_BATCH:
278
240
  return fann_train_epoch_batch(ann, data);
279
241
  case FANN_TRAIN_INCREMENTAL:
@@ -795,15 +757,13 @@ int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const
795
757
  return retval;
796
758
  }
797
759
 
798
-
799
760
  /*
800
- * INTERNAL FUNCTION Reads training data from a file descriptor.
761
+ * Creates an empty set of training data
801
762
  */
802
- struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
763
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output)
803
764
  {
804
- unsigned int num_input, num_output, num_data, i, j;
805
- unsigned int line = 1;
806
765
  fann_type *data_input, *data_output;
766
+ unsigned int i;
807
767
  struct fann_train_data *data =
808
768
  (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
809
769
 
@@ -812,15 +772,7 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
812
772
  fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
813
773
  return NULL;
814
774
  }
815
-
816
- if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
817
- {
818
- fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
819
- fann_destroy_train(data);
820
- return NULL;
821
- }
822
- line++;
823
-
775
+
824
776
  fann_init_error_data((struct fann_error *) data);
825
777
 
826
778
  data->num_data = num_data;
@@ -862,7 +814,63 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
862
814
  {
863
815
  data->input[i] = data_input;
864
816
  data_input += num_input;
817
+ data->output[i] = data_output;
818
+ data_output += num_output;
819
+ }
820
+ return data;
821
+ }
822
+
823
+ /*
824
+ * Creates training data from a callback function.
825
+ */
826
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
827
+ unsigned int num_input,
828
+ unsigned int num_output,
829
+ void (FANN_API *user_function)( unsigned int,
830
+ unsigned int,
831
+ unsigned int,
832
+ fann_type * ,
833
+ fann_type * ))
834
+ {
835
+ unsigned int i;
836
+ struct fann_train_data *data = fann_create_train(num_data, num_input, num_output);
837
+ if(data == NULL)
838
+ {
839
+ return NULL;
840
+ }
841
+
842
+ for( i = 0; i != num_data; i++)
843
+ {
844
+ (*user_function)(i, num_input, num_output, data->input[i], data->output[i]);
845
+ }
846
+
847
+ return data;
848
+ }
865
849
 
850
+ /*
851
+ * INTERNAL FUNCTION Reads training data from a file descriptor.
852
+ */
853
+ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
854
+ {
855
+ unsigned int num_input, num_output, num_data, i, j;
856
+ unsigned int line = 1;
857
+ struct fann_train_data *data;
858
+
859
+ if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
860
+ {
861
+ fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
862
+ return NULL;
863
+ }
864
+ line++;
865
+
866
+ data = fann_create_train(num_data, num_input, num_output);
867
+ if(data == NULL)
868
+ {
869
+ return NULL;
870
+ }
871
+
872
+ for(i = 0; i != num_data; i++)
873
+ {
866
874
  for(j = 0; j != num_input; j++)
867
875
  {
868
876
  if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
@@ -874,9 +882,6 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
874
882
  }
875
883
  line++;
876
884
 
877
- data->output[i] = data_output;
878
- data_output += num_output;
879
-
880
885
  for(j = 0; j != num_output; j++)
881
886
  {
882
887
  if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
@@ -928,7 +933,7 @@ FANN_EXTERNAL void FANN_API fann_scale_input( struct fann *ann, fann_type *input
928
933
  (
929
934
  ( input_vector[ cur_neuron ] - ann->scale_mean_in[ cur_neuron ] )
930
935
  / ann->scale_deviation_in[ cur_neuron ]
931
- - ( -1.0 ) /* This is old_min */
936
+ - ( (fann_type)-1.0 ) /* This is old_min */
932
937
  )
933
938
  * ann->scale_factor_in[ cur_neuron ]
934
939
  + ann->scale_new_min_in[ cur_neuron ];
@@ -951,7 +956,7 @@ FANN_EXTERNAL void FANN_API fann_scale_output( struct fann *ann, fann_type *outp
951
956
  (
952
957
  ( output_vector[ cur_neuron ] - ann->scale_mean_out[ cur_neuron ] )
953
958
  / ann->scale_deviation_out[ cur_neuron ]
954
- - ( -1.0 ) /* This is old_min */
959
+ - ( (fann_type)-1.0 ) /* This is old_min */
955
960
  )
956
961
  * ann->scale_factor_out[ cur_neuron ]
957
962
  + ann->scale_new_min_out[ cur_neuron ];
@@ -977,7 +982,7 @@ FANN_EXTERNAL void FANN_API fann_descale_input( struct fann *ann, fann_type *inp
977
982
  - ann->scale_new_min_in[ cur_neuron ]
978
983
  )
979
984
  / ann->scale_factor_in[ cur_neuron ]
980
- + ( -1.0 ) /* This is old_min */
985
+ + ( (fann_type)-1.0 ) /* This is old_min */
981
986
  )
982
987
  * ann->scale_deviation_in[ cur_neuron ]
983
988
  + ann->scale_mean_in[ cur_neuron ];
@@ -1003,7 +1008,7 @@ FANN_EXTERNAL void FANN_API fann_descale_output( struct fann *ann, fann_type *ou
1003
1008
  - ann->scale_new_min_out[ cur_neuron ]
1004
1009
  )
1005
1010
  / ann->scale_factor_out[ cur_neuron ]
1006
- + ( -1.0 ) /* This is old_min */
1011
+ + ( (fann_type)-1.0 ) /* This is old_min */
1007
1012
  )
1008
1013
  * ann->scale_deviation_out[ cur_neuron ]
1009
1014
  + ann->scale_mean_out[ cur_neuron ];
@@ -1021,14 +1026,8 @@ FANN_EXTERNAL void FANN_API fann_scale_train( struct fann *ann, struct fann_trai
1021
1026
  return;
1022
1027
  }
1023
1028
  /* Check that we have good training data. */
1024
- /* No need for if( !params || !ann ) */
1025
- if( data->num_input != ann->num_input
1026
- || data->num_output != ann->num_output
1027
- )
1028
- {
1029
- fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
1029
+ if(fann_check_input_output_sizes(ann, data) == -1)
1030
1030
  return;
1031
- }
1032
1031
 
1033
1032
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
1034
1033
  {
@@ -1049,14 +1048,8 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
1049
1048
  return;
1050
1049
  }
1051
1050
  /* Check that we have good training data. */
1052
- /* No need for if( !params || !ann ) */
1053
- if( data->num_input != ann->num_input
1054
- || data->num_output != ann->num_output
1055
- )
1056
- {
1057
- fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
1051
+ if(fann_check_input_output_sizes(ann, data) == -1)
1058
1052
  return;
1059
- }
1060
1053
 
1061
1054
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
1062
1055
  {
@@ -1072,38 +1065,38 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
1072
1065
  #define SCALE_SET_PARAM( where ) \
1073
1066
  /* Calculate mean: sum(x)/length */ \
1074
1067
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1075
- ann->scale_mean_##where[ cur_neuron ] = 0.0; \
1068
+ ann->scale_mean_##where[ cur_neuron ] = 0.0f; \
1076
1069
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1077
1070
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
1078
- ann->scale_mean_##where[ cur_neuron ] += data->where##put[ cur_sample ][ cur_neuron ]; \
1071
+ ann->scale_mean_##where[ cur_neuron ] += (float)data->where##put[ cur_sample ][ cur_neuron ];\
1079
1072
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1080
1073
  ann->scale_mean_##where[ cur_neuron ] /= (float)data->num_data; \
1081
1074
  /* Calculate deviation: sqrt(sum((x-mean)^2)/length) */ \
1082
1075
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1083
- ann->scale_deviation_##where[ cur_neuron ] = 0.0; \
1076
+ ann->scale_deviation_##where[ cur_neuron ] = 0.0f; \
1084
1077
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1085
1078
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
1086
1079
  ann->scale_deviation_##where[ cur_neuron ] += \
1087
1080
  /* Another local variable in macro? Oh no! */ \
1088
1081
  ( \
1089
- data->where##put[ cur_sample ][ cur_neuron ] \
1082
+ (float)data->where##put[ cur_sample ][ cur_neuron ] \
1090
1083
  - ann->scale_mean_##where[ cur_neuron ] \
1091
1084
  ) \
1092
1085
  * \
1093
1086
  ( \
1094
- data->where##put[ cur_sample ][ cur_neuron ] \
1087
+ (float)data->where##put[ cur_sample ][ cur_neuron ] \
1095
1088
  - ann->scale_mean_##where[ cur_neuron ] \
1096
1089
  ); \
1097
1090
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1098
1091
  ann->scale_deviation_##where[ cur_neuron ] = \
1099
- sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
1092
+ (float)sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
1100
1093
  /* Calculate factor: (new_max-new_min)/(old_max(1)-old_min(-1)) */ \
1101
1094
  /* Looks like we dont need whole array of factors? */ \
1102
1095
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1103
1096
  ann->scale_factor_##where[ cur_neuron ] = \
1104
1097
  ( new_##where##put_max - new_##where##put_min ) \
1105
1098
  / \
1106
- ( 1.0 - ( -1.0 ) ); \
1099
+ ( 1.0f - ( -1.0f ) ); \
1107
1100
  /* Copy new minimum. */ \
1108
1101
  /* Looks like we dont need whole array of new minimums? */ \
1109
1102
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
@@ -1229,3 +1222,22 @@ FANN_EXTERNAL int FANN_API fann_clear_scaling_params(struct fann *ann)
1229
1222
  }
1230
1223
 
1231
1224
  #endif
1225
+
1226
+ int fann_check_input_output_sizes(struct fann *ann, struct fann_train_data *data)
1227
+ {
1228
+ if(ann->num_input != data->num_input)
1229
+ {
1230
+ fann_error((struct fann_error *) ann, FANN_E_INPUT_NO_MATCH,
1231
+ ann->num_input, data->num_input);
1232
+ return -1;
1233
+ }
1234
+
1235
+ if(ann->num_output != data->num_output)
1236
+ {
1237
+ fann_error((struct fann_error *) ann, FANN_E_OUTPUT_NO_MATCH,
1238
+ ann->num_output, data->num_output);
1239
+ return -1;
1240
+ }
1241
+
1242
+ return 0;
1243
+ }