bun-scikit 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +3 -2
- package/scripts/build-node-addon.ts +17 -1
- package/scripts/check-benchmark-health.ts +112 -6
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/ensemble/RandomForestClassifier.ts +154 -8
- package/src/ensemble/RandomForestRegressor.ts +12 -8
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +307 -0
- package/src/native/zigKernels.ts +122 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +159 -4
- package/zig/kernels.zig +333 -89
package/src/native/zigKernels.ts
CHANGED
|
@@ -88,6 +88,33 @@ type DecisionTreeModelPredictFn = (
|
|
|
88
88
|
nFeatures: number,
|
|
89
89
|
outLabels: Uint8Array,
|
|
90
90
|
) => number;
|
|
91
|
+
type RandomForestClassifierModelCreateFn = (
|
|
92
|
+
nEstimators: number,
|
|
93
|
+
maxDepth: number,
|
|
94
|
+
minSamplesSplit: number,
|
|
95
|
+
minSamplesLeaf: number,
|
|
96
|
+
maxFeaturesMode: number,
|
|
97
|
+
maxFeaturesValue: number,
|
|
98
|
+
bootstrap: number,
|
|
99
|
+
randomState: number,
|
|
100
|
+
useRandomState: number,
|
|
101
|
+
nFeatures: number,
|
|
102
|
+
) => NativeHandle;
|
|
103
|
+
type RandomForestClassifierModelDestroyFn = (handle: NativeHandle) => void;
|
|
104
|
+
type RandomForestClassifierModelFitFn = (
|
|
105
|
+
handle: NativeHandle,
|
|
106
|
+
x: Float64Array,
|
|
107
|
+
y: Uint8Array,
|
|
108
|
+
nSamples: number,
|
|
109
|
+
nFeatures: number,
|
|
110
|
+
) => number;
|
|
111
|
+
type RandomForestClassifierModelPredictFn = (
|
|
112
|
+
handle: NativeHandle,
|
|
113
|
+
x: Float64Array,
|
|
114
|
+
nSamples: number,
|
|
115
|
+
nFeatures: number,
|
|
116
|
+
outLabels: Uint8Array,
|
|
117
|
+
) => number;
|
|
91
118
|
|
|
92
119
|
type LogisticTrainEpochFn = (
|
|
93
120
|
x: Float64Array,
|
|
@@ -138,6 +165,10 @@ interface ZigKernelLibrary {
|
|
|
138
165
|
decision_tree_model_destroy?: DecisionTreeModelDestroyFn;
|
|
139
166
|
decision_tree_model_fit?: DecisionTreeModelFitFn;
|
|
140
167
|
decision_tree_model_predict?: DecisionTreeModelPredictFn;
|
|
168
|
+
random_forest_classifier_model_create?: RandomForestClassifierModelCreateFn;
|
|
169
|
+
random_forest_classifier_model_destroy?: RandomForestClassifierModelDestroyFn;
|
|
170
|
+
random_forest_classifier_model_fit?: RandomForestClassifierModelFitFn;
|
|
171
|
+
random_forest_classifier_model_predict?: RandomForestClassifierModelPredictFn;
|
|
141
172
|
logistic_train_epoch?: LogisticTrainEpochFn;
|
|
142
173
|
logistic_train_epochs?: LogisticTrainEpochsFn;
|
|
143
174
|
};
|
|
@@ -162,6 +193,10 @@ export interface ZigKernels {
|
|
|
162
193
|
decisionTreeModelDestroy: DecisionTreeModelDestroyFn | null;
|
|
163
194
|
decisionTreeModelFit: DecisionTreeModelFitFn | null;
|
|
164
195
|
decisionTreeModelPredict: DecisionTreeModelPredictFn | null;
|
|
196
|
+
randomForestClassifierModelCreate: RandomForestClassifierModelCreateFn | null;
|
|
197
|
+
randomForestClassifierModelDestroy: RandomForestClassifierModelDestroyFn | null;
|
|
198
|
+
randomForestClassifierModelFit: RandomForestClassifierModelFitFn | null;
|
|
199
|
+
randomForestClassifierModelPredict: RandomForestClassifierModelPredictFn | null;
|
|
165
200
|
logisticTrainEpoch: LogisticTrainEpochFn | null;
|
|
166
201
|
logisticTrainEpochs: LogisticTrainEpochsFn | null;
|
|
167
202
|
abiVersion: number | null;
|
|
@@ -243,6 +278,14 @@ interface NodeApiAddon {
|
|
|
243
278
|
logisticModelFitLbfgs: LogisticModelFitLbfgsFn;
|
|
244
279
|
logisticModelCopyCoefficients: LogisticModelCopyCoefficientsFn;
|
|
245
280
|
logisticModelGetIntercept: LogisticModelGetInterceptFn;
|
|
281
|
+
decisionTreeModelCreate?: DecisionTreeModelCreateFn;
|
|
282
|
+
decisionTreeModelDestroy?: DecisionTreeModelDestroyFn;
|
|
283
|
+
decisionTreeModelFit?: DecisionTreeModelFitFn;
|
|
284
|
+
decisionTreeModelPredict?: DecisionTreeModelPredictFn;
|
|
285
|
+
randomForestClassifierModelCreate?: RandomForestClassifierModelCreateFn;
|
|
286
|
+
randomForestClassifierModelDestroy?: RandomForestClassifierModelDestroyFn;
|
|
287
|
+
randomForestClassifierModelFit?: RandomForestClassifierModelFitFn;
|
|
288
|
+
randomForestClassifierModelPredict?: RandomForestClassifierModelPredictFn;
|
|
246
289
|
}
|
|
247
290
|
|
|
248
291
|
function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
@@ -281,10 +324,17 @@ function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
|
281
324
|
logisticModelPredict: null,
|
|
282
325
|
logisticModelCopyCoefficients: addon.logisticModelCopyCoefficients ?? null,
|
|
283
326
|
logisticModelGetIntercept: addon.logisticModelGetIntercept ?? null,
|
|
284
|
-
decisionTreeModelCreate: null,
|
|
285
|
-
decisionTreeModelDestroy: null,
|
|
286
|
-
decisionTreeModelFit: null,
|
|
287
|
-
decisionTreeModelPredict: null,
|
|
327
|
+
decisionTreeModelCreate: addon.decisionTreeModelCreate ?? null,
|
|
328
|
+
decisionTreeModelDestroy: addon.decisionTreeModelDestroy ?? null,
|
|
329
|
+
decisionTreeModelFit: addon.decisionTreeModelFit ?? null,
|
|
330
|
+
decisionTreeModelPredict: addon.decisionTreeModelPredict ?? null,
|
|
331
|
+
randomForestClassifierModelCreate:
|
|
332
|
+
addon.randomForestClassifierModelCreate ?? null,
|
|
333
|
+
randomForestClassifierModelDestroy:
|
|
334
|
+
addon.randomForestClassifierModelDestroy ?? null,
|
|
335
|
+
randomForestClassifierModelFit: addon.randomForestClassifierModelFit ?? null,
|
|
336
|
+
randomForestClassifierModelPredict:
|
|
337
|
+
addon.randomForestClassifierModelPredict ?? null,
|
|
288
338
|
logisticTrainEpoch: null,
|
|
289
339
|
logisticTrainEpochs: null,
|
|
290
340
|
abiVersion,
|
|
@@ -403,6 +453,58 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
403
453
|
args: ["usize"],
|
|
404
454
|
returns: FFIType.f64,
|
|
405
455
|
},
|
|
456
|
+
decision_tree_model_create: {
|
|
457
|
+
args: [
|
|
458
|
+
"usize",
|
|
459
|
+
"usize",
|
|
460
|
+
"usize",
|
|
461
|
+
FFIType.u8,
|
|
462
|
+
"usize",
|
|
463
|
+
FFIType.u32,
|
|
464
|
+
FFIType.u8,
|
|
465
|
+
"usize",
|
|
466
|
+
],
|
|
467
|
+
returns: "usize",
|
|
468
|
+
},
|
|
469
|
+
decision_tree_model_destroy: {
|
|
470
|
+
args: ["usize"],
|
|
471
|
+
returns: FFIType.void,
|
|
472
|
+
},
|
|
473
|
+
decision_tree_model_fit: {
|
|
474
|
+
args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize", FFIType.ptr, "usize"],
|
|
475
|
+
returns: FFIType.u8,
|
|
476
|
+
},
|
|
477
|
+
decision_tree_model_predict: {
|
|
478
|
+
args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
|
|
479
|
+
returns: FFIType.u8,
|
|
480
|
+
},
|
|
481
|
+
random_forest_classifier_model_create: {
|
|
482
|
+
args: [
|
|
483
|
+
"usize",
|
|
484
|
+
"usize",
|
|
485
|
+
"usize",
|
|
486
|
+
"usize",
|
|
487
|
+
FFIType.u8,
|
|
488
|
+
"usize",
|
|
489
|
+
FFIType.u8,
|
|
490
|
+
FFIType.u32,
|
|
491
|
+
FFIType.u8,
|
|
492
|
+
"usize",
|
|
493
|
+
],
|
|
494
|
+
returns: "usize",
|
|
495
|
+
},
|
|
496
|
+
random_forest_classifier_model_destroy: {
|
|
497
|
+
args: ["usize"],
|
|
498
|
+
returns: FFIType.void,
|
|
499
|
+
},
|
|
500
|
+
random_forest_classifier_model_fit: {
|
|
501
|
+
args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize"],
|
|
502
|
+
returns: FFIType.u8,
|
|
503
|
+
},
|
|
504
|
+
random_forest_classifier_model_predict: {
|
|
505
|
+
args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
|
|
506
|
+
returns: FFIType.u8,
|
|
507
|
+
},
|
|
406
508
|
logistic_train_epoch: {
|
|
407
509
|
args: [
|
|
408
510
|
FFIType.ptr,
|
|
@@ -463,6 +565,14 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
463
565
|
decisionTreeModelDestroy: library.symbols.decision_tree_model_destroy ?? null,
|
|
464
566
|
decisionTreeModelFit: library.symbols.decision_tree_model_fit ?? null,
|
|
465
567
|
decisionTreeModelPredict: library.symbols.decision_tree_model_predict ?? null,
|
|
568
|
+
randomForestClassifierModelCreate:
|
|
569
|
+
library.symbols.random_forest_classifier_model_create ?? null,
|
|
570
|
+
randomForestClassifierModelDestroy:
|
|
571
|
+
library.symbols.random_forest_classifier_model_destroy ?? null,
|
|
572
|
+
randomForestClassifierModelFit:
|
|
573
|
+
library.symbols.random_forest_classifier_model_fit ?? null,
|
|
574
|
+
randomForestClassifierModelPredict:
|
|
575
|
+
library.symbols.random_forest_classifier_model_predict ?? null,
|
|
466
576
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
467
577
|
logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
|
|
468
578
|
abiVersion,
|
|
@@ -526,6 +636,10 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
526
636
|
decisionTreeModelDestroy: null,
|
|
527
637
|
decisionTreeModelFit: null,
|
|
528
638
|
decisionTreeModelPredict: null,
|
|
639
|
+
randomForestClassifierModelCreate: null,
|
|
640
|
+
randomForestClassifierModelDestroy: null,
|
|
641
|
+
randomForestClassifierModelFit: null,
|
|
642
|
+
randomForestClassifierModelPredict: null,
|
|
529
643
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
530
644
|
logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
|
|
531
645
|
abiVersion: null,
|
|
@@ -571,6 +685,10 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
571
685
|
decisionTreeModelDestroy: null,
|
|
572
686
|
decisionTreeModelFit: null,
|
|
573
687
|
decisionTreeModelPredict: null,
|
|
688
|
+
randomForestClassifierModelCreate: null,
|
|
689
|
+
randomForestClassifierModelDestroy: null,
|
|
690
|
+
randomForestClassifierModelFit: null,
|
|
691
|
+
randomForestClassifierModelPredict: null,
|
|
574
692
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
575
693
|
logisticTrainEpochs: null,
|
|
576
694
|
abiVersion: null,
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface BinarizerOptions {
|
|
9
|
+
threshold?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class Binarizer {
|
|
13
|
+
nFeaturesIn_: number | null = null;
|
|
14
|
+
private readonly threshold: number;
|
|
15
|
+
|
|
16
|
+
constructor(options: BinarizerOptions = {}) {
|
|
17
|
+
this.threshold = options.threshold ?? 0;
|
|
18
|
+
if (!Number.isFinite(this.threshold)) {
|
|
19
|
+
throw new Error(`threshold must be finite. Got ${this.threshold}.`);
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
fit(X: Matrix): this {
|
|
24
|
+
assertNonEmptyMatrix(X);
|
|
25
|
+
assertConsistentRowSize(X);
|
|
26
|
+
assertFiniteMatrix(X);
|
|
27
|
+
this.nFeaturesIn_ = X[0].length;
|
|
28
|
+
return this;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
transform(X: Matrix): Matrix {
|
|
32
|
+
assertNonEmptyMatrix(X);
|
|
33
|
+
assertConsistentRowSize(X);
|
|
34
|
+
assertFiniteMatrix(X);
|
|
35
|
+
|
|
36
|
+
if (this.nFeaturesIn_ !== null && X[0].length !== this.nFeaturesIn_) {
|
|
37
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
return X.map((row) => row.map((value) => (value > this.threshold ? 1 : 0)));
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
fitTransform(X: Matrix): Matrix {
|
|
44
|
+
return this.fit(X).transform(X);
|
|
45
|
+
}
|
|
46
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import type { Vector } from "../types";
|
|
2
|
+
import { assertFiniteVector } from "../utils/validation";
|
|
3
|
+
|
|
4
|
+
export class LabelEncoder {
|
|
5
|
+
classes_: number[] | null = null;
|
|
6
|
+
private classToIndex: Map<number, number> | null = null;
|
|
7
|
+
|
|
8
|
+
fit(y: Vector): this {
|
|
9
|
+
if (!Array.isArray(y) || y.length === 0) {
|
|
10
|
+
throw new Error("y must be a non-empty array.");
|
|
11
|
+
}
|
|
12
|
+
assertFiniteVector(y);
|
|
13
|
+
|
|
14
|
+
const classes = Array.from(new Set(y)).sort((a, b) => a - b);
|
|
15
|
+
const classToIndex = new Map<number, number>();
|
|
16
|
+
for (let i = 0; i < classes.length; i += 1) {
|
|
17
|
+
classToIndex.set(classes[i], i);
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
this.classes_ = classes;
|
|
21
|
+
this.classToIndex = classToIndex;
|
|
22
|
+
return this;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
transform(y: Vector): Vector {
|
|
26
|
+
if (!this.classToIndex) {
|
|
27
|
+
throw new Error("LabelEncoder has not been fitted.");
|
|
28
|
+
}
|
|
29
|
+
assertFiniteVector(y);
|
|
30
|
+
|
|
31
|
+
const encoded = new Array<number>(y.length);
|
|
32
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
33
|
+
const idx = this.classToIndex.get(y[i]);
|
|
34
|
+
if (idx === undefined) {
|
|
35
|
+
throw new Error(`Unknown label ${y[i]} at index ${i}.`);
|
|
36
|
+
}
|
|
37
|
+
encoded[i] = idx;
|
|
38
|
+
}
|
|
39
|
+
return encoded;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
fitTransform(y: Vector): Vector {
|
|
43
|
+
return this.fit(y).transform(y);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
inverseTransform(y: Vector): Vector {
|
|
47
|
+
if (!this.classes_) {
|
|
48
|
+
throw new Error("LabelEncoder has not been fitted.");
|
|
49
|
+
}
|
|
50
|
+
assertFiniteVector(y);
|
|
51
|
+
|
|
52
|
+
const decoded = new Array<number>(y.length);
|
|
53
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
54
|
+
const encoded = y[i];
|
|
55
|
+
if (!Number.isInteger(encoded) || encoded < 0 || encoded >= this.classes_.length) {
|
|
56
|
+
throw new Error(`Encoded label out of range at index ${i}: ${encoded}.`);
|
|
57
|
+
}
|
|
58
|
+
decoded[i] = this.classes_[encoded];
|
|
59
|
+
}
|
|
60
|
+
return decoded;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export class MaxAbsScaler {
|
|
9
|
+
maxAbs_: Vector | null = null;
|
|
10
|
+
|
|
11
|
+
fit(X: Matrix): this {
|
|
12
|
+
assertNonEmptyMatrix(X);
|
|
13
|
+
assertConsistentRowSize(X);
|
|
14
|
+
assertFiniteMatrix(X);
|
|
15
|
+
|
|
16
|
+
const nFeatures = X[0].length;
|
|
17
|
+
const maxAbs = new Array<number>(nFeatures).fill(0);
|
|
18
|
+
|
|
19
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
20
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
21
|
+
const abs = Math.abs(X[i][j]);
|
|
22
|
+
if (abs > maxAbs[j]) {
|
|
23
|
+
maxAbs[j] = abs;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
29
|
+
if (maxAbs[j] === 0) {
|
|
30
|
+
maxAbs[j] = 1;
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
this.maxAbs_ = maxAbs;
|
|
35
|
+
return this;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
transform(X: Matrix): Matrix {
|
|
39
|
+
if (!this.maxAbs_) {
|
|
40
|
+
throw new Error("MaxAbsScaler has not been fitted.");
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
assertNonEmptyMatrix(X);
|
|
44
|
+
assertConsistentRowSize(X);
|
|
45
|
+
assertFiniteMatrix(X);
|
|
46
|
+
|
|
47
|
+
if (X[0].length !== this.maxAbs_.length) {
|
|
48
|
+
throw new Error(
|
|
49
|
+
`Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
|
|
50
|
+
);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
return X.map((row) => row.map((value, featureIdx) => value / this.maxAbs_![featureIdx]));
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
fitTransform(X: Matrix): Matrix {
|
|
57
|
+
return this.fit(X).transform(X);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
inverseTransform(X: Matrix): Matrix {
|
|
61
|
+
if (!this.maxAbs_) {
|
|
62
|
+
throw new Error("MaxAbsScaler has not been fitted.");
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
assertNonEmptyMatrix(X);
|
|
66
|
+
assertConsistentRowSize(X);
|
|
67
|
+
assertFiniteMatrix(X);
|
|
68
|
+
|
|
69
|
+
if (X[0].length !== this.maxAbs_.length) {
|
|
70
|
+
throw new Error(
|
|
71
|
+
`Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
|
|
72
|
+
);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return X.map((row) => row.map((value, featureIdx) => value * this.maxAbs_![featureIdx]));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface NormalizerOptions {
|
|
9
|
+
norm?: "l1" | "l2" | "max";
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class Normalizer {
|
|
13
|
+
private readonly norm: "l1" | "l2" | "max";
|
|
14
|
+
private nFeatures_: number | null = null;
|
|
15
|
+
|
|
16
|
+
constructor(options: NormalizerOptions = {}) {
|
|
17
|
+
this.norm = options.norm ?? "l2";
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
fit(X: Matrix): this {
|
|
21
|
+
assertNonEmptyMatrix(X);
|
|
22
|
+
assertConsistentRowSize(X);
|
|
23
|
+
assertFiniteMatrix(X);
|
|
24
|
+
this.nFeatures_ = X[0].length;
|
|
25
|
+
return this;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
transform(X: Matrix): Matrix {
|
|
29
|
+
assertNonEmptyMatrix(X);
|
|
30
|
+
assertConsistentRowSize(X);
|
|
31
|
+
assertFiniteMatrix(X);
|
|
32
|
+
if (this.nFeatures_ !== null && X[0].length !== this.nFeatures_) {
|
|
33
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeatures_}, got ${X[0].length}.`);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
return X.map((row) => {
|
|
37
|
+
let scale = 0;
|
|
38
|
+
if (this.norm === "l1") {
|
|
39
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
40
|
+
scale += Math.abs(row[i]);
|
|
41
|
+
}
|
|
42
|
+
} else if (this.norm === "l2") {
|
|
43
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
44
|
+
scale += row[i] * row[i];
|
|
45
|
+
}
|
|
46
|
+
scale = Math.sqrt(scale);
|
|
47
|
+
} else {
|
|
48
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
49
|
+
const abs = Math.abs(row[i]);
|
|
50
|
+
if (abs > scale) {
|
|
51
|
+
scale = abs;
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
if (scale === 0) {
|
|
57
|
+
return [...row];
|
|
58
|
+
}
|
|
59
|
+
return row.map((value) => value / scale);
|
|
60
|
+
});
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
fitTransform(X: Matrix): Matrix {
|
|
64
|
+
return this.fit(X).transform(X);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
@@ -6,6 +6,7 @@ import {
|
|
|
6
6
|
validateClassificationInputs,
|
|
7
7
|
} from "../utils/validation";
|
|
8
8
|
import { accuracyScore } from "../metrics/classification";
|
|
9
|
+
import { getZigKernels } from "../native/zigKernels";
|
|
9
10
|
|
|
10
11
|
export type MaxFeaturesOption = "sqrt" | "log2" | number | null;
|
|
11
12
|
|
|
@@ -38,6 +39,11 @@ interface SplitPartition {
|
|
|
38
39
|
|
|
39
40
|
const MAX_THRESHOLD_BINS = 128;
|
|
40
41
|
|
|
42
|
+
function isZigTreeBackendEnabled(): boolean {
|
|
43
|
+
const mode = process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase();
|
|
44
|
+
return mode === "zig" || mode === "native";
|
|
45
|
+
}
|
|
46
|
+
|
|
41
47
|
function mulberry32(seed: number): () => number {
|
|
42
48
|
let state = seed >>> 0;
|
|
43
49
|
return () => {
|
|
@@ -59,6 +65,8 @@ function giniImpurity(positiveCount: number, sampleCount: number): number {
|
|
|
59
65
|
|
|
60
66
|
export class DecisionTreeClassifier implements ClassificationModel {
|
|
61
67
|
classes_: Vector = [0, 1];
|
|
68
|
+
fitBackend_: "zig" | "js" = "js";
|
|
69
|
+
fitBackendLibrary_: string | null = null;
|
|
62
70
|
private readonly maxDepth: number;
|
|
63
71
|
private readonly minSamplesSplit: number;
|
|
64
72
|
private readonly minSamplesLeaf: number;
|
|
@@ -73,6 +81,7 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
73
81
|
private featureSelectionMarks: Uint8Array | null = null;
|
|
74
82
|
private binTotals: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
|
|
75
83
|
private binPositives: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
|
|
84
|
+
private zigModelHandle: bigint | null = null;
|
|
76
85
|
|
|
77
86
|
constructor(options: DecisionTreeClassifierOptions = {}) {
|
|
78
87
|
this.maxDepth = options.maxDepth ?? 12;
|
|
@@ -90,6 +99,8 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
90
99
|
flattenedXTrain?: Float64Array,
|
|
91
100
|
yBinaryTrain?: Uint8Array,
|
|
92
101
|
): this {
|
|
102
|
+
this.destroyZigModel();
|
|
103
|
+
|
|
93
104
|
if (!skipValidation) {
|
|
94
105
|
validateClassificationInputs(X, y);
|
|
95
106
|
}
|
|
@@ -103,18 +114,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
103
114
|
this.featureSelectionMarks = new Uint8Array(this.featureCount);
|
|
104
115
|
this.random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
105
116
|
|
|
106
|
-
let
|
|
117
|
+
let validatedSampleIndices: Uint32Array | null = null;
|
|
107
118
|
if (sampleIndices) {
|
|
108
119
|
if (sampleIndices.length === 0) {
|
|
109
120
|
throw new Error("sampleIndices must not be empty.");
|
|
110
121
|
}
|
|
122
|
+
validatedSampleIndices = new Uint32Array(sampleIndices.length);
|
|
111
123
|
for (let i = 0; i < sampleIndices.length; i += 1) {
|
|
112
124
|
const index = sampleIndices[i];
|
|
113
125
|
if (!Number.isInteger(index) || index < 0 || index >= X.length) {
|
|
114
126
|
throw new Error(`sampleIndices contains invalid index: ${index}.`);
|
|
115
127
|
}
|
|
128
|
+
validatedSampleIndices[i] = index;
|
|
116
129
|
}
|
|
117
|
-
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if (isZigTreeBackendEnabled() && this.tryFitWithZig(X.length, validatedSampleIndices)) {
|
|
133
|
+
return this;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
let rootIndices: number[];
|
|
137
|
+
if (validatedSampleIndices) {
|
|
138
|
+
rootIndices = Array.from(validatedSampleIndices);
|
|
118
139
|
} else {
|
|
119
140
|
rootIndices = new Array<number>(X.length);
|
|
120
141
|
for (let idx = 0; idx < X.length; idx += 1) {
|
|
@@ -123,11 +144,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
123
144
|
}
|
|
124
145
|
|
|
125
146
|
this.root = this.buildTree(rootIndices, 0);
|
|
147
|
+
this.fitBackend_ = "js";
|
|
148
|
+
this.fitBackendLibrary_ = null;
|
|
126
149
|
return this;
|
|
127
150
|
}
|
|
128
151
|
|
|
129
152
|
predict(X: Matrix): Vector {
|
|
130
|
-
if (
|
|
153
|
+
if ((this.root === null && this.zigModelHandle === null) || this.featureCount === 0) {
|
|
131
154
|
throw new Error("DecisionTreeClassifier has not been fitted.");
|
|
132
155
|
}
|
|
133
156
|
|
|
@@ -140,7 +163,34 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
140
163
|
);
|
|
141
164
|
}
|
|
142
165
|
|
|
143
|
-
|
|
166
|
+
if (this.zigModelHandle !== null) {
|
|
167
|
+
const kernels = getZigKernels();
|
|
168
|
+
const nativePredict = kernels?.decisionTreeModelPredict;
|
|
169
|
+
if (nativePredict) {
|
|
170
|
+
const flattenedX = this.flattenTrainingMatrix(X);
|
|
171
|
+
const outLabels = new Uint8Array(X.length);
|
|
172
|
+
const status = nativePredict(
|
|
173
|
+
this.zigModelHandle,
|
|
174
|
+
flattenedX,
|
|
175
|
+
X.length,
|
|
176
|
+
this.featureCount,
|
|
177
|
+
outLabels,
|
|
178
|
+
);
|
|
179
|
+
if (status === 1) {
|
|
180
|
+
return Array.from(outLabels);
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
if (!this.root) {
|
|
184
|
+
throw new Error("Native DecisionTree predict failed and no JS fallback tree is available.");
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
const predictions = new Array<number>(X.length);
|
|
189
|
+
const root = this.root!;
|
|
190
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
191
|
+
predictions[i] = this.predictOne(X[i], root);
|
|
192
|
+
}
|
|
193
|
+
return predictions;
|
|
144
194
|
}
|
|
145
195
|
|
|
146
196
|
score(X: Matrix, y: Vector): number {
|
|
@@ -148,6 +198,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
148
198
|
return accuracyScore(y, this.predict(X));
|
|
149
199
|
}
|
|
150
200
|
|
|
201
|
+
dispose(): void {
|
|
202
|
+
this.destroyZigModel();
|
|
203
|
+
this.root = null;
|
|
204
|
+
this.flattenedXTrain = null;
|
|
205
|
+
this.yBinaryTrain = null;
|
|
206
|
+
}
|
|
207
|
+
|
|
151
208
|
private predictOne(sample: Vector, node: TreeNode): 0 | 1 {
|
|
152
209
|
let current: TreeNode = node;
|
|
153
210
|
while (
|
|
@@ -228,9 +285,107 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
228
285
|
if (this.maxFeatures === "log2") {
|
|
229
286
|
return Math.max(1, Math.floor(Math.log2(featureCount)));
|
|
230
287
|
}
|
|
288
|
+
if (!Number.isFinite(this.maxFeatures)) {
|
|
289
|
+
return featureCount;
|
|
290
|
+
}
|
|
231
291
|
return Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)));
|
|
232
292
|
}
|
|
233
293
|
|
|
294
|
+
private resolveNativeMaxFeatures(featureCount: number): {
|
|
295
|
+
mode: 0 | 1 | 2 | 3;
|
|
296
|
+
value: number;
|
|
297
|
+
} {
|
|
298
|
+
if (this.maxFeatures === null || this.maxFeatures === undefined) {
|
|
299
|
+
return { mode: 0, value: 0 };
|
|
300
|
+
}
|
|
301
|
+
if (this.maxFeatures === "sqrt") {
|
|
302
|
+
return { mode: 1, value: 0 };
|
|
303
|
+
}
|
|
304
|
+
if (this.maxFeatures === "log2") {
|
|
305
|
+
return { mode: 2, value: 0 };
|
|
306
|
+
}
|
|
307
|
+
const value = Number.isFinite(this.maxFeatures)
|
|
308
|
+
? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
|
|
309
|
+
: featureCount;
|
|
310
|
+
return { mode: 3, value };
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
private tryFitWithZig(
|
|
314
|
+
sampleCount: number,
|
|
315
|
+
sampleIndices: Uint32Array | null,
|
|
316
|
+
): boolean {
|
|
317
|
+
const kernels = getZigKernels();
|
|
318
|
+
const create = kernels?.decisionTreeModelCreate;
|
|
319
|
+
const fit = kernels?.decisionTreeModelFit;
|
|
320
|
+
const destroy = kernels?.decisionTreeModelDestroy;
|
|
321
|
+
if (!create || !fit || !destroy) {
|
|
322
|
+
return false;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
const { mode, value } = this.resolveNativeMaxFeatures(this.featureCount);
|
|
326
|
+
const useRandomState = this.randomState === undefined ? 0 : 1;
|
|
327
|
+
const randomState = this.randomState ?? 0;
|
|
328
|
+
const handle = create(
|
|
329
|
+
this.maxDepth,
|
|
330
|
+
this.minSamplesSplit,
|
|
331
|
+
this.minSamplesLeaf,
|
|
332
|
+
mode,
|
|
333
|
+
value,
|
|
334
|
+
randomState >>> 0,
|
|
335
|
+
useRandomState,
|
|
336
|
+
this.featureCount,
|
|
337
|
+
);
|
|
338
|
+
if (handle === 0n) {
|
|
339
|
+
return false;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
let shouldDestroy = true;
|
|
343
|
+
try {
|
|
344
|
+
const emptySampleIndices = new Uint32Array(0);
|
|
345
|
+
const status = fit(
|
|
346
|
+
handle,
|
|
347
|
+
this.flattenedXTrain!,
|
|
348
|
+
this.yBinaryTrain!,
|
|
349
|
+
sampleCount,
|
|
350
|
+
this.featureCount,
|
|
351
|
+
sampleIndices ?? emptySampleIndices,
|
|
352
|
+
sampleIndices?.length ?? 0,
|
|
353
|
+
);
|
|
354
|
+
if (status !== 1) {
|
|
355
|
+
return false;
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
this.zigModelHandle = handle;
|
|
359
|
+
this.root = null;
|
|
360
|
+
this.fitBackend_ = "zig";
|
|
361
|
+
this.fitBackendLibrary_ = kernels.libraryPath;
|
|
362
|
+
shouldDestroy = false;
|
|
363
|
+
return true;
|
|
364
|
+
} catch {
|
|
365
|
+
return false;
|
|
366
|
+
} finally {
|
|
367
|
+
if (shouldDestroy) {
|
|
368
|
+
destroy(handle);
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
private destroyZigModel(): void {
|
|
374
|
+
if (this.zigModelHandle === null) {
|
|
375
|
+
return;
|
|
376
|
+
}
|
|
377
|
+
const kernels = getZigKernels();
|
|
378
|
+
const destroy = kernels?.decisionTreeModelDestroy;
|
|
379
|
+
if (destroy) {
|
|
380
|
+
try {
|
|
381
|
+
destroy(this.zigModelHandle);
|
|
382
|
+
} catch {
|
|
383
|
+
// no-op: cleanup best effort
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
this.zigModelHandle = null;
|
|
387
|
+
}
|
|
388
|
+
|
|
234
389
|
private selectFeatureIndices(featureCount: number): number[] {
|
|
235
390
|
const k = this.resolveMaxFeatures(featureCount);
|
|
236
391
|
if (k >= featureCount) {
|