gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
geneval/results.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Results container classes for evaluation outputs.
|
|
3
|
+
|
|
4
|
+
Provides structured storage for metrics, conditions, and visualization data.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from typing import Dict, List, Optional, Any, Union
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import json
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ConditionResult:
|
|
18
|
+
"""
|
|
19
|
+
Results for a single condition (perturbation + covariates).
|
|
20
|
+
"""
|
|
21
|
+
condition_key: str
|
|
22
|
+
split: str
|
|
23
|
+
n_real_samples: int
|
|
24
|
+
n_generated_samples: int
|
|
25
|
+
n_genes: int
|
|
26
|
+
gene_names: List[str]
|
|
27
|
+
metrics: Dict[str, "MetricResult"] = field(default_factory=dict)
|
|
28
|
+
|
|
29
|
+
# Mean expression profiles
|
|
30
|
+
real_mean: Optional[np.ndarray] = None
|
|
31
|
+
generated_mean: Optional[np.ndarray] = None
|
|
32
|
+
|
|
33
|
+
# Parsed condition components
|
|
34
|
+
perturbation: Optional[str] = None
|
|
35
|
+
covariates: Dict[str, str] = field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
def add_metric(self, name: str, result: "MetricResult"):
|
|
38
|
+
"""Add a metric result."""
|
|
39
|
+
self.metrics[name] = result
|
|
40
|
+
|
|
41
|
+
def get_metric_value(self, name: str) -> Optional[float]:
|
|
42
|
+
"""Get aggregate value for a metric."""
|
|
43
|
+
if name in self.metrics:
|
|
44
|
+
return self.metrics[name].aggregate_value
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
def get_per_gene_values(self, name: str) -> Optional[np.ndarray]:
|
|
48
|
+
"""Get per-gene values for a metric."""
|
|
49
|
+
if name in self.metrics:
|
|
50
|
+
return self.metrics[name].per_gene_values
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def summary(self) -> Dict[str, Any]:
|
|
55
|
+
"""Get summary dictionary."""
|
|
56
|
+
result = {
|
|
57
|
+
"condition_key": self.condition_key,
|
|
58
|
+
"split": self.split,
|
|
59
|
+
"perturbation": self.perturbation,
|
|
60
|
+
"n_real_samples": self.n_real_samples,
|
|
61
|
+
"n_generated_samples": self.n_generated_samples,
|
|
62
|
+
"n_genes": self.n_genes,
|
|
63
|
+
}
|
|
64
|
+
result.update(self.covariates)
|
|
65
|
+
|
|
66
|
+
for name, metric in self.metrics.items():
|
|
67
|
+
result[name] = metric.aggregate_value
|
|
68
|
+
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class SplitResult:
|
|
74
|
+
"""
|
|
75
|
+
Results for a single split (train/test/all).
|
|
76
|
+
"""
|
|
77
|
+
split_name: str
|
|
78
|
+
conditions: Dict[str, ConditionResult] = field(default_factory=dict)
|
|
79
|
+
aggregate_metrics: Dict[str, float] = field(default_factory=dict)
|
|
80
|
+
|
|
81
|
+
def add_condition(self, condition: ConditionResult):
|
|
82
|
+
"""Add a condition result."""
|
|
83
|
+
self.conditions[condition.condition_key] = condition
|
|
84
|
+
|
|
85
|
+
def compute_aggregates(self):
|
|
86
|
+
"""Compute aggregate metrics across all conditions."""
|
|
87
|
+
if not self.conditions:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
# Collect all metric names
|
|
91
|
+
metric_names = set()
|
|
92
|
+
for cond in self.conditions.values():
|
|
93
|
+
metric_names.update(cond.metrics.keys())
|
|
94
|
+
|
|
95
|
+
# Compute mean across conditions for each metric
|
|
96
|
+
for name in metric_names:
|
|
97
|
+
values = []
|
|
98
|
+
for cond in self.conditions.values():
|
|
99
|
+
if name in cond.metrics:
|
|
100
|
+
values.append(cond.metrics[name].aggregate_value)
|
|
101
|
+
if values:
|
|
102
|
+
self.aggregate_metrics[f"{name}_mean"] = float(np.nanmean(values))
|
|
103
|
+
self.aggregate_metrics[f"{name}_std"] = float(np.nanstd(values))
|
|
104
|
+
self.aggregate_metrics[f"{name}_median"] = float(np.nanmedian(values))
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def n_conditions(self) -> int:
|
|
108
|
+
return len(self.conditions)
|
|
109
|
+
|
|
110
|
+
def to_dataframe(self) -> pd.DataFrame:
|
|
111
|
+
"""Convert condition results to DataFrame."""
|
|
112
|
+
rows = [cond.summary for cond in self.conditions.values()]
|
|
113
|
+
return pd.DataFrame(rows)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class EvaluationResult:
|
|
118
|
+
"""
|
|
119
|
+
Complete evaluation results container.
|
|
120
|
+
|
|
121
|
+
Stores results per split and provides serialization methods.
|
|
122
|
+
"""
|
|
123
|
+
splits: Dict[str, SplitResult] = field(default_factory=dict)
|
|
124
|
+
gene_names: List[str] = field(default_factory=list)
|
|
125
|
+
condition_columns: List[str] = field(default_factory=list)
|
|
126
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
127
|
+
|
|
128
|
+
# Paths to saved outputs
|
|
129
|
+
output_dir: Optional[Path] = None
|
|
130
|
+
|
|
131
|
+
def add_split(self, split: SplitResult):
|
|
132
|
+
"""Add a split result."""
|
|
133
|
+
self.splits[split.split_name] = split
|
|
134
|
+
|
|
135
|
+
def get_split(self, name: str) -> Optional[SplitResult]:
|
|
136
|
+
"""Get results for a specific split."""
|
|
137
|
+
return self.splits.get(name)
|
|
138
|
+
|
|
139
|
+
def get_all_conditions(self) -> List[ConditionResult]:
|
|
140
|
+
"""Get all condition results across splits."""
|
|
141
|
+
conditions = []
|
|
142
|
+
for split in self.splits.values():
|
|
143
|
+
conditions.extend(split.conditions.values())
|
|
144
|
+
return conditions
|
|
145
|
+
|
|
146
|
+
def get_metric_summary(self, metric_name: str) -> Dict[str, Dict[str, float]]:
|
|
147
|
+
"""
|
|
148
|
+
Get summary of a metric across all splits.
|
|
149
|
+
|
|
150
|
+
Returns dict: split_name -> {mean, std, median}
|
|
151
|
+
"""
|
|
152
|
+
summary = {}
|
|
153
|
+
for split_name, split in self.splits.items():
|
|
154
|
+
values = []
|
|
155
|
+
for cond in split.conditions.values():
|
|
156
|
+
if metric_name in cond.metrics:
|
|
157
|
+
values.append(cond.metrics[metric_name].aggregate_value)
|
|
158
|
+
if values:
|
|
159
|
+
summary[split_name] = {
|
|
160
|
+
"mean": float(np.nanmean(values)),
|
|
161
|
+
"std": float(np.nanstd(values)),
|
|
162
|
+
"median": float(np.nanmedian(values)),
|
|
163
|
+
"min": float(np.nanmin(values)),
|
|
164
|
+
"max": float(np.nanmax(values)),
|
|
165
|
+
"n_conditions": len(values),
|
|
166
|
+
}
|
|
167
|
+
return summary
|
|
168
|
+
|
|
169
|
+
def to_dataframe(self, include_split: bool = True) -> pd.DataFrame:
|
|
170
|
+
"""
|
|
171
|
+
Convert all results to a single DataFrame.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
include_split : bool
|
|
176
|
+
Whether to include split column
|
|
177
|
+
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
pd.DataFrame
|
|
181
|
+
DataFrame with one row per condition
|
|
182
|
+
"""
|
|
183
|
+
dfs = []
|
|
184
|
+
for split_name, split in self.splits.items():
|
|
185
|
+
df = split.to_dataframe()
|
|
186
|
+
if include_split:
|
|
187
|
+
df["split"] = split_name
|
|
188
|
+
dfs.append(df)
|
|
189
|
+
|
|
190
|
+
if not dfs:
|
|
191
|
+
return pd.DataFrame()
|
|
192
|
+
|
|
193
|
+
return pd.concat(dfs, ignore_index=True)
|
|
194
|
+
|
|
195
|
+
def to_per_gene_dataframe(self, metric_name: str) -> pd.DataFrame:
|
|
196
|
+
"""
|
|
197
|
+
Get per-gene metric values as DataFrame.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
metric_name : str
|
|
202
|
+
Name of metric to extract
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
pd.DataFrame
|
|
207
|
+
DataFrame with genes as rows, conditions as columns
|
|
208
|
+
"""
|
|
209
|
+
data = {}
|
|
210
|
+
for split in self.splits.values():
|
|
211
|
+
for cond_key, cond in split.conditions.items():
|
|
212
|
+
if metric_name in cond.metrics:
|
|
213
|
+
col_name = f"{split.split_name}_{cond_key}"
|
|
214
|
+
data[col_name] = cond.metrics[metric_name].per_gene_values
|
|
215
|
+
|
|
216
|
+
if not data:
|
|
217
|
+
return pd.DataFrame()
|
|
218
|
+
|
|
219
|
+
df = pd.DataFrame(data, index=self.gene_names)
|
|
220
|
+
return df
|
|
221
|
+
|
|
222
|
+
def summary(self) -> Dict[str, Any]:
|
|
223
|
+
"""Get comprehensive summary."""
|
|
224
|
+
result = {
|
|
225
|
+
"n_splits": len(self.splits),
|
|
226
|
+
"n_genes": len(self.gene_names),
|
|
227
|
+
"condition_columns": self.condition_columns,
|
|
228
|
+
"splits": {},
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
for split_name, split in self.splits.items():
|
|
232
|
+
split.compute_aggregates()
|
|
233
|
+
result["splits"][split_name] = {
|
|
234
|
+
"n_conditions": split.n_conditions,
|
|
235
|
+
"aggregates": split.aggregate_metrics,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
def save(self, output_dir: Union[str, Path]):
|
|
241
|
+
"""
|
|
242
|
+
Save results to directory.
|
|
243
|
+
|
|
244
|
+
Saves:
|
|
245
|
+
- summary.json: Aggregate metrics and metadata
|
|
246
|
+
- results.csv: Per-condition metrics
|
|
247
|
+
- per_gene_*.csv: Per-gene metrics for each metric type
|
|
248
|
+
"""
|
|
249
|
+
output_dir = Path(output_dir)
|
|
250
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
251
|
+
self.output_dir = output_dir
|
|
252
|
+
|
|
253
|
+
# Save summary
|
|
254
|
+
summary = self.summary()
|
|
255
|
+
summary["metadata"] = self.metadata
|
|
256
|
+
|
|
257
|
+
with open(output_dir / "summary.json", "w") as f:
|
|
258
|
+
json.dump(summary, f, indent=2, default=str)
|
|
259
|
+
|
|
260
|
+
# Save condition-level results
|
|
261
|
+
df = self.to_dataframe()
|
|
262
|
+
if not df.empty:
|
|
263
|
+
df.to_csv(output_dir / "results.csv", index=False)
|
|
264
|
+
|
|
265
|
+
# Save per-gene metrics
|
|
266
|
+
metric_names = set()
|
|
267
|
+
for split in self.splits.values():
|
|
268
|
+
for cond in split.conditions.values():
|
|
269
|
+
metric_names.update(cond.metrics.keys())
|
|
270
|
+
|
|
271
|
+
for metric_name in metric_names:
|
|
272
|
+
df_gene = self.to_per_gene_dataframe(metric_name)
|
|
273
|
+
if not df_gene.empty:
|
|
274
|
+
df_gene.to_csv(output_dir / f"per_gene_{metric_name}.csv")
|
|
275
|
+
|
|
276
|
+
return output_dir
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def load(cls, output_dir: Union[str, Path]) -> "EvaluationResult":
|
|
280
|
+
"""
|
|
281
|
+
Load results from directory.
|
|
282
|
+
|
|
283
|
+
Note: Currently loads summary only, not full per-gene data.
|
|
284
|
+
"""
|
|
285
|
+
output_dir = Path(output_dir)
|
|
286
|
+
|
|
287
|
+
with open(output_dir / "summary.json") as f:
|
|
288
|
+
summary = json.load(f)
|
|
289
|
+
|
|
290
|
+
result = cls(
|
|
291
|
+
gene_names=[],
|
|
292
|
+
condition_columns=summary.get("condition_columns", []),
|
|
293
|
+
metadata=summary.get("metadata", {}),
|
|
294
|
+
)
|
|
295
|
+
result.output_dir = output_dir
|
|
296
|
+
|
|
297
|
+
# Load condition-level results if available
|
|
298
|
+
results_path = output_dir / "results.csv"
|
|
299
|
+
if results_path.exists():
|
|
300
|
+
df = pd.read_csv(results_path)
|
|
301
|
+
# Reconstruct splits and conditions from DataFrame
|
|
302
|
+
for split_name in df["split"].unique() if "split" in df.columns else ["all"]:
|
|
303
|
+
split_df = df[df["split"] == split_name] if "split" in df.columns else df
|
|
304
|
+
split_result = SplitResult(split_name=split_name)
|
|
305
|
+
|
|
306
|
+
for _, row in split_df.iterrows():
|
|
307
|
+
cond = ConditionResult(
|
|
308
|
+
condition_key=row.get("condition_key", ""),
|
|
309
|
+
split=split_name,
|
|
310
|
+
n_real_samples=row.get("n_real_samples", 0),
|
|
311
|
+
n_generated_samples=row.get("n_generated_samples", 0),
|
|
312
|
+
n_genes=row.get("n_genes", 0),
|
|
313
|
+
gene_names=[],
|
|
314
|
+
perturbation=row.get("perturbation"),
|
|
315
|
+
)
|
|
316
|
+
split_result.add_condition(cond)
|
|
317
|
+
|
|
318
|
+
result.add_split(split_result)
|
|
319
|
+
|
|
320
|
+
return result
|
|
321
|
+
|
|
322
|
+
def __repr__(self) -> str:
|
|
323
|
+
n_conds = sum(s.n_conditions for s in self.splits.values())
|
|
324
|
+
return (
|
|
325
|
+
f"EvaluationResult(n_splits={len(self.splits)}, "
|
|
326
|
+
f"n_conditions={n_conds}, n_genes={len(self.gene_names)})"
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
# Import MetricResult here to avoid circular import
|
|
331
|
+
from .metrics.base_metric import MetricResult
|
|
332
|
+
|
|
333
|
+
# Update forward references
|
|
334
|
+
ConditionResult.__annotations__["metrics"] = Dict[str, MetricResult]
|