tlearn 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- data/ext/tlearn/Exp/exp.c +13 -0
- data/ext/tlearn/activate.c +222 -0
- data/ext/tlearn/arrays.c +224 -0
- data/ext/tlearn/compute.c +404 -0
- data/ext/tlearn/extconf.rb +14 -0
- data/ext/tlearn/getopt.c +76 -0
- data/ext/tlearn/parse.c +594 -0
- data/ext/tlearn/subs.c +204 -0
- data/ext/tlearn/tlearn.c +525 -0
- data/ext/tlearn/tlearn_ext.c +587 -0
- data/ext/tlearn/update.c +577 -0
- data/ext/tlearn/weights.c +116 -0
- data/lib/tlearn.rb +17 -0
- data/lib/tlearn/config.rb +101 -0
- data/lib/tlearn/fitness_data.rb +24 -0
- data/lib/tlearn/run.rb +29 -0
- data/lib/tlearn/run_tlearn.rb +68 -0
- data/lib/tlearn/training_data.rb +41 -0
- metadata +64 -0
@@ -0,0 +1,404 @@
|
|
1
|
+
#include <math.h>
|
2
|
+
#include <stdio.h>
|
3
|
+
|
4
|
+
#ifdef ibmpc
|
5
|
+
extern char far *malloc();
|
6
|
+
#else
|
7
|
+
extern void *malloc();
|
8
|
+
#endif
|
9
|
+
|
10
|
+
|
11
|
+
extern int nn; /* number of nodes */
|
12
|
+
extern int ni; /* number of inputs */
|
13
|
+
extern int no; /* number of outputs */
|
14
|
+
extern int nt; /* nn + ni + 1 */
|
15
|
+
extern int np; /* ni + 1 */
|
16
|
+
extern int ce; /* cross-entropy flag */
|
17
|
+
|
18
|
+
struct cf {
|
19
|
+
int con; /* connection flag */
|
20
|
+
int fix; /* fixed-weight flag */
|
21
|
+
int num; /* group number */
|
22
|
+
int lim; /* weight limits */
|
23
|
+
float min; /* weight minimum */
|
24
|
+
float max; /* weight maximum */
|
25
|
+
};
|
26
|
+
|
27
|
+
extern struct nf {
|
28
|
+
int func; /* activation function type */
|
29
|
+
int dela; /* delay flag */
|
30
|
+
int targ; /* target flag */
|
31
|
+
};
|
32
|
+
|
33
|
+
extern struct cf **cinfo; /* (nn x nt) connection info */
|
34
|
+
extern struct nf *ninfo; /* (nn) node activation function info */
|
35
|
+
|
36
|
+
extern int *outputs; /* (no) indices of output nodes */
|
37
|
+
|
38
|
+
extern int localist; /* flag for localist input */
|
39
|
+
|
40
|
+
comp_errors(aold,atarget,aerror,e,ce_e)
|
41
|
+
float *aold;
|
42
|
+
float *atarget;
|
43
|
+
float *aerror;
|
44
|
+
float *e;
|
45
|
+
float *ce_e;
|
46
|
+
{
|
47
|
+
extern int ce;
|
48
|
+
|
49
|
+
register int i;
|
50
|
+
register int j;
|
51
|
+
register float *ta;
|
52
|
+
register float *te;
|
53
|
+
register float *ce_te;
|
54
|
+
register float *ee;
|
55
|
+
register int *op;
|
56
|
+
|
57
|
+
static float *terror = 0;
|
58
|
+
static float *ce_terror = 0;
|
59
|
+
|
60
|
+
if (terror == 0){
|
61
|
+
/* malloc space for local copy of error info */
|
62
|
+
terror = (float *) malloc(no * sizeof(float));
|
63
|
+
if (terror == NULL){
|
64
|
+
perror("terror malloc failed");
|
65
|
+
exit(1);
|
66
|
+
}
|
67
|
+
}
|
68
|
+
if (ce_terror == 0){
|
69
|
+
/* malloc space for local copy of cross-entropy info */
|
70
|
+
ce_terror = (float *) malloc(no * sizeof(float));
|
71
|
+
if (ce_terror == NULL){
|
72
|
+
perror("ce_terror malloc failed");
|
73
|
+
exit(1);
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
|
78
|
+
te = terror;
|
79
|
+
ce_te = ce_terror;
|
80
|
+
ta = atarget;
|
81
|
+
op = outputs;
|
82
|
+
for (i = 0; i < no; i++, te++, ce_te++, ta++, op++){
|
83
|
+
if (*ta != -9999.0) {
|
84
|
+
*te = *(aold + ni + *op) - *ta;
|
85
|
+
/*
|
86
|
+
* if collecting cross-entropy statistics;
|
87
|
+
*/
|
88
|
+
if (ce == 2) {
|
89
|
+
*ce_te = *ta * log(*(aold+ni+ *op))/log(2.0) +
|
90
|
+
(1- *ta) * log(1- *(aold+ni+ *op))/log(2.0);
|
91
|
+
}
|
92
|
+
} else {
|
93
|
+
*te = 0.;
|
94
|
+
}
|
95
|
+
*e += *te * *te; /* cumulative ss error */
|
96
|
+
*ce_e += *ce_te; /* cumulate cross-entropy error */
|
97
|
+
}
|
98
|
+
ee = aerror;
|
99
|
+
for (i = 1; i <= nn; i++, ee++){
|
100
|
+
*ee = 0.;
|
101
|
+
te = terror;
|
102
|
+
op = outputs;
|
103
|
+
for (j = 0; j < no; j++, te++, op++){
|
104
|
+
if (*op == i){
|
105
|
+
*ee = *te;
|
106
|
+
break;
|
107
|
+
}
|
108
|
+
}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
comp_deltas(apold,apnew,awt,adwt,aold,anew,aerror)
|
114
|
+
float ***apold;
|
115
|
+
float ***apnew;
|
116
|
+
float **awt;
|
117
|
+
float **adwt;
|
118
|
+
float *aold;
|
119
|
+
float *anew;
|
120
|
+
float *aerror;
|
121
|
+
{
|
122
|
+
register int i;
|
123
|
+
register int j;
|
124
|
+
register int k;
|
125
|
+
register int l;
|
126
|
+
|
127
|
+
register struct cf **cp;
|
128
|
+
|
129
|
+
register struct cf *ci;
|
130
|
+
register struct nf *n;
|
131
|
+
|
132
|
+
register float **wp;
|
133
|
+
register float *zn;
|
134
|
+
register float *pn;
|
135
|
+
register float *po;
|
136
|
+
register float **pnp;
|
137
|
+
register float **pop;
|
138
|
+
register float ***pnpp;
|
139
|
+
register float ***popp;
|
140
|
+
register float *w;
|
141
|
+
|
142
|
+
register float *sum;
|
143
|
+
|
144
|
+
register float *e;
|
145
|
+
|
146
|
+
float asum;
|
147
|
+
|
148
|
+
/* to each node */
|
149
|
+
sum = &asum;
|
150
|
+
cp = cinfo;
|
151
|
+
pnpp = apnew;
|
152
|
+
popp = apold;
|
153
|
+
for (i = 0; i < nn; i++, cp++, pnpp++, popp++){
|
154
|
+
ci = *cp;
|
155
|
+
pnp = *pnpp;
|
156
|
+
pop = *popp;
|
157
|
+
/* from each bias, input, and node */
|
158
|
+
for (j = 0; j < nt; j++, ci++, pnp++, pop++){
|
159
|
+
if (ci->con == 0)
|
160
|
+
continue;
|
161
|
+
pn = *pnp;
|
162
|
+
zn = anew + np;
|
163
|
+
n = ninfo;
|
164
|
+
/* for each node */
|
165
|
+
for (k = 0; k < nn; k++, zn++, pn++, n++){
|
166
|
+
w = *(awt + k) + np;
|
167
|
+
po = *pop;
|
168
|
+
if (i == k)
|
169
|
+
*sum = *(aold + j);
|
170
|
+
else
|
171
|
+
*sum = 0.;
|
172
|
+
/* from each node */
|
173
|
+
for (l = 0; l < nn; l++, w++, po++){
|
174
|
+
*sum += *w * *po;
|
175
|
+
}
|
176
|
+
if (n->func == 0)
|
177
|
+
*pn = *zn * (1. - *zn) * *sum;
|
178
|
+
else if (n->func == 1)
|
179
|
+
*pn = .5 * (1. + *zn)*(1. - *zn) * *sum;
|
180
|
+
else if (n->func == 2){
|
181
|
+
*pn = *sum;
|
182
|
+
}
|
183
|
+
if (n->dela == 0)
|
184
|
+
*(*(*(apold + i) + j) + k) = *pn;
|
185
|
+
}
|
186
|
+
}
|
187
|
+
}
|
188
|
+
/* to each node */
|
189
|
+
cp = cinfo;
|
190
|
+
wp = adwt;
|
191
|
+
pnpp = apnew;
|
192
|
+
popp = apold;
|
193
|
+
for (i = 0; i < nn; i++, cp++, wp++, pnpp++, popp++){
|
194
|
+
w = *wp;
|
195
|
+
ci = *cp;
|
196
|
+
pnp= *pnpp;
|
197
|
+
pop= *popp;
|
198
|
+
/* from each bias, input, and node */
|
199
|
+
for (j = 0; j < nt; j++, w++, ci++, pnp++, pop++){
|
200
|
+
if (ci->con == 0)
|
201
|
+
continue;
|
202
|
+
e = aerror;
|
203
|
+
pn = *pnp;
|
204
|
+
po = *pop;
|
205
|
+
*sum = 0.;
|
206
|
+
/* for each node */
|
207
|
+
for (k = 0; k < nn; k++, e++, po++, pn++){
|
208
|
+
*sum += *e * *po;
|
209
|
+
*po = *pn;
|
210
|
+
}
|
211
|
+
*w -= *sum;
|
212
|
+
}
|
213
|
+
}
|
214
|
+
|
215
|
+
return;
|
216
|
+
}
|
217
|
+
|
218
|
+
comp_backprop(awt,adwt,aold,amem,atarget,aerror,local)
|
219
|
+
float **awt;
|
220
|
+
float **adwt;
|
221
|
+
float *aold;
|
222
|
+
float *amem;
|
223
|
+
float *atarget;
|
224
|
+
float *aerror;
|
225
|
+
int *local;
|
226
|
+
{
|
227
|
+
register int i;
|
228
|
+
register int j;
|
229
|
+
|
230
|
+
register struct cf **cp;
|
231
|
+
|
232
|
+
register struct cf *ci;
|
233
|
+
register struct nf *n;
|
234
|
+
|
235
|
+
register float *sum;
|
236
|
+
|
237
|
+
float **wp;
|
238
|
+
float *ee;
|
239
|
+
float *e;
|
240
|
+
float *w;
|
241
|
+
float *z;
|
242
|
+
float *oz;
|
243
|
+
float *t;
|
244
|
+
|
245
|
+
int *l;
|
246
|
+
int ns;
|
247
|
+
|
248
|
+
float asum;
|
249
|
+
|
250
|
+
/* compute deltas for output units */
|
251
|
+
sum = &asum;
|
252
|
+
e = aerror;
|
253
|
+
n = ninfo;
|
254
|
+
z = aold + np;
|
255
|
+
t = atarget;
|
256
|
+
for (i = 0; i < nn; i++, e++, n++, z++){
|
257
|
+
if (n->targ == 0)
|
258
|
+
continue;
|
259
|
+
if (n->func == 0) {
|
260
|
+
if (ce > 0) { /* if cross-entropy */
|
261
|
+
/*
|
262
|
+
* note that the following collapses
|
263
|
+
* (t-a) and derivative of slope; we
|
264
|
+
* therefore ignore current contents of
|
265
|
+
* *e (which is (t-a)) and assign new
|
266
|
+
* value, whereas with sse, we multiply *e
|
267
|
+
* by deriv. of slope.
|
268
|
+
*/
|
269
|
+
*e = *t - *z;
|
270
|
+
/* NOTE: this is a kludge -- only increments
|
271
|
+
* target when node is an output node. Do
|
272
|
+
* NOT move into for() control statement.
|
273
|
+
*/
|
274
|
+
t++;
|
275
|
+
} else { /* otherwise normal sse-delta */
|
276
|
+
*e *= *z * (1. - *z);
|
277
|
+
}
|
278
|
+
} else if (n->func == 1)
|
279
|
+
*e *= .5 * (1. + *z) * (1. - *z);
|
280
|
+
}
|
281
|
+
|
282
|
+
n = ninfo + nn - 1;
|
283
|
+
z = aold + nt - 1;
|
284
|
+
e = aerror + nn - 1;
|
285
|
+
/* compute deltas for remaining units */
|
286
|
+
for (i = nn - 1; i >= 0; i--, z--, e--, n--){
|
287
|
+
if (n->targ == 1)
|
288
|
+
continue;
|
289
|
+
*sum = 0.;
|
290
|
+
/* ee contains a bad address for i = nn-1 */
|
291
|
+
ee = aerror + i + 1;
|
292
|
+
for (j = i + 1; j < nn; j++, ee++){
|
293
|
+
w = *(awt + j) + np + i;
|
294
|
+
ci = *(cinfo + j) + np + i;
|
295
|
+
if (ci->con)
|
296
|
+
*sum += *w * *ee;
|
297
|
+
}
|
298
|
+
if (n->func == 0)
|
299
|
+
*e = *sum * *z * (1. - *z);
|
300
|
+
else if (n->func == 1)
|
301
|
+
*e = *sum * .5 * (1. + *z) * (1. - *z);
|
302
|
+
else if (n->func == 2){
|
303
|
+
*e = *sum;
|
304
|
+
}
|
305
|
+
else if (n->func == 3)
|
306
|
+
*e = 0.;
|
307
|
+
}
|
308
|
+
|
309
|
+
/* compute weight changes for all connections */
|
310
|
+
|
311
|
+
/* to each node */
|
312
|
+
e = aerror;
|
313
|
+
cp = cinfo;
|
314
|
+
wp = adwt;
|
315
|
+
for (i = 0; i < nn; i++, e++, cp++, wp++){
|
316
|
+
if (localist){
|
317
|
+
if (ce > 0){
|
318
|
+
if ((*cp)->con)
|
319
|
+
**wp += *e;
|
320
|
+
}
|
321
|
+
else {
|
322
|
+
if ((*cp)->con)
|
323
|
+
**wp -= *e;
|
324
|
+
}
|
325
|
+
l = local;
|
326
|
+
while (*l != 0){
|
327
|
+
if (ce > 0){
|
328
|
+
if ((*cp + *l)->con)
|
329
|
+
*(*wp + *l) += *e;
|
330
|
+
}
|
331
|
+
else {
|
332
|
+
if ((*cp + *l)->con)
|
333
|
+
*(*wp + *l) -= *e;
|
334
|
+
}
|
335
|
+
l++;
|
336
|
+
}
|
337
|
+
w = *wp + np;
|
338
|
+
ci = *cp + np;
|
339
|
+
z = aold + np;
|
340
|
+
oz = amem + np;
|
341
|
+
/* from each node */
|
342
|
+
/* loop is broken into two parts:
|
343
|
+
(1) connections from nodes of lower node-number
|
344
|
+
(2) connections from nodes of = or > node-number
|
345
|
+
the latter case requires use of old z values */
|
346
|
+
if (ce > 0){
|
347
|
+
for (j = 0; j < i; j++, w++, ci++, z++, oz++){
|
348
|
+
if (ci->con)
|
349
|
+
*w += *z * *e;
|
350
|
+
}
|
351
|
+
for (j = i; j < nn; j++, w++, ci++, z++, oz++){
|
352
|
+
if (ci->con)
|
353
|
+
*w += *oz * *e;
|
354
|
+
}
|
355
|
+
}
|
356
|
+
else {
|
357
|
+
for (j = 0; j < i; j++, w++, ci++, z++, oz++){
|
358
|
+
if (ci->con)
|
359
|
+
*w -= *z * *e;
|
360
|
+
}
|
361
|
+
for (j = i; j < nn; j++, w++, ci++, z++, oz++){
|
362
|
+
if (ci->con)
|
363
|
+
*w -= *oz * *e;
|
364
|
+
}
|
365
|
+
}
|
366
|
+
}
|
367
|
+
else {
|
368
|
+
w = *wp;
|
369
|
+
ci = *cp;
|
370
|
+
z = aold;
|
371
|
+
oz = amem;
|
372
|
+
/* from each bias, input, and node */
|
373
|
+
ns = np + i;
|
374
|
+
/* loop is broken into two parts:
|
375
|
+
(1) connections from nodes of lower node-number
|
376
|
+
(2) connections from nodes of = or > node-number
|
377
|
+
the latter case requires use of old z values */
|
378
|
+
if (ce > 0){
|
379
|
+
for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
|
380
|
+
if (ci->con)
|
381
|
+
*w += *z * *e;
|
382
|
+
}
|
383
|
+
for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
|
384
|
+
if (ci->con)
|
385
|
+
*w += *oz * *e;
|
386
|
+
}
|
387
|
+
}
|
388
|
+
else {
|
389
|
+
for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
|
390
|
+
if (ci->con)
|
391
|
+
*w -= *z * *e;
|
392
|
+
}
|
393
|
+
for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
|
394
|
+
if (ci->con)
|
395
|
+
*w -= *oz * *e;
|
396
|
+
}
|
397
|
+
}
|
398
|
+
}
|
399
|
+
}
|
400
|
+
|
401
|
+
return;
|
402
|
+
}
|
403
|
+
|
404
|
+
|
data/ext/tlearn/getopt.c
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
#ifdef THINK_C
|
2
|
+
#define ibmpc
|
3
|
+
#endif /* THINK_C */
|
4
|
+
#ifdef ibmpc
|
5
|
+
#ifndef lint
|
6
|
+
static char sccsid[] = "@(#)getopt.c 1.1 86/09/24 SMI"; /* from S5R2 1.5 */
|
7
|
+
#endif
|
8
|
+
|
9
|
+
/*LINTLIBRARY*/
|
10
|
+
#ifndef THINK_C
|
11
|
+
#define ibmpc
|
12
|
+
#endif /* THINK_C */
|
13
|
+
#define EOF (-1)
|
14
|
+
#define ERR(s, c) if(opterr){\
|
15
|
+
extern int strlen(), write();\
|
16
|
+
char errbuf[2];\
|
17
|
+
errbuf[0] = c; errbuf[1] = '\n';\
|
18
|
+
(void) write(2, argv[0], (unsigned)strlen(argv[0]));\
|
19
|
+
(void) write(2, s, (unsigned)strlen(s));\
|
20
|
+
(void) write(2, errbuf, 2);}
|
21
|
+
|
22
|
+
extern int strcmp();
|
23
|
+
extern char *strchr();
|
24
|
+
|
25
|
+
int opterr = 1;
|
26
|
+
int optind = 1;
|
27
|
+
int optopt;
|
28
|
+
char *optarg;
|
29
|
+
|
30
|
+
int
|
31
|
+
getopt(argc, argv, opts)
|
32
|
+
int argc;
|
33
|
+
char **argv, *opts;
|
34
|
+
{
|
35
|
+
static int sp = 1;
|
36
|
+
register int c;
|
37
|
+
register char *cp;
|
38
|
+
|
39
|
+
if(sp == 1)
|
40
|
+
if(optind >= argc ||
|
41
|
+
argv[optind][0] != '-' || argv[optind][1] == '\0')
|
42
|
+
return(EOF);
|
43
|
+
else if(strcmp(argv[optind], "--") == 0) {
|
44
|
+
optind++;
|
45
|
+
return(EOF);
|
46
|
+
}
|
47
|
+
optopt = c = argv[optind][sp];
|
48
|
+
if(c == ':' || (cp=strchr(opts, c)) == NULL) {
|
49
|
+
ERR(": illegal option -- ", c);
|
50
|
+
if(argv[optind][++sp] == '\0') {
|
51
|
+
optind++;
|
52
|
+
sp = 1;
|
53
|
+
}
|
54
|
+
return('?');
|
55
|
+
}
|
56
|
+
if(*++cp == ':') {
|
57
|
+
if(argv[optind][sp+1] != '\0')
|
58
|
+
optarg = &argv[optind++][sp+1];
|
59
|
+
else if(++optind >= argc) {
|
60
|
+
ERR(": option requires an argument -- ", c);
|
61
|
+
sp = 1;
|
62
|
+
return('?');
|
63
|
+
} else
|
64
|
+
optarg = argv[optind++];
|
65
|
+
sp = 1;
|
66
|
+
} else {
|
67
|
+
if(argv[optind][++sp] == '\0') {
|
68
|
+
sp = 1;
|
69
|
+
optind++;
|
70
|
+
}
|
71
|
+
optarg = NULL;
|
72
|
+
}
|
73
|
+
return(c);
|
74
|
+
}
|
75
|
+
#endif
|
76
|
+
|