gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Dict, Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from anndata import AnnData
|
|
8
|
+
import scipy.stats as sstats
|
|
9
|
+
from sklearn.metrics import mean_squared_error
|
|
10
|
+
|
|
11
|
+
from ..metrics.metrics import compute_metrics
|
|
12
|
+
from ..utils.preprocessing import to_dense as _to_dense
|
|
13
|
+
from .base_evaluator import BaseEvaluator
|
|
14
|
+
from ..visualization import EvaluationPlotter
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..data.gene_expression_datamodule import GeneExpressionDataModule
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GeneExpressionEvaluator(BaseEvaluator):
|
|
21
|
+
"""
|
|
22
|
+
Evaluator for gene expression data.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, data: "GeneExpressionDataModule", output: AnnData):
|
|
26
|
+
super().__init__(data, output)
|
|
27
|
+
|
|
28
|
+
def evaluate(
|
|
29
|
+
self,
|
|
30
|
+
delta: bool = False,
|
|
31
|
+
plot: bool = False,
|
|
32
|
+
DEG: Optional[Dict[str, Any]] = None,
|
|
33
|
+
save_dir: Optional[str] = None,
|
|
34
|
+
max_panels: int = 12,
|
|
35
|
+
dpi: int = 150,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Run evaluation. If plot=True, returns and optionally saves figures.
|
|
39
|
+
"""
|
|
40
|
+
data = self.data.gene_expression_dataset.adata.copy()
|
|
41
|
+
generated = self.output.copy()
|
|
42
|
+
data, generated = self._align_varnames_like(data, generated)
|
|
43
|
+
|
|
44
|
+
pert_col = self.data.perturbation_key
|
|
45
|
+
split_key = self.data.split_key
|
|
46
|
+
control = self.data.control
|
|
47
|
+
|
|
48
|
+
order_cols = []
|
|
49
|
+
if "cell_type" in data.obs.columns and "cell_type" in generated.obs.columns:
|
|
50
|
+
order_cols.append("cell_type")
|
|
51
|
+
for c in (getattr(self.data, "condition_keys", None) or []):
|
|
52
|
+
if c in data.obs.columns and c in generated.obs.columns:
|
|
53
|
+
order_cols.append(c)
|
|
54
|
+
|
|
55
|
+
# Baseline handling
|
|
56
|
+
if delta:
|
|
57
|
+
b = self._compute_control_means(data, pert_col, control, strata_cols=order_cols)
|
|
58
|
+
data.X = self._apply_baseline_per_strata(data.X, data.obs, b, strata_cols=order_cols, mode="subtract")
|
|
59
|
+
else:
|
|
60
|
+
b = self._compute_control_means(data, pert_col, control, strata_cols=order_cols)
|
|
61
|
+
generated.X = self._apply_baseline_per_strata(
|
|
62
|
+
generated.X, generated.obs, b, strata_cols=order_cols, mode="add"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
is_test = (data.obs[split_key].astype(str) == "test").to_numpy()
|
|
66
|
+
test_data = data[is_test].copy()
|
|
67
|
+
|
|
68
|
+
if "perturbation" not in generated.obs.columns and pert_col not in generated.obs.columns:
|
|
69
|
+
raise KeyError("'perturbation' column not found in generated data.")
|
|
70
|
+
if pert_col not in generated.obs.columns and "perturbation" in generated.obs.columns:
|
|
71
|
+
generated.obs[pert_col] = generated.obs["perturbation"].astype(test_data.obs[pert_col].dtype)
|
|
72
|
+
|
|
73
|
+
def _means_masks(adata, cols):
|
|
74
|
+
means, masks = {}, {}
|
|
75
|
+
df = adata.obs[[pert_col] + cols].astype(str)
|
|
76
|
+
for _, row in df.drop_duplicates().iterrows():
|
|
77
|
+
pert = row[pert_col]
|
|
78
|
+
key = "####".join([pert] + [row[c] for c in cols])
|
|
79
|
+
mask = (adata.obs[pert_col].astype(str) == pert).to_numpy()
|
|
80
|
+
for c in cols:
|
|
81
|
+
mask &= (adata.obs[c].astype(str) == str(row[c])).to_numpy()
|
|
82
|
+
if mask.any():
|
|
83
|
+
masks[key] = mask
|
|
84
|
+
means[key] = _to_dense(adata[mask].X).mean(axis=0)
|
|
85
|
+
return means, masks
|
|
86
|
+
|
|
87
|
+
real_means, real_masks = _means_masks(test_data, order_cols)
|
|
88
|
+
gen_means, gen_masks = _means_masks(generated, order_cols)
|
|
89
|
+
common = sorted(set(real_means).intersection(gen_means))
|
|
90
|
+
if not common:
|
|
91
|
+
raise ValueError("No common (pert + covariates) between real TEST and generated.")
|
|
92
|
+
|
|
93
|
+
# Metric accumulators
|
|
94
|
+
w1 = []; w2 = []; mmd = []; energy = []
|
|
95
|
+
pearson_corr = []; pearson_p = []
|
|
96
|
+
spearman_corr = []; spearman_p = []
|
|
97
|
+
mse_val = []
|
|
98
|
+
|
|
99
|
+
vnames = pd.Index(test_data.var_names.astype(str))
|
|
100
|
+
|
|
101
|
+
# For plotting
|
|
102
|
+
plot_means = {}
|
|
103
|
+
residuals_per_key = {}
|
|
104
|
+
stats_per_key = {}
|
|
105
|
+
deg_map = {}
|
|
106
|
+
|
|
107
|
+
def maybe_filter(om, gm, td, gd, key):
|
|
108
|
+
if DEG is None:
|
|
109
|
+
return om, gm, td, gd
|
|
110
|
+
deg = DEG.get(key) or DEG.get(key.split("####", 1)[0])
|
|
111
|
+
if deg is None:
|
|
112
|
+
return om, gm, td, gd
|
|
113
|
+
names = None
|
|
114
|
+
if isinstance(deg, dict):
|
|
115
|
+
names = deg.get("names", None)
|
|
116
|
+
elif hasattr(deg, "columns") and "names" in deg.columns:
|
|
117
|
+
names = deg["names"]
|
|
118
|
+
else:
|
|
119
|
+
names = deg
|
|
120
|
+
if hasattr(names, "tolist"):
|
|
121
|
+
names = names.tolist()
|
|
122
|
+
if not names:
|
|
123
|
+
return om, gm, td, gd
|
|
124
|
+
mask = np.asarray(vnames.isin([str(x) for x in names]), dtype=bool)
|
|
125
|
+
if not mask.any():
|
|
126
|
+
return om, gm, td, gd
|
|
127
|
+
return om[mask], gm[mask], td[:, mask], gd[:, mask]
|
|
128
|
+
|
|
129
|
+
for key in common:
|
|
130
|
+
td = _to_dense(test_data.X[real_masks[key], :])
|
|
131
|
+
gd = _to_dense(generated.X[gen_masks[key], :])
|
|
132
|
+
om = real_means[key]; gm = gen_means[key]
|
|
133
|
+
om_f, gm_f, td_f, gd_f = maybe_filter(om, gm, td, gd, key)
|
|
134
|
+
|
|
135
|
+
# distributional metrics
|
|
136
|
+
w1.append({key: compute_metrics(td_f, gd_f, 'w1')})
|
|
137
|
+
w2.append({key: compute_metrics(td_f, gd_f, 'w2')})
|
|
138
|
+
mmd.append({key: compute_metrics(td_f, gd_f, 'mmd')})
|
|
139
|
+
energy.append({key: compute_metrics(td_f, gd_f, 'energy')})
|
|
140
|
+
|
|
141
|
+
# mean-wise metrics
|
|
142
|
+
pc, pcp = sstats.pearsonr(om_f, gm_f)
|
|
143
|
+
sc, scp = sstats.spearmanr(om_f, gm_f)
|
|
144
|
+
pearson_corr.append({key: pc}); pearson_p.append({key: pcp})
|
|
145
|
+
spearman_corr.append({key: sc}); spearman_p.append({key: scp})
|
|
146
|
+
mse = mean_squared_error(om_f, gm_f)
|
|
147
|
+
mse_val.append({key: mse})
|
|
148
|
+
|
|
149
|
+
# for plots
|
|
150
|
+
plot_means[key] = (om, gm, vnames.tolist())
|
|
151
|
+
residuals_per_key[key] = (gm - om)
|
|
152
|
+
stats_per_key[key] = {"pearson": float(pc), "spearman": float(sc), "mse": float(mse)}
|
|
153
|
+
if DEG is not None:
|
|
154
|
+
deg_map[key] = DEG.get(key) or DEG.get(key.split("####", 1)[0])
|
|
155
|
+
|
|
156
|
+
def _m(lst):
|
|
157
|
+
return float("nan") if not lst else float(np.mean([list(d.values())[0] for d in lst]))
|
|
158
|
+
|
|
159
|
+
print(f"Mean Pearson: {_m(pearson_corr):.4f} (p={_m(pearson_p):.4g})")
|
|
160
|
+
print(f"Mean Spearman: {_m(spearman_corr):.4f} (p={_m(spearman_p):.4g})")
|
|
161
|
+
print(f"Mean MSE: {_m(mse_val):.4f}")
|
|
162
|
+
print(f"Wasserstein-1: {_m(w1):.4f}")
|
|
163
|
+
print(f"Wasserstein-2: {_m(w2):.4f}")
|
|
164
|
+
print(f"MMD: {_m(mmd):.4f}")
|
|
165
|
+
print(f"Energy: {_m(energy):.4f}")
|
|
166
|
+
|
|
167
|
+
results = dict(
|
|
168
|
+
pearson_corr=pearson_corr,
|
|
169
|
+
spearman_corr=spearman_corr,
|
|
170
|
+
mse_val=mse_val,
|
|
171
|
+
w1=w1,
|
|
172
|
+
w2=w2,
|
|
173
|
+
mmd=mmd,
|
|
174
|
+
energy=energy,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Plotting
|
|
178
|
+
figures = {}
|
|
179
|
+
if plot:
|
|
180
|
+
plotter = EvaluationPlotter()
|
|
181
|
+
# scatter grid
|
|
182
|
+
fig_scatter = plotter.scatter_means_grid(
|
|
183
|
+
data=plot_means,
|
|
184
|
+
stats=stats_per_key,
|
|
185
|
+
deg_map=deg_map if deg_map else None,
|
|
186
|
+
max_panels=max_panels,
|
|
187
|
+
)
|
|
188
|
+
figures["scatter_means"] = fig_scatter
|
|
189
|
+
|
|
190
|
+
# residual distributions
|
|
191
|
+
fig_residuals = plotter.residuals_violin(residuals=residuals_per_key)
|
|
192
|
+
figures["residuals"] = fig_residuals
|
|
193
|
+
|
|
194
|
+
# metrics bar: combine main metrics
|
|
195
|
+
metrics_pk = {}
|
|
196
|
+
for k in common:
|
|
197
|
+
metrics_pk[k] = {
|
|
198
|
+
"pearson": stats_per_key[k]["pearson"],
|
|
199
|
+
"spearman": stats_per_key[k]["spearman"],
|
|
200
|
+
"MSE": stats_per_key[k]["mse"],
|
|
201
|
+
"W1": float([d[k] for d in w1 if k in d][0]),
|
|
202
|
+
"W2": float([d[k] for d in w2 if k in d][0]),
|
|
203
|
+
"MMD": float([d[k] for d in mmd if k in d][0]),
|
|
204
|
+
"Energy": float([d[k] for d in energy if k in d][0]),
|
|
205
|
+
}
|
|
206
|
+
fig_metrics = plotter.metrics_bar(metrics_per_key=metrics_pk)
|
|
207
|
+
figures["metrics_bar"] = fig_metrics
|
|
208
|
+
|
|
209
|
+
if save_dir:
|
|
210
|
+
import os
|
|
211
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
212
|
+
fig_scatter.savefig(os.path.join(save_dir, "scatter_means.png"), dpi=dpi, bbox_inches="tight")
|
|
213
|
+
fig_residuals.savefig(os.path.join(save_dir, "residuals.png"), dpi=dpi, bbox_inches="tight")
|
|
214
|
+
fig_metrics.savefig(os.path.join(save_dir, "metrics_bar.png"), dpi=dpi, bbox_inches="tight")
|
|
215
|
+
|
|
216
|
+
results["figures"] = figures
|
|
217
|
+
|
|
218
|
+
return results
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Metrics module for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides per-gene and aggregate metrics for comparing distributions:
|
|
5
|
+
- Correlation metrics (Pearson, Spearman)
|
|
6
|
+
- Distribution distances (Wasserstein, MMD, Energy)
|
|
7
|
+
- Multivariate distances
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .base_metric import (
|
|
11
|
+
BaseMetric,
|
|
12
|
+
MetricResult,
|
|
13
|
+
DistributionMetric,
|
|
14
|
+
CorrelationMetric,
|
|
15
|
+
)
|
|
16
|
+
from .correlation import (
|
|
17
|
+
PearsonCorrelation,
|
|
18
|
+
SpearmanCorrelation,
|
|
19
|
+
MeanPearsonCorrelation,
|
|
20
|
+
MeanSpearmanCorrelation,
|
|
21
|
+
)
|
|
22
|
+
from .distances import (
|
|
23
|
+
Wasserstein1Distance,
|
|
24
|
+
Wasserstein2Distance,
|
|
25
|
+
MMDDistance,
|
|
26
|
+
EnergyDistance,
|
|
27
|
+
MultivariateWasserstein,
|
|
28
|
+
MultivariateMMD,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# All available metrics
|
|
32
|
+
ALL_METRICS = [
|
|
33
|
+
PearsonCorrelation,
|
|
34
|
+
SpearmanCorrelation,
|
|
35
|
+
MeanPearsonCorrelation,
|
|
36
|
+
MeanSpearmanCorrelation,
|
|
37
|
+
Wasserstein1Distance,
|
|
38
|
+
Wasserstein2Distance,
|
|
39
|
+
MMDDistance,
|
|
40
|
+
EnergyDistance,
|
|
41
|
+
MultivariateWasserstein,
|
|
42
|
+
MultivariateMMD,
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
# Base classes
|
|
47
|
+
"BaseMetric",
|
|
48
|
+
"MetricResult",
|
|
49
|
+
"DistributionMetric",
|
|
50
|
+
"CorrelationMetric",
|
|
51
|
+
# Correlation metrics
|
|
52
|
+
"PearsonCorrelation",
|
|
53
|
+
"SpearmanCorrelation",
|
|
54
|
+
"MeanPearsonCorrelation",
|
|
55
|
+
"MeanSpearmanCorrelation",
|
|
56
|
+
# Distance metrics
|
|
57
|
+
"Wasserstein1Distance",
|
|
58
|
+
"Wasserstein2Distance",
|
|
59
|
+
"MMDDistance",
|
|
60
|
+
"EnergyDistance",
|
|
61
|
+
"MultivariateWasserstein",
|
|
62
|
+
"MultivariateMMD",
|
|
63
|
+
# Collections
|
|
64
|
+
"ALL_METRICS",
|
|
65
|
+
]
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base metric classes for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides abstract interface for all metrics with per-gene and aggregate computation.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Dict, List, Optional, Union, Any, Callable
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class MetricResult:
|
|
16
|
+
"""
|
|
17
|
+
Container for metric computation results.
|
|
18
|
+
|
|
19
|
+
Stores both per-gene and aggregate values.
|
|
20
|
+
"""
|
|
21
|
+
name: str
|
|
22
|
+
per_gene_values: np.ndarray # Shape: (n_genes,)
|
|
23
|
+
gene_names: List[str]
|
|
24
|
+
aggregate_value: float
|
|
25
|
+
aggregate_method: str = "mean" # mean, median, etc.
|
|
26
|
+
condition: Optional[str] = None
|
|
27
|
+
split: Optional[str] = None
|
|
28
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def as_dict(self) -> Dict[str, Any]:
|
|
32
|
+
"""Convert to dictionary for serialization."""
|
|
33
|
+
return {
|
|
34
|
+
"name": self.name,
|
|
35
|
+
"aggregate_value": float(self.aggregate_value),
|
|
36
|
+
"aggregate_method": self.aggregate_method,
|
|
37
|
+
"per_gene_mean": float(np.nanmean(self.per_gene_values)),
|
|
38
|
+
"per_gene_std": float(np.nanstd(self.per_gene_values)),
|
|
39
|
+
"per_gene_median": float(np.nanmedian(self.per_gene_values)),
|
|
40
|
+
"n_genes": len(self.gene_names),
|
|
41
|
+
"condition": self.condition,
|
|
42
|
+
"split": self.split,
|
|
43
|
+
**self.metadata
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
def top_genes(self, n: int = 10, ascending: bool = True) -> Dict[str, float]:
|
|
47
|
+
"""Get top n genes by metric value."""
|
|
48
|
+
order = np.argsort(self.per_gene_values)
|
|
49
|
+
if not ascending:
|
|
50
|
+
order = order[::-1]
|
|
51
|
+
indices = order[:n]
|
|
52
|
+
return {self.gene_names[i]: float(self.per_gene_values[i]) for i in indices}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class BaseMetric(ABC):
|
|
56
|
+
"""
|
|
57
|
+
Abstract base class for all evaluation metrics.
|
|
58
|
+
|
|
59
|
+
Metrics can be computed per-gene (returning a vector) or as aggregates.
|
|
60
|
+
All metrics should inherit from this class.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
name: str,
|
|
66
|
+
description: str = "",
|
|
67
|
+
higher_is_better: bool = True,
|
|
68
|
+
requires_distribution: bool = False,
|
|
69
|
+
):
|
|
70
|
+
"""
|
|
71
|
+
Initialize metric.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
name : str
|
|
76
|
+
Unique identifier for the metric
|
|
77
|
+
description : str
|
|
78
|
+
Human-readable description
|
|
79
|
+
higher_is_better : bool
|
|
80
|
+
Whether higher values indicate better performance
|
|
81
|
+
requires_distribution : bool
|
|
82
|
+
Whether metric needs full distribution (not just means)
|
|
83
|
+
"""
|
|
84
|
+
self.name = name
|
|
85
|
+
self.description = description
|
|
86
|
+
self.higher_is_better = higher_is_better
|
|
87
|
+
self.requires_distribution = requires_distribution
|
|
88
|
+
|
|
89
|
+
@abstractmethod
|
|
90
|
+
def compute_per_gene(
|
|
91
|
+
self,
|
|
92
|
+
real: np.ndarray,
|
|
93
|
+
generated: np.ndarray,
|
|
94
|
+
) -> np.ndarray:
|
|
95
|
+
"""
|
|
96
|
+
Compute metric for each gene.
|
|
97
|
+
|
|
98
|
+
Parameters
|
|
99
|
+
----------
|
|
100
|
+
real : np.ndarray
|
|
101
|
+
Real data matrix, shape (n_samples_real, n_genes)
|
|
102
|
+
generated : np.ndarray
|
|
103
|
+
Generated data matrix, shape (n_samples_gen, n_genes)
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
np.ndarray
|
|
108
|
+
Metric value per gene, shape (n_genes,)
|
|
109
|
+
"""
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
def compute_aggregate(
|
|
113
|
+
self,
|
|
114
|
+
per_gene_values: np.ndarray,
|
|
115
|
+
method: str = "mean",
|
|
116
|
+
) -> float:
|
|
117
|
+
"""
|
|
118
|
+
Aggregate per-gene values to single metric.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
per_gene_values : np.ndarray
|
|
123
|
+
Per-gene metric values
|
|
124
|
+
method : str
|
|
125
|
+
Aggregation method: "mean", "median", "std", "min", "max"
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
float
|
|
130
|
+
Aggregated metric value
|
|
131
|
+
"""
|
|
132
|
+
methods = {
|
|
133
|
+
"mean": np.nanmean,
|
|
134
|
+
"median": np.nanmedian,
|
|
135
|
+
"std": np.nanstd,
|
|
136
|
+
"min": np.nanmin,
|
|
137
|
+
"max": np.nanmax,
|
|
138
|
+
}
|
|
139
|
+
if method not in methods:
|
|
140
|
+
raise ValueError(f"Unknown aggregation method: {method}")
|
|
141
|
+
return float(methods[method](per_gene_values))
|
|
142
|
+
|
|
143
|
+
def compute(
|
|
144
|
+
self,
|
|
145
|
+
real: np.ndarray,
|
|
146
|
+
generated: np.ndarray,
|
|
147
|
+
gene_names: Optional[List[str]] = None,
|
|
148
|
+
aggregate_method: str = "mean",
|
|
149
|
+
condition: Optional[str] = None,
|
|
150
|
+
split: Optional[str] = None,
|
|
151
|
+
) -> MetricResult:
|
|
152
|
+
"""
|
|
153
|
+
Compute full metric result with per-gene and aggregate values.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
real : np.ndarray
|
|
158
|
+
Real data matrix, shape (n_samples_real, n_genes)
|
|
159
|
+
generated : np.ndarray
|
|
160
|
+
Generated data matrix, shape (n_samples_gen, n_genes)
|
|
161
|
+
gene_names : List[str], optional
|
|
162
|
+
Names of genes (columns)
|
|
163
|
+
aggregate_method : str
|
|
164
|
+
How to aggregate per-gene values
|
|
165
|
+
condition : str, optional
|
|
166
|
+
Condition identifier
|
|
167
|
+
split : str, optional
|
|
168
|
+
Split identifier (train/test)
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
MetricResult
|
|
173
|
+
Complete metric result
|
|
174
|
+
"""
|
|
175
|
+
n_genes = real.shape[1] if real.ndim > 1 else 1
|
|
176
|
+
if gene_names is None:
|
|
177
|
+
gene_names = [f"gene_{i}" for i in range(n_genes)]
|
|
178
|
+
|
|
179
|
+
per_gene = self.compute_per_gene(real, generated)
|
|
180
|
+
aggregate = self.compute_aggregate(per_gene, method=aggregate_method)
|
|
181
|
+
|
|
182
|
+
return MetricResult(
|
|
183
|
+
name=self.name,
|
|
184
|
+
per_gene_values=per_gene,
|
|
185
|
+
gene_names=gene_names,
|
|
186
|
+
aggregate_value=aggregate,
|
|
187
|
+
aggregate_method=aggregate_method,
|
|
188
|
+
condition=condition,
|
|
189
|
+
split=split,
|
|
190
|
+
metadata={
|
|
191
|
+
"higher_is_better": self.higher_is_better,
|
|
192
|
+
"description": self.description,
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def __repr__(self) -> str:
|
|
197
|
+
return f"{self.__class__.__name__}(name='{self.name}')"
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class DistributionMetric(BaseMetric):
|
|
201
|
+
"""
|
|
202
|
+
Base class for distribution-based metrics (Wasserstein, MMD, Energy).
|
|
203
|
+
|
|
204
|
+
These metrics require the full sample distributions, not just means.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(self, name: str, description: str = "", higher_is_better: bool = False):
|
|
208
|
+
super().__init__(
|
|
209
|
+
name=name,
|
|
210
|
+
description=description,
|
|
211
|
+
higher_is_better=higher_is_better,
|
|
212
|
+
requires_distribution=True,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class CorrelationMetric(BaseMetric):
|
|
217
|
+
"""
|
|
218
|
+
Base class for correlation-based metrics (Pearson, Spearman).
|
|
219
|
+
|
|
220
|
+
These compare mean profiles between real and generated data.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def __init__(self, name: str, description: str = ""):
|
|
224
|
+
super().__init__(
|
|
225
|
+
name=name,
|
|
226
|
+
description=description,
|
|
227
|
+
higher_is_better=True,
|
|
228
|
+
requires_distribution=False,
|
|
229
|
+
)
|