@dniskav/neuron 0.2.0 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -3
- package/dist/index.d.mts +28 -2
- package/dist/index.d.ts +28 -2
- package/dist/index.js +134 -9
- package/dist/index.mjs +133 -9
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -19,7 +19,7 @@ A minimal, dependency-free neural network library built from scratch in TypeScri
|
|
|
19
19
|
| `MultiHeadAttention` | N parallel attention heads concatenated and projected to `d_model`. |
|
|
20
20
|
| `AttentionHead` | Single scaled dot-product self-attention head (Q / K / V projections + backprop). |
|
|
21
21
|
| `LayerNorm` | Layer normalization with learnable γ / β per feature. |
|
|
22
|
-
| `WeightMatrix` | 2D weight matrix with per-scalar Adam optimizers. |
|
|
22
|
+
| `WeightMatrix` | 2D weight matrix with per-scalar Adam optimizers. Optional per-element gradient clipping via `update(dW, lr, clipValue)`. |
|
|
23
23
|
| `EmbeddingMatrix` | Lookup-table embedding matrix with SGD updates. |
|
|
24
24
|
| `sigmoid` `relu` `tanh` `linear` | Built-in activation functions. |
|
|
25
25
|
| `SGD` `Momentum` `Adam` | Optimizers. Each instance tracks its own state per weight. |
|
|
@@ -258,6 +258,7 @@ const targets = [...]; // 81*9 one-hot values
|
|
|
258
258
|
const mask = puzzle.map(v => v === 0); // only train on empty cells
|
|
259
259
|
|
|
260
260
|
const loss = net.train(puzzle, targets, 0.001, mask);
|
|
261
|
+
// loss is cross-entropy (not MSE) — decreases from ~2.2 toward 0 as training progresses
|
|
261
262
|
const logits = net.predict(puzzle); // 729 logits (81 × 9)
|
|
262
263
|
|
|
263
264
|
// Attention weights from all blocks for visualization
|
|
@@ -270,8 +271,10 @@ Each head in each block learns a different type of relationship (row, column,
|
|
|
270
271
|
|
|
271
272
|
## Possible improvements
|
|
272
273
|
|
|
273
|
-
1. **Support for batches** in training to improve efficiency.
|
|
274
|
-
2. **
|
|
274
|
+
1. **Support for batches** in training to improve efficiency and gradient stability.
|
|
275
|
+
2. **Global gradient norm clipping** — `WeightMatrix.update` supports per-element clipping; a utility to clip across all matrices by total norm would be more principled.
|
|
276
|
+
3. **Learning rate warmup** — standard practice for Transformers; ramp LR from 0 to target over the first N steps.
|
|
277
|
+
4. **Pre-norm architecture** — LayerNorm before the residual add (instead of after) is more stable for deep stacks.
|
|
275
278
|
|
|
276
279
|
## License
|
|
277
280
|
|
package/dist/index.d.mts
CHANGED
|
@@ -173,7 +173,7 @@ declare class WeightMatrix {
|
|
|
173
173
|
W: number[][];
|
|
174
174
|
private opts;
|
|
175
175
|
constructor(rows: number, cols: number);
|
|
176
|
-
update(dW: number[][], lr: number): void;
|
|
176
|
+
update(dW: number[][], lr: number, clipValue?: number): void;
|
|
177
177
|
}
|
|
178
178
|
declare class EmbeddingMatrix {
|
|
179
179
|
W: number[][];
|
|
@@ -274,10 +274,36 @@ declare class NetworkTransformer {
|
|
|
274
274
|
private _forward;
|
|
275
275
|
}
|
|
276
276
|
|
|
277
|
+
interface NetworkTransformerRLOptions {
|
|
278
|
+
d_model?: number;
|
|
279
|
+
nHeads?: number;
|
|
280
|
+
d_ff?: number;
|
|
281
|
+
nBlocks?: number;
|
|
282
|
+
nActions?: number;
|
|
283
|
+
}
|
|
284
|
+
declare class NetworkTransformerRL {
|
|
285
|
+
readonly seqLen: number;
|
|
286
|
+
readonly inputDim: number;
|
|
287
|
+
readonly d_model: number;
|
|
288
|
+
readonly nActions: number;
|
|
289
|
+
inputProj: WeightMatrix;
|
|
290
|
+
blocks: TransformerBlock[];
|
|
291
|
+
outputProj: WeightMatrix;
|
|
292
|
+
outputBias: number[];
|
|
293
|
+
private outBiasOpts;
|
|
294
|
+
private _projected;
|
|
295
|
+
constructor(seqLen: number, inputDim: number, options?: NetworkTransformerRLOptions);
|
|
296
|
+
predict(sequence: number[][]): number[];
|
|
297
|
+
train(sequence: number[][], target: number[], lr: number): number;
|
|
298
|
+
getAttentionWeights(): (number[][] | null)[][];
|
|
299
|
+
private _forward;
|
|
300
|
+
private _pool;
|
|
301
|
+
}
|
|
302
|
+
|
|
277
303
|
declare function mse(predicted: number[], actual: number[]): number;
|
|
278
304
|
declare function crossEntropy(predicted: number[], actual: number[]): number;
|
|
279
305
|
declare function mseDelta(predicted: number, actual: number): number;
|
|
280
306
|
declare function crossEntropyDelta(predicted: number, actual: number): number;
|
|
281
307
|
declare function crossEntropyDeltaRaw(predicted: number, actual: number): number;
|
|
282
308
|
|
|
283
|
-
export { type Activation, Adam, AttentionHead, EmbeddingMatrix, LSTMLayer, Layer, LayerNorm, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose };
|
|
309
|
+
export { type Activation, Adam, AttentionHead, EmbeddingMatrix, LSTMLayer, Layer, LayerNorm, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, NetworkTransformerRL, type NetworkTransformerRLOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose };
|
package/dist/index.d.ts
CHANGED
|
@@ -173,7 +173,7 @@ declare class WeightMatrix {
|
|
|
173
173
|
W: number[][];
|
|
174
174
|
private opts;
|
|
175
175
|
constructor(rows: number, cols: number);
|
|
176
|
-
update(dW: number[][], lr: number): void;
|
|
176
|
+
update(dW: number[][], lr: number, clipValue?: number): void;
|
|
177
177
|
}
|
|
178
178
|
declare class EmbeddingMatrix {
|
|
179
179
|
W: number[][];
|
|
@@ -274,10 +274,36 @@ declare class NetworkTransformer {
|
|
|
274
274
|
private _forward;
|
|
275
275
|
}
|
|
276
276
|
|
|
277
|
+
interface NetworkTransformerRLOptions {
|
|
278
|
+
d_model?: number;
|
|
279
|
+
nHeads?: number;
|
|
280
|
+
d_ff?: number;
|
|
281
|
+
nBlocks?: number;
|
|
282
|
+
nActions?: number;
|
|
283
|
+
}
|
|
284
|
+
declare class NetworkTransformerRL {
|
|
285
|
+
readonly seqLen: number;
|
|
286
|
+
readonly inputDim: number;
|
|
287
|
+
readonly d_model: number;
|
|
288
|
+
readonly nActions: number;
|
|
289
|
+
inputProj: WeightMatrix;
|
|
290
|
+
blocks: TransformerBlock[];
|
|
291
|
+
outputProj: WeightMatrix;
|
|
292
|
+
outputBias: number[];
|
|
293
|
+
private outBiasOpts;
|
|
294
|
+
private _projected;
|
|
295
|
+
constructor(seqLen: number, inputDim: number, options?: NetworkTransformerRLOptions);
|
|
296
|
+
predict(sequence: number[][]): number[];
|
|
297
|
+
train(sequence: number[][], target: number[], lr: number): number;
|
|
298
|
+
getAttentionWeights(): (number[][] | null)[][];
|
|
299
|
+
private _forward;
|
|
300
|
+
private _pool;
|
|
301
|
+
}
|
|
302
|
+
|
|
277
303
|
declare function mse(predicted: number[], actual: number[]): number;
|
|
278
304
|
declare function crossEntropy(predicted: number[], actual: number[]): number;
|
|
279
305
|
declare function mseDelta(predicted: number, actual: number): number;
|
|
280
306
|
declare function crossEntropyDelta(predicted: number, actual: number): number;
|
|
281
307
|
declare function crossEntropyDeltaRaw(predicted: number, actual: number): number;
|
|
282
308
|
|
|
283
|
-
export { type Activation, Adam, AttentionHead, EmbeddingMatrix, LSTMLayer, Layer, LayerNorm, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose };
|
|
309
|
+
export { type Activation, Adam, AttentionHead, EmbeddingMatrix, LSTMLayer, Layer, LayerNorm, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, NetworkTransformerRL, type NetworkTransformerRLOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose };
|
package/dist/index.js
CHANGED
|
@@ -32,6 +32,7 @@ __export(index_exports, {
|
|
|
32
32
|
NetworkLSTM: () => NetworkLSTM,
|
|
33
33
|
NetworkN: () => NetworkN,
|
|
34
34
|
NetworkTransformer: () => NetworkTransformer,
|
|
35
|
+
NetworkTransformerRL: () => NetworkTransformerRL,
|
|
35
36
|
Neuron: () => Neuron,
|
|
36
37
|
NeuronN: () => NeuronN,
|
|
37
38
|
SGD: () => SGD,
|
|
@@ -579,10 +580,15 @@ var WeightMatrix = class {
|
|
|
579
580
|
);
|
|
580
581
|
}
|
|
581
582
|
// Apply pre-computed gradient (same shape as W).
|
|
582
|
-
|
|
583
|
+
// clipValue: optional per-element gradient clipping before the Adam step.
|
|
584
|
+
// Prevents gradient explosion in deep networks (e.g. Transformers without
|
|
585
|
+
// global norm clipping). Pass e.g. 1.0 to clip to [-1, 1].
|
|
586
|
+
update(dW, lr, clipValue = Infinity) {
|
|
583
587
|
for (let i = 0; i < this.W.length; i++)
|
|
584
|
-
for (let j = 0; j < this.W[0].length; j++)
|
|
585
|
-
|
|
588
|
+
for (let j = 0; j < this.W[0].length; j++) {
|
|
589
|
+
const g = isFinite(clipValue) ? Math.max(-clipValue, Math.min(clipValue, dW[i][j])) : dW[i][j];
|
|
590
|
+
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
591
|
+
}
|
|
586
592
|
}
|
|
587
593
|
};
|
|
588
594
|
var EmbeddingMatrix = class {
|
|
@@ -1036,14 +1042,14 @@ var NetworkTransformer = class {
|
|
|
1036
1042
|
const dLogits = Array.from({ length: this.seqLen }, (_, i) => {
|
|
1037
1043
|
if (mask && !mask[i]) return new Array(this.nClasses).fill(0);
|
|
1038
1044
|
count++;
|
|
1039
|
-
|
|
1045
|
+
const probs = softmax(logits[i]);
|
|
1046
|
+
for (let c = 0; c < this.nClasses; c++) {
|
|
1040
1047
|
const t = targets[i * this.nClasses + c];
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
});
|
|
1048
|
+
if (t > 0) loss -= Math.log(Math.max(probs[c], 1e-7));
|
|
1049
|
+
}
|
|
1050
|
+
return probs.map((p, c) => p - targets[i * this.nClasses + c]);
|
|
1045
1051
|
});
|
|
1046
|
-
if (count > 0) loss /= count
|
|
1052
|
+
if (count > 0) loss /= count;
|
|
1047
1053
|
const dH = Array.from(
|
|
1048
1054
|
{ length: this.seqLen },
|
|
1049
1055
|
(_, i) => Array.from(
|
|
@@ -1093,6 +1099,124 @@ var NetworkTransformer = class {
|
|
|
1093
1099
|
}
|
|
1094
1100
|
};
|
|
1095
1101
|
|
|
1102
|
+
// src/NetworkTransformerRL.ts
|
|
1103
|
+
var NetworkTransformerRL = class {
|
|
1104
|
+
constructor(seqLen, inputDim, options = {}) {
|
|
1105
|
+
// Forward caches para backprop
|
|
1106
|
+
this._projected = null;
|
|
1107
|
+
const {
|
|
1108
|
+
d_model = 32,
|
|
1109
|
+
nHeads = 2,
|
|
1110
|
+
d_ff = 64,
|
|
1111
|
+
nBlocks = 2,
|
|
1112
|
+
nActions = 2
|
|
1113
|
+
} = options;
|
|
1114
|
+
this.seqLen = seqLen;
|
|
1115
|
+
this.inputDim = inputDim;
|
|
1116
|
+
this.d_model = d_model;
|
|
1117
|
+
this.nActions = nActions;
|
|
1118
|
+
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1119
|
+
this.blocks = Array.from(
|
|
1120
|
+
{ length: nBlocks },
|
|
1121
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1122
|
+
);
|
|
1123
|
+
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1124
|
+
this.outputBias = new Array(nActions).fill(0);
|
|
1125
|
+
this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
|
|
1126
|
+
}
|
|
1127
|
+
// ── Forward ────────────────────────────────────────────────────────────────
|
|
1128
|
+
// sequence: seqLen × inputDim → nActions Q-values
|
|
1129
|
+
predict(sequence) {
|
|
1130
|
+
const h = this._forward(sequence);
|
|
1131
|
+
const pooled = this._pool(h);
|
|
1132
|
+
return this.outputProj.W.map(
|
|
1133
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1134
|
+
);
|
|
1135
|
+
}
|
|
1136
|
+
// ── Training ────────────────────────────────────────────────────────────────
|
|
1137
|
+
// sequence: seqLen × inputDim
|
|
1138
|
+
// target: nActions Q-values (one-hot style para Q-learning)
|
|
1139
|
+
// lr: learning rate
|
|
1140
|
+
// Returns: MSE loss
|
|
1141
|
+
train(sequence, target, lr) {
|
|
1142
|
+
const h = this._forward(sequence);
|
|
1143
|
+
const pooled = this._pool(h);
|
|
1144
|
+
const pred = this.outputProj.W.map(
|
|
1145
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1146
|
+
);
|
|
1147
|
+
const n = this.nActions;
|
|
1148
|
+
let loss = 0;
|
|
1149
|
+
for (let c = 0; c < n; c++) {
|
|
1150
|
+
const diff = pred[c] - target[c];
|
|
1151
|
+
loss += diff * diff;
|
|
1152
|
+
}
|
|
1153
|
+
loss /= n;
|
|
1154
|
+
const dPred = pred.map((p, c) => 2 * (p - target[c]) / n);
|
|
1155
|
+
const dPooled = Array.from(
|
|
1156
|
+
{ length: this.d_model },
|
|
1157
|
+
(_, m) => dPred.reduce((s, dp, c) => s + dp * this.outputProj.W[c][m], 0)
|
|
1158
|
+
);
|
|
1159
|
+
const dWout = Array.from(
|
|
1160
|
+
{ length: this.nActions },
|
|
1161
|
+
(_, c) => Array.from(
|
|
1162
|
+
{ length: this.d_model },
|
|
1163
|
+
(_2, m) => dPred[c] * pooled[m]
|
|
1164
|
+
)
|
|
1165
|
+
);
|
|
1166
|
+
const dBout = dPred.slice();
|
|
1167
|
+
this.outputProj.update(dWout, lr);
|
|
1168
|
+
for (let c = 0; c < this.nActions; c++)
|
|
1169
|
+
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1170
|
+
let dH = Array.from(
|
|
1171
|
+
{ length: this.seqLen },
|
|
1172
|
+
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1173
|
+
// Gradiente dividido entre posiciones
|
|
1174
|
+
);
|
|
1175
|
+
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1176
|
+
dH = this.blocks[b].backward(dH, lr);
|
|
1177
|
+
for (let i = 0; i < this.seqLen; i++) {
|
|
1178
|
+
const dInputProj = Array.from(
|
|
1179
|
+
{ length: this.d_model },
|
|
1180
|
+
(_, k) => Array.from(
|
|
1181
|
+
{ length: this.inputDim },
|
|
1182
|
+
(_2, m) => dH[i][k] * sequence[i][m]
|
|
1183
|
+
)
|
|
1184
|
+
);
|
|
1185
|
+
this.inputProj.update(dInputProj, lr);
|
|
1186
|
+
}
|
|
1187
|
+
return loss;
|
|
1188
|
+
}
|
|
1189
|
+
// Attention weights from every block for visualization.
|
|
1190
|
+
getAttentionWeights() {
|
|
1191
|
+
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1192
|
+
}
|
|
1193
|
+
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1194
|
+
_forward(sequence) {
|
|
1195
|
+
let h = sequence.map(
|
|
1196
|
+
(step) => this.inputProj.W.map(
|
|
1197
|
+
(row, k) => row.reduce((s, w, m) => s + w * step[m], 0)
|
|
1198
|
+
)
|
|
1199
|
+
);
|
|
1200
|
+
for (const block of this.blocks)
|
|
1201
|
+
h = block.predict(h);
|
|
1202
|
+
this._projected = h;
|
|
1203
|
+
return h;
|
|
1204
|
+
}
|
|
1205
|
+
_pool(h) {
|
|
1206
|
+
const weights = Array.from(
|
|
1207
|
+
{ length: this.seqLen },
|
|
1208
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1209
|
+
);
|
|
1210
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1211
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1212
|
+
let sum = 0;
|
|
1213
|
+
for (let i = 0; i < this.seqLen; i++)
|
|
1214
|
+
sum += weights[i] * h[i][m];
|
|
1215
|
+
return sum / totalWeight;
|
|
1216
|
+
});
|
|
1217
|
+
}
|
|
1218
|
+
};
|
|
1219
|
+
|
|
1096
1220
|
// src/losses.ts
|
|
1097
1221
|
function mse(predicted, actual) {
|
|
1098
1222
|
return predicted.reduce((sum, p, i) => sum + (actual[i] - p) ** 2, 0) / predicted.length;
|
|
@@ -1129,6 +1253,7 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
1129
1253
|
NetworkLSTM,
|
|
1130
1254
|
NetworkN,
|
|
1131
1255
|
NetworkTransformer,
|
|
1256
|
+
NetworkTransformerRL,
|
|
1132
1257
|
Neuron,
|
|
1133
1258
|
NeuronN,
|
|
1134
1259
|
SGD,
|
package/dist/index.mjs
CHANGED
|
@@ -520,10 +520,15 @@ var WeightMatrix = class {
|
|
|
520
520
|
);
|
|
521
521
|
}
|
|
522
522
|
// Apply pre-computed gradient (same shape as W).
|
|
523
|
-
|
|
523
|
+
// clipValue: optional per-element gradient clipping before the Adam step.
|
|
524
|
+
// Prevents gradient explosion in deep networks (e.g. Transformers without
|
|
525
|
+
// global norm clipping). Pass e.g. 1.0 to clip to [-1, 1].
|
|
526
|
+
update(dW, lr, clipValue = Infinity) {
|
|
524
527
|
for (let i = 0; i < this.W.length; i++)
|
|
525
|
-
for (let j = 0; j < this.W[0].length; j++)
|
|
526
|
-
|
|
528
|
+
for (let j = 0; j < this.W[0].length; j++) {
|
|
529
|
+
const g = isFinite(clipValue) ? Math.max(-clipValue, Math.min(clipValue, dW[i][j])) : dW[i][j];
|
|
530
|
+
this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
|
|
531
|
+
}
|
|
527
532
|
}
|
|
528
533
|
};
|
|
529
534
|
var EmbeddingMatrix = class {
|
|
@@ -977,14 +982,14 @@ var NetworkTransformer = class {
|
|
|
977
982
|
const dLogits = Array.from({ length: this.seqLen }, (_, i) => {
|
|
978
983
|
if (mask && !mask[i]) return new Array(this.nClasses).fill(0);
|
|
979
984
|
count++;
|
|
980
|
-
|
|
985
|
+
const probs = softmax(logits[i]);
|
|
986
|
+
for (let c = 0; c < this.nClasses; c++) {
|
|
981
987
|
const t = targets[i * this.nClasses + c];
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
});
|
|
988
|
+
if (t > 0) loss -= Math.log(Math.max(probs[c], 1e-7));
|
|
989
|
+
}
|
|
990
|
+
return probs.map((p, c) => p - targets[i * this.nClasses + c]);
|
|
986
991
|
});
|
|
987
|
-
if (count > 0) loss /= count
|
|
992
|
+
if (count > 0) loss /= count;
|
|
988
993
|
const dH = Array.from(
|
|
989
994
|
{ length: this.seqLen },
|
|
990
995
|
(_, i) => Array.from(
|
|
@@ -1034,6 +1039,124 @@ var NetworkTransformer = class {
|
|
|
1034
1039
|
}
|
|
1035
1040
|
};
|
|
1036
1041
|
|
|
1042
|
+
// src/NetworkTransformerRL.ts
|
|
1043
|
+
var NetworkTransformerRL = class {
|
|
1044
|
+
constructor(seqLen, inputDim, options = {}) {
|
|
1045
|
+
// Forward caches para backprop
|
|
1046
|
+
this._projected = null;
|
|
1047
|
+
const {
|
|
1048
|
+
d_model = 32,
|
|
1049
|
+
nHeads = 2,
|
|
1050
|
+
d_ff = 64,
|
|
1051
|
+
nBlocks = 2,
|
|
1052
|
+
nActions = 2
|
|
1053
|
+
} = options;
|
|
1054
|
+
this.seqLen = seqLen;
|
|
1055
|
+
this.inputDim = inputDim;
|
|
1056
|
+
this.d_model = d_model;
|
|
1057
|
+
this.nActions = nActions;
|
|
1058
|
+
this.inputProj = new WeightMatrix(d_model, inputDim);
|
|
1059
|
+
this.blocks = Array.from(
|
|
1060
|
+
{ length: nBlocks },
|
|
1061
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
1062
|
+
);
|
|
1063
|
+
this.outputProj = new WeightMatrix(nActions, d_model);
|
|
1064
|
+
this.outputBias = new Array(nActions).fill(0);
|
|
1065
|
+
this.outBiasOpts = Array.from({ length: nActions }, () => new Adam());
|
|
1066
|
+
}
|
|
1067
|
+
// ── Forward ────────────────────────────────────────────────────────────────
|
|
1068
|
+
// sequence: seqLen × inputDim → nActions Q-values
|
|
1069
|
+
predict(sequence) {
|
|
1070
|
+
const h = this._forward(sequence);
|
|
1071
|
+
const pooled = this._pool(h);
|
|
1072
|
+
return this.outputProj.W.map(
|
|
1073
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1074
|
+
);
|
|
1075
|
+
}
|
|
1076
|
+
// ── Training ────────────────────────────────────────────────────────────────
|
|
1077
|
+
// sequence: seqLen × inputDim
|
|
1078
|
+
// target: nActions Q-values (one-hot style para Q-learning)
|
|
1079
|
+
// lr: learning rate
|
|
1080
|
+
// Returns: MSE loss
|
|
1081
|
+
train(sequence, target, lr) {
|
|
1082
|
+
const h = this._forward(sequence);
|
|
1083
|
+
const pooled = this._pool(h);
|
|
1084
|
+
const pred = this.outputProj.W.map(
|
|
1085
|
+
(row, c) => row.reduce((s, w, m) => s + w * pooled[m], this.outputBias[c])
|
|
1086
|
+
);
|
|
1087
|
+
const n = this.nActions;
|
|
1088
|
+
let loss = 0;
|
|
1089
|
+
for (let c = 0; c < n; c++) {
|
|
1090
|
+
const diff = pred[c] - target[c];
|
|
1091
|
+
loss += diff * diff;
|
|
1092
|
+
}
|
|
1093
|
+
loss /= n;
|
|
1094
|
+
const dPred = pred.map((p, c) => 2 * (p - target[c]) / n);
|
|
1095
|
+
const dPooled = Array.from(
|
|
1096
|
+
{ length: this.d_model },
|
|
1097
|
+
(_, m) => dPred.reduce((s, dp, c) => s + dp * this.outputProj.W[c][m], 0)
|
|
1098
|
+
);
|
|
1099
|
+
const dWout = Array.from(
|
|
1100
|
+
{ length: this.nActions },
|
|
1101
|
+
(_, c) => Array.from(
|
|
1102
|
+
{ length: this.d_model },
|
|
1103
|
+
(_2, m) => dPred[c] * pooled[m]
|
|
1104
|
+
)
|
|
1105
|
+
);
|
|
1106
|
+
const dBout = dPred.slice();
|
|
1107
|
+
this.outputProj.update(dWout, lr);
|
|
1108
|
+
for (let c = 0; c < this.nActions; c++)
|
|
1109
|
+
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1110
|
+
let dH = Array.from(
|
|
1111
|
+
{ length: this.seqLen },
|
|
1112
|
+
(_, i) => dPooled.map((v) => v / this.seqLen)
|
|
1113
|
+
// Gradiente dividido entre posiciones
|
|
1114
|
+
);
|
|
1115
|
+
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1116
|
+
dH = this.blocks[b].backward(dH, lr);
|
|
1117
|
+
for (let i = 0; i < this.seqLen; i++) {
|
|
1118
|
+
const dInputProj = Array.from(
|
|
1119
|
+
{ length: this.d_model },
|
|
1120
|
+
(_, k) => Array.from(
|
|
1121
|
+
{ length: this.inputDim },
|
|
1122
|
+
(_2, m) => dH[i][k] * sequence[i][m]
|
|
1123
|
+
)
|
|
1124
|
+
);
|
|
1125
|
+
this.inputProj.update(dInputProj, lr);
|
|
1126
|
+
}
|
|
1127
|
+
return loss;
|
|
1128
|
+
}
|
|
1129
|
+
// Attention weights from every block for visualization.
|
|
1130
|
+
getAttentionWeights() {
|
|
1131
|
+
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1132
|
+
}
|
|
1133
|
+
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1134
|
+
_forward(sequence) {
|
|
1135
|
+
let h = sequence.map(
|
|
1136
|
+
(step) => this.inputProj.W.map(
|
|
1137
|
+
(row, k) => row.reduce((s, w, m) => s + w * step[m], 0)
|
|
1138
|
+
)
|
|
1139
|
+
);
|
|
1140
|
+
for (const block of this.blocks)
|
|
1141
|
+
h = block.predict(h);
|
|
1142
|
+
this._projected = h;
|
|
1143
|
+
return h;
|
|
1144
|
+
}
|
|
1145
|
+
_pool(h) {
|
|
1146
|
+
const weights = Array.from(
|
|
1147
|
+
{ length: this.seqLen },
|
|
1148
|
+
(_, i) => i === this.seqLen - 1 ? 2 : 1
|
|
1149
|
+
);
|
|
1150
|
+
const totalWeight = weights.reduce((a, b) => a + b, 0);
|
|
1151
|
+
return Array.from({ length: this.d_model }, (_, m) => {
|
|
1152
|
+
let sum = 0;
|
|
1153
|
+
for (let i = 0; i < this.seqLen; i++)
|
|
1154
|
+
sum += weights[i] * h[i][m];
|
|
1155
|
+
return sum / totalWeight;
|
|
1156
|
+
});
|
|
1157
|
+
}
|
|
1158
|
+
};
|
|
1159
|
+
|
|
1037
1160
|
// src/losses.ts
|
|
1038
1161
|
function mse(predicted, actual) {
|
|
1039
1162
|
return predicted.reduce((sum, p, i) => sum + (actual[i] - p) ** 2, 0) / predicted.length;
|
|
@@ -1069,6 +1192,7 @@ export {
|
|
|
1069
1192
|
NetworkLSTM,
|
|
1070
1193
|
NetworkN,
|
|
1071
1194
|
NetworkTransformer,
|
|
1195
|
+
NetworkTransformerRL,
|
|
1072
1196
|
Neuron,
|
|
1073
1197
|
NeuronN,
|
|
1074
1198
|
SGD,
|
package/package.json
CHANGED