datly 0.0.1 → 0.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,412 @@
1
+ import BaseModel from './baseModel.js';
2
+
3
+ class DecisionTree extends BaseModel {
4
+ constructor(maxDepth = 10, minSamplesSplit = 2, minSamplesLeaf = 1, criterion = 'gini') {
5
+ super();
6
+ this.maxDepth = maxDepth;
7
+ this.minSamplesSplit = minSamplesSplit;
8
+ this.minSamplesLeaf = minSamplesLeaf;
9
+ this.criterion = criterion; // 'gini', 'entropy' for classification; 'mse' for regression
10
+ this.tree = null;
11
+ this.taskType = null;
12
+ this.classes = null;
13
+ }
14
+
15
+ fit(X, y, taskType = 'classification') {
16
+ this.validateTrainingData(X, y);
17
+
18
+ this.taskType = taskType;
19
+ const X_train = X.map(row => Array.isArray(row) ? row : [row]);
20
+
21
+ if (taskType === 'classification') {
22
+ this.classes = [...new Set(y)].sort();
23
+ }
24
+
25
+ this.tree = this.buildTree(X_train, y, 0);
26
+ this.trained = true;
27
+
28
+ this.trainingMetrics = {
29
+ treeDepth: this.getTreeDepth(this.tree),
30
+ leafCount: this.getLeafCount(this.tree),
31
+ nodeCount: this.getNodeCount(this.tree),
32
+ taskType: this.taskType
33
+ };
34
+
35
+ return this;
36
+ }
37
+
38
+ buildTree(X, y, depth) {
39
+ const nSamples = X.length;
40
+ const nFeatures = X[0].length;
41
+
42
+ // Stopping criteria
43
+ if (depth >= this.maxDepth ||
44
+ nSamples < this.minSamplesSplit ||
45
+ this.isPure(y)) {
46
+ return this.createLeaf(y);
47
+ }
48
+
49
+ // Find best split
50
+ let bestSplit = null;
51
+ let bestScore = -Infinity;
52
+
53
+ for (let featureIdx = 0; featureIdx < nFeatures; featureIdx++) {
54
+ const thresholds = this.getThresholds(X, featureIdx);
55
+
56
+ for (const threshold of thresholds) {
57
+ const { left, right } = this.splitData(X, y, featureIdx, threshold);
58
+
59
+ if (left.y.length < this.minSamplesLeaf ||
60
+ right.y.length < this.minSamplesLeaf) {
61
+ continue;
62
+ }
63
+
64
+ const score = this.calculateSplitScore(y, left.y, right.y);
65
+
66
+ if (score > bestScore) {
67
+ bestScore = score;
68
+ bestSplit = {
69
+ featureIdx,
70
+ threshold,
71
+ left,
72
+ right
73
+ };
74
+ }
75
+ }
76
+ }
77
+
78
+ if (!bestSplit) {
79
+ return this.createLeaf(y);
80
+ }
81
+
82
+ // Recursively build subtrees
83
+ return {
84
+ featureIdx: bestSplit.featureIdx,
85
+ threshold: bestSplit.threshold,
86
+ left: this.buildTree(bestSplit.left.X, bestSplit.left.y, depth + 1),
87
+ right: this.buildTree(bestSplit.right.X, bestSplit.right.y, depth + 1),
88
+ isLeaf: false
89
+ };
90
+ }
91
+
92
+ getThresholds(X, featureIdx) {
93
+ const values = [...new Set(X.map(row => row[featureIdx]))].sort((a, b) => a - b);
94
+ const thresholds = [];
95
+
96
+ for (let i = 0; i < values.length - 1; i++) {
97
+ thresholds.push((values[i] + values[i + 1]) / 2);
98
+ }
99
+
100
+ return thresholds;
101
+ }
102
+
103
+ splitData(X, y, featureIdx, threshold) {
104
+ const leftX = [], leftY = [];
105
+ const rightX = [], rightY = [];
106
+
107
+ for (let i = 0; i < X.length; i++) {
108
+ if (X[i][featureIdx] <= threshold) {
109
+ leftX.push(X[i]);
110
+ leftY.push(y[i]);
111
+ } else {
112
+ rightX.push(X[i]);
113
+ rightY.push(y[i]);
114
+ }
115
+ }
116
+
117
+ return {
118
+ left: { X: leftX, y: leftY },
119
+ right: { X: rightX, y: rightY }
120
+ };
121
+ }
122
+
123
+ isPure(y) {
124
+ return new Set(y).size === 1;
125
+ }
126
+
127
+ createLeaf(y) {
128
+ if (this.taskType === 'classification') {
129
+ const counts = {};
130
+ y.forEach(label => {
131
+ counts[label] = (counts[label] || 0) + 1;
132
+ });
133
+ const prediction = Object.keys(counts).reduce((a, b) =>
134
+ counts[a] > counts[b] ? a : b
135
+ );
136
+ return {
137
+ isLeaf: true,
138
+ prediction: prediction,
139
+ samples: y.length,
140
+ distribution: counts
141
+ };
142
+ } else {
143
+ const mean = y.reduce((sum, val) => sum + val, 0) / y.length;
144
+ return {
145
+ isLeaf: true,
146
+ prediction: mean,
147
+ samples: y.length
148
+ };
149
+ }
150
+ }
151
+
152
+ calculateSplitScore(parentY, leftY, rightY) {
153
+ if (this.taskType === 'classification') {
154
+ const parentImpurity = this.calculateImpurity(parentY);
155
+ const n = parentY.length;
156
+ const nLeft = leftY.length;
157
+ const nRight = rightY.length;
158
+
159
+ const leftImpurity = this.calculateImpurity(leftY);
160
+ const rightImpurity = this.calculateImpurity(rightY);
161
+
162
+ const weightedImpurity = (nLeft / n) * leftImpurity + (nRight / n) * rightImpurity;
163
+ return parentImpurity - weightedImpurity; // Information gain
164
+ } else {
165
+ const parentVariance = this.calculateVariance(parentY);
166
+ const n = parentY.length;
167
+ const nLeft = leftY.length;
168
+ const nRight = rightY.length;
169
+
170
+ const leftVariance = this.calculateVariance(leftY);
171
+ const rightVariance = this.calculateVariance(rightY);
172
+
173
+ const weightedVariance = (nLeft / n) * leftVariance + (nRight / n) * rightVariance;
174
+ return parentVariance - weightedVariance; // Variance reduction
175
+ }
176
+ }
177
+
178
+ calculateImpurity(y) {
179
+ const counts = {};
180
+ y.forEach(label => {
181
+ counts[label] = (counts[label] || 0) + 1;
182
+ });
183
+
184
+ const n = y.length;
185
+ const probabilities = Object.values(counts).map(count => count / n);
186
+
187
+ if (this.criterion === 'gini') {
188
+ return 1 - probabilities.reduce((sum, p) => sum + p * p, 0);
189
+ } else if (this.criterion === 'entropy') {
190
+ return -probabilities.reduce((sum, p) => sum + p * Math.log2(p), 0);
191
+ }
192
+ }
193
+
194
+ calculateVariance(y) {
195
+ if (y.length === 0) return 0;
196
+ const mean = y.reduce((sum, val) => sum + val, 0) / y.length;
197
+ return y.reduce((sum, val) => sum + Math.pow(val - mean, 2), 0) / y.length;
198
+ }
199
+
200
+ predictSingle(x, node = this.tree) {
201
+ if (node.isLeaf) {
202
+ return node.prediction;
203
+ }
204
+
205
+ if (x[node.featureIdx] <= node.threshold) {
206
+ return this.predictSingle(x, node.left);
207
+ } else {
208
+ return this.predictSingle(x, node.right);
209
+ }
210
+ }
211
+
212
+ predict(X) {
213
+ this.validatePredictionData(X);
214
+
215
+ const X_test = X.map(row => Array.isArray(row) ? row : [row]);
216
+ return X_test.map(x => this.predictSingle(x));
217
+ }
218
+
219
+ predictProba(X) {
220
+ if (this.taskType !== 'classification') {
221
+ throw new Error('predictProba is only available for classification tasks');
222
+ }
223
+
224
+ this.validatePredictionData(X);
225
+
226
+ const X_test = X.map(row => Array.isArray(row) ? row : [row]);
227
+
228
+ return X_test.map(x => {
229
+ const leaf = this.findLeaf(x);
230
+ const total = leaf.samples;
231
+ const probas = {};
232
+
233
+ this.classes.forEach(cls => {
234
+ probas[cls] = (leaf.distribution[cls] || 0) / total;
235
+ });
236
+
237
+ return probas;
238
+ });
239
+ }
240
+
241
+ findLeaf(x, node = this.tree) {
242
+ if (node.isLeaf) {
243
+ return node;
244
+ }
245
+
246
+ if (x[node.featureIdx] <= node.threshold) {
247
+ return this.findLeaf(x, node.left);
248
+ } else {
249
+ return this.findLeaf(x, node.right);
250
+ }
251
+ }
252
+
253
+ score(X, y) {
254
+ const predictions = this.predict(X);
255
+
256
+ if (this.taskType === 'classification') {
257
+ let correct = 0;
258
+ for (let i = 0; i < y.length; i++) {
259
+ if (predictions[i] === y[i]) correct++;
260
+ }
261
+ const accuracy = correct / y.length;
262
+
263
+ const cm = this.confusionMatrix(y, predictions);
264
+ const metrics = this.calculateClassMetrics(cm);
265
+
266
+ return {
267
+ accuracy: accuracy,
268
+ confusionMatrix: cm,
269
+ classMetrics: metrics,
270
+ predictions: predictions
271
+ };
272
+ } else {
273
+ const yMean = y.reduce((sum, val) => sum + val, 0) / y.length;
274
+
275
+ const ssRes = predictions.reduce((sum, pred, i) =>
276
+ sum + Math.pow(y[i] - pred, 2), 0);
277
+ const ssTot = y.reduce((sum, val) =>
278
+ sum + Math.pow(val - yMean, 2), 0);
279
+
280
+ const r2 = 1 - (ssRes / ssTot);
281
+ const mse = ssRes / y.length;
282
+ const rmse = Math.sqrt(mse);
283
+ const mae = predictions.reduce((sum, pred, i) =>
284
+ sum + Math.abs(y[i] - pred), 0) / y.length;
285
+
286
+ return {
287
+ r2Score: r2,
288
+ mse: mse,
289
+ rmse: rmse,
290
+ mae: mae,
291
+ predictions: predictions,
292
+ residuals: predictions.map((pred, i) => y[i] - pred)
293
+ };
294
+ }
295
+ }
296
+
297
+ confusionMatrix(yTrue, yPred) {
298
+ const n = this.classes.length;
299
+ const matrix = Array(n).fill(0).map(() => Array(n).fill(0));
300
+
301
+ for (let i = 0; i < yTrue.length; i++) {
302
+ const trueIdx = this.classes.indexOf(yTrue[i]);
303
+ const predIdx = this.classes.indexOf(yPred[i]);
304
+ matrix[trueIdx][predIdx]++;
305
+ }
306
+
307
+ return {
308
+ matrix: matrix,
309
+ classes: this.classes,
310
+ display: this.formatConfusionMatrix(matrix)
311
+ };
312
+ }
313
+
314
+ formatConfusionMatrix(matrix) {
315
+ const maxLen = Math.max(...matrix.flat().map(v => v.toString().length), 8);
316
+ const pad = (str) => str.toString().padStart(maxLen);
317
+
318
+ let output = '\n' + ' '.repeat(maxLen + 2) + 'Predicted\n';
319
+ output += ' '.repeat(maxLen + 2) + this.classes.map(c => pad(c)).join(' ') + '\n';
320
+
321
+ for (let i = 0; i < matrix.length; i++) {
322
+ if (i === 0) output += 'Actual ';
323
+ else output += ' ';
324
+ output += pad(this.classes[i]) + ' ';
325
+ output += matrix[i].map(v => pad(v)).join(' ') + '\n';
326
+ }
327
+
328
+ return output;
329
+ }
330
+
331
+ calculateClassMetrics(cm) {
332
+ const matrix = cm.matrix;
333
+ const metrics = {};
334
+
335
+ this.classes.forEach((cls, i) => {
336
+ const tp = matrix[i][i];
337
+ const fn = matrix[i].reduce((sum, val) => sum + val, 0) - tp;
338
+ const fp = matrix.map(row => row[i]).reduce((sum, val) => sum + val, 0) - tp;
339
+
340
+ const precision = tp + fp > 0 ? tp / (tp + fp) : 0;
341
+ const recall = tp + fn > 0 ? tp / (tp + fn) : 0;
342
+ const f1 = precision + recall > 0 ? 2 * (precision * recall) / (precision + recall) : 0;
343
+
344
+ metrics[cls] = {
345
+ precision: precision,
346
+ recall: recall,
347
+ f1Score: f1,
348
+ support: tp + fn
349
+ };
350
+ });
351
+
352
+ return metrics;
353
+ }
354
+
355
+ getTreeDepth(node) {
356
+ if (node.isLeaf) return 0;
357
+ return 1 + Math.max(this.getTreeDepth(node.left), this.getTreeDepth(node.right));
358
+ }
359
+
360
+ getLeafCount(node) {
361
+ if (node.isLeaf) return 1;
362
+ return this.getLeafCount(node.left) + this.getLeafCount(node.right);
363
+ }
364
+
365
+ getNodeCount(node) {
366
+ if (node.isLeaf) return 1;
367
+ return 1 + this.getNodeCount(node.left) + this.getNodeCount(node.right);
368
+ }
369
+
370
+ getFeatureImportance() {
371
+ const importance = {};
372
+ this.calculateImportance(this.tree, importance);
373
+
374
+ const total = Object.values(importance).reduce((sum, val) => sum + val, 0);
375
+ Object.keys(importance).forEach(key => {
376
+ importance[key] /= total;
377
+ });
378
+
379
+ return importance;
380
+ }
381
+
382
+ calculateImportance(node, importance) {
383
+ if (node.isLeaf) return;
384
+
385
+ const featureName = `feature_${node.featureIdx}`;
386
+ importance[featureName] = (importance[featureName] || 0) + 1;
387
+
388
+ this.calculateImportance(node.left, importance);
389
+ this.calculateImportance(node.right, importance);
390
+ }
391
+
392
+ summary() {
393
+ if (!this.trained) {
394
+ throw new Error('Model must be trained first');
395
+ }
396
+
397
+ return {
398
+ modelType: 'Decision Tree',
399
+ taskType: this.taskType,
400
+ trainingMetrics: this.trainingMetrics,
401
+ featureImportance: this.getFeatureImportance(),
402
+ hyperparameters: {
403
+ maxDepth: this.maxDepth,
404
+ minSamplesSplit: this.minSamplesSplit,
405
+ minSamplesLeaf: this.minSamplesLeaf,
406
+ criterion: this.criterion
407
+ }
408
+ };
409
+ }
410
+ }
411
+
412
+ export default DecisionTree;
@@ -0,0 +1,317 @@
1
+ import BaseModel from './baseModel.js';
2
+
3
+ class KNearestNeighbors extends BaseModel {
4
+ constructor(k = 5, metric = 'euclidean', weights = 'uniform') {
5
+ super();
6
+ this.k = k;
7
+ this.metric = metric; // 'euclidean', 'manhattan', 'minkowski'
8
+ this.weights = weights; // 'uniform' or 'distance'
9
+ this.X_train = null;
10
+ this.y_train = null;
11
+ this.normParams = null;
12
+ this.taskType = null; // 'classification' or 'regression'
13
+ }
14
+
15
+ euclideanDistance(x1, x2) {
16
+ return Math.sqrt(
17
+ x1.reduce((sum, val, i) => sum + Math.pow(val - x2[i], 2), 0)
18
+ );
19
+ }
20
+
21
+ manhattanDistance(x1, x2) {
22
+ return x1.reduce((sum, val, i) => sum + Math.abs(val - x2[i]), 0);
23
+ }
24
+
25
+ minkowskiDistance(x1, x2, p = 3) {
26
+ return Math.pow(
27
+ x1.reduce((sum, val, i) => sum + Math.pow(Math.abs(val - x2[i]), p), 0),
28
+ 1 / p
29
+ );
30
+ }
31
+
32
+ calculateDistance(x1, x2) {
33
+ switch (this.metric) {
34
+ case 'manhattan':
35
+ return this.manhattanDistance(x1, x2);
36
+ case 'minkowski':
37
+ return this.minkowskiDistance(x1, x2);
38
+ case 'euclidean':
39
+ default:
40
+ return this.euclideanDistance(x1, x2);
41
+ }
42
+ }
43
+
44
+ fit(X, y, normalize = true, taskType = 'classification') {
45
+ this.validateTrainingData(X, y);
46
+
47
+ this.taskType = taskType;
48
+ let X_train = X.map(row => Array.isArray(row) ? row : [row]);
49
+
50
+ if (normalize) {
51
+ const { normalized, means, stds } = this.normalizeFeatures(X_train);
52
+ this.X_train = normalized;
53
+ this.normParams = { means, stds };
54
+ } else {
55
+ this.X_train = X_train;
56
+ }
57
+
58
+ this.y_train = [...y];
59
+ this.trained = true;
60
+
61
+ this.trainingMetrics = {
62
+ samples: this.X_train.length,
63
+ features: this.X_train[0].length,
64
+ taskType: this.taskType
65
+ };
66
+
67
+ return this;
68
+ }
69
+
70
+ findKNearest(x) {
71
+ const distances = this.X_train.map((trainPoint, idx) => ({
72
+ distance: this.calculateDistance(x, trainPoint),
73
+ index: idx,
74
+ label: this.y_train[idx]
75
+ }));
76
+
77
+ distances.sort((a, b) => a.distance - b.distance);
78
+ return distances.slice(0, this.k);
79
+ }
80
+
81
+ predictSingleClassification(x) {
82
+ const neighbors = this.findKNearest(x);
83
+
84
+ if (this.weights === 'uniform') {
85
+ const votes = {};
86
+ neighbors.forEach(neighbor => {
87
+ votes[neighbor.label] = (votes[neighbor.label] || 0) + 1;
88
+ });
89
+
90
+ return Object.keys(votes).reduce((a, b) =>
91
+ votes[a] > votes[b] ? a : b
92
+ );
93
+ } else {
94
+ // Distance-weighted voting
95
+ const votes = {};
96
+ neighbors.forEach(neighbor => {
97
+ const weight = neighbor.distance === 0 ? 1e10 : 1 / neighbor.distance;
98
+ votes[neighbor.label] = (votes[neighbor.label] || 0) + weight;
99
+ });
100
+
101
+ return Object.keys(votes).reduce((a, b) =>
102
+ votes[a] > votes[b] ? a : b
103
+ );
104
+ }
105
+ }
106
+
107
+ predictSingleRegression(x) {
108
+ const neighbors = this.findKNearest(x);
109
+
110
+ if (this.weights === 'uniform') {
111
+ return neighbors.reduce((sum, n) => sum + n.label, 0) / neighbors.length;
112
+ } else {
113
+ // Distance-weighted average
114
+ let weightedSum = 0;
115
+ let totalWeight = 0;
116
+
117
+ neighbors.forEach(neighbor => {
118
+ const weight = neighbor.distance === 0 ? 1e10 : 1 / neighbor.distance;
119
+ weightedSum += neighbor.label * weight;
120
+ totalWeight += weight;
121
+ });
122
+
123
+ return weightedSum / totalWeight;
124
+ }
125
+ }
126
+
127
+ predict(X) {
128
+ this.validatePredictionData(X);
129
+
130
+ let X_test = X.map(row => Array.isArray(row) ? row : [row]);
131
+
132
+ if (this.normParams) {
133
+ const { means, stds } = this.normParams;
134
+ X_test = X_test.map(row =>
135
+ row.map((val, j) => (val - means[j]) / stds[j])
136
+ );
137
+ }
138
+
139
+ if (this.taskType === 'classification') {
140
+ return X_test.map(x => this.predictSingleClassification(x));
141
+ } else {
142
+ return X_test.map(x => this.predictSingleRegression(x));
143
+ }
144
+ }
145
+
146
+ predictProba(X) {
147
+ if (this.taskType !== 'classification') {
148
+ throw new Error('predictProba is only available for classification tasks');
149
+ }
150
+
151
+ this.validatePredictionData(X);
152
+
153
+ let X_test = X.map(row => Array.isArray(row) ? row : [row]);
154
+
155
+ if (this.normParams) {
156
+ const { means, stds } = this.normParams;
157
+ X_test = X_test.map(row =>
158
+ row.map((val, j) => (val - means[j]) / stds[j])
159
+ );
160
+ }
161
+
162
+ const classes = [...new Set(this.y_train)].sort();
163
+
164
+ return X_test.map(x => {
165
+ const neighbors = this.findKNearest(x);
166
+ const probas = {};
167
+
168
+ classes.forEach(cls => {
169
+ probas[cls] = 0;
170
+ });
171
+
172
+ if (this.weights === 'uniform') {
173
+ neighbors.forEach(neighbor => {
174
+ probas[neighbor.label] += 1 / this.k;
175
+ });
176
+ } else {
177
+ let totalWeight = 0;
178
+ const weights = {};
179
+
180
+ neighbors.forEach(neighbor => {
181
+ const weight = neighbor.distance === 0 ? 1e10 : 1 / neighbor.distance;
182
+ weights[neighbor.label] = (weights[neighbor.label] || 0) + weight;
183
+ totalWeight += weight;
184
+ });
185
+
186
+ Object.keys(weights).forEach(label => {
187
+ probas[label] = weights[label] / totalWeight;
188
+ });
189
+ }
190
+
191
+ return probas;
192
+ });
193
+ }
194
+
195
+ score(X, y) {
196
+ const predictions = this.predict(X);
197
+
198
+ if (this.taskType === 'classification') {
199
+ let correct = 0;
200
+ for (let i = 0; i < y.length; i++) {
201
+ if (predictions[i] === y[i]) correct++;
202
+ }
203
+ const accuracy = correct / y.length;
204
+
205
+ const classes = [...new Set([...y, ...predictions])].sort();
206
+ const cm = this.confusionMatrix(y, predictions, classes);
207
+ const metrics = this.calculateClassMetrics(cm, classes);
208
+
209
+ return {
210
+ accuracy: accuracy,
211
+ confusionMatrix: cm,
212
+ classMetrics: metrics,
213
+ predictions: predictions
214
+ };
215
+ } else {
216
+ // Regression metrics
217
+ const yMean = y.reduce((sum, val) => sum + val, 0) / y.length;
218
+
219
+ const ssRes = predictions.reduce((sum, pred, i) =>
220
+ sum + Math.pow(y[i] - pred, 2), 0);
221
+ const ssTot = y.reduce((sum, val) =>
222
+ sum + Math.pow(val - yMean, 2), 0);
223
+
224
+ const r2 = 1 - (ssRes / ssTot);
225
+ const mse = ssRes / y.length;
226
+ const rmse = Math.sqrt(mse);
227
+ const mae = predictions.reduce((sum, pred, i) =>
228
+ sum + Math.abs(y[i] - pred), 0) / y.length;
229
+
230
+ return {
231
+ r2Score: r2,
232
+ mse: mse,
233
+ rmse: rmse,
234
+ mae: mae,
235
+ predictions: predictions,
236
+ residuals: predictions.map((pred, i) => y[i] - pred)
237
+ };
238
+ }
239
+ }
240
+
241
+ confusionMatrix(yTrue, yPred, classes) {
242
+ const n = classes.length;
243
+ const matrix = Array(n).fill(0).map(() => Array(n).fill(0));
244
+
245
+ for (let i = 0; i < yTrue.length; i++) {
246
+ const trueIdx = classes.indexOf(yTrue[i]);
247
+ const predIdx = classes.indexOf(yPred[i]);
248
+ matrix[trueIdx][predIdx]++;
249
+ }
250
+
251
+ return {
252
+ matrix: matrix,
253
+ classes: classes,
254
+ display: this.formatConfusionMatrix(matrix, classes)
255
+ };
256
+ }
257
+
258
+ formatConfusionMatrix(matrix, classes) {
259
+ const maxLen = Math.max(...matrix.flat().map(v => v.toString().length), 8);
260
+ const pad = (str) => str.toString().padStart(maxLen);
261
+
262
+ let output = '\n' + ' '.repeat(maxLen + 2) + 'Predicted\n';
263
+ output += ' '.repeat(maxLen + 2) + classes.map(c => pad(c)).join(' ') + '\n';
264
+
265
+ for (let i = 0; i < matrix.length; i++) {
266
+ if (i === 0) output += 'Actual ';
267
+ else output += ' ';
268
+ output += pad(classes[i]) + ' ';
269
+ output += matrix[i].map(v => pad(v)).join(' ') + '\n';
270
+ }
271
+
272
+ return output;
273
+ }
274
+
275
+ calculateClassMetrics(cm, classes) {
276
+ const matrix = cm.matrix;
277
+ const metrics = {};
278
+
279
+ classes.forEach((cls, i) => {
280
+ const tp = matrix[i][i];
281
+ const fn = matrix[i].reduce((sum, val) => sum + val, 0) - tp;
282
+ const fp = matrix.map(row => row[i]).reduce((sum, val) => sum + val, 0) - tp;
283
+
284
+ const precision = tp + fp > 0 ? tp / (tp + fp) : 0;
285
+ const recall = tp + fn > 0 ? tp / (tp + fn) : 0;
286
+ const f1 = precision + recall > 0 ? 2 * (precision * recall) / (precision + recall) : 0;
287
+
288
+ metrics[cls] = {
289
+ precision: precision,
290
+ recall: recall,
291
+ f1Score: f1,
292
+ support: tp + fn
293
+ };
294
+ });
295
+
296
+ return metrics;
297
+ }
298
+
299
+ summary() {
300
+ if (!this.trained) {
301
+ throw new Error('Model must be trained first');
302
+ }
303
+
304
+ return {
305
+ modelType: 'K-Nearest Neighbors',
306
+ taskType: this.taskType,
307
+ trainingMetrics: this.trainingMetrics,
308
+ hyperparameters: {
309
+ k: this.k,
310
+ metric: this.metric,
311
+ weights: this.weights
312
+ }
313
+ };
314
+ }
315
+ }
316
+
317
+ export default KNearestNeighbors;