gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import seaborn as sns
|
|
10
|
+
from matplotlib.figure import Figure
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ==================== STANDALONE PLOTTING FUNCTIONS ====================
|
|
14
|
+
|
|
15
|
+
def create_boxplot(
|
|
16
|
+
data: Dict[str, np.ndarray],
|
|
17
|
+
title: str = "Boxplot",
|
|
18
|
+
xlabel: str = "Group",
|
|
19
|
+
ylabel: str = "Value",
|
|
20
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
21
|
+
show_points: bool = True,
|
|
22
|
+
rotation: int = 45,
|
|
23
|
+
) -> Figure:
|
|
24
|
+
"""
|
|
25
|
+
Create a boxplot from dictionary of arrays.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
data : Dict[str, np.ndarray]
|
|
30
|
+
Dictionary mapping group names to arrays of values.
|
|
31
|
+
title : str
|
|
32
|
+
Plot title.
|
|
33
|
+
xlabel : str
|
|
34
|
+
X-axis label.
|
|
35
|
+
ylabel : str
|
|
36
|
+
Y-axis label.
|
|
37
|
+
figsize : Tuple[int, int]
|
|
38
|
+
Figure size.
|
|
39
|
+
show_points : bool
|
|
40
|
+
Whether to overlay individual points.
|
|
41
|
+
rotation : int
|
|
42
|
+
X-tick label rotation angle.
|
|
43
|
+
|
|
44
|
+
Returns
|
|
45
|
+
-------
|
|
46
|
+
Figure
|
|
47
|
+
Matplotlib figure object.
|
|
48
|
+
"""
|
|
49
|
+
# Prepare data for seaborn
|
|
50
|
+
rows = []
|
|
51
|
+
for group, values in data.items():
|
|
52
|
+
for val in np.asarray(values).flatten():
|
|
53
|
+
rows.append({"group": group, "value": val})
|
|
54
|
+
df = pd.DataFrame(rows)
|
|
55
|
+
|
|
56
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
57
|
+
|
|
58
|
+
with sns.axes_style("whitegrid"):
|
|
59
|
+
sns.boxplot(data=df, x="group", y="value", ax=ax, palette="Set2")
|
|
60
|
+
|
|
61
|
+
if show_points:
|
|
62
|
+
sns.stripplot(
|
|
63
|
+
data=df, x="group", y="value", ax=ax,
|
|
64
|
+
color="black", alpha=0.3, size=3, jitter=True,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
ax.set_title(title, fontsize=12, fontweight="bold")
|
|
68
|
+
ax.set_xlabel(xlabel)
|
|
69
|
+
ax.set_ylabel(ylabel)
|
|
70
|
+
|
|
71
|
+
if rotation:
|
|
72
|
+
plt.xticks(rotation=rotation, ha="right")
|
|
73
|
+
|
|
74
|
+
fig.tight_layout()
|
|
75
|
+
return fig
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def create_violin_plot(
|
|
79
|
+
data: Dict[str, np.ndarray],
|
|
80
|
+
title: str = "Violin Plot",
|
|
81
|
+
xlabel: str = "Group",
|
|
82
|
+
ylabel: str = "Value",
|
|
83
|
+
figsize: Tuple[int, int] = (10, 6),
|
|
84
|
+
show_box: bool = True,
|
|
85
|
+
rotation: int = 45,
|
|
86
|
+
) -> Figure:
|
|
87
|
+
"""
|
|
88
|
+
Create a violin plot from dictionary of arrays.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
data : Dict[str, np.ndarray]
|
|
93
|
+
Dictionary mapping group names to arrays of values.
|
|
94
|
+
title : str
|
|
95
|
+
Plot title.
|
|
96
|
+
xlabel : str
|
|
97
|
+
X-axis label.
|
|
98
|
+
ylabel : str
|
|
99
|
+
Y-axis label.
|
|
100
|
+
figsize : Tuple[int, int]
|
|
101
|
+
Figure size.
|
|
102
|
+
show_box : bool
|
|
103
|
+
Whether to overlay a boxplot inside the violin.
|
|
104
|
+
rotation : int
|
|
105
|
+
X-tick label rotation angle.
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
Figure
|
|
110
|
+
Matplotlib figure object.
|
|
111
|
+
"""
|
|
112
|
+
# Prepare data for seaborn
|
|
113
|
+
rows = []
|
|
114
|
+
for group, values in data.items():
|
|
115
|
+
for val in np.asarray(values).flatten():
|
|
116
|
+
rows.append({"group": group, "value": val})
|
|
117
|
+
df = pd.DataFrame(rows)
|
|
118
|
+
|
|
119
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
120
|
+
|
|
121
|
+
with sns.axes_style("whitegrid"):
|
|
122
|
+
sns.violinplot(
|
|
123
|
+
data=df, x="group", y="value", ax=ax,
|
|
124
|
+
palette="Set2", inner=None, cut=0,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if show_box:
|
|
128
|
+
sns.boxplot(
|
|
129
|
+
data=df, x="group", y="value", ax=ax,
|
|
130
|
+
width=0.15, showcaps=True,
|
|
131
|
+
boxprops={"facecolor": "white", "edgecolor": "black"},
|
|
132
|
+
whiskerprops={"color": "black"},
|
|
133
|
+
medianprops={"color": "red"},
|
|
134
|
+
showfliers=False,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
ax.set_title(title, fontsize=12, fontweight="bold")
|
|
138
|
+
ax.set_xlabel(xlabel)
|
|
139
|
+
ax.set_ylabel(ylabel)
|
|
140
|
+
|
|
141
|
+
if rotation:
|
|
142
|
+
plt.xticks(rotation=rotation, ha="right")
|
|
143
|
+
|
|
144
|
+
fig.tight_layout()
|
|
145
|
+
return fig
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def create_heatmap(
|
|
149
|
+
data: np.ndarray,
|
|
150
|
+
row_labels: Optional[Sequence[str]] = None,
|
|
151
|
+
col_labels: Optional[Sequence[str]] = None,
|
|
152
|
+
title: str = "Heatmap",
|
|
153
|
+
cmap: str = "RdYlBu_r",
|
|
154
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
155
|
+
annot: bool = False,
|
|
156
|
+
center: Optional[float] = None,
|
|
157
|
+
) -> Figure:
|
|
158
|
+
"""
|
|
159
|
+
Create a heatmap from a 2D array.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
data : np.ndarray
|
|
164
|
+
2D array of values.
|
|
165
|
+
row_labels : Sequence[str], optional
|
|
166
|
+
Row labels.
|
|
167
|
+
col_labels : Sequence[str], optional
|
|
168
|
+
Column labels.
|
|
169
|
+
title : str
|
|
170
|
+
Plot title.
|
|
171
|
+
cmap : str
|
|
172
|
+
Colormap name.
|
|
173
|
+
figsize : Tuple[int, int]
|
|
174
|
+
Figure size.
|
|
175
|
+
annot : bool
|
|
176
|
+
Whether to annotate cells with values.
|
|
177
|
+
center : float, optional
|
|
178
|
+
Center value for diverging colormaps.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Figure
|
|
183
|
+
Matplotlib figure object.
|
|
184
|
+
"""
|
|
185
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
186
|
+
|
|
187
|
+
with sns.axes_style("white"):
|
|
188
|
+
sns.heatmap(
|
|
189
|
+
data, ax=ax,
|
|
190
|
+
xticklabels=col_labels if col_labels else False,
|
|
191
|
+
yticklabels=row_labels if row_labels else False,
|
|
192
|
+
cmap=cmap,
|
|
193
|
+
annot=annot,
|
|
194
|
+
center=center,
|
|
195
|
+
cbar_kws={"shrink": 0.8},
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
ax.set_title(title, fontsize=12, fontweight="bold")
|
|
199
|
+
fig.tight_layout()
|
|
200
|
+
return fig
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def create_scatter(
|
|
204
|
+
x: np.ndarray,
|
|
205
|
+
y: np.ndarray,
|
|
206
|
+
labels: Optional[np.ndarray] = None,
|
|
207
|
+
title: str = "Scatter Plot",
|
|
208
|
+
xlabel: str = "X",
|
|
209
|
+
ylabel: str = "Y",
|
|
210
|
+
figsize: Tuple[int, int] = (10, 8),
|
|
211
|
+
alpha: float = 0.6,
|
|
212
|
+
size: int = 20,
|
|
213
|
+
add_diagonal: bool = False,
|
|
214
|
+
) -> Figure:
|
|
215
|
+
"""
|
|
216
|
+
Create a scatter plot.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
x : np.ndarray
|
|
221
|
+
X values.
|
|
222
|
+
y : np.ndarray
|
|
223
|
+
Y values.
|
|
224
|
+
labels : np.ndarray, optional
|
|
225
|
+
Labels for coloring points.
|
|
226
|
+
title : str
|
|
227
|
+
Plot title.
|
|
228
|
+
xlabel : str
|
|
229
|
+
X-axis label.
|
|
230
|
+
ylabel : str
|
|
231
|
+
Y-axis label.
|
|
232
|
+
figsize : Tuple[int, int]
|
|
233
|
+
Figure size.
|
|
234
|
+
alpha : float
|
|
235
|
+
Point transparency.
|
|
236
|
+
size : int
|
|
237
|
+
Point size.
|
|
238
|
+
add_diagonal : bool
|
|
239
|
+
Whether to add y=x diagonal line.
|
|
240
|
+
|
|
241
|
+
Returns
|
|
242
|
+
-------
|
|
243
|
+
Figure
|
|
244
|
+
Matplotlib figure object.
|
|
245
|
+
"""
|
|
246
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
247
|
+
|
|
248
|
+
if labels is not None:
|
|
249
|
+
unique_labels = np.unique(labels)
|
|
250
|
+
colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
|
|
251
|
+
|
|
252
|
+
for i, label in enumerate(unique_labels):
|
|
253
|
+
mask = labels == label
|
|
254
|
+
ax.scatter(x[mask], y[mask], c=[colors[i]], s=size, alpha=alpha, label=str(label))
|
|
255
|
+
|
|
256
|
+
ax.legend(frameon=False, loc="best")
|
|
257
|
+
else:
|
|
258
|
+
ax.scatter(x, y, s=size, alpha=alpha)
|
|
259
|
+
|
|
260
|
+
if add_diagonal:
|
|
261
|
+
lims = [
|
|
262
|
+
min(ax.get_xlim()[0], ax.get_ylim()[0]),
|
|
263
|
+
max(ax.get_xlim()[1], ax.get_ylim()[1]),
|
|
264
|
+
]
|
|
265
|
+
ax.plot(lims, lims, 'k--', alpha=0.5, label='y=x')
|
|
266
|
+
ax.set_xlim(lims)
|
|
267
|
+
ax.set_ylim(lims)
|
|
268
|
+
|
|
269
|
+
ax.set_title(title, fontsize=12, fontweight="bold")
|
|
270
|
+
ax.set_xlabel(xlabel)
|
|
271
|
+
ax.set_ylabel(ylabel)
|
|
272
|
+
|
|
273
|
+
fig.tight_layout()
|
|
274
|
+
return fig
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def create_radar_chart(
|
|
278
|
+
values: Dict[str, float],
|
|
279
|
+
title: str = "Radar Chart",
|
|
280
|
+
figsize: Tuple[int, int] = (8, 8),
|
|
281
|
+
fill: bool = True,
|
|
282
|
+
alpha: float = 0.25,
|
|
283
|
+
) -> Figure:
|
|
284
|
+
"""
|
|
285
|
+
Create a radar/spider chart.
|
|
286
|
+
|
|
287
|
+
Parameters
|
|
288
|
+
----------
|
|
289
|
+
values : Dict[str, float]
|
|
290
|
+
Dictionary mapping metric names to values (should be normalized 0-1).
|
|
291
|
+
title : str
|
|
292
|
+
Plot title.
|
|
293
|
+
figsize : Tuple[int, int]
|
|
294
|
+
Figure size.
|
|
295
|
+
fill : bool
|
|
296
|
+
Whether to fill the radar area.
|
|
297
|
+
alpha : float
|
|
298
|
+
Fill transparency.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
Figure
|
|
303
|
+
Matplotlib figure object.
|
|
304
|
+
"""
|
|
305
|
+
labels = list(values.keys())
|
|
306
|
+
stats = list(values.values())
|
|
307
|
+
|
|
308
|
+
# Close the plot
|
|
309
|
+
stats = stats + stats[:1]
|
|
310
|
+
|
|
311
|
+
# Calculate angles
|
|
312
|
+
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
|
|
313
|
+
angles = angles + angles[:1]
|
|
314
|
+
|
|
315
|
+
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
|
316
|
+
|
|
317
|
+
ax.plot(angles, stats, 'o-', linewidth=2, label="Metrics")
|
|
318
|
+
|
|
319
|
+
if fill:
|
|
320
|
+
ax.fill(angles, stats, alpha=alpha)
|
|
321
|
+
|
|
322
|
+
ax.set_xticks(angles[:-1])
|
|
323
|
+
ax.set_xticklabels(labels)
|
|
324
|
+
ax.set_title(title, fontsize=12, fontweight="bold", y=1.08)
|
|
325
|
+
|
|
326
|
+
fig.tight_layout()
|
|
327
|
+
return fig
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# ==================== PLOTTER CLASS ====================
|
|
331
|
+
|
|
332
|
+
class EvaluationPlotter:
|
|
333
|
+
"""
|
|
334
|
+
Plotting helper for evaluation outputs.
|
|
335
|
+
Produces meaningful, compact figures that summarize fit quality.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(self, style: str = "whitegrid"):
|
|
339
|
+
self.style = style
|
|
340
|
+
|
|
341
|
+
@staticmethod
|
|
342
|
+
def _deg_set(deg: Optional[object]) -> Optional[set]:
|
|
343
|
+
if deg is None:
|
|
344
|
+
return None
|
|
345
|
+
names = None
|
|
346
|
+
if isinstance(deg, dict):
|
|
347
|
+
names = deg.get("names", None)
|
|
348
|
+
elif hasattr(deg, "columns") and "names" in getattr(deg, "columns", []):
|
|
349
|
+
names = deg["names"]
|
|
350
|
+
else:
|
|
351
|
+
names = deg
|
|
352
|
+
if names is None:
|
|
353
|
+
return None
|
|
354
|
+
if hasattr(names, "tolist"):
|
|
355
|
+
names = names.tolist()
|
|
356
|
+
return set([str(x) for x in names])
|
|
357
|
+
|
|
358
|
+
def scatter_means_grid(
|
|
359
|
+
self,
|
|
360
|
+
data: Dict[str, Tuple[np.ndarray, np.ndarray, Sequence[str]]],
|
|
361
|
+
stats: Optional[Mapping[str, Dict[str, float]]] = None,
|
|
362
|
+
deg_map: Optional[Mapping[str, object]] = None,
|
|
363
|
+
max_panels: int = 12,
|
|
364
|
+
ncols: int = 3,
|
|
365
|
+
figsize: Tuple[int, int] = (15, 12),
|
|
366
|
+
alpha_other: float = 0.4,
|
|
367
|
+
alpha_deg: float = 0.9,
|
|
368
|
+
):
|
|
369
|
+
"""
|
|
370
|
+
Grid of scatter plots: mean(real) vs mean(generated) per condition key.
|
|
371
|
+
data: key -> (real_means, gen_means, gene_names)
|
|
372
|
+
stats: key -> {'pearson': float, 'mse': float}
|
|
373
|
+
deg_map: key -> DEG-like object (iterable or dict with 'names')
|
|
374
|
+
"""
|
|
375
|
+
keys = list(data.keys())[:max_panels]
|
|
376
|
+
n = len(keys)
|
|
377
|
+
ncols = min(ncols, n)
|
|
378
|
+
nrows = int(np.ceil(n / ncols)) if n > 0 else 1
|
|
379
|
+
|
|
380
|
+
with sns.axes_style(self.style):
|
|
381
|
+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
|
|
382
|
+
lims = None
|
|
383
|
+
# compute shared limits for comparability
|
|
384
|
+
all_vals = []
|
|
385
|
+
for k in keys:
|
|
386
|
+
rm, gm, _ = data[k]
|
|
387
|
+
all_vals.append(rm)
|
|
388
|
+
all_vals.append(gm)
|
|
389
|
+
if all_vals:
|
|
390
|
+
v = np.concatenate(all_vals)
|
|
391
|
+
lo, hi = np.nanpercentile(v, [0.5, 99.5])
|
|
392
|
+
pad = (hi - lo) * 0.05
|
|
393
|
+
lims = (lo - pad, hi + pad)
|
|
394
|
+
|
|
395
|
+
for i, k in enumerate(keys):
|
|
396
|
+
r, g, genes = data[k]
|
|
397
|
+
ax = axes[i // ncols, i % ncols]
|
|
398
|
+
# highlight DEGs if provided
|
|
399
|
+
degs = self._deg_set(deg_map.get(k)) if deg_map else None
|
|
400
|
+
if degs:
|
|
401
|
+
mask = np.isin(np.asarray(genes).astype(str), list(degs))
|
|
402
|
+
ax.scatter(r[~mask], g[~mask], s=8, alpha=alpha_other, label="Other")
|
|
403
|
+
ax.scatter(r[mask], g[mask], s=10, alpha=alpha_deg, label="DEGs", color="#d62728")
|
|
404
|
+
ax.legend(frameon=False, fontsize=8, loc="upper left")
|
|
405
|
+
else:
|
|
406
|
+
ax.scatter(r, g, s=8, alpha=alpha_other)
|
|
407
|
+
|
|
408
|
+
ax.plot(lims, lims, ls="--", c="gray", lw=1) if lims else None
|
|
409
|
+
if lims:
|
|
410
|
+
ax.set_xlim(lims); ax.set_ylim(lims)
|
|
411
|
+
ax.set_title(k, fontsize=10)
|
|
412
|
+
ax.set_xlabel("Mean expression (real)", fontsize=9)
|
|
413
|
+
ax.set_ylabel("Mean expression (generated)", fontsize=9)
|
|
414
|
+
|
|
415
|
+
if stats and k in stats:
|
|
416
|
+
s = stats[k]
|
|
417
|
+
txt = f"r={s.get('pearson', np.nan):.2f} MSE={s.get('mse', np.nan):.2e}"
|
|
418
|
+
ax.text(0.02, 0.98, txt, transform=ax.transAxes, va="top", ha="left", fontsize=8)
|
|
419
|
+
|
|
420
|
+
# hide empty axes
|
|
421
|
+
for j in range(n, nrows * ncols):
|
|
422
|
+
ax = axes[j // ncols, j % ncols]
|
|
423
|
+
ax.axis("off")
|
|
424
|
+
|
|
425
|
+
fig.tight_layout()
|
|
426
|
+
return fig
|
|
427
|
+
|
|
428
|
+
def residuals_violin(
|
|
429
|
+
self,
|
|
430
|
+
residuals: Dict[str, np.ndarray],
|
|
431
|
+
clip_percentiles: Tuple[float, float] = (1.0, 99.0),
|
|
432
|
+
figsize: Tuple[int, int] = (12, 4),
|
|
433
|
+
rotate_xticks: bool = True,
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Violin/box overlay of residuals (generated - real), per condition key.
|
|
437
|
+
"""
|
|
438
|
+
rows = []
|
|
439
|
+
for k, v in residuals.items():
|
|
440
|
+
v = np.asarray(v, dtype=float)
|
|
441
|
+
lo, hi = np.nanpercentile(v, clip_percentiles)
|
|
442
|
+
v = np.clip(v, lo, hi)
|
|
443
|
+
rows.extend([(k, x) for x in v])
|
|
444
|
+
df = pd.DataFrame(rows, columns=["condition", "residual"])
|
|
445
|
+
|
|
446
|
+
with sns.axes_style(self.style):
|
|
447
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
448
|
+
sns.violinplot(data=df, x="condition", y="residual", inner=None, cut=0, ax=ax, color="#9ecae1")
|
|
449
|
+
sns.boxplot(data=df, x="condition", y="residual", ax=ax, width=0.15, showcaps=True,
|
|
450
|
+
boxprops={"facecolor": "white"}, showfliers=False)
|
|
451
|
+
ax.axhline(0, ls="--", c="gray", lw=1)
|
|
452
|
+
ax.set_title("Residual distributions per condition (generated − real)")
|
|
453
|
+
ax.set_xlabel("Condition")
|
|
454
|
+
ax.set_ylabel("Residual")
|
|
455
|
+
if rotate_xticks:
|
|
456
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
|
|
457
|
+
fig.tight_layout()
|
|
458
|
+
return fig
|
|
459
|
+
|
|
460
|
+
def metrics_bar(
|
|
461
|
+
self,
|
|
462
|
+
metrics_per_key: Mapping[str, Mapping[str, float]],
|
|
463
|
+
order: Optional[List[str]] = None,
|
|
464
|
+
figsize: Tuple[int, int] = (12, 4),
|
|
465
|
+
):
|
|
466
|
+
"""
|
|
467
|
+
Grouped bar chart of metrics per condition key.
|
|
468
|
+
metrics_per_key: key -> {metric_name: value}
|
|
469
|
+
"""
|
|
470
|
+
rows = []
|
|
471
|
+
metric_names = set()
|
|
472
|
+
for k, m in metrics_per_key.items():
|
|
473
|
+
for name, val in m.items():
|
|
474
|
+
metric_names.add(name)
|
|
475
|
+
rows.append((k, name, val))
|
|
476
|
+
df = pd.DataFrame(rows, columns=["condition", "metric", "value"])
|
|
477
|
+
|
|
478
|
+
# default order: descending by pearson if present else by first metric
|
|
479
|
+
if order is None and "pearson" in (metric_names or []):
|
|
480
|
+
agg = df[df.metric == "pearson"].sort_values("value", ascending=True)
|
|
481
|
+
order = agg["condition"].tolist()
|
|
482
|
+
elif order is None:
|
|
483
|
+
first = df.metric.iloc[0] if not df.empty else None
|
|
484
|
+
if first:
|
|
485
|
+
agg = df[df.metric == first].sort_values("value", ascending=False)
|
|
486
|
+
order = agg["condition"].tolist()
|
|
487
|
+
|
|
488
|
+
with sns.axes_style(self.style):
|
|
489
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
490
|
+
sns.barplot(data=df, x="condition", y="value", hue="metric", ax=ax)
|
|
491
|
+
ax.set_title("Evaluation metrics per condition")
|
|
492
|
+
ax.set_xlabel("Condition")
|
|
493
|
+
ax.set_ylabel("Metric value")
|
|
494
|
+
if order:
|
|
495
|
+
ax.set_xticklabels(order)
|
|
496
|
+
ax.legend(frameon=False, ncols=min(4, len(metric_names)))
|
|
497
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
|
|
498
|
+
fig.tight_layout()
|
|
499
|
+
return fig
|