bun-scikit 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -292,3 +292,33 @@ export function classificationReport(
292
292
  },
293
293
  };
294
294
  }
295
+
296
+ export function balancedAccuracyScore(yTrue: number[], yPred: number[]): number {
297
+ const report = classificationReport(yTrue, yPred);
298
+ return report.macroAvg.recall;
299
+ }
300
+
301
+ export function matthewsCorrcoef(
302
+ yTrue: number[],
303
+ yPred: number[],
304
+ positiveLabel = 1,
305
+ ): number {
306
+ const { tp, fp, fn, tn } = confusionCounts(yTrue, yPred, positiveLabel);
307
+ const numerator = tp * tn - fp * fn;
308
+ const denominator = Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
309
+ if (denominator === 0) {
310
+ return 0;
311
+ }
312
+ return numerator / denominator;
313
+ }
314
+
315
+ export function brierScoreLoss(yTrue: number[], yPredProb: number[]): number {
316
+ validateInputs(yTrue, yPredProb);
317
+ validateBinaryTargets(yTrue);
318
+ let total = 0;
319
+ for (let i = 0; i < yTrue.length; i += 1) {
320
+ const diff = yPredProb[i] - yTrue[i];
321
+ total += diff * diff;
322
+ }
323
+ return total / yTrue.length;
324
+ }
@@ -49,3 +49,43 @@ export function r2Score(yTrue: number[], yPred: number[]): number {
49
49
 
50
50
  return 1 - ssRes / ssTot;
51
51
  }
52
+
53
+ export function meanAbsolutePercentageError(yTrue: number[], yPred: number[]): number {
54
+ validateInputs(yTrue, yPred);
55
+ let total = 0;
56
+ for (let i = 0; i < yTrue.length; i += 1) {
57
+ const denom = Math.max(Math.abs(yTrue[i]), 1e-12);
58
+ total += Math.abs((yTrue[i] - yPred[i]) / denom);
59
+ }
60
+ return total / yTrue.length;
61
+ }
62
+
63
+ export function explainedVarianceScore(yTrue: number[], yPred: number[]): number {
64
+ validateInputs(yTrue, yPred);
65
+ const n = yTrue.length;
66
+ const yTrueMean = mean(yTrue);
67
+ const residuals = new Array<number>(n);
68
+ let residualMean = 0;
69
+ for (let i = 0; i < n; i += 1) {
70
+ const r = yTrue[i] - yPred[i];
71
+ residuals[i] = r;
72
+ residualMean += r;
73
+ }
74
+ residualMean /= n;
75
+
76
+ let varTrue = 0;
77
+ let varResidual = 0;
78
+ for (let i = 0; i < n; i += 1) {
79
+ const centeredY = yTrue[i] - yTrueMean;
80
+ const centeredR = residuals[i] - residualMean;
81
+ varTrue += centeredY * centeredY;
82
+ varResidual += centeredR * centeredR;
83
+ }
84
+
85
+ varTrue /= n;
86
+ varResidual /= n;
87
+ if (varTrue === 0) {
88
+ return varResidual === 0 ? 1 : 0;
89
+ }
90
+ return 1 - varResidual / varTrue;
91
+ }
@@ -0,0 +1,269 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import { accuracyScore, f1Score, precisionScore, recallScore } from "../metrics/classification";
3
+ import { meanSquaredError, r2Score } from "../metrics/regression";
4
+ import {
5
+ crossValScore,
6
+ type BuiltInScoring,
7
+ type CrossValEstimator,
8
+ type CrossValSplitter,
9
+ type ScoringFn,
10
+ } from "./crossValScore";
11
+
12
+ export type ParamDistributions = Record<string, readonly unknown[]>;
13
+
14
+ export interface RandomizedSearchCVOptions {
15
+ cv?: number | CrossValSplitter;
16
+ scoring?: BuiltInScoring | ScoringFn;
17
+ refit?: boolean;
18
+ errorScore?: "raise" | number;
19
+ nIter?: number;
20
+ randomState?: number;
21
+ }
22
+
23
+ export interface RandomizedSearchResultRow {
24
+ params: Record<string, unknown>;
25
+ splitScores: number[];
26
+ meanTestScore: number;
27
+ stdTestScore: number;
28
+ rank: number;
29
+ status: "ok" | "error";
30
+ errorMessage?: string;
31
+ }
32
+
33
+ function mean(values: number[]): number {
34
+ let sum = 0;
35
+ for (let i = 0; i < values.length; i += 1) {
36
+ sum += values[i];
37
+ }
38
+ return sum / values.length;
39
+ }
40
+
41
+ function std(values: number[]): number {
42
+ if (values.length < 2) {
43
+ return 0;
44
+ }
45
+ const avg = mean(values);
46
+ let sum = 0;
47
+ for (let i = 0; i < values.length; i += 1) {
48
+ const diff = values[i] - avg;
49
+ sum += diff * diff;
50
+ }
51
+ return Math.sqrt(sum / values.length);
52
+ }
53
+
54
+ function resolveBuiltInScorer(scoring: BuiltInScoring): ScoringFn {
55
+ switch (scoring) {
56
+ case "accuracy":
57
+ return accuracyScore;
58
+ case "f1":
59
+ return f1Score;
60
+ case "precision":
61
+ return precisionScore;
62
+ case "recall":
63
+ return recallScore;
64
+ case "r2":
65
+ return r2Score;
66
+ case "mean_squared_error":
67
+ return meanSquaredError;
68
+ case "neg_mean_squared_error":
69
+ return (yTrue, yPred) => -meanSquaredError(yTrue, yPred);
70
+ default: {
71
+ const exhaustive: never = scoring;
72
+ throw new Error(`Unsupported scoring metric: ${exhaustive}`);
73
+ }
74
+ }
75
+ }
76
+
77
+ function isLossMetric(scoring: BuiltInScoring | ScoringFn | undefined): boolean {
78
+ return scoring === "mean_squared_error";
79
+ }
80
+
81
+ class Mulberry32 {
82
+ private state: number;
83
+
84
+ constructor(seed: number) {
85
+ this.state = seed >>> 0;
86
+ }
87
+
88
+ next(): number {
89
+ this.state = (this.state + 0x6d2b79f5) >>> 0;
90
+ let t = this.state ^ (this.state >>> 15);
91
+ t = Math.imul(t, this.state | 1);
92
+ t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
93
+ return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
94
+ }
95
+
96
+ nextInt(maxExclusive: number): number {
97
+ return Math.floor(this.next() * maxExclusive);
98
+ }
99
+ }
100
+
101
+ function sampleParams(
102
+ distributions: ParamDistributions,
103
+ nIter: number,
104
+ randomState: number,
105
+ ): Record<string, unknown>[] {
106
+ const keys = Object.keys(distributions);
107
+ if (keys.length === 0) {
108
+ throw new Error("paramDistributions must include at least one parameter.");
109
+ }
110
+ for (let i = 0; i < keys.length; i += 1) {
111
+ const values = distributions[keys[i]];
112
+ if (!Array.isArray(values) || values.length === 0) {
113
+ throw new Error(`paramDistributions '${keys[i]}' must be a non-empty array.`);
114
+ }
115
+ }
116
+
117
+ const rng = new Mulberry32(randomState);
118
+ const out: Record<string, unknown>[] = [];
119
+ for (let i = 0; i < nIter; i += 1) {
120
+ const params: Record<string, unknown> = {};
121
+ for (let k = 0; k < keys.length; k += 1) {
122
+ const key = keys[k];
123
+ const values = distributions[key];
124
+ params[key] = values[rng.nextInt(values.length)];
125
+ }
126
+ out.push(params);
127
+ }
128
+ return out;
129
+ }
130
+
131
+ export class RandomizedSearchCV<TEstimator extends CrossValEstimator> {
132
+ bestEstimator_: TEstimator | null = null;
133
+ bestParams_: Record<string, unknown> | null = null;
134
+ bestScore_: number | null = null;
135
+ cvResults_: RandomizedSearchResultRow[] = [];
136
+
137
+ private readonly estimatorFactory: (params: Record<string, unknown>) => TEstimator;
138
+ private readonly paramDistributions: ParamDistributions;
139
+ private readonly cv?: number | CrossValSplitter;
140
+ private readonly scoring?: BuiltInScoring | ScoringFn;
141
+ private readonly refit: boolean;
142
+ private readonly errorScore: "raise" | number;
143
+ private readonly nIter: number;
144
+ private readonly randomState: number;
145
+ private isFitted = false;
146
+
147
+ constructor(
148
+ estimatorFactory: (params: Record<string, unknown>) => TEstimator,
149
+ paramDistributions: ParamDistributions,
150
+ options: RandomizedSearchCVOptions = {},
151
+ ) {
152
+ if (typeof estimatorFactory !== "function") {
153
+ throw new Error("estimatorFactory must be a function.");
154
+ }
155
+ this.estimatorFactory = estimatorFactory;
156
+ this.paramDistributions = paramDistributions;
157
+ this.cv = options.cv;
158
+ this.scoring = options.scoring;
159
+ this.refit = options.refit ?? true;
160
+ this.errorScore = options.errorScore ?? "raise";
161
+ this.nIter = options.nIter ?? 10;
162
+ this.randomState = options.randomState ?? 42;
163
+ if (!Number.isInteger(this.nIter) || this.nIter < 1) {
164
+ throw new Error(`nIter must be an integer >= 1. Got ${this.nIter}.`);
165
+ }
166
+ }
167
+
168
+ fit(X: Matrix, y: Vector): this {
169
+ const candidates = sampleParams(this.paramDistributions, this.nIter, this.randomState);
170
+ const minimize = isLossMetric(this.scoring);
171
+ const rows: RandomizedSearchResultRow[] = [];
172
+ const objectiveScores: number[] = [];
173
+
174
+ for (let candidateIndex = 0; candidateIndex < candidates.length; candidateIndex += 1) {
175
+ const params = candidates[candidateIndex];
176
+ try {
177
+ const splitScores = crossValScore(
178
+ () => this.estimatorFactory(params),
179
+ X,
180
+ y,
181
+ { cv: this.cv, scoring: this.scoring },
182
+ );
183
+ const meanTestScore = mean(splitScores);
184
+ rows.push({
185
+ params: { ...params },
186
+ splitScores,
187
+ meanTestScore,
188
+ stdTestScore: std(splitScores),
189
+ rank: 0,
190
+ status: "ok",
191
+ });
192
+ objectiveScores.push(minimize ? -meanTestScore : meanTestScore);
193
+ } catch (error) {
194
+ if (this.errorScore === "raise") {
195
+ throw error;
196
+ }
197
+ rows.push({
198
+ params: { ...params },
199
+ splitScores: [this.errorScore],
200
+ meanTestScore: this.errorScore,
201
+ stdTestScore: 0,
202
+ rank: 0,
203
+ status: "error",
204
+ errorMessage: error instanceof Error ? error.message : String(error),
205
+ });
206
+ objectiveScores.push(minimize ? -this.errorScore : this.errorScore);
207
+ }
208
+ }
209
+
210
+ const order = Array.from({ length: rows.length }, (_, idx) => idx).sort((a, b) => {
211
+ const delta = objectiveScores[b] - objectiveScores[a];
212
+ if (delta !== 0) {
213
+ return delta;
214
+ }
215
+ return a - b;
216
+ });
217
+
218
+ for (let rank = 0; rank < order.length; rank += 1) {
219
+ rows[order[rank]].rank = rank + 1;
220
+ }
221
+
222
+ const bestIndex = order[0];
223
+ this.bestParams_ = { ...rows[bestIndex].params };
224
+ this.bestScore_ = rows[bestIndex].meanTestScore;
225
+ this.cvResults_ = rows;
226
+
227
+ if (this.refit) {
228
+ const estimator = this.estimatorFactory(this.bestParams_);
229
+ estimator.fit(X, y);
230
+ this.bestEstimator_ = estimator;
231
+ } else {
232
+ this.bestEstimator_ = null;
233
+ }
234
+
235
+ this.isFitted = true;
236
+ return this;
237
+ }
238
+
239
+ predict(X: Matrix): Vector {
240
+ if (!this.isFitted) {
241
+ throw new Error("RandomizedSearchCV has not been fitted.");
242
+ }
243
+ if (!this.refit || !this.bestEstimator_) {
244
+ throw new Error("RandomizedSearchCV predict is unavailable when refit=false.");
245
+ }
246
+ return this.bestEstimator_.predict(X);
247
+ }
248
+
249
+ score(X: Matrix, y: Vector): number {
250
+ if (!this.isFitted) {
251
+ throw new Error("RandomizedSearchCV has not been fitted.");
252
+ }
253
+ if (!this.refit || !this.bestEstimator_) {
254
+ throw new Error("RandomizedSearchCV score is unavailable when refit=false.");
255
+ }
256
+
257
+ if (this.scoring) {
258
+ const scorer =
259
+ typeof this.scoring === "function" ? this.scoring : resolveBuiltInScorer(this.scoring);
260
+ return scorer(y, this.bestEstimator_.predict(X));
261
+ }
262
+
263
+ if (typeof this.bestEstimator_.score === "function") {
264
+ return this.bestEstimator_.score(X, y);
265
+ }
266
+
267
+ throw new Error("No scoring function available. Provide scoring in RandomizedSearchCV options.");
268
+ }
269
+ }
@@ -27,6 +27,24 @@ using LogisticModelFitFn = std::size_t (*)(NativeHandle, const double*, const do
27
27
  using LogisticModelFitLbfgsFn = std::size_t (*)(NativeHandle, const double*, const double*, std::size_t, std::size_t, double, double, std::size_t);
28
28
  using LogisticModelCopyCoefficientsFn = std::uint8_t (*)(NativeHandle, double*);
29
29
  using LogisticModelGetInterceptFn = double (*)(NativeHandle);
30
+ using DecisionTreeModelCreateFn = NativeHandle (*)(std::size_t, std::size_t, std::size_t, std::uint8_t, std::size_t, std::uint32_t, std::uint8_t, std::size_t);
31
+ using DecisionTreeModelDestroyFn = void (*)(NativeHandle);
32
+ using DecisionTreeModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t, const std::uint32_t*, std::size_t);
33
+ using DecisionTreeModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
34
+ using RandomForestClassifierModelCreateFn = NativeHandle (*)(
35
+ std::size_t,
36
+ std::size_t,
37
+ std::size_t,
38
+ std::size_t,
39
+ std::uint8_t,
40
+ std::size_t,
41
+ std::uint8_t,
42
+ std::uint32_t,
43
+ std::uint8_t,
44
+ std::size_t);
45
+ using RandomForestClassifierModelDestroyFn = void (*)(NativeHandle);
46
+ using RandomForestClassifierModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t);
47
+ using RandomForestClassifierModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
30
48
 
31
49
  struct KernelLibrary {
32
50
  #if defined(_WIN32)
@@ -47,6 +65,14 @@ struct KernelLibrary {
47
65
  LogisticModelFitLbfgsFn logistic_model_fit_lbfgs{nullptr};
48
66
  LogisticModelCopyCoefficientsFn logistic_model_copy_coefficients{nullptr};
49
67
  LogisticModelGetInterceptFn logistic_model_get_intercept{nullptr};
68
+ DecisionTreeModelCreateFn decision_tree_model_create{nullptr};
69
+ DecisionTreeModelDestroyFn decision_tree_model_destroy{nullptr};
70
+ DecisionTreeModelFitFn decision_tree_model_fit{nullptr};
71
+ DecisionTreeModelPredictFn decision_tree_model_predict{nullptr};
72
+ RandomForestClassifierModelCreateFn random_forest_classifier_model_create{nullptr};
73
+ RandomForestClassifierModelDestroyFn random_forest_classifier_model_destroy{nullptr};
74
+ RandomForestClassifierModelFitFn random_forest_classifier_model_fit{nullptr};
75
+ RandomForestClassifierModelPredictFn random_forest_classifier_model_predict{nullptr};
50
76
  };
51
77
 
52
78
  KernelLibrary g_library{};
@@ -138,6 +164,22 @@ Napi::Value LoadNativeLibrary(const Napi::CallbackInfo& info) {
138
164
  loadSymbol<LogisticModelCopyCoefficientsFn>("logistic_model_copy_coefficients");
139
165
  g_library.logistic_model_get_intercept =
140
166
  loadSymbol<LogisticModelGetInterceptFn>("logistic_model_get_intercept");
167
+ g_library.decision_tree_model_create =
168
+ loadSymbol<DecisionTreeModelCreateFn>("decision_tree_model_create");
169
+ g_library.decision_tree_model_destroy =
170
+ loadSymbol<DecisionTreeModelDestroyFn>("decision_tree_model_destroy");
171
+ g_library.decision_tree_model_fit =
172
+ loadSymbol<DecisionTreeModelFitFn>("decision_tree_model_fit");
173
+ g_library.decision_tree_model_predict =
174
+ loadSymbol<DecisionTreeModelPredictFn>("decision_tree_model_predict");
175
+ g_library.random_forest_classifier_model_create =
176
+ loadSymbol<RandomForestClassifierModelCreateFn>("random_forest_classifier_model_create");
177
+ g_library.random_forest_classifier_model_destroy =
178
+ loadSymbol<RandomForestClassifierModelDestroyFn>("random_forest_classifier_model_destroy");
179
+ g_library.random_forest_classifier_model_fit =
180
+ loadSymbol<RandomForestClassifierModelFitFn>("random_forest_classifier_model_fit");
181
+ g_library.random_forest_classifier_model_predict =
182
+ loadSymbol<RandomForestClassifierModelPredictFn>("random_forest_classifier_model_predict");
141
183
 
142
184
  return Napi::Boolean::New(env, true);
143
185
  }
@@ -423,6 +465,262 @@ Napi::Value LogisticModelGetIntercept(const Napi::CallbackInfo& info) {
423
465
  return Napi::Number::New(env, g_library.logistic_model_get_intercept(handle));
424
466
  }
425
467
 
468
+ Napi::Value DecisionTreeModelCreate(const Napi::CallbackInfo& info) {
469
+ const Napi::Env env = info.Env();
470
+ if (!isLibraryLoaded(env)) {
471
+ return env.Null();
472
+ }
473
+ if (!g_library.decision_tree_model_create) {
474
+ throwError(env, "Symbol decision_tree_model_create is unavailable.");
475
+ return env.Null();
476
+ }
477
+ if (info.Length() != 8 || !info[0].IsNumber() || !info[1].IsNumber() || !info[2].IsNumber() ||
478
+ !info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsNumber() || !info[6].IsNumber() ||
479
+ !info[7].IsNumber()) {
480
+ throwTypeError(env, "decisionTreeModelCreate(maxDepth, minSamplesSplit, minSamplesLeaf, maxFeaturesMode, maxFeaturesValue, randomState, useRandomState, nFeatures) expects eight numbers.");
481
+ return env.Null();
482
+ }
483
+
484
+ const std::size_t max_depth = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
485
+ const std::size_t min_samples_split = static_cast<std::size_t>(info[1].As<Napi::Number>().Uint32Value());
486
+ const std::size_t min_samples_leaf = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
487
+ const std::uint8_t max_features_mode = static_cast<std::uint8_t>(info[3].As<Napi::Number>().Uint32Value());
488
+ const std::size_t max_features_value = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
489
+ const std::uint32_t random_state = static_cast<std::uint32_t>(info[5].As<Napi::Number>().Uint32Value());
490
+ const std::uint8_t use_random_state = static_cast<std::uint8_t>(info[6].As<Napi::Number>().Uint32Value());
491
+ const std::size_t n_features = static_cast<std::size_t>(info[7].As<Napi::Number>().Uint32Value());
492
+
493
+ const NativeHandle handle = g_library.decision_tree_model_create(
494
+ max_depth,
495
+ min_samples_split,
496
+ min_samples_leaf,
497
+ max_features_mode,
498
+ max_features_value,
499
+ random_state,
500
+ use_random_state,
501
+ n_features);
502
+ return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
503
+ }
504
+
505
+ Napi::Value DecisionTreeModelDestroy(const Napi::CallbackInfo& info) {
506
+ const Napi::Env env = info.Env();
507
+ if (!isLibraryLoaded(env)) {
508
+ return env.Null();
509
+ }
510
+ if (!g_library.decision_tree_model_destroy) {
511
+ throwError(env, "Symbol decision_tree_model_destroy is unavailable.");
512
+ return env.Null();
513
+ }
514
+ if (info.Length() != 1) {
515
+ throwTypeError(env, "decisionTreeModelDestroy(handle) expects one BigInt.");
516
+ return env.Null();
517
+ }
518
+ const NativeHandle handle = handleFromBigInt(info[0], env);
519
+ if (env.IsExceptionPending()) {
520
+ return env.Null();
521
+ }
522
+ g_library.decision_tree_model_destroy(handle);
523
+ return env.Undefined();
524
+ }
525
+
526
+ Napi::Value DecisionTreeModelFit(const Napi::CallbackInfo& info) {
527
+ const Napi::Env env = info.Env();
528
+ if (!isLibraryLoaded(env)) {
529
+ return env.Null();
530
+ }
531
+ if (!g_library.decision_tree_model_fit) {
532
+ throwError(env, "Symbol decision_tree_model_fit is unavailable.");
533
+ return env.Null();
534
+ }
535
+ if (info.Length() != 7 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
536
+ !info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsTypedArray() || !info[6].IsNumber()) {
537
+ throwTypeError(env, "decisionTreeModelFit(handle, x, y, nSamples, nFeatures, sampleIndices, sampleCount) has invalid arguments.");
538
+ return env.Null();
539
+ }
540
+
541
+ const NativeHandle handle = handleFromBigInt(info[0], env);
542
+ if (env.IsExceptionPending()) {
543
+ return env.Null();
544
+ }
545
+ auto x = info[1].As<Napi::Float64Array>();
546
+ auto y = info[2].As<Napi::Uint8Array>();
547
+ const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
548
+ const std::size_t n_features = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
549
+ auto sample_indices = info[5].As<Napi::Uint32Array>();
550
+ const std::size_t sample_count = static_cast<std::size_t>(info[6].As<Napi::Number>().Uint32Value());
551
+
552
+ const std::uint8_t status = g_library.decision_tree_model_fit(
553
+ handle,
554
+ x.Data(),
555
+ y.Data(),
556
+ n_samples,
557
+ n_features,
558
+ sample_indices.Data(),
559
+ sample_count);
560
+ return Napi::Number::New(env, status);
561
+ }
562
+
563
+ Napi::Value DecisionTreeModelPredict(const Napi::CallbackInfo& info) {
564
+ const Napi::Env env = info.Env();
565
+ if (!isLibraryLoaded(env)) {
566
+ return env.Null();
567
+ }
568
+ if (!g_library.decision_tree_model_predict) {
569
+ throwError(env, "Symbol decision_tree_model_predict is unavailable.");
570
+ return env.Null();
571
+ }
572
+ if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsNumber() || !info[3].IsNumber() ||
573
+ !info[4].IsTypedArray()) {
574
+ throwTypeError(env, "decisionTreeModelPredict(handle, x, nSamples, nFeatures, outLabels) has invalid arguments.");
575
+ return env.Null();
576
+ }
577
+
578
+ const NativeHandle handle = handleFromBigInt(info[0], env);
579
+ if (env.IsExceptionPending()) {
580
+ return env.Null();
581
+ }
582
+ auto x = info[1].As<Napi::Float64Array>();
583
+ const std::size_t n_samples = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
584
+ const std::size_t n_features = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
585
+ auto out_labels = info[4].As<Napi::Uint8Array>();
586
+
587
+ const std::uint8_t status = g_library.decision_tree_model_predict(
588
+ handle,
589
+ x.Data(),
590
+ n_samples,
591
+ n_features,
592
+ out_labels.Data());
593
+ return Napi::Number::New(env, status);
594
+ }
595
+
596
+ Napi::Value RandomForestClassifierModelCreate(const Napi::CallbackInfo& info) {
597
+ const Napi::Env env = info.Env();
598
+ if (!isLibraryLoaded(env)) {
599
+ return env.Null();
600
+ }
601
+ if (!g_library.random_forest_classifier_model_create) {
602
+ throwError(env, "Symbol random_forest_classifier_model_create is unavailable.");
603
+ return env.Null();
604
+ }
605
+ if (info.Length() != 10 || !info[0].IsNumber() || !info[1].IsNumber() || !info[2].IsNumber() ||
606
+ !info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsNumber() || !info[6].IsNumber() ||
607
+ !info[7].IsNumber() || !info[8].IsNumber() || !info[9].IsNumber()) {
608
+ throwTypeError(env, "randomForestClassifierModelCreate(nEstimators, maxDepth, minSamplesSplit, minSamplesLeaf, maxFeaturesMode, maxFeaturesValue, bootstrap, randomState, useRandomState, nFeatures) expects ten numbers.");
609
+ return env.Null();
610
+ }
611
+
612
+ const std::size_t n_estimators = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
613
+ const std::size_t max_depth = static_cast<std::size_t>(info[1].As<Napi::Number>().Uint32Value());
614
+ const std::size_t min_samples_split = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
615
+ const std::size_t min_samples_leaf = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
616
+ const std::uint8_t max_features_mode = static_cast<std::uint8_t>(info[4].As<Napi::Number>().Uint32Value());
617
+ const std::size_t max_features_value = static_cast<std::size_t>(info[5].As<Napi::Number>().Uint32Value());
618
+ const std::uint8_t bootstrap = static_cast<std::uint8_t>(info[6].As<Napi::Number>().Uint32Value());
619
+ const std::uint32_t random_state = static_cast<std::uint32_t>(info[7].As<Napi::Number>().Uint32Value());
620
+ const std::uint8_t use_random_state = static_cast<std::uint8_t>(info[8].As<Napi::Number>().Uint32Value());
621
+ const std::size_t n_features = static_cast<std::size_t>(info[9].As<Napi::Number>().Uint32Value());
622
+
623
+ const NativeHandle handle = g_library.random_forest_classifier_model_create(
624
+ n_estimators,
625
+ max_depth,
626
+ min_samples_split,
627
+ min_samples_leaf,
628
+ max_features_mode,
629
+ max_features_value,
630
+ bootstrap,
631
+ random_state,
632
+ use_random_state,
633
+ n_features);
634
+ return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
635
+ }
636
+
637
+ Napi::Value RandomForestClassifierModelDestroy(const Napi::CallbackInfo& info) {
638
+ const Napi::Env env = info.Env();
639
+ if (!isLibraryLoaded(env)) {
640
+ return env.Null();
641
+ }
642
+ if (!g_library.random_forest_classifier_model_destroy) {
643
+ throwError(env, "Symbol random_forest_classifier_model_destroy is unavailable.");
644
+ return env.Null();
645
+ }
646
+ if (info.Length() != 1) {
647
+ throwTypeError(env, "randomForestClassifierModelDestroy(handle) expects one BigInt.");
648
+ return env.Null();
649
+ }
650
+ const NativeHandle handle = handleFromBigInt(info[0], env);
651
+ if (env.IsExceptionPending()) {
652
+ return env.Null();
653
+ }
654
+ g_library.random_forest_classifier_model_destroy(handle);
655
+ return env.Undefined();
656
+ }
657
+
658
+ Napi::Value RandomForestClassifierModelFit(const Napi::CallbackInfo& info) {
659
+ const Napi::Env env = info.Env();
660
+ if (!isLibraryLoaded(env)) {
661
+ return env.Null();
662
+ }
663
+ if (!g_library.random_forest_classifier_model_fit) {
664
+ throwError(env, "Symbol random_forest_classifier_model_fit is unavailable.");
665
+ return env.Null();
666
+ }
667
+ if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
668
+ !info[3].IsNumber() || !info[4].IsNumber()) {
669
+ throwTypeError(env, "randomForestClassifierModelFit(handle, x, y, nSamples, nFeatures) has invalid arguments.");
670
+ return env.Null();
671
+ }
672
+
673
+ const NativeHandle handle = handleFromBigInt(info[0], env);
674
+ if (env.IsExceptionPending()) {
675
+ return env.Null();
676
+ }
677
+ auto x = info[1].As<Napi::Float64Array>();
678
+ auto y = info[2].As<Napi::Uint8Array>();
679
+ const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
680
+ const std::size_t n_features = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
681
+
682
+ const std::uint8_t status = g_library.random_forest_classifier_model_fit(
683
+ handle,
684
+ x.Data(),
685
+ y.Data(),
686
+ n_samples,
687
+ n_features);
688
+ return Napi::Number::New(env, status);
689
+ }
690
+
691
+ Napi::Value RandomForestClassifierModelPredict(const Napi::CallbackInfo& info) {
692
+ const Napi::Env env = info.Env();
693
+ if (!isLibraryLoaded(env)) {
694
+ return env.Null();
695
+ }
696
+ if (!g_library.random_forest_classifier_model_predict) {
697
+ throwError(env, "Symbol random_forest_classifier_model_predict is unavailable.");
698
+ return env.Null();
699
+ }
700
+ if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsNumber() || !info[3].IsNumber() ||
701
+ !info[4].IsTypedArray()) {
702
+ throwTypeError(env, "randomForestClassifierModelPredict(handle, x, nSamples, nFeatures, outLabels) has invalid arguments.");
703
+ return env.Null();
704
+ }
705
+
706
+ const NativeHandle handle = handleFromBigInt(info[0], env);
707
+ if (env.IsExceptionPending()) {
708
+ return env.Null();
709
+ }
710
+ auto x = info[1].As<Napi::Float64Array>();
711
+ const std::size_t n_samples = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
712
+ const std::size_t n_features = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
713
+ auto out_labels = info[4].As<Napi::Uint8Array>();
714
+
715
+ const std::uint8_t status = g_library.random_forest_classifier_model_predict(
716
+ handle,
717
+ x.Data(),
718
+ n_samples,
719
+ n_features,
720
+ out_labels.Data());
721
+ return Napi::Number::New(env, status);
722
+ }
723
+
426
724
  Napi::Object Init(Napi::Env env, Napi::Object exports) {
427
725
  exports.Set("loadLibrary", Napi::Function::New(env, LoadNativeLibrary));
428
726
  exports.Set("unloadLibrary", Napi::Function::New(env, UnloadLibrary));
@@ -442,6 +740,15 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
442
740
  exports.Set("logisticModelCopyCoefficients", Napi::Function::New(env, LogisticModelCopyCoefficients));
443
741
  exports.Set("logisticModelGetIntercept", Napi::Function::New(env, LogisticModelGetIntercept));
444
742
 
743
+ exports.Set("decisionTreeModelCreate", Napi::Function::New(env, DecisionTreeModelCreate));
744
+ exports.Set("decisionTreeModelDestroy", Napi::Function::New(env, DecisionTreeModelDestroy));
745
+ exports.Set("decisionTreeModelFit", Napi::Function::New(env, DecisionTreeModelFit));
746
+ exports.Set("decisionTreeModelPredict", Napi::Function::New(env, DecisionTreeModelPredict));
747
+ exports.Set("randomForestClassifierModelCreate", Napi::Function::New(env, RandomForestClassifierModelCreate));
748
+ exports.Set("randomForestClassifierModelDestroy", Napi::Function::New(env, RandomForestClassifierModelDestroy));
749
+ exports.Set("randomForestClassifierModelFit", Napi::Function::New(env, RandomForestClassifierModelFit));
750
+ exports.Set("randomForestClassifierModelPredict", Napi::Function::New(env, RandomForestClassifierModelPredict));
751
+
445
752
  return exports;
446
753
  }
447
754