gengeneeval 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,821 @@
1
+ """
2
+ DEG-focused evaluator for GenGeneEval.
3
+
4
+ Computes metrics on differentially expressed genes with full control:
5
+ - Comparison of DEG-only vs all-genes metrics
6
+ - Configurable DEG selection (top N, p-value, log fold change thresholds)
7
+ - Per-context evaluation (covariates × perturbations)
8
+ - Fast DEG detection with GPU acceleration
9
+ - Aggregated and expanded result reporting
10
+ """
11
+ from __future__ import annotations
12
+
13
+ from typing import Optional, List, Dict, Union, Any, Literal
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import pandas as pd
18
+ import warnings
19
+
20
+ from .detection import (
21
+ compute_degs_fast,
22
+ compute_degs_gpu,
23
+ compute_degs_auto,
24
+ DEGResult,
25
+ DEGMethod,
26
+ )
27
+ from .context import (
28
+ ContextEvaluator,
29
+ ContextResult,
30
+ get_context_id,
31
+ get_contexts,
32
+ )
33
+
34
+ # Import metrics
35
+ from ..metrics.base_metric import BaseMetric
36
+ from ..metrics.correlation import PearsonCorrelation, SpearmanCorrelation
37
+ from ..metrics.distances import (
38
+ Wasserstein1Distance,
39
+ Wasserstein2Distance,
40
+ MMDDistance,
41
+ EnergyDistance,
42
+ )
43
+ from ..metrics.accelerated import (
44
+ get_available_backends,
45
+ vectorized_wasserstein1,
46
+ vectorized_mmd,
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class DEGSettings:
52
+ """Settings for DEG detection and filtering.
53
+
54
+ Attributes
55
+ ----------
56
+ method : str
57
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
58
+ pval_threshold : float
59
+ P-value threshold for significance
60
+ lfc_threshold : float
61
+ Absolute log2 fold change threshold
62
+ n_top_degs : int, optional
63
+ If set, use only top N DEGs by significance (overrides threshold filtering)
64
+ min_degs : int
65
+ Minimum number of DEGs required to compute metrics
66
+ """
67
+ method: str = "welch"
68
+ pval_threshold: float = 0.05
69
+ lfc_threshold: float = 0.5
70
+ n_top_degs: Optional[int] = None
71
+ min_degs: int = 5
72
+
73
+ def to_dict(self) -> Dict[str, Any]:
74
+ """Convert to dictionary."""
75
+ return {
76
+ "deg_method": self.method,
77
+ "pval_threshold": self.pval_threshold,
78
+ "lfc_threshold": self.lfc_threshold,
79
+ "n_top_degs": self.n_top_degs,
80
+ "min_degs": self.min_degs,
81
+ }
82
+
83
+
84
+ @dataclass
85
+ class ContextMetrics:
86
+ """Metrics for a single context, comparing DEG-only vs all genes.
87
+
88
+ Attributes
89
+ ----------
90
+ context_id : str
91
+ Context identifier
92
+ context_values : Dict
93
+ Context column values
94
+ n_samples_real : int
95
+ Number of real samples
96
+ n_samples_gen : int
97
+ Number of generated samples
98
+ n_genes_total : int
99
+ Total number of genes
100
+ deg_result : DEGResult
101
+ DEG detection results
102
+ deg_metrics : Dict[str, float]
103
+ Metrics computed on DEGs only
104
+ all_genes_metrics : Dict[str, float]
105
+ Metrics computed on all genes
106
+ """
107
+ context_id: str
108
+ context_values: Dict[str, Any]
109
+ n_samples_real: int
110
+ n_samples_gen: int
111
+ n_genes_total: int
112
+ deg_result: Optional[DEGResult]
113
+ deg_metrics: Dict[str, float]
114
+ all_genes_metrics: Dict[str, float]
115
+
116
+ @property
117
+ def n_degs(self) -> int:
118
+ """Number of DEGs."""
119
+ return self.deg_result.n_degs if self.deg_result else 0
120
+
121
+ @property
122
+ def deg_indices_used(self) -> np.ndarray:
123
+ """DEG indices actually used for metrics."""
124
+ return self.deg_result.deg_indices if self.deg_result else np.array([])
125
+
126
+
127
+ @dataclass
128
+ class DEGEvaluationResult:
129
+ """Complete DEG evaluation results with comparison to all-genes metrics.
130
+
131
+ Attributes
132
+ ----------
133
+ context_results : List[ContextMetrics]
134
+ Results for each context with both DEG and all-gene metrics
135
+ aggregated_metrics : pd.DataFrame
136
+ Aggregated metrics across contexts (both DEG and all-genes)
137
+ expanded_metrics : pd.DataFrame
138
+ Per-context expanded metrics (both DEG and all-genes)
139
+ deg_summary : pd.DataFrame
140
+ Summary of DEG detection per context
141
+ comparison_summary : pd.DataFrame
142
+ Comparison between DEG-only and all-genes metrics
143
+ gene_names : np.ndarray
144
+ All gene names
145
+ settings : Dict
146
+ Evaluation settings including DEG parameters
147
+ """
148
+ context_results: List[ContextMetrics]
149
+ aggregated_metrics: pd.DataFrame
150
+ expanded_metrics: pd.DataFrame
151
+ deg_summary: pd.DataFrame
152
+ comparison_summary: pd.DataFrame
153
+ gene_names: np.ndarray
154
+ settings: Dict[str, Any]
155
+
156
+ def save(self, output_dir: Union[str, Path]) -> None:
157
+ """Save results to directory."""
158
+ output_dir = Path(output_dir)
159
+ output_dir.mkdir(parents=True, exist_ok=True)
160
+
161
+ self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv", index=False)
162
+ self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv", index=False)
163
+ self.deg_summary.to_csv(output_dir / "deg_summary.csv", index=False)
164
+ self.comparison_summary.to_csv(output_dir / "deg_vs_all_comparison.csv", index=False)
165
+
166
+ # Save settings
167
+ import json
168
+ with open(output_dir / "settings.json", "w") as f:
169
+ json.dump(self.settings, f, indent=2, default=str)
170
+
171
+ # Save per-context DEG results
172
+ deg_dir = output_dir / "deg_per_context"
173
+ deg_dir.mkdir(exist_ok=True)
174
+ for ctx_result in self.context_results:
175
+ if ctx_result.deg_result is not None:
176
+ ctx_result.deg_result.to_dataframe().to_csv(
177
+ deg_dir / f"{ctx_result.context_id}_degs.csv", index=False
178
+ )
179
+
180
+ def get_deg_only_metrics(self) -> pd.DataFrame:
181
+ """Get expanded metrics for DEGs only."""
182
+ base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
183
+ deg_cols = [c for c in self.expanded_metrics.columns if c.startswith("deg_")]
184
+ cols = [c for c in base_cols if c in self.expanded_metrics.columns] + deg_cols
185
+ return self.expanded_metrics[cols].copy()
186
+
187
+ def get_all_genes_metrics(self) -> pd.DataFrame:
188
+ """Get expanded metrics for all genes."""
189
+ base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
190
+ all_cols = [c for c in self.expanded_metrics.columns if c.startswith("all_")]
191
+ cols = [c for c in base_cols if c in self.expanded_metrics.columns] + all_cols
192
+ return self.expanded_metrics[cols].copy()
193
+
194
+ def __repr__(self) -> str:
195
+ n_degs_avg = self.deg_summary["n_degs"].mean() if len(self.deg_summary) > 0 else 0
196
+ return (
197
+ f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
198
+ f"avg_degs={n_degs_avg:.1f}, "
199
+ f"settings={self.settings.get('deg_method', 'unknown')})"
200
+ )
201
+
202
+
203
+ class DEGEvaluator:
204
+ """
205
+ Evaluator that computes metrics on DEGs with comparison to all genes.
206
+
207
+ This evaluator:
208
+ 1. Detects DEGs for each perturbation context
209
+ 2. Computes distributional metrics on BOTH DEG genes AND all genes
210
+ 3. Provides comparison between DEG-focused and all-genes evaluation
211
+ 4. Reports per-context and aggregated results
212
+
213
+ Parameters
214
+ ----------
215
+ real_data : np.ndarray
216
+ Real expression matrix (n_samples, n_genes)
217
+ generated_data : np.ndarray
218
+ Generated expression matrix (n_samples, n_genes)
219
+ real_obs : pd.DataFrame
220
+ Real data observation metadata
221
+ generated_obs : pd.DataFrame
222
+ Generated data observation metadata
223
+ condition_columns : List[str]
224
+ Columns defining contexts (e.g., ["cell_type", "perturbation"])
225
+ gene_names : np.ndarray, optional
226
+ Gene names
227
+ control_key : str
228
+ Value indicating control samples (default: "control")
229
+ perturbation_column : str, optional
230
+ Column containing perturbation info. If None, uses first condition column.
231
+ deg_method : str
232
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
233
+ pval_threshold : float
234
+ P-value threshold for DEG significance (default: 0.05)
235
+ lfc_threshold : float
236
+ Log2 fold change threshold (default: 0.5)
237
+ n_top_degs : int, optional
238
+ If set, use only top N DEGs by significance instead of thresholds
239
+ min_degs : int
240
+ Minimum DEGs required to compute DEG-specific metrics (default: 5)
241
+ compute_all_genes : bool
242
+ Whether to also compute metrics on all genes (default: True)
243
+ metrics : List[str], optional
244
+ Metrics to compute. Default: all supported metrics.
245
+ n_jobs : int
246
+ Number of parallel CPU jobs
247
+ device : str
248
+ Compute device: "cpu", "cuda", "mps", "auto"
249
+ verbose : bool
250
+ Print progress
251
+
252
+ Examples
253
+ --------
254
+ >>> # Basic usage - computes both DEG and all-genes metrics
255
+ >>> evaluator = DEGEvaluator(
256
+ ... real_data, generated_data,
257
+ ... real_obs, generated_obs,
258
+ ... condition_columns=["perturbation"],
259
+ ... )
260
+ >>> results = evaluator.evaluate()
261
+ >>> print(results.comparison_summary) # DEG vs all-genes comparison
262
+
263
+ >>> # Use top 100 DEGs only
264
+ >>> evaluator = DEGEvaluator(
265
+ ... real_data, generated_data,
266
+ ... real_obs, generated_obs,
267
+ ... condition_columns=["perturbation"],
268
+ ... n_top_degs=100, # Use top 100 most significant DEGs
269
+ ... )
270
+
271
+ >>> # Stricter thresholds
272
+ >>> evaluator = DEGEvaluator(
273
+ ... real_data, generated_data,
274
+ ... real_obs, generated_obs,
275
+ ... condition_columns=["perturbation"],
276
+ ... pval_threshold=0.01, # More stringent p-value
277
+ ... lfc_threshold=1.0, # log2 FC > 1 (2-fold change)
278
+ ... )
279
+
280
+ >>> # DEGs only (no all-genes metrics for speed)
281
+ >>> evaluator = DEGEvaluator(
282
+ ... real_data, generated_data,
283
+ ... real_obs, generated_obs,
284
+ ... condition_columns=["perturbation"],
285
+ ... compute_all_genes=False,
286
+ ... )
287
+ """
288
+
289
+ # Supported metrics
290
+ SUPPORTED_METRICS = [
291
+ "wasserstein_1",
292
+ "wasserstein_2",
293
+ "mmd",
294
+ "energy",
295
+ "pearson",
296
+ "spearman",
297
+ ]
298
+
299
+ def __init__(
300
+ self,
301
+ real_data: np.ndarray,
302
+ generated_data: np.ndarray,
303
+ real_obs: pd.DataFrame,
304
+ generated_obs: pd.DataFrame,
305
+ condition_columns: List[str],
306
+ gene_names: Optional[np.ndarray] = None,
307
+ control_key: str = "control",
308
+ perturbation_column: Optional[str] = None,
309
+ deg_method: DEGMethod = "welch",
310
+ pval_threshold: float = 0.05,
311
+ lfc_threshold: float = 0.5,
312
+ n_top_degs: Optional[int] = None,
313
+ min_degs: int = 5,
314
+ compute_all_genes: bool = True,
315
+ metrics: Optional[List[str]] = None,
316
+ n_jobs: int = 1,
317
+ device: str = "cpu",
318
+ verbose: bool = True,
319
+ ):
320
+ self.real_data = np.asarray(real_data, dtype=np.float32)
321
+ self.generated_data = np.asarray(generated_data, dtype=np.float32)
322
+ self.real_obs = real_obs.reset_index(drop=True)
323
+ self.generated_obs = generated_obs.reset_index(drop=True)
324
+ self.condition_columns = condition_columns
325
+ self.n_genes = real_data.shape[1]
326
+ self.gene_names = gene_names if gene_names is not None else np.array(
327
+ [f"Gene_{i}" for i in range(self.n_genes)]
328
+ )
329
+ self.control_key = control_key
330
+ self.perturbation_column = perturbation_column or condition_columns[0]
331
+
332
+ # DEG settings
333
+ self.deg_settings = DEGSettings(
334
+ method=deg_method,
335
+ pval_threshold=pval_threshold,
336
+ lfc_threshold=lfc_threshold,
337
+ n_top_degs=n_top_degs,
338
+ min_degs=min_degs,
339
+ )
340
+
341
+ self.compute_all_genes = compute_all_genes
342
+ self.metrics = metrics or self.SUPPORTED_METRICS
343
+ self.n_jobs = n_jobs
344
+ self.device = device
345
+ self.verbose = verbose
346
+
347
+ # Create context evaluator
348
+ self.context_evaluator = ContextEvaluator(
349
+ real_data=self.real_data,
350
+ generated_data=self.generated_data,
351
+ real_obs=self.real_obs,
352
+ generated_obs=self.generated_obs,
353
+ condition_columns=condition_columns,
354
+ gene_names=self.gene_names,
355
+ control_key=control_key,
356
+ perturbation_column=self.perturbation_column,
357
+ )
358
+
359
+ # Initialize metric objects
360
+ self._metric_objects = {
361
+ "wasserstein_1": Wasserstein1Distance(),
362
+ "wasserstein_2": Wasserstein2Distance(),
363
+ "mmd": MMDDistance(),
364
+ "energy": EnergyDistance(),
365
+ "pearson": PearsonCorrelation(),
366
+ "spearman": SpearmanCorrelation(),
367
+ }
368
+
369
+ self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
370
+ self._log(f"DEG settings: method={deg_method}, pval<{pval_threshold}, |lfc|>{lfc_threshold}")
371
+ if n_top_degs is not None:
372
+ self._log(f" Using top {n_top_degs} DEGs by significance")
373
+ if compute_all_genes:
374
+ self._log("Will compute metrics on BOTH DEGs and all genes")
375
+
376
+ def _log(self, msg: str) -> None:
377
+ """Print if verbose."""
378
+ if self.verbose:
379
+ print(msg)
380
+
381
+ def _compute_degs(
382
+ self,
383
+ control: np.ndarray,
384
+ perturbed: np.ndarray,
385
+ ) -> DEGResult:
386
+ """Compute DEGs using configured method and device."""
387
+ deg_result = compute_degs_auto(
388
+ control=control,
389
+ perturbed=perturbed,
390
+ gene_names=self.gene_names,
391
+ method=self.deg_settings.method,
392
+ pval_threshold=self.deg_settings.pval_threshold,
393
+ lfc_threshold=self.deg_settings.lfc_threshold,
394
+ n_jobs=self.n_jobs,
395
+ device=self.device,
396
+ )
397
+
398
+ # If n_top_degs is set, limit to top N by significance
399
+ if self.deg_settings.n_top_degs is not None:
400
+ deg_result = self._filter_top_degs(deg_result, self.deg_settings.n_top_degs)
401
+
402
+ return deg_result
403
+
404
+ def _filter_top_degs(self, deg_result: DEGResult, n_top: int) -> DEGResult:
405
+ """Filter DEG result to keep only top N most significant DEGs."""
406
+ if deg_result.n_degs <= n_top:
407
+ return deg_result # Already fewer than n_top
408
+
409
+ # Sort DEGs by adjusted p-value (lower is more significant)
410
+ deg_pvals = deg_result.pvalues_adj[deg_result.is_deg]
411
+ deg_indices = deg_result.deg_indices
412
+
413
+ # Get indices of top N most significant
414
+ top_n_order = np.argsort(deg_pvals)[:n_top]
415
+ top_deg_indices = deg_indices[top_n_order]
416
+
417
+ # Create new is_deg mask
418
+ new_is_deg = np.zeros(len(deg_result.is_deg), dtype=bool)
419
+ new_is_deg[top_deg_indices] = True
420
+
421
+ # Create modified DEGResult with all required fields
422
+ return DEGResult(
423
+ gene_names=deg_result.gene_names,
424
+ pvalues=deg_result.pvalues,
425
+ pvalues_adj=deg_result.pvalues_adj,
426
+ log_fold_changes=deg_result.log_fold_changes,
427
+ mean_control=deg_result.mean_control,
428
+ mean_perturbed=deg_result.mean_perturbed,
429
+ is_deg=new_is_deg,
430
+ n_degs=n_top,
431
+ method=deg_result.method,
432
+ pval_threshold=deg_result.pval_threshold,
433
+ lfc_threshold=deg_result.lfc_threshold,
434
+ deg_indices=top_deg_indices,
435
+ )
436
+
437
+ def _compute_metrics_on_genes(
438
+ self,
439
+ real: np.ndarray,
440
+ generated: np.ndarray,
441
+ gene_indices: Optional[np.ndarray] = None,
442
+ min_genes: int = 1,
443
+ ) -> Dict[str, float]:
444
+ """Compute metrics on specified genes (or all if indices is None)."""
445
+ # Slice to selected genes
446
+ if gene_indices is not None:
447
+ if len(gene_indices) < min_genes:
448
+ return {m: np.nan for m in self.metrics}
449
+ real_subset = real[:, gene_indices]
450
+ gen_subset = generated[:, gene_indices]
451
+ else:
452
+ real_subset = real
453
+ gen_subset = generated
454
+
455
+ results = {}
456
+
457
+ # Use vectorized implementations where available
458
+ if "wasserstein_1" in self.metrics:
459
+ try:
460
+ w1_per_gene = vectorized_wasserstein1(real_subset, gen_subset)
461
+ results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
462
+ except Exception:
463
+ results["wasserstein_1"] = np.nan
464
+
465
+ if "mmd" in self.metrics:
466
+ try:
467
+ mmd_per_gene = vectorized_mmd(real_subset, gen_subset)
468
+ results["mmd"] = float(np.nanmean(mmd_per_gene))
469
+ except Exception:
470
+ results["mmd"] = np.nan
471
+
472
+ # Fall back to standard computation for other metrics
473
+ for metric_name in self.metrics:
474
+ if metric_name in results:
475
+ continue
476
+ if metric_name not in self._metric_objects:
477
+ continue
478
+
479
+ metric = self._metric_objects[metric_name]
480
+ try:
481
+ per_gene = metric.compute_per_gene(real_subset, gen_subset)
482
+ results[metric_name] = float(np.nanmean(per_gene))
483
+ except Exception:
484
+ results[metric_name] = np.nan
485
+
486
+ return results
487
+
488
+ def evaluate(self) -> DEGEvaluationResult:
489
+ """
490
+ Run DEG-focused evaluation on all contexts.
491
+
492
+ Returns both DEG-only and all-genes metrics for comparison.
493
+
494
+ Returns
495
+ -------
496
+ DEGEvaluationResult
497
+ Complete evaluation results with:
498
+ - Per-context DEG and all-genes metrics
499
+ - Aggregated metrics
500
+ - DEG summary
501
+ - Comparison between DEG and all-genes evaluation
502
+ """
503
+ context_results: List[ContextMetrics] = []
504
+
505
+ perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
506
+ n_contexts = len(perturbation_contexts)
507
+
508
+ self._log(f"\nEvaluating {n_contexts} perturbation contexts...")
509
+
510
+ for i, context in enumerate(perturbation_contexts):
511
+ context_id = get_context_id(context)
512
+
513
+ if self.verbose:
514
+ print(f" [{i+1}/{n_contexts}] {context_id}", end="... ")
515
+
516
+ try:
517
+ # Get perturbed data
518
+ real_pert, gen_pert = self.context_evaluator.get_context_data(context)
519
+
520
+ # Get control data
521
+ real_ctrl, gen_ctrl = self.context_evaluator.get_control_data(context)
522
+
523
+ if len(real_ctrl) < 2 or len(real_pert) < 2:
524
+ if self.verbose:
525
+ print("skipped (insufficient samples)")
526
+ continue
527
+
528
+ # Compute DEGs using real data (control vs perturbed)
529
+ deg_result = self._compute_degs(real_ctrl, real_pert)
530
+
531
+ if self.verbose:
532
+ print(f"{deg_result.n_degs} DEGs", end="... ")
533
+
534
+ # Compute metrics on DEGs
535
+ deg_metrics = self._compute_metrics_on_genes(
536
+ real_pert, gen_pert,
537
+ gene_indices=deg_result.deg_indices,
538
+ min_genes=self.deg_settings.min_degs,
539
+ )
540
+
541
+ # Compute metrics on all genes if requested
542
+ if self.compute_all_genes:
543
+ all_genes_metrics = self._compute_metrics_on_genes(
544
+ real_pert, gen_pert,
545
+ gene_indices=None, # All genes
546
+ )
547
+ else:
548
+ all_genes_metrics = {}
549
+
550
+ ctx_result = ContextMetrics(
551
+ context_id=context_id,
552
+ context_values=context,
553
+ n_samples_real=len(real_pert),
554
+ n_samples_gen=len(gen_pert),
555
+ n_genes_total=self.n_genes,
556
+ deg_result=deg_result,
557
+ deg_metrics=deg_metrics,
558
+ all_genes_metrics=all_genes_metrics,
559
+ )
560
+ context_results.append(ctx_result)
561
+
562
+ if self.verbose:
563
+ print("done")
564
+
565
+ except Exception as e:
566
+ if self.verbose:
567
+ print(f"error: {e}")
568
+ continue
569
+
570
+ # Build result DataFrames
571
+ expanded_metrics = self._build_expanded_metrics(context_results)
572
+ aggregated_metrics = self._build_aggregated_metrics(expanded_metrics)
573
+ deg_summary = self._build_deg_summary(context_results)
574
+ comparison_summary = self._build_comparison_summary(expanded_metrics)
575
+
576
+ self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
577
+
578
+ return DEGEvaluationResult(
579
+ context_results=context_results,
580
+ aggregated_metrics=aggregated_metrics,
581
+ expanded_metrics=expanded_metrics,
582
+ deg_summary=deg_summary,
583
+ comparison_summary=comparison_summary,
584
+ gene_names=self.gene_names,
585
+ settings={
586
+ **self.deg_settings.to_dict(),
587
+ "compute_all_genes": self.compute_all_genes,
588
+ "metrics": self.metrics,
589
+ "device": self.device,
590
+ "n_jobs": self.n_jobs,
591
+ },
592
+ )
593
+
594
+ def _build_expanded_metrics(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
595
+ """Build expanded metrics DataFrame with both DEG and all-genes columns."""
596
+ data = []
597
+ for ctx in context_results:
598
+ row = {
599
+ "context_id": ctx.context_id,
600
+ **ctx.context_values,
601
+ "n_samples_real": ctx.n_samples_real,
602
+ "n_samples_gen": ctx.n_samples_gen,
603
+ "n_genes_total": ctx.n_genes_total,
604
+ "n_degs": ctx.n_degs,
605
+ }
606
+
607
+ # DEG metrics with prefix
608
+ for metric_name, value in ctx.deg_metrics.items():
609
+ row[f"deg_{metric_name}"] = value
610
+
611
+ # All-genes metrics with prefix
612
+ for metric_name, value in ctx.all_genes_metrics.items():
613
+ row[f"all_{metric_name}"] = value
614
+
615
+ data.append(row)
616
+
617
+ return pd.DataFrame(data)
618
+
619
+ def _build_aggregated_metrics(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
620
+ """Build aggregated metrics DataFrame."""
621
+ if len(expanded_metrics) == 0:
622
+ return pd.DataFrame()
623
+
624
+ agg_data = {
625
+ "n_contexts": len(expanded_metrics),
626
+ "total_samples_real": expanded_metrics["n_samples_real"].sum(),
627
+ "total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
628
+ "mean_n_degs": expanded_metrics["n_degs"].mean(),
629
+ "median_n_degs": expanded_metrics["n_degs"].median(),
630
+ "min_n_degs": expanded_metrics["n_degs"].min(),
631
+ "max_n_degs": expanded_metrics["n_degs"].max(),
632
+ }
633
+
634
+ # Aggregate DEG metrics
635
+ for metric in self.metrics:
636
+ col = f"deg_{metric}"
637
+ if col in expanded_metrics.columns:
638
+ agg_data[f"deg_{metric}_mean"] = expanded_metrics[col].mean()
639
+ agg_data[f"deg_{metric}_std"] = expanded_metrics[col].std()
640
+
641
+ # Aggregate all-genes metrics
642
+ for metric in self.metrics:
643
+ col = f"all_{metric}"
644
+ if col in expanded_metrics.columns:
645
+ agg_data[f"all_{metric}_mean"] = expanded_metrics[col].mean()
646
+ agg_data[f"all_{metric}_std"] = expanded_metrics[col].std()
647
+
648
+ return pd.DataFrame([agg_data])
649
+
650
+ def _build_deg_summary(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
651
+ """Build DEG summary DataFrame."""
652
+ data = []
653
+ for ctx in context_results:
654
+ if ctx.deg_result is not None:
655
+ deg_lfcs = ctx.deg_result.log_fold_changes[ctx.deg_result.is_deg]
656
+ data.append({
657
+ "context_id": ctx.context_id,
658
+ **ctx.context_values,
659
+ "n_degs": ctx.n_degs,
660
+ "n_genes_total": ctx.n_genes_total,
661
+ "deg_fraction": ctx.n_degs / ctx.n_genes_total if ctx.n_genes_total > 0 else 0,
662
+ "n_upregulated": (deg_lfcs > 0).sum(),
663
+ "n_downregulated": (deg_lfcs < 0).sum(),
664
+ "mean_abs_lfc": float(np.abs(deg_lfcs).mean()) if len(deg_lfcs) > 0 else np.nan,
665
+ "max_abs_lfc": float(np.abs(deg_lfcs).max()) if len(deg_lfcs) > 0 else np.nan,
666
+ })
667
+ return pd.DataFrame(data)
668
+
669
+ def _build_comparison_summary(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
670
+ """Build comparison summary between DEG and all-genes metrics."""
671
+ if len(expanded_metrics) == 0:
672
+ return pd.DataFrame()
673
+
674
+ comparison_data = []
675
+ for metric in self.metrics:
676
+ deg_col = f"deg_{metric}"
677
+ all_col = f"all_{metric}"
678
+
679
+ if deg_col in expanded_metrics.columns and all_col in expanded_metrics.columns:
680
+ deg_values = expanded_metrics[deg_col].dropna()
681
+ all_values = expanded_metrics[all_col].dropna()
682
+
683
+ if len(deg_values) > 0 and len(all_values) > 0:
684
+ deg_mean = deg_values.mean()
685
+ all_mean = all_values.mean()
686
+
687
+ comparison_data.append({
688
+ "metric": metric,
689
+ "deg_mean": deg_mean,
690
+ "deg_std": deg_values.std(),
691
+ "all_mean": all_mean,
692
+ "all_std": all_values.std(),
693
+ "difference": deg_mean - all_mean,
694
+ "ratio": deg_mean / all_mean if all_mean != 0 else np.nan,
695
+ "n_contexts": len(deg_values),
696
+ })
697
+
698
+ return pd.DataFrame(comparison_data)
699
+
700
+
701
+ def evaluate_degs(
702
+ real_data: np.ndarray,
703
+ generated_data: np.ndarray,
704
+ real_obs: pd.DataFrame,
705
+ generated_obs: pd.DataFrame,
706
+ condition_columns: List[str],
707
+ gene_names: Optional[np.ndarray] = None,
708
+ control_key: str = "control",
709
+ perturbation_column: Optional[str] = None,
710
+ deg_method: DEGMethod = "welch",
711
+ pval_threshold: float = 0.05,
712
+ lfc_threshold: float = 0.5,
713
+ n_top_degs: Optional[int] = None,
714
+ min_degs: int = 5,
715
+ compute_all_genes: bool = True,
716
+ metrics: Optional[List[str]] = None,
717
+ n_jobs: int = 1,
718
+ device: str = "auto",
719
+ verbose: bool = True,
720
+ ) -> DEGEvaluationResult:
721
+ """
722
+ Convenience function for DEG-focused evaluation with full control.
723
+
724
+ Computes metrics on both DEGs and all genes for comparison.
725
+
726
+ Parameters
727
+ ----------
728
+ real_data : np.ndarray
729
+ Real expression matrix (n_samples, n_genes)
730
+ generated_data : np.ndarray
731
+ Generated expression matrix (n_samples, n_genes)
732
+ real_obs : pd.DataFrame
733
+ Real data metadata with condition columns
734
+ generated_obs : pd.DataFrame
735
+ Generated data metadata with condition columns
736
+ condition_columns : List[str]
737
+ Columns defining contexts (e.g., ["cell_type", "perturbation"])
738
+ gene_names : np.ndarray, optional
739
+ Gene names for output
740
+ control_key : str
741
+ Control condition identifier (default: "control")
742
+ perturbation_column : str, optional
743
+ Column containing perturbation info. If None, uses first condition column.
744
+ deg_method : str
745
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
746
+ pval_threshold : float
747
+ Adjusted p-value threshold (default: 0.05)
748
+ lfc_threshold : float
749
+ Absolute log2 fold change threshold (default: 0.5)
750
+ n_top_degs : int, optional
751
+ If set, use only top N DEGs by significance (overrides thresholds)
752
+ min_degs : int
753
+ Minimum DEGs to compute DEG metrics (default: 5)
754
+ compute_all_genes : bool
755
+ Also compute metrics on all genes for comparison (default: True)
756
+ metrics : List[str], optional
757
+ Metrics to compute. Default: all supported.
758
+ n_jobs : int
759
+ Parallel CPU jobs (default: 1)
760
+ device : str
761
+ Compute device: "cpu", "cuda", "mps", "auto" (default: "auto")
762
+ verbose : bool
763
+ Print progress (default: True)
764
+
765
+ Returns
766
+ -------
767
+ DEGEvaluationResult
768
+ Complete evaluation results including:
769
+ - expanded_metrics: Per-context metrics for DEGs and all genes
770
+ - aggregated_metrics: Summary statistics
771
+ - deg_summary: DEG detection summary
772
+ - comparison_summary: DEG vs all-genes comparison
773
+
774
+ Examples
775
+ --------
776
+ >>> # Basic usage with default thresholds
777
+ >>> results = evaluate_degs(
778
+ ... real_data, generated_data,
779
+ ... real_obs, generated_obs,
780
+ ... condition_columns=["perturbation"],
781
+ ... )
782
+ >>> print(results.comparison_summary)
783
+
784
+ >>> # Top 50 DEGs only
785
+ >>> results = evaluate_degs(
786
+ ... real_data, generated_data,
787
+ ... real_obs, generated_obs,
788
+ ... condition_columns=["perturbation"],
789
+ ... n_top_degs=50,
790
+ ... )
791
+
792
+ >>> # Strict thresholds
793
+ >>> results = evaluate_degs(
794
+ ... real_data, generated_data,
795
+ ... real_obs, generated_obs,
796
+ ... condition_columns=["perturbation"],
797
+ ... pval_threshold=0.01,
798
+ ... lfc_threshold=1.0, # 2-fold change
799
+ ... )
800
+ """
801
+ evaluator = DEGEvaluator(
802
+ real_data=real_data,
803
+ generated_data=generated_data,
804
+ real_obs=real_obs,
805
+ generated_obs=generated_obs,
806
+ condition_columns=condition_columns,
807
+ gene_names=gene_names,
808
+ control_key=control_key,
809
+ perturbation_column=perturbation_column,
810
+ deg_method=deg_method,
811
+ pval_threshold=pval_threshold,
812
+ lfc_threshold=lfc_threshold,
813
+ n_top_degs=n_top_degs,
814
+ min_degs=min_degs,
815
+ compute_all_genes=compute_all_genes,
816
+ metrics=metrics,
817
+ n_jobs=n_jobs,
818
+ device=device,
819
+ verbose=verbose,
820
+ )
821
+ return evaluator.evaluate()