bun-scikit 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/docs/README.md +1 -0
- package/package.json +2 -3
- package/scripts/check-benchmark-health.ts +162 -1
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +149 -0
- package/src/native/zigKernels.ts +33 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +146 -3
- package/zig/kernels.zig +63 -40
- package/binding.gyp +0 -21
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { accuracyScore } from "../metrics/classification";
|
|
3
|
+
import {
|
|
4
|
+
assertConsistentRowSize,
|
|
5
|
+
assertFiniteMatrix,
|
|
6
|
+
assertFiniteVector,
|
|
7
|
+
assertNonEmptyMatrix,
|
|
8
|
+
assertVectorLength,
|
|
9
|
+
} from "../utils/validation";
|
|
10
|
+
|
|
11
|
+
export type DummyClassifierStrategy =
|
|
12
|
+
| "most_frequent"
|
|
13
|
+
| "prior"
|
|
14
|
+
| "stratified"
|
|
15
|
+
| "uniform"
|
|
16
|
+
| "constant";
|
|
17
|
+
|
|
18
|
+
export interface DummyClassifierOptions {
|
|
19
|
+
strategy?: DummyClassifierStrategy;
|
|
20
|
+
constant?: number;
|
|
21
|
+
randomState?: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
class Mulberry32 {
|
|
25
|
+
private state: number;
|
|
26
|
+
|
|
27
|
+
constructor(seed: number) {
|
|
28
|
+
this.state = seed >>> 0;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
next(): number {
|
|
32
|
+
this.state = (this.state + 0x6d2b79f5) >>> 0;
|
|
33
|
+
let t = this.state ^ (this.state >>> 15);
|
|
34
|
+
t = Math.imul(t, this.state | 1);
|
|
35
|
+
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
|
|
36
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export class DummyClassifier {
|
|
41
|
+
classes_: number[] | null = null;
|
|
42
|
+
classPrior_: number[] | null = null;
|
|
43
|
+
constant_: number | null = null;
|
|
44
|
+
|
|
45
|
+
private readonly strategy: DummyClassifierStrategy;
|
|
46
|
+
private readonly configuredConstant?: number;
|
|
47
|
+
private readonly randomState: number;
|
|
48
|
+
private majorityClass: number | null = null;
|
|
49
|
+
private nFeaturesIn_: number | null = null;
|
|
50
|
+
|
|
51
|
+
constructor(options: DummyClassifierOptions = {}) {
|
|
52
|
+
this.strategy = options.strategy ?? "prior";
|
|
53
|
+
this.configuredConstant = options.constant;
|
|
54
|
+
this.randomState = options.randomState ?? 42;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
fit(X: Matrix, y: Vector): this {
|
|
58
|
+
assertNonEmptyMatrix(X);
|
|
59
|
+
assertConsistentRowSize(X);
|
|
60
|
+
assertFiniteMatrix(X);
|
|
61
|
+
assertVectorLength(y, X.length);
|
|
62
|
+
assertFiniteVector(y);
|
|
63
|
+
this.nFeaturesIn_ = X[0].length;
|
|
64
|
+
|
|
65
|
+
const counts = new Map<number, number>();
|
|
66
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
67
|
+
counts.set(y[i], (counts.get(y[i]) ?? 0) + 1);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
const classes = Array.from(counts.keys()).sort((a, b) => a - b);
|
|
71
|
+
const priors = new Array<number>(classes.length);
|
|
72
|
+
for (let i = 0; i < classes.length; i += 1) {
|
|
73
|
+
priors[i] = (counts.get(classes[i]) ?? 0) / y.length;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let majorityClass = classes[0];
|
|
77
|
+
let majorityCount = counts.get(majorityClass) ?? 0;
|
|
78
|
+
for (let i = 1; i < classes.length; i += 1) {
|
|
79
|
+
const cls = classes[i];
|
|
80
|
+
const clsCount = counts.get(cls) ?? 0;
|
|
81
|
+
if (clsCount > majorityCount) {
|
|
82
|
+
majorityClass = cls;
|
|
83
|
+
majorityCount = clsCount;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (this.strategy === "constant") {
|
|
88
|
+
if (!Number.isFinite(this.configuredConstant)) {
|
|
89
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
90
|
+
}
|
|
91
|
+
this.constant_ = this.configuredConstant!;
|
|
92
|
+
} else {
|
|
93
|
+
this.constant_ = majorityClass;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
this.classes_ = classes;
|
|
97
|
+
this.classPrior_ = priors;
|
|
98
|
+
this.majorityClass = majorityClass;
|
|
99
|
+
return this;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
private ensureFitted(): void {
|
|
103
|
+
if (!this.classes_ || !this.classPrior_ || this.nFeaturesIn_ === null || this.majorityClass === null) {
|
|
104
|
+
throw new Error("DummyClassifier has not been fitted.");
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
private sampleByPrior(rng: Mulberry32): number {
|
|
109
|
+
let r = rng.next();
|
|
110
|
+
for (let i = 0; i < this.classPrior_!.length; i += 1) {
|
|
111
|
+
r -= this.classPrior_![i];
|
|
112
|
+
if (r <= 0) {
|
|
113
|
+
return this.classes_![i];
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
return this.classes_![this.classes_!.length - 1];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
predict(X: Matrix): Vector {
|
|
120
|
+
this.ensureFitted();
|
|
121
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
122
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
123
|
+
}
|
|
124
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
125
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
switch (this.strategy) {
|
|
129
|
+
case "most_frequent":
|
|
130
|
+
case "prior":
|
|
131
|
+
return new Array<number>(X.length).fill(this.majorityClass!);
|
|
132
|
+
case "constant":
|
|
133
|
+
return new Array<number>(X.length).fill(this.constant_!);
|
|
134
|
+
case "uniform": {
|
|
135
|
+
const rng = new Mulberry32(this.randomState);
|
|
136
|
+
const out = new Array<number>(X.length);
|
|
137
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
138
|
+
const idx = Math.floor(rng.next() * this.classes_!.length);
|
|
139
|
+
out[i] = this.classes_![idx];
|
|
140
|
+
}
|
|
141
|
+
return out;
|
|
142
|
+
}
|
|
143
|
+
case "stratified": {
|
|
144
|
+
const rng = new Mulberry32(this.randomState);
|
|
145
|
+
const out = new Array<number>(X.length);
|
|
146
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
147
|
+
out[i] = this.sampleByPrior(rng);
|
|
148
|
+
}
|
|
149
|
+
return out;
|
|
150
|
+
}
|
|
151
|
+
default: {
|
|
152
|
+
const exhaustive: never = this.strategy;
|
|
153
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
predictProba(X: Matrix): Matrix {
|
|
159
|
+
this.ensureFitted();
|
|
160
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
161
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
162
|
+
}
|
|
163
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
164
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
if (this.strategy === "uniform") {
|
|
168
|
+
const value = 1 / this.classes_!.length;
|
|
169
|
+
return X.map(() => new Array(this.classes_!.length).fill(value));
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
if (this.strategy === "most_frequent" || this.strategy === "constant") {
|
|
173
|
+
const oneHot = new Array<number>(this.classes_!.length).fill(0);
|
|
174
|
+
const label = this.strategy === "constant" ? this.constant_! : this.majorityClass!;
|
|
175
|
+
const classIndex = this.classes_!.indexOf(label);
|
|
176
|
+
if (classIndex >= 0) {
|
|
177
|
+
oneHot[classIndex] = 1;
|
|
178
|
+
}
|
|
179
|
+
return X.map(() => [...oneHot]);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// prior / stratified share prior probabilities.
|
|
183
|
+
const prior = [...this.classPrior_!];
|
|
184
|
+
return X.map(() => [...prior]);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
score(X: Matrix, y: Vector): number {
|
|
188
|
+
return accuracyScore(y, this.predict(X));
|
|
189
|
+
}
|
|
190
|
+
}
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { r2Score } from "../metrics/regression";
|
|
3
|
+
import { assertFiniteVector, validateRegressionInputs } from "../utils/validation";
|
|
4
|
+
|
|
5
|
+
export type DummyRegressorStrategy = "mean" | "median" | "quantile" | "constant";
|
|
6
|
+
|
|
7
|
+
export interface DummyRegressorOptions {
|
|
8
|
+
strategy?: DummyRegressorStrategy;
|
|
9
|
+
constant?: number;
|
|
10
|
+
quantile?: number;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
function computeMedian(values: number[]): number {
|
|
14
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
15
|
+
const mid = Math.floor(sorted.length / 2);
|
|
16
|
+
if (sorted.length % 2 === 0) {
|
|
17
|
+
return 0.5 * (sorted[mid - 1] + sorted[mid]);
|
|
18
|
+
}
|
|
19
|
+
return sorted[mid];
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
function computeQuantile(values: number[], q: number): number {
|
|
23
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
24
|
+
const pos = q * (sorted.length - 1);
|
|
25
|
+
const lo = Math.floor(pos);
|
|
26
|
+
const hi = Math.ceil(pos);
|
|
27
|
+
if (lo === hi) {
|
|
28
|
+
return sorted[lo];
|
|
29
|
+
}
|
|
30
|
+
const weight = pos - lo;
|
|
31
|
+
return sorted[lo] * (1 - weight) + sorted[hi] * weight;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export class DummyRegressor {
|
|
35
|
+
constant_: number | null = null;
|
|
36
|
+
|
|
37
|
+
private readonly strategy: DummyRegressorStrategy;
|
|
38
|
+
private readonly constant?: number;
|
|
39
|
+
private readonly quantile: number;
|
|
40
|
+
private nFeaturesIn_: number | null = null;
|
|
41
|
+
|
|
42
|
+
constructor(options: DummyRegressorOptions = {}) {
|
|
43
|
+
this.strategy = options.strategy ?? "mean";
|
|
44
|
+
this.constant = options.constant;
|
|
45
|
+
this.quantile = options.quantile ?? 0.5;
|
|
46
|
+
|
|
47
|
+
if (this.strategy === "constant") {
|
|
48
|
+
if (!Number.isFinite(this.constant)) {
|
|
49
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if (this.strategy === "quantile") {
|
|
54
|
+
if (!Number.isFinite(this.quantile) || this.quantile < 0 || this.quantile > 1) {
|
|
55
|
+
throw new Error(`quantile must be in [0, 1]. Got ${this.quantile}.`);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
fit(X: Matrix, y: Vector): this {
|
|
61
|
+
validateRegressionInputs(X, y);
|
|
62
|
+
this.nFeaturesIn_ = X[0].length;
|
|
63
|
+
|
|
64
|
+
switch (this.strategy) {
|
|
65
|
+
case "mean": {
|
|
66
|
+
let total = 0;
|
|
67
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
68
|
+
total += y[i];
|
|
69
|
+
}
|
|
70
|
+
this.constant_ = total / y.length;
|
|
71
|
+
break;
|
|
72
|
+
}
|
|
73
|
+
case "median":
|
|
74
|
+
this.constant_ = computeMedian(y);
|
|
75
|
+
break;
|
|
76
|
+
case "quantile":
|
|
77
|
+
this.constant_ = computeQuantile(y, this.quantile);
|
|
78
|
+
break;
|
|
79
|
+
case "constant":
|
|
80
|
+
this.constant_ = this.constant!;
|
|
81
|
+
break;
|
|
82
|
+
default: {
|
|
83
|
+
const exhaustive: never = this.strategy;
|
|
84
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
return this;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
predict(X: Matrix): Vector {
|
|
92
|
+
if (this.constant_ === null || this.nFeaturesIn_ === null) {
|
|
93
|
+
throw new Error("DummyRegressor has not been fitted.");
|
|
94
|
+
}
|
|
95
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
96
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
97
|
+
}
|
|
98
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
99
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
100
|
+
}
|
|
101
|
+
return new Array<number>(X.length).fill(this.constant_);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
score(X: Matrix, y: Vector): number {
|
|
105
|
+
assertFiniteVector(y);
|
|
106
|
+
return r2Score(y, this.predict(X));
|
|
107
|
+
}
|
|
108
|
+
}
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface VarianceThresholdOptions {
|
|
9
|
+
threshold?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class VarianceThreshold {
|
|
13
|
+
variances_: number[] | null = null;
|
|
14
|
+
nFeaturesIn_: number | null = null;
|
|
15
|
+
selectedFeatureIndices_: number[] | null = null;
|
|
16
|
+
|
|
17
|
+
private readonly threshold: number;
|
|
18
|
+
|
|
19
|
+
constructor(options: VarianceThresholdOptions = {}) {
|
|
20
|
+
this.threshold = options.threshold ?? 0;
|
|
21
|
+
if (!Number.isFinite(this.threshold) || this.threshold < 0) {
|
|
22
|
+
throw new Error(`threshold must be finite and >= 0. Got ${this.threshold}.`);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
fit(X: Matrix): this {
|
|
27
|
+
assertNonEmptyMatrix(X);
|
|
28
|
+
assertConsistentRowSize(X);
|
|
29
|
+
assertFiniteMatrix(X);
|
|
30
|
+
|
|
31
|
+
const nSamples = X.length;
|
|
32
|
+
const nFeatures = X[0].length;
|
|
33
|
+
const means = new Array<number>(nFeatures).fill(0);
|
|
34
|
+
const variances = new Array<number>(nFeatures).fill(0);
|
|
35
|
+
|
|
36
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
37
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
38
|
+
means[j] += X[i][j];
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
42
|
+
means[j] /= nSamples;
|
|
43
|
+
}
|
|
44
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
45
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
46
|
+
const diff = X[i][j] - means[j];
|
|
47
|
+
variances[j] += diff * diff;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
51
|
+
variances[j] /= nSamples;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
const selectedFeatureIndices: number[] = [];
|
|
55
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
56
|
+
if (variances[j] > this.threshold) {
|
|
57
|
+
selectedFeatureIndices.push(j);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
if (selectedFeatureIndices.length === 0) {
|
|
61
|
+
throw new Error("No feature in X meets the variance threshold.");
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
this.nFeaturesIn_ = nFeatures;
|
|
65
|
+
this.variances_ = variances;
|
|
66
|
+
this.selectedFeatureIndices_ = selectedFeatureIndices;
|
|
67
|
+
return this;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
transform(X: Matrix): Matrix {
|
|
71
|
+
if (!this.selectedFeatureIndices_ || this.nFeaturesIn_ === null) {
|
|
72
|
+
throw new Error("VarianceThreshold has not been fitted.");
|
|
73
|
+
}
|
|
74
|
+
assertNonEmptyMatrix(X);
|
|
75
|
+
assertConsistentRowSize(X);
|
|
76
|
+
assertFiniteMatrix(X);
|
|
77
|
+
|
|
78
|
+
if (X[0].length !== this.nFeaturesIn_) {
|
|
79
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
return X.map((row) => this.selectedFeatureIndices_!.map((featureIdx) => row[featureIdx]));
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fitTransform(X: Matrix): Matrix {
|
|
86
|
+
return this.fit(X).transform(X);
|
|
87
|
+
}
|
|
88
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -1,15 +1,28 @@
|
|
|
1
1
|
export * from "./types";
|
|
2
2
|
|
|
3
|
+
// Baselines
|
|
4
|
+
export * from "./dummy/DummyClassifier";
|
|
5
|
+
export * from "./dummy/DummyRegressor";
|
|
6
|
+
|
|
7
|
+
// Preprocessing
|
|
3
8
|
export * from "./preprocessing/StandardScaler";
|
|
4
9
|
export * from "./preprocessing/MinMaxScaler";
|
|
5
10
|
export * from "./preprocessing/RobustScaler";
|
|
11
|
+
export * from "./preprocessing/MaxAbsScaler";
|
|
12
|
+
export * from "./preprocessing/Normalizer";
|
|
13
|
+
export * from "./preprocessing/Binarizer";
|
|
14
|
+
export * from "./preprocessing/LabelEncoder";
|
|
6
15
|
export * from "./preprocessing/PolynomialFeatures";
|
|
7
16
|
export * from "./preprocessing/SimpleImputer";
|
|
8
17
|
export * from "./preprocessing/OneHotEncoder";
|
|
18
|
+
|
|
19
|
+
// Linear models
|
|
9
20
|
export * from "./linear_model/LinearRegression";
|
|
10
21
|
export * from "./linear_model/LogisticRegression";
|
|
11
22
|
export * from "./linear_model/SGDClassifier";
|
|
12
23
|
export * from "./linear_model/SGDRegressor";
|
|
24
|
+
|
|
25
|
+
// Other estimators
|
|
13
26
|
export * from "./neighbors/KNeighborsClassifier";
|
|
14
27
|
export * from "./naive_bayes/GaussianNB";
|
|
15
28
|
export * from "./svm/LinearSVC";
|
|
@@ -17,6 +30,8 @@ export * from "./tree/DecisionTreeClassifier";
|
|
|
17
30
|
export * from "./tree/DecisionTreeRegressor";
|
|
18
31
|
export * from "./ensemble/RandomForestClassifier";
|
|
19
32
|
export * from "./ensemble/RandomForestRegressor";
|
|
33
|
+
|
|
34
|
+
// Model selection
|
|
20
35
|
export * from "./model_selection/trainTestSplit";
|
|
21
36
|
export * from "./model_selection/KFold";
|
|
22
37
|
export * from "./model_selection/StratifiedKFold";
|
|
@@ -25,8 +40,16 @@ export * from "./model_selection/RepeatedKFold";
|
|
|
25
40
|
export * from "./model_selection/RepeatedStratifiedKFold";
|
|
26
41
|
export * from "./model_selection/crossValScore";
|
|
27
42
|
export * from "./model_selection/GridSearchCV";
|
|
43
|
+
export * from "./model_selection/RandomizedSearchCV";
|
|
44
|
+
|
|
45
|
+
// Feature selection
|
|
46
|
+
export * from "./feature_selection/VarianceThreshold";
|
|
47
|
+
|
|
48
|
+
// Composition
|
|
28
49
|
export * from "./pipeline/Pipeline";
|
|
29
50
|
export * from "./pipeline/ColumnTransformer";
|
|
30
51
|
export * from "./pipeline/FeatureUnion";
|
|
52
|
+
|
|
53
|
+
// Metrics
|
|
31
54
|
export * from "./metrics/regression";
|
|
32
55
|
export * from "./metrics/classification";
|
|
@@ -292,3 +292,33 @@ export function classificationReport(
|
|
|
292
292
|
},
|
|
293
293
|
};
|
|
294
294
|
}
|
|
295
|
+
|
|
296
|
+
export function balancedAccuracyScore(yTrue: number[], yPred: number[]): number {
|
|
297
|
+
const report = classificationReport(yTrue, yPred);
|
|
298
|
+
return report.macroAvg.recall;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
export function matthewsCorrcoef(
|
|
302
|
+
yTrue: number[],
|
|
303
|
+
yPred: number[],
|
|
304
|
+
positiveLabel = 1,
|
|
305
|
+
): number {
|
|
306
|
+
const { tp, fp, fn, tn } = confusionCounts(yTrue, yPred, positiveLabel);
|
|
307
|
+
const numerator = tp * tn - fp * fn;
|
|
308
|
+
const denominator = Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
|
|
309
|
+
if (denominator === 0) {
|
|
310
|
+
return 0;
|
|
311
|
+
}
|
|
312
|
+
return numerator / denominator;
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
export function brierScoreLoss(yTrue: number[], yPredProb: number[]): number {
|
|
316
|
+
validateInputs(yTrue, yPredProb);
|
|
317
|
+
validateBinaryTargets(yTrue);
|
|
318
|
+
let total = 0;
|
|
319
|
+
for (let i = 0; i < yTrue.length; i += 1) {
|
|
320
|
+
const diff = yPredProb[i] - yTrue[i];
|
|
321
|
+
total += diff * diff;
|
|
322
|
+
}
|
|
323
|
+
return total / yTrue.length;
|
|
324
|
+
}
|
|
@@ -49,3 +49,43 @@ export function r2Score(yTrue: number[], yPred: number[]): number {
|
|
|
49
49
|
|
|
50
50
|
return 1 - ssRes / ssTot;
|
|
51
51
|
}
|
|
52
|
+
|
|
53
|
+
export function meanAbsolutePercentageError(yTrue: number[], yPred: number[]): number {
|
|
54
|
+
validateInputs(yTrue, yPred);
|
|
55
|
+
let total = 0;
|
|
56
|
+
for (let i = 0; i < yTrue.length; i += 1) {
|
|
57
|
+
const denom = Math.max(Math.abs(yTrue[i]), 1e-12);
|
|
58
|
+
total += Math.abs((yTrue[i] - yPred[i]) / denom);
|
|
59
|
+
}
|
|
60
|
+
return total / yTrue.length;
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
export function explainedVarianceScore(yTrue: number[], yPred: number[]): number {
|
|
64
|
+
validateInputs(yTrue, yPred);
|
|
65
|
+
const n = yTrue.length;
|
|
66
|
+
const yTrueMean = mean(yTrue);
|
|
67
|
+
const residuals = new Array<number>(n);
|
|
68
|
+
let residualMean = 0;
|
|
69
|
+
for (let i = 0; i < n; i += 1) {
|
|
70
|
+
const r = yTrue[i] - yPred[i];
|
|
71
|
+
residuals[i] = r;
|
|
72
|
+
residualMean += r;
|
|
73
|
+
}
|
|
74
|
+
residualMean /= n;
|
|
75
|
+
|
|
76
|
+
let varTrue = 0;
|
|
77
|
+
let varResidual = 0;
|
|
78
|
+
for (let i = 0; i < n; i += 1) {
|
|
79
|
+
const centeredY = yTrue[i] - yTrueMean;
|
|
80
|
+
const centeredR = residuals[i] - residualMean;
|
|
81
|
+
varTrue += centeredY * centeredY;
|
|
82
|
+
varResidual += centeredR * centeredR;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
varTrue /= n;
|
|
86
|
+
varResidual /= n;
|
|
87
|
+
if (varTrue === 0) {
|
|
88
|
+
return varResidual === 0 ? 1 : 0;
|
|
89
|
+
}
|
|
90
|
+
return 1 - varResidual / varTrue;
|
|
91
|
+
}
|