ruby-fann 0.7.10 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/History.txt +6 -1
- data/License.txt +1 -1
- data/Manifest.txt +22 -1
- data/README.txt +0 -1
- data/Rakefile +0 -0
- data/config/hoe.rb +0 -0
- data/config/requirements.rb +0 -0
- data/ext/ruby_fann/MANIFEST +0 -0
- data/ext/ruby_fann/Makefile +36 -28
- data/ext/ruby_fann/doublefann.c +30 -0
- data/ext/ruby_fann/doublefann.h +33 -0
- data/ext/ruby_fann/extconf.rb +9 -5
- data/ext/ruby_fann/fann.c +1552 -0
- data/ext/ruby_fann/fann_activation.h +144 -0
- data/ext/ruby_fann/fann_augment.h +0 -0
- data/ext/ruby_fann/fann_cascade.c +1031 -0
- data/ext/ruby_fann/fann_cascade.h +503 -0
- data/ext/ruby_fann/fann_data.h +799 -0
- data/ext/ruby_fann/fann_error.c +204 -0
- data/ext/ruby_fann/fann_error.h +161 -0
- data/ext/ruby_fann/fann_internal.h +148 -0
- data/ext/ruby_fann/fann_io.c +762 -0
- data/ext/ruby_fann/fann_io.h +100 -0
- data/ext/ruby_fann/fann_train.c +962 -0
- data/ext/ruby_fann/fann_train.h +1203 -0
- data/ext/ruby_fann/fann_train_data.c +1231 -0
- data/ext/ruby_fann/neural_network.c +0 -0
- data/lib/ruby_fann/neurotica.rb +0 -0
- data/lib/ruby_fann/version.rb +3 -3
- data/lib/ruby_fann.rb +0 -0
- data/neurotica1.png +0 -0
- data/neurotica2.vrml +18 -18
- data/setup.rb +0 -0
- data/tasks/deployment.rake +0 -0
- data/tasks/environment.rake +0 -0
- data/tasks/website.rake +0 -0
- data/test/test.train +0 -0
- data/test/test_helper.rb +0 -0
- data/test/test_neurotica.rb +0 -0
- data/test/test_ruby_fann.rb +0 -0
- data/test/test_ruby_fann_functional.rb +0 -0
- data/verify.train +0 -0
- data/website/index.html +42 -92
- data/website/index.txt +0 -0
- data/website/javascripts/rounded_corners_lite.inc.js +0 -0
- data/website/stylesheets/screen.css +0 -0
- data/website/template.rhtml +0 -0
- data/xor.train +0 -0
- data/xor_cascade.net +2 -2
- data/xor_float.net +1 -1
- metadata +22 -6
- data/log/debug.log +0 -0
@@ -0,0 +1,1231 @@
|
|
1
|
+
/*
|
2
|
+
* Fast Artificial Neural Network Library (fann) Copyright (C) 2003
|
3
|
+
* Steffen Nissen (lukesky@diku.dk)
|
4
|
+
*
|
5
|
+
* This library is free software; you can redistribute it and/or modify it
|
6
|
+
* under the terms of the GNU Lesser General Public License as published
|
7
|
+
* by the Free Software Foundation; either version 2.1 of the License, or
|
8
|
+
* (at your option) any later version.
|
9
|
+
*
|
10
|
+
* This library is distributed in the hope that it will be useful, but
|
11
|
+
* WITHOUT ANY WARRANTY; without even the implied warranty of
|
12
|
+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
13
|
+
* Lesser General Public License for more details.
|
14
|
+
*
|
15
|
+
* You should have received a copy of the GNU Lesser General Public
|
16
|
+
* License along with this library; if not, write to the Free Software
|
17
|
+
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
|
18
|
+
*/
|
19
|
+
|
20
|
+
#include <stdio.h>
|
21
|
+
#include <stdlib.h>
|
22
|
+
#include <stdarg.h>
|
23
|
+
#include <string.h>
|
24
|
+
|
25
|
+
#include "config.h"
|
26
|
+
#include "fann.h"
|
27
|
+
|
28
|
+
/*
|
29
|
+
* Reads training data from a file.
|
30
|
+
*/
|
31
|
+
FANN_EXTERNAL struct fann_train_data *FANN_API fann_read_train_from_file(const char *configuration_file)
|
32
|
+
{
|
33
|
+
struct fann_train_data *data;
|
34
|
+
FILE *file = fopen(configuration_file, "r");
|
35
|
+
|
36
|
+
if(!file)
|
37
|
+
{
|
38
|
+
fann_error(NULL, FANN_E_CANT_OPEN_CONFIG_R, configuration_file);
|
39
|
+
return NULL;
|
40
|
+
}
|
41
|
+
|
42
|
+
data = fann_read_train_from_fd(file, configuration_file);
|
43
|
+
fclose(file);
|
44
|
+
return data;
|
45
|
+
}
|
46
|
+
|
47
|
+
/*
|
48
|
+
* Save training data to a file
|
49
|
+
*/
|
50
|
+
FANN_EXTERNAL int FANN_API fann_save_train(struct fann_train_data *data, const char *filename)
|
51
|
+
{
|
52
|
+
return fann_save_train_internal(data, filename, 0, 0);
|
53
|
+
}
|
54
|
+
|
55
|
+
/*
|
56
|
+
* Save training data to a file in fixed point algebra. (Good for testing
|
57
|
+
* a network in fixed point)
|
58
|
+
*/
|
59
|
+
FANN_EXTERNAL int FANN_API fann_save_train_to_fixed(struct fann_train_data *data, const char *filename,
|
60
|
+
unsigned int decimal_point)
|
61
|
+
{
|
62
|
+
return fann_save_train_internal(data, filename, 1, decimal_point);
|
63
|
+
}
|
64
|
+
|
65
|
+
/*
|
66
|
+
* deallocate the train data structure.
|
67
|
+
*/
|
68
|
+
FANN_EXTERNAL void FANN_API fann_destroy_train(struct fann_train_data *data)
|
69
|
+
{
|
70
|
+
if(data == NULL)
|
71
|
+
return;
|
72
|
+
if(data->input != NULL)
|
73
|
+
fann_safe_free(data->input[0]);
|
74
|
+
if(data->output != NULL)
|
75
|
+
fann_safe_free(data->output[0]);
|
76
|
+
fann_safe_free(data->input);
|
77
|
+
fann_safe_free(data->output);
|
78
|
+
fann_safe_free(data);
|
79
|
+
}
|
80
|
+
|
81
|
+
/*
|
82
|
+
* Test a set of training data and calculate the MSE
|
83
|
+
*/
|
84
|
+
FANN_EXTERNAL float FANN_API fann_test_data(struct fann *ann, struct fann_train_data *data)
|
85
|
+
{
|
86
|
+
unsigned int i;
|
87
|
+
|
88
|
+
fann_reset_MSE(ann);
|
89
|
+
|
90
|
+
for(i = 0; i != data->num_data; i++)
|
91
|
+
{
|
92
|
+
fann_test(ann, data->input[i], data->output[i]);
|
93
|
+
}
|
94
|
+
|
95
|
+
return fann_get_MSE(ann);
|
96
|
+
}
|
97
|
+
|
98
|
+
/*
|
99
|
+
* Creates training data from a callback function.
|
100
|
+
*/
|
101
|
+
FANN_EXTERNAL struct fann_train_data * FANN_API fann_create_train_from_callback(unsigned int num_data,
|
102
|
+
unsigned int num_input,
|
103
|
+
unsigned int num_output,
|
104
|
+
void (FANN_API *user_function)( unsigned int,
|
105
|
+
unsigned int,
|
106
|
+
unsigned int,
|
107
|
+
fann_type * ,
|
108
|
+
fann_type * ))
|
109
|
+
{
|
110
|
+
unsigned int i;
|
111
|
+
fann_type *data_input, *data_output;
|
112
|
+
struct fann_train_data *data = (struct fann_train_data *)
|
113
|
+
malloc(sizeof(struct fann_train_data));
|
114
|
+
|
115
|
+
if(data == NULL){
|
116
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
117
|
+
return NULL;
|
118
|
+
}
|
119
|
+
|
120
|
+
fann_init_error_data((struct fann_error *) data);
|
121
|
+
|
122
|
+
data->num_data = num_data;
|
123
|
+
data->num_input = num_input;
|
124
|
+
data->num_output = num_output;
|
125
|
+
|
126
|
+
data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
127
|
+
if(data->input == NULL)
|
128
|
+
{
|
129
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
130
|
+
fann_destroy_train(data);
|
131
|
+
return NULL;
|
132
|
+
}
|
133
|
+
|
134
|
+
data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
135
|
+
if(data->output == NULL)
|
136
|
+
{
|
137
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
138
|
+
fann_destroy_train(data);
|
139
|
+
return NULL;
|
140
|
+
}
|
141
|
+
|
142
|
+
data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
|
143
|
+
if(data_input == NULL)
|
144
|
+
{
|
145
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
146
|
+
fann_destroy_train(data);
|
147
|
+
return NULL;
|
148
|
+
}
|
149
|
+
|
150
|
+
data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
|
151
|
+
if(data_output == NULL)
|
152
|
+
{
|
153
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
154
|
+
fann_destroy_train(data);
|
155
|
+
return NULL;
|
156
|
+
}
|
157
|
+
|
158
|
+
for( i = 0; i != num_data; i++)
|
159
|
+
{
|
160
|
+
data->input[i] = data_input;
|
161
|
+
data_input += num_input;
|
162
|
+
|
163
|
+
data->output[i] = data_output;
|
164
|
+
data_output += num_output;
|
165
|
+
|
166
|
+
(*user_function)(i, num_input, num_output, data->input[i],data->output[i] );
|
167
|
+
}
|
168
|
+
|
169
|
+
return data;
|
170
|
+
}
|
171
|
+
|
172
|
+
#ifndef FIXEDFANN
|
173
|
+
|
174
|
+
/*
|
175
|
+
* Internal train function
|
176
|
+
*/
|
177
|
+
float fann_train_epoch_quickprop(struct fann *ann, struct fann_train_data *data)
|
178
|
+
{
|
179
|
+
unsigned int i;
|
180
|
+
|
181
|
+
if(ann->prev_train_slopes == NULL)
|
182
|
+
{
|
183
|
+
fann_clear_train_arrays(ann);
|
184
|
+
}
|
185
|
+
|
186
|
+
fann_reset_MSE(ann);
|
187
|
+
|
188
|
+
for(i = 0; i < data->num_data; i++)
|
189
|
+
{
|
190
|
+
fann_run(ann, data->input[i]);
|
191
|
+
fann_compute_MSE(ann, data->output[i]);
|
192
|
+
fann_backpropagate_MSE(ann);
|
193
|
+
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
194
|
+
}
|
195
|
+
fann_update_weights_quickprop(ann, data->num_data, 0, ann->total_connections);
|
196
|
+
|
197
|
+
return fann_get_MSE(ann);
|
198
|
+
}
|
199
|
+
|
200
|
+
/*
|
201
|
+
* Internal train function
|
202
|
+
*/
|
203
|
+
float fann_train_epoch_irpropm(struct fann *ann, struct fann_train_data *data)
|
204
|
+
{
|
205
|
+
unsigned int i;
|
206
|
+
|
207
|
+
if(ann->prev_train_slopes == NULL)
|
208
|
+
{
|
209
|
+
fann_clear_train_arrays(ann);
|
210
|
+
}
|
211
|
+
|
212
|
+
fann_reset_MSE(ann);
|
213
|
+
|
214
|
+
for(i = 0; i < data->num_data; i++)
|
215
|
+
{
|
216
|
+
fann_run(ann, data->input[i]);
|
217
|
+
fann_compute_MSE(ann, data->output[i]);
|
218
|
+
fann_backpropagate_MSE(ann);
|
219
|
+
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
220
|
+
}
|
221
|
+
|
222
|
+
fann_update_weights_irpropm(ann, 0, ann->total_connections);
|
223
|
+
|
224
|
+
return fann_get_MSE(ann);
|
225
|
+
}
|
226
|
+
|
227
|
+
/*
|
228
|
+
* Internal train function
|
229
|
+
*/
|
230
|
+
float fann_train_epoch_batch(struct fann *ann, struct fann_train_data *data)
|
231
|
+
{
|
232
|
+
unsigned int i;
|
233
|
+
|
234
|
+
fann_reset_MSE(ann);
|
235
|
+
|
236
|
+
for(i = 0; i < data->num_data; i++)
|
237
|
+
{
|
238
|
+
fann_run(ann, data->input[i]);
|
239
|
+
fann_compute_MSE(ann, data->output[i]);
|
240
|
+
fann_backpropagate_MSE(ann);
|
241
|
+
fann_update_slopes_batch(ann, ann->first_layer + 1, ann->last_layer - 1);
|
242
|
+
}
|
243
|
+
|
244
|
+
fann_update_weights_batch(ann, data->num_data, 0, ann->total_connections);
|
245
|
+
|
246
|
+
return fann_get_MSE(ann);
|
247
|
+
}
|
248
|
+
|
249
|
+
/*
|
250
|
+
* Internal train function
|
251
|
+
*/
|
252
|
+
float fann_train_epoch_incremental(struct fann *ann, struct fann_train_data *data)
|
253
|
+
{
|
254
|
+
unsigned int i;
|
255
|
+
|
256
|
+
fann_reset_MSE(ann);
|
257
|
+
|
258
|
+
for(i = 0; i != data->num_data; i++)
|
259
|
+
{
|
260
|
+
fann_train(ann, data->input[i], data->output[i]);
|
261
|
+
}
|
262
|
+
|
263
|
+
return fann_get_MSE(ann);
|
264
|
+
}
|
265
|
+
|
266
|
+
/*
|
267
|
+
* Train for one epoch with the selected training algorithm
|
268
|
+
*/
|
269
|
+
FANN_EXTERNAL float FANN_API fann_train_epoch(struct fann *ann, struct fann_train_data *data)
|
270
|
+
{
|
271
|
+
switch (ann->training_algorithm)
|
272
|
+
{
|
273
|
+
case FANN_TRAIN_QUICKPROP:
|
274
|
+
return fann_train_epoch_quickprop(ann, data);
|
275
|
+
case FANN_TRAIN_RPROP:
|
276
|
+
return fann_train_epoch_irpropm(ann, data);
|
277
|
+
case FANN_TRAIN_BATCH:
|
278
|
+
return fann_train_epoch_batch(ann, data);
|
279
|
+
case FANN_TRAIN_INCREMENTAL:
|
280
|
+
return fann_train_epoch_incremental(ann, data);
|
281
|
+
}
|
282
|
+
return 0;
|
283
|
+
}
|
284
|
+
|
285
|
+
FANN_EXTERNAL void FANN_API fann_train_on_data(struct fann *ann, struct fann_train_data *data,
|
286
|
+
unsigned int max_epochs,
|
287
|
+
unsigned int epochs_between_reports,
|
288
|
+
float desired_error)
|
289
|
+
{
|
290
|
+
float error;
|
291
|
+
unsigned int i;
|
292
|
+
int desired_error_reached;
|
293
|
+
|
294
|
+
#ifdef DEBUG
|
295
|
+
printf("Training with %s\n", FANN_TRAIN_NAMES[ann->training_algorithm]);
|
296
|
+
#endif
|
297
|
+
|
298
|
+
if(epochs_between_reports && ann->callback == NULL)
|
299
|
+
{
|
300
|
+
printf("Max epochs %8d. Desired error: %.10f.\n", max_epochs, desired_error);
|
301
|
+
}
|
302
|
+
|
303
|
+
for(i = 1; i <= max_epochs; i++)
|
304
|
+
{
|
305
|
+
/*
|
306
|
+
* train
|
307
|
+
*/
|
308
|
+
error = fann_train_epoch(ann, data);
|
309
|
+
desired_error_reached = fann_desired_error_reached(ann, desired_error);
|
310
|
+
|
311
|
+
/*
|
312
|
+
* print current output
|
313
|
+
*/
|
314
|
+
if(epochs_between_reports &&
|
315
|
+
(i % epochs_between_reports == 0 || i == max_epochs || i == 1 ||
|
316
|
+
desired_error_reached == 0))
|
317
|
+
{
|
318
|
+
if(ann->callback == NULL)
|
319
|
+
{
|
320
|
+
printf("Epochs %8d. Current error: %.10f. Bit fail %d.\n", i, error,
|
321
|
+
ann->num_bit_fail);
|
322
|
+
}
|
323
|
+
else if(((*ann->callback)(ann, data, max_epochs, epochs_between_reports,
|
324
|
+
desired_error, i)) == -1)
|
325
|
+
{
|
326
|
+
/*
|
327
|
+
* you can break the training by returning -1
|
328
|
+
*/
|
329
|
+
break;
|
330
|
+
}
|
331
|
+
}
|
332
|
+
|
333
|
+
if(desired_error_reached == 0)
|
334
|
+
break;
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
FANN_EXTERNAL void FANN_API fann_train_on_file(struct fann *ann, const char *filename,
|
339
|
+
unsigned int max_epochs,
|
340
|
+
unsigned int epochs_between_reports,
|
341
|
+
float desired_error)
|
342
|
+
{
|
343
|
+
struct fann_train_data *data = fann_read_train_from_file(filename);
|
344
|
+
|
345
|
+
if(data == NULL)
|
346
|
+
{
|
347
|
+
return;
|
348
|
+
}
|
349
|
+
fann_train_on_data(ann, data, max_epochs, epochs_between_reports, desired_error);
|
350
|
+
fann_destroy_train(data);
|
351
|
+
}
|
352
|
+
|
353
|
+
#endif
|
354
|
+
|
355
|
+
/*
|
356
|
+
* shuffles training data, randomizing the order
|
357
|
+
*/
|
358
|
+
FANN_EXTERNAL void FANN_API fann_shuffle_train_data(struct fann_train_data *train_data)
|
359
|
+
{
|
360
|
+
unsigned int dat = 0, elem, swap;
|
361
|
+
fann_type temp;
|
362
|
+
|
363
|
+
for(; dat < train_data->num_data; dat++)
|
364
|
+
{
|
365
|
+
swap = (unsigned int) (rand() % train_data->num_data);
|
366
|
+
if(swap != dat)
|
367
|
+
{
|
368
|
+
for(elem = 0; elem < train_data->num_input; elem++)
|
369
|
+
{
|
370
|
+
temp = train_data->input[dat][elem];
|
371
|
+
train_data->input[dat][elem] = train_data->input[swap][elem];
|
372
|
+
train_data->input[swap][elem] = temp;
|
373
|
+
}
|
374
|
+
for(elem = 0; elem < train_data->num_output; elem++)
|
375
|
+
{
|
376
|
+
temp = train_data->output[dat][elem];
|
377
|
+
train_data->output[dat][elem] = train_data->output[swap][elem];
|
378
|
+
train_data->output[swap][elem] = temp;
|
379
|
+
}
|
380
|
+
}
|
381
|
+
}
|
382
|
+
}
|
383
|
+
|
384
|
+
/*
|
385
|
+
* INTERNAL FUNCTION Scales data to a specific range
|
386
|
+
*/
|
387
|
+
void fann_scale_data(fann_type ** data, unsigned int num_data, unsigned int num_elem,
|
388
|
+
fann_type new_min, fann_type new_max)
|
389
|
+
{
|
390
|
+
unsigned int dat, elem;
|
391
|
+
fann_type old_min, old_max, temp, old_span, new_span, factor;
|
392
|
+
|
393
|
+
old_min = old_max = data[0][0];
|
394
|
+
|
395
|
+
/*
|
396
|
+
* first calculate min and max
|
397
|
+
*/
|
398
|
+
for(dat = 0; dat < num_data; dat++)
|
399
|
+
{
|
400
|
+
for(elem = 0; elem < num_elem; elem++)
|
401
|
+
{
|
402
|
+
temp = data[dat][elem];
|
403
|
+
if(temp < old_min)
|
404
|
+
old_min = temp;
|
405
|
+
else if(temp > old_max)
|
406
|
+
old_max = temp;
|
407
|
+
}
|
408
|
+
}
|
409
|
+
|
410
|
+
old_span = old_max - old_min;
|
411
|
+
new_span = new_max - new_min;
|
412
|
+
factor = new_span / old_span;
|
413
|
+
/*printf("max %f, min %f, factor %f\n", old_max, old_min, factor);*/
|
414
|
+
|
415
|
+
for(dat = 0; dat < num_data; dat++)
|
416
|
+
{
|
417
|
+
for(elem = 0; elem < num_elem; elem++)
|
418
|
+
{
|
419
|
+
temp = (data[dat][elem] - old_min) * factor + new_min;
|
420
|
+
if(temp < new_min)
|
421
|
+
{
|
422
|
+
data[dat][elem] = new_min;
|
423
|
+
/*
|
424
|
+
* printf("error %f < %f\n", temp, new_min);
|
425
|
+
*/
|
426
|
+
}
|
427
|
+
else if(temp > new_max)
|
428
|
+
{
|
429
|
+
data[dat][elem] = new_max;
|
430
|
+
/*
|
431
|
+
* printf("error %f > %f\n", temp, new_max);
|
432
|
+
*/
|
433
|
+
}
|
434
|
+
else
|
435
|
+
{
|
436
|
+
data[dat][elem] = temp;
|
437
|
+
}
|
438
|
+
}
|
439
|
+
}
|
440
|
+
}
|
441
|
+
|
442
|
+
/*
|
443
|
+
* Scales the inputs in the training data to the specified range
|
444
|
+
*/
|
445
|
+
FANN_EXTERNAL void FANN_API fann_scale_input_train_data(struct fann_train_data *train_data,
|
446
|
+
fann_type new_min, fann_type new_max)
|
447
|
+
{
|
448
|
+
fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
|
449
|
+
new_max);
|
450
|
+
}
|
451
|
+
|
452
|
+
/*
|
453
|
+
* Scales the inputs in the training data to the specified range
|
454
|
+
*/
|
455
|
+
FANN_EXTERNAL void FANN_API fann_scale_output_train_data(struct fann_train_data *train_data,
|
456
|
+
fann_type new_min, fann_type new_max)
|
457
|
+
{
|
458
|
+
fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
|
459
|
+
new_max);
|
460
|
+
}
|
461
|
+
|
462
|
+
/*
|
463
|
+
* Scales the inputs in the training data to the specified range
|
464
|
+
*/
|
465
|
+
FANN_EXTERNAL void FANN_API fann_scale_train_data(struct fann_train_data *train_data,
|
466
|
+
fann_type new_min, fann_type new_max)
|
467
|
+
{
|
468
|
+
fann_scale_data(train_data->input, train_data->num_data, train_data->num_input, new_min,
|
469
|
+
new_max);
|
470
|
+
fann_scale_data(train_data->output, train_data->num_data, train_data->num_output, new_min,
|
471
|
+
new_max);
|
472
|
+
}
|
473
|
+
|
474
|
+
/*
|
475
|
+
* merges training data into a single struct.
|
476
|
+
*/
|
477
|
+
FANN_EXTERNAL struct fann_train_data *FANN_API fann_merge_train_data(struct fann_train_data *data1,
|
478
|
+
struct fann_train_data *data2)
|
479
|
+
{
|
480
|
+
unsigned int i;
|
481
|
+
fann_type *data_input, *data_output;
|
482
|
+
struct fann_train_data *dest =
|
483
|
+
(struct fann_train_data *) malloc(sizeof(struct fann_train_data));
|
484
|
+
|
485
|
+
if(dest == NULL)
|
486
|
+
{
|
487
|
+
fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
|
488
|
+
return NULL;
|
489
|
+
}
|
490
|
+
|
491
|
+
if((data1->num_input != data2->num_input) || (data1->num_output != data2->num_output))
|
492
|
+
{
|
493
|
+
fann_error((struct fann_error*)data1, FANN_E_TRAIN_DATA_MISMATCH);
|
494
|
+
return NULL;
|
495
|
+
}
|
496
|
+
|
497
|
+
fann_init_error_data((struct fann_error *) dest);
|
498
|
+
dest->error_log = data1->error_log;
|
499
|
+
|
500
|
+
dest->num_data = data1->num_data+data2->num_data;
|
501
|
+
dest->num_input = data1->num_input;
|
502
|
+
dest->num_output = data1->num_output;
|
503
|
+
dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
504
|
+
if(dest->input == NULL)
|
505
|
+
{
|
506
|
+
fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
|
507
|
+
fann_destroy_train(dest);
|
508
|
+
return NULL;
|
509
|
+
}
|
510
|
+
|
511
|
+
dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
512
|
+
if(dest->output == NULL)
|
513
|
+
{
|
514
|
+
fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
|
515
|
+
fann_destroy_train(dest);
|
516
|
+
return NULL;
|
517
|
+
}
|
518
|
+
|
519
|
+
data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
|
520
|
+
if(data_input == NULL)
|
521
|
+
{
|
522
|
+
fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
|
523
|
+
fann_destroy_train(dest);
|
524
|
+
return NULL;
|
525
|
+
}
|
526
|
+
memcpy(data_input, data1->input[0], dest->num_input * data1->num_data * sizeof(fann_type));
|
527
|
+
memcpy(data_input + (dest->num_input*data1->num_data),
|
528
|
+
data2->input[0], dest->num_input * data2->num_data * sizeof(fann_type));
|
529
|
+
|
530
|
+
data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
|
531
|
+
if(data_output == NULL)
|
532
|
+
{
|
533
|
+
fann_error((struct fann_error*)data1, FANN_E_CANT_ALLOCATE_MEM);
|
534
|
+
fann_destroy_train(dest);
|
535
|
+
return NULL;
|
536
|
+
}
|
537
|
+
memcpy(data_output, data1->output[0], dest->num_output * data1->num_data * sizeof(fann_type));
|
538
|
+
memcpy(data_output + (dest->num_output*data1->num_data),
|
539
|
+
data2->output[0], dest->num_output * data2->num_data * sizeof(fann_type));
|
540
|
+
|
541
|
+
for(i = 0; i != dest->num_data; i++)
|
542
|
+
{
|
543
|
+
dest->input[i] = data_input;
|
544
|
+
data_input += dest->num_input;
|
545
|
+
dest->output[i] = data_output;
|
546
|
+
data_output += dest->num_output;
|
547
|
+
}
|
548
|
+
return dest;
|
549
|
+
}
|
550
|
+
|
551
|
+
/*
|
552
|
+
* return a copy of a fann_train_data struct
|
553
|
+
*/
|
554
|
+
FANN_EXTERNAL struct fann_train_data *FANN_API fann_duplicate_train_data(struct fann_train_data
|
555
|
+
*data)
|
556
|
+
{
|
557
|
+
unsigned int i;
|
558
|
+
fann_type *data_input, *data_output;
|
559
|
+
struct fann_train_data *dest =
|
560
|
+
(struct fann_train_data *) malloc(sizeof(struct fann_train_data));
|
561
|
+
|
562
|
+
if(dest == NULL)
|
563
|
+
{
|
564
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
565
|
+
return NULL;
|
566
|
+
}
|
567
|
+
|
568
|
+
fann_init_error_data((struct fann_error *) dest);
|
569
|
+
dest->error_log = data->error_log;
|
570
|
+
|
571
|
+
dest->num_data = data->num_data;
|
572
|
+
dest->num_input = data->num_input;
|
573
|
+
dest->num_output = data->num_output;
|
574
|
+
dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
575
|
+
if(dest->input == NULL)
|
576
|
+
{
|
577
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
578
|
+
fann_destroy_train(dest);
|
579
|
+
return NULL;
|
580
|
+
}
|
581
|
+
|
582
|
+
dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
583
|
+
if(dest->output == NULL)
|
584
|
+
{
|
585
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
586
|
+
fann_destroy_train(dest);
|
587
|
+
return NULL;
|
588
|
+
}
|
589
|
+
|
590
|
+
data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
|
591
|
+
if(data_input == NULL)
|
592
|
+
{
|
593
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
594
|
+
fann_destroy_train(dest);
|
595
|
+
return NULL;
|
596
|
+
}
|
597
|
+
memcpy(data_input, data->input[0], dest->num_input * dest->num_data * sizeof(fann_type));
|
598
|
+
|
599
|
+
data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
|
600
|
+
if(data_output == NULL)
|
601
|
+
{
|
602
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
603
|
+
fann_destroy_train(dest);
|
604
|
+
return NULL;
|
605
|
+
}
|
606
|
+
memcpy(data_output, data->output[0], dest->num_output * dest->num_data * sizeof(fann_type));
|
607
|
+
|
608
|
+
for(i = 0; i != dest->num_data; i++)
|
609
|
+
{
|
610
|
+
dest->input[i] = data_input;
|
611
|
+
data_input += dest->num_input;
|
612
|
+
dest->output[i] = data_output;
|
613
|
+
data_output += dest->num_output;
|
614
|
+
}
|
615
|
+
return dest;
|
616
|
+
}
|
617
|
+
|
618
|
+
FANN_EXTERNAL struct fann_train_data *FANN_API fann_subset_train_data(struct fann_train_data
|
619
|
+
*data, unsigned int pos,
|
620
|
+
unsigned int length)
|
621
|
+
{
|
622
|
+
unsigned int i;
|
623
|
+
fann_type *data_input, *data_output;
|
624
|
+
struct fann_train_data *dest =
|
625
|
+
(struct fann_train_data *) malloc(sizeof(struct fann_train_data));
|
626
|
+
|
627
|
+
if(dest == NULL)
|
628
|
+
{
|
629
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
630
|
+
return NULL;
|
631
|
+
}
|
632
|
+
|
633
|
+
if(pos > data->num_data || pos+length > data->num_data)
|
634
|
+
{
|
635
|
+
fann_error((struct fann_error*)data, FANN_E_TRAIN_DATA_SUBSET, pos, length, data->num_data);
|
636
|
+
return NULL;
|
637
|
+
}
|
638
|
+
|
639
|
+
fann_init_error_data((struct fann_error *) dest);
|
640
|
+
dest->error_log = data->error_log;
|
641
|
+
|
642
|
+
dest->num_data = length;
|
643
|
+
dest->num_input = data->num_input;
|
644
|
+
dest->num_output = data->num_output;
|
645
|
+
dest->input = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
646
|
+
if(dest->input == NULL)
|
647
|
+
{
|
648
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
649
|
+
fann_destroy_train(dest);
|
650
|
+
return NULL;
|
651
|
+
}
|
652
|
+
|
653
|
+
dest->output = (fann_type **) calloc(dest->num_data, sizeof(fann_type *));
|
654
|
+
if(dest->output == NULL)
|
655
|
+
{
|
656
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
657
|
+
fann_destroy_train(dest);
|
658
|
+
return NULL;
|
659
|
+
}
|
660
|
+
|
661
|
+
data_input = (fann_type *) calloc(dest->num_input * dest->num_data, sizeof(fann_type));
|
662
|
+
if(data_input == NULL)
|
663
|
+
{
|
664
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
665
|
+
fann_destroy_train(dest);
|
666
|
+
return NULL;
|
667
|
+
}
|
668
|
+
memcpy(data_input, data->input[pos], dest->num_input * dest->num_data * sizeof(fann_type));
|
669
|
+
|
670
|
+
data_output = (fann_type *) calloc(dest->num_output * dest->num_data, sizeof(fann_type));
|
671
|
+
if(data_output == NULL)
|
672
|
+
{
|
673
|
+
fann_error((struct fann_error*)data, FANN_E_CANT_ALLOCATE_MEM);
|
674
|
+
fann_destroy_train(dest);
|
675
|
+
return NULL;
|
676
|
+
}
|
677
|
+
memcpy(data_output, data->output[pos], dest->num_output * dest->num_data * sizeof(fann_type));
|
678
|
+
|
679
|
+
for(i = 0; i != dest->num_data; i++)
|
680
|
+
{
|
681
|
+
dest->input[i] = data_input;
|
682
|
+
data_input += dest->num_input;
|
683
|
+
dest->output[i] = data_output;
|
684
|
+
data_output += dest->num_output;
|
685
|
+
}
|
686
|
+
return dest;
|
687
|
+
}
|
688
|
+
|
689
|
+
FANN_EXTERNAL unsigned int FANN_API fann_length_train_data(struct fann_train_data *data)
|
690
|
+
{
|
691
|
+
return data->num_data;
|
692
|
+
}
|
693
|
+
|
694
|
+
FANN_EXTERNAL unsigned int FANN_API fann_num_input_train_data(struct fann_train_data *data)
|
695
|
+
{
|
696
|
+
return data->num_input;
|
697
|
+
}
|
698
|
+
|
699
|
+
FANN_EXTERNAL unsigned int FANN_API fann_num_output_train_data(struct fann_train_data *data)
|
700
|
+
{
|
701
|
+
return data->num_output;
|
702
|
+
}
|
703
|
+
|
704
|
+
/* INTERNAL FUNCTION
|
705
|
+
Save the train data structure.
|
706
|
+
*/
|
707
|
+
int fann_save_train_internal(struct fann_train_data *data, const char *filename,
|
708
|
+
unsigned int save_as_fixed, unsigned int decimal_point)
|
709
|
+
{
|
710
|
+
int retval = 0;
|
711
|
+
FILE *file = fopen(filename, "w");
|
712
|
+
|
713
|
+
if(!file)
|
714
|
+
{
|
715
|
+
fann_error((struct fann_error *) data, FANN_E_CANT_OPEN_TD_W, filename);
|
716
|
+
return -1;
|
717
|
+
}
|
718
|
+
retval = fann_save_train_internal_fd(data, file, filename, save_as_fixed, decimal_point);
|
719
|
+
fclose(file);
|
720
|
+
|
721
|
+
return retval;
|
722
|
+
}
|
723
|
+
|
724
|
+
/* INTERNAL FUNCTION
|
725
|
+
Save the train data structure.
|
726
|
+
*/
|
727
|
+
int fann_save_train_internal_fd(struct fann_train_data *data, FILE * file, const char *filename,
|
728
|
+
unsigned int save_as_fixed, unsigned int decimal_point)
|
729
|
+
{
|
730
|
+
unsigned int num_data = data->num_data;
|
731
|
+
unsigned int num_input = data->num_input;
|
732
|
+
unsigned int num_output = data->num_output;
|
733
|
+
unsigned int i, j;
|
734
|
+
int retval = 0;
|
735
|
+
|
736
|
+
#ifndef FIXEDFANN
|
737
|
+
unsigned int multiplier = 1 << decimal_point;
|
738
|
+
#endif
|
739
|
+
|
740
|
+
fprintf(file, "%u %u %u\n", data->num_data, data->num_input, data->num_output);
|
741
|
+
|
742
|
+
for(i = 0; i < num_data; i++)
|
743
|
+
{
|
744
|
+
for(j = 0; j < num_input; j++)
|
745
|
+
{
|
746
|
+
#ifndef FIXEDFANN
|
747
|
+
if(save_as_fixed)
|
748
|
+
{
|
749
|
+
fprintf(file, "%d ", (int) (data->input[i][j] * multiplier));
|
750
|
+
}
|
751
|
+
else
|
752
|
+
{
|
753
|
+
if(((int) floor(data->input[i][j] + 0.5) * 1000000) ==
|
754
|
+
((int) floor(data->input[i][j] * 1000000.0 + 0.5)))
|
755
|
+
{
|
756
|
+
fprintf(file, "%d ", (int) data->input[i][j]);
|
757
|
+
}
|
758
|
+
else
|
759
|
+
{
|
760
|
+
fprintf(file, "%f ", data->input[i][j]);
|
761
|
+
}
|
762
|
+
}
|
763
|
+
#else
|
764
|
+
fprintf(file, FANNPRINTF " ", data->input[i][j]);
|
765
|
+
#endif
|
766
|
+
}
|
767
|
+
fprintf(file, "\n");
|
768
|
+
|
769
|
+
for(j = 0; j < num_output; j++)
|
770
|
+
{
|
771
|
+
#ifndef FIXEDFANN
|
772
|
+
if(save_as_fixed)
|
773
|
+
{
|
774
|
+
fprintf(file, "%d ", (int) (data->output[i][j] * multiplier));
|
775
|
+
}
|
776
|
+
else
|
777
|
+
{
|
778
|
+
if(((int) floor(data->output[i][j] + 0.5) * 1000000) ==
|
779
|
+
((int) floor(data->output[i][j] * 1000000.0 + 0.5)))
|
780
|
+
{
|
781
|
+
fprintf(file, "%d ", (int) data->output[i][j]);
|
782
|
+
}
|
783
|
+
else
|
784
|
+
{
|
785
|
+
fprintf(file, "%f ", data->output[i][j]);
|
786
|
+
}
|
787
|
+
}
|
788
|
+
#else
|
789
|
+
fprintf(file, FANNPRINTF " ", data->output[i][j]);
|
790
|
+
#endif
|
791
|
+
}
|
792
|
+
fprintf(file, "\n");
|
793
|
+
}
|
794
|
+
|
795
|
+
return retval;
|
796
|
+
}
|
797
|
+
|
798
|
+
|
799
|
+
/*
|
800
|
+
* INTERNAL FUNCTION Reads training data from a file descriptor.
|
801
|
+
*/
|
802
|
+
struct fann_train_data *fann_read_train_from_fd(FILE * file, const char *filename)
|
803
|
+
{
|
804
|
+
unsigned int num_input, num_output, num_data, i, j;
|
805
|
+
unsigned int line = 1;
|
806
|
+
fann_type *data_input, *data_output;
|
807
|
+
struct fann_train_data *data =
|
808
|
+
(struct fann_train_data *) malloc(sizeof(struct fann_train_data));
|
809
|
+
|
810
|
+
if(data == NULL)
|
811
|
+
{
|
812
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
813
|
+
return NULL;
|
814
|
+
}
|
815
|
+
|
816
|
+
if(fscanf(file, "%u %u %u\n", &num_data, &num_input, &num_output) != 3)
|
817
|
+
{
|
818
|
+
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
|
819
|
+
fann_destroy_train(data);
|
820
|
+
return NULL;
|
821
|
+
}
|
822
|
+
line++;
|
823
|
+
|
824
|
+
fann_init_error_data((struct fann_error *) data);
|
825
|
+
|
826
|
+
data->num_data = num_data;
|
827
|
+
data->num_input = num_input;
|
828
|
+
data->num_output = num_output;
|
829
|
+
data->input = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
830
|
+
if(data->input == NULL)
|
831
|
+
{
|
832
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
833
|
+
fann_destroy_train(data);
|
834
|
+
return NULL;
|
835
|
+
}
|
836
|
+
|
837
|
+
data->output = (fann_type **) calloc(num_data, sizeof(fann_type *));
|
838
|
+
if(data->output == NULL)
|
839
|
+
{
|
840
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
841
|
+
fann_destroy_train(data);
|
842
|
+
return NULL;
|
843
|
+
}
|
844
|
+
|
845
|
+
data_input = (fann_type *) calloc(num_input * num_data, sizeof(fann_type));
|
846
|
+
if(data_input == NULL)
|
847
|
+
{
|
848
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
849
|
+
fann_destroy_train(data);
|
850
|
+
return NULL;
|
851
|
+
}
|
852
|
+
|
853
|
+
data_output = (fann_type *) calloc(num_output * num_data, sizeof(fann_type));
|
854
|
+
if(data_output == NULL)
|
855
|
+
{
|
856
|
+
fann_error(NULL, FANN_E_CANT_ALLOCATE_MEM);
|
857
|
+
fann_destroy_train(data);
|
858
|
+
return NULL;
|
859
|
+
}
|
860
|
+
|
861
|
+
for(i = 0; i != num_data; i++)
|
862
|
+
{
|
863
|
+
data->input[i] = data_input;
|
864
|
+
data_input += num_input;
|
865
|
+
|
866
|
+
for(j = 0; j != num_input; j++)
|
867
|
+
{
|
868
|
+
if(fscanf(file, FANNSCANF " ", &data->input[i][j]) != 1)
|
869
|
+
{
|
870
|
+
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
|
871
|
+
fann_destroy_train(data);
|
872
|
+
return NULL;
|
873
|
+
}
|
874
|
+
}
|
875
|
+
line++;
|
876
|
+
|
877
|
+
data->output[i] = data_output;
|
878
|
+
data_output += num_output;
|
879
|
+
|
880
|
+
for(j = 0; j != num_output; j++)
|
881
|
+
{
|
882
|
+
if(fscanf(file, FANNSCANF " ", &data->output[i][j]) != 1)
|
883
|
+
{
|
884
|
+
fann_error(NULL, FANN_E_CANT_READ_TD, filename, line);
|
885
|
+
fann_destroy_train(data);
|
886
|
+
return NULL;
|
887
|
+
}
|
888
|
+
}
|
889
|
+
line++;
|
890
|
+
}
|
891
|
+
return data;
|
892
|
+
}
|
893
|
+
|
894
|
+
/*
|
895
|
+
* INTERNAL FUNCTION returns 0 if the desired error is reached and -1 if it is not reached
|
896
|
+
*/
|
897
|
+
int fann_desired_error_reached(struct fann *ann, float desired_error)
|
898
|
+
{
|
899
|
+
switch (ann->train_stop_function)
|
900
|
+
{
|
901
|
+
case FANN_STOPFUNC_MSE:
|
902
|
+
if(fann_get_MSE(ann) <= desired_error)
|
903
|
+
return 0;
|
904
|
+
break;
|
905
|
+
case FANN_STOPFUNC_BIT:
|
906
|
+
if(ann->num_bit_fail <= (unsigned int)desired_error)
|
907
|
+
return 0;
|
908
|
+
break;
|
909
|
+
}
|
910
|
+
return -1;
|
911
|
+
}
|
912
|
+
|
913
|
+
#ifndef FIXEDFANN
|
914
|
+
/*
|
915
|
+
* Scale data in input vector before feed it to ann based on previously calculated parameters.
|
916
|
+
*/
|
917
|
+
FANN_EXTERNAL void FANN_API fann_scale_input( struct fann *ann, fann_type *input_vector )
|
918
|
+
{
|
919
|
+
unsigned cur_neuron;
|
920
|
+
if(ann->scale_mean_in == NULL)
|
921
|
+
{
|
922
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
923
|
+
return;
|
924
|
+
}
|
925
|
+
|
926
|
+
for( cur_neuron = 0; cur_neuron < ann->num_input; cur_neuron++ )
|
927
|
+
input_vector[ cur_neuron ] =
|
928
|
+
(
|
929
|
+
( input_vector[ cur_neuron ] - ann->scale_mean_in[ cur_neuron ] )
|
930
|
+
/ ann->scale_deviation_in[ cur_neuron ]
|
931
|
+
- ( -1.0 ) /* This is old_min */
|
932
|
+
)
|
933
|
+
* ann->scale_factor_in[ cur_neuron ]
|
934
|
+
+ ann->scale_new_min_in[ cur_neuron ];
|
935
|
+
}
|
936
|
+
|
937
|
+
/*
|
938
|
+
* Scale data in output vector before feed it to ann based on previously calculated parameters.
|
939
|
+
*/
|
940
|
+
FANN_EXTERNAL void FANN_API fann_scale_output( struct fann *ann, fann_type *output_vector )
|
941
|
+
{
|
942
|
+
unsigned cur_neuron;
|
943
|
+
if(ann->scale_mean_in == NULL)
|
944
|
+
{
|
945
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
946
|
+
return;
|
947
|
+
}
|
948
|
+
|
949
|
+
for( cur_neuron = 0; cur_neuron < ann->num_output; cur_neuron++ )
|
950
|
+
output_vector[ cur_neuron ] =
|
951
|
+
(
|
952
|
+
( output_vector[ cur_neuron ] - ann->scale_mean_out[ cur_neuron ] )
|
953
|
+
/ ann->scale_deviation_out[ cur_neuron ]
|
954
|
+
- ( -1.0 ) /* This is old_min */
|
955
|
+
)
|
956
|
+
* ann->scale_factor_out[ cur_neuron ]
|
957
|
+
+ ann->scale_new_min_out[ cur_neuron ];
|
958
|
+
}
|
959
|
+
|
960
|
+
/*
|
961
|
+
* Descale data in input vector after based on previously calculated parameters.
|
962
|
+
*/
|
963
|
+
FANN_EXTERNAL void FANN_API fann_descale_input( struct fann *ann, fann_type *input_vector )
|
964
|
+
{
|
965
|
+
unsigned cur_neuron;
|
966
|
+
if(ann->scale_mean_in == NULL)
|
967
|
+
{
|
968
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
969
|
+
return;
|
970
|
+
}
|
971
|
+
|
972
|
+
for( cur_neuron = 0; cur_neuron < ann->num_input; cur_neuron++ )
|
973
|
+
input_vector[ cur_neuron ] =
|
974
|
+
(
|
975
|
+
(
|
976
|
+
input_vector[ cur_neuron ]
|
977
|
+
- ann->scale_new_min_in[ cur_neuron ]
|
978
|
+
)
|
979
|
+
/ ann->scale_factor_in[ cur_neuron ]
|
980
|
+
+ ( -1.0 ) /* This is old_min */
|
981
|
+
)
|
982
|
+
* ann->scale_deviation_in[ cur_neuron ]
|
983
|
+
+ ann->scale_mean_in[ cur_neuron ];
|
984
|
+
}
|
985
|
+
|
986
|
+
/*
|
987
|
+
* Descale data in output vector after get it from ann based on previously calculated parameters.
|
988
|
+
*/
|
989
|
+
FANN_EXTERNAL void FANN_API fann_descale_output( struct fann *ann, fann_type *output_vector )
|
990
|
+
{
|
991
|
+
unsigned cur_neuron;
|
992
|
+
if(ann->scale_mean_in == NULL)
|
993
|
+
{
|
994
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
995
|
+
return;
|
996
|
+
}
|
997
|
+
|
998
|
+
for( cur_neuron = 0; cur_neuron < ann->num_output; cur_neuron++ )
|
999
|
+
output_vector[ cur_neuron ] =
|
1000
|
+
(
|
1001
|
+
(
|
1002
|
+
output_vector[ cur_neuron ]
|
1003
|
+
- ann->scale_new_min_out[ cur_neuron ]
|
1004
|
+
)
|
1005
|
+
/ ann->scale_factor_out[ cur_neuron ]
|
1006
|
+
+ ( -1.0 ) /* This is old_min */
|
1007
|
+
)
|
1008
|
+
* ann->scale_deviation_out[ cur_neuron ]
|
1009
|
+
+ ann->scale_mean_out[ cur_neuron ];
|
1010
|
+
}
|
1011
|
+
|
1012
|
+
/*
|
1013
|
+
* Scale input and output data based on previously calculated parameters.
|
1014
|
+
*/
|
1015
|
+
FANN_EXTERNAL void FANN_API fann_scale_train( struct fann *ann, struct fann_train_data *data )
|
1016
|
+
{
|
1017
|
+
unsigned cur_sample;
|
1018
|
+
if(ann->scale_mean_in == NULL)
|
1019
|
+
{
|
1020
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
1021
|
+
return;
|
1022
|
+
}
|
1023
|
+
/* Check that we have good training data. */
|
1024
|
+
/* No need for if( !params || !ann ) */
|
1025
|
+
if( data->num_input != ann->num_input
|
1026
|
+
|| data->num_output != ann->num_output
|
1027
|
+
)
|
1028
|
+
{
|
1029
|
+
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1030
|
+
return;
|
1031
|
+
}
|
1032
|
+
|
1033
|
+
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
|
1034
|
+
{
|
1035
|
+
fann_scale_input( ann, data->input[ cur_sample ] );
|
1036
|
+
fann_scale_output( ann, data->output[ cur_sample ] );
|
1037
|
+
}
|
1038
|
+
}
|
1039
|
+
|
1040
|
+
/*
|
1041
|
+
* Scale input and output data based on previously calculated parameters.
|
1042
|
+
*/
|
1043
|
+
FANN_EXTERNAL void FANN_API fann_descale_train( struct fann *ann, struct fann_train_data *data )
|
1044
|
+
{
|
1045
|
+
unsigned cur_sample;
|
1046
|
+
if(ann->scale_mean_in == NULL)
|
1047
|
+
{
|
1048
|
+
fann_error( (struct fann_error *) ann, FANN_E_SCALE_NOT_PRESENT );
|
1049
|
+
return;
|
1050
|
+
}
|
1051
|
+
/* Check that we have good training data. */
|
1052
|
+
/* No need for if( !params || !ann ) */
|
1053
|
+
if( data->num_input != ann->num_input
|
1054
|
+
|| data->num_output != ann->num_output
|
1055
|
+
)
|
1056
|
+
{
|
1057
|
+
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1058
|
+
return;
|
1059
|
+
}
|
1060
|
+
|
1061
|
+
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ )
|
1062
|
+
{
|
1063
|
+
fann_descale_input( ann, data->input[ cur_sample ] );
|
1064
|
+
fann_descale_output( ann, data->output[ cur_sample ] );
|
1065
|
+
}
|
1066
|
+
}
|
1067
|
+
|
1068
|
+
#define SCALE_RESET( what, where, default_value ) \
|
1069
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1070
|
+
ann->what##_##where[ cur_neuron ] = ( default_value );
|
1071
|
+
|
1072
|
+
#define SCALE_SET_PARAM( where ) \
|
1073
|
+
/* Calculate mean: sum(x)/length */ \
|
1074
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1075
|
+
ann->scale_mean_##where[ cur_neuron ] = 0.0; \
|
1076
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1077
|
+
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
|
1078
|
+
ann->scale_mean_##where[ cur_neuron ] += data->where##put[ cur_sample ][ cur_neuron ]; \
|
1079
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1080
|
+
ann->scale_mean_##where[ cur_neuron ] /= (float)data->num_data; \
|
1081
|
+
/* Calculate deviation: sqrt(sum((x-mean)^2)/length) */ \
|
1082
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1083
|
+
ann->scale_deviation_##where[ cur_neuron ] = 0.0; \
|
1084
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1085
|
+
for( cur_sample = 0; cur_sample < data->num_data; cur_sample++ ) \
|
1086
|
+
ann->scale_deviation_##where[ cur_neuron ] += \
|
1087
|
+
/* Another local variable in macro? Oh no! */ \
|
1088
|
+
( \
|
1089
|
+
data->where##put[ cur_sample ][ cur_neuron ] \
|
1090
|
+
- ann->scale_mean_##where[ cur_neuron ] \
|
1091
|
+
) \
|
1092
|
+
* \
|
1093
|
+
( \
|
1094
|
+
data->where##put[ cur_sample ][ cur_neuron ] \
|
1095
|
+
- ann->scale_mean_##where[ cur_neuron ] \
|
1096
|
+
); \
|
1097
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1098
|
+
ann->scale_deviation_##where[ cur_neuron ] = \
|
1099
|
+
sqrt( ann->scale_deviation_##where[ cur_neuron ] / (float)data->num_data ); \
|
1100
|
+
/* Calculate factor: (new_max-new_min)/(old_max(1)-old_min(-1)) */ \
|
1101
|
+
/* Looks like we dont need whole array of factors? */ \
|
1102
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1103
|
+
ann->scale_factor_##where[ cur_neuron ] = \
|
1104
|
+
( new_##where##put_max - new_##where##put_min ) \
|
1105
|
+
/ \
|
1106
|
+
( 1.0 - ( -1.0 ) ); \
|
1107
|
+
/* Copy new minimum. */ \
|
1108
|
+
/* Looks like we dont need whole array of new minimums? */ \
|
1109
|
+
for( cur_neuron = 0; cur_neuron < ann->num_##where##put; cur_neuron++ ) \
|
1110
|
+
ann->scale_new_min_##where[ cur_neuron ] = new_##where##put_min;
|
1111
|
+
|
1112
|
+
FANN_EXTERNAL int FANN_API fann_set_input_scaling_params(
|
1113
|
+
struct fann *ann,
|
1114
|
+
const struct fann_train_data *data,
|
1115
|
+
float new_input_min,
|
1116
|
+
float new_input_max)
|
1117
|
+
{
|
1118
|
+
unsigned cur_neuron, cur_sample;
|
1119
|
+
|
1120
|
+
/* Check that we have good training data. */
|
1121
|
+
/* No need for if( !params || !ann ) */
|
1122
|
+
if(data->num_input != ann->num_input
|
1123
|
+
|| data->num_output != ann->num_output)
|
1124
|
+
{
|
1125
|
+
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1126
|
+
return -1;
|
1127
|
+
}
|
1128
|
+
|
1129
|
+
if(ann->scale_mean_in == NULL)
|
1130
|
+
fann_allocate_scale(ann);
|
1131
|
+
|
1132
|
+
if(ann->scale_mean_in == NULL)
|
1133
|
+
return -1;
|
1134
|
+
|
1135
|
+
if( !data->num_data )
|
1136
|
+
{
|
1137
|
+
SCALE_RESET( scale_mean, in, 0.0 )
|
1138
|
+
SCALE_RESET( scale_deviation, in, 1.0 )
|
1139
|
+
SCALE_RESET( scale_new_min, in, -1.0 )
|
1140
|
+
SCALE_RESET( scale_factor, in, 1.0 )
|
1141
|
+
}
|
1142
|
+
else
|
1143
|
+
{
|
1144
|
+
SCALE_SET_PARAM( in );
|
1145
|
+
}
|
1146
|
+
|
1147
|
+
return 0;
|
1148
|
+
}
|
1149
|
+
|
1150
|
+
FANN_EXTERNAL int FANN_API fann_set_output_scaling_params(
|
1151
|
+
struct fann *ann,
|
1152
|
+
const struct fann_train_data *data,
|
1153
|
+
float new_output_min,
|
1154
|
+
float new_output_max)
|
1155
|
+
{
|
1156
|
+
unsigned cur_neuron, cur_sample;
|
1157
|
+
|
1158
|
+
/* Check that we have good training data. */
|
1159
|
+
/* No need for if( !params || !ann ) */
|
1160
|
+
if(data->num_input != ann->num_input
|
1161
|
+
|| data->num_output != ann->num_output)
|
1162
|
+
{
|
1163
|
+
fann_error( (struct fann_error *) ann, FANN_E_TRAIN_DATA_MISMATCH );
|
1164
|
+
return -1;
|
1165
|
+
}
|
1166
|
+
|
1167
|
+
if(ann->scale_mean_out == NULL)
|
1168
|
+
fann_allocate_scale(ann);
|
1169
|
+
|
1170
|
+
if(ann->scale_mean_out == NULL)
|
1171
|
+
return -1;
|
1172
|
+
|
1173
|
+
if( !data->num_data )
|
1174
|
+
{
|
1175
|
+
SCALE_RESET( scale_mean, out, 0.0 )
|
1176
|
+
SCALE_RESET( scale_deviation, out, 1.0 )
|
1177
|
+
SCALE_RESET( scale_new_min, out, -1.0 )
|
1178
|
+
SCALE_RESET( scale_factor, out, 1.0 )
|
1179
|
+
}
|
1180
|
+
else
|
1181
|
+
{
|
1182
|
+
SCALE_SET_PARAM( out );
|
1183
|
+
}
|
1184
|
+
|
1185
|
+
return 0;
|
1186
|
+
}
|
1187
|
+
|
1188
|
+
/*
|
1189
|
+
* Calculate scaling parameters for future use based on training data.
|
1190
|
+
*/
|
1191
|
+
FANN_EXTERNAL int FANN_API fann_set_scaling_params(
|
1192
|
+
struct fann *ann,
|
1193
|
+
const struct fann_train_data *data,
|
1194
|
+
float new_input_min,
|
1195
|
+
float new_input_max,
|
1196
|
+
float new_output_min,
|
1197
|
+
float new_output_max)
|
1198
|
+
{
|
1199
|
+
if(fann_set_input_scaling_params(ann, data, new_input_min, new_input_max) == 0)
|
1200
|
+
return fann_set_output_scaling_params(ann, data, new_output_min, new_output_max);
|
1201
|
+
else
|
1202
|
+
return -1;
|
1203
|
+
}
|
1204
|
+
|
1205
|
+
/*
|
1206
|
+
* Clears scaling parameters.
|
1207
|
+
*/
|
1208
|
+
FANN_EXTERNAL int FANN_API fann_clear_scaling_params(struct fann *ann)
|
1209
|
+
{
|
1210
|
+
unsigned cur_neuron;
|
1211
|
+
|
1212
|
+
if(ann->scale_mean_out == NULL)
|
1213
|
+
fann_allocate_scale(ann);
|
1214
|
+
|
1215
|
+
if(ann->scale_mean_out == NULL)
|
1216
|
+
return -1;
|
1217
|
+
|
1218
|
+
SCALE_RESET( scale_mean, in, 0.0 )
|
1219
|
+
SCALE_RESET( scale_deviation, in, 1.0 )
|
1220
|
+
SCALE_RESET( scale_new_min, in, -1.0 )
|
1221
|
+
SCALE_RESET( scale_factor, in, 1.0 )
|
1222
|
+
|
1223
|
+
SCALE_RESET( scale_mean, out, 0.0 )
|
1224
|
+
SCALE_RESET( scale_deviation, out, 1.0 )
|
1225
|
+
SCALE_RESET( scale_new_min, out, -1.0 )
|
1226
|
+
SCALE_RESET( scale_factor, out, 1.0 )
|
1227
|
+
|
1228
|
+
return 0;
|
1229
|
+
}
|
1230
|
+
|
1231
|
+
#endif
|