@dniskav/neuron 0.2.2 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +38 -0
- package/dist/index.d.mts +291 -11
- package/dist/index.d.ts +291 -11
- package/dist/index.js +1510 -41
- package/dist/index.mjs +1495 -40
- package/package.json +7 -3
package/dist/index.js
CHANGED
|
@@ -22,10 +22,19 @@ var index_exports = {};
|
|
|
22
22
|
__export(index_exports, {
|
|
23
23
|
Adam: () => Adam,
|
|
24
24
|
AttentionHead: () => AttentionHead,
|
|
25
|
+
BatchNorm: () => BatchNorm,
|
|
26
|
+
ClipOptimizer: () => ClipOptimizer,
|
|
27
|
+
ClippedOptimizerFactory: () => ClippedOptimizerFactory,
|
|
28
|
+
Conv1D: () => Conv1D,
|
|
29
|
+
DataLoader: () => DataLoader,
|
|
30
|
+
Dropout: () => Dropout,
|
|
25
31
|
EmbeddingMatrix: () => EmbeddingMatrix,
|
|
32
|
+
GRULayer: () => GRULayer,
|
|
33
|
+
LRScheduler: () => LRScheduler,
|
|
26
34
|
LSTMLayer: () => LSTMLayer,
|
|
27
35
|
Layer: () => Layer,
|
|
28
36
|
LayerNorm: () => LayerNorm,
|
|
37
|
+
ModelSaver: () => ModelSaver,
|
|
29
38
|
Momentum: () => Momentum,
|
|
30
39
|
MultiHeadAttention: () => MultiHeadAttention,
|
|
31
40
|
Network: () => Network,
|
|
@@ -36,6 +45,7 @@ __export(index_exports, {
|
|
|
36
45
|
Neuron: () => Neuron,
|
|
37
46
|
NeuronN: () => NeuronN,
|
|
38
47
|
SGD: () => SGD,
|
|
48
|
+
Trainer: () => Trainer,
|
|
39
49
|
TransformerBlock: () => TransformerBlock,
|
|
40
50
|
WeightMatrix: () => WeightMatrix,
|
|
41
51
|
crossEntropy: () => crossEntropy,
|
|
@@ -54,10 +64,82 @@ __export(index_exports, {
|
|
|
54
64
|
softmax: () => softmax,
|
|
55
65
|
softmaxBackward: () => softmaxBackward,
|
|
56
66
|
tanh: () => tanh,
|
|
57
|
-
transpose: () => transpose
|
|
67
|
+
transpose: () => transpose,
|
|
68
|
+
validate2DArray: () => validate2DArray,
|
|
69
|
+
validateArray: () => validateArray,
|
|
70
|
+
validateArrayMinLength: () => validateArrayMinLength,
|
|
71
|
+
validateNumber: () => validateNumber
|
|
58
72
|
});
|
|
59
73
|
module.exports = __toCommonJS(index_exports);
|
|
60
74
|
|
|
75
|
+
// src/Validation.ts
|
|
76
|
+
function validateArray(arr, expectedLength, methodName) {
|
|
77
|
+
if (!Array.isArray(arr)) {
|
|
78
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
79
|
+
}
|
|
80
|
+
if (arr.length !== expectedLength) {
|
|
81
|
+
throw new Error(
|
|
82
|
+
`${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
|
|
83
|
+
);
|
|
84
|
+
}
|
|
85
|
+
for (let i = 0; i < arr.length; i++) {
|
|
86
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
87
|
+
throw new Error(
|
|
88
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
89
|
+
);
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
function validateArrayMinLength(arr, minLength, methodName) {
|
|
94
|
+
if (!Array.isArray(arr)) {
|
|
95
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
96
|
+
}
|
|
97
|
+
if (arr.length < minLength) {
|
|
98
|
+
throw new Error(
|
|
99
|
+
`${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
|
|
100
|
+
);
|
|
101
|
+
}
|
|
102
|
+
for (let i = 0; i < arr.length; i++) {
|
|
103
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
104
|
+
throw new Error(
|
|
105
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
106
|
+
);
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
function validate2DArray(arr, expectedRows, expectedCols, methodName) {
|
|
111
|
+
if (!Array.isArray(arr)) {
|
|
112
|
+
throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
|
|
113
|
+
}
|
|
114
|
+
if (arr.length !== expectedRows) {
|
|
115
|
+
throw new Error(
|
|
116
|
+
`${methodName}: expected ${expectedRows} rows, got ${arr.length}`
|
|
117
|
+
);
|
|
118
|
+
}
|
|
119
|
+
for (let i = 0; i < arr.length; i++) {
|
|
120
|
+
if (!Array.isArray(arr[i])) {
|
|
121
|
+
throw new Error(`${methodName}: row ${i} is not an array`);
|
|
122
|
+
}
|
|
123
|
+
if (arr[i].length !== expectedCols) {
|
|
124
|
+
throw new Error(
|
|
125
|
+
`${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
|
|
126
|
+
);
|
|
127
|
+
}
|
|
128
|
+
for (let j = 0; j < arr[i].length; j++) {
|
|
129
|
+
if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
|
|
130
|
+
throw new Error(
|
|
131
|
+
`${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
|
|
132
|
+
);
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
function validateNumber(value, methodName) {
|
|
138
|
+
if (typeof value !== "number" || !isFinite(value)) {
|
|
139
|
+
throw new Error(`${methodName}: expected finite number, got ${value}`);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
61
143
|
// src/Neuron.ts
|
|
62
144
|
function sigmoid(x) {
|
|
63
145
|
return 1 / (1 + Math.exp(-x));
|
|
@@ -68,13 +150,18 @@ var Neuron = class {
|
|
|
68
150
|
this.bias = Math.random() * 0.1;
|
|
69
151
|
}
|
|
70
152
|
predict(input) {
|
|
153
|
+
validateNumber(input, "Neuron.predict");
|
|
71
154
|
return sigmoid(input * this.weight + this.bias);
|
|
72
155
|
}
|
|
73
156
|
train(input, target, lr) {
|
|
157
|
+
validateNumber(input, "Neuron.train");
|
|
158
|
+
validateNumber(target, "Neuron.train");
|
|
159
|
+
validateNumber(lr, "Neuron.train");
|
|
74
160
|
const prediction = this.predict(input);
|
|
75
161
|
const error = target - prediction;
|
|
76
|
-
|
|
77
|
-
this.
|
|
162
|
+
const grad = error * prediction * (1 - prediction);
|
|
163
|
+
this.weight += lr * grad * input;
|
|
164
|
+
this.bias += lr * grad;
|
|
78
165
|
}
|
|
79
166
|
};
|
|
80
167
|
|
|
@@ -129,6 +216,19 @@ var Momentum = class {
|
|
|
129
216
|
return weight + this.v;
|
|
130
217
|
}
|
|
131
218
|
};
|
|
219
|
+
var ClipOptimizer = class {
|
|
220
|
+
constructor(inner, clipValue) {
|
|
221
|
+
this.inner = inner;
|
|
222
|
+
this.clipValue = clipValue;
|
|
223
|
+
}
|
|
224
|
+
step(weight, gradient, lr) {
|
|
225
|
+
const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
|
|
226
|
+
return this.inner.step(weight, clipped, lr);
|
|
227
|
+
}
|
|
228
|
+
};
|
|
229
|
+
function ClippedOptimizerFactory(innerFactory, clipValue) {
|
|
230
|
+
return () => new ClipOptimizer(innerFactory(), clipValue);
|
|
231
|
+
}
|
|
132
232
|
var Adam = class {
|
|
133
233
|
constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
|
|
134
234
|
this.beta1 = beta1;
|
|
@@ -159,6 +259,7 @@ var NeuronN = class {
|
|
|
159
259
|
this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
|
|
160
260
|
}
|
|
161
261
|
predict(inputs) {
|
|
262
|
+
validateArray(inputs, this.weights.length, "NeuronN.predict");
|
|
162
263
|
const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
|
|
163
264
|
return this.activation.fn(sum);
|
|
164
265
|
}
|
|
@@ -171,7 +272,8 @@ var NeuronN = class {
|
|
|
171
272
|
train(inputs, target, lr) {
|
|
172
273
|
const prediction = this.predict(inputs);
|
|
173
274
|
const error = target - prediction;
|
|
174
|
-
|
|
275
|
+
const grad = error * this.activation.dfn(prediction);
|
|
276
|
+
this._update(inputs.map((inp) => grad * inp), grad, lr);
|
|
175
277
|
}
|
|
176
278
|
};
|
|
177
279
|
|
|
@@ -196,29 +298,99 @@ var Network = class {
|
|
|
196
298
|
this.outputLayer = new Layer(nOutputs, nHidden);
|
|
197
299
|
}
|
|
198
300
|
predict(inputs) {
|
|
301
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
|
|
199
302
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
200
303
|
return this.outputLayer.predict(hiddenOut)[0];
|
|
201
304
|
}
|
|
202
305
|
// Trains on a single example. Returns the squared error.
|
|
203
306
|
train(inputs, target, lr) {
|
|
307
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
|
|
308
|
+
validateNumber(target, "Network.train");
|
|
309
|
+
validateNumber(lr, "Network.train");
|
|
204
310
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
205
311
|
const prediction = this.outputLayer.predict(hiddenOut)[0];
|
|
206
312
|
const outputError = target - prediction;
|
|
207
313
|
const outputDelta = outputError * prediction * (1 - prediction);
|
|
208
314
|
const outputNeuron = this.outputLayer.neurons[0];
|
|
315
|
+
const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
|
|
316
|
+
const hiddenOut_i = hiddenOut[i];
|
|
317
|
+
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
318
|
+
return hiddenError * hiddenOut_i * (1 - hiddenOut_i);
|
|
319
|
+
});
|
|
320
|
+
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
321
|
+
neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDeltas[i] * inputs[j]);
|
|
322
|
+
neuron.bias += lr * hiddenDeltas[i];
|
|
323
|
+
});
|
|
209
324
|
outputNeuron.weights = outputNeuron.weights.map(
|
|
210
325
|
(w, i) => w + lr * outputDelta * hiddenOut[i]
|
|
211
326
|
);
|
|
212
327
|
outputNeuron.bias += lr * outputDelta;
|
|
213
|
-
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
214
|
-
const hiddenOut_i = hiddenOut[i];
|
|
215
|
-
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
216
|
-
const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
|
|
217
|
-
neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
|
|
218
|
-
neuron.bias += lr * hiddenDelta;
|
|
219
|
-
});
|
|
220
328
|
return outputError * outputError;
|
|
221
329
|
}
|
|
330
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
331
|
+
// Order: hidden layer (all neurons: weights then bias), then output layer.
|
|
332
|
+
getWeights() {
|
|
333
|
+
const w = [];
|
|
334
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
335
|
+
w.push(...n.weights, n.bias);
|
|
336
|
+
}
|
|
337
|
+
for (const n of this.outputLayer.neurons) {
|
|
338
|
+
w.push(...n.weights, n.bias);
|
|
339
|
+
}
|
|
340
|
+
return w;
|
|
341
|
+
}
|
|
342
|
+
setWeights(weights) {
|
|
343
|
+
let idx = 0;
|
|
344
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
345
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
346
|
+
n.bias = weights[idx++];
|
|
347
|
+
}
|
|
348
|
+
for (const n of this.outputLayer.neurons) {
|
|
349
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
350
|
+
n.bias = weights[idx++];
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
};
|
|
354
|
+
|
|
355
|
+
// src/Dropout.ts
|
|
356
|
+
var Dropout = class {
|
|
357
|
+
constructor(rate) {
|
|
358
|
+
this._mask = null;
|
|
359
|
+
if (rate < 0 || rate >= 1) {
|
|
360
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
|
|
361
|
+
}
|
|
362
|
+
this.rate = rate;
|
|
363
|
+
}
|
|
364
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
365
|
+
// x: number[] → number[]
|
|
366
|
+
// If training, applies inverted dropout mask.
|
|
367
|
+
// If not training, returns input unchanged.
|
|
368
|
+
forward(x, training = true) {
|
|
369
|
+
if (!training || this.rate === 0) {
|
|
370
|
+
this._mask = null;
|
|
371
|
+
return [...x];
|
|
372
|
+
}
|
|
373
|
+
const scale = 1 / (1 - this.rate);
|
|
374
|
+
this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
|
|
375
|
+
return x.map((v, i) => v * this._mask[i]);
|
|
376
|
+
}
|
|
377
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
378
|
+
// dOut: number[] → number[]
|
|
379
|
+
// Applies the same mask (gradient is zeroed where activation was zeroed).
|
|
380
|
+
backward(dOut) {
|
|
381
|
+
if (!this._mask) return [...dOut];
|
|
382
|
+
return dOut.map((d, i) => d * this._mask[i]);
|
|
383
|
+
}
|
|
384
|
+
// ── Reset mask between forward passes ─────────────────────────────────────
|
|
385
|
+
resetMask() {
|
|
386
|
+
this._mask = null;
|
|
387
|
+
}
|
|
388
|
+
// ── No trainable params ───────────────────────────────────────────────────
|
|
389
|
+
getWeights() {
|
|
390
|
+
return [];
|
|
391
|
+
}
|
|
392
|
+
setWeights(_weights) {
|
|
393
|
+
}
|
|
222
394
|
};
|
|
223
395
|
|
|
224
396
|
// src/NetworkN.ts
|
|
@@ -229,30 +401,96 @@ var NetworkN = class {
|
|
|
229
401
|
const nLayers = structure.length - 1;
|
|
230
402
|
const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
|
|
231
403
|
const optimizer = options.optimizer ?? defaultOptimizer3;
|
|
404
|
+
const dropoutRate = options.dropoutRate ?? 0;
|
|
405
|
+
if (activations.length !== nLayers) {
|
|
406
|
+
throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
|
|
407
|
+
}
|
|
408
|
+
if (dropoutRate < 0 || dropoutRate >= 1) {
|
|
409
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
|
|
410
|
+
}
|
|
411
|
+
this._residual = options.residual ?? false;
|
|
232
412
|
this.layers = [];
|
|
233
413
|
for (let i = 1; i < structure.length; i++) {
|
|
234
414
|
this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
|
|
235
415
|
}
|
|
416
|
+
this._dropouts = [];
|
|
417
|
+
if (dropoutRate > 0) {
|
|
418
|
+
for (let i = 0; i < nLayers - 1; i++) {
|
|
419
|
+
this._dropouts.push(new Dropout(dropoutRate));
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
const outputLayer = this.layers[this.layers.length - 1];
|
|
423
|
+
const outputActivation = outputLayer.neurons[0].activation;
|
|
424
|
+
for (let i = 1; i < outputLayer.neurons.length; i++) {
|
|
425
|
+
if (outputLayer.neurons[i].activation !== outputActivation) {
|
|
426
|
+
throw new Error("All output neurons must share the same activation function");
|
|
427
|
+
}
|
|
428
|
+
}
|
|
236
429
|
}
|
|
237
|
-
predict(inputs) {
|
|
238
|
-
|
|
430
|
+
predict(inputs, training = false) {
|
|
431
|
+
validateArray(inputs, this.structure[0], "NetworkN.predict");
|
|
432
|
+
let current = [...inputs];
|
|
433
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
434
|
+
const layerInput = [...current];
|
|
435
|
+
const layerOutput = this.layers[i].predict(current);
|
|
436
|
+
if (this._shouldResidual(i)) {
|
|
437
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
438
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
439
|
+
} else {
|
|
440
|
+
current = [...layerOutput];
|
|
441
|
+
}
|
|
442
|
+
} else {
|
|
443
|
+
current = [...layerOutput];
|
|
444
|
+
}
|
|
445
|
+
if (i < this._dropouts.length) {
|
|
446
|
+
current = this._dropouts[i].forward(current, training);
|
|
447
|
+
}
|
|
448
|
+
}
|
|
449
|
+
return current;
|
|
239
450
|
}
|
|
240
451
|
// Generalized backpropagation across L layers.
|
|
241
452
|
// Returns the mean squared error for the example.
|
|
242
453
|
train(inputs, targets, lr) {
|
|
454
|
+
validateArray(inputs, this.structure[0], "NetworkN.train");
|
|
455
|
+
validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
|
|
243
456
|
const act = [inputs];
|
|
244
|
-
for (
|
|
457
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
458
|
+
const layerInput = act[act.length - 1];
|
|
459
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
460
|
+
let current;
|
|
461
|
+
if (this._shouldResidual(i)) {
|
|
462
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
463
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
464
|
+
} else {
|
|
465
|
+
current = [...layerOutput];
|
|
466
|
+
}
|
|
467
|
+
} else {
|
|
468
|
+
current = [...layerOutput];
|
|
469
|
+
}
|
|
470
|
+
if (i < this._dropouts.length) {
|
|
471
|
+
current = this._dropouts[i].forward(current, true);
|
|
472
|
+
}
|
|
473
|
+
act.push(current);
|
|
474
|
+
}
|
|
245
475
|
const pred = act[act.length - 1];
|
|
246
476
|
const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
|
|
247
477
|
let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
|
|
248
478
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
249
479
|
const layer = this.layers[l];
|
|
480
|
+
if (l < this._dropouts.length) {
|
|
481
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
482
|
+
}
|
|
250
483
|
const layerIn = act[l];
|
|
251
484
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
252
485
|
const prevDeltas = layerIn.map((out, j) => {
|
|
253
486
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
254
487
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
255
488
|
});
|
|
489
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
490
|
+
for (let j = 0; j < prevDeltas.length; j++) {
|
|
491
|
+
prevDeltas[j] += deltas[j];
|
|
492
|
+
}
|
|
493
|
+
}
|
|
256
494
|
layer.neurons.forEach((n, k) => {
|
|
257
495
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
258
496
|
});
|
|
@@ -264,22 +502,74 @@ var NetworkN = class {
|
|
|
264
502
|
// Useful for custom loss functions (e.g. physics-based gradients).
|
|
265
503
|
trainWithDeltas(inputs, outputDeltas, lr) {
|
|
266
504
|
const act = [inputs];
|
|
267
|
-
for (
|
|
505
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
506
|
+
const layerInput = act[act.length - 1];
|
|
507
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
508
|
+
let current;
|
|
509
|
+
if (this._shouldResidual(i)) {
|
|
510
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
511
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
512
|
+
} else {
|
|
513
|
+
current = [...layerOutput];
|
|
514
|
+
}
|
|
515
|
+
} else {
|
|
516
|
+
current = [...layerOutput];
|
|
517
|
+
}
|
|
518
|
+
if (i < this._dropouts.length) {
|
|
519
|
+
current = this._dropouts[i].forward(current, true);
|
|
520
|
+
}
|
|
521
|
+
act.push(current);
|
|
522
|
+
}
|
|
268
523
|
let deltas = outputDeltas;
|
|
269
524
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
270
525
|
const layer = this.layers[l];
|
|
526
|
+
if (l < this._dropouts.length) {
|
|
527
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
528
|
+
}
|
|
271
529
|
const layerIn = act[l];
|
|
272
530
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
273
531
|
const prevDeltas = layerIn.map((out, j) => {
|
|
274
532
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
275
533
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
276
534
|
});
|
|
535
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
536
|
+
for (let j = 0; j < prevDeltas.length; j++) {
|
|
537
|
+
prevDeltas[j] += deltas[j];
|
|
538
|
+
}
|
|
539
|
+
}
|
|
277
540
|
layer.neurons.forEach((n, k) => {
|
|
278
541
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
279
542
|
});
|
|
280
543
|
deltas = prevDeltas;
|
|
281
544
|
}
|
|
282
545
|
}
|
|
546
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
547
|
+
// Order: layer 0 (all neurons), layer 1, ..., layer N.
|
|
548
|
+
getWeights() {
|
|
549
|
+
for (const d of this._dropouts) d.resetMask();
|
|
550
|
+
const w = [];
|
|
551
|
+
for (const layer of this.layers) {
|
|
552
|
+
for (const n of layer.neurons) {
|
|
553
|
+
w.push(...n.weights, n.bias);
|
|
554
|
+
}
|
|
555
|
+
}
|
|
556
|
+
return w;
|
|
557
|
+
}
|
|
558
|
+
setWeights(weights) {
|
|
559
|
+
for (const d of this._dropouts) d.resetMask();
|
|
560
|
+
let idx = 0;
|
|
561
|
+
for (const layer of this.layers) {
|
|
562
|
+
for (const n of layer.neurons) {
|
|
563
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
564
|
+
n.bias = weights[idx++];
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
}
|
|
568
|
+
// ── Helper ───────────────────────────────────────────────────────────────
|
|
569
|
+
_shouldResidual(layerIndex) {
|
|
570
|
+
if (typeof this._residual === "function") return this._residual(layerIndex);
|
|
571
|
+
return this._residual;
|
|
572
|
+
}
|
|
283
573
|
};
|
|
284
574
|
|
|
285
575
|
// src/LSTMLayer.ts
|
|
@@ -308,8 +598,11 @@ var Gate = class {
|
|
|
308
598
|
}
|
|
309
599
|
};
|
|
310
600
|
var LSTMLayer = class {
|
|
311
|
-
constructor(inputSize, hiddenSize) {
|
|
601
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
312
602
|
this._traj = [];
|
|
603
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
604
|
+
throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
605
|
+
}
|
|
313
606
|
this.inputSize = inputSize;
|
|
314
607
|
this.hSize = hiddenSize;
|
|
315
608
|
this.h = new Array(hiddenSize).fill(0);
|
|
@@ -318,6 +611,29 @@ var LSTMLayer = class {
|
|
|
318
611
|
this.inputGate = new Gate(inputSize, hiddenSize);
|
|
319
612
|
this.cellGate = new Gate(inputSize, hiddenSize);
|
|
320
613
|
this.outputGate = new Gate(inputSize, hiddenSize);
|
|
614
|
+
const combSize = inputSize + hiddenSize;
|
|
615
|
+
this._optimizers = {
|
|
616
|
+
forgetW: Array.from(
|
|
617
|
+
{ length: hiddenSize },
|
|
618
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
619
|
+
),
|
|
620
|
+
forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
621
|
+
inputW: Array.from(
|
|
622
|
+
{ length: hiddenSize },
|
|
623
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
624
|
+
),
|
|
625
|
+
inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
626
|
+
cellW: Array.from(
|
|
627
|
+
{ length: hiddenSize },
|
|
628
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
629
|
+
),
|
|
630
|
+
cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
631
|
+
outputW: Array.from(
|
|
632
|
+
{ length: hiddenSize },
|
|
633
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
634
|
+
),
|
|
635
|
+
outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
636
|
+
};
|
|
321
637
|
}
|
|
322
638
|
// ── Reset state and trajectory (call at episode start) ────────────────────
|
|
323
639
|
reset() {
|
|
@@ -327,6 +643,9 @@ var LSTMLayer = class {
|
|
|
327
643
|
}
|
|
328
644
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
329
645
|
predict(inputs) {
|
|
646
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
647
|
+
throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
648
|
+
}
|
|
330
649
|
const combined = [...inputs, ...this.h];
|
|
331
650
|
const c_prev = [...this.c];
|
|
332
651
|
const zf = this.forgetGate.linear(combined);
|
|
@@ -401,15 +720,15 @@ var LSTMLayer = class {
|
|
|
401
720
|
const scale = lr / T;
|
|
402
721
|
for (let k = 0; k < hSize; k++) {
|
|
403
722
|
for (let j = 0; j < combSize; j++) {
|
|
404
|
-
this.forgetGate.W[k][j]
|
|
405
|
-
this.inputGate.W[k][j]
|
|
406
|
-
this.cellGate.W[k][j]
|
|
407
|
-
this.outputGate.W[k][j]
|
|
723
|
+
this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
|
|
724
|
+
this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
|
|
725
|
+
this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
|
|
726
|
+
this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
|
|
408
727
|
}
|
|
409
|
-
this.forgetGate.b[k]
|
|
410
|
-
this.inputGate.b[k]
|
|
411
|
-
this.cellGate.b[k]
|
|
412
|
-
this.outputGate.b[k]
|
|
728
|
+
this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
|
|
729
|
+
this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
|
|
730
|
+
this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
|
|
731
|
+
this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
|
|
413
732
|
}
|
|
414
733
|
this._traj = [];
|
|
415
734
|
}
|
|
@@ -432,6 +751,35 @@ var LSTMLayer = class {
|
|
|
432
751
|
this.outputGate.W = data.outputGate.W;
|
|
433
752
|
this.outputGate.b = data.outputGate.b;
|
|
434
753
|
}
|
|
754
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
755
|
+
// Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
|
|
756
|
+
getWeightsFlat() {
|
|
757
|
+
const w = [];
|
|
758
|
+
for (const row of this.forgetGate.W) w.push(...row);
|
|
759
|
+
w.push(...this.forgetGate.b);
|
|
760
|
+
for (const row of this.inputGate.W) w.push(...row);
|
|
761
|
+
w.push(...this.inputGate.b);
|
|
762
|
+
for (const row of this.cellGate.W) w.push(...row);
|
|
763
|
+
w.push(...this.cellGate.b);
|
|
764
|
+
for (const row of this.outputGate.W) w.push(...row);
|
|
765
|
+
w.push(...this.outputGate.b);
|
|
766
|
+
return w;
|
|
767
|
+
}
|
|
768
|
+
setWeightsFlat(weights) {
|
|
769
|
+
let idx = 0;
|
|
770
|
+
for (let i = 0; i < this.forgetGate.W.length; i++)
|
|
771
|
+
for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
|
|
772
|
+
for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
|
|
773
|
+
for (let i = 0; i < this.inputGate.W.length; i++)
|
|
774
|
+
for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
|
|
775
|
+
for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
|
|
776
|
+
for (let i = 0; i < this.cellGate.W.length; i++)
|
|
777
|
+
for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
|
|
778
|
+
for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
|
|
779
|
+
for (let i = 0; i < this.outputGate.W.length; i++)
|
|
780
|
+
for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
|
|
781
|
+
for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
|
|
782
|
+
}
|
|
435
783
|
};
|
|
436
784
|
|
|
437
785
|
// src/NetworkLSTM.ts
|
|
@@ -458,6 +806,7 @@ var NetworkLSTM = class {
|
|
|
458
806
|
}
|
|
459
807
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
460
808
|
predict(inputs) {
|
|
809
|
+
validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
|
|
461
810
|
const h = this.lstm.predict(inputs);
|
|
462
811
|
const acts = [h];
|
|
463
812
|
for (const layer of this.denseLayers) {
|
|
@@ -533,6 +882,30 @@ var NetworkLSTM = class {
|
|
|
533
882
|
});
|
|
534
883
|
});
|
|
535
884
|
}
|
|
885
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
886
|
+
// Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
|
|
887
|
+
getWeightsFlat() {
|
|
888
|
+
const w = [];
|
|
889
|
+
w.push(...this.lstm.getWeightsFlat());
|
|
890
|
+
for (const layer of this.denseLayers) {
|
|
891
|
+
for (const n of layer.neurons) {
|
|
892
|
+
w.push(...n.weights, n.bias);
|
|
893
|
+
}
|
|
894
|
+
}
|
|
895
|
+
return w;
|
|
896
|
+
}
|
|
897
|
+
setWeightsFlat(weights) {
|
|
898
|
+
let idx = 0;
|
|
899
|
+
const lstmLen = this.lstm.getWeightsFlat().length;
|
|
900
|
+
this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
|
|
901
|
+
idx += lstmLen;
|
|
902
|
+
for (const layer of this.denseLayers) {
|
|
903
|
+
for (const n of layer.neurons) {
|
|
904
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
905
|
+
n.bias = weights[idx++];
|
|
906
|
+
}
|
|
907
|
+
}
|
|
908
|
+
}
|
|
536
909
|
};
|
|
537
910
|
|
|
538
911
|
// src/MatMul.ts
|
|
@@ -540,6 +913,9 @@ function matMul(A, B) {
|
|
|
540
913
|
const rows = A.length;
|
|
541
914
|
const inner = B.length;
|
|
542
915
|
const cols = B[0].length;
|
|
916
|
+
if (A[0].length !== B.length) {
|
|
917
|
+
throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
|
|
918
|
+
}
|
|
543
919
|
const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
|
|
544
920
|
for (let i = 0; i < rows; i++)
|
|
545
921
|
for (let k = 0; k < inner; k++) {
|
|
@@ -590,6 +966,17 @@ var WeightMatrix = class {
|
|
|
590
966
|
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
591
967
|
}
|
|
592
968
|
}
|
|
969
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
970
|
+
getWeights() {
|
|
971
|
+
const w = [];
|
|
972
|
+
for (const row of this.W) w.push(...row);
|
|
973
|
+
return w;
|
|
974
|
+
}
|
|
975
|
+
setWeights(weights) {
|
|
976
|
+
let idx = 0;
|
|
977
|
+
for (let i = 0; i < this.W.length; i++)
|
|
978
|
+
for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
|
|
979
|
+
}
|
|
593
980
|
};
|
|
594
981
|
var EmbeddingMatrix = class {
|
|
595
982
|
constructor(vocabSize, d_model) {
|
|
@@ -606,15 +993,29 @@ var EmbeddingMatrix = class {
|
|
|
606
993
|
for (let m = 0; m < this.W[idx].length; m++)
|
|
607
994
|
this.W[idx][m] += lr * grad[m];
|
|
608
995
|
}
|
|
996
|
+
// ── Serializable interface ─────────────────────────────────────────────────
|
|
997
|
+
// Flattened order: row 0, row 1, ... row (vocabSize-1)
|
|
998
|
+
getWeights() {
|
|
999
|
+
const w = [];
|
|
1000
|
+
for (const row of this.W) w.push(...row);
|
|
1001
|
+
return w;
|
|
1002
|
+
}
|
|
1003
|
+
setWeights(weights) {
|
|
1004
|
+
let idx = 0;
|
|
1005
|
+
for (let i = 0; i < this.W.length; i++)
|
|
1006
|
+
for (let j = 0; j < this.W[i].length; j++)
|
|
1007
|
+
this.W[i][j] = weights[idx++];
|
|
1008
|
+
}
|
|
609
1009
|
};
|
|
610
1010
|
|
|
611
1011
|
// src/AttentionHead.ts
|
|
612
1012
|
var AttentionHead = class {
|
|
613
|
-
constructor(d_model, d_k, d_v) {
|
|
1013
|
+
constructor(d_model, d_k, d_v, causal = false) {
|
|
614
1014
|
// d_v × d_model
|
|
615
1015
|
this.cache = null;
|
|
616
1016
|
this.d_k = d_k;
|
|
617
1017
|
this.d_v = d_v;
|
|
1018
|
+
this.causal = causal;
|
|
618
1019
|
this.Wq = new WeightMatrix(d_k, d_model);
|
|
619
1020
|
this.Wk = new WeightMatrix(d_k, d_model);
|
|
620
1021
|
this.Wv = new WeightMatrix(d_v, d_model);
|
|
@@ -635,10 +1036,10 @@ var AttentionHead = class {
|
|
|
635
1036
|
);
|
|
636
1037
|
const scores = Array.from(
|
|
637
1038
|
{ length: seqLen },
|
|
638
|
-
(_, i) => Array.from(
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
)
|
|
1039
|
+
(_, i) => Array.from({ length: seqLen }, (_2, j) => {
|
|
1040
|
+
if (this.causal && j > i) return -Infinity;
|
|
1041
|
+
return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
|
|
1042
|
+
})
|
|
642
1043
|
);
|
|
643
1044
|
const attn = scores.map((row) => softmax(row));
|
|
644
1045
|
const out = Array.from(
|
|
@@ -734,21 +1135,40 @@ var AttentionHead = class {
|
|
|
734
1135
|
getAttentionWeights() {
|
|
735
1136
|
return this.cache ? this.cache.attn : null;
|
|
736
1137
|
}
|
|
1138
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1139
|
+
// Order: Wq, Wk, Wv.
|
|
1140
|
+
getWeights() {
|
|
1141
|
+
const w = [];
|
|
1142
|
+
for (const row of this.Wq.W) w.push(...row);
|
|
1143
|
+
for (const row of this.Wk.W) w.push(...row);
|
|
1144
|
+
for (const row of this.Wv.W) w.push(...row);
|
|
1145
|
+
return w;
|
|
1146
|
+
}
|
|
1147
|
+
setWeights(weights) {
|
|
1148
|
+
let idx = 0;
|
|
1149
|
+
for (let i = 0; i < this.Wq.W.length; i++)
|
|
1150
|
+
for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
|
|
1151
|
+
for (let i = 0; i < this.Wk.W.length; i++)
|
|
1152
|
+
for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
|
|
1153
|
+
for (let i = 0; i < this.Wv.W.length; i++)
|
|
1154
|
+
for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
|
|
1155
|
+
}
|
|
737
1156
|
};
|
|
738
1157
|
|
|
739
1158
|
// src/MultiHeadAttention.ts
|
|
740
1159
|
var MultiHeadAttention = class {
|
|
741
1160
|
// seqLen × (nHeads * d_k)
|
|
742
|
-
constructor(d_model, nHeads) {
|
|
1161
|
+
constructor(d_model, nHeads, causal = false) {
|
|
743
1162
|
// d_model × (nHeads * d_k)
|
|
744
1163
|
// Cached for backward
|
|
745
1164
|
this._concat = null;
|
|
746
1165
|
this.nHeads = nHeads;
|
|
747
1166
|
this.d_model = d_model;
|
|
748
1167
|
this.d_k = Math.floor(d_model / nHeads);
|
|
1168
|
+
this.causal = causal;
|
|
749
1169
|
this.heads = Array.from(
|
|
750
1170
|
{ length: nHeads },
|
|
751
|
-
() => new AttentionHead(d_model, this.d_k, this.d_k)
|
|
1171
|
+
() => new AttentionHead(d_model, this.d_k, this.d_k, causal)
|
|
752
1172
|
);
|
|
753
1173
|
this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
|
|
754
1174
|
}
|
|
@@ -807,6 +1227,31 @@ var MultiHeadAttention = class {
|
|
|
807
1227
|
getAttentionWeights() {
|
|
808
1228
|
return this.heads.map((h) => h.getAttentionWeights());
|
|
809
1229
|
}
|
|
1230
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1231
|
+
// Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
|
|
1232
|
+
getWeights() {
|
|
1233
|
+
const w = [];
|
|
1234
|
+
for (const head of this.heads) {
|
|
1235
|
+
for (const row of head.Wq.W) w.push(...row);
|
|
1236
|
+
for (const row of head.Wk.W) w.push(...row);
|
|
1237
|
+
for (const row of head.Wv.W) w.push(...row);
|
|
1238
|
+
}
|
|
1239
|
+
for (const row of this.Wo.W) w.push(...row);
|
|
1240
|
+
return w;
|
|
1241
|
+
}
|
|
1242
|
+
setWeights(weights) {
|
|
1243
|
+
let idx = 0;
|
|
1244
|
+
for (const head of this.heads) {
|
|
1245
|
+
for (let i = 0; i < head.Wq.W.length; i++)
|
|
1246
|
+
for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
|
|
1247
|
+
for (let i = 0; i < head.Wk.W.length; i++)
|
|
1248
|
+
for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
|
|
1249
|
+
for (let i = 0; i < head.Wv.W.length; i++)
|
|
1250
|
+
for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
|
|
1251
|
+
}
|
|
1252
|
+
for (let i = 0; i < this.Wo.W.length; i++)
|
|
1253
|
+
for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
|
|
1254
|
+
}
|
|
810
1255
|
};
|
|
811
1256
|
|
|
812
1257
|
// src/LayerNorm.ts
|
|
@@ -858,11 +1303,21 @@ var LayerNorm = class {
|
|
|
858
1303
|
const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
|
|
859
1304
|
return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
|
|
860
1305
|
}
|
|
1306
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1307
|
+
// Order: gamma, beta.
|
|
1308
|
+
getWeights() {
|
|
1309
|
+
return [...this.gamma, ...this.beta];
|
|
1310
|
+
}
|
|
1311
|
+
setWeights(weights) {
|
|
1312
|
+
const dim = this.gamma.length;
|
|
1313
|
+
for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
|
|
1314
|
+
for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
|
|
1315
|
+
}
|
|
861
1316
|
};
|
|
862
1317
|
|
|
863
1318
|
// src/TransformerBlock.ts
|
|
864
1319
|
var TransformerBlock = class {
|
|
865
|
-
constructor({ d_model, nHeads, d_ff }) {
|
|
1320
|
+
constructor({ d_model, nHeads, d_ff, causal = false }) {
|
|
866
1321
|
// Forward caches (needed for backprop)
|
|
867
1322
|
this._X = null;
|
|
868
1323
|
this._attnOut = null;
|
|
@@ -874,7 +1329,7 @@ var TransformerBlock = class {
|
|
|
874
1329
|
this._ff2Out = null;
|
|
875
1330
|
this.d_model = d_model;
|
|
876
1331
|
this.d_ff = d_ff;
|
|
877
|
-
this.attn = new MultiHeadAttention(d_model, nHeads);
|
|
1332
|
+
this.attn = new MultiHeadAttention(d_model, nHeads, causal);
|
|
878
1333
|
this.norm1 = new LayerNorm(d_model);
|
|
879
1334
|
this.norm2 = new LayerNorm(d_model);
|
|
880
1335
|
this.ff1 = new WeightMatrix(d_ff, d_model);
|
|
@@ -987,6 +1442,35 @@ var TransformerBlock = class {
|
|
|
987
1442
|
getAttentionWeights() {
|
|
988
1443
|
return this.attn.getAttentionWeights();
|
|
989
1444
|
}
|
|
1445
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1446
|
+
// Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
|
|
1447
|
+
getWeights() {
|
|
1448
|
+
const w = [];
|
|
1449
|
+
w.push(...this.attn.getWeights());
|
|
1450
|
+
w.push(...this.norm1.gamma, ...this.norm1.beta);
|
|
1451
|
+
for (const row of this.ff1.W) w.push(...row);
|
|
1452
|
+
w.push(...this.b1);
|
|
1453
|
+
for (const row of this.ff2.W) w.push(...row);
|
|
1454
|
+
w.push(...this.b2);
|
|
1455
|
+
w.push(...this.norm2.gamma, ...this.norm2.beta);
|
|
1456
|
+
return w;
|
|
1457
|
+
}
|
|
1458
|
+
setWeights(weights) {
|
|
1459
|
+
let idx = 0;
|
|
1460
|
+
const attnLen = this.attn.getWeights().length;
|
|
1461
|
+
this.attn.setWeights(weights.slice(idx, idx + attnLen));
|
|
1462
|
+
idx += attnLen;
|
|
1463
|
+
for (let i = 0; i < this.norm1.gamma.length; i++) this.norm1.gamma[i] = weights[idx++];
|
|
1464
|
+
for (let i = 0; i < this.norm1.beta.length; i++) this.norm1.beta[i] = weights[idx++];
|
|
1465
|
+
for (let i = 0; i < this.ff1.W.length; i++)
|
|
1466
|
+
for (let j = 0; j < this.ff1.W[i].length; j++) this.ff1.W[i][j] = weights[idx++];
|
|
1467
|
+
for (let i = 0; i < this.b1.length; i++) this.b1[i] = weights[idx++];
|
|
1468
|
+
for (let i = 0; i < this.ff2.W.length; i++)
|
|
1469
|
+
for (let j = 0; j < this.ff2.W[i].length; j++) this.ff2.W[i][j] = weights[idx++];
|
|
1470
|
+
for (let i = 0; i < this.b2.length; i++) this.b2[i] = weights[idx++];
|
|
1471
|
+
for (let i = 0; i < this.norm2.gamma.length; i++) this.norm2.gamma[i] = weights[idx++];
|
|
1472
|
+
for (let i = 0; i < this.norm2.beta.length; i++) this.norm2.beta[i] = weights[idx++];
|
|
1473
|
+
}
|
|
990
1474
|
};
|
|
991
1475
|
|
|
992
1476
|
// src/NetworkTransformer.ts
|
|
@@ -1085,6 +1569,32 @@ var NetworkTransformer = class {
|
|
|
1085
1569
|
getAttentionWeights() {
|
|
1086
1570
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1087
1571
|
}
|
|
1572
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1573
|
+
// Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1574
|
+
getWeights() {
|
|
1575
|
+
const w = [];
|
|
1576
|
+
for (const row of this.tokenEmb.W) w.push(...row);
|
|
1577
|
+
for (const row of this.posEmb.W) w.push(...row);
|
|
1578
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1579
|
+
for (const row of this.outputProj.W) w.push(...row);
|
|
1580
|
+
w.push(...this.outputBias);
|
|
1581
|
+
return w;
|
|
1582
|
+
}
|
|
1583
|
+
setWeights(weights) {
|
|
1584
|
+
let idx = 0;
|
|
1585
|
+
for (let i = 0; i < this.tokenEmb.W.length; i++)
|
|
1586
|
+
for (let j = 0; j < this.tokenEmb.W[i].length; j++) this.tokenEmb.W[i][j] = weights[idx++];
|
|
1587
|
+
for (let i = 0; i < this.posEmb.W.length; i++)
|
|
1588
|
+
for (let j = 0; j < this.posEmb.W[i].length; j++) this.posEmb.W[i][j] = weights[idx++];
|
|
1589
|
+
for (const block of this.blocks) {
|
|
1590
|
+
const blockLen = block.getWeights().length;
|
|
1591
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1592
|
+
idx += blockLen;
|
|
1593
|
+
}
|
|
1594
|
+
for (let i = 0; i < this.outputProj.W.length; i++)
|
|
1595
|
+
for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
|
|
1596
|
+
for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
|
|
1597
|
+
}
|
|
1088
1598
|
// ── Internal ──────────────────────────────────────────────────────────────
|
|
1089
1599
|
// Shared embedding + block forward pass.
|
|
1090
1600
|
_forward(tokens) {
|
|
@@ -1104,21 +1614,25 @@ var NetworkTransformerRL = class {
|
|
|
1104
1614
|
constructor(seqLen, inputDim, options = {}) {
|
|
1105
1615
|
// Forward caches para backprop
|
|
1106
1616
|
this._projected = null;
|
|
1617
|
+
// For max pooling backward: argmax per dimension across all positions
|
|
1618
|
+
this._argmax = null;
|
|
1107
1619
|
const {
|
|
1108
1620
|
d_model = 32,
|
|
1109
1621
|
nHeads = 2,
|
|
1110
1622
|
d_ff = 64,
|
|
1111
1623
|
nBlocks = 2,
|
|
1112
|
-
nActions = 2
|
|
1624
|
+
nActions = 2,
|
|
1625
|
+
pooling = "weighted"
|
|
1113
1626
|
} = options;
|
|
1114
1627
|
this.seqLen = seqLen;
|
|
1115
1628
|
this.inputDim = inputDim;
|
|
1116
1629
|
this.d_model = d_model;
|
|
1117
1630
|
this.nActions = nActions;
|
|
1631
|
+
this._pooling = pooling;
|
|
1118
1632
|
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1119
1633
|
this.blocks = Array.from(
|
|
1120
1634
|
{ length: nBlocks },
|
|
1121
|
-
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1635
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
|
|
1122
1636
|
);
|
|
1123
1637
|
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1124
1638
|
this.outputBias = new Array(nActions).fill(0);
|
|
@@ -1167,11 +1681,7 @@ var NetworkTransformerRL = class {
|
|
|
1167
1681
|
this.outputProj.update(dWout, lr);
|
|
1168
1682
|
for (let c = 0; c < this.nActions; c++)
|
|
1169
1683
|
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1170
|
-
let dH =
|
|
1171
|
-
{ length: this.seqLen },
|
|
1172
|
-
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1173
|
-
// Gradiente dividido entre posiciones
|
|
1174
|
-
);
|
|
1684
|
+
let dH = this._distributePoolGradient(dPooled);
|
|
1175
1685
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1176
1686
|
dH = this.blocks[b].backward(dH, lr);
|
|
1177
1687
|
for (let i = 0; i < this.seqLen; i++) {
|
|
@@ -1190,6 +1700,85 @@ var NetworkTransformerRL = class {
|
|
|
1190
1700
|
getAttentionWeights() {
|
|
1191
1701
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1192
1702
|
}
|
|
1703
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1704
|
+
// Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1705
|
+
getWeightsFlat() {
|
|
1706
|
+
const w = [];
|
|
1707
|
+
for (const row of this.inputProj.W) w.push(...row);
|
|
1708
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1709
|
+
for (const row of this.outputProj.W) w.push(...row);
|
|
1710
|
+
w.push(...this.outputBias);
|
|
1711
|
+
return w;
|
|
1712
|
+
}
|
|
1713
|
+
setWeightsFlat(weights) {
|
|
1714
|
+
let idx = 0;
|
|
1715
|
+
for (let i = 0; i < this.inputProj.W.length; i++)
|
|
1716
|
+
for (let j = 0; j < this.inputProj.W[i].length; j++) this.inputProj.W[i][j] = weights[idx++];
|
|
1717
|
+
for (const block of this.blocks) {
|
|
1718
|
+
const blockLen = block.getWeights().length;
|
|
1719
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1720
|
+
idx += blockLen;
|
|
1721
|
+
}
|
|
1722
|
+
for (let i = 0; i < this.outputProj.W.length; i++)
|
|
1723
|
+
for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
|
|
1724
|
+
for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
|
|
1725
|
+
}
|
|
1726
|
+
getWeightsStructured() {
|
|
1727
|
+
return {
|
|
1728
|
+
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1729
|
+
blocks: this.blocks.map((b) => ({
|
|
1730
|
+
attn: {
|
|
1731
|
+
heads: b.attn.heads.map((h) => ({
|
|
1732
|
+
Wq: h.Wq.W.map((r) => [...r]),
|
|
1733
|
+
Wk: h.Wk.W.map((r) => [...r]),
|
|
1734
|
+
Wv: h.Wv.W.map((r) => [...r])
|
|
1735
|
+
})),
|
|
1736
|
+
Wo: b.attn.Wo.W.map((r) => [...r])
|
|
1737
|
+
},
|
|
1738
|
+
norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
|
|
1739
|
+
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1740
|
+
ff1: b.ff1.W.map((r) => [...r]),
|
|
1741
|
+
ff2: b.ff2.W.map((r) => [...r]),
|
|
1742
|
+
b1: [...b.b1],
|
|
1743
|
+
b2: [...b.b2]
|
|
1744
|
+
})),
|
|
1745
|
+
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1746
|
+
outputBias: [...this.outputBias]
|
|
1747
|
+
};
|
|
1748
|
+
}
|
|
1749
|
+
setWeightsStructured(data) {
|
|
1750
|
+
data.inputProj.forEach((row, i) => {
|
|
1751
|
+
this.inputProj.W[i] = [...row];
|
|
1752
|
+
});
|
|
1753
|
+
data.blocks.forEach((bd, b) => {
|
|
1754
|
+
const blk = this.blocks[b];
|
|
1755
|
+
bd.attn.heads.forEach((hd, h) => {
|
|
1756
|
+
blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
|
|
1757
|
+
blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
|
|
1758
|
+
blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
|
|
1759
|
+
});
|
|
1760
|
+
blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
|
|
1761
|
+
blk.norm1.gamma = [...bd.norm1.gamma];
|
|
1762
|
+
blk.norm1.beta = [...bd.norm1.beta];
|
|
1763
|
+
blk.norm2.gamma = [...bd.norm2.gamma];
|
|
1764
|
+
blk.norm2.beta = [...bd.norm2.beta];
|
|
1765
|
+
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1766
|
+
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1767
|
+
blk.b1 = [...bd.b1];
|
|
1768
|
+
blk.b2 = [...bd.b2];
|
|
1769
|
+
});
|
|
1770
|
+
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1771
|
+
this.outputBias = [...data.outputBias];
|
|
1772
|
+
}
|
|
1773
|
+
// ── Serializable interface (flat array) ────────────────────────────────────
|
|
1774
|
+
// These satisfy the Serializable interface from ModelSaver, which requires
|
|
1775
|
+
// getWeights(): number[] and setWeights(weights: number[]): void.
|
|
1776
|
+
getWeights() {
|
|
1777
|
+
return this.getWeightsFlat();
|
|
1778
|
+
}
|
|
1779
|
+
setWeights(weights) {
|
|
1780
|
+
this.setWeightsFlat(weights);
|
|
1781
|
+
}
|
|
1193
1782
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1194
1783
|
_forward(sequence) {
|
|
1195
1784
|
let h = sequence.map(
|
|
@@ -1203,6 +1792,44 @@ var NetworkTransformerRL = class {
|
|
|
1203
1792
|
return h;
|
|
1204
1793
|
}
|
|
1205
1794
|
_pool(h) {
|
|
1795
|
+
switch (this._pooling) {
|
|
1796
|
+
case "avg":
|
|
1797
|
+
return this._poolAvg(h);
|
|
1798
|
+
case "max":
|
|
1799
|
+
return this._poolMax(h);
|
|
1800
|
+
case "last":
|
|
1801
|
+
return this._poolLast(h);
|
|
1802
|
+
case "weighted":
|
|
1803
|
+
default:
|
|
1804
|
+
return this._poolWeighted(h);
|
|
1805
|
+
}
|
|
1806
|
+
}
|
|
1807
|
+
_poolAvg(h) {
|
|
1808
|
+
const n = h.length;
|
|
1809
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1810
|
+
let sum = 0;
|
|
1811
|
+
for (let i = 0; i < n; i++)
|
|
1812
|
+
sum += h[i][m];
|
|
1813
|
+
return sum / n;
|
|
1814
|
+
});
|
|
1815
|
+
}
|
|
1816
|
+
_poolMax(h) {
|
|
1817
|
+
this._argmax = new Array(this.d_model).fill(0);
|
|
1818
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1819
|
+
let maxVal = -Infinity;
|
|
1820
|
+
for (let i = 0; i < h.length; i++) {
|
|
1821
|
+
if (h[i][m] > maxVal) {
|
|
1822
|
+
maxVal = h[i][m];
|
|
1823
|
+
this._argmax[m] = i;
|
|
1824
|
+
}
|
|
1825
|
+
}
|
|
1826
|
+
return maxVal;
|
|
1827
|
+
});
|
|
1828
|
+
}
|
|
1829
|
+
_poolLast(h) {
|
|
1830
|
+
return [...h[h.length - 1]];
|
|
1831
|
+
}
|
|
1832
|
+
_poolWeighted(h) {
|
|
1206
1833
|
const weights = Array.from(
|
|
1207
1834
|
{ length: this.seqLen },
|
|
1208
1835
|
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
@@ -1215,6 +1842,55 @@ var NetworkTransformerRL = class {
|
|
|
1215
1842
|
return sum / totalWeight;
|
|
1216
1843
|
});
|
|
1217
1844
|
}
|
|
1845
|
+
/** Returns the current pooling type for inspection. */
|
|
1846
|
+
getPoolingType() {
|
|
1847
|
+
return this._pooling;
|
|
1848
|
+
}
|
|
1849
|
+
// ── Helper: distribute pooled gradient back to each position ────────────────
|
|
1850
|
+
// Must match the same distribution as _pool() used during forward.
|
|
1851
|
+
_distributePoolGradient(dPooled) {
|
|
1852
|
+
switch (this._pooling) {
|
|
1853
|
+
case "avg": {
|
|
1854
|
+
const n = this.seqLen;
|
|
1855
|
+
return Array.from(
|
|
1856
|
+
{ length: n },
|
|
1857
|
+
() => dPooled.map((v) => v / n)
|
|
1858
|
+
);
|
|
1859
|
+
}
|
|
1860
|
+
case "max": {
|
|
1861
|
+
if (!this._argmax) {
|
|
1862
|
+
const n = this.seqLen;
|
|
1863
|
+
return Array.from(
|
|
1864
|
+
{ length: n },
|
|
1865
|
+
() => dPooled.map((v) => v / n)
|
|
1866
|
+
);
|
|
1867
|
+
}
|
|
1868
|
+
const argmax = this._argmax;
|
|
1869
|
+
return Array.from(
|
|
1870
|
+
{ length: this.seqLen },
|
|
1871
|
+
(_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
|
|
1872
|
+
);
|
|
1873
|
+
}
|
|
1874
|
+
case "last": {
|
|
1875
|
+
return Array.from(
|
|
1876
|
+
{ length: this.seqLen },
|
|
1877
|
+
(_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
|
|
1878
|
+
);
|
|
1879
|
+
}
|
|
1880
|
+
case "weighted":
|
|
1881
|
+
default: {
|
|
1882
|
+
const weights = Array.from(
|
|
1883
|
+
{ length: this.seqLen },
|
|
1884
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1885
|
+
);
|
|
1886
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1887
|
+
return Array.from(
|
|
1888
|
+
{ length: this.seqLen },
|
|
1889
|
+
(_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
|
|
1890
|
+
);
|
|
1891
|
+
}
|
|
1892
|
+
}
|
|
1893
|
+
}
|
|
1218
1894
|
};
|
|
1219
1895
|
|
|
1220
1896
|
// src/losses.ts
|
|
@@ -1239,14 +1915,802 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1239
1915
|
const p = Math.max(eps, Math.min(1 - eps, predicted));
|
|
1240
1916
|
return actual / p - (1 - actual) / (1 - p);
|
|
1241
1917
|
}
|
|
1918
|
+
|
|
1919
|
+
// src/GRU.ts
|
|
1920
|
+
function sigmoid4(x) {
|
|
1921
|
+
return 1 / (1 + Math.exp(-x));
|
|
1922
|
+
}
|
|
1923
|
+
function tanhFn(x) {
|
|
1924
|
+
const e = Math.exp(2 * x);
|
|
1925
|
+
return (e - 1) / (e + 1);
|
|
1926
|
+
}
|
|
1927
|
+
var Gate2 = class {
|
|
1928
|
+
constructor(inputSize, hSize, initBias = 0) {
|
|
1929
|
+
const n = inputSize + hSize;
|
|
1930
|
+
const limit = Math.sqrt(2 / n);
|
|
1931
|
+
this.W = Array.from(
|
|
1932
|
+
{ length: hSize },
|
|
1933
|
+
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
1934
|
+
);
|
|
1935
|
+
this.b = new Array(hSize).fill(initBias);
|
|
1936
|
+
}
|
|
1937
|
+
linear(combined) {
|
|
1938
|
+
return this.W.map(
|
|
1939
|
+
(row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
|
|
1940
|
+
);
|
|
1941
|
+
}
|
|
1942
|
+
};
|
|
1943
|
+
var GRULayer = class {
|
|
1944
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
1945
|
+
this._traj = [];
|
|
1946
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
1947
|
+
throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
1948
|
+
}
|
|
1949
|
+
this.inputSize = inputSize;
|
|
1950
|
+
this.hSize = hiddenSize;
|
|
1951
|
+
this.h = new Array(hiddenSize).fill(0);
|
|
1952
|
+
this.resetGate = new Gate2(inputSize, hiddenSize);
|
|
1953
|
+
this.updateGate = new Gate2(inputSize, hiddenSize);
|
|
1954
|
+
this.newGate = new Gate2(inputSize, hiddenSize);
|
|
1955
|
+
const combSize = inputSize + hiddenSize;
|
|
1956
|
+
this._optimizers = {
|
|
1957
|
+
resetW: Array.from(
|
|
1958
|
+
{ length: hiddenSize },
|
|
1959
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1960
|
+
),
|
|
1961
|
+
resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1962
|
+
updateW: Array.from(
|
|
1963
|
+
{ length: hiddenSize },
|
|
1964
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1965
|
+
),
|
|
1966
|
+
updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1967
|
+
newW: Array.from(
|
|
1968
|
+
{ length: hiddenSize },
|
|
1969
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1970
|
+
),
|
|
1971
|
+
newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
1972
|
+
};
|
|
1973
|
+
}
|
|
1974
|
+
reset() {
|
|
1975
|
+
this.h = new Array(this.hSize).fill(0);
|
|
1976
|
+
this._traj = [];
|
|
1977
|
+
}
|
|
1978
|
+
predict(inputs) {
|
|
1979
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
1980
|
+
throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
1981
|
+
}
|
|
1982
|
+
const combined = [...inputs, ...this.h];
|
|
1983
|
+
const h_prev = [...this.h];
|
|
1984
|
+
const r_pre = this.resetGate.linear(combined);
|
|
1985
|
+
const z_pre = this.updateGate.linear(combined);
|
|
1986
|
+
const r_a = r_pre.map(sigmoid4);
|
|
1987
|
+
const z_a = z_pre.map(sigmoid4);
|
|
1988
|
+
const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
|
|
1989
|
+
const n_pre = this.newGate.linear(combined_r);
|
|
1990
|
+
const n_a = n_pre.map(tanhFn);
|
|
1991
|
+
const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
|
|
1992
|
+
this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
|
|
1993
|
+
this.h = h;
|
|
1994
|
+
return h;
|
|
1995
|
+
}
|
|
1996
|
+
backprop(dh_seq, lr) {
|
|
1997
|
+
const T = this._traj.length;
|
|
1998
|
+
if (T === 0 || dh_seq.length !== T) return;
|
|
1999
|
+
const hSize = this.hSize;
|
|
2000
|
+
const combSize = this.inputSize + hSize;
|
|
2001
|
+
const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
2002
|
+
const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
2003
|
+
const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
2004
|
+
const dbr = new Array(hSize).fill(0);
|
|
2005
|
+
const dbz = new Array(hSize).fill(0);
|
|
2006
|
+
const dbn = new Array(hSize).fill(0);
|
|
2007
|
+
let dh_next = new Array(hSize).fill(0);
|
|
2008
|
+
for (let t = T - 1; t >= 0; t--) {
|
|
2009
|
+
const s = this._traj[t];
|
|
2010
|
+
const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
|
|
2011
|
+
const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
|
|
2012
|
+
const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
|
|
2013
|
+
const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
|
|
2014
|
+
const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
|
|
2015
|
+
const dr_hprev = Array.from(
|
|
2016
|
+
{ length: hSize },
|
|
2017
|
+
(_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
|
|
2018
|
+
);
|
|
2019
|
+
const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
|
|
2020
|
+
const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
|
|
2021
|
+
for (let k = 0; k < hSize; k++) {
|
|
2022
|
+
for (let j = 0; j < combSize; j++) {
|
|
2023
|
+
dWr[k][j] += dr_pre[k] * s.combined[j];
|
|
2024
|
+
dWz[k][j] += dz_pre[k] * s.combined[j];
|
|
2025
|
+
dWn[k][j] += dn_pre[k] * s.combined_r[j];
|
|
2026
|
+
}
|
|
2027
|
+
dbr[k] += dr_pre[k];
|
|
2028
|
+
dbz[k] += dz_pre[k];
|
|
2029
|
+
dbn[k] += dn_pre[k];
|
|
2030
|
+
}
|
|
2031
|
+
dh_next = new Array(hSize).fill(0);
|
|
2032
|
+
for (let k = 0; k < hSize; k++) {
|
|
2033
|
+
for (let j = this.inputSize; j < combSize; j++) {
|
|
2034
|
+
dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
|
|
2035
|
+
}
|
|
2036
|
+
dh_next[k] += dr_hprev[k] * s.r_a[k];
|
|
2037
|
+
dh_next[k] += dh[k] * s.z_a[k];
|
|
2038
|
+
}
|
|
2039
|
+
}
|
|
2040
|
+
const scale = lr / T;
|
|
2041
|
+
for (let k = 0; k < hSize; k++) {
|
|
2042
|
+
for (let j = 0; j < combSize; j++) {
|
|
2043
|
+
this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
|
|
2044
|
+
this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
|
|
2045
|
+
this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
|
|
2046
|
+
}
|
|
2047
|
+
this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
|
|
2048
|
+
this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
|
|
2049
|
+
this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
|
|
2050
|
+
}
|
|
2051
|
+
this._traj = [];
|
|
2052
|
+
}
|
|
2053
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2054
|
+
// Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
|
|
2055
|
+
getWeightsFlat() {
|
|
2056
|
+
const w = [];
|
|
2057
|
+
for (const row of this.resetGate.W) w.push(...row);
|
|
2058
|
+
w.push(...this.resetGate.b);
|
|
2059
|
+
for (const row of this.updateGate.W) w.push(...row);
|
|
2060
|
+
w.push(...this.updateGate.b);
|
|
2061
|
+
for (const row of this.newGate.W) w.push(...row);
|
|
2062
|
+
w.push(...this.newGate.b);
|
|
2063
|
+
return w;
|
|
2064
|
+
}
|
|
2065
|
+
setWeightsFlat(weights) {
|
|
2066
|
+
let idx = 0;
|
|
2067
|
+
for (let i = 0; i < this.resetGate.W.length; i++)
|
|
2068
|
+
for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
|
|
2069
|
+
for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
|
|
2070
|
+
for (let i = 0; i < this.updateGate.W.length; i++)
|
|
2071
|
+
for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
|
|
2072
|
+
for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
|
|
2073
|
+
for (let i = 0; i < this.newGate.W.length; i++)
|
|
2074
|
+
for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
|
|
2075
|
+
for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
|
|
2076
|
+
}
|
|
2077
|
+
getWeights() {
|
|
2078
|
+
return {
|
|
2079
|
+
resetGate: { W: this.resetGate.W, b: this.resetGate.b },
|
|
2080
|
+
updateGate: { W: this.updateGate.W, b: this.updateGate.b },
|
|
2081
|
+
newGate: { W: this.newGate.W, b: this.newGate.b }
|
|
2082
|
+
};
|
|
2083
|
+
}
|
|
2084
|
+
setWeights(data) {
|
|
2085
|
+
this.resetGate.W = data.resetGate.W;
|
|
2086
|
+
this.resetGate.b = data.resetGate.b;
|
|
2087
|
+
this.updateGate.W = data.updateGate.W;
|
|
2088
|
+
this.updateGate.b = data.updateGate.b;
|
|
2089
|
+
this.newGate.W = data.newGate.W;
|
|
2090
|
+
this.newGate.b = data.newGate.b;
|
|
2091
|
+
}
|
|
2092
|
+
};
|
|
2093
|
+
|
|
2094
|
+
// src/BatchNorm.ts
|
|
2095
|
+
var BatchNorm = class {
|
|
2096
|
+
constructor(dim, momentum = 0.1) {
|
|
2097
|
+
this._xNorm = null;
|
|
2098
|
+
this._std = null;
|
|
2099
|
+
this.dim = dim;
|
|
2100
|
+
this.momentum = momentum;
|
|
2101
|
+
this.gamma = new Array(dim).fill(1);
|
|
2102
|
+
this.beta = new Array(dim).fill(0);
|
|
2103
|
+
this.runningMean = new Array(dim).fill(0);
|
|
2104
|
+
this.runningVar = new Array(dim).fill(1);
|
|
2105
|
+
}
|
|
2106
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2107
|
+
forward(x) {
|
|
2108
|
+
if (x.length !== this.dim) {
|
|
2109
|
+
throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
|
|
2110
|
+
}
|
|
2111
|
+
const eps = 1e-5;
|
|
2112
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2113
|
+
this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
|
|
2114
|
+
const diff = x[i] - this.runningMean[i];
|
|
2115
|
+
this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
|
|
2116
|
+
}
|
|
2117
|
+
this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
|
|
2118
|
+
this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
|
|
2119
|
+
return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
|
|
2120
|
+
}
|
|
2121
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2122
|
+
backward(dOut) {
|
|
2123
|
+
if (!this._xNorm || !this._std) {
|
|
2124
|
+
throw new Error("BatchNorm.backward: call forward() first");
|
|
2125
|
+
}
|
|
2126
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2127
|
+
}
|
|
2128
|
+
return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
|
|
2129
|
+
}
|
|
2130
|
+
// ── Train gamma and beta (call after backward) ────────────────────────────
|
|
2131
|
+
trainParams(dOut, lr) {
|
|
2132
|
+
if (!this._xNorm) return;
|
|
2133
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2134
|
+
this.gamma[i] += lr * dOut[i] * this._xNorm[i];
|
|
2135
|
+
this.beta[i] += lr * dOut[i];
|
|
2136
|
+
}
|
|
2137
|
+
}
|
|
2138
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2139
|
+
// Order: gamma, beta.
|
|
2140
|
+
getWeights() {
|
|
2141
|
+
return [...this.gamma, ...this.beta];
|
|
2142
|
+
}
|
|
2143
|
+
setWeights(weights) {
|
|
2144
|
+
for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
|
|
2145
|
+
for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
|
|
2146
|
+
}
|
|
2147
|
+
};
|
|
2148
|
+
|
|
2149
|
+
// src/Conv1D.ts
|
|
2150
|
+
var Conv1D = class {
|
|
2151
|
+
constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
|
|
2152
|
+
// [filters]
|
|
2153
|
+
this._input = null;
|
|
2154
|
+
this._paddedInput = null;
|
|
2155
|
+
if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
|
|
2156
|
+
throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
|
|
2157
|
+
}
|
|
2158
|
+
if (kernelSize > inputLength && padding === "valid") {
|
|
2159
|
+
throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
|
|
2160
|
+
}
|
|
2161
|
+
if (inputChannels < 1) {
|
|
2162
|
+
throw new Error("Conv1D: inputChannels must be >= 1");
|
|
2163
|
+
}
|
|
2164
|
+
this.inputLength = inputLength;
|
|
2165
|
+
this.kernelSize = kernelSize;
|
|
2166
|
+
this.filters = filters;
|
|
2167
|
+
this.stride = stride;
|
|
2168
|
+
this.padding = padding;
|
|
2169
|
+
this.inputChannels = inputChannels;
|
|
2170
|
+
const limit = Math.sqrt(2 / (kernelSize * inputChannels));
|
|
2171
|
+
this.kernels = Array.from(
|
|
2172
|
+
{ length: filters },
|
|
2173
|
+
() => Array.from(
|
|
2174
|
+
{ length: kernelSize },
|
|
2175
|
+
() => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
|
|
2176
|
+
)
|
|
2177
|
+
);
|
|
2178
|
+
this.biases = new Array(filters).fill(0);
|
|
2179
|
+
this._kOpts = Array.from(
|
|
2180
|
+
{ length: filters },
|
|
2181
|
+
() => Array.from(
|
|
2182
|
+
{ length: kernelSize },
|
|
2183
|
+
() => Array.from({ length: inputChannels }, () => optimizerFactory())
|
|
2184
|
+
)
|
|
2185
|
+
);
|
|
2186
|
+
this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
|
|
2187
|
+
}
|
|
2188
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2189
|
+
// Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
|
|
2190
|
+
forward(input) {
|
|
2191
|
+
const input2D = this._normalizeInput(input);
|
|
2192
|
+
this._input = input2D.map((row) => [...row]);
|
|
2193
|
+
let padded;
|
|
2194
|
+
if (this.padding === "same") {
|
|
2195
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2196
|
+
const padRow = new Array(this.inputChannels).fill(0);
|
|
2197
|
+
padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
|
|
2198
|
+
} else {
|
|
2199
|
+
padded = input2D;
|
|
2200
|
+
}
|
|
2201
|
+
this._paddedInput = padded;
|
|
2202
|
+
const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
|
|
2203
|
+
const output = Array.from(
|
|
2204
|
+
{ length: this.filters },
|
|
2205
|
+
() => new Array(outputLength).fill(0)
|
|
2206
|
+
);
|
|
2207
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2208
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2209
|
+
const start = pos * this.stride;
|
|
2210
|
+
let sum = this.biases[f];
|
|
2211
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2212
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2213
|
+
sum += this.kernels[f][k][c] * padded[start + k][c];
|
|
2214
|
+
}
|
|
2215
|
+
}
|
|
2216
|
+
output[f][pos] = sum;
|
|
2217
|
+
}
|
|
2218
|
+
}
|
|
2219
|
+
return output;
|
|
2220
|
+
}
|
|
2221
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2222
|
+
backward(dOut, lr = 1e-3) {
|
|
2223
|
+
if (!this._paddedInput || !this._input) {
|
|
2224
|
+
throw new Error("Conv1D.backward: call forward() first");
|
|
2225
|
+
}
|
|
2226
|
+
const padded = this._paddedInput;
|
|
2227
|
+
const outputLength = dOut[0].length;
|
|
2228
|
+
const dKernels = Array.from(
|
|
2229
|
+
{ length: this.filters },
|
|
2230
|
+
() => Array.from(
|
|
2231
|
+
{ length: this.kernelSize },
|
|
2232
|
+
() => new Array(this.inputChannels).fill(0)
|
|
2233
|
+
)
|
|
2234
|
+
);
|
|
2235
|
+
const dBiases = new Array(this.filters).fill(0);
|
|
2236
|
+
const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
|
|
2237
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2238
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2239
|
+
const start = pos * this.stride;
|
|
2240
|
+
dBiases[f] += dOut[f][pos];
|
|
2241
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2242
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2243
|
+
dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
|
|
2244
|
+
dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
|
|
2245
|
+
}
|
|
2246
|
+
}
|
|
2247
|
+
}
|
|
2248
|
+
}
|
|
2249
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2250
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2251
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2252
|
+
this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
2256
|
+
}
|
|
2257
|
+
if (this.padding === "same") {
|
|
2258
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2259
|
+
return dPadded.slice(padSize, padSize + this.inputLength);
|
|
2260
|
+
}
|
|
2261
|
+
return dPadded.slice(0, this.inputLength);
|
|
2262
|
+
}
|
|
2263
|
+
// ── Output length ─────────────────────────────────────────────────────────
|
|
2264
|
+
getOutputLength() {
|
|
2265
|
+
if (this.padding === "same") {
|
|
2266
|
+
return Math.ceil(this.inputLength / this.stride);
|
|
2267
|
+
}
|
|
2268
|
+
return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
|
|
2269
|
+
}
|
|
2270
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2271
|
+
// Order: kernels (flattened), biases.
|
|
2272
|
+
getWeights() {
|
|
2273
|
+
const w = [];
|
|
2274
|
+
for (const kernel of this.kernels)
|
|
2275
|
+
for (const k of kernel)
|
|
2276
|
+
for (const c of k)
|
|
2277
|
+
w.push(c);
|
|
2278
|
+
w.push(...this.biases);
|
|
2279
|
+
return w;
|
|
2280
|
+
}
|
|
2281
|
+
setWeights(weights) {
|
|
2282
|
+
let idx = 0;
|
|
2283
|
+
for (let f = 0; f < this.filters; f++)
|
|
2284
|
+
for (let k = 0; k < this.kernelSize; k++)
|
|
2285
|
+
for (let c = 0; c < this.inputChannels; c++)
|
|
2286
|
+
this.kernels[f][k][c] = weights[idx++];
|
|
2287
|
+
for (let f = 0; f < this.filters; f++)
|
|
2288
|
+
this.biases[f] = weights[idx++];
|
|
2289
|
+
}
|
|
2290
|
+
// ── Normalize input to 2D format ─────────────────────────────────────────
|
|
2291
|
+
_normalizeInput(input) {
|
|
2292
|
+
if (input.length === 0) {
|
|
2293
|
+
throw new Error("Conv1D.forward: input cannot be empty");
|
|
2294
|
+
}
|
|
2295
|
+
if (typeof input[0] === "number") {
|
|
2296
|
+
if (this.inputChannels !== 1) {
|
|
2297
|
+
throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
|
|
2298
|
+
}
|
|
2299
|
+
const input1D = input;
|
|
2300
|
+
if (input1D.length !== this.inputLength) {
|
|
2301
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
|
|
2302
|
+
}
|
|
2303
|
+
return input1D.map((v) => [v]);
|
|
2304
|
+
}
|
|
2305
|
+
const input2D = input;
|
|
2306
|
+
if (input2D.length !== this.inputLength) {
|
|
2307
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
|
|
2308
|
+
}
|
|
2309
|
+
for (let i = 0; i < input2D.length; i++) {
|
|
2310
|
+
if (input2D[i].length !== this.inputChannels) {
|
|
2311
|
+
throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
|
|
2312
|
+
}
|
|
2313
|
+
}
|
|
2314
|
+
return input2D;
|
|
2315
|
+
}
|
|
2316
|
+
};
|
|
2317
|
+
|
|
2318
|
+
// src/Trainer.ts
|
|
2319
|
+
var Trainer = class {
|
|
2320
|
+
constructor(network, options = {}) {
|
|
2321
|
+
this._history = [];
|
|
2322
|
+
this._bestLoss = Infinity;
|
|
2323
|
+
this._patienceCounter = 0;
|
|
2324
|
+
this._stopReason = "maxEpochs";
|
|
2325
|
+
this._metrics = [];
|
|
2326
|
+
this.network = network;
|
|
2327
|
+
this.epochs = options.epochs ?? 1e3;
|
|
2328
|
+
this.lrInitial = options.lr ?? 0.1;
|
|
2329
|
+
this.lrDecay = options.lrDecay ?? 1;
|
|
2330
|
+
this.verbose = options.verbose ?? false;
|
|
2331
|
+
this.weightDecay = options.weightDecay ?? 0;
|
|
2332
|
+
this._earlyStopping = options.earlyStopping;
|
|
2333
|
+
this._computeMetrics = options.computeMetrics ?? false;
|
|
2334
|
+
this.clipValue = options.clipValue ?? 0;
|
|
2335
|
+
}
|
|
2336
|
+
// ── Set external validation data (for early stopping) ────────────────────
|
|
2337
|
+
setValidationData(dataset) {
|
|
2338
|
+
if (dataset.inputs.length !== dataset.targets.length) {
|
|
2339
|
+
throw new Error(
|
|
2340
|
+
"Trainer.setValidationData: inputs and targets must have the same length"
|
|
2341
|
+
);
|
|
2342
|
+
}
|
|
2343
|
+
this._validationData = dataset;
|
|
2344
|
+
}
|
|
2345
|
+
// ── Get best validation loss during training ─────────────────────────────
|
|
2346
|
+
getBestLoss() {
|
|
2347
|
+
return this._bestLoss === Infinity ? -1 : this._bestLoss;
|
|
2348
|
+
}
|
|
2349
|
+
// ── Why did training stop? ───────────────────────────────────────────────
|
|
2350
|
+
getStopReason() {
|
|
2351
|
+
return this._stopReason;
|
|
2352
|
+
}
|
|
2353
|
+
// ── Get per-epoch classification metrics ─────────────────────────────────
|
|
2354
|
+
getMetrics() {
|
|
2355
|
+
return [...this._metrics];
|
|
2356
|
+
}
|
|
2357
|
+
// ── Train on dataset ──────────────────────────────────────────────────────
|
|
2358
|
+
train(dataset) {
|
|
2359
|
+
const { inputs, targets } = dataset;
|
|
2360
|
+
if (inputs.length !== targets.length) {
|
|
2361
|
+
throw new Error(
|
|
2362
|
+
"Trainer.train: inputs and targets must have the same length"
|
|
2363
|
+
);
|
|
2364
|
+
}
|
|
2365
|
+
const n = inputs.length;
|
|
2366
|
+
let lr = this.lrInitial;
|
|
2367
|
+
this._history = [];
|
|
2368
|
+
this._bestLoss = Infinity;
|
|
2369
|
+
this._patienceCounter = 0;
|
|
2370
|
+
this._stopReason = "maxEpochs";
|
|
2371
|
+
this._metrics = [];
|
|
2372
|
+
const netExt = this._hasWeights(this.network);
|
|
2373
|
+
if (this.weightDecay > 0 && !netExt) {
|
|
2374
|
+
console.warn(
|
|
2375
|
+
"Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
|
|
2376
|
+
);
|
|
2377
|
+
}
|
|
2378
|
+
if (this._earlyStopping && !netExt) {
|
|
2379
|
+
console.warn(
|
|
2380
|
+
"Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
|
|
2381
|
+
);
|
|
2382
|
+
}
|
|
2383
|
+
if (this._computeMetrics && !netExt) {
|
|
2384
|
+
console.warn(
|
|
2385
|
+
"Trainer: computeMetrics requires a network with predict(). Skipping metrics."
|
|
2386
|
+
);
|
|
2387
|
+
}
|
|
2388
|
+
const canDecay = this.weightDecay > 0 && netExt;
|
|
2389
|
+
const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
|
|
2390
|
+
const canMetric = this._computeMetrics && netExt;
|
|
2391
|
+
const isClass = canMetric && this._isClassification(targets);
|
|
2392
|
+
if (canMetric && !isClass) {
|
|
2393
|
+
console.warn(
|
|
2394
|
+
"Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
|
|
2395
|
+
);
|
|
2396
|
+
}
|
|
2397
|
+
for (let epoch = 0; epoch < this.epochs; epoch++) {
|
|
2398
|
+
const indices = Array.from({ length: n }, (_, i) => i);
|
|
2399
|
+
for (let i = n - 1; i > 0; i--) {
|
|
2400
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2401
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
2402
|
+
}
|
|
2403
|
+
let epochLoss = 0;
|
|
2404
|
+
for (const i of indices) {
|
|
2405
|
+
if (canDecay) {
|
|
2406
|
+
const w = netExt.getWeights();
|
|
2407
|
+
for (let j = 0; j < w.length; j++) {
|
|
2408
|
+
w[j] *= 1 - lr * this.weightDecay;
|
|
2409
|
+
}
|
|
2410
|
+
netExt.setWeights(w);
|
|
2411
|
+
}
|
|
2412
|
+
epochLoss += this.network.train(inputs[i], targets[i], lr);
|
|
2413
|
+
}
|
|
2414
|
+
epochLoss /= n;
|
|
2415
|
+
this._history.push(epochLoss);
|
|
2416
|
+
if (canMetric && isClass) {
|
|
2417
|
+
this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
|
|
2418
|
+
}
|
|
2419
|
+
if (canValidate && this._validationData) {
|
|
2420
|
+
const valLoss = this._computeLoss(netExt, this._validationData);
|
|
2421
|
+
const minDelta = this._earlyStopping.minDelta;
|
|
2422
|
+
if (valLoss < this._bestLoss - minDelta) {
|
|
2423
|
+
this._bestLoss = valLoss;
|
|
2424
|
+
this._patienceCounter = 0;
|
|
2425
|
+
} else {
|
|
2426
|
+
this._patienceCounter++;
|
|
2427
|
+
}
|
|
2428
|
+
if (this._patienceCounter >= this._earlyStopping.patience) {
|
|
2429
|
+
this._stopReason = "earlyStopping";
|
|
2430
|
+
break;
|
|
2431
|
+
}
|
|
2432
|
+
}
|
|
2433
|
+
lr *= this.lrDecay;
|
|
2434
|
+
if (this.verbose && (epoch + 1) % 100 === 0) {
|
|
2435
|
+
console.log(
|
|
2436
|
+
`Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
|
|
2437
|
+
);
|
|
2438
|
+
}
|
|
2439
|
+
}
|
|
2440
|
+
return this._history;
|
|
2441
|
+
}
|
|
2442
|
+
// ── Get loss history ──────────────────────────────────────────────────────
|
|
2443
|
+
getHistory() {
|
|
2444
|
+
return [...this._history];
|
|
2445
|
+
}
|
|
2446
|
+
// ── Private helpers ───────────────────────────────────────────────────────
|
|
2447
|
+
/** Type guard: does this network support getWeights/setWeights/predict? */
|
|
2448
|
+
_hasWeights(network) {
|
|
2449
|
+
if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
|
|
2450
|
+
return network;
|
|
2451
|
+
}
|
|
2452
|
+
return null;
|
|
2453
|
+
}
|
|
2454
|
+
/** Mean squared error on a dataset (used for validation loss). */
|
|
2455
|
+
_computeLoss(net, data) {
|
|
2456
|
+
let totalLoss = 0;
|
|
2457
|
+
for (let i = 0; i < data.inputs.length; i++) {
|
|
2458
|
+
const pred = net.predict(data.inputs[i]);
|
|
2459
|
+
const target = data.targets[i];
|
|
2460
|
+
let sampleLoss = 0;
|
|
2461
|
+
for (let j = 0; j < pred.length; j++) {
|
|
2462
|
+
sampleLoss += (target[j] - pred[j]) ** 2;
|
|
2463
|
+
}
|
|
2464
|
+
totalLoss += sampleLoss / pred.length;
|
|
2465
|
+
}
|
|
2466
|
+
return totalLoss / data.inputs.length;
|
|
2467
|
+
}
|
|
2468
|
+
/** Heuristic: are targets classification-style (one-hot or single-class)? */
|
|
2469
|
+
_isClassification(targets) {
|
|
2470
|
+
if (targets.length === 0) return false;
|
|
2471
|
+
const first = targets[0];
|
|
2472
|
+
if (first.length === 1) return true;
|
|
2473
|
+
for (const t of targets) {
|
|
2474
|
+
let sum = 0;
|
|
2475
|
+
for (const v of t) {
|
|
2476
|
+
sum += v;
|
|
2477
|
+
if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
|
|
2478
|
+
return false;
|
|
2479
|
+
}
|
|
2480
|
+
if (Math.abs(sum - 1) > 0.01) return false;
|
|
2481
|
+
}
|
|
2482
|
+
return true;
|
|
2483
|
+
}
|
|
2484
|
+
/** Compute classification metrics from predictions vs targets. */
|
|
2485
|
+
_computeMetricsArray(net, inputs, targets) {
|
|
2486
|
+
const targetLen = targets[0].length;
|
|
2487
|
+
const nClasses = targetLen === 1 ? 2 : targetLen;
|
|
2488
|
+
const confusion = Array.from(
|
|
2489
|
+
{ length: nClasses },
|
|
2490
|
+
() => Array(nClasses).fill(0)
|
|
2491
|
+
);
|
|
2492
|
+
for (let i = 0; i < inputs.length; i++) {
|
|
2493
|
+
const pred = net.predict(inputs[i]);
|
|
2494
|
+
const target = targets[i];
|
|
2495
|
+
let predClass;
|
|
2496
|
+
let trueClass;
|
|
2497
|
+
if (targetLen === 1) {
|
|
2498
|
+
trueClass = target[0] >= 0.5 ? 1 : 0;
|
|
2499
|
+
if (pred.length === 1) {
|
|
2500
|
+
predClass = pred[0] >= 0.5 ? 1 : 0;
|
|
2501
|
+
} else {
|
|
2502
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2503
|
+
}
|
|
2504
|
+
} else {
|
|
2505
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2506
|
+
trueClass = target.indexOf(Math.max(...target));
|
|
2507
|
+
}
|
|
2508
|
+
predClass = Math.max(0, Math.min(nClasses - 1, predClass));
|
|
2509
|
+
trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
|
|
2510
|
+
confusion[trueClass][predClass]++;
|
|
2511
|
+
}
|
|
2512
|
+
let totalCorrect = 0;
|
|
2513
|
+
let totalSamples = 0;
|
|
2514
|
+
const precisions = [];
|
|
2515
|
+
const recalls = [];
|
|
2516
|
+
for (let c = 0; c < nClasses; c++) {
|
|
2517
|
+
const tp = confusion[c][c];
|
|
2518
|
+
totalCorrect += tp;
|
|
2519
|
+
let colSum = 0;
|
|
2520
|
+
let rowSum = 0;
|
|
2521
|
+
for (let r = 0; r < nClasses; r++) {
|
|
2522
|
+
colSum += confusion[r][c];
|
|
2523
|
+
rowSum += confusion[c][r];
|
|
2524
|
+
}
|
|
2525
|
+
totalSamples += rowSum;
|
|
2526
|
+
precisions.push(colSum > 0 ? tp / colSum : 0);
|
|
2527
|
+
recalls.push(rowSum > 0 ? tp / rowSum : 0);
|
|
2528
|
+
}
|
|
2529
|
+
const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
|
|
2530
|
+
const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
|
|
2531
|
+
const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
|
|
2532
|
+
const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
|
|
2533
|
+
return {
|
|
2534
|
+
accuracy,
|
|
2535
|
+
precision: macroPrecision,
|
|
2536
|
+
recall: macroRecall,
|
|
2537
|
+
f1
|
|
2538
|
+
};
|
|
2539
|
+
}
|
|
2540
|
+
};
|
|
2541
|
+
|
|
2542
|
+
// src/DataLoader.ts
|
|
2543
|
+
var DataLoader = class _DataLoader {
|
|
2544
|
+
constructor(data, batchSize = 1, validationSplit = 0) {
|
|
2545
|
+
if (data.inputs.length !== data.targets.length) {
|
|
2546
|
+
throw new Error("DataLoader: inputs and targets must have the same length");
|
|
2547
|
+
}
|
|
2548
|
+
if (validationSplit < 0 || validationSplit >= 1) {
|
|
2549
|
+
throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
|
|
2550
|
+
}
|
|
2551
|
+
this.data = data;
|
|
2552
|
+
this.batchSize = batchSize;
|
|
2553
|
+
this._validationSplit = validationSplit;
|
|
2554
|
+
const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
|
|
2555
|
+
for (let i = fullIndices.length - 1; i > 0; i--) {
|
|
2556
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2557
|
+
[fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
|
|
2558
|
+
}
|
|
2559
|
+
if (validationSplit > 0) {
|
|
2560
|
+
const valSize = Math.round(data.inputs.length * validationSplit);
|
|
2561
|
+
const trainSize = data.inputs.length - valSize;
|
|
2562
|
+
this._trainIndices = fullIndices.slice(0, trainSize);
|
|
2563
|
+
this._valIndices = fullIndices.slice(trainSize);
|
|
2564
|
+
} else {
|
|
2565
|
+
this._trainIndices = [...fullIndices];
|
|
2566
|
+
this._valIndices = [];
|
|
2567
|
+
}
|
|
2568
|
+
this._indices = [...this._trainIndices];
|
|
2569
|
+
this._pos = 0;
|
|
2570
|
+
}
|
|
2571
|
+
// ── Shuffle the training data ──────────────────────────────────────────────
|
|
2572
|
+
shuffle() {
|
|
2573
|
+
for (let i = this._trainIndices.length - 1; i > 0; i--) {
|
|
2574
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2575
|
+
[this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
|
|
2576
|
+
}
|
|
2577
|
+
this._indices = [...this._trainIndices];
|
|
2578
|
+
this._pos = 0;
|
|
2579
|
+
}
|
|
2580
|
+
// ── Check if more batches are available ───────────────────────────────────
|
|
2581
|
+
hasNext() {
|
|
2582
|
+
return this._pos < this._indices.length;
|
|
2583
|
+
}
|
|
2584
|
+
// ── Get next batch ────────────────────────────────────────────────────────
|
|
2585
|
+
next() {
|
|
2586
|
+
const end = Math.min(this._pos + this.batchSize, this._indices.length);
|
|
2587
|
+
const batchIndices = this._indices.slice(this._pos, end);
|
|
2588
|
+
this._pos = end;
|
|
2589
|
+
return {
|
|
2590
|
+
inputs: batchIndices.map((i) => this.data.inputs[i]),
|
|
2591
|
+
targets: batchIndices.map((i) => this.data.targets[i])
|
|
2592
|
+
};
|
|
2593
|
+
}
|
|
2594
|
+
// ── Reset iteration ───────────────────────────────────────────────────────
|
|
2595
|
+
reset() {
|
|
2596
|
+
this._pos = 0;
|
|
2597
|
+
}
|
|
2598
|
+
// ── Get total number of training samples ───────────────────────────────────
|
|
2599
|
+
get length() {
|
|
2600
|
+
return this._trainIndices.length;
|
|
2601
|
+
}
|
|
2602
|
+
// ── Get validation data as a DataPair ──────────────────────────────────────
|
|
2603
|
+
// Returns the validation samples (inputs + targets) in their shuffled order.
|
|
2604
|
+
// Returns empty arrays if no validation split was configured.
|
|
2605
|
+
getValidationData() {
|
|
2606
|
+
return {
|
|
2607
|
+
inputs: this._valIndices.map((i) => this.data.inputs[i]),
|
|
2608
|
+
targets: this._valIndices.map((i) => this.data.targets[i])
|
|
2609
|
+
};
|
|
2610
|
+
}
|
|
2611
|
+
// ── Get number of validation samples ───────────────────────────────────────
|
|
2612
|
+
get validationLength() {
|
|
2613
|
+
return this._valIndices.length;
|
|
2614
|
+
}
|
|
2615
|
+
// ── Create sequence windows from a time series ────────────────────────────
|
|
2616
|
+
static sequences(data, seqLen, validationSplit = 0) {
|
|
2617
|
+
if (data.length < seqLen + 1) {
|
|
2618
|
+
throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
|
|
2619
|
+
}
|
|
2620
|
+
const inputs = [];
|
|
2621
|
+
const targets = [];
|
|
2622
|
+
for (let i = 0; i <= data.length - seqLen - 1; i++) {
|
|
2623
|
+
inputs.push(data.slice(i, i + seqLen).flat());
|
|
2624
|
+
targets.push(data[i + seqLen]);
|
|
2625
|
+
}
|
|
2626
|
+
return new _DataLoader({ inputs, targets }, 1, validationSplit);
|
|
2627
|
+
}
|
|
2628
|
+
};
|
|
2629
|
+
|
|
2630
|
+
// src/LRScheduler.ts
|
|
2631
|
+
var LRScheduler = class {
|
|
2632
|
+
// ── Step Decay ────────────────────────────────────────────────────────────
|
|
2633
|
+
// lr = initialLr * dropRate^floor(epoch / epochsDrop)
|
|
2634
|
+
stepDecay(lr, epoch, dropRate, epochsDrop) {
|
|
2635
|
+
return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
|
|
2636
|
+
}
|
|
2637
|
+
// ── Exponential Decay ─────────────────────────────────────────────────────
|
|
2638
|
+
// lr = initialLr * decayRate^epoch
|
|
2639
|
+
exponentialDecay(lr, epoch, decayRate) {
|
|
2640
|
+
return lr * Math.pow(decayRate, epoch);
|
|
2641
|
+
}
|
|
2642
|
+
// ── Plateau Decay ─────────────────────────────────────────────────────────
|
|
2643
|
+
// If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
|
|
2644
|
+
// Returns the new lr. Call this after each epoch with the current loss.
|
|
2645
|
+
//
|
|
2646
|
+
// Usage:
|
|
2647
|
+
// let patience_counter = 0
|
|
2648
|
+
// let best_loss = Infinity
|
|
2649
|
+
// for (let epoch = 0; epoch < 1000; epoch++) {
|
|
2650
|
+
// const loss = train(...)
|
|
2651
|
+
// lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
|
|
2652
|
+
// }
|
|
2653
|
+
plateauDecay(lr, currentLoss, history, patience, factor) {
|
|
2654
|
+
if (history.length < patience) return lr;
|
|
2655
|
+
const recentLosses = history.slice(-patience);
|
|
2656
|
+
const minRecentLoss = Math.min(...recentLosses);
|
|
2657
|
+
if (currentLoss >= minRecentLoss) {
|
|
2658
|
+
return lr * factor;
|
|
2659
|
+
}
|
|
2660
|
+
return lr;
|
|
2661
|
+
}
|
|
2662
|
+
// ── Cosine Annealing ──────────────────────────────────────────────────────
|
|
2663
|
+
// lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
|
|
2664
|
+
cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
|
|
2665
|
+
return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
|
|
2666
|
+
}
|
|
2667
|
+
};
|
|
2668
|
+
|
|
2669
|
+
// src/ModelSaver.ts
|
|
2670
|
+
var ModelSaver = class _ModelSaver {
|
|
2671
|
+
// ── Serialize to JSON string ──────────────────────────────────────────────
|
|
2672
|
+
static toJSON(model) {
|
|
2673
|
+
return JSON.stringify({
|
|
2674
|
+
weights: model.getWeights(),
|
|
2675
|
+
timestamp: Date.now()
|
|
2676
|
+
});
|
|
2677
|
+
}
|
|
2678
|
+
// ── Deserialize from JSON string ──────────────────────────────────────────
|
|
2679
|
+
static fromJSON(model, json) {
|
|
2680
|
+
const data = JSON.parse(json);
|
|
2681
|
+
if (!data.weights || !Array.isArray(data.weights)) {
|
|
2682
|
+
throw new Error("ModelSaver.fromJSON: invalid model data");
|
|
2683
|
+
}
|
|
2684
|
+
model.setWeights(data.weights);
|
|
2685
|
+
}
|
|
2686
|
+
// ── Save to file (requires write function) ────────────────────────────────
|
|
2687
|
+
static saveToFile(model, path, writeFn) {
|
|
2688
|
+
const json = _ModelSaver.toJSON(model);
|
|
2689
|
+
writeFn(path, json);
|
|
2690
|
+
}
|
|
2691
|
+
// ── Load from file (requires read function) ───────────────────────────────
|
|
2692
|
+
static loadFromFile(model, path, readFn) {
|
|
2693
|
+
const json = readFn(path);
|
|
2694
|
+
_ModelSaver.fromJSON(model, json);
|
|
2695
|
+
}
|
|
2696
|
+
};
|
|
1242
2697
|
// Annotate the CommonJS export names for ESM import in node:
|
|
1243
2698
|
0 && (module.exports = {
|
|
1244
2699
|
Adam,
|
|
1245
2700
|
AttentionHead,
|
|
2701
|
+
BatchNorm,
|
|
2702
|
+
ClipOptimizer,
|
|
2703
|
+
ClippedOptimizerFactory,
|
|
2704
|
+
Conv1D,
|
|
2705
|
+
DataLoader,
|
|
2706
|
+
Dropout,
|
|
1246
2707
|
EmbeddingMatrix,
|
|
2708
|
+
GRULayer,
|
|
2709
|
+
LRScheduler,
|
|
1247
2710
|
LSTMLayer,
|
|
1248
2711
|
Layer,
|
|
1249
2712
|
LayerNorm,
|
|
2713
|
+
ModelSaver,
|
|
1250
2714
|
Momentum,
|
|
1251
2715
|
MultiHeadAttention,
|
|
1252
2716
|
Network,
|
|
@@ -1257,6 +2721,7 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1257
2721
|
Neuron,
|
|
1258
2722
|
NeuronN,
|
|
1259
2723
|
SGD,
|
|
2724
|
+
Trainer,
|
|
1260
2725
|
TransformerBlock,
|
|
1261
2726
|
WeightMatrix,
|
|
1262
2727
|
crossEntropy,
|
|
@@ -1275,5 +2740,9 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1275
2740
|
softmax,
|
|
1276
2741
|
softmaxBackward,
|
|
1277
2742
|
tanh,
|
|
1278
|
-
transpose
|
|
2743
|
+
transpose,
|
|
2744
|
+
validate2DArray,
|
|
2745
|
+
validateArray,
|
|
2746
|
+
validateArrayMinLength,
|
|
2747
|
+
validateNumber
|
|
1279
2748
|
});
|