gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,499 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ from matplotlib.figure import Figure
11
+
12
+
13
+ # ==================== STANDALONE PLOTTING FUNCTIONS ====================
14
+
15
+ def create_boxplot(
16
+ data: Dict[str, np.ndarray],
17
+ title: str = "Boxplot",
18
+ xlabel: str = "Group",
19
+ ylabel: str = "Value",
20
+ figsize: Tuple[int, int] = (10, 6),
21
+ show_points: bool = True,
22
+ rotation: int = 45,
23
+ ) -> Figure:
24
+ """
25
+ Create a boxplot from dictionary of arrays.
26
+
27
+ Parameters
28
+ ----------
29
+ data : Dict[str, np.ndarray]
30
+ Dictionary mapping group names to arrays of values.
31
+ title : str
32
+ Plot title.
33
+ xlabel : str
34
+ X-axis label.
35
+ ylabel : str
36
+ Y-axis label.
37
+ figsize : Tuple[int, int]
38
+ Figure size.
39
+ show_points : bool
40
+ Whether to overlay individual points.
41
+ rotation : int
42
+ X-tick label rotation angle.
43
+
44
+ Returns
45
+ -------
46
+ Figure
47
+ Matplotlib figure object.
48
+ """
49
+ # Prepare data for seaborn
50
+ rows = []
51
+ for group, values in data.items():
52
+ for val in np.asarray(values).flatten():
53
+ rows.append({"group": group, "value": val})
54
+ df = pd.DataFrame(rows)
55
+
56
+ fig, ax = plt.subplots(figsize=figsize)
57
+
58
+ with sns.axes_style("whitegrid"):
59
+ sns.boxplot(data=df, x="group", y="value", ax=ax, palette="Set2")
60
+
61
+ if show_points:
62
+ sns.stripplot(
63
+ data=df, x="group", y="value", ax=ax,
64
+ color="black", alpha=0.3, size=3, jitter=True,
65
+ )
66
+
67
+ ax.set_title(title, fontsize=12, fontweight="bold")
68
+ ax.set_xlabel(xlabel)
69
+ ax.set_ylabel(ylabel)
70
+
71
+ if rotation:
72
+ plt.xticks(rotation=rotation, ha="right")
73
+
74
+ fig.tight_layout()
75
+ return fig
76
+
77
+
78
+ def create_violin_plot(
79
+ data: Dict[str, np.ndarray],
80
+ title: str = "Violin Plot",
81
+ xlabel: str = "Group",
82
+ ylabel: str = "Value",
83
+ figsize: Tuple[int, int] = (10, 6),
84
+ show_box: bool = True,
85
+ rotation: int = 45,
86
+ ) -> Figure:
87
+ """
88
+ Create a violin plot from dictionary of arrays.
89
+
90
+ Parameters
91
+ ----------
92
+ data : Dict[str, np.ndarray]
93
+ Dictionary mapping group names to arrays of values.
94
+ title : str
95
+ Plot title.
96
+ xlabel : str
97
+ X-axis label.
98
+ ylabel : str
99
+ Y-axis label.
100
+ figsize : Tuple[int, int]
101
+ Figure size.
102
+ show_box : bool
103
+ Whether to overlay a boxplot inside the violin.
104
+ rotation : int
105
+ X-tick label rotation angle.
106
+
107
+ Returns
108
+ -------
109
+ Figure
110
+ Matplotlib figure object.
111
+ """
112
+ # Prepare data for seaborn
113
+ rows = []
114
+ for group, values in data.items():
115
+ for val in np.asarray(values).flatten():
116
+ rows.append({"group": group, "value": val})
117
+ df = pd.DataFrame(rows)
118
+
119
+ fig, ax = plt.subplots(figsize=figsize)
120
+
121
+ with sns.axes_style("whitegrid"):
122
+ sns.violinplot(
123
+ data=df, x="group", y="value", ax=ax,
124
+ palette="Set2", inner=None, cut=0,
125
+ )
126
+
127
+ if show_box:
128
+ sns.boxplot(
129
+ data=df, x="group", y="value", ax=ax,
130
+ width=0.15, showcaps=True,
131
+ boxprops={"facecolor": "white", "edgecolor": "black"},
132
+ whiskerprops={"color": "black"},
133
+ medianprops={"color": "red"},
134
+ showfliers=False,
135
+ )
136
+
137
+ ax.set_title(title, fontsize=12, fontweight="bold")
138
+ ax.set_xlabel(xlabel)
139
+ ax.set_ylabel(ylabel)
140
+
141
+ if rotation:
142
+ plt.xticks(rotation=rotation, ha="right")
143
+
144
+ fig.tight_layout()
145
+ return fig
146
+
147
+
148
+ def create_heatmap(
149
+ data: np.ndarray,
150
+ row_labels: Optional[Sequence[str]] = None,
151
+ col_labels: Optional[Sequence[str]] = None,
152
+ title: str = "Heatmap",
153
+ cmap: str = "RdYlBu_r",
154
+ figsize: Tuple[int, int] = (10, 8),
155
+ annot: bool = False,
156
+ center: Optional[float] = None,
157
+ ) -> Figure:
158
+ """
159
+ Create a heatmap from a 2D array.
160
+
161
+ Parameters
162
+ ----------
163
+ data : np.ndarray
164
+ 2D array of values.
165
+ row_labels : Sequence[str], optional
166
+ Row labels.
167
+ col_labels : Sequence[str], optional
168
+ Column labels.
169
+ title : str
170
+ Plot title.
171
+ cmap : str
172
+ Colormap name.
173
+ figsize : Tuple[int, int]
174
+ Figure size.
175
+ annot : bool
176
+ Whether to annotate cells with values.
177
+ center : float, optional
178
+ Center value for diverging colormaps.
179
+
180
+ Returns
181
+ -------
182
+ Figure
183
+ Matplotlib figure object.
184
+ """
185
+ fig, ax = plt.subplots(figsize=figsize)
186
+
187
+ with sns.axes_style("white"):
188
+ sns.heatmap(
189
+ data, ax=ax,
190
+ xticklabels=col_labels if col_labels else False,
191
+ yticklabels=row_labels if row_labels else False,
192
+ cmap=cmap,
193
+ annot=annot,
194
+ center=center,
195
+ cbar_kws={"shrink": 0.8},
196
+ )
197
+
198
+ ax.set_title(title, fontsize=12, fontweight="bold")
199
+ fig.tight_layout()
200
+ return fig
201
+
202
+
203
+ def create_scatter(
204
+ x: np.ndarray,
205
+ y: np.ndarray,
206
+ labels: Optional[np.ndarray] = None,
207
+ title: str = "Scatter Plot",
208
+ xlabel: str = "X",
209
+ ylabel: str = "Y",
210
+ figsize: Tuple[int, int] = (10, 8),
211
+ alpha: float = 0.6,
212
+ size: int = 20,
213
+ add_diagonal: bool = False,
214
+ ) -> Figure:
215
+ """
216
+ Create a scatter plot.
217
+
218
+ Parameters
219
+ ----------
220
+ x : np.ndarray
221
+ X values.
222
+ y : np.ndarray
223
+ Y values.
224
+ labels : np.ndarray, optional
225
+ Labels for coloring points.
226
+ title : str
227
+ Plot title.
228
+ xlabel : str
229
+ X-axis label.
230
+ ylabel : str
231
+ Y-axis label.
232
+ figsize : Tuple[int, int]
233
+ Figure size.
234
+ alpha : float
235
+ Point transparency.
236
+ size : int
237
+ Point size.
238
+ add_diagonal : bool
239
+ Whether to add y=x diagonal line.
240
+
241
+ Returns
242
+ -------
243
+ Figure
244
+ Matplotlib figure object.
245
+ """
246
+ fig, ax = plt.subplots(figsize=figsize)
247
+
248
+ if labels is not None:
249
+ unique_labels = np.unique(labels)
250
+ colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
251
+
252
+ for i, label in enumerate(unique_labels):
253
+ mask = labels == label
254
+ ax.scatter(x[mask], y[mask], c=[colors[i]], s=size, alpha=alpha, label=str(label))
255
+
256
+ ax.legend(frameon=False, loc="best")
257
+ else:
258
+ ax.scatter(x, y, s=size, alpha=alpha)
259
+
260
+ if add_diagonal:
261
+ lims = [
262
+ min(ax.get_xlim()[0], ax.get_ylim()[0]),
263
+ max(ax.get_xlim()[1], ax.get_ylim()[1]),
264
+ ]
265
+ ax.plot(lims, lims, 'k--', alpha=0.5, label='y=x')
266
+ ax.set_xlim(lims)
267
+ ax.set_ylim(lims)
268
+
269
+ ax.set_title(title, fontsize=12, fontweight="bold")
270
+ ax.set_xlabel(xlabel)
271
+ ax.set_ylabel(ylabel)
272
+
273
+ fig.tight_layout()
274
+ return fig
275
+
276
+
277
+ def create_radar_chart(
278
+ values: Dict[str, float],
279
+ title: str = "Radar Chart",
280
+ figsize: Tuple[int, int] = (8, 8),
281
+ fill: bool = True,
282
+ alpha: float = 0.25,
283
+ ) -> Figure:
284
+ """
285
+ Create a radar/spider chart.
286
+
287
+ Parameters
288
+ ----------
289
+ values : Dict[str, float]
290
+ Dictionary mapping metric names to values (should be normalized 0-1).
291
+ title : str
292
+ Plot title.
293
+ figsize : Tuple[int, int]
294
+ Figure size.
295
+ fill : bool
296
+ Whether to fill the radar area.
297
+ alpha : float
298
+ Fill transparency.
299
+
300
+ Returns
301
+ -------
302
+ Figure
303
+ Matplotlib figure object.
304
+ """
305
+ labels = list(values.keys())
306
+ stats = list(values.values())
307
+
308
+ # Close the plot
309
+ stats = stats + stats[:1]
310
+
311
+ # Calculate angles
312
+ angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
313
+ angles = angles + angles[:1]
314
+
315
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
316
+
317
+ ax.plot(angles, stats, 'o-', linewidth=2, label="Metrics")
318
+
319
+ if fill:
320
+ ax.fill(angles, stats, alpha=alpha)
321
+
322
+ ax.set_xticks(angles[:-1])
323
+ ax.set_xticklabels(labels)
324
+ ax.set_title(title, fontsize=12, fontweight="bold", y=1.08)
325
+
326
+ fig.tight_layout()
327
+ return fig
328
+
329
+
330
+ # ==================== PLOTTER CLASS ====================
331
+
332
+ class EvaluationPlotter:
333
+ """
334
+ Plotting helper for evaluation outputs.
335
+ Produces meaningful, compact figures that summarize fit quality.
336
+ """
337
+
338
+ def __init__(self, style: str = "whitegrid"):
339
+ self.style = style
340
+
341
+ @staticmethod
342
+ def _deg_set(deg: Optional[object]) -> Optional[set]:
343
+ if deg is None:
344
+ return None
345
+ names = None
346
+ if isinstance(deg, dict):
347
+ names = deg.get("names", None)
348
+ elif hasattr(deg, "columns") and "names" in getattr(deg, "columns", []):
349
+ names = deg["names"]
350
+ else:
351
+ names = deg
352
+ if names is None:
353
+ return None
354
+ if hasattr(names, "tolist"):
355
+ names = names.tolist()
356
+ return set([str(x) for x in names])
357
+
358
+ def scatter_means_grid(
359
+ self,
360
+ data: Dict[str, Tuple[np.ndarray, np.ndarray, Sequence[str]]],
361
+ stats: Optional[Mapping[str, Dict[str, float]]] = None,
362
+ deg_map: Optional[Mapping[str, object]] = None,
363
+ max_panels: int = 12,
364
+ ncols: int = 3,
365
+ figsize: Tuple[int, int] = (15, 12),
366
+ alpha_other: float = 0.4,
367
+ alpha_deg: float = 0.9,
368
+ ):
369
+ """
370
+ Grid of scatter plots: mean(real) vs mean(generated) per condition key.
371
+ data: key -> (real_means, gen_means, gene_names)
372
+ stats: key -> {'pearson': float, 'mse': float}
373
+ deg_map: key -> DEG-like object (iterable or dict with 'names')
374
+ """
375
+ keys = list(data.keys())[:max_panels]
376
+ n = len(keys)
377
+ ncols = min(ncols, n)
378
+ nrows = int(np.ceil(n / ncols)) if n > 0 else 1
379
+
380
+ with sns.axes_style(self.style):
381
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
382
+ lims = None
383
+ # compute shared limits for comparability
384
+ all_vals = []
385
+ for k in keys:
386
+ rm, gm, _ = data[k]
387
+ all_vals.append(rm)
388
+ all_vals.append(gm)
389
+ if all_vals:
390
+ v = np.concatenate(all_vals)
391
+ lo, hi = np.nanpercentile(v, [0.5, 99.5])
392
+ pad = (hi - lo) * 0.05
393
+ lims = (lo - pad, hi + pad)
394
+
395
+ for i, k in enumerate(keys):
396
+ r, g, genes = data[k]
397
+ ax = axes[i // ncols, i % ncols]
398
+ # highlight DEGs if provided
399
+ degs = self._deg_set(deg_map.get(k)) if deg_map else None
400
+ if degs:
401
+ mask = np.isin(np.asarray(genes).astype(str), list(degs))
402
+ ax.scatter(r[~mask], g[~mask], s=8, alpha=alpha_other, label="Other")
403
+ ax.scatter(r[mask], g[mask], s=10, alpha=alpha_deg, label="DEGs", color="#d62728")
404
+ ax.legend(frameon=False, fontsize=8, loc="upper left")
405
+ else:
406
+ ax.scatter(r, g, s=8, alpha=alpha_other)
407
+
408
+ ax.plot(lims, lims, ls="--", c="gray", lw=1) if lims else None
409
+ if lims:
410
+ ax.set_xlim(lims); ax.set_ylim(lims)
411
+ ax.set_title(k, fontsize=10)
412
+ ax.set_xlabel("Mean expression (real)", fontsize=9)
413
+ ax.set_ylabel("Mean expression (generated)", fontsize=9)
414
+
415
+ if stats and k in stats:
416
+ s = stats[k]
417
+ txt = f"r={s.get('pearson', np.nan):.2f} MSE={s.get('mse', np.nan):.2e}"
418
+ ax.text(0.02, 0.98, txt, transform=ax.transAxes, va="top", ha="left", fontsize=8)
419
+
420
+ # hide empty axes
421
+ for j in range(n, nrows * ncols):
422
+ ax = axes[j // ncols, j % ncols]
423
+ ax.axis("off")
424
+
425
+ fig.tight_layout()
426
+ return fig
427
+
428
+ def residuals_violin(
429
+ self,
430
+ residuals: Dict[str, np.ndarray],
431
+ clip_percentiles: Tuple[float, float] = (1.0, 99.0),
432
+ figsize: Tuple[int, int] = (12, 4),
433
+ rotate_xticks: bool = True,
434
+ ):
435
+ """
436
+ Violin/box overlay of residuals (generated - real), per condition key.
437
+ """
438
+ rows = []
439
+ for k, v in residuals.items():
440
+ v = np.asarray(v, dtype=float)
441
+ lo, hi = np.nanpercentile(v, clip_percentiles)
442
+ v = np.clip(v, lo, hi)
443
+ rows.extend([(k, x) for x in v])
444
+ df = pd.DataFrame(rows, columns=["condition", "residual"])
445
+
446
+ with sns.axes_style(self.style):
447
+ fig, ax = plt.subplots(figsize=figsize)
448
+ sns.violinplot(data=df, x="condition", y="residual", inner=None, cut=0, ax=ax, color="#9ecae1")
449
+ sns.boxplot(data=df, x="condition", y="residual", ax=ax, width=0.15, showcaps=True,
450
+ boxprops={"facecolor": "white"}, showfliers=False)
451
+ ax.axhline(0, ls="--", c="gray", lw=1)
452
+ ax.set_title("Residual distributions per condition (generated − real)")
453
+ ax.set_xlabel("Condition")
454
+ ax.set_ylabel("Residual")
455
+ if rotate_xticks:
456
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
457
+ fig.tight_layout()
458
+ return fig
459
+
460
+ def metrics_bar(
461
+ self,
462
+ metrics_per_key: Mapping[str, Mapping[str, float]],
463
+ order: Optional[List[str]] = None,
464
+ figsize: Tuple[int, int] = (12, 4),
465
+ ):
466
+ """
467
+ Grouped bar chart of metrics per condition key.
468
+ metrics_per_key: key -> {metric_name: value}
469
+ """
470
+ rows = []
471
+ metric_names = set()
472
+ for k, m in metrics_per_key.items():
473
+ for name, val in m.items():
474
+ metric_names.add(name)
475
+ rows.append((k, name, val))
476
+ df = pd.DataFrame(rows, columns=["condition", "metric", "value"])
477
+
478
+ # default order: descending by pearson if present else by first metric
479
+ if order is None and "pearson" in (metric_names or []):
480
+ agg = df[df.metric == "pearson"].sort_values("value", ascending=True)
481
+ order = agg["condition"].tolist()
482
+ elif order is None:
483
+ first = df.metric.iloc[0] if not df.empty else None
484
+ if first:
485
+ agg = df[df.metric == first].sort_values("value", ascending=False)
486
+ order = agg["condition"].tolist()
487
+
488
+ with sns.axes_style(self.style):
489
+ fig, ax = plt.subplots(figsize=figsize)
490
+ sns.barplot(data=df, x="condition", y="value", hue="metric", ax=ax)
491
+ ax.set_title("Evaluation metrics per condition")
492
+ ax.set_xlabel("Condition")
493
+ ax.set_ylabel("Metric value")
494
+ if order:
495
+ ax.set_xticklabels(order)
496
+ ax.legend(frameon=False, ncols=min(4, len(metric_names)))
497
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
498
+ fig.tight_layout()
499
+ return fig