tlearn 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- data/ext/tlearn/Exp/exp.c +13 -0
- data/ext/tlearn/activate.c +222 -0
- data/ext/tlearn/arrays.c +224 -0
- data/ext/tlearn/compute.c +404 -0
- data/ext/tlearn/extconf.rb +14 -0
- data/ext/tlearn/getopt.c +76 -0
- data/ext/tlearn/parse.c +594 -0
- data/ext/tlearn/subs.c +204 -0
- data/ext/tlearn/tlearn.c +525 -0
- data/ext/tlearn/tlearn_ext.c +587 -0
- data/ext/tlearn/update.c +577 -0
- data/ext/tlearn/weights.c +116 -0
- data/lib/tlearn.rb +17 -0
- data/lib/tlearn/config.rb +101 -0
- data/lib/tlearn/fitness_data.rb +24 -0
- data/lib/tlearn/run.rb +29 -0
- data/lib/tlearn/run_tlearn.rb +68 -0
- data/lib/tlearn/training_data.rb +41 -0
- metadata +64 -0
@@ -0,0 +1,587 @@
|
|
1
|
+
|
2
|
+
/* tlearn_ext.c - simulator for arbitrary networks with time-ordered input */
|
3
|
+
|
4
|
+
/*------------------------------------------------------------------------
|
5
|
+
|
6
|
+
This program simulates learning in a neural network using either
|
7
|
+
the classical back-propagation learning algorithm or a slightly
|
8
|
+
modified form derived in Williams and Zipser, "A Learning Algo-
|
9
|
+
rithm for Continually Running Fully Recurrent Networks." The
|
10
|
+
input is a sequence of vectors of (ascii) floating point numbers
|
11
|
+
contained in a ".data" file. The target outputs are a set of
|
12
|
+
time-stamped vectors of (ascii) floating point numbers (including
|
13
|
+
optional "don't care" values) in a ".teach" file. The network
|
14
|
+
configuration is defined in a ".cf" file documented in tlearn.man.
|
15
|
+
|
16
|
+
------------------------------------------------------------------------*/
|
17
|
+
#include <ruby.h>
|
18
|
+
|
19
|
+
#include <math.h>
|
20
|
+
#include <stdio.h>
|
21
|
+
#include <signal.h>
|
22
|
+
#ifdef ibmpc
|
23
|
+
#include "strings.h"
|
24
|
+
#include <fcntl.h>
|
25
|
+
#else
|
26
|
+
#ifndef THINK_C
|
27
|
+
#include <strings.h>
|
28
|
+
#include <sys/file.h>
|
29
|
+
#include <stdlib.h>
|
30
|
+
#else /* THINK_C */
|
31
|
+
#include <console.h>
|
32
|
+
#include <time.h>
|
33
|
+
#include <stdlib.h>
|
34
|
+
#endif /* THINK_C */
|
35
|
+
#endif
|
36
|
+
#ifdef notdef
|
37
|
+
#include <sys/types.h>
|
38
|
+
#include <sys/stat.h>
|
39
|
+
#endif /* notdef */
|
40
|
+
|
41
|
+
#ifdef ibmpc
|
42
|
+
#define random(x) rand(x)
|
43
|
+
#define srandom(x) srand(x)
|
44
|
+
#endif
|
45
|
+
#ifdef THINK_C
|
46
|
+
#define random(x) rand(x)
|
47
|
+
#define srandom(x) srand(x)
|
48
|
+
#endif /* THINK_C */
|
49
|
+
|
50
|
+
extern int nn; /* number of nodes */
|
51
|
+
extern int ni; /* number of inputs */
|
52
|
+
extern int no; /* number of outputs */
|
53
|
+
extern int nt; /* nn + ni + 1 */
|
54
|
+
extern int np; /* ni + 1 */
|
55
|
+
|
56
|
+
extern struct cf {
|
57
|
+
int con; /* connection flag */
|
58
|
+
int fix; /* fixed-weight flag */
|
59
|
+
int num; /* group number */
|
60
|
+
int lim; /* weight-limits flag */
|
61
|
+
float min; /* weight minimum */
|
62
|
+
float max; /* weight maximum */
|
63
|
+
};
|
64
|
+
|
65
|
+
extern struct nf {
|
66
|
+
int func; /* activation function type */
|
67
|
+
int dela; /* delay flag */
|
68
|
+
int targ; /* target flag */
|
69
|
+
};
|
70
|
+
|
71
|
+
extern struct cf **cinfo; /* (nn x nt) connection info */
|
72
|
+
extern struct nf *ninfo; /* (nn) node activation function info */
|
73
|
+
|
74
|
+
extern int *outputs; /* (no) indices of output nodes */
|
75
|
+
|
76
|
+
extern int *selects; /* (nn+1) nodes selected for probe printout */
|
77
|
+
extern int *linput; /* (ni) localist input array */
|
78
|
+
|
79
|
+
extern float *znew; /* (nt) inputs and activations at time t+1 */
|
80
|
+
extern float *zold; /* (nt) inputs and activations at time t */
|
81
|
+
extern float *zmem; /* (nt) inputs and activations at time t */
|
82
|
+
extern float **wt; /* (nn x nt) weight TO node i FROM node j*/
|
83
|
+
extern float **dwt; /* (nn x nt) delta weight at time t */
|
84
|
+
extern float **winc; /* (nn x nt) accumulated weight increment*/
|
85
|
+
extern float *target; /* (no) output target values */
|
86
|
+
extern float *error; /* (nn) error = (output - target) values */
|
87
|
+
extern float ***pnew; /* (nn x nt x nn) p-variable at time t+1 */
|
88
|
+
extern float ***pold; /* (nn x nt x nn) p-variable at time t */
|
89
|
+
|
90
|
+
extern float rate; /* learning rate */
|
91
|
+
extern float momentum; /* momentum */
|
92
|
+
extern float weight_limit; /* bound for random weight init */
|
93
|
+
extern float criterion; /* exit program when rms error is less than this */
|
94
|
+
extern float init_bias; /* possible offset for initial output biases */
|
95
|
+
|
96
|
+
extern float *data; /* Required to reset the .data file */
|
97
|
+
|
98
|
+
extern long sweep; /* current sweep */
|
99
|
+
extern long tsweeps; /* total sweeps to date */
|
100
|
+
extern long rms_report; /* output rms error every "report" sweeps */
|
101
|
+
|
102
|
+
extern int ngroups; /* number of groups */
|
103
|
+
|
104
|
+
extern int backprop; /* flag for standard back propagation (the default) */
|
105
|
+
extern int teacher; /* flag for feeding back targets */
|
106
|
+
extern int localist; /* flag for speed-up with localist inputs */
|
107
|
+
extern int randomly; /* flag for presenting inputs in random order */
|
108
|
+
extern int limits; /* flag for limited weights */
|
109
|
+
extern int ce; /* flag for cross_entropy */
|
110
|
+
|
111
|
+
extern char root[128]; /* root filename for .cf, .data, .teach, etc.*/
|
112
|
+
extern char loadfile[128]; /* filename for weightfile to be read in */
|
113
|
+
|
114
|
+
extern FILE *cfp; /* file pointer for .cf file */
|
115
|
+
|
116
|
+
extern void intr();
|
117
|
+
|
118
|
+
extern int load_wts();
|
119
|
+
extern int save_wts();
|
120
|
+
extern int act_nds();
|
121
|
+
|
122
|
+
extern int optind;
|
123
|
+
|
124
|
+
|
125
|
+
int run_training(nsweeps, file_path, current_weights_output)
|
126
|
+
long nsweeps;
|
127
|
+
char *file_path;
|
128
|
+
float *current_weights_output;
|
129
|
+
{
|
130
|
+
int argc = 1;
|
131
|
+
char *argv[argc];
|
132
|
+
argv[0] = "tlearn";
|
133
|
+
int status;
|
134
|
+
|
135
|
+
backprop = 0;
|
136
|
+
status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
|
137
|
+
|
138
|
+
return(status);
|
139
|
+
}
|
140
|
+
|
141
|
+
int run_fitness(argc,argv, nsweeps, file_path, current_weights_output)
|
142
|
+
int argc;
|
143
|
+
char **argv;
|
144
|
+
long nsweeps;
|
145
|
+
char *file_path;
|
146
|
+
float *current_weights_output;
|
147
|
+
{
|
148
|
+
int status;
|
149
|
+
backprop = 1;
|
150
|
+
status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
|
151
|
+
|
152
|
+
return(status);
|
153
|
+
}
|
154
|
+
|
155
|
+
int run(argc,argv, nsweeps, file_path, backprop, current_weights_output)
|
156
|
+
int argc;
|
157
|
+
char **argv;
|
158
|
+
long nsweeps;
|
159
|
+
char *file_path;
|
160
|
+
int backprop;
|
161
|
+
float *current_weights_output;
|
162
|
+
{
|
163
|
+
//Reset EVERYTHING. Globals, such a great idea...
|
164
|
+
optind = 1;
|
165
|
+
sweep = 0;
|
166
|
+
tsweeps = 0;
|
167
|
+
rate = .1;
|
168
|
+
momentum = 0.;
|
169
|
+
weight_limit = 1.;
|
170
|
+
criterion = 0.;
|
171
|
+
init_bias = 0.;
|
172
|
+
rms_report = 0;
|
173
|
+
ngroups = 0;
|
174
|
+
teacher = 0;
|
175
|
+
localist = 0;
|
176
|
+
randomly = 0;
|
177
|
+
limits = 0;
|
178
|
+
ce = 0;
|
179
|
+
outputs = 0;
|
180
|
+
selects = 0;
|
181
|
+
linput = 0;
|
182
|
+
cinfo = 0;
|
183
|
+
ninfo = 0;
|
184
|
+
znew = 0;
|
185
|
+
zold = 0;
|
186
|
+
zmem = 0;
|
187
|
+
pnew = 0;
|
188
|
+
pold = 0;
|
189
|
+
wt = 0;
|
190
|
+
dwt = 0;
|
191
|
+
winc = 0;
|
192
|
+
target = 0;
|
193
|
+
error = 0;
|
194
|
+
cfp = 0;
|
195
|
+
data = 0;
|
196
|
+
ngroups = 0;
|
197
|
+
root[0] = 0;
|
198
|
+
loadfile[0] = 0;
|
199
|
+
|
200
|
+
FILE *fopen();
|
201
|
+
FILE *fpid;
|
202
|
+
extern char *optarg;
|
203
|
+
extern float rans();
|
204
|
+
extern time_t time();
|
205
|
+
|
206
|
+
long ttime = 0; /* number of sweeps since time = 0 */
|
207
|
+
long utime = 0; /* number of sweeps since last update_weights */
|
208
|
+
long tmax = 0; /* maximum number of sweeps (given in .data) */
|
209
|
+
long umax = 0; /* update weights every umax sweeps */
|
210
|
+
long rtime = 0; /* number of sweeps since last rms_report */
|
211
|
+
long check = 0; /* output weights every "check" sweeps */
|
212
|
+
long ctime = 0; /* number of sweeps since last check */
|
213
|
+
|
214
|
+
int c;
|
215
|
+
int i;
|
216
|
+
int j;
|
217
|
+
int k;
|
218
|
+
int nticks = 1; /* number of internal clock ticks per input */
|
219
|
+
int ticks = 0; /* counter for ticks */
|
220
|
+
int learning = 1; /* flag for learning */
|
221
|
+
int reset = 0; /* flag for resetting net */
|
222
|
+
int verify = 0; /* flag for printing output values */
|
223
|
+
int probe = 0; /* flag for printing selected node values */
|
224
|
+
int command = 1; /* flag for writing to .cmd file */
|
225
|
+
int loadflag = 0; /* flag for loading initial weights from file */
|
226
|
+
int iflag = 0; /* flag for -I */
|
227
|
+
int tflag = 0; /* flag for -T */
|
228
|
+
int rflag = 0; /* flag for -x */
|
229
|
+
int seed = 0; /* seed for random() */
|
230
|
+
|
231
|
+
float err = 0.; /* cumulative ss error */
|
232
|
+
float ce_err = 0.; /* cumulate cross_entropy error */
|
233
|
+
|
234
|
+
float *w;
|
235
|
+
float *wi;
|
236
|
+
float *dw;
|
237
|
+
float *pn;
|
238
|
+
float *po;
|
239
|
+
|
240
|
+
struct cf *ci;
|
241
|
+
|
242
|
+
char cmdfile[128]; /* filename for logging runs of program */
|
243
|
+
char cfile[128]; /* filename for .cf file */
|
244
|
+
|
245
|
+
FILE *cmdfp;
|
246
|
+
|
247
|
+
#ifdef THINK_C
|
248
|
+
argc = ccommand(&argv);
|
249
|
+
#endif /* THINK_C */
|
250
|
+
|
251
|
+
signal(SIGINT, intr);
|
252
|
+
#ifndef ibmpc
|
253
|
+
#ifndef THINK_C
|
254
|
+
signal(SIGHUP, intr);
|
255
|
+
signal(SIGQUIT, intr);
|
256
|
+
signal(SIGKILL, intr);
|
257
|
+
#endif /* THINK_C */
|
258
|
+
#endif
|
259
|
+
|
260
|
+
#ifndef ibmpc
|
261
|
+
exp_init();
|
262
|
+
#endif
|
263
|
+
|
264
|
+
root[0] = 0;
|
265
|
+
strcpy(root, file_path);
|
266
|
+
|
267
|
+
while ((c = getopt(argc, argv, "f:hil:m:n:r:s:tC:E:ILM:PpRS:TU:VvXB:H:D:")) != EOF) {
|
268
|
+
switch (c) {
|
269
|
+
case 'C':
|
270
|
+
check = (long) atol(optarg);
|
271
|
+
ctime = check;
|
272
|
+
break;
|
273
|
+
case 'i':
|
274
|
+
command = 0;
|
275
|
+
break;
|
276
|
+
case 'l':
|
277
|
+
loadflag = 1;
|
278
|
+
strcpy(loadfile,optarg);
|
279
|
+
break;
|
280
|
+
case 'm':
|
281
|
+
momentum = (float) atof(optarg);
|
282
|
+
break;
|
283
|
+
case 'n':
|
284
|
+
nticks = (int) atoi(optarg);
|
285
|
+
break;
|
286
|
+
case 'P':
|
287
|
+
learning = 0;
|
288
|
+
/* drop through deliberately */
|
289
|
+
case 'p':
|
290
|
+
probe = 1;
|
291
|
+
break;
|
292
|
+
case 'r':
|
293
|
+
rate = (double) atof(optarg);
|
294
|
+
break;
|
295
|
+
case 't':
|
296
|
+
teacher = 1;
|
297
|
+
break;
|
298
|
+
case 'V':
|
299
|
+
learning = 0;
|
300
|
+
/* drop through deliberately */
|
301
|
+
case 'v':
|
302
|
+
verify = 1;
|
303
|
+
break;
|
304
|
+
case 'X':
|
305
|
+
rflag = 1;
|
306
|
+
break;
|
307
|
+
case 'E':
|
308
|
+
rms_report = (long) atol(optarg);
|
309
|
+
break;
|
310
|
+
case 'I':
|
311
|
+
iflag = 1;
|
312
|
+
break;
|
313
|
+
case 'M':
|
314
|
+
criterion = (float) atof(optarg);
|
315
|
+
break;
|
316
|
+
case 'R':
|
317
|
+
randomly = 1;
|
318
|
+
break;
|
319
|
+
case 'S':
|
320
|
+
seed = atoi(optarg);
|
321
|
+
break;
|
322
|
+
case 'T':
|
323
|
+
tflag = 1;
|
324
|
+
break;
|
325
|
+
case 'U':
|
326
|
+
umax = atol(optarg);
|
327
|
+
break;
|
328
|
+
case 'B':
|
329
|
+
init_bias = atof(optarg);
|
330
|
+
break;
|
331
|
+
/*
|
332
|
+
* if == 1, use cross-entropy as error;
|
333
|
+
* if == 2, also collect cross-entropy stats.
|
334
|
+
*/
|
335
|
+
case 'H':
|
336
|
+
ce = atoi(optarg);
|
337
|
+
break;
|
338
|
+
case '?':
|
339
|
+
case 'h':
|
340
|
+
default:
|
341
|
+
usage();
|
342
|
+
return(2);
|
343
|
+
break;
|
344
|
+
}
|
345
|
+
}
|
346
|
+
if (nsweeps == 0){
|
347
|
+
perror("ERROR: No -s specified");
|
348
|
+
return(1);
|
349
|
+
}
|
350
|
+
|
351
|
+
/* open files */
|
352
|
+
|
353
|
+
if (root[0] == 0){
|
354
|
+
perror("ERROR: No fileroot specified");
|
355
|
+
return(1);
|
356
|
+
}
|
357
|
+
|
358
|
+
if (command){
|
359
|
+
sprintf(cmdfile, "%s.cmd", root);
|
360
|
+
cmdfp = fopen(cmdfile, "a");
|
361
|
+
if (cmdfp == NULL) {
|
362
|
+
perror("ERROR: Can't open .cmd file");
|
363
|
+
return(1);
|
364
|
+
}
|
365
|
+
for (i = 1; i < argc; i++)
|
366
|
+
fprintf(cmdfp,"%s ",argv[i]);
|
367
|
+
fprintf(cmdfp,"\n");
|
368
|
+
fflush(cmdfp);
|
369
|
+
}
|
370
|
+
|
371
|
+
#ifndef THINK_C
|
372
|
+
sprintf(cmdfile, "%s.pid", root);
|
373
|
+
fpid = fopen(cmdfile, "w");
|
374
|
+
fprintf(fpid, "%d\n", getpid());
|
375
|
+
fclose(fpid);
|
376
|
+
#endif /* THINK_C */
|
377
|
+
|
378
|
+
sprintf(cfile, "%s.cf", root);
|
379
|
+
cfp = fopen(cfile, "r");
|
380
|
+
if (cfp == NULL) {
|
381
|
+
perror("ERROR: Can't open .cf file");
|
382
|
+
return(1);
|
383
|
+
}
|
384
|
+
|
385
|
+
get_nodes();
|
386
|
+
make_arrays();
|
387
|
+
get_outputs();
|
388
|
+
get_connections();
|
389
|
+
get_special();
|
390
|
+
|
391
|
+
if (!seed)
|
392
|
+
seed = time((time_t *) NULL);
|
393
|
+
srandom(seed);
|
394
|
+
|
395
|
+
if (loadflag)
|
396
|
+
load_wts();
|
397
|
+
else {
|
398
|
+
for (i = 0; i < nn; i++){
|
399
|
+
w = *(wt + i);
|
400
|
+
dw = *(dwt+ i);
|
401
|
+
wi = *(winc+ i);
|
402
|
+
ci = *(cinfo+ i);
|
403
|
+
for (j = 0; j < nt; j++, ci++, w++, wi++, dw++){
|
404
|
+
if (ci->con)
|
405
|
+
*w = rans(weight_limit);
|
406
|
+
else
|
407
|
+
*w = 0.;
|
408
|
+
*wi = 0.;
|
409
|
+
*dw = 0.;
|
410
|
+
}
|
411
|
+
}
|
412
|
+
/*
|
413
|
+
* If init_bias, then we want to set initial biases
|
414
|
+
* to (*only*) output units to a random negative number.
|
415
|
+
* We index into the **wt to find the section of receiver
|
416
|
+
* weights for each output node. The first weight in each
|
417
|
+
* section is for unit 0 (bias), so no further indexing needed.
|
418
|
+
*/
|
419
|
+
for (i = 0; i < no; i++){
|
420
|
+
w = *(wt + outputs[i] - 1);
|
421
|
+
ci = *(cinfo + outputs[i] - 1);
|
422
|
+
if (ci->con)
|
423
|
+
*w = init_bias + rans(.1);
|
424
|
+
else
|
425
|
+
*w = 0.;
|
426
|
+
}
|
427
|
+
}
|
428
|
+
zold[0] = znew[0] = 1.;
|
429
|
+
for (i = 1; i < nt; i++)
|
430
|
+
zold[i] = znew[i] = 0.;
|
431
|
+
if (backprop == 0){
|
432
|
+
make_parrays();
|
433
|
+
for (i = 0; i < nn; i++){
|
434
|
+
for (j = 0; j < nt; j++){
|
435
|
+
po = *(*(pold + i) + j);
|
436
|
+
pn = *(*(pnew + i) + j);
|
437
|
+
for (k = 0; k < nn; k++, po++, pn++){
|
438
|
+
*po = 0.;
|
439
|
+
*pn = 0.;
|
440
|
+
}
|
441
|
+
}
|
442
|
+
}
|
443
|
+
}
|
444
|
+
|
445
|
+
data = 0;
|
446
|
+
|
447
|
+
nsweeps += tsweeps;
|
448
|
+
for (sweep = tsweeps; sweep < nsweeps; sweep++){
|
449
|
+
|
450
|
+
for (ticks = 0; ticks < nticks; ticks++){
|
451
|
+
|
452
|
+
update_reset(ttime,ticks,rflag,&tmax,&reset);
|
453
|
+
|
454
|
+
if (reset){
|
455
|
+
if (backprop == 0)
|
456
|
+
reset_network(zold,znew,pold,pnew);
|
457
|
+
else
|
458
|
+
reset_bp_net(zold,znew);
|
459
|
+
}
|
460
|
+
|
461
|
+
update_inputs(zold,ticks,iflag,&tmax,&linput);
|
462
|
+
|
463
|
+
if (learning || teacher || (rms_report != 0))
|
464
|
+
update_targets(target,ttime,ticks,tflag,&tmax);
|
465
|
+
|
466
|
+
act_nds(zold,zmem,znew,wt,linput,target);
|
467
|
+
|
468
|
+
comp_errors(zold,target,error,&err,&ce_err);
|
469
|
+
|
470
|
+
if (learning && (backprop == 0))
|
471
|
+
comp_deltas(pold,pnew,wt,dwt,zold,znew,error);
|
472
|
+
if (learning && (backprop == 1))
|
473
|
+
comp_backprop(wt,dwt,zold,zmem,target,error,linput);
|
474
|
+
|
475
|
+
if (probe)
|
476
|
+
print_nodes(zold);
|
477
|
+
}
|
478
|
+
if (verify){
|
479
|
+
for (i = 0; i < no; i++){
|
480
|
+
current_weights_output[i] = zold[ni+outputs[i]];
|
481
|
+
}
|
482
|
+
|
483
|
+
//print_output(zold);
|
484
|
+
}
|
485
|
+
if (rms_report && (++rtime >= rms_report)){
|
486
|
+
rtime = 0;
|
487
|
+
if (ce == 2)
|
488
|
+
print_error(&ce_err);
|
489
|
+
else
|
490
|
+
print_error(&err);
|
491
|
+
}
|
492
|
+
|
493
|
+
if (check && (++ctime >= check)){
|
494
|
+
ctime = 0;
|
495
|
+
save_wts();
|
496
|
+
}
|
497
|
+
|
498
|
+
if (++ttime >= tmax)
|
499
|
+
ttime = 0;
|
500
|
+
|
501
|
+
if (learning && (++utime >= umax)){
|
502
|
+
utime = 0;
|
503
|
+
update_weights(wt,dwt,winc);
|
504
|
+
}
|
505
|
+
|
506
|
+
}
|
507
|
+
if (learning)
|
508
|
+
save_wts();
|
509
|
+
|
510
|
+
return(0);
|
511
|
+
|
512
|
+
}
|
513
|
+
|
514
|
+
/* -- Ruby interface -- */
|
515
|
+
|
516
|
+
int do_print(VALUE key, VALUE val, VALUE in) {
|
517
|
+
fprintf(stderr, "Input data is %s\n", StringValueCStr(in));
|
518
|
+
|
519
|
+
fprintf(stderr, "Key %s=>Value %s\n", StringValueCStr(key),
|
520
|
+
StringValueCStr(val));
|
521
|
+
|
522
|
+
return ST_CONTINUE;
|
523
|
+
}
|
524
|
+
|
525
|
+
static VALUE tlearn_train(VALUE self, VALUE config) {
|
526
|
+
VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
|
527
|
+
long nsweeps = NUM2DBL(sweeps_value);
|
528
|
+
|
529
|
+
VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
|
530
|
+
char *file_root = StringValueCStr(file_root_value);
|
531
|
+
|
532
|
+
float current_weights_output[6];
|
533
|
+
|
534
|
+
int result = run_training(nsweeps, file_root, current_weights_output);
|
535
|
+
return rb_int_new(result);
|
536
|
+
}
|
537
|
+
|
538
|
+
static VALUE tlearn_fitness(VALUE self, VALUE config) {
|
539
|
+
int tlearn_args_count = 4;
|
540
|
+
char *tlearn_args[tlearn_args_count];
|
541
|
+
|
542
|
+
VALUE ruby_array = rb_ary_new();
|
543
|
+
VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
|
544
|
+
|
545
|
+
VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
|
546
|
+
long nsweeps = NUM2DBL(sweeps_value);
|
547
|
+
|
548
|
+
char *file_root = StringValueCStr(file_root_value);
|
549
|
+
char weights[strlen(file_root) + strlen(".wts")];
|
550
|
+
|
551
|
+
float *result_weights;
|
552
|
+
|
553
|
+
strcpy(weights, file_root);
|
554
|
+
|
555
|
+
tlearn_args[0] = "tlearn_fitness";
|
556
|
+
tlearn_args[1] = "-l";
|
557
|
+
tlearn_args[2] = strcat(weights, ".wts");
|
558
|
+
tlearn_args[3] = "-V";
|
559
|
+
|
560
|
+
float current_weights_output[6];
|
561
|
+
|
562
|
+
int failure = run_fitness(tlearn_args_count, tlearn_args, nsweeps, file_root, current_weights_output);
|
563
|
+
|
564
|
+
if(failure == 0){
|
565
|
+
float weight;
|
566
|
+
int result_index;
|
567
|
+
for(result_index = 0; result_index < 6; result_index++){
|
568
|
+
weight = current_weights_output[result_index];
|
569
|
+
rb_ary_store(ruby_array, result_index, rb_float_new(weight));
|
570
|
+
}
|
571
|
+
return(ruby_array);
|
572
|
+
}
|
573
|
+
else{
|
574
|
+
return(rb_int_new(failure));
|
575
|
+
}
|
576
|
+
}
|
577
|
+
|
578
|
+
void Init_tlearn(void) {
|
579
|
+
VALUE klass = rb_define_class("TLearnExt",
|
580
|
+
rb_cObject);
|
581
|
+
|
582
|
+
rb_define_singleton_method(klass,
|
583
|
+
"train", tlearn_train, 1);
|
584
|
+
|
585
|
+
rb_define_singleton_method(klass,
|
586
|
+
"fitness", tlearn_fitness, 1);
|
587
|
+
}
|