gengeneeval 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +52 -1
- geneval/data/__init__.py +14 -0
- geneval/data/lazy_loader.py +562 -0
- geneval/evaluator.py +46 -0
- geneval/lazy_evaluator.py +424 -0
- geneval/metrics/__init__.py +25 -0
- geneval/metrics/accelerated.py +857 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/METADATA +111 -4
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/RECORD +12 -9
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/WHEEL +0 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/licenses/LICENSE +0 -0
geneval/evaluator.py
CHANGED
|
@@ -66,6 +66,10 @@ class GeneEvalEvaluator:
|
|
|
66
66
|
Whether to include multivariate (whole-space) metrics
|
|
67
67
|
verbose : bool
|
|
68
68
|
Whether to print progress
|
|
69
|
+
n_jobs : int
|
|
70
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
71
|
+
device : str
|
|
72
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
69
73
|
|
|
70
74
|
Examples
|
|
71
75
|
--------
|
|
@@ -73,6 +77,10 @@ class GeneEvalEvaluator:
|
|
|
73
77
|
>>> evaluator = GeneEvalEvaluator(loader)
|
|
74
78
|
>>> results = evaluator.evaluate()
|
|
75
79
|
>>> results.save("output/")
|
|
80
|
+
|
|
81
|
+
>>> # With acceleration
|
|
82
|
+
>>> evaluator = GeneEvalEvaluator(loader, n_jobs=8, device="cuda")
|
|
83
|
+
>>> results = evaluator.evaluate()
|
|
76
84
|
"""
|
|
77
85
|
|
|
78
86
|
def __init__(
|
|
@@ -82,11 +90,15 @@ class GeneEvalEvaluator:
|
|
|
82
90
|
aggregate_method: str = "mean",
|
|
83
91
|
include_multivariate: bool = True,
|
|
84
92
|
verbose: bool = True,
|
|
93
|
+
n_jobs: int = 1,
|
|
94
|
+
device: str = "cpu",
|
|
85
95
|
):
|
|
86
96
|
self.data_loader = data_loader
|
|
87
97
|
self.aggregate_method = aggregate_method
|
|
88
98
|
self.include_multivariate = include_multivariate
|
|
89
99
|
self.verbose = verbose
|
|
100
|
+
self.n_jobs = n_jobs
|
|
101
|
+
self.device = device
|
|
90
102
|
|
|
91
103
|
# Initialize metrics
|
|
92
104
|
self.metrics: List[BaseMetric] = []
|
|
@@ -106,6 +118,25 @@ class GeneEvalEvaluator:
|
|
|
106
118
|
MultivariateWasserstein(),
|
|
107
119
|
MultivariateMMD(),
|
|
108
120
|
])
|
|
121
|
+
|
|
122
|
+
# Initialize accelerated computer if using parallelization or GPU
|
|
123
|
+
self._parallel_computer = None
|
|
124
|
+
if n_jobs != 1 or device != "cpu":
|
|
125
|
+
try:
|
|
126
|
+
from .metrics.accelerated import ParallelMetricComputer
|
|
127
|
+
self._parallel_computer = ParallelMetricComputer(
|
|
128
|
+
n_jobs=n_jobs,
|
|
129
|
+
device=device,
|
|
130
|
+
verbose=verbose,
|
|
131
|
+
)
|
|
132
|
+
if verbose:
|
|
133
|
+
from .metrics.accelerated import get_available_backends
|
|
134
|
+
backends = get_available_backends()
|
|
135
|
+
self._log(f"Acceleration enabled: n_jobs={n_jobs}, device={device}")
|
|
136
|
+
self._log(f"Available backends: {backends}")
|
|
137
|
+
except ImportError as e:
|
|
138
|
+
if verbose:
|
|
139
|
+
self._log(f"Warning: Could not enable acceleration: {e}")
|
|
109
140
|
|
|
110
141
|
def _log(self, msg: str):
|
|
111
142
|
"""Print message if verbose."""
|
|
@@ -262,6 +293,8 @@ def evaluate(
|
|
|
262
293
|
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
263
294
|
include_multivariate: bool = True,
|
|
264
295
|
verbose: bool = True,
|
|
296
|
+
n_jobs: int = 1,
|
|
297
|
+
device: str = "cpu",
|
|
265
298
|
**loader_kwargs
|
|
266
299
|
) -> EvaluationResult:
|
|
267
300
|
"""
|
|
@@ -285,6 +318,10 @@ def evaluate(
|
|
|
285
318
|
Whether to include multivariate metrics
|
|
286
319
|
verbose : bool
|
|
287
320
|
Print progress
|
|
321
|
+
n_jobs : int
|
|
322
|
+
Number of parallel CPU jobs. -1 uses all cores. Default is 1.
|
|
323
|
+
device : str
|
|
324
|
+
Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
|
|
288
325
|
**loader_kwargs
|
|
289
326
|
Additional arguments for data loader
|
|
290
327
|
|
|
@@ -295,6 +332,7 @@ def evaluate(
|
|
|
295
332
|
|
|
296
333
|
Examples
|
|
297
334
|
--------
|
|
335
|
+
>>> # Standard CPU evaluation
|
|
298
336
|
>>> results = evaluate(
|
|
299
337
|
... "real.h5ad",
|
|
300
338
|
... "generated.h5ad",
|
|
@@ -302,6 +340,12 @@ def evaluate(
|
|
|
302
340
|
... split_column="split",
|
|
303
341
|
... output_dir="evaluation_output/"
|
|
304
342
|
... )
|
|
343
|
+
|
|
344
|
+
>>> # Parallel CPU evaluation (8 cores)
|
|
345
|
+
>>> results = evaluate(..., n_jobs=8)
|
|
346
|
+
|
|
347
|
+
>>> # GPU-accelerated evaluation
|
|
348
|
+
>>> results = evaluate(..., device="cuda")
|
|
305
349
|
"""
|
|
306
350
|
# Load data
|
|
307
351
|
loader = load_data(
|
|
@@ -318,6 +362,8 @@ def evaluate(
|
|
|
318
362
|
metrics=metrics,
|
|
319
363
|
include_multivariate=include_multivariate,
|
|
320
364
|
verbose=verbose,
|
|
365
|
+
n_jobs=n_jobs,
|
|
366
|
+
device=device,
|
|
321
367
|
)
|
|
322
368
|
|
|
323
369
|
# Run evaluation
|
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-efficient evaluator for large-scale gene expression datasets.
|
|
3
|
+
|
|
4
|
+
Uses lazy loading and batched processing to minimize memory footprint.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Dict, List, Optional, Union, Type, Any, Generator
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import numpy as np
|
|
11
|
+
import warnings
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
import gc
|
|
14
|
+
|
|
15
|
+
from .data.lazy_loader import (
|
|
16
|
+
LazyGeneExpressionDataLoader,
|
|
17
|
+
load_data_lazy,
|
|
18
|
+
ConditionBatch,
|
|
19
|
+
)
|
|
20
|
+
from .metrics.base_metric import BaseMetric, MetricResult
|
|
21
|
+
from .metrics.correlation import (
|
|
22
|
+
PearsonCorrelation,
|
|
23
|
+
SpearmanCorrelation,
|
|
24
|
+
MeanPearsonCorrelation,
|
|
25
|
+
MeanSpearmanCorrelation,
|
|
26
|
+
)
|
|
27
|
+
from .metrics.distances import (
|
|
28
|
+
Wasserstein1Distance,
|
|
29
|
+
Wasserstein2Distance,
|
|
30
|
+
MMDDistance,
|
|
31
|
+
EnergyDistance,
|
|
32
|
+
)
|
|
33
|
+
from .metrics.reconstruction import (
|
|
34
|
+
MSEDistance,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# These multivariate metrics don't support batched computation
|
|
38
|
+
from .metrics.distances import MultivariateWasserstein, MultivariateMMD
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Metrics that support incremental/batched computation
|
|
42
|
+
BATCHABLE_METRICS = [
|
|
43
|
+
MSEDistance,
|
|
44
|
+
PearsonCorrelation,
|
|
45
|
+
SpearmanCorrelation,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Metrics that require full data
|
|
49
|
+
NON_BATCHABLE_METRICS = [
|
|
50
|
+
Wasserstein1Distance,
|
|
51
|
+
Wasserstein2Distance,
|
|
52
|
+
MMDDistance,
|
|
53
|
+
EnergyDistance,
|
|
54
|
+
MultivariateWasserstein,
|
|
55
|
+
MultivariateMMD,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class StreamingMetricAccumulator:
|
|
61
|
+
"""Accumulates values for streaming mean/std computation."""
|
|
62
|
+
n: int = 0
|
|
63
|
+
sum: float = 0.0
|
|
64
|
+
sum_sq: float = 0.0
|
|
65
|
+
|
|
66
|
+
def add(self, value: float, count: int = 1):
|
|
67
|
+
"""Add a value (or batch of values with same value)."""
|
|
68
|
+
self.n += count
|
|
69
|
+
self.sum += value * count
|
|
70
|
+
self.sum_sq += (value ** 2) * count
|
|
71
|
+
|
|
72
|
+
def add_batch(self, values: np.ndarray):
|
|
73
|
+
"""Add multiple values."""
|
|
74
|
+
self.n += len(values)
|
|
75
|
+
self.sum += np.sum(values)
|
|
76
|
+
self.sum_sq += np.sum(values ** 2)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def mean(self) -> float:
|
|
80
|
+
return self.sum / self.n if self.n > 0 else 0.0
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def std(self) -> float:
|
|
84
|
+
if self.n <= 1:
|
|
85
|
+
return 0.0
|
|
86
|
+
variance = (self.sum_sq / self.n) - (self.mean ** 2)
|
|
87
|
+
return np.sqrt(max(0, variance))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class StreamingConditionResult:
|
|
92
|
+
"""Lightweight result for a single condition."""
|
|
93
|
+
condition_key: str
|
|
94
|
+
n_real_samples: int = 0
|
|
95
|
+
n_generated_samples: int = 0
|
|
96
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
97
|
+
real_mean: Optional[np.ndarray] = None
|
|
98
|
+
generated_mean: Optional[np.ndarray] = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class StreamingEvaluationResult:
|
|
103
|
+
"""Memory-efficient evaluation result that streams to disk."""
|
|
104
|
+
output_dir: Path
|
|
105
|
+
n_conditions: int = 0
|
|
106
|
+
metric_accumulators: Dict[str, StreamingMetricAccumulator] = field(default_factory=dict)
|
|
107
|
+
condition_keys: List[str] = field(default_factory=list)
|
|
108
|
+
|
|
109
|
+
def add_condition(self, result: StreamingConditionResult):
|
|
110
|
+
"""Add a condition result and update accumulators."""
|
|
111
|
+
self.n_conditions += 1
|
|
112
|
+
self.condition_keys.append(result.condition_key)
|
|
113
|
+
|
|
114
|
+
for metric_name, value in result.metrics.items():
|
|
115
|
+
if metric_name not in self.metric_accumulators:
|
|
116
|
+
self.metric_accumulators[metric_name] = StreamingMetricAccumulator()
|
|
117
|
+
self.metric_accumulators[metric_name].add(value)
|
|
118
|
+
|
|
119
|
+
def get_summary(self) -> Dict[str, Dict[str, float]]:
|
|
120
|
+
"""Get summary statistics."""
|
|
121
|
+
summary = {}
|
|
122
|
+
for name, acc in self.metric_accumulators.items():
|
|
123
|
+
summary[name] = {
|
|
124
|
+
"mean": acc.mean,
|
|
125
|
+
"std": acc.std,
|
|
126
|
+
"n": acc.n,
|
|
127
|
+
}
|
|
128
|
+
return summary
|
|
129
|
+
|
|
130
|
+
def save_summary(self):
|
|
131
|
+
"""Save summary to output directory."""
|
|
132
|
+
import json
|
|
133
|
+
|
|
134
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
135
|
+
|
|
136
|
+
summary = {
|
|
137
|
+
"n_conditions": self.n_conditions,
|
|
138
|
+
"metrics": self.get_summary(),
|
|
139
|
+
"condition_keys": self.condition_keys,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
with open(self.output_dir / "summary.json", "w") as f:
|
|
143
|
+
json.dump(summary, f, indent=2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class MemoryEfficientEvaluator:
|
|
147
|
+
"""
|
|
148
|
+
Memory-efficient evaluator using lazy loading and batched processing.
|
|
149
|
+
|
|
150
|
+
Features:
|
|
151
|
+
- Lazy data loading (one condition at a time)
|
|
152
|
+
- Batched processing within conditions
|
|
153
|
+
- Streaming metric accumulation
|
|
154
|
+
- Periodic garbage collection
|
|
155
|
+
- Progress streaming to disk
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
data_loader : LazyGeneExpressionDataLoader
|
|
160
|
+
Lazy data loader
|
|
161
|
+
metrics : List[BaseMetric], optional
|
|
162
|
+
Metrics to compute. Note: Some metrics (like MMD) may not support
|
|
163
|
+
batched computation and will use full condition data.
|
|
164
|
+
batch_size : int
|
|
165
|
+
Batch size for within-condition processing
|
|
166
|
+
gc_every_n_conditions : int
|
|
167
|
+
Run garbage collection every N conditions
|
|
168
|
+
verbose : bool
|
|
169
|
+
Print progress
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
data_loader: LazyGeneExpressionDataLoader,
|
|
175
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
176
|
+
batch_size: int = 256,
|
|
177
|
+
gc_every_n_conditions: int = 10,
|
|
178
|
+
verbose: bool = True,
|
|
179
|
+
):
|
|
180
|
+
self.data_loader = data_loader
|
|
181
|
+
self.batch_size = batch_size
|
|
182
|
+
self.gc_every_n_conditions = gc_every_n_conditions
|
|
183
|
+
self.verbose = verbose
|
|
184
|
+
|
|
185
|
+
# Initialize metrics
|
|
186
|
+
self.metrics: List[BaseMetric] = []
|
|
187
|
+
metric_classes = metrics or [
|
|
188
|
+
MSEDistance,
|
|
189
|
+
PearsonCorrelation,
|
|
190
|
+
SpearmanCorrelation,
|
|
191
|
+
MeanPearsonCorrelation,
|
|
192
|
+
MeanSpearmanCorrelation,
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
for m in metric_classes:
|
|
196
|
+
if isinstance(m, type):
|
|
197
|
+
self.metrics.append(m())
|
|
198
|
+
else:
|
|
199
|
+
self.metrics.append(m)
|
|
200
|
+
|
|
201
|
+
def _log(self, msg: str):
|
|
202
|
+
if self.verbose:
|
|
203
|
+
print(msg)
|
|
204
|
+
|
|
205
|
+
def evaluate(
|
|
206
|
+
self,
|
|
207
|
+
split: Optional[str] = None,
|
|
208
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
209
|
+
save_per_condition: bool = False,
|
|
210
|
+
) -> StreamingEvaluationResult:
|
|
211
|
+
"""
|
|
212
|
+
Run memory-efficient evaluation.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
split : str, optional
|
|
217
|
+
Split to evaluate
|
|
218
|
+
output_dir : str or Path, optional
|
|
219
|
+
Directory to save results. If provided, results are streamed to disk.
|
|
220
|
+
save_per_condition : bool
|
|
221
|
+
If True, save individual condition results to disk
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
StreamingEvaluationResult
|
|
226
|
+
Evaluation result with aggregated metrics
|
|
227
|
+
"""
|
|
228
|
+
if output_dir is not None:
|
|
229
|
+
output_dir = Path(output_dir)
|
|
230
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
231
|
+
else:
|
|
232
|
+
output_dir = Path(".")
|
|
233
|
+
|
|
234
|
+
result = StreamingEvaluationResult(output_dir=output_dir)
|
|
235
|
+
|
|
236
|
+
# Get conditions
|
|
237
|
+
conditions = self.data_loader.get_common_conditions(split)
|
|
238
|
+
self._log(f"Evaluating {len(conditions)} conditions")
|
|
239
|
+
self._log(f"Memory estimate: {self.data_loader.estimate_memory_usage()}")
|
|
240
|
+
|
|
241
|
+
# Iterate conditions (one at a time in memory)
|
|
242
|
+
for i, (cond_key, real_data, gen_data, cond_info) in enumerate(
|
|
243
|
+
self.data_loader.iterate_conditions(split)
|
|
244
|
+
):
|
|
245
|
+
if self.verbose and (i + 1) % 10 == 0:
|
|
246
|
+
self._log(f" Processing {i + 1}/{len(conditions)}: {cond_key}")
|
|
247
|
+
|
|
248
|
+
# Compute metrics for this condition
|
|
249
|
+
cond_result = self._evaluate_condition(
|
|
250
|
+
cond_key, real_data, gen_data, cond_info
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Add to streaming result
|
|
254
|
+
result.add_condition(cond_result)
|
|
255
|
+
|
|
256
|
+
# Optionally save per-condition result
|
|
257
|
+
if save_per_condition and output_dir:
|
|
258
|
+
self._save_condition_result(cond_result, output_dir)
|
|
259
|
+
|
|
260
|
+
# Periodic garbage collection
|
|
261
|
+
if (i + 1) % self.gc_every_n_conditions == 0:
|
|
262
|
+
gc.collect()
|
|
263
|
+
|
|
264
|
+
# Final summary
|
|
265
|
+
result.save_summary()
|
|
266
|
+
|
|
267
|
+
if self.verbose:
|
|
268
|
+
self._print_summary(result)
|
|
269
|
+
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
def _evaluate_condition(
|
|
273
|
+
self,
|
|
274
|
+
cond_key: str,
|
|
275
|
+
real_data: np.ndarray,
|
|
276
|
+
gen_data: np.ndarray,
|
|
277
|
+
cond_info: Dict[str, str],
|
|
278
|
+
) -> StreamingConditionResult:
|
|
279
|
+
"""Evaluate a single condition."""
|
|
280
|
+
result = StreamingConditionResult(
|
|
281
|
+
condition_key=cond_key,
|
|
282
|
+
n_real_samples=real_data.shape[0],
|
|
283
|
+
n_generated_samples=gen_data.shape[0],
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Compute means
|
|
287
|
+
result.real_mean = real_data.mean(axis=0)
|
|
288
|
+
result.generated_mean = gen_data.mean(axis=0)
|
|
289
|
+
|
|
290
|
+
# Compute metrics
|
|
291
|
+
for metric in self.metrics:
|
|
292
|
+
try:
|
|
293
|
+
metric_result = metric.compute(
|
|
294
|
+
real=real_data,
|
|
295
|
+
generated=gen_data,
|
|
296
|
+
gene_names=self.data_loader.gene_names,
|
|
297
|
+
aggregate_method="mean",
|
|
298
|
+
condition=cond_key,
|
|
299
|
+
)
|
|
300
|
+
result.metrics[metric.name] = metric_result.aggregate_value
|
|
301
|
+
except Exception as e:
|
|
302
|
+
warnings.warn(f"Failed to compute {metric.name} for {cond_key}: {e}")
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
def _save_condition_result(
|
|
307
|
+
self,
|
|
308
|
+
result: StreamingConditionResult,
|
|
309
|
+
output_dir: Path,
|
|
310
|
+
):
|
|
311
|
+
"""Save a single condition result to disk."""
|
|
312
|
+
import json
|
|
313
|
+
|
|
314
|
+
condition_dir = output_dir / "conditions"
|
|
315
|
+
condition_dir.mkdir(exist_ok=True)
|
|
316
|
+
|
|
317
|
+
# Safe filename
|
|
318
|
+
safe_key = result.condition_key.replace("/", "_").replace("\\", "_")
|
|
319
|
+
|
|
320
|
+
data = {
|
|
321
|
+
"condition_key": result.condition_key,
|
|
322
|
+
"n_real": result.n_real_samples,
|
|
323
|
+
"n_generated": result.n_generated_samples,
|
|
324
|
+
"metrics": result.metrics,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
with open(condition_dir / f"{safe_key}.json", "w") as f:
|
|
328
|
+
json.dump(data, f, indent=2)
|
|
329
|
+
|
|
330
|
+
def _print_summary(self, result: StreamingEvaluationResult):
|
|
331
|
+
"""Print summary."""
|
|
332
|
+
self._log("\n" + "=" * 60)
|
|
333
|
+
self._log("EVALUATION SUMMARY (Memory-Efficient)")
|
|
334
|
+
self._log("=" * 60)
|
|
335
|
+
self._log(f"Conditions evaluated: {result.n_conditions}")
|
|
336
|
+
self._log("-" * 40)
|
|
337
|
+
|
|
338
|
+
for name, stats in result.get_summary().items():
|
|
339
|
+
self._log(f" {name}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
340
|
+
|
|
341
|
+
self._log("=" * 60)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def evaluate_lazy(
|
|
345
|
+
real_path: Union[str, Path],
|
|
346
|
+
generated_path: Union[str, Path],
|
|
347
|
+
condition_columns: List[str],
|
|
348
|
+
split_column: Optional[str] = None,
|
|
349
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
350
|
+
batch_size: int = 256,
|
|
351
|
+
use_backed: bool = False,
|
|
352
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
353
|
+
verbose: bool = True,
|
|
354
|
+
save_per_condition: bool = False,
|
|
355
|
+
**kwargs
|
|
356
|
+
) -> StreamingEvaluationResult:
|
|
357
|
+
"""
|
|
358
|
+
Memory-efficient evaluation using lazy loading.
|
|
359
|
+
|
|
360
|
+
Use this function for large datasets that don't fit in memory.
|
|
361
|
+
|
|
362
|
+
Parameters
|
|
363
|
+
----------
|
|
364
|
+
real_path : str or Path
|
|
365
|
+
Path to real data h5ad file
|
|
366
|
+
generated_path : str or Path
|
|
367
|
+
Path to generated data h5ad file
|
|
368
|
+
condition_columns : List[str]
|
|
369
|
+
Columns to match between datasets
|
|
370
|
+
split_column : str, optional
|
|
371
|
+
Column for train/test split
|
|
372
|
+
output_dir : str or Path, optional
|
|
373
|
+
Directory to save results
|
|
374
|
+
batch_size : int
|
|
375
|
+
Batch size for processing
|
|
376
|
+
use_backed : bool
|
|
377
|
+
Use memory-mapped file access (for very large files)
|
|
378
|
+
metrics : List, optional
|
|
379
|
+
Metrics to compute
|
|
380
|
+
verbose : bool
|
|
381
|
+
Print progress
|
|
382
|
+
save_per_condition : bool
|
|
383
|
+
Save individual condition results
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
StreamingEvaluationResult
|
|
388
|
+
Aggregated evaluation results
|
|
389
|
+
|
|
390
|
+
Examples
|
|
391
|
+
--------
|
|
392
|
+
>>> # For large datasets that don't fit in memory
|
|
393
|
+
>>> results = evaluate_lazy(
|
|
394
|
+
... "real.h5ad",
|
|
395
|
+
... "generated.h5ad",
|
|
396
|
+
... condition_columns=["perturbation"],
|
|
397
|
+
... output_dir="eval_output/",
|
|
398
|
+
... batch_size=256,
|
|
399
|
+
... use_backed=True, # Memory-mapped for very large files
|
|
400
|
+
... )
|
|
401
|
+
>>> print(results.get_summary())
|
|
402
|
+
"""
|
|
403
|
+
# Create lazy loader
|
|
404
|
+
with load_data_lazy(
|
|
405
|
+
real_path=real_path,
|
|
406
|
+
generated_path=generated_path,
|
|
407
|
+
condition_columns=condition_columns,
|
|
408
|
+
split_column=split_column,
|
|
409
|
+
batch_size=batch_size,
|
|
410
|
+
use_backed=use_backed,
|
|
411
|
+
) as loader:
|
|
412
|
+
# Create evaluator
|
|
413
|
+
evaluator = MemoryEfficientEvaluator(
|
|
414
|
+
data_loader=loader,
|
|
415
|
+
metrics=metrics,
|
|
416
|
+
batch_size=batch_size,
|
|
417
|
+
verbose=verbose,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Run evaluation
|
|
421
|
+
return evaluator.evaluate(
|
|
422
|
+
output_dir=output_dir,
|
|
423
|
+
save_per_condition=save_per_condition,
|
|
424
|
+
)
|
geneval/metrics/__init__.py
CHANGED
|
@@ -35,6 +35,20 @@ from .reconstruction import (
|
|
|
35
35
|
R2Score,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
+
# Accelerated computation
|
|
39
|
+
from .accelerated import (
|
|
40
|
+
AccelerationConfig,
|
|
41
|
+
ParallelMetricComputer,
|
|
42
|
+
get_available_backends,
|
|
43
|
+
compute_metrics_accelerated,
|
|
44
|
+
GPUWasserstein1,
|
|
45
|
+
GPUWasserstein2,
|
|
46
|
+
GPUMMD,
|
|
47
|
+
GPUEnergyDistance,
|
|
48
|
+
vectorized_wasserstein1,
|
|
49
|
+
vectorized_mmd,
|
|
50
|
+
)
|
|
51
|
+
|
|
38
52
|
# All available metrics
|
|
39
53
|
ALL_METRICS = [
|
|
40
54
|
# Reconstruction
|
|
@@ -81,4 +95,15 @@ __all__ = [
|
|
|
81
95
|
"MultivariateMMD",
|
|
82
96
|
# Collections
|
|
83
97
|
"ALL_METRICS",
|
|
98
|
+
# Acceleration
|
|
99
|
+
"AccelerationConfig",
|
|
100
|
+
"ParallelMetricComputer",
|
|
101
|
+
"get_available_backends",
|
|
102
|
+
"compute_metrics_accelerated",
|
|
103
|
+
"GPUWasserstein1",
|
|
104
|
+
"GPUWasserstein2",
|
|
105
|
+
"GPUMMD",
|
|
106
|
+
"GPUEnergyDistance",
|
|
107
|
+
"vectorized_wasserstein1",
|
|
108
|
+
"vectorized_mmd",
|
|
84
109
|
]
|