@dniskav/neuron 0.2.2 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +38 -0
- package/dist/index.d.mts +28 -0
- package/dist/index.d.ts +28 -0
- package/dist/index.js +48 -0
- package/dist/index.mjs +48 -0
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -15,6 +15,7 @@ A minimal, dependency-free neural network library built from scratch in TypeScri
|
|
|
15
15
|
| `LSTMLayer` | Recurrent layer with persistent hidden and cell state. Learns sequences via BPTT. |
|
|
16
16
|
| `NetworkLSTM` | Wraps an `LSTMLayer` + dense layers. Maintains memory across steps within an episode. |
|
|
17
17
|
| `NetworkTransformer` | Full token-classification Transformer: embeddings → N blocks → per-token logits. |
|
|
18
|
+
| `NetworkTransformerRL` | Transformer for RL agents: continuous input projection → causal attention → Q-values. Remembers the last N steps. |
|
|
18
19
|
| `TransformerBlock` | One Transformer block: multi-head attention + FFN + LayerNorm × 2 with residuals. |
|
|
19
20
|
| `MultiHeadAttention` | N parallel attention heads concatenated and projected to `d_model`. |
|
|
20
21
|
| `AttentionHead` | Single scaled dot-product self-attention head (Q / K / V projections + backprop). |
|
|
@@ -269,6 +270,43 @@ const weights = net.getAttentionWeights();
|
|
|
269
270
|
Each head in each block learns a different type of relationship (row, column,
|
|
270
271
|
3×3 box). The network figures this out by itself through training.
|
|
271
272
|
|
|
273
|
+
### NetworkTransformerRL — Transformer for reinforcement learning
|
|
274
|
+
|
|
275
|
+
`NetworkTransformerRL` uses causal self-attention over a sliding window of past states to output Q-values. Unlike `NetworkLSTM`, the agent attends to specific past moments rather than compressing them into a single hidden vector.
|
|
276
|
+
|
|
277
|
+
```ts
|
|
278
|
+
import { NetworkTransformerRL } from "@dniskav/neuron";
|
|
279
|
+
|
|
280
|
+
// Agent sees the last 8 steps, each step is a 7-value sensor vector → 4 actions
|
|
281
|
+
const net = new NetworkTransformerRL(8, 7, {
|
|
282
|
+
d_model: 32,
|
|
283
|
+
nHeads: 2,
|
|
284
|
+
d_ff: 64,
|
|
285
|
+
nBlocks: 2,
|
|
286
|
+
nActions: 4,
|
|
287
|
+
});
|
|
288
|
+
|
|
289
|
+
// Each step: feed the last N states as a sequence
|
|
290
|
+
const sequence = getLastNStates(); // number[][] — shape: [8, 7]
|
|
291
|
+
const qValues = net.predict(sequence); // number[4]
|
|
292
|
+
|
|
293
|
+
// Q-learning update: train toward Bellman target
|
|
294
|
+
const action = argmax(qValues);
|
|
295
|
+
const reward = env.step(action);
|
|
296
|
+
const targets = qValues.slice();
|
|
297
|
+
targets[action] = reward + 0.99 * Math.max(...net.predict(nextSequence));
|
|
298
|
+
|
|
299
|
+
const loss = net.train(sequence, targets, 0.001);
|
|
300
|
+
```
|
|
301
|
+
|
|
302
|
+
The last step in the sequence gets 2× pooling weight — the most recent state contributes more to the decision.
|
|
303
|
+
|
|
304
|
+
```ts
|
|
305
|
+
// Inspect what the agent is attending to
|
|
306
|
+
const attnWeights = net.getAttentionWeights();
|
|
307
|
+
// attnWeights[blockIdx][headIdx] → seqLen × seqLen matrix
|
|
308
|
+
```
|
|
309
|
+
|
|
272
310
|
## Possible improvements
|
|
273
311
|
|
|
274
312
|
1. **Support for batches** in training to improve efficiency and gradient stability.
|
package/dist/index.d.mts
CHANGED
|
@@ -296,6 +296,34 @@ declare class NetworkTransformerRL {
|
|
|
296
296
|
predict(sequence: number[][]): number[];
|
|
297
297
|
train(sequence: number[][], target: number[], lr: number): number;
|
|
298
298
|
getAttentionWeights(): (number[][] | null)[][];
|
|
299
|
+
getWeights(): {
|
|
300
|
+
inputProj: number[][];
|
|
301
|
+
blocks: {
|
|
302
|
+
attn: {
|
|
303
|
+
heads: {
|
|
304
|
+
Wq: number[][];
|
|
305
|
+
Wk: number[][];
|
|
306
|
+
Wv: number[][];
|
|
307
|
+
}[];
|
|
308
|
+
Wo: number[][];
|
|
309
|
+
};
|
|
310
|
+
norm1: {
|
|
311
|
+
gamma: number[];
|
|
312
|
+
beta: number[];
|
|
313
|
+
};
|
|
314
|
+
norm2: {
|
|
315
|
+
gamma: number[];
|
|
316
|
+
beta: number[];
|
|
317
|
+
};
|
|
318
|
+
ff1: number[][];
|
|
319
|
+
ff2: number[][];
|
|
320
|
+
b1: number[];
|
|
321
|
+
b2: number[];
|
|
322
|
+
}[];
|
|
323
|
+
outputProj: number[][];
|
|
324
|
+
outputBias: number[];
|
|
325
|
+
};
|
|
326
|
+
setWeights(data: ReturnType<NetworkTransformerRL['getWeights']>): void;
|
|
299
327
|
private _forward;
|
|
300
328
|
private _pool;
|
|
301
329
|
}
|
package/dist/index.d.ts
CHANGED
|
@@ -296,6 +296,34 @@ declare class NetworkTransformerRL {
|
|
|
296
296
|
predict(sequence: number[][]): number[];
|
|
297
297
|
train(sequence: number[][], target: number[], lr: number): number;
|
|
298
298
|
getAttentionWeights(): (number[][] | null)[][];
|
|
299
|
+
getWeights(): {
|
|
300
|
+
inputProj: number[][];
|
|
301
|
+
blocks: {
|
|
302
|
+
attn: {
|
|
303
|
+
heads: {
|
|
304
|
+
Wq: number[][];
|
|
305
|
+
Wk: number[][];
|
|
306
|
+
Wv: number[][];
|
|
307
|
+
}[];
|
|
308
|
+
Wo: number[][];
|
|
309
|
+
};
|
|
310
|
+
norm1: {
|
|
311
|
+
gamma: number[];
|
|
312
|
+
beta: number[];
|
|
313
|
+
};
|
|
314
|
+
norm2: {
|
|
315
|
+
gamma: number[];
|
|
316
|
+
beta: number[];
|
|
317
|
+
};
|
|
318
|
+
ff1: number[][];
|
|
319
|
+
ff2: number[][];
|
|
320
|
+
b1: number[];
|
|
321
|
+
b2: number[];
|
|
322
|
+
}[];
|
|
323
|
+
outputProj: number[][];
|
|
324
|
+
outputBias: number[];
|
|
325
|
+
};
|
|
326
|
+
setWeights(data: ReturnType<NetworkTransformerRL['getWeights']>): void;
|
|
299
327
|
private _forward;
|
|
300
328
|
private _pool;
|
|
301
329
|
}
|
package/dist/index.js
CHANGED
|
@@ -1190,6 +1190,54 @@ var NetworkTransformerRL = class {
|
|
|
1190
1190
|
getAttentionWeights() {
|
|
1191
1191
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1192
1192
|
}
|
|
1193
|
+
// ── Serialization ──────────────────────────────────────────────────────────
|
|
1194
|
+
getWeights() {
|
|
1195
|
+
return {
|
|
1196
|
+
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1197
|
+
blocks: this.blocks.map((b) => ({
|
|
1198
|
+
attn: {
|
|
1199
|
+
heads: b.attn.heads.map((h) => ({
|
|
1200
|
+
Wq: h.Wq.W.map((r) => [...r]),
|
|
1201
|
+
Wk: h.Wk.W.map((r) => [...r]),
|
|
1202
|
+
Wv: h.Wv.W.map((r) => [...r])
|
|
1203
|
+
})),
|
|
1204
|
+
Wo: b.attn.Wo.W.map((r) => [...r])
|
|
1205
|
+
},
|
|
1206
|
+
norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
|
|
1207
|
+
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1208
|
+
ff1: b.ff1.W.map((r) => [...r]),
|
|
1209
|
+
ff2: b.ff2.W.map((r) => [...r]),
|
|
1210
|
+
b1: [...b.b1],
|
|
1211
|
+
b2: [...b.b2]
|
|
1212
|
+
})),
|
|
1213
|
+
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1214
|
+
outputBias: [...this.outputBias]
|
|
1215
|
+
};
|
|
1216
|
+
}
|
|
1217
|
+
setWeights(data) {
|
|
1218
|
+
data.inputProj.forEach((row, i) => {
|
|
1219
|
+
this.inputProj.W[i] = [...row];
|
|
1220
|
+
});
|
|
1221
|
+
data.blocks.forEach((bd, b) => {
|
|
1222
|
+
const blk = this.blocks[b];
|
|
1223
|
+
bd.attn.heads.forEach((hd, h) => {
|
|
1224
|
+
blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
|
|
1225
|
+
blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
|
|
1226
|
+
blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
|
|
1227
|
+
});
|
|
1228
|
+
blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
|
|
1229
|
+
blk.norm1.gamma = [...bd.norm1.gamma];
|
|
1230
|
+
blk.norm1.beta = [...bd.norm1.beta];
|
|
1231
|
+
blk.norm2.gamma = [...bd.norm2.gamma];
|
|
1232
|
+
blk.norm2.beta = [...bd.norm2.beta];
|
|
1233
|
+
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1234
|
+
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1235
|
+
blk.b1 = [...bd.b1];
|
|
1236
|
+
blk.b2 = [...bd.b2];
|
|
1237
|
+
});
|
|
1238
|
+
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1239
|
+
this.outputBias = [...data.outputBias];
|
|
1240
|
+
}
|
|
1193
1241
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1194
1242
|
_forward(sequence) {
|
|
1195
1243
|
let h = sequence.map(
|
package/dist/index.mjs
CHANGED
|
@@ -1130,6 +1130,54 @@ var NetworkTransformerRL = class {
|
|
|
1130
1130
|
getAttentionWeights() {
|
|
1131
1131
|
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1132
1132
|
}
|
|
1133
|
+
// ── Serialization ──────────────────────────────────────────────────────────
|
|
1134
|
+
getWeights() {
|
|
1135
|
+
return {
|
|
1136
|
+
inputProj: this.inputProj.W.map((r) => [...r]),
|
|
1137
|
+
blocks: this.blocks.map((b) => ({
|
|
1138
|
+
attn: {
|
|
1139
|
+
heads: b.attn.heads.map((h) => ({
|
|
1140
|
+
Wq: h.Wq.W.map((r) => [...r]),
|
|
1141
|
+
Wk: h.Wk.W.map((r) => [...r]),
|
|
1142
|
+
Wv: h.Wv.W.map((r) => [...r])
|
|
1143
|
+
})),
|
|
1144
|
+
Wo: b.attn.Wo.W.map((r) => [...r])
|
|
1145
|
+
},
|
|
1146
|
+
norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
|
|
1147
|
+
norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
|
|
1148
|
+
ff1: b.ff1.W.map((r) => [...r]),
|
|
1149
|
+
ff2: b.ff2.W.map((r) => [...r]),
|
|
1150
|
+
b1: [...b.b1],
|
|
1151
|
+
b2: [...b.b2]
|
|
1152
|
+
})),
|
|
1153
|
+
outputProj: this.outputProj.W.map((r) => [...r]),
|
|
1154
|
+
outputBias: [...this.outputBias]
|
|
1155
|
+
};
|
|
1156
|
+
}
|
|
1157
|
+
setWeights(data) {
|
|
1158
|
+
data.inputProj.forEach((row, i) => {
|
|
1159
|
+
this.inputProj.W[i] = [...row];
|
|
1160
|
+
});
|
|
1161
|
+
data.blocks.forEach((bd, b) => {
|
|
1162
|
+
const blk = this.blocks[b];
|
|
1163
|
+
bd.attn.heads.forEach((hd, h) => {
|
|
1164
|
+
blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
|
|
1165
|
+
blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
|
|
1166
|
+
blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
|
|
1167
|
+
});
|
|
1168
|
+
blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
|
|
1169
|
+
blk.norm1.gamma = [...bd.norm1.gamma];
|
|
1170
|
+
blk.norm1.beta = [...bd.norm1.beta];
|
|
1171
|
+
blk.norm2.gamma = [...bd.norm2.gamma];
|
|
1172
|
+
blk.norm2.beta = [...bd.norm2.beta];
|
|
1173
|
+
blk.ff1.W = bd.ff1.map((r) => [...r]);
|
|
1174
|
+
blk.ff2.W = bd.ff2.map((r) => [...r]);
|
|
1175
|
+
blk.b1 = [...bd.b1];
|
|
1176
|
+
blk.b2 = [...bd.b2];
|
|
1177
|
+
});
|
|
1178
|
+
this.outputProj.W = data.outputProj.map((r) => [...r]);
|
|
1179
|
+
this.outputBias = [...data.outputBias];
|
|
1180
|
+
}
|
|
1133
1181
|
// ── Internal ────────────────────────────────────────────────────────────────
|
|
1134
1182
|
_forward(sequence) {
|
|
1135
1183
|
let h = sequence.map(
|
package/package.json
CHANGED