gengeneeval 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/evaluator.py CHANGED
@@ -66,6 +66,10 @@ class GeneEvalEvaluator:
66
66
  Whether to include multivariate (whole-space) metrics
67
67
  verbose : bool
68
68
  Whether to print progress
69
+ n_jobs : int
70
+ Number of parallel CPU jobs. -1 uses all cores. Default is 1.
71
+ device : str
72
+ Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
69
73
 
70
74
  Examples
71
75
  --------
@@ -73,6 +77,10 @@ class GeneEvalEvaluator:
73
77
  >>> evaluator = GeneEvalEvaluator(loader)
74
78
  >>> results = evaluator.evaluate()
75
79
  >>> results.save("output/")
80
+
81
+ >>> # With acceleration
82
+ >>> evaluator = GeneEvalEvaluator(loader, n_jobs=8, device="cuda")
83
+ >>> results = evaluator.evaluate()
76
84
  """
77
85
 
78
86
  def __init__(
@@ -82,11 +90,15 @@ class GeneEvalEvaluator:
82
90
  aggregate_method: str = "mean",
83
91
  include_multivariate: bool = True,
84
92
  verbose: bool = True,
93
+ n_jobs: int = 1,
94
+ device: str = "cpu",
85
95
  ):
86
96
  self.data_loader = data_loader
87
97
  self.aggregate_method = aggregate_method
88
98
  self.include_multivariate = include_multivariate
89
99
  self.verbose = verbose
100
+ self.n_jobs = n_jobs
101
+ self.device = device
90
102
 
91
103
  # Initialize metrics
92
104
  self.metrics: List[BaseMetric] = []
@@ -106,6 +118,25 @@ class GeneEvalEvaluator:
106
118
  MultivariateWasserstein(),
107
119
  MultivariateMMD(),
108
120
  ])
121
+
122
+ # Initialize accelerated computer if using parallelization or GPU
123
+ self._parallel_computer = None
124
+ if n_jobs != 1 or device != "cpu":
125
+ try:
126
+ from .metrics.accelerated import ParallelMetricComputer
127
+ self._parallel_computer = ParallelMetricComputer(
128
+ n_jobs=n_jobs,
129
+ device=device,
130
+ verbose=verbose,
131
+ )
132
+ if verbose:
133
+ from .metrics.accelerated import get_available_backends
134
+ backends = get_available_backends()
135
+ self._log(f"Acceleration enabled: n_jobs={n_jobs}, device={device}")
136
+ self._log(f"Available backends: {backends}")
137
+ except ImportError as e:
138
+ if verbose:
139
+ self._log(f"Warning: Could not enable acceleration: {e}")
109
140
 
110
141
  def _log(self, msg: str):
111
142
  """Print message if verbose."""
@@ -262,6 +293,8 @@ def evaluate(
262
293
  metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
263
294
  include_multivariate: bool = True,
264
295
  verbose: bool = True,
296
+ n_jobs: int = 1,
297
+ device: str = "cpu",
265
298
  **loader_kwargs
266
299
  ) -> EvaluationResult:
267
300
  """
@@ -285,6 +318,10 @@ def evaluate(
285
318
  Whether to include multivariate metrics
286
319
  verbose : bool
287
320
  Print progress
321
+ n_jobs : int
322
+ Number of parallel CPU jobs. -1 uses all cores. Default is 1.
323
+ device : str
324
+ Compute device: "cpu", "cuda", "cuda:0", "auto". Default is "cpu".
288
325
  **loader_kwargs
289
326
  Additional arguments for data loader
290
327
 
@@ -295,6 +332,7 @@ def evaluate(
295
332
 
296
333
  Examples
297
334
  --------
335
+ >>> # Standard CPU evaluation
298
336
  >>> results = evaluate(
299
337
  ... "real.h5ad",
300
338
  ... "generated.h5ad",
@@ -302,6 +340,12 @@ def evaluate(
302
340
  ... split_column="split",
303
341
  ... output_dir="evaluation_output/"
304
342
  ... )
343
+
344
+ >>> # Parallel CPU evaluation (8 cores)
345
+ >>> results = evaluate(..., n_jobs=8)
346
+
347
+ >>> # GPU-accelerated evaluation
348
+ >>> results = evaluate(..., device="cuda")
305
349
  """
306
350
  # Load data
307
351
  loader = load_data(
@@ -318,6 +362,8 @@ def evaluate(
318
362
  metrics=metrics,
319
363
  include_multivariate=include_multivariate,
320
364
  verbose=verbose,
365
+ n_jobs=n_jobs,
366
+ device=device,
321
367
  )
322
368
 
323
369
  # Run evaluation
@@ -0,0 +1,424 @@
1
+ """
2
+ Memory-efficient evaluator for large-scale gene expression datasets.
3
+
4
+ Uses lazy loading and batched processing to minimize memory footprint.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Dict, List, Optional, Union, Type, Any, Generator
9
+ from pathlib import Path
10
+ import numpy as np
11
+ import warnings
12
+ from dataclasses import dataclass, field
13
+ import gc
14
+
15
+ from .data.lazy_loader import (
16
+ LazyGeneExpressionDataLoader,
17
+ load_data_lazy,
18
+ ConditionBatch,
19
+ )
20
+ from .metrics.base_metric import BaseMetric, MetricResult
21
+ from .metrics.correlation import (
22
+ PearsonCorrelation,
23
+ SpearmanCorrelation,
24
+ MeanPearsonCorrelation,
25
+ MeanSpearmanCorrelation,
26
+ )
27
+ from .metrics.distances import (
28
+ Wasserstein1Distance,
29
+ Wasserstein2Distance,
30
+ MMDDistance,
31
+ EnergyDistance,
32
+ )
33
+ from .metrics.reconstruction import (
34
+ MSEDistance,
35
+ )
36
+
37
+ # These multivariate metrics don't support batched computation
38
+ from .metrics.distances import MultivariateWasserstein, MultivariateMMD
39
+
40
+
41
+ # Metrics that support incremental/batched computation
42
+ BATCHABLE_METRICS = [
43
+ MSEDistance,
44
+ PearsonCorrelation,
45
+ SpearmanCorrelation,
46
+ ]
47
+
48
+ # Metrics that require full data
49
+ NON_BATCHABLE_METRICS = [
50
+ Wasserstein1Distance,
51
+ Wasserstein2Distance,
52
+ MMDDistance,
53
+ EnergyDistance,
54
+ MultivariateWasserstein,
55
+ MultivariateMMD,
56
+ ]
57
+
58
+
59
+ @dataclass
60
+ class StreamingMetricAccumulator:
61
+ """Accumulates values for streaming mean/std computation."""
62
+ n: int = 0
63
+ sum: float = 0.0
64
+ sum_sq: float = 0.0
65
+
66
+ def add(self, value: float, count: int = 1):
67
+ """Add a value (or batch of values with same value)."""
68
+ self.n += count
69
+ self.sum += value * count
70
+ self.sum_sq += (value ** 2) * count
71
+
72
+ def add_batch(self, values: np.ndarray):
73
+ """Add multiple values."""
74
+ self.n += len(values)
75
+ self.sum += np.sum(values)
76
+ self.sum_sq += np.sum(values ** 2)
77
+
78
+ @property
79
+ def mean(self) -> float:
80
+ return self.sum / self.n if self.n > 0 else 0.0
81
+
82
+ @property
83
+ def std(self) -> float:
84
+ if self.n <= 1:
85
+ return 0.0
86
+ variance = (self.sum_sq / self.n) - (self.mean ** 2)
87
+ return np.sqrt(max(0, variance))
88
+
89
+
90
+ @dataclass
91
+ class StreamingConditionResult:
92
+ """Lightweight result for a single condition."""
93
+ condition_key: str
94
+ n_real_samples: int = 0
95
+ n_generated_samples: int = 0
96
+ metrics: Dict[str, float] = field(default_factory=dict)
97
+ real_mean: Optional[np.ndarray] = None
98
+ generated_mean: Optional[np.ndarray] = None
99
+
100
+
101
+ @dataclass
102
+ class StreamingEvaluationResult:
103
+ """Memory-efficient evaluation result that streams to disk."""
104
+ output_dir: Path
105
+ n_conditions: int = 0
106
+ metric_accumulators: Dict[str, StreamingMetricAccumulator] = field(default_factory=dict)
107
+ condition_keys: List[str] = field(default_factory=list)
108
+
109
+ def add_condition(self, result: StreamingConditionResult):
110
+ """Add a condition result and update accumulators."""
111
+ self.n_conditions += 1
112
+ self.condition_keys.append(result.condition_key)
113
+
114
+ for metric_name, value in result.metrics.items():
115
+ if metric_name not in self.metric_accumulators:
116
+ self.metric_accumulators[metric_name] = StreamingMetricAccumulator()
117
+ self.metric_accumulators[metric_name].add(value)
118
+
119
+ def get_summary(self) -> Dict[str, Dict[str, float]]:
120
+ """Get summary statistics."""
121
+ summary = {}
122
+ for name, acc in self.metric_accumulators.items():
123
+ summary[name] = {
124
+ "mean": acc.mean,
125
+ "std": acc.std,
126
+ "n": acc.n,
127
+ }
128
+ return summary
129
+
130
+ def save_summary(self):
131
+ """Save summary to output directory."""
132
+ import json
133
+
134
+ self.output_dir.mkdir(parents=True, exist_ok=True)
135
+
136
+ summary = {
137
+ "n_conditions": self.n_conditions,
138
+ "metrics": self.get_summary(),
139
+ "condition_keys": self.condition_keys,
140
+ }
141
+
142
+ with open(self.output_dir / "summary.json", "w") as f:
143
+ json.dump(summary, f, indent=2)
144
+
145
+
146
+ class MemoryEfficientEvaluator:
147
+ """
148
+ Memory-efficient evaluator using lazy loading and batched processing.
149
+
150
+ Features:
151
+ - Lazy data loading (one condition at a time)
152
+ - Batched processing within conditions
153
+ - Streaming metric accumulation
154
+ - Periodic garbage collection
155
+ - Progress streaming to disk
156
+
157
+ Parameters
158
+ ----------
159
+ data_loader : LazyGeneExpressionDataLoader
160
+ Lazy data loader
161
+ metrics : List[BaseMetric], optional
162
+ Metrics to compute. Note: Some metrics (like MMD) may not support
163
+ batched computation and will use full condition data.
164
+ batch_size : int
165
+ Batch size for within-condition processing
166
+ gc_every_n_conditions : int
167
+ Run garbage collection every N conditions
168
+ verbose : bool
169
+ Print progress
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ data_loader: LazyGeneExpressionDataLoader,
175
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
176
+ batch_size: int = 256,
177
+ gc_every_n_conditions: int = 10,
178
+ verbose: bool = True,
179
+ ):
180
+ self.data_loader = data_loader
181
+ self.batch_size = batch_size
182
+ self.gc_every_n_conditions = gc_every_n_conditions
183
+ self.verbose = verbose
184
+
185
+ # Initialize metrics
186
+ self.metrics: List[BaseMetric] = []
187
+ metric_classes = metrics or [
188
+ MSEDistance,
189
+ PearsonCorrelation,
190
+ SpearmanCorrelation,
191
+ MeanPearsonCorrelation,
192
+ MeanSpearmanCorrelation,
193
+ ]
194
+
195
+ for m in metric_classes:
196
+ if isinstance(m, type):
197
+ self.metrics.append(m())
198
+ else:
199
+ self.metrics.append(m)
200
+
201
+ def _log(self, msg: str):
202
+ if self.verbose:
203
+ print(msg)
204
+
205
+ def evaluate(
206
+ self,
207
+ split: Optional[str] = None,
208
+ output_dir: Optional[Union[str, Path]] = None,
209
+ save_per_condition: bool = False,
210
+ ) -> StreamingEvaluationResult:
211
+ """
212
+ Run memory-efficient evaluation.
213
+
214
+ Parameters
215
+ ----------
216
+ split : str, optional
217
+ Split to evaluate
218
+ output_dir : str or Path, optional
219
+ Directory to save results. If provided, results are streamed to disk.
220
+ save_per_condition : bool
221
+ If True, save individual condition results to disk
222
+
223
+ Returns
224
+ -------
225
+ StreamingEvaluationResult
226
+ Evaluation result with aggregated metrics
227
+ """
228
+ if output_dir is not None:
229
+ output_dir = Path(output_dir)
230
+ output_dir.mkdir(parents=True, exist_ok=True)
231
+ else:
232
+ output_dir = Path(".")
233
+
234
+ result = StreamingEvaluationResult(output_dir=output_dir)
235
+
236
+ # Get conditions
237
+ conditions = self.data_loader.get_common_conditions(split)
238
+ self._log(f"Evaluating {len(conditions)} conditions")
239
+ self._log(f"Memory estimate: {self.data_loader.estimate_memory_usage()}")
240
+
241
+ # Iterate conditions (one at a time in memory)
242
+ for i, (cond_key, real_data, gen_data, cond_info) in enumerate(
243
+ self.data_loader.iterate_conditions(split)
244
+ ):
245
+ if self.verbose and (i + 1) % 10 == 0:
246
+ self._log(f" Processing {i + 1}/{len(conditions)}: {cond_key}")
247
+
248
+ # Compute metrics for this condition
249
+ cond_result = self._evaluate_condition(
250
+ cond_key, real_data, gen_data, cond_info
251
+ )
252
+
253
+ # Add to streaming result
254
+ result.add_condition(cond_result)
255
+
256
+ # Optionally save per-condition result
257
+ if save_per_condition and output_dir:
258
+ self._save_condition_result(cond_result, output_dir)
259
+
260
+ # Periodic garbage collection
261
+ if (i + 1) % self.gc_every_n_conditions == 0:
262
+ gc.collect()
263
+
264
+ # Final summary
265
+ result.save_summary()
266
+
267
+ if self.verbose:
268
+ self._print_summary(result)
269
+
270
+ return result
271
+
272
+ def _evaluate_condition(
273
+ self,
274
+ cond_key: str,
275
+ real_data: np.ndarray,
276
+ gen_data: np.ndarray,
277
+ cond_info: Dict[str, str],
278
+ ) -> StreamingConditionResult:
279
+ """Evaluate a single condition."""
280
+ result = StreamingConditionResult(
281
+ condition_key=cond_key,
282
+ n_real_samples=real_data.shape[0],
283
+ n_generated_samples=gen_data.shape[0],
284
+ )
285
+
286
+ # Compute means
287
+ result.real_mean = real_data.mean(axis=0)
288
+ result.generated_mean = gen_data.mean(axis=0)
289
+
290
+ # Compute metrics
291
+ for metric in self.metrics:
292
+ try:
293
+ metric_result = metric.compute(
294
+ real=real_data,
295
+ generated=gen_data,
296
+ gene_names=self.data_loader.gene_names,
297
+ aggregate_method="mean",
298
+ condition=cond_key,
299
+ )
300
+ result.metrics[metric.name] = metric_result.aggregate_value
301
+ except Exception as e:
302
+ warnings.warn(f"Failed to compute {metric.name} for {cond_key}: {e}")
303
+
304
+ return result
305
+
306
+ def _save_condition_result(
307
+ self,
308
+ result: StreamingConditionResult,
309
+ output_dir: Path,
310
+ ):
311
+ """Save a single condition result to disk."""
312
+ import json
313
+
314
+ condition_dir = output_dir / "conditions"
315
+ condition_dir.mkdir(exist_ok=True)
316
+
317
+ # Safe filename
318
+ safe_key = result.condition_key.replace("/", "_").replace("\\", "_")
319
+
320
+ data = {
321
+ "condition_key": result.condition_key,
322
+ "n_real": result.n_real_samples,
323
+ "n_generated": result.n_generated_samples,
324
+ "metrics": result.metrics,
325
+ }
326
+
327
+ with open(condition_dir / f"{safe_key}.json", "w") as f:
328
+ json.dump(data, f, indent=2)
329
+
330
+ def _print_summary(self, result: StreamingEvaluationResult):
331
+ """Print summary."""
332
+ self._log("\n" + "=" * 60)
333
+ self._log("EVALUATION SUMMARY (Memory-Efficient)")
334
+ self._log("=" * 60)
335
+ self._log(f"Conditions evaluated: {result.n_conditions}")
336
+ self._log("-" * 40)
337
+
338
+ for name, stats in result.get_summary().items():
339
+ self._log(f" {name}: {stats['mean']:.4f} ± {stats['std']:.4f}")
340
+
341
+ self._log("=" * 60)
342
+
343
+
344
+ def evaluate_lazy(
345
+ real_path: Union[str, Path],
346
+ generated_path: Union[str, Path],
347
+ condition_columns: List[str],
348
+ split_column: Optional[str] = None,
349
+ output_dir: Optional[Union[str, Path]] = None,
350
+ batch_size: int = 256,
351
+ use_backed: bool = False,
352
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
353
+ verbose: bool = True,
354
+ save_per_condition: bool = False,
355
+ **kwargs
356
+ ) -> StreamingEvaluationResult:
357
+ """
358
+ Memory-efficient evaluation using lazy loading.
359
+
360
+ Use this function for large datasets that don't fit in memory.
361
+
362
+ Parameters
363
+ ----------
364
+ real_path : str or Path
365
+ Path to real data h5ad file
366
+ generated_path : str or Path
367
+ Path to generated data h5ad file
368
+ condition_columns : List[str]
369
+ Columns to match between datasets
370
+ split_column : str, optional
371
+ Column for train/test split
372
+ output_dir : str or Path, optional
373
+ Directory to save results
374
+ batch_size : int
375
+ Batch size for processing
376
+ use_backed : bool
377
+ Use memory-mapped file access (for very large files)
378
+ metrics : List, optional
379
+ Metrics to compute
380
+ verbose : bool
381
+ Print progress
382
+ save_per_condition : bool
383
+ Save individual condition results
384
+
385
+ Returns
386
+ -------
387
+ StreamingEvaluationResult
388
+ Aggregated evaluation results
389
+
390
+ Examples
391
+ --------
392
+ >>> # For large datasets that don't fit in memory
393
+ >>> results = evaluate_lazy(
394
+ ... "real.h5ad",
395
+ ... "generated.h5ad",
396
+ ... condition_columns=["perturbation"],
397
+ ... output_dir="eval_output/",
398
+ ... batch_size=256,
399
+ ... use_backed=True, # Memory-mapped for very large files
400
+ ... )
401
+ >>> print(results.get_summary())
402
+ """
403
+ # Create lazy loader
404
+ with load_data_lazy(
405
+ real_path=real_path,
406
+ generated_path=generated_path,
407
+ condition_columns=condition_columns,
408
+ split_column=split_column,
409
+ batch_size=batch_size,
410
+ use_backed=use_backed,
411
+ ) as loader:
412
+ # Create evaluator
413
+ evaluator = MemoryEfficientEvaluator(
414
+ data_loader=loader,
415
+ metrics=metrics,
416
+ batch_size=batch_size,
417
+ verbose=verbose,
418
+ )
419
+
420
+ # Run evaluation
421
+ return evaluator.evaluate(
422
+ output_dir=output_dir,
423
+ save_per_condition=save_per_condition,
424
+ )
@@ -35,6 +35,20 @@ from .reconstruction import (
35
35
  R2Score,
36
36
  )
37
37
 
38
+ # Accelerated computation
39
+ from .accelerated import (
40
+ AccelerationConfig,
41
+ ParallelMetricComputer,
42
+ get_available_backends,
43
+ compute_metrics_accelerated,
44
+ GPUWasserstein1,
45
+ GPUWasserstein2,
46
+ GPUMMD,
47
+ GPUEnergyDistance,
48
+ vectorized_wasserstein1,
49
+ vectorized_mmd,
50
+ )
51
+
38
52
  # All available metrics
39
53
  ALL_METRICS = [
40
54
  # Reconstruction
@@ -81,4 +95,15 @@ __all__ = [
81
95
  "MultivariateMMD",
82
96
  # Collections
83
97
  "ALL_METRICS",
98
+ # Acceleration
99
+ "AccelerationConfig",
100
+ "ParallelMetricComputer",
101
+ "get_available_backends",
102
+ "compute_metrics_accelerated",
103
+ "GPUWasserstein1",
104
+ "GPUWasserstein2",
105
+ "GPUMMD",
106
+ "GPUEnergyDistance",
107
+ "vectorized_wasserstein1",
108
+ "vectorized_mmd",
84
109
  ]