bun-scikit 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +73 -137
- package/package.json +2 -2
- package/scripts/check-benchmark-health.ts +62 -1
- package/scripts/sync-benchmark-readme.ts +56 -0
- package/src/dummy/DummyClassifier.ts +190 -0
- package/src/dummy/DummyRegressor.ts +108 -0
- package/src/feature_selection/VarianceThreshold.ts +88 -0
- package/src/index.ts +23 -0
- package/src/metrics/classification.ts +30 -0
- package/src/metrics/regression.ts +40 -0
- package/src/model_selection/RandomizedSearchCV.ts +269 -0
- package/src/native/node-addon/bun_scikit_addon.cpp +149 -0
- package/src/native/zigKernels.ts +33 -4
- package/src/preprocessing/Binarizer.ts +46 -0
- package/src/preprocessing/LabelEncoder.ts +62 -0
- package/src/preprocessing/MaxAbsScaler.ts +77 -0
- package/src/preprocessing/Normalizer.ts +66 -0
- package/src/tree/DecisionTreeClassifier.ts +146 -3
- package/zig/kernels.zig +63 -40
|
@@ -27,6 +27,10 @@ using LogisticModelFitFn = std::size_t (*)(NativeHandle, const double*, const do
|
|
|
27
27
|
using LogisticModelFitLbfgsFn = std::size_t (*)(NativeHandle, const double*, const double*, std::size_t, std::size_t, double, double, std::size_t);
|
|
28
28
|
using LogisticModelCopyCoefficientsFn = std::uint8_t (*)(NativeHandle, double*);
|
|
29
29
|
using LogisticModelGetInterceptFn = double (*)(NativeHandle);
|
|
30
|
+
using DecisionTreeModelCreateFn = NativeHandle (*)(std::size_t, std::size_t, std::size_t, std::uint8_t, std::size_t, std::uint32_t, std::uint8_t, std::size_t);
|
|
31
|
+
using DecisionTreeModelDestroyFn = void (*)(NativeHandle);
|
|
32
|
+
using DecisionTreeModelFitFn = std::uint8_t (*)(NativeHandle, const double*, const std::uint8_t*, std::size_t, std::size_t, const std::uint32_t*, std::size_t);
|
|
33
|
+
using DecisionTreeModelPredictFn = std::uint8_t (*)(NativeHandle, const double*, std::size_t, std::size_t, std::uint8_t*);
|
|
30
34
|
|
|
31
35
|
struct KernelLibrary {
|
|
32
36
|
#if defined(_WIN32)
|
|
@@ -47,6 +51,10 @@ struct KernelLibrary {
|
|
|
47
51
|
LogisticModelFitLbfgsFn logistic_model_fit_lbfgs{nullptr};
|
|
48
52
|
LogisticModelCopyCoefficientsFn logistic_model_copy_coefficients{nullptr};
|
|
49
53
|
LogisticModelGetInterceptFn logistic_model_get_intercept{nullptr};
|
|
54
|
+
DecisionTreeModelCreateFn decision_tree_model_create{nullptr};
|
|
55
|
+
DecisionTreeModelDestroyFn decision_tree_model_destroy{nullptr};
|
|
56
|
+
DecisionTreeModelFitFn decision_tree_model_fit{nullptr};
|
|
57
|
+
DecisionTreeModelPredictFn decision_tree_model_predict{nullptr};
|
|
50
58
|
};
|
|
51
59
|
|
|
52
60
|
KernelLibrary g_library{};
|
|
@@ -138,6 +146,14 @@ Napi::Value LoadNativeLibrary(const Napi::CallbackInfo& info) {
|
|
|
138
146
|
loadSymbol<LogisticModelCopyCoefficientsFn>("logistic_model_copy_coefficients");
|
|
139
147
|
g_library.logistic_model_get_intercept =
|
|
140
148
|
loadSymbol<LogisticModelGetInterceptFn>("logistic_model_get_intercept");
|
|
149
|
+
g_library.decision_tree_model_create =
|
|
150
|
+
loadSymbol<DecisionTreeModelCreateFn>("decision_tree_model_create");
|
|
151
|
+
g_library.decision_tree_model_destroy =
|
|
152
|
+
loadSymbol<DecisionTreeModelDestroyFn>("decision_tree_model_destroy");
|
|
153
|
+
g_library.decision_tree_model_fit =
|
|
154
|
+
loadSymbol<DecisionTreeModelFitFn>("decision_tree_model_fit");
|
|
155
|
+
g_library.decision_tree_model_predict =
|
|
156
|
+
loadSymbol<DecisionTreeModelPredictFn>("decision_tree_model_predict");
|
|
141
157
|
|
|
142
158
|
return Napi::Boolean::New(env, true);
|
|
143
159
|
}
|
|
@@ -423,6 +439,134 @@ Napi::Value LogisticModelGetIntercept(const Napi::CallbackInfo& info) {
|
|
|
423
439
|
return Napi::Number::New(env, g_library.logistic_model_get_intercept(handle));
|
|
424
440
|
}
|
|
425
441
|
|
|
442
|
+
Napi::Value DecisionTreeModelCreate(const Napi::CallbackInfo& info) {
|
|
443
|
+
const Napi::Env env = info.Env();
|
|
444
|
+
if (!isLibraryLoaded(env)) {
|
|
445
|
+
return env.Null();
|
|
446
|
+
}
|
|
447
|
+
if (!g_library.decision_tree_model_create) {
|
|
448
|
+
throwError(env, "Symbol decision_tree_model_create is unavailable.");
|
|
449
|
+
return env.Null();
|
|
450
|
+
}
|
|
451
|
+
if (info.Length() != 8 || !info[0].IsNumber() || !info[1].IsNumber() || !info[2].IsNumber() ||
|
|
452
|
+
!info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsNumber() || !info[6].IsNumber() ||
|
|
453
|
+
!info[7].IsNumber()) {
|
|
454
|
+
throwTypeError(env, "decisionTreeModelCreate(maxDepth, minSamplesSplit, minSamplesLeaf, maxFeaturesMode, maxFeaturesValue, randomState, useRandomState, nFeatures) expects eight numbers.");
|
|
455
|
+
return env.Null();
|
|
456
|
+
}
|
|
457
|
+
|
|
458
|
+
const std::size_t max_depth = static_cast<std::size_t>(info[0].As<Napi::Number>().Uint32Value());
|
|
459
|
+
const std::size_t min_samples_split = static_cast<std::size_t>(info[1].As<Napi::Number>().Uint32Value());
|
|
460
|
+
const std::size_t min_samples_leaf = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
|
|
461
|
+
const std::uint8_t max_features_mode = static_cast<std::uint8_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
462
|
+
const std::size_t max_features_value = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
|
|
463
|
+
const std::uint32_t random_state = static_cast<std::uint32_t>(info[5].As<Napi::Number>().Uint32Value());
|
|
464
|
+
const std::uint8_t use_random_state = static_cast<std::uint8_t>(info[6].As<Napi::Number>().Uint32Value());
|
|
465
|
+
const std::size_t n_features = static_cast<std::size_t>(info[7].As<Napi::Number>().Uint32Value());
|
|
466
|
+
|
|
467
|
+
const NativeHandle handle = g_library.decision_tree_model_create(
|
|
468
|
+
max_depth,
|
|
469
|
+
min_samples_split,
|
|
470
|
+
min_samples_leaf,
|
|
471
|
+
max_features_mode,
|
|
472
|
+
max_features_value,
|
|
473
|
+
random_state,
|
|
474
|
+
use_random_state,
|
|
475
|
+
n_features);
|
|
476
|
+
return Napi::BigInt::New(env, static_cast<std::uint64_t>(handle));
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
Napi::Value DecisionTreeModelDestroy(const Napi::CallbackInfo& info) {
|
|
480
|
+
const Napi::Env env = info.Env();
|
|
481
|
+
if (!isLibraryLoaded(env)) {
|
|
482
|
+
return env.Null();
|
|
483
|
+
}
|
|
484
|
+
if (!g_library.decision_tree_model_destroy) {
|
|
485
|
+
throwError(env, "Symbol decision_tree_model_destroy is unavailable.");
|
|
486
|
+
return env.Null();
|
|
487
|
+
}
|
|
488
|
+
if (info.Length() != 1) {
|
|
489
|
+
throwTypeError(env, "decisionTreeModelDestroy(handle) expects one BigInt.");
|
|
490
|
+
return env.Null();
|
|
491
|
+
}
|
|
492
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
493
|
+
if (env.IsExceptionPending()) {
|
|
494
|
+
return env.Null();
|
|
495
|
+
}
|
|
496
|
+
g_library.decision_tree_model_destroy(handle);
|
|
497
|
+
return env.Undefined();
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
Napi::Value DecisionTreeModelFit(const Napi::CallbackInfo& info) {
|
|
501
|
+
const Napi::Env env = info.Env();
|
|
502
|
+
if (!isLibraryLoaded(env)) {
|
|
503
|
+
return env.Null();
|
|
504
|
+
}
|
|
505
|
+
if (!g_library.decision_tree_model_fit) {
|
|
506
|
+
throwError(env, "Symbol decision_tree_model_fit is unavailable.");
|
|
507
|
+
return env.Null();
|
|
508
|
+
}
|
|
509
|
+
if (info.Length() != 7 || !info[1].IsTypedArray() || !info[2].IsTypedArray() ||
|
|
510
|
+
!info[3].IsNumber() || !info[4].IsNumber() || !info[5].IsTypedArray() || !info[6].IsNumber()) {
|
|
511
|
+
throwTypeError(env, "decisionTreeModelFit(handle, x, y, nSamples, nFeatures, sampleIndices, sampleCount) has invalid arguments.");
|
|
512
|
+
return env.Null();
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
516
|
+
if (env.IsExceptionPending()) {
|
|
517
|
+
return env.Null();
|
|
518
|
+
}
|
|
519
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
520
|
+
auto y = info[2].As<Napi::Uint8Array>();
|
|
521
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
522
|
+
const std::size_t n_features = static_cast<std::size_t>(info[4].As<Napi::Number>().Uint32Value());
|
|
523
|
+
auto sample_indices = info[5].As<Napi::Uint32Array>();
|
|
524
|
+
const std::size_t sample_count = static_cast<std::size_t>(info[6].As<Napi::Number>().Uint32Value());
|
|
525
|
+
|
|
526
|
+
const std::uint8_t status = g_library.decision_tree_model_fit(
|
|
527
|
+
handle,
|
|
528
|
+
x.Data(),
|
|
529
|
+
y.Data(),
|
|
530
|
+
n_samples,
|
|
531
|
+
n_features,
|
|
532
|
+
sample_indices.Data(),
|
|
533
|
+
sample_count);
|
|
534
|
+
return Napi::Number::New(env, status);
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
Napi::Value DecisionTreeModelPredict(const Napi::CallbackInfo& info) {
|
|
538
|
+
const Napi::Env env = info.Env();
|
|
539
|
+
if (!isLibraryLoaded(env)) {
|
|
540
|
+
return env.Null();
|
|
541
|
+
}
|
|
542
|
+
if (!g_library.decision_tree_model_predict) {
|
|
543
|
+
throwError(env, "Symbol decision_tree_model_predict is unavailable.");
|
|
544
|
+
return env.Null();
|
|
545
|
+
}
|
|
546
|
+
if (info.Length() != 5 || !info[1].IsTypedArray() || !info[2].IsNumber() || !info[3].IsNumber() ||
|
|
547
|
+
!info[4].IsTypedArray()) {
|
|
548
|
+
throwTypeError(env, "decisionTreeModelPredict(handle, x, nSamples, nFeatures, outLabels) has invalid arguments.");
|
|
549
|
+
return env.Null();
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
const NativeHandle handle = handleFromBigInt(info[0], env);
|
|
553
|
+
if (env.IsExceptionPending()) {
|
|
554
|
+
return env.Null();
|
|
555
|
+
}
|
|
556
|
+
auto x = info[1].As<Napi::Float64Array>();
|
|
557
|
+
const std::size_t n_samples = static_cast<std::size_t>(info[2].As<Napi::Number>().Uint32Value());
|
|
558
|
+
const std::size_t n_features = static_cast<std::size_t>(info[3].As<Napi::Number>().Uint32Value());
|
|
559
|
+
auto out_labels = info[4].As<Napi::Uint8Array>();
|
|
560
|
+
|
|
561
|
+
const std::uint8_t status = g_library.decision_tree_model_predict(
|
|
562
|
+
handle,
|
|
563
|
+
x.Data(),
|
|
564
|
+
n_samples,
|
|
565
|
+
n_features,
|
|
566
|
+
out_labels.Data());
|
|
567
|
+
return Napi::Number::New(env, status);
|
|
568
|
+
}
|
|
569
|
+
|
|
426
570
|
Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
427
571
|
exports.Set("loadLibrary", Napi::Function::New(env, LoadNativeLibrary));
|
|
428
572
|
exports.Set("unloadLibrary", Napi::Function::New(env, UnloadLibrary));
|
|
@@ -442,6 +586,11 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
|
|
|
442
586
|
exports.Set("logisticModelCopyCoefficients", Napi::Function::New(env, LogisticModelCopyCoefficients));
|
|
443
587
|
exports.Set("logisticModelGetIntercept", Napi::Function::New(env, LogisticModelGetIntercept));
|
|
444
588
|
|
|
589
|
+
exports.Set("decisionTreeModelCreate", Napi::Function::New(env, DecisionTreeModelCreate));
|
|
590
|
+
exports.Set("decisionTreeModelDestroy", Napi::Function::New(env, DecisionTreeModelDestroy));
|
|
591
|
+
exports.Set("decisionTreeModelFit", Napi::Function::New(env, DecisionTreeModelFit));
|
|
592
|
+
exports.Set("decisionTreeModelPredict", Napi::Function::New(env, DecisionTreeModelPredict));
|
|
593
|
+
|
|
445
594
|
return exports;
|
|
446
595
|
}
|
|
447
596
|
|
package/src/native/zigKernels.ts
CHANGED
|
@@ -243,6 +243,10 @@ interface NodeApiAddon {
|
|
|
243
243
|
logisticModelFitLbfgs: LogisticModelFitLbfgsFn;
|
|
244
244
|
logisticModelCopyCoefficients: LogisticModelCopyCoefficientsFn;
|
|
245
245
|
logisticModelGetIntercept: LogisticModelGetInterceptFn;
|
|
246
|
+
decisionTreeModelCreate?: DecisionTreeModelCreateFn;
|
|
247
|
+
decisionTreeModelDestroy?: DecisionTreeModelDestroyFn;
|
|
248
|
+
decisionTreeModelFit?: DecisionTreeModelFitFn;
|
|
249
|
+
decisionTreeModelPredict?: DecisionTreeModelPredictFn;
|
|
246
250
|
}
|
|
247
251
|
|
|
248
252
|
function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
@@ -281,10 +285,10 @@ function tryLoadNodeApiKernels(): ZigKernels | null {
|
|
|
281
285
|
logisticModelPredict: null,
|
|
282
286
|
logisticModelCopyCoefficients: addon.logisticModelCopyCoefficients ?? null,
|
|
283
287
|
logisticModelGetIntercept: addon.logisticModelGetIntercept ?? null,
|
|
284
|
-
decisionTreeModelCreate: null,
|
|
285
|
-
decisionTreeModelDestroy: null,
|
|
286
|
-
decisionTreeModelFit: null,
|
|
287
|
-
decisionTreeModelPredict: null,
|
|
288
|
+
decisionTreeModelCreate: addon.decisionTreeModelCreate ?? null,
|
|
289
|
+
decisionTreeModelDestroy: addon.decisionTreeModelDestroy ?? null,
|
|
290
|
+
decisionTreeModelFit: addon.decisionTreeModelFit ?? null,
|
|
291
|
+
decisionTreeModelPredict: addon.decisionTreeModelPredict ?? null,
|
|
288
292
|
logisticTrainEpoch: null,
|
|
289
293
|
logisticTrainEpochs: null,
|
|
290
294
|
abiVersion,
|
|
@@ -403,6 +407,31 @@ export function getZigKernels(): ZigKernels | null {
|
|
|
403
407
|
args: ["usize"],
|
|
404
408
|
returns: FFIType.f64,
|
|
405
409
|
},
|
|
410
|
+
decision_tree_model_create: {
|
|
411
|
+
args: [
|
|
412
|
+
"usize",
|
|
413
|
+
"usize",
|
|
414
|
+
"usize",
|
|
415
|
+
FFIType.u8,
|
|
416
|
+
"usize",
|
|
417
|
+
FFIType.u32,
|
|
418
|
+
FFIType.u8,
|
|
419
|
+
"usize",
|
|
420
|
+
],
|
|
421
|
+
returns: "usize",
|
|
422
|
+
},
|
|
423
|
+
decision_tree_model_destroy: {
|
|
424
|
+
args: ["usize"],
|
|
425
|
+
returns: FFIType.void,
|
|
426
|
+
},
|
|
427
|
+
decision_tree_model_fit: {
|
|
428
|
+
args: ["usize", FFIType.ptr, FFIType.ptr, "usize", "usize", FFIType.ptr, "usize"],
|
|
429
|
+
returns: FFIType.u8,
|
|
430
|
+
},
|
|
431
|
+
decision_tree_model_predict: {
|
|
432
|
+
args: ["usize", FFIType.ptr, "usize", "usize", FFIType.ptr],
|
|
433
|
+
returns: FFIType.u8,
|
|
434
|
+
},
|
|
406
435
|
logistic_train_epoch: {
|
|
407
436
|
args: [
|
|
408
437
|
FFIType.ptr,
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface BinarizerOptions {
|
|
9
|
+
threshold?: number;
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class Binarizer {
|
|
13
|
+
nFeaturesIn_: number | null = null;
|
|
14
|
+
private readonly threshold: number;
|
|
15
|
+
|
|
16
|
+
constructor(options: BinarizerOptions = {}) {
|
|
17
|
+
this.threshold = options.threshold ?? 0;
|
|
18
|
+
if (!Number.isFinite(this.threshold)) {
|
|
19
|
+
throw new Error(`threshold must be finite. Got ${this.threshold}.`);
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
fit(X: Matrix): this {
|
|
24
|
+
assertNonEmptyMatrix(X);
|
|
25
|
+
assertConsistentRowSize(X);
|
|
26
|
+
assertFiniteMatrix(X);
|
|
27
|
+
this.nFeaturesIn_ = X[0].length;
|
|
28
|
+
return this;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
transform(X: Matrix): Matrix {
|
|
32
|
+
assertNonEmptyMatrix(X);
|
|
33
|
+
assertConsistentRowSize(X);
|
|
34
|
+
assertFiniteMatrix(X);
|
|
35
|
+
|
|
36
|
+
if (this.nFeaturesIn_ !== null && X[0].length !== this.nFeaturesIn_) {
|
|
37
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
return X.map((row) => row.map((value) => (value > this.threshold ? 1 : 0)));
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
fitTransform(X: Matrix): Matrix {
|
|
44
|
+
return this.fit(X).transform(X);
|
|
45
|
+
}
|
|
46
|
+
}
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import type { Vector } from "../types";
|
|
2
|
+
import { assertFiniteVector } from "../utils/validation";
|
|
3
|
+
|
|
4
|
+
export class LabelEncoder {
|
|
5
|
+
classes_: number[] | null = null;
|
|
6
|
+
private classToIndex: Map<number, number> | null = null;
|
|
7
|
+
|
|
8
|
+
fit(y: Vector): this {
|
|
9
|
+
if (!Array.isArray(y) || y.length === 0) {
|
|
10
|
+
throw new Error("y must be a non-empty array.");
|
|
11
|
+
}
|
|
12
|
+
assertFiniteVector(y);
|
|
13
|
+
|
|
14
|
+
const classes = Array.from(new Set(y)).sort((a, b) => a - b);
|
|
15
|
+
const classToIndex = new Map<number, number>();
|
|
16
|
+
for (let i = 0; i < classes.length; i += 1) {
|
|
17
|
+
classToIndex.set(classes[i], i);
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
this.classes_ = classes;
|
|
21
|
+
this.classToIndex = classToIndex;
|
|
22
|
+
return this;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
transform(y: Vector): Vector {
|
|
26
|
+
if (!this.classToIndex) {
|
|
27
|
+
throw new Error("LabelEncoder has not been fitted.");
|
|
28
|
+
}
|
|
29
|
+
assertFiniteVector(y);
|
|
30
|
+
|
|
31
|
+
const encoded = new Array<number>(y.length);
|
|
32
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
33
|
+
const idx = this.classToIndex.get(y[i]);
|
|
34
|
+
if (idx === undefined) {
|
|
35
|
+
throw new Error(`Unknown label ${y[i]} at index ${i}.`);
|
|
36
|
+
}
|
|
37
|
+
encoded[i] = idx;
|
|
38
|
+
}
|
|
39
|
+
return encoded;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
fitTransform(y: Vector): Vector {
|
|
43
|
+
return this.fit(y).transform(y);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
inverseTransform(y: Vector): Vector {
|
|
47
|
+
if (!this.classes_) {
|
|
48
|
+
throw new Error("LabelEncoder has not been fitted.");
|
|
49
|
+
}
|
|
50
|
+
assertFiniteVector(y);
|
|
51
|
+
|
|
52
|
+
const decoded = new Array<number>(y.length);
|
|
53
|
+
for (let i = 0; i < y.length; i += 1) {
|
|
54
|
+
const encoded = y[i];
|
|
55
|
+
if (!Number.isInteger(encoded) || encoded < 0 || encoded >= this.classes_.length) {
|
|
56
|
+
throw new Error(`Encoded label out of range at index ${i}: ${encoded}.`);
|
|
57
|
+
}
|
|
58
|
+
decoded[i] = this.classes_[encoded];
|
|
59
|
+
}
|
|
60
|
+
return decoded;
|
|
61
|
+
}
|
|
62
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import type { Matrix, Vector } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export class MaxAbsScaler {
|
|
9
|
+
maxAbs_: Vector | null = null;
|
|
10
|
+
|
|
11
|
+
fit(X: Matrix): this {
|
|
12
|
+
assertNonEmptyMatrix(X);
|
|
13
|
+
assertConsistentRowSize(X);
|
|
14
|
+
assertFiniteMatrix(X);
|
|
15
|
+
|
|
16
|
+
const nFeatures = X[0].length;
|
|
17
|
+
const maxAbs = new Array<number>(nFeatures).fill(0);
|
|
18
|
+
|
|
19
|
+
for (let i = 0; i < X.length; i += 1) {
|
|
20
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
21
|
+
const abs = Math.abs(X[i][j]);
|
|
22
|
+
if (abs > maxAbs[j]) {
|
|
23
|
+
maxAbs[j] = abs;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
for (let j = 0; j < nFeatures; j += 1) {
|
|
29
|
+
if (maxAbs[j] === 0) {
|
|
30
|
+
maxAbs[j] = 1;
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
this.maxAbs_ = maxAbs;
|
|
35
|
+
return this;
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
transform(X: Matrix): Matrix {
|
|
39
|
+
if (!this.maxAbs_) {
|
|
40
|
+
throw new Error("MaxAbsScaler has not been fitted.");
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
assertNonEmptyMatrix(X);
|
|
44
|
+
assertConsistentRowSize(X);
|
|
45
|
+
assertFiniteMatrix(X);
|
|
46
|
+
|
|
47
|
+
if (X[0].length !== this.maxAbs_.length) {
|
|
48
|
+
throw new Error(
|
|
49
|
+
`Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
|
|
50
|
+
);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
return X.map((row) => row.map((value, featureIdx) => value / this.maxAbs_![featureIdx]));
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
fitTransform(X: Matrix): Matrix {
|
|
57
|
+
return this.fit(X).transform(X);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
inverseTransform(X: Matrix): Matrix {
|
|
61
|
+
if (!this.maxAbs_) {
|
|
62
|
+
throw new Error("MaxAbsScaler has not been fitted.");
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
assertNonEmptyMatrix(X);
|
|
66
|
+
assertConsistentRowSize(X);
|
|
67
|
+
assertFiniteMatrix(X);
|
|
68
|
+
|
|
69
|
+
if (X[0].length !== this.maxAbs_.length) {
|
|
70
|
+
throw new Error(
|
|
71
|
+
`Feature size mismatch. Expected ${this.maxAbs_.length}, got ${X[0].length}.`,
|
|
72
|
+
);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
return X.map((row) => row.map((value, featureIdx) => value * this.maxAbs_![featureIdx]));
|
|
76
|
+
}
|
|
77
|
+
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import type { Matrix } from "../types";
|
|
2
|
+
import {
|
|
3
|
+
assertConsistentRowSize,
|
|
4
|
+
assertFiniteMatrix,
|
|
5
|
+
assertNonEmptyMatrix,
|
|
6
|
+
} from "../utils/validation";
|
|
7
|
+
|
|
8
|
+
export interface NormalizerOptions {
|
|
9
|
+
norm?: "l1" | "l2" | "max";
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export class Normalizer {
|
|
13
|
+
private readonly norm: "l1" | "l2" | "max";
|
|
14
|
+
private nFeatures_: number | null = null;
|
|
15
|
+
|
|
16
|
+
constructor(options: NormalizerOptions = {}) {
|
|
17
|
+
this.norm = options.norm ?? "l2";
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
fit(X: Matrix): this {
|
|
21
|
+
assertNonEmptyMatrix(X);
|
|
22
|
+
assertConsistentRowSize(X);
|
|
23
|
+
assertFiniteMatrix(X);
|
|
24
|
+
this.nFeatures_ = X[0].length;
|
|
25
|
+
return this;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
transform(X: Matrix): Matrix {
|
|
29
|
+
assertNonEmptyMatrix(X);
|
|
30
|
+
assertConsistentRowSize(X);
|
|
31
|
+
assertFiniteMatrix(X);
|
|
32
|
+
if (this.nFeatures_ !== null && X[0].length !== this.nFeatures_) {
|
|
33
|
+
throw new Error(`Feature size mismatch. Expected ${this.nFeatures_}, got ${X[0].length}.`);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
return X.map((row) => {
|
|
37
|
+
let scale = 0;
|
|
38
|
+
if (this.norm === "l1") {
|
|
39
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
40
|
+
scale += Math.abs(row[i]);
|
|
41
|
+
}
|
|
42
|
+
} else if (this.norm === "l2") {
|
|
43
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
44
|
+
scale += row[i] * row[i];
|
|
45
|
+
}
|
|
46
|
+
scale = Math.sqrt(scale);
|
|
47
|
+
} else {
|
|
48
|
+
for (let i = 0; i < row.length; i += 1) {
|
|
49
|
+
const abs = Math.abs(row[i]);
|
|
50
|
+
if (abs > scale) {
|
|
51
|
+
scale = abs;
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
if (scale === 0) {
|
|
57
|
+
return [...row];
|
|
58
|
+
}
|
|
59
|
+
return row.map((value) => value / scale);
|
|
60
|
+
});
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
fitTransform(X: Matrix): Matrix {
|
|
64
|
+
return this.fit(X).transform(X);
|
|
65
|
+
}
|
|
66
|
+
}
|