gengeneeval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/__init__.py ADDED
@@ -0,0 +1,129 @@
1
+ """
2
+ GenEval: Comprehensive evaluation of generated gene expression data.
3
+
4
+ A modular, object-oriented framework for computing metrics between real
5
+ and generated gene expression datasets stored in AnnData (h5ad) format.
6
+
7
+ Features:
8
+ - Multiple distance and correlation metrics (per-gene and aggregate)
9
+ - Condition-based matching (perturbation, cell type, etc.)
10
+ - Train/test split support
11
+ - Publication-quality visualizations
12
+ - Command-line interface
13
+
14
+ Quick Start:
15
+ >>> from geneval import evaluate
16
+ >>> results = evaluate(
17
+ ... real_path="real.h5ad",
18
+ ... generated_path="generated.h5ad",
19
+ ... condition_columns=["perturbation"],
20
+ ... output_dir="output/"
21
+ ... )
22
+
23
+ CLI Usage:
24
+ $ geneval --real real.h5ad --generated generated.h5ad \\
25
+ --conditions perturbation cell_type --output results/
26
+ """
27
+
28
+ __version__ = "0.1.0"
29
+ __author__ = "GenEval Team"
30
+
31
+ # Main evaluation interface
32
+ from .evaluator import (
33
+ evaluate,
34
+ GeneEvalEvaluator,
35
+ MetricRegistry,
36
+ )
37
+
38
+ # Data loading
39
+ from .data.loader import (
40
+ GeneExpressionDataLoader,
41
+ load_data,
42
+ )
43
+
44
+ # Results
45
+ from .results import (
46
+ EvaluationResult,
47
+ SplitResult,
48
+ ConditionResult,
49
+ )
50
+
51
+ # Metrics
52
+ from .metrics.base_metric import (
53
+ BaseMetric,
54
+ MetricResult,
55
+ DistributionMetric,
56
+ CorrelationMetric,
57
+ )
58
+ from .metrics.correlation import (
59
+ PearsonCorrelation,
60
+ SpearmanCorrelation,
61
+ MeanPearsonCorrelation,
62
+ MeanSpearmanCorrelation,
63
+ )
64
+ from .metrics.distances import (
65
+ Wasserstein1Distance,
66
+ Wasserstein2Distance,
67
+ MMDDistance,
68
+ EnergyDistance,
69
+ MultivariateWasserstein,
70
+ MultivariateMMD,
71
+ )
72
+
73
+ # Visualization
74
+ from .visualization.visualizer import (
75
+ EvaluationVisualizer,
76
+ visualize,
77
+ )
78
+
79
+ # Legacy support
80
+ from .data.gene_expression_datamodule import GeneExpressionDataModule
81
+
82
+ # Testing utilities (for users to generate test data)
83
+ from .testing import (
84
+ MockDataGenerator,
85
+ MockMetricData,
86
+ create_test_data,
87
+ )
88
+
89
+ __all__ = [
90
+ # Version
91
+ "__version__",
92
+ # Main API
93
+ "evaluate",
94
+ "GeneEvalEvaluator",
95
+ "MetricRegistry",
96
+ # Data loading
97
+ "GeneExpressionDataLoader",
98
+ "load_data",
99
+ # Results
100
+ "EvaluationResult",
101
+ "SplitResult",
102
+ "ConditionResult",
103
+ # Base metrics
104
+ "BaseMetric",
105
+ "MetricResult",
106
+ "DistributionMetric",
107
+ "CorrelationMetric",
108
+ # Correlation metrics
109
+ "PearsonCorrelation",
110
+ "SpearmanCorrelation",
111
+ "MeanPearsonCorrelation",
112
+ "MeanSpearmanCorrelation",
113
+ # Distance metrics
114
+ "Wasserstein1Distance",
115
+ "Wasserstein2Distance",
116
+ "MMDDistance",
117
+ "EnergyDistance",
118
+ "MultivariateWasserstein",
119
+ "MultivariateMMD",
120
+ # Visualization
121
+ "EvaluationVisualizer",
122
+ "visualize",
123
+ # Testing utilities
124
+ "MockDataGenerator",
125
+ "MockMetricData",
126
+ "create_test_data",
127
+ # Legacy
128
+ "GeneExpressionDataModule",
129
+ ]
geneval/cli.py ADDED
@@ -0,0 +1,333 @@
1
+ """
2
+ Command-line interface for GenEval gene expression evaluation.
3
+
4
+ Provides comprehensive CLI for evaluating generated vs real gene expression data.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import List, Optional
12
+
13
+
14
+ def create_parser() -> argparse.ArgumentParser:
15
+ """Create the argument parser."""
16
+ parser = argparse.ArgumentParser(
17
+ prog="geneval",
18
+ description="""
19
+ GenEval: Comprehensive evaluation of generated gene expression data.
20
+
21
+ Computes metrics between real and generated datasets, matching samples
22
+ by condition columns (e.g., perturbation, cell type). Supports train/test
23
+ splits and generates publication-quality visualizations.
24
+
25
+ Metrics computed:
26
+ - Pearson and Spearman correlation
27
+ - Wasserstein-1 and Wasserstein-2 distance
28
+ - Maximum Mean Discrepancy (MMD)
29
+ - Energy distance
30
+ - Multivariate versions of distance metrics
31
+
32
+ All metrics are computed per-gene and aggregated.
33
+ """,
34
+ formatter_class=argparse.RawDescriptionHelpFormatter,
35
+ )
36
+
37
+ # Required arguments
38
+ required = parser.add_argument_group("Required arguments")
39
+ required.add_argument(
40
+ "--real", "-r",
41
+ type=str,
42
+ required=True,
43
+ help="Path to real data file (h5ad format)",
44
+ )
45
+ required.add_argument(
46
+ "--generated", "-g",
47
+ type=str,
48
+ required=True,
49
+ help="Path to generated data file (h5ad format)",
50
+ )
51
+ required.add_argument(
52
+ "--conditions", "-c",
53
+ type=str,
54
+ nargs="+",
55
+ required=True,
56
+ help="Condition columns to match (e.g., perturbation cell_type)",
57
+ )
58
+ required.add_argument(
59
+ "--output", "-o",
60
+ type=str,
61
+ required=True,
62
+ help="Output directory for results and plots",
63
+ )
64
+
65
+ # Optional arguments
66
+ optional = parser.add_argument_group("Optional arguments")
67
+ optional.add_argument(
68
+ "--split-column", "-s",
69
+ type=str,
70
+ default=None,
71
+ help="Column indicating train/test split. If not provided, all data treated as one split.",
72
+ )
73
+ optional.add_argument(
74
+ "--splits",
75
+ type=str,
76
+ nargs="+",
77
+ default=None,
78
+ help="Specific splits to evaluate (e.g., 'test' or 'train test'). Default: all splits.",
79
+ )
80
+ optional.add_argument(
81
+ "--metrics",
82
+ type=str,
83
+ nargs="+",
84
+ default=None,
85
+ choices=[
86
+ "pearson", "spearman", "mean_pearson", "mean_spearman",
87
+ "wasserstein_1", "wasserstein_2", "mmd", "energy",
88
+ "multivariate_wasserstein", "multivariate_mmd", "all"
89
+ ],
90
+ help="Metrics to compute. Default: all metrics.",
91
+ )
92
+ optional.add_argument(
93
+ "--min-samples",
94
+ type=int,
95
+ default=2,
96
+ help="Minimum samples per condition to include (default: 2)",
97
+ )
98
+ optional.add_argument(
99
+ "--aggregate",
100
+ type=str,
101
+ default="mean",
102
+ choices=["mean", "median", "std"],
103
+ help="How to aggregate per-gene metrics (default: mean)",
104
+ )
105
+
106
+ # Plotting arguments
107
+ plotting = parser.add_argument_group("Plotting options")
108
+ plotting.add_argument(
109
+ "--no-plots",
110
+ action="store_true",
111
+ help="Skip plot generation",
112
+ )
113
+ plotting.add_argument(
114
+ "--plot-formats",
115
+ type=str,
116
+ nargs="+",
117
+ default=["png", "pdf"],
118
+ help="Output formats for plots (default: png pdf)",
119
+ )
120
+ plotting.add_argument(
121
+ "--dpi",
122
+ type=int,
123
+ default=150,
124
+ help="Resolution for saved plots (default: 150)",
125
+ )
126
+ plotting.add_argument(
127
+ "--embedding",
128
+ type=str,
129
+ nargs="+",
130
+ default=["pca"],
131
+ choices=["pca", "umap", "both", "none"],
132
+ help="Embedding methods for visualization (default: pca)",
133
+ )
134
+
135
+ # Output options
136
+ output = parser.add_argument_group("Output options")
137
+ output.add_argument(
138
+ "--verbose", "-v",
139
+ action="store_true",
140
+ help="Print detailed progress",
141
+ )
142
+ output.add_argument(
143
+ "--quiet", "-q",
144
+ action="store_true",
145
+ help="Suppress all output except errors",
146
+ )
147
+ output.add_argument(
148
+ "--save-per-gene",
149
+ action="store_true",
150
+ help="Save per-gene metric values (can be large files)",
151
+ )
152
+
153
+ return parser
154
+
155
+
156
+ def get_metric_classes(metric_names: Optional[List[str]] = None):
157
+ """Get metric classes from names."""
158
+ from .metrics.correlation import (
159
+ PearsonCorrelation,
160
+ SpearmanCorrelation,
161
+ MeanPearsonCorrelation,
162
+ MeanSpearmanCorrelation,
163
+ )
164
+ from .metrics.distances import (
165
+ Wasserstein1Distance,
166
+ Wasserstein2Distance,
167
+ MMDDistance,
168
+ EnergyDistance,
169
+ MultivariateWasserstein,
170
+ MultivariateMMD,
171
+ )
172
+
173
+ all_metrics = {
174
+ "pearson": PearsonCorrelation,
175
+ "spearman": SpearmanCorrelation,
176
+ "mean_pearson": MeanPearsonCorrelation,
177
+ "mean_spearman": MeanSpearmanCorrelation,
178
+ "wasserstein_1": Wasserstein1Distance,
179
+ "wasserstein_2": Wasserstein2Distance,
180
+ "mmd": MMDDistance,
181
+ "energy": EnergyDistance,
182
+ "multivariate_wasserstein": MultivariateWasserstein,
183
+ "multivariate_mmd": MultivariateMMD,
184
+ }
185
+
186
+ if metric_names is None or "all" in metric_names:
187
+ return list(all_metrics.values())
188
+
189
+ return [all_metrics[name] for name in metric_names if name in all_metrics]
190
+
191
+
192
+ def main(args: Optional[List[str]] = None):
193
+ """Main entry point for CLI."""
194
+ parser = create_parser()
195
+ parsed = parser.parse_args(args)
196
+
197
+ # Set verbosity
198
+ verbose = not parsed.quiet
199
+ if parsed.verbose:
200
+ verbose = True
201
+
202
+ # Validate paths
203
+ real_path = Path(parsed.real)
204
+ gen_path = Path(parsed.generated)
205
+ output_dir = Path(parsed.output)
206
+
207
+ if not real_path.exists():
208
+ print(f"Error: Real data file not found: {real_path}", file=sys.stderr)
209
+ sys.exit(1)
210
+
211
+ if not gen_path.exists():
212
+ print(f"Error: Generated data file not found: {gen_path}", file=sys.stderr)
213
+ sys.exit(1)
214
+
215
+ # Import here to avoid slow startup
216
+ from .data.loader import load_data
217
+ from .evaluator import GeneEvalEvaluator
218
+ from .visualization.visualizer import EvaluationVisualizer
219
+
220
+ if verbose:
221
+ print("=" * 60)
222
+ print("GenEval: Gene Expression Evaluation")
223
+ print("=" * 60)
224
+ print(f"\nReal data: {real_path}")
225
+ print(f"Generated data: {gen_path}")
226
+ print(f"Conditions: {parsed.conditions}")
227
+ print(f"Output: {output_dir}")
228
+ if parsed.split_column:
229
+ print(f"Split column: {parsed.split_column}")
230
+ print()
231
+
232
+ # Load data
233
+ if verbose:
234
+ print("Loading data...")
235
+
236
+ try:
237
+ loader = load_data(
238
+ real_path=real_path,
239
+ generated_path=gen_path,
240
+ condition_columns=parsed.conditions,
241
+ split_column=parsed.split_column,
242
+ min_samples_per_condition=parsed.min_samples,
243
+ )
244
+ except Exception as e:
245
+ print(f"Error loading data: {e}", file=sys.stderr)
246
+ sys.exit(1)
247
+
248
+ if verbose:
249
+ summary = loader.summary()
250
+ print(f" Real: {summary['real']['n_samples']} samples x {summary['real']['n_genes']} genes")
251
+ print(f" Generated: {summary['generated']['n_samples']} samples x {summary['generated']['n_genes']} genes")
252
+ print(f" Common genes: {summary.get('n_common_genes', 'N/A')}")
253
+ print(f" Splits: {summary.get('splits', ['all'])}")
254
+ print()
255
+
256
+ # Get metrics
257
+ metric_classes = get_metric_classes(parsed.metrics)
258
+
259
+ # Determine if multivariate metrics should be included
260
+ include_multivariate = (
261
+ parsed.metrics is None or
262
+ "all" in parsed.metrics or
263
+ any(m.startswith("multivariate") for m in (parsed.metrics or []))
264
+ )
265
+
266
+ # Create evaluator
267
+ evaluator = GeneEvalEvaluator(
268
+ data_loader=loader,
269
+ metrics=metric_classes,
270
+ aggregate_method=parsed.aggregate,
271
+ include_multivariate=include_multivariate,
272
+ verbose=verbose,
273
+ )
274
+
275
+ # Run evaluation
276
+ if verbose:
277
+ print("Running evaluation...")
278
+
279
+ results = evaluator.evaluate(
280
+ splits=parsed.splits,
281
+ save_dir=output_dir,
282
+ )
283
+
284
+ # Generate plots
285
+ if not parsed.no_plots:
286
+ if verbose:
287
+ print("\nGenerating visualizations...")
288
+
289
+ plot_dir = output_dir / "plots"
290
+
291
+ try:
292
+ viz = EvaluationVisualizer(results, dpi=parsed.dpi)
293
+
294
+ # Determine embedding methods
295
+ embedding_methods = parsed.embedding
296
+ if "none" in embedding_methods:
297
+ embedding_methods = []
298
+ elif "both" in embedding_methods:
299
+ embedding_methods = ["pca", "umap"]
300
+
301
+ viz.save_all(
302
+ output_dir=plot_dir,
303
+ formats=parsed.plot_formats,
304
+ data_loader=loader if embedding_methods else None,
305
+ )
306
+ except Exception as e:
307
+ print(f"Warning: Failed to generate some plots: {e}", file=sys.stderr)
308
+
309
+ # Print final summary
310
+ if verbose:
311
+ print("\n" + "=" * 60)
312
+ print("RESULTS SAVED")
313
+ print("=" * 60)
314
+ print(f"\nOutput directory: {output_dir}")
315
+ print("\nFiles generated:")
316
+ print(f" - summary.json: Aggregate metrics and metadata")
317
+ print(f" - results.csv: Per-condition metrics")
318
+ if parsed.save_per_gene:
319
+ print(f" - per_gene_*.csv: Per-gene metric values")
320
+ if not parsed.no_plots:
321
+ print(f" - plots/: Visualization figures")
322
+ print()
323
+
324
+ return results
325
+
326
+
327
+ def run():
328
+ """Entry point for console script."""
329
+ main()
330
+
331
+
332
+ if __name__ == "__main__":
333
+ run()
geneval/config.py ADDED
@@ -0,0 +1,141 @@
1
+ """
2
+ Configuration settings for GenEval.
3
+
4
+ Provides centralized configuration for metrics, paths, and defaults.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Dict, Any, Optional
11
+
12
+
13
+ @dataclass
14
+ class MetricConfig:
15
+ """Configuration for metric computation."""
16
+
17
+ # Default metrics to compute
18
+ default_metrics: List[str] = field(default_factory=lambda: [
19
+ "pearson",
20
+ "spearman",
21
+ "mean_pearson",
22
+ "mean_spearman",
23
+ "wasserstein_1",
24
+ "wasserstein_2",
25
+ "mmd",
26
+ "energy",
27
+ ])
28
+
29
+ # Whether to include multivariate metrics
30
+ include_multivariate: bool = True
31
+
32
+ # Aggregation method for per-gene metrics
33
+ aggregate_method: str = "mean"
34
+
35
+ # Wasserstein parameters
36
+ wasserstein_blur: float = 0.01
37
+
38
+ # MMD parameters
39
+ mmd_kernel: str = "rbf"
40
+ mmd_sigma: Optional[float] = None # None = median heuristic
41
+
42
+
43
+ @dataclass
44
+ class DataConfig:
45
+ """Configuration for data loading."""
46
+
47
+ # Minimum samples per condition
48
+ min_samples_per_condition: int = 2
49
+
50
+ # Default split column name
51
+ default_split_column: str = "split"
52
+
53
+ # Standard split values
54
+ train_split_values: List[str] = field(default_factory=lambda: ["train", "training"])
55
+ test_split_values: List[str] = field(default_factory=lambda: ["test", "testing", "val", "validation"])
56
+
57
+
58
+ @dataclass
59
+ class PlotConfig:
60
+ """Configuration for plotting."""
61
+
62
+ # Figure DPI
63
+ dpi: int = 150
64
+
65
+ # Default figure sizes
66
+ figure_small: tuple = (8, 6)
67
+ figure_medium: tuple = (12, 8)
68
+ figure_large: tuple = (16, 12)
69
+ figure_wide: tuple = (16, 6)
70
+
71
+ # Style settings
72
+ style: str = "whitegrid"
73
+ context: str = "paper"
74
+ font_scale: float = 1.2
75
+
76
+ # Colors
77
+ real_color: str = "#1f77b4" # Blue
78
+ generated_color: str = "#ff7f0e" # Orange
79
+
80
+ # Output formats
81
+ default_formats: List[str] = field(default_factory=lambda: ["png", "pdf"])
82
+
83
+
84
+ @dataclass
85
+ class Config:
86
+ """
87
+ Main configuration class for GenEval.
88
+
89
+ Combines all configuration settings.
90
+ """
91
+ metrics: MetricConfig = field(default_factory=MetricConfig)
92
+ data: DataConfig = field(default_factory=DataConfig)
93
+ plot: PlotConfig = field(default_factory=PlotConfig)
94
+
95
+ # Output settings
96
+ output_dir: Path = Path("output/")
97
+ log_dir: Path = Path("logs/")
98
+
99
+ # Verbosity
100
+ verbose: bool = True
101
+
102
+ @classmethod
103
+ def default(cls) -> "Config":
104
+ """Get default configuration."""
105
+ return cls()
106
+
107
+ def to_dict(self) -> Dict[str, Any]:
108
+ """Convert config to dictionary."""
109
+ return {
110
+ "metrics": {
111
+ "default_metrics": self.metrics.default_metrics,
112
+ "include_multivariate": self.metrics.include_multivariate,
113
+ "aggregate_method": self.metrics.aggregate_method,
114
+ },
115
+ "data": {
116
+ "min_samples_per_condition": self.data.min_samples_per_condition,
117
+ "default_split_column": self.data.default_split_column,
118
+ },
119
+ "plot": {
120
+ "dpi": self.plot.dpi,
121
+ "style": self.plot.style,
122
+ "default_formats": self.plot.default_formats,
123
+ },
124
+ "output_dir": str(self.output_dir),
125
+ "verbose": self.verbose,
126
+ }
127
+
128
+
129
+ # Global default config instance
130
+ DEFAULT_CONFIG = Config.default()
131
+
132
+
133
+ def get_config() -> Config:
134
+ """Get the current configuration."""
135
+ return DEFAULT_CONFIG
136
+
137
+
138
+ def set_config(config: Config):
139
+ """Set the global configuration."""
140
+ global DEFAULT_CONFIG
141
+ DEFAULT_CONFIG = config
geneval/core.py ADDED
@@ -0,0 +1,41 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ class BaseEvaluator(ABC):
4
+ """
5
+ Abstract base class for evaluators in the gene expression evaluation system.
6
+ """
7
+
8
+ def __init__(self, data, output):
9
+ self.data = data
10
+ self.output = output
11
+
12
+ @abstractmethod
13
+ def evaluate(self, *args, **kwargs):
14
+ """
15
+ Evaluate the model performance based on the provided data and output.
16
+ This method should be implemented by subclasses.
17
+ """
18
+ pass
19
+
20
+
21
+ class GeneExpressionEvaluator(BaseEvaluator):
22
+ """
23
+ Evaluator for gene expression data.
24
+
25
+ Computes various metrics between real and generated gene expression profiles,
26
+ optionally adjusting for control conditions and covariates.
27
+
28
+ Parameters
29
+ ----------
30
+ data : GeneExpressionDataModule
31
+ The data module containing gene expression datasets.
32
+ output : AnnData
33
+ The generated gene expression data to evaluate.
34
+ """
35
+
36
+ def __init__(self, data, output):
37
+ super().__init__(data, output)
38
+
39
+ def evaluate(self, delta=False, plot=False, DEG=None):
40
+ # Implementation of the evaluation logic will go here
41
+ pass
@@ -0,0 +1,23 @@
1
+ """
2
+ Data loading module for gene expression evaluation.
3
+
4
+ Provides data loaders for paired real and generated datasets.
5
+ """
6
+
7
+ from .loader import (
8
+ GeneExpressionDataLoader,
9
+ load_data,
10
+ DataLoaderError,
11
+ )
12
+ from .gene_expression_datamodule import (
13
+ GeneExpressionDataModule,
14
+ DataModuleError,
15
+ )
16
+
17
+ __all__ = [
18
+ "GeneExpressionDataLoader",
19
+ "load_data",
20
+ "DataLoaderError",
21
+ "GeneExpressionDataModule",
22
+ "DataModuleError",
23
+ ]