@dniskav/neuron 0.1.5 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +38 -0
- package/dist/index.d.mts +116 -1
- package/dist/index.d.ts +116 -1
- package/dist/index.js +605 -2
- package/dist/index.mjs +589 -1
- package/package.json +1 -1
package/dist/index.mjs
CHANGED
|
@@ -38,6 +38,20 @@ var linear = {
|
|
|
38
38
|
fn: (x) => x,
|
|
39
39
|
dfn: () => 1
|
|
40
40
|
};
|
|
41
|
+
function makeLeakyRelu(alpha = 0.01) {
|
|
42
|
+
return {
|
|
43
|
+
fn: (x) => x > 0 ? x : alpha * x,
|
|
44
|
+
dfn: (out) => out > 0 ? 1 : alpha
|
|
45
|
+
};
|
|
46
|
+
}
|
|
47
|
+
var leakyRelu = makeLeakyRelu(0.01);
|
|
48
|
+
function makeElu(alpha = 1) {
|
|
49
|
+
return {
|
|
50
|
+
fn: (x) => x > 0 ? x : alpha * (Math.exp(x) - 1),
|
|
51
|
+
dfn: (out) => out > 0 ? 1 : out + alpha
|
|
52
|
+
};
|
|
53
|
+
}
|
|
54
|
+
var elu = makeElu(1);
|
|
41
55
|
|
|
42
56
|
// src/optimizers.ts
|
|
43
57
|
var SGD = class {
|
|
@@ -461,6 +475,565 @@ var NetworkLSTM = class {
|
|
|
461
475
|
}
|
|
462
476
|
};
|
|
463
477
|
|
|
478
|
+
// src/MatMul.ts
|
|
479
|
+
function matMul(A, B) {
|
|
480
|
+
const rows = A.length;
|
|
481
|
+
const inner = B.length;
|
|
482
|
+
const cols = B[0].length;
|
|
483
|
+
const C = Array.from({ length: rows }, () => new Array(cols).fill(0));
|
|
484
|
+
for (let i = 0; i < rows; i++)
|
|
485
|
+
for (let k = 0; k < inner; k++) {
|
|
486
|
+
const aik = A[i][k];
|
|
487
|
+
for (let j = 0; j < cols; j++)
|
|
488
|
+
C[i][j] += aik * B[k][j];
|
|
489
|
+
}
|
|
490
|
+
return C;
|
|
491
|
+
}
|
|
492
|
+
function transpose(A) {
|
|
493
|
+
const rows = A.length, cols = A[0].length;
|
|
494
|
+
const T = Array.from({ length: cols }, () => new Array(rows).fill(0));
|
|
495
|
+
for (let i = 0; i < rows; i++)
|
|
496
|
+
for (let j = 0; j < cols; j++)
|
|
497
|
+
T[j][i] = A[i][j];
|
|
498
|
+
return T;
|
|
499
|
+
}
|
|
500
|
+
function softmax(row) {
|
|
501
|
+
const max = Math.max(...row);
|
|
502
|
+
const exps = row.map((v) => Math.exp(v - max));
|
|
503
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
504
|
+
return exps.map((e) => e / sum);
|
|
505
|
+
}
|
|
506
|
+
function softmaxBackward(dS, s) {
|
|
507
|
+
const dot = s.reduce((acc, si, i) => acc + dS[i] * si, 0);
|
|
508
|
+
return s.map((si, i) => si * (dS[i] - dot));
|
|
509
|
+
}
|
|
510
|
+
var WeightMatrix = class {
|
|
511
|
+
constructor(rows, cols) {
|
|
512
|
+
const limit = Math.sqrt(2 / (rows + cols));
|
|
513
|
+
this.W = Array.from(
|
|
514
|
+
{ length: rows },
|
|
515
|
+
() => Array.from({ length: cols }, () => (Math.random() * 2 - 1) * limit)
|
|
516
|
+
);
|
|
517
|
+
this.opts = Array.from(
|
|
518
|
+
{ length: rows },
|
|
519
|
+
() => Array.from({ length: cols }, () => new Adam())
|
|
520
|
+
);
|
|
521
|
+
}
|
|
522
|
+
// Apply pre-computed gradient (same shape as W).
|
|
523
|
+
update(dW, lr) {
|
|
524
|
+
for (let i = 0; i < this.W.length; i++)
|
|
525
|
+
for (let j = 0; j < this.W[0].length; j++)
|
|
526
|
+
this.W[i][j] = this.opts[i][j].step(this.W[i][j], dW[i][j], lr);
|
|
527
|
+
}
|
|
528
|
+
};
|
|
529
|
+
var EmbeddingMatrix = class {
|
|
530
|
+
constructor(vocabSize, d_model) {
|
|
531
|
+
const limit = Math.sqrt(1 / d_model);
|
|
532
|
+
this.W = Array.from(
|
|
533
|
+
{ length: vocabSize },
|
|
534
|
+
() => Array.from({ length: d_model }, () => (Math.random() * 2 - 1) * limit)
|
|
535
|
+
);
|
|
536
|
+
}
|
|
537
|
+
get(idx) {
|
|
538
|
+
return [...this.W[idx]];
|
|
539
|
+
}
|
|
540
|
+
update(idx, grad, lr) {
|
|
541
|
+
for (let m = 0; m < this.W[idx].length; m++)
|
|
542
|
+
this.W[idx][m] += lr * grad[m];
|
|
543
|
+
}
|
|
544
|
+
};
|
|
545
|
+
|
|
546
|
+
// src/AttentionHead.ts
|
|
547
|
+
var AttentionHead = class {
|
|
548
|
+
constructor(d_model, d_k, d_v) {
|
|
549
|
+
// d_v × d_model
|
|
550
|
+
this.cache = null;
|
|
551
|
+
this.d_k = d_k;
|
|
552
|
+
this.d_v = d_v;
|
|
553
|
+
this.Wq = new WeightMatrix(d_k, d_model);
|
|
554
|
+
this.Wk = new WeightMatrix(d_k, d_model);
|
|
555
|
+
this.Wv = new WeightMatrix(d_v, d_model);
|
|
556
|
+
}
|
|
557
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
558
|
+
// X: seqLen × d_model → out: seqLen × d_v
|
|
559
|
+
predict(X) {
|
|
560
|
+
const seqLen = X.length;
|
|
561
|
+
const scale = 1 / Math.sqrt(this.d_k);
|
|
562
|
+
const Q = X.map(
|
|
563
|
+
(x) => this.Wq.W.map((wq) => wq.reduce((s, w, m) => s + w * x[m], 0))
|
|
564
|
+
);
|
|
565
|
+
const K = X.map(
|
|
566
|
+
(x) => this.Wk.W.map((wk) => wk.reduce((s, w, m) => s + w * x[m], 0))
|
|
567
|
+
);
|
|
568
|
+
const V = X.map(
|
|
569
|
+
(x) => this.Wv.W.map((wv) => wv.reduce((s, w, m) => s + w * x[m], 0))
|
|
570
|
+
);
|
|
571
|
+
const scores = Array.from(
|
|
572
|
+
{ length: seqLen },
|
|
573
|
+
(_, i) => Array.from(
|
|
574
|
+
{ length: seqLen },
|
|
575
|
+
(_2, j) => Q[i].reduce((s, q, k) => s + q * K[j][k], 0) * scale
|
|
576
|
+
)
|
|
577
|
+
);
|
|
578
|
+
const attn = scores.map((row) => softmax(row));
|
|
579
|
+
const out = Array.from(
|
|
580
|
+
{ length: seqLen },
|
|
581
|
+
(_, i) => Array.from(
|
|
582
|
+
{ length: this.d_v },
|
|
583
|
+
(_2, d) => attn[i].reduce((s, a, j) => s + a * V[j][d], 0)
|
|
584
|
+
)
|
|
585
|
+
);
|
|
586
|
+
this.cache = { X, Q, K, V, scores, attn };
|
|
587
|
+
return out;
|
|
588
|
+
}
|
|
589
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
590
|
+
// dOut: seqLen × d_v → dX: seqLen × d_model
|
|
591
|
+
//
|
|
592
|
+
// Steps:
|
|
593
|
+
// 1. dV = attn^T @ dOut
|
|
594
|
+
// 2. dAttn = dOut @ V^T (attention weight gradients)
|
|
595
|
+
// 3. dScores = softmaxBwd(dAttn) / √d_k
|
|
596
|
+
// 4. dQ = dScores @ K, dK = dScores^T @ Q
|
|
597
|
+
// 5. dWq = dQ^T @ X, dWk = dK^T @ X, dWv = dV^T @ X
|
|
598
|
+
// 6. dX = dQ @ Wq + dK @ Wk + dV @ Wv
|
|
599
|
+
backward(dOut, lr) {
|
|
600
|
+
const { X, Q, K, V, attn } = this.cache;
|
|
601
|
+
const seqLen = X.length;
|
|
602
|
+
const d_model = X[0].length;
|
|
603
|
+
const scale = 1 / Math.sqrt(this.d_k);
|
|
604
|
+
const dV = Array.from(
|
|
605
|
+
{ length: seqLen },
|
|
606
|
+
(_, j) => Array.from(
|
|
607
|
+
{ length: this.d_v },
|
|
608
|
+
(_2, d) => attn.reduce((s, a, i) => s + a[j] * dOut[i][d], 0)
|
|
609
|
+
)
|
|
610
|
+
);
|
|
611
|
+
const dAttn = Array.from(
|
|
612
|
+
{ length: seqLen },
|
|
613
|
+
(_, i) => Array.from(
|
|
614
|
+
{ length: seqLen },
|
|
615
|
+
(_2, j) => dOut[i].reduce((s, d, k) => s + d * V[j][k], 0)
|
|
616
|
+
)
|
|
617
|
+
);
|
|
618
|
+
const dScores = dAttn.map(
|
|
619
|
+
(da, i) => softmaxBackward(da, attn[i]).map((v) => v * scale)
|
|
620
|
+
);
|
|
621
|
+
const dQ = Array.from(
|
|
622
|
+
{ length: seqLen },
|
|
623
|
+
(_, i) => Array.from(
|
|
624
|
+
{ length: this.d_k },
|
|
625
|
+
(_2, k) => dScores[i].reduce((s, ds, j) => s + ds * K[j][k], 0)
|
|
626
|
+
)
|
|
627
|
+
);
|
|
628
|
+
const dK = Array.from(
|
|
629
|
+
{ length: seqLen },
|
|
630
|
+
(_, j) => Array.from(
|
|
631
|
+
{ length: this.d_k },
|
|
632
|
+
(_2, k) => dScores.reduce((s, ds, i) => s + ds[j] * Q[i][k], 0)
|
|
633
|
+
)
|
|
634
|
+
);
|
|
635
|
+
const dWq = Array.from(
|
|
636
|
+
{ length: this.d_k },
|
|
637
|
+
(_, k) => Array.from(
|
|
638
|
+
{ length: d_model },
|
|
639
|
+
(_2, m) => dQ.reduce((s, dq, i) => s + dq[k] * X[i][m], 0)
|
|
640
|
+
)
|
|
641
|
+
);
|
|
642
|
+
const dWk = Array.from(
|
|
643
|
+
{ length: this.d_k },
|
|
644
|
+
(_, k) => Array.from(
|
|
645
|
+
{ length: d_model },
|
|
646
|
+
(_2, m) => dK.reduce((s, dk, i) => s + dk[k] * X[i][m], 0)
|
|
647
|
+
)
|
|
648
|
+
);
|
|
649
|
+
const dWv = Array.from(
|
|
650
|
+
{ length: this.d_v },
|
|
651
|
+
(_, k) => Array.from(
|
|
652
|
+
{ length: d_model },
|
|
653
|
+
(_2, m) => dV.reduce((s, dv, i) => s + dv[k] * X[i][m], 0)
|
|
654
|
+
)
|
|
655
|
+
);
|
|
656
|
+
this.Wq.update(dWq, lr);
|
|
657
|
+
this.Wk.update(dWk, lr);
|
|
658
|
+
this.Wv.update(dWv, lr);
|
|
659
|
+
const dX = Array.from(
|
|
660
|
+
{ length: seqLen },
|
|
661
|
+
(_, i) => Array.from(
|
|
662
|
+
{ length: d_model },
|
|
663
|
+
(_2, m) => dQ[i].reduce((s, dq, k) => s + dq * this.Wq.W[k][m], 0) + dK[i].reduce((s, dk, k) => s + dk * this.Wk.W[k][m], 0) + dV[i].reduce((s, dv, k) => s + dv * this.Wv.W[k][m], 0)
|
|
664
|
+
)
|
|
665
|
+
);
|
|
666
|
+
return dX;
|
|
667
|
+
}
|
|
668
|
+
// Attention weights from the last predict() call — useful for visualization.
|
|
669
|
+
getAttentionWeights() {
|
|
670
|
+
return this.cache ? this.cache.attn : null;
|
|
671
|
+
}
|
|
672
|
+
};
|
|
673
|
+
|
|
674
|
+
// src/MultiHeadAttention.ts
|
|
675
|
+
var MultiHeadAttention = class {
|
|
676
|
+
// seqLen × (nHeads * d_k)
|
|
677
|
+
constructor(d_model, nHeads) {
|
|
678
|
+
// d_model × (nHeads * d_k)
|
|
679
|
+
// Cached for backward
|
|
680
|
+
this._concat = null;
|
|
681
|
+
this.nHeads = nHeads;
|
|
682
|
+
this.d_model = d_model;
|
|
683
|
+
this.d_k = Math.floor(d_model / nHeads);
|
|
684
|
+
this.heads = Array.from(
|
|
685
|
+
{ length: nHeads },
|
|
686
|
+
() => new AttentionHead(d_model, this.d_k, this.d_k)
|
|
687
|
+
);
|
|
688
|
+
this.Wo = new WeightMatrix(d_model, nHeads * this.d_k);
|
|
689
|
+
}
|
|
690
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
691
|
+
// X: seqLen × d_model → out: seqLen × d_model
|
|
692
|
+
predict(X) {
|
|
693
|
+
const seqLen = X.length;
|
|
694
|
+
const headOuts = this.heads.map((h) => h.predict(X));
|
|
695
|
+
const concat = Array.from(
|
|
696
|
+
{ length: seqLen },
|
|
697
|
+
(_, i) => headOuts.flatMap((ho) => ho[i])
|
|
698
|
+
);
|
|
699
|
+
const out = concat.map(
|
|
700
|
+
(c) => this.Wo.W.map((row) => row.reduce((s, w, j) => s + w * c[j], 0))
|
|
701
|
+
);
|
|
702
|
+
this._concat = concat;
|
|
703
|
+
return out;
|
|
704
|
+
}
|
|
705
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
706
|
+
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
707
|
+
backward(dOut, lr) {
|
|
708
|
+
const seqLen = dOut.length;
|
|
709
|
+
const concatD = this.nHeads * this.d_k;
|
|
710
|
+
const d_model = this.d_model;
|
|
711
|
+
const concat = this._concat;
|
|
712
|
+
const dConcat = dOut.map(
|
|
713
|
+
(do_) => Array.from(
|
|
714
|
+
{ length: concatD },
|
|
715
|
+
(_, j) => this.Wo.W.reduce((s, row, k) => s + do_[k] * row[j], 0)
|
|
716
|
+
)
|
|
717
|
+
);
|
|
718
|
+
const dWo = Array.from(
|
|
719
|
+
{ length: d_model },
|
|
720
|
+
(_, k) => Array.from(
|
|
721
|
+
{ length: concatD },
|
|
722
|
+
(_2, j) => dOut.reduce((s, row, i) => s + row[k] * concat[i][j], 0)
|
|
723
|
+
)
|
|
724
|
+
);
|
|
725
|
+
this.Wo.update(dWo, lr);
|
|
726
|
+
const dX = Array.from(
|
|
727
|
+
{ length: seqLen },
|
|
728
|
+
() => new Array(d_model).fill(0)
|
|
729
|
+
);
|
|
730
|
+
for (let h = 0; h < this.nHeads; h++) {
|
|
731
|
+
const start = h * this.d_k;
|
|
732
|
+
const dHeadOut = dConcat.map((dc) => dc.slice(start, start + this.d_k));
|
|
733
|
+
const dXh = this.heads[h].backward(dHeadOut, lr);
|
|
734
|
+
for (let i = 0; i < seqLen; i++)
|
|
735
|
+
for (let m = 0; m < d_model; m++)
|
|
736
|
+
dX[i][m] += dXh[i][m];
|
|
737
|
+
}
|
|
738
|
+
return dX;
|
|
739
|
+
}
|
|
740
|
+
// Attention weights per head from the last predict() — for visualization.
|
|
741
|
+
// Returns: nHeads × seqLen × seqLen
|
|
742
|
+
getAttentionWeights() {
|
|
743
|
+
return this.heads.map((h) => h.getAttentionWeights());
|
|
744
|
+
}
|
|
745
|
+
};
|
|
746
|
+
|
|
747
|
+
// src/LayerNorm.ts
|
|
748
|
+
var LayerNorm = class {
|
|
749
|
+
constructor(dim) {
|
|
750
|
+
this.eps = 1e-5;
|
|
751
|
+
// Per-position cache populated during the forward pass.
|
|
752
|
+
// resetCache() must be called before each sequence forward pass.
|
|
753
|
+
this._cache = [];
|
|
754
|
+
this.gamma = new Array(dim).fill(1);
|
|
755
|
+
this.beta = new Array(dim).fill(0);
|
|
756
|
+
}
|
|
757
|
+
// Call once before forward-passing a new sequence.
|
|
758
|
+
resetCache(seqLen) {
|
|
759
|
+
this._cache = new Array(seqLen);
|
|
760
|
+
}
|
|
761
|
+
// Normalize a single position's feature vector.
|
|
762
|
+
// pos must match the position index used in the corresponding backwardOne call.
|
|
763
|
+
predictOne(x, pos) {
|
|
764
|
+
const N = x.length;
|
|
765
|
+
const mean = x.reduce((s, v) => s + v, 0) / N;
|
|
766
|
+
const vari = x.reduce((s, v) => s + (v - mean) ** 2, 0) / N;
|
|
767
|
+
const std = Math.sqrt(vari + this.eps);
|
|
768
|
+
const x_norm = x.map((v) => (v - mean) / std);
|
|
769
|
+
this._cache[pos] = { x_norm, std };
|
|
770
|
+
return x_norm.map((xn, i) => this.gamma[i] * xn + this.beta[i]);
|
|
771
|
+
}
|
|
772
|
+
// Backprop through layer norm for one position.
|
|
773
|
+
//
|
|
774
|
+
// Given dL/dy (dOut), computes dL/dx:
|
|
775
|
+
// Let D = dOut ⊙ γ
|
|
776
|
+
// dL/dx_i = (1/std) * (D_i − mean(D) − x_norm_i * mean(D ⊙ x_norm))
|
|
777
|
+
//
|
|
778
|
+
// Also updates γ and β via SGD:
|
|
779
|
+
// γ_i += lr * dOut_i * x_norm_i
|
|
780
|
+
// β_i += lr * dOut_i
|
|
781
|
+
//
|
|
782
|
+
// SGD (not Adam) for γ/β: they are aggregated across all positions in the
|
|
783
|
+
// sequence (de-facto mini-batch update), so the gradient is already smoothed.
|
|
784
|
+
backwardOne(dOut, pos, lr) {
|
|
785
|
+
const { x_norm, std } = this._cache[pos];
|
|
786
|
+
const N = dOut.length;
|
|
787
|
+
for (let i = 0; i < N; i++) {
|
|
788
|
+
this.gamma[i] += lr * dOut[i] * x_norm[i];
|
|
789
|
+
this.beta[i] += lr * dOut[i];
|
|
790
|
+
}
|
|
791
|
+
const D = dOut.map((d, i) => d * this.gamma[i]);
|
|
792
|
+
const mD = D.reduce((s, v) => s + v, 0) / N;
|
|
793
|
+
const mDxn = D.reduce((s, d, i) => s + d * x_norm[i], 0) / N;
|
|
794
|
+
return D.map((d, i) => (d - mD - x_norm[i] * mDxn) / std);
|
|
795
|
+
}
|
|
796
|
+
};
|
|
797
|
+
|
|
798
|
+
// src/TransformerBlock.ts
|
|
799
|
+
var TransformerBlock = class {
|
|
800
|
+
constructor({ d_model, nHeads, d_ff }) {
|
|
801
|
+
// Forward caches (needed for backprop)
|
|
802
|
+
this._X = null;
|
|
803
|
+
this._attnOut = null;
|
|
804
|
+
this._h1 = null;
|
|
805
|
+
this._ff1Pre = null;
|
|
806
|
+
// pre-ReLU
|
|
807
|
+
this._ff1Out = null;
|
|
808
|
+
// post-ReLU
|
|
809
|
+
this._ff2Out = null;
|
|
810
|
+
this.d_model = d_model;
|
|
811
|
+
this.d_ff = d_ff;
|
|
812
|
+
this.attn = new MultiHeadAttention(d_model, nHeads);
|
|
813
|
+
this.norm1 = new LayerNorm(d_model);
|
|
814
|
+
this.norm2 = new LayerNorm(d_model);
|
|
815
|
+
this.ff1 = new WeightMatrix(d_ff, d_model);
|
|
816
|
+
this.ff2 = new WeightMatrix(d_model, d_ff);
|
|
817
|
+
this.b1 = new Array(d_ff).fill(0);
|
|
818
|
+
this.b2 = new Array(d_model).fill(0);
|
|
819
|
+
this.b1Opts = Array.from({ length: d_ff }, () => new Adam());
|
|
820
|
+
this.b2Opts = Array.from({ length: d_model }, () => new Adam());
|
|
821
|
+
}
|
|
822
|
+
// ── Forward ───────────────────────────────────────────────────────────────
|
|
823
|
+
// X: seqLen × d_model → out: seqLen × d_model
|
|
824
|
+
predict(X) {
|
|
825
|
+
const seqLen = X.length;
|
|
826
|
+
const attnOut = this.attn.predict(X);
|
|
827
|
+
this.norm1.resetCache(seqLen);
|
|
828
|
+
const h1 = X.map((x, i) => {
|
|
829
|
+
const added = x.map((v, k) => v + attnOut[i][k]);
|
|
830
|
+
return this.norm1.predictOne(added, i);
|
|
831
|
+
});
|
|
832
|
+
const ff1Pre = h1.map(
|
|
833
|
+
(h) => this.ff1.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b1[k]))
|
|
834
|
+
);
|
|
835
|
+
const ff1Out = ff1Pre.map((pre) => pre.map((v) => Math.max(0, v)));
|
|
836
|
+
const ff2Out = ff1Out.map(
|
|
837
|
+
(h) => this.ff2.W.map((row, k) => row.reduce((s, w, m) => s + w * h[m], this.b2[k]))
|
|
838
|
+
);
|
|
839
|
+
this.norm2.resetCache(seqLen);
|
|
840
|
+
const out = h1.map((h, i) => {
|
|
841
|
+
const added = h.map((v, k) => v + ff2Out[i][k]);
|
|
842
|
+
return this.norm2.predictOne(added, i);
|
|
843
|
+
});
|
|
844
|
+
this._X = X;
|
|
845
|
+
this._attnOut = attnOut;
|
|
846
|
+
this._h1 = h1;
|
|
847
|
+
this._ff1Pre = ff1Pre;
|
|
848
|
+
this._ff1Out = ff1Out;
|
|
849
|
+
this._ff2Out = ff2Out;
|
|
850
|
+
return out;
|
|
851
|
+
}
|
|
852
|
+
// ── Backward ──────────────────────────────────────────────────────────────
|
|
853
|
+
// dOut: seqLen × d_model → dX: seqLen × d_model
|
|
854
|
+
backward(dOut, lr) {
|
|
855
|
+
const seqLen = dOut.length;
|
|
856
|
+
const d_model = this.d_model;
|
|
857
|
+
const h1 = this._h1;
|
|
858
|
+
const ff1Out = this._ff1Out;
|
|
859
|
+
const ff1Pre = this._ff1Pre;
|
|
860
|
+
const dAdded2 = dOut.map((do_, i) => this.norm2.backwardOne(do_, i, lr));
|
|
861
|
+
const dFf1Out = dAdded2.map(
|
|
862
|
+
(da) => Array.from(
|
|
863
|
+
{ length: this.d_ff },
|
|
864
|
+
(_, k) => this.ff2.W.reduce((s, row, m) => s + row[k] * da[m], 0)
|
|
865
|
+
)
|
|
866
|
+
);
|
|
867
|
+
const dW2 = Array.from(
|
|
868
|
+
{ length: d_model },
|
|
869
|
+
(_, m) => Array.from(
|
|
870
|
+
{ length: this.d_ff },
|
|
871
|
+
(_2, k) => dAdded2.reduce((s, da, i) => s + da[m] * ff1Out[i][k], 0)
|
|
872
|
+
)
|
|
873
|
+
);
|
|
874
|
+
const db2 = Array.from(
|
|
875
|
+
{ length: d_model },
|
|
876
|
+
(_, m) => dAdded2.reduce((s, da) => s + da[m], 0)
|
|
877
|
+
);
|
|
878
|
+
this.ff2.update(dW2, lr);
|
|
879
|
+
for (let m = 0; m < d_model; m++)
|
|
880
|
+
this.b2[m] = this.b2Opts[m].step(this.b2[m], db2[m], lr);
|
|
881
|
+
const dFf1Pre = dFf1Out.map(
|
|
882
|
+
(d, i) => d.map((v, k) => ff1Pre[i][k] > 0 ? v : 0)
|
|
883
|
+
);
|
|
884
|
+
const dH1_fromFf = dFf1Pre.map(
|
|
885
|
+
(dp) => Array.from(
|
|
886
|
+
{ length: d_model },
|
|
887
|
+
(_, m) => this.ff1.W.reduce((s, row, k) => s + dp[k] * row[m], 0)
|
|
888
|
+
)
|
|
889
|
+
);
|
|
890
|
+
const dW1 = Array.from(
|
|
891
|
+
{ length: this.d_ff },
|
|
892
|
+
(_, k) => Array.from(
|
|
893
|
+
{ length: d_model },
|
|
894
|
+
(_2, m) => dFf1Pre.reduce((s, dp, i) => s + dp[k] * h1[i][m], 0)
|
|
895
|
+
)
|
|
896
|
+
);
|
|
897
|
+
const db1 = Array.from(
|
|
898
|
+
{ length: this.d_ff },
|
|
899
|
+
(_, k) => dFf1Pre.reduce((s, dp) => s + dp[k], 0)
|
|
900
|
+
);
|
|
901
|
+
this.ff1.update(dW1, lr);
|
|
902
|
+
for (let k = 0; k < this.d_ff; k++)
|
|
903
|
+
this.b1[k] = this.b1Opts[k].step(this.b1[k], db1[k], lr);
|
|
904
|
+
const dH1 = Array.from(
|
|
905
|
+
{ length: seqLen },
|
|
906
|
+
(_, i) => dH1_fromFf[i].map((v, m) => v + dAdded2[i][m])
|
|
907
|
+
);
|
|
908
|
+
const dAdded1 = dH1.map((d, i) => this.norm1.backwardOne(d, i, lr));
|
|
909
|
+
const dAttnOut = dAdded1;
|
|
910
|
+
const dX_skip = dAdded1;
|
|
911
|
+
const dX_fromAttn = this.attn.backward(dAttnOut, lr);
|
|
912
|
+
const dX = Array.from(
|
|
913
|
+
{ length: seqLen },
|
|
914
|
+
(_, i) => Array.from(
|
|
915
|
+
{ length: d_model },
|
|
916
|
+
(_2, m) => dX_fromAttn[i][m] + dX_skip[i][m]
|
|
917
|
+
)
|
|
918
|
+
);
|
|
919
|
+
return dX;
|
|
920
|
+
}
|
|
921
|
+
// Attention weights from the last predict() — for visualization.
|
|
922
|
+
getAttentionWeights() {
|
|
923
|
+
return this.attn.getAttentionWeights();
|
|
924
|
+
}
|
|
925
|
+
};
|
|
926
|
+
|
|
927
|
+
// src/NetworkTransformer.ts
|
|
928
|
+
var NetworkTransformer = class {
|
|
929
|
+
constructor(seqLen, options = {}) {
|
|
930
|
+
const {
|
|
931
|
+
vocabSize = 10,
|
|
932
|
+
d_model = 64,
|
|
933
|
+
nHeads = 4,
|
|
934
|
+
d_ff = 128,
|
|
935
|
+
nBlocks = 4,
|
|
936
|
+
nClasses = 9
|
|
937
|
+
} = options;
|
|
938
|
+
this.seqLen = seqLen;
|
|
939
|
+
this.vocabSize = vocabSize;
|
|
940
|
+
this.d_model = d_model;
|
|
941
|
+
this.nClasses = nClasses;
|
|
942
|
+
this.tokenEmb = new EmbeddingMatrix(vocabSize, d_model);
|
|
943
|
+
this.posEmb = new EmbeddingMatrix(seqLen, d_model);
|
|
944
|
+
this.blocks = Array.from(
|
|
945
|
+
{ length: nBlocks },
|
|
946
|
+
() => new TransformerBlock({ d_model, nHeads, d_ff })
|
|
947
|
+
);
|
|
948
|
+
this.outputProj = new WeightMatrix(nClasses, d_model);
|
|
949
|
+
this.outputBias = new Array(nClasses).fill(0);
|
|
950
|
+
this.outBiasOpts = Array.from({ length: nClasses }, () => new Adam());
|
|
951
|
+
}
|
|
952
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
953
|
+
// tokens: seqLen integer ids → seqLen * nClasses logits (flattened)
|
|
954
|
+
predict(tokens) {
|
|
955
|
+
const h = this._forward(tokens);
|
|
956
|
+
return h.flatMap(
|
|
957
|
+
(hi) => this.outputProj.W.map(
|
|
958
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
959
|
+
)
|
|
960
|
+
);
|
|
961
|
+
}
|
|
962
|
+
// ── Training step (online, one sample at a time) ───────────────────────────
|
|
963
|
+
// tokens: seqLen integer ids
|
|
964
|
+
// targets: seqLen * nClasses values (e.g. one-hot per cell)
|
|
965
|
+
// mask: optional boolean[seqLen] — only compute loss/gradients for
|
|
966
|
+
// positions where mask[i] = true (e.g. empty cells in Sudoku)
|
|
967
|
+
// Returns: MSE loss over the masked positions.
|
|
968
|
+
train(tokens, targets, lr, mask) {
|
|
969
|
+
const h = this._forward(tokens);
|
|
970
|
+
const logits = h.map(
|
|
971
|
+
(hi) => this.outputProj.W.map(
|
|
972
|
+
(row, c) => row.reduce((s, w, m) => s + w * hi[m], this.outputBias[c])
|
|
973
|
+
)
|
|
974
|
+
);
|
|
975
|
+
let loss = 0;
|
|
976
|
+
let count = 0;
|
|
977
|
+
const dLogits = Array.from({ length: this.seqLen }, (_, i) => {
|
|
978
|
+
if (mask && !mask[i]) return new Array(this.nClasses).fill(0);
|
|
979
|
+
count++;
|
|
980
|
+
return Array.from({ length: this.nClasses }, (_2, c) => {
|
|
981
|
+
const t = targets[i * this.nClasses + c];
|
|
982
|
+
const p = logits[i][c];
|
|
983
|
+
loss += (p - t) ** 2;
|
|
984
|
+
return 2 * (p - t);
|
|
985
|
+
});
|
|
986
|
+
});
|
|
987
|
+
if (count > 0) loss /= count * this.nClasses;
|
|
988
|
+
const dH = Array.from(
|
|
989
|
+
{ length: this.seqLen },
|
|
990
|
+
(_, i) => Array.from(
|
|
991
|
+
{ length: this.d_model },
|
|
992
|
+
(_2, m) => dLogits[i].reduce((s, dl, c) => s + dl * this.outputProj.W[c][m], 0)
|
|
993
|
+
)
|
|
994
|
+
);
|
|
995
|
+
const dWout = Array.from(
|
|
996
|
+
{ length: this.nClasses },
|
|
997
|
+
(_, c) => Array.from(
|
|
998
|
+
{ length: this.d_model },
|
|
999
|
+
(_2, m) => dLogits.reduce((s, dl, i) => s + dl[c] * h[i][m], 0)
|
|
1000
|
+
)
|
|
1001
|
+
);
|
|
1002
|
+
const dBout = Array.from(
|
|
1003
|
+
{ length: this.nClasses },
|
|
1004
|
+
(_, c) => dLogits.reduce((s, dl) => s + dl[c], 0)
|
|
1005
|
+
);
|
|
1006
|
+
this.outputProj.update(dWout, lr);
|
|
1007
|
+
for (let c = 0; c < this.nClasses; c++)
|
|
1008
|
+
this.outputBias[c] = this.outBiasOpts[c].step(this.outputBias[c], dBout[c], lr);
|
|
1009
|
+
let dX = dH;
|
|
1010
|
+
for (let b = this.blocks.length - 1; b >= 0; b--)
|
|
1011
|
+
dX = this.blocks[b].backward(dX, lr);
|
|
1012
|
+
for (let i = 0; i < this.seqLen; i++) {
|
|
1013
|
+
this.tokenEmb.update(tokens[i], dX[i], lr);
|
|
1014
|
+
this.posEmb.update(i, dX[i], lr);
|
|
1015
|
+
}
|
|
1016
|
+
return loss;
|
|
1017
|
+
}
|
|
1018
|
+
// Attention weights from every block for visualization.
|
|
1019
|
+
// Returns: nBlocks × nHeads × seqLen × seqLen (nulls if not yet run)
|
|
1020
|
+
getAttentionWeights() {
|
|
1021
|
+
return this.blocks.map((b) => b.getAttentionWeights());
|
|
1022
|
+
}
|
|
1023
|
+
// ── Internal ──────────────────────────────────────────────────────────────
|
|
1024
|
+
// Shared embedding + block forward pass.
|
|
1025
|
+
_forward(tokens) {
|
|
1026
|
+
let h = tokens.map((id, i) => {
|
|
1027
|
+
const te = this.tokenEmb.get(id);
|
|
1028
|
+
const pe = this.posEmb.get(i);
|
|
1029
|
+
return te.map((v, m) => v + pe[m]);
|
|
1030
|
+
});
|
|
1031
|
+
for (const block of this.blocks)
|
|
1032
|
+
h = block.predict(h);
|
|
1033
|
+
return h;
|
|
1034
|
+
}
|
|
1035
|
+
};
|
|
1036
|
+
|
|
464
1037
|
// src/losses.ts
|
|
465
1038
|
function mse(predicted, actual) {
|
|
466
1039
|
return predicted.reduce((sum, p, i) => sum + (actual[i] - p) ** 2, 0) / predicted.length;
|
|
@@ -485,22 +1058,37 @@ function crossEntropyDeltaRaw(predicted, actual) {
|
|
|
485
1058
|
}
|
|
486
1059
|
export {
|
|
487
1060
|
Adam,
|
|
1061
|
+
AttentionHead,
|
|
1062
|
+
EmbeddingMatrix,
|
|
488
1063
|
LSTMLayer,
|
|
489
1064
|
Layer,
|
|
1065
|
+
LayerNorm,
|
|
490
1066
|
Momentum,
|
|
1067
|
+
MultiHeadAttention,
|
|
491
1068
|
Network,
|
|
492
1069
|
NetworkLSTM,
|
|
493
1070
|
NetworkN,
|
|
1071
|
+
NetworkTransformer,
|
|
494
1072
|
Neuron,
|
|
495
1073
|
NeuronN,
|
|
496
1074
|
SGD,
|
|
1075
|
+
TransformerBlock,
|
|
1076
|
+
WeightMatrix,
|
|
497
1077
|
crossEntropy,
|
|
498
1078
|
crossEntropyDelta,
|
|
499
1079
|
crossEntropyDeltaRaw,
|
|
1080
|
+
elu,
|
|
1081
|
+
leakyRelu,
|
|
500
1082
|
linear,
|
|
1083
|
+
makeElu,
|
|
1084
|
+
makeLeakyRelu,
|
|
1085
|
+
matMul,
|
|
501
1086
|
mse,
|
|
502
1087
|
mseDelta,
|
|
503
1088
|
relu,
|
|
504
1089
|
sigmoid2 as sigmoid,
|
|
505
|
-
|
|
1090
|
+
softmax,
|
|
1091
|
+
softmaxBackward,
|
|
1092
|
+
tanh,
|
|
1093
|
+
transpose
|
|
506
1094
|
};
|
package/package.json
CHANGED