bun-scikit 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +2 -2
- package/scripts/check-benchmark-health.ts +62 -1
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +149 -0
- package/src/native/zigKernels.ts +33 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +146 -3
- package/zig/kernels.zig +63 -40
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { r2Score } from "../metrics/regression";
|
|
3
|
+
import { assertFiniteVector, validateRegressionInputs } from "../utils/validation";
|
|
4
|
+
|
|
5
|
+
export type DummyRegressorStrategy = "mean" | "median" | "quantile" | "constant";
|
|
6
|
+
|
|
7
|
+
export interface DummyRegressorOptions {
|
|
8
|
+
strategy?: DummyRegressorStrategy;
|
|
9
|
+
constant?: number;
|
|
10
|
+
quantile?: number;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
function computeMedian(values: number[]): number {
|
|
14
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
15
|
+
const mid = Math.floor(sorted.length / 2);
|
|
16
|
+
if (sorted.length % 2 === 0) {
|
|
17
|
+
return 0.5 * (sorted[mid - 1] + sorted[mid]);
|
|
18
|
+
}
|
|
19
|
+
return sorted[mid];
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
function computeQuantile(values: number[], q: number): number {
|
|
23
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
24
|
+
const pos = q * (sorted.length - 1);
|
|
25
|
+
const lo = Math.floor(pos);
|
|
26
|
+
const hi = Math.ceil(pos);
|
|
27
|
+
if (lo === hi) {
|
|
28
|
+
return sorted[lo];
|
|
29
|
+
}
|
|
30
|
+
const weight = pos - lo;
|
|
31
|
+
return sorted[lo] * (1 - weight) + sorted[hi] * weight;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export class DummyRegressor {
|
|
35
|
+
constant_: number | null = null;
|
|
36
|
+
|
|
37
|
+
private readonly strategy: DummyRegressorStrategy;
|
|
38
|
+
private readonly constant?: number;
|
|
39
|
+
private readonly quantile: number;
|
|
40
|
+
private nFeaturesIn_: number | null = null;
|
|
41
|
+
|
|
42
|
+
constructor(options: DummyRegressorOptions = {}) {
|
|
43
|
+
this.strategy = options.strategy ?? "mean";
|
|
44
|
+
this.constant = options.constant;
|
|
45
|
+
this.quantile = options.quantile ?? 0.5;
|
|
46
|
+
|
|
47
|
+
if (this.strategy === "constant") {
|
|
48
|
+
if (!Number.isFinite(this.constant)) {
|
|
49
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if (this.strategy === "quantile") {
|
|
54
|
+
if (!Number.isFinite(this.quantile) || this.quantile < 0 || this.quantile > 1) {
|
|
55
|
+
throw new Error(`quantile must be in [0, 1]. Got ${this.quantile}.`);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
fit(X: Matrix, y: Vector): this {
|
|
61
|
+
validateRegressionInputs(X, y);
|
|
62
|
+
this.nFeaturesIn_ = X[0].length;
|
|
63
|
+
|
|
64
|
+
switch (this.strategy) {
|
|
65
|
+
case "mean": {
|
|
66
|
+
let total = 0;
|
|
67
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
68
|
+
total += y[i];
|
|
69
|
+
}
|
|
70
|
+
this.constant_ = total / y.length;
|
|
71
|
+
break;
|
|
72
|
+
}
|
|
73
|
+
case "median":
|
|
74
|
+
this.constant_ = computeMedian(y);
|
|
75
|
+
break;
|
|
76
|
+
case "quantile":
|
|
77
|
+
this.constant_ = computeQuantile(y, this.quantile);
|
|
78
|
+
break;
|
|
79
|
+
case "constant":
|
|
80
|
+
this.constant_ = this.constant!;
|
|
81
|
+
break;
|
|
82
|
+
default: {
|
|
83
|
+
const exhaustive: never = this.strategy;
|
|
84
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
return this;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
predict(X: Matrix): Vector {
|
|
92
|
+
if (this.constant_ === null || this.nFeaturesIn_ === null) {
|
|
93
|
+
throw new Error("DummyRegressor has not been fitted.");
|
|
94
|
+
}
|
|
95
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
96
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
97
|
+
}
|
|
98
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
99
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
100
|
+
}
|
|
101
|
+
return new Array<number>(X.length).fill(this.constant_);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
score(X: Matrix, y: Vector): number {
|
|
105
|
+
assertFiniteVector(y);
|
|
106
|
+
return r2Score(y, this.predict(X));
|
|
107
|
+
}
|
|
108
|
+
}
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface VarianceThresholdOptions {
|
|
9
|
+
threshold?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class VarianceThreshold {
|
|
13
|
+
variances_: number[] | null = null;
|
|
14
|
+
nFeaturesIn_: number | null = null;
|
|
15
|
+
selectedFeatureIndices_: number[] | null = null;
|
|
16
|
+
|
|
17
|
+
private readonly threshold: number;
|
|
18
|
+
|
|
19
|
+
constructor(options: VarianceThresholdOptions = {}) {
|
|
20
|
+
this.threshold = options.threshold ?? 0;
|
|
21
|
+
if (!Number.isFinite(this.threshold) || this.threshold < 0) {
|
|
22
|
+
throw new Error(`threshold must be finite and >= 0. Got ${this.threshold}.`);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
fit(X: Matrix): this {
|
|
27
|
+
assertNonEmptyMatrix(X);
|
|
28
|
+
assertConsistentRowSize(X);
|
|
29
|
+
assertFiniteMatrix(X);
|
|
30
|
+
|
|
31
|
+
const nSamples = X.length;
|
|
32
|
+
const nFeatures = X[0].length;
|
|
33
|
+
const means = new Array<number>(nFeatures).fill(0);
|
|
34
|
+
const variances = new Array<number>(nFeatures).fill(0);
|
|
35
|
+
|
|
36
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
37
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
38
|
+
means[j] += X[i][j];
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
42
|
+
means[j] /= nSamples;
|
|
43
|
+
}
|
|
44
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
45
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
46
|
+
const diff = X[i][j] - means[j];
|
|
47
|
+
variances[j] += diff * diff;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
51
|
+
variances[j] /= nSamples;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
const selectedFeatureIndices: number[] = [];
|
|
55
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
56
|
+
if (variances[j] > this.threshold) {
|
|
57
|
+
selectedFeatureIndices.push(j);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
if (selectedFeatureIndices.length === 0) {
|
|
61
|
+
throw new Error("No feature in X meets the variance threshold.");
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
this.nFeaturesIn_ = nFeatures;
|
|
65
|
+
this.variances_ = variances;
|
|
66
|
+
this.selectedFeatureIndices_ = selectedFeatureIndices;
|
|
67
|
+
return this;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
transform(X: Matrix): Matrix {
|
|
71
|
+
if (!this.selectedFeatureIndices_ || this.nFeaturesIn_ === null) {
|
|
72
|
+
throw new Error("VarianceThreshold has not been fitted.");
|
|
73
|
+
}
|
|
74
|
+
assertNonEmptyMatrix(X);
|
|
75
|
+
assertConsistentRowSize(X);
|
|
76
|
+
assertFiniteMatrix(X);
|
|
77
|
+
|
|
78
|
+
if (X[0].length !== this.nFeaturesIn_) {
|
|
79
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
return X.map((row) => this.selectedFeatureIndices_!.map((featureIdx) => row[featureIdx]));
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fitTransform(X: Matrix): Matrix {
|
|
86
|
+
return this.fit(X).transform(X);
|
|
87
|
+
}
|
|
88
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -1,15 +1,28 @@
|
|
|
1
1
|
export * from "./types";
|
|
2
2
|
|
|
3
|
+
// Baselines
|
|
4
|
+
export * from "./dummy/DummyClassifier";
|
|
5
|
+
export * from "./dummy/DummyRegressor";
|
|
6
|
+
|
|
7
|
+
// Preprocessing
|
|
3
8
|
export * from "./preprocessing/StandardScaler";
|
|
4
9
|
export * from "./preprocessing/MinMaxScaler";
|
|
5
10
|
export * from "./preprocessing/RobustScaler";
|
|
11
|
+
export * from "./preprocessing/MaxAbsScaler";
|
|
12
|
+
export * from "./preprocessing/Normalizer";
|
|
13
|
+
export * from "./preprocessing/Binarizer";
|
|
14
|
+
export * from "./preprocessing/LabelEncoder";
|
|
6
15
|
export * from "./preprocessing/PolynomialFeatures";
|
|
7
16
|
export * from "./preprocessing/SimpleImputer";
|
|
8
17
|
export * from "./preprocessing/OneHotEncoder";
|
|
18
|
+
|
|
19
|
+
// Linear models
|
|
9
20
|
export * from "./linear_model/LinearRegression";
|
|
10
21
|
export * from "./linear_model/LogisticRegression";
|
|
11
22
|
export * from "./linear_model/SGDClassifier";
|
|
12
23
|
export * from "./linear_model/SGDRegressor";
|
|
24
|
+
|
|
25
|
+
// Other estimators
|
|
13
26
|
export * from "./neighbors/KNeighborsClassifier";
|
|
14
27
|
export * from "./naive_bayes/GaussianNB";
|
|
15
28
|
export * from "./svm/LinearSVC";
|
|
@@ -17,6 +30,8 @@ export * from "./tree/DecisionTreeClassifier";
|
|
|
17
30
|
export * from "./tree/DecisionTreeRegressor";
|
|
18
31
|
export * from "./ensemble/RandomForestClassifier";
|
|
19
32
|
export * from "./ensemble/RandomForestRegressor";
|
|
33
|
+
|
|
34
|
+
// Model selection
|
|
20
35
|
export * from "./model_selection/trainTestSplit";
|
|
21
36
|
export * from "./model_selection/KFold";
|
|
22
37
|
export * from "./model_selection/StratifiedKFold";
|
|
@@ -25,8 +40,16 @@ export * from "./model_selection/RepeatedKFold";
|
|
|
25
40
|
export * from "./model_selection/RepeatedStratifiedKFold";
|
|
26
41
|
export * from "./model_selection/crossValScore";
|
|
27
42
|
export * from "./model_selection/GridSearchCV";
|
|
43
|
+
export * from "./model_selection/RandomizedSearchCV";
|
|
44
|
+
|
|
45
|
+
// Feature selection
|
|
46
|
+
export * from "./feature_selection/VarianceThreshold";
|
|
47
|
+
|
|
48
|
+
// Composition
|
|
28
49
|
export * from "./pipeline/Pipeline";
|
|
29
50
|
export * from "./pipeline/ColumnTransformer";
|
|
30
51
|
export * from "./pipeline/FeatureUnion";
|
|
52
|
+
|
|
53
|
+
// Metrics
|
|
31
54
|
export * from "./metrics/regression";
|
|
32
55
|
export * from "./metrics/classification";
|
|
@@ -292,3 +292,33 @@ export function classificationReport(
|
|
|
292
292
|
},
|
|
293
293
|
};
|
|
294
294
|
}
|
|
295
|
+
|
|
296
|
+
export function balancedAccuracyScore(yTrue: number[], yPred: number[]): number {
|
|
297
|
+
const report = classificationReport(yTrue, yPred);
|
|
298
|
+
return report.macroAvg.recall;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
export function matthewsCorrcoef(
|
|
302
|
+
yTrue: number[],
|
|
303
|
+
yPred: number[],
|
|
304
|
+
positiveLabel = 1,
|
|
305
|
+
): number {
|
|
306
|
+
const { tp, fp, fn, tn } = confusionCounts(yTrue, yPred, positiveLabel);
|
|
307
|
+
const numerator = tp * tn - fp * fn;
|
|
308
|
+
const denominator = Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
|
|
309
|
+
if (denominator === 0) {
|
|
310
|
+
return 0;
|
|
311
|
+
}
|
|
312
|
+
return numerator / denominator;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
export function brierScoreLoss(yTrue: number[], yPredProb: number[]): number {
|
|
316
|
+
validateInputs(yTrue, yPredProb);
|
|
317
|
+
validateBinaryTargets(yTrue);
|
|
318
|
+
let total = 0;
|
|
319
|
+
for (let i = 0; i < yTrue.length; i += 1) {
|
|
320
|
+
const diff = yPredProb[i] - yTrue[i];
|
|
321
|
+
total += diff * diff;
|
|
322
|
+
}
|
|
323
|
+
return total / yTrue.length;
|
|
324
|
+
}
|
|
@@ -49,3 +49,43 @@ export function r2Score(yTrue: number[], yPred: number[]): number {
|
|
|
49
49
|
|
|
50
50
|
return 1 - ssRes / ssTot;
|
|
51
51
|
}
|
|
52
|
+
|
|
53
|
+
export function meanAbsolutePercentageError(yTrue: number[], yPred: number[]): number {
|
|
54
|
+
validateInputs(yTrue, yPred);
|
|
55
|
+
let total = 0;
|
|
56
|
+
for (let i = 0; i < yTrue.length; i += 1) {
|
|
57
|
+
const denom = Math.max(Math.abs(yTrue[i]), 1e-12);
|
|
58
|
+
total += Math.abs((yTrue[i] - yPred[i]) / denom);
|
|
59
|
+
}
|
|
60
|
+
return total / yTrue.length;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export function explainedVarianceScore(yTrue: number[], yPred: number[]): number {
|
|
64
|
+
validateInputs(yTrue, yPred);
|
|
65
|
+
const n = yTrue.length;
|
|
66
|
+
const yTrueMean = mean(yTrue);
|
|
67
|
+
const residuals = new Array<number>(n);
|
|
68
|
+
let residualMean = 0;
|
|
69
|
+
for (let i = 0; i < n; i += 1) {
|
|
70
|
+
const r = yTrue[i] - yPred[i];
|
|
71
|
+
residuals[i] = r;
|
|
72
|
+
residualMean += r;
|
|
73
|
+
}
|
|
74
|
+
residualMean /= n;
|
|
75
|
+
|
|
76
|
+
let varTrue = 0;
|
|
77
|
+
let varResidual = 0;
|
|
78
|
+
for (let i = 0; i < n; i += 1) {
|
|
79
|
+
const centeredY = yTrue[i] - yTrueMean;
|
|
80
|
+
const centeredR = residuals[i] - residualMean;
|
|
81
|
+
varTrue += centeredY * centeredY;
|
|
82
|
+
varResidual += centeredR * centeredR;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
varTrue /= n;
|
|
86
|
+
varResidual /= n;
|
|
87
|
+
if (varTrue === 0) {
|
|
88
|
+
return varResidual === 0 ? 1 : 0;
|
|
89
|
+
}
|
|
90
|
+
return 1 - varResidual / varTrue;
|
|
91
|
+
}
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { accuracyScore, f1Score, precisionScore, recallScore } from "../metrics/classification";
|
|
3
|
+
import { meanSquaredError, r2Score } from "../metrics/regression";
|
|
4
|
+
import {
|
|
5
|
+
crossValScore,
|
|
6
|
+
type BuiltInScoring,
|
|
7
|
+
type CrossValEstimator,
|
|
8
|
+
type CrossValSplitter,
|
|
9
|
+
type ScoringFn,
|
|
10
|
+
} from "./crossValScore";
|
|
11
|
+
|
|
12
|
+
export type ParamDistributions = Record<string, readonly unknown[]>;
|
|
13
|
+
|
|
14
|
+
export interface RandomizedSearchCVOptions {
|
|
15
|
+
cv?: number | CrossValSplitter;
|
|
16
|
+
scoring?: BuiltInScoring | ScoringFn;
|
|
17
|
+
refit?: boolean;
|
|
18
|
+
errorScore?: "raise" | number;
|
|
19
|
+
nIter?: number;
|
|
20
|
+
randomState?: number;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export interface RandomizedSearchResultRow {
|
|
24
|
+
params: Record<string, unknown>;
|
|
25
|
+
splitScores: number[];
|
|
26
|
+
meanTestScore: number;
|
|
27
|
+
stdTestScore: number;
|
|
28
|
+
rank: number;
|
|
29
|
+
status: "ok" | "error";
|
|
30
|
+
errorMessage?: string;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
function mean(values: number[]): number {
|
|
34
|
+
let sum = 0;
|
|
35
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
36
|
+
sum += values[i];
|
|
37
|
+
}
|
|
38
|
+
return sum / values.length;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
function std(values: number[]): number {
|
|
42
|
+
if (values.length < 2) {
|
|
43
|
+
return 0;
|
|
44
|
+
}
|
|
45
|
+
const avg = mean(values);
|
|
46
|
+
let sum = 0;
|
|
47
|
+
for (let i = 0; i < values.length; i += 1) {
|
|
48
|
+
const diff = values[i] - avg;
|
|
49
|
+
sum += diff * diff;
|
|
50
|
+
}
|
|
51
|
+
return Math.sqrt(sum / values.length);
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
function resolveBuiltInScorer(scoring: BuiltInScoring): ScoringFn {
|
|
55
|
+
switch (scoring) {
|
|
56
|
+
case "accuracy":
|
|
57
|
+
return accuracyScore;
|
|
58
|
+
case "f1":
|
|
59
|
+
return f1Score;
|
|
60
|
+
case "precision":
|
|
61
|
+
return precisionScore;
|
|
62
|
+
case "recall":
|
|
63
|
+
return recallScore;
|
|
64
|
+
case "r2":
|
|
65
|
+
return r2Score;
|
|
66
|
+
case "mean_squared_error":
|
|
67
|
+
return meanSquaredError;
|
|
68
|
+
case "neg_mean_squared_error":
|
|
69
|
+
return (yTrue, yPred) => -meanSquaredError(yTrue, yPred);
|
|
70
|
+
default: {
|
|
71
|
+
const exhaustive: never = scoring;
|
|
72
|
+
throw new Error(`Unsupported scoring metric: ${exhaustive}`);
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
function isLossMetric(scoring: BuiltInScoring | ScoringFn | undefined): boolean {
|
|
78
|
+
return scoring === "mean_squared_error";
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
class Mulberry32 {
|
|
82
|
+
private state: number;
|
|
83
|
+
|
|
84
|
+
constructor(seed: number) {
|
|
85
|
+
this.state = seed >>> 0;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
next(): number {
|
|
89
|
+
this.state = (this.state + 0x6d2b79f5) >>> 0;
|
|
90
|
+
let t = this.state ^ (this.state >>> 15);
|
|
91
|
+
t = Math.imul(t, this.state | 1);
|
|
92
|
+
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
|
|
93
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
nextInt(maxExclusive: number): number {
|
|
97
|
+
return Math.floor(this.next() * maxExclusive);
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
function sampleParams(
|
|
102
|
+
distributions: ParamDistributions,
|
|
103
|
+
nIter: number,
|
|
104
|
+
randomState: number,
|
|
105
|
+
): Record<string, unknown>[] {
|
|
106
|
+
const keys = Object.keys(distributions);
|
|
107
|
+
if (keys.length === 0) {
|
|
108
|
+
throw new Error("paramDistributions must include at least one parameter.");
|
|
109
|
+
}
|
|
110
|
+
for (let i = 0; i < keys.length; i += 1) {
|
|
111
|
+
const values = distributions[keys[i]];
|
|
112
|
+
if (!Array.isArray(values) || values.length === 0) {
|
|
113
|
+
throw new Error(`paramDistributions '${keys[i]}' must be a non-empty array.`);
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
const rng = new Mulberry32(randomState);
|
|
118
|
+
const out: Record<string, unknown>[] = [];
|
|
119
|
+
for (let i = 0; i < nIter; i += 1) {
|
|
120
|
+
const params: Record<string, unknown> = {};
|
|
121
|
+
for (let k = 0; k < keys.length; k += 1) {
|
|
122
|
+
const key = keys[k];
|
|
123
|
+
const values = distributions[key];
|
|
124
|
+
params[key] = values[rng.nextInt(values.length)];
|
|
125
|
+
}
|
|
126
|
+
out.push(params);
|
|
127
|
+
}
|
|
128
|
+
return out;
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
export class RandomizedSearchCV<TEstimator extends CrossValEstimator> {
|
|
132
|
+
bestEstimator_: TEstimator | null = null;
|
|
133
|
+
bestParams_: Record<string, unknown> | null = null;
|
|
134
|
+
bestScore_: number | null = null;
|
|
135
|
+
cvResults_: RandomizedSearchResultRow[] = [];
|
|
136
|
+
|
|
137
|
+
private readonly estimatorFactory: (params: Record<string, unknown>) => TEstimator;
|
|
138
|
+
private readonly paramDistributions: ParamDistributions;
|
|
139
|
+
private readonly cv?: number | CrossValSplitter;
|
|
140
|
+
private readonly scoring?: BuiltInScoring | ScoringFn;
|
|
141
|
+
private readonly refit: boolean;
|
|
142
|
+
private readonly errorScore: "raise" | number;
|
|
143
|
+
private readonly nIter: number;
|
|
144
|
+
private readonly randomState: number;
|
|
145
|
+
private isFitted = false;
|
|
146
|
+
|
|
147
|
+
constructor(
|
|
148
|
+
estimatorFactory: (params: Record<string, unknown>) => TEstimator,
|
|
149
|
+
paramDistributions: ParamDistributions,
|
|
150
|
+
options: RandomizedSearchCVOptions = {},
|
|
151
|
+
) {
|
|
152
|
+
if (typeof estimatorFactory !== "function") {
|
|
153
|
+
throw new Error("estimatorFactory must be a function.");
|
|
154
|
+
}
|
|
155
|
+
this.estimatorFactory = estimatorFactory;
|
|
156
|
+
this.paramDistributions = paramDistributions;
|
|
157
|
+
this.cv = options.cv;
|
|
158
|
+
this.scoring = options.scoring;
|
|
159
|
+
this.refit = options.refit ?? true;
|
|
160
|
+
this.errorScore = options.errorScore ?? "raise";
|
|
161
|
+
this.nIter = options.nIter ?? 10;
|
|
162
|
+
this.randomState = options.randomState ?? 42;
|
|
163
|
+
if (!Number.isInteger(this.nIter) || this.nIter < 1) {
|
|
164
|
+
throw new Error(`nIter must be an integer >= 1. Got ${this.nIter}.`);
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
fit(X: Matrix, y: Vector): this {
|
|
169
|
+
const candidates = sampleParams(this.paramDistributions, this.nIter, this.randomState);
|
|
170
|
+
const minimize = isLossMetric(this.scoring);
|
|
171
|
+
const rows: RandomizedSearchResultRow[] = [];
|
|
172
|
+
const objectiveScores: number[] = [];
|
|
173
|
+
|
|
174
|
+
for (let candidateIndex = 0; candidateIndex < candidates.length; candidateIndex += 1) {
|
|
175
|
+
const params = candidates[candidateIndex];
|
|
176
|
+
try {
|
|
177
|
+
const splitScores = crossValScore(
|
|
178
|
+
() => this.estimatorFactory(params),
|
|
179
|
+
X,
|
|
180
|
+
y,
|
|
181
|
+
{ cv: this.cv, scoring: this.scoring },
|
|
182
|
+
);
|
|
183
|
+
const meanTestScore = mean(splitScores);
|
|
184
|
+
rows.push({
|
|
185
|
+
params: { ...params },
|
|
186
|
+
splitScores,
|
|
187
|
+
meanTestScore,
|
|
188
|
+
stdTestScore: std(splitScores),
|
|
189
|
+
rank: 0,
|
|
190
|
+
status: "ok",
|
|
191
|
+
});
|
|
192
|
+
objectiveScores.push(minimize ? -meanTestScore : meanTestScore);
|
|
193
|
+
} catch (error) {
|
|
194
|
+
if (this.errorScore === "raise") {
|
|
195
|
+
throw error;
|
|
196
|
+
}
|
|
197
|
+
rows.push({
|
|
198
|
+
params: { ...params },
|
|
199
|
+
splitScores: [this.errorScore],
|
|
200
|
+
meanTestScore: this.errorScore,
|
|
201
|
+
stdTestScore: 0,
|
|
202
|
+
rank: 0,
|
|
203
|
+
status: "error",
|
|
204
|
+
errorMessage: error instanceof Error ? error.message : String(error),
|
|
205
|
+
});
|
|
206
|
+
objectiveScores.push(minimize ? -this.errorScore : this.errorScore);
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
const order = Array.from({ length: rows.length }, (_, idx) => idx).sort((a, b) => {
|
|
211
|
+
const delta = objectiveScores[b] - objectiveScores[a];
|
|
212
|
+
if (delta !== 0) {
|
|
213
|
+
return delta;
|
|
214
|
+
}
|
|
215
|
+
return a - b;
|
|
216
|
+
});
|
|
217
|
+
|
|
218
|
+
for (let rank = 0; rank < order.length; rank += 1) {
|
|
219
|
+
rows[order[rank]].rank = rank + 1;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
const bestIndex = order[0];
|
|
223
|
+
this.bestParams_ = { ...rows[bestIndex].params };
|
|
224
|
+
this.bestScore_ = rows[bestIndex].meanTestScore;
|
|
225
|
+
this.cvResults_ = rows;
|
|
226
|
+
|
|
227
|
+
if (this.refit) {
|
|
228
|
+
const estimator = this.estimatorFactory(this.bestParams_);
|
|
229
|
+
estimator.fit(X, y);
|
|
230
|
+
this.bestEstimator_ = estimator;
|
|
231
|
+
} else {
|
|
232
|
+
this.bestEstimator_ = null;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
this.isFitted = true;
|
|
236
|
+
return this;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
predict(X: Matrix): Vector {
|
|
240
|
+
if (!this.isFitted) {
|
|
241
|
+
throw new Error("RandomizedSearchCV has not been fitted.");
|
|
242
|
+
}
|
|
243
|
+
if (!this.refit || !this.bestEstimator_) {
|
|
244
|
+
throw new Error("RandomizedSearchCV predict is unavailable when refit=false.");
|
|
245
|
+
}
|
|
246
|
+
return this.bestEstimator_.predict(X);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
score(X: Matrix, y: Vector): number {
|
|
250
|
+
if (!this.isFitted) {
|
|
251
|
+
throw new Error("RandomizedSearchCV has not been fitted.");
|
|
252
|
+
}
|
|
253
|
+
if (!this.refit || !this.bestEstimator_) {
|
|
254
|
+
throw new Error("RandomizedSearchCV score is unavailable when refit=false.");
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
if (this.scoring) {
|
|
258
|
+
const scorer =
|
|
259
|
+
typeof this.scoring === "function" ? this.scoring : resolveBuiltInScorer(this.scoring);
|
|
260
|
+
return scorer(y, this.bestEstimator_.predict(X));
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if (typeof this.bestEstimator_.score === "function") {
|
|
264
|
+
return this.bestEstimator_.score(X, y);
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
throw new Error("No scoring function available. Provide scoring in RandomizedSearchCV options.");
|
|
268
|
+
}
|
|
269
|
+
}
|