gengeneeval 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/__init__.py CHANGED
@@ -8,6 +8,7 @@ Features:
8
8
  - Multiple distance and correlation metrics (per-gene and aggregate)
9
9
  - Condition-based matching (perturbation, cell type, etc.)
10
10
  - Train/test split support
11
+ - Memory-efficient lazy loading for large datasets
11
12
  - Publication-quality visualizations
12
13
  - Command-line interface
13
14
 
@@ -20,12 +21,22 @@ Quick Start:
20
21
  ... output_dir="output/"
21
22
  ... )
22
23
 
24
+ Memory-Efficient Mode (for large datasets):
25
+ >>> from geneval import evaluate_lazy
26
+ >>> results = evaluate_lazy(
27
+ ... real_path="real.h5ad",
28
+ ... generated_path="generated.h5ad",
29
+ ... condition_columns=["perturbation"],
30
+ ... batch_size=256,
31
+ ... use_backed=True, # Memory-mapped access
32
+ ... )
33
+
23
34
  CLI Usage:
24
35
  $ geneval --real real.h5ad --generated generated.h5ad \\
25
36
  --conditions perturbation cell_type --output results/
26
37
  """
27
38
 
28
- __version__ = "0.2.0"
39
+ __version__ = "0.2.1"
29
40
  __author__ = "GenEval Team"
30
41
 
31
42
  # Main evaluation interface
@@ -35,12 +46,26 @@ from .evaluator import (
35
46
  MetricRegistry,
36
47
  )
37
48
 
49
+ # Memory-efficient evaluation
50
+ from .lazy_evaluator import (
51
+ evaluate_lazy,
52
+ MemoryEfficientEvaluator,
53
+ StreamingEvaluationResult,
54
+ )
55
+
38
56
  # Data loading
39
57
  from .data.loader import (
40
58
  GeneExpressionDataLoader,
41
59
  load_data,
42
60
  )
43
61
 
62
+ # Memory-efficient data loading
63
+ from .data.lazy_loader import (
64
+ LazyGeneExpressionDataLoader,
65
+ load_data_lazy,
66
+ ConditionBatch,
67
+ )
68
+
44
69
  # Results
45
70
  from .results import (
46
71
  EvaluationResult,
@@ -99,9 +124,17 @@ __all__ = [
99
124
  "evaluate",
100
125
  "GeneEvalEvaluator",
101
126
  "MetricRegistry",
127
+ # Memory-efficient evaluation
128
+ "evaluate_lazy",
129
+ "MemoryEfficientEvaluator",
130
+ "StreamingEvaluationResult",
102
131
  # Data loading
103
132
  "GeneExpressionDataLoader",
104
133
  "load_data",
134
+ # Memory-efficient data loading
135
+ "LazyGeneExpressionDataLoader",
136
+ "load_data_lazy",
137
+ "ConditionBatch",
105
138
  # Results
106
139
  "EvaluationResult",
107
140
  "SplitResult",
@@ -123,6 +156,11 @@ __all__ = [
123
156
  "EnergyDistance",
124
157
  "MultivariateWasserstein",
125
158
  "MultivariateMMD",
159
+ # Reconstruction metrics
160
+ "MSEDistance",
161
+ "RMSEDistance",
162
+ "MAEDistance",
163
+ "R2Score",
126
164
  # Visualization
127
165
  "EvaluationVisualizer",
128
166
  "visualize",
geneval/data/__init__.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Data loading module for gene expression evaluation.
3
3
 
4
4
  Provides data loaders for paired real and generated datasets.
5
+ Includes both standard and memory-efficient lazy loading options.
5
6
  """
6
7
 
7
8
  from .loader import (
@@ -9,15 +10,28 @@ from .loader import (
9
10
  load_data,
10
11
  DataLoaderError,
11
12
  )
13
+ from .lazy_loader import (
14
+ LazyGeneExpressionDataLoader,
15
+ load_data_lazy,
16
+ LazyDataLoaderError,
17
+ ConditionBatch,
18
+ )
12
19
  from .gene_expression_datamodule import (
13
20
  GeneExpressionDataModule,
14
21
  DataModuleError,
15
22
  )
16
23
 
17
24
  __all__ = [
25
+ # Standard loader
18
26
  "GeneExpressionDataLoader",
19
27
  "load_data",
20
28
  "DataLoaderError",
29
+ # Lazy loader (memory-efficient)
30
+ "LazyGeneExpressionDataLoader",
31
+ "load_data_lazy",
32
+ "LazyDataLoaderError",
33
+ "ConditionBatch",
34
+ # DataModule
21
35
  "GeneExpressionDataModule",
22
36
  "DataModuleError",
23
37
  ]
@@ -0,0 +1,562 @@
1
+ """
2
+ Memory-efficient lazy data loader for large-scale gene expression datasets.
3
+
4
+ Provides lazy loading and batched iteration over AnnData h5ad files without
5
+ loading entire datasets into memory. Supports backed mode for very large files.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Optional, List, Union, Dict, Tuple, Iterator, Generator
10
+ from pathlib import Path
11
+ import warnings
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy import sparse
15
+ from dataclasses import dataclass
16
+
17
+ try:
18
+ import anndata as ad
19
+ import scanpy as sc
20
+ except ImportError:
21
+ raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
22
+
23
+
24
+ class LazyDataLoaderError(Exception):
25
+ """Custom exception for lazy data loading errors."""
26
+ pass
27
+
28
+
29
+ @dataclass
30
+ class ConditionBatch:
31
+ """Container for a batch of samples from a condition."""
32
+ condition_key: str
33
+ condition_info: Dict[str, str]
34
+ real_data: np.ndarray
35
+ generated_data: np.ndarray
36
+ batch_idx: int
37
+ n_batches: int
38
+ is_last_batch: bool
39
+
40
+
41
+ class LazyGeneExpressionDataLoader:
42
+ """
43
+ Memory-efficient lazy data loader for paired gene expression datasets.
44
+
45
+ Unlike GeneExpressionDataLoader, this class:
46
+ - Uses backed mode for h5ad files to avoid loading entire datasets
47
+ - Supports batched iteration over conditions
48
+ - Only loads data into memory when explicitly requested
49
+ - Provides memory usage estimates
50
+
51
+ Parameters
52
+ ----------
53
+ real_path : str or Path
54
+ Path to real data h5ad file
55
+ generated_path : str or Path
56
+ Path to generated data h5ad file
57
+ condition_columns : List[str]
58
+ Columns to match between datasets
59
+ split_column : str, optional
60
+ Column indicating train/test split
61
+ batch_size : int
62
+ Maximum number of samples per batch when iterating
63
+ use_backed : bool
64
+ If True, use backed mode (memory-mapped). May be slower but uses minimal memory.
65
+ If False, loads full file but processes in batches.
66
+ min_samples_per_condition : int
67
+ Minimum samples required per condition to include
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ real_path: Union[str, Path],
73
+ generated_path: Union[str, Path],
74
+ condition_columns: List[str],
75
+ split_column: Optional[str] = None,
76
+ batch_size: int = 256,
77
+ use_backed: bool = False,
78
+ min_samples_per_condition: int = 2,
79
+ ):
80
+ self.real_path = Path(real_path)
81
+ self.generated_path = Path(generated_path)
82
+ self.condition_columns = condition_columns
83
+ self.split_column = split_column
84
+ self.batch_size = batch_size
85
+ self.use_backed = use_backed
86
+ self.min_samples_per_condition = min_samples_per_condition
87
+
88
+ # Lazy-loaded references (backed or full)
89
+ self._real: Optional[ad.AnnData] = None
90
+ self._generated: Optional[ad.AnnData] = None
91
+
92
+ # Metadata (always loaded - lightweight)
93
+ self._real_obs: Optional[pd.DataFrame] = None
94
+ self._generated_obs: Optional[pd.DataFrame] = None
95
+ self._real_var_names: Optional[pd.Index] = None
96
+ self._generated_var_names: Optional[pd.Index] = None
97
+
98
+ # Gene alignment info
99
+ self._common_genes: Optional[List[str]] = None
100
+ self._real_gene_idx: Optional[np.ndarray] = None
101
+ self._gen_gene_idx: Optional[np.ndarray] = None
102
+
103
+ # Pre-computed condition indices for fast access
104
+ self._condition_indices: Optional[Dict[str, Dict[str, np.ndarray]]] = None
105
+
106
+ # State
107
+ self._is_initialized = False
108
+
109
+ def initialize(self) -> "LazyGeneExpressionDataLoader":
110
+ """
111
+ Initialize loader by reading metadata only (not expression data).
112
+
113
+ This loads obs DataFrames and var_names to prepare for iteration,
114
+ but does not load the expression matrices.
115
+
116
+ Returns
117
+ -------
118
+ self
119
+ For method chaining
120
+ """
121
+ if self._is_initialized:
122
+ return self
123
+
124
+ # Validate paths
125
+ if not self.real_path.exists():
126
+ raise LazyDataLoaderError(f"Real data file not found: {self.real_path}")
127
+ if not self.generated_path.exists():
128
+ raise LazyDataLoaderError(f"Generated data file not found: {self.generated_path}")
129
+
130
+ # Load metadata only (obs and var_names)
131
+ # This is much faster and lighter than loading full data
132
+ self._load_metadata()
133
+
134
+ # Validate columns
135
+ self._validate_columns()
136
+
137
+ # Compute gene alignment indices
138
+ self._compute_gene_alignment()
139
+
140
+ # Pre-compute condition indices
141
+ self._precompute_condition_indices()
142
+
143
+ self._is_initialized = True
144
+ return self
145
+
146
+ def _load_metadata(self):
147
+ """Load only metadata (obs, var_names) without expression data."""
148
+ # For backed mode, we open but don't load X
149
+ # For non-backed, we still only read metadata initially
150
+
151
+ # Try backed mode if requested
152
+ if self.use_backed:
153
+ try:
154
+ self._real = sc.read_h5ad(self.real_path, backed='r')
155
+ self._generated = sc.read_h5ad(self.generated_path, backed='r')
156
+ except Exception as e:
157
+ warnings.warn(f"Backed mode failed, falling back to standard loading: {e}")
158
+ self._real = None
159
+ self._generated = None
160
+
161
+ if self._real is None:
162
+ # Load only what we need: obs and var
163
+ # For very large files, read in low-memory mode
164
+ self._real = sc.read_h5ad(self.real_path)
165
+ self._generated = sc.read_h5ad(self.generated_path)
166
+
167
+ # Cache lightweight metadata
168
+ self._real_obs = self._real.obs.copy()
169
+ self._generated_obs = self._generated.obs.copy()
170
+ self._real_var_names = pd.Index(self._real.var_names.astype(str))
171
+ self._generated_var_names = pd.Index(self._generated.var_names.astype(str))
172
+
173
+ def _validate_columns(self):
174
+ """Validate that required columns exist."""
175
+ for col in self.condition_columns:
176
+ if col not in self._real_obs.columns:
177
+ raise LazyDataLoaderError(
178
+ f"Condition column '{col}' not found in real data. "
179
+ f"Available: {list(self._real_obs.columns)}"
180
+ )
181
+ if col not in self._generated_obs.columns:
182
+ raise LazyDataLoaderError(
183
+ f"Condition column '{col}' not found in generated data. "
184
+ f"Available: {list(self._generated_obs.columns)}"
185
+ )
186
+
187
+ def _compute_gene_alignment(self):
188
+ """Pre-compute gene alignment indices."""
189
+ common = self._real_var_names.intersection(self._generated_var_names)
190
+
191
+ if len(common) == 0:
192
+ raise LazyDataLoaderError(
193
+ "No overlapping genes between real and generated data."
194
+ )
195
+
196
+ self._common_genes = common.tolist()
197
+ self._real_gene_idx = self._real_var_names.get_indexer(common)
198
+ self._gen_gene_idx = self._generated_var_names.get_indexer(common)
199
+
200
+ n_real_only = len(self._real_var_names) - len(common)
201
+ n_gen_only = len(self._generated_var_names) - len(common)
202
+
203
+ if n_real_only > 0 or n_gen_only > 0:
204
+ warnings.warn(
205
+ f"Gene alignment: {len(common)} common genes. "
206
+ f"Dropped {n_real_only} from real, {n_gen_only} from generated."
207
+ )
208
+
209
+ def _get_condition_key(self, row: pd.Series) -> str:
210
+ """Generate unique key for a condition combination."""
211
+ return "####".join([str(row[c]) for c in self.condition_columns])
212
+
213
+ def _precompute_condition_indices(self):
214
+ """Pre-compute sample indices for each condition (lightweight)."""
215
+ self._condition_indices = {"real": {}, "generated": {}}
216
+
217
+ # Real data conditions
218
+ real_conditions = self._real_obs[self.condition_columns].astype(str).drop_duplicates()
219
+ for _, row in real_conditions.iterrows():
220
+ key = self._get_condition_key(row)
221
+ mask = np.ones(len(self._real_obs), dtype=bool)
222
+ for col in self.condition_columns:
223
+ mask &= (self._real_obs[col].astype(str) == str(row[col])).values
224
+
225
+ indices = np.where(mask)[0]
226
+ if len(indices) >= self.min_samples_per_condition:
227
+ self._condition_indices["real"][key] = indices
228
+
229
+ # Generated data conditions
230
+ gen_conditions = self._generated_obs[self.condition_columns].astype(str).drop_duplicates()
231
+ for _, row in gen_conditions.iterrows():
232
+ key = self._get_condition_key(row)
233
+ mask = np.ones(len(self._generated_obs), dtype=bool)
234
+ for col in self.condition_columns:
235
+ mask &= (self._generated_obs[col].astype(str) == str(row[col])).values
236
+
237
+ indices = np.where(mask)[0]
238
+ if len(indices) >= self.min_samples_per_condition:
239
+ self._condition_indices["generated"][key] = indices
240
+
241
+ def get_splits(self) -> List[str]:
242
+ """Get available splits."""
243
+ if not self._is_initialized:
244
+ self.initialize()
245
+
246
+ if self.split_column is None or self.split_column not in self._real_obs.columns:
247
+ return ["all"]
248
+
249
+ return list(self._real_obs[self.split_column].astype(str).unique())
250
+
251
+ def get_common_conditions(self, split: Optional[str] = None) -> List[str]:
252
+ """Get conditions present in both real and generated data."""
253
+ if not self._is_initialized:
254
+ self.initialize()
255
+
256
+ real_keys = set(self._condition_indices["real"].keys())
257
+ gen_keys = set(self._condition_indices["generated"].keys())
258
+
259
+ common = real_keys & gen_keys
260
+
261
+ # Filter by split if specified
262
+ if split is not None and split != "all" and self.split_column is not None:
263
+ filtered = set()
264
+ for key in common:
265
+ real_idx = self._condition_indices["real"][key]
266
+ split_vals = self._real_obs.iloc[real_idx][self.split_column].astype(str)
267
+ if (split_vals == split).any():
268
+ filtered.add(key)
269
+ common = filtered
270
+
271
+ return sorted(common)
272
+
273
+ def _extract_data_subset(
274
+ self,
275
+ adata: ad.AnnData,
276
+ indices: np.ndarray,
277
+ gene_idx: np.ndarray,
278
+ ) -> np.ndarray:
279
+ """Extract and align a subset of data."""
280
+ # Handle backed vs loaded data
281
+ if hasattr(adata, 'isbacked') and adata.isbacked:
282
+ # Backed mode: read only what we need
283
+ X = adata.X[indices][:, gene_idx]
284
+ else:
285
+ # Standard mode
286
+ X = adata.X[indices][:, gene_idx]
287
+
288
+ # Convert to dense if sparse
289
+ if sparse.issparse(X):
290
+ X = X.toarray()
291
+
292
+ return np.asarray(X, dtype=np.float32)
293
+
294
+ def iterate_conditions(
295
+ self,
296
+ split: Optional[str] = None,
297
+ ) -> Generator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]], None, None]:
298
+ """
299
+ Iterate over conditions, loading one condition at a time.
300
+
301
+ This loads data for one condition, yields it, then releases memory
302
+ before loading the next condition.
303
+
304
+ Parameters
305
+ ----------
306
+ split : str, optional
307
+ Filter to this split only
308
+
309
+ Yields
310
+ ------
311
+ Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
312
+ (condition_key, real_data, generated_data, condition_info)
313
+ """
314
+ if not self._is_initialized:
315
+ self.initialize()
316
+
317
+ common_conditions = self.get_common_conditions(split)
318
+
319
+ for key in common_conditions:
320
+ real_indices = self._condition_indices["real"][key]
321
+ gen_indices = self._condition_indices["generated"][key]
322
+
323
+ # Filter by split if needed
324
+ if split is not None and split != "all" and self.split_column is not None:
325
+ split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
326
+ real_indices = real_indices[split_mask.values]
327
+
328
+ if len(real_indices) < self.min_samples_per_condition:
329
+ continue
330
+
331
+ # Load data for this condition only
332
+ real_data = self._extract_data_subset(
333
+ self._real, real_indices, self._real_gene_idx
334
+ )
335
+ gen_data = self._extract_data_subset(
336
+ self._generated, gen_indices, self._gen_gene_idx
337
+ )
338
+
339
+ # Parse condition info
340
+ parts = key.split("####")
341
+ condition_info = dict(zip(self.condition_columns, parts))
342
+
343
+ yield key, real_data, gen_data, condition_info
344
+
345
+ def iterate_conditions_batched(
346
+ self,
347
+ split: Optional[str] = None,
348
+ batch_size: Optional[int] = None,
349
+ ) -> Generator[ConditionBatch, None, None]:
350
+ """
351
+ Iterate over conditions in batches for memory efficiency.
352
+
353
+ Useful when even a single condition is too large to fit in memory.
354
+
355
+ Parameters
356
+ ----------
357
+ split : str, optional
358
+ Filter to this split only
359
+ batch_size : int, optional
360
+ Override default batch size
361
+
362
+ Yields
363
+ ------
364
+ ConditionBatch
365
+ Batch of samples from a condition
366
+ """
367
+ if not self._is_initialized:
368
+ self.initialize()
369
+
370
+ batch_size = batch_size or self.batch_size
371
+ common_conditions = self.get_common_conditions(split)
372
+
373
+ for key in common_conditions:
374
+ real_indices = self._condition_indices["real"][key]
375
+ gen_indices = self._condition_indices["generated"][key]
376
+
377
+ # Filter by split
378
+ if split is not None and split != "all" and self.split_column is not None:
379
+ split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
380
+ real_indices = real_indices[split_mask.values]
381
+
382
+ if len(real_indices) < self.min_samples_per_condition:
383
+ continue
384
+
385
+ # Parse condition info
386
+ parts = key.split("####")
387
+ condition_info = dict(zip(self.condition_columns, parts))
388
+
389
+ # Calculate number of batches (use max of real/gen for alignment)
390
+ n_real = len(real_indices)
391
+ n_gen = len(gen_indices)
392
+ n_batches = max(
393
+ (n_real + batch_size - 1) // batch_size,
394
+ (n_gen + batch_size - 1) // batch_size
395
+ )
396
+
397
+ for batch_idx in range(n_batches):
398
+ start_real = batch_idx * batch_size
399
+ end_real = min(start_real + batch_size, n_real)
400
+ start_gen = batch_idx * batch_size
401
+ end_gen = min(start_gen + batch_size, n_gen)
402
+
403
+ # Handle case where one dataset is smaller
404
+ if start_real >= n_real:
405
+ # Wrap around for real data
406
+ batch_real_idx = real_indices[start_real % n_real:end_real % n_real + 1]
407
+ else:
408
+ batch_real_idx = real_indices[start_real:end_real]
409
+
410
+ if start_gen >= n_gen:
411
+ batch_gen_idx = gen_indices[start_gen % n_gen:end_gen % n_gen + 1]
412
+ else:
413
+ batch_gen_idx = gen_indices[start_gen:end_gen]
414
+
415
+ if len(batch_real_idx) == 0 or len(batch_gen_idx) == 0:
416
+ continue
417
+
418
+ real_data = self._extract_data_subset(
419
+ self._real, batch_real_idx, self._real_gene_idx
420
+ )
421
+ gen_data = self._extract_data_subset(
422
+ self._generated, batch_gen_idx, self._gen_gene_idx
423
+ )
424
+
425
+ yield ConditionBatch(
426
+ condition_key=key,
427
+ condition_info=condition_info,
428
+ real_data=real_data,
429
+ generated_data=gen_data,
430
+ batch_idx=batch_idx,
431
+ n_batches=n_batches,
432
+ is_last_batch=(batch_idx == n_batches - 1),
433
+ )
434
+
435
+ @property
436
+ def gene_names(self) -> List[str]:
437
+ """Get common gene names."""
438
+ if not self._is_initialized:
439
+ self.initialize()
440
+ return self._common_genes
441
+
442
+ @property
443
+ def n_genes(self) -> int:
444
+ """Number of common genes."""
445
+ return len(self.gene_names)
446
+
447
+ def estimate_memory_usage(self) -> Dict[str, float]:
448
+ """
449
+ Estimate memory usage in MB for different loading strategies.
450
+
451
+ Returns
452
+ -------
453
+ Dict[str, float]
454
+ Memory estimates in MB
455
+ """
456
+ if not self._is_initialized:
457
+ self.initialize()
458
+
459
+ n_real = len(self._real_obs)
460
+ n_gen = len(self._generated_obs)
461
+ n_genes = self.n_genes
462
+
463
+ # 4 bytes per float32
464
+ bytes_per_element = 4
465
+
466
+ full_real = n_real * n_genes * bytes_per_element / 1e6
467
+ full_gen = n_gen * n_genes * bytes_per_element / 1e6
468
+
469
+ # Average condition size
470
+ n_conditions = len(self.get_common_conditions())
471
+ avg_per_condition = (n_real + n_gen) / max(n_conditions, 1)
472
+ per_condition = avg_per_condition * n_genes * bytes_per_element / 1e6
473
+
474
+ per_batch = self.batch_size * 2 * n_genes * bytes_per_element / 1e6
475
+
476
+ return {
477
+ "full_load_mb": full_real + full_gen,
478
+ "per_condition_mb": per_condition,
479
+ "per_batch_mb": per_batch,
480
+ "metadata_mb": (n_real + n_gen) * 100 / 1e6, # rough obs estimate
481
+ }
482
+
483
+ def close(self):
484
+ """Close backed file handles if any."""
485
+ if self._real is not None and hasattr(self._real, 'file'):
486
+ try:
487
+ self._real.file.close()
488
+ except:
489
+ pass
490
+ if self._generated is not None and hasattr(self._generated, 'file'):
491
+ try:
492
+ self._generated.file.close()
493
+ except:
494
+ pass
495
+
496
+ def __enter__(self):
497
+ """Context manager entry."""
498
+ self.initialize()
499
+ return self
500
+
501
+ def __exit__(self, exc_type, exc_val, exc_tb):
502
+ """Context manager exit."""
503
+ self.close()
504
+
505
+ def __repr__(self) -> str:
506
+ if not self._is_initialized:
507
+ return f"LazyGeneExpressionDataLoader(not initialized, backed={self.use_backed})"
508
+
509
+ return (
510
+ f"LazyGeneExpressionDataLoader("
511
+ f"real={len(self._real_obs)}x{len(self._real_var_names)}, "
512
+ f"gen={len(self._generated_obs)}x{len(self._generated_var_names)}, "
513
+ f"common_genes={self.n_genes}, "
514
+ f"batch_size={self.batch_size})"
515
+ )
516
+
517
+
518
+ def load_data_lazy(
519
+ real_path: Union[str, Path],
520
+ generated_path: Union[str, Path],
521
+ condition_columns: List[str],
522
+ split_column: Optional[str] = None,
523
+ batch_size: int = 256,
524
+ use_backed: bool = False,
525
+ **kwargs
526
+ ) -> LazyGeneExpressionDataLoader:
527
+ """
528
+ Convenience function to create a lazy data loader.
529
+
530
+ Parameters
531
+ ----------
532
+ real_path : str or Path
533
+ Path to real data h5ad file
534
+ generated_path : str or Path
535
+ Path to generated data h5ad file
536
+ condition_columns : List[str]
537
+ Columns to match between datasets
538
+ split_column : str, optional
539
+ Column indicating train/test split
540
+ batch_size : int
541
+ Maximum samples per batch
542
+ use_backed : bool
543
+ Use memory-mapped access for very large files
544
+ **kwargs
545
+ Additional arguments for LazyGeneExpressionDataLoader
546
+
547
+ Returns
548
+ -------
549
+ LazyGeneExpressionDataLoader
550
+ Initialized lazy data loader
551
+ """
552
+ loader = LazyGeneExpressionDataLoader(
553
+ real_path=real_path,
554
+ generated_path=generated_path,
555
+ condition_columns=condition_columns,
556
+ split_column=split_column,
557
+ batch_size=batch_size,
558
+ use_backed=use_backed,
559
+ **kwargs
560
+ )
561
+ loader.initialize()
562
+ return loader
@@ -0,0 +1,424 @@
1
+ """
2
+ Memory-efficient evaluator for large-scale gene expression datasets.
3
+
4
+ Uses lazy loading and batched processing to minimize memory footprint.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Dict, List, Optional, Union, Type, Any, Generator
9
+ from pathlib import Path
10
+ import numpy as np
11
+ import warnings
12
+ from dataclasses import dataclass, field
13
+ import gc
14
+
15
+ from .data.lazy_loader import (
16
+ LazyGeneExpressionDataLoader,
17
+ load_data_lazy,
18
+ ConditionBatch,
19
+ )
20
+ from .metrics.base_metric import BaseMetric, MetricResult
21
+ from .metrics.correlation import (
22
+ PearsonCorrelation,
23
+ SpearmanCorrelation,
24
+ MeanPearsonCorrelation,
25
+ MeanSpearmanCorrelation,
26
+ )
27
+ from .metrics.distances import (
28
+ Wasserstein1Distance,
29
+ Wasserstein2Distance,
30
+ MMDDistance,
31
+ EnergyDistance,
32
+ )
33
+ from .metrics.reconstruction import (
34
+ MSEDistance,
35
+ )
36
+
37
+ # These multivariate metrics don't support batched computation
38
+ from .metrics.distances import MultivariateWasserstein, MultivariateMMD
39
+
40
+
41
+ # Metrics that support incremental/batched computation
42
+ BATCHABLE_METRICS = [
43
+ MSEDistance,
44
+ PearsonCorrelation,
45
+ SpearmanCorrelation,
46
+ ]
47
+
48
+ # Metrics that require full data
49
+ NON_BATCHABLE_METRICS = [
50
+ Wasserstein1Distance,
51
+ Wasserstein2Distance,
52
+ MMDDistance,
53
+ EnergyDistance,
54
+ MultivariateWasserstein,
55
+ MultivariateMMD,
56
+ ]
57
+
58
+
59
+ @dataclass
60
+ class StreamingMetricAccumulator:
61
+ """Accumulates values for streaming mean/std computation."""
62
+ n: int = 0
63
+ sum: float = 0.0
64
+ sum_sq: float = 0.0
65
+
66
+ def add(self, value: float, count: int = 1):
67
+ """Add a value (or batch of values with same value)."""
68
+ self.n += count
69
+ self.sum += value * count
70
+ self.sum_sq += (value ** 2) * count
71
+
72
+ def add_batch(self, values: np.ndarray):
73
+ """Add multiple values."""
74
+ self.n += len(values)
75
+ self.sum += np.sum(values)
76
+ self.sum_sq += np.sum(values ** 2)
77
+
78
+ @property
79
+ def mean(self) -> float:
80
+ return self.sum / self.n if self.n > 0 else 0.0
81
+
82
+ @property
83
+ def std(self) -> float:
84
+ if self.n <= 1:
85
+ return 0.0
86
+ variance = (self.sum_sq / self.n) - (self.mean ** 2)
87
+ return np.sqrt(max(0, variance))
88
+
89
+
90
+ @dataclass
91
+ class StreamingConditionResult:
92
+ """Lightweight result for a single condition."""
93
+ condition_key: str
94
+ n_real_samples: int = 0
95
+ n_generated_samples: int = 0
96
+ metrics: Dict[str, float] = field(default_factory=dict)
97
+ real_mean: Optional[np.ndarray] = None
98
+ generated_mean: Optional[np.ndarray] = None
99
+
100
+
101
+ @dataclass
102
+ class StreamingEvaluationResult:
103
+ """Memory-efficient evaluation result that streams to disk."""
104
+ output_dir: Path
105
+ n_conditions: int = 0
106
+ metric_accumulators: Dict[str, StreamingMetricAccumulator] = field(default_factory=dict)
107
+ condition_keys: List[str] = field(default_factory=list)
108
+
109
+ def add_condition(self, result: StreamingConditionResult):
110
+ """Add a condition result and update accumulators."""
111
+ self.n_conditions += 1
112
+ self.condition_keys.append(result.condition_key)
113
+
114
+ for metric_name, value in result.metrics.items():
115
+ if metric_name not in self.metric_accumulators:
116
+ self.metric_accumulators[metric_name] = StreamingMetricAccumulator()
117
+ self.metric_accumulators[metric_name].add(value)
118
+
119
+ def get_summary(self) -> Dict[str, Dict[str, float]]:
120
+ """Get summary statistics."""
121
+ summary = {}
122
+ for name, acc in self.metric_accumulators.items():
123
+ summary[name] = {
124
+ "mean": acc.mean,
125
+ "std": acc.std,
126
+ "n": acc.n,
127
+ }
128
+ return summary
129
+
130
+ def save_summary(self):
131
+ """Save summary to output directory."""
132
+ import json
133
+
134
+ self.output_dir.mkdir(parents=True, exist_ok=True)
135
+
136
+ summary = {
137
+ "n_conditions": self.n_conditions,
138
+ "metrics": self.get_summary(),
139
+ "condition_keys": self.condition_keys,
140
+ }
141
+
142
+ with open(self.output_dir / "summary.json", "w") as f:
143
+ json.dump(summary, f, indent=2)
144
+
145
+
146
+ class MemoryEfficientEvaluator:
147
+ """
148
+ Memory-efficient evaluator using lazy loading and batched processing.
149
+
150
+ Features:
151
+ - Lazy data loading (one condition at a time)
152
+ - Batched processing within conditions
153
+ - Streaming metric accumulation
154
+ - Periodic garbage collection
155
+ - Progress streaming to disk
156
+
157
+ Parameters
158
+ ----------
159
+ data_loader : LazyGeneExpressionDataLoader
160
+ Lazy data loader
161
+ metrics : List[BaseMetric], optional
162
+ Metrics to compute. Note: Some metrics (like MMD) may not support
163
+ batched computation and will use full condition data.
164
+ batch_size : int
165
+ Batch size for within-condition processing
166
+ gc_every_n_conditions : int
167
+ Run garbage collection every N conditions
168
+ verbose : bool
169
+ Print progress
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ data_loader: LazyGeneExpressionDataLoader,
175
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
176
+ batch_size: int = 256,
177
+ gc_every_n_conditions: int = 10,
178
+ verbose: bool = True,
179
+ ):
180
+ self.data_loader = data_loader
181
+ self.batch_size = batch_size
182
+ self.gc_every_n_conditions = gc_every_n_conditions
183
+ self.verbose = verbose
184
+
185
+ # Initialize metrics
186
+ self.metrics: List[BaseMetric] = []
187
+ metric_classes = metrics or [
188
+ MSEDistance,
189
+ PearsonCorrelation,
190
+ SpearmanCorrelation,
191
+ MeanPearsonCorrelation,
192
+ MeanSpearmanCorrelation,
193
+ ]
194
+
195
+ for m in metric_classes:
196
+ if isinstance(m, type):
197
+ self.metrics.append(m())
198
+ else:
199
+ self.metrics.append(m)
200
+
201
+ def _log(self, msg: str):
202
+ if self.verbose:
203
+ print(msg)
204
+
205
+ def evaluate(
206
+ self,
207
+ split: Optional[str] = None,
208
+ output_dir: Optional[Union[str, Path]] = None,
209
+ save_per_condition: bool = False,
210
+ ) -> StreamingEvaluationResult:
211
+ """
212
+ Run memory-efficient evaluation.
213
+
214
+ Parameters
215
+ ----------
216
+ split : str, optional
217
+ Split to evaluate
218
+ output_dir : str or Path, optional
219
+ Directory to save results. If provided, results are streamed to disk.
220
+ save_per_condition : bool
221
+ If True, save individual condition results to disk
222
+
223
+ Returns
224
+ -------
225
+ StreamingEvaluationResult
226
+ Evaluation result with aggregated metrics
227
+ """
228
+ if output_dir is not None:
229
+ output_dir = Path(output_dir)
230
+ output_dir.mkdir(parents=True, exist_ok=True)
231
+ else:
232
+ output_dir = Path(".")
233
+
234
+ result = StreamingEvaluationResult(output_dir=output_dir)
235
+
236
+ # Get conditions
237
+ conditions = self.data_loader.get_common_conditions(split)
238
+ self._log(f"Evaluating {len(conditions)} conditions")
239
+ self._log(f"Memory estimate: {self.data_loader.estimate_memory_usage()}")
240
+
241
+ # Iterate conditions (one at a time in memory)
242
+ for i, (cond_key, real_data, gen_data, cond_info) in enumerate(
243
+ self.data_loader.iterate_conditions(split)
244
+ ):
245
+ if self.verbose and (i + 1) % 10 == 0:
246
+ self._log(f" Processing {i + 1}/{len(conditions)}: {cond_key}")
247
+
248
+ # Compute metrics for this condition
249
+ cond_result = self._evaluate_condition(
250
+ cond_key, real_data, gen_data, cond_info
251
+ )
252
+
253
+ # Add to streaming result
254
+ result.add_condition(cond_result)
255
+
256
+ # Optionally save per-condition result
257
+ if save_per_condition and output_dir:
258
+ self._save_condition_result(cond_result, output_dir)
259
+
260
+ # Periodic garbage collection
261
+ if (i + 1) % self.gc_every_n_conditions == 0:
262
+ gc.collect()
263
+
264
+ # Final summary
265
+ result.save_summary()
266
+
267
+ if self.verbose:
268
+ self._print_summary(result)
269
+
270
+ return result
271
+
272
+ def _evaluate_condition(
273
+ self,
274
+ cond_key: str,
275
+ real_data: np.ndarray,
276
+ gen_data: np.ndarray,
277
+ cond_info: Dict[str, str],
278
+ ) -> StreamingConditionResult:
279
+ """Evaluate a single condition."""
280
+ result = StreamingConditionResult(
281
+ condition_key=cond_key,
282
+ n_real_samples=real_data.shape[0],
283
+ n_generated_samples=gen_data.shape[0],
284
+ )
285
+
286
+ # Compute means
287
+ result.real_mean = real_data.mean(axis=0)
288
+ result.generated_mean = gen_data.mean(axis=0)
289
+
290
+ # Compute metrics
291
+ for metric in self.metrics:
292
+ try:
293
+ metric_result = metric.compute(
294
+ real=real_data,
295
+ generated=gen_data,
296
+ gene_names=self.data_loader.gene_names,
297
+ aggregate_method="mean",
298
+ condition=cond_key,
299
+ )
300
+ result.metrics[metric.name] = metric_result.aggregate_value
301
+ except Exception as e:
302
+ warnings.warn(f"Failed to compute {metric.name} for {cond_key}: {e}")
303
+
304
+ return result
305
+
306
+ def _save_condition_result(
307
+ self,
308
+ result: StreamingConditionResult,
309
+ output_dir: Path,
310
+ ):
311
+ """Save a single condition result to disk."""
312
+ import json
313
+
314
+ condition_dir = output_dir / "conditions"
315
+ condition_dir.mkdir(exist_ok=True)
316
+
317
+ # Safe filename
318
+ safe_key = result.condition_key.replace("/", "_").replace("\\", "_")
319
+
320
+ data = {
321
+ "condition_key": result.condition_key,
322
+ "n_real": result.n_real_samples,
323
+ "n_generated": result.n_generated_samples,
324
+ "metrics": result.metrics,
325
+ }
326
+
327
+ with open(condition_dir / f"{safe_key}.json", "w") as f:
328
+ json.dump(data, f, indent=2)
329
+
330
+ def _print_summary(self, result: StreamingEvaluationResult):
331
+ """Print summary."""
332
+ self._log("\n" + "=" * 60)
333
+ self._log("EVALUATION SUMMARY (Memory-Efficient)")
334
+ self._log("=" * 60)
335
+ self._log(f"Conditions evaluated: {result.n_conditions}")
336
+ self._log("-" * 40)
337
+
338
+ for name, stats in result.get_summary().items():
339
+ self._log(f" {name}: {stats['mean']:.4f} ± {stats['std']:.4f}")
340
+
341
+ self._log("=" * 60)
342
+
343
+
344
+ def evaluate_lazy(
345
+ real_path: Union[str, Path],
346
+ generated_path: Union[str, Path],
347
+ condition_columns: List[str],
348
+ split_column: Optional[str] = None,
349
+ output_dir: Optional[Union[str, Path]] = None,
350
+ batch_size: int = 256,
351
+ use_backed: bool = False,
352
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
353
+ verbose: bool = True,
354
+ save_per_condition: bool = False,
355
+ **kwargs
356
+ ) -> StreamingEvaluationResult:
357
+ """
358
+ Memory-efficient evaluation using lazy loading.
359
+
360
+ Use this function for large datasets that don't fit in memory.
361
+
362
+ Parameters
363
+ ----------
364
+ real_path : str or Path
365
+ Path to real data h5ad file
366
+ generated_path : str or Path
367
+ Path to generated data h5ad file
368
+ condition_columns : List[str]
369
+ Columns to match between datasets
370
+ split_column : str, optional
371
+ Column for train/test split
372
+ output_dir : str or Path, optional
373
+ Directory to save results
374
+ batch_size : int
375
+ Batch size for processing
376
+ use_backed : bool
377
+ Use memory-mapped file access (for very large files)
378
+ metrics : List, optional
379
+ Metrics to compute
380
+ verbose : bool
381
+ Print progress
382
+ save_per_condition : bool
383
+ Save individual condition results
384
+
385
+ Returns
386
+ -------
387
+ StreamingEvaluationResult
388
+ Aggregated evaluation results
389
+
390
+ Examples
391
+ --------
392
+ >>> # For large datasets that don't fit in memory
393
+ >>> results = evaluate_lazy(
394
+ ... "real.h5ad",
395
+ ... "generated.h5ad",
396
+ ... condition_columns=["perturbation"],
397
+ ... output_dir="eval_output/",
398
+ ... batch_size=256,
399
+ ... use_backed=True, # Memory-mapped for very large files
400
+ ... )
401
+ >>> print(results.get_summary())
402
+ """
403
+ # Create lazy loader
404
+ with load_data_lazy(
405
+ real_path=real_path,
406
+ generated_path=generated_path,
407
+ condition_columns=condition_columns,
408
+ split_column=split_column,
409
+ batch_size=batch_size,
410
+ use_backed=use_backed,
411
+ ) as loader:
412
+ # Create evaluator
413
+ evaluator = MemoryEfficientEvaluator(
414
+ data_loader=loader,
415
+ metrics=metrics,
416
+ batch_size=batch_size,
417
+ verbose=verbose,
418
+ )
419
+
420
+ # Run evaluation
421
+ return evaluator.evaluate(
422
+ output_dir=output_dir,
423
+ save_per_condition=save_per_condition,
424
+ )
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.2.0
4
- Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations.
3
+ Version: 0.2.1
4
+ Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
7
- Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking
7
+ Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
8
8
  Author: GenEval Team
9
9
  Author-email: geneval@example.com
10
10
  Requires-Python: >=3.8,<4.0
@@ -46,7 +46,7 @@ Description-Content-Type: text/markdown
46
46
 
47
47
  **Comprehensive evaluation of generated gene expression data against real datasets.**
48
48
 
49
- GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, and generates publication-quality visualizations.
49
+ GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, memory-efficient lazy loading for large datasets, and generates publication-quality visualizations.
50
50
 
51
51
  ## Features
52
52
 
@@ -77,6 +77,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
77
77
  - ✅ Condition-based matching (perturbation, cell type, etc.)
78
78
  - ✅ Train/test split support
79
79
  - ✅ Per-gene and aggregate metrics
80
+ - ✅ **Memory-efficient lazy loading** for large datasets
81
+ - ✅ **Batched evaluation** to avoid OOM errors
80
82
  - ✅ Modular, extensible architecture
81
83
  - ✅ Command-line interface
82
84
  - ✅ Publication-quality visualizations
@@ -140,6 +142,37 @@ geneval --real real.h5ad --generated generated.h5ad \
140
142
  --output results/
141
143
  ```
142
144
 
145
+ ### Memory-Efficient Mode (for Large Datasets)
146
+
147
+ For datasets too large to fit in memory, use the lazy evaluation API:
148
+
149
+ ```python
150
+ from geneval import evaluate_lazy, load_data_lazy
151
+
152
+ # Memory-efficient evaluation (streams data one condition at a time)
153
+ results = evaluate_lazy(
154
+ real_path="large_real.h5ad",
155
+ generated_path="large_generated.h5ad",
156
+ condition_columns=["perturbation"],
157
+ batch_size=256, # Process in batches
158
+ use_backed=True, # Memory-mapped file access
159
+ output_dir="eval_output/",
160
+ save_per_condition=True, # Save each condition to disk
161
+ )
162
+
163
+ # Get summary statistics
164
+ print(results.get_summary())
165
+
166
+ # Or use the lazy loader directly for custom workflows
167
+ with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
168
+ print(f"Memory estimate: {loader.estimate_memory_usage()}")
169
+
170
+ # Process one condition at a time
171
+ for key, real, gen, info in loader.iterate_conditions():
172
+ # Your custom evaluation logic
173
+ pass
174
+ ```
175
+
143
176
  ## Expected Data Format
144
177
 
145
178
  GenEval expects AnnData (h5ad) files with:
@@ -1,14 +1,16 @@
1
- geneval/__init__.py,sha256=WB9yj2OLpY3tjw545G0xzSP8iP_fn4vOIWA2WaraEWk,2977
1
+ geneval/__init__.py,sha256=_WxX5Kjk7y3u7mBZ5cf6ficy9SIT2FutZNcMe1fr9Ro,3989
2
2
  geneval/cli.py,sha256=0ai0IGyn3SSmEnfLRJhcr0brvUxuNZHE4IXod7jvosU,9977
3
3
  geneval/config.py,sha256=gkCjs_gzPWgUZNcmSR3Y70XQCAZ1m9AKLueaM-x8bvw,3729
4
4
  geneval/core.py,sha256=No0DP8bNR6LedfCWEedY9C5r_c4M14rvSPaGZqbxc94,1155
5
- geneval/data/__init__.py,sha256=nD3uWostZbYD3Yj_TOE44LvPDen-Vm3gN8ZH0QptPGw,450
5
+ geneval/data/__init__.py,sha256=NQUPVpUnBIabrTH5TuRk0KE9S7sVO5QetZv-MCQmZuw,827
6
6
  geneval/data/gene_expression_datamodule.py,sha256=XiBIdf68JZ-3S-FaZsrQlBJA7qL9uUXo2C8y0r4an5M,8009
7
+ geneval/data/lazy_loader.py,sha256=5fTRVjPjcWvYXV-uPWFUF2Nn9rHRdD8lygAUkCW8wOM,20677
7
8
  geneval/data/loader.py,sha256=zpRmwGZ4PJkB3rpXXRCMFtvMi4qvUrPkKmvIlGjfRpY,14555
8
9
  geneval/evaluator.py,sha256=wZFzLo2PLHanjA-9L6C3xJBjMWXxPM63kU6usU4P7bs,11619
9
10
  geneval/evaluators/__init__.py,sha256=i11sHvhsjEAeI3Aw9zFTPmCYuqkGxzTHggAKehe3HQ0,160
10
11
  geneval/evaluators/base_evaluator.py,sha256=yJL568HdNofIcHgNOElSQMVlG9oRPTTDIZ7CmKccRqs,5967
11
12
  geneval/evaluators/gene_expression_evaluator.py,sha256=v8QL6tzOQ3QVXdPMM8tFHTTviZC3WsPRX4G0ShgeDUw,8743
13
+ geneval/lazy_evaluator.py,sha256=I_VvDolxPFGiW38eGPrjSoBOKICKyYN3GHbjJBAe5tg,13200
12
14
  geneval/metrics/__init__.py,sha256=H5IXTKR-zoP_pGht6ioJfhLU7IHrSDQElMk0Cp4-JTw,1786
13
15
  geneval/metrics/base_metric.py,sha256=prbnB-Ap-P64m-2_TUrHxO3NFQaw-obVg1Tw4pjC5EY,6961
14
16
  geneval/metrics/correlation.py,sha256=jpYmaihWK89J1E5yQinGUJeB6pTZ21xPNHJi3XYyXJE,6987
@@ -25,8 +27,8 @@ geneval/utils/preprocessing.py,sha256=1Cij1O2dwDR6_zh5IEgLPq3jEmV8VfIRjfQrHiKe3M
25
27
  geneval/visualization/__init__.py,sha256=LN19jl5xV4WVJTePaOUHWvKZ_pgDFp1chhcklGkNtm8,792
26
28
  geneval/visualization/plots.py,sha256=3K94r3x5NjIUZ-hYVQIivO63VkLOvDWl-BLB_qL2pSY,15008
27
29
  geneval/visualization/visualizer.py,sha256=lX7K0j20nAsgdtOOdbxLdLKYAfovEp3hNAnZOjFTCq0,36670
28
- gengeneeval-0.2.0.dist-info/METADATA,sha256=LyE9iQQUOMakv91vebWNBcg10quJatQwYHGucqri3Rw,6262
29
- gengeneeval-0.2.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
30
- gengeneeval-0.2.0.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
31
- gengeneeval-0.2.0.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
32
- gengeneeval-0.2.0.dist-info/RECORD,,
30
+ gengeneeval-0.2.1.dist-info/METADATA,sha256=aRjsh5JUIcH8huIngsCi14eyE3-Vl_AFv9Uo1j5mciw,7497
31
+ gengeneeval-0.2.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
32
+ gengeneeval-0.2.1.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
33
+ gengeneeval-0.2.1.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
34
+ gengeneeval-0.2.1.dist-info/RECORD,,