tlearn 0.0.5 → 0.0.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. data/ext/tlearn/tlearn.c +525 -0
  2. metadata +2 -1
@@ -0,0 +1,525 @@
1
+
2
+ /* tlearn.c - simulator for arbitrary networks with time-ordered input */
3
+
4
+ /*------------------------------------------------------------------------
5
+
6
+ This program simulates learning in a neural network using either
7
+ the classical back-propagation learning algorithm or a slightly
8
+ modified form derived in Williams and Zipser, "A Learning Algo-
9
+ rithm for Continually Running Fully Recurrent Networks." The
10
+ input is a sequence of vectors of (ascii) floating point numbers
11
+ contained in a ".data" file. The target outputs are a set of
12
+ time-stamped vectors of (ascii) floating point numbers (including
13
+ optional "don't care" values) in a ".teach" file. The network
14
+ configuration is defined in a ".cf" file documented in tlearn.man.
15
+
16
+ ------------------------------------------------------------------------*/
17
+
18
+ #include <math.h>
19
+ #include <stdio.h>
20
+ #include <signal.h>
21
+ #ifdef ibmpc
22
+ #include "strings.h"
23
+ #include <fcntl.h>
24
+ #else
25
+ #ifndef THINK_C
26
+ #include <strings.h>
27
+ #include <sys/file.h>
28
+ #include <stdlib.h>
29
+ #else /* THINK_C */
30
+ #include <console.h>
31
+ #include <time.h>
32
+ #include <stdlib.h>
33
+ #endif /* THINK_C */
34
+ #endif
35
+ #ifdef notdef
36
+ #include <sys/types.h>
37
+ #include <sys/stat.h>
38
+ #endif /* notdef */
39
+
40
+ #ifdef ibmpc
41
+ #define random(x) rand(x)
42
+ #define srandom(x) srand(x)
43
+ #endif
44
+ #ifdef THINK_C
45
+ #define random(x) rand(x)
46
+ #define srandom(x) srand(x)
47
+ #endif /* THINK_C */
48
+
49
+ int nn; /* number of nodes */
50
+ int ni; /* number of inputs */
51
+ int no; /* number of outputs */
52
+ int nt; /* nn + ni + 1 */
53
+ int np; /* ni + 1 */
54
+
55
+ struct cf {
56
+ int con; /* connection flag */
57
+ int fix; /* fixed-weight flag */
58
+ int num; /* group number */
59
+ int lim; /* weight-limits flag */
60
+ float min; /* weight minimum */
61
+ float max; /* weight maximum */
62
+ };
63
+
64
+ struct nf {
65
+ int func; /* activation function type */
66
+ int dela; /* delay flag */
67
+ int targ; /* target flag */
68
+ };
69
+
70
+ struct cf **cinfo; /* (nn x nt) connection info */
71
+ struct nf *ninfo; /* (nn) node activation function info */
72
+
73
+ int *outputs; /* (no) indices of output nodes */
74
+ int *selects; /* (nn+1) nodes selected for probe printout */
75
+ int *linput; /* (ni) localist input array */
76
+
77
+ float *znew; /* (nt) inputs and activations at time t+1 */
78
+ float *zold; /* (nt) inputs and activations at time t */
79
+ float *zmem; /* (nt) inputs and activations at time t */
80
+ float **wt; /* (nn x nt) weight TO node i FROM node j*/
81
+ float **dwt; /* (nn x nt) delta weight at time t */
82
+ float **winc; /* (nn x nt) accumulated weight increment*/
83
+ float *target; /* (no) output target values */
84
+ float *error; /* (nn) error = (output - target) values */
85
+ float ***pnew; /* (nn x nt x nn) p-variable at time t+1 */
86
+ float ***pold; /* (nn x nt x nn) p-variable at time t */
87
+
88
+ float rate = .1; /* learning rate */
89
+ float momentum = 0.; /* momentum */
90
+ float weight_limit = 1.; /* bound for random weight init */
91
+ float criterion = 0.; /* exit program when rms error is less than this */
92
+ float init_bias = 0.; /* possible offset for initial output biases */
93
+
94
+ long sweep = 0; /* current sweep */
95
+ long tsweeps = 0; /* total sweeps to date */
96
+ long rms_report = 0; /* output rms error every "report" sweeps */
97
+
98
+ int ngroups = 0; /* number of groups */
99
+
100
+ int backprop = 1; /* flag for standard back propagation (the default) */
101
+ int teacher = 0; /* flag for feeding back targets */
102
+ int localist = 0; /* flag for speed-up with localist inputs */
103
+ int randomly = 0; /* flag for presenting inputs in random order */
104
+ int limits = 0; /* flag for limited weights */
105
+ int ce = 0; /* flag for cross_entropy */
106
+ #ifdef GRAPHICS
107
+ int dsp_type = 0; /* flag for graphics display */
108
+ int dsp_freq = 0; /* frequency of graphics display */
109
+ int dsp_delay = 0; /* delay of graphics display */
110
+ int dsp_print = 0; /* frequency of graphics hardcopy */
111
+ #endif GRAPHICS
112
+
113
+ char root[128]; /* root filename for .cf, .data, .teach, etc.*/
114
+ char loadfile[128]; /* filename for weightfile to be read in */
115
+
116
+ FILE *cfp; /* file pointer for .cf file */
117
+
118
+ void intr();
119
+
120
+ extern int load_wts();
121
+ extern int save_wts();
122
+ extern int act_nds();
123
+
124
+
125
+ main(argc,argv)
126
+ int argc;
127
+ char **argv;
128
+ {
129
+
130
+ FILE *fopen();
131
+ FILE *fpid;
132
+ extern char *optarg;
133
+ extern float rans();
134
+ extern time_t time();
135
+
136
+
137
+ long nsweeps = 0; /* number of sweeps to run for */
138
+ long ttime = 0; /* number of sweeps since time = 0 */
139
+ long utime = 0; /* number of sweeps since last update_weights */
140
+ long tmax = 0; /* maximum number of sweeps (given in .data) */
141
+ long umax = 0; /* update weights every umax sweeps */
142
+ long rtime = 0; /* number of sweeps since last rms_report */
143
+ long check = 0; /* output weights every "check" sweeps */
144
+ long ctime = 0; /* number of sweeps since last check */
145
+
146
+ int c;
147
+ int i;
148
+ int j;
149
+ int k;
150
+ int nticks = 1; /* number of internal clock ticks per input */
151
+ int ticks = 0; /* counter for ticks */
152
+ int learning = 1; /* flag for learning */
153
+ int reset = 0; /* flag for resetting net */
154
+ int verify = 0; /* flag for printing output values */
155
+ int probe = 0; /* flag for printing selected node values */
156
+ int command = 1; /* flag for writing to .cmd file */
157
+ int loadflag = 0; /* flag for loading initial weights from file */
158
+ int iflag = 0; /* flag for -I */
159
+ int tflag = 0; /* flag for -T */
160
+ int rflag = 0; /* flag for -x */
161
+ int seed = 0; /* seed for random() */
162
+
163
+ float err = 0.; /* cumulative ss error */
164
+ float ce_err = 0.; /* cumulate cross_entropy error */
165
+
166
+ float *w;
167
+ float *wi;
168
+ float *dw;
169
+ float *pn;
170
+ float *po;
171
+
172
+ struct cf *ci;
173
+
174
+ char cmdfile[128]; /* filename for logging runs of program */
175
+ char cfile[128]; /* filename for .cf file */
176
+
177
+ FILE *cmdfp;
178
+
179
+ #ifdef THINK_C
180
+ argc = ccommand(&argv);
181
+ #endif /* THINK_C */
182
+
183
+ signal(SIGINT, intr);
184
+ #ifndef ibmpc
185
+ #ifndef THINK_C
186
+ signal(SIGHUP, intr);
187
+ signal(SIGQUIT, intr);
188
+ signal(SIGKILL, intr);
189
+ #endif /* THINK_C */
190
+ #endif
191
+
192
+ #ifndef ibmpc
193
+ exp_init();
194
+ #endif
195
+
196
+ root[0] = 0;
197
+
198
+ while ((c = getopt(argc, argv, "f:hil:m:n:r:s:tC:E:ILM:PpRS:TU:VvXB:H:D:")) != EOF) {
199
+ switch (c) {
200
+ case 'C':
201
+ check = (long) atol(optarg);
202
+ ctime = check;
203
+ break;
204
+ case 'f':
205
+ strcpy(root,optarg);
206
+ break;
207
+ case 'i':
208
+ command = 0;
209
+ break;
210
+ case 'l':
211
+ loadflag = 1;
212
+ strcpy(loadfile,optarg);
213
+ break;
214
+ case 'm':
215
+ momentum = (float) atof(optarg);
216
+ break;
217
+ case 'n':
218
+ nticks = (int) atoi(optarg);
219
+ break;
220
+ case 'P':
221
+ learning = 0;
222
+ /* drop through deliberately */
223
+ case 'p':
224
+ probe = 1;
225
+ break;
226
+ case 'r':
227
+ rate = (double) atof(optarg);
228
+ break;
229
+ case 's':
230
+ nsweeps = (long) atol(optarg);
231
+ break;
232
+ case 't':
233
+ teacher = 1;
234
+ break;
235
+ case 'L':
236
+ backprop = 0;
237
+ break;
238
+ case 'V':
239
+ learning = 0;
240
+ /* drop through deliberately */
241
+ case 'v':
242
+ verify = 1;
243
+ break;
244
+ case 'X':
245
+ rflag = 1;
246
+ break;
247
+ case 'E':
248
+ rms_report = (long) atol(optarg);
249
+ break;
250
+ case 'I':
251
+ iflag = 1;
252
+ break;
253
+ case 'M':
254
+ criterion = (float) atof(optarg);
255
+ break;
256
+ case 'R':
257
+ randomly = 1;
258
+ break;
259
+ case 'S':
260
+ seed = atoi(optarg);
261
+ break;
262
+ case 'T':
263
+ tflag = 1;
264
+ break;
265
+ case 'U':
266
+ umax = atol(optarg);
267
+ break;
268
+ case 'B':
269
+ init_bias = atof(optarg);
270
+ break;
271
+ #ifdef GRAPHICS
272
+ /*
273
+ * graphics display; dsp_type:
274
+ * 0 = no display (default)
275
+ * 1 = weights only
276
+ * 2 = activations only
277
+ * 3 = weights & activations
278
+ */
279
+ case 'D':
280
+ switch (optarg[0]) {
281
+ case 'f':
282
+ optarg++;
283
+ dsp_freq = atol(optarg);
284
+ break;
285
+ case 't':
286
+ optarg++;
287
+ dsp_type = atoi(optarg);
288
+ break;
289
+ case 'd':
290
+ dsp_delay = 1;
291
+ break;
292
+ case 'p':
293
+ optarg++;
294
+ dsp_print = atol(optarg);
295
+ break;
296
+ }
297
+ break;
298
+ #endif GRAPHICS
299
+ /*
300
+ * if == 1, use cross-entropy as error;
301
+ * if == 2, also collect cross-entropy stats.
302
+ */
303
+ case 'H':
304
+ ce = atoi(optarg);
305
+ break;
306
+ case '?':
307
+ case 'h':
308
+ default:
309
+ usage();
310
+ exit(2);
311
+ break;
312
+ }
313
+ }
314
+ if (nsweeps == 0){
315
+ perror("ERROR: No -s specified");
316
+ exit(1);
317
+ }
318
+
319
+ /* open files */
320
+
321
+ if (root[0] == 0){
322
+ perror("ERROR: No fileroot specified");
323
+ exit(1);
324
+ }
325
+
326
+ if (command){
327
+ sprintf(cmdfile, "%s.cmd", root);
328
+ cmdfp = fopen(cmdfile, "a");
329
+ if (cmdfp == NULL) {
330
+ perror("ERROR: Can't open .cmd file");
331
+ exit(1);
332
+ }
333
+ for (i = 1; i < argc; i++)
334
+ fprintf(cmdfp,"%s ",argv[i]);
335
+ fprintf(cmdfp,"\n");
336
+ fflush(cmdfp);
337
+ }
338
+
339
+ #ifndef THINK_C
340
+ sprintf(cmdfile, "%s.pid", root);
341
+ fpid = fopen(cmdfile, "w");
342
+ fprintf(fpid, "%d\n", getpid());
343
+ fclose(fpid);
344
+ #endif /* THINK_C */
345
+
346
+ sprintf(cfile, "%s.cf", root);
347
+ cfp = fopen(cfile, "r");
348
+ if (cfp == NULL) {
349
+ perror("ERROR: Can't open .cf file");
350
+ exit(1);
351
+ }
352
+
353
+ get_nodes();
354
+ make_arrays();
355
+ get_outputs();
356
+ get_connections();
357
+ get_special();
358
+ #ifdef GRAPHICS
359
+ /*
360
+ * graphics must be done after other files are opened
361
+ */
362
+ if (dsp_type != 0)
363
+ init_dsp(root);
364
+ #endif GRAPHICS
365
+ if (!seed)
366
+ seed = time((time_t *) NULL);
367
+ srandom(seed);
368
+
369
+ if (loadflag)
370
+ load_wts();
371
+ else {
372
+ for (i = 0; i < nn; i++){
373
+ w = *(wt + i);
374
+ dw = *(dwt+ i);
375
+ wi = *(winc+ i);
376
+ ci = *(cinfo+ i);
377
+ for (j = 0; j < nt; j++, ci++, w++, wi++, dw++){
378
+ if (ci->con)
379
+ *w = rans(weight_limit);
380
+ else
381
+ *w = 0.;
382
+ *wi = 0.;
383
+ *dw = 0.;
384
+ }
385
+ }
386
+ /*
387
+ * If init_bias, then we want to set initial biases
388
+ * to (*only*) output units to a random negative number.
389
+ * We index into the **wt to find the section of receiver
390
+ * weights for each output node. The first weight in each
391
+ * section is for unit 0 (bias), so no further indexing needed.
392
+ */
393
+ for (i = 0; i < no; i++){
394
+ w = *(wt + outputs[i] - 1);
395
+ ci = *(cinfo + outputs[i] - 1);
396
+ if (ci->con)
397
+ *w = init_bias + rans(.1);
398
+ else
399
+ *w = 0.;
400
+ }
401
+ }
402
+ zold[0] = znew[0] = 1.;
403
+ for (i = 1; i < nt; i++)
404
+ zold[i] = znew[i] = 0.;
405
+ if (backprop == 0){
406
+ make_parrays();
407
+ for (i = 0; i < nn; i++){
408
+ for (j = 0; j < nt; j++){
409
+ po = *(*(pold + i) + j);
410
+ pn = *(*(pnew + i) + j);
411
+ for (k = 0; k < nn; k++, po++, pn++){
412
+ *po = 0.;
413
+ *pn = 0.;
414
+ }
415
+ }
416
+ }
417
+ }
418
+
419
+
420
+ nsweeps += tsweeps;
421
+ for (sweep = tsweeps; sweep < nsweeps; sweep++){
422
+
423
+ for (ticks = 0; ticks < nticks; ticks++){
424
+
425
+ update_reset(ttime,ticks,rflag,&tmax,&reset);
426
+
427
+ if (reset){
428
+ if (backprop == 0)
429
+ reset_network(zold,znew,pold,pnew);
430
+ else
431
+ reset_bp_net(zold,znew);
432
+ }
433
+
434
+ update_inputs(zold,ticks,iflag,&tmax,&linput);
435
+
436
+ if (learning || teacher || (rms_report != 0))
437
+ update_targets(target,ttime,ticks,tflag,&tmax);
438
+
439
+ act_nds(zold,zmem,znew,wt,linput,target);
440
+
441
+ comp_errors(zold,target,error,&err,&ce_err);
442
+
443
+ if (learning && (backprop == 0))
444
+ comp_deltas(pold,pnew,wt,dwt,zold,znew,error);
445
+ if (learning && (backprop == 1))
446
+ comp_backprop(wt,dwt,zold,zmem,target,error,linput);
447
+
448
+ if (probe)
449
+ print_nodes(zold);
450
+ }
451
+ #ifdef GRAPHICS
452
+ if ((dsp_type != 0) && (sweep%dsp_freq == 0))
453
+ do_dsp();
454
+ #endif GRAPHICS
455
+ if (verify)
456
+ print_output(zold);
457
+
458
+ if (rms_report && (++rtime >= rms_report)){
459
+ rtime = 0;
460
+ if (ce == 2)
461
+ print_error(&ce_err);
462
+ else
463
+ print_error(&err);
464
+ }
465
+
466
+ if (check && (++ctime >= check)){
467
+ ctime = 0;
468
+ save_wts();
469
+ }
470
+
471
+ if (++ttime >= tmax)
472
+ ttime = 0;
473
+
474
+ if (learning && (++utime >= umax)){
475
+ utime = 0;
476
+ update_weights(wt,dwt,winc);
477
+ }
478
+
479
+ }
480
+ if (learning)
481
+ save_wts();
482
+ exit(0);
483
+
484
+ }
485
+
486
+ usage() {
487
+ fprintf(stderr, "\n");
488
+ fprintf(stderr, "-f fileroot:\tspecify fileroot <always required>\n");
489
+ fprintf(stderr, "-l weightfile:\tload in weightfile\n");
490
+ fprintf(stderr, "\n");
491
+ fprintf(stderr, "-s #:\trun for # sweeps <always required>\n");
492
+ fprintf(stderr, "-r #:\tset learning rate to # (between 0. and 1.) [0.1]\n");
493
+ fprintf(stderr, "-m #:\tset momentum to # (between 0. and 1.) [0.0]\n");
494
+ fprintf(stderr, "-n #:\t# of clock ticks per input vector [1]\n");
495
+ fprintf(stderr, "-t:\tfeedback teacher values in place of outputs\n");
496
+ fprintf(stderr, "\n");
497
+ fprintf(stderr, "-S #:\tseed for random number generator [random]\n");
498
+ fprintf(stderr, "-U #:\tupdate weights every # sweeps [1]\n");
499
+ fprintf(stderr, "-E #:\trecord rms error in .err file every # sweeps [0]\n");
500
+ fprintf(stderr, "-C #:\tcheckpoint weights file every # sweeps [0]\n");
501
+ fprintf(stderr, "-M #:\texit program when rms error is less than # [0.0]\n");
502
+ fprintf(stderr, "-X:\tuse auxiliary .reset file\n");
503
+ fprintf(stderr, "-P:\tprobe selected nodes on each sweep (no learning)\n");
504
+ fprintf(stderr, "-V:\tverify outputs on each sweep (no learning)\n");
505
+ fprintf(stderr, "-R:\tpresent input patterns in random order\n");
506
+ fprintf(stderr, "-I:\tignore input values during extra clock ticks\n");
507
+ fprintf(stderr, "-T:\tignore target values during extra clock ticks\n");
508
+ fprintf(stderr, "-L:\tuse RTRL temporally recurrent learning\n");
509
+ fprintf(stderr, "-B #:\toffset for offset biasi initialization\n");
510
+ fprintf(stderr, "-Dt#:\tdisplay type (0=none;1=activations;2=weights;3=both\n");
511
+ fprintf(stderr, "-Df#:\tdisplay frequency (#cycles)");
512
+ fprintf(stderr, "-Dp#:\thardcopy print frequency (#cycles)");
513
+ fprintf(stderr, "-Dd:\tdelay after each display\n");
514
+ fprintf(stderr, "\n");
515
+ }
516
+
517
+ void
518
+ intr(sig)
519
+ int sig;
520
+ {
521
+ save_wts();
522
+ exit(sig);
523
+ }
524
+
525
+
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: tlearn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.5
4
+ version: 0.0.6
5
5
  prerelease:
6
6
  platform: ruby
7
7
  authors:
@@ -32,6 +32,7 @@ files:
32
32
  - ext/tlearn/getopt.c
33
33
  - ext/tlearn/parse.c
34
34
  - ext/tlearn/subs.c
35
+ - ext/tlearn/tlearn.c
35
36
  - ext/tlearn/tlearn_ext.c
36
37
  - ext/tlearn/update.c
37
38
  - ext/tlearn/weights.c