@dniskav/neuron 0.2.3 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.mjs CHANGED
@@ -1,3 +1,71 @@
1
+ // src/Validation.ts
2
+ function validateArray(arr, expectedLength, methodName) {
3
+ if (!Array.isArray(arr)) {
4
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
5
+ }
6
+ if (arr.length !== expectedLength) {
7
+ throw new Error(
8
+ `${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
9
+ );
10
+ }
11
+ for (let i = 0; i < arr.length; i++) {
12
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
13
+ throw new Error(
14
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
15
+ );
16
+ }
17
+ }
18
+ }
19
+ function validateArrayMinLength(arr, minLength, methodName) {
20
+ if (!Array.isArray(arr)) {
21
+ throw new Error(`${methodName}: expected array, got ${typeof arr}`);
22
+ }
23
+ if (arr.length < minLength) {
24
+ throw new Error(
25
+ `${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
26
+ );
27
+ }
28
+ for (let i = 0; i < arr.length; i++) {
29
+ if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
30
+ throw new Error(
31
+ `${methodName}: invalid value at index ${i}: ${arr[i]}`
32
+ );
33
+ }
34
+ }
35
+ }
36
+ function validate2DArray(arr, expectedRows, expectedCols, methodName) {
37
+ if (!Array.isArray(arr)) {
38
+ throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
39
+ }
40
+ if (arr.length !== expectedRows) {
41
+ throw new Error(
42
+ `${methodName}: expected ${expectedRows} rows, got ${arr.length}`
43
+ );
44
+ }
45
+ for (let i = 0; i < arr.length; i++) {
46
+ if (!Array.isArray(arr[i])) {
47
+ throw new Error(`${methodName}: row ${i} is not an array`);
48
+ }
49
+ if (arr[i].length !== expectedCols) {
50
+ throw new Error(
51
+ `${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
52
+ );
53
+ }
54
+ for (let j = 0; j < arr[i].length; j++) {
55
+ if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
56
+ throw new Error(
57
+ `${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
58
+ );
59
+ }
60
+ }
61
+ }
62
+ }
63
+ function validateNumber(value, methodName) {
64
+ if (typeof value !== "number" || !isFinite(value)) {
65
+ throw new Error(`${methodName}: expected finite number, got ${value}`);
66
+ }
67
+ }
68
+
1
69
  // src/Neuron.ts
2
70
  function sigmoid(x) {
3
71
  return 1 / (1 + Math.exp(-x));
@@ -8,13 +76,18 @@ var Neuron = class {
8
76
  this.bias = Math.random() * 0.1;
9
77
  }
10
78
  predict(input) {
79
+ validateNumber(input, "Neuron.predict");
11
80
  return sigmoid(input * this.weight + this.bias);
12
81
  }
13
82
  train(input, target, lr) {
83
+ validateNumber(input, "Neuron.train");
84
+ validateNumber(target, "Neuron.train");
85
+ validateNumber(lr, "Neuron.train");
14
86
  const prediction = this.predict(input);
15
87
  const error = target - prediction;
16
- this.weight += lr * error * input;
17
- this.bias += lr * error;
88
+ const grad = error * prediction * (1 - prediction);
89
+ this.weight += lr * grad * input;
90
+ this.bias += lr * grad;
18
91
  }
19
92
  };
20
93
 
@@ -54,6 +127,7 @@ function makeElu(alpha = 1) {
54
127
  var elu = makeElu(1);
55
128
 
56
129
  // src/optimizers.ts
130
+ var defaultOptimizer = () => new SGD();
57
131
  var SGD = class {
58
132
  step(weight, gradient, lr) {
59
133
  return weight + lr * gradient;
@@ -69,6 +143,19 @@ var Momentum = class {
69
143
  return weight + this.v;
70
144
  }
71
145
  };
146
+ var ClipOptimizer = class {
147
+ constructor(inner, clipValue) {
148
+ this.inner = inner;
149
+ this.clipValue = clipValue;
150
+ }
151
+ step(weight, gradient, lr) {
152
+ const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
153
+ return this.inner.step(weight, clipped, lr);
154
+ }
155
+ };
156
+ function ClippedOptimizerFactory(innerFactory, clipValue) {
157
+ return () => new ClipOptimizer(innerFactory(), clipValue);
158
+ }
72
159
  var Adam = class {
73
160
  constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
74
161
  this.beta1 = beta1;
@@ -89,7 +176,6 @@ var Adam = class {
89
176
  };
90
177
 
91
178
  // src/NeuronN.ts
92
- var defaultOptimizer = () => new SGD();
93
179
  var NeuronN = class {
94
180
  constructor(nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
95
181
  const limit = Math.sqrt(1 / nInputs);
@@ -99,6 +185,7 @@ var NeuronN = class {
99
185
  this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
100
186
  }
101
187
  predict(inputs) {
188
+ validateArray(inputs, this.weights.length, "NeuronN.predict");
102
189
  const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
103
190
  return this.activation.fn(sum);
104
191
  }
@@ -111,14 +198,14 @@ var NeuronN = class {
111
198
  train(inputs, target, lr) {
112
199
  const prediction = this.predict(inputs);
113
200
  const error = target - prediction;
114
- this._update(inputs.map((inp) => error * inp), error, lr);
201
+ const grad = error * this.activation.dfn(prediction);
202
+ this._update(inputs.map((inp) => grad * inp), grad, lr);
115
203
  }
116
204
  };
117
205
 
118
206
  // src/Layer.ts
119
- var defaultOptimizer2 = () => new SGD();
120
207
  var Layer = class {
121
- constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer2) {
208
+ constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
122
209
  this.neurons = Array.from(
123
210
  { length: nNeurons },
124
211
  () => new NeuronN(nInputs, activation, optimizerFactory)
@@ -136,84 +223,233 @@ var Network = class {
136
223
  this.outputLayer = new Layer(nOutputs, nHidden);
137
224
  }
138
225
  predict(inputs) {
226
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
139
227
  const hiddenOut = this.hiddenLayer.predict(inputs);
140
- return this.outputLayer.predict(hiddenOut)[0];
228
+ return this.outputLayer.predict(hiddenOut);
141
229
  }
142
230
  // Trains on a single example. Returns the squared error.
143
231
  train(inputs, target, lr) {
232
+ validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
233
+ validateNumber(target, "Network.train");
234
+ validateNumber(lr, "Network.train");
144
235
  const hiddenOut = this.hiddenLayer.predict(inputs);
145
236
  const prediction = this.outputLayer.predict(hiddenOut)[0];
146
- const outputError = target - prediction;
147
- const outputDelta = outputError * prediction * (1 - prediction);
148
237
  const outputNeuron = this.outputLayer.neurons[0];
149
- outputNeuron.weights = outputNeuron.weights.map(
150
- (w, i) => w + lr * outputDelta * hiddenOut[i]
151
- );
152
- outputNeuron.bias += lr * outputDelta;
153
- this.hiddenLayer.neurons.forEach((neuron, i) => {
154
- const hiddenOut_i = hiddenOut[i];
238
+ const outputError = target - prediction;
239
+ const outputDelta = outputError * outputNeuron.activation.dfn(prediction);
240
+ const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
155
241
  const hiddenError = outputDelta * outputNeuron.weights[i];
156
- const hiddenDelta = hiddenError * hiddenOut_i * (1 - hiddenOut_i);
157
- neuron.weights = neuron.weights.map((w, j) => w + lr * hiddenDelta * inputs[j]);
158
- neuron.bias += lr * hiddenDelta;
242
+ return hiddenError * neuron.activation.dfn(hiddenOut[i]);
243
+ });
244
+ this.hiddenLayer.neurons.forEach((neuron, i) => {
245
+ neuron._update(inputs.map((inp) => hiddenDeltas[i] * inp), hiddenDeltas[i], lr);
159
246
  });
247
+ outputNeuron._update(hiddenOut.map((h) => outputDelta * h), outputDelta, lr);
160
248
  return outputError * outputError;
161
249
  }
250
+ // ── Flat weight serialization ─────────────────────────────────────────────
251
+ // Order: hidden layer (all neurons: weights then bias), then output layer.
252
+ getWeights() {
253
+ const w = [];
254
+ for (const n of this.hiddenLayer.neurons) {
255
+ w.push(...n.weights, n.bias);
256
+ }
257
+ for (const n of this.outputLayer.neurons) {
258
+ w.push(...n.weights, n.bias);
259
+ }
260
+ return w;
261
+ }
262
+ setWeights(weights) {
263
+ let idx = 0;
264
+ for (const n of this.hiddenLayer.neurons) {
265
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
266
+ n.bias = weights[idx++];
267
+ }
268
+ for (const n of this.outputLayer.neurons) {
269
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
270
+ n.bias = weights[idx++];
271
+ }
272
+ }
273
+ };
274
+
275
+ // src/Dropout.ts
276
+ var Dropout = class {
277
+ constructor(rate) {
278
+ this._mask = null;
279
+ if (rate < 0 || rate >= 1) {
280
+ throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
281
+ }
282
+ this.rate = rate;
283
+ }
284
+ // ── Forward ───────────────────────────────────────────────────────────────
285
+ // x: number[] → number[]
286
+ // If training, applies inverted dropout mask.
287
+ // If not training, returns input unchanged.
288
+ forward(x, training = true) {
289
+ if (!training || this.rate === 0) {
290
+ this._mask = null;
291
+ return [...x];
292
+ }
293
+ const scale = 1 / (1 - this.rate);
294
+ this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
295
+ return x.map((v, i) => v * this._mask[i]);
296
+ }
297
+ // ── Backward ──────────────────────────────────────────────────────────────
298
+ // dOut: number[] → number[]
299
+ // Applies the same mask (gradient is zeroed where activation was zeroed).
300
+ backward(dOut) {
301
+ if (!this._mask) return [...dOut];
302
+ return dOut.map((d, i) => d * this._mask[i]);
303
+ }
304
+ // ── Reset mask between forward passes ─────────────────────────────────────
305
+ resetMask() {
306
+ this._mask = null;
307
+ }
308
+ // ── No trainable params ───────────────────────────────────────────────────
309
+ getWeights() {
310
+ return [];
311
+ }
312
+ setWeights(_weights) {
313
+ }
162
314
  };
163
315
 
164
316
  // src/NetworkN.ts
165
- var defaultOptimizer3 = () => new SGD();
166
317
  var NetworkN = class {
167
318
  constructor(structure, options = {}) {
168
319
  this.structure = structure;
169
320
  const nLayers = structure.length - 1;
170
321
  const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
171
- const optimizer = options.optimizer ?? defaultOptimizer3;
322
+ const optimizer = options.optimizer ?? defaultOptimizer;
323
+ const dropoutRate = options.dropoutRate ?? 0;
324
+ if (activations.length !== nLayers) {
325
+ throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
326
+ }
327
+ if (dropoutRate < 0 || dropoutRate >= 1) {
328
+ throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
329
+ }
330
+ this._residual = options.residual ?? false;
172
331
  this.layers = [];
173
332
  for (let i = 1; i < structure.length; i++) {
174
333
  this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
175
334
  }
335
+ this._dropouts = [];
336
+ if (dropoutRate > 0) {
337
+ for (let i = 0; i < nLayers - 1; i++) {
338
+ this._dropouts.push(new Dropout(dropoutRate));
339
+ }
340
+ }
341
+ const outputLayer = this.layers[this.layers.length - 1];
342
+ const outputActivation = outputLayer.neurons[0].activation;
343
+ for (let i = 1; i < outputLayer.neurons.length; i++) {
344
+ if (outputLayer.neurons[i].activation !== outputActivation) {
345
+ throw new Error("All output neurons must share the same activation function");
346
+ }
347
+ }
176
348
  }
177
- predict(inputs) {
178
- return this.layers.reduce((acc, layer) => layer.predict(acc), inputs);
349
+ predict(inputs, training = false) {
350
+ validateArray(inputs, this.structure[0], "NetworkN.predict");
351
+ let current = [...inputs];
352
+ for (let i = 0; i < this.layers.length; i++) {
353
+ const layerInput = [...current];
354
+ const layerOutput = this.layers[i].predict(current);
355
+ if (this._shouldResidual(i)) {
356
+ if (this.structure[i] === this.structure[i + 1]) {
357
+ current = layerOutput.map((v, j) => v + layerInput[j]);
358
+ } else {
359
+ current = [...layerOutput];
360
+ }
361
+ } else {
362
+ current = [...layerOutput];
363
+ }
364
+ if (i < this._dropouts.length) {
365
+ current = this._dropouts[i].forward(current, training);
366
+ }
367
+ }
368
+ return current;
179
369
  }
180
370
  // Generalized backpropagation across L layers.
181
371
  // Returns the mean squared error for the example.
182
372
  train(inputs, targets, lr) {
183
- const act = [inputs];
184
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
373
+ validateArray(inputs, this.structure[0], "NetworkN.train");
374
+ validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
375
+ const act = this._forwardAll(inputs, true);
185
376
  const pred = act[act.length - 1];
186
377
  const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
187
- let deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
188
- for (let l = this.layers.length - 1; l >= 0; l--) {
189
- const layer = this.layers[l];
190
- const layerIn = act[l];
191
- const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
192
- const prevDeltas = layerIn.map((out, j) => {
193
- const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
194
- return prevAct ? errProp * prevAct.dfn(out) : errProp;
195
- });
196
- layer.neurons.forEach((n, k) => {
197
- n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
198
- });
199
- deltas = prevDeltas;
200
- }
378
+ const deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
379
+ this._backpropLayers(act, deltas, lr);
201
380
  return pred.reduce((s, p, i) => s + (targets[i] - p) ** 2, 0) / pred.length;
202
381
  }
203
382
  // Backprop with externally provided output-layer deltas.
204
383
  // Useful for custom loss functions (e.g. physics-based gradients).
205
384
  trainWithDeltas(inputs, outputDeltas, lr) {
385
+ const act = this._forwardAll(inputs, true);
386
+ this._backpropLayers(act, outputDeltas, lr);
387
+ }
388
+ // ── Flat weight serialization ─────────────────────────────────────────────
389
+ // Order: layer 0 (all neurons), layer 1, ..., layer N.
390
+ getWeights() {
391
+ for (const d of this._dropouts) d.resetMask();
392
+ const w = [];
393
+ for (const layer of this.layers) {
394
+ for (const n of layer.neurons) {
395
+ w.push(...n.weights, n.bias);
396
+ }
397
+ }
398
+ return w;
399
+ }
400
+ setWeights(weights) {
401
+ for (const d of this._dropouts) d.resetMask();
402
+ let idx = 0;
403
+ for (const layer of this.layers) {
404
+ for (const n of layer.neurons) {
405
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
406
+ n.bias = weights[idx++];
407
+ }
408
+ }
409
+ }
410
+ // ── Private helpers ──────────────────────────────────────────────────────
411
+ _shouldResidual(layerIndex) {
412
+ if (typeof this._residual === "function") return this._residual(layerIndex);
413
+ return this._residual;
414
+ }
415
+ // Forward pass storing activations at every layer boundary.
416
+ // Used by train(), trainWithDeltas(), and predict() shares the same logic.
417
+ _forwardAll(inputs, training) {
206
418
  const act = [inputs];
207
- for (const layer of this.layers) act.push(layer.predict(act[act.length - 1]));
419
+ for (let i = 0; i < this.layers.length; i++) {
420
+ const layerInput = act[act.length - 1];
421
+ const layerOutput = this.layers[i].predict(layerInput);
422
+ let current;
423
+ if (this._shouldResidual(i) && this.structure[i] === this.structure[i + 1]) {
424
+ current = layerOutput.map((v, j) => v + layerInput[j]);
425
+ } else {
426
+ current = layerOutput;
427
+ }
428
+ if (i < this._dropouts.length) {
429
+ current = this._dropouts[i].forward(current, training);
430
+ }
431
+ act.push(current);
432
+ }
433
+ return act;
434
+ }
435
+ // Backward pass: updates all layer weights given the pre-computed activations
436
+ // and the initial output-layer deltas.
437
+ _backpropLayers(act, outputDeltas, lr) {
208
438
  let deltas = outputDeltas;
209
439
  for (let l = this.layers.length - 1; l >= 0; l--) {
210
440
  const layer = this.layers[l];
441
+ if (l < this._dropouts.length) {
442
+ deltas = this._dropouts[l].backward(deltas);
443
+ }
211
444
  const layerIn = act[l];
212
445
  const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
213
446
  const prevDeltas = layerIn.map((out, j) => {
214
447
  const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
215
448
  return prevAct ? errProp * prevAct.dfn(out) : errProp;
216
449
  });
450
+ if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
451
+ for (let j = 0; j < prevDeltas.length; j++) prevDeltas[j] += deltas[j];
452
+ }
217
453
  layer.neurons.forEach((n, k) => {
218
454
  n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
219
455
  });
@@ -234,7 +470,7 @@ var Gate = class {
234
470
  // shape: [hSize]
235
471
  constructor(inputSize, hSize, initBias = 0) {
236
472
  const n = inputSize + hSize;
237
- const limit = Math.sqrt(2 / n);
473
+ const limit = Math.sqrt(2 / (n + hSize));
238
474
  this.W = Array.from(
239
475
  { length: hSize },
240
476
  () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
@@ -248,8 +484,11 @@ var Gate = class {
248
484
  }
249
485
  };
250
486
  var LSTMLayer = class {
251
- constructor(inputSize, hiddenSize) {
487
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
252
488
  this._traj = [];
489
+ if (inputSize <= 0 || hiddenSize <= 0) {
490
+ throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
491
+ }
253
492
  this.inputSize = inputSize;
254
493
  this.hSize = hiddenSize;
255
494
  this.h = new Array(hiddenSize).fill(0);
@@ -258,6 +497,29 @@ var LSTMLayer = class {
258
497
  this.inputGate = new Gate(inputSize, hiddenSize);
259
498
  this.cellGate = new Gate(inputSize, hiddenSize);
260
499
  this.outputGate = new Gate(inputSize, hiddenSize);
500
+ const combSize = inputSize + hiddenSize;
501
+ this._optimizers = {
502
+ forgetW: Array.from(
503
+ { length: hiddenSize },
504
+ () => Array.from({ length: combSize }, () => optimizerFactory())
505
+ ),
506
+ forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
507
+ inputW: Array.from(
508
+ { length: hiddenSize },
509
+ () => Array.from({ length: combSize }, () => optimizerFactory())
510
+ ),
511
+ inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
512
+ cellW: Array.from(
513
+ { length: hiddenSize },
514
+ () => Array.from({ length: combSize }, () => optimizerFactory())
515
+ ),
516
+ cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
517
+ outputW: Array.from(
518
+ { length: hiddenSize },
519
+ () => Array.from({ length: combSize }, () => optimizerFactory())
520
+ ),
521
+ outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
522
+ };
261
523
  }
262
524
  // ── Reset state and trajectory (call at episode start) ────────────────────
263
525
  reset() {
@@ -267,6 +529,9 @@ var LSTMLayer = class {
267
529
  }
268
530
  // ── Forward pass ──────────────────────────────────────────────────────────
269
531
  predict(inputs) {
532
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
533
+ throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
534
+ }
270
535
  const combined = [...inputs, ...this.h];
271
536
  const c_prev = [...this.c];
272
537
  const zf = this.forgetGate.linear(combined);
@@ -341,15 +606,15 @@ var LSTMLayer = class {
341
606
  const scale = lr / T;
342
607
  for (let k = 0; k < hSize; k++) {
343
608
  for (let j = 0; j < combSize; j++) {
344
- this.forgetGate.W[k][j] += scale * dWf[k][j];
345
- this.inputGate.W[k][j] += scale * dWi[k][j];
346
- this.cellGate.W[k][j] += scale * dWg[k][j];
347
- this.outputGate.W[k][j] += scale * dWo[k][j];
609
+ this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
610
+ this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
611
+ this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
612
+ this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
348
613
  }
349
- this.forgetGate.b[k] += scale * dbf[k];
350
- this.inputGate.b[k] += scale * dbi[k];
351
- this.cellGate.b[k] += scale * dbg[k];
352
- this.outputGate.b[k] += scale * dbo[k];
614
+ this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
615
+ this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
616
+ this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
617
+ this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
353
618
  }
354
619
  this._traj = [];
355
620
  }
@@ -372,10 +637,38 @@ var LSTMLayer = class {
372
637
  this.outputGate.W = data.outputGate.W;
373
638
  this.outputGate.b = data.outputGate.b;
374
639
  }
640
+ // ── Flat weight serialization ─────────────────────────────────────────────
641
+ // Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
642
+ getWeightsFlat() {
643
+ const w = [];
644
+ for (const row of this.forgetGate.W) w.push(...row);
645
+ w.push(...this.forgetGate.b);
646
+ for (const row of this.inputGate.W) w.push(...row);
647
+ w.push(...this.inputGate.b);
648
+ for (const row of this.cellGate.W) w.push(...row);
649
+ w.push(...this.cellGate.b);
650
+ for (const row of this.outputGate.W) w.push(...row);
651
+ w.push(...this.outputGate.b);
652
+ return w;
653
+ }
654
+ setWeightsFlat(weights) {
655
+ let idx = 0;
656
+ for (let i = 0; i < this.forgetGate.W.length; i++)
657
+ for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
658
+ for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
659
+ for (let i = 0; i < this.inputGate.W.length; i++)
660
+ for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
661
+ for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
662
+ for (let i = 0; i < this.cellGate.W.length; i++)
663
+ for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
664
+ for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
665
+ for (let i = 0; i < this.outputGate.W.length; i++)
666
+ for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
667
+ for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
668
+ }
375
669
  };
376
670
 
377
671
  // src/NetworkLSTM.ts
378
- var defaultOptimizer4 = () => new SGD();
379
672
  var NetworkLSTM = class {
380
673
  // [T][layer+1][neuron]
381
674
  constructor(inputSize, hiddenSize, denseStructure, options = {}) {
@@ -383,7 +676,7 @@ var NetworkLSTM = class {
383
676
  this.hiddenSize = hiddenSize;
384
677
  this.lstm = new LSTMLayer(inputSize, hiddenSize);
385
678
  const activation = options.denseActivation ?? sigmoid2;
386
- const optimizer = options.optimizer ?? defaultOptimizer4;
679
+ const optimizer = options.optimizer ?? defaultOptimizer;
387
680
  this.denseLayers = [];
388
681
  const sizes = [hiddenSize, ...denseStructure];
389
682
  for (let i = 1; i < sizes.length; i++) {
@@ -398,6 +691,7 @@ var NetworkLSTM = class {
398
691
  }
399
692
  // ── Forward pass ──────────────────────────────────────────────────────────
400
693
  predict(inputs) {
694
+ validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
401
695
  const h = this.lstm.predict(inputs);
402
696
  const acts = [h];
403
697
  for (const layer of this.denseLayers) {
@@ -473,6 +767,30 @@ var NetworkLSTM = class {
473
767
  });
474
768
  });
475
769
  }
770
+ // ── Flat weight serialization ─────────────────────────────────────────────
771
+ // Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
772
+ getWeightsFlat() {
773
+ const w = [];
774
+ w.push(...this.lstm.getWeightsFlat());
775
+ for (const layer of this.denseLayers) {
776
+ for (const n of layer.neurons) {
777
+ w.push(...n.weights, n.bias);
778
+ }
779
+ }
780
+ return w;
781
+ }
782
+ setWeightsFlat(weights) {
783
+ let idx = 0;
784
+ const lstmLen = this.lstm.getWeightsFlat().length;
785
+ this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
786
+ idx += lstmLen;
787
+ for (const layer of this.denseLayers) {
788
+ for (const n of layer.neurons) {
789
+ for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
790
+ n.bias = weights[idx++];
791
+ }
792
+ }
793
+ }
476
794
  };
477
795
 
478
796
  // src/MatMul.ts
@@ -480,6 +798,9 @@ function matMul(A, B) {
480
798
  const rows = A.length;
481
799
  const inner = B.length;
482
800
  const cols = B[0].length;
801
+ if (A[0].length !== B.length) {
802
+ throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
803
+ }
483
804
  const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
484
805
  for (let i = 0; i < rows; i++)
485
806
  for (let k = 0; k < inner; k++) {
@@ -530,6 +851,33 @@ var WeightMatrix = class {
530
851
  this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
531
852
  }
532
853
  }
854
+ // ── Flat weight serialization ─────────────────────────────────────────────
855
+ getWeights() {
856
+ const w = [];
857
+ for (const row of this.W) w.push(...row);
858
+ return w;
859
+ }
860
+ setWeights(weights) {
861
+ let idx = 0;
862
+ for (let i = 0; i < this.W.length; i++)
863
+ for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
864
+ }
865
+ };
866
+ var BiasVector = class {
867
+ constructor(size) {
868
+ this.values = new Array(size).fill(0);
869
+ this.opts = Array.from({ length: size }, () => new Adam());
870
+ }
871
+ update(grad, lr) {
872
+ for (let i = 0; i < this.values.length; i++)
873
+ this.values[i] = this.opts[i].step(this.values[i], grad[i], lr);
874
+ }
875
+ getWeights() {
876
+ return [...this.values];
877
+ }
878
+ setWeights(weights) {
879
+ for (let i = 0; i < this.values.length; i++) this.values[i] = weights[i];
880
+ }
533
881
  };
534
882
  var EmbeddingMatrix = class {
535
883
  constructor(vocabSize, d_model) {
@@ -546,15 +894,29 @@ var EmbeddingMatrix = class {
546
894
  for (let m = 0; m < this.W[idx].length; m++)
547
895
  this.W[idx][m] += lr * grad[m];
548
896
  }
897
+ // ── Serializable interface ─────────────────────────────────────────────────
898
+ // Flattened order: row 0, row 1, ... row (vocabSize-1)
899
+ getWeights() {
900
+ const w = [];
901
+ for (const row of this.W) w.push(...row);
902
+ return w;
903
+ }
904
+ setWeights(weights) {
905
+ let idx = 0;
906
+ for (let i = 0; i < this.W.length; i++)
907
+ for (let j = 0; j < this.W[i].length; j++)
908
+ this.W[i][j] = weights[idx++];
909
+ }
549
910
  };
550
911
 
551
912
  // src/AttentionHead.ts
552
913
  var AttentionHead = class {
553
- constructor(d_model, d_k, d_v) {
914
+ constructor(d_model, d_k, d_v, causal = false) {
554
915
  // d_v × d_model
555
916
  this.cache = null;
556
917
  this.d_k = d_k;
557
918
  this.d_v = d_v;
919
+ this.causal = causal;
558
920
  this.Wq = new WeightMatrix(d_k, d_model);
559
921
  this.Wk = new WeightMatrix(d_k, d_model);
560
922
  this.Wv = new WeightMatrix(d_v, d_model);
@@ -575,10 +937,10 @@ var AttentionHead = class {
575
937
  );
576
938
  const scores = Array.from(
577
939
  { length: seqLen },
578
- (_, i) => Array.from(
579
- { length: seqLen },
580
- (_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
581
- )
940
+ (_, i) => Array.from({ length: seqLen }, (_2, j) => {
941
+ if (this.causal && j > i) return -Infinity;
942
+ return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
943
+ })
582
944
  );
583
945
  const attn = scores.map((row) => softmax(row));
584
946
  const out = Array.from(
@@ -602,6 +964,7 @@ var AttentionHead = class {
602
964
  // 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
603
965
  // 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
604
966
  backward(dOut, lr) {
967
+ if (!this.cache) throw new Error("AttentionHead.backward() called before predict()");
605
968
  const { X, Q, K, V, attn } = this.cache;
606
969
  const seqLen = X.length;
607
970
  const d_model = X[0].length;
@@ -674,21 +1037,40 @@ var AttentionHead = class {
674
1037
  getAttentionWeights() {
675
1038
  return this.cache ? this.cache.attn : null;
676
1039
  }
1040
+ // ── Flat weight serialization ─────────────────────────────────────────────
1041
+ // Order: Wq, Wk, Wv.
1042
+ getWeights() {
1043
+ const w = [];
1044
+ for (const row of this.Wq.W) w.push(...row);
1045
+ for (const row of this.Wk.W) w.push(...row);
1046
+ for (const row of this.Wv.W) w.push(...row);
1047
+ return w;
1048
+ }
1049
+ setWeights(weights) {
1050
+ let idx = 0;
1051
+ for (let i = 0; i < this.Wq.W.length; i++)
1052
+ for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
1053
+ for (let i = 0; i < this.Wk.W.length; i++)
1054
+ for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
1055
+ for (let i = 0; i < this.Wv.W.length; i++)
1056
+ for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
1057
+ }
677
1058
  };
678
1059
 
679
1060
  // src/MultiHeadAttention.ts
680
1061
  var MultiHeadAttention = class {
681
1062
  // seqLen × (nHeads * d_k)
682
- constructor(d_model, nHeads) {
1063
+ constructor(d_model, nHeads, causal = false) {
683
1064
  // d_model × (nHeads * d_k)
684
1065
  // Cached for backward
685
1066
  this._concat = null;
686
1067
  this.nHeads = nHeads;
687
1068
  this.d_model = d_model;
688
1069
  this.d_k = Math.floor(d_model / nHeads);
1070
+ this.causal = causal;
689
1071
  this.heads = Array.from(
690
1072
  { length: nHeads },
691
- () => new AttentionHead(d_model, this.d_k, this.d_k)
1073
+ () => new AttentionHead(d_model, this.d_k, this.d_k, causal)
692
1074
  );
693
1075
  this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
694
1076
  }
@@ -710,6 +1092,7 @@ var MultiHeadAttention = class {
710
1092
  // ── Backward ──────────────────────────────────────────────────────────────
711
1093
  // dOut: seqLen × d_model → dX: seqLen × d_model
712
1094
  backward(dOut, lr) {
1095
+ if (!this._concat) throw new Error("MultiHeadAttention.backward() called before predict()");
713
1096
  const seqLen = dOut.length;
714
1097
  const concatD = this.nHeads * this.d_k;
715
1098
  const d_model = this.d_model;
@@ -747,6 +1130,31 @@ var MultiHeadAttention = class {
747
1130
  getAttentionWeights() {
748
1131
  return this.heads.map((h) => h.getAttentionWeights());
749
1132
  }
1133
+ // ── Flat weight serialization ─────────────────────────────────────────────
1134
+ // Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
1135
+ getWeights() {
1136
+ const w = [];
1137
+ for (const head of this.heads) {
1138
+ for (const row of head.Wq.W) w.push(...row);
1139
+ for (const row of head.Wk.W) w.push(...row);
1140
+ for (const row of head.Wv.W) w.push(...row);
1141
+ }
1142
+ for (const row of this.Wo.W) w.push(...row);
1143
+ return w;
1144
+ }
1145
+ setWeights(weights) {
1146
+ let idx = 0;
1147
+ for (const head of this.heads) {
1148
+ for (let i = 0; i < head.Wq.W.length; i++)
1149
+ for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
1150
+ for (let i = 0; i < head.Wk.W.length; i++)
1151
+ for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
1152
+ for (let i = 0; i < head.Wv.W.length; i++)
1153
+ for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
1154
+ }
1155
+ for (let i = 0; i < this.Wo.W.length; i++)
1156
+ for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
1157
+ }
750
1158
  };
751
1159
 
752
1160
  // src/LayerNorm.ts
@@ -789,20 +1197,32 @@ var LayerNorm = class {
789
1197
  backwardOne(dOut, pos, lr) {
790
1198
  const { x_norm, std } = this._cache[pos];
791
1199
  const N = dOut.length;
1200
+ const gammaOld = this.gamma.slice();
792
1201
  for (let i = 0; i < N; i++) {
793
1202
  this.gamma[i] += lr * dOut[i] * x_norm[i];
794
1203
  this.beta[i] += lr * dOut[i];
795
1204
  }
796
- const D = dOut.map((d, i) => d * this.gamma[i]);
1205
+ const D = dOut.map((d, i) => d * gammaOld[i]);
797
1206
  const mD = D.reduce((s, v) => s + v, 0) / N;
798
1207
  const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
799
1208
  return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
800
1209
  }
1210
+ // ── Flat weight serialization ─────────────────────────────────────────────
1211
+ // Order: gamma, beta.
1212
+ getWeights() {
1213
+ return [...this.gamma, ...this.beta];
1214
+ }
1215
+ setWeights(weights) {
1216
+ const dim = this.gamma.length;
1217
+ for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
1218
+ for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
1219
+ }
801
1220
  };
802
1221
 
803
1222
  // src/TransformerBlock.ts
804
1223
  var TransformerBlock = class {
805
- constructor({ d_model, nHeads, d_ff }) {
1224
+ constructor({ d_model, nHeads, d_ff, causal = false }) {
1225
+ // d_model
806
1226
  // Forward caches (needed for backprop)
807
1227
  this._X = null;
808
1228
  this._attnOut = null;
@@ -814,15 +1234,13 @@ var TransformerBlock = class {
814
1234
  this._ff2Out = null;
815
1235
  this.d_model = d_model;
816
1236
  this.d_ff = d_ff;
817
- this.attn = new MultiHeadAttention(d_model, nHeads);
1237
+ this.attn = new MultiHeadAttention(d_model, nHeads, causal);
818
1238
  this.norm1 = new LayerNorm(d_model);
819
1239
  this.norm2 = new LayerNorm(d_model);
820
1240
  this.ff1 = new WeightMatrix(d_ff, d_model);
821
1241
  this.ff2 = new WeightMatrix(d_model, d_ff);
822
- this.b1 = new Array(d_ff).fill(0);
823
- this.b2 = new Array(d_model).fill(0);
824
- this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
825
- this.b2Opts = Array.from({ length: d_model }, () => new Adam());
1242
+ this.b1 = new BiasVector(d_ff);
1243
+ this.b2 = new BiasVector(d_model);
826
1244
  }
827
1245
  // ── Forward ───────────────────────────────────────────────────────────────
828
1246
  // X: seqLen × d_model → out: seqLen × d_model
@@ -835,11 +1253,11 @@ var TransformerBlock = class {
835
1253
  return this.norm1.predictOne(added, i);
836
1254
  });
837
1255
  const ff1Pre = h1.map(
838
- (h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
1256
+ (h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1.values[k]))
839
1257
  );
840
1258
  const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
841
1259
  const ff2Out = ff1Out.map(
842
- (h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
1260
+ (h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2.values[k]))
843
1261
  );
844
1262
  this.norm2.resetCache(seqLen);
845
1263
  const out = h1.map((h, i) => {
@@ -857,6 +1275,9 @@ var TransformerBlock = class {
857
1275
  // ── Backward ──────────────────────────────────────────────────────────────
858
1276
  // dOut: seqLen × d_model → dX: seqLen × d_model
859
1277
  backward(dOut, lr) {
1278
+ if (!this._h1 || !this._ff1Out || !this._ff1Pre) {
1279
+ throw new Error("TransformerBlock.backward() called before predict()");
1280
+ }
860
1281
  const seqLen = dOut.length;
861
1282
  const d_model = this.d_model;
862
1283
  const h1 = this._h1;
@@ -881,8 +1302,7 @@ var TransformerBlock = class {
881
1302
  (_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
882
1303
  );
883
1304
  this.ff2.update(dW2, lr);
884
- for (let m = 0; m < d_model; m++)
885
- this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
1305
+ this.b2.update(db2, lr);
886
1306
  const dFf1Pre = dFf1Out.map(
887
1307
  (d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
888
1308
  );
@@ -904,8 +1324,7 @@ var TransformerBlock = class {
904
1324
  (_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
905
1325
  );
906
1326
  this.ff1.update(dW1, lr);
907
- for (let k = 0; k < this.d_ff; k++)
908
- this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
1327
+ this.b1.update(db1, lr);
909
1328
  const dH1 = Array.from(
910
1329
  { length: seqLen },
911
1330
  (_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
@@ -927,6 +1346,36 @@ var TransformerBlock = class {
927
1346
  getAttentionWeights() {
928
1347
  return this.attn.getAttentionWeights();
929
1348
  }
1349
+ // ── Flat weight serialization ─────────────────────────────────────────────
1350
+ // Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
1351
+ getWeights() {
1352
+ const w = [];
1353
+ w.push(...this.attn.getWeights());
1354
+ w.push(...this.norm1.gamma, ...this.norm1.beta);
1355
+ for (const row of this.ff1.W) w.push(...row);
1356
+ w.push(...this.b1.values);
1357
+ for (const row of this.ff2.W) w.push(...row);
1358
+ w.push(...this.b2.values);
1359
+ w.push(...this.norm2.gamma, ...this.norm2.beta);
1360
+ return w;
1361
+ }
1362
+ setWeights(weights) {
1363
+ let idx = 0;
1364
+ const attnLen = this.attn.getWeights().length;
1365
+ this.attn.setWeights(weights.slice(idx, idx + attnLen));
1366
+ idx += attnLen;
1367
+ this.norm1.setWeights(weights.slice(idx, idx + this.norm1.getWeights().length));
1368
+ idx += this.norm1.getWeights().length;
1369
+ this.ff1.setWeights(weights.slice(idx, idx + this.ff1.getWeights().length));
1370
+ idx += this.ff1.getWeights().length;
1371
+ this.b1.setWeights(weights.slice(idx, idx + this.b1.values.length));
1372
+ idx += this.b1.values.length;
1373
+ this.ff2.setWeights(weights.slice(idx, idx + this.ff2.getWeights().length));
1374
+ idx += this.ff2.getWeights().length;
1375
+ this.b2.setWeights(weights.slice(idx, idx + this.b2.values.length));
1376
+ idx += this.b2.values.length;
1377
+ this.norm2.setWeights(weights.slice(idx, idx + this.norm2.getWeights().length));
1378
+ }
930
1379
  };
931
1380
 
932
1381
  // src/NetworkTransformer.ts
@@ -951,8 +1400,7 @@ var NetworkTransformer = class {
951
1400
  () => new TransformerBlock({ d_model, nHeads, d_ff })
952
1401
  );
953
1402
  this.outputProj = new WeightMatrix(nClasses, d_model);
954
- this.outputBias = new Array(nClasses).fill(0);
955
- this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
1403
+ this.outputBias = new BiasVector(nClasses);
956
1404
  }
957
1405
  // ── Forward pass ──────────────────────────────────────────────────────────
958
1406
  // tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
@@ -960,7 +1408,7 @@ var NetworkTransformer = class {
960
1408
  const h = this._forward(tokens);
961
1409
  return h.flatMap(
962
1410
  (hi) => this.outputProj.W.map(
963
- (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1411
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
964
1412
  )
965
1413
  );
966
1414
  }
@@ -974,7 +1422,7 @@ var NetworkTransformer = class {
974
1422
  const h = this._forward(tokens);
975
1423
  const logits = h.map(
976
1424
  (hi) => this.outputProj.W.map(
977
- (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1425
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
978
1426
  )
979
1427
  );
980
1428
  let loss = 0;
@@ -1009,8 +1457,7 @@ var NetworkTransformer = class {
1009
1457
  (_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
1010
1458
  );
1011
1459
  this.outputProj.update(dWout, lr);
1012
- for (let c = 0; c < this.nClasses; c++)
1013
- this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1460
+ this.outputBias.update(dBout, lr);
1014
1461
  let dX = dH;
1015
1462
  for (let b = this.blocks.length - 1; b >= 0; b--)
1016
1463
  dX = this.blocks[b].backward(dX, lr);
@@ -1025,6 +1472,35 @@ var NetworkTransformer = class {
1025
1472
  getAttentionWeights() {
1026
1473
  return this.blocks.map((b) => b.getAttentionWeights());
1027
1474
  }
1475
+ // ── Flat weight serialization ─────────────────────────────────────────────
1476
+ // Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
1477
+ getWeights() {
1478
+ const w = [];
1479
+ w.push(...this.tokenEmb.getWeights());
1480
+ w.push(...this.posEmb.getWeights());
1481
+ for (const block of this.blocks) w.push(...block.getWeights());
1482
+ w.push(...this.outputProj.getWeights());
1483
+ w.push(...this.outputBias.getWeights());
1484
+ return w;
1485
+ }
1486
+ setWeights(weights) {
1487
+ let idx = 0;
1488
+ const tokenEmbLen = this.tokenEmb.getWeights().length;
1489
+ this.tokenEmb.setWeights(weights.slice(idx, idx + tokenEmbLen));
1490
+ idx += tokenEmbLen;
1491
+ const posEmbLen = this.posEmb.getWeights().length;
1492
+ this.posEmb.setWeights(weights.slice(idx, idx + posEmbLen));
1493
+ idx += posEmbLen;
1494
+ for (const block of this.blocks) {
1495
+ const blockLen = block.getWeights().length;
1496
+ block.setWeights(weights.slice(idx, idx + blockLen));
1497
+ idx += blockLen;
1498
+ }
1499
+ const outProjLen = this.outputProj.getWeights().length;
1500
+ this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
1501
+ idx += outProjLen;
1502
+ this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
1503
+ }
1028
1504
  // ── Internal ──────────────────────────────────────────────────────────────
1029
1505
  // Shared embedding + block forward pass.
1030
1506
  _forward(tokens) {
@@ -1044,25 +1520,28 @@ var NetworkTransformerRL = class {
1044
1520
  constructor(seqLen, inputDim, options = {}) {
1045
1521
  // Forward caches para backprop
1046
1522
  this._projected = null;
1523
+ // For max pooling backward: argmax per dimension across all positions
1524
+ this._argmax = null;
1047
1525
  const {
1048
1526
  d_model = 32,
1049
1527
  nHeads = 2,
1050
1528
  d_ff = 64,
1051
1529
  nBlocks = 2,
1052
- nActions = 2
1530
+ nActions = 2,
1531
+ pooling = "weighted"
1053
1532
  } = options;
1054
1533
  this.seqLen = seqLen;
1055
1534
  this.inputDim = inputDim;
1056
1535
  this.d_model = d_model;
1057
1536
  this.nActions = nActions;
1537
+ this._pooling = pooling;
1058
1538
  this.inputProj = new WeightMatrix(d_model, inputDim);
1059
1539
  this.blocks = Array.from(
1060
1540
  { length: nBlocks },
1061
- () => new TransformerBlock({ d_model, nHeads, d_ff })
1541
+ () => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
1062
1542
  );
1063
1543
  this.outputProj = new WeightMatrix(nActions, d_model);
1064
- this.outputBias = new Array(nActions).fill(0);
1065
- this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
1544
+ this.outputBias = new BiasVector(nActions);
1066
1545
  }
1067
1546
  // ── Forward ────────────────────────────────────────────────────────────────
1068
1547
  // sequence: seqLen × inputDim → nActions Q-values
@@ -1070,7 +1549,7 @@ var NetworkTransformerRL = class {
1070
1549
  const h = this._forward(sequence);
1071
1550
  const pooled = this._pool(h);
1072
1551
  return this.outputProj.W.map(
1073
- (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
1552
+ (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
1074
1553
  );
1075
1554
  }
1076
1555
  // ── Training ────────────────────────────────────────────────────────────────
@@ -1082,7 +1561,7 @@ var NetworkTransformerRL = class {
1082
1561
  const h = this._forward(sequence);
1083
1562
  const pooled = this._pool(h);
1084
1563
  const pred = this.outputProj.W.map(
1085
- (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
1564
+ (row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
1086
1565
  );
1087
1566
  const n = this.nActions;
1088
1567
  let loss = 0;
@@ -1105,13 +1584,8 @@ var NetworkTransformerRL = class {
1105
1584
  );
1106
1585
  const dBout = dPred.slice();
1107
1586
  this.outputProj.update(dWout, lr);
1108
- for (let c = 0; c < this.nActions; c++)
1109
- this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1110
- let dH = Array.from(
1111
- { length: this.seqLen },
1112
- (_, i) => dPooled.map((v) => v / this.seqLen)
1113
- // Gradiente dividido entre posiciones
1114
- );
1587
+ this.outputBias.update(dBout, lr);
1588
+ let dH = this._distributePoolGradient(dPooled);
1115
1589
  for (let b = this.blocks.length - 1; b >= 0; b--)
1116
1590
  dH = this.blocks[b].backward(dH, lr);
1117
1591
  for (let i = 0; i < this.seqLen; i++) {
@@ -1130,8 +1604,32 @@ var NetworkTransformerRL = class {
1130
1604
  getAttentionWeights() {
1131
1605
  return this.blocks.map((b) => b.getAttentionWeights());
1132
1606
  }
1133
- // ── Serialization ──────────────────────────────────────────────────────────
1134
- getWeights() {
1607
+ // ── Flat weight serialization ─────────────────────────────────────────────
1608
+ // Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
1609
+ getWeightsFlat() {
1610
+ const w = [];
1611
+ w.push(...this.inputProj.getWeights());
1612
+ for (const block of this.blocks) w.push(...block.getWeights());
1613
+ w.push(...this.outputProj.getWeights());
1614
+ w.push(...this.outputBias.getWeights());
1615
+ return w;
1616
+ }
1617
+ setWeightsFlat(weights) {
1618
+ let idx = 0;
1619
+ const inputProjLen = this.inputProj.getWeights().length;
1620
+ this.inputProj.setWeights(weights.slice(idx, idx + inputProjLen));
1621
+ idx += inputProjLen;
1622
+ for (const block of this.blocks) {
1623
+ const blockLen = block.getWeights().length;
1624
+ block.setWeights(weights.slice(idx, idx + blockLen));
1625
+ idx += blockLen;
1626
+ }
1627
+ const outProjLen = this.outputProj.getWeights().length;
1628
+ this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
1629
+ idx += outProjLen;
1630
+ this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
1631
+ }
1632
+ getWeightsStructured() {
1135
1633
  return {
1136
1634
  inputProj: this.inputProj.W.map((r) => [...r]),
1137
1635
  blocks: this.blocks.map((b) => ({
@@ -1147,17 +1645,15 @@ var NetworkTransformerRL = class {
1147
1645
  norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
1148
1646
  ff1: b.ff1.W.map((r) => [...r]),
1149
1647
  ff2: b.ff2.W.map((r) => [...r]),
1150
- b1: [...b.b1],
1151
- b2: [...b.b2]
1648
+ b1: [...b.b1.values],
1649
+ b2: [...b.b2.values]
1152
1650
  })),
1153
1651
  outputProj: this.outputProj.W.map((r) => [...r]),
1154
- outputBias: [...this.outputBias]
1652
+ outputBias: [...this.outputBias.values]
1155
1653
  };
1156
1654
  }
1157
- setWeights(data) {
1158
- data.inputProj.forEach((row, i) => {
1159
- this.inputProj.W[i] = [...row];
1160
- });
1655
+ setWeightsStructured(data) {
1656
+ this.inputProj.setWeights(data.inputProj.flat());
1161
1657
  data.blocks.forEach((bd, b) => {
1162
1658
  const blk = this.blocks[b];
1163
1659
  bd.attn.heads.forEach((hd, h) => {
@@ -1172,11 +1668,20 @@ var NetworkTransformerRL = class {
1172
1668
  blk.norm2.beta = [...bd.norm2.beta];
1173
1669
  blk.ff1.W = bd.ff1.map((r) => [...r]);
1174
1670
  blk.ff2.W = bd.ff2.map((r) => [...r]);
1175
- blk.b1 = [...bd.b1];
1176
- blk.b2 = [...bd.b2];
1671
+ blk.b1.setWeights(bd.b1);
1672
+ blk.b2.setWeights(bd.b2);
1177
1673
  });
1178
1674
  this.outputProj.W = data.outputProj.map((r) => [...r]);
1179
- this.outputBias = [...data.outputBias];
1675
+ this.outputBias.setWeights(data.outputBias);
1676
+ }
1677
+ // ── Serializable interface (flat array) ────────────────────────────────────
1678
+ // These satisfy the Serializable interface from ModelSaver, which requires
1679
+ // getWeights(): number[] and setWeights(weights: number[]): void.
1680
+ getWeights() {
1681
+ return this.getWeightsFlat();
1682
+ }
1683
+ setWeights(weights) {
1684
+ this.setWeightsFlat(weights);
1180
1685
  }
1181
1686
  // ── Internal ────────────────────────────────────────────────────────────────
1182
1687
  _forward(sequence) {
@@ -1191,6 +1696,44 @@ var NetworkTransformerRL = class {
1191
1696
  return h;
1192
1697
  }
1193
1698
  _pool(h) {
1699
+ switch (this._pooling) {
1700
+ case "avg":
1701
+ return this._poolAvg(h);
1702
+ case "max":
1703
+ return this._poolMax(h);
1704
+ case "last":
1705
+ return this._poolLast(h);
1706
+ case "weighted":
1707
+ default:
1708
+ return this._poolWeighted(h);
1709
+ }
1710
+ }
1711
+ _poolAvg(h) {
1712
+ const n = h.length;
1713
+ return Array.from({ length: this.d_model }, (_, m) => {
1714
+ let sum = 0;
1715
+ for (let i = 0; i < n; i++)
1716
+ sum += h[i][m];
1717
+ return sum / n;
1718
+ });
1719
+ }
1720
+ _poolMax(h) {
1721
+ this._argmax = new Array(this.d_model).fill(0);
1722
+ return Array.from({ length: this.d_model }, (_, m) => {
1723
+ let maxVal = -Infinity;
1724
+ for (let i = 0; i < h.length; i++) {
1725
+ if (h[i][m] > maxVal) {
1726
+ maxVal = h[i][m];
1727
+ this._argmax[m] = i;
1728
+ }
1729
+ }
1730
+ return maxVal;
1731
+ });
1732
+ }
1733
+ _poolLast(h) {
1734
+ return [...h[h.length - 1]];
1735
+ }
1736
+ _poolWeighted(h) {
1194
1737
  const weights = Array.from(
1195
1738
  { length: this.seqLen },
1196
1739
  (_, i) => i === this.seqLen - 1 ? 2 : 1
@@ -1203,6 +1746,55 @@ var NetworkTransformerRL = class {
1203
1746
  return sum / totalWeight;
1204
1747
  });
1205
1748
  }
1749
+ /** Returns the current pooling type for inspection. */
1750
+ getPoolingType() {
1751
+ return this._pooling;
1752
+ }
1753
+ // ── Helper: distribute pooled gradient back to each position ────────────────
1754
+ // Must match the same distribution as _pool() used during forward.
1755
+ _distributePoolGradient(dPooled) {
1756
+ switch (this._pooling) {
1757
+ case "avg": {
1758
+ const n = this.seqLen;
1759
+ return Array.from(
1760
+ { length: n },
1761
+ () => dPooled.map((v) => v / n)
1762
+ );
1763
+ }
1764
+ case "max": {
1765
+ if (!this._argmax) {
1766
+ const n = this.seqLen;
1767
+ return Array.from(
1768
+ { length: n },
1769
+ () => dPooled.map((v) => v / n)
1770
+ );
1771
+ }
1772
+ const argmax = this._argmax;
1773
+ return Array.from(
1774
+ { length: this.seqLen },
1775
+ (_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
1776
+ );
1777
+ }
1778
+ case "last": {
1779
+ return Array.from(
1780
+ { length: this.seqLen },
1781
+ (_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
1782
+ );
1783
+ }
1784
+ case "weighted":
1785
+ default: {
1786
+ const weights = Array.from(
1787
+ { length: this.seqLen },
1788
+ (_, i) => i === this.seqLen - 1 ? 2 : 1
1789
+ );
1790
+ const totalWeight = weights.reduce((a, b) => a + b, 0);
1791
+ return Array.from(
1792
+ { length: this.seqLen },
1793
+ (_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
1794
+ );
1795
+ }
1796
+ }
1797
+ }
1206
1798
  };
1207
1799
 
1208
1800
  // src/losses.ts
@@ -1227,13 +1819,802 @@ function crossEntropyDeltaRaw(predicted, actual) {
1227
1819
  const p = Math.max(eps, Math.min(1 - eps, predicted));
1228
1820
  return actual / p - (1 - actual) / (1 - p);
1229
1821
  }
1822
+
1823
+ // src/GRU.ts
1824
+ function sigmoid4(x) {
1825
+ return 1 / (1 + Math.exp(-x));
1826
+ }
1827
+ function tanhFn(x) {
1828
+ const e = Math.exp(2 * x);
1829
+ return (e - 1) / (e + 1);
1830
+ }
1831
+ var Gate2 = class {
1832
+ constructor(inputSize, hSize, initBias = 0) {
1833
+ const n = inputSize + hSize;
1834
+ const limit = Math.sqrt(2 / (n + hSize));
1835
+ this.W = Array.from(
1836
+ { length: hSize },
1837
+ () => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
1838
+ );
1839
+ this.b = new Array(hSize).fill(initBias);
1840
+ }
1841
+ linear(combined) {
1842
+ return this.W.map(
1843
+ (row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
1844
+ );
1845
+ }
1846
+ };
1847
+ var GRULayer = class {
1848
+ constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
1849
+ this._traj = [];
1850
+ if (inputSize <= 0 || hiddenSize <= 0) {
1851
+ throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
1852
+ }
1853
+ this.inputSize = inputSize;
1854
+ this.hSize = hiddenSize;
1855
+ this.h = new Array(hiddenSize).fill(0);
1856
+ this.resetGate = new Gate2(inputSize, hiddenSize);
1857
+ this.updateGate = new Gate2(inputSize, hiddenSize);
1858
+ this.newGate = new Gate2(inputSize, hiddenSize);
1859
+ const combSize = inputSize + hiddenSize;
1860
+ this._optimizers = {
1861
+ resetW: Array.from(
1862
+ { length: hiddenSize },
1863
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1864
+ ),
1865
+ resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1866
+ updateW: Array.from(
1867
+ { length: hiddenSize },
1868
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1869
+ ),
1870
+ updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
1871
+ newW: Array.from(
1872
+ { length: hiddenSize },
1873
+ () => Array.from({ length: combSize }, () => optimizerFactory())
1874
+ ),
1875
+ newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
1876
+ };
1877
+ }
1878
+ reset() {
1879
+ this.h = new Array(this.hSize).fill(0);
1880
+ this._traj = [];
1881
+ }
1882
+ predict(inputs) {
1883
+ if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
1884
+ throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
1885
+ }
1886
+ const combined = [...inputs, ...this.h];
1887
+ const h_prev = [...this.h];
1888
+ const r_pre = this.resetGate.linear(combined);
1889
+ const z_pre = this.updateGate.linear(combined);
1890
+ const r_a = r_pre.map(sigmoid4);
1891
+ const z_a = z_pre.map(sigmoid4);
1892
+ const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
1893
+ const n_pre = this.newGate.linear(combined_r);
1894
+ const n_a = n_pre.map(tanhFn);
1895
+ const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
1896
+ this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
1897
+ this.h = h;
1898
+ return h;
1899
+ }
1900
+ backprop(dh_seq, lr) {
1901
+ const T = this._traj.length;
1902
+ if (T === 0 || dh_seq.length !== T) return;
1903
+ const hSize = this.hSize;
1904
+ const combSize = this.inputSize + hSize;
1905
+ const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1906
+ const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1907
+ const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
1908
+ const dbr = new Array(hSize).fill(0);
1909
+ const dbz = new Array(hSize).fill(0);
1910
+ const dbn = new Array(hSize).fill(0);
1911
+ let dh_next = new Array(hSize).fill(0);
1912
+ for (let t = T - 1; t >= 0; t--) {
1913
+ const s = this._traj[t];
1914
+ const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
1915
+ const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
1916
+ const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
1917
+ const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
1918
+ const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
1919
+ const dr_hprev = Array.from(
1920
+ { length: hSize },
1921
+ (_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
1922
+ );
1923
+ const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
1924
+ const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
1925
+ for (let k = 0; k < hSize; k++) {
1926
+ for (let j = 0; j < combSize; j++) {
1927
+ dWr[k][j] += dr_pre[k] * s.combined[j];
1928
+ dWz[k][j] += dz_pre[k] * s.combined[j];
1929
+ dWn[k][j] += dn_pre[k] * s.combined_r[j];
1930
+ }
1931
+ dbr[k] += dr_pre[k];
1932
+ dbz[k] += dz_pre[k];
1933
+ dbn[k] += dn_pre[k];
1934
+ }
1935
+ dh_next = new Array(hSize).fill(0);
1936
+ for (let k = 0; k < hSize; k++) {
1937
+ for (let j = this.inputSize; j < combSize; j++) {
1938
+ dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
1939
+ }
1940
+ dh_next[k] += dr_hprev[k] * s.r_a[k];
1941
+ dh_next[k] += dh[k] * s.z_a[k];
1942
+ }
1943
+ }
1944
+ const scale = lr / T;
1945
+ for (let k = 0; k < hSize; k++) {
1946
+ for (let j = 0; j < combSize; j++) {
1947
+ this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
1948
+ this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
1949
+ this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
1950
+ }
1951
+ this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
1952
+ this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
1953
+ this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
1954
+ }
1955
+ this._traj = [];
1956
+ }
1957
+ // ── Flat weight serialization ─────────────────────────────────────────────
1958
+ // Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
1959
+ getWeightsFlat() {
1960
+ const w = [];
1961
+ for (const row of this.resetGate.W) w.push(...row);
1962
+ w.push(...this.resetGate.b);
1963
+ for (const row of this.updateGate.W) w.push(...row);
1964
+ w.push(...this.updateGate.b);
1965
+ for (const row of this.newGate.W) w.push(...row);
1966
+ w.push(...this.newGate.b);
1967
+ return w;
1968
+ }
1969
+ setWeightsFlat(weights) {
1970
+ let idx = 0;
1971
+ for (let i = 0; i < this.resetGate.W.length; i++)
1972
+ for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
1973
+ for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
1974
+ for (let i = 0; i < this.updateGate.W.length; i++)
1975
+ for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
1976
+ for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
1977
+ for (let i = 0; i < this.newGate.W.length; i++)
1978
+ for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
1979
+ for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
1980
+ }
1981
+ getWeights() {
1982
+ return {
1983
+ resetGate: { W: this.resetGate.W, b: this.resetGate.b },
1984
+ updateGate: { W: this.updateGate.W, b: this.updateGate.b },
1985
+ newGate: { W: this.newGate.W, b: this.newGate.b }
1986
+ };
1987
+ }
1988
+ setWeights(data) {
1989
+ this.resetGate.W = data.resetGate.W;
1990
+ this.resetGate.b = data.resetGate.b;
1991
+ this.updateGate.W = data.updateGate.W;
1992
+ this.updateGate.b = data.updateGate.b;
1993
+ this.newGate.W = data.newGate.W;
1994
+ this.newGate.b = data.newGate.b;
1995
+ }
1996
+ };
1997
+
1998
+ // src/BatchNorm.ts
1999
+ var BatchNorm = class {
2000
+ constructor(dim, momentum = 0.1) {
2001
+ this._xNorm = null;
2002
+ this._std = null;
2003
+ this.dim = dim;
2004
+ this.momentum = momentum;
2005
+ this.gamma = new Array(dim).fill(1);
2006
+ this.beta = new Array(dim).fill(0);
2007
+ this.runningMean = new Array(dim).fill(0);
2008
+ this.runningVar = new Array(dim).fill(1);
2009
+ }
2010
+ // ── Forward ───────────────────────────────────────────────────────────────
2011
+ forward(x) {
2012
+ if (x.length !== this.dim) {
2013
+ throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
2014
+ }
2015
+ const eps = 1e-5;
2016
+ for (let i = 0; i < this.dim; i++) {
2017
+ this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
2018
+ const diff = x[i] - this.runningMean[i];
2019
+ this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
2020
+ }
2021
+ this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
2022
+ this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
2023
+ return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
2024
+ }
2025
+ // ── Backward ──────────────────────────────────────────────────────────────
2026
+ backward(dOut) {
2027
+ if (!this._xNorm || !this._std) {
2028
+ throw new Error("BatchNorm.backward: call forward() first");
2029
+ }
2030
+ for (let i = 0; i < this.dim; i++) {
2031
+ }
2032
+ return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
2033
+ }
2034
+ // ── Train gamma and beta (call after backward) ────────────────────────────
2035
+ trainParams(dOut, lr) {
2036
+ if (!this._xNorm) return;
2037
+ for (let i = 0; i < this.dim; i++) {
2038
+ this.gamma[i] += lr * dOut[i] * this._xNorm[i];
2039
+ this.beta[i] += lr * dOut[i];
2040
+ }
2041
+ }
2042
+ // ── Flat weight serialization ─────────────────────────────────────────────
2043
+ // Order: gamma, beta.
2044
+ getWeights() {
2045
+ return [...this.gamma, ...this.beta];
2046
+ }
2047
+ setWeights(weights) {
2048
+ for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
2049
+ for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
2050
+ }
2051
+ };
2052
+
2053
+ // src/Conv1D.ts
2054
+ var Conv1D = class {
2055
+ constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
2056
+ // [filters]
2057
+ this._input = null;
2058
+ this._paddedInput = null;
2059
+ if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
2060
+ throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
2061
+ }
2062
+ if (kernelSize > inputLength && padding === "valid") {
2063
+ throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
2064
+ }
2065
+ if (inputChannels < 1) {
2066
+ throw new Error("Conv1D: inputChannels must be >= 1");
2067
+ }
2068
+ this.inputLength = inputLength;
2069
+ this.kernelSize = kernelSize;
2070
+ this.filters = filters;
2071
+ this.stride = stride;
2072
+ this.padding = padding;
2073
+ this.inputChannels = inputChannels;
2074
+ const limit = Math.sqrt(2 / (kernelSize * inputChannels));
2075
+ this.kernels = Array.from(
2076
+ { length: filters },
2077
+ () => Array.from(
2078
+ { length: kernelSize },
2079
+ () => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
2080
+ )
2081
+ );
2082
+ this.biases = new Array(filters).fill(0);
2083
+ this._kOpts = Array.from(
2084
+ { length: filters },
2085
+ () => Array.from(
2086
+ { length: kernelSize },
2087
+ () => Array.from({ length: inputChannels }, () => optimizerFactory())
2088
+ )
2089
+ );
2090
+ this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
2091
+ }
2092
+ // ── Forward ───────────────────────────────────────────────────────────────
2093
+ // Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
2094
+ forward(input) {
2095
+ const input2D = this._normalizeInput(input);
2096
+ this._input = input2D.map((row) => [...row]);
2097
+ let padded;
2098
+ if (this.padding === "same") {
2099
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2100
+ const padRow = new Array(this.inputChannels).fill(0);
2101
+ padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
2102
+ } else {
2103
+ padded = input2D;
2104
+ }
2105
+ this._paddedInput = padded;
2106
+ const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
2107
+ const output = Array.from(
2108
+ { length: this.filters },
2109
+ () => new Array(outputLength).fill(0)
2110
+ );
2111
+ for (let f = 0; f < this.filters; f++) {
2112
+ for (let pos = 0; pos < outputLength; pos++) {
2113
+ const start = pos * this.stride;
2114
+ let sum = this.biases[f];
2115
+ for (let k = 0; k < this.kernelSize; k++) {
2116
+ for (let c = 0; c < this.inputChannels; c++) {
2117
+ sum += this.kernels[f][k][c] * padded[start + k][c];
2118
+ }
2119
+ }
2120
+ output[f][pos] = sum;
2121
+ }
2122
+ }
2123
+ return output;
2124
+ }
2125
+ // ── Backward ──────────────────────────────────────────────────────────────
2126
+ backward(dOut, lr = 1e-3) {
2127
+ if (!this._paddedInput || !this._input) {
2128
+ throw new Error("Conv1D.backward: call forward() first");
2129
+ }
2130
+ const padded = this._paddedInput;
2131
+ const outputLength = dOut[0].length;
2132
+ const dKernels = Array.from(
2133
+ { length: this.filters },
2134
+ () => Array.from(
2135
+ { length: this.kernelSize },
2136
+ () => new Array(this.inputChannels).fill(0)
2137
+ )
2138
+ );
2139
+ const dBiases = new Array(this.filters).fill(0);
2140
+ const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
2141
+ for (let f = 0; f < this.filters; f++) {
2142
+ for (let pos = 0; pos < outputLength; pos++) {
2143
+ const start = pos * this.stride;
2144
+ dBiases[f] += dOut[f][pos];
2145
+ for (let k = 0; k < this.kernelSize; k++) {
2146
+ for (let c = 0; c < this.inputChannels; c++) {
2147
+ dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
2148
+ dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
2149
+ }
2150
+ }
2151
+ }
2152
+ }
2153
+ for (let f = 0; f < this.filters; f++) {
2154
+ for (let k = 0; k < this.kernelSize; k++) {
2155
+ for (let c = 0; c < this.inputChannels; c++) {
2156
+ this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
2157
+ }
2158
+ }
2159
+ this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
2160
+ }
2161
+ if (this.padding === "same") {
2162
+ const padSize = Math.floor((this.kernelSize - 1) / 2);
2163
+ return dPadded.slice(padSize, padSize + this.inputLength);
2164
+ }
2165
+ return dPadded.slice(0, this.inputLength);
2166
+ }
2167
+ // ── Output length ─────────────────────────────────────────────────────────
2168
+ getOutputLength() {
2169
+ if (this.padding === "same") {
2170
+ return Math.ceil(this.inputLength / this.stride);
2171
+ }
2172
+ return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
2173
+ }
2174
+ // ── Flat weight serialization ─────────────────────────────────────────────
2175
+ // Order: kernels (flattened), biases.
2176
+ getWeights() {
2177
+ const w = [];
2178
+ for (const kernel of this.kernels)
2179
+ for (const k of kernel)
2180
+ for (const c of k)
2181
+ w.push(c);
2182
+ w.push(...this.biases);
2183
+ return w;
2184
+ }
2185
+ setWeights(weights) {
2186
+ let idx = 0;
2187
+ for (let f = 0; f < this.filters; f++)
2188
+ for (let k = 0; k < this.kernelSize; k++)
2189
+ for (let c = 0; c < this.inputChannels; c++)
2190
+ this.kernels[f][k][c] = weights[idx++];
2191
+ for (let f = 0; f < this.filters; f++)
2192
+ this.biases[f] = weights[idx++];
2193
+ }
2194
+ // ── Normalize input to 2D format ─────────────────────────────────────────
2195
+ _normalizeInput(input) {
2196
+ if (input.length === 0) {
2197
+ throw new Error("Conv1D.forward: input cannot be empty");
2198
+ }
2199
+ if (typeof input[0] === "number") {
2200
+ if (this.inputChannels !== 1) {
2201
+ throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
2202
+ }
2203
+ const input1D = input;
2204
+ if (input1D.length !== this.inputLength) {
2205
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
2206
+ }
2207
+ return input1D.map((v) => [v]);
2208
+ }
2209
+ const input2D = input;
2210
+ if (input2D.length !== this.inputLength) {
2211
+ throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
2212
+ }
2213
+ for (let i = 0; i < input2D.length; i++) {
2214
+ if (input2D[i].length !== this.inputChannels) {
2215
+ throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
2216
+ }
2217
+ }
2218
+ return input2D;
2219
+ }
2220
+ };
2221
+
2222
+ // src/Trainer.ts
2223
+ var Trainer = class {
2224
+ constructor(network, options = {}) {
2225
+ this._history = [];
2226
+ this._bestLoss = Infinity;
2227
+ this._patienceCounter = 0;
2228
+ this._stopReason = "maxEpochs";
2229
+ this._metrics = [];
2230
+ this.network = network;
2231
+ this.epochs = options.epochs ?? 1e3;
2232
+ this.lrInitial = options.lr ?? 0.1;
2233
+ this.lrDecay = options.lrDecay ?? 1;
2234
+ this.verbose = options.verbose ?? false;
2235
+ this.weightDecay = options.weightDecay ?? 0;
2236
+ this._earlyStopping = options.earlyStopping;
2237
+ this._computeMetrics = options.computeMetrics ?? false;
2238
+ this.clipValue = options.clipValue ?? 0;
2239
+ }
2240
+ // ── Set external validation data (for early stopping) ────────────────────
2241
+ setValidationData(dataset) {
2242
+ if (dataset.inputs.length !== dataset.targets.length) {
2243
+ throw new Error(
2244
+ "Trainer.setValidationData: inputs and targets must have the same length"
2245
+ );
2246
+ }
2247
+ this._validationData = dataset;
2248
+ }
2249
+ // ── Get best validation loss during training ─────────────────────────────
2250
+ getBestLoss() {
2251
+ return this._bestLoss === Infinity ? -1 : this._bestLoss;
2252
+ }
2253
+ // ── Why did training stop? ───────────────────────────────────────────────
2254
+ getStopReason() {
2255
+ return this._stopReason;
2256
+ }
2257
+ // ── Get per-epoch classification metrics ─────────────────────────────────
2258
+ getMetrics() {
2259
+ return [...this._metrics];
2260
+ }
2261
+ // ── Train on dataset ──────────────────────────────────────────────────────
2262
+ train(dataset) {
2263
+ const { inputs, targets } = dataset;
2264
+ if (inputs.length !== targets.length) {
2265
+ throw new Error(
2266
+ "Trainer.train: inputs and targets must have the same length"
2267
+ );
2268
+ }
2269
+ const n = inputs.length;
2270
+ let lr = this.lrInitial;
2271
+ this._history = [];
2272
+ this._bestLoss = Infinity;
2273
+ this._patienceCounter = 0;
2274
+ this._stopReason = "maxEpochs";
2275
+ this._metrics = [];
2276
+ const netExt = this._hasWeights(this.network);
2277
+ if (this.weightDecay > 0 && !netExt) {
2278
+ console.warn(
2279
+ "Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
2280
+ );
2281
+ }
2282
+ if (this._earlyStopping && !netExt) {
2283
+ console.warn(
2284
+ "Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
2285
+ );
2286
+ }
2287
+ if (this._computeMetrics && !netExt) {
2288
+ console.warn(
2289
+ "Trainer: computeMetrics requires a network with predict(). Skipping metrics."
2290
+ );
2291
+ }
2292
+ const canDecay = this.weightDecay > 0 && netExt;
2293
+ const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
2294
+ const canMetric = this._computeMetrics && netExt;
2295
+ const isClass = canMetric && this._isClassification(targets);
2296
+ if (canMetric && !isClass) {
2297
+ console.warn(
2298
+ "Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
2299
+ );
2300
+ }
2301
+ for (let epoch = 0; epoch < this.epochs; epoch++) {
2302
+ const indices = Array.from({ length: n }, (_, i) => i);
2303
+ for (let i = n - 1; i > 0; i--) {
2304
+ const j = Math.floor(Math.random() * (i + 1));
2305
+ [indices[i], indices[j]] = [indices[j], indices[i]];
2306
+ }
2307
+ let epochLoss = 0;
2308
+ for (const i of indices) {
2309
+ if (canDecay) {
2310
+ const w = netExt.getWeights();
2311
+ for (let j = 0; j < w.length; j++) {
2312
+ w[j] *= 1 - lr * this.weightDecay;
2313
+ }
2314
+ netExt.setWeights(w);
2315
+ }
2316
+ epochLoss += this.network.train(inputs[i], targets[i], lr);
2317
+ }
2318
+ epochLoss /= n;
2319
+ this._history.push(epochLoss);
2320
+ if (canMetric && isClass) {
2321
+ this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
2322
+ }
2323
+ if (canValidate && this._validationData) {
2324
+ const valLoss = this._computeLoss(netExt, this._validationData);
2325
+ const minDelta = this._earlyStopping.minDelta;
2326
+ if (valLoss < this._bestLoss - minDelta) {
2327
+ this._bestLoss = valLoss;
2328
+ this._patienceCounter = 0;
2329
+ } else {
2330
+ this._patienceCounter++;
2331
+ }
2332
+ if (this._patienceCounter >= this._earlyStopping.patience) {
2333
+ this._stopReason = "earlyStopping";
2334
+ break;
2335
+ }
2336
+ }
2337
+ lr *= this.lrDecay;
2338
+ if (this.verbose && (epoch + 1) % 100 === 0) {
2339
+ console.log(
2340
+ `Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
2341
+ );
2342
+ }
2343
+ }
2344
+ return this._history;
2345
+ }
2346
+ // ── Get loss history ──────────────────────────────────────────────────────
2347
+ getHistory() {
2348
+ return [...this._history];
2349
+ }
2350
+ // ── Private helpers ───────────────────────────────────────────────────────
2351
+ /** Type guard: does this network support getWeights/setWeights/predict? */
2352
+ _hasWeights(network) {
2353
+ if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
2354
+ return network;
2355
+ }
2356
+ return null;
2357
+ }
2358
+ /** Mean squared error on a dataset (used for validation loss). */
2359
+ _computeLoss(net, data) {
2360
+ let totalLoss = 0;
2361
+ for (let i = 0; i < data.inputs.length; i++) {
2362
+ const pred = net.predict(data.inputs[i]);
2363
+ const target = data.targets[i];
2364
+ let sampleLoss = 0;
2365
+ for (let j = 0; j < pred.length; j++) {
2366
+ sampleLoss += (target[j] - pred[j]) ** 2;
2367
+ }
2368
+ totalLoss += sampleLoss / pred.length;
2369
+ }
2370
+ return totalLoss / data.inputs.length;
2371
+ }
2372
+ /** Heuristic: are targets classification-style (one-hot or single-class)? */
2373
+ _isClassification(targets) {
2374
+ if (targets.length === 0) return false;
2375
+ const first = targets[0];
2376
+ if (first.length === 1) return true;
2377
+ for (const t of targets) {
2378
+ let sum = 0;
2379
+ for (const v of t) {
2380
+ sum += v;
2381
+ if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
2382
+ return false;
2383
+ }
2384
+ if (Math.abs(sum - 1) > 0.01) return false;
2385
+ }
2386
+ return true;
2387
+ }
2388
+ /** Compute classification metrics from predictions vs targets. */
2389
+ _computeMetricsArray(net, inputs, targets) {
2390
+ const targetLen = targets[0].length;
2391
+ const nClasses = targetLen === 1 ? 2 : targetLen;
2392
+ const confusion = Array.from(
2393
+ { length: nClasses },
2394
+ () => Array(nClasses).fill(0)
2395
+ );
2396
+ for (let i = 0; i < inputs.length; i++) {
2397
+ const pred = net.predict(inputs[i]);
2398
+ const target = targets[i];
2399
+ let predClass;
2400
+ let trueClass;
2401
+ if (targetLen === 1) {
2402
+ trueClass = target[0] >= 0.5 ? 1 : 0;
2403
+ if (pred.length === 1) {
2404
+ predClass = pred[0] >= 0.5 ? 1 : 0;
2405
+ } else {
2406
+ predClass = pred.indexOf(Math.max(...pred));
2407
+ }
2408
+ } else {
2409
+ predClass = pred.indexOf(Math.max(...pred));
2410
+ trueClass = target.indexOf(Math.max(...target));
2411
+ }
2412
+ predClass = Math.max(0, Math.min(nClasses - 1, predClass));
2413
+ trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
2414
+ confusion[trueClass][predClass]++;
2415
+ }
2416
+ let totalCorrect = 0;
2417
+ let totalSamples = 0;
2418
+ const precisions = [];
2419
+ const recalls = [];
2420
+ for (let c = 0; c < nClasses; c++) {
2421
+ const tp = confusion[c][c];
2422
+ totalCorrect += tp;
2423
+ let colSum = 0;
2424
+ let rowSum = 0;
2425
+ for (let r = 0; r < nClasses; r++) {
2426
+ colSum += confusion[r][c];
2427
+ rowSum += confusion[c][r];
2428
+ }
2429
+ totalSamples += rowSum;
2430
+ precisions.push(colSum > 0 ? tp / colSum : 0);
2431
+ recalls.push(rowSum > 0 ? tp / rowSum : 0);
2432
+ }
2433
+ const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
2434
+ const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
2435
+ const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
2436
+ const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
2437
+ return {
2438
+ accuracy,
2439
+ precision: macroPrecision,
2440
+ recall: macroRecall,
2441
+ f1
2442
+ };
2443
+ }
2444
+ };
2445
+
2446
+ // src/DataLoader.ts
2447
+ var DataLoader = class _DataLoader {
2448
+ constructor(data, batchSize = 1, validationSplit = 0) {
2449
+ if (data.inputs.length !== data.targets.length) {
2450
+ throw new Error("DataLoader: inputs and targets must have the same length");
2451
+ }
2452
+ if (validationSplit < 0 || validationSplit >= 1) {
2453
+ throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
2454
+ }
2455
+ this.data = data;
2456
+ this.batchSize = batchSize;
2457
+ this._validationSplit = validationSplit;
2458
+ const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
2459
+ for (let i = fullIndices.length - 1; i > 0; i--) {
2460
+ const j = Math.floor(Math.random() * (i + 1));
2461
+ [fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
2462
+ }
2463
+ if (validationSplit > 0) {
2464
+ const valSize = Math.round(data.inputs.length * validationSplit);
2465
+ const trainSize = data.inputs.length - valSize;
2466
+ this._trainIndices = fullIndices.slice(0, trainSize);
2467
+ this._valIndices = fullIndices.slice(trainSize);
2468
+ } else {
2469
+ this._trainIndices = [...fullIndices];
2470
+ this._valIndices = [];
2471
+ }
2472
+ this._indices = [...this._trainIndices];
2473
+ this._pos = 0;
2474
+ }
2475
+ // ── Shuffle the training data ──────────────────────────────────────────────
2476
+ shuffle() {
2477
+ for (let i = this._trainIndices.length - 1; i > 0; i--) {
2478
+ const j = Math.floor(Math.random() * (i + 1));
2479
+ [this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
2480
+ }
2481
+ this._indices = [...this._trainIndices];
2482
+ this._pos = 0;
2483
+ }
2484
+ // ── Check if more batches are available ───────────────────────────────────
2485
+ hasNext() {
2486
+ return this._pos < this._indices.length;
2487
+ }
2488
+ // ── Get next batch ────────────────────────────────────────────────────────
2489
+ next() {
2490
+ const end = Math.min(this._pos + this.batchSize, this._indices.length);
2491
+ const batchIndices = this._indices.slice(this._pos, end);
2492
+ this._pos = end;
2493
+ return {
2494
+ inputs: batchIndices.map((i) => this.data.inputs[i]),
2495
+ targets: batchIndices.map((i) => this.data.targets[i])
2496
+ };
2497
+ }
2498
+ // ── Reset iteration ───────────────────────────────────────────────────────
2499
+ reset() {
2500
+ this._pos = 0;
2501
+ }
2502
+ // ── Get total number of training samples ───────────────────────────────────
2503
+ get length() {
2504
+ return this._trainIndices.length;
2505
+ }
2506
+ // ── Get validation data as a DataPair ──────────────────────────────────────
2507
+ // Returns the validation samples (inputs + targets) in their shuffled order.
2508
+ // Returns empty arrays if no validation split was configured.
2509
+ getValidationData() {
2510
+ return {
2511
+ inputs: this._valIndices.map((i) => this.data.inputs[i]),
2512
+ targets: this._valIndices.map((i) => this.data.targets[i])
2513
+ };
2514
+ }
2515
+ // ── Get number of validation samples ───────────────────────────────────────
2516
+ get validationLength() {
2517
+ return this._valIndices.length;
2518
+ }
2519
+ // ── Create sequence windows from a time series ────────────────────────────
2520
+ static sequences(data, seqLen, validationSplit = 0) {
2521
+ if (data.length < seqLen + 1) {
2522
+ throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
2523
+ }
2524
+ const inputs = [];
2525
+ const targets = [];
2526
+ for (let i = 0; i <= data.length - seqLen - 1; i++) {
2527
+ inputs.push(data.slice(i, i + seqLen).flat());
2528
+ targets.push(data[i + seqLen]);
2529
+ }
2530
+ return new _DataLoader({ inputs, targets }, 1, validationSplit);
2531
+ }
2532
+ };
2533
+
2534
+ // src/LRScheduler.ts
2535
+ var LRScheduler = class {
2536
+ // ── Step Decay ────────────────────────────────────────────────────────────
2537
+ // lr = initialLr * dropRate^floor(epoch / epochsDrop)
2538
+ stepDecay(lr, epoch, dropRate, epochsDrop) {
2539
+ return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
2540
+ }
2541
+ // ── Exponential Decay ─────────────────────────────────────────────────────
2542
+ // lr = initialLr * decayRate^epoch
2543
+ exponentialDecay(lr, epoch, decayRate) {
2544
+ return lr * Math.pow(decayRate, epoch);
2545
+ }
2546
+ // ── Plateau Decay ─────────────────────────────────────────────────────────
2547
+ // If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
2548
+ // Returns the new lr. Call this after each epoch with the current loss.
2549
+ //
2550
+ // Usage:
2551
+ // let patience_counter = 0
2552
+ // let best_loss = Infinity
2553
+ // for (let epoch = 0; epoch < 1000; epoch++) {
2554
+ // const loss = train(...)
2555
+ // lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
2556
+ // }
2557
+ plateauDecay(lr, currentLoss, history, patience, factor) {
2558
+ if (history.length < patience) return lr;
2559
+ const recentLosses = history.slice(-patience);
2560
+ const minRecentLoss = Math.min(...recentLosses);
2561
+ if (currentLoss >= minRecentLoss) {
2562
+ return lr * factor;
2563
+ }
2564
+ return lr;
2565
+ }
2566
+ // ── Cosine Annealing ──────────────────────────────────────────────────────
2567
+ // lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
2568
+ cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
2569
+ return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
2570
+ }
2571
+ };
2572
+
2573
+ // src/ModelSaver.ts
2574
+ var ModelSaver = class _ModelSaver {
2575
+ // ── Serialize to JSON string ──────────────────────────────────────────────
2576
+ static toJSON(model) {
2577
+ return JSON.stringify({
2578
+ weights: model.getWeights(),
2579
+ timestamp: Date.now()
2580
+ });
2581
+ }
2582
+ // ── Deserialize from JSON string ──────────────────────────────────────────
2583
+ static fromJSON(model, json) {
2584
+ const data = JSON.parse(json);
2585
+ if (!data.weights || !Array.isArray(data.weights)) {
2586
+ throw new Error("ModelSaver.fromJSON: invalid model data");
2587
+ }
2588
+ model.setWeights(data.weights);
2589
+ }
2590
+ // ── Save to file (requires write function) ────────────────────────────────
2591
+ static saveToFile(model, path, writeFn) {
2592
+ const json = _ModelSaver.toJSON(model);
2593
+ writeFn(path, json);
2594
+ }
2595
+ // ── Load from file (requires read function) ───────────────────────────────
2596
+ static loadFromFile(model, path, readFn) {
2597
+ const json = readFn(path);
2598
+ _ModelSaver.fromJSON(model, json);
2599
+ }
2600
+ };
1230
2601
  export {
1231
2602
  Adam,
1232
2603
  AttentionHead,
2604
+ BatchNorm,
2605
+ BiasVector,
2606
+ ClipOptimizer,
2607
+ ClippedOptimizerFactory,
2608
+ Conv1D,
2609
+ DataLoader,
2610
+ Dropout,
1233
2611
  EmbeddingMatrix,
2612
+ GRULayer,
2613
+ LRScheduler,
1234
2614
  LSTMLayer,
1235
2615
  Layer,
1236
2616
  LayerNorm,
2617
+ ModelSaver,
1237
2618
  Momentum,
1238
2619
  MultiHeadAttention,
1239
2620
  Network,
@@ -1244,11 +2625,13 @@ export {
1244
2625
  Neuron,
1245
2626
  NeuronN,
1246
2627
  SGD,
2628
+ Trainer,
1247
2629
  TransformerBlock,
1248
2630
  WeightMatrix,
1249
2631
  crossEntropy,
1250
2632
  crossEntropyDelta,
1251
2633
  crossEntropyDeltaRaw,
2634
+ defaultOptimizer,
1252
2635
  elu,
1253
2636
  leakyRelu,
1254
2637
  linear,
@@ -1262,5 +2645,9 @@ export {
1262
2645
  softmax,
1263
2646
  softmaxBackward,
1264
2647
  tanh,
1265
- transpose
2648
+ transpose,
2649
+ validate2DArray,
2650
+ validateArray,
2651
+ validateArrayMinLength,
2652
+ validateNumber
1266
2653
  };