@dniskav/neuron 0.2.7 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +540 -192
- package/dist/index.d.mts +587 -1
- package/dist/index.d.ts +587 -1
- package/dist/index.js +3778 -2
- package/dist/index.mjs +3734 -2
- package/package.json +2 -2
package/dist/index.mjs
CHANGED
|
@@ -2430,12 +2430,12 @@ var Trainer = class {
|
|
|
2430
2430
|
precisions.push(colSum > 0 ? tp / colSum : 0);
|
|
2431
2431
|
recalls.push(rowSum > 0 ? tp / rowSum : 0);
|
|
2432
2432
|
}
|
|
2433
|
-
const
|
|
2433
|
+
const accuracy2 = totalSamples > 0 ? totalCorrect / totalSamples : 0;
|
|
2434
2434
|
const macroPrecision = precisions.reduce((a, b) => a + b, 0) / nClasses;
|
|
2435
2435
|
const macroRecall = recalls.reduce((a, b) => a + b, 0) / nClasses;
|
|
2436
2436
|
const f1 = macroPrecision + macroRecall > 0 ? 2 * macroPrecision * macroRecall / (macroPrecision + macroRecall) : 0;
|
|
2437
2437
|
return {
|
|
2438
|
-
accuracy,
|
|
2438
|
+
accuracy: accuracy2,
|
|
2439
2439
|
precision: macroPrecision,
|
|
2440
2440
|
recall: macroRecall,
|
|
2441
2441
|
f1
|
|
@@ -2598,22 +2598,3728 @@ var ModelSaver = class _ModelSaver {
|
|
|
2598
2598
|
_ModelSaver.fromJSON(model, json);
|
|
2599
2599
|
}
|
|
2600
2600
|
};
|
|
2601
|
+
|
|
2602
|
+
// src/Perceptron.ts
|
|
2603
|
+
var Perceptron = class {
|
|
2604
|
+
// ─── Constructor ─────────────────────────────────────────────────────────────
|
|
2605
|
+
// All weights and bias start at 0. The perceptron learning rule does not
|
|
2606
|
+
// require random initialization because the step function already breaks
|
|
2607
|
+
// symmetry when any misclassification occurs.
|
|
2608
|
+
constructor(nInputs) {
|
|
2609
|
+
if (!Number.isInteger(nInputs) || nInputs <= 0) {
|
|
2610
|
+
throw new Error(
|
|
2611
|
+
`Perceptron: nInputs must be a positive integer, got ${nInputs}`
|
|
2612
|
+
);
|
|
2613
|
+
}
|
|
2614
|
+
this.weights = new Array(nInputs).fill(0);
|
|
2615
|
+
this.bias = 0;
|
|
2616
|
+
}
|
|
2617
|
+
// ─── Forward pass ────────────────────────────────────────────────────────────
|
|
2618
|
+
// Computes z = Σ(wᵢ·xᵢ) + bias, then applies the Heaviside step function.
|
|
2619
|
+
// Returns 1 if z > 0, else 0.
|
|
2620
|
+
predict(inputs) {
|
|
2621
|
+
validateArray(inputs, this.weights.length, "Perceptron.predict");
|
|
2622
|
+
let z = this.bias;
|
|
2623
|
+
for (let i = 0; i < this.weights.length; i++) {
|
|
2624
|
+
z += this.weights[i] * inputs[i];
|
|
2625
|
+
}
|
|
2626
|
+
return z > 0 ? 1 : 0;
|
|
2627
|
+
}
|
|
2628
|
+
// ─── Training step ───────────────────────────────────────────────────────────
|
|
2629
|
+
// Applies the perceptron update rule for a single (input, target) pair.
|
|
2630
|
+
//
|
|
2631
|
+
// error = target − output (0 on correct prediction → no update)
|
|
2632
|
+
// wᵢ ← wᵢ + lr · error · xᵢ
|
|
2633
|
+
// bias ← bias + lr · error
|
|
2634
|
+
//
|
|
2635
|
+
// Returns the error (useful for tracking convergence).
|
|
2636
|
+
train(inputs, target, lr) {
|
|
2637
|
+
validateArray(inputs, this.weights.length, "Perceptron.train");
|
|
2638
|
+
validateNumber(target, "Perceptron.train");
|
|
2639
|
+
validateNumber(lr, "Perceptron.train");
|
|
2640
|
+
if (target !== 0 && target !== 1) {
|
|
2641
|
+
throw new Error(
|
|
2642
|
+
`Perceptron.train: target must be 0 or 1, got ${target}`
|
|
2643
|
+
);
|
|
2644
|
+
}
|
|
2645
|
+
if (lr <= 0) {
|
|
2646
|
+
throw new Error(
|
|
2647
|
+
`Perceptron.train: learning rate must be positive, got ${lr}`
|
|
2648
|
+
);
|
|
2649
|
+
}
|
|
2650
|
+
const output = this.predict(inputs);
|
|
2651
|
+
const error = target - output;
|
|
2652
|
+
if (error !== 0) {
|
|
2653
|
+
for (let i = 0; i < this.weights.length; i++) {
|
|
2654
|
+
this.weights[i] += lr * error * inputs[i];
|
|
2655
|
+
}
|
|
2656
|
+
this.bias += lr * error;
|
|
2657
|
+
}
|
|
2658
|
+
return error;
|
|
2659
|
+
}
|
|
2660
|
+
};
|
|
2661
|
+
|
|
2662
|
+
// src/LinearRegression.ts
|
|
2663
|
+
function matMul2(A, B) {
|
|
2664
|
+
const m = A.length;
|
|
2665
|
+
const k = A[0].length;
|
|
2666
|
+
const n = B[0].length;
|
|
2667
|
+
const C = Array.from({ length: m }, () => new Array(n).fill(0));
|
|
2668
|
+
for (let i = 0; i < m; i++) {
|
|
2669
|
+
for (let j = 0; j < n; j++) {
|
|
2670
|
+
let sum = 0;
|
|
2671
|
+
for (let p = 0; p < k; p++) sum += A[i][p] * B[p][j];
|
|
2672
|
+
C[i][j] = sum;
|
|
2673
|
+
}
|
|
2674
|
+
}
|
|
2675
|
+
return C;
|
|
2676
|
+
}
|
|
2677
|
+
function transpose2(A) {
|
|
2678
|
+
const m = A.length;
|
|
2679
|
+
const n = A[0].length;
|
|
2680
|
+
const T = Array.from({ length: n }, () => new Array(m).fill(0));
|
|
2681
|
+
for (let i = 0; i < m; i++) {
|
|
2682
|
+
for (let j = 0; j < n; j++) {
|
|
2683
|
+
T[j][i] = A[i][j];
|
|
2684
|
+
}
|
|
2685
|
+
}
|
|
2686
|
+
return T;
|
|
2687
|
+
}
|
|
2688
|
+
function invertMatrix(M) {
|
|
2689
|
+
const n = M.length;
|
|
2690
|
+
const aug = M.map((row, i) => {
|
|
2691
|
+
const id = new Array(n).fill(0);
|
|
2692
|
+
id[i] = 1;
|
|
2693
|
+
return [...row, ...id];
|
|
2694
|
+
});
|
|
2695
|
+
for (let col = 0; col < n; col++) {
|
|
2696
|
+
let maxRow = col;
|
|
2697
|
+
let maxVal = Math.abs(aug[col][col]);
|
|
2698
|
+
for (let row = col + 1; row < n; row++) {
|
|
2699
|
+
if (Math.abs(aug[row][col]) > maxVal) {
|
|
2700
|
+
maxVal = Math.abs(aug[row][col]);
|
|
2701
|
+
maxRow = row;
|
|
2702
|
+
}
|
|
2703
|
+
}
|
|
2704
|
+
[aug[col], aug[maxRow]] = [aug[maxRow], aug[col]];
|
|
2705
|
+
const pivot = aug[col][col];
|
|
2706
|
+
if (Math.abs(pivot) < 1e-12) return null;
|
|
2707
|
+
for (let j = 0; j < 2 * n; j++) aug[col][j] /= pivot;
|
|
2708
|
+
for (let row = 0; row < n; row++) {
|
|
2709
|
+
if (row === col) continue;
|
|
2710
|
+
const factor = aug[row][col];
|
|
2711
|
+
for (let j = 0; j < 2 * n; j++) {
|
|
2712
|
+
aug[row][j] -= factor * aug[col][j];
|
|
2713
|
+
}
|
|
2714
|
+
}
|
|
2715
|
+
}
|
|
2716
|
+
return aug.map((row) => row.slice(n));
|
|
2717
|
+
}
|
|
2718
|
+
function augment(X) {
|
|
2719
|
+
return X.map((row) => [...row, 1]);
|
|
2720
|
+
}
|
|
2721
|
+
var LinearRegression = class {
|
|
2722
|
+
constructor() {
|
|
2723
|
+
// weights = [w₁, w₂, ..., wₙ, bias]
|
|
2724
|
+
this.weights = [];
|
|
2725
|
+
this._nFeatures = 0;
|
|
2726
|
+
}
|
|
2727
|
+
// ─── Normal Equation ───────────────────────────────────────────────────────
|
|
2728
|
+
// W = (XᵀX)⁻¹Xᵀy — exact solution in one matrix operation.
|
|
2729
|
+
// Augments X with a bias column so the bias is solved jointly.
|
|
2730
|
+
fitNormal(X, y) {
|
|
2731
|
+
if (X.length === 0) throw new Error("LinearRegression.fitNormal: X is empty");
|
|
2732
|
+
if (X.length !== y.length) {
|
|
2733
|
+
throw new Error(
|
|
2734
|
+
`LinearRegression.fitNormal: X has ${X.length} rows but y has ${y.length} elements`
|
|
2735
|
+
);
|
|
2736
|
+
}
|
|
2737
|
+
this._nFeatures = X[0].length;
|
|
2738
|
+
const Xa = augment(X);
|
|
2739
|
+
const XaT = transpose2(Xa);
|
|
2740
|
+
const XaTXa = matMul2(XaT, Xa);
|
|
2741
|
+
const XaTXaInv = invertMatrix(XaTXa);
|
|
2742
|
+
if (XaTXaInv === null) {
|
|
2743
|
+
throw new Error(
|
|
2744
|
+
"LinearRegression.fitNormal: X\u1D40X is singular \u2014 features may be linearly dependent"
|
|
2745
|
+
);
|
|
2746
|
+
}
|
|
2747
|
+
const yCol = y.map((v) => [v]);
|
|
2748
|
+
const XaTy = matMul2(XaT, yCol);
|
|
2749
|
+
const W = matMul2(XaTXaInv, XaTy);
|
|
2750
|
+
this.weights = W.map((row) => row[0]);
|
|
2751
|
+
}
|
|
2752
|
+
// ─── Gradient Descent ──────────────────────────────────────────────────────
|
|
2753
|
+
// Minimises MSE = (1/m) Σ (ŷᵢ − yᵢ)² iteratively.
|
|
2754
|
+
//
|
|
2755
|
+
// ŷ = Xa · W
|
|
2756
|
+
// dW = (2/m) · Xaᵀ · (ŷ − y)
|
|
2757
|
+
// W ← W − lr · dW
|
|
2758
|
+
//
|
|
2759
|
+
// Returns the loss (MSE) at every epoch for convergence diagnostics.
|
|
2760
|
+
fitGD(X, y, lr, epochs) {
|
|
2761
|
+
if (X.length === 0) throw new Error("LinearRegression.fitGD: X is empty");
|
|
2762
|
+
if (X.length !== y.length) {
|
|
2763
|
+
throw new Error(
|
|
2764
|
+
`LinearRegression.fitGD: X has ${X.length} rows but y has ${y.length} elements`
|
|
2765
|
+
);
|
|
2766
|
+
}
|
|
2767
|
+
validateNumber(lr, "LinearRegression.fitGD");
|
|
2768
|
+
if (lr <= 0) throw new Error("LinearRegression.fitGD: lr must be positive");
|
|
2769
|
+
if (!Number.isInteger(epochs) || epochs <= 0) {
|
|
2770
|
+
throw new Error("LinearRegression.fitGD: epochs must be a positive integer");
|
|
2771
|
+
}
|
|
2772
|
+
this._nFeatures = X[0].length;
|
|
2773
|
+
const m = X.length;
|
|
2774
|
+
const Xa = augment(X);
|
|
2775
|
+
this.weights = new Array(this._nFeatures + 1).fill(0);
|
|
2776
|
+
const lossHistory = [];
|
|
2777
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
2778
|
+
const yHat = Xa.map(
|
|
2779
|
+
(row) => row.reduce((s, x, j) => s + x * this.weights[j], 0)
|
|
2780
|
+
);
|
|
2781
|
+
const residuals = yHat.map((yh, i) => yh - y[i]);
|
|
2782
|
+
const mse2 = residuals.reduce((s, r) => s + r * r, 0) / m;
|
|
2783
|
+
lossHistory.push(mse2);
|
|
2784
|
+
for (let j = 0; j < this.weights.length; j++) {
|
|
2785
|
+
let grad = 0;
|
|
2786
|
+
for (let i = 0; i < m; i++) {
|
|
2787
|
+
grad += Xa[i][j] * residuals[i];
|
|
2788
|
+
}
|
|
2789
|
+
this.weights[j] -= lr * (2 / m) * grad;
|
|
2790
|
+
}
|
|
2791
|
+
}
|
|
2792
|
+
return lossHistory;
|
|
2793
|
+
}
|
|
2794
|
+
// ─── Inference ─────────────────────────────────────────────────────────────
|
|
2795
|
+
// ŷ = Σ wᵢ·xᵢ + bias (bias = weights[last])
|
|
2796
|
+
predict(x) {
|
|
2797
|
+
if (this.weights.length === 0) {
|
|
2798
|
+
throw new Error("LinearRegression.predict: model has not been fitted yet");
|
|
2799
|
+
}
|
|
2800
|
+
if (x.length !== this._nFeatures) {
|
|
2801
|
+
throw new Error(
|
|
2802
|
+
`LinearRegression.predict: expected ${this._nFeatures} features, got ${x.length}`
|
|
2803
|
+
);
|
|
2804
|
+
}
|
|
2805
|
+
let out = this.weights[this._nFeatures];
|
|
2806
|
+
for (let i = 0; i < this._nFeatures; i++) {
|
|
2807
|
+
out += this.weights[i] * x[i];
|
|
2808
|
+
}
|
|
2809
|
+
return out;
|
|
2810
|
+
}
|
|
2811
|
+
// ─── Introspection ─────────────────────────────────────────────────────────
|
|
2812
|
+
getCoefficients() {
|
|
2813
|
+
if (this.weights.length === 0) {
|
|
2814
|
+
throw new Error(
|
|
2815
|
+
"LinearRegression.getCoefficients: model has not been fitted yet"
|
|
2816
|
+
);
|
|
2817
|
+
}
|
|
2818
|
+
return {
|
|
2819
|
+
weights: this.weights.slice(0, this._nFeatures),
|
|
2820
|
+
bias: this.weights[this._nFeatures]
|
|
2821
|
+
};
|
|
2822
|
+
}
|
|
2823
|
+
};
|
|
2824
|
+
|
|
2825
|
+
// src/LogisticRegression.ts
|
|
2826
|
+
function sigmoid5(z) {
|
|
2827
|
+
return 1 / (1 + Math.exp(-z));
|
|
2828
|
+
}
|
|
2829
|
+
function bce(target, pred) {
|
|
2830
|
+
const eps = 1e-15;
|
|
2831
|
+
const p = Math.max(eps, Math.min(1 - eps, pred));
|
|
2832
|
+
return -(target * Math.log(p) + (1 - target) * Math.log(1 - p));
|
|
2833
|
+
}
|
|
2834
|
+
var LogisticRegression = class {
|
|
2835
|
+
constructor() {
|
|
2836
|
+
this.weights = [];
|
|
2837
|
+
this.bias = 0;
|
|
2838
|
+
this._nFeatures = 0;
|
|
2839
|
+
}
|
|
2840
|
+
// ─── Train ────────────────────────────────────────────────────────────────
|
|
2841
|
+
// Online SGD over the full dataset for `epochs` passes.
|
|
2842
|
+
// Updates are applied after each sample (stochastic gradient descent).
|
|
2843
|
+
//
|
|
2844
|
+
// Returns the mean BCE loss per epoch for convergence monitoring.
|
|
2845
|
+
train(X, y, lr, epochs) {
|
|
2846
|
+
if (X.length === 0) throw new Error("LogisticRegression.train: X is empty");
|
|
2847
|
+
if (X.length !== y.length) {
|
|
2848
|
+
throw new Error(
|
|
2849
|
+
`LogisticRegression.train: X has ${X.length} rows but y has ${y.length} labels`
|
|
2850
|
+
);
|
|
2851
|
+
}
|
|
2852
|
+
validateNumber(lr, "LogisticRegression.train");
|
|
2853
|
+
if (lr <= 0) throw new Error("LogisticRegression.train: lr must be positive");
|
|
2854
|
+
if (!Number.isInteger(epochs) || epochs <= 0) {
|
|
2855
|
+
throw new Error("LogisticRegression.train: epochs must be a positive integer");
|
|
2856
|
+
}
|
|
2857
|
+
this._nFeatures = X[0].length;
|
|
2858
|
+
if (this.weights.length !== this._nFeatures) {
|
|
2859
|
+
const limit = Math.sqrt(2 / this._nFeatures);
|
|
2860
|
+
this.weights = Array.from(
|
|
2861
|
+
{ length: this._nFeatures },
|
|
2862
|
+
() => (Math.random() * 2 - 1) * limit
|
|
2863
|
+
);
|
|
2864
|
+
this.bias = 0;
|
|
2865
|
+
}
|
|
2866
|
+
const lossHistory = [];
|
|
2867
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
2868
|
+
let epochLoss = 0;
|
|
2869
|
+
for (let i = 0; i < X.length; i++) {
|
|
2870
|
+
const xi = X[i];
|
|
2871
|
+
const yi = y[i];
|
|
2872
|
+
let z = this.bias;
|
|
2873
|
+
for (let j = 0; j < this._nFeatures; j++) z += this.weights[j] * xi[j];
|
|
2874
|
+
const yHat = sigmoid5(z);
|
|
2875
|
+
epochLoss += bce(yi, yHat);
|
|
2876
|
+
const delta = yi - yHat;
|
|
2877
|
+
for (let j = 0; j < this._nFeatures; j++) {
|
|
2878
|
+
this.weights[j] += lr * delta * xi[j];
|
|
2879
|
+
}
|
|
2880
|
+
this.bias += lr * delta;
|
|
2881
|
+
}
|
|
2882
|
+
lossHistory.push(epochLoss / X.length);
|
|
2883
|
+
}
|
|
2884
|
+
return lossHistory;
|
|
2885
|
+
}
|
|
2886
|
+
// ─── Predict (probability) ────────────────────────────────────────────────
|
|
2887
|
+
// Returns P(y=1|x) ∈ [0, 1].
|
|
2888
|
+
predict(x) {
|
|
2889
|
+
if (this.weights.length === 0) {
|
|
2890
|
+
throw new Error("LogisticRegression.predict: model has not been trained yet");
|
|
2891
|
+
}
|
|
2892
|
+
validateArray(x, this._nFeatures, "LogisticRegression.predict");
|
|
2893
|
+
let z = this.bias;
|
|
2894
|
+
for (let j = 0; j < this._nFeatures; j++) z += this.weights[j] * x[j];
|
|
2895
|
+
return sigmoid5(z);
|
|
2896
|
+
}
|
|
2897
|
+
// ─── Classify (hard label) ────────────────────────────────────────────────
|
|
2898
|
+
// Returns 0 or 1 using 0.5 as the decision threshold.
|
|
2899
|
+
classify(x) {
|
|
2900
|
+
return this.predict(x) >= 0.5 ? 1 : 0;
|
|
2901
|
+
}
|
|
2902
|
+
};
|
|
2903
|
+
var SoftmaxRegression = class {
|
|
2904
|
+
constructor() {
|
|
2905
|
+
// weights[k][j] = weight for class k, feature j
|
|
2906
|
+
this.weights = [];
|
|
2907
|
+
// biases[k] = bias for class k
|
|
2908
|
+
this.biases = [];
|
|
2909
|
+
this._nFeatures = 0;
|
|
2910
|
+
this._nClasses = 0;
|
|
2911
|
+
}
|
|
2912
|
+
// ─── Softmax helper ──────────────────────────────────────────────────────
|
|
2913
|
+
_softmax(scores) {
|
|
2914
|
+
const maxScore = Math.max(...scores);
|
|
2915
|
+
const exps = scores.map((s) => Math.exp(s - maxScore));
|
|
2916
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
2917
|
+
return exps.map((e) => e / sum);
|
|
2918
|
+
}
|
|
2919
|
+
// ─── Train ────────────────────────────────────────────────────────────────
|
|
2920
|
+
// y must contain integer class labels 0..K-1.
|
|
2921
|
+
// Returns mean cross-entropy loss per epoch.
|
|
2922
|
+
train(X, y, lr, epochs) {
|
|
2923
|
+
if (X.length === 0) throw new Error("SoftmaxRegression.train: X is empty");
|
|
2924
|
+
if (X.length !== y.length) {
|
|
2925
|
+
throw new Error(
|
|
2926
|
+
`SoftmaxRegression.train: X has ${X.length} rows but y has ${y.length} labels`
|
|
2927
|
+
);
|
|
2928
|
+
}
|
|
2929
|
+
validateNumber(lr, "SoftmaxRegression.train");
|
|
2930
|
+
if (lr <= 0) throw new Error("SoftmaxRegression.train: lr must be positive");
|
|
2931
|
+
if (!Number.isInteger(epochs) || epochs <= 0) {
|
|
2932
|
+
throw new Error("SoftmaxRegression.train: epochs must be a positive integer");
|
|
2933
|
+
}
|
|
2934
|
+
this._nFeatures = X[0].length;
|
|
2935
|
+
this._nClasses = Math.max(...y) + 1;
|
|
2936
|
+
if (this._nClasses < 2) {
|
|
2937
|
+
throw new Error("SoftmaxRegression.train: need at least 2 classes in y");
|
|
2938
|
+
}
|
|
2939
|
+
const limit = Math.sqrt(2 / this._nFeatures);
|
|
2940
|
+
this.weights = Array.from(
|
|
2941
|
+
{ length: this._nClasses },
|
|
2942
|
+
() => Array.from({ length: this._nFeatures }, () => (Math.random() * 2 - 1) * limit)
|
|
2943
|
+
);
|
|
2944
|
+
this.biases = new Array(this._nClasses).fill(0);
|
|
2945
|
+
const lossHistory = [];
|
|
2946
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
2947
|
+
let epochLoss = 0;
|
|
2948
|
+
for (let i = 0; i < X.length; i++) {
|
|
2949
|
+
const xi = X[i];
|
|
2950
|
+
const trueClass = y[i];
|
|
2951
|
+
const scores = this.weights.map((wk, k) => {
|
|
2952
|
+
let s = this.biases[k];
|
|
2953
|
+
for (let j = 0; j < this._nFeatures; j++) s += wk[j] * xi[j];
|
|
2954
|
+
return s;
|
|
2955
|
+
});
|
|
2956
|
+
const probs = this._softmax(scores);
|
|
2957
|
+
epochLoss += -Math.log(Math.max(probs[trueClass], 1e-15));
|
|
2958
|
+
for (let k = 0; k < this._nClasses; k++) {
|
|
2959
|
+
const delta = probs[k] - (k === trueClass ? 1 : 0);
|
|
2960
|
+
for (let j = 0; j < this._nFeatures; j++) {
|
|
2961
|
+
this.weights[k][j] -= lr * delta * xi[j];
|
|
2962
|
+
}
|
|
2963
|
+
this.biases[k] -= lr * delta;
|
|
2964
|
+
}
|
|
2965
|
+
}
|
|
2966
|
+
lossHistory.push(epochLoss / X.length);
|
|
2967
|
+
}
|
|
2968
|
+
return lossHistory;
|
|
2969
|
+
}
|
|
2970
|
+
// ─── Predict (class probabilities) ───────────────────────────────────────
|
|
2971
|
+
predictProba(x) {
|
|
2972
|
+
if (this.weights.length === 0) {
|
|
2973
|
+
throw new Error("SoftmaxRegression.predictProba: model has not been trained yet");
|
|
2974
|
+
}
|
|
2975
|
+
validateArray(x, this._nFeatures, "SoftmaxRegression.predictProba");
|
|
2976
|
+
const scores = this.weights.map((wk, k) => {
|
|
2977
|
+
let s = this.biases[k];
|
|
2978
|
+
for (let j = 0; j < this._nFeatures; j++) s += wk[j] * x[j];
|
|
2979
|
+
return s;
|
|
2980
|
+
});
|
|
2981
|
+
return this._softmax(scores);
|
|
2982
|
+
}
|
|
2983
|
+
// ─── Classify (argmax) ────────────────────────────────────────────────────
|
|
2984
|
+
predict(x) {
|
|
2985
|
+
const probs = this.predictProba(x);
|
|
2986
|
+
return probs.indexOf(Math.max(...probs));
|
|
2987
|
+
}
|
|
2988
|
+
};
|
|
2989
|
+
|
|
2990
|
+
// src/NaiveBayes.ts
|
|
2991
|
+
var GaussianNaiveBayes = class {
|
|
2992
|
+
constructor() {
|
|
2993
|
+
// Per-class, per-feature statistics
|
|
2994
|
+
this._means = /* @__PURE__ */ new Map();
|
|
2995
|
+
this._variances = /* @__PURE__ */ new Map();
|
|
2996
|
+
// Log prior: log P(class)
|
|
2997
|
+
this._logPriors = /* @__PURE__ */ new Map();
|
|
2998
|
+
this._classes = [];
|
|
2999
|
+
this._nFeatures = 0;
|
|
3000
|
+
}
|
|
3001
|
+
// ─── Fit ───────────────────────────────────────────────────────────────────
|
|
3002
|
+
// Scans the data once to compute μ, σ², and π per class.
|
|
3003
|
+
// Variance is clamped to a minimum of 1e-9 to prevent division by zero
|
|
3004
|
+
// when a feature is perfectly constant within a class.
|
|
3005
|
+
fit(X, y) {
|
|
3006
|
+
if (X.length === 0) throw new Error("GaussianNaiveBayes.fit: X is empty");
|
|
3007
|
+
if (X.length !== y.length) {
|
|
3008
|
+
throw new Error(
|
|
3009
|
+
`GaussianNaiveBayes.fit: X has ${X.length} rows but y has ${y.length} labels`
|
|
3010
|
+
);
|
|
3011
|
+
}
|
|
3012
|
+
this._nFeatures = X[0].length;
|
|
3013
|
+
const m = X.length;
|
|
3014
|
+
this._classes = [...new Set(y)].sort((a, b) => a - b);
|
|
3015
|
+
for (const c of this._classes) {
|
|
3016
|
+
const rows = X.filter((_, i) => y[i] === c);
|
|
3017
|
+
const count = rows.length;
|
|
3018
|
+
if (count === 0) continue;
|
|
3019
|
+
this._logPriors.set(c, Math.log(count / m));
|
|
3020
|
+
const means = new Array(this._nFeatures).fill(0);
|
|
3021
|
+
for (const row of rows) {
|
|
3022
|
+
for (let j = 0; j < this._nFeatures; j++) means[j] += row[j];
|
|
3023
|
+
}
|
|
3024
|
+
for (let j = 0; j < this._nFeatures; j++) means[j] /= count;
|
|
3025
|
+
this._means.set(c, means);
|
|
3026
|
+
const variances = new Array(this._nFeatures).fill(0);
|
|
3027
|
+
for (const row of rows) {
|
|
3028
|
+
for (let j = 0; j < this._nFeatures; j++) {
|
|
3029
|
+
const diff = row[j] - means[j];
|
|
3030
|
+
variances[j] += diff * diff;
|
|
3031
|
+
}
|
|
3032
|
+
}
|
|
3033
|
+
for (let j = 0; j < this._nFeatures; j++) {
|
|
3034
|
+
variances[j] = Math.max(variances[j] / count, 1e-9);
|
|
3035
|
+
}
|
|
3036
|
+
this._variances.set(c, variances);
|
|
3037
|
+
}
|
|
3038
|
+
}
|
|
3039
|
+
// ─── Log-likelihood of a single feature value under a Gaussian ─────────────
|
|
3040
|
+
// log P(x | μ, σ²) = −0.5·log(2πσ²) − (x−μ)² / (2σ²)
|
|
3041
|
+
_logGaussian(x, mean, variance) {
|
|
3042
|
+
return -0.5 * Math.log(2 * Math.PI * variance) - (x - mean) ** 2 / (2 * variance);
|
|
3043
|
+
}
|
|
3044
|
+
// ─── Log-scores per class ────────────────────────────────────────────────
|
|
3045
|
+
// log P(class|x) ∝ log P(class) + Σⱼ log P(xⱼ|class)
|
|
3046
|
+
_logScores(x) {
|
|
3047
|
+
if (this._classes.length === 0) {
|
|
3048
|
+
throw new Error("GaussianNaiveBayes: model has not been fitted yet");
|
|
3049
|
+
}
|
|
3050
|
+
if (x.length !== this._nFeatures) {
|
|
3051
|
+
throw new Error(
|
|
3052
|
+
`GaussianNaiveBayes: expected ${this._nFeatures} features, got ${x.length}`
|
|
3053
|
+
);
|
|
3054
|
+
}
|
|
3055
|
+
const scores = /* @__PURE__ */ new Map();
|
|
3056
|
+
for (const c of this._classes) {
|
|
3057
|
+
const means = this._means.get(c);
|
|
3058
|
+
const variances = this._variances.get(c);
|
|
3059
|
+
const logPrior = this._logPriors.get(c);
|
|
3060
|
+
let logLikelihood = 0;
|
|
3061
|
+
for (let j = 0; j < this._nFeatures; j++) {
|
|
3062
|
+
logLikelihood += this._logGaussian(x[j], means[j], variances[j]);
|
|
3063
|
+
}
|
|
3064
|
+
scores.set(c, logPrior + logLikelihood);
|
|
3065
|
+
}
|
|
3066
|
+
return scores;
|
|
3067
|
+
}
|
|
3068
|
+
// ─── Predict (argmax class) ──────────────────────────────────────────────
|
|
3069
|
+
// Returns the class with the highest log-posterior.
|
|
3070
|
+
// No exp() needed — argmax is order-preserving.
|
|
3071
|
+
predict(x) {
|
|
3072
|
+
const scores = this._logScores(x);
|
|
3073
|
+
let bestClass = this._classes[0];
|
|
3074
|
+
let bestScore = -Infinity;
|
|
3075
|
+
for (const [c, s] of scores) {
|
|
3076
|
+
if (s > bestScore) {
|
|
3077
|
+
bestScore = s;
|
|
3078
|
+
bestClass = c;
|
|
3079
|
+
}
|
|
3080
|
+
}
|
|
3081
|
+
return bestClass;
|
|
3082
|
+
}
|
|
3083
|
+
// ─── Predict probabilities ────────────────────────────────────────────────
|
|
3084
|
+
// Converts log-scores to actual probabilities using the log-sum-exp trick
|
|
3085
|
+
// to avoid numerical overflow:
|
|
3086
|
+
//
|
|
3087
|
+
// log Σₖ exp(sₖ) = maxScore + log Σₖ exp(sₖ − maxScore)
|
|
3088
|
+
//
|
|
3089
|
+
// Then P(c|x) = exp(score[c] − log Σ exp).
|
|
3090
|
+
predictProba(x) {
|
|
3091
|
+
const logScores = this._logScores(x);
|
|
3092
|
+
const scores = [...logScores.values()];
|
|
3093
|
+
const maxScore = Math.max(...scores);
|
|
3094
|
+
const logSumExp = maxScore + Math.log(
|
|
3095
|
+
scores.reduce((sum, s) => sum + Math.exp(s - maxScore), 0)
|
|
3096
|
+
);
|
|
3097
|
+
const proba = /* @__PURE__ */ new Map();
|
|
3098
|
+
for (const [c, s] of logScores) {
|
|
3099
|
+
proba.set(c, Math.exp(s - logSumExp));
|
|
3100
|
+
}
|
|
3101
|
+
return proba;
|
|
3102
|
+
}
|
|
3103
|
+
};
|
|
3104
|
+
|
|
3105
|
+
// src/DecisionTree.ts
|
|
3106
|
+
var DecisionTree = class {
|
|
3107
|
+
constructor(options) {
|
|
3108
|
+
this._root = null;
|
|
3109
|
+
this._maxDepth = options?.maxDepth ?? 10;
|
|
3110
|
+
this._minSamplesSplit = options?.minSamplesSplit ?? 2;
|
|
3111
|
+
this._task = options?.task ?? "classification";
|
|
3112
|
+
if (this._maxDepth <= 0) {
|
|
3113
|
+
throw new Error("DecisionTree: maxDepth must be positive");
|
|
3114
|
+
}
|
|
3115
|
+
if (this._minSamplesSplit < 2) {
|
|
3116
|
+
throw new Error("DecisionTree: minSamplesSplit must be at least 2");
|
|
3117
|
+
}
|
|
3118
|
+
}
|
|
3119
|
+
// ─── Gini impurity ─────────────────────────────────────────────────────────
|
|
3120
|
+
// G = 1 − Σₖ pₖ²
|
|
3121
|
+
// G = 0 when all samples share one class (perfectly pure node).
|
|
3122
|
+
// G ≈ 0.5 for a binary node with equal class distribution.
|
|
3123
|
+
_gini(y) {
|
|
3124
|
+
if (y.length === 0) return 0;
|
|
3125
|
+
const counts = /* @__PURE__ */ new Map();
|
|
3126
|
+
for (const label of y) counts.set(label, (counts.get(label) ?? 0) + 1);
|
|
3127
|
+
let g = 1;
|
|
3128
|
+
for (const count of counts.values()) {
|
|
3129
|
+
const p = count / y.length;
|
|
3130
|
+
g -= p * p;
|
|
3131
|
+
}
|
|
3132
|
+
return g;
|
|
3133
|
+
}
|
|
3134
|
+
// ─── Mean Squared Error (regression impurity) ─────────────────────────────
|
|
3135
|
+
// MSE = (1/n) Σ (yᵢ − ȳ)²
|
|
3136
|
+
_mse(y) {
|
|
3137
|
+
if (y.length === 0) return 0;
|
|
3138
|
+
const mean = y.reduce((a, b) => a + b, 0) / y.length;
|
|
3139
|
+
return y.reduce((acc, v) => acc + (v - mean) ** 2, 0) / y.length;
|
|
3140
|
+
}
|
|
3141
|
+
// ─── Impurity selector ─────────────────────────────────────────────────────
|
|
3142
|
+
_impurity(y) {
|
|
3143
|
+
return this._task === "classification" ? this._gini(y) : this._mse(y);
|
|
3144
|
+
}
|
|
3145
|
+
// ─── Leaf value ────────────────────────────────────────────────────────────
|
|
3146
|
+
// Classification: majority class. Regression: mean.
|
|
3147
|
+
_leafValue(y) {
|
|
3148
|
+
if (this._task === "regression") {
|
|
3149
|
+
return y.reduce((a, b) => a + b, 0) / y.length;
|
|
3150
|
+
}
|
|
3151
|
+
const counts = /* @__PURE__ */ new Map();
|
|
3152
|
+
for (const label of y) counts.set(label, (counts.get(label) ?? 0) + 1);
|
|
3153
|
+
let bestClass = y[0];
|
|
3154
|
+
let bestCount = 0;
|
|
3155
|
+
for (const [cls, cnt] of counts) {
|
|
3156
|
+
if (cnt > bestCount) {
|
|
3157
|
+
bestCount = cnt;
|
|
3158
|
+
bestClass = cls;
|
|
3159
|
+
}
|
|
3160
|
+
}
|
|
3161
|
+
return bestClass;
|
|
3162
|
+
}
|
|
3163
|
+
// ─── Best split search ─────────────────────────────────────────────────────
|
|
3164
|
+
// Brute-force: try every feature × every unique threshold.
|
|
3165
|
+
// Returns the split that minimises weighted impurity (or null if none helps).
|
|
3166
|
+
_bestSplit(X, y) {
|
|
3167
|
+
const nFeatures = X[0].length;
|
|
3168
|
+
const n = y.length;
|
|
3169
|
+
let bestImpurity = Infinity;
|
|
3170
|
+
let bestSplit = null;
|
|
3171
|
+
const parentImpurity = this._impurity(y);
|
|
3172
|
+
for (let j = 0; j < nFeatures; j++) {
|
|
3173
|
+
const values = [...new Set(X.map((row) => row[j]))].sort((a, b) => a - b);
|
|
3174
|
+
for (let vi = 0; vi < values.length - 1; vi++) {
|
|
3175
|
+
const threshold = (values[vi] + values[vi + 1]) / 2;
|
|
3176
|
+
const leftY = [];
|
|
3177
|
+
const rightY = [];
|
|
3178
|
+
for (let i = 0; i < n; i++) {
|
|
3179
|
+
if (X[i][j] <= threshold) leftY.push(y[i]);
|
|
3180
|
+
else rightY.push(y[i]);
|
|
3181
|
+
}
|
|
3182
|
+
if (leftY.length === 0 || rightY.length === 0) continue;
|
|
3183
|
+
const weightedImpurity = leftY.length / n * this._impurity(leftY) + rightY.length / n * this._impurity(rightY);
|
|
3184
|
+
if (weightedImpurity < bestImpurity && weightedImpurity < parentImpurity) {
|
|
3185
|
+
bestImpurity = weightedImpurity;
|
|
3186
|
+
bestSplit = { featureIndex: j, threshold };
|
|
3187
|
+
}
|
|
3188
|
+
}
|
|
3189
|
+
}
|
|
3190
|
+
return bestSplit;
|
|
3191
|
+
}
|
|
3192
|
+
// ─── Recursive tree builder ────────────────────────────────────────────────
|
|
3193
|
+
_buildNode(X, y, depth) {
|
|
3194
|
+
const allSame = y.every((v) => v === y[0]);
|
|
3195
|
+
if (depth >= this._maxDepth || y.length < this._minSamplesSplit || allSame) {
|
|
3196
|
+
return { isLeaf: true, value: this._leafValue(y) };
|
|
3197
|
+
}
|
|
3198
|
+
const split = this._bestSplit(X, y);
|
|
3199
|
+
if (split === null) {
|
|
3200
|
+
return { isLeaf: true, value: this._leafValue(y) };
|
|
3201
|
+
}
|
|
3202
|
+
const { featureIndex, threshold } = split;
|
|
3203
|
+
const leftX = [];
|
|
3204
|
+
const leftY = [];
|
|
3205
|
+
const rightX = [];
|
|
3206
|
+
const rightY = [];
|
|
3207
|
+
for (let i = 0; i < y.length; i++) {
|
|
3208
|
+
if (X[i][featureIndex] <= threshold) {
|
|
3209
|
+
leftX.push(X[i]);
|
|
3210
|
+
leftY.push(y[i]);
|
|
3211
|
+
} else {
|
|
3212
|
+
rightX.push(X[i]);
|
|
3213
|
+
rightY.push(y[i]);
|
|
3214
|
+
}
|
|
3215
|
+
}
|
|
3216
|
+
return {
|
|
3217
|
+
isLeaf: false,
|
|
3218
|
+
featureIndex,
|
|
3219
|
+
threshold,
|
|
3220
|
+
left: this._buildNode(leftX, leftY, depth + 1),
|
|
3221
|
+
right: this._buildNode(rightX, rightY, depth + 1)
|
|
3222
|
+
};
|
|
3223
|
+
}
|
|
3224
|
+
// ─── Fit ──────────────────────────────────────────────────────────────────
|
|
3225
|
+
fit(X, y) {
|
|
3226
|
+
if (X.length === 0) throw new Error("DecisionTree.fit: X is empty");
|
|
3227
|
+
if (X.length !== y.length) {
|
|
3228
|
+
throw new Error(
|
|
3229
|
+
`DecisionTree.fit: X has ${X.length} rows but y has ${y.length} labels`
|
|
3230
|
+
);
|
|
3231
|
+
}
|
|
3232
|
+
this._root = this._buildNode(X, y, 0);
|
|
3233
|
+
}
|
|
3234
|
+
// ─── Traverse a single sample ─────────────────────────────────────────────
|
|
3235
|
+
_traverse(node, x) {
|
|
3236
|
+
if (node.isLeaf) return node.value;
|
|
3237
|
+
if (x[node.featureIndex] <= node.threshold) {
|
|
3238
|
+
return this._traverse(node.left, x);
|
|
3239
|
+
}
|
|
3240
|
+
return this._traverse(node.right, x);
|
|
3241
|
+
}
|
|
3242
|
+
// ─── Predict single sample ────────────────────────────────────────────────
|
|
3243
|
+
predict(x) {
|
|
3244
|
+
if (this._root === null) {
|
|
3245
|
+
throw new Error("DecisionTree.predict: model has not been fitted yet");
|
|
3246
|
+
}
|
|
3247
|
+
return this._traverse(this._root, x);
|
|
3248
|
+
}
|
|
3249
|
+
// ─── Predict batch ────────────────────────────────────────────────────────
|
|
3250
|
+
predictBatch(X) {
|
|
3251
|
+
return X.map((x) => this.predict(x));
|
|
3252
|
+
}
|
|
3253
|
+
};
|
|
3254
|
+
|
|
3255
|
+
// src/KMeans.ts
|
|
3256
|
+
var KMeans = class {
|
|
3257
|
+
constructor(k, options = {}) {
|
|
3258
|
+
if (!Number.isInteger(k) || k < 1) {
|
|
3259
|
+
throw new Error(`KMeans: k must be a positive integer, got ${k}`);
|
|
3260
|
+
}
|
|
3261
|
+
this._k = k;
|
|
3262
|
+
this._maxIter = options.maxIter ?? 300;
|
|
3263
|
+
this.centroids = [];
|
|
3264
|
+
if (options.seed !== void 0) {
|
|
3265
|
+
let s = options.seed >>> 0;
|
|
3266
|
+
this._rng = () => {
|
|
3267
|
+
s += 1831565813;
|
|
3268
|
+
let t = Math.imul(s ^ s >>> 15, 1 | s);
|
|
3269
|
+
t ^= t + Math.imul(t ^ t >>> 7, 61 | t);
|
|
3270
|
+
return ((t ^ t >>> 14) >>> 0) / 4294967296;
|
|
3271
|
+
};
|
|
3272
|
+
} else {
|
|
3273
|
+
this._rng = () => Math.random();
|
|
3274
|
+
}
|
|
3275
|
+
}
|
|
3276
|
+
// ── fit ────────────────────────────────────────────────────────────────────
|
|
3277
|
+
// Runs K-Means++ init then Lloyd iterations until centroids stop moving or
|
|
3278
|
+
// maxIter is reached.
|
|
3279
|
+
fit(X) {
|
|
3280
|
+
if (!X || X.length === 0) {
|
|
3281
|
+
throw new Error("KMeans.fit: dataset X must be non-empty");
|
|
3282
|
+
}
|
|
3283
|
+
const n = X.length;
|
|
3284
|
+
const d = X[0].length;
|
|
3285
|
+
if (this._k > n) {
|
|
3286
|
+
throw new Error(`KMeans.fit: k (${this._k}) cannot exceed number of samples (${n})`);
|
|
3287
|
+
}
|
|
3288
|
+
this.centroids = [];
|
|
3289
|
+
const firstIdx = Math.floor(this._rng() * n);
|
|
3290
|
+
this.centroids.push([...X[firstIdx]]);
|
|
3291
|
+
for (let c = 1; c < this._k; c++) {
|
|
3292
|
+
const dists = X.map((x) => this._minDistSq(x));
|
|
3293
|
+
const total = dists.reduce((s, v) => s + v, 0);
|
|
3294
|
+
let threshold = this._rng() * total;
|
|
3295
|
+
let chosen = n - 1;
|
|
3296
|
+
for (let i = 0; i < n; i++) {
|
|
3297
|
+
threshold -= dists[i];
|
|
3298
|
+
if (threshold <= 0) {
|
|
3299
|
+
chosen = i;
|
|
3300
|
+
break;
|
|
3301
|
+
}
|
|
3302
|
+
}
|
|
3303
|
+
this.centroids.push([...X[chosen]]);
|
|
3304
|
+
}
|
|
3305
|
+
const assignments = new Int32Array(n);
|
|
3306
|
+
for (let iter = 0; iter < this._maxIter; iter++) {
|
|
3307
|
+
for (let i = 0; i < n; i++) {
|
|
3308
|
+
assignments[i] = this._nearestCentroid(X[i]);
|
|
3309
|
+
}
|
|
3310
|
+
const sums = Array.from({ length: this._k }, () => new Array(d).fill(0));
|
|
3311
|
+
const counts = new Int32Array(this._k);
|
|
3312
|
+
for (let i = 0; i < n; i++) {
|
|
3313
|
+
const c = assignments[i];
|
|
3314
|
+
counts[c]++;
|
|
3315
|
+
for (let j = 0; j < d; j++) sums[c][j] += X[i][j];
|
|
3316
|
+
}
|
|
3317
|
+
let moved = false;
|
|
3318
|
+
for (let c = 0; c < this._k; c++) {
|
|
3319
|
+
if (counts[c] === 0) continue;
|
|
3320
|
+
for (let j = 0; j < d; j++) {
|
|
3321
|
+
const newVal = sums[c][j] / counts[c];
|
|
3322
|
+
if (Math.abs(newVal - this.centroids[c][j]) > 1e-10) moved = true;
|
|
3323
|
+
this.centroids[c][j] = newVal;
|
|
3324
|
+
}
|
|
3325
|
+
}
|
|
3326
|
+
if (!moved) break;
|
|
3327
|
+
}
|
|
3328
|
+
}
|
|
3329
|
+
// ── predict ────────────────────────────────────────────────────────────────
|
|
3330
|
+
// Returns the index of the nearest centroid for a single point.
|
|
3331
|
+
predict(x) {
|
|
3332
|
+
if (this.centroids.length === 0) {
|
|
3333
|
+
throw new Error("KMeans.predict: call fit() before predict()");
|
|
3334
|
+
}
|
|
3335
|
+
return this._nearestCentroid(x);
|
|
3336
|
+
}
|
|
3337
|
+
// ── predictBatch ──────────────────────────────────────────────────────────
|
|
3338
|
+
// Assigns each point in X to a cluster. Returns array of cluster indices.
|
|
3339
|
+
predictBatch(X) {
|
|
3340
|
+
return X.map((x) => this.predict(x));
|
|
3341
|
+
}
|
|
3342
|
+
// ── inertia ───────────────────────────────────────────────────────────────
|
|
3343
|
+
// J = Σᵢ d(xᵢ, μ_{cᵢ})²
|
|
3344
|
+
// Lower inertia = tighter clusters. Use the elbow method to pick K:
|
|
3345
|
+
// run fit() for K = 1..10 and plot inertia — the elbow is your optimal K.
|
|
3346
|
+
inertia(X) {
|
|
3347
|
+
if (this.centroids.length === 0) {
|
|
3348
|
+
throw new Error("KMeans.inertia: call fit() before inertia()");
|
|
3349
|
+
}
|
|
3350
|
+
return X.reduce((sum, x) => sum + this._minDistSq(x), 0);
|
|
3351
|
+
}
|
|
3352
|
+
// ── Private helpers ────────────────────────────────────────────────────────
|
|
3353
|
+
_euclideanSq(a, b) {
|
|
3354
|
+
let s = 0;
|
|
3355
|
+
for (let i = 0; i < a.length; i++) s += (a[i] - b[i]) ** 2;
|
|
3356
|
+
return s;
|
|
3357
|
+
}
|
|
3358
|
+
_minDistSq(x) {
|
|
3359
|
+
let min = Infinity;
|
|
3360
|
+
for (const c of this.centroids) {
|
|
3361
|
+
const d = this._euclideanSq(x, c);
|
|
3362
|
+
if (d < min) min = d;
|
|
3363
|
+
}
|
|
3364
|
+
return min;
|
|
3365
|
+
}
|
|
3366
|
+
_nearestCentroid(x) {
|
|
3367
|
+
let best = 0;
|
|
3368
|
+
let bestDist = Infinity;
|
|
3369
|
+
for (let c = 0; c < this.centroids.length; c++) {
|
|
3370
|
+
const d = this._euclideanSq(x, this.centroids[c]);
|
|
3371
|
+
if (d < bestDist) {
|
|
3372
|
+
bestDist = d;
|
|
3373
|
+
best = c;
|
|
3374
|
+
}
|
|
3375
|
+
}
|
|
3376
|
+
return best;
|
|
3377
|
+
}
|
|
3378
|
+
};
|
|
3379
|
+
|
|
3380
|
+
// src/PCA.ts
|
|
3381
|
+
var PCA = class {
|
|
3382
|
+
constructor(nComponents) {
|
|
3383
|
+
if (!Number.isInteger(nComponents) || nComponents < 1) {
|
|
3384
|
+
throw new Error(`PCA: nComponents must be a positive integer, got ${nComponents}`);
|
|
3385
|
+
}
|
|
3386
|
+
this._nComponents = nComponents;
|
|
3387
|
+
this.components = [];
|
|
3388
|
+
this.explainedVariance = [];
|
|
3389
|
+
this.mean = [];
|
|
3390
|
+
}
|
|
3391
|
+
// ── fit ────────────────────────────────────────────────────────────────────
|
|
3392
|
+
// Computes the mean and the top nComponents principal components from X.
|
|
3393
|
+
fit(X) {
|
|
3394
|
+
const n = X.length;
|
|
3395
|
+
if (n < 2) throw new Error("PCA.fit: need at least 2 samples");
|
|
3396
|
+
const p = X[0].length;
|
|
3397
|
+
if (this._nComponents > p) {
|
|
3398
|
+
throw new Error(
|
|
3399
|
+
`PCA: nComponents (${this._nComponents}) cannot exceed number of features (${p})`
|
|
3400
|
+
);
|
|
3401
|
+
}
|
|
3402
|
+
this.mean = new Array(p).fill(0);
|
|
3403
|
+
for (const row of X) for (let j = 0; j < p; j++) this.mean[j] += row[j];
|
|
3404
|
+
for (let j = 0; j < p; j++) this.mean[j] /= n;
|
|
3405
|
+
const Xc = X.map((row) => row.map((v, j) => v - this.mean[j]));
|
|
3406
|
+
let cov = this._covMatrix(Xc, n, p);
|
|
3407
|
+
this.components = [];
|
|
3408
|
+
this.explainedVariance = [];
|
|
3409
|
+
for (let c = 0; c < this._nComponents; c++) {
|
|
3410
|
+
const { eigenvector, eigenvalue } = this._powerIteration(cov, p);
|
|
3411
|
+
this.components.push(eigenvector);
|
|
3412
|
+
this.explainedVariance.push(eigenvalue);
|
|
3413
|
+
for (let i = 0; i < p; i++) {
|
|
3414
|
+
for (let j = 0; j < p; j++) {
|
|
3415
|
+
cov[i][j] -= eigenvalue * eigenvector[i] * eigenvector[j];
|
|
3416
|
+
}
|
|
3417
|
+
}
|
|
3418
|
+
}
|
|
3419
|
+
}
|
|
3420
|
+
// ── transform ──────────────────────────────────────────────────────────────
|
|
3421
|
+
// Z = (X - μ) · Vᵀ shape [n × nComponents]
|
|
3422
|
+
transform(X) {
|
|
3423
|
+
if (this.components.length === 0) {
|
|
3424
|
+
throw new Error("PCA.transform: call fit() before transform()");
|
|
3425
|
+
}
|
|
3426
|
+
return X.map((row) => {
|
|
3427
|
+
const centered = row.map((v, j) => v - this.mean[j]);
|
|
3428
|
+
return this.components.map(
|
|
3429
|
+
(pc) => pc.reduce((s, w, j) => s + w * centered[j], 0)
|
|
3430
|
+
);
|
|
3431
|
+
});
|
|
3432
|
+
}
|
|
3433
|
+
// ── fitTransform ───────────────────────────────────────────────────────────
|
|
3434
|
+
// Convenience: fit() then transform() in a single call.
|
|
3435
|
+
fitTransform(X) {
|
|
3436
|
+
this.fit(X);
|
|
3437
|
+
return this.transform(X);
|
|
3438
|
+
}
|
|
3439
|
+
// ── inverseTransform ───────────────────────────────────────────────────────
|
|
3440
|
+
// X̂ = Z · V + μ shape [n × nFeatures] (approximate reconstruction)
|
|
3441
|
+
inverseTransform(Z) {
|
|
3442
|
+
if (this.components.length === 0) {
|
|
3443
|
+
throw new Error("PCA.inverseTransform: call fit() before inverseTransform()");
|
|
3444
|
+
}
|
|
3445
|
+
const p = this.mean.length;
|
|
3446
|
+
return Z.map((z) => {
|
|
3447
|
+
const row = new Array(p).fill(0);
|
|
3448
|
+
for (let c = 0; c < this._nComponents; c++) {
|
|
3449
|
+
for (let j = 0; j < p; j++) {
|
|
3450
|
+
row[j] += z[c] * this.components[c][j];
|
|
3451
|
+
}
|
|
3452
|
+
}
|
|
3453
|
+
return row.map((v, j) => v + this.mean[j]);
|
|
3454
|
+
});
|
|
3455
|
+
}
|
|
3456
|
+
// ── explainedVarianceRatio ─────────────────────────────────────────────────
|
|
3457
|
+
// rₖ = λₖ / Σⱼ λⱼ
|
|
3458
|
+
// Sum of all ratios ≤ 1. If you chose nComponents = p, the sum is exactly 1.
|
|
3459
|
+
explainedVarianceRatio() {
|
|
3460
|
+
const total = this.explainedVariance.reduce((s, v) => s + v, 0);
|
|
3461
|
+
if (total === 0) return this.explainedVariance.map(() => 0);
|
|
3462
|
+
return this.explainedVariance.map((v) => v / total);
|
|
3463
|
+
}
|
|
3464
|
+
// ── Private helpers ────────────────────────────────────────────────────────
|
|
3465
|
+
// Build the [p×p] covariance matrix from a centered matrix Xc.
|
|
3466
|
+
_covMatrix(Xc, n, p) {
|
|
3467
|
+
const cov = Array.from({ length: p }, () => new Array(p).fill(0));
|
|
3468
|
+
for (const row of Xc) {
|
|
3469
|
+
for (let i = 0; i < p; i++) {
|
|
3470
|
+
for (let j = i; j < p; j++) {
|
|
3471
|
+
cov[i][j] += row[i] * row[j];
|
|
3472
|
+
}
|
|
3473
|
+
}
|
|
3474
|
+
}
|
|
3475
|
+
for (let i = 0; i < p; i++) {
|
|
3476
|
+
cov[i][i] /= n;
|
|
3477
|
+
for (let j = i + 1; j < p; j++) {
|
|
3478
|
+
cov[i][j] /= n;
|
|
3479
|
+
cov[j][i] = cov[i][j];
|
|
3480
|
+
}
|
|
3481
|
+
}
|
|
3482
|
+
return cov;
|
|
3483
|
+
}
|
|
3484
|
+
// Power iteration: find the dominant eigenvector of a symmetric matrix.
|
|
3485
|
+
// v ← M·v / ‖M·v‖ (repeated until ‖v_new - v_old‖ < tol)
|
|
3486
|
+
// Returns both the eigenvector (unit length) and its eigenvalue λ = vᵀ·M·v.
|
|
3487
|
+
_powerIteration(M, p, maxIter = 1e3, tol = 1e-10) {
|
|
3488
|
+
let v = Array.from({ length: p }, () => Math.random() - 0.5);
|
|
3489
|
+
v = this._normalize(v);
|
|
3490
|
+
for (let iter = 0; iter < maxIter; iter++) {
|
|
3491
|
+
const Mv2 = this._matvec(M, v);
|
|
3492
|
+
const vNew = this._normalize(Mv2);
|
|
3493
|
+
const dot = v.reduce((s, vi, i) => s + vi * vNew[i], 0);
|
|
3494
|
+
v = vNew;
|
|
3495
|
+
if (Math.abs(Math.abs(dot) - 1) < tol) break;
|
|
3496
|
+
}
|
|
3497
|
+
const Mv = this._matvec(M, v);
|
|
3498
|
+
const eigenvalue = v.reduce((s, vi, i) => s + vi * Mv[i], 0);
|
|
3499
|
+
return { eigenvector: v, eigenvalue: Math.max(0, eigenvalue) };
|
|
3500
|
+
}
|
|
3501
|
+
_matvec(M, v) {
|
|
3502
|
+
return M.map((row) => row.reduce((s, mij, j) => s + mij * v[j], 0));
|
|
3503
|
+
}
|
|
3504
|
+
_normalize(v) {
|
|
3505
|
+
const norm = Math.sqrt(v.reduce((s, vi) => s + vi * vi, 0));
|
|
3506
|
+
if (norm < 1e-14) return v;
|
|
3507
|
+
return v.map((vi) => vi / norm);
|
|
3508
|
+
}
|
|
3509
|
+
};
|
|
3510
|
+
|
|
3511
|
+
// src/SOM.ts
|
|
3512
|
+
var SOM = class {
|
|
3513
|
+
constructor(rows, cols, inputSize, options = {}) {
|
|
3514
|
+
if (rows < 1 || cols < 1 || inputSize < 1) {
|
|
3515
|
+
throw new Error(
|
|
3516
|
+
`SOM: rows, cols and inputSize must be positive integers, got ${rows}\xD7${cols}\xD7${inputSize}`
|
|
3517
|
+
);
|
|
3518
|
+
}
|
|
3519
|
+
this._rows = rows;
|
|
3520
|
+
this._cols = cols;
|
|
3521
|
+
this._inputSize = inputSize;
|
|
3522
|
+
this._initialLr = options.initialLr ?? 0.5;
|
|
3523
|
+
this._finalLr = options.finalLr ?? 0.01;
|
|
3524
|
+
this._initialSigma = options.initialSigma ?? Math.max(rows, cols) / 2;
|
|
3525
|
+
this._finalSigma = options.finalSigma ?? 1;
|
|
3526
|
+
this.weights = Array.from(
|
|
3527
|
+
{ length: rows },
|
|
3528
|
+
() => Array.from(
|
|
3529
|
+
{ length: cols },
|
|
3530
|
+
() => Array.from({ length: inputSize }, () => Math.random())
|
|
3531
|
+
)
|
|
3532
|
+
);
|
|
3533
|
+
}
|
|
3534
|
+
// ── train ──────────────────────────────────────────────────────────────────
|
|
3535
|
+
// Iterates over the dataset `epochs` times, presenting each sample and
|
|
3536
|
+
// performing a BMU search + neighborhood weight update.
|
|
3537
|
+
train(X, epochs) {
|
|
3538
|
+
if (!X || X.length === 0) {
|
|
3539
|
+
throw new Error("SOM.train: dataset X must be non-empty");
|
|
3540
|
+
}
|
|
3541
|
+
if (X[0].length !== this._inputSize) {
|
|
3542
|
+
throw new Error(
|
|
3543
|
+
`SOM.train: expected input size ${this._inputSize}, got ${X[0].length}`
|
|
3544
|
+
);
|
|
3545
|
+
}
|
|
3546
|
+
const totalSteps = epochs * X.length;
|
|
3547
|
+
let step = 0;
|
|
3548
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
3549
|
+
const indices = this._shuffle(X.length);
|
|
3550
|
+
for (const idx of indices) {
|
|
3551
|
+
const x = X[idx];
|
|
3552
|
+
const t = step / totalSteps;
|
|
3553
|
+
const lr = this._initialLr * Math.pow(this._finalLr / this._initialLr, t);
|
|
3554
|
+
const sigma = this._initialSigma * Math.pow(this._finalSigma / this._initialSigma, t);
|
|
3555
|
+
const sigma2 = 2 * sigma * sigma;
|
|
3556
|
+
const [bmuR, bmuC] = this.getBMU(x);
|
|
3557
|
+
for (let r = 0; r < this._rows; r++) {
|
|
3558
|
+
for (let c = 0; c < this._cols; c++) {
|
|
3559
|
+
const dr = r - bmuR;
|
|
3560
|
+
const dc = c - bmuC;
|
|
3561
|
+
const gridDistSq = dr * dr + dc * dc;
|
|
3562
|
+
const h = Math.exp(-gridDistSq / sigma2);
|
|
3563
|
+
if (h < 1e-6) continue;
|
|
3564
|
+
const w = this.weights[r][c];
|
|
3565
|
+
for (let i = 0; i < this._inputSize; i++) {
|
|
3566
|
+
w[i] += lr * h * (x[i] - w[i]);
|
|
3567
|
+
}
|
|
3568
|
+
}
|
|
3569
|
+
}
|
|
3570
|
+
step++;
|
|
3571
|
+
}
|
|
3572
|
+
}
|
|
3573
|
+
}
|
|
3574
|
+
// ── getBMU ─────────────────────────────────────────────────────────────────
|
|
3575
|
+
// Returns [row, col] of the Best Matching Unit for input x.
|
|
3576
|
+
// BMU = argmin_{r,c} ‖x − w[r][c]‖²
|
|
3577
|
+
getBMU(x) {
|
|
3578
|
+
if (x.length !== this._inputSize) {
|
|
3579
|
+
throw new Error(
|
|
3580
|
+
`SOM.getBMU: expected input of length ${this._inputSize}, got ${x.length}`
|
|
3581
|
+
);
|
|
3582
|
+
}
|
|
3583
|
+
let bestR = 0;
|
|
3584
|
+
let bestC = 0;
|
|
3585
|
+
let bestDist = Infinity;
|
|
3586
|
+
for (let r = 0; r < this._rows; r++) {
|
|
3587
|
+
for (let c = 0; c < this._cols; c++) {
|
|
3588
|
+
const dist = this._distSq(x, this.weights[r][c]);
|
|
3589
|
+
if (dist < bestDist) {
|
|
3590
|
+
bestDist = dist;
|
|
3591
|
+
bestR = r;
|
|
3592
|
+
bestC = c;
|
|
3593
|
+
}
|
|
3594
|
+
}
|
|
3595
|
+
}
|
|
3596
|
+
return [bestR, bestC];
|
|
3597
|
+
}
|
|
3598
|
+
// ── predict ────────────────────────────────────────────────────────────────
|
|
3599
|
+
// Alias for getBMU — returns [row, col] of the winning neuron.
|
|
3600
|
+
predict(x) {
|
|
3601
|
+
return this.getBMU(x);
|
|
3602
|
+
}
|
|
3603
|
+
// ── quantizationError ─────────────────────────────────────────────────────
|
|
3604
|
+
// QE = (1/n) · Σᵢ ‖xᵢ − w[BMU(xᵢ)]‖
|
|
3605
|
+
// Measures how well the prototypes represent the data. Lower is better.
|
|
3606
|
+
quantizationError(X) {
|
|
3607
|
+
let total = 0;
|
|
3608
|
+
for (const x of X) {
|
|
3609
|
+
const [r, c] = this.getBMU(x);
|
|
3610
|
+
total += Math.sqrt(this._distSq(x, this.weights[r][c]));
|
|
3611
|
+
}
|
|
3612
|
+
return total / X.length;
|
|
3613
|
+
}
|
|
3614
|
+
// ── Private helpers ────────────────────────────────────────────────────────
|
|
3615
|
+
_distSq(a, b) {
|
|
3616
|
+
let s = 0;
|
|
3617
|
+
for (let i = 0; i < a.length; i++) s += (a[i] - b[i]) ** 2;
|
|
3618
|
+
return s;
|
|
3619
|
+
}
|
|
3620
|
+
// Fisher-Yates shuffle — returns an array of shuffled indices.
|
|
3621
|
+
_shuffle(n) {
|
|
3622
|
+
const arr = Array.from({ length: n }, (_, i) => i);
|
|
3623
|
+
for (let i = n - 1; i > 0; i--) {
|
|
3624
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
3625
|
+
[arr[i], arr[j]] = [arr[j], arr[i]];
|
|
3626
|
+
}
|
|
3627
|
+
return arr;
|
|
3628
|
+
}
|
|
3629
|
+
};
|
|
3630
|
+
|
|
3631
|
+
// src/HopfieldNetwork.ts
|
|
3632
|
+
var HopfieldNetwork = class {
|
|
3633
|
+
constructor(n) {
|
|
3634
|
+
if (!Number.isInteger(n) || n < 1) {
|
|
3635
|
+
throw new Error(`HopfieldNetwork: n must be a positive integer, got ${n}`);
|
|
3636
|
+
}
|
|
3637
|
+
this.n = n;
|
|
3638
|
+
this.storedPatterns = 0;
|
|
3639
|
+
this.weights = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
3640
|
+
}
|
|
3641
|
+
// ── store ──────────────────────────────────────────────────────────────────
|
|
3642
|
+
// Adds a pattern to the network's memory using the Hebbian learning rule:
|
|
3643
|
+
// W ← W + (1/N) · p · pᵀ (diagonal stays 0)
|
|
3644
|
+
//
|
|
3645
|
+
// The pattern must be bipolar: each element must be +1 or −1.
|
|
3646
|
+
store(pattern) {
|
|
3647
|
+
if (pattern.length !== this.n) {
|
|
3648
|
+
throw new Error(
|
|
3649
|
+
`HopfieldNetwork.store: pattern length ${pattern.length} does not match network size ${this.n}`
|
|
3650
|
+
);
|
|
3651
|
+
}
|
|
3652
|
+
for (let i = 0; i < this.n; i++) {
|
|
3653
|
+
if (pattern[i] !== 1 && pattern[i] !== -1) {
|
|
3654
|
+
throw new Error(
|
|
3655
|
+
`HopfieldNetwork.store: pattern values must be +1 or -1, got ${pattern[i]} at index ${i}. Use HopfieldNetwork.binarize() to convert 0/1 arrays.`
|
|
3656
|
+
);
|
|
3657
|
+
}
|
|
3658
|
+
}
|
|
3659
|
+
const scale = 1 / this.n;
|
|
3660
|
+
for (let i = 0; i < this.n; i++) {
|
|
3661
|
+
for (let j = 0; j < this.n; j++) {
|
|
3662
|
+
if (i !== j) {
|
|
3663
|
+
this.weights[i][j] += scale * pattern[i] * pattern[j];
|
|
3664
|
+
}
|
|
3665
|
+
}
|
|
3666
|
+
}
|
|
3667
|
+
this.storedPatterns++;
|
|
3668
|
+
if (this.storedPatterns > Math.floor(0.138 * this.n)) {
|
|
3669
|
+
}
|
|
3670
|
+
}
|
|
3671
|
+
// ── recall ─────────────────────────────────────────────────────────────────
|
|
3672
|
+
// Starting from `input` (a noisy/partial copy of a stored pattern), runs
|
|
3673
|
+
// asynchronous updates until convergence or maxIter is reached.
|
|
3674
|
+
// hᵢ = Σⱼ Wᵢⱼ · sⱼ
|
|
3675
|
+
// sᵢ ← sign(hᵢ) (+1, −1; unchanged when hᵢ = 0)
|
|
3676
|
+
// Returns the converged state vector.
|
|
3677
|
+
recall(input, maxIter = 20 * this.n) {
|
|
3678
|
+
if (input.length !== this.n) {
|
|
3679
|
+
throw new Error(
|
|
3680
|
+
`HopfieldNetwork.recall: input length ${input.length} does not match network size ${this.n}`
|
|
3681
|
+
);
|
|
3682
|
+
}
|
|
3683
|
+
const s = [...input];
|
|
3684
|
+
const order = Array.from({ length: this.n }, (_, i) => i);
|
|
3685
|
+
for (let iter = 0; iter < maxIter; iter++) {
|
|
3686
|
+
this._shuffleInPlace(order);
|
|
3687
|
+
let changed = false;
|
|
3688
|
+
for (const i of order) {
|
|
3689
|
+
let h = 0;
|
|
3690
|
+
const row = this.weights[i];
|
|
3691
|
+
for (let j = 0; j < this.n; j++) h += row[j] * s[j];
|
|
3692
|
+
const newSi = h > 0 ? 1 : h < 0 ? -1 : s[i];
|
|
3693
|
+
if (newSi !== s[i]) {
|
|
3694
|
+
s[i] = newSi;
|
|
3695
|
+
changed = true;
|
|
3696
|
+
}
|
|
3697
|
+
}
|
|
3698
|
+
if (!changed) break;
|
|
3699
|
+
}
|
|
3700
|
+
return s;
|
|
3701
|
+
}
|
|
3702
|
+
// ── energy ─────────────────────────────────────────────────────────────────
|
|
3703
|
+
// E(s) = −½ · Σᵢⱼ Wᵢⱼ · sᵢ · sⱼ
|
|
3704
|
+
// Stored patterns are local minima. Updates always push E downward (or keep
|
|
3705
|
+
// it constant), so the network is guaranteed to converge.
|
|
3706
|
+
energy(state) {
|
|
3707
|
+
if (state.length !== this.n) {
|
|
3708
|
+
throw new Error(
|
|
3709
|
+
`HopfieldNetwork.energy: state length ${state.length} does not match network size ${this.n}`
|
|
3710
|
+
);
|
|
3711
|
+
}
|
|
3712
|
+
let e = 0;
|
|
3713
|
+
for (let i = 0; i < this.n; i++) {
|
|
3714
|
+
for (let j = 0; j < this.n; j++) {
|
|
3715
|
+
e += this.weights[i][j] * state[i] * state[j];
|
|
3716
|
+
}
|
|
3717
|
+
}
|
|
3718
|
+
return -0.5 * e;
|
|
3719
|
+
}
|
|
3720
|
+
// ── binarize ───────────────────────────────────────────────────────────────
|
|
3721
|
+
// Converts a 0/1 array to bipolar −1/+1.
|
|
3722
|
+
// 0 → −1, 1 → +1
|
|
3723
|
+
static binarize(arr) {
|
|
3724
|
+
return arr.map((v) => v === 0 ? -1 : 1);
|
|
3725
|
+
}
|
|
3726
|
+
// ── unbinarize ─────────────────────────────────────────────────────────────
|
|
3727
|
+
// Converts a bipolar −1/+1 array back to 0/1.
|
|
3728
|
+
// −1 → 0, +1 → 1
|
|
3729
|
+
static unbinarize(arr) {
|
|
3730
|
+
return arr.map((v) => v === -1 ? 0 : 1);
|
|
3731
|
+
}
|
|
3732
|
+
// ── Private helpers ────────────────────────────────────────────────────────
|
|
3733
|
+
_shuffleInPlace(arr) {
|
|
3734
|
+
for (let i = arr.length - 1; i > 0; i--) {
|
|
3735
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
3736
|
+
[arr[i], arr[j]] = [arr[j], arr[i]];
|
|
3737
|
+
}
|
|
3738
|
+
}
|
|
3739
|
+
};
|
|
3740
|
+
|
|
3741
|
+
// src/Autoencoder.ts
|
|
3742
|
+
var Autoencoder = class {
|
|
3743
|
+
constructor(inputSize, encoderHidden, latentSize, decoderHidden, options = {}) {
|
|
3744
|
+
if (inputSize < 1) {
|
|
3745
|
+
throw new Error(`Autoencoder: inputSize must be \u2265 1, got ${inputSize}`);
|
|
3746
|
+
}
|
|
3747
|
+
if (latentSize < 1) {
|
|
3748
|
+
throw new Error(`Autoencoder: latentSize must be \u2265 1, got ${latentSize}`);
|
|
3749
|
+
}
|
|
3750
|
+
if (latentSize >= inputSize) {
|
|
3751
|
+
}
|
|
3752
|
+
this._inputSize = inputSize;
|
|
3753
|
+
this._latentSize = latentSize;
|
|
3754
|
+
const encoderStructure = [inputSize, ...encoderHidden, latentSize];
|
|
3755
|
+
const encoderOptions = { ...options };
|
|
3756
|
+
if (options.activations) {
|
|
3757
|
+
const nEncoderLayers = encoderStructure.length - 1;
|
|
3758
|
+
if (options.activations.length >= nEncoderLayers) {
|
|
3759
|
+
encoderOptions.activations = options.activations.slice(0, nEncoderLayers);
|
|
3760
|
+
} else {
|
|
3761
|
+
encoderOptions.activations = void 0;
|
|
3762
|
+
}
|
|
3763
|
+
}
|
|
3764
|
+
this.encoder = new NetworkN(encoderStructure, encoderOptions);
|
|
3765
|
+
const decoderStructure = [latentSize, ...decoderHidden, inputSize];
|
|
3766
|
+
const decoderOptions = { ...options };
|
|
3767
|
+
if (options.activations) {
|
|
3768
|
+
const nEncoderLayers = encoderStructure.length - 1;
|
|
3769
|
+
const nDecoderLayers = decoderStructure.length - 1;
|
|
3770
|
+
const remaining = options.activations.slice(nEncoderLayers);
|
|
3771
|
+
if (remaining.length >= nDecoderLayers) {
|
|
3772
|
+
decoderOptions.activations = remaining.slice(0, nDecoderLayers);
|
|
3773
|
+
} else {
|
|
3774
|
+
decoderOptions.activations = void 0;
|
|
3775
|
+
}
|
|
3776
|
+
}
|
|
3777
|
+
this.decoder = new NetworkN(decoderStructure, decoderOptions);
|
|
3778
|
+
}
|
|
3779
|
+
// ── encode ─────────────────────────────────────────────────────────────────
|
|
3780
|
+
// Maps an input vector to its latent representation.
|
|
3781
|
+
// z = encoder(x) ∈ ℝ^latentSize
|
|
3782
|
+
encode(x) {
|
|
3783
|
+
if (x.length !== this._inputSize) {
|
|
3784
|
+
throw new Error(
|
|
3785
|
+
`Autoencoder.encode: expected input of length ${this._inputSize}, got ${x.length}`
|
|
3786
|
+
);
|
|
3787
|
+
}
|
|
3788
|
+
return this.encoder.predict(x);
|
|
3789
|
+
}
|
|
3790
|
+
// ── decode ─────────────────────────────────────────────────────────────────
|
|
3791
|
+
// Reconstructs an input from its latent code.
|
|
3792
|
+
// x̂ = decoder(z) ∈ ℝ^inputSize
|
|
3793
|
+
decode(z) {
|
|
3794
|
+
if (z.length !== this._latentSize) {
|
|
3795
|
+
throw new Error(
|
|
3796
|
+
`Autoencoder.decode: expected latent vector of length ${this._latentSize}, got ${z.length}`
|
|
3797
|
+
);
|
|
3798
|
+
}
|
|
3799
|
+
return this.decoder.predict(z);
|
|
3800
|
+
}
|
|
3801
|
+
// ── reconstruct ───────────────────────────────────────────────────────────
|
|
3802
|
+
// Convenience: encode then decode in a single call.
|
|
3803
|
+
// x̂ = decode(encode(x))
|
|
3804
|
+
reconstruct(x) {
|
|
3805
|
+
return this.decode(this.encode(x));
|
|
3806
|
+
}
|
|
3807
|
+
// ── train ──────────────────────────────────────────────────────────────────
|
|
3808
|
+
// Trains on a single example using backpropagation through both sub-networks.
|
|
3809
|
+
//
|
|
3810
|
+
// Gradient flow:
|
|
3811
|
+
// 1. Forward: z = encoder(x), x̂ = decoder(z)
|
|
3812
|
+
// 2. Compute MSE output deltas at x̂: δᵢ = (xᵢ − x̂ᵢ) · act'(x̂ᵢ)
|
|
3813
|
+
// 3. Walk backward through decoder layers to get ∂L/∂z (BEFORE updating weights)
|
|
3814
|
+
// 4. Update decoder weights via trainWithDeltas(z, outputDeltas, lr)
|
|
3815
|
+
// 5. Update encoder weights via trainWithDeltas(x, dLdz, lr)
|
|
3816
|
+
//
|
|
3817
|
+
// Returns the MSE reconstruction loss: (1/d) · Σᵢ (xᵢ − x̂ᵢ)².
|
|
3818
|
+
train(x, lr) {
|
|
3819
|
+
if (x.length !== this._inputSize) {
|
|
3820
|
+
throw new Error(
|
|
3821
|
+
`Autoencoder.train: expected input of length ${this._inputSize}, got ${x.length}`
|
|
3822
|
+
);
|
|
3823
|
+
}
|
|
3824
|
+
const z = this.encoder.predict(x, true);
|
|
3825
|
+
const xHat = this.decoder.predict(z, true);
|
|
3826
|
+
const loss = mse(xHat, x);
|
|
3827
|
+
const decoderOutAct = this.decoder.layers[this.decoder.layers.length - 1].neurons[0].activation;
|
|
3828
|
+
const outputDeltas = xHat.map((xh, i) => (x[i] - xh) * decoderOutAct.dfn(xh));
|
|
3829
|
+
const decoderLayers = this.decoder.layers;
|
|
3830
|
+
const decoderActVals = [z];
|
|
3831
|
+
let cur = [...z];
|
|
3832
|
+
for (const layer of decoderLayers) {
|
|
3833
|
+
cur = layer.predict(cur);
|
|
3834
|
+
decoderActVals.push(cur);
|
|
3835
|
+
}
|
|
3836
|
+
let deltas = outputDeltas;
|
|
3837
|
+
for (let l = decoderLayers.length - 1; l >= 0; l--) {
|
|
3838
|
+
const layer = decoderLayers[l];
|
|
3839
|
+
const prevAct = decoderActVals[l];
|
|
3840
|
+
const prevLayerActivation = l > 0 ? decoderLayers[l - 1].neurons[0].activation : null;
|
|
3841
|
+
const prevDeltas = prevAct.map((out, j) => {
|
|
3842
|
+
const errProp = layer.neurons.reduce((s, n, k) => s + deltas[k] * n.weights[j], 0);
|
|
3843
|
+
return prevLayerActivation ? errProp * prevLayerActivation.dfn(out) : errProp;
|
|
3844
|
+
});
|
|
3845
|
+
deltas = prevDeltas;
|
|
3846
|
+
}
|
|
3847
|
+
const dLdz = deltas;
|
|
3848
|
+
this.decoder.trainWithDeltas(z, outputDeltas, lr);
|
|
3849
|
+
this.encoder.trainWithDeltas(x, dLdz, lr);
|
|
3850
|
+
return loss;
|
|
3851
|
+
}
|
|
3852
|
+
// ── trainBatch ────────────────────────────────────────────────────────────
|
|
3853
|
+
// Trains on a batch of examples and returns the mean reconstruction MSE.
|
|
3854
|
+
trainBatch(X, lr) {
|
|
3855
|
+
if (X.length === 0) {
|
|
3856
|
+
throw new Error("Autoencoder.trainBatch: batch X must be non-empty");
|
|
3857
|
+
}
|
|
3858
|
+
let totalLoss = 0;
|
|
3859
|
+
for (const x of X) totalLoss += this.train(x, lr);
|
|
3860
|
+
return totalLoss / X.length;
|
|
3861
|
+
}
|
|
3862
|
+
};
|
|
3863
|
+
|
|
3864
|
+
// src/Conv2D.ts
|
|
3865
|
+
var Conv2D = class {
|
|
3866
|
+
constructor(inputHeight, inputWidth, channels, kernelSize, filters, options) {
|
|
3867
|
+
// [filters]
|
|
3868
|
+
this._input = null;
|
|
3869
|
+
this._padded = null;
|
|
3870
|
+
const [kH, kW] = Array.isArray(kernelSize) ? kernelSize : [kernelSize, kernelSize];
|
|
3871
|
+
if (inputHeight <= 0 || inputWidth <= 0 || channels <= 0 || filters <= 0) {
|
|
3872
|
+
throw new Error("Conv2D: dimensions and filters must be positive");
|
|
3873
|
+
}
|
|
3874
|
+
if (kH <= 0 || kW <= 0) {
|
|
3875
|
+
throw new Error("Conv2D: kernelSize must be positive");
|
|
3876
|
+
}
|
|
3877
|
+
this.inputHeight = inputHeight;
|
|
3878
|
+
this.inputWidth = inputWidth;
|
|
3879
|
+
this.channels = channels;
|
|
3880
|
+
this.kH = kH;
|
|
3881
|
+
this.kW = kW;
|
|
3882
|
+
this.filters = filters;
|
|
3883
|
+
this.stride = options?.stride ?? 1;
|
|
3884
|
+
this.padding = options?.padding ?? "valid";
|
|
3885
|
+
const optimizerFactory = options?.optimizerFactory ?? (() => new SGD());
|
|
3886
|
+
const limit = Math.sqrt(2 / (kH * kW * channels));
|
|
3887
|
+
this.kernels = Array.from(
|
|
3888
|
+
{ length: filters },
|
|
3889
|
+
() => Array.from(
|
|
3890
|
+
{ length: kH },
|
|
3891
|
+
() => Array.from(
|
|
3892
|
+
{ length: kW },
|
|
3893
|
+
() => Array.from({ length: channels }, () => (Math.random() * 2 - 1) * limit)
|
|
3894
|
+
)
|
|
3895
|
+
)
|
|
3896
|
+
);
|
|
3897
|
+
this.biases = new Array(filters).fill(0);
|
|
3898
|
+
this._kOpts = Array.from(
|
|
3899
|
+
{ length: filters },
|
|
3900
|
+
() => Array.from(
|
|
3901
|
+
{ length: kH },
|
|
3902
|
+
() => Array.from(
|
|
3903
|
+
{ length: kW },
|
|
3904
|
+
() => Array.from({ length: channels }, () => optimizerFactory())
|
|
3905
|
+
)
|
|
3906
|
+
)
|
|
3907
|
+
);
|
|
3908
|
+
this._bOpts = Array.from({ length: filters }, () => optimizerFactory());
|
|
3909
|
+
}
|
|
3910
|
+
// ── Padding helper ────────────────────────────────────────────────────────
|
|
3911
|
+
_pad(input) {
|
|
3912
|
+
if (this.padding === "valid") return input;
|
|
3913
|
+
const padH = Math.floor(this.kH / 2);
|
|
3914
|
+
const padW = Math.floor(this.kW / 2);
|
|
3915
|
+
const H = input.length;
|
|
3916
|
+
const W = input[0].length;
|
|
3917
|
+
const C = this.channels;
|
|
3918
|
+
const paddedH = H + 2 * padH;
|
|
3919
|
+
const paddedW = W + 2 * padW;
|
|
3920
|
+
const out = Array.from(
|
|
3921
|
+
{ length: paddedH },
|
|
3922
|
+
() => Array.from({ length: paddedW }, () => new Array(C).fill(0))
|
|
3923
|
+
);
|
|
3924
|
+
for (let h = 0; h < H; h++) {
|
|
3925
|
+
for (let w = 0; w < W; w++) {
|
|
3926
|
+
for (let c = 0; c < C; c++) {
|
|
3927
|
+
out[h + padH][w + padW][c] = input[h][w][c];
|
|
3928
|
+
}
|
|
3929
|
+
}
|
|
3930
|
+
}
|
|
3931
|
+
return out;
|
|
3932
|
+
}
|
|
3933
|
+
// ── Output shape ──────────────────────────────────────────────────────────
|
|
3934
|
+
outputShape() {
|
|
3935
|
+
const padH = this.padding === "same" ? Math.floor(this.kH / 2) : 0;
|
|
3936
|
+
const padW = this.padding === "same" ? Math.floor(this.kW / 2) : 0;
|
|
3937
|
+
const outH = Math.floor((this.inputHeight - this.kH + 2 * padH) / this.stride) + 1;
|
|
3938
|
+
const outW = Math.floor((this.inputWidth - this.kW + 2 * padW) / this.stride) + 1;
|
|
3939
|
+
return [outH, outW, this.filters];
|
|
3940
|
+
}
|
|
3941
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
3942
|
+
// output[h][w][f] = bias[f] + Σ_{kh,kw,c} kernel[f][kh][kw][c] · input[h·s+kh][w·s+kw][c]
|
|
3943
|
+
forward(input) {
|
|
3944
|
+
if (input.length !== this.inputHeight) {
|
|
3945
|
+
throw new Error(`Conv2D.forward: expected height ${this.inputHeight}, got ${input.length}`);
|
|
3946
|
+
}
|
|
3947
|
+
if (input[0].length !== this.inputWidth) {
|
|
3948
|
+
throw new Error(`Conv2D.forward: expected width ${this.inputWidth}, got ${input[0].length}`);
|
|
3949
|
+
}
|
|
3950
|
+
this._input = input;
|
|
3951
|
+
this._padded = this._pad(input);
|
|
3952
|
+
const padded = this._padded;
|
|
3953
|
+
const padH = this.padding === "same" ? Math.floor(this.kH / 2) : 0;
|
|
3954
|
+
const padW = this.padding === "same" ? Math.floor(this.kW / 2) : 0;
|
|
3955
|
+
const outH = Math.floor((this.inputHeight - this.kH + 2 * padH) / this.stride) + 1;
|
|
3956
|
+
const outW = Math.floor((this.inputWidth - this.kW + 2 * padW) / this.stride) + 1;
|
|
3957
|
+
const output = Array.from(
|
|
3958
|
+
{ length: outH },
|
|
3959
|
+
() => Array.from({ length: outW }, () => new Array(this.filters).fill(0))
|
|
3960
|
+
);
|
|
3961
|
+
for (let f = 0; f < this.filters; f++) {
|
|
3962
|
+
for (let h = 0; h < outH; h++) {
|
|
3963
|
+
for (let w = 0; w < outW; w++) {
|
|
3964
|
+
let sum = this.biases[f];
|
|
3965
|
+
for (let kh = 0; kh < this.kH; kh++) {
|
|
3966
|
+
for (let kw = 0; kw < this.kW; kw++) {
|
|
3967
|
+
for (let c = 0; c < this.channels; c++) {
|
|
3968
|
+
sum += this.kernels[f][kh][kw][c] * padded[h * this.stride + kh][w * this.stride + kw][c];
|
|
3969
|
+
}
|
|
3970
|
+
}
|
|
3971
|
+
}
|
|
3972
|
+
output[h][w][f] = sum;
|
|
3973
|
+
}
|
|
3974
|
+
}
|
|
3975
|
+
}
|
|
3976
|
+
return output;
|
|
3977
|
+
}
|
|
3978
|
+
// ── Backward pass ─────────────────────────────────────────────────────────
|
|
3979
|
+
// dOutput: number[][][] of shape [outH][outW][filters]
|
|
3980
|
+
// Returns dInput: number[][][] of shape [H][W][channels]
|
|
3981
|
+
backward(dOutput, lr) {
|
|
3982
|
+
if (!this._padded || !this._input) {
|
|
3983
|
+
throw new Error("Conv2D.backward: call forward() first");
|
|
3984
|
+
}
|
|
3985
|
+
const padded = this._padded;
|
|
3986
|
+
const outH = dOutput.length;
|
|
3987
|
+
const outW = dOutput[0].length;
|
|
3988
|
+
const dKernels = Array.from(
|
|
3989
|
+
{ length: this.filters },
|
|
3990
|
+
() => Array.from(
|
|
3991
|
+
{ length: this.kH },
|
|
3992
|
+
() => Array.from({ length: this.kW }, () => new Array(this.channels).fill(0))
|
|
3993
|
+
)
|
|
3994
|
+
);
|
|
3995
|
+
const dBiases = new Array(this.filters).fill(0);
|
|
3996
|
+
const dPadded = Array.from(
|
|
3997
|
+
{ length: padded.length },
|
|
3998
|
+
() => Array.from({ length: padded[0].length }, () => new Array(this.channels).fill(0))
|
|
3999
|
+
);
|
|
4000
|
+
for (let f = 0; f < this.filters; f++) {
|
|
4001
|
+
for (let h = 0; h < outH; h++) {
|
|
4002
|
+
for (let w = 0; w < outW; w++) {
|
|
4003
|
+
const dv = dOutput[h][w][f];
|
|
4004
|
+
dBiases[f] += dv;
|
|
4005
|
+
for (let kh = 0; kh < this.kH; kh++) {
|
|
4006
|
+
for (let kw = 0; kw < this.kW; kw++) {
|
|
4007
|
+
for (let c = 0; c < this.channels; c++) {
|
|
4008
|
+
const ph = h * this.stride + kh;
|
|
4009
|
+
const pw = w * this.stride + kw;
|
|
4010
|
+
dKernels[f][kh][kw][c] += dv * padded[ph][pw][c];
|
|
4011
|
+
dPadded[ph][pw][c] += dv * this.kernels[f][kh][kw][c];
|
|
4012
|
+
}
|
|
4013
|
+
}
|
|
4014
|
+
}
|
|
4015
|
+
}
|
|
4016
|
+
}
|
|
4017
|
+
}
|
|
4018
|
+
for (let f = 0; f < this.filters; f++) {
|
|
4019
|
+
for (let kh = 0; kh < this.kH; kh++) {
|
|
4020
|
+
for (let kw = 0; kw < this.kW; kw++) {
|
|
4021
|
+
for (let c = 0; c < this.channels; c++) {
|
|
4022
|
+
this.kernels[f][kh][kw][c] = this._kOpts[f][kh][kw][c].step(
|
|
4023
|
+
this.kernels[f][kh][kw][c],
|
|
4024
|
+
dKernels[f][kh][kw][c],
|
|
4025
|
+
lr
|
|
4026
|
+
);
|
|
4027
|
+
}
|
|
4028
|
+
}
|
|
4029
|
+
}
|
|
4030
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
4031
|
+
}
|
|
4032
|
+
if (this.padding === "same") {
|
|
4033
|
+
const padH = Math.floor(this.kH / 2);
|
|
4034
|
+
const padW = Math.floor(this.kW / 2);
|
|
4035
|
+
return dPadded.slice(padH, padH + this.inputHeight).map((row) => row.slice(padW, padW + this.inputWidth));
|
|
4036
|
+
}
|
|
4037
|
+
return dPadded.slice(0, this.inputHeight).map((row) => row.slice(0, this.inputWidth));
|
|
4038
|
+
}
|
|
4039
|
+
// ── Weight serialization ──────────────────────────────────────────────────
|
|
4040
|
+
getWeights() {
|
|
4041
|
+
const w = [];
|
|
4042
|
+
for (const kf of this.kernels)
|
|
4043
|
+
for (const kh of kf)
|
|
4044
|
+
for (const kw of kh)
|
|
4045
|
+
for (const v of kw)
|
|
4046
|
+
w.push(v);
|
|
4047
|
+
w.push(...this.biases);
|
|
4048
|
+
return w;
|
|
4049
|
+
}
|
|
4050
|
+
setWeights(weights) {
|
|
4051
|
+
let idx = 0;
|
|
4052
|
+
for (let f = 0; f < this.filters; f++)
|
|
4053
|
+
for (let kh = 0; kh < this.kH; kh++)
|
|
4054
|
+
for (let kw = 0; kw < this.kW; kw++)
|
|
4055
|
+
for (let c = 0; c < this.channels; c++)
|
|
4056
|
+
this.kernels[f][kh][kw][c] = weights[idx++];
|
|
4057
|
+
for (let f = 0; f < this.filters; f++)
|
|
4058
|
+
this.biases[f] = weights[idx++];
|
|
4059
|
+
}
|
|
4060
|
+
};
|
|
4061
|
+
|
|
4062
|
+
// src/MaxPool2D.ts
|
|
4063
|
+
var MaxPool2D = class {
|
|
4064
|
+
constructor(poolSize, stride) {
|
|
4065
|
+
// Mask stored during forward pass for backprop:
|
|
4066
|
+
// _maxMask[h][w][c] = true if input[h][w][c] was the maximum in its window
|
|
4067
|
+
this._maxMask = null;
|
|
4068
|
+
this._inputH = 0;
|
|
4069
|
+
this._inputW = 0;
|
|
4070
|
+
this._inputC = 0;
|
|
4071
|
+
if (poolSize <= 0) {
|
|
4072
|
+
throw new Error("MaxPool2D: poolSize must be positive");
|
|
4073
|
+
}
|
|
4074
|
+
this.poolSize = poolSize;
|
|
4075
|
+
this.stride = stride ?? poolSize;
|
|
4076
|
+
}
|
|
4077
|
+
// ── Output shape ──────────────────────────────────────────────────────────
|
|
4078
|
+
outputShape(inputH, inputW, channels) {
|
|
4079
|
+
const outH = Math.floor((inputH - this.poolSize) / this.stride) + 1;
|
|
4080
|
+
const outW = Math.floor((inputW - this.poolSize) / this.stride) + 1;
|
|
4081
|
+
return [outH, outW, channels];
|
|
4082
|
+
}
|
|
4083
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
4084
|
+
// output[oh][ow][c] = max over ph in [0..poolSize), pw in [0..poolSize) of
|
|
4085
|
+
// input[oh·stride + ph][ow·stride + pw][c]
|
|
4086
|
+
forward(input) {
|
|
4087
|
+
const H = input.length;
|
|
4088
|
+
const W = input[0].length;
|
|
4089
|
+
const C = input[0][0].length;
|
|
4090
|
+
this._inputH = H;
|
|
4091
|
+
this._inputW = W;
|
|
4092
|
+
this._inputC = C;
|
|
4093
|
+
const [outH, outW] = this.outputShape(H, W, C);
|
|
4094
|
+
const output = Array.from(
|
|
4095
|
+
{ length: outH },
|
|
4096
|
+
() => Array.from({ length: outW }, () => new Array(C).fill(-Infinity))
|
|
4097
|
+
);
|
|
4098
|
+
this._maxMask = Array.from(
|
|
4099
|
+
{ length: H },
|
|
4100
|
+
() => Array.from({ length: W }, () => new Array(C).fill(false))
|
|
4101
|
+
);
|
|
4102
|
+
for (let oh = 0; oh < outH; oh++) {
|
|
4103
|
+
for (let ow = 0; ow < outW; ow++) {
|
|
4104
|
+
for (let c = 0; c < C; c++) {
|
|
4105
|
+
let maxVal = -Infinity;
|
|
4106
|
+
let maxPH = 0;
|
|
4107
|
+
let maxPW = 0;
|
|
4108
|
+
for (let ph = 0; ph < this.poolSize; ph++) {
|
|
4109
|
+
for (let pw = 0; pw < this.poolSize; pw++) {
|
|
4110
|
+
const val = input[oh * this.stride + ph][ow * this.stride + pw][c];
|
|
4111
|
+
if (val > maxVal) {
|
|
4112
|
+
maxVal = val;
|
|
4113
|
+
maxPH = ph;
|
|
4114
|
+
maxPW = pw;
|
|
4115
|
+
}
|
|
4116
|
+
}
|
|
4117
|
+
}
|
|
4118
|
+
output[oh][ow][c] = maxVal;
|
|
4119
|
+
this._maxMask[oh * this.stride + maxPH][ow * this.stride + maxPW][c] = true;
|
|
4120
|
+
}
|
|
4121
|
+
}
|
|
4122
|
+
}
|
|
4123
|
+
return output;
|
|
4124
|
+
}
|
|
4125
|
+
// ── Backward pass ─────────────────────────────────────────────────────────
|
|
4126
|
+
// dOutput: number[][][] of shape [outH][outW][C]
|
|
4127
|
+
// Returns dInput: number[][][] of shape [H][W][C]
|
|
4128
|
+
// Gradient is routed only to the max position; all others get 0.
|
|
4129
|
+
backward(dOutput) {
|
|
4130
|
+
if (!this._maxMask) {
|
|
4131
|
+
throw new Error("MaxPool2D.backward: call forward() first");
|
|
4132
|
+
}
|
|
4133
|
+
const dInput = Array.from(
|
|
4134
|
+
{ length: this._inputH },
|
|
4135
|
+
() => Array.from({ length: this._inputW }, () => new Array(this._inputC).fill(0))
|
|
4136
|
+
);
|
|
4137
|
+
const outH = dOutput.length;
|
|
4138
|
+
const outW = dOutput[0].length;
|
|
4139
|
+
const C = this._inputC;
|
|
4140
|
+
for (let oh = 0; oh < outH; oh++) {
|
|
4141
|
+
for (let ow = 0; ow < outW; ow++) {
|
|
4142
|
+
for (let c = 0; c < C; c++) {
|
|
4143
|
+
for (let ph = 0; ph < this.poolSize; ph++) {
|
|
4144
|
+
for (let pw = 0; pw < this.poolSize; pw++) {
|
|
4145
|
+
const ih = oh * this.stride + ph;
|
|
4146
|
+
const iw = ow * this.stride + pw;
|
|
4147
|
+
if (this._maxMask[ih][iw][c]) {
|
|
4148
|
+
dInput[ih][iw][c] += dOutput[oh][ow][c];
|
|
4149
|
+
}
|
|
4150
|
+
}
|
|
4151
|
+
}
|
|
4152
|
+
}
|
|
4153
|
+
}
|
|
4154
|
+
}
|
|
4155
|
+
return dInput;
|
|
4156
|
+
}
|
|
4157
|
+
};
|
|
4158
|
+
|
|
4159
|
+
// src/Flatten.ts
|
|
4160
|
+
var Flatten = class {
|
|
4161
|
+
constructor() {
|
|
4162
|
+
this.inputShape = null;
|
|
4163
|
+
}
|
|
4164
|
+
// [H, W, C]
|
|
4165
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
4166
|
+
// Flattens input[h][w][c] into a 1D array in row-major, channel-last order.
|
|
4167
|
+
forward(input) {
|
|
4168
|
+
const H = input.length;
|
|
4169
|
+
const W = input[0].length;
|
|
4170
|
+
const C = input[0][0].length;
|
|
4171
|
+
this.inputShape = [H, W, C];
|
|
4172
|
+
const flat = new Array(H * W * C);
|
|
4173
|
+
let idx = 0;
|
|
4174
|
+
for (let h = 0; h < H; h++) {
|
|
4175
|
+
for (let w = 0; w < W; w++) {
|
|
4176
|
+
for (let c = 0; c < C; c++) {
|
|
4177
|
+
flat[idx++] = input[h][w][c];
|
|
4178
|
+
}
|
|
4179
|
+
}
|
|
4180
|
+
}
|
|
4181
|
+
return flat;
|
|
4182
|
+
}
|
|
4183
|
+
// ── Backward pass ─────────────────────────────────────────────────────────
|
|
4184
|
+
// Reshapes a flat gradient vector back into [H][W][C] using the saved shape.
|
|
4185
|
+
backward(dOutput) {
|
|
4186
|
+
if (!this.inputShape) {
|
|
4187
|
+
throw new Error("Flatten.backward: call forward() first");
|
|
4188
|
+
}
|
|
4189
|
+
const [H, W, C] = this.inputShape;
|
|
4190
|
+
if (dOutput.length !== H * W * C) {
|
|
4191
|
+
throw new Error(
|
|
4192
|
+
`Flatten.backward: expected gradient of length ${H * W * C}, got ${dOutput.length}`
|
|
4193
|
+
);
|
|
4194
|
+
}
|
|
4195
|
+
const dInput = Array.from(
|
|
4196
|
+
{ length: H },
|
|
4197
|
+
() => Array.from({ length: W }, () => new Array(C).fill(0))
|
|
4198
|
+
);
|
|
4199
|
+
let idx = 0;
|
|
4200
|
+
for (let h = 0; h < H; h++) {
|
|
4201
|
+
for (let w = 0; w < W; w++) {
|
|
4202
|
+
for (let c = 0; c < C; c++) {
|
|
4203
|
+
dInput[h][w][c] = dOutput[idx++];
|
|
4204
|
+
}
|
|
4205
|
+
}
|
|
4206
|
+
}
|
|
4207
|
+
return dInput;
|
|
4208
|
+
}
|
|
4209
|
+
};
|
|
4210
|
+
|
|
4211
|
+
// src/RNN.ts
|
|
4212
|
+
function tanh3(x) {
|
|
4213
|
+
const e = Math.exp(2 * x);
|
|
4214
|
+
return (e - 1) / (e + 1);
|
|
4215
|
+
}
|
|
4216
|
+
var RNN = class {
|
|
4217
|
+
constructor(inputSize, hiddenSize, outputSize, optimizerFactory = () => new SGD()) {
|
|
4218
|
+
// Trajectory stored during forward for BPTT
|
|
4219
|
+
this._traj = [];
|
|
4220
|
+
this._outputs = [];
|
|
4221
|
+
if (inputSize <= 0 || hiddenSize <= 0 || outputSize <= 0) {
|
|
4222
|
+
throw new Error("RNN: all sizes must be positive");
|
|
4223
|
+
}
|
|
4224
|
+
this.inputSize = inputSize;
|
|
4225
|
+
this.hiddenSize = hiddenSize;
|
|
4226
|
+
this.outputSize = outputSize;
|
|
4227
|
+
const limXH = Math.sqrt(2 / inputSize);
|
|
4228
|
+
const limHH = Math.sqrt(2 / hiddenSize);
|
|
4229
|
+
const limHY = Math.sqrt(2 / hiddenSize);
|
|
4230
|
+
this.Wxh = Array.from(
|
|
4231
|
+
{ length: hiddenSize },
|
|
4232
|
+
() => Array.from({ length: inputSize }, () => (Math.random() * 2 - 1) * limXH)
|
|
4233
|
+
);
|
|
4234
|
+
this.Whh = Array.from(
|
|
4235
|
+
{ length: hiddenSize },
|
|
4236
|
+
() => Array.from({ length: hiddenSize }, () => (Math.random() * 2 - 1) * limHH)
|
|
4237
|
+
);
|
|
4238
|
+
this.Why = Array.from(
|
|
4239
|
+
{ length: outputSize },
|
|
4240
|
+
() => Array.from({ length: hiddenSize }, () => (Math.random() * 2 - 1) * limHY)
|
|
4241
|
+
);
|
|
4242
|
+
this.bh = new Array(hiddenSize).fill(0);
|
|
4243
|
+
this.by = new Array(outputSize).fill(0);
|
|
4244
|
+
this._h = new Array(hiddenSize).fill(0);
|
|
4245
|
+
this._opts = {
|
|
4246
|
+
Wxh: Array.from(
|
|
4247
|
+
{ length: hiddenSize },
|
|
4248
|
+
() => Array.from({ length: inputSize }, () => optimizerFactory())
|
|
4249
|
+
),
|
|
4250
|
+
Whh: Array.from(
|
|
4251
|
+
{ length: hiddenSize },
|
|
4252
|
+
() => Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
4253
|
+
),
|
|
4254
|
+
Why: Array.from(
|
|
4255
|
+
{ length: outputSize },
|
|
4256
|
+
() => Array.from({ length: hiddenSize }, () => optimizerFactory())
|
|
4257
|
+
),
|
|
4258
|
+
bh: Array.from({ length: hiddenSize }, () => optimizerFactory()),
|
|
4259
|
+
by: Array.from({ length: outputSize }, () => optimizerFactory())
|
|
4260
|
+
};
|
|
4261
|
+
}
|
|
4262
|
+
// ── Reset hidden state ────────────────────────────────────────────────────
|
|
4263
|
+
// Call at the start of each new sequence / episode.
|
|
4264
|
+
reset() {
|
|
4265
|
+
this._h = new Array(this.hiddenSize).fill(0);
|
|
4266
|
+
this._traj = [];
|
|
4267
|
+
this._outputs = [];
|
|
4268
|
+
}
|
|
4269
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
4270
|
+
// sequence: number[][] of shape [T][inputSize]
|
|
4271
|
+
// Returns outputs and hidden states for all timesteps.
|
|
4272
|
+
forward(sequence) {
|
|
4273
|
+
this._traj = [];
|
|
4274
|
+
this._outputs = [];
|
|
4275
|
+
const outputs = [];
|
|
4276
|
+
const hiddens = [];
|
|
4277
|
+
let hPrev = [...this._h];
|
|
4278
|
+
for (const x of sequence) {
|
|
4279
|
+
const hRaw = this.bh.map(
|
|
4280
|
+
(b, i) => b + this.Wxh[i].reduce((s, w, j) => s + w * x[j], 0) + this.Whh[i].reduce((s, w, j) => s + w * hPrev[j], 0)
|
|
4281
|
+
);
|
|
4282
|
+
const h = hRaw.map(tanh3);
|
|
4283
|
+
const o = this.by.map(
|
|
4284
|
+
(b, i) => b + this.Why[i].reduce((s, w, j) => s + w * h[j], 0)
|
|
4285
|
+
);
|
|
4286
|
+
this._traj.push({ x: [...x], h: [...h], hRaw: [...hRaw], hPrev: [...hPrev] });
|
|
4287
|
+
this._outputs.push(o);
|
|
4288
|
+
outputs.push(o);
|
|
4289
|
+
hiddens.push(h);
|
|
4290
|
+
hPrev = h;
|
|
4291
|
+
}
|
|
4292
|
+
this._h = hPrev;
|
|
4293
|
+
return { outputs, hiddens };
|
|
4294
|
+
}
|
|
4295
|
+
// ── BPTT + weight update ──────────────────────────────────────────────────
|
|
4296
|
+
// targets: number[][] of shape [T][outputSize], paired with the last forward call.
|
|
4297
|
+
// Returns the mean squared error loss.
|
|
4298
|
+
backward(sequence, targets, lr) {
|
|
4299
|
+
this.reset();
|
|
4300
|
+
const { outputs } = this.forward(sequence);
|
|
4301
|
+
const T = this._traj.length;
|
|
4302
|
+
if (T === 0) return 0;
|
|
4303
|
+
let loss = 0;
|
|
4304
|
+
const dOutputs = outputs.map((o, t) => {
|
|
4305
|
+
return o.map((v, k) => {
|
|
4306
|
+
const diff = v - targets[t][k];
|
|
4307
|
+
loss += diff * diff;
|
|
4308
|
+
return 2 * diff / this.outputSize;
|
|
4309
|
+
});
|
|
4310
|
+
});
|
|
4311
|
+
loss /= T * this.outputSize;
|
|
4312
|
+
const dWxh = Array.from({ length: this.hiddenSize }, () => new Array(this.inputSize).fill(0));
|
|
4313
|
+
const dWhh = Array.from({ length: this.hiddenSize }, () => new Array(this.hiddenSize).fill(0));
|
|
4314
|
+
const dWhy = Array.from({ length: this.outputSize }, () => new Array(this.hiddenSize).fill(0));
|
|
4315
|
+
const dbh = new Array(this.hiddenSize).fill(0);
|
|
4316
|
+
const dby = new Array(this.outputSize).fill(0);
|
|
4317
|
+
let dhNext = new Array(this.hiddenSize).fill(0);
|
|
4318
|
+
for (let t = T - 1; t >= 0; t--) {
|
|
4319
|
+
const s = this._traj[t];
|
|
4320
|
+
const do_ = dOutputs[t];
|
|
4321
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4322
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4323
|
+
dWhy[i][j] += do_[i] * s.h[j];
|
|
4324
|
+
}
|
|
4325
|
+
dby[i] += do_[i];
|
|
4326
|
+
}
|
|
4327
|
+
const dh = this.hiddenSize > 0 ? Array.from(
|
|
4328
|
+
{ length: this.hiddenSize },
|
|
4329
|
+
(_, j) => this.Why.reduce((sum, row, i) => sum + row[j] * do_[i], 0) + dhNext[j]
|
|
4330
|
+
) : [];
|
|
4331
|
+
const dhRaw = dh.map((d, k) => d * (1 - s.h[k] ** 2));
|
|
4332
|
+
for (let i = 0; i < this.hiddenSize; i++) {
|
|
4333
|
+
for (let j = 0; j < this.inputSize; j++) {
|
|
4334
|
+
dWxh[i][j] += dhRaw[i] * s.x[j];
|
|
4335
|
+
}
|
|
4336
|
+
dbh[i] += dhRaw[i];
|
|
4337
|
+
}
|
|
4338
|
+
for (let i = 0; i < this.hiddenSize; i++) {
|
|
4339
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4340
|
+
dWhh[i][j] += dhRaw[i] * s.hPrev[j];
|
|
4341
|
+
}
|
|
4342
|
+
}
|
|
4343
|
+
dhNext = Array.from(
|
|
4344
|
+
{ length: this.hiddenSize },
|
|
4345
|
+
(_, j) => this.Whh.reduce((sum, row, i) => sum + row[j] * dhRaw[i], 0)
|
|
4346
|
+
);
|
|
4347
|
+
}
|
|
4348
|
+
const scale = lr / T;
|
|
4349
|
+
for (let i = 0; i < this.hiddenSize; i++) {
|
|
4350
|
+
for (let j = 0; j < this.inputSize; j++) {
|
|
4351
|
+
this.Wxh[i][j] = this._opts.Wxh[i][j].step(this.Wxh[i][j], dWxh[i][j], scale);
|
|
4352
|
+
}
|
|
4353
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4354
|
+
this.Whh[i][j] = this._opts.Whh[i][j].step(this.Whh[i][j], dWhh[i][j], scale);
|
|
4355
|
+
}
|
|
4356
|
+
this.bh[i] = this._opts.bh[i].step(this.bh[i], dbh[i], scale);
|
|
4357
|
+
}
|
|
4358
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4359
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4360
|
+
this.Why[i][j] = this._opts.Why[i][j].step(this.Why[i][j], dWhy[i][j], scale);
|
|
4361
|
+
}
|
|
4362
|
+
this.by[i] = this._opts.by[i].step(this.by[i], dby[i], scale);
|
|
4363
|
+
}
|
|
4364
|
+
this._traj = [];
|
|
4365
|
+
this._outputs = [];
|
|
4366
|
+
return loss;
|
|
4367
|
+
}
|
|
4368
|
+
};
|
|
4369
|
+
|
|
4370
|
+
// src/Seq2Seq.ts
|
|
4371
|
+
var Seq2Seq = class {
|
|
4372
|
+
// [outputSize]
|
|
4373
|
+
constructor(inputSize, hiddenSize, outputSize, options) {
|
|
4374
|
+
if (inputSize <= 0 || hiddenSize <= 0 || outputSize <= 0) {
|
|
4375
|
+
throw new Error("Seq2Seq: all sizes must be positive");
|
|
4376
|
+
}
|
|
4377
|
+
this.inputSize = inputSize;
|
|
4378
|
+
this.hiddenSize = hiddenSize;
|
|
4379
|
+
this.outputSize = outputSize;
|
|
4380
|
+
const factory = options?.optimizerFactory ?? (() => new SGD());
|
|
4381
|
+
this.encoder = new LSTMLayer(inputSize, hiddenSize, factory);
|
|
4382
|
+
this.decoder = new LSTMLayer(outputSize, hiddenSize, factory);
|
|
4383
|
+
const limit = Math.sqrt(2 / hiddenSize);
|
|
4384
|
+
this.W_out = Array.from(
|
|
4385
|
+
{ length: outputSize },
|
|
4386
|
+
() => Array.from({ length: hiddenSize }, () => (Math.random() * 2 - 1) * limit)
|
|
4387
|
+
);
|
|
4388
|
+
this.b_out = new Array(outputSize).fill(0);
|
|
4389
|
+
this._wOutOpts = Array.from(
|
|
4390
|
+
{ length: outputSize },
|
|
4391
|
+
() => Array.from({ length: hiddenSize }, () => factory())
|
|
4392
|
+
);
|
|
4393
|
+
this._bOutOpts = Array.from({ length: outputSize }, () => factory());
|
|
4394
|
+
}
|
|
4395
|
+
// ── Linear projection ─────────────────────────────────────────────────────
|
|
4396
|
+
_project(h) {
|
|
4397
|
+
return this.b_out.map(
|
|
4398
|
+
(b, i) => b + this.W_out[i].reduce((s, w, j) => s + w * h[j], 0)
|
|
4399
|
+
);
|
|
4400
|
+
}
|
|
4401
|
+
// ── Encode ────────────────────────────────────────────────────────────────
|
|
4402
|
+
// Runs the encoder over inputSequence and returns the final (h, c) pair.
|
|
4403
|
+
// The context vector summarizes the full input sequence.
|
|
4404
|
+
encode(inputSequence) {
|
|
4405
|
+
this.encoder.reset();
|
|
4406
|
+
for (const x of inputSequence) {
|
|
4407
|
+
this.encoder.predict(x);
|
|
4408
|
+
}
|
|
4409
|
+
return {
|
|
4410
|
+
h: [...this.encoder.h],
|
|
4411
|
+
c: [...this.encoder.c]
|
|
4412
|
+
};
|
|
4413
|
+
}
|
|
4414
|
+
// ── Decode ────────────────────────────────────────────────────────────────
|
|
4415
|
+
// Generates `steps` output tokens autoregressively.
|
|
4416
|
+
// The decoder starts from contextVector and uses its own previous output
|
|
4417
|
+
// as input at each step (greedy / free-running decoding).
|
|
4418
|
+
decode(contextVector, steps) {
|
|
4419
|
+
this.decoder.reset();
|
|
4420
|
+
this.decoder.h = [...contextVector.h];
|
|
4421
|
+
this.decoder.c = [...contextVector.c];
|
|
4422
|
+
const results = [];
|
|
4423
|
+
let prevOutput = new Array(this.outputSize).fill(0);
|
|
4424
|
+
for (let t = 0; t < steps; t++) {
|
|
4425
|
+
const hidden = this.decoder.predict(prevOutput);
|
|
4426
|
+
const output = this._project(hidden);
|
|
4427
|
+
results.push(output);
|
|
4428
|
+
prevOutput = output;
|
|
4429
|
+
}
|
|
4430
|
+
return results;
|
|
4431
|
+
}
|
|
4432
|
+
// ── Training step (teacher forcing) ──────────────────────────────────────
|
|
4433
|
+
// inputSeq: number[][] of shape [T_in][inputSize]
|
|
4434
|
+
// targetSeq: number[][] of shape [T_out][outputSize]
|
|
4435
|
+
// Returns the MSE loss for this step.
|
|
4436
|
+
trainStep(inputSeq, targetSeq, lr) {
|
|
4437
|
+
const T = targetSeq.length;
|
|
4438
|
+
if (T === 0) return 0;
|
|
4439
|
+
this.encoder.reset();
|
|
4440
|
+
for (const x of inputSeq) {
|
|
4441
|
+
this.encoder.predict(x);
|
|
4442
|
+
}
|
|
4443
|
+
const contextH = [...this.encoder.h];
|
|
4444
|
+
const contextC = [...this.encoder.c];
|
|
4445
|
+
this.decoder.reset();
|
|
4446
|
+
this.decoder.h = [...contextH];
|
|
4447
|
+
this.decoder.c = [...contextC];
|
|
4448
|
+
const hiddens = [];
|
|
4449
|
+
const projOuts = [];
|
|
4450
|
+
let prevTeacher = new Array(this.outputSize).fill(0);
|
|
4451
|
+
for (let t = 0; t < T; t++) {
|
|
4452
|
+
const h = this.decoder.predict(prevTeacher);
|
|
4453
|
+
const out = this._project(h);
|
|
4454
|
+
hiddens.push(h);
|
|
4455
|
+
projOuts.push(out);
|
|
4456
|
+
prevTeacher = targetSeq[t];
|
|
4457
|
+
}
|
|
4458
|
+
let loss = 0;
|
|
4459
|
+
const dProjOut = projOuts.map(
|
|
4460
|
+
(o, t) => o.map((v, k) => {
|
|
4461
|
+
const diff = v - targetSeq[t][k];
|
|
4462
|
+
loss += diff * diff;
|
|
4463
|
+
return 2 * diff / this.outputSize;
|
|
4464
|
+
})
|
|
4465
|
+
);
|
|
4466
|
+
loss /= T * this.outputSize;
|
|
4467
|
+
const dhSeq = Array.from(
|
|
4468
|
+
{ length: T },
|
|
4469
|
+
() => new Array(this.hiddenSize).fill(0)
|
|
4470
|
+
);
|
|
4471
|
+
const dWout = Array.from(
|
|
4472
|
+
{ length: this.outputSize },
|
|
4473
|
+
() => new Array(this.hiddenSize).fill(0)
|
|
4474
|
+
);
|
|
4475
|
+
const dbOut = new Array(this.outputSize).fill(0);
|
|
4476
|
+
for (let t = 0; t < T; t++) {
|
|
4477
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4478
|
+
const dv = dProjOut[t][i];
|
|
4479
|
+
dbOut[i] += dv;
|
|
4480
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4481
|
+
dWout[i][j] += dv * hiddens[t][j];
|
|
4482
|
+
dhSeq[t][j] += dv * this.W_out[i][j];
|
|
4483
|
+
}
|
|
4484
|
+
}
|
|
4485
|
+
}
|
|
4486
|
+
const scale = lr / T;
|
|
4487
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4488
|
+
for (let j = 0; j < this.hiddenSize; j++) {
|
|
4489
|
+
this.W_out[i][j] = this._wOutOpts[i][j].step(this.W_out[i][j], dWout[i][j], scale);
|
|
4490
|
+
}
|
|
4491
|
+
this.b_out[i] = this._bOutOpts[i].step(this.b_out[i], dbOut[i], scale);
|
|
4492
|
+
}
|
|
4493
|
+
this.decoder.backprop(dhSeq, lr);
|
|
4494
|
+
const dContext = dhSeq[0];
|
|
4495
|
+
const encoderDhSeq = inputSeq.map(
|
|
4496
|
+
(_, t) => t === inputSeq.length - 1 ? [...dContext] : new Array(this.hiddenSize).fill(0)
|
|
4497
|
+
);
|
|
4498
|
+
this.encoder.backprop(encoderDhSeq, lr);
|
|
4499
|
+
return loss;
|
|
4500
|
+
}
|
|
4501
|
+
};
|
|
4502
|
+
|
|
4503
|
+
// src/TCN.ts
|
|
4504
|
+
var CausalConv1D = class {
|
|
4505
|
+
constructor(inputChannels, outputChannels, kernelSize, dilation, optimizerFactory = () => new SGD()) {
|
|
4506
|
+
// Cache for backward pass
|
|
4507
|
+
this._paddedInput = null;
|
|
4508
|
+
this._inputLen = 0;
|
|
4509
|
+
if (inputChannels <= 0 || outputChannels <= 0 || kernelSize <= 0 || dilation <= 0) {
|
|
4510
|
+
throw new Error("CausalConv1D: all dimensions must be positive");
|
|
4511
|
+
}
|
|
4512
|
+
this.inputChannels = inputChannels;
|
|
4513
|
+
this.outputChannels = outputChannels;
|
|
4514
|
+
this.kernelSize = kernelSize;
|
|
4515
|
+
this.dilation = dilation;
|
|
4516
|
+
const limit = Math.sqrt(2 / (kernelSize * inputChannels));
|
|
4517
|
+
this.kernels = Array.from(
|
|
4518
|
+
{ length: outputChannels },
|
|
4519
|
+
() => Array.from(
|
|
4520
|
+
{ length: kernelSize },
|
|
4521
|
+
() => Array.from({ length: inputChannels }, () => (Math.random() * 2 - 1) * limit)
|
|
4522
|
+
)
|
|
4523
|
+
);
|
|
4524
|
+
this.biases = new Array(outputChannels).fill(0);
|
|
4525
|
+
this._kOpts = Array.from(
|
|
4526
|
+
{ length: outputChannels },
|
|
4527
|
+
() => Array.from(
|
|
4528
|
+
{ length: kernelSize },
|
|
4529
|
+
() => Array.from({ length: inputChannels }, () => optimizerFactory())
|
|
4530
|
+
)
|
|
4531
|
+
);
|
|
4532
|
+
this._bOpts = Array.from({ length: outputChannels }, () => optimizerFactory());
|
|
4533
|
+
}
|
|
4534
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
4535
|
+
// input: [T][inputChannels]
|
|
4536
|
+
// Returns: [T][outputChannels] (same length — causal padding preserves T)
|
|
4537
|
+
forward(input) {
|
|
4538
|
+
const T = input.length;
|
|
4539
|
+
const pad = (this.kernelSize - 1) * this.dilation;
|
|
4540
|
+
const zeroCh = new Array(this.inputChannels).fill(0);
|
|
4541
|
+
const padded = [
|
|
4542
|
+
...Array.from({ length: pad }, () => [...zeroCh]),
|
|
4543
|
+
...input.map((row) => [...row])
|
|
4544
|
+
];
|
|
4545
|
+
this._paddedInput = padded;
|
|
4546
|
+
this._inputLen = T;
|
|
4547
|
+
const output = Array.from(
|
|
4548
|
+
{ length: T },
|
|
4549
|
+
() => new Array(this.outputChannels).fill(0)
|
|
4550
|
+
);
|
|
4551
|
+
for (let t = 0; t < T; t++) {
|
|
4552
|
+
for (let f = 0; f < this.outputChannels; f++) {
|
|
4553
|
+
let sum = this.biases[f];
|
|
4554
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
4555
|
+
const srcPos = t + k * this.dilation;
|
|
4556
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
4557
|
+
sum += this.kernels[f][k][c] * padded[srcPos][c];
|
|
4558
|
+
}
|
|
4559
|
+
}
|
|
4560
|
+
output[t][f] = sum;
|
|
4561
|
+
}
|
|
4562
|
+
}
|
|
4563
|
+
return output;
|
|
4564
|
+
}
|
|
4565
|
+
// ── Backward pass ─────────────────────────────────────────────────────────
|
|
4566
|
+
// dOutput: [T][outputChannels]
|
|
4567
|
+
// Returns dInput: [T][inputChannels]
|
|
4568
|
+
backward(dOutput, lr) {
|
|
4569
|
+
if (!this._paddedInput) {
|
|
4570
|
+
throw new Error("CausalConv1D.backward: call forward() first");
|
|
4571
|
+
}
|
|
4572
|
+
const T = this._inputLen;
|
|
4573
|
+
const pad = (this.kernelSize - 1) * this.dilation;
|
|
4574
|
+
const padded = this._paddedInput;
|
|
4575
|
+
const dKernels = Array.from(
|
|
4576
|
+
{ length: this.outputChannels },
|
|
4577
|
+
() => Array.from(
|
|
4578
|
+
{ length: this.kernelSize },
|
|
4579
|
+
() => new Array(this.inputChannels).fill(0)
|
|
4580
|
+
)
|
|
4581
|
+
);
|
|
4582
|
+
const dBiases = new Array(this.outputChannels).fill(0);
|
|
4583
|
+
const dPadded = Array.from(
|
|
4584
|
+
{ length: padded.length },
|
|
4585
|
+
() => new Array(this.inputChannels).fill(0)
|
|
4586
|
+
);
|
|
4587
|
+
for (let t = 0; t < T; t++) {
|
|
4588
|
+
for (let f = 0; f < this.outputChannels; f++) {
|
|
4589
|
+
const dv = dOutput[t][f];
|
|
4590
|
+
dBiases[f] += dv;
|
|
4591
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
4592
|
+
const srcPos = t + k * this.dilation;
|
|
4593
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
4594
|
+
dKernels[f][k][c] += dv * padded[srcPos][c];
|
|
4595
|
+
dPadded[srcPos][c] += dv * this.kernels[f][k][c];
|
|
4596
|
+
}
|
|
4597
|
+
}
|
|
4598
|
+
}
|
|
4599
|
+
}
|
|
4600
|
+
for (let f = 0; f < this.outputChannels; f++) {
|
|
4601
|
+
for (let k = 0; k < this.kernelSize; k++) {
|
|
4602
|
+
for (let c = 0; c < this.inputChannels; c++) {
|
|
4603
|
+
this.kernels[f][k][c] = this._kOpts[f][k][c].step(
|
|
4604
|
+
this.kernels[f][k][c],
|
|
4605
|
+
dKernels[f][k][c],
|
|
4606
|
+
lr
|
|
4607
|
+
);
|
|
4608
|
+
}
|
|
4609
|
+
}
|
|
4610
|
+
this.biases[f] = this._bOpts[f].step(this.biases[f], dBiases[f], lr);
|
|
4611
|
+
}
|
|
4612
|
+
return dPadded.slice(pad);
|
|
4613
|
+
}
|
|
4614
|
+
};
|
|
4615
|
+
var TCN = class {
|
|
4616
|
+
// linear projection outputs
|
|
4617
|
+
constructor(inputChannels, channels, kernelSize, levels, outputSize, optimizerFactory = () => new SGD()) {
|
|
4618
|
+
// Cache for backward pass
|
|
4619
|
+
this._layerInputs = [];
|
|
4620
|
+
// inputs to each conv layer
|
|
4621
|
+
this._layerOutputs = [];
|
|
4622
|
+
// outputs from each conv layer (pre-relu)
|
|
4623
|
+
this._lastHidden = [];
|
|
4624
|
+
// post-relu output of last conv layer
|
|
4625
|
+
this._finalOutputs = [];
|
|
4626
|
+
if (levels <= 0) throw new Error("TCN: levels must be positive");
|
|
4627
|
+
if (outputSize <= 0) throw new Error("TCN: outputSize must be positive");
|
|
4628
|
+
this.inputChannels = inputChannels;
|
|
4629
|
+
this.channels = channels;
|
|
4630
|
+
this.kernelSize = kernelSize;
|
|
4631
|
+
this.levels = levels;
|
|
4632
|
+
this.outputSize = outputSize;
|
|
4633
|
+
this.layers = [];
|
|
4634
|
+
for (let l = 0; l < levels; l++) {
|
|
4635
|
+
const dilation = Math.pow(2, l);
|
|
4636
|
+
const inCh = l === 0 ? inputChannels : channels;
|
|
4637
|
+
this.layers.push(new CausalConv1D(inCh, channels, kernelSize, dilation, optimizerFactory));
|
|
4638
|
+
}
|
|
4639
|
+
const outLimit = Math.sqrt(2 / channels);
|
|
4640
|
+
this._outputW = Array.from(
|
|
4641
|
+
{ length: outputSize },
|
|
4642
|
+
() => Array.from({ length: channels }, () => (Math.random() * 2 - 1) * outLimit)
|
|
4643
|
+
);
|
|
4644
|
+
this._outputB = new Array(outputSize).fill(0);
|
|
4645
|
+
this._outOpts = Array.from(
|
|
4646
|
+
{ length: outputSize },
|
|
4647
|
+
() => Array.from({ length: channels }, () => optimizerFactory())
|
|
4648
|
+
);
|
|
4649
|
+
this._bOutOpts = Array.from({ length: outputSize }, () => optimizerFactory());
|
|
4650
|
+
}
|
|
4651
|
+
// ── Receptive field (informational) ──────────────────────────────────────
|
|
4652
|
+
// RF = (kernelSize - 1) · (2^levels - 1) + 1
|
|
4653
|
+
get receptiveField() {
|
|
4654
|
+
return (this.kernelSize - 1) * (Math.pow(2, this.levels) - 1) + 1;
|
|
4655
|
+
}
|
|
4656
|
+
// ── Forward pass ──────────────────────────────────────────────────────────
|
|
4657
|
+
// sequence: [T][inputChannels]
|
|
4658
|
+
// Returns: [T][outputSize]
|
|
4659
|
+
forward(sequence) {
|
|
4660
|
+
this._layerInputs = [];
|
|
4661
|
+
this._layerOutputs = [];
|
|
4662
|
+
let current = sequence;
|
|
4663
|
+
for (let l = 0; l < this.levels; l++) {
|
|
4664
|
+
this._layerInputs.push(current.map((row) => [...row]));
|
|
4665
|
+
const convOut = this.layers[l].forward(current);
|
|
4666
|
+
this._layerOutputs.push(convOut);
|
|
4667
|
+
const afterRelu = convOut.map((row) => row.map((v) => relu.fn(v)));
|
|
4668
|
+
if (current[0].length === afterRelu[0].length) {
|
|
4669
|
+
current = afterRelu.map((row, t) => row.map((v, c) => v + current[t][c]));
|
|
4670
|
+
} else {
|
|
4671
|
+
current = afterRelu;
|
|
4672
|
+
}
|
|
4673
|
+
}
|
|
4674
|
+
this._lastHidden = current;
|
|
4675
|
+
const T = current.length;
|
|
4676
|
+
this._finalOutputs = Array.from(
|
|
4677
|
+
{ length: T },
|
|
4678
|
+
(_, t) => this._outputB.map(
|
|
4679
|
+
(b, i) => b + this._outputW[i].reduce((s, w, j) => s + w * current[t][j], 0)
|
|
4680
|
+
)
|
|
4681
|
+
);
|
|
4682
|
+
return this._finalOutputs.map((row) => [...row]);
|
|
4683
|
+
}
|
|
4684
|
+
// ── Train one step ────────────────────────────────────────────────────────
|
|
4685
|
+
// sequence: [T][inputChannels]
|
|
4686
|
+
// targets: [T][outputSize]
|
|
4687
|
+
// Returns MSE loss.
|
|
4688
|
+
train(sequence, targets, lr) {
|
|
4689
|
+
const outputs = this.forward(sequence);
|
|
4690
|
+
const T = outputs.length;
|
|
4691
|
+
let loss = 0;
|
|
4692
|
+
const dOut = outputs.map(
|
|
4693
|
+
(o, t) => o.map((v, k) => {
|
|
4694
|
+
const diff = v - targets[t][k];
|
|
4695
|
+
loss += diff * diff;
|
|
4696
|
+
return 2 * diff / this.outputSize;
|
|
4697
|
+
})
|
|
4698
|
+
);
|
|
4699
|
+
loss /= T * this.outputSize;
|
|
4700
|
+
const dWout = Array.from({ length: this.outputSize }, () => new Array(this.channels).fill(0));
|
|
4701
|
+
const dBout = new Array(this.outputSize).fill(0);
|
|
4702
|
+
const dHidden = Array.from({ length: T }, () => new Array(this.channels).fill(0));
|
|
4703
|
+
for (let t = 0; t < T; t++) {
|
|
4704
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4705
|
+
const dv = dOut[t][i];
|
|
4706
|
+
dBout[i] += dv;
|
|
4707
|
+
for (let j = 0; j < this.channels; j++) {
|
|
4708
|
+
dWout[i][j] += dv * this._lastHidden[t][j];
|
|
4709
|
+
dHidden[t][j] += dv * this._outputW[i][j];
|
|
4710
|
+
}
|
|
4711
|
+
}
|
|
4712
|
+
}
|
|
4713
|
+
const scale = lr / T;
|
|
4714
|
+
for (let i = 0; i < this.outputSize; i++) {
|
|
4715
|
+
for (let j = 0; j < this.channels; j++) {
|
|
4716
|
+
this._outputW[i][j] = this._outOpts[i][j].step(this._outputW[i][j], dWout[i][j], scale);
|
|
4717
|
+
}
|
|
4718
|
+
this._outputB[i] = this._bOutOpts[i].step(this._outputB[i], dBout[i], scale);
|
|
4719
|
+
}
|
|
4720
|
+
let dCurrent = dHidden;
|
|
4721
|
+
for (let l = this.levels - 1; l >= 0; l--) {
|
|
4722
|
+
const convOut = this._layerOutputs[l];
|
|
4723
|
+
const layerIn = this._layerInputs[l];
|
|
4724
|
+
const dConvOut = dCurrent.map(
|
|
4725
|
+
(row, t) => row.map((d, c) => d * (convOut[t][c] > 0 ? 1 : 0))
|
|
4726
|
+
);
|
|
4727
|
+
let dPrevLayer = this.layers[l].backward(dConvOut, lr);
|
|
4728
|
+
if (layerIn[0].length === dCurrent[0].length) {
|
|
4729
|
+
dPrevLayer = dPrevLayer.map(
|
|
4730
|
+
(row, t) => row.map((d, c) => d + dCurrent[t][c])
|
|
4731
|
+
);
|
|
4732
|
+
}
|
|
4733
|
+
dCurrent = dPrevLayer;
|
|
4734
|
+
}
|
|
4735
|
+
return loss;
|
|
4736
|
+
}
|
|
4737
|
+
};
|
|
4738
|
+
|
|
4739
|
+
// src/Word2Vec.ts
|
|
4740
|
+
var Word2Vec = class {
|
|
4741
|
+
constructor(embeddingDim = 50, options = {}) {
|
|
4742
|
+
this._trained = false;
|
|
4743
|
+
this.embeddingDim = embeddingDim;
|
|
4744
|
+
this._windowSize = options.windowSize ?? 2;
|
|
4745
|
+
this._model = options.model ?? "skipgram";
|
|
4746
|
+
this._minCount = options.minCount ?? 1;
|
|
4747
|
+
this.embeddings = [];
|
|
4748
|
+
this._W2 = [];
|
|
4749
|
+
this.vocab = /* @__PURE__ */ new Map();
|
|
4750
|
+
this._indexToWord = [];
|
|
4751
|
+
this.vocabSize = 0;
|
|
4752
|
+
}
|
|
4753
|
+
// ── buildVocab ─────────────────────────────────────────────────────────────
|
|
4754
|
+
// Scans the corpus, counts word frequencies, discards rare words (< minCount),
|
|
4755
|
+
// and assigns each remaining word a unique integer index.
|
|
4756
|
+
buildVocab(sentences) {
|
|
4757
|
+
const freq = /* @__PURE__ */ new Map();
|
|
4758
|
+
for (const sentence of sentences) {
|
|
4759
|
+
for (const word of sentence) {
|
|
4760
|
+
freq.set(word, (freq.get(word) ?? 0) + 1);
|
|
4761
|
+
}
|
|
4762
|
+
}
|
|
4763
|
+
this.vocab = /* @__PURE__ */ new Map();
|
|
4764
|
+
this._indexToWord = [];
|
|
4765
|
+
for (const [word, count] of freq) {
|
|
4766
|
+
if (count >= this._minCount) {
|
|
4767
|
+
const idx = this._indexToWord.length;
|
|
4768
|
+
this.vocab.set(word, idx);
|
|
4769
|
+
this._indexToWord.push(word);
|
|
4770
|
+
}
|
|
4771
|
+
}
|
|
4772
|
+
this.vocabSize = this._indexToWord.length;
|
|
4773
|
+
if (this.vocabSize === 0) {
|
|
4774
|
+
throw new Error("Word2Vec.buildVocab: vocabulary is empty after applying minCount filter");
|
|
4775
|
+
}
|
|
4776
|
+
const scale1 = Math.sqrt(1 / this.embeddingDim);
|
|
4777
|
+
const scale2 = Math.sqrt(1 / this.vocabSize);
|
|
4778
|
+
this.embeddings = Array.from(
|
|
4779
|
+
{ length: this.vocabSize },
|
|
4780
|
+
() => Array.from({ length: this.embeddingDim }, () => (Math.random() * 2 - 1) * scale1)
|
|
4781
|
+
);
|
|
4782
|
+
this._W2 = Array.from(
|
|
4783
|
+
{ length: this.embeddingDim },
|
|
4784
|
+
() => Array.from({ length: this.vocabSize }, () => (Math.random() * 2 - 1) * scale2)
|
|
4785
|
+
);
|
|
4786
|
+
this._trained = false;
|
|
4787
|
+
}
|
|
4788
|
+
// ── tokenize ───────────────────────────────────────────────────────────────
|
|
4789
|
+
// Simple tokenizer: lowercase, strip punctuation, split on whitespace.
|
|
4790
|
+
// Returns an array of tokens suitable for buildVocab / train.
|
|
4791
|
+
static tokenize(text) {
|
|
4792
|
+
return text.toLowerCase().replace(/[^a-z0-9\s'-]/g, " ").split(/\s+/).filter((t) => t.length > 0);
|
|
4793
|
+
}
|
|
4794
|
+
// ── train ──────────────────────────────────────────────────────────────────
|
|
4795
|
+
// Runs SGD over all (center, context) pairs in the corpus for `epochs` passes.
|
|
4796
|
+
// Returns the average cross-entropy loss per epoch.
|
|
4797
|
+
//
|
|
4798
|
+
// Note: uses full-vocabulary softmax (not negative sampling) for educational
|
|
4799
|
+
// clarity. This is O(vocabSize) per step — for large vocabularies you would
|
|
4800
|
+
// normally switch to negative sampling or hierarchical softmax.
|
|
4801
|
+
train(sentences, lr = 0.025, epochs = 5) {
|
|
4802
|
+
if (this.vocabSize === 0) this.buildVocab(sentences);
|
|
4803
|
+
const lossHistory = [];
|
|
4804
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
4805
|
+
let totalLoss = 0;
|
|
4806
|
+
let nPairs = 0;
|
|
4807
|
+
for (const sentence of sentences) {
|
|
4808
|
+
const indices = sentence.map((w) => this.vocab.get(w)).filter((idx) => idx !== void 0);
|
|
4809
|
+
for (let t = 0; t < indices.length; t++) {
|
|
4810
|
+
const centerIdx = indices[t];
|
|
4811
|
+
const contextIndices = [];
|
|
4812
|
+
for (let offset = -this._windowSize; offset <= this._windowSize; offset++) {
|
|
4813
|
+
if (offset === 0) continue;
|
|
4814
|
+
const pos = t + offset;
|
|
4815
|
+
if (pos >= 0 && pos < indices.length) {
|
|
4816
|
+
contextIndices.push(indices[pos]);
|
|
4817
|
+
}
|
|
4818
|
+
}
|
|
4819
|
+
if (contextIndices.length === 0) continue;
|
|
4820
|
+
if (this._model === "skipgram") {
|
|
4821
|
+
for (const contextIdx of contextIndices) {
|
|
4822
|
+
totalLoss += this._skipgramStep(centerIdx, contextIdx, lr);
|
|
4823
|
+
nPairs++;
|
|
4824
|
+
}
|
|
4825
|
+
} else {
|
|
4826
|
+
totalLoss += this._cbowStep(centerIdx, contextIndices, lr);
|
|
4827
|
+
nPairs++;
|
|
4828
|
+
}
|
|
4829
|
+
}
|
|
4830
|
+
}
|
|
4831
|
+
lossHistory.push(nPairs > 0 ? totalLoss / nPairs : 0);
|
|
4832
|
+
}
|
|
4833
|
+
this._trained = true;
|
|
4834
|
+
return lossHistory;
|
|
4835
|
+
}
|
|
4836
|
+
// ── getEmbedding ───────────────────────────────────────────────────────────
|
|
4837
|
+
// Returns the learned embedding vector for a word. Throws if unknown.
|
|
4838
|
+
getEmbedding(word) {
|
|
4839
|
+
const idx = this.vocab.get(word);
|
|
4840
|
+
if (idx === void 0) throw new Error(`Word2Vec: unknown word "${word}"`);
|
|
4841
|
+
return this.embeddings[idx];
|
|
4842
|
+
}
|
|
4843
|
+
// ── similarity ─────────────────────────────────────────────────────────────
|
|
4844
|
+
// Cosine similarity between two words.
|
|
4845
|
+
// cos(v1, v2) = (v1 · v2) / (‖v1‖ · ‖v2‖)
|
|
4846
|
+
// Returns a value in [-1, 1]. Higher → more similar context usage.
|
|
4847
|
+
similarity(word1, word2) {
|
|
4848
|
+
const v1 = this.getEmbedding(word1);
|
|
4849
|
+
const v2 = this.getEmbedding(word2);
|
|
4850
|
+
return this._cosine(v1, v2);
|
|
4851
|
+
}
|
|
4852
|
+
// ── mostSimilar ────────────────────────────────────────────────────────────
|
|
4853
|
+
// Returns the topK words (excluding `word` itself) sorted by cosine similarity.
|
|
4854
|
+
mostSimilar(word, topK = 10) {
|
|
4855
|
+
const v = this.getEmbedding(word);
|
|
4856
|
+
return this._nearestByVector(v, topK, /* @__PURE__ */ new Set([word]));
|
|
4857
|
+
}
|
|
4858
|
+
// ── analogy ───────────────────────────────────────────────────────────────
|
|
4859
|
+
// Vector arithmetic analogy: positive1 - negative + positive2 ≈ result
|
|
4860
|
+
//
|
|
4861
|
+
// getAnalogy('king', 'man', 'woman') finds the word closest to
|
|
4862
|
+
// vec('king') - vec('man') + vec('woman') ≈ vec('queen')
|
|
4863
|
+
//
|
|
4864
|
+
// The result is excluded from the input words so they don't pollute the top-K.
|
|
4865
|
+
analogy(positive1, negative, positive2, topK = 5) {
|
|
4866
|
+
const vPos1 = this.getEmbedding(positive1);
|
|
4867
|
+
const vNeg = this.getEmbedding(negative);
|
|
4868
|
+
const vPos2 = this.getEmbedding(positive2);
|
|
4869
|
+
const target = vPos1.map((v, i) => v - vNeg[i] + vPos2[i]);
|
|
4870
|
+
const exclude = /* @__PURE__ */ new Set([positive1, negative, positive2]);
|
|
4871
|
+
return this._nearestByVector(target, topK, exclude);
|
|
4872
|
+
}
|
|
4873
|
+
// ── Private: skip-gram step ───────────────────────────────────────────────
|
|
4874
|
+
// Forward + backward for one (center, target) pair.
|
|
4875
|
+
// Returns the cross-entropy loss for this pair.
|
|
4876
|
+
_skipgramStep(centerIdx, targetIdx, lr) {
|
|
4877
|
+
const h = this.embeddings[centerIdx];
|
|
4878
|
+
const scores = this._hiddenToScores(h);
|
|
4879
|
+
const probs = _softmax(scores);
|
|
4880
|
+
const loss = -Math.log(probs[targetIdx] + 1e-12);
|
|
4881
|
+
const err = probs.map((p, j) => j === targetIdx ? p - 1 : p);
|
|
4882
|
+
const dh = new Array(this.embeddingDim).fill(0);
|
|
4883
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4884
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
4885
|
+
this._W2[d][j] -= lr * h[d] * err[j];
|
|
4886
|
+
dh[d] += this._W2[d][j] * err[j];
|
|
4887
|
+
}
|
|
4888
|
+
}
|
|
4889
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4890
|
+
this.embeddings[centerIdx][d] -= lr * dh[d];
|
|
4891
|
+
}
|
|
4892
|
+
return loss;
|
|
4893
|
+
}
|
|
4894
|
+
// ── Private: CBOW step ────────────────────────────────────────────────────
|
|
4895
|
+
// Forward + backward for one (contextIndices → centerIdx) pair.
|
|
4896
|
+
// h is the mean of all context embeddings. The gradient is distributed
|
|
4897
|
+
// equally back to each context word's embedding row.
|
|
4898
|
+
_cbowStep(centerIdx, contextIndices, lr) {
|
|
4899
|
+
const k = contextIndices.length;
|
|
4900
|
+
const h = new Array(this.embeddingDim).fill(0);
|
|
4901
|
+
for (const ci of contextIndices) {
|
|
4902
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4903
|
+
h[d] += this.embeddings[ci][d];
|
|
4904
|
+
}
|
|
4905
|
+
}
|
|
4906
|
+
for (let d = 0; d < this.embeddingDim; d++) h[d] /= k;
|
|
4907
|
+
const scores = this._hiddenToScores(h);
|
|
4908
|
+
const probs = _softmax(scores);
|
|
4909
|
+
const loss = -Math.log(probs[centerIdx] + 1e-12);
|
|
4910
|
+
const err = probs.map((p, j) => j === centerIdx ? p - 1 : p);
|
|
4911
|
+
const dh = new Array(this.embeddingDim).fill(0);
|
|
4912
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4913
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
4914
|
+
this._W2[d][j] -= lr * h[d] * err[j];
|
|
4915
|
+
dh[d] += this._W2[d][j] * err[j];
|
|
4916
|
+
}
|
|
4917
|
+
}
|
|
4918
|
+
for (const ci of contextIndices) {
|
|
4919
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4920
|
+
this.embeddings[ci][d] -= lr * dh[d] / k;
|
|
4921
|
+
}
|
|
4922
|
+
}
|
|
4923
|
+
return loss;
|
|
4924
|
+
}
|
|
4925
|
+
// Computes scores = h · W2 → [vocabSize]
|
|
4926
|
+
_hiddenToScores(h) {
|
|
4927
|
+
const scores = new Array(this.vocabSize).fill(0);
|
|
4928
|
+
for (let d = 0; d < this.embeddingDim; d++) {
|
|
4929
|
+
for (let j = 0; j < this.vocabSize; j++) {
|
|
4930
|
+
scores[j] += h[d] * this._W2[d][j];
|
|
4931
|
+
}
|
|
4932
|
+
}
|
|
4933
|
+
return scores;
|
|
4934
|
+
}
|
|
4935
|
+
// Returns topK words (from all embeddings) sorted by cosine similarity to v,
|
|
4936
|
+
// skipping any word in the exclude set.
|
|
4937
|
+
_nearestByVector(v, topK, exclude) {
|
|
4938
|
+
const results = [];
|
|
4939
|
+
for (let i = 0; i < this.vocabSize; i++) {
|
|
4940
|
+
const w = this._indexToWord[i];
|
|
4941
|
+
if (exclude.has(w)) continue;
|
|
4942
|
+
results.push({ word: w, score: this._cosine(v, this.embeddings[i]) });
|
|
4943
|
+
}
|
|
4944
|
+
results.sort((a, b) => b.score - a.score);
|
|
4945
|
+
return results.slice(0, topK);
|
|
4946
|
+
}
|
|
4947
|
+
// Cosine similarity: (v1 · v2) / (‖v1‖ · ‖v2‖)
|
|
4948
|
+
_cosine(v1, v2) {
|
|
4949
|
+
let dot = 0, n1 = 0, n2 = 0;
|
|
4950
|
+
for (let i = 0; i < v1.length; i++) {
|
|
4951
|
+
dot += v1[i] * v2[i];
|
|
4952
|
+
n1 += v1[i] * v1[i];
|
|
4953
|
+
n2 += v2[i] * v2[i];
|
|
4954
|
+
}
|
|
4955
|
+
const denom = Math.sqrt(n1) * Math.sqrt(n2);
|
|
4956
|
+
return denom < 1e-12 ? 0 : dot / denom;
|
|
4957
|
+
}
|
|
4958
|
+
};
|
|
4959
|
+
function _softmax(scores) {
|
|
4960
|
+
const max = Math.max(...scores);
|
|
4961
|
+
const exps = scores.map((s) => Math.exp(s - max));
|
|
4962
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
4963
|
+
return exps.map((e) => e / sum);
|
|
4964
|
+
}
|
|
4965
|
+
|
|
4966
|
+
// src/TSNE.ts
|
|
4967
|
+
var TSNE = class {
|
|
4968
|
+
constructor(options = {}) {
|
|
4969
|
+
// KL divergence tracked during the last fit() call.
|
|
4970
|
+
this._klDivergence = 0;
|
|
4971
|
+
// P matrix stored for kl() reporting.
|
|
4972
|
+
this._P = [];
|
|
4973
|
+
this._nComponents = options.nComponents ?? 2;
|
|
4974
|
+
this._perplexity = options.perplexity ?? 30;
|
|
4975
|
+
this._lr = options.lr ?? 200;
|
|
4976
|
+
this._nIter = options.nIter ?? 1e3;
|
|
4977
|
+
this._seed = options.seed;
|
|
4978
|
+
this.embedding = [];
|
|
4979
|
+
}
|
|
4980
|
+
// ── fit ────────────────────────────────────────────────────────────────────
|
|
4981
|
+
// Runs the full t-SNE algorithm on X (shape [n][d]).
|
|
4982
|
+
// Stores the result in this.embedding ([n][nComponents]).
|
|
4983
|
+
fit(X) {
|
|
4984
|
+
const n = X.length;
|
|
4985
|
+
if (n < 2) throw new Error("TSNE.fit: need at least 2 data points");
|
|
4986
|
+
if (this._perplexity >= n) {
|
|
4987
|
+
throw new Error(
|
|
4988
|
+
`TSNE.fit: perplexity (${this._perplexity}) must be less than n (${n})`
|
|
4989
|
+
);
|
|
4990
|
+
}
|
|
4991
|
+
const rng = this._seed !== void 0 ? _mulberry32(this._seed) : Math.random;
|
|
4992
|
+
const distSq = _pairwiseDistSq(X, n);
|
|
4993
|
+
const Pcond = this._computePcond(distSq, n);
|
|
4994
|
+
const P = _symmetrize(Pcond, n);
|
|
4995
|
+
this._P = P;
|
|
4996
|
+
let Y = Array.from({ length: n }, () => {
|
|
4997
|
+
return Array.from({ length: this._nComponents }, () => {
|
|
4998
|
+
const u1 = Math.max(rng(), 1e-12);
|
|
4999
|
+
const u2 = rng();
|
|
5000
|
+
const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
|
|
5001
|
+
return z * 0.01;
|
|
5002
|
+
});
|
|
5003
|
+
});
|
|
5004
|
+
let Yprev = Y.map((row) => [...row]);
|
|
5005
|
+
const EXAGGERATION_ITERS = 50;
|
|
5006
|
+
const EXAGGERATION_FACTOR = 4;
|
|
5007
|
+
const MOMENTUM_SWITCH = 20;
|
|
5008
|
+
for (let iter = 0; iter < this._nIter; iter++) {
|
|
5009
|
+
const momentum = iter < MOMENTUM_SWITCH ? 0.5 : 0.8;
|
|
5010
|
+
const pScale = iter < EXAGGERATION_ITERS ? EXAGGERATION_FACTOR : 1;
|
|
5011
|
+
const { Q, invDist } = _computeQ(Y, n, this._nComponents);
|
|
5012
|
+
const grad = Array.from(
|
|
5013
|
+
{ length: n },
|
|
5014
|
+
() => new Array(this._nComponents).fill(0)
|
|
5015
|
+
);
|
|
5016
|
+
for (let i = 0; i < n; i++) {
|
|
5017
|
+
for (let j = 0; j < n; j++) {
|
|
5018
|
+
if (i === j) continue;
|
|
5019
|
+
const pq = pScale * P[i][j] - Q[i][j];
|
|
5020
|
+
const c = 4 * pq * invDist[i][j];
|
|
5021
|
+
for (let d = 0; d < this._nComponents; d++) {
|
|
5022
|
+
grad[i][d] += c * (Y[i][d] - Y[j][d]);
|
|
5023
|
+
}
|
|
5024
|
+
}
|
|
5025
|
+
}
|
|
5026
|
+
const Ynext = Array.from(
|
|
5027
|
+
{ length: n },
|
|
5028
|
+
(_, i) => Array.from(
|
|
5029
|
+
{ length: this._nComponents },
|
|
5030
|
+
(_2, d) => Y[i][d] - this._lr * grad[i][d] + momentum * (Y[i][d] - Yprev[i][d])
|
|
5031
|
+
)
|
|
5032
|
+
);
|
|
5033
|
+
Yprev = Y;
|
|
5034
|
+
Y = Ynext;
|
|
5035
|
+
}
|
|
5036
|
+
this.embedding = Y;
|
|
5037
|
+
const { Q: Qfinal } = _computeQ(Y, n, this._nComponents);
|
|
5038
|
+
let kl = 0;
|
|
5039
|
+
for (let i = 0; i < n; i++) {
|
|
5040
|
+
for (let j = 0; j < n; j++) {
|
|
5041
|
+
if (i === j) continue;
|
|
5042
|
+
const p = P[i][j];
|
|
5043
|
+
if (p > 1e-12) {
|
|
5044
|
+
kl += p * Math.log(p / (Qfinal[i][j] + 1e-12));
|
|
5045
|
+
}
|
|
5046
|
+
}
|
|
5047
|
+
}
|
|
5048
|
+
this._klDivergence = kl;
|
|
5049
|
+
}
|
|
5050
|
+
// ── fitTransform ───────────────────────────────────────────────────────────
|
|
5051
|
+
// Convenience: fit() then return this.embedding.
|
|
5052
|
+
fitTransform(X) {
|
|
5053
|
+
this.fit(X);
|
|
5054
|
+
return this.embedding;
|
|
5055
|
+
}
|
|
5056
|
+
// ── kl ─────────────────────────────────────────────────────────────────────
|
|
5057
|
+
// Returns the KL divergence KL(P ‖ Q) from the last fit() call.
|
|
5058
|
+
// Lower is better. Useful for comparing perplexity settings or iteration counts.
|
|
5059
|
+
kl() {
|
|
5060
|
+
return this._klDivergence;
|
|
5061
|
+
}
|
|
5062
|
+
// ── Private: binary search for σi ─────────────────────────────────────────
|
|
5063
|
+
// For each point i, find σi such that the Shannon entropy of P(·|i) equals
|
|
5064
|
+
// log₂(perplexity). We use binary search on σ².
|
|
5065
|
+
_computePcond(distSq, n) {
|
|
5066
|
+
const targetEntropy = Math.log2(this._perplexity);
|
|
5067
|
+
const Pcond = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5068
|
+
for (let i = 0; i < n; i++) {
|
|
5069
|
+
let sigmaLo = 0;
|
|
5070
|
+
let sigmaHi = 1e10;
|
|
5071
|
+
let sigma2 = 1;
|
|
5072
|
+
for (let attempt = 0; attempt < 50; attempt++) {
|
|
5073
|
+
const dists = distSq[i];
|
|
5074
|
+
let sumExp = 0;
|
|
5075
|
+
const exps = new Array(n).fill(0);
|
|
5076
|
+
for (let j = 0; j < n; j++) {
|
|
5077
|
+
if (j === i) continue;
|
|
5078
|
+
const e = Math.exp(-dists[j] / (2 * sigma2));
|
|
5079
|
+
exps[j] = e;
|
|
5080
|
+
sumExp += e;
|
|
5081
|
+
}
|
|
5082
|
+
if (sumExp < 1e-12) break;
|
|
5083
|
+
let H = 0;
|
|
5084
|
+
for (let j = 0; j < n; j++) {
|
|
5085
|
+
if (j === i) continue;
|
|
5086
|
+
const p = exps[j] / sumExp;
|
|
5087
|
+
Pcond[i][j] = p;
|
|
5088
|
+
if (p > 1e-12) H -= p * Math.log2(p);
|
|
5089
|
+
}
|
|
5090
|
+
const delta = H - targetEntropy;
|
|
5091
|
+
if (Math.abs(delta) < 1e-5) break;
|
|
5092
|
+
if (delta > 0) {
|
|
5093
|
+
sigmaHi = sigma2;
|
|
5094
|
+
sigma2 = (sigmaLo + sigma2) / 2;
|
|
5095
|
+
} else {
|
|
5096
|
+
sigmaLo = sigma2;
|
|
5097
|
+
sigma2 = sigmaHi < 1e9 ? (sigma2 + sigmaHi) / 2 : sigma2 * 2;
|
|
5098
|
+
}
|
|
5099
|
+
}
|
|
5100
|
+
}
|
|
5101
|
+
return Pcond;
|
|
5102
|
+
}
|
|
5103
|
+
};
|
|
5104
|
+
function _pairwiseDistSq(X, n) {
|
|
5105
|
+
const D = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5106
|
+
for (let i = 0; i < n; i++) {
|
|
5107
|
+
for (let j = i + 1; j < n; j++) {
|
|
5108
|
+
let d = 0;
|
|
5109
|
+
for (let k = 0; k < X[i].length; k++) {
|
|
5110
|
+
const diff = X[i][k] - X[j][k];
|
|
5111
|
+
d += diff * diff;
|
|
5112
|
+
}
|
|
5113
|
+
D[i][j] = d;
|
|
5114
|
+
D[j][i] = d;
|
|
5115
|
+
}
|
|
5116
|
+
}
|
|
5117
|
+
return D;
|
|
5118
|
+
}
|
|
5119
|
+
function _symmetrize(Pcond, n) {
|
|
5120
|
+
const P = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5121
|
+
for (let i = 0; i < n; i++) {
|
|
5122
|
+
for (let j = 0; j < n; j++) {
|
|
5123
|
+
P[i][j] = (Pcond[i][j] + Pcond[j][i]) / (2 * n);
|
|
5124
|
+
}
|
|
5125
|
+
}
|
|
5126
|
+
return P;
|
|
5127
|
+
}
|
|
5128
|
+
function _computeQ(Y, n, nComponents) {
|
|
5129
|
+
const num = Array.from({ length: n }, () => new Array(n).fill(0));
|
|
5130
|
+
let Z = 0;
|
|
5131
|
+
for (let i = 0; i < n; i++) {
|
|
5132
|
+
for (let j = i + 1; j < n; j++) {
|
|
5133
|
+
let d2 = 0;
|
|
5134
|
+
for (let d = 0; d < nComponents; d++) {
|
|
5135
|
+
const diff = Y[i][d] - Y[j][d];
|
|
5136
|
+
d2 += diff * diff;
|
|
5137
|
+
}
|
|
5138
|
+
const inv = 1 / (1 + d2);
|
|
5139
|
+
num[i][j] = inv;
|
|
5140
|
+
num[j][i] = inv;
|
|
5141
|
+
Z += 2 * inv;
|
|
5142
|
+
}
|
|
5143
|
+
}
|
|
5144
|
+
if (Z < 1e-12) Z = 1e-12;
|
|
5145
|
+
const Q = Array.from(
|
|
5146
|
+
{ length: n },
|
|
5147
|
+
(_, i) => num[i].map((v) => v / Z)
|
|
5148
|
+
);
|
|
5149
|
+
return { Q, invDist: num };
|
|
5150
|
+
}
|
|
5151
|
+
function _mulberry32(seed) {
|
|
5152
|
+
let s = seed >>> 0;
|
|
5153
|
+
return function() {
|
|
5154
|
+
s = s + 1831565813 >>> 0;
|
|
5155
|
+
let z = s;
|
|
5156
|
+
z = Math.imul(z ^ z >>> 15, z | 1);
|
|
5157
|
+
z ^= z + Math.imul(z ^ z >>> 7, z | 61);
|
|
5158
|
+
z = (z ^ z >>> 14) >>> 0;
|
|
5159
|
+
return z / 4294967296;
|
|
5160
|
+
};
|
|
5161
|
+
}
|
|
5162
|
+
|
|
5163
|
+
// src/PositionalEncoding.ts
|
|
5164
|
+
var PositionalEncoding = class _PositionalEncoding {
|
|
5165
|
+
// Compute the full PE vector for one token at position `pos`.
|
|
5166
|
+
// Returns an array of length `dModel`.
|
|
5167
|
+
//
|
|
5168
|
+
// Each pair of dimensions (2i, 2i+1) shares the same frequency 1/10000^(2i/dModel)
|
|
5169
|
+
// but is 90° out of phase (sin vs cos), which ensures no two positions produce
|
|
5170
|
+
// the identical vector.
|
|
5171
|
+
static encode(pos, dModel) {
|
|
5172
|
+
const pe = new Array(dModel);
|
|
5173
|
+
for (let i = 0; i < Math.floor(dModel / 2); i++) {
|
|
5174
|
+
const freq = Math.pow(1e4, 2 * i / dModel);
|
|
5175
|
+
pe[2 * i] = Math.sin(pos / freq);
|
|
5176
|
+
pe[2 * i + 1] = Math.cos(pos / freq);
|
|
5177
|
+
}
|
|
5178
|
+
if (dModel % 2 !== 0) {
|
|
5179
|
+
const i = Math.floor(dModel / 2);
|
|
5180
|
+
const freq = Math.pow(1e4, 2 * i / dModel);
|
|
5181
|
+
pe[dModel - 1] = Math.sin(pos / freq);
|
|
5182
|
+
}
|
|
5183
|
+
return pe;
|
|
5184
|
+
}
|
|
5185
|
+
// Build the full positional encoding matrix for a sequence of `seqLen` tokens.
|
|
5186
|
+
// Returns shape [seqLen][dModel].
|
|
5187
|
+
//
|
|
5188
|
+
// In practice this matrix is computed once and cached — it doesn't change
|
|
5189
|
+
// across examples, batches, or epochs.
|
|
5190
|
+
static encodeSequence(seqLen, dModel) {
|
|
5191
|
+
return Array.from(
|
|
5192
|
+
{ length: seqLen },
|
|
5193
|
+
(_, pos) => _PositionalEncoding.encode(pos, dModel)
|
|
5194
|
+
);
|
|
5195
|
+
}
|
|
5196
|
+
// Add positional encoding to an existing embedding matrix (in-place on a copy).
|
|
5197
|
+
//
|
|
5198
|
+
// `embeddings` shape: [seqLen][dModel].
|
|
5199
|
+
// `seqLen` is optional; defaults to embeddings.length.
|
|
5200
|
+
//
|
|
5201
|
+
// The sum e = token_embedding + PE is what actually enters the first
|
|
5202
|
+
// Transformer layer. Summing (rather than concatenating) keeps the model
|
|
5203
|
+
// dimension fixed and lets the network distribute its capacity freely —
|
|
5204
|
+
// it can choose how much of each dimension to allocate to content vs. position.
|
|
5205
|
+
static apply(embeddings, seqLen) {
|
|
5206
|
+
const len = seqLen ?? embeddings.length;
|
|
5207
|
+
const dModel = embeddings[0].length;
|
|
5208
|
+
const pe = _PositionalEncoding.encodeSequence(len, dModel);
|
|
5209
|
+
return embeddings.map(
|
|
5210
|
+
(emb, pos) => emb.map((val, d) => val + pe[pos][d])
|
|
5211
|
+
);
|
|
5212
|
+
}
|
|
5213
|
+
};
|
|
5214
|
+
var LearnedPositionalEncoding = class {
|
|
5215
|
+
constructor(maxSeqLen, dModel) {
|
|
5216
|
+
this.maxSeqLen = maxSeqLen;
|
|
5217
|
+
this.dModel = dModel;
|
|
5218
|
+
const limit = Math.sqrt(1 / dModel);
|
|
5219
|
+
this.weights = Array.from(
|
|
5220
|
+
{ length: maxSeqLen },
|
|
5221
|
+
() => Array.from({ length: dModel }, () => (Math.random() * 2 - 1) * limit)
|
|
5222
|
+
);
|
|
5223
|
+
}
|
|
5224
|
+
// Return the learned encoding for one position.
|
|
5225
|
+
// Returns a copy so callers cannot accidentally mutate the weight table.
|
|
5226
|
+
getEncoding(pos) {
|
|
5227
|
+
if (pos >= this.maxSeqLen) {
|
|
5228
|
+
throw new Error(
|
|
5229
|
+
`Position ${pos} exceeds maxSeqLen=${this.maxSeqLen}. Learned encodings cannot generalize beyond their training length.`
|
|
5230
|
+
);
|
|
5231
|
+
}
|
|
5232
|
+
return [...this.weights[pos]];
|
|
5233
|
+
}
|
|
5234
|
+
// Add learned positional encodings to `embeddings` (returns a new matrix).
|
|
5235
|
+
// Shape: [seqLen][dModel] → [seqLen][dModel].
|
|
5236
|
+
apply(embeddings, seqLen) {
|
|
5237
|
+
const len = seqLen ?? embeddings.length;
|
|
5238
|
+
if (len > this.maxSeqLen) {
|
|
5239
|
+
throw new Error(
|
|
5240
|
+
`Sequence length ${len} exceeds maxSeqLen=${this.maxSeqLen}.`
|
|
5241
|
+
);
|
|
5242
|
+
}
|
|
5243
|
+
return embeddings.map(
|
|
5244
|
+
(emb, pos) => emb.map((val, d) => val + this.weights[pos][d])
|
|
5245
|
+
);
|
|
5246
|
+
}
|
|
5247
|
+
// Apply gradient update to position encoding weights.
|
|
5248
|
+
//
|
|
5249
|
+
// `dWeights` has the same shape as `weights`: [maxSeqLen][dModel].
|
|
5250
|
+
// Each entry is dL/dW_pos[pos][d] — the loss gradient w.r.t. that weight.
|
|
5251
|
+
//
|
|
5252
|
+
// Simple SGD is used here (matching EmbeddingMatrix in MatMul.ts):
|
|
5253
|
+
// position embeddings are updated every step for all positions in the batch,
|
|
5254
|
+
// so the sparse-update problem of token embeddings doesn't apply.
|
|
5255
|
+
update(dWeights, lr) {
|
|
5256
|
+
for (let pos = 0; pos < this.maxSeqLen; pos++) {
|
|
5257
|
+
for (let d = 0; d < this.dModel; d++) {
|
|
5258
|
+
this.weights[pos][d] += lr * dWeights[pos][d];
|
|
5259
|
+
}
|
|
5260
|
+
}
|
|
5261
|
+
}
|
|
5262
|
+
};
|
|
5263
|
+
|
|
5264
|
+
// src/ContrastiveLearning.ts
|
|
5265
|
+
var Augmenter = class _Augmenter {
|
|
5266
|
+
// Add zero-mean Gaussian noise with standard deviation `sigma`.
|
|
5267
|
+
//
|
|
5268
|
+
// Uses the Box-Muller transform to produce normally distributed noise from
|
|
5269
|
+
// two uniform random variables:
|
|
5270
|
+
// z = √(-2·ln(u₁)) · cos(2π·u₂) where u₁, u₂ ~ Uniform(0, 1)
|
|
5271
|
+
//
|
|
5272
|
+
// This keeps us dependency-free while yielding proper Gaussian samples.
|
|
5273
|
+
static addNoise(x, sigma = 0.05) {
|
|
5274
|
+
return x.map((v) => {
|
|
5275
|
+
const u1 = Math.max(1e-10, Math.random());
|
|
5276
|
+
const u2 = Math.random();
|
|
5277
|
+
const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
|
|
5278
|
+
return v + sigma * z;
|
|
5279
|
+
});
|
|
5280
|
+
}
|
|
5281
|
+
// Randomly zero out features with probability `rate`.
|
|
5282
|
+
//
|
|
5283
|
+
// Analogous to masking in BERT or random crops in vision contrastive learning.
|
|
5284
|
+
// The encoder must learn representations that are robust to missing features —
|
|
5285
|
+
// it cannot simply memorize individual dimensions.
|
|
5286
|
+
static dropoutFeatures(x, rate = 0.1) {
|
|
5287
|
+
return x.map((v) => Math.random() < rate ? 0 : v);
|
|
5288
|
+
}
|
|
5289
|
+
// Apply both noise and feature dropout in sequence.
|
|
5290
|
+
//
|
|
5291
|
+
// Combining augmentations is standard in SimCLR — stronger augmentations
|
|
5292
|
+
// force the encoder to learn more robust, abstract representations.
|
|
5293
|
+
static augment(x, noiseStd = 0.05, dropRate = 0.1) {
|
|
5294
|
+
return _Augmenter.dropoutFeatures(_Augmenter.addNoise(x, noiseStd), dropRate);
|
|
5295
|
+
}
|
|
5296
|
+
// Generate a positive pair: [original, augmented_copy].
|
|
5297
|
+
//
|
|
5298
|
+
// These two views are used as the (i, j) positive pair in NT-Xent.
|
|
5299
|
+
// Everything else in the batch acts as a negative.
|
|
5300
|
+
static makePair(x) {
|
|
5301
|
+
return [x, _Augmenter.augment(x)];
|
|
5302
|
+
}
|
|
5303
|
+
};
|
|
5304
|
+
var ContrastiveLearning = class _ContrastiveLearning {
|
|
5305
|
+
// encoderHidden: hidden layer sizes for the encoder (not counting input/output).
|
|
5306
|
+
// e.g. inputSize=64, encoderHidden=[256, 128] → NetworkN([64, 256, 128])
|
|
5307
|
+
// The encoder output dimension is encoderHidden[last].
|
|
5308
|
+
//
|
|
5309
|
+
// projectionDim: dimension of the projection head output (the z space).
|
|
5310
|
+
// e.g. 64. Typically smaller than the encoder's output.
|
|
5311
|
+
//
|
|
5312
|
+
// The encoder uses ReLU activations throughout — empirically stronger than
|
|
5313
|
+
// sigmoid for representation learning because it doesn't saturate.
|
|
5314
|
+
constructor(inputSize, encoderHidden, projectionDim, options = {}) {
|
|
5315
|
+
if (encoderHidden.length === 0) {
|
|
5316
|
+
throw new Error("encoderHidden must have at least one element.");
|
|
5317
|
+
}
|
|
5318
|
+
this.temperature = options.temperature ?? 0.5;
|
|
5319
|
+
const encoderStructure = [inputSize, ...encoderHidden];
|
|
5320
|
+
const encoderActivations = encoderHidden.map(() => relu);
|
|
5321
|
+
this.encoder = new NetworkN(encoderStructure, {
|
|
5322
|
+
activations: encoderActivations,
|
|
5323
|
+
...options.encoderOptions
|
|
5324
|
+
});
|
|
5325
|
+
const encoderOut = encoderHidden[encoderHidden.length - 1];
|
|
5326
|
+
const projHidden = Math.max(projectionDim, Math.floor(encoderOut / 2));
|
|
5327
|
+
this.projectionHead = new NetworkN(
|
|
5328
|
+
[encoderOut, projHidden, projectionDim],
|
|
5329
|
+
{ activations: [relu, relu] }
|
|
5330
|
+
);
|
|
5331
|
+
}
|
|
5332
|
+
// ── Inference (downstream tasks use this, not project()) ─────────────────
|
|
5333
|
+
//
|
|
5334
|
+
// Returns h — the encoder representation before the projection head.
|
|
5335
|
+
// This is the vector to use for classification, clustering, retrieval, etc.
|
|
5336
|
+
//
|
|
5337
|
+
// The projection head is only active during training.
|
|
5338
|
+
encode(x) {
|
|
5339
|
+
return this.encoder.predict(x);
|
|
5340
|
+
}
|
|
5341
|
+
// ── Training path: encode then project ───────────────────────────────────
|
|
5342
|
+
//
|
|
5343
|
+
// Returns z — the projected representation used to compute NT-Xent.
|
|
5344
|
+
// Do NOT use this for downstream tasks (see encode() above).
|
|
5345
|
+
project(x) {
|
|
5346
|
+
const h = this.encoder.predict(x);
|
|
5347
|
+
return this.projectionHead.predict(h);
|
|
5348
|
+
}
|
|
5349
|
+
// ── Cosine similarity ─────────────────────────────────────────────────────
|
|
5350
|
+
//
|
|
5351
|
+
// sim(u, v) = uᵀv / (||u|| · ||v||)
|
|
5352
|
+
//
|
|
5353
|
+
// Range: [-1, 1]. We use cosine rather than Euclidean distance because it is
|
|
5354
|
+
// scale-invariant — only the direction of the projection matters, not its
|
|
5355
|
+
// magnitude. This prevents the trivial solution of making ||z|| → ∞.
|
|
5356
|
+
static cosineSimilarity(a, b) {
|
|
5357
|
+
let dot = 0, normA = 0, normB = 0;
|
|
5358
|
+
for (let d = 0; d < a.length; d++) {
|
|
5359
|
+
dot += a[d] * b[d];
|
|
5360
|
+
normA += a[d] * a[d];
|
|
5361
|
+
normB += b[d] * b[d];
|
|
5362
|
+
}
|
|
5363
|
+
const denom = Math.sqrt(normA) * Math.sqrt(normB);
|
|
5364
|
+
return denom < 1e-10 ? 0 : dot / denom;
|
|
5365
|
+
}
|
|
5366
|
+
// ── NT-Xent loss (no weight update) ──────────────────────────────────────
|
|
5367
|
+
//
|
|
5368
|
+
// Forward-only pass. Used for validation / monitoring during training.
|
|
5369
|
+
computeLoss(pairs) {
|
|
5370
|
+
const { projections, N } = this._forwardProjections(pairs);
|
|
5371
|
+
return this._ntXentLoss(projections, N);
|
|
5372
|
+
}
|
|
5373
|
+
// ── Training step ─────────────────────────────────────────────────────────
|
|
5374
|
+
//
|
|
5375
|
+
// Given a batch of positive pairs, compute NT-Xent loss and update weights
|
|
5376
|
+
// via finite-difference gradient approximation.
|
|
5377
|
+
//
|
|
5378
|
+
// Full analytical backprop through NT-Xent is complex to implement from
|
|
5379
|
+
// scratch without an autograd engine. Finite differences are slower but
|
|
5380
|
+
// correct and keep the implementation readable for educational purposes.
|
|
5381
|
+
// For production use, couple this with the Tape (autograd) module.
|
|
5382
|
+
//
|
|
5383
|
+
// Step-by-step:
|
|
5384
|
+
// 1. Forward all 2N inputs through encoder + projection head → { z_i }.
|
|
5385
|
+
// 2. Build the 2N×2N cosine similarity matrix (scaled by 1/τ).
|
|
5386
|
+
// 3. For each anchor i, identify its positive pair and all 2N-2 negatives.
|
|
5387
|
+
// 4. Apply softmax over the row; loss = -log(softmax at positive index).
|
|
5388
|
+
// 5. Average over all 2N anchors.
|
|
5389
|
+
// 6. Approximate ∂L/∂w per weight with finite differences and apply update.
|
|
5390
|
+
//
|
|
5391
|
+
// Returns: NT-Xent loss before the weight update.
|
|
5392
|
+
trainStep(pairs, lr) {
|
|
5393
|
+
const loss = this.computeLoss(pairs);
|
|
5394
|
+
const eps = 1e-4;
|
|
5395
|
+
for (const layer of this.encoder.layers) {
|
|
5396
|
+
for (const neuron of layer.neurons) {
|
|
5397
|
+
for (let j = 0; j < neuron.weights.length; j++) {
|
|
5398
|
+
neuron.weights[j] += eps;
|
|
5399
|
+
const lossPlus2 = this.computeLoss(pairs);
|
|
5400
|
+
neuron.weights[j] -= 2 * eps;
|
|
5401
|
+
const lossMinus2 = this.computeLoss(pairs);
|
|
5402
|
+
neuron.weights[j] += eps;
|
|
5403
|
+
const grad2 = (lossPlus2 - lossMinus2) / (2 * eps);
|
|
5404
|
+
neuron.weights[j] += lr * -grad2;
|
|
5405
|
+
}
|
|
5406
|
+
neuron.bias += eps;
|
|
5407
|
+
const lossPlus = this.computeLoss(pairs);
|
|
5408
|
+
neuron.bias -= 2 * eps;
|
|
5409
|
+
const lossMinus = this.computeLoss(pairs);
|
|
5410
|
+
neuron.bias += eps;
|
|
5411
|
+
const grad = (lossPlus - lossMinus) / (2 * eps);
|
|
5412
|
+
neuron.bias += lr * -grad;
|
|
5413
|
+
}
|
|
5414
|
+
}
|
|
5415
|
+
for (const layer of this.projectionHead.layers) {
|
|
5416
|
+
for (const neuron of layer.neurons) {
|
|
5417
|
+
for (let j = 0; j < neuron.weights.length; j++) {
|
|
5418
|
+
neuron.weights[j] += eps;
|
|
5419
|
+
const lossPlus2 = this.computeLoss(pairs);
|
|
5420
|
+
neuron.weights[j] -= 2 * eps;
|
|
5421
|
+
const lossMinus2 = this.computeLoss(pairs);
|
|
5422
|
+
neuron.weights[j] += eps;
|
|
5423
|
+
const grad2 = (lossPlus2 - lossMinus2) / (2 * eps);
|
|
5424
|
+
neuron.weights[j] += lr * -grad2;
|
|
5425
|
+
}
|
|
5426
|
+
neuron.bias += eps;
|
|
5427
|
+
const lossPlus = this.computeLoss(pairs);
|
|
5428
|
+
neuron.bias -= 2 * eps;
|
|
5429
|
+
const lossMinus = this.computeLoss(pairs);
|
|
5430
|
+
neuron.bias += eps;
|
|
5431
|
+
const grad = (lossPlus - lossMinus) / (2 * eps);
|
|
5432
|
+
neuron.bias += lr * -grad;
|
|
5433
|
+
}
|
|
5434
|
+
}
|
|
5435
|
+
return loss;
|
|
5436
|
+
}
|
|
5437
|
+
// ── Private: forward all pairs through the projection head ───────────────
|
|
5438
|
+
//
|
|
5439
|
+
// Returns a flat array of 2N projections.
|
|
5440
|
+
// Layout: [ z_0, z_0', z_1, z_1', ..., z_{N-1}, z_{N-1}' ]
|
|
5441
|
+
// Even indices 2i → original view of pair i
|
|
5442
|
+
// Odd indices 2i+1 → augmented view of pair i (the positive)
|
|
5443
|
+
_forwardProjections(pairs) {
|
|
5444
|
+
const N = pairs.length;
|
|
5445
|
+
const projections = [];
|
|
5446
|
+
for (const [x, xAug] of pairs) {
|
|
5447
|
+
projections.push(this.project(x));
|
|
5448
|
+
projections.push(this.project(xAug));
|
|
5449
|
+
}
|
|
5450
|
+
return { projections, N };
|
|
5451
|
+
}
|
|
5452
|
+
// ── Private: NT-Xent loss over a set of 2N projections ───────────────────
|
|
5453
|
+
//
|
|
5454
|
+
// pairs[2i] and pairs[2i+1] are positives.
|
|
5455
|
+
// All other 2N-2 samples are negatives for each anchor.
|
|
5456
|
+
_ntXentLoss(projections, N) {
|
|
5457
|
+
const total = 2 * N;
|
|
5458
|
+
const tau = this.temperature;
|
|
5459
|
+
const sim = Array.from(
|
|
5460
|
+
{ length: total },
|
|
5461
|
+
(_, i) => Array.from(
|
|
5462
|
+
{ length: total },
|
|
5463
|
+
(_2, j) => _ContrastiveLearning.cosineSimilarity(projections[i], projections[j]) / tau
|
|
5464
|
+
)
|
|
5465
|
+
);
|
|
5466
|
+
let totalLoss = 0;
|
|
5467
|
+
for (let i = 0; i < total; i++) {
|
|
5468
|
+
const posIdx = i % 2 === 0 ? i + 1 : i - 1;
|
|
5469
|
+
const numerator = Math.exp(sim[i][posIdx]);
|
|
5470
|
+
let denominator = 0;
|
|
5471
|
+
for (let k = 0; k < total; k++) {
|
|
5472
|
+
if (k !== i) {
|
|
5473
|
+
denominator += Math.exp(sim[i][k]);
|
|
5474
|
+
}
|
|
5475
|
+
}
|
|
5476
|
+
totalLoss += -Math.log(numerator / (denominator + 1e-10));
|
|
5477
|
+
}
|
|
5478
|
+
return totalLoss / total;
|
|
5479
|
+
}
|
|
5480
|
+
};
|
|
5481
|
+
|
|
5482
|
+
// src/GAN.ts
|
|
5483
|
+
var GAN = class {
|
|
5484
|
+
constructor(latentDim, generatorHidden, outputDim, discriminatorHidden, options) {
|
|
5485
|
+
this.latentDim = latentDim;
|
|
5486
|
+
const gStructure = [latentDim, ...generatorHidden, outputDim];
|
|
5487
|
+
this.generator = new NetworkN(gStructure, options?.generatorOptions ?? {});
|
|
5488
|
+
const dStructure = [outputDim, ...discriminatorHidden, 1];
|
|
5489
|
+
this.discriminator = new NetworkN(dStructure, options?.discriminatorOptions ?? {});
|
|
5490
|
+
}
|
|
5491
|
+
// ── Public API ───────────────────────────────────────────────────────────
|
|
5492
|
+
// Generate a synthetic sample. If z is not provided, samples from N(0, 1).
|
|
5493
|
+
generate(z) {
|
|
5494
|
+
const latent = z ?? this.sampleLatent();
|
|
5495
|
+
return this.generator.predict(latent);
|
|
5496
|
+
}
|
|
5497
|
+
// Returns the discriminator's estimate that x is real, in [0, 1].
|
|
5498
|
+
discriminate(x) {
|
|
5499
|
+
return this.discriminator.predict(x)[0];
|
|
5500
|
+
}
|
|
5501
|
+
// ── Training Step ────────────────────────────────────────────────────────
|
|
5502
|
+
//
|
|
5503
|
+
// Runs one discriminator update and one generator update over the provided
|
|
5504
|
+
// real batch. Returns per-step losses for monitoring.
|
|
5505
|
+
//
|
|
5506
|
+
// Discriminator loss (binary cross-entropy, minimised via SGD):
|
|
5507
|
+
// L_D = -[ log D(x_real) + log(1 - D(G(z))) ]
|
|
5508
|
+
//
|
|
5509
|
+
// Generator loss:
|
|
5510
|
+
// L_G = -log D(G(z)) (non-saturating variant — avoids vanishing gradients
|
|
5511
|
+
// in early training when D is confident)
|
|
5512
|
+
//
|
|
5513
|
+
trainStep(realBatch, lr) {
|
|
5514
|
+
const eps = 1e-15;
|
|
5515
|
+
let dLossSum = 0;
|
|
5516
|
+
let gLossSum = 0;
|
|
5517
|
+
for (const xReal of realBatch) {
|
|
5518
|
+
const dReal = Math.max(eps, Math.min(1 - eps, this.discriminate(xReal)));
|
|
5519
|
+
const dRealDelta = [1 - dReal];
|
|
5520
|
+
this.discriminator.trainWithDeltas(xReal, dRealDelta, lr);
|
|
5521
|
+
dLossSum += -Math.log(dReal);
|
|
5522
|
+
const z = this.sampleLatent();
|
|
5523
|
+
const xFake = this.generate(z);
|
|
5524
|
+
const dFake = Math.max(eps, Math.min(1 - eps, this.discriminate(xFake)));
|
|
5525
|
+
const dFakeDelta = [0 - dFake];
|
|
5526
|
+
this.discriminator.trainWithDeltas(xFake, dFakeDelta, lr);
|
|
5527
|
+
dLossSum += -Math.log(1 - dFake);
|
|
5528
|
+
const z2 = this.sampleLatent();
|
|
5529
|
+
const xFake2 = this.generate(z2);
|
|
5530
|
+
const dScore = Math.max(eps, Math.min(1 - eps, this.discriminate(xFake2)));
|
|
5531
|
+
const gError = 1 - dScore;
|
|
5532
|
+
const gDelta = xFake2.map(() => gError / xFake2.length);
|
|
5533
|
+
this.generator.trainWithDeltas(z2, gDelta, lr);
|
|
5534
|
+
gLossSum += -Math.log(dScore);
|
|
5535
|
+
}
|
|
5536
|
+
const n = realBatch.length;
|
|
5537
|
+
return {
|
|
5538
|
+
dLoss: dLossSum / n,
|
|
5539
|
+
gLoss: gLossSum / n
|
|
5540
|
+
};
|
|
5541
|
+
}
|
|
5542
|
+
// Samples a latent vector z ~ N(0, 1)^latentDim using Box-Muller transform.
|
|
5543
|
+
sampleLatent() {
|
|
5544
|
+
const z = [];
|
|
5545
|
+
for (let i = 0; i < this.latentDim; i += 2) {
|
|
5546
|
+
const u1 = Math.random();
|
|
5547
|
+
const u2 = Math.random();
|
|
5548
|
+
const r = Math.sqrt(-2 * Math.log(u1 + 1e-15));
|
|
5549
|
+
const theta = 2 * Math.PI * u2;
|
|
5550
|
+
z.push(r * Math.cos(theta));
|
|
5551
|
+
if (i + 1 < this.latentDim) z.push(r * Math.sin(theta));
|
|
5552
|
+
}
|
|
5553
|
+
return z;
|
|
5554
|
+
}
|
|
5555
|
+
};
|
|
5556
|
+
|
|
5557
|
+
// src/VAE.ts
|
|
5558
|
+
var VAE = class {
|
|
5559
|
+
constructor(inputSize, encoderHidden, latentDim, decoderHidden, options) {
|
|
5560
|
+
this.latentDim = latentDim;
|
|
5561
|
+
const encoderStructure = [inputSize, ...encoderHidden, latentDim * 2];
|
|
5562
|
+
this.encoder = new NetworkN(encoderStructure, options ?? {});
|
|
5563
|
+
const decoderStructure = [latentDim, ...decoderHidden, inputSize];
|
|
5564
|
+
this.decoder = new NetworkN(decoderStructure, options ?? {});
|
|
5565
|
+
}
|
|
5566
|
+
// ── Encode ───────────────────────────────────────────────────────────────
|
|
5567
|
+
// Splits the encoder output into μ and logVar vectors.
|
|
5568
|
+
encode(x) {
|
|
5569
|
+
const out = this.encoder.predict(x);
|
|
5570
|
+
const mu = out.slice(0, this.latentDim);
|
|
5571
|
+
const logVar = out.slice(this.latentDim);
|
|
5572
|
+
return { mu, logVar };
|
|
5573
|
+
}
|
|
5574
|
+
// ── Reparametrisation Trick ──────────────────────────────────────────────
|
|
5575
|
+
// z = μ + σ·ε, ε ~ N(0,1)
|
|
5576
|
+
// σ = exp(0.5 · logVar) (ensures σ > 0 without constraining the network)
|
|
5577
|
+
reparametrize(mu, logVar) {
|
|
5578
|
+
return mu.map((m, i) => {
|
|
5579
|
+
const sigma = Math.exp(0.5 * logVar[i]);
|
|
5580
|
+
const eps = this._sampleNormal();
|
|
5581
|
+
return m + sigma * eps;
|
|
5582
|
+
});
|
|
5583
|
+
}
|
|
5584
|
+
// ── Decode ───────────────────────────────────────────────────────────────
|
|
5585
|
+
decode(z) {
|
|
5586
|
+
return this.decoder.predict(z);
|
|
5587
|
+
}
|
|
5588
|
+
// ── Forward Pass ─────────────────────────────────────────────────────────
|
|
5589
|
+
// Encodes, samples z, and decodes.
|
|
5590
|
+
forward(x) {
|
|
5591
|
+
const { mu, logVar } = this.encode(x);
|
|
5592
|
+
const z = this.reparametrize(mu, logVar);
|
|
5593
|
+
const reconstruction = this.decode(z);
|
|
5594
|
+
return { reconstruction, mu, logVar, z };
|
|
5595
|
+
}
|
|
5596
|
+
// ── Training Step ────────────────────────────────────────────────────────
|
|
5597
|
+
//
|
|
5598
|
+
// Performs one forward pass, computes the ELBO loss, and updates both
|
|
5599
|
+
// encoder and decoder weights via their built-in SGD.
|
|
5600
|
+
//
|
|
5601
|
+
// Reconstruction loss: L_recon = MSE(x, x̂)
|
|
5602
|
+
// KL divergence: L_kl = -½ Σ(1 + logVarᵢ - μᵢ² - exp(logVarᵢ))
|
|
5603
|
+
// Total: L = L_recon + L_kl
|
|
5604
|
+
//
|
|
5605
|
+
train(x, lr) {
|
|
5606
|
+
const { reconstruction, mu, logVar, z } = this.forward(x);
|
|
5607
|
+
const reconLoss = x.reduce((s, xi, i) => s + (xi - reconstruction[i]) ** 2, 0) / x.length;
|
|
5608
|
+
const klLoss = mu.reduce((s, m, i) => {
|
|
5609
|
+
return s - 0.5 * (1 + logVar[i] - m * m - Math.exp(logVar[i]));
|
|
5610
|
+
}, 0);
|
|
5611
|
+
const totalLoss = reconLoss + klLoss;
|
|
5612
|
+
const decoderDeltas = reconstruction.map((r, i) => (x[i] - r) / x.length);
|
|
5613
|
+
this.decoder.trainWithDeltas(z, decoderDeltas, lr);
|
|
5614
|
+
const encoderDeltas = [
|
|
5615
|
+
...mu.map((m) => -m),
|
|
5616
|
+
...logVar.map((lv) => -0.5 * (Math.exp(lv) - 1))
|
|
5617
|
+
];
|
|
5618
|
+
this.encoder.trainWithDeltas(x, encoderDeltas, lr);
|
|
5619
|
+
return { totalLoss, reconLoss, klLoss };
|
|
5620
|
+
}
|
|
5621
|
+
// ── Generate ─────────────────────────────────────────────────────────────
|
|
5622
|
+
// Samples z ~ N(0, I) and decodes it (pure generation, no input required).
|
|
5623
|
+
generate(z) {
|
|
5624
|
+
const latent = z ?? Array.from({ length: this.latentDim }, () => this._sampleNormal());
|
|
5625
|
+
return this.decode(latent);
|
|
5626
|
+
}
|
|
5627
|
+
// ── Private ──────────────────────────────────────────────────────────────
|
|
5628
|
+
// Box-Muller transform: samples one value from N(0, 1).
|
|
5629
|
+
_sampleNormal() {
|
|
5630
|
+
const u1 = Math.random();
|
|
5631
|
+
const u2 = Math.random();
|
|
5632
|
+
return Math.sqrt(-2 * Math.log(u1 + 1e-15)) * Math.cos(2 * Math.PI * u2);
|
|
5633
|
+
}
|
|
5634
|
+
};
|
|
5635
|
+
|
|
5636
|
+
// src/Tape.ts
|
|
5637
|
+
var Value = class _Value {
|
|
5638
|
+
constructor(data, children = [], op = "") {
|
|
5639
|
+
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
|
5640
|
+
this._backward = () => {
|
|
5641
|
+
};
|
|
5642
|
+
this.data = data;
|
|
5643
|
+
this.grad = 0;
|
|
5644
|
+
this._prev = new Set(children);
|
|
5645
|
+
this._op = op;
|
|
5646
|
+
}
|
|
5647
|
+
// ── Arithmetic Operations ────────────────────────────────────────────────
|
|
5648
|
+
// z = a + b → ∂z/∂a = 1, ∂z/∂b = 1
|
|
5649
|
+
add(other) {
|
|
5650
|
+
const o = other instanceof _Value ? other : new _Value(other);
|
|
5651
|
+
const out = new _Value(this.data + o.data, [this, o], "+");
|
|
5652
|
+
out._backward = () => {
|
|
5653
|
+
this.grad += out.grad;
|
|
5654
|
+
o.grad += out.grad;
|
|
5655
|
+
};
|
|
5656
|
+
return out;
|
|
5657
|
+
}
|
|
5658
|
+
// z = a * b → ∂z/∂a = b, ∂z/∂b = a
|
|
5659
|
+
mul(other) {
|
|
5660
|
+
const o = other instanceof _Value ? other : new _Value(other);
|
|
5661
|
+
const out = new _Value(this.data * o.data, [this, o], "*");
|
|
5662
|
+
out._backward = () => {
|
|
5663
|
+
this.grad += o.data * out.grad;
|
|
5664
|
+
o.grad += this.data * out.grad;
|
|
5665
|
+
};
|
|
5666
|
+
return out;
|
|
5667
|
+
}
|
|
5668
|
+
// z = aⁿ → ∂z/∂a = n·aⁿ⁻¹
|
|
5669
|
+
pow(exp) {
|
|
5670
|
+
const out = new _Value(Math.pow(this.data, exp), [this], `**${exp}`);
|
|
5671
|
+
out._backward = () => {
|
|
5672
|
+
this.grad += exp * Math.pow(this.data, exp - 1) * out.grad;
|
|
5673
|
+
};
|
|
5674
|
+
return out;
|
|
5675
|
+
}
|
|
5676
|
+
// z = max(0, a) → ∂z/∂a = a > 0 ? 1 : 0
|
|
5677
|
+
relu() {
|
|
5678
|
+
const out = new _Value(Math.max(0, this.data), [this], "ReLU");
|
|
5679
|
+
out._backward = () => {
|
|
5680
|
+
this.grad += (out.data > 0 ? 1 : 0) * out.grad;
|
|
5681
|
+
};
|
|
5682
|
+
return out;
|
|
5683
|
+
}
|
|
5684
|
+
// z = tanh(a) → ∂z/∂a = 1 - tanh(a)² = 1 - z²
|
|
5685
|
+
tanh() {
|
|
5686
|
+
const t = Math.tanh(this.data);
|
|
5687
|
+
const out = new _Value(t, [this], "tanh");
|
|
5688
|
+
out._backward = () => {
|
|
5689
|
+
this.grad += (1 - t * t) * out.grad;
|
|
5690
|
+
};
|
|
5691
|
+
return out;
|
|
5692
|
+
}
|
|
5693
|
+
// z = σ(a) = 1/(1+e⁻ᵃ) → ∂z/∂a = z·(1-z)
|
|
5694
|
+
sigmoid() {
|
|
5695
|
+
const s = 1 / (1 + Math.exp(-this.data));
|
|
5696
|
+
const out = new _Value(s, [this], "sigmoid");
|
|
5697
|
+
out._backward = () => {
|
|
5698
|
+
this.grad += s * (1 - s) * out.grad;
|
|
5699
|
+
};
|
|
5700
|
+
return out;
|
|
5701
|
+
}
|
|
5702
|
+
// z = eᵃ → ∂z/∂a = eᵃ = z
|
|
5703
|
+
exp() {
|
|
5704
|
+
const e = Math.exp(this.data);
|
|
5705
|
+
const out = new _Value(e, [this], "exp");
|
|
5706
|
+
out._backward = () => {
|
|
5707
|
+
this.grad += e * out.grad;
|
|
5708
|
+
};
|
|
5709
|
+
return out;
|
|
5710
|
+
}
|
|
5711
|
+
// ── Derived Operations (built from primitives) ───────────────────────────
|
|
5712
|
+
// a / b = a * b⁻¹
|
|
5713
|
+
div(other) {
|
|
5714
|
+
const o = other instanceof _Value ? other : new _Value(other);
|
|
5715
|
+
return this.mul(o.pow(-1));
|
|
5716
|
+
}
|
|
5717
|
+
// a - b = a + (b * -1)
|
|
5718
|
+
sub(other) {
|
|
5719
|
+
const o = other instanceof _Value ? other : new _Value(other);
|
|
5720
|
+
return this.add(o.mul(-1));
|
|
5721
|
+
}
|
|
5722
|
+
// -a = a * -1
|
|
5723
|
+
neg() {
|
|
5724
|
+
return this.mul(-1);
|
|
5725
|
+
}
|
|
5726
|
+
// ── Backward Pass ────────────────────────────────────────────────────────
|
|
5727
|
+
//
|
|
5728
|
+
// Propagates gradients from this node (treated as the loss L) back through
|
|
5729
|
+
// the entire computational graph.
|
|
5730
|
+
//
|
|
5731
|
+
// Steps:
|
|
5732
|
+
// 1. Build a topological ordering of all ancestor nodes.
|
|
5733
|
+
// 2. Set this.grad = 1 (∂L/∂L = 1).
|
|
5734
|
+
// 3. Visit nodes in reverse topological order, calling each _backward.
|
|
5735
|
+
//
|
|
5736
|
+
backward() {
|
|
5737
|
+
const topo = [];
|
|
5738
|
+
const visited = /* @__PURE__ */ new Set();
|
|
5739
|
+
const buildTopo = (v) => {
|
|
5740
|
+
if (!visited.has(v)) {
|
|
5741
|
+
visited.add(v);
|
|
5742
|
+
for (const child of v._prev) buildTopo(child);
|
|
5743
|
+
topo.push(v);
|
|
5744
|
+
}
|
|
5745
|
+
};
|
|
5746
|
+
buildTopo(this);
|
|
5747
|
+
this.grad = 1;
|
|
5748
|
+
for (let i = topo.length - 1; i >= 0; i--) {
|
|
5749
|
+
topo[i]._backward();
|
|
5750
|
+
}
|
|
5751
|
+
}
|
|
5752
|
+
toString() {
|
|
5753
|
+
return `Value(data=${this.data.toFixed(4)}, grad=${this.grad.toFixed(4)}, op='${this._op}')`;
|
|
5754
|
+
}
|
|
5755
|
+
};
|
|
5756
|
+
|
|
5757
|
+
// src/WeightInspector.ts
|
|
5758
|
+
var WeightInspector = class _WeightInspector {
|
|
5759
|
+
// ── Per-layer statistics ─────────────────────────────────────────────────
|
|
5760
|
+
// Returns one WeightStats per layer in network.layers order.
|
|
5761
|
+
static inspect(network, deadThreshold = 1e-3) {
|
|
5762
|
+
return network.layers.map((layer) => {
|
|
5763
|
+
const weights = [];
|
|
5764
|
+
for (const neuron of layer.neurons) {
|
|
5765
|
+
weights.push(...neuron.weights, neuron.bias);
|
|
5766
|
+
}
|
|
5767
|
+
return _computeStats(weights, deadThreshold);
|
|
5768
|
+
});
|
|
5769
|
+
}
|
|
5770
|
+
// ── Global statistics ────────────────────────────────────────────────────
|
|
5771
|
+
// Aggregates all weights across the entire network.
|
|
5772
|
+
static inspectAll(network, deadThreshold = 1e-3) {
|
|
5773
|
+
const allWeights = [];
|
|
5774
|
+
for (const layer of network.layers) {
|
|
5775
|
+
for (const neuron of layer.neurons) {
|
|
5776
|
+
allWeights.push(...neuron.weights, neuron.bias);
|
|
5777
|
+
}
|
|
5778
|
+
}
|
|
5779
|
+
return _computeStats(allWeights, deadThreshold);
|
|
5780
|
+
}
|
|
5781
|
+
// ── Formatted table ──────────────────────────────────────────────────────
|
|
5782
|
+
// Prints a compact diagnostic table to the console.
|
|
5783
|
+
static print(network, deadThreshold = 1e-3) {
|
|
5784
|
+
const perLayer = _WeightInspector.inspect(network, deadThreshold);
|
|
5785
|
+
const global = _WeightInspector.inspectAll(network, deadThreshold);
|
|
5786
|
+
const header = [
|
|
5787
|
+
"Layer".padEnd(8),
|
|
5788
|
+
"mean".padStart(9),
|
|
5789
|
+
"std".padStart(9),
|
|
5790
|
+
"min".padStart(9),
|
|
5791
|
+
"max".padStart(9),
|
|
5792
|
+
"dead".padStart(11),
|
|
5793
|
+
"params".padStart(8)
|
|
5794
|
+
].join(" ");
|
|
5795
|
+
console.log("");
|
|
5796
|
+
console.log("Weight Inspector:");
|
|
5797
|
+
console.log("\u2500".repeat(header.length));
|
|
5798
|
+
console.log(header);
|
|
5799
|
+
console.log("\u2500".repeat(header.length));
|
|
5800
|
+
perLayer.forEach((s, i) => {
|
|
5801
|
+
console.log(_formatRow(`Layer ${i}`, s));
|
|
5802
|
+
});
|
|
5803
|
+
console.log("\u2500".repeat(header.length));
|
|
5804
|
+
console.log(_formatRow("Global", global));
|
|
5805
|
+
console.log("");
|
|
5806
|
+
}
|
|
5807
|
+
// ── Dead ReLU detection ──────────────────────────────────────────────────
|
|
5808
|
+
//
|
|
5809
|
+
// Given a matrix of activations collected over a forward pass (rows = samples,
|
|
5810
|
+
// cols = neurons), counts neurons that output exactly 0 for every sample.
|
|
5811
|
+
//
|
|
5812
|
+
// How to collect activations:
|
|
5813
|
+
// Run net.predict() for each validation sample and record the output of
|
|
5814
|
+
// each hidden layer. Pass those as `activations` here.
|
|
5815
|
+
//
|
|
5816
|
+
// threshold: activations below this are counted as "dead" (default: 1e-6).
|
|
5817
|
+
//
|
|
5818
|
+
static countDeadReLUs(activations, threshold = 1e-6) {
|
|
5819
|
+
if (activations.length === 0) return 0;
|
|
5820
|
+
const numNeurons = activations[0].length;
|
|
5821
|
+
let dead = 0;
|
|
5822
|
+
for (let j = 0; j < numNeurons; j++) {
|
|
5823
|
+
const allDead = activations.every((row) => Math.abs(row[j]) < threshold);
|
|
5824
|
+
if (allDead) dead++;
|
|
5825
|
+
}
|
|
5826
|
+
return dead;
|
|
5827
|
+
}
|
|
5828
|
+
};
|
|
5829
|
+
function _computeStats(weights, deadThreshold) {
|
|
5830
|
+
const n = weights.length;
|
|
5831
|
+
if (n === 0) {
|
|
5832
|
+
return { mean: 0, std: 0, min: 0, max: 0, deadCount: 0, totalParams: 0 };
|
|
5833
|
+
}
|
|
5834
|
+
let sum = 0, sumSq = 0, min = Infinity, max = -Infinity, deadCount = 0;
|
|
5835
|
+
for (const w of weights) {
|
|
5836
|
+
sum += w;
|
|
5837
|
+
sumSq += w * w;
|
|
5838
|
+
if (w < min) min = w;
|
|
5839
|
+
if (w > max) max = w;
|
|
5840
|
+
if (Math.abs(w) < deadThreshold) deadCount++;
|
|
5841
|
+
}
|
|
5842
|
+
const mean = sum / n;
|
|
5843
|
+
const variance = sumSq / n - mean * mean;
|
|
5844
|
+
const std = Math.sqrt(Math.max(0, variance));
|
|
5845
|
+
return { mean, std, min, max, deadCount, totalParams: n };
|
|
5846
|
+
}
|
|
5847
|
+
function _fmt(n) {
|
|
5848
|
+
return (n >= 0 ? " " : "") + n.toFixed(4);
|
|
5849
|
+
}
|
|
5850
|
+
function _formatRow(label, s) {
|
|
5851
|
+
const deadStr = `${s.deadCount}/${s.totalParams}`;
|
|
5852
|
+
return [
|
|
5853
|
+
label.padEnd(8),
|
|
5854
|
+
_fmt(s.mean).padStart(9),
|
|
5855
|
+
_fmt(s.std).padStart(9),
|
|
5856
|
+
_fmt(s.min).padStart(9),
|
|
5857
|
+
_fmt(s.max).padStart(9),
|
|
5858
|
+
deadStr.padStart(11),
|
|
5859
|
+
String(s.totalParams).padStart(8)
|
|
5860
|
+
].join(" ");
|
|
5861
|
+
}
|
|
5862
|
+
|
|
5863
|
+
// src/Metrics.ts
|
|
5864
|
+
function confusionMatrix(yTrue, yPred, numClasses) {
|
|
5865
|
+
const K = numClasses ?? Math.max(...yTrue, ...yPred) + 1;
|
|
5866
|
+
const matrix = Array.from({ length: K }, () => new Array(K).fill(0));
|
|
5867
|
+
for (let i = 0; i < yTrue.length; i++) {
|
|
5868
|
+
matrix[yTrue[i]][yPred[i]]++;
|
|
5869
|
+
}
|
|
5870
|
+
return matrix;
|
|
5871
|
+
}
|
|
5872
|
+
function precision(yTrue, yPred, positiveClass) {
|
|
5873
|
+
if (positiveClass !== void 0) {
|
|
5874
|
+
return _binaryPrecision(yTrue, yPred, positiveClass);
|
|
5875
|
+
}
|
|
5876
|
+
const K = Math.max(...yTrue, ...yPred) + 1;
|
|
5877
|
+
let sum = 0;
|
|
5878
|
+
for (let c = 0; c < K; c++) sum += _binaryPrecision(yTrue, yPred, c);
|
|
5879
|
+
return sum / K;
|
|
5880
|
+
}
|
|
5881
|
+
function recall(yTrue, yPred, positiveClass) {
|
|
5882
|
+
if (positiveClass !== void 0) {
|
|
5883
|
+
return _binaryRecall(yTrue, yPred, positiveClass);
|
|
5884
|
+
}
|
|
5885
|
+
const K = Math.max(...yTrue, ...yPred) + 1;
|
|
5886
|
+
let sum = 0;
|
|
5887
|
+
for (let c = 0; c < K; c++) sum += _binaryRecall(yTrue, yPred, c);
|
|
5888
|
+
return sum / K;
|
|
5889
|
+
}
|
|
5890
|
+
function f1Score(yTrue, yPred, positiveClass) {
|
|
5891
|
+
const p = precision(yTrue, yPred, positiveClass);
|
|
5892
|
+
const r = recall(yTrue, yPred, positiveClass);
|
|
5893
|
+
if (p + r === 0) return 0;
|
|
5894
|
+
return 2 * p * r / (p + r);
|
|
5895
|
+
}
|
|
5896
|
+
function accuracy(yTrue, yPred) {
|
|
5897
|
+
if (yTrue.length === 0) return 0;
|
|
5898
|
+
const correct = yTrue.filter((y, i) => y === yPred[i]).length;
|
|
5899
|
+
return correct / yTrue.length;
|
|
5900
|
+
}
|
|
5901
|
+
function rocCurve(yTrue, yScores) {
|
|
5902
|
+
const thresholds = [...new Set(yScores)].sort((a, b) => b - a);
|
|
5903
|
+
thresholds.unshift(thresholds[0] + 1);
|
|
5904
|
+
const P = yTrue.filter((y) => y === 1).length;
|
|
5905
|
+
const N = yTrue.length - P;
|
|
5906
|
+
const points = [];
|
|
5907
|
+
for (const t of thresholds) {
|
|
5908
|
+
let tp = 0, fp = 0;
|
|
5909
|
+
for (let i = 0; i < yTrue.length; i++) {
|
|
5910
|
+
const pred = yScores[i] >= t ? 1 : 0;
|
|
5911
|
+
if (pred === 1 && yTrue[i] === 1) tp++;
|
|
5912
|
+
if (pred === 1 && yTrue[i] === 0) fp++;
|
|
5913
|
+
}
|
|
5914
|
+
points.push({
|
|
5915
|
+
threshold: t,
|
|
5916
|
+
fpr: N > 0 ? fp / N : 0,
|
|
5917
|
+
tpr: P > 0 ? tp / P : 0
|
|
5918
|
+
});
|
|
5919
|
+
}
|
|
5920
|
+
return points.sort((a, b) => a.fpr - b.fpr);
|
|
5921
|
+
}
|
|
5922
|
+
function auc(yTrue, yScores) {
|
|
5923
|
+
const curve = rocCurve(yTrue, yScores);
|
|
5924
|
+
let area = 0;
|
|
5925
|
+
for (let i = 1; i < curve.length; i++) {
|
|
5926
|
+
const dx = curve[i].fpr - curve[i - 1].fpr;
|
|
5927
|
+
const avgY = (curve[i].tpr + curve[i - 1].tpr) / 2;
|
|
5928
|
+
area += dx * avgY;
|
|
5929
|
+
}
|
|
5930
|
+
return Math.abs(area);
|
|
5931
|
+
}
|
|
5932
|
+
function mae(yTrue, yPred) {
|
|
5933
|
+
return yTrue.reduce((s, y, i) => s + Math.abs(y - yPred[i]), 0) / yTrue.length;
|
|
5934
|
+
}
|
|
5935
|
+
function rmse(yTrue, yPred) {
|
|
5936
|
+
const mseVal = yTrue.reduce((s, y, i) => s + (y - yPred[i]) ** 2, 0) / yTrue.length;
|
|
5937
|
+
return Math.sqrt(mseVal);
|
|
5938
|
+
}
|
|
5939
|
+
function r2Score(yTrue, yPred) {
|
|
5940
|
+
const mean = yTrue.reduce((s, y) => s + y, 0) / yTrue.length;
|
|
5941
|
+
const ssTot = yTrue.reduce((s, y) => s + (y - mean) ** 2, 0);
|
|
5942
|
+
const ssRes = yTrue.reduce((s, y, i) => s + (y - yPred[i]) ** 2, 0);
|
|
5943
|
+
if (ssTot === 0) return 1;
|
|
5944
|
+
return 1 - ssRes / ssTot;
|
|
5945
|
+
}
|
|
5946
|
+
function perplexity(yTrue, probabilities) {
|
|
5947
|
+
const eps = 1e-15;
|
|
5948
|
+
const T = yTrue.length;
|
|
5949
|
+
let logSum = 0;
|
|
5950
|
+
for (let t = 0; t < T; t++) {
|
|
5951
|
+
const p = Math.max(eps, probabilities[t][yTrue[t]]);
|
|
5952
|
+
logSum += Math.log(p);
|
|
5953
|
+
}
|
|
5954
|
+
return Math.exp(-logSum / T);
|
|
5955
|
+
}
|
|
5956
|
+
function printConfusionMatrix(matrix, labels) {
|
|
5957
|
+
const K = matrix.length;
|
|
5958
|
+
const lbs = labels ?? Array.from({ length: K }, (_, i) => String(i));
|
|
5959
|
+
const colW = Math.max(6, ...lbs.map((l) => l.length));
|
|
5960
|
+
const pad = (s, w) => s.padStart(w);
|
|
5961
|
+
const header = pad("", colW) + " " + lbs.map((l) => pad(l, colW)).join(" ");
|
|
5962
|
+
console.log("");
|
|
5963
|
+
console.log("Confusion Matrix (rows = actual, cols = predicted):");
|
|
5964
|
+
console.log(header);
|
|
5965
|
+
console.log("\u2500".repeat(header.length));
|
|
5966
|
+
for (let i = 0; i < K; i++) {
|
|
5967
|
+
const row = pad(lbs[i], colW) + " " + matrix[i].map((v) => pad(String(v), colW)).join(" ");
|
|
5968
|
+
console.log(row);
|
|
5969
|
+
}
|
|
5970
|
+
console.log("");
|
|
5971
|
+
}
|
|
5972
|
+
function classificationReport(yTrue, yPred, labels) {
|
|
5973
|
+
const K = Math.max(...yTrue, ...yPred) + 1;
|
|
5974
|
+
const lbs = labels ?? Array.from({ length: K }, (_, i) => `class_${i}`);
|
|
5975
|
+
const rows = [];
|
|
5976
|
+
const colW = Math.max(10, ...lbs.map((l) => l.length));
|
|
5977
|
+
const fmt = (n) => n.toFixed(4).padStart(10);
|
|
5978
|
+
const fmtI = (n) => String(n).padStart(10);
|
|
5979
|
+
rows.push(
|
|
5980
|
+
"label".padEnd(colW) + fmt(0).replace(/\d/g, " ").replace("0.0000", "precision") + fmt(0).replace(/\d/g, " ").replace("0.0000", " recall") + fmt(0).replace(/\d/g, " ").replace("0.0000", " f1-score") + fmtI(0).replace(/\d/g, " ").replace("0", " support")
|
|
5981
|
+
);
|
|
5982
|
+
rows.push("\u2500".repeat(colW + 44));
|
|
5983
|
+
let pSum = 0, rSum = 0, f1Sum = 0;
|
|
5984
|
+
for (let c = 0; c < K; c++) {
|
|
5985
|
+
const p = _binaryPrecision(yTrue, yPred, c);
|
|
5986
|
+
const r = _binaryRecall(yTrue, yPred, c);
|
|
5987
|
+
const f1 = p + r > 0 ? 2 * p * r / (p + r) : 0;
|
|
5988
|
+
const support = yTrue.filter((y) => y === c).length;
|
|
5989
|
+
pSum += p;
|
|
5990
|
+
rSum += r;
|
|
5991
|
+
f1Sum += f1;
|
|
5992
|
+
rows.push(lbs[c].padEnd(colW) + fmt(p) + fmt(r) + fmt(f1) + fmtI(support));
|
|
5993
|
+
}
|
|
5994
|
+
rows.push("\u2500".repeat(colW + 44));
|
|
5995
|
+
rows.push("macro avg".padEnd(colW) + fmt(pSum / K) + fmt(rSum / K) + fmt(f1Sum / K) + fmtI(yTrue.length));
|
|
5996
|
+
console.log("");
|
|
5997
|
+
console.log("Classification Report:");
|
|
5998
|
+
rows.forEach((r) => console.log(r));
|
|
5999
|
+
console.log("");
|
|
6000
|
+
}
|
|
6001
|
+
function _binaryPrecision(yTrue, yPred, pos) {
|
|
6002
|
+
let tp = 0, fp = 0;
|
|
6003
|
+
for (let i = 0; i < yTrue.length; i++) {
|
|
6004
|
+
if (yPred[i] === pos && yTrue[i] === pos) tp++;
|
|
6005
|
+
if (yPred[i] === pos && yTrue[i] !== pos) fp++;
|
|
6006
|
+
}
|
|
6007
|
+
return tp + fp > 0 ? tp / (tp + fp) : 0;
|
|
6008
|
+
}
|
|
6009
|
+
function _binaryRecall(yTrue, yPred, pos) {
|
|
6010
|
+
let tp = 0, fn = 0;
|
|
6011
|
+
for (let i = 0; i < yTrue.length; i++) {
|
|
6012
|
+
if (yTrue[i] === pos && yPred[i] === pos) tp++;
|
|
6013
|
+
if (yTrue[i] === pos && yPred[i] !== pos) fn++;
|
|
6014
|
+
}
|
|
6015
|
+
return tp + fn > 0 ? tp / (tp + fn) : 0;
|
|
6016
|
+
}
|
|
6017
|
+
|
|
6018
|
+
// src/EarlyStopping.ts
|
|
6019
|
+
var EarlyStopping = class {
|
|
6020
|
+
constructor(options) {
|
|
6021
|
+
this.patience = options?.patience ?? 10;
|
|
6022
|
+
this.minDelta = options?.minDelta ?? 1e-4;
|
|
6023
|
+
this.mode = options?.mode ?? "min";
|
|
6024
|
+
this.restoreBest = options?.restoreBest ?? false;
|
|
6025
|
+
this.counter = 0;
|
|
6026
|
+
this.stopped = false;
|
|
6027
|
+
this.bestEpoch = 0;
|
|
6028
|
+
this.bestWeights = null;
|
|
6029
|
+
this.bestValue = this.mode === "min" ? Infinity : -Infinity;
|
|
6030
|
+
}
|
|
6031
|
+
// ── update ───────────────────────────────────────────────────────────────
|
|
6032
|
+
//
|
|
6033
|
+
// Call once per epoch with the current metric value.
|
|
6034
|
+
// Returns true when training should stop (patience exhausted).
|
|
6035
|
+
//
|
|
6036
|
+
// Optionally pass `weights` (from net.getWeights()) to enable weight
|
|
6037
|
+
// snapshotting when restoreBest = true.
|
|
6038
|
+
//
|
|
6039
|
+
update(value, epoch, weights) {
|
|
6040
|
+
if (this.stopped) return true;
|
|
6041
|
+
const improved = this.mode === "min" ? value < this.bestValue - this.minDelta : value > this.bestValue + this.minDelta;
|
|
6042
|
+
if (improved) {
|
|
6043
|
+
this.bestValue = value;
|
|
6044
|
+
this.bestEpoch = epoch;
|
|
6045
|
+
this.counter = 0;
|
|
6046
|
+
if (this.restoreBest && weights !== void 0) {
|
|
6047
|
+
this.bestWeights = [...weights];
|
|
6048
|
+
}
|
|
6049
|
+
} else {
|
|
6050
|
+
this.counter++;
|
|
6051
|
+
if (this.counter >= this.patience) {
|
|
6052
|
+
this.stopped = true;
|
|
6053
|
+
return true;
|
|
6054
|
+
}
|
|
6055
|
+
}
|
|
6056
|
+
return false;
|
|
6057
|
+
}
|
|
6058
|
+
// Resets all state — use to re-run training with a fresh early-stop monitor.
|
|
6059
|
+
reset() {
|
|
6060
|
+
this.counter = 0;
|
|
6061
|
+
this.stopped = false;
|
|
6062
|
+
this.bestEpoch = 0;
|
|
6063
|
+
this.bestWeights = null;
|
|
6064
|
+
this.bestValue = this.mode === "min" ? Infinity : -Infinity;
|
|
6065
|
+
}
|
|
6066
|
+
};
|
|
6067
|
+
|
|
6068
|
+
// src/LossPlotter.ts
|
|
6069
|
+
var LossPlotter = class {
|
|
6070
|
+
constructor(options) {
|
|
6071
|
+
this.width = options?.width ?? 60;
|
|
6072
|
+
this.height = options?.height ?? 15;
|
|
6073
|
+
this.title = options?.title ?? "Loss Curve";
|
|
6074
|
+
this.losses = [];
|
|
6075
|
+
this.epochs = [];
|
|
6076
|
+
}
|
|
6077
|
+
// Add a single (loss, epoch) pair.
|
|
6078
|
+
add(loss, epoch) {
|
|
6079
|
+
this.losses.push(loss);
|
|
6080
|
+
this.epochs.push(epoch ?? this.losses.length - 1);
|
|
6081
|
+
}
|
|
6082
|
+
// Add multiple loss values (epochs are auto-numbered from 0).
|
|
6083
|
+
addMultiple(losses) {
|
|
6084
|
+
for (const l of losses) this.add(l);
|
|
6085
|
+
}
|
|
6086
|
+
// Returns the ASCII plot as a multi-line string.
|
|
6087
|
+
render() {
|
|
6088
|
+
if (this.losses.length === 0) return `[${this.title}] \u2014 no data yet`;
|
|
6089
|
+
const losses = this.losses;
|
|
6090
|
+
const minL = Math.min(...losses);
|
|
6091
|
+
const maxL = Math.max(...losses);
|
|
6092
|
+
const range = maxL - minL || 1;
|
|
6093
|
+
const yLabelW = 8;
|
|
6094
|
+
const plotW = Math.max(4, this.width - yLabelW - 1);
|
|
6095
|
+
const plotH = Math.max(3, this.height);
|
|
6096
|
+
const grid = Array.from(
|
|
6097
|
+
{ length: plotH },
|
|
6098
|
+
() => new Array(plotW).fill(" ")
|
|
6099
|
+
);
|
|
6100
|
+
const n = losses.length;
|
|
6101
|
+
for (let idx = 0; idx < n; idx++) {
|
|
6102
|
+
const col = Math.round(idx / Math.max(1, n - 1) * (plotW - 1));
|
|
6103
|
+
const norm = (losses[idx] - minL) / range;
|
|
6104
|
+
const row = Math.round((1 - norm) * (plotH - 1));
|
|
6105
|
+
grid[row][col] = "*";
|
|
6106
|
+
}
|
|
6107
|
+
const lines = [];
|
|
6108
|
+
const titlePadded = ` ${this.title} `;
|
|
6109
|
+
const dashCount = Math.max(0, plotW + yLabelW - titlePadded.length);
|
|
6110
|
+
lines.push("\u250C" + titlePadded + "\u2500".repeat(dashCount) + "\u2510");
|
|
6111
|
+
for (let row = 0; row < plotH; row++) {
|
|
6112
|
+
let label;
|
|
6113
|
+
if (row === 0) {
|
|
6114
|
+
label = _fmtNum(maxL).padStart(yLabelW - 2) + " \u2524";
|
|
6115
|
+
} else if (row === plotH - 1) {
|
|
6116
|
+
label = _fmtNum(minL).padStart(yLabelW - 2) + " \u2524";
|
|
6117
|
+
} else if (row === Math.floor(plotH / 2)) {
|
|
6118
|
+
const mid = minL + range / 2;
|
|
6119
|
+
label = _fmtNum(mid).padStart(yLabelW - 2) + " \u2524";
|
|
6120
|
+
} else {
|
|
6121
|
+
label = " ".repeat(yLabelW - 1) + "\u2502";
|
|
6122
|
+
}
|
|
6123
|
+
lines.push(label + grid[row].join("") + "\u2502");
|
|
6124
|
+
}
|
|
6125
|
+
const xAxis = " ".repeat(yLabelW - 1) + "\u2514" + "\u2500".repeat(plotW) + "\u2518";
|
|
6126
|
+
lines.push(xAxis);
|
|
6127
|
+
const firstEpoch = String(this.epochs[0]);
|
|
6128
|
+
const lastEpoch = String(this.epochs[this.epochs.length - 1]);
|
|
6129
|
+
const xLabel = " ".repeat(yLabelW) + firstEpoch + " ".repeat(Math.max(1, plotW - firstEpoch.length - lastEpoch.length)) + lastEpoch + " epoch";
|
|
6130
|
+
lines.push(xLabel);
|
|
6131
|
+
lines.push(
|
|
6132
|
+
` min=${_fmtNum(minL)} max=${_fmtNum(maxL)} last=${_fmtNum(losses[losses.length - 1])} n=${n}`
|
|
6133
|
+
);
|
|
6134
|
+
return lines.join("\n");
|
|
6135
|
+
}
|
|
6136
|
+
// Prints the chart to stdout.
|
|
6137
|
+
print() {
|
|
6138
|
+
console.log(this.render());
|
|
6139
|
+
}
|
|
6140
|
+
// Clears all accumulated data.
|
|
6141
|
+
reset() {
|
|
6142
|
+
this.losses = [];
|
|
6143
|
+
this.epochs = [];
|
|
6144
|
+
}
|
|
6145
|
+
};
|
|
6146
|
+
function _fmtNum(n) {
|
|
6147
|
+
if (Math.abs(n) >= 1e4 || Math.abs(n) < 1e-3 && n !== 0) {
|
|
6148
|
+
return n.toExponential(1);
|
|
6149
|
+
}
|
|
6150
|
+
return n.toPrecision(4);
|
|
6151
|
+
}
|
|
6152
|
+
|
|
6153
|
+
// src/DataAugmentation.ts
|
|
6154
|
+
var DataAugmentation = class _DataAugmentation {
|
|
6155
|
+
// ── Noise / Perturbation ─────────────────────────────────────────────────
|
|
6156
|
+
// Adds independent Gaussian noise to each feature: x'ᵢ = xᵢ + N(0, σ²).
|
|
6157
|
+
// sigma: standard deviation of the noise (default: 0.01).
|
|
6158
|
+
static addNoise(x, sigma = 0.01) {
|
|
6159
|
+
return x.map((v) => v + _sampleNormal() * sigma);
|
|
6160
|
+
}
|
|
6161
|
+
// Uniform jitter: x'ᵢ = xᵢ + U(-delta, delta).
|
|
6162
|
+
// delta: half-width of the uniform perturbation (default: 0.01).
|
|
6163
|
+
static jitter(x, delta = 0.01) {
|
|
6164
|
+
return x.map((v) => v + (Math.random() * 2 - 1) * delta);
|
|
6165
|
+
}
|
|
6166
|
+
// Reverses the order of the feature vector.
|
|
6167
|
+
// Useful when features are symmetric or represent a temporal window.
|
|
6168
|
+
static flipHorizontal(x) {
|
|
6169
|
+
return [...x].reverse();
|
|
6170
|
+
}
|
|
6171
|
+
// ── Normalisation ────────────────────────────────────────────────────────
|
|
6172
|
+
// Min-Max normalisation fitted on a dataset X.
|
|
6173
|
+
// Returns the normalised data and the per-feature min/max arrays.
|
|
6174
|
+
// Use normalizePoint() to apply the same transform to new samples.
|
|
6175
|
+
static normalize(X) {
|
|
6176
|
+
if (X.length === 0) return { normalized: [], min: [], max: [] };
|
|
6177
|
+
const d = X[0].length;
|
|
6178
|
+
const min = new Array(d).fill(Infinity);
|
|
6179
|
+
const max = new Array(d).fill(-Infinity);
|
|
6180
|
+
for (const row of X) {
|
|
6181
|
+
for (let j = 0; j < d; j++) {
|
|
6182
|
+
if (row[j] < min[j]) min[j] = row[j];
|
|
6183
|
+
if (row[j] > max[j]) max[j] = row[j];
|
|
6184
|
+
}
|
|
6185
|
+
}
|
|
6186
|
+
const normalized = X.map((row) => _DataAugmentation.normalizePoint(row, min, max));
|
|
6187
|
+
return { normalized, min, max };
|
|
6188
|
+
}
|
|
6189
|
+
// Applies pre-computed min/max normalisation to a single sample.
|
|
6190
|
+
// Handles constant features (min === max) by mapping to 0.
|
|
6191
|
+
static normalizePoint(x, min, max) {
|
|
6192
|
+
return x.map((v, j) => {
|
|
6193
|
+
const range = max[j] - min[j];
|
|
6194
|
+
return range === 0 ? 0 : (v - min[j]) / range;
|
|
6195
|
+
});
|
|
6196
|
+
}
|
|
6197
|
+
// Z-Score standardisation fitted on a dataset X.
|
|
6198
|
+
// Returns the standardised data and per-feature mean/std arrays.
|
|
6199
|
+
// Use standardizePoint() to apply the same transform to new samples.
|
|
6200
|
+
static standardize(X) {
|
|
6201
|
+
if (X.length === 0) return { standardized: [], mean: [], std: [] };
|
|
6202
|
+
const n = X.length;
|
|
6203
|
+
const d = X[0].length;
|
|
6204
|
+
const mean = new Array(d).fill(0);
|
|
6205
|
+
for (const row of X) {
|
|
6206
|
+
for (let j = 0; j < d; j++) mean[j] += row[j];
|
|
6207
|
+
}
|
|
6208
|
+
for (let j = 0; j < d; j++) mean[j] /= n;
|
|
6209
|
+
const variance = new Array(d).fill(0);
|
|
6210
|
+
for (const row of X) {
|
|
6211
|
+
for (let j = 0; j < d; j++) variance[j] += (row[j] - mean[j]) ** 2;
|
|
6212
|
+
}
|
|
6213
|
+
const std = variance.map((v) => Math.sqrt(v / n));
|
|
6214
|
+
const standardized = X.map(
|
|
6215
|
+
(row) => _DataAugmentation.standardizePoint(row, mean, std)
|
|
6216
|
+
);
|
|
6217
|
+
return { standardized, mean, std };
|
|
6218
|
+
}
|
|
6219
|
+
// Applies pre-computed z-score standardisation to a single sample.
|
|
6220
|
+
// Constant features (std === 0) are mapped to 0.
|
|
6221
|
+
static standardizePoint(x, mean, std) {
|
|
6222
|
+
return x.map((v, j) => std[j] === 0 ? 0 : (v - mean[j]) / std[j]);
|
|
6223
|
+
}
|
|
6224
|
+
// ── Batch Augmentation ───────────────────────────────────────────────────
|
|
6225
|
+
// Generates `factor` noisy copies of each sample in X.
|
|
6226
|
+
// The original samples are included in the output (at factor = 1 the output
|
|
6227
|
+
// equals the input; at factor = 3 the dataset triples in size).
|
|
6228
|
+
// sigma: noise std dev (default: 0.01).
|
|
6229
|
+
static augmentBatch(X, y, factor = 2, sigma = 0.01) {
|
|
6230
|
+
const augX = [];
|
|
6231
|
+
const augY = [];
|
|
6232
|
+
for (let i = 0; i < X.length; i++) {
|
|
6233
|
+
augX.push([...X[i]]);
|
|
6234
|
+
augY.push(y[i]);
|
|
6235
|
+
for (let k = 1; k < factor; k++) {
|
|
6236
|
+
augX.push(_DataAugmentation.addNoise(X[i], sigma));
|
|
6237
|
+
augY.push(y[i]);
|
|
6238
|
+
}
|
|
6239
|
+
}
|
|
6240
|
+
return { X: augX, y: augY };
|
|
6241
|
+
}
|
|
6242
|
+
// ── Shuffle ──────────────────────────────────────────────────────────────
|
|
6243
|
+
// Fisher-Yates shuffle — in-place permutation of indices.
|
|
6244
|
+
// Returns new arrays (does not mutate the inputs).
|
|
6245
|
+
static shuffle(X, y) {
|
|
6246
|
+
const indices = Array.from({ length: X.length }, (_, i) => i);
|
|
6247
|
+
for (let i = indices.length - 1; i > 0; i--) {
|
|
6248
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
6249
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
6250
|
+
}
|
|
6251
|
+
return {
|
|
6252
|
+
X: indices.map((i) => X[i]),
|
|
6253
|
+
y: indices.map((i) => y[i])
|
|
6254
|
+
};
|
|
6255
|
+
}
|
|
6256
|
+
// ── Train / Val / Test Split ─────────────────────────────────────────────
|
|
6257
|
+
//
|
|
6258
|
+
// Splits the dataset into three non-overlapping partitions.
|
|
6259
|
+
// trainRatio + valRatio must be < 1.0; the remainder goes to test.
|
|
6260
|
+
// Shuffles automatically before splitting.
|
|
6261
|
+
//
|
|
6262
|
+
// Default split: 70% / 15% / 15%.
|
|
6263
|
+
//
|
|
6264
|
+
static split(X, y, trainRatio = 0.7, valRatio = 0.15) {
|
|
6265
|
+
if (trainRatio + valRatio >= 1) {
|
|
6266
|
+
throw new Error(
|
|
6267
|
+
`trainRatio (${trainRatio}) + valRatio (${valRatio}) must be < 1`
|
|
6268
|
+
);
|
|
6269
|
+
}
|
|
6270
|
+
const { X: sX, y: sY } = _DataAugmentation.shuffle(X, y);
|
|
6271
|
+
const n = sX.length;
|
|
6272
|
+
const trainEnd = Math.floor(n * trainRatio);
|
|
6273
|
+
const valEnd = trainEnd + Math.floor(n * valRatio);
|
|
6274
|
+
return {
|
|
6275
|
+
trainX: sX.slice(0, trainEnd),
|
|
6276
|
+
trainY: sY.slice(0, trainEnd),
|
|
6277
|
+
valX: sX.slice(trainEnd, valEnd),
|
|
6278
|
+
valY: sY.slice(trainEnd, valEnd),
|
|
6279
|
+
testX: sX.slice(valEnd),
|
|
6280
|
+
testY: sY.slice(valEnd)
|
|
6281
|
+
};
|
|
6282
|
+
}
|
|
6283
|
+
};
|
|
6284
|
+
function _sampleNormal() {
|
|
6285
|
+
const u1 = Math.random();
|
|
6286
|
+
const u2 = Math.random();
|
|
6287
|
+
return Math.sqrt(-2 * Math.log(u1 + 1e-15)) * Math.cos(2 * Math.PI * u2);
|
|
6288
|
+
}
|
|
2601
6289
|
export {
|
|
2602
6290
|
Adam,
|
|
2603
6291
|
AttentionHead,
|
|
6292
|
+
Augmenter,
|
|
6293
|
+
Autoencoder,
|
|
2604
6294
|
BatchNorm,
|
|
2605
6295
|
BiasVector,
|
|
6296
|
+
CausalConv1D,
|
|
2606
6297
|
ClipOptimizer,
|
|
2607
6298
|
ClippedOptimizerFactory,
|
|
6299
|
+
ContrastiveLearning,
|
|
2608
6300
|
Conv1D,
|
|
6301
|
+
Conv2D,
|
|
6302
|
+
DataAugmentation,
|
|
2609
6303
|
DataLoader,
|
|
6304
|
+
DecisionTree,
|
|
2610
6305
|
Dropout,
|
|
6306
|
+
EarlyStopping,
|
|
2611
6307
|
EmbeddingMatrix,
|
|
6308
|
+
Flatten,
|
|
6309
|
+
GAN,
|
|
2612
6310
|
GRULayer,
|
|
6311
|
+
GaussianNaiveBayes,
|
|
6312
|
+
HopfieldNetwork,
|
|
6313
|
+
KMeans,
|
|
2613
6314
|
LRScheduler,
|
|
2614
6315
|
LSTMLayer,
|
|
2615
6316
|
Layer,
|
|
2616
6317
|
LayerNorm,
|
|
6318
|
+
LearnedPositionalEncoding,
|
|
6319
|
+
LinearRegression,
|
|
6320
|
+
LogisticRegression,
|
|
6321
|
+
LossPlotter,
|
|
6322
|
+
MaxPool2D,
|
|
2617
6323
|
ModelSaver,
|
|
2618
6324
|
Momentum,
|
|
2619
6325
|
MultiHeadAttention,
|
|
@@ -2624,23 +6330,49 @@ export {
|
|
|
2624
6330
|
NetworkTransformerRL,
|
|
2625
6331
|
Neuron,
|
|
2626
6332
|
NeuronN,
|
|
6333
|
+
PCA,
|
|
6334
|
+
Perceptron,
|
|
6335
|
+
PositionalEncoding,
|
|
6336
|
+
RNN,
|
|
2627
6337
|
SGD,
|
|
6338
|
+
SOM,
|
|
6339
|
+
Seq2Seq,
|
|
6340
|
+
SoftmaxRegression,
|
|
6341
|
+
TCN,
|
|
6342
|
+
TSNE,
|
|
2628
6343
|
Trainer,
|
|
2629
6344
|
TransformerBlock,
|
|
6345
|
+
VAE,
|
|
6346
|
+
Value,
|
|
6347
|
+
WeightInspector,
|
|
2630
6348
|
WeightMatrix,
|
|
6349
|
+
Word2Vec,
|
|
6350
|
+
accuracy,
|
|
6351
|
+
auc,
|
|
6352
|
+
classificationReport,
|
|
6353
|
+
confusionMatrix,
|
|
2631
6354
|
crossEntropy,
|
|
2632
6355
|
crossEntropyDelta,
|
|
2633
6356
|
crossEntropyDeltaRaw,
|
|
2634
6357
|
defaultOptimizer,
|
|
2635
6358
|
elu,
|
|
6359
|
+
f1Score,
|
|
2636
6360
|
leakyRelu,
|
|
2637
6361
|
linear,
|
|
6362
|
+
mae,
|
|
2638
6363
|
makeElu,
|
|
2639
6364
|
makeLeakyRelu,
|
|
2640
6365
|
matMul,
|
|
2641
6366
|
mse,
|
|
2642
6367
|
mseDelta,
|
|
6368
|
+
perplexity,
|
|
6369
|
+
precision,
|
|
6370
|
+
printConfusionMatrix,
|
|
6371
|
+
r2Score,
|
|
6372
|
+
recall,
|
|
2643
6373
|
relu,
|
|
6374
|
+
rmse,
|
|
6375
|
+
rocCurve,
|
|
2644
6376
|
sigmoid2 as sigmoid,
|
|
2645
6377
|
softmax,
|
|
2646
6378
|
softmaxBackward,
|