gengeneeval 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +52 -1
- geneval/data/__init__.py +14 -0
- geneval/data/lazy_loader.py +562 -0
- geneval/evaluator.py +46 -0
- geneval/lazy_evaluator.py +424 -0
- geneval/metrics/__init__.py +25 -0
- geneval/metrics/accelerated.py +857 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/METADATA +111 -4
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/RECORD +12 -9
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/WHEEL +0 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.2.0.dist-info → gengeneeval-0.3.0.dist-info}/licenses/LICENSE +0 -0
geneval/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ Features:
|
|
|
8
8
|
- Multiple distance and correlation metrics (per-gene and aggregate)
|
|
9
9
|
- Condition-based matching (perturbation, cell type, etc.)
|
|
10
10
|
- Train/test split support
|
|
11
|
+
- Memory-efficient lazy loading for large datasets
|
|
11
12
|
- Publication-quality visualizations
|
|
12
13
|
- Command-line interface
|
|
13
14
|
|
|
@@ -20,12 +21,22 @@ Quick Start:
|
|
|
20
21
|
... output_dir="output/"
|
|
21
22
|
... )
|
|
22
23
|
|
|
24
|
+
Memory-Efficient Mode (for large datasets):
|
|
25
|
+
>>> from geneval import evaluate_lazy
|
|
26
|
+
>>> results = evaluate_lazy(
|
|
27
|
+
... real_path="real.h5ad",
|
|
28
|
+
... generated_path="generated.h5ad",
|
|
29
|
+
... condition_columns=["perturbation"],
|
|
30
|
+
... batch_size=256,
|
|
31
|
+
... use_backed=True, # Memory-mapped access
|
|
32
|
+
... )
|
|
33
|
+
|
|
23
34
|
CLI Usage:
|
|
24
35
|
$ geneval --real real.h5ad --generated generated.h5ad \\
|
|
25
36
|
--conditions perturbation cell_type --output results/
|
|
26
37
|
"""
|
|
27
38
|
|
|
28
|
-
__version__ = "0.
|
|
39
|
+
__version__ = "0.3.0"
|
|
29
40
|
__author__ = "GenEval Team"
|
|
30
41
|
|
|
31
42
|
# Main evaluation interface
|
|
@@ -35,12 +46,26 @@ from .evaluator import (
|
|
|
35
46
|
MetricRegistry,
|
|
36
47
|
)
|
|
37
48
|
|
|
49
|
+
# Memory-efficient evaluation
|
|
50
|
+
from .lazy_evaluator import (
|
|
51
|
+
evaluate_lazy,
|
|
52
|
+
MemoryEfficientEvaluator,
|
|
53
|
+
StreamingEvaluationResult,
|
|
54
|
+
)
|
|
55
|
+
|
|
38
56
|
# Data loading
|
|
39
57
|
from .data.loader import (
|
|
40
58
|
GeneExpressionDataLoader,
|
|
41
59
|
load_data,
|
|
42
60
|
)
|
|
43
61
|
|
|
62
|
+
# Memory-efficient data loading
|
|
63
|
+
from .data.lazy_loader import (
|
|
64
|
+
LazyGeneExpressionDataLoader,
|
|
65
|
+
load_data_lazy,
|
|
66
|
+
ConditionBatch,
|
|
67
|
+
)
|
|
68
|
+
|
|
44
69
|
# Results
|
|
45
70
|
from .results import (
|
|
46
71
|
EvaluationResult,
|
|
@@ -76,6 +101,14 @@ from .metrics.reconstruction import (
|
|
|
76
101
|
R2Score,
|
|
77
102
|
)
|
|
78
103
|
|
|
104
|
+
# Accelerated computation
|
|
105
|
+
from .metrics.accelerated import (
|
|
106
|
+
AccelerationConfig,
|
|
107
|
+
ParallelMetricComputer,
|
|
108
|
+
get_available_backends,
|
|
109
|
+
compute_metrics_accelerated,
|
|
110
|
+
)
|
|
111
|
+
|
|
79
112
|
# Visualization
|
|
80
113
|
from .visualization.visualizer import (
|
|
81
114
|
EvaluationVisualizer,
|
|
@@ -99,9 +132,17 @@ __all__ = [
|
|
|
99
132
|
"evaluate",
|
|
100
133
|
"GeneEvalEvaluator",
|
|
101
134
|
"MetricRegistry",
|
|
135
|
+
# Memory-efficient evaluation
|
|
136
|
+
"evaluate_lazy",
|
|
137
|
+
"MemoryEfficientEvaluator",
|
|
138
|
+
"StreamingEvaluationResult",
|
|
102
139
|
# Data loading
|
|
103
140
|
"GeneExpressionDataLoader",
|
|
104
141
|
"load_data",
|
|
142
|
+
# Memory-efficient data loading
|
|
143
|
+
"LazyGeneExpressionDataLoader",
|
|
144
|
+
"load_data_lazy",
|
|
145
|
+
"ConditionBatch",
|
|
105
146
|
# Results
|
|
106
147
|
"EvaluationResult",
|
|
107
148
|
"SplitResult",
|
|
@@ -123,6 +164,16 @@ __all__ = [
|
|
|
123
164
|
"EnergyDistance",
|
|
124
165
|
"MultivariateWasserstein",
|
|
125
166
|
"MultivariateMMD",
|
|
167
|
+
# Reconstruction metrics
|
|
168
|
+
"MSEDistance",
|
|
169
|
+
"RMSEDistance",
|
|
170
|
+
"MAEDistance",
|
|
171
|
+
"R2Score",
|
|
172
|
+
# Acceleration
|
|
173
|
+
"AccelerationConfig",
|
|
174
|
+
"ParallelMetricComputer",
|
|
175
|
+
"get_available_backends",
|
|
176
|
+
"compute_metrics_accelerated",
|
|
126
177
|
# Visualization
|
|
127
178
|
"EvaluationVisualizer",
|
|
128
179
|
"visualize",
|
geneval/data/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Data loading module for gene expression evaluation.
|
|
3
3
|
|
|
4
4
|
Provides data loaders for paired real and generated datasets.
|
|
5
|
+
Includes both standard and memory-efficient lazy loading options.
|
|
5
6
|
"""
|
|
6
7
|
|
|
7
8
|
from .loader import (
|
|
@@ -9,15 +10,28 @@ from .loader import (
|
|
|
9
10
|
load_data,
|
|
10
11
|
DataLoaderError,
|
|
11
12
|
)
|
|
13
|
+
from .lazy_loader import (
|
|
14
|
+
LazyGeneExpressionDataLoader,
|
|
15
|
+
load_data_lazy,
|
|
16
|
+
LazyDataLoaderError,
|
|
17
|
+
ConditionBatch,
|
|
18
|
+
)
|
|
12
19
|
from .gene_expression_datamodule import (
|
|
13
20
|
GeneExpressionDataModule,
|
|
14
21
|
DataModuleError,
|
|
15
22
|
)
|
|
16
23
|
|
|
17
24
|
__all__ = [
|
|
25
|
+
# Standard loader
|
|
18
26
|
"GeneExpressionDataLoader",
|
|
19
27
|
"load_data",
|
|
20
28
|
"DataLoaderError",
|
|
29
|
+
# Lazy loader (memory-efficient)
|
|
30
|
+
"LazyGeneExpressionDataLoader",
|
|
31
|
+
"load_data_lazy",
|
|
32
|
+
"LazyDataLoaderError",
|
|
33
|
+
"ConditionBatch",
|
|
34
|
+
# DataModule
|
|
21
35
|
"GeneExpressionDataModule",
|
|
22
36
|
"DataModuleError",
|
|
23
37
|
]
|
|
@@ -0,0 +1,562 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-efficient lazy data loader for large-scale gene expression datasets.
|
|
3
|
+
|
|
4
|
+
Provides lazy loading and batched iteration over AnnData h5ad files without
|
|
5
|
+
loading entire datasets into memory. Supports backed mode for very large files.
|
|
6
|
+
"""
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Optional, List, Union, Dict, Tuple, Iterator, Generator
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import warnings
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from scipy import sparse
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import anndata as ad
|
|
19
|
+
import scanpy as sc
|
|
20
|
+
except ImportError:
|
|
21
|
+
raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LazyDataLoaderError(Exception):
|
|
25
|
+
"""Custom exception for lazy data loading errors."""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ConditionBatch:
|
|
31
|
+
"""Container for a batch of samples from a condition."""
|
|
32
|
+
condition_key: str
|
|
33
|
+
condition_info: Dict[str, str]
|
|
34
|
+
real_data: np.ndarray
|
|
35
|
+
generated_data: np.ndarray
|
|
36
|
+
batch_idx: int
|
|
37
|
+
n_batches: int
|
|
38
|
+
is_last_batch: bool
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LazyGeneExpressionDataLoader:
|
|
42
|
+
"""
|
|
43
|
+
Memory-efficient lazy data loader for paired gene expression datasets.
|
|
44
|
+
|
|
45
|
+
Unlike GeneExpressionDataLoader, this class:
|
|
46
|
+
- Uses backed mode for h5ad files to avoid loading entire datasets
|
|
47
|
+
- Supports batched iteration over conditions
|
|
48
|
+
- Only loads data into memory when explicitly requested
|
|
49
|
+
- Provides memory usage estimates
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
real_path : str or Path
|
|
54
|
+
Path to real data h5ad file
|
|
55
|
+
generated_path : str or Path
|
|
56
|
+
Path to generated data h5ad file
|
|
57
|
+
condition_columns : List[str]
|
|
58
|
+
Columns to match between datasets
|
|
59
|
+
split_column : str, optional
|
|
60
|
+
Column indicating train/test split
|
|
61
|
+
batch_size : int
|
|
62
|
+
Maximum number of samples per batch when iterating
|
|
63
|
+
use_backed : bool
|
|
64
|
+
If True, use backed mode (memory-mapped). May be slower but uses minimal memory.
|
|
65
|
+
If False, loads full file but processes in batches.
|
|
66
|
+
min_samples_per_condition : int
|
|
67
|
+
Minimum samples required per condition to include
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
real_path: Union[str, Path],
|
|
73
|
+
generated_path: Union[str, Path],
|
|
74
|
+
condition_columns: List[str],
|
|
75
|
+
split_column: Optional[str] = None,
|
|
76
|
+
batch_size: int = 256,
|
|
77
|
+
use_backed: bool = False,
|
|
78
|
+
min_samples_per_condition: int = 2,
|
|
79
|
+
):
|
|
80
|
+
self.real_path = Path(real_path)
|
|
81
|
+
self.generated_path = Path(generated_path)
|
|
82
|
+
self.condition_columns = condition_columns
|
|
83
|
+
self.split_column = split_column
|
|
84
|
+
self.batch_size = batch_size
|
|
85
|
+
self.use_backed = use_backed
|
|
86
|
+
self.min_samples_per_condition = min_samples_per_condition
|
|
87
|
+
|
|
88
|
+
# Lazy-loaded references (backed or full)
|
|
89
|
+
self._real: Optional[ad.AnnData] = None
|
|
90
|
+
self._generated: Optional[ad.AnnData] = None
|
|
91
|
+
|
|
92
|
+
# Metadata (always loaded - lightweight)
|
|
93
|
+
self._real_obs: Optional[pd.DataFrame] = None
|
|
94
|
+
self._generated_obs: Optional[pd.DataFrame] = None
|
|
95
|
+
self._real_var_names: Optional[pd.Index] = None
|
|
96
|
+
self._generated_var_names: Optional[pd.Index] = None
|
|
97
|
+
|
|
98
|
+
# Gene alignment info
|
|
99
|
+
self._common_genes: Optional[List[str]] = None
|
|
100
|
+
self._real_gene_idx: Optional[np.ndarray] = None
|
|
101
|
+
self._gen_gene_idx: Optional[np.ndarray] = None
|
|
102
|
+
|
|
103
|
+
# Pre-computed condition indices for fast access
|
|
104
|
+
self._condition_indices: Optional[Dict[str, Dict[str, np.ndarray]]] = None
|
|
105
|
+
|
|
106
|
+
# State
|
|
107
|
+
self._is_initialized = False
|
|
108
|
+
|
|
109
|
+
def initialize(self) -> "LazyGeneExpressionDataLoader":
|
|
110
|
+
"""
|
|
111
|
+
Initialize loader by reading metadata only (not expression data).
|
|
112
|
+
|
|
113
|
+
This loads obs DataFrames and var_names to prepare for iteration,
|
|
114
|
+
but does not load the expression matrices.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
self
|
|
119
|
+
For method chaining
|
|
120
|
+
"""
|
|
121
|
+
if self._is_initialized:
|
|
122
|
+
return self
|
|
123
|
+
|
|
124
|
+
# Validate paths
|
|
125
|
+
if not self.real_path.exists():
|
|
126
|
+
raise LazyDataLoaderError(f"Real data file not found: {self.real_path}")
|
|
127
|
+
if not self.generated_path.exists():
|
|
128
|
+
raise LazyDataLoaderError(f"Generated data file not found: {self.generated_path}")
|
|
129
|
+
|
|
130
|
+
# Load metadata only (obs and var_names)
|
|
131
|
+
# This is much faster and lighter than loading full data
|
|
132
|
+
self._load_metadata()
|
|
133
|
+
|
|
134
|
+
# Validate columns
|
|
135
|
+
self._validate_columns()
|
|
136
|
+
|
|
137
|
+
# Compute gene alignment indices
|
|
138
|
+
self._compute_gene_alignment()
|
|
139
|
+
|
|
140
|
+
# Pre-compute condition indices
|
|
141
|
+
self._precompute_condition_indices()
|
|
142
|
+
|
|
143
|
+
self._is_initialized = True
|
|
144
|
+
return self
|
|
145
|
+
|
|
146
|
+
def _load_metadata(self):
|
|
147
|
+
"""Load only metadata (obs, var_names) without expression data."""
|
|
148
|
+
# For backed mode, we open but don't load X
|
|
149
|
+
# For non-backed, we still only read metadata initially
|
|
150
|
+
|
|
151
|
+
# Try backed mode if requested
|
|
152
|
+
if self.use_backed:
|
|
153
|
+
try:
|
|
154
|
+
self._real = sc.read_h5ad(self.real_path, backed='r')
|
|
155
|
+
self._generated = sc.read_h5ad(self.generated_path, backed='r')
|
|
156
|
+
except Exception as e:
|
|
157
|
+
warnings.warn(f"Backed mode failed, falling back to standard loading: {e}")
|
|
158
|
+
self._real = None
|
|
159
|
+
self._generated = None
|
|
160
|
+
|
|
161
|
+
if self._real is None:
|
|
162
|
+
# Load only what we need: obs and var
|
|
163
|
+
# For very large files, read in low-memory mode
|
|
164
|
+
self._real = sc.read_h5ad(self.real_path)
|
|
165
|
+
self._generated = sc.read_h5ad(self.generated_path)
|
|
166
|
+
|
|
167
|
+
# Cache lightweight metadata
|
|
168
|
+
self._real_obs = self._real.obs.copy()
|
|
169
|
+
self._generated_obs = self._generated.obs.copy()
|
|
170
|
+
self._real_var_names = pd.Index(self._real.var_names.astype(str))
|
|
171
|
+
self._generated_var_names = pd.Index(self._generated.var_names.astype(str))
|
|
172
|
+
|
|
173
|
+
def _validate_columns(self):
|
|
174
|
+
"""Validate that required columns exist."""
|
|
175
|
+
for col in self.condition_columns:
|
|
176
|
+
if col not in self._real_obs.columns:
|
|
177
|
+
raise LazyDataLoaderError(
|
|
178
|
+
f"Condition column '{col}' not found in real data. "
|
|
179
|
+
f"Available: {list(self._real_obs.columns)}"
|
|
180
|
+
)
|
|
181
|
+
if col not in self._generated_obs.columns:
|
|
182
|
+
raise LazyDataLoaderError(
|
|
183
|
+
f"Condition column '{col}' not found in generated data. "
|
|
184
|
+
f"Available: {list(self._generated_obs.columns)}"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _compute_gene_alignment(self):
|
|
188
|
+
"""Pre-compute gene alignment indices."""
|
|
189
|
+
common = self._real_var_names.intersection(self._generated_var_names)
|
|
190
|
+
|
|
191
|
+
if len(common) == 0:
|
|
192
|
+
raise LazyDataLoaderError(
|
|
193
|
+
"No overlapping genes between real and generated data."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self._common_genes = common.tolist()
|
|
197
|
+
self._real_gene_idx = self._real_var_names.get_indexer(common)
|
|
198
|
+
self._gen_gene_idx = self._generated_var_names.get_indexer(common)
|
|
199
|
+
|
|
200
|
+
n_real_only = len(self._real_var_names) - len(common)
|
|
201
|
+
n_gen_only = len(self._generated_var_names) - len(common)
|
|
202
|
+
|
|
203
|
+
if n_real_only > 0 or n_gen_only > 0:
|
|
204
|
+
warnings.warn(
|
|
205
|
+
f"Gene alignment: {len(common)} common genes. "
|
|
206
|
+
f"Dropped {n_real_only} from real, {n_gen_only} from generated."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _get_condition_key(self, row: pd.Series) -> str:
|
|
210
|
+
"""Generate unique key for a condition combination."""
|
|
211
|
+
return "####".join([str(row[c]) for c in self.condition_columns])
|
|
212
|
+
|
|
213
|
+
def _precompute_condition_indices(self):
|
|
214
|
+
"""Pre-compute sample indices for each condition (lightweight)."""
|
|
215
|
+
self._condition_indices = {"real": {}, "generated": {}}
|
|
216
|
+
|
|
217
|
+
# Real data conditions
|
|
218
|
+
real_conditions = self._real_obs[self.condition_columns].astype(str).drop_duplicates()
|
|
219
|
+
for _, row in real_conditions.iterrows():
|
|
220
|
+
key = self._get_condition_key(row)
|
|
221
|
+
mask = np.ones(len(self._real_obs), dtype=bool)
|
|
222
|
+
for col in self.condition_columns:
|
|
223
|
+
mask &= (self._real_obs[col].astype(str) == str(row[col])).values
|
|
224
|
+
|
|
225
|
+
indices = np.where(mask)[0]
|
|
226
|
+
if len(indices) >= self.min_samples_per_condition:
|
|
227
|
+
self._condition_indices["real"][key] = indices
|
|
228
|
+
|
|
229
|
+
# Generated data conditions
|
|
230
|
+
gen_conditions = self._generated_obs[self.condition_columns].astype(str).drop_duplicates()
|
|
231
|
+
for _, row in gen_conditions.iterrows():
|
|
232
|
+
key = self._get_condition_key(row)
|
|
233
|
+
mask = np.ones(len(self._generated_obs), dtype=bool)
|
|
234
|
+
for col in self.condition_columns:
|
|
235
|
+
mask &= (self._generated_obs[col].astype(str) == str(row[col])).values
|
|
236
|
+
|
|
237
|
+
indices = np.where(mask)[0]
|
|
238
|
+
if len(indices) >= self.min_samples_per_condition:
|
|
239
|
+
self._condition_indices["generated"][key] = indices
|
|
240
|
+
|
|
241
|
+
def get_splits(self) -> List[str]:
|
|
242
|
+
"""Get available splits."""
|
|
243
|
+
if not self._is_initialized:
|
|
244
|
+
self.initialize()
|
|
245
|
+
|
|
246
|
+
if self.split_column is None or self.split_column not in self._real_obs.columns:
|
|
247
|
+
return ["all"]
|
|
248
|
+
|
|
249
|
+
return list(self._real_obs[self.split_column].astype(str).unique())
|
|
250
|
+
|
|
251
|
+
def get_common_conditions(self, split: Optional[str] = None) -> List[str]:
|
|
252
|
+
"""Get conditions present in both real and generated data."""
|
|
253
|
+
if not self._is_initialized:
|
|
254
|
+
self.initialize()
|
|
255
|
+
|
|
256
|
+
real_keys = set(self._condition_indices["real"].keys())
|
|
257
|
+
gen_keys = set(self._condition_indices["generated"].keys())
|
|
258
|
+
|
|
259
|
+
common = real_keys & gen_keys
|
|
260
|
+
|
|
261
|
+
# Filter by split if specified
|
|
262
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
263
|
+
filtered = set()
|
|
264
|
+
for key in common:
|
|
265
|
+
real_idx = self._condition_indices["real"][key]
|
|
266
|
+
split_vals = self._real_obs.iloc[real_idx][self.split_column].astype(str)
|
|
267
|
+
if (split_vals == split).any():
|
|
268
|
+
filtered.add(key)
|
|
269
|
+
common = filtered
|
|
270
|
+
|
|
271
|
+
return sorted(common)
|
|
272
|
+
|
|
273
|
+
def _extract_data_subset(
|
|
274
|
+
self,
|
|
275
|
+
adata: ad.AnnData,
|
|
276
|
+
indices: np.ndarray,
|
|
277
|
+
gene_idx: np.ndarray,
|
|
278
|
+
) -> np.ndarray:
|
|
279
|
+
"""Extract and align a subset of data."""
|
|
280
|
+
# Handle backed vs loaded data
|
|
281
|
+
if hasattr(adata, 'isbacked') and adata.isbacked:
|
|
282
|
+
# Backed mode: read only what we need
|
|
283
|
+
X = adata.X[indices][:, gene_idx]
|
|
284
|
+
else:
|
|
285
|
+
# Standard mode
|
|
286
|
+
X = adata.X[indices][:, gene_idx]
|
|
287
|
+
|
|
288
|
+
# Convert to dense if sparse
|
|
289
|
+
if sparse.issparse(X):
|
|
290
|
+
X = X.toarray()
|
|
291
|
+
|
|
292
|
+
return np.asarray(X, dtype=np.float32)
|
|
293
|
+
|
|
294
|
+
def iterate_conditions(
|
|
295
|
+
self,
|
|
296
|
+
split: Optional[str] = None,
|
|
297
|
+
) -> Generator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]], None, None]:
|
|
298
|
+
"""
|
|
299
|
+
Iterate over conditions, loading one condition at a time.
|
|
300
|
+
|
|
301
|
+
This loads data for one condition, yields it, then releases memory
|
|
302
|
+
before loading the next condition.
|
|
303
|
+
|
|
304
|
+
Parameters
|
|
305
|
+
----------
|
|
306
|
+
split : str, optional
|
|
307
|
+
Filter to this split only
|
|
308
|
+
|
|
309
|
+
Yields
|
|
310
|
+
------
|
|
311
|
+
Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
|
|
312
|
+
(condition_key, real_data, generated_data, condition_info)
|
|
313
|
+
"""
|
|
314
|
+
if not self._is_initialized:
|
|
315
|
+
self.initialize()
|
|
316
|
+
|
|
317
|
+
common_conditions = self.get_common_conditions(split)
|
|
318
|
+
|
|
319
|
+
for key in common_conditions:
|
|
320
|
+
real_indices = self._condition_indices["real"][key]
|
|
321
|
+
gen_indices = self._condition_indices["generated"][key]
|
|
322
|
+
|
|
323
|
+
# Filter by split if needed
|
|
324
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
325
|
+
split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
|
|
326
|
+
real_indices = real_indices[split_mask.values]
|
|
327
|
+
|
|
328
|
+
if len(real_indices) < self.min_samples_per_condition:
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
# Load data for this condition only
|
|
332
|
+
real_data = self._extract_data_subset(
|
|
333
|
+
self._real, real_indices, self._real_gene_idx
|
|
334
|
+
)
|
|
335
|
+
gen_data = self._extract_data_subset(
|
|
336
|
+
self._generated, gen_indices, self._gen_gene_idx
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Parse condition info
|
|
340
|
+
parts = key.split("####")
|
|
341
|
+
condition_info = dict(zip(self.condition_columns, parts))
|
|
342
|
+
|
|
343
|
+
yield key, real_data, gen_data, condition_info
|
|
344
|
+
|
|
345
|
+
def iterate_conditions_batched(
|
|
346
|
+
self,
|
|
347
|
+
split: Optional[str] = None,
|
|
348
|
+
batch_size: Optional[int] = None,
|
|
349
|
+
) -> Generator[ConditionBatch, None, None]:
|
|
350
|
+
"""
|
|
351
|
+
Iterate over conditions in batches for memory efficiency.
|
|
352
|
+
|
|
353
|
+
Useful when even a single condition is too large to fit in memory.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
split : str, optional
|
|
358
|
+
Filter to this split only
|
|
359
|
+
batch_size : int, optional
|
|
360
|
+
Override default batch size
|
|
361
|
+
|
|
362
|
+
Yields
|
|
363
|
+
------
|
|
364
|
+
ConditionBatch
|
|
365
|
+
Batch of samples from a condition
|
|
366
|
+
"""
|
|
367
|
+
if not self._is_initialized:
|
|
368
|
+
self.initialize()
|
|
369
|
+
|
|
370
|
+
batch_size = batch_size or self.batch_size
|
|
371
|
+
common_conditions = self.get_common_conditions(split)
|
|
372
|
+
|
|
373
|
+
for key in common_conditions:
|
|
374
|
+
real_indices = self._condition_indices["real"][key]
|
|
375
|
+
gen_indices = self._condition_indices["generated"][key]
|
|
376
|
+
|
|
377
|
+
# Filter by split
|
|
378
|
+
if split is not None and split != "all" and self.split_column is not None:
|
|
379
|
+
split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
|
|
380
|
+
real_indices = real_indices[split_mask.values]
|
|
381
|
+
|
|
382
|
+
if len(real_indices) < self.min_samples_per_condition:
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# Parse condition info
|
|
386
|
+
parts = key.split("####")
|
|
387
|
+
condition_info = dict(zip(self.condition_columns, parts))
|
|
388
|
+
|
|
389
|
+
# Calculate number of batches (use max of real/gen for alignment)
|
|
390
|
+
n_real = len(real_indices)
|
|
391
|
+
n_gen = len(gen_indices)
|
|
392
|
+
n_batches = max(
|
|
393
|
+
(n_real + batch_size - 1) // batch_size,
|
|
394
|
+
(n_gen + batch_size - 1) // batch_size
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
for batch_idx in range(n_batches):
|
|
398
|
+
start_real = batch_idx * batch_size
|
|
399
|
+
end_real = min(start_real + batch_size, n_real)
|
|
400
|
+
start_gen = batch_idx * batch_size
|
|
401
|
+
end_gen = min(start_gen + batch_size, n_gen)
|
|
402
|
+
|
|
403
|
+
# Handle case where one dataset is smaller
|
|
404
|
+
if start_real >= n_real:
|
|
405
|
+
# Wrap around for real data
|
|
406
|
+
batch_real_idx = real_indices[start_real % n_real:end_real % n_real + 1]
|
|
407
|
+
else:
|
|
408
|
+
batch_real_idx = real_indices[start_real:end_real]
|
|
409
|
+
|
|
410
|
+
if start_gen >= n_gen:
|
|
411
|
+
batch_gen_idx = gen_indices[start_gen % n_gen:end_gen % n_gen + 1]
|
|
412
|
+
else:
|
|
413
|
+
batch_gen_idx = gen_indices[start_gen:end_gen]
|
|
414
|
+
|
|
415
|
+
if len(batch_real_idx) == 0 or len(batch_gen_idx) == 0:
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
real_data = self._extract_data_subset(
|
|
419
|
+
self._real, batch_real_idx, self._real_gene_idx
|
|
420
|
+
)
|
|
421
|
+
gen_data = self._extract_data_subset(
|
|
422
|
+
self._generated, batch_gen_idx, self._gen_gene_idx
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
yield ConditionBatch(
|
|
426
|
+
condition_key=key,
|
|
427
|
+
condition_info=condition_info,
|
|
428
|
+
real_data=real_data,
|
|
429
|
+
generated_data=gen_data,
|
|
430
|
+
batch_idx=batch_idx,
|
|
431
|
+
n_batches=n_batches,
|
|
432
|
+
is_last_batch=(batch_idx == n_batches - 1),
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
@property
|
|
436
|
+
def gene_names(self) -> List[str]:
|
|
437
|
+
"""Get common gene names."""
|
|
438
|
+
if not self._is_initialized:
|
|
439
|
+
self.initialize()
|
|
440
|
+
return self._common_genes
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def n_genes(self) -> int:
|
|
444
|
+
"""Number of common genes."""
|
|
445
|
+
return len(self.gene_names)
|
|
446
|
+
|
|
447
|
+
def estimate_memory_usage(self) -> Dict[str, float]:
|
|
448
|
+
"""
|
|
449
|
+
Estimate memory usage in MB for different loading strategies.
|
|
450
|
+
|
|
451
|
+
Returns
|
|
452
|
+
-------
|
|
453
|
+
Dict[str, float]
|
|
454
|
+
Memory estimates in MB
|
|
455
|
+
"""
|
|
456
|
+
if not self._is_initialized:
|
|
457
|
+
self.initialize()
|
|
458
|
+
|
|
459
|
+
n_real = len(self._real_obs)
|
|
460
|
+
n_gen = len(self._generated_obs)
|
|
461
|
+
n_genes = self.n_genes
|
|
462
|
+
|
|
463
|
+
# 4 bytes per float32
|
|
464
|
+
bytes_per_element = 4
|
|
465
|
+
|
|
466
|
+
full_real = n_real * n_genes * bytes_per_element / 1e6
|
|
467
|
+
full_gen = n_gen * n_genes * bytes_per_element / 1e6
|
|
468
|
+
|
|
469
|
+
# Average condition size
|
|
470
|
+
n_conditions = len(self.get_common_conditions())
|
|
471
|
+
avg_per_condition = (n_real + n_gen) / max(n_conditions, 1)
|
|
472
|
+
per_condition = avg_per_condition * n_genes * bytes_per_element / 1e6
|
|
473
|
+
|
|
474
|
+
per_batch = self.batch_size * 2 * n_genes * bytes_per_element / 1e6
|
|
475
|
+
|
|
476
|
+
return {
|
|
477
|
+
"full_load_mb": full_real + full_gen,
|
|
478
|
+
"per_condition_mb": per_condition,
|
|
479
|
+
"per_batch_mb": per_batch,
|
|
480
|
+
"metadata_mb": (n_real + n_gen) * 100 / 1e6, # rough obs estimate
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
def close(self):
|
|
484
|
+
"""Close backed file handles if any."""
|
|
485
|
+
if self._real is not None and hasattr(self._real, 'file'):
|
|
486
|
+
try:
|
|
487
|
+
self._real.file.close()
|
|
488
|
+
except:
|
|
489
|
+
pass
|
|
490
|
+
if self._generated is not None and hasattr(self._generated, 'file'):
|
|
491
|
+
try:
|
|
492
|
+
self._generated.file.close()
|
|
493
|
+
except:
|
|
494
|
+
pass
|
|
495
|
+
|
|
496
|
+
def __enter__(self):
|
|
497
|
+
"""Context manager entry."""
|
|
498
|
+
self.initialize()
|
|
499
|
+
return self
|
|
500
|
+
|
|
501
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
502
|
+
"""Context manager exit."""
|
|
503
|
+
self.close()
|
|
504
|
+
|
|
505
|
+
def __repr__(self) -> str:
|
|
506
|
+
if not self._is_initialized:
|
|
507
|
+
return f"LazyGeneExpressionDataLoader(not initialized, backed={self.use_backed})"
|
|
508
|
+
|
|
509
|
+
return (
|
|
510
|
+
f"LazyGeneExpressionDataLoader("
|
|
511
|
+
f"real={len(self._real_obs)}x{len(self._real_var_names)}, "
|
|
512
|
+
f"gen={len(self._generated_obs)}x{len(self._generated_var_names)}, "
|
|
513
|
+
f"common_genes={self.n_genes}, "
|
|
514
|
+
f"batch_size={self.batch_size})"
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
def load_data_lazy(
|
|
519
|
+
real_path: Union[str, Path],
|
|
520
|
+
generated_path: Union[str, Path],
|
|
521
|
+
condition_columns: List[str],
|
|
522
|
+
split_column: Optional[str] = None,
|
|
523
|
+
batch_size: int = 256,
|
|
524
|
+
use_backed: bool = False,
|
|
525
|
+
**kwargs
|
|
526
|
+
) -> LazyGeneExpressionDataLoader:
|
|
527
|
+
"""
|
|
528
|
+
Convenience function to create a lazy data loader.
|
|
529
|
+
|
|
530
|
+
Parameters
|
|
531
|
+
----------
|
|
532
|
+
real_path : str or Path
|
|
533
|
+
Path to real data h5ad file
|
|
534
|
+
generated_path : str or Path
|
|
535
|
+
Path to generated data h5ad file
|
|
536
|
+
condition_columns : List[str]
|
|
537
|
+
Columns to match between datasets
|
|
538
|
+
split_column : str, optional
|
|
539
|
+
Column indicating train/test split
|
|
540
|
+
batch_size : int
|
|
541
|
+
Maximum samples per batch
|
|
542
|
+
use_backed : bool
|
|
543
|
+
Use memory-mapped access for very large files
|
|
544
|
+
**kwargs
|
|
545
|
+
Additional arguments for LazyGeneExpressionDataLoader
|
|
546
|
+
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
LazyGeneExpressionDataLoader
|
|
550
|
+
Initialized lazy data loader
|
|
551
|
+
"""
|
|
552
|
+
loader = LazyGeneExpressionDataLoader(
|
|
553
|
+
real_path=real_path,
|
|
554
|
+
generated_path=generated_path,
|
|
555
|
+
condition_columns=condition_columns,
|
|
556
|
+
split_column=split_column,
|
|
557
|
+
batch_size=batch_size,
|
|
558
|
+
use_backed=use_backed,
|
|
559
|
+
**kwargs
|
|
560
|
+
)
|
|
561
|
+
loader.initialize()
|
|
562
|
+
return loader
|