gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
|
@@ -0,0 +1,1096 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive visualization module for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides publication-quality plots for evaluation results:
|
|
5
|
+
- Boxplots and violin plots for metric distributions
|
|
6
|
+
- Radar plots for multi-metric comparison
|
|
7
|
+
- Scatter plots for real vs generated expression
|
|
8
|
+
- Embedding plots (PCA, UMAP) for data visualization
|
|
9
|
+
- Heatmaps for per-gene metrics
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import warnings
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import matplotlib.pyplot as plt
|
|
21
|
+
import matplotlib.patches as mpatches
|
|
22
|
+
from matplotlib.figure import Figure
|
|
23
|
+
import seaborn as sns
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"matplotlib and seaborn are required for visualization. "
|
|
27
|
+
"Install with: pip install matplotlib seaborn"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlotStyle:
|
|
32
|
+
"""Plot styling configuration."""
|
|
33
|
+
|
|
34
|
+
# Color palettes
|
|
35
|
+
REAL_COLOR = "#1f77b4" # Blue
|
|
36
|
+
GENERATED_COLOR = "#ff7f0e" # Orange
|
|
37
|
+
|
|
38
|
+
METRIC_PALETTE = {
|
|
39
|
+
"pearson": "#2ecc71", # Green
|
|
40
|
+
"spearman": "#27ae60", # Dark green
|
|
41
|
+
"mean_pearson": "#3498db", # Blue
|
|
42
|
+
"mean_spearman": "#2980b9", # Dark blue
|
|
43
|
+
"wasserstein_1": "#e74c3c", # Red
|
|
44
|
+
"wasserstein_2": "#c0392b", # Dark red
|
|
45
|
+
"mmd": "#9b59b6", # Purple
|
|
46
|
+
"energy": "#8e44ad", # Dark purple
|
|
47
|
+
"multivariate_wasserstein": "#f39c12", # Yellow
|
|
48
|
+
"multivariate_mmd": "#d35400", # Orange
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
# Default figure sizes
|
|
52
|
+
FIGURE_SMALL = (8, 6)
|
|
53
|
+
FIGURE_MEDIUM = (12, 8)
|
|
54
|
+
FIGURE_LARGE = (16, 12)
|
|
55
|
+
FIGURE_WIDE = (16, 6)
|
|
56
|
+
|
|
57
|
+
# Style settings
|
|
58
|
+
STYLE = "whitegrid"
|
|
59
|
+
CONTEXT = "paper"
|
|
60
|
+
FONT_SCALE = 1.2
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class EvaluationVisualizer:
|
|
64
|
+
"""
|
|
65
|
+
Comprehensive visualizer for evaluation results.
|
|
66
|
+
|
|
67
|
+
Generates all plots from EvaluationResult objects.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
results : EvaluationResult
|
|
72
|
+
Evaluation results to visualize
|
|
73
|
+
style : str
|
|
74
|
+
Seaborn style
|
|
75
|
+
context : str
|
|
76
|
+
Seaborn context
|
|
77
|
+
font_scale : float
|
|
78
|
+
Font scale multiplier
|
|
79
|
+
dpi : int
|
|
80
|
+
Resolution for saved figures
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
results: "EvaluationResult",
|
|
86
|
+
style: str = PlotStyle.STYLE,
|
|
87
|
+
context: str = PlotStyle.CONTEXT,
|
|
88
|
+
font_scale: float = PlotStyle.FONT_SCALE,
|
|
89
|
+
dpi: int = 150,
|
|
90
|
+
):
|
|
91
|
+
self.results = results
|
|
92
|
+
self.style = style
|
|
93
|
+
self.context = context
|
|
94
|
+
self.font_scale = font_scale
|
|
95
|
+
self.dpi = dpi
|
|
96
|
+
|
|
97
|
+
# Apply style
|
|
98
|
+
sns.set_style(style)
|
|
99
|
+
sns.set_context(context, font_scale=font_scale)
|
|
100
|
+
|
|
101
|
+
def _get_metric_data(
|
|
102
|
+
self,
|
|
103
|
+
metric_name: str,
|
|
104
|
+
split: Optional[str] = None
|
|
105
|
+
) -> pd.DataFrame:
|
|
106
|
+
"""Extract metric data as DataFrame."""
|
|
107
|
+
rows = []
|
|
108
|
+
|
|
109
|
+
for split_name, split_result in self.results.splits.items():
|
|
110
|
+
if split is not None and split_name != split:
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
for cond_key, cond in split_result.conditions.items():
|
|
114
|
+
if metric_name in cond.metrics:
|
|
115
|
+
value = cond.metrics[metric_name].aggregate_value
|
|
116
|
+
rows.append({
|
|
117
|
+
"split": split_name,
|
|
118
|
+
"condition": cond_key,
|
|
119
|
+
"perturbation": cond.perturbation or cond_key,
|
|
120
|
+
"value": value,
|
|
121
|
+
"metric": metric_name,
|
|
122
|
+
})
|
|
123
|
+
|
|
124
|
+
return pd.DataFrame(rows)
|
|
125
|
+
|
|
126
|
+
def _get_all_metrics_data(
|
|
127
|
+
self,
|
|
128
|
+
split: Optional[str] = None
|
|
129
|
+
) -> pd.DataFrame:
|
|
130
|
+
"""Extract all metrics as DataFrame."""
|
|
131
|
+
rows = []
|
|
132
|
+
|
|
133
|
+
for split_name, split_result in self.results.splits.items():
|
|
134
|
+
if split is not None and split_name != split:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
for cond_key, cond in split_result.conditions.items():
|
|
138
|
+
row = {
|
|
139
|
+
"split": split_name,
|
|
140
|
+
"condition": cond_key,
|
|
141
|
+
"perturbation": cond.perturbation or cond_key,
|
|
142
|
+
}
|
|
143
|
+
for metric_name, metric_result in cond.metrics.items():
|
|
144
|
+
row[metric_name] = metric_result.aggregate_value
|
|
145
|
+
rows.append(row)
|
|
146
|
+
|
|
147
|
+
return pd.DataFrame(rows)
|
|
148
|
+
|
|
149
|
+
# ==================== BOXPLOTS ====================
|
|
150
|
+
|
|
151
|
+
def boxplot_metrics(
|
|
152
|
+
self,
|
|
153
|
+
metrics: Optional[List[str]] = None,
|
|
154
|
+
split: Optional[str] = None,
|
|
155
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
|
|
156
|
+
palette: Optional[Dict[str, str]] = None,
|
|
157
|
+
) -> Figure:
|
|
158
|
+
"""
|
|
159
|
+
Create boxplot of metric values across conditions.
|
|
160
|
+
|
|
161
|
+
Parameters
|
|
162
|
+
----------
|
|
163
|
+
metrics : List[str], optional
|
|
164
|
+
Metrics to include. If None, uses all available.
|
|
165
|
+
split : str, optional
|
|
166
|
+
Filter to specific split
|
|
167
|
+
figsize : Tuple[int, int]
|
|
168
|
+
Figure size
|
|
169
|
+
palette : Dict[str, str], optional
|
|
170
|
+
Color mapping for metrics
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
Figure
|
|
175
|
+
Matplotlib figure
|
|
176
|
+
"""
|
|
177
|
+
df = self._get_all_metrics_data(split)
|
|
178
|
+
|
|
179
|
+
if df.empty:
|
|
180
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
181
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
182
|
+
return fig
|
|
183
|
+
|
|
184
|
+
# Get metric columns
|
|
185
|
+
meta_cols = ["split", "condition", "perturbation"]
|
|
186
|
+
available_metrics = [c for c in df.columns if c not in meta_cols]
|
|
187
|
+
|
|
188
|
+
if metrics is not None:
|
|
189
|
+
available_metrics = [m for m in metrics if m in available_metrics]
|
|
190
|
+
|
|
191
|
+
# Melt to long format
|
|
192
|
+
df_long = df.melt(
|
|
193
|
+
id_vars=meta_cols,
|
|
194
|
+
value_vars=available_metrics,
|
|
195
|
+
var_name="metric",
|
|
196
|
+
value_name="value"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# Create figure
|
|
200
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
201
|
+
|
|
202
|
+
colors = palette or PlotStyle.METRIC_PALETTE
|
|
203
|
+
color_list = [colors.get(m, "#95a5a6") for m in available_metrics]
|
|
204
|
+
|
|
205
|
+
sns.boxplot(
|
|
206
|
+
data=df_long,
|
|
207
|
+
x="metric",
|
|
208
|
+
y="value",
|
|
209
|
+
palette=color_list,
|
|
210
|
+
ax=ax,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
ax.set_xlabel("Metric")
|
|
214
|
+
ax.set_ylabel("Value")
|
|
215
|
+
ax.set_title(f"Metric Distributions{' (' + split + ')' if split else ''}")
|
|
216
|
+
plt.xticks(rotation=45, ha='right')
|
|
217
|
+
|
|
218
|
+
fig.tight_layout()
|
|
219
|
+
return fig
|
|
220
|
+
|
|
221
|
+
def boxplot_by_condition(
|
|
222
|
+
self,
|
|
223
|
+
metric_name: str,
|
|
224
|
+
split: Optional[str] = None,
|
|
225
|
+
max_conditions: int = 20,
|
|
226
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
|
|
227
|
+
) -> Figure:
|
|
228
|
+
"""
|
|
229
|
+
Create boxplot of a single metric across conditions.
|
|
230
|
+
|
|
231
|
+
Shows per-gene distribution for each condition.
|
|
232
|
+
"""
|
|
233
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
234
|
+
|
|
235
|
+
rows = []
|
|
236
|
+
for split_name, split_result in self.results.splits.items():
|
|
237
|
+
if split is not None and split_name != split:
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
for cond_key, cond in split_result.conditions.items():
|
|
241
|
+
if metric_name in cond.metrics:
|
|
242
|
+
per_gene = cond.metrics[metric_name].per_gene_values
|
|
243
|
+
for val in per_gene[:1000]: # Limit for performance
|
|
244
|
+
rows.append({
|
|
245
|
+
"condition": cond.perturbation or cond_key[:20],
|
|
246
|
+
"value": val,
|
|
247
|
+
})
|
|
248
|
+
|
|
249
|
+
df = pd.DataFrame(rows)
|
|
250
|
+
|
|
251
|
+
if df.empty:
|
|
252
|
+
ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
|
|
253
|
+
return fig
|
|
254
|
+
|
|
255
|
+
# Limit conditions
|
|
256
|
+
top_conditions = df.groupby("condition")["value"].median().nlargest(max_conditions).index
|
|
257
|
+
df = df[df["condition"].isin(top_conditions)]
|
|
258
|
+
|
|
259
|
+
sns.boxplot(
|
|
260
|
+
data=df,
|
|
261
|
+
x="condition",
|
|
262
|
+
y="value",
|
|
263
|
+
ax=ax,
|
|
264
|
+
color=PlotStyle.METRIC_PALETTE.get(metric_name, "#3498db"),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
ax.set_xlabel("Condition")
|
|
268
|
+
ax.set_ylabel(metric_name)
|
|
269
|
+
ax.set_title(f"Per-Gene {metric_name} by Condition")
|
|
270
|
+
plt.xticks(rotation=45, ha='right')
|
|
271
|
+
|
|
272
|
+
fig.tight_layout()
|
|
273
|
+
return fig
|
|
274
|
+
|
|
275
|
+
# ==================== VIOLIN PLOTS ====================
|
|
276
|
+
|
|
277
|
+
def violin_metrics(
|
|
278
|
+
self,
|
|
279
|
+
metrics: Optional[List[str]] = None,
|
|
280
|
+
split: Optional[str] = None,
|
|
281
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_WIDE,
|
|
282
|
+
) -> Figure:
|
|
283
|
+
"""
|
|
284
|
+
Create violin plot of metric values across conditions.
|
|
285
|
+
"""
|
|
286
|
+
df = self._get_all_metrics_data(split)
|
|
287
|
+
|
|
288
|
+
if df.empty:
|
|
289
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
290
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
291
|
+
return fig
|
|
292
|
+
|
|
293
|
+
meta_cols = ["split", "condition", "perturbation"]
|
|
294
|
+
available_metrics = [c for c in df.columns if c not in meta_cols]
|
|
295
|
+
|
|
296
|
+
if metrics is not None:
|
|
297
|
+
available_metrics = [m for m in metrics if m in available_metrics]
|
|
298
|
+
|
|
299
|
+
df_long = df.melt(
|
|
300
|
+
id_vars=meta_cols,
|
|
301
|
+
value_vars=available_metrics,
|
|
302
|
+
var_name="metric",
|
|
303
|
+
value_name="value"
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
307
|
+
|
|
308
|
+
colors = [PlotStyle.METRIC_PALETTE.get(m, "#95a5a6") for m in available_metrics]
|
|
309
|
+
|
|
310
|
+
sns.violinplot(
|
|
311
|
+
data=df_long,
|
|
312
|
+
x="metric",
|
|
313
|
+
y="value",
|
|
314
|
+
palette=colors,
|
|
315
|
+
ax=ax,
|
|
316
|
+
inner="box",
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
ax.set_xlabel("Metric")
|
|
320
|
+
ax.set_ylabel("Value")
|
|
321
|
+
ax.set_title(f"Metric Distributions (Violin){' (' + split + ')' if split else ''}")
|
|
322
|
+
plt.xticks(rotation=45, ha='right')
|
|
323
|
+
|
|
324
|
+
fig.tight_layout()
|
|
325
|
+
return fig
|
|
326
|
+
|
|
327
|
+
def violin_per_gene(
|
|
328
|
+
self,
|
|
329
|
+
metric_name: str,
|
|
330
|
+
split: Optional[str] = None,
|
|
331
|
+
max_conditions: int = 10,
|
|
332
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
333
|
+
) -> Figure:
|
|
334
|
+
"""
|
|
335
|
+
Create violin plot showing per-gene metric distributions.
|
|
336
|
+
"""
|
|
337
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
338
|
+
|
|
339
|
+
rows = []
|
|
340
|
+
for split_name, split_result in self.results.splits.items():
|
|
341
|
+
if split is not None and split_name != split:
|
|
342
|
+
continue
|
|
343
|
+
|
|
344
|
+
for cond_key, cond in split_result.conditions.items():
|
|
345
|
+
if metric_name in cond.metrics:
|
|
346
|
+
per_gene = cond.metrics[metric_name].per_gene_values
|
|
347
|
+
for val in per_gene:
|
|
348
|
+
rows.append({
|
|
349
|
+
"condition": cond.perturbation or cond_key[:15],
|
|
350
|
+
"value": val,
|
|
351
|
+
})
|
|
352
|
+
|
|
353
|
+
df = pd.DataFrame(rows)
|
|
354
|
+
|
|
355
|
+
if df.empty:
|
|
356
|
+
ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
|
|
357
|
+
return fig
|
|
358
|
+
|
|
359
|
+
# Limit to top conditions by median
|
|
360
|
+
top_conditions = df.groupby("condition")["value"].median().nlargest(max_conditions).index
|
|
361
|
+
df = df[df["condition"].isin(top_conditions)]
|
|
362
|
+
|
|
363
|
+
sns.violinplot(
|
|
364
|
+
data=df,
|
|
365
|
+
x="condition",
|
|
366
|
+
y="value",
|
|
367
|
+
ax=ax,
|
|
368
|
+
color=PlotStyle.METRIC_PALETTE.get(metric_name, "#3498db"),
|
|
369
|
+
inner="quartile",
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
ax.set_xlabel("Condition")
|
|
373
|
+
ax.set_ylabel(metric_name)
|
|
374
|
+
ax.set_title(f"Per-Gene {metric_name} Distribution")
|
|
375
|
+
plt.xticks(rotation=45, ha='right')
|
|
376
|
+
|
|
377
|
+
fig.tight_layout()
|
|
378
|
+
return fig
|
|
379
|
+
|
|
380
|
+
# ==================== RADAR PLOTS ====================
|
|
381
|
+
|
|
382
|
+
def radar_plot(
|
|
383
|
+
self,
|
|
384
|
+
metrics: Optional[List[str]] = None,
|
|
385
|
+
conditions: Optional[List[str]] = None,
|
|
386
|
+
split: Optional[str] = None,
|
|
387
|
+
max_conditions: int = 6,
|
|
388
|
+
normalize: bool = True,
|
|
389
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
390
|
+
) -> Figure:
|
|
391
|
+
"""
|
|
392
|
+
Create radar plot comparing multiple metrics across conditions.
|
|
393
|
+
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
metrics : List[str], optional
|
|
397
|
+
Metrics to include (should be 3+)
|
|
398
|
+
conditions : List[str], optional
|
|
399
|
+
Conditions to compare
|
|
400
|
+
split : str, optional
|
|
401
|
+
Filter to specific split
|
|
402
|
+
max_conditions : int
|
|
403
|
+
Maximum conditions to show
|
|
404
|
+
normalize : bool
|
|
405
|
+
Whether to normalize metrics to [0, 1]
|
|
406
|
+
figsize : Tuple[int, int]
|
|
407
|
+
Figure size
|
|
408
|
+
|
|
409
|
+
Returns
|
|
410
|
+
-------
|
|
411
|
+
Figure
|
|
412
|
+
Matplotlib figure with radar plot
|
|
413
|
+
"""
|
|
414
|
+
df = self._get_all_metrics_data(split)
|
|
415
|
+
|
|
416
|
+
if df.empty:
|
|
417
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
418
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
419
|
+
return fig
|
|
420
|
+
|
|
421
|
+
meta_cols = ["split", "condition", "perturbation"]
|
|
422
|
+
available_metrics = [c for c in df.columns if c not in meta_cols]
|
|
423
|
+
|
|
424
|
+
if metrics is not None:
|
|
425
|
+
available_metrics = [m for m in metrics if m in available_metrics]
|
|
426
|
+
|
|
427
|
+
if len(available_metrics) < 3:
|
|
428
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
429
|
+
ax.text(0.5, 0.5, "Need at least 3 metrics for radar plot", ha='center', va='center')
|
|
430
|
+
return fig
|
|
431
|
+
|
|
432
|
+
# Select conditions
|
|
433
|
+
if conditions is not None:
|
|
434
|
+
df = df[df["perturbation"].isin(conditions)]
|
|
435
|
+
else:
|
|
436
|
+
# Take top conditions by mean metric
|
|
437
|
+
df["_mean"] = df[available_metrics].mean(axis=1)
|
|
438
|
+
top = df.nlargest(max_conditions, "_mean")["perturbation"].unique()
|
|
439
|
+
df = df[df["perturbation"].isin(top)]
|
|
440
|
+
|
|
441
|
+
# Normalize if requested
|
|
442
|
+
if normalize:
|
|
443
|
+
for m in available_metrics:
|
|
444
|
+
col = df[m]
|
|
445
|
+
min_val, max_val = col.min(), col.max()
|
|
446
|
+
if max_val > min_val:
|
|
447
|
+
df[m] = (col - min_val) / (max_val - min_val)
|
|
448
|
+
|
|
449
|
+
# Set up radar chart
|
|
450
|
+
n_metrics = len(available_metrics)
|
|
451
|
+
angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
|
|
452
|
+
angles += angles[:1] # Close the loop
|
|
453
|
+
|
|
454
|
+
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
|
455
|
+
|
|
456
|
+
# Plot each condition
|
|
457
|
+
colors = sns.color_palette("husl", len(df["perturbation"].unique()))
|
|
458
|
+
|
|
459
|
+
for i, (_, row) in enumerate(df.iterrows()):
|
|
460
|
+
values = [row[m] for m in available_metrics]
|
|
461
|
+
values += values[:1] # Close the loop
|
|
462
|
+
|
|
463
|
+
ax.plot(angles, values, 'o-', linewidth=2,
|
|
464
|
+
label=row["perturbation"][:20], color=colors[i % len(colors)])
|
|
465
|
+
ax.fill(angles, values, alpha=0.1, color=colors[i % len(colors)])
|
|
466
|
+
|
|
467
|
+
# Set labels
|
|
468
|
+
ax.set_xticks(angles[:-1])
|
|
469
|
+
ax.set_xticklabels(available_metrics, size=10)
|
|
470
|
+
|
|
471
|
+
ax.set_title("Multi-Metric Comparison", size=14, y=1.1)
|
|
472
|
+
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
|
|
473
|
+
|
|
474
|
+
fig.tight_layout()
|
|
475
|
+
return fig
|
|
476
|
+
|
|
477
|
+
def radar_split_comparison(
|
|
478
|
+
self,
|
|
479
|
+
metrics: Optional[List[str]] = None,
|
|
480
|
+
normalize: bool = True,
|
|
481
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
482
|
+
) -> Figure:
|
|
483
|
+
"""
|
|
484
|
+
Radar plot comparing aggregate metrics across splits.
|
|
485
|
+
"""
|
|
486
|
+
# Collect aggregate metrics per split
|
|
487
|
+
data = {}
|
|
488
|
+
for split_name, split_result in self.results.splits.items():
|
|
489
|
+
split_result.compute_aggregates()
|
|
490
|
+
data[split_name] = {}
|
|
491
|
+
for key, value in split_result.aggregate_metrics.items():
|
|
492
|
+
if key.endswith("_mean"):
|
|
493
|
+
metric_name = key[:-5]
|
|
494
|
+
data[split_name][metric_name] = value
|
|
495
|
+
|
|
496
|
+
if not data:
|
|
497
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
498
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
499
|
+
return fig
|
|
500
|
+
|
|
501
|
+
df = pd.DataFrame(data).T
|
|
502
|
+
|
|
503
|
+
if metrics is not None:
|
|
504
|
+
df = df[[m for m in metrics if m in df.columns]]
|
|
505
|
+
|
|
506
|
+
available_metrics = df.columns.tolist()
|
|
507
|
+
|
|
508
|
+
if len(available_metrics) < 3:
|
|
509
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
510
|
+
ax.text(0.5, 0.5, "Need at least 3 metrics for radar plot", ha='center', va='center')
|
|
511
|
+
return fig
|
|
512
|
+
|
|
513
|
+
# Normalize
|
|
514
|
+
if normalize:
|
|
515
|
+
for col in df.columns:
|
|
516
|
+
min_val, max_val = df[col].min(), df[col].max()
|
|
517
|
+
if max_val > min_val:
|
|
518
|
+
df[col] = (df[col] - min_val) / (max_val - min_val)
|
|
519
|
+
|
|
520
|
+
# Create radar
|
|
521
|
+
n_metrics = len(available_metrics)
|
|
522
|
+
angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
|
|
523
|
+
angles += angles[:1]
|
|
524
|
+
|
|
525
|
+
fig, ax = plt.subplots(figsize=figsize, subplot_kw=dict(polar=True))
|
|
526
|
+
|
|
527
|
+
colors = [PlotStyle.REAL_COLOR, PlotStyle.GENERATED_COLOR, "#2ecc71", "#9b59b6"]
|
|
528
|
+
|
|
529
|
+
for i, (split_name, row) in enumerate(df.iterrows()):
|
|
530
|
+
values = [row[m] for m in available_metrics]
|
|
531
|
+
values += values[:1]
|
|
532
|
+
|
|
533
|
+
ax.plot(angles, values, 'o-', linewidth=2,
|
|
534
|
+
label=split_name, color=colors[i % len(colors)])
|
|
535
|
+
ax.fill(angles, values, alpha=0.15, color=colors[i % len(colors)])
|
|
536
|
+
|
|
537
|
+
ax.set_xticks(angles[:-1])
|
|
538
|
+
ax.set_xticklabels(available_metrics, size=10)
|
|
539
|
+
ax.set_title("Split Comparison", size=14, y=1.1)
|
|
540
|
+
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))
|
|
541
|
+
|
|
542
|
+
fig.tight_layout()
|
|
543
|
+
return fig
|
|
544
|
+
|
|
545
|
+
# ==================== SCATTER PLOTS ====================
|
|
546
|
+
|
|
547
|
+
def scatter_real_vs_generated(
|
|
548
|
+
self,
|
|
549
|
+
condition: str,
|
|
550
|
+
split: Optional[str] = None,
|
|
551
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
552
|
+
alpha: float = 0.5,
|
|
553
|
+
) -> Figure:
|
|
554
|
+
"""
|
|
555
|
+
Scatter plot of real vs generated mean expression.
|
|
556
|
+
|
|
557
|
+
Parameters
|
|
558
|
+
----------
|
|
559
|
+
condition : str
|
|
560
|
+
Condition key or perturbation name
|
|
561
|
+
split : str, optional
|
|
562
|
+
Filter to specific split
|
|
563
|
+
figsize : Tuple[int, int]
|
|
564
|
+
Figure size
|
|
565
|
+
alpha : float
|
|
566
|
+
Point transparency
|
|
567
|
+
|
|
568
|
+
Returns
|
|
569
|
+
-------
|
|
570
|
+
Figure
|
|
571
|
+
Scatter plot figure
|
|
572
|
+
"""
|
|
573
|
+
# Find the condition
|
|
574
|
+
cond_result = None
|
|
575
|
+
for split_name, split_result in self.results.splits.items():
|
|
576
|
+
if split is not None and split_name != split:
|
|
577
|
+
continue
|
|
578
|
+
|
|
579
|
+
for cond_key, cond in split_result.conditions.items():
|
|
580
|
+
if cond_key == condition or cond.perturbation == condition:
|
|
581
|
+
cond_result = cond
|
|
582
|
+
break
|
|
583
|
+
|
|
584
|
+
if cond_result is None or cond_result.real_mean is None:
|
|
585
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
586
|
+
ax.text(0.5, 0.5, f"No data for condition: {condition}", ha='center', va='center')
|
|
587
|
+
return fig
|
|
588
|
+
|
|
589
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
590
|
+
|
|
591
|
+
real_mean = cond_result.real_mean
|
|
592
|
+
gen_mean = cond_result.generated_mean
|
|
593
|
+
|
|
594
|
+
ax.scatter(real_mean, gen_mean, alpha=alpha, s=10, c=PlotStyle.REAL_COLOR)
|
|
595
|
+
|
|
596
|
+
# Add diagonal line
|
|
597
|
+
lims = [
|
|
598
|
+
min(real_mean.min(), gen_mean.min()),
|
|
599
|
+
max(real_mean.max(), gen_mean.max()),
|
|
600
|
+
]
|
|
601
|
+
ax.plot(lims, lims, 'k--', alpha=0.5, label='y=x')
|
|
602
|
+
|
|
603
|
+
# Add correlation
|
|
604
|
+
if "pearson" in cond_result.metrics:
|
|
605
|
+
r = cond_result.metrics["pearson"].aggregate_value
|
|
606
|
+
ax.text(0.05, 0.95, f'r = {r:.3f}', transform=ax.transAxes,
|
|
607
|
+
fontsize=12, verticalalignment='top')
|
|
608
|
+
|
|
609
|
+
ax.set_xlabel("Real Mean Expression")
|
|
610
|
+
ax.set_ylabel("Generated Mean Expression")
|
|
611
|
+
ax.set_title(f"Real vs Generated: {cond_result.perturbation or condition}")
|
|
612
|
+
|
|
613
|
+
fig.tight_layout()
|
|
614
|
+
return fig
|
|
615
|
+
|
|
616
|
+
def scatter_grid(
|
|
617
|
+
self,
|
|
618
|
+
split: Optional[str] = None,
|
|
619
|
+
max_conditions: int = 12,
|
|
620
|
+
ncols: int = 4,
|
|
621
|
+
figsize_per_panel: Tuple[float, float] = (4, 4),
|
|
622
|
+
) -> Figure:
|
|
623
|
+
"""
|
|
624
|
+
Grid of scatter plots for multiple conditions.
|
|
625
|
+
"""
|
|
626
|
+
# Collect conditions
|
|
627
|
+
conditions = []
|
|
628
|
+
for split_name, split_result in self.results.splits.items():
|
|
629
|
+
if split is not None and split_name != split:
|
|
630
|
+
continue
|
|
631
|
+
|
|
632
|
+
for cond in split_result.conditions.values():
|
|
633
|
+
if cond.real_mean is not None:
|
|
634
|
+
conditions.append(cond)
|
|
635
|
+
|
|
636
|
+
conditions = conditions[:max_conditions]
|
|
637
|
+
n = len(conditions)
|
|
638
|
+
|
|
639
|
+
if n == 0:
|
|
640
|
+
fig, ax = plt.subplots()
|
|
641
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
642
|
+
return fig
|
|
643
|
+
|
|
644
|
+
ncols = min(ncols, n)
|
|
645
|
+
nrows = int(np.ceil(n / ncols))
|
|
646
|
+
|
|
647
|
+
fig, axes = plt.subplots(
|
|
648
|
+
nrows, ncols,
|
|
649
|
+
figsize=(figsize_per_panel[0] * ncols, figsize_per_panel[1] * nrows),
|
|
650
|
+
squeeze=False
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
for i, cond in enumerate(conditions):
|
|
654
|
+
ax = axes[i // ncols, i % ncols]
|
|
655
|
+
|
|
656
|
+
real_mean = cond.real_mean
|
|
657
|
+
gen_mean = cond.generated_mean
|
|
658
|
+
|
|
659
|
+
ax.scatter(real_mean, gen_mean, alpha=0.4, s=5, c=PlotStyle.REAL_COLOR)
|
|
660
|
+
|
|
661
|
+
lims = [
|
|
662
|
+
min(real_mean.min(), gen_mean.min()),
|
|
663
|
+
max(real_mean.max(), gen_mean.max()),
|
|
664
|
+
]
|
|
665
|
+
ax.plot(lims, lims, 'k--', alpha=0.3, linewidth=0.5)
|
|
666
|
+
|
|
667
|
+
if "pearson" in cond.metrics:
|
|
668
|
+
r = cond.metrics["pearson"].aggregate_value
|
|
669
|
+
ax.text(0.05, 0.95, f'r={r:.2f}', transform=ax.transAxes,
|
|
670
|
+
fontsize=8, verticalalignment='top')
|
|
671
|
+
|
|
672
|
+
ax.set_title(cond.perturbation or cond.condition_key[:20], fontsize=9)
|
|
673
|
+
ax.tick_params(labelsize=7)
|
|
674
|
+
|
|
675
|
+
# Hide empty panels
|
|
676
|
+
for j in range(n, nrows * ncols):
|
|
677
|
+
axes[j // ncols, j % ncols].axis('off')
|
|
678
|
+
|
|
679
|
+
fig.suptitle("Real vs Generated Expression", fontsize=12, y=1.02)
|
|
680
|
+
fig.tight_layout()
|
|
681
|
+
return fig
|
|
682
|
+
|
|
683
|
+
# ==================== HEATMAPS ====================
|
|
684
|
+
|
|
685
|
+
def heatmap_per_gene(
|
|
686
|
+
self,
|
|
687
|
+
metric_name: str,
|
|
688
|
+
split: Optional[str] = None,
|
|
689
|
+
max_genes: int = 50,
|
|
690
|
+
max_conditions: int = 20,
|
|
691
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_LARGE,
|
|
692
|
+
cmap: str = "RdYlBu_r",
|
|
693
|
+
) -> Figure:
|
|
694
|
+
"""
|
|
695
|
+
Heatmap of per-gene metric values.
|
|
696
|
+
|
|
697
|
+
Parameters
|
|
698
|
+
----------
|
|
699
|
+
metric_name : str
|
|
700
|
+
Metric to visualize
|
|
701
|
+
split : str, optional
|
|
702
|
+
Filter to specific split
|
|
703
|
+
max_genes : int
|
|
704
|
+
Maximum genes to show (selects most variable)
|
|
705
|
+
max_conditions : int
|
|
706
|
+
Maximum conditions to show
|
|
707
|
+
figsize : Tuple[int, int]
|
|
708
|
+
Figure size
|
|
709
|
+
cmap : str
|
|
710
|
+
Colormap name
|
|
711
|
+
|
|
712
|
+
Returns
|
|
713
|
+
-------
|
|
714
|
+
Figure
|
|
715
|
+
Heatmap figure
|
|
716
|
+
"""
|
|
717
|
+
# Collect data
|
|
718
|
+
data = {}
|
|
719
|
+
gene_names = None
|
|
720
|
+
|
|
721
|
+
for split_name, split_result in self.results.splits.items():
|
|
722
|
+
if split is not None and split_name != split:
|
|
723
|
+
continue
|
|
724
|
+
|
|
725
|
+
for cond_key, cond in split_result.conditions.items():
|
|
726
|
+
if metric_name in cond.metrics:
|
|
727
|
+
data[cond.perturbation or cond_key[:15]] = cond.metrics[metric_name].per_gene_values
|
|
728
|
+
if gene_names is None:
|
|
729
|
+
gene_names = cond.metrics[metric_name].gene_names
|
|
730
|
+
|
|
731
|
+
if not data:
|
|
732
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
733
|
+
ax.text(0.5, 0.5, f"No data for {metric_name}", ha='center', va='center')
|
|
734
|
+
return fig
|
|
735
|
+
|
|
736
|
+
df = pd.DataFrame(data, index=gene_names)
|
|
737
|
+
|
|
738
|
+
# Select conditions
|
|
739
|
+
if df.shape[1] > max_conditions:
|
|
740
|
+
# Keep conditions with highest variance
|
|
741
|
+
var = df.var()
|
|
742
|
+
top_conds = var.nlargest(max_conditions).index
|
|
743
|
+
df = df[top_conds]
|
|
744
|
+
|
|
745
|
+
# Select genes
|
|
746
|
+
if df.shape[0] > max_genes:
|
|
747
|
+
# Keep genes with highest variance
|
|
748
|
+
gene_var = df.var(axis=1)
|
|
749
|
+
top_genes = gene_var.nlargest(max_genes).index
|
|
750
|
+
df = df.loc[top_genes]
|
|
751
|
+
|
|
752
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
753
|
+
|
|
754
|
+
sns.heatmap(
|
|
755
|
+
df,
|
|
756
|
+
ax=ax,
|
|
757
|
+
cmap=cmap,
|
|
758
|
+
xticklabels=True,
|
|
759
|
+
yticklabels=True,
|
|
760
|
+
cbar_kws={"label": metric_name},
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
ax.set_xlabel("Condition")
|
|
764
|
+
ax.set_ylabel("Gene")
|
|
765
|
+
ax.set_title(f"Per-Gene {metric_name}")
|
|
766
|
+
|
|
767
|
+
plt.xticks(rotation=45, ha='right')
|
|
768
|
+
plt.yticks(fontsize=8)
|
|
769
|
+
|
|
770
|
+
fig.tight_layout()
|
|
771
|
+
return fig
|
|
772
|
+
|
|
773
|
+
def heatmap_metrics_summary(
|
|
774
|
+
self,
|
|
775
|
+
split: Optional[str] = None,
|
|
776
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
777
|
+
cmap: str = "RdYlBu",
|
|
778
|
+
) -> Figure:
|
|
779
|
+
"""
|
|
780
|
+
Heatmap summarizing all metrics across conditions.
|
|
781
|
+
"""
|
|
782
|
+
df = self._get_all_metrics_data(split)
|
|
783
|
+
|
|
784
|
+
if df.empty:
|
|
785
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
786
|
+
ax.text(0.5, 0.5, "No data available", ha='center', va='center')
|
|
787
|
+
return fig
|
|
788
|
+
|
|
789
|
+
meta_cols = ["split", "condition", "perturbation"]
|
|
790
|
+
metric_cols = [c for c in df.columns if c not in meta_cols]
|
|
791
|
+
|
|
792
|
+
# Pivot for heatmap
|
|
793
|
+
df_pivot = df.set_index("perturbation")[metric_cols]
|
|
794
|
+
|
|
795
|
+
# Normalize columns for visualization
|
|
796
|
+
df_norm = (df_pivot - df_pivot.mean()) / df_pivot.std()
|
|
797
|
+
|
|
798
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
799
|
+
|
|
800
|
+
sns.heatmap(
|
|
801
|
+
df_norm,
|
|
802
|
+
ax=ax,
|
|
803
|
+
cmap=cmap,
|
|
804
|
+
xticklabels=True,
|
|
805
|
+
yticklabels=True,
|
|
806
|
+
center=0,
|
|
807
|
+
cbar_kws={"label": "Z-score"},
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
ax.set_xlabel("Metric")
|
|
811
|
+
ax.set_ylabel("Condition")
|
|
812
|
+
ax.set_title("Metrics Summary (Z-scored)")
|
|
813
|
+
|
|
814
|
+
plt.xticks(rotation=45, ha='right')
|
|
815
|
+
|
|
816
|
+
fig.tight_layout()
|
|
817
|
+
return fig
|
|
818
|
+
|
|
819
|
+
# ==================== EMBEDDING PLOTS ====================
|
|
820
|
+
|
|
821
|
+
def embedding_plot(
|
|
822
|
+
self,
|
|
823
|
+
data_loader: "GeneExpressionDataLoader",
|
|
824
|
+
method: str = "pca",
|
|
825
|
+
split: Optional[str] = None,
|
|
826
|
+
max_samples: int = 5000,
|
|
827
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_MEDIUM,
|
|
828
|
+
alpha: float = 0.6,
|
|
829
|
+
) -> Figure:
|
|
830
|
+
"""
|
|
831
|
+
Plot embedded data (PCA or UMAP) comparing real and generated.
|
|
832
|
+
|
|
833
|
+
Parameters
|
|
834
|
+
----------
|
|
835
|
+
data_loader : GeneExpressionDataLoader
|
|
836
|
+
Data loader with real and generated data
|
|
837
|
+
method : str
|
|
838
|
+
Embedding method: "pca" or "umap"
|
|
839
|
+
split : str, optional
|
|
840
|
+
Filter to specific split
|
|
841
|
+
max_samples : int
|
|
842
|
+
Maximum samples to plot
|
|
843
|
+
figsize : Tuple[int, int]
|
|
844
|
+
Figure size
|
|
845
|
+
alpha : float
|
|
846
|
+
Point transparency
|
|
847
|
+
|
|
848
|
+
Returns
|
|
849
|
+
-------
|
|
850
|
+
Figure
|
|
851
|
+
Embedding plot figure
|
|
852
|
+
"""
|
|
853
|
+
try:
|
|
854
|
+
import scanpy as sc
|
|
855
|
+
except ImportError:
|
|
856
|
+
raise ImportError("scanpy is required for embedding plots")
|
|
857
|
+
|
|
858
|
+
# Combine real and generated data
|
|
859
|
+
real = data_loader.real.copy()
|
|
860
|
+
gen = data_loader.generated.copy()
|
|
861
|
+
|
|
862
|
+
# Apply split filter if needed
|
|
863
|
+
if split is not None and data_loader.split_column is not None:
|
|
864
|
+
if data_loader.split_column in real.obs.columns:
|
|
865
|
+
mask = real.obs[data_loader.split_column].astype(str) == split
|
|
866
|
+
real = real[mask].copy()
|
|
867
|
+
|
|
868
|
+
# Subsample if needed
|
|
869
|
+
if real.n_obs > max_samples // 2:
|
|
870
|
+
idx = np.random.choice(real.n_obs, max_samples // 2, replace=False)
|
|
871
|
+
real = real[idx].copy()
|
|
872
|
+
if gen.n_obs > max_samples // 2:
|
|
873
|
+
idx = np.random.choice(gen.n_obs, max_samples // 2, replace=False)
|
|
874
|
+
gen = gen[idx].copy()
|
|
875
|
+
|
|
876
|
+
# Add source label
|
|
877
|
+
real.obs["_source"] = "Real"
|
|
878
|
+
gen.obs["_source"] = "Generated"
|
|
879
|
+
|
|
880
|
+
# Concatenate
|
|
881
|
+
combined = real.concatenate(gen, batch_key="_batch")
|
|
882
|
+
|
|
883
|
+
# Compute embedding
|
|
884
|
+
sc.pp.pca(combined, n_comps=50)
|
|
885
|
+
|
|
886
|
+
if method.lower() == "umap":
|
|
887
|
+
sc.pp.neighbors(combined)
|
|
888
|
+
sc.tl.umap(combined)
|
|
889
|
+
x_key = "X_umap"
|
|
890
|
+
x_label, y_label = "UMAP1", "UMAP2"
|
|
891
|
+
else:
|
|
892
|
+
x_key = "X_pca"
|
|
893
|
+
x_label, y_label = "PC1", "PC2"
|
|
894
|
+
|
|
895
|
+
# Plot
|
|
896
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
897
|
+
|
|
898
|
+
coords = combined.obsm[x_key][:, :2]
|
|
899
|
+
source = combined.obs["_source"]
|
|
900
|
+
|
|
901
|
+
for label, color in [("Real", PlotStyle.REAL_COLOR), ("Generated", PlotStyle.GENERATED_COLOR)]:
|
|
902
|
+
mask = source == label
|
|
903
|
+
ax.scatter(
|
|
904
|
+
coords[mask, 0], coords[mask, 1],
|
|
905
|
+
c=color, label=label, alpha=alpha, s=10
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
ax.set_xlabel(x_label)
|
|
909
|
+
ax.set_ylabel(y_label)
|
|
910
|
+
ax.set_title(f"{method.upper()} Embedding: Real vs Generated")
|
|
911
|
+
ax.legend()
|
|
912
|
+
|
|
913
|
+
fig.tight_layout()
|
|
914
|
+
return fig
|
|
915
|
+
|
|
916
|
+
def embedding_by_condition(
|
|
917
|
+
self,
|
|
918
|
+
data_loader: "GeneExpressionDataLoader",
|
|
919
|
+
method: str = "pca",
|
|
920
|
+
condition_column: Optional[str] = None,
|
|
921
|
+
max_samples: int = 5000,
|
|
922
|
+
figsize: Tuple[int, int] = PlotStyle.FIGURE_LARGE,
|
|
923
|
+
) -> Figure:
|
|
924
|
+
"""
|
|
925
|
+
Embedding plot colored by condition.
|
|
926
|
+
"""
|
|
927
|
+
try:
|
|
928
|
+
import scanpy as sc
|
|
929
|
+
except ImportError:
|
|
930
|
+
raise ImportError("scanpy is required for embedding plots")
|
|
931
|
+
|
|
932
|
+
# Use first condition column if not specified
|
|
933
|
+
if condition_column is None:
|
|
934
|
+
condition_column = data_loader.condition_columns[0]
|
|
935
|
+
|
|
936
|
+
# Combine data
|
|
937
|
+
real = data_loader.real.copy()
|
|
938
|
+
gen = data_loader.generated.copy()
|
|
939
|
+
|
|
940
|
+
real.obs["_source"] = "Real"
|
|
941
|
+
gen.obs["_source"] = "Generated"
|
|
942
|
+
|
|
943
|
+
# Subsample
|
|
944
|
+
if real.n_obs > max_samples // 2:
|
|
945
|
+
idx = np.random.choice(real.n_obs, max_samples // 2, replace=False)
|
|
946
|
+
real = real[idx].copy()
|
|
947
|
+
if gen.n_obs > max_samples // 2:
|
|
948
|
+
idx = np.random.choice(gen.n_obs, max_samples // 2, replace=False)
|
|
949
|
+
gen = gen[idx].copy()
|
|
950
|
+
|
|
951
|
+
combined = real.concatenate(gen, batch_key="_batch")
|
|
952
|
+
|
|
953
|
+
# Compute embedding
|
|
954
|
+
sc.pp.pca(combined, n_comps=50)
|
|
955
|
+
if method.lower() == "umap":
|
|
956
|
+
sc.pp.neighbors(combined)
|
|
957
|
+
sc.tl.umap(combined)
|
|
958
|
+
x_key = "X_umap"
|
|
959
|
+
else:
|
|
960
|
+
x_key = "X_pca"
|
|
961
|
+
|
|
962
|
+
# Create side-by-side plot
|
|
963
|
+
fig, axes = plt.subplots(1, 2, figsize=figsize)
|
|
964
|
+
|
|
965
|
+
coords = combined.obsm[x_key][:, :2]
|
|
966
|
+
|
|
967
|
+
for ax, source in zip(axes, ["Real", "Generated"]):
|
|
968
|
+
mask = combined.obs["_source"] == source
|
|
969
|
+
|
|
970
|
+
conditions = combined.obs.loc[mask, condition_column].astype(str)
|
|
971
|
+
unique_conds = conditions.unique()
|
|
972
|
+
colors = sns.color_palette("husl", len(unique_conds))
|
|
973
|
+
color_map = dict(zip(unique_conds, colors))
|
|
974
|
+
|
|
975
|
+
for cond in unique_conds:
|
|
976
|
+
cond_mask = (combined.obs["_source"] == source) & (combined.obs[condition_column].astype(str) == cond)
|
|
977
|
+
ax.scatter(
|
|
978
|
+
coords[cond_mask, 0], coords[cond_mask, 1],
|
|
979
|
+
c=[color_map[cond]], label=cond[:15], alpha=0.6, s=10
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
ax.set_title(f"{source}")
|
|
983
|
+
ax.set_xlabel("Dim 1")
|
|
984
|
+
ax.set_ylabel("Dim 2")
|
|
985
|
+
|
|
986
|
+
# Shared legend
|
|
987
|
+
handles, labels = axes[0].get_legend_handles_labels()
|
|
988
|
+
fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.15, 0.5))
|
|
989
|
+
|
|
990
|
+
fig.suptitle(f"{method.upper()} Embedding by {condition_column}", fontsize=12)
|
|
991
|
+
fig.tight_layout()
|
|
992
|
+
return fig
|
|
993
|
+
|
|
994
|
+
# ==================== SAVE ALL ====================
|
|
995
|
+
|
|
996
|
+
def save_all(
|
|
997
|
+
self,
|
|
998
|
+
output_dir: Union[str, Path],
|
|
999
|
+
formats: List[str] = ["png", "pdf"],
|
|
1000
|
+
data_loader: Optional["GeneExpressionDataLoader"] = None,
|
|
1001
|
+
):
|
|
1002
|
+
"""
|
|
1003
|
+
Generate and save all plots.
|
|
1004
|
+
|
|
1005
|
+
Parameters
|
|
1006
|
+
----------
|
|
1007
|
+
output_dir : str or Path
|
|
1008
|
+
Directory to save plots
|
|
1009
|
+
formats : List[str]
|
|
1010
|
+
Image formats to save
|
|
1011
|
+
data_loader : GeneExpressionDataLoader, optional
|
|
1012
|
+
If provided, also generates embedding plots
|
|
1013
|
+
"""
|
|
1014
|
+
output_dir = Path(output_dir)
|
|
1015
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
1016
|
+
|
|
1017
|
+
plots = {}
|
|
1018
|
+
|
|
1019
|
+
# Generate all plots
|
|
1020
|
+
try:
|
|
1021
|
+
plots["boxplot_metrics"] = self.boxplot_metrics()
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
warnings.warn(f"Failed to generate boxplot_metrics: {e}")
|
|
1024
|
+
|
|
1025
|
+
try:
|
|
1026
|
+
plots["violin_metrics"] = self.violin_metrics()
|
|
1027
|
+
except Exception as e:
|
|
1028
|
+
warnings.warn(f"Failed to generate violin_metrics: {e}")
|
|
1029
|
+
|
|
1030
|
+
try:
|
|
1031
|
+
plots["radar_split"] = self.radar_split_comparison()
|
|
1032
|
+
except Exception as e:
|
|
1033
|
+
warnings.warn(f"Failed to generate radar_split: {e}")
|
|
1034
|
+
|
|
1035
|
+
try:
|
|
1036
|
+
plots["scatter_grid"] = self.scatter_grid()
|
|
1037
|
+
except Exception as e:
|
|
1038
|
+
warnings.warn(f"Failed to generate scatter_grid: {e}")
|
|
1039
|
+
|
|
1040
|
+
try:
|
|
1041
|
+
plots["heatmap_summary"] = self.heatmap_metrics_summary()
|
|
1042
|
+
except Exception as e:
|
|
1043
|
+
warnings.warn(f"Failed to generate heatmap_summary: {e}")
|
|
1044
|
+
|
|
1045
|
+
# Per-metric plots
|
|
1046
|
+
for metric_name in ["pearson", "wasserstein_1", "mmd"]:
|
|
1047
|
+
try:
|
|
1048
|
+
plots[f"violin_{metric_name}"] = self.violin_per_gene(metric_name)
|
|
1049
|
+
except Exception:
|
|
1050
|
+
pass
|
|
1051
|
+
|
|
1052
|
+
# Embedding plots if data loader provided
|
|
1053
|
+
if data_loader is not None:
|
|
1054
|
+
try:
|
|
1055
|
+
plots["embedding_pca"] = self.embedding_plot(data_loader, method="pca")
|
|
1056
|
+
except Exception as e:
|
|
1057
|
+
warnings.warn(f"Failed to generate PCA embedding: {e}")
|
|
1058
|
+
|
|
1059
|
+
try:
|
|
1060
|
+
plots["embedding_umap"] = self.embedding_plot(data_loader, method="umap")
|
|
1061
|
+
except Exception as e:
|
|
1062
|
+
warnings.warn(f"Failed to generate UMAP embedding: {e}")
|
|
1063
|
+
|
|
1064
|
+
# Save all plots
|
|
1065
|
+
for name, fig in plots.items():
|
|
1066
|
+
for fmt in formats:
|
|
1067
|
+
path = output_dir / f"{name}.{fmt}"
|
|
1068
|
+
fig.savefig(path, dpi=self.dpi, bbox_inches='tight')
|
|
1069
|
+
plt.close(fig)
|
|
1070
|
+
|
|
1071
|
+
print(f"Saved {len(plots)} plots to {output_dir}")
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
# Convenience function
|
|
1075
|
+
def visualize(
|
|
1076
|
+
results: "EvaluationResult",
|
|
1077
|
+
output_dir: Union[str, Path],
|
|
1078
|
+
data_loader: Optional["GeneExpressionDataLoader"] = None,
|
|
1079
|
+
**kwargs
|
|
1080
|
+
):
|
|
1081
|
+
"""
|
|
1082
|
+
Generate and save all visualizations.
|
|
1083
|
+
|
|
1084
|
+
Parameters
|
|
1085
|
+
----------
|
|
1086
|
+
results : EvaluationResult
|
|
1087
|
+
Evaluation results to visualize
|
|
1088
|
+
output_dir : str or Path
|
|
1089
|
+
Directory to save plots
|
|
1090
|
+
data_loader : GeneExpressionDataLoader, optional
|
|
1091
|
+
If provided, generates embedding plots
|
|
1092
|
+
**kwargs
|
|
1093
|
+
Additional arguments for EvaluationVisualizer
|
|
1094
|
+
"""
|
|
1095
|
+
viz = EvaluationVisualizer(results, **kwargs)
|
|
1096
|
+
viz.save_all(output_dir, data_loader=data_loader)
|