bun-scikit 0.1.3 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,190 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import { accuracyScore } from "../metrics/classification";
3
+ import {
4
+ assertConsistentRowSize,
5
+ assertFiniteMatrix,
6
+ assertFiniteVector,
7
+ assertNonEmptyMatrix,
8
+ assertVectorLength,
9
+ } from "../utils/validation";
10
+
11
+ export type DummyClassifierStrategy =
12
+ | "most_frequent"
13
+ | "prior"
14
+ | "stratified"
15
+ | "uniform"
16
+ | "constant";
17
+
18
+ export interface DummyClassifierOptions {
19
+ strategy?: DummyClassifierStrategy;
20
+ constant?: number;
21
+ randomState?: number;
22
+ }
23
+
24
+ class Mulberry32 {
25
+ private state: number;
26
+
27
+ constructor(seed: number) {
28
+ this.state = seed >>> 0;
29
+ }
30
+
31
+ next(): number {
32
+ this.state = (this.state + 0x6d2b79f5) >>> 0;
33
+ let t = this.state ^ (this.state >>> 15);
34
+ t = Math.imul(t, this.state | 1);
35
+ t ^= t + Math.imul(t ^ (t >>> 7), t | 61);
36
+ return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
37
+ }
38
+ }
39
+
40
+ export class DummyClassifier {
41
+ classes_: number[] | null = null;
42
+ classPrior_: number[] | null = null;
43
+ constant_: number | null = null;
44
+
45
+ private readonly strategy: DummyClassifierStrategy;
46
+ private readonly configuredConstant?: number;
47
+ private readonly randomState: number;
48
+ private majorityClass: number | null = null;
49
+ private nFeaturesIn_: number | null = null;
50
+
51
+ constructor(options: DummyClassifierOptions = {}) {
52
+ this.strategy = options.strategy ?? "prior";
53
+ this.configuredConstant = options.constant;
54
+ this.randomState = options.randomState ?? 42;
55
+ }
56
+
57
+ fit(X: Matrix, y: Vector): this {
58
+ assertNonEmptyMatrix(X);
59
+ assertConsistentRowSize(X);
60
+ assertFiniteMatrix(X);
61
+ assertVectorLength(y, X.length);
62
+ assertFiniteVector(y);
63
+ this.nFeaturesIn_ = X[0].length;
64
+
65
+ const counts = new Map<number, number>();
66
+ for (let i = 0; i < y.length; i += 1) {
67
+ counts.set(y[i], (counts.get(y[i]) ?? 0) + 1);
68
+ }
69
+
70
+ const classes = Array.from(counts.keys()).sort((a, b) => a - b);
71
+ const priors = new Array<number>(classes.length);
72
+ for (let i = 0; i < classes.length; i += 1) {
73
+ priors[i] = (counts.get(classes[i]) ?? 0) / y.length;
74
+ }
75
+
76
+ let majorityClass = classes[0];
77
+ let majorityCount = counts.get(majorityClass) ?? 0;
78
+ for (let i = 1; i < classes.length; i += 1) {
79
+ const cls = classes[i];
80
+ const clsCount = counts.get(cls) ?? 0;
81
+ if (clsCount > majorityCount) {
82
+ majorityClass = cls;
83
+ majorityCount = clsCount;
84
+ }
85
+ }
86
+
87
+ if (this.strategy === "constant") {
88
+ if (!Number.isFinite(this.configuredConstant)) {
89
+ throw new Error("constant strategy requires a finite constant value.");
90
+ }
91
+ this.constant_ = this.configuredConstant!;
92
+ } else {
93
+ this.constant_ = majorityClass;
94
+ }
95
+
96
+ this.classes_ = classes;
97
+ this.classPrior_ = priors;
98
+ this.majorityClass = majorityClass;
99
+ return this;
100
+ }
101
+
102
+ private ensureFitted(): void {
103
+ if (!this.classes_ || !this.classPrior_ || this.nFeaturesIn_ === null || this.majorityClass === null) {
104
+ throw new Error("DummyClassifier has not been fitted.");
105
+ }
106
+ }
107
+
108
+ private sampleByPrior(rng: Mulberry32): number {
109
+ let r = rng.next();
110
+ for (let i = 0; i < this.classPrior_!.length; i += 1) {
111
+ r -= this.classPrior_![i];
112
+ if (r <= 0) {
113
+ return this.classes_![i];
114
+ }
115
+ }
116
+ return this.classes_![this.classes_!.length - 1];
117
+ }
118
+
119
+ predict(X: Matrix): Vector {
120
+ this.ensureFitted();
121
+ if (!Array.isArray(X) || X.length === 0) {
122
+ throw new Error("X must be a non-empty 2D array.");
123
+ }
124
+ if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
125
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
126
+ }
127
+
128
+ switch (this.strategy) {
129
+ case "most_frequent":
130
+ case "prior":
131
+ return new Array<number>(X.length).fill(this.majorityClass!);
132
+ case "constant":
133
+ return new Array<number>(X.length).fill(this.constant_!);
134
+ case "uniform": {
135
+ const rng = new Mulberry32(this.randomState);
136
+ const out = new Array<number>(X.length);
137
+ for (let i = 0; i < X.length; i += 1) {
138
+ const idx = Math.floor(rng.next() * this.classes_!.length);
139
+ out[i] = this.classes_![idx];
140
+ }
141
+ return out;
142
+ }
143
+ case "stratified": {
144
+ const rng = new Mulberry32(this.randomState);
145
+ const out = new Array<number>(X.length);
146
+ for (let i = 0; i < X.length; i += 1) {
147
+ out[i] = this.sampleByPrior(rng);
148
+ }
149
+ return out;
150
+ }
151
+ default: {
152
+ const exhaustive: never = this.strategy;
153
+ throw new Error(`Unsupported strategy: ${exhaustive}`);
154
+ }
155
+ }
156
+ }
157
+
158
+ predictProba(X: Matrix): Matrix {
159
+ this.ensureFitted();
160
+ if (!Array.isArray(X) || X.length === 0) {
161
+ throw new Error("X must be a non-empty 2D array.");
162
+ }
163
+ if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
164
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
165
+ }
166
+
167
+ if (this.strategy === "uniform") {
168
+ const value = 1 / this.classes_!.length;
169
+ return X.map(() => new Array(this.classes_!.length).fill(value));
170
+ }
171
+
172
+ if (this.strategy === "most_frequent" || this.strategy === "constant") {
173
+ const oneHot = new Array<number>(this.classes_!.length).fill(0);
174
+ const label = this.strategy === "constant" ? this.constant_! : this.majorityClass!;
175
+ const classIndex = this.classes_!.indexOf(label);
176
+ if (classIndex >= 0) {
177
+ oneHot[classIndex] = 1;
178
+ }
179
+ return X.map(() => [...oneHot]);
180
+ }
181
+
182
+ // prior / stratified share prior probabilities.
183
+ const prior = [...this.classPrior_!];
184
+ return X.map(() => [...prior]);
185
+ }
186
+
187
+ score(X: Matrix, y: Vector): number {
188
+ return accuracyScore(y, this.predict(X));
189
+ }
190
+ }
@@ -0,0 +1,108 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import { r2Score } from "../metrics/regression";
3
+ import { assertFiniteVector, validateRegressionInputs } from "../utils/validation";
4
+
5
+ export type DummyRegressorStrategy = "mean" | "median" | "quantile" | "constant";
6
+
7
+ export interface DummyRegressorOptions {
8
+ strategy?: DummyRegressorStrategy;
9
+ constant?: number;
10
+ quantile?: number;
11
+ }
12
+
13
+ function computeMedian(values: number[]): number {
14
+ const sorted = [...values].sort((a, b) => a - b);
15
+ const mid = Math.floor(sorted.length / 2);
16
+ if (sorted.length % 2 === 0) {
17
+ return 0.5 * (sorted[mid - 1] + sorted[mid]);
18
+ }
19
+ return sorted[mid];
20
+ }
21
+
22
+ function computeQuantile(values: number[], q: number): number {
23
+ const sorted = [...values].sort((a, b) => a - b);
24
+ const pos = q * (sorted.length - 1);
25
+ const lo = Math.floor(pos);
26
+ const hi = Math.ceil(pos);
27
+ if (lo === hi) {
28
+ return sorted[lo];
29
+ }
30
+ const weight = pos - lo;
31
+ return sorted[lo] * (1 - weight) + sorted[hi] * weight;
32
+ }
33
+
34
+ export class DummyRegressor {
35
+ constant_: number | null = null;
36
+
37
+ private readonly strategy: DummyRegressorStrategy;
38
+ private readonly constant?: number;
39
+ private readonly quantile: number;
40
+ private nFeaturesIn_: number | null = null;
41
+
42
+ constructor(options: DummyRegressorOptions = {}) {
43
+ this.strategy = options.strategy ?? "mean";
44
+ this.constant = options.constant;
45
+ this.quantile = options.quantile ?? 0.5;
46
+
47
+ if (this.strategy === "constant") {
48
+ if (!Number.isFinite(this.constant)) {
49
+ throw new Error("constant strategy requires a finite constant value.");
50
+ }
51
+ }
52
+
53
+ if (this.strategy === "quantile") {
54
+ if (!Number.isFinite(this.quantile) || this.quantile < 0 || this.quantile > 1) {
55
+ throw new Error(`quantile must be in [0, 1]. Got ${this.quantile}.`);
56
+ }
57
+ }
58
+ }
59
+
60
+ fit(X: Matrix, y: Vector): this {
61
+ validateRegressionInputs(X, y);
62
+ this.nFeaturesIn_ = X[0].length;
63
+
64
+ switch (this.strategy) {
65
+ case "mean": {
66
+ let total = 0;
67
+ for (let i = 0; i < y.length; i += 1) {
68
+ total += y[i];
69
+ }
70
+ this.constant_ = total / y.length;
71
+ break;
72
+ }
73
+ case "median":
74
+ this.constant_ = computeMedian(y);
75
+ break;
76
+ case "quantile":
77
+ this.constant_ = computeQuantile(y, this.quantile);
78
+ break;
79
+ case "constant":
80
+ this.constant_ = this.constant!;
81
+ break;
82
+ default: {
83
+ const exhaustive: never = this.strategy;
84
+ throw new Error(`Unsupported strategy: ${exhaustive}`);
85
+ }
86
+ }
87
+
88
+ return this;
89
+ }
90
+
91
+ predict(X: Matrix): Vector {
92
+ if (this.constant_ === null || this.nFeaturesIn_ === null) {
93
+ throw new Error("DummyRegressor has not been fitted.");
94
+ }
95
+ if (!Array.isArray(X) || X.length === 0) {
96
+ throw new Error("X must be a non-empty 2D array.");
97
+ }
98
+ if (!Array.isArray(X[0]) || X[0].length !== this.nFeaturesIn_) {
99
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0]?.length ?? 0}.`);
100
+ }
101
+ return new Array<number>(X.length).fill(this.constant_);
102
+ }
103
+
104
+ score(X: Matrix, y: Vector): number {
105
+ assertFiniteVector(y);
106
+ return r2Score(y, this.predict(X));
107
+ }
108
+ }
@@ -2,6 +2,7 @@ import type { ClassificationModel, Matrix, Vector } from "../types";
2
2
  import { accuracyScore } from "../metrics/classification";
3
3
  import { DecisionTreeClassifier, type MaxFeaturesOption } from "../tree/DecisionTreeClassifier";
4
4
  import { assertFiniteVector, validateClassificationInputs } from "../utils/validation";
5
+ import { getZigKernels } from "../native/zigKernels";
5
6
 
6
7
  export interface RandomForestClassifierOptions {
7
8
  nEstimators?: number;
@@ -23,8 +24,18 @@ function mulberry32(seed: number): () => number {
23
24
  };
24
25
  }
25
26
 
27
+ function isTruthy(value: string | undefined): boolean {
28
+ if (!value) {
29
+ return false;
30
+ }
31
+ const normalized = value.trim().toLowerCase();
32
+ return !(normalized === "0" || normalized === "false" || normalized === "off");
33
+ }
34
+
26
35
  export class RandomForestClassifier implements ClassificationModel {
27
36
  classes_: Vector = [0, 1];
37
+ fitBackend_: "zig" | "js" = "js";
38
+ fitBackendLibrary_: string | null = null;
28
39
  private readonly nEstimators: number;
29
40
  private readonly maxDepth?: number;
30
41
  private readonly minSamplesSplit?: number;
@@ -32,6 +43,7 @@ export class RandomForestClassifier implements ClassificationModel {
32
43
  private readonly maxFeatures: MaxFeaturesOption;
33
44
  private readonly bootstrap: boolean;
34
45
  private readonly randomState?: number;
46
+ private nativeModelHandle: bigint | null = null;
35
47
  private trees: DecisionTreeClassifier[] = [];
36
48
 
37
49
  constructor(options: RandomForestClassifierOptions = {}) {
@@ -49,6 +61,7 @@ export class RandomForestClassifier implements ClassificationModel {
49
61
  }
50
62
 
51
63
  fit(X: Matrix, y: Vector): this {
64
+ this.disposeNativeModel();
52
65
  validateClassificationInputs(X, y);
53
66
 
54
67
  const sampleCount = X.length;
@@ -56,10 +69,17 @@ export class RandomForestClassifier implements ClassificationModel {
56
69
  const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
57
70
  const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
58
71
  const yBinary = this.buildBinaryTargets(y);
72
+ const sampleIndices = new Uint32Array(sampleCount);
73
+ this.trees = [];
74
+ if (this.tryFitNativeForest(flattenedX, yBinary, sampleCount, featureCount)) {
75
+ this.fitBackend_ = "zig";
76
+ return this;
77
+ }
78
+ this.fitBackend_ = "js";
79
+ this.fitBackendLibrary_ = null;
59
80
  this.trees = new Array(this.nEstimators);
60
81
 
61
82
  for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
62
- const sampleIndices = new Uint32Array(sampleCount);
63
83
  if (this.bootstrap) {
64
84
  for (let i = 0; i < sampleCount; i += 1) {
65
85
  sampleIndices[i] = Math.floor(random() * sampleCount);
@@ -86,20 +106,47 @@ export class RandomForestClassifier implements ClassificationModel {
86
106
  }
87
107
 
88
108
  predict(X: Matrix): Vector {
109
+ if (this.nativeModelHandle !== null) {
110
+ const kernels = getZigKernels();
111
+ const predict = kernels?.randomForestClassifierModelPredict;
112
+ if (predict) {
113
+ const sampleCount = X.length;
114
+ const featureCount = X[0]?.length ?? 0;
115
+ const flattened = this.flattenTrainingMatrix(X, sampleCount, featureCount);
116
+ const out = new Uint8Array(sampleCount);
117
+ const status = predict(
118
+ this.nativeModelHandle,
119
+ flattened,
120
+ sampleCount,
121
+ featureCount,
122
+ out,
123
+ );
124
+ if (status === 1) {
125
+ return Array.from(out);
126
+ }
127
+ }
128
+ }
129
+
89
130
  if (this.trees.length === 0) {
90
131
  throw new Error("RandomForestClassifier has not been fitted.");
91
132
  }
92
133
 
93
- const treePredictions = this.trees.map((tree) => tree.predict(X));
94
134
  const sampleCount = X.length;
95
- const predictions = new Array(sampleCount).fill(0);
135
+ const voteCounts = new Uint16Array(sampleCount);
96
136
 
97
- for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
98
- let positiveVotes = 0;
99
- for (let treeIndex = 0; treeIndex < treePredictions.length; treeIndex += 1) {
100
- positiveVotes += treePredictions[treeIndex][sampleIndex] === 1 ? 1 : 0;
137
+ for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
138
+ const treePrediction = this.trees[treeIndex].predict(X);
139
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
140
+ if (treePrediction[sampleIndex] === 1) {
141
+ voteCounts[sampleIndex] += 1;
142
+ }
101
143
  }
102
- predictions[sampleIndex] = positiveVotes * 2 >= this.trees.length ? 1 : 0;
144
+ }
145
+
146
+ const predictions = new Array<number>(sampleCount);
147
+ const voteThreshold = this.trees.length;
148
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
149
+ predictions[sampleIndex] = voteCounts[sampleIndex] * 2 >= voteThreshold ? 1 : 0;
103
150
  }
104
151
 
105
152
  return predictions;
@@ -110,6 +157,105 @@ export class RandomForestClassifier implements ClassificationModel {
110
157
  return accuracyScore(y, this.predict(X));
111
158
  }
112
159
 
160
+ dispose(): void {
161
+ this.disposeNativeModel();
162
+ for (let i = 0; i < this.trees.length; i += 1) {
163
+ this.trees[i].dispose();
164
+ }
165
+ this.trees = [];
166
+ }
167
+
168
+ private resolveNativeMaxFeatures(featureCount: number): {
169
+ mode: 0 | 1 | 2 | 3;
170
+ value: number;
171
+ } {
172
+ if (this.maxFeatures === null || this.maxFeatures === undefined) {
173
+ return { mode: 0, value: 0 };
174
+ }
175
+ if (this.maxFeatures === "sqrt") {
176
+ return { mode: 1, value: 0 };
177
+ }
178
+ if (this.maxFeatures === "log2") {
179
+ return { mode: 2, value: 0 };
180
+ }
181
+ const value = Number.isFinite(this.maxFeatures)
182
+ ? Math.max(1, Math.min(featureCount, Math.floor(this.maxFeatures)))
183
+ : featureCount;
184
+ return { mode: 3, value };
185
+ }
186
+
187
+ private tryFitNativeForest(
188
+ flattenedX: Float64Array,
189
+ yBinary: Uint8Array,
190
+ sampleCount: number,
191
+ featureCount: number,
192
+ ): boolean {
193
+ if (!isTruthy(process.env.BUN_SCIKIT_EXPERIMENTAL_NATIVE_FOREST)) {
194
+ return false;
195
+ }
196
+ if (process.env.BUN_SCIKIT_TREE_BACKEND?.trim().toLowerCase() !== "zig") {
197
+ return false;
198
+ }
199
+ const kernels = getZigKernels();
200
+ const create = kernels?.randomForestClassifierModelCreate;
201
+ const fit = kernels?.randomForestClassifierModelFit;
202
+ const destroy = kernels?.randomForestClassifierModelDestroy;
203
+ if (!create || !fit || !destroy) {
204
+ return false;
205
+ }
206
+
207
+ const { mode, value } = this.resolveNativeMaxFeatures(featureCount);
208
+ const useRandomState = this.randomState === undefined ? 0 : 1;
209
+ const randomState = this.randomState ?? 0;
210
+ const handle = create(
211
+ this.nEstimators,
212
+ this.maxDepth ?? 12,
213
+ this.minSamplesSplit ?? 2,
214
+ this.minSamplesLeaf ?? 1,
215
+ mode,
216
+ value,
217
+ this.bootstrap ? 1 : 0,
218
+ randomState >>> 0,
219
+ useRandomState,
220
+ featureCount,
221
+ );
222
+ if (handle === 0n) {
223
+ return false;
224
+ }
225
+
226
+ let shouldDestroy = true;
227
+ try {
228
+ const status = fit(handle, flattenedX, yBinary, sampleCount, featureCount);
229
+ if (status !== 1) {
230
+ return false;
231
+ }
232
+ this.nativeModelHandle = handle;
233
+ this.fitBackendLibrary_ = kernels.libraryPath;
234
+ shouldDestroy = false;
235
+ return true;
236
+ } finally {
237
+ if (shouldDestroy) {
238
+ destroy(handle);
239
+ }
240
+ }
241
+ }
242
+
243
+ private disposeNativeModel(): void {
244
+ if (this.nativeModelHandle === null) {
245
+ return;
246
+ }
247
+ const kernels = getZigKernels();
248
+ const destroy = kernels?.randomForestClassifierModelDestroy;
249
+ if (destroy) {
250
+ try {
251
+ destroy(this.nativeModelHandle);
252
+ } catch {
253
+ // best effort cleanup
254
+ }
255
+ }
256
+ this.nativeModelHandle = null;
257
+ }
258
+
113
259
  private flattenTrainingMatrix(
114
260
  X: Matrix,
115
261
  sampleCount: number,
@@ -56,10 +56,10 @@ export class RandomForestRegressor implements RegressionModel {
56
56
  const random = this.randomState === undefined ? Math.random : mulberry32(this.randomState);
57
57
  const flattenedX = this.flattenTrainingMatrix(X, sampleCount, featureCount);
58
58
  const yValues = this.toFloat64Vector(y);
59
+ const sampleIndices = new Uint32Array(sampleCount);
59
60
  this.trees = new Array(this.nEstimators);
60
61
 
61
62
  for (let estimatorIndex = 0; estimatorIndex < this.nEstimators; estimatorIndex += 1) {
62
- const sampleIndices = new Uint32Array(sampleCount);
63
63
  if (this.bootstrap) {
64
64
  for (let i = 0; i < sampleCount; i += 1) {
65
65
  sampleIndices[i] = Math.floor(random() * sampleCount);
@@ -90,16 +90,20 @@ export class RandomForestRegressor implements RegressionModel {
90
90
  throw new Error("RandomForestRegressor has not been fitted.");
91
91
  }
92
92
 
93
- const treePredictions = this.trees.map((tree) => tree.predict(X));
94
93
  const sampleCount = X.length;
95
- const predictions = new Array<number>(sampleCount).fill(0);
94
+ const sums = new Float64Array(sampleCount);
96
95
 
97
- for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
98
- let sum = 0;
99
- for (let treeIndex = 0; treeIndex < treePredictions.length; treeIndex += 1) {
100
- sum += treePredictions[treeIndex][sampleIndex];
96
+ for (let treeIndex = 0; treeIndex < this.trees.length; treeIndex += 1) {
97
+ const treePrediction = this.trees[treeIndex].predict(X);
98
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
99
+ sums[sampleIndex] += treePrediction[sampleIndex];
101
100
  }
102
- predictions[sampleIndex] = sum / this.trees.length;
101
+ }
102
+
103
+ const predictions = new Array<number>(sampleCount);
104
+ const denominator = this.trees.length;
105
+ for (let sampleIndex = 0; sampleIndex < sampleCount; sampleIndex += 1) {
106
+ predictions[sampleIndex] = sums[sampleIndex] / denominator;
103
107
  }
104
108
 
105
109
  return predictions;
@@ -0,0 +1,88 @@
1
+ import type { Matrix } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export interface VarianceThresholdOptions {
9
+ threshold?: number;
10
+ }
11
+
12
+ export class VarianceThreshold {
13
+ variances_: number[] | null = null;
14
+ nFeaturesIn_: number | null = null;
15
+ selectedFeatureIndices_: number[] | null = null;
16
+
17
+ private readonly threshold: number;
18
+
19
+ constructor(options: VarianceThresholdOptions = {}) {
20
+ this.threshold = options.threshold ?? 0;
21
+ if (!Number.isFinite(this.threshold) || this.threshold < 0) {
22
+ throw new Error(`threshold must be finite and >= 0. Got ${this.threshold}.`);
23
+ }
24
+ }
25
+
26
+ fit(X: Matrix): this {
27
+ assertNonEmptyMatrix(X);
28
+ assertConsistentRowSize(X);
29
+ assertFiniteMatrix(X);
30
+
31
+ const nSamples = X.length;
32
+ const nFeatures = X[0].length;
33
+ const means = new Array<number>(nFeatures).fill(0);
34
+ const variances = new Array<number>(nFeatures).fill(0);
35
+
36
+ for (let i = 0; i < nSamples; i += 1) {
37
+ for (let j = 0; j < nFeatures; j += 1) {
38
+ means[j] += X[i][j];
39
+ }
40
+ }
41
+ for (let j = 0; j < nFeatures; j += 1) {
42
+ means[j] /= nSamples;
43
+ }
44
+ for (let i = 0; i < nSamples; i += 1) {
45
+ for (let j = 0; j < nFeatures; j += 1) {
46
+ const diff = X[i][j] - means[j];
47
+ variances[j] += diff * diff;
48
+ }
49
+ }
50
+ for (let j = 0; j < nFeatures; j += 1) {
51
+ variances[j] /= nSamples;
52
+ }
53
+
54
+ const selectedFeatureIndices: number[] = [];
55
+ for (let j = 0; j < nFeatures; j += 1) {
56
+ if (variances[j] > this.threshold) {
57
+ selectedFeatureIndices.push(j);
58
+ }
59
+ }
60
+ if (selectedFeatureIndices.length === 0) {
61
+ throw new Error("No feature in X meets the variance threshold.");
62
+ }
63
+
64
+ this.nFeaturesIn_ = nFeatures;
65
+ this.variances_ = variances;
66
+ this.selectedFeatureIndices_ = selectedFeatureIndices;
67
+ return this;
68
+ }
69
+
70
+ transform(X: Matrix): Matrix {
71
+ if (!this.selectedFeatureIndices_ || this.nFeaturesIn_ === null) {
72
+ throw new Error("VarianceThreshold has not been fitted.");
73
+ }
74
+ assertNonEmptyMatrix(X);
75
+ assertConsistentRowSize(X);
76
+ assertFiniteMatrix(X);
77
+
78
+ if (X[0].length !== this.nFeaturesIn_) {
79
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
80
+ }
81
+
82
+ return X.map((row) => this.selectedFeatureIndices_!.map((featureIdx) => row[featureIdx]));
83
+ }
84
+
85
+ fitTransform(X: Matrix): Matrix {
86
+ return this.fit(X).transform(X);
87
+ }
88
+ }
package/src/index.ts CHANGED
@@ -1,15 +1,28 @@
1
1
  export * from "./types";
2
2
 
3
+ // Baselines
4
+ export * from "./dummy/DummyClassifier";
5
+ export * from "./dummy/DummyRegressor";
6
+
7
+ // Preprocessing
3
8
  export * from "./preprocessing/StandardScaler";
4
9
  export * from "./preprocessing/MinMaxScaler";
5
10
  export * from "./preprocessing/RobustScaler";
11
+ export * from "./preprocessing/MaxAbsScaler";
12
+ export * from "./preprocessing/Normalizer";
13
+ export * from "./preprocessing/Binarizer";
14
+ export * from "./preprocessing/LabelEncoder";
6
15
  export * from "./preprocessing/PolynomialFeatures";
7
16
  export * from "./preprocessing/SimpleImputer";
8
17
  export * from "./preprocessing/OneHotEncoder";
18
+
19
+ // Linear models
9
20
  export * from "./linear_model/LinearRegression";
10
21
  export * from "./linear_model/LogisticRegression";
11
22
  export * from "./linear_model/SGDClassifier";
12
23
  export * from "./linear_model/SGDRegressor";
24
+
25
+ // Other estimators
13
26
  export * from "./neighbors/KNeighborsClassifier";
14
27
  export * from "./naive_bayes/GaussianNB";
15
28
  export * from "./svm/LinearSVC";
@@ -17,6 +30,8 @@ export * from "./tree/DecisionTreeClassifier";
17
30
  export * from "./tree/DecisionTreeRegressor";
18
31
  export * from "./ensemble/RandomForestClassifier";
19
32
  export * from "./ensemble/RandomForestRegressor";
33
+
34
+ // Model selection
20
35
  export * from "./model_selection/trainTestSplit";
21
36
  export * from "./model_selection/KFold";
22
37
  export * from "./model_selection/StratifiedKFold";
@@ -25,8 +40,16 @@ export * from "./model_selection/RepeatedKFold";
25
40
  export * from "./model_selection/RepeatedStratifiedKFold";
26
41
  export * from "./model_selection/crossValScore";
27
42
  export * from "./model_selection/GridSearchCV";
43
+ export * from "./model_selection/RandomizedSearchCV";
44
+
45
+ // Feature selection
46
+ export * from "./feature_selection/VarianceThreshold";
47
+
48
+ // Composition
28
49
  export * from "./pipeline/Pipeline";
29
50
  export * from "./pipeline/ColumnTransformer";
30
51
  export * from "./pipeline/FeatureUnion";
52
+
53
+ // Metrics
31
54
  export * from "./metrics/regression";
32
55
  export * from "./metrics/classification";