bun-scikit 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +187 -0
  3. package/binding.gyp +21 -0
  4. package/docs/README.md +7 -0
  5. package/docs/native-abi.md +53 -0
  6. package/index.ts +1 -0
  7. package/package.json +76 -0
  8. package/scripts/build-node-addon.ts +26 -0
  9. package/scripts/build-zig-kernels.ts +50 -0
  10. package/scripts/check-api-docs-coverage.ts +52 -0
  11. package/scripts/check-benchmark-health.ts +140 -0
  12. package/scripts/install-native.ts +160 -0
  13. package/scripts/package-native-artifacts.ts +62 -0
  14. package/scripts/sync-benchmark-readme.ts +181 -0
  15. package/scripts/update-benchmark-history.ts +91 -0
  16. package/src/ensemble/RandomForestClassifier.ts +136 -0
  17. package/src/ensemble/RandomForestRegressor.ts +136 -0
  18. package/src/index.ts +32 -0
  19. package/src/linear_model/LinearRegression.ts +136 -0
  20. package/src/linear_model/LogisticRegression.ts +260 -0
  21. package/src/linear_model/SGDClassifier.ts +161 -0
  22. package/src/linear_model/SGDRegressor.ts +104 -0
  23. package/src/metrics/classification.ts +294 -0
  24. package/src/metrics/regression.ts +51 -0
  25. package/src/model_selection/GridSearchCV.ts +244 -0
  26. package/src/model_selection/KFold.ts +82 -0
  27. package/src/model_selection/RepeatedKFold.ts +49 -0
  28. package/src/model_selection/RepeatedStratifiedKFold.ts +50 -0
  29. package/src/model_selection/StratifiedKFold.ts +112 -0
  30. package/src/model_selection/StratifiedShuffleSplit.ts +211 -0
  31. package/src/model_selection/crossValScore.ts +165 -0
  32. package/src/model_selection/trainTestSplit.ts +82 -0
  33. package/src/naive_bayes/GaussianNB.ts +148 -0
  34. package/src/native/node-addon/bun_scikit_addon.cpp +450 -0
  35. package/src/native/zigKernels.ts +576 -0
  36. package/src/neighbors/KNeighborsClassifier.ts +85 -0
  37. package/src/pipeline/ColumnTransformer.ts +203 -0
  38. package/src/pipeline/FeatureUnion.ts +123 -0
  39. package/src/pipeline/Pipeline.ts +168 -0
  40. package/src/preprocessing/MinMaxScaler.ts +113 -0
  41. package/src/preprocessing/OneHotEncoder.ts +91 -0
  42. package/src/preprocessing/PolynomialFeatures.ts +158 -0
  43. package/src/preprocessing/RobustScaler.ts +149 -0
  44. package/src/preprocessing/SimpleImputer.ts +150 -0
  45. package/src/preprocessing/StandardScaler.ts +92 -0
  46. package/src/svm/LinearSVC.ts +117 -0
  47. package/src/tree/DecisionTreeClassifier.ts +394 -0
  48. package/src/tree/DecisionTreeRegressor.ts +407 -0
  49. package/src/types.ts +18 -0
  50. package/src/utils/linalg.ts +209 -0
  51. package/src/utils/validation.ts +78 -0
  52. package/zig/kernels.zig +1327 -0
@@ -0,0 +1,158 @@
1
+ import type { Matrix } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export interface PolynomialFeaturesOptions {
9
+ degree?: number;
10
+ includeBias?: boolean;
11
+ interactionOnly?: boolean;
12
+ }
13
+
14
+ type TermPower = number[];
15
+
16
+ function generateCombinationsWithReplacement(
17
+ nFeatures: number,
18
+ degree: number,
19
+ start: number,
20
+ current: number[],
21
+ out: number[][],
22
+ ): void {
23
+ if (current.length === degree) {
24
+ out.push(current.slice());
25
+ return;
26
+ }
27
+
28
+ for (let featureIndex = start; featureIndex < nFeatures; featureIndex += 1) {
29
+ current.push(featureIndex);
30
+ generateCombinationsWithReplacement(
31
+ nFeatures,
32
+ degree,
33
+ featureIndex,
34
+ current,
35
+ out,
36
+ );
37
+ current.pop();
38
+ }
39
+ }
40
+
41
+ function generateCombinationsWithoutReplacement(
42
+ nFeatures: number,
43
+ degree: number,
44
+ start: number,
45
+ current: number[],
46
+ out: number[][],
47
+ ): void {
48
+ if (current.length === degree) {
49
+ out.push(current.slice());
50
+ return;
51
+ }
52
+
53
+ for (let featureIndex = start; featureIndex < nFeatures; featureIndex += 1) {
54
+ current.push(featureIndex);
55
+ generateCombinationsWithoutReplacement(
56
+ nFeatures,
57
+ degree,
58
+ featureIndex + 1,
59
+ current,
60
+ out,
61
+ );
62
+ current.pop();
63
+ }
64
+ }
65
+
66
+ export class PolynomialFeatures {
67
+ nFeaturesIn_: number | null = null;
68
+ nOutputFeatures_: number | null = null;
69
+ powers_: number[][] | null = null;
70
+
71
+ private readonly degree: number;
72
+ private readonly includeBias: boolean;
73
+ private readonly interactionOnly: boolean;
74
+
75
+ constructor(options: PolynomialFeaturesOptions = {}) {
76
+ this.degree = options.degree ?? 2;
77
+ this.includeBias = options.includeBias ?? true;
78
+ this.interactionOnly = options.interactionOnly ?? false;
79
+
80
+ if (!Number.isInteger(this.degree) || this.degree < 0) {
81
+ throw new Error(`degree must be an integer >= 0. Got ${this.degree}.`);
82
+ }
83
+ }
84
+
85
+ fit(X: Matrix): this {
86
+ assertNonEmptyMatrix(X);
87
+ assertConsistentRowSize(X);
88
+ assertFiniteMatrix(X);
89
+
90
+ const nFeatures = X[0].length;
91
+ const powers: TermPower[] = [];
92
+
93
+ if (this.includeBias) {
94
+ powers.push(new Array<number>(nFeatures).fill(0));
95
+ }
96
+
97
+ for (let d = 1; d <= this.degree; d += 1) {
98
+ const combinations: number[][] = [];
99
+ if (this.interactionOnly) {
100
+ generateCombinationsWithoutReplacement(nFeatures, d, 0, [], combinations);
101
+ } else {
102
+ generateCombinationsWithReplacement(nFeatures, d, 0, [], combinations);
103
+ }
104
+
105
+ for (let i = 0; i < combinations.length; i += 1) {
106
+ const power = new Array<number>(nFeatures).fill(0);
107
+ for (let j = 0; j < combinations[i].length; j += 1) {
108
+ power[combinations[i][j]] += 1;
109
+ }
110
+ powers.push(power);
111
+ }
112
+ }
113
+
114
+ this.nFeaturesIn_ = nFeatures;
115
+ this.nOutputFeatures_ = powers.length;
116
+ this.powers_ = powers;
117
+ return this;
118
+ }
119
+
120
+ transform(X: Matrix): Matrix {
121
+ if (this.nFeaturesIn_ === null || this.powers_ === null || this.nOutputFeatures_ === null) {
122
+ throw new Error("PolynomialFeatures has not been fitted.");
123
+ }
124
+
125
+ assertNonEmptyMatrix(X);
126
+ assertConsistentRowSize(X);
127
+ assertFiniteMatrix(X);
128
+
129
+ if (X[0].length !== this.nFeaturesIn_) {
130
+ throw new Error(`Feature size mismatch. Expected ${this.nFeaturesIn_}, got ${X[0].length}.`);
131
+ }
132
+
133
+ const transformed = new Array<number[]>(X.length);
134
+ for (let i = 0; i < X.length; i += 1) {
135
+ const row = X[i];
136
+ const outRow = new Array<number>(this.nOutputFeatures_);
137
+ for (let termIndex = 0; termIndex < this.powers_.length; termIndex += 1) {
138
+ const power = this.powers_[termIndex];
139
+ let value = 1;
140
+ for (let featureIndex = 0; featureIndex < this.nFeaturesIn_; featureIndex += 1) {
141
+ const exponent = power[featureIndex];
142
+ if (exponent === 1) {
143
+ value *= row[featureIndex];
144
+ } else if (exponent > 1) {
145
+ value *= row[featureIndex] ** exponent;
146
+ }
147
+ }
148
+ outRow[termIndex] = value;
149
+ }
150
+ transformed[i] = outRow;
151
+ }
152
+ return transformed;
153
+ }
154
+
155
+ fitTransform(X: Matrix): Matrix {
156
+ return this.fit(X).transform(X);
157
+ }
158
+ }
@@ -0,0 +1,149 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export interface RobustScalerOptions {
9
+ withCentering?: boolean;
10
+ withScaling?: boolean;
11
+ quantileRange?: [number, number];
12
+ }
13
+
14
+ function percentile(sortedValues: number[], q: number): number {
15
+ if (sortedValues.length === 0) {
16
+ throw new Error("Cannot compute percentile of an empty array.");
17
+ }
18
+ if (q <= 0) {
19
+ return sortedValues[0];
20
+ }
21
+ if (q >= 100) {
22
+ return sortedValues[sortedValues.length - 1];
23
+ }
24
+
25
+ const position = (q / 100) * (sortedValues.length - 1);
26
+ const lowerIndex = Math.floor(position);
27
+ const upperIndex = Math.ceil(position);
28
+ if (lowerIndex === upperIndex) {
29
+ return sortedValues[lowerIndex];
30
+ }
31
+
32
+ const weight = position - lowerIndex;
33
+ return sortedValues[lowerIndex] * (1 - weight) + sortedValues[upperIndex] * weight;
34
+ }
35
+
36
+ export class RobustScaler {
37
+ center_: Vector | null = null;
38
+ scale_: Vector | null = null;
39
+
40
+ private readonly withCentering: boolean;
41
+ private readonly withScaling: boolean;
42
+ private readonly quantileRange: [number, number];
43
+
44
+ constructor(options: RobustScalerOptions = {}) {
45
+ this.withCentering = options.withCentering ?? true;
46
+ this.withScaling = options.withScaling ?? true;
47
+ this.quantileRange = options.quantileRange ?? [25, 75];
48
+
49
+ const [qMin, qMax] = this.quantileRange;
50
+ if (
51
+ !Number.isFinite(qMin) ||
52
+ !Number.isFinite(qMax) ||
53
+ qMin < 0 ||
54
+ qMax > 100 ||
55
+ qMin >= qMax
56
+ ) {
57
+ throw new Error(
58
+ `quantileRange must satisfy 0 <= qMin < qMax <= 100. Got [${qMin}, ${qMax}].`,
59
+ );
60
+ }
61
+ }
62
+
63
+ fit(X: Matrix): this {
64
+ assertNonEmptyMatrix(X);
65
+ assertConsistentRowSize(X);
66
+ assertFiniteMatrix(X);
67
+
68
+ const nFeatures = X[0].length;
69
+ const centers = new Array<number>(nFeatures).fill(0);
70
+ const scales = new Array<number>(nFeatures).fill(1);
71
+ const [qMin, qMax] = this.quantileRange;
72
+
73
+ for (let featureIndex = 0; featureIndex < nFeatures; featureIndex += 1) {
74
+ const values = new Array<number>(X.length);
75
+ for (let i = 0; i < X.length; i += 1) {
76
+ values[i] = X[i][featureIndex];
77
+ }
78
+ values.sort((a, b) => a - b);
79
+
80
+ centers[featureIndex] = percentile(values, 50);
81
+ const lower = percentile(values, qMin);
82
+ const upper = percentile(values, qMax);
83
+ const iqr = upper - lower;
84
+ scales[featureIndex] = iqr === 0 ? 1 : iqr;
85
+ }
86
+
87
+ this.center_ = centers;
88
+ this.scale_ = scales;
89
+ return this;
90
+ }
91
+
92
+ transform(X: Matrix): Matrix {
93
+ if (!this.center_ || !this.scale_) {
94
+ throw new Error("RobustScaler has not been fitted.");
95
+ }
96
+
97
+ assertNonEmptyMatrix(X);
98
+ assertConsistentRowSize(X);
99
+ assertFiniteMatrix(X);
100
+
101
+ if (X[0].length !== this.center_.length) {
102
+ throw new Error(`Feature size mismatch. Expected ${this.center_.length}, got ${X[0].length}.`);
103
+ }
104
+
105
+ return X.map((row) =>
106
+ row.map((value, featureIndex) => {
107
+ let out = value;
108
+ if (this.withCentering) {
109
+ out -= this.center_![featureIndex];
110
+ }
111
+ if (this.withScaling) {
112
+ out /= this.scale_![featureIndex];
113
+ }
114
+ return out;
115
+ }),
116
+ );
117
+ }
118
+
119
+ fitTransform(X: Matrix): Matrix {
120
+ return this.fit(X).transform(X);
121
+ }
122
+
123
+ inverseTransform(X: Matrix): Matrix {
124
+ if (!this.center_ || !this.scale_) {
125
+ throw new Error("RobustScaler has not been fitted.");
126
+ }
127
+
128
+ assertNonEmptyMatrix(X);
129
+ assertConsistentRowSize(X);
130
+ assertFiniteMatrix(X);
131
+
132
+ if (X[0].length !== this.center_.length) {
133
+ throw new Error(`Feature size mismatch. Expected ${this.center_.length}, got ${X[0].length}.`);
134
+ }
135
+
136
+ return X.map((row) =>
137
+ row.map((value, featureIndex) => {
138
+ let out = value;
139
+ if (this.withScaling) {
140
+ out *= this.scale_![featureIndex];
141
+ }
142
+ if (this.withCentering) {
143
+ out += this.center_![featureIndex];
144
+ }
145
+ return out;
146
+ }),
147
+ );
148
+ }
149
+ }
@@ -0,0 +1,150 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import { assertConsistentRowSize, assertNonEmptyMatrix } from "../utils/validation";
3
+
4
+ export type ImputerStrategy = "mean" | "median" | "most_frequent" | "constant";
5
+
6
+ export interface SimpleImputerOptions {
7
+ strategy?: ImputerStrategy;
8
+ fillValue?: number;
9
+ }
10
+
11
+ function isMissing(value: number): boolean {
12
+ return Number.isNaN(value);
13
+ }
14
+
15
+ function assertFiniteOrMissing(X: Matrix, label = "X"): void {
16
+ for (let i = 0; i < X.length; i += 1) {
17
+ for (let j = 0; j < X[i].length; j += 1) {
18
+ const value = X[i][j];
19
+ if (!Number.isFinite(value) && !isMissing(value)) {
20
+ throw new Error(`${label} contains non-finite non-missing value at [${i}, ${j}].`);
21
+ }
22
+ }
23
+ }
24
+ }
25
+
26
+ function median(values: number[]): number {
27
+ if (values.length === 0) {
28
+ throw new Error("Cannot compute median of an empty array.");
29
+ }
30
+ const sorted = values.slice().sort((a, b) => a - b);
31
+ const middle = Math.floor(sorted.length / 2);
32
+ if (sorted.length % 2 === 0) {
33
+ return 0.5 * (sorted[middle - 1] + sorted[middle]);
34
+ }
35
+ return sorted[middle];
36
+ }
37
+
38
+ function mostFrequent(values: number[]): number {
39
+ if (values.length === 0) {
40
+ throw new Error("Cannot compute most frequent of an empty array.");
41
+ }
42
+ const counts = new Map<number, number>();
43
+ for (let i = 0; i < values.length; i += 1) {
44
+ counts.set(values[i], (counts.get(values[i]) ?? 0) + 1);
45
+ }
46
+
47
+ let bestValue = values[0];
48
+ let bestCount = counts.get(bestValue)!;
49
+ for (const [value, count] of counts.entries()) {
50
+ if (count > bestCount || (count === bestCount && value < bestValue)) {
51
+ bestValue = value;
52
+ bestCount = count;
53
+ }
54
+ }
55
+ return bestValue;
56
+ }
57
+
58
+ export class SimpleImputer {
59
+ statistics_: Vector | null = null;
60
+ private readonly strategy: ImputerStrategy;
61
+ private readonly fillValue?: number;
62
+
63
+ constructor(options: SimpleImputerOptions = {}) {
64
+ this.strategy = options.strategy ?? "mean";
65
+ this.fillValue = options.fillValue;
66
+
67
+ if (this.strategy === "constant") {
68
+ const value = this.fillValue ?? 0;
69
+ if (!Number.isFinite(value)) {
70
+ throw new Error(`fillValue must be finite for constant strategy. Got ${value}.`);
71
+ }
72
+ }
73
+ }
74
+
75
+ fit(X: Matrix): this {
76
+ assertNonEmptyMatrix(X);
77
+ assertConsistentRowSize(X);
78
+ assertFiniteOrMissing(X);
79
+
80
+ const nFeatures = X[0].length;
81
+ const stats = new Array<number>(nFeatures);
82
+
83
+ for (let featureIndex = 0; featureIndex < nFeatures; featureIndex += 1) {
84
+ const values: number[] = [];
85
+ for (let i = 0; i < X.length; i += 1) {
86
+ const value = X[i][featureIndex];
87
+ if (!isMissing(value)) {
88
+ values.push(value);
89
+ }
90
+ }
91
+
92
+ if (values.length === 0) {
93
+ if (this.strategy === "constant") {
94
+ stats[featureIndex] = this.fillValue ?? 0;
95
+ continue;
96
+ }
97
+ throw new Error(
98
+ `Feature at index ${featureIndex} has only missing values. Use strategy='constant' or provide observed values.`,
99
+ );
100
+ }
101
+
102
+ switch (this.strategy) {
103
+ case "mean": {
104
+ let sum = 0;
105
+ for (let i = 0; i < values.length; i += 1) {
106
+ sum += values[i];
107
+ }
108
+ stats[featureIndex] = sum / values.length;
109
+ break;
110
+ }
111
+ case "median":
112
+ stats[featureIndex] = median(values);
113
+ break;
114
+ case "most_frequent":
115
+ stats[featureIndex] = mostFrequent(values);
116
+ break;
117
+ case "constant":
118
+ stats[featureIndex] = this.fillValue ?? 0;
119
+ break;
120
+ }
121
+ }
122
+
123
+ this.statistics_ = stats;
124
+ return this;
125
+ }
126
+
127
+ transform(X: Matrix): Matrix {
128
+ if (!this.statistics_) {
129
+ throw new Error("SimpleImputer has not been fitted.");
130
+ }
131
+
132
+ assertNonEmptyMatrix(X);
133
+ assertConsistentRowSize(X);
134
+ assertFiniteOrMissing(X);
135
+
136
+ if (X[0].length !== this.statistics_.length) {
137
+ throw new Error(
138
+ `Feature size mismatch. Expected ${this.statistics_.length}, got ${X[0].length}.`,
139
+ );
140
+ }
141
+
142
+ return X.map((row) =>
143
+ row.map((value, featureIndex) => (isMissing(value) ? this.statistics_![featureIndex] : value)),
144
+ );
145
+ }
146
+
147
+ fitTransform(X: Matrix): Matrix {
148
+ return this.fit(X).transform(X);
149
+ }
150
+ }
@@ -0,0 +1,92 @@
1
+ import type { Matrix, Vector } from "../types";
2
+ import {
3
+ assertConsistentRowSize,
4
+ assertFiniteMatrix,
5
+ assertNonEmptyMatrix,
6
+ } from "../utils/validation";
7
+
8
+ export class StandardScaler {
9
+ mean_: Vector | null = null;
10
+ scale_: Vector | null = null;
11
+
12
+ fit(X: Matrix): this {
13
+ assertNonEmptyMatrix(X);
14
+ assertConsistentRowSize(X);
15
+ assertFiniteMatrix(X);
16
+
17
+ const nSamples = X.length;
18
+ const nFeatures = X[0].length;
19
+ const means = new Array(nFeatures).fill(0);
20
+ const variances = new Array(nFeatures).fill(0);
21
+
22
+ for (let i = 0; i < nSamples; i += 1) {
23
+ for (let j = 0; j < nFeatures; j += 1) {
24
+ means[j] += X[i][j];
25
+ }
26
+ }
27
+
28
+ for (let j = 0; j < nFeatures; j += 1) {
29
+ means[j] /= nSamples;
30
+ }
31
+
32
+ for (let i = 0; i < nSamples; i += 1) {
33
+ for (let j = 0; j < nFeatures; j += 1) {
34
+ const diff = X[i][j] - means[j];
35
+ variances[j] += diff * diff;
36
+ }
37
+ }
38
+
39
+ const scales = variances.map((v) => {
40
+ const std = Math.sqrt(v / nSamples);
41
+ return std === 0 ? 1 : std;
42
+ });
43
+
44
+ this.mean_ = means;
45
+ this.scale_ = scales;
46
+ return this;
47
+ }
48
+
49
+ transform(X: Matrix): Matrix {
50
+ if (!this.mean_ || !this.scale_) {
51
+ throw new Error("StandardScaler has not been fitted.");
52
+ }
53
+
54
+ assertNonEmptyMatrix(X);
55
+ assertConsistentRowSize(X);
56
+ assertFiniteMatrix(X);
57
+
58
+ if (X[0].length !== this.mean_.length) {
59
+ throw new Error(
60
+ `Feature size mismatch. Expected ${this.mean_.length}, got ${X[0].length}.`,
61
+ );
62
+ }
63
+
64
+ return X.map((row) =>
65
+ row.map((value, featureIdx) => (value - this.mean_![featureIdx]) / this.scale_![featureIdx]),
66
+ );
67
+ }
68
+
69
+ fitTransform(X: Matrix): Matrix {
70
+ return this.fit(X).transform(X);
71
+ }
72
+
73
+ inverseTransform(X: Matrix): Matrix {
74
+ if (!this.mean_ || !this.scale_) {
75
+ throw new Error("StandardScaler has not been fitted.");
76
+ }
77
+
78
+ assertNonEmptyMatrix(X);
79
+ assertConsistentRowSize(X);
80
+ assertFiniteMatrix(X);
81
+
82
+ if (X[0].length !== this.mean_.length) {
83
+ throw new Error(
84
+ `Feature size mismatch. Expected ${this.mean_.length}, got ${X[0].length}.`,
85
+ );
86
+ }
87
+
88
+ return X.map((row) =>
89
+ row.map((value, featureIdx) => value * this.scale_![featureIdx] + this.mean_![featureIdx]),
90
+ );
91
+ }
92
+ }
@@ -0,0 +1,117 @@
1
+ import type { ClassificationModel, Matrix, Vector } from "../types";
2
+ import { accuracyScore } from "../metrics/classification";
3
+ import { dot } from "../utils/linalg";
4
+ import {
5
+ assertConsistentRowSize,
6
+ assertFiniteMatrix,
7
+ assertFiniteVector,
8
+ validateClassificationInputs,
9
+ } from "../utils/validation";
10
+
11
+ export interface LinearSVCOptions {
12
+ fitIntercept?: boolean;
13
+ C?: number;
14
+ learningRate?: number;
15
+ maxIter?: number;
16
+ tolerance?: number;
17
+ }
18
+
19
+ export class LinearSVC implements ClassificationModel {
20
+ coef_: Vector = [];
21
+ intercept_ = 0;
22
+ classes_: Vector = [0, 1];
23
+
24
+ private readonly fitIntercept: boolean;
25
+ private readonly C: number;
26
+ private readonly learningRate: number;
27
+ private readonly maxIter: number;
28
+ private readonly tolerance: number;
29
+ private isFitted = false;
30
+
31
+ constructor(options: LinearSVCOptions = {}) {
32
+ this.fitIntercept = options.fitIntercept ?? true;
33
+ this.C = options.C ?? 1.0;
34
+ this.learningRate = options.learningRate ?? 0.05;
35
+ this.maxIter = options.maxIter ?? 10_000;
36
+ this.tolerance = options.tolerance ?? 1e-6;
37
+
38
+ if (!Number.isFinite(this.C) || this.C <= 0) {
39
+ throw new Error(`C must be > 0. Got ${this.C}.`);
40
+ }
41
+ }
42
+
43
+ fit(X: Matrix, y: Vector): this {
44
+ validateClassificationInputs(X, y);
45
+ const nSamples = X.length;
46
+ const nFeatures = X[0].length;
47
+
48
+ this.coef_ = new Array<number>(nFeatures).fill(0);
49
+ this.intercept_ = 0;
50
+ const ySigned = y.map((value) => (value === 1 ? 1 : -1));
51
+
52
+ for (let iter = 0; iter < this.maxIter; iter += 1) {
53
+ const gradients = this.coef_.slice();
54
+ let interceptGradient = 0;
55
+
56
+ for (let i = 0; i < nSamples; i += 1) {
57
+ const margin = ySigned[i] * (dot(X[i], this.coef_) + this.intercept_);
58
+ if (margin < 1) {
59
+ const factor = -this.C * ySigned[i];
60
+ for (let j = 0; j < nFeatures; j += 1) {
61
+ gradients[j] += factor * X[i][j];
62
+ }
63
+ if (this.fitIntercept) {
64
+ interceptGradient += factor;
65
+ }
66
+ }
67
+ }
68
+
69
+ let maxUpdate = 0;
70
+ for (let j = 0; j < nFeatures; j += 1) {
71
+ const delta = this.learningRate * (gradients[j] / nSamples);
72
+ this.coef_[j] -= delta;
73
+ const absDelta = Math.abs(delta);
74
+ if (absDelta > maxUpdate) {
75
+ maxUpdate = absDelta;
76
+ }
77
+ }
78
+
79
+ if (this.fitIntercept) {
80
+ const interceptDelta = this.learningRate * (interceptGradient / nSamples);
81
+ this.intercept_ -= interceptDelta;
82
+ const absInterceptDelta = Math.abs(interceptDelta);
83
+ if (absInterceptDelta > maxUpdate) {
84
+ maxUpdate = absInterceptDelta;
85
+ }
86
+ }
87
+
88
+ if (maxUpdate < this.tolerance) {
89
+ break;
90
+ }
91
+ }
92
+
93
+ this.isFitted = true;
94
+ return this;
95
+ }
96
+
97
+ decisionFunction(X: Matrix): Vector {
98
+ if (!this.isFitted) {
99
+ throw new Error("LinearSVC has not been fitted.");
100
+ }
101
+ assertConsistentRowSize(X);
102
+ assertFiniteMatrix(X);
103
+ if (X[0].length !== this.coef_.length) {
104
+ throw new Error(`Feature size mismatch. Expected ${this.coef_.length}, got ${X[0].length}.`);
105
+ }
106
+ return X.map((row) => dot(row, this.coef_) + this.intercept_);
107
+ }
108
+
109
+ predict(X: Matrix): Vector {
110
+ return this.decisionFunction(X).map((score) => (score >= 0 ? 1 : 0));
111
+ }
112
+
113
+ score(X: Matrix, y: Vector): number {
114
+ assertFiniteVector(y);
115
+ return accuracyScore(y, this.predict(X));
116
+ }
117
+ }