gengeneeval 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +1 -1
- geneval/deg/__init__.py +4 -0
- geneval/deg/evaluator.py +429 -146
- {gengeneeval-0.4.0.dist-info → gengeneeval-0.4.1.dist-info}/METADATA +46 -11
- {gengeneeval-0.4.0.dist-info → gengeneeval-0.4.1.dist-info}/RECORD +8 -8
- {gengeneeval-0.4.0.dist-info → gengeneeval-0.4.1.dist-info}/WHEEL +0 -0
- {gengeneeval-0.4.0.dist-info → gengeneeval-0.4.1.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.4.0.dist-info → gengeneeval-0.4.1.dist-info}/licenses/LICENSE +0 -0
geneval/__init__.py
CHANGED
geneval/deg/__init__.py
CHANGED
|
@@ -31,6 +31,8 @@ from .context import (
|
|
|
31
31
|
from .evaluator import (
|
|
32
32
|
DEGEvaluator,
|
|
33
33
|
DEGEvaluationResult,
|
|
34
|
+
DEGSettings,
|
|
35
|
+
ContextMetrics,
|
|
34
36
|
evaluate_degs,
|
|
35
37
|
)
|
|
36
38
|
from .visualization import (
|
|
@@ -56,6 +58,8 @@ __all__ = [
|
|
|
56
58
|
# Evaluator
|
|
57
59
|
"DEGEvaluator",
|
|
58
60
|
"DEGEvaluationResult",
|
|
61
|
+
"DEGSettings",
|
|
62
|
+
"ContextMetrics",
|
|
59
63
|
"evaluate_degs",
|
|
60
64
|
# Visualization
|
|
61
65
|
"plot_deg_distributions",
|
geneval/deg/evaluator.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
DEG-focused evaluator for GenGeneEval.
|
|
3
3
|
|
|
4
|
-
Computes metrics
|
|
4
|
+
Computes metrics on differentially expressed genes with full control:
|
|
5
|
+
- Comparison of DEG-only vs all-genes metrics
|
|
6
|
+
- Configurable DEG selection (top N, p-value, log fold change thresholds)
|
|
5
7
|
- Per-context evaluation (covariates × perturbations)
|
|
6
8
|
- Fast DEG detection with GPU acceleration
|
|
7
9
|
- Aggregated and expanded result reporting
|
|
@@ -45,29 +47,109 @@ from ..metrics.accelerated import (
|
|
|
45
47
|
)
|
|
46
48
|
|
|
47
49
|
|
|
50
|
+
@dataclass
|
|
51
|
+
class DEGSettings:
|
|
52
|
+
"""Settings for DEG detection and filtering.
|
|
53
|
+
|
|
54
|
+
Attributes
|
|
55
|
+
----------
|
|
56
|
+
method : str
|
|
57
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
58
|
+
pval_threshold : float
|
|
59
|
+
P-value threshold for significance
|
|
60
|
+
lfc_threshold : float
|
|
61
|
+
Absolute log2 fold change threshold
|
|
62
|
+
n_top_degs : int, optional
|
|
63
|
+
If set, use only top N DEGs by significance (overrides threshold filtering)
|
|
64
|
+
min_degs : int
|
|
65
|
+
Minimum number of DEGs required to compute metrics
|
|
66
|
+
"""
|
|
67
|
+
method: str = "welch"
|
|
68
|
+
pval_threshold: float = 0.05
|
|
69
|
+
lfc_threshold: float = 0.5
|
|
70
|
+
n_top_degs: Optional[int] = None
|
|
71
|
+
min_degs: int = 5
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
74
|
+
"""Convert to dictionary."""
|
|
75
|
+
return {
|
|
76
|
+
"deg_method": self.method,
|
|
77
|
+
"pval_threshold": self.pval_threshold,
|
|
78
|
+
"lfc_threshold": self.lfc_threshold,
|
|
79
|
+
"n_top_degs": self.n_top_degs,
|
|
80
|
+
"min_degs": self.min_degs,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class ContextMetrics:
|
|
86
|
+
"""Metrics for a single context, comparing DEG-only vs all genes.
|
|
87
|
+
|
|
88
|
+
Attributes
|
|
89
|
+
----------
|
|
90
|
+
context_id : str
|
|
91
|
+
Context identifier
|
|
92
|
+
context_values : Dict
|
|
93
|
+
Context column values
|
|
94
|
+
n_samples_real : int
|
|
95
|
+
Number of real samples
|
|
96
|
+
n_samples_gen : int
|
|
97
|
+
Number of generated samples
|
|
98
|
+
n_genes_total : int
|
|
99
|
+
Total number of genes
|
|
100
|
+
deg_result : DEGResult
|
|
101
|
+
DEG detection results
|
|
102
|
+
deg_metrics : Dict[str, float]
|
|
103
|
+
Metrics computed on DEGs only
|
|
104
|
+
all_genes_metrics : Dict[str, float]
|
|
105
|
+
Metrics computed on all genes
|
|
106
|
+
"""
|
|
107
|
+
context_id: str
|
|
108
|
+
context_values: Dict[str, Any]
|
|
109
|
+
n_samples_real: int
|
|
110
|
+
n_samples_gen: int
|
|
111
|
+
n_genes_total: int
|
|
112
|
+
deg_result: Optional[DEGResult]
|
|
113
|
+
deg_metrics: Dict[str, float]
|
|
114
|
+
all_genes_metrics: Dict[str, float]
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def n_degs(self) -> int:
|
|
118
|
+
"""Number of DEGs."""
|
|
119
|
+
return self.deg_result.n_degs if self.deg_result else 0
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def deg_indices_used(self) -> np.ndarray:
|
|
123
|
+
"""DEG indices actually used for metrics."""
|
|
124
|
+
return self.deg_result.deg_indices if self.deg_result else np.array([])
|
|
125
|
+
|
|
126
|
+
|
|
48
127
|
@dataclass
|
|
49
128
|
class DEGEvaluationResult:
|
|
50
|
-
"""Complete DEG evaluation results.
|
|
129
|
+
"""Complete DEG evaluation results with comparison to all-genes metrics.
|
|
51
130
|
|
|
52
131
|
Attributes
|
|
53
132
|
----------
|
|
54
|
-
context_results : List[
|
|
55
|
-
Results for each context
|
|
133
|
+
context_results : List[ContextMetrics]
|
|
134
|
+
Results for each context with both DEG and all-gene metrics
|
|
56
135
|
aggregated_metrics : pd.DataFrame
|
|
57
|
-
Aggregated metrics across contexts
|
|
136
|
+
Aggregated metrics across contexts (both DEG and all-genes)
|
|
58
137
|
expanded_metrics : pd.DataFrame
|
|
59
|
-
Per-context expanded metrics
|
|
138
|
+
Per-context expanded metrics (both DEG and all-genes)
|
|
60
139
|
deg_summary : pd.DataFrame
|
|
61
140
|
Summary of DEG detection per context
|
|
141
|
+
comparison_summary : pd.DataFrame
|
|
142
|
+
Comparison between DEG-only and all-genes metrics
|
|
62
143
|
gene_names : np.ndarray
|
|
63
144
|
All gene names
|
|
64
145
|
settings : Dict
|
|
65
|
-
Evaluation settings
|
|
146
|
+
Evaluation settings including DEG parameters
|
|
66
147
|
"""
|
|
67
|
-
context_results: List[
|
|
148
|
+
context_results: List[ContextMetrics]
|
|
68
149
|
aggregated_metrics: pd.DataFrame
|
|
69
150
|
expanded_metrics: pd.DataFrame
|
|
70
151
|
deg_summary: pd.DataFrame
|
|
152
|
+
comparison_summary: pd.DataFrame
|
|
71
153
|
gene_names: np.ndarray
|
|
72
154
|
settings: Dict[str, Any]
|
|
73
155
|
|
|
@@ -76,9 +158,15 @@ class DEGEvaluationResult:
|
|
|
76
158
|
output_dir = Path(output_dir)
|
|
77
159
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
78
160
|
|
|
79
|
-
self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv")
|
|
80
|
-
self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv")
|
|
81
|
-
self.deg_summary.to_csv(output_dir / "deg_summary.csv")
|
|
161
|
+
self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv", index=False)
|
|
162
|
+
self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv", index=False)
|
|
163
|
+
self.deg_summary.to_csv(output_dir / "deg_summary.csv", index=False)
|
|
164
|
+
self.comparison_summary.to_csv(output_dir / "deg_vs_all_comparison.csv", index=False)
|
|
165
|
+
|
|
166
|
+
# Save settings
|
|
167
|
+
import json
|
|
168
|
+
with open(output_dir / "settings.json", "w") as f:
|
|
169
|
+
json.dump(self.settings, f, indent=2, default=str)
|
|
82
170
|
|
|
83
171
|
# Save per-context DEG results
|
|
84
172
|
deg_dir = output_dir / "deg_per_context"
|
|
@@ -86,24 +174,41 @@ class DEGEvaluationResult:
|
|
|
86
174
|
for ctx_result in self.context_results:
|
|
87
175
|
if ctx_result.deg_result is not None:
|
|
88
176
|
ctx_result.deg_result.to_dataframe().to_csv(
|
|
89
|
-
deg_dir / f"{ctx_result.context_id}_degs.csv"
|
|
177
|
+
deg_dir / f"{ctx_result.context_id}_degs.csv", index=False
|
|
90
178
|
)
|
|
91
179
|
|
|
180
|
+
def get_deg_only_metrics(self) -> pd.DataFrame:
|
|
181
|
+
"""Get expanded metrics for DEGs only."""
|
|
182
|
+
base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
|
|
183
|
+
deg_cols = [c for c in self.expanded_metrics.columns if c.startswith("deg_")]
|
|
184
|
+
cols = [c for c in base_cols if c in self.expanded_metrics.columns] + deg_cols
|
|
185
|
+
return self.expanded_metrics[cols].copy()
|
|
186
|
+
|
|
187
|
+
def get_all_genes_metrics(self) -> pd.DataFrame:
|
|
188
|
+
"""Get expanded metrics for all genes."""
|
|
189
|
+
base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
|
|
190
|
+
all_cols = [c for c in self.expanded_metrics.columns if c.startswith("all_")]
|
|
191
|
+
cols = [c for c in base_cols if c in self.expanded_metrics.columns] + all_cols
|
|
192
|
+
return self.expanded_metrics[cols].copy()
|
|
193
|
+
|
|
92
194
|
def __repr__(self) -> str:
|
|
195
|
+
n_degs_avg = self.deg_summary["n_degs"].mean() if len(self.deg_summary) > 0 else 0
|
|
93
196
|
return (
|
|
94
197
|
f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
|
|
95
|
-
f"
|
|
198
|
+
f"avg_degs={n_degs_avg:.1f}, "
|
|
199
|
+
f"settings={self.settings.get('deg_method', 'unknown')})"
|
|
96
200
|
)
|
|
97
201
|
|
|
98
202
|
|
|
99
203
|
class DEGEvaluator:
|
|
100
204
|
"""
|
|
101
|
-
Evaluator that computes metrics on DEGs
|
|
205
|
+
Evaluator that computes metrics on DEGs with comparison to all genes.
|
|
102
206
|
|
|
103
207
|
This evaluator:
|
|
104
208
|
1. Detects DEGs for each perturbation context
|
|
105
|
-
2. Computes distributional metrics
|
|
106
|
-
3.
|
|
209
|
+
2. Computes distributional metrics on BOTH DEG genes AND all genes
|
|
210
|
+
3. Provides comparison between DEG-focused and all-genes evaluation
|
|
211
|
+
4. Reports per-context and aggregated results
|
|
107
212
|
|
|
108
213
|
Parameters
|
|
109
214
|
----------
|
|
@@ -126,11 +231,15 @@ class DEGEvaluator:
|
|
|
126
231
|
deg_method : str
|
|
127
232
|
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
128
233
|
pval_threshold : float
|
|
129
|
-
P-value threshold for DEG significance
|
|
234
|
+
P-value threshold for DEG significance (default: 0.05)
|
|
130
235
|
lfc_threshold : float
|
|
131
|
-
Log2 fold change threshold
|
|
236
|
+
Log2 fold change threshold (default: 0.5)
|
|
237
|
+
n_top_degs : int, optional
|
|
238
|
+
If set, use only top N DEGs by significance instead of thresholds
|
|
132
239
|
min_degs : int
|
|
133
|
-
Minimum DEGs required to compute metrics
|
|
240
|
+
Minimum DEGs required to compute DEG-specific metrics (default: 5)
|
|
241
|
+
compute_all_genes : bool
|
|
242
|
+
Whether to also compute metrics on all genes (default: True)
|
|
134
243
|
metrics : List[str], optional
|
|
135
244
|
Metrics to compute. Default: all supported metrics.
|
|
136
245
|
n_jobs : int
|
|
@@ -142,15 +251,39 @@ class DEGEvaluator:
|
|
|
142
251
|
|
|
143
252
|
Examples
|
|
144
253
|
--------
|
|
254
|
+
>>> # Basic usage - computes both DEG and all-genes metrics
|
|
145
255
|
>>> evaluator = DEGEvaluator(
|
|
146
256
|
... real_data, generated_data,
|
|
147
257
|
... real_obs, generated_obs,
|
|
148
258
|
... condition_columns=["perturbation"],
|
|
149
|
-
... deg_method="welch",
|
|
150
|
-
... device="cuda",
|
|
151
259
|
... )
|
|
152
260
|
>>> results = evaluator.evaluate()
|
|
153
|
-
>>> results.
|
|
261
|
+
>>> print(results.comparison_summary) # DEG vs all-genes comparison
|
|
262
|
+
|
|
263
|
+
>>> # Use top 100 DEGs only
|
|
264
|
+
>>> evaluator = DEGEvaluator(
|
|
265
|
+
... real_data, generated_data,
|
|
266
|
+
... real_obs, generated_obs,
|
|
267
|
+
... condition_columns=["perturbation"],
|
|
268
|
+
... n_top_degs=100, # Use top 100 most significant DEGs
|
|
269
|
+
... )
|
|
270
|
+
|
|
271
|
+
>>> # Stricter thresholds
|
|
272
|
+
>>> evaluator = DEGEvaluator(
|
|
273
|
+
... real_data, generated_data,
|
|
274
|
+
... real_obs, generated_obs,
|
|
275
|
+
... condition_columns=["perturbation"],
|
|
276
|
+
... pval_threshold=0.01, # More stringent p-value
|
|
277
|
+
... lfc_threshold=1.0, # log2 FC > 1 (2-fold change)
|
|
278
|
+
... )
|
|
279
|
+
|
|
280
|
+
>>> # DEGs only (no all-genes metrics for speed)
|
|
281
|
+
>>> evaluator = DEGEvaluator(
|
|
282
|
+
... real_data, generated_data,
|
|
283
|
+
... real_obs, generated_obs,
|
|
284
|
+
... condition_columns=["perturbation"],
|
|
285
|
+
... compute_all_genes=False,
|
|
286
|
+
... )
|
|
154
287
|
"""
|
|
155
288
|
|
|
156
289
|
# Supported metrics
|
|
@@ -176,7 +309,9 @@ class DEGEvaluator:
|
|
|
176
309
|
deg_method: DEGMethod = "welch",
|
|
177
310
|
pval_threshold: float = 0.05,
|
|
178
311
|
lfc_threshold: float = 0.5,
|
|
312
|
+
n_top_degs: Optional[int] = None,
|
|
179
313
|
min_degs: int = 5,
|
|
314
|
+
compute_all_genes: bool = True,
|
|
180
315
|
metrics: Optional[List[str]] = None,
|
|
181
316
|
n_jobs: int = 1,
|
|
182
317
|
device: str = "cpu",
|
|
@@ -187,15 +322,23 @@ class DEGEvaluator:
|
|
|
187
322
|
self.real_obs = real_obs.reset_index(drop=True)
|
|
188
323
|
self.generated_obs = generated_obs.reset_index(drop=True)
|
|
189
324
|
self.condition_columns = condition_columns
|
|
325
|
+
self.n_genes = real_data.shape[1]
|
|
190
326
|
self.gene_names = gene_names if gene_names is not None else np.array(
|
|
191
|
-
[f"Gene_{i}" for i in range(
|
|
327
|
+
[f"Gene_{i}" for i in range(self.n_genes)]
|
|
192
328
|
)
|
|
193
329
|
self.control_key = control_key
|
|
194
330
|
self.perturbation_column = perturbation_column or condition_columns[0]
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
self.
|
|
198
|
-
|
|
331
|
+
|
|
332
|
+
# DEG settings
|
|
333
|
+
self.deg_settings = DEGSettings(
|
|
334
|
+
method=deg_method,
|
|
335
|
+
pval_threshold=pval_threshold,
|
|
336
|
+
lfc_threshold=lfc_threshold,
|
|
337
|
+
n_top_degs=n_top_degs,
|
|
338
|
+
min_degs=min_degs,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
self.compute_all_genes = compute_all_genes
|
|
199
342
|
self.metrics = metrics or self.SUPPORTED_METRICS
|
|
200
343
|
self.n_jobs = n_jobs
|
|
201
344
|
self.device = device
|
|
@@ -224,7 +367,11 @@ class DEGEvaluator:
|
|
|
224
367
|
}
|
|
225
368
|
|
|
226
369
|
self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
|
|
227
|
-
self._log(f"
|
|
370
|
+
self._log(f"DEG settings: method={deg_method}, pval<{pval_threshold}, |lfc|>{lfc_threshold}")
|
|
371
|
+
if n_top_degs is not None:
|
|
372
|
+
self._log(f" Using top {n_top_degs} DEGs by significance")
|
|
373
|
+
if compute_all_genes:
|
|
374
|
+
self._log("Will compute metrics on BOTH DEGs and all genes")
|
|
228
375
|
|
|
229
376
|
def _log(self, msg: str) -> None:
|
|
230
377
|
"""Print if verbose."""
|
|
@@ -237,78 +384,87 @@ class DEGEvaluator:
|
|
|
237
384
|
perturbed: np.ndarray,
|
|
238
385
|
) -> DEGResult:
|
|
239
386
|
"""Compute DEGs using configured method and device."""
|
|
240
|
-
|
|
387
|
+
deg_result = compute_degs_auto(
|
|
241
388
|
control=control,
|
|
242
389
|
perturbed=perturbed,
|
|
243
390
|
gene_names=self.gene_names,
|
|
244
|
-
method=self.
|
|
245
|
-
pval_threshold=self.pval_threshold,
|
|
246
|
-
lfc_threshold=self.lfc_threshold,
|
|
391
|
+
method=self.deg_settings.method,
|
|
392
|
+
pval_threshold=self.deg_settings.pval_threshold,
|
|
393
|
+
lfc_threshold=self.deg_settings.lfc_threshold,
|
|
247
394
|
n_jobs=self.n_jobs,
|
|
248
395
|
device=self.device,
|
|
249
396
|
)
|
|
397
|
+
|
|
398
|
+
# If n_top_degs is set, limit to top N by significance
|
|
399
|
+
if self.deg_settings.n_top_degs is not None:
|
|
400
|
+
deg_result = self._filter_top_degs(deg_result, self.deg_settings.n_top_degs)
|
|
401
|
+
|
|
402
|
+
return deg_result
|
|
250
403
|
|
|
251
|
-
def
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
deg_indices: np.ndarray,
|
|
256
|
-
) -> Dict[str, float]:
|
|
257
|
-
"""Compute metrics on DEG genes only."""
|
|
258
|
-
if len(deg_indices) < self.min_degs:
|
|
259
|
-
return {m: np.nan for m in self.metrics}
|
|
404
|
+
def _filter_top_degs(self, deg_result: DEGResult, n_top: int) -> DEGResult:
|
|
405
|
+
"""Filter DEG result to keep only top N most significant DEGs."""
|
|
406
|
+
if deg_result.n_degs <= n_top:
|
|
407
|
+
return deg_result # Already fewer than n_top
|
|
260
408
|
|
|
261
|
-
#
|
|
262
|
-
|
|
263
|
-
|
|
409
|
+
# Sort DEGs by adjusted p-value (lower is more significant)
|
|
410
|
+
deg_pvals = deg_result.pvalues_adj[deg_result.is_deg]
|
|
411
|
+
deg_indices = deg_result.deg_indices
|
|
264
412
|
|
|
265
|
-
|
|
413
|
+
# Get indices of top N most significant
|
|
414
|
+
top_n_order = np.argsort(deg_pvals)[:n_top]
|
|
415
|
+
top_deg_indices = deg_indices[top_n_order]
|
|
266
416
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
metric = self._metric_objects[metric_name]
|
|
272
|
-
|
|
273
|
-
try:
|
|
274
|
-
# Compute per-gene and aggregate
|
|
275
|
-
per_gene = metric.compute_per_gene(real_degs, gen_degs)
|
|
276
|
-
results[metric_name] = float(np.nanmean(per_gene))
|
|
277
|
-
except Exception as e:
|
|
278
|
-
if self.verbose:
|
|
279
|
-
self._log(f"Warning: {metric_name} failed: {e}")
|
|
280
|
-
results[metric_name] = np.nan
|
|
417
|
+
# Create new is_deg mask
|
|
418
|
+
new_is_deg = np.zeros(len(deg_result.is_deg), dtype=bool)
|
|
419
|
+
new_is_deg[top_deg_indices] = True
|
|
281
420
|
|
|
282
|
-
|
|
421
|
+
# Create modified DEGResult with all required fields
|
|
422
|
+
return DEGResult(
|
|
423
|
+
gene_names=deg_result.gene_names,
|
|
424
|
+
pvalues=deg_result.pvalues,
|
|
425
|
+
pvalues_adj=deg_result.pvalues_adj,
|
|
426
|
+
log_fold_changes=deg_result.log_fold_changes,
|
|
427
|
+
mean_control=deg_result.mean_control,
|
|
428
|
+
mean_perturbed=deg_result.mean_perturbed,
|
|
429
|
+
is_deg=new_is_deg,
|
|
430
|
+
n_degs=n_top,
|
|
431
|
+
method=deg_result.method,
|
|
432
|
+
pval_threshold=deg_result.pval_threshold,
|
|
433
|
+
lfc_threshold=deg_result.lfc_threshold,
|
|
434
|
+
deg_indices=top_deg_indices,
|
|
435
|
+
)
|
|
283
436
|
|
|
284
|
-
def
|
|
437
|
+
def _compute_metrics_on_genes(
|
|
285
438
|
self,
|
|
286
439
|
real: np.ndarray,
|
|
287
440
|
generated: np.ndarray,
|
|
288
|
-
|
|
441
|
+
gene_indices: Optional[np.ndarray] = None,
|
|
442
|
+
min_genes: int = 1,
|
|
289
443
|
) -> Dict[str, float]:
|
|
290
|
-
"""Compute metrics
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
444
|
+
"""Compute metrics on specified genes (or all if indices is None)."""
|
|
445
|
+
# Slice to selected genes
|
|
446
|
+
if gene_indices is not None:
|
|
447
|
+
if len(gene_indices) < min_genes:
|
|
448
|
+
return {m: np.nan for m in self.metrics}
|
|
449
|
+
real_subset = real[:, gene_indices]
|
|
450
|
+
gen_subset = generated[:, gene_indices]
|
|
451
|
+
else:
|
|
452
|
+
real_subset = real
|
|
453
|
+
gen_subset = generated
|
|
297
454
|
|
|
298
455
|
results = {}
|
|
299
|
-
backends = get_available_backends()
|
|
300
456
|
|
|
301
457
|
# Use vectorized implementations where available
|
|
302
458
|
if "wasserstein_1" in self.metrics:
|
|
303
459
|
try:
|
|
304
|
-
w1_per_gene = vectorized_wasserstein1(
|
|
460
|
+
w1_per_gene = vectorized_wasserstein1(real_subset, gen_subset)
|
|
305
461
|
results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
|
|
306
462
|
except Exception:
|
|
307
463
|
results["wasserstein_1"] = np.nan
|
|
308
464
|
|
|
309
465
|
if "mmd" in self.metrics:
|
|
310
466
|
try:
|
|
311
|
-
mmd_per_gene = vectorized_mmd(
|
|
467
|
+
mmd_per_gene = vectorized_mmd(real_subset, gen_subset)
|
|
312
468
|
results["mmd"] = float(np.nanmean(mmd_per_gene))
|
|
313
469
|
except Exception:
|
|
314
470
|
results["mmd"] = np.nan
|
|
@@ -322,7 +478,7 @@ class DEGEvaluator:
|
|
|
322
478
|
|
|
323
479
|
metric = self._metric_objects[metric_name]
|
|
324
480
|
try:
|
|
325
|
-
per_gene = metric.compute_per_gene(
|
|
481
|
+
per_gene = metric.compute_per_gene(real_subset, gen_subset)
|
|
326
482
|
results[metric_name] = float(np.nanmean(per_gene))
|
|
327
483
|
except Exception:
|
|
328
484
|
results[metric_name] = np.nan
|
|
@@ -333,17 +489,23 @@ class DEGEvaluator:
|
|
|
333
489
|
"""
|
|
334
490
|
Run DEG-focused evaluation on all contexts.
|
|
335
491
|
|
|
492
|
+
Returns both DEG-only and all-genes metrics for comparison.
|
|
493
|
+
|
|
336
494
|
Returns
|
|
337
495
|
-------
|
|
338
496
|
DEGEvaluationResult
|
|
339
|
-
Complete evaluation results with
|
|
497
|
+
Complete evaluation results with:
|
|
498
|
+
- Per-context DEG and all-genes metrics
|
|
499
|
+
- Aggregated metrics
|
|
500
|
+
- DEG summary
|
|
501
|
+
- Comparison between DEG and all-genes evaluation
|
|
340
502
|
"""
|
|
341
|
-
context_results = []
|
|
503
|
+
context_results: List[ContextMetrics] = []
|
|
342
504
|
|
|
343
505
|
perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
|
|
344
506
|
n_contexts = len(perturbation_contexts)
|
|
345
507
|
|
|
346
|
-
self._log(f"
|
|
508
|
+
self._log(f"\nEvaluating {n_contexts} perturbation contexts...")
|
|
347
509
|
|
|
348
510
|
for i, context in enumerate(perturbation_contexts):
|
|
349
511
|
context_id = get_context_id(context)
|
|
@@ -370,17 +532,30 @@ class DEGEvaluator:
|
|
|
370
532
|
print(f"{deg_result.n_degs} DEGs", end="... ")
|
|
371
533
|
|
|
372
534
|
# Compute metrics on DEGs
|
|
373
|
-
|
|
374
|
-
real_pert, gen_pert,
|
|
535
|
+
deg_metrics = self._compute_metrics_on_genes(
|
|
536
|
+
real_pert, gen_pert,
|
|
537
|
+
gene_indices=deg_result.deg_indices,
|
|
538
|
+
min_genes=self.deg_settings.min_degs,
|
|
375
539
|
)
|
|
376
540
|
|
|
377
|
-
|
|
541
|
+
# Compute metrics on all genes if requested
|
|
542
|
+
if self.compute_all_genes:
|
|
543
|
+
all_genes_metrics = self._compute_metrics_on_genes(
|
|
544
|
+
real_pert, gen_pert,
|
|
545
|
+
gene_indices=None, # All genes
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
all_genes_metrics = {}
|
|
549
|
+
|
|
550
|
+
ctx_result = ContextMetrics(
|
|
378
551
|
context_id=context_id,
|
|
379
552
|
context_values=context,
|
|
380
553
|
n_samples_real=len(real_pert),
|
|
381
554
|
n_samples_gen=len(gen_pert),
|
|
555
|
+
n_genes_total=self.n_genes,
|
|
382
556
|
deg_result=deg_result,
|
|
383
|
-
|
|
557
|
+
deg_metrics=deg_metrics,
|
|
558
|
+
all_genes_metrics=all_genes_metrics,
|
|
384
559
|
)
|
|
385
560
|
context_results.append(ctx_result)
|
|
386
561
|
|
|
@@ -393,52 +568,10 @@ class DEGEvaluator:
|
|
|
393
568
|
continue
|
|
394
569
|
|
|
395
570
|
# Build result DataFrames
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
**ctx_result.context_values,
|
|
401
|
-
"n_samples_real": ctx_result.n_samples_real,
|
|
402
|
-
"n_samples_gen": ctx_result.n_samples_gen,
|
|
403
|
-
"n_degs": ctx_result.deg_result.n_degs if ctx_result.deg_result else 0,
|
|
404
|
-
**ctx_result.metrics,
|
|
405
|
-
}
|
|
406
|
-
expanded_data.append(row)
|
|
407
|
-
|
|
408
|
-
expanded_metrics = pd.DataFrame(expanded_data)
|
|
409
|
-
|
|
410
|
-
# Aggregated metrics
|
|
411
|
-
if len(expanded_metrics) > 0:
|
|
412
|
-
agg_data = {
|
|
413
|
-
"n_contexts": len(context_results),
|
|
414
|
-
"total_samples_real": expanded_metrics["n_samples_real"].sum(),
|
|
415
|
-
"total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
|
|
416
|
-
"mean_n_degs": expanded_metrics["n_degs"].mean(),
|
|
417
|
-
"median_n_degs": expanded_metrics["n_degs"].median(),
|
|
418
|
-
}
|
|
419
|
-
for metric in self.metrics:
|
|
420
|
-
if metric in expanded_metrics.columns:
|
|
421
|
-
agg_data[f"{metric}_mean"] = expanded_metrics[metric].mean()
|
|
422
|
-
agg_data[f"{metric}_std"] = expanded_metrics[metric].std()
|
|
423
|
-
agg_data[f"{metric}_median"] = expanded_metrics[metric].median()
|
|
424
|
-
|
|
425
|
-
aggregated_metrics = pd.DataFrame([agg_data])
|
|
426
|
-
else:
|
|
427
|
-
aggregated_metrics = pd.DataFrame()
|
|
428
|
-
|
|
429
|
-
# DEG summary
|
|
430
|
-
deg_summary_data = []
|
|
431
|
-
for ctx_result in context_results:
|
|
432
|
-
if ctx_result.deg_result is not None:
|
|
433
|
-
deg_summary_data.append({
|
|
434
|
-
"context_id": ctx_result.context_id,
|
|
435
|
-
**ctx_result.context_values,
|
|
436
|
-
"n_degs": ctx_result.deg_result.n_degs,
|
|
437
|
-
"n_upregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] > 0).sum(),
|
|
438
|
-
"n_downregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] < 0).sum(),
|
|
439
|
-
"mean_abs_lfc": np.abs(ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg]).mean() if ctx_result.deg_result.n_degs > 0 else np.nan,
|
|
440
|
-
})
|
|
441
|
-
deg_summary = pd.DataFrame(deg_summary_data)
|
|
571
|
+
expanded_metrics = self._build_expanded_metrics(context_results)
|
|
572
|
+
aggregated_metrics = self._build_aggregated_metrics(expanded_metrics)
|
|
573
|
+
deg_summary = self._build_deg_summary(context_results)
|
|
574
|
+
comparison_summary = self._build_comparison_summary(expanded_metrics)
|
|
442
575
|
|
|
443
576
|
self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
|
|
444
577
|
|
|
@@ -447,17 +580,122 @@ class DEGEvaluator:
|
|
|
447
580
|
aggregated_metrics=aggregated_metrics,
|
|
448
581
|
expanded_metrics=expanded_metrics,
|
|
449
582
|
deg_summary=deg_summary,
|
|
583
|
+
comparison_summary=comparison_summary,
|
|
450
584
|
gene_names=self.gene_names,
|
|
451
585
|
settings={
|
|
452
|
-
|
|
453
|
-
"
|
|
454
|
-
"lfc_threshold": self.lfc_threshold,
|
|
455
|
-
"min_degs": self.min_degs,
|
|
586
|
+
**self.deg_settings.to_dict(),
|
|
587
|
+
"compute_all_genes": self.compute_all_genes,
|
|
456
588
|
"metrics": self.metrics,
|
|
457
589
|
"device": self.device,
|
|
458
590
|
"n_jobs": self.n_jobs,
|
|
459
591
|
},
|
|
460
592
|
)
|
|
593
|
+
|
|
594
|
+
def _build_expanded_metrics(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
|
|
595
|
+
"""Build expanded metrics DataFrame with both DEG and all-genes columns."""
|
|
596
|
+
data = []
|
|
597
|
+
for ctx in context_results:
|
|
598
|
+
row = {
|
|
599
|
+
"context_id": ctx.context_id,
|
|
600
|
+
**ctx.context_values,
|
|
601
|
+
"n_samples_real": ctx.n_samples_real,
|
|
602
|
+
"n_samples_gen": ctx.n_samples_gen,
|
|
603
|
+
"n_genes_total": ctx.n_genes_total,
|
|
604
|
+
"n_degs": ctx.n_degs,
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
# DEG metrics with prefix
|
|
608
|
+
for metric_name, value in ctx.deg_metrics.items():
|
|
609
|
+
row[f"deg_{metric_name}"] = value
|
|
610
|
+
|
|
611
|
+
# All-genes metrics with prefix
|
|
612
|
+
for metric_name, value in ctx.all_genes_metrics.items():
|
|
613
|
+
row[f"all_{metric_name}"] = value
|
|
614
|
+
|
|
615
|
+
data.append(row)
|
|
616
|
+
|
|
617
|
+
return pd.DataFrame(data)
|
|
618
|
+
|
|
619
|
+
def _build_aggregated_metrics(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
|
|
620
|
+
"""Build aggregated metrics DataFrame."""
|
|
621
|
+
if len(expanded_metrics) == 0:
|
|
622
|
+
return pd.DataFrame()
|
|
623
|
+
|
|
624
|
+
agg_data = {
|
|
625
|
+
"n_contexts": len(expanded_metrics),
|
|
626
|
+
"total_samples_real": expanded_metrics["n_samples_real"].sum(),
|
|
627
|
+
"total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
|
|
628
|
+
"mean_n_degs": expanded_metrics["n_degs"].mean(),
|
|
629
|
+
"median_n_degs": expanded_metrics["n_degs"].median(),
|
|
630
|
+
"min_n_degs": expanded_metrics["n_degs"].min(),
|
|
631
|
+
"max_n_degs": expanded_metrics["n_degs"].max(),
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
# Aggregate DEG metrics
|
|
635
|
+
for metric in self.metrics:
|
|
636
|
+
col = f"deg_{metric}"
|
|
637
|
+
if col in expanded_metrics.columns:
|
|
638
|
+
agg_data[f"deg_{metric}_mean"] = expanded_metrics[col].mean()
|
|
639
|
+
agg_data[f"deg_{metric}_std"] = expanded_metrics[col].std()
|
|
640
|
+
|
|
641
|
+
# Aggregate all-genes metrics
|
|
642
|
+
for metric in self.metrics:
|
|
643
|
+
col = f"all_{metric}"
|
|
644
|
+
if col in expanded_metrics.columns:
|
|
645
|
+
agg_data[f"all_{metric}_mean"] = expanded_metrics[col].mean()
|
|
646
|
+
agg_data[f"all_{metric}_std"] = expanded_metrics[col].std()
|
|
647
|
+
|
|
648
|
+
return pd.DataFrame([agg_data])
|
|
649
|
+
|
|
650
|
+
def _build_deg_summary(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
|
|
651
|
+
"""Build DEG summary DataFrame."""
|
|
652
|
+
data = []
|
|
653
|
+
for ctx in context_results:
|
|
654
|
+
if ctx.deg_result is not None:
|
|
655
|
+
deg_lfcs = ctx.deg_result.log_fold_changes[ctx.deg_result.is_deg]
|
|
656
|
+
data.append({
|
|
657
|
+
"context_id": ctx.context_id,
|
|
658
|
+
**ctx.context_values,
|
|
659
|
+
"n_degs": ctx.n_degs,
|
|
660
|
+
"n_genes_total": ctx.n_genes_total,
|
|
661
|
+
"deg_fraction": ctx.n_degs / ctx.n_genes_total if ctx.n_genes_total > 0 else 0,
|
|
662
|
+
"n_upregulated": (deg_lfcs > 0).sum(),
|
|
663
|
+
"n_downregulated": (deg_lfcs < 0).sum(),
|
|
664
|
+
"mean_abs_lfc": float(np.abs(deg_lfcs).mean()) if len(deg_lfcs) > 0 else np.nan,
|
|
665
|
+
"max_abs_lfc": float(np.abs(deg_lfcs).max()) if len(deg_lfcs) > 0 else np.nan,
|
|
666
|
+
})
|
|
667
|
+
return pd.DataFrame(data)
|
|
668
|
+
|
|
669
|
+
def _build_comparison_summary(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
|
|
670
|
+
"""Build comparison summary between DEG and all-genes metrics."""
|
|
671
|
+
if len(expanded_metrics) == 0:
|
|
672
|
+
return pd.DataFrame()
|
|
673
|
+
|
|
674
|
+
comparison_data = []
|
|
675
|
+
for metric in self.metrics:
|
|
676
|
+
deg_col = f"deg_{metric}"
|
|
677
|
+
all_col = f"all_{metric}"
|
|
678
|
+
|
|
679
|
+
if deg_col in expanded_metrics.columns and all_col in expanded_metrics.columns:
|
|
680
|
+
deg_values = expanded_metrics[deg_col].dropna()
|
|
681
|
+
all_values = expanded_metrics[all_col].dropna()
|
|
682
|
+
|
|
683
|
+
if len(deg_values) > 0 and len(all_values) > 0:
|
|
684
|
+
deg_mean = deg_values.mean()
|
|
685
|
+
all_mean = all_values.mean()
|
|
686
|
+
|
|
687
|
+
comparison_data.append({
|
|
688
|
+
"metric": metric,
|
|
689
|
+
"deg_mean": deg_mean,
|
|
690
|
+
"deg_std": deg_values.std(),
|
|
691
|
+
"all_mean": all_mean,
|
|
692
|
+
"all_std": all_values.std(),
|
|
693
|
+
"difference": deg_mean - all_mean,
|
|
694
|
+
"ratio": deg_mean / all_mean if all_mean != 0 else np.nan,
|
|
695
|
+
"n_contexts": len(deg_values),
|
|
696
|
+
})
|
|
697
|
+
|
|
698
|
+
return pd.DataFrame(comparison_data)
|
|
461
699
|
|
|
462
700
|
|
|
463
701
|
def evaluate_degs(
|
|
@@ -472,51 +710,93 @@ def evaluate_degs(
|
|
|
472
710
|
deg_method: DEGMethod = "welch",
|
|
473
711
|
pval_threshold: float = 0.05,
|
|
474
712
|
lfc_threshold: float = 0.5,
|
|
713
|
+
n_top_degs: Optional[int] = None,
|
|
714
|
+
min_degs: int = 5,
|
|
715
|
+
compute_all_genes: bool = True,
|
|
475
716
|
metrics: Optional[List[str]] = None,
|
|
476
717
|
n_jobs: int = 1,
|
|
477
718
|
device: str = "auto",
|
|
478
719
|
verbose: bool = True,
|
|
479
720
|
) -> DEGEvaluationResult:
|
|
480
721
|
"""
|
|
481
|
-
Convenience function for DEG-focused evaluation.
|
|
722
|
+
Convenience function for DEG-focused evaluation with full control.
|
|
723
|
+
|
|
724
|
+
Computes metrics on both DEGs and all genes for comparison.
|
|
482
725
|
|
|
483
726
|
Parameters
|
|
484
727
|
----------
|
|
485
728
|
real_data : np.ndarray
|
|
486
|
-
Real expression matrix
|
|
729
|
+
Real expression matrix (n_samples, n_genes)
|
|
487
730
|
generated_data : np.ndarray
|
|
488
|
-
Generated expression matrix
|
|
731
|
+
Generated expression matrix (n_samples, n_genes)
|
|
489
732
|
real_obs : pd.DataFrame
|
|
490
|
-
Real data metadata
|
|
733
|
+
Real data metadata with condition columns
|
|
491
734
|
generated_obs : pd.DataFrame
|
|
492
|
-
Generated data metadata
|
|
735
|
+
Generated data metadata with condition columns
|
|
493
736
|
condition_columns : List[str]
|
|
494
|
-
Columns defining contexts
|
|
737
|
+
Columns defining contexts (e.g., ["cell_type", "perturbation"])
|
|
495
738
|
gene_names : np.ndarray, optional
|
|
496
|
-
Gene names
|
|
739
|
+
Gene names for output
|
|
497
740
|
control_key : str
|
|
498
|
-
Control condition identifier
|
|
741
|
+
Control condition identifier (default: "control")
|
|
499
742
|
perturbation_column : str, optional
|
|
500
743
|
Column containing perturbation info. If None, uses first condition column.
|
|
501
744
|
deg_method : str
|
|
502
|
-
DEG detection method
|
|
745
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
503
746
|
pval_threshold : float
|
|
504
|
-
|
|
747
|
+
Adjusted p-value threshold (default: 0.05)
|
|
505
748
|
lfc_threshold : float
|
|
506
|
-
|
|
749
|
+
Absolute log2 fold change threshold (default: 0.5)
|
|
750
|
+
n_top_degs : int, optional
|
|
751
|
+
If set, use only top N DEGs by significance (overrides thresholds)
|
|
752
|
+
min_degs : int
|
|
753
|
+
Minimum DEGs to compute DEG metrics (default: 5)
|
|
754
|
+
compute_all_genes : bool
|
|
755
|
+
Also compute metrics on all genes for comparison (default: True)
|
|
507
756
|
metrics : List[str], optional
|
|
508
|
-
Metrics to compute
|
|
757
|
+
Metrics to compute. Default: all supported.
|
|
509
758
|
n_jobs : int
|
|
510
|
-
Parallel CPU jobs
|
|
759
|
+
Parallel CPU jobs (default: 1)
|
|
511
760
|
device : str
|
|
512
|
-
Compute device
|
|
761
|
+
Compute device: "cpu", "cuda", "mps", "auto" (default: "auto")
|
|
513
762
|
verbose : bool
|
|
514
|
-
Print progress
|
|
763
|
+
Print progress (default: True)
|
|
515
764
|
|
|
516
765
|
Returns
|
|
517
766
|
-------
|
|
518
767
|
DEGEvaluationResult
|
|
519
|
-
|
|
768
|
+
Complete evaluation results including:
|
|
769
|
+
- expanded_metrics: Per-context metrics for DEGs and all genes
|
|
770
|
+
- aggregated_metrics: Summary statistics
|
|
771
|
+
- deg_summary: DEG detection summary
|
|
772
|
+
- comparison_summary: DEG vs all-genes comparison
|
|
773
|
+
|
|
774
|
+
Examples
|
|
775
|
+
--------
|
|
776
|
+
>>> # Basic usage with default thresholds
|
|
777
|
+
>>> results = evaluate_degs(
|
|
778
|
+
... real_data, generated_data,
|
|
779
|
+
... real_obs, generated_obs,
|
|
780
|
+
... condition_columns=["perturbation"],
|
|
781
|
+
... )
|
|
782
|
+
>>> print(results.comparison_summary)
|
|
783
|
+
|
|
784
|
+
>>> # Top 50 DEGs only
|
|
785
|
+
>>> results = evaluate_degs(
|
|
786
|
+
... real_data, generated_data,
|
|
787
|
+
... real_obs, generated_obs,
|
|
788
|
+
... condition_columns=["perturbation"],
|
|
789
|
+
... n_top_degs=50,
|
|
790
|
+
... )
|
|
791
|
+
|
|
792
|
+
>>> # Strict thresholds
|
|
793
|
+
>>> results = evaluate_degs(
|
|
794
|
+
... real_data, generated_data,
|
|
795
|
+
... real_obs, generated_obs,
|
|
796
|
+
... condition_columns=["perturbation"],
|
|
797
|
+
... pval_threshold=0.01,
|
|
798
|
+
... lfc_threshold=1.0, # 2-fold change
|
|
799
|
+
... )
|
|
520
800
|
"""
|
|
521
801
|
evaluator = DEGEvaluator(
|
|
522
802
|
real_data=real_data,
|
|
@@ -530,6 +810,9 @@ def evaluate_degs(
|
|
|
530
810
|
deg_method=deg_method,
|
|
531
811
|
pval_threshold=pval_threshold,
|
|
532
812
|
lfc_threshold=lfc_threshold,
|
|
813
|
+
n_top_degs=n_top_degs,
|
|
814
|
+
min_degs=min_degs,
|
|
815
|
+
compute_all_genes=compute_all_genes,
|
|
533
816
|
metrics=metrics,
|
|
534
817
|
n_jobs=n_jobs,
|
|
535
818
|
device=device,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gengeneeval
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, DEG-focused evaluation, per-context analysis, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -256,6 +256,8 @@ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, co
|
|
|
256
256
|
#### Key Features
|
|
257
257
|
|
|
258
258
|
- **Fast DEG detection**: Vectorized Welch's t-test, Student's t-test, or Wilcoxon rank-sum
|
|
259
|
+
- **DEG vs all-genes comparison**: Compute metrics on both and compare
|
|
260
|
+
- **Flexible DEG selection**: Top N by significance, or threshold-based filtering
|
|
259
261
|
- **Per-context evaluation**: Automatically evaluates each (covariate × perturbation) combination
|
|
260
262
|
- **GPU acceleration**: DEG detection and metrics on GPU for large datasets
|
|
261
263
|
- **Comprehensive reporting**: Aggregated and expanded results with visualizations
|
|
@@ -266,7 +268,7 @@ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, co
|
|
|
266
268
|
from geneval import evaluate_degs
|
|
267
269
|
import pandas as pd
|
|
268
270
|
|
|
269
|
-
# Evaluate with DEG-focused metrics
|
|
271
|
+
# Evaluate with DEG-focused metrics (computes both DEG and all-genes by default)
|
|
270
272
|
results = evaluate_degs(
|
|
271
273
|
real_data=real_adata.X, # (n_samples, n_genes)
|
|
272
274
|
generated_data=gen_adata.X,
|
|
@@ -276,29 +278,62 @@ results = evaluate_degs(
|
|
|
276
278
|
control_key="control", # Value indicating control samples
|
|
277
279
|
perturbation_column="perturbation",
|
|
278
280
|
deg_method="welch", # or "student", "wilcoxon", "logfc"
|
|
279
|
-
pval_threshold=0.05,
|
|
281
|
+
pval_threshold=0.05, # Significance threshold
|
|
280
282
|
lfc_threshold=0.5, # log2 fold change threshold
|
|
283
|
+
compute_all_genes=True, # Also compute metrics on all genes
|
|
281
284
|
device="cuda", # GPU acceleration
|
|
282
285
|
)
|
|
283
286
|
|
|
284
|
-
#
|
|
285
|
-
print(results.
|
|
286
|
-
|
|
287
|
+
# Compare DEG-only vs all-genes metrics
|
|
288
|
+
print(results.comparison_summary)
|
|
289
|
+
# metric deg_mean all_mean difference ratio
|
|
290
|
+
# wasserstein_1 5.34 0.69 4.65 7.74
|
|
291
|
+
# mmd 1.14 0.13 1.02 9.00
|
|
292
|
+
|
|
293
|
+
# Access per-context results
|
|
294
|
+
print(results.expanded_metrics) # Has deg_* and all_* columns
|
|
287
295
|
print(results.deg_summary) # DEG counts per context
|
|
288
296
|
|
|
289
297
|
# Save results with plots
|
|
290
298
|
results.save("deg_evaluation/")
|
|
291
299
|
```
|
|
292
300
|
|
|
301
|
+
#### DEG Selection Control
|
|
302
|
+
|
|
303
|
+
```python
|
|
304
|
+
# Option 1: Top N most significant DEGs
|
|
305
|
+
results = evaluate_degs(
|
|
306
|
+
...,
|
|
307
|
+
n_top_degs=50, # Use only top 50 DEGs by adjusted p-value
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Option 2: Stricter thresholds
|
|
311
|
+
results = evaluate_degs(
|
|
312
|
+
...,
|
|
313
|
+
pval_threshold=0.01, # More stringent p-value
|
|
314
|
+
lfc_threshold=1.0, # 2-fold change minimum
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Option 3: DEGs only (skip all-genes metrics for speed)
|
|
318
|
+
results = evaluate_degs(
|
|
319
|
+
...,
|
|
320
|
+
compute_all_genes=False,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Get DEG-only or all-genes metrics separately
|
|
324
|
+
deg_only = results.get_deg_only_metrics()
|
|
325
|
+
all_genes = results.get_all_genes_metrics()
|
|
326
|
+
```
|
|
327
|
+
|
|
293
328
|
#### Per-Context Evaluation
|
|
294
329
|
|
|
295
330
|
When multiple condition columns are provided (e.g., `["cell_type", "perturbation"]`), GenEval evaluates **every combination** separately:
|
|
296
331
|
|
|
297
|
-
| Context | n_DEGs |
|
|
298
|
-
|
|
299
|
-
| TypeA_drug1 | 234 |
|
|
300
|
-
| TypeA_drug2 | 189 |
|
|
301
|
-
| TypeB_drug1 | 312 |
|
|
332
|
+
| Context | n_DEGs | deg_W1 | all_W1 | deg_MMD | all_MMD |
|
|
333
|
+
|---------|--------|--------|--------|---------|---------|
|
|
334
|
+
| TypeA_drug1 | 234 | 5.42 | 0.69 | 1.03 | 0.13 |
|
|
335
|
+
| TypeA_drug2 | 189 | 4.38 | 0.71 | 0.92 | 0.12 |
|
|
336
|
+
| TypeB_drug1 | 312 | 6.51 | 0.68 | 1.21 | 0.14 |
|
|
302
337
|
|
|
303
338
|
If only `perturbation` column is provided, evaluation is done per-perturbation.
|
|
304
339
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
geneval/__init__.py,sha256=
|
|
1
|
+
geneval/__init__.py,sha256=UD-fl1x0J0VUTyktgvUzCqaU1kLaU2vmYALfdzm-TzQ,5418
|
|
2
2
|
geneval/cli.py,sha256=0ai0IGyn3SSmEnfLRJhcr0brvUxuNZHE4IXod7jvosU,9977
|
|
3
3
|
geneval/config.py,sha256=gkCjs_gzPWgUZNcmSR3Y70XQCAZ1m9AKLueaM-x8bvw,3729
|
|
4
4
|
geneval/core.py,sha256=No0DP8bNR6LedfCWEedY9C5r_c4M14rvSPaGZqbxc94,1155
|
|
@@ -6,10 +6,10 @@ geneval/data/__init__.py,sha256=NQUPVpUnBIabrTH5TuRk0KE9S7sVO5QetZv-MCQmZuw,827
|
|
|
6
6
|
geneval/data/gene_expression_datamodule.py,sha256=XiBIdf68JZ-3S-FaZsrQlBJA7qL9uUXo2C8y0r4an5M,8009
|
|
7
7
|
geneval/data/lazy_loader.py,sha256=5fTRVjPjcWvYXV-uPWFUF2Nn9rHRdD8lygAUkCW8wOM,20677
|
|
8
8
|
geneval/data/loader.py,sha256=zpRmwGZ4PJkB3rpXXRCMFtvMi4qvUrPkKmvIlGjfRpY,14555
|
|
9
|
-
geneval/deg/__init__.py,sha256=
|
|
9
|
+
geneval/deg/__init__.py,sha256=iNKvtbumTA-A1usWhHIP1rbRVNkje5tN5x81FzD6CbI,1577
|
|
10
10
|
geneval/deg/context.py,sha256=_9gnWnRqqCZUDlegV2sT_rQrw8OeP1TIE9NZjNcI0ig,9069
|
|
11
11
|
geneval/deg/detection.py,sha256=gDdHOyFLOfl_B0xutS3KVFy53sreJ19N33B0RRI01wo,18119
|
|
12
|
-
geneval/deg/evaluator.py,sha256=
|
|
12
|
+
geneval/deg/evaluator.py,sha256=uPduuWovUD6B_vie4RomH-F9MgrtaqQbtjmJlvEeDYM,30493
|
|
13
13
|
geneval/deg/visualization.py,sha256=9lWW9vRH_FbkIjJrf1MPobU1Yu_CAh6aw60S7g2Qe2k,10448
|
|
14
14
|
geneval/evaluator.py,sha256=WgdrgqOcGYT35k1keiFEIIRIj2CQaD2DsmBpq9hcLrI,13440
|
|
15
15
|
geneval/evaluators/__init__.py,sha256=i11sHvhsjEAeI3Aw9zFTPmCYuqkGxzTHggAKehe3HQ0,160
|
|
@@ -33,8 +33,8 @@ geneval/utils/preprocessing.py,sha256=1Cij1O2dwDR6_zh5IEgLPq3jEmV8VfIRjfQrHiKe3M
|
|
|
33
33
|
geneval/visualization/__init__.py,sha256=LN19jl5xV4WVJTePaOUHWvKZ_pgDFp1chhcklGkNtm8,792
|
|
34
34
|
geneval/visualization/plots.py,sha256=3K94r3x5NjIUZ-hYVQIivO63VkLOvDWl-BLB_qL2pSY,15008
|
|
35
35
|
geneval/visualization/visualizer.py,sha256=lX7K0j20nAsgdtOOdbxLdLKYAfovEp3hNAnZOjFTCq0,36670
|
|
36
|
-
gengeneeval-0.4.
|
|
37
|
-
gengeneeval-0.4.
|
|
38
|
-
gengeneeval-0.4.
|
|
39
|
-
gengeneeval-0.4.
|
|
40
|
-
gengeneeval-0.4.
|
|
36
|
+
gengeneeval-0.4.1.dist-info/METADATA,sha256=1xoULyzbHjzOKmTEQcx5fRv7DXEMJwe2XqqKxQhCi2Q,14041
|
|
37
|
+
gengeneeval-0.4.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
38
|
+
gengeneeval-0.4.1.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
|
|
39
|
+
gengeneeval-0.4.1.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
|
|
40
|
+
gengeneeval-0.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|