@dniskav/neuron 0.2.2 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -15,6 +15,7 @@ A minimal, dependency-free neural network library built from scratch in TypeScri
15
15
  | `LSTMLayer` | Recurrent layer with persistent hidden and cell state. Learns sequences via BPTT. |
16
16
  | `NetworkLSTM` | Wraps an `LSTMLayer` + dense layers. Maintains memory across steps within an episode. |
17
17
  | `NetworkTransformer` | Full token-classification Transformer: embeddings → N blocks → per-token logits. |
18
+ | `NetworkTransformerRL` | Transformer for RL agents: continuous input projection → causal attention → Q-values. Remembers the last N steps. |
18
19
  | `TransformerBlock` | One Transformer block: multi-head attention + FFN + LayerNorm × 2 with residuals. |
19
20
  | `MultiHeadAttention` | N parallel attention heads concatenated and projected to `d_model`. |
20
21
  | `AttentionHead` | Single scaled dot-product self-attention head (Q / K / V projections + backprop). |
@@ -269,6 +270,43 @@ const weights = net.getAttentionWeights();
269
270
  Each head in each block learns a different type of relationship (row, column,
270
271
  3×3 box). The network figures this out by itself through training.
271
272
 
273
+ ### NetworkTransformerRL — Transformer for reinforcement learning
274
+
275
+ `NetworkTransformerRL` uses causal self-attention over a sliding window of past states to output Q-values. Unlike `NetworkLSTM`, the agent attends to specific past moments rather than compressing them into a single hidden vector.
276
+
277
+ ```ts
278
+ import { NetworkTransformerRL } from "@dniskav/neuron";
279
+
280
+ // Agent sees the last 8 steps, each step is a 7-value sensor vector → 4 actions
281
+ const net = new NetworkTransformerRL(8, 7, {
282
+ d_model: 32,
283
+ nHeads: 2,
284
+ d_ff: 64,
285
+ nBlocks: 2,
286
+ nActions: 4,
287
+ });
288
+
289
+ // Each step: feed the last N states as a sequence
290
+ const sequence = getLastNStates(); // number[][] — shape: [8, 7]
291
+ const qValues = net.predict(sequence); // number[4]
292
+
293
+ // Q-learning update: train toward Bellman target
294
+ const action = argmax(qValues);
295
+ const reward = env.step(action);
296
+ const targets = qValues.slice();
297
+ targets[action] = reward + 0.99 * Math.max(...net.predict(nextSequence));
298
+
299
+ const loss = net.train(sequence, targets, 0.001);
300
+ ```
301
+
302
+ The last step in the sequence gets 2× pooling weight — the most recent state contributes more to the decision.
303
+
304
+ ```ts
305
+ // Inspect what the agent is attending to
306
+ const attnWeights = net.getAttentionWeights();
307
+ // attnWeights[blockIdx][headIdx] → seqLen × seqLen matrix
308
+ ```
309
+
272
310
  ## Possible improvements
273
311
 
274
312
  1. **Support for batches** in training to improve efficiency and gradient stability.
package/dist/index.d.mts CHANGED
@@ -34,6 +34,13 @@ declare class Momentum implements Optimizer {
34
34
  constructor(beta?: number);
35
35
  step(weight: number, gradient: number, lr: number): number;
36
36
  }
37
+ declare class ClipOptimizer implements Optimizer {
38
+ readonly inner: Optimizer;
39
+ readonly clipValue: number;
40
+ constructor(inner: Optimizer, clipValue: number);
41
+ step(weight: number, gradient: number, lr: number): number;
42
+ }
43
+ declare function ClippedOptimizerFactory(innerFactory: OptimizerFactory, clipValue: number): OptimizerFactory;
37
44
  declare class Adam implements Optimizer {
38
45
  readonly beta1: number;
39
46
  readonly beta2: number;
@@ -68,22 +75,31 @@ declare class Network {
68
75
  constructor(nInputs: number, nHidden: number, nOutputs: number);
69
76
  predict(inputs: number[]): number;
70
77
  train(inputs: number[], target: number, lr: number): number;
78
+ getWeights(): number[];
79
+ setWeights(weights: number[]): void;
71
80
  }
72
81
 
73
82
  interface NetworkNOptions {
74
83
  activations?: Activation[];
75
84
  optimizer?: OptimizerFactory;
85
+ residual?: boolean | ((layerIndex: number) => boolean);
86
+ dropoutRate?: number;
76
87
  }
77
88
  declare class NetworkN {
78
89
  readonly structure: number[];
79
90
  layers: Layer[];
91
+ private _dropouts;
92
+ private _residual;
80
93
  constructor(structure: number[], options?: NetworkNOptions);
81
- predict(inputs: number[]): number[];
94
+ predict(inputs: number[], training?: boolean): number[];
82
95
  train(inputs: number[], targets: number[], lr: number): number;
83
96
  trainWithDeltas(inputs: number[], outputDeltas: number[], lr: number): void;
97
+ getWeights(): number[];
98
+ setWeights(weights: number[]): void;
99
+ private _shouldResidual;
84
100
  }
85
101
 
86
- declare class Gate {
102
+ declare class Gate$1 {
87
103
  W: number[][];
88
104
  b: number[];
89
105
  constructor(inputSize: number, hSize: number, initBias?: number);
@@ -94,12 +110,13 @@ declare class LSTMLayer {
94
110
  readonly hSize: number;
95
111
  h: number[];
96
112
  c: number[];
97
- forgetGate: Gate;
98
- inputGate: Gate;
99
- cellGate: Gate;
100
- outputGate: Gate;
113
+ forgetGate: Gate$1;
114
+ inputGate: Gate$1;
115
+ cellGate: Gate$1;
116
+ outputGate: Gate$1;
117
+ private _optimizers;
101
118
  private _traj;
102
- constructor(inputSize: number, hiddenSize: number);
119
+ constructor(inputSize: number, hiddenSize: number, optimizerFactory?: OptimizerFactory);
103
120
  reset(): void;
104
121
  predict(inputs: number[]): number[];
105
122
  backprop(dh_seq: number[][], lr: number): void;
@@ -122,6 +139,8 @@ declare class LSTMLayer {
122
139
  };
123
140
  };
124
141
  setWeights(data: ReturnType<LSTMLayer["getWeights"]>): void;
142
+ getWeightsFlat(): number[];
143
+ setWeightsFlat(weights: number[]): void;
125
144
  }
126
145
 
127
146
  interface NetworkLSTMOptions {
@@ -163,6 +182,8 @@ declare class NetworkLSTM {
163
182
  }[][];
164
183
  };
165
184
  setWeights(data: ReturnType<NetworkLSTM["getWeights"]>): void;
185
+ getWeightsFlat(): number[];
186
+ setWeightsFlat(weights: number[]): void;
166
187
  }
167
188
 
168
189
  declare function matMul(A: number[][], B: number[][]): number[][];
@@ -174,38 +195,48 @@ declare class WeightMatrix {
174
195
  private opts;
175
196
  constructor(rows: number, cols: number);
176
197
  update(dW: number[][], lr: number, clipValue?: number): void;
198
+ getWeights(): number[];
199
+ setWeights(weights: number[]): void;
177
200
  }
178
201
  declare class EmbeddingMatrix {
179
202
  W: number[][];
180
203
  constructor(vocabSize: number, d_model: number);
181
204
  get(idx: number): number[];
182
205
  update(idx: number, grad: number[], lr: number): void;
206
+ getWeights(): number[];
207
+ setWeights(weights: number[]): void;
183
208
  }
184
209
 
185
210
  declare class AttentionHead {
186
211
  readonly d_k: number;
187
212
  readonly d_v: number;
213
+ readonly causal: boolean;
188
214
  Wq: WeightMatrix;
189
215
  Wk: WeightMatrix;
190
216
  Wv: WeightMatrix;
191
217
  private cache;
192
- constructor(d_model: number, d_k: number, d_v: number);
218
+ constructor(d_model: number, d_k: number, d_v: number, causal?: boolean);
193
219
  predict(X: number[][]): number[][];
194
220
  backward(dOut: number[][], lr: number): number[][];
195
221
  getAttentionWeights(): number[][] | null;
222
+ getWeights(): number[];
223
+ setWeights(weights: number[]): void;
196
224
  }
197
225
 
198
226
  declare class MultiHeadAttention {
199
227
  readonly nHeads: number;
200
228
  readonly d_model: number;
201
229
  readonly d_k: number;
230
+ readonly causal: boolean;
202
231
  heads: AttentionHead[];
203
232
  Wo: WeightMatrix;
204
233
  private _concat;
205
- constructor(d_model: number, nHeads: number);
234
+ constructor(d_model: number, nHeads: number, causal?: boolean);
206
235
  predict(X: number[][]): number[][];
207
236
  backward(dOut: number[][], lr: number): number[][];
208
237
  getAttentionWeights(): (number[][] | null)[];
238
+ getWeights(): number[];
239
+ setWeights(weights: number[]): void;
209
240
  }
210
241
 
211
242
  declare class LayerNorm {
@@ -217,12 +248,15 @@ declare class LayerNorm {
217
248
  resetCache(seqLen: number): void;
218
249
  predictOne(x: number[], pos: number): number[];
219
250
  backwardOne(dOut: number[], pos: number, lr: number): number[];
251
+ getWeights(): number[];
252
+ setWeights(weights: number[]): void;
220
253
  }
221
254
 
222
255
  interface TransformerBlockOptions {
223
256
  d_model: number;
224
257
  nHeads: number;
225
258
  d_ff: number;
259
+ causal?: boolean;
226
260
  }
227
261
  declare class TransformerBlock {
228
262
  readonly d_model: number;
@@ -242,10 +276,12 @@ declare class TransformerBlock {
242
276
  private _ff1Pre;
243
277
  private _ff1Out;
244
278
  private _ff2Out;
245
- constructor({ d_model, nHeads, d_ff }: TransformerBlockOptions);
279
+ constructor({ d_model, nHeads, d_ff, causal }: TransformerBlockOptions);
246
280
  predict(X: number[][]): number[][];
247
281
  backward(dOut: number[][], lr: number): number[][];
248
282
  getAttentionWeights(): (number[][] | null)[];
283
+ getWeights(): number[];
284
+ setWeights(weights: number[]): void;
249
285
  }
250
286
 
251
287
  interface NetworkTransformerOptions {
@@ -271,6 +307,8 @@ declare class NetworkTransformer {
271
307
  predict(tokens: number[]): number[];
272
308
  train(tokens: number[], targets: number[], lr: number, mask?: boolean[]): number;
273
309
  getAttentionWeights(): (number[][] | null)[][];
310
+ getWeights(): number[];
311
+ setWeights(weights: number[]): void;
274
312
  private _forward;
275
313
  }
276
314
 
@@ -280,6 +318,7 @@ interface NetworkTransformerRLOptions {
280
318
  d_ff?: number;
281
319
  nBlocks?: number;
282
320
  nActions?: number;
321
+ pooling?: 'avg' | 'max' | 'last' | 'weighted';
283
322
  }
284
323
  declare class NetworkTransformerRL {
285
324
  readonly seqLen: number;
@@ -292,12 +331,53 @@ declare class NetworkTransformerRL {
292
331
  outputBias: number[];
293
332
  private outBiasOpts;
294
333
  private _projected;
334
+ private _pooling;
335
+ private _argmax;
295
336
  constructor(seqLen: number, inputDim: number, options?: NetworkTransformerRLOptions);
296
337
  predict(sequence: number[][]): number[];
297
338
  train(sequence: number[][], target: number[], lr: number): number;
298
339
  getAttentionWeights(): (number[][] | null)[][];
340
+ getWeightsFlat(): number[];
341
+ setWeightsFlat(weights: number[]): void;
342
+ getWeightsStructured(): {
343
+ inputProj: number[][];
344
+ blocks: {
345
+ attn: {
346
+ heads: {
347
+ Wq: number[][];
348
+ Wk: number[][];
349
+ Wv: number[][];
350
+ }[];
351
+ Wo: number[][];
352
+ };
353
+ norm1: {
354
+ gamma: number[];
355
+ beta: number[];
356
+ };
357
+ norm2: {
358
+ gamma: number[];
359
+ beta: number[];
360
+ };
361
+ ff1: number[][];
362
+ ff2: number[][];
363
+ b1: number[];
364
+ b2: number[];
365
+ }[];
366
+ outputProj: number[][];
367
+ outputBias: number[];
368
+ };
369
+ setWeightsStructured(data: ReturnType<NetworkTransformerRL['getWeightsStructured']>): void;
370
+ getWeights(): number[];
371
+ setWeights(weights: number[]): void;
299
372
  private _forward;
300
373
  private _pool;
374
+ private _poolAvg;
375
+ private _poolMax;
376
+ private _poolLast;
377
+ private _poolWeighted;
378
+ /** Returns the current pooling type for inspection. */
379
+ getPoolingType(): string;
380
+ private _distributePoolGradient;
301
381
  }
302
382
 
303
383
  declare function mse(predicted: number[], actual: number[]): number;
@@ -306,4 +386,204 @@ declare function mseDelta(predicted: number, actual: number): number;
306
386
  declare function crossEntropyDelta(predicted: number, actual: number): number;
307
387
  declare function crossEntropyDeltaRaw(predicted: number, actual: number): number;
308
388
 
309
- export { type Activation, Adam, AttentionHead, EmbeddingMatrix, LSTMLayer, Layer, LayerNorm, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, NetworkTransformerRL, type NetworkTransformerRLOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose };
389
+ declare class Dropout {
390
+ readonly rate: number;
391
+ private _mask;
392
+ constructor(rate: number);
393
+ forward(x: number[], training?: boolean): number[];
394
+ backward(dOut: number[]): number[];
395
+ resetMask(): void;
396
+ getWeights(): number[];
397
+ setWeights(_weights: number[]): void;
398
+ }
399
+
400
+ declare class Gate {
401
+ W: number[][];
402
+ b: number[];
403
+ constructor(inputSize: number, hSize: number, initBias?: number);
404
+ linear(combined: number[]): number[];
405
+ }
406
+ declare class GRULayer {
407
+ readonly inputSize: number;
408
+ readonly hSize: number;
409
+ h: number[];
410
+ resetGate: Gate;
411
+ updateGate: Gate;
412
+ newGate: Gate;
413
+ private _optimizers;
414
+ private _traj;
415
+ constructor(inputSize: number, hiddenSize: number, optimizerFactory?: OptimizerFactory);
416
+ reset(): void;
417
+ predict(inputs: number[]): number[];
418
+ backprop(dh_seq: number[][], lr: number): void;
419
+ getWeightsFlat(): number[];
420
+ setWeightsFlat(weights: number[]): void;
421
+ getWeights(): {
422
+ resetGate: {
423
+ W: number[][];
424
+ b: number[];
425
+ };
426
+ updateGate: {
427
+ W: number[][];
428
+ b: number[];
429
+ };
430
+ newGate: {
431
+ W: number[][];
432
+ b: number[];
433
+ };
434
+ };
435
+ setWeights(data: ReturnType<GRULayer["getWeights"]>): void;
436
+ }
437
+
438
+ declare class BatchNorm {
439
+ readonly dim: number;
440
+ readonly momentum: number;
441
+ gamma: number[];
442
+ beta: number[];
443
+ runningMean: number[];
444
+ runningVar: number[];
445
+ private _xNorm;
446
+ private _std;
447
+ constructor(dim: number, momentum?: number);
448
+ forward(x: number[]): number[];
449
+ backward(dOut: number[]): number[];
450
+ trainParams(dOut: number[], lr: number): void;
451
+ getWeights(): number[];
452
+ setWeights(weights: number[]): void;
453
+ }
454
+
455
+ declare class Conv1D {
456
+ readonly inputLength: number;
457
+ readonly kernelSize: number;
458
+ readonly filters: number;
459
+ readonly stride: number;
460
+ readonly padding: 'valid' | 'same';
461
+ readonly inputChannels: number;
462
+ kernels: number[][][];
463
+ biases: number[];
464
+ private _kOpts;
465
+ private _bOpts;
466
+ private _input;
467
+ private _paddedInput;
468
+ constructor(inputLength: number, kernelSize: number, filters: number, stride?: number, padding?: 'valid' | 'same', optimizerFactory?: OptimizerFactory, inputChannels?: number);
469
+ forward(input: number[] | number[][]): number[][];
470
+ backward(dOut: number[][], lr?: number): number[][];
471
+ getOutputLength(): number;
472
+ getWeights(): number[];
473
+ setWeights(weights: number[]): void;
474
+ private _normalizeInput;
475
+ }
476
+
477
+ interface DataPair {
478
+ inputs: number[][];
479
+ targets: number[][];
480
+ }
481
+ declare class DataLoader {
482
+ readonly data: DataPair;
483
+ readonly batchSize: number;
484
+ private _indices;
485
+ private _trainIndices;
486
+ private _valIndices;
487
+ private _pos;
488
+ private _validationSplit;
489
+ constructor(data: DataPair, batchSize?: number, validationSplit?: number);
490
+ shuffle(): void;
491
+ hasNext(): boolean;
492
+ next(): DataPair;
493
+ reset(): void;
494
+ get length(): number;
495
+ getValidationData(): DataPair;
496
+ get validationLength(): number;
497
+ static sequences(data: number[][], seqLen: number, validationSplit?: number): DataLoader;
498
+ }
499
+
500
+ interface TrainMetrics {
501
+ accuracy: number;
502
+ precision: number;
503
+ recall: number;
504
+ f1: number;
505
+ }
506
+ interface TrainerOptions {
507
+ epochs?: number;
508
+ lr?: number;
509
+ lrDecay?: number;
510
+ verbose?: boolean;
511
+ weightDecay?: number;
512
+ earlyStopping?: {
513
+ patience: number;
514
+ minDelta: number;
515
+ };
516
+ computeMetrics?: boolean;
517
+ clipValue?: number;
518
+ }
519
+ interface TrainDataset {
520
+ inputs: number[][];
521
+ targets: number[][];
522
+ }
523
+ interface TrainableNetwork {
524
+ train(inputs: number[], targets: number[], lr: number): number;
525
+ }
526
+ /** Extended interface for networks that support weight access and prediction.
527
+ * Required for weightDecay, earlyStopping, and computeMetrics features. */
528
+ interface TrainableNetworkWithWeights extends TrainableNetwork {
529
+ predict(inputs: number[]): number[];
530
+ getWeights(): number[];
531
+ setWeights(weights: number[]): void;
532
+ }
533
+ declare class Trainer {
534
+ readonly network: TrainableNetwork;
535
+ readonly epochs: number;
536
+ readonly lrInitial: number;
537
+ readonly lrDecay: number;
538
+ readonly verbose: boolean;
539
+ readonly weightDecay: number;
540
+ readonly clipValue: number;
541
+ private _history;
542
+ private _validationData?;
543
+ private _earlyStopping?;
544
+ private _bestLoss;
545
+ private _patienceCounter;
546
+ private _stopReason;
547
+ private _computeMetrics;
548
+ private _metrics;
549
+ constructor(network: TrainableNetwork, options?: TrainerOptions);
550
+ setValidationData(dataset: DataPair): void;
551
+ getBestLoss(): number;
552
+ getStopReason(): string;
553
+ getMetrics(): TrainMetrics[];
554
+ train(dataset: TrainDataset): number[];
555
+ getHistory(): number[];
556
+ /** Type guard: does this network support getWeights/setWeights/predict? */
557
+ private _hasWeights;
558
+ /** Mean squared error on a dataset (used for validation loss). */
559
+ private _computeLoss;
560
+ /** Heuristic: are targets classification-style (one-hot or single-class)? */
561
+ private _isClassification;
562
+ /** Compute classification metrics from predictions vs targets. */
563
+ private _computeMetricsArray;
564
+ }
565
+
566
+ declare class LRScheduler {
567
+ stepDecay(lr: number, epoch: number, dropRate: number, epochsDrop: number): number;
568
+ exponentialDecay(lr: number, epoch: number, decayRate: number): number;
569
+ plateauDecay(lr: number, currentLoss: number, history: number[], patience: number, factor: number): number;
570
+ cosineAnnealing(lr: number, epoch: number, maxEpochs: number, minLr?: number): number;
571
+ }
572
+
573
+ interface Serializable {
574
+ getWeights(): number[];
575
+ setWeights(weights: number[]): void;
576
+ }
577
+ declare class ModelSaver {
578
+ static toJSON(model: Serializable): string;
579
+ static fromJSON(model: Serializable, json: string): void;
580
+ static saveToFile(model: Serializable, path: string, writeFn: (path: string, data: string) => void): void;
581
+ static loadFromFile(model: Serializable, path: string, readFn: (path: string) => string): void;
582
+ }
583
+
584
+ declare function validateArray(arr: unknown, expectedLength: number, methodName: string): asserts arr is number[];
585
+ declare function validateArrayMinLength(arr: unknown, minLength: number, methodName: string): asserts arr is number[];
586
+ declare function validate2DArray(arr: unknown, expectedRows: number, expectedCols: number, methodName: string): asserts arr is number[][];
587
+ declare function validateNumber(value: unknown, methodName: string): asserts value is number;
588
+
589
+ export { type Activation, Adam, AttentionHead, BatchNorm, ClipOptimizer, ClippedOptimizerFactory, Conv1D, DataLoader, type DataPair, Dropout, EmbeddingMatrix, GRULayer, LRScheduler, LSTMLayer, Layer, LayerNorm, ModelSaver, Momentum, MultiHeadAttention, Network, NetworkLSTM, type NetworkLSTMOptions, NetworkN, type NetworkNOptions, NetworkTransformer, type NetworkTransformerOptions, NetworkTransformerRL, type NetworkTransformerRLOptions, Neuron, NeuronN, type Optimizer, type OptimizerFactory, SGD, type Serializable, type TrainDataset, type TrainMetrics, type TrainableNetwork, type TrainableNetworkWithWeights, Trainer, type TrainerOptions, TransformerBlock, type TransformerBlockOptions, WeightMatrix, crossEntropy, crossEntropyDelta, crossEntropyDeltaRaw, elu, leakyRelu, linear, makeElu, makeLeakyRelu, matMul, mse, mseDelta, relu, sigmoid, softmax, softmaxBackward, tanh, transpose, validate2DArray, validateArray, validateArrayMinLength, validateNumber };