@dniskav/neuron 0.2.2 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +38 -0
- package/dist/index.d.mts +291 -11
- package/dist/index.d.ts +291 -11
- package/dist/index.js +1510 -41
- package/dist/index.mjs +1495 -40
- package/package.json +7 -3
package/dist/index.mjs
CHANGED
|
@@ -1,3 +1,71 @@
|
|
|
1
|
+
// src/Validation.ts
|
|
2
|
+
function validateArray(arr, expectedLength, methodName) {
|
|
3
|
+
if (!Array.isArray(arr)) {
|
|
4
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
5
|
+
}
|
|
6
|
+
if (arr.length !== expectedLength) {
|
|
7
|
+
throw new Error(
|
|
8
|
+
`${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
|
|
9
|
+
);
|
|
10
|
+
}
|
|
11
|
+
for (let i = 0; i < arr.length; i++) {
|
|
12
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
13
|
+
throw new Error(
|
|
14
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
15
|
+
);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
function validateArrayMinLength(arr, minLength, methodName) {
|
|
20
|
+
if (!Array.isArray(arr)) {
|
|
21
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
22
|
+
}
|
|
23
|
+
if (arr.length < minLength) {
|
|
24
|
+
throw new Error(
|
|
25
|
+
`${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
|
|
26
|
+
);
|
|
27
|
+
}
|
|
28
|
+
for (let i = 0; i < arr.length; i++) {
|
|
29
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
30
|
+
throw new Error(
|
|
31
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
32
|
+
);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
function validate2DArray(arr, expectedRows, expectedCols, methodName) {
|
|
37
|
+
if (!Array.isArray(arr)) {
|
|
38
|
+
throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
|
|
39
|
+
}
|
|
40
|
+
if (arr.length !== expectedRows) {
|
|
41
|
+
throw new Error(
|
|
42
|
+
`${methodName}: expected ${expectedRows} rows, got ${arr.length}`
|
|
43
|
+
);
|
|
44
|
+
}
|
|
45
|
+
for (let i = 0; i < arr.length; i++) {
|
|
46
|
+
if (!Array.isArray(arr[i])) {
|
|
47
|
+
throw new Error(`${methodName}: row ${i} is not an array`);
|
|
48
|
+
}
|
|
49
|
+
if (arr[i].length !== expectedCols) {
|
|
50
|
+
throw new Error(
|
|
51
|
+
`${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
|
|
52
|
+
);
|
|
53
|
+
}
|
|
54
|
+
for (let j = 0; j < arr[i].length; j++) {
|
|
55
|
+
if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
|
|
56
|
+
throw new Error(
|
|
57
|
+
`${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
|
|
58
|
+
);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
function validateNumber(value, methodName) {
|
|
64
|
+
if (typeof value !== "number" || !isFinite(value)) {
|
|
65
|
+
throw new Error(`${methodName}: expected finite number, got ${value}`);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
1
69
|
// src/Neuron.ts
|
|
2
70
|
function sigmoid(x) {
|
|
3
71
|
return 1 / (1 + Math.exp(-x));
|
|
@@ -8,13 +76,18 @@ var Neuron = class {
|
|
|
8
76
|
this.bias = Math.random() * 0.1;
|
|
9
77
|
}
|
|
10
78
|
predict(input) {
|
|
79
|
+
validateNumber(input, "Neuron.predict");
|
|
11
80
|
return sigmoid(input * this.weight + this.bias);
|
|
12
81
|
}
|
|
13
82
|
train(input, target, lr) {
|
|
83
|
+
validateNumber(input, "Neuron.train");
|
|
84
|
+
validateNumber(target, "Neuron.train");
|
|
85
|
+
validateNumber(lr, "Neuron.train");
|
|
14
86
|
const prediction = this.predict(input);
|
|
15
87
|
const error = target - prediction;
|
|
16
|
-
|
|
17
|
-
this.
|
|
88
|
+
const grad = error * prediction * (1 - prediction);
|
|
89
|
+
this.weight += lr * grad * input;
|
|
90
|
+
this.bias += lr * grad;
|
|
18
91
|
}
|
|
19
92
|
};
|
|
20
93
|
|
|
@@ -69,6 +142,19 @@ var Momentum = class {
|
|
|
69
142
|
return weight + this.v;
|
|
70
143
|
}
|
|
71
144
|
};
|
|
145
|
+
var ClipOptimizer = class {
|
|
146
|
+
constructor(inner, clipValue) {
|
|
147
|
+
this.inner = inner;
|
|
148
|
+
this.clipValue = clipValue;
|
|
149
|
+
}
|
|
150
|
+
step(weight, gradient, lr) {
|
|
151
|
+
const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
|
|
152
|
+
return this.inner.step(weight, clipped, lr);
|
|
153
|
+
}
|
|
154
|
+
};
|
|
155
|
+
function ClippedOptimizerFactory(innerFactory, clipValue) {
|
|
156
|
+
return () => new ClipOptimizer(innerFactory(), clipValue);
|
|
157
|
+
}
|
|
72
158
|
var Adam = class {
|
|
73
159
|
constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
|
|
74
160
|
this.beta1 = beta1;
|
|
@@ -99,6 +185,7 @@ var NeuronN = class {
|
|
|
99
185
|
this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
|
|
100
186
|
}
|
|
101
187
|
predict(inputs) {
|
|
188
|
+
validateArray(inputs, this.weights.length, "NeuronN.predict");
|
|
102
189
|
const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
|
|
103
190
|
return this.activation.fn(sum);
|
|
104
191
|
}
|
|
@@ -111,7 +198,8 @@ var NeuronN = class {
|
|
|
111
198
|
train(inputs, target, lr) {
|
|
112
199
|
const prediction = this.predict(inputs);
|
|
113
200
|
const error = target - prediction;
|
|
114
|
-
|
|
201
|
+
const grad = error * this.activation.dfn(prediction);
|
|
202
|
+
this._update(inputs.map((inp) => grad * inp), grad, lr);
|
|
115
203
|
}
|
|
116
204
|
};
|
|
117
205
|
|
|
@@ -136,29 +224,99 @@ var Network = class {
|
|
|
136
224
|
this.outputLayer = new Layer(nOutputs, nHidden);
|
|
137
225
|
}
|
|
138
226
|
predict(inputs) {
|
|
227
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
|
|
139
228
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
140
229
|
return this.outputLayer.predict(hiddenOut)[0];
|
|
141
230
|
}
|
|
142
231
|
// Trains on a single example. Returns the squared error.
|
|
143
232
|
train(inputs, target, lr) {
|
|
233
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
|
|
234
|
+
validateNumber(target, "Network.train");
|
|
235
|
+
validateNumber(lr, "Network.train");
|
|
144
236
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
145
237
|
const prediction = this.outputLayer.predict(hiddenOut)[0];
|
|
146
238
|
const outputError = target - prediction;
|
|
147
239
|
const outputDelta = outputError * prediction * (1 - prediction);
|
|
148
240
|
const outputNeuron = this.outputLayer.neurons[0];
|
|
241
|
+
const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
|
|
242
|
+
const hiddenOut_i = hiddenOut[i];
|
|
243
|
+
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
244
|
+
return hiddenError * hiddenOut_i * (1 - hiddenOut_i);
|
|
245
|
+
});
|
|
246
|
+
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
247
|
+
neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDeltas[i] * inputs[j]);
|
|
248
|
+
neuron.bias += lr * hiddenDeltas[i];
|
|
249
|
+
});
|
|
149
250
|
outputNeuron.weights = outputNeuron.weights.map(
|
|
150
251
|
(w, i) => w + lr * outputDelta * hiddenOut[i]
|
|
151
252
|
);
|
|
152
253
|
outputNeuron.bias += lr * outputDelta;
|
|
153
|
-
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
154
|
-
const hiddenOut_i = hiddenOut[i];
|
|
155
|
-
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
156
|
-
const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
|
|
157
|
-
neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
|
|
158
|
-
neuron.bias += lr * hiddenDelta;
|
|
159
|
-
});
|
|
160
254
|
return outputError * outputError;
|
|
161
255
|
}
|
|
256
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
257
|
+
// Order: hidden layer (all neurons: weights then bias), then output layer.
|
|
258
|
+
getWeights() {
|
|
259
|
+
const w = [];
|
|
260
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
261
|
+
w.push(...n.weights, n.bias);
|
|
262
|
+
}
|
|
263
|
+
for (const n of this.outputLayer.neurons) {
|
|
264
|
+
w.push(...n.weights, n.bias);
|
|
265
|
+
}
|
|
266
|
+
return w;
|
|
267
|
+
}
|
|
268
|
+
setWeights(weights) {
|
|
269
|
+
let idx = 0;
|
|
270
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
271
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
272
|
+
n.bias = weights[idx++];
|
|
273
|
+
}
|
|
274
|
+
for (const n of this.outputLayer.neurons) {
|
|
275
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
276
|
+
n.bias = weights[idx++];
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
};
|
|
280
|
+
|
|
281
|
+
// src/Dropout.ts
|
|
282
|
+
var Dropout = class {
|
|
283
|
+
constructor(rate) {
|
|
284
|
+
this._mask = null;
|
|
285
|
+
if (rate < 0 || rate >= 1) {
|
|
286
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
|
|
287
|
+
}
|
|
288
|
+
this.rate = rate;
|
|
289
|
+
}
|
|
290
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
291
|
+
// x: number[] → number[]
|
|
292
|
+
// If training, applies inverted dropout mask.
|
|
293
|
+
// If not training, returns input unchanged.
|
|
294
|
+
forward(x, training = true) {
|
|
295
|
+
if (!training || this.rate === 0) {
|
|
296
|
+
this._mask = null;
|
|
297
|
+
return [...x];
|
|
298
|
+
}
|
|
299
|
+
const scale = 1 / (1 - this.rate);
|
|
300
|
+
this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
|
|
301
|
+
return x.map((v, i) => v * this._mask[i]);
|
|
302
|
+
}
|
|
303
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
304
|
+
// dOut: number[] → number[]
|
|
305
|
+
// Applies the same mask (gradient is zeroed where activation was zeroed).
|
|
306
|
+
backward(dOut) {
|
|
307
|
+
if (!this._mask) return [...dOut];
|
|
308
|
+
return dOut.map((d, i) => d * this._mask[i]);
|
|
309
|
+
}
|
|
310
|
+
// ── Reset mask between forward passes ─────────────────────────────────────
|
|
311
|
+
resetMask() {
|
|
312
|
+
this._mask = null;
|
|
313
|
+
}
|
|
314
|
+
// ── No trainable params ───────────────────────────────────────────────────
|
|
315
|
+
getWeights() {
|
|
316
|
+
return [];
|
|
317
|
+
}
|
|
318
|
+
setWeights(_weights) {
|
|
319
|
+
}
|
|
162
320
|
};
|
|
163
321
|
|
|
164
322
|
// src/NetworkN.ts
|
|
@@ -169,30 +327,96 @@ var NetworkN = class {
|
|
|
169
327
|
const nLayers = structure.length - 1;
|
|
170
328
|
const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
|
|
171
329
|
const optimizer = options.optimizer ?? defaultOptimizer3;
|
|
330
|
+
const dropoutRate = options.dropoutRate ?? 0;
|
|
331
|
+
if (activations.length !== nLayers) {
|
|
332
|
+
throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
|
|
333
|
+
}
|
|
334
|
+
if (dropoutRate < 0 || dropoutRate >= 1) {
|
|
335
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
|
|
336
|
+
}
|
|
337
|
+
this._residual = options.residual ?? false;
|
|
172
338
|
this.layers = [];
|
|
173
339
|
for (let i = 1; i < structure.length; i++) {
|
|
174
340
|
this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
|
|
175
341
|
}
|
|
342
|
+
this._dropouts = [];
|
|
343
|
+
if (dropoutRate > 0) {
|
|
344
|
+
for (let i = 0; i < nLayers - 1; i++) {
|
|
345
|
+
this._dropouts.push(new Dropout(dropoutRate));
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
const outputLayer = this.layers[this.layers.length - 1];
|
|
349
|
+
const outputActivation = outputLayer.neurons[0].activation;
|
|
350
|
+
for (let i = 1; i < outputLayer.neurons.length; i++) {
|
|
351
|
+
if (outputLayer.neurons[i].activation !== outputActivation) {
|
|
352
|
+
throw new Error("All output neurons must share the same activation function");
|
|
353
|
+
}
|
|
354
|
+
}
|
|
176
355
|
}
|
|
177
|
-
predict(inputs) {
|
|
178
|
-
|
|
356
|
+
predict(inputs, training = false) {
|
|
357
|
+
validateArray(inputs, this.structure[0], "NetworkN.predict");
|
|
358
|
+
let current = [...inputs];
|
|
359
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
360
|
+
const layerInput = [...current];
|
|
361
|
+
const layerOutput = this.layers[i].predict(current);
|
|
362
|
+
if (this._shouldResidual(i)) {
|
|
363
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
364
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
365
|
+
} else {
|
|
366
|
+
current = [...layerOutput];
|
|
367
|
+
}
|
|
368
|
+
} else {
|
|
369
|
+
current = [...layerOutput];
|
|
370
|
+
}
|
|
371
|
+
if (i < this._dropouts.length) {
|
|
372
|
+
current = this._dropouts[i].forward(current, training);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
return current;
|
|
179
376
|
}
|
|
180
377
|
// Generalized backpropagation across L layers.
|
|
181
378
|
// Returns the mean squared error for the example.
|
|
182
379
|
train(inputs, targets, lr) {
|
|
380
|
+
validateArray(inputs, this.structure[0], "NetworkN.train");
|
|
381
|
+
validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
|
|
183
382
|
const act = [inputs];
|
|
184
|
-
for (
|
|
383
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
384
|
+
const layerInput = act[act.length - 1];
|
|
385
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
386
|
+
let current;
|
|
387
|
+
if (this._shouldResidual(i)) {
|
|
388
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
389
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
390
|
+
} else {
|
|
391
|
+
current = [...layerOutput];
|
|
392
|
+
}
|
|
393
|
+
} else {
|
|
394
|
+
current = [...layerOutput];
|
|
395
|
+
}
|
|
396
|
+
if (i < this._dropouts.length) {
|
|
397
|
+
current = this._dropouts[i].forward(current, true);
|
|
398
|
+
}
|
|
399
|
+
act.push(current);
|
|
400
|
+
}
|
|
185
401
|
const pred = act[act.length - 1];
|
|
186
402
|
const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
|
|
187
403
|
let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
|
|
188
404
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
189
405
|
const layer = this.layers[l];
|
|
406
|
+
if (l < this._dropouts.length) {
|
|
407
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
408
|
+
}
|
|
190
409
|
const layerIn = act[l];
|
|
191
410
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
192
411
|
const prevDeltas = layerIn.map((out, j) => {
|
|
193
412
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
194
413
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
195
414
|
});
|
|
415
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
416
|
+
for (let j = 0; j < prevDeltas.length; j++) {
|
|
417
|
+
prevDeltas[j] += deltas[j];
|
|
418
|
+
}
|
|
419
|
+
}
|
|
196
420
|
layer.neurons.forEach((n, k) => {
|
|
197
421
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
198
422
|
});
|
|
@@ -204,22 +428,74 @@ var NetworkN = class {
|
|
|
204
428
|
// Useful for custom loss functions (e.g. physics-based gradients).
|
|
205
429
|
trainWithDeltas(inputs, outputDeltas, lr) {
|
|
206
430
|
const act = [inputs];
|
|
207
|
-
for (
|
|
431
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
432
|
+
const layerInput = act[act.length - 1];
|
|
433
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
434
|
+
let current;
|
|
435
|
+
if (this._shouldResidual(i)) {
|
|
436
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
437
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
438
|
+
} else {
|
|
439
|
+
current = [...layerOutput];
|
|
440
|
+
}
|
|
441
|
+
} else {
|
|
442
|
+
current = [...layerOutput];
|
|
443
|
+
}
|
|
444
|
+
if (i < this._dropouts.length) {
|
|
445
|
+
current = this._dropouts[i].forward(current, true);
|
|
446
|
+
}
|
|
447
|
+
act.push(current);
|
|
448
|
+
}
|
|
208
449
|
let deltas = outputDeltas;
|
|
209
450
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
210
451
|
const layer = this.layers[l];
|
|
452
|
+
if (l < this._dropouts.length) {
|
|
453
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
454
|
+
}
|
|
211
455
|
const layerIn = act[l];
|
|
212
456
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
213
457
|
const prevDeltas = layerIn.map((out, j) => {
|
|
214
458
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
215
459
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
216
460
|
});
|
|
461
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
462
|
+
for (let j = 0; j < prevDeltas.length; j++) {
|
|
463
|
+
prevDeltas[j] += deltas[j];
|
|
464
|
+
}
|
|
465
|
+
}
|
|
217
466
|
layer.neurons.forEach((n, k) => {
|
|
218
467
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
219
468
|
});
|
|
220
469
|
deltas = prevDeltas;
|
|
221
470
|
}
|
|
222
471
|
}
|
|
472
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
473
|
+
// Order: layer 0 (all neurons), layer 1, ..., layer N.
|
|
474
|
+
getWeights() {
|
|
475
|
+
for (const d of this._dropouts) d.resetMask();
|
|
476
|
+
const w = [];
|
|
477
|
+
for (const layer of this.layers) {
|
|
478
|
+
for (const n of layer.neurons) {
|
|
479
|
+
w.push(...n.weights, n.bias);
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
return w;
|
|
483
|
+
}
|
|
484
|
+
setWeights(weights) {
|
|
485
|
+
for (const d of this._dropouts) d.resetMask();
|
|
486
|
+
let idx = 0;
|
|
487
|
+
for (const layer of this.layers) {
|
|
488
|
+
for (const n of layer.neurons) {
|
|
489
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
490
|
+
n.bias = weights[idx++];
|
|
491
|
+
}
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
// ── Helper ───────────────────────────────────────────────────────────────
|
|
495
|
+
_shouldResidual(layerIndex) {
|
|
496
|
+
if (typeof this._residual === "function") return this._residual(layerIndex);
|
|
497
|
+
return this._residual;
|
|
498
|
+
}
|
|
223
499
|
};
|
|
224
500
|
|
|
225
501
|
// src/LSTMLayer.ts
|
|
@@ -248,8 +524,11 @@ var Gate = class {
|
|
|
248
524
|
}
|
|
249
525
|
};
|
|
250
526
|
var LSTMLayer = class {
|
|
251
|
-
constructor(inputSize, hiddenSize) {
|
|
527
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
252
528
|
this._traj = [];
|
|
529
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
530
|
+
throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
531
|
+
}
|
|
253
532
|
this.inputSize = inputSize;
|
|
254
533
|
this.hSize = hiddenSize;
|
|
255
534
|
this.h = new Array(hiddenSize).fill(0);
|
|
@@ -258,6 +537,29 @@ var LSTMLayer = class {
|
|
|
258
537
|
this.inputGate = new Gate(inputSize, hiddenSize);
|
|
259
538
|
this.cellGate = new Gate(inputSize, hiddenSize);
|
|
260
539
|
this.outputGate = new Gate(inputSize, hiddenSize);
|
|
540
|
+
const combSize = inputSize + hiddenSize;
|
|
541
|
+
this._optimizers = {
|
|
542
|
+
forgetW: Array.from(
|
|
543
|
+
{ length: hiddenSize },
|
|
544
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
545
|
+
),
|
|
546
|
+
forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
547
|
+
inputW: Array.from(
|
|
548
|
+
{ length: hiddenSize },
|
|
549
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
550
|
+
),
|
|
551
|
+
inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
552
|
+
cellW: Array.from(
|
|
553
|
+
{ length: hiddenSize },
|
|
554
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
555
|
+
),
|
|
556
|
+
cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
557
|
+
outputW: Array.from(
|
|
558
|
+
{ length: hiddenSize },
|
|
559
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
560
|
+
),
|
|
561
|
+
outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
562
|
+
};
|
|
261
563
|
}
|
|
262
564
|
// ── Reset state and trajectory (call at episode start) ────────────────────
|
|
263
565
|
reset() {
|
|
@@ -267,6 +569,9 @@ var LSTMLayer = class {
|
|
|
267
569
|
}
|
|
268
570
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
269
571
|
predict(inputs) {
|
|
572
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
573
|
+
throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
574
|
+
}
|
|
270
575
|
const combined = [...inputs, ...this.h];
|
|
271
576
|
const c_prev = [...this.c];
|
|
272
577
|
const zf = this.forgetGate.linear(combined);
|
|
@@ -341,15 +646,15 @@ var LSTMLayer = class {
|
|
|
341
646
|
const scale = lr / T;
|
|
342
647
|
for (let k = 0; k < hSize; k++) {
|
|
343
648
|
for (let j = 0; j < combSize; j++) {
|
|
344
|
-
this.forgetGate.W[k][j]
|
|
345
|
-
this.inputGate.W[k][j]
|
|
346
|
-
this.cellGate.W[k][j]
|
|
347
|
-
this.outputGate.W[k][j]
|
|
649
|
+
this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
|
|
650
|
+
this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
|
|
651
|
+
this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
|
|
652
|
+
this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
|
|
348
653
|
}
|
|
349
|
-
this.forgetGate.b[k]
|
|
350
|
-
this.inputGate.b[k]
|
|
351
|
-
this.cellGate.b[k]
|
|
352
|
-
this.outputGate.b[k]
|
|
654
|
+
this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
|
|
655
|
+
this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
|
|
656
|
+
this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
|
|
657
|
+
this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
|
|
353
658
|
}
|
|
354
659
|
this._traj = [];
|
|
355
660
|
}
|
|
@@ -372,6 +677,35 @@ var LSTMLayer = class {
|
|
|
372
677
|
this.outputGate.W = data.outputGate.W;
|
|
373
678
|
this.outputGate.b = data.outputGate.b;
|
|
374
679
|
}
|
|
680
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
681
|
+
// Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
|
|
682
|
+
getWeightsFlat() {
|
|
683
|
+
const w = [];
|
|
684
|
+
for (const row of this.forgetGate.W) w.push(...row);
|
|
685
|
+
w.push(...this.forgetGate.b);
|
|
686
|
+
for (const row of this.inputGate.W) w.push(...row);
|
|
687
|
+
w.push(...this.inputGate.b);
|
|
688
|
+
for (const row of this.cellGate.W) w.push(...row);
|
|
689
|
+
w.push(...this.cellGate.b);
|
|
690
|
+
for (const row of this.outputGate.W) w.push(...row);
|
|
691
|
+
w.push(...this.outputGate.b);
|
|
692
|
+
return w;
|
|
693
|
+
}
|
|
694
|
+
setWeightsFlat(weights) {
|
|
695
|
+
let idx = 0;
|
|
696
|
+
for (let i = 0; i < this.forgetGate.W.length; i++)
|
|
697
|
+
for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
|
|
698
|
+
for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
|
|
699
|
+
for (let i = 0; i < this.inputGate.W.length; i++)
|
|
700
|
+
for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
|
|
701
|
+
for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
|
|
702
|
+
for (let i = 0; i < this.cellGate.W.length; i++)
|
|
703
|
+
for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
|
|
704
|
+
for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
|
|
705
|
+
for (let i = 0; i < this.outputGate.W.length; i++)
|
|
706
|
+
for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
|
|
707
|
+
for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
|
|
708
|
+
}
|
|
375
709
|
};
|
|
376
710
|
|
|
377
711
|
// src/NetworkLSTM.ts
|
|
@@ -398,6 +732,7 @@ var NetworkLSTM = class {
|
|
|
398
732
|
}
|
|
399
733
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
400
734
|
predict(inputs) {
|
|
735
|
+
validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
|
|
401
736
|
const h = this.lstm.predict(inputs);
|
|
402
737
|
const acts = [h];
|
|
403
738
|
for (const layer of this.denseLayers) {
|
|
@@ -473,6 +808,30 @@ var NetworkLSTM = class {
|
|
|
473
808
|
});
|
|
474
809
|
});
|
|
475
810
|
}
|
|
811
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
812
|
+
// Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
|
|
813
|
+
getWeightsFlat() {
|
|
814
|
+
const w = [];
|
|
815
|
+
w.push(...this.lstm.getWeightsFlat());
|
|
816
|
+
for (const layer of this.denseLayers) {
|
|
817
|
+
for (const n of layer.neurons) {
|
|
818
|
+
w.push(...n.weights, n.bias);
|
|
819
|
+
}
|
|
820
|
+
}
|
|
821
|
+
return w;
|
|
822
|
+
}
|
|
823
|
+
setWeightsFlat(weights) {
|
|
824
|
+
let idx = 0;
|
|
825
|
+
const lstmLen = this.lstm.getWeightsFlat().length;
|
|
826
|
+
this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
|
|
827
|
+
idx += lstmLen;
|
|
828
|
+
for (const layer of this.denseLayers) {
|
|
829
|
+
for (const n of layer.neurons) {
|
|
830
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
831
|
+
n.bias = weights[idx++];
|
|
832
|
+
}
|
|
833
|
+
}
|
|
834
|
+
}
|
|
476
835
|
};
|
|
477
836
|
|
|
478
837
|
// src/MatMul.ts
|
|
@@ -480,6 +839,9 @@ function matMul(A, B) {
|
|
|
480
839
|
const rows = A.length;
|
|
481
840
|
const inner = B.length;
|
|
482
841
|
const cols = B[0].length;
|
|
842
|
+
if (A[0].length !== B.length) {
|
|
843
|
+
throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
|
|
844
|
+
}
|
|
483
845
|
const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
|
|
484
846
|
for (let i = 0; i < rows; i++)
|
|
485
847
|
for (let k = 0; k < inner; k++) {
|
|
@@ -530,6 +892,17 @@ var WeightMatrix = class {
|
|
|
530
892
|
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
531
893
|
}
|
|
532
894
|
}
|
|
895
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
896
|
+
getWeights() {
|
|
897
|
+
const w = [];
|
|
898
|
+
for (const row of this.W) w.push(...row);
|
|
899
|
+
return w;
|
|
900
|
+
}
|
|
901
|
+
setWeights(weights) {
|
|
902
|
+
let idx = 0;
|
|
903
|
+
for (let i = 0; i < this.W.length; i++)
|
|
904
|
+
for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
|
|
905
|
+
}
|
|
533
906
|
};
|
|
534
907
|
var EmbeddingMatrix = class {
|
|
535
908
|
constructor(vocabSize, d_model) {
|
|
@@ -546,15 +919,29 @@ var EmbeddingMatrix = class {
|
|
|
546
919
|
for (let m = 0; m < this.W[idx].length; m++)
|
|
547
920
|
this.W[idx][m] += lr * grad[m];
|
|
548
921
|
}
|
|
922
|
+
// ── Serializable interface ─────────────────────────────────────────────────
|
|
923
|
+
// Flattened order: row 0, row 1, ... row (vocabSize-1)
|
|
924
|
+
getWeights() {
|
|
925
|
+
const w = [];
|
|
926
|
+
for (const row of this.W) w.push(...row);
|
|
927
|
+
return w;
|
|
928
|
+
}
|
|
929
|
+
setWeights(weights) {
|
|
930
|
+
let idx = 0;
|
|
931
|
+
for (let i = 0; i < this.W.length; i++)
|
|
932
|
+
for (let j = 0; j < this.W[i].length; j++)
|
|
933
|
+
this.W[i][j] = weights[idx++];
|
|
934
|
+
}
|
|
549
935
|
};
|
|
550
936
|
|
|
551
937
|
// src/AttentionHead.ts
|
|
552
938
|
var AttentionHead = class {
|
|
553
|
-
constructor(d_model, d_k, d_v) {
|
|
939
|
+
constructor(d_model, d_k, d_v, causal = false) {
|
|
554
940
|
// d_v × d_model
|
|
555
941
|
this.cache = null;
|
|
556
942
|
this.d_k = d_k;
|
|
557
943
|
this.d_v = d_v;
|
|
944
|
+
this.causal = causal;
|
|
558
945
|
this.Wq = new WeightMatrix(d_k, d_model);
|
|
559
946
|
this.Wk = new WeightMatrix(d_k, d_model);
|
|
560
947
|
this.Wv = new WeightMatrix(d_v, d_model);
|
|
@@ -575,10 +962,10 @@ var AttentionHead = class {
|
|
|
575
962
|
);
|
|
576
963
|
const scores = Array.from(
|
|
577
964
|
{ length: seqLen },
|
|
578
|
-
(_, i) => Array.from(
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
)
|
|
965
|
+
(_, i) => Array.from({ length: seqLen }, (_2, j) => {
|
|
966
|
+
if (this.causal && j > i) return -Infinity;
|
|
967
|
+
return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
|
|
968
|
+
})
|
|
582
969
|
);
|
|
583
970
|
const attn = scores.map((row) => softmax(row));
|
|
584
971
|
const out = Array.from(
|
|
@@ -674,21 +1061,40 @@ var AttentionHead = class {
|
|
|
674
1061
|
getAttentionWeights() {
|
|
675
1062
|
return this.cache ? this.cache.attn : null;
|
|
676
1063
|
}
|
|
1064
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1065
|
+
// Order: Wq, Wk, Wv.
|
|
1066
|
+
getWeights() {
|
|
1067
|
+
const w = [];
|
|
1068
|
+
for (const row of this.Wq.W) w.push(...row);
|
|
1069
|
+
for (const row of this.Wk.W) w.push(...row);
|
|
1070
|
+
for (const row of this.Wv.W) w.push(...row);
|
|
1071
|
+
return w;
|
|
1072
|
+
}
|
|
1073
|
+
setWeights(weights) {
|
|
1074
|
+
let idx = 0;
|
|
1075
|
+
for (let i = 0; i < this.Wq.W.length; i++)
|
|
1076
|
+
for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
|
|
1077
|
+
for (let i = 0; i < this.Wk.W.length; i++)
|
|
1078
|
+
for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
|
|
1079
|
+
for (let i = 0; i < this.Wv.W.length; i++)
|
|
1080
|
+
for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
|
|
1081
|
+
}
|
|
677
1082
|
};
|
|
678
1083
|
|
|
679
1084
|
// src/MultiHeadAttention.ts
|
|
680
1085
|
var MultiHeadAttention = class {
|
|
681
1086
|
// seqLen × (nHeads * d_k)
|
|
682
|
-
constructor(d_model, nHeads) {
|
|
1087
|
+
constructor(d_model, nHeads, causal = false) {
|
|
683
1088
|
// d_model × (nHeads * d_k)
|
|
684
1089
|
// Cached for backward
|
|
685
1090
|
this._concat = null;
|
|
686
1091
|
this.nHeads = nHeads;
|
|
687
1092
|
this.d_model = d_model;
|
|
688
1093
|
this.d_k = Math.floor(d_model / nHeads);
|
|
1094
|
+
this.causal = causal;
|
|
689
1095
|
this.heads = Array.from(
|
|
690
1096
|
{ length: nHeads },
|
|
691
|
-
() => new AttentionHead(d_model, this.d_k, this.d_k)
|
|
1097
|
+
() => new AttentionHead(d_model, this.d_k, this.d_k, causal)
|
|
692
1098
|
);
|
|
693
1099
|
this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
|
|
694
1100
|
}
|
|
@@ -747,6 +1153,31 @@ var MultiHeadAttention = class {
|
|
|
747
1153
|
getAttentionWeights() {
|
|
748
1154
|
return this.heads.map((h) => h.getAttentionWeights());
|
|
749
1155
|
}
|
|
1156
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1157
|
+
// Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
|
|
1158
|
+
getWeights() {
|
|
1159
|
+
const w = [];
|
|
1160
|
+
for (const head of this.heads) {
|
|
1161
|
+
for (const row of head.Wq.W) w.push(...row);
|
|
1162
|
+
for (const row of head.Wk.W) w.push(...row);
|
|
1163
|
+
for (const row of head.Wv.W) w.push(...row);
|
|
1164
|
+
}
|
|
1165
|
+
for (const row of this.Wo.W) w.push(...row);
|
|
1166
|
+
return w;
|
|
1167
|
+
}
|
|
1168
|
+
setWeights(weights) {
|
|
1169
|
+
let idx = 0;
|
|
1170
|
+
for (const head of this.heads) {
|
|
1171
|
+
for (let i = 0; i < head.Wq.W.length; i++)
|
|
1172
|
+
for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
|
|
1173
|
+
for (let i = 0; i < head.Wk.W.length; i++)
|
|
1174
|
+
for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
|
|
1175
|
+
for (let i = 0; i < head.Wv.W.length; i++)
|
|
1176
|
+
for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
|
|
1177
|
+
}
|
|
1178
|
+
for (let i = 0; i < this.Wo.W.length; i++)
|
|
1179
|
+
for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
|
|
1180
|
+
}
|
|
750
1181
|
};
|
|
751
1182
|
|
|
752
1183
|
// src/LayerNorm.ts
|
|
@@ -798,11 +1229,21 @@ var LayerNorm = class {
|
|
|
798
1229
|
const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
|
|
799
1230
|
return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
|
|
800
1231
|
}
|
|
1232
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1233
|
+
// Order: gamma, beta.
|
|
1234
|
+
getWeights() {
|
|
1235
|
+
return [...this.gamma, ...this.beta];
|
|
1236
|
+
}
|
|
1237
|
+
setWeights(weights) {
|
|
1238
|
+
const dim = this.gamma.length;
|
|
1239
|
+
for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
|
|
1240
|
+
for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
|
|
1241
|
+
}
|
|
801
1242
|
};
|
|
802
1243
|
|
|
803
1244
|
// src/TransformerBlock.ts
|
|
804
1245
|
var TransformerBlock = class {
|
|
805
|
-
constructor({ d_model, nHeads, d_ff }) {
|
|
1246
|
+
constructor({ d_model, nHeads, d_ff, causal = false }) {
|
|
806
1247
|
// Forward caches (needed for backprop)
|
|
807
1248
|
this._X = null;
|
|
808
1249
|
this._attnOut = null;
|
|
@@ -814,7 +1255,7 @@ var TransformerBlock = class {
|
|
|
814
1255
|
this._ff2Out = null;
|
|
815
1256
|
this.d_model = d_model;
|
|
816
1257
|
this.d_ff = d_ff;
|
|
817
|
-
this.attn = new MultiHeadAttention(d_model, nHeads);
|
|
1258
|
+
this.attn = new MultiHeadAttention(d_model, nHeads, causal);
|
|
818
1259
|
this.norm1 = new LayerNorm(d_model);
|
|
819
1260
|
this.norm2 = new LayerNorm(d_model);
|
|
820
1261
|
this.ff1 = new WeightMatrix(d_ff, d_model);
|
|
@@ -927,6 +1368,35 @@ var TransformerBlock = class {
|
|
|
927
1368
|
getAttentionWeights() {
|
|
928
1369
|
return this.attn.getAttentionWeights();
|
|
929
1370
|
}
|
|
1371
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1372
|
+
// Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
|
|
1373
|
+
getWeights() {
|
|
1374
|
+
const w = [];
|
|
1375
|
+
w.push(...this.attn.getWeights());
|
|
1376
|
+
w.push(...this.norm1.gamma, ...this.norm1.beta);
|
|
1377
|
+
for (const row of this.ff1.W) w.push(...row);
|
|
1378
|
+
w.push(...this.b1);
|
|
1379
|
+
for (const row of this.ff2.W) w.push(...row);
|
|
1380
|
+
w.push(...this.b2);
|
|
1381
|
+
w.push(...this.norm2.gamma, ...this.norm2.beta);
|
|
1382
|
+
return w;
|
|
1383
|
+
}
|
|
1384
|
+
setWeights(weights) {
|
|
1385
|
+
let idx = 0;
|
|
1386
|
+
const attnLen = this.attn.getWeights().length;
|
|
1387
|
+
this.attn.setWeights(weights.slice(idx, idx + attnLen));
|
|
1388
|
+
idx += attnLen;
|
|
1389
|
+
for (let i = 0; i < this.norm1.gamma.length; i++) this.norm1.gamma[i] = weights[idx++];
|
|
1390
|
+
for (let i = 0; i < this.norm1.beta.length; i++) this.norm1.beta[i] = weights[idx++];
|
|
1391
|
+
for (let i = 0; i < this.ff1.W.length; i++)
|
|
1392
|
+
for (let j = 0; j < this.ff1.W[i].length; j++) this.ff1.W[i][j] = weights[idx++];
|
|
1393
|
+
for (let i = 0; i < this.b1.length; i++) this.b1[i] = weights[idx++];
|
|
1394
|
+
for (let i = 0; i < this.ff2.W.length; i++)
|
|
1395
|
+
for (let j = 0; j < this.ff2.W[i].length; j++) this.ff2.W[i][j] = weights[idx++];
|
|
1396
|
+
for (let i = 0; i < this.b2.length; i++) this.b2[i] = weights[idx++];
|
|
1397
|
+
for (let i = 0; i < this.norm2.gamma.length; i++) this.norm2.gamma[i] = weights[idx++];
|
|
1398
|
+
for (let i = 0; i < this.norm2.beta.length; i++) this.norm2.beta[i] = weights[idx++];
|
|
1399
|
+
}
|
|
930
1400
|
};
|
|
931
1401
|
|
|
932
1402
|
// src/NetworkTransformer.ts
|
|
@@ -1025,6 +1495,32 @@ var NetworkTransformer = class {
|
|
|
1025
1495
|
getAttentionWeights() {
|
|
1026
1496
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1027
1497
|
}
|
|
1498
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1499
|
+
// Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1500
|
+
getWeights() {
|
|
1501
|
+
const w = [];
|
|
1502
|
+
for (const row of this.tokenEmb.W) w.push(...row);
|
|
1503
|
+
for (const row of this.posEmb.W) w.push(...row);
|
|
1504
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1505
|
+
for (const row of this.outputProj.W) w.push(...row);
|
|
1506
|
+
w.push(...this.outputBias);
|
|
1507
|
+
return w;
|
|
1508
|
+
}
|
|
1509
|
+
setWeights(weights) {
|
|
1510
|
+
let idx = 0;
|
|
1511
|
+
for (let i = 0; i < this.tokenEmb.W.length; i++)
|
|
1512
|
+
for (let j = 0; j < this.tokenEmb.W[i].length; j++) this.tokenEmb.W[i][j] = weights[idx++];
|
|
1513
|
+
for (let i = 0; i < this.posEmb.W.length; i++)
|
|
1514
|
+
for (let j = 0; j < this.posEmb.W[i].length; j++) this.posEmb.W[i][j] = weights[idx++];
|
|
1515
|
+
for (const block of this.blocks) {
|
|
1516
|
+
const blockLen = block.getWeights().length;
|
|
1517
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1518
|
+
idx += blockLen;
|
|
1519
|
+
}
|
|
1520
|
+
for (let i = 0; i < this.outputProj.W.length; i++)
|
|
1521
|
+
for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
|
|
1522
|
+
for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
|
|
1523
|
+
}
|
|
1028
1524
|
// ── Internal ──────────────────────────────────────────────────────────────
|
|
1029
1525
|
// Shared embedding + block forward pass.
|
|
1030
1526
|
_forward(tokens) {
|
|
@@ -1044,21 +1540,25 @@ var NetworkTransformerRL = class {
|
|
|
1044
1540
|
constructor(seqLen, inputDim, options = {}) {
|
|
1045
1541
|
// Forward caches para backprop
|
|
1046
1542
|
this._projected = null;
|
|
1543
|
+
// For max pooling backward: argmax per dimension across all positions
|
|
1544
|
+
this._argmax = null;
|
|
1047
1545
|
const {
|
|
1048
1546
|
d_model = 32,
|
|
1049
1547
|
nHeads = 2,
|
|
1050
1548
|
d_ff = 64,
|
|
1051
1549
|
nBlocks = 2,
|
|
1052
|
-
nActions = 2
|
|
1550
|
+
nActions = 2,
|
|
1551
|
+
pooling = "weighted"
|
|
1053
1552
|
} = options;
|
|
1054
1553
|
this.seqLen = seqLen;
|
|
1055
1554
|
this.inputDim = inputDim;
|
|
1056
1555
|
this.d_model = d_model;
|
|
1057
1556
|
this.nActions = nActions;
|
|
1557
|
+
this._pooling = pooling;
|
|
1058
1558
|
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1059
1559
|
this.blocks = Array.from(
|
|
1060
1560
|
{ length: nBlocks },
|
|
1061
|
-
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1561
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
|
|
1062
1562
|
);
|
|
1063
1563
|
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1064
1564
|
this.outputBias = new Array(nActions).fill(0);
|
|
@@ -1107,11 +1607,7 @@ var NetworkTransformerRL = class {
|
|
|
1107
1607
|
this.outputProj.update(dWout, lr);
|
|
1108
1608
|
for (let c = 0; c < this.nActions; c++)
|
|
1109
1609
|
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1110
|
-
let dH =
|
|
1111
|
-
{ length: this.seqLen },
|
|
1112
|
-
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1113
|
-
// Gradiente dividido entre posiciones
|
|
1114
|
-
);
|
|
1610
|
+
let dH = this._distributePoolGradient(dPooled);
|
|
1115
1611
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1116
1612
|
dH = this.blocks[b].backward(dH, lr);
|
|
1117
1613
|
for (let i = 0; i < this.seqLen; i++) {
|
|
@@ -1130,6 +1626,85 @@ var NetworkTransformerRL = class {
|
|
|
1130
1626
|
getAttentionWeights() {
|
|
1131
1627
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1132
1628
|
}
|
|
1629
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1630
|
+
// Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1631
|
+
getWeightsFlat() {
|
|
1632
|
+
const w = [];
|
|
1633
|
+
for (const row of this.inputProj.W) w.push(...row);
|
|
1634
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1635
|
+
for (const row of this.outputProj.W) w.push(...row);
|
|
1636
|
+
w.push(...this.outputBias);
|
|
1637
|
+
return w;
|
|
1638
|
+
}
|
|
1639
|
+
setWeightsFlat(weights) {
|
|
1640
|
+
let idx = 0;
|
|
1641
|
+
for (let i = 0; i < this.inputProj.W.length; i++)
|
|
1642
|
+
for (let j = 0; j < this.inputProj.W[i].length; j++) this.inputProj.W[i][j] = weights[idx++];
|
|
1643
|
+
for (const block of this.blocks) {
|
|
1644
|
+
const blockLen = block.getWeights().length;
|
|
1645
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1646
|
+
idx += blockLen;
|
|
1647
|
+
}
|
|
1648
|
+
for (let i = 0; i < this.outputProj.W.length; i++)
|
|
1649
|
+
for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
|
|
1650
|
+
for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
|
|
1651
|
+
}
|
|
1652
|
+
getWeightsStructured() {
|
|
1653
|
+
return {
|
|
1654
|
+
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1655
|
+
blocks: this.blocks.map((b) => ({
|
|
1656
|
+
attn: {
|
|
1657
|
+
heads: b.attn.heads.map((h) => ({
|
|
1658
|
+
Wq: h.Wq.W.map((r) => [...r]),
|
|
1659
|
+
Wk: h.Wk.W.map((r) => [...r]),
|
|
1660
|
+
Wv: h.Wv.W.map((r) => [...r])
|
|
1661
|
+
})),
|
|
1662
|
+
Wo: b.attn.Wo.W.map((r) => [...r])
|
|
1663
|
+
},
|
|
1664
|
+
norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
|
|
1665
|
+
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1666
|
+
ff1: b.ff1.W.map((r) => [...r]),
|
|
1667
|
+
ff2: b.ff2.W.map((r) => [...r]),
|
|
1668
|
+
b1: [...b.b1],
|
|
1669
|
+
b2: [...b.b2]
|
|
1670
|
+
})),
|
|
1671
|
+
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1672
|
+
outputBias: [...this.outputBias]
|
|
1673
|
+
};
|
|
1674
|
+
}
|
|
1675
|
+
setWeightsStructured(data) {
|
|
1676
|
+
data.inputProj.forEach((row, i) => {
|
|
1677
|
+
this.inputProj.W[i] = [...row];
|
|
1678
|
+
});
|
|
1679
|
+
data.blocks.forEach((bd, b) => {
|
|
1680
|
+
const blk = this.blocks[b];
|
|
1681
|
+
bd.attn.heads.forEach((hd, h) => {
|
|
1682
|
+
blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
|
|
1683
|
+
blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
|
|
1684
|
+
blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
|
|
1685
|
+
});
|
|
1686
|
+
blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
|
|
1687
|
+
blk.norm1.gamma = [...bd.norm1.gamma];
|
|
1688
|
+
blk.norm1.beta = [...bd.norm1.beta];
|
|
1689
|
+
blk.norm2.gamma = [...bd.norm2.gamma];
|
|
1690
|
+
blk.norm2.beta = [...bd.norm2.beta];
|
|
1691
|
+
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1692
|
+
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1693
|
+
blk.b1 = [...bd.b1];
|
|
1694
|
+
blk.b2 = [...bd.b2];
|
|
1695
|
+
});
|
|
1696
|
+
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1697
|
+
this.outputBias = [...data.outputBias];
|
|
1698
|
+
}
|
|
1699
|
+
// ── Serializable interface (flat array) ────────────────────────────────────
|
|
1700
|
+
// These satisfy the Serializable interface from ModelSaver, which requires
|
|
1701
|
+
// getWeights(): number[] and setWeights(weights: number[]): void.
|
|
1702
|
+
getWeights() {
|
|
1703
|
+
return this.getWeightsFlat();
|
|
1704
|
+
}
|
|
1705
|
+
setWeights(weights) {
|
|
1706
|
+
this.setWeightsFlat(weights);
|
|
1707
|
+
}
|
|
1133
1708
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1134
1709
|
_forward(sequence) {
|
|
1135
1710
|
let h = sequence.map(
|
|
@@ -1143,6 +1718,44 @@ var NetworkTransformerRL = class {
|
|
|
1143
1718
|
return h;
|
|
1144
1719
|
}
|
|
1145
1720
|
_pool(h) {
|
|
1721
|
+
switch (this._pooling) {
|
|
1722
|
+
case "avg":
|
|
1723
|
+
return this._poolAvg(h);
|
|
1724
|
+
case "max":
|
|
1725
|
+
return this._poolMax(h);
|
|
1726
|
+
case "last":
|
|
1727
|
+
return this._poolLast(h);
|
|
1728
|
+
case "weighted":
|
|
1729
|
+
default:
|
|
1730
|
+
return this._poolWeighted(h);
|
|
1731
|
+
}
|
|
1732
|
+
}
|
|
1733
|
+
_poolAvg(h) {
|
|
1734
|
+
const n = h.length;
|
|
1735
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1736
|
+
let sum = 0;
|
|
1737
|
+
for (let i = 0; i < n; i++)
|
|
1738
|
+
sum += h[i][m];
|
|
1739
|
+
return sum / n;
|
|
1740
|
+
});
|
|
1741
|
+
}
|
|
1742
|
+
_poolMax(h) {
|
|
1743
|
+
this._argmax = new Array(this.d_model).fill(0);
|
|
1744
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1745
|
+
let maxVal = -Infinity;
|
|
1746
|
+
for (let i = 0; i < h.length; i++) {
|
|
1747
|
+
if (h[i][m] > maxVal) {
|
|
1748
|
+
maxVal = h[i][m];
|
|
1749
|
+
this._argmax[m] = i;
|
|
1750
|
+
}
|
|
1751
|
+
}
|
|
1752
|
+
return maxVal;
|
|
1753
|
+
});
|
|
1754
|
+
}
|
|
1755
|
+
_poolLast(h) {
|
|
1756
|
+
return [...h[h.length - 1]];
|
|
1757
|
+
}
|
|
1758
|
+
_poolWeighted(h) {
|
|
1146
1759
|
const weights = Array.from(
|
|
1147
1760
|
{ length: this.seqLen },
|
|
1148
1761
|
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
@@ -1155,6 +1768,55 @@ var NetworkTransformerRL = class {
|
|
|
1155
1768
|
return sum / totalWeight;
|
|
1156
1769
|
});
|
|
1157
1770
|
}
|
|
1771
|
+
/** Returns the current pooling type for inspection. */
|
|
1772
|
+
getPoolingType() {
|
|
1773
|
+
return this._pooling;
|
|
1774
|
+
}
|
|
1775
|
+
// ── Helper: distribute pooled gradient back to each position ────────────────
|
|
1776
|
+
// Must match the same distribution as _pool() used during forward.
|
|
1777
|
+
_distributePoolGradient(dPooled) {
|
|
1778
|
+
switch (this._pooling) {
|
|
1779
|
+
case "avg": {
|
|
1780
|
+
const n = this.seqLen;
|
|
1781
|
+
return Array.from(
|
|
1782
|
+
{ length: n },
|
|
1783
|
+
() => dPooled.map((v) => v / n)
|
|
1784
|
+
);
|
|
1785
|
+
}
|
|
1786
|
+
case "max": {
|
|
1787
|
+
if (!this._argmax) {
|
|
1788
|
+
const n = this.seqLen;
|
|
1789
|
+
return Array.from(
|
|
1790
|
+
{ length: n },
|
|
1791
|
+
() => dPooled.map((v) => v / n)
|
|
1792
|
+
);
|
|
1793
|
+
}
|
|
1794
|
+
const argmax = this._argmax;
|
|
1795
|
+
return Array.from(
|
|
1796
|
+
{ length: this.seqLen },
|
|
1797
|
+
(_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
|
|
1798
|
+
);
|
|
1799
|
+
}
|
|
1800
|
+
case "last": {
|
|
1801
|
+
return Array.from(
|
|
1802
|
+
{ length: this.seqLen },
|
|
1803
|
+
(_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
|
|
1804
|
+
);
|
|
1805
|
+
}
|
|
1806
|
+
case "weighted":
|
|
1807
|
+
default: {
|
|
1808
|
+
const weights = Array.from(
|
|
1809
|
+
{ length: this.seqLen },
|
|
1810
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1811
|
+
);
|
|
1812
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1813
|
+
return Array.from(
|
|
1814
|
+
{ length: this.seqLen },
|
|
1815
|
+
(_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
|
|
1816
|
+
);
|
|
1817
|
+
}
|
|
1818
|
+
}
|
|
1819
|
+
}
|
|
1158
1820
|
};
|
|
1159
1821
|
|
|
1160
1822
|
// src/losses.ts
|
|
@@ -1179,13 +1841,801 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1179
1841
|
const p = Math.max(eps, Math.min(1 - eps, predicted));
|
|
1180
1842
|
return actual / p - (1 - actual) / (1 - p);
|
|
1181
1843
|
}
|
|
1844
|
+
|
|
1845
|
+
// src/GRU.ts
|
|
1846
|
+
function sigmoid4(x) {
|
|
1847
|
+
return 1 / (1 + Math.exp(-x));
|
|
1848
|
+
}
|
|
1849
|
+
function tanhFn(x) {
|
|
1850
|
+
const e = Math.exp(2 * x);
|
|
1851
|
+
return (e - 1) / (e + 1);
|
|
1852
|
+
}
|
|
1853
|
+
var Gate2 = class {
|
|
1854
|
+
constructor(inputSize, hSize, initBias = 0) {
|
|
1855
|
+
const n = inputSize + hSize;
|
|
1856
|
+
const limit = Math.sqrt(2 / n);
|
|
1857
|
+
this.W = Array.from(
|
|
1858
|
+
{ length: hSize },
|
|
1859
|
+
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
1860
|
+
);
|
|
1861
|
+
this.b = new Array(hSize).fill(initBias);
|
|
1862
|
+
}
|
|
1863
|
+
linear(combined) {
|
|
1864
|
+
return this.W.map(
|
|
1865
|
+
(row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
|
|
1866
|
+
);
|
|
1867
|
+
}
|
|
1868
|
+
};
|
|
1869
|
+
var GRULayer = class {
|
|
1870
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
1871
|
+
this._traj = [];
|
|
1872
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
1873
|
+
throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
1874
|
+
}
|
|
1875
|
+
this.inputSize = inputSize;
|
|
1876
|
+
this.hSize = hiddenSize;
|
|
1877
|
+
this.h = new Array(hiddenSize).fill(0);
|
|
1878
|
+
this.resetGate = new Gate2(inputSize, hiddenSize);
|
|
1879
|
+
this.updateGate = new Gate2(inputSize, hiddenSize);
|
|
1880
|
+
this.newGate = new Gate2(inputSize, hiddenSize);
|
|
1881
|
+
const combSize = inputSize + hiddenSize;
|
|
1882
|
+
this._optimizers = {
|
|
1883
|
+
resetW: Array.from(
|
|
1884
|
+
{ length: hiddenSize },
|
|
1885
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1886
|
+
),
|
|
1887
|
+
resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1888
|
+
updateW: Array.from(
|
|
1889
|
+
{ length: hiddenSize },
|
|
1890
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1891
|
+
),
|
|
1892
|
+
updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1893
|
+
newW: Array.from(
|
|
1894
|
+
{ length: hiddenSize },
|
|
1895
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1896
|
+
),
|
|
1897
|
+
newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
1898
|
+
};
|
|
1899
|
+
}
|
|
1900
|
+
reset() {
|
|
1901
|
+
this.h = new Array(this.hSize).fill(0);
|
|
1902
|
+
this._traj = [];
|
|
1903
|
+
}
|
|
1904
|
+
predict(inputs) {
|
|
1905
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
1906
|
+
throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
1907
|
+
}
|
|
1908
|
+
const combined = [...inputs, ...this.h];
|
|
1909
|
+
const h_prev = [...this.h];
|
|
1910
|
+
const r_pre = this.resetGate.linear(combined);
|
|
1911
|
+
const z_pre = this.updateGate.linear(combined);
|
|
1912
|
+
const r_a = r_pre.map(sigmoid4);
|
|
1913
|
+
const z_a = z_pre.map(sigmoid4);
|
|
1914
|
+
const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
|
|
1915
|
+
const n_pre = this.newGate.linear(combined_r);
|
|
1916
|
+
const n_a = n_pre.map(tanhFn);
|
|
1917
|
+
const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
|
|
1918
|
+
this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
|
|
1919
|
+
this.h = h;
|
|
1920
|
+
return h;
|
|
1921
|
+
}
|
|
1922
|
+
backprop(dh_seq, lr) {
|
|
1923
|
+
const T = this._traj.length;
|
|
1924
|
+
if (T === 0 || dh_seq.length !== T) return;
|
|
1925
|
+
const hSize = this.hSize;
|
|
1926
|
+
const combSize = this.inputSize + hSize;
|
|
1927
|
+
const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1928
|
+
const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1929
|
+
const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1930
|
+
const dbr = new Array(hSize).fill(0);
|
|
1931
|
+
const dbz = new Array(hSize).fill(0);
|
|
1932
|
+
const dbn = new Array(hSize).fill(0);
|
|
1933
|
+
let dh_next = new Array(hSize).fill(0);
|
|
1934
|
+
for (let t = T - 1; t >= 0; t--) {
|
|
1935
|
+
const s = this._traj[t];
|
|
1936
|
+
const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
|
|
1937
|
+
const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
|
|
1938
|
+
const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
|
|
1939
|
+
const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
|
|
1940
|
+
const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
|
|
1941
|
+
const dr_hprev = Array.from(
|
|
1942
|
+
{ length: hSize },
|
|
1943
|
+
(_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
|
|
1944
|
+
);
|
|
1945
|
+
const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
|
|
1946
|
+
const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
|
|
1947
|
+
for (let k = 0; k < hSize; k++) {
|
|
1948
|
+
for (let j = 0; j < combSize; j++) {
|
|
1949
|
+
dWr[k][j] += dr_pre[k] * s.combined[j];
|
|
1950
|
+
dWz[k][j] += dz_pre[k] * s.combined[j];
|
|
1951
|
+
dWn[k][j] += dn_pre[k] * s.combined_r[j];
|
|
1952
|
+
}
|
|
1953
|
+
dbr[k] += dr_pre[k];
|
|
1954
|
+
dbz[k] += dz_pre[k];
|
|
1955
|
+
dbn[k] += dn_pre[k];
|
|
1956
|
+
}
|
|
1957
|
+
dh_next = new Array(hSize).fill(0);
|
|
1958
|
+
for (let k = 0; k < hSize; k++) {
|
|
1959
|
+
for (let j = this.inputSize; j < combSize; j++) {
|
|
1960
|
+
dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
|
|
1961
|
+
}
|
|
1962
|
+
dh_next[k] += dr_hprev[k] * s.r_a[k];
|
|
1963
|
+
dh_next[k] += dh[k] * s.z_a[k];
|
|
1964
|
+
}
|
|
1965
|
+
}
|
|
1966
|
+
const scale = lr / T;
|
|
1967
|
+
for (let k = 0; k < hSize; k++) {
|
|
1968
|
+
for (let j = 0; j < combSize; j++) {
|
|
1969
|
+
this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
|
|
1970
|
+
this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
|
|
1971
|
+
this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
|
|
1972
|
+
}
|
|
1973
|
+
this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
|
|
1974
|
+
this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
|
|
1975
|
+
this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
|
|
1976
|
+
}
|
|
1977
|
+
this._traj = [];
|
|
1978
|
+
}
|
|
1979
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1980
|
+
// Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
|
|
1981
|
+
getWeightsFlat() {
|
|
1982
|
+
const w = [];
|
|
1983
|
+
for (const row of this.resetGate.W) w.push(...row);
|
|
1984
|
+
w.push(...this.resetGate.b);
|
|
1985
|
+
for (const row of this.updateGate.W) w.push(...row);
|
|
1986
|
+
w.push(...this.updateGate.b);
|
|
1987
|
+
for (const row of this.newGate.W) w.push(...row);
|
|
1988
|
+
w.push(...this.newGate.b);
|
|
1989
|
+
return w;
|
|
1990
|
+
}
|
|
1991
|
+
setWeightsFlat(weights) {
|
|
1992
|
+
let idx = 0;
|
|
1993
|
+
for (let i = 0; i < this.resetGate.W.length; i++)
|
|
1994
|
+
for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
|
|
1995
|
+
for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
|
|
1996
|
+
for (let i = 0; i < this.updateGate.W.length; i++)
|
|
1997
|
+
for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
|
|
1998
|
+
for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
|
|
1999
|
+
for (let i = 0; i < this.newGate.W.length; i++)
|
|
2000
|
+
for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
|
|
2001
|
+
for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
|
|
2002
|
+
}
|
|
2003
|
+
getWeights() {
|
|
2004
|
+
return {
|
|
2005
|
+
resetGate: { W: this.resetGate.W, b: this.resetGate.b },
|
|
2006
|
+
updateGate: { W: this.updateGate.W, b: this.updateGate.b },
|
|
2007
|
+
newGate: { W: this.newGate.W, b: this.newGate.b }
|
|
2008
|
+
};
|
|
2009
|
+
}
|
|
2010
|
+
setWeights(data) {
|
|
2011
|
+
this.resetGate.W = data.resetGate.W;
|
|
2012
|
+
this.resetGate.b = data.resetGate.b;
|
|
2013
|
+
this.updateGate.W = data.updateGate.W;
|
|
2014
|
+
this.updateGate.b = data.updateGate.b;
|
|
2015
|
+
this.newGate.W = data.newGate.W;
|
|
2016
|
+
this.newGate.b = data.newGate.b;
|
|
2017
|
+
}
|
|
2018
|
+
};
|
|
2019
|
+
|
|
2020
|
+
// src/BatchNorm.ts
|
|
2021
|
+
var BatchNorm = class {
|
|
2022
|
+
constructor(dim, momentum = 0.1) {
|
|
2023
|
+
this._xNorm = null;
|
|
2024
|
+
this._std = null;
|
|
2025
|
+
this.dim = dim;
|
|
2026
|
+
this.momentum = momentum;
|
|
2027
|
+
this.gamma = new Array(dim).fill(1);
|
|
2028
|
+
this.beta = new Array(dim).fill(0);
|
|
2029
|
+
this.runningMean = new Array(dim).fill(0);
|
|
2030
|
+
this.runningVar = new Array(dim).fill(1);
|
|
2031
|
+
}
|
|
2032
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2033
|
+
forward(x) {
|
|
2034
|
+
if (x.length !== this.dim) {
|
|
2035
|
+
throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
|
|
2036
|
+
}
|
|
2037
|
+
const eps = 1e-5;
|
|
2038
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2039
|
+
this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
|
|
2040
|
+
const diff = x[i] - this.runningMean[i];
|
|
2041
|
+
this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
|
|
2042
|
+
}
|
|
2043
|
+
this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
|
|
2044
|
+
this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
|
|
2045
|
+
return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
|
|
2046
|
+
}
|
|
2047
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2048
|
+
backward(dOut) {
|
|
2049
|
+
if (!this._xNorm || !this._std) {
|
|
2050
|
+
throw new Error("BatchNorm.backward: call forward() first");
|
|
2051
|
+
}
|
|
2052
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2053
|
+
}
|
|
2054
|
+
return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
|
|
2055
|
+
}
|
|
2056
|
+
// ── Train gamma and beta (call after backward) ────────────────────────────
|
|
2057
|
+
trainParams(dOut, lr) {
|
|
2058
|
+
if (!this._xNorm) return;
|
|
2059
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2060
|
+
this.gamma[i] += lr * dOut[i] * this._xNorm[i];
|
|
2061
|
+
this.beta[i] += lr * dOut[i];
|
|
2062
|
+
}
|
|
2063
|
+
}
|
|
2064
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2065
|
+
// Order: gamma, beta.
|
|
2066
|
+
getWeights() {
|
|
2067
|
+
return [...this.gamma, ...this.beta];
|
|
2068
|
+
}
|
|
2069
|
+
setWeights(weights) {
|
|
2070
|
+
for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
|
|
2071
|
+
for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
|
|
2072
|
+
}
|
|
2073
|
+
};
|
|
2074
|
+
|
|
2075
|
+
// src/Conv1D.ts
|
|
2076
|
+
var Conv1D = class {
|
|
2077
|
+
constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
|
|
2078
|
+
// [filters]
|
|
2079
|
+
this._input = null;
|
|
2080
|
+
this._paddedInput = null;
|
|
2081
|
+
if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
|
|
2082
|
+
throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
|
|
2083
|
+
}
|
|
2084
|
+
if (kernelSize > inputLength && padding === "valid") {
|
|
2085
|
+
throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
|
|
2086
|
+
}
|
|
2087
|
+
if (inputChannels < 1) {
|
|
2088
|
+
throw new Error("Conv1D: inputChannels must be >= 1");
|
|
2089
|
+
}
|
|
2090
|
+
this.inputLength = inputLength;
|
|
2091
|
+
this.kernelSize = kernelSize;
|
|
2092
|
+
this.filters = filters;
|
|
2093
|
+
this.stride = stride;
|
|
2094
|
+
this.padding = padding;
|
|
2095
|
+
this.inputChannels = inputChannels;
|
|
2096
|
+
const limit = Math.sqrt(2 / (kernelSize * inputChannels));
|
|
2097
|
+
this.kernels = Array.from(
|
|
2098
|
+
{ length: filters },
|
|
2099
|
+
() => Array.from(
|
|
2100
|
+
{ length: kernelSize },
|
|
2101
|
+
() => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
|
|
2102
|
+
)
|
|
2103
|
+
);
|
|
2104
|
+
this.biases = new Array(filters).fill(0);
|
|
2105
|
+
this._kOpts = Array.from(
|
|
2106
|
+
{ length: filters },
|
|
2107
|
+
() => Array.from(
|
|
2108
|
+
{ length: kernelSize },
|
|
2109
|
+
() => Array.from({ length: inputChannels }, () => optimizerFactory())
|
|
2110
|
+
)
|
|
2111
|
+
);
|
|
2112
|
+
this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
|
|
2113
|
+
}
|
|
2114
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2115
|
+
// Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
|
|
2116
|
+
forward(input) {
|
|
2117
|
+
const input2D = this._normalizeInput(input);
|
|
2118
|
+
this._input = input2D.map((row) => [...row]);
|
|
2119
|
+
let padded;
|
|
2120
|
+
if (this.padding === "same") {
|
|
2121
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2122
|
+
const padRow = new Array(this.inputChannels).fill(0);
|
|
2123
|
+
padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
|
|
2124
|
+
} else {
|
|
2125
|
+
padded = input2D;
|
|
2126
|
+
}
|
|
2127
|
+
this._paddedInput = padded;
|
|
2128
|
+
const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
|
|
2129
|
+
const output = Array.from(
|
|
2130
|
+
{ length: this.filters },
|
|
2131
|
+
() => new Array(outputLength).fill(0)
|
|
2132
|
+
);
|
|
2133
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2134
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2135
|
+
const start = pos * this.stride;
|
|
2136
|
+
let sum = this.biases[f];
|
|
2137
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2138
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2139
|
+
sum += this.kernels[f][k][c] * padded[start + k][c];
|
|
2140
|
+
}
|
|
2141
|
+
}
|
|
2142
|
+
output[f][pos] = sum;
|
|
2143
|
+
}
|
|
2144
|
+
}
|
|
2145
|
+
return output;
|
|
2146
|
+
}
|
|
2147
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2148
|
+
backward(dOut, lr = 1e-3) {
|
|
2149
|
+
if (!this._paddedInput || !this._input) {
|
|
2150
|
+
throw new Error("Conv1D.backward: call forward() first");
|
|
2151
|
+
}
|
|
2152
|
+
const padded = this._paddedInput;
|
|
2153
|
+
const outputLength = dOut[0].length;
|
|
2154
|
+
const dKernels = Array.from(
|
|
2155
|
+
{ length: this.filters },
|
|
2156
|
+
() => Array.from(
|
|
2157
|
+
{ length: this.kernelSize },
|
|
2158
|
+
() => new Array(this.inputChannels).fill(0)
|
|
2159
|
+
)
|
|
2160
|
+
);
|
|
2161
|
+
const dBiases = new Array(this.filters).fill(0);
|
|
2162
|
+
const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
|
|
2163
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2164
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2165
|
+
const start = pos * this.stride;
|
|
2166
|
+
dBiases[f] += dOut[f][pos];
|
|
2167
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2168
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2169
|
+
dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
|
|
2170
|
+
dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
|
|
2171
|
+
}
|
|
2172
|
+
}
|
|
2173
|
+
}
|
|
2174
|
+
}
|
|
2175
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2176
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2177
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2178
|
+
this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
|
|
2179
|
+
}
|
|
2180
|
+
}
|
|
2181
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
2182
|
+
}
|
|
2183
|
+
if (this.padding === "same") {
|
|
2184
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2185
|
+
return dPadded.slice(padSize, padSize + this.inputLength);
|
|
2186
|
+
}
|
|
2187
|
+
return dPadded.slice(0, this.inputLength);
|
|
2188
|
+
}
|
|
2189
|
+
// ── Output length ─────────────────────────────────────────────────────────
|
|
2190
|
+
getOutputLength() {
|
|
2191
|
+
if (this.padding === "same") {
|
|
2192
|
+
return Math.ceil(this.inputLength / this.stride);
|
|
2193
|
+
}
|
|
2194
|
+
return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
|
|
2195
|
+
}
|
|
2196
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2197
|
+
// Order: kernels (flattened), biases.
|
|
2198
|
+
getWeights() {
|
|
2199
|
+
const w = [];
|
|
2200
|
+
for (const kernel of this.kernels)
|
|
2201
|
+
for (const k of kernel)
|
|
2202
|
+
for (const c of k)
|
|
2203
|
+
w.push(c);
|
|
2204
|
+
w.push(...this.biases);
|
|
2205
|
+
return w;
|
|
2206
|
+
}
|
|
2207
|
+
setWeights(weights) {
|
|
2208
|
+
let idx = 0;
|
|
2209
|
+
for (let f = 0; f < this.filters; f++)
|
|
2210
|
+
for (let k = 0; k < this.kernelSize; k++)
|
|
2211
|
+
for (let c = 0; c < this.inputChannels; c++)
|
|
2212
|
+
this.kernels[f][k][c] = weights[idx++];
|
|
2213
|
+
for (let f = 0; f < this.filters; f++)
|
|
2214
|
+
this.biases[f] = weights[idx++];
|
|
2215
|
+
}
|
|
2216
|
+
// ── Normalize input to 2D format ─────────────────────────────────────────
|
|
2217
|
+
_normalizeInput(input) {
|
|
2218
|
+
if (input.length === 0) {
|
|
2219
|
+
throw new Error("Conv1D.forward: input cannot be empty");
|
|
2220
|
+
}
|
|
2221
|
+
if (typeof input[0] === "number") {
|
|
2222
|
+
if (this.inputChannels !== 1) {
|
|
2223
|
+
throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
|
|
2224
|
+
}
|
|
2225
|
+
const input1D = input;
|
|
2226
|
+
if (input1D.length !== this.inputLength) {
|
|
2227
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
|
|
2228
|
+
}
|
|
2229
|
+
return input1D.map((v) => [v]);
|
|
2230
|
+
}
|
|
2231
|
+
const input2D = input;
|
|
2232
|
+
if (input2D.length !== this.inputLength) {
|
|
2233
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
|
|
2234
|
+
}
|
|
2235
|
+
for (let i = 0; i < input2D.length; i++) {
|
|
2236
|
+
if (input2D[i].length !== this.inputChannels) {
|
|
2237
|
+
throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
|
|
2238
|
+
}
|
|
2239
|
+
}
|
|
2240
|
+
return input2D;
|
|
2241
|
+
}
|
|
2242
|
+
};
|
|
2243
|
+
|
|
2244
|
+
// src/Trainer.ts
|
|
2245
|
+
var Trainer = class {
|
|
2246
|
+
constructor(network, options = {}) {
|
|
2247
|
+
this._history = [];
|
|
2248
|
+
this._bestLoss = Infinity;
|
|
2249
|
+
this._patienceCounter = 0;
|
|
2250
|
+
this._stopReason = "maxEpochs";
|
|
2251
|
+
this._metrics = [];
|
|
2252
|
+
this.network = network;
|
|
2253
|
+
this.epochs = options.epochs ?? 1e3;
|
|
2254
|
+
this.lrInitial = options.lr ?? 0.1;
|
|
2255
|
+
this.lrDecay = options.lrDecay ?? 1;
|
|
2256
|
+
this.verbose = options.verbose ?? false;
|
|
2257
|
+
this.weightDecay = options.weightDecay ?? 0;
|
|
2258
|
+
this._earlyStopping = options.earlyStopping;
|
|
2259
|
+
this._computeMetrics = options.computeMetrics ?? false;
|
|
2260
|
+
this.clipValue = options.clipValue ?? 0;
|
|
2261
|
+
}
|
|
2262
|
+
// ── Set external validation data (for early stopping) ────────────────────
|
|
2263
|
+
setValidationData(dataset) {
|
|
2264
|
+
if (dataset.inputs.length !== dataset.targets.length) {
|
|
2265
|
+
throw new Error(
|
|
2266
|
+
"Trainer.setValidationData: inputs and targets must have the same length"
|
|
2267
|
+
);
|
|
2268
|
+
}
|
|
2269
|
+
this._validationData = dataset;
|
|
2270
|
+
}
|
|
2271
|
+
// ── Get best validation loss during training ─────────────────────────────
|
|
2272
|
+
getBestLoss() {
|
|
2273
|
+
return this._bestLoss === Infinity ? -1 : this._bestLoss;
|
|
2274
|
+
}
|
|
2275
|
+
// ── Why did training stop? ───────────────────────────────────────────────
|
|
2276
|
+
getStopReason() {
|
|
2277
|
+
return this._stopReason;
|
|
2278
|
+
}
|
|
2279
|
+
// ── Get per-epoch classification metrics ─────────────────────────────────
|
|
2280
|
+
getMetrics() {
|
|
2281
|
+
return [...this._metrics];
|
|
2282
|
+
}
|
|
2283
|
+
// ── Train on dataset ──────────────────────────────────────────────────────
|
|
2284
|
+
train(dataset) {
|
|
2285
|
+
const { inputs, targets } = dataset;
|
|
2286
|
+
if (inputs.length !== targets.length) {
|
|
2287
|
+
throw new Error(
|
|
2288
|
+
"Trainer.train: inputs and targets must have the same length"
|
|
2289
|
+
);
|
|
2290
|
+
}
|
|
2291
|
+
const n = inputs.length;
|
|
2292
|
+
let lr = this.lrInitial;
|
|
2293
|
+
this._history = [];
|
|
2294
|
+
this._bestLoss = Infinity;
|
|
2295
|
+
this._patienceCounter = 0;
|
|
2296
|
+
this._stopReason = "maxEpochs";
|
|
2297
|
+
this._metrics = [];
|
|
2298
|
+
const netExt = this._hasWeights(this.network);
|
|
2299
|
+
if (this.weightDecay > 0 && !netExt) {
|
|
2300
|
+
console.warn(
|
|
2301
|
+
"Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
|
|
2302
|
+
);
|
|
2303
|
+
}
|
|
2304
|
+
if (this._earlyStopping && !netExt) {
|
|
2305
|
+
console.warn(
|
|
2306
|
+
"Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
|
|
2307
|
+
);
|
|
2308
|
+
}
|
|
2309
|
+
if (this._computeMetrics && !netExt) {
|
|
2310
|
+
console.warn(
|
|
2311
|
+
"Trainer: computeMetrics requires a network with predict(). Skipping metrics."
|
|
2312
|
+
);
|
|
2313
|
+
}
|
|
2314
|
+
const canDecay = this.weightDecay > 0 && netExt;
|
|
2315
|
+
const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
|
|
2316
|
+
const canMetric = this._computeMetrics && netExt;
|
|
2317
|
+
const isClass = canMetric && this._isClassification(targets);
|
|
2318
|
+
if (canMetric && !isClass) {
|
|
2319
|
+
console.warn(
|
|
2320
|
+
"Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
|
|
2321
|
+
);
|
|
2322
|
+
}
|
|
2323
|
+
for (let epoch = 0; epoch < this.epochs; epoch++) {
|
|
2324
|
+
const indices = Array.from({ length: n }, (_, i) => i);
|
|
2325
|
+
for (let i = n - 1; i > 0; i--) {
|
|
2326
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2327
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
2328
|
+
}
|
|
2329
|
+
let epochLoss = 0;
|
|
2330
|
+
for (const i of indices) {
|
|
2331
|
+
if (canDecay) {
|
|
2332
|
+
const w = netExt.getWeights();
|
|
2333
|
+
for (let j = 0; j < w.length; j++) {
|
|
2334
|
+
w[j] *= 1 - lr * this.weightDecay;
|
|
2335
|
+
}
|
|
2336
|
+
netExt.setWeights(w);
|
|
2337
|
+
}
|
|
2338
|
+
epochLoss += this.network.train(inputs[i], targets[i], lr);
|
|
2339
|
+
}
|
|
2340
|
+
epochLoss /= n;
|
|
2341
|
+
this._history.push(epochLoss);
|
|
2342
|
+
if (canMetric && isClass) {
|
|
2343
|
+
this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
|
|
2344
|
+
}
|
|
2345
|
+
if (canValidate && this._validationData) {
|
|
2346
|
+
const valLoss = this._computeLoss(netExt, this._validationData);
|
|
2347
|
+
const minDelta = this._earlyStopping.minDelta;
|
|
2348
|
+
if (valLoss < this._bestLoss - minDelta) {
|
|
2349
|
+
this._bestLoss = valLoss;
|
|
2350
|
+
this._patienceCounter = 0;
|
|
2351
|
+
} else {
|
|
2352
|
+
this._patienceCounter++;
|
|
2353
|
+
}
|
|
2354
|
+
if (this._patienceCounter >= this._earlyStopping.patience) {
|
|
2355
|
+
this._stopReason = "earlyStopping";
|
|
2356
|
+
break;
|
|
2357
|
+
}
|
|
2358
|
+
}
|
|
2359
|
+
lr *= this.lrDecay;
|
|
2360
|
+
if (this.verbose && (epoch + 1) % 100 === 0) {
|
|
2361
|
+
console.log(
|
|
2362
|
+
`Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
|
|
2363
|
+
);
|
|
2364
|
+
}
|
|
2365
|
+
}
|
|
2366
|
+
return this._history;
|
|
2367
|
+
}
|
|
2368
|
+
// ── Get loss history ──────────────────────────────────────────────────────
|
|
2369
|
+
getHistory() {
|
|
2370
|
+
return [...this._history];
|
|
2371
|
+
}
|
|
2372
|
+
// ── Private helpers ───────────────────────────────────────────────────────
|
|
2373
|
+
/** Type guard: does this network support getWeights/setWeights/predict? */
|
|
2374
|
+
_hasWeights(network) {
|
|
2375
|
+
if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
|
|
2376
|
+
return network;
|
|
2377
|
+
}
|
|
2378
|
+
return null;
|
|
2379
|
+
}
|
|
2380
|
+
/** Mean squared error on a dataset (used for validation loss). */
|
|
2381
|
+
_computeLoss(net, data) {
|
|
2382
|
+
let totalLoss = 0;
|
|
2383
|
+
for (let i = 0; i < data.inputs.length; i++) {
|
|
2384
|
+
const pred = net.predict(data.inputs[i]);
|
|
2385
|
+
const target = data.targets[i];
|
|
2386
|
+
let sampleLoss = 0;
|
|
2387
|
+
for (let j = 0; j < pred.length; j++) {
|
|
2388
|
+
sampleLoss += (target[j] - pred[j]) ** 2;
|
|
2389
|
+
}
|
|
2390
|
+
totalLoss += sampleLoss / pred.length;
|
|
2391
|
+
}
|
|
2392
|
+
return totalLoss / data.inputs.length;
|
|
2393
|
+
}
|
|
2394
|
+
/** Heuristic: are targets classification-style (one-hot or single-class)? */
|
|
2395
|
+
_isClassification(targets) {
|
|
2396
|
+
if (targets.length === 0) return false;
|
|
2397
|
+
const first = targets[0];
|
|
2398
|
+
if (first.length === 1) return true;
|
|
2399
|
+
for (const t of targets) {
|
|
2400
|
+
let sum = 0;
|
|
2401
|
+
for (const v of t) {
|
|
2402
|
+
sum += v;
|
|
2403
|
+
if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
|
|
2404
|
+
return false;
|
|
2405
|
+
}
|
|
2406
|
+
if (Math.abs(sum - 1) > 0.01) return false;
|
|
2407
|
+
}
|
|
2408
|
+
return true;
|
|
2409
|
+
}
|
|
2410
|
+
/** Compute classification metrics from predictions vs targets. */
|
|
2411
|
+
_computeMetricsArray(net, inputs, targets) {
|
|
2412
|
+
const targetLen = targets[0].length;
|
|
2413
|
+
const nClasses = targetLen === 1 ? 2 : targetLen;
|
|
2414
|
+
const confusion = Array.from(
|
|
2415
|
+
{ length: nClasses },
|
|
2416
|
+
() => Array(nClasses).fill(0)
|
|
2417
|
+
);
|
|
2418
|
+
for (let i = 0; i < inputs.length; i++) {
|
|
2419
|
+
const pred = net.predict(inputs[i]);
|
|
2420
|
+
const target = targets[i];
|
|
2421
|
+
let predClass;
|
|
2422
|
+
let trueClass;
|
|
2423
|
+
if (targetLen === 1) {
|
|
2424
|
+
trueClass = target[0] >= 0.5 ? 1 : 0;
|
|
2425
|
+
if (pred.length === 1) {
|
|
2426
|
+
predClass = pred[0] >= 0.5 ? 1 : 0;
|
|
2427
|
+
} else {
|
|
2428
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2429
|
+
}
|
|
2430
|
+
} else {
|
|
2431
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2432
|
+
trueClass = target.indexOf(Math.max(...target));
|
|
2433
|
+
}
|
|
2434
|
+
predClass = Math.max(0, Math.min(nClasses - 1, predClass));
|
|
2435
|
+
trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
|
|
2436
|
+
confusion[trueClass][predClass]++;
|
|
2437
|
+
}
|
|
2438
|
+
let totalCorrect = 0;
|
|
2439
|
+
let totalSamples = 0;
|
|
2440
|
+
const precisions = [];
|
|
2441
|
+
const recalls = [];
|
|
2442
|
+
for (let c = 0; c < nClasses; c++) {
|
|
2443
|
+
const tp = confusion[c][c];
|
|
2444
|
+
totalCorrect += tp;
|
|
2445
|
+
let colSum = 0;
|
|
2446
|
+
let rowSum = 0;
|
|
2447
|
+
for (let r = 0; r < nClasses; r++) {
|
|
2448
|
+
colSum += confusion[r][c];
|
|
2449
|
+
rowSum += confusion[c][r];
|
|
2450
|
+
}
|
|
2451
|
+
totalSamples += rowSum;
|
|
2452
|
+
precisions.push(colSum > 0 ? tp / colSum : 0);
|
|
2453
|
+
recalls.push(rowSum > 0 ? tp / rowSum : 0);
|
|
2454
|
+
}
|
|
2455
|
+
const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
|
|
2456
|
+
const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
|
|
2457
|
+
const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
|
|
2458
|
+
const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
|
|
2459
|
+
return {
|
|
2460
|
+
accuracy,
|
|
2461
|
+
precision: macroPrecision,
|
|
2462
|
+
recall: macroRecall,
|
|
2463
|
+
f1
|
|
2464
|
+
};
|
|
2465
|
+
}
|
|
2466
|
+
};
|
|
2467
|
+
|
|
2468
|
+
// src/DataLoader.ts
|
|
2469
|
+
var DataLoader = class _DataLoader {
|
|
2470
|
+
constructor(data, batchSize = 1, validationSplit = 0) {
|
|
2471
|
+
if (data.inputs.length !== data.targets.length) {
|
|
2472
|
+
throw new Error("DataLoader: inputs and targets must have the same length");
|
|
2473
|
+
}
|
|
2474
|
+
if (validationSplit < 0 || validationSplit >= 1) {
|
|
2475
|
+
throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
|
|
2476
|
+
}
|
|
2477
|
+
this.data = data;
|
|
2478
|
+
this.batchSize = batchSize;
|
|
2479
|
+
this._validationSplit = validationSplit;
|
|
2480
|
+
const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
|
|
2481
|
+
for (let i = fullIndices.length - 1; i > 0; i--) {
|
|
2482
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2483
|
+
[fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
|
|
2484
|
+
}
|
|
2485
|
+
if (validationSplit > 0) {
|
|
2486
|
+
const valSize = Math.round(data.inputs.length * validationSplit);
|
|
2487
|
+
const trainSize = data.inputs.length - valSize;
|
|
2488
|
+
this._trainIndices = fullIndices.slice(0, trainSize);
|
|
2489
|
+
this._valIndices = fullIndices.slice(trainSize);
|
|
2490
|
+
} else {
|
|
2491
|
+
this._trainIndices = [...fullIndices];
|
|
2492
|
+
this._valIndices = [];
|
|
2493
|
+
}
|
|
2494
|
+
this._indices = [...this._trainIndices];
|
|
2495
|
+
this._pos = 0;
|
|
2496
|
+
}
|
|
2497
|
+
// ── Shuffle the training data ──────────────────────────────────────────────
|
|
2498
|
+
shuffle() {
|
|
2499
|
+
for (let i = this._trainIndices.length - 1; i > 0; i--) {
|
|
2500
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2501
|
+
[this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
|
|
2502
|
+
}
|
|
2503
|
+
this._indices = [...this._trainIndices];
|
|
2504
|
+
this._pos = 0;
|
|
2505
|
+
}
|
|
2506
|
+
// ── Check if more batches are available ───────────────────────────────────
|
|
2507
|
+
hasNext() {
|
|
2508
|
+
return this._pos < this._indices.length;
|
|
2509
|
+
}
|
|
2510
|
+
// ── Get next batch ────────────────────────────────────────────────────────
|
|
2511
|
+
next() {
|
|
2512
|
+
const end = Math.min(this._pos + this.batchSize, this._indices.length);
|
|
2513
|
+
const batchIndices = this._indices.slice(this._pos, end);
|
|
2514
|
+
this._pos = end;
|
|
2515
|
+
return {
|
|
2516
|
+
inputs: batchIndices.map((i) => this.data.inputs[i]),
|
|
2517
|
+
targets: batchIndices.map((i) => this.data.targets[i])
|
|
2518
|
+
};
|
|
2519
|
+
}
|
|
2520
|
+
// ── Reset iteration ───────────────────────────────────────────────────────
|
|
2521
|
+
reset() {
|
|
2522
|
+
this._pos = 0;
|
|
2523
|
+
}
|
|
2524
|
+
// ── Get total number of training samples ───────────────────────────────────
|
|
2525
|
+
get length() {
|
|
2526
|
+
return this._trainIndices.length;
|
|
2527
|
+
}
|
|
2528
|
+
// ── Get validation data as a DataPair ──────────────────────────────────────
|
|
2529
|
+
// Returns the validation samples (inputs + targets) in their shuffled order.
|
|
2530
|
+
// Returns empty arrays if no validation split was configured.
|
|
2531
|
+
getValidationData() {
|
|
2532
|
+
return {
|
|
2533
|
+
inputs: this._valIndices.map((i) => this.data.inputs[i]),
|
|
2534
|
+
targets: this._valIndices.map((i) => this.data.targets[i])
|
|
2535
|
+
};
|
|
2536
|
+
}
|
|
2537
|
+
// ── Get number of validation samples ───────────────────────────────────────
|
|
2538
|
+
get validationLength() {
|
|
2539
|
+
return this._valIndices.length;
|
|
2540
|
+
}
|
|
2541
|
+
// ── Create sequence windows from a time series ────────────────────────────
|
|
2542
|
+
static sequences(data, seqLen, validationSplit = 0) {
|
|
2543
|
+
if (data.length < seqLen + 1) {
|
|
2544
|
+
throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
|
|
2545
|
+
}
|
|
2546
|
+
const inputs = [];
|
|
2547
|
+
const targets = [];
|
|
2548
|
+
for (let i = 0; i <= data.length - seqLen - 1; i++) {
|
|
2549
|
+
inputs.push(data.slice(i, i + seqLen).flat());
|
|
2550
|
+
targets.push(data[i + seqLen]);
|
|
2551
|
+
}
|
|
2552
|
+
return new _DataLoader({ inputs, targets }, 1, validationSplit);
|
|
2553
|
+
}
|
|
2554
|
+
};
|
|
2555
|
+
|
|
2556
|
+
// src/LRScheduler.ts
|
|
2557
|
+
var LRScheduler = class {
|
|
2558
|
+
// ── Step Decay ────────────────────────────────────────────────────────────
|
|
2559
|
+
// lr = initialLr * dropRate^floor(epoch / epochsDrop)
|
|
2560
|
+
stepDecay(lr, epoch, dropRate, epochsDrop) {
|
|
2561
|
+
return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
|
|
2562
|
+
}
|
|
2563
|
+
// ── Exponential Decay ─────────────────────────────────────────────────────
|
|
2564
|
+
// lr = initialLr * decayRate^epoch
|
|
2565
|
+
exponentialDecay(lr, epoch, decayRate) {
|
|
2566
|
+
return lr * Math.pow(decayRate, epoch);
|
|
2567
|
+
}
|
|
2568
|
+
// ── Plateau Decay ─────────────────────────────────────────────────────────
|
|
2569
|
+
// If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
|
|
2570
|
+
// Returns the new lr. Call this after each epoch with the current loss.
|
|
2571
|
+
//
|
|
2572
|
+
// Usage:
|
|
2573
|
+
// let patience_counter = 0
|
|
2574
|
+
// let best_loss = Infinity
|
|
2575
|
+
// for (let epoch = 0; epoch < 1000; epoch++) {
|
|
2576
|
+
// const loss = train(...)
|
|
2577
|
+
// lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
|
|
2578
|
+
// }
|
|
2579
|
+
plateauDecay(lr, currentLoss, history, patience, factor) {
|
|
2580
|
+
if (history.length < patience) return lr;
|
|
2581
|
+
const recentLosses = history.slice(-patience);
|
|
2582
|
+
const minRecentLoss = Math.min(...recentLosses);
|
|
2583
|
+
if (currentLoss >= minRecentLoss) {
|
|
2584
|
+
return lr * factor;
|
|
2585
|
+
}
|
|
2586
|
+
return lr;
|
|
2587
|
+
}
|
|
2588
|
+
// ── Cosine Annealing ──────────────────────────────────────────────────────
|
|
2589
|
+
// lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
|
|
2590
|
+
cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
|
|
2591
|
+
return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
|
|
2592
|
+
}
|
|
2593
|
+
};
|
|
2594
|
+
|
|
2595
|
+
// src/ModelSaver.ts
|
|
2596
|
+
var ModelSaver = class _ModelSaver {
|
|
2597
|
+
// ── Serialize to JSON string ──────────────────────────────────────────────
|
|
2598
|
+
static toJSON(model) {
|
|
2599
|
+
return JSON.stringify({
|
|
2600
|
+
weights: model.getWeights(),
|
|
2601
|
+
timestamp: Date.now()
|
|
2602
|
+
});
|
|
2603
|
+
}
|
|
2604
|
+
// ── Deserialize from JSON string ──────────────────────────────────────────
|
|
2605
|
+
static fromJSON(model, json) {
|
|
2606
|
+
const data = JSON.parse(json);
|
|
2607
|
+
if (!data.weights || !Array.isArray(data.weights)) {
|
|
2608
|
+
throw new Error("ModelSaver.fromJSON: invalid model data");
|
|
2609
|
+
}
|
|
2610
|
+
model.setWeights(data.weights);
|
|
2611
|
+
}
|
|
2612
|
+
// ── Save to file (requires write function) ────────────────────────────────
|
|
2613
|
+
static saveToFile(model, path, writeFn) {
|
|
2614
|
+
const json = _ModelSaver.toJSON(model);
|
|
2615
|
+
writeFn(path, json);
|
|
2616
|
+
}
|
|
2617
|
+
// ── Load from file (requires read function) ───────────────────────────────
|
|
2618
|
+
static loadFromFile(model, path, readFn) {
|
|
2619
|
+
const json = readFn(path);
|
|
2620
|
+
_ModelSaver.fromJSON(model, json);
|
|
2621
|
+
}
|
|
2622
|
+
};
|
|
1182
2623
|
export {
|
|
1183
2624
|
Adam,
|
|
1184
2625
|
AttentionHead,
|
|
2626
|
+
BatchNorm,
|
|
2627
|
+
ClipOptimizer,
|
|
2628
|
+
ClippedOptimizerFactory,
|
|
2629
|
+
Conv1D,
|
|
2630
|
+
DataLoader,
|
|
2631
|
+
Dropout,
|
|
1185
2632
|
EmbeddingMatrix,
|
|
2633
|
+
GRULayer,
|
|
2634
|
+
LRScheduler,
|
|
1186
2635
|
LSTMLayer,
|
|
1187
2636
|
Layer,
|
|
1188
2637
|
LayerNorm,
|
|
2638
|
+
ModelSaver,
|
|
1189
2639
|
Momentum,
|
|
1190
2640
|
MultiHeadAttention,
|
|
1191
2641
|
Network,
|
|
@@ -1196,6 +2646,7 @@ export {
|
|
|
1196
2646
|
Neuron,
|
|
1197
2647
|
NeuronN,
|
|
1198
2648
|
SGD,
|
|
2649
|
+
Trainer,
|
|
1199
2650
|
TransformerBlock,
|
|
1200
2651
|
WeightMatrix,
|
|
1201
2652
|
crossEntropy,
|
|
@@ -1214,5 +2665,9 @@ export {
|
|
|
1214
2665
|
softmax,
|
|
1215
2666
|
softmaxBackward,
|
|
1216
2667
|
tanh,
|
|
1217
|
-
transpose
|
|
2668
|
+
transpose,
|
|
2669
|
+
validate2DArray,
|
|
2670
|
+
validateArray,
|
|
2671
|
+
validateArrayMinLength,
|
|
2672
|
+
validateNumber
|
|
1218
2673
|
};
|