gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional, Dict, Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from anndata import AnnData
8
+ import scipy.stats as sstats
9
+ from sklearn.metrics import mean_squared_error
10
+
11
+ from ..metrics.metrics import compute_metrics
12
+ from ..utils.preprocessing import to_dense as _to_dense
13
+ from .base_evaluator import BaseEvaluator
14
+ from ..visualization import EvaluationPlotter
15
+
16
+ if TYPE_CHECKING:
17
+ from ..data.gene_expression_datamodule import GeneExpressionDataModule
18
+
19
+
20
+ class GeneExpressionEvaluator(BaseEvaluator):
21
+ """
22
+ Evaluator for gene expression data.
23
+ """
24
+
25
+ def __init__(self, data: "GeneExpressionDataModule", output: AnnData):
26
+ super().__init__(data, output)
27
+
28
+ def evaluate(
29
+ self,
30
+ delta: bool = False,
31
+ plot: bool = False,
32
+ DEG: Optional[Dict[str, Any]] = None,
33
+ save_dir: Optional[str] = None,
34
+ max_panels: int = 12,
35
+ dpi: int = 150,
36
+ ):
37
+ """
38
+ Run evaluation. If plot=True, returns and optionally saves figures.
39
+ """
40
+ data = self.data.gene_expression_dataset.adata.copy()
41
+ generated = self.output.copy()
42
+ data, generated = self._align_varnames_like(data, generated)
43
+
44
+ pert_col = self.data.perturbation_key
45
+ split_key = self.data.split_key
46
+ control = self.data.control
47
+
48
+ order_cols = []
49
+ if "cell_type" in data.obs.columns and "cell_type" in generated.obs.columns:
50
+ order_cols.append("cell_type")
51
+ for c in (getattr(self.data, "condition_keys", None) or []):
52
+ if c in data.obs.columns and c in generated.obs.columns:
53
+ order_cols.append(c)
54
+
55
+ # Baseline handling
56
+ if delta:
57
+ b = self._compute_control_means(data, pert_col, control, strata_cols=order_cols)
58
+ data.X = self._apply_baseline_per_strata(data.X, data.obs, b, strata_cols=order_cols, mode="subtract")
59
+ else:
60
+ b = self._compute_control_means(data, pert_col, control, strata_cols=order_cols)
61
+ generated.X = self._apply_baseline_per_strata(
62
+ generated.X, generated.obs, b, strata_cols=order_cols, mode="add"
63
+ )
64
+
65
+ is_test = (data.obs[split_key].astype(str) == "test").to_numpy()
66
+ test_data = data[is_test].copy()
67
+
68
+ if "perturbation" not in generated.obs.columns and pert_col not in generated.obs.columns:
69
+ raise KeyError("'perturbation' column not found in generated data.")
70
+ if pert_col not in generated.obs.columns and "perturbation" in generated.obs.columns:
71
+ generated.obs[pert_col] = generated.obs["perturbation"].astype(test_data.obs[pert_col].dtype)
72
+
73
+ def _means_masks(adata, cols):
74
+ means, masks = {}, {}
75
+ df = adata.obs[[pert_col] + cols].astype(str)
76
+ for _, row in df.drop_duplicates().iterrows():
77
+ pert = row[pert_col]
78
+ key = "####".join([pert] + [row[c] for c in cols])
79
+ mask = (adata.obs[pert_col].astype(str) == pert).to_numpy()
80
+ for c in cols:
81
+ mask &= (adata.obs[c].astype(str) == str(row[c])).to_numpy()
82
+ if mask.any():
83
+ masks[key] = mask
84
+ means[key] = _to_dense(adata[mask].X).mean(axis=0)
85
+ return means, masks
86
+
87
+ real_means, real_masks = _means_masks(test_data, order_cols)
88
+ gen_means, gen_masks = _means_masks(generated, order_cols)
89
+ common = sorted(set(real_means).intersection(gen_means))
90
+ if not common:
91
+ raise ValueError("No common (pert + covariates) between real TEST and generated.")
92
+
93
+ # Metric accumulators
94
+ w1 = []; w2 = []; mmd = []; energy = []
95
+ pearson_corr = []; pearson_p = []
96
+ spearman_corr = []; spearman_p = []
97
+ mse_val = []
98
+
99
+ vnames = pd.Index(test_data.var_names.astype(str))
100
+
101
+ # For plotting
102
+ plot_means = {}
103
+ residuals_per_key = {}
104
+ stats_per_key = {}
105
+ deg_map = {}
106
+
107
+ def maybe_filter(om, gm, td, gd, key):
108
+ if DEG is None:
109
+ return om, gm, td, gd
110
+ deg = DEG.get(key) or DEG.get(key.split("####", 1)[0])
111
+ if deg is None:
112
+ return om, gm, td, gd
113
+ names = None
114
+ if isinstance(deg, dict):
115
+ names = deg.get("names", None)
116
+ elif hasattr(deg, "columns") and "names" in deg.columns:
117
+ names = deg["names"]
118
+ else:
119
+ names = deg
120
+ if hasattr(names, "tolist"):
121
+ names = names.tolist()
122
+ if not names:
123
+ return om, gm, td, gd
124
+ mask = np.asarray(vnames.isin([str(x) for x in names]), dtype=bool)
125
+ if not mask.any():
126
+ return om, gm, td, gd
127
+ return om[mask], gm[mask], td[:, mask], gd[:, mask]
128
+
129
+ for key in common:
130
+ td = _to_dense(test_data.X[real_masks[key], :])
131
+ gd = _to_dense(generated.X[gen_masks[key], :])
132
+ om = real_means[key]; gm = gen_means[key]
133
+ om_f, gm_f, td_f, gd_f = maybe_filter(om, gm, td, gd, key)
134
+
135
+ # distributional metrics
136
+ w1.append({key: compute_metrics(td_f, gd_f, 'w1')})
137
+ w2.append({key: compute_metrics(td_f, gd_f, 'w2')})
138
+ mmd.append({key: compute_metrics(td_f, gd_f, 'mmd')})
139
+ energy.append({key: compute_metrics(td_f, gd_f, 'energy')})
140
+
141
+ # mean-wise metrics
142
+ pc, pcp = sstats.pearsonr(om_f, gm_f)
143
+ sc, scp = sstats.spearmanr(om_f, gm_f)
144
+ pearson_corr.append({key: pc}); pearson_p.append({key: pcp})
145
+ spearman_corr.append({key: sc}); spearman_p.append({key: scp})
146
+ mse = mean_squared_error(om_f, gm_f)
147
+ mse_val.append({key: mse})
148
+
149
+ # for plots
150
+ plot_means[key] = (om, gm, vnames.tolist())
151
+ residuals_per_key[key] = (gm - om)
152
+ stats_per_key[key] = {"pearson": float(pc), "spearman": float(sc), "mse": float(mse)}
153
+ if DEG is not None:
154
+ deg_map[key] = DEG.get(key) or DEG.get(key.split("####", 1)[0])
155
+
156
+ def _m(lst):
157
+ return float("nan") if not lst else float(np.mean([list(d.values())[0] for d in lst]))
158
+
159
+ print(f"Mean Pearson: {_m(pearson_corr):.4f} (p={_m(pearson_p):.4g})")
160
+ print(f"Mean Spearman: {_m(spearman_corr):.4f} (p={_m(spearman_p):.4g})")
161
+ print(f"Mean MSE: {_m(mse_val):.4f}")
162
+ print(f"Wasserstein-1: {_m(w1):.4f}")
163
+ print(f"Wasserstein-2: {_m(w2):.4f}")
164
+ print(f"MMD: {_m(mmd):.4f}")
165
+ print(f"Energy: {_m(energy):.4f}")
166
+
167
+ results = dict(
168
+ pearson_corr=pearson_corr,
169
+ spearman_corr=spearman_corr,
170
+ mse_val=mse_val,
171
+ w1=w1,
172
+ w2=w2,
173
+ mmd=mmd,
174
+ energy=energy,
175
+ )
176
+
177
+ # Plotting
178
+ figures = {}
179
+ if plot:
180
+ plotter = EvaluationPlotter()
181
+ # scatter grid
182
+ fig_scatter = plotter.scatter_means_grid(
183
+ data=plot_means,
184
+ stats=stats_per_key,
185
+ deg_map=deg_map if deg_map else None,
186
+ max_panels=max_panels,
187
+ )
188
+ figures["scatter_means"] = fig_scatter
189
+
190
+ # residual distributions
191
+ fig_residuals = plotter.residuals_violin(residuals=residuals_per_key)
192
+ figures["residuals"] = fig_residuals
193
+
194
+ # metrics bar: combine main metrics
195
+ metrics_pk = {}
196
+ for k in common:
197
+ metrics_pk[k] = {
198
+ "pearson": stats_per_key[k]["pearson"],
199
+ "spearman": stats_per_key[k]["spearman"],
200
+ "MSE": stats_per_key[k]["mse"],
201
+ "W1": float([d[k] for d in w1 if k in d][0]),
202
+ "W2": float([d[k] for d in w2 if k in d][0]),
203
+ "MMD": float([d[k] for d in mmd if k in d][0]),
204
+ "Energy": float([d[k] for d in energy if k in d][0]),
205
+ }
206
+ fig_metrics = plotter.metrics_bar(metrics_per_key=metrics_pk)
207
+ figures["metrics_bar"] = fig_metrics
208
+
209
+ if save_dir:
210
+ import os
211
+ os.makedirs(save_dir, exist_ok=True)
212
+ fig_scatter.savefig(os.path.join(save_dir, "scatter_means.png"), dpi=dpi, bbox_inches="tight")
213
+ fig_residuals.savefig(os.path.join(save_dir, "residuals.png"), dpi=dpi, bbox_inches="tight")
214
+ fig_metrics.savefig(os.path.join(save_dir, "metrics_bar.png"), dpi=dpi, bbox_inches="tight")
215
+
216
+ results["figures"] = figures
217
+
218
+ return results
@@ -0,0 +1,65 @@
1
+ """
2
+ Metrics module for gene expression evaluation.
3
+
4
+ Provides per-gene and aggregate metrics for comparing distributions:
5
+ - Correlation metrics (Pearson, Spearman)
6
+ - Distribution distances (Wasserstein, MMD, Energy)
7
+ - Multivariate distances
8
+ """
9
+
10
+ from .base_metric import (
11
+ BaseMetric,
12
+ MetricResult,
13
+ DistributionMetric,
14
+ CorrelationMetric,
15
+ )
16
+ from .correlation import (
17
+ PearsonCorrelation,
18
+ SpearmanCorrelation,
19
+ MeanPearsonCorrelation,
20
+ MeanSpearmanCorrelation,
21
+ )
22
+ from .distances import (
23
+ Wasserstein1Distance,
24
+ Wasserstein2Distance,
25
+ MMDDistance,
26
+ EnergyDistance,
27
+ MultivariateWasserstein,
28
+ MultivariateMMD,
29
+ )
30
+
31
+ # All available metrics
32
+ ALL_METRICS = [
33
+ PearsonCorrelation,
34
+ SpearmanCorrelation,
35
+ MeanPearsonCorrelation,
36
+ MeanSpearmanCorrelation,
37
+ Wasserstein1Distance,
38
+ Wasserstein2Distance,
39
+ MMDDistance,
40
+ EnergyDistance,
41
+ MultivariateWasserstein,
42
+ MultivariateMMD,
43
+ ]
44
+
45
+ __all__ = [
46
+ # Base classes
47
+ "BaseMetric",
48
+ "MetricResult",
49
+ "DistributionMetric",
50
+ "CorrelationMetric",
51
+ # Correlation metrics
52
+ "PearsonCorrelation",
53
+ "SpearmanCorrelation",
54
+ "MeanPearsonCorrelation",
55
+ "MeanSpearmanCorrelation",
56
+ # Distance metrics
57
+ "Wasserstein1Distance",
58
+ "Wasserstein2Distance",
59
+ "MMDDistance",
60
+ "EnergyDistance",
61
+ "MultivariateWasserstein",
62
+ "MultivariateMMD",
63
+ # Collections
64
+ "ALL_METRICS",
65
+ ]
@@ -0,0 +1,229 @@
1
+ """
2
+ Base metric classes for gene expression evaluation.
3
+
4
+ Provides abstract interface for all metrics with per-gene and aggregate computation.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from abc import ABC, abstractmethod
9
+ from dataclasses import dataclass, field
10
+ from typing import Dict, List, Optional, Union, Any, Callable
11
+ import numpy as np
12
+
13
+
14
+ @dataclass
15
+ class MetricResult:
16
+ """
17
+ Container for metric computation results.
18
+
19
+ Stores both per-gene and aggregate values.
20
+ """
21
+ name: str
22
+ per_gene_values: np.ndarray # Shape: (n_genes,)
23
+ gene_names: List[str]
24
+ aggregate_value: float
25
+ aggregate_method: str = "mean" # mean, median, etc.
26
+ condition: Optional[str] = None
27
+ split: Optional[str] = None
28
+ metadata: Dict[str, Any] = field(default_factory=dict)
29
+
30
+ @property
31
+ def as_dict(self) -> Dict[str, Any]:
32
+ """Convert to dictionary for serialization."""
33
+ return {
34
+ "name": self.name,
35
+ "aggregate_value": float(self.aggregate_value),
36
+ "aggregate_method": self.aggregate_method,
37
+ "per_gene_mean": float(np.nanmean(self.per_gene_values)),
38
+ "per_gene_std": float(np.nanstd(self.per_gene_values)),
39
+ "per_gene_median": float(np.nanmedian(self.per_gene_values)),
40
+ "n_genes": len(self.gene_names),
41
+ "condition": self.condition,
42
+ "split": self.split,
43
+ **self.metadata
44
+ }
45
+
46
+ def top_genes(self, n: int = 10, ascending: bool = True) -> Dict[str, float]:
47
+ """Get top n genes by metric value."""
48
+ order = np.argsort(self.per_gene_values)
49
+ if not ascending:
50
+ order = order[::-1]
51
+ indices = order[:n]
52
+ return {self.gene_names[i]: float(self.per_gene_values[i]) for i in indices}
53
+
54
+
55
+ class BaseMetric(ABC):
56
+ """
57
+ Abstract base class for all evaluation metrics.
58
+
59
+ Metrics can be computed per-gene (returning a vector) or as aggregates.
60
+ All metrics should inherit from this class.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ name: str,
66
+ description: str = "",
67
+ higher_is_better: bool = True,
68
+ requires_distribution: bool = False,
69
+ ):
70
+ """
71
+ Initialize metric.
72
+
73
+ Parameters
74
+ ----------
75
+ name : str
76
+ Unique identifier for the metric
77
+ description : str
78
+ Human-readable description
79
+ higher_is_better : bool
80
+ Whether higher values indicate better performance
81
+ requires_distribution : bool
82
+ Whether metric needs full distribution (not just means)
83
+ """
84
+ self.name = name
85
+ self.description = description
86
+ self.higher_is_better = higher_is_better
87
+ self.requires_distribution = requires_distribution
88
+
89
+ @abstractmethod
90
+ def compute_per_gene(
91
+ self,
92
+ real: np.ndarray,
93
+ generated: np.ndarray,
94
+ ) -> np.ndarray:
95
+ """
96
+ Compute metric for each gene.
97
+
98
+ Parameters
99
+ ----------
100
+ real : np.ndarray
101
+ Real data matrix, shape (n_samples_real, n_genes)
102
+ generated : np.ndarray
103
+ Generated data matrix, shape (n_samples_gen, n_genes)
104
+
105
+ Returns
106
+ -------
107
+ np.ndarray
108
+ Metric value per gene, shape (n_genes,)
109
+ """
110
+ pass
111
+
112
+ def compute_aggregate(
113
+ self,
114
+ per_gene_values: np.ndarray,
115
+ method: str = "mean",
116
+ ) -> float:
117
+ """
118
+ Aggregate per-gene values to single metric.
119
+
120
+ Parameters
121
+ ----------
122
+ per_gene_values : np.ndarray
123
+ Per-gene metric values
124
+ method : str
125
+ Aggregation method: "mean", "median", "std", "min", "max"
126
+
127
+ Returns
128
+ -------
129
+ float
130
+ Aggregated metric value
131
+ """
132
+ methods = {
133
+ "mean": np.nanmean,
134
+ "median": np.nanmedian,
135
+ "std": np.nanstd,
136
+ "min": np.nanmin,
137
+ "max": np.nanmax,
138
+ }
139
+ if method not in methods:
140
+ raise ValueError(f"Unknown aggregation method: {method}")
141
+ return float(methods[method](per_gene_values))
142
+
143
+ def compute(
144
+ self,
145
+ real: np.ndarray,
146
+ generated: np.ndarray,
147
+ gene_names: Optional[List[str]] = None,
148
+ aggregate_method: str = "mean",
149
+ condition: Optional[str] = None,
150
+ split: Optional[str] = None,
151
+ ) -> MetricResult:
152
+ """
153
+ Compute full metric result with per-gene and aggregate values.
154
+
155
+ Parameters
156
+ ----------
157
+ real : np.ndarray
158
+ Real data matrix, shape (n_samples_real, n_genes)
159
+ generated : np.ndarray
160
+ Generated data matrix, shape (n_samples_gen, n_genes)
161
+ gene_names : List[str], optional
162
+ Names of genes (columns)
163
+ aggregate_method : str
164
+ How to aggregate per-gene values
165
+ condition : str, optional
166
+ Condition identifier
167
+ split : str, optional
168
+ Split identifier (train/test)
169
+
170
+ Returns
171
+ -------
172
+ MetricResult
173
+ Complete metric result
174
+ """
175
+ n_genes = real.shape[1] if real.ndim > 1 else 1
176
+ if gene_names is None:
177
+ gene_names = [f"gene_{i}" for i in range(n_genes)]
178
+
179
+ per_gene = self.compute_per_gene(real, generated)
180
+ aggregate = self.compute_aggregate(per_gene, method=aggregate_method)
181
+
182
+ return MetricResult(
183
+ name=self.name,
184
+ per_gene_values=per_gene,
185
+ gene_names=gene_names,
186
+ aggregate_value=aggregate,
187
+ aggregate_method=aggregate_method,
188
+ condition=condition,
189
+ split=split,
190
+ metadata={
191
+ "higher_is_better": self.higher_is_better,
192
+ "description": self.description,
193
+ }
194
+ )
195
+
196
+ def __repr__(self) -> str:
197
+ return f"{self.__class__.__name__}(name='{self.name}')"
198
+
199
+
200
+ class DistributionMetric(BaseMetric):
201
+ """
202
+ Base class for distribution-based metrics (Wasserstein, MMD, Energy).
203
+
204
+ These metrics require the full sample distributions, not just means.
205
+ """
206
+
207
+ def __init__(self, name: str, description: str = "", higher_is_better: bool = False):
208
+ super().__init__(
209
+ name=name,
210
+ description=description,
211
+ higher_is_better=higher_is_better,
212
+ requires_distribution=True,
213
+ )
214
+
215
+
216
+ class CorrelationMetric(BaseMetric):
217
+ """
218
+ Base class for correlation-based metrics (Pearson, Spearman).
219
+
220
+ These compare mean profiles between real and generated data.
221
+ """
222
+
223
+ def __init__(self, name: str, description: str = ""):
224
+ super().__init__(
225
+ name=name,
226
+ description=description,
227
+ higher_is_better=True,
228
+ requires_distribution=False,
229
+ )