@dniskav/neuron 0.2.3 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -3
- package/dist/index.d.mts +281 -22
- package/dist/index.d.ts +281 -22
- package/dist/index.js +1507 -104
- package/dist/index.mjs +1490 -103
- package/package.json +7 -3
package/dist/index.js
CHANGED
|
@@ -22,10 +22,20 @@ var index_exports = {};
|
|
|
22
22
|
__export(index_exports, {
|
|
23
23
|
Adam: () => Adam,
|
|
24
24
|
AttentionHead: () => AttentionHead,
|
|
25
|
+
BatchNorm: () => BatchNorm,
|
|
26
|
+
BiasVector: () => BiasVector,
|
|
27
|
+
ClipOptimizer: () => ClipOptimizer,
|
|
28
|
+
ClippedOptimizerFactory: () => ClippedOptimizerFactory,
|
|
29
|
+
Conv1D: () => Conv1D,
|
|
30
|
+
DataLoader: () => DataLoader,
|
|
31
|
+
Dropout: () => Dropout,
|
|
25
32
|
EmbeddingMatrix: () => EmbeddingMatrix,
|
|
33
|
+
GRULayer: () => GRULayer,
|
|
34
|
+
LRScheduler: () => LRScheduler,
|
|
26
35
|
LSTMLayer: () => LSTMLayer,
|
|
27
36
|
Layer: () => Layer,
|
|
28
37
|
LayerNorm: () => LayerNorm,
|
|
38
|
+
ModelSaver: () => ModelSaver,
|
|
29
39
|
Momentum: () => Momentum,
|
|
30
40
|
MultiHeadAttention: () => MultiHeadAttention,
|
|
31
41
|
Network: () => Network,
|
|
@@ -36,11 +46,13 @@ __export(index_exports, {
|
|
|
36
46
|
Neuron: () => Neuron,
|
|
37
47
|
NeuronN: () => NeuronN,
|
|
38
48
|
SGD: () => SGD,
|
|
49
|
+
Trainer: () => Trainer,
|
|
39
50
|
TransformerBlock: () => TransformerBlock,
|
|
40
51
|
WeightMatrix: () => WeightMatrix,
|
|
41
52
|
crossEntropy: () => crossEntropy,
|
|
42
53
|
crossEntropyDelta: () => crossEntropyDelta,
|
|
43
54
|
crossEntropyDeltaRaw: () => crossEntropyDeltaRaw,
|
|
55
|
+
defaultOptimizer: () => defaultOptimizer,
|
|
44
56
|
elu: () => elu,
|
|
45
57
|
leakyRelu: () => leakyRelu,
|
|
46
58
|
linear: () => linear,
|
|
@@ -54,10 +66,82 @@ __export(index_exports, {
|
|
|
54
66
|
softmax: () => softmax,
|
|
55
67
|
softmaxBackward: () => softmaxBackward,
|
|
56
68
|
tanh: () => tanh,
|
|
57
|
-
transpose: () => transpose
|
|
69
|
+
transpose: () => transpose,
|
|
70
|
+
validate2DArray: () => validate2DArray,
|
|
71
|
+
validateArray: () => validateArray,
|
|
72
|
+
validateArrayMinLength: () => validateArrayMinLength,
|
|
73
|
+
validateNumber: () => validateNumber
|
|
58
74
|
});
|
|
59
75
|
module.exports = __toCommonJS(index_exports);
|
|
60
76
|
|
|
77
|
+
// src/Validation.ts
|
|
78
|
+
function validateArray(arr, expectedLength, methodName) {
|
|
79
|
+
if (!Array.isArray(arr)) {
|
|
80
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
81
|
+
}
|
|
82
|
+
if (arr.length !== expectedLength) {
|
|
83
|
+
throw new Error(
|
|
84
|
+
`${methodName}: expected array of length ${expectedLength}, got ${arr.length}`
|
|
85
|
+
);
|
|
86
|
+
}
|
|
87
|
+
for (let i = 0; i < arr.length; i++) {
|
|
88
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
89
|
+
throw new Error(
|
|
90
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
91
|
+
);
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
function validateArrayMinLength(arr, minLength, methodName) {
|
|
96
|
+
if (!Array.isArray(arr)) {
|
|
97
|
+
throw new Error(`${methodName}: expected array, got ${typeof arr}`);
|
|
98
|
+
}
|
|
99
|
+
if (arr.length < minLength) {
|
|
100
|
+
throw new Error(
|
|
101
|
+
`${methodName}: expected array of at least length ${minLength}, got ${arr.length}`
|
|
102
|
+
);
|
|
103
|
+
}
|
|
104
|
+
for (let i = 0; i < arr.length; i++) {
|
|
105
|
+
if (typeof arr[i] !== "number" || !isFinite(arr[i])) {
|
|
106
|
+
throw new Error(
|
|
107
|
+
`${methodName}: invalid value at index ${i}: ${arr[i]}`
|
|
108
|
+
);
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
function validate2DArray(arr, expectedRows, expectedCols, methodName) {
|
|
113
|
+
if (!Array.isArray(arr)) {
|
|
114
|
+
throw new Error(`${methodName}: expected 2D array, got ${typeof arr}`);
|
|
115
|
+
}
|
|
116
|
+
if (arr.length !== expectedRows) {
|
|
117
|
+
throw new Error(
|
|
118
|
+
`${methodName}: expected ${expectedRows} rows, got ${arr.length}`
|
|
119
|
+
);
|
|
120
|
+
}
|
|
121
|
+
for (let i = 0; i < arr.length; i++) {
|
|
122
|
+
if (!Array.isArray(arr[i])) {
|
|
123
|
+
throw new Error(`${methodName}: row ${i} is not an array`);
|
|
124
|
+
}
|
|
125
|
+
if (arr[i].length !== expectedCols) {
|
|
126
|
+
throw new Error(
|
|
127
|
+
`${methodName}: row ${i} expected ${expectedCols} cols, got ${arr[i].length}`
|
|
128
|
+
);
|
|
129
|
+
}
|
|
130
|
+
for (let j = 0; j < arr[i].length; j++) {
|
|
131
|
+
if (typeof arr[i][j] !== "number" || !isFinite(arr[i][j])) {
|
|
132
|
+
throw new Error(
|
|
133
|
+
`${methodName}: invalid value at [${i}][${j}]: ${arr[i][j]}`
|
|
134
|
+
);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
function validateNumber(value, methodName) {
|
|
140
|
+
if (typeof value !== "number" || !isFinite(value)) {
|
|
141
|
+
throw new Error(`${methodName}: expected finite number, got ${value}`);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
61
145
|
// src/Neuron.ts
|
|
62
146
|
function sigmoid(x) {
|
|
63
147
|
return 1 / (1 + Math.exp(-x));
|
|
@@ -68,13 +152,18 @@ var Neuron = class {
|
|
|
68
152
|
this.bias = Math.random() * 0.1;
|
|
69
153
|
}
|
|
70
154
|
predict(input) {
|
|
155
|
+
validateNumber(input, "Neuron.predict");
|
|
71
156
|
return sigmoid(input * this.weight + this.bias);
|
|
72
157
|
}
|
|
73
158
|
train(input, target, lr) {
|
|
159
|
+
validateNumber(input, "Neuron.train");
|
|
160
|
+
validateNumber(target, "Neuron.train");
|
|
161
|
+
validateNumber(lr, "Neuron.train");
|
|
74
162
|
const prediction = this.predict(input);
|
|
75
163
|
const error = target - prediction;
|
|
76
|
-
|
|
77
|
-
this.
|
|
164
|
+
const grad = error * prediction * (1 - prediction);
|
|
165
|
+
this.weight += lr * grad * input;
|
|
166
|
+
this.bias += lr * grad;
|
|
78
167
|
}
|
|
79
168
|
};
|
|
80
169
|
|
|
@@ -114,6 +203,7 @@ function makeElu(alpha = 1) {
|
|
|
114
203
|
var elu = makeElu(1);
|
|
115
204
|
|
|
116
205
|
// src/optimizers.ts
|
|
206
|
+
var defaultOptimizer = () => new SGD();
|
|
117
207
|
var SGD = class {
|
|
118
208
|
step(weight, gradient, lr) {
|
|
119
209
|
return weight + lr * gradient;
|
|
@@ -129,6 +219,19 @@ var Momentum = class {
|
|
|
129
219
|
return weight + this.v;
|
|
130
220
|
}
|
|
131
221
|
};
|
|
222
|
+
var ClipOptimizer = class {
|
|
223
|
+
constructor(inner, clipValue) {
|
|
224
|
+
this.inner = inner;
|
|
225
|
+
this.clipValue = clipValue;
|
|
226
|
+
}
|
|
227
|
+
step(weight, gradient, lr) {
|
|
228
|
+
const clipped = Math.max(-this.clipValue, Math.min(this.clipValue, gradient));
|
|
229
|
+
return this.inner.step(weight, clipped, lr);
|
|
230
|
+
}
|
|
231
|
+
};
|
|
232
|
+
function ClippedOptimizerFactory(innerFactory, clipValue) {
|
|
233
|
+
return () => new ClipOptimizer(innerFactory(), clipValue);
|
|
234
|
+
}
|
|
132
235
|
var Adam = class {
|
|
133
236
|
constructor(beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8) {
|
|
134
237
|
this.beta1 = beta1;
|
|
@@ -149,7 +252,6 @@ var Adam = class {
|
|
|
149
252
|
};
|
|
150
253
|
|
|
151
254
|
// src/NeuronN.ts
|
|
152
|
-
var defaultOptimizer = () => new SGD();
|
|
153
255
|
var NeuronN = class {
|
|
154
256
|
constructor(nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
|
|
155
257
|
const limit = Math.sqrt(1 / nInputs);
|
|
@@ -159,6 +261,7 @@ var NeuronN = class {
|
|
|
159
261
|
this._opts = Array.from({ length: nInputs + 1 }, optimizerFactory);
|
|
160
262
|
}
|
|
161
263
|
predict(inputs) {
|
|
264
|
+
validateArray(inputs, this.weights.length, "NeuronN.predict");
|
|
162
265
|
const sum = inputs.reduce((acc, e, i) => acc + e * this.weights[i], this.bias);
|
|
163
266
|
return this.activation.fn(sum);
|
|
164
267
|
}
|
|
@@ -171,14 +274,14 @@ var NeuronN = class {
|
|
|
171
274
|
train(inputs, target, lr) {
|
|
172
275
|
const prediction = this.predict(inputs);
|
|
173
276
|
const error = target - prediction;
|
|
174
|
-
|
|
277
|
+
const grad = error * this.activation.dfn(prediction);
|
|
278
|
+
this._update(inputs.map((inp) => grad * inp), grad, lr);
|
|
175
279
|
}
|
|
176
280
|
};
|
|
177
281
|
|
|
178
282
|
// src/Layer.ts
|
|
179
|
-
var defaultOptimizer2 = () => new SGD();
|
|
180
283
|
var Layer = class {
|
|
181
|
-
constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory =
|
|
284
|
+
constructor(nNeurons, nInputs, activation = sigmoid2, optimizerFactory = defaultOptimizer) {
|
|
182
285
|
this.neurons = Array.from(
|
|
183
286
|
{ length: nNeurons },
|
|
184
287
|
() => new NeuronN(nInputs, activation, optimizerFactory)
|
|
@@ -196,84 +299,233 @@ var Network = class {
|
|
|
196
299
|
this.outputLayer = new Layer(nOutputs, nHidden);
|
|
197
300
|
}
|
|
198
301
|
predict(inputs) {
|
|
302
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.predict");
|
|
199
303
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
200
|
-
return this.outputLayer.predict(hiddenOut)
|
|
304
|
+
return this.outputLayer.predict(hiddenOut);
|
|
201
305
|
}
|
|
202
306
|
// Trains on a single example. Returns the squared error.
|
|
203
307
|
train(inputs, target, lr) {
|
|
308
|
+
validateArray(inputs, this.hiddenLayer.neurons[0].weights.length, "Network.train");
|
|
309
|
+
validateNumber(target, "Network.train");
|
|
310
|
+
validateNumber(lr, "Network.train");
|
|
204
311
|
const hiddenOut = this.hiddenLayer.predict(inputs);
|
|
205
312
|
const prediction = this.outputLayer.predict(hiddenOut)[0];
|
|
206
|
-
const outputError = target - prediction;
|
|
207
|
-
const outputDelta = outputError * prediction * (1 - prediction);
|
|
208
313
|
const outputNeuron = this.outputLayer.neurons[0];
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
)
|
|
212
|
-
outputNeuron.bias += lr * outputDelta;
|
|
213
|
-
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
214
|
-
const hiddenOut_i = hiddenOut[i];
|
|
314
|
+
const outputError = target - prediction;
|
|
315
|
+
const outputDelta = outputError * outputNeuron.activation.dfn(prediction);
|
|
316
|
+
const hiddenDeltas = this.hiddenLayer.neurons.map((neuron, i) => {
|
|
215
317
|
const hiddenError = outputDelta * outputNeuron.weights[i];
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
318
|
+
return hiddenError * neuron.activation.dfn(hiddenOut[i]);
|
|
319
|
+
});
|
|
320
|
+
this.hiddenLayer.neurons.forEach((neuron, i) => {
|
|
321
|
+
neuron._update(inputs.map((inp) => hiddenDeltas[i] * inp), hiddenDeltas[i], lr);
|
|
219
322
|
});
|
|
323
|
+
outputNeuron._update(hiddenOut.map((h) => outputDelta * h), outputDelta, lr);
|
|
220
324
|
return outputError * outputError;
|
|
221
325
|
}
|
|
326
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
327
|
+
// Order: hidden layer (all neurons: weights then bias), then output layer.
|
|
328
|
+
getWeights() {
|
|
329
|
+
const w = [];
|
|
330
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
331
|
+
w.push(...n.weights, n.bias);
|
|
332
|
+
}
|
|
333
|
+
for (const n of this.outputLayer.neurons) {
|
|
334
|
+
w.push(...n.weights, n.bias);
|
|
335
|
+
}
|
|
336
|
+
return w;
|
|
337
|
+
}
|
|
338
|
+
setWeights(weights) {
|
|
339
|
+
let idx = 0;
|
|
340
|
+
for (const n of this.hiddenLayer.neurons) {
|
|
341
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
342
|
+
n.bias = weights[idx++];
|
|
343
|
+
}
|
|
344
|
+
for (const n of this.outputLayer.neurons) {
|
|
345
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
346
|
+
n.bias = weights[idx++];
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
};
|
|
350
|
+
|
|
351
|
+
// src/Dropout.ts
|
|
352
|
+
var Dropout = class {
|
|
353
|
+
constructor(rate) {
|
|
354
|
+
this._mask = null;
|
|
355
|
+
if (rate < 0 || rate >= 1) {
|
|
356
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${rate}`);
|
|
357
|
+
}
|
|
358
|
+
this.rate = rate;
|
|
359
|
+
}
|
|
360
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
361
|
+
// x: number[] → number[]
|
|
362
|
+
// If training, applies inverted dropout mask.
|
|
363
|
+
// If not training, returns input unchanged.
|
|
364
|
+
forward(x, training = true) {
|
|
365
|
+
if (!training || this.rate === 0) {
|
|
366
|
+
this._mask = null;
|
|
367
|
+
return [...x];
|
|
368
|
+
}
|
|
369
|
+
const scale = 1 / (1 - this.rate);
|
|
370
|
+
this._mask = x.map(() => Math.random() > this.rate ? scale : 0);
|
|
371
|
+
return x.map((v, i) => v * this._mask[i]);
|
|
372
|
+
}
|
|
373
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
374
|
+
// dOut: number[] → number[]
|
|
375
|
+
// Applies the same mask (gradient is zeroed where activation was zeroed).
|
|
376
|
+
backward(dOut) {
|
|
377
|
+
if (!this._mask) return [...dOut];
|
|
378
|
+
return dOut.map((d, i) => d * this._mask[i]);
|
|
379
|
+
}
|
|
380
|
+
// ── Reset mask between forward passes ─────────────────────────────────────
|
|
381
|
+
resetMask() {
|
|
382
|
+
this._mask = null;
|
|
383
|
+
}
|
|
384
|
+
// ── No trainable params ───────────────────────────────────────────────────
|
|
385
|
+
getWeights() {
|
|
386
|
+
return [];
|
|
387
|
+
}
|
|
388
|
+
setWeights(_weights) {
|
|
389
|
+
}
|
|
222
390
|
};
|
|
223
391
|
|
|
224
392
|
// src/NetworkN.ts
|
|
225
|
-
var defaultOptimizer3 = () => new SGD();
|
|
226
393
|
var NetworkN = class {
|
|
227
394
|
constructor(structure, options = {}) {
|
|
228
395
|
this.structure = structure;
|
|
229
396
|
const nLayers = structure.length - 1;
|
|
230
397
|
const activations = options.activations ?? Array.from({ length: nLayers }, () => sigmoid2);
|
|
231
|
-
const optimizer = options.optimizer ??
|
|
398
|
+
const optimizer = options.optimizer ?? defaultOptimizer;
|
|
399
|
+
const dropoutRate = options.dropoutRate ?? 0;
|
|
400
|
+
if (activations.length !== nLayers) {
|
|
401
|
+
throw new Error(`Expected ${nLayers} activations, got ${activations.length}`);
|
|
402
|
+
}
|
|
403
|
+
if (dropoutRate < 0 || dropoutRate >= 1) {
|
|
404
|
+
throw new Error(`Dropout rate must be in [0, 1), got ${dropoutRate}`);
|
|
405
|
+
}
|
|
406
|
+
this._residual = options.residual ?? false;
|
|
232
407
|
this.layers = [];
|
|
233
408
|
for (let i = 1; i < structure.length; i++) {
|
|
234
409
|
this.layers.push(new Layer(structure[i], structure[i - 1], activations[i - 1], optimizer));
|
|
235
410
|
}
|
|
411
|
+
this._dropouts = [];
|
|
412
|
+
if (dropoutRate > 0) {
|
|
413
|
+
for (let i = 0; i < nLayers - 1; i++) {
|
|
414
|
+
this._dropouts.push(new Dropout(dropoutRate));
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
const outputLayer = this.layers[this.layers.length - 1];
|
|
418
|
+
const outputActivation = outputLayer.neurons[0].activation;
|
|
419
|
+
for (let i = 1; i < outputLayer.neurons.length; i++) {
|
|
420
|
+
if (outputLayer.neurons[i].activation !== outputActivation) {
|
|
421
|
+
throw new Error("All output neurons must share the same activation function");
|
|
422
|
+
}
|
|
423
|
+
}
|
|
236
424
|
}
|
|
237
|
-
predict(inputs) {
|
|
238
|
-
|
|
425
|
+
predict(inputs, training = false) {
|
|
426
|
+
validateArray(inputs, this.structure[0], "NetworkN.predict");
|
|
427
|
+
let current = [...inputs];
|
|
428
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
429
|
+
const layerInput = [...current];
|
|
430
|
+
const layerOutput = this.layers[i].predict(current);
|
|
431
|
+
if (this._shouldResidual(i)) {
|
|
432
|
+
if (this.structure[i] === this.structure[i + 1]) {
|
|
433
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
434
|
+
} else {
|
|
435
|
+
current = [...layerOutput];
|
|
436
|
+
}
|
|
437
|
+
} else {
|
|
438
|
+
current = [...layerOutput];
|
|
439
|
+
}
|
|
440
|
+
if (i < this._dropouts.length) {
|
|
441
|
+
current = this._dropouts[i].forward(current, training);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
return current;
|
|
239
445
|
}
|
|
240
446
|
// Generalized backpropagation across L layers.
|
|
241
447
|
// Returns the mean squared error for the example.
|
|
242
448
|
train(inputs, targets, lr) {
|
|
243
|
-
|
|
244
|
-
|
|
449
|
+
validateArray(inputs, this.structure[0], "NetworkN.train");
|
|
450
|
+
validateArray(targets, this.structure[this.structure.length - 1], "NetworkN.train");
|
|
451
|
+
const act = this._forwardAll(inputs, true);
|
|
245
452
|
const pred = act[act.length - 1];
|
|
246
453
|
const outAct = this.layers[this.layers.length - 1].neurons[0].activation;
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
const layer = this.layers[l];
|
|
250
|
-
const layerIn = act[l];
|
|
251
|
-
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
252
|
-
const prevDeltas = layerIn.map((out, j) => {
|
|
253
|
-
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
254
|
-
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
255
|
-
});
|
|
256
|
-
layer.neurons.forEach((n, k) => {
|
|
257
|
-
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
258
|
-
});
|
|
259
|
-
deltas = prevDeltas;
|
|
260
|
-
}
|
|
454
|
+
const deltas = pred.map((p, i) => (targets[i] - p) * outAct.dfn(p));
|
|
455
|
+
this._backpropLayers(act, deltas, lr);
|
|
261
456
|
return pred.reduce((s, p, i) => s + (targets[i] - p) ** 2, 0) / pred.length;
|
|
262
457
|
}
|
|
263
458
|
// Backprop with externally provided output-layer deltas.
|
|
264
459
|
// Useful for custom loss functions (e.g. physics-based gradients).
|
|
265
460
|
trainWithDeltas(inputs, outputDeltas, lr) {
|
|
461
|
+
const act = this._forwardAll(inputs, true);
|
|
462
|
+
this._backpropLayers(act, outputDeltas, lr);
|
|
463
|
+
}
|
|
464
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
465
|
+
// Order: layer 0 (all neurons), layer 1, ..., layer N.
|
|
466
|
+
getWeights() {
|
|
467
|
+
for (const d of this._dropouts) d.resetMask();
|
|
468
|
+
const w = [];
|
|
469
|
+
for (const layer of this.layers) {
|
|
470
|
+
for (const n of layer.neurons) {
|
|
471
|
+
w.push(...n.weights, n.bias);
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
return w;
|
|
475
|
+
}
|
|
476
|
+
setWeights(weights) {
|
|
477
|
+
for (const d of this._dropouts) d.resetMask();
|
|
478
|
+
let idx = 0;
|
|
479
|
+
for (const layer of this.layers) {
|
|
480
|
+
for (const n of layer.neurons) {
|
|
481
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
482
|
+
n.bias = weights[idx++];
|
|
483
|
+
}
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
// ── Private helpers ──────────────────────────────────────────────────────
|
|
487
|
+
_shouldResidual(layerIndex) {
|
|
488
|
+
if (typeof this._residual === "function") return this._residual(layerIndex);
|
|
489
|
+
return this._residual;
|
|
490
|
+
}
|
|
491
|
+
// Forward pass storing activations at every layer boundary.
|
|
492
|
+
// Used by train(), trainWithDeltas(), and predict() shares the same logic.
|
|
493
|
+
_forwardAll(inputs, training) {
|
|
266
494
|
const act = [inputs];
|
|
267
|
-
for (
|
|
495
|
+
for (let i = 0; i < this.layers.length; i++) {
|
|
496
|
+
const layerInput = act[act.length - 1];
|
|
497
|
+
const layerOutput = this.layers[i].predict(layerInput);
|
|
498
|
+
let current;
|
|
499
|
+
if (this._shouldResidual(i) && this.structure[i] === this.structure[i + 1]) {
|
|
500
|
+
current = layerOutput.map((v, j) => v + layerInput[j]);
|
|
501
|
+
} else {
|
|
502
|
+
current = layerOutput;
|
|
503
|
+
}
|
|
504
|
+
if (i < this._dropouts.length) {
|
|
505
|
+
current = this._dropouts[i].forward(current, training);
|
|
506
|
+
}
|
|
507
|
+
act.push(current);
|
|
508
|
+
}
|
|
509
|
+
return act;
|
|
510
|
+
}
|
|
511
|
+
// Backward pass: updates all layer weights given the pre-computed activations
|
|
512
|
+
// and the initial output-layer deltas.
|
|
513
|
+
_backpropLayers(act, outputDeltas, lr) {
|
|
268
514
|
let deltas = outputDeltas;
|
|
269
515
|
for (let l = this.layers.length - 1; l >= 0; l--) {
|
|
270
516
|
const layer = this.layers[l];
|
|
517
|
+
if (l < this._dropouts.length) {
|
|
518
|
+
deltas = this._dropouts[l].backward(deltas);
|
|
519
|
+
}
|
|
271
520
|
const layerIn = act[l];
|
|
272
521
|
const prevAct = l > 0 ? this.layers[l - 1].neurons[0].activation : null;
|
|
273
522
|
const prevDeltas = layerIn.map((out, j) => {
|
|
274
523
|
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
275
524
|
return prevAct ? errProp * prevAct.dfn(out) : errProp;
|
|
276
525
|
});
|
|
526
|
+
if (this._shouldResidual(l) && this.structure[l] === this.structure[l + 1]) {
|
|
527
|
+
for (let j = 0; j < prevDeltas.length; j++) prevDeltas[j] += deltas[j];
|
|
528
|
+
}
|
|
277
529
|
layer.neurons.forEach((n, k) => {
|
|
278
530
|
n._update(layerIn.map((inp) => deltas[k] * inp), deltas[k], lr);
|
|
279
531
|
});
|
|
@@ -294,7 +546,7 @@ var Gate = class {
|
|
|
294
546
|
// shape: [hSize]
|
|
295
547
|
constructor(inputSize, hSize, initBias = 0) {
|
|
296
548
|
const n = inputSize + hSize;
|
|
297
|
-
const limit = Math.sqrt(2 / n);
|
|
549
|
+
const limit = Math.sqrt(2 / (n + hSize));
|
|
298
550
|
this.W = Array.from(
|
|
299
551
|
{ length: hSize },
|
|
300
552
|
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
@@ -308,8 +560,11 @@ var Gate = class {
|
|
|
308
560
|
}
|
|
309
561
|
};
|
|
310
562
|
var LSTMLayer = class {
|
|
311
|
-
constructor(inputSize, hiddenSize) {
|
|
563
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
312
564
|
this._traj = [];
|
|
565
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
566
|
+
throw new Error(`LSTMLayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
567
|
+
}
|
|
313
568
|
this.inputSize = inputSize;
|
|
314
569
|
this.hSize = hiddenSize;
|
|
315
570
|
this.h = new Array(hiddenSize).fill(0);
|
|
@@ -318,6 +573,29 @@ var LSTMLayer = class {
|
|
|
318
573
|
this.inputGate = new Gate(inputSize, hiddenSize);
|
|
319
574
|
this.cellGate = new Gate(inputSize, hiddenSize);
|
|
320
575
|
this.outputGate = new Gate(inputSize, hiddenSize);
|
|
576
|
+
const combSize = inputSize + hiddenSize;
|
|
577
|
+
this._optimizers = {
|
|
578
|
+
forgetW: Array.from(
|
|
579
|
+
{ length: hiddenSize },
|
|
580
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
581
|
+
),
|
|
582
|
+
forgetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
583
|
+
inputW: Array.from(
|
|
584
|
+
{ length: hiddenSize },
|
|
585
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
586
|
+
),
|
|
587
|
+
inputB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
588
|
+
cellW: Array.from(
|
|
589
|
+
{ length: hiddenSize },
|
|
590
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
591
|
+
),
|
|
592
|
+
cellB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
593
|
+
outputW: Array.from(
|
|
594
|
+
{ length: hiddenSize },
|
|
595
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
596
|
+
),
|
|
597
|
+
outputB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
598
|
+
};
|
|
321
599
|
}
|
|
322
600
|
// ── Reset state and trajectory (call at episode start) ────────────────────
|
|
323
601
|
reset() {
|
|
@@ -327,6 +605,9 @@ var LSTMLayer = class {
|
|
|
327
605
|
}
|
|
328
606
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
329
607
|
predict(inputs) {
|
|
608
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
609
|
+
throw new Error(`LSTMLayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
610
|
+
}
|
|
330
611
|
const combined = [...inputs, ...this.h];
|
|
331
612
|
const c_prev = [...this.c];
|
|
332
613
|
const zf = this.forgetGate.linear(combined);
|
|
@@ -401,15 +682,15 @@ var LSTMLayer = class {
|
|
|
401
682
|
const scale = lr / T;
|
|
402
683
|
for (let k = 0; k < hSize; k++) {
|
|
403
684
|
for (let j = 0; j < combSize; j++) {
|
|
404
|
-
this.forgetGate.W[k][j]
|
|
405
|
-
this.inputGate.W[k][j]
|
|
406
|
-
this.cellGate.W[k][j]
|
|
407
|
-
this.outputGate.W[k][j]
|
|
685
|
+
this.forgetGate.W[k][j] = this._optimizers.forgetW[k][j].step(this.forgetGate.W[k][j], dWf[k][j], scale);
|
|
686
|
+
this.inputGate.W[k][j] = this._optimizers.inputW[k][j].step(this.inputGate.W[k][j], dWi[k][j], scale);
|
|
687
|
+
this.cellGate.W[k][j] = this._optimizers.cellW[k][j].step(this.cellGate.W[k][j], dWg[k][j], scale);
|
|
688
|
+
this.outputGate.W[k][j] = this._optimizers.outputW[k][j].step(this.outputGate.W[k][j], dWo[k][j], scale);
|
|
408
689
|
}
|
|
409
|
-
this.forgetGate.b[k]
|
|
410
|
-
this.inputGate.b[k]
|
|
411
|
-
this.cellGate.b[k]
|
|
412
|
-
this.outputGate.b[k]
|
|
690
|
+
this.forgetGate.b[k] = this._optimizers.forgetB[k].step(this.forgetGate.b[k], dbf[k], scale);
|
|
691
|
+
this.inputGate.b[k] = this._optimizers.inputB[k].step(this.inputGate.b[k], dbi[k], scale);
|
|
692
|
+
this.cellGate.b[k] = this._optimizers.cellB[k].step(this.cellGate.b[k], dbg[k], scale);
|
|
693
|
+
this.outputGate.b[k] = this._optimizers.outputB[k].step(this.outputGate.b[k], dbo[k], scale);
|
|
413
694
|
}
|
|
414
695
|
this._traj = [];
|
|
415
696
|
}
|
|
@@ -432,10 +713,38 @@ var LSTMLayer = class {
|
|
|
432
713
|
this.outputGate.W = data.outputGate.W;
|
|
433
714
|
this.outputGate.b = data.outputGate.b;
|
|
434
715
|
}
|
|
716
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
717
|
+
// Order: forgetGate (W, b), inputGate (W, b), cellGate (W, b), outputGate (W, b).
|
|
718
|
+
getWeightsFlat() {
|
|
719
|
+
const w = [];
|
|
720
|
+
for (const row of this.forgetGate.W) w.push(...row);
|
|
721
|
+
w.push(...this.forgetGate.b);
|
|
722
|
+
for (const row of this.inputGate.W) w.push(...row);
|
|
723
|
+
w.push(...this.inputGate.b);
|
|
724
|
+
for (const row of this.cellGate.W) w.push(...row);
|
|
725
|
+
w.push(...this.cellGate.b);
|
|
726
|
+
for (const row of this.outputGate.W) w.push(...row);
|
|
727
|
+
w.push(...this.outputGate.b);
|
|
728
|
+
return w;
|
|
729
|
+
}
|
|
730
|
+
setWeightsFlat(weights) {
|
|
731
|
+
let idx = 0;
|
|
732
|
+
for (let i = 0; i < this.forgetGate.W.length; i++)
|
|
733
|
+
for (let j = 0; j < this.forgetGate.W[i].length; j++) this.forgetGate.W[i][j] = weights[idx++];
|
|
734
|
+
for (let i = 0; i < this.forgetGate.b.length; i++) this.forgetGate.b[i] = weights[idx++];
|
|
735
|
+
for (let i = 0; i < this.inputGate.W.length; i++)
|
|
736
|
+
for (let j = 0; j < this.inputGate.W[i].length; j++) this.inputGate.W[i][j] = weights[idx++];
|
|
737
|
+
for (let i = 0; i < this.inputGate.b.length; i++) this.inputGate.b[i] = weights[idx++];
|
|
738
|
+
for (let i = 0; i < this.cellGate.W.length; i++)
|
|
739
|
+
for (let j = 0; j < this.cellGate.W[i].length; j++) this.cellGate.W[i][j] = weights[idx++];
|
|
740
|
+
for (let i = 0; i < this.cellGate.b.length; i++) this.cellGate.b[i] = weights[idx++];
|
|
741
|
+
for (let i = 0; i < this.outputGate.W.length; i++)
|
|
742
|
+
for (let j = 0; j < this.outputGate.W[i].length; j++) this.outputGate.W[i][j] = weights[idx++];
|
|
743
|
+
for (let i = 0; i < this.outputGate.b.length; i++) this.outputGate.b[i] = weights[idx++];
|
|
744
|
+
}
|
|
435
745
|
};
|
|
436
746
|
|
|
437
747
|
// src/NetworkLSTM.ts
|
|
438
|
-
var defaultOptimizer4 = () => new SGD();
|
|
439
748
|
var NetworkLSTM = class {
|
|
440
749
|
// [T][layer+1][neuron]
|
|
441
750
|
constructor(inputSize, hiddenSize, denseStructure, options = {}) {
|
|
@@ -443,7 +752,7 @@ var NetworkLSTM = class {
|
|
|
443
752
|
this.hiddenSize = hiddenSize;
|
|
444
753
|
this.lstm = new LSTMLayer(inputSize, hiddenSize);
|
|
445
754
|
const activation = options.denseActivation ?? sigmoid2;
|
|
446
|
-
const optimizer = options.optimizer ??
|
|
755
|
+
const optimizer = options.optimizer ?? defaultOptimizer;
|
|
447
756
|
this.denseLayers = [];
|
|
448
757
|
const sizes = [hiddenSize, ...denseStructure];
|
|
449
758
|
for (let i = 1; i < sizes.length; i++) {
|
|
@@ -458,6 +767,7 @@ var NetworkLSTM = class {
|
|
|
458
767
|
}
|
|
459
768
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
460
769
|
predict(inputs) {
|
|
770
|
+
validateArray(inputs, this.inputSize, "NetworkLSTM.predict");
|
|
461
771
|
const h = this.lstm.predict(inputs);
|
|
462
772
|
const acts = [h];
|
|
463
773
|
for (const layer of this.denseLayers) {
|
|
@@ -533,6 +843,30 @@ var NetworkLSTM = class {
|
|
|
533
843
|
});
|
|
534
844
|
});
|
|
535
845
|
}
|
|
846
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
847
|
+
// Order: LSTM (flat), then dense layer 0, dense layer 1, ..., dense layer N.
|
|
848
|
+
getWeightsFlat() {
|
|
849
|
+
const w = [];
|
|
850
|
+
w.push(...this.lstm.getWeightsFlat());
|
|
851
|
+
for (const layer of this.denseLayers) {
|
|
852
|
+
for (const n of layer.neurons) {
|
|
853
|
+
w.push(...n.weights, n.bias);
|
|
854
|
+
}
|
|
855
|
+
}
|
|
856
|
+
return w;
|
|
857
|
+
}
|
|
858
|
+
setWeightsFlat(weights) {
|
|
859
|
+
let idx = 0;
|
|
860
|
+
const lstmLen = this.lstm.getWeightsFlat().length;
|
|
861
|
+
this.lstm.setWeightsFlat(weights.slice(idx, idx + lstmLen));
|
|
862
|
+
idx += lstmLen;
|
|
863
|
+
for (const layer of this.denseLayers) {
|
|
864
|
+
for (const n of layer.neurons) {
|
|
865
|
+
for (let j = 0; j < n.weights.length; j++) n.weights[j] = weights[idx++];
|
|
866
|
+
n.bias = weights[idx++];
|
|
867
|
+
}
|
|
868
|
+
}
|
|
869
|
+
}
|
|
536
870
|
};
|
|
537
871
|
|
|
538
872
|
// src/MatMul.ts
|
|
@@ -540,6 +874,9 @@ function matMul(A, B) {
|
|
|
540
874
|
const rows = A.length;
|
|
541
875
|
const inner = B.length;
|
|
542
876
|
const cols = B[0].length;
|
|
877
|
+
if (A[0].length !== B.length) {
|
|
878
|
+
throw new Error(`Incompatible dimensions for matrix multiplication: A cols (${A[0].length}) !== B rows (${B.length})`);
|
|
879
|
+
}
|
|
543
880
|
const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
|
|
544
881
|
for (let i = 0; i < rows; i++)
|
|
545
882
|
for (let k = 0; k < inner; k++) {
|
|
@@ -590,6 +927,33 @@ var WeightMatrix = class {
|
|
|
590
927
|
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
591
928
|
}
|
|
592
929
|
}
|
|
930
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
931
|
+
getWeights() {
|
|
932
|
+
const w = [];
|
|
933
|
+
for (const row of this.W) w.push(...row);
|
|
934
|
+
return w;
|
|
935
|
+
}
|
|
936
|
+
setWeights(weights) {
|
|
937
|
+
let idx = 0;
|
|
938
|
+
for (let i = 0; i < this.W.length; i++)
|
|
939
|
+
for (let j = 0; j < this.W[i].length; j++) this.W[i][j] = weights[idx++];
|
|
940
|
+
}
|
|
941
|
+
};
|
|
942
|
+
var BiasVector = class {
|
|
943
|
+
constructor(size) {
|
|
944
|
+
this.values = new Array(size).fill(0);
|
|
945
|
+
this.opts = Array.from({ length: size }, () => new Adam());
|
|
946
|
+
}
|
|
947
|
+
update(grad, lr) {
|
|
948
|
+
for (let i = 0; i < this.values.length; i++)
|
|
949
|
+
this.values[i] = this.opts[i].step(this.values[i], grad[i], lr);
|
|
950
|
+
}
|
|
951
|
+
getWeights() {
|
|
952
|
+
return [...this.values];
|
|
953
|
+
}
|
|
954
|
+
setWeights(weights) {
|
|
955
|
+
for (let i = 0; i < this.values.length; i++) this.values[i] = weights[i];
|
|
956
|
+
}
|
|
593
957
|
};
|
|
594
958
|
var EmbeddingMatrix = class {
|
|
595
959
|
constructor(vocabSize, d_model) {
|
|
@@ -606,15 +970,29 @@ var EmbeddingMatrix = class {
|
|
|
606
970
|
for (let m = 0; m < this.W[idx].length; m++)
|
|
607
971
|
this.W[idx][m] += lr * grad[m];
|
|
608
972
|
}
|
|
973
|
+
// ── Serializable interface ─────────────────────────────────────────────────
|
|
974
|
+
// Flattened order: row 0, row 1, ... row (vocabSize-1)
|
|
975
|
+
getWeights() {
|
|
976
|
+
const w = [];
|
|
977
|
+
for (const row of this.W) w.push(...row);
|
|
978
|
+
return w;
|
|
979
|
+
}
|
|
980
|
+
setWeights(weights) {
|
|
981
|
+
let idx = 0;
|
|
982
|
+
for (let i = 0; i < this.W.length; i++)
|
|
983
|
+
for (let j = 0; j < this.W[i].length; j++)
|
|
984
|
+
this.W[i][j] = weights[idx++];
|
|
985
|
+
}
|
|
609
986
|
};
|
|
610
987
|
|
|
611
988
|
// src/AttentionHead.ts
|
|
612
989
|
var AttentionHead = class {
|
|
613
|
-
constructor(d_model, d_k, d_v) {
|
|
990
|
+
constructor(d_model, d_k, d_v, causal = false) {
|
|
614
991
|
// d_v × d_model
|
|
615
992
|
this.cache = null;
|
|
616
993
|
this.d_k = d_k;
|
|
617
994
|
this.d_v = d_v;
|
|
995
|
+
this.causal = causal;
|
|
618
996
|
this.Wq = new WeightMatrix(d_k, d_model);
|
|
619
997
|
this.Wk = new WeightMatrix(d_k, d_model);
|
|
620
998
|
this.Wv = new WeightMatrix(d_v, d_model);
|
|
@@ -635,10 +1013,10 @@ var AttentionHead = class {
|
|
|
635
1013
|
);
|
|
636
1014
|
const scores = Array.from(
|
|
637
1015
|
{ length: seqLen },
|
|
638
|
-
(_, i) => Array.from(
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
)
|
|
1016
|
+
(_, i) => Array.from({ length: seqLen }, (_2, j) => {
|
|
1017
|
+
if (this.causal && j > i) return -Infinity;
|
|
1018
|
+
return Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale;
|
|
1019
|
+
})
|
|
642
1020
|
);
|
|
643
1021
|
const attn = scores.map((row) => softmax(row));
|
|
644
1022
|
const out = Array.from(
|
|
@@ -662,6 +1040,7 @@ var AttentionHead = class {
|
|
|
662
1040
|
// 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
|
|
663
1041
|
// 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
|
|
664
1042
|
backward(dOut, lr) {
|
|
1043
|
+
if (!this.cache) throw new Error("AttentionHead.backward() called before predict()");
|
|
665
1044
|
const { X, Q, K, V, attn } = this.cache;
|
|
666
1045
|
const seqLen = X.length;
|
|
667
1046
|
const d_model = X[0].length;
|
|
@@ -734,21 +1113,40 @@ var AttentionHead = class {
|
|
|
734
1113
|
getAttentionWeights() {
|
|
735
1114
|
return this.cache ? this.cache.attn : null;
|
|
736
1115
|
}
|
|
1116
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1117
|
+
// Order: Wq, Wk, Wv.
|
|
1118
|
+
getWeights() {
|
|
1119
|
+
const w = [];
|
|
1120
|
+
for (const row of this.Wq.W) w.push(...row);
|
|
1121
|
+
for (const row of this.Wk.W) w.push(...row);
|
|
1122
|
+
for (const row of this.Wv.W) w.push(...row);
|
|
1123
|
+
return w;
|
|
1124
|
+
}
|
|
1125
|
+
setWeights(weights) {
|
|
1126
|
+
let idx = 0;
|
|
1127
|
+
for (let i = 0; i < this.Wq.W.length; i++)
|
|
1128
|
+
for (let j = 0; j < this.Wq.W[i].length; j++) this.Wq.W[i][j] = weights[idx++];
|
|
1129
|
+
for (let i = 0; i < this.Wk.W.length; i++)
|
|
1130
|
+
for (let j = 0; j < this.Wk.W[i].length; j++) this.Wk.W[i][j] = weights[idx++];
|
|
1131
|
+
for (let i = 0; i < this.Wv.W.length; i++)
|
|
1132
|
+
for (let j = 0; j < this.Wv.W[i].length; j++) this.Wv.W[i][j] = weights[idx++];
|
|
1133
|
+
}
|
|
737
1134
|
};
|
|
738
1135
|
|
|
739
1136
|
// src/MultiHeadAttention.ts
|
|
740
1137
|
var MultiHeadAttention = class {
|
|
741
1138
|
// seqLen × (nHeads * d_k)
|
|
742
|
-
constructor(d_model, nHeads) {
|
|
1139
|
+
constructor(d_model, nHeads, causal = false) {
|
|
743
1140
|
// d_model × (nHeads * d_k)
|
|
744
1141
|
// Cached for backward
|
|
745
1142
|
this._concat = null;
|
|
746
1143
|
this.nHeads = nHeads;
|
|
747
1144
|
this.d_model = d_model;
|
|
748
1145
|
this.d_k = Math.floor(d_model / nHeads);
|
|
1146
|
+
this.causal = causal;
|
|
749
1147
|
this.heads = Array.from(
|
|
750
1148
|
{ length: nHeads },
|
|
751
|
-
() => new AttentionHead(d_model, this.d_k, this.d_k)
|
|
1149
|
+
() => new AttentionHead(d_model, this.d_k, this.d_k, causal)
|
|
752
1150
|
);
|
|
753
1151
|
this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
|
|
754
1152
|
}
|
|
@@ -770,6 +1168,7 @@ var MultiHeadAttention = class {
|
|
|
770
1168
|
// ── Backward ──────────────────────────────────────────────────────────────
|
|
771
1169
|
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
772
1170
|
backward(dOut, lr) {
|
|
1171
|
+
if (!this._concat) throw new Error("MultiHeadAttention.backward() called before predict()");
|
|
773
1172
|
const seqLen = dOut.length;
|
|
774
1173
|
const concatD = this.nHeads * this.d_k;
|
|
775
1174
|
const d_model = this.d_model;
|
|
@@ -807,6 +1206,31 @@ var MultiHeadAttention = class {
|
|
|
807
1206
|
getAttentionWeights() {
|
|
808
1207
|
return this.heads.map((h) => h.getAttentionWeights());
|
|
809
1208
|
}
|
|
1209
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1210
|
+
// Order: head0 (Wq, Wk, Wv), head1, ..., headN, then Wo.
|
|
1211
|
+
getWeights() {
|
|
1212
|
+
const w = [];
|
|
1213
|
+
for (const head of this.heads) {
|
|
1214
|
+
for (const row of head.Wq.W) w.push(...row);
|
|
1215
|
+
for (const row of head.Wk.W) w.push(...row);
|
|
1216
|
+
for (const row of head.Wv.W) w.push(...row);
|
|
1217
|
+
}
|
|
1218
|
+
for (const row of this.Wo.W) w.push(...row);
|
|
1219
|
+
return w;
|
|
1220
|
+
}
|
|
1221
|
+
setWeights(weights) {
|
|
1222
|
+
let idx = 0;
|
|
1223
|
+
for (const head of this.heads) {
|
|
1224
|
+
for (let i = 0; i < head.Wq.W.length; i++)
|
|
1225
|
+
for (let j = 0; j < head.Wq.W[i].length; j++) head.Wq.W[i][j] = weights[idx++];
|
|
1226
|
+
for (let i = 0; i < head.Wk.W.length; i++)
|
|
1227
|
+
for (let j = 0; j < head.Wk.W[i].length; j++) head.Wk.W[i][j] = weights[idx++];
|
|
1228
|
+
for (let i = 0; i < head.Wv.W.length; i++)
|
|
1229
|
+
for (let j = 0; j < head.Wv.W[i].length; j++) head.Wv.W[i][j] = weights[idx++];
|
|
1230
|
+
}
|
|
1231
|
+
for (let i = 0; i < this.Wo.W.length; i++)
|
|
1232
|
+
for (let j = 0; j < this.Wo.W[i].length; j++) this.Wo.W[i][j] = weights[idx++];
|
|
1233
|
+
}
|
|
810
1234
|
};
|
|
811
1235
|
|
|
812
1236
|
// src/LayerNorm.ts
|
|
@@ -849,20 +1273,32 @@ var LayerNorm = class {
|
|
|
849
1273
|
backwardOne(dOut, pos, lr) {
|
|
850
1274
|
const { x_norm, std } = this._cache[pos];
|
|
851
1275
|
const N = dOut.length;
|
|
1276
|
+
const gammaOld = this.gamma.slice();
|
|
852
1277
|
for (let i = 0; i < N; i++) {
|
|
853
1278
|
this.gamma[i] += lr * dOut[i] * x_norm[i];
|
|
854
1279
|
this.beta[i] += lr * dOut[i];
|
|
855
1280
|
}
|
|
856
|
-
const D = dOut.map((d, i) => d *
|
|
1281
|
+
const D = dOut.map((d, i) => d * gammaOld[i]);
|
|
857
1282
|
const mD = D.reduce((s, v) => s + v, 0) / N;
|
|
858
1283
|
const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
|
|
859
1284
|
return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
|
|
860
1285
|
}
|
|
1286
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1287
|
+
// Order: gamma, beta.
|
|
1288
|
+
getWeights() {
|
|
1289
|
+
return [...this.gamma, ...this.beta];
|
|
1290
|
+
}
|
|
1291
|
+
setWeights(weights) {
|
|
1292
|
+
const dim = this.gamma.length;
|
|
1293
|
+
for (let i = 0; i < dim; i++) this.gamma[i] = weights[i];
|
|
1294
|
+
for (let i = 0; i < dim; i++) this.beta[i] = weights[dim + i];
|
|
1295
|
+
}
|
|
861
1296
|
};
|
|
862
1297
|
|
|
863
1298
|
// src/TransformerBlock.ts
|
|
864
1299
|
var TransformerBlock = class {
|
|
865
|
-
constructor({ d_model, nHeads, d_ff }) {
|
|
1300
|
+
constructor({ d_model, nHeads, d_ff, causal = false }) {
|
|
1301
|
+
// d_model
|
|
866
1302
|
// Forward caches (needed for backprop)
|
|
867
1303
|
this._X = null;
|
|
868
1304
|
this._attnOut = null;
|
|
@@ -874,15 +1310,13 @@ var TransformerBlock = class {
|
|
|
874
1310
|
this._ff2Out = null;
|
|
875
1311
|
this.d_model = d_model;
|
|
876
1312
|
this.d_ff = d_ff;
|
|
877
|
-
this.attn = new MultiHeadAttention(d_model, nHeads);
|
|
1313
|
+
this.attn = new MultiHeadAttention(d_model, nHeads, causal);
|
|
878
1314
|
this.norm1 = new LayerNorm(d_model);
|
|
879
1315
|
this.norm2 = new LayerNorm(d_model);
|
|
880
1316
|
this.ff1 = new WeightMatrix(d_ff, d_model);
|
|
881
1317
|
this.ff2 = new WeightMatrix(d_model, d_ff);
|
|
882
|
-
this.b1 = new
|
|
883
|
-
this.b2 = new
|
|
884
|
-
this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
|
|
885
|
-
this.b2Opts = Array.from({ length: d_model }, () => new Adam());
|
|
1318
|
+
this.b1 = new BiasVector(d_ff);
|
|
1319
|
+
this.b2 = new BiasVector(d_model);
|
|
886
1320
|
}
|
|
887
1321
|
// ── Forward ───────────────────────────────────────────────────────────────
|
|
888
1322
|
// X: seqLen × d_model → out: seqLen × d_model
|
|
@@ -895,11 +1329,11 @@ var TransformerBlock = class {
|
|
|
895
1329
|
return this.norm1.predictOne(added, i);
|
|
896
1330
|
});
|
|
897
1331
|
const ff1Pre = h1.map(
|
|
898
|
-
(h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
|
|
1332
|
+
(h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1.values[k]))
|
|
899
1333
|
);
|
|
900
1334
|
const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
|
|
901
1335
|
const ff2Out = ff1Out.map(
|
|
902
|
-
(h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
|
|
1336
|
+
(h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2.values[k]))
|
|
903
1337
|
);
|
|
904
1338
|
this.norm2.resetCache(seqLen);
|
|
905
1339
|
const out = h1.map((h, i) => {
|
|
@@ -917,6 +1351,9 @@ var TransformerBlock = class {
|
|
|
917
1351
|
// ── Backward ──────────────────────────────────────────────────────────────
|
|
918
1352
|
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
919
1353
|
backward(dOut, lr) {
|
|
1354
|
+
if (!this._h1 || !this._ff1Out || !this._ff1Pre) {
|
|
1355
|
+
throw new Error("TransformerBlock.backward() called before predict()");
|
|
1356
|
+
}
|
|
920
1357
|
const seqLen = dOut.length;
|
|
921
1358
|
const d_model = this.d_model;
|
|
922
1359
|
const h1 = this._h1;
|
|
@@ -941,8 +1378,7 @@ var TransformerBlock = class {
|
|
|
941
1378
|
(_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
|
|
942
1379
|
);
|
|
943
1380
|
this.ff2.update(dW2, lr);
|
|
944
|
-
|
|
945
|
-
this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
|
|
1381
|
+
this.b2.update(db2, lr);
|
|
946
1382
|
const dFf1Pre = dFf1Out.map(
|
|
947
1383
|
(d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
|
|
948
1384
|
);
|
|
@@ -964,8 +1400,7 @@ var TransformerBlock = class {
|
|
|
964
1400
|
(_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
|
|
965
1401
|
);
|
|
966
1402
|
this.ff1.update(dW1, lr);
|
|
967
|
-
|
|
968
|
-
this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
|
|
1403
|
+
this.b1.update(db1, lr);
|
|
969
1404
|
const dH1 = Array.from(
|
|
970
1405
|
{ length: seqLen },
|
|
971
1406
|
(_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
|
|
@@ -987,6 +1422,36 @@ var TransformerBlock = class {
|
|
|
987
1422
|
getAttentionWeights() {
|
|
988
1423
|
return this.attn.getAttentionWeights();
|
|
989
1424
|
}
|
|
1425
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1426
|
+
// Order: attn (MHA), norm1 (gamma, beta), ff1, b1, ff2, b2, norm2 (gamma, beta).
|
|
1427
|
+
getWeights() {
|
|
1428
|
+
const w = [];
|
|
1429
|
+
w.push(...this.attn.getWeights());
|
|
1430
|
+
w.push(...this.norm1.gamma, ...this.norm1.beta);
|
|
1431
|
+
for (const row of this.ff1.W) w.push(...row);
|
|
1432
|
+
w.push(...this.b1.values);
|
|
1433
|
+
for (const row of this.ff2.W) w.push(...row);
|
|
1434
|
+
w.push(...this.b2.values);
|
|
1435
|
+
w.push(...this.norm2.gamma, ...this.norm2.beta);
|
|
1436
|
+
return w;
|
|
1437
|
+
}
|
|
1438
|
+
setWeights(weights) {
|
|
1439
|
+
let idx = 0;
|
|
1440
|
+
const attnLen = this.attn.getWeights().length;
|
|
1441
|
+
this.attn.setWeights(weights.slice(idx, idx + attnLen));
|
|
1442
|
+
idx += attnLen;
|
|
1443
|
+
this.norm1.setWeights(weights.slice(idx, idx + this.norm1.getWeights().length));
|
|
1444
|
+
idx += this.norm1.getWeights().length;
|
|
1445
|
+
this.ff1.setWeights(weights.slice(idx, idx + this.ff1.getWeights().length));
|
|
1446
|
+
idx += this.ff1.getWeights().length;
|
|
1447
|
+
this.b1.setWeights(weights.slice(idx, idx + this.b1.values.length));
|
|
1448
|
+
idx += this.b1.values.length;
|
|
1449
|
+
this.ff2.setWeights(weights.slice(idx, idx + this.ff2.getWeights().length));
|
|
1450
|
+
idx += this.ff2.getWeights().length;
|
|
1451
|
+
this.b2.setWeights(weights.slice(idx, idx + this.b2.values.length));
|
|
1452
|
+
idx += this.b2.values.length;
|
|
1453
|
+
this.norm2.setWeights(weights.slice(idx, idx + this.norm2.getWeights().length));
|
|
1454
|
+
}
|
|
990
1455
|
};
|
|
991
1456
|
|
|
992
1457
|
// src/NetworkTransformer.ts
|
|
@@ -1011,8 +1476,7 @@ var NetworkTransformer = class {
|
|
|
1011
1476
|
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1012
1477
|
);
|
|
1013
1478
|
this.outputProj = new WeightMatrix(nClasses, d_model);
|
|
1014
|
-
this.outputBias = new
|
|
1015
|
-
this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
|
|
1479
|
+
this.outputBias = new BiasVector(nClasses);
|
|
1016
1480
|
}
|
|
1017
1481
|
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
1018
1482
|
// tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
|
|
@@ -1020,7 +1484,7 @@ var NetworkTransformer = class {
|
|
|
1020
1484
|
const h = this._forward(tokens);
|
|
1021
1485
|
return h.flatMap(
|
|
1022
1486
|
(hi) => this.outputProj.W.map(
|
|
1023
|
-
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
1487
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
|
|
1024
1488
|
)
|
|
1025
1489
|
);
|
|
1026
1490
|
}
|
|
@@ -1034,7 +1498,7 @@ var NetworkTransformer = class {
|
|
|
1034
1498
|
const h = this._forward(tokens);
|
|
1035
1499
|
const logits = h.map(
|
|
1036
1500
|
(hi) => this.outputProj.W.map(
|
|
1037
|
-
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
1501
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias.values[c])
|
|
1038
1502
|
)
|
|
1039
1503
|
);
|
|
1040
1504
|
let loss = 0;
|
|
@@ -1069,8 +1533,7 @@ var NetworkTransformer = class {
|
|
|
1069
1533
|
(_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
|
|
1070
1534
|
);
|
|
1071
1535
|
this.outputProj.update(dWout, lr);
|
|
1072
|
-
|
|
1073
|
-
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1536
|
+
this.outputBias.update(dBout, lr);
|
|
1074
1537
|
let dX = dH;
|
|
1075
1538
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1076
1539
|
dX = this.blocks[b].backward(dX, lr);
|
|
@@ -1085,6 +1548,35 @@ var NetworkTransformer = class {
|
|
|
1085
1548
|
getAttentionWeights() {
|
|
1086
1549
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1087
1550
|
}
|
|
1551
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1552
|
+
// Order: tokenEmb, posEmb, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1553
|
+
getWeights() {
|
|
1554
|
+
const w = [];
|
|
1555
|
+
w.push(...this.tokenEmb.getWeights());
|
|
1556
|
+
w.push(...this.posEmb.getWeights());
|
|
1557
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1558
|
+
w.push(...this.outputProj.getWeights());
|
|
1559
|
+
w.push(...this.outputBias.getWeights());
|
|
1560
|
+
return w;
|
|
1561
|
+
}
|
|
1562
|
+
setWeights(weights) {
|
|
1563
|
+
let idx = 0;
|
|
1564
|
+
const tokenEmbLen = this.tokenEmb.getWeights().length;
|
|
1565
|
+
this.tokenEmb.setWeights(weights.slice(idx, idx + tokenEmbLen));
|
|
1566
|
+
idx += tokenEmbLen;
|
|
1567
|
+
const posEmbLen = this.posEmb.getWeights().length;
|
|
1568
|
+
this.posEmb.setWeights(weights.slice(idx, idx + posEmbLen));
|
|
1569
|
+
idx += posEmbLen;
|
|
1570
|
+
for (const block of this.blocks) {
|
|
1571
|
+
const blockLen = block.getWeights().length;
|
|
1572
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1573
|
+
idx += blockLen;
|
|
1574
|
+
}
|
|
1575
|
+
const outProjLen = this.outputProj.getWeights().length;
|
|
1576
|
+
this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
|
|
1577
|
+
idx += outProjLen;
|
|
1578
|
+
this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
|
|
1579
|
+
}
|
|
1088
1580
|
// ── Internal ──────────────────────────────────────────────────────────────
|
|
1089
1581
|
// Shared embedding + block forward pass.
|
|
1090
1582
|
_forward(tokens) {
|
|
@@ -1104,25 +1596,28 @@ var NetworkTransformerRL = class {
|
|
|
1104
1596
|
constructor(seqLen, inputDim, options = {}) {
|
|
1105
1597
|
// Forward caches para backprop
|
|
1106
1598
|
this._projected = null;
|
|
1599
|
+
// For max pooling backward: argmax per dimension across all positions
|
|
1600
|
+
this._argmax = null;
|
|
1107
1601
|
const {
|
|
1108
1602
|
d_model = 32,
|
|
1109
1603
|
nHeads = 2,
|
|
1110
1604
|
d_ff = 64,
|
|
1111
1605
|
nBlocks = 2,
|
|
1112
|
-
nActions = 2
|
|
1606
|
+
nActions = 2,
|
|
1607
|
+
pooling = "weighted"
|
|
1113
1608
|
} = options;
|
|
1114
1609
|
this.seqLen = seqLen;
|
|
1115
1610
|
this.inputDim = inputDim;
|
|
1116
1611
|
this.d_model = d_model;
|
|
1117
1612
|
this.nActions = nActions;
|
|
1613
|
+
this._pooling = pooling;
|
|
1118
1614
|
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1119
1615
|
this.blocks = Array.from(
|
|
1120
1616
|
{ length: nBlocks },
|
|
1121
|
-
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1617
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff, causal: true })
|
|
1122
1618
|
);
|
|
1123
1619
|
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1124
|
-
this.outputBias = new
|
|
1125
|
-
this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
|
|
1620
|
+
this.outputBias = new BiasVector(nActions);
|
|
1126
1621
|
}
|
|
1127
1622
|
// ── Forward ────────────────────────────────────────────────────────────────
|
|
1128
1623
|
// sequence: seqLen × inputDim → nActions Q-values
|
|
@@ -1130,7 +1625,7 @@ var NetworkTransformerRL = class {
|
|
|
1130
1625
|
const h = this._forward(sequence);
|
|
1131
1626
|
const pooled = this._pool(h);
|
|
1132
1627
|
return this.outputProj.W.map(
|
|
1133
|
-
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1628
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
|
|
1134
1629
|
);
|
|
1135
1630
|
}
|
|
1136
1631
|
// ── Training ────────────────────────────────────────────────────────────────
|
|
@@ -1142,7 +1637,7 @@ var NetworkTransformerRL = class {
|
|
|
1142
1637
|
const h = this._forward(sequence);
|
|
1143
1638
|
const pooled = this._pool(h);
|
|
1144
1639
|
const pred = this.outputProj.W.map(
|
|
1145
|
-
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1640
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias.values[c])
|
|
1146
1641
|
);
|
|
1147
1642
|
const n = this.nActions;
|
|
1148
1643
|
let loss = 0;
|
|
@@ -1165,13 +1660,8 @@ var NetworkTransformerRL = class {
|
|
|
1165
1660
|
);
|
|
1166
1661
|
const dBout = dPred.slice();
|
|
1167
1662
|
this.outputProj.update(dWout, lr);
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
let dH = Array.from(
|
|
1171
|
-
{ length: this.seqLen },
|
|
1172
|
-
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1173
|
-
// Gradiente dividido entre posiciones
|
|
1174
|
-
);
|
|
1663
|
+
this.outputBias.update(dBout, lr);
|
|
1664
|
+
let dH = this._distributePoolGradient(dPooled);
|
|
1175
1665
|
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1176
1666
|
dH = this.blocks[b].backward(dH, lr);
|
|
1177
1667
|
for (let i = 0; i < this.seqLen; i++) {
|
|
@@ -1190,8 +1680,32 @@ var NetworkTransformerRL = class {
|
|
|
1190
1680
|
getAttentionWeights() {
|
|
1191
1681
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1192
1682
|
}
|
|
1193
|
-
// ──
|
|
1194
|
-
|
|
1683
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
1684
|
+
// Order: inputProj, block0, block1, ..., blockN, outputProj, outputBias.
|
|
1685
|
+
getWeightsFlat() {
|
|
1686
|
+
const w = [];
|
|
1687
|
+
w.push(...this.inputProj.getWeights());
|
|
1688
|
+
for (const block of this.blocks) w.push(...block.getWeights());
|
|
1689
|
+
w.push(...this.outputProj.getWeights());
|
|
1690
|
+
w.push(...this.outputBias.getWeights());
|
|
1691
|
+
return w;
|
|
1692
|
+
}
|
|
1693
|
+
setWeightsFlat(weights) {
|
|
1694
|
+
let idx = 0;
|
|
1695
|
+
const inputProjLen = this.inputProj.getWeights().length;
|
|
1696
|
+
this.inputProj.setWeights(weights.slice(idx, idx + inputProjLen));
|
|
1697
|
+
idx += inputProjLen;
|
|
1698
|
+
for (const block of this.blocks) {
|
|
1699
|
+
const blockLen = block.getWeights().length;
|
|
1700
|
+
block.setWeights(weights.slice(idx, idx + blockLen));
|
|
1701
|
+
idx += blockLen;
|
|
1702
|
+
}
|
|
1703
|
+
const outProjLen = this.outputProj.getWeights().length;
|
|
1704
|
+
this.outputProj.setWeights(weights.slice(idx, idx + outProjLen));
|
|
1705
|
+
idx += outProjLen;
|
|
1706
|
+
this.outputBias.setWeights(weights.slice(idx, idx + this.outputBias.values.length));
|
|
1707
|
+
}
|
|
1708
|
+
getWeightsStructured() {
|
|
1195
1709
|
return {
|
|
1196
1710
|
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1197
1711
|
blocks: this.blocks.map((b) => ({
|
|
@@ -1207,17 +1721,15 @@ var NetworkTransformerRL = class {
|
|
|
1207
1721
|
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1208
1722
|
ff1: b.ff1.W.map((r) => [...r]),
|
|
1209
1723
|
ff2: b.ff2.W.map((r) => [...r]),
|
|
1210
|
-
b1: [...b.b1],
|
|
1211
|
-
b2: [...b.b2]
|
|
1724
|
+
b1: [...b.b1.values],
|
|
1725
|
+
b2: [...b.b2.values]
|
|
1212
1726
|
})),
|
|
1213
1727
|
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1214
|
-
outputBias: [...this.outputBias]
|
|
1728
|
+
outputBias: [...this.outputBias.values]
|
|
1215
1729
|
};
|
|
1216
1730
|
}
|
|
1217
|
-
|
|
1218
|
-
data.inputProj.
|
|
1219
|
-
this.inputProj.W[i] = [...row];
|
|
1220
|
-
});
|
|
1731
|
+
setWeightsStructured(data) {
|
|
1732
|
+
this.inputProj.setWeights(data.inputProj.flat());
|
|
1221
1733
|
data.blocks.forEach((bd, b) => {
|
|
1222
1734
|
const blk = this.blocks[b];
|
|
1223
1735
|
bd.attn.heads.forEach((hd, h) => {
|
|
@@ -1232,11 +1744,20 @@ var NetworkTransformerRL = class {
|
|
|
1232
1744
|
blk.norm2.beta = [...bd.norm2.beta];
|
|
1233
1745
|
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1234
1746
|
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1235
|
-
blk.b1
|
|
1236
|
-
blk.b2
|
|
1747
|
+
blk.b1.setWeights(bd.b1);
|
|
1748
|
+
blk.b2.setWeights(bd.b2);
|
|
1237
1749
|
});
|
|
1238
1750
|
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1239
|
-
this.outputBias
|
|
1751
|
+
this.outputBias.setWeights(data.outputBias);
|
|
1752
|
+
}
|
|
1753
|
+
// ── Serializable interface (flat array) ────────────────────────────────────
|
|
1754
|
+
// These satisfy the Serializable interface from ModelSaver, which requires
|
|
1755
|
+
// getWeights(): number[] and setWeights(weights: number[]): void.
|
|
1756
|
+
getWeights() {
|
|
1757
|
+
return this.getWeightsFlat();
|
|
1758
|
+
}
|
|
1759
|
+
setWeights(weights) {
|
|
1760
|
+
this.setWeightsFlat(weights);
|
|
1240
1761
|
}
|
|
1241
1762
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1242
1763
|
_forward(sequence) {
|
|
@@ -1251,6 +1772,44 @@ var NetworkTransformerRL = class {
|
|
|
1251
1772
|
return h;
|
|
1252
1773
|
}
|
|
1253
1774
|
_pool(h) {
|
|
1775
|
+
switch (this._pooling) {
|
|
1776
|
+
case "avg":
|
|
1777
|
+
return this._poolAvg(h);
|
|
1778
|
+
case "max":
|
|
1779
|
+
return this._poolMax(h);
|
|
1780
|
+
case "last":
|
|
1781
|
+
return this._poolLast(h);
|
|
1782
|
+
case "weighted":
|
|
1783
|
+
default:
|
|
1784
|
+
return this._poolWeighted(h);
|
|
1785
|
+
}
|
|
1786
|
+
}
|
|
1787
|
+
_poolAvg(h) {
|
|
1788
|
+
const n = h.length;
|
|
1789
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1790
|
+
let sum = 0;
|
|
1791
|
+
for (let i = 0; i < n; i++)
|
|
1792
|
+
sum += h[i][m];
|
|
1793
|
+
return sum / n;
|
|
1794
|
+
});
|
|
1795
|
+
}
|
|
1796
|
+
_poolMax(h) {
|
|
1797
|
+
this._argmax = new Array(this.d_model).fill(0);
|
|
1798
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1799
|
+
let maxVal = -Infinity;
|
|
1800
|
+
for (let i = 0; i < h.length; i++) {
|
|
1801
|
+
if (h[i][m] > maxVal) {
|
|
1802
|
+
maxVal = h[i][m];
|
|
1803
|
+
this._argmax[m] = i;
|
|
1804
|
+
}
|
|
1805
|
+
}
|
|
1806
|
+
return maxVal;
|
|
1807
|
+
});
|
|
1808
|
+
}
|
|
1809
|
+
_poolLast(h) {
|
|
1810
|
+
return [...h[h.length - 1]];
|
|
1811
|
+
}
|
|
1812
|
+
_poolWeighted(h) {
|
|
1254
1813
|
const weights = Array.from(
|
|
1255
1814
|
{ length: this.seqLen },
|
|
1256
1815
|
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
@@ -1263,6 +1822,55 @@ var NetworkTransformerRL = class {
|
|
|
1263
1822
|
return sum / totalWeight;
|
|
1264
1823
|
});
|
|
1265
1824
|
}
|
|
1825
|
+
/** Returns the current pooling type for inspection. */
|
|
1826
|
+
getPoolingType() {
|
|
1827
|
+
return this._pooling;
|
|
1828
|
+
}
|
|
1829
|
+
// ── Helper: distribute pooled gradient back to each position ────────────────
|
|
1830
|
+
// Must match the same distribution as _pool() used during forward.
|
|
1831
|
+
_distributePoolGradient(dPooled) {
|
|
1832
|
+
switch (this._pooling) {
|
|
1833
|
+
case "avg": {
|
|
1834
|
+
const n = this.seqLen;
|
|
1835
|
+
return Array.from(
|
|
1836
|
+
{ length: n },
|
|
1837
|
+
() => dPooled.map((v) => v / n)
|
|
1838
|
+
);
|
|
1839
|
+
}
|
|
1840
|
+
case "max": {
|
|
1841
|
+
if (!this._argmax) {
|
|
1842
|
+
const n = this.seqLen;
|
|
1843
|
+
return Array.from(
|
|
1844
|
+
{ length: n },
|
|
1845
|
+
() => dPooled.map((v) => v / n)
|
|
1846
|
+
);
|
|
1847
|
+
}
|
|
1848
|
+
const argmax = this._argmax;
|
|
1849
|
+
return Array.from(
|
|
1850
|
+
{ length: this.seqLen },
|
|
1851
|
+
(_, i) => dPooled.map((v, m) => i === argmax[m] ? v : 0)
|
|
1852
|
+
);
|
|
1853
|
+
}
|
|
1854
|
+
case "last": {
|
|
1855
|
+
return Array.from(
|
|
1856
|
+
{ length: this.seqLen },
|
|
1857
|
+
(_, i) => i === this.seqLen - 1 ? [...dPooled] : new Array(this.d_model).fill(0)
|
|
1858
|
+
);
|
|
1859
|
+
}
|
|
1860
|
+
case "weighted":
|
|
1861
|
+
default: {
|
|
1862
|
+
const weights = Array.from(
|
|
1863
|
+
{ length: this.seqLen },
|
|
1864
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1865
|
+
);
|
|
1866
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1867
|
+
return Array.from(
|
|
1868
|
+
{ length: this.seqLen },
|
|
1869
|
+
(_, i) => dPooled.map((v) => v * weights[i] / totalWeight)
|
|
1870
|
+
);
|
|
1871
|
+
}
|
|
1872
|
+
}
|
|
1873
|
+
}
|
|
1266
1874
|
};
|
|
1267
1875
|
|
|
1268
1876
|
// src/losses.ts
|
|
@@ -1287,14 +1895,803 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1287
1895
|
const p = Math.max(eps, Math.min(1 - eps, predicted));
|
|
1288
1896
|
return actual / p - (1 - actual) / (1 - p);
|
|
1289
1897
|
}
|
|
1898
|
+
|
|
1899
|
+
// src/GRU.ts
|
|
1900
|
+
function sigmoid4(x) {
|
|
1901
|
+
return 1 / (1 + Math.exp(-x));
|
|
1902
|
+
}
|
|
1903
|
+
function tanhFn(x) {
|
|
1904
|
+
const e = Math.exp(2 * x);
|
|
1905
|
+
return (e - 1) / (e + 1);
|
|
1906
|
+
}
|
|
1907
|
+
var Gate2 = class {
|
|
1908
|
+
constructor(inputSize, hSize, initBias = 0) {
|
|
1909
|
+
const n = inputSize + hSize;
|
|
1910
|
+
const limit = Math.sqrt(2 / (n + hSize));
|
|
1911
|
+
this.W = Array.from(
|
|
1912
|
+
{ length: hSize },
|
|
1913
|
+
() => Array.from({ length: n }, () => (Math.random() * 2 - 1) * limit)
|
|
1914
|
+
);
|
|
1915
|
+
this.b = new Array(hSize).fill(initBias);
|
|
1916
|
+
}
|
|
1917
|
+
linear(combined) {
|
|
1918
|
+
return this.W.map(
|
|
1919
|
+
(row, i) => row.reduce((s, w, j) => s + w * combined[j], this.b[i])
|
|
1920
|
+
);
|
|
1921
|
+
}
|
|
1922
|
+
};
|
|
1923
|
+
var GRULayer = class {
|
|
1924
|
+
constructor(inputSize, hiddenSize, optimizerFactory = () => new SGD()) {
|
|
1925
|
+
this._traj = [];
|
|
1926
|
+
if (inputSize <= 0 || hiddenSize <= 0) {
|
|
1927
|
+
throw new Error(`GRULayer: inputSize and hiddenSize must be positive, got ${inputSize} and ${hiddenSize}`);
|
|
1928
|
+
}
|
|
1929
|
+
this.inputSize = inputSize;
|
|
1930
|
+
this.hSize = hiddenSize;
|
|
1931
|
+
this.h = new Array(hiddenSize).fill(0);
|
|
1932
|
+
this.resetGate = new Gate2(inputSize, hiddenSize);
|
|
1933
|
+
this.updateGate = new Gate2(inputSize, hiddenSize);
|
|
1934
|
+
this.newGate = new Gate2(inputSize, hiddenSize);
|
|
1935
|
+
const combSize = inputSize + hiddenSize;
|
|
1936
|
+
this._optimizers = {
|
|
1937
|
+
resetW: Array.from(
|
|
1938
|
+
{ length: hiddenSize },
|
|
1939
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1940
|
+
),
|
|
1941
|
+
resetB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1942
|
+
updateW: Array.from(
|
|
1943
|
+
{ length: hiddenSize },
|
|
1944
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1945
|
+
),
|
|
1946
|
+
updateB: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
1947
|
+
newW: Array.from(
|
|
1948
|
+
{ length: hiddenSize },
|
|
1949
|
+
() => Array.from({ length: combSize }, () => optimizerFactory())
|
|
1950
|
+
),
|
|
1951
|
+
newB: Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
1952
|
+
};
|
|
1953
|
+
}
|
|
1954
|
+
reset() {
|
|
1955
|
+
this.h = new Array(this.hSize).fill(0);
|
|
1956
|
+
this._traj = [];
|
|
1957
|
+
}
|
|
1958
|
+
predict(inputs) {
|
|
1959
|
+
if (!Array.isArray(inputs) || inputs.length !== this.inputSize) {
|
|
1960
|
+
throw new Error(`GRULayer.predict: expected array of length ${this.inputSize}, got ${inputs?.length}`);
|
|
1961
|
+
}
|
|
1962
|
+
const combined = [...inputs, ...this.h];
|
|
1963
|
+
const h_prev = [...this.h];
|
|
1964
|
+
const r_pre = this.resetGate.linear(combined);
|
|
1965
|
+
const z_pre = this.updateGate.linear(combined);
|
|
1966
|
+
const r_a = r_pre.map(sigmoid4);
|
|
1967
|
+
const z_a = z_pre.map(sigmoid4);
|
|
1968
|
+
const combined_r = [...inputs, ...r_a.map((r, i) => r * h_prev[i])];
|
|
1969
|
+
const n_pre = this.newGate.linear(combined_r);
|
|
1970
|
+
const n_a = n_pre.map(tanhFn);
|
|
1971
|
+
const h = n_a.map((n, i) => (1 - z_a[i]) * n + z_a[i] * h_prev[i]);
|
|
1972
|
+
this._traj.push({ combined, h_prev, r: r_pre, r_a, z: z_pre, z_a, combined_r, n_pre, n_a, h });
|
|
1973
|
+
this.h = h;
|
|
1974
|
+
return h;
|
|
1975
|
+
}
|
|
1976
|
+
backprop(dh_seq, lr) {
|
|
1977
|
+
const T = this._traj.length;
|
|
1978
|
+
if (T === 0 || dh_seq.length !== T) return;
|
|
1979
|
+
const hSize = this.hSize;
|
|
1980
|
+
const combSize = this.inputSize + hSize;
|
|
1981
|
+
const dWr = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1982
|
+
const dWz = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1983
|
+
const dWn = Array.from({ length: hSize }, () => new Array(combSize).fill(0));
|
|
1984
|
+
const dbr = new Array(hSize).fill(0);
|
|
1985
|
+
const dbz = new Array(hSize).fill(0);
|
|
1986
|
+
const dbn = new Array(hSize).fill(0);
|
|
1987
|
+
let dh_next = new Array(hSize).fill(0);
|
|
1988
|
+
for (let t = T - 1; t >= 0; t--) {
|
|
1989
|
+
const s = this._traj[t];
|
|
1990
|
+
const dh = dh_seq[t].map((d, i) => d + dh_next[i]);
|
|
1991
|
+
const dz_a = dh.map((d, i) => (s.h_prev[i] - s.n_a[i]) * d);
|
|
1992
|
+
const dn_a = dh.map((d, i) => (1 - s.z_a[i]) * d);
|
|
1993
|
+
const dn_pre = dn_a.map((d, i) => d * (1 - s.n_a[i] ** 2));
|
|
1994
|
+
const dz_pre = dz_a.map((d, i) => d * s.z_a[i] * (1 - s.z_a[i]));
|
|
1995
|
+
const dr_hprev = Array.from(
|
|
1996
|
+
{ length: hSize },
|
|
1997
|
+
(_, i) => this.newGate.W.reduce((sum, row, k) => sum + dn_pre[k] * row[this.inputSize + i], 0)
|
|
1998
|
+
);
|
|
1999
|
+
const dr_a = dr_hprev.map((d, i) => d * s.h_prev[i]);
|
|
2000
|
+
const dr_pre = dr_a.map((d, i) => d * s.r_a[i] * (1 - s.r_a[i]));
|
|
2001
|
+
for (let k = 0; k < hSize; k++) {
|
|
2002
|
+
for (let j = 0; j < combSize; j++) {
|
|
2003
|
+
dWr[k][j] += dr_pre[k] * s.combined[j];
|
|
2004
|
+
dWz[k][j] += dz_pre[k] * s.combined[j];
|
|
2005
|
+
dWn[k][j] += dn_pre[k] * s.combined_r[j];
|
|
2006
|
+
}
|
|
2007
|
+
dbr[k] += dr_pre[k];
|
|
2008
|
+
dbz[k] += dz_pre[k];
|
|
2009
|
+
dbn[k] += dn_pre[k];
|
|
2010
|
+
}
|
|
2011
|
+
dh_next = new Array(hSize).fill(0);
|
|
2012
|
+
for (let k = 0; k < hSize; k++) {
|
|
2013
|
+
for (let j = this.inputSize; j < combSize; j++) {
|
|
2014
|
+
dh_next[j - this.inputSize] += dr_pre[k] * this.resetGate.W[k][j] + dz_pre[k] * this.updateGate.W[k][j];
|
|
2015
|
+
}
|
|
2016
|
+
dh_next[k] += dr_hprev[k] * s.r_a[k];
|
|
2017
|
+
dh_next[k] += dh[k] * s.z_a[k];
|
|
2018
|
+
}
|
|
2019
|
+
}
|
|
2020
|
+
const scale = lr / T;
|
|
2021
|
+
for (let k = 0; k < hSize; k++) {
|
|
2022
|
+
for (let j = 0; j < combSize; j++) {
|
|
2023
|
+
this.resetGate.W[k][j] = this._optimizers.resetW[k][j].step(this.resetGate.W[k][j], dWr[k][j], scale);
|
|
2024
|
+
this.updateGate.W[k][j] = this._optimizers.updateW[k][j].step(this.updateGate.W[k][j], dWz[k][j], scale);
|
|
2025
|
+
this.newGate.W[k][j] = this._optimizers.newW[k][j].step(this.newGate.W[k][j], dWn[k][j], scale);
|
|
2026
|
+
}
|
|
2027
|
+
this.resetGate.b[k] = this._optimizers.resetB[k].step(this.resetGate.b[k], dbr[k], scale);
|
|
2028
|
+
this.updateGate.b[k] = this._optimizers.updateB[k].step(this.updateGate.b[k], dbz[k], scale);
|
|
2029
|
+
this.newGate.b[k] = this._optimizers.newB[k].step(this.newGate.b[k], dbn[k], scale);
|
|
2030
|
+
}
|
|
2031
|
+
this._traj = [];
|
|
2032
|
+
}
|
|
2033
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2034
|
+
// Order: resetGate (W, b), updateGate (W, b), newGate (W, b).
|
|
2035
|
+
getWeightsFlat() {
|
|
2036
|
+
const w = [];
|
|
2037
|
+
for (const row of this.resetGate.W) w.push(...row);
|
|
2038
|
+
w.push(...this.resetGate.b);
|
|
2039
|
+
for (const row of this.updateGate.W) w.push(...row);
|
|
2040
|
+
w.push(...this.updateGate.b);
|
|
2041
|
+
for (const row of this.newGate.W) w.push(...row);
|
|
2042
|
+
w.push(...this.newGate.b);
|
|
2043
|
+
return w;
|
|
2044
|
+
}
|
|
2045
|
+
setWeightsFlat(weights) {
|
|
2046
|
+
let idx = 0;
|
|
2047
|
+
for (let i = 0; i < this.resetGate.W.length; i++)
|
|
2048
|
+
for (let j = 0; j < this.resetGate.W[i].length; j++) this.resetGate.W[i][j] = weights[idx++];
|
|
2049
|
+
for (let i = 0; i < this.resetGate.b.length; i++) this.resetGate.b[i] = weights[idx++];
|
|
2050
|
+
for (let i = 0; i < this.updateGate.W.length; i++)
|
|
2051
|
+
for (let j = 0; j < this.updateGate.W[i].length; j++) this.updateGate.W[i][j] = weights[idx++];
|
|
2052
|
+
for (let i = 0; i < this.updateGate.b.length; i++) this.updateGate.b[i] = weights[idx++];
|
|
2053
|
+
for (let i = 0; i < this.newGate.W.length; i++)
|
|
2054
|
+
for (let j = 0; j < this.newGate.W[i].length; j++) this.newGate.W[i][j] = weights[idx++];
|
|
2055
|
+
for (let i = 0; i < this.newGate.b.length; i++) this.newGate.b[i] = weights[idx++];
|
|
2056
|
+
}
|
|
2057
|
+
getWeights() {
|
|
2058
|
+
return {
|
|
2059
|
+
resetGate: { W: this.resetGate.W, b: this.resetGate.b },
|
|
2060
|
+
updateGate: { W: this.updateGate.W, b: this.updateGate.b },
|
|
2061
|
+
newGate: { W: this.newGate.W, b: this.newGate.b }
|
|
2062
|
+
};
|
|
2063
|
+
}
|
|
2064
|
+
setWeights(data) {
|
|
2065
|
+
this.resetGate.W = data.resetGate.W;
|
|
2066
|
+
this.resetGate.b = data.resetGate.b;
|
|
2067
|
+
this.updateGate.W = data.updateGate.W;
|
|
2068
|
+
this.updateGate.b = data.updateGate.b;
|
|
2069
|
+
this.newGate.W = data.newGate.W;
|
|
2070
|
+
this.newGate.b = data.newGate.b;
|
|
2071
|
+
}
|
|
2072
|
+
};
|
|
2073
|
+
|
|
2074
|
+
// src/BatchNorm.ts
|
|
2075
|
+
var BatchNorm = class {
|
|
2076
|
+
constructor(dim, momentum = 0.1) {
|
|
2077
|
+
this._xNorm = null;
|
|
2078
|
+
this._std = null;
|
|
2079
|
+
this.dim = dim;
|
|
2080
|
+
this.momentum = momentum;
|
|
2081
|
+
this.gamma = new Array(dim).fill(1);
|
|
2082
|
+
this.beta = new Array(dim).fill(0);
|
|
2083
|
+
this.runningMean = new Array(dim).fill(0);
|
|
2084
|
+
this.runningVar = new Array(dim).fill(1);
|
|
2085
|
+
}
|
|
2086
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2087
|
+
forward(x) {
|
|
2088
|
+
if (x.length !== this.dim) {
|
|
2089
|
+
throw new Error(`BatchNorm.forward: expected array of length ${this.dim}, got ${x.length}`);
|
|
2090
|
+
}
|
|
2091
|
+
const eps = 1e-5;
|
|
2092
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2093
|
+
this.runningMean[i] = this.momentum * this.runningMean[i] + (1 - this.momentum) * x[i];
|
|
2094
|
+
const diff = x[i] - this.runningMean[i];
|
|
2095
|
+
this.runningVar[i] = this.momentum * this.runningVar[i] + (1 - this.momentum) * diff * diff;
|
|
2096
|
+
}
|
|
2097
|
+
this._std = this.runningVar.map((v) => Math.sqrt(v + eps));
|
|
2098
|
+
this._xNorm = x.map((v, i) => (v - this.runningMean[i]) / this._std[i]);
|
|
2099
|
+
return this._xNorm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
|
|
2100
|
+
}
|
|
2101
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2102
|
+
backward(dOut) {
|
|
2103
|
+
if (!this._xNorm || !this._std) {
|
|
2104
|
+
throw new Error("BatchNorm.backward: call forward() first");
|
|
2105
|
+
}
|
|
2106
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2107
|
+
}
|
|
2108
|
+
return dOut.map((d, i) => d * this.gamma[i] / this._std[i]);
|
|
2109
|
+
}
|
|
2110
|
+
// ── Train gamma and beta (call after backward) ────────────────────────────
|
|
2111
|
+
trainParams(dOut, lr) {
|
|
2112
|
+
if (!this._xNorm) return;
|
|
2113
|
+
for (let i = 0; i < this.dim; i++) {
|
|
2114
|
+
this.gamma[i] += lr * dOut[i] * this._xNorm[i];
|
|
2115
|
+
this.beta[i] += lr * dOut[i];
|
|
2116
|
+
}
|
|
2117
|
+
}
|
|
2118
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2119
|
+
// Order: gamma, beta.
|
|
2120
|
+
getWeights() {
|
|
2121
|
+
return [...this.gamma, ...this.beta];
|
|
2122
|
+
}
|
|
2123
|
+
setWeights(weights) {
|
|
2124
|
+
for (let i = 0; i < this.dim; i++) this.gamma[i] = weights[i];
|
|
2125
|
+
for (let i = 0; i < this.dim; i++) this.beta[i] = weights[this.dim + i];
|
|
2126
|
+
}
|
|
2127
|
+
};
|
|
2128
|
+
|
|
2129
|
+
// src/Conv1D.ts
|
|
2130
|
+
var Conv1D = class {
|
|
2131
|
+
constructor(inputLength, kernelSize, filters, stride = 1, padding = "valid", optimizerFactory = () => new SGD(), inputChannels = 1) {
|
|
2132
|
+
// [filters]
|
|
2133
|
+
this._input = null;
|
|
2134
|
+
this._paddedInput = null;
|
|
2135
|
+
if (inputLength <= 0 || kernelSize <= 0 || filters <= 0) {
|
|
2136
|
+
throw new Error("Conv1D: inputLength, kernelSize, and filters must be positive");
|
|
2137
|
+
}
|
|
2138
|
+
if (kernelSize > inputLength && padding === "valid") {
|
|
2139
|
+
throw new Error("Conv1D: kernelSize cannot exceed inputLength with valid padding");
|
|
2140
|
+
}
|
|
2141
|
+
if (inputChannels < 1) {
|
|
2142
|
+
throw new Error("Conv1D: inputChannels must be >= 1");
|
|
2143
|
+
}
|
|
2144
|
+
this.inputLength = inputLength;
|
|
2145
|
+
this.kernelSize = kernelSize;
|
|
2146
|
+
this.filters = filters;
|
|
2147
|
+
this.stride = stride;
|
|
2148
|
+
this.padding = padding;
|
|
2149
|
+
this.inputChannels = inputChannels;
|
|
2150
|
+
const limit = Math.sqrt(2 / (kernelSize * inputChannels));
|
|
2151
|
+
this.kernels = Array.from(
|
|
2152
|
+
{ length: filters },
|
|
2153
|
+
() => Array.from(
|
|
2154
|
+
{ length: kernelSize },
|
|
2155
|
+
() => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
|
|
2156
|
+
)
|
|
2157
|
+
);
|
|
2158
|
+
this.biases = new Array(filters).fill(0);
|
|
2159
|
+
this._kOpts = Array.from(
|
|
2160
|
+
{ length: filters },
|
|
2161
|
+
() => Array.from(
|
|
2162
|
+
{ length: kernelSize },
|
|
2163
|
+
() => Array.from({ length: inputChannels }, () => optimizerFactory())
|
|
2164
|
+
)
|
|
2165
|
+
);
|
|
2166
|
+
this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
|
|
2167
|
+
}
|
|
2168
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
2169
|
+
// Accepts either number[] (when inputChannels=1) or number[][] (multi-channel).
|
|
2170
|
+
forward(input) {
|
|
2171
|
+
const input2D = this._normalizeInput(input);
|
|
2172
|
+
this._input = input2D.map((row) => [...row]);
|
|
2173
|
+
let padded;
|
|
2174
|
+
if (this.padding === "same") {
|
|
2175
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2176
|
+
const padRow = new Array(this.inputChannels).fill(0);
|
|
2177
|
+
padded = new Array(padSize).fill(null).map(() => [...padRow]).concat(input2D).concat(new Array(padSize).fill(null).map(() => [...padRow]));
|
|
2178
|
+
} else {
|
|
2179
|
+
padded = input2D;
|
|
2180
|
+
}
|
|
2181
|
+
this._paddedInput = padded;
|
|
2182
|
+
const outputLength = Math.floor((padded.length - this.kernelSize) / this.stride) + 1;
|
|
2183
|
+
const output = Array.from(
|
|
2184
|
+
{ length: this.filters },
|
|
2185
|
+
() => new Array(outputLength).fill(0)
|
|
2186
|
+
);
|
|
2187
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2188
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2189
|
+
const start = pos * this.stride;
|
|
2190
|
+
let sum = this.biases[f];
|
|
2191
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2192
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2193
|
+
sum += this.kernels[f][k][c] * padded[start + k][c];
|
|
2194
|
+
}
|
|
2195
|
+
}
|
|
2196
|
+
output[f][pos] = sum;
|
|
2197
|
+
}
|
|
2198
|
+
}
|
|
2199
|
+
return output;
|
|
2200
|
+
}
|
|
2201
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
2202
|
+
backward(dOut, lr = 1e-3) {
|
|
2203
|
+
if (!this._paddedInput || !this._input) {
|
|
2204
|
+
throw new Error("Conv1D.backward: call forward() first");
|
|
2205
|
+
}
|
|
2206
|
+
const padded = this._paddedInput;
|
|
2207
|
+
const outputLength = dOut[0].length;
|
|
2208
|
+
const dKernels = Array.from(
|
|
2209
|
+
{ length: this.filters },
|
|
2210
|
+
() => Array.from(
|
|
2211
|
+
{ length: this.kernelSize },
|
|
2212
|
+
() => new Array(this.inputChannels).fill(0)
|
|
2213
|
+
)
|
|
2214
|
+
);
|
|
2215
|
+
const dBiases = new Array(this.filters).fill(0);
|
|
2216
|
+
const dPadded = padded.map((row) => new Array(this.inputChannels).fill(0));
|
|
2217
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2218
|
+
for (let pos = 0; pos < outputLength; pos++) {
|
|
2219
|
+
const start = pos * this.stride;
|
|
2220
|
+
dBiases[f] += dOut[f][pos];
|
|
2221
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2222
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2223
|
+
dKernels[f][k][c] += dOut[f][pos] * padded[start + k][c];
|
|
2224
|
+
dPadded[start + k][c] += dOut[f][pos] * this.kernels[f][k][c];
|
|
2225
|
+
}
|
|
2226
|
+
}
|
|
2227
|
+
}
|
|
2228
|
+
}
|
|
2229
|
+
for (let f = 0; f < this.filters; f++) {
|
|
2230
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
2231
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
2232
|
+
this.kernels[f][k][c] = this._kOpts[f][k][c].step(this.kernels[f][k][c], dKernels[f][k][c], lr);
|
|
2233
|
+
}
|
|
2234
|
+
}
|
|
2235
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
2236
|
+
}
|
|
2237
|
+
if (this.padding === "same") {
|
|
2238
|
+
const padSize = Math.floor((this.kernelSize - 1) / 2);
|
|
2239
|
+
return dPadded.slice(padSize, padSize + this.inputLength);
|
|
2240
|
+
}
|
|
2241
|
+
return dPadded.slice(0, this.inputLength);
|
|
2242
|
+
}
|
|
2243
|
+
// ── Output length ─────────────────────────────────────────────────────────
|
|
2244
|
+
getOutputLength() {
|
|
2245
|
+
if (this.padding === "same") {
|
|
2246
|
+
return Math.ceil(this.inputLength / this.stride);
|
|
2247
|
+
}
|
|
2248
|
+
return Math.floor((this.inputLength - this.kernelSize) / this.stride) + 1;
|
|
2249
|
+
}
|
|
2250
|
+
// ── Flat weight serialization ─────────────────────────────────────────────
|
|
2251
|
+
// Order: kernels (flattened), biases.
|
|
2252
|
+
getWeights() {
|
|
2253
|
+
const w = [];
|
|
2254
|
+
for (const kernel of this.kernels)
|
|
2255
|
+
for (const k of kernel)
|
|
2256
|
+
for (const c of k)
|
|
2257
|
+
w.push(c);
|
|
2258
|
+
w.push(...this.biases);
|
|
2259
|
+
return w;
|
|
2260
|
+
}
|
|
2261
|
+
setWeights(weights) {
|
|
2262
|
+
let idx = 0;
|
|
2263
|
+
for (let f = 0; f < this.filters; f++)
|
|
2264
|
+
for (let k = 0; k < this.kernelSize; k++)
|
|
2265
|
+
for (let c = 0; c < this.inputChannels; c++)
|
|
2266
|
+
this.kernels[f][k][c] = weights[idx++];
|
|
2267
|
+
for (let f = 0; f < this.filters; f++)
|
|
2268
|
+
this.biases[f] = weights[idx++];
|
|
2269
|
+
}
|
|
2270
|
+
// ── Normalize input to 2D format ─────────────────────────────────────────
|
|
2271
|
+
_normalizeInput(input) {
|
|
2272
|
+
if (input.length === 0) {
|
|
2273
|
+
throw new Error("Conv1D.forward: input cannot be empty");
|
|
2274
|
+
}
|
|
2275
|
+
if (typeof input[0] === "number") {
|
|
2276
|
+
if (this.inputChannels !== 1) {
|
|
2277
|
+
throw new Error(`Conv1D.forward: expected 2D input with ${this.inputChannels} channels, got 1D`);
|
|
2278
|
+
}
|
|
2279
|
+
const input1D = input;
|
|
2280
|
+
if (input1D.length !== this.inputLength) {
|
|
2281
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input1D.length}`);
|
|
2282
|
+
}
|
|
2283
|
+
return input1D.map((v) => [v]);
|
|
2284
|
+
}
|
|
2285
|
+
const input2D = input;
|
|
2286
|
+
if (input2D.length !== this.inputLength) {
|
|
2287
|
+
throw new Error(`Conv1D.forward: expected input of length ${this.inputLength}, got ${input2D.length}`);
|
|
2288
|
+
}
|
|
2289
|
+
for (let i = 0; i < input2D.length; i++) {
|
|
2290
|
+
if (input2D[i].length !== this.inputChannels) {
|
|
2291
|
+
throw new Error(`Conv1D.forward: expected ${this.inputChannels} channels at position ${i}, got ${input2D[i].length}`);
|
|
2292
|
+
}
|
|
2293
|
+
}
|
|
2294
|
+
return input2D;
|
|
2295
|
+
}
|
|
2296
|
+
};
|
|
2297
|
+
|
|
2298
|
+
// src/Trainer.ts
|
|
2299
|
+
var Trainer = class {
|
|
2300
|
+
constructor(network, options = {}) {
|
|
2301
|
+
this._history = [];
|
|
2302
|
+
this._bestLoss = Infinity;
|
|
2303
|
+
this._patienceCounter = 0;
|
|
2304
|
+
this._stopReason = "maxEpochs";
|
|
2305
|
+
this._metrics = [];
|
|
2306
|
+
this.network = network;
|
|
2307
|
+
this.epochs = options.epochs ?? 1e3;
|
|
2308
|
+
this.lrInitial = options.lr ?? 0.1;
|
|
2309
|
+
this.lrDecay = options.lrDecay ?? 1;
|
|
2310
|
+
this.verbose = options.verbose ?? false;
|
|
2311
|
+
this.weightDecay = options.weightDecay ?? 0;
|
|
2312
|
+
this._earlyStopping = options.earlyStopping;
|
|
2313
|
+
this._computeMetrics = options.computeMetrics ?? false;
|
|
2314
|
+
this.clipValue = options.clipValue ?? 0;
|
|
2315
|
+
}
|
|
2316
|
+
// ── Set external validation data (for early stopping) ────────────────────
|
|
2317
|
+
setValidationData(dataset) {
|
|
2318
|
+
if (dataset.inputs.length !== dataset.targets.length) {
|
|
2319
|
+
throw new Error(
|
|
2320
|
+
"Trainer.setValidationData: inputs and targets must have the same length"
|
|
2321
|
+
);
|
|
2322
|
+
}
|
|
2323
|
+
this._validationData = dataset;
|
|
2324
|
+
}
|
|
2325
|
+
// ── Get best validation loss during training ─────────────────────────────
|
|
2326
|
+
getBestLoss() {
|
|
2327
|
+
return this._bestLoss === Infinity ? -1 : this._bestLoss;
|
|
2328
|
+
}
|
|
2329
|
+
// ── Why did training stop? ───────────────────────────────────────────────
|
|
2330
|
+
getStopReason() {
|
|
2331
|
+
return this._stopReason;
|
|
2332
|
+
}
|
|
2333
|
+
// ── Get per-epoch classification metrics ─────────────────────────────────
|
|
2334
|
+
getMetrics() {
|
|
2335
|
+
return [...this._metrics];
|
|
2336
|
+
}
|
|
2337
|
+
// ── Train on dataset ──────────────────────────────────────────────────────
|
|
2338
|
+
train(dataset) {
|
|
2339
|
+
const { inputs, targets } = dataset;
|
|
2340
|
+
if (inputs.length !== targets.length) {
|
|
2341
|
+
throw new Error(
|
|
2342
|
+
"Trainer.train: inputs and targets must have the same length"
|
|
2343
|
+
);
|
|
2344
|
+
}
|
|
2345
|
+
const n = inputs.length;
|
|
2346
|
+
let lr = this.lrInitial;
|
|
2347
|
+
this._history = [];
|
|
2348
|
+
this._bestLoss = Infinity;
|
|
2349
|
+
this._patienceCounter = 0;
|
|
2350
|
+
this._stopReason = "maxEpochs";
|
|
2351
|
+
this._metrics = [];
|
|
2352
|
+
const netExt = this._hasWeights(this.network);
|
|
2353
|
+
if (this.weightDecay > 0 && !netExt) {
|
|
2354
|
+
console.warn(
|
|
2355
|
+
"Trainer: weightDecay requires a network with getWeights/setWeights/predict. Skipping weight decay."
|
|
2356
|
+
);
|
|
2357
|
+
}
|
|
2358
|
+
if (this._earlyStopping && !netExt) {
|
|
2359
|
+
console.warn(
|
|
2360
|
+
"Trainer: earlyStopping requires a network with predict(). Skipping early stopping."
|
|
2361
|
+
);
|
|
2362
|
+
}
|
|
2363
|
+
if (this._computeMetrics && !netExt) {
|
|
2364
|
+
console.warn(
|
|
2365
|
+
"Trainer: computeMetrics requires a network with predict(). Skipping metrics."
|
|
2366
|
+
);
|
|
2367
|
+
}
|
|
2368
|
+
const canDecay = this.weightDecay > 0 && netExt;
|
|
2369
|
+
const canValidate = !!this._earlyStopping && netExt && !!this._validationData;
|
|
2370
|
+
const canMetric = this._computeMetrics && netExt;
|
|
2371
|
+
const isClass = canMetric && this._isClassification(targets);
|
|
2372
|
+
if (canMetric && !isClass) {
|
|
2373
|
+
console.warn(
|
|
2374
|
+
"Trainer: computeMetrics is set but targets do not appear to be one-hot or single-class. Metrics will be skipped."
|
|
2375
|
+
);
|
|
2376
|
+
}
|
|
2377
|
+
for (let epoch = 0; epoch < this.epochs; epoch++) {
|
|
2378
|
+
const indices = Array.from({ length: n }, (_, i) => i);
|
|
2379
|
+
for (let i = n - 1; i > 0; i--) {
|
|
2380
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2381
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
2382
|
+
}
|
|
2383
|
+
let epochLoss = 0;
|
|
2384
|
+
for (const i of indices) {
|
|
2385
|
+
if (canDecay) {
|
|
2386
|
+
const w = netExt.getWeights();
|
|
2387
|
+
for (let j = 0; j < w.length; j++) {
|
|
2388
|
+
w[j] *= 1 - lr * this.weightDecay;
|
|
2389
|
+
}
|
|
2390
|
+
netExt.setWeights(w);
|
|
2391
|
+
}
|
|
2392
|
+
epochLoss += this.network.train(inputs[i], targets[i], lr);
|
|
2393
|
+
}
|
|
2394
|
+
epochLoss /= n;
|
|
2395
|
+
this._history.push(epochLoss);
|
|
2396
|
+
if (canMetric && isClass) {
|
|
2397
|
+
this._metrics.push(this._computeMetricsArray(netExt, inputs, targets));
|
|
2398
|
+
}
|
|
2399
|
+
if (canValidate && this._validationData) {
|
|
2400
|
+
const valLoss = this._computeLoss(netExt, this._validationData);
|
|
2401
|
+
const minDelta = this._earlyStopping.minDelta;
|
|
2402
|
+
if (valLoss < this._bestLoss - minDelta) {
|
|
2403
|
+
this._bestLoss = valLoss;
|
|
2404
|
+
this._patienceCounter = 0;
|
|
2405
|
+
} else {
|
|
2406
|
+
this._patienceCounter++;
|
|
2407
|
+
}
|
|
2408
|
+
if (this._patienceCounter >= this._earlyStopping.patience) {
|
|
2409
|
+
this._stopReason = "earlyStopping";
|
|
2410
|
+
break;
|
|
2411
|
+
}
|
|
2412
|
+
}
|
|
2413
|
+
lr *= this.lrDecay;
|
|
2414
|
+
if (this.verbose && (epoch + 1) % 100 === 0) {
|
|
2415
|
+
console.log(
|
|
2416
|
+
`Epoch ${epoch + 1}/${this.epochs}, loss: ${epochLoss.toFixed(6)}, lr: ${lr.toFixed(6)}`
|
|
2417
|
+
);
|
|
2418
|
+
}
|
|
2419
|
+
}
|
|
2420
|
+
return this._history;
|
|
2421
|
+
}
|
|
2422
|
+
// ── Get loss history ──────────────────────────────────────────────────────
|
|
2423
|
+
getHistory() {
|
|
2424
|
+
return [...this._history];
|
|
2425
|
+
}
|
|
2426
|
+
// ── Private helpers ───────────────────────────────────────────────────────
|
|
2427
|
+
/** Type guard: does this network support getWeights/setWeights/predict? */
|
|
2428
|
+
_hasWeights(network) {
|
|
2429
|
+
if ("getWeights" in network && "setWeights" in network && "predict" in network && typeof network.getWeights === "function" && typeof network.setWeights === "function" && typeof network.predict === "function") {
|
|
2430
|
+
return network;
|
|
2431
|
+
}
|
|
2432
|
+
return null;
|
|
2433
|
+
}
|
|
2434
|
+
/** Mean squared error on a dataset (used for validation loss). */
|
|
2435
|
+
_computeLoss(net, data) {
|
|
2436
|
+
let totalLoss = 0;
|
|
2437
|
+
for (let i = 0; i < data.inputs.length; i++) {
|
|
2438
|
+
const pred = net.predict(data.inputs[i]);
|
|
2439
|
+
const target = data.targets[i];
|
|
2440
|
+
let sampleLoss = 0;
|
|
2441
|
+
for (let j = 0; j < pred.length; j++) {
|
|
2442
|
+
sampleLoss += (target[j] - pred[j]) ** 2;
|
|
2443
|
+
}
|
|
2444
|
+
totalLoss += sampleLoss / pred.length;
|
|
2445
|
+
}
|
|
2446
|
+
return totalLoss / data.inputs.length;
|
|
2447
|
+
}
|
|
2448
|
+
/** Heuristic: are targets classification-style (one-hot or single-class)? */
|
|
2449
|
+
_isClassification(targets) {
|
|
2450
|
+
if (targets.length === 0) return false;
|
|
2451
|
+
const first = targets[0];
|
|
2452
|
+
if (first.length === 1) return true;
|
|
2453
|
+
for (const t of targets) {
|
|
2454
|
+
let sum = 0;
|
|
2455
|
+
for (const v of t) {
|
|
2456
|
+
sum += v;
|
|
2457
|
+
if (v < -0.01 || v > 0.01 && v < 0.99 && Math.abs(v - 1) > 0.01)
|
|
2458
|
+
return false;
|
|
2459
|
+
}
|
|
2460
|
+
if (Math.abs(sum - 1) > 0.01) return false;
|
|
2461
|
+
}
|
|
2462
|
+
return true;
|
|
2463
|
+
}
|
|
2464
|
+
/** Compute classification metrics from predictions vs targets. */
|
|
2465
|
+
_computeMetricsArray(net, inputs, targets) {
|
|
2466
|
+
const targetLen = targets[0].length;
|
|
2467
|
+
const nClasses = targetLen === 1 ? 2 : targetLen;
|
|
2468
|
+
const confusion = Array.from(
|
|
2469
|
+
{ length: nClasses },
|
|
2470
|
+
() => Array(nClasses).fill(0)
|
|
2471
|
+
);
|
|
2472
|
+
for (let i = 0; i < inputs.length; i++) {
|
|
2473
|
+
const pred = net.predict(inputs[i]);
|
|
2474
|
+
const target = targets[i];
|
|
2475
|
+
let predClass;
|
|
2476
|
+
let trueClass;
|
|
2477
|
+
if (targetLen === 1) {
|
|
2478
|
+
trueClass = target[0] >= 0.5 ? 1 : 0;
|
|
2479
|
+
if (pred.length === 1) {
|
|
2480
|
+
predClass = pred[0] >= 0.5 ? 1 : 0;
|
|
2481
|
+
} else {
|
|
2482
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2483
|
+
}
|
|
2484
|
+
} else {
|
|
2485
|
+
predClass = pred.indexOf(Math.max(...pred));
|
|
2486
|
+
trueClass = target.indexOf(Math.max(...target));
|
|
2487
|
+
}
|
|
2488
|
+
predClass = Math.max(0, Math.min(nClasses - 1, predClass));
|
|
2489
|
+
trueClass = Math.max(0, Math.min(nClasses - 1, trueClass));
|
|
2490
|
+
confusion[trueClass][predClass]++;
|
|
2491
|
+
}
|
|
2492
|
+
let totalCorrect = 0;
|
|
2493
|
+
let totalSamples = 0;
|
|
2494
|
+
const precisions = [];
|
|
2495
|
+
const recalls = [];
|
|
2496
|
+
for (let c = 0; c < nClasses; c++) {
|
|
2497
|
+
const tp = confusion[c][c];
|
|
2498
|
+
totalCorrect += tp;
|
|
2499
|
+
let colSum = 0;
|
|
2500
|
+
let rowSum = 0;
|
|
2501
|
+
for (let r = 0; r < nClasses; r++) {
|
|
2502
|
+
colSum += confusion[r][c];
|
|
2503
|
+
rowSum += confusion[c][r];
|
|
2504
|
+
}
|
|
2505
|
+
totalSamples += rowSum;
|
|
2506
|
+
precisions.push(colSum > 0 ? tp / colSum : 0);
|
|
2507
|
+
recalls.push(rowSum > 0 ? tp / rowSum : 0);
|
|
2508
|
+
}
|
|
2509
|
+
const accuracy = totalSamples > 0 ? totalCorrect / totalSamples : 0;
|
|
2510
|
+
const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
|
|
2511
|
+
const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
|
|
2512
|
+
const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
|
|
2513
|
+
return {
|
|
2514
|
+
accuracy,
|
|
2515
|
+
precision: macroPrecision,
|
|
2516
|
+
recall: macroRecall,
|
|
2517
|
+
f1
|
|
2518
|
+
};
|
|
2519
|
+
}
|
|
2520
|
+
};
|
|
2521
|
+
|
|
2522
|
+
// src/DataLoader.ts
|
|
2523
|
+
var DataLoader = class _DataLoader {
|
|
2524
|
+
constructor(data, batchSize = 1, validationSplit = 0) {
|
|
2525
|
+
if (data.inputs.length !== data.targets.length) {
|
|
2526
|
+
throw new Error("DataLoader: inputs and targets must have the same length");
|
|
2527
|
+
}
|
|
2528
|
+
if (validationSplit < 0 || validationSplit >= 1) {
|
|
2529
|
+
throw new Error(`DataLoader: validationSplit must be in [0, 1), got ${validationSplit}`);
|
|
2530
|
+
}
|
|
2531
|
+
this.data = data;
|
|
2532
|
+
this.batchSize = batchSize;
|
|
2533
|
+
this._validationSplit = validationSplit;
|
|
2534
|
+
const fullIndices = Array.from({ length: data.inputs.length }, (_, i) => i);
|
|
2535
|
+
for (let i = fullIndices.length - 1; i > 0; i--) {
|
|
2536
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2537
|
+
[fullIndices[i], fullIndices[j]] = [fullIndices[j], fullIndices[i]];
|
|
2538
|
+
}
|
|
2539
|
+
if (validationSplit > 0) {
|
|
2540
|
+
const valSize = Math.round(data.inputs.length * validationSplit);
|
|
2541
|
+
const trainSize = data.inputs.length - valSize;
|
|
2542
|
+
this._trainIndices = fullIndices.slice(0, trainSize);
|
|
2543
|
+
this._valIndices = fullIndices.slice(trainSize);
|
|
2544
|
+
} else {
|
|
2545
|
+
this._trainIndices = [...fullIndices];
|
|
2546
|
+
this._valIndices = [];
|
|
2547
|
+
}
|
|
2548
|
+
this._indices = [...this._trainIndices];
|
|
2549
|
+
this._pos = 0;
|
|
2550
|
+
}
|
|
2551
|
+
// ── Shuffle the training data ──────────────────────────────────────────────
|
|
2552
|
+
shuffle() {
|
|
2553
|
+
for (let i = this._trainIndices.length - 1; i > 0; i--) {
|
|
2554
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
2555
|
+
[this._trainIndices[i], this._trainIndices[j]] = [this._trainIndices[j], this._trainIndices[i]];
|
|
2556
|
+
}
|
|
2557
|
+
this._indices = [...this._trainIndices];
|
|
2558
|
+
this._pos = 0;
|
|
2559
|
+
}
|
|
2560
|
+
// ── Check if more batches are available ───────────────────────────────────
|
|
2561
|
+
hasNext() {
|
|
2562
|
+
return this._pos < this._indices.length;
|
|
2563
|
+
}
|
|
2564
|
+
// ── Get next batch ────────────────────────────────────────────────────────
|
|
2565
|
+
next() {
|
|
2566
|
+
const end = Math.min(this._pos + this.batchSize, this._indices.length);
|
|
2567
|
+
const batchIndices = this._indices.slice(this._pos, end);
|
|
2568
|
+
this._pos = end;
|
|
2569
|
+
return {
|
|
2570
|
+
inputs: batchIndices.map((i) => this.data.inputs[i]),
|
|
2571
|
+
targets: batchIndices.map((i) => this.data.targets[i])
|
|
2572
|
+
};
|
|
2573
|
+
}
|
|
2574
|
+
// ── Reset iteration ───────────────────────────────────────────────────────
|
|
2575
|
+
reset() {
|
|
2576
|
+
this._pos = 0;
|
|
2577
|
+
}
|
|
2578
|
+
// ── Get total number of training samples ───────────────────────────────────
|
|
2579
|
+
get length() {
|
|
2580
|
+
return this._trainIndices.length;
|
|
2581
|
+
}
|
|
2582
|
+
// ── Get validation data as a DataPair ──────────────────────────────────────
|
|
2583
|
+
// Returns the validation samples (inputs + targets) in their shuffled order.
|
|
2584
|
+
// Returns empty arrays if no validation split was configured.
|
|
2585
|
+
getValidationData() {
|
|
2586
|
+
return {
|
|
2587
|
+
inputs: this._valIndices.map((i) => this.data.inputs[i]),
|
|
2588
|
+
targets: this._valIndices.map((i) => this.data.targets[i])
|
|
2589
|
+
};
|
|
2590
|
+
}
|
|
2591
|
+
// ── Get number of validation samples ───────────────────────────────────────
|
|
2592
|
+
get validationLength() {
|
|
2593
|
+
return this._valIndices.length;
|
|
2594
|
+
}
|
|
2595
|
+
// ── Create sequence windows from a time series ────────────────────────────
|
|
2596
|
+
static sequences(data, seqLen, validationSplit = 0) {
|
|
2597
|
+
if (data.length < seqLen + 1) {
|
|
2598
|
+
throw new Error("DataLoader.sequences: data length must be >= seqLen + 1");
|
|
2599
|
+
}
|
|
2600
|
+
const inputs = [];
|
|
2601
|
+
const targets = [];
|
|
2602
|
+
for (let i = 0; i <= data.length - seqLen - 1; i++) {
|
|
2603
|
+
inputs.push(data.slice(i, i + seqLen).flat());
|
|
2604
|
+
targets.push(data[i + seqLen]);
|
|
2605
|
+
}
|
|
2606
|
+
return new _DataLoader({ inputs, targets }, 1, validationSplit);
|
|
2607
|
+
}
|
|
2608
|
+
};
|
|
2609
|
+
|
|
2610
|
+
// src/LRScheduler.ts
|
|
2611
|
+
var LRScheduler = class {
|
|
2612
|
+
// ── Step Decay ────────────────────────────────────────────────────────────
|
|
2613
|
+
// lr = initialLr * dropRate^floor(epoch / epochsDrop)
|
|
2614
|
+
stepDecay(lr, epoch, dropRate, epochsDrop) {
|
|
2615
|
+
return lr * Math.pow(dropRate, Math.floor(epoch / epochsDrop));
|
|
2616
|
+
}
|
|
2617
|
+
// ── Exponential Decay ─────────────────────────────────────────────────────
|
|
2618
|
+
// lr = initialLr * decayRate^epoch
|
|
2619
|
+
exponentialDecay(lr, epoch, decayRate) {
|
|
2620
|
+
return lr * Math.pow(decayRate, epoch);
|
|
2621
|
+
}
|
|
2622
|
+
// ── Plateau Decay ─────────────────────────────────────────────────────────
|
|
2623
|
+
// If loss hasn't improved for `patience` epochs, multiply lr by `factor`.
|
|
2624
|
+
// Returns the new lr. Call this after each epoch with the current loss.
|
|
2625
|
+
//
|
|
2626
|
+
// Usage:
|
|
2627
|
+
// let patience_counter = 0
|
|
2628
|
+
// let best_loss = Infinity
|
|
2629
|
+
// for (let epoch = 0; epoch < 1000; epoch++) {
|
|
2630
|
+
// const loss = train(...)
|
|
2631
|
+
// lr = scheduler.plateauDecay(lr, loss, history, 10, 0.5)
|
|
2632
|
+
// }
|
|
2633
|
+
plateauDecay(lr, currentLoss, history, patience, factor) {
|
|
2634
|
+
if (history.length < patience) return lr;
|
|
2635
|
+
const recentLosses = history.slice(-patience);
|
|
2636
|
+
const minRecentLoss = Math.min(...recentLosses);
|
|
2637
|
+
if (currentLoss >= minRecentLoss) {
|
|
2638
|
+
return lr * factor;
|
|
2639
|
+
}
|
|
2640
|
+
return lr;
|
|
2641
|
+
}
|
|
2642
|
+
// ── Cosine Annealing ──────────────────────────────────────────────────────
|
|
2643
|
+
// lr = minLr + 0.5 * (maxLr - minLr) * (1 + cos(π * epoch / maxEpochs))
|
|
2644
|
+
cosineAnnealing(lr, epoch, maxEpochs, minLr = 0) {
|
|
2645
|
+
return minLr + 0.5 * (lr - minLr) * (1 + Math.cos(Math.PI * epoch / maxEpochs));
|
|
2646
|
+
}
|
|
2647
|
+
};
|
|
2648
|
+
|
|
2649
|
+
// src/ModelSaver.ts
|
|
2650
|
+
var ModelSaver = class _ModelSaver {
|
|
2651
|
+
// ── Serialize to JSON string ──────────────────────────────────────────────
|
|
2652
|
+
static toJSON(model) {
|
|
2653
|
+
return JSON.stringify({
|
|
2654
|
+
weights: model.getWeights(),
|
|
2655
|
+
timestamp: Date.now()
|
|
2656
|
+
});
|
|
2657
|
+
}
|
|
2658
|
+
// ── Deserialize from JSON string ──────────────────────────────────────────
|
|
2659
|
+
static fromJSON(model, json) {
|
|
2660
|
+
const data = JSON.parse(json);
|
|
2661
|
+
if (!data.weights || !Array.isArray(data.weights)) {
|
|
2662
|
+
throw new Error("ModelSaver.fromJSON: invalid model data");
|
|
2663
|
+
}
|
|
2664
|
+
model.setWeights(data.weights);
|
|
2665
|
+
}
|
|
2666
|
+
// ── Save to file (requires write function) ────────────────────────────────
|
|
2667
|
+
static saveToFile(model, path, writeFn) {
|
|
2668
|
+
const json = _ModelSaver.toJSON(model);
|
|
2669
|
+
writeFn(path, json);
|
|
2670
|
+
}
|
|
2671
|
+
// ── Load from file (requires read function) ───────────────────────────────
|
|
2672
|
+
static loadFromFile(model, path, readFn) {
|
|
2673
|
+
const json = readFn(path);
|
|
2674
|
+
_ModelSaver.fromJSON(model, json);
|
|
2675
|
+
}
|
|
2676
|
+
};
|
|
1290
2677
|
// Annotate the CommonJS export names for ESM import in node:
|
|
1291
2678
|
0 && (module.exports = {
|
|
1292
2679
|
Adam,
|
|
1293
2680
|
AttentionHead,
|
|
2681
|
+
BatchNorm,
|
|
2682
|
+
BiasVector,
|
|
2683
|
+
ClipOptimizer,
|
|
2684
|
+
ClippedOptimizerFactory,
|
|
2685
|
+
Conv1D,
|
|
2686
|
+
DataLoader,
|
|
2687
|
+
Dropout,
|
|
1294
2688
|
EmbeddingMatrix,
|
|
2689
|
+
GRULayer,
|
|
2690
|
+
LRScheduler,
|
|
1295
2691
|
LSTMLayer,
|
|
1296
2692
|
Layer,
|
|
1297
2693
|
LayerNorm,
|
|
2694
|
+
ModelSaver,
|
|
1298
2695
|
Momentum,
|
|
1299
2696
|
MultiHeadAttention,
|
|
1300
2697
|
Network,
|
|
@@ -1305,11 +2702,13 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1305
2702
|
Neuron,
|
|
1306
2703
|
NeuronN,
|
|
1307
2704
|
SGD,
|
|
2705
|
+
Trainer,
|
|
1308
2706
|
TransformerBlock,
|
|
1309
2707
|
WeightMatrix,
|
|
1310
2708
|
crossEntropy,
|
|
1311
2709
|
crossEntropyDelta,
|
|
1312
2710
|
crossEntropyDeltaRaw,
|
|
2711
|
+
defaultOptimizer,
|
|
1313
2712
|
elu,
|
|
1314
2713
|
leakyRelu,
|
|
1315
2714
|
linear,
|
|
@@ -1323,5 +2722,9 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1323
2722
|
softmax,
|
|
1324
2723
|
softmaxBackward,
|
|
1325
2724
|
tanh,
|
|
1326
|
-
transpose
|
|
2725
|
+
transpose,
|
|
2726
|
+
validate2DArray,
|
|
2727
|
+
validateArray,
|
|
2728
|
+
validateArrayMinLength,
|
|
2729
|
+
validateNumber
|
|
1327
2730
|
});
|