@dniskav/neuron 0.1.6 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -21,15 +21,22 @@ var __toCommonJS = (mod) => __copyProps(__defProp({}, "__esModule", { value: tru
21
21
  var index_exports = {};
22
22
  __export(index_exports, {
23
23
  Adam: () => Adam,
24
+ AttentionHead: () => AttentionHead,
25
+ EmbeddingMatrix: () => EmbeddingMatrix,
24
26
  LSTMLayer: () => LSTMLayer,
25
27
  Layer: () => Layer,
28
+ LayerNorm: () => LayerNorm,
26
29
  Momentum: () => Momentum,
30
+ MultiHeadAttention: () => MultiHeadAttention,
27
31
  Network: () => Network,
28
32
  NetworkLSTM: () => NetworkLSTM,
29
33
  NetworkN: () => NetworkN,
34
+ NetworkTransformer: () => NetworkTransformer,
30
35
  Neuron: () => Neuron,
31
36
  NeuronN: () => NeuronN,
32
37
  SGD: () => SGD,
38
+ TransformerBlock: () => TransformerBlock,
39
+ WeightMatrix: () => WeightMatrix,
33
40
  crossEntropy: () => crossEntropy,
34
41
  crossEntropyDelta: () => crossEntropyDelta,
35
42
  crossEntropyDeltaRaw: () => crossEntropyDeltaRaw,
@@ -38,11 +45,15 @@ __export(index_exports, {
38
45
  linear: () => linear,
39
46
  makeElu: () => makeElu,
40
47
  makeLeakyRelu: () => makeLeakyRelu,
48
+ matMul: () => matMul,
41
49
  mse: () => mse,
42
50
  mseDelta: () => mseDelta,
43
51
  relu: () => relu,
44
52
  sigmoid: () => sigmoid2,
45
- tanh: () => tanh
53
+ softmax: () => softmax,
54
+ softmaxBackward: () => softmaxBackward,
55
+ tanh: () => tanh,
56
+ transpose: () => transpose
46
57
  });
47
58
  module.exports = __toCommonJS(index_exports);
48
59
 
@@ -523,6 +534,570 @@ var NetworkLSTM = class {
523
534
  }
524
535
  };
525
536
 
537
+ // src/MatMul.ts
538
+ function matMul(A, B) {
539
+ const rows = A.length;
540
+ const inner = B.length;
541
+ const cols = B[0].length;
542
+ const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
543
+ for (let i = 0; i < rows; i++)
544
+ for (let k = 0; k < inner; k++) {
545
+ const aik = A[i][k];
546
+ for (let j = 0; j < cols; j++)
547
+ C[i][j] += aik * B[k][j];
548
+ }
549
+ return C;
550
+ }
551
+ function transpose(A) {
552
+ const rows = A.length, cols = A[0].length;
553
+ const T = Array.from({ length: cols }, () => new Array(rows).fill(0));
554
+ for (let i = 0; i < rows; i++)
555
+ for (let j = 0; j < cols; j++)
556
+ T[j][i] = A[i][j];
557
+ return T;
558
+ }
559
+ function softmax(row) {
560
+ const max = Math.max(...row);
561
+ const exps = row.map((v) => Math.exp(v - max));
562
+ const sum = exps.reduce((a, b) => a + b, 0);
563
+ return exps.map((e) => e / sum);
564
+ }
565
+ function softmaxBackward(dS, s) {
566
+ const dot = s.reduce((acc, si, i) => acc + dS[i] * si, 0);
567
+ return s.map((si, i) => si * (dS[i] - dot));
568
+ }
569
+ var WeightMatrix = class {
570
+ constructor(rows, cols) {
571
+ const limit = Math.sqrt(2 / (rows + cols));
572
+ this.W = Array.from(
573
+ { length: rows },
574
+ () => Array.from({ length: cols }, () => (Math.random() * 2 - 1) * limit)
575
+ );
576
+ this.opts = Array.from(
577
+ { length: rows },
578
+ () => Array.from({ length: cols }, () => new Adam())
579
+ );
580
+ }
581
+ // Apply pre-computed gradient (same shape as W).
582
+ // clipValue: optional per-element gradient clipping before the Adam step.
583
+ // Prevents gradient explosion in deep networks (e.g. Transformers without
584
+ // global norm clipping). Pass e.g. 1.0 to clip to [-1, 1].
585
+ update(dW, lr, clipValue = Infinity) {
586
+ for (let i = 0; i < this.W.length; i++)
587
+ for (let j = 0; j < this.W[0].length; j++) {
588
+ const g = isFinite(clipValue) ? Math.max(-clipValue, Math.min(clipValue, dW[i][j])) : dW[i][j];
589
+ this.W[i][j] = this.opts[i][j].step(this.W[i][j], g, lr);
590
+ }
591
+ }
592
+ };
593
+ var EmbeddingMatrix = class {
594
+ constructor(vocabSize, d_model) {
595
+ const limit = Math.sqrt(1 / d_model);
596
+ this.W = Array.from(
597
+ { length: vocabSize },
598
+ () => Array.from({ length: d_model }, () => (Math.random() * 2 - 1) * limit)
599
+ );
600
+ }
601
+ get(idx) {
602
+ return [...this.W[idx]];
603
+ }
604
+ update(idx, grad, lr) {
605
+ for (let m = 0; m < this.W[idx].length; m++)
606
+ this.W[idx][m] += lr * grad[m];
607
+ }
608
+ };
609
+
610
+ // src/AttentionHead.ts
611
+ var AttentionHead = class {
612
+ constructor(d_model, d_k, d_v) {
613
+ // d_v × d_model
614
+ this.cache = null;
615
+ this.d_k = d_k;
616
+ this.d_v = d_v;
617
+ this.Wq = new WeightMatrix(d_k, d_model);
618
+ this.Wk = new WeightMatrix(d_k, d_model);
619
+ this.Wv = new WeightMatrix(d_v, d_model);
620
+ }
621
+ // ── Forward ───────────────────────────────────────────────────────────────
622
+ // X: seqLen × d_model → out: seqLen × d_v
623
+ predict(X) {
624
+ const seqLen = X.length;
625
+ const scale = 1 / Math.sqrt(this.d_k);
626
+ const Q = X.map(
627
+ (x) => this.Wq.W.map((wq) => wq.reduce((s, w, m) => s + w * x[m], 0))
628
+ );
629
+ const K = X.map(
630
+ (x) => this.Wk.W.map((wk) => wk.reduce((s, w, m) => s + w * x[m], 0))
631
+ );
632
+ const V = X.map(
633
+ (x) => this.Wv.W.map((wv) => wv.reduce((s, w, m) => s + w * x[m], 0))
634
+ );
635
+ const scores = Array.from(
636
+ { length: seqLen },
637
+ (_, i) => Array.from(
638
+ { length: seqLen },
639
+ (_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
640
+ )
641
+ );
642
+ const attn = scores.map((row) => softmax(row));
643
+ const out = Array.from(
644
+ { length: seqLen },
645
+ (_, i) => Array.from(
646
+ { length: this.d_v },
647
+ (_2, d) => attn[i].reduce((s, a, j) => s + a * V[j][d], 0)
648
+ )
649
+ );
650
+ this.cache = { X, Q, K, V, scores, attn };
651
+ return out;
652
+ }
653
+ // ── Backward ──────────────────────────────────────────────────────────────
654
+ // dOut: seqLen × d_v → dX: seqLen × d_model
655
+ //
656
+ // Steps:
657
+ // 1. dV = attn^T @ dOut
658
+ // 2. dAttn = dOut @ V^T (attention weight gradients)
659
+ // 3. dScores = softmaxBwd(dAttn) / √d_k
660
+ // 4. dQ = dScores @ K, dK = dScores^T @ Q
661
+ // 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
662
+ // 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
663
+ backward(dOut, lr) {
664
+ const { X, Q, K, V, attn } = this.cache;
665
+ const seqLen = X.length;
666
+ const d_model = X[0].length;
667
+ const scale = 1 / Math.sqrt(this.d_k);
668
+ const dV = Array.from(
669
+ { length: seqLen },
670
+ (_, j) => Array.from(
671
+ { length: this.d_v },
672
+ (_2, d) => attn.reduce((s, a, i) => s + a[j] * dOut[i][d], 0)
673
+ )
674
+ );
675
+ const dAttn = Array.from(
676
+ { length: seqLen },
677
+ (_, i) => Array.from(
678
+ { length: seqLen },
679
+ (_2, j) => dOut[i].reduce((s, d, k) => s + d * V[j][k], 0)
680
+ )
681
+ );
682
+ const dScores = dAttn.map(
683
+ (da, i) => softmaxBackward(da, attn[i]).map((v) => v * scale)
684
+ );
685
+ const dQ = Array.from(
686
+ { length: seqLen },
687
+ (_, i) => Array.from(
688
+ { length: this.d_k },
689
+ (_2, k) => dScores[i].reduce((s, ds, j) => s + ds * K[j][k], 0)
690
+ )
691
+ );
692
+ const dK = Array.from(
693
+ { length: seqLen },
694
+ (_, j) => Array.from(
695
+ { length: this.d_k },
696
+ (_2, k) => dScores.reduce((s, ds, i) => s + ds[j] * Q[i][k], 0)
697
+ )
698
+ );
699
+ const dWq = Array.from(
700
+ { length: this.d_k },
701
+ (_, k) => Array.from(
702
+ { length: d_model },
703
+ (_2, m) => dQ.reduce((s, dq, i) => s + dq[k] * X[i][m], 0)
704
+ )
705
+ );
706
+ const dWk = Array.from(
707
+ { length: this.d_k },
708
+ (_, k) => Array.from(
709
+ { length: d_model },
710
+ (_2, m) => dK.reduce((s, dk, i) => s + dk[k] * X[i][m], 0)
711
+ )
712
+ );
713
+ const dWv = Array.from(
714
+ { length: this.d_v },
715
+ (_, k) => Array.from(
716
+ { length: d_model },
717
+ (_2, m) => dV.reduce((s, dv, i) => s + dv[k] * X[i][m], 0)
718
+ )
719
+ );
720
+ this.Wq.update(dWq, lr);
721
+ this.Wk.update(dWk, lr);
722
+ this.Wv.update(dWv, lr);
723
+ const dX = Array.from(
724
+ { length: seqLen },
725
+ (_, i) => Array.from(
726
+ { length: d_model },
727
+ (_2, m) => dQ[i].reduce((s, dq, k) => s + dq * this.Wq.W[k][m], 0) + dK[i].reduce((s, dk, k) => s + dk * this.Wk.W[k][m], 0) + dV[i].reduce((s, dv, k) => s + dv * this.Wv.W[k][m], 0)
728
+ )
729
+ );
730
+ return dX;
731
+ }
732
+ // Attention weights from the last predict() call — useful for visualization.
733
+ getAttentionWeights() {
734
+ return this.cache ? this.cache.attn : null;
735
+ }
736
+ };
737
+
738
+ // src/MultiHeadAttention.ts
739
+ var MultiHeadAttention = class {
740
+ // seqLen × (nHeads * d_k)
741
+ constructor(d_model, nHeads) {
742
+ // d_model × (nHeads * d_k)
743
+ // Cached for backward
744
+ this._concat = null;
745
+ this.nHeads = nHeads;
746
+ this.d_model = d_model;
747
+ this.d_k = Math.floor(d_model / nHeads);
748
+ this.heads = Array.from(
749
+ { length: nHeads },
750
+ () => new AttentionHead(d_model, this.d_k, this.d_k)
751
+ );
752
+ this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
753
+ }
754
+ // ── Forward ───────────────────────────────────────────────────────────────
755
+ // X: seqLen × d_model → out: seqLen × d_model
756
+ predict(X) {
757
+ const seqLen = X.length;
758
+ const headOuts = this.heads.map((h) => h.predict(X));
759
+ const concat = Array.from(
760
+ { length: seqLen },
761
+ (_, i) => headOuts.flatMap((ho) => ho[i])
762
+ );
763
+ const out = concat.map(
764
+ (c) => this.Wo.W.map((row) => row.reduce((s, w, j) => s + w * c[j], 0))
765
+ );
766
+ this._concat = concat;
767
+ return out;
768
+ }
769
+ // ── Backward ──────────────────────────────────────────────────────────────
770
+ // dOut: seqLen × d_model → dX: seqLen × d_model
771
+ backward(dOut, lr) {
772
+ const seqLen = dOut.length;
773
+ const concatD = this.nHeads * this.d_k;
774
+ const d_model = this.d_model;
775
+ const concat = this._concat;
776
+ const dConcat = dOut.map(
777
+ (do_) => Array.from(
778
+ { length: concatD },
779
+ (_, j) => this.Wo.W.reduce((s, row, k) => s + do_[k] * row[j], 0)
780
+ )
781
+ );
782
+ const dWo = Array.from(
783
+ { length: d_model },
784
+ (_, k) => Array.from(
785
+ { length: concatD },
786
+ (_2, j) => dOut.reduce((s, row, i) => s + row[k] * concat[i][j], 0)
787
+ )
788
+ );
789
+ this.Wo.update(dWo, lr);
790
+ const dX = Array.from(
791
+ { length: seqLen },
792
+ () => new Array(d_model).fill(0)
793
+ );
794
+ for (let h = 0; h < this.nHeads; h++) {
795
+ const start = h * this.d_k;
796
+ const dHeadOut = dConcat.map((dc) => dc.slice(start, start + this.d_k));
797
+ const dXh = this.heads[h].backward(dHeadOut, lr);
798
+ for (let i = 0; i < seqLen; i++)
799
+ for (let m = 0; m < d_model; m++)
800
+ dX[i][m] += dXh[i][m];
801
+ }
802
+ return dX;
803
+ }
804
+ // Attention weights per head from the last predict() — for visualization.
805
+ // Returns: nHeads × seqLen × seqLen
806
+ getAttentionWeights() {
807
+ return this.heads.map((h) => h.getAttentionWeights());
808
+ }
809
+ };
810
+
811
+ // src/LayerNorm.ts
812
+ var LayerNorm = class {
813
+ constructor(dim) {
814
+ this.eps = 1e-5;
815
+ // Per-position cache populated during the forward pass.
816
+ // resetCache() must be called before each sequence forward pass.
817
+ this._cache = [];
818
+ this.gamma = new Array(dim).fill(1);
819
+ this.beta = new Array(dim).fill(0);
820
+ }
821
+ // Call once before forward-passing a new sequence.
822
+ resetCache(seqLen) {
823
+ this._cache = new Array(seqLen);
824
+ }
825
+ // Normalize a single position's feature vector.
826
+ // pos must match the position index used in the corresponding backwardOne call.
827
+ predictOne(x, pos) {
828
+ const N = x.length;
829
+ const mean = x.reduce((s, v) => s + v, 0) / N;
830
+ const vari = x.reduce((s, v) => s + (v - mean) ** 2, 0) / N;
831
+ const std = Math.sqrt(vari + this.eps);
832
+ const x_norm = x.map((v) => (v - mean) / std);
833
+ this._cache[pos] = { x_norm, std };
834
+ return x_norm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
835
+ }
836
+ // Backprop through layer norm for one position.
837
+ //
838
+ // Given dL/dy (dOut), computes dL/dx:
839
+ // Let D = dOut ⊙ γ
840
+ // dL/dx_i = (1/std) * (D_i − mean(D) − x_norm_i * mean(D ⊙ x_norm))
841
+ //
842
+ // Also updates γ and β via SGD:
843
+ // γ_i += lr * dOut_i * x_norm_i
844
+ // β_i += lr * dOut_i
845
+ //
846
+ // SGD (not Adam) for γ/β: they are aggregated across all positions in the
847
+ // sequence (de-facto mini-batch update), so the gradient is already smoothed.
848
+ backwardOne(dOut, pos, lr) {
849
+ const { x_norm, std } = this._cache[pos];
850
+ const N = dOut.length;
851
+ for (let i = 0; i < N; i++) {
852
+ this.gamma[i] += lr * dOut[i] * x_norm[i];
853
+ this.beta[i] += lr * dOut[i];
854
+ }
855
+ const D = dOut.map((d, i) => d * this.gamma[i]);
856
+ const mD = D.reduce((s, v) => s + v, 0) / N;
857
+ const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
858
+ return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
859
+ }
860
+ };
861
+
862
+ // src/TransformerBlock.ts
863
+ var TransformerBlock = class {
864
+ constructor({ d_model, nHeads, d_ff }) {
865
+ // Forward caches (needed for backprop)
866
+ this._X = null;
867
+ this._attnOut = null;
868
+ this._h1 = null;
869
+ this._ff1Pre = null;
870
+ // pre-ReLU
871
+ this._ff1Out = null;
872
+ // post-ReLU
873
+ this._ff2Out = null;
874
+ this.d_model = d_model;
875
+ this.d_ff = d_ff;
876
+ this.attn = new MultiHeadAttention(d_model, nHeads);
877
+ this.norm1 = new LayerNorm(d_model);
878
+ this.norm2 = new LayerNorm(d_model);
879
+ this.ff1 = new WeightMatrix(d_ff, d_model);
880
+ this.ff2 = new WeightMatrix(d_model, d_ff);
881
+ this.b1 = new Array(d_ff).fill(0);
882
+ this.b2 = new Array(d_model).fill(0);
883
+ this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
884
+ this.b2Opts = Array.from({ length: d_model }, () => new Adam());
885
+ }
886
+ // ── Forward ───────────────────────────────────────────────────────────────
887
+ // X: seqLen × d_model → out: seqLen × d_model
888
+ predict(X) {
889
+ const seqLen = X.length;
890
+ const attnOut = this.attn.predict(X);
891
+ this.norm1.resetCache(seqLen);
892
+ const h1 = X.map((x, i) => {
893
+ const added = x.map((v, k) => v + attnOut[i][k]);
894
+ return this.norm1.predictOne(added, i);
895
+ });
896
+ const ff1Pre = h1.map(
897
+ (h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
898
+ );
899
+ const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
900
+ const ff2Out = ff1Out.map(
901
+ (h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
902
+ );
903
+ this.norm2.resetCache(seqLen);
904
+ const out = h1.map((h, i) => {
905
+ const added = h.map((v, k) => v + ff2Out[i][k]);
906
+ return this.norm2.predictOne(added, i);
907
+ });
908
+ this._X = X;
909
+ this._attnOut = attnOut;
910
+ this._h1 = h1;
911
+ this._ff1Pre = ff1Pre;
912
+ this._ff1Out = ff1Out;
913
+ this._ff2Out = ff2Out;
914
+ return out;
915
+ }
916
+ // ── Backward ──────────────────────────────────────────────────────────────
917
+ // dOut: seqLen × d_model → dX: seqLen × d_model
918
+ backward(dOut, lr) {
919
+ const seqLen = dOut.length;
920
+ const d_model = this.d_model;
921
+ const h1 = this._h1;
922
+ const ff1Out = this._ff1Out;
923
+ const ff1Pre = this._ff1Pre;
924
+ const dAdded2 = dOut.map((do_, i) => this.norm2.backwardOne(do_, i, lr));
925
+ const dFf1Out = dAdded2.map(
926
+ (da) => Array.from(
927
+ { length: this.d_ff },
928
+ (_, k) => this.ff2.W.reduce((s, row, m) => s + row[k] * da[m], 0)
929
+ )
930
+ );
931
+ const dW2 = Array.from(
932
+ { length: d_model },
933
+ (_, m) => Array.from(
934
+ { length: this.d_ff },
935
+ (_2, k) => dAdded2.reduce((s, da, i) => s + da[m] * ff1Out[i][k], 0)
936
+ )
937
+ );
938
+ const db2 = Array.from(
939
+ { length: d_model },
940
+ (_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
941
+ );
942
+ this.ff2.update(dW2, lr);
943
+ for (let m = 0; m < d_model; m++)
944
+ this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
945
+ const dFf1Pre = dFf1Out.map(
946
+ (d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
947
+ );
948
+ const dH1_fromFf = dFf1Pre.map(
949
+ (dp) => Array.from(
950
+ { length: d_model },
951
+ (_, m) => this.ff1.W.reduce((s, row, k) => s + dp[k] * row[m], 0)
952
+ )
953
+ );
954
+ const dW1 = Array.from(
955
+ { length: this.d_ff },
956
+ (_, k) => Array.from(
957
+ { length: d_model },
958
+ (_2, m) => dFf1Pre.reduce((s, dp, i) => s + dp[k] * h1[i][m], 0)
959
+ )
960
+ );
961
+ const db1 = Array.from(
962
+ { length: this.d_ff },
963
+ (_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
964
+ );
965
+ this.ff1.update(dW1, lr);
966
+ for (let k = 0; k < this.d_ff; k++)
967
+ this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
968
+ const dH1 = Array.from(
969
+ { length: seqLen },
970
+ (_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
971
+ );
972
+ const dAdded1 = dH1.map((d, i) => this.norm1.backwardOne(d, i, lr));
973
+ const dAttnOut = dAdded1;
974
+ const dX_skip = dAdded1;
975
+ const dX_fromAttn = this.attn.backward(dAttnOut, lr);
976
+ const dX = Array.from(
977
+ { length: seqLen },
978
+ (_, i) => Array.from(
979
+ { length: d_model },
980
+ (_2, m) => dX_fromAttn[i][m] + dX_skip[i][m]
981
+ )
982
+ );
983
+ return dX;
984
+ }
985
+ // Attention weights from the last predict() — for visualization.
986
+ getAttentionWeights() {
987
+ return this.attn.getAttentionWeights();
988
+ }
989
+ };
990
+
991
+ // src/NetworkTransformer.ts
992
+ var NetworkTransformer = class {
993
+ constructor(seqLen, options = {}) {
994
+ const {
995
+ vocabSize = 10,
996
+ d_model = 64,
997
+ nHeads = 4,
998
+ d_ff = 128,
999
+ nBlocks = 4,
1000
+ nClasses = 9
1001
+ } = options;
1002
+ this.seqLen = seqLen;
1003
+ this.vocabSize = vocabSize;
1004
+ this.d_model = d_model;
1005
+ this.nClasses = nClasses;
1006
+ this.tokenEmb = new EmbeddingMatrix(vocabSize, d_model);
1007
+ this.posEmb = new EmbeddingMatrix(seqLen, d_model);
1008
+ this.blocks = Array.from(
1009
+ { length: nBlocks },
1010
+ () => new TransformerBlock({ d_model, nHeads, d_ff })
1011
+ );
1012
+ this.outputProj = new WeightMatrix(nClasses, d_model);
1013
+ this.outputBias = new Array(nClasses).fill(0);
1014
+ this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
1015
+ }
1016
+ // ── Forward pass ──────────────────────────────────────────────────────────
1017
+ // tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
1018
+ predict(tokens) {
1019
+ const h = this._forward(tokens);
1020
+ return h.flatMap(
1021
+ (hi) => this.outputProj.W.map(
1022
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1023
+ )
1024
+ );
1025
+ }
1026
+ // ── Training step (online, one sample at a time) ───────────────────────────
1027
+ // tokens: seqLen integer ids
1028
+ // targets: seqLen * nClasses values (e.g. one-hot per cell)
1029
+ // mask: optional boolean[seqLen] — only compute loss/gradients for
1030
+ // positions where mask[i] = true (e.g. empty cells in Sudoku)
1031
+ // Returns: MSE loss over the masked positions.
1032
+ train(tokens, targets, lr, mask) {
1033
+ const h = this._forward(tokens);
1034
+ const logits = h.map(
1035
+ (hi) => this.outputProj.W.map(
1036
+ (row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
1037
+ )
1038
+ );
1039
+ let loss = 0;
1040
+ let count = 0;
1041
+ const dLogits = Array.from({ length: this.seqLen }, (_, i) => {
1042
+ if (mask && !mask[i]) return new Array(this.nClasses).fill(0);
1043
+ count++;
1044
+ const probs = softmax(logits[i]);
1045
+ for (let c = 0; c < this.nClasses; c++) {
1046
+ const t = targets[i * this.nClasses + c];
1047
+ if (t > 0) loss -= Math.log(Math.max(probs[c], 1e-7));
1048
+ }
1049
+ return probs.map((p, c) => p - targets[i * this.nClasses + c]);
1050
+ });
1051
+ if (count > 0) loss /= count;
1052
+ const dH = Array.from(
1053
+ { length: this.seqLen },
1054
+ (_, i) => Array.from(
1055
+ { length: this.d_model },
1056
+ (_2, m) => dLogits[i].reduce((s, dl, c) => s + dl * this.outputProj.W[c][m], 0)
1057
+ )
1058
+ );
1059
+ const dWout = Array.from(
1060
+ { length: this.nClasses },
1061
+ (_, c) => Array.from(
1062
+ { length: this.d_model },
1063
+ (_2, m) => dLogits.reduce((s, dl, i) => s + dl[c] * h[i][m], 0)
1064
+ )
1065
+ );
1066
+ const dBout = Array.from(
1067
+ { length: this.nClasses },
1068
+ (_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
1069
+ );
1070
+ this.outputProj.update(dWout, lr);
1071
+ for (let c = 0; c < this.nClasses; c++)
1072
+ this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
1073
+ let dX = dH;
1074
+ for (let b = this.blocks.length - 1; b >= 0; b--)
1075
+ dX = this.blocks[b].backward(dX, lr);
1076
+ for (let i = 0; i < this.seqLen; i++) {
1077
+ this.tokenEmb.update(tokens[i], dX[i], lr);
1078
+ this.posEmb.update(i, dX[i], lr);
1079
+ }
1080
+ return loss;
1081
+ }
1082
+ // Attention weights from every block for visualization.
1083
+ // Returns: nBlocks × nHeads × seqLen × seqLen (nulls if not yet run)
1084
+ getAttentionWeights() {
1085
+ return this.blocks.map((b) => b.getAttentionWeights());
1086
+ }
1087
+ // ── Internal ──────────────────────────────────────────────────────────────
1088
+ // Shared embedding + block forward pass.
1089
+ _forward(tokens) {
1090
+ let h = tokens.map((id, i) => {
1091
+ const te = this.tokenEmb.get(id);
1092
+ const pe = this.posEmb.get(i);
1093
+ return te.map((v, m) => v + pe[m]);
1094
+ });
1095
+ for (const block of this.blocks)
1096
+ h = block.predict(h);
1097
+ return h;
1098
+ }
1099
+ };
1100
+
526
1101
  // src/losses.ts
527
1102
  function mse(predicted, actual) {
528
1103
  return predicted.reduce((sum, p, i) => sum + (actual[i] - p) ** 2, 0) / predicted.length;
@@ -548,15 +1123,22 @@ function crossEntropyDeltaRaw(predicted, actual) {
548
1123
  // Annotate the CommonJS export names for ESM import in node:
549
1124
  0 && (module.exports = {
550
1125
  Adam,
1126
+ AttentionHead,
1127
+ EmbeddingMatrix,
551
1128
  LSTMLayer,
552
1129
  Layer,
1130
+ LayerNorm,
553
1131
  Momentum,
1132
+ MultiHeadAttention,
554
1133
  Network,
555
1134
  NetworkLSTM,
556
1135
  NetworkN,
1136
+ NetworkTransformer,
557
1137
  Neuron,
558
1138
  NeuronN,
559
1139
  SGD,
1140
+ TransformerBlock,
1141
+ WeightMatrix,
560
1142
  crossEntropy,
561
1143
  crossEntropyDelta,
562
1144
  crossEntropyDeltaRaw,
@@ -565,9 +1147,13 @@ function crossEntropyDeltaRaw(predicted, actual) {
565
1147
  linear,
566
1148
  makeElu,
567
1149
  makeLeakyRelu,
1150
+ matMul,
568
1151
  mse,
569
1152
  mseDelta,
570
1153
  relu,
571
1154
  sigmoid,
572
- tanh
1155
+ softmax,
1156
+ softmaxBackward,
1157
+ tanh,
1158
+ transpose
573
1159
  });