pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.0.dist-info/RECORD +0 -75
- pg_sui-0.2.0.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
@@ -0,0 +1,608 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import matplotlib as mpl
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import plotly.graph_objects as go
|
|
12
|
+
import seaborn as sns
|
|
13
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
14
|
+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from matplotlib.figure import Figure
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ClassificationReportVisualizer:
|
|
22
|
+
"""Pretty plotting for scikit-learn classification reports (output_dict=True).
|
|
23
|
+
|
|
24
|
+
Adds neon cyberpunk aesthetics, a per-class support overlay, and optional bootstrap confidence intervals.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
retro_palette: Hex colors for neon vibe.
|
|
28
|
+
background_hex: Matplotlib/Plotly dark background.
|
|
29
|
+
grid_hex: Gridline color for dark theme.
|
|
30
|
+
reset_kwargs: Keyword args for resetting Matplotlib rcParams.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
retro_palette: List[str] = field(
|
|
34
|
+
default_factory=lambda: [
|
|
35
|
+
"#ff00ff",
|
|
36
|
+
"#9400ff",
|
|
37
|
+
"#00f0ff",
|
|
38
|
+
"#00ff9f",
|
|
39
|
+
"#ff6ec7",
|
|
40
|
+
"#7d00ff",
|
|
41
|
+
"#39ff14",
|
|
42
|
+
"#00bcd4",
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
background_hex: str = "#0a0a15"
|
|
46
|
+
grid_hex: str = "#2a2a3a"
|
|
47
|
+
reset_kwargs: Dict[str, bool | str] | None = None
|
|
48
|
+
|
|
49
|
+
# ---------- Core data prep ----------
|
|
50
|
+
def to_dataframe(self, report: Dict[str, Dict[str, float]]) -> pd.DataFrame:
|
|
51
|
+
"""Convert sklearn classification_report output_dict to a tidy DataFrame.
|
|
52
|
+
|
|
53
|
+
This method standardizes the output of scikit-learn's classification_report function.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
report (Dict[str, Dict[str, float]]): Dictionary from `classification_report(..., output_dict=True)`.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
pd.DataFrame: Index are classes/avg rows (str). Columns include ["precision", "recall", "f1-score", "support"]. The "accuracy" scalar (if present) is stored in df.attrs["accuracy"], and the row is removed.
|
|
60
|
+
"""
|
|
61
|
+
df = pd.DataFrame(report).T
|
|
62
|
+
for col in ["precision", "recall", "f1-score", "support"]:
|
|
63
|
+
if col not in df.columns:
|
|
64
|
+
df[col] = np.nan
|
|
65
|
+
|
|
66
|
+
if "accuracy" in df.index:
|
|
67
|
+
# sklearn puts accuracy scalar in "accuracy" row, usually in 'precision'
|
|
68
|
+
try:
|
|
69
|
+
acc_val = df.loc["accuracy", "precision"]
|
|
70
|
+
if pd.api.types.is_number(acc_val):
|
|
71
|
+
df.attrs["accuracy"] = float(str(acc_val))
|
|
72
|
+
except Exception:
|
|
73
|
+
squeezed_val = df.loc["accuracy"].squeeze()
|
|
74
|
+
if pd.api.types.is_number(squeezed_val):
|
|
75
|
+
df.attrs["accuracy"] = float(str(squeezed_val))
|
|
76
|
+
df = df.drop(index="accuracy", errors="ignore")
|
|
77
|
+
|
|
78
|
+
df.index = df.index.astype(str)
|
|
79
|
+
|
|
80
|
+
is_avg = df.index.str.contains("avg", case=False, regex=True)
|
|
81
|
+
class_df = df.loc[~is_avg].copy()
|
|
82
|
+
avg_df = df.loc[is_avg].copy()
|
|
83
|
+
|
|
84
|
+
num_cols = ["precision", "recall", "f1-score", "support"]
|
|
85
|
+
class_df[num_cols] = class_df[num_cols].apply(pd.to_numeric, errors="coerce")
|
|
86
|
+
avg_df[num_cols] = avg_df[num_cols].apply(pd.to_numeric, errors="coerce")
|
|
87
|
+
|
|
88
|
+
class_df = class_df.sort_index()
|
|
89
|
+
tidy = pd.concat([class_df, avg_df], axis=0)
|
|
90
|
+
return tidy
|
|
91
|
+
|
|
92
|
+
def compute_ci(
|
|
93
|
+
self,
|
|
94
|
+
boot_reports: List[Dict[str, Dict[str, float]]],
|
|
95
|
+
ci: float = 0.95,
|
|
96
|
+
metrics: Tuple[str, ...] = ("precision", "recall", "f1-score"),
|
|
97
|
+
) -> pd.DataFrame:
|
|
98
|
+
"""Compute per-class bootstrap CIs from multiple report dicts.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
boot_reports (List[Dict[str, Dict[str, float]]]): List of `output_dict=True` results over bootstrap repeats.
|
|
102
|
+
ci (float): Confidence level (e.g., 0.95 for 95%).
|
|
103
|
+
metrics (Tuple[str, ...]): Metrics to compute bounds for.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
pd.DataFrame: Multi-index columns with (metric, ["lower","upper","mean"]). Index contains any class/avg labels present in the bootstrap reports.
|
|
107
|
+
"""
|
|
108
|
+
if not boot_reports:
|
|
109
|
+
raise ValueError("boot_reports is empty; provide at least one dict.")
|
|
110
|
+
|
|
111
|
+
# Gather frames; union of indices (classes/avg rows) across repeats
|
|
112
|
+
frames = []
|
|
113
|
+
for rep in boot_reports:
|
|
114
|
+
df = self.to_dataframe(rep)
|
|
115
|
+
frames.append(df)
|
|
116
|
+
|
|
117
|
+
# Align on index, stack into 3D array (repeat x class x metric)
|
|
118
|
+
common_index = sorted(set().union(*[f.index for f in frames]))
|
|
119
|
+
arrs = []
|
|
120
|
+
for f in frames:
|
|
121
|
+
sub = f.reindex(common_index)
|
|
122
|
+
arrs.append(sub[[m for m in metrics]].to_numpy(dtype=float))
|
|
123
|
+
arr = np.stack(arrs, axis=0) # shape: (B, C, M)
|
|
124
|
+
|
|
125
|
+
alpha = (1 - ci) / 2
|
|
126
|
+
lower_q = 100 * alpha
|
|
127
|
+
upper_q = 100 * (1 - alpha)
|
|
128
|
+
|
|
129
|
+
lower = np.nanpercentile(arr, lower_q, axis=0) # (C, M)
|
|
130
|
+
upper = np.nanpercentile(arr, upper_q, axis=0) # (C, M)
|
|
131
|
+
mean = np.nanmean(arr, axis=0) # (C, M)
|
|
132
|
+
|
|
133
|
+
out = pd.DataFrame(index=common_index)
|
|
134
|
+
column_tuples = []
|
|
135
|
+
for j, m in enumerate(metrics):
|
|
136
|
+
out[(m, "lower")] = lower[:, j]
|
|
137
|
+
out[(m, "upper")] = upper[:, j]
|
|
138
|
+
out[(m, "mean")] = mean[:, j]
|
|
139
|
+
column_tuples.extend([(m, "lower"), (m, "upper"), (m, "mean")])
|
|
140
|
+
|
|
141
|
+
out.columns = pd.MultiIndex.from_tuples(column_tuples)
|
|
142
|
+
return out
|
|
143
|
+
|
|
144
|
+
# ---------- Palettes & styles ----------
|
|
145
|
+
def _retro_cmap(self, n: int = 256) -> LinearSegmentedColormap:
|
|
146
|
+
"""Create a neon gradient colormap.
|
|
147
|
+
|
|
148
|
+
This colormap transitions through a series of bright, neon colors.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
n (int): Number of discrete colors in the colormap. Defaults to 256.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
LinearSegmentedColormap: The generated colormap.
|
|
155
|
+
"""
|
|
156
|
+
anchors = ["#241937", "#7d00ff", "#ff00ff", "#ff6ec7", "#00f0ff", "#00ff9f"]
|
|
157
|
+
return LinearSegmentedColormap.from_list("retro_neon", anchors, N=n)
|
|
158
|
+
|
|
159
|
+
def _set_mpl_style(self) -> None:
|
|
160
|
+
"""Apply a dark neon Matplotlib theme.
|
|
161
|
+
|
|
162
|
+
This method modifies global rcParams; call before plotting.
|
|
163
|
+
"""
|
|
164
|
+
plt.rcParams.update(
|
|
165
|
+
{
|
|
166
|
+
"figure.facecolor": self.background_hex,
|
|
167
|
+
"axes.facecolor": self.background_hex,
|
|
168
|
+
"axes.edgecolor": self.grid_hex,
|
|
169
|
+
"axes.labelcolor": "#e8e8ff",
|
|
170
|
+
"xtick.color": "#d7d7ff",
|
|
171
|
+
"ytick.color": "#d7d7ff",
|
|
172
|
+
"grid.color": self.grid_hex,
|
|
173
|
+
"text.color": "#f7f7ff",
|
|
174
|
+
"axes.grid": True,
|
|
175
|
+
"grid.linestyle": "--",
|
|
176
|
+
"grid.linewidth": 0.5,
|
|
177
|
+
"legend.facecolor": "#121222",
|
|
178
|
+
"legend.edgecolor": self.grid_hex,
|
|
179
|
+
}
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def _reset_mpl_style(self) -> None:
|
|
183
|
+
"""Reset Matplotlib rcParams to default."""
|
|
184
|
+
plt.rcParams.update(mpl.rcParamsDefault)
|
|
185
|
+
mpl.rcParams.update(mpl.rcParamsDefault)
|
|
186
|
+
|
|
187
|
+
if self.reset_kwargs is not None:
|
|
188
|
+
plt.rcParams.update(self.reset_kwargs)
|
|
189
|
+
mpl.rcParams.update(self.reset_kwargs)
|
|
190
|
+
|
|
191
|
+
def plot_heatmap(
|
|
192
|
+
self,
|
|
193
|
+
df: pd.DataFrame,
|
|
194
|
+
title: str = "Classification Report — Per-Class Metrics",
|
|
195
|
+
classes_only: bool = True,
|
|
196
|
+
figsize: Tuple[int, int] = (12, 6),
|
|
197
|
+
annot_decimals: int = 3,
|
|
198
|
+
vmax: float = 1.0,
|
|
199
|
+
vmin: float = 0.0,
|
|
200
|
+
show_support_strip: bool = False,
|
|
201
|
+
):
|
|
202
|
+
"""Plot a per-class heatmap with an optional right-hand support strip.
|
|
203
|
+
|
|
204
|
+
This visualizes the classification metrics for each class.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
df (pd.DataFrame): DataFrame from `to_dataframe()`.
|
|
208
|
+
title (str): Plot title.
|
|
209
|
+
classes_only (bool): If True, exclude avg rows.
|
|
210
|
+
figsize (Tuple[int, int]): Matplotlib figure size.
|
|
211
|
+
annot_decimals (int): Decimal places for annotations.
|
|
212
|
+
vmax (float): Max heatmap value.
|
|
213
|
+
vmin (float): Min heatmap value.
|
|
214
|
+
show_support_strip (bool): If True, draw normalized support strip at right.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
matplotlib.figure.Figure: The created figure.
|
|
218
|
+
"""
|
|
219
|
+
self._set_mpl_style()
|
|
220
|
+
|
|
221
|
+
work = df.copy()
|
|
222
|
+
if classes_only:
|
|
223
|
+
work = work[~work.index.str.contains("avg", case=False, regex=True)]
|
|
224
|
+
|
|
225
|
+
metric_cols = ["precision", "recall", "f1-score"]
|
|
226
|
+
heat = work[metric_cols].astype(float)
|
|
227
|
+
|
|
228
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
229
|
+
cmap = self._retro_cmap()
|
|
230
|
+
sns.heatmap(
|
|
231
|
+
heat,
|
|
232
|
+
annot=True,
|
|
233
|
+
fmt=f".{annot_decimals}f",
|
|
234
|
+
cmap=cmap,
|
|
235
|
+
vmin=vmin,
|
|
236
|
+
vmax=vmax,
|
|
237
|
+
linewidths=0.5,
|
|
238
|
+
linecolor=self.grid_hex,
|
|
239
|
+
cbar_kws={"label": "Score"},
|
|
240
|
+
ax=ax,
|
|
241
|
+
)
|
|
242
|
+
ax.set_title(title, pad=12, fontweight="bold")
|
|
243
|
+
ax.set_xlabel("Metric")
|
|
244
|
+
ax.set_ylabel("Class")
|
|
245
|
+
|
|
246
|
+
# Optional support strip (normalized 0..1) as an inset axis
|
|
247
|
+
if show_support_strip and "support" in work.columns:
|
|
248
|
+
supports = work["support"].astype(float).fillna(0.0).to_numpy()
|
|
249
|
+
sup_norm = (supports - supports.min()) / (np.ptp(supports) + 1e-9)
|
|
250
|
+
ax_strip = inset_axes(
|
|
251
|
+
ax,
|
|
252
|
+
width="2%",
|
|
253
|
+
height="100%",
|
|
254
|
+
loc="right",
|
|
255
|
+
bbox_to_anchor=(0.03, 0.0, 1, 1),
|
|
256
|
+
bbox_transform=ax.transAxes,
|
|
257
|
+
borderpad=0,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
strip_data = sup_norm[:, None] # (n_classes, 1)
|
|
261
|
+
|
|
262
|
+
sns.heatmap(
|
|
263
|
+
strip_data,
|
|
264
|
+
cmap=self._retro_cmap(),
|
|
265
|
+
cbar=True,
|
|
266
|
+
cbar_kws={"label": "Support (normalized)"},
|
|
267
|
+
xticklabels=False,
|
|
268
|
+
yticklabels=False,
|
|
269
|
+
vmin=0.0,
|
|
270
|
+
vmax=1.0,
|
|
271
|
+
linewidths=0.0,
|
|
272
|
+
ax=ax_strip,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Align strip y-limits to main heatmap
|
|
276
|
+
ax_strip.set_ylim(ax.get_ylim())
|
|
277
|
+
|
|
278
|
+
return fig
|
|
279
|
+
|
|
280
|
+
def plot_grouped_bars(
|
|
281
|
+
self,
|
|
282
|
+
df: pd.DataFrame,
|
|
283
|
+
title: str = "Per-Class Metrics (Grouped Bars)",
|
|
284
|
+
classes_only: bool = True,
|
|
285
|
+
figsize: Tuple[int, int] = (14, 7),
|
|
286
|
+
bar_alpha: float = 0.9,
|
|
287
|
+
ci_df: Optional[pd.DataFrame] = None,
|
|
288
|
+
):
|
|
289
|
+
"""Plot grouped bars for P/R/F1 with support markers and optional CI.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
df (pd.DataFrame): DataFrame from `to_dataframe()`.
|
|
293
|
+
title (str): Plot title.
|
|
294
|
+
classes_only (bool): If True, exclude avg rows.
|
|
295
|
+
figsize (Tuple[int, int]): Figure size.
|
|
296
|
+
bar_alpha (float): Bar alpha.
|
|
297
|
+
ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; adds error bars if provided.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
matplotlib.figure.Figure: The created figure.
|
|
301
|
+
"""
|
|
302
|
+
self._set_mpl_style()
|
|
303
|
+
work = df.copy()
|
|
304
|
+
if classes_only:
|
|
305
|
+
work = work[~work.index.str.contains("avg", case=False, regex=True)]
|
|
306
|
+
|
|
307
|
+
metric_cols = ["precision", "recall", "f1-score"]
|
|
308
|
+
|
|
309
|
+
lng = (
|
|
310
|
+
work[metric_cols]
|
|
311
|
+
.reset_index(names="class")
|
|
312
|
+
.melt(id_vars="class", var_name="metric", value_name="score")
|
|
313
|
+
.dropna(subset=["score"])
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
homozygote_order = ["A", "C", "G", "T"]
|
|
317
|
+
classes = homozygote_order + [
|
|
318
|
+
c for c in lng["class"].unique().tolist() if c not in homozygote_order
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
metrics = metric_cols
|
|
322
|
+
palette = self.retro_palette[: len(metrics)]
|
|
323
|
+
|
|
324
|
+
x = np.arange(len(classes))
|
|
325
|
+
width = 0.25
|
|
326
|
+
offsets = np.linspace(-width, width, num=len(metrics))
|
|
327
|
+
|
|
328
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
329
|
+
|
|
330
|
+
# Secondary axis for support markers
|
|
331
|
+
ax2 = ax.twinx()
|
|
332
|
+
supports = work.reindex(classes)["support"].astype(float).fillna(0.0).values
|
|
333
|
+
|
|
334
|
+
ax2.plot(
|
|
335
|
+
x,
|
|
336
|
+
np.asarray(supports),
|
|
337
|
+
linestyle="None",
|
|
338
|
+
marker="o",
|
|
339
|
+
markersize=6,
|
|
340
|
+
markerfacecolor="#39ff14",
|
|
341
|
+
markeredgecolor="#ffffff",
|
|
342
|
+
alpha=0.9,
|
|
343
|
+
label="Support",
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Plot bars with optional CI error bars
|
|
347
|
+
for i, m in enumerate(metrics):
|
|
348
|
+
vals = (
|
|
349
|
+
lng.loc[lng["metric"].eq(m)]
|
|
350
|
+
.set_index("class")
|
|
351
|
+
.reindex(classes)["score"]
|
|
352
|
+
.values
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
yerr = None
|
|
356
|
+
if ci_df is not None and (m, "lower") in ci_df.columns:
|
|
357
|
+
ci_reindexed = ci_df.reindex(classes)
|
|
358
|
+
lows = ci_reindexed[(m, "lower")].to_numpy(dtype=float)
|
|
359
|
+
ups = ci_reindexed[(m, "upper")].to_numpy(dtype=float)
|
|
360
|
+
|
|
361
|
+
# Convert to symmetric error around the point estimate
|
|
362
|
+
center = vals
|
|
363
|
+
yerr = np.vstack([center - lows, ups - center])
|
|
364
|
+
|
|
365
|
+
ax.bar(
|
|
366
|
+
x + offsets[i],
|
|
367
|
+
np.asarray(vals),
|
|
368
|
+
width=width * 0.95,
|
|
369
|
+
label=m.title(),
|
|
370
|
+
color=palette[i % len(palette)],
|
|
371
|
+
alpha=bar_alpha,
|
|
372
|
+
edgecolor="#ffffff",
|
|
373
|
+
linewidth=0.4,
|
|
374
|
+
yerr=yerr,
|
|
375
|
+
error_kw=dict(ecolor="#ffffff", elinewidth=0.9, capsize=3),
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
ax.set_xticks(x)
|
|
379
|
+
ax.set_xticklabels(classes, rotation=45, ha="right")
|
|
380
|
+
ax.set_ylim(0, 1.05)
|
|
381
|
+
ax.set_ylabel("Score")
|
|
382
|
+
ax.set_title(title, pad=12, fontweight="bold")
|
|
383
|
+
ax.legend(ncols=3, frameon=True, loc="upper left")
|
|
384
|
+
|
|
385
|
+
# Configure secondary (support) axis
|
|
386
|
+
ax2.set_ylabel("Support")
|
|
387
|
+
ax2.grid(False)
|
|
388
|
+
ax2.set_ylim(0, max(1.0, np.asarray(supports).max() * 1.15))
|
|
389
|
+
ax2.legend(loc="upper right", frameon=True)
|
|
390
|
+
|
|
391
|
+
ax.grid(axis="y", linestyle="--", alpha=0.6)
|
|
392
|
+
plt.tight_layout()
|
|
393
|
+
return fig
|
|
394
|
+
|
|
395
|
+
def plot_radar(
|
|
396
|
+
self,
|
|
397
|
+
df: pd.DataFrame,
|
|
398
|
+
title: str = "Macro/Weighted Averages & Top-K Class Radar",
|
|
399
|
+
top_k: int = 5,
|
|
400
|
+
include_micro: bool = True,
|
|
401
|
+
include_macro: bool = True,
|
|
402
|
+
include_weighted: bool = True,
|
|
403
|
+
ci_df: Optional[pd.DataFrame] = None,
|
|
404
|
+
) -> go.Figure:
|
|
405
|
+
"""Interactive radar chart of averages + top-k classes; optional CI bands.
|
|
406
|
+
|
|
407
|
+
This function creates a radar chart using Plotly, displaying the specified metrics for the top-k classes.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
df (pd.DataFrame): DataFrame from `to_dataframe()`.
|
|
411
|
+
title (str): Figure title.
|
|
412
|
+
top_k (int): Include up to top_k classes by support (descending).
|
|
413
|
+
include_micro (bool): Include micro avg trace if available.
|
|
414
|
+
include_macro (bool): Include macro avg trace.
|
|
415
|
+
include_weighted (bool): Include weighted avg trace.
|
|
416
|
+
ci_df (Optional[pd.DataFrame]): Output of `compute_ci()`; draws semi-transparent CI bands.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
plotly.graph_objects.Figure: The interactive radar chart.
|
|
420
|
+
"""
|
|
421
|
+
work = df.copy()
|
|
422
|
+
|
|
423
|
+
is_avg = work.index.str.contains("avg", case=False, regex=True)
|
|
424
|
+
classes = work.loc[~is_avg].copy().sort_values("support", ascending=False)
|
|
425
|
+
if top_k is not None and top_k > 0:
|
|
426
|
+
classes = classes.head(top_k)
|
|
427
|
+
|
|
428
|
+
avgs = []
|
|
429
|
+
if include_macro and "macro avg" in work.index:
|
|
430
|
+
avgs.append(("macro avg", work.loc["macro avg"]))
|
|
431
|
+
if include_weighted and "weighted avg" in work.index:
|
|
432
|
+
avgs.append(("weighted avg", work.loc["weighted avg"]))
|
|
433
|
+
if include_micro and "micro avg" in work.index:
|
|
434
|
+
avgs.append(("micro avg", work.loc["micro avg"]))
|
|
435
|
+
|
|
436
|
+
metrics = ["precision", "recall", "f1-score"]
|
|
437
|
+
theta = metrics + [metrics[0]]
|
|
438
|
+
|
|
439
|
+
fig = go.Figure()
|
|
440
|
+
|
|
441
|
+
def _add_ci_band(name: str, color: str):
|
|
442
|
+
if ci_df is None:
|
|
443
|
+
return
|
|
444
|
+
if not all([(m, "lower") in ci_df.columns for m in metrics]):
|
|
445
|
+
return
|
|
446
|
+
if name not in ci_df.index:
|
|
447
|
+
return
|
|
448
|
+
lows = [
|
|
449
|
+
float(pd.to_numeric(ci_df.loc[name, (m, "lower")], errors="coerce"))
|
|
450
|
+
for m in metrics
|
|
451
|
+
]
|
|
452
|
+
ups = [
|
|
453
|
+
float(pd.to_numeric(ci_df.loc[name, (m, "upper")], errors="coerce"))
|
|
454
|
+
for m in metrics
|
|
455
|
+
]
|
|
456
|
+
lows.append(lows[0])
|
|
457
|
+
ups.append(ups[0])
|
|
458
|
+
|
|
459
|
+
# Plotly polar CI band: plot upper path, then lower reversed with fill
|
|
460
|
+
fig.add_trace(
|
|
461
|
+
go.Scatterpolar(
|
|
462
|
+
r=ups,
|
|
463
|
+
theta=theta,
|
|
464
|
+
mode="lines",
|
|
465
|
+
line=dict(width=0),
|
|
466
|
+
hoverinfo="skip",
|
|
467
|
+
showlegend=False,
|
|
468
|
+
)
|
|
469
|
+
)
|
|
470
|
+
fig.add_trace(
|
|
471
|
+
go.Scatterpolar(
|
|
472
|
+
r=lows[::-1],
|
|
473
|
+
theta=theta[::-1],
|
|
474
|
+
mode="lines",
|
|
475
|
+
line=dict(width=0),
|
|
476
|
+
fill="toself",
|
|
477
|
+
fillcolor=(
|
|
478
|
+
color.replace("#", "rgba(") if False else None
|
|
479
|
+
), # placeholder
|
|
480
|
+
hoverinfo="skip",
|
|
481
|
+
name=f"{name} CI",
|
|
482
|
+
showlegend=False,
|
|
483
|
+
opacity=0.20,
|
|
484
|
+
)
|
|
485
|
+
)
|
|
486
|
+
# Workaround: directly set fillcolor via marker color on last trace
|
|
487
|
+
fig.data[-1].fillcolor = f"{color}33" # add ~20% alpha
|
|
488
|
+
|
|
489
|
+
# Add average traces with CI first
|
|
490
|
+
for i, (name, row) in enumerate(avgs):
|
|
491
|
+
r = [float(row.get(m, np.nan)) for m in metrics]
|
|
492
|
+
r.append(r[0])
|
|
493
|
+
color = self.retro_palette[i % len(self.retro_palette)]
|
|
494
|
+
_add_ci_band(name, color)
|
|
495
|
+
fig.add_trace(
|
|
496
|
+
go.Scatterpolar(
|
|
497
|
+
r=r,
|
|
498
|
+
theta=theta,
|
|
499
|
+
name=name.title(),
|
|
500
|
+
mode="lines+markers",
|
|
501
|
+
line=dict(width=3, color=color),
|
|
502
|
+
marker=dict(size=7, color=color),
|
|
503
|
+
opacity=0.95,
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# Add class traces (top-k) with optional CI
|
|
508
|
+
base_idx = len(avgs)
|
|
509
|
+
for i, (cls, row) in enumerate(classes[metrics].iterrows()):
|
|
510
|
+
r = [float(row.get(m, np.nan)) for m in metrics]
|
|
511
|
+
r.append(r[0])
|
|
512
|
+
color = self.retro_palette[(base_idx + i) % len(self.retro_palette)]
|
|
513
|
+
_add_ci_band(str(cls), color)
|
|
514
|
+
fig.add_trace(
|
|
515
|
+
go.Scatterpolar(
|
|
516
|
+
r=r,
|
|
517
|
+
theta=theta,
|
|
518
|
+
name=str(cls),
|
|
519
|
+
mode="lines+markers",
|
|
520
|
+
line=dict(width=2, color=color),
|
|
521
|
+
marker=dict(size=6, color=color),
|
|
522
|
+
opacity=0.85,
|
|
523
|
+
)
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
fig.update_layout(
|
|
527
|
+
title=title,
|
|
528
|
+
template="plotly_dark",
|
|
529
|
+
paper_bgcolor=self.background_hex,
|
|
530
|
+
plot_bgcolor=self.background_hex,
|
|
531
|
+
polar=dict(
|
|
532
|
+
bgcolor="#111122",
|
|
533
|
+
radialaxis=dict(range=[0, 1.05], showline=True, gridcolor="#33334d"),
|
|
534
|
+
angularaxis=dict(gridcolor="#33334d"),
|
|
535
|
+
),
|
|
536
|
+
legend=dict(
|
|
537
|
+
bgcolor="#121222",
|
|
538
|
+
bordercolor="#2a2a3a",
|
|
539
|
+
borderwidth=1,
|
|
540
|
+
orientation="h",
|
|
541
|
+
yanchor="bottom",
|
|
542
|
+
y=-0.15,
|
|
543
|
+
x=0.5,
|
|
544
|
+
xanchor="center",
|
|
545
|
+
),
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
return fig
|
|
549
|
+
|
|
550
|
+
def plot_all(
|
|
551
|
+
self,
|
|
552
|
+
report: Dict[str, Dict[str, float]],
|
|
553
|
+
title_prefix: str = "Classification Report",
|
|
554
|
+
heatmap_classes_only: bool = True,
|
|
555
|
+
radar_top_k: int = 10,
|
|
556
|
+
boot_reports: Optional[List[Dict[str, Dict[str, float]]]] = None,
|
|
557
|
+
ci: float = 0.95,
|
|
558
|
+
show: bool = True,
|
|
559
|
+
) -> Dict[str, Union["Figure", go.Figure]]:
|
|
560
|
+
"""Generate all visuals, with optional CI from bootstrap reports.
|
|
561
|
+
|
|
562
|
+
Args:
|
|
563
|
+
report (Dict[str, Dict[str, float]]): The `output_dict=True` classification report (single run).
|
|
564
|
+
title_prefix (str): Common prefix for titles.
|
|
565
|
+
heatmap_classes_only (bool): Exclude averages in heatmap if True.
|
|
566
|
+
radar_top_k (int): Number of top classes (by support) on radar.
|
|
567
|
+
boot_reports (Optional[List[Dict[str, Dict[str, float]]]]): Optional list of bootstrap report dicts for CI.
|
|
568
|
+
ci (float): Confidence level (e.g., 0.95).
|
|
569
|
+
show (bool): If True, call plt.show() for Matplotlib figures.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
Dict[str, Union[matplotlib.figure.Figure, plotly.graph_objects.Figure]]: Keys: {"heatmap_fig", "bars_fig", "radar_fig"}.
|
|
573
|
+
"""
|
|
574
|
+
df = self.to_dataframe(report)
|
|
575
|
+
acc = df.attrs.get("accuracy", None)
|
|
576
|
+
acc_str = f" (Accuracy: {acc:.3f})" if isinstance(acc, float) else ""
|
|
577
|
+
|
|
578
|
+
ci_df = None
|
|
579
|
+
if boot_reports:
|
|
580
|
+
ci_df = self.compute_ci(boot_reports, ci=ci)
|
|
581
|
+
|
|
582
|
+
heatmap_fig = self.plot_heatmap(
|
|
583
|
+
df,
|
|
584
|
+
title=f"{title_prefix} — Heatmap{acc_str}",
|
|
585
|
+
classes_only=heatmap_classes_only,
|
|
586
|
+
show_support_strip=False,
|
|
587
|
+
)
|
|
588
|
+
bars_fig = self.plot_grouped_bars(
|
|
589
|
+
df,
|
|
590
|
+
title=f"{title_prefix} — Grouped Bars{acc_str}",
|
|
591
|
+
classes_only=True,
|
|
592
|
+
ci_df=ci_df,
|
|
593
|
+
)
|
|
594
|
+
radar_fig = self.plot_radar(
|
|
595
|
+
df,
|
|
596
|
+
title=f"{title_prefix} — Averages & Top-{radar_top_k} Classes",
|
|
597
|
+
top_k=radar_top_k,
|
|
598
|
+
ci_df=ci_df,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if show:
|
|
602
|
+
plt.show()
|
|
603
|
+
|
|
604
|
+
return {
|
|
605
|
+
"heatmap_fig": heatmap_fig,
|
|
606
|
+
"bars_fig": bars_fig,
|
|
607
|
+
"radar_fig": radar_fig,
|
|
608
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def configure_logger(
|
|
5
|
+
logger: logging.Logger,
|
|
6
|
+
*,
|
|
7
|
+
verbose: bool = False,
|
|
8
|
+
debug: bool = False,
|
|
9
|
+
quiet_level: int = logging.ERROR,
|
|
10
|
+
) -> logging.Logger:
|
|
11
|
+
"""Force a logger and its handlers to respect verbose/debug controls."""
|
|
12
|
+
if debug:
|
|
13
|
+
level = logging.DEBUG
|
|
14
|
+
elif verbose:
|
|
15
|
+
level = logging.INFO
|
|
16
|
+
else:
|
|
17
|
+
level = quiet_level
|
|
18
|
+
|
|
19
|
+
logger.setLevel(level)
|
|
20
|
+
for handler in getattr(logger, "handlers", ()):
|
|
21
|
+
handler.setLevel(level)
|
|
22
|
+
return logger
|