pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,591 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ import seaborn as sns
13
+ from matplotlib.colors import LinearSegmentedColormap
14
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
15
+
16
+
17
+ @dataclass
18
+ class ClassificationReportVisualizer:
19
+ """Pretty plotting for scikit-learn classification reports (output_dict=True).
20
+
21
+ Adds neon cyberpunk aesthetics, a per-class support overlay, and optional bootstrap confidence intervals.
22
+
23
+ Attributes:
24
+ retro_palette: Hex colors for neon vibe.
25
+ background_hex: Matplotlib/Plotly dark background.
26
+ grid_hex: Gridline color for dark theme.
27
+ reset_kwargs: Keyword args for resetting Matplotlib rcParams.
28
+ """
29
+
30
+ retro_palette: List[str] = field(
31
+ default_factory=lambda: [
32
+ "#ff00ff",
33
+ "#9400ff",
34
+ "#00f0ff",
35
+ "#00ff9f",
36
+ "#ff6ec7",
37
+ "#7d00ff",
38
+ "#39ff14",
39
+ "#00bcd4",
40
+ ]
41
+ )
42
+ background_hex: str = "#0a0a15"
43
+ grid_hex: str = "#2a2a3a"
44
+ reset_kwargs: Dict[str, bool | str] | None = None
45
+
46
+ # ---------- Core data prep ----------
47
+ def to_dataframe(self, report: Dict[str, Dict[str, float]]) -> pd.DataFrame:
48
+ """Convert sklearn classification_report output_dict to a tidy DataFrame.
49
+
50
+ This method standardizes the output of scikit-learn's classification_report function.
51
+
52
+ Args:
53
+ report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
54
+
55
+ Returns:
56
+ pd.DataFrame: Index are classes/avg rows (str). Columns include ["precision", "recall", "f1-score", "support"]. The "accuracy" scalar (if present) is stored in df.attrs["accuracy"], and the row is removed.
57
+ """
58
+ df = pd.DataFrame(report).T
59
+ for col in ["precision", "recall", "f1-score", "support"]:
60
+ if col not in df.columns:
61
+ df[col] = np.nan
62
+
63
+ if "accuracy" in df.index:
64
+ # sklearn puts accuracy scalar in "accuracy" row, usually in 'precision'
65
+ try:
66
+ df.attrs["accuracy"] = float(df.loc["accuracy", "precision"])
67
+ except Exception:
68
+ df.attrs["accuracy"] = float(df.loc["accuracy"].squeeze())
69
+ df = df.drop(index="accuracy", errors="ignore")
70
+
71
+ df.index = df.index.astype(str)
72
+
73
+ is_avg = df.index.str.contains("avg", case=False, regex=True)
74
+ class_df = df.loc[~is_avg].copy()
75
+ avg_df = df.loc[is_avg].copy()
76
+
77
+ num_cols = ["precision", "recall", "f1-score", "support"]
78
+ class_df[num_cols] = class_df[num_cols].apply(pd.to_numeric, errors="coerce")
79
+ avg_df[num_cols] = avg_df[num_cols].apply(pd.to_numeric, errors="coerce")
80
+
81
+ class_df = class_df.sort_index()
82
+ tidy = pd.concat([class_df, avg_df], axis=0)
83
+ return tidy
84
+
85
+ def compute_ci(
86
+ self,
87
+ boot_reports: List[Dict[str, Dict[str, float]]],
88
+ ci: float = 0.95,
89
+ metrics: Tuple[str, ...] = ("precision", "recall", "f1-score"),
90
+ ) -> pd.DataFrame:
91
+ """Compute per-class bootstrap CIs from multiple report dicts.
92
+
93
+ Args:
94
+ boot_reports (List[Dict[str, Dict[str, float]]]): List of `output_dict=True` results over bootstrap repeats.
95
+ ci (float): Confidence level (e.g., 0.95 for 95%).
96
+ metrics (Tuple[str, ...]): Metrics to compute bounds for.
97
+
98
+ Returns:
99
+ pd.DataFrame: Multi-index columns with (metric, ["lower","upper","mean"]). Index contains any class/avg labels present in the bootstrap reports.
100
+ """
101
+ if not boot_reports:
102
+ raise ValueError("boot_reports is empty; provide at least one dict.")
103
+
104
+ # Gather frames; union of indices (classes/avg rows) across repeats
105
+ frames = []
106
+ for rep in boot_reports:
107
+ df = self.to_dataframe(rep)
108
+ frames.append(df)
109
+
110
+ # Align on index, stack into 3D array (repeat x class x metric)
111
+ common_index = sorted(set().union(*[f.index for f in frames]))
112
+ arrs = []
113
+ for f in frames:
114
+ sub = f.reindex(common_index)
115
+ arrs.append(sub[[m for m in metrics]].to_numpy(dtype=float))
116
+ arr = np.stack(arrs, axis=0) # shape: (B, C, M)
117
+
118
+ alpha = (1 - ci) / 2
119
+ lower_q = 100 * alpha
120
+ upper_q = 100 * (1 - alpha)
121
+
122
+ lower = np.nanpercentile(arr, lower_q, axis=0) # (C, M)
123
+ upper = np.nanpercentile(arr, upper_q, axis=0) # (C, M)
124
+ mean = np.nanmean(arr, axis=0) # (C, M)
125
+
126
+ out = pd.DataFrame(index=common_index)
127
+ for j, m in enumerate(metrics):
128
+ out[(m, "lower")] = lower[:, j]
129
+ out[(m, "upper")] = upper[:, j]
130
+ out[(m, "mean")] = mean[:, j]
131
+
132
+ out.columns = pd.MultiIndex.from_tuples(out.columns)
133
+ return out
134
+
135
+ # ---------- Palettes & styles ----------
136
+ def _retro_cmap(self, n: int = 256) -> LinearSegmentedColormap:
137
+ """Create a neon gradient colormap.
138
+
139
+ This colormap transitions through a series of bright, neon colors.
140
+
141
+ Args:
142
+ n (int): Number of discrete colors in the colormap. Defaults to 256.
143
+
144
+ Returns:
145
+ LinearSegmentedColormap: The generated colormap.
146
+ """
147
+ anchors = ["#241937", "#7d00ff", "#ff00ff", "#ff6ec7", "#00f0ff", "#00ff9f"]
148
+ return LinearSegmentedColormap.from_list("retro_neon", anchors, N=n)
149
+
150
+ def _set_mpl_style(self) -> None:
151
+ """Apply a dark neon Matplotlib theme.
152
+
153
+ This method modifies global rcParams; call before plotting.
154
+ """
155
+ plt.rcParams.update(
156
+ {
157
+ "figure.facecolor": self.background_hex,
158
+ "axes.facecolor": self.background_hex,
159
+ "axes.edgecolor": self.grid_hex,
160
+ "axes.labelcolor": "#e8e8ff",
161
+ "xtick.color": "#d7d7ff",
162
+ "ytick.color": "#d7d7ff",
163
+ "grid.color": self.grid_hex,
164
+ "text.color": "#f7f7ff",
165
+ "axes.grid": True,
166
+ "grid.linestyle": "--",
167
+ "grid.linewidth": 0.5,
168
+ "legend.facecolor": "#121222",
169
+ "legend.edgecolor": self.grid_hex,
170
+ }
171
+ )
172
+
173
+ def _reset_mpl_style(self) -> None:
174
+ """Reset Matplotlib rcParams to default."""
175
+ plt.rcParams.update(plt.rcParamsDefault)
176
+ mpl.rcParams.update(plt.rcParamsDefault)
177
+
178
+ if self.reset_kwargs is not None:
179
+ plt.rcParams.update(self.reset_kwargs)
180
+ mpl.rcParams.update(self.reset_kwargs)
181
+
182
+ def plot_heatmap(
183
+ self,
184
+ df: pd.DataFrame,
185
+ title: str = "Classification Report — Per-Class Metrics",
186
+ classes_only: bool = True,
187
+ figsize: Tuple[int, int] = (12, 6),
188
+ annot_decimals: int = 3,
189
+ vmax: float = 1.0,
190
+ vmin: float = 0.0,
191
+ show_support_strip: bool = False,
192
+ ):
193
+ """Plot a per-class heatmap with an optional right-hand support strip.
194
+
195
+ This visualizes the classification metrics for each class.
196
+
197
+ Args:
198
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
199
+ title (str): Plot title.
200
+ classes_only (bool): If True, exclude avg rows.
201
+ figsize (Tuple[int, int]): Matplotlib figure size.
202
+ annot_decimals (int): Decimal places for annotations.
203
+ vmax (float): Max heatmap value.
204
+ vmin (float): Min heatmap value.
205
+ show_support_strip (bool): If True, draw normalized support strip at right.
206
+
207
+ Returns:
208
+ matplotlib.figure.Figure: The created figure.
209
+ """
210
+ self._set_mpl_style()
211
+
212
+ work = df.copy()
213
+ if classes_only:
214
+ work = work[~work.index.str.contains("avg", case=False, regex=True)]
215
+
216
+ metric_cols = ["precision", "recall", "f1-score"]
217
+ heat = work[metric_cols].astype(float)
218
+
219
+ fig, ax = plt.subplots(figsize=figsize)
220
+ cmap = self._retro_cmap()
221
+ sns.heatmap(
222
+ heat,
223
+ annot=True,
224
+ fmt=f".{annot_decimals}f",
225
+ cmap=cmap,
226
+ vmin=vmin,
227
+ vmax=vmax,
228
+ linewidths=0.5,
229
+ linecolor=self.grid_hex,
230
+ cbar_kws={"label": "Score"},
231
+ ax=ax,
232
+ )
233
+ ax.set_title(title, pad=12, fontweight="bold")
234
+ ax.set_xlabel("Metric")
235
+ ax.set_ylabel("Class")
236
+
237
+ # Optional support strip (normalized 0..1) as an inset axis
238
+ if show_support_strip and "support" in work.columns:
239
+ supports = work["support"].astype(float).fillna(0.0).values
240
+ sup_norm = (supports - supports.min()) / (np.ptp(supports) + 1e-9)
241
+ ax_strip = inset_axes(
242
+ ax,
243
+ width="2%",
244
+ height="100%",
245
+ loc="right",
246
+ bbox_to_anchor=(0.03, 0.0, 1, 1),
247
+ bbox_transform=ax.transAxes,
248
+ borderpad=0,
249
+ )
250
+
251
+ strip_data = sup_norm[:, None] # (n_classes, 1)
252
+
253
+ sns.heatmap(
254
+ strip_data,
255
+ cmap=self._retro_cmap(),
256
+ cbar=True,
257
+ cbar_kws={"label": "Support (normalized)"},
258
+ xticklabels=False,
259
+ yticklabels=False,
260
+ vmin=0.0,
261
+ vmax=1.0,
262
+ linewidths=0.0,
263
+ ax=ax_strip,
264
+ )
265
+
266
+ # Align strip y-limits to main heatmap
267
+ ax_strip.set_ylim(ax.get_ylim())
268
+
269
+ return fig
270
+
271
+ def plot_grouped_bars(
272
+ self,
273
+ df: pd.DataFrame,
274
+ title: str = "Per-Class Metrics (Grouped Bars)",
275
+ classes_only: bool = True,
276
+ figsize: Tuple[int, int] = (14, 7),
277
+ bar_alpha: float = 0.9,
278
+ ci_df: Optional[pd.DataFrame] = None,
279
+ ):
280
+ """Plot grouped bars for P/R/F1 with support markers and optional CI.
281
+
282
+ Args:
283
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
284
+ title (str): Plot title.
285
+ classes_only (bool): If True, exclude avg rows.
286
+ figsize (Tuple[int, int]): Figure size.
287
+ bar_alpha (float): Bar alpha.
288
+ ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; adds error bars if provided.
289
+
290
+ Returns:
291
+ matplotlib.figure.Figure: The created figure.
292
+ """
293
+ self._set_mpl_style()
294
+ work = df.copy()
295
+ if classes_only:
296
+ work = work[~work.index.str.contains("avg", case=False, regex=True)]
297
+
298
+ metric_cols = ["precision", "recall", "f1-score"]
299
+
300
+ lng = (
301
+ work[metric_cols]
302
+ .reset_index(names="class")
303
+ .melt(id_vars="class", var_name="metric", value_name="score")
304
+ .dropna(subset=["score"])
305
+ )
306
+
307
+ homozygote_order = ["A", "C", "G", "T"]
308
+ classes = homozygote_order + [
309
+ c for c in lng["class"].unique().tolist() if c not in homozygote_order
310
+ ]
311
+
312
+ metrics = metric_cols
313
+ palette = self.retro_palette[: len(metrics)]
314
+
315
+ x = np.arange(len(classes))
316
+ width = 0.25
317
+ offsets = np.linspace(-width, width, num=len(metrics))
318
+
319
+ fig, ax = plt.subplots(figsize=figsize)
320
+
321
+ # Secondary axis for support markers
322
+ ax2 = ax.twinx()
323
+ supports = work.reindex(classes)["support"].astype(float).fillna(0.0).values
324
+
325
+ ax2.plot(
326
+ x,
327
+ supports,
328
+ linestyle="None",
329
+ marker="o",
330
+ markersize=6,
331
+ markerfacecolor="#39ff14",
332
+ markeredgecolor="#ffffff",
333
+ alpha=0.9,
334
+ label="Support",
335
+ )
336
+
337
+ # Plot bars with optional CI error bars
338
+ for i, m in enumerate(metrics):
339
+ vals = (
340
+ lng.loc[lng["metric"].eq(m)]
341
+ .set_index("class")
342
+ .reindex(classes)["score"]
343
+ .values
344
+ )
345
+
346
+ yerr = None
347
+ if ci_df is not None and (m, "lower") in ci_df.columns:
348
+ lows = ci_df.loc[classes, (m, "lower")].to_numpy(dtype=float)
349
+ ups = ci_df.loc[classes, (m, "upper")].to_numpy(dtype=float)
350
+
351
+ # Convert to symmetric error around the point estimate
352
+ center = vals
353
+ yerr = np.vstack([center - lows, ups - center])
354
+
355
+ ax.bar(
356
+ x + offsets[i],
357
+ vals,
358
+ width=width * 0.95,
359
+ label=m.title(),
360
+ color=palette[i % len(palette)],
361
+ alpha=bar_alpha,
362
+ edgecolor="#ffffff",
363
+ linewidth=0.4,
364
+ yerr=yerr,
365
+ error_kw=dict(ecolor="#ffffff", elinewidth=0.9, capsize=3),
366
+ )
367
+
368
+ ax.set_xticks(x)
369
+ ax.set_xticklabels(classes, rotation=45, ha="right")
370
+ ax.set_ylim(0, 1.05)
371
+ ax.set_ylabel("Score")
372
+ ax.set_title(title, pad=12, fontweight="bold")
373
+ ax.legend(ncols=3, frameon=True, loc="upper left")
374
+
375
+ # Configure secondary (support) axis
376
+ ax2.set_ylabel("Support")
377
+ ax2.grid(False)
378
+ ax2.set_ylim(0, max(1.0, supports.max() * 1.15))
379
+ ax2.legend(loc="upper right", frameon=True)
380
+
381
+ ax.grid(axis="y", linestyle="--", alpha=0.6)
382
+ plt.tight_layout()
383
+ return fig
384
+
385
+ def plot_radar(
386
+ self,
387
+ df: pd.DataFrame,
388
+ title: str = "Macro/Weighted Averages & Top-K Class Radar",
389
+ top_k: int = 5,
390
+ include_micro: bool = True,
391
+ include_macro: bool = True,
392
+ include_weighted: bool = True,
393
+ ci_df: Optional[pd.DataFrame] = None,
394
+ ) -> go.Figure:
395
+ """Interactive radar chart of averages + top-k classes; optional CI bands.
396
+
397
+ This function creates a radar chart using Plotly, displaying the specified metrics for the top-k classes.
398
+
399
+ Args:
400
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
401
+ title (str): Figure title.
402
+ top_k (int): Include up to top_k classes by support (descending).
403
+ include_micro (bool): Include micro avg trace if available.
404
+ include_macro (bool): Include macro avg trace.
405
+ include_weighted (bool): Include weighted avg trace.
406
+ ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; draws semi-transparent CI bands.
407
+
408
+ Returns:
409
+ plotly.graph_objects.Figure: The interactive radar chart.
410
+ """
411
+ work = df.copy()
412
+
413
+ is_avg = work.index.str.contains("avg", case=False, regex=True)
414
+ classes = work.loc[~is_avg].copy().sort_values("support", ascending=False)
415
+ if top_k is not None and top_k > 0:
416
+ classes = classes.head(top_k)
417
+
418
+ avgs = []
419
+ if include_macro and "macro avg" in work.index:
420
+ avgs.append(("macro avg", work.loc["macro avg"]))
421
+ if include_weighted and "weighted avg" in work.index:
422
+ avgs.append(("weighted avg", work.loc["weighted avg"]))
423
+ if include_micro and "micro avg" in work.index:
424
+ avgs.append(("micro avg", work.loc["micro avg"]))
425
+
426
+ metrics = ["precision", "recall", "f1-score"]
427
+ theta = metrics + [metrics[0]]
428
+
429
+ fig = go.Figure()
430
+
431
+ def _add_ci_band(name: str, color: str):
432
+ if ci_df is None:
433
+ return
434
+ if not all([(m, "lower") in ci_df.columns for m in metrics]):
435
+ return
436
+ if name not in ci_df.index:
437
+ return
438
+ lows = [float(ci_df.loc[name, (m, "lower")]) for m in metrics]
439
+ ups = [float(ci_df.loc[name, (m, "upper")]) for m in metrics]
440
+ lows.append(lows[0])
441
+ ups.append(ups[0])
442
+
443
+ # Plotly polar CI band: plot upper path, then lower reversed with fill
444
+ fig.add_trace(
445
+ go.Scatterpolar(
446
+ r=ups,
447
+ theta=theta,
448
+ mode="lines",
449
+ line=dict(width=0),
450
+ hoverinfo="skip",
451
+ showlegend=False,
452
+ )
453
+ )
454
+ fig.add_trace(
455
+ go.Scatterpolar(
456
+ r=lows[::-1],
457
+ theta=theta[::-1],
458
+ mode="lines",
459
+ line=dict(width=0),
460
+ fill="toself",
461
+ fillcolor=(
462
+ color.replace("#", "rgba(") if False else None
463
+ ), # placeholder
464
+ hoverinfo="skip",
465
+ name=f"{name} CI",
466
+ showlegend=False,
467
+ opacity=0.20,
468
+ )
469
+ )
470
+ # Workaround: directly set fillcolor via marker color on last trace
471
+ fig.data[-1].fillcolor = f"{color}33" # add ~20% alpha
472
+
473
+ # Add average traces with CI first
474
+ for i, (name, row) in enumerate(avgs):
475
+ r = [float(row.get(m, np.nan)) for m in metrics]
476
+ r.append(r[0])
477
+ color = self.retro_palette[i % len(self.retro_palette)]
478
+ _add_ci_band(name, color)
479
+ fig.add_trace(
480
+ go.Scatterpolar(
481
+ r=r,
482
+ theta=theta,
483
+ name=name.title(),
484
+ mode="lines+markers",
485
+ line=dict(width=3, color=color),
486
+ marker=dict(size=7, color=color),
487
+ opacity=0.95,
488
+ )
489
+ )
490
+
491
+ # Add class traces (top-k) with optional CI
492
+ base_idx = len(avgs)
493
+ for i, (cls, row) in enumerate(classes[metrics].iterrows()):
494
+ r = [float(row.get(m, np.nan)) for m in metrics]
495
+ r.append(r[0])
496
+ color = self.retro_palette[(base_idx + i) % len(self.retro_palette)]
497
+ _add_ci_band(str(cls), color)
498
+ fig.add_trace(
499
+ go.Scatterpolar(
500
+ r=r,
501
+ theta=theta,
502
+ name=str(cls),
503
+ mode="lines+markers",
504
+ line=dict(width=2, color=color),
505
+ marker=dict(size=6, color=color),
506
+ opacity=0.85,
507
+ )
508
+ )
509
+
510
+ fig.update_layout(
511
+ title=title,
512
+ template="plotly_dark",
513
+ paper_bgcolor=self.background_hex,
514
+ plot_bgcolor=self.background_hex,
515
+ polar=dict(
516
+ bgcolor="#111122",
517
+ radialaxis=dict(range=[0, 1.05], showline=True, gridcolor="#33334d"),
518
+ angularaxis=dict(gridcolor="#33334d"),
519
+ ),
520
+ legend=dict(
521
+ bgcolor="#121222",
522
+ bordercolor="#2a2a3a",
523
+ borderwidth=1,
524
+ orientation="h",
525
+ yanchor="bottom",
526
+ y=-0.15,
527
+ x=0.5,
528
+ xanchor="center",
529
+ ),
530
+ )
531
+ return fig
532
+
533
+ def plot_all(
534
+ self,
535
+ report: Dict[str, Dict[str, float]],
536
+ title_prefix: str = "Classification Report",
537
+ heatmap_classes_only: bool = True,
538
+ radar_top_k: int = 5,
539
+ boot_reports: Optional[List[Dict[str, Dict[str, float]]]] = None,
540
+ ci: float = 0.95,
541
+ show: bool = True,
542
+ ) -> Dict[str, Union[plt.Figure, go.Figure]]:
543
+ """Generate all visuals, with optional CI from bootstrap reports.
544
+
545
+ Args:
546
+ report (Dict[str, Dict[str, float]]): The `output_dict=True` classification report (single run).
547
+ title_prefix (str): Common prefix for titles.
548
+ heatmap_classes_only (bool): Exclude averages in heatmap if True.
549
+ radar_top_k (int): Number of top classes (by support) on radar.
550
+ boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI.
551
+ ci (float): Confidence level (e.g., 0.95).
552
+ show (bool): If True, call plt.show() for Matplotlib figures.
553
+
554
+ Returns:
555
+ Dict[str, Union[matplotlib.figure.Figure, plotly.graph_objects.Figure]]: Keys: {"heatmap_fig", "bars_fig", "radar_fig"}.
556
+ """
557
+ df = self.to_dataframe(report)
558
+ acc = df.attrs.get("accuracy", None)
559
+ acc_str = f" (Accuracy: {acc:.3f})" if isinstance(acc, float) else ""
560
+
561
+ ci_df = None
562
+ if boot_reports:
563
+ ci_df = self.compute_ci(boot_reports, ci=ci)
564
+
565
+ heatmap_fig = self.plot_heatmap(
566
+ df,
567
+ title=f"{title_prefix} — Heatmap{acc_str}",
568
+ classes_only=heatmap_classes_only,
569
+ show_support_strip=False,
570
+ )
571
+ bars_fig = self.plot_grouped_bars(
572
+ df,
573
+ title=f"{title_prefix} — Grouped Bars{acc_str}",
574
+ classes_only=True,
575
+ ci_df=ci_df,
576
+ )
577
+ radar_fig = self.plot_radar(
578
+ df,
579
+ title=f"{title_prefix} — Averages & Top-{radar_top_k} Classes",
580
+ top_k=radar_top_k,
581
+ ci_df=ci_df,
582
+ )
583
+
584
+ if show:
585
+ plt.show()
586
+
587
+ return {
588
+ "heatmap_fig": heatmap_fig,
589
+ "bars_fig": bars_fig,
590
+ "radar_fig": radar_fig,
591
+ }