gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of gengeneeval might be problematic. Click here for more details.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
from typing import Optional, List, Union
|
|
2
|
+
import warnings
|
|
3
|
+
import numpy as np
|
|
4
|
+
import anndata as ad
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scanpy as sc
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DataModuleError(Exception):
|
|
10
|
+
"""Custom exception for data module validation errors."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GeneExpressionDataModule:
|
|
14
|
+
"""
|
|
15
|
+
Safe data module for gene expression datasets.
|
|
16
|
+
|
|
17
|
+
Adds robust validation on construction:
|
|
18
|
+
- Checks required obs columns (perturbation_key, split_key)
|
|
19
|
+
- Validates control value presence (if provided)
|
|
20
|
+
- Ensures minimum cells/genes
|
|
21
|
+
- Detects duplicate gene names
|
|
22
|
+
- Flags sparsity, normalization, log state
|
|
23
|
+
- Prevents negative counts before preprocessing
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
adata: ad.AnnData,
|
|
29
|
+
perturbation_key: str,
|
|
30
|
+
split_key: str,
|
|
31
|
+
control: Optional[str] = None,
|
|
32
|
+
condition_keys: Optional[List[str]] = None,
|
|
33
|
+
min_cells: int = 10,
|
|
34
|
+
min_genes: int = 50,
|
|
35
|
+
allow_float_counts: bool = True,
|
|
36
|
+
enforce_unique_var_names: bool = True,
|
|
37
|
+
):
|
|
38
|
+
if adata is None:
|
|
39
|
+
raise DataModuleError("AnnData object cannot be None.")
|
|
40
|
+
self.adata = adata
|
|
41
|
+
self.perturbation_key = perturbation_key
|
|
42
|
+
self.split_key = split_key
|
|
43
|
+
self.control = control
|
|
44
|
+
self.condition_keys = condition_keys or []
|
|
45
|
+
self.min_cells = int(min_cells)
|
|
46
|
+
self.min_genes = int(min_genes)
|
|
47
|
+
self.allow_float_counts = allow_float_counts
|
|
48
|
+
self.enforce_unique_var_names = enforce_unique_var_names
|
|
49
|
+
|
|
50
|
+
# State flags
|
|
51
|
+
self.is_normalized: bool = False
|
|
52
|
+
self.is_logged: bool = False
|
|
53
|
+
self.is_sparse: bool = self._is_sparse(self.adata.X)
|
|
54
|
+
|
|
55
|
+
self._validate_adata()
|
|
56
|
+
|
|
57
|
+
# ----------------- validation helpers -----------------
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _is_sparse(X) -> bool:
|
|
61
|
+
try:
|
|
62
|
+
from scipy import sparse
|
|
63
|
+
return sparse.issparse(X)
|
|
64
|
+
except ImportError:
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
def _validate_obs_column(self, key: str):
|
|
68
|
+
if key not in self.adata.obs.columns:
|
|
69
|
+
raise DataModuleError(f"Required obs column '{key}' not found in AnnData.obs.")
|
|
70
|
+
|
|
71
|
+
def _validate_control(self):
|
|
72
|
+
if self.control is None:
|
|
73
|
+
return
|
|
74
|
+
col = self.perturbation_key
|
|
75
|
+
vals = self.adata.obs[col].astype(str)
|
|
76
|
+
if str(self.control) not in set(vals):
|
|
77
|
+
raise DataModuleError(f"Control value '{self.control}' not present in '{col}' column.")
|
|
78
|
+
|
|
79
|
+
def _validate_sizes(self):
|
|
80
|
+
if self.adata.n_obs < self.min_cells:
|
|
81
|
+
raise DataModuleError(
|
|
82
|
+
f"Too few cells ({self.adata.n_obs}). Minimum required: {self.min_cells}."
|
|
83
|
+
)
|
|
84
|
+
if self.adata.n_vars < self.min_genes:
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"Low gene count ({self.adata.n_vars} < {self.min_genes}). Evaluation may be unstable.",
|
|
87
|
+
RuntimeWarning,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _validate_var_names(self):
|
|
91
|
+
v = pd.Index(self.adata.var_names.astype(str))
|
|
92
|
+
if v.has_duplicates:
|
|
93
|
+
if self.enforce_unique_var_names:
|
|
94
|
+
raise DataModuleError("Duplicate gene names detected in var_names.")
|
|
95
|
+
else:
|
|
96
|
+
warnings.warn("Duplicate gene names detected; downstream alignment may fail.", RuntimeWarning)
|
|
97
|
+
|
|
98
|
+
def _detect_logged(self):
|
|
99
|
+
X = self.adata.X
|
|
100
|
+
# Heuristic: if many values < 0 or max < 50 maybe already logged.
|
|
101
|
+
arr = X.toarray() if self._is_sparse(X) else np.asarray(X)
|
|
102
|
+
if np.any(arr < 0):
|
|
103
|
+
warnings.warn("Negative values detected in expression matrix.", RuntimeWarning)
|
|
104
|
+
# Fraction of integer entries
|
|
105
|
+
finite = np.isfinite(arr)
|
|
106
|
+
sample = arr[finite]
|
|
107
|
+
if sample.size == 0:
|
|
108
|
+
return
|
|
109
|
+
frac_int = np.mean(np.isclose(sample, np.round(sample)))
|
|
110
|
+
if frac_int < 0.7:
|
|
111
|
+
# likely normalized/logged
|
|
112
|
+
self.is_normalized = True
|
|
113
|
+
# check for log transform: typical upper bound after log1p ~ ~15
|
|
114
|
+
if np.nanmax(sample) < 25:
|
|
115
|
+
self.is_logged = True
|
|
116
|
+
|
|
117
|
+
def _validate_preprocessing_state(self):
|
|
118
|
+
# If counts are integers and large, warn if not normalized/logged
|
|
119
|
+
X = self.adata.X
|
|
120
|
+
arr = X.toarray() if self._is_sparse(X) else np.asarray(X)
|
|
121
|
+
finite = arr[np.isfinite(arr)]
|
|
122
|
+
if finite.size == 0:
|
|
123
|
+
raise DataModuleError("Expression matrix contains no finite values.")
|
|
124
|
+
frac_int = np.mean(np.isclose(finite, np.round(finite)))
|
|
125
|
+
if frac_int > 0.95 and np.nanmax(finite) > 50 and not self.is_normalized:
|
|
126
|
+
warnings.warn(
|
|
127
|
+
"Data appears to be raw counts (mostly integers, high max). "
|
|
128
|
+
"Run preprocess_data() before evaluation.",
|
|
129
|
+
RuntimeWarning,
|
|
130
|
+
)
|
|
131
|
+
if np.nanmin(finite) < 0:
|
|
132
|
+
raise DataModuleError("Negative values found in raw counts; data corruption suspected.")
|
|
133
|
+
|
|
134
|
+
def _validate_condition_keys(self):
|
|
135
|
+
for c in self.condition_keys:
|
|
136
|
+
if c not in self.adata.obs.columns:
|
|
137
|
+
warnings.warn(f"Condition key '{c}' not found in obs; it will be ignored.", RuntimeWarning)
|
|
138
|
+
|
|
139
|
+
def _validate_split_column(self):
|
|
140
|
+
self._validate_obs_column(self.split_key)
|
|
141
|
+
splits = set(self.adata.obs[self.split_key].astype(str))
|
|
142
|
+
if not splits.intersection({"test", "train", "val", "validation"}):
|
|
143
|
+
warnings.warn(
|
|
144
|
+
f"Split column '{self.split_key}' lacks standard split labels (e.g., 'test').",
|
|
145
|
+
RuntimeWarning,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def _validate_adata(self):
|
|
149
|
+
self._validate_obs_column(self.perturbation_key)
|
|
150
|
+
self._validate_split_column()
|
|
151
|
+
self._validate_control()
|
|
152
|
+
self._validate_sizes()
|
|
153
|
+
self._validate_var_names()
|
|
154
|
+
self._detect_logged()
|
|
155
|
+
self._validate_preprocessing_state()
|
|
156
|
+
self._validate_condition_keys()
|
|
157
|
+
|
|
158
|
+
# ----------------- public API -----------------
|
|
159
|
+
|
|
160
|
+
def load_data(self, filepath: str):
|
|
161
|
+
"""Load AnnData from file and re-run validation."""
|
|
162
|
+
self.adata = sc.read(filepath)
|
|
163
|
+
self.is_sparse = self._is_sparse(self.adata.X)
|
|
164
|
+
self.is_normalized = False
|
|
165
|
+
self.is_logged = False
|
|
166
|
+
self._validate_adata()
|
|
167
|
+
|
|
168
|
+
def preprocess_data(
|
|
169
|
+
self,
|
|
170
|
+
filter_min_cells: int = 1,
|
|
171
|
+
target_sum: float = 1e4,
|
|
172
|
+
log_base: Union[int, float] = np.e,
|
|
173
|
+
):
|
|
174
|
+
"""
|
|
175
|
+
Apply basic preprocessing: gene filtering, total count normalization, log1p.
|
|
176
|
+
Sets flags accordingly.
|
|
177
|
+
"""
|
|
178
|
+
sc.pp.filter_genes(self.adata, min_cells=filter_min_cells)
|
|
179
|
+
sc.pp.normalize_total(self.adata, target_sum=target_sum)
|
|
180
|
+
sc.pp.log1p(self.adata)
|
|
181
|
+
self.is_normalized = True
|
|
182
|
+
self.is_logged = True
|
|
183
|
+
|
|
184
|
+
def get_data(self) -> ad.AnnData:
|
|
185
|
+
"""Return AnnData (post any preprocessing)."""
|
|
186
|
+
return self.adata
|
|
187
|
+
|
|
188
|
+
def get_conditions(self) -> pd.Series:
|
|
189
|
+
"""Return unique perturbation conditions."""
|
|
190
|
+
return pd.Series(self.adata.obs[self.perturbation_key].unique(), name="condition")
|
|
191
|
+
|
|
192
|
+
def summary(self) -> dict:
|
|
193
|
+
"""Structured summary of current dataset and preprocessing state."""
|
|
194
|
+
return {
|
|
195
|
+
"n_cells": int(self.adata.n_obs),
|
|
196
|
+
"n_genes": int(self.adata.n_vars),
|
|
197
|
+
"is_sparse": bool(self.is_sparse),
|
|
198
|
+
"is_normalized": bool(self.is_normalized),
|
|
199
|
+
"is_logged": bool(self.is_logged),
|
|
200
|
+
"perturbation_key": self.perturbation_key,
|
|
201
|
+
"split_key": self.split_key,
|
|
202
|
+
"control": self.control,
|
|
203
|
+
"condition_keys_present": [c for c in self.condition_keys if c in self.adata.obs.columns],
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
def assert_ready_for_evaluation(self):
|
|
207
|
+
"""Raise error if dataset appears unprocessed."""
|
|
208
|
+
if not self.is_logged or not self.is_normalized:
|
|
209
|
+
raise DataModuleError(
|
|
210
|
+
"Dataset not preprocessed (normalization/log). Call preprocess_data() before evaluation."
|
|
211
|
+
)
|
geneval/data/loader.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loader module for paired real and generated datasets.
|
|
3
|
+
|
|
4
|
+
Provides loading, validation, and alignment of AnnData objects for evaluation.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Optional, List, Union, Dict, Tuple, Iterator
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import warnings
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from scipy import sparse
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import anndata as ad
|
|
17
|
+
import scanpy as sc
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DataLoaderError(Exception):
|
|
23
|
+
"""Custom exception for data loading errors."""
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GeneExpressionDataLoader:
|
|
28
|
+
"""
|
|
29
|
+
Data loader for paired real and generated gene expression datasets.
|
|
30
|
+
|
|
31
|
+
Handles:
|
|
32
|
+
- Loading AnnData files (h5ad format)
|
|
33
|
+
- Validation of required columns
|
|
34
|
+
- Alignment of gene names between datasets
|
|
35
|
+
- Matching samples by condition columns
|
|
36
|
+
- Split handling (train/test/all)
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
real_path : str or Path
|
|
41
|
+
Path to real data h5ad file
|
|
42
|
+
generated_path : str or Path
|
|
43
|
+
Path to generated data h5ad file
|
|
44
|
+
condition_columns : List[str]
|
|
45
|
+
Columns to match between datasets (e.g., ['perturbation', 'cell_type'])
|
|
46
|
+
split_column : str, optional
|
|
47
|
+
Column indicating train/test split. If None, all data treated as single split.
|
|
48
|
+
min_samples_per_condition : int
|
|
49
|
+
Minimum samples required per condition to include
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
real_path: Union[str, Path],
|
|
55
|
+
generated_path: Union[str, Path],
|
|
56
|
+
condition_columns: List[str],
|
|
57
|
+
split_column: Optional[str] = None,
|
|
58
|
+
min_samples_per_condition: int = 2,
|
|
59
|
+
):
|
|
60
|
+
self.real_path = Path(real_path)
|
|
61
|
+
self.generated_path = Path(generated_path)
|
|
62
|
+
self.condition_columns = condition_columns
|
|
63
|
+
self.split_column = split_column
|
|
64
|
+
self.min_samples_per_condition = min_samples_per_condition
|
|
65
|
+
|
|
66
|
+
# Loaded data
|
|
67
|
+
self._real: Optional[ad.AnnData] = None
|
|
68
|
+
self._generated: Optional[ad.AnnData] = None
|
|
69
|
+
|
|
70
|
+
# Aligned data
|
|
71
|
+
self._real_aligned: Optional[ad.AnnData] = None
|
|
72
|
+
self._generated_aligned: Optional[ad.AnnData] = None
|
|
73
|
+
|
|
74
|
+
# Common genes and conditions
|
|
75
|
+
self._common_genes: Optional[List[str]] = None
|
|
76
|
+
self._common_conditions: Optional[Dict[str, List[str]]] = None
|
|
77
|
+
|
|
78
|
+
# Cached condition masks
|
|
79
|
+
self._condition_cache: Dict[str, Dict[str, np.ndarray]] = {}
|
|
80
|
+
|
|
81
|
+
# Validation state
|
|
82
|
+
self._is_loaded = False
|
|
83
|
+
self._is_aligned = False
|
|
84
|
+
|
|
85
|
+
def load(self) -> "GeneExpressionDataLoader":
|
|
86
|
+
"""
|
|
87
|
+
Load both datasets from disk.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
self
|
|
92
|
+
For method chaining
|
|
93
|
+
"""
|
|
94
|
+
# Load real data
|
|
95
|
+
if not self.real_path.exists():
|
|
96
|
+
raise DataLoaderError(f"Real data file not found: {self.real_path}")
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
self._real = sc.read_h5ad(self.real_path)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
raise DataLoaderError(f"Failed to load real data: {e}")
|
|
102
|
+
|
|
103
|
+
# Load generated data
|
|
104
|
+
if not self.generated_path.exists():
|
|
105
|
+
raise DataLoaderError(f"Generated data file not found: {self.generated_path}")
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
self._generated = sc.read_h5ad(self.generated_path)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise DataLoaderError(f"Failed to load generated data: {e}")
|
|
111
|
+
|
|
112
|
+
# Validate columns
|
|
113
|
+
self._validate_columns()
|
|
114
|
+
|
|
115
|
+
self._is_loaded = True
|
|
116
|
+
return self
|
|
117
|
+
|
|
118
|
+
def _validate_columns(self):
|
|
119
|
+
"""Validate that required columns exist in both datasets."""
|
|
120
|
+
for col in self.condition_columns:
|
|
121
|
+
if col not in self._real.obs.columns:
|
|
122
|
+
raise DataLoaderError(
|
|
123
|
+
f"Condition column '{col}' not found in real data. "
|
|
124
|
+
f"Available columns: {list(self._real.obs.columns)}"
|
|
125
|
+
)
|
|
126
|
+
if col not in self._generated.obs.columns:
|
|
127
|
+
raise DataLoaderError(
|
|
128
|
+
f"Condition column '{col}' not found in generated data. "
|
|
129
|
+
f"Available columns: {list(self._generated.obs.columns)}"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if self.split_column is not None:
|
|
133
|
+
if self.split_column not in self._real.obs.columns:
|
|
134
|
+
raise DataLoaderError(
|
|
135
|
+
f"Split column '{self.split_column}' not found in real data."
|
|
136
|
+
)
|
|
137
|
+
# Generated data may not have split column - that's OK
|
|
138
|
+
if self.split_column not in self._generated.obs.columns:
|
|
139
|
+
warnings.warn(
|
|
140
|
+
f"Split column '{self.split_column}' not in generated data. "
|
|
141
|
+
"Generated data will be matched to real data by conditions only."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def align_genes(self) -> "GeneExpressionDataLoader":
|
|
145
|
+
"""
|
|
146
|
+
Align gene names between real and generated datasets.
|
|
147
|
+
|
|
148
|
+
Keeps only genes present in both datasets in the same order.
|
|
149
|
+
|
|
150
|
+
Returns
|
|
151
|
+
-------
|
|
152
|
+
self
|
|
153
|
+
For method chaining
|
|
154
|
+
"""
|
|
155
|
+
if not self._is_loaded:
|
|
156
|
+
raise DataLoaderError("Data not loaded. Call load() first.")
|
|
157
|
+
|
|
158
|
+
real_genes = pd.Index(self._real.var_names.astype(str))
|
|
159
|
+
gen_genes = pd.Index(self._generated.var_names.astype(str))
|
|
160
|
+
|
|
161
|
+
# Find common genes
|
|
162
|
+
common = real_genes.intersection(gen_genes)
|
|
163
|
+
|
|
164
|
+
if len(common) == 0:
|
|
165
|
+
raise DataLoaderError(
|
|
166
|
+
"No overlapping genes between real and generated data."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Warn about dropped genes
|
|
170
|
+
n_real_only = len(real_genes) - len(common)
|
|
171
|
+
n_gen_only = len(gen_genes) - len(common)
|
|
172
|
+
|
|
173
|
+
if n_real_only > 0 or n_gen_only > 0:
|
|
174
|
+
warnings.warn(
|
|
175
|
+
f"Gene alignment: keeping {len(common)} common genes. "
|
|
176
|
+
f"Dropped {n_real_only} from real, {n_gen_only} from generated."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Subset and order genes
|
|
180
|
+
self._common_genes = common.tolist()
|
|
181
|
+
|
|
182
|
+
# Create aligned copies
|
|
183
|
+
real_idx = real_genes.get_indexer(common)
|
|
184
|
+
gen_idx = gen_genes.get_indexer(common)
|
|
185
|
+
|
|
186
|
+
self._real_aligned = self._real[:, real_idx].copy()
|
|
187
|
+
self._generated_aligned = self._generated[:, gen_idx].copy()
|
|
188
|
+
|
|
189
|
+
# Ensure var_names match
|
|
190
|
+
self._real_aligned.var_names = common
|
|
191
|
+
self._generated_aligned.var_names = common
|
|
192
|
+
|
|
193
|
+
self._is_aligned = True
|
|
194
|
+
return self
|
|
195
|
+
|
|
196
|
+
def _get_condition_key(self, row: pd.Series) -> str:
|
|
197
|
+
"""Generate unique key for a condition combination."""
|
|
198
|
+
return "####".join([str(row[c]) for c in self.condition_columns])
|
|
199
|
+
|
|
200
|
+
def _build_condition_masks(
|
|
201
|
+
self,
|
|
202
|
+
adata: ad.AnnData,
|
|
203
|
+
split: Optional[str] = None
|
|
204
|
+
) -> Dict[str, np.ndarray]:
|
|
205
|
+
"""Build boolean masks for each unique condition."""
|
|
206
|
+
obs = adata.obs.copy()
|
|
207
|
+
|
|
208
|
+
# Apply split filter if specified
|
|
209
|
+
if split is not None and self.split_column is not None:
|
|
210
|
+
if self.split_column in obs.columns:
|
|
211
|
+
split_mask = obs[self.split_column].astype(str) == split
|
|
212
|
+
obs = obs[split_mask]
|
|
213
|
+
|
|
214
|
+
# Get unique condition combinations
|
|
215
|
+
conditions = obs[self.condition_columns].astype(str).drop_duplicates()
|
|
216
|
+
|
|
217
|
+
masks = {}
|
|
218
|
+
for _, row in conditions.iterrows():
|
|
219
|
+
key = self._get_condition_key(row)
|
|
220
|
+
|
|
221
|
+
# Build mask
|
|
222
|
+
mask = np.ones(adata.n_obs, dtype=bool)
|
|
223
|
+
for col in self.condition_columns:
|
|
224
|
+
mask &= (adata.obs[col].astype(str) == str(row[col])).values
|
|
225
|
+
|
|
226
|
+
if split is not None and self.split_column is not None:
|
|
227
|
+
if self.split_column in adata.obs.columns:
|
|
228
|
+
mask &= (adata.obs[self.split_column].astype(str) == split).values
|
|
229
|
+
|
|
230
|
+
if mask.sum() >= self.min_samples_per_condition:
|
|
231
|
+
masks[key] = mask
|
|
232
|
+
|
|
233
|
+
return masks
|
|
234
|
+
|
|
235
|
+
def get_splits(self) -> List[str]:
|
|
236
|
+
"""
|
|
237
|
+
Get list of available splits.
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
List[str]
|
|
242
|
+
Split names (e.g., ['train', 'test'] or ['all'])
|
|
243
|
+
"""
|
|
244
|
+
if not self._is_loaded:
|
|
245
|
+
raise DataLoaderError("Data not loaded. Call load() first.")
|
|
246
|
+
|
|
247
|
+
if self.split_column is None:
|
|
248
|
+
return ["all"]
|
|
249
|
+
|
|
250
|
+
if self.split_column not in self._real.obs.columns:
|
|
251
|
+
return ["all"]
|
|
252
|
+
|
|
253
|
+
return list(self._real.obs[self.split_column].astype(str).unique())
|
|
254
|
+
|
|
255
|
+
def get_common_conditions(
|
|
256
|
+
self,
|
|
257
|
+
split: Optional[str] = None
|
|
258
|
+
) -> List[str]:
|
|
259
|
+
"""
|
|
260
|
+
Get conditions present in both real and generated data.
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
split : str, optional
|
|
265
|
+
If specified, only return conditions in this split
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
List[str]
|
|
270
|
+
Condition keys present in both datasets
|
|
271
|
+
"""
|
|
272
|
+
if not self._is_aligned:
|
|
273
|
+
self.align_genes()
|
|
274
|
+
|
|
275
|
+
real_masks = self._build_condition_masks(self._real_aligned, split)
|
|
276
|
+
gen_masks = self._build_condition_masks(self._generated_aligned, None)
|
|
277
|
+
|
|
278
|
+
# Find intersection
|
|
279
|
+
common = sorted(set(real_masks.keys()) & set(gen_masks.keys()))
|
|
280
|
+
|
|
281
|
+
return common
|
|
282
|
+
|
|
283
|
+
def iterate_conditions(
|
|
284
|
+
self,
|
|
285
|
+
split: Optional[str] = None
|
|
286
|
+
) -> Iterator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]]:
|
|
287
|
+
"""
|
|
288
|
+
Iterate over matched conditions yielding aligned data.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
split : str, optional
|
|
293
|
+
If specified, only iterate conditions in this split
|
|
294
|
+
|
|
295
|
+
Yields
|
|
296
|
+
------
|
|
297
|
+
Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
|
|
298
|
+
(condition_key, real_data, generated_data, condition_info)
|
|
299
|
+
where condition_info contains the parsed condition values
|
|
300
|
+
"""
|
|
301
|
+
if not self._is_aligned:
|
|
302
|
+
self.align_genes()
|
|
303
|
+
|
|
304
|
+
real_masks = self._build_condition_masks(self._real_aligned, split)
|
|
305
|
+
gen_masks = self._build_condition_masks(self._generated_aligned, None)
|
|
306
|
+
|
|
307
|
+
common = sorted(set(real_masks.keys()) & set(gen_masks.keys()))
|
|
308
|
+
|
|
309
|
+
for key in common:
|
|
310
|
+
real_mask = real_masks[key]
|
|
311
|
+
gen_mask = gen_masks[key]
|
|
312
|
+
|
|
313
|
+
# Extract data matrices
|
|
314
|
+
real_data = self._to_dense(self._real_aligned.X[real_mask])
|
|
315
|
+
gen_data = self._to_dense(self._generated_aligned.X[gen_mask])
|
|
316
|
+
|
|
317
|
+
# Parse condition info
|
|
318
|
+
parts = key.split("####")
|
|
319
|
+
condition_info = dict(zip(self.condition_columns, parts))
|
|
320
|
+
|
|
321
|
+
yield key, real_data, gen_data, condition_info
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _to_dense(X) -> np.ndarray:
|
|
325
|
+
"""Convert matrix to dense numpy array."""
|
|
326
|
+
if sparse.issparse(X):
|
|
327
|
+
return X.toarray()
|
|
328
|
+
return np.asarray(X)
|
|
329
|
+
|
|
330
|
+
@property
|
|
331
|
+
def real(self) -> ad.AnnData:
|
|
332
|
+
"""Get aligned real data."""
|
|
333
|
+
if not self._is_aligned:
|
|
334
|
+
self.align_genes()
|
|
335
|
+
return self._real_aligned
|
|
336
|
+
|
|
337
|
+
@property
|
|
338
|
+
def generated(self) -> ad.AnnData:
|
|
339
|
+
"""Get aligned generated data."""
|
|
340
|
+
if not self._is_aligned:
|
|
341
|
+
self.align_genes()
|
|
342
|
+
return self._generated_aligned
|
|
343
|
+
|
|
344
|
+
@property
|
|
345
|
+
def gene_names(self) -> List[str]:
|
|
346
|
+
"""Get common gene names."""
|
|
347
|
+
if not self._is_aligned:
|
|
348
|
+
self.align_genes()
|
|
349
|
+
return self._common_genes
|
|
350
|
+
|
|
351
|
+
@property
|
|
352
|
+
def n_genes(self) -> int:
|
|
353
|
+
"""Number of common genes."""
|
|
354
|
+
return len(self.gene_names)
|
|
355
|
+
|
|
356
|
+
def summary(self) -> Dict[str, any]:
|
|
357
|
+
"""Get summary of loaded data."""
|
|
358
|
+
if not self._is_loaded:
|
|
359
|
+
return {"loaded": False}
|
|
360
|
+
|
|
361
|
+
result = {
|
|
362
|
+
"loaded": True,
|
|
363
|
+
"aligned": self._is_aligned,
|
|
364
|
+
"real": {
|
|
365
|
+
"n_samples": self._real.n_obs,
|
|
366
|
+
"n_genes": self._real.n_vars,
|
|
367
|
+
"path": str(self.real_path),
|
|
368
|
+
},
|
|
369
|
+
"generated": {
|
|
370
|
+
"n_samples": self._generated.n_obs,
|
|
371
|
+
"n_genes": self._generated.n_vars,
|
|
372
|
+
"path": str(self.generated_path),
|
|
373
|
+
},
|
|
374
|
+
"condition_columns": self.condition_columns,
|
|
375
|
+
"split_column": self.split_column,
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
if self._is_aligned:
|
|
379
|
+
result["n_common_genes"] = len(self._common_genes)
|
|
380
|
+
result["splits"] = self.get_splits()
|
|
381
|
+
|
|
382
|
+
for split in result["splits"]:
|
|
383
|
+
s = split if split != "all" else None
|
|
384
|
+
result[f"n_conditions_{split}"] = len(self.get_common_conditions(s))
|
|
385
|
+
|
|
386
|
+
return result
|
|
387
|
+
|
|
388
|
+
def __repr__(self) -> str:
|
|
389
|
+
if not self._is_loaded:
|
|
390
|
+
return "GeneExpressionDataLoader(not loaded)"
|
|
391
|
+
|
|
392
|
+
return (
|
|
393
|
+
f"GeneExpressionDataLoader("
|
|
394
|
+
f"real={self._real.n_obs}x{self._real.n_vars}, "
|
|
395
|
+
f"gen={self._generated.n_obs}x{self._generated.n_vars}, "
|
|
396
|
+
f"aligned={self._is_aligned})"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def load_data(
|
|
401
|
+
real_path: Union[str, Path],
|
|
402
|
+
generated_path: Union[str, Path],
|
|
403
|
+
condition_columns: List[str],
|
|
404
|
+
split_column: Optional[str] = None,
|
|
405
|
+
**kwargs
|
|
406
|
+
) -> GeneExpressionDataLoader:
|
|
407
|
+
"""
|
|
408
|
+
Convenience function to load and align data.
|
|
409
|
+
|
|
410
|
+
Parameters
|
|
411
|
+
----------
|
|
412
|
+
real_path : str or Path
|
|
413
|
+
Path to real data h5ad file
|
|
414
|
+
generated_path : str or Path
|
|
415
|
+
Path to generated data h5ad file
|
|
416
|
+
condition_columns : List[str]
|
|
417
|
+
Columns to match between datasets
|
|
418
|
+
split_column : str, optional
|
|
419
|
+
Column indicating train/test split
|
|
420
|
+
**kwargs
|
|
421
|
+
Additional arguments for GeneExpressionDataLoader
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
GeneExpressionDataLoader
|
|
426
|
+
Loaded and aligned data loader
|
|
427
|
+
"""
|
|
428
|
+
loader = GeneExpressionDataLoader(
|
|
429
|
+
real_path=real_path,
|
|
430
|
+
generated_path=generated_path,
|
|
431
|
+
condition_columns=condition_columns,
|
|
432
|
+
split_column=split_column,
|
|
433
|
+
**kwargs
|
|
434
|
+
)
|
|
435
|
+
loader.load()
|
|
436
|
+
loader.align_genes()
|
|
437
|
+
return loader
|