@dniskav/neuron 0.2.3 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -22,10 +22,20 @@ var index_exports = {};
22
22
  __export(index_exports, {
23
23
  Adam: () => Adam,
24
24
  AttentionHead: () => AttentionHead,
25
+ BatchNorm: () => BatchNorm,
26
+ BiasVector: () => BiasVector,
27
+ ClipOptimizer: () => ClipOptimizer,
28
+ ClippedOptimizerFactory: () => ClippedOptimizerFactory,
29
+ Conv1D: () => Conv1D,
30
+ DataLoader: () => DataLoader,
31
+ Dropout: () => Dropout,
25
32
  EmbeddingMatrix: () => EmbeddingMatrix,
33
+ GRULayer: () => GRULayer,
34
+ LRScheduler: () => LRScheduler,
26
35
  LSTMLayer: () => LSTMLayer,
27
36
  Layer: () => Layer,
28
37
  LayerNorm: () => LayerNorm,
38
+ ModelSaver: () => ModelSaver,
29
39
  Momentum: () => Momentum,
30
40
  MultiHeadAttention: () => MultiHeadAttention,
31
41
  Network: () => Network,
@@ -36,11 +46,13 @@ __export(index_exports, {
36
46
  Neuron: () => Neuron,
37
47
  NeuronN: () => NeuronN,
38
48
  SGD: () => SGD,
49
+ Trainer: () => Trainer,
39
50
  TransformerBlock: () => TransformerBlock,
40
51
  WeightMatrix: () => WeightMatrix,
41
52
  crossEntropy: () => crossEntropy,
42
53
  crossEntropyDelta: () => crossEntropyDelta,
43
54
  crossEntropyDeltaRaw: () => crossEntropyDeltaRaw,
55
+ defaultOptimizer: () => defaultOptimizer,
44
56
  elu: () => elu,
45
57
  leakyRelu: () => leakyRelu,
46
58
  linear: () => linear,
@@ -54,10 +66,82 @@ __export(index_exports, {
54
66
  softmax: () => softmax,
55
67
  softmaxBackward: () => softmaxBackward,
56
68
  tanh: () => tanh,
57
- transpose: () => transpose
69
+ transpose: () => transpose,
70
+ validate2DArray: () => validate2DArray,
71
+ validateArray: () => validateArray,
72
+ validateArrayMinLength: () => validateArrayMinLength,
73
+ validateNumber: () => validateNumber
58
74
  });
59
75
  module.exports = __toCommonJS(index_exports);
60
76
 
77
+ // src/Validation.ts
78
+ function validateArray(arr, expectedLength, methodName) {
79
+ if (!Array.isArray(arr)) {
80
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
81
+ }
82
+ if (arr.length !== expectedLength) {
83
+ throw new Error(
84
+ `${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
85
+ );
86
+ }
87
+ for (let i = 0; i < arr.length; i++) {
88
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
89
+ throw new Error(
90
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
91
+ );
92
+ }
93
+ }
94
+ }
95
+ function validateArrayMinLength(arr, minLength, methodName) {
96
+ if (!Array.isArray(arr)) {
97
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
98
+ }
99
+ if (arr.length < minLength) {
100
+ throw new Error(
101
+ `${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
102
+ );
103
+ }
104
+ for (let i = 0; i < arr.length; i++) {
105
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
106
+ throw new Error(
107
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
108
+ );
109
+ }
110
+ }
111
+ }
112
+ function validate2DArray(arr, expectedRows, expectedCols, methodName) {
113
+ if (!Array.isArray(arr)) {
114
+ throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
115
+ }
116
+ if (arr.length !== expectedRows) {
117
+ throw new Error(
118
+ `${methodName}: expected ${expectedRows} rows, got ${arr.length}`
119
+ );
120
+ }
121
+ for (let i = 0; i < arr.length; i++) {
122
+ if (!Array.isArray(arr[i])) {
123
+ throw new Error(`${methodName}: row ${i} is not an array`);
124
+ }
125
+ if (arr[i].length !== expectedCols) {
126
+ throw new Error(
127
+ `${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
128
+ );
129
+ }
130
+ for (let j = 0; j < arr[i].length; j++) {
131
+ if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
132
+ throw new Error(
133
+ `${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
134
+ );
135
+ }
136
+ }
137
+ }
138
+ }
139
+ function validateNumber(value, methodName) {
140
+ if (typeof value !== "number" || !isFinite(value)) {
141
+ throw new Error(`${methodName}: expected finite number, got ${value}`);
142
+ }
143
+ }
144
+
61
145
  // src/Neuron.ts
62
146
  function sigmoid(x) {
63
147
  return 1 / (1 + Math.exp(-x));
@@ -68,13 +152,18 @@ var Neuron = class {
68
152
  this.bias = Math.random() * 0.1;
69
153
  }
70
154
  predict(input) {
155
+ validateNumber(input, "Neuron.predict");
71
156
  return sigmoid(input * this.weight + this.bias);
72
157
  }
73
158
  train(input, target, lr) {
159
+ validateNumber(input, "Neuron.train");
160
+ validateNumber(target, "Neuron.train");
161
+ validateNumber(lr, "Neuron.train");
74
162
  const prediction = this.predict(input);
75
163
  const error = target - prediction;
76
- this.weight += lr * error * input;
77
- this.bias += lr * error;
164
+ const grad = error * prediction * (1 - prediction);
165
+ this.weight += lr * grad * input;
166
+ this.bias += lr * grad;
78
167
  }
79
168
  };
80
169
 
@@ -114,6 +203,7 @@ function makeElu(alpha = 1) {
114
203
  var elu = makeElu(1);
115
204
 
116
205
  // src/optimizers.ts
206
+ var defaultOptimizer = () => new SGD();
117
207
  var SGD = class {
118
208
  step(weight, gradient, lr) {
119
209
  return weight + lr * gradient;
@@ -129,6 +219,19 @@ var Momentum = class {
129
219
  return weight + this.v;
130
220
  }
131
221
  };
222
+ var ClipOptimizer = class {
223
+ constructor(inner, clipValue) {
224
+ this.inner = inner;
225
+ this.clipValue = clipValue;
226
+ }
227
+ step(weight, gradient, lr) {
228
+ const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
229
+ return this.inner.step(weight, clipped, lr);
230
+ }
231
+ };
232
+ function ClippedOptimizerFactory(innerFactory, clipValue) {
233
+ return () => new ClipOptimizer(innerFactory(), clipValue);
234
+ }
132
235
  var Adam = class {
133
236
  constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
134
237
  this.beta1 = beta1;
@@ -149,7 +252,6 @@ var Adam = class {
149
252
  };
150
253
 
151
254
  // src/NeuronN.ts
152
- var defaultOptimizer = () => new SGD();
153
255
  var NeuronN = class {
154
256
  constructor(nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
155
257
  const limit = Math.sqrt(1 / nInputs);
@@ -159,6 +261,7 @@ var NeuronN = class {
159
261
  this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
160
262
  }
161
263
  predict(inputs) {
264
+ validateArray(inputs, this.weights.length, "NeuronN.predict");
162
265
  const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
163
266
  return this.activation.fn(sum);
164
267
  }
@@ -171,14 +274,14 @@ var NeuronN = class {
171
274
  train(inputs, target, lr) {
172
275
  const prediction = this.predict(inputs);
173
276
  const error = target - prediction;
174
- this._update(inputs.map((inp) => error * inp), error, lr);
277
+ const grad = error * this.activation.dfn(prediction);
278
+ this._update(inputs.map((inp) => grad * inp), grad, lr);
175
279
  }
176
280
  };
177
281
 
178
282
  // src/Layer.ts
179
- var defaultOptimizer2 = () => new SGD();
180
283
  var Layer = class {
181
- constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer2) {
284
+ constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
182
285
  this.neurons = Array.from(
183
286
  { length: nNeurons },
184
287
  () => new NeuronN(nInputs, activation, optimizerFactory)
@@ -196,84 +299,233 @@ var Network = class {
196
299
  this.outputLayer = new Layer(nOutputs, nHidden);
197
300
  }
198
301
  predict(inputs) {
302
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
199
303
  const hiddenOut = this.hiddenLayer.predict(inputs);
200
- return this.outputLayer.predict(hiddenOut)[0];
304
+ return this.outputLayer.predict(hiddenOut);
201
305
  }
202
306
  // Trains on a single example. Returns the squared error.
203
307
  train(inputs, target, lr) {
308
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
309
+ validateNumber(target, "Network.train");
310
+ validateNumber(lr, "Network.train");
204
311
  const hiddenOut = this.hiddenLayer.predict(inputs);
205
312
  const prediction = this.outputLayer.predict(hiddenOut)[0];
206
- const outputError = target - prediction;
207
- const outputDelta = outputError * prediction * (1 - prediction);
208
313
  const outputNeuron = this.outputLayer.neurons[0];
209
- outputNeuron.weights = outputNeuron.weights.map(
210
- (w, i) => w + lr * outputDelta * hiddenOut[i]
211
- );
212
- outputNeuron.bias += lr * outputDelta;
213
- this.hiddenLayer.neurons.forEach((neuron, i) => {
214
- const hiddenOut_i = hiddenOut[i];
314
+ const outputError = target - prediction;
315
+ const outputDelta = outputError * outputNeuron.activation.dfn(prediction);
316
+ const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
215
317
  const hiddenError = outputDelta * outputNeuron.weights[i];
216
- const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
217
- neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
218
- neuron.bias += lr * hiddenDelta;
318
+ return hiddenError * neuron.activation.dfn(hiddenOut[i]);
319
+ });
320
+ this.hiddenLayer.neurons.forEach((neuron, i) => {
321
+ neuron._update(inputs.map((inp) => hiddenDeltas[i] * inp), hiddenDeltas[i], lr);
219
322
  });
323
+ outputNeuron._update(hiddenOut.map((h) => outputDelta * h), outputDelta, lr);
220
324
  return outputError * outputError;
221
325
  }
326
+ // ── Flat weight serialization ─────────────────────────────────────────────
327
+ // Order: hidden layer (all neurons: weights then bias), then output layer.
328
+ getWeights() {
329
+ const w = [];
330
+ for (const n of this.hiddenLayer.neurons) {
331
+ w.push(...n.weights, n.bias);
332
+ }
333
+ for (const n of this.outputLayer.neurons) {
334
+ w.push(...n.weights, n.bias);
335
+ }
336
+ return w;
337
+ }
338
+ setWeights(weights) {
339
+ let idx = 0;
340
+ for (const n of this.hiddenLayer.neurons) {
341
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
342
+ n.bias = weights[idx++];
343
+ }
344
+ for (const n of this.outputLayer.neurons) {
345
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
346
+ n.bias = weights[idx++];
347
+ }
348
+ }
349
+ };
350
+
351
+ // src/Dropout.ts
352
+ var Dropout = class {
353
+ constructor(rate) {
354
+ this._mask = null;
355
+ if (rate < 0 || rate >= 1) {
356
+ throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
357
+ }
358
+ this.rate = rate;
359
+ }
360
+ // ── Forward ───────────────────────────────────────────────────────────────
361
+ // x: number[] → number[]
362
+ // If training, applies inverted dropout mask.
363
+ // If not training, returns input unchanged.
364
+ forward(x, training = true) {
365
+ if (!training || this.rate === 0) {
366
+ this._mask = null;
367
+ return [...x];
368
+ }
369
+ const scale = 1 / (1 - this.rate);
370
+ this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
371
+ return x.map((v, i) => v * this._mask[i]);
372
+ }
373
+ // ── Backward ──────────────────────────────────────────────────────────────
374
+ // dOut: number[] → number[]
375
+ // Applies the same mask (gradient is zeroed where activation was zeroed).
376
+ backward(dOut) {
377
+ if (!this._mask) return [...dOut];
378
+ return dOut.map((d, i) => d * this._mask[i]);
379
+ }
380
+ // ── Reset mask between forward passes ─────────────────────────────────────
381
+ resetMask() {
382
+ this._mask = null;
383
+ }
384
+ // ── No trainable params ───────────────────────────────────────────────────
385
+ getWeights() {
386
+ return [];
387
+ }
388
+ setWeights(_weights) {
389
+ }
222
390
  };
223
391
 
224
392
  // src/NetworkN.ts
225
- var defaultOptimizer3 = () => new SGD();
226
393
  var NetworkN = class {
227
394
  constructor(structure, options = {}) {
228
395
  this.structure = structure;
229
396
  const nLayers = structure.length - 1;
230
397
  const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
231
- const optimizer = options.optimizer ?? defaultOptimizer3;
398
+ const optimizer = options.optimizer ?? defaultOptimizer;
399
+ const dropoutRate = options.dropoutRate ?? 0;
400
+ if (activations.length !== nLayers) {
401
+ throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
402
+ }
403
+ if (dropoutRate < 0 || dropoutRate >= 1) {
404
+ throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
405
+ }
406
+ this._residual = options.residual ?? false;
232
407
  this.layers = [];
233
408
  for (let i = 1; i < structure.length; i++) {
234
409
  this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
235
410
  }
411
+ this._dropouts = [];
412
+ if (dropoutRate > 0) {
413
+ for (let i = 0; i < nLayers - 1; i++) {
414
+ this._dropouts.push(new Dropout(dropoutRate));
415
+ }
416
+ }
417
+ const outputLayer = this.layers[this.layers.length - 1];
418
+ const outputActivation = outputLayer.neurons[0].activation;
419
+ for (let i = 1; i < outputLayer.neurons.length; i++) {
420
+ if (outputLayer.neurons[i].activation !== outputActivation) {
421
+ throw new Error("All output neurons must share the same activation function");
422
+ }
423
+ }
236
424
  }
237
- predict(inputs) {
238
- return this.layers.reduce((acc, layer) => layer.predict(acc), inputs);
425
+ predict(inputs, training = false) {
426
+ validateArray(inputs, this.structure[0], "NetworkN.predict");
427
+ let current = [...inputs];
428
+ for (let i = 0; i < this.layers.length; i++) {
429
+ const layerInput = [...current];
430
+ const layerOutput = this.layers[i].predict(current);
431
+ if (this._shouldResidual(i)) {
432
+ if (this.structure[i] === this.structure[i + 1]) {
433
+ current = layerOutput.map((v, j) => v + layerInput[j]);
434
+ } else {
435
+ current = [...layerOutput];
436
+ }
437
+ } else {
438
+ current = [...layerOutput];
439
+ }
440
+ if (i < this._dropouts.length) {
441
+ current = this._dropouts[i].forward(current, training);
442
+ }
443
+ }
444
+ return current;
239
445
  }
240
446
  // Generalized backpropagation across L layers.
241
447
  // Returns the mean squared error for the example.
242
448
  train(inputs, targets, lr) {
243
- const act = [inputs];
244
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
449
+ validateArray(inputs, this.structure[0], "NetworkN.train");
450
+ validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
451
+ const act = this._forwardAll(inputs, true);
245
452
  const pred = act[act.length - 1];
246
453
  const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
247
- let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
248
- for (let l = this.layers.length - 1; l >= 0; l--) {
249
- const layer = this.layers[l];
250
- const layerIn = act[l];
251
- const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
252
- const prevDeltas = layerIn.map((out, j) => {
253
- const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
254
- return prevAct ? errProp * prevAct.dfn(out) : errProp;
255
- });
256
- layer.neurons.forEach((n, k) => {
257
- n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
258
- });
259
- deltas = prevDeltas;
260
- }
454
+ const deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
455
+ this._backpropLayers(act, deltas, lr);
261
456
  return pred.reduce((s, p, i) => s + (targets[i] - p) ** 2, 0) / pred.length;
262
457
  }
263
458
  // Backprop with externally provided output-layer deltas.
264
459
  // Useful for custom loss functions (e.g. physics-based gradients).
265
460
  trainWithDeltas(inputs, outputDeltas, lr) {
461
+ const act = this._forwardAll(inputs, true);
462
+ this._backpropLayers(act, outputDeltas, lr);
463
+ }
464
+ // ── Flat weight serialization ─────────────────────────────────────────────
465
+ // Order: layer 0 (all neurons), layer 1, ..., layer N.
466
+ getWeights() {
467
+ for (const d of this._dropouts) d.resetMask();
468
+ const w = [];
469
+ for (const layer of this.layers) {
470
+ for (const n of layer.neurons) {
471
+ w.push(...n.weights, n.bias);
472
+ }
473
+ }
474
+ return w;
475
+ }
476
+ setWeights(weights) {
477
+ for (const d of this._dropouts) d.resetMask();
478
+ let idx = 0;
479
+ for (const layer of this.layers) {
480
+ for (const n of layer.neurons) {
481
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
482
+ n.bias = weights[idx++];
483
+ }
484
+ }
485
+ }
486
+ // ── Private helpers ──────────────────────────────────────────────────────
487
+ _shouldResidual(layerIndex) {
488
+ if (typeof this._residual === "function") return this._residual(layerIndex);
489
+ return this._residual;
490
+ }
491
+ // Forward pass storing activations at every layer boundary.
492
+ // Used by train(), trainWithDeltas(), and predict() shares the same logic.
493
+ _forwardAll(inputs, training) {
266
494
  const act = [inputs];
267
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
495
+ for (let i = 0; i < this.layers.length; i++) {
496
+ const layerInput = act[act.length - 1];
497
+ const layerOutput = this.layers[i].predict(layerInput);
498
+ let current;
499
+ if (this._shouldResidual(i) && this.structure[i] === this.structure[i + 1]) {
500
+ current = layerOutput.map((v, j) => v + layerInput[j]);
501
+ } else {
502
+ current = layerOutput;
503
+ }
504
+ if (i < this._dropouts.length) {
505
+ current = this._dropouts[i].forward(current, training);
506
+ }
507
+ act.push(current);
508
+ }
509
+ return act;
510
+ }
511
+ // Backward pass: updates all layer weights given the pre-computed activations
512
+ // and the initial output-layer deltas.
513
+ _backpropLayers(act, outputDeltas, lr) {
268
514
  let deltas = outputDeltas;
269
515
  for (let l = this.layers.length - 1; l >= 0; l--) {
270
516
  const layer = this.layers[l];
517
+ if (l < this._dropouts.length) {
518
+ deltas = this._dropouts[l].backward(deltas);
519
+ }
271
520
  const layerIn = act[l];
272
521
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
273
522
  const prevDeltas = layerIn.map((out, j) => {
274
523
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
275
524
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
276
525
  });
526
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
527
+ for (let j = 0; j < prevDeltas.length; j++) prevDeltas[j] += deltas[j];
528
+ }
277
529
  layer.neurons.forEach((n, k) => {
278
530
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
279
531
  });
@@ -294,7 +546,7 @@ var Gate = class {
294
546
  // shape: [hSize]
295
547
  constructor(inputSize, hSize, initBias = 0) {
296
548
  const n = inputSize + hSize;
297
- const limit = Math.sqrt(2 / n);
549
+ const limit = Math.sqrt(2 / (n + hSize));
298
550
  this.W = Array.from(
299
551
  { length: hSize },
300
552
  () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
@@ -308,8 +560,11 @@ var Gate = class {
308
560
  }
309
561
  };
310
562
  var LSTMLayer = class {
311
- constructor(inputSize, hiddenSize) {
563
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
312
564
  this._traj = [];
565
+ if (inputSize <= 0 || hiddenSize <= 0) {
566
+ throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
567
+ }
313
568
  this.inputSize = inputSize;
314
569
  this.hSize = hiddenSize;
315
570
  this.h = new Array(hiddenSize).fill(0);
@@ -318,6 +573,29 @@ var LSTMLayer = class {
318
573
  this.inputGate = new Gate(inputSize, hiddenSize);
319
574
  this.cellGate = new Gate(inputSize, hiddenSize);
320
575
  this.outputGate = new Gate(inputSize, hiddenSize);
576
+ const combSize = inputSize + hiddenSize;
577
+ this._optimizers = {
578
+ forgetW: Array.from(
579
+ { length: hiddenSize },
580
+ () => Array.from({ length: combSize }, () => optimizerFactory())
581
+ ),
582
+ forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
583
+ inputW: Array.from(
584
+ { length: hiddenSize },
585
+ () => Array.from({ length: combSize }, () => optimizerFactory())
586
+ ),
587
+ inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
588
+ cellW: Array.from(
589
+ { length: hiddenSize },
590
+ () => Array.from({ length: combSize }, () => optimizerFactory())
591
+ ),
592
+ cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
593
+ outputW: Array.from(
594
+ { length: hiddenSize },
595
+ () => Array.from({ length: combSize }, () => optimizerFactory())
596
+ ),
597
+ outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
598
+ };
321
599
  }
322
600
  // ── Reset state and trajectory (call at episode start) ────────────────────
323
601
  reset() {
@@ -327,6 +605,9 @@ var LSTMLayer = class {
327
605
  }
328
606
  // ── Forward pass ──────────────────────────────────────────────────────────
329
607
  predict(inputs) {
608
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
609
+ throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
610
+ }
330
611
  const combined = [...inputs, ...this.h];
331
612
  const c_prev = [...this.c];
332
613
  const zf = this.forgetGate.linear(combined);
@@ -401,15 +682,15 @@ var LSTMLayer = class {
401
682
  const scale = lr / T;
402
683
  for (let k = 0; k < hSize; k++) {
403
684
  for (let j = 0; j < combSize; j++) {
404
- this.forgetGate.W[k][j] += scale * dWf[k][j];
405
- this.inputGate.W[k][j] += scale * dWi[k][j];
406
- this.cellGate.W[k][j] += scale * dWg[k][j];
407
- this.outputGate.W[k][j] += scale * dWo[k][j];
685
+ this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
686
+ this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
687
+ this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
688
+ this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
408
689
  }
409
- this.forgetGate.b[k] += scale * dbf[k];
410
- this.inputGate.b[k] += scale * dbi[k];
411
- this.cellGate.b[k] += scale * dbg[k];
412
- this.outputGate.b[k] += scale * dbo[k];
690
+ this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
691
+ this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
692
+ this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
693
+ this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
413
694
  }
414
695
  this._traj = [];
415
696
  }
@@ -432,10 +713,38 @@ var LSTMLayer = class {
432
713
  this.outputGate.W = data.outputGate.W;
433
714
  this.outputGate.b = data.outputGate.b;
434
715
  }
716
+ // ── Flat weight serialization ─────────────────────────────────────────────
717
+ // Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
718
+ getWeightsFlat() {
719
+ const w = [];
720
+ for (const row of this.forgetGate.W) w.push(...row);
721
+ w.push(...this.forgetGate.b);
722
+ for (const row of this.inputGate.W) w.push(...row);
723
+ w.push(...this.inputGate.b);
724
+ for (const row of this.cellGate.W) w.push(...row);
725
+ w.push(...this.cellGate.b);
726
+ for (const row of this.outputGate.W) w.push(...row);
727
+ w.push(...this.outputGate.b);
728
+ return w;
729
+ }
730
+ setWeightsFlat(weights) {
731
+ let idx = 0;
732
+ for (let i = 0; i < this.forgetGate.W.length; i++)
733
+ for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
734
+ for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
735
+ for (let i = 0; i < this.inputGate.W.length; i++)
736
+ for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
737
+ for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
738
+ for (let i = 0; i < this.cellGate.W.length; i++)
739
+ for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
740
+ for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
741
+ for (let i = 0; i < this.outputGate.W.length; i++)
742
+ for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
743
+ for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
744
+ }
435
745
  };
436
746
 
437
747
  // src/NetworkLSTM.ts
438
- var defaultOptimizer4 = () => new SGD();
439
748
  var NetworkLSTM = class {
440
749
  // [T][layer+1][neuron]
441
750
  constructor(inputSize, hiddenSize, denseStructure, options = {}) {
@@ -443,7 +752,7 @@ var NetworkLSTM = class {
443
752
  this.hiddenSize = hiddenSize;
444
753
  this.lstm = new LSTMLayer(inputSize, hiddenSize);
445
754
  const activation = options.denseActivation ?? sigmoid2;
446
- const optimizer = options.optimizer ?? defaultOptimizer4;
755
+ const optimizer = options.optimizer ?? defaultOptimizer;
447
756
  this.denseLayers = [];
448
757
  const sizes = [hiddenSize, ...denseStructure];
449
758
  for (let i = 1; i < sizes.length; i++) {
@@ -458,6 +767,7 @@ var NetworkLSTM = class {
458
767
  }
459
768
  // ── Forward pass ──────────────────────────────────────────────────────────
460
769
  predict(inputs) {
770
+ validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
461
771
  const h = this.lstm.predict(inputs);
462
772
  const acts = [h];
463
773
  for (const layer of this.denseLayers) {
@@ -533,6 +843,30 @@ var NetworkLSTM = class {
533
843
  });
534
844
  });
535
845
  }
846
+ // ── Flat weight serialization ─────────────────────────────────────────────
847
+ // Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
848
+ getWeightsFlat() {
849
+ const w = [];
850
+ w.push(...this.lstm.getWeightsFlat());
851
+ for (const layer of this.denseLayers) {
852
+ for (const n of layer.neurons) {
853
+ w.push(...n.weights, n.bias);
854
+ }
855
+ }
856
+ return w;
857
+ }
858
+ setWeightsFlat(weights) {
859
+ let idx = 0;
860
+ const lstmLen = this.lstm.getWeightsFlat().length;
861
+ this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
862
+ idx += lstmLen;
863
+ for (const layer of this.denseLayers) {
864
+ for (const n of layer.neurons) {
865
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
866
+ n.bias = weights[idx++];
867
+ }
868
+ }
869
+ }
536
870
  };
537
871
 
538
872
  // src/MatMul.ts
@@ -540,6 +874,9 @@ function matMul(A, B) {
540
874
  const rows = A.length;
541
875
  const inner = B.length;
542
876
  const cols = B[0].length;
877
+ if (A[0].length !== B.length) {
878
+ throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
879
+ }
543
880
  const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
544
881
  for (let i = 0; i < rows; i++)
545
882
  for (let k = 0; k < inner; k++) {
@@ -590,6 +927,33 @@ var WeightMatrix = class {
590
927
  this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
591
928
  }
592
929
  }
930
+ // ── Flat weight serialization ─────────────────────────────────────────────
931
+ getWeights() {
932
+ const w = [];
933
+ for (const row of this.W) w.push(...row);
934
+ return w;
935
+ }
936
+ setWeights(weights) {
937
+ let idx = 0;
938
+ for (let i = 0; i < this.W.length; i++)
939
+ for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
940
+ }
941
+ };
942
+ var BiasVector = class {
943
+ constructor(size) {
944
+ this.values = new Array(size).fill(0);
945
+ this.opts = Array.from({ length: size }, () => new Adam());
946
+ }
947
+ update(grad, lr) {
948
+ for (let i = 0; i < this.values.length; i++)
949
+ this.values[i] = this.opts[i].step(this.values[i], grad[i], lr);
950
+ }
951
+ getWeights() {
952
+ return [...this.values];
953
+ }
954
+ setWeights(weights) {
955
+ for (let i = 0; i < this.values.length; i++) this.values[i] = weights[i];
956
+ }
593
957
  };
594
958
  var EmbeddingMatrix = class {
595
959
  constructor(vocabSize, d_model) {
@@ -606,15 +970,29 @@ var EmbeddingMatrix = class {
606
970
  for (let m = 0; m < this.W[idx].length; m++)
607
971
  this.W[idx][m] += lr * grad[m];
608
972
  }
973
+ // ── Serializable interface ─────────────────────────────────────────────────
974
+ // Flattened order: row 0, row 1, ... row (vocabSize-1)
975
+ getWeights() {
976
+ const w = [];
977
+ for (const row of this.W) w.push(...row);
978
+ return w;
979
+ }
980
+ setWeights(weights) {
981
+ let idx = 0;
982
+ for (let i = 0; i < this.W.length; i++)
983
+ for (let j = 0; j < this.W[i].length; j++)
984
+ this.W[i][j] = weights[idx++];
985
+ }
609
986
  };
610
987
 
611
988
  // src/AttentionHead.ts
612
989
  var AttentionHead = class {
613
- constructor(d_model, d_k, d_v) {
990
+ constructor(d_model, d_k, d_v, causal = false) {
614
991
  // d_v × d_model
615
992
  this.cache = null;
616
993
  this.d_k = d_k;
617
994
  this.d_v = d_v;
995
+ this.causal = causal;
618
996
  this.Wq = new WeightMatrix(d_k, d_model);
619
997
  this.Wk = new WeightMatrix(d_k, d_model);
620
998
  this.Wv = new WeightMatrix(d_v, d_model);
@@ -635,10 +1013,10 @@ var AttentionHead = class {
635
1013
  );
636
1014
  const scores = Array.from(
637
1015
  { length: seqLen },
638
- (_, i) => Array.from(
639
- { length: seqLen },
640
- (_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
641
- )
1016
+ (_, i) => Array.from({ length: seqLen }, (_2, j) => {
1017
+ if (this.causal && j > i) return -Infinity;
1018
+ return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
1019
+ })
642
1020
  );
643
1021
  const attn = scores.map((row) => softmax(row));
644
1022
  const out = Array.from(
@@ -662,6 +1040,7 @@ var AttentionHead = class {
662
1040
  // 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
663
1041
  // 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
664
1042
  backward(dOut, lr) {
1043
+ if (!this.cache) throw new Error("AttentionHead.backward() called before predict()");
665
1044
  const { X, Q, K, V, attn } = this.cache;
666
1045
  const seqLen = X.length;
667
1046
  const d_model = X[0].length;
@@ -734,21 +1113,40 @@ var AttentionHead = class {
734
1113
  getAttentionWeights() {
735
1114
  return this.cache ? this.cache.attn : null;
736
1115
  }
1116
+ // ── Flat weight serialization ─────────────────────────────────────────────
1117
+ // Order: Wq, Wk, Wv.
1118
+ getWeights() {
1119
+ const w = [];
1120
+ for (const row of this.Wq.W) w.push(...row);
1121
+ for (const row of this.Wk.W) w.push(...row);
1122
+ for (const row of this.Wv.W) w.push(...row);
1123
+ return w;
1124
+ }
1125
+ setWeights(weights) {
1126
+ let idx = 0;
1127
+ for (let i = 0; i < this.Wq.W.length; i++)
1128
+ for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
1129
+ for (let i = 0; i < this.Wk.W.length; i++)
1130
+ for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
1131
+ for (let i = 0; i < this.Wv.W.length; i++)
1132
+ for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
1133
+ }
737
1134
  };
738
1135
 
739
1136
  // src/MultiHeadAttention.ts
740
1137
  var MultiHeadAttention = class {
741
1138
  // seqLen × (nHeads * d_k)
742
- constructor(d_model, nHeads) {
1139
+ constructor(d_model, nHeads, causal = false) {
743
1140
  // d_model × (nHeads * d_k)
744
1141
  // Cached for backward
745
1142
  this._concat = null;
746
1143
  this.nHeads = nHeads;
747
1144
  this.d_model = d_model;
748
1145
  this.d_k = Math.floor(d_model / nHeads);
1146
+ this.causal = causal;
749
1147
  this.heads = Array.from(
750
1148
  { length: nHeads },
751
- () => new AttentionHead(d_model, this.d_k, this.d_k)
1149
+ () => new AttentionHead(d_model, this.d_k, this.d_k, causal)
752
1150
  );
753
1151
  this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
754
1152
  }
@@ -770,6 +1168,7 @@ var MultiHeadAttention = class {
770
1168
  // ── Backward ──────────────────────────────────────────────────────────────
771
1169
  // dOut: seqLen × d_model → dX: seqLen × d_model
772
1170
  backward(dOut, lr) {
1171
+ if (!this._concat) throw new Error("MultiHeadAttention.backward() called before predict()");
773
1172
  const seqLen = dOut.length;
774
1173
  const concatD = this.nHeads * this.d_k;
775
1174
  const d_model = this.d_model;
@@ -807,6 +1206,31 @@ var MultiHeadAttention = class {
807
1206
  getAttentionWeights() {
808
1207
  return this.heads.map((h) => h.getAttentionWeights());
809
1208
  }
1209
+ // ── Flat weight serialization ─────────────────────────────────────────────
1210
+ // Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
1211
+ getWeights() {
1212
+ const w = [];
1213
+ for (const head of this.heads) {
1214
+ for (const row of head.Wq.W) w.push(...row);
1215
+ for (const row of head.Wk.W) w.push(...row);
1216
+ for (const row of head.Wv.W) w.push(...row);
1217
+ }
1218
+ for (const row of this.Wo.W) w.push(...row);
1219
+ return w;
1220
+ }
1221
+ setWeights(weights) {
1222
+ let idx = 0;
1223
+ for (const head of this.heads) {
1224
+ for (let i = 0; i < head.Wq.W.length; i++)
1225
+ for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
1226
+ for (let i = 0; i < head.Wk.W.length; i++)
1227
+ for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
1228
+ for (let i = 0; i < head.Wv.W.length; i++)
1229
+ for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
1230
+ }
1231
+ for (let i = 0; i < this.Wo.W.length; i++)
1232
+ for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
1233
+ }
810
1234
  };
811
1235
 
812
1236
  // src/LayerNorm.ts
@@ -849,20 +1273,32 @@ var LayerNorm = class {
849
1273
  backwardOne(dOut, pos, lr) {
850
1274
  const { x_norm, std } = this._cache[pos];
851
1275
  const N = dOut.length;
1276
+ const gammaOld = this.gamma.slice();
852
1277
  for (let i = 0; i < N; i++) {
853
1278
  this.gamma[i] += lr * dOut[i] * x_norm[i];
854
1279
  this.beta[i] += lr * dOut[i];
855
1280
  }
856
- const D = dOut.map((d, i) => d * this.gamma[i]);
1281
+ const D = dOut.map((d, i) => d * gammaOld[i]);
857
1282
  const mD = D.reduce((s, v) => s + v, 0) / N;
858
1283
  const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
859
1284
  return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
860
1285
  }
1286
+ // ── Flat weight serialization ─────────────────────────────────────────────
1287
+ // Order: gamma, beta.
1288
+ getWeights() {
1289
+ return [...this.gamma, ...this.beta];
1290
+ }
1291
+ setWeights(weights) {
1292
+ const dim = this.gamma.length;
1293
+ for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
1294
+ for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
1295
+ }
861
1296
  };
862
1297
 
863
1298
  // src/TransformerBlock.ts
864
1299
  var TransformerBlock = class {
865
- constructor({ d_model, nHeads, d_ff }) {
1300
+ constructor({ d_model, nHeads, d_ff, causal = false }) {
1301
+ // d_model
866
1302
  // Forward caches (needed for backprop)
867
1303
  this._X = null;
868
1304
  this._attnOut = null;
@@ -874,15 +1310,13 @@ var TransformerBlock = class {
874
1310
  this._ff2Out = null;
875
1311
  this.d_model = d_model;
876
1312
  this.d_ff = d_ff;
877
- this.attn = new MultiHeadAttention(d_model, nHeads);
1313
+ this.attn = new MultiHeadAttention(d_model, nHeads, causal);
878
1314
  this.norm1 = new LayerNorm(d_model);
879
1315
  this.norm2 = new LayerNorm(d_model);
880
1316
  this.ff1 = new WeightMatrix(d_ff, d_model);
881
1317
  this.ff2 = new WeightMatrix(d_model, d_ff);
882
- this.b1 = new Array(d_ff).fill(0);
883
- this.b2 = new Array(d_model).fill(0);
884
- this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
885
- this.b2Opts = Array.from({ length: d_model }, () => new Adam());
1318
+ this.b1 = new BiasVector(d_ff);
1319
+ this.b2 = new BiasVector(d_model);
886
1320
  }
887
1321
  // ── Forward ───────────────────────────────────────────────────────────────
888
1322
  // X: seqLen × d_model → out: seqLen × d_model
@@ -895,11 +1329,11 @@ var TransformerBlock = class {
895
1329
  return this.norm1.predictOne(added, i);
896
1330
  });
897
1331
  const ff1Pre = h1.map(
898
- (h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
1332
+ (h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1.values[k]))
899
1333
  );
900
1334
  const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
901
1335
  const ff2Out = ff1Out.map(
902
- (h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
1336
+ (h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2.values[k]))
903
1337
  );
904
1338
  this.norm2.resetCache(seqLen);
905
1339
  const out = h1.map((h, i) => {
@@ -917,6 +1351,9 @@ var TransformerBlock = class {
917
1351
  // ── Backward ──────────────────────────────────────────────────────────────
918
1352
  // dOut: seqLen × d_model → dX: seqLen × d_model
919
1353
  backward(dOut, lr) {
1354
+ if (!this._h1 || !this._ff1Out || !this._ff1Pre) {
1355
+ throw new Error("TransformerBlock.backward() called before predict()");
1356
+ }
920
1357
  const seqLen = dOut.length;
921
1358
  const d_model = this.d_model;
922
1359
  const h1 = this._h1;
@@ -941,8 +1378,7 @@ var TransformerBlock = class {
941
1378
  (_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
942
1379
  );
943
1380
  this.ff2.update(dW2, lr);
944
- for (let m = 0; m < d_model; m++)
945
- this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
1381
+ this.b2.update(db2, lr);
946
1382
  const dFf1Pre = dFf1Out.map(
947
1383
  (d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
948
1384
  );
@@ -964,8 +1400,7 @@ var TransformerBlock = class {
964
1400
  (_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
965
1401
  );
966
1402
  this.ff1.update(dW1, lr);
967
- for (let k = 0; k < this.d_ff; k++)
968
- this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
1403
+ this.b1.update(db1, lr);
969
1404
  const dH1 = Array.from(
970
1405
  { length: seqLen },
971
1406
  (_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
@@ -987,6 +1422,36 @@ var TransformerBlock = class {
987
1422
  getAttentionWeights() {
988
1423
  return this.attn.getAttentionWeights();
989
1424
  }
1425
+ // ── Flat weight serialization ─────────────────────────────────────────────
1426
+ // Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
1427
+ getWeights() {
1428
+ const w = [];
1429
+ w.push(...this.attn.getWeights());
1430
+ w.push(...this.norm1.gamma, ...this.norm1.beta);
1431
+ for (const row of this.ff1.W) w.push(...row);
1432
+ w.push(...this.b1.values);
1433
+ for (const row of this.ff2.W) w.push(...row);
1434
+ w.push(...this.b2.values);
1435
+ w.push(...this.norm2.gamma, ...this.norm2.beta);
1436
+ return w;
1437
+ }
1438
+ setWeights(weights) {
1439
+ let idx = 0;
1440
+ const attnLen = this.attn.getWeights().length;
1441
+ this.attn.setWeights(weights.slice(idx, idx + attnLen));
1442
+ idx += attnLen;
1443
+ this.norm1.setWeights(weights.slice(idx, idx + this.norm1.getWeights().length));
1444
+ idx += this.norm1.getWeights().length;
1445
+ this.ff1.setWeights(weights.slice(idx, idx + this.ff1.getWeights().length));
1446
+ idx += this.ff1.getWeights().length;
1447
+ this.b1.setWeights(weights.slice(idx, idx + this.b1.values.length));
1448
+ idx += this.b1.values.length;
1449
+ this.ff2.setWeights(weights.slice(idx, idx + this.ff2.getWeights().length));
1450
+ idx += this.ff2.getWeights().length;
1451
+ this.b2.setWeights(weights.slice(idx, idx + this.b2.values.length));
1452
+ idx += this.b2.values.length;
1453
+ this.norm2.setWeights(weights.slice(idx, idx + this.norm2.getWeights().length));
1454
+ }
990
1455
  };
991
1456
 
992
1457
  // src/NetworkTransformer.ts
@@ -1011,8 +1476,7 @@ var NetworkTransformer = class {
1011
1476
  () => new TransformerBlock({ d_model, nHeads, d_ff })
1012
1477
  );
1013
1478
  this.outputProj = new WeightMatrix(nClasses, d_model);
1014
- this.outputBias = new Array(nClasses).fill(0);
1015
- this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
1479
+ this.outputBias = new BiasVector(nClasses);
1016
1480
  }
1017
1481
  // ── Forward pass ──────────────────────────────────────────────────────────
1018
1482
  // tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
@@ -1020,7 +1484,7 @@ var NetworkTransformer = class {
1020
1484
  const h = this._forward(tokens);
1021
1485
  return h.flatMap(
1022
1486
  (hi) => this.outputProj.W.map(
1023
- (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1487
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
1024
1488
  )
1025
1489
  );
1026
1490
  }
@@ -1034,7 +1498,7 @@ var NetworkTransformer = class {
1034
1498
  const h = this._forward(tokens);
1035
1499
  const logits = h.map(
1036
1500
  (hi) => this.outputProj.W.map(
1037
- (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1501
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
1038
1502
  )
1039
1503
  );
1040
1504
  let loss = 0;
@@ -1069,8 +1533,7 @@ var NetworkTransformer = class {
1069
1533
  (_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
1070
1534
  );
1071
1535
  this.outputProj.update(dWout, lr);
1072
- for (let c = 0; c < this.nClasses; c++)
1073
- this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1536
+ this.outputBias.update(dBout, lr);
1074
1537
  let dX = dH;
1075
1538
  for (let b = this.blocks.length - 1; b >= 0; b--)
1076
1539
  dX = this.blocks[b].backward(dX, lr);
@@ -1085,6 +1548,35 @@ var NetworkTransformer = class {
1085
1548
  getAttentionWeights() {
1086
1549
  return this.blocks.map((b) => b.getAttentionWeights());
1087
1550
  }
1551
+ // ── Flat weight serialization ─────────────────────────────────────────────
1552
+ // Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
1553
+ getWeights() {
1554
+ const w = [];
1555
+ w.push(...this.tokenEmb.getWeights());
1556
+ w.push(...this.posEmb.getWeights());
1557
+ for (const block of this.blocks) w.push(...block.getWeights());
1558
+ w.push(...this.outputProj.getWeights());
1559
+ w.push(...this.outputBias.getWeights());
1560
+ return w;
1561
+ }
1562
+ setWeights(weights) {
1563
+ let idx = 0;
1564
+ const tokenEmbLen = this.tokenEmb.getWeights().length;
1565
+ this.tokenEmb.setWeights(weights.slice(idx, idx + tokenEmbLen));
1566
+ idx += tokenEmbLen;
1567
+ const posEmbLen = this.posEmb.getWeights().length;
1568
+ this.posEmb.setWeights(weights.slice(idx, idx + posEmbLen));
1569
+ idx += posEmbLen;
1570
+ for (const block of this.blocks) {
1571
+ const blockLen = block.getWeights().length;
1572
+ block.setWeights(weights.slice(idx, idx + blockLen));
1573
+ idx += blockLen;
1574
+ }
1575
+ const outProjLen = this.outputProj.getWeights().length;
1576
+ this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
1577
+ idx += outProjLen;
1578
+ this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
1579
+ }
1088
1580
  // ── Internal ──────────────────────────────────────────────────────────────
1089
1581
  // Shared embedding + block forward pass.
1090
1582
  _forward(tokens) {
@@ -1104,25 +1596,28 @@ var NetworkTransformerRL = class {
1104
1596
  constructor(seqLen, inputDim, options = {}) {
1105
1597
  // Forward caches para backprop
1106
1598
  this._projected = null;
1599
+ // For max pooling backward: argmax per dimension across all positions
1600
+ this._argmax = null;
1107
1601
  const {
1108
1602
  d_model = 32,
1109
1603
  nHeads = 2,
1110
1604
  d_ff = 64,
1111
1605
  nBlocks = 2,
1112
- nActions = 2
1606
+ nActions = 2,
1607
+ pooling = "weighted"
1113
1608
  } = options;
1114
1609
  this.seqLen = seqLen;
1115
1610
  this.inputDim = inputDim;
1116
1611
  this.d_model = d_model;
1117
1612
  this.nActions = nActions;
1613
+ this._pooling = pooling;
1118
1614
  this.inputProj = new WeightMatrix(d_model, inputDim);
1119
1615
  this.blocks = Array.from(
1120
1616
  { length: nBlocks },
1121
- () => new TransformerBlock({ d_model, nHeads, d_ff })
1617
+ () => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
1122
1618
  );
1123
1619
  this.outputProj = new WeightMatrix(nActions, d_model);
1124
- this.outputBias = new Array(nActions).fill(0);
1125
- this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
1620
+ this.outputBias = new BiasVector(nActions);
1126
1621
  }
1127
1622
  // ── Forward ────────────────────────────────────────────────────────────────
1128
1623
  // sequence: seqLen × inputDim → nActions Q-values
@@ -1130,7 +1625,7 @@ var NetworkTransformerRL = class {
1130
1625
  const h = this._forward(sequence);
1131
1626
  const pooled = this._pool(h);
1132
1627
  return this.outputProj.W.map(
1133
- (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
1628
+ (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
1134
1629
  );
1135
1630
  }
1136
1631
  // ── Training ────────────────────────────────────────────────────────────────
@@ -1142,7 +1637,7 @@ var NetworkTransformerRL = class {
1142
1637
  const h = this._forward(sequence);
1143
1638
  const pooled = this._pool(h);
1144
1639
  const pred = this.outputProj.W.map(
1145
- (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
1640
+ (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
1146
1641
  );
1147
1642
  const n = this.nActions;
1148
1643
  let loss = 0;
@@ -1165,13 +1660,8 @@ var NetworkTransformerRL = class {
1165
1660
  );
1166
1661
  const dBout = dPred.slice();
1167
1662
  this.outputProj.update(dWout, lr);
1168
- for (let c = 0; c < this.nActions; c++)
1169
- this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1170
- let dH = Array.from(
1171
- { length: this.seqLen },
1172
- (_, i) => dPooled.map((v) => v / this.seqLen)
1173
- // Gradiente dividido entre posiciones
1174
- );
1663
+ this.outputBias.update(dBout, lr);
1664
+ let dH = this._distributePoolGradient(dPooled);
1175
1665
  for (let b = this.blocks.length - 1; b >= 0; b--)
1176
1666
  dH = this.blocks[b].backward(dH, lr);
1177
1667
  for (let i = 0; i < this.seqLen; i++) {
@@ -1190,8 +1680,32 @@ var NetworkTransformerRL = class {
1190
1680
  getAttentionWeights() {
1191
1681
  return this.blocks.map((b) => b.getAttentionWeights());
1192
1682
  }
1193
- // ── Serialization ──────────────────────────────────────────────────────────
1194
- getWeights() {
1683
+ // ── Flat weight serialization ─────────────────────────────────────────────
1684
+ // Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
1685
+ getWeightsFlat() {
1686
+ const w = [];
1687
+ w.push(...this.inputProj.getWeights());
1688
+ for (const block of this.blocks) w.push(...block.getWeights());
1689
+ w.push(...this.outputProj.getWeights());
1690
+ w.push(...this.outputBias.getWeights());
1691
+ return w;
1692
+ }
1693
+ setWeightsFlat(weights) {
1694
+ let idx = 0;
1695
+ const inputProjLen = this.inputProj.getWeights().length;
1696
+ this.inputProj.setWeights(weights.slice(idx, idx + inputProjLen));
1697
+ idx += inputProjLen;
1698
+ for (const block of this.blocks) {
1699
+ const blockLen = block.getWeights().length;
1700
+ block.setWeights(weights.slice(idx, idx + blockLen));
1701
+ idx += blockLen;
1702
+ }
1703
+ const outProjLen = this.outputProj.getWeights().length;
1704
+ this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
1705
+ idx += outProjLen;
1706
+ this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
1707
+ }
1708
+ getWeightsStructured() {
1195
1709
  return {
1196
1710
  inputProj: this.inputProj.W.map((r) => [...r]),
1197
1711
  blocks: this.blocks.map((b) => ({
@@ -1207,17 +1721,15 @@ var NetworkTransformerRL = class {
1207
1721
  norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
1208
1722
  ff1: b.ff1.W.map((r) => [...r]),
1209
1723
  ff2: b.ff2.W.map((r) => [...r]),
1210
- b1: [...b.b1],
1211
- b2: [...b.b2]
1724
+ b1: [...b.b1.values],
1725
+ b2: [...b.b2.values]
1212
1726
  })),
1213
1727
  outputProj: this.outputProj.W.map((r) => [...r]),
1214
- outputBias: [...this.outputBias]
1728
+ outputBias: [...this.outputBias.values]
1215
1729
  };
1216
1730
  }
1217
- setWeights(data) {
1218
- data.inputProj.forEach((row, i) => {
1219
- this.inputProj.W[i] = [...row];
1220
- });
1731
+ setWeightsStructured(data) {
1732
+ this.inputProj.setWeights(data.inputProj.flat());
1221
1733
  data.blocks.forEach((bd, b) => {
1222
1734
  const blk = this.blocks[b];
1223
1735
  bd.attn.heads.forEach((hd, h) => {
@@ -1232,11 +1744,20 @@ var NetworkTransformerRL = class {
1232
1744
  blk.norm2.beta = [...bd.norm2.beta];
1233
1745
  blk.ff1.W = bd.ff1.map((r) => [...r]);
1234
1746
  blk.ff2.W = bd.ff2.map((r) => [...r]);
1235
- blk.b1 = [...bd.b1];
1236
- blk.b2 = [...bd.b2];
1747
+ blk.b1.setWeights(bd.b1);
1748
+ blk.b2.setWeights(bd.b2);
1237
1749
  });
1238
1750
  this.outputProj.W = data.outputProj.map((r) => [...r]);
1239
- this.outputBias = [...data.outputBias];
1751
+ this.outputBias.setWeights(data.outputBias);
1752
+ }
1753
+ // ── Serializable interface (flat array) ────────────────────────────────────
1754
+ // These satisfy the Serializable interface from ModelSaver, which requires
1755
+ // getWeights(): number[] and setWeights(weights: number[]): void.
1756
+ getWeights() {
1757
+ return this.getWeightsFlat();
1758
+ }
1759
+ setWeights(weights) {
1760
+ this.setWeightsFlat(weights);
1240
1761
  }
1241
1762
  // ── Internal ────────────────────────────────────────────────────────────────
1242
1763
  _forward(sequence) {
@@ -1251,6 +1772,44 @@ var NetworkTransformerRL = class {
1251
1772
  return h;
1252
1773
  }
1253
1774
  _pool(h) {
1775
+ switch (this._pooling) {
1776
+ case "avg":
1777
+ return this._poolAvg(h);
1778
+ case "max":
1779
+ return this._poolMax(h);
1780
+ case "last":
1781
+ return this._poolLast(h);
1782
+ case "weighted":
1783
+ default:
1784
+ return this._poolWeighted(h);
1785
+ }
1786
+ }
1787
+ _poolAvg(h) {
1788
+ const n = h.length;
1789
+ return Array.from({ length: this.d_model }, (_, m) => {
1790
+ let sum = 0;
1791
+ for (let i = 0; i < n; i++)
1792
+ sum += h[i][m];
1793
+ return sum / n;
1794
+ });
1795
+ }
1796
+ _poolMax(h) {
1797
+ this._argmax = new Array(this.d_model).fill(0);
1798
+ return Array.from({ length: this.d_model }, (_, m) => {
1799
+ let maxVal = -Infinity;
1800
+ for (let i = 0; i < h.length; i++) {
1801
+ if (h[i][m] > maxVal) {
1802
+ maxVal = h[i][m];
1803
+ this._argmax[m] = i;
1804
+ }
1805
+ }
1806
+ return maxVal;
1807
+ });
1808
+ }
1809
+ _poolLast(h) {
1810
+ return [...h[h.length - 1]];
1811
+ }
1812
+ _poolWeighted(h) {
1254
1813
  const weights = Array.from(
1255
1814
  { length: this.seqLen },
1256
1815
  (_, i) => i === this.seqLen - 1 ? 2 : 1
@@ -1263,6 +1822,55 @@ var NetworkTransformerRL = class {
1263
1822
  return sum / totalWeight;
1264
1823
  });
1265
1824
  }
1825
+ /** Returns the current pooling type for inspection. */
1826
+ getPoolingType() {
1827
+ return this._pooling;
1828
+ }
1829
+ // ── Helper: distribute pooled gradient back to each position ────────────────
1830
+ // Must match the same distribution as _pool() used during forward.
1831
+ _distributePoolGradient(dPooled) {
1832
+ switch (this._pooling) {
1833
+ case "avg": {
1834
+ const n = this.seqLen;
1835
+ return Array.from(
1836
+ { length: n },
1837
+ () => dPooled.map((v) => v / n)
1838
+ );
1839
+ }
1840
+ case "max": {
1841
+ if (!this._argmax) {
1842
+ const n = this.seqLen;
1843
+ return Array.from(
1844
+ { length: n },
1845
+ () => dPooled.map((v) => v / n)
1846
+ );
1847
+ }
1848
+ const argmax = this._argmax;
1849
+ return Array.from(
1850
+ { length: this.seqLen },
1851
+ (_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
1852
+ );
1853
+ }
1854
+ case "last": {
1855
+ return Array.from(
1856
+ { length: this.seqLen },
1857
+ (_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
1858
+ );
1859
+ }
1860
+ case "weighted":
1861
+ default: {
1862
+ const weights = Array.from(
1863
+ { length: this.seqLen },
1864
+ (_, i) => i === this.seqLen - 1 ? 2 : 1
1865
+ );
1866
+ const totalWeight = weights.reduce((a, b) => a + b, 0);
1867
+ return Array.from(
1868
+ { length: this.seqLen },
1869
+ (_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
1870
+ );
1871
+ }
1872
+ }
1873
+ }
1266
1874
  };
1267
1875
 
1268
1876
  // src/losses.ts
@@ -1287,14 +1895,803 @@ function crossEntropyDeltaRaw(predicted, actual) {
1287
1895
  const p = Math.max(eps, Math.min(1 - eps, predicted));
1288
1896
  return actual / p - (1 - actual) / (1 - p);
1289
1897
  }
1898
+
1899
+ // src/GRU.ts
1900
+ function sigmoid4(x) {
1901
+ return 1 / (1 + Math.exp(-x));
1902
+ }
1903
+ function tanhFn(x) {
1904
+ const e = Math.exp(2 * x);
1905
+ return (e - 1) / (e + 1);
1906
+ }
1907
+ var Gate2 = class {
1908
+ constructor(inputSize, hSize, initBias = 0) {
1909
+ const n = inputSize + hSize;
1910
+ const limit = Math.sqrt(2 / (n + hSize));
1911
+ this.W = Array.from(
1912
+ { length: hSize },
1913
+ () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
1914
+ );
1915
+ this.b = new Array(hSize).fill(initBias);
1916
+ }
1917
+ linear(combined) {
1918
+ return this.W.map(
1919
+ (row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
1920
+ );
1921
+ }
1922
+ };
1923
+ var GRULayer = class {
1924
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
1925
+ this._traj = [];
1926
+ if (inputSize <= 0 || hiddenSize <= 0) {
1927
+ throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
1928
+ }
1929
+ this.inputSize = inputSize;
1930
+ this.hSize = hiddenSize;
1931
+ this.h = new Array(hiddenSize).fill(0);
1932
+ this.resetGate = new Gate2(inputSize, hiddenSize);
1933
+ this.updateGate = new Gate2(inputSize, hiddenSize);
1934
+ this.newGate = new Gate2(inputSize, hiddenSize);
1935
+ const combSize = inputSize + hiddenSize;
1936
+ this._optimizers = {
1937
+ resetW: Array.from(
1938
+ { length: hiddenSize },
1939
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1940
+ ),
1941
+ resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1942
+ updateW: Array.from(
1943
+ { length: hiddenSize },
1944
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1945
+ ),
1946
+ updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1947
+ newW: Array.from(
1948
+ { length: hiddenSize },
1949
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1950
+ ),
1951
+ newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
1952
+ };
1953
+ }
1954
+ reset() {
1955
+ this.h = new Array(this.hSize).fill(0);
1956
+ this._traj = [];
1957
+ }
1958
+ predict(inputs) {
1959
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
1960
+ throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
1961
+ }
1962
+ const combined = [...inputs, ...this.h];
1963
+ const h_prev = [...this.h];
1964
+ const r_pre = this.resetGate.linear(combined);
1965
+ const z_pre = this.updateGate.linear(combined);
1966
+ const r_a = r_pre.map(sigmoid4);
1967
+ const z_a = z_pre.map(sigmoid4);
1968
+ const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
1969
+ const n_pre = this.newGate.linear(combined_r);
1970
+ const n_a = n_pre.map(tanhFn);
1971
+ const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
1972
+ this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
1973
+ this.h = h;
1974
+ return h;
1975
+ }
1976
+ backprop(dh_seq, lr) {
1977
+ const T = this._traj.length;
1978
+ if (T === 0 || dh_seq.length !== T) return;
1979
+ const hSize = this.hSize;
1980
+ const combSize = this.inputSize + hSize;
1981
+ const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1982
+ const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1983
+ const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1984
+ const dbr = new Array(hSize).fill(0);
1985
+ const dbz = new Array(hSize).fill(0);
1986
+ const dbn = new Array(hSize).fill(0);
1987
+ let dh_next = new Array(hSize).fill(0);
1988
+ for (let t = T - 1; t >= 0; t--) {
1989
+ const s = this._traj[t];
1990
+ const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
1991
+ const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
1992
+ const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
1993
+ const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
1994
+ const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
1995
+ const dr_hprev = Array.from(
1996
+ { length: hSize },
1997
+ (_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
1998
+ );
1999
+ const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
2000
+ const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
2001
+ for (let k = 0; k < hSize; k++) {
2002
+ for (let j = 0; j < combSize; j++) {
2003
+ dWr[k][j] += dr_pre[k] * s.combined[j];
2004
+ dWz[k][j] += dz_pre[k] * s.combined[j];
2005
+ dWn[k][j] += dn_pre[k] * s.combined_r[j];
2006
+ }
2007
+ dbr[k] += dr_pre[k];
2008
+ dbz[k] += dz_pre[k];
2009
+ dbn[k] += dn_pre[k];
2010
+ }
2011
+ dh_next = new Array(hSize).fill(0);
2012
+ for (let k = 0; k < hSize; k++) {
2013
+ for (let j = this.inputSize; j < combSize; j++) {
2014
+ dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
2015
+ }
2016
+ dh_next[k] += dr_hprev[k] * s.r_a[k];
2017
+ dh_next[k] += dh[k] * s.z_a[k];
2018
+ }
2019
+ }
2020
+ const scale = lr / T;
2021
+ for (let k = 0; k < hSize; k++) {
2022
+ for (let j = 0; j < combSize; j++) {
2023
+ this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
2024
+ this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
2025
+ this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
2026
+ }
2027
+ this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
2028
+ this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
2029
+ this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
2030
+ }
2031
+ this._traj = [];
2032
+ }
2033
+ // ── Flat weight serialization ─────────────────────────────────────────────
2034
+ // Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
2035
+ getWeightsFlat() {
2036
+ const w = [];
2037
+ for (const row of this.resetGate.W) w.push(...row);
2038
+ w.push(...this.resetGate.b);
2039
+ for (const row of this.updateGate.W) w.push(...row);
2040
+ w.push(...this.updateGate.b);
2041
+ for (const row of this.newGate.W) w.push(...row);
2042
+ w.push(...this.newGate.b);
2043
+ return w;
2044
+ }
2045
+ setWeightsFlat(weights) {
2046
+ let idx = 0;
2047
+ for (let i = 0; i < this.resetGate.W.length; i++)
2048
+ for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
2049
+ for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
2050
+ for (let i = 0; i < this.updateGate.W.length; i++)
2051
+ for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
2052
+ for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
2053
+ for (let i = 0; i < this.newGate.W.length; i++)
2054
+ for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
2055
+ for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
2056
+ }
2057
+ getWeights() {
2058
+ return {
2059
+ resetGate: { W: this.resetGate.W, b: this.resetGate.b },
2060
+ updateGate: { W: this.updateGate.W, b: this.updateGate.b },
2061
+ newGate: { W: this.newGate.W, b: this.newGate.b }
2062
+ };
2063
+ }
2064
+ setWeights(data) {
2065
+ this.resetGate.W = data.resetGate.W;
2066
+ this.resetGate.b = data.resetGate.b;
2067
+ this.updateGate.W = data.updateGate.W;
2068
+ this.updateGate.b = data.updateGate.b;
2069
+ this.newGate.W = data.newGate.W;
2070
+ this.newGate.b = data.newGate.b;
2071
+ }
2072
+ };
2073
+
2074
+ // src/BatchNorm.ts
2075
+ var BatchNorm = class {
2076
+ constructor(dim, momentum = 0.1) {
2077
+ this._xNorm = null;
2078
+ this._std = null;
2079
+ this.dim = dim;
2080
+ this.momentum = momentum;
2081
+ this.gamma = new Array(dim).fill(1);
2082
+ this.beta = new Array(dim).fill(0);
2083
+ this.runningMean = new Array(dim).fill(0);
2084
+ this.runningVar = new Array(dim).fill(1);
2085
+ }
2086
+ // ── Forward ───────────────────────────────────────────────────────────────
2087
+ forward(x) {
2088
+ if (x.length !== this.dim) {
2089
+ throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
2090
+ }
2091
+ const eps = 1e-5;
2092
+ for (let i = 0; i < this.dim; i++) {
2093
+ this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
2094
+ const diff = x[i] - this.runningMean[i];
2095
+ this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
2096
+ }
2097
+ this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
2098
+ this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
2099
+ return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
2100
+ }
2101
+ // ── Backward ──────────────────────────────────────────────────────────────
2102
+ backward(dOut) {
2103
+ if (!this._xNorm || !this._std) {
2104
+ throw new Error("BatchNorm.backward: call forward() first");
2105
+ }
2106
+ for (let i = 0; i < this.dim; i++) {
2107
+ }
2108
+ return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
2109
+ }
2110
+ // ── Train gamma and beta (call after backward) ────────────────────────────
2111
+ trainParams(dOut, lr) {
2112
+ if (!this._xNorm) return;
2113
+ for (let i = 0; i < this.dim; i++) {
2114
+ this.gamma[i] += lr * dOut[i] * this._xNorm[i];
2115
+ this.beta[i] += lr * dOut[i];
2116
+ }
2117
+ }
2118
+ // ── Flat weight serialization ─────────────────────────────────────────────
2119
+ // Order: gamma, beta.
2120
+ getWeights() {
2121
+ return [...this.gamma, ...this.beta];
2122
+ }
2123
+ setWeights(weights) {
2124
+ for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
2125
+ for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
2126
+ }
2127
+ };
2128
+
2129
+ // src/Conv1D.ts
2130
+ var Conv1D = class {
2131
+ constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
2132
+ // [filters]
2133
+ this._input = null;
2134
+ this._paddedInput = null;
2135
+ if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
2136
+ throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
2137
+ }
2138
+ if (kernelSize > inputLength && padding === "valid") {
2139
+ throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
2140
+ }
2141
+ if (inputChannels < 1) {
2142
+ throw new Error("Conv1D: inputChannels must be >= 1");
2143
+ }
2144
+ this.inputLength = inputLength;
2145
+ this.kernelSize = kernelSize;
2146
+ this.filters = filters;
2147
+ this.stride = stride;
2148
+ this.padding = padding;
2149
+ this.inputChannels = inputChannels;
2150
+ const limit = Math.sqrt(2 / (kernelSize * inputChannels));
2151
+ this.kernels = Array.from(
2152
+ { length: filters },
2153
+ () => Array.from(
2154
+ { length: kernelSize },
2155
+ () => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
2156
+ )
2157
+ );
2158
+ this.biases = new Array(filters).fill(0);
2159
+ this._kOpts = Array.from(
2160
+ { length: filters },
2161
+ () => Array.from(
2162
+ { length: kernelSize },
2163
+ () => Array.from({ length: inputChannels }, () => optimizerFactory())
2164
+ )
2165
+ );
2166
+ this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
2167
+ }
2168
+ // ── Forward ───────────────────────────────────────────────────────────────
2169
+ // Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
2170
+ forward(input) {
2171
+ const input2D = this._normalizeInput(input);
2172
+ this._input = input2D.map((row) => [...row]);
2173
+ let padded;
2174
+ if (this.padding === "same") {
2175
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2176
+ const padRow = new Array(this.inputChannels).fill(0);
2177
+ padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
2178
+ } else {
2179
+ padded = input2D;
2180
+ }
2181
+ this._paddedInput = padded;
2182
+ const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
2183
+ const output = Array.from(
2184
+ { length: this.filters },
2185
+ () => new Array(outputLength).fill(0)
2186
+ );
2187
+ for (let f = 0; f < this.filters; f++) {
2188
+ for (let pos = 0; pos < outputLength; pos++) {
2189
+ const start = pos * this.stride;
2190
+ let sum = this.biases[f];
2191
+ for (let k = 0; k < this.kernelSize; k++) {
2192
+ for (let c = 0; c < this.inputChannels; c++) {
2193
+ sum += this.kernels[f][k][c] * padded[start + k][c];
2194
+ }
2195
+ }
2196
+ output[f][pos] = sum;
2197
+ }
2198
+ }
2199
+ return output;
2200
+ }
2201
+ // ── Backward ──────────────────────────────────────────────────────────────
2202
+ backward(dOut, lr = 1e-3) {
2203
+ if (!this._paddedInput || !this._input) {
2204
+ throw new Error("Conv1D.backward: call forward() first");
2205
+ }
2206
+ const padded = this._paddedInput;
2207
+ const outputLength = dOut[0].length;
2208
+ const dKernels = Array.from(
2209
+ { length: this.filters },
2210
+ () => Array.from(
2211
+ { length: this.kernelSize },
2212
+ () => new Array(this.inputChannels).fill(0)
2213
+ )
2214
+ );
2215
+ const dBiases = new Array(this.filters).fill(0);
2216
+ const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
2217
+ for (let f = 0; f < this.filters; f++) {
2218
+ for (let pos = 0; pos < outputLength; pos++) {
2219
+ const start = pos * this.stride;
2220
+ dBiases[f] += dOut[f][pos];
2221
+ for (let k = 0; k < this.kernelSize; k++) {
2222
+ for (let c = 0; c < this.inputChannels; c++) {
2223
+ dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
2224
+ dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
2225
+ }
2226
+ }
2227
+ }
2228
+ }
2229
+ for (let f = 0; f < this.filters; f++) {
2230
+ for (let k = 0; k < this.kernelSize; k++) {
2231
+ for (let c = 0; c < this.inputChannels; c++) {
2232
+ this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
2233
+ }
2234
+ }
2235
+ this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
2236
+ }
2237
+ if (this.padding === "same") {
2238
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2239
+ return dPadded.slice(padSize, padSize + this.inputLength);
2240
+ }
2241
+ return dPadded.slice(0, this.inputLength);
2242
+ }
2243
+ // ── Output length ─────────────────────────────────────────────────────────
2244
+ getOutputLength() {
2245
+ if (this.padding === "same") {
2246
+ return Math.ceil(this.inputLength / this.stride);
2247
+ }
2248
+ return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
2249
+ }
2250
+ // ── Flat weight serialization ─────────────────────────────────────────────
2251
+ // Order: kernels (flattened), biases.
2252
+ getWeights() {
2253
+ const w = [];
2254
+ for (const kernel of this.kernels)
2255
+ for (const k of kernel)
2256
+ for (const c of k)
2257
+ w.push(c);
2258
+ w.push(...this.biases);
2259
+ return w;
2260
+ }
2261
+ setWeights(weights) {
2262
+ let idx = 0;
2263
+ for (let f = 0; f < this.filters; f++)
2264
+ for (let k = 0; k < this.kernelSize; k++)
2265
+ for (let c = 0; c < this.inputChannels; c++)
2266
+ this.kernels[f][k][c] = weights[idx++];
2267
+ for (let f = 0; f < this.filters; f++)
2268
+ this.biases[f] = weights[idx++];
2269
+ }
2270
+ // ── Normalize input to 2D format ─────────────────────────────────────────
2271
+ _normalizeInput(input) {
2272
+ if (input.length === 0) {
2273
+ throw new Error("Conv1D.forward: input cannot be empty");
2274
+ }
2275
+ if (typeof input[0] === "number") {
2276
+ if (this.inputChannels !== 1) {
2277
+ throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
2278
+ }
2279
+ const input1D = input;
2280
+ if (input1D.length !== this.inputLength) {
2281
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
2282
+ }
2283
+ return input1D.map((v) => [v]);
2284
+ }
2285
+ const input2D = input;
2286
+ if (input2D.length !== this.inputLength) {
2287
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
2288
+ }
2289
+ for (let i = 0; i < input2D.length; i++) {
2290
+ if (input2D[i].length !== this.inputChannels) {
2291
+ throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
2292
+ }
2293
+ }
2294
+ return input2D;
2295
+ }
2296
+ };
2297
+
2298
+ // src/Trainer.ts
2299
+ var Trainer = class {
2300
+ constructor(network, options = {}) {
2301
+ this._history = [];
2302
+ this._bestLoss = Infinity;
2303
+ this._patienceCounter = 0;
2304
+ this._stopReason = "maxEpochs";
2305
+ this._metrics = [];
2306
+ this.network = network;
2307
+ this.epochs = options.epochs ?? 1e3;
2308
+ this.lrInitial = options.lr ?? 0.1;
2309
+ this.lrDecay = options.lrDecay ?? 1;
2310
+ this.verbose = options.verbose ?? false;
2311
+ this.weightDecay = options.weightDecay ?? 0;
2312
+ this._earlyStopping = options.earlyStopping;
2313
+ this._computeMetrics = options.computeMetrics ?? false;
2314
+ this.clipValue = options.clipValue ?? 0;
2315
+ }
2316
+ // ── Set external validation data (for early stopping) ────────────────────
2317
+ setValidationData(dataset) {
2318
+ if (dataset.inputs.length !== dataset.targets.length) {
2319
+ throw new Error(
2320
+ "Trainer.setValidationData: inputs and targets must have the same length"
2321
+ );
2322
+ }
2323
+ this._validationData = dataset;
2324
+ }
2325
+ // ── Get best validation loss during training ─────────────────────────────
2326
+ getBestLoss() {
2327
+ return this._bestLoss === Infinity ? -1 : this._bestLoss;
2328
+ }
2329
+ // ── Why did training stop? ───────────────────────────────────────────────
2330
+ getStopReason() {
2331
+ return this._stopReason;
2332
+ }
2333
+ // ── Get per-epoch classification metrics ─────────────────────────────────
2334
+ getMetrics() {
2335
+ return [...this._metrics];
2336
+ }
2337
+ // ── Train on dataset ──────────────────────────────────────────────────────
2338
+ train(dataset) {
2339
+ const { inputs, targets } = dataset;
2340
+ if (inputs.length !== targets.length) {
2341
+ throw new Error(
2342
+ "Trainer.train: inputs and targets must have the same length"
2343
+ );
2344
+ }
2345
+ const n = inputs.length;
2346
+ let lr = this.lrInitial;
2347
+ this._history = [];
2348
+ this._bestLoss = Infinity;
2349
+ this._patienceCounter = 0;
2350
+ this._stopReason = "maxEpochs";
2351
+ this._metrics = [];
2352
+ const netExt = this._hasWeights(this.network);
2353
+ if (this.weightDecay > 0 && !netExt) {
2354
+ console.warn(
2355
+ "Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
2356
+ );
2357
+ }
2358
+ if (this._earlyStopping && !netExt) {
2359
+ console.warn(
2360
+ "Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
2361
+ );
2362
+ }
2363
+ if (this._computeMetrics && !netExt) {
2364
+ console.warn(
2365
+ "Trainer: computeMetrics requires a network with predict(). Skipping metrics."
2366
+ );
2367
+ }
2368
+ const canDecay = this.weightDecay > 0 && netExt;
2369
+ const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
2370
+ const canMetric = this._computeMetrics && netExt;
2371
+ const isClass = canMetric && this._isClassification(targets);
2372
+ if (canMetric && !isClass) {
2373
+ console.warn(
2374
+ "Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
2375
+ );
2376
+ }
2377
+ for (let epoch = 0; epoch < this.epochs; epoch++) {
2378
+ const indices = Array.from({ length: n }, (_, i) => i);
2379
+ for (let i = n - 1; i > 0; i--) {
2380
+ const j = Math.floor(Math.random() * (i + 1));
2381
+ [indices[i], indices[j]] = [indices[j], indices[i]];
2382
+ }
2383
+ let epochLoss = 0;
2384
+ for (const i of indices) {
2385
+ if (canDecay) {
2386
+ const w = netExt.getWeights();
2387
+ for (let j = 0; j < w.length; j++) {
2388
+ w[j] *= 1 - lr * this.weightDecay;
2389
+ }
2390
+ netExt.setWeights(w);
2391
+ }
2392
+ epochLoss += this.network.train(inputs[i], targets[i], lr);
2393
+ }
2394
+ epochLoss /= n;
2395
+ this._history.push(epochLoss);
2396
+ if (canMetric && isClass) {
2397
+ this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
2398
+ }
2399
+ if (canValidate && this._validationData) {
2400
+ const valLoss = this._computeLoss(netExt, this._validationData);
2401
+ const minDelta = this._earlyStopping.minDelta;
2402
+ if (valLoss < this._bestLoss - minDelta) {
2403
+ this._bestLoss = valLoss;
2404
+ this._patienceCounter = 0;
2405
+ } else {
2406
+ this._patienceCounter++;
2407
+ }
2408
+ if (this._patienceCounter >= this._earlyStopping.patience) {
2409
+ this._stopReason = "earlyStopping";
2410
+ break;
2411
+ }
2412
+ }
2413
+ lr *= this.lrDecay;
2414
+ if (this.verbose && (epoch + 1) % 100 === 0) {
2415
+ console.log(
2416
+ `Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
2417
+ );
2418
+ }
2419
+ }
2420
+ return this._history;
2421
+ }
2422
+ // ── Get loss history ──────────────────────────────────────────────────────
2423
+ getHistory() {
2424
+ return [...this._history];
2425
+ }
2426
+ // ── Private helpers ───────────────────────────────────────────────────────
2427
+ /** Type guard: does this network support getWeights/setWeights/predict? */
2428
+ _hasWeights(network) {
2429
+ if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
2430
+ return network;
2431
+ }
2432
+ return null;
2433
+ }
2434
+ /** Mean squared error on a dataset (used for validation loss). */
2435
+ _computeLoss(net, data) {
2436
+ let totalLoss = 0;
2437
+ for (let i = 0; i < data.inputs.length; i++) {
2438
+ const pred = net.predict(data.inputs[i]);
2439
+ const target = data.targets[i];
2440
+ let sampleLoss = 0;
2441
+ for (let j = 0; j < pred.length; j++) {
2442
+ sampleLoss += (target[j] - pred[j]) ** 2;
2443
+ }
2444
+ totalLoss += sampleLoss / pred.length;
2445
+ }
2446
+ return totalLoss / data.inputs.length;
2447
+ }
2448
+ /** Heuristic: are targets classification-style (one-hot or single-class)? */
2449
+ _isClassification(targets) {
2450
+ if (targets.length === 0) return false;
2451
+ const first = targets[0];
2452
+ if (first.length === 1) return true;
2453
+ for (const t of targets) {
2454
+ let sum = 0;
2455
+ for (const v of t) {
2456
+ sum += v;
2457
+ if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
2458
+ return false;
2459
+ }
2460
+ if (Math.abs(sum - 1) > 0.01) return false;
2461
+ }
2462
+ return true;
2463
+ }
2464
+ /** Compute classification metrics from predictions vs targets. */
2465
+ _computeMetricsArray(net, inputs, targets) {
2466
+ const targetLen = targets[0].length;
2467
+ const nClasses = targetLen === 1 ? 2 : targetLen;
2468
+ const confusion = Array.from(
2469
+ { length: nClasses },
2470
+ () => Array(nClasses).fill(0)
2471
+ );
2472
+ for (let i = 0; i < inputs.length; i++) {
2473
+ const pred = net.predict(inputs[i]);
2474
+ const target = targets[i];
2475
+ let predClass;
2476
+ let trueClass;
2477
+ if (targetLen === 1) {
2478
+ trueClass = target[0] >= 0.5 ? 1 : 0;
2479
+ if (pred.length === 1) {
2480
+ predClass = pred[0] >= 0.5 ? 1 : 0;
2481
+ } else {
2482
+ predClass = pred.indexOf(Math.max(...pred));
2483
+ }
2484
+ } else {
2485
+ predClass = pred.indexOf(Math.max(...pred));
2486
+ trueClass = target.indexOf(Math.max(...target));
2487
+ }
2488
+ predClass = Math.max(0, Math.min(nClasses - 1, predClass));
2489
+ trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
2490
+ confusion[trueClass][predClass]++;
2491
+ }
2492
+ let totalCorrect = 0;
2493
+ let totalSamples = 0;
2494
+ const precisions = [];
2495
+ const recalls = [];
2496
+ for (let c = 0; c < nClasses; c++) {
2497
+ const tp = confusion[c][c];
2498
+ totalCorrect += tp;
2499
+ let colSum = 0;
2500
+ let rowSum = 0;
2501
+ for (let r = 0; r < nClasses; r++) {
2502
+ colSum += confusion[r][c];
2503
+ rowSum += confusion[c][r];
2504
+ }
2505
+ totalSamples += rowSum;
2506
+ precisions.push(colSum > 0 ? tp / colSum : 0);
2507
+ recalls.push(rowSum > 0 ? tp / rowSum : 0);
2508
+ }
2509
+ const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
2510
+ const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
2511
+ const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
2512
+ const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
2513
+ return {
2514
+ accuracy,
2515
+ precision: macroPrecision,
2516
+ recall: macroRecall,
2517
+ f1
2518
+ };
2519
+ }
2520
+ };
2521
+
2522
+ // src/DataLoader.ts
2523
+ var DataLoader = class _DataLoader {
2524
+ constructor(data, batchSize = 1, validationSplit = 0) {
2525
+ if (data.inputs.length !== data.targets.length) {
2526
+ throw new Error("DataLoader: inputs and targets must have the same length");
2527
+ }
2528
+ if (validationSplit < 0 || validationSplit >= 1) {
2529
+ throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
2530
+ }
2531
+ this.data = data;
2532
+ this.batchSize = batchSize;
2533
+ this._validationSplit = validationSplit;
2534
+ const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
2535
+ for (let i = fullIndices.length - 1; i > 0; i--) {
2536
+ const j = Math.floor(Math.random() * (i + 1));
2537
+ [fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
2538
+ }
2539
+ if (validationSplit > 0) {
2540
+ const valSize = Math.round(data.inputs.length * validationSplit);
2541
+ const trainSize = data.inputs.length - valSize;
2542
+ this._trainIndices = fullIndices.slice(0, trainSize);
2543
+ this._valIndices = fullIndices.slice(trainSize);
2544
+ } else {
2545
+ this._trainIndices = [...fullIndices];
2546
+ this._valIndices = [];
2547
+ }
2548
+ this._indices = [...this._trainIndices];
2549
+ this._pos = 0;
2550
+ }
2551
+ // ── Shuffle the training data ──────────────────────────────────────────────
2552
+ shuffle() {
2553
+ for (let i = this._trainIndices.length - 1; i > 0; i--) {
2554
+ const j = Math.floor(Math.random() * (i + 1));
2555
+ [this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
2556
+ }
2557
+ this._indices = [...this._trainIndices];
2558
+ this._pos = 0;
2559
+ }
2560
+ // ── Check if more batches are available ───────────────────────────────────
2561
+ hasNext() {
2562
+ return this._pos < this._indices.length;
2563
+ }
2564
+ // ── Get next batch ────────────────────────────────────────────────────────
2565
+ next() {
2566
+ const end = Math.min(this._pos + this.batchSize, this._indices.length);
2567
+ const batchIndices = this._indices.slice(this._pos, end);
2568
+ this._pos = end;
2569
+ return {
2570
+ inputs: batchIndices.map((i) => this.data.inputs[i]),
2571
+ targets: batchIndices.map((i) => this.data.targets[i])
2572
+ };
2573
+ }
2574
+ // ── Reset iteration ───────────────────────────────────────────────────────
2575
+ reset() {
2576
+ this._pos = 0;
2577
+ }
2578
+ // ── Get total number of training samples ───────────────────────────────────
2579
+ get length() {
2580
+ return this._trainIndices.length;
2581
+ }
2582
+ // ── Get validation data as a DataPair ──────────────────────────────────────
2583
+ // Returns the validation samples (inputs + targets) in their shuffled order.
2584
+ // Returns empty arrays if no validation split was configured.
2585
+ getValidationData() {
2586
+ return {
2587
+ inputs: this._valIndices.map((i) => this.data.inputs[i]),
2588
+ targets: this._valIndices.map((i) => this.data.targets[i])
2589
+ };
2590
+ }
2591
+ // ── Get number of validation samples ───────────────────────────────────────
2592
+ get validationLength() {
2593
+ return this._valIndices.length;
2594
+ }
2595
+ // ── Create sequence windows from a time series ────────────────────────────
2596
+ static sequences(data, seqLen, validationSplit = 0) {
2597
+ if (data.length < seqLen + 1) {
2598
+ throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
2599
+ }
2600
+ const inputs = [];
2601
+ const targets = [];
2602
+ for (let i = 0; i <= data.length - seqLen - 1; i++) {
2603
+ inputs.push(data.slice(i, i + seqLen).flat());
2604
+ targets.push(data[i + seqLen]);
2605
+ }
2606
+ return new _DataLoader({ inputs, targets }, 1, validationSplit);
2607
+ }
2608
+ };
2609
+
2610
+ // src/LRScheduler.ts
2611
+ var LRScheduler = class {
2612
+ // ── Step Decay ────────────────────────────────────────────────────────────
2613
+ // lr = initialLr * dropRate^floor(epoch / epochsDrop)
2614
+ stepDecay(lr, epoch, dropRate, epochsDrop) {
2615
+ return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
2616
+ }
2617
+ // ── Exponential Decay ─────────────────────────────────────────────────────
2618
+ // lr = initialLr * decayRate^epoch
2619
+ exponentialDecay(lr, epoch, decayRate) {
2620
+ return lr * Math.pow(decayRate, epoch);
2621
+ }
2622
+ // ── Plateau Decay ─────────────────────────────────────────────────────────
2623
+ // If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
2624
+ // Returns the new lr. Call this after each epoch with the current loss.
2625
+ //
2626
+ // Usage:
2627
+ // let patience_counter = 0
2628
+ // let best_loss = Infinity
2629
+ // for (let epoch = 0; epoch < 1000; epoch++) {
2630
+ // const loss = train(...)
2631
+ // lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
2632
+ // }
2633
+ plateauDecay(lr, currentLoss, history, patience, factor) {
2634
+ if (history.length < patience) return lr;
2635
+ const recentLosses = history.slice(-patience);
2636
+ const minRecentLoss = Math.min(...recentLosses);
2637
+ if (currentLoss >= minRecentLoss) {
2638
+ return lr * factor;
2639
+ }
2640
+ return lr;
2641
+ }
2642
+ // ── Cosine Annealing ──────────────────────────────────────────────────────
2643
+ // lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
2644
+ cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
2645
+ return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
2646
+ }
2647
+ };
2648
+
2649
+ // src/ModelSaver.ts
2650
+ var ModelSaver = class _ModelSaver {
2651
+ // ── Serialize to JSON string ──────────────────────────────────────────────
2652
+ static toJSON(model) {
2653
+ return JSON.stringify({
2654
+ weights: model.getWeights(),
2655
+ timestamp: Date.now()
2656
+ });
2657
+ }
2658
+ // ── Deserialize from JSON string ──────────────────────────────────────────
2659
+ static fromJSON(model, json) {
2660
+ const data = JSON.parse(json);
2661
+ if (!data.weights || !Array.isArray(data.weights)) {
2662
+ throw new Error("ModelSaver.fromJSON: invalid model data");
2663
+ }
2664
+ model.setWeights(data.weights);
2665
+ }
2666
+ // ── Save to file (requires write function) ────────────────────────────────
2667
+ static saveToFile(model, path, writeFn) {
2668
+ const json = _ModelSaver.toJSON(model);
2669
+ writeFn(path, json);
2670
+ }
2671
+ // ── Load from file (requires read function) ───────────────────────────────
2672
+ static loadFromFile(model, path, readFn) {
2673
+ const json = readFn(path);
2674
+ _ModelSaver.fromJSON(model, json);
2675
+ }
2676
+ };
1290
2677
  // Annotate the CommonJS export names for ESM import in node:
1291
2678
  0 && (module.exports = {
1292
2679
  Adam,
1293
2680
  AttentionHead,
2681
+ BatchNorm,
2682
+ BiasVector,
2683
+ ClipOptimizer,
2684
+ ClippedOptimizerFactory,
2685
+ Conv1D,
2686
+ DataLoader,
2687
+ Dropout,
1294
2688
  EmbeddingMatrix,
2689
+ GRULayer,
2690
+ LRScheduler,
1295
2691
  LSTMLayer,
1296
2692
  Layer,
1297
2693
  LayerNorm,
2694
+ ModelSaver,
1298
2695
  Momentum,
1299
2696
  MultiHeadAttention,
1300
2697
  Network,
@@ -1305,11 +2702,13 @@ function crossEntropyDeltaRaw(predicted, actual) {
1305
2702
  Neuron,
1306
2703
  NeuronN,
1307
2704
  SGD,
2705
+ Trainer,
1308
2706
  TransformerBlock,
1309
2707
  WeightMatrix,
1310
2708
  crossEntropy,
1311
2709
  crossEntropyDelta,
1312
2710
  crossEntropyDeltaRaw,
2711
+ defaultOptimizer,
1313
2712
  elu,
1314
2713
  leakyRelu,
1315
2714
  linear,
@@ -1323,5 +2722,9 @@ function crossEntropyDeltaRaw(predicted, actual) {
1323
2722
  softmax,
1324
2723
  softmaxBackward,
1325
2724
  tanh,
1326
- transpose
2725
+ transpose,
2726
+ validate2DArray,
2727
+ validateArray,
2728
+ validateArrayMinLength,
2729
+ validateNumber
1327
2730
  });