model-eval-toolkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalreport/__init__.py +28 -0
- evalreport/__version__.py +2 -0
- evalreport/classification/__init__.py +4 -0
- evalreport/classification/report.py +319 -0
- evalreport/clustering/__init__.py +4 -0
- evalreport/clustering/report.py +174 -0
- evalreport/core/base_report.py +479 -0
- evalreport/core/entrypoints.py +97 -0
- evalreport/core/task_inference.py +180 -0
- evalreport/nlp/__init__.py +5 -0
- evalreport/nlp/text_classification.py +21 -0
- evalreport/nlp/text_generation.py +202 -0
- evalreport/ranking/__init__.py +3 -0
- evalreport/ranking/report.py +274 -0
- evalreport/regression/__init__.py +4 -0
- evalreport/regression/report.py +173 -0
- evalreport/timeseries/__init__.py +4 -0
- evalreport/timeseries/report.py +211 -0
- evalreport/vision/__init__.py +6 -0
- evalreport/vision/detection.py +359 -0
- evalreport/vision/image_classification.py +25 -0
- evalreport/vision/segmentation.py +140 -0
- model_eval_toolkit-0.1.0.dist-info/METADATA +339 -0
- model_eval_toolkit-0.1.0.dist-info/RECORD +27 -0
- model_eval_toolkit-0.1.0.dist-info/WHEEL +5 -0
- model_eval_toolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- model_eval_toolkit-0.1.0.dist-info/top_level.txt +1 -0
evalreport/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from .core.entrypoints import generate_report
|
|
2
|
+
from .classification.report import ClassificationReport
|
|
3
|
+
from .regression.report import RegressionReport
|
|
4
|
+
from .__version__ import __version__
|
|
5
|
+
from .clustering.report import ClusteringReport
|
|
6
|
+
from .timeseries.report import TimeSeriesReport
|
|
7
|
+
from .nlp.text_classification import TextClassificationReport
|
|
8
|
+
from .nlp.text_generation import TextGenerationReport
|
|
9
|
+
from .vision.segmentation import SegmentationReport
|
|
10
|
+
from .vision.detection import DetectionReport
|
|
11
|
+
from .vision.image_classification import ImageClassificationReport
|
|
12
|
+
from .ranking.report import RankingReport
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"generate_report",
|
|
16
|
+
"__version__",
|
|
17
|
+
"ClassificationReport",
|
|
18
|
+
"RegressionReport",
|
|
19
|
+
"ClusteringReport",
|
|
20
|
+
"TimeSeriesReport",
|
|
21
|
+
"TextClassificationReport",
|
|
22
|
+
"TextGenerationReport",
|
|
23
|
+
"SegmentationReport",
|
|
24
|
+
"DetectionReport",
|
|
25
|
+
"ImageClassificationReport",
|
|
26
|
+
"RankingReport",
|
|
27
|
+
]
|
|
28
|
+
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterable, List, Optional, Sequence
|
|
6
|
+
|
|
7
|
+
import matplotlib
|
|
8
|
+
matplotlib.use("Agg")
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
import seaborn as sns
|
|
11
|
+
import numpy as np
|
|
12
|
+
from sklearn.metrics import (
|
|
13
|
+
accuracy_score,
|
|
14
|
+
cohen_kappa_score,
|
|
15
|
+
confusion_matrix,
|
|
16
|
+
f1_score,
|
|
17
|
+
log_loss,
|
|
18
|
+
matthews_corrcoef,
|
|
19
|
+
precision_score,
|
|
20
|
+
recall_score,
|
|
21
|
+
roc_auc_score,
|
|
22
|
+
average_precision_score,
|
|
23
|
+
roc_curve,
|
|
24
|
+
precision_recall_curve,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from ..core.base_report import BaseReport
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _as_array(x: Optional[Iterable[Any]]) -> Optional[np.ndarray]:
|
|
31
|
+
if x is None:
|
|
32
|
+
return None
|
|
33
|
+
arr = np.asarray(list(x))
|
|
34
|
+
return arr
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _safe_float(x: Any) -> Any:
|
|
38
|
+
try:
|
|
39
|
+
if isinstance(x, (np.floating, np.integer)):
|
|
40
|
+
return x.item()
|
|
41
|
+
if isinstance(x, float) and (np.isnan(x) or np.isinf(x)):
|
|
42
|
+
return None
|
|
43
|
+
return float(x)
|
|
44
|
+
except Exception:
|
|
45
|
+
return x
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ClassificationReport(BaseReport):
|
|
50
|
+
y_true: Optional[Iterable[Any]] = None
|
|
51
|
+
y_pred: Optional[Iterable[Any]] = None
|
|
52
|
+
y_prob: Optional[Iterable[Any]] = None
|
|
53
|
+
labels: Optional[Sequence[Any]] = None
|
|
54
|
+
|
|
55
|
+
def _compute_metrics(self) -> None:
|
|
56
|
+
y_true = _as_array(self.y_true)
|
|
57
|
+
y_pred = _as_array(self.y_pred)
|
|
58
|
+
if y_true is None or y_pred is None:
|
|
59
|
+
raise ValueError("ClassificationReport requires y_true and y_pred.")
|
|
60
|
+
|
|
61
|
+
average_modes = ["micro", "macro", "weighted"]
|
|
62
|
+
self.metrics["accuracy"] = _safe_float(accuracy_score(y_true, y_pred))
|
|
63
|
+
|
|
64
|
+
for avg in average_modes:
|
|
65
|
+
self.metrics[f"precision_{avg}"] = _safe_float(
|
|
66
|
+
precision_score(y_true, y_pred, average=avg, zero_division=0)
|
|
67
|
+
)
|
|
68
|
+
self.metrics[f"recall_{avg}"] = _safe_float(
|
|
69
|
+
recall_score(y_true, y_pred, average=avg, zero_division=0)
|
|
70
|
+
)
|
|
71
|
+
self.metrics[f"f1_{avg}"] = _safe_float(
|
|
72
|
+
f1_score(y_true, y_pred, average=avg, zero_division=0)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Extras
|
|
76
|
+
self.metrics["mcc"] = _safe_float(matthews_corrcoef(y_true, y_pred))
|
|
77
|
+
try:
|
|
78
|
+
self.metrics["cohen_kappa"] = _safe_float(cohen_kappa_score(y_true, y_pred))
|
|
79
|
+
except Exception:
|
|
80
|
+
self.metrics["cohen_kappa"] = None
|
|
81
|
+
|
|
82
|
+
# Probabilistic metrics (best-effort; may be None)
|
|
83
|
+
y_prob = _as_array(self.y_prob)
|
|
84
|
+
if y_prob is not None:
|
|
85
|
+
try:
|
|
86
|
+
self.metrics["log_loss"] = _safe_float(log_loss(y_true, y_prob, labels=self.labels))
|
|
87
|
+
except Exception:
|
|
88
|
+
self.metrics["log_loss"] = None
|
|
89
|
+
|
|
90
|
+
# ROC/PR AUC (binary or multiclass if possible)
|
|
91
|
+
try:
|
|
92
|
+
# Binary case: accept shape (n,) or (n,2) and use positive class scores
|
|
93
|
+
if y_prob.ndim == 1:
|
|
94
|
+
y_score = y_prob
|
|
95
|
+
elif y_prob.ndim == 2 and y_prob.shape[1] == 2:
|
|
96
|
+
y_score = y_prob[:, 1]
|
|
97
|
+
else:
|
|
98
|
+
y_score = y_prob
|
|
99
|
+
|
|
100
|
+
self.metrics["roc_auc"] = _safe_float(
|
|
101
|
+
roc_auc_score(y_true, y_score, multi_class="ovr" if getattr(y_score, "ndim", 1) == 2 else "raise")
|
|
102
|
+
)
|
|
103
|
+
except Exception:
|
|
104
|
+
self.metrics["roc_auc"] = None
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
if y_prob.ndim == 1:
|
|
108
|
+
y_score = y_prob
|
|
109
|
+
elif y_prob.ndim == 2 and y_prob.shape[1] == 2:
|
|
110
|
+
y_score = y_prob[:, 1]
|
|
111
|
+
else:
|
|
112
|
+
y_score = y_prob
|
|
113
|
+
|
|
114
|
+
# average_precision_score supports binary directly; multiclass handled as macro if possible
|
|
115
|
+
self.metrics["pr_auc"] = _safe_float(
|
|
116
|
+
average_precision_score(
|
|
117
|
+
y_true,
|
|
118
|
+
y_score,
|
|
119
|
+
average="macro" if getattr(y_score, "ndim", 1) == 2 else "macro",
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
except Exception:
|
|
123
|
+
self.metrics["pr_auc"] = None
|
|
124
|
+
|
|
125
|
+
# Confusion matrix as a small, JSON-serializable payload
|
|
126
|
+
try:
|
|
127
|
+
cm = confusion_matrix(y_true, y_pred, labels=self.labels)
|
|
128
|
+
self.metrics["confusion_matrix"] = cm.tolist()
|
|
129
|
+
except Exception:
|
|
130
|
+
self.metrics["confusion_matrix"] = None
|
|
131
|
+
|
|
132
|
+
def _generate_plots(self) -> None:
|
|
133
|
+
y_true = _as_array(self.y_true)
|
|
134
|
+
y_pred = _as_array(self.y_pred)
|
|
135
|
+
if y_true is None or y_pred is None:
|
|
136
|
+
self.plots = {}
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
root = self.output_dir or Path("reports")
|
|
140
|
+
plot_dir = root / "evalreport_plots"
|
|
141
|
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
plots: dict[str, str] = {}
|
|
143
|
+
|
|
144
|
+
# Confusion matrix heatmap
|
|
145
|
+
try:
|
|
146
|
+
labels = self.labels
|
|
147
|
+
if labels is None:
|
|
148
|
+
labels = list(np.unique(np.concatenate([y_true, y_pred])))
|
|
149
|
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
|
150
|
+
plt.figure(figsize=(4, 3))
|
|
151
|
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
|
|
152
|
+
plt.xlabel("Predicted")
|
|
153
|
+
plt.ylabel("True")
|
|
154
|
+
plt.title("Confusion Matrix")
|
|
155
|
+
path = plot_dir / "classification_confusion_matrix.png"
|
|
156
|
+
plt.tight_layout()
|
|
157
|
+
plt.savefig(path)
|
|
158
|
+
plt.close()
|
|
159
|
+
plots["confusion_matrix"] = str(path)
|
|
160
|
+
except Exception:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
# ROC and PR curves with probabilities
|
|
164
|
+
y_prob = _as_array(self.y_prob)
|
|
165
|
+
if y_prob is not None:
|
|
166
|
+
try:
|
|
167
|
+
y_prob_arr = y_prob
|
|
168
|
+
n_classes = len(np.unique(y_true))
|
|
169
|
+
# Determine class ordering for column mapping:
|
|
170
|
+
# - if `labels` provided, assume columns follow that order
|
|
171
|
+
# - else assume columns follow sorted unique labels from y_true
|
|
172
|
+
class_order = list(self.labels) if self.labels is not None else sorted(np.unique(y_true).tolist())
|
|
173
|
+
|
|
174
|
+
# handle probability input shape
|
|
175
|
+
y_score_for_binary = None
|
|
176
|
+
if y_prob_arr.ndim == 1:
|
|
177
|
+
# binary: P(positive class)
|
|
178
|
+
y_score_for_binary = y_prob_arr
|
|
179
|
+
elif y_prob_arr.ndim == 2:
|
|
180
|
+
# For binary, accept both (n,2) and (n,) variants.
|
|
181
|
+
if n_classes == 2 and y_prob_arr.shape[1] == 2:
|
|
182
|
+
# Use column 1 as "positive" by convention
|
|
183
|
+
y_score_for_binary = y_prob_arr[:, 1]
|
|
184
|
+
|
|
185
|
+
# Binary case ------------------------------------------------
|
|
186
|
+
if n_classes == 2 and y_score_for_binary is not None:
|
|
187
|
+
# ROC
|
|
188
|
+
fpr, tpr, _ = roc_curve(y_true, y_score_for_binary)
|
|
189
|
+
plt.figure(figsize=(4, 3))
|
|
190
|
+
plt.plot(fpr, tpr, label="ROC curve")
|
|
191
|
+
plt.plot([0, 1], [0, 1], "k--", label="Random")
|
|
192
|
+
plt.xlabel("False Positive Rate")
|
|
193
|
+
plt.ylabel("True Positive Rate")
|
|
194
|
+
plt.title("ROC Curve (binary)")
|
|
195
|
+
plt.legend()
|
|
196
|
+
path = plot_dir / "classification_roc_curve.png"
|
|
197
|
+
plt.tight_layout()
|
|
198
|
+
plt.savefig(path)
|
|
199
|
+
plt.close()
|
|
200
|
+
plots["roc_curve"] = str(path)
|
|
201
|
+
|
|
202
|
+
# PR
|
|
203
|
+
prec, rec, _ = precision_recall_curve(y_true, y_score_for_binary)
|
|
204
|
+
plt.figure(figsize=(4, 3))
|
|
205
|
+
plt.plot(rec, prec, label="PR curve")
|
|
206
|
+
plt.xlabel("Recall")
|
|
207
|
+
plt.ylabel("Precision")
|
|
208
|
+
plt.title("Precision-Recall Curve (binary)")
|
|
209
|
+
plt.legend()
|
|
210
|
+
path = plot_dir / "classification_pr_curve.png"
|
|
211
|
+
plt.tight_layout()
|
|
212
|
+
plt.savefig(path)
|
|
213
|
+
plt.close()
|
|
214
|
+
plots["pr_curve"] = str(path)
|
|
215
|
+
|
|
216
|
+
# Multiclass case ------------------------------------------
|
|
217
|
+
if n_classes > 2 and y_prob_arr.ndim == 2:
|
|
218
|
+
# Only proceed if columns match the class order we will use.
|
|
219
|
+
if y_prob_arr.shape[1] == len(class_order):
|
|
220
|
+
# One-vs-rest ROC
|
|
221
|
+
plt.figure(figsize=(5, 4))
|
|
222
|
+
for col_idx, cls in enumerate(class_order):
|
|
223
|
+
y_bin = (y_true == cls).astype(int)
|
|
224
|
+
fpr, tpr, _ = roc_curve(y_bin, y_prob_arr[:, col_idx])
|
|
225
|
+
plt.plot(fpr, tpr, linewidth=2, label=str(cls))
|
|
226
|
+
plt.plot([0, 1], [0, 1], "k--", linewidth=1, label="Random")
|
|
227
|
+
plt.xlabel("False Positive Rate")
|
|
228
|
+
plt.ylabel("True Positive Rate")
|
|
229
|
+
plt.title("ROC Curve (multiclass, one-vs-rest)")
|
|
230
|
+
plt.legend(title="Class", fontsize=8)
|
|
231
|
+
path = plot_dir / "classification_roc_curve_multiclass.png"
|
|
232
|
+
plt.tight_layout()
|
|
233
|
+
plt.savefig(path)
|
|
234
|
+
plt.close()
|
|
235
|
+
plots["roc_curve_multiclass"] = str(path)
|
|
236
|
+
|
|
237
|
+
# One-vs-rest PR
|
|
238
|
+
plt.figure(figsize=(5, 4))
|
|
239
|
+
for col_idx, cls in enumerate(class_order):
|
|
240
|
+
y_bin = (y_true == cls).astype(int)
|
|
241
|
+
prec, rec, _ = precision_recall_curve(y_bin, y_prob_arr[:, col_idx])
|
|
242
|
+
plt.plot(rec, prec, linewidth=2, label=str(cls))
|
|
243
|
+
plt.xlabel("Recall")
|
|
244
|
+
plt.ylabel("Precision")
|
|
245
|
+
plt.title("Precision-Recall Curve (multiclass, one-vs-rest)")
|
|
246
|
+
plt.legend(title="Class", fontsize=8)
|
|
247
|
+
path = plot_dir / "classification_pr_curve_multiclass.png"
|
|
248
|
+
plt.tight_layout()
|
|
249
|
+
plt.savefig(path)
|
|
250
|
+
plt.close()
|
|
251
|
+
plots["pr_curve_multiclass"] = str(path)
|
|
252
|
+
except Exception:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
self.plots = plots
|
|
256
|
+
|
|
257
|
+
def _generate_insights(self) -> None:
|
|
258
|
+
y_true = _as_array(self.y_true)
|
|
259
|
+
y_pred = _as_array(self.y_pred)
|
|
260
|
+
if y_true is None or y_pred is None:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
insights: List[str] = []
|
|
264
|
+
|
|
265
|
+
# Class imbalance detection (simple heuristic)
|
|
266
|
+
try:
|
|
267
|
+
values, counts = np.unique(y_true, return_counts=True)
|
|
268
|
+
if len(counts) > 1:
|
|
269
|
+
ratio = counts.max() / max(1, counts.min())
|
|
270
|
+
if ratio >= 5:
|
|
271
|
+
minority = values[np.argmin(counts)]
|
|
272
|
+
majority = values[np.argmax(counts)]
|
|
273
|
+
insights.append(
|
|
274
|
+
f"Class imbalance detected (majority={majority!r}, minority={minority!r}, ratio≈{ratio:.1f})."
|
|
275
|
+
)
|
|
276
|
+
except Exception:
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
# Misclassification trends: top confusions off-diagonal
|
|
280
|
+
try:
|
|
281
|
+
labels = self.labels
|
|
282
|
+
if labels is None:
|
|
283
|
+
labels = list(np.unique(np.concatenate([y_true, y_pred])))
|
|
284
|
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
|
285
|
+
cm_off = cm.copy()
|
|
286
|
+
np.fill_diagonal(cm_off, 0)
|
|
287
|
+
if cm_off.sum() > 0:
|
|
288
|
+
i, j = np.unravel_index(np.argmax(cm_off), cm_off.shape)
|
|
289
|
+
if cm_off[i, j] > 0:
|
|
290
|
+
insights.append(
|
|
291
|
+
f"Most common confusion: true={labels[i]!r} predicted={labels[j]!r} ({int(cm_off[i, j])} samples)."
|
|
292
|
+
)
|
|
293
|
+
except Exception:
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
self.insights = insights
|
|
297
|
+
|
|
298
|
+
# Descriptions for key metrics shown in HTML/PDF
|
|
299
|
+
self.metric_descriptions.update(
|
|
300
|
+
{
|
|
301
|
+
"accuracy": "Overall fraction of correct predictions.",
|
|
302
|
+
"precision_micro": "Precision aggregated over all classes (micro-average).",
|
|
303
|
+
"recall_micro": "Recall aggregated over all classes (micro-average).",
|
|
304
|
+
"f1_micro": "F1 score aggregated over all classes (micro-average).",
|
|
305
|
+
"precision_macro": "Unweighted mean of per-class precision.",
|
|
306
|
+
"recall_macro": "Unweighted mean of per-class recall.",
|
|
307
|
+
"f1_macro": "Unweighted mean of per-class F1 score.",
|
|
308
|
+
"precision_weighted": "Precision averaged over classes, weighted by support.",
|
|
309
|
+
"recall_weighted": "Recall averaged over classes, weighted by support.",
|
|
310
|
+
"f1_weighted": "F1 averaged over classes, weighted by support.",
|
|
311
|
+
"mcc": "Matthews correlation coefficient; balanced measure even under class imbalance.",
|
|
312
|
+
"cohen_kappa": "Cohen’s kappa; agreement between predictions and truth beyond chance.",
|
|
313
|
+
"log_loss": "Logarithmic loss; lower values indicate better calibrated probabilities.",
|
|
314
|
+
"roc_auc": "Area under the ROC curve; trade-off between TPR and FPR.",
|
|
315
|
+
"pr_auc": "Area under the precision–recall curve; useful for imbalanced data.",
|
|
316
|
+
"confusion_matrix": "Counts of predictions vs true labels for each class pair.",
|
|
317
|
+
}
|
|
318
|
+
)
|
|
319
|
+
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterable, List, Optional, Sequence
|
|
6
|
+
|
|
7
|
+
import matplotlib
|
|
8
|
+
|
|
9
|
+
matplotlib.use("Agg")
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import seaborn as sns
|
|
13
|
+
from sklearn.cluster import KMeans
|
|
14
|
+
from sklearn.decomposition import PCA
|
|
15
|
+
from sklearn.metrics import (
|
|
16
|
+
calinski_harabasz_score,
|
|
17
|
+
davies_bouldin_score,
|
|
18
|
+
silhouette_score,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from ..core.base_report import BaseReport
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _as_2d_array(x: Optional[Iterable[Any]]) -> Optional[np.ndarray]:
|
|
25
|
+
if x is None:
|
|
26
|
+
return None
|
|
27
|
+
arr = np.asarray(list(x))
|
|
28
|
+
if arr.ndim == 1:
|
|
29
|
+
arr = arr.reshape(-1, 1)
|
|
30
|
+
return arr
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _as_array(x: Optional[Iterable[Any]]) -> Optional[np.ndarray]:
|
|
34
|
+
if x is None:
|
|
35
|
+
return None
|
|
36
|
+
return np.asarray(list(x))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ClusteringReport(BaseReport):
|
|
41
|
+
X: Optional[Iterable[Any]] = None
|
|
42
|
+
labels: Optional[Iterable[Any]] = None
|
|
43
|
+
# When provided, we will fit a basic KMeans model to derive cluster centers
|
|
44
|
+
# for plotting/diagnostics.
|
|
45
|
+
n_clusters: Optional[int] = None
|
|
46
|
+
random_state: int = 0
|
|
47
|
+
|
|
48
|
+
def _compute_metrics(self) -> None:
|
|
49
|
+
X = _as_2d_array(self.X)
|
|
50
|
+
labels = _as_array(self.labels)
|
|
51
|
+
if X is None or labels is None:
|
|
52
|
+
raise ValueError("ClusteringReport requires X and labels (cluster assignments).")
|
|
53
|
+
|
|
54
|
+
unique = np.unique(labels)
|
|
55
|
+
if unique.size < 2:
|
|
56
|
+
self.metrics.update(
|
|
57
|
+
{
|
|
58
|
+
"silhouette_score": None,
|
|
59
|
+
"davies_bouldin_index": None,
|
|
60
|
+
"calinski_harabasz_score": None,
|
|
61
|
+
}
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
self.metrics["silhouette_score"] = float(silhouette_score(X, labels))
|
|
65
|
+
self.metrics["davies_bouldin_index"] = float(davies_bouldin_score(X, labels))
|
|
66
|
+
self.metrics["calinski_harabasz_score"] = float(calinski_harabasz_score(X, labels))
|
|
67
|
+
|
|
68
|
+
self.metrics["num_clusters"] = int(unique.size)
|
|
69
|
+
# cluster_size distribution as counts
|
|
70
|
+
vals, counts = np.unique(labels, return_counts=True)
|
|
71
|
+
self.metrics["cluster_sizes"] = {str(v): int(c) for v, c in zip(vals, counts)}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _generate_plots(self) -> None:
|
|
76
|
+
X = _as_2d_array(self.X)
|
|
77
|
+
labels = _as_array(self.labels)
|
|
78
|
+
if X is None or labels is None:
|
|
79
|
+
self.plots = {}
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
root = self.output_dir or Path("reports")
|
|
83
|
+
plot_dir = root / "evalreport_plots"
|
|
84
|
+
plot_dir.mkdir(parents=True, exist_ok=True)
|
|
85
|
+
|
|
86
|
+
plots: dict[str, str] = {}
|
|
87
|
+
|
|
88
|
+
# Scatter (PCA projection to 2D)
|
|
89
|
+
try:
|
|
90
|
+
pca_dim = 2 if X.shape[1] >= 2 and X.shape[0] >= 3 else 1
|
|
91
|
+
pca_dim = min(pca_dim, X.shape[1], max(1, X.shape[0] - 1))
|
|
92
|
+
pca = PCA(n_components=pca_dim, random_state=self.random_state)
|
|
93
|
+
X2 = pca.fit_transform(X)
|
|
94
|
+
|
|
95
|
+
plt.figure(figsize=(5, 4))
|
|
96
|
+
if X2.shape[1] == 1:
|
|
97
|
+
plt.scatter(X2[:, 0], np.zeros_like(X2[:, 0]), c=labels, cmap="tab10", alpha=0.8)
|
|
98
|
+
else:
|
|
99
|
+
plt.scatter(X2[:, 0], X2[:, 1], c=labels, cmap="tab10", alpha=0.8)
|
|
100
|
+
plt.title("Cluster scatter (PCA)")
|
|
101
|
+
plt.xlabel("PC1")
|
|
102
|
+
plt.ylabel("PC2" if X2.shape[1] > 1 else "")
|
|
103
|
+
path = plot_dir / "clustering_scatter_pca.png"
|
|
104
|
+
plt.tight_layout()
|
|
105
|
+
plt.savefig(path)
|
|
106
|
+
plt.close()
|
|
107
|
+
plots["cluster_scatter_pca"] = str(path)
|
|
108
|
+
except Exception:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
# Cluster size distribution
|
|
112
|
+
try:
|
|
113
|
+
vals, counts = np.unique(labels, return_counts=True)
|
|
114
|
+
plt.figure(figsize=(5, 3.5))
|
|
115
|
+
sns.barplot(x=[str(v) for v in vals], y=counts, color="#4C78A8")
|
|
116
|
+
plt.xlabel("Cluster")
|
|
117
|
+
plt.ylabel("Count")
|
|
118
|
+
plt.title("Cluster size distribution")
|
|
119
|
+
for i, c in enumerate(counts):
|
|
120
|
+
plt.text(i, c, str(int(c)), ha="center", va="bottom", fontsize=8)
|
|
121
|
+
path = plot_dir / "clustering_cluster_sizes.png"
|
|
122
|
+
plt.tight_layout()
|
|
123
|
+
plt.savefig(path)
|
|
124
|
+
plt.close()
|
|
125
|
+
plots["cluster_size_distribution"] = str(path)
|
|
126
|
+
except Exception:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
self.plots = plots
|
|
130
|
+
|
|
131
|
+
def _generate_insights(self) -> None:
|
|
132
|
+
labels = _as_array(self.labels)
|
|
133
|
+
if labels is None:
|
|
134
|
+
self.insights = []
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
insights: List[str] = []
|
|
138
|
+
unique, counts = np.unique(labels, return_counts=True)
|
|
139
|
+
if unique.size >= 2:
|
|
140
|
+
max_c = counts.max()
|
|
141
|
+
min_c = counts.min()
|
|
142
|
+
ratio = float(max_c) / float(max(1, min_c))
|
|
143
|
+
if ratio >= 5:
|
|
144
|
+
# report most/least dominant clusters
|
|
145
|
+
maj = unique[np.argmax(counts)]
|
|
146
|
+
min_label = unique[np.argmin(counts)]
|
|
147
|
+
insights.append(f"Cluster imbalance detected (majority={maj!r}, minority={min_label!r}, ratio≈{ratio:.1f}).")
|
|
148
|
+
|
|
149
|
+
# Separability heuristics using silhouette
|
|
150
|
+
sil = self.metrics.get("silhouette_score")
|
|
151
|
+
if isinstance(sil, (int, float)) and sil is not None:
|
|
152
|
+
if sil < 0.25:
|
|
153
|
+
insights.append("Clusters overlap significantly (low silhouette). Consider revisiting features, scaling, or k.")
|
|
154
|
+
elif sil > 0.5:
|
|
155
|
+
insights.append("Clusters appear well separated (high silhouette).")
|
|
156
|
+
|
|
157
|
+
# Davies-Bouldin lower is better
|
|
158
|
+
dbi = self.metrics.get("davies_bouldin_index")
|
|
159
|
+
if isinstance(dbi, (int, float)) and dbi is not None:
|
|
160
|
+
if dbi > 1.0:
|
|
161
|
+
insights.append("Higher Davies–Bouldin suggests clusters may be less distinct; inspect overlaps.")
|
|
162
|
+
|
|
163
|
+
self.insights = insights
|
|
164
|
+
# Human-readable explanations
|
|
165
|
+
self.metric_descriptions.update(
|
|
166
|
+
{
|
|
167
|
+
"silhouette_score": "How well points fit their own cluster vs other clusters (higher is better).",
|
|
168
|
+
"davies_bouldin_index": "Average similarity between clusters (lower is better).",
|
|
169
|
+
"calinski_harabasz_score": "Variance ratio criterion (higher suggests clearer separation).",
|
|
170
|
+
"num_clusters": "Number of unique clusters in the provided assignments.",
|
|
171
|
+
"cluster_sizes": "Counts per cluster; helps detect cluster imbalance.",
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
|