gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/evaluator.py ADDED
@@ -0,0 +1,359 @@
1
+ """
2
+ Comprehensive evaluator for gene expression data.
3
+
4
+ Computes all metrics between real and generated data, organized by conditions and splits.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from typing import Dict, List, Optional, Union, Type, Any
9
+ from pathlib import Path
10
+ import numpy as np
11
+ import warnings
12
+ from dataclasses import dataclass
13
+
14
+ from .data.loader import GeneExpressionDataLoader, load_data
15
+ from .metrics.base_metric import BaseMetric, MetricResult
16
+ from .metrics.correlation import (
17
+ PearsonCorrelation,
18
+ SpearmanCorrelation,
19
+ MeanPearsonCorrelation,
20
+ MeanSpearmanCorrelation,
21
+ )
22
+ from .metrics.distances import (
23
+ Wasserstein1Distance,
24
+ Wasserstein2Distance,
25
+ MMDDistance,
26
+ EnergyDistance,
27
+ MultivariateWasserstein,
28
+ MultivariateMMD,
29
+ )
30
+ from .results import EvaluationResult, SplitResult, ConditionResult
31
+
32
+
33
+ # Default metrics to compute
34
+ DEFAULT_METRICS = [
35
+ PearsonCorrelation,
36
+ SpearmanCorrelation,
37
+ MeanPearsonCorrelation,
38
+ MeanSpearmanCorrelation,
39
+ Wasserstein1Distance,
40
+ Wasserstein2Distance,
41
+ MMDDistance,
42
+ EnergyDistance,
43
+ ]
44
+
45
+
46
+ class GeneEvalEvaluator:
47
+ """
48
+ Main evaluator class for gene expression data.
49
+
50
+ Computes comprehensive metrics between real and generated datasets,
51
+ supporting multiple conditions, splits, and metric types.
52
+
53
+ Parameters
54
+ ----------
55
+ data_loader : GeneExpressionDataLoader
56
+ Loaded and aligned data loader
57
+ metrics : List[BaseMetric or Type[BaseMetric]], optional
58
+ Metrics to compute. If None, uses default set.
59
+ aggregate_method : str
60
+ How to aggregate per-gene values (mean, median, etc.)
61
+ include_multivariate : bool
62
+ Whether to include multivariate (whole-space) metrics
63
+ verbose : bool
64
+ Whether to print progress
65
+
66
+ Examples
67
+ --------
68
+ >>> loader = load_data("real.h5ad", "generated.h5ad", ["perturbation"])
69
+ >>> evaluator = GeneEvalEvaluator(loader)
70
+ >>> results = evaluator.evaluate()
71
+ >>> results.save("output/")
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ data_loader: GeneExpressionDataLoader,
77
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
78
+ aggregate_method: str = "mean",
79
+ include_multivariate: bool = True,
80
+ verbose: bool = True,
81
+ ):
82
+ self.data_loader = data_loader
83
+ self.aggregate_method = aggregate_method
84
+ self.include_multivariate = include_multivariate
85
+ self.verbose = verbose
86
+
87
+ # Initialize metrics
88
+ self.metrics: List[BaseMetric] = []
89
+ metric_classes = metrics or DEFAULT_METRICS
90
+
91
+ for m in metric_classes:
92
+ if isinstance(m, type):
93
+ # It's a class, instantiate it
94
+ self.metrics.append(m())
95
+ else:
96
+ # It's already an instance
97
+ self.metrics.append(m)
98
+
99
+ # Add multivariate metrics if requested
100
+ if include_multivariate:
101
+ self.metrics.extend([
102
+ MultivariateWasserstein(),
103
+ MultivariateMMD(),
104
+ ])
105
+
106
+ def _log(self, msg: str):
107
+ """Print message if verbose."""
108
+ if self.verbose:
109
+ print(msg)
110
+
111
+ def evaluate(
112
+ self,
113
+ splits: Optional[List[str]] = None,
114
+ save_dir: Optional[Union[str, Path]] = None,
115
+ ) -> EvaluationResult:
116
+ """
117
+ Run full evaluation on all conditions and splits.
118
+
119
+ Parameters
120
+ ----------
121
+ splits : List[str], optional
122
+ Splits to evaluate. If None, evaluates all available splits.
123
+ save_dir : str or Path, optional
124
+ If provided, save results to this directory
125
+
126
+ Returns
127
+ -------
128
+ EvaluationResult
129
+ Complete evaluation results
130
+ """
131
+ # Get available splits
132
+ available_splits = self.data_loader.get_splits()
133
+
134
+ if splits is None:
135
+ splits = available_splits
136
+ else:
137
+ # Validate requested splits
138
+ invalid = set(splits) - set(available_splits)
139
+ if invalid:
140
+ warnings.warn(f"Requested splits not found: {invalid}")
141
+ splits = [s for s in splits if s in available_splits]
142
+
143
+ self._log(f"Evaluating {len(splits)} splits: {splits}")
144
+ self._log(f"Using {len(self.metrics)} metrics: {[m.name for m in self.metrics]}")
145
+
146
+ # Create result container
147
+ result = EvaluationResult(
148
+ gene_names=self.data_loader.gene_names,
149
+ condition_columns=self.data_loader.condition_columns,
150
+ metadata={
151
+ "real_path": str(self.data_loader.real_path),
152
+ "generated_path": str(self.data_loader.generated_path),
153
+ "aggregate_method": self.aggregate_method,
154
+ "metric_names": [m.name for m in self.metrics],
155
+ }
156
+ )
157
+
158
+ # Evaluate each split
159
+ for split in splits:
160
+ split_key = split if split != "all" else None
161
+ split_result = self._evaluate_split(split, split_key)
162
+ result.add_split(split_result)
163
+
164
+ # Compute aggregate metrics
165
+ for split_result in result.splits.values():
166
+ split_result.compute_aggregates()
167
+
168
+ # Print summary
169
+ if self.verbose:
170
+ self._print_summary(result)
171
+
172
+ # Save if requested
173
+ if save_dir is not None:
174
+ result.save(save_dir)
175
+ self._log(f"Results saved to: {save_dir}")
176
+
177
+ return result
178
+
179
+ def _evaluate_split(
180
+ self,
181
+ split_name: str,
182
+ split_filter: Optional[str]
183
+ ) -> SplitResult:
184
+ """Evaluate a single split."""
185
+ split_result = SplitResult(split_name=split_name)
186
+
187
+ conditions = list(self.data_loader.iterate_conditions(split_filter))
188
+ self._log(f"\n Split '{split_name}': {len(conditions)} conditions")
189
+
190
+ for i, (cond_key, real_data, gen_data, cond_info) in enumerate(conditions):
191
+ if self.verbose and (i + 1) % 10 == 0:
192
+ self._log(f" Processing condition {i + 1}/{len(conditions)}")
193
+
194
+ # Create condition result
195
+ cond_result = ConditionResult(
196
+ condition_key=cond_key,
197
+ split=split_name,
198
+ n_real_samples=real_data.shape[0],
199
+ n_generated_samples=gen_data.shape[0],
200
+ n_genes=real_data.shape[1],
201
+ gene_names=self.data_loader.gene_names,
202
+ perturbation=cond_info.get(self.data_loader.condition_columns[0]),
203
+ covariates=cond_info,
204
+ )
205
+
206
+ # Store mean profiles
207
+ cond_result.real_mean = real_data.mean(axis=0)
208
+ cond_result.generated_mean = gen_data.mean(axis=0)
209
+
210
+ # Compute all metrics
211
+ for metric in self.metrics:
212
+ try:
213
+ metric_result = metric.compute(
214
+ real=real_data,
215
+ generated=gen_data,
216
+ gene_names=self.data_loader.gene_names,
217
+ aggregate_method=self.aggregate_method,
218
+ condition=cond_key,
219
+ split=split_name,
220
+ )
221
+ cond_result.add_metric(metric.name, metric_result)
222
+ except Exception as e:
223
+ warnings.warn(
224
+ f"Failed to compute {metric.name} for {cond_key}: {e}"
225
+ )
226
+
227
+ split_result.add_condition(cond_result)
228
+
229
+ return split_result
230
+
231
+ def _print_summary(self, result: EvaluationResult):
232
+ """Print summary of results."""
233
+ self._log("\n" + "=" * 60)
234
+ self._log("EVALUATION SUMMARY")
235
+ self._log("=" * 60)
236
+
237
+ for split_name, split in result.splits.items():
238
+ self._log(f"\nSplit: {split_name} ({split.n_conditions} conditions)")
239
+ self._log("-" * 40)
240
+
241
+ # Print aggregate metrics
242
+ for key, value in sorted(split.aggregate_metrics.items()):
243
+ if key.endswith("_mean"):
244
+ metric_name = key[:-5]
245
+ std_key = f"{metric_name}_std"
246
+ std = split.aggregate_metrics.get(std_key, 0)
247
+ self._log(f" {metric_name}: {value:.4f} ± {std:.4f}")
248
+
249
+ self._log("=" * 60)
250
+
251
+
252
+ def evaluate(
253
+ real_path: Union[str, Path],
254
+ generated_path: Union[str, Path],
255
+ condition_columns: List[str],
256
+ split_column: Optional[str] = None,
257
+ output_dir: Optional[Union[str, Path]] = None,
258
+ metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
259
+ include_multivariate: bool = True,
260
+ verbose: bool = True,
261
+ **loader_kwargs
262
+ ) -> EvaluationResult:
263
+ """
264
+ Convenience function to run full evaluation.
265
+
266
+ Parameters
267
+ ----------
268
+ real_path : str or Path
269
+ Path to real data h5ad file
270
+ generated_path : str or Path
271
+ Path to generated data h5ad file
272
+ condition_columns : List[str]
273
+ Columns to match between datasets
274
+ split_column : str, optional
275
+ Column indicating train/test split
276
+ output_dir : str or Path, optional
277
+ Directory to save results
278
+ metrics : List, optional
279
+ Metrics to compute
280
+ include_multivariate : bool
281
+ Whether to include multivariate metrics
282
+ verbose : bool
283
+ Print progress
284
+ **loader_kwargs
285
+ Additional arguments for data loader
286
+
287
+ Returns
288
+ -------
289
+ EvaluationResult
290
+ Complete evaluation results
291
+
292
+ Examples
293
+ --------
294
+ >>> results = evaluate(
295
+ ... "real.h5ad",
296
+ ... "generated.h5ad",
297
+ ... condition_columns=["perturbation", "cell_type"],
298
+ ... split_column="split",
299
+ ... output_dir="evaluation_output/"
300
+ ... )
301
+ """
302
+ # Load data
303
+ loader = load_data(
304
+ real_path=real_path,
305
+ generated_path=generated_path,
306
+ condition_columns=condition_columns,
307
+ split_column=split_column,
308
+ **loader_kwargs
309
+ )
310
+
311
+ # Create evaluator
312
+ evaluator = GeneEvalEvaluator(
313
+ data_loader=loader,
314
+ metrics=metrics,
315
+ include_multivariate=include_multivariate,
316
+ verbose=verbose,
317
+ )
318
+
319
+ # Run evaluation
320
+ return evaluator.evaluate(save_dir=output_dir)
321
+
322
+
323
+ class MetricRegistry:
324
+ """
325
+ Registry of available metrics.
326
+
327
+ Allows registration of custom metrics and retrieval by name.
328
+ """
329
+
330
+ _metrics: Dict[str, Type[BaseMetric]] = {}
331
+
332
+ @classmethod
333
+ def register(cls, metric_class: Type[BaseMetric]):
334
+ """Register a metric class."""
335
+ instance = metric_class()
336
+ cls._metrics[instance.name] = metric_class
337
+
338
+ @classmethod
339
+ def get(cls, name: str) -> Optional[Type[BaseMetric]]:
340
+ """Get metric class by name."""
341
+ return cls._metrics.get(name)
342
+
343
+ @classmethod
344
+ def list_all(cls) -> List[str]:
345
+ """List all registered metric names."""
346
+ return list(cls._metrics.keys())
347
+
348
+ @classmethod
349
+ def get_all(cls) -> List[Type[BaseMetric]]:
350
+ """Get all registered metric classes."""
351
+ return list(cls._metrics.values())
352
+
353
+
354
+ # Register default metrics
355
+ for metric_class in DEFAULT_METRICS:
356
+ MetricRegistry.register(metric_class)
357
+
358
+ MetricRegistry.register(MultivariateWasserstein)
359
+ MetricRegistry.register(MultivariateMMD)
@@ -0,0 +1,4 @@
1
+ from .base_evaluator import BaseEvaluator
2
+ from .gene_expression_evaluator import GeneExpressionEvaluator
3
+
4
+ __all__ = ["BaseEvaluator", "GeneExpressionEvaluator"]
@@ -0,0 +1,178 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC
4
+ from typing import Dict, Iterable, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from anndata import AnnData
9
+ from scipy import sparse
10
+
11
+ from ..utils.preprocessing import to_dense
12
+
13
+
14
+ class BaseEvaluator(ABC):
15
+ """
16
+ Base class for evaluation of generated data against real datasets.
17
+
18
+ Provides:
19
+ - Variable/gene alignment between real and generated AnnData objects
20
+ - Computation and application of control baselines per strata
21
+ """
22
+
23
+ def __init__(self, data, output: AnnData):
24
+ """
25
+ Parameters
26
+ ----------
27
+ data : object
28
+ An object providing at least:
29
+ - gene_expression_dataset.adata: AnnData
30
+ - perturbation_key: str
31
+ - split_key: str
32
+ - control: str
33
+ - condition_keys: Optional[List[str]]
34
+ output : AnnData
35
+ Generated data to evaluate.
36
+ """
37
+ self.data = data
38
+ self.output = output
39
+
40
+ # ---------- alignment utilities ----------
41
+
42
+ def _align_varnames_like(self, real: AnnData, generated: AnnData) -> Tuple[AnnData, AnnData]:
43
+ """
44
+ Align real and generated AnnData to the common set of var_names (genes),
45
+ preserving order based on the real AnnData.
46
+ """
47
+ real_genes = pd.Index(real.var_names.astype(str))
48
+ gen_genes = pd.Index(generated.var_names.astype(str))
49
+ common = real_genes.intersection(gen_genes)
50
+ if len(common) == 0:
51
+ raise ValueError("No overlapping genes between real and generated AnnData.")
52
+
53
+ # Reindex both adatas to the common genes in the order of real
54
+ real = real[:, real_genes.get_indexer(common)].copy()
55
+ generated = generated[:, generated.var_names.astype(str).isin(common)].copy()
56
+ # Reorder generated to match real
57
+ generated = generated[:, pd.Index(generated.var_names.astype(str)).get_indexer(common)].copy()
58
+
59
+ real.var_names = common
60
+ generated.var_names = common
61
+ return real, generated
62
+
63
+ # ---------- baseline utilities ----------
64
+
65
+ @staticmethod
66
+ def _key_from_values(values: Iterable[object]) -> str:
67
+ # stable string key for strata-tuples
68
+ return "####".join([str(v) for v in values])
69
+
70
+ def _compute_control_means(
71
+ self,
72
+ adata: AnnData,
73
+ perturbation_col: str,
74
+ control_value: str,
75
+ strata_cols: Optional[List[str]] = None,
76
+ ) -> Dict[str, np.ndarray]:
77
+ """
78
+ Compute per-strata control means across genes.
79
+
80
+ Returns a dict mapping a strata-key -> mean vector (n_genes,).
81
+ """
82
+ strata_cols = strata_cols or []
83
+ obs = adata.obs
84
+
85
+ if perturbation_col not in obs.columns:
86
+ raise KeyError(f"'{perturbation_col}' not found in adata.obs.")
87
+
88
+ is_control = (obs[perturbation_col].astype(str) == str(control_value)).to_numpy()
89
+ if not is_control.any():
90
+ # no controls; return empty means map
91
+ return {}
92
+
93
+ ctrl = adata[is_control]
94
+ if not strata_cols:
95
+ return {self._key_from_values([]): to_dense(ctrl.X).mean(axis=0)}
96
+
97
+ # group by strata columns (as strings to be robust)
98
+ df = ctrl.obs[strata_cols].astype(str)
99
+ means: Dict[str, np.ndarray] = {}
100
+ # compute mean per unique strata combination
101
+ for _, row in df.drop_duplicates().iterrows():
102
+ mask = np.ones(ctrl.n_obs, dtype=bool)
103
+ for c in strata_cols:
104
+ mask &= (df[c].to_numpy() == str(row[c]))
105
+ if not mask.any():
106
+ continue
107
+ key = self._key_from_values([row[c] for c in strata_cols])
108
+ means[key] = to_dense(ctrl.X[mask]).mean(axis=0)
109
+ return means
110
+
111
+ def _apply_baseline_per_strata(
112
+ self,
113
+ X,
114
+ obs: pd.DataFrame,
115
+ baseline: Dict[str, np.ndarray],
116
+ strata_cols: Optional[List[str]] = None,
117
+ mode: str = "subtract",
118
+ ):
119
+ """
120
+ Apply per-strata baseline vectors to rows in X based on obs[strata_cols].
121
+
122
+ mode: 'subtract' or 'add'
123
+ """
124
+ strata_cols = strata_cols or []
125
+ if mode not in ("subtract", "add"):
126
+ raise ValueError("mode must be 'subtract' or 'add'.")
127
+
128
+ if sparse.issparse(X):
129
+ X = X.tocsr(copy=True)
130
+ to_dense_first = False
131
+ else:
132
+ X = np.array(X, copy=True)
133
+ to_dense_first = True # already dense
134
+
135
+ if not strata_cols:
136
+ key = self._key_from_values([])
137
+ b = baseline.get(key, None)
138
+ if b is None:
139
+ return X
140
+ if sparse.issparse(X):
141
+ # operate dense for simplicity
142
+ X = X.toarray()
143
+ if mode == "subtract":
144
+ X -= b
145
+ else:
146
+ X += b
147
+ return X
148
+
149
+ # Apply per group
150
+ df = obs[strata_cols].astype(str)
151
+ # iterate groups in baseline for efficiency
152
+ for key, b in baseline.items():
153
+ # decode key into tuple of values
154
+ parts = key.split("####") if key else []
155
+ if len(parts) != len(strata_cols):
156
+ # skip mismatched key
157
+ continue
158
+ mask = np.ones(df.shape[0], dtype=bool)
159
+ for col, val in zip(strata_cols, parts):
160
+ mask &= (df[col].to_numpy() == val)
161
+ if not mask.any():
162
+ continue
163
+
164
+ if sparse.issparse(X):
165
+ # operate in dense block then write back
166
+ block = X[mask].toarray()
167
+ if mode == "subtract":
168
+ block -= b
169
+ else:
170
+ block += b
171
+ X[mask] = sparse.csr_matrix(block)
172
+ else:
173
+ if mode == "subtract":
174
+ X[mask] -= b
175
+ else:
176
+ X[mask] += b
177
+
178
+ return X