tlearn 0.0.5 → 0.0.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (2) hide show
  1. data/ext/tlearn/tlearn.c +525 -0
  2. metadata +2 -1
@@ -0,0 +1,525 @@
1
+
2
+ /* tlearn.c - simulator for arbitrary networks with time-ordered input */
3
+
4
+ /*------------------------------------------------------------------------
5
+
6
+ This program simulates learning in a neural network using either
7
+ the classical back-propagation learning algorithm or a slightly
8
+ modified form derived in Williams and Zipser, "A Learning Algo-
9
+ rithm for Continually Running Fully Recurrent Networks." The
10
+ input is a sequence of vectors of (ascii) floating point numbers
11
+ contained in a ".data" file. The target outputs are a set of
12
+ time-stamped vectors of (ascii) floating point numbers (including
13
+ optional "don't care" values) in a ".teach" file. The network
14
+ configuration is defined in a ".cf" file documented in tlearn.man.
15
+
16
+ ------------------------------------------------------------------------*/
17
+
18
+ #include <math.h>
19
+ #include <stdio.h>
20
+ #include <signal.h>
21
+ #ifdef ibmpc
22
+ #include "strings.h"
23
+ #include <fcntl.h>
24
+ #else
25
+ #ifndef THINK_C
26
+ #include <strings.h>
27
+ #include <sys/file.h>
28
+ #include <stdlib.h>
29
+ #else /* THINK_C */
30
+ #include <console.h>
31
+ #include <time.h>
32
+ #include <stdlib.h>
33
+ #endif /* THINK_C */
34
+ #endif
35
+ #ifdef notdef
36
+ #include <sys/types.h>
37
+ #include <sys/stat.h>
38
+ #endif /* notdef */
39
+
40
+ #ifdef ibmpc
41
+ #define random(x) rand(x)
42
+ #define srandom(x) srand(x)
43
+ #endif
44
+ #ifdef THINK_C
45
+ #define random(x) rand(x)
46
+ #define srandom(x) srand(x)
47
+ #endif /* THINK_C */
48
+
49
+ int nn; /* number of nodes */
50
+ int ni; /* number of inputs */
51
+ int no; /* number of outputs */
52
+ int nt; /* nn + ni + 1 */
53
+ int np; /* ni + 1 */
54
+
55
+ struct cf {
56
+ int con; /* connection flag */
57
+ int fix; /* fixed-weight flag */
58
+ int num; /* group number */
59
+ int lim; /* weight-limits flag */
60
+ float min; /* weight minimum */
61
+ float max; /* weight maximum */
62
+ };
63
+
64
+ struct nf {
65
+ int func; /* activation function type */
66
+ int dela; /* delay flag */
67
+ int targ; /* target flag */
68
+ };
69
+
70
+ struct cf **cinfo; /* (nn x nt) connection info */
71
+ struct nf *ninfo; /* (nn) node activation function info */
72
+
73
+ int *outputs; /* (no) indices of output nodes */
74
+ int *selects; /* (nn+1) nodes selected for probe printout */
75
+ int *linput; /* (ni) localist input array */
76
+
77
+ float *znew; /* (nt) inputs and activations at time t+1 */
78
+ float *zold; /* (nt) inputs and activations at time t */
79
+ float *zmem; /* (nt) inputs and activations at time t */
80
+ float **wt; /* (nn x nt) weight TO node i FROM node j*/
81
+ float **dwt; /* (nn x nt) delta weight at time t */
82
+ float **winc; /* (nn x nt) accumulated weight increment*/
83
+ float *target; /* (no) output target values */
84
+ float *error; /* (nn) error = (output - target) values */
85
+ float ***pnew; /* (nn x nt x nn) p-variable at time t+1 */
86
+ float ***pold; /* (nn x nt x nn) p-variable at time t */
87
+
88
+ float rate = .1; /* learning rate */
89
+ float momentum = 0.; /* momentum */
90
+ float weight_limit = 1.; /* bound for random weight init */
91
+ float criterion = 0.; /* exit program when rms error is less than this */
92
+ float init_bias = 0.; /* possible offset for initial output biases */
93
+
94
+ long sweep = 0; /* current sweep */
95
+ long tsweeps = 0; /* total sweeps to date */
96
+ long rms_report = 0; /* output rms error every "report" sweeps */
97
+
98
+ int ngroups = 0; /* number of groups */
99
+
100
+ int backprop = 1; /* flag for standard back propagation (the default) */
101
+ int teacher = 0; /* flag for feeding back targets */
102
+ int localist = 0; /* flag for speed-up with localist inputs */
103
+ int randomly = 0; /* flag for presenting inputs in random order */
104
+ int limits = 0; /* flag for limited weights */
105
+ int ce = 0; /* flag for cross_entropy */
106
+ #ifdef GRAPHICS
107
+ int dsp_type = 0; /* flag for graphics display */
108
+ int dsp_freq = 0; /* frequency of graphics display */
109
+ int dsp_delay = 0; /* delay of graphics display */
110
+ int dsp_print = 0; /* frequency of graphics hardcopy */
111
+ #endif GRAPHICS
112
+
113
+ char root[128]; /* root filename for .cf, .data, .teach, etc.*/
114
+ char loadfile[128]; /* filename for weightfile to be read in */
115
+
116
+ FILE *cfp; /* file pointer for .cf file */
117
+
118
+ void intr();
119
+
120
+ extern int load_wts();
121
+ extern int save_wts();
122
+ extern int act_nds();
123
+
124
+
125
+ main(argc,argv)
126
+ int argc;
127
+ char **argv;
128
+ {
129
+
130
+ FILE *fopen();
131
+ FILE *fpid;
132
+ extern char *optarg;
133
+ extern float rans();
134
+ extern time_t time();
135
+
136
+
137
+ long nsweeps = 0; /* number of sweeps to run for */
138
+ long ttime = 0; /* number of sweeps since time = 0 */
139
+ long utime = 0; /* number of sweeps since last update_weights */
140
+ long tmax = 0; /* maximum number of sweeps (given in .data) */
141
+ long umax = 0; /* update weights every umax sweeps */
142
+ long rtime = 0; /* number of sweeps since last rms_report */
143
+ long check = 0; /* output weights every "check" sweeps */
144
+ long ctime = 0; /* number of sweeps since last check */
145
+
146
+ int c;
147
+ int i;
148
+ int j;
149
+ int k;
150
+ int nticks = 1; /* number of internal clock ticks per input */
151
+ int ticks = 0; /* counter for ticks */
152
+ int learning = 1; /* flag for learning */
153
+ int reset = 0; /* flag for resetting net */
154
+ int verify = 0; /* flag for printing output values */
155
+ int probe = 0; /* flag for printing selected node values */
156
+ int command = 1; /* flag for writing to .cmd file */
157
+ int loadflag = 0; /* flag for loading initial weights from file */
158
+ int iflag = 0; /* flag for -I */
159
+ int tflag = 0; /* flag for -T */
160
+ int rflag = 0; /* flag for -x */
161
+ int seed = 0; /* seed for random() */
162
+
163
+ float err = 0.; /* cumulative ss error */
164
+ float ce_err = 0.; /* cumulate cross_entropy error */
165
+
166
+ float *w;
167
+ float *wi;
168
+ float *dw;
169
+ float *pn;
170
+ float *po;
171
+
172
+ struct cf *ci;
173
+
174
+ char cmdfile[128]; /* filename for logging runs of program */
175
+ char cfile[128]; /* filename for .cf file */
176
+
177
+ FILE *cmdfp;
178
+
179
+ #ifdef THINK_C
180
+ argc = ccommand(&argv);
181
+ #endif /* THINK_C */
182
+
183
+ signal(SIGINT, intr);
184
+ #ifndef ibmpc
185
+ #ifndef THINK_C
186
+ signal(SIGHUP, intr);
187
+ signal(SIGQUIT, intr);
188
+ signal(SIGKILL, intr);
189
+ #endif /* THINK_C */
190
+ #endif
191
+
192
+ #ifndef ibmpc
193
+ exp_init();
194
+ #endif
195
+
196
+ root[0] = 0;
197
+
198
+ while ((c = getopt(argc, argv, "f:hil:m:n:r:s:tC:E:ILM:PpRS:TU:VvXB:H:D:")) != EOF) {
199
+ switch (c) {
200
+ case 'C':
201
+ check = (long) atol(optarg);
202
+ ctime = check;
203
+ break;
204
+ case 'f':
205
+ strcpy(root,optarg);
206
+ break;
207
+ case 'i':
208
+ command = 0;
209
+ break;
210
+ case 'l':
211
+ loadflag = 1;
212
+ strcpy(loadfile,optarg);
213
+ break;
214
+ case 'm':
215
+ momentum = (float) atof(optarg);
216
+ break;
217
+ case 'n':
218
+ nticks = (int) atoi(optarg);
219
+ break;
220
+ case 'P':
221
+ learning = 0;
222
+ /* drop through deliberately */
223
+ case 'p':
224
+ probe = 1;
225
+ break;
226
+ case 'r':
227
+ rate = (double) atof(optarg);
228
+ break;
229
+ case 's':
230
+ nsweeps = (long) atol(optarg);
231
+ break;
232
+ case 't':
233
+ teacher = 1;
234
+ break;
235
+ case 'L':
236
+ backprop = 0;
237
+ break;
238
+ case 'V':
239
+ learning = 0;
240
+ /* drop through deliberately */
241
+ case 'v':
242
+ verify = 1;
243
+ break;
244
+ case 'X':
245
+ rflag = 1;
246
+ break;
247
+ case 'E':
248
+ rms_report = (long) atol(optarg);
249
+ break;
250
+ case 'I':
251
+ iflag = 1;
252
+ break;
253
+ case 'M':
254
+ criterion = (float) atof(optarg);
255
+ break;
256
+ case 'R':
257
+ randomly = 1;
258
+ break;
259
+ case 'S':
260
+ seed = atoi(optarg);
261
+ break;
262
+ case 'T':
263
+ tflag = 1;
264
+ break;
265
+ case 'U':
266
+ umax = atol(optarg);
267
+ break;
268
+ case 'B':
269
+ init_bias = atof(optarg);
270
+ break;
271
+ #ifdef GRAPHICS
272
+ /*
273
+ * graphics display; dsp_type:
274
+ * 0 = no display (default)
275
+ * 1 = weights only
276
+ * 2 = activations only
277
+ * 3 = weights & activations
278
+ */
279
+ case 'D':
280
+ switch (optarg[0]) {
281
+ case 'f':
282
+ optarg++;
283
+ dsp_freq = atol(optarg);
284
+ break;
285
+ case 't':
286
+ optarg++;
287
+ dsp_type = atoi(optarg);
288
+ break;
289
+ case 'd':
290
+ dsp_delay = 1;
291
+ break;
292
+ case 'p':
293
+ optarg++;
294
+ dsp_print = atol(optarg);
295
+ break;
296
+ }
297
+ break;
298
+ #endif GRAPHICS
299
+ /*
300
+ * if == 1, use cross-entropy as error;
301
+ * if == 2, also collect cross-entropy stats.
302
+ */
303
+ case 'H':
304
+ ce = atoi(optarg);
305
+ break;
306
+ case '?':
307
+ case 'h':
308
+ default:
309
+ usage();
310
+ exit(2);
311
+ break;
312
+ }
313
+ }
314
+ if (nsweeps == 0){
315
+ perror("ERROR: No -s specified");
316
+ exit(1);
317
+ }
318
+
319
+ /* open files */
320
+
321
+ if (root[0] == 0){
322
+ perror("ERROR: No fileroot specified");
323
+ exit(1);
324
+ }
325
+
326
+ if (command){
327
+ sprintf(cmdfile, "%s.cmd", root);
328
+ cmdfp = fopen(cmdfile, "a");
329
+ if (cmdfp == NULL) {
330
+ perror("ERROR: Can't open .cmd file");
331
+ exit(1);
332
+ }
333
+ for (i = 1; i < argc; i++)
334
+ fprintf(cmdfp,"%s ",argv[i]);
335
+ fprintf(cmdfp,"\n");
336
+ fflush(cmdfp);
337
+ }
338
+
339
+ #ifndef THINK_C
340
+ sprintf(cmdfile, "%s.pid", root);
341
+ fpid = fopen(cmdfile, "w");
342
+ fprintf(fpid, "%d\n", getpid());
343
+ fclose(fpid);
344
+ #endif /* THINK_C */
345
+
346
+ sprintf(cfile, "%s.cf", root);
347
+ cfp = fopen(cfile, "r");
348
+ if (cfp == NULL) {
349
+ perror("ERROR: Can't open .cf file");
350
+ exit(1);
351
+ }
352
+
353
+ get_nodes();
354
+ make_arrays();
355
+ get_outputs();
356
+ get_connections();
357
+ get_special();
358
+ #ifdef GRAPHICS
359
+ /*
360
+ * graphics must be done after other files are opened
361
+ */
362
+ if (dsp_type != 0)
363
+ init_dsp(root);
364
+ #endif GRAPHICS
365
+ if (!seed)
366
+ seed = time((time_t *) NULL);
367
+ srandom(seed);
368
+
369
+ if (loadflag)
370
+ load_wts();
371
+ else {
372
+ for (i = 0; i < nn; i++){
373
+ w = *(wt + i);
374
+ dw = *(dwt+ i);
375
+ wi = *(winc+ i);
376
+ ci = *(cinfo+ i);
377
+ for (j = 0; j < nt; j++, ci++, w++, wi++, dw++){
378
+ if (ci->con)
379
+ *w = rans(weight_limit);
380
+ else
381
+ *w = 0.;
382
+ *wi = 0.;
383
+ *dw = 0.;
384
+ }
385
+ }
386
+ /*
387
+ * If init_bias, then we want to set initial biases
388
+ * to (*only*) output units to a random negative number.
389
+ * We index into the **wt to find the section of receiver
390
+ * weights for each output node. The first weight in each
391
+ * section is for unit 0 (bias), so no further indexing needed.
392
+ */
393
+ for (i = 0; i < no; i++){
394
+ w = *(wt + outputs[i] - 1);
395
+ ci = *(cinfo + outputs[i] - 1);
396
+ if (ci->con)
397
+ *w = init_bias + rans(.1);
398
+ else
399
+ *w = 0.;
400
+ }
401
+ }
402
+ zold[0] = znew[0] = 1.;
403
+ for (i = 1; i < nt; i++)
404
+ zold[i] = znew[i] = 0.;
405
+ if (backprop == 0){
406
+ make_parrays();
407
+ for (i = 0; i < nn; i++){
408
+ for (j = 0; j < nt; j++){
409
+ po = *(*(pold + i) + j);
410
+ pn = *(*(pnew + i) + j);
411
+ for (k = 0; k < nn; k++, po++, pn++){
412
+ *po = 0.;
413
+ *pn = 0.;
414
+ }
415
+ }
416
+ }
417
+ }
418
+
419
+
420
+ nsweeps += tsweeps;
421
+ for (sweep = tsweeps; sweep < nsweeps; sweep++){
422
+
423
+ for (ticks = 0; ticks < nticks; ticks++){
424
+
425
+ update_reset(ttime,ticks,rflag,&tmax,&reset);
426
+
427
+ if (reset){
428
+ if (backprop == 0)
429
+ reset_network(zold,znew,pold,pnew);
430
+ else
431
+ reset_bp_net(zold,znew);
432
+ }
433
+
434
+ update_inputs(zold,ticks,iflag,&tmax,&linput);
435
+
436
+ if (learning || teacher || (rms_report != 0))
437
+ update_targets(target,ttime,ticks,tflag,&tmax);
438
+
439
+ act_nds(zold,zmem,znew,wt,linput,target);
440
+
441
+ comp_errors(zold,target,error,&err,&ce_err);
442
+
443
+ if (learning && (backprop == 0))
444
+ comp_deltas(pold,pnew,wt,dwt,zold,znew,error);
445
+ if (learning && (backprop == 1))
446
+ comp_backprop(wt,dwt,zold,zmem,target,error,linput);
447
+
448
+ if (probe)
449
+ print_nodes(zold);
450
+ }
451
+ #ifdef GRAPHICS
452
+ if ((dsp_type != 0) && (sweep%dsp_freq == 0))
453
+ do_dsp();
454
+ #endif GRAPHICS
455
+ if (verify)
456
+ print_output(zold);
457
+
458
+ if (rms_report && (++rtime >= rms_report)){
459
+ rtime = 0;
460
+ if (ce == 2)
461
+ print_error(&ce_err);
462
+ else
463
+ print_error(&err);
464
+ }
465
+
466
+ if (check && (++ctime >= check)){
467
+ ctime = 0;
468
+ save_wts();
469
+ }
470
+
471
+ if (++ttime >= tmax)
472
+ ttime = 0;
473
+
474
+ if (learning && (++utime >= umax)){
475
+ utime = 0;
476
+ update_weights(wt,dwt,winc);
477
+ }
478
+
479
+ }
480
+ if (learning)
481
+ save_wts();
482
+ exit(0);
483
+
484
+ }
485
+
486
+ usage() {
487
+ fprintf(stderr, "\n");
488
+ fprintf(stderr, "-f fileroot:\tspecify fileroot <always required>\n");
489
+ fprintf(stderr, "-l weightfile:\tload in weightfile\n");
490
+ fprintf(stderr, "\n");
491
+ fprintf(stderr, "-s #:\trun for # sweeps <always required>\n");
492
+ fprintf(stderr, "-r #:\tset learning rate to # (between 0. and 1.) [0.1]\n");
493
+ fprintf(stderr, "-m #:\tset momentum to # (between 0. and 1.) [0.0]\n");
494
+ fprintf(stderr, "-n #:\t# of clock ticks per input vector [1]\n");
495
+ fprintf(stderr, "-t:\tfeedback teacher values in place of outputs\n");
496
+ fprintf(stderr, "\n");
497
+ fprintf(stderr, "-S #:\tseed for random number generator [random]\n");
498
+ fprintf(stderr, "-U #:\tupdate weights every # sweeps [1]\n");
499
+ fprintf(stderr, "-E #:\trecord rms error in .err file every # sweeps [0]\n");
500
+ fprintf(stderr, "-C #:\tcheckpoint weights file every # sweeps [0]\n");
501
+ fprintf(stderr, "-M #:\texit program when rms error is less than # [0.0]\n");
502
+ fprintf(stderr, "-X:\tuse auxiliary .reset file\n");
503
+ fprintf(stderr, "-P:\tprobe selected nodes on each sweep (no learning)\n");
504
+ fprintf(stderr, "-V:\tverify outputs on each sweep (no learning)\n");
505
+ fprintf(stderr, "-R:\tpresent input patterns in random order\n");
506
+ fprintf(stderr, "-I:\tignore input values during extra clock ticks\n");
507
+ fprintf(stderr, "-T:\tignore target values during extra clock ticks\n");
508
+ fprintf(stderr, "-L:\tuse RTRL temporally recurrent learning\n");
509
+ fprintf(stderr, "-B #:\toffset for offset biasi initialization\n");
510
+ fprintf(stderr, "-Dt#:\tdisplay type (0=none;1=activations;2=weights;3=both\n");
511
+ fprintf(stderr, "-Df#:\tdisplay frequency (#cycles)");
512
+ fprintf(stderr, "-Dp#:\thardcopy print frequency (#cycles)");
513
+ fprintf(stderr, "-Dd:\tdelay after each display\n");
514
+ fprintf(stderr, "\n");
515
+ }
516
+
517
+ void
518
+ intr(sig)
519
+ int sig;
520
+ {
521
+ save_wts();
522
+ exit(sig);
523
+ }
524
+
525
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: tlearn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.5
4
+ version: 0.0.6
5
5
  prerelease:
6
6
  platform: ruby
7
7
  authors:
@@ -32,6 +32,7 @@ files:
32
32
  - ext/tlearn/getopt.c
33
33
  - ext/tlearn/parse.c
34
34
  - ext/tlearn/subs.c
35
+ - ext/tlearn/tlearn.c
35
36
  - ext/tlearn/tlearn_ext.c
36
37
  - ext/tlearn/update.c
37
38
  - ext/tlearn/weights.c