tlearn 0.0.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,587 @@
1
+
2
+ /* tlearn_ext.c - simulator for arbitrary networks with time-ordered input */
3
+
4
+ /*------------------------------------------------------------------------
5
+
6
+ This program simulates learning in a neural network using either
7
+ the classical back-propagation learning algorithm or a slightly
8
+ modified form derived in Williams and Zipser, "A Learning Algo-
9
+ rithm for Continually Running Fully Recurrent Networks." The
10
+ input is a sequence of vectors of (ascii) floating point numbers
11
+ contained in a ".data" file. The target outputs are a set of
12
+ time-stamped vectors of (ascii) floating point numbers (including
13
+ optional "don't care" values) in a ".teach" file. The network
14
+ configuration is defined in a ".cf" file documented in tlearn.man.
15
+
16
+ ------------------------------------------------------------------------*/
17
+ #include <ruby.h>
18
+
19
+ #include <math.h>
20
+ #include <stdio.h>
21
+ #include <signal.h>
22
+ #ifdef ibmpc
23
+ #include "strings.h"
24
+ #include <fcntl.h>
25
+ #else
26
+ #ifndef THINK_C
27
+ #include <strings.h>
28
+ #include <sys/file.h>
29
+ #include <stdlib.h>
30
+ #else /* THINK_C */
31
+ #include <console.h>
32
+ #include <time.h>
33
+ #include <stdlib.h>
34
+ #endif /* THINK_C */
35
+ #endif
36
+ #ifdef notdef
37
+ #include <sys/types.h>
38
+ #include <sys/stat.h>
39
+ #endif /* notdef */
40
+
41
+ #ifdef ibmpc
42
+ #define random(x) rand(x)
43
+ #define srandom(x) srand(x)
44
+ #endif
45
+ #ifdef THINK_C
46
+ #define random(x) rand(x)
47
+ #define srandom(x) srand(x)
48
+ #endif /* THINK_C */
49
+
50
+ extern int nn; /* number of nodes */
51
+ extern int ni; /* number of inputs */
52
+ extern int no; /* number of outputs */
53
+ extern int nt; /* nn + ni + 1 */
54
+ extern int np; /* ni + 1 */
55
+
56
+ extern struct cf {
57
+ int con; /* connection flag */
58
+ int fix; /* fixed-weight flag */
59
+ int num; /* group number */
60
+ int lim; /* weight-limits flag */
61
+ float min; /* weight minimum */
62
+ float max; /* weight maximum */
63
+ };
64
+
65
+ extern struct nf {
66
+ int func; /* activation function type */
67
+ int dela; /* delay flag */
68
+ int targ; /* target flag */
69
+ };
70
+
71
+ extern struct cf **cinfo; /* (nn x nt) connection info */
72
+ extern struct nf *ninfo; /* (nn) node activation function info */
73
+
74
+ extern int *outputs; /* (no) indices of output nodes */
75
+
76
+ extern int *selects; /* (nn+1) nodes selected for probe printout */
77
+ extern int *linput; /* (ni) localist input array */
78
+
79
+ extern float *znew; /* (nt) inputs and activations at time t+1 */
80
+ extern float *zold; /* (nt) inputs and activations at time t */
81
+ extern float *zmem; /* (nt) inputs and activations at time t */
82
+ extern float **wt; /* (nn x nt) weight TO node i FROM node j*/
83
+ extern float **dwt; /* (nn x nt) delta weight at time t */
84
+ extern float **winc; /* (nn x nt) accumulated weight increment*/
85
+ extern float *target; /* (no) output target values */
86
+ extern float *error; /* (nn) error = (output - target) values */
87
+ extern float ***pnew; /* (nn x nt x nn) p-variable at time t+1 */
88
+ extern float ***pold; /* (nn x nt x nn) p-variable at time t */
89
+
90
+ extern float rate; /* learning rate */
91
+ extern float momentum; /* momentum */
92
+ extern float weight_limit; /* bound for random weight init */
93
+ extern float criterion; /* exit program when rms error is less than this */
94
+ extern float init_bias; /* possible offset for initial output biases */
95
+
96
+ extern float *data; /* Required to reset the .data file */
97
+
98
+ extern long sweep; /* current sweep */
99
+ extern long tsweeps; /* total sweeps to date */
100
+ extern long rms_report; /* output rms error every "report" sweeps */
101
+
102
+ extern int ngroups; /* number of groups */
103
+
104
+ extern int backprop; /* flag for standard back propagation (the default) */
105
+ extern int teacher; /* flag for feeding back targets */
106
+ extern int localist; /* flag for speed-up with localist inputs */
107
+ extern int randomly; /* flag for presenting inputs in random order */
108
+ extern int limits; /* flag for limited weights */
109
+ extern int ce; /* flag for cross_entropy */
110
+
111
+ extern char root[128]; /* root filename for .cf, .data, .teach, etc.*/
112
+ extern char loadfile[128]; /* filename for weightfile to be read in */
113
+
114
+ extern FILE *cfp; /* file pointer for .cf file */
115
+
116
+ extern void intr();
117
+
118
+ extern int load_wts();
119
+ extern int save_wts();
120
+ extern int act_nds();
121
+
122
+ extern int optind;
123
+
124
+
125
+ int run_training(nsweeps, file_path, current_weights_output)
126
+ long nsweeps;
127
+ char *file_path;
128
+ float *current_weights_output;
129
+ {
130
+ int argc = 1;
131
+ char *argv[argc];
132
+ argv[0] = "tlearn";
133
+ int status;
134
+
135
+ backprop = 0;
136
+ status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
137
+
138
+ return(status);
139
+ }
140
+
141
+ int run_fitness(argc,argv, nsweeps, file_path, current_weights_output)
142
+ int argc;
143
+ char **argv;
144
+ long nsweeps;
145
+ char *file_path;
146
+ float *current_weights_output;
147
+ {
148
+ int status;
149
+ backprop = 1;
150
+ status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
151
+
152
+ return(status);
153
+ }
154
+
155
+ int run(argc,argv, nsweeps, file_path, backprop, current_weights_output)
156
+ int argc;
157
+ char **argv;
158
+ long nsweeps;
159
+ char *file_path;
160
+ int backprop;
161
+ float *current_weights_output;
162
+ {
163
+ //Reset EVERYTHING. Globals, such a great idea...
164
+ optind = 1;
165
+ sweep = 0;
166
+ tsweeps = 0;
167
+ rate = .1;
168
+ momentum = 0.;
169
+ weight_limit = 1.;
170
+ criterion = 0.;
171
+ init_bias = 0.;
172
+ rms_report = 0;
173
+ ngroups = 0;
174
+ teacher = 0;
175
+ localist = 0;
176
+ randomly = 0;
177
+ limits = 0;
178
+ ce = 0;
179
+ outputs = 0;
180
+ selects = 0;
181
+ linput = 0;
182
+ cinfo = 0;
183
+ ninfo = 0;
184
+ znew = 0;
185
+ zold = 0;
186
+ zmem = 0;
187
+ pnew = 0;
188
+ pold = 0;
189
+ wt = 0;
190
+ dwt = 0;
191
+ winc = 0;
192
+ target = 0;
193
+ error = 0;
194
+ cfp = 0;
195
+ data = 0;
196
+ ngroups = 0;
197
+ root[0] = 0;
198
+ loadfile[0] = 0;
199
+
200
+ FILE *fopen();
201
+ FILE *fpid;
202
+ extern char *optarg;
203
+ extern float rans();
204
+ extern time_t time();
205
+
206
+ long ttime = 0; /* number of sweeps since time = 0 */
207
+ long utime = 0; /* number of sweeps since last update_weights */
208
+ long tmax = 0; /* maximum number of sweeps (given in .data) */
209
+ long umax = 0; /* update weights every umax sweeps */
210
+ long rtime = 0; /* number of sweeps since last rms_report */
211
+ long check = 0; /* output weights every "check" sweeps */
212
+ long ctime = 0; /* number of sweeps since last check */
213
+
214
+ int c;
215
+ int i;
216
+ int j;
217
+ int k;
218
+ int nticks = 1; /* number of internal clock ticks per input */
219
+ int ticks = 0; /* counter for ticks */
220
+ int learning = 1; /* flag for learning */
221
+ int reset = 0; /* flag for resetting net */
222
+ int verify = 0; /* flag for printing output values */
223
+ int probe = 0; /* flag for printing selected node values */
224
+ int command = 1; /* flag for writing to .cmd file */
225
+ int loadflag = 0; /* flag for loading initial weights from file */
226
+ int iflag = 0; /* flag for -I */
227
+ int tflag = 0; /* flag for -T */
228
+ int rflag = 0; /* flag for -x */
229
+ int seed = 0; /* seed for random() */
230
+
231
+ float err = 0.; /* cumulative ss error */
232
+ float ce_err = 0.; /* cumulate cross_entropy error */
233
+
234
+ float *w;
235
+ float *wi;
236
+ float *dw;
237
+ float *pn;
238
+ float *po;
239
+
240
+ struct cf *ci;
241
+
242
+ char cmdfile[128]; /* filename for logging runs of program */
243
+ char cfile[128]; /* filename for .cf file */
244
+
245
+ FILE *cmdfp;
246
+
247
+ #ifdef THINK_C
248
+ argc = ccommand(&argv);
249
+ #endif /* THINK_C */
250
+
251
+ signal(SIGINT, intr);
252
+ #ifndef ibmpc
253
+ #ifndef THINK_C
254
+ signal(SIGHUP, intr);
255
+ signal(SIGQUIT, intr);
256
+ signal(SIGKILL, intr);
257
+ #endif /* THINK_C */
258
+ #endif
259
+
260
+ #ifndef ibmpc
261
+ exp_init();
262
+ #endif
263
+
264
+ root[0] = 0;
265
+ strcpy(root, file_path);
266
+
267
+ while ((c = getopt(argc, argv, "f:hil:m:n:r:s:tC:E:ILM:PpRS:TU:VvXB:H:D:")) != EOF) {
268
+ switch (c) {
269
+ case 'C':
270
+ check = (long) atol(optarg);
271
+ ctime = check;
272
+ break;
273
+ case 'i':
274
+ command = 0;
275
+ break;
276
+ case 'l':
277
+ loadflag = 1;
278
+ strcpy(loadfile,optarg);
279
+ break;
280
+ case 'm':
281
+ momentum = (float) atof(optarg);
282
+ break;
283
+ case 'n':
284
+ nticks = (int) atoi(optarg);
285
+ break;
286
+ case 'P':
287
+ learning = 0;
288
+ /* drop through deliberately */
289
+ case 'p':
290
+ probe = 1;
291
+ break;
292
+ case 'r':
293
+ rate = (double) atof(optarg);
294
+ break;
295
+ case 't':
296
+ teacher = 1;
297
+ break;
298
+ case 'V':
299
+ learning = 0;
300
+ /* drop through deliberately */
301
+ case 'v':
302
+ verify = 1;
303
+ break;
304
+ case 'X':
305
+ rflag = 1;
306
+ break;
307
+ case 'E':
308
+ rms_report = (long) atol(optarg);
309
+ break;
310
+ case 'I':
311
+ iflag = 1;
312
+ break;
313
+ case 'M':
314
+ criterion = (float) atof(optarg);
315
+ break;
316
+ case 'R':
317
+ randomly = 1;
318
+ break;
319
+ case 'S':
320
+ seed = atoi(optarg);
321
+ break;
322
+ case 'T':
323
+ tflag = 1;
324
+ break;
325
+ case 'U':
326
+ umax = atol(optarg);
327
+ break;
328
+ case 'B':
329
+ init_bias = atof(optarg);
330
+ break;
331
+ /*
332
+ * if == 1, use cross-entropy as error;
333
+ * if == 2, also collect cross-entropy stats.
334
+ */
335
+ case 'H':
336
+ ce = atoi(optarg);
337
+ break;
338
+ case '?':
339
+ case 'h':
340
+ default:
341
+ usage();
342
+ return(2);
343
+ break;
344
+ }
345
+ }
346
+ if (nsweeps == 0){
347
+ perror("ERROR: No -s specified");
348
+ return(1);
349
+ }
350
+
351
+ /* open files */
352
+
353
+ if (root[0] == 0){
354
+ perror("ERROR: No fileroot specified");
355
+ return(1);
356
+ }
357
+
358
+ if (command){
359
+ sprintf(cmdfile, "%s.cmd", root);
360
+ cmdfp = fopen(cmdfile, "a");
361
+ if (cmdfp == NULL) {
362
+ perror("ERROR: Can't open .cmd file");
363
+ return(1);
364
+ }
365
+ for (i = 1; i < argc; i++)
366
+ fprintf(cmdfp,"%s ",argv[i]);
367
+ fprintf(cmdfp,"\n");
368
+ fflush(cmdfp);
369
+ }
370
+
371
+ #ifndef THINK_C
372
+ sprintf(cmdfile, "%s.pid", root);
373
+ fpid = fopen(cmdfile, "w");
374
+ fprintf(fpid, "%d\n", getpid());
375
+ fclose(fpid);
376
+ #endif /* THINK_C */
377
+
378
+ sprintf(cfile, "%s.cf", root);
379
+ cfp = fopen(cfile, "r");
380
+ if (cfp == NULL) {
381
+ perror("ERROR: Can't open .cf file");
382
+ return(1);
383
+ }
384
+
385
+ get_nodes();
386
+ make_arrays();
387
+ get_outputs();
388
+ get_connections();
389
+ get_special();
390
+
391
+ if (!seed)
392
+ seed = time((time_t *) NULL);
393
+ srandom(seed);
394
+
395
+ if (loadflag)
396
+ load_wts();
397
+ else {
398
+ for (i = 0; i < nn; i++){
399
+ w = *(wt + i);
400
+ dw = *(dwt+ i);
401
+ wi = *(winc+ i);
402
+ ci = *(cinfo+ i);
403
+ for (j = 0; j < nt; j++, ci++, w++, wi++, dw++){
404
+ if (ci->con)
405
+ *w = rans(weight_limit);
406
+ else
407
+ *w = 0.;
408
+ *wi = 0.;
409
+ *dw = 0.;
410
+ }
411
+ }
412
+ /*
413
+ * If init_bias, then we want to set initial biases
414
+ * to (*only*) output units to a random negative number.
415
+ * We index into the **wt to find the section of receiver
416
+ * weights for each output node. The first weight in each
417
+ * section is for unit 0 (bias), so no further indexing needed.
418
+ */
419
+ for (i = 0; i < no; i++){
420
+ w = *(wt + outputs[i] - 1);
421
+ ci = *(cinfo + outputs[i] - 1);
422
+ if (ci->con)
423
+ *w = init_bias + rans(.1);
424
+ else
425
+ *w = 0.;
426
+ }
427
+ }
428
+ zold[0] = znew[0] = 1.;
429
+ for (i = 1; i < nt; i++)
430
+ zold[i] = znew[i] = 0.;
431
+ if (backprop == 0){
432
+ make_parrays();
433
+ for (i = 0; i < nn; i++){
434
+ for (j = 0; j < nt; j++){
435
+ po = *(*(pold + i) + j);
436
+ pn = *(*(pnew + i) + j);
437
+ for (k = 0; k < nn; k++, po++, pn++){
438
+ *po = 0.;
439
+ *pn = 0.;
440
+ }
441
+ }
442
+ }
443
+ }
444
+
445
+ data = 0;
446
+
447
+ nsweeps += tsweeps;
448
+ for (sweep = tsweeps; sweep < nsweeps; sweep++){
449
+
450
+ for (ticks = 0; ticks < nticks; ticks++){
451
+
452
+ update_reset(ttime,ticks,rflag,&tmax,&reset);
453
+
454
+ if (reset){
455
+ if (backprop == 0)
456
+ reset_network(zold,znew,pold,pnew);
457
+ else
458
+ reset_bp_net(zold,znew);
459
+ }
460
+
461
+ update_inputs(zold,ticks,iflag,&tmax,&linput);
462
+
463
+ if (learning || teacher || (rms_report != 0))
464
+ update_targets(target,ttime,ticks,tflag,&tmax);
465
+
466
+ act_nds(zold,zmem,znew,wt,linput,target);
467
+
468
+ comp_errors(zold,target,error,&err,&ce_err);
469
+
470
+ if (learning && (backprop == 0))
471
+ comp_deltas(pold,pnew,wt,dwt,zold,znew,error);
472
+ if (learning && (backprop == 1))
473
+ comp_backprop(wt,dwt,zold,zmem,target,error,linput);
474
+
475
+ if (probe)
476
+ print_nodes(zold);
477
+ }
478
+ if (verify){
479
+ for (i = 0; i < no; i++){
480
+ current_weights_output[i] = zold[ni+outputs[i]];
481
+ }
482
+
483
+ //print_output(zold);
484
+ }
485
+ if (rms_report && (++rtime >= rms_report)){
486
+ rtime = 0;
487
+ if (ce == 2)
488
+ print_error(&ce_err);
489
+ else
490
+ print_error(&err);
491
+ }
492
+
493
+ if (check && (++ctime >= check)){
494
+ ctime = 0;
495
+ save_wts();
496
+ }
497
+
498
+ if (++ttime >= tmax)
499
+ ttime = 0;
500
+
501
+ if (learning && (++utime >= umax)){
502
+ utime = 0;
503
+ update_weights(wt,dwt,winc);
504
+ }
505
+
506
+ }
507
+ if (learning)
508
+ save_wts();
509
+
510
+ return(0);
511
+
512
+ }
513
+
514
+ /* -- Ruby interface -- */
515
+
516
+ int do_print(VALUE key, VALUE val, VALUE in) {
517
+ fprintf(stderr, "Input data is %s\n", StringValueCStr(in));
518
+
519
+ fprintf(stderr, "Key %s=>Value %s\n", StringValueCStr(key),
520
+ StringValueCStr(val));
521
+
522
+ return ST_CONTINUE;
523
+ }
524
+
525
+ static VALUE tlearn_train(VALUE self, VALUE config) {
526
+ VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
527
+ long nsweeps = NUM2DBL(sweeps_value);
528
+
529
+ VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
530
+ char *file_root = StringValueCStr(file_root_value);
531
+
532
+ float current_weights_output[6];
533
+
534
+ int result = run_training(nsweeps, file_root, current_weights_output);
535
+ return rb_int_new(result);
536
+ }
537
+
538
+ static VALUE tlearn_fitness(VALUE self, VALUE config) {
539
+ int tlearn_args_count = 4;
540
+ char *tlearn_args[tlearn_args_count];
541
+
542
+ VALUE ruby_array = rb_ary_new();
543
+ VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
544
+
545
+ VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
546
+ long nsweeps = NUM2DBL(sweeps_value);
547
+
548
+ char *file_root = StringValueCStr(file_root_value);
549
+ char weights[strlen(file_root) + strlen(".wts")];
550
+
551
+ float *result_weights;
552
+
553
+ strcpy(weights, file_root);
554
+
555
+ tlearn_args[0] = "tlearn_fitness";
556
+ tlearn_args[1] = "-l";
557
+ tlearn_args[2] = strcat(weights, ".wts");
558
+ tlearn_args[3] = "-V";
559
+
560
+ float current_weights_output[6];
561
+
562
+ int failure = run_fitness(tlearn_args_count, tlearn_args, nsweeps, file_root, current_weights_output);
563
+
564
+ if(failure == 0){
565
+ float weight;
566
+ int result_index;
567
+ for(result_index = 0; result_index < 6; result_index++){
568
+ weight = current_weights_output[result_index];
569
+ rb_ary_store(ruby_array, result_index, rb_float_new(weight));
570
+ }
571
+ return(ruby_array);
572
+ }
573
+ else{
574
+ return(rb_int_new(failure));
575
+ }
576
+ }
577
+
578
+ void Init_tlearn(void) {
579
+ VALUE klass = rb_define_class("TLearnExt",
580
+ rb_cObject);
581
+
582
+ rb_define_singleton_method(klass,
583
+ "train", tlearn_train, 1);
584
+
585
+ rb_define_singleton_method(klass,
586
+ "fitness", tlearn_fitness, 1);
587
+ }