bun-scikit 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import {
6
6
  validateClassificationInputs,
7
7
  } from "../utils/validation";
8
8
  import { accuracyScore } from "../metrics/classification";
9
+ import { getZigKernels } from "../native/zigKernels";
9
10
 
10
11
  export type MaxFeaturesOption = "sqrt" | "log2" | number | null;
11
12
 
@@ -38,6 +39,11 @@ interface SplitPartition {
38
39
 
39
40
  const MAX_THRESHOLD_BINS = 128;
40
41
 
42
+ function isZigTreeBackendEnabled(): boolean {
43
+ const mode = process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase();
44
+ return mode === "zig" || mode === "native";
45
+ }
46
+
41
47
  function mulberry32(seed: number): () => number {
42
48
  let state = seed >>> 0;
43
49
  return () => {
@@ -59,6 +65,8 @@ function giniImpurity(positiveCount: number, sampleCount: number): number {
59
65
 
60
66
  export class DecisionTreeClassifier implements ClassificationModel {
61
67
  classes_: Vector = [0, 1];
68
+ fitBackend_: "zig" | "js" = "js";
69
+ fitBackendLibrary_: string | null = null;
62
70
  private readonly maxDepth: number;
63
71
  private readonly minSamplesSplit: number;
64
72
  private readonly minSamplesLeaf: number;
@@ -73,6 +81,7 @@ export class DecisionTreeClassifier implements ClassificationModel {
73
81
  private featureSelectionMarks: Uint8Array | null = null;
74
82
  private binTotals: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
75
83
  private binPositives: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
84
+ private zigModelHandle: bigint | null = null;
76
85
 
77
86
  constructor(options: DecisionTreeClassifierOptions = {}) {
78
87
  this.maxDepth = options.maxDepth ?? 12;
@@ -90,6 +99,8 @@ export class DecisionTreeClassifier implements ClassificationModel {
90
99
  flattenedXTrain?: Float64Array,
91
100
  yBinaryTrain?: Uint8Array,
92
101
  ): this {
102
+ this.destroyZigModel();
103
+
93
104
  if (!skipValidation) {
94
105
  validateClassificationInputs(X, y);
95
106
  }
@@ -103,18 +114,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
103
114
  this.featureSelectionMarks = new Uint8Array(this.featureCount);
104
115
  this.random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
105
116
 
106
- let rootIndices: number[];
117
+ let validatedSampleIndices: Uint32Array | null = null;
107
118
  if (sampleIndices) {
108
119
  if (sampleIndices.length === 0) {
109
120
  throw new Error("sampleIndices must not be empty.");
110
121
  }
122
+ validatedSampleIndices = new Uint32Array(sampleIndices.length);
111
123
  for (let i = 0; i < sampleIndices.length; i += 1) {
112
124
  const index = sampleIndices[i];
113
125
  if (!Number.isInteger(index) || index < 0 || index >= X.length) {
114
126
  throw new Error(`sampleIndices contains invalid index: ${index}.`);
115
127
  }
128
+ validatedSampleIndices[i] = index;
116
129
  }
117
- rootIndices = Array.from(sampleIndices);
130
+ }
131
+
132
+ if (isZigTreeBackendEnabled() && this.tryFitWithZig(X.length, validatedSampleIndices)) {
133
+ return this;
134
+ }
135
+
136
+ let rootIndices: number[];
137
+ if (validatedSampleIndices) {
138
+ rootIndices = Array.from(validatedSampleIndices);
118
139
  } else {
119
140
  rootIndices = new Array<number>(X.length);
120
141
  for (let idx = 0; idx < X.length; idx += 1) {
@@ -123,11 +144,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
123
144
  }
124
145
 
125
146
  this.root = this.buildTree(rootIndices, 0);
147
+ this.fitBackend_ = "js";
148
+ this.fitBackendLibrary_ = null;
126
149
  return this;
127
150
  }
128
151
 
129
152
  predict(X: Matrix): Vector {
130
- if (!this.root || this.featureCount === 0) {
153
+ if ((this.root === null && this.zigModelHandle === null) || this.featureCount === 0) {
131
154
  throw new Error("DecisionTreeClassifier has not been fitted.");
132
155
  }
133
156
 
@@ -140,6 +163,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
140
163
  );
141
164
  }
142
165
 
166
+ if (this.zigModelHandle !== null) {
167
+ const kernels = getZigKernels();
168
+ const nativePredict = kernels?.decisionTreeModelPredict;
169
+ if (nativePredict) {
170
+ const flattenedX = this.flattenTrainingMatrix(X);
171
+ const outLabels = new Uint8Array(X.length);
172
+ const status = nativePredict(
173
+ this.zigModelHandle,
174
+ flattenedX,
175
+ X.length,
176
+ this.featureCount,
177
+ outLabels,
178
+ );
179
+ if (status === 1) {
180
+ return Array.from(outLabels);
181
+ }
182
+ }
183
+ if (!this.root) {
184
+ throw new Error("Native DecisionTree predict failed and no JS fallback tree is available.");
185
+ }
186
+ }
187
+
143
188
  return X.map((sample) => this.predictOne(sample, this.root!));
144
189
  }
145
190
 
@@ -228,9 +273,107 @@ export class DecisionTreeClassifier implements ClassificationModel {
228
273
  if (this.maxFeatures === "log2") {
229
274
  return Math.max(1, Math.floor(Math.log2(featureCount)));
230
275
  }
276
+ if (!Number.isFinite(this.maxFeatures)) {
277
+ return featureCount;
278
+ }
231
279
  return Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)));
232
280
  }
233
281
 
282
+ private resolveNativeMaxFeatures(featureCount: number): {
283
+ mode: 0 | 1 | 2 | 3;
284
+ value: number;
285
+ } {
286
+ if (this.maxFeatures === null || this.maxFeatures === undefined) {
287
+ return { mode: 0, value: 0 };
288
+ }
289
+ if (this.maxFeatures === "sqrt") {
290
+ return { mode: 1, value: 0 };
291
+ }
292
+ if (this.maxFeatures === "log2") {
293
+ return { mode: 2, value: 0 };
294
+ }
295
+ const value = Number.isFinite(this.maxFeatures)
296
+ ? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
297
+ : featureCount;
298
+ return { mode: 3, value };
299
+ }
300
+
301
+ private tryFitWithZig(
302
+ sampleCount: number,
303
+ sampleIndices: Uint32Array | null,
304
+ ): boolean {
305
+ const kernels = getZigKernels();
306
+ const create = kernels?.decisionTreeModelCreate;
307
+ const fit = kernels?.decisionTreeModelFit;
308
+ const destroy = kernels?.decisionTreeModelDestroy;
309
+ if (!create || !fit || !destroy) {
310
+ return false;
311
+ }
312
+
313
+ const { mode, value } = this.resolveNativeMaxFeatures(this.featureCount);
314
+ const useRandomState = this.randomState === undefined ? 0 : 1;
315
+ const randomState = this.randomState ?? 0;
316
+ const handle = create(
317
+ this.maxDepth,
318
+ this.minSamplesSplit,
319
+ this.minSamplesLeaf,
320
+ mode,
321
+ value,
322
+ randomState >>> 0,
323
+ useRandomState,
324
+ this.featureCount,
325
+ );
326
+ if (handle === 0n) {
327
+ return false;
328
+ }
329
+
330
+ let shouldDestroy = true;
331
+ try {
332
+ const emptySampleIndices = new Uint32Array(0);
333
+ const status = fit(
334
+ handle,
335
+ this.flattenedXTrain!,
336
+ this.yBinaryTrain!,
337
+ sampleCount,
338
+ this.featureCount,
339
+ sampleIndices ?? emptySampleIndices,
340
+ sampleIndices?.length ?? 0,
341
+ );
342
+ if (status !== 1) {
343
+ return false;
344
+ }
345
+
346
+ this.zigModelHandle = handle;
347
+ this.root = null;
348
+ this.fitBackend_ = "zig";
349
+ this.fitBackendLibrary_ = kernels.libraryPath;
350
+ shouldDestroy = false;
351
+ return true;
352
+ } catch {
353
+ return false;
354
+ } finally {
355
+ if (shouldDestroy) {
356
+ destroy(handle);
357
+ }
358
+ }
359
+ }
360
+
361
+ private destroyZigModel(): void {
362
+ if (this.zigModelHandle === null) {
363
+ return;
364
+ }
365
+ const kernels = getZigKernels();
366
+ const destroy = kernels?.decisionTreeModelDestroy;
367
+ if (destroy) {
368
+ try {
369
+ destroy(this.zigModelHandle);
370
+ } catch {
371
+ // no-op: cleanup best effort
372
+ }
373
+ }
374
+ this.zigModelHandle = null;
375
+ }
376
+
234
377
  private selectFeatureIndices(featureCount: number): number[] {
235
378
  const k = this.resolveMaxFeatures(featureCount);
236
379
  if (k >= featureCount) {
package/zig/kernels.zig CHANGED
@@ -84,6 +84,8 @@ const SplitResult = struct {
84
84
  right_indices: []usize,
85
85
  };
86
86
 
87
+ const MAX_THRESHOLD_BINS: usize = 128;
88
+
87
89
  const Mulberry32 = struct {
88
90
  state: u32,
89
91
 
@@ -213,55 +215,55 @@ fn findBestSplitForFeature(
213
215
  if (sample_count < 2) {
214
216
  return null;
215
217
  }
218
+ var min_value = std.math.inf(f64);
219
+ var max_value = -std.math.inf(f64);
220
+ var total_positive: usize = 0;
221
+ for (indices) |sample_index| {
222
+ const value = x_ptr[sample_index * model.n_features + feature_index];
223
+ if (value < min_value) {
224
+ min_value = value;
225
+ }
226
+ if (value > max_value) {
227
+ max_value = value;
228
+ }
229
+ total_positive += y_ptr[sample_index];
230
+ }
216
231
 
217
- const sorted_indices = try allocator.alloc(usize, sample_count);
218
- defer allocator.free(sorted_indices);
219
- @memcpy(sorted_indices, indices);
232
+ if (!std.math.isFinite(min_value) or !std.math.isFinite(max_value) or min_value == max_value) {
233
+ return null;
234
+ }
220
235
 
221
- const SortContext = struct {
222
- x_ptr: [*]const f64,
223
- n_features: usize,
224
- feature_index: usize,
225
- fn lessThan(ctx: @This(), a: usize, b: usize) bool {
226
- return ctx.x_ptr[a * ctx.n_features + ctx.feature_index] <
227
- ctx.x_ptr[b * ctx.n_features + ctx.feature_index];
228
- }
229
- };
230
- std.sort.heap(usize, sorted_indices, SortContext{
231
- .x_ptr = x_ptr,
232
- .n_features = model.n_features,
233
- .feature_index = feature_index,
234
- }, SortContext.lessThan);
236
+ const dynamic_bins = @as(usize, @intFromFloat(@floor(@sqrt(@as(f64, @floatFromInt(sample_count))))));
237
+ const bin_count = std.math.clamp(dynamic_bins, 16, MAX_THRESHOLD_BINS);
238
+ var bin_totals: [MAX_THRESHOLD_BINS]usize = [_]usize{0} ** MAX_THRESHOLD_BINS;
239
+ var bin_positives: [MAX_THRESHOLD_BINS]usize = [_]usize{0} ** MAX_THRESHOLD_BINS;
240
+ const value_range = max_value - min_value;
235
241
 
236
- var total_positive: usize = 0;
237
- for (sorted_indices) |sample_index| {
238
- total_positive += y_ptr[sample_index];
242
+ for (indices) |sample_index| {
243
+ const value = x_ptr[sample_index * model.n_features + feature_index];
244
+ var bin_index = @as(usize, @intFromFloat(@floor(((value - min_value) / value_range) * @as(f64, @floatFromInt(bin_count)))));
245
+ if (bin_index >= bin_count) {
246
+ bin_index = bin_count - 1;
247
+ }
248
+ bin_totals[bin_index] += 1;
249
+ bin_positives[bin_index] += y_ptr[sample_index];
239
250
  }
240
251
 
241
252
  var left_count: usize = 0;
242
253
  var left_positive: usize = 0;
243
254
  var best_impurity = std.math.inf(f64);
244
255
  var best_threshold: f64 = 0.0;
245
- var best_split_index: usize = 0;
246
256
  var found = false;
247
257
 
248
- var i: usize = 1;
249
- while (i < sample_count) : (i += 1) {
250
- const previous_index = sorted_indices[i - 1];
251
- left_count += 1;
252
- left_positive += y_ptr[previous_index];
258
+ var bin: usize = 0;
259
+ while (bin + 1 < bin_count) : (bin += 1) {
260
+ left_count += bin_totals[bin];
261
+ left_positive += bin_positives[bin];
253
262
  const right_count = sample_count - left_count;
254
-
255
263
  if (left_count < model.min_samples_leaf or right_count < model.min_samples_leaf) {
256
264
  continue;
257
265
  }
258
266
 
259
- const left_value = x_ptr[previous_index * model.n_features + feature_index];
260
- const right_value = x_ptr[sorted_indices[i] * model.n_features + feature_index];
261
- if (left_value == right_value) {
262
- continue;
263
- }
264
-
265
267
  const right_positive = total_positive - left_positive;
266
268
  const impurity =
267
269
  (@as(f64, @floatFromInt(left_count)) / @as(f64, @floatFromInt(sample_count))) *
@@ -271,8 +273,7 @@ fn findBestSplitForFeature(
271
273
 
272
274
  if (impurity < best_impurity) {
273
275
  best_impurity = impurity;
274
- best_threshold = (left_value + right_value) / 2.0;
275
- best_split_index = i;
276
+ best_threshold = min_value + (value_range * @as(f64, @floatFromInt(bin + 1))) / @as(f64, @floatFromInt(bin_count));
276
277
  found = true;
277
278
  }
278
279
  }
@@ -281,14 +282,36 @@ fn findBestSplitForFeature(
281
282
  return null;
282
283
  }
283
284
 
284
- const left_indices = try allocator.alloc(usize, best_split_index);
285
+ var left_partition_count: usize = 0;
286
+ for (indices) |sample_index| {
287
+ const value = x_ptr[sample_index * model.n_features + feature_index];
288
+ if (value <= best_threshold) {
289
+ left_partition_count += 1;
290
+ }
291
+ }
292
+
293
+ const right_partition_count = sample_count - left_partition_count;
294
+ if (left_partition_count < model.min_samples_leaf or right_partition_count < model.min_samples_leaf) {
295
+ return null;
296
+ }
297
+
298
+ const left_indices = try allocator.alloc(usize, left_partition_count);
285
299
  errdefer allocator.free(left_indices);
286
- const right_size = sample_count - best_split_index;
287
- const right_indices = try allocator.alloc(usize, right_size);
300
+ const right_indices = try allocator.alloc(usize, right_partition_count);
288
301
  errdefer allocator.free(right_indices);
289
302
 
290
- @memcpy(left_indices, sorted_indices[0..best_split_index]);
291
- @memcpy(right_indices, sorted_indices[best_split_index..]);
303
+ var left_write: usize = 0;
304
+ var right_write: usize = 0;
305
+ for (indices) |sample_index| {
306
+ const value = x_ptr[sample_index * model.n_features + feature_index];
307
+ if (value <= best_threshold) {
308
+ left_indices[left_write] = sample_index;
309
+ left_write += 1;
310
+ } else {
311
+ right_indices[right_write] = sample_index;
312
+ right_write += 1;
313
+ }
314
+ }
292
315
 
293
316
  return SplitResult{
294
317
  .threshold = best_threshold,