gengeneeval 0.1.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/PKG-INFO +6 -2
  2. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/README.md +5 -1
  3. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/pyproject.toml +1 -1
  4. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/__init__.py +7 -1
  5. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/evaluator.py +4 -0
  6. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/metrics/__init__.py +19 -0
  7. gengeneeval-0.2.0/src/geneval/metrics/reconstruction.py +243 -0
  8. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/LICENSE +0 -0
  9. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/cli.py +0 -0
  10. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/config.py +0 -0
  11. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/core.py +0 -0
  12. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/data/__init__.py +0 -0
  13. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/data/gene_expression_datamodule.py +0 -0
  14. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/data/loader.py +0 -0
  15. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/evaluators/__init__.py +0 -0
  16. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/evaluators/base_evaluator.py +0 -0
  17. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/evaluators/gene_expression_evaluator.py +0 -0
  18. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/metrics/base_metric.py +0 -0
  19. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/metrics/correlation.py +0 -0
  20. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/metrics/distances.py +0 -0
  21. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/metrics/metrics.py +0 -0
  22. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/models/__init__.py +0 -0
  23. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/models/base_model.py +0 -0
  24. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/results.py +0 -0
  25. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/testing.py +0 -0
  26. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/utils/__init__.py +0 -0
  27. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/utils/io.py +0 -0
  28. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/utils/preprocessing.py +0 -0
  29. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/visualization/__init__.py +0 -0
  30. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/visualization/plots.py +0 -0
  31. {gengeneeval-0.1.1 → gengeneeval-0.2.0}/src/geneval/visualization/visualizer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -42,7 +42,7 @@ Description-Content-Type: text/markdown
42
42
  [![PyPI version](https://badge.fury.io/py/gengeneeval.svg)](https://badge.fury.io/py/gengeneeval)
43
43
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
44
44
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
45
- [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/tests.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
45
+ [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/test.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
46
46
 
47
47
  **Comprehensive evaluation of generated gene expression data against real datasets.**
48
48
 
@@ -55,6 +55,10 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
55
55
 
56
56
  | Metric | Description | Direction |
57
57
  |--------|-------------|-----------|
58
+ | **MSE** | Mean Squared Error | Lower is better |
59
+ | **RMSE** | Root Mean Squared Error | Lower is better |
60
+ | **MAE** | Mean Absolute Error | Lower is better |
61
+ | **R²** | Coefficient of Determination | Higher is better |
58
62
  | **Pearson Correlation** | Linear correlation between expression profiles | Higher is better |
59
63
  | **Spearman Correlation** | Rank correlation (robust to outliers) | Higher is better |
60
64
  | **Wasserstein-1** | Earth Mover's Distance (L1) | Lower is better |
@@ -3,7 +3,7 @@
3
3
  [![PyPI version](https://badge.fury.io/py/gengeneeval.svg)](https://badge.fury.io/py/gengeneeval)
4
4
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
5
5
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
6
- [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/tests.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
6
+ [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/test.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
7
7
 
8
8
  **Comprehensive evaluation of generated gene expression data against real datasets.**
9
9
 
@@ -16,6 +16,10 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
16
16
 
17
17
  | Metric | Description | Direction |
18
18
  |--------|-------------|-----------|
19
+ | **MSE** | Mean Squared Error | Lower is better |
20
+ | **RMSE** | Root Mean Squared Error | Lower is better |
21
+ | **MAE** | Mean Absolute Error | Lower is better |
22
+ | **R²** | Coefficient of Determination | Higher is better |
19
23
  | **Pearson Correlation** | Linear correlation between expression profiles | Higher is better |
20
24
  | **Spearman Correlation** | Rank correlation (robust to outliers) | Higher is better |
21
25
  | **Wasserstein-1** | Earth Mover's Distance (L1) | Lower is better |
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "gengeneeval"
3
- version = "0.1.1"
3
+ version = "0.2.0"
4
4
  description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations."
5
5
  authors = ["GenEval Team <geneval@example.com>"]
6
6
  license = "MIT"
@@ -25,7 +25,7 @@ CLI Usage:
25
25
  --conditions perturbation cell_type --output results/
26
26
  """
27
27
 
28
- __version__ = "0.1.1"
28
+ __version__ = "0.2.0"
29
29
  __author__ = "GenEval Team"
30
30
 
31
31
  # Main evaluation interface
@@ -69,6 +69,12 @@ from .metrics.distances import (
69
69
  MultivariateWasserstein,
70
70
  MultivariateMMD,
71
71
  )
72
+ from .metrics.reconstruction import (
73
+ MSEDistance,
74
+ RMSEDistance,
75
+ MAEDistance,
76
+ R2Score,
77
+ )
72
78
 
73
79
  # Visualization
74
80
  from .visualization.visualizer import (
@@ -27,11 +27,15 @@ from .metrics.distances import (
27
27
  MultivariateWasserstein,
28
28
  MultivariateMMD,
29
29
  )
30
+ from .metrics.reconstruction import (
31
+ MSEDistance,
32
+ )
30
33
  from .results import EvaluationResult, SplitResult, ConditionResult
31
34
 
32
35
 
33
36
  # Default metrics to compute
34
37
  DEFAULT_METRICS = [
38
+ MSEDistance,
35
39
  PearsonCorrelation,
36
40
  SpearmanCorrelation,
37
41
  MeanPearsonCorrelation,
@@ -2,6 +2,7 @@
2
2
  Metrics module for gene expression evaluation.
3
3
 
4
4
  Provides per-gene and aggregate metrics for comparing distributions:
5
+ - Reconstruction metrics (MSE, RMSE, MAE, R²)
5
6
  - Correlation metrics (Pearson, Spearman)
6
7
  - Distribution distances (Wasserstein, MMD, Energy)
7
8
  - Multivariate distances
@@ -27,13 +28,26 @@ from .distances import (
27
28
  MultivariateWasserstein,
28
29
  MultivariateMMD,
29
30
  )
31
+ from .reconstruction import (
32
+ MSEDistance,
33
+ RMSEDistance,
34
+ MAEDistance,
35
+ R2Score,
36
+ )
30
37
 
31
38
  # All available metrics
32
39
  ALL_METRICS = [
40
+ # Reconstruction
41
+ MSEDistance,
42
+ RMSEDistance,
43
+ MAEDistance,
44
+ R2Score,
45
+ # Correlation
33
46
  PearsonCorrelation,
34
47
  SpearmanCorrelation,
35
48
  MeanPearsonCorrelation,
36
49
  MeanSpearmanCorrelation,
50
+ # Distribution
37
51
  Wasserstein1Distance,
38
52
  Wasserstein2Distance,
39
53
  MMDDistance,
@@ -48,6 +62,11 @@ __all__ = [
48
62
  "MetricResult",
49
63
  "DistributionMetric",
50
64
  "CorrelationMetric",
65
+ # Reconstruction metrics
66
+ "MSEDistance",
67
+ "RMSEDistance",
68
+ "MAEDistance",
69
+ "R2Score",
51
70
  # Correlation metrics
52
71
  "PearsonCorrelation",
53
72
  "SpearmanCorrelation",
@@ -0,0 +1,243 @@
1
+ """
2
+ Reconstruction metrics for gene expression evaluation.
3
+
4
+ Provides MSE (Mean Squared Error) and related reconstruction quality metrics.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ from typing import Optional
10
+
11
+ from .base_metric import BaseMetric
12
+
13
+
14
+ def _ensure_2d(arr: np.ndarray) -> np.ndarray:
15
+ """Ensure array is 2D (samples x genes)."""
16
+ arr = np.asarray(arr, dtype=np.float64)
17
+ if arr.ndim == 1:
18
+ arr = arr.reshape(-1, 1)
19
+ return arr
20
+
21
+
22
+ class MSEDistance(BaseMetric):
23
+ """
24
+ Mean Squared Error (MSE) between real and generated distributions.
25
+
26
+ Computes the average squared difference between samples. When sample
27
+ sizes differ, compares mean expression profiles.
28
+
29
+ Lower values indicate better reconstruction.
30
+
31
+ Parameters
32
+ ----------
33
+ compare_means : bool
34
+ If True, always compare mean profiles regardless of sample sizes.
35
+ If False, compute sample-wise MSE when sizes match.
36
+
37
+ Examples
38
+ --------
39
+ >>> mse = MSEDistance()
40
+ >>> result = mse.compute(real_data, generated_data, gene_names)
41
+ >>> print(f"MSE: {result.aggregate_value:.4f}")
42
+ """
43
+
44
+ def __init__(self, compare_means: bool = False):
45
+ super().__init__(
46
+ name="mse",
47
+ description="Mean Squared Error per gene",
48
+ higher_is_better=False,
49
+ requires_distribution=True,
50
+ )
51
+ self.compare_means = compare_means
52
+
53
+ def compute_per_gene(
54
+ self,
55
+ real: np.ndarray,
56
+ generated: np.ndarray,
57
+ ) -> np.ndarray:
58
+ """
59
+ Compute MSE for each gene.
60
+
61
+ Parameters
62
+ ----------
63
+ real : np.ndarray
64
+ Real data, shape (n_samples_real, n_genes)
65
+ generated : np.ndarray
66
+ Generated data, shape (n_samples_gen, n_genes)
67
+
68
+ Returns
69
+ -------
70
+ np.ndarray
71
+ MSE per gene, shape (n_genes,)
72
+ """
73
+ real = _ensure_2d(real)
74
+ generated = _ensure_2d(generated)
75
+ n_genes = real.shape[1]
76
+
77
+ mse = np.zeros(n_genes)
78
+
79
+ # Compare mean profiles when sample sizes differ or compare_means is True
80
+ if self.compare_means or real.shape[0] != generated.shape[0]:
81
+ real_mean = np.mean(real, axis=0)
82
+ gen_mean = np.mean(generated, axis=0)
83
+ mse = (real_mean - gen_mean) ** 2
84
+ else:
85
+ # Sample-wise MSE when sizes match
86
+ for i in range(n_genes):
87
+ r_vals = real[:, i]
88
+ g_vals = generated[:, i]
89
+
90
+ # Filter NaN values
91
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
92
+ if not valid.any():
93
+ mse[i] = np.nan
94
+ continue
95
+
96
+ mse[i] = np.mean((r_vals[valid] - g_vals[valid]) ** 2)
97
+
98
+ return mse
99
+
100
+
101
+ class RMSEDistance(BaseMetric):
102
+ """
103
+ Root Mean Squared Error (RMSE) between real and generated distributions.
104
+
105
+ Square root of MSE, in the same units as the original data.
106
+ Lower values indicate better reconstruction.
107
+ """
108
+
109
+ def __init__(self, compare_means: bool = False):
110
+ super().__init__(
111
+ name="rmse",
112
+ description="Root Mean Squared Error per gene",
113
+ higher_is_better=False,
114
+ requires_distribution=True,
115
+ )
116
+ self.compare_means = compare_means
117
+ self._mse = MSEDistance(compare_means=compare_means)
118
+
119
+ def compute_per_gene(
120
+ self,
121
+ real: np.ndarray,
122
+ generated: np.ndarray,
123
+ ) -> np.ndarray:
124
+ """
125
+ Compute RMSE for each gene.
126
+ """
127
+ mse = self._mse.compute_per_gene(real, generated)
128
+ return np.sqrt(mse)
129
+
130
+
131
+ class MAEDistance(BaseMetric):
132
+ """
133
+ Mean Absolute Error (MAE) between real and generated distributions.
134
+
135
+ More robust to outliers than MSE.
136
+ Lower values indicate better reconstruction.
137
+ """
138
+
139
+ def __init__(self, compare_means: bool = False):
140
+ super().__init__(
141
+ name="mae",
142
+ description="Mean Absolute Error per gene",
143
+ higher_is_better=False,
144
+ requires_distribution=True,
145
+ )
146
+ self.compare_means = compare_means
147
+
148
+ def compute_per_gene(
149
+ self,
150
+ real: np.ndarray,
151
+ generated: np.ndarray,
152
+ ) -> np.ndarray:
153
+ """
154
+ Compute MAE for each gene.
155
+ """
156
+ real = _ensure_2d(real)
157
+ generated = _ensure_2d(generated)
158
+ n_genes = real.shape[1]
159
+
160
+ mae = np.zeros(n_genes)
161
+
162
+ if self.compare_means or real.shape[0] != generated.shape[0]:
163
+ real_mean = np.mean(real, axis=0)
164
+ gen_mean = np.mean(generated, axis=0)
165
+ mae = np.abs(real_mean - gen_mean)
166
+ else:
167
+ for i in range(n_genes):
168
+ r_vals = real[:, i]
169
+ g_vals = generated[:, i]
170
+
171
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
172
+ if not valid.any():
173
+ mae[i] = np.nan
174
+ continue
175
+
176
+ mae[i] = np.mean(np.abs(r_vals[valid] - g_vals[valid]))
177
+
178
+ return mae
179
+
180
+
181
+ class R2Score(BaseMetric):
182
+ """
183
+ Coefficient of Determination (R²) between real and generated data.
184
+
185
+ Measures the proportion of variance explained. Values close to 1
186
+ indicate good fit, 0 means no better than mean prediction.
187
+
188
+ Higher values indicate better reconstruction.
189
+ """
190
+
191
+ def __init__(self):
192
+ super().__init__(
193
+ name="r2",
194
+ description="R² (coefficient of determination) per gene",
195
+ higher_is_better=True,
196
+ requires_distribution=True,
197
+ )
198
+
199
+ def compute_per_gene(
200
+ self,
201
+ real: np.ndarray,
202
+ generated: np.ndarray,
203
+ ) -> np.ndarray:
204
+ """
205
+ Compute R² for each gene.
206
+ """
207
+ real = _ensure_2d(real)
208
+ generated = _ensure_2d(generated)
209
+ n_genes = real.shape[1]
210
+
211
+ r2 = np.zeros(n_genes)
212
+
213
+ # R² only makes sense when sample sizes match
214
+ if real.shape[0] != generated.shape[0]:
215
+ # Fall back to using mean comparison
216
+ real_mean = np.mean(real, axis=0)
217
+ gen_mean = np.mean(generated, axis=0)
218
+
219
+ ss_tot = np.var(real, axis=0) * real.shape[0]
220
+ ss_res = (real_mean - gen_mean) ** 2
221
+
222
+ with np.errstate(invalid='ignore', divide='ignore'):
223
+ r2 = 1 - ss_res / (ss_tot / real.shape[0] + 1e-10)
224
+ r2 = np.nan_to_num(r2, nan=0.0)
225
+ else:
226
+ for i in range(n_genes):
227
+ r_vals = real[:, i]
228
+ g_vals = generated[:, i]
229
+
230
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
231
+ if not valid.any():
232
+ r2[i] = np.nan
233
+ continue
234
+
235
+ ss_tot = np.sum((r_vals[valid] - np.mean(r_vals[valid])) ** 2)
236
+ ss_res = np.sum((r_vals[valid] - g_vals[valid]) ** 2)
237
+
238
+ if ss_tot < 1e-10:
239
+ r2[i] = 1.0 if ss_res < 1e-10 else 0.0
240
+ else:
241
+ r2[i] = 1 - ss_res / ss_tot
242
+
243
+ return r2
File without changes