bun-scikit 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +187 -0
  3. package/binding.gyp +21 -0
  4. package/docs/README.md +7 -0
  5. package/docs/native-abi.md +53 -0
  6. package/index.ts +1 -0
  7. package/package.json +76 -0
  8. package/scripts/build-node-addon.ts +26 -0
  9. package/scripts/build-zig-kernels.ts +50 -0
  10. package/scripts/check-api-docs-coverage.ts +52 -0
  11. package/scripts/check-benchmark-health.ts +140 -0
  12. package/scripts/install-native.ts +160 -0
  13. package/scripts/package-native-artifacts.ts +62 -0
  14. package/scripts/sync-benchmark-readme.ts +181 -0
  15. package/scripts/update-benchmark-history.ts +91 -0
  16. package/src/ensemble/RandomForestClassifier.ts +136 -0
  17. package/src/ensemble/RandomForestRegressor.ts +136 -0
  18. package/src/index.ts +32 -0
  19. package/src/linear_model/LinearRegression.ts +136 -0
  20. package/src/linear_model/LogisticRegression.ts +260 -0
  21. package/src/linear_model/SGDClassifier.ts +161 -0
  22. package/src/linear_model/SGDRegressor.ts +104 -0
  23. package/src/metrics/classification.ts +294 -0
  24. package/src/metrics/regression.ts +51 -0
  25. package/src/model_selection/GridSearchCV.ts +244 -0
  26. package/src/model_selection/KFold.ts +82 -0
  27. package/src/model_selection/RepeatedKFold.ts +49 -0
  28. package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
  29. package/src/model_selection/StratifiedKFold.ts +112 -0
  30. package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
  31. package/src/model_selection/crossValScore.ts +165 -0
  32. package/src/model_selection/trainTestSplit.ts +82 -0
  33. package/src/naive_bayes/GaussianNB.ts +148 -0
  34. package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
  35. package/src/native/zigKernels.ts +576 -0
  36. package/src/neighbors/KNeighborsClassifier.ts +85 -0
  37. package/src/pipeline/ColumnTransformer.ts +203 -0
  38. package/src/pipeline/FeatureUnion.ts +123 -0
  39. package/src/pipeline/Pipeline.ts +168 -0
  40. package/src/preprocessing/MinMaxScaler.ts +113 -0
  41. package/src/preprocessing/OneHotEncoder.ts +91 -0
  42. package/src/preprocessing/PolynomialFeatures.ts +158 -0
  43. package/src/preprocessing/RobustScaler.ts +149 -0
  44. package/src/preprocessing/SimpleImputer.ts +150 -0
  45. package/src/preprocessing/StandardScaler.ts +92 -0
  46. package/src/svm/LinearSVC.ts +117 -0
  47. package/src/tree/DecisionTreeClassifier.ts +394 -0
  48. package/src/tree/DecisionTreeRegressor.ts +407 -0
  49. package/src/types.ts +18 -0
  50. package/src/utils/linalg.ts +209 -0
  51. package/src/utils/validation.ts +78 -0
  52. package/zig/kernels.zig +1327 -0
@@ -0,0 +1,211 @@
1
+ import type { FoldIndices } from "./KFold";
2
+
3
+ export interface StratifiedShuffleSplitOptions {
4
+ nSplits?: number;
5
+ testSize?: number;
6
+ trainSize?: number;
7
+ randomState?: number;
8
+ }
9
+
10
+ function mulberry32(seed: number): () => number {
11
+ let state = seed >>> 0;
12
+ return () => {
13
+ state += 0x6d2b79f5;
14
+ let t = Math.imul(state ^ (state >>> 15), 1 | state);
15
+ t ^= t + Math.imul(t ^ (t >>> 7), 61 | t);
16
+ return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
17
+ };
18
+ }
19
+
20
+ function shuffleInPlace(values: number[], random: () => number): void {
21
+ for (let i = values.length - 1; i > 0; i -= 1) {
22
+ const j = Math.floor(random() * (i + 1));
23
+ const tmp = values[i];
24
+ values[i] = values[j];
25
+ values[j] = tmp;
26
+ }
27
+ }
28
+
29
+ function resolveSplitCounts(
30
+ sampleCount: number,
31
+ testSize: number | undefined,
32
+ trainSize: number | undefined,
33
+ ): { trainCount: number; testCount: number } {
34
+ function resolveSize(value: number | undefined, label: string): number | null {
35
+ if (value === undefined) {
36
+ return null;
37
+ }
38
+ if (value > 0 && value < 1) {
39
+ return Math.max(1, Math.floor(sampleCount * value));
40
+ }
41
+ if (Number.isInteger(value) && value >= 1 && value < sampleCount) {
42
+ return value;
43
+ }
44
+ throw new Error(
45
+ `${label} must be a float in (0, 1) or int in [1, n-1]. Got ${value}.`,
46
+ );
47
+ }
48
+
49
+ const resolvedTest = resolveSize(testSize, "testSize");
50
+ const resolvedTrain = resolveSize(trainSize, "trainSize");
51
+
52
+ if (resolvedTest === null && resolvedTrain === null) {
53
+ const defaultTest = Math.max(1, Math.floor(sampleCount * 0.1));
54
+ return {
55
+ testCount: defaultTest,
56
+ trainCount: sampleCount - defaultTest,
57
+ };
58
+ }
59
+
60
+ if (resolvedTest !== null && resolvedTrain !== null) {
61
+ if (resolvedTest + resolvedTrain > sampleCount) {
62
+ throw new Error(
63
+ `trainSize + testSize must be <= sample count (${sampleCount}). Got ${resolvedTrain + resolvedTest}.`,
64
+ );
65
+ }
66
+ return { trainCount: resolvedTrain, testCount: resolvedTest };
67
+ }
68
+
69
+ if (resolvedTest !== null) {
70
+ return { testCount: resolvedTest, trainCount: sampleCount - resolvedTest };
71
+ }
72
+
73
+ const trainCount = resolvedTrain!;
74
+ return { trainCount, testCount: sampleCount - trainCount };
75
+ }
76
+
77
+ export class StratifiedShuffleSplit {
78
+ private readonly nSplits: number;
79
+ private readonly testSize?: number;
80
+ private readonly trainSize?: number;
81
+ private readonly randomState: number;
82
+
83
+ constructor(options: StratifiedShuffleSplitOptions = {}) {
84
+ this.nSplits = options.nSplits ?? 10;
85
+ this.testSize = options.testSize;
86
+ this.trainSize = options.trainSize;
87
+ this.randomState = options.randomState ?? 42;
88
+
89
+ if (!Number.isInteger(this.nSplits) || this.nSplits < 1) {
90
+ throw new Error(`nSplits must be an integer >= 1. Got ${this.nSplits}.`);
91
+ }
92
+ }
93
+
94
+ split<TX>(X: TX[], y: number[]): FoldIndices[] {
95
+ if (!Array.isArray(X) || X.length === 0) {
96
+ throw new Error("X must be a non-empty array.");
97
+ }
98
+ if (!Array.isArray(y) || y.length !== X.length) {
99
+ throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
100
+ }
101
+
102
+ const { trainCount, testCount } = resolveSplitCounts(X.length, this.testSize, this.trainSize);
103
+ if (trainCount < 1 || testCount < 1) {
104
+ throw new Error("Both train and test sets must have at least one sample.");
105
+ }
106
+
107
+ const byClass = new Map<number, number[]>();
108
+ for (let i = 0; i < y.length; i += 1) {
109
+ const label = y[i];
110
+ const bucket = byClass.get(label);
111
+ if (bucket) {
112
+ bucket.push(i);
113
+ } else {
114
+ byClass.set(label, [i]);
115
+ }
116
+ }
117
+
118
+ if (byClass.size < 2) {
119
+ throw new Error("StratifiedShuffleSplit requires at least two classes.");
120
+ }
121
+
122
+ const classEntries = Array.from(byClass.entries());
123
+ let minClassCount = Number.POSITIVE_INFINITY;
124
+ for (const [, indices] of classEntries) {
125
+ if (indices.length < minClassCount) {
126
+ minClassCount = indices.length;
127
+ }
128
+ }
129
+ if (minClassCount < 2) {
130
+ throw new Error(
131
+ "The least populated class has fewer than 2 members, which is not enough for stratified splitting.",
132
+ );
133
+ }
134
+
135
+ const proportions = classEntries.map(([, indices]) => (indices.length * testCount) / X.length);
136
+ const testPerClass = proportions.map((value) => Math.floor(value));
137
+
138
+ // Respect the global test count while keeping per-class allocations feasible.
139
+ let allocated = testPerClass.reduce((sum, count) => sum + count, 0);
140
+ let remaining = testCount - allocated;
141
+ if (remaining > 0) {
142
+ const classOrder = proportions
143
+ .map((target, idx) => ({ idx, frac: target - Math.floor(target) }))
144
+ .sort((a, b) => b.frac - a.frac)
145
+ .map((entry) => entry.idx);
146
+
147
+ let cursor = 0;
148
+ while (remaining > 0) {
149
+ const classIdx = classOrder[cursor % classOrder.length];
150
+ const classCount = classEntries[classIdx][1].length;
151
+ if (testPerClass[classIdx] < classCount - 1) {
152
+ testPerClass[classIdx] += 1;
153
+ remaining -= 1;
154
+ }
155
+ cursor += 1;
156
+ }
157
+ }
158
+
159
+ for (let i = 0; i < testPerClass.length; i += 1) {
160
+ const classCount = classEntries[i][1].length;
161
+ if (testPerClass[i] >= classCount) {
162
+ testPerClass[i] = classCount - 1;
163
+ }
164
+ }
165
+
166
+ allocated = testPerClass.reduce((sum, count) => sum + count, 0);
167
+ if (allocated !== testCount) {
168
+ throw new Error(
169
+ `Could not allocate exactly ${testCount} stratified test samples. Allocated ${allocated}.`,
170
+ );
171
+ }
172
+
173
+ const splits: FoldIndices[] = [];
174
+ for (let splitIndex = 0; splitIndex < this.nSplits; splitIndex += 1) {
175
+ const random = mulberry32(this.randomState + splitIndex * 104_729);
176
+ const testIndices: number[] = [];
177
+
178
+ for (let classIdx = 0; classIdx < classEntries.length; classIdx += 1) {
179
+ const classIndices = classEntries[classIdx][1].slice();
180
+ shuffleInPlace(classIndices, random);
181
+ const classTestCount = testPerClass[classIdx];
182
+ for (let i = 0; i < classTestCount; i += 1) {
183
+ testIndices.push(classIndices[i]);
184
+ }
185
+ }
186
+
187
+ testIndices.sort((a, b) => a - b);
188
+ const testMask = new Uint8Array(X.length);
189
+ for (let i = 0; i < testIndices.length; i += 1) {
190
+ testMask[testIndices[i]] = 1;
191
+ }
192
+
193
+ const trainIndices: number[] = [];
194
+ for (let i = 0; i < X.length; i += 1) {
195
+ if (testMask[i] === 0) {
196
+ trainIndices.push(i);
197
+ }
198
+ }
199
+
200
+ if (trainIndices.length !== trainCount || testIndices.length !== testCount) {
201
+ throw new Error(
202
+ `Split sizes mismatch. Expected train/test ${trainCount}/${testCount}, got ${trainIndices.length}/${testIndices.length}.`,
203
+ );
204
+ }
205
+
206
+ splits.push({ trainIndices, testIndices });
207
+ }
208
+
209
+ return splits;
210
+ }
211
+ }
@@ -0,0 +1,165 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import { accuracyScore, f1Score, precisionScore, recallScore } from "../metrics/classification";
3
+ import { meanSquaredError, r2Score } from "../metrics/regression";
4
+ import { assertFiniteMatrix, assertFiniteVector, assertVectorLength } from "../utils/validation";
5
+ import { KFold, type FoldIndices } from "./KFold";
6
+ import { StratifiedKFold } from "./StratifiedKFold";
7
+
8
+ export type BuiltInScoring =
9
+ | "accuracy"
10
+ | "f1"
11
+ | "precision"
12
+ | "recall"
13
+ | "r2"
14
+ | "mean_squared_error"
15
+ | "neg_mean_squared_error";
16
+
17
+ export type ScoringFn = (yTrue: Vector, yPred: Vector) => number;
18
+
19
+ export interface CrossValEstimator {
20
+ fit(X: Matrix, y: Vector): unknown;
21
+ predict(X: Matrix): Vector;
22
+ score?(X: Matrix, y: Vector): number;
23
+ }
24
+
25
+ export type CrossValSplitter = {
26
+ split(X: Matrix, y?: Vector): FoldIndices[];
27
+ };
28
+
29
+ export interface CrossValScoreOptions {
30
+ cv?: number | CrossValSplitter;
31
+ scoring?: BuiltInScoring | ScoringFn;
32
+ }
33
+
34
+ function isBinaryVector(y: Vector): boolean {
35
+ for (let i = 0; i < y.length; i += 1) {
36
+ const value = y[i];
37
+ if (!(value === 0 || value === 1)) {
38
+ return false;
39
+ }
40
+ }
41
+ return true;
42
+ }
43
+
44
+ function subsetMatrix(X: Matrix, indices: number[]): Matrix {
45
+ const out = new Array(indices.length);
46
+ for (let i = 0; i < indices.length; i += 1) {
47
+ out[i] = X[indices[i]];
48
+ }
49
+ return out;
50
+ }
51
+
52
+ function subsetVector(y: Vector, indices: number[]): Vector {
53
+ const out = new Array(indices.length);
54
+ for (let i = 0; i < indices.length; i += 1) {
55
+ out[i] = y[indices[i]];
56
+ }
57
+ return out;
58
+ }
59
+
60
+ function resolveBuiltInScorer(scoring: BuiltInScoring): ScoringFn {
61
+ switch (scoring) {
62
+ case "accuracy":
63
+ return accuracyScore;
64
+ case "f1":
65
+ return f1Score;
66
+ case "precision":
67
+ return precisionScore;
68
+ case "recall":
69
+ return recallScore;
70
+ case "r2":
71
+ return r2Score;
72
+ case "mean_squared_error":
73
+ return meanSquaredError;
74
+ case "neg_mean_squared_error":
75
+ return (yTrue, yPred) => -meanSquaredError(yTrue, yPred);
76
+ default: {
77
+ const exhaustive: never = scoring;
78
+ throw new Error(`Unsupported scoring metric: ${exhaustive}`);
79
+ }
80
+ }
81
+ }
82
+
83
+ function resolveFolds(X: Matrix, y: Vector, cv: number | CrossValSplitter | undefined): FoldIndices[] {
84
+ if (typeof cv === "number") {
85
+ if (!Number.isInteger(cv) || cv < 2) {
86
+ throw new Error(`cv must be an integer >= 2. Got ${cv}.`);
87
+ }
88
+ if (isBinaryVector(y)) {
89
+ return new StratifiedKFold({ nSplits: cv, shuffle: false }).split(X, y);
90
+ }
91
+ return new KFold({ nSplits: cv, shuffle: false }).split(X, y);
92
+ }
93
+
94
+ if (cv) {
95
+ return cv.split(X, y);
96
+ }
97
+
98
+ if (isBinaryVector(y)) {
99
+ return new StratifiedKFold({ nSplits: 5, shuffle: false }).split(X, y);
100
+ }
101
+ return new KFold({ nSplits: 5, shuffle: false }).split(X, y);
102
+ }
103
+
104
+ export function crossValScore(
105
+ createEstimator: () => CrossValEstimator,
106
+ X: Matrix,
107
+ y: Vector,
108
+ options: CrossValScoreOptions = {},
109
+ ): number[] {
110
+ if (typeof createEstimator !== "function") {
111
+ throw new Error("createEstimator must be a function returning a new estimator instance.");
112
+ }
113
+
114
+ if (!Array.isArray(X) || X.length === 0) {
115
+ throw new Error("X must be a non-empty matrix.");
116
+ }
117
+ assertFiniteMatrix(X);
118
+ assertVectorLength(y, X.length);
119
+ assertFiniteVector(y);
120
+
121
+ const folds = resolveFolds(X, y, options.cv);
122
+ if (folds.length === 0) {
123
+ throw new Error("Cross-validation splitter produced no folds.");
124
+ }
125
+
126
+ const explicitScorer =
127
+ typeof options.scoring === "function"
128
+ ? options.scoring
129
+ : options.scoring
130
+ ? resolveBuiltInScorer(options.scoring)
131
+ : null;
132
+
133
+ const scores = new Array<number>(folds.length);
134
+ for (let foldIndex = 0; foldIndex < folds.length; foldIndex += 1) {
135
+ const fold = folds[foldIndex];
136
+ if (fold.trainIndices.length === 0 || fold.testIndices.length === 0) {
137
+ throw new Error(`Fold ${foldIndex} must have non-empty train and test indices.`);
138
+ }
139
+
140
+ const XTrain = subsetMatrix(X, fold.trainIndices);
141
+ const yTrain = subsetVector(y, fold.trainIndices);
142
+ const XTest = subsetMatrix(X, fold.testIndices);
143
+ const yTest = subsetVector(y, fold.testIndices);
144
+
145
+ const estimator = createEstimator();
146
+ estimator.fit(XTrain, yTrain);
147
+
148
+ if (explicitScorer) {
149
+ const yPred = estimator.predict(XTest);
150
+ scores[foldIndex] = explicitScorer(yTest, yPred);
151
+ continue;
152
+ }
153
+
154
+ if (typeof estimator.score === "function") {
155
+ scores[foldIndex] = estimator.score(XTest, yTest);
156
+ continue;
157
+ }
158
+
159
+ throw new Error(
160
+ "Estimator must implement score() when no explicit scoring function is provided.",
161
+ );
162
+ }
163
+
164
+ return scores;
165
+ }
@@ -0,0 +1,82 @@
1
+ export interface TrainTestSplitOptions {
2
+ testSize?: number;
3
+ shuffle?: boolean;
4
+ randomState?: number;
5
+ }
6
+
7
+ export interface TrainTestSplitResult<TX, TY> {
8
+ XTrain: TX[];
9
+ XTest: TX[];
10
+ yTrain: TY[];
11
+ yTest: TY[];
12
+ }
13
+
14
+ function mulberry32(seed: number): () => number {
15
+ let state = seed >>> 0;
16
+ return () => {
17
+ state += 0x6d2b79f5;
18
+ let t = Math.imul(state ^ (state >>> 15), 1 | state);
19
+ t ^= t + Math.imul(t ^ (t >>> 7), 61 | t);
20
+ return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
21
+ };
22
+ }
23
+
24
+ export function trainTestSplit<TX, TY>(
25
+ X: TX[],
26
+ y: TY[],
27
+ options: TrainTestSplitOptions = {},
28
+ ): TrainTestSplitResult<TX, TY> {
29
+ if (X.length !== y.length) {
30
+ throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
31
+ }
32
+
33
+ if (X.length < 2) {
34
+ throw new Error("At least two samples are required for train/test splitting.");
35
+ }
36
+
37
+ const shuffle = options.shuffle ?? true;
38
+ const randomState = options.randomState ?? 42;
39
+ const testSize = options.testSize ?? 0.25;
40
+ const sampleCount = X.length;
41
+
42
+ let testCount: number;
43
+ if (testSize > 0 && testSize < 1) {
44
+ testCount = Math.max(1, Math.floor(sampleCount * testSize));
45
+ } else if (Number.isInteger(testSize) && testSize >= 1 && testSize < sampleCount) {
46
+ testCount = testSize;
47
+ } else {
48
+ throw new Error(
49
+ `testSize must be a float in (0, 1) or int in [1, n-1]. Got ${testSize}.`,
50
+ );
51
+ }
52
+
53
+ const indices = Array.from({ length: sampleCount }, (_, idx) => idx);
54
+
55
+ if (shuffle) {
56
+ const random = mulberry32(randomState);
57
+ for (let i = indices.length - 1; i > 0; i -= 1) {
58
+ const j = Math.floor(random() * (i + 1));
59
+ const tmp = indices[i];
60
+ indices[i] = indices[j];
61
+ indices[j] = tmp;
62
+ }
63
+ }
64
+
65
+ const testIndices = new Set(indices.slice(0, testCount));
66
+ const XTrain: TX[] = [];
67
+ const XTest: TX[] = [];
68
+ const yTrain: TY[] = [];
69
+ const yTest: TY[] = [];
70
+
71
+ for (let i = 0; i < sampleCount; i += 1) {
72
+ if (testIndices.has(i)) {
73
+ XTest.push(X[i]);
74
+ yTest.push(y[i]);
75
+ } else {
76
+ XTrain.push(X[i]);
77
+ yTrain.push(y[i]);
78
+ }
79
+ }
80
+
81
+ return { XTrain, XTest, yTrain, yTest };
82
+ }
@@ -0,0 +1,148 @@
1
+ import type { ClassificationModel, Matrix, Vector } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertFiniteVector,
6
+ validateClassificationInputs,
7
+ } from "../utils/validation";
8
+ import { accuracyScore } from "../metrics/classification";
9
+
10
+ const DEFAULT_VAR_SMOOTHING = 1e-9;
11
+
12
+ export interface GaussianNBOptions {
13
+ varSmoothing?: number;
14
+ }
15
+
16
+ export class GaussianNB implements ClassificationModel {
17
+ classes_: Vector = [0, 1];
18
+ classPrior_: Vector | null = null;
19
+ theta_: Matrix | null = null;
20
+ var_: Matrix | null = null;
21
+
22
+ private readonly varSmoothing: number;
23
+ private fittedFeatureCount = 0;
24
+ private epsilon = 0;
25
+
26
+ constructor(options: GaussianNBOptions = {}) {
27
+ this.varSmoothing = options.varSmoothing ?? DEFAULT_VAR_SMOOTHING;
28
+ if (!Number.isFinite(this.varSmoothing) || this.varSmoothing < 0) {
29
+ throw new Error(`varSmoothing must be finite and >= 0. Got ${this.varSmoothing}.`);
30
+ }
31
+ }
32
+
33
+ fit(X: Matrix, y: Vector): this {
34
+ validateClassificationInputs(X, y);
35
+
36
+ const sampleCount = X.length;
37
+ const featureCount = X[0].length;
38
+ this.fittedFeatureCount = featureCount;
39
+ this.classes_ = [0, 1];
40
+
41
+ let maxVariance = 0;
42
+ for (let j = 0; j < featureCount; j += 1) {
43
+ let sum = 0;
44
+ let sumSquares = 0;
45
+ for (let i = 0; i < sampleCount; i += 1) {
46
+ const value = X[i][j];
47
+ sum += value;
48
+ sumSquares += value * value;
49
+ }
50
+ const mean = sum / sampleCount;
51
+ const variance = sumSquares / sampleCount - mean * mean;
52
+ if (variance > maxVariance) {
53
+ maxVariance = variance;
54
+ }
55
+ }
56
+ this.epsilon = this.varSmoothing * maxVariance;
57
+
58
+ const priors = new Array<number>(2).fill(0);
59
+ const means = Array.from({ length: 2 }, () => new Array<number>(featureCount).fill(0));
60
+ const variances = Array.from({ length: 2 }, () => new Array<number>(featureCount).fill(0));
61
+ const counts = new Array<number>(2).fill(0);
62
+
63
+ for (let i = 0; i < sampleCount; i += 1) {
64
+ const label = y[i];
65
+ counts[label] += 1;
66
+ for (let j = 0; j < featureCount; j += 1) {
67
+ means[label][j] += X[i][j];
68
+ }
69
+ }
70
+
71
+ for (let cls = 0; cls < 2; cls += 1) {
72
+ if (counts[cls] === 0) {
73
+ throw new Error(`GaussianNB requires both classes to be present. Missing class ${cls}.`);
74
+ }
75
+ priors[cls] = counts[cls] / sampleCount;
76
+ for (let j = 0; j < featureCount; j += 1) {
77
+ means[cls][j] /= counts[cls];
78
+ }
79
+ }
80
+
81
+ for (let i = 0; i < sampleCount; i += 1) {
82
+ const label = y[i];
83
+ for (let j = 0; j < featureCount; j += 1) {
84
+ const diff = X[i][j] - means[label][j];
85
+ variances[label][j] += diff * diff;
86
+ }
87
+ }
88
+
89
+ for (let cls = 0; cls < 2; cls += 1) {
90
+ for (let j = 0; j < featureCount; j += 1) {
91
+ variances[cls][j] = variances[cls][j] / counts[cls] + this.epsilon;
92
+ }
93
+ }
94
+
95
+ this.classPrior_ = priors;
96
+ this.theta_ = means;
97
+ this.var_ = variances;
98
+ return this;
99
+ }
100
+
101
+ predictProba(X: Matrix): Matrix {
102
+ if (!this.classPrior_ || !this.theta_ || !this.var_ || this.fittedFeatureCount === 0) {
103
+ throw new Error("GaussianNB has not been fitted.");
104
+ }
105
+
106
+ assertConsistentRowSize(X);
107
+ assertFiniteMatrix(X);
108
+ if (X[0].length !== this.fittedFeatureCount) {
109
+ throw new Error(
110
+ `Feature size mismatch. Expected ${this.fittedFeatureCount}, got ${X[0].length}.`,
111
+ );
112
+ }
113
+
114
+ const outputs = new Array<number[]>(X.length);
115
+ for (let i = 0; i < X.length; i += 1) {
116
+ const row = X[i];
117
+ const logProb = new Array<number>(2).fill(0);
118
+ for (let cls = 0; cls < 2; cls += 1) {
119
+ let sum = Math.log(this.classPrior_[cls]);
120
+ for (let j = 0; j < this.fittedFeatureCount; j += 1) {
121
+ const variance = this.var_[cls][j];
122
+ const mean = this.theta_[cls][j];
123
+ const diff = row[j] - mean;
124
+ sum += -0.5 * Math.log(2 * Math.PI * variance) - (diff * diff) / (2 * variance);
125
+ }
126
+ logProb[cls] = sum;
127
+ }
128
+
129
+ const maxLog = Math.max(logProb[0], logProb[1]);
130
+ const exp0 = Math.exp(logProb[0] - maxLog);
131
+ const exp1 = Math.exp(logProb[1] - maxLog);
132
+ const denom = exp0 + exp1;
133
+ outputs[i] = [exp0 / denom, exp1 / denom];
134
+ }
135
+
136
+ return outputs;
137
+ }
138
+
139
+ predict(X: Matrix): Vector {
140
+ const probabilities = this.predictProba(X);
141
+ return probabilities.map((pair) => (pair[1] >= 0.5 ? 1 : 0));
142
+ }
143
+
144
+ score(X: Matrix, y: Vector): number {
145
+ assertFiniteVector(y);
146
+ return accuracyScore(y, this.predict(X));
147
+ }
148
+ }