gengeneeval 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +43 -1
- geneval/deg/__init__.py +69 -0
- geneval/deg/context.py +271 -0
- geneval/deg/detection.py +578 -0
- geneval/deg/evaluator.py +821 -0
- geneval/deg/visualization.py +376 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/METADATA +125 -3
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/RECORD +11 -6
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/WHEEL +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/licenses/LICENSE +0 -0
geneval/deg/evaluator.py
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DEG-focused evaluator for GenGeneEval.
|
|
3
|
+
|
|
4
|
+
Computes metrics on differentially expressed genes with full control:
|
|
5
|
+
- Comparison of DEG-only vs all-genes metrics
|
|
6
|
+
- Configurable DEG selection (top N, p-value, log fold change thresholds)
|
|
7
|
+
- Per-context evaluation (covariates × perturbations)
|
|
8
|
+
- Fast DEG detection with GPU acceleration
|
|
9
|
+
- Aggregated and expanded result reporting
|
|
10
|
+
"""
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Optional, List, Dict, Union, Any, Literal
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import warnings
|
|
19
|
+
|
|
20
|
+
from .detection import (
|
|
21
|
+
compute_degs_fast,
|
|
22
|
+
compute_degs_gpu,
|
|
23
|
+
compute_degs_auto,
|
|
24
|
+
DEGResult,
|
|
25
|
+
DEGMethod,
|
|
26
|
+
)
|
|
27
|
+
from .context import (
|
|
28
|
+
ContextEvaluator,
|
|
29
|
+
ContextResult,
|
|
30
|
+
get_context_id,
|
|
31
|
+
get_contexts,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Import metrics
|
|
35
|
+
from ..metrics.base_metric import BaseMetric
|
|
36
|
+
from ..metrics.correlation import PearsonCorrelation, SpearmanCorrelation
|
|
37
|
+
from ..metrics.distances import (
|
|
38
|
+
Wasserstein1Distance,
|
|
39
|
+
Wasserstein2Distance,
|
|
40
|
+
MMDDistance,
|
|
41
|
+
EnergyDistance,
|
|
42
|
+
)
|
|
43
|
+
from ..metrics.accelerated import (
|
|
44
|
+
get_available_backends,
|
|
45
|
+
vectorized_wasserstein1,
|
|
46
|
+
vectorized_mmd,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class DEGSettings:
|
|
52
|
+
"""Settings for DEG detection and filtering.
|
|
53
|
+
|
|
54
|
+
Attributes
|
|
55
|
+
----------
|
|
56
|
+
method : str
|
|
57
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
58
|
+
pval_threshold : float
|
|
59
|
+
P-value threshold for significance
|
|
60
|
+
lfc_threshold : float
|
|
61
|
+
Absolute log2 fold change threshold
|
|
62
|
+
n_top_degs : int, optional
|
|
63
|
+
If set, use only top N DEGs by significance (overrides threshold filtering)
|
|
64
|
+
min_degs : int
|
|
65
|
+
Minimum number of DEGs required to compute metrics
|
|
66
|
+
"""
|
|
67
|
+
method: str = "welch"
|
|
68
|
+
pval_threshold: float = 0.05
|
|
69
|
+
lfc_threshold: float = 0.5
|
|
70
|
+
n_top_degs: Optional[int] = None
|
|
71
|
+
min_degs: int = 5
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
74
|
+
"""Convert to dictionary."""
|
|
75
|
+
return {
|
|
76
|
+
"deg_method": self.method,
|
|
77
|
+
"pval_threshold": self.pval_threshold,
|
|
78
|
+
"lfc_threshold": self.lfc_threshold,
|
|
79
|
+
"n_top_degs": self.n_top_degs,
|
|
80
|
+
"min_degs": self.min_degs,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class ContextMetrics:
|
|
86
|
+
"""Metrics for a single context, comparing DEG-only vs all genes.
|
|
87
|
+
|
|
88
|
+
Attributes
|
|
89
|
+
----------
|
|
90
|
+
context_id : str
|
|
91
|
+
Context identifier
|
|
92
|
+
context_values : Dict
|
|
93
|
+
Context column values
|
|
94
|
+
n_samples_real : int
|
|
95
|
+
Number of real samples
|
|
96
|
+
n_samples_gen : int
|
|
97
|
+
Number of generated samples
|
|
98
|
+
n_genes_total : int
|
|
99
|
+
Total number of genes
|
|
100
|
+
deg_result : DEGResult
|
|
101
|
+
DEG detection results
|
|
102
|
+
deg_metrics : Dict[str, float]
|
|
103
|
+
Metrics computed on DEGs only
|
|
104
|
+
all_genes_metrics : Dict[str, float]
|
|
105
|
+
Metrics computed on all genes
|
|
106
|
+
"""
|
|
107
|
+
context_id: str
|
|
108
|
+
context_values: Dict[str, Any]
|
|
109
|
+
n_samples_real: int
|
|
110
|
+
n_samples_gen: int
|
|
111
|
+
n_genes_total: int
|
|
112
|
+
deg_result: Optional[DEGResult]
|
|
113
|
+
deg_metrics: Dict[str, float]
|
|
114
|
+
all_genes_metrics: Dict[str, float]
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def n_degs(self) -> int:
|
|
118
|
+
"""Number of DEGs."""
|
|
119
|
+
return self.deg_result.n_degs if self.deg_result else 0
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def deg_indices_used(self) -> np.ndarray:
|
|
123
|
+
"""DEG indices actually used for metrics."""
|
|
124
|
+
return self.deg_result.deg_indices if self.deg_result else np.array([])
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class DEGEvaluationResult:
|
|
129
|
+
"""Complete DEG evaluation results with comparison to all-genes metrics.
|
|
130
|
+
|
|
131
|
+
Attributes
|
|
132
|
+
----------
|
|
133
|
+
context_results : List[ContextMetrics]
|
|
134
|
+
Results for each context with both DEG and all-gene metrics
|
|
135
|
+
aggregated_metrics : pd.DataFrame
|
|
136
|
+
Aggregated metrics across contexts (both DEG and all-genes)
|
|
137
|
+
expanded_metrics : pd.DataFrame
|
|
138
|
+
Per-context expanded metrics (both DEG and all-genes)
|
|
139
|
+
deg_summary : pd.DataFrame
|
|
140
|
+
Summary of DEG detection per context
|
|
141
|
+
comparison_summary : pd.DataFrame
|
|
142
|
+
Comparison between DEG-only and all-genes metrics
|
|
143
|
+
gene_names : np.ndarray
|
|
144
|
+
All gene names
|
|
145
|
+
settings : Dict
|
|
146
|
+
Evaluation settings including DEG parameters
|
|
147
|
+
"""
|
|
148
|
+
context_results: List[ContextMetrics]
|
|
149
|
+
aggregated_metrics: pd.DataFrame
|
|
150
|
+
expanded_metrics: pd.DataFrame
|
|
151
|
+
deg_summary: pd.DataFrame
|
|
152
|
+
comparison_summary: pd.DataFrame
|
|
153
|
+
gene_names: np.ndarray
|
|
154
|
+
settings: Dict[str, Any]
|
|
155
|
+
|
|
156
|
+
def save(self, output_dir: Union[str, Path]) -> None:
|
|
157
|
+
"""Save results to directory."""
|
|
158
|
+
output_dir = Path(output_dir)
|
|
159
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
|
|
161
|
+
self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv", index=False)
|
|
162
|
+
self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv", index=False)
|
|
163
|
+
self.deg_summary.to_csv(output_dir / "deg_summary.csv", index=False)
|
|
164
|
+
self.comparison_summary.to_csv(output_dir / "deg_vs_all_comparison.csv", index=False)
|
|
165
|
+
|
|
166
|
+
# Save settings
|
|
167
|
+
import json
|
|
168
|
+
with open(output_dir / "settings.json", "w") as f:
|
|
169
|
+
json.dump(self.settings, f, indent=2, default=str)
|
|
170
|
+
|
|
171
|
+
# Save per-context DEG results
|
|
172
|
+
deg_dir = output_dir / "deg_per_context"
|
|
173
|
+
deg_dir.mkdir(exist_ok=True)
|
|
174
|
+
for ctx_result in self.context_results:
|
|
175
|
+
if ctx_result.deg_result is not None:
|
|
176
|
+
ctx_result.deg_result.to_dataframe().to_csv(
|
|
177
|
+
deg_dir / f"{ctx_result.context_id}_degs.csv", index=False
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def get_deg_only_metrics(self) -> pd.DataFrame:
|
|
181
|
+
"""Get expanded metrics for DEGs only."""
|
|
182
|
+
base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
|
|
183
|
+
deg_cols = [c for c in self.expanded_metrics.columns if c.startswith("deg_")]
|
|
184
|
+
cols = [c for c in base_cols if c in self.expanded_metrics.columns] + deg_cols
|
|
185
|
+
return self.expanded_metrics[cols].copy()
|
|
186
|
+
|
|
187
|
+
def get_all_genes_metrics(self) -> pd.DataFrame:
|
|
188
|
+
"""Get expanded metrics for all genes."""
|
|
189
|
+
base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
|
|
190
|
+
all_cols = [c for c in self.expanded_metrics.columns if c.startswith("all_")]
|
|
191
|
+
cols = [c for c in base_cols if c in self.expanded_metrics.columns] + all_cols
|
|
192
|
+
return self.expanded_metrics[cols].copy()
|
|
193
|
+
|
|
194
|
+
def __repr__(self) -> str:
|
|
195
|
+
n_degs_avg = self.deg_summary["n_degs"].mean() if len(self.deg_summary) > 0 else 0
|
|
196
|
+
return (
|
|
197
|
+
f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
|
|
198
|
+
f"avg_degs={n_degs_avg:.1f}, "
|
|
199
|
+
f"settings={self.settings.get('deg_method', 'unknown')})"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class DEGEvaluator:
|
|
204
|
+
"""
|
|
205
|
+
Evaluator that computes metrics on DEGs with comparison to all genes.
|
|
206
|
+
|
|
207
|
+
This evaluator:
|
|
208
|
+
1. Detects DEGs for each perturbation context
|
|
209
|
+
2. Computes distributional metrics on BOTH DEG genes AND all genes
|
|
210
|
+
3. Provides comparison between DEG-focused and all-genes evaluation
|
|
211
|
+
4. Reports per-context and aggregated results
|
|
212
|
+
|
|
213
|
+
Parameters
|
|
214
|
+
----------
|
|
215
|
+
real_data : np.ndarray
|
|
216
|
+
Real expression matrix (n_samples, n_genes)
|
|
217
|
+
generated_data : np.ndarray
|
|
218
|
+
Generated expression matrix (n_samples, n_genes)
|
|
219
|
+
real_obs : pd.DataFrame
|
|
220
|
+
Real data observation metadata
|
|
221
|
+
generated_obs : pd.DataFrame
|
|
222
|
+
Generated data observation metadata
|
|
223
|
+
condition_columns : List[str]
|
|
224
|
+
Columns defining contexts (e.g., ["cell_type", "perturbation"])
|
|
225
|
+
gene_names : np.ndarray, optional
|
|
226
|
+
Gene names
|
|
227
|
+
control_key : str
|
|
228
|
+
Value indicating control samples (default: "control")
|
|
229
|
+
perturbation_column : str, optional
|
|
230
|
+
Column containing perturbation info. If None, uses first condition column.
|
|
231
|
+
deg_method : str
|
|
232
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
233
|
+
pval_threshold : float
|
|
234
|
+
P-value threshold for DEG significance (default: 0.05)
|
|
235
|
+
lfc_threshold : float
|
|
236
|
+
Log2 fold change threshold (default: 0.5)
|
|
237
|
+
n_top_degs : int, optional
|
|
238
|
+
If set, use only top N DEGs by significance instead of thresholds
|
|
239
|
+
min_degs : int
|
|
240
|
+
Minimum DEGs required to compute DEG-specific metrics (default: 5)
|
|
241
|
+
compute_all_genes : bool
|
|
242
|
+
Whether to also compute metrics on all genes (default: True)
|
|
243
|
+
metrics : List[str], optional
|
|
244
|
+
Metrics to compute. Default: all supported metrics.
|
|
245
|
+
n_jobs : int
|
|
246
|
+
Number of parallel CPU jobs
|
|
247
|
+
device : str
|
|
248
|
+
Compute device: "cpu", "cuda", "mps", "auto"
|
|
249
|
+
verbose : bool
|
|
250
|
+
Print progress
|
|
251
|
+
|
|
252
|
+
Examples
|
|
253
|
+
--------
|
|
254
|
+
>>> # Basic usage - computes both DEG and all-genes metrics
|
|
255
|
+
>>> evaluator = DEGEvaluator(
|
|
256
|
+
... real_data, generated_data,
|
|
257
|
+
... real_obs, generated_obs,
|
|
258
|
+
... condition_columns=["perturbation"],
|
|
259
|
+
... )
|
|
260
|
+
>>> results = evaluator.evaluate()
|
|
261
|
+
>>> print(results.comparison_summary) # DEG vs all-genes comparison
|
|
262
|
+
|
|
263
|
+
>>> # Use top 100 DEGs only
|
|
264
|
+
>>> evaluator = DEGEvaluator(
|
|
265
|
+
... real_data, generated_data,
|
|
266
|
+
... real_obs, generated_obs,
|
|
267
|
+
... condition_columns=["perturbation"],
|
|
268
|
+
... n_top_degs=100, # Use top 100 most significant DEGs
|
|
269
|
+
... )
|
|
270
|
+
|
|
271
|
+
>>> # Stricter thresholds
|
|
272
|
+
>>> evaluator = DEGEvaluator(
|
|
273
|
+
... real_data, generated_data,
|
|
274
|
+
... real_obs, generated_obs,
|
|
275
|
+
... condition_columns=["perturbation"],
|
|
276
|
+
... pval_threshold=0.01, # More stringent p-value
|
|
277
|
+
... lfc_threshold=1.0, # log2 FC > 1 (2-fold change)
|
|
278
|
+
... )
|
|
279
|
+
|
|
280
|
+
>>> # DEGs only (no all-genes metrics for speed)
|
|
281
|
+
>>> evaluator = DEGEvaluator(
|
|
282
|
+
... real_data, generated_data,
|
|
283
|
+
... real_obs, generated_obs,
|
|
284
|
+
... condition_columns=["perturbation"],
|
|
285
|
+
... compute_all_genes=False,
|
|
286
|
+
... )
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
# Supported metrics
|
|
290
|
+
SUPPORTED_METRICS = [
|
|
291
|
+
"wasserstein_1",
|
|
292
|
+
"wasserstein_2",
|
|
293
|
+
"mmd",
|
|
294
|
+
"energy",
|
|
295
|
+
"pearson",
|
|
296
|
+
"spearman",
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
def __init__(
|
|
300
|
+
self,
|
|
301
|
+
real_data: np.ndarray,
|
|
302
|
+
generated_data: np.ndarray,
|
|
303
|
+
real_obs: pd.DataFrame,
|
|
304
|
+
generated_obs: pd.DataFrame,
|
|
305
|
+
condition_columns: List[str],
|
|
306
|
+
gene_names: Optional[np.ndarray] = None,
|
|
307
|
+
control_key: str = "control",
|
|
308
|
+
perturbation_column: Optional[str] = None,
|
|
309
|
+
deg_method: DEGMethod = "welch",
|
|
310
|
+
pval_threshold: float = 0.05,
|
|
311
|
+
lfc_threshold: float = 0.5,
|
|
312
|
+
n_top_degs: Optional[int] = None,
|
|
313
|
+
min_degs: int = 5,
|
|
314
|
+
compute_all_genes: bool = True,
|
|
315
|
+
metrics: Optional[List[str]] = None,
|
|
316
|
+
n_jobs: int = 1,
|
|
317
|
+
device: str = "cpu",
|
|
318
|
+
verbose: bool = True,
|
|
319
|
+
):
|
|
320
|
+
self.real_data = np.asarray(real_data, dtype=np.float32)
|
|
321
|
+
self.generated_data = np.asarray(generated_data, dtype=np.float32)
|
|
322
|
+
self.real_obs = real_obs.reset_index(drop=True)
|
|
323
|
+
self.generated_obs = generated_obs.reset_index(drop=True)
|
|
324
|
+
self.condition_columns = condition_columns
|
|
325
|
+
self.n_genes = real_data.shape[1]
|
|
326
|
+
self.gene_names = gene_names if gene_names is not None else np.array(
|
|
327
|
+
[f"Gene_{i}" for i in range(self.n_genes)]
|
|
328
|
+
)
|
|
329
|
+
self.control_key = control_key
|
|
330
|
+
self.perturbation_column = perturbation_column or condition_columns[0]
|
|
331
|
+
|
|
332
|
+
# DEG settings
|
|
333
|
+
self.deg_settings = DEGSettings(
|
|
334
|
+
method=deg_method,
|
|
335
|
+
pval_threshold=pval_threshold,
|
|
336
|
+
lfc_threshold=lfc_threshold,
|
|
337
|
+
n_top_degs=n_top_degs,
|
|
338
|
+
min_degs=min_degs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
self.compute_all_genes = compute_all_genes
|
|
342
|
+
self.metrics = metrics or self.SUPPORTED_METRICS
|
|
343
|
+
self.n_jobs = n_jobs
|
|
344
|
+
self.device = device
|
|
345
|
+
self.verbose = verbose
|
|
346
|
+
|
|
347
|
+
# Create context evaluator
|
|
348
|
+
self.context_evaluator = ContextEvaluator(
|
|
349
|
+
real_data=self.real_data,
|
|
350
|
+
generated_data=self.generated_data,
|
|
351
|
+
real_obs=self.real_obs,
|
|
352
|
+
generated_obs=self.generated_obs,
|
|
353
|
+
condition_columns=condition_columns,
|
|
354
|
+
gene_names=self.gene_names,
|
|
355
|
+
control_key=control_key,
|
|
356
|
+
perturbation_column=self.perturbation_column,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Initialize metric objects
|
|
360
|
+
self._metric_objects = {
|
|
361
|
+
"wasserstein_1": Wasserstein1Distance(),
|
|
362
|
+
"wasserstein_2": Wasserstein2Distance(),
|
|
363
|
+
"mmd": MMDDistance(),
|
|
364
|
+
"energy": EnergyDistance(),
|
|
365
|
+
"pearson": PearsonCorrelation(),
|
|
366
|
+
"spearman": SpearmanCorrelation(),
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
|
|
370
|
+
self._log(f"DEG settings: method={deg_method}, pval<{pval_threshold}, |lfc|>{lfc_threshold}")
|
|
371
|
+
if n_top_degs is not None:
|
|
372
|
+
self._log(f" Using top {n_top_degs} DEGs by significance")
|
|
373
|
+
if compute_all_genes:
|
|
374
|
+
self._log("Will compute metrics on BOTH DEGs and all genes")
|
|
375
|
+
|
|
376
|
+
def _log(self, msg: str) -> None:
|
|
377
|
+
"""Print if verbose."""
|
|
378
|
+
if self.verbose:
|
|
379
|
+
print(msg)
|
|
380
|
+
|
|
381
|
+
def _compute_degs(
|
|
382
|
+
self,
|
|
383
|
+
control: np.ndarray,
|
|
384
|
+
perturbed: np.ndarray,
|
|
385
|
+
) -> DEGResult:
|
|
386
|
+
"""Compute DEGs using configured method and device."""
|
|
387
|
+
deg_result = compute_degs_auto(
|
|
388
|
+
control=control,
|
|
389
|
+
perturbed=perturbed,
|
|
390
|
+
gene_names=self.gene_names,
|
|
391
|
+
method=self.deg_settings.method,
|
|
392
|
+
pval_threshold=self.deg_settings.pval_threshold,
|
|
393
|
+
lfc_threshold=self.deg_settings.lfc_threshold,
|
|
394
|
+
n_jobs=self.n_jobs,
|
|
395
|
+
device=self.device,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# If n_top_degs is set, limit to top N by significance
|
|
399
|
+
if self.deg_settings.n_top_degs is not None:
|
|
400
|
+
deg_result = self._filter_top_degs(deg_result, self.deg_settings.n_top_degs)
|
|
401
|
+
|
|
402
|
+
return deg_result
|
|
403
|
+
|
|
404
|
+
def _filter_top_degs(self, deg_result: DEGResult, n_top: int) -> DEGResult:
|
|
405
|
+
"""Filter DEG result to keep only top N most significant DEGs."""
|
|
406
|
+
if deg_result.n_degs <= n_top:
|
|
407
|
+
return deg_result # Already fewer than n_top
|
|
408
|
+
|
|
409
|
+
# Sort DEGs by adjusted p-value (lower is more significant)
|
|
410
|
+
deg_pvals = deg_result.pvalues_adj[deg_result.is_deg]
|
|
411
|
+
deg_indices = deg_result.deg_indices
|
|
412
|
+
|
|
413
|
+
# Get indices of top N most significant
|
|
414
|
+
top_n_order = np.argsort(deg_pvals)[:n_top]
|
|
415
|
+
top_deg_indices = deg_indices[top_n_order]
|
|
416
|
+
|
|
417
|
+
# Create new is_deg mask
|
|
418
|
+
new_is_deg = np.zeros(len(deg_result.is_deg), dtype=bool)
|
|
419
|
+
new_is_deg[top_deg_indices] = True
|
|
420
|
+
|
|
421
|
+
# Create modified DEGResult with all required fields
|
|
422
|
+
return DEGResult(
|
|
423
|
+
gene_names=deg_result.gene_names,
|
|
424
|
+
pvalues=deg_result.pvalues,
|
|
425
|
+
pvalues_adj=deg_result.pvalues_adj,
|
|
426
|
+
log_fold_changes=deg_result.log_fold_changes,
|
|
427
|
+
mean_control=deg_result.mean_control,
|
|
428
|
+
mean_perturbed=deg_result.mean_perturbed,
|
|
429
|
+
is_deg=new_is_deg,
|
|
430
|
+
n_degs=n_top,
|
|
431
|
+
method=deg_result.method,
|
|
432
|
+
pval_threshold=deg_result.pval_threshold,
|
|
433
|
+
lfc_threshold=deg_result.lfc_threshold,
|
|
434
|
+
deg_indices=top_deg_indices,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def _compute_metrics_on_genes(
|
|
438
|
+
self,
|
|
439
|
+
real: np.ndarray,
|
|
440
|
+
generated: np.ndarray,
|
|
441
|
+
gene_indices: Optional[np.ndarray] = None,
|
|
442
|
+
min_genes: int = 1,
|
|
443
|
+
) -> Dict[str, float]:
|
|
444
|
+
"""Compute metrics on specified genes (or all if indices is None)."""
|
|
445
|
+
# Slice to selected genes
|
|
446
|
+
if gene_indices is not None:
|
|
447
|
+
if len(gene_indices) < min_genes:
|
|
448
|
+
return {m: np.nan for m in self.metrics}
|
|
449
|
+
real_subset = real[:, gene_indices]
|
|
450
|
+
gen_subset = generated[:, gene_indices]
|
|
451
|
+
else:
|
|
452
|
+
real_subset = real
|
|
453
|
+
gen_subset = generated
|
|
454
|
+
|
|
455
|
+
results = {}
|
|
456
|
+
|
|
457
|
+
# Use vectorized implementations where available
|
|
458
|
+
if "wasserstein_1" in self.metrics:
|
|
459
|
+
try:
|
|
460
|
+
w1_per_gene = vectorized_wasserstein1(real_subset, gen_subset)
|
|
461
|
+
results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
|
|
462
|
+
except Exception:
|
|
463
|
+
results["wasserstein_1"] = np.nan
|
|
464
|
+
|
|
465
|
+
if "mmd" in self.metrics:
|
|
466
|
+
try:
|
|
467
|
+
mmd_per_gene = vectorized_mmd(real_subset, gen_subset)
|
|
468
|
+
results["mmd"] = float(np.nanmean(mmd_per_gene))
|
|
469
|
+
except Exception:
|
|
470
|
+
results["mmd"] = np.nan
|
|
471
|
+
|
|
472
|
+
# Fall back to standard computation for other metrics
|
|
473
|
+
for metric_name in self.metrics:
|
|
474
|
+
if metric_name in results:
|
|
475
|
+
continue
|
|
476
|
+
if metric_name not in self._metric_objects:
|
|
477
|
+
continue
|
|
478
|
+
|
|
479
|
+
metric = self._metric_objects[metric_name]
|
|
480
|
+
try:
|
|
481
|
+
per_gene = metric.compute_per_gene(real_subset, gen_subset)
|
|
482
|
+
results[metric_name] = float(np.nanmean(per_gene))
|
|
483
|
+
except Exception:
|
|
484
|
+
results[metric_name] = np.nan
|
|
485
|
+
|
|
486
|
+
return results
|
|
487
|
+
|
|
488
|
+
def evaluate(self) -> DEGEvaluationResult:
|
|
489
|
+
"""
|
|
490
|
+
Run DEG-focused evaluation on all contexts.
|
|
491
|
+
|
|
492
|
+
Returns both DEG-only and all-genes metrics for comparison.
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
DEGEvaluationResult
|
|
497
|
+
Complete evaluation results with:
|
|
498
|
+
- Per-context DEG and all-genes metrics
|
|
499
|
+
- Aggregated metrics
|
|
500
|
+
- DEG summary
|
|
501
|
+
- Comparison between DEG and all-genes evaluation
|
|
502
|
+
"""
|
|
503
|
+
context_results: List[ContextMetrics] = []
|
|
504
|
+
|
|
505
|
+
perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
|
|
506
|
+
n_contexts = len(perturbation_contexts)
|
|
507
|
+
|
|
508
|
+
self._log(f"\nEvaluating {n_contexts} perturbation contexts...")
|
|
509
|
+
|
|
510
|
+
for i, context in enumerate(perturbation_contexts):
|
|
511
|
+
context_id = get_context_id(context)
|
|
512
|
+
|
|
513
|
+
if self.verbose:
|
|
514
|
+
print(f" [{i+1}/{n_contexts}] {context_id}", end="... ")
|
|
515
|
+
|
|
516
|
+
try:
|
|
517
|
+
# Get perturbed data
|
|
518
|
+
real_pert, gen_pert = self.context_evaluator.get_context_data(context)
|
|
519
|
+
|
|
520
|
+
# Get control data
|
|
521
|
+
real_ctrl, gen_ctrl = self.context_evaluator.get_control_data(context)
|
|
522
|
+
|
|
523
|
+
if len(real_ctrl) < 2 or len(real_pert) < 2:
|
|
524
|
+
if self.verbose:
|
|
525
|
+
print("skipped (insufficient samples)")
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
# Compute DEGs using real data (control vs perturbed)
|
|
529
|
+
deg_result = self._compute_degs(real_ctrl, real_pert)
|
|
530
|
+
|
|
531
|
+
if self.verbose:
|
|
532
|
+
print(f"{deg_result.n_degs} DEGs", end="... ")
|
|
533
|
+
|
|
534
|
+
# Compute metrics on DEGs
|
|
535
|
+
deg_metrics = self._compute_metrics_on_genes(
|
|
536
|
+
real_pert, gen_pert,
|
|
537
|
+
gene_indices=deg_result.deg_indices,
|
|
538
|
+
min_genes=self.deg_settings.min_degs,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Compute metrics on all genes if requested
|
|
542
|
+
if self.compute_all_genes:
|
|
543
|
+
all_genes_metrics = self._compute_metrics_on_genes(
|
|
544
|
+
real_pert, gen_pert,
|
|
545
|
+
gene_indices=None, # All genes
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
all_genes_metrics = {}
|
|
549
|
+
|
|
550
|
+
ctx_result = ContextMetrics(
|
|
551
|
+
context_id=context_id,
|
|
552
|
+
context_values=context,
|
|
553
|
+
n_samples_real=len(real_pert),
|
|
554
|
+
n_samples_gen=len(gen_pert),
|
|
555
|
+
n_genes_total=self.n_genes,
|
|
556
|
+
deg_result=deg_result,
|
|
557
|
+
deg_metrics=deg_metrics,
|
|
558
|
+
all_genes_metrics=all_genes_metrics,
|
|
559
|
+
)
|
|
560
|
+
context_results.append(ctx_result)
|
|
561
|
+
|
|
562
|
+
if self.verbose:
|
|
563
|
+
print("done")
|
|
564
|
+
|
|
565
|
+
except Exception as e:
|
|
566
|
+
if self.verbose:
|
|
567
|
+
print(f"error: {e}")
|
|
568
|
+
continue
|
|
569
|
+
|
|
570
|
+
# Build result DataFrames
|
|
571
|
+
expanded_metrics = self._build_expanded_metrics(context_results)
|
|
572
|
+
aggregated_metrics = self._build_aggregated_metrics(expanded_metrics)
|
|
573
|
+
deg_summary = self._build_deg_summary(context_results)
|
|
574
|
+
comparison_summary = self._build_comparison_summary(expanded_metrics)
|
|
575
|
+
|
|
576
|
+
self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
|
|
577
|
+
|
|
578
|
+
return DEGEvaluationResult(
|
|
579
|
+
context_results=context_results,
|
|
580
|
+
aggregated_metrics=aggregated_metrics,
|
|
581
|
+
expanded_metrics=expanded_metrics,
|
|
582
|
+
deg_summary=deg_summary,
|
|
583
|
+
comparison_summary=comparison_summary,
|
|
584
|
+
gene_names=self.gene_names,
|
|
585
|
+
settings={
|
|
586
|
+
**self.deg_settings.to_dict(),
|
|
587
|
+
"compute_all_genes": self.compute_all_genes,
|
|
588
|
+
"metrics": self.metrics,
|
|
589
|
+
"device": self.device,
|
|
590
|
+
"n_jobs": self.n_jobs,
|
|
591
|
+
},
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
def _build_expanded_metrics(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
|
|
595
|
+
"""Build expanded metrics DataFrame with both DEG and all-genes columns."""
|
|
596
|
+
data = []
|
|
597
|
+
for ctx in context_results:
|
|
598
|
+
row = {
|
|
599
|
+
"context_id": ctx.context_id,
|
|
600
|
+
**ctx.context_values,
|
|
601
|
+
"n_samples_real": ctx.n_samples_real,
|
|
602
|
+
"n_samples_gen": ctx.n_samples_gen,
|
|
603
|
+
"n_genes_total": ctx.n_genes_total,
|
|
604
|
+
"n_degs": ctx.n_degs,
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
# DEG metrics with prefix
|
|
608
|
+
for metric_name, value in ctx.deg_metrics.items():
|
|
609
|
+
row[f"deg_{metric_name}"] = value
|
|
610
|
+
|
|
611
|
+
# All-genes metrics with prefix
|
|
612
|
+
for metric_name, value in ctx.all_genes_metrics.items():
|
|
613
|
+
row[f"all_{metric_name}"] = value
|
|
614
|
+
|
|
615
|
+
data.append(row)
|
|
616
|
+
|
|
617
|
+
return pd.DataFrame(data)
|
|
618
|
+
|
|
619
|
+
def _build_aggregated_metrics(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
|
|
620
|
+
"""Build aggregated metrics DataFrame."""
|
|
621
|
+
if len(expanded_metrics) == 0:
|
|
622
|
+
return pd.DataFrame()
|
|
623
|
+
|
|
624
|
+
agg_data = {
|
|
625
|
+
"n_contexts": len(expanded_metrics),
|
|
626
|
+
"total_samples_real": expanded_metrics["n_samples_real"].sum(),
|
|
627
|
+
"total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
|
|
628
|
+
"mean_n_degs": expanded_metrics["n_degs"].mean(),
|
|
629
|
+
"median_n_degs": expanded_metrics["n_degs"].median(),
|
|
630
|
+
"min_n_degs": expanded_metrics["n_degs"].min(),
|
|
631
|
+
"max_n_degs": expanded_metrics["n_degs"].max(),
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
# Aggregate DEG metrics
|
|
635
|
+
for metric in self.metrics:
|
|
636
|
+
col = f"deg_{metric}"
|
|
637
|
+
if col in expanded_metrics.columns:
|
|
638
|
+
agg_data[f"deg_{metric}_mean"] = expanded_metrics[col].mean()
|
|
639
|
+
agg_data[f"deg_{metric}_std"] = expanded_metrics[col].std()
|
|
640
|
+
|
|
641
|
+
# Aggregate all-genes metrics
|
|
642
|
+
for metric in self.metrics:
|
|
643
|
+
col = f"all_{metric}"
|
|
644
|
+
if col in expanded_metrics.columns:
|
|
645
|
+
agg_data[f"all_{metric}_mean"] = expanded_metrics[col].mean()
|
|
646
|
+
agg_data[f"all_{metric}_std"] = expanded_metrics[col].std()
|
|
647
|
+
|
|
648
|
+
return pd.DataFrame([agg_data])
|
|
649
|
+
|
|
650
|
+
def _build_deg_summary(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
|
|
651
|
+
"""Build DEG summary DataFrame."""
|
|
652
|
+
data = []
|
|
653
|
+
for ctx in context_results:
|
|
654
|
+
if ctx.deg_result is not None:
|
|
655
|
+
deg_lfcs = ctx.deg_result.log_fold_changes[ctx.deg_result.is_deg]
|
|
656
|
+
data.append({
|
|
657
|
+
"context_id": ctx.context_id,
|
|
658
|
+
**ctx.context_values,
|
|
659
|
+
"n_degs": ctx.n_degs,
|
|
660
|
+
"n_genes_total": ctx.n_genes_total,
|
|
661
|
+
"deg_fraction": ctx.n_degs / ctx.n_genes_total if ctx.n_genes_total > 0 else 0,
|
|
662
|
+
"n_upregulated": (deg_lfcs > 0).sum(),
|
|
663
|
+
"n_downregulated": (deg_lfcs < 0).sum(),
|
|
664
|
+
"mean_abs_lfc": float(np.abs(deg_lfcs).mean()) if len(deg_lfcs) > 0 else np.nan,
|
|
665
|
+
"max_abs_lfc": float(np.abs(deg_lfcs).max()) if len(deg_lfcs) > 0 else np.nan,
|
|
666
|
+
})
|
|
667
|
+
return pd.DataFrame(data)
|
|
668
|
+
|
|
669
|
+
def _build_comparison_summary(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
|
|
670
|
+
"""Build comparison summary between DEG and all-genes metrics."""
|
|
671
|
+
if len(expanded_metrics) == 0:
|
|
672
|
+
return pd.DataFrame()
|
|
673
|
+
|
|
674
|
+
comparison_data = []
|
|
675
|
+
for metric in self.metrics:
|
|
676
|
+
deg_col = f"deg_{metric}"
|
|
677
|
+
all_col = f"all_{metric}"
|
|
678
|
+
|
|
679
|
+
if deg_col in expanded_metrics.columns and all_col in expanded_metrics.columns:
|
|
680
|
+
deg_values = expanded_metrics[deg_col].dropna()
|
|
681
|
+
all_values = expanded_metrics[all_col].dropna()
|
|
682
|
+
|
|
683
|
+
if len(deg_values) > 0 and len(all_values) > 0:
|
|
684
|
+
deg_mean = deg_values.mean()
|
|
685
|
+
all_mean = all_values.mean()
|
|
686
|
+
|
|
687
|
+
comparison_data.append({
|
|
688
|
+
"metric": metric,
|
|
689
|
+
"deg_mean": deg_mean,
|
|
690
|
+
"deg_std": deg_values.std(),
|
|
691
|
+
"all_mean": all_mean,
|
|
692
|
+
"all_std": all_values.std(),
|
|
693
|
+
"difference": deg_mean - all_mean,
|
|
694
|
+
"ratio": deg_mean / all_mean if all_mean != 0 else np.nan,
|
|
695
|
+
"n_contexts": len(deg_values),
|
|
696
|
+
})
|
|
697
|
+
|
|
698
|
+
return pd.DataFrame(comparison_data)
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def evaluate_degs(
|
|
702
|
+
real_data: np.ndarray,
|
|
703
|
+
generated_data: np.ndarray,
|
|
704
|
+
real_obs: pd.DataFrame,
|
|
705
|
+
generated_obs: pd.DataFrame,
|
|
706
|
+
condition_columns: List[str],
|
|
707
|
+
gene_names: Optional[np.ndarray] = None,
|
|
708
|
+
control_key: str = "control",
|
|
709
|
+
perturbation_column: Optional[str] = None,
|
|
710
|
+
deg_method: DEGMethod = "welch",
|
|
711
|
+
pval_threshold: float = 0.05,
|
|
712
|
+
lfc_threshold: float = 0.5,
|
|
713
|
+
n_top_degs: Optional[int] = None,
|
|
714
|
+
min_degs: int = 5,
|
|
715
|
+
compute_all_genes: bool = True,
|
|
716
|
+
metrics: Optional[List[str]] = None,
|
|
717
|
+
n_jobs: int = 1,
|
|
718
|
+
device: str = "auto",
|
|
719
|
+
verbose: bool = True,
|
|
720
|
+
) -> DEGEvaluationResult:
|
|
721
|
+
"""
|
|
722
|
+
Convenience function for DEG-focused evaluation with full control.
|
|
723
|
+
|
|
724
|
+
Computes metrics on both DEGs and all genes for comparison.
|
|
725
|
+
|
|
726
|
+
Parameters
|
|
727
|
+
----------
|
|
728
|
+
real_data : np.ndarray
|
|
729
|
+
Real expression matrix (n_samples, n_genes)
|
|
730
|
+
generated_data : np.ndarray
|
|
731
|
+
Generated expression matrix (n_samples, n_genes)
|
|
732
|
+
real_obs : pd.DataFrame
|
|
733
|
+
Real data metadata with condition columns
|
|
734
|
+
generated_obs : pd.DataFrame
|
|
735
|
+
Generated data metadata with condition columns
|
|
736
|
+
condition_columns : List[str]
|
|
737
|
+
Columns defining contexts (e.g., ["cell_type", "perturbation"])
|
|
738
|
+
gene_names : np.ndarray, optional
|
|
739
|
+
Gene names for output
|
|
740
|
+
control_key : str
|
|
741
|
+
Control condition identifier (default: "control")
|
|
742
|
+
perturbation_column : str, optional
|
|
743
|
+
Column containing perturbation info. If None, uses first condition column.
|
|
744
|
+
deg_method : str
|
|
745
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
746
|
+
pval_threshold : float
|
|
747
|
+
Adjusted p-value threshold (default: 0.05)
|
|
748
|
+
lfc_threshold : float
|
|
749
|
+
Absolute log2 fold change threshold (default: 0.5)
|
|
750
|
+
n_top_degs : int, optional
|
|
751
|
+
If set, use only top N DEGs by significance (overrides thresholds)
|
|
752
|
+
min_degs : int
|
|
753
|
+
Minimum DEGs to compute DEG metrics (default: 5)
|
|
754
|
+
compute_all_genes : bool
|
|
755
|
+
Also compute metrics on all genes for comparison (default: True)
|
|
756
|
+
metrics : List[str], optional
|
|
757
|
+
Metrics to compute. Default: all supported.
|
|
758
|
+
n_jobs : int
|
|
759
|
+
Parallel CPU jobs (default: 1)
|
|
760
|
+
device : str
|
|
761
|
+
Compute device: "cpu", "cuda", "mps", "auto" (default: "auto")
|
|
762
|
+
verbose : bool
|
|
763
|
+
Print progress (default: True)
|
|
764
|
+
|
|
765
|
+
Returns
|
|
766
|
+
-------
|
|
767
|
+
DEGEvaluationResult
|
|
768
|
+
Complete evaluation results including:
|
|
769
|
+
- expanded_metrics: Per-context metrics for DEGs and all genes
|
|
770
|
+
- aggregated_metrics: Summary statistics
|
|
771
|
+
- deg_summary: DEG detection summary
|
|
772
|
+
- comparison_summary: DEG vs all-genes comparison
|
|
773
|
+
|
|
774
|
+
Examples
|
|
775
|
+
--------
|
|
776
|
+
>>> # Basic usage with default thresholds
|
|
777
|
+
>>> results = evaluate_degs(
|
|
778
|
+
... real_data, generated_data,
|
|
779
|
+
... real_obs, generated_obs,
|
|
780
|
+
... condition_columns=["perturbation"],
|
|
781
|
+
... )
|
|
782
|
+
>>> print(results.comparison_summary)
|
|
783
|
+
|
|
784
|
+
>>> # Top 50 DEGs only
|
|
785
|
+
>>> results = evaluate_degs(
|
|
786
|
+
... real_data, generated_data,
|
|
787
|
+
... real_obs, generated_obs,
|
|
788
|
+
... condition_columns=["perturbation"],
|
|
789
|
+
... n_top_degs=50,
|
|
790
|
+
... )
|
|
791
|
+
|
|
792
|
+
>>> # Strict thresholds
|
|
793
|
+
>>> results = evaluate_degs(
|
|
794
|
+
... real_data, generated_data,
|
|
795
|
+
... real_obs, generated_obs,
|
|
796
|
+
... condition_columns=["perturbation"],
|
|
797
|
+
... pval_threshold=0.01,
|
|
798
|
+
... lfc_threshold=1.0, # 2-fold change
|
|
799
|
+
... )
|
|
800
|
+
"""
|
|
801
|
+
evaluator = DEGEvaluator(
|
|
802
|
+
real_data=real_data,
|
|
803
|
+
generated_data=generated_data,
|
|
804
|
+
real_obs=real_obs,
|
|
805
|
+
generated_obs=generated_obs,
|
|
806
|
+
condition_columns=condition_columns,
|
|
807
|
+
gene_names=gene_names,
|
|
808
|
+
control_key=control_key,
|
|
809
|
+
perturbation_column=perturbation_column,
|
|
810
|
+
deg_method=deg_method,
|
|
811
|
+
pval_threshold=pval_threshold,
|
|
812
|
+
lfc_threshold=lfc_threshold,
|
|
813
|
+
n_top_degs=n_top_degs,
|
|
814
|
+
min_degs=min_degs,
|
|
815
|
+
compute_all_genes=compute_all_genes,
|
|
816
|
+
metrics=metrics,
|
|
817
|
+
n_jobs=n_jobs,
|
|
818
|
+
device=device,
|
|
819
|
+
verbose=verbose,
|
|
820
|
+
)
|
|
821
|
+
return evaluator.evaluate()
|