@dniskav/neuron 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1,3 +1,71 @@
1
+ // src/Validation.ts
2
+ function validateArray(arr, expectedLength, methodName) {
3
+ if (!Array.isArray(arr)) {
4
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
5
+ }
6
+ if (arr.length !== expectedLength) {
7
+ throw new Error(
8
+ `${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
9
+ );
10
+ }
11
+ for (let i = 0; i < arr.length; i++) {
12
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
13
+ throw new Error(
14
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
15
+ );
16
+ }
17
+ }
18
+ }
19
+ function validateArrayMinLength(arr, minLength, methodName) {
20
+ if (!Array.isArray(arr)) {
21
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
22
+ }
23
+ if (arr.length < minLength) {
24
+ throw new Error(
25
+ `${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
26
+ );
27
+ }
28
+ for (let i = 0; i < arr.length; i++) {
29
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
30
+ throw new Error(
31
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
32
+ );
33
+ }
34
+ }
35
+ }
36
+ function validate2DArray(arr, expectedRows, expectedCols, methodName) {
37
+ if (!Array.isArray(arr)) {
38
+ throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
39
+ }
40
+ if (arr.length !== expectedRows) {
41
+ throw new Error(
42
+ `${methodName}: expected ${expectedRows} rows, got ${arr.length}`
43
+ );
44
+ }
45
+ for (let i = 0; i < arr.length; i++) {
46
+ if (!Array.isArray(arr[i])) {
47
+ throw new Error(`${methodName}: row ${i} is not an array`);
48
+ }
49
+ if (arr[i].length !== expectedCols) {
50
+ throw new Error(
51
+ `${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
52
+ );
53
+ }
54
+ for (let j = 0; j < arr[i].length; j++) {
55
+ if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
56
+ throw new Error(
57
+ `${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
58
+ );
59
+ }
60
+ }
61
+ }
62
+ }
63
+ function validateNumber(value, methodName) {
64
+ if (typeof value !== "number" || !isFinite(value)) {
65
+ throw new Error(`${methodName}: expected finite number, got ${value}`);
66
+ }
67
+ }
68
+
1
69
  // src/Neuron.ts
2
70
  function sigmoid(x) {
3
71
  return 1 / (1 + Math.exp(-x));
@@ -8,13 +76,18 @@ var Neuron = class {
8
76
  this.bias = Math.random() * 0.1;
9
77
  }
10
78
  predict(input) {
79
+ validateNumber(input, "Neuron.predict");
11
80
  return sigmoid(input * this.weight + this.bias);
12
81
  }
13
82
  train(input, target, lr) {
83
+ validateNumber(input, "Neuron.train");
84
+ validateNumber(target, "Neuron.train");
85
+ validateNumber(lr, "Neuron.train");
14
86
  const prediction = this.predict(input);
15
87
  const error = target - prediction;
16
- this.weight += lr * error * input;
17
- this.bias += lr * error;
88
+ const grad = error * prediction * (1 - prediction);
89
+ this.weight += lr * grad * input;
90
+ this.bias += lr * grad;
18
91
  }
19
92
  };
20
93
 
@@ -69,6 +142,19 @@ var Momentum = class {
69
142
  return weight + this.v;
70
143
  }
71
144
  };
145
+ var ClipOptimizer = class {
146
+ constructor(inner, clipValue) {
147
+ this.inner = inner;
148
+ this.clipValue = clipValue;
149
+ }
150
+ step(weight, gradient, lr) {
151
+ const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
152
+ return this.inner.step(weight, clipped, lr);
153
+ }
154
+ };
155
+ function ClippedOptimizerFactory(innerFactory, clipValue) {
156
+ return () => new ClipOptimizer(innerFactory(), clipValue);
157
+ }
72
158
  var Adam = class {
73
159
  constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
74
160
  this.beta1 = beta1;
@@ -99,6 +185,7 @@ var NeuronN = class {
99
185
  this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
100
186
  }
101
187
  predict(inputs) {
188
+ validateArray(inputs, this.weights.length, "NeuronN.predict");
102
189
  const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
103
190
  return this.activation.fn(sum);
104
191
  }
@@ -111,7 +198,8 @@ var NeuronN = class {
111
198
  train(inputs, target, lr) {
112
199
  const prediction = this.predict(inputs);
113
200
  const error = target - prediction;
114
- this._update(inputs.map((inp) => error * inp), error, lr);
201
+ const grad = error * this.activation.dfn(prediction);
202
+ this._update(inputs.map((inp) => grad * inp), grad, lr);
115
203
  }
116
204
  };
117
205
 
@@ -136,29 +224,99 @@ var Network = class {
136
224
  this.outputLayer = new Layer(nOutputs, nHidden);
137
225
  }
138
226
  predict(inputs) {
227
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
139
228
  const hiddenOut = this.hiddenLayer.predict(inputs);
140
229
  return this.outputLayer.predict(hiddenOut)[0];
141
230
  }
142
231
  // Trains on a single example. Returns the squared error.
143
232
  train(inputs, target, lr) {
233
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
234
+ validateNumber(target, "Network.train");
235
+ validateNumber(lr, "Network.train");
144
236
  const hiddenOut = this.hiddenLayer.predict(inputs);
145
237
  const prediction = this.outputLayer.predict(hiddenOut)[0];
146
238
  const outputError = target - prediction;
147
239
  const outputDelta = outputError * prediction * (1 - prediction);
148
240
  const outputNeuron = this.outputLayer.neurons[0];
241
+ const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
242
+ const hiddenOut_i = hiddenOut[i];
243
+ const hiddenError = outputDelta * outputNeuron.weights[i];
244
+ return hiddenError * hiddenOut_i * (1 - hiddenOut_i);
245
+ });
246
+ this.hiddenLayer.neurons.forEach((neuron, i) => {
247
+ neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDeltas[i] * inputs[j]);
248
+ neuron.bias += lr * hiddenDeltas[i];
249
+ });
149
250
  outputNeuron.weights = outputNeuron.weights.map(
150
251
  (w, i) => w + lr * outputDelta * hiddenOut[i]
151
252
  );
152
253
  outputNeuron.bias += lr * outputDelta;
153
- this.hiddenLayer.neurons.forEach((neuron, i) => {
154
- const hiddenOut_i = hiddenOut[i];
155
- const hiddenError = outputDelta * outputNeuron.weights[i];
156
- const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
157
- neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
158
- neuron.bias += lr * hiddenDelta;
159
- });
160
254
  return outputError * outputError;
161
255
  }
256
+ // ── Flat weight serialization ─────────────────────────────────────────────
257
+ // Order: hidden layer (all neurons: weights then bias), then output layer.
258
+ getWeights() {
259
+ const w = [];
260
+ for (const n of this.hiddenLayer.neurons) {
261
+ w.push(...n.weights, n.bias);
262
+ }
263
+ for (const n of this.outputLayer.neurons) {
264
+ w.push(...n.weights, n.bias);
265
+ }
266
+ return w;
267
+ }
268
+ setWeights(weights) {
269
+ let idx = 0;
270
+ for (const n of this.hiddenLayer.neurons) {
271
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
272
+ n.bias = weights[idx++];
273
+ }
274
+ for (const n of this.outputLayer.neurons) {
275
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
276
+ n.bias = weights[idx++];
277
+ }
278
+ }
279
+ };
280
+
281
+ // src/Dropout.ts
282
+ var Dropout = class {
283
+ constructor(rate) {
284
+ this._mask = null;
285
+ if (rate < 0 || rate >= 1) {
286
+ throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
287
+ }
288
+ this.rate = rate;
289
+ }
290
+ // ── Forward ───────────────────────────────────────────────────────────────
291
+ // x: number[] → number[]
292
+ // If training, applies inverted dropout mask.
293
+ // If not training, returns input unchanged.
294
+ forward(x, training = true) {
295
+ if (!training || this.rate === 0) {
296
+ this._mask = null;
297
+ return [...x];
298
+ }
299
+ const scale = 1 / (1 - this.rate);
300
+ this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
301
+ return x.map((v, i) => v * this._mask[i]);
302
+ }
303
+ // ── Backward ──────────────────────────────────────────────────────────────
304
+ // dOut: number[] → number[]
305
+ // Applies the same mask (gradient is zeroed where activation was zeroed).
306
+ backward(dOut) {
307
+ if (!this._mask) return [...dOut];
308
+ return dOut.map((d, i) => d * this._mask[i]);
309
+ }
310
+ // ── Reset mask between forward passes ─────────────────────────────────────
311
+ resetMask() {
312
+ this._mask = null;
313
+ }
314
+ // ── No trainable params ───────────────────────────────────────────────────
315
+ getWeights() {
316
+ return [];
317
+ }
318
+ setWeights(_weights) {
319
+ }
162
320
  };
163
321
 
164
322
  // src/NetworkN.ts
@@ -169,30 +327,96 @@ var NetworkN = class {
169
327
  const nLayers = structure.length - 1;
170
328
  const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
171
329
  const optimizer = options.optimizer ?? defaultOptimizer3;
330
+ const dropoutRate = options.dropoutRate ?? 0;
331
+ if (activations.length !== nLayers) {
332
+ throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
333
+ }
334
+ if (dropoutRate < 0 || dropoutRate >= 1) {
335
+ throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
336
+ }
337
+ this._residual = options.residual ?? false;
172
338
  this.layers = [];
173
339
  for (let i = 1; i < structure.length; i++) {
174
340
  this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
175
341
  }
342
+ this._dropouts = [];
343
+ if (dropoutRate > 0) {
344
+ for (let i = 0; i < nLayers - 1; i++) {
345
+ this._dropouts.push(new Dropout(dropoutRate));
346
+ }
347
+ }
348
+ const outputLayer = this.layers[this.layers.length - 1];
349
+ const outputActivation = outputLayer.neurons[0].activation;
350
+ for (let i = 1; i < outputLayer.neurons.length; i++) {
351
+ if (outputLayer.neurons[i].activation !== outputActivation) {
352
+ throw new Error("All output neurons must share the same activation function");
353
+ }
354
+ }
176
355
  }
177
- predict(inputs) {
178
- return this.layers.reduce((acc, layer) => layer.predict(acc), inputs);
356
+ predict(inputs, training = false) {
357
+ validateArray(inputs, this.structure[0], "NetworkN.predict");
358
+ let current = [...inputs];
359
+ for (let i = 0; i < this.layers.length; i++) {
360
+ const layerInput = [...current];
361
+ const layerOutput = this.layers[i].predict(current);
362
+ if (this._shouldResidual(i)) {
363
+ if (this.structure[i] === this.structure[i + 1]) {
364
+ current = layerOutput.map((v, j) => v + layerInput[j]);
365
+ } else {
366
+ current = [...layerOutput];
367
+ }
368
+ } else {
369
+ current = [...layerOutput];
370
+ }
371
+ if (i < this._dropouts.length) {
372
+ current = this._dropouts[i].forward(current, training);
373
+ }
374
+ }
375
+ return current;
179
376
  }
180
377
  // Generalized backpropagation across L layers.
181
378
  // Returns the mean squared error for the example.
182
379
  train(inputs, targets, lr) {
380
+ validateArray(inputs, this.structure[0], "NetworkN.train");
381
+ validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
183
382
  const act = [inputs];
184
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
383
+ for (let i = 0; i < this.layers.length; i++) {
384
+ const layerInput = act[act.length - 1];
385
+ const layerOutput = this.layers[i].predict(layerInput);
386
+ let current;
387
+ if (this._shouldResidual(i)) {
388
+ if (this.structure[i] === this.structure[i + 1]) {
389
+ current = layerOutput.map((v, j) => v + layerInput[j]);
390
+ } else {
391
+ current = [...layerOutput];
392
+ }
393
+ } else {
394
+ current = [...layerOutput];
395
+ }
396
+ if (i < this._dropouts.length) {
397
+ current = this._dropouts[i].forward(current, true);
398
+ }
399
+ act.push(current);
400
+ }
185
401
  const pred = act[act.length - 1];
186
402
  const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
187
403
  let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
188
404
  for (let l = this.layers.length - 1; l >= 0; l--) {
189
405
  const layer = this.layers[l];
406
+ if (l < this._dropouts.length) {
407
+ deltas = this._dropouts[l].backward(deltas);
408
+ }
190
409
  const layerIn = act[l];
191
410
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
192
411
  const prevDeltas = layerIn.map((out, j) => {
193
412
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
194
413
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
195
414
  });
415
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
416
+ for (let j = 0; j < prevDeltas.length; j++) {
417
+ prevDeltas[j] += deltas[j];
418
+ }
419
+ }
196
420
  layer.neurons.forEach((n, k) => {
197
421
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
198
422
  });
@@ -204,22 +428,74 @@ var NetworkN = class {
204
428
  // Useful for custom loss functions (e.g. physics-based gradients).
205
429
  trainWithDeltas(inputs, outputDeltas, lr) {
206
430
  const act = [inputs];
207
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
431
+ for (let i = 0; i < this.layers.length; i++) {
432
+ const layerInput = act[act.length - 1];
433
+ const layerOutput = this.layers[i].predict(layerInput);
434
+ let current;
435
+ if (this._shouldResidual(i)) {
436
+ if (this.structure[i] === this.structure[i + 1]) {
437
+ current = layerOutput.map((v, j) => v + layerInput[j]);
438
+ } else {
439
+ current = [...layerOutput];
440
+ }
441
+ } else {
442
+ current = [...layerOutput];
443
+ }
444
+ if (i < this._dropouts.length) {
445
+ current = this._dropouts[i].forward(current, true);
446
+ }
447
+ act.push(current);
448
+ }
208
449
  let deltas = outputDeltas;
209
450
  for (let l = this.layers.length - 1; l >= 0; l--) {
210
451
  const layer = this.layers[l];
452
+ if (l < this._dropouts.length) {
453
+ deltas = this._dropouts[l].backward(deltas);
454
+ }
211
455
  const layerIn = act[l];
212
456
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
213
457
  const prevDeltas = layerIn.map((out, j) => {
214
458
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
215
459
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
216
460
  });
461
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
462
+ for (let j = 0; j < prevDeltas.length; j++) {
463
+ prevDeltas[j] += deltas[j];
464
+ }
465
+ }
217
466
  layer.neurons.forEach((n, k) => {
218
467
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
219
468
  });
220
469
  deltas = prevDeltas;
221
470
  }
222
471
  }
472
+ // ── Flat weight serialization ─────────────────────────────────────────────
473
+ // Order: layer 0 (all neurons), layer 1, ..., layer N.
474
+ getWeights() {
475
+ for (const d of this._dropouts) d.resetMask();
476
+ const w = [];
477
+ for (const layer of this.layers) {
478
+ for (const n of layer.neurons) {
479
+ w.push(...n.weights, n.bias);
480
+ }
481
+ }
482
+ return w;
483
+ }
484
+ setWeights(weights) {
485
+ for (const d of this._dropouts) d.resetMask();
486
+ let idx = 0;
487
+ for (const layer of this.layers) {
488
+ for (const n of layer.neurons) {
489
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
490
+ n.bias = weights[idx++];
491
+ }
492
+ }
493
+ }
494
+ // ── Helper ───────────────────────────────────────────────────────────────
495
+ _shouldResidual(layerIndex) {
496
+ if (typeof this._residual === "function") return this._residual(layerIndex);
497
+ return this._residual;
498
+ }
223
499
  };
224
500
 
225
501
  // src/LSTMLayer.ts
@@ -248,8 +524,11 @@ var Gate = class {
248
524
  }
249
525
  };
250
526
  var LSTMLayer = class {
251
- constructor(inputSize, hiddenSize) {
527
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
252
528
  this._traj = [];
529
+ if (inputSize <= 0 || hiddenSize <= 0) {
530
+ throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
531
+ }
253
532
  this.inputSize = inputSize;
254
533
  this.hSize = hiddenSize;
255
534
  this.h = new Array(hiddenSize).fill(0);
@@ -258,6 +537,29 @@ var LSTMLayer = class {
258
537
  this.inputGate = new Gate(inputSize, hiddenSize);
259
538
  this.cellGate = new Gate(inputSize, hiddenSize);
260
539
  this.outputGate = new Gate(inputSize, hiddenSize);
540
+ const combSize = inputSize + hiddenSize;
541
+ this._optimizers = {
542
+ forgetW: Array.from(
543
+ { length: hiddenSize },
544
+ () => Array.from({ length: combSize }, () => optimizerFactory())
545
+ ),
546
+ forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
547
+ inputW: Array.from(
548
+ { length: hiddenSize },
549
+ () => Array.from({ length: combSize }, () => optimizerFactory())
550
+ ),
551
+ inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
552
+ cellW: Array.from(
553
+ { length: hiddenSize },
554
+ () => Array.from({ length: combSize }, () => optimizerFactory())
555
+ ),
556
+ cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
557
+ outputW: Array.from(
558
+ { length: hiddenSize },
559
+ () => Array.from({ length: combSize }, () => optimizerFactory())
560
+ ),
561
+ outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
562
+ };
261
563
  }
262
564
  // ── Reset state and trajectory (call at episode start) ────────────────────
263
565
  reset() {
@@ -267,6 +569,9 @@ var LSTMLayer = class {
267
569
  }
268
570
  // ── Forward pass ──────────────────────────────────────────────────────────
269
571
  predict(inputs) {
572
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
573
+ throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
574
+ }
270
575
  const combined = [...inputs, ...this.h];
271
576
  const c_prev = [...this.c];
272
577
  const zf = this.forgetGate.linear(combined);
@@ -341,15 +646,15 @@ var LSTMLayer = class {
341
646
  const scale = lr / T;
342
647
  for (let k = 0; k < hSize; k++) {
343
648
  for (let j = 0; j < combSize; j++) {
344
- this.forgetGate.W[k][j] += scale * dWf[k][j];
345
- this.inputGate.W[k][j] += scale * dWi[k][j];
346
- this.cellGate.W[k][j] += scale * dWg[k][j];
347
- this.outputGate.W[k][j] += scale * dWo[k][j];
649
+ this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
650
+ this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
651
+ this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
652
+ this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
348
653
  }
349
- this.forgetGate.b[k] += scale * dbf[k];
350
- this.inputGate.b[k] += scale * dbi[k];
351
- this.cellGate.b[k] += scale * dbg[k];
352
- this.outputGate.b[k] += scale * dbo[k];
654
+ this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
655
+ this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
656
+ this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
657
+ this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
353
658
  }
354
659
  this._traj = [];
355
660
  }
@@ -372,6 +677,35 @@ var LSTMLayer = class {
372
677
  this.outputGate.W = data.outputGate.W;
373
678
  this.outputGate.b = data.outputGate.b;
374
679
  }
680
+ // ── Flat weight serialization ─────────────────────────────────────────────
681
+ // Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
682
+ getWeightsFlat() {
683
+ const w = [];
684
+ for (const row of this.forgetGate.W) w.push(...row);
685
+ w.push(...this.forgetGate.b);
686
+ for (const row of this.inputGate.W) w.push(...row);
687
+ w.push(...this.inputGate.b);
688
+ for (const row of this.cellGate.W) w.push(...row);
689
+ w.push(...this.cellGate.b);
690
+ for (const row of this.outputGate.W) w.push(...row);
691
+ w.push(...this.outputGate.b);
692
+ return w;
693
+ }
694
+ setWeightsFlat(weights) {
695
+ let idx = 0;
696
+ for (let i = 0; i < this.forgetGate.W.length; i++)
697
+ for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
698
+ for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
699
+ for (let i = 0; i < this.inputGate.W.length; i++)
700
+ for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
701
+ for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
702
+ for (let i = 0; i < this.cellGate.W.length; i++)
703
+ for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
704
+ for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
705
+ for (let i = 0; i < this.outputGate.W.length; i++)
706
+ for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
707
+ for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
708
+ }
375
709
  };
376
710
 
377
711
  // src/NetworkLSTM.ts
@@ -398,6 +732,7 @@ var NetworkLSTM = class {
398
732
  }
399
733
  // ── Forward pass ──────────────────────────────────────────────────────────
400
734
  predict(inputs) {
735
+ validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
401
736
  const h = this.lstm.predict(inputs);
402
737
  const acts = [h];
403
738
  for (const layer of this.denseLayers) {
@@ -473,6 +808,30 @@ var NetworkLSTM = class {
473
808
  });
474
809
  });
475
810
  }
811
+ // ── Flat weight serialization ─────────────────────────────────────────────
812
+ // Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
813
+ getWeightsFlat() {
814
+ const w = [];
815
+ w.push(...this.lstm.getWeightsFlat());
816
+ for (const layer of this.denseLayers) {
817
+ for (const n of layer.neurons) {
818
+ w.push(...n.weights, n.bias);
819
+ }
820
+ }
821
+ return w;
822
+ }
823
+ setWeightsFlat(weights) {
824
+ let idx = 0;
825
+ const lstmLen = this.lstm.getWeightsFlat().length;
826
+ this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
827
+ idx += lstmLen;
828
+ for (const layer of this.denseLayers) {
829
+ for (const n of layer.neurons) {
830
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
831
+ n.bias = weights[idx++];
832
+ }
833
+ }
834
+ }
476
835
  };
477
836
 
478
837
  // src/MatMul.ts
@@ -480,6 +839,9 @@ function matMul(A, B) {
480
839
  const rows = A.length;
481
840
  const inner = B.length;
482
841
  const cols = B[0].length;
842
+ if (A[0].length !== B.length) {
843
+ throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
844
+ }
483
845
  const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
484
846
  for (let i = 0; i < rows; i++)
485
847
  for (let k = 0; k < inner; k++) {
@@ -530,6 +892,17 @@ var WeightMatrix = class {
530
892
  this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
531
893
  }
532
894
  }
895
+ // ── Flat weight serialization ─────────────────────────────────────────────
896
+ getWeights() {
897
+ const w = [];
898
+ for (const row of this.W) w.push(...row);
899
+ return w;
900
+ }
901
+ setWeights(weights) {
902
+ let idx = 0;
903
+ for (let i = 0; i < this.W.length; i++)
904
+ for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
905
+ }
533
906
  };
534
907
  var EmbeddingMatrix = class {
535
908
  constructor(vocabSize, d_model) {
@@ -546,15 +919,29 @@ var EmbeddingMatrix = class {
546
919
  for (let m = 0; m < this.W[idx].length; m++)
547
920
  this.W[idx][m] += lr * grad[m];
548
921
  }
922
+ // ── Serializable interface ─────────────────────────────────────────────────
923
+ // Flattened order: row 0, row 1, ... row (vocabSize-1)
924
+ getWeights() {
925
+ const w = [];
926
+ for (const row of this.W) w.push(...row);
927
+ return w;
928
+ }
929
+ setWeights(weights) {
930
+ let idx = 0;
931
+ for (let i = 0; i < this.W.length; i++)
932
+ for (let j = 0; j < this.W[i].length; j++)
933
+ this.W[i][j] = weights[idx++];
934
+ }
549
935
  };
550
936
 
551
937
  // src/AttentionHead.ts
552
938
  var AttentionHead = class {
553
- constructor(d_model, d_k, d_v) {
939
+ constructor(d_model, d_k, d_v, causal = false) {
554
940
  // d_v × d_model
555
941
  this.cache = null;
556
942
  this.d_k = d_k;
557
943
  this.d_v = d_v;
944
+ this.causal = causal;
558
945
  this.Wq = new WeightMatrix(d_k, d_model);
559
946
  this.Wk = new WeightMatrix(d_k, d_model);
560
947
  this.Wv = new WeightMatrix(d_v, d_model);
@@ -575,10 +962,10 @@ var AttentionHead = class {
575
962
  );
576
963
  const scores = Array.from(
577
964
  { length: seqLen },
578
- (_, i) => Array.from(
579
- { length: seqLen },
580
- (_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
581
- )
965
+ (_, i) => Array.from({ length: seqLen }, (_2, j) => {
966
+ if (this.causal && j > i) return -Infinity;
967
+ return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
968
+ })
582
969
  );
583
970
  const attn = scores.map((row) => softmax(row));
584
971
  const out = Array.from(
@@ -674,21 +1061,40 @@ var AttentionHead = class {
674
1061
  getAttentionWeights() {
675
1062
  return this.cache ? this.cache.attn : null;
676
1063
  }
1064
+ // ── Flat weight serialization ─────────────────────────────────────────────
1065
+ // Order: Wq, Wk, Wv.
1066
+ getWeights() {
1067
+ const w = [];
1068
+ for (const row of this.Wq.W) w.push(...row);
1069
+ for (const row of this.Wk.W) w.push(...row);
1070
+ for (const row of this.Wv.W) w.push(...row);
1071
+ return w;
1072
+ }
1073
+ setWeights(weights) {
1074
+ let idx = 0;
1075
+ for (let i = 0; i < this.Wq.W.length; i++)
1076
+ for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
1077
+ for (let i = 0; i < this.Wk.W.length; i++)
1078
+ for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
1079
+ for (let i = 0; i < this.Wv.W.length; i++)
1080
+ for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
1081
+ }
677
1082
  };
678
1083
 
679
1084
  // src/MultiHeadAttention.ts
680
1085
  var MultiHeadAttention = class {
681
1086
  // seqLen × (nHeads * d_k)
682
- constructor(d_model, nHeads) {
1087
+ constructor(d_model, nHeads, causal = false) {
683
1088
  // d_model × (nHeads * d_k)
684
1089
  // Cached for backward
685
1090
  this._concat = null;
686
1091
  this.nHeads = nHeads;
687
1092
  this.d_model = d_model;
688
1093
  this.d_k = Math.floor(d_model / nHeads);
1094
+ this.causal = causal;
689
1095
  this.heads = Array.from(
690
1096
  { length: nHeads },
691
- () => new AttentionHead(d_model, this.d_k, this.d_k)
1097
+ () => new AttentionHead(d_model, this.d_k, this.d_k, causal)
692
1098
  );
693
1099
  this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
694
1100
  }
@@ -747,6 +1153,31 @@ var MultiHeadAttention = class {
747
1153
  getAttentionWeights() {
748
1154
  return this.heads.map((h) => h.getAttentionWeights());
749
1155
  }
1156
+ // ── Flat weight serialization ─────────────────────────────────────────────
1157
+ // Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
1158
+ getWeights() {
1159
+ const w = [];
1160
+ for (const head of this.heads) {
1161
+ for (const row of head.Wq.W) w.push(...row);
1162
+ for (const row of head.Wk.W) w.push(...row);
1163
+ for (const row of head.Wv.W) w.push(...row);
1164
+ }
1165
+ for (const row of this.Wo.W) w.push(...row);
1166
+ return w;
1167
+ }
1168
+ setWeights(weights) {
1169
+ let idx = 0;
1170
+ for (const head of this.heads) {
1171
+ for (let i = 0; i < head.Wq.W.length; i++)
1172
+ for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
1173
+ for (let i = 0; i < head.Wk.W.length; i++)
1174
+ for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
1175
+ for (let i = 0; i < head.Wv.W.length; i++)
1176
+ for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
1177
+ }
1178
+ for (let i = 0; i < this.Wo.W.length; i++)
1179
+ for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
1180
+ }
750
1181
  };
751
1182
 
752
1183
  // src/LayerNorm.ts
@@ -798,11 +1229,21 @@ var LayerNorm = class {
798
1229
  const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
799
1230
  return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
800
1231
  }
1232
+ // ── Flat weight serialization ─────────────────────────────────────────────
1233
+ // Order: gamma, beta.
1234
+ getWeights() {
1235
+ return [...this.gamma, ...this.beta];
1236
+ }
1237
+ setWeights(weights) {
1238
+ const dim = this.gamma.length;
1239
+ for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
1240
+ for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
1241
+ }
801
1242
  };
802
1243
 
803
1244
  // src/TransformerBlock.ts
804
1245
  var TransformerBlock = class {
805
- constructor({ d_model, nHeads, d_ff }) {
1246
+ constructor({ d_model, nHeads, d_ff, causal = false }) {
806
1247
  // Forward caches (needed for backprop)
807
1248
  this._X = null;
808
1249
  this._attnOut = null;
@@ -814,7 +1255,7 @@ var TransformerBlock = class {
814
1255
  this._ff2Out = null;
815
1256
  this.d_model = d_model;
816
1257
  this.d_ff = d_ff;
817
- this.attn = new MultiHeadAttention(d_model, nHeads);
1258
+ this.attn = new MultiHeadAttention(d_model, nHeads, causal);
818
1259
  this.norm1 = new LayerNorm(d_model);
819
1260
  this.norm2 = new LayerNorm(d_model);
820
1261
  this.ff1 = new WeightMatrix(d_ff, d_model);
@@ -927,6 +1368,35 @@ var TransformerBlock = class {
927
1368
  getAttentionWeights() {
928
1369
  return this.attn.getAttentionWeights();
929
1370
  }
1371
+ // ── Flat weight serialization ─────────────────────────────────────────────
1372
+ // Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
1373
+ getWeights() {
1374
+ const w = [];
1375
+ w.push(...this.attn.getWeights());
1376
+ w.push(...this.norm1.gamma, ...this.norm1.beta);
1377
+ for (const row of this.ff1.W) w.push(...row);
1378
+ w.push(...this.b1);
1379
+ for (const row of this.ff2.W) w.push(...row);
1380
+ w.push(...this.b2);
1381
+ w.push(...this.norm2.gamma, ...this.norm2.beta);
1382
+ return w;
1383
+ }
1384
+ setWeights(weights) {
1385
+ let idx = 0;
1386
+ const attnLen = this.attn.getWeights().length;
1387
+ this.attn.setWeights(weights.slice(idx, idx + attnLen));
1388
+ idx += attnLen;
1389
+ for (let i = 0; i < this.norm1.gamma.length; i++) this.norm1.gamma[i] = weights[idx++];
1390
+ for (let i = 0; i < this.norm1.beta.length; i++) this.norm1.beta[i] = weights[idx++];
1391
+ for (let i = 0; i < this.ff1.W.length; i++)
1392
+ for (let j = 0; j < this.ff1.W[i].length; j++) this.ff1.W[i][j] = weights[idx++];
1393
+ for (let i = 0; i < this.b1.length; i++) this.b1[i] = weights[idx++];
1394
+ for (let i = 0; i < this.ff2.W.length; i++)
1395
+ for (let j = 0; j < this.ff2.W[i].length; j++) this.ff2.W[i][j] = weights[idx++];
1396
+ for (let i = 0; i < this.b2.length; i++) this.b2[i] = weights[idx++];
1397
+ for (let i = 0; i < this.norm2.gamma.length; i++) this.norm2.gamma[i] = weights[idx++];
1398
+ for (let i = 0; i < this.norm2.beta.length; i++) this.norm2.beta[i] = weights[idx++];
1399
+ }
930
1400
  };
931
1401
 
932
1402
  // src/NetworkTransformer.ts
@@ -1025,6 +1495,32 @@ var NetworkTransformer = class {
1025
1495
  getAttentionWeights() {
1026
1496
  return this.blocks.map((b) => b.getAttentionWeights());
1027
1497
  }
1498
+ // ── Flat weight serialization ─────────────────────────────────────────────
1499
+ // Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
1500
+ getWeights() {
1501
+ const w = [];
1502
+ for (const row of this.tokenEmb.W) w.push(...row);
1503
+ for (const row of this.posEmb.W) w.push(...row);
1504
+ for (const block of this.blocks) w.push(...block.getWeights());
1505
+ for (const row of this.outputProj.W) w.push(...row);
1506
+ w.push(...this.outputBias);
1507
+ return w;
1508
+ }
1509
+ setWeights(weights) {
1510
+ let idx = 0;
1511
+ for (let i = 0; i < this.tokenEmb.W.length; i++)
1512
+ for (let j = 0; j < this.tokenEmb.W[i].length; j++) this.tokenEmb.W[i][j] = weights[idx++];
1513
+ for (let i = 0; i < this.posEmb.W.length; i++)
1514
+ for (let j = 0; j < this.posEmb.W[i].length; j++) this.posEmb.W[i][j] = weights[idx++];
1515
+ for (const block of this.blocks) {
1516
+ const blockLen = block.getWeights().length;
1517
+ block.setWeights(weights.slice(idx, idx + blockLen));
1518
+ idx += blockLen;
1519
+ }
1520
+ for (let i = 0; i < this.outputProj.W.length; i++)
1521
+ for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
1522
+ for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
1523
+ }
1028
1524
  // ── Internal ──────────────────────────────────────────────────────────────
1029
1525
  // Shared embedding + block forward pass.
1030
1526
  _forward(tokens) {
@@ -1044,21 +1540,25 @@ var NetworkTransformerRL = class {
1044
1540
  constructor(seqLen, inputDim, options = {}) {
1045
1541
  // Forward caches para backprop
1046
1542
  this._projected = null;
1543
+ // For max pooling backward: argmax per dimension across all positions
1544
+ this._argmax = null;
1047
1545
  const {
1048
1546
  d_model = 32,
1049
1547
  nHeads = 2,
1050
1548
  d_ff = 64,
1051
1549
  nBlocks = 2,
1052
- nActions = 2
1550
+ nActions = 2,
1551
+ pooling = "weighted"
1053
1552
  } = options;
1054
1553
  this.seqLen = seqLen;
1055
1554
  this.inputDim = inputDim;
1056
1555
  this.d_model = d_model;
1057
1556
  this.nActions = nActions;
1557
+ this._pooling = pooling;
1058
1558
  this.inputProj = new WeightMatrix(d_model, inputDim);
1059
1559
  this.blocks = Array.from(
1060
1560
  { length: nBlocks },
1061
- () => new TransformerBlock({ d_model, nHeads, d_ff })
1561
+ () => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
1062
1562
  );
1063
1563
  this.outputProj = new WeightMatrix(nActions, d_model);
1064
1564
  this.outputBias = new Array(nActions).fill(0);
@@ -1107,11 +1607,7 @@ var NetworkTransformerRL = class {
1107
1607
  this.outputProj.update(dWout, lr);
1108
1608
  for (let c = 0; c < this.nActions; c++)
1109
1609
  this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1110
- let dH = Array.from(
1111
- { length: this.seqLen },
1112
- (_, i) => dPooled.map((v) => v / this.seqLen)
1113
- // Gradiente dividido entre posiciones
1114
- );
1610
+ let dH = this._distributePoolGradient(dPooled);
1115
1611
  for (let b = this.blocks.length - 1; b >= 0; b--)
1116
1612
  dH = this.blocks[b].backward(dH, lr);
1117
1613
  for (let i = 0; i < this.seqLen; i++) {
@@ -1130,8 +1626,30 @@ var NetworkTransformerRL = class {
1130
1626
  getAttentionWeights() {
1131
1627
  return this.blocks.map((b) => b.getAttentionWeights());
1132
1628
  }
1133
- // ── Serialization ──────────────────────────────────────────────────────────
1134
- getWeights() {
1629
+ // ── Flat weight serialization ─────────────────────────────────────────────
1630
+ // Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
1631
+ getWeightsFlat() {
1632
+ const w = [];
1633
+ for (const row of this.inputProj.W) w.push(...row);
1634
+ for (const block of this.blocks) w.push(...block.getWeights());
1635
+ for (const row of this.outputProj.W) w.push(...row);
1636
+ w.push(...this.outputBias);
1637
+ return w;
1638
+ }
1639
+ setWeightsFlat(weights) {
1640
+ let idx = 0;
1641
+ for (let i = 0; i < this.inputProj.W.length; i++)
1642
+ for (let j = 0; j < this.inputProj.W[i].length; j++) this.inputProj.W[i][j] = weights[idx++];
1643
+ for (const block of this.blocks) {
1644
+ const blockLen = block.getWeights().length;
1645
+ block.setWeights(weights.slice(idx, idx + blockLen));
1646
+ idx += blockLen;
1647
+ }
1648
+ for (let i = 0; i < this.outputProj.W.length; i++)
1649
+ for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
1650
+ for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
1651
+ }
1652
+ getWeightsStructured() {
1135
1653
  return {
1136
1654
  inputProj: this.inputProj.W.map((r) => [...r]),
1137
1655
  blocks: this.blocks.map((b) => ({
@@ -1154,7 +1672,7 @@ var NetworkTransformerRL = class {
1154
1672
  outputBias: [...this.outputBias]
1155
1673
  };
1156
1674
  }
1157
- setWeights(data) {
1675
+ setWeightsStructured(data) {
1158
1676
  data.inputProj.forEach((row, i) => {
1159
1677
  this.inputProj.W[i] = [...row];
1160
1678
  });
@@ -1178,6 +1696,15 @@ var NetworkTransformerRL = class {
1178
1696
  this.outputProj.W = data.outputProj.map((r) => [...r]);
1179
1697
  this.outputBias = [...data.outputBias];
1180
1698
  }
1699
+ // ── Serializable interface (flat array) ────────────────────────────────────
1700
+ // These satisfy the Serializable interface from ModelSaver, which requires
1701
+ // getWeights(): number[] and setWeights(weights: number[]): void.
1702
+ getWeights() {
1703
+ return this.getWeightsFlat();
1704
+ }
1705
+ setWeights(weights) {
1706
+ this.setWeightsFlat(weights);
1707
+ }
1181
1708
  // ── Internal ────────────────────────────────────────────────────────────────
1182
1709
  _forward(sequence) {
1183
1710
  let h = sequence.map(
@@ -1191,6 +1718,44 @@ var NetworkTransformerRL = class {
1191
1718
  return h;
1192
1719
  }
1193
1720
  _pool(h) {
1721
+ switch (this._pooling) {
1722
+ case "avg":
1723
+ return this._poolAvg(h);
1724
+ case "max":
1725
+ return this._poolMax(h);
1726
+ case "last":
1727
+ return this._poolLast(h);
1728
+ case "weighted":
1729
+ default:
1730
+ return this._poolWeighted(h);
1731
+ }
1732
+ }
1733
+ _poolAvg(h) {
1734
+ const n = h.length;
1735
+ return Array.from({ length: this.d_model }, (_, m) => {
1736
+ let sum = 0;
1737
+ for (let i = 0; i < n; i++)
1738
+ sum += h[i][m];
1739
+ return sum / n;
1740
+ });
1741
+ }
1742
+ _poolMax(h) {
1743
+ this._argmax = new Array(this.d_model).fill(0);
1744
+ return Array.from({ length: this.d_model }, (_, m) => {
1745
+ let maxVal = -Infinity;
1746
+ for (let i = 0; i < h.length; i++) {
1747
+ if (h[i][m] > maxVal) {
1748
+ maxVal = h[i][m];
1749
+ this._argmax[m] = i;
1750
+ }
1751
+ }
1752
+ return maxVal;
1753
+ });
1754
+ }
1755
+ _poolLast(h) {
1756
+ return [...h[h.length - 1]];
1757
+ }
1758
+ _poolWeighted(h) {
1194
1759
  const weights = Array.from(
1195
1760
  { length: this.seqLen },
1196
1761
  (_, i) => i === this.seqLen - 1 ? 2 : 1
@@ -1203,6 +1768,55 @@ var NetworkTransformerRL = class {
1203
1768
  return sum / totalWeight;
1204
1769
  });
1205
1770
  }
1771
+ /** Returns the current pooling type for inspection. */
1772
+ getPoolingType() {
1773
+ return this._pooling;
1774
+ }
1775
+ // ── Helper: distribute pooled gradient back to each position ────────────────
1776
+ // Must match the same distribution as _pool() used during forward.
1777
+ _distributePoolGradient(dPooled) {
1778
+ switch (this._pooling) {
1779
+ case "avg": {
1780
+ const n = this.seqLen;
1781
+ return Array.from(
1782
+ { length: n },
1783
+ () => dPooled.map((v) => v / n)
1784
+ );
1785
+ }
1786
+ case "max": {
1787
+ if (!this._argmax) {
1788
+ const n = this.seqLen;
1789
+ return Array.from(
1790
+ { length: n },
1791
+ () => dPooled.map((v) => v / n)
1792
+ );
1793
+ }
1794
+ const argmax = this._argmax;
1795
+ return Array.from(
1796
+ { length: this.seqLen },
1797
+ (_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
1798
+ );
1799
+ }
1800
+ case "last": {
1801
+ return Array.from(
1802
+ { length: this.seqLen },
1803
+ (_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
1804
+ );
1805
+ }
1806
+ case "weighted":
1807
+ default: {
1808
+ const weights = Array.from(
1809
+ { length: this.seqLen },
1810
+ (_, i) => i === this.seqLen - 1 ? 2 : 1
1811
+ );
1812
+ const totalWeight = weights.reduce((a, b) => a + b, 0);
1813
+ return Array.from(
1814
+ { length: this.seqLen },
1815
+ (_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
1816
+ );
1817
+ }
1818
+ }
1819
+ }
1206
1820
  };
1207
1821
 
1208
1822
  // src/losses.ts
@@ -1227,13 +1841,801 @@ function crossEntropyDeltaRaw(predicted, actual) {
1227
1841
  const p = Math.max(eps, Math.min(1 - eps, predicted));
1228
1842
  return actual / p - (1 - actual) / (1 - p);
1229
1843
  }
1844
+
1845
+ // src/GRU.ts
1846
+ function sigmoid4(x) {
1847
+ return 1 / (1 + Math.exp(-x));
1848
+ }
1849
+ function tanhFn(x) {
1850
+ const e = Math.exp(2 * x);
1851
+ return (e - 1) / (e + 1);
1852
+ }
1853
+ var Gate2 = class {
1854
+ constructor(inputSize, hSize, initBias = 0) {
1855
+ const n = inputSize + hSize;
1856
+ const limit = Math.sqrt(2 / n);
1857
+ this.W = Array.from(
1858
+ { length: hSize },
1859
+ () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
1860
+ );
1861
+ this.b = new Array(hSize).fill(initBias);
1862
+ }
1863
+ linear(combined) {
1864
+ return this.W.map(
1865
+ (row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
1866
+ );
1867
+ }
1868
+ };
1869
+ var GRULayer = class {
1870
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
1871
+ this._traj = [];
1872
+ if (inputSize <= 0 || hiddenSize <= 0) {
1873
+ throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
1874
+ }
1875
+ this.inputSize = inputSize;
1876
+ this.hSize = hiddenSize;
1877
+ this.h = new Array(hiddenSize).fill(0);
1878
+ this.resetGate = new Gate2(inputSize, hiddenSize);
1879
+ this.updateGate = new Gate2(inputSize, hiddenSize);
1880
+ this.newGate = new Gate2(inputSize, hiddenSize);
1881
+ const combSize = inputSize + hiddenSize;
1882
+ this._optimizers = {
1883
+ resetW: Array.from(
1884
+ { length: hiddenSize },
1885
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1886
+ ),
1887
+ resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1888
+ updateW: Array.from(
1889
+ { length: hiddenSize },
1890
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1891
+ ),
1892
+ updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1893
+ newW: Array.from(
1894
+ { length: hiddenSize },
1895
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1896
+ ),
1897
+ newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
1898
+ };
1899
+ }
1900
+ reset() {
1901
+ this.h = new Array(this.hSize).fill(0);
1902
+ this._traj = [];
1903
+ }
1904
+ predict(inputs) {
1905
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
1906
+ throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
1907
+ }
1908
+ const combined = [...inputs, ...this.h];
1909
+ const h_prev = [...this.h];
1910
+ const r_pre = this.resetGate.linear(combined);
1911
+ const z_pre = this.updateGate.linear(combined);
1912
+ const r_a = r_pre.map(sigmoid4);
1913
+ const z_a = z_pre.map(sigmoid4);
1914
+ const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
1915
+ const n_pre = this.newGate.linear(combined_r);
1916
+ const n_a = n_pre.map(tanhFn);
1917
+ const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
1918
+ this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
1919
+ this.h = h;
1920
+ return h;
1921
+ }
1922
+ backprop(dh_seq, lr) {
1923
+ const T = this._traj.length;
1924
+ if (T === 0 || dh_seq.length !== T) return;
1925
+ const hSize = this.hSize;
1926
+ const combSize = this.inputSize + hSize;
1927
+ const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1928
+ const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1929
+ const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1930
+ const dbr = new Array(hSize).fill(0);
1931
+ const dbz = new Array(hSize).fill(0);
1932
+ const dbn = new Array(hSize).fill(0);
1933
+ let dh_next = new Array(hSize).fill(0);
1934
+ for (let t = T - 1; t >= 0; t--) {
1935
+ const s = this._traj[t];
1936
+ const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
1937
+ const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
1938
+ const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
1939
+ const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
1940
+ const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
1941
+ const dr_hprev = Array.from(
1942
+ { length: hSize },
1943
+ (_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
1944
+ );
1945
+ const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
1946
+ const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
1947
+ for (let k = 0; k < hSize; k++) {
1948
+ for (let j = 0; j < combSize; j++) {
1949
+ dWr[k][j] += dr_pre[k] * s.combined[j];
1950
+ dWz[k][j] += dz_pre[k] * s.combined[j];
1951
+ dWn[k][j] += dn_pre[k] * s.combined_r[j];
1952
+ }
1953
+ dbr[k] += dr_pre[k];
1954
+ dbz[k] += dz_pre[k];
1955
+ dbn[k] += dn_pre[k];
1956
+ }
1957
+ dh_next = new Array(hSize).fill(0);
1958
+ for (let k = 0; k < hSize; k++) {
1959
+ for (let j = this.inputSize; j < combSize; j++) {
1960
+ dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
1961
+ }
1962
+ dh_next[k] += dr_hprev[k] * s.r_a[k];
1963
+ dh_next[k] += dh[k] * s.z_a[k];
1964
+ }
1965
+ }
1966
+ const scale = lr / T;
1967
+ for (let k = 0; k < hSize; k++) {
1968
+ for (let j = 0; j < combSize; j++) {
1969
+ this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
1970
+ this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
1971
+ this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
1972
+ }
1973
+ this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
1974
+ this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
1975
+ this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
1976
+ }
1977
+ this._traj = [];
1978
+ }
1979
+ // ── Flat weight serialization ─────────────────────────────────────────────
1980
+ // Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
1981
+ getWeightsFlat() {
1982
+ const w = [];
1983
+ for (const row of this.resetGate.W) w.push(...row);
1984
+ w.push(...this.resetGate.b);
1985
+ for (const row of this.updateGate.W) w.push(...row);
1986
+ w.push(...this.updateGate.b);
1987
+ for (const row of this.newGate.W) w.push(...row);
1988
+ w.push(...this.newGate.b);
1989
+ return w;
1990
+ }
1991
+ setWeightsFlat(weights) {
1992
+ let idx = 0;
1993
+ for (let i = 0; i < this.resetGate.W.length; i++)
1994
+ for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
1995
+ for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
1996
+ for (let i = 0; i < this.updateGate.W.length; i++)
1997
+ for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
1998
+ for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
1999
+ for (let i = 0; i < this.newGate.W.length; i++)
2000
+ for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
2001
+ for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
2002
+ }
2003
+ getWeights() {
2004
+ return {
2005
+ resetGate: { W: this.resetGate.W, b: this.resetGate.b },
2006
+ updateGate: { W: this.updateGate.W, b: this.updateGate.b },
2007
+ newGate: { W: this.newGate.W, b: this.newGate.b }
2008
+ };
2009
+ }
2010
+ setWeights(data) {
2011
+ this.resetGate.W = data.resetGate.W;
2012
+ this.resetGate.b = data.resetGate.b;
2013
+ this.updateGate.W = data.updateGate.W;
2014
+ this.updateGate.b = data.updateGate.b;
2015
+ this.newGate.W = data.newGate.W;
2016
+ this.newGate.b = data.newGate.b;
2017
+ }
2018
+ };
2019
+
2020
+ // src/BatchNorm.ts
2021
+ var BatchNorm = class {
2022
+ constructor(dim, momentum = 0.1) {
2023
+ this._xNorm = null;
2024
+ this._std = null;
2025
+ this.dim = dim;
2026
+ this.momentum = momentum;
2027
+ this.gamma = new Array(dim).fill(1);
2028
+ this.beta = new Array(dim).fill(0);
2029
+ this.runningMean = new Array(dim).fill(0);
2030
+ this.runningVar = new Array(dim).fill(1);
2031
+ }
2032
+ // ── Forward ───────────────────────────────────────────────────────────────
2033
+ forward(x) {
2034
+ if (x.length !== this.dim) {
2035
+ throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
2036
+ }
2037
+ const eps = 1e-5;
2038
+ for (let i = 0; i < this.dim; i++) {
2039
+ this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
2040
+ const diff = x[i] - this.runningMean[i];
2041
+ this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
2042
+ }
2043
+ this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
2044
+ this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
2045
+ return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
2046
+ }
2047
+ // ── Backward ──────────────────────────────────────────────────────────────
2048
+ backward(dOut) {
2049
+ if (!this._xNorm || !this._std) {
2050
+ throw new Error("BatchNorm.backward: call forward() first");
2051
+ }
2052
+ for (let i = 0; i < this.dim; i++) {
2053
+ }
2054
+ return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
2055
+ }
2056
+ // ── Train gamma and beta (call after backward) ────────────────────────────
2057
+ trainParams(dOut, lr) {
2058
+ if (!this._xNorm) return;
2059
+ for (let i = 0; i < this.dim; i++) {
2060
+ this.gamma[i] += lr * dOut[i] * this._xNorm[i];
2061
+ this.beta[i] += lr * dOut[i];
2062
+ }
2063
+ }
2064
+ // ── Flat weight serialization ─────────────────────────────────────────────
2065
+ // Order: gamma, beta.
2066
+ getWeights() {
2067
+ return [...this.gamma, ...this.beta];
2068
+ }
2069
+ setWeights(weights) {
2070
+ for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
2071
+ for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
2072
+ }
2073
+ };
2074
+
2075
+ // src/Conv1D.ts
2076
+ var Conv1D = class {
2077
+ constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
2078
+ // [filters]
2079
+ this._input = null;
2080
+ this._paddedInput = null;
2081
+ if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
2082
+ throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
2083
+ }
2084
+ if (kernelSize > inputLength && padding === "valid") {
2085
+ throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
2086
+ }
2087
+ if (inputChannels < 1) {
2088
+ throw new Error("Conv1D: inputChannels must be >= 1");
2089
+ }
2090
+ this.inputLength = inputLength;
2091
+ this.kernelSize = kernelSize;
2092
+ this.filters = filters;
2093
+ this.stride = stride;
2094
+ this.padding = padding;
2095
+ this.inputChannels = inputChannels;
2096
+ const limit = Math.sqrt(2 / (kernelSize * inputChannels));
2097
+ this.kernels = Array.from(
2098
+ { length: filters },
2099
+ () => Array.from(
2100
+ { length: kernelSize },
2101
+ () => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
2102
+ )
2103
+ );
2104
+ this.biases = new Array(filters).fill(0);
2105
+ this._kOpts = Array.from(
2106
+ { length: filters },
2107
+ () => Array.from(
2108
+ { length: kernelSize },
2109
+ () => Array.from({ length: inputChannels }, () => optimizerFactory())
2110
+ )
2111
+ );
2112
+ this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
2113
+ }
2114
+ // ── Forward ───────────────────────────────────────────────────────────────
2115
+ // Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
2116
+ forward(input) {
2117
+ const input2D = this._normalizeInput(input);
2118
+ this._input = input2D.map((row) => [...row]);
2119
+ let padded;
2120
+ if (this.padding === "same") {
2121
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2122
+ const padRow = new Array(this.inputChannels).fill(0);
2123
+ padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
2124
+ } else {
2125
+ padded = input2D;
2126
+ }
2127
+ this._paddedInput = padded;
2128
+ const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
2129
+ const output = Array.from(
2130
+ { length: this.filters },
2131
+ () => new Array(outputLength).fill(0)
2132
+ );
2133
+ for (let f = 0; f < this.filters; f++) {
2134
+ for (let pos = 0; pos < outputLength; pos++) {
2135
+ const start = pos * this.stride;
2136
+ let sum = this.biases[f];
2137
+ for (let k = 0; k < this.kernelSize; k++) {
2138
+ for (let c = 0; c < this.inputChannels; c++) {
2139
+ sum += this.kernels[f][k][c] * padded[start + k][c];
2140
+ }
2141
+ }
2142
+ output[f][pos] = sum;
2143
+ }
2144
+ }
2145
+ return output;
2146
+ }
2147
+ // ── Backward ──────────────────────────────────────────────────────────────
2148
+ backward(dOut, lr = 1e-3) {
2149
+ if (!this._paddedInput || !this._input) {
2150
+ throw new Error("Conv1D.backward: call forward() first");
2151
+ }
2152
+ const padded = this._paddedInput;
2153
+ const outputLength = dOut[0].length;
2154
+ const dKernels = Array.from(
2155
+ { length: this.filters },
2156
+ () => Array.from(
2157
+ { length: this.kernelSize },
2158
+ () => new Array(this.inputChannels).fill(0)
2159
+ )
2160
+ );
2161
+ const dBiases = new Array(this.filters).fill(0);
2162
+ const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
2163
+ for (let f = 0; f < this.filters; f++) {
2164
+ for (let pos = 0; pos < outputLength; pos++) {
2165
+ const start = pos * this.stride;
2166
+ dBiases[f] += dOut[f][pos];
2167
+ for (let k = 0; k < this.kernelSize; k++) {
2168
+ for (let c = 0; c < this.inputChannels; c++) {
2169
+ dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
2170
+ dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
2171
+ }
2172
+ }
2173
+ }
2174
+ }
2175
+ for (let f = 0; f < this.filters; f++) {
2176
+ for (let k = 0; k < this.kernelSize; k++) {
2177
+ for (let c = 0; c < this.inputChannels; c++) {
2178
+ this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
2179
+ }
2180
+ }
2181
+ this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
2182
+ }
2183
+ if (this.padding === "same") {
2184
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2185
+ return dPadded.slice(padSize, padSize + this.inputLength);
2186
+ }
2187
+ return dPadded.slice(0, this.inputLength);
2188
+ }
2189
+ // ── Output length ─────────────────────────────────────────────────────────
2190
+ getOutputLength() {
2191
+ if (this.padding === "same") {
2192
+ return Math.ceil(this.inputLength / this.stride);
2193
+ }
2194
+ return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
2195
+ }
2196
+ // ── Flat weight serialization ─────────────────────────────────────────────
2197
+ // Order: kernels (flattened), biases.
2198
+ getWeights() {
2199
+ const w = [];
2200
+ for (const kernel of this.kernels)
2201
+ for (const k of kernel)
2202
+ for (const c of k)
2203
+ w.push(c);
2204
+ w.push(...this.biases);
2205
+ return w;
2206
+ }
2207
+ setWeights(weights) {
2208
+ let idx = 0;
2209
+ for (let f = 0; f < this.filters; f++)
2210
+ for (let k = 0; k < this.kernelSize; k++)
2211
+ for (let c = 0; c < this.inputChannels; c++)
2212
+ this.kernels[f][k][c] = weights[idx++];
2213
+ for (let f = 0; f < this.filters; f++)
2214
+ this.biases[f] = weights[idx++];
2215
+ }
2216
+ // ── Normalize input to 2D format ─────────────────────────────────────────
2217
+ _normalizeInput(input) {
2218
+ if (input.length === 0) {
2219
+ throw new Error("Conv1D.forward: input cannot be empty");
2220
+ }
2221
+ if (typeof input[0] === "number") {
2222
+ if (this.inputChannels !== 1) {
2223
+ throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
2224
+ }
2225
+ const input1D = input;
2226
+ if (input1D.length !== this.inputLength) {
2227
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
2228
+ }
2229
+ return input1D.map((v) => [v]);
2230
+ }
2231
+ const input2D = input;
2232
+ if (input2D.length !== this.inputLength) {
2233
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
2234
+ }
2235
+ for (let i = 0; i < input2D.length; i++) {
2236
+ if (input2D[i].length !== this.inputChannels) {
2237
+ throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
2238
+ }
2239
+ }
2240
+ return input2D;
2241
+ }
2242
+ };
2243
+
2244
+ // src/Trainer.ts
2245
+ var Trainer = class {
2246
+ constructor(network, options = {}) {
2247
+ this._history = [];
2248
+ this._bestLoss = Infinity;
2249
+ this._patienceCounter = 0;
2250
+ this._stopReason = "maxEpochs";
2251
+ this._metrics = [];
2252
+ this.network = network;
2253
+ this.epochs = options.epochs ?? 1e3;
2254
+ this.lrInitial = options.lr ?? 0.1;
2255
+ this.lrDecay = options.lrDecay ?? 1;
2256
+ this.verbose = options.verbose ?? false;
2257
+ this.weightDecay = options.weightDecay ?? 0;
2258
+ this._earlyStopping = options.earlyStopping;
2259
+ this._computeMetrics = options.computeMetrics ?? false;
2260
+ this.clipValue = options.clipValue ?? 0;
2261
+ }
2262
+ // ── Set external validation data (for early stopping) ────────────────────
2263
+ setValidationData(dataset) {
2264
+ if (dataset.inputs.length !== dataset.targets.length) {
2265
+ throw new Error(
2266
+ "Trainer.setValidationData: inputs and targets must have the same length"
2267
+ );
2268
+ }
2269
+ this._validationData = dataset;
2270
+ }
2271
+ // ── Get best validation loss during training ─────────────────────────────
2272
+ getBestLoss() {
2273
+ return this._bestLoss === Infinity ? -1 : this._bestLoss;
2274
+ }
2275
+ // ── Why did training stop? ───────────────────────────────────────────────
2276
+ getStopReason() {
2277
+ return this._stopReason;
2278
+ }
2279
+ // ── Get per-epoch classification metrics ─────────────────────────────────
2280
+ getMetrics() {
2281
+ return [...this._metrics];
2282
+ }
2283
+ // ── Train on dataset ──────────────────────────────────────────────────────
2284
+ train(dataset) {
2285
+ const { inputs, targets } = dataset;
2286
+ if (inputs.length !== targets.length) {
2287
+ throw new Error(
2288
+ "Trainer.train: inputs and targets must have the same length"
2289
+ );
2290
+ }
2291
+ const n = inputs.length;
2292
+ let lr = this.lrInitial;
2293
+ this._history = [];
2294
+ this._bestLoss = Infinity;
2295
+ this._patienceCounter = 0;
2296
+ this._stopReason = "maxEpochs";
2297
+ this._metrics = [];
2298
+ const netExt = this._hasWeights(this.network);
2299
+ if (this.weightDecay > 0 && !netExt) {
2300
+ console.warn(
2301
+ "Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
2302
+ );
2303
+ }
2304
+ if (this._earlyStopping && !netExt) {
2305
+ console.warn(
2306
+ "Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
2307
+ );
2308
+ }
2309
+ if (this._computeMetrics && !netExt) {
2310
+ console.warn(
2311
+ "Trainer: computeMetrics requires a network with predict(). Skipping metrics."
2312
+ );
2313
+ }
2314
+ const canDecay = this.weightDecay > 0 && netExt;
2315
+ const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
2316
+ const canMetric = this._computeMetrics && netExt;
2317
+ const isClass = canMetric && this._isClassification(targets);
2318
+ if (canMetric && !isClass) {
2319
+ console.warn(
2320
+ "Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
2321
+ );
2322
+ }
2323
+ for (let epoch = 0; epoch < this.epochs; epoch++) {
2324
+ const indices = Array.from({ length: n }, (_, i) => i);
2325
+ for (let i = n - 1; i > 0; i--) {
2326
+ const j = Math.floor(Math.random() * (i + 1));
2327
+ [indices[i], indices[j]] = [indices[j], indices[i]];
2328
+ }
2329
+ let epochLoss = 0;
2330
+ for (const i of indices) {
2331
+ if (canDecay) {
2332
+ const w = netExt.getWeights();
2333
+ for (let j = 0; j < w.length; j++) {
2334
+ w[j] *= 1 - lr * this.weightDecay;
2335
+ }
2336
+ netExt.setWeights(w);
2337
+ }
2338
+ epochLoss += this.network.train(inputs[i], targets[i], lr);
2339
+ }
2340
+ epochLoss /= n;
2341
+ this._history.push(epochLoss);
2342
+ if (canMetric && isClass) {
2343
+ this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
2344
+ }
2345
+ if (canValidate && this._validationData) {
2346
+ const valLoss = this._computeLoss(netExt, this._validationData);
2347
+ const minDelta = this._earlyStopping.minDelta;
2348
+ if (valLoss < this._bestLoss - minDelta) {
2349
+ this._bestLoss = valLoss;
2350
+ this._patienceCounter = 0;
2351
+ } else {
2352
+ this._patienceCounter++;
2353
+ }
2354
+ if (this._patienceCounter >= this._earlyStopping.patience) {
2355
+ this._stopReason = "earlyStopping";
2356
+ break;
2357
+ }
2358
+ }
2359
+ lr *= this.lrDecay;
2360
+ if (this.verbose && (epoch + 1) % 100 === 0) {
2361
+ console.log(
2362
+ `Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
2363
+ );
2364
+ }
2365
+ }
2366
+ return this._history;
2367
+ }
2368
+ // ── Get loss history ──────────────────────────────────────────────────────
2369
+ getHistory() {
2370
+ return [...this._history];
2371
+ }
2372
+ // ── Private helpers ───────────────────────────────────────────────────────
2373
+ /** Type guard: does this network support getWeights/setWeights/predict? */
2374
+ _hasWeights(network) {
2375
+ if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
2376
+ return network;
2377
+ }
2378
+ return null;
2379
+ }
2380
+ /** Mean squared error on a dataset (used for validation loss). */
2381
+ _computeLoss(net, data) {
2382
+ let totalLoss = 0;
2383
+ for (let i = 0; i < data.inputs.length; i++) {
2384
+ const pred = net.predict(data.inputs[i]);
2385
+ const target = data.targets[i];
2386
+ let sampleLoss = 0;
2387
+ for (let j = 0; j < pred.length; j++) {
2388
+ sampleLoss += (target[j] - pred[j]) ** 2;
2389
+ }
2390
+ totalLoss += sampleLoss / pred.length;
2391
+ }
2392
+ return totalLoss / data.inputs.length;
2393
+ }
2394
+ /** Heuristic: are targets classification-style (one-hot or single-class)? */
2395
+ _isClassification(targets) {
2396
+ if (targets.length === 0) return false;
2397
+ const first = targets[0];
2398
+ if (first.length === 1) return true;
2399
+ for (const t of targets) {
2400
+ let sum = 0;
2401
+ for (const v of t) {
2402
+ sum += v;
2403
+ if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
2404
+ return false;
2405
+ }
2406
+ if (Math.abs(sum - 1) > 0.01) return false;
2407
+ }
2408
+ return true;
2409
+ }
2410
+ /** Compute classification metrics from predictions vs targets. */
2411
+ _computeMetricsArray(net, inputs, targets) {
2412
+ const targetLen = targets[0].length;
2413
+ const nClasses = targetLen === 1 ? 2 : targetLen;
2414
+ const confusion = Array.from(
2415
+ { length: nClasses },
2416
+ () => Array(nClasses).fill(0)
2417
+ );
2418
+ for (let i = 0; i < inputs.length; i++) {
2419
+ const pred = net.predict(inputs[i]);
2420
+ const target = targets[i];
2421
+ let predClass;
2422
+ let trueClass;
2423
+ if (targetLen === 1) {
2424
+ trueClass = target[0] >= 0.5 ? 1 : 0;
2425
+ if (pred.length === 1) {
2426
+ predClass = pred[0] >= 0.5 ? 1 : 0;
2427
+ } else {
2428
+ predClass = pred.indexOf(Math.max(...pred));
2429
+ }
2430
+ } else {
2431
+ predClass = pred.indexOf(Math.max(...pred));
2432
+ trueClass = target.indexOf(Math.max(...target));
2433
+ }
2434
+ predClass = Math.max(0, Math.min(nClasses - 1, predClass));
2435
+ trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
2436
+ confusion[trueClass][predClass]++;
2437
+ }
2438
+ let totalCorrect = 0;
2439
+ let totalSamples = 0;
2440
+ const precisions = [];
2441
+ const recalls = [];
2442
+ for (let c = 0; c < nClasses; c++) {
2443
+ const tp = confusion[c][c];
2444
+ totalCorrect += tp;
2445
+ let colSum = 0;
2446
+ let rowSum = 0;
2447
+ for (let r = 0; r < nClasses; r++) {
2448
+ colSum += confusion[r][c];
2449
+ rowSum += confusion[c][r];
2450
+ }
2451
+ totalSamples += rowSum;
2452
+ precisions.push(colSum > 0 ? tp / colSum : 0);
2453
+ recalls.push(rowSum > 0 ? tp / rowSum : 0);
2454
+ }
2455
+ const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
2456
+ const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
2457
+ const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
2458
+ const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
2459
+ return {
2460
+ accuracy,
2461
+ precision: macroPrecision,
2462
+ recall: macroRecall,
2463
+ f1
2464
+ };
2465
+ }
2466
+ };
2467
+
2468
+ // src/DataLoader.ts
2469
+ var DataLoader = class _DataLoader {
2470
+ constructor(data, batchSize = 1, validationSplit = 0) {
2471
+ if (data.inputs.length !== data.targets.length) {
2472
+ throw new Error("DataLoader: inputs and targets must have the same length");
2473
+ }
2474
+ if (validationSplit < 0 || validationSplit >= 1) {
2475
+ throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
2476
+ }
2477
+ this.data = data;
2478
+ this.batchSize = batchSize;
2479
+ this._validationSplit = validationSplit;
2480
+ const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
2481
+ for (let i = fullIndices.length - 1; i > 0; i--) {
2482
+ const j = Math.floor(Math.random() * (i + 1));
2483
+ [fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
2484
+ }
2485
+ if (validationSplit > 0) {
2486
+ const valSize = Math.round(data.inputs.length * validationSplit);
2487
+ const trainSize = data.inputs.length - valSize;
2488
+ this._trainIndices = fullIndices.slice(0, trainSize);
2489
+ this._valIndices = fullIndices.slice(trainSize);
2490
+ } else {
2491
+ this._trainIndices = [...fullIndices];
2492
+ this._valIndices = [];
2493
+ }
2494
+ this._indices = [...this._trainIndices];
2495
+ this._pos = 0;
2496
+ }
2497
+ // ── Shuffle the training data ──────────────────────────────────────────────
2498
+ shuffle() {
2499
+ for (let i = this._trainIndices.length - 1; i > 0; i--) {
2500
+ const j = Math.floor(Math.random() * (i + 1));
2501
+ [this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
2502
+ }
2503
+ this._indices = [...this._trainIndices];
2504
+ this._pos = 0;
2505
+ }
2506
+ // ── Check if more batches are available ───────────────────────────────────
2507
+ hasNext() {
2508
+ return this._pos < this._indices.length;
2509
+ }
2510
+ // ── Get next batch ────────────────────────────────────────────────────────
2511
+ next() {
2512
+ const end = Math.min(this._pos + this.batchSize, this._indices.length);
2513
+ const batchIndices = this._indices.slice(this._pos, end);
2514
+ this._pos = end;
2515
+ return {
2516
+ inputs: batchIndices.map((i) => this.data.inputs[i]),
2517
+ targets: batchIndices.map((i) => this.data.targets[i])
2518
+ };
2519
+ }
2520
+ // ── Reset iteration ───────────────────────────────────────────────────────
2521
+ reset() {
2522
+ this._pos = 0;
2523
+ }
2524
+ // ── Get total number of training samples ───────────────────────────────────
2525
+ get length() {
2526
+ return this._trainIndices.length;
2527
+ }
2528
+ // ── Get validation data as a DataPair ──────────────────────────────────────
2529
+ // Returns the validation samples (inputs + targets) in their shuffled order.
2530
+ // Returns empty arrays if no validation split was configured.
2531
+ getValidationData() {
2532
+ return {
2533
+ inputs: this._valIndices.map((i) => this.data.inputs[i]),
2534
+ targets: this._valIndices.map((i) => this.data.targets[i])
2535
+ };
2536
+ }
2537
+ // ── Get number of validation samples ───────────────────────────────────────
2538
+ get validationLength() {
2539
+ return this._valIndices.length;
2540
+ }
2541
+ // ── Create sequence windows from a time series ────────────────────────────
2542
+ static sequences(data, seqLen, validationSplit = 0) {
2543
+ if (data.length < seqLen + 1) {
2544
+ throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
2545
+ }
2546
+ const inputs = [];
2547
+ const targets = [];
2548
+ for (let i = 0; i <= data.length - seqLen - 1; i++) {
2549
+ inputs.push(data.slice(i, i + seqLen).flat());
2550
+ targets.push(data[i + seqLen]);
2551
+ }
2552
+ return new _DataLoader({ inputs, targets }, 1, validationSplit);
2553
+ }
2554
+ };
2555
+
2556
+ // src/LRScheduler.ts
2557
+ var LRScheduler = class {
2558
+ // ── Step Decay ────────────────────────────────────────────────────────────
2559
+ // lr = initialLr * dropRate^floor(epoch / epochsDrop)
2560
+ stepDecay(lr, epoch, dropRate, epochsDrop) {
2561
+ return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
2562
+ }
2563
+ // ── Exponential Decay ─────────────────────────────────────────────────────
2564
+ // lr = initialLr * decayRate^epoch
2565
+ exponentialDecay(lr, epoch, decayRate) {
2566
+ return lr * Math.pow(decayRate, epoch);
2567
+ }
2568
+ // ── Plateau Decay ─────────────────────────────────────────────────────────
2569
+ // If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
2570
+ // Returns the new lr. Call this after each epoch with the current loss.
2571
+ //
2572
+ // Usage:
2573
+ // let patience_counter = 0
2574
+ // let best_loss = Infinity
2575
+ // for (let epoch = 0; epoch < 1000; epoch++) {
2576
+ // const loss = train(...)
2577
+ // lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
2578
+ // }
2579
+ plateauDecay(lr, currentLoss, history, patience, factor) {
2580
+ if (history.length < patience) return lr;
2581
+ const recentLosses = history.slice(-patience);
2582
+ const minRecentLoss = Math.min(...recentLosses);
2583
+ if (currentLoss >= minRecentLoss) {
2584
+ return lr * factor;
2585
+ }
2586
+ return lr;
2587
+ }
2588
+ // ── Cosine Annealing ──────────────────────────────────────────────────────
2589
+ // lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
2590
+ cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
2591
+ return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
2592
+ }
2593
+ };
2594
+
2595
+ // src/ModelSaver.ts
2596
+ var ModelSaver = class _ModelSaver {
2597
+ // ── Serialize to JSON string ──────────────────────────────────────────────
2598
+ static toJSON(model) {
2599
+ return JSON.stringify({
2600
+ weights: model.getWeights(),
2601
+ timestamp: Date.now()
2602
+ });
2603
+ }
2604
+ // ── Deserialize from JSON string ──────────────────────────────────────────
2605
+ static fromJSON(model, json) {
2606
+ const data = JSON.parse(json);
2607
+ if (!data.weights || !Array.isArray(data.weights)) {
2608
+ throw new Error("ModelSaver.fromJSON: invalid model data");
2609
+ }
2610
+ model.setWeights(data.weights);
2611
+ }
2612
+ // ── Save to file (requires write function) ────────────────────────────────
2613
+ static saveToFile(model, path, writeFn) {
2614
+ const json = _ModelSaver.toJSON(model);
2615
+ writeFn(path, json);
2616
+ }
2617
+ // ── Load from file (requires read function) ───────────────────────────────
2618
+ static loadFromFile(model, path, readFn) {
2619
+ const json = readFn(path);
2620
+ _ModelSaver.fromJSON(model, json);
2621
+ }
2622
+ };
1230
2623
  export {
1231
2624
  Adam,
1232
2625
  AttentionHead,
2626
+ BatchNorm,
2627
+ ClipOptimizer,
2628
+ ClippedOptimizerFactory,
2629
+ Conv1D,
2630
+ DataLoader,
2631
+ Dropout,
1233
2632
  EmbeddingMatrix,
2633
+ GRULayer,
2634
+ LRScheduler,
1234
2635
  LSTMLayer,
1235
2636
  Layer,
1236
2637
  LayerNorm,
2638
+ ModelSaver,
1237
2639
  Momentum,
1238
2640
  MultiHeadAttention,
1239
2641
  Network,
@@ -1244,6 +2646,7 @@ export {
1244
2646
  Neuron,
1245
2647
  NeuronN,
1246
2648
  SGD,
2649
+ Trainer,
1247
2650
  TransformerBlock,
1248
2651
  WeightMatrix,
1249
2652
  crossEntropy,
@@ -1262,5 +2665,9 @@ export {
1262
2665
  softmax,
1263
2666
  softmaxBackward,
1264
2667
  tanh,
1265
- transpose
2668
+ transpose,
2669
+ validate2DArray,
2670
+ validateArray,
2671
+ validateArrayMinLength,
2672
+ validateNumber
1266
2673
  };