bun-scikit 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "bun-scikit",
3
- "version": "0.1.4",
3
+ "version": "0.1.5",
4
4
  "description": "A scikit-learn-inspired machine learning library for Bun/TypeScript.",
5
5
  "license": "MIT",
6
6
  "module": "index.ts",
@@ -69,6 +69,7 @@
69
69
  "devDependencies": {
70
70
  "@types/bun": "latest",
71
71
  "node-addon-api": "^8.3.1",
72
+ "node-gyp": "^12.2.0",
72
73
  "typedoc": "^0.28.14",
73
74
  "typescript": "^5.9.2"
74
75
  }
@@ -1,8 +1,24 @@
1
1
  import { cp, mkdir } from "node:fs/promises";
2
+ import { createRequire } from "node:module";
2
3
  import { resolve } from "node:path";
3
4
 
5
+ function resolveNodeGypCommand(): string[] {
6
+ const npmNodeGyp = process.env.npm_config_node_gyp?.trim();
7
+ if (npmNodeGyp) {
8
+ return ["node", npmNodeGyp, "rebuild"];
9
+ }
10
+
11
+ try {
12
+ const require = createRequire(import.meta.url);
13
+ const nodeGypScript = require.resolve("node-gyp/bin/node-gyp.js");
14
+ return ["node", nodeGypScript, "rebuild"];
15
+ } catch {
16
+ return ["node-gyp", "rebuild"];
17
+ }
18
+ }
19
+
4
20
  async function main(): Promise<void> {
5
- const child = Bun.spawn(["bunx", "node-gyp", "rebuild"], {
21
+ const child = Bun.spawn(resolveNodeGypCommand(), {
6
22
  stdout: "inherit",
7
23
  stderr: "inherit",
8
24
  });
@@ -80,6 +80,14 @@ interface BenchmarkSnapshot {
80
80
  };
81
81
  }
82
82
 
83
+ function parseArgValue(flag: string): string | null {
84
+ const index = Bun.argv.indexOf(flag);
85
+ if (index === -1 || index + 1 >= Bun.argv.length) {
86
+ return null;
87
+ }
88
+ return Bun.argv[index + 1];
89
+ }
90
+
83
91
  function speedupThreshold(
84
92
  envName: string,
85
93
  defaultValue: number,
@@ -95,13 +103,18 @@ function speedupThreshold(
95
103
  return parsed;
96
104
  }
97
105
 
98
- const pathArgIndex = Bun.argv.indexOf("--input");
99
- const inputPath =
100
- pathArgIndex !== -1 && pathArgIndex + 1 < Bun.argv.length
101
- ? resolve(Bun.argv[pathArgIndex + 1])
102
- : resolve("bench/results/heart-ci-current.json");
106
+ const inputPath = resolve(parseArgValue("--input") ?? "bench/results/heart-ci-current.json");
107
+ const baselinePath = resolve(
108
+ parseArgValue("--baseline") ?? process.env.BENCH_BASELINE_INPUT ?? "bench/results/heart-ci-latest.json",
109
+ );
110
+ const baselineInputEnabled = inputPath !== baselinePath;
103
111
 
104
112
  const snapshot = JSON.parse(await readFile(inputPath, "utf-8")) as BenchmarkSnapshot;
113
+ const baselineSnapshot = baselineInputEnabled
114
+ ? ((await readFile(baselinePath, "utf-8").then((raw) => JSON.parse(raw) as BenchmarkSnapshot).catch(
115
+ () => null,
116
+ )) as BenchmarkSnapshot | null)
117
+ : null;
105
118
 
106
119
  const [bunRegression, sklearnRegression] = snapshot.suites.regression.results;
107
120
  const [bunClassification, sklearnClassification] = snapshot.suites.classification.results;
@@ -136,6 +149,14 @@ const maxZigForestPredictSlowdownVsJs = speedupThreshold(
136
149
  "BENCH_MAX_ZIG_FOREST_PREDICT_SLOWDOWN_VS_JS",
137
150
  20,
138
151
  );
152
+ const minZigTreeFitRetentionVsBaseline = speedupThreshold(
153
+ "BENCH_MIN_ZIG_TREE_FIT_RETENTION_VS_BASELINE",
154
+ 0.9,
155
+ );
156
+ const minZigForestFitRetentionVsBaseline = speedupThreshold(
157
+ "BENCH_MIN_ZIG_FOREST_FIT_RETENTION_VS_BASELINE",
158
+ 0.9,
159
+ );
139
160
 
140
161
  for (const result of [
141
162
  bunRegression,
@@ -296,6 +317,30 @@ if (snapshot.suites.treeBackendModes.enabled) {
296
317
  `RandomForest zig predict slowdown too large vs js-fast: ${randomForestPredictSlowdown} > ${maxZigForestPredictSlowdownVsJs}.`,
297
318
  );
298
319
  }
320
+
321
+ if (baselineSnapshot?.suites?.treeBackendModes?.enabled) {
322
+ const [baselineDecisionTreeModes, baselineRandomForestModes] =
323
+ baselineSnapshot.suites.treeBackendModes.models;
324
+ if (baselineDecisionTreeModes && baselineRandomForestModes) {
325
+ const decisionTreeFitRetention =
326
+ decisionTreeModes.comparison.zigFitSpeedupVsJs /
327
+ baselineDecisionTreeModes.comparison.zigFitSpeedupVsJs;
328
+ const randomForestFitRetention =
329
+ randomForestModes.comparison.zigFitSpeedupVsJs /
330
+ baselineRandomForestModes.comparison.zigFitSpeedupVsJs;
331
+
332
+ if (decisionTreeFitRetention < minZigTreeFitRetentionVsBaseline) {
333
+ throw new Error(
334
+ `DecisionTree zig/js fit retention too low vs baseline: ${decisionTreeFitRetention} < ${minZigTreeFitRetentionVsBaseline}.`,
335
+ );
336
+ }
337
+ if (randomForestFitRetention < minZigForestFitRetentionVsBaseline) {
338
+ throw new Error(
339
+ `RandomForest zig/js fit retention too low vs baseline: ${randomForestFitRetention} < ${minZigForestFitRetentionVsBaseline}.`,
340
+ );
341
+ }
342
+ }
343
+ }
299
344
  }
300
345
 
301
346
  console.log("Benchmark comparison health checks passed.");
@@ -2,6 +2,7 @@ import type { ClassificationModel, Matrix, Vector } from "../types";
2
2
  import { accuracyScore } from "../metrics/classification";
3
3
  import { DecisionTreeClassifier, type MaxFeaturesOption } from "../tree/DecisionTreeClassifier";
4
4
  import { assertFiniteVector, validateClassificationInputs } from "../utils/validation";
5
+ import { getZigKernels } from "../native/zigKernels";
5
6
 
6
7
  export interface RandomForestClassifierOptions {
7
8
  nEstimators?: number;
@@ -23,8 +24,18 @@ function mulberry32(seed: number): () => number {
23
24
  };
24
25
  }
25
26
 
27
+ function isTruthy(value: string | undefined): boolean {
28
+ if (!value) {
29
+ return false;
30
+ }
31
+ const normalized = value.trim().toLowerCase();
32
+ return !(normalized === "0" || normalized === "false" || normalized === "off");
33
+ }
34
+
26
35
  export class RandomForestClassifier implements ClassificationModel {
27
36
  classes_: Vector = [0, 1];
37
+ fitBackend_: "zig" | "js" = "js";
38
+ fitBackendLibrary_: string | null = null;
28
39
  private readonly nEstimators: number;
29
40
  private readonly maxDepth?: number;
30
41
  private readonly minSamplesSplit?: number;
@@ -32,6 +43,7 @@ export class RandomForestClassifier implements ClassificationModel {
32
43
  private readonly maxFeatures: MaxFeaturesOption;
33
44
  private readonly bootstrap: boolean;
34
45
  private readonly randomState?: number;
46
+ private nativeModelHandle: bigint | null = null;
35
47
  private trees: DecisionTreeClassifier[] = [];
36
48
 
37
49
  constructor(options: RandomForestClassifierOptions = {}) {
@@ -49,6 +61,7 @@ export class RandomForestClassifier implements ClassificationModel {
49
61
  }
50
62
 
51
63
  fit(X: Matrix, y: Vector): this {
64
+ this.disposeNativeModel();
52
65
  validateClassificationInputs(X, y);
53
66
 
54
67
  const sampleCount = X.length;
@@ -56,10 +69,17 @@ export class RandomForestClassifier implements ClassificationModel {
56
69
  const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
57
70
  const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
58
71
  const yBinary = this.buildBinaryTargets(y);
72
+ const sampleIndices = new Uint32Array(sampleCount);
73
+ this.trees = [];
74
+ if (this.tryFitNativeForest(flattenedX, yBinary, sampleCount, featureCount)) {
75
+ this.fitBackend_ = "zig";
76
+ return this;
77
+ }
78
+ this.fitBackend_ = "js";
79
+ this.fitBackendLibrary_ = null;
59
80
  this.trees = new Array(this.nEstimators);
60
81
 
61
82
  for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
62
- const sampleIndices = new Uint32Array(sampleCount);
63
83
  if (this.bootstrap) {
64
84
  for (let i = 0; i < sampleCount; i += 1) {
65
85
  sampleIndices[i] = Math.floor(random() * sampleCount);
@@ -86,20 +106,47 @@ export class RandomForestClassifier implements ClassificationModel {
86
106
  }
87
107
 
88
108
  predict(X: Matrix): Vector {
109
+ if (this.nativeModelHandle !== null) {
110
+ const kernels = getZigKernels();
111
+ const predict = kernels?.randomForestClassifierModelPredict;
112
+ if (predict) {
113
+ const sampleCount = X.length;
114
+ const featureCount = X[0]?.length ?? 0;
115
+ const flattened = this.flattenTrainingMatrix(X, sampleCount, featureCount);
116
+ const out = new Uint8Array(sampleCount);
117
+ const status = predict(
118
+ this.nativeModelHandle,
119
+ flattened,
120
+ sampleCount,
121
+ featureCount,
122
+ out,
123
+ );
124
+ if (status === 1) {
125
+ return Array.from(out);
126
+ }
127
+ }
128
+ }
129
+
89
130
  if (this.trees.length === 0) {
90
131
  throw new Error("RandomForestClassifier has not been fitted.");
91
132
  }
92
133
 
93
- const treePredictions = this.trees.map((tree) => tree.predict(X));
94
134
  const sampleCount = X.length;
95
- const predictions = new Array(sampleCount).fill(0);
135
+ const voteCounts = new Uint16Array(sampleCount);
96
136
 
97
- for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
98
- let positiveVotes = 0;
99
- for (let treeIndex = 0; treeIndex < treePredictions.length; treeIndex += 1) {
100
- positiveVotes += treePredictions[treeIndex][sampleIndex] === 1 ? 1 : 0;
137
+ for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
138
+ const treePrediction = this.trees[treeIndex].predict(X);
139
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
140
+ if (treePrediction[sampleIndex] === 1) {
141
+ voteCounts[sampleIndex] += 1;
142
+ }
101
143
  }
102
- predictions[sampleIndex] = positiveVotes * 2 >= this.trees.length ? 1 : 0;
144
+ }
145
+
146
+ const predictions = new Array<number>(sampleCount);
147
+ const voteThreshold = this.trees.length;
148
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
149
+ predictions[sampleIndex] = voteCounts[sampleIndex] * 2 >= voteThreshold ? 1 : 0;
103
150
  }
104
151
 
105
152
  return predictions;
@@ -110,6 +157,105 @@ export class RandomForestClassifier implements ClassificationModel {
110
157
  return accuracyScore(y, this.predict(X));
111
158
  }
112
159
 
160
+ dispose(): void {
161
+ this.disposeNativeModel();
162
+ for (let i = 0; i < this.trees.length; i += 1) {
163
+ this.trees[i].dispose();
164
+ }
165
+ this.trees = [];
166
+ }
167
+
168
+ private resolveNativeMaxFeatures(featureCount: number): {
169
+ mode: 0 | 1 | 2 | 3;
170
+ value: number;
171
+ } {
172
+ if (this.maxFeatures === null || this.maxFeatures === undefined) {
173
+ return { mode: 0, value: 0 };
174
+ }
175
+ if (this.maxFeatures === "sqrt") {
176
+ return { mode: 1, value: 0 };
177
+ }
178
+ if (this.maxFeatures === "log2") {
179
+ return { mode: 2, value: 0 };
180
+ }
181
+ const value = Number.isFinite(this.maxFeatures)
182
+ ? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
183
+ : featureCount;
184
+ return { mode: 3, value };
185
+ }
186
+
187
+ private tryFitNativeForest(
188
+ flattenedX: Float64Array,
189
+ yBinary: Uint8Array,
190
+ sampleCount: number,
191
+ featureCount: number,
192
+ ): boolean {
193
+ if (!isTruthy(process.env.BUN_SCIKIT_EXPERIMENTAL_NATIVE_FOREST)) {
194
+ return false;
195
+ }
196
+ if (process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase() !== "zig") {
197
+ return false;
198
+ }
199
+ const kernels = getZigKernels();
200
+ const create = kernels?.randomForestClassifierModelCreate;
201
+ const fit = kernels?.randomForestClassifierModelFit;
202
+ const destroy = kernels?.randomForestClassifierModelDestroy;
203
+ if (!create || !fit || !destroy) {
204
+ return false;
205
+ }
206
+
207
+ const { mode, value } = this.resolveNativeMaxFeatures(featureCount);
208
+ const useRandomState = this.randomState === undefined ? 0 : 1;
209
+ const randomState = this.randomState ?? 0;
210
+ const handle = create(
211
+ this.nEstimators,
212
+ this.maxDepth ?? 12,
213
+ this.minSamplesSplit ?? 2,
214
+ this.minSamplesLeaf ?? 1,
215
+ mode,
216
+ value,
217
+ this.bootstrap ? 1 : 0,
218
+ randomState >>> 0,
219
+ useRandomState,
220
+ featureCount,
221
+ );
222
+ if (handle === 0n) {
223
+ return false;
224
+ }
225
+
226
+ let shouldDestroy = true;
227
+ try {
228
+ const status = fit(handle, flattenedX, yBinary, sampleCount, featureCount);
229
+ if (status !== 1) {
230
+ return false;
231
+ }
232
+ this.nativeModelHandle = handle;
233
+ this.fitBackendLibrary_ = kernels.libraryPath;
234
+ shouldDestroy = false;
235
+ return true;
236
+ } finally {
237
+ if (shouldDestroy) {
238
+ destroy(handle);
239
+ }
240
+ }
241
+ }
242
+
243
+ private disposeNativeModel(): void {
244
+ if (this.nativeModelHandle === null) {
245
+ return;
246
+ }
247
+ const kernels = getZigKernels();
248
+ const destroy = kernels?.randomForestClassifierModelDestroy;
249
+ if (destroy) {
250
+ try {
251
+ destroy(this.nativeModelHandle);
252
+ } catch {
253
+ // best effort cleanup
254
+ }
255
+ }
256
+ this.nativeModelHandle = null;
257
+ }
258
+
113
259
  private flattenTrainingMatrix(
114
260
  X: Matrix,
115
261
  sampleCount: number,
@@ -56,10 +56,10 @@ export class RandomForestRegressor implements RegressionModel {
56
56
  const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
57
57
  const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
58
58
  const yValues = this.toFloat64Vector(y);
59
+ const sampleIndices = new Uint32Array(sampleCount);
59
60
  this.trees = new Array(this.nEstimators);
60
61
 
61
62
  for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
62
- const sampleIndices = new Uint32Array(sampleCount);
63
63
  if (this.bootstrap) {
64
64
  for (let i = 0; i < sampleCount; i += 1) {
65
65
  sampleIndices[i] = Math.floor(random() * sampleCount);
@@ -90,16 +90,20 @@ export class RandomForestRegressor implements RegressionModel {
90
90
  throw new Error("RandomForestRegressor has not been fitted.");
91
91
  }
92
92
 
93
- const treePredictions = this.trees.map((tree) => tree.predict(X));
94
93
  const sampleCount = X.length;
95
- const predictions = new Array<number>(sampleCount).fill(0);
94
+ const sums = new Float64Array(sampleCount);
96
95
 
97
- for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
98
- let sum = 0;
99
- for (let treeIndex = 0; treeIndex < treePredictions.length; treeIndex += 1) {
100
- sum += treePredictions[treeIndex][sampleIndex];
96
+ for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
97
+ const treePrediction = this.trees[treeIndex].predict(X);
98
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
99
+ sums[sampleIndex] += treePrediction[sampleIndex];
101
100
  }
102
- predictions[sampleIndex] = sum / this.trees.length;
101
+ }
102
+
103
+ const predictions = new Array<number>(sampleCount);
104
+ const denominator = this.trees.length;
105
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
106
+ predictions[sampleIndex] = sums[sampleIndex] / denominator;
103
107
  }
104
108
 
105
109
  return predictions;
@@ -31,6 +31,20 @@ using DecisionTreeModelCreateFn = NativeHandle (*)(std::size_t, std::size_t, std
31
31
  using DecisionTreeModelDestroyFn = void (*)(NativeHandle);
32
32
  using DecisionTreeModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t, const std::uint32_t*, std::size_t);
33
33
  using DecisionTreeModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
34
+ using RandomForestClassifierModelCreateFn = NativeHandle (*)(
35
+ std::size_t,
36
+ std::size_t,
37
+ std::size_t,
38
+ std::size_t,
39
+ std::uint8_t,
40
+ std::size_t,
41
+ std::uint8_t,
42
+ std::uint32_t,
43
+ std::uint8_t,
44
+ std::size_t);
45
+ using RandomForestClassifierModelDestroyFn = void (*)(NativeHandle);
46
+ using RandomForestClassifierModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t);
47
+ using RandomForestClassifierModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
34
48
 
35
49
  struct KernelLibrary {
36
50
  #if defined(_WIN32)
@@ -55,6 +69,10 @@ struct KernelLibrary {
55
69
  DecisionTreeModelDestroyFn decision_tree_model_destroy{nullptr};
56
70
  DecisionTreeModelFitFn decision_tree_model_fit{nullptr};
57
71
  DecisionTreeModelPredictFn decision_tree_model_predict{nullptr};
72
+ RandomForestClassifierModelCreateFn random_forest_classifier_model_create{nullptr};
73
+ RandomForestClassifierModelDestroyFn random_forest_classifier_model_destroy{nullptr};
74
+ RandomForestClassifierModelFitFn random_forest_classifier_model_fit{nullptr};
75
+ RandomForestClassifierModelPredictFn random_forest_classifier_model_predict{nullptr};
58
76
  };
59
77
 
60
78
  KernelLibrary g_library{};
@@ -154,6 +172,14 @@ Napi::Value LoadNativeLibrary(const Napi::CallbackInfo& info) {
154
172
  loadSymbol<DecisionTreeModelFitFn>("decision_tree_model_fit");
155
173
  g_library.decision_tree_model_predict =
156
174
  loadSymbol<DecisionTreeModelPredictFn>("decision_tree_model_predict");
175
+ g_library.random_forest_classifier_model_create =
176
+ loadSymbol<RandomForestClassifierModelCreateFn>("random_forest_classifier_model_create");
177
+ g_library.random_forest_classifier_model_destroy =
178
+ loadSymbol<RandomForestClassifierModelDestroyFn>("random_forest_classifier_model_destroy");
179
+ g_library.random_forest_classifier_model_fit =
180
+ loadSymbol<RandomForestClassifierModelFitFn>("random_forest_classifier_model_fit");
181
+ g_library.random_forest_classifier_model_predict =
182
+ loadSymbol<RandomForestClassifierModelPredictFn>("random_forest_classifier_model_predict");
157
183
 
158
184
  return Napi::Boolean::New(env, true);
159
185
  }
@@ -567,6 +593,134 @@ Napi::Value DecisionTreeModelPredict(const Napi::CallbackInfo& info) {
567
593
  return Napi::Number::New(env, status);
568
594
  }
569
595
 
596
+ Napi::Value RandomForestClassifierModelCreate(const Napi::CallbackInfo& info) {
597
+ const Napi::Env env = info.Env();
598
+ if (!isLibraryLoaded(env)) {
599
+ return env.Null();
600
+ }
601
+ if (!g_library.random_forest_classifier_model_create) {
602
+ throwError(env, "Symbol random_forest_classifier_model_create is unavailable.");
603
+ return env.Null();
604
+ }
605
+ if (info.Length() != 10 || !info[0].IsNumber() || !info[1].IsNumber() || !info[2].IsNumber() ||
606
+ !info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsNumber() || !info[6].IsNumber() ||
607
+ !info[7].IsNumber() || !info[8].IsNumber() || !info[9].IsNumber()) {
608
+ throwTypeError(env, "randomForestClassifierModelCreate(nEstimators, maxDepth, minSamplesSplit, minSamplesLeaf, maxFeaturesMode, maxFeaturesValue, bootstrap, randomState, useRandomState, nFeatures) expects ten numbers.");
609
+ return env.Null();
610
+ }
611
+
612
+ const std::size_t n_estimators = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
613
+ const std::size_t max_depth = static_cast<std::size_t>(info[1].As<Napi::Number>().Uint32Value());
614
+ const std::size_t min_samples_split = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
615
+ const std::size_t min_samples_leaf = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
616
+ const std::uint8_t max_features_mode = static_cast<std::uint8_t>(info[4].As<Napi::Number>().Uint32Value());
617
+ const std::size_t max_features_value = static_cast<std::size_t>(info[5].As<Napi::Number>().Uint32Value());
618
+ const std::uint8_t bootstrap = static_cast<std::uint8_t>(info[6].As<Napi::Number>().Uint32Value());
619
+ const std::uint32_t random_state = static_cast<std::uint32_t>(info[7].As<Napi::Number>().Uint32Value());
620
+ const std::uint8_t use_random_state = static_cast<std::uint8_t>(info[8].As<Napi::Number>().Uint32Value());
621
+ const std::size_t n_features = static_cast<std::size_t>(info[9].As<Napi::Number>().Uint32Value());
622
+
623
+ const NativeHandle handle = g_library.random_forest_classifier_model_create(
624
+ n_estimators,
625
+ max_depth,
626
+ min_samples_split,
627
+ min_samples_leaf,
628
+ max_features_mode,
629
+ max_features_value,
630
+ bootstrap,
631
+ random_state,
632
+ use_random_state,
633
+ n_features);
634
+ return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
635
+ }
636
+
637
+ Napi::Value RandomForestClassifierModelDestroy(const Napi::CallbackInfo& info) {
638
+ const Napi::Env env = info.Env();
639
+ if (!isLibraryLoaded(env)) {
640
+ return env.Null();
641
+ }
642
+ if (!g_library.random_forest_classifier_model_destroy) {
643
+ throwError(env, "Symbol random_forest_classifier_model_destroy is unavailable.");
644
+ return env.Null();
645
+ }
646
+ if (info.Length() != 1) {
647
+ throwTypeError(env, "randomForestClassifierModelDestroy(handle) expects one BigInt.");
648
+ return env.Null();
649
+ }
650
+ const NativeHandle handle = handleFromBigInt(info[0], env);
651
+ if (env.IsExceptionPending()) {
652
+ return env.Null();
653
+ }
654
+ g_library.random_forest_classifier_model_destroy(handle);
655
+ return env.Undefined();
656
+ }
657
+
658
+ Napi::Value RandomForestClassifierModelFit(const Napi::CallbackInfo& info) {
659
+ const Napi::Env env = info.Env();
660
+ if (!isLibraryLoaded(env)) {
661
+ return env.Null();
662
+ }
663
+ if (!g_library.random_forest_classifier_model_fit) {
664
+ throwError(env, "Symbol random_forest_classifier_model_fit is unavailable.");
665
+ return env.Null();
666
+ }
667
+ if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
668
+ !info[3].IsNumber() || !info[4].IsNumber()) {
669
+ throwTypeError(env, "randomForestClassifierModelFit(handle, x, y, nSamples, nFeatures) has invalid arguments.");
670
+ return env.Null();
671
+ }
672
+
673
+ const NativeHandle handle = handleFromBigInt(info[0], env);
674
+ if (env.IsExceptionPending()) {
675
+ return env.Null();
676
+ }
677
+ auto x = info[1].As<Napi::Float64Array>();
678
+ auto y = info[2].As<Napi::Uint8Array>();
679
+ const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
680
+ const std::size_t n_features = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
681
+
682
+ const std::uint8_t status = g_library.random_forest_classifier_model_fit(
683
+ handle,
684
+ x.Data(),
685
+ y.Data(),
686
+ n_samples,
687
+ n_features);
688
+ return Napi::Number::New(env, status);
689
+ }
690
+
691
+ Napi::Value RandomForestClassifierModelPredict(const Napi::CallbackInfo& info) {
692
+ const Napi::Env env = info.Env();
693
+ if (!isLibraryLoaded(env)) {
694
+ return env.Null();
695
+ }
696
+ if (!g_library.random_forest_classifier_model_predict) {
697
+ throwError(env, "Symbol random_forest_classifier_model_predict is unavailable.");
698
+ return env.Null();
699
+ }
700
+ if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsNumber() || !info[3].IsNumber() ||
701
+ !info[4].IsTypedArray()) {
702
+ throwTypeError(env, "randomForestClassifierModelPredict(handle, x, nSamples, nFeatures, outLabels) has invalid arguments.");
703
+ return env.Null();
704
+ }
705
+
706
+ const NativeHandle handle = handleFromBigInt(info[0], env);
707
+ if (env.IsExceptionPending()) {
708
+ return env.Null();
709
+ }
710
+ auto x = info[1].As<Napi::Float64Array>();
711
+ const std::size_t n_samples = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
712
+ const std::size_t n_features = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
713
+ auto out_labels = info[4].As<Napi::Uint8Array>();
714
+
715
+ const std::uint8_t status = g_library.random_forest_classifier_model_predict(
716
+ handle,
717
+ x.Data(),
718
+ n_samples,
719
+ n_features,
720
+ out_labels.Data());
721
+ return Napi::Number::New(env, status);
722
+ }
723
+
570
724
  Napi::Object Init(Napi::Env env, Napi::Object exports) {
571
725
  exports.Set("loadLibrary", Napi::Function::New(env, LoadNativeLibrary));
572
726
  exports.Set("unloadLibrary", Napi::Function::New(env, UnloadLibrary));
@@ -590,6 +744,10 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
590
744
  exports.Set("decisionTreeModelDestroy", Napi::Function::New(env, DecisionTreeModelDestroy));
591
745
  exports.Set("decisionTreeModelFit", Napi::Function::New(env, DecisionTreeModelFit));
592
746
  exports.Set("decisionTreeModelPredict", Napi::Function::New(env, DecisionTreeModelPredict));
747
+ exports.Set("randomForestClassifierModelCreate", Napi::Function::New(env, RandomForestClassifierModelCreate));
748
+ exports.Set("randomForestClassifierModelDestroy", Napi::Function::New(env, RandomForestClassifierModelDestroy));
749
+ exports.Set("randomForestClassifierModelFit", Napi::Function::New(env, RandomForestClassifierModelFit));
750
+ exports.Set("randomForestClassifierModelPredict", Napi::Function::New(env, RandomForestClassifierModelPredict));
593
751
 
594
752
  return exports;
595
753
  }
@@ -88,6 +88,33 @@ type DecisionTreeModelPredictFn = (
88
88
  nFeatures: number,
89
89
  outLabels: Uint8Array,
90
90
  ) => number;
91
+ type RandomForestClassifierModelCreateFn = (
92
+ nEstimators: number,
93
+ maxDepth: number,
94
+ minSamplesSplit: number,
95
+ minSamplesLeaf: number,
96
+ maxFeaturesMode: number,
97
+ maxFeaturesValue: number,
98
+ bootstrap: number,
99
+ randomState: number,
100
+ useRandomState: number,
101
+ nFeatures: number,
102
+ ) => NativeHandle;
103
+ type RandomForestClassifierModelDestroyFn = (handle: NativeHandle) => void;
104
+ type RandomForestClassifierModelFitFn = (
105
+ handle: NativeHandle,
106
+ x: Float64Array,
107
+ y: Uint8Array,
108
+ nSamples: number,
109
+ nFeatures: number,
110
+ ) => number;
111
+ type RandomForestClassifierModelPredictFn = (
112
+ handle: NativeHandle,
113
+ x: Float64Array,
114
+ nSamples: number,
115
+ nFeatures: number,
116
+ outLabels: Uint8Array,
117
+ ) => number;
91
118
 
92
119
  type LogisticTrainEpochFn = (
93
120
  x: Float64Array,
@@ -138,6 +165,10 @@ interface ZigKernelLibrary {
138
165
  decision_tree_model_destroy?: DecisionTreeModelDestroyFn;
139
166
  decision_tree_model_fit?: DecisionTreeModelFitFn;
140
167
  decision_tree_model_predict?: DecisionTreeModelPredictFn;
168
+ random_forest_classifier_model_create?: RandomForestClassifierModelCreateFn;
169
+ random_forest_classifier_model_destroy?: RandomForestClassifierModelDestroyFn;
170
+ random_forest_classifier_model_fit?: RandomForestClassifierModelFitFn;
171
+ random_forest_classifier_model_predict?: RandomForestClassifierModelPredictFn;
141
172
  logistic_train_epoch?: LogisticTrainEpochFn;
142
173
  logistic_train_epochs?: LogisticTrainEpochsFn;
143
174
  };
@@ -162,6 +193,10 @@ export interface ZigKernels {
162
193
  decisionTreeModelDestroy: DecisionTreeModelDestroyFn | null;
163
194
  decisionTreeModelFit: DecisionTreeModelFitFn | null;
164
195
  decisionTreeModelPredict: DecisionTreeModelPredictFn | null;
196
+ randomForestClassifierModelCreate: RandomForestClassifierModelCreateFn | null;
197
+ randomForestClassifierModelDestroy: RandomForestClassifierModelDestroyFn | null;
198
+ randomForestClassifierModelFit: RandomForestClassifierModelFitFn | null;
199
+ randomForestClassifierModelPredict: RandomForestClassifierModelPredictFn | null;
165
200
  logisticTrainEpoch: LogisticTrainEpochFn | null;
166
201
  logisticTrainEpochs: LogisticTrainEpochsFn | null;
167
202
  abiVersion: number | null;
@@ -247,6 +282,10 @@ interface NodeApiAddon {
247
282
  decisionTreeModelDestroy?: DecisionTreeModelDestroyFn;
248
283
  decisionTreeModelFit?: DecisionTreeModelFitFn;
249
284
  decisionTreeModelPredict?: DecisionTreeModelPredictFn;
285
+ randomForestClassifierModelCreate?: RandomForestClassifierModelCreateFn;
286
+ randomForestClassifierModelDestroy?: RandomForestClassifierModelDestroyFn;
287
+ randomForestClassifierModelFit?: RandomForestClassifierModelFitFn;
288
+ randomForestClassifierModelPredict?: RandomForestClassifierModelPredictFn;
250
289
  }
251
290
 
252
291
  function tryLoadNodeApiKernels(): ZigKernels | null {
@@ -289,6 +328,13 @@ function tryLoadNodeApiKernels(): ZigKernels | null {
289
328
  decisionTreeModelDestroy: addon.decisionTreeModelDestroy ?? null,
290
329
  decisionTreeModelFit: addon.decisionTreeModelFit ?? null,
291
330
  decisionTreeModelPredict: addon.decisionTreeModelPredict ?? null,
331
+ randomForestClassifierModelCreate:
332
+ addon.randomForestClassifierModelCreate ?? null,
333
+ randomForestClassifierModelDestroy:
334
+ addon.randomForestClassifierModelDestroy ?? null,
335
+ randomForestClassifierModelFit: addon.randomForestClassifierModelFit ?? null,
336
+ randomForestClassifierModelPredict:
337
+ addon.randomForestClassifierModelPredict ?? null,
292
338
  logisticTrainEpoch: null,
293
339
  logisticTrainEpochs: null,
294
340
  abiVersion,
@@ -432,6 +478,33 @@ export function getZigKernels(): ZigKernels | null {
432
478
  args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
433
479
  returns: FFIType.u8,
434
480
  },
481
+ random_forest_classifier_model_create: {
482
+ args: [
483
+ "usize",
484
+ "usize",
485
+ "usize",
486
+ "usize",
487
+ FFIType.u8,
488
+ "usize",
489
+ FFIType.u8,
490
+ FFIType.u32,
491
+ FFIType.u8,
492
+ "usize",
493
+ ],
494
+ returns: "usize",
495
+ },
496
+ random_forest_classifier_model_destroy: {
497
+ args: ["usize"],
498
+ returns: FFIType.void,
499
+ },
500
+ random_forest_classifier_model_fit: {
501
+ args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize"],
502
+ returns: FFIType.u8,
503
+ },
504
+ random_forest_classifier_model_predict: {
505
+ args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
506
+ returns: FFIType.u8,
507
+ },
435
508
  logistic_train_epoch: {
436
509
  args: [
437
510
  FFIType.ptr,
@@ -492,6 +565,14 @@ export function getZigKernels(): ZigKernels | null {
492
565
  decisionTreeModelDestroy: library.symbols.decision_tree_model_destroy ?? null,
493
566
  decisionTreeModelFit: library.symbols.decision_tree_model_fit ?? null,
494
567
  decisionTreeModelPredict: library.symbols.decision_tree_model_predict ?? null,
568
+ randomForestClassifierModelCreate:
569
+ library.symbols.random_forest_classifier_model_create ?? null,
570
+ randomForestClassifierModelDestroy:
571
+ library.symbols.random_forest_classifier_model_destroy ?? null,
572
+ randomForestClassifierModelFit:
573
+ library.symbols.random_forest_classifier_model_fit ?? null,
574
+ randomForestClassifierModelPredict:
575
+ library.symbols.random_forest_classifier_model_predict ?? null,
495
576
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
496
577
  logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
497
578
  abiVersion,
@@ -555,6 +636,10 @@ export function getZigKernels(): ZigKernels | null {
555
636
  decisionTreeModelDestroy: null,
556
637
  decisionTreeModelFit: null,
557
638
  decisionTreeModelPredict: null,
639
+ randomForestClassifierModelCreate: null,
640
+ randomForestClassifierModelDestroy: null,
641
+ randomForestClassifierModelFit: null,
642
+ randomForestClassifierModelPredict: null,
558
643
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
559
644
  logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
560
645
  abiVersion: null,
@@ -600,6 +685,10 @@ export function getZigKernels(): ZigKernels | null {
600
685
  decisionTreeModelDestroy: null,
601
686
  decisionTreeModelFit: null,
602
687
  decisionTreeModelPredict: null,
688
+ randomForestClassifierModelCreate: null,
689
+ randomForestClassifierModelDestroy: null,
690
+ randomForestClassifierModelFit: null,
691
+ randomForestClassifierModelPredict: null,
603
692
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
604
693
  logisticTrainEpochs: null,
605
694
  abiVersion: null,
@@ -185,7 +185,12 @@ export class DecisionTreeClassifier implements ClassificationModel {
185
185
  }
186
186
  }
187
187
 
188
- return X.map((sample) => this.predictOne(sample, this.root!));
188
+ const predictions = new Array<number>(X.length);
189
+ const root = this.root!;
190
+ for (let i = 0; i < X.length; i += 1) {
191
+ predictions[i] = this.predictOne(X[i], root);
192
+ }
193
+ return predictions;
189
194
  }
190
195
 
191
196
  score(X: Matrix, y: Vector): number {
@@ -193,6 +198,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
193
198
  return accuracyScore(y, this.predict(X));
194
199
  }
195
200
 
201
+ dispose(): void {
202
+ this.destroyZigModel();
203
+ this.root = null;
204
+ this.flattenedXTrain = null;
205
+ this.yBinaryTrain = null;
206
+ }
207
+
196
208
  private predictOne(sample: Vector, node: TreeNode): 0 | 1 {
197
209
  let current: TreeNode = node;
198
210
  while (
package/zig/kernels.zig CHANGED
@@ -74,12 +74,31 @@ const DecisionTreeModel = struct {
74
74
  use_random_state: bool,
75
75
  root_index: usize,
76
76
  has_root: bool,
77
+ feature_scratch: []usize,
77
78
  nodes: std.ArrayListUnmanaged(TreeNode),
78
79
  };
79
80
 
80
- const SplitResult = struct {
81
+ const RandomForestClassifierModel = struct {
82
+ n_features: usize,
83
+ n_estimators: usize,
84
+ max_depth: usize,
85
+ min_samples_split: usize,
86
+ min_samples_leaf: usize,
87
+ max_features_mode: u8,
88
+ max_features_value: usize,
89
+ bootstrap: bool,
90
+ random_state: u32,
91
+ use_random_state: bool,
92
+ tree_handles: []usize,
93
+ fitted_estimators: usize,
94
+ };
95
+
96
+ const SplitEvaluation = struct {
81
97
  threshold: f64,
82
98
  impurity: f64,
99
+ };
100
+
101
+ const SplitPartition = struct {
83
102
  left_indices: []usize,
84
103
  right_indices: []usize,
85
104
  };
@@ -167,41 +186,33 @@ fn resolveMaxFeatures(model: *const DecisionTreeModel) usize {
167
186
  }
168
187
  }
169
188
 
170
- fn freeSplit(split: SplitResult) void {
171
- allocator.free(split.left_indices);
172
- allocator.free(split.right_indices);
189
+ inline fn asRandomForestClassifierModel(handle: usize) ?*RandomForestClassifierModel {
190
+ if (handle == 0) {
191
+ return null;
192
+ }
193
+ return @as(*RandomForestClassifierModel, @ptrFromInt(handle));
173
194
  }
174
195
 
175
- fn selectCandidateFeatures(model: *const DecisionTreeModel, rng: *Mulberry32) ![]usize {
176
- const k = resolveMaxFeatures(model);
177
- if (k >= model.n_features) {
178
- const all_features = try allocator.alloc(usize, model.n_features);
179
- errdefer allocator.free(all_features);
180
- for (all_features, 0..) |*entry, idx| {
181
- entry.* = idx;
182
- }
183
- return all_features;
196
+ fn selectCandidateFeatures(model: *DecisionTreeModel, rng: *Mulberry32) []const usize {
197
+ for (model.feature_scratch, 0..) |*entry, idx| {
198
+ entry.* = idx;
184
199
  }
185
200
 
186
- const shuffled = try allocator.alloc(usize, model.n_features);
187
- errdefer allocator.free(shuffled);
188
- for (shuffled, 0..) |*entry, idx| {
189
- entry.* = idx;
201
+ const k = resolveMaxFeatures(model);
202
+ if (k >= model.n_features) {
203
+ return model.feature_scratch[0..model.n_features];
190
204
  }
191
205
 
192
- var i = model.n_features;
193
- while (i > 1) {
194
- i -= 1;
195
- const j = rng.nextIndex(i + 1);
196
- const tmp = shuffled[i];
197
- shuffled[i] = shuffled[j];
198
- shuffled[j] = tmp;
206
+ var i: usize = 0;
207
+ while (i < k) : (i += 1) {
208
+ const remaining = model.n_features - i;
209
+ const j = i + rng.nextIndex(remaining);
210
+ const tmp = model.feature_scratch[i];
211
+ model.feature_scratch[i] = model.feature_scratch[j];
212
+ model.feature_scratch[j] = tmp;
199
213
  }
200
214
 
201
- const selected = try allocator.alloc(usize, k);
202
- @memcpy(selected, shuffled[0..k]);
203
- allocator.free(shuffled);
204
- return selected;
215
+ return model.feature_scratch[0..k];
205
216
  }
206
217
 
207
218
  fn findBestSplitForFeature(
@@ -210,7 +221,7 @@ fn findBestSplitForFeature(
210
221
  y_ptr: [*]const u8,
211
222
  indices: []const usize,
212
223
  feature_index: usize,
213
- ) !?SplitResult {
224
+ ) ?SplitEvaluation {
214
225
  const sample_count = indices.len;
215
226
  if (sample_count < 2) {
216
227
  return null;
@@ -282,29 +293,41 @@ fn findBestSplitForFeature(
282
293
  return null;
283
294
  }
284
295
 
285
- var left_partition_count: usize = 0;
296
+ return SplitEvaluation{
297
+ .threshold = best_threshold,
298
+ .impurity = best_impurity,
299
+ };
300
+ }
301
+
302
+ fn partitionIndicesForThreshold(
303
+ model: *const DecisionTreeModel,
304
+ workspace: std.mem.Allocator,
305
+ x_ptr: [*]const f64,
306
+ indices: []const usize,
307
+ feature_index: usize,
308
+ threshold: f64,
309
+ ) !?SplitPartition {
310
+ var left_count: usize = 0;
286
311
  for (indices) |sample_index| {
287
312
  const value = x_ptr[sample_index * model.n_features + feature_index];
288
- if (value <= best_threshold) {
289
- left_partition_count += 1;
313
+ if (value <= threshold) {
314
+ left_count += 1;
290
315
  }
291
316
  }
292
317
 
293
- const right_partition_count = sample_count - left_partition_count;
294
- if (left_partition_count < model.min_samples_leaf or right_partition_count < model.min_samples_leaf) {
318
+ const right_count = indices.len - left_count;
319
+ if (left_count < model.min_samples_leaf or right_count < model.min_samples_leaf) {
295
320
  return null;
296
321
  }
297
322
 
298
- const left_indices = try allocator.alloc(usize, left_partition_count);
299
- errdefer allocator.free(left_indices);
300
- const right_indices = try allocator.alloc(usize, right_partition_count);
301
- errdefer allocator.free(right_indices);
323
+ const left_indices = try workspace.alloc(usize, left_count);
324
+ const right_indices = try workspace.alloc(usize, right_count);
302
325
 
303
326
  var left_write: usize = 0;
304
327
  var right_write: usize = 0;
305
328
  for (indices) |sample_index| {
306
329
  const value = x_ptr[sample_index * model.n_features + feature_index];
307
- if (value <= best_threshold) {
330
+ if (value <= threshold) {
308
331
  left_indices[left_write] = sample_index;
309
332
  left_write += 1;
310
333
  } else {
@@ -313,9 +336,7 @@ fn findBestSplitForFeature(
313
336
  }
314
337
  }
315
338
 
316
- return SplitResult{
317
- .threshold = best_threshold,
318
- .impurity = best_impurity,
339
+ return SplitPartition{
319
340
  .left_indices = left_indices,
320
341
  .right_indices = right_indices,
321
342
  };
@@ -323,6 +344,7 @@ fn findBestSplitForFeature(
323
344
 
324
345
  fn buildDecisionTreeNode(
325
346
  model: *DecisionTreeModel,
347
+ workspace: std.mem.Allocator,
326
348
  x_ptr: [*]const f64,
327
349
  y_ptr: [*]const u8,
328
350
  indices: []const usize,
@@ -353,25 +375,19 @@ fn buildDecisionTreeNode(
353
375
  }
354
376
 
355
377
  const parent_impurity = giniImpurity(positive_count, sample_count);
356
- const candidate_features = try selectCandidateFeatures(model, rng);
357
- defer allocator.free(candidate_features);
378
+ const candidate_features = selectCandidateFeatures(model, rng);
358
379
 
359
380
  var best_feature: usize = 0;
360
- var best_split: ?SplitResult = null;
381
+ var best_split: ?SplitEvaluation = null;
361
382
  var best_found = false;
362
383
 
363
384
  for (candidate_features) |feature_index| {
364
- const split_opt = try findBestSplitForFeature(model, x_ptr, y_ptr, indices, feature_index);
385
+ const split_opt = findBestSplitForFeature(model, x_ptr, y_ptr, indices, feature_index);
365
386
  if (split_opt) |split| {
366
387
  if (!best_found or split.impurity < best_split.?.impurity) {
367
- if (best_split) |previous| {
368
- freeSplit(previous);
369
- }
370
388
  best_split = split;
371
389
  best_feature = feature_index;
372
390
  best_found = true;
373
- } else {
374
- freeSplit(split);
375
391
  }
376
392
  }
377
393
  }
@@ -390,7 +406,6 @@ fn buildDecisionTreeNode(
390
406
  }
391
407
 
392
408
  const split = best_split.?;
393
- defer freeSplit(split);
394
409
  if (split.impurity >= parent_impurity - 1e-12) {
395
410
  const node_index = model.nodes.items.len;
396
411
  try model.nodes.append(allocator, TreeNode{
@@ -404,6 +419,25 @@ fn buildDecisionTreeNode(
404
419
  return node_index;
405
420
  }
406
421
 
422
+ const partition = (try partitionIndicesForThreshold(
423
+ model,
424
+ workspace,
425
+ x_ptr,
426
+ indices,
427
+ best_feature,
428
+ split.threshold,
429
+ )) orelse {
430
+ const node_index = model.nodes.items.len;
431
+ try model.nodes.append(allocator, TreeNode{
432
+ .prediction = prediction,
433
+ .feature_index = 0,
434
+ .threshold = 0.0,
435
+ .left_index = 0,
436
+ .right_index = 0,
437
+ .is_leaf = true,
438
+ });
439
+ return node_index;
440
+ };
407
441
  const node_index = model.nodes.items.len;
408
442
  try model.nodes.append(allocator, TreeNode{
409
443
  .prediction = prediction,
@@ -416,17 +450,19 @@ fn buildDecisionTreeNode(
416
450
 
417
451
  const left_index = try buildDecisionTreeNode(
418
452
  model,
453
+ workspace,
419
454
  x_ptr,
420
455
  y_ptr,
421
- split.left_indices,
456
+ partition.left_indices,
422
457
  depth + 1,
423
458
  rng,
424
459
  );
425
460
  const right_index = try buildDecisionTreeNode(
426
461
  model,
462
+ workspace,
427
463
  x_ptr,
428
464
  y_ptr,
429
- split.right_indices,
465
+ partition.right_indices,
430
466
  depth + 1,
431
467
  rng,
432
468
  );
@@ -1136,6 +1172,11 @@ pub export fn decision_tree_model_create(
1136
1172
 
1137
1173
  const model = allocator.create(DecisionTreeModel) catch return 0;
1138
1174
  errdefer allocator.destroy(model);
1175
+ const feature_scratch = allocator.alloc(usize, n_features) catch return 0;
1176
+ errdefer allocator.free(feature_scratch);
1177
+ for (feature_scratch, 0..) |*entry, idx| {
1178
+ entry.* = idx;
1179
+ }
1139
1180
  model.* = .{
1140
1181
  .n_features = n_features,
1141
1182
  .max_depth = max_depth,
@@ -1147,6 +1188,7 @@ pub export fn decision_tree_model_create(
1147
1188
  .use_random_state = use_random_state != 0,
1148
1189
  .root_index = 0,
1149
1190
  .has_root = false,
1191
+ .feature_scratch = feature_scratch,
1150
1192
  .nodes = .empty,
1151
1193
  };
1152
1194
  return @intFromPtr(model);
@@ -1154,6 +1196,7 @@ pub export fn decision_tree_model_create(
1154
1196
 
1155
1197
  pub export fn decision_tree_model_destroy(handle: usize) void {
1156
1198
  const model = asDecisionTreeModel(handle) orelse return;
1199
+ allocator.free(model.feature_scratch);
1157
1200
  model.nodes.deinit(allocator);
1158
1201
  allocator.destroy(model);
1159
1202
  }
@@ -1180,8 +1223,11 @@ pub export fn decision_tree_model_fit(
1180
1223
  return 0;
1181
1224
  }
1182
1225
 
1183
- const root_indices = allocator.alloc(usize, root_size) catch return 0;
1184
- defer allocator.free(root_indices);
1226
+ var arena = std.heap.ArenaAllocator.init(allocator);
1227
+ defer arena.deinit();
1228
+ const workspace = arena.allocator();
1229
+
1230
+ const root_indices = workspace.alloc(usize, root_size) catch return 0;
1185
1231
 
1186
1232
  if (sample_count == 0) {
1187
1233
  for (root_indices, 0..) |*entry, idx| {
@@ -1202,7 +1248,7 @@ pub export fn decision_tree_model_fit(
1202
1248
  else
1203
1249
  @as(u32, @truncate(@as(u64, @bitCast(std.time.microTimestamp()))));
1204
1250
  var rng = Mulberry32.init(rng_seed);
1205
- const root_index = buildDecisionTreeNode(model, x_ptr, y_ptr, root_indices, 0, &rng) catch {
1251
+ const root_index = buildDecisionTreeNode(model, workspace, x_ptr, y_ptr, root_indices, 0, &rng) catch {
1206
1252
  model.nodes.clearRetainingCapacity();
1207
1253
  model.has_root = false;
1208
1254
  return 0;
@@ -1243,6 +1289,181 @@ pub export fn decision_tree_model_predict(
1243
1289
  return 1;
1244
1290
  }
1245
1291
 
1292
+ fn resetRandomForestClassifierModel(model: *RandomForestClassifierModel) void {
1293
+ var i: usize = 0;
1294
+ while (i < model.fitted_estimators) : (i += 1) {
1295
+ const tree_handle = model.tree_handles[i];
1296
+ if (tree_handle != 0) {
1297
+ decision_tree_model_destroy(tree_handle);
1298
+ model.tree_handles[i] = 0;
1299
+ }
1300
+ }
1301
+ model.fitted_estimators = 0;
1302
+ }
1303
+
1304
+ pub export fn random_forest_classifier_model_create(
1305
+ n_estimators: usize,
1306
+ max_depth: usize,
1307
+ min_samples_split: usize,
1308
+ min_samples_leaf: usize,
1309
+ max_features_mode: u8,
1310
+ max_features_value: usize,
1311
+ bootstrap: u8,
1312
+ random_state: u32,
1313
+ use_random_state: u8,
1314
+ n_features: usize,
1315
+ ) usize {
1316
+ if (n_features == 0 or max_depth == 0 or n_estimators == 0) {
1317
+ return 0;
1318
+ }
1319
+
1320
+ const model = allocator.create(RandomForestClassifierModel) catch return 0;
1321
+ errdefer allocator.destroy(model);
1322
+ const tree_handles = allocator.alloc(usize, n_estimators) catch return 0;
1323
+ errdefer allocator.free(tree_handles);
1324
+ @memset(tree_handles, 0);
1325
+
1326
+ model.* = .{
1327
+ .n_features = n_features,
1328
+ .n_estimators = n_estimators,
1329
+ .max_depth = max_depth,
1330
+ .min_samples_split = if (min_samples_split < 2) 2 else min_samples_split,
1331
+ .min_samples_leaf = if (min_samples_leaf < 1) 1 else min_samples_leaf,
1332
+ .max_features_mode = max_features_mode,
1333
+ .max_features_value = max_features_value,
1334
+ .bootstrap = bootstrap != 0,
1335
+ .random_state = random_state,
1336
+ .use_random_state = use_random_state != 0,
1337
+ .tree_handles = tree_handles,
1338
+ .fitted_estimators = 0,
1339
+ };
1340
+ return @intFromPtr(model);
1341
+ }
1342
+
1343
+ pub export fn random_forest_classifier_model_destroy(handle: usize) void {
1344
+ const model = asRandomForestClassifierModel(handle) orelse return;
1345
+ resetRandomForestClassifierModel(model);
1346
+ allocator.free(model.tree_handles);
1347
+ allocator.destroy(model);
1348
+ }
1349
+
1350
+ pub export fn random_forest_classifier_model_fit(
1351
+ handle: usize,
1352
+ x_ptr: [*]const f64,
1353
+ y_ptr: [*]const u8,
1354
+ n_samples: usize,
1355
+ n_features: usize,
1356
+ ) u8 {
1357
+ const model = asRandomForestClassifierModel(handle) orelse return 0;
1358
+ if (n_samples == 0 or n_features == 0 or n_features != model.n_features) {
1359
+ return 0;
1360
+ }
1361
+
1362
+ resetRandomForestClassifierModel(model);
1363
+
1364
+ const sample_indices = allocator.alloc(u32, n_samples) catch return 0;
1365
+ defer allocator.free(sample_indices);
1366
+
1367
+ const rng_seed: u32 = if (model.use_random_state)
1368
+ model.random_state
1369
+ else
1370
+ @as(u32, @truncate(@as(u64, @bitCast(std.time.microTimestamp()))));
1371
+ var rng = Mulberry32.init(rng_seed);
1372
+
1373
+ var estimator_index: usize = 0;
1374
+ while (estimator_index < model.n_estimators) : (estimator_index += 1) {
1375
+ const tree_seed: u32 = if (model.use_random_state)
1376
+ model.random_state +% @as(u32, @truncate(estimator_index + 1))
1377
+ else
1378
+ rng.state +% @as(u32, @truncate(estimator_index + 1));
1379
+ const tree_handle = decision_tree_model_create(
1380
+ model.max_depth,
1381
+ model.min_samples_split,
1382
+ model.min_samples_leaf,
1383
+ model.max_features_mode,
1384
+ model.max_features_value,
1385
+ tree_seed,
1386
+ if (model.use_random_state) 1 else 0,
1387
+ model.n_features,
1388
+ );
1389
+ if (tree_handle == 0) {
1390
+ resetRandomForestClassifierModel(model);
1391
+ return 0;
1392
+ }
1393
+
1394
+ if (model.bootstrap) {
1395
+ var i: usize = 0;
1396
+ while (i < n_samples) : (i += 1) {
1397
+ sample_indices[i] = @as(u32, @truncate(rng.nextIndex(n_samples)));
1398
+ }
1399
+ } else {
1400
+ for (sample_indices, 0..) |*entry, idx| {
1401
+ entry.* = @as(u32, @truncate(idx));
1402
+ }
1403
+ }
1404
+
1405
+ const fit_status = decision_tree_model_fit(
1406
+ tree_handle,
1407
+ x_ptr,
1408
+ y_ptr,
1409
+ n_samples,
1410
+ n_features,
1411
+ sample_indices.ptr,
1412
+ n_samples,
1413
+ );
1414
+ if (fit_status != 1) {
1415
+ decision_tree_model_destroy(tree_handle);
1416
+ resetRandomForestClassifierModel(model);
1417
+ return 0;
1418
+ }
1419
+
1420
+ model.tree_handles[estimator_index] = tree_handle;
1421
+ model.fitted_estimators = estimator_index + 1;
1422
+ }
1423
+
1424
+ return 1;
1425
+ }
1426
+
1427
+ pub export fn random_forest_classifier_model_predict(
1428
+ handle: usize,
1429
+ x_ptr: [*]const f64,
1430
+ n_samples: usize,
1431
+ n_features: usize,
1432
+ out_labels_ptr: [*]u8,
1433
+ ) u8 {
1434
+ const model = asRandomForestClassifierModel(handle) orelse return 0;
1435
+ if (model.fitted_estimators == 0 or n_samples == 0 or n_features != model.n_features) {
1436
+ return 0;
1437
+ }
1438
+
1439
+ var i: usize = 0;
1440
+ while (i < n_samples) : (i += 1) {
1441
+ const row_offset = i * model.n_features;
1442
+ var positive_votes: usize = 0;
1443
+ var tree_index: usize = 0;
1444
+ while (tree_index < model.fitted_estimators) : (tree_index += 1) {
1445
+ const tree = asDecisionTreeModel(model.tree_handles[tree_index]) orelse continue;
1446
+ if (!tree.has_root) {
1447
+ continue;
1448
+ }
1449
+
1450
+ var node_index = tree.root_index;
1451
+ while (true) {
1452
+ const node = tree.nodes.items[node_index];
1453
+ if (node.is_leaf) {
1454
+ positive_votes += if (node.prediction == 1) 1 else 0;
1455
+ break;
1456
+ }
1457
+ const value = x_ptr[row_offset + node.feature_index];
1458
+ node_index = if (value <= node.threshold) node.left_index else node.right_index;
1459
+ }
1460
+ }
1461
+ out_labels_ptr[i] = if (positive_votes * 2 >= model.fitted_estimators) 1 else 0;
1462
+ }
1463
+
1464
+ return 1;
1465
+ }
1466
+
1246
1467
  pub export fn logistic_train_epoch(
1247
1468
  x_ptr: [*]const f64,
1248
1469
  y_ptr: [*]const f64,