gengeneeval 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/__init__.py CHANGED
@@ -49,7 +49,7 @@ CLI Usage:
49
49
  --conditions perturbation cell_type --output results/
50
50
  """
51
51
 
52
- __version__ = "0.4.0"
52
+ __version__ = "0.4.1"
53
53
  __author__ = "GenEval Team"
54
54
 
55
55
  # Main evaluation interface
geneval/deg/__init__.py CHANGED
@@ -31,6 +31,8 @@ from .context import (
31
31
  from .evaluator import (
32
32
  DEGEvaluator,
33
33
  DEGEvaluationResult,
34
+ DEGSettings,
35
+ ContextMetrics,
34
36
  evaluate_degs,
35
37
  )
36
38
  from .visualization import (
@@ -56,6 +58,8 @@ __all__ = [
56
58
  # Evaluator
57
59
  "DEGEvaluator",
58
60
  "DEGEvaluationResult",
61
+ "DEGSettings",
62
+ "ContextMetrics",
59
63
  "evaluate_degs",
60
64
  # Visualization
61
65
  "plot_deg_distributions",
geneval/deg/evaluator.py CHANGED
@@ -1,7 +1,9 @@
1
1
  """
2
2
  DEG-focused evaluator for GenGeneEval.
3
3
 
4
- Computes metrics only on differentially expressed genes, with support for:
4
+ Computes metrics on differentially expressed genes with full control:
5
+ - Comparison of DEG-only vs all-genes metrics
6
+ - Configurable DEG selection (top N, p-value, log fold change thresholds)
5
7
  - Per-context evaluation (covariates × perturbations)
6
8
  - Fast DEG detection with GPU acceleration
7
9
  - Aggregated and expanded result reporting
@@ -45,29 +47,109 @@ from ..metrics.accelerated import (
45
47
  )
46
48
 
47
49
 
50
+ @dataclass
51
+ class DEGSettings:
52
+ """Settings for DEG detection and filtering.
53
+
54
+ Attributes
55
+ ----------
56
+ method : str
57
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
58
+ pval_threshold : float
59
+ P-value threshold for significance
60
+ lfc_threshold : float
61
+ Absolute log2 fold change threshold
62
+ n_top_degs : int, optional
63
+ If set, use only top N DEGs by significance (overrides threshold filtering)
64
+ min_degs : int
65
+ Minimum number of DEGs required to compute metrics
66
+ """
67
+ method: str = "welch"
68
+ pval_threshold: float = 0.05
69
+ lfc_threshold: float = 0.5
70
+ n_top_degs: Optional[int] = None
71
+ min_degs: int = 5
72
+
73
+ def to_dict(self) -> Dict[str, Any]:
74
+ """Convert to dictionary."""
75
+ return {
76
+ "deg_method": self.method,
77
+ "pval_threshold": self.pval_threshold,
78
+ "lfc_threshold": self.lfc_threshold,
79
+ "n_top_degs": self.n_top_degs,
80
+ "min_degs": self.min_degs,
81
+ }
82
+
83
+
84
+ @dataclass
85
+ class ContextMetrics:
86
+ """Metrics for a single context, comparing DEG-only vs all genes.
87
+
88
+ Attributes
89
+ ----------
90
+ context_id : str
91
+ Context identifier
92
+ context_values : Dict
93
+ Context column values
94
+ n_samples_real : int
95
+ Number of real samples
96
+ n_samples_gen : int
97
+ Number of generated samples
98
+ n_genes_total : int
99
+ Total number of genes
100
+ deg_result : DEGResult
101
+ DEG detection results
102
+ deg_metrics : Dict[str, float]
103
+ Metrics computed on DEGs only
104
+ all_genes_metrics : Dict[str, float]
105
+ Metrics computed on all genes
106
+ """
107
+ context_id: str
108
+ context_values: Dict[str, Any]
109
+ n_samples_real: int
110
+ n_samples_gen: int
111
+ n_genes_total: int
112
+ deg_result: Optional[DEGResult]
113
+ deg_metrics: Dict[str, float]
114
+ all_genes_metrics: Dict[str, float]
115
+
116
+ @property
117
+ def n_degs(self) -> int:
118
+ """Number of DEGs."""
119
+ return self.deg_result.n_degs if self.deg_result else 0
120
+
121
+ @property
122
+ def deg_indices_used(self) -> np.ndarray:
123
+ """DEG indices actually used for metrics."""
124
+ return self.deg_result.deg_indices if self.deg_result else np.array([])
125
+
126
+
48
127
  @dataclass
49
128
  class DEGEvaluationResult:
50
- """Complete DEG evaluation results.
129
+ """Complete DEG evaluation results with comparison to all-genes metrics.
51
130
 
52
131
  Attributes
53
132
  ----------
54
- context_results : List[ContextResult]
55
- Results for each context
133
+ context_results : List[ContextMetrics]
134
+ Results for each context with both DEG and all-gene metrics
56
135
  aggregated_metrics : pd.DataFrame
57
- Aggregated metrics across contexts
136
+ Aggregated metrics across contexts (both DEG and all-genes)
58
137
  expanded_metrics : pd.DataFrame
59
- Per-context expanded metrics
138
+ Per-context expanded metrics (both DEG and all-genes)
60
139
  deg_summary : pd.DataFrame
61
140
  Summary of DEG detection per context
141
+ comparison_summary : pd.DataFrame
142
+ Comparison between DEG-only and all-genes metrics
62
143
  gene_names : np.ndarray
63
144
  All gene names
64
145
  settings : Dict
65
- Evaluation settings
146
+ Evaluation settings including DEG parameters
66
147
  """
67
- context_results: List[ContextResult]
148
+ context_results: List[ContextMetrics]
68
149
  aggregated_metrics: pd.DataFrame
69
150
  expanded_metrics: pd.DataFrame
70
151
  deg_summary: pd.DataFrame
152
+ comparison_summary: pd.DataFrame
71
153
  gene_names: np.ndarray
72
154
  settings: Dict[str, Any]
73
155
 
@@ -76,9 +158,15 @@ class DEGEvaluationResult:
76
158
  output_dir = Path(output_dir)
77
159
  output_dir.mkdir(parents=True, exist_ok=True)
78
160
 
79
- self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv")
80
- self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv")
81
- self.deg_summary.to_csv(output_dir / "deg_summary.csv")
161
+ self.aggregated_metrics.to_csv(output_dir / "deg_aggregated_metrics.csv", index=False)
162
+ self.expanded_metrics.to_csv(output_dir / "deg_expanded_metrics.csv", index=False)
163
+ self.deg_summary.to_csv(output_dir / "deg_summary.csv", index=False)
164
+ self.comparison_summary.to_csv(output_dir / "deg_vs_all_comparison.csv", index=False)
165
+
166
+ # Save settings
167
+ import json
168
+ with open(output_dir / "settings.json", "w") as f:
169
+ json.dump(self.settings, f, indent=2, default=str)
82
170
 
83
171
  # Save per-context DEG results
84
172
  deg_dir = output_dir / "deg_per_context"
@@ -86,24 +174,41 @@ class DEGEvaluationResult:
86
174
  for ctx_result in self.context_results:
87
175
  if ctx_result.deg_result is not None:
88
176
  ctx_result.deg_result.to_dataframe().to_csv(
89
- deg_dir / f"{ctx_result.context_id}_degs.csv"
177
+ deg_dir / f"{ctx_result.context_id}_degs.csv", index=False
90
178
  )
91
179
 
180
+ def get_deg_only_metrics(self) -> pd.DataFrame:
181
+ """Get expanded metrics for DEGs only."""
182
+ base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
183
+ deg_cols = [c for c in self.expanded_metrics.columns if c.startswith("deg_")]
184
+ cols = [c for c in base_cols if c in self.expanded_metrics.columns] + deg_cols
185
+ return self.expanded_metrics[cols].copy()
186
+
187
+ def get_all_genes_metrics(self) -> pd.DataFrame:
188
+ """Get expanded metrics for all genes."""
189
+ base_cols = ["context_id", "n_samples_real", "n_samples_gen", "n_degs", "n_genes_total"]
190
+ all_cols = [c for c in self.expanded_metrics.columns if c.startswith("all_")]
191
+ cols = [c for c in base_cols if c in self.expanded_metrics.columns] + all_cols
192
+ return self.expanded_metrics[cols].copy()
193
+
92
194
  def __repr__(self) -> str:
195
+ n_degs_avg = self.deg_summary["n_degs"].mean() if len(self.deg_summary) > 0 else 0
93
196
  return (
94
197
  f"DEGEvaluationResult(n_contexts={len(self.context_results)}, "
95
- f"metrics={list(self.aggregated_metrics.columns)})"
198
+ f"avg_degs={n_degs_avg:.1f}, "
199
+ f"settings={self.settings.get('deg_method', 'unknown')})"
96
200
  )
97
201
 
98
202
 
99
203
  class DEGEvaluator:
100
204
  """
101
- Evaluator that computes metrics on DEGs only.
205
+ Evaluator that computes metrics on DEGs with comparison to all genes.
102
206
 
103
207
  This evaluator:
104
208
  1. Detects DEGs for each perturbation context
105
- 2. Computes distributional metrics only on DEG genes
106
- 3. Reports per-context and aggregated results
209
+ 2. Computes distributional metrics on BOTH DEG genes AND all genes
210
+ 3. Provides comparison between DEG-focused and all-genes evaluation
211
+ 4. Reports per-context and aggregated results
107
212
 
108
213
  Parameters
109
214
  ----------
@@ -126,11 +231,15 @@ class DEGEvaluator:
126
231
  deg_method : str
127
232
  DEG detection method: "welch", "student", "wilcoxon", "logfc"
128
233
  pval_threshold : float
129
- P-value threshold for DEG significance
234
+ P-value threshold for DEG significance (default: 0.05)
130
235
  lfc_threshold : float
131
- Log2 fold change threshold
236
+ Log2 fold change threshold (default: 0.5)
237
+ n_top_degs : int, optional
238
+ If set, use only top N DEGs by significance instead of thresholds
132
239
  min_degs : int
133
- Minimum DEGs required to compute metrics
240
+ Minimum DEGs required to compute DEG-specific metrics (default: 5)
241
+ compute_all_genes : bool
242
+ Whether to also compute metrics on all genes (default: True)
134
243
  metrics : List[str], optional
135
244
  Metrics to compute. Default: all supported metrics.
136
245
  n_jobs : int
@@ -142,15 +251,39 @@ class DEGEvaluator:
142
251
 
143
252
  Examples
144
253
  --------
254
+ >>> # Basic usage - computes both DEG and all-genes metrics
145
255
  >>> evaluator = DEGEvaluator(
146
256
  ... real_data, generated_data,
147
257
  ... real_obs, generated_obs,
148
258
  ... condition_columns=["perturbation"],
149
- ... deg_method="welch",
150
- ... device="cuda",
151
259
  ... )
152
260
  >>> results = evaluator.evaluate()
153
- >>> results.save("output/")
261
+ >>> print(results.comparison_summary) # DEG vs all-genes comparison
262
+
263
+ >>> # Use top 100 DEGs only
264
+ >>> evaluator = DEGEvaluator(
265
+ ... real_data, generated_data,
266
+ ... real_obs, generated_obs,
267
+ ... condition_columns=["perturbation"],
268
+ ... n_top_degs=100, # Use top 100 most significant DEGs
269
+ ... )
270
+
271
+ >>> # Stricter thresholds
272
+ >>> evaluator = DEGEvaluator(
273
+ ... real_data, generated_data,
274
+ ... real_obs, generated_obs,
275
+ ... condition_columns=["perturbation"],
276
+ ... pval_threshold=0.01, # More stringent p-value
277
+ ... lfc_threshold=1.0, # log2 FC > 1 (2-fold change)
278
+ ... )
279
+
280
+ >>> # DEGs only (no all-genes metrics for speed)
281
+ >>> evaluator = DEGEvaluator(
282
+ ... real_data, generated_data,
283
+ ... real_obs, generated_obs,
284
+ ... condition_columns=["perturbation"],
285
+ ... compute_all_genes=False,
286
+ ... )
154
287
  """
155
288
 
156
289
  # Supported metrics
@@ -176,7 +309,9 @@ class DEGEvaluator:
176
309
  deg_method: DEGMethod = "welch",
177
310
  pval_threshold: float = 0.05,
178
311
  lfc_threshold: float = 0.5,
312
+ n_top_degs: Optional[int] = None,
179
313
  min_degs: int = 5,
314
+ compute_all_genes: bool = True,
180
315
  metrics: Optional[List[str]] = None,
181
316
  n_jobs: int = 1,
182
317
  device: str = "cpu",
@@ -187,15 +322,23 @@ class DEGEvaluator:
187
322
  self.real_obs = real_obs.reset_index(drop=True)
188
323
  self.generated_obs = generated_obs.reset_index(drop=True)
189
324
  self.condition_columns = condition_columns
325
+ self.n_genes = real_data.shape[1]
190
326
  self.gene_names = gene_names if gene_names is not None else np.array(
191
- [f"Gene_{i}" for i in range(real_data.shape[1])]
327
+ [f"Gene_{i}" for i in range(self.n_genes)]
192
328
  )
193
329
  self.control_key = control_key
194
330
  self.perturbation_column = perturbation_column or condition_columns[0]
195
- self.deg_method = deg_method
196
- self.pval_threshold = pval_threshold
197
- self.lfc_threshold = lfc_threshold
198
- self.min_degs = min_degs
331
+
332
+ # DEG settings
333
+ self.deg_settings = DEGSettings(
334
+ method=deg_method,
335
+ pval_threshold=pval_threshold,
336
+ lfc_threshold=lfc_threshold,
337
+ n_top_degs=n_top_degs,
338
+ min_degs=min_degs,
339
+ )
340
+
341
+ self.compute_all_genes = compute_all_genes
199
342
  self.metrics = metrics or self.SUPPORTED_METRICS
200
343
  self.n_jobs = n_jobs
201
344
  self.device = device
@@ -224,7 +367,11 @@ class DEGEvaluator:
224
367
  }
225
368
 
226
369
  self._log(f"DEGEvaluator initialized with {len(self.context_evaluator)} contexts")
227
- self._log(f"Perturbation contexts: {len(self.context_evaluator.get_perturbation_contexts())}")
370
+ self._log(f"DEG settings: method={deg_method}, pval<{pval_threshold}, |lfc|>{lfc_threshold}")
371
+ if n_top_degs is not None:
372
+ self._log(f" Using top {n_top_degs} DEGs by significance")
373
+ if compute_all_genes:
374
+ self._log("Will compute metrics on BOTH DEGs and all genes")
228
375
 
229
376
  def _log(self, msg: str) -> None:
230
377
  """Print if verbose."""
@@ -237,78 +384,87 @@ class DEGEvaluator:
237
384
  perturbed: np.ndarray,
238
385
  ) -> DEGResult:
239
386
  """Compute DEGs using configured method and device."""
240
- return compute_degs_auto(
387
+ deg_result = compute_degs_auto(
241
388
  control=control,
242
389
  perturbed=perturbed,
243
390
  gene_names=self.gene_names,
244
- method=self.deg_method,
245
- pval_threshold=self.pval_threshold,
246
- lfc_threshold=self.lfc_threshold,
391
+ method=self.deg_settings.method,
392
+ pval_threshold=self.deg_settings.pval_threshold,
393
+ lfc_threshold=self.deg_settings.lfc_threshold,
247
394
  n_jobs=self.n_jobs,
248
395
  device=self.device,
249
396
  )
397
+
398
+ # If n_top_degs is set, limit to top N by significance
399
+ if self.deg_settings.n_top_degs is not None:
400
+ deg_result = self._filter_top_degs(deg_result, self.deg_settings.n_top_degs)
401
+
402
+ return deg_result
250
403
 
251
- def _compute_metrics_on_degs(
252
- self,
253
- real: np.ndarray,
254
- generated: np.ndarray,
255
- deg_indices: np.ndarray,
256
- ) -> Dict[str, float]:
257
- """Compute metrics on DEG genes only."""
258
- if len(deg_indices) < self.min_degs:
259
- return {m: np.nan for m in self.metrics}
404
+ def _filter_top_degs(self, deg_result: DEGResult, n_top: int) -> DEGResult:
405
+ """Filter DEG result to keep only top N most significant DEGs."""
406
+ if deg_result.n_degs <= n_top:
407
+ return deg_result # Already fewer than n_top
260
408
 
261
- # Slice to DEGs only
262
- real_degs = real[:, deg_indices]
263
- gen_degs = generated[:, deg_indices]
409
+ # Sort DEGs by adjusted p-value (lower is more significant)
410
+ deg_pvals = deg_result.pvalues_adj[deg_result.is_deg]
411
+ deg_indices = deg_result.deg_indices
264
412
 
265
- results = {}
413
+ # Get indices of top N most significant
414
+ top_n_order = np.argsort(deg_pvals)[:n_top]
415
+ top_deg_indices = deg_indices[top_n_order]
266
416
 
267
- for metric_name in self.metrics:
268
- if metric_name not in self._metric_objects:
269
- continue
270
-
271
- metric = self._metric_objects[metric_name]
272
-
273
- try:
274
- # Compute per-gene and aggregate
275
- per_gene = metric.compute_per_gene(real_degs, gen_degs)
276
- results[metric_name] = float(np.nanmean(per_gene))
277
- except Exception as e:
278
- if self.verbose:
279
- self._log(f"Warning: {metric_name} failed: {e}")
280
- results[metric_name] = np.nan
417
+ # Create new is_deg mask
418
+ new_is_deg = np.zeros(len(deg_result.is_deg), dtype=bool)
419
+ new_is_deg[top_deg_indices] = True
281
420
 
282
- return results
421
+ # Create modified DEGResult with all required fields
422
+ return DEGResult(
423
+ gene_names=deg_result.gene_names,
424
+ pvalues=deg_result.pvalues,
425
+ pvalues_adj=deg_result.pvalues_adj,
426
+ log_fold_changes=deg_result.log_fold_changes,
427
+ mean_control=deg_result.mean_control,
428
+ mean_perturbed=deg_result.mean_perturbed,
429
+ is_deg=new_is_deg,
430
+ n_degs=n_top,
431
+ method=deg_result.method,
432
+ pval_threshold=deg_result.pval_threshold,
433
+ lfc_threshold=deg_result.lfc_threshold,
434
+ deg_indices=top_deg_indices,
435
+ )
283
436
 
284
- def _compute_metrics_accelerated(
437
+ def _compute_metrics_on_genes(
285
438
  self,
286
439
  real: np.ndarray,
287
440
  generated: np.ndarray,
288
- deg_indices: np.ndarray,
441
+ gene_indices: Optional[np.ndarray] = None,
442
+ min_genes: int = 1,
289
443
  ) -> Dict[str, float]:
290
- """Compute metrics using accelerated implementations."""
291
- if len(deg_indices) < self.min_degs:
292
- return {m: np.nan for m in self.metrics}
293
-
294
- # Slice to DEGs only
295
- real_degs = real[:, deg_indices]
296
- gen_degs = generated[:, deg_indices]
444
+ """Compute metrics on specified genes (or all if indices is None)."""
445
+ # Slice to selected genes
446
+ if gene_indices is not None:
447
+ if len(gene_indices) < min_genes:
448
+ return {m: np.nan for m in self.metrics}
449
+ real_subset = real[:, gene_indices]
450
+ gen_subset = generated[:, gene_indices]
451
+ else:
452
+ real_subset = real
453
+ gen_subset = generated
297
454
 
298
455
  results = {}
299
- backends = get_available_backends()
300
456
 
301
457
  # Use vectorized implementations where available
302
458
  if "wasserstein_1" in self.metrics:
303
459
  try:
304
- w1_per_gene = vectorized_wasserstein1(real_degs, gen_degs)
460
+ w1_per_gene = vectorized_wasserstein1(real_subset, gen_subset)
305
461
  results["wasserstein_1"] = float(np.nanmean(w1_per_gene))
306
462
  except Exception:
307
463
  results["wasserstein_1"] = np.nan
308
464
 
309
465
  if "mmd" in self.metrics:
310
466
  try:
311
- mmd_per_gene = vectorized_mmd(real_degs, gen_degs)
467
+ mmd_per_gene = vectorized_mmd(real_subset, gen_subset)
312
468
  results["mmd"] = float(np.nanmean(mmd_per_gene))
313
469
  except Exception:
314
470
  results["mmd"] = np.nan
@@ -322,7 +478,7 @@ class DEGEvaluator:
322
478
 
323
479
  metric = self._metric_objects[metric_name]
324
480
  try:
325
- per_gene = metric.compute_per_gene(real_degs, gen_degs)
481
+ per_gene = metric.compute_per_gene(real_subset, gen_subset)
326
482
  results[metric_name] = float(np.nanmean(per_gene))
327
483
  except Exception:
328
484
  results[metric_name] = np.nan
@@ -333,17 +489,23 @@ class DEGEvaluator:
333
489
  """
334
490
  Run DEG-focused evaluation on all contexts.
335
491
 
492
+ Returns both DEG-only and all-genes metrics for comparison.
493
+
336
494
  Returns
337
495
  -------
338
496
  DEGEvaluationResult
339
- Complete evaluation results with per-context and aggregated metrics.
497
+ Complete evaluation results with:
498
+ - Per-context DEG and all-genes metrics
499
+ - Aggregated metrics
500
+ - DEG summary
501
+ - Comparison between DEG and all-genes evaluation
340
502
  """
341
- context_results = []
503
+ context_results: List[ContextMetrics] = []
342
504
 
343
505
  perturbation_contexts = self.context_evaluator.get_perturbation_contexts()
344
506
  n_contexts = len(perturbation_contexts)
345
507
 
346
- self._log(f"Evaluating {n_contexts} perturbation contexts...")
508
+ self._log(f"\nEvaluating {n_contexts} perturbation contexts...")
347
509
 
348
510
  for i, context in enumerate(perturbation_contexts):
349
511
  context_id = get_context_id(context)
@@ -370,17 +532,30 @@ class DEGEvaluator:
370
532
  print(f"{deg_result.n_degs} DEGs", end="... ")
371
533
 
372
534
  # Compute metrics on DEGs
373
- metrics = self._compute_metrics_accelerated(
374
- real_pert, gen_pert, deg_result.deg_indices
535
+ deg_metrics = self._compute_metrics_on_genes(
536
+ real_pert, gen_pert,
537
+ gene_indices=deg_result.deg_indices,
538
+ min_genes=self.deg_settings.min_degs,
375
539
  )
376
540
 
377
- ctx_result = ContextResult(
541
+ # Compute metrics on all genes if requested
542
+ if self.compute_all_genes:
543
+ all_genes_metrics = self._compute_metrics_on_genes(
544
+ real_pert, gen_pert,
545
+ gene_indices=None, # All genes
546
+ )
547
+ else:
548
+ all_genes_metrics = {}
549
+
550
+ ctx_result = ContextMetrics(
378
551
  context_id=context_id,
379
552
  context_values=context,
380
553
  n_samples_real=len(real_pert),
381
554
  n_samples_gen=len(gen_pert),
555
+ n_genes_total=self.n_genes,
382
556
  deg_result=deg_result,
383
- metrics=metrics,
557
+ deg_metrics=deg_metrics,
558
+ all_genes_metrics=all_genes_metrics,
384
559
  )
385
560
  context_results.append(ctx_result)
386
561
 
@@ -393,52 +568,10 @@ class DEGEvaluator:
393
568
  continue
394
569
 
395
570
  # Build result DataFrames
396
- expanded_data = []
397
- for ctx_result in context_results:
398
- row = {
399
- "context_id": ctx_result.context_id,
400
- **ctx_result.context_values,
401
- "n_samples_real": ctx_result.n_samples_real,
402
- "n_samples_gen": ctx_result.n_samples_gen,
403
- "n_degs": ctx_result.deg_result.n_degs if ctx_result.deg_result else 0,
404
- **ctx_result.metrics,
405
- }
406
- expanded_data.append(row)
407
-
408
- expanded_metrics = pd.DataFrame(expanded_data)
409
-
410
- # Aggregated metrics
411
- if len(expanded_metrics) > 0:
412
- agg_data = {
413
- "n_contexts": len(context_results),
414
- "total_samples_real": expanded_metrics["n_samples_real"].sum(),
415
- "total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
416
- "mean_n_degs": expanded_metrics["n_degs"].mean(),
417
- "median_n_degs": expanded_metrics["n_degs"].median(),
418
- }
419
- for metric in self.metrics:
420
- if metric in expanded_metrics.columns:
421
- agg_data[f"{metric}_mean"] = expanded_metrics[metric].mean()
422
- agg_data[f"{metric}_std"] = expanded_metrics[metric].std()
423
- agg_data[f"{metric}_median"] = expanded_metrics[metric].median()
424
-
425
- aggregated_metrics = pd.DataFrame([agg_data])
426
- else:
427
- aggregated_metrics = pd.DataFrame()
428
-
429
- # DEG summary
430
- deg_summary_data = []
431
- for ctx_result in context_results:
432
- if ctx_result.deg_result is not None:
433
- deg_summary_data.append({
434
- "context_id": ctx_result.context_id,
435
- **ctx_result.context_values,
436
- "n_degs": ctx_result.deg_result.n_degs,
437
- "n_upregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] > 0).sum(),
438
- "n_downregulated": (ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg] < 0).sum(),
439
- "mean_abs_lfc": np.abs(ctx_result.deg_result.log_fold_changes[ctx_result.deg_result.is_deg]).mean() if ctx_result.deg_result.n_degs > 0 else np.nan,
440
- })
441
- deg_summary = pd.DataFrame(deg_summary_data)
571
+ expanded_metrics = self._build_expanded_metrics(context_results)
572
+ aggregated_metrics = self._build_aggregated_metrics(expanded_metrics)
573
+ deg_summary = self._build_deg_summary(context_results)
574
+ comparison_summary = self._build_comparison_summary(expanded_metrics)
442
575
 
443
576
  self._log(f"\nEvaluation complete: {len(context_results)} contexts evaluated")
444
577
 
@@ -447,17 +580,122 @@ class DEGEvaluator:
447
580
  aggregated_metrics=aggregated_metrics,
448
581
  expanded_metrics=expanded_metrics,
449
582
  deg_summary=deg_summary,
583
+ comparison_summary=comparison_summary,
450
584
  gene_names=self.gene_names,
451
585
  settings={
452
- "deg_method": self.deg_method,
453
- "pval_threshold": self.pval_threshold,
454
- "lfc_threshold": self.lfc_threshold,
455
- "min_degs": self.min_degs,
586
+ **self.deg_settings.to_dict(),
587
+ "compute_all_genes": self.compute_all_genes,
456
588
  "metrics": self.metrics,
457
589
  "device": self.device,
458
590
  "n_jobs": self.n_jobs,
459
591
  },
460
592
  )
593
+
594
+ def _build_expanded_metrics(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
595
+ """Build expanded metrics DataFrame with both DEG and all-genes columns."""
596
+ data = []
597
+ for ctx in context_results:
598
+ row = {
599
+ "context_id": ctx.context_id,
600
+ **ctx.context_values,
601
+ "n_samples_real": ctx.n_samples_real,
602
+ "n_samples_gen": ctx.n_samples_gen,
603
+ "n_genes_total": ctx.n_genes_total,
604
+ "n_degs": ctx.n_degs,
605
+ }
606
+
607
+ # DEG metrics with prefix
608
+ for metric_name, value in ctx.deg_metrics.items():
609
+ row[f"deg_{metric_name}"] = value
610
+
611
+ # All-genes metrics with prefix
612
+ for metric_name, value in ctx.all_genes_metrics.items():
613
+ row[f"all_{metric_name}"] = value
614
+
615
+ data.append(row)
616
+
617
+ return pd.DataFrame(data)
618
+
619
+ def _build_aggregated_metrics(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
620
+ """Build aggregated metrics DataFrame."""
621
+ if len(expanded_metrics) == 0:
622
+ return pd.DataFrame()
623
+
624
+ agg_data = {
625
+ "n_contexts": len(expanded_metrics),
626
+ "total_samples_real": expanded_metrics["n_samples_real"].sum(),
627
+ "total_samples_gen": expanded_metrics["n_samples_gen"].sum(),
628
+ "mean_n_degs": expanded_metrics["n_degs"].mean(),
629
+ "median_n_degs": expanded_metrics["n_degs"].median(),
630
+ "min_n_degs": expanded_metrics["n_degs"].min(),
631
+ "max_n_degs": expanded_metrics["n_degs"].max(),
632
+ }
633
+
634
+ # Aggregate DEG metrics
635
+ for metric in self.metrics:
636
+ col = f"deg_{metric}"
637
+ if col in expanded_metrics.columns:
638
+ agg_data[f"deg_{metric}_mean"] = expanded_metrics[col].mean()
639
+ agg_data[f"deg_{metric}_std"] = expanded_metrics[col].std()
640
+
641
+ # Aggregate all-genes metrics
642
+ for metric in self.metrics:
643
+ col = f"all_{metric}"
644
+ if col in expanded_metrics.columns:
645
+ agg_data[f"all_{metric}_mean"] = expanded_metrics[col].mean()
646
+ agg_data[f"all_{metric}_std"] = expanded_metrics[col].std()
647
+
648
+ return pd.DataFrame([agg_data])
649
+
650
+ def _build_deg_summary(self, context_results: List[ContextMetrics]) -> pd.DataFrame:
651
+ """Build DEG summary DataFrame."""
652
+ data = []
653
+ for ctx in context_results:
654
+ if ctx.deg_result is not None:
655
+ deg_lfcs = ctx.deg_result.log_fold_changes[ctx.deg_result.is_deg]
656
+ data.append({
657
+ "context_id": ctx.context_id,
658
+ **ctx.context_values,
659
+ "n_degs": ctx.n_degs,
660
+ "n_genes_total": ctx.n_genes_total,
661
+ "deg_fraction": ctx.n_degs / ctx.n_genes_total if ctx.n_genes_total > 0 else 0,
662
+ "n_upregulated": (deg_lfcs > 0).sum(),
663
+ "n_downregulated": (deg_lfcs < 0).sum(),
664
+ "mean_abs_lfc": float(np.abs(deg_lfcs).mean()) if len(deg_lfcs) > 0 else np.nan,
665
+ "max_abs_lfc": float(np.abs(deg_lfcs).max()) if len(deg_lfcs) > 0 else np.nan,
666
+ })
667
+ return pd.DataFrame(data)
668
+
669
+ def _build_comparison_summary(self, expanded_metrics: pd.DataFrame) -> pd.DataFrame:
670
+ """Build comparison summary between DEG and all-genes metrics."""
671
+ if len(expanded_metrics) == 0:
672
+ return pd.DataFrame()
673
+
674
+ comparison_data = []
675
+ for metric in self.metrics:
676
+ deg_col = f"deg_{metric}"
677
+ all_col = f"all_{metric}"
678
+
679
+ if deg_col in expanded_metrics.columns and all_col in expanded_metrics.columns:
680
+ deg_values = expanded_metrics[deg_col].dropna()
681
+ all_values = expanded_metrics[all_col].dropna()
682
+
683
+ if len(deg_values) > 0 and len(all_values) > 0:
684
+ deg_mean = deg_values.mean()
685
+ all_mean = all_values.mean()
686
+
687
+ comparison_data.append({
688
+ "metric": metric,
689
+ "deg_mean": deg_mean,
690
+ "deg_std": deg_values.std(),
691
+ "all_mean": all_mean,
692
+ "all_std": all_values.std(),
693
+ "difference": deg_mean - all_mean,
694
+ "ratio": deg_mean / all_mean if all_mean != 0 else np.nan,
695
+ "n_contexts": len(deg_values),
696
+ })
697
+
698
+ return pd.DataFrame(comparison_data)
461
699
 
462
700
 
463
701
  def evaluate_degs(
@@ -472,51 +710,93 @@ def evaluate_degs(
472
710
  deg_method: DEGMethod = "welch",
473
711
  pval_threshold: float = 0.05,
474
712
  lfc_threshold: float = 0.5,
713
+ n_top_degs: Optional[int] = None,
714
+ min_degs: int = 5,
715
+ compute_all_genes: bool = True,
475
716
  metrics: Optional[List[str]] = None,
476
717
  n_jobs: int = 1,
477
718
  device: str = "auto",
478
719
  verbose: bool = True,
479
720
  ) -> DEGEvaluationResult:
480
721
  """
481
- Convenience function for DEG-focused evaluation.
722
+ Convenience function for DEG-focused evaluation with full control.
723
+
724
+ Computes metrics on both DEGs and all genes for comparison.
482
725
 
483
726
  Parameters
484
727
  ----------
485
728
  real_data : np.ndarray
486
- Real expression matrix
729
+ Real expression matrix (n_samples, n_genes)
487
730
  generated_data : np.ndarray
488
- Generated expression matrix
731
+ Generated expression matrix (n_samples, n_genes)
489
732
  real_obs : pd.DataFrame
490
- Real data metadata
733
+ Real data metadata with condition columns
491
734
  generated_obs : pd.DataFrame
492
- Generated data metadata
735
+ Generated data metadata with condition columns
493
736
  condition_columns : List[str]
494
- Columns defining contexts
737
+ Columns defining contexts (e.g., ["cell_type", "perturbation"])
495
738
  gene_names : np.ndarray, optional
496
- Gene names
739
+ Gene names for output
497
740
  control_key : str
498
- Control condition identifier
741
+ Control condition identifier (default: "control")
499
742
  perturbation_column : str, optional
500
743
  Column containing perturbation info. If None, uses first condition column.
501
744
  deg_method : str
502
- DEG detection method
745
+ DEG detection method: "welch", "student", "wilcoxon", "logfc"
503
746
  pval_threshold : float
504
- P-value threshold
747
+ Adjusted p-value threshold (default: 0.05)
505
748
  lfc_threshold : float
506
- Log fold change threshold
749
+ Absolute log2 fold change threshold (default: 0.5)
750
+ n_top_degs : int, optional
751
+ If set, use only top N DEGs by significance (overrides thresholds)
752
+ min_degs : int
753
+ Minimum DEGs to compute DEG metrics (default: 5)
754
+ compute_all_genes : bool
755
+ Also compute metrics on all genes for comparison (default: True)
507
756
  metrics : List[str], optional
508
- Metrics to compute
757
+ Metrics to compute. Default: all supported.
509
758
  n_jobs : int
510
- Parallel CPU jobs
759
+ Parallel CPU jobs (default: 1)
511
760
  device : str
512
- Compute device
761
+ Compute device: "cpu", "cuda", "mps", "auto" (default: "auto")
513
762
  verbose : bool
514
- Print progress
763
+ Print progress (default: True)
515
764
 
516
765
  Returns
517
766
  -------
518
767
  DEGEvaluationResult
519
- Evaluation results
768
+ Complete evaluation results including:
769
+ - expanded_metrics: Per-context metrics for DEGs and all genes
770
+ - aggregated_metrics: Summary statistics
771
+ - deg_summary: DEG detection summary
772
+ - comparison_summary: DEG vs all-genes comparison
773
+
774
+ Examples
775
+ --------
776
+ >>> # Basic usage with default thresholds
777
+ >>> results = evaluate_degs(
778
+ ... real_data, generated_data,
779
+ ... real_obs, generated_obs,
780
+ ... condition_columns=["perturbation"],
781
+ ... )
782
+ >>> print(results.comparison_summary)
783
+
784
+ >>> # Top 50 DEGs only
785
+ >>> results = evaluate_degs(
786
+ ... real_data, generated_data,
787
+ ... real_obs, generated_obs,
788
+ ... condition_columns=["perturbation"],
789
+ ... n_top_degs=50,
790
+ ... )
791
+
792
+ >>> # Strict thresholds
793
+ >>> results = evaluate_degs(
794
+ ... real_data, generated_data,
795
+ ... real_obs, generated_obs,
796
+ ... condition_columns=["perturbation"],
797
+ ... pval_threshold=0.01,
798
+ ... lfc_threshold=1.0, # 2-fold change
799
+ ... )
520
800
  """
521
801
  evaluator = DEGEvaluator(
522
802
  real_data=real_data,
@@ -530,6 +810,9 @@ def evaluate_degs(
530
810
  deg_method=deg_method,
531
811
  pval_threshold=pval_threshold,
532
812
  lfc_threshold=lfc_threshold,
813
+ n_top_degs=n_top_degs,
814
+ min_degs=min_degs,
815
+ compute_all_genes=compute_all_genes,
533
816
  metrics=metrics,
534
817
  n_jobs=n_jobs,
535
818
  device=device,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, DEG-focused evaluation, per-context analysis, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -256,6 +256,8 @@ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, co
256
256
  #### Key Features
257
257
 
258
258
  - **Fast DEG detection**: Vectorized Welch's t-test, Student's t-test, or Wilcoxon rank-sum
259
+ - **DEG vs all-genes comparison**: Compute metrics on both and compare
260
+ - **Flexible DEG selection**: Top N by significance, or threshold-based filtering
259
261
  - **Per-context evaluation**: Automatically evaluates each (covariate × perturbation) combination
260
262
  - **GPU acceleration**: DEG detection and metrics on GPU for large datasets
261
263
  - **Comprehensive reporting**: Aggregated and expanded results with visualizations
@@ -266,7 +268,7 @@ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, co
266
268
  from geneval import evaluate_degs
267
269
  import pandas as pd
268
270
 
269
- # Evaluate with DEG-focused metrics
271
+ # Evaluate with DEG-focused metrics (computes both DEG and all-genes by default)
270
272
  results = evaluate_degs(
271
273
  real_data=real_adata.X, # (n_samples, n_genes)
272
274
  generated_data=gen_adata.X,
@@ -276,29 +278,62 @@ results = evaluate_degs(
276
278
  control_key="control", # Value indicating control samples
277
279
  perturbation_column="perturbation",
278
280
  deg_method="welch", # or "student", "wilcoxon", "logfc"
279
- pval_threshold=0.05,
281
+ pval_threshold=0.05, # Significance threshold
280
282
  lfc_threshold=0.5, # log2 fold change threshold
283
+ compute_all_genes=True, # Also compute metrics on all genes
281
284
  device="cuda", # GPU acceleration
282
285
  )
283
286
 
284
- # Access results
285
- print(results.aggregated_metrics) # Summary across all contexts
286
- print(results.expanded_metrics) # Per-context metrics
287
+ # Compare DEG-only vs all-genes metrics
288
+ print(results.comparison_summary)
289
+ # metric deg_mean all_mean difference ratio
290
+ # wasserstein_1 5.34 0.69 4.65 7.74
291
+ # mmd 1.14 0.13 1.02 9.00
292
+
293
+ # Access per-context results
294
+ print(results.expanded_metrics) # Has deg_* and all_* columns
287
295
  print(results.deg_summary) # DEG counts per context
288
296
 
289
297
  # Save results with plots
290
298
  results.save("deg_evaluation/")
291
299
  ```
292
300
 
301
+ #### DEG Selection Control
302
+
303
+ ```python
304
+ # Option 1: Top N most significant DEGs
305
+ results = evaluate_degs(
306
+ ...,
307
+ n_top_degs=50, # Use only top 50 DEGs by adjusted p-value
308
+ )
309
+
310
+ # Option 2: Stricter thresholds
311
+ results = evaluate_degs(
312
+ ...,
313
+ pval_threshold=0.01, # More stringent p-value
314
+ lfc_threshold=1.0, # 2-fold change minimum
315
+ )
316
+
317
+ # Option 3: DEGs only (skip all-genes metrics for speed)
318
+ results = evaluate_degs(
319
+ ...,
320
+ compute_all_genes=False,
321
+ )
322
+
323
+ # Get DEG-only or all-genes metrics separately
324
+ deg_only = results.get_deg_only_metrics()
325
+ all_genes = results.get_all_genes_metrics()
326
+ ```
327
+
293
328
  #### Per-Context Evaluation
294
329
 
295
330
  When multiple condition columns are provided (e.g., `["cell_type", "perturbation"]`), GenEval evaluates **every combination** separately:
296
331
 
297
- | Context | n_DEGs | W1 (DEGs only) | MMD (DEGs only) |
298
- |---------|--------|----------------|-----------------|
299
- | TypeA_drug1 | 234 | 0.42 | 0.031 |
300
- | TypeA_drug2 | 189 | 0.38 | 0.027 |
301
- | TypeB_drug1 | 312 | 0.51 | 0.045 |
332
+ | Context | n_DEGs | deg_W1 | all_W1 | deg_MMD | all_MMD |
333
+ |---------|--------|--------|--------|---------|---------|
334
+ | TypeA_drug1 | 234 | 5.42 | 0.69 | 1.03 | 0.13 |
335
+ | TypeA_drug2 | 189 | 4.38 | 0.71 | 0.92 | 0.12 |
336
+ | TypeB_drug1 | 312 | 6.51 | 0.68 | 1.21 | 0.14 |
302
337
 
303
338
  If only `perturbation` column is provided, evaluation is done per-perturbation.
304
339
 
@@ -1,4 +1,4 @@
1
- geneval/__init__.py,sha256=1ENlptAErFX1ThLDuO8J5Hs0ko5gIxGGVq7PZUhBUKY,5418
1
+ geneval/__init__.py,sha256=UD-fl1x0J0VUTyktgvUzCqaU1kLaU2vmYALfdzm-TzQ,5418
2
2
  geneval/cli.py,sha256=0ai0IGyn3SSmEnfLRJhcr0brvUxuNZHE4IXod7jvosU,9977
3
3
  geneval/config.py,sha256=gkCjs_gzPWgUZNcmSR3Y70XQCAZ1m9AKLueaM-x8bvw,3729
4
4
  geneval/core.py,sha256=No0DP8bNR6LedfCWEedY9C5r_c4M14rvSPaGZqbxc94,1155
@@ -6,10 +6,10 @@ geneval/data/__init__.py,sha256=NQUPVpUnBIabrTH5TuRk0KE9S7sVO5QetZv-MCQmZuw,827
6
6
  geneval/data/gene_expression_datamodule.py,sha256=XiBIdf68JZ-3S-FaZsrQlBJA7qL9uUXo2C8y0r4an5M,8009
7
7
  geneval/data/lazy_loader.py,sha256=5fTRVjPjcWvYXV-uPWFUF2Nn9rHRdD8lygAUkCW8wOM,20677
8
8
  geneval/data/loader.py,sha256=zpRmwGZ4PJkB3rpXXRCMFtvMi4qvUrPkKmvIlGjfRpY,14555
9
- geneval/deg/__init__.py,sha256=joH816k_UWvu2qVhWb-fTbMQTmAhz4nUvt6yraziRek,1499
9
+ geneval/deg/__init__.py,sha256=iNKvtbumTA-A1usWhHIP1rbRVNkje5tN5x81FzD6CbI,1577
10
10
  geneval/deg/context.py,sha256=_9gnWnRqqCZUDlegV2sT_rQrw8OeP1TIE9NZjNcI0ig,9069
11
11
  geneval/deg/detection.py,sha256=gDdHOyFLOfl_B0xutS3KVFy53sreJ19N33B0RRI01wo,18119
12
- geneval/deg/evaluator.py,sha256=MiBT2GOXUwq9rxHVAnJOVSbybX0rVgTsSDvOeJtnanE,18570
12
+ geneval/deg/evaluator.py,sha256=uPduuWovUD6B_vie4RomH-F9MgrtaqQbtjmJlvEeDYM,30493
13
13
  geneval/deg/visualization.py,sha256=9lWW9vRH_FbkIjJrf1MPobU1Yu_CAh6aw60S7g2Qe2k,10448
14
14
  geneval/evaluator.py,sha256=WgdrgqOcGYT35k1keiFEIIRIj2CQaD2DsmBpq9hcLrI,13440
15
15
  geneval/evaluators/__init__.py,sha256=i11sHvhsjEAeI3Aw9zFTPmCYuqkGxzTHggAKehe3HQ0,160
@@ -33,8 +33,8 @@ geneval/utils/preprocessing.py,sha256=1Cij1O2dwDR6_zh5IEgLPq3jEmV8VfIRjfQrHiKe3M
33
33
  geneval/visualization/__init__.py,sha256=LN19jl5xV4WVJTePaOUHWvKZ_pgDFp1chhcklGkNtm8,792
34
34
  geneval/visualization/plots.py,sha256=3K94r3x5NjIUZ-hYVQIivO63VkLOvDWl-BLB_qL2pSY,15008
35
35
  geneval/visualization/visualizer.py,sha256=lX7K0j20nAsgdtOOdbxLdLKYAfovEp3hNAnZOjFTCq0,36670
36
- gengeneeval-0.4.0.dist-info/METADATA,sha256=R3GI2E_z6qC1olM0D3aPKrJ3yjQDf_9-GncDqvNhwMY,12879
37
- gengeneeval-0.4.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
- gengeneeval-0.4.0.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
39
- gengeneeval-0.4.0.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
40
- gengeneeval-0.4.0.dist-info/RECORD,,
36
+ gengeneeval-0.4.1.dist-info/METADATA,sha256=1xoULyzbHjzOKmTEQcx5fRv7DXEMJwe2XqqKxQhCi2Q,14041
37
+ gengeneeval-0.4.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
+ gengeneeval-0.4.1.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
39
+ gengeneeval-0.4.1.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
40
+ gengeneeval-0.4.1.dist-info/RECORD,,