gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,376 @@
1
+ """
2
+ Visualization module for DEG evaluation results.
3
+
4
+ Provides:
5
+ - Distribution plots for metrics across contexts
6
+ - Heatmaps for context × metric results
7
+ - Comprehensive DEG reports
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
12
+ from pathlib import Path
13
+ import numpy as np
14
+
15
+ if TYPE_CHECKING:
16
+ import matplotlib.pyplot as plt
17
+ import pandas as pd
18
+ from .evaluator import DEGEvaluationResult
19
+
20
+
21
+ def _check_matplotlib():
22
+ """Check if matplotlib is available."""
23
+ try:
24
+ import matplotlib.pyplot as plt
25
+ return plt
26
+ except ImportError:
27
+ raise ImportError("matplotlib is required for visualization. Install with: pip install matplotlib")
28
+
29
+
30
+ def _check_seaborn():
31
+ """Check if seaborn is available."""
32
+ try:
33
+ import seaborn as sns
34
+ return sns
35
+ except ImportError:
36
+ return None # Seaborn is optional
37
+
38
+
39
+ def plot_deg_distributions(
40
+ results: "DEGEvaluationResult",
41
+ metrics: Optional[List[str]] = None,
42
+ figsize: tuple = (12, 4),
43
+ save_path: Optional[Union[str, Path]] = None,
44
+ show: bool = True,
45
+ ) -> "plt.Figure":
46
+ """
47
+ Plot distribution of metrics across contexts.
48
+
49
+ Creates violin/box plots showing metric distributions across all contexts.
50
+
51
+ Parameters
52
+ ----------
53
+ results : DEGEvaluationResult
54
+ DEG evaluation results
55
+ metrics : List[str], optional
56
+ Metrics to plot. If None, plots all available.
57
+ figsize : tuple
58
+ Figure size per metric
59
+ save_path : str or Path, optional
60
+ Path to save figure
61
+ show : bool
62
+ Whether to display the figure
63
+
64
+ Returns
65
+ -------
66
+ plt.Figure
67
+ Matplotlib figure
68
+ """
69
+ plt = _check_matplotlib()
70
+ sns = _check_seaborn()
71
+
72
+ df = results.expanded_metrics
73
+
74
+ if metrics is None:
75
+ # Get all numeric columns that are metrics
76
+ metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
77
+
78
+ if len(metrics) == 0:
79
+ raise ValueError("No metrics to plot")
80
+
81
+ n_metrics = len(metrics)
82
+ fig, axes = plt.subplots(1, n_metrics, figsize=(figsize[0], figsize[1]))
83
+
84
+ if n_metrics == 1:
85
+ axes = [axes]
86
+
87
+ for ax, metric in zip(axes, metrics):
88
+ values = df[metric].dropna()
89
+
90
+ if sns is not None:
91
+ sns.violinplot(y=values, ax=ax, inner="box")
92
+ else:
93
+ ax.boxplot(values, vert=True)
94
+
95
+ ax.set_title(metric)
96
+ ax.set_ylabel("Value")
97
+
98
+ # Add mean annotation
99
+ mean_val = values.mean()
100
+ ax.axhline(mean_val, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_val:.3f}')
101
+ ax.legend(loc='upper right', fontsize=8)
102
+
103
+ plt.suptitle("Metric Distributions Across Contexts (DEGs only)", y=1.02)
104
+ plt.tight_layout()
105
+
106
+ if save_path:
107
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
108
+
109
+ if show:
110
+ plt.show()
111
+
112
+ return fig
113
+
114
+
115
+ def plot_context_heatmap(
116
+ results: "DEGEvaluationResult",
117
+ metrics: Optional[List[str]] = None,
118
+ figsize: tuple = (12, 8),
119
+ cmap: str = "RdYlBu_r",
120
+ save_path: Optional[Union[str, Path]] = None,
121
+ show: bool = True,
122
+ ) -> "plt.Figure":
123
+ """
124
+ Plot heatmap of metrics across contexts.
125
+
126
+ Parameters
127
+ ----------
128
+ results : DEGEvaluationResult
129
+ DEG evaluation results
130
+ metrics : List[str], optional
131
+ Metrics to include
132
+ figsize : tuple
133
+ Figure size
134
+ cmap : str
135
+ Colormap name
136
+ save_path : str or Path, optional
137
+ Path to save figure
138
+ show : bool
139
+ Whether to display
140
+
141
+ Returns
142
+ -------
143
+ plt.Figure
144
+ Matplotlib figure
145
+ """
146
+ plt = _check_matplotlib()
147
+ sns = _check_seaborn()
148
+
149
+ df = results.expanded_metrics
150
+
151
+ if metrics is None:
152
+ metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
153
+
154
+ # Create matrix for heatmap
155
+ heatmap_data = df.set_index("context_id")[metrics]
156
+
157
+ fig, ax = plt.subplots(figsize=figsize)
158
+
159
+ if sns is not None:
160
+ sns.heatmap(
161
+ heatmap_data,
162
+ ax=ax,
163
+ cmap=cmap,
164
+ annot=True,
165
+ fmt=".3f",
166
+ cbar_kws={"label": "Metric Value"},
167
+ )
168
+ else:
169
+ im = ax.imshow(heatmap_data.values, cmap=cmap, aspect='auto')
170
+ plt.colorbar(im, ax=ax, label="Metric Value")
171
+ ax.set_xticks(range(len(metrics)))
172
+ ax.set_xticklabels(metrics, rotation=45, ha='right')
173
+ ax.set_yticks(range(len(heatmap_data)))
174
+ ax.set_yticklabels(heatmap_data.index)
175
+
176
+ ax.set_title("Metrics per Context (DEGs only)")
177
+ ax.set_xlabel("Metric")
178
+ ax.set_ylabel("Context")
179
+
180
+ plt.tight_layout()
181
+
182
+ if save_path:
183
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
184
+
185
+ if show:
186
+ plt.show()
187
+
188
+ return fig
189
+
190
+
191
+ def plot_deg_counts(
192
+ results: "DEGEvaluationResult",
193
+ figsize: tuple = (10, 6),
194
+ save_path: Optional[Union[str, Path]] = None,
195
+ show: bool = True,
196
+ ) -> "plt.Figure":
197
+ """
198
+ Plot DEG counts per context.
199
+
200
+ Parameters
201
+ ----------
202
+ results : DEGEvaluationResult
203
+ DEG evaluation results
204
+ figsize : tuple
205
+ Figure size
206
+ save_path : str or Path, optional
207
+ Path to save figure
208
+ show : bool
209
+ Whether to display
210
+
211
+ Returns
212
+ -------
213
+ plt.Figure
214
+ Matplotlib figure
215
+ """
216
+ plt = _check_matplotlib()
217
+
218
+ df = results.deg_summary
219
+
220
+ if len(df) == 0:
221
+ raise ValueError("No DEG data to plot")
222
+
223
+ fig, ax = plt.subplots(figsize=figsize)
224
+
225
+ x = range(len(df))
226
+ width = 0.35
227
+
228
+ if "n_upregulated" in df.columns and "n_downregulated" in df.columns:
229
+ ax.bar(x, df["n_upregulated"], width, label='Upregulated', color='red', alpha=0.7)
230
+ ax.bar(x, -df["n_downregulated"], width, label='Downregulated', color='blue', alpha=0.7)
231
+ ax.axhline(0, color='black', linewidth=0.5)
232
+ ax.set_ylabel("Number of DEGs")
233
+ else:
234
+ ax.bar(x, df["n_degs"], color='purple', alpha=0.7)
235
+ ax.set_ylabel("Number of DEGs")
236
+
237
+ ax.set_xlabel("Context")
238
+ ax.set_xticks(x)
239
+ ax.set_xticklabels(df["context_id"], rotation=45, ha='right')
240
+ ax.set_title("DEG Counts per Context")
241
+ ax.legend()
242
+
243
+ plt.tight_layout()
244
+
245
+ if save_path:
246
+ fig.savefig(save_path, dpi=150, bbox_inches='tight')
247
+
248
+ if show:
249
+ plt.show()
250
+
251
+ return fig
252
+
253
+
254
+ def create_deg_report(
255
+ results: "DEGEvaluationResult",
256
+ output_dir: Union[str, Path],
257
+ include_plots: bool = True,
258
+ ) -> None:
259
+ """
260
+ Create a comprehensive DEG evaluation report.
261
+
262
+ Creates:
263
+ - Summary statistics (CSV)
264
+ - Aggregated metrics (CSV)
265
+ - Expanded per-context metrics (CSV)
266
+ - DEG summary per context (CSV)
267
+ - Visualization plots (PNG, if include_plots=True)
268
+ - Markdown report
269
+
270
+ Parameters
271
+ ----------
272
+ results : DEGEvaluationResult
273
+ DEG evaluation results
274
+ output_dir : str or Path
275
+ Output directory
276
+ include_plots : bool
277
+ Whether to generate plots
278
+ """
279
+ output_dir = Path(output_dir)
280
+ output_dir.mkdir(parents=True, exist_ok=True)
281
+
282
+ # Save CSVs
283
+ results.save(output_dir)
284
+
285
+ # Generate plots
286
+ if include_plots:
287
+ try:
288
+ plt = _check_matplotlib()
289
+
290
+ # Distribution plot
291
+ if len(results.expanded_metrics) > 0:
292
+ plot_deg_distributions(
293
+ results,
294
+ save_path=output_dir / "deg_metric_distributions.png",
295
+ show=False,
296
+ )
297
+ plt.close()
298
+
299
+ # Heatmap
300
+ plot_context_heatmap(
301
+ results,
302
+ save_path=output_dir / "deg_context_heatmap.png",
303
+ show=False,
304
+ )
305
+ plt.close()
306
+
307
+ # DEG counts
308
+ if len(results.deg_summary) > 0:
309
+ plot_deg_counts(
310
+ results,
311
+ save_path=output_dir / "deg_counts.png",
312
+ show=False,
313
+ )
314
+ plt.close()
315
+
316
+ except ImportError:
317
+ pass # Skip plots if matplotlib not available
318
+
319
+ # Create markdown report
320
+ report_lines = [
321
+ "# DEG Evaluation Report",
322
+ "",
323
+ "## Summary",
324
+ "",
325
+ f"- **Number of contexts evaluated**: {len(results.context_results)}",
326
+ f"- **DEG detection method**: {results.settings.get('deg_method', 'N/A')}",
327
+ f"- **P-value threshold**: {results.settings.get('pval_threshold', 'N/A')}",
328
+ f"- **Log fold change threshold**: {results.settings.get('lfc_threshold', 'N/A')}",
329
+ f"- **Compute device**: {results.settings.get('device', 'N/A')}",
330
+ "",
331
+ "## Aggregated Metrics",
332
+ "",
333
+ ]
334
+
335
+ if len(results.aggregated_metrics) > 0:
336
+ for col in results.aggregated_metrics.columns:
337
+ val = results.aggregated_metrics[col].iloc[0]
338
+ if isinstance(val, float):
339
+ report_lines.append(f"- **{col}**: {val:.4f}")
340
+ else:
341
+ report_lines.append(f"- **{col}**: {val}")
342
+ else:
343
+ report_lines.append("No aggregated metrics available.")
344
+
345
+ report_lines.extend([
346
+ "",
347
+ "## DEG Summary",
348
+ "",
349
+ ])
350
+
351
+ if len(results.deg_summary) > 0:
352
+ report_lines.append(results.deg_summary.to_markdown(index=False))
353
+ else:
354
+ report_lines.append("No DEG data available.")
355
+
356
+ report_lines.extend([
357
+ "",
358
+ "## Files Generated",
359
+ "",
360
+ "- `deg_aggregated_metrics.csv`: Summary metrics across all contexts",
361
+ "- `deg_expanded_metrics.csv`: Per-context metrics",
362
+ "- `deg_summary.csv`: DEG detection summary",
363
+ "- `deg_per_context/`: Individual DEG results per context",
364
+ ])
365
+
366
+ if include_plots:
367
+ report_lines.extend([
368
+ "- `deg_metric_distributions.png`: Distribution plots",
369
+ "- `deg_context_heatmap.png`: Context × metric heatmap",
370
+ "- `deg_counts.png`: DEG counts per context",
371
+ ])
372
+
373
+ with open(output_dir / "DEG_REPORT.md", "w") as f:
374
+ f.write("\n".join(report_lines))
375
+
376
+ print(f"DEG report saved to {output_dir}")
geneval/evaluator.py CHANGED
@@ -66,6 +66,10 @@ class GeneEvalEvaluator:
66
66
  Whether to include multivariate (whole-space) metrics
67
67
  verbose : bool
68
68
  Whether to print progress
69
+ n_jobs : int
70
+ Number of parallel CPU jobs. -1 uses all cores. Default is 1.
71
+ device : str
72
+ Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
69
73
 
70
74
  Examples
71
75
  --------
@@ -73,6 +77,10 @@ class GeneEvalEvaluator:
73
77
  >>> evaluator = GeneEvalEvaluator(loader)
74
78
  >>> results = evaluator.evaluate()
75
79
  >>> results.save("output/")
80
+
81
+ >>> # With acceleration
82
+ >>> evaluator = GeneEvalEvaluator(loader, n_jobs=8, device="cuda")
83
+ >>> results = evaluator.evaluate()
76
84
  """
77
85
 
78
86
  def __init__(
@@ -82,11 +90,15 @@ class GeneEvalEvaluator:
82
90
  aggregate_method: str = "mean",
83
91
  include_multivariate: bool = True,
84
92
  verbose: bool = True,
93
+ n_jobs: int = 1,
94
+ device: str = "cpu",
85
95
  ):
86
96
  self.data_loader = data_loader
87
97
  self.aggregate_method = aggregate_method
88
98
  self.include_multivariate = include_multivariate
89
99
  self.verbose = verbose
100
+ self.n_jobs = n_jobs
101
+ self.device = device
90
102
 
91
103
  # Initialize metrics
92
104
  self.metrics: List[BaseMetric] = []
@@ -106,6 +118,25 @@ class GeneEvalEvaluator:
106
118
  MultivariateWasserstein(),
107
119
  MultivariateMMD(),
108
120
  ])
121
+
122
+ # Initialize accelerated computer if using parallelization or GPU
123
+ self._parallel_computer = None
124
+ if n_jobs != 1 or device != "cpu":
125
+ try:
126
+ from .metrics.accelerated import ParallelMetricComputer
127
+ self._parallel_computer = ParallelMetricComputer(
128
+ n_jobs=n_jobs,
129
+ device=device,
130
+ verbose=verbose,
131
+ )
132
+ if verbose:
133
+ from .metrics.accelerated import get_available_backends
134
+ backends = get_available_backends()
135
+ self._log(f"Acceleration enabled: n_jobs={n_jobs}, device={device}")
136
+ self._log(f"Available backends: {backends}")
137
+ except ImportError as e:
138
+ if verbose:
139
+ self._log(f"Warning: Could not enable acceleration: {e}")
109
140
 
110
141
  def _log(self, msg: str):
111
142
  """Print message if verbose."""
@@ -262,6 +293,8 @@ def evaluate(
262
293
  metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
263
294
  include_multivariate: bool = True,
264
295
  verbose: bool = True,
296
+ n_jobs: int = 1,
297
+ device: str = "cpu",
265
298
  **loader_kwargs
266
299
  ) -> EvaluationResult:
267
300
  """
@@ -285,6 +318,10 @@ def evaluate(
285
318
  Whether to include multivariate metrics
286
319
  verbose : bool
287
320
  Print progress
321
+ n_jobs : int
322
+ Number of parallel CPU jobs. -1 uses all cores. Default is 1.
323
+ device : str
324
+ Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
288
325
  **loader_kwargs
289
326
  Additional arguments for data loader
290
327
 
@@ -295,6 +332,7 @@ def evaluate(
295
332
 
296
333
  Examples
297
334
  --------
335
+ >>> # Standard CPU evaluation
298
336
  >>> results = evaluate(
299
337
  ... "real.h5ad",
300
338
  ... "generated.h5ad",
@@ -302,6 +340,12 @@ def evaluate(
302
340
  ... split_column="split",
303
341
  ... output_dir="evaluation_output/"
304
342
  ... )
343
+
344
+ >>> # Parallel CPU evaluation (8 cores)
345
+ >>> results = evaluate(..., n_jobs=8)
346
+
347
+ >>> # GPU-accelerated evaluation
348
+ >>> results = evaluate(..., device="cuda")
305
349
  """
306
350
  # Load data
307
351
  loader = load_data(
@@ -318,6 +362,8 @@ def evaluate(
318
362
  metrics=metrics,
319
363
  include_multivariate=include_multivariate,
320
364
  verbose=verbose,
365
+ n_jobs=n_jobs,
366
+ device=device,
321
367
  )
322
368
 
323
369
  # Run evaluation
@@ -35,6 +35,20 @@ from .reconstruction import (
35
35
  R2Score,
36
36
  )
37
37
 
38
+ # Accelerated computation
39
+ from .accelerated import (
40
+ AccelerationConfig,
41
+ ParallelMetricComputer,
42
+ get_available_backends,
43
+ compute_metrics_accelerated,
44
+ GPUWasserstein1,
45
+ GPUWasserstein2,
46
+ GPUMMD,
47
+ GPUEnergyDistance,
48
+ vectorized_wasserstein1,
49
+ vectorized_mmd,
50
+ )
51
+
38
52
  # All available metrics
39
53
  ALL_METRICS = [
40
54
  # Reconstruction
@@ -81,4 +95,15 @@ __all__ = [
81
95
  "MultivariateMMD",
82
96
  # Collections
83
97
  "ALL_METRICS",
98
+ # Acceleration
99
+ "AccelerationConfig",
100
+ "ParallelMetricComputer",
101
+ "get_available_backends",
102
+ "compute_metrics_accelerated",
103
+ "GPUWasserstein1",
104
+ "GPUWasserstein2",
105
+ "GPUMMD",
106
+ "GPUEnergyDistance",
107
+ "vectorized_wasserstein1",
108
+ "vectorized_mmd",
84
109
  ]