tlearn 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/ext/tlearn/Exp/exp.c +13 -0
- data/ext/tlearn/activate.c +222 -0
- data/ext/tlearn/arrays.c +224 -0
- data/ext/tlearn/compute.c +404 -0
- data/ext/tlearn/extconf.rb +14 -0
- data/ext/tlearn/getopt.c +76 -0
- data/ext/tlearn/parse.c +594 -0
- data/ext/tlearn/subs.c +204 -0
- data/ext/tlearn/tlearn.c +525 -0
- data/ext/tlearn/tlearn_ext.c +587 -0
- data/ext/tlearn/update.c +577 -0
- data/ext/tlearn/weights.c +116 -0
- data/lib/tlearn.rb +17 -0
- data/lib/tlearn/config.rb +101 -0
- data/lib/tlearn/fitness_data.rb +24 -0
- data/lib/tlearn/run.rb +29 -0
- data/lib/tlearn/run_tlearn.rb +68 -0
- data/lib/tlearn/training_data.rb +41 -0
- metadata +64 -0
@@ -0,0 +1,587 @@
|
|
1
|
+
|
2
|
+
/* tlearn_ext.c - simulator for arbitrary networks with time-ordered input */
|
3
|
+
|
4
|
+
/*------------------------------------------------------------------------
|
5
|
+
|
6
|
+
This program simulates learning in a neural network using either
|
7
|
+
the classical back-propagation learning algorithm or a slightly
|
8
|
+
modified form derived in Williams and Zipser, "A Learning Algo-
|
9
|
+
rithm for Continually Running Fully Recurrent Networks." The
|
10
|
+
input is a sequence of vectors of (ascii) floating point numbers
|
11
|
+
contained in a ".data" file. The target outputs are a set of
|
12
|
+
time-stamped vectors of (ascii) floating point numbers (including
|
13
|
+
optional "don't care" values) in a ".teach" file. The network
|
14
|
+
configuration is defined in a ".cf" file documented in tlearn.man.
|
15
|
+
|
16
|
+
------------------------------------------------------------------------*/
|
17
|
+
#include <ruby.h>
|
18
|
+
|
19
|
+
#include <math.h>
|
20
|
+
#include <stdio.h>
|
21
|
+
#include <signal.h>
|
22
|
+
#ifdef ibmpc
|
23
|
+
#include "strings.h"
|
24
|
+
#include <fcntl.h>
|
25
|
+
#else
|
26
|
+
#ifndef THINK_C
|
27
|
+
#include <strings.h>
|
28
|
+
#include <sys/file.h>
|
29
|
+
#include <stdlib.h>
|
30
|
+
#else /* THINK_C */
|
31
|
+
#include <console.h>
|
32
|
+
#include <time.h>
|
33
|
+
#include <stdlib.h>
|
34
|
+
#endif /* THINK_C */
|
35
|
+
#endif
|
36
|
+
#ifdef notdef
|
37
|
+
#include <sys/types.h>
|
38
|
+
#include <sys/stat.h>
|
39
|
+
#endif /* notdef */
|
40
|
+
|
41
|
+
#ifdef ibmpc
|
42
|
+
#define random(x) rand(x)
|
43
|
+
#define srandom(x) srand(x)
|
44
|
+
#endif
|
45
|
+
#ifdef THINK_C
|
46
|
+
#define random(x) rand(x)
|
47
|
+
#define srandom(x) srand(x)
|
48
|
+
#endif /* THINK_C */
|
49
|
+
|
50
|
+
extern int nn; /* number of nodes */
|
51
|
+
extern int ni; /* number of inputs */
|
52
|
+
extern int no; /* number of outputs */
|
53
|
+
extern int nt; /* nn + ni + 1 */
|
54
|
+
extern int np; /* ni + 1 */
|
55
|
+
|
56
|
+
extern struct cf {
|
57
|
+
int con; /* connection flag */
|
58
|
+
int fix; /* fixed-weight flag */
|
59
|
+
int num; /* group number */
|
60
|
+
int lim; /* weight-limits flag */
|
61
|
+
float min; /* weight minimum */
|
62
|
+
float max; /* weight maximum */
|
63
|
+
};
|
64
|
+
|
65
|
+
extern struct nf {
|
66
|
+
int func; /* activation function type */
|
67
|
+
int dela; /* delay flag */
|
68
|
+
int targ; /* target flag */
|
69
|
+
};
|
70
|
+
|
71
|
+
extern struct cf **cinfo; /* (nn x nt) connection info */
|
72
|
+
extern struct nf *ninfo; /* (nn) node activation function info */
|
73
|
+
|
74
|
+
extern int *outputs; /* (no) indices of output nodes */
|
75
|
+
|
76
|
+
extern int *selects; /* (nn+1) nodes selected for probe printout */
|
77
|
+
extern int *linput; /* (ni) localist input array */
|
78
|
+
|
79
|
+
extern float *znew; /* (nt) inputs and activations at time t+1 */
|
80
|
+
extern float *zold; /* (nt) inputs and activations at time t */
|
81
|
+
extern float *zmem; /* (nt) inputs and activations at time t */
|
82
|
+
extern float **wt; /* (nn x nt) weight TO node i FROM node j*/
|
83
|
+
extern float **dwt; /* (nn x nt) delta weight at time t */
|
84
|
+
extern float **winc; /* (nn x nt) accumulated weight increment*/
|
85
|
+
extern float *target; /* (no) output target values */
|
86
|
+
extern float *error; /* (nn) error = (output - target) values */
|
87
|
+
extern float ***pnew; /* (nn x nt x nn) p-variable at time t+1 */
|
88
|
+
extern float ***pold; /* (nn x nt x nn) p-variable at time t */
|
89
|
+
|
90
|
+
extern float rate; /* learning rate */
|
91
|
+
extern float momentum; /* momentum */
|
92
|
+
extern float weight_limit; /* bound for random weight init */
|
93
|
+
extern float criterion; /* exit program when rms error is less than this */
|
94
|
+
extern float init_bias; /* possible offset for initial output biases */
|
95
|
+
|
96
|
+
extern float *data; /* Required to reset the .data file */
|
97
|
+
|
98
|
+
extern long sweep; /* current sweep */
|
99
|
+
extern long tsweeps; /* total sweeps to date */
|
100
|
+
extern long rms_report; /* output rms error every "report" sweeps */
|
101
|
+
|
102
|
+
extern int ngroups; /* number of groups */
|
103
|
+
|
104
|
+
extern int backprop; /* flag for standard back propagation (the default) */
|
105
|
+
extern int teacher; /* flag for feeding back targets */
|
106
|
+
extern int localist; /* flag for speed-up with localist inputs */
|
107
|
+
extern int randomly; /* flag for presenting inputs in random order */
|
108
|
+
extern int limits; /* flag for limited weights */
|
109
|
+
extern int ce; /* flag for cross_entropy */
|
110
|
+
|
111
|
+
extern char root[128]; /* root filename for .cf, .data, .teach, etc.*/
|
112
|
+
extern char loadfile[128]; /* filename for weightfile to be read in */
|
113
|
+
|
114
|
+
extern FILE *cfp; /* file pointer for .cf file */
|
115
|
+
|
116
|
+
extern void intr();
|
117
|
+
|
118
|
+
extern int load_wts();
|
119
|
+
extern int save_wts();
|
120
|
+
extern int act_nds();
|
121
|
+
|
122
|
+
extern int optind;
|
123
|
+
|
124
|
+
|
125
|
+
int run_training(nsweeps, file_path, current_weights_output)
|
126
|
+
long nsweeps;
|
127
|
+
char *file_path;
|
128
|
+
float *current_weights_output;
|
129
|
+
{
|
130
|
+
int argc = 1;
|
131
|
+
char *argv[argc];
|
132
|
+
argv[0] = "tlearn";
|
133
|
+
int status;
|
134
|
+
|
135
|
+
backprop = 0;
|
136
|
+
status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
|
137
|
+
|
138
|
+
return(status);
|
139
|
+
}
|
140
|
+
|
141
|
+
int run_fitness(argc,argv, nsweeps, file_path, current_weights_output)
|
142
|
+
int argc;
|
143
|
+
char **argv;
|
144
|
+
long nsweeps;
|
145
|
+
char *file_path;
|
146
|
+
float *current_weights_output;
|
147
|
+
{
|
148
|
+
int status;
|
149
|
+
backprop = 1;
|
150
|
+
status = run(argc,argv, nsweeps, file_path, backprop, current_weights_output);
|
151
|
+
|
152
|
+
return(status);
|
153
|
+
}
|
154
|
+
|
155
|
+
int run(argc,argv, nsweeps, file_path, backprop, current_weights_output)
|
156
|
+
int argc;
|
157
|
+
char **argv;
|
158
|
+
long nsweeps;
|
159
|
+
char *file_path;
|
160
|
+
int backprop;
|
161
|
+
float *current_weights_output;
|
162
|
+
{
|
163
|
+
//Reset EVERYTHING. Globals, such a great idea...
|
164
|
+
optind = 1;
|
165
|
+
sweep = 0;
|
166
|
+
tsweeps = 0;
|
167
|
+
rate = .1;
|
168
|
+
momentum = 0.;
|
169
|
+
weight_limit = 1.;
|
170
|
+
criterion = 0.;
|
171
|
+
init_bias = 0.;
|
172
|
+
rms_report = 0;
|
173
|
+
ngroups = 0;
|
174
|
+
teacher = 0;
|
175
|
+
localist = 0;
|
176
|
+
randomly = 0;
|
177
|
+
limits = 0;
|
178
|
+
ce = 0;
|
179
|
+
outputs = 0;
|
180
|
+
selects = 0;
|
181
|
+
linput = 0;
|
182
|
+
cinfo = 0;
|
183
|
+
ninfo = 0;
|
184
|
+
znew = 0;
|
185
|
+
zold = 0;
|
186
|
+
zmem = 0;
|
187
|
+
pnew = 0;
|
188
|
+
pold = 0;
|
189
|
+
wt = 0;
|
190
|
+
dwt = 0;
|
191
|
+
winc = 0;
|
192
|
+
target = 0;
|
193
|
+
error = 0;
|
194
|
+
cfp = 0;
|
195
|
+
data = 0;
|
196
|
+
ngroups = 0;
|
197
|
+
root[0] = 0;
|
198
|
+
loadfile[0] = 0;
|
199
|
+
|
200
|
+
FILE *fopen();
|
201
|
+
FILE *fpid;
|
202
|
+
extern char *optarg;
|
203
|
+
extern float rans();
|
204
|
+
extern time_t time();
|
205
|
+
|
206
|
+
long ttime = 0; /* number of sweeps since time = 0 */
|
207
|
+
long utime = 0; /* number of sweeps since last update_weights */
|
208
|
+
long tmax = 0; /* maximum number of sweeps (given in .data) */
|
209
|
+
long umax = 0; /* update weights every umax sweeps */
|
210
|
+
long rtime = 0; /* number of sweeps since last rms_report */
|
211
|
+
long check = 0; /* output weights every "check" sweeps */
|
212
|
+
long ctime = 0; /* number of sweeps since last check */
|
213
|
+
|
214
|
+
int c;
|
215
|
+
int i;
|
216
|
+
int j;
|
217
|
+
int k;
|
218
|
+
int nticks = 1; /* number of internal clock ticks per input */
|
219
|
+
int ticks = 0; /* counter for ticks */
|
220
|
+
int learning = 1; /* flag for learning */
|
221
|
+
int reset = 0; /* flag for resetting net */
|
222
|
+
int verify = 0; /* flag for printing output values */
|
223
|
+
int probe = 0; /* flag for printing selected node values */
|
224
|
+
int command = 1; /* flag for writing to .cmd file */
|
225
|
+
int loadflag = 0; /* flag for loading initial weights from file */
|
226
|
+
int iflag = 0; /* flag for -I */
|
227
|
+
int tflag = 0; /* flag for -T */
|
228
|
+
int rflag = 0; /* flag for -x */
|
229
|
+
int seed = 0; /* seed for random() */
|
230
|
+
|
231
|
+
float err = 0.; /* cumulative ss error */
|
232
|
+
float ce_err = 0.; /* cumulate cross_entropy error */
|
233
|
+
|
234
|
+
float *w;
|
235
|
+
float *wi;
|
236
|
+
float *dw;
|
237
|
+
float *pn;
|
238
|
+
float *po;
|
239
|
+
|
240
|
+
struct cf *ci;
|
241
|
+
|
242
|
+
char cmdfile[128]; /* filename for logging runs of program */
|
243
|
+
char cfile[128]; /* filename for .cf file */
|
244
|
+
|
245
|
+
FILE *cmdfp;
|
246
|
+
|
247
|
+
#ifdef THINK_C
|
248
|
+
argc = ccommand(&argv);
|
249
|
+
#endif /* THINK_C */
|
250
|
+
|
251
|
+
signal(SIGINT, intr);
|
252
|
+
#ifndef ibmpc
|
253
|
+
#ifndef THINK_C
|
254
|
+
signal(SIGHUP, intr);
|
255
|
+
signal(SIGQUIT, intr);
|
256
|
+
signal(SIGKILL, intr);
|
257
|
+
#endif /* THINK_C */
|
258
|
+
#endif
|
259
|
+
|
260
|
+
#ifndef ibmpc
|
261
|
+
exp_init();
|
262
|
+
#endif
|
263
|
+
|
264
|
+
root[0] = 0;
|
265
|
+
strcpy(root, file_path);
|
266
|
+
|
267
|
+
while ((c = getopt(argc, argv, "f:hil:m:n:r:s:tC:E:ILM:PpRS:TU:VvXB:H:D:")) != EOF) {
|
268
|
+
switch (c) {
|
269
|
+
case 'C':
|
270
|
+
check = (long) atol(optarg);
|
271
|
+
ctime = check;
|
272
|
+
break;
|
273
|
+
case 'i':
|
274
|
+
command = 0;
|
275
|
+
break;
|
276
|
+
case 'l':
|
277
|
+
loadflag = 1;
|
278
|
+
strcpy(loadfile,optarg);
|
279
|
+
break;
|
280
|
+
case 'm':
|
281
|
+
momentum = (float) atof(optarg);
|
282
|
+
break;
|
283
|
+
case 'n':
|
284
|
+
nticks = (int) atoi(optarg);
|
285
|
+
break;
|
286
|
+
case 'P':
|
287
|
+
learning = 0;
|
288
|
+
/* drop through deliberately */
|
289
|
+
case 'p':
|
290
|
+
probe = 1;
|
291
|
+
break;
|
292
|
+
case 'r':
|
293
|
+
rate = (double) atof(optarg);
|
294
|
+
break;
|
295
|
+
case 't':
|
296
|
+
teacher = 1;
|
297
|
+
break;
|
298
|
+
case 'V':
|
299
|
+
learning = 0;
|
300
|
+
/* drop through deliberately */
|
301
|
+
case 'v':
|
302
|
+
verify = 1;
|
303
|
+
break;
|
304
|
+
case 'X':
|
305
|
+
rflag = 1;
|
306
|
+
break;
|
307
|
+
case 'E':
|
308
|
+
rms_report = (long) atol(optarg);
|
309
|
+
break;
|
310
|
+
case 'I':
|
311
|
+
iflag = 1;
|
312
|
+
break;
|
313
|
+
case 'M':
|
314
|
+
criterion = (float) atof(optarg);
|
315
|
+
break;
|
316
|
+
case 'R':
|
317
|
+
randomly = 1;
|
318
|
+
break;
|
319
|
+
case 'S':
|
320
|
+
seed = atoi(optarg);
|
321
|
+
break;
|
322
|
+
case 'T':
|
323
|
+
tflag = 1;
|
324
|
+
break;
|
325
|
+
case 'U':
|
326
|
+
umax = atol(optarg);
|
327
|
+
break;
|
328
|
+
case 'B':
|
329
|
+
init_bias = atof(optarg);
|
330
|
+
break;
|
331
|
+
/*
|
332
|
+
* if == 1, use cross-entropy as error;
|
333
|
+
* if == 2, also collect cross-entropy stats.
|
334
|
+
*/
|
335
|
+
case 'H':
|
336
|
+
ce = atoi(optarg);
|
337
|
+
break;
|
338
|
+
case '?':
|
339
|
+
case 'h':
|
340
|
+
default:
|
341
|
+
usage();
|
342
|
+
return(2);
|
343
|
+
break;
|
344
|
+
}
|
345
|
+
}
|
346
|
+
if (nsweeps == 0){
|
347
|
+
perror("ERROR: No -s specified");
|
348
|
+
return(1);
|
349
|
+
}
|
350
|
+
|
351
|
+
/* open files */
|
352
|
+
|
353
|
+
if (root[0] == 0){
|
354
|
+
perror("ERROR: No fileroot specified");
|
355
|
+
return(1);
|
356
|
+
}
|
357
|
+
|
358
|
+
if (command){
|
359
|
+
sprintf(cmdfile, "%s.cmd", root);
|
360
|
+
cmdfp = fopen(cmdfile, "a");
|
361
|
+
if (cmdfp == NULL) {
|
362
|
+
perror("ERROR: Can't open .cmd file");
|
363
|
+
return(1);
|
364
|
+
}
|
365
|
+
for (i = 1; i < argc; i++)
|
366
|
+
fprintf(cmdfp,"%s ",argv[i]);
|
367
|
+
fprintf(cmdfp,"\n");
|
368
|
+
fflush(cmdfp);
|
369
|
+
}
|
370
|
+
|
371
|
+
#ifndef THINK_C
|
372
|
+
sprintf(cmdfile, "%s.pid", root);
|
373
|
+
fpid = fopen(cmdfile, "w");
|
374
|
+
fprintf(fpid, "%d\n", getpid());
|
375
|
+
fclose(fpid);
|
376
|
+
#endif /* THINK_C */
|
377
|
+
|
378
|
+
sprintf(cfile, "%s.cf", root);
|
379
|
+
cfp = fopen(cfile, "r");
|
380
|
+
if (cfp == NULL) {
|
381
|
+
perror("ERROR: Can't open .cf file");
|
382
|
+
return(1);
|
383
|
+
}
|
384
|
+
|
385
|
+
get_nodes();
|
386
|
+
make_arrays();
|
387
|
+
get_outputs();
|
388
|
+
get_connections();
|
389
|
+
get_special();
|
390
|
+
|
391
|
+
if (!seed)
|
392
|
+
seed = time((time_t *) NULL);
|
393
|
+
srandom(seed);
|
394
|
+
|
395
|
+
if (loadflag)
|
396
|
+
load_wts();
|
397
|
+
else {
|
398
|
+
for (i = 0; i < nn; i++){
|
399
|
+
w = *(wt + i);
|
400
|
+
dw = *(dwt+ i);
|
401
|
+
wi = *(winc+ i);
|
402
|
+
ci = *(cinfo+ i);
|
403
|
+
for (j = 0; j < nt; j++, ci++, w++, wi++, dw++){
|
404
|
+
if (ci->con)
|
405
|
+
*w = rans(weight_limit);
|
406
|
+
else
|
407
|
+
*w = 0.;
|
408
|
+
*wi = 0.;
|
409
|
+
*dw = 0.;
|
410
|
+
}
|
411
|
+
}
|
412
|
+
/*
|
413
|
+
* If init_bias, then we want to set initial biases
|
414
|
+
* to (*only*) output units to a random negative number.
|
415
|
+
* We index into the **wt to find the section of receiver
|
416
|
+
* weights for each output node. The first weight in each
|
417
|
+
* section is for unit 0 (bias), so no further indexing needed.
|
418
|
+
*/
|
419
|
+
for (i = 0; i < no; i++){
|
420
|
+
w = *(wt + outputs[i] - 1);
|
421
|
+
ci = *(cinfo + outputs[i] - 1);
|
422
|
+
if (ci->con)
|
423
|
+
*w = init_bias + rans(.1);
|
424
|
+
else
|
425
|
+
*w = 0.;
|
426
|
+
}
|
427
|
+
}
|
428
|
+
zold[0] = znew[0] = 1.;
|
429
|
+
for (i = 1; i < nt; i++)
|
430
|
+
zold[i] = znew[i] = 0.;
|
431
|
+
if (backprop == 0){
|
432
|
+
make_parrays();
|
433
|
+
for (i = 0; i < nn; i++){
|
434
|
+
for (j = 0; j < nt; j++){
|
435
|
+
po = *(*(pold + i) + j);
|
436
|
+
pn = *(*(pnew + i) + j);
|
437
|
+
for (k = 0; k < nn; k++, po++, pn++){
|
438
|
+
*po = 0.;
|
439
|
+
*pn = 0.;
|
440
|
+
}
|
441
|
+
}
|
442
|
+
}
|
443
|
+
}
|
444
|
+
|
445
|
+
data = 0;
|
446
|
+
|
447
|
+
nsweeps += tsweeps;
|
448
|
+
for (sweep = tsweeps; sweep < nsweeps; sweep++){
|
449
|
+
|
450
|
+
for (ticks = 0; ticks < nticks; ticks++){
|
451
|
+
|
452
|
+
update_reset(ttime,ticks,rflag,&tmax,&reset);
|
453
|
+
|
454
|
+
if (reset){
|
455
|
+
if (backprop == 0)
|
456
|
+
reset_network(zold,znew,pold,pnew);
|
457
|
+
else
|
458
|
+
reset_bp_net(zold,znew);
|
459
|
+
}
|
460
|
+
|
461
|
+
update_inputs(zold,ticks,iflag,&tmax,&linput);
|
462
|
+
|
463
|
+
if (learning || teacher || (rms_report != 0))
|
464
|
+
update_targets(target,ttime,ticks,tflag,&tmax);
|
465
|
+
|
466
|
+
act_nds(zold,zmem,znew,wt,linput,target);
|
467
|
+
|
468
|
+
comp_errors(zold,target,error,&err,&ce_err);
|
469
|
+
|
470
|
+
if (learning && (backprop == 0))
|
471
|
+
comp_deltas(pold,pnew,wt,dwt,zold,znew,error);
|
472
|
+
if (learning && (backprop == 1))
|
473
|
+
comp_backprop(wt,dwt,zold,zmem,target,error,linput);
|
474
|
+
|
475
|
+
if (probe)
|
476
|
+
print_nodes(zold);
|
477
|
+
}
|
478
|
+
if (verify){
|
479
|
+
for (i = 0; i < no; i++){
|
480
|
+
current_weights_output[i] = zold[ni+outputs[i]];
|
481
|
+
}
|
482
|
+
|
483
|
+
//print_output(zold);
|
484
|
+
}
|
485
|
+
if (rms_report && (++rtime >= rms_report)){
|
486
|
+
rtime = 0;
|
487
|
+
if (ce == 2)
|
488
|
+
print_error(&ce_err);
|
489
|
+
else
|
490
|
+
print_error(&err);
|
491
|
+
}
|
492
|
+
|
493
|
+
if (check && (++ctime >= check)){
|
494
|
+
ctime = 0;
|
495
|
+
save_wts();
|
496
|
+
}
|
497
|
+
|
498
|
+
if (++ttime >= tmax)
|
499
|
+
ttime = 0;
|
500
|
+
|
501
|
+
if (learning && (++utime >= umax)){
|
502
|
+
utime = 0;
|
503
|
+
update_weights(wt,dwt,winc);
|
504
|
+
}
|
505
|
+
|
506
|
+
}
|
507
|
+
if (learning)
|
508
|
+
save_wts();
|
509
|
+
|
510
|
+
return(0);
|
511
|
+
|
512
|
+
}
|
513
|
+
|
514
|
+
/* -- Ruby interface -- */
|
515
|
+
|
516
|
+
int do_print(VALUE key, VALUE val, VALUE in) {
|
517
|
+
fprintf(stderr, "Input data is %s\n", StringValueCStr(in));
|
518
|
+
|
519
|
+
fprintf(stderr, "Key %s=>Value %s\n", StringValueCStr(key),
|
520
|
+
StringValueCStr(val));
|
521
|
+
|
522
|
+
return ST_CONTINUE;
|
523
|
+
}
|
524
|
+
|
525
|
+
static VALUE tlearn_train(VALUE self, VALUE config) {
|
526
|
+
VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
|
527
|
+
long nsweeps = NUM2DBL(sweeps_value);
|
528
|
+
|
529
|
+
VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
|
530
|
+
char *file_root = StringValueCStr(file_root_value);
|
531
|
+
|
532
|
+
float current_weights_output[6];
|
533
|
+
|
534
|
+
int result = run_training(nsweeps, file_root, current_weights_output);
|
535
|
+
return rb_int_new(result);
|
536
|
+
}
|
537
|
+
|
538
|
+
static VALUE tlearn_fitness(VALUE self, VALUE config) {
|
539
|
+
int tlearn_args_count = 4;
|
540
|
+
char *tlearn_args[tlearn_args_count];
|
541
|
+
|
542
|
+
VALUE ruby_array = rb_ary_new();
|
543
|
+
VALUE file_root_value = rb_hash_aref(config, ID2SYM(rb_intern("file_root")));
|
544
|
+
|
545
|
+
VALUE sweeps_value = rb_hash_aref(config, ID2SYM(rb_intern("sweeps")));
|
546
|
+
long nsweeps = NUM2DBL(sweeps_value);
|
547
|
+
|
548
|
+
char *file_root = StringValueCStr(file_root_value);
|
549
|
+
char weights[strlen(file_root) + strlen(".wts")];
|
550
|
+
|
551
|
+
float *result_weights;
|
552
|
+
|
553
|
+
strcpy(weights, file_root);
|
554
|
+
|
555
|
+
tlearn_args[0] = "tlearn_fitness";
|
556
|
+
tlearn_args[1] = "-l";
|
557
|
+
tlearn_args[2] = strcat(weights, ".wts");
|
558
|
+
tlearn_args[3] = "-V";
|
559
|
+
|
560
|
+
float current_weights_output[6];
|
561
|
+
|
562
|
+
int failure = run_fitness(tlearn_args_count, tlearn_args, nsweeps, file_root, current_weights_output);
|
563
|
+
|
564
|
+
if(failure == 0){
|
565
|
+
float weight;
|
566
|
+
int result_index;
|
567
|
+
for(result_index = 0; result_index < 6; result_index++){
|
568
|
+
weight = current_weights_output[result_index];
|
569
|
+
rb_ary_store(ruby_array, result_index, rb_float_new(weight));
|
570
|
+
}
|
571
|
+
return(ruby_array);
|
572
|
+
}
|
573
|
+
else{
|
574
|
+
return(rb_int_new(failure));
|
575
|
+
}
|
576
|
+
}
|
577
|
+
|
578
|
+
void Init_tlearn(void) {
|
579
|
+
VALUE klass = rb_define_class("TLearnExt",
|
580
|
+
rb_cObject);
|
581
|
+
|
582
|
+
rb_define_singleton_method(klass,
|
583
|
+
"train", tlearn_train, 1);
|
584
|
+
|
585
|
+
rb_define_singleton_method(klass,
|
586
|
+
"fitness", tlearn_fitness, 1);
|
587
|
+
}
|