gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,538 @@
1
+ """
2
+ DEG-focused evaluator for GenGeneEval.
3
+
4
+ Computes metrics only on differentially expressed genes, with support for:
5
+ - Per-context evaluation (covariates × perturbations)
6
+ - Fast DEG detection with GPU acceleration
7
+ - Aggregated and expanded result reporting
8
+ """
9
+ from __future__ import annotations
10
+
11
+ from typing import Optional, List, Dict, Union, Any, Literal
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ import numpy as np
15
+ import pandas as pd
16
+ import warnings
17
+
18
+ from .detection import (
19
+ compute_degs_fast,
20
+ compute_degs_gpu,
21
+ compute_degs_auto,
22
+ DEGResult,
23
+ DEGMethod,
24
+ )
25
+ from .context import (
26
+ ContextEvaluator,
27
+ ContextResult,
28
+ get_context_id,
29
+ get_contexts,
30
+ )
31
+
32
+ # Import metrics
33
+ from ..metrics.base_metric import BaseMetric
34
+ from ..metrics.correlation import PearsonCorrelation, SpearmanCorrelation
35
+ from ..metrics.distances import (
36
+ Wasserstein1Distance,
37
+ Wasserstein2Distance,
38
+ MMDDistance,
39
+ EnergyDistance,
40
+ )
41
+ from ..metrics.accelerated import (
42
+ get_available_backends,
43
+ vectorized_wasserstein1,
44
+ vectorized_mmd,
45
+ )
46
+
47
+
48
+ @dataclass
49
+ class DEGEvaluationResult:
50
+ """Complete DEG evaluation results.
51
+
52
+ Attributes
53
+ ----------
54
+ context_results : List[ContextResult]
55
+ Results for each context
56
+ aggregated_metrics : pd.DataFrame
57
+ Aggregated metrics across contexts
58
+ expanded_metrics : pd.DataFrame
59
+ Per-context expanded metrics
60
+ deg_summary : pd.DataFrame
61
+ Summary of DEG detection per context
62
+ gene_names : np.ndarray
63
+ All gene names
64
+ settings : Dict
65
+ Evaluation settings
66
+ """
67
+ context_results: List[ContextResult]
68
+ aggregated_metrics: pd.DataFrame
69
+ expanded_metrics: pd.DataFrame
70
+ deg_summary: pd.DataFrame
71
+ gene_names: np.ndarray
72
+ settings: Dict[str, Any]
73
+
74
+ def save(self, output_dir: Union[str, Path]) -> None:
75
+ """Save results to directory."""
76
+ output_dir = Path(output_dir)
77
+ output_dir.mkdir(parents=True, exist_ok=True)
78
+
79
+ self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv")
80
+ self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv")
81
+ self.deg_summary.to_csv(output_dir / "deg_summary.csv")
82
+
83
+ # Save per-context DEG results
84
+ deg_dir = output_dir / "deg_per_context"
85
+ deg_dir.mkdir(exist_ok=True)
86
+ for ctx_result in self.context_results:
87
+ if ctx_result.deg_result is not None:
88
+ ctx_result.deg_result.to_dataframe().to_csv(
89
+ deg_dir / f"{ctx_result.context_id}_degs.csv"
90
+ )
91
+
92
+ def __repr__(self) -> str:
93
+ return (
94
+ f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
95
+ f"metrics={list(self.aggregated_metrics.columns)})"
96
+ )
97
+
98
+
99
+ class DEGEvaluator:
100
+ """
101
+ Evaluator that computes metrics on DEGs only.
102
+
103
+ This evaluator:
104
+ 1. Detects DEGs for each perturbation context
105
+ 2. Computes distributional metrics only on DEG genes
106
+ 3. Reports per-context and aggregated results
107
+
108
+ Parameters
109
+ ----------
110
+ real_data : np.ndarray
111
+ Real expression matrix (n_samples, n_genes)
112
+ generated_data : np.ndarray
113
+ Generated expression matrix (n_samples, n_genes)
114
+ real_obs : pd.DataFrame
115
+ Real data observation metadata
116
+ generated_obs : pd.DataFrame
117
+ Generated data observation metadata
118
+ condition_columns : List[str]
119
+ Columns defining contexts (e.g., ["cell_type", "perturbation"])
120
+ gene_names : np.ndarray, optional
121
+ Gene names
122
+ control_key : str
123
+ Value indicating control samples (default: "control")
124
+ perturbation_column : str, optional
125
+ Column containing perturbation info. If None, uses first condition column.
126
+ deg_method : str
127
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
128
+ pval_threshold : float
129
+ P-value threshold for DEG significance
130
+ lfc_threshold : float
131
+ Log2 fold change threshold
132
+ min_degs : int
133
+ Minimum DEGs required to compute metrics
134
+ metrics : List[str], optional
135
+ Metrics to compute. Default: all supported metrics.
136
+ n_jobs : int
137
+ Number of parallel CPU jobs
138
+ device : str
139
+ Compute device: "cpu", "cuda", "mps", "auto"
140
+ verbose : bool
141
+ Print progress
142
+
143
+ Examples
144
+ --------
145
+ >>> evaluator = DEGEvaluator(
146
+ ... real_data, generated_data,
147
+ ... real_obs, generated_obs,
148
+ ... condition_columns=["perturbation"],
149
+ ... deg_method="welch",
150
+ ... device="cuda",
151
+ ... )
152
+ >>> results = evaluator.evaluate()
153
+ >>> results.save("output/")
154
+ """
155
+
156
+ # Supported metrics
157
+ SUPPORTED_METRICS = [
158
+ "wasserstein_1",
159
+ "wasserstein_2",
160
+ "mmd",
161
+ "energy",
162
+ "pearson",
163
+ "spearman",
164
+ ]
165
+
166
+ def __init__(
167
+ self,
168
+ real_data: np.ndarray,
169
+ generated_data: np.ndarray,
170
+ real_obs: pd.DataFrame,
171
+ generated_obs: pd.DataFrame,
172
+ condition_columns: List[str],
173
+ gene_names: Optional[np.ndarray] = None,
174
+ control_key: str = "control",
175
+ perturbation_column: Optional[str] = None,
176
+ deg_method: DEGMethod = "welch",
177
+ pval_threshold: float = 0.05,
178
+ lfc_threshold: float = 0.5,
179
+ min_degs: int = 5,
180
+ metrics: Optional[List[str]] = None,
181
+ n_jobs: int = 1,
182
+ device: str = "cpu",
183
+ verbose: bool = True,
184
+ ):
185
+ self.real_data = np.asarray(real_data, dtype=np.float32)
186
+ self.generated_data = np.asarray(generated_data, dtype=np.float32)
187
+ self.real_obs = real_obs.reset_index(drop=True)
188
+ self.generated_obs = generated_obs.reset_index(drop=True)
189
+ self.condition_columns = condition_columns
190
+ self.gene_names = gene_names if gene_names is not None else np.array(
191
+ [f"Gene_{i}" for i in range(real_data.shape[1])]
192
+ )
193
+ self.control_key = control_key
194
+ self.perturbation_column = perturbation_column or condition_columns[0]
195
+ self.deg_method = deg_method
196
+ self.pval_threshold = pval_threshold
197
+ self.lfc_threshold = lfc_threshold
198
+ self.min_degs = min_degs
199
+ self.metrics = metrics or self.SUPPORTED_METRICS
200
+ self.n_jobs = n_jobs
201
+ self.device = device
202
+ self.verbose = verbose
203
+
204
+ # Create context evaluator
205
+ self.context_evaluator = ContextEvaluator(
206
+ real_data=self.real_data,
207
+ generated_data=self.generated_data,
208
+ real_obs=self.real_obs,
209
+ generated_obs=self.generated_obs,
210
+ condition_columns=condition_columns,
211
+ gene_names=self.gene_names,
212
+ control_key=control_key,
213
+ perturbation_column=self.perturbation_column,
214
+ )
215
+
216
+ # Initialize metric objects
217
+ self._metric_objects = {
218
+ "wasserstein_1": Wasserstein1Distance(),
219
+ "wasserstein_2": Wasserstein2Distance(),
220
+ "mmd": MMDDistance(),
221
+ "energy": EnergyDistance(),
222
+ "pearson": PearsonCorrelation(),
223
+ "spearman": SpearmanCorrelation(),
224
+ }
225
+
226
+ self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
227
+ self._log(f"Perturbation contexts: {len(self.context_evaluator.get_perturbation_contexts())}")
228
+
229
+ def _log(self, msg: str) -> None:
230
+ """Print if verbose."""
231
+ if self.verbose:
232
+ print(msg)
233
+
234
+ def _compute_degs(
235
+ self,
236
+ control: np.ndarray,
237
+ perturbed: np.ndarray,
238
+ ) -> DEGResult:
239
+ """Compute DEGs using configured method and device."""
240
+ return compute_degs_auto(
241
+ control=control,
242
+ perturbed=perturbed,
243
+ gene_names=self.gene_names,
244
+ method=self.deg_method,
245
+ pval_threshold=self.pval_threshold,
246
+ lfc_threshold=self.lfc_threshold,
247
+ n_jobs=self.n_jobs,
248
+ device=self.device,
249
+ )
250
+
251
+ def _compute_metrics_on_degs(
252
+ self,
253
+ real: np.ndarray,
254
+ generated: np.ndarray,
255
+ deg_indices: np.ndarray,
256
+ ) -> Dict[str, float]:
257
+ """Compute metrics on DEG genes only."""
258
+ if len(deg_indices) < self.min_degs:
259
+ return {m: np.nan for m in self.metrics}
260
+
261
+ # Slice to DEGs only
262
+ real_degs = real[:, deg_indices]
263
+ gen_degs = generated[:, deg_indices]
264
+
265
+ results = {}
266
+
267
+ for metric_name in self.metrics:
268
+ if metric_name not in self._metric_objects:
269
+ continue
270
+
271
+ metric = self._metric_objects[metric_name]
272
+
273
+ try:
274
+ # Compute per-gene and aggregate
275
+ per_gene = metric.compute_per_gene(real_degs, gen_degs)
276
+ results[metric_name] = float(np.nanmean(per_gene))
277
+ except Exception as e:
278
+ if self.verbose:
279
+ self._log(f"Warning: {metric_name} failed: {e}")
280
+ results[metric_name] = np.nan
281
+
282
+ return results
283
+
284
+ def _compute_metrics_accelerated(
285
+ self,
286
+ real: np.ndarray,
287
+ generated: np.ndarray,
288
+ deg_indices: np.ndarray,
289
+ ) -> Dict[str, float]:
290
+ """Compute metrics using accelerated implementations."""
291
+ if len(deg_indices) < self.min_degs:
292
+ return {m: np.nan for m in self.metrics}
293
+
294
+ # Slice to DEGs only
295
+ real_degs = real[:, deg_indices]
296
+ gen_degs = generated[:, deg_indices]
297
+
298
+ results = {}
299
+ backends = get_available_backends()
300
+
301
+ # Use vectorized implementations where available
302
+ if "wasserstein_1" in self.metrics:
303
+ try:
304
+ w1_per_gene = vectorized_wasserstein1(real_degs, gen_degs)
305
+ results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
306
+ except Exception:
307
+ results["wasserstein_1"] = np.nan
308
+
309
+ if "mmd" in self.metrics:
310
+ try:
311
+ mmd_per_gene = vectorized_mmd(real_degs, gen_degs)
312
+ results["mmd"] = float(np.nanmean(mmd_per_gene))
313
+ except Exception:
314
+ results["mmd"] = np.nan
315
+
316
+ # Fall back to standard computation for other metrics
317
+ for metric_name in self.metrics:
318
+ if metric_name in results:
319
+ continue
320
+ if metric_name not in self._metric_objects:
321
+ continue
322
+
323
+ metric = self._metric_objects[metric_name]
324
+ try:
325
+ per_gene = metric.compute_per_gene(real_degs, gen_degs)
326
+ results[metric_name] = float(np.nanmean(per_gene))
327
+ except Exception:
328
+ results[metric_name] = np.nan
329
+
330
+ return results
331
+
332
+ def evaluate(self) -> DEGEvaluationResult:
333
+ """
334
+ Run DEG-focused evaluation on all contexts.
335
+
336
+ Returns
337
+ -------
338
+ DEGEvaluationResult
339
+ Complete evaluation results with per-context and aggregated metrics.
340
+ """
341
+ context_results = []
342
+
343
+ perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
344
+ n_contexts = len(perturbation_contexts)
345
+
346
+ self._log(f"Evaluating {n_contexts} perturbation contexts...")
347
+
348
+ for i, context in enumerate(perturbation_contexts):
349
+ context_id = get_context_id(context)
350
+
351
+ if self.verbose:
352
+ print(f" [{i+1}/{n_contexts}] {context_id}", end="... ")
353
+
354
+ try:
355
+ # Get perturbed data
356
+ real_pert, gen_pert = self.context_evaluator.get_context_data(context)
357
+
358
+ # Get control data
359
+ real_ctrl, gen_ctrl = self.context_evaluator.get_control_data(context)
360
+
361
+ if len(real_ctrl) < 2 or len(real_pert) < 2:
362
+ if self.verbose:
363
+ print("skipped (insufficient samples)")
364
+ continue
365
+
366
+ # Compute DEGs using real data (control vs perturbed)
367
+ deg_result = self._compute_degs(real_ctrl, real_pert)
368
+
369
+ if self.verbose:
370
+ print(f"{deg_result.n_degs} DEGs", end="... ")
371
+
372
+ # Compute metrics on DEGs
373
+ metrics = self._compute_metrics_accelerated(
374
+ real_pert, gen_pert, deg_result.deg_indices
375
+ )
376
+
377
+ ctx_result = ContextResult(
378
+ context_id=context_id,
379
+ context_values=context,
380
+ n_samples_real=len(real_pert),
381
+ n_samples_gen=len(gen_pert),
382
+ deg_result=deg_result,
383
+ metrics=metrics,
384
+ )
385
+ context_results.append(ctx_result)
386
+
387
+ if self.verbose:
388
+ print("done")
389
+
390
+ except Exception as e:
391
+ if self.verbose:
392
+ print(f"error: {e}")
393
+ continue
394
+
395
+ # Build result DataFrames
396
+ expanded_data = []
397
+ for ctx_result in context_results:
398
+ row = {
399
+ "context_id": ctx_result.context_id,
400
+ **ctx_result.context_values,
401
+ "n_samples_real": ctx_result.n_samples_real,
402
+ "n_samples_gen": ctx_result.n_samples_gen,
403
+ "n_degs": ctx_result.deg_result.n_degs if ctx_result.deg_result else 0,
404
+ **ctx_result.metrics,
405
+ }
406
+ expanded_data.append(row)
407
+
408
+ expanded_metrics = pd.DataFrame(expanded_data)
409
+
410
+ # Aggregated metrics
411
+ if len(expanded_metrics) > 0:
412
+ agg_data = {
413
+ "n_contexts": len(context_results),
414
+ "total_samples_real": expanded_metrics["n_samples_real"].sum(),
415
+ "total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
416
+ "mean_n_degs": expanded_metrics["n_degs"].mean(),
417
+ "median_n_degs": expanded_metrics["n_degs"].median(),
418
+ }
419
+ for metric in self.metrics:
420
+ if metric in expanded_metrics.columns:
421
+ agg_data[f"{metric}_mean"] = expanded_metrics[metric].mean()
422
+ agg_data[f"{metric}_std"] = expanded_metrics[metric].std()
423
+ agg_data[f"{metric}_median"] = expanded_metrics[metric].median()
424
+
425
+ aggregated_metrics = pd.DataFrame([agg_data])
426
+ else:
427
+ aggregated_metrics = pd.DataFrame()
428
+
429
+ # DEG summary
430
+ deg_summary_data = []
431
+ for ctx_result in context_results:
432
+ if ctx_result.deg_result is not None:
433
+ deg_summary_data.append({
434
+ "context_id": ctx_result.context_id,
435
+ **ctx_result.context_values,
436
+ "n_degs": ctx_result.deg_result.n_degs,
437
+ "n_upregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] > 0).sum(),
438
+ "n_downregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] < 0).sum(),
439
+ "mean_abs_lfc": np.abs(ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg]).mean() if ctx_result.deg_result.n_degs > 0 else np.nan,
440
+ })
441
+ deg_summary = pd.DataFrame(deg_summary_data)
442
+
443
+ self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
444
+
445
+ return DEGEvaluationResult(
446
+ context_results=context_results,
447
+ aggregated_metrics=aggregated_metrics,
448
+ expanded_metrics=expanded_metrics,
449
+ deg_summary=deg_summary,
450
+ gene_names=self.gene_names,
451
+ settings={
452
+ "deg_method": self.deg_method,
453
+ "pval_threshold": self.pval_threshold,
454
+ "lfc_threshold": self.lfc_threshold,
455
+ "min_degs": self.min_degs,
456
+ "metrics": self.metrics,
457
+ "device": self.device,
458
+ "n_jobs": self.n_jobs,
459
+ },
460
+ )
461
+
462
+
463
+ def evaluate_degs(
464
+ real_data: np.ndarray,
465
+ generated_data: np.ndarray,
466
+ real_obs: pd.DataFrame,
467
+ generated_obs: pd.DataFrame,
468
+ condition_columns: List[str],
469
+ gene_names: Optional[np.ndarray] = None,
470
+ control_key: str = "control",
471
+ perturbation_column: Optional[str] = None,
472
+ deg_method: DEGMethod = "welch",
473
+ pval_threshold: float = 0.05,
474
+ lfc_threshold: float = 0.5,
475
+ metrics: Optional[List[str]] = None,
476
+ n_jobs: int = 1,
477
+ device: str = "auto",
478
+ verbose: bool = True,
479
+ ) -> DEGEvaluationResult:
480
+ """
481
+ Convenience function for DEG-focused evaluation.
482
+
483
+ Parameters
484
+ ----------
485
+ real_data : np.ndarray
486
+ Real expression matrix
487
+ generated_data : np.ndarray
488
+ Generated expression matrix
489
+ real_obs : pd.DataFrame
490
+ Real data metadata
491
+ generated_obs : pd.DataFrame
492
+ Generated data metadata
493
+ condition_columns : List[str]
494
+ Columns defining contexts
495
+ gene_names : np.ndarray, optional
496
+ Gene names
497
+ control_key : str
498
+ Control condition identifier
499
+ perturbation_column : str, optional
500
+ Column containing perturbation info. If None, uses first condition column.
501
+ deg_method : str
502
+ DEG detection method
503
+ pval_threshold : float
504
+ P-value threshold
505
+ lfc_threshold : float
506
+ Log fold change threshold
507
+ metrics : List[str], optional
508
+ Metrics to compute
509
+ n_jobs : int
510
+ Parallel CPU jobs
511
+ device : str
512
+ Compute device
513
+ verbose : bool
514
+ Print progress
515
+
516
+ Returns
517
+ -------
518
+ DEGEvaluationResult
519
+ Evaluation results
520
+ """
521
+ evaluator = DEGEvaluator(
522
+ real_data=real_data,
523
+ generated_data=generated_data,
524
+ real_obs=real_obs,
525
+ generated_obs=generated_obs,
526
+ condition_columns=condition_columns,
527
+ gene_names=gene_names,
528
+ control_key=control_key,
529
+ perturbation_column=perturbation_column,
530
+ deg_method=deg_method,
531
+ pval_threshold=pval_threshold,
532
+ lfc_threshold=lfc_threshold,
533
+ metrics=metrics,
534
+ n_jobs=n_jobs,
535
+ device=device,
536
+ verbose=verbose,
537
+ )
538
+ return evaluator.evaluate()