gengeneeval 0.3.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +43 -1
- geneval/deg/__init__.py +69 -0
- geneval/deg/context.py +271 -0
- geneval/deg/detection.py +578 -0
- geneval/deg/evaluator.py +821 -0
- geneval/deg/visualization.py +376 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/METADATA +125 -3
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/RECORD +11 -6
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/WHEEL +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.3.0.dist-info → gengeneeval-0.4.1.dist-info}/licenses/LICENSE +0 -0
geneval/__init__.py
CHANGED
|
@@ -7,8 +7,10 @@ and generated gene expression datasets stored in AnnData (h5ad) format.
|
|
|
7
7
|
Features:
|
|
8
8
|
- Multiple distance and correlation metrics (per-gene and aggregate)
|
|
9
9
|
- Condition-based matching (perturbation, cell type, etc.)
|
|
10
|
+
- DEG-focused evaluation with per-context (covariate × perturbation) support
|
|
10
11
|
- Train/test split support
|
|
11
12
|
- Memory-efficient lazy loading for large datasets
|
|
13
|
+
- CPU parallelization and GPU acceleration
|
|
12
14
|
- Publication-quality visualizations
|
|
13
15
|
- Command-line interface
|
|
14
16
|
|
|
@@ -21,6 +23,17 @@ Quick Start:
|
|
|
21
23
|
... output_dir="output/"
|
|
22
24
|
... )
|
|
23
25
|
|
|
26
|
+
DEG-Focused Evaluation:
|
|
27
|
+
>>> from geneval import evaluate_degs
|
|
28
|
+
>>> results = evaluate_degs(
|
|
29
|
+
... real_data, generated_data,
|
|
30
|
+
... real_obs, generated_obs,
|
|
31
|
+
... condition_columns=["cell_type", "perturbation"],
|
|
32
|
+
... control_key="control",
|
|
33
|
+
... deg_method="welch",
|
|
34
|
+
... device="cuda", # GPU acceleration
|
|
35
|
+
... )
|
|
36
|
+
|
|
24
37
|
Memory-Efficient Mode (for large datasets):
|
|
25
38
|
>>> from geneval import evaluate_lazy
|
|
26
39
|
>>> results = evaluate_lazy(
|
|
@@ -36,7 +49,7 @@ CLI Usage:
|
|
|
36
49
|
--conditions perturbation cell_type --output results/
|
|
37
50
|
"""
|
|
38
51
|
|
|
39
|
-
__version__ = "0.
|
|
52
|
+
__version__ = "0.4.1"
|
|
40
53
|
__author__ = "GenEval Team"
|
|
41
54
|
|
|
42
55
|
# Main evaluation interface
|
|
@@ -109,6 +122,22 @@ from .metrics.accelerated import (
|
|
|
109
122
|
compute_metrics_accelerated,
|
|
110
123
|
)
|
|
111
124
|
|
|
125
|
+
# DEG-focused evaluation
|
|
126
|
+
from .deg import (
|
|
127
|
+
DEGEvaluator,
|
|
128
|
+
DEGResult,
|
|
129
|
+
DEGEvaluationResult,
|
|
130
|
+
ContextEvaluator,
|
|
131
|
+
ContextResult,
|
|
132
|
+
compute_degs_fast,
|
|
133
|
+
compute_degs_gpu,
|
|
134
|
+
get_contexts,
|
|
135
|
+
plot_deg_distributions,
|
|
136
|
+
plot_context_heatmap,
|
|
137
|
+
create_deg_report,
|
|
138
|
+
)
|
|
139
|
+
from .deg.evaluator import evaluate_degs
|
|
140
|
+
|
|
112
141
|
# Visualization
|
|
113
142
|
from .visualization.visualizer import (
|
|
114
143
|
EvaluationVisualizer,
|
|
@@ -174,6 +203,19 @@ __all__ = [
|
|
|
174
203
|
"ParallelMetricComputer",
|
|
175
204
|
"get_available_backends",
|
|
176
205
|
"compute_metrics_accelerated",
|
|
206
|
+
# DEG evaluation
|
|
207
|
+
"DEGEvaluator",
|
|
208
|
+
"DEGResult",
|
|
209
|
+
"DEGEvaluationResult",
|
|
210
|
+
"ContextEvaluator",
|
|
211
|
+
"ContextResult",
|
|
212
|
+
"compute_degs_fast",
|
|
213
|
+
"compute_degs_gpu",
|
|
214
|
+
"evaluate_degs",
|
|
215
|
+
"get_contexts",
|
|
216
|
+
"plot_deg_distributions",
|
|
217
|
+
"plot_context_heatmap",
|
|
218
|
+
"create_deg_report",
|
|
177
219
|
# Visualization
|
|
178
220
|
"EvaluationVisualizer",
|
|
179
221
|
"visualize",
|
geneval/deg/__init__.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Differentially Expressed Genes (DEG) module for GenGeneEval.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
- Fast DEG detection using vectorized statistical tests
|
|
6
|
+
- Per-context evaluation (covariates × perturbations)
|
|
7
|
+
- DEG-focused metrics computation
|
|
8
|
+
- Integration with GPU acceleration
|
|
9
|
+
|
|
10
|
+
Example usage:
|
|
11
|
+
>>> from geneval.deg import DEGEvaluator, compute_degs_fast
|
|
12
|
+
>>> degs = compute_degs_fast(control_data, perturbed_data, method="welch")
|
|
13
|
+
>>> evaluator = DEGEvaluator(loader, deg_method="welch", pval_threshold=0.05)
|
|
14
|
+
>>> results = evaluator.evaluate()
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from .detection import (
|
|
18
|
+
compute_degs_fast,
|
|
19
|
+
compute_degs_gpu,
|
|
20
|
+
compute_degs_auto,
|
|
21
|
+
DEGResult,
|
|
22
|
+
DEGMethod,
|
|
23
|
+
)
|
|
24
|
+
from .context import (
|
|
25
|
+
ContextEvaluator,
|
|
26
|
+
ContextResult,
|
|
27
|
+
get_contexts,
|
|
28
|
+
get_context_id,
|
|
29
|
+
filter_by_context,
|
|
30
|
+
)
|
|
31
|
+
from .evaluator import (
|
|
32
|
+
DEGEvaluator,
|
|
33
|
+
DEGEvaluationResult,
|
|
34
|
+
DEGSettings,
|
|
35
|
+
ContextMetrics,
|
|
36
|
+
evaluate_degs,
|
|
37
|
+
)
|
|
38
|
+
from .visualization import (
|
|
39
|
+
plot_deg_distributions,
|
|
40
|
+
plot_context_heatmap,
|
|
41
|
+
plot_deg_counts,
|
|
42
|
+
create_deg_report,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
# Detection
|
|
47
|
+
"compute_degs_fast",
|
|
48
|
+
"compute_degs_gpu",
|
|
49
|
+
"compute_degs_auto",
|
|
50
|
+
"DEGResult",
|
|
51
|
+
"DEGMethod",
|
|
52
|
+
# Context
|
|
53
|
+
"ContextEvaluator",
|
|
54
|
+
"ContextResult",
|
|
55
|
+
"get_contexts",
|
|
56
|
+
"get_context_id",
|
|
57
|
+
"filter_by_context",
|
|
58
|
+
# Evaluator
|
|
59
|
+
"DEGEvaluator",
|
|
60
|
+
"DEGEvaluationResult",
|
|
61
|
+
"DEGSettings",
|
|
62
|
+
"ContextMetrics",
|
|
63
|
+
"evaluate_degs",
|
|
64
|
+
# Visualization
|
|
65
|
+
"plot_deg_distributions",
|
|
66
|
+
"plot_context_heatmap",
|
|
67
|
+
"plot_deg_counts",
|
|
68
|
+
"create_deg_report",
|
|
69
|
+
]
|
geneval/deg/context.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context-aware evaluation for gene expression data.
|
|
3
|
+
|
|
4
|
+
Supports per-context evaluation where context = covariates × perturbation.
|
|
5
|
+
If only perturbation column is given, evaluates per-perturbation.
|
|
6
|
+
If multiple condition columns are given, evaluates every combination.
|
|
7
|
+
"""
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Optional, List, Dict, Tuple, Union, Iterator
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from itertools import product
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ContextResult:
|
|
19
|
+
"""Results for a single context (covariate × perturbation combination).
|
|
20
|
+
|
|
21
|
+
Attributes
|
|
22
|
+
----------
|
|
23
|
+
context_id : str
|
|
24
|
+
Unique identifier for this context
|
|
25
|
+
context_values : Dict[str, str]
|
|
26
|
+
Values for each condition column
|
|
27
|
+
n_samples_real : int
|
|
28
|
+
Number of real samples in this context
|
|
29
|
+
n_samples_gen : int
|
|
30
|
+
Number of generated samples in this context
|
|
31
|
+
deg_result : Any, optional
|
|
32
|
+
DEG detection result for this context
|
|
33
|
+
metrics : Dict[str, float]
|
|
34
|
+
Computed metrics for this context
|
|
35
|
+
"""
|
|
36
|
+
context_id: str
|
|
37
|
+
context_values: Dict[str, str]
|
|
38
|
+
n_samples_real: int
|
|
39
|
+
n_samples_gen: int
|
|
40
|
+
deg_result: Optional["DEGResult"] = None # Forward reference
|
|
41
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
42
|
+
per_gene_metrics: Dict[str, np.ndarray] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
return (
|
|
46
|
+
f"ContextResult(id='{self.context_id}', "
|
|
47
|
+
f"n_real={self.n_samples_real}, n_gen={self.n_samples_gen}, "
|
|
48
|
+
f"n_degs={self.deg_result.n_degs if self.deg_result else 'N/A'})"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_contexts(
|
|
53
|
+
obs: pd.DataFrame,
|
|
54
|
+
condition_columns: List[str],
|
|
55
|
+
min_samples: int = 2,
|
|
56
|
+
) -> List[Dict[str, str]]:
|
|
57
|
+
"""
|
|
58
|
+
Get all unique contexts (combinations of condition values).
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
obs : pd.DataFrame
|
|
63
|
+
Observation metadata (adata.obs)
|
|
64
|
+
condition_columns : List[str]
|
|
65
|
+
Columns to use for context definition
|
|
66
|
+
min_samples : int
|
|
67
|
+
Minimum samples required per context
|
|
68
|
+
|
|
69
|
+
Returns
|
|
70
|
+
-------
|
|
71
|
+
List[Dict[str, str]]
|
|
72
|
+
List of context dictionaries
|
|
73
|
+
"""
|
|
74
|
+
if len(condition_columns) == 0:
|
|
75
|
+
return [{}]
|
|
76
|
+
|
|
77
|
+
# Get unique values for each column
|
|
78
|
+
unique_values = []
|
|
79
|
+
for col in condition_columns:
|
|
80
|
+
if col in obs.columns:
|
|
81
|
+
unique_values.append(obs[col].unique().tolist())
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(f"Column '{col}' not found in obs")
|
|
84
|
+
|
|
85
|
+
# Generate all combinations
|
|
86
|
+
contexts = []
|
|
87
|
+
for combo in product(*unique_values):
|
|
88
|
+
context = dict(zip(condition_columns, combo))
|
|
89
|
+
|
|
90
|
+
# Check if context has enough samples
|
|
91
|
+
mask = np.ones(len(obs), dtype=bool)
|
|
92
|
+
for col, val in context.items():
|
|
93
|
+
mask &= (obs[col] == val).values
|
|
94
|
+
|
|
95
|
+
if mask.sum() >= min_samples:
|
|
96
|
+
contexts.append(context)
|
|
97
|
+
|
|
98
|
+
return contexts
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_context_id(context: Dict[str, str]) -> str:
|
|
102
|
+
"""Generate unique ID for a context."""
|
|
103
|
+
if not context:
|
|
104
|
+
return "all"
|
|
105
|
+
return "_".join(f"{k}={v}" for k, v in sorted(context.items()))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def filter_by_context(
|
|
109
|
+
data: np.ndarray,
|
|
110
|
+
obs: pd.DataFrame,
|
|
111
|
+
context: Dict[str, str],
|
|
112
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
113
|
+
"""
|
|
114
|
+
Filter data by context.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
data : np.ndarray
|
|
119
|
+
Expression matrix (n_samples, n_genes)
|
|
120
|
+
obs : pd.DataFrame
|
|
121
|
+
Observation metadata
|
|
122
|
+
context : Dict[str, str]
|
|
123
|
+
Context to filter by
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Tuple[np.ndarray, np.ndarray]
|
|
128
|
+
Filtered data and mask
|
|
129
|
+
"""
|
|
130
|
+
if not context:
|
|
131
|
+
return data, np.ones(len(obs), dtype=bool)
|
|
132
|
+
|
|
133
|
+
mask = np.ones(len(obs), dtype=bool)
|
|
134
|
+
for col, val in context.items():
|
|
135
|
+
mask &= (obs[col] == val).values
|
|
136
|
+
|
|
137
|
+
return data[mask], mask
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class ContextEvaluator:
|
|
141
|
+
"""
|
|
142
|
+
Evaluator that computes metrics per context.
|
|
143
|
+
|
|
144
|
+
A context is defined by the combination of all condition column values.
|
|
145
|
+
For example, if condition_columns = ["cell_type", "perturbation"],
|
|
146
|
+
each unique (cell_type, perturbation) pair is a context.
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
real_data : np.ndarray
|
|
151
|
+
Real expression matrix (n_samples, n_genes)
|
|
152
|
+
generated_data : np.ndarray
|
|
153
|
+
Generated expression matrix (n_samples, n_genes)
|
|
154
|
+
real_obs : pd.DataFrame
|
|
155
|
+
Real data metadata
|
|
156
|
+
generated_obs : pd.DataFrame
|
|
157
|
+
Generated data metadata
|
|
158
|
+
condition_columns : List[str]
|
|
159
|
+
Columns defining contexts
|
|
160
|
+
gene_names : np.ndarray, optional
|
|
161
|
+
Gene names
|
|
162
|
+
control_key : str, optional
|
|
163
|
+
Value in perturbation column indicating control (for DEG computation)
|
|
164
|
+
perturbation_column : str, optional
|
|
165
|
+
Name of perturbation column (for DEG computation)
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
real_data: np.ndarray,
|
|
171
|
+
generated_data: np.ndarray,
|
|
172
|
+
real_obs: pd.DataFrame,
|
|
173
|
+
generated_obs: pd.DataFrame,
|
|
174
|
+
condition_columns: List[str],
|
|
175
|
+
gene_names: Optional[np.ndarray] = None,
|
|
176
|
+
control_key: str = "control",
|
|
177
|
+
perturbation_column: Optional[str] = None,
|
|
178
|
+
):
|
|
179
|
+
self.real_data = real_data
|
|
180
|
+
self.generated_data = generated_data
|
|
181
|
+
self.real_obs = real_obs
|
|
182
|
+
self.generated_obs = generated_obs
|
|
183
|
+
self.condition_columns = condition_columns
|
|
184
|
+
self.gene_names = gene_names
|
|
185
|
+
self.control_key = control_key
|
|
186
|
+
|
|
187
|
+
# Determine perturbation column
|
|
188
|
+
if perturbation_column is not None:
|
|
189
|
+
self.perturbation_column = perturbation_column
|
|
190
|
+
elif len(condition_columns) > 0:
|
|
191
|
+
self.perturbation_column = condition_columns[0]
|
|
192
|
+
else:
|
|
193
|
+
self.perturbation_column = None
|
|
194
|
+
|
|
195
|
+
# Get contexts
|
|
196
|
+
self._real_contexts = get_contexts(real_obs, condition_columns)
|
|
197
|
+
self._gen_contexts = get_contexts(generated_obs, condition_columns)
|
|
198
|
+
|
|
199
|
+
# Find common contexts
|
|
200
|
+
real_ids = {get_context_id(c) for c in self._real_contexts}
|
|
201
|
+
gen_ids = {get_context_id(c) for c in self._gen_contexts}
|
|
202
|
+
common_ids = real_ids & gen_ids
|
|
203
|
+
|
|
204
|
+
self.contexts = [c for c in self._real_contexts if get_context_id(c) in common_ids]
|
|
205
|
+
|
|
206
|
+
# Cache control data for DEG computation
|
|
207
|
+
self._control_cache: Dict[str, Tuple[np.ndarray, np.ndarray]] = {}
|
|
208
|
+
|
|
209
|
+
def get_context_data(
|
|
210
|
+
self,
|
|
211
|
+
context: Dict[str, str],
|
|
212
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
213
|
+
"""Get real and generated data for a context."""
|
|
214
|
+
real_filtered, _ = filter_by_context(self.real_data, self.real_obs, context)
|
|
215
|
+
gen_filtered, _ = filter_by_context(self.generated_data, self.generated_obs, context)
|
|
216
|
+
return real_filtered, gen_filtered
|
|
217
|
+
|
|
218
|
+
def get_control_data(
|
|
219
|
+
self,
|
|
220
|
+
context: Dict[str, str],
|
|
221
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
222
|
+
"""
|
|
223
|
+
Get control data for DEG computation.
|
|
224
|
+
|
|
225
|
+
For a given context, finds the corresponding control by replacing
|
|
226
|
+
the perturbation value with control_key.
|
|
227
|
+
"""
|
|
228
|
+
if self.perturbation_column is None:
|
|
229
|
+
raise ValueError("perturbation_column required for DEG computation")
|
|
230
|
+
|
|
231
|
+
# Create control context
|
|
232
|
+
control_context = context.copy()
|
|
233
|
+
control_context[self.perturbation_column] = self.control_key
|
|
234
|
+
context_id = get_context_id(control_context)
|
|
235
|
+
|
|
236
|
+
# Check cache
|
|
237
|
+
if context_id in self._control_cache:
|
|
238
|
+
return self._control_cache[context_id]
|
|
239
|
+
|
|
240
|
+
# Get control data
|
|
241
|
+
real_control, _ = filter_by_context(self.real_data, self.real_obs, control_context)
|
|
242
|
+
gen_control, _ = filter_by_context(self.generated_data, self.generated_obs, control_context)
|
|
243
|
+
|
|
244
|
+
self._control_cache[context_id] = (real_control, gen_control)
|
|
245
|
+
return real_control, gen_control
|
|
246
|
+
|
|
247
|
+
def iter_contexts(self) -> Iterator[Tuple[str, Dict[str, str], np.ndarray, np.ndarray]]:
|
|
248
|
+
"""Iterate over contexts with their data."""
|
|
249
|
+
for context in self.contexts:
|
|
250
|
+
context_id = get_context_id(context)
|
|
251
|
+
real_data, gen_data = self.get_context_data(context)
|
|
252
|
+
yield context_id, context, real_data, gen_data
|
|
253
|
+
|
|
254
|
+
def is_control_context(self, context: Dict[str, str]) -> bool:
|
|
255
|
+
"""Check if context is a control (not perturbed)."""
|
|
256
|
+
if self.perturbation_column is None:
|
|
257
|
+
return False
|
|
258
|
+
return context.get(self.perturbation_column) == self.control_key
|
|
259
|
+
|
|
260
|
+
def get_perturbation_contexts(self) -> List[Dict[str, str]]:
|
|
261
|
+
"""Get only perturbation contexts (excluding controls)."""
|
|
262
|
+
return [c for c in self.contexts if not self.is_control_context(c)]
|
|
263
|
+
|
|
264
|
+
def __len__(self) -> int:
|
|
265
|
+
return len(self.contexts)
|
|
266
|
+
|
|
267
|
+
def __repr__(self) -> str:
|
|
268
|
+
return (
|
|
269
|
+
f"ContextEvaluator(n_contexts={len(self.contexts)}, "
|
|
270
|
+
f"condition_columns={self.condition_columns})"
|
|
271
|
+
)
|