ruby-fann 1.2.5 → 1.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -21,6 +21,7 @@
21
21
  #include <stdlib.h>
22
22
  #include <stdarg.h>
23
23
  #include <string.h>
24
+ #include <math.h>
24
25
 
25
26
  #include "config.h"
26
27
  #include "fann.h"
@@ -206,6 +207,7 @@ FANN_EXTERNAL unsigned int FANN_API fann_get_bit_fail(struct fann *ann)
206
207
  */
207
208
  FANN_EXTERNAL void FANN_API fann_reset_MSE(struct fann *ann)
208
209
  {
210
+ /*printf("resetMSE %d %f\n", ann->num_MSE, ann->MSE_value);*/
209
211
  ann->num_MSE = 0;
210
212
  ann->MSE_value = 0;
211
213
  ann->num_bit_fail = 0;
@@ -764,6 +766,85 @@ void fann_update_weights_irpropm(struct fann *ann, unsigned int first_weight, un
764
766
  }
765
767
  }
766
768
 
769
+ /* INTERNAL FUNCTION
770
+ The SARprop- algorithm
771
+ */
772
+ void fann_update_weights_sarprop(struct fann *ann, unsigned int epoch, unsigned int first_weight, unsigned int past_end)
773
+ {
774
+ fann_type *train_slopes = ann->train_slopes;
775
+ fann_type *weights = ann->weights;
776
+ fann_type *prev_steps = ann->prev_steps;
777
+ fann_type *prev_train_slopes = ann->prev_train_slopes;
778
+
779
+ fann_type prev_step, slope, prev_slope, next_step = 0, same_sign;
780
+
781
+ /* These should be set from variables */
782
+ float increase_factor = ann->rprop_increase_factor; /*1.2; */
783
+ float decrease_factor = ann->rprop_decrease_factor; /*0.5; */
784
+ /* TODO: why is delta_min 0.0 in iRprop? SARPROP uses 1x10^-6 (Braun and Riedmiller, 1993) */
785
+ float delta_min = 0.000001f;
786
+ float delta_max = ann->rprop_delta_max; /*50.0; */
787
+ float weight_decay_shift = ann->sarprop_weight_decay_shift; /* ld 0.01 = -6.644 */
788
+ float step_error_threshold_factor = ann->sarprop_step_error_threshold_factor; /* 0.1 */
789
+ float step_error_shift = ann->sarprop_step_error_shift; /* ld 3 = 1.585 */
790
+ float T = ann->sarprop_temperature;
791
+ float MSE = fann_get_MSE(ann);
792
+ float RMSE = (float)sqrt(MSE);
793
+
794
+ unsigned int i = first_weight;
795
+
796
+
797
+ /* for all weights; TODO: are biases included? */
798
+ for(; i != past_end; i++)
799
+ {
800
+ /* TODO: confirm whether 1x10^-6 == delta_min is really better */
801
+ prev_step = fann_max(prev_steps[i], (fann_type) 0.000001); /* prev_step may not be zero because then the training will stop */
802
+ /* calculate SARPROP slope; TODO: better as new error function? (see SARPROP paper)*/
803
+ slope = -train_slopes[i] - weights[i] * (fann_type)fann_exp2(-T * epoch + weight_decay_shift);
804
+
805
+ /* TODO: is prev_train_slopes[i] 0.0 in the beginning? */
806
+ prev_slope = prev_train_slopes[i];
807
+
808
+ same_sign = prev_slope * slope;
809
+
810
+ if(same_sign > 0.0)
811
+ {
812
+ next_step = fann_min(prev_step * increase_factor, delta_max);
813
+ /* TODO: are the signs inverted? see differences between SARPROP paper and iRprop */
814
+ if (slope < 0.0)
815
+ weights[i] += next_step;
816
+ else
817
+ weights[i] -= next_step;
818
+ }
819
+ else if(same_sign < 0.0)
820
+ {
821
+ if(prev_step < step_error_threshold_factor * MSE)
822
+ next_step = prev_step * decrease_factor + (float)rand() / RAND_MAX * RMSE * (fann_type)fann_exp2(-T * epoch + step_error_shift);
823
+ else
824
+ next_step = fann_max(prev_step * decrease_factor, delta_min);
825
+
826
+ slope = 0.0;
827
+ }
828
+ else
829
+ {
830
+ if(slope < 0.0)
831
+ weights[i] += prev_step;
832
+ else
833
+ weights[i] -= prev_step;
834
+ }
835
+
836
+
837
+ /*if(i == 2){
838
+ * printf("weight=%f, slope=%f, next_step=%f, prev_step=%f\n", weights[i], slope, next_step, prev_step);
839
+ * } */
840
+
841
+ /* update global data arrays */
842
+ prev_steps[i] = next_step;
843
+ prev_train_slopes[i] = slope;
844
+ train_slopes[i] = 0.0;
845
+ }
846
+ }
847
+
767
848
  #endif
768
849
 
769
850
  FANN_GET_SET(enum fann_train_enum, training_algorithm)
@@ -957,6 +1038,10 @@ FANN_GET_SET(float, rprop_decrease_factor)
957
1038
  FANN_GET_SET(float, rprop_delta_min)
958
1039
  FANN_GET_SET(float, rprop_delta_max)
959
1040
  FANN_GET_SET(float, rprop_delta_zero)
1041
+ FANN_GET_SET(float, sarprop_weight_decay_shift)
1042
+ FANN_GET_SET(float, sarprop_step_error_threshold_factor)
1043
+ FANN_GET_SET(float, sarprop_step_error_shift)
1044
+ FANN_GET_SET(float, sarprop_temperature)
960
1045
  FANN_GET_SET(enum fann_stopfunc_enum, train_stop_function)
961
1046
  FANN_GET_SET(fann_type, bit_fail_limit)
962
1047
  FANN_GET_SET(float, learning_momentum)
@@ -1,6 +1,6 @@
1
1
  /*
2
2
  Fast Artificial Neural Network Library (fann)
3
- Copyright (C) 2003 Steffen Nissen (lukesky@diku.dk)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
4
 
5
5
  This library is free software; you can redistribute it and/or
6
6
  modify it under the terms of the GNU Lesser General Public
@@ -252,6 +252,17 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
252
252
  FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *filename);
253
253
 
254
254
 
255
+ /* Function: fann_create_train
256
+ Creates an empty training data struct.
257
+
258
+ See also:
259
+ <fann_read_train_from_file>, <fann_train_on_data>, <fann_destroy_train>,
260
+ <fann_save_train>
261
+
262
+ This function appears in FANN >= 2.2.0
263
+ */
264
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output);
265
+
255
266
  /* Function: fann_create_train_from_callback
256
267
  Creates the training data struct from a user supplied function.
257
268
  As the training data are numerable (data 1, data 2...), the user must write
@@ -1200,4 +1211,100 @@ FANN_EXTERNAL float FANN_API fann_get_rprop_delta_zero(struct fann *ann);
1200
1211
  */
1201
1212
  FANN_EXTERNAL void FANN_API fann_set_rprop_delta_zero(struct fann *ann, float rprop_delta_max);
1202
1213
 
1214
+ /* Method: fann_get_sarprop_weight_decay_shift
1215
+
1216
+ The sarprop weight decay shift.
1217
+
1218
+ The default delta max is -6.644.
1219
+
1220
+ See also:
1221
+ <fann fann_set_sarprop_weight_decay_shift>
1222
+
1223
+ This function appears in FANN >= 2.1.0.
1224
+ */
1225
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_weight_decay_shift(struct fann *ann);
1226
+
1227
+ /* Method: fann_set_sarprop_weight_decay_shift
1228
+
1229
+ Set the sarprop weight decay shift.
1230
+
1231
+ This function appears in FANN >= 2.1.0.
1232
+
1233
+ See also:
1234
+ <fann_set_sarprop_weight_decay_shift>
1235
+ */
1236
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_weight_decay_shift(struct fann *ann, float sarprop_weight_decay_shift);
1237
+
1238
+ /* Method: fann_get_sarprop_step_error_threshold_factor
1239
+
1240
+ The sarprop step error threshold factor.
1241
+
1242
+ The default delta max is 0.1.
1243
+
1244
+ See also:
1245
+ <fann fann_get_sarprop_step_error_threshold_factor>
1246
+
1247
+ This function appears in FANN >= 2.1.0.
1248
+ */
1249
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_threshold_factor(struct fann *ann);
1250
+
1251
+ /* Method: fann_set_sarprop_step_error_threshold_factor
1252
+
1253
+ Set the sarprop step error threshold factor.
1254
+
1255
+ This function appears in FANN >= 2.1.0.
1256
+
1257
+ See also:
1258
+ <fann_get_sarprop_step_error_threshold_factor>
1259
+ */
1260
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_threshold_factor(struct fann *ann, float sarprop_step_error_threshold_factor);
1261
+
1262
+ /* Method: fann_get_sarprop_step_error_shift
1263
+
1264
+ The get sarprop step error shift.
1265
+
1266
+ The default delta max is 1.385.
1267
+
1268
+ See also:
1269
+ <fann_set_sarprop_step_error_shift>
1270
+
1271
+ This function appears in FANN >= 2.1.0.
1272
+ */
1273
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_step_error_shift(struct fann *ann);
1274
+
1275
+ /* Method: fann_set_sarprop_step_error_shift
1276
+
1277
+ Set the sarprop step error shift.
1278
+
1279
+ This function appears in FANN >= 2.1.0.
1280
+
1281
+ See also:
1282
+ <fann_get_sarprop_step_error_shift>
1283
+ */
1284
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_step_error_shift(struct fann *ann, float sarprop_step_error_shift);
1285
+
1286
+ /* Method: fann_get_sarprop_temperature
1287
+
1288
+ The sarprop weight decay shift.
1289
+
1290
+ The default delta max is 0.015.
1291
+
1292
+ See also:
1293
+ <fann_set_sarprop_temperature>
1294
+
1295
+ This function appears in FANN >= 2.1.0.
1296
+ */
1297
+ FANN_EXTERNAL float FANN_API fann_get_sarprop_temperature(struct fann *ann);
1298
+
1299
+ /* Method: fann_set_sarprop_temperature
1300
+
1301
+ Set the sarprop_temperature.
1302
+
1303
+ This function appears in FANN >= 2.1.0.
1304
+
1305
+ See also:
1306
+ <fann_get_sarprop_temperature>
1307
+ */
1308
+ FANN_EXTERNAL void FANN_API fann_set_sarprop_temperature(struct fann *ann, float sarprop_temperature);
1309
+
1203
1310
  #endif
@@ -1,21 +1,21 @@
1
1
  /*
2
- * Fast Artificial Neural Network Library (fann) Copyright (C) 2003
3
- * Steffen Nissen (lukesky@diku.dk)
4
- *
5
- * This library is free software; you can redistribute it and/or modify it
6
- * under the terms of the GNU Lesser General Public License as published
7
- * by the Free Software Foundation; either version 2.1 of the License, or
8
- * (at your option) any later version.
9
- *
10
- * This library is distributed in the hope that it will be useful, but
11
- * WITHOUT ANY WARRANTY; without even the implied warranty of
12
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
- * Lesser General Public License for more details.
14
- *
15
- * You should have received a copy of the GNU Lesser General Public
16
- * License along with this library; if not, write to the Free Software
17
- * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18
- */
2
+ Fast Artificial Neural Network Library (fann)
3
+ Copyright (C) 2003-2012 Steffen Nissen (sn@leenissen.dk)
4
+
5
+ This library is free software; you can redistribute it and/or
6
+ modify it under the terms of the GNU Lesser General Public
7
+ License as published by the Free Software Foundation; either
8
+ version 2.1 of the License, or (at your option) any later version.
9
+
10
+ This library is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
13
+ Lesser General Public License for more details.
14
+
15
+ You should have received a copy of the GNU Lesser General Public
16
+ License along with this library; if not, write to the Free Software
17
+ Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18
+ */
19
19
 
20
20
  #include <stdio.h>
21
21
  #include <stdlib.h>
@@ -84,7 +84,9 @@ FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
84
84
  FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
85
85
  {
86
86
  unsigned int i;
87
-
87
+ if(fann_check_input_output_sizes(ann, data) == -1)
88
+ return 0;
89
+
88
90
  fann_reset_MSE(ann);
89
91
 
90
92
  for(i = 0; i != data->num_data; i++)
@@ -95,86 +97,38 @@ FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_
95
97
  return fann_get_MSE(ann);
96
98
  }
97
99
 
100
+ #ifndef FIXEDFANN
101
+
98
102
  /*
99
- * Creates training data from a callback function.
103
+ * Internal train function
100
104
  */
101
- FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
102
- unsigned int num_input,
103
- unsigned int num_output,
104
- void (FANN_API *user_function)( unsigned int,
105
- unsigned int,
106
- unsigned int,
107
- fann_type * ,
108
- fann_type * ))
105
+ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
109
106
  {
110
- unsigned int i;
111
- fann_type *data_input, *data_output;
112
- struct fann_train_data *data = (struct fann_train_data *)
113
- malloc(sizeof(struct fann_train_data));
114
-
115
- if(data == NULL){
116
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
117
- return NULL;
118
- }
119
-
120
- fann_init_error_data((struct fann_error *) data);
121
-
122
- data->num_data = num_data;
123
- data->num_input = num_input;
124
- data->num_output = num_output;
125
-
126
- data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
127
- if(data->input == NULL)
128
- {
129
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
130
- fann_destroy_train(data);
131
- return NULL;
132
- }
133
-
134
- data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
135
- if(data->output == NULL)
136
- {
137
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
138
- fann_destroy_train(data);
139
- return NULL;
140
- }
141
-
142
- data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
143
- if(data_input == NULL)
144
- {
145
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
146
- fann_destroy_train(data);
147
- return NULL;
148
- }
149
-
150
- data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
151
- if(data_output == NULL)
152
- {
153
- fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
154
- fann_destroy_train(data);
155
- return NULL;
156
- }
157
-
158
- for( i = 0; i != num_data; i++)
159
- {
160
- data->input[i] = data_input;
161
- data_input += num_input;
107
+ unsigned int i;
162
108
 
163
- data->output[i] = data_output;
164
- data_output += num_output;
109
+ if(ann->prev_train_slopes == NULL)
110
+ {
111
+ fann_clear_train_arrays(ann);
112
+ }
165
113
 
166
- (*user_function)(i, num_input, num_output, data->input[i],data->output[i] );
167
- }
114
+ fann_reset_MSE(ann);
168
115
 
169
- return data;
170
- }
116
+ for(i = 0; i < data->num_data; i++)
117
+ {
118
+ fann_run(ann, data->input[i]);
119
+ fann_compute_MSE(ann, data->output[i]);
120
+ fann_backpropagate_MSE(ann);
121
+ fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
122
+ }
123
+ fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
171
124
 
172
- #ifndef FIXEDFANN
125
+ return fann_get_MSE(ann);
126
+ }
173
127
 
174
128
  /*
175
129
  * Internal train function
176
130
  */
177
- float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
131
+ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
178
132
  {
179
133
  unsigned int i;
180
134
 
@@ -192,7 +146,8 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
192
146
  fann_backpropagate_MSE(ann);
193
147
  fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
194
148
  }
195
- fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
149
+
150
+ fann_update_weights_irpropm(ann, 0, ann->total_connections);
196
151
 
197
152
  return fann_get_MSE(ann);
198
153
  }
@@ -200,7 +155,7 @@ float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
200
155
  /*
201
156
  * Internal train function
202
157
  */
203
- float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
158
+ float fann_train_epoch_sarprop(struct fann *ann, struct fann_train_data *data)
204
159
  {
205
160
  unsigned int i;
206
161
 
@@ -219,7 +174,9 @@ float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
219
174
  fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
220
175
  }
221
176
 
222
- fann_update_weights_irpropm(ann, 0, ann->total_connections);
177
+ fann_update_weights_sarprop(ann, ann->sarprop_epoch, 0, ann->total_connections);
178
+
179
+ ++(ann->sarprop_epoch);
223
180
 
224
181
  return fann_get_MSE(ann);
225
182
  }
@@ -268,12 +225,17 @@ float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *dat
268
225
  */
269
226
  FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
270
227
  {
228
+ if(fann_check_input_output_sizes(ann, data) == -1)
229
+ return 0;
230
+
271
231
  switch (ann->training_algorithm)
272
232
  {
273
233
  case FANN_TRAIN_QUICKPROP:
274
234
  return fann_train_epoch_quickprop(ann, data);
275
235
  case FANN_TRAIN_RPROP:
276
236
  return fann_train_epoch_irpropm(ann, data);
237
+ case FANN_TRAIN_SARPROP:
238
+ return fann_train_epoch_sarprop(ann, data);
277
239
  case FANN_TRAIN_BATCH:
278
240
  return fann_train_epoch_batch(ann, data);
279
241
  case FANN_TRAIN_INCREMENTAL:
@@ -795,15 +757,13 @@ int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const
795
757
  return retval;
796
758
  }
797
759
 
798
-
799
760
  /*
800
- * INTERNAL FUNCTION Reads training data from a file descriptor.
761
+ * Creates an empty set of training data
801
762
  */
802
- struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
763
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train(unsigned int num_data, unsigned int num_input, unsigned int num_output)
803
764
  {
804
- unsigned int num_input, num_output, num_data, i, j;
805
- unsigned int line = 1;
806
765
  fann_type *data_input, *data_output;
766
+ unsigned int i;
807
767
  struct fann_train_data *data =
808
768
  (struct fann_train_data *) malloc(sizeof(struct fann_train_data));
809
769
 
@@ -812,15 +772,7 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
812
772
  fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
813
773
  return NULL;
814
774
  }
815
-
816
- if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
817
- {
818
- fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
819
- fann_destroy_train(data);
820
- return NULL;
821
- }
822
- line++;
823
-
775
+
824
776
  fann_init_error_data((struct fann_error *) data);
825
777
 
826
778
  data->num_data = num_data;
@@ -862,7 +814,63 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
862
814
  {
863
815
  data->input[i] = data_input;
864
816
  data_input += num_input;
817
+ data->output[i] = data_output;
818
+ data_output += num_output;
819
+ }
820
+ return data;
821
+ }
822
+
823
+ /*
824
+ * Creates training data from a callback function.
825
+ */
826
+ FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
827
+ unsigned int num_input,
828
+ unsigned int num_output,
829
+ void (FANN_API *user_function)( unsigned int,
830
+ unsigned int,
831
+ unsigned int,
832
+ fann_type * ,
833
+ fann_type * ))
834
+ {
835
+ unsigned int i;
836
+ struct fann_train_data *data = fann_create_train(num_data, num_input, num_output);
837
+ if(data == NULL)
838
+ {
839
+ return NULL;
840
+ }
841
+
842
+ for( i = 0; i != num_data; i++)
843
+ {
844
+ (*user_function)(i, num_input, num_output, data->input[i], data->output[i]);
845
+ }
846
+
847
+ return data;
848
+ }
865
849
 
850
+ /*
851
+ * INTERNAL FUNCTION Reads training data from a file descriptor.
852
+ */
853
+ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
854
+ {
855
+ unsigned int num_input, num_output, num_data, i, j;
856
+ unsigned int line = 1;
857
+ struct fann_train_data *data;
858
+
859
+ if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
860
+ {
861
+ fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
862
+ return NULL;
863
+ }
864
+ line++;
865
+
866
+ data = fann_create_train(num_data, num_input, num_output);
867
+ if(data == NULL)
868
+ {
869
+ return NULL;
870
+ }
871
+
872
+ for(i = 0; i != num_data; i++)
873
+ {
866
874
  for(j = 0; j != num_input; j++)
867
875
  {
868
876
  if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
@@ -874,9 +882,6 @@ struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filenam
874
882
  }
875
883
  line++;
876
884
 
877
- data->output[i] = data_output;
878
- data_output += num_output;
879
-
880
885
  for(j = 0; j != num_output; j++)
881
886
  {
882
887
  if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
@@ -928,7 +933,7 @@ FANN_EXTERNAL void FANN_API fann_scale_input( struct fann *ann, fann_type *input
928
933
  (
929
934
  ( input_vector[ cur_neuron ] - ann->scale_mean_in[ cur_neuron ] )
930
935
  / ann->scale_deviation_in[ cur_neuron ]
931
- - ( -1.0 ) /* This is old_min */
936
+ - ( (fann_type)-1.0 ) /* This is old_min */
932
937
  )
933
938
  * ann->scale_factor_in[ cur_neuron ]
934
939
  + ann->scale_new_min_in[ cur_neuron ];
@@ -951,7 +956,7 @@ FANN_EXTERNAL void FANN_API fann_scale_output( struct fann *ann, fann_type *outp
951
956
  (
952
957
  ( output_vector[ cur_neuron ] - ann->scale_mean_out[ cur_neuron ] )
953
958
  / ann->scale_deviation_out[ cur_neuron ]
954
- - ( -1.0 ) /* This is old_min */
959
+ - ( (fann_type)-1.0 ) /* This is old_min */
955
960
  )
956
961
  * ann->scale_factor_out[ cur_neuron ]
957
962
  + ann->scale_new_min_out[ cur_neuron ];
@@ -977,7 +982,7 @@ FANN_EXTERNAL void FANN_API fann_descale_input( struct fann *ann, fann_type *inp
977
982
  - ann->scale_new_min_in[ cur_neuron ]
978
983
  )
979
984
  / ann->scale_factor_in[ cur_neuron ]
980
- + ( -1.0 ) /* This is old_min */
985
+ + ( (fann_type)-1.0 ) /* This is old_min */
981
986
  )
982
987
  * ann->scale_deviation_in[ cur_neuron ]
983
988
  + ann->scale_mean_in[ cur_neuron ];
@@ -1003,7 +1008,7 @@ FANN_EXTERNAL void FANN_API fann_descale_output( struct fann *ann, fann_type *ou
1003
1008
  - ann->scale_new_min_out[ cur_neuron ]
1004
1009
  )
1005
1010
  / ann->scale_factor_out[ cur_neuron ]
1006
- + ( -1.0 ) /* This is old_min */
1011
+ + ( (fann_type)-1.0 ) /* This is old_min */
1007
1012
  )
1008
1013
  * ann->scale_deviation_out[ cur_neuron ]
1009
1014
  + ann->scale_mean_out[ cur_neuron ];
@@ -1021,14 +1026,8 @@ FANN_EXTERNAL void FANN_API fann_scale_train( struct fann *ann, struct fann_trai
1021
1026
  return;
1022
1027
  }
1023
1028
  /* Check that we have good training data. */
1024
- /* No need for if( !params || !ann ) */
1025
- if( data->num_input != ann->num_input
1026
- || data->num_output != ann->num_output
1027
- )
1028
- {
1029
- fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
1029
+ if(fann_check_input_output_sizes(ann, data) == -1)
1030
1030
  return;
1031
- }
1032
1031
 
1033
1032
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
1034
1033
  {
@@ -1049,14 +1048,8 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
1049
1048
  return;
1050
1049
  }
1051
1050
  /* Check that we have good training data. */
1052
- /* No need for if( !params || !ann ) */
1053
- if( data->num_input != ann->num_input
1054
- || data->num_output != ann->num_output
1055
- )
1056
- {
1057
- fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
1051
+ if(fann_check_input_output_sizes(ann, data) == -1)
1058
1052
  return;
1059
- }
1060
1053
 
1061
1054
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
1062
1055
  {
@@ -1072,38 +1065,38 @@ FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_tr
1072
1065
  #define SCALE_SET_PARAM( where ) \
1073
1066
  /* Calculate mean: sum(x)/length */ \
1074
1067
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1075
- ann->scale_mean_##where[ cur_neuron ] = 0.0; \
1068
+ ann->scale_mean_##where[ cur_neuron ] = 0.0f; \
1076
1069
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1077
1070
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
1078
- ann->scale_mean_##where[ cur_neuron ] += data->where##put[ cur_sample ][ cur_neuron ]; \
1071
+ ann->scale_mean_##where[ cur_neuron ] += (float)data->where##put[ cur_sample ][ cur_neuron ];\
1079
1072
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1080
1073
  ann->scale_mean_##where[ cur_neuron ] /= (float)data->num_data; \
1081
1074
  /* Calculate deviation: sqrt(sum((x-mean)^2)/length) */ \
1082
1075
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1083
- ann->scale_deviation_##where[ cur_neuron ] = 0.0; \
1076
+ ann->scale_deviation_##where[ cur_neuron ] = 0.0f; \
1084
1077
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1085
1078
  for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
1086
1079
  ann->scale_deviation_##where[ cur_neuron ] += \
1087
1080
  /* Another local variable in macro? Oh no! */ \
1088
1081
  ( \
1089
- data->where##put[ cur_sample ][ cur_neuron ] \
1082
+ (float)data->where##put[ cur_sample ][ cur_neuron ] \
1090
1083
  - ann->scale_mean_##where[ cur_neuron ] \
1091
1084
  ) \
1092
1085
  * \
1093
1086
  ( \
1094
- data->where##put[ cur_sample ][ cur_neuron ] \
1087
+ (float)data->where##put[ cur_sample ][ cur_neuron ] \
1095
1088
  - ann->scale_mean_##where[ cur_neuron ] \
1096
1089
  ); \
1097
1090
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1098
1091
  ann->scale_deviation_##where[ cur_neuron ] = \
1099
- sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
1092
+ (float)sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
1100
1093
  /* Calculate factor: (new_max-new_min)/(old_max(1)-old_min(-1)) */ \
1101
1094
  /* Looks like we dont need whole array of factors? */ \
1102
1095
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
1103
1096
  ann->scale_factor_##where[ cur_neuron ] = \
1104
1097
  ( new_##where##put_max - new_##where##put_min ) \
1105
1098
  / \
1106
- ( 1.0 - ( -1.0 ) ); \
1099
+ ( 1.0f - ( -1.0f ) ); \
1107
1100
  /* Copy new minimum. */ \
1108
1101
  /* Looks like we dont need whole array of new minimums? */ \
1109
1102
  for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
@@ -1229,3 +1222,22 @@ FANN_EXTERNAL int FANN_API fann_clear_scaling_params(struct fann *ann)
1229
1222
  }
1230
1223
 
1231
1224
  #endif
1225
+
1226
+ int fann_check_input_output_sizes(struct fann *ann, struct fann_train_data *data)
1227
+ {
1228
+ if(ann->num_input != data->num_input)
1229
+ {
1230
+ fann_error((struct fann_error *) ann, FANN_E_INPUT_NO_MATCH,
1231
+ ann->num_input, data->num_input);
1232
+ return -1;
1233
+ }
1234
+
1235
+ if(ann->num_output != data->num_output)
1236
+ {
1237
+ fann_error((struct fann_error *) ann, FANN_E_OUTPUT_NO_MATCH,
1238
+ ann->num_output, data->num_output);
1239
+ return -1;
1240
+ }
1241
+
1242
+ return 0;
1243
+ }