pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pg_sui-1.7.0.dist-info/METADATA +288 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +591 -126
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +218 -533
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1268 -530
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
- pgsui/impute/unsupervised/imputers/vae.py +928 -696
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
+
import re
|
|
5
|
+
import warnings
|
|
4
6
|
from dataclasses import dataclass, field
|
|
5
7
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
6
8
|
|
|
@@ -28,6 +30,8 @@ class ClassificationReportVisualizer:
|
|
|
28
30
|
background_hex: Matplotlib/Plotly dark background.
|
|
29
31
|
grid_hex: Gridline color for dark theme.
|
|
30
32
|
reset_kwargs: Keyword args for resetting Matplotlib rcParams.
|
|
33
|
+
genotype_order: Canonical ordering for genotype/IUPAC class labels in plots.
|
|
34
|
+
avg_order: Canonical ordering for average rows (when present).
|
|
31
35
|
"""
|
|
32
36
|
|
|
33
37
|
retro_palette: List[str] = field(
|
|
@@ -46,12 +50,143 @@ class ClassificationReportVisualizer:
|
|
|
46
50
|
grid_hex: str = "#2a2a3a"
|
|
47
51
|
reset_kwargs: Dict[str, bool | str] | None = None
|
|
48
52
|
|
|
53
|
+
# Canonical label order used everywhere.
|
|
54
|
+
# Edit/extend this if you want additional IUPAC or special tokens
|
|
55
|
+
# ordered explicitly.
|
|
56
|
+
genotype_order: List[str] = field(
|
|
57
|
+
default_factory=lambda: ["A", "C", "G", "T", "K", "M", "R", "S", "W", "Y", "N"]
|
|
58
|
+
)
|
|
59
|
+
avg_order: List[str] = field(
|
|
60
|
+
default_factory=lambda: [
|
|
61
|
+
"micro avg",
|
|
62
|
+
"macro avg",
|
|
63
|
+
"weighted avg",
|
|
64
|
+
"samples avg",
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# ---------- Ordering helpers ----------
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _normalize_label(label: str) -> str:
|
|
71
|
+
"""Normalize class labels for ordering comparisons (case-insensitive)."""
|
|
72
|
+
return str(label).strip().upper()
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _normalize_avg(label: str) -> str:
|
|
76
|
+
"""Normalize avg labels for ordering comparisons (case-insensitive)."""
|
|
77
|
+
return str(label).strip().lower()
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def _natural_sort_key(s: str):
|
|
81
|
+
"""Natural sort key so '10' sorts after '2'."""
|
|
82
|
+
parts = re.split(r"(\d+)", str(s))
|
|
83
|
+
key = []
|
|
84
|
+
for p in parts:
|
|
85
|
+
if p.isdigit():
|
|
86
|
+
key.append((0, int(p)))
|
|
87
|
+
else:
|
|
88
|
+
key.append((1, p.lower()))
|
|
89
|
+
return key
|
|
90
|
+
|
|
91
|
+
def _ordered_class_labels(self, labels: Union[pd.Index, List[str]]) -> List[str]:
|
|
92
|
+
"""Order non-avg class labels with genotype_order first, then natural-sorted remainder."""
|
|
93
|
+
labels_list = [str(x) for x in list(labels)]
|
|
94
|
+
if not labels_list:
|
|
95
|
+
return []
|
|
96
|
+
|
|
97
|
+
# Map normalized -> first-seen original label to preserve original
|
|
98
|
+
# formatting.
|
|
99
|
+
norm_to_orig: Dict[str, str] = {}
|
|
100
|
+
for lab in labels_list:
|
|
101
|
+
n = self._normalize_label(lab)
|
|
102
|
+
norm_to_orig.setdefault(n, lab)
|
|
103
|
+
|
|
104
|
+
desired_norm = [self._normalize_label(x) for x in self.genotype_order]
|
|
105
|
+
desired_set = set(desired_norm)
|
|
106
|
+
|
|
107
|
+
ordered = [norm_to_orig[n] for n in desired_norm if n in norm_to_orig]
|
|
108
|
+
|
|
109
|
+
# Append everything not in genotype_order
|
|
110
|
+
# (natural sort; stable + de-dup)
|
|
111
|
+
seen = set(self._normalize_label(x) for x in ordered)
|
|
112
|
+
remainder = [
|
|
113
|
+
lab
|
|
114
|
+
for lab in labels_list
|
|
115
|
+
if self._normalize_label(lab) not in desired_set
|
|
116
|
+
and self._normalize_label(lab) not in seen
|
|
117
|
+
]
|
|
118
|
+
remainder_sorted = sorted(remainder, key=self._natural_sort_key)
|
|
119
|
+
return ordered + remainder_sorted
|
|
120
|
+
|
|
121
|
+
def _ordered_avg_labels(self, labels: Union[pd.Index, List[str]]) -> List[str]:
|
|
122
|
+
"""Order avg labels with avg_order first, then alpha remainder.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
labels (Union[pd.Index, List[str]]): List of avg labels.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List[str]: Ordered list of avg labels.
|
|
129
|
+
"""
|
|
130
|
+
labels_list = [str(x) for x in list(labels)]
|
|
131
|
+
if not labels_list:
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
norm_to_orig: Dict[str, str] = {}
|
|
135
|
+
for lab in labels_list:
|
|
136
|
+
n = self._normalize_avg(lab)
|
|
137
|
+
norm_to_orig.setdefault(n, lab)
|
|
138
|
+
|
|
139
|
+
preferred = []
|
|
140
|
+
preferred_set = set(self.avg_order)
|
|
141
|
+
for pref in self.avg_order:
|
|
142
|
+
if pref in norm_to_orig:
|
|
143
|
+
preferred.append(norm_to_orig[pref])
|
|
144
|
+
|
|
145
|
+
seen = set(self._normalize_avg(x) for x in preferred)
|
|
146
|
+
remainder = [
|
|
147
|
+
lab
|
|
148
|
+
for lab in labels_list
|
|
149
|
+
if self._normalize_avg(lab) not in preferred_set
|
|
150
|
+
and self._normalize_avg(lab) not in seen
|
|
151
|
+
]
|
|
152
|
+
remainder_sorted = sorted(remainder, key=lambda x: x.lower())
|
|
153
|
+
return preferred + remainder_sorted
|
|
154
|
+
|
|
155
|
+
def _ordered_report_index(self, idx: Union[pd.Index, List[str]]) -> List[str]:
|
|
156
|
+
"""Order full report index: classes first (genotype_order), avg rows last (avg_order).
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
idx (Union[pd.Index, List[str]]): Index from classification report DataFrame.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
List[str]: Ordered list of index labels.
|
|
163
|
+
"""
|
|
164
|
+
labels = [str(x) for x in list(idx)]
|
|
165
|
+
is_avg = [("avg" in lab.lower()) for lab in labels]
|
|
166
|
+
class_labels = [lab for lab, a in zip(labels, is_avg) if not a]
|
|
167
|
+
avg_labels = [lab for lab, a in zip(labels, is_avg) if a]
|
|
168
|
+
return self._ordered_class_labels(class_labels) + self._ordered_avg_labels(
|
|
169
|
+
avg_labels
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _apply_ordering(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
173
|
+
"""Reindex df to canonical ordering (classes then avgs).
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
df (pd.DataFrame): DataFrame from classification report.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
pd.DataFrame: Reindexed DataFrame.
|
|
180
|
+
"""
|
|
181
|
+
ordered = self._ordered_report_index(df.index)
|
|
182
|
+
# Only keep labels that exist (avoid introducing all genotype_order labels as NaN rows)
|
|
183
|
+
ordered = [x for x in ordered if x in df.index]
|
|
184
|
+
return df.reindex(ordered)
|
|
185
|
+
|
|
49
186
|
# ---------- Core data prep ----------
|
|
50
187
|
def to_dataframe(self, report: Dict[str, Dict[str, float]]) -> pd.DataFrame:
|
|
51
188
|
"""Convert sklearn classification_report output_dict to a tidy DataFrame.
|
|
52
189
|
|
|
53
|
-
This method standardizes the output of scikit-learn's classification_report function.
|
|
54
|
-
|
|
55
190
|
Args:
|
|
56
191
|
report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
|
|
57
192
|
|
|
@@ -77,17 +212,12 @@ class ClassificationReportVisualizer:
|
|
|
77
212
|
|
|
78
213
|
df.index = df.index.astype(str)
|
|
79
214
|
|
|
80
|
-
is_avg = df.index.str.contains("avg", case=False, regex=True)
|
|
81
|
-
class_df = df.loc[~is_avg].copy()
|
|
82
|
-
avg_df = df.loc[is_avg].copy()
|
|
83
|
-
|
|
84
215
|
num_cols = ["precision", "recall", "f1-score", "support"]
|
|
85
|
-
|
|
86
|
-
avg_df[num_cols] = avg_df[num_cols].apply(pd.to_numeric, errors="coerce")
|
|
216
|
+
df[num_cols] = df[num_cols].apply(pd.to_numeric, errors="coerce")
|
|
87
217
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
return
|
|
218
|
+
# Apply canonical ordering (classes then avg rows)
|
|
219
|
+
df = self._apply_ordering(df)
|
|
220
|
+
return df
|
|
91
221
|
|
|
92
222
|
def compute_ci(
|
|
93
223
|
self,
|
|
@@ -106,21 +236,21 @@ class ClassificationReportVisualizer:
|
|
|
106
236
|
pd.DataFrame: Multi-index columns with (metric, ["lower","upper","mean"]). Index contains any class/avg labels present in the bootstrap reports.
|
|
107
237
|
"""
|
|
108
238
|
if not boot_reports:
|
|
109
|
-
|
|
239
|
+
msg = "boot_reports is empty; provide at least one dict."
|
|
240
|
+
raise ValueError(msg)
|
|
110
241
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
242
|
+
frames = [self.to_dataframe(rep) for rep in boot_reports]
|
|
243
|
+
|
|
244
|
+
# Union of indices across repeats, ordered canonically.
|
|
245
|
+
union_idx = set().union(*[set(f.index) for f in frames])
|
|
246
|
+
common_index = self._ordered_report_index(list(union_idx))
|
|
247
|
+
common_index = [x for x in common_index if x in union_idx]
|
|
116
248
|
|
|
117
|
-
# Align on index, stack into 3D array (repeat x class x metric)
|
|
118
|
-
common_index = sorted(set().union(*[f.index for f in frames]))
|
|
119
249
|
arrs = []
|
|
120
250
|
for f in frames:
|
|
121
251
|
sub = f.reindex(common_index)
|
|
122
252
|
arrs.append(sub[[m for m in metrics]].to_numpy(dtype=float))
|
|
123
|
-
arr = np.stack(arrs, axis=0) #
|
|
253
|
+
arr = np.stack(arrs, axis=0) # (B, C, M)
|
|
124
254
|
|
|
125
255
|
alpha = (1 - ci) / 2
|
|
126
256
|
lower_q = 100 * alpha
|
|
@@ -145,22 +275,17 @@ class ClassificationReportVisualizer:
|
|
|
145
275
|
def _retro_cmap(self, n: int = 256) -> LinearSegmentedColormap:
|
|
146
276
|
"""Create a neon gradient colormap.
|
|
147
277
|
|
|
148
|
-
This colormap transitions through a series of bright, neon colors.
|
|
149
|
-
|
|
150
278
|
Args:
|
|
151
|
-
n (int): Number of discrete colors in the colormap.
|
|
279
|
+
n (int): Number of discrete colors in the colormap.
|
|
152
280
|
|
|
153
281
|
Returns:
|
|
154
|
-
LinearSegmentedColormap:
|
|
282
|
+
LinearSegmentedColormap: Neon-themed colormap.
|
|
155
283
|
"""
|
|
156
284
|
anchors = ["#241937", "#7d00ff", "#ff00ff", "#ff6ec7", "#00f0ff", "#00ff9f"]
|
|
157
285
|
return LinearSegmentedColormap.from_list("retro_neon", anchors, N=n)
|
|
158
286
|
|
|
159
287
|
def _set_mpl_style(self) -> None:
|
|
160
|
-
"""Apply a dark neon Matplotlib theme.
|
|
161
|
-
|
|
162
|
-
This method modifies global rcParams; call before plotting.
|
|
163
|
-
"""
|
|
288
|
+
"""Apply a dark neon Matplotlib theme."""
|
|
164
289
|
plt.rcParams.update(
|
|
165
290
|
{
|
|
166
291
|
"figure.facecolor": self.background_hex,
|
|
@@ -201,26 +326,29 @@ class ClassificationReportVisualizer:
|
|
|
201
326
|
):
|
|
202
327
|
"""Plot a per-class heatmap with an optional right-hand support strip.
|
|
203
328
|
|
|
204
|
-
This visualizes the classification metrics for each class.
|
|
205
|
-
|
|
206
329
|
Args:
|
|
207
|
-
df (pd.DataFrame): DataFrame from
|
|
330
|
+
df (pd.DataFrame): DataFrame from to_dataframe().
|
|
208
331
|
title (str): Plot title.
|
|
209
|
-
classes_only (bool):
|
|
210
|
-
figsize (Tuple[int, int]):
|
|
332
|
+
classes_only (bool): Whether to include only classes (exclude avg rows).
|
|
333
|
+
figsize (Tuple[int, int]): Figure size.
|
|
211
334
|
annot_decimals (int): Decimal places for annotations.
|
|
212
|
-
vmax (float): Max
|
|
213
|
-
vmin (float): Min
|
|
214
|
-
show_support_strip (bool):
|
|
335
|
+
vmax (float): Max value for colormap scaling.
|
|
336
|
+
vmin (float): Min value for colormap scaling.
|
|
337
|
+
show_support_strip (bool): Whether to show a support strip on the right.
|
|
215
338
|
|
|
216
339
|
Returns:
|
|
217
|
-
|
|
340
|
+
Figure: Matplotlib figure.
|
|
218
341
|
"""
|
|
219
342
|
self._set_mpl_style()
|
|
220
343
|
|
|
221
344
|
work = df.copy()
|
|
345
|
+
# Ensure canonical ordering even if caller didn't use to_dataframe().
|
|
346
|
+
work = self._apply_ordering(work)
|
|
347
|
+
|
|
222
348
|
if classes_only:
|
|
223
349
|
work = work[~work.index.str.contains("avg", case=False, regex=True)]
|
|
350
|
+
# Re-apply class ordering after filtering
|
|
351
|
+
work = work.reindex(self._ordered_class_labels(work.index))
|
|
224
352
|
|
|
225
353
|
metric_cols = ["precision", "recall", "f1-score"]
|
|
226
354
|
heat = work[metric_cols].astype(float)
|
|
@@ -243,7 +371,6 @@ class ClassificationReportVisualizer:
|
|
|
243
371
|
ax.set_xlabel("Metric")
|
|
244
372
|
ax.set_ylabel("Class")
|
|
245
373
|
|
|
246
|
-
# Optional support strip (normalized 0..1) as an inset axis
|
|
247
374
|
if show_support_strip and "support" in work.columns:
|
|
248
375
|
supports = work["support"].astype(float).fillna(0.0).to_numpy()
|
|
249
376
|
sup_norm = (supports - supports.min()) / (np.ptp(supports) + 1e-9)
|
|
@@ -258,7 +385,6 @@ class ClassificationReportVisualizer:
|
|
|
258
385
|
)
|
|
259
386
|
|
|
260
387
|
strip_data = sup_norm[:, None] # (n_classes, 1)
|
|
261
|
-
|
|
262
388
|
sns.heatmap(
|
|
263
389
|
strip_data,
|
|
264
390
|
cmap=self._retro_cmap(),
|
|
@@ -271,8 +397,6 @@ class ClassificationReportVisualizer:
|
|
|
271
397
|
linewidths=0.0,
|
|
272
398
|
ax=ax_strip,
|
|
273
399
|
)
|
|
274
|
-
|
|
275
|
-
# Align strip y-limits to main heatmap
|
|
276
400
|
ax_strip.set_ylim(ax.get_ylim())
|
|
277
401
|
|
|
278
402
|
return fig
|
|
@@ -289,23 +413,32 @@ class ClassificationReportVisualizer:
|
|
|
289
413
|
"""Plot grouped bars for P/R/F1 with support markers and optional CI.
|
|
290
414
|
|
|
291
415
|
Args:
|
|
292
|
-
df (pd.DataFrame): DataFrame from
|
|
416
|
+
df (pd.DataFrame): DataFrame from to_dataframe().
|
|
293
417
|
title (str): Plot title.
|
|
294
|
-
classes_only (bool):
|
|
418
|
+
classes_only (bool): Whether to include only classes (exclude avg rows).
|
|
295
419
|
figsize (Tuple[int, int]): Figure size.
|
|
296
|
-
bar_alpha (float):
|
|
297
|
-
ci_df (Optional[pd.DataFrame]):
|
|
420
|
+
bar_alpha (float): Alpha transparency for bars.
|
|
421
|
+
ci_df (Optional[pd.DataFrame]): DataFrame from compute_ci() for CI bars (optional).
|
|
298
422
|
|
|
299
423
|
Returns:
|
|
300
|
-
|
|
424
|
+
Figure: Matplotlib figure.
|
|
301
425
|
"""
|
|
302
426
|
self._set_mpl_style()
|
|
427
|
+
|
|
303
428
|
work = df.copy()
|
|
429
|
+
work = self._apply_ordering(work)
|
|
430
|
+
|
|
304
431
|
if classes_only:
|
|
305
432
|
work = work[~work.index.str.contains("avg", case=False, regex=True)]
|
|
433
|
+
classes = self._ordered_class_labels(work.index)
|
|
434
|
+
work = work.reindex(classes)
|
|
435
|
+
else:
|
|
436
|
+
# If including avgs, only plot classes on x-axis for bars.
|
|
437
|
+
classes = self._ordered_class_labels(
|
|
438
|
+
work.loc[~work.index.str.contains("avg", case=False, regex=True)].index
|
|
439
|
+
)
|
|
306
440
|
|
|
307
441
|
metric_cols = ["precision", "recall", "f1-score"]
|
|
308
|
-
|
|
309
442
|
lng = (
|
|
310
443
|
work[metric_cols]
|
|
311
444
|
.reset_index(names="class")
|
|
@@ -313,11 +446,6 @@ class ClassificationReportVisualizer:
|
|
|
313
446
|
.dropna(subset=["score"])
|
|
314
447
|
)
|
|
315
448
|
|
|
316
|
-
homozygote_order = ["A", "C", "G", "T"]
|
|
317
|
-
classes = homozygote_order + [
|
|
318
|
-
c for c in lng["class"].unique().tolist() if c not in homozygote_order
|
|
319
|
-
]
|
|
320
|
-
|
|
321
449
|
metrics = metric_cols
|
|
322
450
|
palette = self.retro_palette[: len(metrics)]
|
|
323
451
|
|
|
@@ -327,9 +455,12 @@ class ClassificationReportVisualizer:
|
|
|
327
455
|
|
|
328
456
|
fig, ax = plt.subplots(figsize=figsize)
|
|
329
457
|
|
|
330
|
-
# Secondary axis for support markers
|
|
331
458
|
ax2 = ax.twinx()
|
|
332
|
-
supports =
|
|
459
|
+
supports = (
|
|
460
|
+
work.reindex(classes)["support"].astype(float).fillna(0.0).to_numpy()
|
|
461
|
+
if "support" in work.columns
|
|
462
|
+
else np.zeros(len(classes), dtype=float)
|
|
463
|
+
)
|
|
333
464
|
|
|
334
465
|
ax2.plot(
|
|
335
466
|
x,
|
|
@@ -343,13 +474,12 @@ class ClassificationReportVisualizer:
|
|
|
343
474
|
label="Support",
|
|
344
475
|
)
|
|
345
476
|
|
|
346
|
-
# Plot bars with optional CI error bars
|
|
347
477
|
for i, m in enumerate(metrics):
|
|
348
478
|
vals = (
|
|
349
479
|
lng.loc[lng["metric"].eq(m)]
|
|
350
480
|
.set_index("class")
|
|
351
481
|
.reindex(classes)["score"]
|
|
352
|
-
.
|
|
482
|
+
.to_numpy(dtype=float)
|
|
353
483
|
)
|
|
354
484
|
|
|
355
485
|
yerr = None
|
|
@@ -357,14 +487,12 @@ class ClassificationReportVisualizer:
|
|
|
357
487
|
ci_reindexed = ci_df.reindex(classes)
|
|
358
488
|
lows = ci_reindexed[(m, "lower")].to_numpy(dtype=float)
|
|
359
489
|
ups = ci_reindexed[(m, "upper")].to_numpy(dtype=float)
|
|
360
|
-
|
|
361
|
-
# Convert to symmetric error around the point estimate
|
|
362
490
|
center = vals
|
|
363
491
|
yerr = np.vstack([center - lows, ups - center])
|
|
364
492
|
|
|
365
493
|
ax.bar(
|
|
366
494
|
x + offsets[i],
|
|
367
|
-
|
|
495
|
+
vals,
|
|
368
496
|
width=width * 0.95,
|
|
369
497
|
label=m.title(),
|
|
370
498
|
color=palette[i % len(palette)],
|
|
@@ -382,10 +510,9 @@ class ClassificationReportVisualizer:
|
|
|
382
510
|
ax.set_title(title, pad=12, fontweight="bold")
|
|
383
511
|
ax.legend(ncols=3, frameon=True, loc="upper left")
|
|
384
512
|
|
|
385
|
-
# Configure secondary (support) axis
|
|
386
513
|
ax2.set_ylabel("Support")
|
|
387
514
|
ax2.grid(False)
|
|
388
|
-
ax2.set_ylim(0, max(1.0, np.asarray(supports).max() * 1.15))
|
|
515
|
+
ax2.set_ylim(0, max(1.0, float(np.asarray(supports).max()) * 1.15))
|
|
389
516
|
ax2.legend(loc="upper right", frameon=True)
|
|
390
517
|
|
|
391
518
|
ax.grid(axis="y", linestyle="--", alpha=0.6)
|
|
@@ -404,34 +531,54 @@ class ClassificationReportVisualizer:
|
|
|
404
531
|
) -> go.Figure:
|
|
405
532
|
"""Interactive radar chart of averages + top-k classes; optional CI bands.
|
|
406
533
|
|
|
407
|
-
This function creates a radar chart using Plotly, displaying the specified metrics for the top-k classes.
|
|
408
|
-
|
|
409
534
|
Args:
|
|
410
|
-
df (pd.DataFrame): DataFrame from
|
|
411
|
-
title (str):
|
|
412
|
-
top_k (int):
|
|
413
|
-
include_micro (bool):
|
|
414
|
-
include_macro (bool):
|
|
415
|
-
include_weighted (bool):
|
|
416
|
-
ci_df (Optional[pd.DataFrame]):
|
|
535
|
+
df (pd.DataFrame): DataFrame from to_dataframe().
|
|
536
|
+
title (str): Plot title.
|
|
537
|
+
top_k (int): Number of top classes by support to include.
|
|
538
|
+
include_micro (bool): Whether to include micro avg.
|
|
539
|
+
include_macro (bool): Whether to include macro avg.
|
|
540
|
+
include_weighted (bool): Whether to include weighted avg.
|
|
541
|
+
ci_df (Optional[pd.DataFrame]): DataFrame from compute_ci() for CI bands (optional).
|
|
417
542
|
|
|
418
543
|
Returns:
|
|
419
|
-
|
|
544
|
+
go.Figure: Plotly radar figure.
|
|
420
545
|
"""
|
|
421
546
|
work = df.copy()
|
|
547
|
+
work = self._apply_ordering(work)
|
|
422
548
|
|
|
423
549
|
is_avg = work.index.str.contains("avg", case=False, regex=True)
|
|
424
|
-
classes = work.loc[~is_avg].copy().sort_values("support", ascending=False)
|
|
425
|
-
if top_k is not None and top_k > 0:
|
|
426
|
-
classes = classes.head(top_k)
|
|
427
550
|
|
|
551
|
+
# --- choose top-k by support, but order those chosen by canonical genotype order ---
|
|
552
|
+
class_block = work.loc[~is_avg].copy()
|
|
553
|
+
if top_k is not None and top_k > 0 and "support" in class_block.columns:
|
|
554
|
+
top_labels = (
|
|
555
|
+
(
|
|
556
|
+
class_block["support"]
|
|
557
|
+
.astype(float)
|
|
558
|
+
.fillna(0.0)
|
|
559
|
+
.sort_values(ascending=False)
|
|
560
|
+
)
|
|
561
|
+
.head(top_k)
|
|
562
|
+
.index.tolist()
|
|
563
|
+
)
|
|
564
|
+
ordered_top = self._ordered_class_labels(top_labels)
|
|
565
|
+
classes = class_block.reindex(
|
|
566
|
+
[x for x in ordered_top if x in class_block.index]
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
classes = class_block.reindex(self._ordered_class_labels(class_block.index))
|
|
570
|
+
|
|
571
|
+
# --- averages in canonical order ---
|
|
572
|
+
include_map = {
|
|
573
|
+
"micro avg": include_micro,
|
|
574
|
+
"macro avg": include_macro,
|
|
575
|
+
"weighted avg": include_weighted,
|
|
576
|
+
"samples avg": True, # keep if present; user can ignore via flags by removing from avg_order
|
|
577
|
+
}
|
|
428
578
|
avgs = []
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
avgs.append(("weighted avg", work.loc["weighted avg"]))
|
|
433
|
-
if include_micro and "micro avg" in work.index:
|
|
434
|
-
avgs.append(("micro avg", work.loc["micro avg"]))
|
|
579
|
+
for name in self.avg_order:
|
|
580
|
+
if include_map.get(name, True) and name in work.index:
|
|
581
|
+
avgs.append((name, work.loc[name]))
|
|
435
582
|
|
|
436
583
|
metrics = ["precision", "recall", "f1-score"]
|
|
437
584
|
theta = metrics + [metrics[0]]
|
|
@@ -456,7 +603,6 @@ class ClassificationReportVisualizer:
|
|
|
456
603
|
lows.append(lows[0])
|
|
457
604
|
ups.append(ups[0])
|
|
458
605
|
|
|
459
|
-
# Plotly polar CI band: plot upper path, then lower reversed with fill
|
|
460
606
|
fig.add_trace(
|
|
461
607
|
go.Scatterpolar(
|
|
462
608
|
r=ups,
|
|
@@ -474,19 +620,14 @@ class ClassificationReportVisualizer:
|
|
|
474
620
|
mode="lines",
|
|
475
621
|
line=dict(width=0),
|
|
476
622
|
fill="toself",
|
|
477
|
-
fillcolor=(
|
|
478
|
-
color.replace("#", "rgba(") if False else None
|
|
479
|
-
), # placeholder
|
|
480
623
|
hoverinfo="skip",
|
|
481
624
|
name=f"{name} CI",
|
|
482
625
|
showlegend=False,
|
|
483
626
|
opacity=0.20,
|
|
484
627
|
)
|
|
485
628
|
)
|
|
486
|
-
|
|
487
|
-
fig.data[-1].fillcolor = f"{color}33" # add ~20% alpha
|
|
629
|
+
fig.data[-1].fillcolor = f"{color}33" # 8-digit hex w/ alpha
|
|
488
630
|
|
|
489
|
-
# Add average traces with CI first
|
|
490
631
|
for i, (name, row) in enumerate(avgs):
|
|
491
632
|
r = [float(row.get(m, np.nan)) for m in metrics]
|
|
492
633
|
r.append(r[0])
|
|
@@ -504,7 +645,6 @@ class ClassificationReportVisualizer:
|
|
|
504
645
|
)
|
|
505
646
|
)
|
|
506
647
|
|
|
507
|
-
# Add class traces (top-k) with optional CI
|
|
508
648
|
base_idx = len(avgs)
|
|
509
649
|
for i, (cls, row) in enumerate(classes[metrics].iterrows()):
|
|
510
650
|
r = [float(row.get(m, np.nan)) for m in metrics]
|
|
@@ -560,16 +700,17 @@ class ClassificationReportVisualizer:
|
|
|
560
700
|
"""Generate all visuals, with optional CI from bootstrap reports.
|
|
561
701
|
|
|
562
702
|
Args:
|
|
563
|
-
report (Dict[str, Dict[str, float]]):
|
|
564
|
-
title_prefix (str):
|
|
565
|
-
heatmap_classes_only (bool):
|
|
566
|
-
radar_top_k (int): Number of top classes
|
|
567
|
-
boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI.
|
|
568
|
-
ci (float): Confidence level
|
|
569
|
-
show (bool):
|
|
703
|
+
report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
|
|
704
|
+
title_prefix (str): Prefix for plot titles.
|
|
705
|
+
heatmap_classes_only (bool): Whether to only plot classes (exclude avg rows) in heatmap.
|
|
706
|
+
radar_top_k (int): Number of top classes by support to include in radar plot.
|
|
707
|
+
boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI computation.
|
|
708
|
+
ci (float): Confidence level for CIs.
|
|
709
|
+
show (bool): Whether to display the plots via plt.show().
|
|
570
710
|
|
|
571
711
|
Returns:
|
|
572
|
-
Dict[str, Union[
|
|
712
|
+
Dict[str, Union[Figure, go.Figure]]: Dictionary with keys:
|
|
713
|
+
"heatmap_fig", "bars_fig", "radar_fig".
|
|
573
714
|
"""
|
|
574
715
|
df = self.to_dataframe(report)
|
|
575
716
|
acc = df.attrs.get("accuracy", None)
|
|
@@ -599,7 +740,9 @@ class ClassificationReportVisualizer:
|
|
|
599
740
|
)
|
|
600
741
|
|
|
601
742
|
if show:
|
|
602
|
-
|
|
743
|
+
with warnings.catch_warnings():
|
|
744
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
745
|
+
plt.show()
|
|
603
746
|
|
|
604
747
|
return {
|
|
605
748
|
"heatmap_fig": heatmap_fig,
|