gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of gengeneeval might be problematic. Click here for more details.

@@ -0,0 +1,211 @@
1
+ from typing import Optional, List, Union
2
+ import warnings
3
+ import numpy as np
4
+ import anndata as ad
5
+ import pandas as pd
6
+ import scanpy as sc
7
+
8
+
9
+ class DataModuleError(Exception):
10
+ """Custom exception for data module validation errors."""
11
+
12
+
13
+ class GeneExpressionDataModule:
14
+ """
15
+ Safe data module for gene expression datasets.
16
+
17
+ Adds robust validation on construction:
18
+ - Checks required obs columns (perturbation_key, split_key)
19
+ - Validates control value presence (if provided)
20
+ - Ensures minimum cells/genes
21
+ - Detects duplicate gene names
22
+ - Flags sparsity, normalization, log state
23
+ - Prevents negative counts before preprocessing
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ adata: ad.AnnData,
29
+ perturbation_key: str,
30
+ split_key: str,
31
+ control: Optional[str] = None,
32
+ condition_keys: Optional[List[str]] = None,
33
+ min_cells: int = 10,
34
+ min_genes: int = 50,
35
+ allow_float_counts: bool = True,
36
+ enforce_unique_var_names: bool = True,
37
+ ):
38
+ if adata is None:
39
+ raise DataModuleError("AnnData object cannot be None.")
40
+ self.adata = adata
41
+ self.perturbation_key = perturbation_key
42
+ self.split_key = split_key
43
+ self.control = control
44
+ self.condition_keys = condition_keys or []
45
+ self.min_cells = int(min_cells)
46
+ self.min_genes = int(min_genes)
47
+ self.allow_float_counts = allow_float_counts
48
+ self.enforce_unique_var_names = enforce_unique_var_names
49
+
50
+ # State flags
51
+ self.is_normalized: bool = False
52
+ self.is_logged: bool = False
53
+ self.is_sparse: bool = self._is_sparse(self.adata.X)
54
+
55
+ self._validate_adata()
56
+
57
+ # ----------------- validation helpers -----------------
58
+
59
+ @staticmethod
60
+ def _is_sparse(X) -> bool:
61
+ try:
62
+ from scipy import sparse
63
+ return sparse.issparse(X)
64
+ except ImportError:
65
+ return False
66
+
67
+ def _validate_obs_column(self, key: str):
68
+ if key not in self.adata.obs.columns:
69
+ raise DataModuleError(f"Required obs column '{key}' not found in AnnData.obs.")
70
+
71
+ def _validate_control(self):
72
+ if self.control is None:
73
+ return
74
+ col = self.perturbation_key
75
+ vals = self.adata.obs[col].astype(str)
76
+ if str(self.control) not in set(vals):
77
+ raise DataModuleError(f"Control value '{self.control}' not present in '{col}' column.")
78
+
79
+ def _validate_sizes(self):
80
+ if self.adata.n_obs < self.min_cells:
81
+ raise DataModuleError(
82
+ f"Too few cells ({self.adata.n_obs}). Minimum required: {self.min_cells}."
83
+ )
84
+ if self.adata.n_vars < self.min_genes:
85
+ warnings.warn(
86
+ f"Low gene count ({self.adata.n_vars} < {self.min_genes}). Evaluation may be unstable.",
87
+ RuntimeWarning,
88
+ )
89
+
90
+ def _validate_var_names(self):
91
+ v = pd.Index(self.adata.var_names.astype(str))
92
+ if v.has_duplicates:
93
+ if self.enforce_unique_var_names:
94
+ raise DataModuleError("Duplicate gene names detected in var_names.")
95
+ else:
96
+ warnings.warn("Duplicate gene names detected; downstream alignment may fail.", RuntimeWarning)
97
+
98
+ def _detect_logged(self):
99
+ X = self.adata.X
100
+ # Heuristic: if many values < 0 or max < 50 maybe already logged.
101
+ arr = X.toarray() if self._is_sparse(X) else np.asarray(X)
102
+ if np.any(arr < 0):
103
+ warnings.warn("Negative values detected in expression matrix.", RuntimeWarning)
104
+ # Fraction of integer entries
105
+ finite = np.isfinite(arr)
106
+ sample = arr[finite]
107
+ if sample.size == 0:
108
+ return
109
+ frac_int = np.mean(np.isclose(sample, np.round(sample)))
110
+ if frac_int < 0.7:
111
+ # likely normalized/logged
112
+ self.is_normalized = True
113
+ # check for log transform: typical upper bound after log1p ~ ~15
114
+ if np.nanmax(sample) < 25:
115
+ self.is_logged = True
116
+
117
+ def _validate_preprocessing_state(self):
118
+ # If counts are integers and large, warn if not normalized/logged
119
+ X = self.adata.X
120
+ arr = X.toarray() if self._is_sparse(X) else np.asarray(X)
121
+ finite = arr[np.isfinite(arr)]
122
+ if finite.size == 0:
123
+ raise DataModuleError("Expression matrix contains no finite values.")
124
+ frac_int = np.mean(np.isclose(finite, np.round(finite)))
125
+ if frac_int > 0.95 and np.nanmax(finite) > 50 and not self.is_normalized:
126
+ warnings.warn(
127
+ "Data appears to be raw counts (mostly integers, high max). "
128
+ "Run preprocess_data() before evaluation.",
129
+ RuntimeWarning,
130
+ )
131
+ if np.nanmin(finite) < 0:
132
+ raise DataModuleError("Negative values found in raw counts; data corruption suspected.")
133
+
134
+ def _validate_condition_keys(self):
135
+ for c in self.condition_keys:
136
+ if c not in self.adata.obs.columns:
137
+ warnings.warn(f"Condition key '{c}' not found in obs; it will be ignored.", RuntimeWarning)
138
+
139
+ def _validate_split_column(self):
140
+ self._validate_obs_column(self.split_key)
141
+ splits = set(self.adata.obs[self.split_key].astype(str))
142
+ if not splits.intersection({"test", "train", "val", "validation"}):
143
+ warnings.warn(
144
+ f"Split column '{self.split_key}' lacks standard split labels (e.g., 'test').",
145
+ RuntimeWarning,
146
+ )
147
+
148
+ def _validate_adata(self):
149
+ self._validate_obs_column(self.perturbation_key)
150
+ self._validate_split_column()
151
+ self._validate_control()
152
+ self._validate_sizes()
153
+ self._validate_var_names()
154
+ self._detect_logged()
155
+ self._validate_preprocessing_state()
156
+ self._validate_condition_keys()
157
+
158
+ # ----------------- public API -----------------
159
+
160
+ def load_data(self, filepath: str):
161
+ """Load AnnData from file and re-run validation."""
162
+ self.adata = sc.read(filepath)
163
+ self.is_sparse = self._is_sparse(self.adata.X)
164
+ self.is_normalized = False
165
+ self.is_logged = False
166
+ self._validate_adata()
167
+
168
+ def preprocess_data(
169
+ self,
170
+ filter_min_cells: int = 1,
171
+ target_sum: float = 1e4,
172
+ log_base: Union[int, float] = np.e,
173
+ ):
174
+ """
175
+ Apply basic preprocessing: gene filtering, total count normalization, log1p.
176
+ Sets flags accordingly.
177
+ """
178
+ sc.pp.filter_genes(self.adata, min_cells=filter_min_cells)
179
+ sc.pp.normalize_total(self.adata, target_sum=target_sum)
180
+ sc.pp.log1p(self.adata)
181
+ self.is_normalized = True
182
+ self.is_logged = True
183
+
184
+ def get_data(self) -> ad.AnnData:
185
+ """Return AnnData (post any preprocessing)."""
186
+ return self.adata
187
+
188
+ def get_conditions(self) -> pd.Series:
189
+ """Return unique perturbation conditions."""
190
+ return pd.Series(self.adata.obs[self.perturbation_key].unique(), name="condition")
191
+
192
+ def summary(self) -> dict:
193
+ """Structured summary of current dataset and preprocessing state."""
194
+ return {
195
+ "n_cells": int(self.adata.n_obs),
196
+ "n_genes": int(self.adata.n_vars),
197
+ "is_sparse": bool(self.is_sparse),
198
+ "is_normalized": bool(self.is_normalized),
199
+ "is_logged": bool(self.is_logged),
200
+ "perturbation_key": self.perturbation_key,
201
+ "split_key": self.split_key,
202
+ "control": self.control,
203
+ "condition_keys_present": [c for c in self.condition_keys if c in self.adata.obs.columns],
204
+ }
205
+
206
+ def assert_ready_for_evaluation(self):
207
+ """Raise error if dataset appears unprocessed."""
208
+ if not self.is_logged or not self.is_normalized:
209
+ raise DataModuleError(
210
+ "Dataset not preprocessed (normalization/log). Call preprocess_data() before evaluation."
211
+ )
geneval/data/loader.py ADDED
@@ -0,0 +1,437 @@
1
+ """
2
+ Data loader module for paired real and generated datasets.
3
+
4
+ Provides loading, validation, and alignment of AnnData objects for evaluation.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Optional, List, Union, Dict, Tuple, Iterator
9
+ from pathlib import Path
10
+ import warnings
11
+ import numpy as np
12
+ import pandas as pd
13
+ from scipy import sparse
14
+
15
+ try:
16
+ import anndata as ad
17
+ import scanpy as sc
18
+ except ImportError:
19
+ raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
20
+
21
+
22
+ class DataLoaderError(Exception):
23
+ """Custom exception for data loading errors."""
24
+ pass
25
+
26
+
27
+ class GeneExpressionDataLoader:
28
+ """
29
+ Data loader for paired real and generated gene expression datasets.
30
+
31
+ Handles:
32
+ - Loading AnnData files (h5ad format)
33
+ - Validation of required columns
34
+ - Alignment of gene names between datasets
35
+ - Matching samples by condition columns
36
+ - Split handling (train/test/all)
37
+
38
+ Parameters
39
+ ----------
40
+ real_path : str or Path
41
+ Path to real data h5ad file
42
+ generated_path : str or Path
43
+ Path to generated data h5ad file
44
+ condition_columns : List[str]
45
+ Columns to match between datasets (e.g., ['perturbation', 'cell_type'])
46
+ split_column : str, optional
47
+ Column indicating train/test split. If None, all data treated as single split.
48
+ min_samples_per_condition : int
49
+ Minimum samples required per condition to include
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ real_path: Union[str, Path],
55
+ generated_path: Union[str, Path],
56
+ condition_columns: List[str],
57
+ split_column: Optional[str] = None,
58
+ min_samples_per_condition: int = 2,
59
+ ):
60
+ self.real_path = Path(real_path)
61
+ self.generated_path = Path(generated_path)
62
+ self.condition_columns = condition_columns
63
+ self.split_column = split_column
64
+ self.min_samples_per_condition = min_samples_per_condition
65
+
66
+ # Loaded data
67
+ self._real: Optional[ad.AnnData] = None
68
+ self._generated: Optional[ad.AnnData] = None
69
+
70
+ # Aligned data
71
+ self._real_aligned: Optional[ad.AnnData] = None
72
+ self._generated_aligned: Optional[ad.AnnData] = None
73
+
74
+ # Common genes and conditions
75
+ self._common_genes: Optional[List[str]] = None
76
+ self._common_conditions: Optional[Dict[str, List[str]]] = None
77
+
78
+ # Cached condition masks
79
+ self._condition_cache: Dict[str, Dict[str, np.ndarray]] = {}
80
+
81
+ # Validation state
82
+ self._is_loaded = False
83
+ self._is_aligned = False
84
+
85
+ def load(self) -> "GeneExpressionDataLoader":
86
+ """
87
+ Load both datasets from disk.
88
+
89
+ Returns
90
+ -------
91
+ self
92
+ For method chaining
93
+ """
94
+ # Load real data
95
+ if not self.real_path.exists():
96
+ raise DataLoaderError(f"Real data file not found: {self.real_path}")
97
+
98
+ try:
99
+ self._real = sc.read_h5ad(self.real_path)
100
+ except Exception as e:
101
+ raise DataLoaderError(f"Failed to load real data: {e}")
102
+
103
+ # Load generated data
104
+ if not self.generated_path.exists():
105
+ raise DataLoaderError(f"Generated data file not found: {self.generated_path}")
106
+
107
+ try:
108
+ self._generated = sc.read_h5ad(self.generated_path)
109
+ except Exception as e:
110
+ raise DataLoaderError(f"Failed to load generated data: {e}")
111
+
112
+ # Validate columns
113
+ self._validate_columns()
114
+
115
+ self._is_loaded = True
116
+ return self
117
+
118
+ def _validate_columns(self):
119
+ """Validate that required columns exist in both datasets."""
120
+ for col in self.condition_columns:
121
+ if col not in self._real.obs.columns:
122
+ raise DataLoaderError(
123
+ f"Condition column '{col}' not found in real data. "
124
+ f"Available columns: {list(self._real.obs.columns)}"
125
+ )
126
+ if col not in self._generated.obs.columns:
127
+ raise DataLoaderError(
128
+ f"Condition column '{col}' not found in generated data. "
129
+ f"Available columns: {list(self._generated.obs.columns)}"
130
+ )
131
+
132
+ if self.split_column is not None:
133
+ if self.split_column not in self._real.obs.columns:
134
+ raise DataLoaderError(
135
+ f"Split column '{self.split_column}' not found in real data."
136
+ )
137
+ # Generated data may not have split column - that's OK
138
+ if self.split_column not in self._generated.obs.columns:
139
+ warnings.warn(
140
+ f"Split column '{self.split_column}' not in generated data. "
141
+ "Generated data will be matched to real data by conditions only."
142
+ )
143
+
144
+ def align_genes(self) -> "GeneExpressionDataLoader":
145
+ """
146
+ Align gene names between real and generated datasets.
147
+
148
+ Keeps only genes present in both datasets in the same order.
149
+
150
+ Returns
151
+ -------
152
+ self
153
+ For method chaining
154
+ """
155
+ if not self._is_loaded:
156
+ raise DataLoaderError("Data not loaded. Call load() first.")
157
+
158
+ real_genes = pd.Index(self._real.var_names.astype(str))
159
+ gen_genes = pd.Index(self._generated.var_names.astype(str))
160
+
161
+ # Find common genes
162
+ common = real_genes.intersection(gen_genes)
163
+
164
+ if len(common) == 0:
165
+ raise DataLoaderError(
166
+ "No overlapping genes between real and generated data."
167
+ )
168
+
169
+ # Warn about dropped genes
170
+ n_real_only = len(real_genes) - len(common)
171
+ n_gen_only = len(gen_genes) - len(common)
172
+
173
+ if n_real_only > 0 or n_gen_only > 0:
174
+ warnings.warn(
175
+ f"Gene alignment: keeping {len(common)} common genes. "
176
+ f"Dropped {n_real_only} from real, {n_gen_only} from generated."
177
+ )
178
+
179
+ # Subset and order genes
180
+ self._common_genes = common.tolist()
181
+
182
+ # Create aligned copies
183
+ real_idx = real_genes.get_indexer(common)
184
+ gen_idx = gen_genes.get_indexer(common)
185
+
186
+ self._real_aligned = self._real[:, real_idx].copy()
187
+ self._generated_aligned = self._generated[:, gen_idx].copy()
188
+
189
+ # Ensure var_names match
190
+ self._real_aligned.var_names = common
191
+ self._generated_aligned.var_names = common
192
+
193
+ self._is_aligned = True
194
+ return self
195
+
196
+ def _get_condition_key(self, row: pd.Series) -> str:
197
+ """Generate unique key for a condition combination."""
198
+ return "####".join([str(row[c]) for c in self.condition_columns])
199
+
200
+ def _build_condition_masks(
201
+ self,
202
+ adata: ad.AnnData,
203
+ split: Optional[str] = None
204
+ ) -> Dict[str, np.ndarray]:
205
+ """Build boolean masks for each unique condition."""
206
+ obs = adata.obs.copy()
207
+
208
+ # Apply split filter if specified
209
+ if split is not None and self.split_column is not None:
210
+ if self.split_column in obs.columns:
211
+ split_mask = obs[self.split_column].astype(str) == split
212
+ obs = obs[split_mask]
213
+
214
+ # Get unique condition combinations
215
+ conditions = obs[self.condition_columns].astype(str).drop_duplicates()
216
+
217
+ masks = {}
218
+ for _, row in conditions.iterrows():
219
+ key = self._get_condition_key(row)
220
+
221
+ # Build mask
222
+ mask = np.ones(adata.n_obs, dtype=bool)
223
+ for col in self.condition_columns:
224
+ mask &= (adata.obs[col].astype(str) == str(row[col])).values
225
+
226
+ if split is not None and self.split_column is not None:
227
+ if self.split_column in adata.obs.columns:
228
+ mask &= (adata.obs[self.split_column].astype(str) == split).values
229
+
230
+ if mask.sum() >= self.min_samples_per_condition:
231
+ masks[key] = mask
232
+
233
+ return masks
234
+
235
+ def get_splits(self) -> List[str]:
236
+ """
237
+ Get list of available splits.
238
+
239
+ Returns
240
+ -------
241
+ List[str]
242
+ Split names (e.g., ['train', 'test'] or ['all'])
243
+ """
244
+ if not self._is_loaded:
245
+ raise DataLoaderError("Data not loaded. Call load() first.")
246
+
247
+ if self.split_column is None:
248
+ return ["all"]
249
+
250
+ if self.split_column not in self._real.obs.columns:
251
+ return ["all"]
252
+
253
+ return list(self._real.obs[self.split_column].astype(str).unique())
254
+
255
+ def get_common_conditions(
256
+ self,
257
+ split: Optional[str] = None
258
+ ) -> List[str]:
259
+ """
260
+ Get conditions present in both real and generated data.
261
+
262
+ Parameters
263
+ ----------
264
+ split : str, optional
265
+ If specified, only return conditions in this split
266
+
267
+ Returns
268
+ -------
269
+ List[str]
270
+ Condition keys present in both datasets
271
+ """
272
+ if not self._is_aligned:
273
+ self.align_genes()
274
+
275
+ real_masks = self._build_condition_masks(self._real_aligned, split)
276
+ gen_masks = self._build_condition_masks(self._generated_aligned, None)
277
+
278
+ # Find intersection
279
+ common = sorted(set(real_masks.keys()) & set(gen_masks.keys()))
280
+
281
+ return common
282
+
283
+ def iterate_conditions(
284
+ self,
285
+ split: Optional[str] = None
286
+ ) -> Iterator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]]:
287
+ """
288
+ Iterate over matched conditions yielding aligned data.
289
+
290
+ Parameters
291
+ ----------
292
+ split : str, optional
293
+ If specified, only iterate conditions in this split
294
+
295
+ Yields
296
+ ------
297
+ Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
298
+ (condition_key, real_data, generated_data, condition_info)
299
+ where condition_info contains the parsed condition values
300
+ """
301
+ if not self._is_aligned:
302
+ self.align_genes()
303
+
304
+ real_masks = self._build_condition_masks(self._real_aligned, split)
305
+ gen_masks = self._build_condition_masks(self._generated_aligned, None)
306
+
307
+ common = sorted(set(real_masks.keys()) & set(gen_masks.keys()))
308
+
309
+ for key in common:
310
+ real_mask = real_masks[key]
311
+ gen_mask = gen_masks[key]
312
+
313
+ # Extract data matrices
314
+ real_data = self._to_dense(self._real_aligned.X[real_mask])
315
+ gen_data = self._to_dense(self._generated_aligned.X[gen_mask])
316
+
317
+ # Parse condition info
318
+ parts = key.split("####")
319
+ condition_info = dict(zip(self.condition_columns, parts))
320
+
321
+ yield key, real_data, gen_data, condition_info
322
+
323
+ @staticmethod
324
+ def _to_dense(X) -> np.ndarray:
325
+ """Convert matrix to dense numpy array."""
326
+ if sparse.issparse(X):
327
+ return X.toarray()
328
+ return np.asarray(X)
329
+
330
+ @property
331
+ def real(self) -> ad.AnnData:
332
+ """Get aligned real data."""
333
+ if not self._is_aligned:
334
+ self.align_genes()
335
+ return self._real_aligned
336
+
337
+ @property
338
+ def generated(self) -> ad.AnnData:
339
+ """Get aligned generated data."""
340
+ if not self._is_aligned:
341
+ self.align_genes()
342
+ return self._generated_aligned
343
+
344
+ @property
345
+ def gene_names(self) -> List[str]:
346
+ """Get common gene names."""
347
+ if not self._is_aligned:
348
+ self.align_genes()
349
+ return self._common_genes
350
+
351
+ @property
352
+ def n_genes(self) -> int:
353
+ """Number of common genes."""
354
+ return len(self.gene_names)
355
+
356
+ def summary(self) -> Dict[str, any]:
357
+ """Get summary of loaded data."""
358
+ if not self._is_loaded:
359
+ return {"loaded": False}
360
+
361
+ result = {
362
+ "loaded": True,
363
+ "aligned": self._is_aligned,
364
+ "real": {
365
+ "n_samples": self._real.n_obs,
366
+ "n_genes": self._real.n_vars,
367
+ "path": str(self.real_path),
368
+ },
369
+ "generated": {
370
+ "n_samples": self._generated.n_obs,
371
+ "n_genes": self._generated.n_vars,
372
+ "path": str(self.generated_path),
373
+ },
374
+ "condition_columns": self.condition_columns,
375
+ "split_column": self.split_column,
376
+ }
377
+
378
+ if self._is_aligned:
379
+ result["n_common_genes"] = len(self._common_genes)
380
+ result["splits"] = self.get_splits()
381
+
382
+ for split in result["splits"]:
383
+ s = split if split != "all" else None
384
+ result[f"n_conditions_{split}"] = len(self.get_common_conditions(s))
385
+
386
+ return result
387
+
388
+ def __repr__(self) -> str:
389
+ if not self._is_loaded:
390
+ return "GeneExpressionDataLoader(not loaded)"
391
+
392
+ return (
393
+ f"GeneExpressionDataLoader("
394
+ f"real={self._real.n_obs}x{self._real.n_vars}, "
395
+ f"gen={self._generated.n_obs}x{self._generated.n_vars}, "
396
+ f"aligned={self._is_aligned})"
397
+ )
398
+
399
+
400
+ def load_data(
401
+ real_path: Union[str, Path],
402
+ generated_path: Union[str, Path],
403
+ condition_columns: List[str],
404
+ split_column: Optional[str] = None,
405
+ **kwargs
406
+ ) -> GeneExpressionDataLoader:
407
+ """
408
+ Convenience function to load and align data.
409
+
410
+ Parameters
411
+ ----------
412
+ real_path : str or Path
413
+ Path to real data h5ad file
414
+ generated_path : str or Path
415
+ Path to generated data h5ad file
416
+ condition_columns : List[str]
417
+ Columns to match between datasets
418
+ split_column : str, optional
419
+ Column indicating train/test split
420
+ **kwargs
421
+ Additional arguments for GeneExpressionDataLoader
422
+
423
+ Returns
424
+ -------
425
+ GeneExpressionDataLoader
426
+ Loaded and aligned data loader
427
+ """
428
+ loader = GeneExpressionDataLoader(
429
+ real_path=real_path,
430
+ generated_path=generated_path,
431
+ condition_columns=condition_columns,
432
+ split_column=split_column,
433
+ **kwargs
434
+ )
435
+ loader.load()
436
+ loader.align_genes()
437
+ return loader