gengeneeval 0.2.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/__init__.py CHANGED
@@ -7,8 +7,10 @@ and generated gene expression datasets stored in AnnData (h5ad) format.
7
7
  Features:
8
8
  - Multiple distance and correlation metrics (per-gene and aggregate)
9
9
  - Condition-based matching (perturbation, cell type, etc.)
10
+ - DEG-focused evaluation with per-context (covariate × perturbation) support
10
11
  - Train/test split support
11
12
  - Memory-efficient lazy loading for large datasets
13
+ - CPU parallelization and GPU acceleration
12
14
  - Publication-quality visualizations
13
15
  - Command-line interface
14
16
 
@@ -21,6 +23,17 @@ Quick Start:
21
23
  ... output_dir="output/"
22
24
  ... )
23
25
 
26
+ DEG-Focused Evaluation:
27
+ >>> from geneval import evaluate_degs
28
+ >>> results = evaluate_degs(
29
+ ... real_data, generated_data,
30
+ ... real_obs, generated_obs,
31
+ ... condition_columns=["cell_type", "perturbation"],
32
+ ... control_key="control",
33
+ ... deg_method="welch",
34
+ ... device="cuda", # GPU acceleration
35
+ ... )
36
+
24
37
  Memory-Efficient Mode (for large datasets):
25
38
  >>> from geneval import evaluate_lazy
26
39
  >>> results = evaluate_lazy(
@@ -36,7 +49,7 @@ CLI Usage:
36
49
  --conditions perturbation cell_type --output results/
37
50
  """
38
51
 
39
- __version__ = "0.2.1"
52
+ __version__ = "0.4.0"
40
53
  __author__ = "GenEval Team"
41
54
 
42
55
  # Main evaluation interface
@@ -101,6 +114,30 @@ from .metrics.reconstruction import (
101
114
  R2Score,
102
115
  )
103
116
 
117
+ # Accelerated computation
118
+ from .metrics.accelerated import (
119
+ AccelerationConfig,
120
+ ParallelMetricComputer,
121
+ get_available_backends,
122
+ compute_metrics_accelerated,
123
+ )
124
+
125
+ # DEG-focused evaluation
126
+ from .deg import (
127
+ DEGEvaluator,
128
+ DEGResult,
129
+ DEGEvaluationResult,
130
+ ContextEvaluator,
131
+ ContextResult,
132
+ compute_degs_fast,
133
+ compute_degs_gpu,
134
+ get_contexts,
135
+ plot_deg_distributions,
136
+ plot_context_heatmap,
137
+ create_deg_report,
138
+ )
139
+ from .deg.evaluator import evaluate_degs
140
+
104
141
  # Visualization
105
142
  from .visualization.visualizer import (
106
143
  EvaluationVisualizer,
@@ -161,6 +198,24 @@ __all__ = [
161
198
  "RMSEDistance",
162
199
  "MAEDistance",
163
200
  "R2Score",
201
+ # Acceleration
202
+ "AccelerationConfig",
203
+ "ParallelMetricComputer",
204
+ "get_available_backends",
205
+ "compute_metrics_accelerated",
206
+ # DEG evaluation
207
+ "DEGEvaluator",
208
+ "DEGResult",
209
+ "DEGEvaluationResult",
210
+ "ContextEvaluator",
211
+ "ContextResult",
212
+ "compute_degs_fast",
213
+ "compute_degs_gpu",
214
+ "evaluate_degs",
215
+ "get_contexts",
216
+ "plot_deg_distributions",
217
+ "plot_context_heatmap",
218
+ "create_deg_report",
164
219
  # Visualization
165
220
  "EvaluationVisualizer",
166
221
  "visualize",
@@ -0,0 +1,65 @@
1
+ """
2
+ Differentially Expressed Genes (DEG) module for GenGeneEval.
3
+
4
+ This module provides:
5
+ - Fast DEG detection using vectorized statistical tests
6
+ - Per-context evaluation (covariates × perturbations)
7
+ - DEG-focused metrics computation
8
+ - Integration with GPU acceleration
9
+
10
+ Example usage:
11
+ >>> from geneval.deg import DEGEvaluator, compute_degs_fast
12
+ >>> degs = compute_degs_fast(control_data, perturbed_data, method="welch")
13
+ >>> evaluator = DEGEvaluator(loader, deg_method="welch", pval_threshold=0.05)
14
+ >>> results = evaluator.evaluate()
15
+ """
16
+
17
+ from .detection import (
18
+ compute_degs_fast,
19
+ compute_degs_gpu,
20
+ compute_degs_auto,
21
+ DEGResult,
22
+ DEGMethod,
23
+ )
24
+ from .context import (
25
+ ContextEvaluator,
26
+ ContextResult,
27
+ get_contexts,
28
+ get_context_id,
29
+ filter_by_context,
30
+ )
31
+ from .evaluator import (
32
+ DEGEvaluator,
33
+ DEGEvaluationResult,
34
+ evaluate_degs,
35
+ )
36
+ from .visualization import (
37
+ plot_deg_distributions,
38
+ plot_context_heatmap,
39
+ plot_deg_counts,
40
+ create_deg_report,
41
+ )
42
+
43
+ __all__ = [
44
+ # Detection
45
+ "compute_degs_fast",
46
+ "compute_degs_gpu",
47
+ "compute_degs_auto",
48
+ "DEGResult",
49
+ "DEGMethod",
50
+ # Context
51
+ "ContextEvaluator",
52
+ "ContextResult",
53
+ "get_contexts",
54
+ "get_context_id",
55
+ "filter_by_context",
56
+ # Evaluator
57
+ "DEGEvaluator",
58
+ "DEGEvaluationResult",
59
+ "evaluate_degs",
60
+ # Visualization
61
+ "plot_deg_distributions",
62
+ "plot_context_heatmap",
63
+ "plot_deg_counts",
64
+ "create_deg_report",
65
+ ]
geneval/deg/context.py ADDED
@@ -0,0 +1,271 @@
1
+ """
2
+ Context-aware evaluation for gene expression data.
3
+
4
+ Supports per-context evaluation where context = covariates × perturbation.
5
+ If only perturbation column is given, evaluates per-perturbation.
6
+ If multiple condition columns are given, evaluates every combination.
7
+ """
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional, List, Dict, Tuple, Union, Iterator
11
+ from dataclasses import dataclass, field
12
+ import numpy as np
13
+ import pandas as pd
14
+ from itertools import product
15
+
16
+
17
+ @dataclass
18
+ class ContextResult:
19
+ """Results for a single context (covariate × perturbation combination).
20
+
21
+ Attributes
22
+ ----------
23
+ context_id : str
24
+ Unique identifier for this context
25
+ context_values : Dict[str, str]
26
+ Values for each condition column
27
+ n_samples_real : int
28
+ Number of real samples in this context
29
+ n_samples_gen : int
30
+ Number of generated samples in this context
31
+ deg_result : Any, optional
32
+ DEG detection result for this context
33
+ metrics : Dict[str, float]
34
+ Computed metrics for this context
35
+ """
36
+ context_id: str
37
+ context_values: Dict[str, str]
38
+ n_samples_real: int
39
+ n_samples_gen: int
40
+ deg_result: Optional["DEGResult"] = None # Forward reference
41
+ metrics: Dict[str, float] = field(default_factory=dict)
42
+ per_gene_metrics: Dict[str, np.ndarray] = field(default_factory=dict)
43
+
44
+ def __repr__(self) -> str:
45
+ return (
46
+ f"ContextResult(id='{self.context_id}', "
47
+ f"n_real={self.n_samples_real}, n_gen={self.n_samples_gen}, "
48
+ f"n_degs={self.deg_result.n_degs if self.deg_result else 'N/A'})"
49
+ )
50
+
51
+
52
+ def get_contexts(
53
+ obs: pd.DataFrame,
54
+ condition_columns: List[str],
55
+ min_samples: int = 2,
56
+ ) -> List[Dict[str, str]]:
57
+ """
58
+ Get all unique contexts (combinations of condition values).
59
+
60
+ Parameters
61
+ ----------
62
+ obs : pd.DataFrame
63
+ Observation metadata (adata.obs)
64
+ condition_columns : List[str]
65
+ Columns to use for context definition
66
+ min_samples : int
67
+ Minimum samples required per context
68
+
69
+ Returns
70
+ -------
71
+ List[Dict[str, str]]
72
+ List of context dictionaries
73
+ """
74
+ if len(condition_columns) == 0:
75
+ return [{}]
76
+
77
+ # Get unique values for each column
78
+ unique_values = []
79
+ for col in condition_columns:
80
+ if col in obs.columns:
81
+ unique_values.append(obs[col].unique().tolist())
82
+ else:
83
+ raise ValueError(f"Column '{col}' not found in obs")
84
+
85
+ # Generate all combinations
86
+ contexts = []
87
+ for combo in product(*unique_values):
88
+ context = dict(zip(condition_columns, combo))
89
+
90
+ # Check if context has enough samples
91
+ mask = np.ones(len(obs), dtype=bool)
92
+ for col, val in context.items():
93
+ mask &= (obs[col] == val).values
94
+
95
+ if mask.sum() >= min_samples:
96
+ contexts.append(context)
97
+
98
+ return contexts
99
+
100
+
101
+ def get_context_id(context: Dict[str, str]) -> str:
102
+ """Generate unique ID for a context."""
103
+ if not context:
104
+ return "all"
105
+ return "_".join(f"{k}={v}" for k, v in sorted(context.items()))
106
+
107
+
108
+ def filter_by_context(
109
+ data: np.ndarray,
110
+ obs: pd.DataFrame,
111
+ context: Dict[str, str],
112
+ ) -> Tuple[np.ndarray, np.ndarray]:
113
+ """
114
+ Filter data by context.
115
+
116
+ Parameters
117
+ ----------
118
+ data : np.ndarray
119
+ Expression matrix (n_samples, n_genes)
120
+ obs : pd.DataFrame
121
+ Observation metadata
122
+ context : Dict[str, str]
123
+ Context to filter by
124
+
125
+ Returns
126
+ -------
127
+ Tuple[np.ndarray, np.ndarray]
128
+ Filtered data and mask
129
+ """
130
+ if not context:
131
+ return data, np.ones(len(obs), dtype=bool)
132
+
133
+ mask = np.ones(len(obs), dtype=bool)
134
+ for col, val in context.items():
135
+ mask &= (obs[col] == val).values
136
+
137
+ return data[mask], mask
138
+
139
+
140
+ class ContextEvaluator:
141
+ """
142
+ Evaluator that computes metrics per context.
143
+
144
+ A context is defined by the combination of all condition column values.
145
+ For example, if condition_columns = ["cell_type", "perturbation"],
146
+ each unique (cell_type, perturbation) pair is a context.
147
+
148
+ Parameters
149
+ ----------
150
+ real_data : np.ndarray
151
+ Real expression matrix (n_samples, n_genes)
152
+ generated_data : np.ndarray
153
+ Generated expression matrix (n_samples, n_genes)
154
+ real_obs : pd.DataFrame
155
+ Real data metadata
156
+ generated_obs : pd.DataFrame
157
+ Generated data metadata
158
+ condition_columns : List[str]
159
+ Columns defining contexts
160
+ gene_names : np.ndarray, optional
161
+ Gene names
162
+ control_key : str, optional
163
+ Value in perturbation column indicating control (for DEG computation)
164
+ perturbation_column : str, optional
165
+ Name of perturbation column (for DEG computation)
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ real_data: np.ndarray,
171
+ generated_data: np.ndarray,
172
+ real_obs: pd.DataFrame,
173
+ generated_obs: pd.DataFrame,
174
+ condition_columns: List[str],
175
+ gene_names: Optional[np.ndarray] = None,
176
+ control_key: str = "control",
177
+ perturbation_column: Optional[str] = None,
178
+ ):
179
+ self.real_data = real_data
180
+ self.generated_data = generated_data
181
+ self.real_obs = real_obs
182
+ self.generated_obs = generated_obs
183
+ self.condition_columns = condition_columns
184
+ self.gene_names = gene_names
185
+ self.control_key = control_key
186
+
187
+ # Determine perturbation column
188
+ if perturbation_column is not None:
189
+ self.perturbation_column = perturbation_column
190
+ elif len(condition_columns) > 0:
191
+ self.perturbation_column = condition_columns[0]
192
+ else:
193
+ self.perturbation_column = None
194
+
195
+ # Get contexts
196
+ self._real_contexts = get_contexts(real_obs, condition_columns)
197
+ self._gen_contexts = get_contexts(generated_obs, condition_columns)
198
+
199
+ # Find common contexts
200
+ real_ids = {get_context_id(c) for c in self._real_contexts}
201
+ gen_ids = {get_context_id(c) for c in self._gen_contexts}
202
+ common_ids = real_ids & gen_ids
203
+
204
+ self.contexts = [c for c in self._real_contexts if get_context_id(c) in common_ids]
205
+
206
+ # Cache control data for DEG computation
207
+ self._control_cache: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
208
+
209
+ def get_context_data(
210
+ self,
211
+ context: Dict[str, str],
212
+ ) -> Tuple[np.ndarray, np.ndarray]:
213
+ """Get real and generated data for a context."""
214
+ real_filtered, _ = filter_by_context(self.real_data, self.real_obs, context)
215
+ gen_filtered, _ = filter_by_context(self.generated_data, self.generated_obs, context)
216
+ return real_filtered, gen_filtered
217
+
218
+ def get_control_data(
219
+ self,
220
+ context: Dict[str, str],
221
+ ) -> Tuple[np.ndarray, np.ndarray]:
222
+ """
223
+ Get control data for DEG computation.
224
+
225
+ For a given context, finds the corresponding control by replacing
226
+ the perturbation value with control_key.
227
+ """
228
+ if self.perturbation_column is None:
229
+ raise ValueError("perturbation_column required for DEG computation")
230
+
231
+ # Create control context
232
+ control_context = context.copy()
233
+ control_context[self.perturbation_column] = self.control_key
234
+ context_id = get_context_id(control_context)
235
+
236
+ # Check cache
237
+ if context_id in self._control_cache:
238
+ return self._control_cache[context_id]
239
+
240
+ # Get control data
241
+ real_control, _ = filter_by_context(self.real_data, self.real_obs, control_context)
242
+ gen_control, _ = filter_by_context(self.generated_data, self.generated_obs, control_context)
243
+
244
+ self._control_cache[context_id] = (real_control, gen_control)
245
+ return real_control, gen_control
246
+
247
+ def iter_contexts(self) -> Iterator[Tuple[str, Dict[str, str], np.ndarray, np.ndarray]]:
248
+ """Iterate over contexts with their data."""
249
+ for context in self.contexts:
250
+ context_id = get_context_id(context)
251
+ real_data, gen_data = self.get_context_data(context)
252
+ yield context_id, context, real_data, gen_data
253
+
254
+ def is_control_context(self, context: Dict[str, str]) -> bool:
255
+ """Check if context is a control (not perturbed)."""
256
+ if self.perturbation_column is None:
257
+ return False
258
+ return context.get(self.perturbation_column) == self.control_key
259
+
260
+ def get_perturbation_contexts(self) -> List[Dict[str, str]]:
261
+ """Get only perturbation contexts (excluding controls)."""
262
+ return [c for c in self.contexts if not self.is_control_context(c)]
263
+
264
+ def __len__(self) -> int:
265
+ return len(self.contexts)
266
+
267
+ def __repr__(self) -> str:
268
+ return (
269
+ f"ContextEvaluator(n_contexts={len(self.contexts)}, "
270
+ f"condition_columns={self.condition_columns})"
271
+ )