gengeneeval 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,243 @@
1
+ """
2
+ Reconstruction metrics for gene expression evaluation.
3
+
4
+ Provides MSE (Mean Squared Error) and related reconstruction quality metrics.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ from typing import Optional
10
+
11
+ from .base_metric import BaseMetric
12
+
13
+
14
+ def _ensure_2d(arr: np.ndarray) -> np.ndarray:
15
+ """Ensure array is 2D (samples x genes)."""
16
+ arr = np.asarray(arr, dtype=np.float64)
17
+ if arr.ndim == 1:
18
+ arr = arr.reshape(-1, 1)
19
+ return arr
20
+
21
+
22
+ class MSEDistance(BaseMetric):
23
+ """
24
+ Mean Squared Error (MSE) between real and generated distributions.
25
+
26
+ Computes the average squared difference between samples. When sample
27
+ sizes differ, compares mean expression profiles.
28
+
29
+ Lower values indicate better reconstruction.
30
+
31
+ Parameters
32
+ ----------
33
+ compare_means : bool
34
+ If True, always compare mean profiles regardless of sample sizes.
35
+ If False, compute sample-wise MSE when sizes match.
36
+
37
+ Examples
38
+ --------
39
+ >>> mse = MSEDistance()
40
+ >>> result = mse.compute(real_data, generated_data, gene_names)
41
+ >>> print(f"MSE: {result.aggregate_value:.4f}")
42
+ """
43
+
44
+ def __init__(self, compare_means: bool = False):
45
+ super().__init__(
46
+ name="mse",
47
+ description="Mean Squared Error per gene",
48
+ higher_is_better=False,
49
+ requires_distribution=True,
50
+ )
51
+ self.compare_means = compare_means
52
+
53
+ def compute_per_gene(
54
+ self,
55
+ real: np.ndarray,
56
+ generated: np.ndarray,
57
+ ) -> np.ndarray:
58
+ """
59
+ Compute MSE for each gene.
60
+
61
+ Parameters
62
+ ----------
63
+ real : np.ndarray
64
+ Real data, shape (n_samples_real, n_genes)
65
+ generated : np.ndarray
66
+ Generated data, shape (n_samples_gen, n_genes)
67
+
68
+ Returns
69
+ -------
70
+ np.ndarray
71
+ MSE per gene, shape (n_genes,)
72
+ """
73
+ real = _ensure_2d(real)
74
+ generated = _ensure_2d(generated)
75
+ n_genes = real.shape[1]
76
+
77
+ mse = np.zeros(n_genes)
78
+
79
+ # Compare mean profiles when sample sizes differ or compare_means is True
80
+ if self.compare_means or real.shape[0] != generated.shape[0]:
81
+ real_mean = np.mean(real, axis=0)
82
+ gen_mean = np.mean(generated, axis=0)
83
+ mse = (real_mean - gen_mean) ** 2
84
+ else:
85
+ # Sample-wise MSE when sizes match
86
+ for i in range(n_genes):
87
+ r_vals = real[:, i]
88
+ g_vals = generated[:, i]
89
+
90
+ # Filter NaN values
91
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
92
+ if not valid.any():
93
+ mse[i] = np.nan
94
+ continue
95
+
96
+ mse[i] = np.mean((r_vals[valid] - g_vals[valid]) ** 2)
97
+
98
+ return mse
99
+
100
+
101
+ class RMSEDistance(BaseMetric):
102
+ """
103
+ Root Mean Squared Error (RMSE) between real and generated distributions.
104
+
105
+ Square root of MSE, in the same units as the original data.
106
+ Lower values indicate better reconstruction.
107
+ """
108
+
109
+ def __init__(self, compare_means: bool = False):
110
+ super().__init__(
111
+ name="rmse",
112
+ description="Root Mean Squared Error per gene",
113
+ higher_is_better=False,
114
+ requires_distribution=True,
115
+ )
116
+ self.compare_means = compare_means
117
+ self._mse = MSEDistance(compare_means=compare_means)
118
+
119
+ def compute_per_gene(
120
+ self,
121
+ real: np.ndarray,
122
+ generated: np.ndarray,
123
+ ) -> np.ndarray:
124
+ """
125
+ Compute RMSE for each gene.
126
+ """
127
+ mse = self._mse.compute_per_gene(real, generated)
128
+ return np.sqrt(mse)
129
+
130
+
131
+ class MAEDistance(BaseMetric):
132
+ """
133
+ Mean Absolute Error (MAE) between real and generated distributions.
134
+
135
+ More robust to outliers than MSE.
136
+ Lower values indicate better reconstruction.
137
+ """
138
+
139
+ def __init__(self, compare_means: bool = False):
140
+ super().__init__(
141
+ name="mae",
142
+ description="Mean Absolute Error per gene",
143
+ higher_is_better=False,
144
+ requires_distribution=True,
145
+ )
146
+ self.compare_means = compare_means
147
+
148
+ def compute_per_gene(
149
+ self,
150
+ real: np.ndarray,
151
+ generated: np.ndarray,
152
+ ) -> np.ndarray:
153
+ """
154
+ Compute MAE for each gene.
155
+ """
156
+ real = _ensure_2d(real)
157
+ generated = _ensure_2d(generated)
158
+ n_genes = real.shape[1]
159
+
160
+ mae = np.zeros(n_genes)
161
+
162
+ if self.compare_means or real.shape[0] != generated.shape[0]:
163
+ real_mean = np.mean(real, axis=0)
164
+ gen_mean = np.mean(generated, axis=0)
165
+ mae = np.abs(real_mean - gen_mean)
166
+ else:
167
+ for i in range(n_genes):
168
+ r_vals = real[:, i]
169
+ g_vals = generated[:, i]
170
+
171
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
172
+ if not valid.any():
173
+ mae[i] = np.nan
174
+ continue
175
+
176
+ mae[i] = np.mean(np.abs(r_vals[valid] - g_vals[valid]))
177
+
178
+ return mae
179
+
180
+
181
+ class R2Score(BaseMetric):
182
+ """
183
+ Coefficient of Determination (R²) between real and generated data.
184
+
185
+ Measures the proportion of variance explained. Values close to 1
186
+ indicate good fit, 0 means no better than mean prediction.
187
+
188
+ Higher values indicate better reconstruction.
189
+ """
190
+
191
+ def __init__(self):
192
+ super().__init__(
193
+ name="r2",
194
+ description="R² (coefficient of determination) per gene",
195
+ higher_is_better=True,
196
+ requires_distribution=True,
197
+ )
198
+
199
+ def compute_per_gene(
200
+ self,
201
+ real: np.ndarray,
202
+ generated: np.ndarray,
203
+ ) -> np.ndarray:
204
+ """
205
+ Compute R² for each gene.
206
+ """
207
+ real = _ensure_2d(real)
208
+ generated = _ensure_2d(generated)
209
+ n_genes = real.shape[1]
210
+
211
+ r2 = np.zeros(n_genes)
212
+
213
+ # R² only makes sense when sample sizes match
214
+ if real.shape[0] != generated.shape[0]:
215
+ # Fall back to using mean comparison
216
+ real_mean = np.mean(real, axis=0)
217
+ gen_mean = np.mean(generated, axis=0)
218
+
219
+ ss_tot = np.var(real, axis=0) * real.shape[0]
220
+ ss_res = (real_mean - gen_mean) ** 2
221
+
222
+ with np.errstate(invalid='ignore', divide='ignore'):
223
+ r2 = 1 - ss_res / (ss_tot / real.shape[0] + 1e-10)
224
+ r2 = np.nan_to_num(r2, nan=0.0)
225
+ else:
226
+ for i in range(n_genes):
227
+ r_vals = real[:, i]
228
+ g_vals = generated[:, i]
229
+
230
+ valid = ~(np.isnan(r_vals) | np.isnan(g_vals))
231
+ if not valid.any():
232
+ r2[i] = np.nan
233
+ continue
234
+
235
+ ss_tot = np.sum((r_vals[valid] - np.mean(r_vals[valid])) ** 2)
236
+ ss_res = np.sum((r_vals[valid] - g_vals[valid]) ** 2)
237
+
238
+ if ss_tot < 1e-10:
239
+ r2[i] = 1.0 if ss_res < 1e-10 else 0.0
240
+ else:
241
+ r2[i] = 1 - ss_res / ss_tot
242
+
243
+ return r2
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gengeneeval
3
- Version: 0.1.1
4
- Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, and publication-quality visualizations.
3
+ Version: 0.2.1
4
+ Summary: Comprehensive evaluation of generated gene expression data. Computes metrics between real and generated datasets with support for condition matching, train/test splits, memory-efficient lazy loading, and publication-quality visualizations.
5
5
  License: MIT
6
6
  License-File: LICENSE
7
- Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking
7
+ Keywords: gene expression,evaluation,metrics,single-cell,generative models,benchmarking,memory-efficient
8
8
  Author: GenEval Team
9
9
  Author-email: geneval@example.com
10
10
  Requires-Python: >=3.8,<4.0
@@ -42,11 +42,11 @@ Description-Content-Type: text/markdown
42
42
  [![PyPI version](https://badge.fury.io/py/gengeneeval.svg)](https://badge.fury.io/py/gengeneeval)
43
43
  [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
44
44
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
45
- [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/tests.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
45
+ [![Tests](https://github.com/AndreaRubbi/GenGeneEval/actions/workflows/test.yml/badge.svg)](https://github.com/AndreaRubbi/GenGeneEval/actions)
46
46
 
47
47
  **Comprehensive evaluation of generated gene expression data against real datasets.**
48
48
 
49
- GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, and generates publication-quality visualizations.
49
+ GenEval is a modular, object-oriented Python framework for computing metrics between real and generated gene expression datasets stored in AnnData (h5ad) format. It supports condition-based matching, train/test splits, memory-efficient lazy loading for large datasets, and generates publication-quality visualizations.
50
50
 
51
51
  ## Features
52
52
 
@@ -55,6 +55,10 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
55
55
 
56
56
  | Metric | Description | Direction |
57
57
  |--------|-------------|-----------|
58
+ | **MSE** | Mean Squared Error | Lower is better |
59
+ | **RMSE** | Root Mean Squared Error | Lower is better |
60
+ | **MAE** | Mean Absolute Error | Lower is better |
61
+ | **R²** | Coefficient of Determination | Higher is better |
58
62
  | **Pearson Correlation** | Linear correlation between expression profiles | Higher is better |
59
63
  | **Spearman Correlation** | Rank correlation (robust to outliers) | Higher is better |
60
64
  | **Wasserstein-1** | Earth Mover's Distance (L1) | Lower is better |
@@ -73,6 +77,8 @@ All metrics are computed **per-gene** (returning a vector) and **aggregated**:
73
77
  - ✅ Condition-based matching (perturbation, cell type, etc.)
74
78
  - ✅ Train/test split support
75
79
  - ✅ Per-gene and aggregate metrics
80
+ - ✅ **Memory-efficient lazy loading** for large datasets
81
+ - ✅ **Batched evaluation** to avoid OOM errors
76
82
  - ✅ Modular, extensible architecture
77
83
  - ✅ Command-line interface
78
84
  - ✅ Publication-quality visualizations
@@ -136,6 +142,37 @@ geneval --real real.h5ad --generated generated.h5ad \
136
142
  --output results/
137
143
  ```
138
144
 
145
+ ### Memory-Efficient Mode (for Large Datasets)
146
+
147
+ For datasets too large to fit in memory, use the lazy evaluation API:
148
+
149
+ ```python
150
+ from geneval import evaluate_lazy, load_data_lazy
151
+
152
+ # Memory-efficient evaluation (streams data one condition at a time)
153
+ results = evaluate_lazy(
154
+ real_path="large_real.h5ad",
155
+ generated_path="large_generated.h5ad",
156
+ condition_columns=["perturbation"],
157
+ batch_size=256, # Process in batches
158
+ use_backed=True, # Memory-mapped file access
159
+ output_dir="eval_output/",
160
+ save_per_condition=True, # Save each condition to disk
161
+ )
162
+
163
+ # Get summary statistics
164
+ print(results.get_summary())
165
+
166
+ # Or use the lazy loader directly for custom workflows
167
+ with load_data_lazy("real.h5ad", "gen.h5ad", ["perturbation"]) as loader:
168
+ print(f"Memory estimate: {loader.estimate_memory_usage()}")
169
+
170
+ # Process one condition at a time
171
+ for key, real, gen, info in loader.iterate_conditions():
172
+ # Your custom evaluation logic
173
+ pass
174
+ ```
175
+
139
176
  ## Expected Data Format
140
177
 
141
178
  GenEval expects AnnData (h5ad) files with:
@@ -1,19 +1,22 @@
1
- geneval/__init__.py,sha256=8D8KN9PSLbux6LMc6Ap1nz_WlQ4Jc6YNylPy_0udJ-Q,2872
1
+ geneval/__init__.py,sha256=_WxX5Kjk7y3u7mBZ5cf6ficy9SIT2FutZNcMe1fr9Ro,3989
2
2
  geneval/cli.py,sha256=0ai0IGyn3SSmEnfLRJhcr0brvUxuNZHE4IXod7jvosU,9977
3
3
  geneval/config.py,sha256=gkCjs_gzPWgUZNcmSR3Y70XQCAZ1m9AKLueaM-x8bvw,3729
4
4
  geneval/core.py,sha256=No0DP8bNR6LedfCWEedY9C5r_c4M14rvSPaGZqbxc94,1155
5
- geneval/data/__init__.py,sha256=nD3uWostZbYD3Yj_TOE44LvPDen-Vm3gN8ZH0QptPGw,450
5
+ geneval/data/__init__.py,sha256=NQUPVpUnBIabrTH5TuRk0KE9S7sVO5QetZv-MCQmZuw,827
6
6
  geneval/data/gene_expression_datamodule.py,sha256=XiBIdf68JZ-3S-FaZsrQlBJA7qL9uUXo2C8y0r4an5M,8009
7
+ geneval/data/lazy_loader.py,sha256=5fTRVjPjcWvYXV-uPWFUF2Nn9rHRdD8lygAUkCW8wOM,20677
7
8
  geneval/data/loader.py,sha256=zpRmwGZ4PJkB3rpXXRCMFtvMi4qvUrPkKmvIlGjfRpY,14555
8
- geneval/evaluator.py,sha256=grPudMng-CcnWwkxQGWM6RZ198Q-1THkR4MCXtadCdU,11545
9
+ geneval/evaluator.py,sha256=wZFzLo2PLHanjA-9L6C3xJBjMWXxPM63kU6usU4P7bs,11619
9
10
  geneval/evaluators/__init__.py,sha256=i11sHvhsjEAeI3Aw9zFTPmCYuqkGxzTHggAKehe3HQ0,160
10
11
  geneval/evaluators/base_evaluator.py,sha256=yJL568HdNofIcHgNOElSQMVlG9oRPTTDIZ7CmKccRqs,5967
11
12
  geneval/evaluators/gene_expression_evaluator.py,sha256=v8QL6tzOQ3QVXdPMM8tFHTTviZC3WsPRX4G0ShgeDUw,8743
12
- geneval/metrics/__init__.py,sha256=wk0CdFXvipfPqXWUMsRRz9CPiSVPG40Id4lyoSaLIkY,1417
13
+ geneval/lazy_evaluator.py,sha256=I_VvDolxPFGiW38eGPrjSoBOKICKyYN3GHbjJBAe5tg,13200
14
+ geneval/metrics/__init__.py,sha256=H5IXTKR-zoP_pGht6ioJfhLU7IHrSDQElMk0Cp4-JTw,1786
13
15
  geneval/metrics/base_metric.py,sha256=prbnB-Ap-P64m-2_TUrHxO3NFQaw-obVg1Tw4pjC5EY,6961
14
16
  geneval/metrics/correlation.py,sha256=jpYmaihWK89J1E5yQinGUJeB6pTZ21xPNHJi3XYyXJE,6987
15
17
  geneval/metrics/distances.py,sha256=9mWzbMbIBY1ckOd2a0l3by3aEFMQZL9bVMSeP44xzUg,16155
16
18
  geneval/metrics/metrics.py,sha256=s3ONmYTOk1ou6DUqCDOREQaqk3ajgpW_lZbZp1WW4aY,4747
19
+ geneval/metrics/reconstruction.py,sha256=phvQtB-CTFyfmkmj0oii4PChMmZtsRdfIRn_ISBq1w0,7387
17
20
  geneval/models/__init__.py,sha256=vJHXIhwzykjoqZ-vHQJnPwwjSUu9nnMyo7jGnWlTd94,42
18
21
  geneval/models/base_model.py,sha256=2QDtweYTgiovnksaRPBjNbIDu1l9l_WQMMFfeIX3GB8,1345
19
22
  geneval/results.py,sha256=iXSB0o0f1jQrCKjc-lbRfwBFGhspTDDJpQ2K2tM-XR4,11362
@@ -24,8 +27,8 @@ geneval/utils/preprocessing.py,sha256=1Cij1O2dwDR6_zh5IEgLPq3jEmV8VfIRjfQrHiKe3M
24
27
  geneval/visualization/__init__.py,sha256=LN19jl5xV4WVJTePaOUHWvKZ_pgDFp1chhcklGkNtm8,792
25
28
  geneval/visualization/plots.py,sha256=3K94r3x5NjIUZ-hYVQIivO63VkLOvDWl-BLB_qL2pSY,15008
26
29
  geneval/visualization/visualizer.py,sha256=lX7K0j20nAsgdtOOdbxLdLKYAfovEp3hNAnZOjFTCq0,36670
27
- gengeneeval-0.1.1.dist-info/METADATA,sha256=4WJtqiCK88ZIKGObjIXQ-Fp0PvmF6D5-YGpm0WciUNc,6041
28
- gengeneeval-0.1.1.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
29
- gengeneeval-0.1.1.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
30
- gengeneeval-0.1.1.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
31
- gengeneeval-0.1.1.dist-info/RECORD,,
30
+ gengeneeval-0.2.1.dist-info/METADATA,sha256=aRjsh5JUIcH8huIngsCi14eyE3-Vl_AFv9Uo1j5mciw,7497
31
+ gengeneeval-0.2.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
32
+ gengeneeval-0.2.1.dist-info/entry_points.txt,sha256=xTkwnNa2fP0w1uGVsafzRTaCeuBSWLlNO-1CN8uBSK0,43
33
+ gengeneeval-0.2.1.dist-info/licenses/LICENSE,sha256=RDHgHDI4rSDq35R4CAC3npy86YUnmZ81ecO7aHfmmGA,1073
34
+ gengeneeval-0.2.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.3.0
2
+ Generator: poetry-core 2.2.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any