bun-scikit 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +187 -0
  3. package/binding.gyp +21 -0
  4. package/docs/README.md +7 -0
  5. package/docs/native-abi.md +53 -0
  6. package/index.ts +1 -0
  7. package/package.json +76 -0
  8. package/scripts/build-node-addon.ts +26 -0
  9. package/scripts/build-zig-kernels.ts +50 -0
  10. package/scripts/check-api-docs-coverage.ts +52 -0
  11. package/scripts/check-benchmark-health.ts +140 -0
  12. package/scripts/install-native.ts +160 -0
  13. package/scripts/package-native-artifacts.ts +62 -0
  14. package/scripts/sync-benchmark-readme.ts +181 -0
  15. package/scripts/update-benchmark-history.ts +91 -0
  16. package/src/ensemble/RandomForestClassifier.ts +136 -0
  17. package/src/ensemble/RandomForestRegressor.ts +136 -0
  18. package/src/index.ts +32 -0
  19. package/src/linear_model/LinearRegression.ts +136 -0
  20. package/src/linear_model/LogisticRegression.ts +260 -0
  21. package/src/linear_model/SGDClassifier.ts +161 -0
  22. package/src/linear_model/SGDRegressor.ts +104 -0
  23. package/src/metrics/classification.ts +294 -0
  24. package/src/metrics/regression.ts +51 -0
  25. package/src/model_selection/GridSearchCV.ts +244 -0
  26. package/src/model_selection/KFold.ts +82 -0
  27. package/src/model_selection/RepeatedKFold.ts +49 -0
  28. package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
  29. package/src/model_selection/StratifiedKFold.ts +112 -0
  30. package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
  31. package/src/model_selection/crossValScore.ts +165 -0
  32. package/src/model_selection/trainTestSplit.ts +82 -0
  33. package/src/naive_bayes/GaussianNB.ts +148 -0
  34. package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
  35. package/src/native/zigKernels.ts +576 -0
  36. package/src/neighbors/KNeighborsClassifier.ts +85 -0
  37. package/src/pipeline/ColumnTransformer.ts +203 -0
  38. package/src/pipeline/FeatureUnion.ts +123 -0
  39. package/src/pipeline/Pipeline.ts +168 -0
  40. package/src/preprocessing/MinMaxScaler.ts +113 -0
  41. package/src/preprocessing/OneHotEncoder.ts +91 -0
  42. package/src/preprocessing/PolynomialFeatures.ts +158 -0
  43. package/src/preprocessing/RobustScaler.ts +149 -0
  44. package/src/preprocessing/SimpleImputer.ts +150 -0
  45. package/src/preprocessing/StandardScaler.ts +92 -0
  46. package/src/svm/LinearSVC.ts +117 -0
  47. package/src/tree/DecisionTreeClassifier.ts +394 -0
  48. package/src/tree/DecisionTreeRegressor.ts +407 -0
  49. package/src/types.ts +18 -0
  50. package/src/utils/linalg.ts +209 -0
  51. package/src/utils/validation.ts +78 -0
  52. package/zig/kernels.zig +1327 -0
@@ -0,0 +1,160 @@
1
+ import { access, mkdir, readFile, writeFile } from "node:fs/promises";
2
+ import { constants } from "node:fs";
3
+ import { resolve } from "node:path";
4
+
5
+ interface PackageJson {
6
+ version: string;
7
+ repository?: {
8
+ url?: string;
9
+ };
10
+ }
11
+
12
+ function mapOs(value: NodeJS.Platform): "windows" | "linux" | null {
13
+ if (value === "win32") {
14
+ return "windows";
15
+ }
16
+ if (value === "linux") {
17
+ return "linux";
18
+ }
19
+ return null;
20
+ }
21
+
22
+ function mapArch(value: string): "x64" | "arm64" | null {
23
+ if (value === "x64") {
24
+ return "x64";
25
+ }
26
+ if (value === "arm64") {
27
+ return "arm64";
28
+ }
29
+ return null;
30
+ }
31
+
32
+ function kernelExtension(osName: "windows" | "linux"): string {
33
+ return osName === "windows" ? "dll" : "so";
34
+ }
35
+
36
+ function parseRepositorySlug(url: string | undefined): string | null {
37
+ if (!url) {
38
+ return null;
39
+ }
40
+ const normalized = url.replace("git+", "").replace(/\.git$/, "");
41
+ const match = normalized.match(/github\.com\/([^/]+\/[^/]+)/i);
42
+ return match?.[1] ?? null;
43
+ }
44
+
45
+ async function fileExists(path: string): Promise<boolean> {
46
+ try {
47
+ await access(path, constants.F_OK);
48
+ return true;
49
+ } catch {
50
+ return false;
51
+ }
52
+ }
53
+
54
+ async function downloadToFile(url: string, path: string): Promise<boolean> {
55
+ const response = await fetch(url);
56
+ if (!response.ok) {
57
+ return false;
58
+ }
59
+ const bytes = await response.arrayBuffer();
60
+ await writeFile(path, Buffer.from(bytes));
61
+ return true;
62
+ }
63
+
64
+ async function tryDownloadPrebuilt(): Promise<boolean> {
65
+ const packageJsonRaw = await readFile(resolve("package.json"), "utf-8");
66
+ const packageJson = JSON.parse(packageJsonRaw) as PackageJson;
67
+
68
+ const osName = mapOs(process.platform);
69
+ const arch = mapArch(process.arch);
70
+ if (!osName || !arch) {
71
+ console.log(
72
+ `[bun-scikit] prebuilt binaries are unavailable for platform=${process.platform} arch=${process.arch}`,
73
+ );
74
+ return false;
75
+ }
76
+
77
+ const repoSlug = parseRepositorySlug(packageJson.repository?.url) ?? "Seyamalam/bun-scikit";
78
+ const version = packageJson.version;
79
+ const ext = kernelExtension(osName);
80
+ const tag = `v${version}`;
81
+ const base =
82
+ `https://github.com/${repoSlug}/releases/download/${tag}`;
83
+
84
+ const kernelName = `bun_scikit_kernels-v${version}-${osName}-${arch}.${ext}`;
85
+ const addonName = `bun_scikit_node_addon-v${version}-${osName}-${arch}.node`;
86
+ const kernelUrl = `${base}/${kernelName}`;
87
+ const addonUrl = `${base}/${addonName}`;
88
+
89
+ const outDir = resolve("dist", "native");
90
+ await mkdir(outDir, { recursive: true });
91
+ const kernelPath = resolve(outDir, `bun_scikit_kernels.${ext}`);
92
+ const addonPath = resolve(outDir, "bun_scikit_node_addon.node");
93
+
94
+ const [kernelOk, addonOk] = await Promise.all([
95
+ downloadToFile(kernelUrl, kernelPath),
96
+ downloadToFile(addonUrl, addonPath),
97
+ ]);
98
+
99
+ if (kernelOk && addonOk) {
100
+ console.log(`[bun-scikit] downloaded native prebuilt artifacts for ${osName}-${arch}`);
101
+ return true;
102
+ }
103
+
104
+ return false;
105
+ }
106
+
107
+ async function tryLocalBuild(): Promise<void> {
108
+ console.log("[bun-scikit] prebuilt binaries not found; attempting local native build");
109
+ {
110
+ const child = Bun.spawn(["bun", "run", "native:build"], {
111
+ stdout: "inherit",
112
+ stderr: "inherit",
113
+ });
114
+ const code = await child.exited;
115
+ if (code !== 0) {
116
+ throw new Error(`native:build failed with exit code ${code}`);
117
+ }
118
+ }
119
+
120
+ {
121
+ const child = Bun.spawn(["bun", "run", "native:build:node-addon"], {
122
+ stdout: "inherit",
123
+ stderr: "inherit",
124
+ });
125
+ const code = await child.exited;
126
+ if (code !== 0) {
127
+ console.warn(
128
+ "[bun-scikit] Node-API addon local build failed; the Zig FFI backend can still be used.",
129
+ );
130
+ }
131
+ }
132
+ }
133
+
134
+ async function main(): Promise<void> {
135
+ if (process.env.BUN_SCIKIT_SKIP_NATIVE_INSTALL === "1" || process.env.CI === "true") {
136
+ console.log("[bun-scikit] skipping native install bootstrap in CI/skip mode");
137
+ return;
138
+ }
139
+
140
+ const existingKernelDll = resolve("dist", "native", "bun_scikit_kernels.dll");
141
+ const existingKernelSo = resolve("dist", "native", "bun_scikit_kernels.so");
142
+ if (await fileExists(existingKernelDll) || (await fileExists(existingKernelSo))) {
143
+ return;
144
+ }
145
+
146
+ const downloaded = await tryDownloadPrebuilt();
147
+ if (downloaded) {
148
+ return;
149
+ }
150
+
151
+ try {
152
+ await tryLocalBuild();
153
+ } catch (error) {
154
+ console.warn("[bun-scikit] native postinstall setup did not complete.", error);
155
+ }
156
+ }
157
+
158
+ main().catch((error) => {
159
+ console.warn("[bun-scikit] native install script error:", error);
160
+ });
@@ -0,0 +1,62 @@
1
+ import { cp, mkdir, readFile } from "node:fs/promises";
2
+ import { resolve } from "node:path";
3
+
4
+ interface PackageJson {
5
+ version: string;
6
+ }
7
+
8
+ function argValue(flag: string): string | null {
9
+ const index = Bun.argv.indexOf(flag);
10
+ if (index === -1 || index + 1 >= Bun.argv.length) {
11
+ return null;
12
+ }
13
+ return Bun.argv[index + 1];
14
+ }
15
+
16
+ function kernelExtension(osName: string): string {
17
+ switch (osName) {
18
+ case "windows":
19
+ return "dll";
20
+ case "linux":
21
+ return "so";
22
+ default:
23
+ throw new Error(`Unsupported OS for packaging assets: ${osName}`);
24
+ }
25
+ }
26
+
27
+ async function main(): Promise<void> {
28
+ const osName = argValue("--os");
29
+ const arch = argValue("--arch");
30
+ if (!osName || !arch) {
31
+ throw new Error("Usage: bun run scripts/package-native-artifacts.ts --os <linux|windows> --arch <x64|arm64>");
32
+ }
33
+
34
+ const packageJsonRaw = await readFile(resolve("package.json"), "utf-8");
35
+ const packageJson = JSON.parse(packageJsonRaw) as PackageJson;
36
+ const version = packageJson.version;
37
+ const extension = kernelExtension(osName);
38
+
39
+ const sourceKernel = resolve("dist", "native", `bun_scikit_kernels.${extension}`);
40
+ const sourceAddon = resolve("dist", "native", "bun_scikit_node_addon.node");
41
+
42
+ const outDir = resolve("dist", "release-assets");
43
+ await mkdir(outDir, { recursive: true });
44
+
45
+ const targetKernel = resolve(
46
+ outDir,
47
+ `bun_scikit_kernels-v${version}-${osName}-${arch}.${extension}`,
48
+ );
49
+ const targetAddon = resolve(
50
+ outDir,
51
+ `bun_scikit_node_addon-v${version}-${osName}-${arch}.node`,
52
+ );
53
+
54
+ await cp(sourceKernel, targetKernel, { force: true });
55
+ await cp(sourceAddon, targetAddon, { force: true });
56
+ console.log(`Packaged release assets:\n- ${targetKernel}\n- ${targetAddon}`);
57
+ }
58
+
59
+ main().catch((error) => {
60
+ console.error(error);
61
+ process.exit(1);
62
+ });
@@ -0,0 +1,181 @@
1
+ import { readFile, writeFile } from "node:fs/promises";
2
+ import { resolve } from "node:path";
3
+
4
+ interface SharedBenchmarkResult {
5
+ implementation: string;
6
+ model: string;
7
+ fitMsMedian: number;
8
+ predictMsMedian: number;
9
+ }
10
+
11
+ interface RegressionBenchmarkResult extends SharedBenchmarkResult {
12
+ mse: number;
13
+ r2: number;
14
+ }
15
+
16
+ interface ClassificationBenchmarkResult extends SharedBenchmarkResult {
17
+ accuracy: number;
18
+ f1: number;
19
+ }
20
+
21
+ type TreeModelKey = "decision_tree" | "random_forest";
22
+
23
+ interface TreeModelComparison {
24
+ key: TreeModelKey;
25
+ bun: ClassificationBenchmarkResult;
26
+ sklearn: ClassificationBenchmarkResult;
27
+ comparison: {
28
+ fitSpeedupVsSklearn: number;
29
+ predictSpeedupVsSklearn: number;
30
+ accuracyDeltaVsSklearn: number;
31
+ f1DeltaVsSklearn: number;
32
+ };
33
+ }
34
+
35
+ interface BenchmarkSnapshot {
36
+ generatedAt: string;
37
+ dataset: {
38
+ path: string;
39
+ samples: number;
40
+ features: number;
41
+ testFraction: number;
42
+ };
43
+ suites: {
44
+ regression: {
45
+ results: [RegressionBenchmarkResult, RegressionBenchmarkResult];
46
+ comparison: {
47
+ fitSpeedupVsSklearn: number;
48
+ predictSpeedupVsSklearn: number;
49
+ mseDeltaVsSklearn: number;
50
+ r2DeltaVsSklearn: number;
51
+ };
52
+ };
53
+ classification: {
54
+ results: [ClassificationBenchmarkResult, ClassificationBenchmarkResult];
55
+ comparison: {
56
+ fitSpeedupVsSklearn: number;
57
+ predictSpeedupVsSklearn: number;
58
+ accuracyDeltaVsSklearn: number;
59
+ f1DeltaVsSklearn: number;
60
+ };
61
+ };
62
+ treeClassification: {
63
+ models: [TreeModelComparison, TreeModelComparison];
64
+ };
65
+ };
66
+ }
67
+
68
+ const START_MARKER = "<!-- BENCHMARK_TABLE_START -->";
69
+ const END_MARKER = "<!-- BENCHMARK_TABLE_END -->";
70
+ const README_PATH = resolve("README.md");
71
+ const DEFAULT_SNAPSHOT_PATH = resolve("bench/results/heart-ci-latest.json");
72
+
73
+ function parseArgValue(flag: string): string | null {
74
+ const index = Bun.argv.indexOf(flag);
75
+ if (index === -1 || index + 1 >= Bun.argv.length) {
76
+ return null;
77
+ }
78
+ return Bun.argv[index + 1];
79
+ }
80
+
81
+ function normalizeLineEndings(content: string): string {
82
+ return content.replace(/\r\n/g, "\n");
83
+ }
84
+
85
+ function renderBenchmarkSection(snapshot: BenchmarkSnapshot): string {
86
+ const regression = snapshot.suites.regression;
87
+ const classification = snapshot.suites.classification;
88
+ const treeClassification = snapshot.suites.treeClassification;
89
+ const [bunReg, sklearnReg] = regression.results;
90
+ const [bunCls, sklearnCls] = classification.results;
91
+ const [decisionTree, randomForest] = treeClassification.models;
92
+
93
+ return [
94
+ START_MARKER,
95
+ "Benchmark snapshot source: `bench/results/heart-ci-latest.json` (generated in CI workflow `Benchmark Snapshot`).",
96
+ `Dataset: \`${snapshot.dataset.path}\` (${snapshot.dataset.samples} samples, ${snapshot.dataset.features} features, test fraction ${snapshot.dataset.testFraction}).`,
97
+ "",
98
+ "### Regression",
99
+ "",
100
+ "| Implementation | Model | Fit median (ms) | Predict median (ms) | MSE | R2 |",
101
+ "|---|---|---:|---:|---:|---:|",
102
+ `| ${bunReg.implementation} | ${bunReg.model} | ${bunReg.fitMsMedian.toFixed(4)} | ${bunReg.predictMsMedian.toFixed(4)} | ${bunReg.mse.toFixed(6)} | ${bunReg.r2.toFixed(6)} |`,
103
+ `| ${sklearnReg.implementation} | ${sklearnReg.model} | ${sklearnReg.fitMsMedian.toFixed(4)} | ${sklearnReg.predictMsMedian.toFixed(4)} | ${sklearnReg.mse.toFixed(6)} | ${sklearnReg.r2.toFixed(6)} |`,
104
+ "",
105
+ `Bun fit speedup vs scikit-learn: ${regression.comparison.fitSpeedupVsSklearn.toFixed(3)}x`,
106
+ `Bun predict speedup vs scikit-learn: ${regression.comparison.predictSpeedupVsSklearn.toFixed(3)}x`,
107
+ `MSE delta (bun - sklearn): ${regression.comparison.mseDeltaVsSklearn.toExponential(3)}`,
108
+ `R2 delta (bun - sklearn): ${regression.comparison.r2DeltaVsSklearn.toExponential(3)}`,
109
+ "",
110
+ "### Classification",
111
+ "",
112
+ "| Implementation | Model | Fit median (ms) | Predict median (ms) | Accuracy | F1 |",
113
+ "|---|---|---:|---:|---:|---:|",
114
+ `| ${bunCls.implementation} | ${bunCls.model} | ${bunCls.fitMsMedian.toFixed(4)} | ${bunCls.predictMsMedian.toFixed(4)} | ${bunCls.accuracy.toFixed(6)} | ${bunCls.f1.toFixed(6)} |`,
115
+ `| ${sklearnCls.implementation} | ${sklearnCls.model} | ${sklearnCls.fitMsMedian.toFixed(4)} | ${sklearnCls.predictMsMedian.toFixed(4)} | ${sklearnCls.accuracy.toFixed(6)} | ${sklearnCls.f1.toFixed(6)} |`,
116
+ "",
117
+ `Bun fit speedup vs scikit-learn: ${classification.comparison.fitSpeedupVsSklearn.toFixed(3)}x`,
118
+ `Bun predict speedup vs scikit-learn: ${classification.comparison.predictSpeedupVsSklearn.toFixed(3)}x`,
119
+ `Accuracy delta (bun - sklearn): ${classification.comparison.accuracyDeltaVsSklearn.toExponential(3)}`,
120
+ `F1 delta (bun - sklearn): ${classification.comparison.f1DeltaVsSklearn.toExponential(3)}`,
121
+ "",
122
+ "### Tree Classification",
123
+ "",
124
+ "| Model | Implementation | Fit median (ms) | Predict median (ms) | Accuracy | F1 |",
125
+ "|---|---|---:|---:|---:|---:|",
126
+ `| ${decisionTree.bun.model} | ${decisionTree.bun.implementation} | ${decisionTree.bun.fitMsMedian.toFixed(4)} | ${decisionTree.bun.predictMsMedian.toFixed(4)} | ${decisionTree.bun.accuracy.toFixed(6)} | ${decisionTree.bun.f1.toFixed(6)} |`,
127
+ `| ${decisionTree.sklearn.model} | ${decisionTree.sklearn.implementation} | ${decisionTree.sklearn.fitMsMedian.toFixed(4)} | ${decisionTree.sklearn.predictMsMedian.toFixed(4)} | ${decisionTree.sklearn.accuracy.toFixed(6)} | ${decisionTree.sklearn.f1.toFixed(6)} |`,
128
+ `| ${randomForest.bun.model} | ${randomForest.bun.implementation} | ${randomForest.bun.fitMsMedian.toFixed(4)} | ${randomForest.bun.predictMsMedian.toFixed(4)} | ${randomForest.bun.accuracy.toFixed(6)} | ${randomForest.bun.f1.toFixed(6)} |`,
129
+ `| ${randomForest.sklearn.model} | ${randomForest.sklearn.implementation} | ${randomForest.sklearn.fitMsMedian.toFixed(4)} | ${randomForest.sklearn.predictMsMedian.toFixed(4)} | ${randomForest.sklearn.accuracy.toFixed(6)} | ${randomForest.sklearn.f1.toFixed(6)} |`,
130
+ "",
131
+ `DecisionTree fit speedup vs scikit-learn: ${decisionTree.comparison.fitSpeedupVsSklearn.toFixed(3)}x`,
132
+ `DecisionTree predict speedup vs scikit-learn: ${decisionTree.comparison.predictSpeedupVsSklearn.toFixed(3)}x`,
133
+ `DecisionTree accuracy delta (bun - sklearn): ${decisionTree.comparison.accuracyDeltaVsSklearn.toExponential(3)}`,
134
+ `DecisionTree f1 delta (bun - sklearn): ${decisionTree.comparison.f1DeltaVsSklearn.toExponential(3)}`,
135
+ "",
136
+ `RandomForest fit speedup vs scikit-learn: ${randomForest.comparison.fitSpeedupVsSklearn.toFixed(3)}x`,
137
+ `RandomForest predict speedup vs scikit-learn: ${randomForest.comparison.predictSpeedupVsSklearn.toFixed(3)}x`,
138
+ `RandomForest accuracy delta (bun - sklearn): ${randomForest.comparison.accuracyDeltaVsSklearn.toExponential(3)}`,
139
+ `RandomForest f1 delta (bun - sklearn): ${randomForest.comparison.f1DeltaVsSklearn.toExponential(3)}`,
140
+ "",
141
+ `Snapshot generated at: ${snapshot.generatedAt}`,
142
+ END_MARKER,
143
+ ].join("\n");
144
+ }
145
+
146
+ const inputPath = resolve(parseArgValue("--input") ?? DEFAULT_SNAPSHOT_PATH);
147
+ const checkMode = Bun.argv.includes("--check");
148
+
149
+ const [readme, snapshotRaw] = await Promise.all([
150
+ readFile(README_PATH, "utf-8"),
151
+ readFile(inputPath, "utf-8"),
152
+ ]);
153
+ const snapshot = JSON.parse(snapshotRaw) as BenchmarkSnapshot;
154
+
155
+ const startIndex = readme.indexOf(START_MARKER);
156
+ const endIndex = readme.indexOf(END_MARKER);
157
+ if (startIndex === -1 || endIndex === -1 || endIndex < startIndex) {
158
+ throw new Error(
159
+ `README markers are missing or invalid. Expected markers: ${START_MARKER} and ${END_MARKER}.`,
160
+ );
161
+ }
162
+
163
+ const existingSectionEnd = endIndex + END_MARKER.length;
164
+ const nextReadme =
165
+ readme.slice(0, startIndex) +
166
+ renderBenchmarkSection(snapshot) +
167
+ readme.slice(existingSectionEnd);
168
+
169
+ if (checkMode) {
170
+ if (normalizeLineEndings(nextReadme) !== normalizeLineEndings(readme)) {
171
+ console.error(
172
+ "README benchmark section is out of date. Run: bun run bench:sync-readme",
173
+ );
174
+ process.exit(1);
175
+ }
176
+ console.log("README benchmark section is up to date.");
177
+ process.exit(0);
178
+ }
179
+
180
+ await writeFile(README_PATH, nextReadme, "utf-8");
181
+ console.log(`Updated README benchmark section from ${inputPath}`);
@@ -0,0 +1,91 @@
1
+ import { mkdir, readFile, writeFile } from "node:fs/promises";
2
+ import { dirname, resolve } from "node:path";
3
+
4
+ interface BenchmarkSnapshot {
5
+ generatedAt: string;
6
+ benchmarkConfig: {
7
+ iterations: number;
8
+ warmup: number;
9
+ };
10
+ suites: {
11
+ regression: {
12
+ comparison: {
13
+ fitSpeedupVsSklearn: number;
14
+ predictSpeedupVsSklearn: number;
15
+ mseDeltaVsSklearn: number;
16
+ r2DeltaVsSklearn: number;
17
+ };
18
+ };
19
+ classification: {
20
+ comparison: {
21
+ fitSpeedupVsSklearn: number;
22
+ predictSpeedupVsSklearn: number;
23
+ accuracyDeltaVsSklearn: number;
24
+ f1DeltaVsSklearn: number;
25
+ };
26
+ };
27
+ treeClassification: {
28
+ models: Array<{
29
+ key: string;
30
+ comparison: {
31
+ fitSpeedupVsSklearn: number;
32
+ predictSpeedupVsSklearn: number;
33
+ accuracyDeltaVsSklearn: number;
34
+ f1DeltaVsSklearn: number;
35
+ };
36
+ }>;
37
+ };
38
+ };
39
+ }
40
+
41
+ function parseArgValue(flag: string): string | null {
42
+ const index = Bun.argv.indexOf(flag);
43
+ if (index === -1 || index + 1 >= Bun.argv.length) {
44
+ return null;
45
+ }
46
+ return Bun.argv[index + 1];
47
+ }
48
+
49
+ const inputPath = resolve(
50
+ parseArgValue("--input") ?? "bench/results/heart-ci-latest.json",
51
+ );
52
+ const outputPath = resolve(
53
+ parseArgValue("--output") ?? "bench/results/history/heart-ci-history.jsonl",
54
+ );
55
+
56
+ const snapshot = JSON.parse(await readFile(inputPath, "utf-8")) as BenchmarkSnapshot;
57
+
58
+ const historyRecord = {
59
+ generatedAt: snapshot.generatedAt,
60
+ benchmarkConfig: snapshot.benchmarkConfig,
61
+ regression: snapshot.suites.regression.comparison,
62
+ classification: snapshot.suites.classification.comparison,
63
+ treeClassification: snapshot.suites.treeClassification.models.map((model) => ({
64
+ key: model.key,
65
+ ...model.comparison,
66
+ })),
67
+ };
68
+
69
+ let existing = "";
70
+ try {
71
+ existing = await readFile(outputPath, "utf-8");
72
+ } catch {
73
+ existing = "";
74
+ }
75
+
76
+ const lines = existing
77
+ .split(/\r?\n/)
78
+ .map((line) => line.trim())
79
+ .filter((line) => line.length > 0);
80
+
81
+ if (lines.length > 0) {
82
+ const last = JSON.parse(lines[lines.length - 1]) as { generatedAt?: string };
83
+ if (last.generatedAt === historyRecord.generatedAt) {
84
+ console.log(`History already contains snapshot ${historyRecord.generatedAt}.`);
85
+ process.exit(0);
86
+ }
87
+ }
88
+
89
+ await mkdir(dirname(outputPath), { recursive: true });
90
+ await writeFile(outputPath, `${existing}${JSON.stringify(historyRecord)}\n`, "utf-8");
91
+ console.log(`Appended benchmark history: ${outputPath}`);
@@ -0,0 +1,136 @@
1
+ import type { ClassificationModel, Matrix, Vector } from "../types";
2
+ import { accuracyScore } from "../metrics/classification";
3
+ import { DecisionTreeClassifier, type MaxFeaturesOption } from "../tree/DecisionTreeClassifier";
4
+ import { assertFiniteVector, validateClassificationInputs } from "../utils/validation";
5
+
6
+ export interface RandomForestClassifierOptions {
7
+ nEstimators?: number;
8
+ maxDepth?: number;
9
+ minSamplesSplit?: number;
10
+ minSamplesLeaf?: number;
11
+ maxFeatures?: MaxFeaturesOption;
12
+ bootstrap?: boolean;
13
+ randomState?: number;
14
+ }
15
+
16
+ function mulberry32(seed: number): () => number {
17
+ let state = seed >>> 0;
18
+ return () => {
19
+ state += 0x6d2b79f5;
20
+ let t = Math.imul(state ^ (state >>> 15), 1 | state);
21
+ t ^= t + Math.imul(t ^ (t >>> 7), 61 | t);
22
+ return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
23
+ };
24
+ }
25
+
26
+ export class RandomForestClassifier implements ClassificationModel {
27
+ classes_: Vector = [0, 1];
28
+ private readonly nEstimators: number;
29
+ private readonly maxDepth?: number;
30
+ private readonly minSamplesSplit?: number;
31
+ private readonly minSamplesLeaf?: number;
32
+ private readonly maxFeatures: MaxFeaturesOption;
33
+ private readonly bootstrap: boolean;
34
+ private readonly randomState?: number;
35
+ private trees: DecisionTreeClassifier[] = [];
36
+
37
+ constructor(options: RandomForestClassifierOptions = {}) {
38
+ this.nEstimators = options.nEstimators ?? 50;
39
+ this.maxDepth = options.maxDepth ?? 12;
40
+ this.minSamplesSplit = options.minSamplesSplit ?? 2;
41
+ this.minSamplesLeaf = options.minSamplesLeaf ?? 1;
42
+ this.maxFeatures = options.maxFeatures ?? "sqrt";
43
+ this.bootstrap = options.bootstrap ?? true;
44
+ this.randomState = options.randomState;
45
+
46
+ if (!Number.isInteger(this.nEstimators) || this.nEstimators < 1) {
47
+ throw new Error(`nEstimators must be a positive integer. Got ${this.nEstimators}.`);
48
+ }
49
+ }
50
+
51
+ fit(X: Matrix, y: Vector): this {
52
+ validateClassificationInputs(X, y);
53
+
54
+ const sampleCount = X.length;
55
+ const featureCount = X[0].length;
56
+ const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
57
+ const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
58
+ const yBinary = this.buildBinaryTargets(y);
59
+ this.trees = new Array(this.nEstimators);
60
+
61
+ for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
62
+ const sampleIndices = new Uint32Array(sampleCount);
63
+ if (this.bootstrap) {
64
+ for (let i = 0; i < sampleCount; i += 1) {
65
+ sampleIndices[i] = Math.floor(random() * sampleCount);
66
+ }
67
+ } else {
68
+ for (let i = 0; i < sampleCount; i += 1) {
69
+ sampleIndices[i] = i;
70
+ }
71
+ }
72
+
73
+ const tree = new DecisionTreeClassifier({
74
+ maxDepth: this.maxDepth,
75
+ minSamplesSplit: this.minSamplesSplit,
76
+ minSamplesLeaf: this.minSamplesLeaf,
77
+ maxFeatures: this.maxFeatures,
78
+ randomState:
79
+ this.randomState === undefined ? undefined : this.randomState + estimatorIndex + 1,
80
+ });
81
+ tree.fit(X, y, sampleIndices, true, flattenedX, yBinary);
82
+ this.trees[estimatorIndex] = tree;
83
+ }
84
+
85
+ return this;
86
+ }
87
+
88
+ predict(X: Matrix): Vector {
89
+ if (this.trees.length === 0) {
90
+ throw new Error("RandomForestClassifier has not been fitted.");
91
+ }
92
+
93
+ const treePredictions = this.trees.map((tree) => tree.predict(X));
94
+ const sampleCount = X.length;
95
+ const predictions = new Array(sampleCount).fill(0);
96
+
97
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
98
+ let positiveVotes = 0;
99
+ for (let treeIndex = 0; treeIndex < treePredictions.length; treeIndex += 1) {
100
+ positiveVotes += treePredictions[treeIndex][sampleIndex] === 1 ? 1 : 0;
101
+ }
102
+ predictions[sampleIndex] = positiveVotes * 2 >= this.trees.length ? 1 : 0;
103
+ }
104
+
105
+ return predictions;
106
+ }
107
+
108
+ score(X: Matrix, y: Vector): number {
109
+ assertFiniteVector(y);
110
+ return accuracyScore(y, this.predict(X));
111
+ }
112
+
113
+ private flattenTrainingMatrix(
114
+ X: Matrix,
115
+ sampleCount: number,
116
+ featureCount: number,
117
+ ): Float64Array {
118
+ const flattened = new Float64Array(sampleCount * featureCount);
119
+ for (let i = 0; i < sampleCount; i += 1) {
120
+ const row = X[i];
121
+ const rowOffset = i * featureCount;
122
+ for (let j = 0; j < featureCount; j += 1) {
123
+ flattened[rowOffset + j] = row[j];
124
+ }
125
+ }
126
+ return flattened;
127
+ }
128
+
129
+ private buildBinaryTargets(y: Vector): Uint8Array {
130
+ const encoded = new Uint8Array(y.length);
131
+ for (let i = 0; i < y.length; i += 1) {
132
+ encoded[i] = y[i] === 1 ? 1 : 0;
133
+ }
134
+ return encoded;
135
+ }
136
+ }