gengeneeval 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/PKG-INFO +6 -2
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/README.md +5 -1
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/pyproject.toml +1 -1
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/__init__.py +7 -1
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/evaluator.py +4 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/metrics/__init__.py +19 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/metrics/metrics.py +18 -0
- gengeneeval-0.2.0/src/geneval/metrics/reconstruction.py +243 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/visualization/visualizer.py +5 -1
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/LICENSE +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/cli.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/config.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/core.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/data/__init__.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/data/gene_expression_datamodule.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/data/loader.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/evaluators/__init__.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/evaluators/base_evaluator.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/evaluators/gene_expression_evaluator.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/metrics/base_metric.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/metrics/correlation.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/metrics/distances.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/models/__init__.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/models/base_model.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/results.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/testing.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/utils/__init__.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/utils/io.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/utils/preprocessing.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/visualization/__init__.py +0 -0
- {gengeneeval-0.1.0 → gengeneeval-0.2.0}/src/geneval/visualization/plots.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gengeneeval
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations.
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -42,7 +42,7 @@ Description-Content-Type: text/markdown
|
|
|
42
42
|
[](https://badge.fury.io/py/gengeneeval)
|
|
43
43
|
[](https://www.python.org/downloads/)
|
|
44
44
|
[](https://opensource.org/licenses/MIT)
|
|
45
|
-
[](https://github.com/AndreaRubbi/GenGeneEval/actions)
|
|
46
46
|
|
|
47
47
|
**Comprehensive evaluation of generated gene expression data against real datasets.**
|
|
48
48
|
|
|
@@ -55,6 +55,10 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
55
55
|
|
|
56
56
|
| Metric | Description | Direction |
|
|
57
57
|
|--------|-------------|-----------|
|
|
58
|
+
| **MSE** | Mean Squared Error | Lower is better |
|
|
59
|
+
| **RMSE** | Root Mean Squared Error | Lower is better |
|
|
60
|
+
| **MAE** | Mean Absolute Error | Lower is better |
|
|
61
|
+
| **R²** | Coefficient of Determination | Higher is better |
|
|
58
62
|
| **Pearson Correlation** | Linear correlation between expression profiles | Higher is better |
|
|
59
63
|
| **Spearman Correlation** | Rank correlation (robust to outliers) | Higher is better |
|
|
60
64
|
| **Wasserstein-1** | Earth Mover's Distance (L1) | Lower is better |
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
[](https://badge.fury.io/py/gengeneeval)
|
|
4
4
|
[](https://www.python.org/downloads/)
|
|
5
5
|
[](https://opensource.org/licenses/MIT)
|
|
6
|
-
[](https://github.com/AndreaRubbi/GenGeneEval/actions)
|
|
7
7
|
|
|
8
8
|
**Comprehensive evaluation of generated gene expression data against real datasets.**
|
|
9
9
|
|
|
@@ -16,6 +16,10 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
|
|
|
16
16
|
|
|
17
17
|
| Metric | Description | Direction |
|
|
18
18
|
|--------|-------------|-----------|
|
|
19
|
+
| **MSE** | Mean Squared Error | Lower is better |
|
|
20
|
+
| **RMSE** | Root Mean Squared Error | Lower is better |
|
|
21
|
+
| **MAE** | Mean Absolute Error | Lower is better |
|
|
22
|
+
| **R²** | Coefficient of Determination | Higher is better |
|
|
19
23
|
| **Pearson Correlation** | Linear correlation between expression profiles | Higher is better |
|
|
20
24
|
| **Spearman Correlation** | Rank correlation (robust to outliers) | Higher is better |
|
|
21
25
|
| **Wasserstein-1** | Earth Mover's Distance (L1) | Lower is better |
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "gengeneeval"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.2.0"
|
|
4
4
|
description = "Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations."
|
|
5
5
|
authors = ["GenEval Team <geneval@example.com>"]
|
|
6
6
|
license = "MIT"
|
|
@@ -25,7 +25,7 @@ CLI Usage:
|
|
|
25
25
|
--conditions perturbation cell_type --output results/
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
-
__version__ = "0.
|
|
28
|
+
__version__ = "0.2.0"
|
|
29
29
|
__author__ = "GenEval Team"
|
|
30
30
|
|
|
31
31
|
# Main evaluation interface
|
|
@@ -69,6 +69,12 @@ from .metrics.distances import (
|
|
|
69
69
|
MultivariateWasserstein,
|
|
70
70
|
MultivariateMMD,
|
|
71
71
|
)
|
|
72
|
+
from .metrics.reconstruction import (
|
|
73
|
+
MSEDistance,
|
|
74
|
+
RMSEDistance,
|
|
75
|
+
MAEDistance,
|
|
76
|
+
R2Score,
|
|
77
|
+
)
|
|
72
78
|
|
|
73
79
|
# Visualization
|
|
74
80
|
from .visualization.visualizer import (
|
|
@@ -27,11 +27,15 @@ from .metrics.distances import (
|
|
|
27
27
|
MultivariateWasserstein,
|
|
28
28
|
MultivariateMMD,
|
|
29
29
|
)
|
|
30
|
+
from .metrics.reconstruction import (
|
|
31
|
+
MSEDistance,
|
|
32
|
+
)
|
|
30
33
|
from .results import EvaluationResult, SplitResult, ConditionResult
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
# Default metrics to compute
|
|
34
37
|
DEFAULT_METRICS = [
|
|
38
|
+
MSEDistance,
|
|
35
39
|
PearsonCorrelation,
|
|
36
40
|
SpearmanCorrelation,
|
|
37
41
|
MeanPearsonCorrelation,
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Metrics module for gene expression evaluation.
|
|
3
3
|
|
|
4
4
|
Provides per-gene and aggregate metrics for comparing distributions:
|
|
5
|
+
- Reconstruction metrics (MSE, RMSE, MAE, R²)
|
|
5
6
|
- Correlation metrics (Pearson, Spearman)
|
|
6
7
|
- Distribution distances (Wasserstein, MMD, Energy)
|
|
7
8
|
- Multivariate distances
|
|
@@ -27,13 +28,26 @@ from .distances import (
|
|
|
27
28
|
MultivariateWasserstein,
|
|
28
29
|
MultivariateMMD,
|
|
29
30
|
)
|
|
31
|
+
from .reconstruction import (
|
|
32
|
+
MSEDistance,
|
|
33
|
+
RMSEDistance,
|
|
34
|
+
MAEDistance,
|
|
35
|
+
R2Score,
|
|
36
|
+
)
|
|
30
37
|
|
|
31
38
|
# All available metrics
|
|
32
39
|
ALL_METRICS = [
|
|
40
|
+
# Reconstruction
|
|
41
|
+
MSEDistance,
|
|
42
|
+
RMSEDistance,
|
|
43
|
+
MAEDistance,
|
|
44
|
+
R2Score,
|
|
45
|
+
# Correlation
|
|
33
46
|
PearsonCorrelation,
|
|
34
47
|
SpearmanCorrelation,
|
|
35
48
|
MeanPearsonCorrelation,
|
|
36
49
|
MeanSpearmanCorrelation,
|
|
50
|
+
# Distribution
|
|
37
51
|
Wasserstein1Distance,
|
|
38
52
|
Wasserstein2Distance,
|
|
39
53
|
MMDDistance,
|
|
@@ -48,6 +62,11 @@ __all__ = [
|
|
|
48
62
|
"MetricResult",
|
|
49
63
|
"DistributionMetric",
|
|
50
64
|
"CorrelationMetric",
|
|
65
|
+
# Reconstruction metrics
|
|
66
|
+
"MSEDistance",
|
|
67
|
+
"RMSEDistance",
|
|
68
|
+
"MAEDistance",
|
|
69
|
+
"R2Score",
|
|
51
70
|
# Correlation metrics
|
|
52
71
|
"PearsonCorrelation",
|
|
53
72
|
"SpearmanCorrelation",
|
|
@@ -7,6 +7,24 @@ from scipy.stats import pearsonr, spearmanr
|
|
|
7
7
|
import torch
|
|
8
8
|
from . import metric_MMD
|
|
9
9
|
|
|
10
|
+
|
|
11
|
+
def scanpy_preprocessing(adata: ad.AnnData) -> ad.AnnData:
|
|
12
|
+
"""Apply standard scanpy preprocessing."""
|
|
13
|
+
adata = adata.copy()
|
|
14
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
|
15
|
+
sc.pp.log1p(adata)
|
|
16
|
+
return adata
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def scanpy_pca(adata: ad.AnnData, n_comps: int = 50) -> ad.AnnData:
|
|
20
|
+
"""Compute PCA on AnnData object."""
|
|
21
|
+
adata = adata.copy()
|
|
22
|
+
if adata.n_vars > 2000:
|
|
23
|
+
sc.pp.highly_variable_genes(adata, n_top_genes=2000, flavor='seurat_v3', subset=True)
|
|
24
|
+
sc.tl.pca(adata, n_comps=min(n_comps, adata.n_vars - 1, adata.n_obs - 1))
|
|
25
|
+
return adata
|
|
26
|
+
|
|
27
|
+
|
|
10
28
|
class Metric():
|
|
11
29
|
def __init__(self, name: str, fn):
|
|
12
30
|
self.name = name
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reconstruction metrics for gene expression evaluation.
|
|
3
|
+
|
|
4
|
+
Provides MSE (Mean Squared Error) and related reconstruction quality metrics.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from .base_metric import BaseMetric
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _ensure_2d(arr: np.ndarray) -> np.ndarray:
|
|
15
|
+
"""Ensure array is 2D (samples x genes)."""
|
|
16
|
+
arr = np.asarray(arr, dtype=np.float64)
|
|
17
|
+
if arr.ndim == 1:
|
|
18
|
+
arr = arr.reshape(-1, 1)
|
|
19
|
+
return arr
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MSEDistance(BaseMetric):
|
|
23
|
+
"""
|
|
24
|
+
Mean Squared Error (MSE) between real and generated distributions.
|
|
25
|
+
|
|
26
|
+
Computes the average squared difference between samples. When sample
|
|
27
|
+
sizes differ, compares mean expression profiles.
|
|
28
|
+
|
|
29
|
+
Lower values indicate better reconstruction.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
compare_means : bool
|
|
34
|
+
If True, always compare mean profiles regardless of sample sizes.
|
|
35
|
+
If False, compute sample-wise MSE when sizes match.
|
|
36
|
+
|
|
37
|
+
Examples
|
|
38
|
+
--------
|
|
39
|
+
>>> mse = MSEDistance()
|
|
40
|
+
>>> result = mse.compute(real_data, generated_data, gene_names)
|
|
41
|
+
>>> print(f"MSE: {result.aggregate_value:.4f}")
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, compare_means: bool = False):
|
|
45
|
+
super().__init__(
|
|
46
|
+
name="mse",
|
|
47
|
+
description="Mean Squared Error per gene",
|
|
48
|
+
higher_is_better=False,
|
|
49
|
+
requires_distribution=True,
|
|
50
|
+
)
|
|
51
|
+
self.compare_means = compare_means
|
|
52
|
+
|
|
53
|
+
def compute_per_gene(
|
|
54
|
+
self,
|
|
55
|
+
real: np.ndarray,
|
|
56
|
+
generated: np.ndarray,
|
|
57
|
+
) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Compute MSE for each gene.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
real : np.ndarray
|
|
64
|
+
Real data, shape (n_samples_real, n_genes)
|
|
65
|
+
generated : np.ndarray
|
|
66
|
+
Generated data, shape (n_samples_gen, n_genes)
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
np.ndarray
|
|
71
|
+
MSE per gene, shape (n_genes,)
|
|
72
|
+
"""
|
|
73
|
+
real = _ensure_2d(real)
|
|
74
|
+
generated = _ensure_2d(generated)
|
|
75
|
+
n_genes = real.shape[1]
|
|
76
|
+
|
|
77
|
+
mse = np.zeros(n_genes)
|
|
78
|
+
|
|
79
|
+
# Compare mean profiles when sample sizes differ or compare_means is True
|
|
80
|
+
if self.compare_means or real.shape[0] != generated.shape[0]:
|
|
81
|
+
real_mean = np.mean(real, axis=0)
|
|
82
|
+
gen_mean = np.mean(generated, axis=0)
|
|
83
|
+
mse = (real_mean - gen_mean) ** 2
|
|
84
|
+
else:
|
|
85
|
+
# Sample-wise MSE when sizes match
|
|
86
|
+
for i in range(n_genes):
|
|
87
|
+
r_vals = real[:, i]
|
|
88
|
+
g_vals = generated[:, i]
|
|
89
|
+
|
|
90
|
+
# Filter NaN values
|
|
91
|
+
valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
|
|
92
|
+
if not valid.any():
|
|
93
|
+
mse[i] = np.nan
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
mse[i] = np.mean((r_vals[valid] - g_vals[valid]) ** 2)
|
|
97
|
+
|
|
98
|
+
return mse
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RMSEDistance(BaseMetric):
|
|
102
|
+
"""
|
|
103
|
+
Root Mean Squared Error (RMSE) between real and generated distributions.
|
|
104
|
+
|
|
105
|
+
Square root of MSE, in the same units as the original data.
|
|
106
|
+
Lower values indicate better reconstruction.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(self, compare_means: bool = False):
|
|
110
|
+
super().__init__(
|
|
111
|
+
name="rmse",
|
|
112
|
+
description="Root Mean Squared Error per gene",
|
|
113
|
+
higher_is_better=False,
|
|
114
|
+
requires_distribution=True,
|
|
115
|
+
)
|
|
116
|
+
self.compare_means = compare_means
|
|
117
|
+
self._mse = MSEDistance(compare_means=compare_means)
|
|
118
|
+
|
|
119
|
+
def compute_per_gene(
|
|
120
|
+
self,
|
|
121
|
+
real: np.ndarray,
|
|
122
|
+
generated: np.ndarray,
|
|
123
|
+
) -> np.ndarray:
|
|
124
|
+
"""
|
|
125
|
+
Compute RMSE for each gene.
|
|
126
|
+
"""
|
|
127
|
+
mse = self._mse.compute_per_gene(real, generated)
|
|
128
|
+
return np.sqrt(mse)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class MAEDistance(BaseMetric):
|
|
132
|
+
"""
|
|
133
|
+
Mean Absolute Error (MAE) between real and generated distributions.
|
|
134
|
+
|
|
135
|
+
More robust to outliers than MSE.
|
|
136
|
+
Lower values indicate better reconstruction.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def __init__(self, compare_means: bool = False):
|
|
140
|
+
super().__init__(
|
|
141
|
+
name="mae",
|
|
142
|
+
description="Mean Absolute Error per gene",
|
|
143
|
+
higher_is_better=False,
|
|
144
|
+
requires_distribution=True,
|
|
145
|
+
)
|
|
146
|
+
self.compare_means = compare_means
|
|
147
|
+
|
|
148
|
+
def compute_per_gene(
|
|
149
|
+
self,
|
|
150
|
+
real: np.ndarray,
|
|
151
|
+
generated: np.ndarray,
|
|
152
|
+
) -> np.ndarray:
|
|
153
|
+
"""
|
|
154
|
+
Compute MAE for each gene.
|
|
155
|
+
"""
|
|
156
|
+
real = _ensure_2d(real)
|
|
157
|
+
generated = _ensure_2d(generated)
|
|
158
|
+
n_genes = real.shape[1]
|
|
159
|
+
|
|
160
|
+
mae = np.zeros(n_genes)
|
|
161
|
+
|
|
162
|
+
if self.compare_means or real.shape[0] != generated.shape[0]:
|
|
163
|
+
real_mean = np.mean(real, axis=0)
|
|
164
|
+
gen_mean = np.mean(generated, axis=0)
|
|
165
|
+
mae = np.abs(real_mean - gen_mean)
|
|
166
|
+
else:
|
|
167
|
+
for i in range(n_genes):
|
|
168
|
+
r_vals = real[:, i]
|
|
169
|
+
g_vals = generated[:, i]
|
|
170
|
+
|
|
171
|
+
valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
|
|
172
|
+
if not valid.any():
|
|
173
|
+
mae[i] = np.nan
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
mae[i] = np.mean(np.abs(r_vals[valid] - g_vals[valid]))
|
|
177
|
+
|
|
178
|
+
return mae
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class R2Score(BaseMetric):
|
|
182
|
+
"""
|
|
183
|
+
Coefficient of Determination (R²) between real and generated data.
|
|
184
|
+
|
|
185
|
+
Measures the proportion of variance explained. Values close to 1
|
|
186
|
+
indicate good fit, 0 means no better than mean prediction.
|
|
187
|
+
|
|
188
|
+
Higher values indicate better reconstruction.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self):
|
|
192
|
+
super().__init__(
|
|
193
|
+
name="r2",
|
|
194
|
+
description="R² (coefficient of determination) per gene",
|
|
195
|
+
higher_is_better=True,
|
|
196
|
+
requires_distribution=True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def compute_per_gene(
|
|
200
|
+
self,
|
|
201
|
+
real: np.ndarray,
|
|
202
|
+
generated: np.ndarray,
|
|
203
|
+
) -> np.ndarray:
|
|
204
|
+
"""
|
|
205
|
+
Compute R² for each gene.
|
|
206
|
+
"""
|
|
207
|
+
real = _ensure_2d(real)
|
|
208
|
+
generated = _ensure_2d(generated)
|
|
209
|
+
n_genes = real.shape[1]
|
|
210
|
+
|
|
211
|
+
r2 = np.zeros(n_genes)
|
|
212
|
+
|
|
213
|
+
# R² only makes sense when sample sizes match
|
|
214
|
+
if real.shape[0] != generated.shape[0]:
|
|
215
|
+
# Fall back to using mean comparison
|
|
216
|
+
real_mean = np.mean(real, axis=0)
|
|
217
|
+
gen_mean = np.mean(generated, axis=0)
|
|
218
|
+
|
|
219
|
+
ss_tot = np.var(real, axis=0) * real.shape[0]
|
|
220
|
+
ss_res = (real_mean - gen_mean) ** 2
|
|
221
|
+
|
|
222
|
+
with np.errstate(invalid='ignore', divide='ignore'):
|
|
223
|
+
r2 = 1 - ss_res / (ss_tot / real.shape[0] + 1e-10)
|
|
224
|
+
r2 = np.nan_to_num(r2, nan=0.0)
|
|
225
|
+
else:
|
|
226
|
+
for i in range(n_genes):
|
|
227
|
+
r_vals = real[:, i]
|
|
228
|
+
g_vals = generated[:, i]
|
|
229
|
+
|
|
230
|
+
valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
|
|
231
|
+
if not valid.any():
|
|
232
|
+
r2[i] = np.nan
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
ss_tot = np.sum((r_vals[valid] - np.mean(r_vals[valid])) ** 2)
|
|
236
|
+
ss_res = np.sum((r_vals[valid] - g_vals[valid]) ** 2)
|
|
237
|
+
|
|
238
|
+
if ss_tot < 1e-10:
|
|
239
|
+
r2[i] = 1.0 if ss_res < 1e-10 else 0.0
|
|
240
|
+
else:
|
|
241
|
+
r2[i] = 1 - ss_res / ss_tot
|
|
242
|
+
|
|
243
|
+
return r2
|
|
@@ -10,12 +10,16 @@ Provides publication-quality plots for evaluation results:
|
|
|
10
10
|
"""
|
|
11
11
|
from __future__ import annotations
|
|
12
12
|
|
|
13
|
-
from typing import Dict, List, Optional, Tuple, Union, Any
|
|
13
|
+
from typing import Dict, List, Optional, Tuple, Union, Any, TYPE_CHECKING
|
|
14
14
|
from pathlib import Path
|
|
15
15
|
import numpy as np
|
|
16
16
|
import pandas as pd
|
|
17
17
|
import warnings
|
|
18
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from geneval.results import EvaluationResult
|
|
21
|
+
from geneval.data.loader import GeneExpressionDataLoader
|
|
22
|
+
|
|
19
23
|
try:
|
|
20
24
|
import matplotlib.pyplot as plt
|
|
21
25
|
import matplotlib.patches as mpatches
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|