tlearn 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,404 @@
1
+ #include <math.h>
2
+ #include <stdio.h>
3
+
4
+ #ifdef ibmpc
5
+ extern char far *malloc();
6
+ #else
7
+ extern void *malloc();
8
+ #endif
9
+
10
+
11
+ extern int nn; /* number of nodes */
12
+ extern int ni; /* number of inputs */
13
+ extern int no; /* number of outputs */
14
+ extern int nt; /* nn + ni + 1 */
15
+ extern int np; /* ni + 1 */
16
+ extern int ce; /* cross-entropy flag */
17
+
18
+ struct cf {
19
+ int con; /* connection flag */
20
+ int fix; /* fixed-weight flag */
21
+ int num; /* group number */
22
+ int lim; /* weight limits */
23
+ float min; /* weight minimum */
24
+ float max; /* weight maximum */
25
+ };
26
+
27
+ extern struct nf {
28
+ int func; /* activation function type */
29
+ int dela; /* delay flag */
30
+ int targ; /* target flag */
31
+ };
32
+
33
+ extern struct cf **cinfo; /* (nn x nt) connection info */
34
+ extern struct nf *ninfo; /* (nn) node activation function info */
35
+
36
+ extern int *outputs; /* (no) indices of output nodes */
37
+
38
+ extern int localist; /* flag for localist input */
39
+
40
+ comp_errors(aold,atarget,aerror,e,ce_e)
41
+ float *aold;
42
+ float *atarget;
43
+ float *aerror;
44
+ float *e;
45
+ float *ce_e;
46
+ {
47
+ extern int ce;
48
+
49
+ register int i;
50
+ register int j;
51
+ register float *ta;
52
+ register float *te;
53
+ register float *ce_te;
54
+ register float *ee;
55
+ register int *op;
56
+
57
+ static float *terror = 0;
58
+ static float *ce_terror = 0;
59
+
60
+ if (terror == 0){
61
+ /* malloc space for local copy of error info */
62
+ terror = (float *) malloc(no * sizeof(float));
63
+ if (terror == NULL){
64
+ perror("terror malloc failed");
65
+ exit(1);
66
+ }
67
+ }
68
+ if (ce_terror == 0){
69
+ /* malloc space for local copy of cross-entropy info */
70
+ ce_terror = (float *) malloc(no * sizeof(float));
71
+ if (ce_terror == NULL){
72
+ perror("ce_terror malloc failed");
73
+ exit(1);
74
+ }
75
+ }
76
+
77
+
78
+ te = terror;
79
+ ce_te = ce_terror;
80
+ ta = atarget;
81
+ op = outputs;
82
+ for (i = 0; i < no; i++, te++, ce_te++, ta++, op++){
83
+ if (*ta != -9999.0) {
84
+ *te = *(aold + ni + *op) - *ta;
85
+ /*
86
+ * if collecting cross-entropy statistics;
87
+ */
88
+ if (ce == 2) {
89
+ *ce_te = *ta * log(*(aold+ni+ *op))/log(2.0) +
90
+ (1- *ta) * log(1- *(aold+ni+ *op))/log(2.0);
91
+ }
92
+ } else {
93
+ *te = 0.;
94
+ }
95
+ *e += *te * *te; /* cumulative ss error */
96
+ *ce_e += *ce_te; /* cumulate cross-entropy error */
97
+ }
98
+ ee = aerror;
99
+ for (i = 1; i <= nn; i++, ee++){
100
+ *ee = 0.;
101
+ te = terror;
102
+ op = outputs;
103
+ for (j = 0; j < no; j++, te++, op++){
104
+ if (*op == i){
105
+ *ee = *te;
106
+ break;
107
+ }
108
+ }
109
+ }
110
+ }
111
+
112
+
113
+ comp_deltas(apold,apnew,awt,adwt,aold,anew,aerror)
114
+ float ***apold;
115
+ float ***apnew;
116
+ float **awt;
117
+ float **adwt;
118
+ float *aold;
119
+ float *anew;
120
+ float *aerror;
121
+ {
122
+ register int i;
123
+ register int j;
124
+ register int k;
125
+ register int l;
126
+
127
+ register struct cf **cp;
128
+
129
+ register struct cf *ci;
130
+ register struct nf *n;
131
+
132
+ register float **wp;
133
+ register float *zn;
134
+ register float *pn;
135
+ register float *po;
136
+ register float **pnp;
137
+ register float **pop;
138
+ register float ***pnpp;
139
+ register float ***popp;
140
+ register float *w;
141
+
142
+ register float *sum;
143
+
144
+ register float *e;
145
+
146
+ float asum;
147
+
148
+ /* to each node */
149
+ sum = &asum;
150
+ cp = cinfo;
151
+ pnpp = apnew;
152
+ popp = apold;
153
+ for (i = 0; i < nn; i++, cp++, pnpp++, popp++){
154
+ ci = *cp;
155
+ pnp = *pnpp;
156
+ pop = *popp;
157
+ /* from each bias, input, and node */
158
+ for (j = 0; j < nt; j++, ci++, pnp++, pop++){
159
+ if (ci->con == 0)
160
+ continue;
161
+ pn = *pnp;
162
+ zn = anew + np;
163
+ n = ninfo;
164
+ /* for each node */
165
+ for (k = 0; k < nn; k++, zn++, pn++, n++){
166
+ w = *(awt + k) + np;
167
+ po = *pop;
168
+ if (i == k)
169
+ *sum = *(aold + j);
170
+ else
171
+ *sum = 0.;
172
+ /* from each node */
173
+ for (l = 0; l < nn; l++, w++, po++){
174
+ *sum += *w * *po;
175
+ }
176
+ if (n->func == 0)
177
+ *pn = *zn * (1. - *zn) * *sum;
178
+ else if (n->func == 1)
179
+ *pn = .5 * (1. + *zn)*(1. - *zn) * *sum;
180
+ else if (n->func == 2){
181
+ *pn = *sum;
182
+ }
183
+ if (n->dela == 0)
184
+ *(*(*(apold + i) + j) + k) = *pn;
185
+ }
186
+ }
187
+ }
188
+ /* to each node */
189
+ cp = cinfo;
190
+ wp = adwt;
191
+ pnpp = apnew;
192
+ popp = apold;
193
+ for (i = 0; i < nn; i++, cp++, wp++, pnpp++, popp++){
194
+ w = *wp;
195
+ ci = *cp;
196
+ pnp= *pnpp;
197
+ pop= *popp;
198
+ /* from each bias, input, and node */
199
+ for (j = 0; j < nt; j++, w++, ci++, pnp++, pop++){
200
+ if (ci->con == 0)
201
+ continue;
202
+ e = aerror;
203
+ pn = *pnp;
204
+ po = *pop;
205
+ *sum = 0.;
206
+ /* for each node */
207
+ for (k = 0; k < nn; k++, e++, po++, pn++){
208
+ *sum += *e * *po;
209
+ *po = *pn;
210
+ }
211
+ *w -= *sum;
212
+ }
213
+ }
214
+
215
+ return;
216
+ }
217
+
218
+ comp_backprop(awt,adwt,aold,amem,atarget,aerror,local)
219
+ float **awt;
220
+ float **adwt;
221
+ float *aold;
222
+ float *amem;
223
+ float *atarget;
224
+ float *aerror;
225
+ int *local;
226
+ {
227
+ register int i;
228
+ register int j;
229
+
230
+ register struct cf **cp;
231
+
232
+ register struct cf *ci;
233
+ register struct nf *n;
234
+
235
+ register float *sum;
236
+
237
+ float **wp;
238
+ float *ee;
239
+ float *e;
240
+ float *w;
241
+ float *z;
242
+ float *oz;
243
+ float *t;
244
+
245
+ int *l;
246
+ int ns;
247
+
248
+ float asum;
249
+
250
+ /* compute deltas for output units */
251
+ sum = &asum;
252
+ e = aerror;
253
+ n = ninfo;
254
+ z = aold + np;
255
+ t = atarget;
256
+ for (i = 0; i < nn; i++, e++, n++, z++){
257
+ if (n->targ == 0)
258
+ continue;
259
+ if (n->func == 0) {
260
+ if (ce > 0) { /* if cross-entropy */
261
+ /*
262
+ * note that the following collapses
263
+ * (t-a) and derivative of slope; we
264
+ * therefore ignore current contents of
265
+ * *e (which is (t-a)) and assign new
266
+ * value, whereas with sse, we multiply *e
267
+ * by deriv. of slope.
268
+ */
269
+ *e = *t - *z;
270
+ /* NOTE: this is a kludge -- only increments
271
+ * target when node is an output node. Do
272
+ * NOT move into for() control statement.
273
+ */
274
+ t++;
275
+ } else { /* otherwise normal sse-delta */
276
+ *e *= *z * (1. - *z);
277
+ }
278
+ } else if (n->func == 1)
279
+ *e *= .5 * (1. + *z) * (1. - *z);
280
+ }
281
+
282
+ n = ninfo + nn - 1;
283
+ z = aold + nt - 1;
284
+ e = aerror + nn - 1;
285
+ /* compute deltas for remaining units */
286
+ for (i = nn - 1; i >= 0; i--, z--, e--, n--){
287
+ if (n->targ == 1)
288
+ continue;
289
+ *sum = 0.;
290
+ /* ee contains a bad address for i = nn-1 */
291
+ ee = aerror + i + 1;
292
+ for (j = i + 1; j < nn; j++, ee++){
293
+ w = *(awt + j) + np + i;
294
+ ci = *(cinfo + j) + np + i;
295
+ if (ci->con)
296
+ *sum += *w * *ee;
297
+ }
298
+ if (n->func == 0)
299
+ *e = *sum * *z * (1. - *z);
300
+ else if (n->func == 1)
301
+ *e = *sum * .5 * (1. + *z) * (1. - *z);
302
+ else if (n->func == 2){
303
+ *e = *sum;
304
+ }
305
+ else if (n->func == 3)
306
+ *e = 0.;
307
+ }
308
+
309
+ /* compute weight changes for all connections */
310
+
311
+ /* to each node */
312
+ e = aerror;
313
+ cp = cinfo;
314
+ wp = adwt;
315
+ for (i = 0; i < nn; i++, e++, cp++, wp++){
316
+ if (localist){
317
+ if (ce > 0){
318
+ if ((*cp)->con)
319
+ **wp += *e;
320
+ }
321
+ else {
322
+ if ((*cp)->con)
323
+ **wp -= *e;
324
+ }
325
+ l = local;
326
+ while (*l != 0){
327
+ if (ce > 0){
328
+ if ((*cp + *l)->con)
329
+ *(*wp + *l) += *e;
330
+ }
331
+ else {
332
+ if ((*cp + *l)->con)
333
+ *(*wp + *l) -= *e;
334
+ }
335
+ l++;
336
+ }
337
+ w = *wp + np;
338
+ ci = *cp + np;
339
+ z = aold + np;
340
+ oz = amem + np;
341
+ /* from each node */
342
+ /* loop is broken into two parts:
343
+ (1) connections from nodes of lower node-number
344
+ (2) connections from nodes of = or > node-number
345
+ the latter case requires use of old z values */
346
+ if (ce > 0){
347
+ for (j = 0; j < i; j++, w++, ci++, z++, oz++){
348
+ if (ci->con)
349
+ *w += *z * *e;
350
+ }
351
+ for (j = i; j < nn; j++, w++, ci++, z++, oz++){
352
+ if (ci->con)
353
+ *w += *oz * *e;
354
+ }
355
+ }
356
+ else {
357
+ for (j = 0; j < i; j++, w++, ci++, z++, oz++){
358
+ if (ci->con)
359
+ *w -= *z * *e;
360
+ }
361
+ for (j = i; j < nn; j++, w++, ci++, z++, oz++){
362
+ if (ci->con)
363
+ *w -= *oz * *e;
364
+ }
365
+ }
366
+ }
367
+ else {
368
+ w = *wp;
369
+ ci = *cp;
370
+ z = aold;
371
+ oz = amem;
372
+ /* from each bias, input, and node */
373
+ ns = np + i;
374
+ /* loop is broken into two parts:
375
+ (1) connections from nodes of lower node-number
376
+ (2) connections from nodes of = or > node-number
377
+ the latter case requires use of old z values */
378
+ if (ce > 0){
379
+ for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
380
+ if (ci->con)
381
+ *w += *z * *e;
382
+ }
383
+ for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
384
+ if (ci->con)
385
+ *w += *oz * *e;
386
+ }
387
+ }
388
+ else {
389
+ for (j = 0; j < ns; j++, w++, ci++, z++, oz++){
390
+ if (ci->con)
391
+ *w -= *z * *e;
392
+ }
393
+ for (j = ns; j < nt; j++, w++, ci++, z++, oz++){
394
+ if (ci->con)
395
+ *w -= *oz * *e;
396
+ }
397
+ }
398
+ }
399
+ }
400
+
401
+ return;
402
+ }
403
+
404
+
@@ -0,0 +1,14 @@
1
+ require 'mkmf'
2
+
3
+ # compiler optimization level
4
+ OPTIMIZE= '-O'
5
+
6
+ RbConfig::MAKEFILE_CONFIG['CC'] = 'gcc'
7
+
8
+ $CFLAGS = "#{OPTIMIZE}"
9
+ $LFLAGS= '-abcn -Dlint'
10
+
11
+ $DESTDIR= '.'
12
+ $LIBS = '-lm'
13
+
14
+ create_makefile('tlearn/tlearn')
@@ -0,0 +1,76 @@
1
+ #ifdef THINK_C
2
+ #define ibmpc
3
+ #endif /* THINK_C */
4
+ #ifdef ibmpc
5
+ #ifndef lint
6
+ static char sccsid[] = "@(#)getopt.c 1.1 86/09/24 SMI"; /* from S5R2 1.5 */
7
+ #endif
8
+
9
+ /*LINTLIBRARY*/
10
+ #ifndef THINK_C
11
+ #define ibmpc
12
+ #endif /* THINK_C */
13
+ #define EOF (-1)
14
+ #define ERR(s, c) if(opterr){\
15
+ extern int strlen(), write();\
16
+ char errbuf[2];\
17
+ errbuf[0] = c; errbuf[1] = '\n';\
18
+ (void) write(2, argv[0], (unsigned)strlen(argv[0]));\
19
+ (void) write(2, s, (unsigned)strlen(s));\
20
+ (void) write(2, errbuf, 2);}
21
+
22
+ extern int strcmp();
23
+ extern char *strchr();
24
+
25
+ int opterr = 1;
26
+ int optind = 1;
27
+ int optopt;
28
+ char *optarg;
29
+
30
+ int
31
+ getopt(argc, argv, opts)
32
+ int argc;
33
+ char **argv, *opts;
34
+ {
35
+ static int sp = 1;
36
+ register int c;
37
+ register char *cp;
38
+
39
+ if(sp == 1)
40
+ if(optind >= argc ||
41
+ argv[optind][0] != '-' || argv[optind][1] == '\0')
42
+ return(EOF);
43
+ else if(strcmp(argv[optind], "--") == 0) {
44
+ optind++;
45
+ return(EOF);
46
+ }
47
+ optopt = c = argv[optind][sp];
48
+ if(c == ':' || (cp=strchr(opts, c)) == NULL) {
49
+ ERR(": illegal option -- ", c);
50
+ if(argv[optind][++sp] == '\0') {
51
+ optind++;
52
+ sp = 1;
53
+ }
54
+ return('?');
55
+ }
56
+ if(*++cp == ':') {
57
+ if(argv[optind][sp+1] != '\0')
58
+ optarg = &argv[optind++][sp+1];
59
+ else if(++optind >= argc) {
60
+ ERR(": option requires an argument -- ", c);
61
+ sp = 1;
62
+ return('?');
63
+ } else
64
+ optarg = argv[optind++];
65
+ sp = 1;
66
+ } else {
67
+ if(argv[optind][++sp] == '\0') {
68
+ sp = 1;
69
+ optind++;
70
+ }
71
+ optarg = NULL;
72
+ }
73
+ return(c);
74
+ }
75
+ #endif
76
+