pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,608 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
+
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.graph_objects as go
12
+ import seaborn as sns
13
+ from matplotlib.colors import LinearSegmentedColormap
14
+ from mpl_toolkits.axes_grid1.inset_locator import inset_axes
15
+
16
+ if TYPE_CHECKING:
17
+ from matplotlib.figure import Figure
18
+
19
+
20
+ @dataclass
21
+ class ClassificationReportVisualizer:
22
+ """Pretty plotting for scikit-learn classification reports (output_dict=True).
23
+
24
+ Adds neon cyberpunk aesthetics, a per-class support overlay, and optional bootstrap confidence intervals.
25
+
26
+ Attributes:
27
+ retro_palette: Hex colors for neon vibe.
28
+ background_hex: Matplotlib/Plotly dark background.
29
+ grid_hex: Gridline color for dark theme.
30
+ reset_kwargs: Keyword args for resetting Matplotlib rcParams.
31
+ """
32
+
33
+ retro_palette: List[str] = field(
34
+ default_factory=lambda: [
35
+ "#ff00ff",
36
+ "#9400ff",
37
+ "#00f0ff",
38
+ "#00ff9f",
39
+ "#ff6ec7",
40
+ "#7d00ff",
41
+ "#39ff14",
42
+ "#00bcd4",
43
+ ]
44
+ )
45
+ background_hex: str = "#0a0a15"
46
+ grid_hex: str = "#2a2a3a"
47
+ reset_kwargs: Dict[str, bool | str] | None = None
48
+
49
+ # ---------- Core data prep ----------
50
+ def to_dataframe(self, report: Dict[str, Dict[str, float]]) -> pd.DataFrame:
51
+ """Convert sklearn classification_report output_dict to a tidy DataFrame.
52
+
53
+ This method standardizes the output of scikit-learn's classification_report function.
54
+
55
+ Args:
56
+ report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
57
+
58
+ Returns:
59
+ pd.DataFrame: Index are classes/avg rows (str). Columns include ["precision", "recall", "f1-score", "support"]. The "accuracy" scalar (if present) is stored in df.attrs["accuracy"], and the row is removed.
60
+ """
61
+ df = pd.DataFrame(report).T
62
+ for col in ["precision", "recall", "f1-score", "support"]:
63
+ if col not in df.columns:
64
+ df[col] = np.nan
65
+
66
+ if "accuracy" in df.index:
67
+ # sklearn puts accuracy scalar in "accuracy" row, usually in 'precision'
68
+ try:
69
+ acc_val = df.loc["accuracy", "precision"]
70
+ if pd.api.types.is_number(acc_val):
71
+ df.attrs["accuracy"] = float(str(acc_val))
72
+ except Exception:
73
+ squeezed_val = df.loc["accuracy"].squeeze()
74
+ if pd.api.types.is_number(squeezed_val):
75
+ df.attrs["accuracy"] = float(str(squeezed_val))
76
+ df = df.drop(index="accuracy", errors="ignore")
77
+
78
+ df.index = df.index.astype(str)
79
+
80
+ is_avg = df.index.str.contains("avg", case=False, regex=True)
81
+ class_df = df.loc[~is_avg].copy()
82
+ avg_df = df.loc[is_avg].copy()
83
+
84
+ num_cols = ["precision", "recall", "f1-score", "support"]
85
+ class_df[num_cols] = class_df[num_cols].apply(pd.to_numeric, errors="coerce")
86
+ avg_df[num_cols] = avg_df[num_cols].apply(pd.to_numeric, errors="coerce")
87
+
88
+ class_df = class_df.sort_index()
89
+ tidy = pd.concat([class_df, avg_df], axis=0)
90
+ return tidy
91
+
92
+ def compute_ci(
93
+ self,
94
+ boot_reports: List[Dict[str, Dict[str, float]]],
95
+ ci: float = 0.95,
96
+ metrics: Tuple[str, ...] = ("precision", "recall", "f1-score"),
97
+ ) -> pd.DataFrame:
98
+ """Compute per-class bootstrap CIs from multiple report dicts.
99
+
100
+ Args:
101
+ boot_reports (List[Dict[str, Dict[str, float]]]): List of `output_dict=True` results over bootstrap repeats.
102
+ ci (float): Confidence level (e.g., 0.95 for 95%).
103
+ metrics (Tuple[str, ...]): Metrics to compute bounds for.
104
+
105
+ Returns:
106
+ pd.DataFrame: Multi-index columns with (metric, ["lower","upper","mean"]). Index contains any class/avg labels present in the bootstrap reports.
107
+ """
108
+ if not boot_reports:
109
+ raise ValueError("boot_reports is empty; provide at least one dict.")
110
+
111
+ # Gather frames; union of indices (classes/avg rows) across repeats
112
+ frames = []
113
+ for rep in boot_reports:
114
+ df = self.to_dataframe(rep)
115
+ frames.append(df)
116
+
117
+ # Align on index, stack into 3D array (repeat x class x metric)
118
+ common_index = sorted(set().union(*[f.index for f in frames]))
119
+ arrs = []
120
+ for f in frames:
121
+ sub = f.reindex(common_index)
122
+ arrs.append(sub[[m for m in metrics]].to_numpy(dtype=float))
123
+ arr = np.stack(arrs, axis=0) # shape: (B, C, M)
124
+
125
+ alpha = (1 - ci) / 2
126
+ lower_q = 100 * alpha
127
+ upper_q = 100 * (1 - alpha)
128
+
129
+ lower = np.nanpercentile(arr, lower_q, axis=0) # (C, M)
130
+ upper = np.nanpercentile(arr, upper_q, axis=0) # (C, M)
131
+ mean = np.nanmean(arr, axis=0) # (C, M)
132
+
133
+ out = pd.DataFrame(index=common_index)
134
+ column_tuples = []
135
+ for j, m in enumerate(metrics):
136
+ out[(m, "lower")] = lower[:, j]
137
+ out[(m, "upper")] = upper[:, j]
138
+ out[(m, "mean")] = mean[:, j]
139
+ column_tuples.extend([(m, "lower"), (m, "upper"), (m, "mean")])
140
+
141
+ out.columns = pd.MultiIndex.from_tuples(column_tuples)
142
+ return out
143
+
144
+ # ---------- Palettes & styles ----------
145
+ def _retro_cmap(self, n: int = 256) -> LinearSegmentedColormap:
146
+ """Create a neon gradient colormap.
147
+
148
+ This colormap transitions through a series of bright, neon colors.
149
+
150
+ Args:
151
+ n (int): Number of discrete colors in the colormap. Defaults to 256.
152
+
153
+ Returns:
154
+ LinearSegmentedColormap: The generated colormap.
155
+ """
156
+ anchors = ["#241937", "#7d00ff", "#ff00ff", "#ff6ec7", "#00f0ff", "#00ff9f"]
157
+ return LinearSegmentedColormap.from_list("retro_neon", anchors, N=n)
158
+
159
+ def _set_mpl_style(self) -> None:
160
+ """Apply a dark neon Matplotlib theme.
161
+
162
+ This method modifies global rcParams; call before plotting.
163
+ """
164
+ plt.rcParams.update(
165
+ {
166
+ "figure.facecolor": self.background_hex,
167
+ "axes.facecolor": self.background_hex,
168
+ "axes.edgecolor": self.grid_hex,
169
+ "axes.labelcolor": "#e8e8ff",
170
+ "xtick.color": "#d7d7ff",
171
+ "ytick.color": "#d7d7ff",
172
+ "grid.color": self.grid_hex,
173
+ "text.color": "#f7f7ff",
174
+ "axes.grid": True,
175
+ "grid.linestyle": "--",
176
+ "grid.linewidth": 0.5,
177
+ "legend.facecolor": "#121222",
178
+ "legend.edgecolor": self.grid_hex,
179
+ }
180
+ )
181
+
182
+ def _reset_mpl_style(self) -> None:
183
+ """Reset Matplotlib rcParams to default."""
184
+ plt.rcParams.update(mpl.rcParamsDefault)
185
+ mpl.rcParams.update(mpl.rcParamsDefault)
186
+
187
+ if self.reset_kwargs is not None:
188
+ plt.rcParams.update(self.reset_kwargs)
189
+ mpl.rcParams.update(self.reset_kwargs)
190
+
191
+ def plot_heatmap(
192
+ self,
193
+ df: pd.DataFrame,
194
+ title: str = "Classification Report — Per-Class Metrics",
195
+ classes_only: bool = True,
196
+ figsize: Tuple[int, int] = (12, 6),
197
+ annot_decimals: int = 3,
198
+ vmax: float = 1.0,
199
+ vmin: float = 0.0,
200
+ show_support_strip: bool = False,
201
+ ):
202
+ """Plot a per-class heatmap with an optional right-hand support strip.
203
+
204
+ This visualizes the classification metrics for each class.
205
+
206
+ Args:
207
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
208
+ title (str): Plot title.
209
+ classes_only (bool): If True, exclude avg rows.
210
+ figsize (Tuple[int, int]): Matplotlib figure size.
211
+ annot_decimals (int): Decimal places for annotations.
212
+ vmax (float): Max heatmap value.
213
+ vmin (float): Min heatmap value.
214
+ show_support_strip (bool): If True, draw normalized support strip at right.
215
+
216
+ Returns:
217
+ matplotlib.figure.Figure: The created figure.
218
+ """
219
+ self._set_mpl_style()
220
+
221
+ work = df.copy()
222
+ if classes_only:
223
+ work = work[~work.index.str.contains("avg", case=False, regex=True)]
224
+
225
+ metric_cols = ["precision", "recall", "f1-score"]
226
+ heat = work[metric_cols].astype(float)
227
+
228
+ fig, ax = plt.subplots(figsize=figsize)
229
+ cmap = self._retro_cmap()
230
+ sns.heatmap(
231
+ heat,
232
+ annot=True,
233
+ fmt=f".{annot_decimals}f",
234
+ cmap=cmap,
235
+ vmin=vmin,
236
+ vmax=vmax,
237
+ linewidths=0.5,
238
+ linecolor=self.grid_hex,
239
+ cbar_kws={"label": "Score"},
240
+ ax=ax,
241
+ )
242
+ ax.set_title(title, pad=12, fontweight="bold")
243
+ ax.set_xlabel("Metric")
244
+ ax.set_ylabel("Class")
245
+
246
+ # Optional support strip (normalized 0..1) as an inset axis
247
+ if show_support_strip and "support" in work.columns:
248
+ supports = work["support"].astype(float).fillna(0.0).to_numpy()
249
+ sup_norm = (supports - supports.min()) / (np.ptp(supports) + 1e-9)
250
+ ax_strip = inset_axes(
251
+ ax,
252
+ width="2%",
253
+ height="100%",
254
+ loc="right",
255
+ bbox_to_anchor=(0.03, 0.0, 1, 1),
256
+ bbox_transform=ax.transAxes,
257
+ borderpad=0,
258
+ )
259
+
260
+ strip_data = sup_norm[:, None] # (n_classes, 1)
261
+
262
+ sns.heatmap(
263
+ strip_data,
264
+ cmap=self._retro_cmap(),
265
+ cbar=True,
266
+ cbar_kws={"label": "Support (normalized)"},
267
+ xticklabels=False,
268
+ yticklabels=False,
269
+ vmin=0.0,
270
+ vmax=1.0,
271
+ linewidths=0.0,
272
+ ax=ax_strip,
273
+ )
274
+
275
+ # Align strip y-limits to main heatmap
276
+ ax_strip.set_ylim(ax.get_ylim())
277
+
278
+ return fig
279
+
280
+ def plot_grouped_bars(
281
+ self,
282
+ df: pd.DataFrame,
283
+ title: str = "Per-Class Metrics (Grouped Bars)",
284
+ classes_only: bool = True,
285
+ figsize: Tuple[int, int] = (14, 7),
286
+ bar_alpha: float = 0.9,
287
+ ci_df: Optional[pd.DataFrame] = None,
288
+ ):
289
+ """Plot grouped bars for P/R/F1 with support markers and optional CI.
290
+
291
+ Args:
292
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
293
+ title (str): Plot title.
294
+ classes_only (bool): If True, exclude avg rows.
295
+ figsize (Tuple[int, int]): Figure size.
296
+ bar_alpha (float): Bar alpha.
297
+ ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; adds error bars if provided.
298
+
299
+ Returns:
300
+ matplotlib.figure.Figure: The created figure.
301
+ """
302
+ self._set_mpl_style()
303
+ work = df.copy()
304
+ if classes_only:
305
+ work = work[~work.index.str.contains("avg", case=False, regex=True)]
306
+
307
+ metric_cols = ["precision", "recall", "f1-score"]
308
+
309
+ lng = (
310
+ work[metric_cols]
311
+ .reset_index(names="class")
312
+ .melt(id_vars="class", var_name="metric", value_name="score")
313
+ .dropna(subset=["score"])
314
+ )
315
+
316
+ homozygote_order = ["A", "C", "G", "T"]
317
+ classes = homozygote_order + [
318
+ c for c in lng["class"].unique().tolist() if c not in homozygote_order
319
+ ]
320
+
321
+ metrics = metric_cols
322
+ palette = self.retro_palette[: len(metrics)]
323
+
324
+ x = np.arange(len(classes))
325
+ width = 0.25
326
+ offsets = np.linspace(-width, width, num=len(metrics))
327
+
328
+ fig, ax = plt.subplots(figsize=figsize)
329
+
330
+ # Secondary axis for support markers
331
+ ax2 = ax.twinx()
332
+ supports = work.reindex(classes)["support"].astype(float).fillna(0.0).values
333
+
334
+ ax2.plot(
335
+ x,
336
+ np.asarray(supports),
337
+ linestyle="None",
338
+ marker="o",
339
+ markersize=6,
340
+ markerfacecolor="#39ff14",
341
+ markeredgecolor="#ffffff",
342
+ alpha=0.9,
343
+ label="Support",
344
+ )
345
+
346
+ # Plot bars with optional CI error bars
347
+ for i, m in enumerate(metrics):
348
+ vals = (
349
+ lng.loc[lng["metric"].eq(m)]
350
+ .set_index("class")
351
+ .reindex(classes)["score"]
352
+ .values
353
+ )
354
+
355
+ yerr = None
356
+ if ci_df is not None and (m, "lower") in ci_df.columns:
357
+ ci_reindexed = ci_df.reindex(classes)
358
+ lows = ci_reindexed[(m, "lower")].to_numpy(dtype=float)
359
+ ups = ci_reindexed[(m, "upper")].to_numpy(dtype=float)
360
+
361
+ # Convert to symmetric error around the point estimate
362
+ center = vals
363
+ yerr = np.vstack([center - lows, ups - center])
364
+
365
+ ax.bar(
366
+ x + offsets[i],
367
+ np.asarray(vals),
368
+ width=width * 0.95,
369
+ label=m.title(),
370
+ color=palette[i % len(palette)],
371
+ alpha=bar_alpha,
372
+ edgecolor="#ffffff",
373
+ linewidth=0.4,
374
+ yerr=yerr,
375
+ error_kw=dict(ecolor="#ffffff", elinewidth=0.9, capsize=3),
376
+ )
377
+
378
+ ax.set_xticks(x)
379
+ ax.set_xticklabels(classes, rotation=45, ha="right")
380
+ ax.set_ylim(0, 1.05)
381
+ ax.set_ylabel("Score")
382
+ ax.set_title(title, pad=12, fontweight="bold")
383
+ ax.legend(ncols=3, frameon=True, loc="upper left")
384
+
385
+ # Configure secondary (support) axis
386
+ ax2.set_ylabel("Support")
387
+ ax2.grid(False)
388
+ ax2.set_ylim(0, max(1.0, np.asarray(supports).max() * 1.15))
389
+ ax2.legend(loc="upper right", frameon=True)
390
+
391
+ ax.grid(axis="y", linestyle="--", alpha=0.6)
392
+ plt.tight_layout()
393
+ return fig
394
+
395
+ def plot_radar(
396
+ self,
397
+ df: pd.DataFrame,
398
+ title: str = "Macro/Weighted Averages & Top-K Class Radar",
399
+ top_k: int = 5,
400
+ include_micro: bool = True,
401
+ include_macro: bool = True,
402
+ include_weighted: bool = True,
403
+ ci_df: Optional[pd.DataFrame] = None,
404
+ ) -> go.Figure:
405
+ """Interactive radar chart of averages + top-k classes; optional CI bands.
406
+
407
+ This function creates a radar chart using Plotly, displaying the specified metrics for the top-k classes.
408
+
409
+ Args:
410
+ df (pd.DataFrame): DataFrame from `to_dataframe()`.
411
+ title (str): Figure title.
412
+ top_k (int): Include up to top_k classes by support (descending).
413
+ include_micro (bool): Include micro avg trace if available.
414
+ include_macro (bool): Include macro avg trace.
415
+ include_weighted (bool): Include weighted avg trace.
416
+ ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; draws semi-transparent CI bands.
417
+
418
+ Returns:
419
+ plotly.graph_objects.Figure: The interactive radar chart.
420
+ """
421
+ work = df.copy()
422
+
423
+ is_avg = work.index.str.contains("avg", case=False, regex=True)
424
+ classes = work.loc[~is_avg].copy().sort_values("support", ascending=False)
425
+ if top_k is not None and top_k > 0:
426
+ classes = classes.head(top_k)
427
+
428
+ avgs = []
429
+ if include_macro and "macro avg" in work.index:
430
+ avgs.append(("macro avg", work.loc["macro avg"]))
431
+ if include_weighted and "weighted avg" in work.index:
432
+ avgs.append(("weighted avg", work.loc["weighted avg"]))
433
+ if include_micro and "micro avg" in work.index:
434
+ avgs.append(("micro avg", work.loc["micro avg"]))
435
+
436
+ metrics = ["precision", "recall", "f1-score"]
437
+ theta = metrics + [metrics[0]]
438
+
439
+ fig = go.Figure()
440
+
441
+ def _add_ci_band(name: str, color: str):
442
+ if ci_df is None:
443
+ return
444
+ if not all([(m, "lower") in ci_df.columns for m in metrics]):
445
+ return
446
+ if name not in ci_df.index:
447
+ return
448
+ lows = [
449
+ float(pd.to_numeric(ci_df.loc[name, (m, "lower")], errors="coerce"))
450
+ for m in metrics
451
+ ]
452
+ ups = [
453
+ float(pd.to_numeric(ci_df.loc[name, (m, "upper")], errors="coerce"))
454
+ for m in metrics
455
+ ]
456
+ lows.append(lows[0])
457
+ ups.append(ups[0])
458
+
459
+ # Plotly polar CI band: plot upper path, then lower reversed with fill
460
+ fig.add_trace(
461
+ go.Scatterpolar(
462
+ r=ups,
463
+ theta=theta,
464
+ mode="lines",
465
+ line=dict(width=0),
466
+ hoverinfo="skip",
467
+ showlegend=False,
468
+ )
469
+ )
470
+ fig.add_trace(
471
+ go.Scatterpolar(
472
+ r=lows[::-1],
473
+ theta=theta[::-1],
474
+ mode="lines",
475
+ line=dict(width=0),
476
+ fill="toself",
477
+ fillcolor=(
478
+ color.replace("#", "rgba(") if False else None
479
+ ), # placeholder
480
+ hoverinfo="skip",
481
+ name=f"{name} CI",
482
+ showlegend=False,
483
+ opacity=0.20,
484
+ )
485
+ )
486
+ # Workaround: directly set fillcolor via marker color on last trace
487
+ fig.data[-1].fillcolor = f"{color}33" # add ~20% alpha
488
+
489
+ # Add average traces with CI first
490
+ for i, (name, row) in enumerate(avgs):
491
+ r = [float(row.get(m, np.nan)) for m in metrics]
492
+ r.append(r[0])
493
+ color = self.retro_palette[i % len(self.retro_palette)]
494
+ _add_ci_band(name, color)
495
+ fig.add_trace(
496
+ go.Scatterpolar(
497
+ r=r,
498
+ theta=theta,
499
+ name=name.title(),
500
+ mode="lines+markers",
501
+ line=dict(width=3, color=color),
502
+ marker=dict(size=7, color=color),
503
+ opacity=0.95,
504
+ )
505
+ )
506
+
507
+ # Add class traces (top-k) with optional CI
508
+ base_idx = len(avgs)
509
+ for i, (cls, row) in enumerate(classes[metrics].iterrows()):
510
+ r = [float(row.get(m, np.nan)) for m in metrics]
511
+ r.append(r[0])
512
+ color = self.retro_palette[(base_idx + i) % len(self.retro_palette)]
513
+ _add_ci_band(str(cls), color)
514
+ fig.add_trace(
515
+ go.Scatterpolar(
516
+ r=r,
517
+ theta=theta,
518
+ name=str(cls),
519
+ mode="lines+markers",
520
+ line=dict(width=2, color=color),
521
+ marker=dict(size=6, color=color),
522
+ opacity=0.85,
523
+ )
524
+ )
525
+
526
+ fig.update_layout(
527
+ title=title,
528
+ template="plotly_dark",
529
+ paper_bgcolor=self.background_hex,
530
+ plot_bgcolor=self.background_hex,
531
+ polar=dict(
532
+ bgcolor="#111122",
533
+ radialaxis=dict(range=[0, 1.05], showline=True, gridcolor="#33334d"),
534
+ angularaxis=dict(gridcolor="#33334d"),
535
+ ),
536
+ legend=dict(
537
+ bgcolor="#121222",
538
+ bordercolor="#2a2a3a",
539
+ borderwidth=1,
540
+ orientation="h",
541
+ yanchor="bottom",
542
+ y=-0.15,
543
+ x=0.5,
544
+ xanchor="center",
545
+ ),
546
+ )
547
+
548
+ return fig
549
+
550
+ def plot_all(
551
+ self,
552
+ report: Dict[str, Dict[str, float]],
553
+ title_prefix: str = "Classification Report",
554
+ heatmap_classes_only: bool = True,
555
+ radar_top_k: int = 10,
556
+ boot_reports: Optional[List[Dict[str, Dict[str, float]]]] = None,
557
+ ci: float = 0.95,
558
+ show: bool = True,
559
+ ) -> Dict[str, Union["Figure", go.Figure]]:
560
+ """Generate all visuals, with optional CI from bootstrap reports.
561
+
562
+ Args:
563
+ report (Dict[str, Dict[str, float]]): The `output_dict=True` classification report (single run).
564
+ title_prefix (str): Common prefix for titles.
565
+ heatmap_classes_only (bool): Exclude averages in heatmap if True.
566
+ radar_top_k (int): Number of top classes (by support) on radar.
567
+ boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI.
568
+ ci (float): Confidence level (e.g., 0.95).
569
+ show (bool): If True, call plt.show() for Matplotlib figures.
570
+
571
+ Returns:
572
+ Dict[str, Union[matplotlib.figure.Figure, plotly.graph_objects.Figure]]: Keys: {"heatmap_fig", "bars_fig", "radar_fig"}.
573
+ """
574
+ df = self.to_dataframe(report)
575
+ acc = df.attrs.get("accuracy", None)
576
+ acc_str = f" (Accuracy: {acc:.3f})" if isinstance(acc, float) else ""
577
+
578
+ ci_df = None
579
+ if boot_reports:
580
+ ci_df = self.compute_ci(boot_reports, ci=ci)
581
+
582
+ heatmap_fig = self.plot_heatmap(
583
+ df,
584
+ title=f"{title_prefix} — Heatmap{acc_str}",
585
+ classes_only=heatmap_classes_only,
586
+ show_support_strip=False,
587
+ )
588
+ bars_fig = self.plot_grouped_bars(
589
+ df,
590
+ title=f"{title_prefix} — Grouped Bars{acc_str}",
591
+ classes_only=True,
592
+ ci_df=ci_df,
593
+ )
594
+ radar_fig = self.plot_radar(
595
+ df,
596
+ title=f"{title_prefix} — Averages & Top-{radar_top_k} Classes",
597
+ top_k=radar_top_k,
598
+ ci_df=ci_df,
599
+ )
600
+
601
+ if show:
602
+ plt.show()
603
+
604
+ return {
605
+ "heatmap_fig": heatmap_fig,
606
+ "bars_fig": bars_fig,
607
+ "radar_fig": radar_fig,
608
+ }
@@ -0,0 +1,22 @@
1
+ import logging
2
+
3
+
4
+ def configure_logger(
5
+ logger: logging.Logger,
6
+ *,
7
+ verbose: bool = False,
8
+ debug: bool = False,
9
+ quiet_level: int = logging.ERROR,
10
+ ) -> logging.Logger:
11
+ """Force a logger and its handlers to respect verbose/debug controls."""
12
+ if debug:
13
+ level = logging.DEBUG
14
+ elif verbose:
15
+ level = logging.INFO
16
+ else:
17
+ level = quiet_level
18
+
19
+ logger.setLevel(level)
20
+ for handler in getattr(logger, "handlers", ()):
21
+ handler.setLevel(level)
22
+ return logger