gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +56 -1
- geneval/deg/__init__.py +65 -0
- geneval/deg/context.py +271 -0
- geneval/deg/detection.py +578 -0
- geneval/deg/evaluator.py +538 -0
- geneval/deg/visualization.py +376 -0
- geneval/evaluator.py +46 -0
- geneval/metrics/__init__.py +25 -0
- geneval/metrics/accelerated.py +857 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/METADATA +164 -3
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/RECORD +14 -8
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/WHEEL +0 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.2.1.dist-info → gengeneeval-0.4.0.dist-info}/licenses/LICENSE +0 -0
geneval/deg/evaluator.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DEG-focused evaluator for GenGeneEval.
|
|
3
|
+
|
|
4
|
+
Computes metrics only on differentially expressed genes, with support for:
|
|
5
|
+
- Per-context evaluation (covariates × perturbations)
|
|
6
|
+
- Fast DEG detection with GPU acceleration
|
|
7
|
+
- Aggregated and expanded result reporting
|
|
8
|
+
"""
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Optional, List, Dict, Union, Any, Literal
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import warnings
|
|
17
|
+
|
|
18
|
+
from .detection import (
|
|
19
|
+
compute_degs_fast,
|
|
20
|
+
compute_degs_gpu,
|
|
21
|
+
compute_degs_auto,
|
|
22
|
+
DEGResult,
|
|
23
|
+
DEGMethod,
|
|
24
|
+
)
|
|
25
|
+
from .context import (
|
|
26
|
+
ContextEvaluator,
|
|
27
|
+
ContextResult,
|
|
28
|
+
get_context_id,
|
|
29
|
+
get_contexts,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Import metrics
|
|
33
|
+
from ..metrics.base_metric import BaseMetric
|
|
34
|
+
from ..metrics.correlation import PearsonCorrelation, SpearmanCorrelation
|
|
35
|
+
from ..metrics.distances import (
|
|
36
|
+
Wasserstein1Distance,
|
|
37
|
+
Wasserstein2Distance,
|
|
38
|
+
MMDDistance,
|
|
39
|
+
EnergyDistance,
|
|
40
|
+
)
|
|
41
|
+
from ..metrics.accelerated import (
|
|
42
|
+
get_available_backends,
|
|
43
|
+
vectorized_wasserstein1,
|
|
44
|
+
vectorized_mmd,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class DEGEvaluationResult:
|
|
50
|
+
"""Complete DEG evaluation results.
|
|
51
|
+
|
|
52
|
+
Attributes
|
|
53
|
+
----------
|
|
54
|
+
context_results : List[ContextResult]
|
|
55
|
+
Results for each context
|
|
56
|
+
aggregated_metrics : pd.DataFrame
|
|
57
|
+
Aggregated metrics across contexts
|
|
58
|
+
expanded_metrics : pd.DataFrame
|
|
59
|
+
Per-context expanded metrics
|
|
60
|
+
deg_summary : pd.DataFrame
|
|
61
|
+
Summary of DEG detection per context
|
|
62
|
+
gene_names : np.ndarray
|
|
63
|
+
All gene names
|
|
64
|
+
settings : Dict
|
|
65
|
+
Evaluation settings
|
|
66
|
+
"""
|
|
67
|
+
context_results: List[ContextResult]
|
|
68
|
+
aggregated_metrics: pd.DataFrame
|
|
69
|
+
expanded_metrics: pd.DataFrame
|
|
70
|
+
deg_summary: pd.DataFrame
|
|
71
|
+
gene_names: np.ndarray
|
|
72
|
+
settings: Dict[str, Any]
|
|
73
|
+
|
|
74
|
+
def save(self, output_dir: Union[str, Path]) -> None:
|
|
75
|
+
"""Save results to directory."""
|
|
76
|
+
output_dir = Path(output_dir)
|
|
77
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
|
|
79
|
+
self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv")
|
|
80
|
+
self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv")
|
|
81
|
+
self.deg_summary.to_csv(output_dir / "deg_summary.csv")
|
|
82
|
+
|
|
83
|
+
# Save per-context DEG results
|
|
84
|
+
deg_dir = output_dir / "deg_per_context"
|
|
85
|
+
deg_dir.mkdir(exist_ok=True)
|
|
86
|
+
for ctx_result in self.context_results:
|
|
87
|
+
if ctx_result.deg_result is not None:
|
|
88
|
+
ctx_result.deg_result.to_dataframe().to_csv(
|
|
89
|
+
deg_dir / f"{ctx_result.context_id}_degs.csv"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def __repr__(self) -> str:
|
|
93
|
+
return (
|
|
94
|
+
f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
|
|
95
|
+
f"metrics={list(self.aggregated_metrics.columns)})"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class DEGEvaluator:
|
|
100
|
+
"""
|
|
101
|
+
Evaluator that computes metrics on DEGs only.
|
|
102
|
+
|
|
103
|
+
This evaluator:
|
|
104
|
+
1. Detects DEGs for each perturbation context
|
|
105
|
+
2. Computes distributional metrics only on DEG genes
|
|
106
|
+
3. Reports per-context and aggregated results
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
real_data : np.ndarray
|
|
111
|
+
Real expression matrix (n_samples, n_genes)
|
|
112
|
+
generated_data : np.ndarray
|
|
113
|
+
Generated expression matrix (n_samples, n_genes)
|
|
114
|
+
real_obs : pd.DataFrame
|
|
115
|
+
Real data observation metadata
|
|
116
|
+
generated_obs : pd.DataFrame
|
|
117
|
+
Generated data observation metadata
|
|
118
|
+
condition_columns : List[str]
|
|
119
|
+
Columns defining contexts (e.g., ["cell_type", "perturbation"])
|
|
120
|
+
gene_names : np.ndarray, optional
|
|
121
|
+
Gene names
|
|
122
|
+
control_key : str
|
|
123
|
+
Value indicating control samples (default: "control")
|
|
124
|
+
perturbation_column : str, optional
|
|
125
|
+
Column containing perturbation info. If None, uses first condition column.
|
|
126
|
+
deg_method : str
|
|
127
|
+
DEG detection method: "welch", "student", "wilcoxon", "logfc"
|
|
128
|
+
pval_threshold : float
|
|
129
|
+
P-value threshold for DEG significance
|
|
130
|
+
lfc_threshold : float
|
|
131
|
+
Log2 fold change threshold
|
|
132
|
+
min_degs : int
|
|
133
|
+
Minimum DEGs required to compute metrics
|
|
134
|
+
metrics : List[str], optional
|
|
135
|
+
Metrics to compute. Default: all supported metrics.
|
|
136
|
+
n_jobs : int
|
|
137
|
+
Number of parallel CPU jobs
|
|
138
|
+
device : str
|
|
139
|
+
Compute device: "cpu", "cuda", "mps", "auto"
|
|
140
|
+
verbose : bool
|
|
141
|
+
Print progress
|
|
142
|
+
|
|
143
|
+
Examples
|
|
144
|
+
--------
|
|
145
|
+
>>> evaluator = DEGEvaluator(
|
|
146
|
+
... real_data, generated_data,
|
|
147
|
+
... real_obs, generated_obs,
|
|
148
|
+
... condition_columns=["perturbation"],
|
|
149
|
+
... deg_method="welch",
|
|
150
|
+
... device="cuda",
|
|
151
|
+
... )
|
|
152
|
+
>>> results = evaluator.evaluate()
|
|
153
|
+
>>> results.save("output/")
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Supported metrics
|
|
157
|
+
SUPPORTED_METRICS = [
|
|
158
|
+
"wasserstein_1",
|
|
159
|
+
"wasserstein_2",
|
|
160
|
+
"mmd",
|
|
161
|
+
"energy",
|
|
162
|
+
"pearson",
|
|
163
|
+
"spearman",
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
real_data: np.ndarray,
|
|
169
|
+
generated_data: np.ndarray,
|
|
170
|
+
real_obs: pd.DataFrame,
|
|
171
|
+
generated_obs: pd.DataFrame,
|
|
172
|
+
condition_columns: List[str],
|
|
173
|
+
gene_names: Optional[np.ndarray] = None,
|
|
174
|
+
control_key: str = "control",
|
|
175
|
+
perturbation_column: Optional[str] = None,
|
|
176
|
+
deg_method: DEGMethod = "welch",
|
|
177
|
+
pval_threshold: float = 0.05,
|
|
178
|
+
lfc_threshold: float = 0.5,
|
|
179
|
+
min_degs: int = 5,
|
|
180
|
+
metrics: Optional[List[str]] = None,
|
|
181
|
+
n_jobs: int = 1,
|
|
182
|
+
device: str = "cpu",
|
|
183
|
+
verbose: bool = True,
|
|
184
|
+
):
|
|
185
|
+
self.real_data = np.asarray(real_data, dtype=np.float32)
|
|
186
|
+
self.generated_data = np.asarray(generated_data, dtype=np.float32)
|
|
187
|
+
self.real_obs = real_obs.reset_index(drop=True)
|
|
188
|
+
self.generated_obs = generated_obs.reset_index(drop=True)
|
|
189
|
+
self.condition_columns = condition_columns
|
|
190
|
+
self.gene_names = gene_names if gene_names is not None else np.array(
|
|
191
|
+
[f"Gene_{i}" for i in range(real_data.shape[1])]
|
|
192
|
+
)
|
|
193
|
+
self.control_key = control_key
|
|
194
|
+
self.perturbation_column = perturbation_column or condition_columns[0]
|
|
195
|
+
self.deg_method = deg_method
|
|
196
|
+
self.pval_threshold = pval_threshold
|
|
197
|
+
self.lfc_threshold = lfc_threshold
|
|
198
|
+
self.min_degs = min_degs
|
|
199
|
+
self.metrics = metrics or self.SUPPORTED_METRICS
|
|
200
|
+
self.n_jobs = n_jobs
|
|
201
|
+
self.device = device
|
|
202
|
+
self.verbose = verbose
|
|
203
|
+
|
|
204
|
+
# Create context evaluator
|
|
205
|
+
self.context_evaluator = ContextEvaluator(
|
|
206
|
+
real_data=self.real_data,
|
|
207
|
+
generated_data=self.generated_data,
|
|
208
|
+
real_obs=self.real_obs,
|
|
209
|
+
generated_obs=self.generated_obs,
|
|
210
|
+
condition_columns=condition_columns,
|
|
211
|
+
gene_names=self.gene_names,
|
|
212
|
+
control_key=control_key,
|
|
213
|
+
perturbation_column=self.perturbation_column,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Initialize metric objects
|
|
217
|
+
self._metric_objects = {
|
|
218
|
+
"wasserstein_1": Wasserstein1Distance(),
|
|
219
|
+
"wasserstein_2": Wasserstein2Distance(),
|
|
220
|
+
"mmd": MMDDistance(),
|
|
221
|
+
"energy": EnergyDistance(),
|
|
222
|
+
"pearson": PearsonCorrelation(),
|
|
223
|
+
"spearman": SpearmanCorrelation(),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
|
|
227
|
+
self._log(f"Perturbation contexts: {len(self.context_evaluator.get_perturbation_contexts())}")
|
|
228
|
+
|
|
229
|
+
def _log(self, msg: str) -> None:
|
|
230
|
+
"""Print if verbose."""
|
|
231
|
+
if self.verbose:
|
|
232
|
+
print(msg)
|
|
233
|
+
|
|
234
|
+
def _compute_degs(
|
|
235
|
+
self,
|
|
236
|
+
control: np.ndarray,
|
|
237
|
+
perturbed: np.ndarray,
|
|
238
|
+
) -> DEGResult:
|
|
239
|
+
"""Compute DEGs using configured method and device."""
|
|
240
|
+
return compute_degs_auto(
|
|
241
|
+
control=control,
|
|
242
|
+
perturbed=perturbed,
|
|
243
|
+
gene_names=self.gene_names,
|
|
244
|
+
method=self.deg_method,
|
|
245
|
+
pval_threshold=self.pval_threshold,
|
|
246
|
+
lfc_threshold=self.lfc_threshold,
|
|
247
|
+
n_jobs=self.n_jobs,
|
|
248
|
+
device=self.device,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def _compute_metrics_on_degs(
|
|
252
|
+
self,
|
|
253
|
+
real: np.ndarray,
|
|
254
|
+
generated: np.ndarray,
|
|
255
|
+
deg_indices: np.ndarray,
|
|
256
|
+
) -> Dict[str, float]:
|
|
257
|
+
"""Compute metrics on DEG genes only."""
|
|
258
|
+
if len(deg_indices) < self.min_degs:
|
|
259
|
+
return {m: np.nan for m in self.metrics}
|
|
260
|
+
|
|
261
|
+
# Slice to DEGs only
|
|
262
|
+
real_degs = real[:, deg_indices]
|
|
263
|
+
gen_degs = generated[:, deg_indices]
|
|
264
|
+
|
|
265
|
+
results = {}
|
|
266
|
+
|
|
267
|
+
for metric_name in self.metrics:
|
|
268
|
+
if metric_name not in self._metric_objects:
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
metric = self._metric_objects[metric_name]
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
# Compute per-gene and aggregate
|
|
275
|
+
per_gene = metric.compute_per_gene(real_degs, gen_degs)
|
|
276
|
+
results[metric_name] = float(np.nanmean(per_gene))
|
|
277
|
+
except Exception as e:
|
|
278
|
+
if self.verbose:
|
|
279
|
+
self._log(f"Warning: {metric_name} failed: {e}")
|
|
280
|
+
results[metric_name] = np.nan
|
|
281
|
+
|
|
282
|
+
return results
|
|
283
|
+
|
|
284
|
+
def _compute_metrics_accelerated(
|
|
285
|
+
self,
|
|
286
|
+
real: np.ndarray,
|
|
287
|
+
generated: np.ndarray,
|
|
288
|
+
deg_indices: np.ndarray,
|
|
289
|
+
) -> Dict[str, float]:
|
|
290
|
+
"""Compute metrics using accelerated implementations."""
|
|
291
|
+
if len(deg_indices) < self.min_degs:
|
|
292
|
+
return {m: np.nan for m in self.metrics}
|
|
293
|
+
|
|
294
|
+
# Slice to DEGs only
|
|
295
|
+
real_degs = real[:, deg_indices]
|
|
296
|
+
gen_degs = generated[:, deg_indices]
|
|
297
|
+
|
|
298
|
+
results = {}
|
|
299
|
+
backends = get_available_backends()
|
|
300
|
+
|
|
301
|
+
# Use vectorized implementations where available
|
|
302
|
+
if "wasserstein_1" in self.metrics:
|
|
303
|
+
try:
|
|
304
|
+
w1_per_gene = vectorized_wasserstein1(real_degs, gen_degs)
|
|
305
|
+
results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
|
|
306
|
+
except Exception:
|
|
307
|
+
results["wasserstein_1"] = np.nan
|
|
308
|
+
|
|
309
|
+
if "mmd" in self.metrics:
|
|
310
|
+
try:
|
|
311
|
+
mmd_per_gene = vectorized_mmd(real_degs, gen_degs)
|
|
312
|
+
results["mmd"] = float(np.nanmean(mmd_per_gene))
|
|
313
|
+
except Exception:
|
|
314
|
+
results["mmd"] = np.nan
|
|
315
|
+
|
|
316
|
+
# Fall back to standard computation for other metrics
|
|
317
|
+
for metric_name in self.metrics:
|
|
318
|
+
if metric_name in results:
|
|
319
|
+
continue
|
|
320
|
+
if metric_name not in self._metric_objects:
|
|
321
|
+
continue
|
|
322
|
+
|
|
323
|
+
metric = self._metric_objects[metric_name]
|
|
324
|
+
try:
|
|
325
|
+
per_gene = metric.compute_per_gene(real_degs, gen_degs)
|
|
326
|
+
results[metric_name] = float(np.nanmean(per_gene))
|
|
327
|
+
except Exception:
|
|
328
|
+
results[metric_name] = np.nan
|
|
329
|
+
|
|
330
|
+
return results
|
|
331
|
+
|
|
332
|
+
def evaluate(self) -> DEGEvaluationResult:
|
|
333
|
+
"""
|
|
334
|
+
Run DEG-focused evaluation on all contexts.
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
DEGEvaluationResult
|
|
339
|
+
Complete evaluation results with per-context and aggregated metrics.
|
|
340
|
+
"""
|
|
341
|
+
context_results = []
|
|
342
|
+
|
|
343
|
+
perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
|
|
344
|
+
n_contexts = len(perturbation_contexts)
|
|
345
|
+
|
|
346
|
+
self._log(f"Evaluating {n_contexts} perturbation contexts...")
|
|
347
|
+
|
|
348
|
+
for i, context in enumerate(perturbation_contexts):
|
|
349
|
+
context_id = get_context_id(context)
|
|
350
|
+
|
|
351
|
+
if self.verbose:
|
|
352
|
+
print(f" [{i+1}/{n_contexts}] {context_id}", end="... ")
|
|
353
|
+
|
|
354
|
+
try:
|
|
355
|
+
# Get perturbed data
|
|
356
|
+
real_pert, gen_pert = self.context_evaluator.get_context_data(context)
|
|
357
|
+
|
|
358
|
+
# Get control data
|
|
359
|
+
real_ctrl, gen_ctrl = self.context_evaluator.get_control_data(context)
|
|
360
|
+
|
|
361
|
+
if len(real_ctrl) < 2 or len(real_pert) < 2:
|
|
362
|
+
if self.verbose:
|
|
363
|
+
print("skipped (insufficient samples)")
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
# Compute DEGs using real data (control vs perturbed)
|
|
367
|
+
deg_result = self._compute_degs(real_ctrl, real_pert)
|
|
368
|
+
|
|
369
|
+
if self.verbose:
|
|
370
|
+
print(f"{deg_result.n_degs} DEGs", end="... ")
|
|
371
|
+
|
|
372
|
+
# Compute metrics on DEGs
|
|
373
|
+
metrics = self._compute_metrics_accelerated(
|
|
374
|
+
real_pert, gen_pert, deg_result.deg_indices
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
ctx_result = ContextResult(
|
|
378
|
+
context_id=context_id,
|
|
379
|
+
context_values=context,
|
|
380
|
+
n_samples_real=len(real_pert),
|
|
381
|
+
n_samples_gen=len(gen_pert),
|
|
382
|
+
deg_result=deg_result,
|
|
383
|
+
metrics=metrics,
|
|
384
|
+
)
|
|
385
|
+
context_results.append(ctx_result)
|
|
386
|
+
|
|
387
|
+
if self.verbose:
|
|
388
|
+
print("done")
|
|
389
|
+
|
|
390
|
+
except Exception as e:
|
|
391
|
+
if self.verbose:
|
|
392
|
+
print(f"error: {e}")
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
# Build result DataFrames
|
|
396
|
+
expanded_data = []
|
|
397
|
+
for ctx_result in context_results:
|
|
398
|
+
row = {
|
|
399
|
+
"context_id": ctx_result.context_id,
|
|
400
|
+
**ctx_result.context_values,
|
|
401
|
+
"n_samples_real": ctx_result.n_samples_real,
|
|
402
|
+
"n_samples_gen": ctx_result.n_samples_gen,
|
|
403
|
+
"n_degs": ctx_result.deg_result.n_degs if ctx_result.deg_result else 0,
|
|
404
|
+
**ctx_result.metrics,
|
|
405
|
+
}
|
|
406
|
+
expanded_data.append(row)
|
|
407
|
+
|
|
408
|
+
expanded_metrics = pd.DataFrame(expanded_data)
|
|
409
|
+
|
|
410
|
+
# Aggregated metrics
|
|
411
|
+
if len(expanded_metrics) > 0:
|
|
412
|
+
agg_data = {
|
|
413
|
+
"n_contexts": len(context_results),
|
|
414
|
+
"total_samples_real": expanded_metrics["n_samples_real"].sum(),
|
|
415
|
+
"total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
|
|
416
|
+
"mean_n_degs": expanded_metrics["n_degs"].mean(),
|
|
417
|
+
"median_n_degs": expanded_metrics["n_degs"].median(),
|
|
418
|
+
}
|
|
419
|
+
for metric in self.metrics:
|
|
420
|
+
if metric in expanded_metrics.columns:
|
|
421
|
+
agg_data[f"{metric}_mean"] = expanded_metrics[metric].mean()
|
|
422
|
+
agg_data[f"{metric}_std"] = expanded_metrics[metric].std()
|
|
423
|
+
agg_data[f"{metric}_median"] = expanded_metrics[metric].median()
|
|
424
|
+
|
|
425
|
+
aggregated_metrics = pd.DataFrame([agg_data])
|
|
426
|
+
else:
|
|
427
|
+
aggregated_metrics = pd.DataFrame()
|
|
428
|
+
|
|
429
|
+
# DEG summary
|
|
430
|
+
deg_summary_data = []
|
|
431
|
+
for ctx_result in context_results:
|
|
432
|
+
if ctx_result.deg_result is not None:
|
|
433
|
+
deg_summary_data.append({
|
|
434
|
+
"context_id": ctx_result.context_id,
|
|
435
|
+
**ctx_result.context_values,
|
|
436
|
+
"n_degs": ctx_result.deg_result.n_degs,
|
|
437
|
+
"n_upregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] > 0).sum(),
|
|
438
|
+
"n_downregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] < 0).sum(),
|
|
439
|
+
"mean_abs_lfc": np.abs(ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg]).mean() if ctx_result.deg_result.n_degs > 0 else np.nan,
|
|
440
|
+
})
|
|
441
|
+
deg_summary = pd.DataFrame(deg_summary_data)
|
|
442
|
+
|
|
443
|
+
self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
|
|
444
|
+
|
|
445
|
+
return DEGEvaluationResult(
|
|
446
|
+
context_results=context_results,
|
|
447
|
+
aggregated_metrics=aggregated_metrics,
|
|
448
|
+
expanded_metrics=expanded_metrics,
|
|
449
|
+
deg_summary=deg_summary,
|
|
450
|
+
gene_names=self.gene_names,
|
|
451
|
+
settings={
|
|
452
|
+
"deg_method": self.deg_method,
|
|
453
|
+
"pval_threshold": self.pval_threshold,
|
|
454
|
+
"lfc_threshold": self.lfc_threshold,
|
|
455
|
+
"min_degs": self.min_degs,
|
|
456
|
+
"metrics": self.metrics,
|
|
457
|
+
"device": self.device,
|
|
458
|
+
"n_jobs": self.n_jobs,
|
|
459
|
+
},
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def evaluate_degs(
|
|
464
|
+
real_data: np.ndarray,
|
|
465
|
+
generated_data: np.ndarray,
|
|
466
|
+
real_obs: pd.DataFrame,
|
|
467
|
+
generated_obs: pd.DataFrame,
|
|
468
|
+
condition_columns: List[str],
|
|
469
|
+
gene_names: Optional[np.ndarray] = None,
|
|
470
|
+
control_key: str = "control",
|
|
471
|
+
perturbation_column: Optional[str] = None,
|
|
472
|
+
deg_method: DEGMethod = "welch",
|
|
473
|
+
pval_threshold: float = 0.05,
|
|
474
|
+
lfc_threshold: float = 0.5,
|
|
475
|
+
metrics: Optional[List[str]] = None,
|
|
476
|
+
n_jobs: int = 1,
|
|
477
|
+
device: str = "auto",
|
|
478
|
+
verbose: bool = True,
|
|
479
|
+
) -> DEGEvaluationResult:
|
|
480
|
+
"""
|
|
481
|
+
Convenience function for DEG-focused evaluation.
|
|
482
|
+
|
|
483
|
+
Parameters
|
|
484
|
+
----------
|
|
485
|
+
real_data : np.ndarray
|
|
486
|
+
Real expression matrix
|
|
487
|
+
generated_data : np.ndarray
|
|
488
|
+
Generated expression matrix
|
|
489
|
+
real_obs : pd.DataFrame
|
|
490
|
+
Real data metadata
|
|
491
|
+
generated_obs : pd.DataFrame
|
|
492
|
+
Generated data metadata
|
|
493
|
+
condition_columns : List[str]
|
|
494
|
+
Columns defining contexts
|
|
495
|
+
gene_names : np.ndarray, optional
|
|
496
|
+
Gene names
|
|
497
|
+
control_key : str
|
|
498
|
+
Control condition identifier
|
|
499
|
+
perturbation_column : str, optional
|
|
500
|
+
Column containing perturbation info. If None, uses first condition column.
|
|
501
|
+
deg_method : str
|
|
502
|
+
DEG detection method
|
|
503
|
+
pval_threshold : float
|
|
504
|
+
P-value threshold
|
|
505
|
+
lfc_threshold : float
|
|
506
|
+
Log fold change threshold
|
|
507
|
+
metrics : List[str], optional
|
|
508
|
+
Metrics to compute
|
|
509
|
+
n_jobs : int
|
|
510
|
+
Parallel CPU jobs
|
|
511
|
+
device : str
|
|
512
|
+
Compute device
|
|
513
|
+
verbose : bool
|
|
514
|
+
Print progress
|
|
515
|
+
|
|
516
|
+
Returns
|
|
517
|
+
-------
|
|
518
|
+
DEGEvaluationResult
|
|
519
|
+
Evaluation results
|
|
520
|
+
"""
|
|
521
|
+
evaluator = DEGEvaluator(
|
|
522
|
+
real_data=real_data,
|
|
523
|
+
generated_data=generated_data,
|
|
524
|
+
real_obs=real_obs,
|
|
525
|
+
generated_obs=generated_obs,
|
|
526
|
+
condition_columns=condition_columns,
|
|
527
|
+
gene_names=gene_names,
|
|
528
|
+
control_key=control_key,
|
|
529
|
+
perturbation_column=perturbation_column,
|
|
530
|
+
deg_method=deg_method,
|
|
531
|
+
pval_threshold=pval_threshold,
|
|
532
|
+
lfc_threshold=lfc_threshold,
|
|
533
|
+
metrics=metrics,
|
|
534
|
+
n_jobs=n_jobs,
|
|
535
|
+
device=device,
|
|
536
|
+
verbose=verbose,
|
|
537
|
+
)
|
|
538
|
+
return evaluator.evaluate()
|