bun-scikit 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +187 -0
  3. package/binding.gyp +21 -0
  4. package/docs/README.md +7 -0
  5. package/docs/native-abi.md +53 -0
  6. package/index.ts +1 -0
  7. package/package.json +76 -0
  8. package/scripts/build-node-addon.ts +26 -0
  9. package/scripts/build-zig-kernels.ts +50 -0
  10. package/scripts/check-api-docs-coverage.ts +52 -0
  11. package/scripts/check-benchmark-health.ts +140 -0
  12. package/scripts/install-native.ts +160 -0
  13. package/scripts/package-native-artifacts.ts +62 -0
  14. package/scripts/sync-benchmark-readme.ts +181 -0
  15. package/scripts/update-benchmark-history.ts +91 -0
  16. package/src/ensemble/RandomForestClassifier.ts +136 -0
  17. package/src/ensemble/RandomForestRegressor.ts +136 -0
  18. package/src/index.ts +32 -0
  19. package/src/linear_model/LinearRegression.ts +136 -0
  20. package/src/linear_model/LogisticRegression.ts +260 -0
  21. package/src/linear_model/SGDClassifier.ts +161 -0
  22. package/src/linear_model/SGDRegressor.ts +104 -0
  23. package/src/metrics/classification.ts +294 -0
  24. package/src/metrics/regression.ts +51 -0
  25. package/src/model_selection/GridSearchCV.ts +244 -0
  26. package/src/model_selection/KFold.ts +82 -0
  27. package/src/model_selection/RepeatedKFold.ts +49 -0
  28. package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
  29. package/src/model_selection/StratifiedKFold.ts +112 -0
  30. package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
  31. package/src/model_selection/crossValScore.ts +165 -0
  32. package/src/model_selection/trainTestSplit.ts +82 -0
  33. package/src/naive_bayes/GaussianNB.ts +148 -0
  34. package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
  35. package/src/native/zigKernels.ts +576 -0
  36. package/src/neighbors/KNeighborsClassifier.ts +85 -0
  37. package/src/pipeline/ColumnTransformer.ts +203 -0
  38. package/src/pipeline/FeatureUnion.ts +123 -0
  39. package/src/pipeline/Pipeline.ts +168 -0
  40. package/src/preprocessing/MinMaxScaler.ts +113 -0
  41. package/src/preprocessing/OneHotEncoder.ts +91 -0
  42. package/src/preprocessing/PolynomialFeatures.ts +158 -0
  43. package/src/preprocessing/RobustScaler.ts +149 -0
  44. package/src/preprocessing/SimpleImputer.ts +150 -0
  45. package/src/preprocessing/StandardScaler.ts +92 -0
  46. package/src/svm/LinearSVC.ts +117 -0
  47. package/src/tree/DecisionTreeClassifier.ts +394 -0
  48. package/src/tree/DecisionTreeRegressor.ts +407 -0
  49. package/src/types.ts +18 -0
  50. package/src/utils/linalg.ts +209 -0
  51. package/src/utils/validation.ts +78 -0
  52. package/zig/kernels.zig +1327 -0
@@ -0,0 +1,576 @@
1
+ import { dlopen, FFIType, suffix } from "bun:ffi";
2
+ import { existsSync } from "node:fs";
3
+ import { createRequire } from "node:module";
4
+ import { resolve } from "node:path";
5
+
6
+ type NativeHandle = bigint;
7
+ type AbiVersionFn = () => number;
8
+
9
+ type LinearModelCreateFn = (nFeatures: number, fitIntercept: number) => NativeHandle;
10
+ type LinearModelDestroyFn = (handle: NativeHandle) => void;
11
+ type LinearModelFitFn = (
12
+ handle: NativeHandle,
13
+ x: Float64Array,
14
+ y: Float64Array,
15
+ nSamples: number,
16
+ l2: number,
17
+ ) => number;
18
+ type LinearModelPredictFn = (
19
+ handle: NativeHandle,
20
+ x: Float64Array,
21
+ nSamples: number,
22
+ out: Float64Array,
23
+ ) => number;
24
+ type LinearModelCopyCoefficientsFn = (handle: NativeHandle, out: Float64Array) => number;
25
+ type LinearModelGetInterceptFn = (handle: NativeHandle) => number;
26
+
27
+ type LogisticModelCreateFn = (nFeatures: number, fitIntercept: number) => NativeHandle;
28
+ type LogisticModelDestroyFn = (handle: NativeHandle) => void;
29
+ type LogisticModelFitFn = (
30
+ handle: NativeHandle,
31
+ x: Float64Array,
32
+ y: Float64Array,
33
+ nSamples: number,
34
+ learningRate: number,
35
+ l2: number,
36
+ maxIter: number,
37
+ tolerance: number,
38
+ ) => bigint;
39
+ type LogisticModelFitLbfgsFn = (
40
+ handle: NativeHandle,
41
+ x: Float64Array,
42
+ y: Float64Array,
43
+ nSamples: number,
44
+ maxIter: number,
45
+ tolerance: number,
46
+ l2: number,
47
+ memory: number,
48
+ ) => bigint;
49
+ type LogisticModelPredictProbaFn = (
50
+ handle: NativeHandle,
51
+ x: Float64Array,
52
+ nSamples: number,
53
+ outPositive: Float64Array,
54
+ ) => number;
55
+ type LogisticModelPredictFn = (
56
+ handle: NativeHandle,
57
+ x: Float64Array,
58
+ nSamples: number,
59
+ outLabels: Uint8Array,
60
+ ) => number;
61
+ type LogisticModelCopyCoefficientsFn = (handle: NativeHandle, out: Float64Array) => number;
62
+ type LogisticModelGetInterceptFn = (handle: NativeHandle) => number;
63
+
64
+ type DecisionTreeModelCreateFn = (
65
+ maxDepth: number,
66
+ minSamplesSplit: number,
67
+ minSamplesLeaf: number,
68
+ maxFeaturesMode: number,
69
+ maxFeaturesValue: number,
70
+ randomState: number,
71
+ useRandomState: number,
72
+ nFeatures: number,
73
+ ) => NativeHandle;
74
+ type DecisionTreeModelDestroyFn = (handle: NativeHandle) => void;
75
+ type DecisionTreeModelFitFn = (
76
+ handle: NativeHandle,
77
+ x: Float64Array,
78
+ y: Uint8Array,
79
+ nSamples: number,
80
+ nFeatures: number,
81
+ sampleIndices: Uint32Array,
82
+ sampleCount: number,
83
+ ) => number;
84
+ type DecisionTreeModelPredictFn = (
85
+ handle: NativeHandle,
86
+ x: Float64Array,
87
+ nSamples: number,
88
+ nFeatures: number,
89
+ outLabels: Uint8Array,
90
+ ) => number;
91
+
92
+ type LogisticTrainEpochFn = (
93
+ x: Float64Array,
94
+ y: Float64Array,
95
+ nSamples: number,
96
+ nFeatures: number,
97
+ weights: Float64Array,
98
+ intercept: Float64Array,
99
+ gradients: Float64Array,
100
+ learningRate: number,
101
+ l2: number,
102
+ fitIntercept: number,
103
+ ) => number;
104
+
105
+ type LogisticTrainEpochsFn = (
106
+ x: Float64Array,
107
+ y: Float64Array,
108
+ nSamples: number,
109
+ nFeatures: number,
110
+ weights: Float64Array,
111
+ intercept: Float64Array,
112
+ gradients: Float64Array,
113
+ learningRate: number,
114
+ l2: number,
115
+ fitIntercept: number,
116
+ maxIter: number,
117
+ tolerance: number,
118
+ ) => bigint;
119
+
120
+ interface ZigKernelLibrary {
121
+ symbols: {
122
+ bun_scikit_abi_version?: AbiVersionFn;
123
+ linear_model_create?: LinearModelCreateFn;
124
+ linear_model_destroy?: LinearModelDestroyFn;
125
+ linear_model_fit?: LinearModelFitFn;
126
+ linear_model_predict?: LinearModelPredictFn;
127
+ linear_model_copy_coefficients?: LinearModelCopyCoefficientsFn;
128
+ linear_model_get_intercept?: LinearModelGetInterceptFn;
129
+ logistic_model_create?: LogisticModelCreateFn;
130
+ logistic_model_destroy?: LogisticModelDestroyFn;
131
+ logistic_model_fit?: LogisticModelFitFn;
132
+ logistic_model_fit_lbfgs?: LogisticModelFitLbfgsFn;
133
+ logistic_model_predict_proba?: LogisticModelPredictProbaFn;
134
+ logistic_model_predict?: LogisticModelPredictFn;
135
+ logistic_model_copy_coefficients?: LogisticModelCopyCoefficientsFn;
136
+ logistic_model_get_intercept?: LogisticModelGetInterceptFn;
137
+ decision_tree_model_create?: DecisionTreeModelCreateFn;
138
+ decision_tree_model_destroy?: DecisionTreeModelDestroyFn;
139
+ decision_tree_model_fit?: DecisionTreeModelFitFn;
140
+ decision_tree_model_predict?: DecisionTreeModelPredictFn;
141
+ logistic_train_epoch?: LogisticTrainEpochFn;
142
+ logistic_train_epochs?: LogisticTrainEpochsFn;
143
+ };
144
+ }
145
+
146
+ export interface ZigKernels {
147
+ linearModelCreate: LinearModelCreateFn | null;
148
+ linearModelDestroy: LinearModelDestroyFn | null;
149
+ linearModelFit: LinearModelFitFn | null;
150
+ linearModelPredict: LinearModelPredictFn | null;
151
+ linearModelCopyCoefficients: LinearModelCopyCoefficientsFn | null;
152
+ linearModelGetIntercept: LinearModelGetInterceptFn | null;
153
+ logisticModelCreate: LogisticModelCreateFn | null;
154
+ logisticModelDestroy: LogisticModelDestroyFn | null;
155
+ logisticModelFit: LogisticModelFitFn | null;
156
+ logisticModelFitLbfgs: LogisticModelFitLbfgsFn | null;
157
+ logisticModelPredictProba: LogisticModelPredictProbaFn | null;
158
+ logisticModelPredict: LogisticModelPredictFn | null;
159
+ logisticModelCopyCoefficients: LogisticModelCopyCoefficientsFn | null;
160
+ logisticModelGetIntercept: LogisticModelGetInterceptFn | null;
161
+ decisionTreeModelCreate: DecisionTreeModelCreateFn | null;
162
+ decisionTreeModelDestroy: DecisionTreeModelDestroyFn | null;
163
+ decisionTreeModelFit: DecisionTreeModelFitFn | null;
164
+ decisionTreeModelPredict: DecisionTreeModelPredictFn | null;
165
+ logisticTrainEpoch: LogisticTrainEpochFn | null;
166
+ logisticTrainEpochs: LogisticTrainEpochsFn | null;
167
+ abiVersion: number | null;
168
+ libraryPath: string;
169
+ }
170
+
171
+ let cachedKernels: ZigKernels | null | undefined;
172
+ const EXPECTED_ABI_VERSION = 1;
173
+
174
+ function isTruthy(value: string | undefined): boolean {
175
+ if (!value) {
176
+ return false;
177
+ }
178
+ const normalized = value.trim().toLowerCase();
179
+ return !(normalized === "0" || normalized === "false" || normalized === "off");
180
+ }
181
+
182
+ export function isZigBackendEnabled(): boolean {
183
+ const envValue = process.env.BUN_SCIKIT_ENABLE_ZIG;
184
+ if (!envValue) {
185
+ return true;
186
+ }
187
+ return isTruthy(envValue);
188
+ }
189
+
190
+ function candidateLibraryPaths(): string[] {
191
+ const extension = suffix;
192
+ const fileName = `bun_scikit_kernels.${extension}`;
193
+ const explicitPath = process.env.BUN_SCIKIT_ZIG_LIB;
194
+
195
+ const candidates = [
196
+ explicitPath,
197
+ resolve(process.cwd(), "dist", "native", fileName),
198
+ resolve(process.cwd(), "native", fileName),
199
+ resolve(import.meta.dir, "../../dist/native", fileName),
200
+ resolve(import.meta.dir, "../../native", fileName),
201
+ ];
202
+
203
+ return candidates.filter((entry): entry is string => Boolean(entry));
204
+ }
205
+
206
+ function candidateAddonPaths(): string[] {
207
+ const candidates = [
208
+ process.env.BUN_SCIKIT_NODE_ADDON,
209
+ resolve(process.cwd(), "dist", "native", "bun_scikit_node_addon.node"),
210
+ resolve(process.cwd(), "build", "Release", "bun_scikit_node_addon.node"),
211
+ resolve(import.meta.dir, "../../dist/native", "bun_scikit_node_addon.node"),
212
+ ];
213
+ return candidates.filter((entry): entry is string => Boolean(entry));
214
+ }
215
+
216
+ interface NodeApiAddon {
217
+ loadLibrary: (path: string) => boolean;
218
+ unloadLibrary?: () => void;
219
+ loadedPath: () => string | null;
220
+ abiVersion: () => number;
221
+ linearModelCreate: LinearModelCreateFn;
222
+ linearModelDestroy: LinearModelDestroyFn;
223
+ linearModelFit: LinearModelFitFn;
224
+ linearModelCopyCoefficients: LinearModelCopyCoefficientsFn;
225
+ linearModelGetIntercept: LinearModelGetInterceptFn;
226
+ logisticModelCreate: LogisticModelCreateFn;
227
+ logisticModelDestroy: LogisticModelDestroyFn;
228
+ logisticModelFit: LogisticModelFitFn;
229
+ logisticModelFitLbfgs: LogisticModelFitLbfgsFn;
230
+ logisticModelCopyCoefficients: LogisticModelCopyCoefficientsFn;
231
+ logisticModelGetIntercept: LogisticModelGetInterceptFn;
232
+ }
233
+
234
+ function tryLoadNodeApiKernels(): ZigKernels | null {
235
+ const require = createRequire(import.meta.url);
236
+ for (const addonPath of candidateAddonPaths()) {
237
+ if (!existsSync(addonPath)) {
238
+ continue;
239
+ }
240
+ try {
241
+ const addon = require(addonPath) as NodeApiAddon;
242
+ for (const libraryPath of candidateLibraryPaths()) {
243
+ if (!existsSync(libraryPath)) {
244
+ continue;
245
+ }
246
+ if (!addon.loadLibrary(libraryPath)) {
247
+ continue;
248
+ }
249
+ const abiVersion = addon.abiVersion();
250
+ if (abiVersion !== EXPECTED_ABI_VERSION) {
251
+ addon.unloadLibrary?.();
252
+ continue;
253
+ }
254
+
255
+ return {
256
+ linearModelCreate: addon.linearModelCreate ?? null,
257
+ linearModelDestroy: addon.linearModelDestroy ?? null,
258
+ linearModelFit: addon.linearModelFit ?? null,
259
+ linearModelPredict: null,
260
+ linearModelCopyCoefficients: addon.linearModelCopyCoefficients ?? null,
261
+ linearModelGetIntercept: addon.linearModelGetIntercept ?? null,
262
+ logisticModelCreate: addon.logisticModelCreate ?? null,
263
+ logisticModelDestroy: addon.logisticModelDestroy ?? null,
264
+ logisticModelFit: addon.logisticModelFit ?? null,
265
+ logisticModelFitLbfgs: addon.logisticModelFitLbfgs ?? null,
266
+ logisticModelPredictProba: null,
267
+ logisticModelPredict: null,
268
+ logisticModelCopyCoefficients: addon.logisticModelCopyCoefficients ?? null,
269
+ logisticModelGetIntercept: addon.logisticModelGetIntercept ?? null,
270
+ decisionTreeModelCreate: null,
271
+ decisionTreeModelDestroy: null,
272
+ decisionTreeModelFit: null,
273
+ decisionTreeModelPredict: null,
274
+ logisticTrainEpoch: null,
275
+ logisticTrainEpochs: null,
276
+ abiVersion,
277
+ libraryPath: addon.loadedPath() ?? libraryPath,
278
+ };
279
+ }
280
+ } catch {
281
+ continue;
282
+ }
283
+ }
284
+
285
+ return null;
286
+ }
287
+
288
+ export function getZigKernels(): ZigKernels | null {
289
+ if (!isZigBackendEnabled()) {
290
+ return null;
291
+ }
292
+
293
+ if (cachedKernels !== undefined) {
294
+ return cachedKernels;
295
+ }
296
+
297
+ const bridgePreference = process.env.BUN_SCIKIT_NATIVE_BRIDGE?.trim().toLowerCase();
298
+ if (bridgePreference !== "ffi") {
299
+ const nodeApiKernels = tryLoadNodeApiKernels();
300
+ if (nodeApiKernels) {
301
+ cachedKernels = nodeApiKernels;
302
+ return cachedKernels;
303
+ }
304
+ }
305
+
306
+ for (const libraryPath of candidateLibraryPaths()) {
307
+ if (!existsSync(libraryPath)) {
308
+ continue;
309
+ }
310
+
311
+ try {
312
+ try {
313
+ const library = dlopen(libraryPath, {
314
+ linear_model_create: {
315
+ args: ["usize", FFIType.u8],
316
+ returns: "usize",
317
+ },
318
+ bun_scikit_abi_version: {
319
+ args: [],
320
+ returns: FFIType.u32,
321
+ },
322
+ linear_model_destroy: {
323
+ args: ["usize"],
324
+ returns: FFIType.void,
325
+ },
326
+ linear_model_fit: {
327
+ args: ["usize", FFIType.ptr, FFIType.ptr, "usize", FFIType.f64],
328
+ returns: FFIType.u8,
329
+ },
330
+ linear_model_predict: {
331
+ args: ["usize", FFIType.ptr, "usize", FFIType.ptr],
332
+ returns: FFIType.u8,
333
+ },
334
+ linear_model_copy_coefficients: {
335
+ args: ["usize", FFIType.ptr],
336
+ returns: FFIType.u8,
337
+ },
338
+ linear_model_get_intercept: {
339
+ args: ["usize"],
340
+ returns: FFIType.f64,
341
+ },
342
+ logistic_model_create: {
343
+ args: ["usize", FFIType.u8],
344
+ returns: "usize",
345
+ },
346
+ logistic_model_destroy: {
347
+ args: ["usize"],
348
+ returns: FFIType.void,
349
+ },
350
+ logistic_model_fit: {
351
+ args: [
352
+ "usize",
353
+ FFIType.ptr,
354
+ FFIType.ptr,
355
+ "usize",
356
+ FFIType.f64,
357
+ FFIType.f64,
358
+ "usize",
359
+ FFIType.f64,
360
+ ],
361
+ returns: "usize",
362
+ },
363
+ logistic_model_fit_lbfgs: {
364
+ args: [
365
+ "usize",
366
+ FFIType.ptr,
367
+ FFIType.ptr,
368
+ "usize",
369
+ "usize",
370
+ FFIType.f64,
371
+ FFIType.f64,
372
+ "usize",
373
+ ],
374
+ returns: "usize",
375
+ },
376
+ logistic_model_predict_proba: {
377
+ args: ["usize", FFIType.ptr, "usize", FFIType.ptr],
378
+ returns: FFIType.u8,
379
+ },
380
+ logistic_model_predict: {
381
+ args: ["usize", FFIType.ptr, "usize", FFIType.ptr],
382
+ returns: FFIType.u8,
383
+ },
384
+ logistic_model_copy_coefficients: {
385
+ args: ["usize", FFIType.ptr],
386
+ returns: FFIType.u8,
387
+ },
388
+ logistic_model_get_intercept: {
389
+ args: ["usize"],
390
+ returns: FFIType.f64,
391
+ },
392
+ logistic_train_epoch: {
393
+ args: [
394
+ FFIType.ptr,
395
+ FFIType.ptr,
396
+ "usize",
397
+ "usize",
398
+ FFIType.ptr,
399
+ FFIType.ptr,
400
+ FFIType.ptr,
401
+ FFIType.f64,
402
+ FFIType.f64,
403
+ FFIType.u8,
404
+ ],
405
+ returns: FFIType.f64,
406
+ },
407
+ logistic_train_epochs: {
408
+ args: [
409
+ FFIType.ptr,
410
+ FFIType.ptr,
411
+ "usize",
412
+ "usize",
413
+ FFIType.ptr,
414
+ FFIType.ptr,
415
+ FFIType.ptr,
416
+ FFIType.f64,
417
+ FFIType.f64,
418
+ FFIType.u8,
419
+ "usize",
420
+ FFIType.f64,
421
+ ],
422
+ returns: "usize",
423
+ },
424
+ }) as ZigKernelLibrary;
425
+
426
+ const abiVersion = library.symbols.bun_scikit_abi_version?.() ?? null;
427
+ if (abiVersion !== null && abiVersion !== EXPECTED_ABI_VERSION) {
428
+ continue;
429
+ }
430
+
431
+ cachedKernels = {
432
+ linearModelCreate: library.symbols.linear_model_create ?? null,
433
+ linearModelDestroy: library.symbols.linear_model_destroy ?? null,
434
+ linearModelFit: library.symbols.linear_model_fit ?? null,
435
+ linearModelPredict: library.symbols.linear_model_predict ?? null,
436
+ linearModelCopyCoefficients:
437
+ library.symbols.linear_model_copy_coefficients ?? null,
438
+ linearModelGetIntercept: library.symbols.linear_model_get_intercept ?? null,
439
+ logisticModelCreate: library.symbols.logistic_model_create ?? null,
440
+ logisticModelDestroy: library.symbols.logistic_model_destroy ?? null,
441
+ logisticModelFit: library.symbols.logistic_model_fit ?? null,
442
+ logisticModelFitLbfgs: library.symbols.logistic_model_fit_lbfgs ?? null,
443
+ logisticModelPredictProba: library.symbols.logistic_model_predict_proba ?? null,
444
+ logisticModelPredict: library.symbols.logistic_model_predict ?? null,
445
+ logisticModelCopyCoefficients:
446
+ library.symbols.logistic_model_copy_coefficients ?? null,
447
+ logisticModelGetIntercept: library.symbols.logistic_model_get_intercept ?? null,
448
+ decisionTreeModelCreate: library.symbols.decision_tree_model_create ?? null,
449
+ decisionTreeModelDestroy: library.symbols.decision_tree_model_destroy ?? null,
450
+ decisionTreeModelFit: library.symbols.decision_tree_model_fit ?? null,
451
+ decisionTreeModelPredict: library.symbols.decision_tree_model_predict ?? null,
452
+ logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
453
+ logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
454
+ abiVersion,
455
+ libraryPath,
456
+ };
457
+
458
+ return cachedKernels;
459
+ } catch {
460
+ try {
461
+ const library = dlopen(libraryPath, {
462
+ logistic_train_epoch: {
463
+ args: [
464
+ FFIType.ptr,
465
+ FFIType.ptr,
466
+ "usize",
467
+ "usize",
468
+ FFIType.ptr,
469
+ FFIType.ptr,
470
+ FFIType.ptr,
471
+ FFIType.f64,
472
+ FFIType.f64,
473
+ FFIType.u8,
474
+ ],
475
+ returns: FFIType.f64,
476
+ },
477
+ logistic_train_epochs: {
478
+ args: [
479
+ FFIType.ptr,
480
+ FFIType.ptr,
481
+ "usize",
482
+ "usize",
483
+ FFIType.ptr,
484
+ FFIType.ptr,
485
+ FFIType.ptr,
486
+ FFIType.f64,
487
+ FFIType.f64,
488
+ FFIType.u8,
489
+ "usize",
490
+ FFIType.f64,
491
+ ],
492
+ returns: "usize",
493
+ },
494
+ }) as ZigKernelLibrary;
495
+
496
+ cachedKernels = {
497
+ linearModelCreate: null,
498
+ linearModelDestroy: null,
499
+ linearModelFit: null,
500
+ linearModelPredict: null,
501
+ linearModelCopyCoefficients: null,
502
+ linearModelGetIntercept: null,
503
+ logisticModelCreate: null,
504
+ logisticModelDestroy: null,
505
+ logisticModelFit: null,
506
+ logisticModelFitLbfgs: null,
507
+ logisticModelPredictProba: null,
508
+ logisticModelPredict: null,
509
+ logisticModelCopyCoefficients: null,
510
+ logisticModelGetIntercept: null,
511
+ decisionTreeModelCreate: null,
512
+ decisionTreeModelDestroy: null,
513
+ decisionTreeModelFit: null,
514
+ decisionTreeModelPredict: null,
515
+ logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
516
+ logisticTrainEpochs: library.symbols.logistic_train_epochs ?? null,
517
+ abiVersion: null,
518
+ libraryPath,
519
+ };
520
+
521
+ return cachedKernels;
522
+ } catch {
523
+ const library = dlopen(libraryPath, {
524
+ logistic_train_epoch: {
525
+ args: [
526
+ FFIType.ptr,
527
+ FFIType.ptr,
528
+ "usize",
529
+ "usize",
530
+ FFIType.ptr,
531
+ FFIType.ptr,
532
+ FFIType.ptr,
533
+ FFIType.f64,
534
+ FFIType.f64,
535
+ FFIType.u8,
536
+ ],
537
+ returns: FFIType.f64,
538
+ },
539
+ }) as ZigKernelLibrary;
540
+
541
+ cachedKernels = {
542
+ linearModelCreate: null,
543
+ linearModelDestroy: null,
544
+ linearModelFit: null,
545
+ linearModelPredict: null,
546
+ linearModelCopyCoefficients: null,
547
+ linearModelGetIntercept: null,
548
+ logisticModelCreate: null,
549
+ logisticModelDestroy: null,
550
+ logisticModelFit: null,
551
+ logisticModelFitLbfgs: null,
552
+ logisticModelPredictProba: null,
553
+ logisticModelPredict: null,
554
+ logisticModelCopyCoefficients: null,
555
+ logisticModelGetIntercept: null,
556
+ decisionTreeModelCreate: null,
557
+ decisionTreeModelDestroy: null,
558
+ decisionTreeModelFit: null,
559
+ decisionTreeModelPredict: null,
560
+ logisticTrainEpoch: library.symbols.logistic_train_epoch ?? null,
561
+ logisticTrainEpochs: null,
562
+ abiVersion: null,
563
+ libraryPath,
564
+ };
565
+
566
+ return cachedKernels;
567
+ }
568
+ }
569
+ } catch {
570
+ continue;
571
+ }
572
+ }
573
+
574
+ cachedKernels = null;
575
+ return cachedKernels;
576
+ }
@@ -0,0 +1,85 @@
1
+ import type { ClassificationModel, Matrix, Vector } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertFiniteVector,
6
+ validateClassificationInputs,
7
+ } from "../utils/validation";
8
+ import { accuracyScore } from "../metrics/classification";
9
+
10
+ export interface KNeighborsClassifierOptions {
11
+ nNeighbors?: number;
12
+ }
13
+
14
+ function squaredEuclideanDistance(a: Vector, b: Vector): number {
15
+ let sum = 0;
16
+ for (let i = 0; i < a.length; i += 1) {
17
+ const diff = a[i] - b[i];
18
+ sum += diff * diff;
19
+ }
20
+ return sum;
21
+ }
22
+
23
+ export class KNeighborsClassifier implements ClassificationModel {
24
+ classes_: Vector = [0, 1];
25
+ private readonly nNeighbors: number;
26
+ private XTrain: Matrix | null = null;
27
+ private yTrain: Vector | null = null;
28
+
29
+ constructor(options: KNeighborsClassifierOptions = {}) {
30
+ const nNeighbors = options.nNeighbors ?? 5;
31
+ if (!Number.isInteger(nNeighbors) || nNeighbors < 1) {
32
+ throw new Error(`nNeighbors must be a positive integer. Got ${nNeighbors}.`);
33
+ }
34
+ this.nNeighbors = nNeighbors;
35
+ }
36
+
37
+ fit(X: Matrix, y: Vector): this {
38
+ validateClassificationInputs(X, y);
39
+ if (this.nNeighbors > X.length) {
40
+ throw new Error(
41
+ `nNeighbors (${this.nNeighbors}) cannot exceed training size (${X.length}).`,
42
+ );
43
+ }
44
+
45
+ this.XTrain = X.map((row) => [...row]);
46
+ this.yTrain = [...y];
47
+ return this;
48
+ }
49
+
50
+ predict(X: Matrix): Vector {
51
+ if (!this.XTrain || !this.yTrain) {
52
+ throw new Error("KNeighborsClassifier has not been fitted.");
53
+ }
54
+
55
+ assertConsistentRowSize(X);
56
+ assertFiniteMatrix(X);
57
+ if (X[0].length !== this.XTrain[0].length) {
58
+ throw new Error(
59
+ `Feature size mismatch. Expected ${this.XTrain[0].length}, got ${X[0].length}.`,
60
+ );
61
+ }
62
+
63
+ return X.map((sample) => {
64
+ const distances = this.XTrain!.map((row, idx) => ({
65
+ distance: squaredEuclideanDistance(sample, row),
66
+ label: this.yTrain![idx],
67
+ }));
68
+ distances.sort((a, b) => a.distance - b.distance);
69
+
70
+ let positiveVotes = 0;
71
+ for (let i = 0; i < this.nNeighbors; i += 1) {
72
+ if (distances[i].label === 1) {
73
+ positiveVotes += 1;
74
+ }
75
+ }
76
+
77
+ return positiveVotes * 2 >= this.nNeighbors ? 1 : 0;
78
+ });
79
+ }
80
+
81
+ score(X: Matrix, y: Vector): number {
82
+ assertFiniteVector(y);
83
+ return accuracyScore(y, this.predict(X));
84
+ }
85
+ }