bun-scikit 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +2 -1
- package/scripts/build-node-addon.ts +17 -1
- package/scripts/check-benchmark-health.ts +50 -5
- package/src/ensemble/RandomForestClassifier.ts +154 -8
- package/src/ensemble/RandomForestRegressor.ts +12 -8
- package/src/native/node-addon/bun_scikit_addon.cpp +158 -0
- package/src/native/zigKernels.ts +89 -0
- package/src/tree/DecisionTreeClassifier.ts +13 -1
- package/zig/kernels.zig +278 -57
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "bun-scikit",
|
|
3
|
-
"version": "0.1.
|
|
3
|
+
"version": "0.1.5",
|
|
4
4
|
"description": "A scikit-learn-inspired machine learning library for Bun/TypeScript.",
|
|
5
5
|
"license": "MIT",
|
|
6
6
|
"module": "index.ts",
|
|
@@ -69,6 +69,7 @@
|
|
|
69
69
|
"devDependencies": {
|
|
70
70
|
"@types/bun": "latest",
|
|
71
71
|
"node-addon-api": "^8.3.1",
|
|
72
|
+
"node-gyp": "^12.2.0",
|
|
72
73
|
"typedoc": "^0.28.14",
|
|
73
74
|
"typescript": "^5.9.2"
|
|
74
75
|
}
|
|
@@ -1,8 +1,24 @@
|
|
|
1
1
|
import { cp, mkdir } from "node:fs/promises";
|
|
2
|
+
import { createRequire } from "node:module";
|
|
2
3
|
import { resolve } from "node:path";
|
|
3
4
|
|
|
5
|
+
function resolveNodeGypCommand(): string[] {
|
|
6
|
+
const npmNodeGyp = process.env.npm_config_node_gyp?.trim();
|
|
7
|
+
if (npmNodeGyp) {
|
|
8
|
+
return ["node", npmNodeGyp, "rebuild"];
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
try {
|
|
12
|
+
const require = createRequire(import.meta.url);
|
|
13
|
+
const nodeGypScript = require.resolve("node-gyp/bin/node-gyp.js");
|
|
14
|
+
return ["node", nodeGypScript, "rebuild"];
|
|
15
|
+
} catch {
|
|
16
|
+
return ["node-gyp", "rebuild"];
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
4
20
|
async function main(): Promise<void> {
|
|
5
|
-
const child = Bun.spawn(
|
|
21
|
+
const child = Bun.spawn(resolveNodeGypCommand(), {
|
|
6
22
|
stdout: "inherit",
|
|
7
23
|
stderr: "inherit",
|
|
8
24
|
});
|
|
@@ -80,6 +80,14 @@ interface BenchmarkSnapshot {
|
|
|
80
80
|
};
|
|
81
81
|
}
|
|
82
82
|
|
|
83
|
+
function parseArgValue(flag: string): string | null {
|
|
84
|
+
const index = Bun.argv.indexOf(flag);
|
|
85
|
+
if (index === -1 || index + 1 >= Bun.argv.length) {
|
|
86
|
+
return null;
|
|
87
|
+
}
|
|
88
|
+
return Bun.argv[index + 1];
|
|
89
|
+
}
|
|
90
|
+
|
|
83
91
|
function speedupThreshold(
|
|
84
92
|
envName: string,
|
|
85
93
|
defaultValue: number,
|
|
@@ -95,13 +103,18 @@ function speedupThreshold(
|
|
|
95
103
|
return parsed;
|
|
96
104
|
}
|
|
97
105
|
|
|
98
|
-
const
|
|
99
|
-
const
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
106
|
+
const inputPath = resolve(parseArgValue("--input") ?? "bench/results/heart-ci-current.json");
|
|
107
|
+
const baselinePath = resolve(
|
|
108
|
+
parseArgValue("--baseline") ?? process.env.BENCH_BASELINE_INPUT ?? "bench/results/heart-ci-latest.json",
|
|
109
|
+
);
|
|
110
|
+
const baselineInputEnabled = inputPath !== baselinePath;
|
|
103
111
|
|
|
104
112
|
const snapshot = JSON.parse(await readFile(inputPath, "utf-8")) as BenchmarkSnapshot;
|
|
113
|
+
const baselineSnapshot = baselineInputEnabled
|
|
114
|
+
? ((await readFile(baselinePath, "utf-8").then((raw) => JSON.parse(raw) as BenchmarkSnapshot).catch(
|
|
115
|
+
() => null,
|
|
116
|
+
)) as BenchmarkSnapshot | null)
|
|
117
|
+
: null;
|
|
105
118
|
|
|
106
119
|
const [bunRegression, sklearnRegression] = snapshot.suites.regression.results;
|
|
107
120
|
const [bunClassification, sklearnClassification] = snapshot.suites.classification.results;
|
|
@@ -136,6 +149,14 @@ const maxZigForestPredictSlowdownVsJs = speedupThreshold(
|
|
|
136
149
|
"BENCH_MAX_ZIG_FOREST_PREDICT_SLOWDOWN_VS_JS",
|
|
137
150
|
20,
|
|
138
151
|
);
|
|
152
|
+
const minZigTreeFitRetentionVsBaseline = speedupThreshold(
|
|
153
|
+
"BENCH_MIN_ZIG_TREE_FIT_RETENTION_VS_BASELINE",
|
|
154
|
+
0.9,
|
|
155
|
+
);
|
|
156
|
+
const minZigForestFitRetentionVsBaseline = speedupThreshold(
|
|
157
|
+
"BENCH_MIN_ZIG_FOREST_FIT_RETENTION_VS_BASELINE",
|
|
158
|
+
0.9,
|
|
159
|
+
);
|
|
139
160
|
|
|
140
161
|
for (const result of [
|
|
141
162
|
bunRegression,
|
|
@@ -296,6 +317,30 @@ if (snapshot.suites.treeBackendModes.enabled) {
|
|
|
296
317
|
`RandomForest zig predict slowdown too large vs js-fast: ${randomForestPredictSlowdown} > ${maxZigForestPredictSlowdownVsJs}.`,
|
|
297
318
|
);
|
|
298
319
|
}
|
|
320
|
+
|
|
321
|
+
if (baselineSnapshot?.suites?.treeBackendModes?.enabled) {
|
|
322
|
+
const [baselineDecisionTreeModes, baselineRandomForestModes] =
|
|
323
|
+
baselineSnapshot.suites.treeBackendModes.models;
|
|
324
|
+
if (baselineDecisionTreeModes && baselineRandomForestModes) {
|
|
325
|
+
const decisionTreeFitRetention =
|
|
326
|
+
decisionTreeModes.comparison.zigFitSpeedupVsJs /
|
|
327
|
+
baselineDecisionTreeModes.comparison.zigFitSpeedupVsJs;
|
|
328
|
+
const randomForestFitRetention =
|
|
329
|
+
randomForestModes.comparison.zigFitSpeedupVsJs /
|
|
330
|
+
baselineRandomForestModes.comparison.zigFitSpeedupVsJs;
|
|
331
|
+
|
|
332
|
+
if (decisionTreeFitRetention < minZigTreeFitRetentionVsBaseline) {
|
|
333
|
+
throw new Error(
|
|
334
|
+
`DecisionTree zig/js fit retention too low vs baseline: ${decisionTreeFitRetention} < ${minZigTreeFitRetentionVsBaseline}.`,
|
|
335
|
+
);
|
|
336
|
+
}
|
|
337
|
+
if (randomForestFitRetention < minZigForestFitRetentionVsBaseline) {
|
|
338
|
+
throw new Error(
|
|
339
|
+
`RandomForest zig/js fit retention too low vs baseline: ${randomForestFitRetention} < ${minZigForestFitRetentionVsBaseline}.`,
|
|
340
|
+
);
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
299
344
|
}
|
|
300
345
|
|
|
301
346
|
console.log("Benchmark comparison health checks passed.");
|
|
@@ -2,6 +2,7 @@ import type { ClassificationModel, Matrix, Vector } from "../types";
|
|
|
2
2
|
import { accuracyScore } from "../metrics/classification";
|
|
3
3
|
import { DecisionTreeClassifier, type MaxFeaturesOption } from "../tree/DecisionTreeClassifier";
|
|
4
4
|
import { assertFiniteVector, validateClassificationInputs } from "../utils/validation";
|
|
5
|
+
import { getZigKernels } from "../native/zigKernels";
|
|
5
6
|
|
|
6
7
|
export interface RandomForestClassifierOptions {
|
|
7
8
|
nEstimators?: number;
|
|
@@ -23,8 +24,18 @@ function mulberry32(seed: number): () => number {
|
|
|
23
24
|
};
|
|
24
25
|
}
|
|
25
26
|
|
|
27
|
+
function isTruthy(value: string | undefined): boolean {
|
|
28
|
+
if (!value) {
|
|
29
|
+
return false;
|
|
30
|
+
}
|
|
31
|
+
const normalized = value.trim().toLowerCase();
|
|
32
|
+
return !(normalized === "0" || normalized === "false" || normalized === "off");
|
|
33
|
+
}
|
|
34
|
+
|
|
26
35
|
export class RandomForestClassifier implements ClassificationModel {
|
|
27
36
|
classes_: Vector = [0, 1];
|
|
37
|
+
fitBackend_: "zig" | "js" = "js";
|
|
38
|
+
fitBackendLibrary_: string | null = null;
|
|
28
39
|
private readonly nEstimators: number;
|
|
29
40
|
private readonly maxDepth?: number;
|
|
30
41
|
private readonly minSamplesSplit?: number;
|
|
@@ -32,6 +43,7 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
32
43
|
private readonly maxFeatures: MaxFeaturesOption;
|
|
33
44
|
private readonly bootstrap: boolean;
|
|
34
45
|
private readonly randomState?: number;
|
|
46
|
+
private nativeModelHandle: bigint | null = null;
|
|
35
47
|
private trees: DecisionTreeClassifier[] = [];
|
|
36
48
|
|
|
37
49
|
constructor(options: RandomForestClassifierOptions = {}) {
|
|
@@ -49,6 +61,7 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
49
61
|
}
|
|
50
62
|
|
|
51
63
|
fit(X: Matrix, y: Vector): this {
|
|
64
|
+
this.disposeNativeModel();
|
|
52
65
|
validateClassificationInputs(X, y);
|
|
53
66
|
|
|
54
67
|
const sampleCount = X.length;
|
|
@@ -56,10 +69,17 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
56
69
|
const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
57
70
|
const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
58
71
|
const yBinary = this.buildBinaryTargets(y);
|
|
72
|
+
const sampleIndices = new Uint32Array(sampleCount);
|
|
73
|
+
this.trees = [];
|
|
74
|
+
if (this.tryFitNativeForest(flattenedX, yBinary, sampleCount, featureCount)) {
|
|
75
|
+
this.fitBackend_ = "zig";
|
|
76
|
+
return this;
|
|
77
|
+
}
|
|
78
|
+
this.fitBackend_ = "js";
|
|
79
|
+
this.fitBackendLibrary_ = null;
|
|
59
80
|
this.trees = new Array(this.nEstimators);
|
|
60
81
|
|
|
61
82
|
for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
|
|
62
|
-
const sampleIndices = new Uint32Array(sampleCount);
|
|
63
83
|
if (this.bootstrap) {
|
|
64
84
|
for (let i = 0; i < sampleCount; i += 1) {
|
|
65
85
|
sampleIndices[i] = Math.floor(random() * sampleCount);
|
|
@@ -86,20 +106,47 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
86
106
|
}
|
|
87
107
|
|
|
88
108
|
predict(X: Matrix): Vector {
|
|
109
|
+
if (this.nativeModelHandle !== null) {
|
|
110
|
+
const kernels = getZigKernels();
|
|
111
|
+
const predict = kernels?.randomForestClassifierModelPredict;
|
|
112
|
+
if (predict) {
|
|
113
|
+
const sampleCount = X.length;
|
|
114
|
+
const featureCount = X[0]?.length ?? 0;
|
|
115
|
+
const flattened = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
116
|
+
const out = new Uint8Array(sampleCount);
|
|
117
|
+
const status = predict(
|
|
118
|
+
this.nativeModelHandle,
|
|
119
|
+
flattened,
|
|
120
|
+
sampleCount,
|
|
121
|
+
featureCount,
|
|
122
|
+
out,
|
|
123
|
+
);
|
|
124
|
+
if (status === 1) {
|
|
125
|
+
return Array.from(out);
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
89
130
|
if (this.trees.length === 0) {
|
|
90
131
|
throw new Error("RandomForestClassifier has not been fitted.");
|
|
91
132
|
}
|
|
92
133
|
|
|
93
|
-
const treePredictions = this.trees.map((tree) => tree.predict(X));
|
|
94
134
|
const sampleCount = X.length;
|
|
95
|
-
const
|
|
135
|
+
const voteCounts = new Uint16Array(sampleCount);
|
|
96
136
|
|
|
97
|
-
for (let
|
|
98
|
-
|
|
99
|
-
for (let
|
|
100
|
-
|
|
137
|
+
for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
|
|
138
|
+
const treePrediction = this.trees[treeIndex].predict(X);
|
|
139
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
140
|
+
if (treePrediction[sampleIndex] === 1) {
|
|
141
|
+
voteCounts[sampleIndex] += 1;
|
|
142
|
+
}
|
|
101
143
|
}
|
|
102
|
-
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
const predictions = new Array<number>(sampleCount);
|
|
147
|
+
const voteThreshold = this.trees.length;
|
|
148
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
149
|
+
predictions[sampleIndex] = voteCounts[sampleIndex] * 2 >= voteThreshold ? 1 : 0;
|
|
103
150
|
}
|
|
104
151
|
|
|
105
152
|
return predictions;
|
|
@@ -110,6 +157,105 @@ export class RandomForestClassifier implements ClassificationModel {
|
|
|
110
157
|
return accuracyScore(y, this.predict(X));
|
|
111
158
|
}
|
|
112
159
|
|
|
160
|
+
dispose(): void {
|
|
161
|
+
this.disposeNativeModel();
|
|
162
|
+
for (let i = 0; i < this.trees.length; i += 1) {
|
|
163
|
+
this.trees[i].dispose();
|
|
164
|
+
}
|
|
165
|
+
this.trees = [];
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
private resolveNativeMaxFeatures(featureCount: number): {
|
|
169
|
+
mode: 0 | 1 | 2 | 3;
|
|
170
|
+
value: number;
|
|
171
|
+
} {
|
|
172
|
+
if (this.maxFeatures === null || this.maxFeatures === undefined) {
|
|
173
|
+
return { mode: 0, value: 0 };
|
|
174
|
+
}
|
|
175
|
+
if (this.maxFeatures === "sqrt") {
|
|
176
|
+
return { mode: 1, value: 0 };
|
|
177
|
+
}
|
|
178
|
+
if (this.maxFeatures === "log2") {
|
|
179
|
+
return { mode: 2, value: 0 };
|
|
180
|
+
}
|
|
181
|
+
const value = Number.isFinite(this.maxFeatures)
|
|
182
|
+
? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
|
|
183
|
+
: featureCount;
|
|
184
|
+
return { mode: 3, value };
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
private tryFitNativeForest(
|
|
188
|
+
flattenedX: Float64Array,
|
|
189
|
+
yBinary: Uint8Array,
|
|
190
|
+
sampleCount: number,
|
|
191
|
+
featureCount: number,
|
|
192
|
+
): boolean {
|
|
193
|
+
if (!isTruthy(process.env.BUN_SCIKIT_EXPERIMENTAL_NATIVE_FOREST)) {
|
|
194
|
+
return false;
|
|
195
|
+
}
|
|
196
|
+
if (process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase() !== "zig") {
|
|
197
|
+
return false;
|
|
198
|
+
}
|
|
199
|
+
const kernels = getZigKernels();
|
|
200
|
+
const create = kernels?.randomForestClassifierModelCreate;
|
|
201
|
+
const fit = kernels?.randomForestClassifierModelFit;
|
|
202
|
+
const destroy = kernels?.randomForestClassifierModelDestroy;
|
|
203
|
+
if (!create || !fit || !destroy) {
|
|
204
|
+
return false;
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
const { mode, value } = this.resolveNativeMaxFeatures(featureCount);
|
|
208
|
+
const useRandomState = this.randomState === undefined ? 0 : 1;
|
|
209
|
+
const randomState = this.randomState ?? 0;
|
|
210
|
+
const handle = create(
|
|
211
|
+
this.nEstimators,
|
|
212
|
+
this.maxDepth ?? 12,
|
|
213
|
+
this.minSamplesSplit ?? 2,
|
|
214
|
+
this.minSamplesLeaf ?? 1,
|
|
215
|
+
mode,
|
|
216
|
+
value,
|
|
217
|
+
this.bootstrap ? 1 : 0,
|
|
218
|
+
randomState >>> 0,
|
|
219
|
+
useRandomState,
|
|
220
|
+
featureCount,
|
|
221
|
+
);
|
|
222
|
+
if (handle === 0n) {
|
|
223
|
+
return false;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
let shouldDestroy = true;
|
|
227
|
+
try {
|
|
228
|
+
const status = fit(handle, flattenedX, yBinary, sampleCount, featureCount);
|
|
229
|
+
if (status !== 1) {
|
|
230
|
+
return false;
|
|
231
|
+
}
|
|
232
|
+
this.nativeModelHandle = handle;
|
|
233
|
+
this.fitBackendLibrary_ = kernels.libraryPath;
|
|
234
|
+
shouldDestroy = false;
|
|
235
|
+
return true;
|
|
236
|
+
} finally {
|
|
237
|
+
if (shouldDestroy) {
|
|
238
|
+
destroy(handle);
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
private disposeNativeModel(): void {
|
|
244
|
+
if (this.nativeModelHandle === null) {
|
|
245
|
+
return;
|
|
246
|
+
}
|
|
247
|
+
const kernels = getZigKernels();
|
|
248
|
+
const destroy = kernels?.randomForestClassifierModelDestroy;
|
|
249
|
+
if (destroy) {
|
|
250
|
+
try {
|
|
251
|
+
destroy(this.nativeModelHandle);
|
|
252
|
+
} catch {
|
|
253
|
+
// best effort cleanup
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
this.nativeModelHandle = null;
|
|
257
|
+
}
|
|
258
|
+
|
|
113
259
|
private flattenTrainingMatrix(
|
|
114
260
|
X: Matrix,
|
|
115
261
|
sampleCount: number,
|
|
@@ -56,10 +56,10 @@ export class RandomForestRegressor implements RegressionModel {
|
|
|
56
56
|
const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
57
57
|
const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
|
|
58
58
|
const yValues = this.toFloat64Vector(y);
|
|
59
|
+
const sampleIndices = new Uint32Array(sampleCount);
|
|
59
60
|
this.trees = new Array(this.nEstimators);
|
|
60
61
|
|
|
61
62
|
for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
|
|
62
|
-
const sampleIndices = new Uint32Array(sampleCount);
|
|
63
63
|
if (this.bootstrap) {
|
|
64
64
|
for (let i = 0; i < sampleCount; i += 1) {
|
|
65
65
|
sampleIndices[i] = Math.floor(random() * sampleCount);
|
|
@@ -90,16 +90,20 @@ export class RandomForestRegressor implements RegressionModel {
|
|
|
90
90
|
throw new Error("RandomForestRegressor has not been fitted.");
|
|
91
91
|
}
|
|
92
92
|
|
|
93
|
-
const treePredictions = this.trees.map((tree) => tree.predict(X));
|
|
94
93
|
const sampleCount = X.length;
|
|
95
|
-
const
|
|
94
|
+
const sums = new Float64Array(sampleCount);
|
|
96
95
|
|
|
97
|
-
for (let
|
|
98
|
-
|
|
99
|
-
for (let
|
|
100
|
-
|
|
96
|
+
for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
|
|
97
|
+
const treePrediction = this.trees[treeIndex].predict(X);
|
|
98
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
99
|
+
sums[sampleIndex] += treePrediction[sampleIndex];
|
|
101
100
|
}
|
|
102
|
-
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
const predictions = new Array<number>(sampleCount);
|
|
104
|
+
const denominator = this.trees.length;
|
|
105
|
+
for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
|
|
106
|
+
predictions[sampleIndex] = sums[sampleIndex] / denominator;
|
|
103
107
|
}
|
|
104
108
|
|
|
105
109
|
return predictions;
|
|
@@ -31,6 +31,20 @@ using DecisionTreeModelCreateFn = NativeHandle (*)(std::size_t, std::size_t, std
|
|
|
31
31
|
using DecisionTreeModelDestroyFn = void (*)(NativeHandle);
|
|
32
32
|
using DecisionTreeModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t, const std::uint32_t*, std::size_t);
|
|
33
33
|
using DecisionTreeModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
|
|
34
|
+
using RandomForestClassifierModelCreateFn = NativeHandle (*)(
|
|
35
|
+
std::size_t,
|
|
36
|
+
std::size_t,
|
|
37
|
+
std::size_t,
|
|
38
|
+
std::size_t,
|
|
39
|
+
std::uint8_t,
|
|
40
|
+
std::size_t,
|
|
41
|
+
std::uint8_t,
|
|
42
|
+
std::uint32_t,
|
|
43
|
+
std::uint8_t,
|
|
44
|
+
std::size_t);
|
|
45
|
+
using RandomForestClassifierModelDestroyFn = void (*)(NativeHandle);
|
|
46
|
+
using RandomForestClassifierModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t);
|
|
47
|
+
using RandomForestClassifierModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
|
|
34
48
|
|
|
35
49
|
struct KernelLibrary {
|
|
36
50
|
#if defined(_WIN32)
|
|
@@ -55,6 +69,10 @@ struct KernelLibrary {
|
|
|
55
69
|
DecisionTreeModelDestroyFn decision_tree_model_destroy{nullptr};
|
|
56
70
|
DecisionTreeModelFitFn decision_tree_model_fit{nullptr};
|
|
57
71
|
DecisionTreeModelPredictFn decision_tree_model_predict{nullptr};
|
|
72
|
+
RandomForestClassifierModelCreateFn random_forest_classifier_model_create{nullptr};
|
|
73
|
+
RandomForestClassifierModelDestroyFn random_forest_classifier_model_destroy{nullptr};
|
|
74
|
+
RandomForestClassifierModelFitFn random_forest_classifier_model_fit{nullptr};
|
|
75
|
+
RandomForestClassifierModelPredictFn random_forest_classifier_model_predict{nullptr};
|
|
58
76
|
};
|
|
59
77
|
|
|
60
78
|
KernelLibrary g_library{};
|
|
@@ -154,6 +172,14 @@ Napi::Value LoadNativeLibrary(const Napi::CallbackInfo& info) {
|
|
|
154
172
|
loadSymbol<DecisionTreeModelFitFn>("decision_tree_model_fit");
|
|
155
173
|
g_library.decision_tree_model_predict =
|
|
156
174
|
loadSymbol<DecisionTreeModelPredictFn>("decision_tree_model_predict");
|
|
175
|
+
g_library.random_forest_classifier_model_create =
|
|
176
|
+
loadSymbol<RandomForestClassifierModelCreateFn>("random_forest_classifier_model_create");
|
|
177
|
+
g_library.random_forest_classifier_model_destroy =
|
|
178
|
+
loadSymbol<RandomForestClassifierModelDestroyFn>("random_forest_classifier_model_destroy");
|
|
179
|
+
g_library.random_forest_classifier_model_fit =
|
|
180
|
+
loadSymbol<RandomForestClassifierModelFitFn>("random_forest_classifier_model_fit");
|
|
181
|
+
g_library.random_forest_classifier_model_predict =
|
|
182
|
+
loadSymbol<RandomForestClassifierModelPredictFn>("random_forest_classifier_model_predict");
|
|
157
183
|
|
|
158
184
|
return Napi::Boolean::New(env, true);
|
|
159
185
|
}
|
|
@@ -567,6 +593,134 @@ Napi::Value DecisionTreeModelPredict(const Napi::CallbackInfo& info) {
|
|
|
567
593
|
return Napi::Number::New(env, status);
|
|
568
594
|
}
|
|
569
595
|
|
|
596
|
+
Napi::Value RandomForestClassifierModelCreate(const Napi::CallbackInfo& info) {
|
|
597
|
+
const Napi::Env env = info.Env();
|
|
598
|
+
if (!isLibraryLoaded(env)) {
|
|
599
|
+
return env.Null();
|
|
600
|
+
}
|
|
601
|
+
if (!g_library.random_forest_classifier_model_create) {
|
|
602
|
+
throwError(env, "Symbol random_forest_classifier_model_create is unavailable.");
|
|
603
|
+
return env.Null();
|
|
604
|
+
}
|
|
605
|
+
if (info.Length() != 10 || !info[0].IsNumber() || !info[1].IsNumber() || !info[2].IsNumber() ||
|
|
606
|
+
!info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsNumber() || !info[6].IsNumber() ||
|
|
607
|
+
!info[7].IsNumber() || !info[8].IsNumber() || !info[9].IsNumber()) {
|
|
608
|
+
throwTypeError(env, "randomForestClassifierModelCreate(nEstimators, maxDepth, minSamplesSplit, minSamplesLeaf, maxFeaturesMode, maxFeaturesValue, bootstrap, randomState, useRandomState, nFeatures) expects ten numbers.");
|
|
609
|
+
return env.Null();
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
const std::size_t n_estimators = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
|
|
613
|
+
const std::size_t max_depth = static_cast<std::size_t>(info[1].As<Napi::Number>().Uint32Value());
|
|
614
|
+
const std::size_t min_samples_split = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
|
|
615
|
+
const std::size_t min_samples_leaf = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
616
|
+
const std::uint8_t max_features_mode = static_cast<std::uint8_t>(info[4].As<Napi::Number>().Uint32Value());
|
|
617
|
+
const std::size_t max_features_value = static_cast<std::size_t>(info[5].As<Napi::Number>().Uint32Value());
|
|
618
|
+
const std::uint8_t bootstrap = static_cast<std::uint8_t>(info[6].As<Napi::Number>().Uint32Value());
|
|
619
|
+
const std::uint32_t random_state = static_cast<std::uint32_t>(info[7].As<Napi::Number>().Uint32Value());
|
|
620
|
+
const std::uint8_t use_random_state = static_cast<std::uint8_t>(info[8].As<Napi::Number>().Uint32Value());
|
|
621
|
+
const std::size_t n_features = static_cast<std::size_t>(info[9].As<Napi::Number>().Uint32Value());
|
|
622
|
+
|
|
623
|
+
const NativeHandle handle = g_library.random_forest_classifier_model_create(
|
|
624
|
+
n_estimators,
|
|
625
|
+
max_depth,
|
|
626
|
+
min_samples_split,
|
|
627
|
+
min_samples_leaf,
|
|
628
|
+
max_features_mode,
|
|
629
|
+
max_features_value,
|
|
630
|
+
bootstrap,
|
|
631
|
+
random_state,
|
|
632
|
+
use_random_state,
|
|
633
|
+
n_features);
|
|
634
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
Napi::Value RandomForestClassifierModelDestroy(const Napi::CallbackInfo& info) {
|
|
638
|
+
const Napi::Env env = info.Env();
|
|
639
|
+
if (!isLibraryLoaded(env)) {
|
|
640
|
+
return env.Null();
|
|
641
|
+
}
|
|
642
|
+
if (!g_library.random_forest_classifier_model_destroy) {
|
|
643
|
+
throwError(env, "Symbol random_forest_classifier_model_destroy is unavailable.");
|
|
644
|
+
return env.Null();
|
|
645
|
+
}
|
|
646
|
+
if (info.Length() != 1) {
|
|
647
|
+
throwTypeError(env, "randomForestClassifierModelDestroy(handle) expects one BigInt.");
|
|
648
|
+
return env.Null();
|
|
649
|
+
}
|
|
650
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
651
|
+
if (env.IsExceptionPending()) {
|
|
652
|
+
return env.Null();
|
|
653
|
+
}
|
|
654
|
+
g_library.random_forest_classifier_model_destroy(handle);
|
|
655
|
+
return env.Undefined();
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
Napi::Value RandomForestClassifierModelFit(const Napi::CallbackInfo& info) {
|
|
659
|
+
const Napi::Env env = info.Env();
|
|
660
|
+
if (!isLibraryLoaded(env)) {
|
|
661
|
+
return env.Null();
|
|
662
|
+
}
|
|
663
|
+
if (!g_library.random_forest_classifier_model_fit) {
|
|
664
|
+
throwError(env, "Symbol random_forest_classifier_model_fit is unavailable.");
|
|
665
|
+
return env.Null();
|
|
666
|
+
}
|
|
667
|
+
if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
|
|
668
|
+
!info[3].IsNumber() || !info[4].IsNumber()) {
|
|
669
|
+
throwTypeError(env, "randomForestClassifierModelFit(handle, x, y, nSamples, nFeatures) has invalid arguments.");
|
|
670
|
+
return env.Null();
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
674
|
+
if (env.IsExceptionPending()) {
|
|
675
|
+
return env.Null();
|
|
676
|
+
}
|
|
677
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
678
|
+
auto y = info[2].As<Napi::Uint8Array>();
|
|
679
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
680
|
+
const std::size_t n_features = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
|
|
681
|
+
|
|
682
|
+
const std::uint8_t status = g_library.random_forest_classifier_model_fit(
|
|
683
|
+
handle,
|
|
684
|
+
x.Data(),
|
|
685
|
+
y.Data(),
|
|
686
|
+
n_samples,
|
|
687
|
+
n_features);
|
|
688
|
+
return Napi::Number::New(env, status);
|
|
689
|
+
}
|
|
690
|
+
|
|
691
|
+
Napi::Value RandomForestClassifierModelPredict(const Napi::CallbackInfo& info) {
|
|
692
|
+
const Napi::Env env = info.Env();
|
|
693
|
+
if (!isLibraryLoaded(env)) {
|
|
694
|
+
return env.Null();
|
|
695
|
+
}
|
|
696
|
+
if (!g_library.random_forest_classifier_model_predict) {
|
|
697
|
+
throwError(env, "Symbol random_forest_classifier_model_predict is unavailable.");
|
|
698
|
+
return env.Null();
|
|
699
|
+
}
|
|
700
|
+
if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsNumber() || !info[3].IsNumber() ||
|
|
701
|
+
!info[4].IsTypedArray()) {
|
|
702
|
+
throwTypeError(env, "randomForestClassifierModelPredict(handle, x, nSamples, nFeatures, outLabels) has invalid arguments.");
|
|
703
|
+
return env.Null();
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
707
|
+
if (env.IsExceptionPending()) {
|
|
708
|
+
return env.Null();
|
|
709
|
+
}
|
|
710
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
711
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
|
|
712
|
+
const std::size_t n_features = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
713
|
+
auto out_labels = info[4].As<Napi::Uint8Array>();
|
|
714
|
+
|
|
715
|
+
const std::uint8_t status = g_library.random_forest_classifier_model_predict(
|
|
716
|
+
handle,
|
|
717
|
+
x.Data(),
|
|
718
|
+
n_samples,
|
|
719
|
+
n_features,
|
|
720
|
+
out_labels.Data());
|
|
721
|
+
return Napi::Number::New(env, status);
|
|
722
|
+
}
|
|
723
|
+
|
|
570
724
|
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
571
725
|
exports.Set("loadLibrary", Napi::Function::New(env, LoadNativeLibrary));
|
|
572
726
|
exports.Set("unloadLibrary", Napi::Function::New(env, UnloadLibrary));
|
|
@@ -590,6 +744,10 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
|
590
744
|
exports.Set("decisionTreeModelDestroy", Napi::Function::New(env, DecisionTreeModelDestroy));
|
|
591
745
|
exports.Set("decisionTreeModelFit", Napi::Function::New(env, DecisionTreeModelFit));
|
|
592
746
|
exports.Set("decisionTreeModelPredict", Napi::Function::New(env, DecisionTreeModelPredict));
|
|
747
|
+
exports.Set("randomForestClassifierModelCreate", Napi::Function::New(env, RandomForestClassifierModelCreate));
|
|
748
|
+
exports.Set("randomForestClassifierModelDestroy", Napi::Function::New(env, RandomForestClassifierModelDestroy));
|
|
749
|
+
exports.Set("randomForestClassifierModelFit", Napi::Function::New(env, RandomForestClassifierModelFit));
|
|
750
|
+
exports.Set("randomForestClassifierModelPredict", Napi::Function::New(env, RandomForestClassifierModelPredict));
|
|
593
751
|
|
|
594
752
|
return exports;
|
|
595
753
|
}
|
package/src/native/zigKernels.ts
CHANGED
|
@@ -88,6 +88,33 @@ type DecisionTreeModelPredictFn = (
|
|
|
88
88
|
nFeatures: number,
|
|
89
89
|
outLabels: Uint8Array,
|
|
90
90
|
) => number;
|
|
91
|
+
type RandomForestClassifierModelCreateFn = (
|
|
92
|
+
nEstimators: number,
|
|
93
|
+
maxDepth: number,
|
|
94
|
+
minSamplesSplit: number,
|
|
95
|
+
minSamplesLeaf: number,
|
|
96
|
+
maxFeaturesMode: number,
|
|
97
|
+
maxFeaturesValue: number,
|
|
98
|
+
bootstrap: number,
|
|
99
|
+
randomState: number,
|
|
100
|
+
useRandomState: number,
|
|
101
|
+
nFeatures: number,
|
|
102
|
+
) => NativeHandle;
|
|
103
|
+
type RandomForestClassifierModelDestroyFn = (handle: NativeHandle) => void;
|
|
104
|
+
type RandomForestClassifierModelFitFn = (
|
|
105
|
+
handle: NativeHandle,
|
|
106
|
+
x: Float64Array,
|
|
107
|
+
y: Uint8Array,
|
|
108
|
+
nSamples: number,
|
|
109
|
+
nFeatures: number,
|
|
110
|
+
) => number;
|
|
111
|
+
type RandomForestClassifierModelPredictFn = (
|
|
112
|
+
handle: NativeHandle,
|
|
113
|
+
x: Float64Array,
|
|
114
|
+
nSamples: number,
|
|
115
|
+
nFeatures: number,
|
|
116
|
+
outLabels: Uint8Array,
|
|
117
|
+
) => number;
|
|
91
118
|
|
|
92
119
|
type LogisticTrainEpochFn = (
|
|
93
120
|
x: Float64Array,
|
|
@@ -138,6 +165,10 @@ interface ZigKernelLibrary {
|
|
|
138
165
|
decision_tree_model_destroy?: DecisionTreeModelDestroyFn;
|
|
139
166
|
decision_tree_model_fit?: DecisionTreeModelFitFn;
|
|
140
167
|
decision_tree_model_predict?: DecisionTreeModelPredictFn;
|
|
168
|
+
random_forest_classifier_model_create?: RandomForestClassifierModelCreateFn;
|
|
169
|
+
random_forest_classifier_model_destroy?: RandomForestClassifierModelDestroyFn;
|
|
170
|
+
random_forest_classifier_model_fit?: RandomForestClassifierModelFitFn;
|
|
171
|
+
random_forest_classifier_model_predict?: RandomForestClassifierModelPredictFn;
|
|
141
172
|
logistic_train_epoch?: LogisticTrainEpochFn;
|
|
142
173
|
logistic_train_epochs?: LogisticTrainEpochsFn;
|
|
143
174
|
};
|
|
@@ -162,6 +193,10 @@ export interface ZigKernels {
|
|
|
162
193
|
decisionTreeModelDestroy: DecisionTreeModelDestroyFn | null;
|
|
163
194
|
decisionTreeModelFit: DecisionTreeModelFitFn | null;
|
|
164
195
|
decisionTreeModelPredict: DecisionTreeModelPredictFn | null;
|
|
196
|
+
randomForestClassifierModelCreate: RandomForestClassifierModelCreateFn | null;
|
|
197
|
+
randomForestClassifierModelDestroy: RandomForestClassifierModelDestroyFn | null;
|
|
198
|
+
randomForestClassifierModelFit: RandomForestClassifierModelFitFn | null;
|
|
199
|
+
randomForestClassifierModelPredict: RandomForestClassifierModelPredictFn | null;
|
|
165
200
|
logisticTrainEpoch: LogisticTrainEpochFn | null;
|
|
166
201
|
logisticTrainEpochs: LogisticTrainEpochsFn | null;
|
|
167
202
|
abiVersion: number | null;
|
|
@@ -247,6 +282,10 @@ interface NodeApiAddon {
|
|
|
247
282
|
decisionTreeModelDestroy?: DecisionTreeModelDestroyFn;
|
|
248
283
|
decisionTreeModelFit?: DecisionTreeModelFitFn;
|
|
249
284
|
decisionTreeModelPredict?: DecisionTreeModelPredictFn;
|
|
285
|
+
randomForestClassifierModelCreate?: RandomForestClassifierModelCreateFn;
|
|
286
|
+
randomForestClassifierModelDestroy?: RandomForestClassifierModelDestroyFn;
|
|
287
|
+
randomForestClassifierModelFit?: RandomForestClassifierModelFitFn;
|
|
288
|
+
randomForestClassifierModelPredict?: RandomForestClassifierModelPredictFn;
|
|
250
289
|
}
|
|
251
290
|
|
|
252
291
|
function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
@@ -289,6 +328,13 @@ function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
|
289
328
|
decisionTreeModelDestroy: addon.decisionTreeModelDestroy ?? null,
|
|
290
329
|
decisionTreeModelFit: addon.decisionTreeModelFit ?? null,
|
|
291
330
|
decisionTreeModelPredict: addon.decisionTreeModelPredict ?? null,
|
|
331
|
+
randomForestClassifierModelCreate:
|
|
332
|
+
addon.randomForestClassifierModelCreate ?? null,
|
|
333
|
+
randomForestClassifierModelDestroy:
|
|
334
|
+
addon.randomForestClassifierModelDestroy ?? null,
|
|
335
|
+
randomForestClassifierModelFit: addon.randomForestClassifierModelFit ?? null,
|
|
336
|
+
randomForestClassifierModelPredict:
|
|
337
|
+
addon.randomForestClassifierModelPredict ?? null,
|
|
292
338
|
logisticTrainEpoch: null,
|
|
293
339
|
logisticTrainEpochs: null,
|
|
294
340
|
abiVersion,
|
|
@@ -432,6 +478,33 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
432
478
|
args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
|
|
433
479
|
returns: FFIType.u8,
|
|
434
480
|
},
|
|
481
|
+
random_forest_classifier_model_create: {
|
|
482
|
+
args: [
|
|
483
|
+
"usize",
|
|
484
|
+
"usize",
|
|
485
|
+
"usize",
|
|
486
|
+
"usize",
|
|
487
|
+
FFIType.u8,
|
|
488
|
+
"usize",
|
|
489
|
+
FFIType.u8,
|
|
490
|
+
FFIType.u32,
|
|
491
|
+
FFIType.u8,
|
|
492
|
+
"usize",
|
|
493
|
+
],
|
|
494
|
+
returns: "usize",
|
|
495
|
+
},
|
|
496
|
+
random_forest_classifier_model_destroy: {
|
|
497
|
+
args: ["usize"],
|
|
498
|
+
returns: FFIType.void,
|
|
499
|
+
},
|
|
500
|
+
random_forest_classifier_model_fit: {
|
|
501
|
+
args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize"],
|
|
502
|
+
returns: FFIType.u8,
|
|
503
|
+
},
|
|
504
|
+
random_forest_classifier_model_predict: {
|
|
505
|
+
args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
|
|
506
|
+
returns: FFIType.u8,
|
|
507
|
+
},
|
|
435
508
|
logistic_train_epoch: {
|
|
436
509
|
args: [
|
|
437
510
|
FFIType.ptr,
|
|
@@ -492,6 +565,14 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
492
565
|
decisionTreeModelDestroy: library.symbols.decision_tree_model_destroy ?? null,
|
|
493
566
|
decisionTreeModelFit: library.symbols.decision_tree_model_fit ?? null,
|
|
494
567
|
decisionTreeModelPredict: library.symbols.decision_tree_model_predict ?? null,
|
|
568
|
+
randomForestClassifierModelCreate:
|
|
569
|
+
library.symbols.random_forest_classifier_model_create ?? null,
|
|
570
|
+
randomForestClassifierModelDestroy:
|
|
571
|
+
library.symbols.random_forest_classifier_model_destroy ?? null,
|
|
572
|
+
randomForestClassifierModelFit:
|
|
573
|
+
library.symbols.random_forest_classifier_model_fit ?? null,
|
|
574
|
+
randomForestClassifierModelPredict:
|
|
575
|
+
library.symbols.random_forest_classifier_model_predict ?? null,
|
|
495
576
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
496
577
|
logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
|
|
497
578
|
abiVersion,
|
|
@@ -555,6 +636,10 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
555
636
|
decisionTreeModelDestroy: null,
|
|
556
637
|
decisionTreeModelFit: null,
|
|
557
638
|
decisionTreeModelPredict: null,
|
|
639
|
+
randomForestClassifierModelCreate: null,
|
|
640
|
+
randomForestClassifierModelDestroy: null,
|
|
641
|
+
randomForestClassifierModelFit: null,
|
|
642
|
+
randomForestClassifierModelPredict: null,
|
|
558
643
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
559
644
|
logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
|
|
560
645
|
abiVersion: null,
|
|
@@ -600,6 +685,10 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
600
685
|
decisionTreeModelDestroy: null,
|
|
601
686
|
decisionTreeModelFit: null,
|
|
602
687
|
decisionTreeModelPredict: null,
|
|
688
|
+
randomForestClassifierModelCreate: null,
|
|
689
|
+
randomForestClassifierModelDestroy: null,
|
|
690
|
+
randomForestClassifierModelFit: null,
|
|
691
|
+
randomForestClassifierModelPredict: null,
|
|
603
692
|
logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
|
|
604
693
|
logisticTrainEpochs: null,
|
|
605
694
|
abiVersion: null,
|
|
@@ -185,7 +185,12 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
185
185
|
}
|
|
186
186
|
}
|
|
187
187
|
|
|
188
|
-
|
|
188
|
+
const predictions = new Array<number>(X.length);
|
|
189
|
+
const root = this.root!;
|
|
190
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
191
|
+
predictions[i] = this.predictOne(X[i], root);
|
|
192
|
+
}
|
|
193
|
+
return predictions;
|
|
189
194
|
}
|
|
190
195
|
|
|
191
196
|
score(X: Matrix, y: Vector): number {
|
|
@@ -193,6 +198,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
193
198
|
return accuracyScore(y, this.predict(X));
|
|
194
199
|
}
|
|
195
200
|
|
|
201
|
+
dispose(): void {
|
|
202
|
+
this.destroyZigModel();
|
|
203
|
+
this.root = null;
|
|
204
|
+
this.flattenedXTrain = null;
|
|
205
|
+
this.yBinaryTrain = null;
|
|
206
|
+
}
|
|
207
|
+
|
|
196
208
|
private predictOne(sample: Vector, node: TreeNode): 0 | 1 {
|
|
197
209
|
let current: TreeNode = node;
|
|
198
210
|
while (
|
package/zig/kernels.zig
CHANGED
|
@@ -74,12 +74,31 @@ const DecisionTreeModel = struct {
|
|
|
74
74
|
use_random_state: bool,
|
|
75
75
|
root_index: usize,
|
|
76
76
|
has_root: bool,
|
|
77
|
+
feature_scratch: []usize,
|
|
77
78
|
nodes: std.ArrayListUnmanaged(TreeNode),
|
|
78
79
|
};
|
|
79
80
|
|
|
80
|
-
const
|
|
81
|
+
const RandomForestClassifierModel = struct {
|
|
82
|
+
n_features: usize,
|
|
83
|
+
n_estimators: usize,
|
|
84
|
+
max_depth: usize,
|
|
85
|
+
min_samples_split: usize,
|
|
86
|
+
min_samples_leaf: usize,
|
|
87
|
+
max_features_mode: u8,
|
|
88
|
+
max_features_value: usize,
|
|
89
|
+
bootstrap: bool,
|
|
90
|
+
random_state: u32,
|
|
91
|
+
use_random_state: bool,
|
|
92
|
+
tree_handles: []usize,
|
|
93
|
+
fitted_estimators: usize,
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
const SplitEvaluation = struct {
|
|
81
97
|
threshold: f64,
|
|
82
98
|
impurity: f64,
|
|
99
|
+
};
|
|
100
|
+
|
|
101
|
+
const SplitPartition = struct {
|
|
83
102
|
left_indices: []usize,
|
|
84
103
|
right_indices: []usize,
|
|
85
104
|
};
|
|
@@ -167,41 +186,33 @@ fn resolveMaxFeatures(model: *const DecisionTreeModel) usize {
|
|
|
167
186
|
}
|
|
168
187
|
}
|
|
169
188
|
|
|
170
|
-
fn
|
|
171
|
-
|
|
172
|
-
|
|
189
|
+
inline fn asRandomForestClassifierModel(handle: usize) ?*RandomForestClassifierModel {
|
|
190
|
+
if (handle == 0) {
|
|
191
|
+
return null;
|
|
192
|
+
}
|
|
193
|
+
return @as(*RandomForestClassifierModel, @ptrFromInt(handle));
|
|
173
194
|
}
|
|
174
195
|
|
|
175
|
-
fn selectCandidateFeatures(model: *
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
const all_features = try allocator.alloc(usize, model.n_features);
|
|
179
|
-
errdefer allocator.free(all_features);
|
|
180
|
-
for (all_features, 0..) |*entry, idx| {
|
|
181
|
-
entry.* = idx;
|
|
182
|
-
}
|
|
183
|
-
return all_features;
|
|
196
|
+
fn selectCandidateFeatures(model: *DecisionTreeModel, rng: *Mulberry32) []const usize {
|
|
197
|
+
for (model.feature_scratch, 0..) |*entry, idx| {
|
|
198
|
+
entry.* = idx;
|
|
184
199
|
}
|
|
185
200
|
|
|
186
|
-
const
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
entry.* = idx;
|
|
201
|
+
const k = resolveMaxFeatures(model);
|
|
202
|
+
if (k >= model.n_features) {
|
|
203
|
+
return model.feature_scratch[0..model.n_features];
|
|
190
204
|
}
|
|
191
205
|
|
|
192
|
-
var i =
|
|
193
|
-
while (i
|
|
194
|
-
|
|
195
|
-
const j = rng.nextIndex(
|
|
196
|
-
const tmp =
|
|
197
|
-
|
|
198
|
-
|
|
206
|
+
var i: usize = 0;
|
|
207
|
+
while (i < k) : (i += 1) {
|
|
208
|
+
const remaining = model.n_features - i;
|
|
209
|
+
const j = i + rng.nextIndex(remaining);
|
|
210
|
+
const tmp = model.feature_scratch[i];
|
|
211
|
+
model.feature_scratch[i] = model.feature_scratch[j];
|
|
212
|
+
model.feature_scratch[j] = tmp;
|
|
199
213
|
}
|
|
200
214
|
|
|
201
|
-
|
|
202
|
-
@memcpy(selected, shuffled[0..k]);
|
|
203
|
-
allocator.free(shuffled);
|
|
204
|
-
return selected;
|
|
215
|
+
return model.feature_scratch[0..k];
|
|
205
216
|
}
|
|
206
217
|
|
|
207
218
|
fn findBestSplitForFeature(
|
|
@@ -210,7 +221,7 @@ fn findBestSplitForFeature(
|
|
|
210
221
|
y_ptr: [*]const u8,
|
|
211
222
|
indices: []const usize,
|
|
212
223
|
feature_index: usize,
|
|
213
|
-
)
|
|
224
|
+
) ?SplitEvaluation {
|
|
214
225
|
const sample_count = indices.len;
|
|
215
226
|
if (sample_count < 2) {
|
|
216
227
|
return null;
|
|
@@ -282,29 +293,41 @@ fn findBestSplitForFeature(
|
|
|
282
293
|
return null;
|
|
283
294
|
}
|
|
284
295
|
|
|
285
|
-
|
|
296
|
+
return SplitEvaluation{
|
|
297
|
+
.threshold = best_threshold,
|
|
298
|
+
.impurity = best_impurity,
|
|
299
|
+
};
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
fn partitionIndicesForThreshold(
|
|
303
|
+
model: *const DecisionTreeModel,
|
|
304
|
+
workspace: std.mem.Allocator,
|
|
305
|
+
x_ptr: [*]const f64,
|
|
306
|
+
indices: []const usize,
|
|
307
|
+
feature_index: usize,
|
|
308
|
+
threshold: f64,
|
|
309
|
+
) !?SplitPartition {
|
|
310
|
+
var left_count: usize = 0;
|
|
286
311
|
for (indices) |sample_index| {
|
|
287
312
|
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
288
|
-
if (value <=
|
|
289
|
-
|
|
313
|
+
if (value <= threshold) {
|
|
314
|
+
left_count += 1;
|
|
290
315
|
}
|
|
291
316
|
}
|
|
292
317
|
|
|
293
|
-
const
|
|
294
|
-
if (
|
|
318
|
+
const right_count = indices.len - left_count;
|
|
319
|
+
if (left_count < model.min_samples_leaf or right_count < model.min_samples_leaf) {
|
|
295
320
|
return null;
|
|
296
321
|
}
|
|
297
322
|
|
|
298
|
-
const left_indices = try
|
|
299
|
-
|
|
300
|
-
const right_indices = try allocator.alloc(usize, right_partition_count);
|
|
301
|
-
errdefer allocator.free(right_indices);
|
|
323
|
+
const left_indices = try workspace.alloc(usize, left_count);
|
|
324
|
+
const right_indices = try workspace.alloc(usize, right_count);
|
|
302
325
|
|
|
303
326
|
var left_write: usize = 0;
|
|
304
327
|
var right_write: usize = 0;
|
|
305
328
|
for (indices) |sample_index| {
|
|
306
329
|
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
307
|
-
if (value <=
|
|
330
|
+
if (value <= threshold) {
|
|
308
331
|
left_indices[left_write] = sample_index;
|
|
309
332
|
left_write += 1;
|
|
310
333
|
} else {
|
|
@@ -313,9 +336,7 @@ fn findBestSplitForFeature(
|
|
|
313
336
|
}
|
|
314
337
|
}
|
|
315
338
|
|
|
316
|
-
return
|
|
317
|
-
.threshold = best_threshold,
|
|
318
|
-
.impurity = best_impurity,
|
|
339
|
+
return SplitPartition{
|
|
319
340
|
.left_indices = left_indices,
|
|
320
341
|
.right_indices = right_indices,
|
|
321
342
|
};
|
|
@@ -323,6 +344,7 @@ fn findBestSplitForFeature(
|
|
|
323
344
|
|
|
324
345
|
fn buildDecisionTreeNode(
|
|
325
346
|
model: *DecisionTreeModel,
|
|
347
|
+
workspace: std.mem.Allocator,
|
|
326
348
|
x_ptr: [*]const f64,
|
|
327
349
|
y_ptr: [*]const u8,
|
|
328
350
|
indices: []const usize,
|
|
@@ -353,25 +375,19 @@ fn buildDecisionTreeNode(
|
|
|
353
375
|
}
|
|
354
376
|
|
|
355
377
|
const parent_impurity = giniImpurity(positive_count, sample_count);
|
|
356
|
-
const candidate_features =
|
|
357
|
-
defer allocator.free(candidate_features);
|
|
378
|
+
const candidate_features = selectCandidateFeatures(model, rng);
|
|
358
379
|
|
|
359
380
|
var best_feature: usize = 0;
|
|
360
|
-
var best_split: ?
|
|
381
|
+
var best_split: ?SplitEvaluation = null;
|
|
361
382
|
var best_found = false;
|
|
362
383
|
|
|
363
384
|
for (candidate_features) |feature_index| {
|
|
364
|
-
const split_opt =
|
|
385
|
+
const split_opt = findBestSplitForFeature(model, x_ptr, y_ptr, indices, feature_index);
|
|
365
386
|
if (split_opt) |split| {
|
|
366
387
|
if (!best_found or split.impurity < best_split.?.impurity) {
|
|
367
|
-
if (best_split) |previous| {
|
|
368
|
-
freeSplit(previous);
|
|
369
|
-
}
|
|
370
388
|
best_split = split;
|
|
371
389
|
best_feature = feature_index;
|
|
372
390
|
best_found = true;
|
|
373
|
-
} else {
|
|
374
|
-
freeSplit(split);
|
|
375
391
|
}
|
|
376
392
|
}
|
|
377
393
|
}
|
|
@@ -390,7 +406,6 @@ fn buildDecisionTreeNode(
|
|
|
390
406
|
}
|
|
391
407
|
|
|
392
408
|
const split = best_split.?;
|
|
393
|
-
defer freeSplit(split);
|
|
394
409
|
if (split.impurity >= parent_impurity - 1e-12) {
|
|
395
410
|
const node_index = model.nodes.items.len;
|
|
396
411
|
try model.nodes.append(allocator, TreeNode{
|
|
@@ -404,6 +419,25 @@ fn buildDecisionTreeNode(
|
|
|
404
419
|
return node_index;
|
|
405
420
|
}
|
|
406
421
|
|
|
422
|
+
const partition = (try partitionIndicesForThreshold(
|
|
423
|
+
model,
|
|
424
|
+
workspace,
|
|
425
|
+
x_ptr,
|
|
426
|
+
indices,
|
|
427
|
+
best_feature,
|
|
428
|
+
split.threshold,
|
|
429
|
+
)) orelse {
|
|
430
|
+
const node_index = model.nodes.items.len;
|
|
431
|
+
try model.nodes.append(allocator, TreeNode{
|
|
432
|
+
.prediction = prediction,
|
|
433
|
+
.feature_index = 0,
|
|
434
|
+
.threshold = 0.0,
|
|
435
|
+
.left_index = 0,
|
|
436
|
+
.right_index = 0,
|
|
437
|
+
.is_leaf = true,
|
|
438
|
+
});
|
|
439
|
+
return node_index;
|
|
440
|
+
};
|
|
407
441
|
const node_index = model.nodes.items.len;
|
|
408
442
|
try model.nodes.append(allocator, TreeNode{
|
|
409
443
|
.prediction = prediction,
|
|
@@ -416,17 +450,19 @@ fn buildDecisionTreeNode(
|
|
|
416
450
|
|
|
417
451
|
const left_index = try buildDecisionTreeNode(
|
|
418
452
|
model,
|
|
453
|
+
workspace,
|
|
419
454
|
x_ptr,
|
|
420
455
|
y_ptr,
|
|
421
|
-
|
|
456
|
+
partition.left_indices,
|
|
422
457
|
depth + 1,
|
|
423
458
|
rng,
|
|
424
459
|
);
|
|
425
460
|
const right_index = try buildDecisionTreeNode(
|
|
426
461
|
model,
|
|
462
|
+
workspace,
|
|
427
463
|
x_ptr,
|
|
428
464
|
y_ptr,
|
|
429
|
-
|
|
465
|
+
partition.right_indices,
|
|
430
466
|
depth + 1,
|
|
431
467
|
rng,
|
|
432
468
|
);
|
|
@@ -1136,6 +1172,11 @@ pub export fn decision_tree_model_create(
|
|
|
1136
1172
|
|
|
1137
1173
|
const model = allocator.create(DecisionTreeModel) catch return 0;
|
|
1138
1174
|
errdefer allocator.destroy(model);
|
|
1175
|
+
const feature_scratch = allocator.alloc(usize, n_features) catch return 0;
|
|
1176
|
+
errdefer allocator.free(feature_scratch);
|
|
1177
|
+
for (feature_scratch, 0..) |*entry, idx| {
|
|
1178
|
+
entry.* = idx;
|
|
1179
|
+
}
|
|
1139
1180
|
model.* = .{
|
|
1140
1181
|
.n_features = n_features,
|
|
1141
1182
|
.max_depth = max_depth,
|
|
@@ -1147,6 +1188,7 @@ pub export fn decision_tree_model_create(
|
|
|
1147
1188
|
.use_random_state = use_random_state != 0,
|
|
1148
1189
|
.root_index = 0,
|
|
1149
1190
|
.has_root = false,
|
|
1191
|
+
.feature_scratch = feature_scratch,
|
|
1150
1192
|
.nodes = .empty,
|
|
1151
1193
|
};
|
|
1152
1194
|
return @intFromPtr(model);
|
|
@@ -1154,6 +1196,7 @@ pub export fn decision_tree_model_create(
|
|
|
1154
1196
|
|
|
1155
1197
|
pub export fn decision_tree_model_destroy(handle: usize) void {
|
|
1156
1198
|
const model = asDecisionTreeModel(handle) orelse return;
|
|
1199
|
+
allocator.free(model.feature_scratch);
|
|
1157
1200
|
model.nodes.deinit(allocator);
|
|
1158
1201
|
allocator.destroy(model);
|
|
1159
1202
|
}
|
|
@@ -1180,8 +1223,11 @@ pub export fn decision_tree_model_fit(
|
|
|
1180
1223
|
return 0;
|
|
1181
1224
|
}
|
|
1182
1225
|
|
|
1183
|
-
|
|
1184
|
-
defer
|
|
1226
|
+
var arena = std.heap.ArenaAllocator.init(allocator);
|
|
1227
|
+
defer arena.deinit();
|
|
1228
|
+
const workspace = arena.allocator();
|
|
1229
|
+
|
|
1230
|
+
const root_indices = workspace.alloc(usize, root_size) catch return 0;
|
|
1185
1231
|
|
|
1186
1232
|
if (sample_count == 0) {
|
|
1187
1233
|
for (root_indices, 0..) |*entry, idx| {
|
|
@@ -1202,7 +1248,7 @@ pub export fn decision_tree_model_fit(
|
|
|
1202
1248
|
else
|
|
1203
1249
|
@as(u32, @truncate(@as(u64, @bitCast(std.time.microTimestamp()))));
|
|
1204
1250
|
var rng = Mulberry32.init(rng_seed);
|
|
1205
|
-
const root_index = buildDecisionTreeNode(model, x_ptr, y_ptr, root_indices, 0, &rng) catch {
|
|
1251
|
+
const root_index = buildDecisionTreeNode(model, workspace, x_ptr, y_ptr, root_indices, 0, &rng) catch {
|
|
1206
1252
|
model.nodes.clearRetainingCapacity();
|
|
1207
1253
|
model.has_root = false;
|
|
1208
1254
|
return 0;
|
|
@@ -1243,6 +1289,181 @@ pub export fn decision_tree_model_predict(
|
|
|
1243
1289
|
return 1;
|
|
1244
1290
|
}
|
|
1245
1291
|
|
|
1292
|
+
fn resetRandomForestClassifierModel(model: *RandomForestClassifierModel) void {
|
|
1293
|
+
var i: usize = 0;
|
|
1294
|
+
while (i < model.fitted_estimators) : (i += 1) {
|
|
1295
|
+
const tree_handle = model.tree_handles[i];
|
|
1296
|
+
if (tree_handle != 0) {
|
|
1297
|
+
decision_tree_model_destroy(tree_handle);
|
|
1298
|
+
model.tree_handles[i] = 0;
|
|
1299
|
+
}
|
|
1300
|
+
}
|
|
1301
|
+
model.fitted_estimators = 0;
|
|
1302
|
+
}
|
|
1303
|
+
|
|
1304
|
+
pub export fn random_forest_classifier_model_create(
|
|
1305
|
+
n_estimators: usize,
|
|
1306
|
+
max_depth: usize,
|
|
1307
|
+
min_samples_split: usize,
|
|
1308
|
+
min_samples_leaf: usize,
|
|
1309
|
+
max_features_mode: u8,
|
|
1310
|
+
max_features_value: usize,
|
|
1311
|
+
bootstrap: u8,
|
|
1312
|
+
random_state: u32,
|
|
1313
|
+
use_random_state: u8,
|
|
1314
|
+
n_features: usize,
|
|
1315
|
+
) usize {
|
|
1316
|
+
if (n_features == 0 or max_depth == 0 or n_estimators == 0) {
|
|
1317
|
+
return 0;
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
const model = allocator.create(RandomForestClassifierModel) catch return 0;
|
|
1321
|
+
errdefer allocator.destroy(model);
|
|
1322
|
+
const tree_handles = allocator.alloc(usize, n_estimators) catch return 0;
|
|
1323
|
+
errdefer allocator.free(tree_handles);
|
|
1324
|
+
@memset(tree_handles, 0);
|
|
1325
|
+
|
|
1326
|
+
model.* = .{
|
|
1327
|
+
.n_features = n_features,
|
|
1328
|
+
.n_estimators = n_estimators,
|
|
1329
|
+
.max_depth = max_depth,
|
|
1330
|
+
.min_samples_split = if (min_samples_split < 2) 2 else min_samples_split,
|
|
1331
|
+
.min_samples_leaf = if (min_samples_leaf < 1) 1 else min_samples_leaf,
|
|
1332
|
+
.max_features_mode = max_features_mode,
|
|
1333
|
+
.max_features_value = max_features_value,
|
|
1334
|
+
.bootstrap = bootstrap != 0,
|
|
1335
|
+
.random_state = random_state,
|
|
1336
|
+
.use_random_state = use_random_state != 0,
|
|
1337
|
+
.tree_handles = tree_handles,
|
|
1338
|
+
.fitted_estimators = 0,
|
|
1339
|
+
};
|
|
1340
|
+
return @intFromPtr(model);
|
|
1341
|
+
}
|
|
1342
|
+
|
|
1343
|
+
pub export fn random_forest_classifier_model_destroy(handle: usize) void {
|
|
1344
|
+
const model = asRandomForestClassifierModel(handle) orelse return;
|
|
1345
|
+
resetRandomForestClassifierModel(model);
|
|
1346
|
+
allocator.free(model.tree_handles);
|
|
1347
|
+
allocator.destroy(model);
|
|
1348
|
+
}
|
|
1349
|
+
|
|
1350
|
+
pub export fn random_forest_classifier_model_fit(
|
|
1351
|
+
handle: usize,
|
|
1352
|
+
x_ptr: [*]const f64,
|
|
1353
|
+
y_ptr: [*]const u8,
|
|
1354
|
+
n_samples: usize,
|
|
1355
|
+
n_features: usize,
|
|
1356
|
+
) u8 {
|
|
1357
|
+
const model = asRandomForestClassifierModel(handle) orelse return 0;
|
|
1358
|
+
if (n_samples == 0 or n_features == 0 or n_features != model.n_features) {
|
|
1359
|
+
return 0;
|
|
1360
|
+
}
|
|
1361
|
+
|
|
1362
|
+
resetRandomForestClassifierModel(model);
|
|
1363
|
+
|
|
1364
|
+
const sample_indices = allocator.alloc(u32, n_samples) catch return 0;
|
|
1365
|
+
defer allocator.free(sample_indices);
|
|
1366
|
+
|
|
1367
|
+
const rng_seed: u32 = if (model.use_random_state)
|
|
1368
|
+
model.random_state
|
|
1369
|
+
else
|
|
1370
|
+
@as(u32, @truncate(@as(u64, @bitCast(std.time.microTimestamp()))));
|
|
1371
|
+
var rng = Mulberry32.init(rng_seed);
|
|
1372
|
+
|
|
1373
|
+
var estimator_index: usize = 0;
|
|
1374
|
+
while (estimator_index < model.n_estimators) : (estimator_index += 1) {
|
|
1375
|
+
const tree_seed: u32 = if (model.use_random_state)
|
|
1376
|
+
model.random_state +% @as(u32, @truncate(estimator_index + 1))
|
|
1377
|
+
else
|
|
1378
|
+
rng.state +% @as(u32, @truncate(estimator_index + 1));
|
|
1379
|
+
const tree_handle = decision_tree_model_create(
|
|
1380
|
+
model.max_depth,
|
|
1381
|
+
model.min_samples_split,
|
|
1382
|
+
model.min_samples_leaf,
|
|
1383
|
+
model.max_features_mode,
|
|
1384
|
+
model.max_features_value,
|
|
1385
|
+
tree_seed,
|
|
1386
|
+
if (model.use_random_state) 1 else 0,
|
|
1387
|
+
model.n_features,
|
|
1388
|
+
);
|
|
1389
|
+
if (tree_handle == 0) {
|
|
1390
|
+
resetRandomForestClassifierModel(model);
|
|
1391
|
+
return 0;
|
|
1392
|
+
}
|
|
1393
|
+
|
|
1394
|
+
if (model.bootstrap) {
|
|
1395
|
+
var i: usize = 0;
|
|
1396
|
+
while (i < n_samples) : (i += 1) {
|
|
1397
|
+
sample_indices[i] = @as(u32, @truncate(rng.nextIndex(n_samples)));
|
|
1398
|
+
}
|
|
1399
|
+
} else {
|
|
1400
|
+
for (sample_indices, 0..) |*entry, idx| {
|
|
1401
|
+
entry.* = @as(u32, @truncate(idx));
|
|
1402
|
+
}
|
|
1403
|
+
}
|
|
1404
|
+
|
|
1405
|
+
const fit_status = decision_tree_model_fit(
|
|
1406
|
+
tree_handle,
|
|
1407
|
+
x_ptr,
|
|
1408
|
+
y_ptr,
|
|
1409
|
+
n_samples,
|
|
1410
|
+
n_features,
|
|
1411
|
+
sample_indices.ptr,
|
|
1412
|
+
n_samples,
|
|
1413
|
+
);
|
|
1414
|
+
if (fit_status != 1) {
|
|
1415
|
+
decision_tree_model_destroy(tree_handle);
|
|
1416
|
+
resetRandomForestClassifierModel(model);
|
|
1417
|
+
return 0;
|
|
1418
|
+
}
|
|
1419
|
+
|
|
1420
|
+
model.tree_handles[estimator_index] = tree_handle;
|
|
1421
|
+
model.fitted_estimators = estimator_index + 1;
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
return 1;
|
|
1425
|
+
}
|
|
1426
|
+
|
|
1427
|
+
pub export fn random_forest_classifier_model_predict(
|
|
1428
|
+
handle: usize,
|
|
1429
|
+
x_ptr: [*]const f64,
|
|
1430
|
+
n_samples: usize,
|
|
1431
|
+
n_features: usize,
|
|
1432
|
+
out_labels_ptr: [*]u8,
|
|
1433
|
+
) u8 {
|
|
1434
|
+
const model = asRandomForestClassifierModel(handle) orelse return 0;
|
|
1435
|
+
if (model.fitted_estimators == 0 or n_samples == 0 or n_features != model.n_features) {
|
|
1436
|
+
return 0;
|
|
1437
|
+
}
|
|
1438
|
+
|
|
1439
|
+
var i: usize = 0;
|
|
1440
|
+
while (i < n_samples) : (i += 1) {
|
|
1441
|
+
const row_offset = i * model.n_features;
|
|
1442
|
+
var positive_votes: usize = 0;
|
|
1443
|
+
var tree_index: usize = 0;
|
|
1444
|
+
while (tree_index < model.fitted_estimators) : (tree_index += 1) {
|
|
1445
|
+
const tree = asDecisionTreeModel(model.tree_handles[tree_index]) orelse continue;
|
|
1446
|
+
if (!tree.has_root) {
|
|
1447
|
+
continue;
|
|
1448
|
+
}
|
|
1449
|
+
|
|
1450
|
+
var node_index = tree.root_index;
|
|
1451
|
+
while (true) {
|
|
1452
|
+
const node = tree.nodes.items[node_index];
|
|
1453
|
+
if (node.is_leaf) {
|
|
1454
|
+
positive_votes += if (node.prediction == 1) 1 else 0;
|
|
1455
|
+
break;
|
|
1456
|
+
}
|
|
1457
|
+
const value = x_ptr[row_offset + node.feature_index];
|
|
1458
|
+
node_index = if (value <= node.threshold) node.left_index else node.right_index;
|
|
1459
|
+
}
|
|
1460
|
+
}
|
|
1461
|
+
out_labels_ptr[i] = if (positive_votes * 2 >= model.fitted_estimators) 1 else 0;
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
return 1;
|
|
1465
|
+
}
|
|
1466
|
+
|
|
1246
1467
|
pub export fn logistic_train_epoch(
|
|
1247
1468
|
x_ptr: [*]const f64,
|
|
1248
1469
|
y_ptr: [*]const f64,
|