pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  from __future__ import annotations
3
3
 
4
+ import re
5
+ import warnings
4
6
  from dataclasses import dataclass, field
5
7
  from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
8
 
@@ -28,6 +30,8 @@ class ClassificationReportVisualizer:
28
30
  background_hex: Matplotlib/Plotly dark background.
29
31
  grid_hex: Gridline color for dark theme.
30
32
  reset_kwargs: Keyword args for resetting Matplotlib rcParams.
33
+ genotype_order: Canonical ordering for genotype/IUPAC class labels in plots.
34
+ avg_order: Canonical ordering for average rows (when present).
31
35
  """
32
36
 
33
37
  retro_palette: List[str] = field(
@@ -46,12 +50,143 @@ class ClassificationReportVisualizer:
46
50
  grid_hex: str = "#2a2a3a"
47
51
  reset_kwargs: Dict[str, bool | str] | None = None
48
52
 
53
+ # Canonical label order used everywhere.
54
+ # Edit/extend this if you want additional IUPAC or special tokens
55
+ # ordered explicitly.
56
+ genotype_order: List[str] = field(
57
+ default_factory=lambda: ["A", "C", "G", "T", "K", "M", "R", "S", "W", "Y", "N"]
58
+ )
59
+ avg_order: List[str] = field(
60
+ default_factory=lambda: [
61
+ "micro avg",
62
+ "macro avg",
63
+ "weighted avg",
64
+ "samples avg",
65
+ ]
66
+ )
67
+
68
+ # ---------- Ordering helpers ----------
69
+ @staticmethod
70
+ def _normalize_label(label: str) -> str:
71
+ """Normalize class labels for ordering comparisons (case-insensitive)."""
72
+ return str(label).strip().upper()
73
+
74
+ @staticmethod
75
+ def _normalize_avg(label: str) -> str:
76
+ """Normalize avg labels for ordering comparisons (case-insensitive)."""
77
+ return str(label).strip().lower()
78
+
79
+ @staticmethod
80
+ def _natural_sort_key(s: str):
81
+ """Natural sort key so '10' sorts after '2'."""
82
+ parts = re.split(r"(\d+)", str(s))
83
+ key = []
84
+ for p in parts:
85
+ if p.isdigit():
86
+ key.append((0, int(p)))
87
+ else:
88
+ key.append((1, p.lower()))
89
+ return key
90
+
91
+ def _ordered_class_labels(self, labels: Union[pd.Index, List[str]]) -> List[str]:
92
+ """Order non-avg class labels with genotype_order first, then natural-sorted remainder."""
93
+ labels_list = [str(x) for x in list(labels)]
94
+ if not labels_list:
95
+ return []
96
+
97
+ # Map normalized -> first-seen original label to preserve original
98
+ # formatting.
99
+ norm_to_orig: Dict[str, str] = {}
100
+ for lab in labels_list:
101
+ n = self._normalize_label(lab)
102
+ norm_to_orig.setdefault(n, lab)
103
+
104
+ desired_norm = [self._normalize_label(x) for x in self.genotype_order]
105
+ desired_set = set(desired_norm)
106
+
107
+ ordered = [norm_to_orig[n] for n in desired_norm if n in norm_to_orig]
108
+
109
+ # Append everything not in genotype_order
110
+ # (natural sort; stable + de-dup)
111
+ seen = set(self._normalize_label(x) for x in ordered)
112
+ remainder = [
113
+ lab
114
+ for lab in labels_list
115
+ if self._normalize_label(lab) not in desired_set
116
+ and self._normalize_label(lab) not in seen
117
+ ]
118
+ remainder_sorted = sorted(remainder, key=self._natural_sort_key)
119
+ return ordered + remainder_sorted
120
+
121
+ def _ordered_avg_labels(self, labels: Union[pd.Index, List[str]]) -> List[str]:
122
+ """Order avg labels with avg_order first, then alpha remainder.
123
+
124
+ Args:
125
+ labels (Union[pd.Index, List[str]]): List of avg labels.
126
+
127
+ Returns:
128
+ List[str]: Ordered list of avg labels.
129
+ """
130
+ labels_list = [str(x) for x in list(labels)]
131
+ if not labels_list:
132
+ return []
133
+
134
+ norm_to_orig: Dict[str, str] = {}
135
+ for lab in labels_list:
136
+ n = self._normalize_avg(lab)
137
+ norm_to_orig.setdefault(n, lab)
138
+
139
+ preferred = []
140
+ preferred_set = set(self.avg_order)
141
+ for pref in self.avg_order:
142
+ if pref in norm_to_orig:
143
+ preferred.append(norm_to_orig[pref])
144
+
145
+ seen = set(self._normalize_avg(x) for x in preferred)
146
+ remainder = [
147
+ lab
148
+ for lab in labels_list
149
+ if self._normalize_avg(lab) not in preferred_set
150
+ and self._normalize_avg(lab) not in seen
151
+ ]
152
+ remainder_sorted = sorted(remainder, key=lambda x: x.lower())
153
+ return preferred + remainder_sorted
154
+
155
+ def _ordered_report_index(self, idx: Union[pd.Index, List[str]]) -> List[str]:
156
+ """Order full report index: classes first (genotype_order), avg rows last (avg_order).
157
+
158
+ Args:
159
+ idx (Union[pd.Index, List[str]]): Index from classification report DataFrame.
160
+
161
+ Returns:
162
+ List[str]: Ordered list of index labels.
163
+ """
164
+ labels = [str(x) for x in list(idx)]
165
+ is_avg = [("avg" in lab.lower()) for lab in labels]
166
+ class_labels = [lab for lab, a in zip(labels, is_avg) if not a]
167
+ avg_labels = [lab for lab, a in zip(labels, is_avg) if a]
168
+ return self._ordered_class_labels(class_labels) + self._ordered_avg_labels(
169
+ avg_labels
170
+ )
171
+
172
+ def _apply_ordering(self, df: pd.DataFrame) -> pd.DataFrame:
173
+ """Reindex df to canonical ordering (classes then avgs).
174
+
175
+ Args:
176
+ df (pd.DataFrame): DataFrame from classification report.
177
+
178
+ Returns:
179
+ pd.DataFrame: Reindexed DataFrame.
180
+ """
181
+ ordered = self._ordered_report_index(df.index)
182
+ # Only keep labels that exist (avoid introducing all genotype_order labels as NaN rows)
183
+ ordered = [x for x in ordered if x in df.index]
184
+ return df.reindex(ordered)
185
+
49
186
  # ---------- Core data prep ----------
50
187
  def to_dataframe(self, report: Dict[str, Dict[str, float]]) -> pd.DataFrame:
51
188
  """Convert sklearn classification_report output_dict to a tidy DataFrame.
52
189
 
53
- This method standardizes the output of scikit-learn's classification_report function.
54
-
55
190
  Args:
56
191
  report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
57
192
 
@@ -77,17 +212,12 @@ class ClassificationReportVisualizer:
77
212
 
78
213
  df.index = df.index.astype(str)
79
214
 
80
- is_avg = df.index.str.contains("avg", case=False, regex=True)
81
- class_df = df.loc[~is_avg].copy()
82
- avg_df = df.loc[is_avg].copy()
83
-
84
215
  num_cols = ["precision", "recall", "f1-score", "support"]
85
- class_df[num_cols] = class_df[num_cols].apply(pd.to_numeric, errors="coerce")
86
- avg_df[num_cols] = avg_df[num_cols].apply(pd.to_numeric, errors="coerce")
216
+ df[num_cols] = df[num_cols].apply(pd.to_numeric, errors="coerce")
87
217
 
88
- class_df = class_df.sort_index()
89
- tidy = pd.concat([class_df, avg_df], axis=0)
90
- return tidy
218
+ # Apply canonical ordering (classes then avg rows)
219
+ df = self._apply_ordering(df)
220
+ return df
91
221
 
92
222
  def compute_ci(
93
223
  self,
@@ -106,21 +236,21 @@ class ClassificationReportVisualizer:
106
236
  pd.DataFrame: Multi-index columns with (metric, ["lower","upper","mean"]). Index contains any class/avg labels present in the bootstrap reports.
107
237
  """
108
238
  if not boot_reports:
109
- raise ValueError("boot_reports is empty; provide at least one dict.")
239
+ msg = "boot_reports is empty; provide at least one dict."
240
+ raise ValueError(msg)
110
241
 
111
- # Gather frames; union of indices (classes/avg rows) across repeats
112
- frames = []
113
- for rep in boot_reports:
114
- df = self.to_dataframe(rep)
115
- frames.append(df)
242
+ frames = [self.to_dataframe(rep) for rep in boot_reports]
243
+
244
+ # Union of indices across repeats, ordered canonically.
245
+ union_idx = set().union(*[set(f.index) for f in frames])
246
+ common_index = self._ordered_report_index(list(union_idx))
247
+ common_index = [x for x in common_index if x in union_idx]
116
248
 
117
- # Align on index, stack into 3D array (repeat x class x metric)
118
- common_index = sorted(set().union(*[f.index for f in frames]))
119
249
  arrs = []
120
250
  for f in frames:
121
251
  sub = f.reindex(common_index)
122
252
  arrs.append(sub[[m for m in metrics]].to_numpy(dtype=float))
123
- arr = np.stack(arrs, axis=0) # shape: (B, C, M)
253
+ arr = np.stack(arrs, axis=0) # (B, C, M)
124
254
 
125
255
  alpha = (1 - ci) / 2
126
256
  lower_q = 100 * alpha
@@ -145,22 +275,17 @@ class ClassificationReportVisualizer:
145
275
  def _retro_cmap(self, n: int = 256) -> LinearSegmentedColormap:
146
276
  """Create a neon gradient colormap.
147
277
 
148
- This colormap transitions through a series of bright, neon colors.
149
-
150
278
  Args:
151
- n (int): Number of discrete colors in the colormap. Defaults to 256.
279
+ n (int): Number of discrete colors in the colormap.
152
280
 
153
281
  Returns:
154
- LinearSegmentedColormap: The generated colormap.
282
+ LinearSegmentedColormap: Neon-themed colormap.
155
283
  """
156
284
  anchors = ["#241937", "#7d00ff", "#ff00ff", "#ff6ec7", "#00f0ff", "#00ff9f"]
157
285
  return LinearSegmentedColormap.from_list("retro_neon", anchors, N=n)
158
286
 
159
287
  def _set_mpl_style(self) -> None:
160
- """Apply a dark neon Matplotlib theme.
161
-
162
- This method modifies global rcParams; call before plotting.
163
- """
288
+ """Apply a dark neon Matplotlib theme."""
164
289
  plt.rcParams.update(
165
290
  {
166
291
  "figure.facecolor": self.background_hex,
@@ -201,26 +326,29 @@ class ClassificationReportVisualizer:
201
326
  ):
202
327
  """Plot a per-class heatmap with an optional right-hand support strip.
203
328
 
204
- This visualizes the classification metrics for each class.
205
-
206
329
  Args:
207
- df (pd.DataFrame): DataFrame from `to_dataframe()`.
330
+ df (pd.DataFrame): DataFrame from to_dataframe().
208
331
  title (str): Plot title.
209
- classes_only (bool): If True, exclude avg rows.
210
- figsize (Tuple[int, int]): Matplotlib figure size.
332
+ classes_only (bool): Whether to include only classes (exclude avg rows).
333
+ figsize (Tuple[int, int]): Figure size.
211
334
  annot_decimals (int): Decimal places for annotations.
212
- vmax (float): Max heatmap value.
213
- vmin (float): Min heatmap value.
214
- show_support_strip (bool): If True, draw normalized support strip at right.
335
+ vmax (float): Max value for colormap scaling.
336
+ vmin (float): Min value for colormap scaling.
337
+ show_support_strip (bool): Whether to show a support strip on the right.
215
338
 
216
339
  Returns:
217
- matplotlib.figure.Figure: The created figure.
340
+ Figure: Matplotlib figure.
218
341
  """
219
342
  self._set_mpl_style()
220
343
 
221
344
  work = df.copy()
345
+ # Ensure canonical ordering even if caller didn't use to_dataframe().
346
+ work = self._apply_ordering(work)
347
+
222
348
  if classes_only:
223
349
  work = work[~work.index.str.contains("avg", case=False, regex=True)]
350
+ # Re-apply class ordering after filtering
351
+ work = work.reindex(self._ordered_class_labels(work.index))
224
352
 
225
353
  metric_cols = ["precision", "recall", "f1-score"]
226
354
  heat = work[metric_cols].astype(float)
@@ -243,7 +371,6 @@ class ClassificationReportVisualizer:
243
371
  ax.set_xlabel("Metric")
244
372
  ax.set_ylabel("Class")
245
373
 
246
- # Optional support strip (normalized 0..1) as an inset axis
247
374
  if show_support_strip and "support" in work.columns:
248
375
  supports = work["support"].astype(float).fillna(0.0).to_numpy()
249
376
  sup_norm = (supports - supports.min()) / (np.ptp(supports) + 1e-9)
@@ -258,7 +385,6 @@ class ClassificationReportVisualizer:
258
385
  )
259
386
 
260
387
  strip_data = sup_norm[:, None] # (n_classes, 1)
261
-
262
388
  sns.heatmap(
263
389
  strip_data,
264
390
  cmap=self._retro_cmap(),
@@ -271,8 +397,6 @@ class ClassificationReportVisualizer:
271
397
  linewidths=0.0,
272
398
  ax=ax_strip,
273
399
  )
274
-
275
- # Align strip y-limits to main heatmap
276
400
  ax_strip.set_ylim(ax.get_ylim())
277
401
 
278
402
  return fig
@@ -289,23 +413,32 @@ class ClassificationReportVisualizer:
289
413
  """Plot grouped bars for P/R/F1 with support markers and optional CI.
290
414
 
291
415
  Args:
292
- df (pd.DataFrame): DataFrame from `to_dataframe()`.
416
+ df (pd.DataFrame): DataFrame from to_dataframe().
293
417
  title (str): Plot title.
294
- classes_only (bool): If True, exclude avg rows.
418
+ classes_only (bool): Whether to include only classes (exclude avg rows).
295
419
  figsize (Tuple[int, int]): Figure size.
296
- bar_alpha (float): Bar alpha.
297
- ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; adds error bars if provided.
420
+ bar_alpha (float): Alpha transparency for bars.
421
+ ci_df (Optional[pd.DataFrame]): DataFrame from compute_ci() for CI bars (optional).
298
422
 
299
423
  Returns:
300
- matplotlib.figure.Figure: The created figure.
424
+ Figure: Matplotlib figure.
301
425
  """
302
426
  self._set_mpl_style()
427
+
303
428
  work = df.copy()
429
+ work = self._apply_ordering(work)
430
+
304
431
  if classes_only:
305
432
  work = work[~work.index.str.contains("avg", case=False, regex=True)]
433
+ classes = self._ordered_class_labels(work.index)
434
+ work = work.reindex(classes)
435
+ else:
436
+ # If including avgs, only plot classes on x-axis for bars.
437
+ classes = self._ordered_class_labels(
438
+ work.loc[~work.index.str.contains("avg", case=False, regex=True)].index
439
+ )
306
440
 
307
441
  metric_cols = ["precision", "recall", "f1-score"]
308
-
309
442
  lng = (
310
443
  work[metric_cols]
311
444
  .reset_index(names="class")
@@ -313,11 +446,6 @@ class ClassificationReportVisualizer:
313
446
  .dropna(subset=["score"])
314
447
  )
315
448
 
316
- homozygote_order = ["A", "C", "G", "T"]
317
- classes = homozygote_order + [
318
- c for c in lng["class"].unique().tolist() if c not in homozygote_order
319
- ]
320
-
321
449
  metrics = metric_cols
322
450
  palette = self.retro_palette[: len(metrics)]
323
451
 
@@ -327,9 +455,12 @@ class ClassificationReportVisualizer:
327
455
 
328
456
  fig, ax = plt.subplots(figsize=figsize)
329
457
 
330
- # Secondary axis for support markers
331
458
  ax2 = ax.twinx()
332
- supports = work.reindex(classes)["support"].astype(float).fillna(0.0).values
459
+ supports = (
460
+ work.reindex(classes)["support"].astype(float).fillna(0.0).to_numpy()
461
+ if "support" in work.columns
462
+ else np.zeros(len(classes), dtype=float)
463
+ )
333
464
 
334
465
  ax2.plot(
335
466
  x,
@@ -343,13 +474,12 @@ class ClassificationReportVisualizer:
343
474
  label="Support",
344
475
  )
345
476
 
346
- # Plot bars with optional CI error bars
347
477
  for i, m in enumerate(metrics):
348
478
  vals = (
349
479
  lng.loc[lng["metric"].eq(m)]
350
480
  .set_index("class")
351
481
  .reindex(classes)["score"]
352
- .values
482
+ .to_numpy(dtype=float)
353
483
  )
354
484
 
355
485
  yerr = None
@@ -357,14 +487,12 @@ class ClassificationReportVisualizer:
357
487
  ci_reindexed = ci_df.reindex(classes)
358
488
  lows = ci_reindexed[(m, "lower")].to_numpy(dtype=float)
359
489
  ups = ci_reindexed[(m, "upper")].to_numpy(dtype=float)
360
-
361
- # Convert to symmetric error around the point estimate
362
490
  center = vals
363
491
  yerr = np.vstack([center - lows, ups - center])
364
492
 
365
493
  ax.bar(
366
494
  x + offsets[i],
367
- np.asarray(vals),
495
+ vals,
368
496
  width=width * 0.95,
369
497
  label=m.title(),
370
498
  color=palette[i % len(palette)],
@@ -382,10 +510,9 @@ class ClassificationReportVisualizer:
382
510
  ax.set_title(title, pad=12, fontweight="bold")
383
511
  ax.legend(ncols=3, frameon=True, loc="upper left")
384
512
 
385
- # Configure secondary (support) axis
386
513
  ax2.set_ylabel("Support")
387
514
  ax2.grid(False)
388
- ax2.set_ylim(0, max(1.0, np.asarray(supports).max() * 1.15))
515
+ ax2.set_ylim(0, max(1.0, float(np.asarray(supports).max()) * 1.15))
389
516
  ax2.legend(loc="upper right", frameon=True)
390
517
 
391
518
  ax.grid(axis="y", linestyle="--", alpha=0.6)
@@ -404,34 +531,54 @@ class ClassificationReportVisualizer:
404
531
  ) -> go.Figure:
405
532
  """Interactive radar chart of averages + top-k classes; optional CI bands.
406
533
 
407
- This function creates a radar chart using Plotly, displaying the specified metrics for the top-k classes.
408
-
409
534
  Args:
410
- df (pd.DataFrame): DataFrame from `to_dataframe()`.
411
- title (str): Figure title.
412
- top_k (int): Include up to top_k classes by support (descending).
413
- include_micro (bool): Include micro avg trace if available.
414
- include_macro (bool): Include macro avg trace.
415
- include_weighted (bool): Include weighted avg trace.
416
- ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; draws semi-transparent CI bands.
535
+ df (pd.DataFrame): DataFrame from to_dataframe().
536
+ title (str): Plot title.
537
+ top_k (int): Number of top classes by support to include.
538
+ include_micro (bool): Whether to include micro avg.
539
+ include_macro (bool): Whether to include macro avg.
540
+ include_weighted (bool): Whether to include weighted avg.
541
+ ci_df (Optional[pd.DataFrame]): DataFrame from compute_ci() for CI bands (optional).
417
542
 
418
543
  Returns:
419
- plotly.graph_objects.Figure: The interactive radar chart.
544
+ go.Figure: Plotly radar figure.
420
545
  """
421
546
  work = df.copy()
547
+ work = self._apply_ordering(work)
422
548
 
423
549
  is_avg = work.index.str.contains("avg", case=False, regex=True)
424
- classes = work.loc[~is_avg].copy().sort_values("support", ascending=False)
425
- if top_k is not None and top_k > 0:
426
- classes = classes.head(top_k)
427
550
 
551
+ # --- choose top-k by support, but order those chosen by canonical genotype order ---
552
+ class_block = work.loc[~is_avg].copy()
553
+ if top_k is not None and top_k > 0 and "support" in class_block.columns:
554
+ top_labels = (
555
+ (
556
+ class_block["support"]
557
+ .astype(float)
558
+ .fillna(0.0)
559
+ .sort_values(ascending=False)
560
+ )
561
+ .head(top_k)
562
+ .index.tolist()
563
+ )
564
+ ordered_top = self._ordered_class_labels(top_labels)
565
+ classes = class_block.reindex(
566
+ [x for x in ordered_top if x in class_block.index]
567
+ )
568
+ else:
569
+ classes = class_block.reindex(self._ordered_class_labels(class_block.index))
570
+
571
+ # --- averages in canonical order ---
572
+ include_map = {
573
+ "micro avg": include_micro,
574
+ "macro avg": include_macro,
575
+ "weighted avg": include_weighted,
576
+ "samples avg": True, # keep if present; user can ignore via flags by removing from avg_order
577
+ }
428
578
  avgs = []
429
- if include_macro and "macro avg" in work.index:
430
- avgs.append(("macro avg", work.loc["macro avg"]))
431
- if include_weighted and "weighted avg" in work.index:
432
- avgs.append(("weighted avg", work.loc["weighted avg"]))
433
- if include_micro and "micro avg" in work.index:
434
- avgs.append(("micro avg", work.loc["micro avg"]))
579
+ for name in self.avg_order:
580
+ if include_map.get(name, True) and name in work.index:
581
+ avgs.append((name, work.loc[name]))
435
582
 
436
583
  metrics = ["precision", "recall", "f1-score"]
437
584
  theta = metrics + [metrics[0]]
@@ -456,7 +603,6 @@ class ClassificationReportVisualizer:
456
603
  lows.append(lows[0])
457
604
  ups.append(ups[0])
458
605
 
459
- # Plotly polar CI band: plot upper path, then lower reversed with fill
460
606
  fig.add_trace(
461
607
  go.Scatterpolar(
462
608
  r=ups,
@@ -474,19 +620,14 @@ class ClassificationReportVisualizer:
474
620
  mode="lines",
475
621
  line=dict(width=0),
476
622
  fill="toself",
477
- fillcolor=(
478
- color.replace("#", "rgba(") if False else None
479
- ), # placeholder
480
623
  hoverinfo="skip",
481
624
  name=f"{name} CI",
482
625
  showlegend=False,
483
626
  opacity=0.20,
484
627
  )
485
628
  )
486
- # Workaround: directly set fillcolor via marker color on last trace
487
- fig.data[-1].fillcolor = f"{color}33" # add ~20% alpha
629
+ fig.data[-1].fillcolor = f"{color}33" # 8-digit hex w/ alpha
488
630
 
489
- # Add average traces with CI first
490
631
  for i, (name, row) in enumerate(avgs):
491
632
  r = [float(row.get(m, np.nan)) for m in metrics]
492
633
  r.append(r[0])
@@ -504,7 +645,6 @@ class ClassificationReportVisualizer:
504
645
  )
505
646
  )
506
647
 
507
- # Add class traces (top-k) with optional CI
508
648
  base_idx = len(avgs)
509
649
  for i, (cls, row) in enumerate(classes[metrics].iterrows()):
510
650
  r = [float(row.get(m, np.nan)) for m in metrics]
@@ -560,16 +700,17 @@ class ClassificationReportVisualizer:
560
700
  """Generate all visuals, with optional CI from bootstrap reports.
561
701
 
562
702
  Args:
563
- report (Dict[str, Dict[str, float]]): The `output_dict=True` classification report (single run).
564
- title_prefix (str): Common prefix for titles.
565
- heatmap_classes_only (bool): Exclude averages in heatmap if True.
566
- radar_top_k (int): Number of top classes (by support) on radar.
567
- boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI.
568
- ci (float): Confidence level (e.g., 0.95).
569
- show (bool): If True, call plt.show() for Matplotlib figures.
703
+ report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
704
+ title_prefix (str): Prefix for plot titles.
705
+ heatmap_classes_only (bool): Whether to only plot classes (exclude avg rows) in heatmap.
706
+ radar_top_k (int): Number of top classes by support to include in radar plot.
707
+ boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI computation.
708
+ ci (float): Confidence level for CIs.
709
+ show (bool): Whether to display the plots via plt.show().
570
710
 
571
711
  Returns:
572
- Dict[str, Union[matplotlib.figure.Figure, plotly.graph_objects.Figure]]: Keys: {"heatmap_fig", "bars_fig", "radar_fig"}.
712
+ Dict[str, Union[Figure, go.Figure]]: Dictionary with keys:
713
+ "heatmap_fig", "bars_fig", "radar_fig".
573
714
  """
574
715
  df = self.to_dataframe(report)
575
716
  acc = df.attrs.get("accuracy", None)
@@ -599,7 +740,9 @@ class ClassificationReportVisualizer:
599
740
  )
600
741
 
601
742
  if show:
602
- plt.show()
743
+ with warnings.catch_warnings():
744
+ warnings.simplefilter("ignore", UserWarning)
745
+ plt.show()
603
746
 
604
747
  return {
605
748
  "heatmap_fig": heatmap_fig,