@dniskav/neuron 0.2.3 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -3
- package/dist/index.d.mts +281 -22
- package/dist/index.d.ts +281 -22
- package/dist/index.js +1507 -104
- package/dist/index.mjs +1490 -103
- package/package.json +7 -3
package/dist/index.mjs
CHANGED
|
@@ -1,3 +1,71 @@
|
|
|
1
|
+
// src/Validation.ts
|
|
2
|
+
function validateArray(arr, expectedLength, methodName) {
|
|
3
|
+
if (!Array.isArray(arr)) {
|
|
4
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
5
|
+
}
|
|
6
|
+
if (arr.length !== expectedLength) {
|
|
7
|
+
throw new Error(
|
|
8
|
+
`${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
|
|
9
|
+
);
|
|
10
|
+
}
|
|
11
|
+
for (let i = 0; i < arr.length; i++) {
|
|
12
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
13
|
+
throw new Error(
|
|
14
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
15
|
+
);
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
function validateArrayMinLength(arr, minLength, methodName) {
|
|
20
|
+
if (!Array.isArray(arr)) {
|
|
21
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
22
|
+
}
|
|
23
|
+
if (arr.length < minLength) {
|
|
24
|
+
throw new Error(
|
|
25
|
+
`${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
|
|
26
|
+
);
|
|
27
|
+
}
|
|
28
|
+
for (let i = 0; i < arr.length; i++) {
|
|
29
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
30
|
+
throw new Error(
|
|
31
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
32
|
+
);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
function validate2DArray(arr, expectedRows, expectedCols, methodName) {
|
|
37
|
+
if (!Array.isArray(arr)) {
|
|
38
|
+
throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
|
|
39
|
+
}
|
|
40
|
+
if (arr.length !== expectedRows) {
|
|
41
|
+
throw new Error(
|
|
42
|
+
`${methodName}: expected ${expectedRows} rows, got ${arr.length}`
|
|
43
|
+
);
|
|
44
|
+
}
|
|
45
|
+
for (let i = 0; i < arr.length; i++) {
|
|
46
|
+
if (!Array.isArray(arr[i])) {
|
|
47
|
+
throw new Error(`${methodName}: row ${i} is not an array`);
|
|
48
|
+
}
|
|
49
|
+
if (arr[i].length !== expectedCols) {
|
|
50
|
+
throw new Error(
|
|
51
|
+
`${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
|
|
52
|
+
);
|
|
53
|
+
}
|
|
54
|
+
for (let j = 0; j < arr[i].length; j++) {
|
|
55
|
+
if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
|
|
56
|
+
throw new Error(
|
|
57
|
+
`${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
|
|
58
|
+
);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
function validateNumber(value, methodName) {
|
|
64
|
+
if (typeof value !== "number" || !isFinite(value)) {
|
|
65
|
+
throw new Error(`${methodName}: expected finite number, got ${value}`);
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
1
69
|
// src/Neuron.ts
|
|
2
70
|
function sigmoid(x) {
|
|
3
71
|
return 1 / (1 + Math.exp(-x));
|
|
@@ -8,13 +76,18 @@ var Neuron = class {
|
|
|
8
76
|
this.bias = Math.random() * 0.1;
|
|
9
77
|
}
|
|
10
78
|
predict(input) {
|
|
79
|
+
validateNumber(input, "Neuron.predict");
|
|
11
80
|
return sigmoid(input * this.weight + this.bias);
|
|
12
81
|
}
|
|
13
82
|
train(input, target, lr) {
|
|
83
|
+
validateNumber(input, "Neuron.train");
|
|
84
|
+
validateNumber(target, "Neuron.train");
|
|
85
|
+
validateNumber(lr, "Neuron.train");
|
|
14
86
|
const prediction = this.predict(input);
|
|
15
87
|
const error = target - prediction;
|
|
16
|
-
|
|
17
|
-
this.
|
|
88
|
+
const grad = error * prediction * (1 - prediction);
|
|
89
|
+
this.weight += lr * grad * input;
|
|
90
|
+
this.bias += lr * grad;
|
|
18
91
|
}
|
|
19
92
|
};
|
|
20
93
|
|
|
@@ -54,6 +127,7 @@ function makeElu(alpha = 1) {
|
|
|
54
127
|
var elu = makeElu(1);
|
|
55
128
|
|
|
56
129
|
// src/optimizers.ts
|
|
130
|
+
var defaultOptimizer = () => new SGD();
|
|
57
131
|
var SGD = class {
|
|
58
132
|
step(weight, gradient, lr) {
|
|
59
133
|
return weight + lr * gradient;
|
|
@@ -69,6 +143,19 @@ var Momentum = class {
|
|
|
69
143
|
return weight + this.v;
|
|
70
144
|
}
|
|
71
145
|
};
|
|
146
|
+
var ClipOptimizer = class {
|
|
147
|
+
constructor(inner, clipValue) {
|
|
148
|
+
this.inner = inner;
|
|
149
|
+
this.clipValue = clipValue;
|
|
150
|
+
}
|
|
151
|
+
step(weight, gradient, lr) {
|
|
152
|
+
const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
|
|
153
|
+
return this.inner.step(weight, clipped, lr);
|
|
154
|
+
}
|
|
155
|
+
};
|
|
156
|
+
function ClippedOptimizerFactory(innerFactory, clipValue) {
|
|
157
|
+
return () => new ClipOptimizer(innerFactory(), clipValue);
|
|
158
|
+
}
|
|
72
159
|
var Adam = class {
|
|
73
160
|
constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
|
|
74
161
|
this.beta1 = beta1;
|
|
@@ -89,7 +176,6 @@ var Adam = class {
|
|
|
89
176
|
};
|
|
90
177
|
|
|
91
178
|
// src/NeuronN.ts
|
|
92
|
-
var defaultOptimizer = () => new SGD();
|
|
93
179
|
var NeuronN = class {
|
|
94
180
|
constructor(nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
|
|
95
181
|
const limit = Math.sqrt(1 / nInputs);
|
|
@@ -99,6 +185,7 @@ var NeuronN = class {
|
|
|
99
185
|
this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
|
|
100
186
|
}
|
|
101
187
|
predict(inputs) {
|
|
188
|
+
validateArray(inputs, this.weights.length, "NeuronN.predict");
|
|
102
189
|
const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
|
|
103
190
|
return this.activation.fn(sum);
|
|
104
191
|
}
|
|
@@ -111,14 +198,14 @@ var NeuronN = class {
|
|
|
111
198
|
train(inputs, target, lr) {
|
|
112
199
|
const prediction = this.predict(inputs);
|
|
113
200
|
const error = target - prediction;
|
|
114
|
-
|
|
201
|
+
const grad = error * this.activation.dfn(prediction);
|
|
202
|
+
this._update(inputs.map((inp) => grad * inp), grad, lr);
|
|
115
203
|
}
|
|
116
204
|
};
|
|
117
205
|
|
|
118
206
|
// src/Layer.ts
|
|
119
|
-
var defaultOptimizer2 = () => new SGD();
|
|
120
207
|
var Layer = class {
|
|
121
|
-
constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory =
|
|
208
|
+
constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
|
|
122
209
|
this.neurons = Array.from(
|
|
123
210
|
{ length: nNeurons },
|
|
124
211
|
() => new NeuronN(nInputs, activation, optimizerFactory)
|
|
@@ -136,84 +223,233 @@ var Network = class {
|
|
|
136
223
|
this.outputLayer = new Layer(nOutputs, nHidden);
|
|
137
224
|
}
|
|
138
225
|
predict(inputs) {
|
|
226
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
|
|
139
227
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
140
|
-
return this.outputLayer.predict(hiddenOut)
|
|
228
|
+
return this.outputLayer.predict(hiddenOut);
|
|
141
229
|
}
|
|
142
230
|
// Trains on a single example. Returns the squared error.
|
|
143
231
|
train(inputs, target, lr) {
|
|
232
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
|
|
233
|
+
validateNumber(target, "Network.train");
|
|
234
|
+
validateNumber(lr, "Network.train");
|
|
144
235
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
145
236
|
const prediction = this.outputLayer.predict(hiddenOut)[0];
|
|
146
|
-
const outputError = target - prediction;
|
|
147
|
-
const outputDelta = outputError * prediction * (1 - prediction);
|
|
148
237
|
const outputNeuron = this.outputLayer.neurons[0];
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
)
|
|
152
|
-
outputNeuron.bias += lr * outputDelta;
|
|
153
|
-
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
154
|
-
const hiddenOut_i = hiddenOut[i];
|
|
238
|
+
const outputError = target - prediction;
|
|
239
|
+
const outputDelta = outputError * outputNeuron.activation.dfn(prediction);
|
|
240
|
+
const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
|
|
155
241
|
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
242
|
+
return hiddenError * neuron.activation.dfn(hiddenOut[i]);
|
|
243
|
+
});
|
|
244
|
+
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
245
|
+
neuron._update(inputs.map((inp) => hiddenDeltas[i] * inp), hiddenDeltas[i], lr);
|
|
159
246
|
});
|
|
247
|
+
outputNeuron._update(hiddenOut.map((h) => outputDelta * h), outputDelta, lr);
|
|
160
248
|
return outputError * outputError;
|
|
161
249
|
}
|
|
250
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
251
|
+
// Order: hidden layer (all neurons: weights then bias), then output layer.
|
|
252
|
+
getWeights() {
|
|
253
|
+
const w = [];
|
|
254
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
255
|
+
w.push(...n.weights, n.bias);
|
|
256
|
+
}
|
|
257
|
+
for (const n of this.outputLayer.neurons) {
|
|
258
|
+
w.push(...n.weights, n.bias);
|
|
259
|
+
}
|
|
260
|
+
return w;
|
|
261
|
+
}
|
|
262
|
+
setWeights(weights) {
|
|
263
|
+
let idx = 0;
|
|
264
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
265
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
266
|
+
n.bias = weights[idx++];
|
|
267
|
+
}
|
|
268
|
+
for (const n of this.outputLayer.neurons) {
|
|
269
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
270
|
+
n.bias = weights[idx++];
|
|
271
|
+
}
|
|
272
|
+
}
|
|
273
|
+
};
|
|
274
|
+
|
|
275
|
+
// src/Dropout.ts
|
|
276
|
+
var Dropout = class {
|
|
277
|
+
constructor(rate) {
|
|
278
|
+
this._mask = null;
|
|
279
|
+
if (rate < 0 || rate >= 1) {
|
|
280
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
|
|
281
|
+
}
|
|
282
|
+
this.rate = rate;
|
|
283
|
+
}
|
|
284
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
285
|
+
// x: number[] → number[]
|
|
286
|
+
// If training, applies inverted dropout mask.
|
|
287
|
+
// If not training, returns input unchanged.
|
|
288
|
+
forward(x, training = true) {
|
|
289
|
+
if (!training || this.rate === 0) {
|
|
290
|
+
this._mask = null;
|
|
291
|
+
return [...x];
|
|
292
|
+
}
|
|
293
|
+
const scale = 1 / (1 - this.rate);
|
|
294
|
+
this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
|
|
295
|
+
return x.map((v, i) => v * this._mask[i]);
|
|
296
|
+
}
|
|
297
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
298
|
+
// dOut: number[] → number[]
|
|
299
|
+
// Applies the same mask (gradient is zeroed where activation was zeroed).
|
|
300
|
+
backward(dOut) {
|
|
301
|
+
if (!this._mask) return [...dOut];
|
|
302
|
+
return dOut.map((d, i) => d * this._mask[i]);
|
|
303
|
+
}
|
|
304
|
+
// ── Reset mask between forward passes ─────────────────────────────────────
|
|
305
|
+
resetMask() {
|
|
306
|
+
this._mask = null;
|
|
307
|
+
}
|
|
308
|
+
// ── No trainable params ───────────────────────────────────────────────────
|
|
309
|
+
getWeights() {
|
|
310
|
+
return [];
|
|
311
|
+
}
|
|
312
|
+
setWeights(_weights) {
|
|
313
|
+
}
|
|
162
314
|
};
|
|
163
315
|
|
|
164
316
|
// src/NetworkN.ts
|
|
165
|
-
var defaultOptimizer3 = () => new SGD();
|
|
166
317
|
var NetworkN = class {
|
|
167
318
|
constructor(structure, options = {}) {
|
|
168
319
|
this.structure = structure;
|
|
169
320
|
const nLayers = structure.length - 1;
|
|
170
321
|
const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
|
|
171
|
-
const optimizer = options.optimizer ??
|
|
322
|
+
const optimizer = options.optimizer ?? defaultOptimizer;
|
|
323
|
+
const dropoutRate = options.dropoutRate ?? 0;
|
|
324
|
+
if (activations.length !== nLayers) {
|
|
325
|
+
throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
|
|
326
|
+
}
|
|
327
|
+
if (dropoutRate < 0 || dropoutRate >= 1) {
|
|
328
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
|
|
329
|
+
}
|
|
330
|
+
this._residual = options.residual ?? false;
|
|
172
331
|
this.layers = [];
|
|
173
332
|
for (let i = 1; i < structure.length; i++) {
|
|
174
333
|
this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
|
|
175
334
|
}
|
|
335
|
+
this._dropouts = [];
|
|
336
|
+
if (dropoutRate > 0) {
|
|
337
|
+
for (let i = 0; i < nLayers - 1; i++) {
|
|
338
|
+
this._dropouts.push(new Dropout(dropoutRate));
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
const outputLayer = this.layers[this.layers.length - 1];
|
|
342
|
+
const outputActivation = outputLayer.neurons[0].activation;
|
|
343
|
+
for (let i = 1; i < outputLayer.neurons.length; i++) {
|
|
344
|
+
if (outputLayer.neurons[i].activation !== outputActivation) {
|
|
345
|
+
throw new Error("All output neurons must share the same activation function");
|
|
346
|
+
}
|
|
347
|
+
}
|
|
176
348
|
}
|
|
177
|
-
predict(inputs) {
|
|
178
|
-
|
|
349
|
+
predict(inputs, training = false) {
|
|
350
|
+
validateArray(inputs, this.structure[0], "NetworkN.predict");
|
|
351
|
+
let current = [...inputs];
|
|
352
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
353
|
+
const layerInput = [...current];
|
|
354
|
+
const layerOutput = this.layers[i].predict(current);
|
|
355
|
+
if (this._shouldResidual(i)) {
|
|
356
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
357
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
358
|
+
} else {
|
|
359
|
+
current = [...layerOutput];
|
|
360
|
+
}
|
|
361
|
+
} else {
|
|
362
|
+
current = [...layerOutput];
|
|
363
|
+
}
|
|
364
|
+
if (i < this._dropouts.length) {
|
|
365
|
+
current = this._dropouts[i].forward(current, training);
|
|
366
|
+
}
|
|
367
|
+
}
|
|
368
|
+
return current;
|
|
179
369
|
}
|
|
180
370
|
// Generalized backpropagation across L layers.
|
|
181
371
|
// Returns the mean squared error for the example.
|
|
182
372
|
train(inputs, targets, lr) {
|
|
183
|
-
|
|
184
|
-
|
|
373
|
+
validateArray(inputs, this.structure[0], "NetworkN.train");
|
|
374
|
+
validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
|
|
375
|
+
const act = this._forwardAll(inputs, true);
|
|
185
376
|
const pred = act[act.length - 1];
|
|
186
377
|
const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
const layer = this.layers[l];
|
|
190
|
-
const layerIn = act[l];
|
|
191
|
-
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
192
|
-
const prevDeltas = layerIn.map((out, j) => {
|
|
193
|
-
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
194
|
-
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
195
|
-
});
|
|
196
|
-
layer.neurons.forEach((n, k) => {
|
|
197
|
-
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
198
|
-
});
|
|
199
|
-
deltas = prevDeltas;
|
|
200
|
-
}
|
|
378
|
+
const deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
|
|
379
|
+
this._backpropLayers(act, deltas, lr);
|
|
201
380
|
return pred.reduce((s, p, i) => s + (targets[i] - p) ** 2, 0) / pred.length;
|
|
202
381
|
}
|
|
203
382
|
// Backprop with externally provided output-layer deltas.
|
|
204
383
|
// Useful for custom loss functions (e.g. physics-based gradients).
|
|
205
384
|
trainWithDeltas(inputs, outputDeltas, lr) {
|
|
385
|
+
const act = this._forwardAll(inputs, true);
|
|
386
|
+
this._backpropLayers(act, outputDeltas, lr);
|
|
387
|
+
}
|
|
388
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
389
|
+
// Order: layer 0 (all neurons), layer 1, ..., layer N.
|
|
390
|
+
getWeights() {
|
|
391
|
+
for (const d of this._dropouts) d.resetMask();
|
|
392
|
+
const w = [];
|
|
393
|
+
for (const layer of this.layers) {
|
|
394
|
+
for (const n of layer.neurons) {
|
|
395
|
+
w.push(...n.weights, n.bias);
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
return w;
|
|
399
|
+
}
|
|
400
|
+
setWeights(weights) {
|
|
401
|
+
for (const d of this._dropouts) d.resetMask();
|
|
402
|
+
let idx = 0;
|
|
403
|
+
for (const layer of this.layers) {
|
|
404
|
+
for (const n of layer.neurons) {
|
|
405
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
406
|
+
n.bias = weights[idx++];
|
|
407
|
+
}
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
// ── Private helpers ──────────────────────────────────────────────────────
|
|
411
|
+
_shouldResidual(layerIndex) {
|
|
412
|
+
if (typeof this._residual === "function") return this._residual(layerIndex);
|
|
413
|
+
return this._residual;
|
|
414
|
+
}
|
|
415
|
+
// Forward pass storing activations at every layer boundary.
|
|
416
|
+
// Used by train(), trainWithDeltas(), and predict() shares the same logic.
|
|
417
|
+
_forwardAll(inputs, training) {
|
|
206
418
|
const act = [inputs];
|
|
207
|
-
for (
|
|
419
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
420
|
+
const layerInput = act[act.length - 1];
|
|
421
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
422
|
+
let current;
|
|
423
|
+
if (this._shouldResidual(i) && this.structure[i] === this.structure[i + 1]) {
|
|
424
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
425
|
+
} else {
|
|
426
|
+
current = layerOutput;
|
|
427
|
+
}
|
|
428
|
+
if (i < this._dropouts.length) {
|
|
429
|
+
current = this._dropouts[i].forward(current, training);
|
|
430
|
+
}
|
|
431
|
+
act.push(current);
|
|
432
|
+
}
|
|
433
|
+
return act;
|
|
434
|
+
}
|
|
435
|
+
// Backward pass: updates all layer weights given the pre-computed activations
|
|
436
|
+
// and the initial output-layer deltas.
|
|
437
|
+
_backpropLayers(act, outputDeltas, lr) {
|
|
208
438
|
let deltas = outputDeltas;
|
|
209
439
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
210
440
|
const layer = this.layers[l];
|
|
441
|
+
if (l < this._dropouts.length) {
|
|
442
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
443
|
+
}
|
|
211
444
|
const layerIn = act[l];
|
|
212
445
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
213
446
|
const prevDeltas = layerIn.map((out, j) => {
|
|
214
447
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
215
448
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
216
449
|
});
|
|
450
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
451
|
+
for (let j = 0; j < prevDeltas.length; j++) prevDeltas[j] += deltas[j];
|
|
452
|
+
}
|
|
217
453
|
layer.neurons.forEach((n, k) => {
|
|
218
454
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
219
455
|
});
|
|
@@ -234,7 +470,7 @@ var Gate = class {
|
|
|
234
470
|
// shape: [hSize]
|
|
235
471
|
constructor(inputSize, hSize, initBias = 0) {
|
|
236
472
|
const n = inputSize + hSize;
|
|
237
|
-
const limit = Math.sqrt(2 / n);
|
|
473
|
+
const limit = Math.sqrt(2 / (n + hSize));
|
|
238
474
|
this.W = Array.from(
|
|
239
475
|
{ length: hSize },
|
|
240
476
|
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
@@ -248,8 +484,11 @@ var Gate = class {
|
|
|
248
484
|
}
|
|
249
485
|
};
|
|
250
486
|
var LSTMLayer = class {
|
|
251
|
-
constructor(inputSize, hiddenSize) {
|
|
487
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
252
488
|
this._traj = [];
|
|
489
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
490
|
+
throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
491
|
+
}
|
|
253
492
|
this.inputSize = inputSize;
|
|
254
493
|
this.hSize = hiddenSize;
|
|
255
494
|
this.h = new Array(hiddenSize).fill(0);
|
|
@@ -258,6 +497,29 @@ var LSTMLayer = class {
|
|
|
258
497
|
this.inputGate = new Gate(inputSize, hiddenSize);
|
|
259
498
|
this.cellGate = new Gate(inputSize, hiddenSize);
|
|
260
499
|
this.outputGate = new Gate(inputSize, hiddenSize);
|
|
500
|
+
const combSize = inputSize + hiddenSize;
|
|
501
|
+
this._optimizers = {
|
|
502
|
+
forgetW: Array.from(
|
|
503
|
+
{ length: hiddenSize },
|
|
504
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
505
|
+
),
|
|
506
|
+
forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
507
|
+
inputW: Array.from(
|
|
508
|
+
{ length: hiddenSize },
|
|
509
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
510
|
+
),
|
|
511
|
+
inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
512
|
+
cellW: Array.from(
|
|
513
|
+
{ length: hiddenSize },
|
|
514
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
515
|
+
),
|
|
516
|
+
cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
517
|
+
outputW: Array.from(
|
|
518
|
+
{ length: hiddenSize },
|
|
519
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
520
|
+
),
|
|
521
|
+
outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
522
|
+
};
|
|
261
523
|
}
|
|
262
524
|
// ── Reset state and trajectory (call at episode start) ────────────────────
|
|
263
525
|
reset() {
|
|
@@ -267,6 +529,9 @@ var LSTMLayer = class {
|
|
|
267
529
|
}
|
|
268
530
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
269
531
|
predict(inputs) {
|
|
532
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
533
|
+
throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
534
|
+
}
|
|
270
535
|
const combined = [...inputs, ...this.h];
|
|
271
536
|
const c_prev = [...this.c];
|
|
272
537
|
const zf = this.forgetGate.linear(combined);
|
|
@@ -341,15 +606,15 @@ var LSTMLayer = class {
|
|
|
341
606
|
const scale = lr / T;
|
|
342
607
|
for (let k = 0; k < hSize; k++) {
|
|
343
608
|
for (let j = 0; j < combSize; j++) {
|
|
344
|
-
this.forgetGate.W[k][j]
|
|
345
|
-
this.inputGate.W[k][j]
|
|
346
|
-
this.cellGate.W[k][j]
|
|
347
|
-
this.outputGate.W[k][j]
|
|
609
|
+
this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
|
|
610
|
+
this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
|
|
611
|
+
this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
|
|
612
|
+
this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
|
|
348
613
|
}
|
|
349
|
-
this.forgetGate.b[k]
|
|
350
|
-
this.inputGate.b[k]
|
|
351
|
-
this.cellGate.b[k]
|
|
352
|
-
this.outputGate.b[k]
|
|
614
|
+
this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
|
|
615
|
+
this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
|
|
616
|
+
this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
|
|
617
|
+
this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
|
|
353
618
|
}
|
|
354
619
|
this._traj = [];
|
|
355
620
|
}
|
|
@@ -372,10 +637,38 @@ var LSTMLayer = class {
|
|
|
372
637
|
this.outputGate.W = data.outputGate.W;
|
|
373
638
|
this.outputGate.b = data.outputGate.b;
|
|
374
639
|
}
|
|
640
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
641
|
+
// Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
|
|
642
|
+
getWeightsFlat() {
|
|
643
|
+
const w = [];
|
|
644
|
+
for (const row of this.forgetGate.W) w.push(...row);
|
|
645
|
+
w.push(...this.forgetGate.b);
|
|
646
|
+
for (const row of this.inputGate.W) w.push(...row);
|
|
647
|
+
w.push(...this.inputGate.b);
|
|
648
|
+
for (const row of this.cellGate.W) w.push(...row);
|
|
649
|
+
w.push(...this.cellGate.b);
|
|
650
|
+
for (const row of this.outputGate.W) w.push(...row);
|
|
651
|
+
w.push(...this.outputGate.b);
|
|
652
|
+
return w;
|
|
653
|
+
}
|
|
654
|
+
setWeightsFlat(weights) {
|
|
655
|
+
let idx = 0;
|
|
656
|
+
for (let i = 0; i < this.forgetGate.W.length; i++)
|
|
657
|
+
for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
|
|
658
|
+
for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
|
|
659
|
+
for (let i = 0; i < this.inputGate.W.length; i++)
|
|
660
|
+
for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
|
|
661
|
+
for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
|
|
662
|
+
for (let i = 0; i < this.cellGate.W.length; i++)
|
|
663
|
+
for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
|
|
664
|
+
for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
|
|
665
|
+
for (let i = 0; i < this.outputGate.W.length; i++)
|
|
666
|
+
for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
|
|
667
|
+
for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
|
|
668
|
+
}
|
|
375
669
|
};
|
|
376
670
|
|
|
377
671
|
// src/NetworkLSTM.ts
|
|
378
|
-
var defaultOptimizer4 = () => new SGD();
|
|
379
672
|
var NetworkLSTM = class {
|
|
380
673
|
// [T][layer+1][neuron]
|
|
381
674
|
constructor(inputSize, hiddenSize, denseStructure, options = {}) {
|
|
@@ -383,7 +676,7 @@ var NetworkLSTM = class {
|
|
|
383
676
|
this.hiddenSize = hiddenSize;
|
|
384
677
|
this.lstm = new LSTMLayer(inputSize, hiddenSize);
|
|
385
678
|
const activation = options.denseActivation ?? sigmoid2;
|
|
386
|
-
const optimizer = options.optimizer ??
|
|
679
|
+
const optimizer = options.optimizer ?? defaultOptimizer;
|
|
387
680
|
this.denseLayers = [];
|
|
388
681
|
const sizes = [hiddenSize, ...denseStructure];
|
|
389
682
|
for (let i = 1; i < sizes.length; i++) {
|
|
@@ -398,6 +691,7 @@ var NetworkLSTM = class {
|
|
|
398
691
|
}
|
|
399
692
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
400
693
|
predict(inputs) {
|
|
694
|
+
validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
|
|
401
695
|
const h = this.lstm.predict(inputs);
|
|
402
696
|
const acts = [h];
|
|
403
697
|
for (const layer of this.denseLayers) {
|
|
@@ -473,6 +767,30 @@ var NetworkLSTM = class {
|
|
|
473
767
|
});
|
|
474
768
|
});
|
|
475
769
|
}
|
|
770
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
771
|
+
// Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
|
|
772
|
+
getWeightsFlat() {
|
|
773
|
+
const w = [];
|
|
774
|
+
w.push(...this.lstm.getWeightsFlat());
|
|
775
|
+
for (const layer of this.denseLayers) {
|
|
776
|
+
for (const n of layer.neurons) {
|
|
777
|
+
w.push(...n.weights, n.bias);
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
return w;
|
|
781
|
+
}
|
|
782
|
+
setWeightsFlat(weights) {
|
|
783
|
+
let idx = 0;
|
|
784
|
+
const lstmLen = this.lstm.getWeightsFlat().length;
|
|
785
|
+
this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
|
|
786
|
+
idx += lstmLen;
|
|
787
|
+
for (const layer of this.denseLayers) {
|
|
788
|
+
for (const n of layer.neurons) {
|
|
789
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
790
|
+
n.bias = weights[idx++];
|
|
791
|
+
}
|
|
792
|
+
}
|
|
793
|
+
}
|
|
476
794
|
};
|
|
477
795
|
|
|
478
796
|
// src/MatMul.ts
|
|
@@ -480,6 +798,9 @@ function matMul(A, B) {
|
|
|
480
798
|
const rows = A.length;
|
|
481
799
|
const inner = B.length;
|
|
482
800
|
const cols = B[0].length;
|
|
801
|
+
if (A[0].length !== B.length) {
|
|
802
|
+
throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
|
|
803
|
+
}
|
|
483
804
|
const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
|
|
484
805
|
for (let i = 0; i < rows; i++)
|
|
485
806
|
for (let k = 0; k < inner; k++) {
|
|
@@ -530,6 +851,33 @@ var WeightMatrix = class {
|
|
|
530
851
|
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
531
852
|
}
|
|
532
853
|
}
|
|
854
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
855
|
+
getWeights() {
|
|
856
|
+
const w = [];
|
|
857
|
+
for (const row of this.W) w.push(...row);
|
|
858
|
+
return w;
|
|
859
|
+
}
|
|
860
|
+
setWeights(weights) {
|
|
861
|
+
let idx = 0;
|
|
862
|
+
for (let i = 0; i < this.W.length; i++)
|
|
863
|
+
for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
|
|
864
|
+
}
|
|
865
|
+
};
|
|
866
|
+
var BiasVector = class {
|
|
867
|
+
constructor(size) {
|
|
868
|
+
this.values = new Array(size).fill(0);
|
|
869
|
+
this.opts = Array.from({ length: size }, () => new Adam());
|
|
870
|
+
}
|
|
871
|
+
update(grad, lr) {
|
|
872
|
+
for (let i = 0; i < this.values.length; i++)
|
|
873
|
+
this.values[i] = this.opts[i].step(this.values[i], grad[i], lr);
|
|
874
|
+
}
|
|
875
|
+
getWeights() {
|
|
876
|
+
return [...this.values];
|
|
877
|
+
}
|
|
878
|
+
setWeights(weights) {
|
|
879
|
+
for (let i = 0; i < this.values.length; i++) this.values[i] = weights[i];
|
|
880
|
+
}
|
|
533
881
|
};
|
|
534
882
|
var EmbeddingMatrix = class {
|
|
535
883
|
constructor(vocabSize, d_model) {
|
|
@@ -546,15 +894,29 @@ var EmbeddingMatrix = class {
|
|
|
546
894
|
for (let m = 0; m < this.W[idx].length; m++)
|
|
547
895
|
this.W[idx][m] += lr * grad[m];
|
|
548
896
|
}
|
|
897
|
+
// ── Serializable interface ─────────────────────────────────────────────────
|
|
898
|
+
// Flattened order: row 0, row 1, ... row (vocabSize-1)
|
|
899
|
+
getWeights() {
|
|
900
|
+
const w = [];
|
|
901
|
+
for (const row of this.W) w.push(...row);
|
|
902
|
+
return w;
|
|
903
|
+
}
|
|
904
|
+
setWeights(weights) {
|
|
905
|
+
let idx = 0;
|
|
906
|
+
for (let i = 0; i < this.W.length; i++)
|
|
907
|
+
for (let j = 0; j < this.W[i].length; j++)
|
|
908
|
+
this.W[i][j] = weights[idx++];
|
|
909
|
+
}
|
|
549
910
|
};
|
|
550
911
|
|
|
551
912
|
// src/AttentionHead.ts
|
|
552
913
|
var AttentionHead = class {
|
|
553
|
-
constructor(d_model, d_k, d_v) {
|
|
914
|
+
constructor(d_model, d_k, d_v, causal = false) {
|
|
554
915
|
// d_v × d_model
|
|
555
916
|
this.cache = null;
|
|
556
917
|
this.d_k = d_k;
|
|
557
918
|
this.d_v = d_v;
|
|
919
|
+
this.causal = causal;
|
|
558
920
|
this.Wq = new WeightMatrix(d_k, d_model);
|
|
559
921
|
this.Wk = new WeightMatrix(d_k, d_model);
|
|
560
922
|
this.Wv = new WeightMatrix(d_v, d_model);
|
|
@@ -575,10 +937,10 @@ var AttentionHead = class {
|
|
|
575
937
|
);
|
|
576
938
|
const scores = Array.from(
|
|
577
939
|
{ length: seqLen },
|
|
578
|
-
(_, i) => Array.from(
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
)
|
|
940
|
+
(_, i) => Array.from({ length: seqLen }, (_2, j) => {
|
|
941
|
+
if (this.causal && j > i) return -Infinity;
|
|
942
|
+
return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
|
|
943
|
+
})
|
|
582
944
|
);
|
|
583
945
|
const attn = scores.map((row) => softmax(row));
|
|
584
946
|
const out = Array.from(
|
|
@@ -602,6 +964,7 @@ var AttentionHead = class {
|
|
|
602
964
|
// 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
|
|
603
965
|
// 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
|
|
604
966
|
backward(dOut, lr) {
|
|
967
|
+
if (!this.cache) throw new Error("AttentionHead.backward() called before predict()");
|
|
605
968
|
const { X, Q, K, V, attn } = this.cache;
|
|
606
969
|
const seqLen = X.length;
|
|
607
970
|
const d_model = X[0].length;
|
|
@@ -674,21 +1037,40 @@ var AttentionHead = class {
|
|
|
674
1037
|
getAttentionWeights() {
|
|
675
1038
|
return this.cache ? this.cache.attn : null;
|
|
676
1039
|
}
|
|
1040
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1041
|
+
// Order: Wq, Wk, Wv.
|
|
1042
|
+
getWeights() {
|
|
1043
|
+
const w = [];
|
|
1044
|
+
for (const row of this.Wq.W) w.push(...row);
|
|
1045
|
+
for (const row of this.Wk.W) w.push(...row);
|
|
1046
|
+
for (const row of this.Wv.W) w.push(...row);
|
|
1047
|
+
return w;
|
|
1048
|
+
}
|
|
1049
|
+
setWeights(weights) {
|
|
1050
|
+
let idx = 0;
|
|
1051
|
+
for (let i = 0; i < this.Wq.W.length; i++)
|
|
1052
|
+
for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
|
|
1053
|
+
for (let i = 0; i < this.Wk.W.length; i++)
|
|
1054
|
+
for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
|
|
1055
|
+
for (let i = 0; i < this.Wv.W.length; i++)
|
|
1056
|
+
for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
|
|
1057
|
+
}
|
|
677
1058
|
};
|
|
678
1059
|
|
|
679
1060
|
// src/MultiHeadAttention.ts
|
|
680
1061
|
var MultiHeadAttention = class {
|
|
681
1062
|
// seqLen × (nHeads * d_k)
|
|
682
|
-
constructor(d_model, nHeads) {
|
|
1063
|
+
constructor(d_model, nHeads, causal = false) {
|
|
683
1064
|
// d_model × (nHeads * d_k)
|
|
684
1065
|
// Cached for backward
|
|
685
1066
|
this._concat = null;
|
|
686
1067
|
this.nHeads = nHeads;
|
|
687
1068
|
this.d_model = d_model;
|
|
688
1069
|
this.d_k = Math.floor(d_model / nHeads);
|
|
1070
|
+
this.causal = causal;
|
|
689
1071
|
this.heads = Array.from(
|
|
690
1072
|
{ length: nHeads },
|
|
691
|
-
() => new AttentionHead(d_model, this.d_k, this.d_k)
|
|
1073
|
+
() => new AttentionHead(d_model, this.d_k, this.d_k, causal)
|
|
692
1074
|
);
|
|
693
1075
|
this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
|
|
694
1076
|
}
|
|
@@ -710,6 +1092,7 @@ var MultiHeadAttention = class {
|
|
|
710
1092
|
// ── Backward ──────────────────────────────────────────────────────────────
|
|
711
1093
|
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
712
1094
|
backward(dOut, lr) {
|
|
1095
|
+
if (!this._concat) throw new Error("MultiHeadAttention.backward() called before predict()");
|
|
713
1096
|
const seqLen = dOut.length;
|
|
714
1097
|
const concatD = this.nHeads * this.d_k;
|
|
715
1098
|
const d_model = this.d_model;
|
|
@@ -747,6 +1130,31 @@ var MultiHeadAttention = class {
|
|
|
747
1130
|
getAttentionWeights() {
|
|
748
1131
|
return this.heads.map((h) => h.getAttentionWeights());
|
|
749
1132
|
}
|
|
1133
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1134
|
+
// Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
|
|
1135
|
+
getWeights() {
|
|
1136
|
+
const w = [];
|
|
1137
|
+
for (const head of this.heads) {
|
|
1138
|
+
for (const row of head.Wq.W) w.push(...row);
|
|
1139
|
+
for (const row of head.Wk.W) w.push(...row);
|
|
1140
|
+
for (const row of head.Wv.W) w.push(...row);
|
|
1141
|
+
}
|
|
1142
|
+
for (const row of this.Wo.W) w.push(...row);
|
|
1143
|
+
return w;
|
|
1144
|
+
}
|
|
1145
|
+
setWeights(weights) {
|
|
1146
|
+
let idx = 0;
|
|
1147
|
+
for (const head of this.heads) {
|
|
1148
|
+
for (let i = 0; i < head.Wq.W.length; i++)
|
|
1149
|
+
for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
|
|
1150
|
+
for (let i = 0; i < head.Wk.W.length; i++)
|
|
1151
|
+
for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
|
|
1152
|
+
for (let i = 0; i < head.Wv.W.length; i++)
|
|
1153
|
+
for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
|
|
1154
|
+
}
|
|
1155
|
+
for (let i = 0; i < this.Wo.W.length; i++)
|
|
1156
|
+
for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
|
|
1157
|
+
}
|
|
750
1158
|
};
|
|
751
1159
|
|
|
752
1160
|
// src/LayerNorm.ts
|
|
@@ -789,20 +1197,32 @@ var LayerNorm = class {
|
|
|
789
1197
|
backwardOne(dOut, pos, lr) {
|
|
790
1198
|
const { x_norm, std } = this._cache[pos];
|
|
791
1199
|
const N = dOut.length;
|
|
1200
|
+
const gammaOld = this.gamma.slice();
|
|
792
1201
|
for (let i = 0; i < N; i++) {
|
|
793
1202
|
this.gamma[i] += lr * dOut[i] * x_norm[i];
|
|
794
1203
|
this.beta[i] += lr * dOut[i];
|
|
795
1204
|
}
|
|
796
|
-
const D = dOut.map((d, i) => d *
|
|
1205
|
+
const D = dOut.map((d, i) => d * gammaOld[i]);
|
|
797
1206
|
const mD = D.reduce((s, v) => s + v, 0) / N;
|
|
798
1207
|
const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
|
|
799
1208
|
return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
|
|
800
1209
|
}
|
|
1210
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1211
|
+
// Order: gamma, beta.
|
|
1212
|
+
getWeights() {
|
|
1213
|
+
return [...this.gamma, ...this.beta];
|
|
1214
|
+
}
|
|
1215
|
+
setWeights(weights) {
|
|
1216
|
+
const dim = this.gamma.length;
|
|
1217
|
+
for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
|
|
1218
|
+
for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
|
|
1219
|
+
}
|
|
801
1220
|
};
|
|
802
1221
|
|
|
803
1222
|
// src/TransformerBlock.ts
|
|
804
1223
|
var TransformerBlock = class {
|
|
805
|
-
constructor({ d_model, nHeads, d_ff }) {
|
|
1224
|
+
constructor({ d_model, nHeads, d_ff, causal = false }) {
|
|
1225
|
+
// d_model
|
|
806
1226
|
// Forward caches (needed for backprop)
|
|
807
1227
|
this._X = null;
|
|
808
1228
|
this._attnOut = null;
|
|
@@ -814,15 +1234,13 @@ var TransformerBlock = class {
|
|
|
814
1234
|
this._ff2Out = null;
|
|
815
1235
|
this.d_model = d_model;
|
|
816
1236
|
this.d_ff = d_ff;
|
|
817
|
-
this.attn = new MultiHeadAttention(d_model, nHeads);
|
|
1237
|
+
this.attn = new MultiHeadAttention(d_model, nHeads, causal);
|
|
818
1238
|
this.norm1 = new LayerNorm(d_model);
|
|
819
1239
|
this.norm2 = new LayerNorm(d_model);
|
|
820
1240
|
this.ff1 = new WeightMatrix(d_ff, d_model);
|
|
821
1241
|
this.ff2 = new WeightMatrix(d_model, d_ff);
|
|
822
|
-
this.b1 = new
|
|
823
|
-
this.b2 = new
|
|
824
|
-
this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
|
|
825
|
-
this.b2Opts = Array.from({ length: d_model }, () => new Adam());
|
|
1242
|
+
this.b1 = new BiasVector(d_ff);
|
|
1243
|
+
this.b2 = new BiasVector(d_model);
|
|
826
1244
|
}
|
|
827
1245
|
// ── Forward ───────────────────────────────────────────────────────────────
|
|
828
1246
|
// X: seqLen × d_model → out: seqLen × d_model
|
|
@@ -835,11 +1253,11 @@ var TransformerBlock = class {
|
|
|
835
1253
|
return this.norm1.predictOne(added, i);
|
|
836
1254
|
});
|
|
837
1255
|
const ff1Pre = h1.map(
|
|
838
|
-
(h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
|
|
1256
|
+
(h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1.values[k]))
|
|
839
1257
|
);
|
|
840
1258
|
const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
|
|
841
1259
|
const ff2Out = ff1Out.map(
|
|
842
|
-
(h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
|
|
1260
|
+
(h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2.values[k]))
|
|
843
1261
|
);
|
|
844
1262
|
this.norm2.resetCache(seqLen);
|
|
845
1263
|
const out = h1.map((h, i) => {
|
|
@@ -857,6 +1275,9 @@ var TransformerBlock = class {
|
|
|
857
1275
|
// ── Backward ──────────────────────────────────────────────────────────────
|
|
858
1276
|
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
859
1277
|
backward(dOut, lr) {
|
|
1278
|
+
if (!this._h1 || !this._ff1Out || !this._ff1Pre) {
|
|
1279
|
+
throw new Error("TransformerBlock.backward() called before predict()");
|
|
1280
|
+
}
|
|
860
1281
|
const seqLen = dOut.length;
|
|
861
1282
|
const d_model = this.d_model;
|
|
862
1283
|
const h1 = this._h1;
|
|
@@ -881,8 +1302,7 @@ var TransformerBlock = class {
|
|
|
881
1302
|
(_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
|
|
882
1303
|
);
|
|
883
1304
|
this.ff2.update(dW2, lr);
|
|
884
|
-
|
|
885
|
-
this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
|
|
1305
|
+
this.b2.update(db2, lr);
|
|
886
1306
|
const dFf1Pre = dFf1Out.map(
|
|
887
1307
|
(d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
|
|
888
1308
|
);
|
|
@@ -904,8 +1324,7 @@ var TransformerBlock = class {
|
|
|
904
1324
|
(_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
|
|
905
1325
|
);
|
|
906
1326
|
this.ff1.update(dW1, lr);
|
|
907
|
-
|
|
908
|
-
this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
|
|
1327
|
+
this.b1.update(db1, lr);
|
|
909
1328
|
const dH1 = Array.from(
|
|
910
1329
|
{ length: seqLen },
|
|
911
1330
|
(_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
|
|
@@ -927,6 +1346,36 @@ var TransformerBlock = class {
|
|
|
927
1346
|
getAttentionWeights() {
|
|
928
1347
|
return this.attn.getAttentionWeights();
|
|
929
1348
|
}
|
|
1349
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1350
|
+
// Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
|
|
1351
|
+
getWeights() {
|
|
1352
|
+
const w = [];
|
|
1353
|
+
w.push(...this.attn.getWeights());
|
|
1354
|
+
w.push(...this.norm1.gamma, ...this.norm1.beta);
|
|
1355
|
+
for (const row of this.ff1.W) w.push(...row);
|
|
1356
|
+
w.push(...this.b1.values);
|
|
1357
|
+
for (const row of this.ff2.W) w.push(...row);
|
|
1358
|
+
w.push(...this.b2.values);
|
|
1359
|
+
w.push(...this.norm2.gamma, ...this.norm2.beta);
|
|
1360
|
+
return w;
|
|
1361
|
+
}
|
|
1362
|
+
setWeights(weights) {
|
|
1363
|
+
let idx = 0;
|
|
1364
|
+
const attnLen = this.attn.getWeights().length;
|
|
1365
|
+
this.attn.setWeights(weights.slice(idx, idx + attnLen));
|
|
1366
|
+
idx += attnLen;
|
|
1367
|
+
this.norm1.setWeights(weights.slice(idx, idx + this.norm1.getWeights().length));
|
|
1368
|
+
idx += this.norm1.getWeights().length;
|
|
1369
|
+
this.ff1.setWeights(weights.slice(idx, idx + this.ff1.getWeights().length));
|
|
1370
|
+
idx += this.ff1.getWeights().length;
|
|
1371
|
+
this.b1.setWeights(weights.slice(idx, idx + this.b1.values.length));
|
|
1372
|
+
idx += this.b1.values.length;
|
|
1373
|
+
this.ff2.setWeights(weights.slice(idx, idx + this.ff2.getWeights().length));
|
|
1374
|
+
idx += this.ff2.getWeights().length;
|
|
1375
|
+
this.b2.setWeights(weights.slice(idx, idx + this.b2.values.length));
|
|
1376
|
+
idx += this.b2.values.length;
|
|
1377
|
+
this.norm2.setWeights(weights.slice(idx, idx + this.norm2.getWeights().length));
|
|
1378
|
+
}
|
|
930
1379
|
};
|
|
931
1380
|
|
|
932
1381
|
// src/NetworkTransformer.ts
|
|
@@ -951,8 +1400,7 @@ var NetworkTransformer = class {
|
|
|
951
1400
|
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
952
1401
|
);
|
|
953
1402
|
this.outputProj = new WeightMatrix(nClasses, d_model);
|
|
954
|
-
this.outputBias = new
|
|
955
|
-
this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
|
|
1403
|
+
this.outputBias = new BiasVector(nClasses);
|
|
956
1404
|
}
|
|
957
1405
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
958
1406
|
// tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
|
|
@@ -960,7 +1408,7 @@ var NetworkTransformer = class {
|
|
|
960
1408
|
const h = this._forward(tokens);
|
|
961
1409
|
return h.flatMap(
|
|
962
1410
|
(hi) => this.outputProj.W.map(
|
|
963
|
-
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
1411
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
|
|
964
1412
|
)
|
|
965
1413
|
);
|
|
966
1414
|
}
|
|
@@ -974,7 +1422,7 @@ var NetworkTransformer = class {
|
|
|
974
1422
|
const h = this._forward(tokens);
|
|
975
1423
|
const logits = h.map(
|
|
976
1424
|
(hi) => this.outputProj.W.map(
|
|
977
|
-
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
1425
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
|
|
978
1426
|
)
|
|
979
1427
|
);
|
|
980
1428
|
let loss = 0;
|
|
@@ -1009,8 +1457,7 @@ var NetworkTransformer = class {
|
|
|
1009
1457
|
(_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
|
|
1010
1458
|
);
|
|
1011
1459
|
this.outputProj.update(dWout, lr);
|
|
1012
|
-
|
|
1013
|
-
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1460
|
+
this.outputBias.update(dBout, lr);
|
|
1014
1461
|
let dX = dH;
|
|
1015
1462
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1016
1463
|
dX = this.blocks[b].backward(dX, lr);
|
|
@@ -1025,6 +1472,35 @@ var NetworkTransformer = class {
|
|
|
1025
1472
|
getAttentionWeights() {
|
|
1026
1473
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1027
1474
|
}
|
|
1475
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1476
|
+
// Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1477
|
+
getWeights() {
|
|
1478
|
+
const w = [];
|
|
1479
|
+
w.push(...this.tokenEmb.getWeights());
|
|
1480
|
+
w.push(...this.posEmb.getWeights());
|
|
1481
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1482
|
+
w.push(...this.outputProj.getWeights());
|
|
1483
|
+
w.push(...this.outputBias.getWeights());
|
|
1484
|
+
return w;
|
|
1485
|
+
}
|
|
1486
|
+
setWeights(weights) {
|
|
1487
|
+
let idx = 0;
|
|
1488
|
+
const tokenEmbLen = this.tokenEmb.getWeights().length;
|
|
1489
|
+
this.tokenEmb.setWeights(weights.slice(idx, idx + tokenEmbLen));
|
|
1490
|
+
idx += tokenEmbLen;
|
|
1491
|
+
const posEmbLen = this.posEmb.getWeights().length;
|
|
1492
|
+
this.posEmb.setWeights(weights.slice(idx, idx + posEmbLen));
|
|
1493
|
+
idx += posEmbLen;
|
|
1494
|
+
for (const block of this.blocks) {
|
|
1495
|
+
const blockLen = block.getWeights().length;
|
|
1496
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1497
|
+
idx += blockLen;
|
|
1498
|
+
}
|
|
1499
|
+
const outProjLen = this.outputProj.getWeights().length;
|
|
1500
|
+
this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
|
|
1501
|
+
idx += outProjLen;
|
|
1502
|
+
this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
|
|
1503
|
+
}
|
|
1028
1504
|
// ── Internal ──────────────────────────────────────────────────────────────
|
|
1029
1505
|
// Shared embedding + block forward pass.
|
|
1030
1506
|
_forward(tokens) {
|
|
@@ -1044,25 +1520,28 @@ var NetworkTransformerRL = class {
|
|
|
1044
1520
|
constructor(seqLen, inputDim, options = {}) {
|
|
1045
1521
|
// Forward caches para backprop
|
|
1046
1522
|
this._projected = null;
|
|
1523
|
+
// For max pooling backward: argmax per dimension across all positions
|
|
1524
|
+
this._argmax = null;
|
|
1047
1525
|
const {
|
|
1048
1526
|
d_model = 32,
|
|
1049
1527
|
nHeads = 2,
|
|
1050
1528
|
d_ff = 64,
|
|
1051
1529
|
nBlocks = 2,
|
|
1052
|
-
nActions = 2
|
|
1530
|
+
nActions = 2,
|
|
1531
|
+
pooling = "weighted"
|
|
1053
1532
|
} = options;
|
|
1054
1533
|
this.seqLen = seqLen;
|
|
1055
1534
|
this.inputDim = inputDim;
|
|
1056
1535
|
this.d_model = d_model;
|
|
1057
1536
|
this.nActions = nActions;
|
|
1537
|
+
this._pooling = pooling;
|
|
1058
1538
|
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1059
1539
|
this.blocks = Array.from(
|
|
1060
1540
|
{ length: nBlocks },
|
|
1061
|
-
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1541
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
|
|
1062
1542
|
);
|
|
1063
1543
|
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1064
|
-
this.outputBias = new
|
|
1065
|
-
this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
|
|
1544
|
+
this.outputBias = new BiasVector(nActions);
|
|
1066
1545
|
}
|
|
1067
1546
|
// ── Forward ────────────────────────────────────────────────────────────────
|
|
1068
1547
|
// sequence: seqLen × inputDim → nActions Q-values
|
|
@@ -1070,7 +1549,7 @@ var NetworkTransformerRL = class {
|
|
|
1070
1549
|
const h = this._forward(sequence);
|
|
1071
1550
|
const pooled = this._pool(h);
|
|
1072
1551
|
return this.outputProj.W.map(
|
|
1073
|
-
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1552
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
|
|
1074
1553
|
);
|
|
1075
1554
|
}
|
|
1076
1555
|
// ── Training ────────────────────────────────────────────────────────────────
|
|
@@ -1082,7 +1561,7 @@ var NetworkTransformerRL = class {
|
|
|
1082
1561
|
const h = this._forward(sequence);
|
|
1083
1562
|
const pooled = this._pool(h);
|
|
1084
1563
|
const pred = this.outputProj.W.map(
|
|
1085
|
-
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1564
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
|
|
1086
1565
|
);
|
|
1087
1566
|
const n = this.nActions;
|
|
1088
1567
|
let loss = 0;
|
|
@@ -1105,13 +1584,8 @@ var NetworkTransformerRL = class {
|
|
|
1105
1584
|
);
|
|
1106
1585
|
const dBout = dPred.slice();
|
|
1107
1586
|
this.outputProj.update(dWout, lr);
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
let dH = Array.from(
|
|
1111
|
-
{ length: this.seqLen },
|
|
1112
|
-
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1113
|
-
// Gradiente dividido entre posiciones
|
|
1114
|
-
);
|
|
1587
|
+
this.outputBias.update(dBout, lr);
|
|
1588
|
+
let dH = this._distributePoolGradient(dPooled);
|
|
1115
1589
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1116
1590
|
dH = this.blocks[b].backward(dH, lr);
|
|
1117
1591
|
for (let i = 0; i < this.seqLen; i++) {
|
|
@@ -1130,8 +1604,32 @@ var NetworkTransformerRL = class {
|
|
|
1130
1604
|
getAttentionWeights() {
|
|
1131
1605
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1132
1606
|
}
|
|
1133
|
-
// ──
|
|
1134
|
-
|
|
1607
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1608
|
+
// Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1609
|
+
getWeightsFlat() {
|
|
1610
|
+
const w = [];
|
|
1611
|
+
w.push(...this.inputProj.getWeights());
|
|
1612
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1613
|
+
w.push(...this.outputProj.getWeights());
|
|
1614
|
+
w.push(...this.outputBias.getWeights());
|
|
1615
|
+
return w;
|
|
1616
|
+
}
|
|
1617
|
+
setWeightsFlat(weights) {
|
|
1618
|
+
let idx = 0;
|
|
1619
|
+
const inputProjLen = this.inputProj.getWeights().length;
|
|
1620
|
+
this.inputProj.setWeights(weights.slice(idx, idx + inputProjLen));
|
|
1621
|
+
idx += inputProjLen;
|
|
1622
|
+
for (const block of this.blocks) {
|
|
1623
|
+
const blockLen = block.getWeights().length;
|
|
1624
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1625
|
+
idx += blockLen;
|
|
1626
|
+
}
|
|
1627
|
+
const outProjLen = this.outputProj.getWeights().length;
|
|
1628
|
+
this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
|
|
1629
|
+
idx += outProjLen;
|
|
1630
|
+
this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
|
|
1631
|
+
}
|
|
1632
|
+
getWeightsStructured() {
|
|
1135
1633
|
return {
|
|
1136
1634
|
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1137
1635
|
blocks: this.blocks.map((b) => ({
|
|
@@ -1147,17 +1645,15 @@ var NetworkTransformerRL = class {
|
|
|
1147
1645
|
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1148
1646
|
ff1: b.ff1.W.map((r) => [...r]),
|
|
1149
1647
|
ff2: b.ff2.W.map((r) => [...r]),
|
|
1150
|
-
b1: [...b.b1],
|
|
1151
|
-
b2: [...b.b2]
|
|
1648
|
+
b1: [...b.b1.values],
|
|
1649
|
+
b2: [...b.b2.values]
|
|
1152
1650
|
})),
|
|
1153
1651
|
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1154
|
-
outputBias: [...this.outputBias]
|
|
1652
|
+
outputBias: [...this.outputBias.values]
|
|
1155
1653
|
};
|
|
1156
1654
|
}
|
|
1157
|
-
|
|
1158
|
-
data.inputProj.
|
|
1159
|
-
this.inputProj.W[i] = [...row];
|
|
1160
|
-
});
|
|
1655
|
+
setWeightsStructured(data) {
|
|
1656
|
+
this.inputProj.setWeights(data.inputProj.flat());
|
|
1161
1657
|
data.blocks.forEach((bd, b) => {
|
|
1162
1658
|
const blk = this.blocks[b];
|
|
1163
1659
|
bd.attn.heads.forEach((hd, h) => {
|
|
@@ -1172,11 +1668,20 @@ var NetworkTransformerRL = class {
|
|
|
1172
1668
|
blk.norm2.beta = [...bd.norm2.beta];
|
|
1173
1669
|
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1174
1670
|
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1175
|
-
blk.b1
|
|
1176
|
-
blk.b2
|
|
1671
|
+
blk.b1.setWeights(bd.b1);
|
|
1672
|
+
blk.b2.setWeights(bd.b2);
|
|
1177
1673
|
});
|
|
1178
1674
|
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1179
|
-
this.outputBias
|
|
1675
|
+
this.outputBias.setWeights(data.outputBias);
|
|
1676
|
+
}
|
|
1677
|
+
// ── Serializable interface (flat array) ────────────────────────────────────
|
|
1678
|
+
// These satisfy the Serializable interface from ModelSaver, which requires
|
|
1679
|
+
// getWeights(): number[] and setWeights(weights: number[]): void.
|
|
1680
|
+
getWeights() {
|
|
1681
|
+
return this.getWeightsFlat();
|
|
1682
|
+
}
|
|
1683
|
+
setWeights(weights) {
|
|
1684
|
+
this.setWeightsFlat(weights);
|
|
1180
1685
|
}
|
|
1181
1686
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1182
1687
|
_forward(sequence) {
|
|
@@ -1191,6 +1696,44 @@ var NetworkTransformerRL = class {
|
|
|
1191
1696
|
return h;
|
|
1192
1697
|
}
|
|
1193
1698
|
_pool(h) {
|
|
1699
|
+
switch (this._pooling) {
|
|
1700
|
+
case "avg":
|
|
1701
|
+
return this._poolAvg(h);
|
|
1702
|
+
case "max":
|
|
1703
|
+
return this._poolMax(h);
|
|
1704
|
+
case "last":
|
|
1705
|
+
return this._poolLast(h);
|
|
1706
|
+
case "weighted":
|
|
1707
|
+
default:
|
|
1708
|
+
return this._poolWeighted(h);
|
|
1709
|
+
}
|
|
1710
|
+
}
|
|
1711
|
+
_poolAvg(h) {
|
|
1712
|
+
const n = h.length;
|
|
1713
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1714
|
+
let sum = 0;
|
|
1715
|
+
for (let i = 0; i < n; i++)
|
|
1716
|
+
sum += h[i][m];
|
|
1717
|
+
return sum / n;
|
|
1718
|
+
});
|
|
1719
|
+
}
|
|
1720
|
+
_poolMax(h) {
|
|
1721
|
+
this._argmax = new Array(this.d_model).fill(0);
|
|
1722
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1723
|
+
let maxVal = -Infinity;
|
|
1724
|
+
for (let i = 0; i < h.length; i++) {
|
|
1725
|
+
if (h[i][m] > maxVal) {
|
|
1726
|
+
maxVal = h[i][m];
|
|
1727
|
+
this._argmax[m] = i;
|
|
1728
|
+
}
|
|
1729
|
+
}
|
|
1730
|
+
return maxVal;
|
|
1731
|
+
});
|
|
1732
|
+
}
|
|
1733
|
+
_poolLast(h) {
|
|
1734
|
+
return [...h[h.length - 1]];
|
|
1735
|
+
}
|
|
1736
|
+
_poolWeighted(h) {
|
|
1194
1737
|
const weights = Array.from(
|
|
1195
1738
|
{ length: this.seqLen },
|
|
1196
1739
|
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
@@ -1203,6 +1746,55 @@ var NetworkTransformerRL = class {
|
|
|
1203
1746
|
return sum / totalWeight;
|
|
1204
1747
|
});
|
|
1205
1748
|
}
|
|
1749
|
+
/** Returns the current pooling type for inspection. */
|
|
1750
|
+
getPoolingType() {
|
|
1751
|
+
return this._pooling;
|
|
1752
|
+
}
|
|
1753
|
+
// ── Helper: distribute pooled gradient back to each position ────────────────
|
|
1754
|
+
// Must match the same distribution as _pool() used during forward.
|
|
1755
|
+
_distributePoolGradient(dPooled) {
|
|
1756
|
+
switch (this._pooling) {
|
|
1757
|
+
case "avg": {
|
|
1758
|
+
const n = this.seqLen;
|
|
1759
|
+
return Array.from(
|
|
1760
|
+
{ length: n },
|
|
1761
|
+
() => dPooled.map((v) => v / n)
|
|
1762
|
+
);
|
|
1763
|
+
}
|
|
1764
|
+
case "max": {
|
|
1765
|
+
if (!this._argmax) {
|
|
1766
|
+
const n = this.seqLen;
|
|
1767
|
+
return Array.from(
|
|
1768
|
+
{ length: n },
|
|
1769
|
+
() => dPooled.map((v) => v / n)
|
|
1770
|
+
);
|
|
1771
|
+
}
|
|
1772
|
+
const argmax = this._argmax;
|
|
1773
|
+
return Array.from(
|
|
1774
|
+
{ length: this.seqLen },
|
|
1775
|
+
(_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
|
|
1776
|
+
);
|
|
1777
|
+
}
|
|
1778
|
+
case "last": {
|
|
1779
|
+
return Array.from(
|
|
1780
|
+
{ length: this.seqLen },
|
|
1781
|
+
(_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
|
|
1782
|
+
);
|
|
1783
|
+
}
|
|
1784
|
+
case "weighted":
|
|
1785
|
+
default: {
|
|
1786
|
+
const weights = Array.from(
|
|
1787
|
+
{ length: this.seqLen },
|
|
1788
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1789
|
+
);
|
|
1790
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1791
|
+
return Array.from(
|
|
1792
|
+
{ length: this.seqLen },
|
|
1793
|
+
(_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
|
|
1794
|
+
);
|
|
1795
|
+
}
|
|
1796
|
+
}
|
|
1797
|
+
}
|
|
1206
1798
|
};
|
|
1207
1799
|
|
|
1208
1800
|
// src/losses.ts
|
|
@@ -1227,13 +1819,802 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1227
1819
|
const p = Math.max(eps, Math.min(1 - eps, predicted));
|
|
1228
1820
|
return actual / p - (1 - actual) / (1 - p);
|
|
1229
1821
|
}
|
|
1822
|
+
|
|
1823
|
+
// src/GRU.ts
|
|
1824
|
+
function sigmoid4(x) {
|
|
1825
|
+
return 1 / (1 + Math.exp(-x));
|
|
1826
|
+
}
|
|
1827
|
+
function tanhFn(x) {
|
|
1828
|
+
const e = Math.exp(2 * x);
|
|
1829
|
+
return (e - 1) / (e + 1);
|
|
1830
|
+
}
|
|
1831
|
+
var Gate2 = class {
|
|
1832
|
+
constructor(inputSize, hSize, initBias = 0) {
|
|
1833
|
+
const n = inputSize + hSize;
|
|
1834
|
+
const limit = Math.sqrt(2 / (n + hSize));
|
|
1835
|
+
this.W = Array.from(
|
|
1836
|
+
{ length: hSize },
|
|
1837
|
+
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
1838
|
+
);
|
|
1839
|
+
this.b = new Array(hSize).fill(initBias);
|
|
1840
|
+
}
|
|
1841
|
+
linear(combined) {
|
|
1842
|
+
return this.W.map(
|
|
1843
|
+
(row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
|
|
1844
|
+
);
|
|
1845
|
+
}
|
|
1846
|
+
};
|
|
1847
|
+
var GRULayer = class {
|
|
1848
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
1849
|
+
this._traj = [];
|
|
1850
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
1851
|
+
throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
1852
|
+
}
|
|
1853
|
+
this.inputSize = inputSize;
|
|
1854
|
+
this.hSize = hiddenSize;
|
|
1855
|
+
this.h = new Array(hiddenSize).fill(0);
|
|
1856
|
+
this.resetGate = new Gate2(inputSize, hiddenSize);
|
|
1857
|
+
this.updateGate = new Gate2(inputSize, hiddenSize);
|
|
1858
|
+
this.newGate = new Gate2(inputSize, hiddenSize);
|
|
1859
|
+
const combSize = inputSize + hiddenSize;
|
|
1860
|
+
this._optimizers = {
|
|
1861
|
+
resetW: Array.from(
|
|
1862
|
+
{ length: hiddenSize },
|
|
1863
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1864
|
+
),
|
|
1865
|
+
resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1866
|
+
updateW: Array.from(
|
|
1867
|
+
{ length: hiddenSize },
|
|
1868
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1869
|
+
),
|
|
1870
|
+
updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1871
|
+
newW: Array.from(
|
|
1872
|
+
{ length: hiddenSize },
|
|
1873
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1874
|
+
),
|
|
1875
|
+
newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
1876
|
+
};
|
|
1877
|
+
}
|
|
1878
|
+
reset() {
|
|
1879
|
+
this.h = new Array(this.hSize).fill(0);
|
|
1880
|
+
this._traj = [];
|
|
1881
|
+
}
|
|
1882
|
+
predict(inputs) {
|
|
1883
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
1884
|
+
throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
1885
|
+
}
|
|
1886
|
+
const combined = [...inputs, ...this.h];
|
|
1887
|
+
const h_prev = [...this.h];
|
|
1888
|
+
const r_pre = this.resetGate.linear(combined);
|
|
1889
|
+
const z_pre = this.updateGate.linear(combined);
|
|
1890
|
+
const r_a = r_pre.map(sigmoid4);
|
|
1891
|
+
const z_a = z_pre.map(sigmoid4);
|
|
1892
|
+
const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
|
|
1893
|
+
const n_pre = this.newGate.linear(combined_r);
|
|
1894
|
+
const n_a = n_pre.map(tanhFn);
|
|
1895
|
+
const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
|
|
1896
|
+
this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
|
|
1897
|
+
this.h = h;
|
|
1898
|
+
return h;
|
|
1899
|
+
}
|
|
1900
|
+
backprop(dh_seq, lr) {
|
|
1901
|
+
const T = this._traj.length;
|
|
1902
|
+
if (T === 0 || dh_seq.length !== T) return;
|
|
1903
|
+
const hSize = this.hSize;
|
|
1904
|
+
const combSize = this.inputSize + hSize;
|
|
1905
|
+
const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1906
|
+
const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1907
|
+
const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1908
|
+
const dbr = new Array(hSize).fill(0);
|
|
1909
|
+
const dbz = new Array(hSize).fill(0);
|
|
1910
|
+
const dbn = new Array(hSize).fill(0);
|
|
1911
|
+
let dh_next = new Array(hSize).fill(0);
|
|
1912
|
+
for (let t = T - 1; t >= 0; t--) {
|
|
1913
|
+
const s = this._traj[t];
|
|
1914
|
+
const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
|
|
1915
|
+
const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
|
|
1916
|
+
const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
|
|
1917
|
+
const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
|
|
1918
|
+
const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
|
|
1919
|
+
const dr_hprev = Array.from(
|
|
1920
|
+
{ length: hSize },
|
|
1921
|
+
(_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
|
|
1922
|
+
);
|
|
1923
|
+
const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
|
|
1924
|
+
const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
|
|
1925
|
+
for (let k = 0; k < hSize; k++) {
|
|
1926
|
+
for (let j = 0; j < combSize; j++) {
|
|
1927
|
+
dWr[k][j] += dr_pre[k] * s.combined[j];
|
|
1928
|
+
dWz[k][j] += dz_pre[k] * s.combined[j];
|
|
1929
|
+
dWn[k][j] += dn_pre[k] * s.combined_r[j];
|
|
1930
|
+
}
|
|
1931
|
+
dbr[k] += dr_pre[k];
|
|
1932
|
+
dbz[k] += dz_pre[k];
|
|
1933
|
+
dbn[k] += dn_pre[k];
|
|
1934
|
+
}
|
|
1935
|
+
dh_next = new Array(hSize).fill(0);
|
|
1936
|
+
for (let k = 0; k < hSize; k++) {
|
|
1937
|
+
for (let j = this.inputSize; j < combSize; j++) {
|
|
1938
|
+
dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
|
|
1939
|
+
}
|
|
1940
|
+
dh_next[k] += dr_hprev[k] * s.r_a[k];
|
|
1941
|
+
dh_next[k] += dh[k] * s.z_a[k];
|
|
1942
|
+
}
|
|
1943
|
+
}
|
|
1944
|
+
const scale = lr / T;
|
|
1945
|
+
for (let k = 0; k < hSize; k++) {
|
|
1946
|
+
for (let j = 0; j < combSize; j++) {
|
|
1947
|
+
this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
|
|
1948
|
+
this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
|
|
1949
|
+
this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
|
|
1950
|
+
}
|
|
1951
|
+
this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
|
|
1952
|
+
this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
|
|
1953
|
+
this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
|
|
1954
|
+
}
|
|
1955
|
+
this._traj = [];
|
|
1956
|
+
}
|
|
1957
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1958
|
+
// Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
|
|
1959
|
+
getWeightsFlat() {
|
|
1960
|
+
const w = [];
|
|
1961
|
+
for (const row of this.resetGate.W) w.push(...row);
|
|
1962
|
+
w.push(...this.resetGate.b);
|
|
1963
|
+
for (const row of this.updateGate.W) w.push(...row);
|
|
1964
|
+
w.push(...this.updateGate.b);
|
|
1965
|
+
for (const row of this.newGate.W) w.push(...row);
|
|
1966
|
+
w.push(...this.newGate.b);
|
|
1967
|
+
return w;
|
|
1968
|
+
}
|
|
1969
|
+
setWeightsFlat(weights) {
|
|
1970
|
+
let idx = 0;
|
|
1971
|
+
for (let i = 0; i < this.resetGate.W.length; i++)
|
|
1972
|
+
for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
|
|
1973
|
+
for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
|
|
1974
|
+
for (let i = 0; i < this.updateGate.W.length; i++)
|
|
1975
|
+
for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
|
|
1976
|
+
for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
|
|
1977
|
+
for (let i = 0; i < this.newGate.W.length; i++)
|
|
1978
|
+
for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
|
|
1979
|
+
for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
|
|
1980
|
+
}
|
|
1981
|
+
getWeights() {
|
|
1982
|
+
return {
|
|
1983
|
+
resetGate: { W: this.resetGate.W, b: this.resetGate.b },
|
|
1984
|
+
updateGate: { W: this.updateGate.W, b: this.updateGate.b },
|
|
1985
|
+
newGate: { W: this.newGate.W, b: this.newGate.b }
|
|
1986
|
+
};
|
|
1987
|
+
}
|
|
1988
|
+
setWeights(data) {
|
|
1989
|
+
this.resetGate.W = data.resetGate.W;
|
|
1990
|
+
this.resetGate.b = data.resetGate.b;
|
|
1991
|
+
this.updateGate.W = data.updateGate.W;
|
|
1992
|
+
this.updateGate.b = data.updateGate.b;
|
|
1993
|
+
this.newGate.W = data.newGate.W;
|
|
1994
|
+
this.newGate.b = data.newGate.b;
|
|
1995
|
+
}
|
|
1996
|
+
};
|
|
1997
|
+
|
|
1998
|
+
// src/BatchNorm.ts
|
|
1999
|
+
var BatchNorm = class {
|
|
2000
|
+
constructor(dim, momentum = 0.1) {
|
|
2001
|
+
this._xNorm = null;
|
|
2002
|
+
this._std = null;
|
|
2003
|
+
this.dim = dim;
|
|
2004
|
+
this.momentum = momentum;
|
|
2005
|
+
this.gamma = new Array(dim).fill(1);
|
|
2006
|
+
this.beta = new Array(dim).fill(0);
|
|
2007
|
+
this.runningMean = new Array(dim).fill(0);
|
|
2008
|
+
this.runningVar = new Array(dim).fill(1);
|
|
2009
|
+
}
|
|
2010
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2011
|
+
forward(x) {
|
|
2012
|
+
if (x.length !== this.dim) {
|
|
2013
|
+
throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
|
|
2014
|
+
}
|
|
2015
|
+
const eps = 1e-5;
|
|
2016
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2017
|
+
this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
|
|
2018
|
+
const diff = x[i] - this.runningMean[i];
|
|
2019
|
+
this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
|
|
2020
|
+
}
|
|
2021
|
+
this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
|
|
2022
|
+
this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
|
|
2023
|
+
return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
|
|
2024
|
+
}
|
|
2025
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2026
|
+
backward(dOut) {
|
|
2027
|
+
if (!this._xNorm || !this._std) {
|
|
2028
|
+
throw new Error("BatchNorm.backward: call forward() first");
|
|
2029
|
+
}
|
|
2030
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2031
|
+
}
|
|
2032
|
+
return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
|
|
2033
|
+
}
|
|
2034
|
+
// ── Train gamma and beta (call after backward) ────────────────────────────
|
|
2035
|
+
trainParams(dOut, lr) {
|
|
2036
|
+
if (!this._xNorm) return;
|
|
2037
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2038
|
+
this.gamma[i] += lr * dOut[i] * this._xNorm[i];
|
|
2039
|
+
this.beta[i] += lr * dOut[i];
|
|
2040
|
+
}
|
|
2041
|
+
}
|
|
2042
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2043
|
+
// Order: gamma, beta.
|
|
2044
|
+
getWeights() {
|
|
2045
|
+
return [...this.gamma, ...this.beta];
|
|
2046
|
+
}
|
|
2047
|
+
setWeights(weights) {
|
|
2048
|
+
for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
|
|
2049
|
+
for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
|
|
2050
|
+
}
|
|
2051
|
+
};
|
|
2052
|
+
|
|
2053
|
+
// src/Conv1D.ts
|
|
2054
|
+
var Conv1D = class {
|
|
2055
|
+
constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
|
|
2056
|
+
// [filters]
|
|
2057
|
+
this._input = null;
|
|
2058
|
+
this._paddedInput = null;
|
|
2059
|
+
if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
|
|
2060
|
+
throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
|
|
2061
|
+
}
|
|
2062
|
+
if (kernelSize > inputLength && padding === "valid") {
|
|
2063
|
+
throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
|
|
2064
|
+
}
|
|
2065
|
+
if (inputChannels < 1) {
|
|
2066
|
+
throw new Error("Conv1D: inputChannels must be >= 1");
|
|
2067
|
+
}
|
|
2068
|
+
this.inputLength = inputLength;
|
|
2069
|
+
this.kernelSize = kernelSize;
|
|
2070
|
+
this.filters = filters;
|
|
2071
|
+
this.stride = stride;
|
|
2072
|
+
this.padding = padding;
|
|
2073
|
+
this.inputChannels = inputChannels;
|
|
2074
|
+
const limit = Math.sqrt(2 / (kernelSize * inputChannels));
|
|
2075
|
+
this.kernels = Array.from(
|
|
2076
|
+
{ length: filters },
|
|
2077
|
+
() => Array.from(
|
|
2078
|
+
{ length: kernelSize },
|
|
2079
|
+
() => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
|
|
2080
|
+
)
|
|
2081
|
+
);
|
|
2082
|
+
this.biases = new Array(filters).fill(0);
|
|
2083
|
+
this._kOpts = Array.from(
|
|
2084
|
+
{ length: filters },
|
|
2085
|
+
() => Array.from(
|
|
2086
|
+
{ length: kernelSize },
|
|
2087
|
+
() => Array.from({ length: inputChannels }, () => optimizerFactory())
|
|
2088
|
+
)
|
|
2089
|
+
);
|
|
2090
|
+
this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
|
|
2091
|
+
}
|
|
2092
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2093
|
+
// Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
|
|
2094
|
+
forward(input) {
|
|
2095
|
+
const input2D = this._normalizeInput(input);
|
|
2096
|
+
this._input = input2D.map((row) => [...row]);
|
|
2097
|
+
let padded;
|
|
2098
|
+
if (this.padding === "same") {
|
|
2099
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2100
|
+
const padRow = new Array(this.inputChannels).fill(0);
|
|
2101
|
+
padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
|
|
2102
|
+
} else {
|
|
2103
|
+
padded = input2D;
|
|
2104
|
+
}
|
|
2105
|
+
this._paddedInput = padded;
|
|
2106
|
+
const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
|
|
2107
|
+
const output = Array.from(
|
|
2108
|
+
{ length: this.filters },
|
|
2109
|
+
() => new Array(outputLength).fill(0)
|
|
2110
|
+
);
|
|
2111
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2112
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2113
|
+
const start = pos * this.stride;
|
|
2114
|
+
let sum = this.biases[f];
|
|
2115
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2116
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2117
|
+
sum += this.kernels[f][k][c] * padded[start + k][c];
|
|
2118
|
+
}
|
|
2119
|
+
}
|
|
2120
|
+
output[f][pos] = sum;
|
|
2121
|
+
}
|
|
2122
|
+
}
|
|
2123
|
+
return output;
|
|
2124
|
+
}
|
|
2125
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2126
|
+
backward(dOut, lr = 1e-3) {
|
|
2127
|
+
if (!this._paddedInput || !this._input) {
|
|
2128
|
+
throw new Error("Conv1D.backward: call forward() first");
|
|
2129
|
+
}
|
|
2130
|
+
const padded = this._paddedInput;
|
|
2131
|
+
const outputLength = dOut[0].length;
|
|
2132
|
+
const dKernels = Array.from(
|
|
2133
|
+
{ length: this.filters },
|
|
2134
|
+
() => Array.from(
|
|
2135
|
+
{ length: this.kernelSize },
|
|
2136
|
+
() => new Array(this.inputChannels).fill(0)
|
|
2137
|
+
)
|
|
2138
|
+
);
|
|
2139
|
+
const dBiases = new Array(this.filters).fill(0);
|
|
2140
|
+
const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
|
|
2141
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2142
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2143
|
+
const start = pos * this.stride;
|
|
2144
|
+
dBiases[f] += dOut[f][pos];
|
|
2145
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2146
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2147
|
+
dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
|
|
2148
|
+
dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
|
|
2149
|
+
}
|
|
2150
|
+
}
|
|
2151
|
+
}
|
|
2152
|
+
}
|
|
2153
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2154
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2155
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2156
|
+
this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
|
|
2157
|
+
}
|
|
2158
|
+
}
|
|
2159
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
2160
|
+
}
|
|
2161
|
+
if (this.padding === "same") {
|
|
2162
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2163
|
+
return dPadded.slice(padSize, padSize + this.inputLength);
|
|
2164
|
+
}
|
|
2165
|
+
return dPadded.slice(0, this.inputLength);
|
|
2166
|
+
}
|
|
2167
|
+
// ── Output length ─────────────────────────────────────────────────────────
|
|
2168
|
+
getOutputLength() {
|
|
2169
|
+
if (this.padding === "same") {
|
|
2170
|
+
return Math.ceil(this.inputLength / this.stride);
|
|
2171
|
+
}
|
|
2172
|
+
return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
|
|
2173
|
+
}
|
|
2174
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2175
|
+
// Order: kernels (flattened), biases.
|
|
2176
|
+
getWeights() {
|
|
2177
|
+
const w = [];
|
|
2178
|
+
for (const kernel of this.kernels)
|
|
2179
|
+
for (const k of kernel)
|
|
2180
|
+
for (const c of k)
|
|
2181
|
+
w.push(c);
|
|
2182
|
+
w.push(...this.biases);
|
|
2183
|
+
return w;
|
|
2184
|
+
}
|
|
2185
|
+
setWeights(weights) {
|
|
2186
|
+
let idx = 0;
|
|
2187
|
+
for (let f = 0; f < this.filters; f++)
|
|
2188
|
+
for (let k = 0; k < this.kernelSize; k++)
|
|
2189
|
+
for (let c = 0; c < this.inputChannels; c++)
|
|
2190
|
+
this.kernels[f][k][c] = weights[idx++];
|
|
2191
|
+
for (let f = 0; f < this.filters; f++)
|
|
2192
|
+
this.biases[f] = weights[idx++];
|
|
2193
|
+
}
|
|
2194
|
+
// ── Normalize input to 2D format ─────────────────────────────────────────
|
|
2195
|
+
_normalizeInput(input) {
|
|
2196
|
+
if (input.length === 0) {
|
|
2197
|
+
throw new Error("Conv1D.forward: input cannot be empty");
|
|
2198
|
+
}
|
|
2199
|
+
if (typeof input[0] === "number") {
|
|
2200
|
+
if (this.inputChannels !== 1) {
|
|
2201
|
+
throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
|
|
2202
|
+
}
|
|
2203
|
+
const input1D = input;
|
|
2204
|
+
if (input1D.length !== this.inputLength) {
|
|
2205
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
|
|
2206
|
+
}
|
|
2207
|
+
return input1D.map((v) => [v]);
|
|
2208
|
+
}
|
|
2209
|
+
const input2D = input;
|
|
2210
|
+
if (input2D.length !== this.inputLength) {
|
|
2211
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
|
|
2212
|
+
}
|
|
2213
|
+
for (let i = 0; i < input2D.length; i++) {
|
|
2214
|
+
if (input2D[i].length !== this.inputChannels) {
|
|
2215
|
+
throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
|
|
2216
|
+
}
|
|
2217
|
+
}
|
|
2218
|
+
return input2D;
|
|
2219
|
+
}
|
|
2220
|
+
};
|
|
2221
|
+
|
|
2222
|
+
// src/Trainer.ts
|
|
2223
|
+
var Trainer = class {
|
|
2224
|
+
constructor(network, options = {}) {
|
|
2225
|
+
this._history = [];
|
|
2226
|
+
this._bestLoss = Infinity;
|
|
2227
|
+
this._patienceCounter = 0;
|
|
2228
|
+
this._stopReason = "maxEpochs";
|
|
2229
|
+
this._metrics = [];
|
|
2230
|
+
this.network = network;
|
|
2231
|
+
this.epochs = options.epochs ?? 1e3;
|
|
2232
|
+
this.lrInitial = options.lr ?? 0.1;
|
|
2233
|
+
this.lrDecay = options.lrDecay ?? 1;
|
|
2234
|
+
this.verbose = options.verbose ?? false;
|
|
2235
|
+
this.weightDecay = options.weightDecay ?? 0;
|
|
2236
|
+
this._earlyStopping = options.earlyStopping;
|
|
2237
|
+
this._computeMetrics = options.computeMetrics ?? false;
|
|
2238
|
+
this.clipValue = options.clipValue ?? 0;
|
|
2239
|
+
}
|
|
2240
|
+
// ── Set external validation data (for early stopping) ────────────────────
|
|
2241
|
+
setValidationData(dataset) {
|
|
2242
|
+
if (dataset.inputs.length !== dataset.targets.length) {
|
|
2243
|
+
throw new Error(
|
|
2244
|
+
"Trainer.setValidationData: inputs and targets must have the same length"
|
|
2245
|
+
);
|
|
2246
|
+
}
|
|
2247
|
+
this._validationData = dataset;
|
|
2248
|
+
}
|
|
2249
|
+
// ── Get best validation loss during training ─────────────────────────────
|
|
2250
|
+
getBestLoss() {
|
|
2251
|
+
return this._bestLoss === Infinity ? -1 : this._bestLoss;
|
|
2252
|
+
}
|
|
2253
|
+
// ── Why did training stop? ───────────────────────────────────────────────
|
|
2254
|
+
getStopReason() {
|
|
2255
|
+
return this._stopReason;
|
|
2256
|
+
}
|
|
2257
|
+
// ── Get per-epoch classification metrics ─────────────────────────────────
|
|
2258
|
+
getMetrics() {
|
|
2259
|
+
return [...this._metrics];
|
|
2260
|
+
}
|
|
2261
|
+
// ── Train on dataset ──────────────────────────────────────────────────────
|
|
2262
|
+
train(dataset) {
|
|
2263
|
+
const { inputs, targets } = dataset;
|
|
2264
|
+
if (inputs.length !== targets.length) {
|
|
2265
|
+
throw new Error(
|
|
2266
|
+
"Trainer.train: inputs and targets must have the same length"
|
|
2267
|
+
);
|
|
2268
|
+
}
|
|
2269
|
+
const n = inputs.length;
|
|
2270
|
+
let lr = this.lrInitial;
|
|
2271
|
+
this._history = [];
|
|
2272
|
+
this._bestLoss = Infinity;
|
|
2273
|
+
this._patienceCounter = 0;
|
|
2274
|
+
this._stopReason = "maxEpochs";
|
|
2275
|
+
this._metrics = [];
|
|
2276
|
+
const netExt = this._hasWeights(this.network);
|
|
2277
|
+
if (this.weightDecay > 0 && !netExt) {
|
|
2278
|
+
console.warn(
|
|
2279
|
+
"Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
|
|
2280
|
+
);
|
|
2281
|
+
}
|
|
2282
|
+
if (this._earlyStopping && !netExt) {
|
|
2283
|
+
console.warn(
|
|
2284
|
+
"Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
|
|
2285
|
+
);
|
|
2286
|
+
}
|
|
2287
|
+
if (this._computeMetrics && !netExt) {
|
|
2288
|
+
console.warn(
|
|
2289
|
+
"Trainer: computeMetrics requires a network with predict(). Skipping metrics."
|
|
2290
|
+
);
|
|
2291
|
+
}
|
|
2292
|
+
const canDecay = this.weightDecay > 0 && netExt;
|
|
2293
|
+
const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
|
|
2294
|
+
const canMetric = this._computeMetrics && netExt;
|
|
2295
|
+
const isClass = canMetric && this._isClassification(targets);
|
|
2296
|
+
if (canMetric && !isClass) {
|
|
2297
|
+
console.warn(
|
|
2298
|
+
"Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
|
|
2299
|
+
);
|
|
2300
|
+
}
|
|
2301
|
+
for (let epoch = 0; epoch < this.epochs; epoch++) {
|
|
2302
|
+
const indices = Array.from({ length: n }, (_, i) => i);
|
|
2303
|
+
for (let i = n - 1; i > 0; i--) {
|
|
2304
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2305
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
2306
|
+
}
|
|
2307
|
+
let epochLoss = 0;
|
|
2308
|
+
for (const i of indices) {
|
|
2309
|
+
if (canDecay) {
|
|
2310
|
+
const w = netExt.getWeights();
|
|
2311
|
+
for (let j = 0; j < w.length; j++) {
|
|
2312
|
+
w[j] *= 1 - lr * this.weightDecay;
|
|
2313
|
+
}
|
|
2314
|
+
netExt.setWeights(w);
|
|
2315
|
+
}
|
|
2316
|
+
epochLoss += this.network.train(inputs[i], targets[i], lr);
|
|
2317
|
+
}
|
|
2318
|
+
epochLoss /= n;
|
|
2319
|
+
this._history.push(epochLoss);
|
|
2320
|
+
if (canMetric && isClass) {
|
|
2321
|
+
this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
|
|
2322
|
+
}
|
|
2323
|
+
if (canValidate && this._validationData) {
|
|
2324
|
+
const valLoss = this._computeLoss(netExt, this._validationData);
|
|
2325
|
+
const minDelta = this._earlyStopping.minDelta;
|
|
2326
|
+
if (valLoss < this._bestLoss - minDelta) {
|
|
2327
|
+
this._bestLoss = valLoss;
|
|
2328
|
+
this._patienceCounter = 0;
|
|
2329
|
+
} else {
|
|
2330
|
+
this._patienceCounter++;
|
|
2331
|
+
}
|
|
2332
|
+
if (this._patienceCounter >= this._earlyStopping.patience) {
|
|
2333
|
+
this._stopReason = "earlyStopping";
|
|
2334
|
+
break;
|
|
2335
|
+
}
|
|
2336
|
+
}
|
|
2337
|
+
lr *= this.lrDecay;
|
|
2338
|
+
if (this.verbose && (epoch + 1) % 100 === 0) {
|
|
2339
|
+
console.log(
|
|
2340
|
+
`Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
|
|
2341
|
+
);
|
|
2342
|
+
}
|
|
2343
|
+
}
|
|
2344
|
+
return this._history;
|
|
2345
|
+
}
|
|
2346
|
+
// ── Get loss history ──────────────────────────────────────────────────────
|
|
2347
|
+
getHistory() {
|
|
2348
|
+
return [...this._history];
|
|
2349
|
+
}
|
|
2350
|
+
// ── Private helpers ───────────────────────────────────────────────────────
|
|
2351
|
+
/** Type guard: does this network support getWeights/setWeights/predict? */
|
|
2352
|
+
_hasWeights(network) {
|
|
2353
|
+
if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
|
|
2354
|
+
return network;
|
|
2355
|
+
}
|
|
2356
|
+
return null;
|
|
2357
|
+
}
|
|
2358
|
+
/** Mean squared error on a dataset (used for validation loss). */
|
|
2359
|
+
_computeLoss(net, data) {
|
|
2360
|
+
let totalLoss = 0;
|
|
2361
|
+
for (let i = 0; i < data.inputs.length; i++) {
|
|
2362
|
+
const pred = net.predict(data.inputs[i]);
|
|
2363
|
+
const target = data.targets[i];
|
|
2364
|
+
let sampleLoss = 0;
|
|
2365
|
+
for (let j = 0; j < pred.length; j++) {
|
|
2366
|
+
sampleLoss += (target[j] - pred[j]) ** 2;
|
|
2367
|
+
}
|
|
2368
|
+
totalLoss += sampleLoss / pred.length;
|
|
2369
|
+
}
|
|
2370
|
+
return totalLoss / data.inputs.length;
|
|
2371
|
+
}
|
|
2372
|
+
/** Heuristic: are targets classification-style (one-hot or single-class)? */
|
|
2373
|
+
_isClassification(targets) {
|
|
2374
|
+
if (targets.length === 0) return false;
|
|
2375
|
+
const first = targets[0];
|
|
2376
|
+
if (first.length === 1) return true;
|
|
2377
|
+
for (const t of targets) {
|
|
2378
|
+
let sum = 0;
|
|
2379
|
+
for (const v of t) {
|
|
2380
|
+
sum += v;
|
|
2381
|
+
if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
|
|
2382
|
+
return false;
|
|
2383
|
+
}
|
|
2384
|
+
if (Math.abs(sum - 1) > 0.01) return false;
|
|
2385
|
+
}
|
|
2386
|
+
return true;
|
|
2387
|
+
}
|
|
2388
|
+
/** Compute classification metrics from predictions vs targets. */
|
|
2389
|
+
_computeMetricsArray(net, inputs, targets) {
|
|
2390
|
+
const targetLen = targets[0].length;
|
|
2391
|
+
const nClasses = targetLen === 1 ? 2 : targetLen;
|
|
2392
|
+
const confusion = Array.from(
|
|
2393
|
+
{ length: nClasses },
|
|
2394
|
+
() => Array(nClasses).fill(0)
|
|
2395
|
+
);
|
|
2396
|
+
for (let i = 0; i < inputs.length; i++) {
|
|
2397
|
+
const pred = net.predict(inputs[i]);
|
|
2398
|
+
const target = targets[i];
|
|
2399
|
+
let predClass;
|
|
2400
|
+
let trueClass;
|
|
2401
|
+
if (targetLen === 1) {
|
|
2402
|
+
trueClass = target[0] >= 0.5 ? 1 : 0;
|
|
2403
|
+
if (pred.length === 1) {
|
|
2404
|
+
predClass = pred[0] >= 0.5 ? 1 : 0;
|
|
2405
|
+
} else {
|
|
2406
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2407
|
+
}
|
|
2408
|
+
} else {
|
|
2409
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2410
|
+
trueClass = target.indexOf(Math.max(...target));
|
|
2411
|
+
}
|
|
2412
|
+
predClass = Math.max(0, Math.min(nClasses - 1, predClass));
|
|
2413
|
+
trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
|
|
2414
|
+
confusion[trueClass][predClass]++;
|
|
2415
|
+
}
|
|
2416
|
+
let totalCorrect = 0;
|
|
2417
|
+
let totalSamples = 0;
|
|
2418
|
+
const precisions = [];
|
|
2419
|
+
const recalls = [];
|
|
2420
|
+
for (let c = 0; c < nClasses; c++) {
|
|
2421
|
+
const tp = confusion[c][c];
|
|
2422
|
+
totalCorrect += tp;
|
|
2423
|
+
let colSum = 0;
|
|
2424
|
+
let rowSum = 0;
|
|
2425
|
+
for (let r = 0; r < nClasses; r++) {
|
|
2426
|
+
colSum += confusion[r][c];
|
|
2427
|
+
rowSum += confusion[c][r];
|
|
2428
|
+
}
|
|
2429
|
+
totalSamples += rowSum;
|
|
2430
|
+
precisions.push(colSum > 0 ? tp / colSum : 0);
|
|
2431
|
+
recalls.push(rowSum > 0 ? tp / rowSum : 0);
|
|
2432
|
+
}
|
|
2433
|
+
const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
|
|
2434
|
+
const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
|
|
2435
|
+
const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
|
|
2436
|
+
const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
|
|
2437
|
+
return {
|
|
2438
|
+
accuracy,
|
|
2439
|
+
precision: macroPrecision,
|
|
2440
|
+
recall: macroRecall,
|
|
2441
|
+
f1
|
|
2442
|
+
};
|
|
2443
|
+
}
|
|
2444
|
+
};
|
|
2445
|
+
|
|
2446
|
+
// src/DataLoader.ts
|
|
2447
|
+
var DataLoader = class _DataLoader {
|
|
2448
|
+
constructor(data, batchSize = 1, validationSplit = 0) {
|
|
2449
|
+
if (data.inputs.length !== data.targets.length) {
|
|
2450
|
+
throw new Error("DataLoader: inputs and targets must have the same length");
|
|
2451
|
+
}
|
|
2452
|
+
if (validationSplit < 0 || validationSplit >= 1) {
|
|
2453
|
+
throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
|
|
2454
|
+
}
|
|
2455
|
+
this.data = data;
|
|
2456
|
+
this.batchSize = batchSize;
|
|
2457
|
+
this._validationSplit = validationSplit;
|
|
2458
|
+
const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
|
|
2459
|
+
for (let i = fullIndices.length - 1; i > 0; i--) {
|
|
2460
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2461
|
+
[fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
|
|
2462
|
+
}
|
|
2463
|
+
if (validationSplit > 0) {
|
|
2464
|
+
const valSize = Math.round(data.inputs.length * validationSplit);
|
|
2465
|
+
const trainSize = data.inputs.length - valSize;
|
|
2466
|
+
this._trainIndices = fullIndices.slice(0, trainSize);
|
|
2467
|
+
this._valIndices = fullIndices.slice(trainSize);
|
|
2468
|
+
} else {
|
|
2469
|
+
this._trainIndices = [...fullIndices];
|
|
2470
|
+
this._valIndices = [];
|
|
2471
|
+
}
|
|
2472
|
+
this._indices = [...this._trainIndices];
|
|
2473
|
+
this._pos = 0;
|
|
2474
|
+
}
|
|
2475
|
+
// ── Shuffle the training data ──────────────────────────────────────────────
|
|
2476
|
+
shuffle() {
|
|
2477
|
+
for (let i = this._trainIndices.length - 1; i > 0; i--) {
|
|
2478
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2479
|
+
[this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
|
|
2480
|
+
}
|
|
2481
|
+
this._indices = [...this._trainIndices];
|
|
2482
|
+
this._pos = 0;
|
|
2483
|
+
}
|
|
2484
|
+
// ── Check if more batches are available ───────────────────────────────────
|
|
2485
|
+
hasNext() {
|
|
2486
|
+
return this._pos < this._indices.length;
|
|
2487
|
+
}
|
|
2488
|
+
// ── Get next batch ────────────────────────────────────────────────────────
|
|
2489
|
+
next() {
|
|
2490
|
+
const end = Math.min(this._pos + this.batchSize, this._indices.length);
|
|
2491
|
+
const batchIndices = this._indices.slice(this._pos, end);
|
|
2492
|
+
this._pos = end;
|
|
2493
|
+
return {
|
|
2494
|
+
inputs: batchIndices.map((i) => this.data.inputs[i]),
|
|
2495
|
+
targets: batchIndices.map((i) => this.data.targets[i])
|
|
2496
|
+
};
|
|
2497
|
+
}
|
|
2498
|
+
// ── Reset iteration ───────────────────────────────────────────────────────
|
|
2499
|
+
reset() {
|
|
2500
|
+
this._pos = 0;
|
|
2501
|
+
}
|
|
2502
|
+
// ── Get total number of training samples ───────────────────────────────────
|
|
2503
|
+
get length() {
|
|
2504
|
+
return this._trainIndices.length;
|
|
2505
|
+
}
|
|
2506
|
+
// ── Get validation data as a DataPair ──────────────────────────────────────
|
|
2507
|
+
// Returns the validation samples (inputs + targets) in their shuffled order.
|
|
2508
|
+
// Returns empty arrays if no validation split was configured.
|
|
2509
|
+
getValidationData() {
|
|
2510
|
+
return {
|
|
2511
|
+
inputs: this._valIndices.map((i) => this.data.inputs[i]),
|
|
2512
|
+
targets: this._valIndices.map((i) => this.data.targets[i])
|
|
2513
|
+
};
|
|
2514
|
+
}
|
|
2515
|
+
// ── Get number of validation samples ───────────────────────────────────────
|
|
2516
|
+
get validationLength() {
|
|
2517
|
+
return this._valIndices.length;
|
|
2518
|
+
}
|
|
2519
|
+
// ── Create sequence windows from a time series ────────────────────────────
|
|
2520
|
+
static sequences(data, seqLen, validationSplit = 0) {
|
|
2521
|
+
if (data.length < seqLen + 1) {
|
|
2522
|
+
throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
|
|
2523
|
+
}
|
|
2524
|
+
const inputs = [];
|
|
2525
|
+
const targets = [];
|
|
2526
|
+
for (let i = 0; i <= data.length - seqLen - 1; i++) {
|
|
2527
|
+
inputs.push(data.slice(i, i + seqLen).flat());
|
|
2528
|
+
targets.push(data[i + seqLen]);
|
|
2529
|
+
}
|
|
2530
|
+
return new _DataLoader({ inputs, targets }, 1, validationSplit);
|
|
2531
|
+
}
|
|
2532
|
+
};
|
|
2533
|
+
|
|
2534
|
+
// src/LRScheduler.ts
|
|
2535
|
+
var LRScheduler = class {
|
|
2536
|
+
// ── Step Decay ────────────────────────────────────────────────────────────
|
|
2537
|
+
// lr = initialLr * dropRate^floor(epoch / epochsDrop)
|
|
2538
|
+
stepDecay(lr, epoch, dropRate, epochsDrop) {
|
|
2539
|
+
return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
|
|
2540
|
+
}
|
|
2541
|
+
// ── Exponential Decay ─────────────────────────────────────────────────────
|
|
2542
|
+
// lr = initialLr * decayRate^epoch
|
|
2543
|
+
exponentialDecay(lr, epoch, decayRate) {
|
|
2544
|
+
return lr * Math.pow(decayRate, epoch);
|
|
2545
|
+
}
|
|
2546
|
+
// ── Plateau Decay ─────────────────────────────────────────────────────────
|
|
2547
|
+
// If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
|
|
2548
|
+
// Returns the new lr. Call this after each epoch with the current loss.
|
|
2549
|
+
//
|
|
2550
|
+
// Usage:
|
|
2551
|
+
// let patience_counter = 0
|
|
2552
|
+
// let best_loss = Infinity
|
|
2553
|
+
// for (let epoch = 0; epoch < 1000; epoch++) {
|
|
2554
|
+
// const loss = train(...)
|
|
2555
|
+
// lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
|
|
2556
|
+
// }
|
|
2557
|
+
plateauDecay(lr, currentLoss, history, patience, factor) {
|
|
2558
|
+
if (history.length < patience) return lr;
|
|
2559
|
+
const recentLosses = history.slice(-patience);
|
|
2560
|
+
const minRecentLoss = Math.min(...recentLosses);
|
|
2561
|
+
if (currentLoss >= minRecentLoss) {
|
|
2562
|
+
return lr * factor;
|
|
2563
|
+
}
|
|
2564
|
+
return lr;
|
|
2565
|
+
}
|
|
2566
|
+
// ── Cosine Annealing ──────────────────────────────────────────────────────
|
|
2567
|
+
// lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
|
|
2568
|
+
cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
|
|
2569
|
+
return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
|
|
2570
|
+
}
|
|
2571
|
+
};
|
|
2572
|
+
|
|
2573
|
+
// src/ModelSaver.ts
|
|
2574
|
+
var ModelSaver = class _ModelSaver {
|
|
2575
|
+
// ── Serialize to JSON string ──────────────────────────────────────────────
|
|
2576
|
+
static toJSON(model) {
|
|
2577
|
+
return JSON.stringify({
|
|
2578
|
+
weights: model.getWeights(),
|
|
2579
|
+
timestamp: Date.now()
|
|
2580
|
+
});
|
|
2581
|
+
}
|
|
2582
|
+
// ── Deserialize from JSON string ──────────────────────────────────────────
|
|
2583
|
+
static fromJSON(model, json) {
|
|
2584
|
+
const data = JSON.parse(json);
|
|
2585
|
+
if (!data.weights || !Array.isArray(data.weights)) {
|
|
2586
|
+
throw new Error("ModelSaver.fromJSON: invalid model data");
|
|
2587
|
+
}
|
|
2588
|
+
model.setWeights(data.weights);
|
|
2589
|
+
}
|
|
2590
|
+
// ── Save to file (requires write function) ────────────────────────────────
|
|
2591
|
+
static saveToFile(model, path, writeFn) {
|
|
2592
|
+
const json = _ModelSaver.toJSON(model);
|
|
2593
|
+
writeFn(path, json);
|
|
2594
|
+
}
|
|
2595
|
+
// ── Load from file (requires read function) ───────────────────────────────
|
|
2596
|
+
static loadFromFile(model, path, readFn) {
|
|
2597
|
+
const json = readFn(path);
|
|
2598
|
+
_ModelSaver.fromJSON(model, json);
|
|
2599
|
+
}
|
|
2600
|
+
};
|
|
1230
2601
|
export {
|
|
1231
2602
|
Adam,
|
|
1232
2603
|
AttentionHead,
|
|
2604
|
+
BatchNorm,
|
|
2605
|
+
BiasVector,
|
|
2606
|
+
ClipOptimizer,
|
|
2607
|
+
ClippedOptimizerFactory,
|
|
2608
|
+
Conv1D,
|
|
2609
|
+
DataLoader,
|
|
2610
|
+
Dropout,
|
|
1233
2611
|
EmbeddingMatrix,
|
|
2612
|
+
GRULayer,
|
|
2613
|
+
LRScheduler,
|
|
1234
2614
|
LSTMLayer,
|
|
1235
2615
|
Layer,
|
|
1236
2616
|
LayerNorm,
|
|
2617
|
+
ModelSaver,
|
|
1237
2618
|
Momentum,
|
|
1238
2619
|
MultiHeadAttention,
|
|
1239
2620
|
Network,
|
|
@@ -1244,11 +2625,13 @@ export {
|
|
|
1244
2625
|
Neuron,
|
|
1245
2626
|
NeuronN,
|
|
1246
2627
|
SGD,
|
|
2628
|
+
Trainer,
|
|
1247
2629
|
TransformerBlock,
|
|
1248
2630
|
WeightMatrix,
|
|
1249
2631
|
crossEntropy,
|
|
1250
2632
|
crossEntropyDelta,
|
|
1251
2633
|
crossEntropyDeltaRaw,
|
|
2634
|
+
defaultOptimizer,
|
|
1252
2635
|
elu,
|
|
1253
2636
|
leakyRelu,
|
|
1254
2637
|
linear,
|
|
@@ -1262,5 +2645,9 @@ export {
|
|
|
1262
2645
|
softmax,
|
|
1263
2646
|
softmaxBackward,
|
|
1264
2647
|
tanh,
|
|
1265
|
-
transpose
|
|
2648
|
+
transpose,
|
|
2649
|
+
validate2DArray,
|
|
2650
|
+
validateArray,
|
|
2651
|
+
validateArrayMinLength,
|
|
2652
|
+
validateNumber
|
|
1266
2653
|
};
|