bun-scikit 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +2 -2
- package/scripts/check-benchmark-health.ts +62 -1
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +149 -0
- package/src/native/zigKernels.ts +33 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +146 -3
- package/zig/kernels.zig +63 -40
|
@@ -6,6 +6,7 @@ import {
|
|
|
6
6
|
validateClassificationInputs,
|
|
7
7
|
} from "../utils/validation";
|
|
8
8
|
import { accuracyScore } from "../metrics/classification";
|
|
9
|
+
import { getZigKernels } from "../native/zigKernels";
|
|
9
10
|
|
|
10
11
|
export type MaxFeaturesOption = "sqrt" | "log2" | number | null;
|
|
11
12
|
|
|
@@ -38,6 +39,11 @@ interface SplitPartition {
|
|
|
38
39
|
|
|
39
40
|
const MAX_THRESHOLD_BINS = 128;
|
|
40
41
|
|
|
42
|
+
function isZigTreeBackendEnabled(): boolean {
|
|
43
|
+
const mode = process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase();
|
|
44
|
+
return mode === "zig" || mode === "native";
|
|
45
|
+
}
|
|
46
|
+
|
|
41
47
|
function mulberry32(seed: number): () => number {
|
|
42
48
|
let state = seed >>> 0;
|
|
43
49
|
return () => {
|
|
@@ -59,6 +65,8 @@ function giniImpurity(positiveCount: number, sampleCount: number): number {
|
|
|
59
65
|
|
|
60
66
|
export class DecisionTreeClassifier implements ClassificationModel {
|
|
61
67
|
classes_: Vector = [0, 1];
|
|
68
|
+
fitBackend_: "zig" | "js" = "js";
|
|
69
|
+
fitBackendLibrary_: string | null = null;
|
|
62
70
|
private readonly maxDepth: number;
|
|
63
71
|
private readonly minSamplesSplit: number;
|
|
64
72
|
private readonly minSamplesLeaf: number;
|
|
@@ -73,6 +81,7 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
73
81
|
private featureSelectionMarks: Uint8Array | null = null;
|
|
74
82
|
private binTotals: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
|
|
75
83
|
private binPositives: Uint32Array = new Uint32Array(MAX_THRESHOLD_BINS);
|
|
84
|
+
private zigModelHandle: bigint | null = null;
|
|
76
85
|
|
|
77
86
|
constructor(options: DecisionTreeClassifierOptions = {}) {
|
|
78
87
|
this.maxDepth = options.maxDepth ?? 12;
|
|
@@ -90,6 +99,8 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
90
99
|
flattenedXTrain?: Float64Array,
|
|
91
100
|
yBinaryTrain?: Uint8Array,
|
|
92
101
|
): this {
|
|
102
|
+
this.destroyZigModel();
|
|
103
|
+
|
|
93
104
|
if (!skipValidation) {
|
|
94
105
|
validateClassificationInputs(X, y);
|
|
95
106
|
}
|
|
@@ -103,18 +114,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
103
114
|
this.featureSelectionMarks = new Uint8Array(this.featureCount);
|
|
104
115
|
this.random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
|
|
105
116
|
|
|
106
|
-
let
|
|
117
|
+
let validatedSampleIndices: Uint32Array | null = null;
|
|
107
118
|
if (sampleIndices) {
|
|
108
119
|
if (sampleIndices.length === 0) {
|
|
109
120
|
throw new Error("sampleIndices must not be empty.");
|
|
110
121
|
}
|
|
122
|
+
validatedSampleIndices = new Uint32Array(sampleIndices.length);
|
|
111
123
|
for (let i = 0; i < sampleIndices.length; i += 1) {
|
|
112
124
|
const index = sampleIndices[i];
|
|
113
125
|
if (!Number.isInteger(index) || index < 0 || index >= X.length) {
|
|
114
126
|
throw new Error(`sampleIndices contains invalid index: ${index}.`);
|
|
115
127
|
}
|
|
128
|
+
validatedSampleIndices[i] = index;
|
|
116
129
|
}
|
|
117
|
-
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if (isZigTreeBackendEnabled() && this.tryFitWithZig(X.length, validatedSampleIndices)) {
|
|
133
|
+
return this;
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
let rootIndices: number[];
|
|
137
|
+
if (validatedSampleIndices) {
|
|
138
|
+
rootIndices = Array.from(validatedSampleIndices);
|
|
118
139
|
} else {
|
|
119
140
|
rootIndices = new Array<number>(X.length);
|
|
120
141
|
for (let idx = 0; idx < X.length; idx += 1) {
|
|
@@ -123,11 +144,13 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
123
144
|
}
|
|
124
145
|
|
|
125
146
|
this.root = this.buildTree(rootIndices, 0);
|
|
147
|
+
this.fitBackend_ = "js";
|
|
148
|
+
this.fitBackendLibrary_ = null;
|
|
126
149
|
return this;
|
|
127
150
|
}
|
|
128
151
|
|
|
129
152
|
predict(X: Matrix): Vector {
|
|
130
|
-
if (
|
|
153
|
+
if ((this.root === null && this.zigModelHandle === null) || this.featureCount === 0) {
|
|
131
154
|
throw new Error("DecisionTreeClassifier has not been fitted.");
|
|
132
155
|
}
|
|
133
156
|
|
|
@@ -140,6 +163,28 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
140
163
|
);
|
|
141
164
|
}
|
|
142
165
|
|
|
166
|
+
if (this.zigModelHandle !== null) {
|
|
167
|
+
const kernels = getZigKernels();
|
|
168
|
+
const nativePredict = kernels?.decisionTreeModelPredict;
|
|
169
|
+
if (nativePredict) {
|
|
170
|
+
const flattenedX = this.flattenTrainingMatrix(X);
|
|
171
|
+
const outLabels = new Uint8Array(X.length);
|
|
172
|
+
const status = nativePredict(
|
|
173
|
+
this.zigModelHandle,
|
|
174
|
+
flattenedX,
|
|
175
|
+
X.length,
|
|
176
|
+
this.featureCount,
|
|
177
|
+
outLabels,
|
|
178
|
+
);
|
|
179
|
+
if (status === 1) {
|
|
180
|
+
return Array.from(outLabels);
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
if (!this.root) {
|
|
184
|
+
throw new Error("Native DecisionTree predict failed and no JS fallback tree is available.");
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
143
188
|
return X.map((sample) => this.predictOne(sample, this.root!));
|
|
144
189
|
}
|
|
145
190
|
|
|
@@ -228,9 +273,107 @@ export class DecisionTreeClassifier implements ClassificationModel {
|
|
|
228
273
|
if (this.maxFeatures === "log2") {
|
|
229
274
|
return Math.max(1, Math.floor(Math.log2(featureCount)));
|
|
230
275
|
}
|
|
276
|
+
if (!Number.isFinite(this.maxFeatures)) {
|
|
277
|
+
return featureCount;
|
|
278
|
+
}
|
|
231
279
|
return Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)));
|
|
232
280
|
}
|
|
233
281
|
|
|
282
|
+
private resolveNativeMaxFeatures(featureCount: number): {
|
|
283
|
+
mode: 0 | 1 | 2 | 3;
|
|
284
|
+
value: number;
|
|
285
|
+
} {
|
|
286
|
+
if (this.maxFeatures === null || this.maxFeatures === undefined) {
|
|
287
|
+
return { mode: 0, value: 0 };
|
|
288
|
+
}
|
|
289
|
+
if (this.maxFeatures === "sqrt") {
|
|
290
|
+
return { mode: 1, value: 0 };
|
|
291
|
+
}
|
|
292
|
+
if (this.maxFeatures === "log2") {
|
|
293
|
+
return { mode: 2, value: 0 };
|
|
294
|
+
}
|
|
295
|
+
const value = Number.isFinite(this.maxFeatures)
|
|
296
|
+
? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
|
|
297
|
+
: featureCount;
|
|
298
|
+
return { mode: 3, value };
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
private tryFitWithZig(
|
|
302
|
+
sampleCount: number,
|
|
303
|
+
sampleIndices: Uint32Array | null,
|
|
304
|
+
): boolean {
|
|
305
|
+
const kernels = getZigKernels();
|
|
306
|
+
const create = kernels?.decisionTreeModelCreate;
|
|
307
|
+
const fit = kernels?.decisionTreeModelFit;
|
|
308
|
+
const destroy = kernels?.decisionTreeModelDestroy;
|
|
309
|
+
if (!create || !fit || !destroy) {
|
|
310
|
+
return false;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
const { mode, value } = this.resolveNativeMaxFeatures(this.featureCount);
|
|
314
|
+
const useRandomState = this.randomState === undefined ? 0 : 1;
|
|
315
|
+
const randomState = this.randomState ?? 0;
|
|
316
|
+
const handle = create(
|
|
317
|
+
this.maxDepth,
|
|
318
|
+
this.minSamplesSplit,
|
|
319
|
+
this.minSamplesLeaf,
|
|
320
|
+
mode,
|
|
321
|
+
value,
|
|
322
|
+
randomState >>> 0,
|
|
323
|
+
useRandomState,
|
|
324
|
+
this.featureCount,
|
|
325
|
+
);
|
|
326
|
+
if (handle === 0n) {
|
|
327
|
+
return false;
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
let shouldDestroy = true;
|
|
331
|
+
try {
|
|
332
|
+
const emptySampleIndices = new Uint32Array(0);
|
|
333
|
+
const status = fit(
|
|
334
|
+
handle,
|
|
335
|
+
this.flattenedXTrain!,
|
|
336
|
+
this.yBinaryTrain!,
|
|
337
|
+
sampleCount,
|
|
338
|
+
this.featureCount,
|
|
339
|
+
sampleIndices ?? emptySampleIndices,
|
|
340
|
+
sampleIndices?.length ?? 0,
|
|
341
|
+
);
|
|
342
|
+
if (status !== 1) {
|
|
343
|
+
return false;
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
this.zigModelHandle = handle;
|
|
347
|
+
this.root = null;
|
|
348
|
+
this.fitBackend_ = "zig";
|
|
349
|
+
this.fitBackendLibrary_ = kernels.libraryPath;
|
|
350
|
+
shouldDestroy = false;
|
|
351
|
+
return true;
|
|
352
|
+
} catch {
|
|
353
|
+
return false;
|
|
354
|
+
} finally {
|
|
355
|
+
if (shouldDestroy) {
|
|
356
|
+
destroy(handle);
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
private destroyZigModel(): void {
|
|
362
|
+
if (this.zigModelHandle === null) {
|
|
363
|
+
return;
|
|
364
|
+
}
|
|
365
|
+
const kernels = getZigKernels();
|
|
366
|
+
const destroy = kernels?.decisionTreeModelDestroy;
|
|
367
|
+
if (destroy) {
|
|
368
|
+
try {
|
|
369
|
+
destroy(this.zigModelHandle);
|
|
370
|
+
} catch {
|
|
371
|
+
// no-op: cleanup best effort
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
this.zigModelHandle = null;
|
|
375
|
+
}
|
|
376
|
+
|
|
234
377
|
private selectFeatureIndices(featureCount: number): number[] {
|
|
235
378
|
const k = this.resolveMaxFeatures(featureCount);
|
|
236
379
|
if (k >= featureCount) {
|
package/zig/kernels.zig
CHANGED
|
@@ -84,6 +84,8 @@ const SplitResult = struct {
|
|
|
84
84
|
right_indices: []usize,
|
|
85
85
|
};
|
|
86
86
|
|
|
87
|
+
const MAX_THRESHOLD_BINS: usize = 128;
|
|
88
|
+
|
|
87
89
|
const Mulberry32 = struct {
|
|
88
90
|
state: u32,
|
|
89
91
|
|
|
@@ -213,55 +215,55 @@ fn findBestSplitForFeature(
|
|
|
213
215
|
if (sample_count < 2) {
|
|
214
216
|
return null;
|
|
215
217
|
}
|
|
218
|
+
var min_value = std.math.inf(f64);
|
|
219
|
+
var max_value = -std.math.inf(f64);
|
|
220
|
+
var total_positive: usize = 0;
|
|
221
|
+
for (indices) |sample_index| {
|
|
222
|
+
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
223
|
+
if (value < min_value) {
|
|
224
|
+
min_value = value;
|
|
225
|
+
}
|
|
226
|
+
if (value > max_value) {
|
|
227
|
+
max_value = value;
|
|
228
|
+
}
|
|
229
|
+
total_positive += y_ptr[sample_index];
|
|
230
|
+
}
|
|
216
231
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
232
|
+
if (!std.math.isFinite(min_value) or !std.math.isFinite(max_value) or min_value == max_value) {
|
|
233
|
+
return null;
|
|
234
|
+
}
|
|
220
235
|
|
|
221
|
-
const
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
return ctx.x_ptr[a * ctx.n_features + ctx.feature_index] <
|
|
227
|
-
ctx.x_ptr[b * ctx.n_features + ctx.feature_index];
|
|
228
|
-
}
|
|
229
|
-
};
|
|
230
|
-
std.sort.heap(usize, sorted_indices, SortContext{
|
|
231
|
-
.x_ptr = x_ptr,
|
|
232
|
-
.n_features = model.n_features,
|
|
233
|
-
.feature_index = feature_index,
|
|
234
|
-
}, SortContext.lessThan);
|
|
236
|
+
const dynamic_bins = @as(usize, @intFromFloat(@floor(@sqrt(@as(f64, @floatFromInt(sample_count))))));
|
|
237
|
+
const bin_count = std.math.clamp(dynamic_bins, 16, MAX_THRESHOLD_BINS);
|
|
238
|
+
var bin_totals: [MAX_THRESHOLD_BINS]usize = [_]usize{0} ** MAX_THRESHOLD_BINS;
|
|
239
|
+
var bin_positives: [MAX_THRESHOLD_BINS]usize = [_]usize{0} ** MAX_THRESHOLD_BINS;
|
|
240
|
+
const value_range = max_value - min_value;
|
|
235
241
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
242
|
+
for (indices) |sample_index| {
|
|
243
|
+
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
244
|
+
var bin_index = @as(usize, @intFromFloat(@floor(((value - min_value) / value_range) * @as(f64, @floatFromInt(bin_count)))));
|
|
245
|
+
if (bin_index >= bin_count) {
|
|
246
|
+
bin_index = bin_count - 1;
|
|
247
|
+
}
|
|
248
|
+
bin_totals[bin_index] += 1;
|
|
249
|
+
bin_positives[bin_index] += y_ptr[sample_index];
|
|
239
250
|
}
|
|
240
251
|
|
|
241
252
|
var left_count: usize = 0;
|
|
242
253
|
var left_positive: usize = 0;
|
|
243
254
|
var best_impurity = std.math.inf(f64);
|
|
244
255
|
var best_threshold: f64 = 0.0;
|
|
245
|
-
var best_split_index: usize = 0;
|
|
246
256
|
var found = false;
|
|
247
257
|
|
|
248
|
-
var
|
|
249
|
-
while (
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
left_positive += y_ptr[previous_index];
|
|
258
|
+
var bin: usize = 0;
|
|
259
|
+
while (bin + 1 < bin_count) : (bin += 1) {
|
|
260
|
+
left_count += bin_totals[bin];
|
|
261
|
+
left_positive += bin_positives[bin];
|
|
253
262
|
const right_count = sample_count - left_count;
|
|
254
|
-
|
|
255
263
|
if (left_count < model.min_samples_leaf or right_count < model.min_samples_leaf) {
|
|
256
264
|
continue;
|
|
257
265
|
}
|
|
258
266
|
|
|
259
|
-
const left_value = x_ptr[previous_index * model.n_features + feature_index];
|
|
260
|
-
const right_value = x_ptr[sorted_indices[i] * model.n_features + feature_index];
|
|
261
|
-
if (left_value == right_value) {
|
|
262
|
-
continue;
|
|
263
|
-
}
|
|
264
|
-
|
|
265
267
|
const right_positive = total_positive - left_positive;
|
|
266
268
|
const impurity =
|
|
267
269
|
(@as(f64, @floatFromInt(left_count)) / @as(f64, @floatFromInt(sample_count))) *
|
|
@@ -271,8 +273,7 @@ fn findBestSplitForFeature(
|
|
|
271
273
|
|
|
272
274
|
if (impurity < best_impurity) {
|
|
273
275
|
best_impurity = impurity;
|
|
274
|
-
best_threshold = (
|
|
275
|
-
best_split_index = i;
|
|
276
|
+
best_threshold = min_value + (value_range * @as(f64, @floatFromInt(bin + 1))) / @as(f64, @floatFromInt(bin_count));
|
|
276
277
|
found = true;
|
|
277
278
|
}
|
|
278
279
|
}
|
|
@@ -281,14 +282,36 @@ fn findBestSplitForFeature(
|
|
|
281
282
|
return null;
|
|
282
283
|
}
|
|
283
284
|
|
|
284
|
-
|
|
285
|
+
var left_partition_count: usize = 0;
|
|
286
|
+
for (indices) |sample_index| {
|
|
287
|
+
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
288
|
+
if (value <= best_threshold) {
|
|
289
|
+
left_partition_count += 1;
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
const right_partition_count = sample_count - left_partition_count;
|
|
294
|
+
if (left_partition_count < model.min_samples_leaf or right_partition_count < model.min_samples_leaf) {
|
|
295
|
+
return null;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
const left_indices = try allocator.alloc(usize, left_partition_count);
|
|
285
299
|
errdefer allocator.free(left_indices);
|
|
286
|
-
const
|
|
287
|
-
const right_indices = try allocator.alloc(usize, right_size);
|
|
300
|
+
const right_indices = try allocator.alloc(usize, right_partition_count);
|
|
288
301
|
errdefer allocator.free(right_indices);
|
|
289
302
|
|
|
290
|
-
|
|
291
|
-
|
|
303
|
+
var left_write: usize = 0;
|
|
304
|
+
var right_write: usize = 0;
|
|
305
|
+
for (indices) |sample_index| {
|
|
306
|
+
const value = x_ptr[sample_index * model.n_features + feature_index];
|
|
307
|
+
if (value <= best_threshold) {
|
|
308
|
+
left_indices[left_write] = sample_index;
|
|
309
|
+
left_write += 1;
|
|
310
|
+
} else {
|
|
311
|
+
right_indices[right_write] = sample_index;
|
|
312
|
+
right_write += 1;
|
|
313
|
+
}
|
|
314
|
+
}
|
|
292
315
|
|
|
293
316
|
return SplitResult{
|
|
294
317
|
.threshold = best_threshold,
|