tlearn 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/ext/tlearn/Exp/exp.c +13 -0
- data/ext/tlearn/activate.c +222 -0
- data/ext/tlearn/arrays.c +224 -0
- data/ext/tlearn/compute.c +404 -0
- data/ext/tlearn/extconf.rb +14 -0
- data/ext/tlearn/getopt.c +76 -0
- data/ext/tlearn/parse.c +594 -0
- data/ext/tlearn/subs.c +204 -0
- data/ext/tlearn/tlearn.c +525 -0
- data/ext/tlearn/tlearn_ext.c +587 -0
- data/ext/tlearn/update.c +577 -0
- data/ext/tlearn/weights.c +116 -0
- data/lib/tlearn.rb +17 -0
- data/lib/tlearn/config.rb +101 -0
- data/lib/tlearn/fitness_data.rb +24 -0
- data/lib/tlearn/run.rb +29 -0
- data/lib/tlearn/run_tlearn.rb +68 -0
- data/lib/tlearn/training_data.rb +41 -0
- metadata +64 -0
@@ -0,0 +1,404 @@
|
|
1
|
+
#include <math.h>
|
2
|
+
#include <stdio.h>
|
3
|
+
|
4
|
+
#ifdef ibmpc
|
5
|
+
extern char far *malloc();
|
6
|
+
#else
|
7
|
+
extern void *malloc();
|
8
|
+
#endif
|
9
|
+
|
10
|
+
|
11
|
+
extern int nn; /* number of nodes */
|
12
|
+
extern int ni; /* number of inputs */
|
13
|
+
extern int no; /* number of outputs */
|
14
|
+
extern int nt; /* nn + ni + 1 */
|
15
|
+
extern int np; /* ni + 1 */
|
16
|
+
extern int ce; /* cross-entropy flag */
|
17
|
+
|
18
|
+
struct cf {
|
19
|
+
int con; /* connection flag */
|
20
|
+
int fix; /* fixed-weight flag */
|
21
|
+
int num; /* group number */
|
22
|
+
int lim; /* weight limits */
|
23
|
+
float min; /* weight minimum */
|
24
|
+
float max; /* weight maximum */
|
25
|
+
};
|
26
|
+
|
27
|
+
extern struct nf {
|
28
|
+
int func; /* activation function type */
|
29
|
+
int dela; /* delay flag */
|
30
|
+
int targ; /* target flag */
|
31
|
+
};
|
32
|
+
|
33
|
+
extern struct cf **cinfo; /* (nn x nt) connection info */
|
34
|
+
extern struct nf *ninfo; /* (nn) node activation function info */
|
35
|
+
|
36
|
+
extern int *outputs; /* (no) indices of output nodes */
|
37
|
+
|
38
|
+
extern int localist; /* flag for localist input */
|
39
|
+
|
40
|
+
comp_errors(aold,atarget,aerror,e,ce_e)
|
41
|
+
float *aold;
|
42
|
+
float *atarget;
|
43
|
+
float *aerror;
|
44
|
+
float *e;
|
45
|
+
float *ce_e;
|
46
|
+
{
|
47
|
+
extern int ce;
|
48
|
+
|
49
|
+
register int i;
|
50
|
+
register int j;
|
51
|
+
register float *ta;
|
52
|
+
register float *te;
|
53
|
+
register float *ce_te;
|
54
|
+
register float *ee;
|
55
|
+
register int *op;
|
56
|
+
|
57
|
+
static float *terror = 0;
|
58
|
+
static float *ce_terror = 0;
|
59
|
+
|
60
|
+
if (terror == 0){
|
61
|
+
/* malloc space for local copy of error info */
|
62
|
+
terror = (float *) malloc(no * sizeof(float));
|
63
|
+
if (terror == NULL){
|
64
|
+
perror("terror malloc failed");
|
65
|
+
exit(1);
|
66
|
+
}
|
67
|
+
}
|
68
|
+
if (ce_terror == 0){
|
69
|
+
/* malloc space for local copy of cross-entropy info */
|
70
|
+
ce_terror = (float *) malloc(no * sizeof(float));
|
71
|
+
if (ce_terror == NULL){
|
72
|
+
perror("ce_terror malloc failed");
|
73
|
+
exit(1);
|
74
|
+
}
|
75
|
+
}
|
76
|
+
|
77
|
+
|
78
|
+
te = terror;
|
79
|
+
ce_te = ce_terror;
|
80
|
+
ta = atarget;
|
81
|
+
op = outputs;
|
82
|
+
for (i = 0; i < no; i++, te++, ce_te++, ta++, op++){
|
83
|
+
if (*ta != -9999.0) {
|
84
|
+
*te = *(aold + ni + *op) - *ta;
|
85
|
+
/*
|
86
|
+
* if collecting cross-entropy statistics;
|
87
|
+
*/
|
88
|
+
if (ce == 2) {
|
89
|
+
*ce_te = *ta * log(*(aold+ni+ *op))/log(2.0) +
|
90
|
+
(1- *ta) * log(1- *(aold+ni+ *op))/log(2.0);
|
91
|
+
}
|
92
|
+
} else {
|
93
|
+
*te = 0.;
|
94
|
+
}
|
95
|
+
*e += *te * *te; /* cumulative ss error */
|
96
|
+
*ce_e += *ce_te; /* cumulate cross-entropy error */
|
97
|
+
}
|
98
|
+
ee = aerror;
|
99
|
+
for (i = 1; i <= nn; i++, ee++){
|
100
|
+
*ee = 0.;
|
101
|
+
te = terror;
|
102
|
+
op = outputs;
|
103
|
+
for (j = 0; j < no; j++, te++, op++){
|
104
|
+
if (*op == i){
|
105
|
+
*ee = *te;
|
106
|
+
break;
|
107
|
+
}
|
108
|
+
}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
|
113
|
+
comp_deltas(apold,apnew,awt,adwt,aold,anew,aerror)
|
114
|
+
float ***apold;
|
115
|
+
float ***apnew;
|
116
|
+
float **awt;
|
117
|
+
float **adwt;
|
118
|
+
float *aold;
|
119
|
+
float *anew;
|
120
|
+
float *aerror;
|
121
|
+
{
|
122
|
+
register int i;
|
123
|
+
register int j;
|
124
|
+
register int k;
|
125
|
+
register int l;
|
126
|
+
|
127
|
+
register struct cf **cp;
|
128
|
+
|
129
|
+
register struct cf *ci;
|
130
|
+
register struct nf *n;
|
131
|
+
|
132
|
+
register float **wp;
|
133
|
+
register float *zn;
|
134
|
+
register float *pn;
|
135
|
+
register float *po;
|
136
|
+
register float **pnp;
|
137
|
+
register float **pop;
|
138
|
+
register float ***pnpp;
|
139
|
+
register float ***popp;
|
140
|
+
register float *w;
|
141
|
+
|
142
|
+
register float *sum;
|
143
|
+
|
144
|
+
register float *e;
|
145
|
+
|
146
|
+
float asum;
|
147
|
+
|
148
|
+
/* to each node */
|
149
|
+
sum = &asum;
|
150
|
+
cp = cinfo;
|
151
|
+
pnpp = apnew;
|
152
|
+
popp = apold;
|
153
|
+
for (i = 0; i < nn; i++, cp++, pnpp++, popp++){
|
154
|
+
ci = *cp;
|
155
|
+
pnp = *pnpp;
|
156
|
+
pop = *popp;
|
157
|
+
/* from each bias, input, and node */
|
158
|
+
for (j = 0; j < nt; j++, ci++, pnp++, pop++){
|
159
|
+
if (ci->con == 0)
|
160
|
+
continue;
|
161
|
+
pn = *pnp;
|
162
|
+
zn = anew + np;
|
163
|
+
n = ninfo;
|
164
|
+
/* for each node */
|
165
|
+
for (k = 0; k < nn; k++, zn++, pn++, n++){
|
166
|
+
w = *(awt + k) + np;
|
167
|
+
po = *pop;
|
168
|
+
if (i == k)
|
169
|
+
*sum = *(aold + j);
|
170
|
+
else
|
171
|
+
*sum = 0.;
|
172
|
+
/* from each node */
|
173
|
+
for (l = 0; l < nn; l++, w++, po++){
|
174
|
+
*sum += *w * *po;
|
175
|
+
}
|
176
|
+
if (n->func == 0)
|
177
|
+
*pn = *zn * (1. - *zn) * *sum;
|
178
|
+
else if (n->func == 1)
|
179
|
+
*pn = .5 * (1. + *zn)*(1. - *zn) * *sum;
|
180
|
+
else if (n->func == 2){
|
181
|
+
*pn = *sum;
|
182
|
+
}
|
183
|
+
if (n->dela == 0)
|
184
|
+
*(*(*(apold + i) + j) + k) = *pn;
|
185
|
+
}
|
186
|
+
}
|
187
|
+
}
|
188
|
+
/* to each node */
|
189
|
+
cp = cinfo;
|
190
|
+
wp = adwt;
|
191
|
+
pnpp = apnew;
|
192
|
+
popp = apold;
|
193
|
+
for (i = 0; i < nn; i++, cp++, wp++, pnpp++, popp++){
|
194
|
+
w = *wp;
|
195
|
+
ci = *cp;
|
196
|
+
pnp= *pnpp;
|
197
|
+
pop= *popp;
|
198
|
+
/* from each bias, input, and node */
|
199
|
+
for (j = 0; j < nt; j++, w++, ci++, pnp++, pop++){
|
200
|
+
if (ci->con == 0)
|
201
|
+
continue;
|
202
|
+
e = aerror;
|
203
|
+
pn = *pnp;
|
204
|
+
po = *pop;
|
205
|
+
*sum = 0.;
|
206
|
+
/* for each node */
|
207
|
+
for (k = 0; k < nn; k++, e++, po++, pn++){
|
208
|
+
*sum += *e * *po;
|
209
|
+
*po = *pn;
|
210
|
+
}
|
211
|
+
*w -= *sum;
|
212
|
+
}
|
213
|
+
}
|
214
|
+
|
215
|
+
return;
|
216
|
+
}
|
217
|
+
|
218
|
+
comp_backprop(awt,adwt,aold,amem,atarget,aerror,local)
|
219
|
+
float **awt;
|
220
|
+
float **adwt;
|
221
|
+
float *aold;
|
222
|
+
float *amem;
|
223
|
+
float *atarget;
|
224
|
+
float *aerror;
|
225
|
+
int *local;
|
226
|
+
{
|
227
|
+
register int i;
|
228
|
+
register int j;
|
229
|
+
|
230
|
+
register struct cf **cp;
|
231
|
+
|
232
|
+
register struct cf *ci;
|
233
|
+
register struct nf *n;
|
234
|
+
|
235
|
+
register float *sum;
|
236
|
+
|
237
|
+
float **wp;
|
238
|
+
float *ee;
|
239
|
+
float *e;
|
240
|
+
float *w;
|
241
|
+
float *z;
|
242
|
+
float *oz;
|
243
|
+
float *t;
|
244
|
+
|
245
|
+
int *l;
|
246
|
+
int ns;
|
247
|
+
|
248
|
+
float asum;
|
249
|
+
|
250
|
+
/* compute deltas for output units */
|
251
|
+
sum = &asum;
|
252
|
+
e = aerror;
|
253
|
+
n = ninfo;
|
254
|
+
z = aold + np;
|
255
|
+
t = atarget;
|
256
|
+
for (i = 0; i < nn; i++, e++, n++, z++){
|
257
|
+
if (n->targ == 0)
|
258
|
+
continue;
|
259
|
+
if (n->func == 0) {
|
260
|
+
if (ce > 0) { /* if cross-entropy */
|
261
|
+
/*
|
262
|
+
* note that the following collapses
|
263
|
+
* (t-a) and derivative of slope; we
|
264
|
+
* therefore ignore current contents of
|
265
|
+
* *e (which is (t-a)) and assign new
|
266
|
+
* value, whereas with sse, we multiply *e
|
267
|
+
* by deriv. of slope.
|
268
|
+
*/
|
269
|
+
*e = *t - *z;
|
270
|
+
/* NOTE: this is a kludge -- only increments
|
271
|
+
* target when node is an output node. Do
|
272
|
+
* NOT move into for() control statement.
|
273
|
+
*/
|
274
|
+
t++;
|
275
|
+
} else { /* otherwise normal sse-delta */
|
276
|
+
*e *= *z * (1. - *z);
|
277
|
+
}
|
278
|
+
} else if (n->func == 1)
|
279
|
+
*e *= .5 * (1. + *z) * (1. - *z);
|
280
|
+
}
|
281
|
+
|
282
|
+
n = ninfo + nn - 1;
|
283
|
+
z = aold + nt - 1;
|
284
|
+
e = aerror + nn - 1;
|
285
|
+
/* compute deltas for remaining units */
|
286
|
+
for (i = nn - 1; i >= 0; i--, z--, e--, n--){
|
287
|
+
if (n->targ == 1)
|
288
|
+
continue;
|
289
|
+
*sum = 0.;
|
290
|
+
/* ee contains a bad address for i = nn-1 */
|
291
|
+
ee = aerror + i + 1;
|
292
|
+
for (j = i + 1; j < nn; j++, ee++){
|
293
|
+
w = *(awt + j) + np + i;
|
294
|
+
ci = *(cinfo + j) + np + i;
|
295
|
+
if (ci->con)
|
296
|
+
*sum += *w * *ee;
|
297
|
+
}
|
298
|
+
if (n->func == 0)
|
299
|
+
*e = *sum * *z * (1. - *z);
|
300
|
+
else if (n->func == 1)
|
301
|
+
*e = *sum * .5 * (1. + *z) * (1. - *z);
|
302
|
+
else if (n->func == 2){
|
303
|
+
*e = *sum;
|
304
|
+
}
|
305
|
+
else if (n->func == 3)
|
306
|
+
*e = 0.;
|
307
|
+
}
|
308
|
+
|
309
|
+
/* compute weight changes for all connections */
|
310
|
+
|
311
|
+
/* to each node */
|
312
|
+
e = aerror;
|
313
|
+
cp = cinfo;
|
314
|
+
wp = adwt;
|
315
|
+
for (i = 0; i < nn; i++, e++, cp++, wp++){
|
316
|
+
if (localist){
|
317
|
+
if (ce > 0){
|
318
|
+
if ((*cp)->con)
|
319
|
+
**wp += *e;
|
320
|
+
}
|
321
|
+
else {
|
322
|
+
if ((*cp)->con)
|
323
|
+
**wp -= *e;
|
324
|
+
}
|
325
|
+
l = local;
|
326
|
+
while (*l != 0){
|
327
|
+
if (ce > 0){
|
328
|
+
if ((*cp + *l)->con)
|
329
|
+
*(*wp + *l) += *e;
|
330
|
+
}
|
331
|
+
else {
|
332
|
+
if ((*cp + *l)->con)
|
333
|
+
*(*wp + *l) -= *e;
|
334
|
+
}
|
335
|
+
l++;
|
336
|
+
}
|
337
|
+
w = *wp + np;
|
338
|
+
ci = *cp + np;
|
339
|
+
z = aold + np;
|
340
|
+
oz = amem + np;
|
341
|
+
/* from each node */
|
342
|
+
/* loop is broken into two parts:
|
343
|
+
(1) connections from nodes of lower node-number
|
344
|
+
(2) connections from nodes of = or > node-number
|
345
|
+
the latter case requires use of old z values */
|
346
|
+
if (ce > 0){
|
347
|
+
for (j = 0; j < i; j++, w++, ci++, z++, oz++){
|
348
|
+
if (ci->con)
|
349
|
+
*w += *z * *e;
|
350
|
+
}
|
351
|
+
for (j = i; j < nn; j++, w++, ci++, z++, oz++){
|
352
|
+
if (ci->con)
|
353
|
+
*w += *oz * *e;
|
354
|
+
}
|
355
|
+
}
|
356
|
+
else {
|
357
|
+
for (j = 0; j < i; j++, w++, ci++, z++, oz++){
|
358
|
+
if (ci->con)
|
359
|
+
*w -= *z * *e;
|
360
|
+
}
|
361
|
+
for (j = i; j < nn; j++, w++, ci++, z++, oz++){
|
362
|
+
if (ci->con)
|
363
|
+
*w -= *oz * *e;
|
364
|
+
}
|
365
|
+
}
|
366
|
+
}
|
367
|
+
else {
|
368
|
+
w = *wp;
|
369
|
+
ci = *cp;
|
370
|
+
z = aold;
|
371
|
+
oz = amem;
|
372
|
+
/* from each bias, input, and node */
|
373
|
+
ns = np + i;
|
374
|
+
/* loop is broken into two parts:
|
375
|
+
(1) connections from nodes of lower node-number
|
376
|
+
(2) connections from nodes of = or > node-number
|
377
|
+
the latter case requires use of old z values */
|
378
|
+
if (ce > 0){
|
379
|
+
for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
|
380
|
+
if (ci->con)
|
381
|
+
*w += *z * *e;
|
382
|
+
}
|
383
|
+
for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
|
384
|
+
if (ci->con)
|
385
|
+
*w += *oz * *e;
|
386
|
+
}
|
387
|
+
}
|
388
|
+
else {
|
389
|
+
for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
|
390
|
+
if (ci->con)
|
391
|
+
*w -= *z * *e;
|
392
|
+
}
|
393
|
+
for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
|
394
|
+
if (ci->con)
|
395
|
+
*w -= *oz * *e;
|
396
|
+
}
|
397
|
+
}
|
398
|
+
}
|
399
|
+
}
|
400
|
+
|
401
|
+
return;
|
402
|
+
}
|
403
|
+
|
404
|
+
|
data/ext/tlearn/getopt.c
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
#ifdef THINK_C
|
2
|
+
#define ibmpc
|
3
|
+
#endif /* THINK_C */
|
4
|
+
#ifdef ibmpc
|
5
|
+
#ifndef lint
|
6
|
+
static char sccsid[] = "@(#)getopt.c 1.1 86/09/24 SMI"; /* from S5R2 1.5 */
|
7
|
+
#endif
|
8
|
+
|
9
|
+
/*LINTLIBRARY*/
|
10
|
+
#ifndef THINK_C
|
11
|
+
#define ibmpc
|
12
|
+
#endif /* THINK_C */
|
13
|
+
#define EOF (-1)
|
14
|
+
#define ERR(s, c) if(opterr){\
|
15
|
+
extern int strlen(), write();\
|
16
|
+
char errbuf[2];\
|
17
|
+
errbuf[0] = c; errbuf[1] = '\n';\
|
18
|
+
(void) write(2, argv[0], (unsigned)strlen(argv[0]));\
|
19
|
+
(void) write(2, s, (unsigned)strlen(s));\
|
20
|
+
(void) write(2, errbuf, 2);}
|
21
|
+
|
22
|
+
extern int strcmp();
|
23
|
+
extern char *strchr();
|
24
|
+
|
25
|
+
int opterr = 1;
|
26
|
+
int optind = 1;
|
27
|
+
int optopt;
|
28
|
+
char *optarg;
|
29
|
+
|
30
|
+
int
|
31
|
+
getopt(argc, argv, opts)
|
32
|
+
int argc;
|
33
|
+
char **argv, *opts;
|
34
|
+
{
|
35
|
+
static int sp = 1;
|
36
|
+
register int c;
|
37
|
+
register char *cp;
|
38
|
+
|
39
|
+
if(sp == 1)
|
40
|
+
if(optind >= argc ||
|
41
|
+
argv[optind][0] != '-' || argv[optind][1] == '\0')
|
42
|
+
return(EOF);
|
43
|
+
else if(strcmp(argv[optind], "--") == 0) {
|
44
|
+
optind++;
|
45
|
+
return(EOF);
|
46
|
+
}
|
47
|
+
optopt = c = argv[optind][sp];
|
48
|
+
if(c == ':' || (cp=strchr(opts, c)) == NULL) {
|
49
|
+
ERR(": illegal option -- ", c);
|
50
|
+
if(argv[optind][++sp] == '\0') {
|
51
|
+
optind++;
|
52
|
+
sp = 1;
|
53
|
+
}
|
54
|
+
return('?');
|
55
|
+
}
|
56
|
+
if(*++cp == ':') {
|
57
|
+
if(argv[optind][sp+1] != '\0')
|
58
|
+
optarg = &argv[optind++][sp+1];
|
59
|
+
else if(++optind >= argc) {
|
60
|
+
ERR(": option requires an argument -- ", c);
|
61
|
+
sp = 1;
|
62
|
+
return('?');
|
63
|
+
} else
|
64
|
+
optarg = argv[optind++];
|
65
|
+
sp = 1;
|
66
|
+
} else {
|
67
|
+
if(argv[optind][++sp] == '\0') {
|
68
|
+
sp = 1;
|
69
|
+
optind++;
|
70
|
+
}
|
71
|
+
optarg = NULL;
|
72
|
+
}
|
73
|
+
return(c);
|
74
|
+
}
|
75
|
+
#endif
|
76
|
+
|