gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1096 @@
1
+ """
2
+ Comprehensive visualization module for gene expression evaluation.
3
+
4
+ Provides publication-quality plots for evaluation results:
5
+ - Boxplots and violin plots for metric distributions
6
+ - Radar plots for multi-metric comparison
7
+ - Scatter plots for real vs generated expression
8
+ - Embedding plots (PCA, UMAP) for data visualization
9
+ - Heatmaps for per-gene metrics
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Dict, List, Optional, Tuple, Union, Any
14
+ from pathlib import Path
15
+ import numpy as np
16
+ import pandas as pd
17
+ import warnings
18
+
19
+ try:
20
+ import matplotlib.pyplot as plt
21
+ import matplotlib.patches as mpatches
22
+ from matplotlib.figure import Figure
23
+ import seaborn as sns
24
+ except ImportError:
25
+ raise ImportError(
26
+ "matplotlib and seaborn are required for visualization. "
27
+ "Install with: pip install matplotlib seaborn"
28
+ )
29
+
30
+
31
+ class PlotStyle:
32
+ """Plot styling configuration."""
33
+
34
+ # Color palettes
35
+ REAL_COLOR = "#1f77b4" # Blue
36
+ GENERATED_COLOR = "#ff7f0e" # Orange
37
+
38
+ METRIC_PALETTE = {
39
+ "pearson": "#2ecc71", # Green
40
+ "spearman": "#27ae60", # Dark green
41
+ "mean_pearson": "#3498db", # Blue
42
+ "mean_spearman": "#2980b9", # Dark blue
43
+ "wasserstein_1": "#e74c3c", # Red
44
+ "wasserstein_2": "#c0392b", # Dark red
45
+ "mmd": "#9b59b6", # Purple
46
+ "energy": "#8e44ad", # Dark purple
47
+ "multivariate_wasserstein": "#f39c12", # Yellow
48
+ "multivariate_mmd": "#d35400", # Orange
49
+ }
50
+
51
+ # Default figure sizes
52
+ FIGURE_SMALL = (8, 6)
53
+ FIGURE_MEDIUM = (12, 8)
54
+ FIGURE_LARGE = (16, 12)
55
+ FIGURE_WIDE = (16, 6)
56
+
57
+ # Style settings
58
+ STYLE = "whitegrid"
59
+ CONTEXT = "paper"
60
+ FONT_SCALE = 1.2
61
+
62
+
63
+ class EvaluationVisualizer:
64
+ """
65
+ Comprehensive visualizer for evaluation results.
66
+
67
+ Generates all plots from EvaluationResult objects.
68
+
69
+ Parameters
70
+ ----------
71
+ results : EvaluationResult
72
+ Evaluation results to visualize
73
+ style : str
74
+ Seaborn style
75
+ context : str
76
+ Seaborn context
77
+ font_scale : float
78
+ Font scale multiplier
79
+ dpi : int
80
+ Resolution for saved figures
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ results: "EvaluationResult",
86
+ style: str = PlotStyle.STYLE,
87
+ context: str = PlotStyle.CONTEXT,
88
+ font_scale: float = PlotStyle.FONT_SCALE,
89
+ dpi: int = 150,
90
+ ):
91
+ self.results = results
92
+ self.style = style
93
+ self.context = context
94
+ self.font_scale = font_scale
95
+ self.dpi = dpi
96
+
97
+ # Apply style
98
+ sns.set_style(style)
99
+ sns.set_context(context, font_scale=font_scale)
100
+
101
+ def _get_metric_data(
102
+ self,
103
+ metric_name: str,
104
+ split: Optional[str] = None
105
+ ) -> pd.DataFrame:
106
+ """Extract metric data as DataFrame."""
107
+ rows = []
108
+
109
+ for split_name, split_result in self.results.splits.items():
110
+ if split is not None and split_name != split:
111
+ continue
112
+
113
+ for cond_key, cond in split_result.conditions.items():
114
+ if metric_name in cond.metrics:
115
+ value = cond.metrics[metric_name].aggregate_value
116
+ rows.append({
117
+ "split": split_name,
118
+ "condition": cond_key,
119
+ "perturbation": cond.perturbation or cond_key,
120
+ "value": value,
121
+ "metric": metric_name,
122
+ })
123
+
124
+ return pd.DataFrame(rows)
125
+
126
+ def _get_all_metrics_data(
127
+ self,
128
+ split: Optional[str] = None
129
+ ) -> pd.DataFrame:
130
+ """Extract all metrics as DataFrame."""
131
+ rows = []
132
+
133
+ for split_name, split_result in self.results.splits.items():
134
+ if split is not None and split_name != split:
135
+ continue
136
+
137
+ for cond_key, cond in split_result.conditions.items():
138
+ row = {
139
+ "split": split_name,
140
+ "condition": cond_key,
141
+ "perturbation": cond.perturbation or cond_key,
142
+ }
143
+ for metric_name, metric_result in cond.metrics.items():
144
+ row[metric_name] = metric_result.aggregate_value
145
+ rows.append(row)
146
+
147
+ return pd.DataFrame(rows)
148
+
149
+ # ==================== BOXPLOTS ====================
150
+
151
+ def boxplot_metrics(
152
+ self,
153
+ metrics: Optional[List[str]] = None,
154
+ split: Optional[str] = None,
155
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
156
+ palette: Optional[Dict[str, str]] = None,
157
+ ) -> Figure:
158
+ """
159
+ Create boxplot of metric values across conditions.
160
+
161
+ Parameters
162
+ ----------
163
+ metrics : List[str], optional
164
+ Metrics to include. If None, uses all available.
165
+ split : str, optional
166
+ Filter to specific split
167
+ figsize : Tuple[int, int]
168
+ Figure size
169
+ palette : Dict[str, str], optional
170
+ Color mapping for metrics
171
+
172
+ Returns
173
+ -------
174
+ Figure
175
+ Matplotlib figure
176
+ """
177
+ df = self._get_all_metrics_data(split)
178
+
179
+ if df.empty:
180
+ fig, ax = plt.subplots(figsize=figsize)
181
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
182
+ return fig
183
+
184
+ # Get metric columns
185
+ meta_cols = ["split", "condition", "perturbation"]
186
+ available_metrics = [c for c in df.columns if c not in meta_cols]
187
+
188
+ if metrics is not None:
189
+ available_metrics = [m for m in metrics if m in available_metrics]
190
+
191
+ # Melt to long format
192
+ df_long = df.melt(
193
+ id_vars=meta_cols,
194
+ value_vars=available_metrics,
195
+ var_name="metric",
196
+ value_name="value"
197
+ )
198
+
199
+ # Create figure
200
+ fig, ax = plt.subplots(figsize=figsize)
201
+
202
+ colors = palette or PlotStyle.METRIC_PALETTE
203
+ color_list = [colors.get(m, "#95a5a6") for m in available_metrics]
204
+
205
+ sns.boxplot(
206
+ data=df_long,
207
+ x="metric",
208
+ y="value",
209
+ palette=color_list,
210
+ ax=ax,
211
+ )
212
+
213
+ ax.set_xlabel("Metric")
214
+ ax.set_ylabel("Value")
215
+ ax.set_title(f"Metric Distributions{' (' + split + ')' if split else ''}")
216
+ plt.xticks(rotation=45, ha='right')
217
+
218
+ fig.tight_layout()
219
+ return fig
220
+
221
+ def boxplot_by_condition(
222
+ self,
223
+ metric_name: str,
224
+ split: Optional[str] = None,
225
+ max_conditions: int = 20,
226
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
227
+ ) -> Figure:
228
+ """
229
+ Create boxplot of a single metric across conditions.
230
+
231
+ Shows per-gene distribution for each condition.
232
+ """
233
+ fig, ax = plt.subplots(figsize=figsize)
234
+
235
+ rows = []
236
+ for split_name, split_result in self.results.splits.items():
237
+ if split is not None and split_name != split:
238
+ continue
239
+
240
+ for cond_key, cond in split_result.conditions.items():
241
+ if metric_name in cond.metrics:
242
+ per_gene = cond.metrics[metric_name].per_gene_values
243
+ for val in per_gene[:1000]: # Limit for performance
244
+ rows.append({
245
+ "condition": cond.perturbation or cond_key[:20],
246
+ "value": val,
247
+ })
248
+
249
+ df = pd.DataFrame(rows)
250
+
251
+ if df.empty:
252
+ ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
253
+ return fig
254
+
255
+ # Limit conditions
256
+ top_conditions = df.groupby("condition")["value"].median().nlargest(max_conditions).index
257
+ df = df[df["condition"].isin(top_conditions)]
258
+
259
+ sns.boxplot(
260
+ data=df,
261
+ x="condition",
262
+ y="value",
263
+ ax=ax,
264
+ color=PlotStyle.METRIC_PALETTE.get(metric_name, "#3498db"),
265
+ )
266
+
267
+ ax.set_xlabel("Condition")
268
+ ax.set_ylabel(metric_name)
269
+ ax.set_title(f"Per-Gene {metric_name} by Condition")
270
+ plt.xticks(rotation=45, ha='right')
271
+
272
+ fig.tight_layout()
273
+ return fig
274
+
275
+ # ==================== VIOLIN PLOTS ====================
276
+
277
+ def violin_metrics(
278
+ self,
279
+ metrics: Optional[List[str]] = None,
280
+ split: Optional[str] = None,
281
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
282
+ ) -> Figure:
283
+ """
284
+ Create violin plot of metric values across conditions.
285
+ """
286
+ df = self._get_all_metrics_data(split)
287
+
288
+ if df.empty:
289
+ fig, ax = plt.subplots(figsize=figsize)
290
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
291
+ return fig
292
+
293
+ meta_cols = ["split", "condition", "perturbation"]
294
+ available_metrics = [c for c in df.columns if c not in meta_cols]
295
+
296
+ if metrics is not None:
297
+ available_metrics = [m for m in metrics if m in available_metrics]
298
+
299
+ df_long = df.melt(
300
+ id_vars=meta_cols,
301
+ value_vars=available_metrics,
302
+ var_name="metric",
303
+ value_name="value"
304
+ )
305
+
306
+ fig, ax = plt.subplots(figsize=figsize)
307
+
308
+ colors = [PlotStyle.METRIC_PALETTE.get(m, "#95a5a6") for m in available_metrics]
309
+
310
+ sns.violinplot(
311
+ data=df_long,
312
+ x="metric",
313
+ y="value",
314
+ palette=colors,
315
+ ax=ax,
316
+ inner="box",
317
+ )
318
+
319
+ ax.set_xlabel("Metric")
320
+ ax.set_ylabel("Value")
321
+ ax.set_title(f"Metric Distributions (Violin){' (' + split + ')' if split else ''}")
322
+ plt.xticks(rotation=45, ha='right')
323
+
324
+ fig.tight_layout()
325
+ return fig
326
+
327
+ def violin_per_gene(
328
+ self,
329
+ metric_name: str,
330
+ split: Optional[str] = None,
331
+ max_conditions: int = 10,
332
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
333
+ ) -> Figure:
334
+ """
335
+ Create violin plot showing per-gene metric distributions.
336
+ """
337
+ fig, ax = plt.subplots(figsize=figsize)
338
+
339
+ rows = []
340
+ for split_name, split_result in self.results.splits.items():
341
+ if split is not None and split_name != split:
342
+ continue
343
+
344
+ for cond_key, cond in split_result.conditions.items():
345
+ if metric_name in cond.metrics:
346
+ per_gene = cond.metrics[metric_name].per_gene_values
347
+ for val in per_gene:
348
+ rows.append({
349
+ "condition": cond.perturbation or cond_key[:15],
350
+ "value": val,
351
+ })
352
+
353
+ df = pd.DataFrame(rows)
354
+
355
+ if df.empty:
356
+ ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
357
+ return fig
358
+
359
+ # Limit to top conditions by median
360
+ top_conditions = df.groupby("condition")["value"].median().nlargest(max_conditions).index
361
+ df = df[df["condition"].isin(top_conditions)]
362
+
363
+ sns.violinplot(
364
+ data=df,
365
+ x="condition",
366
+ y="value",
367
+ ax=ax,
368
+ color=PlotStyle.METRIC_PALETTE.get(metric_name, "#3498db"),
369
+ inner="quartile",
370
+ )
371
+
372
+ ax.set_xlabel("Condition")
373
+ ax.set_ylabel(metric_name)
374
+ ax.set_title(f"Per-Gene {metric_name} Distribution")
375
+ plt.xticks(rotation=45, ha='right')
376
+
377
+ fig.tight_layout()
378
+ return fig
379
+
380
+ # ==================== RADAR PLOTS ====================
381
+
382
+ def radar_plot(
383
+ self,
384
+ metrics: Optional[List[str]] = None,
385
+ conditions: Optional[List[str]] = None,
386
+ split: Optional[str] = None,
387
+ max_conditions: int = 6,
388
+ normalize: bool = True,
389
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
390
+ ) -> Figure:
391
+ """
392
+ Create radar plot comparing multiple metrics across conditions.
393
+
394
+ Parameters
395
+ ----------
396
+ metrics : List[str], optional
397
+ Metrics to include (should be 3+)
398
+ conditions : List[str], optional
399
+ Conditions to compare
400
+ split : str, optional
401
+ Filter to specific split
402
+ max_conditions : int
403
+ Maximum conditions to show
404
+ normalize : bool
405
+ Whether to normalize metrics to [0, 1]
406
+ figsize : Tuple[int, int]
407
+ Figure size
408
+
409
+ Returns
410
+ -------
411
+ Figure
412
+ Matplotlib figure with radar plot
413
+ """
414
+ df = self._get_all_metrics_data(split)
415
+
416
+ if df.empty:
417
+ fig, ax = plt.subplots(figsize=figsize)
418
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
419
+ return fig
420
+
421
+ meta_cols = ["split", "condition", "perturbation"]
422
+ available_metrics = [c for c in df.columns if c not in meta_cols]
423
+
424
+ if metrics is not None:
425
+ available_metrics = [m for m in metrics if m in available_metrics]
426
+
427
+ if len(available_metrics) < 3:
428
+ fig, ax = plt.subplots(figsize=figsize)
429
+ ax.text(0.5, 0.5, "Need at least 3 metrics for radar plot", ha='center', va='center')
430
+ return fig
431
+
432
+ # Select conditions
433
+ if conditions is not None:
434
+ df = df[df["perturbation"].isin(conditions)]
435
+ else:
436
+ # Take top conditions by mean metric
437
+ df["_mean"] = df[available_metrics].mean(axis=1)
438
+ top = df.nlargest(max_conditions, "_mean")["perturbation"].unique()
439
+ df = df[df["perturbation"].isin(top)]
440
+
441
+ # Normalize if requested
442
+ if normalize:
443
+ for m in available_metrics:
444
+ col = df[m]
445
+ min_val, max_val = col.min(), col.max()
446
+ if max_val > min_val:
447
+ df[m] = (col - min_val) / (max_val - min_val)
448
+
449
+ # Set up radar chart
450
+ n_metrics = len(available_metrics)
451
+ angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
452
+ angles += angles[:1] # Close the loop
453
+
454
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
455
+
456
+ # Plot each condition
457
+ colors = sns.color_palette("husl", len(df["perturbation"].unique()))
458
+
459
+ for i, (_, row) in enumerate(df.iterrows()):
460
+ values = [row[m] for m in available_metrics]
461
+ values += values[:1] # Close the loop
462
+
463
+ ax.plot(angles, values, 'o-', linewidth=2,
464
+ label=row["perturbation"][:20], color=colors[i % len(colors)])
465
+ ax.fill(angles, values, alpha=0.1, color=colors[i % len(colors)])
466
+
467
+ # Set labels
468
+ ax.set_xticks(angles[:-1])
469
+ ax.set_xticklabels(available_metrics, size=10)
470
+
471
+ ax.set_title("Multi-Metric Comparison", size=14, y=1.1)
472
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
473
+
474
+ fig.tight_layout()
475
+ return fig
476
+
477
+ def radar_split_comparison(
478
+ self,
479
+ metrics: Optional[List[str]] = None,
480
+ normalize: bool = True,
481
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
482
+ ) -> Figure:
483
+ """
484
+ Radar plot comparing aggregate metrics across splits.
485
+ """
486
+ # Collect aggregate metrics per split
487
+ data = {}
488
+ for split_name, split_result in self.results.splits.items():
489
+ split_result.compute_aggregates()
490
+ data[split_name] = {}
491
+ for key, value in split_result.aggregate_metrics.items():
492
+ if key.endswith("_mean"):
493
+ metric_name = key[:-5]
494
+ data[split_name][metric_name] = value
495
+
496
+ if not data:
497
+ fig, ax = plt.subplots(figsize=figsize)
498
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
499
+ return fig
500
+
501
+ df = pd.DataFrame(data).T
502
+
503
+ if metrics is not None:
504
+ df = df[[m for m in metrics if m in df.columns]]
505
+
506
+ available_metrics = df.columns.tolist()
507
+
508
+ if len(available_metrics) < 3:
509
+ fig, ax = plt.subplots(figsize=figsize)
510
+ ax.text(0.5, 0.5, "Need at least 3 metrics for radar plot", ha='center', va='center')
511
+ return fig
512
+
513
+ # Normalize
514
+ if normalize:
515
+ for col in df.columns:
516
+ min_val, max_val = df[col].min(), df[col].max()
517
+ if max_val > min_val:
518
+ df[col] = (df[col] - min_val) / (max_val - min_val)
519
+
520
+ # Create radar
521
+ n_metrics = len(available_metrics)
522
+ angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
523
+ angles += angles[:1]
524
+
525
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
526
+
527
+ colors = [PlotStyle.REAL_COLOR, PlotStyle.GENERATED_COLOR, "#2ecc71", "#9b59b6"]
528
+
529
+ for i, (split_name, row) in enumerate(df.iterrows()):
530
+ values = [row[m] for m in available_metrics]
531
+ values += values[:1]
532
+
533
+ ax.plot(angles, values, 'o-', linewidth=2,
534
+ label=split_name, color=colors[i % len(colors)])
535
+ ax.fill(angles, values, alpha=0.15, color=colors[i % len(colors)])
536
+
537
+ ax.set_xticks(angles[:-1])
538
+ ax.set_xticklabels(available_metrics, size=10)
539
+ ax.set_title("Split Comparison", size=14, y=1.1)
540
+ ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
541
+
542
+ fig.tight_layout()
543
+ return fig
544
+
545
+ # ==================== SCATTER PLOTS ====================
546
+
547
+ def scatter_real_vs_generated(
548
+ self,
549
+ condition: str,
550
+ split: Optional[str] = None,
551
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
552
+ alpha: float = 0.5,
553
+ ) -> Figure:
554
+ """
555
+ Scatter plot of real vs generated mean expression.
556
+
557
+ Parameters
558
+ ----------
559
+ condition : str
560
+ Condition key or perturbation name
561
+ split : str, optional
562
+ Filter to specific split
563
+ figsize : Tuple[int, int]
564
+ Figure size
565
+ alpha : float
566
+ Point transparency
567
+
568
+ Returns
569
+ -------
570
+ Figure
571
+ Scatter plot figure
572
+ """
573
+ # Find the condition
574
+ cond_result = None
575
+ for split_name, split_result in self.results.splits.items():
576
+ if split is not None and split_name != split:
577
+ continue
578
+
579
+ for cond_key, cond in split_result.conditions.items():
580
+ if cond_key == condition or cond.perturbation == condition:
581
+ cond_result = cond
582
+ break
583
+
584
+ if cond_result is None or cond_result.real_mean is None:
585
+ fig, ax = plt.subplots(figsize=figsize)
586
+ ax.text(0.5, 0.5, f"No data for condition: {condition}", ha='center', va='center')
587
+ return fig
588
+
589
+ fig, ax = plt.subplots(figsize=figsize)
590
+
591
+ real_mean = cond_result.real_mean
592
+ gen_mean = cond_result.generated_mean
593
+
594
+ ax.scatter(real_mean, gen_mean, alpha=alpha, s=10, c=PlotStyle.REAL_COLOR)
595
+
596
+ # Add diagonal line
597
+ lims = [
598
+ min(real_mean.min(), gen_mean.min()),
599
+ max(real_mean.max(), gen_mean.max()),
600
+ ]
601
+ ax.plot(lims, lims, 'k--', alpha=0.5, label='y=x')
602
+
603
+ # Add correlation
604
+ if "pearson" in cond_result.metrics:
605
+ r = cond_result.metrics["pearson"].aggregate_value
606
+ ax.text(0.05, 0.95, f'r = {r:.3f}', transform=ax.transAxes,
607
+ fontsize=12, verticalalignment='top')
608
+
609
+ ax.set_xlabel("Real Mean Expression")
610
+ ax.set_ylabel("Generated Mean Expression")
611
+ ax.set_title(f"Real vs Generated: {cond_result.perturbation or condition}")
612
+
613
+ fig.tight_layout()
614
+ return fig
615
+
616
+ def scatter_grid(
617
+ self,
618
+ split: Optional[str] = None,
619
+ max_conditions: int = 12,
620
+ ncols: int = 4,
621
+ figsize_per_panel: Tuple[float, float] = (4, 4),
622
+ ) -> Figure:
623
+ """
624
+ Grid of scatter plots for multiple conditions.
625
+ """
626
+ # Collect conditions
627
+ conditions = []
628
+ for split_name, split_result in self.results.splits.items():
629
+ if split is not None and split_name != split:
630
+ continue
631
+
632
+ for cond in split_result.conditions.values():
633
+ if cond.real_mean is not None:
634
+ conditions.append(cond)
635
+
636
+ conditions = conditions[:max_conditions]
637
+ n = len(conditions)
638
+
639
+ if n == 0:
640
+ fig, ax = plt.subplots()
641
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
642
+ return fig
643
+
644
+ ncols = min(ncols, n)
645
+ nrows = int(np.ceil(n / ncols))
646
+
647
+ fig, axes = plt.subplots(
648
+ nrows, ncols,
649
+ figsize=(figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows),
650
+ squeeze=False
651
+ )
652
+
653
+ for i, cond in enumerate(conditions):
654
+ ax = axes[i // ncols, i % ncols]
655
+
656
+ real_mean = cond.real_mean
657
+ gen_mean = cond.generated_mean
658
+
659
+ ax.scatter(real_mean, gen_mean, alpha=0.4, s=5, c=PlotStyle.REAL_COLOR)
660
+
661
+ lims = [
662
+ min(real_mean.min(), gen_mean.min()),
663
+ max(real_mean.max(), gen_mean.max()),
664
+ ]
665
+ ax.plot(lims, lims, 'k--', alpha=0.3, linewidth=0.5)
666
+
667
+ if "pearson" in cond.metrics:
668
+ r = cond.metrics["pearson"].aggregate_value
669
+ ax.text(0.05, 0.95, f'r={r:.2f}', transform=ax.transAxes,
670
+ fontsize=8, verticalalignment='top')
671
+
672
+ ax.set_title(cond.perturbation or cond.condition_key[:20], fontsize=9)
673
+ ax.tick_params(labelsize=7)
674
+
675
+ # Hide empty panels
676
+ for j in range(n, nrows * ncols):
677
+ axes[j // ncols, j % ncols].axis('off')
678
+
679
+ fig.suptitle("Real vs Generated Expression", fontsize=12, y=1.02)
680
+ fig.tight_layout()
681
+ return fig
682
+
683
+ # ==================== HEATMAPS ====================
684
+
685
+ def heatmap_per_gene(
686
+ self,
687
+ metric_name: str,
688
+ split: Optional[str] = None,
689
+ max_genes: int = 50,
690
+ max_conditions: int = 20,
691
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_LARGE,
692
+ cmap: str = "RdYlBu_r",
693
+ ) -> Figure:
694
+ """
695
+ Heatmap of per-gene metric values.
696
+
697
+ Parameters
698
+ ----------
699
+ metric_name : str
700
+ Metric to visualize
701
+ split : str, optional
702
+ Filter to specific split
703
+ max_genes : int
704
+ Maximum genes to show (selects most variable)
705
+ max_conditions : int
706
+ Maximum conditions to show
707
+ figsize : Tuple[int, int]
708
+ Figure size
709
+ cmap : str
710
+ Colormap name
711
+
712
+ Returns
713
+ -------
714
+ Figure
715
+ Heatmap figure
716
+ """
717
+ # Collect data
718
+ data = {}
719
+ gene_names = None
720
+
721
+ for split_name, split_result in self.results.splits.items():
722
+ if split is not None and split_name != split:
723
+ continue
724
+
725
+ for cond_key, cond in split_result.conditions.items():
726
+ if metric_name in cond.metrics:
727
+ data[cond.perturbation or cond_key[:15]] = cond.metrics[metric_name].per_gene_values
728
+ if gene_names is None:
729
+ gene_names = cond.metrics[metric_name].gene_names
730
+
731
+ if not data:
732
+ fig, ax = plt.subplots(figsize=figsize)
733
+ ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
734
+ return fig
735
+
736
+ df = pd.DataFrame(data, index=gene_names)
737
+
738
+ # Select conditions
739
+ if df.shape[1] > max_conditions:
740
+ # Keep conditions with highest variance
741
+ var = df.var()
742
+ top_conds = var.nlargest(max_conditions).index
743
+ df = df[top_conds]
744
+
745
+ # Select genes
746
+ if df.shape[0] > max_genes:
747
+ # Keep genes with highest variance
748
+ gene_var = df.var(axis=1)
749
+ top_genes = gene_var.nlargest(max_genes).index
750
+ df = df.loc[top_genes]
751
+
752
+ fig, ax = plt.subplots(figsize=figsize)
753
+
754
+ sns.heatmap(
755
+ df,
756
+ ax=ax,
757
+ cmap=cmap,
758
+ xticklabels=True,
759
+ yticklabels=True,
760
+ cbar_kws={"label": metric_name},
761
+ )
762
+
763
+ ax.set_xlabel("Condition")
764
+ ax.set_ylabel("Gene")
765
+ ax.set_title(f"Per-Gene {metric_name}")
766
+
767
+ plt.xticks(rotation=45, ha='right')
768
+ plt.yticks(fontsize=8)
769
+
770
+ fig.tight_layout()
771
+ return fig
772
+
773
+ def heatmap_metrics_summary(
774
+ self,
775
+ split: Optional[str] = None,
776
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
777
+ cmap: str = "RdYlBu",
778
+ ) -> Figure:
779
+ """
780
+ Heatmap summarizing all metrics across conditions.
781
+ """
782
+ df = self._get_all_metrics_data(split)
783
+
784
+ if df.empty:
785
+ fig, ax = plt.subplots(figsize=figsize)
786
+ ax.text(0.5, 0.5, "No data available", ha='center', va='center')
787
+ return fig
788
+
789
+ meta_cols = ["split", "condition", "perturbation"]
790
+ metric_cols = [c for c in df.columns if c not in meta_cols]
791
+
792
+ # Pivot for heatmap
793
+ df_pivot = df.set_index("perturbation")[metric_cols]
794
+
795
+ # Normalize columns for visualization
796
+ df_norm = (df_pivot - df_pivot.mean()) / df_pivot.std()
797
+
798
+ fig, ax = plt.subplots(figsize=figsize)
799
+
800
+ sns.heatmap(
801
+ df_norm,
802
+ ax=ax,
803
+ cmap=cmap,
804
+ xticklabels=True,
805
+ yticklabels=True,
806
+ center=0,
807
+ cbar_kws={"label": "Z-score"},
808
+ )
809
+
810
+ ax.set_xlabel("Metric")
811
+ ax.set_ylabel("Condition")
812
+ ax.set_title("Metrics Summary (Z-scored)")
813
+
814
+ plt.xticks(rotation=45, ha='right')
815
+
816
+ fig.tight_layout()
817
+ return fig
818
+
819
+ # ==================== EMBEDDING PLOTS ====================
820
+
821
+ def embedding_plot(
822
+ self,
823
+ data_loader: "GeneExpressionDataLoader",
824
+ method: str = "pca",
825
+ split: Optional[str] = None,
826
+ max_samples: int = 5000,
827
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
828
+ alpha: float = 0.6,
829
+ ) -> Figure:
830
+ """
831
+ Plot embedded data (PCA or UMAP) comparing real and generated.
832
+
833
+ Parameters
834
+ ----------
835
+ data_loader : GeneExpressionDataLoader
836
+ Data loader with real and generated data
837
+ method : str
838
+ Embedding method: "pca" or "umap"
839
+ split : str, optional
840
+ Filter to specific split
841
+ max_samples : int
842
+ Maximum samples to plot
843
+ figsize : Tuple[int, int]
844
+ Figure size
845
+ alpha : float
846
+ Point transparency
847
+
848
+ Returns
849
+ -------
850
+ Figure
851
+ Embedding plot figure
852
+ """
853
+ try:
854
+ import scanpy as sc
855
+ except ImportError:
856
+ raise ImportError("scanpy is required for embedding plots")
857
+
858
+ # Combine real and generated data
859
+ real = data_loader.real.copy()
860
+ gen = data_loader.generated.copy()
861
+
862
+ # Apply split filter if needed
863
+ if split is not None and data_loader.split_column is not None:
864
+ if data_loader.split_column in real.obs.columns:
865
+ mask = real.obs[data_loader.split_column].astype(str) == split
866
+ real = real[mask].copy()
867
+
868
+ # Subsample if needed
869
+ if real.n_obs > max_samples // 2:
870
+ idx = np.random.choice(real.n_obs, max_samples // 2, replace=False)
871
+ real = real[idx].copy()
872
+ if gen.n_obs > max_samples // 2:
873
+ idx = np.random.choice(gen.n_obs, max_samples // 2, replace=False)
874
+ gen = gen[idx].copy()
875
+
876
+ # Add source label
877
+ real.obs["_source"] = "Real"
878
+ gen.obs["_source"] = "Generated"
879
+
880
+ # Concatenate
881
+ combined = real.concatenate(gen, batch_key="_batch")
882
+
883
+ # Compute embedding
884
+ sc.pp.pca(combined, n_comps=50)
885
+
886
+ if method.lower() == "umap":
887
+ sc.pp.neighbors(combined)
888
+ sc.tl.umap(combined)
889
+ x_key = "X_umap"
890
+ x_label, y_label = "UMAP1", "UMAP2"
891
+ else:
892
+ x_key = "X_pca"
893
+ x_label, y_label = "PC1", "PC2"
894
+
895
+ # Plot
896
+ fig, ax = plt.subplots(figsize=figsize)
897
+
898
+ coords = combined.obsm[x_key][:, :2]
899
+ source = combined.obs["_source"]
900
+
901
+ for label, color in [("Real", PlotStyle.REAL_COLOR), ("Generated", PlotStyle.GENERATED_COLOR)]:
902
+ mask = source == label
903
+ ax.scatter(
904
+ coords[mask, 0], coords[mask, 1],
905
+ c=color, label=label, alpha=alpha, s=10
906
+ )
907
+
908
+ ax.set_xlabel(x_label)
909
+ ax.set_ylabel(y_label)
910
+ ax.set_title(f"{method.upper()} Embedding: Real vs Generated")
911
+ ax.legend()
912
+
913
+ fig.tight_layout()
914
+ return fig
915
+
916
+ def embedding_by_condition(
917
+ self,
918
+ data_loader: "GeneExpressionDataLoader",
919
+ method: str = "pca",
920
+ condition_column: Optional[str] = None,
921
+ max_samples: int = 5000,
922
+ figsize: Tuple[int, int] = PlotStyle.FIGURE_LARGE,
923
+ ) -> Figure:
924
+ """
925
+ Embedding plot colored by condition.
926
+ """
927
+ try:
928
+ import scanpy as sc
929
+ except ImportError:
930
+ raise ImportError("scanpy is required for embedding plots")
931
+
932
+ # Use first condition column if not specified
933
+ if condition_column is None:
934
+ condition_column = data_loader.condition_columns[0]
935
+
936
+ # Combine data
937
+ real = data_loader.real.copy()
938
+ gen = data_loader.generated.copy()
939
+
940
+ real.obs["_source"] = "Real"
941
+ gen.obs["_source"] = "Generated"
942
+
943
+ # Subsample
944
+ if real.n_obs > max_samples // 2:
945
+ idx = np.random.choice(real.n_obs, max_samples // 2, replace=False)
946
+ real = real[idx].copy()
947
+ if gen.n_obs > max_samples // 2:
948
+ idx = np.random.choice(gen.n_obs, max_samples // 2, replace=False)
949
+ gen = gen[idx].copy()
950
+
951
+ combined = real.concatenate(gen, batch_key="_batch")
952
+
953
+ # Compute embedding
954
+ sc.pp.pca(combined, n_comps=50)
955
+ if method.lower() == "umap":
956
+ sc.pp.neighbors(combined)
957
+ sc.tl.umap(combined)
958
+ x_key = "X_umap"
959
+ else:
960
+ x_key = "X_pca"
961
+
962
+ # Create side-by-side plot
963
+ fig, axes = plt.subplots(1, 2, figsize=figsize)
964
+
965
+ coords = combined.obsm[x_key][:, :2]
966
+
967
+ for ax, source in zip(axes, ["Real", "Generated"]):
968
+ mask = combined.obs["_source"] == source
969
+
970
+ conditions = combined.obs.loc[mask, condition_column].astype(str)
971
+ unique_conds = conditions.unique()
972
+ colors = sns.color_palette("husl", len(unique_conds))
973
+ color_map = dict(zip(unique_conds, colors))
974
+
975
+ for cond in unique_conds:
976
+ cond_mask = (combined.obs["_source"] == source) & (combined.obs[condition_column].astype(str) == cond)
977
+ ax.scatter(
978
+ coords[cond_mask, 0], coords[cond_mask, 1],
979
+ c=[color_map[cond]], label=cond[:15], alpha=0.6, s=10
980
+ )
981
+
982
+ ax.set_title(f"{source}")
983
+ ax.set_xlabel("Dim 1")
984
+ ax.set_ylabel("Dim 2")
985
+
986
+ # Shared legend
987
+ handles, labels = axes[0].get_legend_handles_labels()
988
+ fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
989
+
990
+ fig.suptitle(f"{method.upper()} Embedding by {condition_column}", fontsize=12)
991
+ fig.tight_layout()
992
+ return fig
993
+
994
+ # ==================== SAVE ALL ====================
995
+
996
+ def save_all(
997
+ self,
998
+ output_dir: Union[str, Path],
999
+ formats: List[str] = ["png", "pdf"],
1000
+ data_loader: Optional["GeneExpressionDataLoader"] = None,
1001
+ ):
1002
+ """
1003
+ Generate and save all plots.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ output_dir : str or Path
1008
+ Directory to save plots
1009
+ formats : List[str]
1010
+ Image formats to save
1011
+ data_loader : GeneExpressionDataLoader, optional
1012
+ If provided, also generates embedding plots
1013
+ """
1014
+ output_dir = Path(output_dir)
1015
+ output_dir.mkdir(parents=True, exist_ok=True)
1016
+
1017
+ plots = {}
1018
+
1019
+ # Generate all plots
1020
+ try:
1021
+ plots["boxplot_metrics"] = self.boxplot_metrics()
1022
+ except Exception as e:
1023
+ warnings.warn(f"Failed to generate boxplot_metrics: {e}")
1024
+
1025
+ try:
1026
+ plots["violin_metrics"] = self.violin_metrics()
1027
+ except Exception as e:
1028
+ warnings.warn(f"Failed to generate violin_metrics: {e}")
1029
+
1030
+ try:
1031
+ plots["radar_split"] = self.radar_split_comparison()
1032
+ except Exception as e:
1033
+ warnings.warn(f"Failed to generate radar_split: {e}")
1034
+
1035
+ try:
1036
+ plots["scatter_grid"] = self.scatter_grid()
1037
+ except Exception as e:
1038
+ warnings.warn(f"Failed to generate scatter_grid: {e}")
1039
+
1040
+ try:
1041
+ plots["heatmap_summary"] = self.heatmap_metrics_summary()
1042
+ except Exception as e:
1043
+ warnings.warn(f"Failed to generate heatmap_summary: {e}")
1044
+
1045
+ # Per-metric plots
1046
+ for metric_name in ["pearson", "wasserstein_1", "mmd"]:
1047
+ try:
1048
+ plots[f"violin_{metric_name}"] = self.violin_per_gene(metric_name)
1049
+ except Exception:
1050
+ pass
1051
+
1052
+ # Embedding plots if data loader provided
1053
+ if data_loader is not None:
1054
+ try:
1055
+ plots["embedding_pca"] = self.embedding_plot(data_loader, method="pca")
1056
+ except Exception as e:
1057
+ warnings.warn(f"Failed to generate PCA embedding: {e}")
1058
+
1059
+ try:
1060
+ plots["embedding_umap"] = self.embedding_plot(data_loader, method="umap")
1061
+ except Exception as e:
1062
+ warnings.warn(f"Failed to generate UMAP embedding: {e}")
1063
+
1064
+ # Save all plots
1065
+ for name, fig in plots.items():
1066
+ for fmt in formats:
1067
+ path = output_dir / f"{name}.{fmt}"
1068
+ fig.savefig(path, dpi=self.dpi, bbox_inches='tight')
1069
+ plt.close(fig)
1070
+
1071
+ print(f"Saved {len(plots)} plots to {output_dir}")
1072
+
1073
+
1074
+ # Convenience function
1075
+ def visualize(
1076
+ results: "EvaluationResult",
1077
+ output_dir: Union[str, Path],
1078
+ data_loader: Optional["GeneExpressionDataLoader"] = None,
1079
+ **kwargs
1080
+ ):
1081
+ """
1082
+ Generate and save all visualizations.
1083
+
1084
+ Parameters
1085
+ ----------
1086
+ results : EvaluationResult
1087
+ Evaluation results to visualize
1088
+ output_dir : str or Path
1089
+ Directory to save plots
1090
+ data_loader : GeneExpressionDataLoader, optional
1091
+ If provided, generates embedding plots
1092
+ **kwargs
1093
+ Additional arguments for EvaluationVisualizer
1094
+ """
1095
+ viz = EvaluationVisualizer(results, **kwargs)
1096
+ viz.save_all(output_dir, data_loader=data_loader)