@dniskav/neuron 0.2.2 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -15,6 +15,7 @@ A minimal, dependency-free neural network library built from scratch in TypeScri
15
15
  | `LSTMLayer` | Recurrent layer with persistent hidden and cell state. Learns sequences via BPTT. |
16
16
  | `NetworkLSTM` | Wraps an `LSTMLayer` + dense layers. Maintains memory across steps within an episode. |
17
17
  | `NetworkTransformer` | Full token-classification Transformer: embeddings → N blocks → per-token logits. |
18
+ | `NetworkTransformerRL` | Transformer for RL agents: continuous input projection → causal attention → Q-values. Remembers the last N steps. |
18
19
  | `TransformerBlock` | One Transformer block: multi-head attention + FFN + LayerNorm × 2 with residuals. |
19
20
  | `MultiHeadAttention` | N parallel attention heads concatenated and projected to `d_model`. |
20
21
  | `AttentionHead` | Single scaled dot-product self-attention head (Q / K / V projections + backprop). |
@@ -269,6 +270,43 @@ const weights = net.getAttentionWeights();
269
270
  Each head in each block learns a different type of relationship (row, column,
270
271
  3×3 box). The network figures this out by itself through training.
271
272
 
273
+ ### NetworkTransformerRL — Transformer for reinforcement learning
274
+
275
+ `NetworkTransformerRL` uses causal self-attention over a sliding window of past states to output Q-values. Unlike `NetworkLSTM`, the agent attends to specific past moments rather than compressing them into a single hidden vector.
276
+
277
+ ```ts
278
+ import { NetworkTransformerRL } from "@dniskav/neuron";
279
+
280
+ // Agent sees the last 8 steps, each step is a 7-value sensor vector → 4 actions
281
+ const net = new NetworkTransformerRL(8, 7, {
282
+ d_model: 32,
283
+ nHeads: 2,
284
+ d_ff: 64,
285
+ nBlocks: 2,
286
+ nActions: 4,
287
+ });
288
+
289
+ // Each step: feed the last N states as a sequence
290
+ const sequence = getLastNStates(); // number[][] — shape: [8, 7]
291
+ const qValues = net.predict(sequence); // number[4]
292
+
293
+ // Q-learning update: train toward Bellman target
294
+ const action = argmax(qValues);
295
+ const reward = env.step(action);
296
+ const targets = qValues.slice();
297
+ targets[action] = reward + 0.99 * Math.max(...net.predict(nextSequence));
298
+
299
+ const loss = net.train(sequence, targets, 0.001);
300
+ ```
301
+
302
+ The last step in the sequence gets 2× pooling weight — the most recent state contributes more to the decision.
303
+
304
+ ```ts
305
+ // Inspect what the agent is attending to
306
+ const attnWeights = net.getAttentionWeights();
307
+ // attnWeights[blockIdx][headIdx] → seqLen × seqLen matrix
308
+ ```
309
+
272
310
  ## Possible improvements
273
311
 
274
312
  1. **Support for batches** in training to improve efficiency and gradient stability.
package/dist/index.d.mts CHANGED
@@ -296,6 +296,34 @@ declare class NetworkTransformerRL {
296
296
  predict(sequence: number[][]): number[];
297
297
  train(sequence: number[][], target: number[], lr: number): number;
298
298
  getAttentionWeights(): (number[][] | null)[][];
299
+ getWeights(): {
300
+ inputProj: number[][];
301
+ blocks: {
302
+ attn: {
303
+ heads: {
304
+ Wq: number[][];
305
+ Wk: number[][];
306
+ Wv: number[][];
307
+ }[];
308
+ Wo: number[][];
309
+ };
310
+ norm1: {
311
+ gamma: number[];
312
+ beta: number[];
313
+ };
314
+ norm2: {
315
+ gamma: number[];
316
+ beta: number[];
317
+ };
318
+ ff1: number[][];
319
+ ff2: number[][];
320
+ b1: number[];
321
+ b2: number[];
322
+ }[];
323
+ outputProj: number[][];
324
+ outputBias: number[];
325
+ };
326
+ setWeights(data: ReturnType<NetworkTransformerRL['getWeights']>): void;
299
327
  private _forward;
300
328
  private _pool;
301
329
  }
package/dist/index.d.ts CHANGED
@@ -296,6 +296,34 @@ declare class NetworkTransformerRL {
296
296
  predict(sequence: number[][]): number[];
297
297
  train(sequence: number[][], target: number[], lr: number): number;
298
298
  getAttentionWeights(): (number[][] | null)[][];
299
+ getWeights(): {
300
+ inputProj: number[][];
301
+ blocks: {
302
+ attn: {
303
+ heads: {
304
+ Wq: number[][];
305
+ Wk: number[][];
306
+ Wv: number[][];
307
+ }[];
308
+ Wo: number[][];
309
+ };
310
+ norm1: {
311
+ gamma: number[];
312
+ beta: number[];
313
+ };
314
+ norm2: {
315
+ gamma: number[];
316
+ beta: number[];
317
+ };
318
+ ff1: number[][];
319
+ ff2: number[][];
320
+ b1: number[];
321
+ b2: number[];
322
+ }[];
323
+ outputProj: number[][];
324
+ outputBias: number[];
325
+ };
326
+ setWeights(data: ReturnType<NetworkTransformerRL['getWeights']>): void;
299
327
  private _forward;
300
328
  private _pool;
301
329
  }
package/dist/index.js CHANGED
@@ -1190,6 +1190,54 @@ var NetworkTransformerRL = class {
1190
1190
  getAttentionWeights() {
1191
1191
  return this.blocks.map((b) => b.getAttentionWeights());
1192
1192
  }
1193
+ // ── Serialization ──────────────────────────────────────────────────────────
1194
+ getWeights() {
1195
+ return {
1196
+ inputProj: this.inputProj.W.map((r) => [...r]),
1197
+ blocks: this.blocks.map((b) => ({
1198
+ attn: {
1199
+ heads: b.attn.heads.map((h) => ({
1200
+ Wq: h.Wq.W.map((r) => [...r]),
1201
+ Wk: h.Wk.W.map((r) => [...r]),
1202
+ Wv: h.Wv.W.map((r) => [...r])
1203
+ })),
1204
+ Wo: b.attn.Wo.W.map((r) => [...r])
1205
+ },
1206
+ norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
1207
+ norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
1208
+ ff1: b.ff1.W.map((r) => [...r]),
1209
+ ff2: b.ff2.W.map((r) => [...r]),
1210
+ b1: [...b.b1],
1211
+ b2: [...b.b2]
1212
+ })),
1213
+ outputProj: this.outputProj.W.map((r) => [...r]),
1214
+ outputBias: [...this.outputBias]
1215
+ };
1216
+ }
1217
+ setWeights(data) {
1218
+ data.inputProj.forEach((row, i) => {
1219
+ this.inputProj.W[i] = [...row];
1220
+ });
1221
+ data.blocks.forEach((bd, b) => {
1222
+ const blk = this.blocks[b];
1223
+ bd.attn.heads.forEach((hd, h) => {
1224
+ blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
1225
+ blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
1226
+ blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
1227
+ });
1228
+ blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
1229
+ blk.norm1.gamma = [...bd.norm1.gamma];
1230
+ blk.norm1.beta = [...bd.norm1.beta];
1231
+ blk.norm2.gamma = [...bd.norm2.gamma];
1232
+ blk.norm2.beta = [...bd.norm2.beta];
1233
+ blk.ff1.W = bd.ff1.map((r) => [...r]);
1234
+ blk.ff2.W = bd.ff2.map((r) => [...r]);
1235
+ blk.b1 = [...bd.b1];
1236
+ blk.b2 = [...bd.b2];
1237
+ });
1238
+ this.outputProj.W = data.outputProj.map((r) => [...r]);
1239
+ this.outputBias = [...data.outputBias];
1240
+ }
1193
1241
  // ── Internal ────────────────────────────────────────────────────────────────
1194
1242
  _forward(sequence) {
1195
1243
  let h = sequence.map(
package/dist/index.mjs CHANGED
@@ -1130,6 +1130,54 @@ var NetworkTransformerRL = class {
1130
1130
  getAttentionWeights() {
1131
1131
  return this.blocks.map((b) => b.getAttentionWeights());
1132
1132
  }
1133
+ // ── Serialization ──────────────────────────────────────────────────────────
1134
+ getWeights() {
1135
+ return {
1136
+ inputProj: this.inputProj.W.map((r) => [...r]),
1137
+ blocks: this.blocks.map((b) => ({
1138
+ attn: {
1139
+ heads: b.attn.heads.map((h) => ({
1140
+ Wq: h.Wq.W.map((r) => [...r]),
1141
+ Wk: h.Wk.W.map((r) => [...r]),
1142
+ Wv: h.Wv.W.map((r) => [...r])
1143
+ })),
1144
+ Wo: b.attn.Wo.W.map((r) => [...r])
1145
+ },
1146
+ norm1: { gamma: [...b.norm1.gamma], beta: [...b.norm1.beta] },
1147
+ norm2: { gamma: [...b.norm2.gamma], beta: [...b.norm2.beta] },
1148
+ ff1: b.ff1.W.map((r) => [...r]),
1149
+ ff2: b.ff2.W.map((r) => [...r]),
1150
+ b1: [...b.b1],
1151
+ b2: [...b.b2]
1152
+ })),
1153
+ outputProj: this.outputProj.W.map((r) => [...r]),
1154
+ outputBias: [...this.outputBias]
1155
+ };
1156
+ }
1157
+ setWeights(data) {
1158
+ data.inputProj.forEach((row, i) => {
1159
+ this.inputProj.W[i] = [...row];
1160
+ });
1161
+ data.blocks.forEach((bd, b) => {
1162
+ const blk = this.blocks[b];
1163
+ bd.attn.heads.forEach((hd, h) => {
1164
+ blk.attn.heads[h].Wq.W = hd.Wq.map((r) => [...r]);
1165
+ blk.attn.heads[h].Wk.W = hd.Wk.map((r) => [...r]);
1166
+ blk.attn.heads[h].Wv.W = hd.Wv.map((r) => [...r]);
1167
+ });
1168
+ blk.attn.Wo.W = bd.attn.Wo.map((r) => [...r]);
1169
+ blk.norm1.gamma = [...bd.norm1.gamma];
1170
+ blk.norm1.beta = [...bd.norm1.beta];
1171
+ blk.norm2.gamma = [...bd.norm2.gamma];
1172
+ blk.norm2.beta = [...bd.norm2.beta];
1173
+ blk.ff1.W = bd.ff1.map((r) => [...r]);
1174
+ blk.ff2.W = bd.ff2.map((r) => [...r]);
1175
+ blk.b1 = [...bd.b1];
1176
+ blk.b2 = [...bd.b2];
1177
+ });
1178
+ this.outputProj.W = data.outputProj.map((r) => [...r]);
1179
+ this.outputBias = [...data.outputBias];
1180
+ }
1133
1181
  // ── Internal ────────────────────────────────────────────────────────────────
1134
1182
  _forward(sequence) {
1135
1183
  let h = sequence.map(
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@dniskav/neuron",
3
- "version": "0.2.2",
3
+ "version": "0.2.3",
4
4
  "description": "Minimal neural network from scratch — neuron, layer, network, backpropagation. No dependencies.",
5
5
  "main": "dist/index.js",
6
6
  "module": "dist/index.mjs",