@dniskav/neuron 0.2.2 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -22,10 +22,19 @@ var index_exports = {};
22
22
  __export(index_exports, {
23
23
  Adam: () => Adam,
24
24
  AttentionHead: () => AttentionHead,
25
+ BatchNorm: () => BatchNorm,
26
+ ClipOptimizer: () => ClipOptimizer,
27
+ ClippedOptimizerFactory: () => ClippedOptimizerFactory,
28
+ Conv1D: () => Conv1D,
29
+ DataLoader: () => DataLoader,
30
+ Dropout: () => Dropout,
25
31
  EmbeddingMatrix: () => EmbeddingMatrix,
32
+ GRULayer: () => GRULayer,
33
+ LRScheduler: () => LRScheduler,
26
34
  LSTMLayer: () => LSTMLayer,
27
35
  Layer: () => Layer,
28
36
  LayerNorm: () => LayerNorm,
37
+ ModelSaver: () => ModelSaver,
29
38
  Momentum: () => Momentum,
30
39
  MultiHeadAttention: () => MultiHeadAttention,
31
40
  Network: () => Network,
@@ -36,6 +45,7 @@ __export(index_exports, {
36
45
  Neuron: () => Neuron,
37
46
  NeuronN: () => NeuronN,
38
47
  SGD: () => SGD,
48
+ Trainer: () => Trainer,
39
49
  TransformerBlock: () => TransformerBlock,
40
50
  WeightMatrix: () => WeightMatrix,
41
51
  crossEntropy: () => crossEntropy,
@@ -54,10 +64,82 @@ __export(index_exports, {
54
64
  softmax: () => softmax,
55
65
  softmaxBackward: () => softmaxBackward,
56
66
  tanh: () => tanh,
57
- transpose: () => transpose
67
+ transpose: () => transpose,
68
+ validate2DArray: () => validate2DArray,
69
+ validateArray: () => validateArray,
70
+ validateArrayMinLength: () => validateArrayMinLength,
71
+ validateNumber: () => validateNumber
58
72
  });
59
73
  module.exports = __toCommonJS(index_exports);
60
74
 
75
+ // src/Validation.ts
76
+ function validateArray(arr, expectedLength, methodName) {
77
+ if (!Array.isArray(arr)) {
78
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
79
+ }
80
+ if (arr.length !== expectedLength) {
81
+ throw new Error(
82
+ `${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
83
+ );
84
+ }
85
+ for (let i = 0; i < arr.length; i++) {
86
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
87
+ throw new Error(
88
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
89
+ );
90
+ }
91
+ }
92
+ }
93
+ function validateArrayMinLength(arr, minLength, methodName) {
94
+ if (!Array.isArray(arr)) {
95
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
96
+ }
97
+ if (arr.length < minLength) {
98
+ throw new Error(
99
+ `${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
100
+ );
101
+ }
102
+ for (let i = 0; i < arr.length; i++) {
103
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
104
+ throw new Error(
105
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
106
+ );
107
+ }
108
+ }
109
+ }
110
+ function validate2DArray(arr, expectedRows, expectedCols, methodName) {
111
+ if (!Array.isArray(arr)) {
112
+ throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
113
+ }
114
+ if (arr.length !== expectedRows) {
115
+ throw new Error(
116
+ `${methodName}: expected ${expectedRows} rows, got ${arr.length}`
117
+ );
118
+ }
119
+ for (let i = 0; i < arr.length; i++) {
120
+ if (!Array.isArray(arr[i])) {
121
+ throw new Error(`${methodName}: row ${i} is not an array`);
122
+ }
123
+ if (arr[i].length !== expectedCols) {
124
+ throw new Error(
125
+ `${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
126
+ );
127
+ }
128
+ for (let j = 0; j < arr[i].length; j++) {
129
+ if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
130
+ throw new Error(
131
+ `${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
132
+ );
133
+ }
134
+ }
135
+ }
136
+ }
137
+ function validateNumber(value, methodName) {
138
+ if (typeof value !== "number" || !isFinite(value)) {
139
+ throw new Error(`${methodName}: expected finite number, got ${value}`);
140
+ }
141
+ }
142
+
61
143
  // src/Neuron.ts
62
144
  function sigmoid(x) {
63
145
  return 1 / (1 + Math.exp(-x));
@@ -68,13 +150,18 @@ var Neuron = class {
68
150
  this.bias = Math.random() * 0.1;
69
151
  }
70
152
  predict(input) {
153
+ validateNumber(input, "Neuron.predict");
71
154
  return sigmoid(input * this.weight + this.bias);
72
155
  }
73
156
  train(input, target, lr) {
157
+ validateNumber(input, "Neuron.train");
158
+ validateNumber(target, "Neuron.train");
159
+ validateNumber(lr, "Neuron.train");
74
160
  const prediction = this.predict(input);
75
161
  const error = target - prediction;
76
- this.weight += lr * error * input;
77
- this.bias += lr * error;
162
+ const grad = error * prediction * (1 - prediction);
163
+ this.weight += lr * grad * input;
164
+ this.bias += lr * grad;
78
165
  }
79
166
  };
80
167
 
@@ -129,6 +216,19 @@ var Momentum = class {
129
216
  return weight + this.v;
130
217
  }
131
218
  };
219
+ var ClipOptimizer = class {
220
+ constructor(inner, clipValue) {
221
+ this.inner = inner;
222
+ this.clipValue = clipValue;
223
+ }
224
+ step(weight, gradient, lr) {
225
+ const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
226
+ return this.inner.step(weight, clipped, lr);
227
+ }
228
+ };
229
+ function ClippedOptimizerFactory(innerFactory, clipValue) {
230
+ return () => new ClipOptimizer(innerFactory(), clipValue);
231
+ }
132
232
  var Adam = class {
133
233
  constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
134
234
  this.beta1 = beta1;
@@ -159,6 +259,7 @@ var NeuronN = class {
159
259
  this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
160
260
  }
161
261
  predict(inputs) {
262
+ validateArray(inputs, this.weights.length, "NeuronN.predict");
162
263
  const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
163
264
  return this.activation.fn(sum);
164
265
  }
@@ -171,7 +272,8 @@ var NeuronN = class {
171
272
  train(inputs, target, lr) {
172
273
  const prediction = this.predict(inputs);
173
274
  const error = target - prediction;
174
- this._update(inputs.map((inp) => error * inp), error, lr);
275
+ const grad = error * this.activation.dfn(prediction);
276
+ this._update(inputs.map((inp) => grad * inp), grad, lr);
175
277
  }
176
278
  };
177
279
 
@@ -196,29 +298,99 @@ var Network = class {
196
298
  this.outputLayer = new Layer(nOutputs, nHidden);
197
299
  }
198
300
  predict(inputs) {
301
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
199
302
  const hiddenOut = this.hiddenLayer.predict(inputs);
200
303
  return this.outputLayer.predict(hiddenOut)[0];
201
304
  }
202
305
  // Trains on a single example. Returns the squared error.
203
306
  train(inputs, target, lr) {
307
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
308
+ validateNumber(target, "Network.train");
309
+ validateNumber(lr, "Network.train");
204
310
  const hiddenOut = this.hiddenLayer.predict(inputs);
205
311
  const prediction = this.outputLayer.predict(hiddenOut)[0];
206
312
  const outputError = target - prediction;
207
313
  const outputDelta = outputError * prediction * (1 - prediction);
208
314
  const outputNeuron = this.outputLayer.neurons[0];
315
+ const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
316
+ const hiddenOut_i = hiddenOut[i];
317
+ const hiddenError = outputDelta * outputNeuron.weights[i];
318
+ return hiddenError * hiddenOut_i * (1 - hiddenOut_i);
319
+ });
320
+ this.hiddenLayer.neurons.forEach((neuron, i) => {
321
+ neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDeltas[i] * inputs[j]);
322
+ neuron.bias += lr * hiddenDeltas[i];
323
+ });
209
324
  outputNeuron.weights = outputNeuron.weights.map(
210
325
  (w, i) => w + lr * outputDelta * hiddenOut[i]
211
326
  );
212
327
  outputNeuron.bias += lr * outputDelta;
213
- this.hiddenLayer.neurons.forEach((neuron, i) => {
214
- const hiddenOut_i = hiddenOut[i];
215
- const hiddenError = outputDelta * outputNeuron.weights[i];
216
- const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
217
- neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
218
- neuron.bias += lr * hiddenDelta;
219
- });
220
328
  return outputError * outputError;
221
329
  }
330
+ // ── Flat weight serialization ─────────────────────────────────────────────
331
+ // Order: hidden layer (all neurons: weights then bias), then output layer.
332
+ getWeights() {
333
+ const w = [];
334
+ for (const n of this.hiddenLayer.neurons) {
335
+ w.push(...n.weights, n.bias);
336
+ }
337
+ for (const n of this.outputLayer.neurons) {
338
+ w.push(...n.weights, n.bias);
339
+ }
340
+ return w;
341
+ }
342
+ setWeights(weights) {
343
+ let idx = 0;
344
+ for (const n of this.hiddenLayer.neurons) {
345
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
346
+ n.bias = weights[idx++];
347
+ }
348
+ for (const n of this.outputLayer.neurons) {
349
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
350
+ n.bias = weights[idx++];
351
+ }
352
+ }
353
+ };
354
+
355
+ // src/Dropout.ts
356
+ var Dropout = class {
357
+ constructor(rate) {
358
+ this._mask = null;
359
+ if (rate < 0 || rate >= 1) {
360
+ throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
361
+ }
362
+ this.rate = rate;
363
+ }
364
+ // ── Forward ───────────────────────────────────────────────────────────────
365
+ // x: number[] → number[]
366
+ // If training, applies inverted dropout mask.
367
+ // If not training, returns input unchanged.
368
+ forward(x, training = true) {
369
+ if (!training || this.rate === 0) {
370
+ this._mask = null;
371
+ return [...x];
372
+ }
373
+ const scale = 1 / (1 - this.rate);
374
+ this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
375
+ return x.map((v, i) => v * this._mask[i]);
376
+ }
377
+ // ── Backward ──────────────────────────────────────────────────────────────
378
+ // dOut: number[] → number[]
379
+ // Applies the same mask (gradient is zeroed where activation was zeroed).
380
+ backward(dOut) {
381
+ if (!this._mask) return [...dOut];
382
+ return dOut.map((d, i) => d * this._mask[i]);
383
+ }
384
+ // ── Reset mask between forward passes ─────────────────────────────────────
385
+ resetMask() {
386
+ this._mask = null;
387
+ }
388
+ // ── No trainable params ───────────────────────────────────────────────────
389
+ getWeights() {
390
+ return [];
391
+ }
392
+ setWeights(_weights) {
393
+ }
222
394
  };
223
395
 
224
396
  // src/NetworkN.ts
@@ -229,30 +401,96 @@ var NetworkN = class {
229
401
  const nLayers = structure.length - 1;
230
402
  const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
231
403
  const optimizer = options.optimizer ?? defaultOptimizer3;
404
+ const dropoutRate = options.dropoutRate ?? 0;
405
+ if (activations.length !== nLayers) {
406
+ throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
407
+ }
408
+ if (dropoutRate < 0 || dropoutRate >= 1) {
409
+ throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
410
+ }
411
+ this._residual = options.residual ?? false;
232
412
  this.layers = [];
233
413
  for (let i = 1; i < structure.length; i++) {
234
414
  this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
235
415
  }
416
+ this._dropouts = [];
417
+ if (dropoutRate > 0) {
418
+ for (let i = 0; i < nLayers - 1; i++) {
419
+ this._dropouts.push(new Dropout(dropoutRate));
420
+ }
421
+ }
422
+ const outputLayer = this.layers[this.layers.length - 1];
423
+ const outputActivation = outputLayer.neurons[0].activation;
424
+ for (let i = 1; i < outputLayer.neurons.length; i++) {
425
+ if (outputLayer.neurons[i].activation !== outputActivation) {
426
+ throw new Error("All output neurons must share the same activation function");
427
+ }
428
+ }
236
429
  }
237
- predict(inputs) {
238
- return this.layers.reduce((acc, layer) => layer.predict(acc), inputs);
430
+ predict(inputs, training = false) {
431
+ validateArray(inputs, this.structure[0], "NetworkN.predict");
432
+ let current = [...inputs];
433
+ for (let i = 0; i < this.layers.length; i++) {
434
+ const layerInput = [...current];
435
+ const layerOutput = this.layers[i].predict(current);
436
+ if (this._shouldResidual(i)) {
437
+ if (this.structure[i] === this.structure[i + 1]) {
438
+ current = layerOutput.map((v, j) => v + layerInput[j]);
439
+ } else {
440
+ current = [...layerOutput];
441
+ }
442
+ } else {
443
+ current = [...layerOutput];
444
+ }
445
+ if (i < this._dropouts.length) {
446
+ current = this._dropouts[i].forward(current, training);
447
+ }
448
+ }
449
+ return current;
239
450
  }
240
451
  // Generalized backpropagation across L layers.
241
452
  // Returns the mean squared error for the example.
242
453
  train(inputs, targets, lr) {
454
+ validateArray(inputs, this.structure[0], "NetworkN.train");
455
+ validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
243
456
  const act = [inputs];
244
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
457
+ for (let i = 0; i < this.layers.length; i++) {
458
+ const layerInput = act[act.length - 1];
459
+ const layerOutput = this.layers[i].predict(layerInput);
460
+ let current;
461
+ if (this._shouldResidual(i)) {
462
+ if (this.structure[i] === this.structure[i + 1]) {
463
+ current = layerOutput.map((v, j) => v + layerInput[j]);
464
+ } else {
465
+ current = [...layerOutput];
466
+ }
467
+ } else {
468
+ current = [...layerOutput];
469
+ }
470
+ if (i < this._dropouts.length) {
471
+ current = this._dropouts[i].forward(current, true);
472
+ }
473
+ act.push(current);
474
+ }
245
475
  const pred = act[act.length - 1];
246
476
  const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
247
477
  let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
248
478
  for (let l = this.layers.length - 1; l >= 0; l--) {
249
479
  const layer = this.layers[l];
480
+ if (l < this._dropouts.length) {
481
+ deltas = this._dropouts[l].backward(deltas);
482
+ }
250
483
  const layerIn = act[l];
251
484
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
252
485
  const prevDeltas = layerIn.map((out, j) => {
253
486
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
254
487
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
255
488
  });
489
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
490
+ for (let j = 0; j < prevDeltas.length; j++) {
491
+ prevDeltas[j] += deltas[j];
492
+ }
493
+ }
256
494
  layer.neurons.forEach((n, k) => {
257
495
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
258
496
  });
@@ -264,22 +502,74 @@ var NetworkN = class {
264
502
  // Useful for custom loss functions (e.g. physics-based gradients).
265
503
  trainWithDeltas(inputs, outputDeltas, lr) {
266
504
  const act = [inputs];
267
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
505
+ for (let i = 0; i < this.layers.length; i++) {
506
+ const layerInput = act[act.length - 1];
507
+ const layerOutput = this.layers[i].predict(layerInput);
508
+ let current;
509
+ if (this._shouldResidual(i)) {
510
+ if (this.structure[i] === this.structure[i + 1]) {
511
+ current = layerOutput.map((v, j) => v + layerInput[j]);
512
+ } else {
513
+ current = [...layerOutput];
514
+ }
515
+ } else {
516
+ current = [...layerOutput];
517
+ }
518
+ if (i < this._dropouts.length) {
519
+ current = this._dropouts[i].forward(current, true);
520
+ }
521
+ act.push(current);
522
+ }
268
523
  let deltas = outputDeltas;
269
524
  for (let l = this.layers.length - 1; l >= 0; l--) {
270
525
  const layer = this.layers[l];
526
+ if (l < this._dropouts.length) {
527
+ deltas = this._dropouts[l].backward(deltas);
528
+ }
271
529
  const layerIn = act[l];
272
530
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
273
531
  const prevDeltas = layerIn.map((out, j) => {
274
532
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
275
533
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
276
534
  });
535
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
536
+ for (let j = 0; j < prevDeltas.length; j++) {
537
+ prevDeltas[j] += deltas[j];
538
+ }
539
+ }
277
540
  layer.neurons.forEach((n, k) => {
278
541
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
279
542
  });
280
543
  deltas = prevDeltas;
281
544
  }
282
545
  }
546
+ // ── Flat weight serialization ─────────────────────────────────────────────
547
+ // Order: layer 0 (all neurons), layer 1, ..., layer N.
548
+ getWeights() {
549
+ for (const d of this._dropouts) d.resetMask();
550
+ const w = [];
551
+ for (const layer of this.layers) {
552
+ for (const n of layer.neurons) {
553
+ w.push(...n.weights, n.bias);
554
+ }
555
+ }
556
+ return w;
557
+ }
558
+ setWeights(weights) {
559
+ for (const d of this._dropouts) d.resetMask();
560
+ let idx = 0;
561
+ for (const layer of this.layers) {
562
+ for (const n of layer.neurons) {
563
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
564
+ n.bias = weights[idx++];
565
+ }
566
+ }
567
+ }
568
+ // ── Helper ───────────────────────────────────────────────────────────────
569
+ _shouldResidual(layerIndex) {
570
+ if (typeof this._residual === "function") return this._residual(layerIndex);
571
+ return this._residual;
572
+ }
283
573
  };
284
574
 
285
575
  // src/LSTMLayer.ts
@@ -308,8 +598,11 @@ var Gate = class {
308
598
  }
309
599
  };
310
600
  var LSTMLayer = class {
311
- constructor(inputSize, hiddenSize) {
601
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
312
602
  this._traj = [];
603
+ if (inputSize <= 0 || hiddenSize <= 0) {
604
+ throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
605
+ }
313
606
  this.inputSize = inputSize;
314
607
  this.hSize = hiddenSize;
315
608
  this.h = new Array(hiddenSize).fill(0);
@@ -318,6 +611,29 @@ var LSTMLayer = class {
318
611
  this.inputGate = new Gate(inputSize, hiddenSize);
319
612
  this.cellGate = new Gate(inputSize, hiddenSize);
320
613
  this.outputGate = new Gate(inputSize, hiddenSize);
614
+ const combSize = inputSize + hiddenSize;
615
+ this._optimizers = {
616
+ forgetW: Array.from(
617
+ { length: hiddenSize },
618
+ () => Array.from({ length: combSize }, () => optimizerFactory())
619
+ ),
620
+ forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
621
+ inputW: Array.from(
622
+ { length: hiddenSize },
623
+ () => Array.from({ length: combSize }, () => optimizerFactory())
624
+ ),
625
+ inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
626
+ cellW: Array.from(
627
+ { length: hiddenSize },
628
+ () => Array.from({ length: combSize }, () => optimizerFactory())
629
+ ),
630
+ cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
631
+ outputW: Array.from(
632
+ { length: hiddenSize },
633
+ () => Array.from({ length: combSize }, () => optimizerFactory())
634
+ ),
635
+ outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
636
+ };
321
637
  }
322
638
  // ── Reset state and trajectory (call at episode start) ────────────────────
323
639
  reset() {
@@ -327,6 +643,9 @@ var LSTMLayer = class {
327
643
  }
328
644
  // ── Forward pass ──────────────────────────────────────────────────────────
329
645
  predict(inputs) {
646
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
647
+ throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
648
+ }
330
649
  const combined = [...inputs, ...this.h];
331
650
  const c_prev = [...this.c];
332
651
  const zf = this.forgetGate.linear(combined);
@@ -401,15 +720,15 @@ var LSTMLayer = class {
401
720
  const scale = lr / T;
402
721
  for (let k = 0; k < hSize; k++) {
403
722
  for (let j = 0; j < combSize; j++) {
404
- this.forgetGate.W[k][j] += scale * dWf[k][j];
405
- this.inputGate.W[k][j] += scale * dWi[k][j];
406
- this.cellGate.W[k][j] += scale * dWg[k][j];
407
- this.outputGate.W[k][j] += scale * dWo[k][j];
723
+ this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
724
+ this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
725
+ this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
726
+ this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
408
727
  }
409
- this.forgetGate.b[k] += scale * dbf[k];
410
- this.inputGate.b[k] += scale * dbi[k];
411
- this.cellGate.b[k] += scale * dbg[k];
412
- this.outputGate.b[k] += scale * dbo[k];
728
+ this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
729
+ this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
730
+ this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
731
+ this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
413
732
  }
414
733
  this._traj = [];
415
734
  }
@@ -432,6 +751,35 @@ var LSTMLayer = class {
432
751
  this.outputGate.W = data.outputGate.W;
433
752
  this.outputGate.b = data.outputGate.b;
434
753
  }
754
+ // ── Flat weight serialization ─────────────────────────────────────────────
755
+ // Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
756
+ getWeightsFlat() {
757
+ const w = [];
758
+ for (const row of this.forgetGate.W) w.push(...row);
759
+ w.push(...this.forgetGate.b);
760
+ for (const row of this.inputGate.W) w.push(...row);
761
+ w.push(...this.inputGate.b);
762
+ for (const row of this.cellGate.W) w.push(...row);
763
+ w.push(...this.cellGate.b);
764
+ for (const row of this.outputGate.W) w.push(...row);
765
+ w.push(...this.outputGate.b);
766
+ return w;
767
+ }
768
+ setWeightsFlat(weights) {
769
+ let idx = 0;
770
+ for (let i = 0; i < this.forgetGate.W.length; i++)
771
+ for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
772
+ for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
773
+ for (let i = 0; i < this.inputGate.W.length; i++)
774
+ for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
775
+ for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
776
+ for (let i = 0; i < this.cellGate.W.length; i++)
777
+ for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
778
+ for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
779
+ for (let i = 0; i < this.outputGate.W.length; i++)
780
+ for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
781
+ for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
782
+ }
435
783
  };
436
784
 
437
785
  // src/NetworkLSTM.ts
@@ -458,6 +806,7 @@ var NetworkLSTM = class {
458
806
  }
459
807
  // ── Forward pass ──────────────────────────────────────────────────────────
460
808
  predict(inputs) {
809
+ validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
461
810
  const h = this.lstm.predict(inputs);
462
811
  const acts = [h];
463
812
  for (const layer of this.denseLayers) {
@@ -533,6 +882,30 @@ var NetworkLSTM = class {
533
882
  });
534
883
  });
535
884
  }
885
+ // ── Flat weight serialization ─────────────────────────────────────────────
886
+ // Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
887
+ getWeightsFlat() {
888
+ const w = [];
889
+ w.push(...this.lstm.getWeightsFlat());
890
+ for (const layer of this.denseLayers) {
891
+ for (const n of layer.neurons) {
892
+ w.push(...n.weights, n.bias);
893
+ }
894
+ }
895
+ return w;
896
+ }
897
+ setWeightsFlat(weights) {
898
+ let idx = 0;
899
+ const lstmLen = this.lstm.getWeightsFlat().length;
900
+ this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
901
+ idx += lstmLen;
902
+ for (const layer of this.denseLayers) {
903
+ for (const n of layer.neurons) {
904
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
905
+ n.bias = weights[idx++];
906
+ }
907
+ }
908
+ }
536
909
  };
537
910
 
538
911
  // src/MatMul.ts
@@ -540,6 +913,9 @@ function matMul(A, B) {
540
913
  const rows = A.length;
541
914
  const inner = B.length;
542
915
  const cols = B[0].length;
916
+ if (A[0].length !== B.length) {
917
+ throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
918
+ }
543
919
  const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
544
920
  for (let i = 0; i < rows; i++)
545
921
  for (let k = 0; k < inner; k++) {
@@ -590,6 +966,17 @@ var WeightMatrix = class {
590
966
  this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
591
967
  }
592
968
  }
969
+ // ── Flat weight serialization ─────────────────────────────────────────────
970
+ getWeights() {
971
+ const w = [];
972
+ for (const row of this.W) w.push(...row);
973
+ return w;
974
+ }
975
+ setWeights(weights) {
976
+ let idx = 0;
977
+ for (let i = 0; i < this.W.length; i++)
978
+ for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
979
+ }
593
980
  };
594
981
  var EmbeddingMatrix = class {
595
982
  constructor(vocabSize, d_model) {
@@ -606,15 +993,29 @@ var EmbeddingMatrix = class {
606
993
  for (let m = 0; m < this.W[idx].length; m++)
607
994
  this.W[idx][m] += lr * grad[m];
608
995
  }
996
+ // ── Serializable interface ─────────────────────────────────────────────────
997
+ // Flattened order: row 0, row 1, ... row (vocabSize-1)
998
+ getWeights() {
999
+ const w = [];
1000
+ for (const row of this.W) w.push(...row);
1001
+ return w;
1002
+ }
1003
+ setWeights(weights) {
1004
+ let idx = 0;
1005
+ for (let i = 0; i < this.W.length; i++)
1006
+ for (let j = 0; j < this.W[i].length; j++)
1007
+ this.W[i][j] = weights[idx++];
1008
+ }
609
1009
  };
610
1010
 
611
1011
  // src/AttentionHead.ts
612
1012
  var AttentionHead = class {
613
- constructor(d_model, d_k, d_v) {
1013
+ constructor(d_model, d_k, d_v, causal = false) {
614
1014
  // d_v × d_model
615
1015
  this.cache = null;
616
1016
  this.d_k = d_k;
617
1017
  this.d_v = d_v;
1018
+ this.causal = causal;
618
1019
  this.Wq = new WeightMatrix(d_k, d_model);
619
1020
  this.Wk = new WeightMatrix(d_k, d_model);
620
1021
  this.Wv = new WeightMatrix(d_v, d_model);
@@ -635,10 +1036,10 @@ var AttentionHead = class {
635
1036
  );
636
1037
  const scores = Array.from(
637
1038
  { length: seqLen },
638
- (_, i) => Array.from(
639
- { length: seqLen },
640
- (_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
641
- )
1039
+ (_, i) => Array.from({ length: seqLen }, (_2, j) => {
1040
+ if (this.causal && j > i) return -Infinity;
1041
+ return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
1042
+ })
642
1043
  );
643
1044
  const attn = scores.map((row) => softmax(row));
644
1045
  const out = Array.from(
@@ -734,21 +1135,40 @@ var AttentionHead = class {
734
1135
  getAttentionWeights() {
735
1136
  return this.cache ? this.cache.attn : null;
736
1137
  }
1138
+ // ── Flat weight serialization ─────────────────────────────────────────────
1139
+ // Order: Wq, Wk, Wv.
1140
+ getWeights() {
1141
+ const w = [];
1142
+ for (const row of this.Wq.W) w.push(...row);
1143
+ for (const row of this.Wk.W) w.push(...row);
1144
+ for (const row of this.Wv.W) w.push(...row);
1145
+ return w;
1146
+ }
1147
+ setWeights(weights) {
1148
+ let idx = 0;
1149
+ for (let i = 0; i < this.Wq.W.length; i++)
1150
+ for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
1151
+ for (let i = 0; i < this.Wk.W.length; i++)
1152
+ for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
1153
+ for (let i = 0; i < this.Wv.W.length; i++)
1154
+ for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
1155
+ }
737
1156
  };
738
1157
 
739
1158
  // src/MultiHeadAttention.ts
740
1159
  var MultiHeadAttention = class {
741
1160
  // seqLen × (nHeads * d_k)
742
- constructor(d_model, nHeads) {
1161
+ constructor(d_model, nHeads, causal = false) {
743
1162
  // d_model × (nHeads * d_k)
744
1163
  // Cached for backward
745
1164
  this._concat = null;
746
1165
  this.nHeads = nHeads;
747
1166
  this.d_model = d_model;
748
1167
  this.d_k = Math.floor(d_model / nHeads);
1168
+ this.causal = causal;
749
1169
  this.heads = Array.from(
750
1170
  { length: nHeads },
751
- () => new AttentionHead(d_model, this.d_k, this.d_k)
1171
+ () => new AttentionHead(d_model, this.d_k, this.d_k, causal)
752
1172
  );
753
1173
  this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
754
1174
  }
@@ -807,6 +1227,31 @@ var MultiHeadAttention = class {
807
1227
  getAttentionWeights() {
808
1228
  return this.heads.map((h) => h.getAttentionWeights());
809
1229
  }
1230
+ // ── Flat weight serialization ─────────────────────────────────────────────
1231
+ // Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
1232
+ getWeights() {
1233
+ const w = [];
1234
+ for (const head of this.heads) {
1235
+ for (const row of head.Wq.W) w.push(...row);
1236
+ for (const row of head.Wk.W) w.push(...row);
1237
+ for (const row of head.Wv.W) w.push(...row);
1238
+ }
1239
+ for (const row of this.Wo.W) w.push(...row);
1240
+ return w;
1241
+ }
1242
+ setWeights(weights) {
1243
+ let idx = 0;
1244
+ for (const head of this.heads) {
1245
+ for (let i = 0; i < head.Wq.W.length; i++)
1246
+ for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
1247
+ for (let i = 0; i < head.Wk.W.length; i++)
1248
+ for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
1249
+ for (let i = 0; i < head.Wv.W.length; i++)
1250
+ for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
1251
+ }
1252
+ for (let i = 0; i < this.Wo.W.length; i++)
1253
+ for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
1254
+ }
810
1255
  };
811
1256
 
812
1257
  // src/LayerNorm.ts
@@ -858,11 +1303,21 @@ var LayerNorm = class {
858
1303
  const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
859
1304
  return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
860
1305
  }
1306
+ // ── Flat weight serialization ─────────────────────────────────────────────
1307
+ // Order: gamma, beta.
1308
+ getWeights() {
1309
+ return [...this.gamma, ...this.beta];
1310
+ }
1311
+ setWeights(weights) {
1312
+ const dim = this.gamma.length;
1313
+ for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
1314
+ for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
1315
+ }
861
1316
  };
862
1317
 
863
1318
  // src/TransformerBlock.ts
864
1319
  var TransformerBlock = class {
865
- constructor({ d_model, nHeads, d_ff }) {
1320
+ constructor({ d_model, nHeads, d_ff, causal = false }) {
866
1321
  // Forward caches (needed for backprop)
867
1322
  this._X = null;
868
1323
  this._attnOut = null;
@@ -874,7 +1329,7 @@ var TransformerBlock = class {
874
1329
  this._ff2Out = null;
875
1330
  this.d_model = d_model;
876
1331
  this.d_ff = d_ff;
877
- this.attn = new MultiHeadAttention(d_model, nHeads);
1332
+ this.attn = new MultiHeadAttention(d_model, nHeads, causal);
878
1333
  this.norm1 = new LayerNorm(d_model);
879
1334
  this.norm2 = new LayerNorm(d_model);
880
1335
  this.ff1 = new WeightMatrix(d_ff, d_model);
@@ -987,6 +1442,35 @@ var TransformerBlock = class {
987
1442
  getAttentionWeights() {
988
1443
  return this.attn.getAttentionWeights();
989
1444
  }
1445
+ // ── Flat weight serialization ─────────────────────────────────────────────
1446
+ // Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
1447
+ getWeights() {
1448
+ const w = [];
1449
+ w.push(...this.attn.getWeights());
1450
+ w.push(...this.norm1.gamma, ...this.norm1.beta);
1451
+ for (const row of this.ff1.W) w.push(...row);
1452
+ w.push(...this.b1);
1453
+ for (const row of this.ff2.W) w.push(...row);
1454
+ w.push(...this.b2);
1455
+ w.push(...this.norm2.gamma, ...this.norm2.beta);
1456
+ return w;
1457
+ }
1458
+ setWeights(weights) {
1459
+ let idx = 0;
1460
+ const attnLen = this.attn.getWeights().length;
1461
+ this.attn.setWeights(weights.slice(idx, idx + attnLen));
1462
+ idx += attnLen;
1463
+ for (let i = 0; i < this.norm1.gamma.length; i++) this.norm1.gamma[i] = weights[idx++];
1464
+ for (let i = 0; i < this.norm1.beta.length; i++) this.norm1.beta[i] = weights[idx++];
1465
+ for (let i = 0; i < this.ff1.W.length; i++)
1466
+ for (let j = 0; j < this.ff1.W[i].length; j++) this.ff1.W[i][j] = weights[idx++];
1467
+ for (let i = 0; i < this.b1.length; i++) this.b1[i] = weights[idx++];
1468
+ for (let i = 0; i < this.ff2.W.length; i++)
1469
+ for (let j = 0; j < this.ff2.W[i].length; j++) this.ff2.W[i][j] = weights[idx++];
1470
+ for (let i = 0; i < this.b2.length; i++) this.b2[i] = weights[idx++];
1471
+ for (let i = 0; i < this.norm2.gamma.length; i++) this.norm2.gamma[i] = weights[idx++];
1472
+ for (let i = 0; i < this.norm2.beta.length; i++) this.norm2.beta[i] = weights[idx++];
1473
+ }
990
1474
  };
991
1475
 
992
1476
  // src/NetworkTransformer.ts
@@ -1085,6 +1569,32 @@ var NetworkTransformer = class {
1085
1569
  getAttentionWeights() {
1086
1570
  return this.blocks.map((b) => b.getAttentionWeights());
1087
1571
  }
1572
+ // ── Flat weight serialization ─────────────────────────────────────────────
1573
+ // Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
1574
+ getWeights() {
1575
+ const w = [];
1576
+ for (const row of this.tokenEmb.W) w.push(...row);
1577
+ for (const row of this.posEmb.W) w.push(...row);
1578
+ for (const block of this.blocks) w.push(...block.getWeights());
1579
+ for (const row of this.outputProj.W) w.push(...row);
1580
+ w.push(...this.outputBias);
1581
+ return w;
1582
+ }
1583
+ setWeights(weights) {
1584
+ let idx = 0;
1585
+ for (let i = 0; i < this.tokenEmb.W.length; i++)
1586
+ for (let j = 0; j < this.tokenEmb.W[i].length; j++) this.tokenEmb.W[i][j] = weights[idx++];
1587
+ for (let i = 0; i < this.posEmb.W.length; i++)
1588
+ for (let j = 0; j < this.posEmb.W[i].length; j++) this.posEmb.W[i][j] = weights[idx++];
1589
+ for (const block of this.blocks) {
1590
+ const blockLen = block.getWeights().length;
1591
+ block.setWeights(weights.slice(idx, idx + blockLen));
1592
+ idx += blockLen;
1593
+ }
1594
+ for (let i = 0; i < this.outputProj.W.length; i++)
1595
+ for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
1596
+ for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
1597
+ }
1088
1598
  // ── Internal ──────────────────────────────────────────────────────────────
1089
1599
  // Shared embedding + block forward pass.
1090
1600
  _forward(tokens) {
@@ -1104,21 +1614,25 @@ var NetworkTransformerRL = class {
1104
1614
  constructor(seqLen, inputDim, options = {}) {
1105
1615
  // Forward caches para backprop
1106
1616
  this._projected = null;
1617
+ // For max pooling backward: argmax per dimension across all positions
1618
+ this._argmax = null;
1107
1619
  const {
1108
1620
  d_model = 32,
1109
1621
  nHeads = 2,
1110
1622
  d_ff = 64,
1111
1623
  nBlocks = 2,
1112
- nActions = 2
1624
+ nActions = 2,
1625
+ pooling = "weighted"
1113
1626
  } = options;
1114
1627
  this.seqLen = seqLen;
1115
1628
  this.inputDim = inputDim;
1116
1629
  this.d_model = d_model;
1117
1630
  this.nActions = nActions;
1631
+ this._pooling = pooling;
1118
1632
  this.inputProj = new WeightMatrix(d_model, inputDim);
1119
1633
  this.blocks = Array.from(
1120
1634
  { length: nBlocks },
1121
- () => new TransformerBlock({ d_model, nHeads, d_ff })
1635
+ () => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
1122
1636
  );
1123
1637
  this.outputProj = new WeightMatrix(nActions, d_model);
1124
1638
  this.outputBias = new Array(nActions).fill(0);
@@ -1167,11 +1681,7 @@ var NetworkTransformerRL = class {
1167
1681
  this.outputProj.update(dWout, lr);
1168
1682
  for (let c = 0; c < this.nActions; c++)
1169
1683
  this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1170
- let dH = Array.from(
1171
- { length: this.seqLen },
1172
- (_, i) => dPooled.map((v) => v / this.seqLen)
1173
- // Gradiente dividido entre posiciones
1174
- );
1684
+ let dH = this._distributePoolGradient(dPooled);
1175
1685
  for (let b = this.blocks.length - 1; b >= 0; b--)
1176
1686
  dH = this.blocks[b].backward(dH, lr);
1177
1687
  for (let i = 0; i < this.seqLen; i++) {
@@ -1190,6 +1700,85 @@ var NetworkTransformerRL = class {
1190
1700
  getAttentionWeights() {
1191
1701
  return this.blocks.map((b) => b.getAttentionWeights());
1192
1702
  }
1703
+ // ── Flat weight serialization ─────────────────────────────────────────────
1704
+ // Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
1705
+ getWeightsFlat() {
1706
+ const w = [];
1707
+ for (const row of this.inputProj.W) w.push(...row);
1708
+ for (const block of this.blocks) w.push(...block.getWeights());
1709
+ for (const row of this.outputProj.W) w.push(...row);
1710
+ w.push(...this.outputBias);
1711
+ return w;
1712
+ }
1713
+ setWeightsFlat(weights) {
1714
+ let idx = 0;
1715
+ for (let i = 0; i < this.inputProj.W.length; i++)
1716
+ for (let j = 0; j < this.inputProj.W[i].length; j++) this.inputProj.W[i][j] = weights[idx++];
1717
+ for (const block of this.blocks) {
1718
+ const blockLen = block.getWeights().length;
1719
+ block.setWeights(weights.slice(idx, idx + blockLen));
1720
+ idx += blockLen;
1721
+ }
1722
+ for (let i = 0; i < this.outputProj.W.length; i++)
1723
+ for (let j = 0; j < this.outputProj.W[i].length; j++) this.outputProj.W[i][j] = weights[idx++];
1724
+ for (let i = 0; i < this.outputBias.length; i++) this.outputBias[i] = weights[idx++];
1725
+ }
1726
+ getWeightsStructured() {
1727
+ return {
1728
+ inputProj: this.inputProj.W.map((r) => [...r]),
1729
+ blocks: this.blocks.map((b) => ({
1730
+ attn: {
1731
+ heads: b.attn.heads.map((h) => ({
1732
+ Wq: h.Wq.W.map((r) => [...r]),
1733
+ Wk: h.Wk.W.map((r) => [...r]),
1734
+ Wv: h.Wv.W.map((r) => [...r])
1735
+ })),
1736
+ Wo: b.attn.Wo.W.map((r) => [...r])
1737
+ },
1738
+ norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
1739
+ norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
1740
+ ff1: b.ff1.W.map((r) => [...r]),
1741
+ ff2: b.ff2.W.map((r) => [...r]),
1742
+ b1: [...b.b1],
1743
+ b2: [...b.b2]
1744
+ })),
1745
+ outputProj: this.outputProj.W.map((r) => [...r]),
1746
+ outputBias: [...this.outputBias]
1747
+ };
1748
+ }
1749
+ setWeightsStructured(data) {
1750
+ data.inputProj.forEach((row, i) => {
1751
+ this.inputProj.W[i] = [...row];
1752
+ });
1753
+ data.blocks.forEach((bd, b) => {
1754
+ const blk = this.blocks[b];
1755
+ bd.attn.heads.forEach((hd, h) => {
1756
+ blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
1757
+ blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
1758
+ blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
1759
+ });
1760
+ blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
1761
+ blk.norm1.gamma = [...bd.norm1.gamma];
1762
+ blk.norm1.beta = [...bd.norm1.beta];
1763
+ blk.norm2.gamma = [...bd.norm2.gamma];
1764
+ blk.norm2.beta = [...bd.norm2.beta];
1765
+ blk.ff1.W = bd.ff1.map((r) => [...r]);
1766
+ blk.ff2.W = bd.ff2.map((r) => [...r]);
1767
+ blk.b1 = [...bd.b1];
1768
+ blk.b2 = [...bd.b2];
1769
+ });
1770
+ this.outputProj.W = data.outputProj.map((r) => [...r]);
1771
+ this.outputBias = [...data.outputBias];
1772
+ }
1773
+ // ── Serializable interface (flat array) ────────────────────────────────────
1774
+ // These satisfy the Serializable interface from ModelSaver, which requires
1775
+ // getWeights(): number[] and setWeights(weights: number[]): void.
1776
+ getWeights() {
1777
+ return this.getWeightsFlat();
1778
+ }
1779
+ setWeights(weights) {
1780
+ this.setWeightsFlat(weights);
1781
+ }
1193
1782
  // ── Internal ────────────────────────────────────────────────────────────────
1194
1783
  _forward(sequence) {
1195
1784
  let h = sequence.map(
@@ -1203,6 +1792,44 @@ var NetworkTransformerRL = class {
1203
1792
  return h;
1204
1793
  }
1205
1794
  _pool(h) {
1795
+ switch (this._pooling) {
1796
+ case "avg":
1797
+ return this._poolAvg(h);
1798
+ case "max":
1799
+ return this._poolMax(h);
1800
+ case "last":
1801
+ return this._poolLast(h);
1802
+ case "weighted":
1803
+ default:
1804
+ return this._poolWeighted(h);
1805
+ }
1806
+ }
1807
+ _poolAvg(h) {
1808
+ const n = h.length;
1809
+ return Array.from({ length: this.d_model }, (_, m) => {
1810
+ let sum = 0;
1811
+ for (let i = 0; i < n; i++)
1812
+ sum += h[i][m];
1813
+ return sum / n;
1814
+ });
1815
+ }
1816
+ _poolMax(h) {
1817
+ this._argmax = new Array(this.d_model).fill(0);
1818
+ return Array.from({ length: this.d_model }, (_, m) => {
1819
+ let maxVal = -Infinity;
1820
+ for (let i = 0; i < h.length; i++) {
1821
+ if (h[i][m] > maxVal) {
1822
+ maxVal = h[i][m];
1823
+ this._argmax[m] = i;
1824
+ }
1825
+ }
1826
+ return maxVal;
1827
+ });
1828
+ }
1829
+ _poolLast(h) {
1830
+ return [...h[h.length - 1]];
1831
+ }
1832
+ _poolWeighted(h) {
1206
1833
  const weights = Array.from(
1207
1834
  { length: this.seqLen },
1208
1835
  (_, i) => i === this.seqLen - 1 ? 2 : 1
@@ -1215,6 +1842,55 @@ var NetworkTransformerRL = class {
1215
1842
  return sum / totalWeight;
1216
1843
  });
1217
1844
  }
1845
+ /** Returns the current pooling type for inspection. */
1846
+ getPoolingType() {
1847
+ return this._pooling;
1848
+ }
1849
+ // ── Helper: distribute pooled gradient back to each position ────────────────
1850
+ // Must match the same distribution as _pool() used during forward.
1851
+ _distributePoolGradient(dPooled) {
1852
+ switch (this._pooling) {
1853
+ case "avg": {
1854
+ const n = this.seqLen;
1855
+ return Array.from(
1856
+ { length: n },
1857
+ () => dPooled.map((v) => v / n)
1858
+ );
1859
+ }
1860
+ case "max": {
1861
+ if (!this._argmax) {
1862
+ const n = this.seqLen;
1863
+ return Array.from(
1864
+ { length: n },
1865
+ () => dPooled.map((v) => v / n)
1866
+ );
1867
+ }
1868
+ const argmax = this._argmax;
1869
+ return Array.from(
1870
+ { length: this.seqLen },
1871
+ (_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
1872
+ );
1873
+ }
1874
+ case "last": {
1875
+ return Array.from(
1876
+ { length: this.seqLen },
1877
+ (_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
1878
+ );
1879
+ }
1880
+ case "weighted":
1881
+ default: {
1882
+ const weights = Array.from(
1883
+ { length: this.seqLen },
1884
+ (_, i) => i === this.seqLen - 1 ? 2 : 1
1885
+ );
1886
+ const totalWeight = weights.reduce((a, b) => a + b, 0);
1887
+ return Array.from(
1888
+ { length: this.seqLen },
1889
+ (_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
1890
+ );
1891
+ }
1892
+ }
1893
+ }
1218
1894
  };
1219
1895
 
1220
1896
  // src/losses.ts
@@ -1239,14 +1915,802 @@ function crossEntropyDeltaRaw(predicted, actual) {
1239
1915
  const p = Math.max(eps, Math.min(1 - eps, predicted));
1240
1916
  return actual / p - (1 - actual) / (1 - p);
1241
1917
  }
1918
+
1919
+ // src/GRU.ts
1920
+ function sigmoid4(x) {
1921
+ return 1 / (1 + Math.exp(-x));
1922
+ }
1923
+ function tanhFn(x) {
1924
+ const e = Math.exp(2 * x);
1925
+ return (e - 1) / (e + 1);
1926
+ }
1927
+ var Gate2 = class {
1928
+ constructor(inputSize, hSize, initBias = 0) {
1929
+ const n = inputSize + hSize;
1930
+ const limit = Math.sqrt(2 / n);
1931
+ this.W = Array.from(
1932
+ { length: hSize },
1933
+ () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
1934
+ );
1935
+ this.b = new Array(hSize).fill(initBias);
1936
+ }
1937
+ linear(combined) {
1938
+ return this.W.map(
1939
+ (row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
1940
+ );
1941
+ }
1942
+ };
1943
+ var GRULayer = class {
1944
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
1945
+ this._traj = [];
1946
+ if (inputSize <= 0 || hiddenSize <= 0) {
1947
+ throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
1948
+ }
1949
+ this.inputSize = inputSize;
1950
+ this.hSize = hiddenSize;
1951
+ this.h = new Array(hiddenSize).fill(0);
1952
+ this.resetGate = new Gate2(inputSize, hiddenSize);
1953
+ this.updateGate = new Gate2(inputSize, hiddenSize);
1954
+ this.newGate = new Gate2(inputSize, hiddenSize);
1955
+ const combSize = inputSize + hiddenSize;
1956
+ this._optimizers = {
1957
+ resetW: Array.from(
1958
+ { length: hiddenSize },
1959
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1960
+ ),
1961
+ resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1962
+ updateW: Array.from(
1963
+ { length: hiddenSize },
1964
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1965
+ ),
1966
+ updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1967
+ newW: Array.from(
1968
+ { length: hiddenSize },
1969
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1970
+ ),
1971
+ newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
1972
+ };
1973
+ }
1974
+ reset() {
1975
+ this.h = new Array(this.hSize).fill(0);
1976
+ this._traj = [];
1977
+ }
1978
+ predict(inputs) {
1979
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
1980
+ throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
1981
+ }
1982
+ const combined = [...inputs, ...this.h];
1983
+ const h_prev = [...this.h];
1984
+ const r_pre = this.resetGate.linear(combined);
1985
+ const z_pre = this.updateGate.linear(combined);
1986
+ const r_a = r_pre.map(sigmoid4);
1987
+ const z_a = z_pre.map(sigmoid4);
1988
+ const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
1989
+ const n_pre = this.newGate.linear(combined_r);
1990
+ const n_a = n_pre.map(tanhFn);
1991
+ const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
1992
+ this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
1993
+ this.h = h;
1994
+ return h;
1995
+ }
1996
+ backprop(dh_seq, lr) {
1997
+ const T = this._traj.length;
1998
+ if (T === 0 || dh_seq.length !== T) return;
1999
+ const hSize = this.hSize;
2000
+ const combSize = this.inputSize + hSize;
2001
+ const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
2002
+ const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
2003
+ const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
2004
+ const dbr = new Array(hSize).fill(0);
2005
+ const dbz = new Array(hSize).fill(0);
2006
+ const dbn = new Array(hSize).fill(0);
2007
+ let dh_next = new Array(hSize).fill(0);
2008
+ for (let t = T - 1; t >= 0; t--) {
2009
+ const s = this._traj[t];
2010
+ const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
2011
+ const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
2012
+ const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
2013
+ const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
2014
+ const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
2015
+ const dr_hprev = Array.from(
2016
+ { length: hSize },
2017
+ (_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
2018
+ );
2019
+ const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
2020
+ const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
2021
+ for (let k = 0; k < hSize; k++) {
2022
+ for (let j = 0; j < combSize; j++) {
2023
+ dWr[k][j] += dr_pre[k] * s.combined[j];
2024
+ dWz[k][j] += dz_pre[k] * s.combined[j];
2025
+ dWn[k][j] += dn_pre[k] * s.combined_r[j];
2026
+ }
2027
+ dbr[k] += dr_pre[k];
2028
+ dbz[k] += dz_pre[k];
2029
+ dbn[k] += dn_pre[k];
2030
+ }
2031
+ dh_next = new Array(hSize).fill(0);
2032
+ for (let k = 0; k < hSize; k++) {
2033
+ for (let j = this.inputSize; j < combSize; j++) {
2034
+ dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
2035
+ }
2036
+ dh_next[k] += dr_hprev[k] * s.r_a[k];
2037
+ dh_next[k] += dh[k] * s.z_a[k];
2038
+ }
2039
+ }
2040
+ const scale = lr / T;
2041
+ for (let k = 0; k < hSize; k++) {
2042
+ for (let j = 0; j < combSize; j++) {
2043
+ this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
2044
+ this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
2045
+ this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
2046
+ }
2047
+ this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
2048
+ this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
2049
+ this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
2050
+ }
2051
+ this._traj = [];
2052
+ }
2053
+ // ── Flat weight serialization ─────────────────────────────────────────────
2054
+ // Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
2055
+ getWeightsFlat() {
2056
+ const w = [];
2057
+ for (const row of this.resetGate.W) w.push(...row);
2058
+ w.push(...this.resetGate.b);
2059
+ for (const row of this.updateGate.W) w.push(...row);
2060
+ w.push(...this.updateGate.b);
2061
+ for (const row of this.newGate.W) w.push(...row);
2062
+ w.push(...this.newGate.b);
2063
+ return w;
2064
+ }
2065
+ setWeightsFlat(weights) {
2066
+ let idx = 0;
2067
+ for (let i = 0; i < this.resetGate.W.length; i++)
2068
+ for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
2069
+ for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
2070
+ for (let i = 0; i < this.updateGate.W.length; i++)
2071
+ for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
2072
+ for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
2073
+ for (let i = 0; i < this.newGate.W.length; i++)
2074
+ for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
2075
+ for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
2076
+ }
2077
+ getWeights() {
2078
+ return {
2079
+ resetGate: { W: this.resetGate.W, b: this.resetGate.b },
2080
+ updateGate: { W: this.updateGate.W, b: this.updateGate.b },
2081
+ newGate: { W: this.newGate.W, b: this.newGate.b }
2082
+ };
2083
+ }
2084
+ setWeights(data) {
2085
+ this.resetGate.W = data.resetGate.W;
2086
+ this.resetGate.b = data.resetGate.b;
2087
+ this.updateGate.W = data.updateGate.W;
2088
+ this.updateGate.b = data.updateGate.b;
2089
+ this.newGate.W = data.newGate.W;
2090
+ this.newGate.b = data.newGate.b;
2091
+ }
2092
+ };
2093
+
2094
+ // src/BatchNorm.ts
2095
+ var BatchNorm = class {
2096
+ constructor(dim, momentum = 0.1) {
2097
+ this._xNorm = null;
2098
+ this._std = null;
2099
+ this.dim = dim;
2100
+ this.momentum = momentum;
2101
+ this.gamma = new Array(dim).fill(1);
2102
+ this.beta = new Array(dim).fill(0);
2103
+ this.runningMean = new Array(dim).fill(0);
2104
+ this.runningVar = new Array(dim).fill(1);
2105
+ }
2106
+ // ── Forward ───────────────────────────────────────────────────────────────
2107
+ forward(x) {
2108
+ if (x.length !== this.dim) {
2109
+ throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
2110
+ }
2111
+ const eps = 1e-5;
2112
+ for (let i = 0; i < this.dim; i++) {
2113
+ this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
2114
+ const diff = x[i] - this.runningMean[i];
2115
+ this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
2116
+ }
2117
+ this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
2118
+ this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
2119
+ return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
2120
+ }
2121
+ // ── Backward ──────────────────────────────────────────────────────────────
2122
+ backward(dOut) {
2123
+ if (!this._xNorm || !this._std) {
2124
+ throw new Error("BatchNorm.backward: call forward() first");
2125
+ }
2126
+ for (let i = 0; i < this.dim; i++) {
2127
+ }
2128
+ return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
2129
+ }
2130
+ // ── Train gamma and beta (call after backward) ────────────────────────────
2131
+ trainParams(dOut, lr) {
2132
+ if (!this._xNorm) return;
2133
+ for (let i = 0; i < this.dim; i++) {
2134
+ this.gamma[i] += lr * dOut[i] * this._xNorm[i];
2135
+ this.beta[i] += lr * dOut[i];
2136
+ }
2137
+ }
2138
+ // ── Flat weight serialization ─────────────────────────────────────────────
2139
+ // Order: gamma, beta.
2140
+ getWeights() {
2141
+ return [...this.gamma, ...this.beta];
2142
+ }
2143
+ setWeights(weights) {
2144
+ for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
2145
+ for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
2146
+ }
2147
+ };
2148
+
2149
+ // src/Conv1D.ts
2150
+ var Conv1D = class {
2151
+ constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
2152
+ // [filters]
2153
+ this._input = null;
2154
+ this._paddedInput = null;
2155
+ if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
2156
+ throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
2157
+ }
2158
+ if (kernelSize > inputLength && padding === "valid") {
2159
+ throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
2160
+ }
2161
+ if (inputChannels < 1) {
2162
+ throw new Error("Conv1D: inputChannels must be >= 1");
2163
+ }
2164
+ this.inputLength = inputLength;
2165
+ this.kernelSize = kernelSize;
2166
+ this.filters = filters;
2167
+ this.stride = stride;
2168
+ this.padding = padding;
2169
+ this.inputChannels = inputChannels;
2170
+ const limit = Math.sqrt(2 / (kernelSize * inputChannels));
2171
+ this.kernels = Array.from(
2172
+ { length: filters },
2173
+ () => Array.from(
2174
+ { length: kernelSize },
2175
+ () => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
2176
+ )
2177
+ );
2178
+ this.biases = new Array(filters).fill(0);
2179
+ this._kOpts = Array.from(
2180
+ { length: filters },
2181
+ () => Array.from(
2182
+ { length: kernelSize },
2183
+ () => Array.from({ length: inputChannels }, () => optimizerFactory())
2184
+ )
2185
+ );
2186
+ this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
2187
+ }
2188
+ // ── Forward ───────────────────────────────────────────────────────────────
2189
+ // Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
2190
+ forward(input) {
2191
+ const input2D = this._normalizeInput(input);
2192
+ this._input = input2D.map((row) => [...row]);
2193
+ let padded;
2194
+ if (this.padding === "same") {
2195
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2196
+ const padRow = new Array(this.inputChannels).fill(0);
2197
+ padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
2198
+ } else {
2199
+ padded = input2D;
2200
+ }
2201
+ this._paddedInput = padded;
2202
+ const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
2203
+ const output = Array.from(
2204
+ { length: this.filters },
2205
+ () => new Array(outputLength).fill(0)
2206
+ );
2207
+ for (let f = 0; f < this.filters; f++) {
2208
+ for (let pos = 0; pos < outputLength; pos++) {
2209
+ const start = pos * this.stride;
2210
+ let sum = this.biases[f];
2211
+ for (let k = 0; k < this.kernelSize; k++) {
2212
+ for (let c = 0; c < this.inputChannels; c++) {
2213
+ sum += this.kernels[f][k][c] * padded[start + k][c];
2214
+ }
2215
+ }
2216
+ output[f][pos] = sum;
2217
+ }
2218
+ }
2219
+ return output;
2220
+ }
2221
+ // ── Backward ──────────────────────────────────────────────────────────────
2222
+ backward(dOut, lr = 1e-3) {
2223
+ if (!this._paddedInput || !this._input) {
2224
+ throw new Error("Conv1D.backward: call forward() first");
2225
+ }
2226
+ const padded = this._paddedInput;
2227
+ const outputLength = dOut[0].length;
2228
+ const dKernels = Array.from(
2229
+ { length: this.filters },
2230
+ () => Array.from(
2231
+ { length: this.kernelSize },
2232
+ () => new Array(this.inputChannels).fill(0)
2233
+ )
2234
+ );
2235
+ const dBiases = new Array(this.filters).fill(0);
2236
+ const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
2237
+ for (let f = 0; f < this.filters; f++) {
2238
+ for (let pos = 0; pos < outputLength; pos++) {
2239
+ const start = pos * this.stride;
2240
+ dBiases[f] += dOut[f][pos];
2241
+ for (let k = 0; k < this.kernelSize; k++) {
2242
+ for (let c = 0; c < this.inputChannels; c++) {
2243
+ dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
2244
+ dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
2245
+ }
2246
+ }
2247
+ }
2248
+ }
2249
+ for (let f = 0; f < this.filters; f++) {
2250
+ for (let k = 0; k < this.kernelSize; k++) {
2251
+ for (let c = 0; c < this.inputChannels; c++) {
2252
+ this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
2253
+ }
2254
+ }
2255
+ this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
2256
+ }
2257
+ if (this.padding === "same") {
2258
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2259
+ return dPadded.slice(padSize, padSize + this.inputLength);
2260
+ }
2261
+ return dPadded.slice(0, this.inputLength);
2262
+ }
2263
+ // ── Output length ─────────────────────────────────────────────────────────
2264
+ getOutputLength() {
2265
+ if (this.padding === "same") {
2266
+ return Math.ceil(this.inputLength / this.stride);
2267
+ }
2268
+ return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
2269
+ }
2270
+ // ── Flat weight serialization ─────────────────────────────────────────────
2271
+ // Order: kernels (flattened), biases.
2272
+ getWeights() {
2273
+ const w = [];
2274
+ for (const kernel of this.kernels)
2275
+ for (const k of kernel)
2276
+ for (const c of k)
2277
+ w.push(c);
2278
+ w.push(...this.biases);
2279
+ return w;
2280
+ }
2281
+ setWeights(weights) {
2282
+ let idx = 0;
2283
+ for (let f = 0; f < this.filters; f++)
2284
+ for (let k = 0; k < this.kernelSize; k++)
2285
+ for (let c = 0; c < this.inputChannels; c++)
2286
+ this.kernels[f][k][c] = weights[idx++];
2287
+ for (let f = 0; f < this.filters; f++)
2288
+ this.biases[f] = weights[idx++];
2289
+ }
2290
+ // ── Normalize input to 2D format ─────────────────────────────────────────
2291
+ _normalizeInput(input) {
2292
+ if (input.length === 0) {
2293
+ throw new Error("Conv1D.forward: input cannot be empty");
2294
+ }
2295
+ if (typeof input[0] === "number") {
2296
+ if (this.inputChannels !== 1) {
2297
+ throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
2298
+ }
2299
+ const input1D = input;
2300
+ if (input1D.length !== this.inputLength) {
2301
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
2302
+ }
2303
+ return input1D.map((v) => [v]);
2304
+ }
2305
+ const input2D = input;
2306
+ if (input2D.length !== this.inputLength) {
2307
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
2308
+ }
2309
+ for (let i = 0; i < input2D.length; i++) {
2310
+ if (input2D[i].length !== this.inputChannels) {
2311
+ throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
2312
+ }
2313
+ }
2314
+ return input2D;
2315
+ }
2316
+ };
2317
+
2318
+ // src/Trainer.ts
2319
+ var Trainer = class {
2320
+ constructor(network, options = {}) {
2321
+ this._history = [];
2322
+ this._bestLoss = Infinity;
2323
+ this._patienceCounter = 0;
2324
+ this._stopReason = "maxEpochs";
2325
+ this._metrics = [];
2326
+ this.network = network;
2327
+ this.epochs = options.epochs ?? 1e3;
2328
+ this.lrInitial = options.lr ?? 0.1;
2329
+ this.lrDecay = options.lrDecay ?? 1;
2330
+ this.verbose = options.verbose ?? false;
2331
+ this.weightDecay = options.weightDecay ?? 0;
2332
+ this._earlyStopping = options.earlyStopping;
2333
+ this._computeMetrics = options.computeMetrics ?? false;
2334
+ this.clipValue = options.clipValue ?? 0;
2335
+ }
2336
+ // ── Set external validation data (for early stopping) ────────────────────
2337
+ setValidationData(dataset) {
2338
+ if (dataset.inputs.length !== dataset.targets.length) {
2339
+ throw new Error(
2340
+ "Trainer.setValidationData: inputs and targets must have the same length"
2341
+ );
2342
+ }
2343
+ this._validationData = dataset;
2344
+ }
2345
+ // ── Get best validation loss during training ─────────────────────────────
2346
+ getBestLoss() {
2347
+ return this._bestLoss === Infinity ? -1 : this._bestLoss;
2348
+ }
2349
+ // ── Why did training stop? ───────────────────────────────────────────────
2350
+ getStopReason() {
2351
+ return this._stopReason;
2352
+ }
2353
+ // ── Get per-epoch classification metrics ─────────────────────────────────
2354
+ getMetrics() {
2355
+ return [...this._metrics];
2356
+ }
2357
+ // ── Train on dataset ──────────────────────────────────────────────────────
2358
+ train(dataset) {
2359
+ const { inputs, targets } = dataset;
2360
+ if (inputs.length !== targets.length) {
2361
+ throw new Error(
2362
+ "Trainer.train: inputs and targets must have the same length"
2363
+ );
2364
+ }
2365
+ const n = inputs.length;
2366
+ let lr = this.lrInitial;
2367
+ this._history = [];
2368
+ this._bestLoss = Infinity;
2369
+ this._patienceCounter = 0;
2370
+ this._stopReason = "maxEpochs";
2371
+ this._metrics = [];
2372
+ const netExt = this._hasWeights(this.network);
2373
+ if (this.weightDecay > 0 && !netExt) {
2374
+ console.warn(
2375
+ "Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
2376
+ );
2377
+ }
2378
+ if (this._earlyStopping && !netExt) {
2379
+ console.warn(
2380
+ "Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
2381
+ );
2382
+ }
2383
+ if (this._computeMetrics && !netExt) {
2384
+ console.warn(
2385
+ "Trainer: computeMetrics requires a network with predict(). Skipping metrics."
2386
+ );
2387
+ }
2388
+ const canDecay = this.weightDecay > 0 && netExt;
2389
+ const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
2390
+ const canMetric = this._computeMetrics && netExt;
2391
+ const isClass = canMetric && this._isClassification(targets);
2392
+ if (canMetric && !isClass) {
2393
+ console.warn(
2394
+ "Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
2395
+ );
2396
+ }
2397
+ for (let epoch = 0; epoch < this.epochs; epoch++) {
2398
+ const indices = Array.from({ length: n }, (_, i) => i);
2399
+ for (let i = n - 1; i > 0; i--) {
2400
+ const j = Math.floor(Math.random() * (i + 1));
2401
+ [indices[i], indices[j]] = [indices[j], indices[i]];
2402
+ }
2403
+ let epochLoss = 0;
2404
+ for (const i of indices) {
2405
+ if (canDecay) {
2406
+ const w = netExt.getWeights();
2407
+ for (let j = 0; j < w.length; j++) {
2408
+ w[j] *= 1 - lr * this.weightDecay;
2409
+ }
2410
+ netExt.setWeights(w);
2411
+ }
2412
+ epochLoss += this.network.train(inputs[i], targets[i], lr);
2413
+ }
2414
+ epochLoss /= n;
2415
+ this._history.push(epochLoss);
2416
+ if (canMetric && isClass) {
2417
+ this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
2418
+ }
2419
+ if (canValidate && this._validationData) {
2420
+ const valLoss = this._computeLoss(netExt, this._validationData);
2421
+ const minDelta = this._earlyStopping.minDelta;
2422
+ if (valLoss < this._bestLoss - minDelta) {
2423
+ this._bestLoss = valLoss;
2424
+ this._patienceCounter = 0;
2425
+ } else {
2426
+ this._patienceCounter++;
2427
+ }
2428
+ if (this._patienceCounter >= this._earlyStopping.patience) {
2429
+ this._stopReason = "earlyStopping";
2430
+ break;
2431
+ }
2432
+ }
2433
+ lr *= this.lrDecay;
2434
+ if (this.verbose && (epoch + 1) % 100 === 0) {
2435
+ console.log(
2436
+ `Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
2437
+ );
2438
+ }
2439
+ }
2440
+ return this._history;
2441
+ }
2442
+ // ── Get loss history ──────────────────────────────────────────────────────
2443
+ getHistory() {
2444
+ return [...this._history];
2445
+ }
2446
+ // ── Private helpers ───────────────────────────────────────────────────────
2447
+ /** Type guard: does this network support getWeights/setWeights/predict? */
2448
+ _hasWeights(network) {
2449
+ if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
2450
+ return network;
2451
+ }
2452
+ return null;
2453
+ }
2454
+ /** Mean squared error on a dataset (used for validation loss). */
2455
+ _computeLoss(net, data) {
2456
+ let totalLoss = 0;
2457
+ for (let i = 0; i < data.inputs.length; i++) {
2458
+ const pred = net.predict(data.inputs[i]);
2459
+ const target = data.targets[i];
2460
+ let sampleLoss = 0;
2461
+ for (let j = 0; j < pred.length; j++) {
2462
+ sampleLoss += (target[j] - pred[j]) ** 2;
2463
+ }
2464
+ totalLoss += sampleLoss / pred.length;
2465
+ }
2466
+ return totalLoss / data.inputs.length;
2467
+ }
2468
+ /** Heuristic: are targets classification-style (one-hot or single-class)? */
2469
+ _isClassification(targets) {
2470
+ if (targets.length === 0) return false;
2471
+ const first = targets[0];
2472
+ if (first.length === 1) return true;
2473
+ for (const t of targets) {
2474
+ let sum = 0;
2475
+ for (const v of t) {
2476
+ sum += v;
2477
+ if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
2478
+ return false;
2479
+ }
2480
+ if (Math.abs(sum - 1) > 0.01) return false;
2481
+ }
2482
+ return true;
2483
+ }
2484
+ /** Compute classification metrics from predictions vs targets. */
2485
+ _computeMetricsArray(net, inputs, targets) {
2486
+ const targetLen = targets[0].length;
2487
+ const nClasses = targetLen === 1 ? 2 : targetLen;
2488
+ const confusion = Array.from(
2489
+ { length: nClasses },
2490
+ () => Array(nClasses).fill(0)
2491
+ );
2492
+ for (let i = 0; i < inputs.length; i++) {
2493
+ const pred = net.predict(inputs[i]);
2494
+ const target = targets[i];
2495
+ let predClass;
2496
+ let trueClass;
2497
+ if (targetLen === 1) {
2498
+ trueClass = target[0] >= 0.5 ? 1 : 0;
2499
+ if (pred.length === 1) {
2500
+ predClass = pred[0] >= 0.5 ? 1 : 0;
2501
+ } else {
2502
+ predClass = pred.indexOf(Math.max(...pred));
2503
+ }
2504
+ } else {
2505
+ predClass = pred.indexOf(Math.max(...pred));
2506
+ trueClass = target.indexOf(Math.max(...target));
2507
+ }
2508
+ predClass = Math.max(0, Math.min(nClasses - 1, predClass));
2509
+ trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
2510
+ confusion[trueClass][predClass]++;
2511
+ }
2512
+ let totalCorrect = 0;
2513
+ let totalSamples = 0;
2514
+ const precisions = [];
2515
+ const recalls = [];
2516
+ for (let c = 0; c < nClasses; c++) {
2517
+ const tp = confusion[c][c];
2518
+ totalCorrect += tp;
2519
+ let colSum = 0;
2520
+ let rowSum = 0;
2521
+ for (let r = 0; r < nClasses; r++) {
2522
+ colSum += confusion[r][c];
2523
+ rowSum += confusion[c][r];
2524
+ }
2525
+ totalSamples += rowSum;
2526
+ precisions.push(colSum > 0 ? tp / colSum : 0);
2527
+ recalls.push(rowSum > 0 ? tp / rowSum : 0);
2528
+ }
2529
+ const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
2530
+ const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
2531
+ const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
2532
+ const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
2533
+ return {
2534
+ accuracy,
2535
+ precision: macroPrecision,
2536
+ recall: macroRecall,
2537
+ f1
2538
+ };
2539
+ }
2540
+ };
2541
+
2542
+ // src/DataLoader.ts
2543
+ var DataLoader = class _DataLoader {
2544
+ constructor(data, batchSize = 1, validationSplit = 0) {
2545
+ if (data.inputs.length !== data.targets.length) {
2546
+ throw new Error("DataLoader: inputs and targets must have the same length");
2547
+ }
2548
+ if (validationSplit < 0 || validationSplit >= 1) {
2549
+ throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
2550
+ }
2551
+ this.data = data;
2552
+ this.batchSize = batchSize;
2553
+ this._validationSplit = validationSplit;
2554
+ const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
2555
+ for (let i = fullIndices.length - 1; i > 0; i--) {
2556
+ const j = Math.floor(Math.random() * (i + 1));
2557
+ [fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
2558
+ }
2559
+ if (validationSplit > 0) {
2560
+ const valSize = Math.round(data.inputs.length * validationSplit);
2561
+ const trainSize = data.inputs.length - valSize;
2562
+ this._trainIndices = fullIndices.slice(0, trainSize);
2563
+ this._valIndices = fullIndices.slice(trainSize);
2564
+ } else {
2565
+ this._trainIndices = [...fullIndices];
2566
+ this._valIndices = [];
2567
+ }
2568
+ this._indices = [...this._trainIndices];
2569
+ this._pos = 0;
2570
+ }
2571
+ // ── Shuffle the training data ──────────────────────────────────────────────
2572
+ shuffle() {
2573
+ for (let i = this._trainIndices.length - 1; i > 0; i--) {
2574
+ const j = Math.floor(Math.random() * (i + 1));
2575
+ [this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
2576
+ }
2577
+ this._indices = [...this._trainIndices];
2578
+ this._pos = 0;
2579
+ }
2580
+ // ── Check if more batches are available ───────────────────────────────────
2581
+ hasNext() {
2582
+ return this._pos < this._indices.length;
2583
+ }
2584
+ // ── Get next batch ────────────────────────────────────────────────────────
2585
+ next() {
2586
+ const end = Math.min(this._pos + this.batchSize, this._indices.length);
2587
+ const batchIndices = this._indices.slice(this._pos, end);
2588
+ this._pos = end;
2589
+ return {
2590
+ inputs: batchIndices.map((i) => this.data.inputs[i]),
2591
+ targets: batchIndices.map((i) => this.data.targets[i])
2592
+ };
2593
+ }
2594
+ // ── Reset iteration ───────────────────────────────────────────────────────
2595
+ reset() {
2596
+ this._pos = 0;
2597
+ }
2598
+ // ── Get total number of training samples ───────────────────────────────────
2599
+ get length() {
2600
+ return this._trainIndices.length;
2601
+ }
2602
+ // ── Get validation data as a DataPair ──────────────────────────────────────
2603
+ // Returns the validation samples (inputs + targets) in their shuffled order.
2604
+ // Returns empty arrays if no validation split was configured.
2605
+ getValidationData() {
2606
+ return {
2607
+ inputs: this._valIndices.map((i) => this.data.inputs[i]),
2608
+ targets: this._valIndices.map((i) => this.data.targets[i])
2609
+ };
2610
+ }
2611
+ // ── Get number of validation samples ───────────────────────────────────────
2612
+ get validationLength() {
2613
+ return this._valIndices.length;
2614
+ }
2615
+ // ── Create sequence windows from a time series ────────────────────────────
2616
+ static sequences(data, seqLen, validationSplit = 0) {
2617
+ if (data.length < seqLen + 1) {
2618
+ throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
2619
+ }
2620
+ const inputs = [];
2621
+ const targets = [];
2622
+ for (let i = 0; i <= data.length - seqLen - 1; i++) {
2623
+ inputs.push(data.slice(i, i + seqLen).flat());
2624
+ targets.push(data[i + seqLen]);
2625
+ }
2626
+ return new _DataLoader({ inputs, targets }, 1, validationSplit);
2627
+ }
2628
+ };
2629
+
2630
+ // src/LRScheduler.ts
2631
+ var LRScheduler = class {
2632
+ // ── Step Decay ────────────────────────────────────────────────────────────
2633
+ // lr = initialLr * dropRate^floor(epoch / epochsDrop)
2634
+ stepDecay(lr, epoch, dropRate, epochsDrop) {
2635
+ return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
2636
+ }
2637
+ // ── Exponential Decay ─────────────────────────────────────────────────────
2638
+ // lr = initialLr * decayRate^epoch
2639
+ exponentialDecay(lr, epoch, decayRate) {
2640
+ return lr * Math.pow(decayRate, epoch);
2641
+ }
2642
+ // ── Plateau Decay ─────────────────────────────────────────────────────────
2643
+ // If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
2644
+ // Returns the new lr. Call this after each epoch with the current loss.
2645
+ //
2646
+ // Usage:
2647
+ // let patience_counter = 0
2648
+ // let best_loss = Infinity
2649
+ // for (let epoch = 0; epoch < 1000; epoch++) {
2650
+ // const loss = train(...)
2651
+ // lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
2652
+ // }
2653
+ plateauDecay(lr, currentLoss, history, patience, factor) {
2654
+ if (history.length < patience) return lr;
2655
+ const recentLosses = history.slice(-patience);
2656
+ const minRecentLoss = Math.min(...recentLosses);
2657
+ if (currentLoss >= minRecentLoss) {
2658
+ return lr * factor;
2659
+ }
2660
+ return lr;
2661
+ }
2662
+ // ── Cosine Annealing ──────────────────────────────────────────────────────
2663
+ // lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
2664
+ cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
2665
+ return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
2666
+ }
2667
+ };
2668
+
2669
+ // src/ModelSaver.ts
2670
+ var ModelSaver = class _ModelSaver {
2671
+ // ── Serialize to JSON string ──────────────────────────────────────────────
2672
+ static toJSON(model) {
2673
+ return JSON.stringify({
2674
+ weights: model.getWeights(),
2675
+ timestamp: Date.now()
2676
+ });
2677
+ }
2678
+ // ── Deserialize from JSON string ──────────────────────────────────────────
2679
+ static fromJSON(model, json) {
2680
+ const data = JSON.parse(json);
2681
+ if (!data.weights || !Array.isArray(data.weights)) {
2682
+ throw new Error("ModelSaver.fromJSON: invalid model data");
2683
+ }
2684
+ model.setWeights(data.weights);
2685
+ }
2686
+ // ── Save to file (requires write function) ────────────────────────────────
2687
+ static saveToFile(model, path, writeFn) {
2688
+ const json = _ModelSaver.toJSON(model);
2689
+ writeFn(path, json);
2690
+ }
2691
+ // ── Load from file (requires read function) ───────────────────────────────
2692
+ static loadFromFile(model, path, readFn) {
2693
+ const json = readFn(path);
2694
+ _ModelSaver.fromJSON(model, json);
2695
+ }
2696
+ };
1242
2697
  // Annotate the CommonJS export names for ESM import in node:
1243
2698
  0 && (module.exports = {
1244
2699
  Adam,
1245
2700
  AttentionHead,
2701
+ BatchNorm,
2702
+ ClipOptimizer,
2703
+ ClippedOptimizerFactory,
2704
+ Conv1D,
2705
+ DataLoader,
2706
+ Dropout,
1246
2707
  EmbeddingMatrix,
2708
+ GRULayer,
2709
+ LRScheduler,
1247
2710
  LSTMLayer,
1248
2711
  Layer,
1249
2712
  LayerNorm,
2713
+ ModelSaver,
1250
2714
  Momentum,
1251
2715
  MultiHeadAttention,
1252
2716
  Network,
@@ -1257,6 +2721,7 @@ function crossEntropyDeltaRaw(predicted, actual) {
1257
2721
  Neuron,
1258
2722
  NeuronN,
1259
2723
  SGD,
2724
+ Trainer,
1260
2725
  TransformerBlock,
1261
2726
  WeightMatrix,
1262
2727
  crossEntropy,
@@ -1275,5 +2740,9 @@ function crossEntropyDeltaRaw(predicted, actual) {
1275
2740
  softmax,
1276
2741
  softmaxBackward,
1277
2742
  tanh,
1278
- transpose
2743
+ transpose,
2744
+ validate2DArray,
2745
+ validateArray,
2746
+ validateArrayMinLength,
2747
+ validateNumber
1279
2748
  });