bun-scikit 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +3 -2
- package/scripts/build-node-addon.ts +17 -1
- package/scripts/check-benchmark-health.ts +112 -6
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/ensemble/RandomForestClassifier.ts +154 -8
- package/src/ensemble/RandomForestRegressor.ts +12 -8
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +307 -0
- package/src/native/zigKernels.ts +122 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +159 -4
- package/zig/kernels.zig +333 -89
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { accuracyScore } from "../metrics/classification";
|
|
3
|
+
import {
|
|
4
|
+
assertConsistentRowSize,
|
|
5
|
+
assertFiniteMatrix,
|
|
6
|
+
assertFiniteVector,
|
|
7
|
+
assertNonEmptyMatrix,
|
|
8
|
+
assertVectorLength,
|
|
9
|
+
} from "../utils/validation";
|
|
10
|
+
|
|
11
|
+
export type DummyClassifierStrategy =
|
|
12
|
+
| "most_frequent"
|
|
13
|
+
| "prior"
|
|
14
|
+
| "stratified"
|
|
15
|
+
| "uniform"
|
|
16
|
+
| "constant";
|
|
17
|
+
|
|
18
|
+
export interface DummyClassifierOptions {
|
|
19
|
+
strategy?: DummyClassifierStrategy;
|
|
20
|
+
constant?: number;
|
|
21
|
+
randomState?: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
class Mulberry32 {
|
|
25
|
+
private state: number;
|
|
26
|
+
|
|
27
|
+
constructor(seed: number) {
|
|
28
|
+
this.state = seed >>> 0;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
next(): number {
|
|
32
|
+
this.state = (this.state + 0x6d2b79f5) >>> 0;
|
|
33
|
+
let t = this.state ^ (this.state >>> 15);
|
|
34
|
+
t = Math.imul(t, this.state | 1);
|
|
35
|
+
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
|
|
36
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export class DummyClassifier {
|
|
41
|
+
classes_: number[] | null = null;
|
|
42
|
+
classPrior_: number[] | null = null;
|
|
43
|
+
constant_: number | null = null;
|
|
44
|
+
|
|
45
|
+
private readonly strategy: DummyClassifierStrategy;
|
|
46
|
+
private readonly configuredConstant?: number;
|
|
47
|
+
private readonly randomState: number;
|
|
48
|
+
private majorityClass: number | null = null;
|
|
49
|
+
private nFeaturesIn_: number | null = null;
|
|
50
|
+
|
|
51
|
+
constructor(options: DummyClassifierOptions = {}) {
|
|
52
|
+
this.strategy = options.strategy ?? "prior";
|
|
53
|
+
this.configuredConstant = options.constant;
|
|
54
|
+
this.randomState = options.randomState ?? 42;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
fit(X: Matrix, y: Vector): this {
|
|
58
|
+
assertNonEmptyMatrix(X);
|
|
59
|
+
assertConsistentRowSize(X);
|
|
60
|
+
assertFiniteMatrix(X);
|
|
61
|
+
assertVectorLength(y, X.length);
|
|
62
|
+
assertFiniteVector(y);
|
|
63
|
+
this.nFeaturesIn_ = X[0].length;
|
|
64
|
+
|
|
65
|
+
const counts = new Map<number, number>();
|
|
66
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
67
|
+
counts.set(y[i], (counts.get(y[i]) ?? 0) + 1);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
const classes = Array.from(counts.keys()).sort((a, b) => a - b);
|
|
71
|
+
const priors = new Array<number>(classes.length);
|
|
72
|
+
for (let i = 0; i < classes.length; i += 1) {
|
|
73
|
+
priors[i] = (counts.get(classes[i]) ?? 0) / y.length;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let majorityClass = classes[0];
|
|
77
|
+
let majorityCount = counts.get(majorityClass) ?? 0;
|
|
78
|
+
for (let i = 1; i < classes.length; i += 1) {
|
|
79
|
+
const cls = classes[i];
|
|
80
|
+
const clsCount = counts.get(cls) ?? 0;
|
|
81
|
+
if (clsCount > majorityCount) {
|
|
82
|
+
majorityClass = cls;
|
|
83
|
+
majorityCount = clsCount;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (this.strategy === "constant") {
|
|
88
|
+
if (!Number.isFinite(this.configuredConstant)) {
|
|
89
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
90
|
+
}
|
|
91
|
+
this.constant_ = this.configuredConstant!;
|
|
92
|
+
} else {
|
|
93
|
+
this.constant_ = majorityClass;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
this.classes_ = classes;
|
|
97
|
+
this.classPrior_ = priors;
|
|
98
|
+
this.majorityClass = majorityClass;
|
|
99
|
+
return this;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
private ensureFitted(): void {
|
|
103
|
+
if (!this.classes_ || !this.classPrior_ || this.nFeaturesIn_ === null || this.majorityClass === null) {
|
|
104
|
+
throw new Error("DummyClassifier has not been fitted.");
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
private sampleByPrior(rng: Mulberry32): number {
|
|
109
|
+
let r = rng.next();
|
|
110
|
+
for (let i = 0; i < this.classPrior_!.length; i += 1) {
|
|
111
|
+
r -= this.classPrior_![i];
|
|
112
|
+
if (r <= 0) {
|
|
113
|
+
return this.classes_![i];
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
return this.classes_![this.classes_!.length - 1];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
predict(X: Matrix): Vector {
|
|
120
|
+
this.ensureFitted();
|
|
121
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
122
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
123
|
+
}
|
|
124
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
125
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
switch (this.strategy) {
|
|
129
|
+
case "most_frequent":
|
|
130
|
+
case "prior":
|
|
131
|
+
return new Array<number>(X.length).fill(this.majorityClass!);
|
|
132
|
+
case "constant":
|
|
133
|
+
return new Array<number>(X.length).fill(this.constant_!);
|
|
134
|
+
case "uniform": {
|
|
135
|
+
const rng = new Mulberry32(this.randomState);
|
|
136
|
+
const out = new Array<number>(X.length);
|
|
137
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
138
|
+
const idx = Math.floor(rng.next() * this.classes_!.length);
|
|
139
|
+
out[i] = this.classes_![idx];
|
|
140
|
+
}
|
|
141
|
+
return out;
|
|
142
|
+
}
|
|
143
|
+
case "stratified": {
|
|
144
|
+
const rng = new Mulberry32(this.randomState);
|
|
145
|
+
const out = new Array<number>(X.length);
|
|
146
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
147
|
+
out[i] = this.sampleByPrior(rng);
|
|
148
|
+
}
|
|
149
|
+
return out;
|
|
150
|
+
}
|
|
151
|
+
default: {
|
|
152
|
+
const exhaustive: never = this.strategy;
|
|
153
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
predictProba(X: Matrix): Matrix {
|
|
159
|
+
this.ensureFitted();
|
|
160
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
161
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
162
|
+
}
|
|
163
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
164
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
if (this.strategy === "uniform") {
|
|
168
|
+
const value = 1 / this.classes_!.length;
|
|
169
|
+
return X.map(() => new Array(this.classes_!.length).fill(value));
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
if (this.strategy === "most_frequent" || this.strategy === "constant") {
|
|
173
|
+
const oneHot = new Array<number>(this.classes_!.length).fill(0);
|
|
174
|
+
const label = this.strategy === "constant" ? this.constant_! : this.majorityClass!;
|
|
175
|
+
const classIndex = this.classes_!.indexOf(label);
|
|
176
|
+
if (classIndex >= 0) {
|
|
177
|
+
oneHot[classIndex] = 1;
|
|
178
|
+
}
|
|
179
|
+
return X.map(() => [...oneHot]);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// prior / stratified share prior probabilities.
|
|
183
|
+
const prior = [...this.classPrior_!];
|
|
184
|
+
return X.map(() => [...prior]);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
score(X: Matrix, y: Vector): number {
|
|
188
|
+
return accuracyScore(y, this.predict(X));
|
|
189
|
+
}
|
|
190
|
+
}
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { r2Score } from "../metrics/regression";
|
|
3
|
+
import { assertFiniteVector, validateRegressionInputs } from "../utils/validation";
|
|
4
|
+
|
|
5
|
+
export type DummyRegressorStrategy = "mean" | "median" | "quantile" | "constant";
|
|
6
|
+
|
|
7
|
+
export interface DummyRegressorOptions {
|
|
8
|
+
strategy?: DummyRegressorStrategy;
|
|
9
|
+
constant?: number;
|
|
10
|
+
quantile?: number;
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
function computeMedian(values: number[]): number {
|
|
14
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
15
|
+
const mid = Math.floor(sorted.length / 2);
|
|
16
|
+
if (sorted.length % 2 === 0) {
|
|
17
|
+
return 0.5 * (sorted[mid - 1] + sorted[mid]);
|
|
18
|
+
}
|
|
19
|
+
return sorted[mid];
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
function computeQuantile(values: number[], q: number): number {
|
|
23
|
+
const sorted = [...values].sort((a, b) => a - b);
|
|
24
|
+
const pos = q * (sorted.length - 1);
|
|
25
|
+
const lo = Math.floor(pos);
|
|
26
|
+
const hi = Math.ceil(pos);
|
|
27
|
+
if (lo === hi) {
|
|
28
|
+
return sorted[lo];
|
|
29
|
+
}
|
|
30
|
+
const weight = pos - lo;
|
|
31
|
+
return sorted[lo] * (1 - weight) + sorted[hi] * weight;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
export class DummyRegressor {
|
|
35
|
+
constant_: number | null = null;
|
|
36
|
+
|
|
37
|
+
private readonly strategy: DummyRegressorStrategy;
|
|
38
|
+
private readonly constant?: number;
|
|
39
|
+
private readonly quantile: number;
|
|
40
|
+
private nFeaturesIn_: number | null = null;
|
|
41
|
+
|
|
42
|
+
constructor(options: DummyRegressorOptions = {}) {
|
|
43
|
+
this.strategy = options.strategy ?? "mean";
|
|
44
|
+
this.constant = options.constant;
|
|
45
|
+
this.quantile = options.quantile ?? 0.5;
|
|
46
|
+
|
|
47
|
+
if (this.strategy === "constant") {
|
|
48
|
+
if (!Number.isFinite(this.constant)) {
|
|
49
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
if (this.strategy === "quantile") {
|
|
54
|
+
if (!Number.isFinite(this.quantile) || this.quantile < 0 || this.quantile > 1) {
|
|
55
|
+
throw new Error(`quantile must be in [0, 1]. Got ${this.quantile}.`);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
fit(X: Matrix, y: Vector): this {
|
|
61
|
+
validateRegressionInputs(X, y);
|
|
62
|
+
this.nFeaturesIn_ = X[0].length;
|
|
63
|
+
|
|
64
|
+
switch (this.strategy) {
|
|
65
|
+
case "mean": {
|
|
66
|
+
let total = 0;
|
|
67
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
68
|
+
total += y[i];
|
|
69
|
+
}
|
|
70
|
+
this.constant_ = total / y.length;
|
|
71
|
+
break;
|
|
72
|
+
}
|
|
73
|
+
case "median":
|
|
74
|
+
this.constant_ = computeMedian(y);
|
|
75
|
+
break;
|
|
76
|
+
case "quantile":
|
|
77
|
+
this.constant_ = computeQuantile(y, this.quantile);
|
|
78
|
+
break;
|
|
79
|
+
case "constant":
|
|
80
|
+
this.constant_ = this.constant!;
|
|
81
|
+
break;
|
|
82
|
+
default: {
|
|
83
|
+
const exhaustive: never = this.strategy;
|
|
84
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
return this;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
predict(X: Matrix): Vector {
|
|
92
|
+
if (this.constant_ === null || this.nFeaturesIn_ === null) {
|
|
93
|
+
throw new Error("DummyRegressor has not been fitted.");
|
|
94
|
+
}
|
|
95
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
96
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
97
|
+
}
|
|
98
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
99
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
100
|
+
}
|
|
101
|
+
return new Array<number>(X.length).fill(this.constant_);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
score(X: Matrix, y: Vector): number {
|
|
105
|
+
assertFiniteVector(y);
|
|
106
|
+
return r2Score(y, this.predict(X));
|
|
107
|
+
}
|
|
108
|
+
}
|
|
@@ -2,6 +2,7 @@ import type { ClassificationModel, Matrix, Vector } from "../types";
|
|
|
2
2
|
import { accuracyScore } from "../metrics/classification";
|
|
3
3
|
import { DecisionTreeClassifier, type MaxFeaturesOption } from "../tree/DecisionTreeClassifier";
|
|
4
4
|
import { assertFiniteVector, validateClassificationInputs } from "../utils/validation";
|
|
5
|
+
import { getZigKernels } from "../native/zigKernels";
|
|
5
6
|
|
|
6
7
|
export interface RandomForestClassifierOptions {
|
|
7
8
|
nEstimators?: number;
|
|
@@ -23,8 +24,18 @@ function mulberry32(seed: number): () => number {
|
|
|
23
24
|
};
|
|
24
25
|
}
|
|
25
26
|
|
|
27
|
+
function isTruthy(value: string | undefined): boolean {
|
|
28
|
+
if (!value) {
|
|
29
|
+
return false;
|
|
30
|
+
}
|
|
31
|
+
const normalized = value.trim().toLowerCase();
|
|
32
|
+
return !(normalized === "0" || normalized === "false" || normalized === "off");
|
|
33
|
+
}
|
|
34
|
+
|
|
26
35
|
export class RandomForestClassifier implements ClassificationModel {
|
|
27
36
|
classes_: Vector = [0, 1];
|
|
37
|
+
fitBackend_: "zig" | "js" = "js";
|
|
38
|
+
fitBackendLibrary_: string | null = null;
|
|
28
39
|
private readonly nEstimators: number;
|
|
29
40
|
private readonly maxDepth?: number;
|
|
30
41
|
private readonly minSamplesSplit?: number;
|
|
@@ -32,6 +43,7 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
32
43
|
private readonly maxFeatures: MaxFeaturesOption;
|
|
33
44
|
private readonly bootstrap: boolean;
|
|
34
45
|
private readonly randomState?: number;
|
|
46
|
+
private nativeModelHandle: bigint | null = null;
|
|
35
47
|
private trees: DecisionTreeClassifier[] = [];
|
|
36
48
|
|
|
37
49
|
constructor(options: RandomForestClassifierOptions = {}) {
|
|
@@ -49,6 +61,7 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
49
61
|
}
|
|
50
62
|
|
|
51
63
|
fit(X: Matrix, y: Vector): this {
|
|
64
|
+
this.disposeNativeModel();
|
|
52
65
|
validateClassificationInputs(X, y);
|
|
53
66
|
|
|
54
67
|
const sampleCount = X.length;
|
|
@@ -56,10 +69,17 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
56
69
|
const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
57
70
|
const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
58
71
|
const yBinary = this.buildBinaryTargets(y);
|
|
72
|
+
const sampleIndices = new Uint32Array(sampleCount);
|
|
73
|
+
this.trees = [];
|
|
74
|
+
if (this.tryFitNativeForest(flattenedX, yBinary, sampleCount, featureCount)) {
|
|
75
|
+
this.fitBackend_ = "zig";
|
|
76
|
+
return this;
|
|
77
|
+
}
|
|
78
|
+
this.fitBackend_ = "js";
|
|
79
|
+
this.fitBackendLibrary_ = null;
|
|
59
80
|
this.trees = new Array(this.nEstimators);
|
|
60
81
|
|
|
61
82
|
for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
|
|
62
|
-
const sampleIndices = new Uint32Array(sampleCount);
|
|
63
83
|
if (this.bootstrap) {
|
|
64
84
|
for (let i = 0; i < sampleCount; i += 1) {
|
|
65
85
|
sampleIndices[i] = Math.floor(random() * sampleCount);
|
|
@@ -86,20 +106,47 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
86
106
|
}
|
|
87
107
|
|
|
88
108
|
predict(X: Matrix): Vector {
|
|
109
|
+
if (this.nativeModelHandle !== null) {
|
|
110
|
+
const kernels = getZigKernels();
|
|
111
|
+
const predict = kernels?.randomForestClassifierModelPredict;
|
|
112
|
+
if (predict) {
|
|
113
|
+
const sampleCount = X.length;
|
|
114
|
+
const featureCount = X[0]?.length ?? 0;
|
|
115
|
+
const flattened = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
116
|
+
const out = new Uint8Array(sampleCount);
|
|
117
|
+
const status = predict(
|
|
118
|
+
this.nativeModelHandle,
|
|
119
|
+
flattened,
|
|
120
|
+
sampleCount,
|
|
121
|
+
featureCount,
|
|
122
|
+
out,
|
|
123
|
+
);
|
|
124
|
+
if (status === 1) {
|
|
125
|
+
return Array.from(out);
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
89
130
|
if (this.trees.length === 0) {
|
|
90
131
|
throw new Error("RandomForestClassifier has not been fitted.");
|
|
91
132
|
}
|
|
92
133
|
|
|
93
|
-
const treePredictions = this.trees.map((tree) => tree.predict(X));
|
|
94
134
|
const sampleCount = X.length;
|
|
95
|
-
const
|
|
135
|
+
const voteCounts = new Uint16Array(sampleCount);
|
|
96
136
|
|
|
97
|
-
for (let
|
|
98
|
-
|
|
99
|
-
for (let
|
|
100
|
-
|
|
137
|
+
for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
|
|
138
|
+
const treePrediction = this.trees[treeIndex].predict(X);
|
|
139
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
140
|
+
if (treePrediction[sampleIndex] === 1) {
|
|
141
|
+
voteCounts[sampleIndex] += 1;
|
|
142
|
+
}
|
|
101
143
|
}
|
|
102
|
-
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
const predictions = new Array<number>(sampleCount);
|
|
147
|
+
const voteThreshold = this.trees.length;
|
|
148
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
149
|
+
predictions[sampleIndex] = voteCounts[sampleIndex] * 2 >= voteThreshold ? 1 : 0;
|
|
103
150
|
}
|
|
104
151
|
|
|
105
152
|
return predictions;
|
|
@@ -110,6 +157,105 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
110
157
|
return accuracyScore(y, this.predict(X));
|
|
111
158
|
}
|
|
112
159
|
|
|
160
|
+
dispose(): void {
|
|
161
|
+
this.disposeNativeModel();
|
|
162
|
+
for (let i = 0; i < this.trees.length; i += 1) {
|
|
163
|
+
this.trees[i].dispose();
|
|
164
|
+
}
|
|
165
|
+
this.trees = [];
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
private resolveNativeMaxFeatures(featureCount: number): {
|
|
169
|
+
mode: 0 | 1 | 2 | 3;
|
|
170
|
+
value: number;
|
|
171
|
+
} {
|
|
172
|
+
if (this.maxFeatures === null || this.maxFeatures === undefined) {
|
|
173
|
+
return { mode: 0, value: 0 };
|
|
174
|
+
}
|
|
175
|
+
if (this.maxFeatures === "sqrt") {
|
|
176
|
+
return { mode: 1, value: 0 };
|
|
177
|
+
}
|
|
178
|
+
if (this.maxFeatures === "log2") {
|
|
179
|
+
return { mode: 2, value: 0 };
|
|
180
|
+
}
|
|
181
|
+
const value = Number.isFinite(this.maxFeatures)
|
|
182
|
+
? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
|
|
183
|
+
: featureCount;
|
|
184
|
+
return { mode: 3, value };
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
private tryFitNativeForest(
|
|
188
|
+
flattenedX: Float64Array,
|
|
189
|
+
yBinary: Uint8Array,
|
|
190
|
+
sampleCount: number,
|
|
191
|
+
featureCount: number,
|
|
192
|
+
): boolean {
|
|
193
|
+
if (!isTruthy(process.env.BUN_SCIKIT_EXPERIMENTAL_NATIVE_FOREST)) {
|
|
194
|
+
return false;
|
|
195
|
+
}
|
|
196
|
+
if (process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase() !== "zig") {
|
|
197
|
+
return false;
|
|
198
|
+
}
|
|
199
|
+
const kernels = getZigKernels();
|
|
200
|
+
const create = kernels?.randomForestClassifierModelCreate;
|
|
201
|
+
const fit = kernels?.randomForestClassifierModelFit;
|
|
202
|
+
const destroy = kernels?.randomForestClassifierModelDestroy;
|
|
203
|
+
if (!create || !fit || !destroy) {
|
|
204
|
+
return false;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
const { mode, value } = this.resolveNativeMaxFeatures(featureCount);
|
|
208
|
+
const useRandomState = this.randomState === undefined ? 0 : 1;
|
|
209
|
+
const randomState = this.randomState ?? 0;
|
|
210
|
+
const handle = create(
|
|
211
|
+
this.nEstimators,
|
|
212
|
+
this.maxDepth ?? 12,
|
|
213
|
+
this.minSamplesSplit ?? 2,
|
|
214
|
+
this.minSamplesLeaf ?? 1,
|
|
215
|
+
mode,
|
|
216
|
+
value,
|
|
217
|
+
this.bootstrap ? 1 : 0,
|
|
218
|
+
randomState >>> 0,
|
|
219
|
+
useRandomState,
|
|
220
|
+
featureCount,
|
|
221
|
+
);
|
|
222
|
+
if (handle === 0n) {
|
|
223
|
+
return false;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
let shouldDestroy = true;
|
|
227
|
+
try {
|
|
228
|
+
const status = fit(handle, flattenedX, yBinary, sampleCount, featureCount);
|
|
229
|
+
if (status !== 1) {
|
|
230
|
+
return false;
|
|
231
|
+
}
|
|
232
|
+
this.nativeModelHandle = handle;
|
|
233
|
+
this.fitBackendLibrary_ = kernels.libraryPath;
|
|
234
|
+
shouldDestroy = false;
|
|
235
|
+
return true;
|
|
236
|
+
} finally {
|
|
237
|
+
if (shouldDestroy) {
|
|
238
|
+
destroy(handle);
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
private disposeNativeModel(): void {
|
|
244
|
+
if (this.nativeModelHandle === null) {
|
|
245
|
+
return;
|
|
246
|
+
}
|
|
247
|
+
const kernels = getZigKernels();
|
|
248
|
+
const destroy = kernels?.randomForestClassifierModelDestroy;
|
|
249
|
+
if (destroy) {
|
|
250
|
+
try {
|
|
251
|
+
destroy(this.nativeModelHandle);
|
|
252
|
+
} catch {
|
|
253
|
+
// best effort cleanup
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
this.nativeModelHandle = null;
|
|
257
|
+
}
|
|
258
|
+
|
|
113
259
|
private flattenTrainingMatrix(
|
|
114
260
|
X: Matrix,
|
|
115
261
|
sampleCount: number,
|
|
@@ -56,10 +56,10 @@ export class RandomForestRegressor implements RegressionModel {
|
|
|
56
56
|
const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
57
57
|
const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
58
58
|
const yValues = this.toFloat64Vector(y);
|
|
59
|
+
const sampleIndices = new Uint32Array(sampleCount);
|
|
59
60
|
this.trees = new Array(this.nEstimators);
|
|
60
61
|
|
|
61
62
|
for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
|
|
62
|
-
const sampleIndices = new Uint32Array(sampleCount);
|
|
63
63
|
if (this.bootstrap) {
|
|
64
64
|
for (let i = 0; i < sampleCount; i += 1) {
|
|
65
65
|
sampleIndices[i] = Math.floor(random() * sampleCount);
|
|
@@ -90,16 +90,20 @@ export class RandomForestRegressor implements RegressionModel {
|
|
|
90
90
|
throw new Error("RandomForestRegressor has not been fitted.");
|
|
91
91
|
}
|
|
92
92
|
|
|
93
|
-
const treePredictions = this.trees.map((tree) => tree.predict(X));
|
|
94
93
|
const sampleCount = X.length;
|
|
95
|
-
const
|
|
94
|
+
const sums = new Float64Array(sampleCount);
|
|
96
95
|
|
|
97
|
-
for (let
|
|
98
|
-
|
|
99
|
-
for (let
|
|
100
|
-
|
|
96
|
+
for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
|
|
97
|
+
const treePrediction = this.trees[treeIndex].predict(X);
|
|
98
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
99
|
+
sums[sampleIndex] += treePrediction[sampleIndex];
|
|
101
100
|
}
|
|
102
|
-
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
const predictions = new Array<number>(sampleCount);
|
|
104
|
+
const denominator = this.trees.length;
|
|
105
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
106
|
+
predictions[sampleIndex] = sums[sampleIndex] / denominator;
|
|
103
107
|
}
|
|
104
108
|
|
|
105
109
|
return predictions;
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface VarianceThresholdOptions {
|
|
9
|
+
threshold?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class VarianceThreshold {
|
|
13
|
+
variances_: number[] | null = null;
|
|
14
|
+
nFeaturesIn_: number | null = null;
|
|
15
|
+
selectedFeatureIndices_: number[] | null = null;
|
|
16
|
+
|
|
17
|
+
private readonly threshold: number;
|
|
18
|
+
|
|
19
|
+
constructor(options: VarianceThresholdOptions = {}) {
|
|
20
|
+
this.threshold = options.threshold ?? 0;
|
|
21
|
+
if (!Number.isFinite(this.threshold) || this.threshold < 0) {
|
|
22
|
+
throw new Error(`threshold must be finite and >= 0. Got ${this.threshold}.`);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
fit(X: Matrix): this {
|
|
27
|
+
assertNonEmptyMatrix(X);
|
|
28
|
+
assertConsistentRowSize(X);
|
|
29
|
+
assertFiniteMatrix(X);
|
|
30
|
+
|
|
31
|
+
const nSamples = X.length;
|
|
32
|
+
const nFeatures = X[0].length;
|
|
33
|
+
const means = new Array<number>(nFeatures).fill(0);
|
|
34
|
+
const variances = new Array<number>(nFeatures).fill(0);
|
|
35
|
+
|
|
36
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
37
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
38
|
+
means[j] += X[i][j];
|
|
39
|
+
}
|
|
40
|
+
}
|
|
41
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
42
|
+
means[j] /= nSamples;
|
|
43
|
+
}
|
|
44
|
+
for (let i = 0; i < nSamples; i += 1) {
|
|
45
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
46
|
+
const diff = X[i][j] - means[j];
|
|
47
|
+
variances[j] += diff * diff;
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
51
|
+
variances[j] /= nSamples;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
const selectedFeatureIndices: number[] = [];
|
|
55
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
56
|
+
if (variances[j] > this.threshold) {
|
|
57
|
+
selectedFeatureIndices.push(j);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
if (selectedFeatureIndices.length === 0) {
|
|
61
|
+
throw new Error("No feature in X meets the variance threshold.");
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
this.nFeaturesIn_ = nFeatures;
|
|
65
|
+
this.variances_ = variances;
|
|
66
|
+
this.selectedFeatureIndices_ = selectedFeatureIndices;
|
|
67
|
+
return this;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
transform(X: Matrix): Matrix {
|
|
71
|
+
if (!this.selectedFeatureIndices_ || this.nFeaturesIn_ === null) {
|
|
72
|
+
throw new Error("VarianceThreshold has not been fitted.");
|
|
73
|
+
}
|
|
74
|
+
assertNonEmptyMatrix(X);
|
|
75
|
+
assertConsistentRowSize(X);
|
|
76
|
+
assertFiniteMatrix(X);
|
|
77
|
+
|
|
78
|
+
if (X[0].length !== this.nFeaturesIn_) {
|
|
79
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
return X.map((row) => this.selectedFeatureIndices_!.map((featureIdx) => row[featureIdx]));
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
fitTransform(X: Matrix): Matrix {
|
|
86
|
+
return this.fit(X).transform(X);
|
|
87
|
+
}
|
|
88
|
+
}
|
package/src/index.ts
CHANGED
|
@@ -1,15 +1,28 @@
|
|
|
1
1
|
export * from "./types";
|
|
2
2
|
|
|
3
|
+
// Baselines
|
|
4
|
+
export * from "./dummy/DummyClassifier";
|
|
5
|
+
export * from "./dummy/DummyRegressor";
|
|
6
|
+
|
|
7
|
+
// Preprocessing
|
|
3
8
|
export * from "./preprocessing/StandardScaler";
|
|
4
9
|
export * from "./preprocessing/MinMaxScaler";
|
|
5
10
|
export * from "./preprocessing/RobustScaler";
|
|
11
|
+
export * from "./preprocessing/MaxAbsScaler";
|
|
12
|
+
export * from "./preprocessing/Normalizer";
|
|
13
|
+
export * from "./preprocessing/Binarizer";
|
|
14
|
+
export * from "./preprocessing/LabelEncoder";
|
|
6
15
|
export * from "./preprocessing/PolynomialFeatures";
|
|
7
16
|
export * from "./preprocessing/SimpleImputer";
|
|
8
17
|
export * from "./preprocessing/OneHotEncoder";
|
|
18
|
+
|
|
19
|
+
// Linear models
|
|
9
20
|
export * from "./linear_model/LinearRegression";
|
|
10
21
|
export * from "./linear_model/LogisticRegression";
|
|
11
22
|
export * from "./linear_model/SGDClassifier";
|
|
12
23
|
export * from "./linear_model/SGDRegressor";
|
|
24
|
+
|
|
25
|
+
// Other estimators
|
|
13
26
|
export * from "./neighbors/KNeighborsClassifier";
|
|
14
27
|
export * from "./naive_bayes/GaussianNB";
|
|
15
28
|
export * from "./svm/LinearSVC";
|
|
@@ -17,6 +30,8 @@ export * from "./tree/DecisionTreeClassifier";
|
|
|
17
30
|
export * from "./tree/DecisionTreeRegressor";
|
|
18
31
|
export * from "./ensemble/RandomForestClassifier";
|
|
19
32
|
export * from "./ensemble/RandomForestRegressor";
|
|
33
|
+
|
|
34
|
+
// Model selection
|
|
20
35
|
export * from "./model_selection/trainTestSplit";
|
|
21
36
|
export * from "./model_selection/KFold";
|
|
22
37
|
export * from "./model_selection/StratifiedKFold";
|
|
@@ -25,8 +40,16 @@ export * from "./model_selection/RepeatedKFold";
|
|
|
25
40
|
export * from "./model_selection/RepeatedStratifiedKFold";
|
|
26
41
|
export * from "./model_selection/crossValScore";
|
|
27
42
|
export * from "./model_selection/GridSearchCV";
|
|
43
|
+
export * from "./model_selection/RandomizedSearchCV";
|
|
44
|
+
|
|
45
|
+
// Feature selection
|
|
46
|
+
export * from "./feature_selection/VarianceThreshold";
|
|
47
|
+
|
|
48
|
+
// Composition
|
|
28
49
|
export * from "./pipeline/Pipeline";
|
|
29
50
|
export * from "./pipeline/ColumnTransformer";
|
|
30
51
|
export * from "./pipeline/FeatureUnion";
|
|
52
|
+
|
|
53
|
+
// Metrics
|
|
31
54
|
export * from "./metrics/regression";
|
|
32
55
|
export * from "./metrics/classification";
|