gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +56 -1
- geneval/deg/__init__.py +65 -0
- geneval/deg/context.py +271 -0
- geneval/deg/detection.py +578 -0
- geneval/deg/evaluator.py +538 -0
- geneval/deg/visualization.py +376 -0
- geneval/evaluator.py +46 -0
- geneval/metrics/__init__.py +25 -0
- geneval/metrics/accelerated.py +857 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/METADATA +164 -3
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/RECORD +14 -8
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/WHEEL +0 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization module for DEG evaluation results.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- Distribution plots for metrics across contexts
|
|
6
|
+
- Heatmaps for context × metric results
|
|
7
|
+
- Comprehensive DEG reports
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Optional, List, Dict, Union, Any, TYPE_CHECKING
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from .evaluator import DEGEvaluationResult
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _check_matplotlib():
|
|
22
|
+
"""Check if matplotlib is available."""
|
|
23
|
+
try:
|
|
24
|
+
import matplotlib.pyplot as plt
|
|
25
|
+
return plt
|
|
26
|
+
except ImportError:
|
|
27
|
+
raise ImportError("matplotlib is required for visualization. Install with: pip install matplotlib")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _check_seaborn():
|
|
31
|
+
"""Check if seaborn is available."""
|
|
32
|
+
try:
|
|
33
|
+
import seaborn as sns
|
|
34
|
+
return sns
|
|
35
|
+
except ImportError:
|
|
36
|
+
return None # Seaborn is optional
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def plot_deg_distributions(
|
|
40
|
+
results: "DEGEvaluationResult",
|
|
41
|
+
metrics: Optional[List[str]] = None,
|
|
42
|
+
figsize: tuple = (12, 4),
|
|
43
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
44
|
+
show: bool = True,
|
|
45
|
+
) -> "plt.Figure":
|
|
46
|
+
"""
|
|
47
|
+
Plot distribution of metrics across contexts.
|
|
48
|
+
|
|
49
|
+
Creates violin/box plots showing metric distributions across all contexts.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
results : DEGEvaluationResult
|
|
54
|
+
DEG evaluation results
|
|
55
|
+
metrics : List[str], optional
|
|
56
|
+
Metrics to plot. If None, plots all available.
|
|
57
|
+
figsize : tuple
|
|
58
|
+
Figure size per metric
|
|
59
|
+
save_path : str or Path, optional
|
|
60
|
+
Path to save figure
|
|
61
|
+
show : bool
|
|
62
|
+
Whether to display the figure
|
|
63
|
+
|
|
64
|
+
Returns
|
|
65
|
+
-------
|
|
66
|
+
plt.Figure
|
|
67
|
+
Matplotlib figure
|
|
68
|
+
"""
|
|
69
|
+
plt = _check_matplotlib()
|
|
70
|
+
sns = _check_seaborn()
|
|
71
|
+
|
|
72
|
+
df = results.expanded_metrics
|
|
73
|
+
|
|
74
|
+
if metrics is None:
|
|
75
|
+
# Get all numeric columns that are metrics
|
|
76
|
+
metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
|
|
77
|
+
|
|
78
|
+
if len(metrics) == 0:
|
|
79
|
+
raise ValueError("No metrics to plot")
|
|
80
|
+
|
|
81
|
+
n_metrics = len(metrics)
|
|
82
|
+
fig, axes = plt.subplots(1, n_metrics, figsize=(figsize[0], figsize[1]))
|
|
83
|
+
|
|
84
|
+
if n_metrics == 1:
|
|
85
|
+
axes = [axes]
|
|
86
|
+
|
|
87
|
+
for ax, metric in zip(axes, metrics):
|
|
88
|
+
values = df[metric].dropna()
|
|
89
|
+
|
|
90
|
+
if sns is not None:
|
|
91
|
+
sns.violinplot(y=values, ax=ax, inner="box")
|
|
92
|
+
else:
|
|
93
|
+
ax.boxplot(values, vert=True)
|
|
94
|
+
|
|
95
|
+
ax.set_title(metric)
|
|
96
|
+
ax.set_ylabel("Value")
|
|
97
|
+
|
|
98
|
+
# Add mean annotation
|
|
99
|
+
mean_val = values.mean()
|
|
100
|
+
ax.axhline(mean_val, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_val:.3f}')
|
|
101
|
+
ax.legend(loc='upper right', fontsize=8)
|
|
102
|
+
|
|
103
|
+
plt.suptitle("Metric Distributions Across Contexts (DEGs only)", y=1.02)
|
|
104
|
+
plt.tight_layout()
|
|
105
|
+
|
|
106
|
+
if save_path:
|
|
107
|
+
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
108
|
+
|
|
109
|
+
if show:
|
|
110
|
+
plt.show()
|
|
111
|
+
|
|
112
|
+
return fig
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def plot_context_heatmap(
|
|
116
|
+
results: "DEGEvaluationResult",
|
|
117
|
+
metrics: Optional[List[str]] = None,
|
|
118
|
+
figsize: tuple = (12, 8),
|
|
119
|
+
cmap: str = "RdYlBu_r",
|
|
120
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
121
|
+
show: bool = True,
|
|
122
|
+
) -> "plt.Figure":
|
|
123
|
+
"""
|
|
124
|
+
Plot heatmap of metrics across contexts.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
results : DEGEvaluationResult
|
|
129
|
+
DEG evaluation results
|
|
130
|
+
metrics : List[str], optional
|
|
131
|
+
Metrics to include
|
|
132
|
+
figsize : tuple
|
|
133
|
+
Figure size
|
|
134
|
+
cmap : str
|
|
135
|
+
Colormap name
|
|
136
|
+
save_path : str or Path, optional
|
|
137
|
+
Path to save figure
|
|
138
|
+
show : bool
|
|
139
|
+
Whether to display
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
plt.Figure
|
|
144
|
+
Matplotlib figure
|
|
145
|
+
"""
|
|
146
|
+
plt = _check_matplotlib()
|
|
147
|
+
sns = _check_seaborn()
|
|
148
|
+
|
|
149
|
+
df = results.expanded_metrics
|
|
150
|
+
|
|
151
|
+
if metrics is None:
|
|
152
|
+
metrics = [c for c in df.columns if c in results.settings.get("metrics", [])]
|
|
153
|
+
|
|
154
|
+
# Create matrix for heatmap
|
|
155
|
+
heatmap_data = df.set_index("context_id")[metrics]
|
|
156
|
+
|
|
157
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
158
|
+
|
|
159
|
+
if sns is not None:
|
|
160
|
+
sns.heatmap(
|
|
161
|
+
heatmap_data,
|
|
162
|
+
ax=ax,
|
|
163
|
+
cmap=cmap,
|
|
164
|
+
annot=True,
|
|
165
|
+
fmt=".3f",
|
|
166
|
+
cbar_kws={"label": "Metric Value"},
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
im = ax.imshow(heatmap_data.values, cmap=cmap, aspect='auto')
|
|
170
|
+
plt.colorbar(im, ax=ax, label="Metric Value")
|
|
171
|
+
ax.set_xticks(range(len(metrics)))
|
|
172
|
+
ax.set_xticklabels(metrics, rotation=45, ha='right')
|
|
173
|
+
ax.set_yticks(range(len(heatmap_data)))
|
|
174
|
+
ax.set_yticklabels(heatmap_data.index)
|
|
175
|
+
|
|
176
|
+
ax.set_title("Metrics per Context (DEGs only)")
|
|
177
|
+
ax.set_xlabel("Metric")
|
|
178
|
+
ax.set_ylabel("Context")
|
|
179
|
+
|
|
180
|
+
plt.tight_layout()
|
|
181
|
+
|
|
182
|
+
if save_path:
|
|
183
|
+
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
184
|
+
|
|
185
|
+
if show:
|
|
186
|
+
plt.show()
|
|
187
|
+
|
|
188
|
+
return fig
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def plot_deg_counts(
|
|
192
|
+
results: "DEGEvaluationResult",
|
|
193
|
+
figsize: tuple = (10, 6),
|
|
194
|
+
save_path: Optional[Union[str, Path]] = None,
|
|
195
|
+
show: bool = True,
|
|
196
|
+
) -> "plt.Figure":
|
|
197
|
+
"""
|
|
198
|
+
Plot DEG counts per context.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
results : DEGEvaluationResult
|
|
203
|
+
DEG evaluation results
|
|
204
|
+
figsize : tuple
|
|
205
|
+
Figure size
|
|
206
|
+
save_path : str or Path, optional
|
|
207
|
+
Path to save figure
|
|
208
|
+
show : bool
|
|
209
|
+
Whether to display
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
plt.Figure
|
|
214
|
+
Matplotlib figure
|
|
215
|
+
"""
|
|
216
|
+
plt = _check_matplotlib()
|
|
217
|
+
|
|
218
|
+
df = results.deg_summary
|
|
219
|
+
|
|
220
|
+
if len(df) == 0:
|
|
221
|
+
raise ValueError("No DEG data to plot")
|
|
222
|
+
|
|
223
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
224
|
+
|
|
225
|
+
x = range(len(df))
|
|
226
|
+
width = 0.35
|
|
227
|
+
|
|
228
|
+
if "n_upregulated" in df.columns and "n_downregulated" in df.columns:
|
|
229
|
+
ax.bar(x, df["n_upregulated"], width, label='Upregulated', color='red', alpha=0.7)
|
|
230
|
+
ax.bar(x, -df["n_downregulated"], width, label='Downregulated', color='blue', alpha=0.7)
|
|
231
|
+
ax.axhline(0, color='black', linewidth=0.5)
|
|
232
|
+
ax.set_ylabel("Number of DEGs")
|
|
233
|
+
else:
|
|
234
|
+
ax.bar(x, df["n_degs"], color='purple', alpha=0.7)
|
|
235
|
+
ax.set_ylabel("Number of DEGs")
|
|
236
|
+
|
|
237
|
+
ax.set_xlabel("Context")
|
|
238
|
+
ax.set_xticks(x)
|
|
239
|
+
ax.set_xticklabels(df["context_id"], rotation=45, ha='right')
|
|
240
|
+
ax.set_title("DEG Counts per Context")
|
|
241
|
+
ax.legend()
|
|
242
|
+
|
|
243
|
+
plt.tight_layout()
|
|
244
|
+
|
|
245
|
+
if save_path:
|
|
246
|
+
fig.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
247
|
+
|
|
248
|
+
if show:
|
|
249
|
+
plt.show()
|
|
250
|
+
|
|
251
|
+
return fig
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def create_deg_report(
|
|
255
|
+
results: "DEGEvaluationResult",
|
|
256
|
+
output_dir: Union[str, Path],
|
|
257
|
+
include_plots: bool = True,
|
|
258
|
+
) -> None:
|
|
259
|
+
"""
|
|
260
|
+
Create a comprehensive DEG evaluation report.
|
|
261
|
+
|
|
262
|
+
Creates:
|
|
263
|
+
- Summary statistics (CSV)
|
|
264
|
+
- Aggregated metrics (CSV)
|
|
265
|
+
- Expanded per-context metrics (CSV)
|
|
266
|
+
- DEG summary per context (CSV)
|
|
267
|
+
- Visualization plots (PNG, if include_plots=True)
|
|
268
|
+
- Markdown report
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
results : DEGEvaluationResult
|
|
273
|
+
DEG evaluation results
|
|
274
|
+
output_dir : str or Path
|
|
275
|
+
Output directory
|
|
276
|
+
include_plots : bool
|
|
277
|
+
Whether to generate plots
|
|
278
|
+
"""
|
|
279
|
+
output_dir = Path(output_dir)
|
|
280
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
281
|
+
|
|
282
|
+
# Save CSVs
|
|
283
|
+
results.save(output_dir)
|
|
284
|
+
|
|
285
|
+
# Generate plots
|
|
286
|
+
if include_plots:
|
|
287
|
+
try:
|
|
288
|
+
plt = _check_matplotlib()
|
|
289
|
+
|
|
290
|
+
# Distribution plot
|
|
291
|
+
if len(results.expanded_metrics) > 0:
|
|
292
|
+
plot_deg_distributions(
|
|
293
|
+
results,
|
|
294
|
+
save_path=output_dir / "deg_metric_distributions.png",
|
|
295
|
+
show=False,
|
|
296
|
+
)
|
|
297
|
+
plt.close()
|
|
298
|
+
|
|
299
|
+
# Heatmap
|
|
300
|
+
plot_context_heatmap(
|
|
301
|
+
results,
|
|
302
|
+
save_path=output_dir / "deg_context_heatmap.png",
|
|
303
|
+
show=False,
|
|
304
|
+
)
|
|
305
|
+
plt.close()
|
|
306
|
+
|
|
307
|
+
# DEG counts
|
|
308
|
+
if len(results.deg_summary) > 0:
|
|
309
|
+
plot_deg_counts(
|
|
310
|
+
results,
|
|
311
|
+
save_path=output_dir / "deg_counts.png",
|
|
312
|
+
show=False,
|
|
313
|
+
)
|
|
314
|
+
plt.close()
|
|
315
|
+
|
|
316
|
+
except ImportError:
|
|
317
|
+
pass # Skip plots if matplotlib not available
|
|
318
|
+
|
|
319
|
+
# Create markdown report
|
|
320
|
+
report_lines = [
|
|
321
|
+
"# DEG Evaluation Report",
|
|
322
|
+
"",
|
|
323
|
+
"## Summary",
|
|
324
|
+
"",
|
|
325
|
+
f"- **Number of contexts evaluated**: {len(results.context_results)}",
|
|
326
|
+
f"- **DEG detection method**: {results.settings.get('deg_method', 'N/A')}",
|
|
327
|
+
f"- **P-value threshold**: {results.settings.get('pval_threshold', 'N/A')}",
|
|
328
|
+
f"- **Log fold change threshold**: {results.settings.get('lfc_threshold', 'N/A')}",
|
|
329
|
+
f"- **Compute device**: {results.settings.get('device', 'N/A')}",
|
|
330
|
+
"",
|
|
331
|
+
"## Aggregated Metrics",
|
|
332
|
+
"",
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
if len(results.aggregated_metrics) > 0:
|
|
336
|
+
for col in results.aggregated_metrics.columns:
|
|
337
|
+
val = results.aggregated_metrics[col].iloc[0]
|
|
338
|
+
if isinstance(val, float):
|
|
339
|
+
report_lines.append(f"- **{col}**: {val:.4f}")
|
|
340
|
+
else:
|
|
341
|
+
report_lines.append(f"- **{col}**: {val}")
|
|
342
|
+
else:
|
|
343
|
+
report_lines.append("No aggregated metrics available.")
|
|
344
|
+
|
|
345
|
+
report_lines.extend([
|
|
346
|
+
"",
|
|
347
|
+
"## DEG Summary",
|
|
348
|
+
"",
|
|
349
|
+
])
|
|
350
|
+
|
|
351
|
+
if len(results.deg_summary) > 0:
|
|
352
|
+
report_lines.append(results.deg_summary.to_markdown(index=False))
|
|
353
|
+
else:
|
|
354
|
+
report_lines.append("No DEG data available.")
|
|
355
|
+
|
|
356
|
+
report_lines.extend([
|
|
357
|
+
"",
|
|
358
|
+
"## Files Generated",
|
|
359
|
+
"",
|
|
360
|
+
"- `deg_aggregated_metrics.csv`: Summary metrics across all contexts",
|
|
361
|
+
"- `deg_expanded_metrics.csv`: Per-context metrics",
|
|
362
|
+
"- `deg_summary.csv`: DEG detection summary",
|
|
363
|
+
"- `deg_per_context/`: Individual DEG results per context",
|
|
364
|
+
])
|
|
365
|
+
|
|
366
|
+
if include_plots:
|
|
367
|
+
report_lines.extend([
|
|
368
|
+
"- `deg_metric_distributions.png`: Distribution plots",
|
|
369
|
+
"- `deg_context_heatmap.png`: Context × metric heatmap",
|
|
370
|
+
"- `deg_counts.png`: DEG counts per context",
|
|
371
|
+
])
|
|
372
|
+
|
|
373
|
+
with open(output_dir / "DEG_REPORT.md", "w") as f:
|
|
374
|
+
f.write("\n".join(report_lines))
|
|
375
|
+
|
|
376
|
+
print(f"DEG report saved to {output_dir}")
|
geneval/evaluator.py
CHANGED
|
@@ -66,6 +66,10 @@ class GeneEvalEvaluator:
|
|
|
66
66
|
Whether to include multivariate (whole-space) metrics
|
|
67
67
|
verbose : bool
|
|
68
68
|
Whether to print progress
|
|
69
|
+
n_jobs : int
|
|
70
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
71
|
+
device : str
|
|
72
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
69
73
|
|
|
70
74
|
Examples
|
|
71
75
|
--------
|
|
@@ -73,6 +77,10 @@ class GeneEvalEvaluator:
|
|
|
73
77
|
>>> evaluator = GeneEvalEvaluator(loader)
|
|
74
78
|
>>> results = evaluator.evaluate()
|
|
75
79
|
>>> results.save("output/")
|
|
80
|
+
|
|
81
|
+
>>> # With acceleration
|
|
82
|
+
>>> evaluator = GeneEvalEvaluator(loader, n_jobs=8, device="cuda")
|
|
83
|
+
>>> results = evaluator.evaluate()
|
|
76
84
|
"""
|
|
77
85
|
|
|
78
86
|
def __init__(
|
|
@@ -82,11 +90,15 @@ class GeneEvalEvaluator:
|
|
|
82
90
|
aggregate_method: str = "mean",
|
|
83
91
|
include_multivariate: bool = True,
|
|
84
92
|
verbose: bool = True,
|
|
93
|
+
n_jobs: int = 1,
|
|
94
|
+
device: str = "cpu",
|
|
85
95
|
):
|
|
86
96
|
self.data_loader = data_loader
|
|
87
97
|
self.aggregate_method = aggregate_method
|
|
88
98
|
self.include_multivariate = include_multivariate
|
|
89
99
|
self.verbose = verbose
|
|
100
|
+
self.n_jobs = n_jobs
|
|
101
|
+
self.device = device
|
|
90
102
|
|
|
91
103
|
# Initialize metrics
|
|
92
104
|
self.metrics: List[BaseMetric] = []
|
|
@@ -106,6 +118,25 @@ class GeneEvalEvaluator:
|
|
|
106
118
|
MultivariateWasserstein(),
|
|
107
119
|
MultivariateMMD(),
|
|
108
120
|
])
|
|
121
|
+
|
|
122
|
+
# Initialize accelerated computer if using parallelization or GPU
|
|
123
|
+
self._parallel_computer = None
|
|
124
|
+
if n_jobs != 1 or device != "cpu":
|
|
125
|
+
try:
|
|
126
|
+
from .metrics.accelerated import ParallelMetricComputer
|
|
127
|
+
self._parallel_computer = ParallelMetricComputer(
|
|
128
|
+
n_jobs=n_jobs,
|
|
129
|
+
device=device,
|
|
130
|
+
verbose=verbose,
|
|
131
|
+
)
|
|
132
|
+
if verbose:
|
|
133
|
+
from .metrics.accelerated import get_available_backends
|
|
134
|
+
backends = get_available_backends()
|
|
135
|
+
self._log(f"Acceleration enabled: n_jobs={n_jobs}, device={device}")
|
|
136
|
+
self._log(f"Available backends: {backends}")
|
|
137
|
+
except ImportError as e:
|
|
138
|
+
if verbose:
|
|
139
|
+
self._log(f"Warning: Could not enable acceleration: {e}")
|
|
109
140
|
|
|
110
141
|
def _log(self, msg: str):
|
|
111
142
|
"""Print message if verbose."""
|
|
@@ -262,6 +293,8 @@ def evaluate(
|
|
|
262
293
|
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
263
294
|
include_multivariate: bool = True,
|
|
264
295
|
verbose: bool = True,
|
|
296
|
+
n_jobs: int = 1,
|
|
297
|
+
device: str = "cpu",
|
|
265
298
|
**loader_kwargs
|
|
266
299
|
) -> EvaluationResult:
|
|
267
300
|
"""
|
|
@@ -285,6 +318,10 @@ def evaluate(
|
|
|
285
318
|
Whether to include multivariate metrics
|
|
286
319
|
verbose : bool
|
|
287
320
|
Print progress
|
|
321
|
+
n_jobs : int
|
|
322
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
323
|
+
device : str
|
|
324
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
288
325
|
**loader_kwargs
|
|
289
326
|
Additional arguments for data loader
|
|
290
327
|
|
|
@@ -295,6 +332,7 @@ def evaluate(
|
|
|
295
332
|
|
|
296
333
|
Examples
|
|
297
334
|
--------
|
|
335
|
+
>>> # Standard CPU evaluation
|
|
298
336
|
>>> results = evaluate(
|
|
299
337
|
... "real.h5ad",
|
|
300
338
|
... "generated.h5ad",
|
|
@@ -302,6 +340,12 @@ def evaluate(
|
|
|
302
340
|
... split_column="split",
|
|
303
341
|
... output_dir="evaluation_output/"
|
|
304
342
|
... )
|
|
343
|
+
|
|
344
|
+
>>> # Parallel CPU evaluation (8 cores)
|
|
345
|
+
>>> results = evaluate(..., n_jobs=8)
|
|
346
|
+
|
|
347
|
+
>>> # GPU-accelerated evaluation
|
|
348
|
+
>>> results = evaluate(..., device="cuda")
|
|
305
349
|
"""
|
|
306
350
|
# Load data
|
|
307
351
|
loader = load_data(
|
|
@@ -318,6 +362,8 @@ def evaluate(
|
|
|
318
362
|
metrics=metrics,
|
|
319
363
|
include_multivariate=include_multivariate,
|
|
320
364
|
verbose=verbose,
|
|
365
|
+
n_jobs=n_jobs,
|
|
366
|
+
device=device,
|
|
321
367
|
)
|
|
322
368
|
|
|
323
369
|
# Run evaluation
|
geneval/metrics/__init__.py
CHANGED
|
@@ -35,6 +35,20 @@ from .reconstruction import (
|
|
|
35
35
|
R2Score,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
+
# Accelerated computation
|
|
39
|
+
from .accelerated import (
|
|
40
|
+
AccelerationConfig,
|
|
41
|
+
ParallelMetricComputer,
|
|
42
|
+
get_available_backends,
|
|
43
|
+
compute_metrics_accelerated,
|
|
44
|
+
GPUWasserstein1,
|
|
45
|
+
GPUWasserstein2,
|
|
46
|
+
GPUMMD,
|
|
47
|
+
GPUEnergyDistance,
|
|
48
|
+
vectorized_wasserstein1,
|
|
49
|
+
vectorized_mmd,
|
|
50
|
+
)
|
|
51
|
+
|
|
38
52
|
# All available metrics
|
|
39
53
|
ALL_METRICS = [
|
|
40
54
|
# Reconstruction
|
|
@@ -81,4 +95,15 @@ __all__ = [
|
|
|
81
95
|
"MultivariateMMD",
|
|
82
96
|
# Collections
|
|
83
97
|
"ALL_METRICS",
|
|
98
|
+
# Acceleration
|
|
99
|
+
"AccelerationConfig",
|
|
100
|
+
"ParallelMetricComputer",
|
|
101
|
+
"get_available_backends",
|
|
102
|
+
"compute_metrics_accelerated",
|
|
103
|
+
"GPUWasserstein1",
|
|
104
|
+
"GPUWasserstein2",
|
|
105
|
+
"GPUMMD",
|
|
106
|
+
"GPUEnergyDistance",
|
|
107
|
+
"vectorized_wasserstein1",
|
|
108
|
+
"vectorized_mmd",
|
|
84
109
|
]
|