gengeneeval 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ """
2
+ Visualization module for DEG evaluation results.
3
+
4
+ Provides:
5
+ - Distribution plots for metrics across contexts
6
+ - Heatmaps for context × metric results
7
+ - Comprehensive DEG reports
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
12
+ from pathlib import Path
13
+ import numpy as np
14
+
15
+ if TYPE_CHECKING:
16
+ import matplotlib.pyplot as plt
17
+ import pandas as pd
18
+ from .evaluator import DEGEvaluationResult
19
+
20
+
21
+ def _check_matplotlib():
22
+ """Check if matplotlib is available."""
23
+ try:
24
+ import matplotlib.pyplot as plt
25
+ return plt
26
+ except ImportError:
27
+ raise ImportError("matplotlib is required for visualization. Install with: pip install matplotlib")
28
+
29
+
30
+ def _check_seaborn():
31
+ """Check if seaborn is available."""
32
+ try:
33
+ import seaborn as sns
34
+ return sns
35
+ except ImportError:
36
+ return None # Seaborn is optional
37
+
38
+
39
+ def plot_deg_distributions(
40
+ results: "DEGEvaluationResult",
41
+ metrics: Optional[List[str]] = None,
42
+ figsize: tuple = (12, 4),
43
+ save_path: Optional[Union[str, Path]] = None,
44
+ show: bool = True,
45
+ ) -> "plt.Figure":
46
+ """
47
+ Plot distribution of metrics across contexts.
48
+
49
+ Creates violin/box plots showing metric distributions across all contexts.
50
+
51
+ Parameters
52
+ ----------
53
+ results : DEGEvaluationResult
54
+ DEG evaluation results
55
+ metrics : List[str], optional
56
+ Metrics to plot. If None, plots all available.
57
+ figsize : tuple
58
+ Figure size per metric
59
+ save_path : str or Path, optional
60
+ Path to save figure
61
+ show : bool
62
+ Whether to display the figure
63
+
64
+ Returns
65
+ -------
66
+ plt.Figure
67
+ Matplotlib figure
68
+ """
69
+ plt = _check_matplotlib()
70
+ sns = _check_seaborn()
71
+
72
+ df = results.expanded_metrics
73
+
74
+ if metrics is None:
75
+ # Get all numeric columns that are metrics
76
+ metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
77
+
78
+ if len(metrics) == 0:
79
+ raise ValueError("No metrics to plot")
80
+
81
+ n_metrics = len(metrics)
82
+ fig, axes = plt.subplots(1, n_metrics, figsize=(figsize[0], figsize[1]))
83
+
84
+ if n_metrics == 1:
85
+ axes = [axes]
86
+
87
+ for ax, metric in zip(axes, metrics):
88
+ values = df[metric].dropna()
89
+
90
+ if sns is not None:
91
+ sns.violinplot(y=values, ax=ax, inner="box")
92
+ else:
93
+ ax.boxplot(values, vert=True)
94
+
95
+ ax.set_title(metric)
96
+ ax.set_ylabel("Value")
97
+
98
+ # Add mean annotation
99
+ mean_val = values.mean()
100
+ ax.axhline(mean_val, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_val:.3f}')
101
+ ax.legend(loc='upper right', fontsize=8)
102
+
103
+ plt.suptitle("Metric Distributions Across Contexts (DEGs only)", y=1.02)
104
+ plt.tight_layout()
105
+
106
+ if save_path:
107
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
108
+
109
+ if show:
110
+ plt.show()
111
+
112
+ return fig
113
+
114
+
115
+ def plot_context_heatmap(
116
+ results: "DEGEvaluationResult",
117
+ metrics: Optional[List[str]] = None,
118
+ figsize: tuple = (12, 8),
119
+ cmap: str = "RdYlBu_r",
120
+ save_path: Optional[Union[str, Path]] = None,
121
+ show: bool = True,
122
+ ) -> "plt.Figure":
123
+ """
124
+ Plot heatmap of metrics across contexts.
125
+
126
+ Parameters
127
+ ----------
128
+ results : DEGEvaluationResult
129
+ DEG evaluation results
130
+ metrics : List[str], optional
131
+ Metrics to include
132
+ figsize : tuple
133
+ Figure size
134
+ cmap : str
135
+ Colormap name
136
+ save_path : str or Path, optional
137
+ Path to save figure
138
+ show : bool
139
+ Whether to display
140
+
141
+ Returns
142
+ -------
143
+ plt.Figure
144
+ Matplotlib figure
145
+ """
146
+ plt = _check_matplotlib()
147
+ sns = _check_seaborn()
148
+
149
+ df = results.expanded_metrics
150
+
151
+ if metrics is None:
152
+ metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
153
+
154
+ # Create matrix for heatmap
155
+ heatmap_data = df.set_index("context_id")[metrics]
156
+
157
+ fig, ax = plt.subplots(figsize=figsize)
158
+
159
+ if sns is not None:
160
+ sns.heatmap(
161
+ heatmap_data,
162
+ ax=ax,
163
+ cmap=cmap,
164
+ annot=True,
165
+ fmt=".3f",
166
+ cbar_kws={"label": "Metric Value"},
167
+ )
168
+ else:
169
+ im = ax.imshow(heatmap_data.values, cmap=cmap, aspect='auto')
170
+ plt.colorbar(im, ax=ax, label="Metric Value")
171
+ ax.set_xticks(range(len(metrics)))
172
+ ax.set_xticklabels(metrics, rotation=45, ha='right')
173
+ ax.set_yticks(range(len(heatmap_data)))
174
+ ax.set_yticklabels(heatmap_data.index)
175
+
176
+ ax.set_title("Metrics per Context (DEGs only)")
177
+ ax.set_xlabel("Metric")
178
+ ax.set_ylabel("Context")
179
+
180
+ plt.tight_layout()
181
+
182
+ if save_path:
183
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
184
+
185
+ if show:
186
+ plt.show()
187
+
188
+ return fig
189
+
190
+
191
+ def plot_deg_counts(
192
+ results: "DEGEvaluationResult",
193
+ figsize: tuple = (10, 6),
194
+ save_path: Optional[Union[str, Path]] = None,
195
+ show: bool = True,
196
+ ) -> "plt.Figure":
197
+ """
198
+ Plot DEG counts per context.
199
+
200
+ Parameters
201
+ ----------
202
+ results : DEGEvaluationResult
203
+ DEG evaluation results
204
+ figsize : tuple
205
+ Figure size
206
+ save_path : str or Path, optional
207
+ Path to save figure
208
+ show : bool
209
+ Whether to display
210
+
211
+ Returns
212
+ -------
213
+ plt.Figure
214
+ Matplotlib figure
215
+ """
216
+ plt = _check_matplotlib()
217
+
218
+ df = results.deg_summary
219
+
220
+ if len(df) == 0:
221
+ raise ValueError("No DEG data to plot")
222
+
223
+ fig, ax = plt.subplots(figsize=figsize)
224
+
225
+ x = range(len(df))
226
+ width = 0.35
227
+
228
+ if "n_upregulated" in df.columns and "n_downregulated" in df.columns:
229
+ ax.bar(x, df["n_upregulated"], width, label='Upregulated', color='red', alpha=0.7)
230
+ ax.bar(x, -df["n_downregulated"], width, label='Downregulated', color='blue', alpha=0.7)
231
+ ax.axhline(0, color='black', linewidth=0.5)
232
+ ax.set_ylabel("Number of DEGs")
233
+ else:
234
+ ax.bar(x, df["n_degs"], color='purple', alpha=0.7)
235
+ ax.set_ylabel("Number of DEGs")
236
+
237
+ ax.set_xlabel("Context")
238
+ ax.set_xticks(x)
239
+ ax.set_xticklabels(df["context_id"], rotation=45, ha='right')
240
+ ax.set_title("DEG Counts per Context")
241
+ ax.legend()
242
+
243
+ plt.tight_layout()
244
+
245
+ if save_path:
246
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
247
+
248
+ if show:
249
+ plt.show()
250
+
251
+ return fig
252
+
253
+
254
+ def create_deg_report(
255
+ results: "DEGEvaluationResult",
256
+ output_dir: Union[str, Path],
257
+ include_plots: bool = True,
258
+ ) -> None:
259
+ """
260
+ Create a comprehensive DEG evaluation report.
261
+
262
+ Creates:
263
+ - Summary statistics (CSV)
264
+ - Aggregated metrics (CSV)
265
+ - Expanded per-context metrics (CSV)
266
+ - DEG summary per context (CSV)
267
+ - Visualization plots (PNG, if include_plots=True)
268
+ - Markdown report
269
+
270
+ Parameters
271
+ ----------
272
+ results : DEGEvaluationResult
273
+ DEG evaluation results
274
+ output_dir : str or Path
275
+ Output directory
276
+ include_plots : bool
277
+ Whether to generate plots
278
+ """
279
+ output_dir = Path(output_dir)
280
+ output_dir.mkdir(parents=True, exist_ok=True)
281
+
282
+ # Save CSVs
283
+ results.save(output_dir)
284
+
285
+ # Generate plots
286
+ if include_plots:
287
+ try:
288
+ plt = _check_matplotlib()
289
+
290
+ # Distribution plot
291
+ if len(results.expanded_metrics) > 0:
292
+ plot_deg_distributions(
293
+ results,
294
+ save_path=output_dir / "deg_metric_distributions.png",
295
+ show=False,
296
+ )
297
+ plt.close()
298
+
299
+ # Heatmap
300
+ plot_context_heatmap(
301
+ results,
302
+ save_path=output_dir / "deg_context_heatmap.png",
303
+ show=False,
304
+ )
305
+ plt.close()
306
+
307
+ # DEG counts
308
+ if len(results.deg_summary) > 0:
309
+ plot_deg_counts(
310
+ results,
311
+ save_path=output_dir / "deg_counts.png",
312
+ show=False,
313
+ )
314
+ plt.close()
315
+
316
+ except ImportError:
317
+ pass # Skip plots if matplotlib not available
318
+
319
+ # Create markdown report
320
+ report_lines = [
321
+ "# DEG Evaluation Report",
322
+ "",
323
+ "## Summary",
324
+ "",
325
+ f"- **Number of contexts evaluated**: {len(results.context_results)}",
326
+ f"- **DEG detection method**: {results.settings.get('deg_method', 'N/A')}",
327
+ f"- **P-value threshold**: {results.settings.get('pval_threshold', 'N/A')}",
328
+ f"- **Log fold change threshold**: {results.settings.get('lfc_threshold', 'N/A')}",
329
+ f"- **Compute device**: {results.settings.get('device', 'N/A')}",
330
+ "",
331
+ "## Aggregated Metrics",
332
+ "",
333
+ ]
334
+
335
+ if len(results.aggregated_metrics) > 0:
336
+ for col in results.aggregated_metrics.columns:
337
+ val = results.aggregated_metrics[col].iloc[0]
338
+ if isinstance(val, float):
339
+ report_lines.append(f"- **{col}**: {val:.4f}")
340
+ else:
341
+ report_lines.append(f"- **{col}**: {val}")
342
+ else:
343
+ report_lines.append("No aggregated metrics available.")
344
+
345
+ report_lines.extend([
346
+ "",
347
+ "## DEG Summary",
348
+ "",
349
+ ])
350
+
351
+ if len(results.deg_summary) > 0:
352
+ report_lines.append(results.deg_summary.to_markdown(index=False))
353
+ else:
354
+ report_lines.append("No DEG data available.")
355
+
356
+ report_lines.extend([
357
+ "",
358
+ "## Files Generated",
359
+ "",
360
+ "- `deg_aggregated_metrics.csv`: Summary metrics across all contexts",
361
+ "- `deg_expanded_metrics.csv`: Per-context metrics",
362
+ "- `deg_summary.csv`: DEG detection summary",
363
+ "- `deg_per_context/`: Individual DEG results per context",
364
+ ])
365
+
366
+ if include_plots:
367
+ report_lines.extend([
368
+ "- `deg_metric_distributions.png`: Distribution plots",
369
+ "- `deg_context_heatmap.png`: Context × metric heatmap",
370
+ "- `deg_counts.png`: DEG counts per context",
371
+ ])
372
+
373
+ with open(output_dir / "DEG_REPORT.md", "w") as f:
374
+ f.write("\n".join(report_lines))
375
+
376
+ print(f"DEG report saved to {output_dir}")
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.3.0
4
- Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
3
+ Version: 0.4.0
4
+ Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, DEG-focused evaluation, per-context analysis, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
7
- Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
7
+ Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient,DEG,perturbation
8
8
  Author: GenEval Team
9
9
  Author-email: geneval@example.com
10
10
  Requires-Python: >=3.8,<4.0
@@ -78,6 +78,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
78
78
  - ✅ Condition-based matching (perturbation, cell type, etc.)
79
79
  - ✅ Train/test split support
80
80
  - ✅ Per-gene and aggregate metrics
81
+ - ✅ **DEG-focused evaluation** with per-context (covariate × perturbation) support
82
+ - ✅ **Fast DEG detection** via vectorized Welch's t-test, Student's t-test, Wilcoxon
81
83
  - ✅ **Memory-efficient lazy loading** for large datasets
82
84
  - ✅ **Batched evaluation** to avoid OOM errors
83
85
  - ✅ **CPU parallelization** via joblib (multi-core speedup)
@@ -247,6 +249,91 @@ print(f"MMD: {results['mmd'].aggregate_value:.4f}")
247
249
  | `device="mps"` | 5-20x | Apple Silicon Macs |
248
250
  | Vectorized NumPy | 2-5x | Automatic fallback |
249
251
 
252
+ ### DEG-Focused Evaluation
253
+
254
+ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, computing metrics only on biologically relevant DEGs rather than all genes. This provides more meaningful evaluation for perturbation prediction tasks.
255
+
256
+ #### Key Features
257
+
258
+ - **Fast DEG detection**: Vectorized Welch's t-test, Student's t-test, or Wilcoxon rank-sum
259
+ - **Per-context evaluation**: Automatically evaluates each (covariate × perturbation) combination
260
+ - **GPU acceleration**: DEG detection and metrics on GPU for large datasets
261
+ - **Comprehensive reporting**: Aggregated and expanded results with visualizations
262
+
263
+ #### Quick Start
264
+
265
+ ```python
266
+ from geneval import evaluate_degs
267
+ import pandas as pd
268
+
269
+ # Evaluate with DEG-focused metrics
270
+ results = evaluate_degs(
271
+ real_data=real_adata.X, # (n_samples, n_genes)
272
+ generated_data=gen_adata.X,
273
+ real_obs=real_adata.obs,
274
+ generated_obs=gen_adata.obs,
275
+ condition_columns=["cell_type", "perturbation"], # Context columns
276
+ control_key="control", # Value indicating control samples
277
+ perturbation_column="perturbation",
278
+ deg_method="welch", # or "student", "wilcoxon", "logfc"
279
+ pval_threshold=0.05,
280
+ lfc_threshold=0.5, # log2 fold change threshold
281
+ device="cuda", # GPU acceleration
282
+ )
283
+
284
+ # Access results
285
+ print(results.aggregated_metrics) # Summary across all contexts
286
+ print(results.expanded_metrics) # Per-context metrics
287
+ print(results.deg_summary) # DEG counts per context
288
+
289
+ # Save results with plots
290
+ results.save("deg_evaluation/")
291
+ ```
292
+
293
+ #### Per-Context Evaluation
294
+
295
+ When multiple condition columns are provided (e.g., `["cell_type", "perturbation"]`), GenEval evaluates **every combination** separately:
296
+
297
+ | Context | n_DEGs | W1 (DEGs only) | MMD (DEGs only) |
298
+ |---------|--------|----------------|-----------------|
299
+ | TypeA_drug1 | 234 | 0.42 | 0.031 |
300
+ | TypeA_drug2 | 189 | 0.38 | 0.027 |
301
+ | TypeB_drug1 | 312 | 0.51 | 0.045 |
302
+
303
+ If only `perturbation` column is provided, evaluation is done per-perturbation.
304
+
305
+ #### Available DEG Methods
306
+
307
+ | Method | Description | Speed |
308
+ |--------|-------------|-------|
309
+ | `welch` | Welch's t-test (unequal variance) | ⚡ Fast |
310
+ | `student` | Student's t-test (equal variance) | ⚡ Fast |
311
+ | `wilcoxon` | Wilcoxon rank-sum (non-parametric) | 🐢 Slower |
312
+ | `logfc` | Log fold change only (no p-value) | ⚡⚡ Fastest |
313
+
314
+ #### Visualization
315
+
316
+ ```python
317
+ from geneval.deg import (
318
+ plot_deg_distributions,
319
+ plot_context_heatmap,
320
+ plot_deg_counts,
321
+ create_deg_report,
322
+ )
323
+
324
+ # Distribution of metrics across contexts
325
+ plot_deg_distributions(results, save_path="dist.png")
326
+
327
+ # Heatmap: context × metric
328
+ plot_context_heatmap(results, save_path="heatmap.png")
329
+
330
+ # DEG counts per context (up/down regulated)
331
+ plot_deg_counts(results, save_path="deg_counts.png")
332
+
333
+ # Generate comprehensive report
334
+ create_deg_report(results, "report/", include_plots=True)
335
+ ```
336
+
250
337
  ## Expected Data Format
251
338
 
252
339
  GenEval expects AnnData (h5ad) files with:
@@ -1,4 +1,4 @@
1
- geneval/__init__.py,sha256=K0E3Jyt3l7_KxqIeI3upBBBrjRA4ASdRFugaxMVVGRM,4306
1
+ geneval/__init__.py,sha256=1ENlptAErFX1ThLDuO8J5Hs0ko5gIxGGVq7PZUhBUKY,5418
2
2
  geneval/cli.py,sha256=0ai0IGyn3SSmEnfLRJhcr0brvUxuNZHE4IXod7jvosU,9977
3
3
  geneval/config.py,sha256=gkCjs_gzPWgUZNcmSR3Y70XQCAZ1m9AKLueaM-x8bvw,3729
4
4
  geneval/core.py,sha256=No0DP8bNR6LedfCWEedY9C5r_c4M14rvSPaGZqbxc94,1155
@@ -6,6 +6,11 @@ geneval/data/__init__.py,sha256=NQUPVpUnBIabrTH5TuRk0KE9S7sVO5QetZv-MCQmZuw,827
6
6
  geneval/data/gene_expression_datamodule.py,sha256=XiBIdf68JZ-3S-FaZsrQlBJA7qL9uUXo2C8y0r4an5M,8009
7
7
  geneval/data/lazy_loader.py,sha256=5fTRVjPjcWvYXV-uPWFUF2Nn9rHRdD8lygAUkCW8wOM,20677
8
8
  geneval/data/loader.py,sha256=zpRmwGZ4PJkB3rpXXRCMFtvMi4qvUrPkKmvIlGjfRpY,14555
9
+ geneval/deg/__init__.py,sha256=joH816k_UWvu2qVhWb-fTbMQTmAhz4nUvt6yraziRek,1499
10
+ geneval/deg/context.py,sha256=_9gnWnRqqCZUDlegV2sT_rQrw8OeP1TIE9NZjNcI0ig,9069
11
+ geneval/deg/detection.py,sha256=gDdHOyFLOfl_B0xutS3KVFy53sreJ19N33B0RRI01wo,18119
12
+ geneval/deg/evaluator.py,sha256=MiBT2GOXUwq9rxHVAnJOVSbybX0rVgTsSDvOeJtnanE,18570
13
+ geneval/deg/visualization.py,sha256=9lWW9vRH_FbkIjJrf1MPobU1Yu_CAh6aw60S7g2Qe2k,10448
9
14
  geneval/evaluator.py,sha256=WgdrgqOcGYT35k1keiFEIIRIj2CQaD2DsmBpq9hcLrI,13440
10
15
  geneval/evaluators/__init__.py,sha256=i11sHvhsjEAeI3Aw9zFTPmCYuqkGxzTHggAKehe3HQ0,160
11
16
  geneval/evaluators/base_evaluator.py,sha256=yJL568HdNofIcHgNOElSQMVlG9oRPTTDIZ7CmKccRqs,5967
@@ -28,8 +33,8 @@ geneval/utils/preprocessing.py,sha256=1Cij1O2dwDR6_zh5IEgLPq3jEmV8VfIRjfQrHiKe3M
28
33
  geneval/visualization/__init__.py,sha256=LN19jl5xV4WVJTePaOUHWvKZ_pgDFp1chhcklGkNtm8,792
29
34
  geneval/visualization/plots.py,sha256=3K94r3x5NjIUZ-hYVQIivO63VkLOvDWl-BLB_qL2pSY,15008
30
35
  geneval/visualization/visualizer.py,sha256=lX7K0j20nAsgdtOOdbxLdLKYAfovEp3hNAnZOjFTCq0,36670
31
- gengeneeval-0.3.0.dist-info/METADATA,sha256=5K2bIh59OEM88dNVeUWPOevyyAbnAIyiKaZu6VmJIh0,9680
32
- gengeneeval-0.3.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
33
- gengeneeval-0.3.0.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
34
- gengeneeval-0.3.0.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
35
- gengeneeval-0.3.0.dist-info/RECORD,,
36
+ gengeneeval-0.4.0.dist-info/METADATA,sha256=R3GI2E_z6qC1olM0D3aPKrJ3yjQDf_9-GncDqvNhwMY,12879
37
+ gengeneeval-0.4.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
+ gengeneeval-0.4.0.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
39
+ gengeneeval-0.4.0.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
40
+ gengeneeval-0.4.0.dist-info/RECORD,,