tlearn 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,404 @@
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+
4
+ #ifdef ibmpc
5
+ extern char far *malloc();
6
+ #else
7
+ extern void *malloc();
8
+ #endif
9
+
10
+
11
+ extern int nn; /* number of nodes */
12
+ extern int ni; /* number of inputs */
13
+ extern int no; /* number of outputs */
14
+ extern int nt; /* nn + ni + 1 */
15
+ extern int np; /* ni + 1 */
16
+ extern int ce; /* cross-entropy flag */
17
+
18
+ struct cf {
19
+ int con; /* connection flag */
20
+ int fix; /* fixed-weight flag */
21
+ int num; /* group number */
22
+ int lim; /* weight limits */
23
+ float min; /* weight minimum */
24
+ float max; /* weight maximum */
25
+ };
26
+
27
+ extern struct nf {
28
+ int func; /* activation function type */
29
+ int dela; /* delay flag */
30
+ int targ; /* target flag */
31
+ };
32
+
33
+ extern struct cf **cinfo; /* (nn x nt) connection info */
34
+ extern struct nf *ninfo; /* (nn) node activation function info */
35
+
36
+ extern int *outputs; /* (no) indices of output nodes */
37
+
38
+ extern int localist; /* flag for localist input */
39
+
40
+ comp_errors(aold,atarget,aerror,e,ce_e)
41
+ float *aold;
42
+ float *atarget;
43
+ float *aerror;
44
+ float *e;
45
+ float *ce_e;
46
+ {
47
+ extern int ce;
48
+
49
+ register int i;
50
+ register int j;
51
+ register float *ta;
52
+ register float *te;
53
+ register float *ce_te;
54
+ register float *ee;
55
+ register int *op;
56
+
57
+ static float *terror = 0;
58
+ static float *ce_terror = 0;
59
+
60
+ if (terror == 0){
61
+ /* malloc space for local copy of error info */
62
+ terror = (float *) malloc(no * sizeof(float));
63
+ if (terror == NULL){
64
+ perror("terror malloc failed");
65
+ exit(1);
66
+ }
67
+ }
68
+ if (ce_terror == 0){
69
+ /* malloc space for local copy of cross-entropy info */
70
+ ce_terror = (float *) malloc(no * sizeof(float));
71
+ if (ce_terror == NULL){
72
+ perror("ce_terror malloc failed");
73
+ exit(1);
74
+ }
75
+ }
76
+
77
+
78
+ te = terror;
79
+ ce_te = ce_terror;
80
+ ta = atarget;
81
+ op = outputs;
82
+ for (i = 0; i < no; i++, te++, ce_te++, ta++, op++){
83
+ if (*ta != -9999.0) {
84
+ *te = *(aold + ni + *op) - *ta;
85
+ /*
86
+ * if collecting cross-entropy statistics;
87
+ */
88
+ if (ce == 2) {
89
+ *ce_te = *ta * log(*(aold+ni+ *op))/log(2.0) +
90
+ (1- *ta) * log(1- *(aold+ni+ *op))/log(2.0);
91
+ }
92
+ } else {
93
+ *te = 0.;
94
+ }
95
+ *e += *te * *te; /* cumulative ss error */
96
+ *ce_e += *ce_te; /* cumulate cross-entropy error */
97
+ }
98
+ ee = aerror;
99
+ for (i = 1; i <= nn; i++, ee++){
100
+ *ee = 0.;
101
+ te = terror;
102
+ op = outputs;
103
+ for (j = 0; j < no; j++, te++, op++){
104
+ if (*op == i){
105
+ *ee = *te;
106
+ break;
107
+ }
108
+ }
109
+ }
110
+ }
111
+
112
+
113
+ comp_deltas(apold,apnew,awt,adwt,aold,anew,aerror)
114
+ float ***apold;
115
+ float ***apnew;
116
+ float **awt;
117
+ float **adwt;
118
+ float *aold;
119
+ float *anew;
120
+ float *aerror;
121
+ {
122
+ register int i;
123
+ register int j;
124
+ register int k;
125
+ register int l;
126
+
127
+ register struct cf **cp;
128
+
129
+ register struct cf *ci;
130
+ register struct nf *n;
131
+
132
+ register float **wp;
133
+ register float *zn;
134
+ register float *pn;
135
+ register float *po;
136
+ register float **pnp;
137
+ register float **pop;
138
+ register float ***pnpp;
139
+ register float ***popp;
140
+ register float *w;
141
+
142
+ register float *sum;
143
+
144
+ register float *e;
145
+
146
+ float asum;
147
+
148
+ /* to each node */
149
+ sum = &asum;
150
+ cp = cinfo;
151
+ pnpp = apnew;
152
+ popp = apold;
153
+ for (i = 0; i < nn; i++, cp++, pnpp++, popp++){
154
+ ci = *cp;
155
+ pnp = *pnpp;
156
+ pop = *popp;
157
+ /* from each bias, input, and node */
158
+ for (j = 0; j < nt; j++, ci++, pnp++, pop++){
159
+ if (ci->con == 0)
160
+ continue;
161
+ pn = *pnp;
162
+ zn = anew + np;
163
+ n = ninfo;
164
+ /* for each node */
165
+ for (k = 0; k < nn; k++, zn++, pn++, n++){
166
+ w = *(awt + k) + np;
167
+ po = *pop;
168
+ if (i == k)
169
+ *sum = *(aold + j);
170
+ else
171
+ *sum = 0.;
172
+ /* from each node */
173
+ for (l = 0; l < nn; l++, w++, po++){
174
+ *sum += *w * *po;
175
+ }
176
+ if (n->func == 0)
177
+ *pn = *zn * (1. - *zn) * *sum;
178
+ else if (n->func == 1)
179
+ *pn = .5 * (1. + *zn)*(1. - *zn) * *sum;
180
+ else if (n->func == 2){
181
+ *pn = *sum;
182
+ }
183
+ if (n->dela == 0)
184
+ *(*(*(apold + i) + j) + k) = *pn;
185
+ }
186
+ }
187
+ }
188
+ /* to each node */
189
+ cp = cinfo;
190
+ wp = adwt;
191
+ pnpp = apnew;
192
+ popp = apold;
193
+ for (i = 0; i < nn; i++, cp++, wp++, pnpp++, popp++){
194
+ w = *wp;
195
+ ci = *cp;
196
+ pnp= *pnpp;
197
+ pop= *popp;
198
+ /* from each bias, input, and node */
199
+ for (j = 0; j < nt; j++, w++, ci++, pnp++, pop++){
200
+ if (ci->con == 0)
201
+ continue;
202
+ e = aerror;
203
+ pn = *pnp;
204
+ po = *pop;
205
+ *sum = 0.;
206
+ /* for each node */
207
+ for (k = 0; k < nn; k++, e++, po++, pn++){
208
+ *sum += *e * *po;
209
+ *po = *pn;
210
+ }
211
+ *w -= *sum;
212
+ }
213
+ }
214
+
215
+ return;
216
+ }
217
+
218
+ comp_backprop(awt,adwt,aold,amem,atarget,aerror,local)
219
+ float **awt;
220
+ float **adwt;
221
+ float *aold;
222
+ float *amem;
223
+ float *atarget;
224
+ float *aerror;
225
+ int *local;
226
+ {
227
+ register int i;
228
+ register int j;
229
+
230
+ register struct cf **cp;
231
+
232
+ register struct cf *ci;
233
+ register struct nf *n;
234
+
235
+ register float *sum;
236
+
237
+ float **wp;
238
+ float *ee;
239
+ float *e;
240
+ float *w;
241
+ float *z;
242
+ float *oz;
243
+ float *t;
244
+
245
+ int *l;
246
+ int ns;
247
+
248
+ float asum;
249
+
250
+ /* compute deltas for output units */
251
+ sum = &asum;
252
+ e = aerror;
253
+ n = ninfo;
254
+ z = aold + np;
255
+ t = atarget;
256
+ for (i = 0; i < nn; i++, e++, n++, z++){
257
+ if (n->targ == 0)
258
+ continue;
259
+ if (n->func == 0) {
260
+ if (ce > 0) { /* if cross-entropy */
261
+ /*
262
+ * note that the following collapses
263
+ * (t-a) and derivative of slope; we
264
+ * therefore ignore current contents of
265
+ * *e (which is (t-a)) and assign new
266
+ * value, whereas with sse, we multiply *e
267
+ * by deriv. of slope.
268
+ */
269
+ *e = *t - *z;
270
+ /* NOTE: this is a kludge -- only increments
271
+ * target when node is an output node. Do
272
+ * NOT move into for() control statement.
273
+ */
274
+ t++;
275
+ } else { /* otherwise normal sse-delta */
276
+ *e *= *z * (1. - *z);
277
+ }
278
+ } else if (n->func == 1)
279
+ *e *= .5 * (1. + *z) * (1. - *z);
280
+ }
281
+
282
+ n = ninfo + nn - 1;
283
+ z = aold + nt - 1;
284
+ e = aerror + nn - 1;
285
+ /* compute deltas for remaining units */
286
+ for (i = nn - 1; i >= 0; i--, z--, e--, n--){
287
+ if (n->targ == 1)
288
+ continue;
289
+ *sum = 0.;
290
+ /* ee contains a bad address for i = nn-1 */
291
+ ee = aerror + i + 1;
292
+ for (j = i + 1; j < nn; j++, ee++){
293
+ w = *(awt + j) + np + i;
294
+ ci = *(cinfo + j) + np + i;
295
+ if (ci->con)
296
+ *sum += *w * *ee;
297
+ }
298
+ if (n->func == 0)
299
+ *e = *sum * *z * (1. - *z);
300
+ else if (n->func == 1)
301
+ *e = *sum * .5 * (1. + *z) * (1. - *z);
302
+ else if (n->func == 2){
303
+ *e = *sum;
304
+ }
305
+ else if (n->func == 3)
306
+ *e = 0.;
307
+ }
308
+
309
+ /* compute weight changes for all connections */
310
+
311
+ /* to each node */
312
+ e = aerror;
313
+ cp = cinfo;
314
+ wp = adwt;
315
+ for (i = 0; i < nn; i++, e++, cp++, wp++){
316
+ if (localist){
317
+ if (ce > 0){
318
+ if ((*cp)->con)
319
+ **wp += *e;
320
+ }
321
+ else {
322
+ if ((*cp)->con)
323
+ **wp -= *e;
324
+ }
325
+ l = local;
326
+ while (*l != 0){
327
+ if (ce > 0){
328
+ if ((*cp + *l)->con)
329
+ *(*wp + *l) += *e;
330
+ }
331
+ else {
332
+ if ((*cp + *l)->con)
333
+ *(*wp + *l) -= *e;
334
+ }
335
+ l++;
336
+ }
337
+ w = *wp + np;
338
+ ci = *cp + np;
339
+ z = aold + np;
340
+ oz = amem + np;
341
+ /* from each node */
342
+ /* loop is broken into two parts:
343
+ (1) connections from nodes of lower node-number
344
+ (2) connections from nodes of = or > node-number
345
+ the latter case requires use of old z values */
346
+ if (ce > 0){
347
+ for (j = 0; j < i; j++, w++, ci++, z++, oz++){
348
+ if (ci->con)
349
+ *w += *z * *e;
350
+ }
351
+ for (j = i; j < nn; j++, w++, ci++, z++, oz++){
352
+ if (ci->con)
353
+ *w += *oz * *e;
354
+ }
355
+ }
356
+ else {
357
+ for (j = 0; j < i; j++, w++, ci++, z++, oz++){
358
+ if (ci->con)
359
+ *w -= *z * *e;
360
+ }
361
+ for (j = i; j < nn; j++, w++, ci++, z++, oz++){
362
+ if (ci->con)
363
+ *w -= *oz * *e;
364
+ }
365
+ }
366
+ }
367
+ else {
368
+ w = *wp;
369
+ ci = *cp;
370
+ z = aold;
371
+ oz = amem;
372
+ /* from each bias, input, and node */
373
+ ns = np + i;
374
+ /* loop is broken into two parts:
375
+ (1) connections from nodes of lower node-number
376
+ (2) connections from nodes of = or > node-number
377
+ the latter case requires use of old z values */
378
+ if (ce > 0){
379
+ for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
380
+ if (ci->con)
381
+ *w += *z * *e;
382
+ }
383
+ for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
384
+ if (ci->con)
385
+ *w += *oz * *e;
386
+ }
387
+ }
388
+ else {
389
+ for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
390
+ if (ci->con)
391
+ *w -= *z * *e;
392
+ }
393
+ for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
394
+ if (ci->con)
395
+ *w -= *oz * *e;
396
+ }
397
+ }
398
+ }
399
+ }
400
+
401
+ return;
402
+ }
403
+
404
+
@@ -0,0 +1,14 @@
1
+ require 'mkmf'
2
+
3
+ # compiler optimization level
4
+ OPTIMIZE= '-O'
5
+
6
+ RbConfig::MAKEFILE_CONFIG['CC'] = 'gcc'
7
+
8
+ $CFLAGS = "#{OPTIMIZE}"
9
+ $LFLAGS= '-abcn -Dlint'
10
+
11
+ $DESTDIR= '.'
12
+ $LIBS = '-lm'
13
+
14
+ create_makefile('tlearn/tlearn')
@@ -0,0 +1,76 @@
1
+ #ifdef THINK_C
2
+ #define ibmpc
3
+ #endif /* THINK_C */
4
+ #ifdef ibmpc
5
+ #ifndef lint
6
+ static char sccsid[] = "@(#)getopt.c 1.1 86/09/24 SMI"; /* from S5R2 1.5 */
7
+ #endif
8
+
9
+ /*LINTLIBRARY*/
10
+ #ifndef THINK_C
11
+ #define ibmpc
12
+ #endif /* THINK_C */
13
+ #define EOF (-1)
14
+ #define ERR(s, c) if(opterr){\
15
+ extern int strlen(), write();\
16
+ char errbuf[2];\
17
+ errbuf[0] = c; errbuf[1] = '\n';\
18
+ (void) write(2, argv[0], (unsigned)strlen(argv[0]));\
19
+ (void) write(2, s, (unsigned)strlen(s));\
20
+ (void) write(2, errbuf, 2);}
21
+
22
+ extern int strcmp();
23
+ extern char *strchr();
24
+
25
+ int opterr = 1;
26
+ int optind = 1;
27
+ int optopt;
28
+ char *optarg;
29
+
30
+ int
31
+ getopt(argc, argv, opts)
32
+ int argc;
33
+ char **argv, *opts;
34
+ {
35
+ static int sp = 1;
36
+ register int c;
37
+ register char *cp;
38
+
39
+ if(sp == 1)
40
+ if(optind >= argc ||
41
+ argv[optind][0] != '-' || argv[optind][1] == '\0')
42
+ return(EOF);
43
+ else if(strcmp(argv[optind], "--") == 0) {
44
+ optind++;
45
+ return(EOF);
46
+ }
47
+ optopt = c = argv[optind][sp];
48
+ if(c == ':' || (cp=strchr(opts, c)) == NULL) {
49
+ ERR(": illegal option -- ", c);
50
+ if(argv[optind][++sp] == '\0') {
51
+ optind++;
52
+ sp = 1;
53
+ }
54
+ return('?');
55
+ }
56
+ if(*++cp == ':') {
57
+ if(argv[optind][sp+1] != '\0')
58
+ optarg = &argv[optind++][sp+1];
59
+ else if(++optind >= argc) {
60
+ ERR(": option requires an argument -- ", c);
61
+ sp = 1;
62
+ return('?');
63
+ } else
64
+ optarg = argv[optind++];
65
+ sp = 1;
66
+ } else {
67
+ if(argv[optind][++sp] == '\0') {
68
+ sp = 1;
69
+ optind++;
70
+ }
71
+ optarg = NULL;
72
+ }
73
+ return(c);
74
+ }
75
+ #endif
76
+