gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/results.py ADDED
@@ -0,0 +1,334 @@
1
+ """
2
+ Results container classes for evaluation outputs.
3
+
4
+ Provides structured storage for metrics, conditions, and visualization data.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+ from typing import Dict, List, Optional, Any, Union
10
+ import numpy as np
11
+ import pandas as pd
12
+ import json
13
+ from pathlib import Path
14
+
15
+
16
+ @dataclass
17
+ class ConditionResult:
18
+ """
19
+ Results for a single condition (perturbation + covariates).
20
+ """
21
+ condition_key: str
22
+ split: str
23
+ n_real_samples: int
24
+ n_generated_samples: int
25
+ n_genes: int
26
+ gene_names: List[str]
27
+ metrics: Dict[str, "MetricResult"] = field(default_factory=dict)
28
+
29
+ # Mean expression profiles
30
+ real_mean: Optional[np.ndarray] = None
31
+ generated_mean: Optional[np.ndarray] = None
32
+
33
+ # Parsed condition components
34
+ perturbation: Optional[str] = None
35
+ covariates: Dict[str, str] = field(default_factory=dict)
36
+
37
+ def add_metric(self, name: str, result: "MetricResult"):
38
+ """Add a metric result."""
39
+ self.metrics[name] = result
40
+
41
+ def get_metric_value(self, name: str) -> Optional[float]:
42
+ """Get aggregate value for a metric."""
43
+ if name in self.metrics:
44
+ return self.metrics[name].aggregate_value
45
+ return None
46
+
47
+ def get_per_gene_values(self, name: str) -> Optional[np.ndarray]:
48
+ """Get per-gene values for a metric."""
49
+ if name in self.metrics:
50
+ return self.metrics[name].per_gene_values
51
+ return None
52
+
53
+ @property
54
+ def summary(self) -> Dict[str, Any]:
55
+ """Get summary dictionary."""
56
+ result = {
57
+ "condition_key": self.condition_key,
58
+ "split": self.split,
59
+ "perturbation": self.perturbation,
60
+ "n_real_samples": self.n_real_samples,
61
+ "n_generated_samples": self.n_generated_samples,
62
+ "n_genes": self.n_genes,
63
+ }
64
+ result.update(self.covariates)
65
+
66
+ for name, metric in self.metrics.items():
67
+ result[name] = metric.aggregate_value
68
+
69
+ return result
70
+
71
+
72
+ @dataclass
73
+ class SplitResult:
74
+ """
75
+ Results for a single split (train/test/all).
76
+ """
77
+ split_name: str
78
+ conditions: Dict[str, ConditionResult] = field(default_factory=dict)
79
+ aggregate_metrics: Dict[str, float] = field(default_factory=dict)
80
+
81
+ def add_condition(self, condition: ConditionResult):
82
+ """Add a condition result."""
83
+ self.conditions[condition.condition_key] = condition
84
+
85
+ def compute_aggregates(self):
86
+ """Compute aggregate metrics across all conditions."""
87
+ if not self.conditions:
88
+ return
89
+
90
+ # Collect all metric names
91
+ metric_names = set()
92
+ for cond in self.conditions.values():
93
+ metric_names.update(cond.metrics.keys())
94
+
95
+ # Compute mean across conditions for each metric
96
+ for name in metric_names:
97
+ values = []
98
+ for cond in self.conditions.values():
99
+ if name in cond.metrics:
100
+ values.append(cond.metrics[name].aggregate_value)
101
+ if values:
102
+ self.aggregate_metrics[f"{name}_mean"] = float(np.nanmean(values))
103
+ self.aggregate_metrics[f"{name}_std"] = float(np.nanstd(values))
104
+ self.aggregate_metrics[f"{name}_median"] = float(np.nanmedian(values))
105
+
106
+ @property
107
+ def n_conditions(self) -> int:
108
+ return len(self.conditions)
109
+
110
+ def to_dataframe(self) -> pd.DataFrame:
111
+ """Convert condition results to DataFrame."""
112
+ rows = [cond.summary for cond in self.conditions.values()]
113
+ return pd.DataFrame(rows)
114
+
115
+
116
+ @dataclass
117
+ class EvaluationResult:
118
+ """
119
+ Complete evaluation results container.
120
+
121
+ Stores results per split and provides serialization methods.
122
+ """
123
+ splits: Dict[str, SplitResult] = field(default_factory=dict)
124
+ gene_names: List[str] = field(default_factory=list)
125
+ condition_columns: List[str] = field(default_factory=list)
126
+ metadata: Dict[str, Any] = field(default_factory=dict)
127
+
128
+ # Paths to saved outputs
129
+ output_dir: Optional[Path] = None
130
+
131
+ def add_split(self, split: SplitResult):
132
+ """Add a split result."""
133
+ self.splits[split.split_name] = split
134
+
135
+ def get_split(self, name: str) -> Optional[SplitResult]:
136
+ """Get results for a specific split."""
137
+ return self.splits.get(name)
138
+
139
+ def get_all_conditions(self) -> List[ConditionResult]:
140
+ """Get all condition results across splits."""
141
+ conditions = []
142
+ for split in self.splits.values():
143
+ conditions.extend(split.conditions.values())
144
+ return conditions
145
+
146
+ def get_metric_summary(self, metric_name: str) -> Dict[str, Dict[str, float]]:
147
+ """
148
+ Get summary of a metric across all splits.
149
+
150
+ Returns dict: split_name -> {mean, std, median}
151
+ """
152
+ summary = {}
153
+ for split_name, split in self.splits.items():
154
+ values = []
155
+ for cond in split.conditions.values():
156
+ if metric_name in cond.metrics:
157
+ values.append(cond.metrics[metric_name].aggregate_value)
158
+ if values:
159
+ summary[split_name] = {
160
+ "mean": float(np.nanmean(values)),
161
+ "std": float(np.nanstd(values)),
162
+ "median": float(np.nanmedian(values)),
163
+ "min": float(np.nanmin(values)),
164
+ "max": float(np.nanmax(values)),
165
+ "n_conditions": len(values),
166
+ }
167
+ return summary
168
+
169
+ def to_dataframe(self, include_split: bool = True) -> pd.DataFrame:
170
+ """
171
+ Convert all results to a single DataFrame.
172
+
173
+ Parameters
174
+ ----------
175
+ include_split : bool
176
+ Whether to include split column
177
+
178
+ Returns
179
+ -------
180
+ pd.DataFrame
181
+ DataFrame with one row per condition
182
+ """
183
+ dfs = []
184
+ for split_name, split in self.splits.items():
185
+ df = split.to_dataframe()
186
+ if include_split:
187
+ df["split"] = split_name
188
+ dfs.append(df)
189
+
190
+ if not dfs:
191
+ return pd.DataFrame()
192
+
193
+ return pd.concat(dfs, ignore_index=True)
194
+
195
+ def to_per_gene_dataframe(self, metric_name: str) -> pd.DataFrame:
196
+ """
197
+ Get per-gene metric values as DataFrame.
198
+
199
+ Parameters
200
+ ----------
201
+ metric_name : str
202
+ Name of metric to extract
203
+
204
+ Returns
205
+ -------
206
+ pd.DataFrame
207
+ DataFrame with genes as rows, conditions as columns
208
+ """
209
+ data = {}
210
+ for split in self.splits.values():
211
+ for cond_key, cond in split.conditions.items():
212
+ if metric_name in cond.metrics:
213
+ col_name = f"{split.split_name}_{cond_key}"
214
+ data[col_name] = cond.metrics[metric_name].per_gene_values
215
+
216
+ if not data:
217
+ return pd.DataFrame()
218
+
219
+ df = pd.DataFrame(data, index=self.gene_names)
220
+ return df
221
+
222
+ def summary(self) -> Dict[str, Any]:
223
+ """Get comprehensive summary."""
224
+ result = {
225
+ "n_splits": len(self.splits),
226
+ "n_genes": len(self.gene_names),
227
+ "condition_columns": self.condition_columns,
228
+ "splits": {},
229
+ }
230
+
231
+ for split_name, split in self.splits.items():
232
+ split.compute_aggregates()
233
+ result["splits"][split_name] = {
234
+ "n_conditions": split.n_conditions,
235
+ "aggregates": split.aggregate_metrics,
236
+ }
237
+
238
+ return result
239
+
240
+ def save(self, output_dir: Union[str, Path]):
241
+ """
242
+ Save results to directory.
243
+
244
+ Saves:
245
+ - summary.json: Aggregate metrics and metadata
246
+ - results.csv: Per-condition metrics
247
+ - per_gene_*.csv: Per-gene metrics for each metric type
248
+ """
249
+ output_dir = Path(output_dir)
250
+ output_dir.mkdir(parents=True, exist_ok=True)
251
+ self.output_dir = output_dir
252
+
253
+ # Save summary
254
+ summary = self.summary()
255
+ summary["metadata"] = self.metadata
256
+
257
+ with open(output_dir / "summary.json", "w") as f:
258
+ json.dump(summary, f, indent=2, default=str)
259
+
260
+ # Save condition-level results
261
+ df = self.to_dataframe()
262
+ if not df.empty:
263
+ df.to_csv(output_dir / "results.csv", index=False)
264
+
265
+ # Save per-gene metrics
266
+ metric_names = set()
267
+ for split in self.splits.values():
268
+ for cond in split.conditions.values():
269
+ metric_names.update(cond.metrics.keys())
270
+
271
+ for metric_name in metric_names:
272
+ df_gene = self.to_per_gene_dataframe(metric_name)
273
+ if not df_gene.empty:
274
+ df_gene.to_csv(output_dir / f"per_gene_{metric_name}.csv")
275
+
276
+ return output_dir
277
+
278
+ @classmethod
279
+ def load(cls, output_dir: Union[str, Path]) -> "EvaluationResult":
280
+ """
281
+ Load results from directory.
282
+
283
+ Note: Currently loads summary only, not full per-gene data.
284
+ """
285
+ output_dir = Path(output_dir)
286
+
287
+ with open(output_dir / "summary.json") as f:
288
+ summary = json.load(f)
289
+
290
+ result = cls(
291
+ gene_names=[],
292
+ condition_columns=summary.get("condition_columns", []),
293
+ metadata=summary.get("metadata", {}),
294
+ )
295
+ result.output_dir = output_dir
296
+
297
+ # Load condition-level results if available
298
+ results_path = output_dir / "results.csv"
299
+ if results_path.exists():
300
+ df = pd.read_csv(results_path)
301
+ # Reconstruct splits and conditions from DataFrame
302
+ for split_name in df["split"].unique() if "split" in df.columns else ["all"]:
303
+ split_df = df[df["split"] == split_name] if "split" in df.columns else df
304
+ split_result = SplitResult(split_name=split_name)
305
+
306
+ for _, row in split_df.iterrows():
307
+ cond = ConditionResult(
308
+ condition_key=row.get("condition_key", ""),
309
+ split=split_name,
310
+ n_real_samples=row.get("n_real_samples", 0),
311
+ n_generated_samples=row.get("n_generated_samples", 0),
312
+ n_genes=row.get("n_genes", 0),
313
+ gene_names=[],
314
+ perturbation=row.get("perturbation"),
315
+ )
316
+ split_result.add_condition(cond)
317
+
318
+ result.add_split(split_result)
319
+
320
+ return result
321
+
322
+ def __repr__(self) -> str:
323
+ n_conds = sum(s.n_conditions for s in self.splits.values())
324
+ return (
325
+ f"EvaluationResult(n_splits={len(self.splits)}, "
326
+ f"n_conditions={n_conds}, n_genes={len(self.gene_names)})"
327
+ )
328
+
329
+
330
+ # Import MetricResult here to avoid circular import
331
+ from .metrics.base_metric import MetricResult
332
+
333
+ # Update forward references
334
+ ConditionResult.__annotations__["metrics"] = Dict[str, MetricResult]