gengeneeval 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/PKG-INFO +90 -3
  2. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/README.md +87 -0
  3. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/pyproject.toml +3 -3
  4. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/__init__.py +43 -1
  5. gengeneeval-0.4.0/src/geneval/deg/__init__.py +65 -0
  6. gengeneeval-0.4.0/src/geneval/deg/context.py +271 -0
  7. gengeneeval-0.4.0/src/geneval/deg/detection.py +578 -0
  8. gengeneeval-0.4.0/src/geneval/deg/evaluator.py +538 -0
  9. gengeneeval-0.4.0/src/geneval/deg/visualization.py +376 -0
  10. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/LICENSE +0 -0
  11. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/cli.py +0 -0
  12. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/config.py +0 -0
  13. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/core.py +0 -0
  14. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/data/__init__.py +0 -0
  15. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/data/gene_expression_datamodule.py +0 -0
  16. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/data/lazy_loader.py +0 -0
  17. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/data/loader.py +0 -0
  18. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/evaluator.py +0 -0
  19. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/evaluators/__init__.py +0 -0
  20. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/evaluators/base_evaluator.py +0 -0
  21. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/evaluators/gene_expression_evaluator.py +0 -0
  22. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/lazy_evaluator.py +0 -0
  23. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/__init__.py +0 -0
  24. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/accelerated.py +0 -0
  25. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/base_metric.py +0 -0
  26. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/correlation.py +0 -0
  27. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/distances.py +0 -0
  28. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/metrics.py +0 -0
  29. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/metrics/reconstruction.py +0 -0
  30. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/models/__init__.py +0 -0
  31. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/models/base_model.py +0 -0
  32. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/results.py +0 -0
  33. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/testing.py +0 -0
  34. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/utils/__init__.py +0 -0
  35. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/utils/io.py +0 -0
  36. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/utils/preprocessing.py +0 -0
  37. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/visualization/__init__.py +0 -0
  38. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/visualization/plots.py +0 -0
  39. {gengeneeval-0.3.0 → gengeneeval-0.4.0}/src/geneval/visualization/visualizer.py +0 -0
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.3.0
4
- Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
3
+ Version: 0.4.0
4
+ Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, DEG-focused evaluation, per-context analysis, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
7
- Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
7
+ Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient,DEG,perturbation
8
8
  Author: GenEval Team
9
9
  Author-email: geneval@example.com
10
10
  Requires-Python: >=3.8,<4.0
@@ -78,6 +78,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
78
78
  - ✅ Condition-based matching (perturbation, cell type, etc.)
79
79
  - ✅ Train/test split support
80
80
  - ✅ Per-gene and aggregate metrics
81
+ - ✅ **DEG-focused evaluation** with per-context (covariate × perturbation) support
82
+ - ✅ **Fast DEG detection** via vectorized Welch's t-test, Student's t-test, Wilcoxon
81
83
  - ✅ **Memory-efficient lazy loading** for large datasets
82
84
  - ✅ **Batched evaluation** to avoid OOM errors
83
85
  - ✅ **CPU parallelization** via joblib (multi-core speedup)
@@ -247,6 +249,91 @@ print(f"MMD: {results['mmd'].aggregate_value:.4f}")
247
249
  | `device="mps"` | 5-20x | Apple Silicon Macs |
248
250
  | Vectorized NumPy | 2-5x | Automatic fallback |
249
251
 
252
+ ### DEG-Focused Evaluation
253
+
254
+ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, computing metrics only on biologically relevant DEGs rather than all genes. This provides more meaningful evaluation for perturbation prediction tasks.
255
+
256
+ #### Key Features
257
+
258
+ - **Fast DEG detection**: Vectorized Welch's t-test, Student's t-test, or Wilcoxon rank-sum
259
+ - **Per-context evaluation**: Automatically evaluates each (covariate × perturbation) combination
260
+ - **GPU acceleration**: DEG detection and metrics on GPU for large datasets
261
+ - **Comprehensive reporting**: Aggregated and expanded results with visualizations
262
+
263
+ #### Quick Start
264
+
265
+ ```python
266
+ from geneval import evaluate_degs
267
+ import pandas as pd
268
+
269
+ # Evaluate with DEG-focused metrics
270
+ results = evaluate_degs(
271
+ real_data=real_adata.X, # (n_samples, n_genes)
272
+ generated_data=gen_adata.X,
273
+ real_obs=real_adata.obs,
274
+ generated_obs=gen_adata.obs,
275
+ condition_columns=["cell_type", "perturbation"], # Context columns
276
+ control_key="control", # Value indicating control samples
277
+ perturbation_column="perturbation",
278
+ deg_method="welch", # or "student", "wilcoxon", "logfc"
279
+ pval_threshold=0.05,
280
+ lfc_threshold=0.5, # log2 fold change threshold
281
+ device="cuda", # GPU acceleration
282
+ )
283
+
284
+ # Access results
285
+ print(results.aggregated_metrics) # Summary across all contexts
286
+ print(results.expanded_metrics) # Per-context metrics
287
+ print(results.deg_summary) # DEG counts per context
288
+
289
+ # Save results with plots
290
+ results.save("deg_evaluation/")
291
+ ```
292
+
293
+ #### Per-Context Evaluation
294
+
295
+ When multiple condition columns are provided (e.g., `["cell_type", "perturbation"]`), GenEval evaluates **every combination** separately:
296
+
297
+ | Context | n_DEGs | W1 (DEGs only) | MMD (DEGs only) |
298
+ |---------|--------|----------------|-----------------|
299
+ | TypeA_drug1 | 234 | 0.42 | 0.031 |
300
+ | TypeA_drug2 | 189 | 0.38 | 0.027 |
301
+ | TypeB_drug1 | 312 | 0.51 | 0.045 |
302
+
303
+ If only `perturbation` column is provided, evaluation is done per-perturbation.
304
+
305
+ #### Available DEG Methods
306
+
307
+ | Method | Description | Speed |
308
+ |--------|-------------|-------|
309
+ | `welch` | Welch's t-test (unequal variance) | ⚡ Fast |
310
+ | `student` | Student's t-test (equal variance) | ⚡ Fast |
311
+ | `wilcoxon` | Wilcoxon rank-sum (non-parametric) | 🐢 Slower |
312
+ | `logfc` | Log fold change only (no p-value) | ⚡⚡ Fastest |
313
+
314
+ #### Visualization
315
+
316
+ ```python
317
+ from geneval.deg import (
318
+ plot_deg_distributions,
319
+ plot_context_heatmap,
320
+ plot_deg_counts,
321
+ create_deg_report,
322
+ )
323
+
324
+ # Distribution of metrics across contexts
325
+ plot_deg_distributions(results, save_path="dist.png")
326
+
327
+ # Heatmap: context × metric
328
+ plot_context_heatmap(results, save_path="heatmap.png")
329
+
330
+ # DEG counts per context (up/down regulated)
331
+ plot_deg_counts(results, save_path="deg_counts.png")
332
+
333
+ # Generate comprehensive report
334
+ create_deg_report(results, "report/", include_plots=True)
335
+ ```
336
+
250
337
  ## Expected Data Format
251
338
 
252
339
  GenEval expects AnnData (h5ad) files with:
@@ -38,6 +38,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
38
38
  - ✅ Condition-based matching (perturbation, cell type, etc.)
39
39
  - ✅ Train/test split support
40
40
  - ✅ Per-gene and aggregate metrics
41
+ - ✅ **DEG-focused evaluation** with per-context (covariate × perturbation) support
42
+ - ✅ **Fast DEG detection** via vectorized Welch's t-test, Student's t-test, Wilcoxon
41
43
  - ✅ **Memory-efficient lazy loading** for large datasets
42
44
  - ✅ **Batched evaluation** to avoid OOM errors
43
45
  - ✅ **CPU parallelization** via joblib (multi-core speedup)
@@ -207,6 +209,91 @@ print(f"MMD: {results['mmd'].aggregate_value:.4f}")
207
209
  | `device="mps"` | 5-20x | Apple Silicon Macs |
208
210
  | Vectorized NumPy | 2-5x | Automatic fallback |
209
211
 
212
+ ### DEG-Focused Evaluation
213
+
214
+ GenEval supports **Differentially Expressed Genes (DEG)-focused evaluation**, computing metrics only on biologically relevant DEGs rather than all genes. This provides more meaningful evaluation for perturbation prediction tasks.
215
+
216
+ #### Key Features
217
+
218
+ - **Fast DEG detection**: Vectorized Welch's t-test, Student's t-test, or Wilcoxon rank-sum
219
+ - **Per-context evaluation**: Automatically evaluates each (covariate × perturbation) combination
220
+ - **GPU acceleration**: DEG detection and metrics on GPU for large datasets
221
+ - **Comprehensive reporting**: Aggregated and expanded results with visualizations
222
+
223
+ #### Quick Start
224
+
225
+ ```python
226
+ from geneval import evaluate_degs
227
+ import pandas as pd
228
+
229
+ # Evaluate with DEG-focused metrics
230
+ results = evaluate_degs(
231
+ real_data=real_adata.X, # (n_samples, n_genes)
232
+ generated_data=gen_adata.X,
233
+ real_obs=real_adata.obs,
234
+ generated_obs=gen_adata.obs,
235
+ condition_columns=["cell_type", "perturbation"], # Context columns
236
+ control_key="control", # Value indicating control samples
237
+ perturbation_column="perturbation",
238
+ deg_method="welch", # or "student", "wilcoxon", "logfc"
239
+ pval_threshold=0.05,
240
+ lfc_threshold=0.5, # log2 fold change threshold
241
+ device="cuda", # GPU acceleration
242
+ )
243
+
244
+ # Access results
245
+ print(results.aggregated_metrics) # Summary across all contexts
246
+ print(results.expanded_metrics) # Per-context metrics
247
+ print(results.deg_summary) # DEG counts per context
248
+
249
+ # Save results with plots
250
+ results.save("deg_evaluation/")
251
+ ```
252
+
253
+ #### Per-Context Evaluation
254
+
255
+ When multiple condition columns are provided (e.g., `["cell_type", "perturbation"]`), GenEval evaluates **every combination** separately:
256
+
257
+ | Context | n_DEGs | W1 (DEGs only) | MMD (DEGs only) |
258
+ |---------|--------|----------------|-----------------|
259
+ | TypeA_drug1 | 234 | 0.42 | 0.031 |
260
+ | TypeA_drug2 | 189 | 0.38 | 0.027 |
261
+ | TypeB_drug1 | 312 | 0.51 | 0.045 |
262
+
263
+ If only `perturbation` column is provided, evaluation is done per-perturbation.
264
+
265
+ #### Available DEG Methods
266
+
267
+ | Method | Description | Speed |
268
+ |--------|-------------|-------|
269
+ | `welch` | Welch's t-test (unequal variance) | ⚡ Fast |
270
+ | `student` | Student's t-test (equal variance) | ⚡ Fast |
271
+ | `wilcoxon` | Wilcoxon rank-sum (non-parametric) | 🐢 Slower |
272
+ | `logfc` | Log fold change only (no p-value) | ⚡⚡ Fastest |
273
+
274
+ #### Visualization
275
+
276
+ ```python
277
+ from geneval.deg import (
278
+ plot_deg_distributions,
279
+ plot_context_heatmap,
280
+ plot_deg_counts,
281
+ create_deg_report,
282
+ )
283
+
284
+ # Distribution of metrics across contexts
285
+ plot_deg_distributions(results, save_path="dist.png")
286
+
287
+ # Heatmap: context × metric
288
+ plot_context_heatmap(results, save_path="heatmap.png")
289
+
290
+ # DEG counts per context (up/down regulated)
291
+ plot_deg_counts(results, save_path="deg_counts.png")
292
+
293
+ # Generate comprehensive report
294
+ create_deg_report(results, "report/", include_plots=True)
295
+ ```
296
+
210
297
  ## Expected Data Format
211
298
 
212
299
  GenEval expects AnnData (h5ad) files with:
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "gengeneeval"
3
- version = "0.3.0"
4
- description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations."
3
+ version = "0.4.0"
4
+ description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, DEG-focused evaluation, per-context analysis, train/test splits, memory-efficient lazy loading, CPU parallelization, GPU acceleration, and publication-quality visualizations."
5
5
  authors = ["GenEval Team <geneval@example.com>"]
6
6
  license = "MIT"
7
7
  readme = "README.md"
8
8
  homepage = "https://github.com/AndreaRubbi/GenGeneEval"
9
9
  repository = "https://github.com/AndreaRubbi/GenGeneEval"
10
- keywords = ["gene expression", "evaluation", "metrics", "single-cell", "generative models", "benchmarking", "memory-efficient"]
10
+ keywords = ["gene expression", "evaluation", "metrics", "single-cell", "generative models", "benchmarking", "memory-efficient", "DEG", "perturbation"]
11
11
  classifiers = [
12
12
  "Development Status :: 4 - Beta",
13
13
  "Intended Audience :: Science/Research",
@@ -7,8 +7,10 @@ and generated gene expression datasets stored in AnnData (h5ad) format.
7
7
  Features:
8
8
  - Multiple distance and correlation metrics (per-gene and aggregate)
9
9
  - Condition-based matching (perturbation, cell type, etc.)
10
+ - DEG-focused evaluation with per-context (covariate × perturbation) support
10
11
  - Train/test split support
11
12
  - Memory-efficient lazy loading for large datasets
13
+ - CPU parallelization and GPU acceleration
12
14
  - Publication-quality visualizations
13
15
  - Command-line interface
14
16
 
@@ -21,6 +23,17 @@ Quick Start:
21
23
  ... output_dir="output/"
22
24
  ... )
23
25
 
26
+ DEG-Focused Evaluation:
27
+ >>> from geneval import evaluate_degs
28
+ >>> results = evaluate_degs(
29
+ ... real_data, generated_data,
30
+ ... real_obs, generated_obs,
31
+ ... condition_columns=["cell_type", "perturbation"],
32
+ ... control_key="control",
33
+ ... deg_method="welch",
34
+ ... device="cuda", # GPU acceleration
35
+ ... )
36
+
24
37
  Memory-Efficient Mode (for large datasets):
25
38
  >>> from geneval import evaluate_lazy
26
39
  >>> results = evaluate_lazy(
@@ -36,7 +49,7 @@ CLI Usage:
36
49
  --conditions perturbation cell_type --output results/
37
50
  """
38
51
 
39
- __version__ = "0.3.0"
52
+ __version__ = "0.4.0"
40
53
  __author__ = "GenEval Team"
41
54
 
42
55
  # Main evaluation interface
@@ -109,6 +122,22 @@ from .metrics.accelerated import (
109
122
  compute_metrics_accelerated,
110
123
  )
111
124
 
125
+ # DEG-focused evaluation
126
+ from .deg import (
127
+ DEGEvaluator,
128
+ DEGResult,
129
+ DEGEvaluationResult,
130
+ ContextEvaluator,
131
+ ContextResult,
132
+ compute_degs_fast,
133
+ compute_degs_gpu,
134
+ get_contexts,
135
+ plot_deg_distributions,
136
+ plot_context_heatmap,
137
+ create_deg_report,
138
+ )
139
+ from .deg.evaluator import evaluate_degs
140
+
112
141
  # Visualization
113
142
  from .visualization.visualizer import (
114
143
  EvaluationVisualizer,
@@ -174,6 +203,19 @@ __all__ = [
174
203
  "ParallelMetricComputer",
175
204
  "get_available_backends",
176
205
  "compute_metrics_accelerated",
206
+ # DEG evaluation
207
+ "DEGEvaluator",
208
+ "DEGResult",
209
+ "DEGEvaluationResult",
210
+ "ContextEvaluator",
211
+ "ContextResult",
212
+ "compute_degs_fast",
213
+ "compute_degs_gpu",
214
+ "evaluate_degs",
215
+ "get_contexts",
216
+ "plot_deg_distributions",
217
+ "plot_context_heatmap",
218
+ "create_deg_report",
177
219
  # Visualization
178
220
  "EvaluationVisualizer",
179
221
  "visualize",
@@ -0,0 +1,65 @@
1
+ """
2
+ Differentially Expressed Genes (DEG) module for GenGeneEval.
3
+
4
+ This module provides:
5
+ - Fast DEG detection using vectorized statistical tests
6
+ - Per-context evaluation (covariates × perturbations)
7
+ - DEG-focused metrics computation
8
+ - Integration with GPU acceleration
9
+
10
+ Example usage:
11
+ >>> from geneval.deg import DEGEvaluator, compute_degs_fast
12
+ >>> degs = compute_degs_fast(control_data, perturbed_data, method="welch")
13
+ >>> evaluator = DEGEvaluator(loader, deg_method="welch", pval_threshold=0.05)
14
+ >>> results = evaluator.evaluate()
15
+ """
16
+
17
+ from .detection import (
18
+ compute_degs_fast,
19
+ compute_degs_gpu,
20
+ compute_degs_auto,
21
+ DEGResult,
22
+ DEGMethod,
23
+ )
24
+ from .context import (
25
+ ContextEvaluator,
26
+ ContextResult,
27
+ get_contexts,
28
+ get_context_id,
29
+ filter_by_context,
30
+ )
31
+ from .evaluator import (
32
+ DEGEvaluator,
33
+ DEGEvaluationResult,
34
+ evaluate_degs,
35
+ )
36
+ from .visualization import (
37
+ plot_deg_distributions,
38
+ plot_context_heatmap,
39
+ plot_deg_counts,
40
+ create_deg_report,
41
+ )
42
+
43
+ __all__ = [
44
+ # Detection
45
+ "compute_degs_fast",
46
+ "compute_degs_gpu",
47
+ "compute_degs_auto",
48
+ "DEGResult",
49
+ "DEGMethod",
50
+ # Context
51
+ "ContextEvaluator",
52
+ "ContextResult",
53
+ "get_contexts",
54
+ "get_context_id",
55
+ "filter_by_context",
56
+ # Evaluator
57
+ "DEGEvaluator",
58
+ "DEGEvaluationResult",
59
+ "evaluate_degs",
60
+ # Visualization
61
+ "plot_deg_distributions",
62
+ "plot_context_heatmap",
63
+ "plot_deg_counts",
64
+ "create_deg_report",
65
+ ]
@@ -0,0 +1,271 @@
1
+ """
2
+ Context-aware evaluation for gene expression data.
3
+
4
+ Supports per-context evaluation where context = covariates × perturbation.
5
+ If only perturbation column is given, evaluates per-perturbation.
6
+ If multiple condition columns are given, evaluates every combination.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional, List, Dict, Tuple, Union, Iterator
11
+ from dataclasses import dataclass, field
12
+ import numpy as np
13
+ import pandas as pd
14
+ from itertools import product
15
+
16
+
17
+ @dataclass
18
+ class ContextResult:
19
+ """Results for a single context (covariate × perturbation combination).
20
+
21
+ Attributes
22
+ ----------
23
+ context_id : str
24
+ Unique identifier for this context
25
+ context_values : Dict[str, str]
26
+ Values for each condition column
27
+ n_samples_real : int
28
+ Number of real samples in this context
29
+ n_samples_gen : int
30
+ Number of generated samples in this context
31
+ deg_result : Any, optional
32
+ DEG detection result for this context
33
+ metrics : Dict[str, float]
34
+ Computed metrics for this context
35
+ """
36
+ context_id: str
37
+ context_values: Dict[str, str]
38
+ n_samples_real: int
39
+ n_samples_gen: int
40
+ deg_result: Optional["DEGResult"] = None # Forward reference
41
+ metrics: Dict[str, float] = field(default_factory=dict)
42
+ per_gene_metrics: Dict[str, np.ndarray] = field(default_factory=dict)
43
+
44
+ def __repr__(self) -> str:
45
+ return (
46
+ f"ContextResult(id='{self.context_id}', "
47
+ f"n_real={self.n_samples_real}, n_gen={self.n_samples_gen}, "
48
+ f"n_degs={self.deg_result.n_degs if self.deg_result else 'N/A'})"
49
+ )
50
+
51
+
52
+ def get_contexts(
53
+ obs: pd.DataFrame,
54
+ condition_columns: List[str],
55
+ min_samples: int = 2,
56
+ ) -> List[Dict[str, str]]:
57
+ """
58
+ Get all unique contexts (combinations of condition values).
59
+
60
+ Parameters
61
+ ----------
62
+ obs : pd.DataFrame
63
+ Observation metadata (adata.obs)
64
+ condition_columns : List[str]
65
+ Columns to use for context definition
66
+ min_samples : int
67
+ Minimum samples required per context
68
+
69
+ Returns
70
+ -------
71
+ List[Dict[str, str]]
72
+ List of context dictionaries
73
+ """
74
+ if len(condition_columns) == 0:
75
+ return [{}]
76
+
77
+ # Get unique values for each column
78
+ unique_values = []
79
+ for col in condition_columns:
80
+ if col in obs.columns:
81
+ unique_values.append(obs[col].unique().tolist())
82
+ else:
83
+ raise ValueError(f"Column '{col}' not found in obs")
84
+
85
+ # Generate all combinations
86
+ contexts = []
87
+ for combo in product(*unique_values):
88
+ context = dict(zip(condition_columns, combo))
89
+
90
+ # Check if context has enough samples
91
+ mask = np.ones(len(obs), dtype=bool)
92
+ for col, val in context.items():
93
+ mask &= (obs[col] == val).values
94
+
95
+ if mask.sum() >= min_samples:
96
+ contexts.append(context)
97
+
98
+ return contexts
99
+
100
+
101
+ def get_context_id(context: Dict[str, str]) -> str:
102
+ """Generate unique ID for a context."""
103
+ if not context:
104
+ return "all"
105
+ return "_".join(f"{k}={v}" for k, v in sorted(context.items()))
106
+
107
+
108
+ def filter_by_context(
109
+ data: np.ndarray,
110
+ obs: pd.DataFrame,
111
+ context: Dict[str, str],
112
+ ) -> Tuple[np.ndarray, np.ndarray]:
113
+ """
114
+ Filter data by context.
115
+
116
+ Parameters
117
+ ----------
118
+ data : np.ndarray
119
+ Expression matrix (n_samples, n_genes)
120
+ obs : pd.DataFrame
121
+ Observation metadata
122
+ context : Dict[str, str]
123
+ Context to filter by
124
+
125
+ Returns
126
+ -------
127
+ Tuple[np.ndarray, np.ndarray]
128
+ Filtered data and mask
129
+ """
130
+ if not context:
131
+ return data, np.ones(len(obs), dtype=bool)
132
+
133
+ mask = np.ones(len(obs), dtype=bool)
134
+ for col, val in context.items():
135
+ mask &= (obs[col] == val).values
136
+
137
+ return data[mask], mask
138
+
139
+
140
+ class ContextEvaluator:
141
+ """
142
+ Evaluator that computes metrics per context.
143
+
144
+ A context is defined by the combination of all condition column values.
145
+ For example, if condition_columns = ["cell_type", "perturbation"],
146
+ each unique (cell_type, perturbation) pair is a context.
147
+
148
+ Parameters
149
+ ----------
150
+ real_data : np.ndarray
151
+ Real expression matrix (n_samples, n_genes)
152
+ generated_data : np.ndarray
153
+ Generated expression matrix (n_samples, n_genes)
154
+ real_obs : pd.DataFrame
155
+ Real data metadata
156
+ generated_obs : pd.DataFrame
157
+ Generated data metadata
158
+ condition_columns : List[str]
159
+ Columns defining contexts
160
+ gene_names : np.ndarray, optional
161
+ Gene names
162
+ control_key : str, optional
163
+ Value in perturbation column indicating control (for DEG computation)
164
+ perturbation_column : str, optional
165
+ Name of perturbation column (for DEG computation)
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ real_data: np.ndarray,
171
+ generated_data: np.ndarray,
172
+ real_obs: pd.DataFrame,
173
+ generated_obs: pd.DataFrame,
174
+ condition_columns: List[str],
175
+ gene_names: Optional[np.ndarray] = None,
176
+ control_key: str = "control",
177
+ perturbation_column: Optional[str] = None,
178
+ ):
179
+ self.real_data = real_data
180
+ self.generated_data = generated_data
181
+ self.real_obs = real_obs
182
+ self.generated_obs = generated_obs
183
+ self.condition_columns = condition_columns
184
+ self.gene_names = gene_names
185
+ self.control_key = control_key
186
+
187
+ # Determine perturbation column
188
+ if perturbation_column is not None:
189
+ self.perturbation_column = perturbation_column
190
+ elif len(condition_columns) > 0:
191
+ self.perturbation_column = condition_columns[0]
192
+ else:
193
+ self.perturbation_column = None
194
+
195
+ # Get contexts
196
+ self._real_contexts = get_contexts(real_obs, condition_columns)
197
+ self._gen_contexts = get_contexts(generated_obs, condition_columns)
198
+
199
+ # Find common contexts
200
+ real_ids = {get_context_id(c) for c in self._real_contexts}
201
+ gen_ids = {get_context_id(c) for c in self._gen_contexts}
202
+ common_ids = real_ids & gen_ids
203
+
204
+ self.contexts = [c for c in self._real_contexts if get_context_id(c) in common_ids]
205
+
206
+ # Cache control data for DEG computation
207
+ self._control_cache: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
208
+
209
+ def get_context_data(
210
+ self,
211
+ context: Dict[str, str],
212
+ ) -> Tuple[np.ndarray, np.ndarray]:
213
+ """Get real and generated data for a context."""
214
+ real_filtered, _ = filter_by_context(self.real_data, self.real_obs, context)
215
+ gen_filtered, _ = filter_by_context(self.generated_data, self.generated_obs, context)
216
+ return real_filtered, gen_filtered
217
+
218
+ def get_control_data(
219
+ self,
220
+ context: Dict[str, str],
221
+ ) -> Tuple[np.ndarray, np.ndarray]:
222
+ """
223
+ Get control data for DEG computation.
224
+
225
+ For a given context, finds the corresponding control by replacing
226
+ the perturbation value with control_key.
227
+ """
228
+ if self.perturbation_column is None:
229
+ raise ValueError("perturbation_column required for DEG computation")
230
+
231
+ # Create control context
232
+ control_context = context.copy()
233
+ control_context[self.perturbation_column] = self.control_key
234
+ context_id = get_context_id(control_context)
235
+
236
+ # Check cache
237
+ if context_id in self._control_cache:
238
+ return self._control_cache[context_id]
239
+
240
+ # Get control data
241
+ real_control, _ = filter_by_context(self.real_data, self.real_obs, control_context)
242
+ gen_control, _ = filter_by_context(self.generated_data, self.generated_obs, control_context)
243
+
244
+ self._control_cache[context_id] = (real_control, gen_control)
245
+ return real_control, gen_control
246
+
247
+ def iter_contexts(self) -> Iterator[Tuple[str, Dict[str, str], np.ndarray, np.ndarray]]:
248
+ """Iterate over contexts with their data."""
249
+ for context in self.contexts:
250
+ context_id = get_context_id(context)
251
+ real_data, gen_data = self.get_context_data(context)
252
+ yield context_id, context, real_data, gen_data
253
+
254
+ def is_control_context(self, context: Dict[str, str]) -> bool:
255
+ """Check if context is a control (not perturbed)."""
256
+ if self.perturbation_column is None:
257
+ return False
258
+ return context.get(self.perturbation_column) == self.control_key
259
+
260
+ def get_perturbation_contexts(self) -> List[Dict[str, str]]:
261
+ """Get only perturbation contexts (excluding controls)."""
262
+ return [c for c in self.contexts if not self.is_control_context(c)]
263
+
264
+ def __len__(self) -> int:
265
+ return len(self.contexts)
266
+
267
+ def __repr__(self) -> str:
268
+ return (
269
+ f"ContextEvaluator(n_contexts={len(self.contexts)}, "
270
+ f"condition_columns={self.condition_columns})"
271
+ )