bun-scikit 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +2 -2
- package/scripts/check-benchmark-health.ts +62 -1
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +149 -0
- package/src/native/zigKernels.ts +33 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +146 -3
- package/zig/kernels.zig +63 -40
package/README.md
CHANGED
|
@@ -3,185 +3,121 @@
|
|
|
3
3
|
[](https://github.com/Seyamalam/bun-scikit/actions/workflows/ci.yml)
|
|
4
4
|
[](https://github.com/Seyamalam/bun-scikit/actions/workflows/benchmark-snapshot.yml)
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
## Features
|
|
9
|
-
|
|
10
|
-
- `StandardScaler`
|
|
11
|
-
- `LinearRegression` (native Zig `normal` solver)
|
|
12
|
-
- `LogisticRegression` (binary classification, native Zig)
|
|
13
|
-
- `KNeighborsClassifier`
|
|
14
|
-
- `DecisionTreeClassifier`
|
|
15
|
-
- `RandomForestClassifier`
|
|
16
|
-
- `trainTestSplit`
|
|
17
|
-
- Regression metrics: `meanSquaredError`, `meanAbsoluteError`, `r2Score`
|
|
18
|
-
- Classification metrics: `accuracyScore`, `precisionScore`, `recallScore`, `f1Score`
|
|
19
|
-
- Dataset-driven benchmark and CI comparison against Python `scikit-learn`
|
|
20
|
-
|
|
21
|
-
`test_data/heart.csv` is used for integration testing and benchmark comparison.
|
|
22
|
-
|
|
23
|
-
## Native Zig Backend
|
|
24
|
-
|
|
25
|
-
`LinearRegression` (`solver: "normal"`) and `LogisticRegression` require native Zig kernels.
|
|
26
|
-
|
|
27
|
-
```bash
|
|
28
|
-
bun run native:build
|
|
29
|
-
```
|
|
30
|
-
|
|
31
|
-
Optional Node-API bridge (experimental):
|
|
32
|
-
|
|
33
|
-
```bash
|
|
34
|
-
bun run native:build:node-addon
|
|
35
|
-
```
|
|
36
|
-
|
|
37
|
-
```ts
|
|
38
|
-
const linear = new LinearRegression({ solver: "normal" });
|
|
39
|
-
const logistic = new LogisticRegression();
|
|
40
|
-
|
|
41
|
-
linear.fit(XTrain, yTrain);
|
|
42
|
-
logistic.fit(XTrain, yTrain);
|
|
43
|
-
console.log(linear.fitBackend_, linear.fitBackendLibrary_);
|
|
44
|
-
console.log(logistic.fitBackend_, logistic.fitBackendLibrary_);
|
|
45
|
-
```
|
|
46
|
-
|
|
47
|
-
If native kernels are missing, `fit()` throws with guidance to run `bun run native:build`.
|
|
48
|
-
|
|
49
|
-
Bridge selection:
|
|
50
|
-
|
|
51
|
-
- `BUN_SCIKIT_NATIVE_BRIDGE=node-api|ffi` (`node-api` is attempted first when available)
|
|
52
|
-
- `BUN_SCIKIT_NODE_ADDON=/absolute/path/to/bun_scikit_node_addon.node`
|
|
53
|
-
- `BUN_SCIKIT_ZIG_LIB=/absolute/path/to/bun_scikit_kernels.<ext>`
|
|
54
|
-
|
|
55
|
-
Native ABI contract: `docs/native-abi.md`
|
|
6
|
+
Scikit-learn-inspired machine learning for Bun + TypeScript, with native Zig acceleration for core training paths.
|
|
56
7
|
|
|
57
8
|
## Install
|
|
58
9
|
|
|
59
10
|
```bash
|
|
60
|
-
bun
|
|
11
|
+
bun add bun-scikit
|
|
61
12
|
```
|
|
62
13
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
- Prebuilt native binaries for `linux-x64` and `windows-x64` are bundled in the npm package.
|
|
66
|
-
- No `bun pm trust` step is required for normal install/use.
|
|
67
|
-
- macOS prebuilt binaries are currently not published.
|
|
68
|
-
|
|
69
|
-
## Usage
|
|
14
|
+
## Quick Start
|
|
70
15
|
|
|
71
16
|
```ts
|
|
72
17
|
import {
|
|
73
18
|
LinearRegression,
|
|
19
|
+
LogisticRegression,
|
|
74
20
|
StandardScaler,
|
|
75
|
-
meanSquaredError,
|
|
76
21
|
trainTestSplit,
|
|
22
|
+
meanSquaredError,
|
|
23
|
+
accuracyScore,
|
|
77
24
|
} from "bun-scikit";
|
|
78
25
|
|
|
79
|
-
const X = [
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
[3, 4],
|
|
83
|
-
[4, 5],
|
|
84
|
-
];
|
|
85
|
-
const y = [5, 7, 9, 11];
|
|
26
|
+
const X = [[1], [2], [3], [4], [5], [6]];
|
|
27
|
+
const yReg = [3, 5, 7, 9, 11, 13];
|
|
28
|
+
const yCls = [0, 0, 0, 1, 1, 1];
|
|
86
29
|
|
|
87
30
|
const scaler = new StandardScaler();
|
|
88
|
-
const
|
|
89
|
-
|
|
90
|
-
|
|
31
|
+
const Xs = scaler.fitTransform(X);
|
|
32
|
+
|
|
33
|
+
const { XTrain, XTest, yTrain, yTest } = trainTestSplit(Xs, yReg, {
|
|
34
|
+
testSize: 0.33,
|
|
91
35
|
randomState: 42,
|
|
92
36
|
});
|
|
93
37
|
|
|
94
|
-
const
|
|
95
|
-
|
|
96
|
-
|
|
38
|
+
const reg = new LinearRegression({ solver: "normal" });
|
|
39
|
+
reg.fit(XTrain, yTrain);
|
|
40
|
+
console.log("MSE:", meanSquaredError(yTest, reg.predict(XTest)));
|
|
97
41
|
|
|
98
|
-
|
|
42
|
+
const clf = new LogisticRegression({
|
|
43
|
+
solver: "gd",
|
|
44
|
+
learningRate: 0.8,
|
|
45
|
+
maxIter: 100,
|
|
46
|
+
tolerance: 1e-5,
|
|
47
|
+
});
|
|
48
|
+
clf.fit(Xs, yCls);
|
|
49
|
+
console.log("Accuracy:", accuracyScore(yCls, clf.predict(Xs)));
|
|
99
50
|
```
|
|
100
51
|
|
|
101
|
-
##
|
|
52
|
+
## Included APIs
|
|
102
53
|
|
|
103
|
-
|
|
104
|
-
|
|
54
|
+
- Models: `LinearRegression`, `LogisticRegression`, `KNeighborsClassifier`, `DecisionTreeClassifier`, `RandomForestClassifier`, plus additional parity models (`LinearSVC`, `GaussianNB`, `SGDClassifier`, `SGDRegressor`, regressors for tree/forest).
|
|
55
|
+
- Baselines: `DummyClassifier`, `DummyRegressor`.
|
|
56
|
+
- Preprocessing: `StandardScaler`, `MinMaxScaler`, `RobustScaler`, `MaxAbsScaler`, `Normalizer`, `Binarizer`, `LabelEncoder`, `PolynomialFeatures`, `SimpleImputer`, `OneHotEncoder`.
|
|
57
|
+
- Composition: `Pipeline`, `ColumnTransformer`, `FeatureUnion`.
|
|
58
|
+
- Feature selection: `VarianceThreshold`.
|
|
59
|
+
- Model selection: `trainTestSplit`, `KFold`, stratified/repeated splitters, `crossValScore`, `GridSearchCV`, `RandomizedSearchCV`.
|
|
60
|
+
- Metrics: regression and classification metrics, including `logLoss`, `rocAucScore`, `confusionMatrix`, `classificationReport`, `balancedAccuracyScore`, `matthewsCorrcoef`, `brierScoreLoss`, `meanAbsolutePercentageError`, and `explainedVarianceScore`.
|
|
105
61
|
|
|
106
|
-
|
|
107
|
-
Benchmark snapshot source: `bench/results/heart-ci-latest.json` (generated in CI workflow `Benchmark Snapshot`).
|
|
108
|
-
Dataset: `test_data/heart.csv` (1025 samples, 13 features, test fraction 0.2).
|
|
62
|
+
## Scikit Parity Matrix
|
|
109
63
|
|
|
110
|
-
|
|
64
|
+
| Area | Status |
|
|
65
|
+
| --- | --- |
|
|
66
|
+
| Linear models | `LinearRegression`, `LogisticRegression`, `SGDClassifier`, `SGDRegressor`, `LinearSVC` |
|
|
67
|
+
| Tree/ensemble | `DecisionTreeClassifier`, `DecisionTreeRegressor`, `RandomForestClassifier`, `RandomForestRegressor` |
|
|
68
|
+
| Neighbors / Bayes | `KNeighborsClassifier`, `GaussianNB` |
|
|
69
|
+
| Baselines | `DummyClassifier`, `DummyRegressor` |
|
|
70
|
+
| Preprocessing | `StandardScaler`, `MinMaxScaler`, `RobustScaler`, `MaxAbsScaler`, `Normalizer`, `Binarizer`, `LabelEncoder`, `PolynomialFeatures`, `SimpleImputer`, `OneHotEncoder` |
|
|
71
|
+
| Feature selection | `VarianceThreshold` |
|
|
72
|
+
| Model selection | `trainTestSplit`, `KFold`, `StratifiedKFold`, `StratifiedShuffleSplit`, `RepeatedKFold`, `RepeatedStratifiedKFold`, `crossValScore`, `GridSearchCV`, `RandomizedSearchCV` |
|
|
73
|
+
| Metrics (regression) | `meanSquaredError`, `meanAbsoluteError`, `r2Score`, `meanAbsolutePercentageError`, `explainedVarianceScore` |
|
|
74
|
+
| Metrics (classification) | `accuracyScore`, `precisionScore`, `recallScore`, `f1Score`, `balancedAccuracyScore`, `matthewsCorrcoef`, `logLoss`, `brierScoreLoss`, `rocAucScore`, `confusionMatrix`, `classificationReport` |
|
|
111
75
|
|
|
112
|
-
|
|
113
|
-
|---|---|---:|---:|---:|---:|
|
|
114
|
-
| bun-scikit | StandardScaler + LinearRegression(normal) | 0.2103 | 0.0216 | 0.117545 | 0.529539 |
|
|
115
|
-
| python-scikit-learn | StandardScaler + LinearRegression | 0.3201 | 0.0365 | 0.117545 | 0.529539 |
|
|
76
|
+
Near-term parity gaps vs scikit-learn include clustering, decomposition, calibration, advanced feature selection, and probability calibration/meta-estimators.
|
|
116
77
|
|
|
117
|
-
|
|
118
|
-
Bun predict speedup vs scikit-learn: 1.684x
|
|
119
|
-
MSE delta (bun - sklearn): 6.362e-14
|
|
120
|
-
R2 delta (bun - sklearn): -2.539e-13
|
|
78
|
+
## Native Runtime
|
|
121
79
|
|
|
122
|
-
|
|
80
|
+
- Prebuilt binaries are bundled in the npm package for:
|
|
81
|
+
- `linux-x64`
|
|
82
|
+
- `windows-x64`
|
|
83
|
+
- No `bun pm trust` step is required for standard install/use.
|
|
84
|
+
- macOS prebuilt binaries are not published yet.
|
|
123
85
|
|
|
124
|
-
|
|
125
|
-
|---|---|---:|---:|---:|---:|
|
|
126
|
-
| bun-scikit | StandardScaler + LogisticRegression(gd,zig) | 0.4868 | 0.0282 | 0.863415 | 0.876106 |
|
|
127
|
-
| python-scikit-learn | StandardScaler + LogisticRegression(lbfgs) | 1.1246 | 0.0724 | 0.863415 | 0.875000 |
|
|
86
|
+
Optional env vars:
|
|
128
87
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
88
|
+
- `BUN_SCIKIT_NATIVE_BRIDGE=node-api|ffi`
|
|
89
|
+
- `BUN_SCIKIT_NODE_ADDON=/absolute/path/to/bun_scikit_node_addon.node`
|
|
90
|
+
- `BUN_SCIKIT_ZIG_LIB=/absolute/path/to/bun_scikit_kernels.<ext>`
|
|
91
|
+
- `BUN_SCIKIT_TREE_BACKEND=zig` (opt-in native tree/forest training path; default keeps JS-fast tree splitter)
|
|
133
92
|
|
|
134
|
-
|
|
93
|
+
## Performance Snapshot
|
|
135
94
|
|
|
136
|
-
|
|
137
|
-
|---|---|---:|---:|---:|---:|
|
|
138
|
-
| DecisionTreeClassifier(maxDepth=8) | bun-scikit | 0.8062 | 0.0190 | 0.946341 | 0.948837 |
|
|
139
|
-
| DecisionTreeClassifier | python-scikit-learn | 1.4781 | 0.0999 | 0.931707 | 0.933962 |
|
|
140
|
-
| RandomForestClassifier(nEstimators=80,maxDepth=8) | bun-scikit | 27.6225 | 1.8535 | 0.990244 | 0.990566 |
|
|
141
|
-
| RandomForestClassifier | python-scikit-learn | 172.9585 | 6.4850 | 0.995122 | 0.995261 |
|
|
95
|
+
Latest CI snapshot on `test_data/heart.csv` vs Python scikit-learn:
|
|
142
96
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
DecisionTree
|
|
146
|
-
|
|
97
|
+
- Regression: fit `1.67x`, predict `1.84x`
|
|
98
|
+
- Classification: fit `1.78x`, predict `2.66x`
|
|
99
|
+
- DecisionTree (`js-fast`): fit `1.54x`, predict `4.06x`
|
|
100
|
+
- RandomForest (`js-fast`): fit `2.59x`, predict `1.29x`
|
|
101
|
+
- Tree backend matrix (`js-fast` vs `zig-tree` vs `sklearn`) is included in `bench/results/heart-ci-latest.md`
|
|
147
102
|
|
|
148
|
-
|
|
149
|
-
RandomForest predict speedup vs scikit-learn: 3.499x
|
|
150
|
-
RandomForest accuracy delta (bun - sklearn): -4.878e-3
|
|
151
|
-
RandomForest f1 delta (bun - sklearn): -4.695e-3
|
|
103
|
+
Raw benchmark artifacts:
|
|
152
104
|
|
|
153
|
-
|
|
154
|
-
|
|
105
|
+
- `bench/results/heart-ci-latest.json`
|
|
106
|
+
- `bench/results/heart-ci-latest.md`
|
|
155
107
|
|
|
156
108
|
## Documentation
|
|
157
109
|
|
|
158
|
-
- Docs index: `docs/README.md`
|
|
159
110
|
- Getting started: `docs/getting-started.md`
|
|
160
111
|
- API reference: `docs/api.md`
|
|
161
|
-
- Benchmarking
|
|
112
|
+
- Benchmarking: `docs/benchmarking.md`
|
|
162
113
|
- Zig acceleration: `docs/zig-acceleration.md`
|
|
114
|
+
- Native ABI: `docs/native-abi.md`
|
|
115
|
+
- Release checklist: `docs/release-checklist.md`
|
|
163
116
|
|
|
164
|
-
##
|
|
117
|
+
## Contributing / Project Files
|
|
165
118
|
|
|
166
119
|
- Changelog: `CHANGELOG.md`
|
|
167
|
-
- Contributing
|
|
120
|
+
- Contributing: `CONTRIBUTING.md`
|
|
121
|
+
- Security: `SECURITY.md`
|
|
168
122
|
- Code of Conduct: `CODE_OF_CONDUCT.md`
|
|
169
|
-
-
|
|
170
|
-
- Support policy: `SUPPORT.md`
|
|
171
|
-
- License: `LICENSE`
|
|
172
|
-
|
|
173
|
-
## Local Commands
|
|
174
|
-
|
|
175
|
-
```bash
|
|
176
|
-
bun run test
|
|
177
|
-
bun run typecheck
|
|
178
|
-
bun run docs:api:generate
|
|
179
|
-
bun run docs:coverage:check
|
|
180
|
-
bun run bench
|
|
181
|
-
bun run bench:heart:classification
|
|
182
|
-
bun run bench:heart:tree
|
|
183
|
-
bun run bench:ci
|
|
184
|
-
bun run bench:ci:native
|
|
185
|
-
bun run bench:snapshot
|
|
186
|
-
bun run native:build
|
|
187
|
-
```
|
|
123
|
+
- Support: `SUPPORT.md`
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "bun-scikit",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.4",
|
|
4
4
|
"description": "A scikit-learn-inspired machine learning library for Bun/TypeScript.",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"module": "index.ts",
|
|
@@ -52,7 +52,7 @@
|
|
|
52
52
|
"bench:synthetic": "bun run bench/linear-regression.bench.ts",
|
|
53
53
|
"bench:ci": "bun run bench/run-ci-benchmarks.ts --output bench/results/heart-ci-current.json",
|
|
54
54
|
"bench:ci:native": "bun run native:build && bun run bench:ci",
|
|
55
|
-
"bench:snapshot": "bun run bench/run-ci-benchmarks.ts --output bench/results/heart-ci-latest.json && bun run bench:
|
|
55
|
+
"bench:snapshot": "bun run bench/run-ci-benchmarks.ts --output bench/results/heart-ci-latest.json && bun run bench:history:update",
|
|
56
56
|
"bench:sync-readme": "bun run scripts/sync-benchmark-readme.ts",
|
|
57
57
|
"bench:readme:check": "bun run scripts/sync-benchmark-readme.ts --check",
|
|
58
58
|
"bench:health": "bun run scripts/check-benchmark-health.ts",
|
|
@@ -26,6 +26,17 @@ interface TreeModelComparison {
|
|
|
26
26
|
};
|
|
27
27
|
}
|
|
28
28
|
|
|
29
|
+
interface TreeBackendModeComparison {
|
|
30
|
+
comparison: {
|
|
31
|
+
zigFitSpeedupVsJs: number;
|
|
32
|
+
zigPredictSpeedupVsJs: number;
|
|
33
|
+
jsFitSpeedupVsSklearn: number;
|
|
34
|
+
jsPredictSpeedupVsSklearn: number;
|
|
35
|
+
zigFitSpeedupVsSklearn: number;
|
|
36
|
+
zigPredictSpeedupVsSklearn: number;
|
|
37
|
+
};
|
|
38
|
+
}
|
|
39
|
+
|
|
29
40
|
interface BenchmarkSnapshot {
|
|
30
41
|
suites: {
|
|
31
42
|
regression: {
|
|
@@ -62,6 +73,10 @@ interface BenchmarkSnapshot {
|
|
|
62
73
|
},
|
|
63
74
|
];
|
|
64
75
|
};
|
|
76
|
+
treeBackendModes: {
|
|
77
|
+
enabled: boolean;
|
|
78
|
+
models: [TreeBackendModeComparison, TreeBackendModeComparison] | [];
|
|
79
|
+
};
|
|
65
80
|
};
|
|
66
81
|
}
|
|
67
82
|
|
|
@@ -106,7 +121,20 @@ const minDecisionTreePredictSpeedup = speedupThreshold(
|
|
|
106
121
|
const minRandomForestFitSpeedup = speedupThreshold("BENCH_MIN_RANDOM_FOREST_FIT_SPEEDUP", 2.0);
|
|
107
122
|
const minRandomForestPredictSpeedup = speedupThreshold(
|
|
108
123
|
"BENCH_MIN_RANDOM_FOREST_PREDICT_SPEEDUP",
|
|
109
|
-
2
|
|
124
|
+
1.2,
|
|
125
|
+
);
|
|
126
|
+
const maxZigTreeFitSlowdownVsJs = speedupThreshold("BENCH_MAX_ZIG_TREE_FIT_SLOWDOWN_VS_JS", 20);
|
|
127
|
+
const maxZigTreePredictSlowdownVsJs = speedupThreshold(
|
|
128
|
+
"BENCH_MAX_ZIG_TREE_PREDICT_SLOWDOWN_VS_JS",
|
|
129
|
+
20,
|
|
130
|
+
);
|
|
131
|
+
const maxZigForestFitSlowdownVsJs = speedupThreshold(
|
|
132
|
+
"BENCH_MAX_ZIG_FOREST_FIT_SLOWDOWN_VS_JS",
|
|
133
|
+
20,
|
|
134
|
+
);
|
|
135
|
+
const maxZigForestPredictSlowdownVsJs = speedupThreshold(
|
|
136
|
+
"BENCH_MAX_ZIG_FOREST_PREDICT_SLOWDOWN_VS_JS",
|
|
137
|
+
20,
|
|
110
138
|
);
|
|
111
139
|
|
|
112
140
|
for (const result of [
|
|
@@ -237,4 +265,37 @@ if (randomForest.comparison.predictSpeedupVsSklearn < minRandomForestPredictSpee
|
|
|
237
265
|
);
|
|
238
266
|
}
|
|
239
267
|
|
|
268
|
+
if (snapshot.suites.treeBackendModes.enabled) {
|
|
269
|
+
const [decisionTreeModes, randomForestModes] = snapshot.suites.treeBackendModes.models;
|
|
270
|
+
if (!decisionTreeModes || !randomForestModes) {
|
|
271
|
+
throw new Error("Tree backend mode suite is enabled but missing model comparisons.");
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
const decisionTreeFitSlowdown = 1 / decisionTreeModes.comparison.zigFitSpeedupVsJs;
|
|
275
|
+
const decisionTreePredictSlowdown = 1 / decisionTreeModes.comparison.zigPredictSpeedupVsJs;
|
|
276
|
+
const randomForestFitSlowdown = 1 / randomForestModes.comparison.zigFitSpeedupVsJs;
|
|
277
|
+
const randomForestPredictSlowdown = 1 / randomForestModes.comparison.zigPredictSpeedupVsJs;
|
|
278
|
+
|
|
279
|
+
if (decisionTreeFitSlowdown > maxZigTreeFitSlowdownVsJs) {
|
|
280
|
+
throw new Error(
|
|
281
|
+
`DecisionTree zig fit slowdown too large vs js-fast: ${decisionTreeFitSlowdown} > ${maxZigTreeFitSlowdownVsJs}.`,
|
|
282
|
+
);
|
|
283
|
+
}
|
|
284
|
+
if (decisionTreePredictSlowdown > maxZigTreePredictSlowdownVsJs) {
|
|
285
|
+
throw new Error(
|
|
286
|
+
`DecisionTree zig predict slowdown too large vs js-fast: ${decisionTreePredictSlowdown} > ${maxZigTreePredictSlowdownVsJs}.`,
|
|
287
|
+
);
|
|
288
|
+
}
|
|
289
|
+
if (randomForestFitSlowdown > maxZigForestFitSlowdownVsJs) {
|
|
290
|
+
throw new Error(
|
|
291
|
+
`RandomForest zig fit slowdown too large vs js-fast: ${randomForestFitSlowdown} > ${maxZigForestFitSlowdownVsJs}.`,
|
|
292
|
+
);
|
|
293
|
+
}
|
|
294
|
+
if (randomForestPredictSlowdown > maxZigForestPredictSlowdownVsJs) {
|
|
295
|
+
throw new Error(
|
|
296
|
+
`RandomForest zig predict slowdown too large vs js-fast: ${randomForestPredictSlowdown} > ${maxZigForestPredictSlowdownVsJs}.`,
|
|
297
|
+
);
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
|
|
240
301
|
console.log("Benchmark comparison health checks passed.");
|
|
@@ -62,6 +62,19 @@ interface BenchmarkSnapshot {
|
|
|
62
62
|
treeClassification: {
|
|
63
63
|
models: [TreeModelComparison, TreeModelComparison];
|
|
64
64
|
};
|
|
65
|
+
treeBackendModes?: {
|
|
66
|
+
enabled: boolean;
|
|
67
|
+
models: Array<{
|
|
68
|
+
key: TreeModelKey;
|
|
69
|
+
jsFast: ClassificationBenchmarkResult;
|
|
70
|
+
zigTree: ClassificationBenchmarkResult;
|
|
71
|
+
sklearn: ClassificationBenchmarkResult;
|
|
72
|
+
comparison: {
|
|
73
|
+
zigFitSpeedupVsJs: number;
|
|
74
|
+
zigPredictSpeedupVsJs: number;
|
|
75
|
+
};
|
|
76
|
+
}>;
|
|
77
|
+
};
|
|
65
78
|
};
|
|
66
79
|
}
|
|
67
80
|
|
|
@@ -89,6 +102,11 @@ function renderBenchmarkSection(snapshot: BenchmarkSnapshot): string {
|
|
|
89
102
|
const [bunReg, sklearnReg] = regression.results;
|
|
90
103
|
const [bunCls, sklearnCls] = classification.results;
|
|
91
104
|
const [decisionTree, randomForest] = treeClassification.models;
|
|
105
|
+
const treeBackendModes = snapshot.suites.treeBackendModes;
|
|
106
|
+
const hasTreeBackendModes =
|
|
107
|
+
treeBackendModes?.enabled === true && Array.isArray(treeBackendModes.models) && treeBackendModes.models.length === 2;
|
|
108
|
+
const decisionTreeModes = hasTreeBackendModes ? treeBackendModes.models[0] : null;
|
|
109
|
+
const randomForestModes = hasTreeBackendModes ? treeBackendModes.models[1] : null;
|
|
92
110
|
|
|
93
111
|
return [
|
|
94
112
|
START_MARKER,
|
|
@@ -138,6 +156,44 @@ function renderBenchmarkSection(snapshot: BenchmarkSnapshot): string {
|
|
|
138
156
|
`RandomForest accuracy delta (bun - sklearn): ${randomForest.comparison.accuracyDeltaVsSklearn.toExponential(3)}`,
|
|
139
157
|
`RandomForest f1 delta (bun - sklearn): ${randomForest.comparison.f1DeltaVsSklearn.toExponential(3)}`,
|
|
140
158
|
"",
|
|
159
|
+
"### Tree Backend Modes (Bun vs Bun vs sklearn)",
|
|
160
|
+
"",
|
|
161
|
+
hasTreeBackendModes
|
|
162
|
+
? "| Model | Backend | Fit median (ms) | Predict median (ms) | Accuracy | F1 |"
|
|
163
|
+
: "Tree backend mode matrix disabled (`BENCH_TREE_BACKEND_MATRIX=0`).",
|
|
164
|
+
hasTreeBackendModes ? "|---|---|---:|---:|---:|---:|" : "",
|
|
165
|
+
hasTreeBackendModes
|
|
166
|
+
? `| DecisionTreeClassifier(maxDepth=8) | js-fast | ${decisionTreeModes!.jsFast.fitMsMedian.toFixed(4)} | ${decisionTreeModes!.jsFast.predictMsMedian.toFixed(4)} | ${decisionTreeModes!.jsFast.accuracy.toFixed(6)} | ${decisionTreeModes!.jsFast.f1.toFixed(6)} |`
|
|
167
|
+
: "",
|
|
168
|
+
hasTreeBackendModes
|
|
169
|
+
? `| DecisionTreeClassifier(maxDepth=8) | zig-tree | ${decisionTreeModes!.zigTree.fitMsMedian.toFixed(4)} | ${decisionTreeModes!.zigTree.predictMsMedian.toFixed(4)} | ${decisionTreeModes!.zigTree.accuracy.toFixed(6)} | ${decisionTreeModes!.zigTree.f1.toFixed(6)} |`
|
|
170
|
+
: "",
|
|
171
|
+
hasTreeBackendModes
|
|
172
|
+
? `| DecisionTreeClassifier | python-scikit-learn | ${decisionTreeModes!.sklearn.fitMsMedian.toFixed(4)} | ${decisionTreeModes!.sklearn.predictMsMedian.toFixed(4)} | ${decisionTreeModes!.sklearn.accuracy.toFixed(6)} | ${decisionTreeModes!.sklearn.f1.toFixed(6)} |`
|
|
173
|
+
: "",
|
|
174
|
+
hasTreeBackendModes
|
|
175
|
+
? `| RandomForestClassifier(nEstimators=80,maxDepth=8) | js-fast | ${randomForestModes!.jsFast.fitMsMedian.toFixed(4)} | ${randomForestModes!.jsFast.predictMsMedian.toFixed(4)} | ${randomForestModes!.jsFast.accuracy.toFixed(6)} | ${randomForestModes!.jsFast.f1.toFixed(6)} |`
|
|
176
|
+
: "",
|
|
177
|
+
hasTreeBackendModes
|
|
178
|
+
? `| RandomForestClassifier(nEstimators=80,maxDepth=8) | zig-tree | ${randomForestModes!.zigTree.fitMsMedian.toFixed(4)} | ${randomForestModes!.zigTree.predictMsMedian.toFixed(4)} | ${randomForestModes!.zigTree.accuracy.toFixed(6)} | ${randomForestModes!.zigTree.f1.toFixed(6)} |`
|
|
179
|
+
: "",
|
|
180
|
+
hasTreeBackendModes
|
|
181
|
+
? `| RandomForestClassifier | python-scikit-learn | ${randomForestModes!.sklearn.fitMsMedian.toFixed(4)} | ${randomForestModes!.sklearn.predictMsMedian.toFixed(4)} | ${randomForestModes!.sklearn.accuracy.toFixed(6)} | ${randomForestModes!.sklearn.f1.toFixed(6)} |`
|
|
182
|
+
: "",
|
|
183
|
+
"",
|
|
184
|
+
hasTreeBackendModes
|
|
185
|
+
? `DecisionTree zig/js fit speedup: ${decisionTreeModes!.comparison.zigFitSpeedupVsJs.toFixed(3)}x`
|
|
186
|
+
: "",
|
|
187
|
+
hasTreeBackendModes
|
|
188
|
+
? `DecisionTree zig/js predict speedup: ${decisionTreeModes!.comparison.zigPredictSpeedupVsJs.toFixed(3)}x`
|
|
189
|
+
: "",
|
|
190
|
+
hasTreeBackendModes
|
|
191
|
+
? `RandomForest zig/js fit speedup: ${randomForestModes!.comparison.zigFitSpeedupVsJs.toFixed(3)}x`
|
|
192
|
+
: "",
|
|
193
|
+
hasTreeBackendModes
|
|
194
|
+
? `RandomForest zig/js predict speedup: ${randomForestModes!.comparison.zigPredictSpeedupVsJs.toFixed(3)}x`
|
|
195
|
+
: "",
|
|
196
|
+
"",
|
|
141
197
|
`Snapshot generated at: ${snapshot.generatedAt}`,
|
|
142
198
|
END_MARKER,
|
|
143
199
|
].join("\n");
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import { accuracyScore } from "../metrics/classification";
|
|
3
|
+
import {
|
|
4
|
+
assertConsistentRowSize,
|
|
5
|
+
assertFiniteMatrix,
|
|
6
|
+
assertFiniteVector,
|
|
7
|
+
assertNonEmptyMatrix,
|
|
8
|
+
assertVectorLength,
|
|
9
|
+
} from "../utils/validation";
|
|
10
|
+
|
|
11
|
+
export type DummyClassifierStrategy =
|
|
12
|
+
| "most_frequent"
|
|
13
|
+
| "prior"
|
|
14
|
+
| "stratified"
|
|
15
|
+
| "uniform"
|
|
16
|
+
| "constant";
|
|
17
|
+
|
|
18
|
+
export interface DummyClassifierOptions {
|
|
19
|
+
strategy?: DummyClassifierStrategy;
|
|
20
|
+
constant?: number;
|
|
21
|
+
randomState?: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
class Mulberry32 {
|
|
25
|
+
private state: number;
|
|
26
|
+
|
|
27
|
+
constructor(seed: number) {
|
|
28
|
+
this.state = seed >>> 0;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
next(): number {
|
|
32
|
+
this.state = (this.state + 0x6d2b79f5) >>> 0;
|
|
33
|
+
let t = this.state ^ (this.state >>> 15);
|
|
34
|
+
t = Math.imul(t, this.state | 1);
|
|
35
|
+
t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
|
|
36
|
+
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export class DummyClassifier {
|
|
41
|
+
classes_: number[] | null = null;
|
|
42
|
+
classPrior_: number[] | null = null;
|
|
43
|
+
constant_: number | null = null;
|
|
44
|
+
|
|
45
|
+
private readonly strategy: DummyClassifierStrategy;
|
|
46
|
+
private readonly configuredConstant?: number;
|
|
47
|
+
private readonly randomState: number;
|
|
48
|
+
private majorityClass: number | null = null;
|
|
49
|
+
private nFeaturesIn_: number | null = null;
|
|
50
|
+
|
|
51
|
+
constructor(options: DummyClassifierOptions = {}) {
|
|
52
|
+
this.strategy = options.strategy ?? "prior";
|
|
53
|
+
this.configuredConstant = options.constant;
|
|
54
|
+
this.randomState = options.randomState ?? 42;
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
fit(X: Matrix, y: Vector): this {
|
|
58
|
+
assertNonEmptyMatrix(X);
|
|
59
|
+
assertConsistentRowSize(X);
|
|
60
|
+
assertFiniteMatrix(X);
|
|
61
|
+
assertVectorLength(y, X.length);
|
|
62
|
+
assertFiniteVector(y);
|
|
63
|
+
this.nFeaturesIn_ = X[0].length;
|
|
64
|
+
|
|
65
|
+
const counts = new Map<number, number>();
|
|
66
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
67
|
+
counts.set(y[i], (counts.get(y[i]) ?? 0) + 1);
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
const classes = Array.from(counts.keys()).sort((a, b) => a - b);
|
|
71
|
+
const priors = new Array<number>(classes.length);
|
|
72
|
+
for (let i = 0; i < classes.length; i += 1) {
|
|
73
|
+
priors[i] = (counts.get(classes[i]) ?? 0) / y.length;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let majorityClass = classes[0];
|
|
77
|
+
let majorityCount = counts.get(majorityClass) ?? 0;
|
|
78
|
+
for (let i = 1; i < classes.length; i += 1) {
|
|
79
|
+
const cls = classes[i];
|
|
80
|
+
const clsCount = counts.get(cls) ?? 0;
|
|
81
|
+
if (clsCount > majorityCount) {
|
|
82
|
+
majorityClass = cls;
|
|
83
|
+
majorityCount = clsCount;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (this.strategy === "constant") {
|
|
88
|
+
if (!Number.isFinite(this.configuredConstant)) {
|
|
89
|
+
throw new Error("constant strategy requires a finite constant value.");
|
|
90
|
+
}
|
|
91
|
+
this.constant_ = this.configuredConstant!;
|
|
92
|
+
} else {
|
|
93
|
+
this.constant_ = majorityClass;
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
this.classes_ = classes;
|
|
97
|
+
this.classPrior_ = priors;
|
|
98
|
+
this.majorityClass = majorityClass;
|
|
99
|
+
return this;
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
private ensureFitted(): void {
|
|
103
|
+
if (!this.classes_ || !this.classPrior_ || this.nFeaturesIn_ === null || this.majorityClass === null) {
|
|
104
|
+
throw new Error("DummyClassifier has not been fitted.");
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
private sampleByPrior(rng: Mulberry32): number {
|
|
109
|
+
let r = rng.next();
|
|
110
|
+
for (let i = 0; i < this.classPrior_!.length; i += 1) {
|
|
111
|
+
r -= this.classPrior_![i];
|
|
112
|
+
if (r <= 0) {
|
|
113
|
+
return this.classes_![i];
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
return this.classes_![this.classes_!.length - 1];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
predict(X: Matrix): Vector {
|
|
120
|
+
this.ensureFitted();
|
|
121
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
122
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
123
|
+
}
|
|
124
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
125
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
switch (this.strategy) {
|
|
129
|
+
case "most_frequent":
|
|
130
|
+
case "prior":
|
|
131
|
+
return new Array<number>(X.length).fill(this.majorityClass!);
|
|
132
|
+
case "constant":
|
|
133
|
+
return new Array<number>(X.length).fill(this.constant_!);
|
|
134
|
+
case "uniform": {
|
|
135
|
+
const rng = new Mulberry32(this.randomState);
|
|
136
|
+
const out = new Array<number>(X.length);
|
|
137
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
138
|
+
const idx = Math.floor(rng.next() * this.classes_!.length);
|
|
139
|
+
out[i] = this.classes_![idx];
|
|
140
|
+
}
|
|
141
|
+
return out;
|
|
142
|
+
}
|
|
143
|
+
case "stratified": {
|
|
144
|
+
const rng = new Mulberry32(this.randomState);
|
|
145
|
+
const out = new Array<number>(X.length);
|
|
146
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
147
|
+
out[i] = this.sampleByPrior(rng);
|
|
148
|
+
}
|
|
149
|
+
return out;
|
|
150
|
+
}
|
|
151
|
+
default: {
|
|
152
|
+
const exhaustive: never = this.strategy;
|
|
153
|
+
throw new Error(`Unsupported strategy: ${exhaustive}`);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
predictProba(X: Matrix): Matrix {
|
|
159
|
+
this.ensureFitted();
|
|
160
|
+
if (!Array.isArray(X) || X.length === 0) {
|
|
161
|
+
throw new Error("X must be a non-empty 2D array.");
|
|
162
|
+
}
|
|
163
|
+
if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
|
|
164
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
if (this.strategy === "uniform") {
|
|
168
|
+
const value = 1 / this.classes_!.length;
|
|
169
|
+
return X.map(() => new Array(this.classes_!.length).fill(value));
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
if (this.strategy === "most_frequent" || this.strategy === "constant") {
|
|
173
|
+
const oneHot = new Array<number>(this.classes_!.length).fill(0);
|
|
174
|
+
const label = this.strategy === "constant" ? this.constant_! : this.majorityClass!;
|
|
175
|
+
const classIndex = this.classes_!.indexOf(label);
|
|
176
|
+
if (classIndex >= 0) {
|
|
177
|
+
oneHot[classIndex] = 1;
|
|
178
|
+
}
|
|
179
|
+
return X.map(() => [...oneHot]);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// prior / stratified share prior probabilities.
|
|
183
|
+
const prior = [...this.classPrior_!];
|
|
184
|
+
return X.map(() => [...prior]);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
score(X: Matrix, y: Vector): number {
|
|
188
|
+
return accuracyScore(y, this.predict(X));
|
|
189
|
+
}
|
|
190
|
+
}
|