gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
geneval/__init__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GenEval: Comprehensive evaluation of generated gene expression data.
|
|
3
|
+
|
|
4
|
+
A modular, object-oriented framework for computing metrics between real
|
|
5
|
+
and generated gene expression datasets stored in AnnData (h5ad) format.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Multiple distance and correlation metrics (per-gene and aggregate)
|
|
9
|
+
- Condition-based matching (perturbation, cell type, etc.)
|
|
10
|
+
- Train/test split support
|
|
11
|
+
- Publication-quality visualizations
|
|
12
|
+
- Command-line interface
|
|
13
|
+
|
|
14
|
+
Quick Start:
|
|
15
|
+
>>> from geneval import evaluate
|
|
16
|
+
>>> results = evaluate(
|
|
17
|
+
... real_path="real.h5ad",
|
|
18
|
+
... generated_path="generated.h5ad",
|
|
19
|
+
... condition_columns=["perturbation"],
|
|
20
|
+
... output_dir="output/"
|
|
21
|
+
... )
|
|
22
|
+
|
|
23
|
+
CLI Usage:
|
|
24
|
+
$ geneval --real real.h5ad --generated generated.h5ad \\
|
|
25
|
+
--conditions perturbation cell_type --output results/
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
__version__ = "0.1.0"
|
|
29
|
+
__author__ = "GenEval Team"
|
|
30
|
+
|
|
31
|
+
# Main evaluation interface
|
|
32
|
+
from .evaluator import (
|
|
33
|
+
evaluate,
|
|
34
|
+
GeneEvalEvaluator,
|
|
35
|
+
MetricRegistry,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Data loading
|
|
39
|
+
from .data.loader import (
|
|
40
|
+
GeneExpressionDataLoader,
|
|
41
|
+
load_data,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Results
|
|
45
|
+
from .results import (
|
|
46
|
+
EvaluationResult,
|
|
47
|
+
SplitResult,
|
|
48
|
+
ConditionResult,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Metrics
|
|
52
|
+
from .metrics.base_metric import (
|
|
53
|
+
BaseMetric,
|
|
54
|
+
MetricResult,
|
|
55
|
+
DistributionMetric,
|
|
56
|
+
CorrelationMetric,
|
|
57
|
+
)
|
|
58
|
+
from .metrics.correlation import (
|
|
59
|
+
PearsonCorrelation,
|
|
60
|
+
SpearmanCorrelation,
|
|
61
|
+
MeanPearsonCorrelation,
|
|
62
|
+
MeanSpearmanCorrelation,
|
|
63
|
+
)
|
|
64
|
+
from .metrics.distances import (
|
|
65
|
+
Wasserstein1Distance,
|
|
66
|
+
Wasserstein2Distance,
|
|
67
|
+
MMDDistance,
|
|
68
|
+
EnergyDistance,
|
|
69
|
+
MultivariateWasserstein,
|
|
70
|
+
MultivariateMMD,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Visualization
|
|
74
|
+
from .visualization.visualizer import (
|
|
75
|
+
EvaluationVisualizer,
|
|
76
|
+
visualize,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Legacy support
|
|
80
|
+
from .data.gene_expression_datamodule import GeneExpressionDataModule
|
|
81
|
+
|
|
82
|
+
# Testing utilities (for users to generate test data)
|
|
83
|
+
from .testing import (
|
|
84
|
+
MockDataGenerator,
|
|
85
|
+
MockMetricData,
|
|
86
|
+
create_test_data,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
__all__ = [
|
|
90
|
+
# Version
|
|
91
|
+
"__version__",
|
|
92
|
+
# Main API
|
|
93
|
+
"evaluate",
|
|
94
|
+
"GeneEvalEvaluator",
|
|
95
|
+
"MetricRegistry",
|
|
96
|
+
# Data loading
|
|
97
|
+
"GeneExpressionDataLoader",
|
|
98
|
+
"load_data",
|
|
99
|
+
# Results
|
|
100
|
+
"EvaluationResult",
|
|
101
|
+
"SplitResult",
|
|
102
|
+
"ConditionResult",
|
|
103
|
+
# Base metrics
|
|
104
|
+
"BaseMetric",
|
|
105
|
+
"MetricResult",
|
|
106
|
+
"DistributionMetric",
|
|
107
|
+
"CorrelationMetric",
|
|
108
|
+
# Correlation metrics
|
|
109
|
+
"PearsonCorrelation",
|
|
110
|
+
"SpearmanCorrelation",
|
|
111
|
+
"MeanPearsonCorrelation",
|
|
112
|
+
"MeanSpearmanCorrelation",
|
|
113
|
+
# Distance metrics
|
|
114
|
+
"Wasserstein1Distance",
|
|
115
|
+
"Wasserstein2Distance",
|
|
116
|
+
"MMDDistance",
|
|
117
|
+
"EnergyDistance",
|
|
118
|
+
"MultivariateWasserstein",
|
|
119
|
+
"MultivariateMMD",
|
|
120
|
+
# Visualization
|
|
121
|
+
"EvaluationVisualizer",
|
|
122
|
+
"visualize",
|
|
123
|
+
# Testing utilities
|
|
124
|
+
"MockDataGenerator",
|
|
125
|
+
"MockMetricData",
|
|
126
|
+
"create_test_data",
|
|
127
|
+
# Legacy
|
|
128
|
+
"GeneExpressionDataModule",
|
|
129
|
+
]
|
geneval/cli.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for GenEval gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides comprehensive CLI for evaluating generated vs real gene expression data.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import List, Optional
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def create_parser() -> argparse.ArgumentParser:
|
|
15
|
+
"""Create the argument parser."""
|
|
16
|
+
parser = argparse.ArgumentParser(
|
|
17
|
+
prog="geneval",
|
|
18
|
+
description="""
|
|
19
|
+
GenEval: Comprehensive evaluation of generated gene expression data.
|
|
20
|
+
|
|
21
|
+
Computes metrics between real and generated datasets, matching samples
|
|
22
|
+
by condition columns (e.g., perturbation, cell type). Supports train/test
|
|
23
|
+
splits and generates publication-quality visualizations.
|
|
24
|
+
|
|
25
|
+
Metrics computed:
|
|
26
|
+
- Pearson and Spearman correlation
|
|
27
|
+
- Wasserstein-1 and Wasserstein-2 distance
|
|
28
|
+
- Maximum Mean Discrepancy (MMD)
|
|
29
|
+
- Energy distance
|
|
30
|
+
- Multivariate versions of distance metrics
|
|
31
|
+
|
|
32
|
+
All metrics are computed per-gene and aggregated.
|
|
33
|
+
""",
|
|
34
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Required arguments
|
|
38
|
+
required = parser.add_argument_group("Required arguments")
|
|
39
|
+
required.add_argument(
|
|
40
|
+
"--real", "-r",
|
|
41
|
+
type=str,
|
|
42
|
+
required=True,
|
|
43
|
+
help="Path to real data file (h5ad format)",
|
|
44
|
+
)
|
|
45
|
+
required.add_argument(
|
|
46
|
+
"--generated", "-g",
|
|
47
|
+
type=str,
|
|
48
|
+
required=True,
|
|
49
|
+
help="Path to generated data file (h5ad format)",
|
|
50
|
+
)
|
|
51
|
+
required.add_argument(
|
|
52
|
+
"--conditions", "-c",
|
|
53
|
+
type=str,
|
|
54
|
+
nargs="+",
|
|
55
|
+
required=True,
|
|
56
|
+
help="Condition columns to match (e.g., perturbation cell_type)",
|
|
57
|
+
)
|
|
58
|
+
required.add_argument(
|
|
59
|
+
"--output", "-o",
|
|
60
|
+
type=str,
|
|
61
|
+
required=True,
|
|
62
|
+
help="Output directory for results and plots",
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Optional arguments
|
|
66
|
+
optional = parser.add_argument_group("Optional arguments")
|
|
67
|
+
optional.add_argument(
|
|
68
|
+
"--split-column", "-s",
|
|
69
|
+
type=str,
|
|
70
|
+
default=None,
|
|
71
|
+
help="Column indicating train/test split. If not provided, all data treated as one split.",
|
|
72
|
+
)
|
|
73
|
+
optional.add_argument(
|
|
74
|
+
"--splits",
|
|
75
|
+
type=str,
|
|
76
|
+
nargs="+",
|
|
77
|
+
default=None,
|
|
78
|
+
help="Specific splits to evaluate (e.g., 'test' or 'train test'). Default: all splits.",
|
|
79
|
+
)
|
|
80
|
+
optional.add_argument(
|
|
81
|
+
"--metrics",
|
|
82
|
+
type=str,
|
|
83
|
+
nargs="+",
|
|
84
|
+
default=None,
|
|
85
|
+
choices=[
|
|
86
|
+
"pearson", "spearman", "mean_pearson", "mean_spearman",
|
|
87
|
+
"wasserstein_1", "wasserstein_2", "mmd", "energy",
|
|
88
|
+
"multivariate_wasserstein", "multivariate_mmd", "all"
|
|
89
|
+
],
|
|
90
|
+
help="Metrics to compute. Default: all metrics.",
|
|
91
|
+
)
|
|
92
|
+
optional.add_argument(
|
|
93
|
+
"--min-samples",
|
|
94
|
+
type=int,
|
|
95
|
+
default=2,
|
|
96
|
+
help="Minimum samples per condition to include (default: 2)",
|
|
97
|
+
)
|
|
98
|
+
optional.add_argument(
|
|
99
|
+
"--aggregate",
|
|
100
|
+
type=str,
|
|
101
|
+
default="mean",
|
|
102
|
+
choices=["mean", "median", "std"],
|
|
103
|
+
help="How to aggregate per-gene metrics (default: mean)",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Plotting arguments
|
|
107
|
+
plotting = parser.add_argument_group("Plotting options")
|
|
108
|
+
plotting.add_argument(
|
|
109
|
+
"--no-plots",
|
|
110
|
+
action="store_true",
|
|
111
|
+
help="Skip plot generation",
|
|
112
|
+
)
|
|
113
|
+
plotting.add_argument(
|
|
114
|
+
"--plot-formats",
|
|
115
|
+
type=str,
|
|
116
|
+
nargs="+",
|
|
117
|
+
default=["png", "pdf"],
|
|
118
|
+
help="Output formats for plots (default: png pdf)",
|
|
119
|
+
)
|
|
120
|
+
plotting.add_argument(
|
|
121
|
+
"--dpi",
|
|
122
|
+
type=int,
|
|
123
|
+
default=150,
|
|
124
|
+
help="Resolution for saved plots (default: 150)",
|
|
125
|
+
)
|
|
126
|
+
plotting.add_argument(
|
|
127
|
+
"--embedding",
|
|
128
|
+
type=str,
|
|
129
|
+
nargs="+",
|
|
130
|
+
default=["pca"],
|
|
131
|
+
choices=["pca", "umap", "both", "none"],
|
|
132
|
+
help="Embedding methods for visualization (default: pca)",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Output options
|
|
136
|
+
output = parser.add_argument_group("Output options")
|
|
137
|
+
output.add_argument(
|
|
138
|
+
"--verbose", "-v",
|
|
139
|
+
action="store_true",
|
|
140
|
+
help="Print detailed progress",
|
|
141
|
+
)
|
|
142
|
+
output.add_argument(
|
|
143
|
+
"--quiet", "-q",
|
|
144
|
+
action="store_true",
|
|
145
|
+
help="Suppress all output except errors",
|
|
146
|
+
)
|
|
147
|
+
output.add_argument(
|
|
148
|
+
"--save-per-gene",
|
|
149
|
+
action="store_true",
|
|
150
|
+
help="Save per-gene metric values (can be large files)",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return parser
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_metric_classes(metric_names: Optional[List[str]] = None):
|
|
157
|
+
"""Get metric classes from names."""
|
|
158
|
+
from .metrics.correlation import (
|
|
159
|
+
PearsonCorrelation,
|
|
160
|
+
SpearmanCorrelation,
|
|
161
|
+
MeanPearsonCorrelation,
|
|
162
|
+
MeanSpearmanCorrelation,
|
|
163
|
+
)
|
|
164
|
+
from .metrics.distances import (
|
|
165
|
+
Wasserstein1Distance,
|
|
166
|
+
Wasserstein2Distance,
|
|
167
|
+
MMDDistance,
|
|
168
|
+
EnergyDistance,
|
|
169
|
+
MultivariateWasserstein,
|
|
170
|
+
MultivariateMMD,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
all_metrics = {
|
|
174
|
+
"pearson": PearsonCorrelation,
|
|
175
|
+
"spearman": SpearmanCorrelation,
|
|
176
|
+
"mean_pearson": MeanPearsonCorrelation,
|
|
177
|
+
"mean_spearman": MeanSpearmanCorrelation,
|
|
178
|
+
"wasserstein_1": Wasserstein1Distance,
|
|
179
|
+
"wasserstein_2": Wasserstein2Distance,
|
|
180
|
+
"mmd": MMDDistance,
|
|
181
|
+
"energy": EnergyDistance,
|
|
182
|
+
"multivariate_wasserstein": MultivariateWasserstein,
|
|
183
|
+
"multivariate_mmd": MultivariateMMD,
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if metric_names is None or "all" in metric_names:
|
|
187
|
+
return list(all_metrics.values())
|
|
188
|
+
|
|
189
|
+
return [all_metrics[name] for name in metric_names if name in all_metrics]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def main(args: Optional[List[str]] = None):
|
|
193
|
+
"""Main entry point for CLI."""
|
|
194
|
+
parser = create_parser()
|
|
195
|
+
parsed = parser.parse_args(args)
|
|
196
|
+
|
|
197
|
+
# Set verbosity
|
|
198
|
+
verbose = not parsed.quiet
|
|
199
|
+
if parsed.verbose:
|
|
200
|
+
verbose = True
|
|
201
|
+
|
|
202
|
+
# Validate paths
|
|
203
|
+
real_path = Path(parsed.real)
|
|
204
|
+
gen_path = Path(parsed.generated)
|
|
205
|
+
output_dir = Path(parsed.output)
|
|
206
|
+
|
|
207
|
+
if not real_path.exists():
|
|
208
|
+
print(f"Error: Real data file not found: {real_path}", file=sys.stderr)
|
|
209
|
+
sys.exit(1)
|
|
210
|
+
|
|
211
|
+
if not gen_path.exists():
|
|
212
|
+
print(f"Error: Generated data file not found: {gen_path}", file=sys.stderr)
|
|
213
|
+
sys.exit(1)
|
|
214
|
+
|
|
215
|
+
# Import here to avoid slow startup
|
|
216
|
+
from .data.loader import load_data
|
|
217
|
+
from .evaluator import GeneEvalEvaluator
|
|
218
|
+
from .visualization.visualizer import EvaluationVisualizer
|
|
219
|
+
|
|
220
|
+
if verbose:
|
|
221
|
+
print("=" * 60)
|
|
222
|
+
print("GenEval: Gene Expression Evaluation")
|
|
223
|
+
print("=" * 60)
|
|
224
|
+
print(f"\nReal data: {real_path}")
|
|
225
|
+
print(f"Generated data: {gen_path}")
|
|
226
|
+
print(f"Conditions: {parsed.conditions}")
|
|
227
|
+
print(f"Output: {output_dir}")
|
|
228
|
+
if parsed.split_column:
|
|
229
|
+
print(f"Split column: {parsed.split_column}")
|
|
230
|
+
print()
|
|
231
|
+
|
|
232
|
+
# Load data
|
|
233
|
+
if verbose:
|
|
234
|
+
print("Loading data...")
|
|
235
|
+
|
|
236
|
+
try:
|
|
237
|
+
loader = load_data(
|
|
238
|
+
real_path=real_path,
|
|
239
|
+
generated_path=gen_path,
|
|
240
|
+
condition_columns=parsed.conditions,
|
|
241
|
+
split_column=parsed.split_column,
|
|
242
|
+
min_samples_per_condition=parsed.min_samples,
|
|
243
|
+
)
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print(f"Error loading data: {e}", file=sys.stderr)
|
|
246
|
+
sys.exit(1)
|
|
247
|
+
|
|
248
|
+
if verbose:
|
|
249
|
+
summary = loader.summary()
|
|
250
|
+
print(f" Real: {summary['real']['n_samples']} samples x {summary['real']['n_genes']} genes")
|
|
251
|
+
print(f" Generated: {summary['generated']['n_samples']} samples x {summary['generated']['n_genes']} genes")
|
|
252
|
+
print(f" Common genes: {summary.get('n_common_genes', 'N/A')}")
|
|
253
|
+
print(f" Splits: {summary.get('splits', ['all'])}")
|
|
254
|
+
print()
|
|
255
|
+
|
|
256
|
+
# Get metrics
|
|
257
|
+
metric_classes = get_metric_classes(parsed.metrics)
|
|
258
|
+
|
|
259
|
+
# Determine if multivariate metrics should be included
|
|
260
|
+
include_multivariate = (
|
|
261
|
+
parsed.metrics is None or
|
|
262
|
+
"all" in parsed.metrics or
|
|
263
|
+
any(m.startswith("multivariate") for m in (parsed.metrics or []))
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Create evaluator
|
|
267
|
+
evaluator = GeneEvalEvaluator(
|
|
268
|
+
data_loader=loader,
|
|
269
|
+
metrics=metric_classes,
|
|
270
|
+
aggregate_method=parsed.aggregate,
|
|
271
|
+
include_multivariate=include_multivariate,
|
|
272
|
+
verbose=verbose,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Run evaluation
|
|
276
|
+
if verbose:
|
|
277
|
+
print("Running evaluation...")
|
|
278
|
+
|
|
279
|
+
results = evaluator.evaluate(
|
|
280
|
+
splits=parsed.splits,
|
|
281
|
+
save_dir=output_dir,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Generate plots
|
|
285
|
+
if not parsed.no_plots:
|
|
286
|
+
if verbose:
|
|
287
|
+
print("\nGenerating visualizations...")
|
|
288
|
+
|
|
289
|
+
plot_dir = output_dir / "plots"
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
viz = EvaluationVisualizer(results, dpi=parsed.dpi)
|
|
293
|
+
|
|
294
|
+
# Determine embedding methods
|
|
295
|
+
embedding_methods = parsed.embedding
|
|
296
|
+
if "none" in embedding_methods:
|
|
297
|
+
embedding_methods = []
|
|
298
|
+
elif "both" in embedding_methods:
|
|
299
|
+
embedding_methods = ["pca", "umap"]
|
|
300
|
+
|
|
301
|
+
viz.save_all(
|
|
302
|
+
output_dir=plot_dir,
|
|
303
|
+
formats=parsed.plot_formats,
|
|
304
|
+
data_loader=loader if embedding_methods else None,
|
|
305
|
+
)
|
|
306
|
+
except Exception as e:
|
|
307
|
+
print(f"Warning: Failed to generate some plots: {e}", file=sys.stderr)
|
|
308
|
+
|
|
309
|
+
# Print final summary
|
|
310
|
+
if verbose:
|
|
311
|
+
print("\n" + "=" * 60)
|
|
312
|
+
print("RESULTS SAVED")
|
|
313
|
+
print("=" * 60)
|
|
314
|
+
print(f"\nOutput directory: {output_dir}")
|
|
315
|
+
print("\nFiles generated:")
|
|
316
|
+
print(f" - summary.json: Aggregate metrics and metadata")
|
|
317
|
+
print(f" - results.csv: Per-condition metrics")
|
|
318
|
+
if parsed.save_per_gene:
|
|
319
|
+
print(f" - per_gene_*.csv: Per-gene metric values")
|
|
320
|
+
if not parsed.no_plots:
|
|
321
|
+
print(f" - plots/: Visualization figures")
|
|
322
|
+
print()
|
|
323
|
+
|
|
324
|
+
return results
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def run():
|
|
328
|
+
"""Entry point for console script."""
|
|
329
|
+
main()
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
if __name__ == "__main__":
|
|
333
|
+
run()
|
geneval/config.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration settings for GenEval.
|
|
3
|
+
|
|
4
|
+
Provides centralized configuration for metrics, paths, and defaults.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import List, Dict, Any, Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class MetricConfig:
|
|
15
|
+
"""Configuration for metric computation."""
|
|
16
|
+
|
|
17
|
+
# Default metrics to compute
|
|
18
|
+
default_metrics: List[str] = field(default_factory=lambda: [
|
|
19
|
+
"pearson",
|
|
20
|
+
"spearman",
|
|
21
|
+
"mean_pearson",
|
|
22
|
+
"mean_spearman",
|
|
23
|
+
"wasserstein_1",
|
|
24
|
+
"wasserstein_2",
|
|
25
|
+
"mmd",
|
|
26
|
+
"energy",
|
|
27
|
+
])
|
|
28
|
+
|
|
29
|
+
# Whether to include multivariate metrics
|
|
30
|
+
include_multivariate: bool = True
|
|
31
|
+
|
|
32
|
+
# Aggregation method for per-gene metrics
|
|
33
|
+
aggregate_method: str = "mean"
|
|
34
|
+
|
|
35
|
+
# Wasserstein parameters
|
|
36
|
+
wasserstein_blur: float = 0.01
|
|
37
|
+
|
|
38
|
+
# MMD parameters
|
|
39
|
+
mmd_kernel: str = "rbf"
|
|
40
|
+
mmd_sigma: Optional[float] = None # None = median heuristic
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class DataConfig:
|
|
45
|
+
"""Configuration for data loading."""
|
|
46
|
+
|
|
47
|
+
# Minimum samples per condition
|
|
48
|
+
min_samples_per_condition: int = 2
|
|
49
|
+
|
|
50
|
+
# Default split column name
|
|
51
|
+
default_split_column: str = "split"
|
|
52
|
+
|
|
53
|
+
# Standard split values
|
|
54
|
+
train_split_values: List[str] = field(default_factory=lambda: ["train", "training"])
|
|
55
|
+
test_split_values: List[str] = field(default_factory=lambda: ["test", "testing", "val", "validation"])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class PlotConfig:
|
|
60
|
+
"""Configuration for plotting."""
|
|
61
|
+
|
|
62
|
+
# Figure DPI
|
|
63
|
+
dpi: int = 150
|
|
64
|
+
|
|
65
|
+
# Default figure sizes
|
|
66
|
+
figure_small: tuple = (8, 6)
|
|
67
|
+
figure_medium: tuple = (12, 8)
|
|
68
|
+
figure_large: tuple = (16, 12)
|
|
69
|
+
figure_wide: tuple = (16, 6)
|
|
70
|
+
|
|
71
|
+
# Style settings
|
|
72
|
+
style: str = "whitegrid"
|
|
73
|
+
context: str = "paper"
|
|
74
|
+
font_scale: float = 1.2
|
|
75
|
+
|
|
76
|
+
# Colors
|
|
77
|
+
real_color: str = "#1f77b4" # Blue
|
|
78
|
+
generated_color: str = "#ff7f0e" # Orange
|
|
79
|
+
|
|
80
|
+
# Output formats
|
|
81
|
+
default_formats: List[str] = field(default_factory=lambda: ["png", "pdf"])
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class Config:
|
|
86
|
+
"""
|
|
87
|
+
Main configuration class for GenEval.
|
|
88
|
+
|
|
89
|
+
Combines all configuration settings.
|
|
90
|
+
"""
|
|
91
|
+
metrics: MetricConfig = field(default_factory=MetricConfig)
|
|
92
|
+
data: DataConfig = field(default_factory=DataConfig)
|
|
93
|
+
plot: PlotConfig = field(default_factory=PlotConfig)
|
|
94
|
+
|
|
95
|
+
# Output settings
|
|
96
|
+
output_dir: Path = Path("output/")
|
|
97
|
+
log_dir: Path = Path("logs/")
|
|
98
|
+
|
|
99
|
+
# Verbosity
|
|
100
|
+
verbose: bool = True
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def default(cls) -> "Config":
|
|
104
|
+
"""Get default configuration."""
|
|
105
|
+
return cls()
|
|
106
|
+
|
|
107
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
108
|
+
"""Convert config to dictionary."""
|
|
109
|
+
return {
|
|
110
|
+
"metrics": {
|
|
111
|
+
"default_metrics": self.metrics.default_metrics,
|
|
112
|
+
"include_multivariate": self.metrics.include_multivariate,
|
|
113
|
+
"aggregate_method": self.metrics.aggregate_method,
|
|
114
|
+
},
|
|
115
|
+
"data": {
|
|
116
|
+
"min_samples_per_condition": self.data.min_samples_per_condition,
|
|
117
|
+
"default_split_column": self.data.default_split_column,
|
|
118
|
+
},
|
|
119
|
+
"plot": {
|
|
120
|
+
"dpi": self.plot.dpi,
|
|
121
|
+
"style": self.plot.style,
|
|
122
|
+
"default_formats": self.plot.default_formats,
|
|
123
|
+
},
|
|
124
|
+
"output_dir": str(self.output_dir),
|
|
125
|
+
"verbose": self.verbose,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# Global default config instance
|
|
130
|
+
DEFAULT_CONFIG = Config.default()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_config() -> Config:
|
|
134
|
+
"""Get the current configuration."""
|
|
135
|
+
return DEFAULT_CONFIG
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def set_config(config: Config):
|
|
139
|
+
"""Set the global configuration."""
|
|
140
|
+
global DEFAULT_CONFIG
|
|
141
|
+
DEFAULT_CONFIG = config
|
geneval/core.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
class BaseEvaluator(ABC):
|
|
4
|
+
"""
|
|
5
|
+
Abstract base class for evaluators in the gene expression evaluation system.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, data, output):
|
|
9
|
+
self.data = data
|
|
10
|
+
self.output = output
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def evaluate(self, *args, **kwargs):
|
|
14
|
+
"""
|
|
15
|
+
Evaluate the model performance based on the provided data and output.
|
|
16
|
+
This method should be implemented by subclasses.
|
|
17
|
+
"""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class GeneExpressionEvaluator(BaseEvaluator):
|
|
22
|
+
"""
|
|
23
|
+
Evaluator for gene expression data.
|
|
24
|
+
|
|
25
|
+
Computes various metrics between real and generated gene expression profiles,
|
|
26
|
+
optionally adjusting for control conditions and covariates.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
data : GeneExpressionDataModule
|
|
31
|
+
The data module containing gene expression datasets.
|
|
32
|
+
output : AnnData
|
|
33
|
+
The generated gene expression data to evaluate.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, data, output):
|
|
37
|
+
super().__init__(data, output)
|
|
38
|
+
|
|
39
|
+
def evaluate(self, delta=False, plot=False, DEG=None):
|
|
40
|
+
# Implementation of the evaluation logic will go here
|
|
41
|
+
pass
|
geneval/data/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loading module for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides data loaders for paired real and generated datasets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .loader import (
|
|
8
|
+
GeneExpressionDataLoader,
|
|
9
|
+
load_data,
|
|
10
|
+
DataLoaderError,
|
|
11
|
+
)
|
|
12
|
+
from .gene_expression_datamodule import (
|
|
13
|
+
GeneExpressionDataModule,
|
|
14
|
+
DataModuleError,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"GeneExpressionDataLoader",
|
|
19
|
+
"load_data",
|
|
20
|
+
"DataLoaderError",
|
|
21
|
+
"GeneExpressionDataModule",
|
|
22
|
+
"DataModuleError",
|
|
23
|
+
]
|