bun-scikit 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -88,6 +88,33 @@ type DecisionTreeModelPredictFn = (
88
88
  nFeatures: number,
89
89
  outLabels: Uint8Array,
90
90
  ) => number;
91
+ type RandomForestClassifierModelCreateFn = (
92
+ nEstimators: number,
93
+ maxDepth: number,
94
+ minSamplesSplit: number,
95
+ minSamplesLeaf: number,
96
+ maxFeaturesMode: number,
97
+ maxFeaturesValue: number,
98
+ bootstrap: number,
99
+ randomState: number,
100
+ useRandomState: number,
101
+ nFeatures: number,
102
+ ) => NativeHandle;
103
+ type RandomForestClassifierModelDestroyFn = (handle: NativeHandle) => void;
104
+ type RandomForestClassifierModelFitFn = (
105
+ handle: NativeHandle,
106
+ x: Float64Array,
107
+ y: Uint8Array,
108
+ nSamples: number,
109
+ nFeatures: number,
110
+ ) => number;
111
+ type RandomForestClassifierModelPredictFn = (
112
+ handle: NativeHandle,
113
+ x: Float64Array,
114
+ nSamples: number,
115
+ nFeatures: number,
116
+ outLabels: Uint8Array,
117
+ ) => number;
91
118
 
92
119
  type LogisticTrainEpochFn = (
93
120
  x: Float64Array,
@@ -138,6 +165,10 @@ interface ZigKernelLibrary {
138
165
  decision_tree_model_destroy?: DecisionTreeModelDestroyFn;
139
166
  decision_tree_model_fit?: DecisionTreeModelFitFn;
140
167
  decision_tree_model_predict?: DecisionTreeModelPredictFn;
168
+ random_forest_classifier_model_create?: RandomForestClassifierModelCreateFn;
169
+ random_forest_classifier_model_destroy?: RandomForestClassifierModelDestroyFn;
170
+ random_forest_classifier_model_fit?: RandomForestClassifierModelFitFn;
171
+ random_forest_classifier_model_predict?: RandomForestClassifierModelPredictFn;
141
172
  logistic_train_epoch?: LogisticTrainEpochFn;
142
173
  logistic_train_epochs?: LogisticTrainEpochsFn;
143
174
  };
@@ -162,6 +193,10 @@ export interface ZigKernels {
162
193
  decisionTreeModelDestroy: DecisionTreeModelDestroyFn | null;
163
194
  decisionTreeModelFit: DecisionTreeModelFitFn | null;
164
195
  decisionTreeModelPredict: DecisionTreeModelPredictFn | null;
196
+ randomForestClassifierModelCreate: RandomForestClassifierModelCreateFn | null;
197
+ randomForestClassifierModelDestroy: RandomForestClassifierModelDestroyFn | null;
198
+ randomForestClassifierModelFit: RandomForestClassifierModelFitFn | null;
199
+ randomForestClassifierModelPredict: RandomForestClassifierModelPredictFn | null;
165
200
  logisticTrainEpoch: LogisticTrainEpochFn | null;
166
201
  logisticTrainEpochs: LogisticTrainEpochsFn | null;
167
202
  abiVersion: number | null;
@@ -243,6 +278,14 @@ interface NodeApiAddon {
243
278
  logisticModelFitLbfgs: LogisticModelFitLbfgsFn;
244
279
  logisticModelCopyCoefficients: LogisticModelCopyCoefficientsFn;
245
280
  logisticModelGetIntercept: LogisticModelGetInterceptFn;
281
+ decisionTreeModelCreate?: DecisionTreeModelCreateFn;
282
+ decisionTreeModelDestroy?: DecisionTreeModelDestroyFn;
283
+ decisionTreeModelFit?: DecisionTreeModelFitFn;
284
+ decisionTreeModelPredict?: DecisionTreeModelPredictFn;
285
+ randomForestClassifierModelCreate?: RandomForestClassifierModelCreateFn;
286
+ randomForestClassifierModelDestroy?: RandomForestClassifierModelDestroyFn;
287
+ randomForestClassifierModelFit?: RandomForestClassifierModelFitFn;
288
+ randomForestClassifierModelPredict?: RandomForestClassifierModelPredictFn;
246
289
  }
247
290
 
248
291
  function tryLoadNodeApiKernels(): ZigKernels | null {
@@ -281,10 +324,17 @@ function tryLoadNodeApiKernels(): ZigKernels | null {
281
324
  logisticModelPredict: null,
282
325
  logisticModelCopyCoefficients: addon.logisticModelCopyCoefficients ?? null,
283
326
  logisticModelGetIntercept: addon.logisticModelGetIntercept ?? null,
284
- decisionTreeModelCreate: null,
285
- decisionTreeModelDestroy: null,
286
- decisionTreeModelFit: null,
287
- decisionTreeModelPredict: null,
327
+ decisionTreeModelCreate: addon.decisionTreeModelCreate ?? null,
328
+ decisionTreeModelDestroy: addon.decisionTreeModelDestroy ?? null,
329
+ decisionTreeModelFit: addon.decisionTreeModelFit ?? null,
330
+ decisionTreeModelPredict: addon.decisionTreeModelPredict ?? null,
331
+ randomForestClassifierModelCreate:
332
+ addon.randomForestClassifierModelCreate ?? null,
333
+ randomForestClassifierModelDestroy:
334
+ addon.randomForestClassifierModelDestroy ?? null,
335
+ randomForestClassifierModelFit: addon.randomForestClassifierModelFit ?? null,
336
+ randomForestClassifierModelPredict:
337
+ addon.randomForestClassifierModelPredict ?? null,
288
338
  logisticTrainEpoch: null,
289
339
  logisticTrainEpochs: null,
290
340
  abiVersion,
@@ -403,6 +453,58 @@ export function getZigKernels(): ZigKernels | null {
403
453
  args: ["usize"],
404
454
  returns: FFIType.f64,
405
455
  },
456
+ decision_tree_model_create: {
457
+ args: [
458
+ "usize",
459
+ "usize",
460
+ "usize",
461
+ FFIType.u8,
462
+ "usize",
463
+ FFIType.u32,
464
+ FFIType.u8,
465
+ "usize",
466
+ ],
467
+ returns: "usize",
468
+ },
469
+ decision_tree_model_destroy: {
470
+ args: ["usize"],
471
+ returns: FFIType.void,
472
+ },
473
+ decision_tree_model_fit: {
474
+ args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize", FFIType.ptr, "usize"],
475
+ returns: FFIType.u8,
476
+ },
477
+ decision_tree_model_predict: {
478
+ args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
479
+ returns: FFIType.u8,
480
+ },
481
+ random_forest_classifier_model_create: {
482
+ args: [
483
+ "usize",
484
+ "usize",
485
+ "usize",
486
+ "usize",
487
+ FFIType.u8,
488
+ "usize",
489
+ FFIType.u8,
490
+ FFIType.u32,
491
+ FFIType.u8,
492
+ "usize",
493
+ ],
494
+ returns: "usize",
495
+ },
496
+ random_forest_classifier_model_destroy: {
497
+ args: ["usize"],
498
+ returns: FFIType.void,
499
+ },
500
+ random_forest_classifier_model_fit: {
501
+ args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize"],
502
+ returns: FFIType.u8,
503
+ },
504
+ random_forest_classifier_model_predict: {
505
+ args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
506
+ returns: FFIType.u8,
507
+ },
406
508
  logistic_train_epoch: {
407
509
  args: [
408
510
  FFIType.ptr,
@@ -463,6 +565,14 @@ export function getZigKernels(): ZigKernels | null {
463
565
  decisionTreeModelDestroy: library.symbols.decision_tree_model_destroy ?? null,
464
566
  decisionTreeModelFit: library.symbols.decision_tree_model_fit ?? null,
465
567
  decisionTreeModelPredict: library.symbols.decision_tree_model_predict ?? null,
568
+ randomForestClassifierModelCreate:
569
+ library.symbols.random_forest_classifier_model_create ?? null,
570
+ randomForestClassifierModelDestroy:
571
+ library.symbols.random_forest_classifier_model_destroy ?? null,
572
+ randomForestClassifierModelFit:
573
+ library.symbols.random_forest_classifier_model_fit ?? null,
574
+ randomForestClassifierModelPredict:
575
+ library.symbols.random_forest_classifier_model_predict ?? null,
466
576
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
467
577
  logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
468
578
  abiVersion,
@@ -526,6 +636,10 @@ export function getZigKernels(): ZigKernels | null {
526
636
  decisionTreeModelDestroy: null,
527
637
  decisionTreeModelFit: null,
528
638
  decisionTreeModelPredict: null,
639
+ randomForestClassifierModelCreate: null,
640
+ randomForestClassifierModelDestroy: null,
641
+ randomForestClassifierModelFit: null,
642
+ randomForestClassifierModelPredict: null,
529
643
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
530
644
  logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
531
645
  abiVersion: null,
@@ -571,6 +685,10 @@ export function getZigKernels(): ZigKernels | null {
571
685
  decisionTreeModelDestroy: null,
572
686
  decisionTreeModelFit: null,
573
687
  decisionTreeModelPredict: null,
688
+ randomForestClassifierModelCreate: null,
689
+ randomForestClassifierModelDestroy: null,
690
+ randomForestClassifierModelFit: null,
691
+ randomForestClassifierModelPredict: null,
574
692
  logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
575
693
  logisticTrainEpochs: null,
576
694
  abiVersion: null,
@@ -0,0 +1,46 @@
1
+ import type { Matrix } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export interface BinarizerOptions {
9
+ threshold?: number;
10
+ }
11
+
12
+ export class Binarizer {
13
+ nFeaturesIn_: number | null = null;
14
+ private readonly threshold: number;
15
+
16
+ constructor(options: BinarizerOptions = {}) {
17
+ this.threshold = options.threshold ?? 0;
18
+ if (!Number.isFinite(this.threshold)) {
19
+ throw new Error(`threshold must be finite. Got ${this.threshold}.`);
20
+ }
21
+ }
22
+
23
+ fit(X: Matrix): this {
24
+ assertNonEmptyMatrix(X);
25
+ assertConsistentRowSize(X);
26
+ assertFiniteMatrix(X);
27
+ this.nFeaturesIn_ = X[0].length;
28
+ return this;
29
+ }
30
+
31
+ transform(X: Matrix): Matrix {
32
+ assertNonEmptyMatrix(X);
33
+ assertConsistentRowSize(X);
34
+ assertFiniteMatrix(X);
35
+
36
+ if (this.nFeaturesIn_ !== null && X[0].length !== this.nFeaturesIn_) {
37
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
38
+ }
39
+
40
+ return X.map((row) => row.map((value) => (value > this.threshold ? 1 : 0)));
41
+ }
42
+
43
+ fitTransform(X: Matrix): Matrix {
44
+ return this.fit(X).transform(X);
45
+ }
46
+ }
@@ -0,0 +1,62 @@
1
+ import type { Vector } from "../types";
2
+ import { assertFiniteVector } from "../utils/validation";
3
+
4
+ export class LabelEncoder {
5
+ classes_: number[] | null = null;
6
+ private classToIndex: Map<number, number> | null = null;
7
+
8
+ fit(y: Vector): this {
9
+ if (!Array.isArray(y) || y.length === 0) {
10
+ throw new Error("y must be a non-empty array.");
11
+ }
12
+ assertFiniteVector(y);
13
+
14
+ const classes = Array.from(new Set(y)).sort((a, b) => a - b);
15
+ const classToIndex = new Map<number, number>();
16
+ for (let i = 0; i < classes.length; i += 1) {
17
+ classToIndex.set(classes[i], i);
18
+ }
19
+
20
+ this.classes_ = classes;
21
+ this.classToIndex = classToIndex;
22
+ return this;
23
+ }
24
+
25
+ transform(y: Vector): Vector {
26
+ if (!this.classToIndex) {
27
+ throw new Error("LabelEncoder has not been fitted.");
28
+ }
29
+ assertFiniteVector(y);
30
+
31
+ const encoded = new Array<number>(y.length);
32
+ for (let i = 0; i < y.length; i += 1) {
33
+ const idx = this.classToIndex.get(y[i]);
34
+ if (idx === undefined) {
35
+ throw new Error(`Unknown label ${y[i]} at index ${i}.`);
36
+ }
37
+ encoded[i] = idx;
38
+ }
39
+ return encoded;
40
+ }
41
+
42
+ fitTransform(y: Vector): Vector {
43
+ return this.fit(y).transform(y);
44
+ }
45
+
46
+ inverseTransform(y: Vector): Vector {
47
+ if (!this.classes_) {
48
+ throw new Error("LabelEncoder has not been fitted.");
49
+ }
50
+ assertFiniteVector(y);
51
+
52
+ const decoded = new Array<number>(y.length);
53
+ for (let i = 0; i < y.length; i += 1) {
54
+ const encoded = y[i];
55
+ if (!Number.isInteger(encoded) || encoded < 0 || encoded >= this.classes_.length) {
56
+ throw new Error(`Encoded label out of range at index ${i}: ${encoded}.`);
57
+ }
58
+ decoded[i] = this.classes_[encoded];
59
+ }
60
+ return decoded;
61
+ }
62
+ }
@@ -0,0 +1,77 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export class MaxAbsScaler {
9
+ maxAbs_: Vector | null = null;
10
+
11
+ fit(X: Matrix): this {
12
+ assertNonEmptyMatrix(X);
13
+ assertConsistentRowSize(X);
14
+ assertFiniteMatrix(X);
15
+
16
+ const nFeatures = X[0].length;
17
+ const maxAbs = new Array<number>(nFeatures).fill(0);
18
+
19
+ for (let i = 0; i < X.length; i += 1) {
20
+ for (let j = 0; j < nFeatures; j += 1) {
21
+ const abs = Math.abs(X[i][j]);
22
+ if (abs > maxAbs[j]) {
23
+ maxAbs[j] = abs;
24
+ }
25
+ }
26
+ }
27
+
28
+ for (let j = 0; j < nFeatures; j += 1) {
29
+ if (maxAbs[j] === 0) {
30
+ maxAbs[j] = 1;
31
+ }
32
+ }
33
+
34
+ this.maxAbs_ = maxAbs;
35
+ return this;
36
+ }
37
+
38
+ transform(X: Matrix): Matrix {
39
+ if (!this.maxAbs_) {
40
+ throw new Error("MaxAbsScaler has not been fitted.");
41
+ }
42
+
43
+ assertNonEmptyMatrix(X);
44
+ assertConsistentRowSize(X);
45
+ assertFiniteMatrix(X);
46
+
47
+ if (X[0].length !== this.maxAbs_.length) {
48
+ throw new Error(
49
+ `Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
50
+ );
51
+ }
52
+
53
+ return X.map((row) => row.map((value, featureIdx) => value / this.maxAbs_![featureIdx]));
54
+ }
55
+
56
+ fitTransform(X: Matrix): Matrix {
57
+ return this.fit(X).transform(X);
58
+ }
59
+
60
+ inverseTransform(X: Matrix): Matrix {
61
+ if (!this.maxAbs_) {
62
+ throw new Error("MaxAbsScaler has not been fitted.");
63
+ }
64
+
65
+ assertNonEmptyMatrix(X);
66
+ assertConsistentRowSize(X);
67
+ assertFiniteMatrix(X);
68
+
69
+ if (X[0].length !== this.maxAbs_.length) {
70
+ throw new Error(
71
+ `Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
72
+ );
73
+ }
74
+
75
+ return X.map((row) => row.map((value, featureIdx) => value * this.maxAbs_![featureIdx]));
76
+ }
77
+ }
@@ -0,0 +1,66 @@
1
+ import type { Matrix } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export interface NormalizerOptions {
9
+ norm?: "l1" | "l2" | "max";
10
+ }
11
+
12
+ export class Normalizer {
13
+ private readonly norm: "l1" | "l2" | "max";
14
+ private nFeatures_: number | null = null;
15
+
16
+ constructor(options: NormalizerOptions = {}) {
17
+ this.norm = options.norm ?? "l2";
18
+ }
19
+
20
+ fit(X: Matrix): this {
21
+ assertNonEmptyMatrix(X);
22
+ assertConsistentRowSize(X);
23
+ assertFiniteMatrix(X);
24
+ this.nFeatures_ = X[0].length;
25
+ return this;
26
+ }
27
+
28
+ transform(X: Matrix): Matrix {
29
+ assertNonEmptyMatrix(X);
30
+ assertConsistentRowSize(X);
31
+ assertFiniteMatrix(X);
32
+ if (this.nFeatures_ !== null && X[0].length !== this.nFeatures_) {
33
+ throw new Error(`Feature size mismatch. Expected ${this.nFeatures_}, got ${X[0].length}.`);
34
+ }
35
+
36
+ return X.map((row) => {
37
+ let scale = 0;
38
+ if (this.norm === "l1") {
39
+ for (let i = 0; i < row.length; i += 1) {
40
+ scale += Math.abs(row[i]);
41
+ }
42
+ } else if (this.norm === "l2") {
43
+ for (let i = 0; i < row.length; i += 1) {
44
+ scale += row[i] * row[i];
45
+ }
46
+ scale = Math.sqrt(scale);
47
+ } else {
48
+ for (let i = 0; i < row.length; i += 1) {
49
+ const abs = Math.abs(row[i]);
50
+ if (abs > scale) {
51
+ scale = abs;
52
+ }
53
+ }
54
+ }
55
+
56
+ if (scale === 0) {
57
+ return [...row];
58
+ }
59
+ return row.map((value) => value / scale);
60
+ });
61
+ }
62
+
63
+ fitTransform(X: Matrix): Matrix {
64
+ return this.fit(X).transform(X);
65
+ }
66
+ }
@@ -6,6 +6,7 @@ import {
6
6
  validateClassificationInputs,
7
7
  } from "../utils/validation";
8
8
  import { accuracyScore } from "../metrics/classification";
9
+ import { getZigKernels } from "../native/zigKernels";
9
10
 
10
11
  export type MaxFeaturesOption = "sqrt" | "log2" | number | null;
11
12
 
@@ -38,6 +39,11 @@ interface SplitPartition {
38
39
 
39
40
  const MAX_THRESHOLD_BINS = 128;
40
41
 
42
+ function isZigTreeBackendEnabled(): boolean {
43
+ const mode = process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase();
44
+ return mode === "zig" || mode === "native";
45
+ }
46
+
41
47
  function mulberry32(seed: number): () => number {
42
48
  let state = seed >>> 0;
43
49
  return () => {
@@ -59,6 +65,8 @@ function giniImpurity(positiveCount: number, sampleCount: number): number {
59
65
 
60
66
  export class DecisionTreeClassifier implements ClassificationModel {
61
67
  classes_: Vector = [0, 1];
68
+ fitBackend_: "zig" | "js" = "js";
69
+ fitBackendLibrary_: string | null = null;
62
70
  private readonly maxDepth: number;
63
71
  private readonly minSamplesSplit: number;
64
72
  private readonly minSamplesLeaf: number;
@@ -73,6 +81,7 @@ export class DecisionTreeClassifier implements ClassificationModel {
73
81
  private featureSelectionMarks: Uint8Array | null = null;
74
82
  private binTotals: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
75
83
  private binPositives: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
84
+ private zigModelHandle: bigint | null = null;
76
85
 
77
86
  constructor(options: DecisionTreeClassifierOptions = {}) {
78
87
  this.maxDepth = options.maxDepth ?? 12;
@@ -90,6 +99,8 @@ export class DecisionTreeClassifier implements ClassificationModel {
90
99
  flattenedXTrain?: Float64Array,
91
100
  yBinaryTrain?: Uint8Array,
92
101
  ): this {
102
+ this.destroyZigModel();
103
+
93
104
  if (!skipValidation) {
94
105
  validateClassificationInputs(X, y);
95
106
  }
@@ -103,18 +114,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
103
114
  this.featureSelectionMarks = new Uint8Array(this.featureCount);
104
115
  this.random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
105
116
 
106
- let rootIndices: number[];
117
+ let validatedSampleIndices: Uint32Array | null = null;
107
118
  if (sampleIndices) {
108
119
  if (sampleIndices.length === 0) {
109
120
  throw new Error("sampleIndices must not be empty.");
110
121
  }
122
+ validatedSampleIndices = new Uint32Array(sampleIndices.length);
111
123
  for (let i = 0; i < sampleIndices.length; i += 1) {
112
124
  const index = sampleIndices[i];
113
125
  if (!Number.isInteger(index) || index < 0 || index >= X.length) {
114
126
  throw new Error(`sampleIndices contains invalid index: ${index}.`);
115
127
  }
128
+ validatedSampleIndices[i] = index;
116
129
  }
117
- rootIndices = Array.from(sampleIndices);
130
+ }
131
+
132
+ if (isZigTreeBackendEnabled() && this.tryFitWithZig(X.length, validatedSampleIndices)) {
133
+ return this;
134
+ }
135
+
136
+ let rootIndices: number[];
137
+ if (validatedSampleIndices) {
138
+ rootIndices = Array.from(validatedSampleIndices);
118
139
  } else {
119
140
  rootIndices = new Array<number>(X.length);
120
141
  for (let idx = 0; idx < X.length; idx += 1) {
@@ -123,11 +144,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
123
144
  }
124
145
 
125
146
  this.root = this.buildTree(rootIndices, 0);
147
+ this.fitBackend_ = "js";
148
+ this.fitBackendLibrary_ = null;
126
149
  return this;
127
150
  }
128
151
 
129
152
  predict(X: Matrix): Vector {
130
- if (!this.root || this.featureCount === 0) {
153
+ if ((this.root === null && this.zigModelHandle === null) || this.featureCount === 0) {
131
154
  throw new Error("DecisionTreeClassifier has not been fitted.");
132
155
  }
133
156
 
@@ -140,7 +163,34 @@ export class DecisionTreeClassifier implements ClassificationModel {
140
163
  );
141
164
  }
142
165
 
143
- return X.map((sample) => this.predictOne(sample, this.root!));
166
+ if (this.zigModelHandle !== null) {
167
+ const kernels = getZigKernels();
168
+ const nativePredict = kernels?.decisionTreeModelPredict;
169
+ if (nativePredict) {
170
+ const flattenedX = this.flattenTrainingMatrix(X);
171
+ const outLabels = new Uint8Array(X.length);
172
+ const status = nativePredict(
173
+ this.zigModelHandle,
174
+ flattenedX,
175
+ X.length,
176
+ this.featureCount,
177
+ outLabels,
178
+ );
179
+ if (status === 1) {
180
+ return Array.from(outLabels);
181
+ }
182
+ }
183
+ if (!this.root) {
184
+ throw new Error("Native DecisionTree predict failed and no JS fallback tree is available.");
185
+ }
186
+ }
187
+
188
+ const predictions = new Array<number>(X.length);
189
+ const root = this.root!;
190
+ for (let i = 0; i < X.length; i += 1) {
191
+ predictions[i] = this.predictOne(X[i], root);
192
+ }
193
+ return predictions;
144
194
  }
145
195
 
146
196
  score(X: Matrix, y: Vector): number {
@@ -148,6 +198,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
148
198
  return accuracyScore(y, this.predict(X));
149
199
  }
150
200
 
201
+ dispose(): void {
202
+ this.destroyZigModel();
203
+ this.root = null;
204
+ this.flattenedXTrain = null;
205
+ this.yBinaryTrain = null;
206
+ }
207
+
151
208
  private predictOne(sample: Vector, node: TreeNode): 0 | 1 {
152
209
  let current: TreeNode = node;
153
210
  while (
@@ -228,9 +285,107 @@ export class DecisionTreeClassifier implements ClassificationModel {
228
285
  if (this.maxFeatures === "log2") {
229
286
  return Math.max(1, Math.floor(Math.log2(featureCount)));
230
287
  }
288
+ if (!Number.isFinite(this.maxFeatures)) {
289
+ return featureCount;
290
+ }
231
291
  return Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)));
232
292
  }
233
293
 
294
+ private resolveNativeMaxFeatures(featureCount: number): {
295
+ mode: 0 | 1 | 2 | 3;
296
+ value: number;
297
+ } {
298
+ if (this.maxFeatures === null || this.maxFeatures === undefined) {
299
+ return { mode: 0, value: 0 };
300
+ }
301
+ if (this.maxFeatures === "sqrt") {
302
+ return { mode: 1, value: 0 };
303
+ }
304
+ if (this.maxFeatures === "log2") {
305
+ return { mode: 2, value: 0 };
306
+ }
307
+ const value = Number.isFinite(this.maxFeatures)
308
+ ? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
309
+ : featureCount;
310
+ return { mode: 3, value };
311
+ }
312
+
313
+ private tryFitWithZig(
314
+ sampleCount: number,
315
+ sampleIndices: Uint32Array | null,
316
+ ): boolean {
317
+ const kernels = getZigKernels();
318
+ const create = kernels?.decisionTreeModelCreate;
319
+ const fit = kernels?.decisionTreeModelFit;
320
+ const destroy = kernels?.decisionTreeModelDestroy;
321
+ if (!create || !fit || !destroy) {
322
+ return false;
323
+ }
324
+
325
+ const { mode, value } = this.resolveNativeMaxFeatures(this.featureCount);
326
+ const useRandomState = this.randomState === undefined ? 0 : 1;
327
+ const randomState = this.randomState ?? 0;
328
+ const handle = create(
329
+ this.maxDepth,
330
+ this.minSamplesSplit,
331
+ this.minSamplesLeaf,
332
+ mode,
333
+ value,
334
+ randomState >>> 0,
335
+ useRandomState,
336
+ this.featureCount,
337
+ );
338
+ if (handle === 0n) {
339
+ return false;
340
+ }
341
+
342
+ let shouldDestroy = true;
343
+ try {
344
+ const emptySampleIndices = new Uint32Array(0);
345
+ const status = fit(
346
+ handle,
347
+ this.flattenedXTrain!,
348
+ this.yBinaryTrain!,
349
+ sampleCount,
350
+ this.featureCount,
351
+ sampleIndices ?? emptySampleIndices,
352
+ sampleIndices?.length ?? 0,
353
+ );
354
+ if (status !== 1) {
355
+ return false;
356
+ }
357
+
358
+ this.zigModelHandle = handle;
359
+ this.root = null;
360
+ this.fitBackend_ = "zig";
361
+ this.fitBackendLibrary_ = kernels.libraryPath;
362
+ shouldDestroy = false;
363
+ return true;
364
+ } catch {
365
+ return false;
366
+ } finally {
367
+ if (shouldDestroy) {
368
+ destroy(handle);
369
+ }
370
+ }
371
+ }
372
+
373
+ private destroyZigModel(): void {
374
+ if (this.zigModelHandle === null) {
375
+ return;
376
+ }
377
+ const kernels = getZigKernels();
378
+ const destroy = kernels?.decisionTreeModelDestroy;
379
+ if (destroy) {
380
+ try {
381
+ destroy(this.zigModelHandle);
382
+ } catch {
383
+ // no-op: cleanup best effort
384
+ }
385
+ }
386
+ this.zigModelHandle = null;
387
+ }
388
+
234
389
  private selectFeatureIndices(featureCount: number): number[] {
235
390
  const k = this.resolveMaxFeatures(featureCount);
236
391
  if (k >= featureCount) {