gengeneeval 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/PKG-INFO +37 -4
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/README.md +34 -1
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/pyproject.toml +3 -3
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/__init__.py +39 -1
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/data/__init__.py +14 -0
- gengeneeval-0.2.1/src/geneval/data/lazy_loader.py +562 -0
- gengeneeval-0.2.1/src/geneval/lazy_evaluator.py +424 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/LICENSE +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/cli.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/config.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/core.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/data/gene_expression_datamodule.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/data/loader.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/evaluator.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/evaluators/__init__.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/evaluators/base_evaluator.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/evaluators/gene_expression_evaluator.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/__init__.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/base_metric.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/correlation.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/distances.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/metrics.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/metrics/reconstruction.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/models/__init__.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/models/base_model.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/results.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/testing.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/utils/__init__.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/utils/io.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/utils/preprocessing.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/visualization/__init__.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/visualization/plots.py +0 -0
- {gengeneeval-0.2.0 → gengeneeval-0.2.1}/src/geneval/visualization/visualizer.py +0 -0
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gengeneeval
|
|
3
|
-
Version: 0.2.
|
|
4
|
-
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations.
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations.
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
7
|
-
Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking
|
|
7
|
+
Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
|
|
8
8
|
Author: GenEval Team
|
|
9
9
|
Author-email: geneval@example.com
|
|
10
10
|
Requires-Python: >=3.8,<4.0
|
|
@@ -46,7 +46,7 @@ Description-Content-Type: text/markdown
|
|
|
46
46
|
|
|
47
47
|
**Comprehensive evaluation of generated gene expression data against real datasets.**
|
|
48
48
|
|
|
49
|
-
GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, and generates publication-quality visualizations.
|
|
49
|
+
GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, memory-efficient lazy loading for large datasets, and generates publication-quality visualizations.
|
|
50
50
|
|
|
51
51
|
## Features
|
|
52
52
|
|
|
@@ -77,6 +77,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
77
77
|
- ✅ Condition-based matching (perturbation, cell type, etc.)
|
|
78
78
|
- ✅ Train/test split support
|
|
79
79
|
- ✅ Per-gene and aggregate metrics
|
|
80
|
+
- ✅ **Memory-efficient lazy loading** for large datasets
|
|
81
|
+
- ✅ **Batched evaluation** to avoid OOM errors
|
|
80
82
|
- ✅ Modular, extensible architecture
|
|
81
83
|
- ✅ Command-line interface
|
|
82
84
|
- ✅ Publication-quality visualizations
|
|
@@ -140,6 +142,37 @@ geneval --real real.h5ad --generated generated.h5ad \
|
|
|
140
142
|
--output results/
|
|
141
143
|
```
|
|
142
144
|
|
|
145
|
+
### Memory-Efficient Mode (for Large Datasets)
|
|
146
|
+
|
|
147
|
+
For datasets too large to fit in memory, use the lazy evaluation API:
|
|
148
|
+
|
|
149
|
+
```python
|
|
150
|
+
from geneval import evaluate_lazy, load_data_lazy
|
|
151
|
+
|
|
152
|
+
# Memory-efficient evaluation (streams data one condition at a time)
|
|
153
|
+
results = evaluate_lazy(
|
|
154
|
+
real_path="large_real.h5ad",
|
|
155
|
+
generated_path="large_generated.h5ad",
|
|
156
|
+
condition_columns=["perturbation"],
|
|
157
|
+
batch_size=256, # Process in batches
|
|
158
|
+
use_backed=True, # Memory-mapped file access
|
|
159
|
+
output_dir="eval_output/",
|
|
160
|
+
save_per_condition=True, # Save each condition to disk
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Get summary statistics
|
|
164
|
+
print(results.get_summary())
|
|
165
|
+
|
|
166
|
+
# Or use the lazy loader directly for custom workflows
|
|
167
|
+
with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
|
|
168
|
+
print(f"Memory estimate: {loader.estimate_memory_usage()}")
|
|
169
|
+
|
|
170
|
+
# Process one condition at a time
|
|
171
|
+
for key, real, gen, info in loader.iterate_conditions():
|
|
172
|
+
# Your custom evaluation logic
|
|
173
|
+
pass
|
|
174
|
+
```
|
|
175
|
+
|
|
143
176
|
## Expected Data Format
|
|
144
177
|
|
|
145
178
|
GenEval expects AnnData (h5ad) files with:
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
**Comprehensive evaluation of generated gene expression data against real datasets.**
|
|
9
9
|
|
|
10
|
-
GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, and generates publication-quality visualizations.
|
|
10
|
+
GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, memory-efficient lazy loading for large datasets, and generates publication-quality visualizations.
|
|
11
11
|
|
|
12
12
|
## Features
|
|
13
13
|
|
|
@@ -38,6 +38,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
38
38
|
- ✅ Condition-based matching (perturbation, cell type, etc.)
|
|
39
39
|
- ✅ Train/test split support
|
|
40
40
|
- ✅ Per-gene and aggregate metrics
|
|
41
|
+
- ✅ **Memory-efficient lazy loading** for large datasets
|
|
42
|
+
- ✅ **Batched evaluation** to avoid OOM errors
|
|
41
43
|
- ✅ Modular, extensible architecture
|
|
42
44
|
- ✅ Command-line interface
|
|
43
45
|
- ✅ Publication-quality visualizations
|
|
@@ -101,6 +103,37 @@ geneval --real real.h5ad --generated generated.h5ad \
|
|
|
101
103
|
--output results/
|
|
102
104
|
```
|
|
103
105
|
|
|
106
|
+
### Memory-Efficient Mode (for Large Datasets)
|
|
107
|
+
|
|
108
|
+
For datasets too large to fit in memory, use the lazy evaluation API:
|
|
109
|
+
|
|
110
|
+
```python
|
|
111
|
+
from geneval import evaluate_lazy, load_data_lazy
|
|
112
|
+
|
|
113
|
+
# Memory-efficient evaluation (streams data one condition at a time)
|
|
114
|
+
results = evaluate_lazy(
|
|
115
|
+
real_path="large_real.h5ad",
|
|
116
|
+
generated_path="large_generated.h5ad",
|
|
117
|
+
condition_columns=["perturbation"],
|
|
118
|
+
batch_size=256, # Process in batches
|
|
119
|
+
use_backed=True, # Memory-mapped file access
|
|
120
|
+
output_dir="eval_output/",
|
|
121
|
+
save_per_condition=True, # Save each condition to disk
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Get summary statistics
|
|
125
|
+
print(results.get_summary())
|
|
126
|
+
|
|
127
|
+
# Or use the lazy loader directly for custom workflows
|
|
128
|
+
with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
|
|
129
|
+
print(f"Memory estimate: {loader.estimate_memory_usage()}")
|
|
130
|
+
|
|
131
|
+
# Process one condition at a time
|
|
132
|
+
for key, real, gen, info in loader.iterate_conditions():
|
|
133
|
+
# Your custom evaluation logic
|
|
134
|
+
pass
|
|
135
|
+
```
|
|
136
|
+
|
|
104
137
|
## Expected Data Format
|
|
105
138
|
|
|
106
139
|
GenEval expects AnnData (h5ad) files with:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "gengeneeval"
|
|
3
|
-
version = "0.2.
|
|
4
|
-
description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations."
|
|
3
|
+
version = "0.2.1"
|
|
4
|
+
description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations."
|
|
5
5
|
authors = ["GenEval Team <geneval@example.com>"]
|
|
6
6
|
license = "MIT"
|
|
7
7
|
readme = "README.md"
|
|
8
8
|
homepage = "https://github.com/AndreaRubbi/GenGeneEval"
|
|
9
9
|
repository = "https://github.com/AndreaRubbi/GenGeneEval"
|
|
10
|
-
keywords = ["gene expression", "evaluation", "metrics", "single-cell", "generative models", "benchmarking"]
|
|
10
|
+
keywords = ["gene expression", "evaluation", "metrics", "single-cell", "generative models", "benchmarking", "memory-efficient"]
|
|
11
11
|
classifiers = [
|
|
12
12
|
"Development Status :: 4 - Beta",
|
|
13
13
|
"Intended Audience :: Science/Research",
|
|
@@ -8,6 +8,7 @@ Features:
|
|
|
8
8
|
- Multiple distance and correlation metrics (per-gene and aggregate)
|
|
9
9
|
- Condition-based matching (perturbation, cell type, etc.)
|
|
10
10
|
- Train/test split support
|
|
11
|
+
- Memory-efficient lazy loading for large datasets
|
|
11
12
|
- Publication-quality visualizations
|
|
12
13
|
- Command-line interface
|
|
13
14
|
|
|
@@ -20,12 +21,22 @@ Quick Start:
|
|
|
20
21
|
... output_dir="output/"
|
|
21
22
|
... )
|
|
22
23
|
|
|
24
|
+
Memory-Efficient Mode (for large datasets):
|
|
25
|
+
>>> from geneval import evaluate_lazy
|
|
26
|
+
>>> results = evaluate_lazy(
|
|
27
|
+
... real_path="real.h5ad",
|
|
28
|
+
... generated_path="generated.h5ad",
|
|
29
|
+
... condition_columns=["perturbation"],
|
|
30
|
+
... batch_size=256,
|
|
31
|
+
... use_backed=True, # Memory-mapped access
|
|
32
|
+
... )
|
|
33
|
+
|
|
23
34
|
CLI Usage:
|
|
24
35
|
$ geneval --real real.h5ad --generated generated.h5ad \\
|
|
25
36
|
--conditions perturbation cell_type --output results/
|
|
26
37
|
"""
|
|
27
38
|
|
|
28
|
-
__version__ = "0.2.
|
|
39
|
+
__version__ = "0.2.1"
|
|
29
40
|
__author__ = "GenEval Team"
|
|
30
41
|
|
|
31
42
|
# Main evaluation interface
|
|
@@ -35,12 +46,26 @@ from .evaluator import (
|
|
|
35
46
|
MetricRegistry,
|
|
36
47
|
)
|
|
37
48
|
|
|
49
|
+
# Memory-efficient evaluation
|
|
50
|
+
from .lazy_evaluator import (
|
|
51
|
+
evaluate_lazy,
|
|
52
|
+
MemoryEfficientEvaluator,
|
|
53
|
+
StreamingEvaluationResult,
|
|
54
|
+
)
|
|
55
|
+
|
|
38
56
|
# Data loading
|
|
39
57
|
from .data.loader import (
|
|
40
58
|
GeneExpressionDataLoader,
|
|
41
59
|
load_data,
|
|
42
60
|
)
|
|
43
61
|
|
|
62
|
+
# Memory-efficient data loading
|
|
63
|
+
from .data.lazy_loader import (
|
|
64
|
+
LazyGeneExpressionDataLoader,
|
|
65
|
+
load_data_lazy,
|
|
66
|
+
ConditionBatch,
|
|
67
|
+
)
|
|
68
|
+
|
|
44
69
|
# Results
|
|
45
70
|
from .results import (
|
|
46
71
|
EvaluationResult,
|
|
@@ -99,9 +124,17 @@ __all__ = [
|
|
|
99
124
|
"evaluate",
|
|
100
125
|
"GeneEvalEvaluator",
|
|
101
126
|
"MetricRegistry",
|
|
127
|
+
# Memory-efficient evaluation
|
|
128
|
+
"evaluate_lazy",
|
|
129
|
+
"MemoryEfficientEvaluator",
|
|
130
|
+
"StreamingEvaluationResult",
|
|
102
131
|
# Data loading
|
|
103
132
|
"GeneExpressionDataLoader",
|
|
104
133
|
"load_data",
|
|
134
|
+
# Memory-efficient data loading
|
|
135
|
+
"LazyGeneExpressionDataLoader",
|
|
136
|
+
"load_data_lazy",
|
|
137
|
+
"ConditionBatch",
|
|
105
138
|
# Results
|
|
106
139
|
"EvaluationResult",
|
|
107
140
|
"SplitResult",
|
|
@@ -123,6 +156,11 @@ __all__ = [
|
|
|
123
156
|
"EnergyDistance",
|
|
124
157
|
"MultivariateWasserstein",
|
|
125
158
|
"MultivariateMMD",
|
|
159
|
+
# Reconstruction metrics
|
|
160
|
+
"MSEDistance",
|
|
161
|
+
"RMSEDistance",
|
|
162
|
+
"MAEDistance",
|
|
163
|
+
"R2Score",
|
|
126
164
|
# Visualization
|
|
127
165
|
"EvaluationVisualizer",
|
|
128
166
|
"visualize",
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Data loading module for gene expression evaluation.
|
|
3
3
|
|
|
4
4
|
Provides data loaders for paired real and generated datasets.
|
|
5
|
+
Includes both standard and memory-efficient lazy loading options.
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
from .loader import (
|
|
@@ -9,15 +10,28 @@ from .loader import (
|
|
|
9
10
|
load_data,
|
|
10
11
|
DataLoaderError,
|
|
11
12
|
)
|
|
13
|
+
from .lazy_loader import (
|
|
14
|
+
LazyGeneExpressionDataLoader,
|
|
15
|
+
load_data_lazy,
|
|
16
|
+
LazyDataLoaderError,
|
|
17
|
+
ConditionBatch,
|
|
18
|
+
)
|
|
12
19
|
from .gene_expression_datamodule import (
|
|
13
20
|
GeneExpressionDataModule,
|
|
14
21
|
DataModuleError,
|
|
15
22
|
)
|
|
16
23
|
|
|
17
24
|
__all__ = [
|
|
25
|
+
# Standard loader
|
|
18
26
|
"GeneExpressionDataLoader",
|
|
19
27
|
"load_data",
|
|
20
28
|
"DataLoaderError",
|
|
29
|
+
# Lazy loader (memory-efficient)
|
|
30
|
+
"LazyGeneExpressionDataLoader",
|
|
31
|
+
"load_data_lazy",
|
|
32
|
+
"LazyDataLoaderError",
|
|
33
|
+
"ConditionBatch",
|
|
34
|
+
# DataModule
|
|
21
35
|
"GeneExpressionDataModule",
|
|
22
36
|
"DataModuleError",
|
|
23
37
|
]
|
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-efficient lazy data loader for large-scale gene expression datasets.
|
|
3
|
+
|
|
4
|
+
Provides lazy loading and batched iteration over AnnData h5ad files without
|
|
5
|
+
loading entire datasets into memory. Supports backed mode for very large files.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Optional, List, Union, Dict, Tuple, Iterator, Generator
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import warnings
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from scipy import sparse
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import anndata as ad
|
|
19
|
+
import scanpy as sc
|
|
20
|
+
except ImportError:
|
|
21
|
+
raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LazyDataLoaderError(Exception):
|
|
25
|
+
"""Custom exception for lazy data loading errors."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ConditionBatch:
|
|
31
|
+
"""Container for a batch of samples from a condition."""
|
|
32
|
+
condition_key: str
|
|
33
|
+
condition_info: Dict[str, str]
|
|
34
|
+
real_data: np.ndarray
|
|
35
|
+
generated_data: np.ndarray
|
|
36
|
+
batch_idx: int
|
|
37
|
+
n_batches: int
|
|
38
|
+
is_last_batch: bool
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LazyGeneExpressionDataLoader:
|
|
42
|
+
"""
|
|
43
|
+
Memory-efficient lazy data loader for paired gene expression datasets.
|
|
44
|
+
|
|
45
|
+
Unlike GeneExpressionDataLoader, this class:
|
|
46
|
+
- Uses backed mode for h5ad files to avoid loading entire datasets
|
|
47
|
+
- Supports batched iteration over conditions
|
|
48
|
+
- Only loads data into memory when explicitly requested
|
|
49
|
+
- Provides memory usage estimates
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
real_path : str or Path
|
|
54
|
+
Path to real data h5ad file
|
|
55
|
+
generated_path : str or Path
|
|
56
|
+
Path to generated data h5ad file
|
|
57
|
+
condition_columns : List[str]
|
|
58
|
+
Columns to match between datasets
|
|
59
|
+
split_column : str, optional
|
|
60
|
+
Column indicating train/test split
|
|
61
|
+
batch_size : int
|
|
62
|
+
Maximum number of samples per batch when iterating
|
|
63
|
+
use_backed : bool
|
|
64
|
+
If True, use backed mode (memory-mapped). May be slower but uses minimal memory.
|
|
65
|
+
If False, loads full file but processes in batches.
|
|
66
|
+
min_samples_per_condition : int
|
|
67
|
+
Minimum samples required per condition to include
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
real_path: Union[str, Path],
|
|
73
|
+
generated_path: Union[str, Path],
|
|
74
|
+
condition_columns: List[str],
|
|
75
|
+
split_column: Optional[str] = None,
|
|
76
|
+
batch_size: int = 256,
|
|
77
|
+
use_backed: bool = False,
|
|
78
|
+
min_samples_per_condition: int = 2,
|
|
79
|
+
):
|
|
80
|
+
self.real_path = Path(real_path)
|
|
81
|
+
self.generated_path = Path(generated_path)
|
|
82
|
+
self.condition_columns = condition_columns
|
|
83
|
+
self.split_column = split_column
|
|
84
|
+
self.batch_size = batch_size
|
|
85
|
+
self.use_backed = use_backed
|
|
86
|
+
self.min_samples_per_condition = min_samples_per_condition
|
|
87
|
+
|
|
88
|
+
# Lazy-loaded references (backed or full)
|
|
89
|
+
self._real: Optional[ad.AnnData] = None
|
|
90
|
+
self._generated: Optional[ad.AnnData] = None
|
|
91
|
+
|
|
92
|
+
# Metadata (always loaded - lightweight)
|
|
93
|
+
self._real_obs: Optional[pd.DataFrame] = None
|
|
94
|
+
self._generated_obs: Optional[pd.DataFrame] = None
|
|
95
|
+
self._real_var_names: Optional[pd.Index] = None
|
|
96
|
+
self._generated_var_names: Optional[pd.Index] = None
|
|
97
|
+
|
|
98
|
+
# Gene alignment info
|
|
99
|
+
self._common_genes: Optional[List[str]] = None
|
|
100
|
+
self._real_gene_idx: Optional[np.ndarray] = None
|
|
101
|
+
self._gen_gene_idx: Optional[np.ndarray] = None
|
|
102
|
+
|
|
103
|
+
# Pre-computed condition indices for fast access
|
|
104
|
+
self._condition_indices: Optional[Dict[str, Dict[str, np.ndarray]]] = None
|
|
105
|
+
|
|
106
|
+
# State
|
|
107
|
+
self._is_initialized = False
|
|
108
|
+
|
|
109
|
+
def initialize(self) -> "LazyGeneExpressionDataLoader":
|
|
110
|
+
"""
|
|
111
|
+
Initialize loader by reading metadata only (not expression data).
|
|
112
|
+
|
|
113
|
+
This loads obs DataFrames and var_names to prepare for iteration,
|
|
114
|
+
but does not load the expression matrices.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
self
|
|
119
|
+
For method chaining
|
|
120
|
+
"""
|
|
121
|
+
if self._is_initialized:
|
|
122
|
+
return self
|
|
123
|
+
|
|
124
|
+
# Validate paths
|
|
125
|
+
if not self.real_path.exists():
|
|
126
|
+
raise LazyDataLoaderError(f"Real data file not found: {self.real_path}")
|
|
127
|
+
if not self.generated_path.exists():
|
|
128
|
+
raise LazyDataLoaderError(f"Generated data file not found: {self.generated_path}")
|
|
129
|
+
|
|
130
|
+
# Load metadata only (obs and var_names)
|
|
131
|
+
# This is much faster and lighter than loading full data
|
|
132
|
+
self._load_metadata()
|
|
133
|
+
|
|
134
|
+
# Validate columns
|
|
135
|
+
self._validate_columns()
|
|
136
|
+
|
|
137
|
+
# Compute gene alignment indices
|
|
138
|
+
self._compute_gene_alignment()
|
|
139
|
+
|
|
140
|
+
# Pre-compute condition indices
|
|
141
|
+
self._precompute_condition_indices()
|
|
142
|
+
|
|
143
|
+
self._is_initialized = True
|
|
144
|
+
return self
|
|
145
|
+
|
|
146
|
+
def _load_metadata(self):
|
|
147
|
+
"""Load only metadata (obs, var_names) without expression data."""
|
|
148
|
+
# For backed mode, we open but don't load X
|
|
149
|
+
# For non-backed, we still only read metadata initially
|
|
150
|
+
|
|
151
|
+
# Try backed mode if requested
|
|
152
|
+
if self.use_backed:
|
|
153
|
+
try:
|
|
154
|
+
self._real = sc.read_h5ad(self.real_path, backed='r')
|
|
155
|
+
self._generated = sc.read_h5ad(self.generated_path, backed='r')
|
|
156
|
+
except Exception as e:
|
|
157
|
+
warnings.warn(f"Backed mode failed, falling back to standard loading: {e}")
|
|
158
|
+
self._real = None
|
|
159
|
+
self._generated = None
|
|
160
|
+
|
|
161
|
+
if self._real is None:
|
|
162
|
+
# Load only what we need: obs and var
|
|
163
|
+
# For very large files, read in low-memory mode
|
|
164
|
+
self._real = sc.read_h5ad(self.real_path)
|
|
165
|
+
self._generated = sc.read_h5ad(self.generated_path)
|
|
166
|
+
|
|
167
|
+
# Cache lightweight metadata
|
|
168
|
+
self._real_obs = self._real.obs.copy()
|
|
169
|
+
self._generated_obs = self._generated.obs.copy()
|
|
170
|
+
self._real_var_names = pd.Index(self._real.var_names.astype(str))
|
|
171
|
+
self._generated_var_names = pd.Index(self._generated.var_names.astype(str))
|
|
172
|
+
|
|
173
|
+
def _validate_columns(self):
|
|
174
|
+
"""Validate that required columns exist."""
|
|
175
|
+
for col in self.condition_columns:
|
|
176
|
+
if col not in self._real_obs.columns:
|
|
177
|
+
raise LazyDataLoaderError(
|
|
178
|
+
f"Condition column '{col}' not found in real data. "
|
|
179
|
+
f"Available: {list(self._real_obs.columns)}"
|
|
180
|
+
)
|
|
181
|
+
if col not in self._generated_obs.columns:
|
|
182
|
+
raise LazyDataLoaderError(
|
|
183
|
+
f"Condition column '{col}' not found in generated data. "
|
|
184
|
+
f"Available: {list(self._generated_obs.columns)}"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _compute_gene_alignment(self):
|
|
188
|
+
"""Pre-compute gene alignment indices."""
|
|
189
|
+
common = self._real_var_names.intersection(self._generated_var_names)
|
|
190
|
+
|
|
191
|
+
if len(common) == 0:
|
|
192
|
+
raise LazyDataLoaderError(
|
|
193
|
+
"No overlapping genes between real and generated data."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self._common_genes = common.tolist()
|
|
197
|
+
self._real_gene_idx = self._real_var_names.get_indexer(common)
|
|
198
|
+
self._gen_gene_idx = self._generated_var_names.get_indexer(common)
|
|
199
|
+
|
|
200
|
+
n_real_only = len(self._real_var_names) - len(common)
|
|
201
|
+
n_gen_only = len(self._generated_var_names) - len(common)
|
|
202
|
+
|
|
203
|
+
if n_real_only > 0 or n_gen_only > 0:
|
|
204
|
+
warnings.warn(
|
|
205
|
+
f"Gene alignment: {len(common)} common genes. "
|
|
206
|
+
f"Dropped {n_real_only} from real, {n_gen_only} from generated."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _get_condition_key(self, row: pd.Series) -> str:
|
|
210
|
+
"""Generate unique key for a condition combination."""
|
|
211
|
+
return "####".join([str(row[c]) for c in self.condition_columns])
|
|
212
|
+
|
|
213
|
+
def _precompute_condition_indices(self):
|
|
214
|
+
"""Pre-compute sample indices for each condition (lightweight)."""
|
|
215
|
+
self._condition_indices = {"real": {}, "generated": {}}
|
|
216
|
+
|
|
217
|
+
# Real data conditions
|
|
218
|
+
real_conditions = self._real_obs[self.condition_columns].astype(str).drop_duplicates()
|
|
219
|
+
for _, row in real_conditions.iterrows():
|
|
220
|
+
key = self._get_condition_key(row)
|
|
221
|
+
mask = np.ones(len(self._real_obs), dtype=bool)
|
|
222
|
+
for col in self.condition_columns:
|
|
223
|
+
mask &= (self._real_obs[col].astype(str) == str(row[col])).values
|
|
224
|
+
|
|
225
|
+
indices = np.where(mask)[0]
|
|
226
|
+
if len(indices) >= self.min_samples_per_condition:
|
|
227
|
+
self._condition_indices["real"][key] = indices
|
|
228
|
+
|
|
229
|
+
# Generated data conditions
|
|
230
|
+
gen_conditions = self._generated_obs[self.condition_columns].astype(str).drop_duplicates()
|
|
231
|
+
for _, row in gen_conditions.iterrows():
|
|
232
|
+
key = self._get_condition_key(row)
|
|
233
|
+
mask = np.ones(len(self._generated_obs), dtype=bool)
|
|
234
|
+
for col in self.condition_columns:
|
|
235
|
+
mask &= (self._generated_obs[col].astype(str) == str(row[col])).values
|
|
236
|
+
|
|
237
|
+
indices = np.where(mask)[0]
|
|
238
|
+
if len(indices) >= self.min_samples_per_condition:
|
|
239
|
+
self._condition_indices["generated"][key] = indices
|
|
240
|
+
|
|
241
|
+
def get_splits(self) -> List[str]:
|
|
242
|
+
"""Get available splits."""
|
|
243
|
+
if not self._is_initialized:
|
|
244
|
+
self.initialize()
|
|
245
|
+
|
|
246
|
+
if self.split_column is None or self.split_column not in self._real_obs.columns:
|
|
247
|
+
return ["all"]
|
|
248
|
+
|
|
249
|
+
return list(self._real_obs[self.split_column].astype(str).unique())
|
|
250
|
+
|
|
251
|
+
def get_common_conditions(self, split: Optional[str] = None) -> List[str]:
|
|
252
|
+
"""Get conditions present in both real and generated data."""
|
|
253
|
+
if not self._is_initialized:
|
|
254
|
+
self.initialize()
|
|
255
|
+
|
|
256
|
+
real_keys = set(self._condition_indices["real"].keys())
|
|
257
|
+
gen_keys = set(self._condition_indices["generated"].keys())
|
|
258
|
+
|
|
259
|
+
common = real_keys & gen_keys
|
|
260
|
+
|
|
261
|
+
# Filter by split if specified
|
|
262
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
263
|
+
filtered = set()
|
|
264
|
+
for key in common:
|
|
265
|
+
real_idx = self._condition_indices["real"][key]
|
|
266
|
+
split_vals = self._real_obs.iloc[real_idx][self.split_column].astype(str)
|
|
267
|
+
if (split_vals == split).any():
|
|
268
|
+
filtered.add(key)
|
|
269
|
+
common = filtered
|
|
270
|
+
|
|
271
|
+
return sorted(common)
|
|
272
|
+
|
|
273
|
+
def _extract_data_subset(
|
|
274
|
+
self,
|
|
275
|
+
adata: ad.AnnData,
|
|
276
|
+
indices: np.ndarray,
|
|
277
|
+
gene_idx: np.ndarray,
|
|
278
|
+
) -> np.ndarray:
|
|
279
|
+
"""Extract and align a subset of data."""
|
|
280
|
+
# Handle backed vs loaded data
|
|
281
|
+
if hasattr(adata, 'isbacked') and adata.isbacked:
|
|
282
|
+
# Backed mode: read only what we need
|
|
283
|
+
X = adata.X[indices][:, gene_idx]
|
|
284
|
+
else:
|
|
285
|
+
# Standard mode
|
|
286
|
+
X = adata.X[indices][:, gene_idx]
|
|
287
|
+
|
|
288
|
+
# Convert to dense if sparse
|
|
289
|
+
if sparse.issparse(X):
|
|
290
|
+
X = X.toarray()
|
|
291
|
+
|
|
292
|
+
return np.asarray(X, dtype=np.float32)
|
|
293
|
+
|
|
294
|
+
def iterate_conditions(
|
|
295
|
+
self,
|
|
296
|
+
split: Optional[str] = None,
|
|
297
|
+
) -> Generator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]], None, None]:
|
|
298
|
+
"""
|
|
299
|
+
Iterate over conditions, loading one condition at a time.
|
|
300
|
+
|
|
301
|
+
This loads data for one condition, yields it, then releases memory
|
|
302
|
+
before loading the next condition.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
split : str, optional
|
|
307
|
+
Filter to this split only
|
|
308
|
+
|
|
309
|
+
Yields
|
|
310
|
+
------
|
|
311
|
+
Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
|
|
312
|
+
(condition_key, real_data, generated_data, condition_info)
|
|
313
|
+
"""
|
|
314
|
+
if not self._is_initialized:
|
|
315
|
+
self.initialize()
|
|
316
|
+
|
|
317
|
+
common_conditions = self.get_common_conditions(split)
|
|
318
|
+
|
|
319
|
+
for key in common_conditions:
|
|
320
|
+
real_indices = self._condition_indices["real"][key]
|
|
321
|
+
gen_indices = self._condition_indices["generated"][key]
|
|
322
|
+
|
|
323
|
+
# Filter by split if needed
|
|
324
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
325
|
+
split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
|
|
326
|
+
real_indices = real_indices[split_mask.values]
|
|
327
|
+
|
|
328
|
+
if len(real_indices) < self.min_samples_per_condition:
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
# Load data for this condition only
|
|
332
|
+
real_data = self._extract_data_subset(
|
|
333
|
+
self._real, real_indices, self._real_gene_idx
|
|
334
|
+
)
|
|
335
|
+
gen_data = self._extract_data_subset(
|
|
336
|
+
self._generated, gen_indices, self._gen_gene_idx
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Parse condition info
|
|
340
|
+
parts = key.split("####")
|
|
341
|
+
condition_info = dict(zip(self.condition_columns, parts))
|
|
342
|
+
|
|
343
|
+
yield key, real_data, gen_data, condition_info
|
|
344
|
+
|
|
345
|
+
def iterate_conditions_batched(
|
|
346
|
+
self,
|
|
347
|
+
split: Optional[str] = None,
|
|
348
|
+
batch_size: Optional[int] = None,
|
|
349
|
+
) -> Generator[ConditionBatch, None, None]:
|
|
350
|
+
"""
|
|
351
|
+
Iterate over conditions in batches for memory efficiency.
|
|
352
|
+
|
|
353
|
+
Useful when even a single condition is too large to fit in memory.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
split : str, optional
|
|
358
|
+
Filter to this split only
|
|
359
|
+
batch_size : int, optional
|
|
360
|
+
Override default batch size
|
|
361
|
+
|
|
362
|
+
Yields
|
|
363
|
+
------
|
|
364
|
+
ConditionBatch
|
|
365
|
+
Batch of samples from a condition
|
|
366
|
+
"""
|
|
367
|
+
if not self._is_initialized:
|
|
368
|
+
self.initialize()
|
|
369
|
+
|
|
370
|
+
batch_size = batch_size or self.batch_size
|
|
371
|
+
common_conditions = self.get_common_conditions(split)
|
|
372
|
+
|
|
373
|
+
for key in common_conditions:
|
|
374
|
+
real_indices = self._condition_indices["real"][key]
|
|
375
|
+
gen_indices = self._condition_indices["generated"][key]
|
|
376
|
+
|
|
377
|
+
# Filter by split
|
|
378
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
379
|
+
split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
|
|
380
|
+
real_indices = real_indices[split_mask.values]
|
|
381
|
+
|
|
382
|
+
if len(real_indices) < self.min_samples_per_condition:
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# Parse condition info
|
|
386
|
+
parts = key.split("####")
|
|
387
|
+
condition_info = dict(zip(self.condition_columns, parts))
|
|
388
|
+
|
|
389
|
+
# Calculate number of batches (use max of real/gen for alignment)
|
|
390
|
+
n_real = len(real_indices)
|
|
391
|
+
n_gen = len(gen_indices)
|
|
392
|
+
n_batches = max(
|
|
393
|
+
(n_real + batch_size - 1) // batch_size,
|
|
394
|
+
(n_gen + batch_size - 1) // batch_size
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
for batch_idx in range(n_batches):
|
|
398
|
+
start_real = batch_idx * batch_size
|
|
399
|
+
end_real = min(start_real + batch_size, n_real)
|
|
400
|
+
start_gen = batch_idx * batch_size
|
|
401
|
+
end_gen = min(start_gen + batch_size, n_gen)
|
|
402
|
+
|
|
403
|
+
# Handle case where one dataset is smaller
|
|
404
|
+
if start_real >= n_real:
|
|
405
|
+
# Wrap around for real data
|
|
406
|
+
batch_real_idx = real_indices[start_real % n_real:end_real % n_real + 1]
|
|
407
|
+
else:
|
|
408
|
+
batch_real_idx = real_indices[start_real:end_real]
|
|
409
|
+
|
|
410
|
+
if start_gen >= n_gen:
|
|
411
|
+
batch_gen_idx = gen_indices[start_gen % n_gen:end_gen % n_gen + 1]
|
|
412
|
+
else:
|
|
413
|
+
batch_gen_idx = gen_indices[start_gen:end_gen]
|
|
414
|
+
|
|
415
|
+
if len(batch_real_idx) == 0 or len(batch_gen_idx) == 0:
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
real_data = self._extract_data_subset(
|
|
419
|
+
self._real, batch_real_idx, self._real_gene_idx
|
|
420
|
+
)
|
|
421
|
+
gen_data = self._extract_data_subset(
|
|
422
|
+
self._generated, batch_gen_idx, self._gen_gene_idx
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
yield ConditionBatch(
|
|
426
|
+
condition_key=key,
|
|
427
|
+
condition_info=condition_info,
|
|
428
|
+
real_data=real_data,
|
|
429
|
+
generated_data=gen_data,
|
|
430
|
+
batch_idx=batch_idx,
|
|
431
|
+
n_batches=n_batches,
|
|
432
|
+
is_last_batch=(batch_idx == n_batches - 1),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def gene_names(self) -> List[str]:
|
|
437
|
+
"""Get common gene names."""
|
|
438
|
+
if not self._is_initialized:
|
|
439
|
+
self.initialize()
|
|
440
|
+
return self._common_genes
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def n_genes(self) -> int:
|
|
444
|
+
"""Number of common genes."""
|
|
445
|
+
return len(self.gene_names)
|
|
446
|
+
|
|
447
|
+
def estimate_memory_usage(self) -> Dict[str, float]:
|
|
448
|
+
"""
|
|
449
|
+
Estimate memory usage in MB for different loading strategies.
|
|
450
|
+
|
|
451
|
+
Returns
|
|
452
|
+
-------
|
|
453
|
+
Dict[str, float]
|
|
454
|
+
Memory estimates in MB
|
|
455
|
+
"""
|
|
456
|
+
if not self._is_initialized:
|
|
457
|
+
self.initialize()
|
|
458
|
+
|
|
459
|
+
n_real = len(self._real_obs)
|
|
460
|
+
n_gen = len(self._generated_obs)
|
|
461
|
+
n_genes = self.n_genes
|
|
462
|
+
|
|
463
|
+
# 4 bytes per float32
|
|
464
|
+
bytes_per_element = 4
|
|
465
|
+
|
|
466
|
+
full_real = n_real * n_genes * bytes_per_element / 1e6
|
|
467
|
+
full_gen = n_gen * n_genes * bytes_per_element / 1e6
|
|
468
|
+
|
|
469
|
+
# Average condition size
|
|
470
|
+
n_conditions = len(self.get_common_conditions())
|
|
471
|
+
avg_per_condition = (n_real + n_gen) / max(n_conditions, 1)
|
|
472
|
+
per_condition = avg_per_condition * n_genes * bytes_per_element / 1e6
|
|
473
|
+
|
|
474
|
+
per_batch = self.batch_size * 2 * n_genes * bytes_per_element / 1e6
|
|
475
|
+
|
|
476
|
+
return {
|
|
477
|
+
"full_load_mb": full_real + full_gen,
|
|
478
|
+
"per_condition_mb": per_condition,
|
|
479
|
+
"per_batch_mb": per_batch,
|
|
480
|
+
"metadata_mb": (n_real + n_gen) * 100 / 1e6, # rough obs estimate
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
def close(self):
|
|
484
|
+
"""Close backed file handles if any."""
|
|
485
|
+
if self._real is not None and hasattr(self._real, 'file'):
|
|
486
|
+
try:
|
|
487
|
+
self._real.file.close()
|
|
488
|
+
except:
|
|
489
|
+
pass
|
|
490
|
+
if self._generated is not None and hasattr(self._generated, 'file'):
|
|
491
|
+
try:
|
|
492
|
+
self._generated.file.close()
|
|
493
|
+
except:
|
|
494
|
+
pass
|
|
495
|
+
|
|
496
|
+
def __enter__(self):
|
|
497
|
+
"""Context manager entry."""
|
|
498
|
+
self.initialize()
|
|
499
|
+
return self
|
|
500
|
+
|
|
501
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
502
|
+
"""Context manager exit."""
|
|
503
|
+
self.close()
|
|
504
|
+
|
|
505
|
+
def __repr__(self) -> str:
|
|
506
|
+
if not self._is_initialized:
|
|
507
|
+
return f"LazyGeneExpressionDataLoader(not initialized, backed={self.use_backed})"
|
|
508
|
+
|
|
509
|
+
return (
|
|
510
|
+
f"LazyGeneExpressionDataLoader("
|
|
511
|
+
f"real={len(self._real_obs)}x{len(self._real_var_names)}, "
|
|
512
|
+
f"gen={len(self._generated_obs)}x{len(self._generated_var_names)}, "
|
|
513
|
+
f"common_genes={self.n_genes}, "
|
|
514
|
+
f"batch_size={self.batch_size})"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def load_data_lazy(
|
|
519
|
+
real_path: Union[str, Path],
|
|
520
|
+
generated_path: Union[str, Path],
|
|
521
|
+
condition_columns: List[str],
|
|
522
|
+
split_column: Optional[str] = None,
|
|
523
|
+
batch_size: int = 256,
|
|
524
|
+
use_backed: bool = False,
|
|
525
|
+
**kwargs
|
|
526
|
+
) -> LazyGeneExpressionDataLoader:
|
|
527
|
+
"""
|
|
528
|
+
Convenience function to create a lazy data loader.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
real_path : str or Path
|
|
533
|
+
Path to real data h5ad file
|
|
534
|
+
generated_path : str or Path
|
|
535
|
+
Path to generated data h5ad file
|
|
536
|
+
condition_columns : List[str]
|
|
537
|
+
Columns to match between datasets
|
|
538
|
+
split_column : str, optional
|
|
539
|
+
Column indicating train/test split
|
|
540
|
+
batch_size : int
|
|
541
|
+
Maximum samples per batch
|
|
542
|
+
use_backed : bool
|
|
543
|
+
Use memory-mapped access for very large files
|
|
544
|
+
**kwargs
|
|
545
|
+
Additional arguments for LazyGeneExpressionDataLoader
|
|
546
|
+
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
LazyGeneExpressionDataLoader
|
|
550
|
+
Initialized lazy data loader
|
|
551
|
+
"""
|
|
552
|
+
loader = LazyGeneExpressionDataLoader(
|
|
553
|
+
real_path=real_path,
|
|
554
|
+
generated_path=generated_path,
|
|
555
|
+
condition_columns=condition_columns,
|
|
556
|
+
split_column=split_column,
|
|
557
|
+
batch_size=batch_size,
|
|
558
|
+
use_backed=use_backed,
|
|
559
|
+
**kwargs
|
|
560
|
+
)
|
|
561
|
+
loader.initialize()
|
|
562
|
+
return loader
|
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-efficient evaluator for large-scale gene expression datasets.
|
|
3
|
+
|
|
4
|
+
Uses lazy loading and batched processing to minimize memory footprint.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Dict, List, Optional, Union, Type, Any, Generator
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import numpy as np
|
|
11
|
+
import warnings
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
import gc
|
|
14
|
+
|
|
15
|
+
from .data.lazy_loader import (
|
|
16
|
+
LazyGeneExpressionDataLoader,
|
|
17
|
+
load_data_lazy,
|
|
18
|
+
ConditionBatch,
|
|
19
|
+
)
|
|
20
|
+
from .metrics.base_metric import BaseMetric, MetricResult
|
|
21
|
+
from .metrics.correlation import (
|
|
22
|
+
PearsonCorrelation,
|
|
23
|
+
SpearmanCorrelation,
|
|
24
|
+
MeanPearsonCorrelation,
|
|
25
|
+
MeanSpearmanCorrelation,
|
|
26
|
+
)
|
|
27
|
+
from .metrics.distances import (
|
|
28
|
+
Wasserstein1Distance,
|
|
29
|
+
Wasserstein2Distance,
|
|
30
|
+
MMDDistance,
|
|
31
|
+
EnergyDistance,
|
|
32
|
+
)
|
|
33
|
+
from .metrics.reconstruction import (
|
|
34
|
+
MSEDistance,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# These multivariate metrics don't support batched computation
|
|
38
|
+
from .metrics.distances import MultivariateWasserstein, MultivariateMMD
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Metrics that support incremental/batched computation
|
|
42
|
+
BATCHABLE_METRICS = [
|
|
43
|
+
MSEDistance,
|
|
44
|
+
PearsonCorrelation,
|
|
45
|
+
SpearmanCorrelation,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Metrics that require full data
|
|
49
|
+
NON_BATCHABLE_METRICS = [
|
|
50
|
+
Wasserstein1Distance,
|
|
51
|
+
Wasserstein2Distance,
|
|
52
|
+
MMDDistance,
|
|
53
|
+
EnergyDistance,
|
|
54
|
+
MultivariateWasserstein,
|
|
55
|
+
MultivariateMMD,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class StreamingMetricAccumulator:
|
|
61
|
+
"""Accumulates values for streaming mean/std computation."""
|
|
62
|
+
n: int = 0
|
|
63
|
+
sum: float = 0.0
|
|
64
|
+
sum_sq: float = 0.0
|
|
65
|
+
|
|
66
|
+
def add(self, value: float, count: int = 1):
|
|
67
|
+
"""Add a value (or batch of values with same value)."""
|
|
68
|
+
self.n += count
|
|
69
|
+
self.sum += value * count
|
|
70
|
+
self.sum_sq += (value ** 2) * count
|
|
71
|
+
|
|
72
|
+
def add_batch(self, values: np.ndarray):
|
|
73
|
+
"""Add multiple values."""
|
|
74
|
+
self.n += len(values)
|
|
75
|
+
self.sum += np.sum(values)
|
|
76
|
+
self.sum_sq += np.sum(values ** 2)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def mean(self) -> float:
|
|
80
|
+
return self.sum / self.n if self.n > 0 else 0.0
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def std(self) -> float:
|
|
84
|
+
if self.n <= 1:
|
|
85
|
+
return 0.0
|
|
86
|
+
variance = (self.sum_sq / self.n) - (self.mean ** 2)
|
|
87
|
+
return np.sqrt(max(0, variance))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class StreamingConditionResult:
|
|
92
|
+
"""Lightweight result for a single condition."""
|
|
93
|
+
condition_key: str
|
|
94
|
+
n_real_samples: int = 0
|
|
95
|
+
n_generated_samples: int = 0
|
|
96
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
97
|
+
real_mean: Optional[np.ndarray] = None
|
|
98
|
+
generated_mean: Optional[np.ndarray] = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class StreamingEvaluationResult:
|
|
103
|
+
"""Memory-efficient evaluation result that streams to disk."""
|
|
104
|
+
output_dir: Path
|
|
105
|
+
n_conditions: int = 0
|
|
106
|
+
metric_accumulators: Dict[str, StreamingMetricAccumulator] = field(default_factory=dict)
|
|
107
|
+
condition_keys: List[str] = field(default_factory=list)
|
|
108
|
+
|
|
109
|
+
def add_condition(self, result: StreamingConditionResult):
|
|
110
|
+
"""Add a condition result and update accumulators."""
|
|
111
|
+
self.n_conditions += 1
|
|
112
|
+
self.condition_keys.append(result.condition_key)
|
|
113
|
+
|
|
114
|
+
for metric_name, value in result.metrics.items():
|
|
115
|
+
if metric_name not in self.metric_accumulators:
|
|
116
|
+
self.metric_accumulators[metric_name] = StreamingMetricAccumulator()
|
|
117
|
+
self.metric_accumulators[metric_name].add(value)
|
|
118
|
+
|
|
119
|
+
def get_summary(self) -> Dict[str, Dict[str, float]]:
|
|
120
|
+
"""Get summary statistics."""
|
|
121
|
+
summary = {}
|
|
122
|
+
for name, acc in self.metric_accumulators.items():
|
|
123
|
+
summary[name] = {
|
|
124
|
+
"mean": acc.mean,
|
|
125
|
+
"std": acc.std,
|
|
126
|
+
"n": acc.n,
|
|
127
|
+
}
|
|
128
|
+
return summary
|
|
129
|
+
|
|
130
|
+
def save_summary(self):
|
|
131
|
+
"""Save summary to output directory."""
|
|
132
|
+
import json
|
|
133
|
+
|
|
134
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
135
|
+
|
|
136
|
+
summary = {
|
|
137
|
+
"n_conditions": self.n_conditions,
|
|
138
|
+
"metrics": self.get_summary(),
|
|
139
|
+
"condition_keys": self.condition_keys,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
with open(self.output_dir / "summary.json", "w") as f:
|
|
143
|
+
json.dump(summary, f, indent=2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class MemoryEfficientEvaluator:
|
|
147
|
+
"""
|
|
148
|
+
Memory-efficient evaluator using lazy loading and batched processing.
|
|
149
|
+
|
|
150
|
+
Features:
|
|
151
|
+
- Lazy data loading (one condition at a time)
|
|
152
|
+
- Batched processing within conditions
|
|
153
|
+
- Streaming metric accumulation
|
|
154
|
+
- Periodic garbage collection
|
|
155
|
+
- Progress streaming to disk
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
data_loader : LazyGeneExpressionDataLoader
|
|
160
|
+
Lazy data loader
|
|
161
|
+
metrics : List[BaseMetric], optional
|
|
162
|
+
Metrics to compute. Note: Some metrics (like MMD) may not support
|
|
163
|
+
batched computation and will use full condition data.
|
|
164
|
+
batch_size : int
|
|
165
|
+
Batch size for within-condition processing
|
|
166
|
+
gc_every_n_conditions : int
|
|
167
|
+
Run garbage collection every N conditions
|
|
168
|
+
verbose : bool
|
|
169
|
+
Print progress
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
data_loader: LazyGeneExpressionDataLoader,
|
|
175
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
176
|
+
batch_size: int = 256,
|
|
177
|
+
gc_every_n_conditions: int = 10,
|
|
178
|
+
verbose: bool = True,
|
|
179
|
+
):
|
|
180
|
+
self.data_loader = data_loader
|
|
181
|
+
self.batch_size = batch_size
|
|
182
|
+
self.gc_every_n_conditions = gc_every_n_conditions
|
|
183
|
+
self.verbose = verbose
|
|
184
|
+
|
|
185
|
+
# Initialize metrics
|
|
186
|
+
self.metrics: List[BaseMetric] = []
|
|
187
|
+
metric_classes = metrics or [
|
|
188
|
+
MSEDistance,
|
|
189
|
+
PearsonCorrelation,
|
|
190
|
+
SpearmanCorrelation,
|
|
191
|
+
MeanPearsonCorrelation,
|
|
192
|
+
MeanSpearmanCorrelation,
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
for m in metric_classes:
|
|
196
|
+
if isinstance(m, type):
|
|
197
|
+
self.metrics.append(m())
|
|
198
|
+
else:
|
|
199
|
+
self.metrics.append(m)
|
|
200
|
+
|
|
201
|
+
def _log(self, msg: str):
|
|
202
|
+
if self.verbose:
|
|
203
|
+
print(msg)
|
|
204
|
+
|
|
205
|
+
def evaluate(
|
|
206
|
+
self,
|
|
207
|
+
split: Optional[str] = None,
|
|
208
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
209
|
+
save_per_condition: bool = False,
|
|
210
|
+
) -> StreamingEvaluationResult:
|
|
211
|
+
"""
|
|
212
|
+
Run memory-efficient evaluation.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
split : str, optional
|
|
217
|
+
Split to evaluate
|
|
218
|
+
output_dir : str or Path, optional
|
|
219
|
+
Directory to save results. If provided, results are streamed to disk.
|
|
220
|
+
save_per_condition : bool
|
|
221
|
+
If True, save individual condition results to disk
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
StreamingEvaluationResult
|
|
226
|
+
Evaluation result with aggregated metrics
|
|
227
|
+
"""
|
|
228
|
+
if output_dir is not None:
|
|
229
|
+
output_dir = Path(output_dir)
|
|
230
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
231
|
+
else:
|
|
232
|
+
output_dir = Path(".")
|
|
233
|
+
|
|
234
|
+
result = StreamingEvaluationResult(output_dir=output_dir)
|
|
235
|
+
|
|
236
|
+
# Get conditions
|
|
237
|
+
conditions = self.data_loader.get_common_conditions(split)
|
|
238
|
+
self._log(f"Evaluating {len(conditions)} conditions")
|
|
239
|
+
self._log(f"Memory estimate: {self.data_loader.estimate_memory_usage()}")
|
|
240
|
+
|
|
241
|
+
# Iterate conditions (one at a time in memory)
|
|
242
|
+
for i, (cond_key, real_data, gen_data, cond_info) in enumerate(
|
|
243
|
+
self.data_loader.iterate_conditions(split)
|
|
244
|
+
):
|
|
245
|
+
if self.verbose and (i + 1) % 10 == 0:
|
|
246
|
+
self._log(f" Processing {i + 1}/{len(conditions)}: {cond_key}")
|
|
247
|
+
|
|
248
|
+
# Compute metrics for this condition
|
|
249
|
+
cond_result = self._evaluate_condition(
|
|
250
|
+
cond_key, real_data, gen_data, cond_info
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Add to streaming result
|
|
254
|
+
result.add_condition(cond_result)
|
|
255
|
+
|
|
256
|
+
# Optionally save per-condition result
|
|
257
|
+
if save_per_condition and output_dir:
|
|
258
|
+
self._save_condition_result(cond_result, output_dir)
|
|
259
|
+
|
|
260
|
+
# Periodic garbage collection
|
|
261
|
+
if (i + 1) % self.gc_every_n_conditions == 0:
|
|
262
|
+
gc.collect()
|
|
263
|
+
|
|
264
|
+
# Final summary
|
|
265
|
+
result.save_summary()
|
|
266
|
+
|
|
267
|
+
if self.verbose:
|
|
268
|
+
self._print_summary(result)
|
|
269
|
+
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
def _evaluate_condition(
|
|
273
|
+
self,
|
|
274
|
+
cond_key: str,
|
|
275
|
+
real_data: np.ndarray,
|
|
276
|
+
gen_data: np.ndarray,
|
|
277
|
+
cond_info: Dict[str, str],
|
|
278
|
+
) -> StreamingConditionResult:
|
|
279
|
+
"""Evaluate a single condition."""
|
|
280
|
+
result = StreamingConditionResult(
|
|
281
|
+
condition_key=cond_key,
|
|
282
|
+
n_real_samples=real_data.shape[0],
|
|
283
|
+
n_generated_samples=gen_data.shape[0],
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Compute means
|
|
287
|
+
result.real_mean = real_data.mean(axis=0)
|
|
288
|
+
result.generated_mean = gen_data.mean(axis=0)
|
|
289
|
+
|
|
290
|
+
# Compute metrics
|
|
291
|
+
for metric in self.metrics:
|
|
292
|
+
try:
|
|
293
|
+
metric_result = metric.compute(
|
|
294
|
+
real=real_data,
|
|
295
|
+
generated=gen_data,
|
|
296
|
+
gene_names=self.data_loader.gene_names,
|
|
297
|
+
aggregate_method="mean",
|
|
298
|
+
condition=cond_key,
|
|
299
|
+
)
|
|
300
|
+
result.metrics[metric.name] = metric_result.aggregate_value
|
|
301
|
+
except Exception as e:
|
|
302
|
+
warnings.warn(f"Failed to compute {metric.name} for {cond_key}: {e}")
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
def _save_condition_result(
|
|
307
|
+
self,
|
|
308
|
+
result: StreamingConditionResult,
|
|
309
|
+
output_dir: Path,
|
|
310
|
+
):
|
|
311
|
+
"""Save a single condition result to disk."""
|
|
312
|
+
import json
|
|
313
|
+
|
|
314
|
+
condition_dir = output_dir / "conditions"
|
|
315
|
+
condition_dir.mkdir(exist_ok=True)
|
|
316
|
+
|
|
317
|
+
# Safe filename
|
|
318
|
+
safe_key = result.condition_key.replace("/", "_").replace("\\", "_")
|
|
319
|
+
|
|
320
|
+
data = {
|
|
321
|
+
"condition_key": result.condition_key,
|
|
322
|
+
"n_real": result.n_real_samples,
|
|
323
|
+
"n_generated": result.n_generated_samples,
|
|
324
|
+
"metrics": result.metrics,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
with open(condition_dir / f"{safe_key}.json", "w") as f:
|
|
328
|
+
json.dump(data, f, indent=2)
|
|
329
|
+
|
|
330
|
+
def _print_summary(self, result: StreamingEvaluationResult):
|
|
331
|
+
"""Print summary."""
|
|
332
|
+
self._log("\n" + "=" * 60)
|
|
333
|
+
self._log("EVALUATION SUMMARY (Memory-Efficient)")
|
|
334
|
+
self._log("=" * 60)
|
|
335
|
+
self._log(f"Conditions evaluated: {result.n_conditions}")
|
|
336
|
+
self._log("-" * 40)
|
|
337
|
+
|
|
338
|
+
for name, stats in result.get_summary().items():
|
|
339
|
+
self._log(f" {name}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
340
|
+
|
|
341
|
+
self._log("=" * 60)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def evaluate_lazy(
|
|
345
|
+
real_path: Union[str, Path],
|
|
346
|
+
generated_path: Union[str, Path],
|
|
347
|
+
condition_columns: List[str],
|
|
348
|
+
split_column: Optional[str] = None,
|
|
349
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
350
|
+
batch_size: int = 256,
|
|
351
|
+
use_backed: bool = False,
|
|
352
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
353
|
+
verbose: bool = True,
|
|
354
|
+
save_per_condition: bool = False,
|
|
355
|
+
**kwargs
|
|
356
|
+
) -> StreamingEvaluationResult:
|
|
357
|
+
"""
|
|
358
|
+
Memory-efficient evaluation using lazy loading.
|
|
359
|
+
|
|
360
|
+
Use this function for large datasets that don't fit in memory.
|
|
361
|
+
|
|
362
|
+
Parameters
|
|
363
|
+
----------
|
|
364
|
+
real_path : str or Path
|
|
365
|
+
Path to real data h5ad file
|
|
366
|
+
generated_path : str or Path
|
|
367
|
+
Path to generated data h5ad file
|
|
368
|
+
condition_columns : List[str]
|
|
369
|
+
Columns to match between datasets
|
|
370
|
+
split_column : str, optional
|
|
371
|
+
Column for train/test split
|
|
372
|
+
output_dir : str or Path, optional
|
|
373
|
+
Directory to save results
|
|
374
|
+
batch_size : int
|
|
375
|
+
Batch size for processing
|
|
376
|
+
use_backed : bool
|
|
377
|
+
Use memory-mapped file access (for very large files)
|
|
378
|
+
metrics : List, optional
|
|
379
|
+
Metrics to compute
|
|
380
|
+
verbose : bool
|
|
381
|
+
Print progress
|
|
382
|
+
save_per_condition : bool
|
|
383
|
+
Save individual condition results
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
StreamingEvaluationResult
|
|
388
|
+
Aggregated evaluation results
|
|
389
|
+
|
|
390
|
+
Examples
|
|
391
|
+
--------
|
|
392
|
+
>>> # For large datasets that don't fit in memory
|
|
393
|
+
>>> results = evaluate_lazy(
|
|
394
|
+
... "real.h5ad",
|
|
395
|
+
... "generated.h5ad",
|
|
396
|
+
... condition_columns=["perturbation"],
|
|
397
|
+
... output_dir="eval_output/",
|
|
398
|
+
... batch_size=256,
|
|
399
|
+
... use_backed=True, # Memory-mapped for very large files
|
|
400
|
+
... )
|
|
401
|
+
>>> print(results.get_summary())
|
|
402
|
+
"""
|
|
403
|
+
# Create lazy loader
|
|
404
|
+
with load_data_lazy(
|
|
405
|
+
real_path=real_path,
|
|
406
|
+
generated_path=generated_path,
|
|
407
|
+
condition_columns=condition_columns,
|
|
408
|
+
split_column=split_column,
|
|
409
|
+
batch_size=batch_size,
|
|
410
|
+
use_backed=use_backed,
|
|
411
|
+
) as loader:
|
|
412
|
+
# Create evaluator
|
|
413
|
+
evaluator = MemoryEfficientEvaluator(
|
|
414
|
+
data_loader=loader,
|
|
415
|
+
metrics=metrics,
|
|
416
|
+
batch_size=batch_size,
|
|
417
|
+
verbose=verbose,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Run evaluation
|
|
421
|
+
return evaluator.evaluate(
|
|
422
|
+
output_dir=output_dir,
|
|
423
|
+
save_per_condition=save_per_condition,
|
|
424
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|