bun-scikit 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +187 -0
- package/binding.gyp +21 -0
- package/docs/README.md +7 -0
- package/docs/native-abi.md +53 -0
- package/index.ts +1 -0
- package/package.json +76 -0
- package/scripts/build-node-addon.ts +26 -0
- package/scripts/build-zig-kernels.ts +50 -0
- package/scripts/check-api-docs-coverage.ts +52 -0
- package/scripts/check-benchmark-health.ts +140 -0
- package/scripts/install-native.ts +160 -0
- package/scripts/package-native-artifacts.ts +62 -0
- package/scripts/sync-benchmark-readme.ts +181 -0
- package/scripts/update-benchmark-history.ts +91 -0
- package/src/ensemble/RandomForestClassifier.ts +136 -0
- package/src/ensemble/RandomForestRegressor.ts +136 -0
- package/src/index.ts +32 -0
- package/src/linear_model/LinearRegression.ts +136 -0
- package/src/linear_model/LogisticRegression.ts +260 -0
- package/src/linear_model/SGDClassifier.ts +161 -0
- package/src/linear_model/SGDRegressor.ts +104 -0
- package/src/metrics/classification.ts +294 -0
- package/src/metrics/regression.ts +51 -0
- package/src/model_selection/GridSearchCV.ts +244 -0
- package/src/model_selection/KFold.ts +82 -0
- package/src/model_selection/RepeatedKFold.ts +49 -0
- package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
- package/src/model_selection/StratifiedKFold.ts +112 -0
- package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
- package/src/model_selection/crossValScore.ts +165 -0
- package/src/model_selection/trainTestSplit.ts +82 -0
- package/src/naive_bayes/GaussianNB.ts +148 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
- package/src/native/zigKernels.ts +576 -0
- package/src/neighbors/KNeighborsClassifier.ts +85 -0
- package/src/pipeline/ColumnTransformer.ts +203 -0
- package/src/pipeline/FeatureUnion.ts +123 -0
- package/src/pipeline/Pipeline.ts +168 -0
- package/src/preprocessing/MinMaxScaler.ts +113 -0
- package/src/preprocessing/OneHotEncoder.ts +91 -0
- package/src/preprocessing/PolynomialFeatures.ts +158 -0
- package/src/preprocessing/RobustScaler.ts +149 -0
- package/src/preprocessing/SimpleImputer.ts +150 -0
- package/src/preprocessing/StandardScaler.ts +92 -0
- package/src/svm/LinearSVC.ts +117 -0
- package/src/tree/DecisionTreeClassifier.ts +394 -0
- package/src/tree/DecisionTreeRegressor.ts +407 -0
- package/src/types.ts +18 -0
- package/src/utils/linalg.ts +209 -0
- package/src/utils/validation.ts +78 -0
- package/zig/kernels.zig +1327 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { accuracyScore, f1Score, precisionScore, recallScore } from "../metrics/classification";
|
|
3
|
+
import { meanSquaredError, r2Score } from "../metrics/regression";
|
|
4
|
+
import {
|
|
5
|
+
crossValScore,
|
|
6
|
+
type BuiltInScoring,
|
|
7
|
+
type CrossValEstimator,
|
|
8
|
+
type CrossValSplitter,
|
|
9
|
+
type ScoringFn,
|
|
10
|
+
} from "./crossValScore";
|
|
11
|
+
|
|
12
|
+
export type ParamGrid = Record<string, readonly unknown[]>;
|
|
13
|
+
|
|
14
|
+
export interface GridSearchCVOptions {
|
|
15
|
+
cv?: number | CrossValSplitter;
|
|
16
|
+
scoring?: BuiltInScoring | ScoringFn;
|
|
17
|
+
refit?: boolean;
|
|
18
|
+
errorScore?: "raise" | number;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export interface GridSearchResultRow {
|
|
22
|
+
params: Record<string, unknown>;
|
|
23
|
+
splitScores: number[];
|
|
24
|
+
meanTestScore: number;
|
|
25
|
+
stdTestScore: number;
|
|
26
|
+
rank: number;
|
|
27
|
+
status: "ok" | "error";
|
|
28
|
+
errorMessage?: string;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
function mean(values: number[]): number {
|
|
32
|
+
let sum = 0;
|
|
33
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
34
|
+
sum += values[i];
|
|
35
|
+
}
|
|
36
|
+
return sum / values.length;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
function std(values: number[]): number {
|
|
40
|
+
if (values.length < 2) {
|
|
41
|
+
return 0;
|
|
42
|
+
}
|
|
43
|
+
const avg = mean(values);
|
|
44
|
+
let sum = 0;
|
|
45
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
46
|
+
const diff = values[i] - avg;
|
|
47
|
+
sum += diff * diff;
|
|
48
|
+
}
|
|
49
|
+
return Math.sqrt(sum / values.length);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
function resolveBuiltInScorer(scoring: BuiltInScoring): ScoringFn {
|
|
53
|
+
switch (scoring) {
|
|
54
|
+
case "accuracy":
|
|
55
|
+
return accuracyScore;
|
|
56
|
+
case "f1":
|
|
57
|
+
return f1Score;
|
|
58
|
+
case "precision":
|
|
59
|
+
return precisionScore;
|
|
60
|
+
case "recall":
|
|
61
|
+
return recallScore;
|
|
62
|
+
case "r2":
|
|
63
|
+
return r2Score;
|
|
64
|
+
case "mean_squared_error":
|
|
65
|
+
return meanSquaredError;
|
|
66
|
+
case "neg_mean_squared_error":
|
|
67
|
+
return (yTrue, yPred) => -meanSquaredError(yTrue, yPred);
|
|
68
|
+
default: {
|
|
69
|
+
const exhaustive: never = scoring;
|
|
70
|
+
throw new Error(`Unsupported scoring metric: ${exhaustive}`);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
function isLossMetric(scoring: BuiltInScoring | ScoringFn | undefined): boolean {
|
|
76
|
+
return scoring === "mean_squared_error";
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
function cartesianProduct(grid: ParamGrid): Record<string, unknown>[] {
|
|
80
|
+
const keys = Object.keys(grid);
|
|
81
|
+
if (keys.length === 0) {
|
|
82
|
+
throw new Error("paramGrid must include at least one parameter.");
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
for (let i = 0; i < keys.length; i += 1) {
|
|
86
|
+
const values = grid[keys[i]];
|
|
87
|
+
if (!Array.isArray(values) || values.length === 0) {
|
|
88
|
+
throw new Error(`paramGrid '${keys[i]}' must be a non-empty array.`);
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
const out: Record<string, unknown>[] = [];
|
|
93
|
+
const current: Record<string, unknown> = {};
|
|
94
|
+
|
|
95
|
+
function recurse(depth: number): void {
|
|
96
|
+
if (depth === keys.length) {
|
|
97
|
+
out.push({ ...current });
|
|
98
|
+
return;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
const key = keys[depth];
|
|
102
|
+
const values = grid[key];
|
|
103
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
104
|
+
current[key] = values[i];
|
|
105
|
+
recurse(depth + 1);
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
recurse(0);
|
|
110
|
+
return out;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
export class GridSearchCV<TEstimator extends CrossValEstimator> {
|
|
114
|
+
bestEstimator_: TEstimator | null = null;
|
|
115
|
+
bestParams_: Record<string, unknown> | null = null;
|
|
116
|
+
bestScore_: number | null = null;
|
|
117
|
+
cvResults_: GridSearchResultRow[] = [];
|
|
118
|
+
|
|
119
|
+
private readonly estimatorFactory: (params: Record<string, unknown>) => TEstimator;
|
|
120
|
+
private readonly paramGrid: ParamGrid;
|
|
121
|
+
private readonly cv?: number | CrossValSplitter;
|
|
122
|
+
private readonly scoring?: BuiltInScoring | ScoringFn;
|
|
123
|
+
private readonly refit: boolean;
|
|
124
|
+
private readonly errorScore: "raise" | number;
|
|
125
|
+
private isFitted = false;
|
|
126
|
+
|
|
127
|
+
constructor(
|
|
128
|
+
estimatorFactory: (params: Record<string, unknown>) => TEstimator,
|
|
129
|
+
paramGrid: ParamGrid,
|
|
130
|
+
options: GridSearchCVOptions = {},
|
|
131
|
+
) {
|
|
132
|
+
if (typeof estimatorFactory !== "function") {
|
|
133
|
+
throw new Error("estimatorFactory must be a function.");
|
|
134
|
+
}
|
|
135
|
+
this.estimatorFactory = estimatorFactory;
|
|
136
|
+
this.paramGrid = paramGrid;
|
|
137
|
+
this.cv = options.cv;
|
|
138
|
+
this.scoring = options.scoring;
|
|
139
|
+
this.refit = options.refit ?? true;
|
|
140
|
+
this.errorScore = options.errorScore ?? "raise";
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
fit(X: Matrix, y: Vector): this {
|
|
144
|
+
const candidates = cartesianProduct(this.paramGrid);
|
|
145
|
+
const minimize = isLossMetric(this.scoring);
|
|
146
|
+
const rows: GridSearchResultRow[] = [];
|
|
147
|
+
const objectiveScores: number[] = [];
|
|
148
|
+
|
|
149
|
+
for (let candidateIndex = 0; candidateIndex < candidates.length; candidateIndex += 1) {
|
|
150
|
+
const params = candidates[candidateIndex];
|
|
151
|
+
try {
|
|
152
|
+
const splitScores = crossValScore(
|
|
153
|
+
() => this.estimatorFactory(params),
|
|
154
|
+
X,
|
|
155
|
+
y,
|
|
156
|
+
{ cv: this.cv, scoring: this.scoring },
|
|
157
|
+
);
|
|
158
|
+
const meanTestScore = mean(splitScores);
|
|
159
|
+
rows.push({
|
|
160
|
+
params: { ...params },
|
|
161
|
+
splitScores,
|
|
162
|
+
meanTestScore,
|
|
163
|
+
stdTestScore: std(splitScores),
|
|
164
|
+
rank: 0,
|
|
165
|
+
status: "ok",
|
|
166
|
+
});
|
|
167
|
+
objectiveScores.push(minimize ? -meanTestScore : meanTestScore);
|
|
168
|
+
} catch (error) {
|
|
169
|
+
if (this.errorScore === "raise") {
|
|
170
|
+
throw error;
|
|
171
|
+
}
|
|
172
|
+
rows.push({
|
|
173
|
+
params: { ...params },
|
|
174
|
+
splitScores: [this.errorScore],
|
|
175
|
+
meanTestScore: this.errorScore,
|
|
176
|
+
stdTestScore: 0,
|
|
177
|
+
rank: 0,
|
|
178
|
+
status: "error",
|
|
179
|
+
errorMessage: error instanceof Error ? error.message : String(error),
|
|
180
|
+
});
|
|
181
|
+
objectiveScores.push(minimize ? -this.errorScore : this.errorScore);
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
const order = Array.from({ length: rows.length }, (_, idx) => idx).sort((a, b) => {
|
|
186
|
+
const delta = objectiveScores[b] - objectiveScores[a];
|
|
187
|
+
if (delta !== 0) {
|
|
188
|
+
return delta;
|
|
189
|
+
}
|
|
190
|
+
return a - b;
|
|
191
|
+
});
|
|
192
|
+
|
|
193
|
+
for (let rank = 0; rank < order.length; rank += 1) {
|
|
194
|
+
rows[order[rank]].rank = rank + 1;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
const bestIndex = order[0];
|
|
198
|
+
this.bestParams_ = { ...rows[bestIndex].params };
|
|
199
|
+
this.bestScore_ = rows[bestIndex].meanTestScore;
|
|
200
|
+
this.cvResults_ = rows;
|
|
201
|
+
|
|
202
|
+
if (this.refit) {
|
|
203
|
+
const estimator = this.estimatorFactory(this.bestParams_);
|
|
204
|
+
estimator.fit(X, y);
|
|
205
|
+
this.bestEstimator_ = estimator;
|
|
206
|
+
} else {
|
|
207
|
+
this.bestEstimator_ = null;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
this.isFitted = true;
|
|
211
|
+
return this;
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
predict(X: Matrix): Vector {
|
|
215
|
+
if (!this.isFitted) {
|
|
216
|
+
throw new Error("GridSearchCV has not been fitted.");
|
|
217
|
+
}
|
|
218
|
+
if (!this.refit || !this.bestEstimator_) {
|
|
219
|
+
throw new Error("GridSearchCV predict is unavailable when refit=false.");
|
|
220
|
+
}
|
|
221
|
+
return this.bestEstimator_.predict(X);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
score(X: Matrix, y: Vector): number {
|
|
225
|
+
if (!this.isFitted) {
|
|
226
|
+
throw new Error("GridSearchCV has not been fitted.");
|
|
227
|
+
}
|
|
228
|
+
if (!this.refit || !this.bestEstimator_) {
|
|
229
|
+
throw new Error("GridSearchCV score is unavailable when refit=false.");
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if (this.scoring) {
|
|
233
|
+
const scorer =
|
|
234
|
+
typeof this.scoring === "function" ? this.scoring : resolveBuiltInScorer(this.scoring);
|
|
235
|
+
return scorer(y, this.bestEstimator_.predict(X));
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if (typeof this.bestEstimator_.score === "function") {
|
|
239
|
+
return this.bestEstimator_.score(X, y);
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
throw new Error("No scoring function available. Provide scoring in GridSearchCV options.");
|
|
243
|
+
}
|
|
244
|
+
}
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
export interface FoldIndices {
|
|
2
|
+
trainIndices: number[];
|
|
3
|
+
testIndices: number[];
|
|
4
|
+
}
|
|
5
|
+
|
|
6
|
+
export interface KFoldOptions {
|
|
7
|
+
nSplits?: number;
|
|
8
|
+
shuffle?: boolean;
|
|
9
|
+
randomState?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
function mulberry32(seed: number): () => number {
|
|
13
|
+
let state = seed >>> 0;
|
|
14
|
+
return () => {
|
|
15
|
+
state += 0x6d2b79f5;
|
|
16
|
+
let t = Math.imul(state ^ (state >>> 15), 1 | state);
|
|
17
|
+
t ^= t + Math.imul(t ^ (t >>> 7), 61 | t);
|
|
18
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
19
|
+
};
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
function shuffleInPlace(indices: number[], seed: number): void {
|
|
23
|
+
const random = mulberry32(seed);
|
|
24
|
+
for (let i = indices.length - 1; i > 0; i -= 1) {
|
|
25
|
+
const j = Math.floor(random() * (i + 1));
|
|
26
|
+
const tmp = indices[i];
|
|
27
|
+
indices[i] = indices[j];
|
|
28
|
+
indices[j] = tmp;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export class KFold {
|
|
33
|
+
private readonly nSplits: number;
|
|
34
|
+
private readonly shuffle: boolean;
|
|
35
|
+
private readonly randomState: number;
|
|
36
|
+
|
|
37
|
+
constructor(options: KFoldOptions = {}) {
|
|
38
|
+
this.nSplits = options.nSplits ?? 5;
|
|
39
|
+
this.shuffle = options.shuffle ?? false;
|
|
40
|
+
this.randomState = options.randomState ?? 42;
|
|
41
|
+
|
|
42
|
+
if (!Number.isInteger(this.nSplits) || this.nSplits < 2) {
|
|
43
|
+
throw new Error(`nSplits must be an integer >= 2. Got ${this.nSplits}.`);
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
split<TX>(X: TX[], y?: unknown[]): FoldIndices[] {
|
|
48
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
49
|
+
throw new Error("X must be a non-empty array.");
|
|
50
|
+
}
|
|
51
|
+
if (y && y.length !== X.length) {
|
|
52
|
+
throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
|
|
53
|
+
}
|
|
54
|
+
if (this.nSplits > X.length) {
|
|
55
|
+
throw new Error(`nSplits (${this.nSplits}) cannot exceed sample count (${X.length}).`);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
const indices = Array.from({ length: X.length }, (_, idx) => idx);
|
|
59
|
+
if (this.shuffle) {
|
|
60
|
+
shuffleInPlace(indices, this.randomState);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
const foldSizes = new Array<number>(this.nSplits).fill(Math.floor(X.length / this.nSplits));
|
|
64
|
+
const remainder = X.length % this.nSplits;
|
|
65
|
+
for (let i = 0; i < remainder; i += 1) {
|
|
66
|
+
foldSizes[i] += 1;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
const folds: FoldIndices[] = [];
|
|
70
|
+
let start = 0;
|
|
71
|
+
for (let foldIdx = 0; foldIdx < this.nSplits; foldIdx += 1) {
|
|
72
|
+
const size = foldSizes[foldIdx];
|
|
73
|
+
const end = start + size;
|
|
74
|
+
const testIndices = indices.slice(start, end);
|
|
75
|
+
const trainIndices = indices.slice(0, start).concat(indices.slice(end));
|
|
76
|
+
folds.push({ trainIndices, testIndices });
|
|
77
|
+
start = end;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
return folds;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import { KFold, type FoldIndices } from "./KFold";
|
|
2
|
+
|
|
3
|
+
export interface RepeatedKFoldOptions {
|
|
4
|
+
nSplits?: number;
|
|
5
|
+
nRepeats?: number;
|
|
6
|
+
randomState?: number;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
export class RepeatedKFold {
|
|
10
|
+
private readonly nSplits: number;
|
|
11
|
+
private readonly nRepeats: number;
|
|
12
|
+
private readonly randomState: number;
|
|
13
|
+
|
|
14
|
+
constructor(options: RepeatedKFoldOptions = {}) {
|
|
15
|
+
this.nSplits = options.nSplits ?? 5;
|
|
16
|
+
this.nRepeats = options.nRepeats ?? 10;
|
|
17
|
+
this.randomState = options.randomState ?? 42;
|
|
18
|
+
|
|
19
|
+
if (!Number.isInteger(this.nSplits) || this.nSplits < 2) {
|
|
20
|
+
throw new Error(`nSplits must be an integer >= 2. Got ${this.nSplits}.`);
|
|
21
|
+
}
|
|
22
|
+
if (!Number.isInteger(this.nRepeats) || this.nRepeats < 1) {
|
|
23
|
+
throw new Error(`nRepeats must be an integer >= 1. Got ${this.nRepeats}.`);
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
split<TX>(X: TX[], y?: unknown[]): FoldIndices[] {
|
|
28
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
29
|
+
throw new Error("X must be a non-empty array.");
|
|
30
|
+
}
|
|
31
|
+
if (y && y.length !== X.length) {
|
|
32
|
+
throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
const allFolds: FoldIndices[] = [];
|
|
36
|
+
for (let repeat = 0; repeat < this.nRepeats; repeat += 1) {
|
|
37
|
+
const splitter = new KFold({
|
|
38
|
+
nSplits: this.nSplits,
|
|
39
|
+
shuffle: true,
|
|
40
|
+
randomState: this.randomState + repeat * 104_729,
|
|
41
|
+
});
|
|
42
|
+
const folds = splitter.split(X, y);
|
|
43
|
+
for (let i = 0; i < folds.length; i += 1) {
|
|
44
|
+
allFolds.push(folds[i]);
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
return allFolds;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import type { FoldIndices } from "./KFold";
|
|
2
|
+
import { StratifiedKFold } from "./StratifiedKFold";
|
|
3
|
+
|
|
4
|
+
export interface RepeatedStratifiedKFoldOptions {
|
|
5
|
+
nSplits?: number;
|
|
6
|
+
nRepeats?: number;
|
|
7
|
+
randomState?: number;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
export class RepeatedStratifiedKFold {
|
|
11
|
+
private readonly nSplits: number;
|
|
12
|
+
private readonly nRepeats: number;
|
|
13
|
+
private readonly randomState: number;
|
|
14
|
+
|
|
15
|
+
constructor(options: RepeatedStratifiedKFoldOptions = {}) {
|
|
16
|
+
this.nSplits = options.nSplits ?? 5;
|
|
17
|
+
this.nRepeats = options.nRepeats ?? 10;
|
|
18
|
+
this.randomState = options.randomState ?? 42;
|
|
19
|
+
|
|
20
|
+
if (!Number.isInteger(this.nSplits) || this.nSplits < 2) {
|
|
21
|
+
throw new Error(`nSplits must be an integer >= 2. Got ${this.nSplits}.`);
|
|
22
|
+
}
|
|
23
|
+
if (!Number.isInteger(this.nRepeats) || this.nRepeats < 1) {
|
|
24
|
+
throw new Error(`nRepeats must be an integer >= 1. Got ${this.nRepeats}.`);
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
split<TX>(X: TX[], y: number[]): FoldIndices[] {
|
|
29
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
30
|
+
throw new Error("X must be a non-empty array.");
|
|
31
|
+
}
|
|
32
|
+
if (!Array.isArray(y) || y.length !== X.length) {
|
|
33
|
+
throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const allFolds: FoldIndices[] = [];
|
|
37
|
+
for (let repeat = 0; repeat < this.nRepeats; repeat += 1) {
|
|
38
|
+
const splitter = new StratifiedKFold({
|
|
39
|
+
nSplits: this.nSplits,
|
|
40
|
+
shuffle: true,
|
|
41
|
+
randomState: this.randomState + repeat * 104_729,
|
|
42
|
+
});
|
|
43
|
+
const folds = splitter.split(X, y);
|
|
44
|
+
for (let i = 0; i < folds.length; i += 1) {
|
|
45
|
+
allFolds.push(folds[i]);
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
return allFolds;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import type { FoldIndices } from "./KFold";
|
|
2
|
+
|
|
3
|
+
export interface StratifiedKFoldOptions {
|
|
4
|
+
nSplits?: number;
|
|
5
|
+
shuffle?: boolean;
|
|
6
|
+
randomState?: number;
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
function mulberry32(seed: number): () => number {
|
|
10
|
+
let state = seed >>> 0;
|
|
11
|
+
return () => {
|
|
12
|
+
state += 0x6d2b79f5;
|
|
13
|
+
let t = Math.imul(state ^ (state >>> 15), 1 | state);
|
|
14
|
+
t ^= t + Math.imul(t ^ (t >>> 7), 61 | t);
|
|
15
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
16
|
+
};
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
function shuffleInPlace(indices: number[], random: () => number): void {
|
|
20
|
+
for (let i = indices.length - 1; i > 0; i -= 1) {
|
|
21
|
+
const j = Math.floor(random() * (i + 1));
|
|
22
|
+
const tmp = indices[i];
|
|
23
|
+
indices[i] = indices[j];
|
|
24
|
+
indices[j] = tmp;
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export class StratifiedKFold {
|
|
29
|
+
private readonly nSplits: number;
|
|
30
|
+
private readonly shuffle: boolean;
|
|
31
|
+
private readonly randomState: number;
|
|
32
|
+
|
|
33
|
+
constructor(options: StratifiedKFoldOptions = {}) {
|
|
34
|
+
this.nSplits = options.nSplits ?? 5;
|
|
35
|
+
this.shuffle = options.shuffle ?? false;
|
|
36
|
+
this.randomState = options.randomState ?? 42;
|
|
37
|
+
|
|
38
|
+
if (!Number.isInteger(this.nSplits) || this.nSplits < 2) {
|
|
39
|
+
throw new Error(`nSplits must be an integer >= 2. Got ${this.nSplits}.`);
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
split<TX>(X: TX[], y: number[]): FoldIndices[] {
|
|
44
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
45
|
+
throw new Error("X must be a non-empty array.");
|
|
46
|
+
}
|
|
47
|
+
if (!Array.isArray(y) || y.length !== X.length) {
|
|
48
|
+
throw new Error(`X and y must have the same length. Got ${X.length} and ${y.length}.`);
|
|
49
|
+
}
|
|
50
|
+
if (this.nSplits > X.length) {
|
|
51
|
+
throw new Error(`nSplits (${this.nSplits}) cannot exceed sample count (${X.length}).`);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
const byClass = new Map<number, number[]>();
|
|
55
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
56
|
+
const label = y[i];
|
|
57
|
+
const bucket = byClass.get(label);
|
|
58
|
+
if (bucket) {
|
|
59
|
+
bucket.push(i);
|
|
60
|
+
} else {
|
|
61
|
+
byClass.set(label, [i]);
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
if (byClass.size < 2) {
|
|
66
|
+
throw new Error("StratifiedKFold requires at least two distinct classes.");
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
let minClassCount = Number.POSITIVE_INFINITY;
|
|
70
|
+
for (const indices of byClass.values()) {
|
|
71
|
+
if (indices.length < minClassCount) {
|
|
72
|
+
minClassCount = indices.length;
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
if (minClassCount < this.nSplits) {
|
|
76
|
+
throw new Error(
|
|
77
|
+
`nSplits (${this.nSplits}) cannot exceed the smallest class count (${minClassCount}).`,
|
|
78
|
+
);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
const foldTestIndices = Array.from({ length: this.nSplits }, () => new Array<number>());
|
|
82
|
+
const random = mulberry32(this.randomState);
|
|
83
|
+
|
|
84
|
+
for (const classIndices of byClass.values()) {
|
|
85
|
+
const working = classIndices.slice();
|
|
86
|
+
if (this.shuffle) {
|
|
87
|
+
shuffleInPlace(working, random);
|
|
88
|
+
}
|
|
89
|
+
for (let i = 0; i < working.length; i += 1) {
|
|
90
|
+
foldTestIndices[i % this.nSplits].push(working[i]);
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
const folds: FoldIndices[] = [];
|
|
95
|
+
for (let foldIdx = 0; foldIdx < this.nSplits; foldIdx += 1) {
|
|
96
|
+
const testIndices = foldTestIndices[foldIdx].slice().sort((a, b) => a - b);
|
|
97
|
+
const testMask = new Uint8Array(X.length);
|
|
98
|
+
for (let i = 0; i < testIndices.length; i += 1) {
|
|
99
|
+
testMask[testIndices[i]] = 1;
|
|
100
|
+
}
|
|
101
|
+
const trainIndices = new Array<number>();
|
|
102
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
103
|
+
if (testMask[i] === 0) {
|
|
104
|
+
trainIndices.push(i);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
folds.push({ trainIndices, testIndices });
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
return folds;
|
|
111
|
+
}
|
|
112
|
+
}
|