gengeneeval 0.2.1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/PKG-INFO +76 -2
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/README.md +73 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/pyproject.toml +3 -2
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/__init__.py +14 -1
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/evaluator.py +46 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/__init__.py +25 -0
- gengeneeval-0.3.0/src/geneval/metrics/accelerated.py +857 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/LICENSE +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/cli.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/config.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/core.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/data/__init__.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/data/gene_expression_datamodule.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/data/lazy_loader.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/data/loader.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/evaluators/__init__.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/evaluators/base_evaluator.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/evaluators/gene_expression_evaluator.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/lazy_evaluator.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/base_metric.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/correlation.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/distances.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/metrics.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/metrics/reconstruction.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/models/__init__.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/models/base_model.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/results.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/testing.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/utils/__init__.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/utils/io.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/utils/preprocessing.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/visualization/__init__.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/visualization/plots.py +0 -0
- {gengeneeval-0.2.1 → gengeneeval-0.3.0}/src/geneval/visualization/visualizer.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gengeneeval
|
|
3
|
-
Version: 0.
|
|
4
|
-
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations.
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
7
7
|
Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
|
|
@@ -24,6 +24,7 @@ Provides-Extra: full
|
|
|
24
24
|
Provides-Extra: gpu
|
|
25
25
|
Requires-Dist: anndata (>=0.8.0)
|
|
26
26
|
Requires-Dist: geomloss (>=0.2.1) ; extra == "full" or extra == "gpu"
|
|
27
|
+
Requires-Dist: joblib (>=1.0.0)
|
|
27
28
|
Requires-Dist: matplotlib (>=3.5.0)
|
|
28
29
|
Requires-Dist: numpy (>=1.21.0)
|
|
29
30
|
Requires-Dist: pandas (>=1.3.0)
|
|
@@ -79,6 +80,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
79
80
|
- ✅ Per-gene and aggregate metrics
|
|
80
81
|
- ✅ **Memory-efficient lazy loading** for large datasets
|
|
81
82
|
- ✅ **Batched evaluation** to avoid OOM errors
|
|
83
|
+
- ✅ **CPU parallelization** via joblib (multi-core speedup)
|
|
84
|
+
- ✅ **GPU acceleration** via PyTorch (10-100x speedup)
|
|
82
85
|
- ✅ Modular, extensible architecture
|
|
83
86
|
- ✅ Command-line interface
|
|
84
87
|
- ✅ Publication-quality visualizations
|
|
@@ -173,6 +176,77 @@ with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
|
|
|
173
176
|
pass
|
|
174
177
|
```
|
|
175
178
|
|
|
179
|
+
### Accelerated Evaluation (CPU Parallelization & GPU)
|
|
180
|
+
|
|
181
|
+
GenEval supports CPU parallelization and GPU acceleration for significant speedups:
|
|
182
|
+
|
|
183
|
+
```python
|
|
184
|
+
from geneval import evaluate, get_available_backends
|
|
185
|
+
|
|
186
|
+
# Check available backends
|
|
187
|
+
print(get_available_backends())
|
|
188
|
+
# {'joblib': True, 'torch': True, 'geomloss': True, 'cuda': True, 'mps': False}
|
|
189
|
+
|
|
190
|
+
# Parallel CPU evaluation (use all cores)
|
|
191
|
+
results = evaluate(
|
|
192
|
+
real_path="real.h5ad",
|
|
193
|
+
generated_path="generated.h5ad",
|
|
194
|
+
condition_columns=["perturbation"],
|
|
195
|
+
n_jobs=-1, # Use all available CPU cores
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# GPU-accelerated evaluation
|
|
199
|
+
results = evaluate(
|
|
200
|
+
real_path="real.h5ad",
|
|
201
|
+
generated_path="generated.h5ad",
|
|
202
|
+
condition_columns=["perturbation"],
|
|
203
|
+
device="cuda", # Use NVIDIA GPU
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Combined: parallel CPU + auto device selection
|
|
207
|
+
results = evaluate(..., n_jobs=8, device="auto")
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
#### Low-level Accelerated API
|
|
211
|
+
|
|
212
|
+
For custom workflows, use the accelerated metrics directly:
|
|
213
|
+
|
|
214
|
+
```python
|
|
215
|
+
from geneval.metrics.accelerated import (
|
|
216
|
+
compute_metrics_accelerated,
|
|
217
|
+
GPUWasserstein1,
|
|
218
|
+
GPUMMD,
|
|
219
|
+
vectorized_wasserstein1,
|
|
220
|
+
)
|
|
221
|
+
import numpy as np
|
|
222
|
+
|
|
223
|
+
# Load your data
|
|
224
|
+
real = np.random.randn(1000, 5000) # 1000 cells, 5000 genes
|
|
225
|
+
generated = np.random.randn(1000, 5000)
|
|
226
|
+
|
|
227
|
+
# Compute multiple metrics with acceleration
|
|
228
|
+
results = compute_metrics_accelerated(
|
|
229
|
+
real, generated,
|
|
230
|
+
metrics=["wasserstein_1", "wasserstein_2", "mmd", "energy"],
|
|
231
|
+
n_jobs=8, # CPU parallelization
|
|
232
|
+
device="cuda", # GPU acceleration
|
|
233
|
+
verbose=True,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Access results
|
|
237
|
+
print(f"W1: {results['wasserstein_1'].aggregate_value:.4f}")
|
|
238
|
+
print(f"MMD: {results['mmd'].aggregate_value:.4f}")
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
#### Performance Tips
|
|
242
|
+
|
|
243
|
+
| Optimization | Speedup | When to Use |
|
|
244
|
+
|--------------|---------|-------------|
|
|
245
|
+
| `n_jobs=-1` (all cores) | 4-16x | Always (if joblib available) |
|
|
246
|
+
| `device="cuda"` | 10-100x | Large datasets, NVIDIA GPU available |
|
|
247
|
+
| `device="mps"` | 5-20x | Apple Silicon Macs |
|
|
248
|
+
| Vectorized NumPy | 2-5x | Automatic fallback |
|
|
249
|
+
|
|
176
250
|
## Expected Data Format
|
|
177
251
|
|
|
178
252
|
GenEval expects AnnData (h5ad) files with:
|
|
@@ -40,6 +40,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
40
40
|
- ✅ Per-gene and aggregate metrics
|
|
41
41
|
- ✅ **Memory-efficient lazy loading** for large datasets
|
|
42
42
|
- ✅ **Batched evaluation** to avoid OOM errors
|
|
43
|
+
- ✅ **CPU parallelization** via joblib (multi-core speedup)
|
|
44
|
+
- ✅ **GPU acceleration** via PyTorch (10-100x speedup)
|
|
43
45
|
- ✅ Modular, extensible architecture
|
|
44
46
|
- ✅ Command-line interface
|
|
45
47
|
- ✅ Publication-quality visualizations
|
|
@@ -134,6 +136,77 @@ with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
|
|
|
134
136
|
pass
|
|
135
137
|
```
|
|
136
138
|
|
|
139
|
+
### Accelerated Evaluation (CPU Parallelization & GPU)
|
|
140
|
+
|
|
141
|
+
GenEval supports CPU parallelization and GPU acceleration for significant speedups:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
from geneval import evaluate, get_available_backends
|
|
145
|
+
|
|
146
|
+
# Check available backends
|
|
147
|
+
print(get_available_backends())
|
|
148
|
+
# {'joblib': True, 'torch': True, 'geomloss': True, 'cuda': True, 'mps': False}
|
|
149
|
+
|
|
150
|
+
# Parallel CPU evaluation (use all cores)
|
|
151
|
+
results = evaluate(
|
|
152
|
+
real_path="real.h5ad",
|
|
153
|
+
generated_path="generated.h5ad",
|
|
154
|
+
condition_columns=["perturbation"],
|
|
155
|
+
n_jobs=-1, # Use all available CPU cores
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# GPU-accelerated evaluation
|
|
159
|
+
results = evaluate(
|
|
160
|
+
real_path="real.h5ad",
|
|
161
|
+
generated_path="generated.h5ad",
|
|
162
|
+
condition_columns=["perturbation"],
|
|
163
|
+
device="cuda", # Use NVIDIA GPU
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Combined: parallel CPU + auto device selection
|
|
167
|
+
results = evaluate(..., n_jobs=8, device="auto")
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
#### Low-level Accelerated API
|
|
171
|
+
|
|
172
|
+
For custom workflows, use the accelerated metrics directly:
|
|
173
|
+
|
|
174
|
+
```python
|
|
175
|
+
from geneval.metrics.accelerated import (
|
|
176
|
+
compute_metrics_accelerated,
|
|
177
|
+
GPUWasserstein1,
|
|
178
|
+
GPUMMD,
|
|
179
|
+
vectorized_wasserstein1,
|
|
180
|
+
)
|
|
181
|
+
import numpy as np
|
|
182
|
+
|
|
183
|
+
# Load your data
|
|
184
|
+
real = np.random.randn(1000, 5000) # 1000 cells, 5000 genes
|
|
185
|
+
generated = np.random.randn(1000, 5000)
|
|
186
|
+
|
|
187
|
+
# Compute multiple metrics with acceleration
|
|
188
|
+
results = compute_metrics_accelerated(
|
|
189
|
+
real, generated,
|
|
190
|
+
metrics=["wasserstein_1", "wasserstein_2", "mmd", "energy"],
|
|
191
|
+
n_jobs=8, # CPU parallelization
|
|
192
|
+
device="cuda", # GPU acceleration
|
|
193
|
+
verbose=True,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Access results
|
|
197
|
+
print(f"W1: {results['wasserstein_1'].aggregate_value:.4f}")
|
|
198
|
+
print(f"MMD: {results['mmd'].aggregate_value:.4f}")
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
#### Performance Tips
|
|
202
|
+
|
|
203
|
+
| Optimization | Speedup | When to Use |
|
|
204
|
+
|--------------|---------|-------------|
|
|
205
|
+
| `n_jobs=-1` (all cores) | 4-16x | Always (if joblib available) |
|
|
206
|
+
| `device="cuda"` | 10-100x | Large datasets, NVIDIA GPU available |
|
|
207
|
+
| `device="mps"` | 5-20x | Apple Silicon Macs |
|
|
208
|
+
| Vectorized NumPy | 2-5x | Automatic fallback |
|
|
209
|
+
|
|
137
210
|
## Expected Data Format
|
|
138
211
|
|
|
139
212
|
GenEval expects AnnData (h5ad) files with:
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "gengeneeval"
|
|
3
|
-
version = "0.
|
|
4
|
-
description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations."
|
|
3
|
+
version = "0.3.0"
|
|
4
|
+
description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations."
|
|
5
5
|
authors = ["GenEval Team <geneval@example.com>"]
|
|
6
6
|
license = "MIT"
|
|
7
7
|
readme = "README.md"
|
|
@@ -29,6 +29,7 @@ scipy = ">=1.7.0"
|
|
|
29
29
|
torch = ">=1.9.0"
|
|
30
30
|
matplotlib = ">=3.5.0"
|
|
31
31
|
seaborn = ">=0.11.0"
|
|
32
|
+
joblib = ">=1.0.0"
|
|
32
33
|
geomloss = {version = ">=0.2.1", optional = true}
|
|
33
34
|
pykeops = {version = ">=1.4.0", optional = true}
|
|
34
35
|
umap-learn = {version = ">=0.5.0", optional = true}
|
|
@@ -36,7 +36,7 @@ CLI Usage:
|
|
|
36
36
|
--conditions perturbation cell_type --output results/
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
__version__ = "0.
|
|
39
|
+
__version__ = "0.3.0"
|
|
40
40
|
__author__ = "GenEval Team"
|
|
41
41
|
|
|
42
42
|
# Main evaluation interface
|
|
@@ -101,6 +101,14 @@ from .metrics.reconstruction import (
|
|
|
101
101
|
R2Score,
|
|
102
102
|
)
|
|
103
103
|
|
|
104
|
+
# Accelerated computation
|
|
105
|
+
from .metrics.accelerated import (
|
|
106
|
+
AccelerationConfig,
|
|
107
|
+
ParallelMetricComputer,
|
|
108
|
+
get_available_backends,
|
|
109
|
+
compute_metrics_accelerated,
|
|
110
|
+
)
|
|
111
|
+
|
|
104
112
|
# Visualization
|
|
105
113
|
from .visualization.visualizer import (
|
|
106
114
|
EvaluationVisualizer,
|
|
@@ -161,6 +169,11 @@ __all__ = [
|
|
|
161
169
|
"RMSEDistance",
|
|
162
170
|
"MAEDistance",
|
|
163
171
|
"R2Score",
|
|
172
|
+
# Acceleration
|
|
173
|
+
"AccelerationConfig",
|
|
174
|
+
"ParallelMetricComputer",
|
|
175
|
+
"get_available_backends",
|
|
176
|
+
"compute_metrics_accelerated",
|
|
164
177
|
# Visualization
|
|
165
178
|
"EvaluationVisualizer",
|
|
166
179
|
"visualize",
|
|
@@ -66,6 +66,10 @@ class GeneEvalEvaluator:
|
|
|
66
66
|
Whether to include multivariate (whole-space) metrics
|
|
67
67
|
verbose : bool
|
|
68
68
|
Whether to print progress
|
|
69
|
+
n_jobs : int
|
|
70
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
71
|
+
device : str
|
|
72
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
69
73
|
|
|
70
74
|
Examples
|
|
71
75
|
--------
|
|
@@ -73,6 +77,10 @@ class GeneEvalEvaluator:
|
|
|
73
77
|
>>> evaluator = GeneEvalEvaluator(loader)
|
|
74
78
|
>>> results = evaluator.evaluate()
|
|
75
79
|
>>> results.save("output/")
|
|
80
|
+
|
|
81
|
+
>>> # With acceleration
|
|
82
|
+
>>> evaluator = GeneEvalEvaluator(loader, n_jobs=8, device="cuda")
|
|
83
|
+
>>> results = evaluator.evaluate()
|
|
76
84
|
"""
|
|
77
85
|
|
|
78
86
|
def __init__(
|
|
@@ -82,11 +90,15 @@ class GeneEvalEvaluator:
|
|
|
82
90
|
aggregate_method: str = "mean",
|
|
83
91
|
include_multivariate: bool = True,
|
|
84
92
|
verbose: bool = True,
|
|
93
|
+
n_jobs: int = 1,
|
|
94
|
+
device: str = "cpu",
|
|
85
95
|
):
|
|
86
96
|
self.data_loader = data_loader
|
|
87
97
|
self.aggregate_method = aggregate_method
|
|
88
98
|
self.include_multivariate = include_multivariate
|
|
89
99
|
self.verbose = verbose
|
|
100
|
+
self.n_jobs = n_jobs
|
|
101
|
+
self.device = device
|
|
90
102
|
|
|
91
103
|
# Initialize metrics
|
|
92
104
|
self.metrics: List[BaseMetric] = []
|
|
@@ -106,6 +118,25 @@ class GeneEvalEvaluator:
|
|
|
106
118
|
MultivariateWasserstein(),
|
|
107
119
|
MultivariateMMD(),
|
|
108
120
|
])
|
|
121
|
+
|
|
122
|
+
# Initialize accelerated computer if using parallelization or GPU
|
|
123
|
+
self._parallel_computer = None
|
|
124
|
+
if n_jobs != 1 or device != "cpu":
|
|
125
|
+
try:
|
|
126
|
+
from .metrics.accelerated import ParallelMetricComputer
|
|
127
|
+
self._parallel_computer = ParallelMetricComputer(
|
|
128
|
+
n_jobs=n_jobs,
|
|
129
|
+
device=device,
|
|
130
|
+
verbose=verbose,
|
|
131
|
+
)
|
|
132
|
+
if verbose:
|
|
133
|
+
from .metrics.accelerated import get_available_backends
|
|
134
|
+
backends = get_available_backends()
|
|
135
|
+
self._log(f"Acceleration enabled: n_jobs={n_jobs}, device={device}")
|
|
136
|
+
self._log(f"Available backends: {backends}")
|
|
137
|
+
except ImportError as e:
|
|
138
|
+
if verbose:
|
|
139
|
+
self._log(f"Warning: Could not enable acceleration: {e}")
|
|
109
140
|
|
|
110
141
|
def _log(self, msg: str):
|
|
111
142
|
"""Print message if verbose."""
|
|
@@ -262,6 +293,8 @@ def evaluate(
|
|
|
262
293
|
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
263
294
|
include_multivariate: bool = True,
|
|
264
295
|
verbose: bool = True,
|
|
296
|
+
n_jobs: int = 1,
|
|
297
|
+
device: str = "cpu",
|
|
265
298
|
**loader_kwargs
|
|
266
299
|
) -> EvaluationResult:
|
|
267
300
|
"""
|
|
@@ -285,6 +318,10 @@ def evaluate(
|
|
|
285
318
|
Whether to include multivariate metrics
|
|
286
319
|
verbose : bool
|
|
287
320
|
Print progress
|
|
321
|
+
n_jobs : int
|
|
322
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
323
|
+
device : str
|
|
324
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
288
325
|
**loader_kwargs
|
|
289
326
|
Additional arguments for data loader
|
|
290
327
|
|
|
@@ -295,6 +332,7 @@ def evaluate(
|
|
|
295
332
|
|
|
296
333
|
Examples
|
|
297
334
|
--------
|
|
335
|
+
>>> # Standard CPU evaluation
|
|
298
336
|
>>> results = evaluate(
|
|
299
337
|
... "real.h5ad",
|
|
300
338
|
... "generated.h5ad",
|
|
@@ -302,6 +340,12 @@ def evaluate(
|
|
|
302
340
|
... split_column="split",
|
|
303
341
|
... output_dir="evaluation_output/"
|
|
304
342
|
... )
|
|
343
|
+
|
|
344
|
+
>>> # Parallel CPU evaluation (8 cores)
|
|
345
|
+
>>> results = evaluate(..., n_jobs=8)
|
|
346
|
+
|
|
347
|
+
>>> # GPU-accelerated evaluation
|
|
348
|
+
>>> results = evaluate(..., device="cuda")
|
|
305
349
|
"""
|
|
306
350
|
# Load data
|
|
307
351
|
loader = load_data(
|
|
@@ -318,6 +362,8 @@ def evaluate(
|
|
|
318
362
|
metrics=metrics,
|
|
319
363
|
include_multivariate=include_multivariate,
|
|
320
364
|
verbose=verbose,
|
|
365
|
+
n_jobs=n_jobs,
|
|
366
|
+
device=device,
|
|
321
367
|
)
|
|
322
368
|
|
|
323
369
|
# Run evaluation
|
|
@@ -35,6 +35,20 @@ from .reconstruction import (
|
|
|
35
35
|
R2Score,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
+
# Accelerated computation
|
|
39
|
+
from .accelerated import (
|
|
40
|
+
AccelerationConfig,
|
|
41
|
+
ParallelMetricComputer,
|
|
42
|
+
get_available_backends,
|
|
43
|
+
compute_metrics_accelerated,
|
|
44
|
+
GPUWasserstein1,
|
|
45
|
+
GPUWasserstein2,
|
|
46
|
+
GPUMMD,
|
|
47
|
+
GPUEnergyDistance,
|
|
48
|
+
vectorized_wasserstein1,
|
|
49
|
+
vectorized_mmd,
|
|
50
|
+
)
|
|
51
|
+
|
|
38
52
|
# All available metrics
|
|
39
53
|
ALL_METRICS = [
|
|
40
54
|
# Reconstruction
|
|
@@ -81,4 +95,15 @@ __all__ = [
|
|
|
81
95
|
"MultivariateMMD",
|
|
82
96
|
# Collections
|
|
83
97
|
"ALL_METRICS",
|
|
98
|
+
# Acceleration
|
|
99
|
+
"AccelerationConfig",
|
|
100
|
+
"ParallelMetricComputer",
|
|
101
|
+
"get_available_backends",
|
|
102
|
+
"compute_metrics_accelerated",
|
|
103
|
+
"GPUWasserstein1",
|
|
104
|
+
"GPUWasserstein2",
|
|
105
|
+
"GPUMMD",
|
|
106
|
+
"GPUEnergyDistance",
|
|
107
|
+
"vectorized_wasserstein1",
|
|
108
|
+
"vectorized_mmd",
|
|
84
109
|
]
|
|
@@ -0,0 +1,857 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Accelerated metric computation with CPU parallelization and GPU support.
|
|
3
|
+
|
|
4
|
+
This module provides performance optimizations for metric computation:
|
|
5
|
+
- CPU parallelization via joblib for multi-core speedup
|
|
6
|
+
- GPU acceleration via PyTorch/geomloss for batch computation
|
|
7
|
+
- Vectorized operations for improved NumPy performance
|
|
8
|
+
|
|
9
|
+
Example usage:
|
|
10
|
+
>>> from geneval.metrics.accelerated import ParallelMetricComputer
|
|
11
|
+
>>> computer = ParallelMetricComputer(n_jobs=8, device="cuda")
|
|
12
|
+
>>> results = computer.compute_all(real, generated, metrics)
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import warnings
|
|
17
|
+
from typing import List, Optional, Dict, Any, Union, Literal
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from .base_metric import BaseMetric, MetricResult
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Check for optional dependencies
|
|
25
|
+
try:
|
|
26
|
+
from joblib import Parallel, delayed
|
|
27
|
+
HAS_JOBLIB = True
|
|
28
|
+
except ImportError:
|
|
29
|
+
HAS_JOBLIB = False
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import torch
|
|
33
|
+
HAS_TORCH = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
HAS_TORCH = False
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
from geomloss import SamplesLoss
|
|
39
|
+
HAS_GEOMLOSS = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
HAS_GEOMLOSS = False
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class AccelerationConfig:
|
|
46
|
+
"""Configuration for accelerated metric computation.
|
|
47
|
+
|
|
48
|
+
Attributes
|
|
49
|
+
----------
|
|
50
|
+
n_jobs : int
|
|
51
|
+
Number of CPU jobs for parallel computation.
|
|
52
|
+
-1 uses all available cores. Default is 1 (no parallelization).
|
|
53
|
+
device : str
|
|
54
|
+
Device for computation: "cpu", "cuda", "cuda:0", etc.
|
|
55
|
+
Default is "cpu".
|
|
56
|
+
batch_genes : bool
|
|
57
|
+
If True, batch all genes for GPU computation. Default is True.
|
|
58
|
+
gene_batch_size : int or None
|
|
59
|
+
If set, process genes in batches of this size to manage memory.
|
|
60
|
+
None means process all genes at once.
|
|
61
|
+
prefer_gpu : bool
|
|
62
|
+
If True and GPU is available, prefer GPU implementations.
|
|
63
|
+
Default is True.
|
|
64
|
+
verbose : bool
|
|
65
|
+
Print acceleration info. Default is False.
|
|
66
|
+
"""
|
|
67
|
+
n_jobs: int = 1
|
|
68
|
+
device: str = "cpu"
|
|
69
|
+
batch_genes: bool = True
|
|
70
|
+
gene_batch_size: Optional[int] = None
|
|
71
|
+
prefer_gpu: bool = True
|
|
72
|
+
verbose: bool = False
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_available_backends() -> Dict[str, bool]:
|
|
76
|
+
"""Check which acceleration backends are available.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
Dict[str, bool]
|
|
81
|
+
Dictionary with backend availability.
|
|
82
|
+
"""
|
|
83
|
+
backends = {
|
|
84
|
+
"joblib": HAS_JOBLIB,
|
|
85
|
+
"torch": HAS_TORCH,
|
|
86
|
+
"geomloss": HAS_GEOMLOSS,
|
|
87
|
+
"cuda": HAS_TORCH and torch.cuda.is_available(),
|
|
88
|
+
"mps": HAS_TORCH and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
|
|
89
|
+
}
|
|
90
|
+
return backends
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _get_device(device: str) -> "torch.device":
|
|
94
|
+
"""Get PyTorch device, handling availability checks.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
device : str
|
|
99
|
+
Device string ("cpu", "cuda", "cuda:0", "mps", "auto")
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
torch.device
|
|
104
|
+
PyTorch device object
|
|
105
|
+
"""
|
|
106
|
+
if not HAS_TORCH:
|
|
107
|
+
raise ImportError("PyTorch is required for GPU acceleration")
|
|
108
|
+
|
|
109
|
+
if device == "auto":
|
|
110
|
+
if torch.cuda.is_available():
|
|
111
|
+
return torch.device("cuda")
|
|
112
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
113
|
+
return torch.device("mps")
|
|
114
|
+
else:
|
|
115
|
+
return torch.device("cpu")
|
|
116
|
+
|
|
117
|
+
return torch.device(device)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class ParallelMetricComputer:
|
|
121
|
+
"""Parallel and GPU-accelerated metric computation.
|
|
122
|
+
|
|
123
|
+
This class wraps metric computation with parallelization and GPU
|
|
124
|
+
acceleration options for significant speedups on large datasets.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
n_jobs : int
|
|
129
|
+
Number of parallel jobs. -1 for all cores.
|
|
130
|
+
device : str
|
|
131
|
+
Compute device ("cpu", "cuda", "auto")
|
|
132
|
+
batch_genes : bool
|
|
133
|
+
Whether to batch genes for GPU computation.
|
|
134
|
+
gene_batch_size : int, optional
|
|
135
|
+
Process genes in chunks of this size.
|
|
136
|
+
verbose : bool
|
|
137
|
+
Print progress information.
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
--------
|
|
141
|
+
>>> computer = ParallelMetricComputer(n_jobs=8)
|
|
142
|
+
>>> results = computer.compute_metric(metric, real, generated)
|
|
143
|
+
|
|
144
|
+
>>> # GPU acceleration
|
|
145
|
+
>>> computer = ParallelMetricComputer(device="cuda")
|
|
146
|
+
>>> results = computer.compute_metric(metric, real, generated)
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
n_jobs: int = 1,
|
|
152
|
+
device: str = "cpu",
|
|
153
|
+
batch_genes: bool = True,
|
|
154
|
+
gene_batch_size: Optional[int] = None,
|
|
155
|
+
verbose: bool = False,
|
|
156
|
+
):
|
|
157
|
+
self.n_jobs = n_jobs
|
|
158
|
+
self.device = device
|
|
159
|
+
self.batch_genes = batch_genes
|
|
160
|
+
self.gene_batch_size = gene_batch_size
|
|
161
|
+
self.verbose = verbose
|
|
162
|
+
|
|
163
|
+
# Validate configuration
|
|
164
|
+
if n_jobs != 1 and not HAS_JOBLIB:
|
|
165
|
+
warnings.warn("joblib not available, falling back to sequential processing")
|
|
166
|
+
self.n_jobs = 1
|
|
167
|
+
|
|
168
|
+
if device != "cpu" and not HAS_TORCH:
|
|
169
|
+
warnings.warn("PyTorch not available, falling back to CPU")
|
|
170
|
+
self.device = "cpu"
|
|
171
|
+
|
|
172
|
+
if self.verbose:
|
|
173
|
+
backends = get_available_backends()
|
|
174
|
+
print(f"Acceleration backends: {backends}")
|
|
175
|
+
print(f"Using n_jobs={self.n_jobs}, device={self.device}")
|
|
176
|
+
|
|
177
|
+
def compute_metric_parallel(
|
|
178
|
+
self,
|
|
179
|
+
metric: BaseMetric,
|
|
180
|
+
real: np.ndarray,
|
|
181
|
+
generated: np.ndarray,
|
|
182
|
+
gene_names: Optional[List[str]] = None,
|
|
183
|
+
) -> MetricResult:
|
|
184
|
+
"""Compute a metric with CPU parallelization.
|
|
185
|
+
|
|
186
|
+
Splits genes across multiple CPU cores for parallel computation.
|
|
187
|
+
|
|
188
|
+
Parameters
|
|
189
|
+
----------
|
|
190
|
+
metric : BaseMetric
|
|
191
|
+
Metric to compute
|
|
192
|
+
real : np.ndarray
|
|
193
|
+
Real data, shape (n_samples, n_genes)
|
|
194
|
+
generated : np.ndarray
|
|
195
|
+
Generated data, shape (n_samples, n_genes)
|
|
196
|
+
gene_names : List[str], optional
|
|
197
|
+
Gene names
|
|
198
|
+
|
|
199
|
+
Returns
|
|
200
|
+
-------
|
|
201
|
+
MetricResult
|
|
202
|
+
Computed metric result
|
|
203
|
+
"""
|
|
204
|
+
n_genes = real.shape[1]
|
|
205
|
+
if gene_names is None:
|
|
206
|
+
gene_names = [f"gene_{i}" for i in range(n_genes)]
|
|
207
|
+
|
|
208
|
+
if self.n_jobs == 1 or not HAS_JOBLIB:
|
|
209
|
+
# Sequential computation
|
|
210
|
+
per_gene = metric.compute_per_gene(real, generated)
|
|
211
|
+
else:
|
|
212
|
+
# Parallel computation across genes
|
|
213
|
+
if self.gene_batch_size:
|
|
214
|
+
# Process in batches
|
|
215
|
+
batches = [
|
|
216
|
+
(i, min(i + self.gene_batch_size, n_genes))
|
|
217
|
+
for i in range(0, n_genes, self.gene_batch_size)
|
|
218
|
+
]
|
|
219
|
+
else:
|
|
220
|
+
# Split evenly across jobs
|
|
221
|
+
n_effective_jobs = min(self.n_jobs if self.n_jobs > 0 else 8, n_genes)
|
|
222
|
+
batch_size = max(1, n_genes // n_effective_jobs)
|
|
223
|
+
batches = [
|
|
224
|
+
(i, min(i + batch_size, n_genes))
|
|
225
|
+
for i in range(0, n_genes, batch_size)
|
|
226
|
+
]
|
|
227
|
+
|
|
228
|
+
def compute_batch(start: int, end: int) -> np.ndarray:
|
|
229
|
+
return metric.compute_per_gene(
|
|
230
|
+
real[:, start:end],
|
|
231
|
+
generated[:, start:end]
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
results = Parallel(n_jobs=self.n_jobs, prefer="threads")(
|
|
235
|
+
delayed(compute_batch)(start, end) for start, end in batches
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
per_gene = np.concatenate(results)
|
|
239
|
+
|
|
240
|
+
aggregate = metric.compute_aggregate(per_gene, method="mean")
|
|
241
|
+
|
|
242
|
+
return MetricResult(
|
|
243
|
+
name=metric.name,
|
|
244
|
+
per_gene_values=per_gene,
|
|
245
|
+
gene_names=gene_names,
|
|
246
|
+
aggregate_value=aggregate,
|
|
247
|
+
aggregate_method="mean",
|
|
248
|
+
metadata={
|
|
249
|
+
"higher_is_better": metric.higher_is_better,
|
|
250
|
+
"accelerated": True,
|
|
251
|
+
"n_jobs": self.n_jobs,
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# =============================================================================
|
|
257
|
+
# GPU-Accelerated Distance Metrics
|
|
258
|
+
# =============================================================================
|
|
259
|
+
|
|
260
|
+
class GPUWasserstein1:
|
|
261
|
+
"""GPU-accelerated Wasserstein-1 distance computation.
|
|
262
|
+
|
|
263
|
+
Computes W1 distance for all genes in parallel on GPU using
|
|
264
|
+
vectorized sorting and quantile interpolation.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(self, device: str = "cuda"):
|
|
268
|
+
if not HAS_TORCH:
|
|
269
|
+
raise ImportError("PyTorch required for GPU acceleration")
|
|
270
|
+
self.device = _get_device(device)
|
|
271
|
+
|
|
272
|
+
def compute_batch(
|
|
273
|
+
self,
|
|
274
|
+
real: np.ndarray,
|
|
275
|
+
generated: np.ndarray,
|
|
276
|
+
) -> np.ndarray:
|
|
277
|
+
"""Compute W1 for all genes in batch on GPU.
|
|
278
|
+
|
|
279
|
+
Parameters
|
|
280
|
+
----------
|
|
281
|
+
real : np.ndarray
|
|
282
|
+
Real data, shape (n_samples_real, n_genes)
|
|
283
|
+
generated : np.ndarray
|
|
284
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
np.ndarray
|
|
289
|
+
W1 distance per gene
|
|
290
|
+
"""
|
|
291
|
+
# Move to GPU
|
|
292
|
+
real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
|
|
293
|
+
gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
|
|
294
|
+
|
|
295
|
+
n_genes = real_t.shape[1]
|
|
296
|
+
n_quantiles = max(real_t.shape[0], gen_t.shape[0])
|
|
297
|
+
|
|
298
|
+
# Sort each gene column
|
|
299
|
+
real_sorted, _ = torch.sort(real_t, dim=0)
|
|
300
|
+
gen_sorted, _ = torch.sort(gen_t, dim=0)
|
|
301
|
+
|
|
302
|
+
# Interpolate to same number of quantiles
|
|
303
|
+
quantile_positions = torch.linspace(0, 1, n_quantiles, device=self.device)
|
|
304
|
+
|
|
305
|
+
# Interpolate real
|
|
306
|
+
real_indices = quantile_positions * (real_sorted.shape[0] - 1)
|
|
307
|
+
real_floor = real_indices.long().clamp(0, real_sorted.shape[0] - 2)
|
|
308
|
+
real_frac = (real_indices - real_floor.float()).unsqueeze(1)
|
|
309
|
+
real_interp = (
|
|
310
|
+
real_sorted[real_floor] * (1 - real_frac) +
|
|
311
|
+
real_sorted[real_floor + 1] * real_frac
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Interpolate generated
|
|
315
|
+
gen_indices = quantile_positions * (gen_sorted.shape[0] - 1)
|
|
316
|
+
gen_floor = gen_indices.long().clamp(0, gen_sorted.shape[0] - 2)
|
|
317
|
+
gen_frac = (gen_indices - gen_floor.float()).unsqueeze(1)
|
|
318
|
+
gen_interp = (
|
|
319
|
+
gen_sorted[gen_floor] * (1 - gen_frac) +
|
|
320
|
+
gen_sorted[gen_floor + 1] * gen_frac
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# W1 = mean absolute difference
|
|
324
|
+
w1 = torch.mean(torch.abs(real_interp - gen_interp), dim=0)
|
|
325
|
+
|
|
326
|
+
return w1.cpu().numpy()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class GPUWasserstein2:
|
|
330
|
+
"""GPU-accelerated Wasserstein-2 distance using geomloss.
|
|
331
|
+
|
|
332
|
+
Batches all genes together for efficient GPU computation.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
def __init__(self, device: str = "cuda", blur: float = 0.01):
|
|
336
|
+
if not HAS_TORCH:
|
|
337
|
+
raise ImportError("PyTorch required for GPU acceleration")
|
|
338
|
+
if not HAS_GEOMLOSS:
|
|
339
|
+
raise ImportError("geomloss required for Wasserstein-2 GPU acceleration")
|
|
340
|
+
|
|
341
|
+
self.device = _get_device(device)
|
|
342
|
+
self.blur = blur
|
|
343
|
+
self.loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=blur, backend="tensorized")
|
|
344
|
+
|
|
345
|
+
def compute_batch(
|
|
346
|
+
self,
|
|
347
|
+
real: np.ndarray,
|
|
348
|
+
generated: np.ndarray,
|
|
349
|
+
) -> np.ndarray:
|
|
350
|
+
"""Compute W2 for all genes in batch on GPU.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
real : np.ndarray
|
|
355
|
+
Real data, shape (n_samples_real, n_genes)
|
|
356
|
+
generated : np.ndarray
|
|
357
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
358
|
+
|
|
359
|
+
Returns
|
|
360
|
+
-------
|
|
361
|
+
np.ndarray
|
|
362
|
+
W2 distance per gene
|
|
363
|
+
"""
|
|
364
|
+
n_genes = real.shape[1]
|
|
365
|
+
|
|
366
|
+
# Move to GPU
|
|
367
|
+
real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
|
|
368
|
+
gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
|
|
369
|
+
|
|
370
|
+
distances = torch.zeros(n_genes, device=self.device)
|
|
371
|
+
|
|
372
|
+
# Process each gene (geomloss requires separate calls per distribution pair)
|
|
373
|
+
# But we can batch by treating genes as batch dimension
|
|
374
|
+
for i in range(n_genes):
|
|
375
|
+
r = real_t[:, i:i+1] # Keep 2D
|
|
376
|
+
g = gen_t[:, i:i+1]
|
|
377
|
+
distances[i] = self.loss_fn(r, g)
|
|
378
|
+
|
|
379
|
+
return distances.cpu().numpy()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class GPUMMD:
|
|
383
|
+
"""GPU-accelerated MMD computation with RBF kernel.
|
|
384
|
+
|
|
385
|
+
Uses PyTorch for vectorized kernel computation across all genes.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def __init__(self, device: str = "cuda", sigma: Optional[float] = None):
|
|
389
|
+
if not HAS_TORCH:
|
|
390
|
+
raise ImportError("PyTorch required for GPU acceleration")
|
|
391
|
+
|
|
392
|
+
self.device = _get_device(device)
|
|
393
|
+
self.sigma = sigma
|
|
394
|
+
|
|
395
|
+
def compute_batch(
|
|
396
|
+
self,
|
|
397
|
+
real: np.ndarray,
|
|
398
|
+
generated: np.ndarray,
|
|
399
|
+
) -> np.ndarray:
|
|
400
|
+
"""Compute MMD for all genes in batch on GPU.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
real : np.ndarray
|
|
405
|
+
Real data, shape (n_samples_real, n_genes)
|
|
406
|
+
generated : np.ndarray
|
|
407
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
408
|
+
|
|
409
|
+
Returns
|
|
410
|
+
-------
|
|
411
|
+
np.ndarray
|
|
412
|
+
MMD per gene
|
|
413
|
+
"""
|
|
414
|
+
real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
|
|
415
|
+
gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
|
|
416
|
+
|
|
417
|
+
n_genes = real_t.shape[1]
|
|
418
|
+
n_x, n_y = real_t.shape[0], gen_t.shape[0]
|
|
419
|
+
|
|
420
|
+
mmd_values = torch.zeros(n_genes, device=self.device)
|
|
421
|
+
|
|
422
|
+
for g in range(n_genes):
|
|
423
|
+
x = real_t[:, g:g+1]
|
|
424
|
+
y = gen_t[:, g:g+1]
|
|
425
|
+
|
|
426
|
+
# Median heuristic for sigma
|
|
427
|
+
if self.sigma is None:
|
|
428
|
+
combined = torch.cat([x, y], dim=0)
|
|
429
|
+
pairwise = torch.abs(combined - combined.T)
|
|
430
|
+
sigma = torch.median(pairwise[pairwise > 0]).item()
|
|
431
|
+
if sigma == 0:
|
|
432
|
+
sigma = 1.0
|
|
433
|
+
else:
|
|
434
|
+
sigma = self.sigma
|
|
435
|
+
|
|
436
|
+
# RBF kernel
|
|
437
|
+
def rbf(a, b, s):
|
|
438
|
+
sq_dist = (a - b.T) ** 2
|
|
439
|
+
return torch.exp(-sq_dist / (2 * s ** 2))
|
|
440
|
+
|
|
441
|
+
K_xx = rbf(x, x, sigma)
|
|
442
|
+
K_yy = rbf(y, y, sigma)
|
|
443
|
+
K_xy = rbf(x, y, sigma)
|
|
444
|
+
|
|
445
|
+
# Unbiased MMD
|
|
446
|
+
mmd = (
|
|
447
|
+
(K_xx.sum() - K_xx.trace()) / (n_x * (n_x - 1)) +
|
|
448
|
+
(K_yy.sum() - K_yy.trace()) / (n_y * (n_y - 1)) -
|
|
449
|
+
2 * K_xy.sum() / (n_x * n_y)
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
mmd_values[g] = torch.clamp(mmd, min=0)
|
|
453
|
+
|
|
454
|
+
return mmd_values.cpu().numpy()
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
class GPUEnergyDistance:
|
|
458
|
+
"""GPU-accelerated Energy distance computation."""
|
|
459
|
+
|
|
460
|
+
def __init__(self, device: str = "cuda"):
|
|
461
|
+
if not HAS_TORCH:
|
|
462
|
+
raise ImportError("PyTorch required for GPU acceleration")
|
|
463
|
+
|
|
464
|
+
self.device = _get_device(device)
|
|
465
|
+
|
|
466
|
+
def compute_batch(
|
|
467
|
+
self,
|
|
468
|
+
real: np.ndarray,
|
|
469
|
+
generated: np.ndarray,
|
|
470
|
+
) -> np.ndarray:
|
|
471
|
+
"""Compute Energy distance for all genes in batch on GPU.
|
|
472
|
+
|
|
473
|
+
Parameters
|
|
474
|
+
----------
|
|
475
|
+
real : np.ndarray
|
|
476
|
+
Real data, shape (n_samples_real, n_genes)
|
|
477
|
+
generated : np.ndarray
|
|
478
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
479
|
+
|
|
480
|
+
Returns
|
|
481
|
+
-------
|
|
482
|
+
np.ndarray
|
|
483
|
+
Energy distance per gene
|
|
484
|
+
"""
|
|
485
|
+
real_t = torch.tensor(real, dtype=torch.float32, device=self.device)
|
|
486
|
+
gen_t = torch.tensor(generated, dtype=torch.float32, device=self.device)
|
|
487
|
+
|
|
488
|
+
n_genes = real_t.shape[1]
|
|
489
|
+
|
|
490
|
+
energy_values = torch.zeros(n_genes, device=self.device)
|
|
491
|
+
|
|
492
|
+
for g in range(n_genes):
|
|
493
|
+
x = real_t[:, g]
|
|
494
|
+
y = gen_t[:, g]
|
|
495
|
+
|
|
496
|
+
# E[|X - Y|]
|
|
497
|
+
xy_dist = torch.mean(torch.abs(x.unsqueeze(1) - y.unsqueeze(0)))
|
|
498
|
+
|
|
499
|
+
# E[|X - X'|]
|
|
500
|
+
xx_dist = torch.mean(torch.abs(x.unsqueeze(1) - x.unsqueeze(0)))
|
|
501
|
+
|
|
502
|
+
# E[|Y - Y'|]
|
|
503
|
+
yy_dist = torch.mean(torch.abs(y.unsqueeze(1) - y.unsqueeze(0)))
|
|
504
|
+
|
|
505
|
+
energy = 2 * xy_dist - xx_dist - yy_dist
|
|
506
|
+
energy_values[g] = torch.clamp(energy, min=0)
|
|
507
|
+
|
|
508
|
+
return energy_values.cpu().numpy()
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
# =============================================================================
|
|
512
|
+
# Vectorized NumPy Implementations (for CPU speedup without joblib)
|
|
513
|
+
# =============================================================================
|
|
514
|
+
|
|
515
|
+
def vectorized_wasserstein1(
|
|
516
|
+
real: np.ndarray,
|
|
517
|
+
generated: np.ndarray,
|
|
518
|
+
) -> np.ndarray:
|
|
519
|
+
"""Compute W1 for all genes using vectorized NumPy.
|
|
520
|
+
|
|
521
|
+
This is faster than the loop-based scipy implementation.
|
|
522
|
+
|
|
523
|
+
Parameters
|
|
524
|
+
----------
|
|
525
|
+
real : np.ndarray
|
|
526
|
+
Real data, shape (n_samples_real, n_genes)
|
|
527
|
+
generated : np.ndarray
|
|
528
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
529
|
+
|
|
530
|
+
Returns
|
|
531
|
+
-------
|
|
532
|
+
np.ndarray
|
|
533
|
+
W1 distance per gene
|
|
534
|
+
"""
|
|
535
|
+
n_genes = real.shape[1]
|
|
536
|
+
n_quantiles = max(real.shape[0], generated.shape[0])
|
|
537
|
+
|
|
538
|
+
# Sort each column
|
|
539
|
+
real_sorted = np.sort(real, axis=0)
|
|
540
|
+
gen_sorted = np.sort(generated, axis=0)
|
|
541
|
+
|
|
542
|
+
# Interpolate to same number of quantiles
|
|
543
|
+
real_positions = np.linspace(0, 1, real_sorted.shape[0])
|
|
544
|
+
gen_positions = np.linspace(0, 1, gen_sorted.shape[0])
|
|
545
|
+
target_positions = np.linspace(0, 1, n_quantiles)
|
|
546
|
+
|
|
547
|
+
# Interpolate each gene column
|
|
548
|
+
real_interp = np.zeros((n_quantiles, n_genes))
|
|
549
|
+
gen_interp = np.zeros((n_quantiles, n_genes))
|
|
550
|
+
|
|
551
|
+
for g in range(n_genes):
|
|
552
|
+
real_interp[:, g] = np.interp(target_positions, real_positions, real_sorted[:, g])
|
|
553
|
+
gen_interp[:, g] = np.interp(target_positions, gen_positions, gen_sorted[:, g])
|
|
554
|
+
|
|
555
|
+
# W1 = mean absolute difference
|
|
556
|
+
return np.mean(np.abs(real_interp - gen_interp), axis=0)
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
def vectorized_mmd(
|
|
560
|
+
real: np.ndarray,
|
|
561
|
+
generated: np.ndarray,
|
|
562
|
+
sigma: Optional[float] = None,
|
|
563
|
+
) -> np.ndarray:
|
|
564
|
+
"""Compute MMD for all genes using vectorized NumPy.
|
|
565
|
+
|
|
566
|
+
Parameters
|
|
567
|
+
----------
|
|
568
|
+
real : np.ndarray
|
|
569
|
+
Real data, shape (n_samples_real, n_genes)
|
|
570
|
+
generated : np.ndarray
|
|
571
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
572
|
+
sigma : float, optional
|
|
573
|
+
Kernel bandwidth. Uses median heuristic if None.
|
|
574
|
+
|
|
575
|
+
Returns
|
|
576
|
+
-------
|
|
577
|
+
np.ndarray
|
|
578
|
+
MMD per gene
|
|
579
|
+
"""
|
|
580
|
+
n_genes = real.shape[1]
|
|
581
|
+
n_x, n_y = real.shape[0], generated.shape[0]
|
|
582
|
+
|
|
583
|
+
mmd_values = np.zeros(n_genes)
|
|
584
|
+
|
|
585
|
+
for g in range(n_genes):
|
|
586
|
+
x = real[:, g:g+1]
|
|
587
|
+
y = generated[:, g:g+1]
|
|
588
|
+
|
|
589
|
+
# Median heuristic
|
|
590
|
+
if sigma is None:
|
|
591
|
+
combined = np.vstack([x, y])
|
|
592
|
+
pairwise = np.abs(combined - combined.T)
|
|
593
|
+
s = float(np.median(pairwise[pairwise > 0]))
|
|
594
|
+
if s == 0:
|
|
595
|
+
s = 1.0
|
|
596
|
+
else:
|
|
597
|
+
s = sigma
|
|
598
|
+
|
|
599
|
+
# RBF kernel
|
|
600
|
+
K_xx = np.exp(-(x - x.T) ** 2 / (2 * s ** 2))
|
|
601
|
+
K_yy = np.exp(-(y - y.T) ** 2 / (2 * s ** 2))
|
|
602
|
+
K_xy = np.exp(-(x - y.T) ** 2 / (2 * s ** 2))
|
|
603
|
+
|
|
604
|
+
# Unbiased MMD
|
|
605
|
+
mmd = (
|
|
606
|
+
(np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
|
|
607
|
+
(np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
|
|
608
|
+
2 * np.sum(K_xy) / (n_x * n_y)
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
mmd_values[g] = max(0, mmd)
|
|
612
|
+
|
|
613
|
+
return mmd_values
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
# =============================================================================
|
|
617
|
+
# High-Level Accelerated Evaluation Interface
|
|
618
|
+
# =============================================================================
|
|
619
|
+
|
|
620
|
+
def compute_metrics_accelerated(
|
|
621
|
+
real: np.ndarray,
|
|
622
|
+
generated: np.ndarray,
|
|
623
|
+
metrics: List[str] = ["wasserstein_1", "wasserstein_2", "mmd", "energy"],
|
|
624
|
+
n_jobs: int = 1,
|
|
625
|
+
device: str = "cpu",
|
|
626
|
+
gene_names: Optional[List[str]] = None,
|
|
627
|
+
verbose: bool = False,
|
|
628
|
+
) -> Dict[str, MetricResult]:
|
|
629
|
+
"""Compute multiple metrics with acceleration.
|
|
630
|
+
|
|
631
|
+
This is the main entry point for accelerated metric computation.
|
|
632
|
+
Automatically selects the best available backend.
|
|
633
|
+
|
|
634
|
+
Parameters
|
|
635
|
+
----------
|
|
636
|
+
real : np.ndarray
|
|
637
|
+
Real data, shape (n_samples_real, n_genes)
|
|
638
|
+
generated : np.ndarray
|
|
639
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
640
|
+
metrics : List[str]
|
|
641
|
+
Metrics to compute: "wasserstein_1", "wasserstein_2", "mmd", "energy"
|
|
642
|
+
n_jobs : int
|
|
643
|
+
Number of CPU jobs (-1 for all cores)
|
|
644
|
+
device : str
|
|
645
|
+
Compute device ("cpu", "cuda", "auto")
|
|
646
|
+
gene_names : List[str], optional
|
|
647
|
+
Gene names
|
|
648
|
+
verbose : bool
|
|
649
|
+
Print progress
|
|
650
|
+
|
|
651
|
+
Returns
|
|
652
|
+
-------
|
|
653
|
+
Dict[str, MetricResult]
|
|
654
|
+
Dictionary of metric results
|
|
655
|
+
"""
|
|
656
|
+
backends = get_available_backends()
|
|
657
|
+
|
|
658
|
+
if device == "auto":
|
|
659
|
+
if backends["cuda"]:
|
|
660
|
+
device = "cuda"
|
|
661
|
+
elif backends["mps"]:
|
|
662
|
+
device = "mps"
|
|
663
|
+
else:
|
|
664
|
+
device = "cpu"
|
|
665
|
+
|
|
666
|
+
if verbose:
|
|
667
|
+
print(f"Using device: {device}, n_jobs: {n_jobs}")
|
|
668
|
+
print(f"Available backends: {backends}")
|
|
669
|
+
|
|
670
|
+
n_genes = real.shape[1]
|
|
671
|
+
if gene_names is None:
|
|
672
|
+
gene_names = [f"gene_{i}" for i in range(n_genes)]
|
|
673
|
+
|
|
674
|
+
results = {}
|
|
675
|
+
|
|
676
|
+
for metric_name in metrics:
|
|
677
|
+
if verbose:
|
|
678
|
+
print(f"Computing {metric_name}...")
|
|
679
|
+
|
|
680
|
+
if device != "cpu" and backends["torch"]:
|
|
681
|
+
# GPU path
|
|
682
|
+
if metric_name == "wasserstein_1":
|
|
683
|
+
gpu_metric = GPUWasserstein1(device=device)
|
|
684
|
+
per_gene = gpu_metric.compute_batch(real, generated)
|
|
685
|
+
elif metric_name == "wasserstein_2" and backends["geomloss"]:
|
|
686
|
+
gpu_metric = GPUWasserstein2(device=device)
|
|
687
|
+
per_gene = gpu_metric.compute_batch(real, generated)
|
|
688
|
+
elif metric_name == "mmd":
|
|
689
|
+
gpu_metric = GPUMMD(device=device)
|
|
690
|
+
per_gene = gpu_metric.compute_batch(real, generated)
|
|
691
|
+
elif metric_name == "energy":
|
|
692
|
+
gpu_metric = GPUEnergyDistance(device=device)
|
|
693
|
+
per_gene = gpu_metric.compute_batch(real, generated)
|
|
694
|
+
else:
|
|
695
|
+
# Fallback to vectorized CPU
|
|
696
|
+
per_gene = _compute_cpu_metric(metric_name, real, generated, n_jobs)
|
|
697
|
+
else:
|
|
698
|
+
# CPU path
|
|
699
|
+
per_gene = _compute_cpu_metric(metric_name, real, generated, n_jobs)
|
|
700
|
+
|
|
701
|
+
results[metric_name] = MetricResult(
|
|
702
|
+
name=metric_name,
|
|
703
|
+
per_gene_values=per_gene,
|
|
704
|
+
gene_names=gene_names,
|
|
705
|
+
aggregate_value=float(np.nanmean(per_gene)),
|
|
706
|
+
aggregate_method="mean",
|
|
707
|
+
metadata={
|
|
708
|
+
"device": device,
|
|
709
|
+
"n_jobs": n_jobs,
|
|
710
|
+
"accelerated": True,
|
|
711
|
+
}
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return results
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def _compute_cpu_metric(
|
|
718
|
+
metric_name: str,
|
|
719
|
+
real: np.ndarray,
|
|
720
|
+
generated: np.ndarray,
|
|
721
|
+
n_jobs: int,
|
|
722
|
+
) -> np.ndarray:
|
|
723
|
+
"""Compute metric on CPU with optional parallelization."""
|
|
724
|
+
if metric_name == "wasserstein_1":
|
|
725
|
+
if n_jobs != 1 and HAS_JOBLIB:
|
|
726
|
+
return _parallel_w1(real, generated, n_jobs)
|
|
727
|
+
else:
|
|
728
|
+
return vectorized_wasserstein1(real, generated)
|
|
729
|
+
elif metric_name == "wasserstein_2":
|
|
730
|
+
return _compute_w2_cpu(real, generated, n_jobs)
|
|
731
|
+
elif metric_name == "mmd":
|
|
732
|
+
if n_jobs != 1 and HAS_JOBLIB:
|
|
733
|
+
return _parallel_mmd(real, generated, n_jobs)
|
|
734
|
+
else:
|
|
735
|
+
return vectorized_mmd(real, generated)
|
|
736
|
+
elif metric_name == "energy":
|
|
737
|
+
return _compute_energy_cpu(real, generated, n_jobs)
|
|
738
|
+
else:
|
|
739
|
+
raise ValueError(f"Unknown metric: {metric_name}")
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def _parallel_w1(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
|
|
743
|
+
"""Parallel W1 computation."""
|
|
744
|
+
from scipy.stats import wasserstein_distance
|
|
745
|
+
|
|
746
|
+
n_genes = real.shape[1]
|
|
747
|
+
|
|
748
|
+
def compute_single(g):
|
|
749
|
+
r = real[:, g]
|
|
750
|
+
gen = generated[:, g]
|
|
751
|
+
r = r[~np.isnan(r)]
|
|
752
|
+
gen = gen[~np.isnan(gen)]
|
|
753
|
+
if len(r) == 0 or len(gen) == 0:
|
|
754
|
+
return np.nan
|
|
755
|
+
return wasserstein_distance(r, gen)
|
|
756
|
+
|
|
757
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
758
|
+
delayed(compute_single)(g) for g in range(n_genes)
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
return np.array(results)
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def _parallel_mmd(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
|
|
765
|
+
"""Parallel MMD computation."""
|
|
766
|
+
n_genes = real.shape[1]
|
|
767
|
+
|
|
768
|
+
def compute_single(g):
|
|
769
|
+
x = real[:, g:g+1]
|
|
770
|
+
y = generated[:, g:g+1]
|
|
771
|
+
|
|
772
|
+
combined = np.vstack([x, y])
|
|
773
|
+
pairwise = np.abs(combined - combined.T)
|
|
774
|
+
sigma = float(np.median(pairwise[pairwise > 0]))
|
|
775
|
+
if sigma == 0:
|
|
776
|
+
sigma = 1.0
|
|
777
|
+
|
|
778
|
+
n_x, n_y = len(x), len(y)
|
|
779
|
+
|
|
780
|
+
K_xx = np.exp(-(x - x.T) ** 2 / (2 * sigma ** 2))
|
|
781
|
+
K_yy = np.exp(-(y - y.T) ** 2 / (2 * sigma ** 2))
|
|
782
|
+
K_xy = np.exp(-(x - y.T) ** 2 / (2 * sigma ** 2))
|
|
783
|
+
|
|
784
|
+
mmd = (
|
|
785
|
+
(np.sum(K_xx) - np.trace(K_xx)) / (n_x * (n_x - 1)) +
|
|
786
|
+
(np.sum(K_yy) - np.trace(K_yy)) / (n_y * (n_y - 1)) -
|
|
787
|
+
2 * np.sum(K_xy) / (n_x * n_y)
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
return max(0, mmd)
|
|
791
|
+
|
|
792
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
793
|
+
delayed(compute_single)(g) for g in range(n_genes)
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
return np.array(results)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _compute_w2_cpu(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
|
|
800
|
+
"""CPU W2 computation (quantile-based)."""
|
|
801
|
+
n_genes = real.shape[1]
|
|
802
|
+
|
|
803
|
+
def compute_single(g):
|
|
804
|
+
r = real[:, g]
|
|
805
|
+
gen = generated[:, g]
|
|
806
|
+
|
|
807
|
+
r = r[~np.isnan(r)]
|
|
808
|
+
gen = gen[~np.isnan(gen)]
|
|
809
|
+
|
|
810
|
+
if len(r) == 0 or len(gen) == 0:
|
|
811
|
+
return np.nan
|
|
812
|
+
|
|
813
|
+
r_sorted = np.sort(r)
|
|
814
|
+
g_sorted = np.sort(gen)
|
|
815
|
+
|
|
816
|
+
n = max(len(r_sorted), len(g_sorted))
|
|
817
|
+
r_q = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(r_sorted)), r_sorted)
|
|
818
|
+
g_q = np.interp(np.linspace(0, 1, n), np.linspace(0, 1, len(g_sorted)), g_sorted)
|
|
819
|
+
|
|
820
|
+
return np.sqrt(np.mean((r_q - g_q) ** 2))
|
|
821
|
+
|
|
822
|
+
if n_jobs != 1 and HAS_JOBLIB:
|
|
823
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
824
|
+
delayed(compute_single)(g) for g in range(n_genes)
|
|
825
|
+
)
|
|
826
|
+
return np.array(results)
|
|
827
|
+
else:
|
|
828
|
+
return np.array([compute_single(g) for g in range(n_genes)])
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def _compute_energy_cpu(real: np.ndarray, generated: np.ndarray, n_jobs: int) -> np.ndarray:
|
|
832
|
+
"""CPU Energy distance computation."""
|
|
833
|
+
n_genes = real.shape[1]
|
|
834
|
+
|
|
835
|
+
def compute_single(g):
|
|
836
|
+
x = real[:, g]
|
|
837
|
+
y = generated[:, g]
|
|
838
|
+
|
|
839
|
+
x = x[~np.isnan(x)]
|
|
840
|
+
y = y[~np.isnan(y)]
|
|
841
|
+
|
|
842
|
+
if len(x) < 2 or len(y) < 2:
|
|
843
|
+
return np.nan
|
|
844
|
+
|
|
845
|
+
xy_dist = np.mean(np.abs(x[:, np.newaxis] - y[np.newaxis, :]))
|
|
846
|
+
xx_dist = np.mean(np.abs(x[:, np.newaxis] - x[np.newaxis, :]))
|
|
847
|
+
yy_dist = np.mean(np.abs(y[:, np.newaxis] - y[np.newaxis, :]))
|
|
848
|
+
|
|
849
|
+
return max(0, 2 * xy_dist - xx_dist - yy_dist)
|
|
850
|
+
|
|
851
|
+
if n_jobs != 1 and HAS_JOBLIB:
|
|
852
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
853
|
+
delayed(compute_single)(g) for g in range(n_genes)
|
|
854
|
+
)
|
|
855
|
+
return np.array(results)
|
|
856
|
+
else:
|
|
857
|
+
return np.array([compute_single(g) for g in range(n_genes)])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|